From f6a612b1808faff28dceef9b2a2d0613b0e14fbf Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Wed, 15 Apr 2026 11:17:31 -0700 Subject: [PATCH 01/22] chore: add backend client and inventory skeletons Signed-off-by: Jingxiang Zhang --- internal/agentstate/state.go | 34 +++ internal/attestationloop/manager.go | 112 ++++++++++ internal/attestationloop/mapper/backend.go | 54 +++++ internal/attestationloop/sink/backend.go | 46 ++++ internal/attestationloop/source/source.go | 28 +++ internal/attestationloop/types.go | 71 +++++++ internal/backendclient/client.go | 233 +++++++++++++++++++++ internal/backendclient/client_test.go | 150 +++++++++++++ internal/backendclient/errors.go | 37 ++++ internal/backendclient/types.go | 137 ++++++++++++ internal/inventory/manager.go | 69 ++++++ internal/inventory/mapper/backend.go | 54 +++++ internal/inventory/sink/backend.go | 46 ++++ internal/inventory/source/source.go | 45 ++++ internal/inventory/types.go | 126 +++++++++++ internal/store/memory.go | 103 +++++++++ 16 files changed, 1345 insertions(+) create mode 100644 internal/agentstate/state.go create mode 100644 internal/attestationloop/manager.go create mode 100644 internal/attestationloop/mapper/backend.go create mode 100644 internal/attestationloop/sink/backend.go create mode 100644 internal/attestationloop/source/source.go create mode 100644 internal/attestationloop/types.go create mode 100644 internal/backendclient/client.go create mode 100644 internal/backendclient/client_test.go create mode 100644 internal/backendclient/errors.go create mode 100644 internal/backendclient/types.go create mode 100644 internal/inventory/manager.go create mode 100644 internal/inventory/mapper/backend.go create mode 100644 internal/inventory/sink/backend.go create mode 100644 internal/inventory/source/source.go create mode 100644 internal/inventory/types.go create mode 100644 internal/store/memory.go diff --git a/internal/agentstate/state.go b/internal/agentstate/state.go new file mode 100644 index 00000000..6149e0a3 --- /dev/null +++ b/internal/agentstate/state.go @@ -0,0 +1,34 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package agentstate centralizes access to local persisted agent state. +package agentstate + +import "context" + +// State provides local persisted metadata/state access for backend workflows. +type State interface { + GetBackendBaseURL(ctx context.Context) (string, error) + SetBackendBaseURL(ctx context.Context, value string) error + + GetJWT(ctx context.Context) (string, error) + SetJWT(ctx context.Context, value string) error + + GetSAK(ctx context.Context) (string, error) + SetSAK(ctx context.Context, value string) error + + GetNodeID(ctx context.Context) (string, error) + SetNodeID(ctx context.Context, value string) error +} diff --git a/internal/attestationloop/manager.go b/internal/attestationloop/manager.go new file mode 100644 index 00000000..2f1c3701 --- /dev/null +++ b/internal/attestationloop/manager.go @@ -0,0 +1,112 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestationloop + +import ( + "context" + "fmt" + "time" +) + +// JWTProvider retrieves the current backend JWT. +type JWTProvider interface { + GetJWT(ctx context.Context) (string, error) + SetJWT(ctx context.Context, value string) error +} + +// Manager coordinates periodic attestation collection into a store. +type Manager interface { + Run(ctx context.Context) error + CollectOnce(ctx context.Context) (*Result, error) +} + +type manager struct { + nodeIDProvider func(context.Context) (string, error) + jwtProvider JWTProvider + nonceProvider NonceProvider + collector EvidenceCollector + store StateStore + interval time.Duration +} + +// NewManager creates an attestation loop manager skeleton. +func NewManager( + nodeIDProvider func(context.Context) (string, error), + jwtProvider JWTProvider, + nonceProvider NonceProvider, + collector EvidenceCollector, + store StateStore, + interval time.Duration, +) Manager { + return &manager{ + nodeIDProvider: nodeIDProvider, + jwtProvider: jwtProvider, + nonceProvider: nonceProvider, + collector: collector, + store: store, + interval: interval, + } +} + +func (m *manager) Run(ctx context.Context) error { + if _, err := m.CollectOnce(ctx); err != nil { + return err + } + return fmt.Errorf("attestation loop run loop not implemented") +} + +func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { + if m.nodeIDProvider == nil || m.jwtProvider == nil || m.nonceProvider == nil || m.collector == nil { + return nil, fmt.Errorf("attestation loop dependencies are incomplete") + } + + nodeID, err := m.nodeIDProvider(ctx) + if err != nil { + return nil, err + } + jwt, err := m.jwtProvider.GetJWT(ctx) + if err != nil { + return nil, err + } + nonce, refreshTS, refreshedJWT, err := m.nonceProvider.GetNonce(ctx, nodeID, jwt) + if err != nil { + return nil, err + } + if refreshedJWT != "" && refreshedJWT != jwt { + if err := m.jwtProvider.SetJWT(ctx, refreshedJWT); err != nil { + return nil, err + } + } + sdkResp, err := m.collector.Collect(ctx, nonce) + if err != nil { + return nil, err + } + result := &Result{ + CollectedAt: time.Now().UTC(), + NodeID: nodeID, + NonceRefreshTimestamp: refreshTS, + Success: true, + } + if sdkResp != nil { + result.SDKResponse = *sdkResp + } + if m.store != nil { + if err := m.store.PutAttestation(ctx, *result); err != nil { + return nil, err + } + } + return result, nil +} diff --git a/internal/attestationloop/mapper/backend.go b/internal/attestationloop/mapper/backend.go new file mode 100644 index 00000000..cbaae521 --- /dev/null +++ b/internal/attestationloop/mapper/backend.go @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mapper contains attestation loop payload mappers. +package mapper + +import ( + "github.com/NVIDIA/fleet-intelligence-agent/internal/attestationloop" + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +) + +// ToAttestationRequest maps an attestation result to the backend attestation contract. +func ToAttestationRequest(r attestationloop.Result) backendclient.AttestationRequest { + req := backendclient.AttestationRequest{ + AttestationData: backendclient.AttestationData{ + NonceRefreshTimestamp: r.NonceRefreshTimestamp, + Success: r.Success, + ErrorMessage: r.ErrorMessage, + SDKResponse: backendclient.AttestationSDKResponse{ + ResultCode: r.SDKResponse.ResultCode, + ResultMessage: r.SDKResponse.ResultMessage, + }, + }, + } + + if len(r.SDKResponse.Evidences) > 0 { + req.AttestationData.SDKResponse.Evidences = make([]backendclient.EvidenceItem, 0, len(r.SDKResponse.Evidences)) + for _, ev := range r.SDKResponse.Evidences { + req.AttestationData.SDKResponse.Evidences = append(req.AttestationData.SDKResponse.Evidences, backendclient.EvidenceItem{ + Arch: ev.Arch, + Certificate: ev.Certificate, + DriverVersion: ev.DriverVersion, + Evidence: ev.Evidence, + Nonce: ev.Nonce, + VBIOSVersion: ev.VBIOSVersion, + Version: ev.Version, + }) + } + } + + return req +} diff --git a/internal/attestationloop/sink/backend.go b/internal/attestationloop/sink/backend.go new file mode 100644 index 00000000..9c73ee88 --- /dev/null +++ b/internal/attestationloop/sink/backend.go @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package sink contains attestation loop sink implementations. +package sink + +import ( + "context" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/attestationloop" + "github.com/NVIDIA/fleet-intelligence-agent/internal/attestationloop/mapper" + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +) + +type backendSink struct { + client backendclient.Client + jwt func(context.Context) (string, error) +} + +// NewBackendSink creates the backend attestation sink skeleton. +func NewBackendSink(client backendclient.Client, jwt func(context.Context) (string, error)) attestationloop.Sink { + return &backendSink{ + client: client, + jwt: jwt, + } +} + +func (s *backendSink) Export(ctx context.Context, result attestationloop.Result) error { + jwt, err := s.jwt(ctx) + if err != nil { + return err + } + return s.client.SubmitAttestation(ctx, result.NodeID, mapper.ToAttestationRequest(result), jwt) +} diff --git a/internal/attestationloop/source/source.go b/internal/attestationloop/source/source.go new file mode 100644 index 00000000..82ee02d8 --- /dev/null +++ b/internal/attestationloop/source/source.go @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package source contains attestation loop collection adapters. +package source + +import ( + "context" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/attestationloop" +) + +// NVAttestCollector is the local attestation evidence collector dependency. +type NVAttestCollector interface { + Collect(ctx context.Context, nonce string) (*attestationloop.SDKResponse, error) +} diff --git a/internal/attestationloop/types.go b/internal/attestationloop/types.go new file mode 100644 index 00000000..09c0a933 --- /dev/null +++ b/internal/attestationloop/types.go @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package attestationloop owns attestation collection and sync state orchestration. +package attestationloop + +import ( + "context" + "time" +) + +// Result is the agent-owned attestation state model for the new backend sync loop. +type Result struct { + CollectedAt time.Time + NodeID string + NonceRefreshTimestamp time.Time + Success bool + ErrorMessage string + SDKResponse SDKResponse +} + +type SDKResponse struct { + Evidences []EvidenceItem + ResultCode int + ResultMessage string +} + +type EvidenceItem struct { + Arch string + Certificate string + DriverVersion string + Evidence string + Nonce string + VBIOSVersion string + Version string +} + +// NonceProvider retrieves a backend nonce for a node. +type NonceProvider interface { + GetNonce(ctx context.Context, nodeID, jwt string) (nonce string, refreshTS time.Time, refreshedJWT string, err error) +} + +// EvidenceCollector collects attestation evidence from local tooling. +type EvidenceCollector interface { + Collect(ctx context.Context, nonce string) (*SDKResponse, error) +} + +// Sink exports attestation results to an external destination. +type Sink interface { + Export(ctx context.Context, result Result) error +} + +// StateStore is the attestation loop view of local transient store state. +type StateStore interface { + PutAttestation(ctx context.Context, result Result) error + GetAttestation(ctx context.Context) (Result, bool, error) + MarkAttestationExported(ctx context.Context, key string, at time.Time) error + WasAttestationExported(ctx context.Context, key string) (bool, error) +} diff --git a/internal/backendclient/client.go b/internal/backendclient/client.go new file mode 100644 index 00000000..6ef16b10 --- /dev/null +++ b/internal/backendclient/client.go @@ -0,0 +1,233 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package backendclient provides the agent-facing client for backend workflows. +package backendclient + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" +) + +const ( + userAgent = "fleet-intelligence-agent" + maxResponseBodyBytes = 1 << 20 +) + +// Client is the backend workflow client used by enrollment, inventory, and attestation paths. +type Client interface { + Enroll(ctx context.Context, sakToken string) (jwt string, err error) + UpsertNode(ctx context.Context, nodeID string, req NodeUpsertRequest, jwt string) error + GetNonce(ctx context.Context, nodeID string, jwt string) (*NonceResponse, error) + SubmitAttestation(ctx context.Context, nodeID string, req AttestationRequest, jwt string) error + RefreshToken(ctx context.Context, jwt string) (newJWT string, err error) +} + +type client struct { + httpClient *http.Client + baseURL *url.URL +} + +// New creates a backend client from a validated backend base URL. +func New(rawBaseURL string) (Client, error) { + baseURL, err := endpoint.ValidateBackendEndpoint(rawBaseURL) + if err != nil { + return nil, fmt.Errorf("invalid backend base URL: %w", err) + } + + return NewWithHTTPClient(baseURL, &http.Client{Timeout: 30 * time.Second}), nil +} + +// NewWithHTTPClient creates a backend client with an explicit HTTP client. +func NewWithHTTPClient(baseURL *url.URL, httpClient *http.Client) Client { + if httpClient == nil { + httpClient = &http.Client{Timeout: 30 * time.Second} + } + return &client{ + httpClient: httpClient, + baseURL: baseURL, + } +} + +func (c *client) Enroll(ctx context.Context, sakToken string) (string, error) { + if sakToken == "" { + return "", fmt.Errorf("sakToken cannot be empty") + } + + var resp struct { + JWTAssertion string `json:"jwtAssertion"` + } + if err := c.doJSON(ctx, http.MethodPost, []string{"v1", "agent", "enroll"}, sakToken, nil, &resp); err != nil { + return "", mapEnrollError(err) + } + if resp.JWTAssertion == "" { + return "", fmt.Errorf("enrollment response missing jwtAssertion field") + } + return resp.JWTAssertion, nil +} + +func (c *client) UpsertNode(ctx context.Context, nodeID string, req NodeUpsertRequest, jwt string) error { + if nodeID == "" { + return fmt.Errorf("nodeID cannot be empty") + } + if jwt == "" { + return fmt.Errorf("jwt cannot be empty") + } + return c.doJSON(ctx, http.MethodPut, []string{"v1", "agent", "nodes", nodeID}, jwt, req, nil) +} + +func (c *client) GetNonce(ctx context.Context, nodeID string, jwt string) (*NonceResponse, error) { + if nodeID == "" { + return nil, fmt.Errorf("nodeID cannot be empty") + } + if jwt == "" { + return nil, fmt.Errorf("jwt cannot be empty") + } + + var resp NonceResponse + if err := c.doJSON(ctx, http.MethodPost, []string{"v1", "agent", "nodes", nodeID, "nonce"}, jwt, nil, &resp); err != nil { + return nil, err + } + if resp.Nonce == "" { + return nil, fmt.Errorf("nonce response missing nonce field") + } + return &resp, nil +} + +func (c *client) SubmitAttestation(ctx context.Context, nodeID string, req AttestationRequest, jwt string) error { + if nodeID == "" { + return fmt.Errorf("nodeID cannot be empty") + } + if jwt == "" { + return fmt.Errorf("jwt cannot be empty") + } + return c.doJSON(ctx, http.MethodPost, []string{"v1", "agent", "nodes", nodeID, "attestation"}, jwt, req, nil) +} + +func (c *client) RefreshToken(ctx context.Context, jwt string) (string, error) { + if jwt == "" { + return "", fmt.Errorf("jwt cannot be empty") + } + + var resp struct { + JWTAssertion string `json:"jwtAssertion"` + } + req := struct { + JWTAssertion string `json:"jwtAssertion"` + }{ + JWTAssertion: jwt, + } + if err := c.doJSON(ctx, http.MethodPost, []string{"v1", "agent", "token"}, "", req, &resp); err != nil { + return "", err + } + if resp.JWTAssertion == "" { + return "", fmt.Errorf("token refresh response missing jwtAssertion field") + } + return resp.JWTAssertion, nil +} + +func (c *client) doJSON(ctx context.Context, method string, pathElems []string, bearerToken string, reqBody any, respBody any) error { + requestURL, err := endpoint.JoinPath(c.baseURL, pathElems...) + if err != nil { + return fmt.Errorf("failed to construct request URL: %w", err) + } + + var bodyReader io.Reader + if reqBody != nil { + payload, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request body: %w", err) + } + bodyReader = bytes.NewReader(payload) + } + + req, err := http.NewRequestWithContext(ctx, method, requestURL, bodyReader) + if err != nil { + return fmt.Errorf("failed to create HTTP request: %w", err) + } + req.Header.Set("User-Agent", userAgent) + if reqBody != nil { + req.Header.Set("Content-Type", "application/json") + } + if bearerToken != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", bearerToken)) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to make backend request: %w", err) + } + defer resp.Body.Close() + + data, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes+1)) + if err != nil { + return fmt.Errorf("failed to read backend response: %w", err) + } + if len(data) > maxResponseBodyBytes { + return fmt.Errorf("backend response too large (max %d bytes)", maxResponseBodyBytes) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return &HTTPStatusError{ + StatusCode: resp.StatusCode, + Body: string(bytes.TrimSpace(data)), + } + } + + if respBody == nil || len(bytes.TrimSpace(data)) == 0 { + return nil + } + if err := json.Unmarshal(data, respBody); err != nil { + return fmt.Errorf("failed to parse backend response: %w", err) + } + return nil +} + +func mapEnrollError(err error) error { + var statusErr *HTTPStatusError + if !errors.As(err, &statusErr) { + return err + } + + switch statusErr.StatusCode { + case http.StatusBadRequest: + return fmt.Errorf("the token used in the enrollment is not in the correct format") + case http.StatusUnauthorized: + return fmt.Errorf("the token used in the enrollment is incorrect") + case http.StatusForbidden: + return fmt.Errorf("the token used in the enrollment is incorrect/expired") + case http.StatusNotFound: + return fmt.Errorf("the endpoint is not found") + case http.StatusTooManyRequests: + return fmt.Errorf("please retry after some time; server is under heavy load") + case http.StatusBadGateway: + return fmt.Errorf("some temporary issue caused enrollment to fail") + case http.StatusServiceUnavailable: + return fmt.Errorf("service is unavailable currently") + case http.StatusGatewayTimeout: + return fmt.Errorf("service is experiencing load and is slow to respond") + default: + return err + } +} diff --git a/internal/backendclient/client_test.go b/internal/backendclient/client_test.go new file mode 100644 index 00000000..b6911e33 --- /dev/null +++ b/internal/backendclient/client_test.go @@ -0,0 +1,150 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backendclient + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestClient_Enroll(t *testing.T) { + t.Parallel() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/v1/agent/enroll", r.URL.Path) + require.Equal(t, "Bearer sak-token", r.Header.Get("Authorization")) + require.Equal(t, userAgent, r.Header.Get("User-Agent")) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"jwtAssertion": "jwt-token"}) + })) + defer server.Close() + + c, err := New(server.URL) + require.NoError(t, err) + c = NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + + jwt, err := c.Enroll(context.Background(), "sak-token") + require.NoError(t, err) + require.Equal(t, "jwt-token", jwt) +} + +func TestClient_UpsertNode(t *testing.T) { + t.Parallel() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPut, r.Method) + require.Equal(t, "/v1/agent/nodes/node-1", r.URL.Path) + require.Equal(t, "Bearer jwt-token", r.Header.Get("Authorization")) + + var req NodeUpsertRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + require.Equal(t, "node-1", req.Hostname) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + err := c.UpsertNode(context.Background(), "node-1", NodeUpsertRequest{Hostname: "node-1"}, "jwt-token") + require.NoError(t, err) +} + +func TestClient_GetNonce(t *testing.T) { + t.Parallel() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/v1/agent/nodes/node-1/nonce", r.URL.Path) + require.Equal(t, "Bearer jwt-token", r.Header.Get("Authorization")) + _ = json.NewEncoder(w).Encode(NonceResponse{ + Nonce: "abc123", + JWTAssertion: "new-jwt", + }) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + resp, err := c.GetNonce(context.Background(), "node-1", "jwt-token") + require.NoError(t, err) + require.Equal(t, "abc123", resp.Nonce) + require.Equal(t, "new-jwt", resp.JWTAssertion) +} + +func TestClient_SubmitAttestation(t *testing.T) { + t.Parallel() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/v1/agent/nodes/node-1/attestation", r.URL.Path) + require.Equal(t, "Bearer jwt-token", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + err := c.SubmitAttestation(context.Background(), "node-1", AttestationRequest{}, "jwt-token") + require.NoError(t, err) +} + +func TestClient_RefreshToken(t *testing.T) { + t.Parallel() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/v1/agent/token", r.URL.Path) + + var req struct { + JWTAssertion string `json:"jwtAssertion"` + } + require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + require.Equal(t, "jwt-token", req.JWTAssertion) + + _ = json.NewEncoder(w).Encode(map[string]string{"jwtAssertion": "new-jwt-token"}) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + jwt, err := c.RefreshToken(context.Background(), "jwt-token") + require.NoError(t, err) + require.Equal(t, "new-jwt-token", jwt) +} + +func TestClient_EnrollMapsHTTPStatus(t *testing.T) { + t.Parallel() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad token", http.StatusUnauthorized) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.Enroll(context.Background(), "sak-token") + require.Error(t, err) + require.Contains(t, err.Error(), "incorrect") +} + +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + parsed, err := url.Parse(raw) + require.NoError(t, err) + return parsed +} diff --git a/internal/backendclient/errors.go b/internal/backendclient/errors.go new file mode 100644 index 00000000..29411e92 --- /dev/null +++ b/internal/backendclient/errors.go @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backendclient + +import ( + "errors" + "fmt" +) + +// ErrNotImplemented is returned by skeleton backend client methods. +var ErrNotImplemented = errors.New("backend client not implemented") + +// HTTPStatusError captures a non-2xx backend response. +type HTTPStatusError struct { + StatusCode int + Body string +} + +func (e *HTTPStatusError) Error() string { + if e.Body != "" { + return fmt.Sprintf("backend request failed with status %d: %s", e.StatusCode, e.Body) + } + return fmt.Sprintf("backend request failed with status %d", e.StatusCode) +} diff --git a/internal/backendclient/types.go b/internal/backendclient/types.go new file mode 100644 index 00000000..7d9ef5e3 --- /dev/null +++ b/internal/backendclient/types.go @@ -0,0 +1,137 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backendclient + +import "time" + +// NodeUpsertRequest is the backend DTO for node inventory upserts. +type NodeUpsertRequest struct { + Hostname string `json:"hostname"` + Resources NodeResources `json:"resources"` + FleetintVersion string `json:"gpuHealthVersion"` + GPUDriverVersion string `json:"gpuDriverVersion"` + CUDAVersion string `json:"cudaVersion"` + DCGMVersion string `json:"dcgmVersion"` + ContainerRuntimeVersion string `json:"containerRuntimeVersion"` + KernelVersion string `json:"kernelVersion"` + OSImage string `json:"osImage"` + OperatingSystem string `json:"operatingSystem"` + SystemUUID string `json:"systemUUID"` + MachineID string `json:"machineId"` + BootID string `json:"bootID"` + NetPrivateIP string `json:"netPrivateIP,omitempty"` + NetPublicIP string `json:"netPublicIP,omitempty"` + InventoryHash string `json:"inventoryHash,omitempty"` +} + +type NodeResources struct { + CPUInfo CPUInfo `json:"cpuInfo"` + MemoryInfo MemoryInfo `json:"memoryInfo"` + GPUInfo GPUInfo `json:"gpuInfo"` + DiskInfo DiskInfo `json:"diskInfo"` + NICInfo NICInfo `json:"nicInfo"` +} + +type CPUInfo struct { + Type string `json:"type"` + Manufacturer string `json:"manufacturer"` + Architecture string `json:"architecture"` + LogicalCores int64 `json:"logicalCores"` +} + +type MemoryInfo struct { + TotalBytes uint64 `json:"totalBytes"` +} + +type GPUInfo struct { + Product string `json:"product"` + Manufacturer string `json:"manufacturer"` + Architecture string `json:"architecture"` + Memory string `json:"memory"` + GPUs []GPUDevice `json:"gpus"` +} + +type GPUDevice struct { + UUID string `json:"uuid"` + BusID string `json:"busID"` + SN string `json:"sn"` + MinorID string `json:"minorID"` + BoardID int `json:"boardID"` + VBIOSVersion string `json:"vbiosVersion"` + ChassisSN string `json:"chassisSN"` + GPUIndex string `json:"gpuIndex,omitempty"` +} + +type DiskInfo struct { + ContainerRootDisk string `json:"containerRootDisk"` + BlockDevices []BlockDevice `json:"blockDevices"` +} + +type BlockDevice struct { + Name string `json:"name"` + Type string `json:"type"` + Size int64 `json:"size"` + WWN string `json:"wwn"` + MountPoint string `json:"mountPoint"` + FSType string `json:"fsType"` + PartUUID string `json:"partUUID"` + Parents []string `json:"parents"` +} + +type NICInfo struct { + PrivateIPInterfaces []NICInterface `json:"privateIPInterfaces"` +} + +type NICInterface struct { + Interface string `json:"interface"` + MAC string `json:"mac"` + IP string `json:"ip"` +} + +// NonceResponse is the backend DTO for node-scoped nonce responses. +type NonceResponse struct { + Nonce string `json:"nonce"` + JWTAssertion string `json:"jwtAssertion,omitempty"` + NonceRefreshTimestamp time.Time `json:"nonceRefreshTimestamp"` +} + +// AttestationRequest is the backend DTO for attestation submission. +type AttestationRequest struct { + AttestationData AttestationData `json:"attestationData"` +} + +type AttestationData struct { + SDKResponse AttestationSDKResponse `json:"sdkResponse"` + NonceRefreshTimestamp time.Time `json:"nonceRefreshTimestamp"` + Success bool `json:"success"` + ErrorMessage string `json:"errorMessage,omitempty"` +} + +type AttestationSDKResponse struct { + Evidences []EvidenceItem `json:"evidences"` + ResultCode int `json:"resultCode"` + ResultMessage string `json:"resultMessage"` +} + +type EvidenceItem struct { + Arch string `json:"arch"` + Certificate string `json:"certificate"` + DriverVersion string `json:"driverVersion"` + Evidence string `json:"evidence"` + Nonce string `json:"nonce"` + VBIOSVersion string `json:"vbiosVersion"` + Version string `json:"version"` +} diff --git a/internal/inventory/manager.go b/internal/inventory/manager.go new file mode 100644 index 00000000..efbb8a70 --- /dev/null +++ b/internal/inventory/manager.go @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inventory + +import ( + "context" + "fmt" + "time" +) + +// Manager coordinates periodic inventory collection into a store. +type Manager interface { + Run(ctx context.Context) error + CollectOnce(ctx context.Context) (*Snapshot, error) +} + +type manager struct { + source Source + store StateStore + interval time.Duration +} + +// NewManager creates an inventory manager skeleton. +func NewManager(source Source, store StateStore, interval time.Duration) Manager { + return &manager{ + source: source, + store: store, + interval: interval, + } +} + +func (m *manager) Run(ctx context.Context) error { + if _, err := m.CollectOnce(ctx); err != nil { + return err + } + return fmt.Errorf("inventory manager run loop not implemented") +} + +func (m *manager) CollectOnce(ctx context.Context) (*Snapshot, error) { + if m.source == nil { + return nil, fmt.Errorf("inventory source is required") + } + snap, err := m.source.Collect(ctx) + if err != nil { + return nil, err + } + if snap == nil { + return nil, fmt.Errorf("inventory source returned nil snapshot") + } + if m.store != nil { + if err := m.store.PutInventory(ctx, *snap); err != nil { + return nil, err + } + } + return snap, nil +} diff --git a/internal/inventory/mapper/backend.go b/internal/inventory/mapper/backend.go new file mode 100644 index 00000000..d389d1d9 --- /dev/null +++ b/internal/inventory/mapper/backend.go @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mapper contains inventory payload mappers. +package mapper + +import ( + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" +) + +// ToNodeUpsertRequest maps an inventory snapshot to the backend node-upsert contract. +func ToNodeUpsertRequest(s inventory.Snapshot) backendclient.NodeUpsertRequest { + return backendclient.NodeUpsertRequest{ + Hostname: s.Hostname, + MachineID: s.MachineID, + SystemUUID: s.SystemUUID, + BootID: s.BootID, + OperatingSystem: s.OperatingSystem, + OSImage: s.OSImage, + KernelVersion: s.KernelVersion, + FleetintVersion: s.FleetintVersion, + GPUDriverVersion: s.GPUDriverVersion, + CUDAVersion: s.CUDAVersion, + DCGMVersion: s.DCGMVersion, + ContainerRuntimeVersion: s.ContainerRuntimeVersion, + NetPrivateIP: s.NetPrivateIP, + NetPublicIP: s.NetPublicIP, + InventoryHash: s.InventoryHash, + Resources: backendclient.NodeResources{ + CPUInfo: backendclient.CPUInfo{ + Type: s.Resources.CPUInfo.Type, + Manufacturer: s.Resources.CPUInfo.Manufacturer, + Architecture: s.Resources.CPUInfo.Architecture, + LogicalCores: s.Resources.CPUInfo.LogicalCores, + }, + MemoryInfo: backendclient.MemoryInfo{ + TotalBytes: s.Resources.MemoryInfo.TotalBytes, + }, + }, + } +} diff --git a/internal/inventory/sink/backend.go b/internal/inventory/sink/backend.go new file mode 100644 index 00000000..15efba7d --- /dev/null +++ b/internal/inventory/sink/backend.go @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package sink contains inventory sink implementations. +package sink + +import ( + "context" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/mapper" +) + +type backendSink struct { + client backendclient.Client + jwt func(context.Context) (string, error) +} + +// NewBackendSink creates the backend inventory sink skeleton. +func NewBackendSink(client backendclient.Client, jwt func(context.Context) (string, error)) inventory.Sink { + return &backendSink{ + client: client, + jwt: jwt, + } +} + +func (s *backendSink) Export(ctx context.Context, snap inventory.Snapshot) error { + jwt, err := s.jwt(ctx) + if err != nil { + return err + } + return s.client.UpsertNode(ctx, snap.NodeID, mapper.ToNodeUpsertRequest(snap), jwt) +} diff --git a/internal/inventory/source/source.go b/internal/inventory/source/source.go new file mode 100644 index 00000000..20c0fec4 --- /dev/null +++ b/internal/inventory/source/source.go @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package source contains inventory collection adapters. +package source + +import ( + "context" + "fmt" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" +) + +// MachineInfoCollector is the local machine inventory collector dependency. +type MachineInfoCollector interface { + Collect(ctx context.Context) (*inventory.Snapshot, error) +} + +type machineInfoSource struct { + collector MachineInfoCollector +} + +// NewMachineInfoSource wraps the machine inventory collector as an inventory source. +func NewMachineInfoSource(collector MachineInfoCollector) inventory.Source { + return &machineInfoSource{collector: collector} +} + +func (s *machineInfoSource) Collect(ctx context.Context) (*inventory.Snapshot, error) { + if s.collector == nil { + return nil, fmt.Errorf("machine info collector is required") + } + return s.collector.Collect(ctx) +} diff --git a/internal/inventory/types.go b/internal/inventory/types.go new file mode 100644 index 00000000..4e5ebe3a --- /dev/null +++ b/internal/inventory/types.go @@ -0,0 +1,126 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package inventory owns inventory collection and sync state. +package inventory + +import ( + "context" + "time" +) + +// Snapshot is the agent-owned inventory state model. +type Snapshot struct { + CollectedAt time.Time + NodeID string + InventoryHash string + Hostname string + MachineID string + SystemUUID string + BootID string + OperatingSystem string + OSImage string + KernelVersion string + FleetintVersion string + GPUDriverVersion string + CUDAVersion string + DCGMVersion string + ContainerRuntimeVersion string + NetPrivateIP string + NetPublicIP string + Resources Resources +} + +type Resources struct { + CPUInfo CPUInfo + MemoryInfo MemoryInfo + GPUInfo GPUInfo + DiskInfo DiskInfo + NICInfo NICInfo +} + +type CPUInfo struct { + Type string + Manufacturer string + Architecture string + LogicalCores int64 +} + +type MemoryInfo struct { + TotalBytes uint64 +} + +type GPUInfo struct { + Product string + Manufacturer string + Architecture string + Memory string + GPUs []GPUDevice +} + +type GPUDevice struct { + UUID string + BusID string + SN string + MinorID string + BoardID int + VBIOSVersion string + ChassisSN string + GPUIndex string +} + +type DiskInfo struct { + ContainerRootDisk string + BlockDevices []BlockDevice +} + +type BlockDevice struct { + Name string + Type string + Size int64 + WWN string + MountPoint string + FSType string + PartUUID string + Parents []string +} + +type NICInfo struct { + PrivateIPInterfaces []NICInterface +} + +type NICInterface struct { + Interface string + MAC string + IP string +} + +// Source collects inventory from local providers. +type Source interface { + Collect(ctx context.Context) (*Snapshot, error) +} + +// Sink exports inventory snapshots to an external destination. +type Sink interface { + Export(ctx context.Context, snap Snapshot) error +} + +// StateStore is the inventory package view of local transient store state. +type StateStore interface { + PutInventory(ctx context.Context, snap Snapshot) error + GetInventory(ctx context.Context) (Snapshot, bool, error) + MarkInventoryExported(ctx context.Context, hash string, at time.Time) error + LastExportedInventoryHash(ctx context.Context) (string, error) +} diff --git a/internal/store/memory.go b/internal/store/memory.go new file mode 100644 index 00000000..fbe00f86 --- /dev/null +++ b/internal/store/memory.go @@ -0,0 +1,103 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package store contains transient in-agent state stores. +package store + +import ( + "context" + "sync" + "time" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/attestationloop" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" +) + +// MemoryStore is the initial in-memory implementation for inventory and attestation state. +type MemoryStore struct { + mu sync.RWMutex + + inventory inventory.Snapshot + hasInventory bool + lastInventoryHash string + lastInventorySyncTS time.Time + + attestation attestationloop.Result + hasAttestation bool + exportedAttestationKeys map[string]time.Time +} + +// NewMemoryStore creates an empty in-memory state store. +func NewMemoryStore() *MemoryStore { + return &MemoryStore{ + exportedAttestationKeys: make(map[string]time.Time), + } +} + +func (s *MemoryStore) PutInventory(_ context.Context, snap inventory.Snapshot) error { + s.mu.Lock() + defer s.mu.Unlock() + s.inventory = snap + s.hasInventory = true + return nil +} + +func (s *MemoryStore) GetInventory(_ context.Context) (inventory.Snapshot, bool, error) { + s.mu.RLock() + defer s.mu.RUnlock() + return s.inventory, s.hasInventory, nil +} + +func (s *MemoryStore) MarkInventoryExported(_ context.Context, hash string, at time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + s.lastInventoryHash = hash + s.lastInventorySyncTS = at + return nil +} + +func (s *MemoryStore) LastExportedInventoryHash(_ context.Context) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + return s.lastInventoryHash, nil +} + +func (s *MemoryStore) PutAttestation(_ context.Context, result attestationloop.Result) error { + s.mu.Lock() + defer s.mu.Unlock() + s.attestation = result + s.hasAttestation = true + return nil +} + +func (s *MemoryStore) GetAttestation(_ context.Context) (attestationloop.Result, bool, error) { + s.mu.RLock() + defer s.mu.RUnlock() + return s.attestation, s.hasAttestation, nil +} + +func (s *MemoryStore) MarkAttestationExported(_ context.Context, key string, at time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + s.exportedAttestationKeys[key] = at + return nil +} + +func (s *MemoryStore) WasAttestationExported(_ context.Context, key string) (bool, error) { + s.mu.RLock() + defer s.mu.RUnlock() + _, ok := s.exportedAttestationKeys[key] + return ok, nil +} From de7635f1ff9204b76ebd6a56e84880811d54c797 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Wed, 15 Apr 2026 11:49:08 -0700 Subject: [PATCH 02/22] fix: address initial review feedback Signed-off-by: Jingxiang Zhang --- internal/agentstate/state.go | 8 ++-- internal/attestationloop/sink/backend.go | 10 +++++ internal/backendclient/client_test.go | 12 ++++-- internal/backendclient/errors.go | 22 +++++++++- internal/inventory/mapper/backend.go | 51 ++++++++++++++++++++++++ internal/inventory/sink/backend.go | 10 +++++ 6 files changed, 104 insertions(+), 9 deletions(-) diff --git a/internal/agentstate/state.go b/internal/agentstate/state.go index 6149e0a3..74ce21c9 100644 --- a/internal/agentstate/state.go +++ b/internal/agentstate/state.go @@ -20,15 +20,15 @@ import "context" // State provides local persisted metadata/state access for backend workflows. type State interface { - GetBackendBaseURL(ctx context.Context) (string, error) + GetBackendBaseURL(ctx context.Context) (value string, ok bool, err error) SetBackendBaseURL(ctx context.Context, value string) error - GetJWT(ctx context.Context) (string, error) + GetJWT(ctx context.Context) (value string, ok bool, err error) SetJWT(ctx context.Context, value string) error - GetSAK(ctx context.Context) (string, error) + GetSAK(ctx context.Context) (value string, ok bool, err error) SetSAK(ctx context.Context, value string) error - GetNodeID(ctx context.Context) (string, error) + GetNodeID(ctx context.Context) (value string, ok bool, err error) SetNodeID(ctx context.Context, value string) error } diff --git a/internal/attestationloop/sink/backend.go b/internal/attestationloop/sink/backend.go index 9c73ee88..6b1399fb 100644 --- a/internal/attestationloop/sink/backend.go +++ b/internal/attestationloop/sink/backend.go @@ -18,6 +18,7 @@ package sink import ( "context" + "fmt" "github.com/NVIDIA/fleet-intelligence-agent/internal/attestationloop" "github.com/NVIDIA/fleet-intelligence-agent/internal/attestationloop/mapper" @@ -38,9 +39,18 @@ func NewBackendSink(client backendclient.Client, jwt func(context.Context) (stri } func (s *backendSink) Export(ctx context.Context, result attestationloop.Result) error { + if s.jwt == nil { + return fmt.Errorf("attestation backend export requires jwt provider") + } + if s.client == nil { + return fmt.Errorf("attestation backend export requires backend client") + } jwt, err := s.jwt(ctx) if err != nil { return err } + if jwt == "" { + return fmt.Errorf("attestation backend export received empty jwt") + } return s.client.SubmitAttestation(ctx, result.NodeID, mapper.ToAttestationRequest(result), jwt) } diff --git a/internal/backendclient/client_test.go b/internal/backendclient/client_test.go index b6911e33..72d9a7ad 100644 --- a/internal/backendclient/client_test.go +++ b/internal/backendclient/client_test.go @@ -39,15 +39,21 @@ func TestClient_Enroll(t *testing.T) { })) defer server.Close() - c, err := New(server.URL) - require.NoError(t, err) - c = NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) jwt, err := c.Enroll(context.Background(), "sak-token") require.NoError(t, err) require.Equal(t, "jwt-token", jwt) } +func TestNew(t *testing.T) { + t.Parallel() + + c, err := New("https://backend.example.com") + require.NoError(t, err) + require.NotNil(t, c) +} + func TestClient_UpsertNode(t *testing.T) { t.Parallel() diff --git a/internal/backendclient/errors.go b/internal/backendclient/errors.go index 29411e92..890c2dd4 100644 --- a/internal/backendclient/errors.go +++ b/internal/backendclient/errors.go @@ -18,6 +18,7 @@ package backendclient import ( "errors" "fmt" + "strings" ) // ErrNotImplemented is returned by skeleton backend client methods. @@ -30,8 +31,25 @@ type HTTPStatusError struct { } func (e *HTTPStatusError) Error() string { - if e.Body != "" { - return fmt.Sprintf("backend request failed with status %d: %s", e.StatusCode, e.Body) + body := sanitizeErrorBody(e.Body) + if body != "" { + return fmt.Sprintf("backend request failed with status %d: %s", e.StatusCode, body) } return fmt.Sprintf("backend request failed with status %d", e.StatusCode) } + +func sanitizeErrorBody(body string) string { + const maxLen = 200 + + body = strings.TrimSpace(body) + if body == "" { + return "" + } + body = strings.ReplaceAll(body, "\n", " ") + body = strings.ReplaceAll(body, "\r", " ") + body = strings.Join(strings.Fields(body), " ") + if len(body) <= maxLen { + return body + } + return body[:maxLen] + "...(truncated)" +} diff --git a/internal/inventory/mapper/backend.go b/internal/inventory/mapper/backend.go index d389d1d9..3bb5d93f 100644 --- a/internal/inventory/mapper/backend.go +++ b/internal/inventory/mapper/backend.go @@ -23,6 +23,43 @@ import ( // ToNodeUpsertRequest maps an inventory snapshot to the backend node-upsert contract. func ToNodeUpsertRequest(s inventory.Snapshot) backendclient.NodeUpsertRequest { + gpus := make([]backendclient.GPUDevice, 0, len(s.Resources.GPUInfo.GPUs)) + for _, gpu := range s.Resources.GPUInfo.GPUs { + gpus = append(gpus, backendclient.GPUDevice{ + UUID: gpu.UUID, + BusID: gpu.BusID, + SN: gpu.SN, + MinorID: gpu.MinorID, + BoardID: gpu.BoardID, + VBIOSVersion: gpu.VBIOSVersion, + ChassisSN: gpu.ChassisSN, + GPUIndex: gpu.GPUIndex, + }) + } + + blockDevices := make([]backendclient.BlockDevice, 0, len(s.Resources.DiskInfo.BlockDevices)) + for _, disk := range s.Resources.DiskInfo.BlockDevices { + blockDevices = append(blockDevices, backendclient.BlockDevice{ + Name: disk.Name, + Type: disk.Type, + Size: disk.Size, + WWN: disk.WWN, + MountPoint: disk.MountPoint, + FSType: disk.FSType, + PartUUID: disk.PartUUID, + Parents: append([]string(nil), disk.Parents...), + }) + } + + interfaces := make([]backendclient.NICInterface, 0, len(s.Resources.NICInfo.PrivateIPInterfaces)) + for _, nic := range s.Resources.NICInfo.PrivateIPInterfaces { + interfaces = append(interfaces, backendclient.NICInterface{ + Interface: nic.Interface, + MAC: nic.MAC, + IP: nic.IP, + }) + } + return backendclient.NodeUpsertRequest{ Hostname: s.Hostname, MachineID: s.MachineID, @@ -49,6 +86,20 @@ func ToNodeUpsertRequest(s inventory.Snapshot) backendclient.NodeUpsertRequest { MemoryInfo: backendclient.MemoryInfo{ TotalBytes: s.Resources.MemoryInfo.TotalBytes, }, + GPUInfo: backendclient.GPUInfo{ + Product: s.Resources.GPUInfo.Product, + Manufacturer: s.Resources.GPUInfo.Manufacturer, + Architecture: s.Resources.GPUInfo.Architecture, + Memory: s.Resources.GPUInfo.Memory, + GPUs: gpus, + }, + DiskInfo: backendclient.DiskInfo{ + ContainerRootDisk: s.Resources.DiskInfo.ContainerRootDisk, + BlockDevices: blockDevices, + }, + NICInfo: backendclient.NICInfo{ + PrivateIPInterfaces: interfaces, + }, }, } } diff --git a/internal/inventory/sink/backend.go b/internal/inventory/sink/backend.go index 15efba7d..1e0e007b 100644 --- a/internal/inventory/sink/backend.go +++ b/internal/inventory/sink/backend.go @@ -18,6 +18,7 @@ package sink import ( "context" + "fmt" "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" @@ -38,9 +39,18 @@ func NewBackendSink(client backendclient.Client, jwt func(context.Context) (stri } func (s *backendSink) Export(ctx context.Context, snap inventory.Snapshot) error { + if s.jwt == nil { + return fmt.Errorf("inventory backend export requires jwt provider") + } + if s.client == nil { + return fmt.Errorf("inventory backend export requires backend client") + } jwt, err := s.jwt(ctx) if err != nil { return err } + if jwt == "" { + return fmt.Errorf("inventory backend export received empty jwt") + } return s.client.UpsertNode(ctx, snap.NodeID, mapper.ToNodeUpsertRequest(snap), jwt) } From 31682f6bd02eb21dd602245feab423f3c7279360 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Wed, 15 Apr 2026 14:07:01 -0700 Subject: [PATCH 03/22] fix: address lint and toolchain drift Signed-off-by: Jingxiang Zhang --- .github/workflows/ci.yml | 2 +- .github/workflows/release.yml | 2 +- .golangci.yml | 2 +- internal/attestationloop/manager.go | 2 +- internal/attestationloop/mapper/backend.go | 7 +++- internal/attestationloop/sink/backend.go | 5 ++- internal/attestationloop/types.go | 6 +-- internal/backendclient/client.go | 18 ++++++--- internal/backendclient/client_test.go | 4 +- internal/inventory/manager.go | 2 +- internal/inventory/mapper/backend.go | 7 +++- internal/inventory/sink/backend.go | 5 ++- internal/inventory/types.go | 6 +-- internal/store/memory.go | 46 +++++++++++++++------- 14 files changed, 74 insertions(+), 40 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2ac732bd..7ec52836 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: branches: [ main, 'release/**' ] env: - GO_VERSION: '1.26.1' + GO_VERSION: '1.26.2' GOFLAGS: '-trimpath' jobs: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index edc1403f..882e1970 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -6,7 +6,7 @@ on: - 'v*' # Triggers on version tags like v1.0.0, v2.1.3, etc. env: - GO_VERSION: '1.26.1' + GO_VERSION: '1.26.2' permissions: contents: write # Needed for creating GitHub releases diff --git a/.golangci.yml b/.golangci.yml index 4ddb2e22..67e46a3a 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,6 +1,6 @@ version: "2" run: - go: 1.26.1 + go: 1.26.2 linters: default: none enable: diff --git a/internal/attestationloop/manager.go b/internal/attestationloop/manager.go index 2f1c3701..a52a9bc1 100644 --- a/internal/attestationloop/manager.go +++ b/internal/attestationloop/manager.go @@ -104,7 +104,7 @@ func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { result.SDKResponse = *sdkResp } if m.store != nil { - if err := m.store.PutAttestation(ctx, *result); err != nil { + if err := m.store.PutAttestation(ctx, result); err != nil { return nil, err } } diff --git a/internal/attestationloop/mapper/backend.go b/internal/attestationloop/mapper/backend.go index cbaae521..1f2160fb 100644 --- a/internal/attestationloop/mapper/backend.go +++ b/internal/attestationloop/mapper/backend.go @@ -22,8 +22,11 @@ import ( ) // ToAttestationRequest maps an attestation result to the backend attestation contract. -func ToAttestationRequest(r attestationloop.Result) backendclient.AttestationRequest { - req := backendclient.AttestationRequest{ +func ToAttestationRequest(r *attestationloop.Result) *backendclient.AttestationRequest { + if r == nil { + return nil + } + req := &backendclient.AttestationRequest{ AttestationData: backendclient.AttestationData{ NonceRefreshTimestamp: r.NonceRefreshTimestamp, Success: r.Success, diff --git a/internal/attestationloop/sink/backend.go b/internal/attestationloop/sink/backend.go index 6b1399fb..65bd56bb 100644 --- a/internal/attestationloop/sink/backend.go +++ b/internal/attestationloop/sink/backend.go @@ -38,13 +38,16 @@ func NewBackendSink(client backendclient.Client, jwt func(context.Context) (stri } } -func (s *backendSink) Export(ctx context.Context, result attestationloop.Result) error { +func (s *backendSink) Export(ctx context.Context, result *attestationloop.Result) error { if s.jwt == nil { return fmt.Errorf("attestation backend export requires jwt provider") } if s.client == nil { return fmt.Errorf("attestation backend export requires backend client") } + if result == nil { + return fmt.Errorf("attestation backend export requires attestation result") + } jwt, err := s.jwt(ctx) if err != nil { return err diff --git a/internal/attestationloop/types.go b/internal/attestationloop/types.go index 09c0a933..8491d6a0 100644 --- a/internal/attestationloop/types.go +++ b/internal/attestationloop/types.go @@ -59,13 +59,13 @@ type EvidenceCollector interface { // Sink exports attestation results to an external destination. type Sink interface { - Export(ctx context.Context, result Result) error + Export(ctx context.Context, result *Result) error } // StateStore is the attestation loop view of local transient store state. type StateStore interface { - PutAttestation(ctx context.Context, result Result) error - GetAttestation(ctx context.Context) (Result, bool, error) + PutAttestation(ctx context.Context, result *Result) error + GetAttestation(ctx context.Context) (*Result, bool, error) MarkAttestationExported(ctx context.Context, key string, at time.Time) error WasAttestationExported(ctx context.Context, key string) (bool, error) } diff --git a/internal/backendclient/client.go b/internal/backendclient/client.go index 6ef16b10..73f5dddc 100644 --- a/internal/backendclient/client.go +++ b/internal/backendclient/client.go @@ -31,16 +31,16 @@ import ( ) const ( - userAgent = "fleet-intelligence-agent" - maxResponseBodyBytes = 1 << 20 + userAgent = "fleet-intelligence-agent" + maxResponseBodyBytes = 1 << 20 ) // Client is the backend workflow client used by enrollment, inventory, and attestation paths. type Client interface { Enroll(ctx context.Context, sakToken string) (jwt string, err error) - UpsertNode(ctx context.Context, nodeID string, req NodeUpsertRequest, jwt string) error + UpsertNode(ctx context.Context, nodeID string, req *NodeUpsertRequest, jwt string) error GetNonce(ctx context.Context, nodeID string, jwt string) (*NonceResponse, error) - SubmitAttestation(ctx context.Context, nodeID string, req AttestationRequest, jwt string) error + SubmitAttestation(ctx context.Context, nodeID string, req *AttestationRequest, jwt string) error RefreshToken(ctx context.Context, jwt string) (newJWT string, err error) } @@ -87,13 +87,16 @@ func (c *client) Enroll(ctx context.Context, sakToken string) (string, error) { return resp.JWTAssertion, nil } -func (c *client) UpsertNode(ctx context.Context, nodeID string, req NodeUpsertRequest, jwt string) error { +func (c *client) UpsertNode(ctx context.Context, nodeID string, req *NodeUpsertRequest, jwt string) error { if nodeID == "" { return fmt.Errorf("nodeID cannot be empty") } if jwt == "" { return fmt.Errorf("jwt cannot be empty") } + if req == nil { + return fmt.Errorf("node upsert request cannot be nil") + } return c.doJSON(ctx, http.MethodPut, []string{"v1", "agent", "nodes", nodeID}, jwt, req, nil) } @@ -115,13 +118,16 @@ func (c *client) GetNonce(ctx context.Context, nodeID string, jwt string) (*Nonc return &resp, nil } -func (c *client) SubmitAttestation(ctx context.Context, nodeID string, req AttestationRequest, jwt string) error { +func (c *client) SubmitAttestation(ctx context.Context, nodeID string, req *AttestationRequest, jwt string) error { if nodeID == "" { return fmt.Errorf("nodeID cannot be empty") } if jwt == "" { return fmt.Errorf("jwt cannot be empty") } + if req == nil { + return fmt.Errorf("attestation request cannot be nil") + } return c.doJSON(ctx, http.MethodPost, []string{"v1", "agent", "nodes", nodeID, "attestation"}, jwt, req, nil) } diff --git a/internal/backendclient/client_test.go b/internal/backendclient/client_test.go index 72d9a7ad..b9d4cb76 100644 --- a/internal/backendclient/client_test.go +++ b/internal/backendclient/client_test.go @@ -70,7 +70,7 @@ func TestClient_UpsertNode(t *testing.T) { defer server.Close() c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) - err := c.UpsertNode(context.Background(), "node-1", NodeUpsertRequest{Hostname: "node-1"}, "jwt-token") + err := c.UpsertNode(context.Background(), "node-1", &NodeUpsertRequest{Hostname: "node-1"}, "jwt-token") require.NoError(t, err) } @@ -107,7 +107,7 @@ func TestClient_SubmitAttestation(t *testing.T) { defer server.Close() c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) - err := c.SubmitAttestation(context.Background(), "node-1", AttestationRequest{}, "jwt-token") + err := c.SubmitAttestation(context.Background(), "node-1", &AttestationRequest{}, "jwt-token") require.NoError(t, err) } diff --git a/internal/inventory/manager.go b/internal/inventory/manager.go index efbb8a70..1310f002 100644 --- a/internal/inventory/manager.go +++ b/internal/inventory/manager.go @@ -61,7 +61,7 @@ func (m *manager) CollectOnce(ctx context.Context) (*Snapshot, error) { return nil, fmt.Errorf("inventory source returned nil snapshot") } if m.store != nil { - if err := m.store.PutInventory(ctx, *snap); err != nil { + if err := m.store.PutInventory(ctx, snap); err != nil { return nil, err } } diff --git a/internal/inventory/mapper/backend.go b/internal/inventory/mapper/backend.go index 3bb5d93f..7f68d017 100644 --- a/internal/inventory/mapper/backend.go +++ b/internal/inventory/mapper/backend.go @@ -22,7 +22,10 @@ import ( ) // ToNodeUpsertRequest maps an inventory snapshot to the backend node-upsert contract. -func ToNodeUpsertRequest(s inventory.Snapshot) backendclient.NodeUpsertRequest { +func ToNodeUpsertRequest(s *inventory.Snapshot) *backendclient.NodeUpsertRequest { + if s == nil { + return nil + } gpus := make([]backendclient.GPUDevice, 0, len(s.Resources.GPUInfo.GPUs)) for _, gpu := range s.Resources.GPUInfo.GPUs { gpus = append(gpus, backendclient.GPUDevice{ @@ -60,7 +63,7 @@ func ToNodeUpsertRequest(s inventory.Snapshot) backendclient.NodeUpsertRequest { }) } - return backendclient.NodeUpsertRequest{ + return &backendclient.NodeUpsertRequest{ Hostname: s.Hostname, MachineID: s.MachineID, SystemUUID: s.SystemUUID, diff --git a/internal/inventory/sink/backend.go b/internal/inventory/sink/backend.go index 1e0e007b..f0ac82e9 100644 --- a/internal/inventory/sink/backend.go +++ b/internal/inventory/sink/backend.go @@ -38,13 +38,16 @@ func NewBackendSink(client backendclient.Client, jwt func(context.Context) (stri } } -func (s *backendSink) Export(ctx context.Context, snap inventory.Snapshot) error { +func (s *backendSink) Export(ctx context.Context, snap *inventory.Snapshot) error { if s.jwt == nil { return fmt.Errorf("inventory backend export requires jwt provider") } if s.client == nil { return fmt.Errorf("inventory backend export requires backend client") } + if snap == nil { + return fmt.Errorf("inventory backend export requires inventory snapshot") + } jwt, err := s.jwt(ctx) if err != nil { return err diff --git a/internal/inventory/types.go b/internal/inventory/types.go index 4e5ebe3a..131cf3fe 100644 --- a/internal/inventory/types.go +++ b/internal/inventory/types.go @@ -114,13 +114,13 @@ type Source interface { // Sink exports inventory snapshots to an external destination. type Sink interface { - Export(ctx context.Context, snap Snapshot) error + Export(ctx context.Context, snap *Snapshot) error } // StateStore is the inventory package view of local transient store state. type StateStore interface { - PutInventory(ctx context.Context, snap Snapshot) error - GetInventory(ctx context.Context) (Snapshot, bool, error) + PutInventory(ctx context.Context, snap *Snapshot) error + GetInventory(ctx context.Context) (*Snapshot, bool, error) MarkInventoryExported(ctx context.Context, hash string, at time.Time) error LastExportedInventoryHash(ctx context.Context) (string, error) } diff --git a/internal/store/memory.go b/internal/store/memory.go index fbe00f86..a93ddb00 100644 --- a/internal/store/memory.go +++ b/internal/store/memory.go @@ -29,14 +29,14 @@ import ( type MemoryStore struct { mu sync.RWMutex - inventory inventory.Snapshot - hasInventory bool - lastInventoryHash string - lastInventorySyncTS time.Time + inventory *inventory.Snapshot + hasInventory bool + lastInventoryHash string + lastInventorySyncTS time.Time - attestation attestationloop.Result - hasAttestation bool - exportedAttestationKeys map[string]time.Time + attestation *attestationloop.Result + hasAttestation bool + exportedAttestationKeys map[string]time.Time } // NewMemoryStore creates an empty in-memory state store. @@ -46,18 +46,26 @@ func NewMemoryStore() *MemoryStore { } } -func (s *MemoryStore) PutInventory(_ context.Context, snap inventory.Snapshot) error { +func (s *MemoryStore) PutInventory(_ context.Context, snap *inventory.Snapshot) error { + if snap == nil { + return nil + } s.mu.Lock() defer s.mu.Unlock() - s.inventory = snap + cloned := *snap + s.inventory = &cloned s.hasInventory = true return nil } -func (s *MemoryStore) GetInventory(_ context.Context) (inventory.Snapshot, bool, error) { +func (s *MemoryStore) GetInventory(_ context.Context) (*inventory.Snapshot, bool, error) { s.mu.RLock() defer s.mu.RUnlock() - return s.inventory, s.hasInventory, nil + if !s.hasInventory || s.inventory == nil { + return nil, false, nil + } + cloned := *s.inventory + return &cloned, true, nil } func (s *MemoryStore) MarkInventoryExported(_ context.Context, hash string, at time.Time) error { @@ -74,18 +82,26 @@ func (s *MemoryStore) LastExportedInventoryHash(_ context.Context) (string, erro return s.lastInventoryHash, nil } -func (s *MemoryStore) PutAttestation(_ context.Context, result attestationloop.Result) error { +func (s *MemoryStore) PutAttestation(_ context.Context, result *attestationloop.Result) error { + if result == nil { + return nil + } s.mu.Lock() defer s.mu.Unlock() - s.attestation = result + cloned := *result + s.attestation = &cloned s.hasAttestation = true return nil } -func (s *MemoryStore) GetAttestation(_ context.Context) (attestationloop.Result, bool, error) { +func (s *MemoryStore) GetAttestation(_ context.Context) (*attestationloop.Result, bool, error) { s.mu.RLock() defer s.mu.RUnlock() - return s.attestation, s.hasAttestation, nil + if !s.hasAttestation || s.attestation == nil { + return nil, false, nil + } + cloned := *s.attestation + return &cloned, true, nil } func (s *MemoryStore) MarkAttestationExported(_ context.Context, key string, at time.Time) error { From 31c9df04fa40e2bec7dd65519299ff4029a4ef7d Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Wed, 15 Apr 2026 14:15:39 -0700 Subject: [PATCH 04/22] refactor: model attestationloop as a workflow Signed-off-by: Jingxiang Zhang --- .../attestationloop/{mapper => }/backend.go | 11 +--- internal/attestationloop/manager.go | 48 +++++++++++++-- internal/attestationloop/sink/backend.go | 59 ------------------- internal/attestationloop/source/source.go | 28 --------- internal/attestationloop/types.go | 10 ++-- 5 files changed, 50 insertions(+), 106 deletions(-) rename internal/attestationloop/{mapper => }/backend.go (80%) delete mode 100644 internal/attestationloop/sink/backend.go delete mode 100644 internal/attestationloop/source/source.go diff --git a/internal/attestationloop/mapper/backend.go b/internal/attestationloop/backend.go similarity index 80% rename from internal/attestationloop/mapper/backend.go rename to internal/attestationloop/backend.go index 1f2160fb..bb8d7330 100644 --- a/internal/attestationloop/mapper/backend.go +++ b/internal/attestationloop/backend.go @@ -13,16 +13,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package mapper contains attestation loop payload mappers. -package mapper +package attestationloop -import ( - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestationloop" - "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" -) +import "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" -// ToAttestationRequest maps an attestation result to the backend attestation contract. -func ToAttestationRequest(r *attestationloop.Result) *backendclient.AttestationRequest { +func toAttestationRequest(r *Result) *backendclient.AttestationRequest { if r == nil { return nil } diff --git a/internal/attestationloop/manager.go b/internal/attestationloop/manager.go index a52a9bc1..c31a489b 100644 --- a/internal/attestationloop/manager.go +++ b/internal/attestationloop/manager.go @@ -19,6 +19,8 @@ import ( "context" "fmt" "time" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" ) // JWTProvider retrieves the current backend JWT. @@ -38,6 +40,7 @@ type manager struct { jwtProvider JWTProvider nonceProvider NonceProvider collector EvidenceCollector + submitter Submitter store StateStore interval time.Duration } @@ -48,6 +51,7 @@ func NewManager( jwtProvider JWTProvider, nonceProvider NonceProvider, collector EvidenceCollector, + submitter Submitter, store StateStore, interval time.Duration, ) Manager { @@ -56,6 +60,7 @@ func NewManager( jwtProvider: jwtProvider, nonceProvider: nonceProvider, collector: collector, + submitter: submitter, store: store, interval: interval, } @@ -69,7 +74,7 @@ func (m *manager) Run(ctx context.Context) error { } func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { - if m.nodeIDProvider == nil || m.jwtProvider == nil || m.nonceProvider == nil || m.collector == nil { + if m.nodeIDProvider == nil || m.jwtProvider == nil || m.nonceProvider == nil || m.collector == nil || m.submitter == nil { return nil, fmt.Errorf("attestation loop dependencies are incomplete") } @@ -89,16 +94,19 @@ func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { if err := m.jwtProvider.SetJWT(ctx, refreshedJWT); err != nil { return nil, err } + jwt = refreshedJWT } sdkResp, err := m.collector.Collect(ctx, nonce) - if err != nil { - return nil, err - } result := &Result{ CollectedAt: time.Now().UTC(), NodeID: nodeID, NonceRefreshTimestamp: refreshTS, - Success: true, + } + if err != nil { + result.Success = false + result.ErrorMessage = err.Error() + } else { + result.Success = true } if sdkResp != nil { result.SDKResponse = *sdkResp @@ -108,5 +116,35 @@ func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { return nil, err } } + if err := m.submitter.Submit(ctx, result, jwt); err != nil { + return nil, err + } return result, nil } + +type backendSubmitter struct { + client BackendClient +} + +// BackendClient is the backend client view required by the attestation workflow. +type BackendClient interface { + SubmitAttestation(ctx context.Context, nodeID string, req *backendclient.AttestationRequest, jwt string) error +} + +// NewBackendSubmitter creates a backend submitter backed by the agent backend client. +func NewBackendSubmitter(client BackendClient) Submitter { + return &backendSubmitter{client: client} +} + +func (s *backendSubmitter) Submit(ctx context.Context, result *Result, jwt string) error { + if s.client == nil { + return fmt.Errorf("attestation submission requires backend client") + } + if result == nil { + return fmt.Errorf("attestation submission requires result") + } + if jwt == "" { + return fmt.Errorf("attestation submission requires jwt") + } + return s.client.SubmitAttestation(ctx, result.NodeID, toAttestationRequest(result), jwt) +} diff --git a/internal/attestationloop/sink/backend.go b/internal/attestationloop/sink/backend.go deleted file mode 100644 index 65bd56bb..00000000 --- a/internal/attestationloop/sink/backend.go +++ /dev/null @@ -1,59 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package sink contains attestation loop sink implementations. -package sink - -import ( - "context" - "fmt" - - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestationloop" - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestationloop/mapper" - "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" -) - -type backendSink struct { - client backendclient.Client - jwt func(context.Context) (string, error) -} - -// NewBackendSink creates the backend attestation sink skeleton. -func NewBackendSink(client backendclient.Client, jwt func(context.Context) (string, error)) attestationloop.Sink { - return &backendSink{ - client: client, - jwt: jwt, - } -} - -func (s *backendSink) Export(ctx context.Context, result *attestationloop.Result) error { - if s.jwt == nil { - return fmt.Errorf("attestation backend export requires jwt provider") - } - if s.client == nil { - return fmt.Errorf("attestation backend export requires backend client") - } - if result == nil { - return fmt.Errorf("attestation backend export requires attestation result") - } - jwt, err := s.jwt(ctx) - if err != nil { - return err - } - if jwt == "" { - return fmt.Errorf("attestation backend export received empty jwt") - } - return s.client.SubmitAttestation(ctx, result.NodeID, mapper.ToAttestationRequest(result), jwt) -} diff --git a/internal/attestationloop/source/source.go b/internal/attestationloop/source/source.go deleted file mode 100644 index 82ee02d8..00000000 --- a/internal/attestationloop/source/source.go +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package source contains attestation loop collection adapters. -package source - -import ( - "context" - - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestationloop" -) - -// NVAttestCollector is the local attestation evidence collector dependency. -type NVAttestCollector interface { - Collect(ctx context.Context, nonce string) (*attestationloop.SDKResponse, error) -} diff --git a/internal/attestationloop/types.go b/internal/attestationloop/types.go index 8491d6a0..7c74a4fb 100644 --- a/internal/attestationloop/types.go +++ b/internal/attestationloop/types.go @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package attestationloop owns attestation collection and sync state orchestration. +// Package attestationloop owns the backend attestation workflow. package attestationloop import ( @@ -57,15 +57,13 @@ type EvidenceCollector interface { Collect(ctx context.Context, nonce string) (*SDKResponse, error) } -// Sink exports attestation results to an external destination. -type Sink interface { - Export(ctx context.Context, result *Result) error +// Submitter submits attestation results to the backend. +type Submitter interface { + Submit(ctx context.Context, result *Result, jwt string) error } // StateStore is the attestation loop view of local transient store state. type StateStore interface { PutAttestation(ctx context.Context, result *Result) error GetAttestation(ctx context.Context) (*Result, bool, error) - MarkAttestationExported(ctx context.Context, key string, at time.Time) error - WasAttestationExported(ctx context.Context, key string) (bool, error) } From fb816473a2e21fd49235f2afbb1bfa47c5c9cb91 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Wed, 15 Apr 2026 15:05:44 -0700 Subject: [PATCH 05/22] feat: add inventory backend sync flow Signed-off-by: Jingxiang Zhang --- cmd/fleetint/enroll.go | 41 ++++++- cmd/fleetint/enroll_test.go | 12 +- cmd/fleetint/unenroll.go | 3 +- internal/agentstate/sqlite.go | 133 +++++++++++++++++++++ internal/agentstate/sqlite_test.go | 108 +++++++++++++++++ internal/attestationloop/manager.go | 16 +-- internal/attestationloop/types.go | 6 - internal/backendclient/client_test.go | 137 ++++++++++++++++++++++ internal/backendclient/errors_test.go | 46 ++++++++ internal/inventory/hash.go | 41 +++++++ internal/inventory/hash_test.go | 50 ++++++++ internal/inventory/manager.go | 56 +++++++-- internal/inventory/manager_run_test.go | 75 ++++++++++++ internal/inventory/manager_test.go | 103 ++++++++++++++++ internal/inventory/mapper/backend_test.go | 110 +++++++++++++++++ internal/inventory/sink/backend.go | 40 ++++--- internal/inventory/sink/backend_test.go | 135 +++++++++++++++++++++ internal/inventory/source/source.go | 99 +++++++++++++++- internal/inventory/source/source_test.go | 114 ++++++++++++++++++ internal/inventory/types.go | 12 +- internal/store/memory.go | 119 ------------------- 21 files changed, 1287 insertions(+), 169 deletions(-) create mode 100644 internal/agentstate/sqlite.go create mode 100644 internal/agentstate/sqlite_test.go create mode 100644 internal/backendclient/errors_test.go create mode 100644 internal/inventory/hash.go create mode 100644 internal/inventory/hash_test.go create mode 100644 internal/inventory/manager_run_test.go create mode 100644 internal/inventory/manager_test.go create mode 100644 internal/inventory/mapper/backend_test.go create mode 100644 internal/inventory/sink/backend_test.go create mode 100644 internal/inventory/source/source_test.go delete mode 100644 internal/store/memory.go diff --git a/cmd/fleetint/enroll.go b/cmd/fleetint/enroll.go index 42d2b504..26878d9d 100644 --- a/cmd/fleetint/enroll.go +++ b/cmd/fleetint/enroll.go @@ -23,12 +23,18 @@ import ( "strings" pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" + nvidianvml "github.com/NVIDIA/fleet-intelligence-sdk/pkg/nvidia-query/nvml" "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" "github.com/urfave/cli" + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" "github.com/NVIDIA/fleet-intelligence-agent/internal/enrollment" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" + inventorysink "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/sink" + inventorysource "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/source" + "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" ) var ( @@ -36,6 +42,7 @@ var ( return enrollment.PerformEnrollment(context.Background(), enrollEndpoint, sakToken) } storeEnrollmentConfig = storeConfigInMetadata + performInventorySync = syncInventoryOnce ) // resolveToken returns the SAK token from --token, --token-file, or stdin. @@ -138,14 +145,17 @@ func enrollCommand(cliContext *cli.Context) error { } // Store endpoints and JWT token in metadata table - if err := storeEnrollmentConfig(enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken); err != nil { + if err := storeEnrollmentConfig(baseURL.String(), enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken); err != nil { return fmt.Errorf("failed to store configuration: %w", err) } + if err := performInventorySync(context.Background()); err != nil { + fmt.Fprintf(writerFromContext(cliContext), "Post-enroll inventory sync failed: %v\n", err) + } return nil } -func storeConfigInMetadata(enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken string) error { +func storeConfigInMetadata(baseURL, enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken string) error { stateFile, err := config.DefaultStateFile() if err != nil { return fmt.Errorf("failed to get state file path: %w", err) @@ -168,6 +178,9 @@ func storeConfigInMetadata(enrollEndpoint, metricsEndpoint, logsEndpoint, nonceE if err := pkgmetadata.SetMetadata(context.Background(), dbRW, pkgmetadata.MetadataKeyToken, jwtToken); err != nil { return fmt.Errorf("failed to set JWT token: %w", err) } + if err := pkgmetadata.SetMetadata(context.Background(), dbRW, "backend_base_url", baseURL); err != nil { + return fmt.Errorf("failed to set backend base URL: %w", err) + } if err := pkgmetadata.SetMetadata(context.Background(), dbRW, "enroll_endpoint", enrollEndpoint); err != nil { return fmt.Errorf("failed to set enroll endpoint: %w", err) } @@ -186,3 +199,27 @@ func storeConfigInMetadata(enrollEndpoint, metricsEndpoint, logsEndpoint, nonceE return nil } + +type machineInfoCollectorFunc func(context.Context) (*machineinfo.MachineInfo, error) + +func (f machineInfoCollectorFunc) Collect(ctx context.Context) (*machineinfo.MachineInfo, error) { + return f(ctx) +} + +func syncInventoryOnce(ctx context.Context) error { + state := agentstate.NewSQLite() + sink := inventorysink.NewBackendSink(state) + + nvmlInstance, err := nvidianvml.New() + if err != nil { + return fmt.Errorf("initialize nvml for inventory sync: %w", err) + } + defer func() { _ = nvmlInstance.Shutdown() }() + + src := inventorysource.NewMachineInfoSource(machineInfoCollectorFunc(func(context.Context) (*machineinfo.MachineInfo, error) { + return machineinfo.GetMachineInfo(nvmlInstance) + })) + manager := inventory.NewManager(src, sink, 0) + _, err = manager.CollectOnce(ctx) + return err +} diff --git a/cmd/fleetint/enroll_test.go b/cmd/fleetint/enroll_test.go index faf63b8f..681f99c8 100644 --- a/cmd/fleetint/enroll_test.go +++ b/cmd/fleetint/enroll_test.go @@ -17,6 +17,7 @@ package main import ( "bytes" + "context" "fmt" "os" "path/filepath" @@ -52,10 +53,12 @@ func TestEnrollCommandBlocksOnFailedPrecheck(t *testing.T) { originalRunPrecheck := runPrecheck originalPerformEnrollment := performEnrollment originalStoreConfig := storeEnrollmentConfig + originalInventorySync := performInventorySync t.Cleanup(func() { runPrecheck = originalRunPrecheck performEnrollment = originalPerformEnrollment storeEnrollmentConfig = originalStoreConfig + performInventorySync = originalInventorySync }) enrollmentCalled := false @@ -70,9 +73,10 @@ func TestEnrollCommandBlocksOnFailedPrecheck(t *testing.T) { enrollmentCalled = true return "jwt-token", nil } - storeEnrollmentConfig = func(enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken string) error { + storeEnrollmentConfig = func(baseURL, enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken string) error { return nil } + performInventorySync = func(context.Context) error { return nil } out := &bytes.Buffer{} app := App() @@ -90,10 +94,12 @@ func TestEnrollCommandForceBypassesFailedPrecheck(t *testing.T) { originalRunPrecheck := runPrecheck originalPerformEnrollment := performEnrollment originalStoreConfig := storeEnrollmentConfig + originalInventorySync := performInventorySync t.Cleanup(func() { runPrecheck = originalRunPrecheck performEnrollment = originalPerformEnrollment storeEnrollmentConfig = originalStoreConfig + performInventorySync = originalInventorySync }) enrollmentCalled := false @@ -108,9 +114,10 @@ func TestEnrollCommandForceBypassesFailedPrecheck(t *testing.T) { enrollmentCalled = true return "jwt-token", nil } - storeEnrollmentConfig = func(enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken string) error { + storeEnrollmentConfig = func(baseURL, enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken string) error { return nil } + performInventorySync = func(context.Context) error { return nil } app := App() app.Writer = &bytes.Buffer{} @@ -130,6 +137,7 @@ func TestStoreConfigInMetadataSecuresFreshStateFile(t *testing.T) { t.Setenv("HOME", tmpHome) err := storeConfigInMetadata( + "https://example.com", "https://example.com/api/v1/health/enroll", "https://example.com/api/v1/health/metrics", "https://example.com/api/v1/health/logs", diff --git a/cmd/fleetint/unenroll.go b/cmd/fleetint/unenroll.go index 6bb8b6d2..381b31d1 100644 --- a/cmd/fleetint/unenroll.go +++ b/cmd/fleetint/unenroll.go @@ -67,6 +67,7 @@ func removeEnrollmentMetadata(ctx context.Context, dbRW *sql.DB) error { keysToDelete := []string{ pkgmetadata.MetadataKeyToken, "sak_token", + "backend_base_url", "enroll_endpoint", "metrics_endpoint", "logs_endpoint", @@ -74,7 +75,7 @@ func removeEnrollmentMetadata(ctx context.Context, dbRW *sql.DB) error { } // Build batch delete query - query := "DELETE FROM gpud_metadata WHERE key IN (?, ?, ?, ?, ?, ?)" + query := "DELETE FROM gpud_metadata WHERE key IN (?, ?, ?, ?, ?, ?, ?)" // Convert string slice to []interface{} for ExecContext args := make([]interface{}, len(keysToDelete)) diff --git a/internal/agentstate/sqlite.go b/internal/agentstate/sqlite.go new file mode 100644 index 00000000..f384d6ed --- /dev/null +++ b/internal/agentstate/sqlite.go @@ -0,0 +1,133 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package agentstate + +import ( + "context" + "database/sql" + "fmt" + + pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" + "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/config" +) + +const metadataKeyBackendBaseURL = "backend_base_url" + +type sqliteState struct { + stateFileFn func() (string, error) +} + +// NewSQLite returns a State backed by the agent sqlite metadata database. +func NewSQLite() State { + return &sqliteState{stateFileFn: config.DefaultStateFile} +} + +func (s *sqliteState) GetBackendBaseURL(ctx context.Context) (string, bool, error) { + return s.getMetadata(ctx, metadataKeyBackendBaseURL) +} + +func (s *sqliteState) SetBackendBaseURL(ctx context.Context, value string) error { + return s.setMetadata(ctx, metadataKeyBackendBaseURL, value) +} + +func (s *sqliteState) GetJWT(ctx context.Context) (string, bool, error) { + return s.getMetadata(ctx, pkgmetadata.MetadataKeyToken) +} + +func (s *sqliteState) SetJWT(ctx context.Context, value string) error { + return s.setMetadata(ctx, pkgmetadata.MetadataKeyToken, value) +} + +func (s *sqliteState) GetSAK(ctx context.Context) (string, bool, error) { + return s.getMetadata(ctx, "sak_token") +} + +func (s *sqliteState) SetSAK(ctx context.Context, value string) error { + return s.setMetadata(ctx, "sak_token", value) +} + +func (s *sqliteState) GetNodeID(ctx context.Context) (string, bool, error) { + return s.getMetadata(ctx, pkgmetadata.MetadataKeyMachineID) +} + +func (s *sqliteState) SetNodeID(ctx context.Context, value string) error { + return s.setMetadata(ctx, pkgmetadata.MetadataKeyMachineID, value) +} + +func (s *sqliteState) getMetadata(ctx context.Context, key string) (string, bool, error) { + db, err := s.openReadOnly() + if err != nil { + return "", false, err + } + defer db.Close() + + value, err := pkgmetadata.ReadMetadata(ctx, db, key) + if err != nil { + return "", false, fmt.Errorf("read metadata %q: %w", key, err) + } + if value == "" { + return "", false, nil + } + return value, true, nil +} + +func (s *sqliteState) setMetadata(ctx context.Context, key, value string) error { + db, err := s.openReadWrite() + if err != nil { + return err + } + defer db.Close() + + if err := pkgmetadata.CreateTableMetadata(ctx, db); err != nil { + return fmt.Errorf("create metadata table: %w", err) + } + if err := pkgmetadata.SetMetadata(ctx, db, key, value); err != nil { + return fmt.Errorf("set metadata %q: %w", key, err) + } + stateFile, err := s.stateFileFn() + if err == nil { + if err := config.SecureStateFilePermissions(stateFile); err != nil { + return fmt.Errorf("secure state file permissions: %w", err) + } + } + return nil +} + +func (s *sqliteState) openReadOnly() (*sql.DB, error) { + stateFile, err := s.stateFileFn() + if err != nil { + return nil, fmt.Errorf("get state file path: %w", err) + } + db, err := sqlite.Open(stateFile, sqlite.WithReadOnly(true)) + if err != nil { + return nil, fmt.Errorf("open state database read-only: %w", err) + } + return db, nil +} + +func (s *sqliteState) openReadWrite() (*sql.DB, error) { + stateFile, err := s.stateFileFn() + if err != nil { + return nil, fmt.Errorf("get state file path: %w", err) + } + db, err := sqlite.Open(stateFile) + if err != nil { + return nil, fmt.Errorf("open state database: %w", err) + } + return db, nil +} diff --git a/internal/agentstate/sqlite_test.go b/internal/agentstate/sqlite_test.go new file mode 100644 index 00000000..ee7abc97 --- /dev/null +++ b/internal/agentstate/sqlite_test.go @@ -0,0 +1,108 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package agentstate + +import ( + "context" + "errors" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func newTestSQLiteState(t *testing.T) *sqliteState { + t.Helper() + stateFile := filepath.Join(t.TempDir(), "agent.state") + return &sqliteState{ + stateFileFn: func() (string, error) { + return stateFile, nil + }, + } +} + +func TestSQLiteStateRoundTrip(t *testing.T) { + t.Parallel() + + ctx := context.Background() + state := newTestSQLiteState(t) + + err := state.SetBackendBaseURL(ctx, "https://backend.example.com") + require.NoError(t, err) + err = state.SetJWT(ctx, "jwt-token") + require.NoError(t, err) + err = state.SetSAK(ctx, "sak-token") + require.NoError(t, err) + err = state.SetNodeID(ctx, "node-1") + require.NoError(t, err) + + value, ok, err := state.GetBackendBaseURL(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "https://backend.example.com", value) + + value, ok, err = state.GetJWT(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "jwt-token", value) + + value, ok, err = state.GetSAK(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "sak-token", value) + + value, ok, err = state.GetNodeID(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "node-1", value) +} + +func TestSQLiteStateMissingValue(t *testing.T) { + t.Parallel() + + ctx := context.Background() + state := newTestSQLiteState(t) + + err := state.SetJWT(ctx, "jwt-token") + require.NoError(t, err) + + value, ok, err := state.GetBackendBaseURL(ctx) + require.NoError(t, err) + require.False(t, ok) + require.Empty(t, value) +} + +func TestSQLiteStateStateFileErrors(t *testing.T) { + t.Parallel() + + boom := errors.New("boom") + state := &sqliteState{ + stateFileFn: func() (string, error) { + return "", boom + }, + } + + _, _, err := state.GetJWT(context.Background()) + require.ErrorIs(t, err, boom) + + err = state.SetJWT(context.Background(), "jwt-token") + require.ErrorIs(t, err, boom) +} + +func TestNewSQLite(t *testing.T) { + t.Parallel() + require.NotNil(t, NewSQLite()) +} diff --git a/internal/attestationloop/manager.go b/internal/attestationloop/manager.go index c31a489b..68fb4e53 100644 --- a/internal/attestationloop/manager.go +++ b/internal/attestationloop/manager.go @@ -18,6 +18,7 @@ package attestationloop import ( "context" "fmt" + "sync" "time" "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" @@ -36,13 +37,15 @@ type Manager interface { } type manager struct { + mu sync.RWMutex nodeIDProvider func(context.Context) (string, error) jwtProvider JWTProvider nonceProvider NonceProvider collector EvidenceCollector submitter Submitter - store StateStore interval time.Duration + + lastResult *Result } // NewManager creates an attestation loop manager skeleton. @@ -52,7 +55,6 @@ func NewManager( nonceProvider NonceProvider, collector EvidenceCollector, submitter Submitter, - store StateStore, interval time.Duration, ) Manager { return &manager{ @@ -61,7 +63,6 @@ func NewManager( nonceProvider: nonceProvider, collector: collector, submitter: submitter, - store: store, interval: interval, } } @@ -111,11 +112,10 @@ func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { if sdkResp != nil { result.SDKResponse = *sdkResp } - if m.store != nil { - if err := m.store.PutAttestation(ctx, result); err != nil { - return nil, err - } - } + m.mu.Lock() + cloned := *result + m.lastResult = &cloned + m.mu.Unlock() if err := m.submitter.Submit(ctx, result, jwt); err != nil { return nil, err } diff --git a/internal/attestationloop/types.go b/internal/attestationloop/types.go index 7c74a4fb..eb327486 100644 --- a/internal/attestationloop/types.go +++ b/internal/attestationloop/types.go @@ -61,9 +61,3 @@ type EvidenceCollector interface { type Submitter interface { Submit(ctx context.Context, result *Result, jwt string) error } - -// StateStore is the attestation loop view of local transient store state. -type StateStore interface { - PutAttestation(ctx context.Context, result *Result) error - GetAttestation(ctx context.Context) (*Result, bool, error) -} diff --git a/internal/backendclient/client_test.go b/internal/backendclient/client_test.go index b9d4cb76..75735d63 100644 --- a/internal/backendclient/client_test.go +++ b/internal/backendclient/client_test.go @@ -18,9 +18,11 @@ package backendclient import ( "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "net/url" + "strings" "testing" "github.com/stretchr/testify/require" @@ -148,6 +150,141 @@ func TestClient_EnrollMapsHTTPStatus(t *testing.T) { require.Contains(t, err.Error(), "incorrect") } +func TestClient_ValidationErrors(t *testing.T) { + t.Parallel() + + c := NewWithHTTPClient(mustParseURL(t, "https://backend.example.com"), nil) + + _, err := c.Enroll(context.Background(), "") + require.ErrorContains(t, err, "sakToken cannot be empty") + + err = c.UpsertNode(context.Background(), "", &NodeUpsertRequest{}, "jwt") + require.ErrorContains(t, err, "nodeID cannot be empty") + err = c.UpsertNode(context.Background(), "node-1", nil, "jwt") + require.ErrorContains(t, err, "cannot be nil") + err = c.UpsertNode(context.Background(), "node-1", &NodeUpsertRequest{}, "") + require.ErrorContains(t, err, "jwt cannot be empty") + + _, err = c.GetNonce(context.Background(), "", "jwt") + require.ErrorContains(t, err, "nodeID cannot be empty") + _, err = c.GetNonce(context.Background(), "node-1", "") + require.ErrorContains(t, err, "jwt cannot be empty") + + err = c.SubmitAttestation(context.Background(), "", &AttestationRequest{}, "jwt") + require.ErrorContains(t, err, "nodeID cannot be empty") + err = c.SubmitAttestation(context.Background(), "node-1", nil, "jwt") + require.ErrorContains(t, err, "cannot be nil") + err = c.SubmitAttestation(context.Background(), "node-1", &AttestationRequest{}, "") + require.ErrorContains(t, err, "jwt cannot be empty") + + _, err = c.RefreshToken(context.Background(), "") + require.ErrorContains(t, err, "jwt cannot be empty") +} + +func TestClient_ResponseValidationAndErrors(t *testing.T) { + t.Parallel() + + t.Run("missing jwt assertion in enroll", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]string{}) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.Enroll(context.Background(), "sak-token") + require.ErrorContains(t, err, "missing jwtAssertion") + }) + + t.Run("missing nonce field", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]string{}) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.GetNonce(context.Background(), "node-1", "jwt-token") + require.ErrorContains(t, err, "missing nonce") + }) + + t.Run("missing refresh jwt assertion", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]string{}) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.RefreshToken(context.Background(), "jwt-token") + require.ErrorContains(t, err, "missing jwtAssertion") + }) + + t.Run("invalid json response", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("{invalid")) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + err := c.UpsertNode(context.Background(), "node-1", &NodeUpsertRequest{Hostname: "node-1"}, "jwt-token") + require.NoError(t, err) + + _, err = c.GetNonce(context.Background(), "node-1", "jwt-token") + require.ErrorContains(t, err, "failed to parse backend response") + }) + + t.Run("http client error", func(t *testing.T) { + c := NewWithHTTPClient(mustParseURL(t, "https://backend.example.com"), &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("network boom") + }), + }) + err := c.UpsertNode(context.Background(), "node-1", &NodeUpsertRequest{Hostname: "node-1"}, "jwt-token") + require.ErrorContains(t, err, "failed to make backend request") + }) + + t.Run("oversized response body", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(strings.Repeat("a", maxResponseBodyBytes+10))) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.GetNonce(context.Background(), "node-1", "jwt-token") + require.ErrorContains(t, err, "response too large") + }) +} + +func TestMapEnrollErrorStatuses(t *testing.T) { + t.Parallel() + + cases := map[int]string{ + http.StatusBadRequest: "correct format", + http.StatusUnauthorized: "incorrect", + http.StatusForbidden: "incorrect/expired", + http.StatusNotFound: "not found", + http.StatusTooManyRequests: "retry after some time", + http.StatusBadGateway: "temporary issue", + http.StatusServiceUnavailable: "unavailable", + http.StatusGatewayTimeout: "slow to respond", + } + + for status, want := range cases { + got := mapEnrollError(&HTTPStatusError{StatusCode: status}) + require.ErrorContains(t, got, want) + } + + other := &HTTPStatusError{StatusCode: http.StatusTeapot} + require.Equal(t, other, mapEnrollError(other)) + + plain := errors.New("plain") + require.Equal(t, plain, mapEnrollError(plain)) +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + func mustParseURL(t *testing.T, raw string) *url.URL { t.Helper() parsed, err := url.Parse(raw) diff --git a/internal/backendclient/errors_test.go b/internal/backendclient/errors_test.go new file mode 100644 index 00000000..ebee5327 --- /dev/null +++ b/internal/backendclient/errors_test.go @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backendclient + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHTTPStatusErrorError(t *testing.T) { + t.Parallel() + + err := (&HTTPStatusError{StatusCode: 500, Body: "line1\nline2"}).Error() + require.Contains(t, err, "status 500") + require.Contains(t, err, "line1 line2") + + err = (&HTTPStatusError{StatusCode: 404}).Error() + require.Equal(t, "backend request failed with status 404", err) +} + +func TestSanitizeErrorBody(t *testing.T) { + t.Parallel() + + require.Empty(t, sanitizeErrorBody(" \n\t ")) + require.Equal(t, "hello world", sanitizeErrorBody(" hello \n world ")) + + body := strings.Repeat("a", 220) + got := sanitizeErrorBody(body) + require.True(t, strings.HasSuffix(got, "...(truncated)")) + require.LessOrEqual(t, len(got), 214) +} diff --git a/internal/inventory/hash.go b/internal/inventory/hash.go new file mode 100644 index 00000000..d0284811 --- /dev/null +++ b/internal/inventory/hash.go @@ -0,0 +1,41 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inventory + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "time" +) + +// ComputeHash returns a deterministic hash for the stable inventory contents. +func ComputeHash(snap *Snapshot) (string, error) { + if snap == nil { + return "", fmt.Errorf("inventory snapshot is nil") + } + normalized := *snap + normalized.CollectedAt = time.Time{} + normalized.InventoryHash = "" + + payload, err := json.Marshal(normalized) + if err != nil { + return "", fmt.Errorf("marshal inventory snapshot: %w", err) + } + sum := sha256.Sum256(payload) + return hex.EncodeToString(sum[:]), nil +} diff --git a/internal/inventory/hash_test.go b/internal/inventory/hash_test.go new file mode 100644 index 00000000..8a91d7f6 --- /dev/null +++ b/internal/inventory/hash_test.go @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inventory + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestComputeHashIgnoresCollectedAtAndExistingHash(t *testing.T) { + base := &Snapshot{ + CollectedAt: time.Unix(100, 0).UTC(), + NodeID: "node-1", + InventoryHash: "old-hash", + Hostname: "host-a", + MachineID: "machine-id", + Resources: Resources{ + CPUInfo: CPUInfo{Type: "Xeon", LogicalCores: 64}, + }, + } + other := *base + other.CollectedAt = time.Unix(200, 0).UTC() + other.InventoryHash = "different-old-hash" + + hash1, err := ComputeHash(base) + require.NoError(t, err) + hash2, err := ComputeHash(&other) + require.NoError(t, err) + require.Equal(t, hash1, hash2) + + other.Hostname = "host-b" + hash3, err := ComputeHash(&other) + require.NoError(t, err) + require.NotEqual(t, hash1, hash3) +} diff --git a/internal/inventory/manager.go b/internal/inventory/manager.go index 1310f002..1668a4dd 100644 --- a/internal/inventory/manager.go +++ b/internal/inventory/manager.go @@ -17,7 +17,9 @@ package inventory import ( "context" + "errors" "fmt" + "sync" "time" ) @@ -28,16 +30,20 @@ type Manager interface { } type manager struct { + mu sync.RWMutex source Source - store StateStore + sink Sink interval time.Duration + + lastSnapshot *Snapshot + lastExportedHash string } -// NewManager creates an inventory manager skeleton. -func NewManager(source Source, store StateStore, interval time.Duration) Manager { +// NewManager creates an inventory manager. +func NewManager(source Source, sink Sink, interval time.Duration) Manager { return &manager{ source: source, - store: store, + sink: sink, interval: interval, } } @@ -46,7 +52,24 @@ func (m *manager) Run(ctx context.Context) error { if _, err := m.CollectOnce(ctx); err != nil { return err } - return fmt.Errorf("inventory manager run loop not implemented") + + if m.interval <= 0 { + return nil + } + + ticker := time.NewTicker(m.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if _, err := m.CollectOnce(ctx); err != nil { + return err + } + } + } } func (m *manager) CollectOnce(ctx context.Context) (*Snapshot, error) { @@ -60,10 +83,29 @@ func (m *manager) CollectOnce(ctx context.Context) (*Snapshot, error) { if snap == nil { return nil, fmt.Errorf("inventory source returned nil snapshot") } - if m.store != nil { - if err := m.store.PutInventory(ctx, snap); err != nil { + hash, err := ComputeHash(snap) + if err != nil { + return nil, err + } + snap.InventoryHash = hash + + m.mu.Lock() + cloned := *snap + m.lastSnapshot = &cloned + shouldExport := m.sink != nil && m.lastExportedHash != hash + m.mu.Unlock() + + if shouldExport { + if err := m.sink.Export(ctx, snap); err != nil { + if errors.Is(err, ErrNotReady) { + return snap, nil + } return nil, err } + m.mu.Lock() + m.lastExportedHash = hash + m.mu.Unlock() } + return snap, nil } diff --git a/internal/inventory/manager_run_test.go b/internal/inventory/manager_run_test.go new file mode 100644 index 00000000..9bbce141 --- /dev/null +++ b/internal/inventory/manager_run_test.go @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inventory + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type errSource struct{ err error } + +func (s errSource) Collect(context.Context) (*Snapshot, error) { return nil, s.err } + +type nilSnapshotSource struct{} + +func (nilSnapshotSource) Collect(context.Context) (*Snapshot, error) { return nil, nil } + +func TestManagerCollectOnceErrors(t *testing.T) { + _, err := NewManager(nil, nil, 0).CollectOnce(context.Background()) + require.ErrorContains(t, err, "inventory source is required") + + _, err = NewManager(errSource{err: errors.New("boom")}, nil, 0).CollectOnce(context.Background()) + require.ErrorContains(t, err, "boom") + + _, err = NewManager(nilSnapshotSource{}, nil, 0).CollectOnce(context.Background()) + require.ErrorContains(t, err, "nil snapshot") +} + +func TestManagerRunWithZeroInterval(t *testing.T) { + src := &fakeSource{ + snapshots: []*Snapshot{{NodeID: "node-1", MachineID: "machine-1", Hostname: "host-a"}}, + } + sink := &fakeSink{} + + err := NewManager(src, sink, 0).Run(context.Background()) + require.NoError(t, err) + require.Len(t, sink.exported, 1) +} + +func TestManagerRunStopsOnContextCancel(t *testing.T) { + src := &fakeSource{ + snapshots: []*Snapshot{{NodeID: "node-1", MachineID: "machine-1", Hostname: "host-a"}}, + } + sink := &fakeSink{} + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan error, 1) + go func() { + done <- NewManager(src, sink, 10*time.Millisecond).Run(ctx) + }() + + time.Sleep(25 * time.Millisecond) + cancel() + + err := <-done + require.ErrorIs(t, err, context.Canceled) + require.NotEmpty(t, sink.exported) +} diff --git a/internal/inventory/manager_test.go b/internal/inventory/manager_test.go new file mode 100644 index 00000000..fcc44e75 --- /dev/null +++ b/internal/inventory/manager_test.go @@ -0,0 +1,103 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inventory + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type fakeSource struct { + snapshots []*Snapshot + index int +} + +func (f *fakeSource) Collect(context.Context) (*Snapshot, error) { + if len(f.snapshots) == 0 { + return nil, nil + } + if f.index >= len(f.snapshots) { + last := *f.snapshots[len(f.snapshots)-1] + return &last, nil + } + snap := *f.snapshots[f.index] + f.index++ + return &snap, nil +} + +type fakeSink struct { + exported []*Snapshot +} + +func (f *fakeSink) Export(_ context.Context, snap *Snapshot) error { + cloned := *snap + f.exported = append(f.exported, &cloned) + return nil +} + +func TestManagerCollectOnceExportsOnlyWhenInventoryChanges(t *testing.T) { + src := &fakeSource{ + snapshots: []*Snapshot{ + { + CollectedAt: time.Unix(100, 0).UTC(), + NodeID: "node-1", + Hostname: "host-a", + MachineID: "machine-id", + Resources: Resources{ + CPUInfo: CPUInfo{Type: "Xeon", LogicalCores: 64}, + }, + }, + { + CollectedAt: time.Unix(200, 0).UTC(), + NodeID: "node-1", + Hostname: "host-a", + MachineID: "machine-id", + Resources: Resources{ + CPUInfo: CPUInfo{Type: "Xeon", LogicalCores: 64}, + }, + }, + { + CollectedAt: time.Unix(300, 0).UTC(), + NodeID: "node-1", + Hostname: "host-b", + MachineID: "machine-id", + Resources: Resources{ + CPUInfo: CPUInfo{Type: "Xeon", LogicalCores: 64}, + }, + }, + }, + } + sink := &fakeSink{} + mgr := NewManager(src, sink, 0) + + snap1, err := mgr.CollectOnce(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, snap1.InventoryHash) + require.Len(t, sink.exported, 1) + + snap2, err := mgr.CollectOnce(context.Background()) + require.NoError(t, err) + require.Equal(t, snap1.InventoryHash, snap2.InventoryHash) + require.Len(t, sink.exported, 1) + + snap3, err := mgr.CollectOnce(context.Background()) + require.NoError(t, err) + require.NotEqual(t, snap1.InventoryHash, snap3.InventoryHash) + require.Len(t, sink.exported, 2) +} diff --git a/internal/inventory/mapper/backend_test.go b/internal/inventory/mapper/backend_test.go new file mode 100644 index 00000000..8dc5868d --- /dev/null +++ b/internal/inventory/mapper/backend_test.go @@ -0,0 +1,110 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapper + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" +) + +func TestToNodeUpsertRequestNil(t *testing.T) { + require.Nil(t, ToNodeUpsertRequest(nil)) +} + +func TestToNodeUpsertRequest(t *testing.T) { + req := ToNodeUpsertRequest(&inventory.Snapshot{ + NodeID: "node-1", + Hostname: "host-a", + MachineID: "machine-id", + SystemUUID: "uuid-1", + BootID: "boot-1", + OperatingSystem: "linux", + OSImage: "Ubuntu", + KernelVersion: "6.5.0", + FleetintVersion: "1.2.3", + GPUDriverVersion: "550.54.15", + CUDAVersion: "12.4", + DCGMVersion: "4.2.3", + ContainerRuntimeVersion: "containerd://1.7.13", + NetPrivateIP: "10.0.0.10", + NetPublicIP: "203.0.113.10", + InventoryHash: "hash-1", + Resources: inventory.Resources{ + CPUInfo: inventory.CPUInfo{ + Type: "Xeon", + Manufacturer: "Intel", + Architecture: "x86_64", + LogicalCores: 64, + }, + MemoryInfo: inventory.MemoryInfo{ + TotalBytes: 1024, + }, + GPUInfo: inventory.GPUInfo{ + Product: "H100", + Manufacturer: "NVIDIA", + Architecture: "Hopper", + Memory: "80GB", + GPUs: []inventory.GPUDevice{{ + UUID: "GPU-1", + BusID: "0000:01:00.0", + SN: "serial", + MinorID: "1", + BoardID: 7, + VBIOSVersion: "vbios", + ChassisSN: "chassis", + GPUIndex: "0", + }}, + }, + DiskInfo: inventory.DiskInfo{ + ContainerRootDisk: "/dev/nvme0n1", + BlockDevices: []inventory.BlockDevice{{ + Name: "nvme0n1", + Type: "disk", + Size: 2048, + WWN: "wwn", + MountPoint: "/", + FSType: "ext4", + PartUUID: "part-uuid", + Parents: []string{"parent0"}, + }}, + }, + NICInfo: inventory.NICInfo{ + PrivateIPInterfaces: []inventory.NICInterface{{ + Interface: "eth0", + MAC: "00:11:22:33:44:55", + IP: "10.0.0.10", + }}, + }, + }, + }) + + require.NotNil(t, req) + require.Equal(t, "host-a", req.Hostname) + require.Equal(t, "machine-id", req.MachineID) + require.Equal(t, "203.0.113.10", req.NetPublicIP) + require.Equal(t, "hash-1", req.InventoryHash) + require.Equal(t, int64(64), req.Resources.CPUInfo.LogicalCores) + require.Equal(t, uint64(1024), req.Resources.MemoryInfo.TotalBytes) + require.Len(t, req.Resources.GPUInfo.GPUs, 1) + require.Equal(t, 7, req.Resources.GPUInfo.GPUs[0].BoardID) + require.Len(t, req.Resources.DiskInfo.BlockDevices, 1) + require.Equal(t, "parent0", req.Resources.DiskInfo.BlockDevices[0].Parents[0]) + require.Len(t, req.Resources.NICInfo.PrivateIPInterfaces, 1) + require.Equal(t, "eth0", req.Resources.NICInfo.PrivateIPInterfaces[0].Interface) +} diff --git a/internal/inventory/sink/backend.go b/internal/inventory/sink/backend.go index f0ac82e9..73e4fd64 100644 --- a/internal/inventory/sink/backend.go +++ b/internal/inventory/sink/backend.go @@ -20,40 +20,52 @@ import ( "context" "fmt" + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/mapper" ) type backendSink struct { - client backendclient.Client - jwt func(context.Context) (string, error) + state agentstate.State + clientFactory func(rawBaseURL string) (backendclient.Client, error) } -// NewBackendSink creates the backend inventory sink skeleton. -func NewBackendSink(client backendclient.Client, jwt func(context.Context) (string, error)) inventory.Sink { +// NewBackendSink creates the backend inventory sink. +func NewBackendSink(state agentstate.State) inventory.Sink { return &backendSink{ - client: client, - jwt: jwt, + state: state, + clientFactory: backendclient.New, } } func (s *backendSink) Export(ctx context.Context, snap *inventory.Snapshot) error { - if s.jwt == nil { - return fmt.Errorf("inventory backend export requires jwt provider") + if s.state == nil { + return fmt.Errorf("inventory backend export requires agent state") } - if s.client == nil { - return fmt.Errorf("inventory backend export requires backend client") + if s.clientFactory == nil { + return fmt.Errorf("inventory backend export requires backend client factory") } if snap == nil { return fmt.Errorf("inventory backend export requires inventory snapshot") } - jwt, err := s.jwt(ctx) + baseURL, ok, err := s.state.GetBackendBaseURL(ctx) if err != nil { return err } - if jwt == "" { - return fmt.Errorf("inventory backend export received empty jwt") + if !ok || baseURL == "" { + return inventory.ErrNotReady } - return s.client.UpsertNode(ctx, snap.NodeID, mapper.ToNodeUpsertRequest(snap), jwt) + jwt, ok, err := s.state.GetJWT(ctx) + if err != nil { + return err + } + if !ok || jwt == "" { + return inventory.ErrNotReady + } + client, err := s.clientFactory(baseURL) + if err != nil { + return fmt.Errorf("create backend client: %w", err) + } + return client.UpsertNode(ctx, snap.NodeID, mapper.ToNodeUpsertRequest(snap), jwt) } diff --git a/internal/inventory/sink/backend_test.go b/internal/inventory/sink/backend_test.go new file mode 100644 index 00000000..a4ed0d0f --- /dev/null +++ b/internal/inventory/sink/backend_test.go @@ -0,0 +1,135 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sink + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" +) + +type fakeState struct { + baseURL string + jwt string + err error +} + +func (f fakeState) GetBackendBaseURL(context.Context) (string, bool, error) { + if f.err != nil { + return "", false, f.err + } + return f.baseURL, f.baseURL != "", nil +} +func (f fakeState) SetBackendBaseURL(context.Context, string) error { return nil } +func (f fakeState) GetJWT(context.Context) (string, bool, error) { + if f.err != nil { + return "", false, f.err + } + return f.jwt, f.jwt != "", nil +} +func (f fakeState) SetJWT(context.Context, string) error { return nil } +func (f fakeState) GetSAK(context.Context) (string, bool, error) { return "", false, nil } +func (f fakeState) SetSAK(context.Context, string) error { return nil } +func (f fakeState) GetNodeID(context.Context) (string, bool, error) { return "", false, nil } +func (f fakeState) SetNodeID(context.Context, string) error { return nil } + +type fakeClient struct { + nodeID string + req *backendclient.NodeUpsertRequest + jwt string +} + +func (f *fakeClient) Enroll(context.Context, string) (string, error) { return "", nil } +func (f *fakeClient) GetNonce(context.Context, string, string) (*backendclient.NonceResponse, error) { + return nil, nil +} +func (f *fakeClient) SubmitAttestation(context.Context, string, *backendclient.AttestationRequest, string) error { + return nil +} +func (f *fakeClient) RefreshToken(context.Context, string) (string, error) { return "", nil } +func (f *fakeClient) UpsertNode(_ context.Context, nodeID string, req *backendclient.NodeUpsertRequest, jwt string) error { + f.nodeID = nodeID + f.req = req + f.jwt = jwt + return nil +} + +func TestBackendSinkExportNotReady(t *testing.T) { + s := &backendSink{ + state: fakeState{}, + clientFactory: backendclient.New, + } + + err := s.Export(context.Background(), &inventory.Snapshot{NodeID: "node-1"}) + require.ErrorIs(t, err, inventory.ErrNotReady) +} + +func TestBackendSinkExportErrors(t *testing.T) { + err := (&backendSink{}).Export(context.Background(), &inventory.Snapshot{NodeID: "node-1"}) + require.ErrorContains(t, err, "agent state") + + err = (&backendSink{state: fakeState{baseURL: "https://example.com", jwt: "jwt"}}).Export(context.Background(), &inventory.Snapshot{NodeID: "node-1"}) + require.ErrorContains(t, err, "client factory") + + err = (&backendSink{ + state: fakeState{err: errors.New("state error")}, + clientFactory: backendclient.New, + }).Export(context.Background(), &inventory.Snapshot{NodeID: "node-1"}) + require.ErrorContains(t, err, "state error") + + err = (&backendSink{ + state: fakeState{baseURL: "https://example.com", jwt: "jwt"}, + clientFactory: backendclient.New, + }).Export(context.Background(), nil) + require.ErrorContains(t, err, "inventory snapshot") + + err = (&backendSink{ + state: fakeState{baseURL: "https://example.com", jwt: "jwt"}, + clientFactory: func(string) (backendclient.Client, error) { + return nil, errors.New("client factory error") + }, + }).Export(context.Background(), &inventory.Snapshot{NodeID: "node-1"}) + require.ErrorContains(t, err, "create backend client") +} + +func TestBackendSinkExportUsesState(t *testing.T) { + client := &fakeClient{} + s := &backendSink{ + state: fakeState{ + baseURL: "https://example.com", + jwt: "jwt-token", + }, + clientFactory: func(string) (backendclient.Client, error) { + return client, nil + }, + } + + err := s.Export(context.Background(), &inventory.Snapshot{ + NodeID: "node-1", + Hostname: "host-a", + MachineID: "machine-id", + }) + require.NoError(t, err) + require.Equal(t, "node-1", client.nodeID) + require.Equal(t, "jwt-token", client.jwt) + require.NotNil(t, client.req) + require.Equal(t, "host-a", client.req.Hostname) +} diff --git a/internal/inventory/source/source.go b/internal/inventory/source/source.go index 20c0fec4..8d84a99a 100644 --- a/internal/inventory/source/source.go +++ b/internal/inventory/source/source.go @@ -19,13 +19,15 @@ package source import ( "context" "fmt" + "time" "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" + "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" ) // MachineInfoCollector is the local machine inventory collector dependency. type MachineInfoCollector interface { - Collect(ctx context.Context) (*inventory.Snapshot, error) + Collect(ctx context.Context) (*machineinfo.MachineInfo, error) } type machineInfoSource struct { @@ -41,5 +43,98 @@ func (s *machineInfoSource) Collect(ctx context.Context) (*inventory.Snapshot, e if s.collector == nil { return nil, fmt.Errorf("machine info collector is required") } - return s.collector.Collect(ctx) + info, err := s.collector.Collect(ctx) + if err != nil { + return nil, err + } + if info == nil { + return nil, fmt.Errorf("machine info collector returned nil machine info") + } + + snap := &inventory.Snapshot{ + CollectedAt: time.Now().UTC(), + NodeID: info.MachineID, + Hostname: info.Hostname, + MachineID: info.MachineID, + SystemUUID: info.SystemUUID, + BootID: info.BootID, + OperatingSystem: info.OperatingSystem, + OSImage: info.OSImage, + KernelVersion: info.KernelVersion, + FleetintVersion: info.FleetintVersion, + GPUDriverVersion: info.GPUDriverVersion, + CUDAVersion: info.CUDAVersion, + DCGMVersion: info.DCGMVersion, + ContainerRuntimeVersion: info.ContainerRuntimeVersion, + } + + if info.CPUInfo != nil { + snap.Resources.CPUInfo = inventory.CPUInfo{ + Type: info.CPUInfo.Type, + Manufacturer: info.CPUInfo.Manufacturer, + Architecture: info.CPUInfo.Architecture, + LogicalCores: info.CPUInfo.LogicalCores, + } + } + if info.MemoryInfo != nil { + snap.Resources.MemoryInfo = inventory.MemoryInfo{ + TotalBytes: info.MemoryInfo.TotalBytes, + } + } + if info.GPUInfo != nil { + snap.Resources.GPUInfo = inventory.GPUInfo{ + Product: info.GPUInfo.Product, + Manufacturer: info.GPUInfo.Manufacturer, + Architecture: info.GPUInfo.Architecture, + Memory: info.GPUInfo.Memory, + } + if len(info.GPUInfo.GPUs) > 0 { + snap.Resources.GPUInfo.GPUs = make([]inventory.GPUDevice, 0, len(info.GPUInfo.GPUs)) + for _, gpu := range info.GPUInfo.GPUs { + snap.Resources.GPUInfo.GPUs = append(snap.Resources.GPUInfo.GPUs, inventory.GPUDevice{ + UUID: gpu.UUID, + BusID: gpu.BusID, + SN: gpu.SN, + MinorID: gpu.MinorID, + BoardID: int(gpu.BoardID), + VBIOSVersion: gpu.VBIOSVersion, + ChassisSN: gpu.ChassisSN, + GPUIndex: gpu.GPUIndex, + }) + } + } + } + if info.DiskInfo != nil { + snap.Resources.DiskInfo = inventory.DiskInfo{ + ContainerRootDisk: info.DiskInfo.ContainerRootDisk, + } + if len(info.DiskInfo.BlockDevices) > 0 { + snap.Resources.DiskInfo.BlockDevices = make([]inventory.BlockDevice, 0, len(info.DiskInfo.BlockDevices)) + for _, disk := range info.DiskInfo.BlockDevices { + snap.Resources.DiskInfo.BlockDevices = append(snap.Resources.DiskInfo.BlockDevices, inventory.BlockDevice{ + Name: disk.Name, + Type: disk.Type, + Size: disk.Size, + WWN: disk.WWN, + MountPoint: disk.MountPoint, + FSType: disk.FSType, + PartUUID: disk.PartUUID, + Parents: append([]string(nil), disk.Parents...), + }) + } + } + } + if info.NICInfo != nil && len(info.NICInfo.PrivateIPInterfaces) > 0 { + snap.Resources.NICInfo.PrivateIPInterfaces = make([]inventory.NICInterface, 0, len(info.NICInfo.PrivateIPInterfaces)) + for _, nic := range info.NICInfo.PrivateIPInterfaces { + snap.Resources.NICInfo.PrivateIPInterfaces = append(snap.Resources.NICInfo.PrivateIPInterfaces, inventory.NICInterface{ + Interface: nic.Interface, + MAC: nic.MAC, + IP: nic.IP, + }) + } + snap.NetPrivateIP = info.NICInfo.PrivateIPInterfaces[0].IP + } + + return snap, nil } diff --git a/internal/inventory/source/source_test.go b/internal/inventory/source/source_test.go new file mode 100644 index 00000000..836d40f7 --- /dev/null +++ b/internal/inventory/source/source_test.go @@ -0,0 +1,114 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package source + +import ( + "context" + "testing" + + apiv1 "github.com/NVIDIA/fleet-intelligence-sdk/api/v1" + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" +) + +type fakeMachineInfoCollector struct { + info *machineinfo.MachineInfo + err error +} + +func (f fakeMachineInfoCollector) Collect(context.Context) (*machineinfo.MachineInfo, error) { + return f.info, f.err +} + +func TestMachineInfoSourceCollect(t *testing.T) { + src := NewMachineInfoSource(fakeMachineInfoCollector{ + info: &machineinfo.MachineInfo{ + FleetintVersion: "1.2.3", + GPUDriverVersion: "550.54.15", + CUDAVersion: "12.4", + DCGMVersion: "4.2.3", + ContainerRuntimeVersion: "containerd://1.7.13", + KernelVersion: "6.5.0", + OSImage: "Ubuntu 22.04", + OperatingSystem: "linux", + SystemUUID: "system-uuid", + MachineID: "machine-id", + BootID: "boot-id", + Hostname: "host-a", + CPUInfo: &apiv1.MachineCPUInfo{ + Type: "Xeon", + Manufacturer: "Intel", + Architecture: "x86_64", + LogicalCores: 64, + }, + MemoryInfo: &apiv1.MachineMemoryInfo{ + TotalBytes: 1024, + }, + GPUInfo: &apiv1.MachineGPUInfo{ + Product: "H100", + Manufacturer: "NVIDIA", + Architecture: "Hopper", + Memory: "80GB", + GPUs: []apiv1.MachineGPUInstance{{ + UUID: "GPU-1", + GPUIndex: "0", + BusID: "0000:01:00.0", + SN: "serial", + MinorID: "1", + BoardID: 7, + VBIOSVersion: "vbios", + ChassisSN: "chassis", + }}, + }, + DiskInfo: &apiv1.MachineDiskInfo{ + ContainerRootDisk: "/dev/nvme0n1", + BlockDevices: []apiv1.MachineDiskDevice{{ + Name: "nvme0n1", + Type: "disk", + Size: 2048, + WWN: "wwn", + MountPoint: "/", + FSType: "ext4", + PartUUID: "part-uuid", + Parents: []string{"parent0"}, + }}, + }, + NICInfo: &apiv1.MachineNICInfo{ + PrivateIPInterfaces: []apiv1.MachineNetworkInterface{{ + Interface: "eth0", + MAC: "00:11:22:33:44:55", + IP: "10.0.0.10", + }}, + }, + }, + }) + + snap, err := src.Collect(context.Background()) + require.NoError(t, err) + require.NotNil(t, snap) + require.Equal(t, "machine-id", snap.NodeID) + require.Equal(t, "host-a", snap.Hostname) + require.Equal(t, "10.0.0.10", snap.NetPrivateIP) + require.Equal(t, "Xeon", snap.Resources.CPUInfo.Type) + require.Equal(t, uint64(1024), snap.Resources.MemoryInfo.TotalBytes) + require.Equal(t, "H100", snap.Resources.GPUInfo.Product) + require.Len(t, snap.Resources.GPUInfo.GPUs, 1) + require.Equal(t, 7, snap.Resources.GPUInfo.GPUs[0].BoardID) + require.Equal(t, "/dev/nvme0n1", snap.Resources.DiskInfo.ContainerRootDisk) + require.Len(t, snap.Resources.DiskInfo.BlockDevices, 1) + require.Equal(t, "eth0", snap.Resources.NICInfo.PrivateIPInterfaces[0].Interface) +} diff --git a/internal/inventory/types.go b/internal/inventory/types.go index 131cf3fe..998debee 100644 --- a/internal/inventory/types.go +++ b/internal/inventory/types.go @@ -18,9 +18,13 @@ package inventory import ( "context" + "errors" "time" ) +// ErrNotReady indicates inventory export cannot proceed because backend state is not ready yet. +var ErrNotReady = errors.New("inventory backend not ready") + // Snapshot is the agent-owned inventory state model. type Snapshot struct { CollectedAt time.Time @@ -116,11 +120,3 @@ type Source interface { type Sink interface { Export(ctx context.Context, snap *Snapshot) error } - -// StateStore is the inventory package view of local transient store state. -type StateStore interface { - PutInventory(ctx context.Context, snap *Snapshot) error - GetInventory(ctx context.Context) (*Snapshot, bool, error) - MarkInventoryExported(ctx context.Context, hash string, at time.Time) error - LastExportedInventoryHash(ctx context.Context) (string, error) -} diff --git a/internal/store/memory.go b/internal/store/memory.go deleted file mode 100644 index a93ddb00..00000000 --- a/internal/store/memory.go +++ /dev/null @@ -1,119 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package store contains transient in-agent state stores. -package store - -import ( - "context" - "sync" - "time" - - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestationloop" - "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" -) - -// MemoryStore is the initial in-memory implementation for inventory and attestation state. -type MemoryStore struct { - mu sync.RWMutex - - inventory *inventory.Snapshot - hasInventory bool - lastInventoryHash string - lastInventorySyncTS time.Time - - attestation *attestationloop.Result - hasAttestation bool - exportedAttestationKeys map[string]time.Time -} - -// NewMemoryStore creates an empty in-memory state store. -func NewMemoryStore() *MemoryStore { - return &MemoryStore{ - exportedAttestationKeys: make(map[string]time.Time), - } -} - -func (s *MemoryStore) PutInventory(_ context.Context, snap *inventory.Snapshot) error { - if snap == nil { - return nil - } - s.mu.Lock() - defer s.mu.Unlock() - cloned := *snap - s.inventory = &cloned - s.hasInventory = true - return nil -} - -func (s *MemoryStore) GetInventory(_ context.Context) (*inventory.Snapshot, bool, error) { - s.mu.RLock() - defer s.mu.RUnlock() - if !s.hasInventory || s.inventory == nil { - return nil, false, nil - } - cloned := *s.inventory - return &cloned, true, nil -} - -func (s *MemoryStore) MarkInventoryExported(_ context.Context, hash string, at time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - s.lastInventoryHash = hash - s.lastInventorySyncTS = at - return nil -} - -func (s *MemoryStore) LastExportedInventoryHash(_ context.Context) (string, error) { - s.mu.RLock() - defer s.mu.RUnlock() - return s.lastInventoryHash, nil -} - -func (s *MemoryStore) PutAttestation(_ context.Context, result *attestationloop.Result) error { - if result == nil { - return nil - } - s.mu.Lock() - defer s.mu.Unlock() - cloned := *result - s.attestation = &cloned - s.hasAttestation = true - return nil -} - -func (s *MemoryStore) GetAttestation(_ context.Context) (*attestationloop.Result, bool, error) { - s.mu.RLock() - defer s.mu.RUnlock() - if !s.hasAttestation || s.attestation == nil { - return nil, false, nil - } - cloned := *s.attestation - return &cloned, true, nil -} - -func (s *MemoryStore) MarkAttestationExported(_ context.Context, key string, at time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - s.exportedAttestationKeys[key] = at - return nil -} - -func (s *MemoryStore) WasAttestationExported(_ context.Context, key string) (bool, error) { - s.mu.RLock() - defer s.mu.RUnlock() - _, ok := s.exportedAttestationKeys[key] - return ok, nil -} From 4def98c78c02a6e40970a0dd7f2041576eaea823 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Thu, 16 Apr 2026 10:11:35 -0700 Subject: [PATCH 06/22] refactor: use backend base URL for agent workflows Signed-off-by: Jingxiang Zhang --- cmd/fleetint/enroll.go | 131 +------- cmd/fleetint/enroll_test.go | 63 +--- cmd/fleetint/status.go | 70 ++++- internal/agentstate/sqlite.go | 25 +- internal/agentstate/sqlite_test.go | 15 + internal/attestation/attestation.go | 33 +- internal/attestation/attestation_test.go | 64 +--- internal/endpoint/endpoint.go | 17 + internal/enrollment/enrollment.go | 170 +++++----- internal/enrollment/enrollment_test.go | 381 +++++++---------------- internal/exporter/exporter.go | 104 +++++-- internal/exporter/exporter_test.go | 114 +++++-- 12 files changed, 496 insertions(+), 691 deletions(-) diff --git a/cmd/fleetint/enroll.go b/cmd/fleetint/enroll.go index 26878d9d..a4740e46 100644 --- a/cmd/fleetint/enroll.go +++ b/cmd/fleetint/enroll.go @@ -22,28 +22,12 @@ import ( "os" "strings" - pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" - nvidianvml "github.com/NVIDIA/fleet-intelligence-sdk/pkg/nvidia-query/nvml" - "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" "github.com/urfave/cli" - "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" - "github.com/NVIDIA/fleet-intelligence-agent/internal/config" - "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" "github.com/NVIDIA/fleet-intelligence-agent/internal/enrollment" - "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" - inventorysink "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/sink" - inventorysource "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/source" - "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" ) -var ( - performEnrollment = func(enrollEndpoint, sakToken string) (string, error) { - return enrollment.PerformEnrollment(context.Background(), enrollEndpoint, sakToken) - } - storeEnrollmentConfig = storeConfigInMetadata - performInventorySync = syncInventoryOnce -) +var performEnrollWorkflow = enrollment.Enroll // resolveToken returns the SAK token from --token, --token-file, or stdin. func resolveToken(cliContext *cli.Context) (string, error) { @@ -55,7 +39,7 @@ func resolveToken(cliContext *cli.Context) (string, error) { } if tokenFile != "" { - const maxTokenSize = 1 << 20 // 1 MiB -- SAK tokens are small; anything larger is a mistake + const maxTokenSize = 1 << 20 var raw []byte var err error if tokenFile == "-" { @@ -112,114 +96,5 @@ func enrollCommand(cliContext *cli.Context) error { fmt.Fprintln(writerFromContext(cliContext), "Proceeding with enrollment because --force was provided") } - baseURL, err := endpoint.ValidateBackendEndpoint(baseEndpoint) - if err != nil { - return fmt.Errorf("invalid enrollment endpoint: %w", err) - } - - // Construct enroll endpoint - enrollEndpoint, err := endpoint.JoinPath(baseURL, "api", "v1", "health", "enroll") - if err != nil { - return fmt.Errorf("failed to construct enroll endpoint: %w", err) - } - - // Make enrollment request to get JWT token - jwtToken, err := performEnrollment(enrollEndpoint, sakToken) - if err != nil { - // Error already printed to stderr by PerformEnrollment - return err - } - - // Construct other endpoints using url.JoinPath for proper URL handling - metricsEndpoint, err := endpoint.JoinPath(baseURL, "api", "v1", "health", "metrics") - if err != nil { - return fmt.Errorf("failed to construct metrics endpoint: %w", err) - } - logsEndpoint, err := endpoint.JoinPath(baseURL, "api", "v1", "health", "logs") - if err != nil { - return fmt.Errorf("failed to construct logs endpoint: %w", err) - } - nonceEndpoint, err := endpoint.JoinPath(baseURL, "api", "v1", "health", "nonce") - if err != nil { - return fmt.Errorf("failed to construct nonce endpoint: %w", err) - } - - // Store endpoints and JWT token in metadata table - if err := storeEnrollmentConfig(baseURL.String(), enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken); err != nil { - return fmt.Errorf("failed to store configuration: %w", err) - } - if err := performInventorySync(context.Background()); err != nil { - fmt.Fprintf(writerFromContext(cliContext), "Post-enroll inventory sync failed: %v\n", err) - } - - return nil -} - -func storeConfigInMetadata(baseURL, enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken string) error { - stateFile, err := config.DefaultStateFile() - if err != nil { - return fmt.Errorf("failed to get state file path: %w", err) - } - - dbRW, err := sqlite.Open(stateFile) - if err != nil { - return fmt.Errorf("failed to open state database: %w", err) - } - defer dbRW.Close() - - if err := pkgmetadata.CreateTableMetadata(context.Background(), dbRW); err != nil { - return fmt.Errorf("failed to create metadata table: %w", err) - } - - // Store SAK token (for JWT refresh), JWT token (for API calls), and all endpoints - if err := pkgmetadata.SetMetadata(context.Background(), dbRW, "sak_token", sakToken); err != nil { - return fmt.Errorf("failed to set SAK token: %w", err) - } - if err := pkgmetadata.SetMetadata(context.Background(), dbRW, pkgmetadata.MetadataKeyToken, jwtToken); err != nil { - return fmt.Errorf("failed to set JWT token: %w", err) - } - if err := pkgmetadata.SetMetadata(context.Background(), dbRW, "backend_base_url", baseURL); err != nil { - return fmt.Errorf("failed to set backend base URL: %w", err) - } - if err := pkgmetadata.SetMetadata(context.Background(), dbRW, "enroll_endpoint", enrollEndpoint); err != nil { - return fmt.Errorf("failed to set enroll endpoint: %w", err) - } - if err := pkgmetadata.SetMetadata(context.Background(), dbRW, "metrics_endpoint", metricsEndpoint); err != nil { - return fmt.Errorf("failed to set metrics endpoint: %w", err) - } - if err := pkgmetadata.SetMetadata(context.Background(), dbRW, "logs_endpoint", logsEndpoint); err != nil { - return fmt.Errorf("failed to set logs endpoint: %w", err) - } - if err := pkgmetadata.SetMetadata(context.Background(), dbRW, "nonce_endpoint", nonceEndpoint); err != nil { - return fmt.Errorf("failed to set nonce endpoint: %w", err) - } - if err := config.SecureStateFilePermissions(stateFile); err != nil { - return fmt.Errorf("failed to secure state database permissions: %w", err) - } - - return nil -} - -type machineInfoCollectorFunc func(context.Context) (*machineinfo.MachineInfo, error) - -func (f machineInfoCollectorFunc) Collect(ctx context.Context) (*machineinfo.MachineInfo, error) { - return f(ctx) -} - -func syncInventoryOnce(ctx context.Context) error { - state := agentstate.NewSQLite() - sink := inventorysink.NewBackendSink(state) - - nvmlInstance, err := nvidianvml.New() - if err != nil { - return fmt.Errorf("initialize nvml for inventory sync: %w", err) - } - defer func() { _ = nvmlInstance.Shutdown() }() - - src := inventorysource.NewMachineInfoSource(machineInfoCollectorFunc(func(context.Context) (*machineinfo.MachineInfo, error) { - return machineinfo.GetMachineInfo(nvmlInstance) - })) - manager := inventory.NewManager(src, sink, 0) - _, err = manager.CollectOnce(ctx) - return err + return performEnrollWorkflow(context.Background(), baseEndpoint, sakToken) } diff --git a/cmd/fleetint/enroll_test.go b/cmd/fleetint/enroll_test.go index 681f99c8..e191530b 100644 --- a/cmd/fleetint/enroll_test.go +++ b/cmd/fleetint/enroll_test.go @@ -19,8 +19,6 @@ import ( "bytes" "context" "fmt" - "os" - "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -51,14 +49,10 @@ func TestEnrollCommandPrecheckError(t *testing.T) { func TestEnrollCommandBlocksOnFailedPrecheck(t *testing.T) { originalRunPrecheck := runPrecheck - originalPerformEnrollment := performEnrollment - originalStoreConfig := storeEnrollmentConfig - originalInventorySync := performInventorySync + originalEnrollWorkflow := performEnrollWorkflow t.Cleanup(func() { runPrecheck = originalRunPrecheck - performEnrollment = originalPerformEnrollment - storeEnrollmentConfig = originalStoreConfig - performInventorySync = originalInventorySync + performEnrollWorkflow = originalEnrollWorkflow }) enrollmentCalled := false @@ -69,14 +63,10 @@ func TestEnrollCommandBlocksOnFailedPrecheck(t *testing.T) { }, }, nil } - performEnrollment = func(enrollEndpoint, sakToken string) (string, error) { + performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string) error { enrollmentCalled = true - return "jwt-token", nil - } - storeEnrollmentConfig = func(baseURL, enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken string) error { return nil } - performInventorySync = func(context.Context) error { return nil } out := &bytes.Buffer{} app := App() @@ -92,14 +82,10 @@ func TestEnrollCommandBlocksOnFailedPrecheck(t *testing.T) { func TestEnrollCommandForceBypassesFailedPrecheck(t *testing.T) { originalRunPrecheck := runPrecheck - originalPerformEnrollment := performEnrollment - originalStoreConfig := storeEnrollmentConfig - originalInventorySync := performInventorySync + originalEnrollWorkflow := performEnrollWorkflow t.Cleanup(func() { runPrecheck = originalRunPrecheck - performEnrollment = originalPerformEnrollment - storeEnrollmentConfig = originalStoreConfig - performInventorySync = originalInventorySync + performEnrollWorkflow = originalEnrollWorkflow }) enrollmentCalled := false @@ -110,14 +96,10 @@ func TestEnrollCommandForceBypassesFailedPrecheck(t *testing.T) { }, }, nil } - performEnrollment = func(enrollEndpoint, sakToken string) (string, error) { + performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string) error { enrollmentCalled = true - return "jwt-token", nil - } - storeEnrollmentConfig = func(baseURL, enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken string) error { return nil } - performInventorySync = func(context.Context) error { return nil } app := App() app.Writer = &bytes.Buffer{} @@ -127,36 +109,3 @@ func TestEnrollCommandForceBypassesFailedPrecheck(t *testing.T) { require.NoError(t, err) assert.True(t, enrollmentCalled) } - -func TestStoreConfigInMetadataSecuresFreshStateFile(t *testing.T) { - if os.Geteuid() == 0 { - t.Skip("test expects non-root default state path resolution") - } - - tmpHome := t.TempDir() - t.Setenv("HOME", tmpHome) - - err := storeConfigInMetadata( - "https://example.com", - "https://example.com/api/v1/health/enroll", - "https://example.com/api/v1/health/metrics", - "https://example.com/api/v1/health/logs", - "https://example.com/api/v1/health/nonce", - "jwt-token", - "sak-token", - ) - require.NoError(t, err) - - stateFile := filepath.Join(tmpHome, ".fleetint", "fleetint.state") - for _, candidate := range []string{stateFile, stateFile + "-wal", stateFile + "-shm"} { - info, err := os.Stat(candidate) - if os.IsNotExist(err) { - if candidate == stateFile { - require.NoError(t, err) - } - continue - } - require.NoError(t, err) - assert.Equal(t, os.FileMode(0o600), info.Mode().Perm()) - } -} diff --git a/cmd/fleetint/status.go b/cmd/fleetint/status.go index 6c61d4fc..627b0af0 100644 --- a/cmd/fleetint/status.go +++ b/cmd/fleetint/status.go @@ -36,6 +36,12 @@ import ( "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" ) +type enrollmentStatus struct { + baseURL string + metricsEndpoint string + logsEndpoint string +} + func statusCommand(cliContext *cli.Context) error { logLevel := cliContext.String("log-level") serverURL := cliContext.String("server-url") @@ -78,25 +84,21 @@ func statusCommand(cliContext *cli.Context) error { defer dbRO.Close() log.Logger.Debugw("successfully opened state file for reading") - metricsEndpoint, err := pkgmetadata.ReadMetadata(rootCtx, dbRO, "metrics_endpoint") - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return fmt.Errorf("failed to read metrics endpoint: %w", err) - } - logsEndpoint, err := pkgmetadata.ReadMetadata(rootCtx, dbRO, "logs_endpoint") - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return fmt.Errorf("failed to read logs endpoint: %w", err) + enrollment, err := readEnrollmentStatus(rootCtx, dbRO) + if err != nil { + return err } - if metricsEndpoint != "" || logsEndpoint != "" { + if enrollment.baseURL != "" || enrollment.metricsEndpoint != "" || enrollment.logsEndpoint != "" { fmt.Printf("%s enrolled\n", cmdutil.CheckMark) - if metricsEndpoint != "" { - fmt.Printf(" metrics endpoint: %s\n", metricsEndpoint) + if enrollment.metricsEndpoint != "" { + fmt.Printf(" metrics endpoint: %s\n", enrollment.metricsEndpoint) } - if logsEndpoint != "" { - fmt.Printf(" logs endpoint: %s\n", logsEndpoint) + if enrollment.logsEndpoint != "" { + fmt.Printf(" logs endpoint: %s\n", enrollment.logsEndpoint) } } else { - fmt.Printf("%s not enrolled (no endpoint configured)\n", cmdutil.WarningSign) + fmt.Printf("%s not enrolled (no backend or legacy endpoints configured)\n", cmdutil.WarningSign) } var active bool @@ -159,3 +161,45 @@ func statusCommand(cliContext *cli.Context) error { fmt.Printf("%s successfully checked fleetint health\n", cmdutil.CheckMark) return nil } + +func readEnrollmentStatus(ctx context.Context, dbRO *sql.DB) (*enrollmentStatus, error) { + baseURL, err := pkgmetadata.ReadMetadata(ctx, dbRO, "backend_base_url") + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("failed to read backend base URL: %w", err) + } + + status := &enrollmentStatus{baseURL: baseURL} + if baseURL != "" { + validated, err := endpoint.ValidateBackendEndpoint(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid backend base URL in metadata: %w", err) + } + status.metricsEndpoint, err = endpoint.JoinPath(validated, "api", "v1", "health", "metrics") + if err != nil { + return nil, fmt.Errorf("failed to construct metrics endpoint: %w", err) + } + status.logsEndpoint, err = endpoint.JoinPath(validated, "api", "v1", "health", "logs") + if err != nil { + return nil, fmt.Errorf("failed to construct logs endpoint: %w", err) + } + return status, nil + } + + status.metricsEndpoint, err = readLegacyEndpoint(ctx, dbRO, "metrics_endpoint") + if err != nil { + return nil, err + } + status.logsEndpoint, err = readLegacyEndpoint(ctx, dbRO, "logs_endpoint") + if err != nil { + return nil, err + } + return status, nil +} + +func readLegacyEndpoint(ctx context.Context, dbRO *sql.DB, key string) (string, error) { + value, err := pkgmetadata.ReadMetadata(ctx, dbRO, key) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return "", fmt.Errorf("failed to read %s: %w", key, err) + } + return value, nil +} diff --git a/internal/agentstate/sqlite.go b/internal/agentstate/sqlite.go index f384d6ed..fa25babb 100644 --- a/internal/agentstate/sqlite.go +++ b/internal/agentstate/sqlite.go @@ -24,6 +24,7 @@ import ( "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" + "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" ) const metadataKeyBackendBaseURL = "backend_base_url" @@ -38,7 +39,29 @@ func NewSQLite() State { } func (s *sqliteState) GetBackendBaseURL(ctx context.Context) (string, bool, error) { - return s.getMetadata(ctx, metadataKeyBackendBaseURL) + db, err := s.openReadOnly() + if err != nil { + return "", false, err + } + defer db.Close() + + if value, err := pkgmetadata.ReadMetadata(ctx, db, metadataKeyBackendBaseURL); err == nil && value != "" { + return value, true, nil + } + + for _, key := range []string{"enroll_endpoint", "metrics_endpoint", "logs_endpoint", "nonce_endpoint"} { + value, err := pkgmetadata.ReadMetadata(ctx, db, key) + if err != nil || value == "" { + continue + } + baseURL, err := endpoint.DeriveBackendBaseURL(value) + if err != nil { + return "", false, fmt.Errorf("derive backend base URL from metadata %q: %w", key, err) + } + return baseURL, true, nil + } + + return "", false, nil } func (s *sqliteState) SetBackendBaseURL(ctx context.Context, value string) error { diff --git a/internal/agentstate/sqlite_test.go b/internal/agentstate/sqlite_test.go index ee7abc97..442f37d6 100644 --- a/internal/agentstate/sqlite_test.go +++ b/internal/agentstate/sqlite_test.go @@ -85,6 +85,21 @@ func TestSQLiteStateMissingValue(t *testing.T) { require.Empty(t, value) } +func TestSQLiteStateGetBackendBaseURLFallsBackToLegacyEndpoints(t *testing.T) { + t.Parallel() + + ctx := context.Background() + state := newTestSQLiteState(t) + + err := state.setMetadata(ctx, "metrics_endpoint", "https://backend.example.com/api/v1/health/metrics") + require.NoError(t, err) + + value, ok, err := state.GetBackendBaseURL(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "https://backend.example.com", value) +} + func TestSQLiteStateStateFileErrors(t *testing.T) { t.Parallel() diff --git a/internal/attestation/attestation.go b/internal/attestation/attestation.go index e68984d0..e07441f0 100644 --- a/internal/attestation/attestation.go +++ b/internal/attestation/attestation.go @@ -370,16 +370,29 @@ func (m *Manager) getNonce(jwtToken string, machineId string) (string, time.Time } func (m *Manager) getValidatedNonceEndpoint(ctx context.Context) (string, error) { - nonceEndpoint := m.getNonceEndpointFromMetadata(ctx) - if nonceEndpoint == "" { - return "", fmt.Errorf("nonce endpoint not found in metadata") + baseEndpoint := m.getEndpointFromMetadata(ctx) + if baseEndpoint != "" { + validated, err := endpoint.ValidateBackendEndpoint(baseEndpoint) + if err != nil { + return "", fmt.Errorf("invalid backend endpoint: %w", err) + } + + joined, err := endpoint.JoinPath(validated, "api", "v1", "health", "nonce") + if err != nil { + return "", fmt.Errorf("failed to construct nonce endpoint: %w", err) + } + return joined, nil } - validated, err := endpoint.ValidateBackendEndpoint(nonceEndpoint) + legacyNonceEndpoint := m.getLegacyNonceEndpointFromMetadata(ctx) + if legacyNonceEndpoint == "" { + return "", fmt.Errorf("backend endpoint not found in metadata") + } + + validated, err := endpoint.ValidateBackendEndpoint(legacyNonceEndpoint) if err != nil { return "", fmt.Errorf("invalid nonce endpoint: %w", err) } - return validated.String(), nil } @@ -421,8 +434,8 @@ func (m *Manager) getEndpointFromMetadata(ctx context.Context) string { } defer dbRO.Close() - // Load endpoint from metadata - if endpoint, err := pkgmetadata.ReadMetadata(ctx, dbRO, pkgmetadata.MetadataKeyEndpoint); err == nil && endpoint != "" { + // Load backend base URL from metadata + if endpoint, err := pkgmetadata.ReadMetadata(ctx, dbRO, "backend_base_url"); err == nil && endpoint != "" { return endpoint } @@ -430,8 +443,7 @@ func (m *Manager) getEndpointFromMetadata(ctx context.Context) string { return "" } -// getNonceEndpointFromMetadata retrieves the nonce endpoint from the metadata database -func (m *Manager) getNonceEndpointFromMetadata(ctx context.Context) string { +func (m *Manager) getLegacyNonceEndpointFromMetadata(ctx context.Context) string { stateFile, err := defaultStateFileFn() if err != nil { log.Logger.Debugw("failed to get state file path", "error", err) @@ -445,12 +457,11 @@ func (m *Manager) getNonceEndpointFromMetadata(ctx context.Context) string { } defer dbRO.Close() - // Load nonce endpoint from metadata if endpoint, err := pkgmetadata.ReadMetadata(ctx, dbRO, "nonce_endpoint"); err == nil && endpoint != "" { return endpoint } - log.Logger.Debugw("Nonce endpoint not found in metadata") + log.Logger.Debugw("legacy nonce endpoint not found in metadata") return "" } diff --git a/internal/attestation/attestation_test.go b/internal/attestation/attestation_test.go index 3f11e87e..e43c80e7 100644 --- a/internal/attestation/attestation_test.go +++ b/internal/attestation/attestation_test.go @@ -24,7 +24,6 @@ import ( "net/http/httptest" "path/filepath" "strings" - "sync/atomic" "testing" "time" @@ -303,63 +302,10 @@ func TestManager_GetNonce_ServerError(t *testing.T) { assert.True(t, response.NonceRefreshTimestamp.IsZero()) } -func TestManager_GetNonce_RejectsNonOKStatus(t *testing.T) { - manager := newTestManager(t) - var redirectTargetCalled atomic.Bool - var server *httptest.Server - server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/redirected" { - redirectTargetCalled.Store(true) - w.WriteHeader(http.StatusOK) - return - } - http.Redirect(w, r, server.URL+"/redirected", http.StatusFound) - })) - defer server.Close() - - useDefaultTransport(t, server.Client().Transport) - stateFile := setupAttestationMetadataDB(t, map[string]string{ - "nonce_endpoint": server.URL, - }) - useTestStateFile(t, stateFile) - - nonce, refresh, err := manager.getNonce("test-jwt-token", "test-machine-id") - - require.Error(t, err) - assert.Contains(t, err.Error(), "nonce endpoint returned HTTP 302") - assert.Empty(t, nonce) - assert.True(t, refresh.IsZero()) - assert.False(t, redirectTargetCalled.Load()) -} - -func TestManager_GetNonce_RejectsServerErrorPayload(t *testing.T) { - manager := newTestManager(t) - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(map[string]string{ - "error": "Invalid token", - })) - })) - defer server.Close() - - useDefaultTransport(t, server.Client().Transport) - stateFile := setupAttestationMetadataDB(t, map[string]string{ - "nonce_endpoint": server.URL, - }) - useTestStateFile(t, stateFile) - - nonce, refresh, err := manager.getNonce("test-jwt-token", "test-machine-id") - - require.Error(t, err) - assert.Contains(t, err.Error(), "nonce endpoint returned error: Invalid token") - assert.Empty(t, nonce) - assert.True(t, refresh.IsZero()) -} - -func TestManager_GetValidatedNonceEndpoint_UsesStoredNonceEndpoint(t *testing.T) { +func TestManager_GetValidatedNonceEndpoint_DerivesFromStoredBackendBaseURL(t *testing.T) { manager := newTestManager(t) stateFile := setupAttestationMetadataDB(t, map[string]string{ - "nonce_endpoint": "https://backend.example.com/api/v1/health/nonce", + "backend_base_url": "https://backend.example.com", }) useTestStateFile(t, stateFile) @@ -368,16 +314,16 @@ func TestManager_GetValidatedNonceEndpoint_UsesStoredNonceEndpoint(t *testing.T) assert.Equal(t, "https://backend.example.com/api/v1/health/nonce", got) } -func TestManager_GetValidatedNonceEndpoint_RejectsTamperedStoredNonceEndpoint(t *testing.T) { +func TestManager_GetValidatedNonceEndpoint_RejectsInvalidStoredBackendBaseURL(t *testing.T) { manager := newTestManager(t) stateFile := setupAttestationMetadataDB(t, map[string]string{ - "nonce_endpoint": "http://evil.example.com/api/v1/health/nonce", + "backend_base_url": "http://evil.example.com", }) useTestStateFile(t, stateFile) _, err := manager.getValidatedNonceEndpoint(context.Background()) require.Error(t, err) - assert.Contains(t, err.Error(), "invalid nonce endpoint") + assert.Contains(t, err.Error(), "invalid backend endpoint") assert.Contains(t, err.Error(), "https") } diff --git a/internal/endpoint/endpoint.go b/internal/endpoint/endpoint.go index 8d10e713..31b1d9b4 100644 --- a/internal/endpoint/endpoint.go +++ b/internal/endpoint/endpoint.go @@ -121,6 +121,23 @@ func ValidateBackendEndpoint(raw string) (*url.URL, error) { return parsed, nil } +// DeriveBackendBaseURL converts a legacy HTTPS endpoint URL into its backend base URL. +// For example, "https://backend.example.com/api/v1/health/metrics" becomes +// "https://backend.example.com". +func DeriveBackendBaseURL(raw string) (string, error) { + parsed, err := parseURL(raw) + if err != nil { + return "", err + } + if err := requireScheme(parsed, "https"); err != nil { + return "", err + } + return (&url.URL{ + Scheme: parsed.Scheme, + Host: parsed.Host, + }).String(), nil +} + // JoinPath appends path elements to a validated base URL. func JoinPath(base *url.URL, elems ...string) (string, error) { return url.JoinPath(base.String(), elems...) diff --git a/internal/enrollment/enrollment.go b/internal/enrollment/enrollment.go index 67c26175..d999b5df 100644 --- a/internal/enrollment/enrollment.go +++ b/internal/enrollment/enrollment.go @@ -13,123 +13,109 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package enrollment provides shared enrollment functionality for the Fleet Intelligence agent +// Package enrollment provides shared enrollment functionality for the Fleet Intelligence agent. package enrollment import ( "context" - "encoding/json" "fmt" - "io" - "net/http" - "os" - "time" "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" + pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" + nvidianvml "github.com/NVIDIA/fleet-intelligence-sdk/pkg/nvidia-query/nvml" + "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" + "github.com/NVIDIA/fleet-intelligence-agent/internal/config" + "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" + inventorysink "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/sink" + inventorysource "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/source" + "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" ) -const maxEnrollmentResponseSize = 1 << 20 - -// EnrollResponse represents the response from the enrollment endpoint -type EnrollResponse struct { - JWTToken string `json:"jwt_assertion"` -} +var ( + newBackendClient = backendclient.New + syncInventoryAfterEnroll = syncInventoryOnce +) -// PerformEnrollment performs the enrollment request to get a new JWT token -func PerformEnrollment(ctx context.Context, enrollEndpoint, sakToken string) (string, error) { - if enrollEndpoint == "" { - return "", fmt.Errorf("enrollEndpoint cannot be empty") - } - if sakToken == "" { - return "", fmt.Errorf("sakToken cannot be empty") +// Enroll runs the full enrollment workflow and performs a best-effort initial inventory sync. +func Enroll(ctx context.Context, baseEndpoint, sakToken string) error { + baseURL, err := endpoint.ValidateBackendEndpoint(baseEndpoint) + if err != nil { + return fmt.Errorf("invalid enrollment endpoint: %w", err) } - // Use the provided enrollment endpoint directly - enrollURL := enrollEndpoint - - // Create HTTP client with timeout. Disable redirects so a compromised - // backend cannot bounce us to an internal service (SSRF). - client := &http.Client{ - Timeout: 30 * time.Second, - CheckRedirect: func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse - }, + client, err := newBackendClient(baseURL.String()) + if err != nil { + return fmt.Errorf("failed to create backend client: %w", err) } - - // Create HTTP request with empty body - req, err := http.NewRequestWithContext(ctx, "POST", enrollURL, nil) + jwtToken, err := client.Enroll(ctx, sakToken) if err != nil { - return "", fmt.Errorf("failed to create HTTP request: %w", err) + return err } + if err := storeConfigInMetadata(ctx, baseURL.String(), jwtToken, sakToken); err != nil { + return fmt.Errorf("failed to store configuration: %w", err) + } + if err := syncInventoryAfterEnroll(ctx); err != nil { + log.Logger.Warnw("post-enroll inventory sync failed", "error", err) + } + return nil +} - // Set headers (no Content-Type since no body is sent) - req.Header.Set("User-Agent", "fleet-intelligence-agent") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", sakToken)) - - // Make the request - resp, err := client.Do(req) +func storeConfigInMetadata(ctx context.Context, baseURL, jwtToken, sakToken string) error { + stateFile, err := config.DefaultStateFile() if err != nil { - log.Logger.Errorw("Enrollment request failed", "error", err, "url", enrollURL) - return "", fmt.Errorf("failed to make enrollment request: %w", err) + return fmt.Errorf("failed to get state file path: %w", err) } - defer resp.Body.Close() - // Read at most max+1 bytes so oversized responses fail explicitly instead of - // surfacing later as a truncated JSON parse error. - respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxEnrollmentResponseSize+1)) + dbRW, err := sqlite.Open(stateFile) if err != nil { - log.Logger.Errorw("Failed to read enrollment response body", "error", err) - return "", fmt.Errorf("failed to read enrollment response: %w", err) - } - if len(respBody) > maxEnrollmentResponseSize { - err = fmt.Errorf("enrollment response too large (max %d bytes)", maxEnrollmentResponseSize) - log.Logger.Errorw("Failed to read enrollment response body", "error", err) - return "", err + return fmt.Errorf("failed to open state database: %w", err) } + defer dbRW.Close() - // Check response status and return specific error messages - if resp.StatusCode != http.StatusOK { - var errMsg string - switch resp.StatusCode { - case http.StatusBadRequest: // 400 - errMsg = "The token used in the enrollment is not in the correct format. Please check the token. If all else fails, generate a new token by going to the UI" - case http.StatusUnauthorized: // 401 - errMsg = "The token used in the enrollment is incorrect. Please generate a new token by going to the UI or make sure you are using the correct token" - case http.StatusForbidden: // 403 - errMsg = "The token used in the enrollment is incorrect/expired. Please generate a new token by going to the UI or make sure you are using the correct token" - case http.StatusNotFound: // 404 - errMsg = "The endpoint is not found. Please use the correct endpoint" - case http.StatusTooManyRequests: // 429 - errMsg = "Please retry after some time. Server is under heavy load" - case http.StatusBadGateway: // 502 - errMsg = "Some temporary issue caused enrollment to fail. Please try again" - case http.StatusServiceUnavailable: // 503 - errMsg = "Service is unavailable currently. Please try again" - case http.StatusGatewayTimeout: // 504 - errMsg = "Service is experiencing load and is slow to respond. Please try again maybe after a few minutes" - default: - errMsg = fmt.Sprintf("enrollment request failed with status %d", resp.StatusCode) - } - - // Print error to stderr - fmt.Fprintf(os.Stderr, "Enrollment failed: %s\n", errMsg) - return "", fmt.Errorf("%s", errMsg) + if err := pkgmetadata.CreateTableMetadata(ctx, dbRW); err != nil { + return fmt.Errorf("failed to create metadata table: %w", err) } - // Parse response - var enrollResp EnrollResponse - if err := json.Unmarshal(respBody, &enrollResp); err != nil { - log.Logger.Errorw("Failed to parse enrollment response JSON", "error", err) - return "", fmt.Errorf("failed to parse enrollment response: %w", err) + state := agentstate.NewSQLite() + if err := state.SetSAK(ctx, sakToken); err != nil { + return fmt.Errorf("failed to set SAK token: %w", err) } - - // Validate JWT token is present - if enrollResp.JWTToken == "" { - log.Logger.Errorw("Enrollment response missing jwt-token field") - return "", fmt.Errorf("enrollment response missing jwt-token field") + if err := state.SetJWT(ctx, jwtToken); err != nil { + return fmt.Errorf("failed to set JWT token: %w", err) + } + if err := state.SetBackendBaseURL(ctx, baseURL); err != nil { + return fmt.Errorf("failed to set backend base URL: %w", err) } + if err := config.SecureStateFilePermissions(stateFile); err != nil { + return fmt.Errorf("failed to secure state database permissions: %w", err) + } + return nil +} + +type machineInfoCollectorFunc func(context.Context) (*machineinfo.MachineInfo, error) - // Print success to stdout - fmt.Fprintf(os.Stdout, "Enrollment succeeded\n") - return enrollResp.JWTToken, nil +func (f machineInfoCollectorFunc) Collect(ctx context.Context) (*machineinfo.MachineInfo, error) { + return f(ctx) +} + +func syncInventoryOnce(ctx context.Context) error { + state := agentstate.NewSQLite() + sink := inventorysink.NewBackendSink(state) + + nvmlInstance, err := nvidianvml.New() + if err != nil { + return fmt.Errorf("initialize nvml for inventory sync: %w", err) + } + defer func() { _ = nvmlInstance.Shutdown() }() + + src := inventorysource.NewMachineInfoSource(machineInfoCollectorFunc(func(context.Context) (*machineinfo.MachineInfo, error) { + return machineinfo.GetMachineInfo(nvmlInstance) + })) + manager := inventory.NewManager(src, sink, 0) + _, err = manager.CollectOnce(ctx) + return err } diff --git a/internal/enrollment/enrollment_test.go b/internal/enrollment/enrollment_test.go index 74a1cc35..338c8285 100644 --- a/internal/enrollment/enrollment_test.go +++ b/internal/enrollment/enrollment_test.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,318 +17,151 @@ package enrollment import ( "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" + "errors" + "os" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" -) - -func TestPerformEnrollment_Success(t *testing.T) { - expectedToken := "test-jwt-token-12345" - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify request method - assert.Equal(t, "POST", r.Method) - - // Verify headers - assert.Equal(t, "fleet-intelligence-agent", r.Header.Get("User-Agent")) - assert.Equal(t, "Bearer test-sak-token", r.Header.Get("Authorization")) - - // Send successful response - response := EnrollResponse{ - JWTToken: expectedToken, - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - require.NoError(t, err) - })) - defer server.Close() - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" + "github.com/NVIDIA/fleet-intelligence-agent/internal/config" +) - require.NoError(t, err) - assert.Equal(t, expectedToken, token) +type fakeBackendClient struct { + enrollSAK string + enrollJWT string + enrollErr error } -func TestPerformEnrollment_EmptyEndpoint(t *testing.T) { - ctx := context.Background() - token, err := PerformEnrollment(ctx, "", "test-sak-token") - - require.Error(t, err) - assert.Contains(t, err.Error(), "enrollEndpoint cannot be empty") - assert.Empty(t, token) +func (f *fakeBackendClient) Enroll(_ context.Context, sakToken string) (string, error) { + f.enrollSAK = sakToken + return f.enrollJWT, f.enrollErr } -func TestPerformEnrollment_EmptyToken(t *testing.T) { - ctx := context.Background() - token, err := PerformEnrollment(ctx, "http://example.com", "") - - require.Error(t, err) - assert.Contains(t, err.Error(), "sakToken cannot be empty") - assert.Empty(t, token) +func (f *fakeBackendClient) UpsertNode(context.Context, string, *backendclient.NodeUpsertRequest, string) error { + return nil } -func TestPerformEnrollment_HTTPStatusCodes(t *testing.T) { - tests := []struct { - name string - statusCode int - expectedErrMsg string - }{ - { - name: "BadRequest_400", - statusCode: http.StatusBadRequest, - expectedErrMsg: "The token used in the enrollment is not in the correct format", - }, - { - name: "Unauthorized_401", - statusCode: http.StatusUnauthorized, - expectedErrMsg: "The token used in the enrollment is incorrect", - }, - { - name: "Forbidden_403", - statusCode: http.StatusForbidden, - expectedErrMsg: "The token used in the enrollment is incorrect/expired", - }, - { - name: "NotFound_404", - statusCode: http.StatusNotFound, - expectedErrMsg: "The endpoint is not found", - }, - { - name: "TooManyRequests_429", - statusCode: http.StatusTooManyRequests, - expectedErrMsg: "Please retry after some time", - }, - { - name: "BadGateway_502", - statusCode: http.StatusBadGateway, - expectedErrMsg: "Some temporary issue caused enrollment to fail", - }, - { - name: "ServiceUnavailable_503", - statusCode: http.StatusServiceUnavailable, - expectedErrMsg: "Service is unavailable currently", - }, - { - name: "GatewayTimeout_504", - statusCode: http.StatusGatewayTimeout, - expectedErrMsg: "Service is experiencing load and is slow to respond", - }, - { - name: "InternalServerError_500", - statusCode: http.StatusInternalServerError, - expectedErrMsg: "enrollment request failed with status 500", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(tt.statusCode) - })) - defer server.Close() - - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") - - require.Error(t, err) - assert.Contains(t, err.Error(), tt.expectedErrMsg) - assert.Empty(t, token) - }) - } +func (f *fakeBackendClient) GetNonce(context.Context, string, string) (*backendclient.NonceResponse, error) { + return nil, nil } -func TestPerformEnrollment_DoesNotFollowRedirects(t *testing.T) { - redirectTargetCalled := false - redirectTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - redirectTargetCalled = true - w.WriteHeader(http.StatusOK) - })) - defer redirectTarget.Close() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, redirectTarget.URL, http.StatusFound) - })) - defer server.Close() - - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") - - require.Error(t, err) - assert.Contains(t, err.Error(), "status 302") - assert.Empty(t, token) - assert.False(t, redirectTargetCalled) +func (f *fakeBackendClient) SubmitAttestation(context.Context, string, *backendclient.AttestationRequest, string) error { + return nil } -func TestPerformEnrollment_MissingJWTToken(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Send response without JWT token - response := EnrollResponse{ - JWTToken: "", - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - require.NoError(t, err) - })) - defer server.Close() - - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") - - require.Error(t, err) - assert.Contains(t, err.Error(), "enrollment response missing jwt-token field") - assert.Empty(t, token) +func (f *fakeBackendClient) RefreshToken(context.Context, string) (string, error) { + return "", nil } -func TestPerformEnrollment_InvalidJSON(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte("invalid json")) - require.NoError(t, err) - })) - defer server.Close() - - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") - - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to parse enrollment response") - assert.Empty(t, token) -} +func TestEnrollWorkflow(t *testing.T) { + originalFactory := newBackendClient + originalSync := syncInventoryAfterEnroll + t.Cleanup(func() { + newBackendClient = originalFactory + syncInventoryAfterEnroll = originalSync + }) + + client := &fakeBackendClient{enrollJWT: "jwt-token"} + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + require.Equal(t, "https://example.com", rawBaseURL) + return client, nil + } -func TestPerformEnrollment_ResponseTooLarge(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, err := w.Write(make([]byte, maxEnrollmentResponseSize+1)) - require.NoError(t, err) - })) - defer server.Close() + syncCalled := false + syncInventoryAfterEnroll = func(ctx context.Context) error { + syncCalled = true + return nil + } - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) - require.Error(t, err) - assert.Contains(t, err.Error(), "enrollment response too large") - assert.Empty(t, token) + err := Enroll(context.Background(), "https://example.com", "sak-token") + require.NoError(t, err) + require.Equal(t, "sak-token", client.enrollSAK) + require.True(t, syncCalled) } -func TestPerformEnrollment_ContextCancellation(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Simulate slow response - time.Sleep(2 * time.Second) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") +func TestEnrollWorkflowErrors(t *testing.T) { + t.Run("invalid endpoint", func(t *testing.T) { + err := Enroll(context.Background(), "http://example.com", "sak-token") + require.Error(t, err) + }) + + t.Run("backend client creation", func(t *testing.T) { + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + return nil, errors.New("factory boom") + } - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to make enrollment request") - assert.Empty(t, token) -} + err := Enroll(context.Background(), "https://example.com", "sak-token") + require.ErrorContains(t, err, "failed to create backend client") + }) -func TestPerformEnrollment_ServerUnavailable(t *testing.T) { - // Use an invalid URL that will fail to connect - ctx := context.Background() - token, err := PerformEnrollment(ctx, "http://localhost:99999", "test-sak-token") + t.Run("enroll error", func(t *testing.T) { + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + return &fakeBackendClient{enrollErr: errors.New("enroll boom")}, nil + } - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to make enrollment request") - assert.Empty(t, token) + err := Enroll(context.Background(), "https://example.com", "sak-token") + require.ErrorContains(t, err, "enroll boom") + }) } -func TestPerformEnrollment_RequestBodyEmpty(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify no body is sent (Content-Length should be 0 or not set) - assert.Equal(t, int64(0), r.ContentLength) +func TestEnrollWorkflowInventorySyncFailureIsNonFatal(t *testing.T) { + originalFactory := newBackendClient + originalSync := syncInventoryAfterEnroll + t.Cleanup(func() { + newBackendClient = originalFactory + syncInventoryAfterEnroll = originalSync + }) - response := EnrollResponse{ - JWTToken: "test-token", - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - require.NoError(t, err) - })) - defer server.Close() + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + return &fakeBackendClient{enrollJWT: "jwt-token"}, nil + } + syncInventoryAfterEnroll = func(ctx context.Context) error { + return errors.New("inventory failed") + } - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + err := Enroll(context.Background(), "https://example.com", "sak-token") require.NoError(t, err) - assert.NotEmpty(t, token) } -func TestEnrollResponse_JSONSerialization(t *testing.T) { - tests := []struct { - name string - response EnrollResponse - }{ - { - name: "with_token", - response: EnrollResponse{ - JWTToken: "test-jwt-token", - }, - }, - { - name: "empty_token", - response: EnrollResponse{ - JWTToken: "", - }, - }, +func TestStoreConfigInMetadataSecuresFreshStateFile(t *testing.T) { + if os.Geteuid() == 0 { + t.Skip("test expects non-root default state path resolution") } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Test marshaling - data, err := json.Marshal(tt.response) - assert.NoError(t, err) - - // Test unmarshaling - var unmarshaled EnrollResponse - err = json.Unmarshal(data, &unmarshaled) - assert.NoError(t, err) - assert.Equal(t, tt.response.JWTToken, unmarshaled.JWTToken) - }) - } -} + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) -func TestPerformEnrollment_MultipleSuccessiveRequests(t *testing.T) { - requestCount := 0 + err := storeConfigInMetadata( + context.Background(), + "https://example.com", + "jwt-token", + "sak-token", + ) + require.NoError(t, err) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount++ - response := EnrollResponse{ - JWTToken: fmt.Sprintf("token-%d", requestCount), + stateFile, err := config.DefaultStateFile() + require.NoError(t, err) + for _, candidate := range []string{stateFile, stateFile + "-wal", stateFile + "-shm"} { + info, err := os.Stat(candidate) + if os.IsNotExist(err) { + if candidate == stateFile { + require.NoError(t, err) + } + continue } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - require.NoError(t, err) - })) - defer server.Close() - - ctx := context.Background() - - // Make multiple requests - for i := 1; i <= 3; i++ { - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") require.NoError(t, err) - assert.Equal(t, fmt.Sprintf("token-%d", i), token) + assert.Equal(t, os.FileMode(0o600), info.Mode().Perm()) } - - assert.Equal(t, 3, requestCount) } diff --git a/internal/exporter/exporter.go b/internal/exporter/exporter.go index 5b7e9b97..185dde18 100644 --- a/internal/exporter/exporter.go +++ b/internal/exporter/exporter.go @@ -31,8 +31,8 @@ import ( pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" "github.com/NVIDIA/fleet-intelligence-agent/internal/attestation" + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" - "github.com/NVIDIA/fleet-intelligence-agent/internal/enrollment" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/collector" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/converter" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/writer" @@ -42,6 +42,8 @@ import ( // Ensure healthExporter implements the Exporter interface var _ Exporter = (*healthExporter)(nil) +var newBackendClient = backendclient.New + // healthExporter implements the Exporter interface with improved architecture type healthExporter struct { ctx context.Context @@ -267,17 +269,35 @@ func (e *healthExporter) refreshConfigFromMetadata(ctx context.Context) { return } - // Load metrics endpoint (update even if empty to handle un-enrollment) - if metricsEndpoint, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, "metrics_endpoint"); err == nil { - if metricsEndpoint != "" { - validated, validateErr := endpoint.ValidateBackendEndpoint(metricsEndpoint) - if validateErr != nil { - log.Logger.Errorw("ignoring invalid metrics endpoint from metadata", "error", validateErr) - metricsEndpoint = "" + metricsEndpoint := "" + logsEndpoint := "" + baseURL, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, "backend_base_url") + if err != nil { + log.Logger.Errorw("failed to read backend base URL from metadata", "error", err) + baseURL = "" + } + if baseURL != "" { + validated, validateErr := endpoint.ValidateBackendEndpoint(baseURL) + if validateErr != nil { + log.Logger.Errorw("ignoring invalid backend base URL from metadata", "error", validateErr) + } else { + if joined, joinErr := endpoint.JoinPath(validated, "api", "v1", "health", "metrics"); joinErr == nil { + metricsEndpoint = joined } else { - metricsEndpoint = validated.String() + log.Logger.Errorw("failed to derive metrics endpoint from backend base URL", "error", joinErr) + } + if joined, joinErr := endpoint.JoinPath(validated, "api", "v1", "health", "logs"); joinErr == nil { + logsEndpoint = joined + } else { + log.Logger.Errorw("failed to derive logs endpoint from backend base URL", "error", joinErr) } } + } else { + metricsEndpoint = e.readValidatedEndpoint(ctx, "metrics_endpoint") + logsEndpoint = e.readValidatedEndpoint(ctx, "logs_endpoint") + } + + { if e.options.config.MetricsEndpoint != metricsEndpoint { e.options.config.MetricsEndpoint = metricsEndpoint if metricsEndpoint == "" { @@ -286,21 +306,9 @@ func (e *healthExporter) refreshConfigFromMetadata(ctx context.Context) { log.Logger.Infow("updated metrics endpoint from metadata", "metrics_endpoint", metricsEndpoint) } } - } else { - log.Logger.Errorw("failed to read metrics endpoint from metadata", "error", err) } - // Load logs endpoint (update even if empty to handle un-enrollment) - if logsEndpoint, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, "logs_endpoint"); err == nil { - if logsEndpoint != "" { - validated, validateErr := endpoint.ValidateBackendEndpoint(logsEndpoint) - if validateErr != nil { - log.Logger.Errorw("ignoring invalid logs endpoint from metadata", "error", validateErr) - logsEndpoint = "" - } else { - logsEndpoint = validated.String() - } - } + { if e.options.config.LogsEndpoint != logsEndpoint { e.options.config.LogsEndpoint = logsEndpoint if logsEndpoint == "" { @@ -309,8 +317,6 @@ func (e *healthExporter) refreshConfigFromMetadata(ctx context.Context) { log.Logger.Infow("updated logs endpoint from metadata", "logs_endpoint", logsEndpoint) } } - } else { - log.Logger.Errorw("failed to read logs endpoint from metadata", "error", err) } // Load auth token (update even if empty to handle un-enrollment) @@ -355,14 +361,24 @@ func (e *healthExporter) refreshJWTToken(ctx context.Context) (string, error) { return "", fmt.Errorf("no SAK token available for JWT refresh") } - // Read enroll endpoint from metadata (stored during enrollment) - enrollEndpoint, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, "enroll_endpoint") - if err != nil || enrollEndpoint == "" { - return "", fmt.Errorf("no enroll endpoint available for JWT refresh") + baseURL, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, "backend_base_url") + if err == nil && baseURL != "" { + // use configured base URL + } else { + baseURL, err = e.readLegacyBackendBaseURL(ctx) + if err != nil { + return "", err + } + if baseURL == "" { + return "", fmt.Errorf("no backend base URL available for JWT refresh") + } } - // Perform enrollment to get new JWT token using the shared function - newJWT, err := enrollment.PerformEnrollment(ctx, enrollEndpoint, sakToken) + client, err := newBackendClient(baseURL) + if err != nil { + return "", fmt.Errorf("failed to create backend client for JWT refresh: %w", err) + } + newJWT, err := client.Enroll(ctx, sakToken) if err != nil { return "", fmt.Errorf("failed to refresh JWT token: %w", err) } @@ -378,3 +394,31 @@ func (e *healthExporter) refreshJWTToken(ctx context.Context) (string, error) { log.Logger.Infow("Successfully refreshed JWT token") return newJWT, nil } + +func (e *healthExporter) readValidatedEndpoint(ctx context.Context, key string) string { + value, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, key) + if err != nil || value == "" { + return "" + } + validated, err := endpoint.ValidateBackendEndpoint(value) + if err != nil { + log.Logger.Errorw("ignoring invalid legacy endpoint from metadata", "key", key, "error", err) + return "" + } + return validated.String() +} + +func (e *healthExporter) readLegacyBackendBaseURL(ctx context.Context) (string, error) { + for _, key := range []string{"enroll_endpoint", "metrics_endpoint", "logs_endpoint", "nonce_endpoint"} { + value, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, key) + if err != nil || value == "" { + continue + } + baseURL, err := endpoint.DeriveBackendBaseURL(value) + if err != nil { + return "", fmt.Errorf("invalid legacy %s for JWT refresh: %w", key, err) + } + return baseURL, nil + } + return "", nil +} diff --git a/internal/exporter/exporter_test.go b/internal/exporter/exporter_test.go index a30f3bb2..9ea2362c 100644 --- a/internal/exporter/exporter_test.go +++ b/internal/exporter/exporter_test.go @@ -38,6 +38,7 @@ import ( pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" pkgmetrics "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metrics" + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/collector" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/writer" @@ -706,10 +707,7 @@ func TestRefreshConfigFromMetadata(t *testing.T) { ctx := context.Background() - // Insert test data into metadata table using SetMetadata - err := pkgmetadata.SetMetadata(ctx, tmpDB, "metrics_endpoint", "https://new-metrics.example.com") - require.NoError(t, err) - err = pkgmetadata.SetMetadata(ctx, tmpDB, "logs_endpoint", "https://new-logs.example.com") + err := pkgmetadata.SetMetadata(ctx, tmpDB, "backend_base_url", "https://backend.example.com") require.NoError(t, err) err = pkgmetadata.SetMetadata(ctx, tmpDB, pkgmetadata.MetadataKeyToken, "new-test-token") require.NoError(t, err) @@ -736,8 +734,8 @@ func TestRefreshConfigFromMetadata(t *testing.T) { he.refreshConfigFromMetadata(ctx) // Verify config was updated - assert.Equal(t, "https://new-metrics.example.com", he.options.config.MetricsEndpoint) - assert.Equal(t, "https://new-logs.example.com", he.options.config.LogsEndpoint) + assert.Equal(t, "https://backend.example.com/api/v1/health/metrics", he.options.config.MetricsEndpoint) + assert.Equal(t, "https://backend.example.com/api/v1/health/logs", he.options.config.LogsEndpoint) assert.Equal(t, "new-test-token", he.options.config.AuthToken) // Cleanup @@ -751,10 +749,7 @@ func TestRefreshConfigFromMetadata(t *testing.T) { ctx := context.Background() - // Insert empty values using SetMetadata - err := pkgmetadata.SetMetadata(ctx, tmpDB, "metrics_endpoint", "") - require.NoError(t, err) - err = pkgmetadata.SetMetadata(ctx, tmpDB, "logs_endpoint", "") + err := pkgmetadata.SetMetadata(ctx, tmpDB, "backend_base_url", "") require.NoError(t, err) cfg := &config.HealthExporterConfig{ @@ -792,9 +787,7 @@ func TestRefreshConfigFromMetadata(t *testing.T) { ctx := context.Background() - err := pkgmetadata.SetMetadata(ctx, tmpDB, "metrics_endpoint", "http://bad-metrics.example.com") - require.NoError(t, err) - err = pkgmetadata.SetMetadata(ctx, tmpDB, "logs_endpoint", "https://user@bad-logs.example.com") + err := pkgmetadata.SetMetadata(ctx, tmpDB, "backend_base_url", "http://bad-backend.example.com") require.NoError(t, err) cfg := &config.HealthExporterConfig{ @@ -932,19 +925,16 @@ func TestRefreshJWTToken(t *testing.T) { }) t.Run("refreshes token successfully", func(t *testing.T) { - // Mock enrollment server expectedToken := "new-jwt-token" - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify request - assert.Equal(t, http.MethodPost, r.Method) - assert.Equal(t, "Bearer test-sak-token", r.Header.Get("Authorization")) - - // Return new token - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"jwt_assertion": "%s"}`, expectedToken) - })) - defer server.Close() + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + assert.Equal(t, "https://backend.example.com", rawBaseURL) + return &fakeJWTRefreshClient{ + expectedSAK: "test-sak-token", + token: expectedToken, + }, nil + } tmpDB := setupTestDB(t) defer tmpDB.Close() @@ -953,7 +943,7 @@ func TestRefreshJWTToken(t *testing.T) { // Setup metadata err := pkgmetadata.SetMetadata(ctx, tmpDB, "sak_token", "test-sak-token") require.NoError(t, err) - err = pkgmetadata.SetMetadata(ctx, tmpDB, "enroll_endpoint", server.URL) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "backend_base_url", "https://backend.example.com") require.NoError(t, err) cfg := &config.HealthExporterConfig{ @@ -984,6 +974,78 @@ func TestRefreshJWTToken(t *testing.T) { err = exporter.Stop() require.NoError(t, err) }) + + t.Run("refreshes token successfully using legacy enroll endpoint", func(t *testing.T) { + expectedToken := "new-jwt-token" + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + assert.Equal(t, "https://backend.example.com", rawBaseURL) + return &fakeJWTRefreshClient{ + expectedSAK: "test-sak-token", + token: expectedToken, + }, nil + } + + tmpDB := setupTestDB(t) + defer tmpDB.Close() + + ctx := context.Background() + err := pkgmetadata.SetMetadata(ctx, tmpDB, "sak_token", "test-sak-token") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "enroll_endpoint", "https://backend.example.com/api/v1/health/enroll") + require.NoError(t, err) + + cfg := &config.HealthExporterConfig{ + Interval: metav1.Duration{Duration: 1 * time.Minute}, + Timeout: metav1.Duration{Duration: 30 * time.Second}, + } + + exporter, err := New(ctx, + WithConfig(cfg), + WithDatabaseConnections(tmpDB, tmpDB), + WithMachineID("test-machine-id"), + ) + require.NoError(t, err) + require.NotNil(t, exporter) + + he := exporter.(*healthExporter) + + token, err := he.refreshJWTToken(ctx) + require.NoError(t, err) + assert.Equal(t, expectedToken, token) + + err = exporter.Stop() + require.NoError(t, err) + }) +} + +type fakeJWTRefreshClient struct { + expectedSAK string + token string +} + +func (f *fakeJWTRefreshClient) Enroll(_ context.Context, sakToken string) (string, error) { + if f.expectedSAK != "" && sakToken != f.expectedSAK { + return "", fmt.Errorf("unexpected sak token %q", sakToken) + } + return f.token, nil +} + +func (f *fakeJWTRefreshClient) UpsertNode(context.Context, string, *backendclient.NodeUpsertRequest, string) error { + return nil +} + +func (f *fakeJWTRefreshClient) GetNonce(context.Context, string, string) (*backendclient.NonceResponse, error) { + return nil, nil +} + +func (f *fakeJWTRefreshClient) SubmitAttestation(context.Context, string, *backendclient.AttestationRequest, string) error { + return nil +} + +func (f *fakeJWTRefreshClient) RefreshToken(context.Context, string) (string, error) { + return "", nil } // TestIntegration provides integration tests From be13d03b4e12a76ef164d67a8969eff2d8c184fd Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Thu, 16 Apr 2026 10:59:01 -0700 Subject: [PATCH 07/22] feat: implement backend attestation workflow Signed-off-by: Jingxiang Zhang --- internal/attestationloop/backend.go | 63 +++- internal/attestationloop/backend_test.go | 202 +++++++++++ internal/attestationloop/collector.go | 101 ++++++ internal/attestationloop/collector_test.go | 129 +++++++ internal/attestationloop/manager.go | 95 +++++- internal/attestationloop/manager_test.go | 176 ++++++++++ internal/attestationloop/nonce.go | 59 ++++ internal/attestationloop/nonce_test.go | 51 +++ internal/attestationloop/types.go | 34 ++ internal/exporter/collector/collector.go | 66 +--- internal/exporter/collector/collector_test.go | 319 +----------------- internal/exporter/converter/otlp.go | 20 -- internal/exporter/converter/otlp_test.go | 112 ------ internal/exporter/exporter.go | 41 +-- 14 files changed, 939 insertions(+), 529 deletions(-) create mode 100644 internal/attestationloop/backend_test.go create mode 100644 internal/attestationloop/collector.go create mode 100644 internal/attestationloop/collector_test.go create mode 100644 internal/attestationloop/manager_test.go create mode 100644 internal/attestationloop/nonce.go create mode 100644 internal/attestationloop/nonce_test.go diff --git a/internal/attestationloop/backend.go b/internal/attestationloop/backend.go index bb8d7330..c5f63ec4 100644 --- a/internal/attestationloop/backend.go +++ b/internal/attestationloop/backend.go @@ -15,7 +15,16 @@ package attestationloop -import "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +import ( + "context" + "fmt" + "time" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +) + +var newBackendClient = backendclient.New func toAttestationRequest(r *Result) *backendclient.AttestationRequest { if r == nil { @@ -50,3 +59,55 @@ func toAttestationRequest(r *Result) *backendclient.AttestationRequest { return req } + +type stateBackendClientFactory struct { + state agentstate.State +} + +func (f *stateBackendClientFactory) client(ctx context.Context) (backendclient.Client, error) { + if f.state == nil { + return nil, fmt.Errorf("backend client factory requires agent state") + } + baseURL, ok, err := f.state.GetBackendBaseURL(ctx) + if err != nil { + return nil, err + } + if !ok || baseURL == "" { + return nil, fmt.Errorf("backend base URL not available in agent state") + } + return newBackendClient(baseURL) +} + +type stateNonceProvider struct { + factory *stateBackendClientFactory +} + +// NewStateNonceProvider creates a nonce provider that resolves backend state dynamically. +func NewStateNonceProvider(state agentstate.State) NonceProvider { + return &stateNonceProvider{factory: &stateBackendClientFactory{state: state}} +} + +func (p *stateNonceProvider) GetNonce(ctx context.Context, nodeID, jwt string) (string, time.Time, string, error) { + client, err := p.factory.client(ctx) + if err != nil { + return "", time.Time{}, "", err + } + return NewBackendNonceProvider(client).GetNonce(ctx, nodeID, jwt) +} + +type stateSubmitter struct { + factory *stateBackendClientFactory +} + +// NewStateBackendSubmitter creates a submitter that resolves backend state dynamically. +func NewStateBackendSubmitter(state agentstate.State) Submitter { + return &stateSubmitter{factory: &stateBackendClientFactory{state: state}} +} + +func (s *stateSubmitter) Submit(ctx context.Context, result *Result, jwt string) error { + client, err := s.factory.client(ctx) + if err != nil { + return err + } + return NewBackendSubmitter(client).Submit(ctx, result, jwt) +} diff --git a/internal/attestationloop/backend_test.go b/internal/attestationloop/backend_test.go new file mode 100644 index 00000000..290e436d --- /dev/null +++ b/internal/attestationloop/backend_test.go @@ -0,0 +1,202 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestationloop + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +) + +type stubState struct { + baseURL string + baseOK bool + baseErr error + jwt string + jwtOK bool + jwtErr error + setJWT string + nodeID string + nodeOK bool + nodeErr error +} + +func (s *stubState) GetBackendBaseURL(context.Context) (string, bool, error) { + return s.baseURL, s.baseOK, s.baseErr +} +func (s *stubState) SetBackendBaseURL(context.Context, string) error { return nil } +func (s *stubState) GetJWT(context.Context) (string, bool, error) { return s.jwt, s.jwtOK, s.jwtErr } +func (s *stubState) SetJWT(_ context.Context, v string) error { s.setJWT = v; s.jwt = v; return nil } +func (s *stubState) GetSAK(context.Context) (string, bool, error) { return "", false, nil } +func (s *stubState) SetSAK(context.Context, string) error { return nil } +func (s *stubState) GetNodeID(context.Context) (string, bool, error) { + return s.nodeID, s.nodeOK, s.nodeErr +} +func (s *stubState) SetNodeID(context.Context, string) error { return nil } + +type recordingClient struct { + lastNodeID string + lastJWT string + lastReq *backendclient.AttestationRequest + nonceResp *backendclient.NonceResponse +} + +func (c *recordingClient) Enroll(context.Context, string) (string, error) { return "", nil } +func (c *recordingClient) UpsertNode(context.Context, string, *backendclient.NodeUpsertRequest, string) error { + return nil +} +func (c *recordingClient) GetNonce(context.Context, string, string) (*backendclient.NonceResponse, error) { + return c.nonceResp, nil +} +func (c *recordingClient) SubmitAttestation(_ context.Context, nodeID string, req *backendclient.AttestationRequest, jwt string) error { + c.lastNodeID = nodeID + c.lastJWT = jwt + c.lastReq = req + return nil +} +func (c *recordingClient) RefreshToken(context.Context, string) (string, error) { return "", nil } + +func TestToAttestationRequest(t *testing.T) { + refreshTS := time.Now().UTC() + req := toAttestationRequest(&Result{ + NonceRefreshTimestamp: refreshTS, + Success: true, + ErrorMessage: "", + SDKResponse: SDKResponse{ + ResultCode: 7, + ResultMessage: "ok", + Evidences: []EvidenceItem{{ + Arch: "BLACKWELL", + Certificate: "cert", + DriverVersion: "575.1", + Evidence: "blob", + Nonce: "nonce", + VBIOSVersion: "vbios", + Version: "1.0", + }}, + }, + }) + require.NotNil(t, req) + require.Equal(t, refreshTS, req.AttestationData.NonceRefreshTimestamp) + require.True(t, req.AttestationData.Success) + require.Len(t, req.AttestationData.SDKResponse.Evidences, 1) + require.Equal(t, "BLACKWELL", req.AttestationData.SDKResponse.Evidences[0].Arch) + require.Nil(t, toAttestationRequest(nil)) +} + +func TestStateBackendClientFactory(t *testing.T) { + orig := newBackendClient + t.Cleanup(func() { newBackendClient = orig }) + + client := &recordingClient{} + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + require.Equal(t, "https://backend.example.com", rawBaseURL) + return client, nil + } + + factory := &stateBackendClientFactory{state: &stubState{baseURL: "https://backend.example.com", baseOK: true}} + got, err := factory.client(context.Background()) + require.NoError(t, err) + require.Equal(t, client, got) + + _, err = (&stateBackendClientFactory{}).client(context.Background()) + require.ErrorContains(t, err, "requires agent state") + + _, err = (&stateBackendClientFactory{state: &stubState{baseErr: errors.New("boom")}}).client(context.Background()) + require.ErrorContains(t, err, "boom") + + _, err = (&stateBackendClientFactory{state: &stubState{baseOK: false}}).client(context.Background()) + require.ErrorContains(t, err, "backend base URL not available") +} + +func TestStateProvidersAndSubmitter(t *testing.T) { + orig := newBackendClient + t.Cleanup(func() { newBackendClient = orig }) + + recording := &recordingClient{ + nonceResp: &backendclient.NonceResponse{ + Nonce: "abc123", + NonceRefreshTimestamp: time.Unix(10, 0).UTC(), + JWTAssertion: "new-jwt", + }, + } + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + require.Equal(t, "https://backend.example.com", rawBaseURL) + return recording, nil + } + + state := &stubState{ + baseURL: "https://backend.example.com", baseOK: true, + jwt: "jwt-token", jwtOK: true, + nodeID: "node-1", nodeOK: true, + } + + jwtProvider := NewStateJWTProvider(state) + jwt, err := jwtProvider.GetJWT(context.Background()) + require.NoError(t, err) + require.Equal(t, "jwt-token", jwt) + require.NoError(t, jwtProvider.SetJWT(context.Background(), "updated")) + require.Equal(t, "updated", state.setJWT) + + nodeID, err := NewStateNodeIDProvider(state)(context.Background()) + require.NoError(t, err) + require.Equal(t, "node-1", nodeID) + + nonce, ts, refreshedJWT, err := NewStateNonceProvider(state).GetNonce(context.Background(), "node-1", "jwt-token") + require.NoError(t, err) + require.Equal(t, "abc123", nonce) + require.Equal(t, time.Unix(10, 0).UTC(), ts) + require.Equal(t, "new-jwt", refreshedJWT) + + result := &Result{ + NodeID: "node-1", + SDKResponse: SDKResponse{ + ResultCode: 1, + Evidences: []EvidenceItem{{Arch: "BLACKWELL"}}, + }, + } + err = NewStateBackendSubmitter(state).Submit(context.Background(), result, "jwt-token") + require.NoError(t, err) + require.Equal(t, "node-1", recording.lastNodeID) + require.Equal(t, "jwt-token", recording.lastJWT) + require.NotNil(t, recording.lastReq) + require.Equal(t, "BLACKWELL", recording.lastReq.AttestationData.SDKResponse.Evidences[0].Arch) +} + +func TestLegacyAttestationData(t *testing.T) { + result := &Result{ + NonceRefreshTimestamp: time.Unix(20, 0).UTC(), + Success: false, + ErrorMessage: "boom", + SDKResponse: SDKResponse{ + ResultCode: 9, + ResultMessage: "bad", + Evidences: []EvidenceItem{{Arch: "BLACKWELL"}}, + }, + } + legacy := result.LegacyAttestationData() + require.NotNil(t, legacy) + require.False(t, legacy.Success) + require.Equal(t, "boom", legacy.ErrorMessage) + require.Equal(t, 9, legacy.SDKResponse.ResultCode) + require.Len(t, legacy.SDKResponse.Evidences, 1) + require.Nil(t, (*Result)(nil).LegacyAttestationData()) +} diff --git a/internal/attestationloop/collector.go b/internal/attestationloop/collector.go new file mode 100644 index 00000000..1294e1da --- /dev/null +++ b/internal/attestationloop/collector.go @@ -0,0 +1,101 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestationloop + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os/exec" + "time" +) + +var execCommandContext = exec.CommandContext + +type cliEvidenceCollector struct { + timeout time.Duration +} + +// NewCLIEvidenceCollector creates an evidence collector backed by the nvattest CLI. +func NewCLIEvidenceCollector(timeout time.Duration) EvidenceCollector { + if timeout <= 0 { + timeout = 60 * time.Second + } + return &cliEvidenceCollector{timeout: timeout} +} + +func (c *cliEvidenceCollector) Collect(ctx context.Context, nonce string) (*SDKResponse, error) { + if err := validateNonce(nonce); err != nil { + return nil, fmt.Errorf("invalid nonce received from backend: %w", err) + } + + runCtx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + + cmd := execCommandContext( + runCtx, + "nvattest", + "collect-evidence", + "--gpu-evidence-source=corelib", + "--nonce", nonce, + "--gpu-architecture", "blackwell", + "--format", "json", + ) + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + if err != nil && stdout.Len() == 0 { + return nil, fmt.Errorf("attestation CLI execution failed: %w (stderr: %s)", err, stderr.String()) + } + + var response SDKResponse + if parseErr := json.Unmarshal(stdout.Bytes(), &response); parseErr != nil { + errText := "" + if err != nil { + errText = err.Error() + } + return nil, fmt.Errorf( + "failed to parse CLI response: %w (stderr: %s), stdout: %s, error: %s", + parseErr, stderr.String(), stdout.String(), errText, + ) + } + return &response, nil +} + +func validateNonce(nonce string) error { + if nonce == "" { + return fmt.Errorf("nonce is empty") + } + const maxLen = 512 + if len(nonce) > maxLen { + return fmt.Errorf("nonce length %d exceeds maximum of %d characters", len(nonce), maxLen) + } + for i, c := range nonce { + switch { + case c >= '0' && c <= '9', + c >= 'a' && c <= 'z', + c >= 'A' && c <= 'Z', + c == '-', c == '_', c == '=', c == '+', c == '/': + default: + return fmt.Errorf("nonce contains invalid character %q at position %d", c, i) + } + } + return nil +} diff --git a/internal/attestationloop/collector_test.go b/internal/attestationloop/collector_test.go new file mode 100644 index 00000000..b75ae9a2 --- /dev/null +++ b/internal/attestationloop/collector_test.go @@ -0,0 +1,129 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestationloop + +import ( + "context" + "errors" + "os" + "os/exec" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +) + +func TestValidateNonce(t *testing.T) { + require.NoError(t, validateNonce("abc123-_=/+")) + require.Error(t, validateNonce("")) + require.Error(t, validateNonce("bad nonce")) +} + +func TestCLIEvidenceCollectorRejectsInvalidNonce(t *testing.T) { + collector := NewCLIEvidenceCollector(time.Second) + _, err := collector.Collect(context.Background(), "bad nonce") + require.Error(t, err) +} + +func TestCLIEvidenceCollectorParsesResponse(t *testing.T) { + original := execCommandContext + t.Cleanup(func() { execCommandContext = original }) + execCommandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd { + argv := append([]string{"-test.run=^TestHelperProcess$", "--", name}, args...) + cmd := exec.CommandContext(ctx, os.Args[0], argv...) + cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1") + return cmd + } + + collector := NewCLIEvidenceCollector(time.Second) + resp, err := collector.Collect(context.Background(), "abc123") + require.NoError(t, err) + require.Equal(t, 200, resp.ResultCode) + require.Equal(t, "ok", resp.ResultMessage) +} + +func TestCLIEvidenceCollectorExecutionAndParseErrors(t *testing.T) { + original := execCommandContext + t.Cleanup(func() { execCommandContext = original }) + + execCommandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd { + argv := append([]string{"-test.run=^TestHelperProcessError$", "--", name}, args...) + cmd := exec.CommandContext(ctx, os.Args[0], argv...) + cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1", "HELPER_MODE=stderr_only") + return cmd + } + collector := NewCLIEvidenceCollector(time.Second) + _, err := collector.Collect(context.Background(), "abc123") + require.Error(t, err) + + execCommandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd { + argv := append([]string{"-test.run=^TestHelperProcessError$", "--", name}, args...) + cmd := exec.CommandContext(ctx, os.Args[0], argv...) + cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1", "HELPER_MODE=bad_json") + return cmd + } + _, err = collector.Collect(context.Background(), "abc123") + require.Error(t, err) +} + +func TestBackendNonceProviderErrors(t *testing.T) { + _, _, _, err := NewBackendNonceProvider(nil).GetNonce(context.Background(), "node", "jwt") + require.ErrorContains(t, err, "backend client") + + client := &testNonceClient{} + _, _, _, err = NewBackendNonceProvider(client).GetNonce(context.Background(), "", "jwt") + require.ErrorContains(t, err, "node ID") + _, _, _, err = NewBackendNonceProvider(client).GetNonce(context.Background(), "node", "") + require.ErrorContains(t, err, "jwt") + + nilClient := &nilNonceClient{} + _, _, _, err = NewBackendNonceProvider(nilClient).GetNonce(context.Background(), "node", "jwt") + require.ErrorContains(t, err, "nil") +} + +type nilNonceClient struct{} + +func (c *nilNonceClient) GetNonce(context.Context, string, string) (*backendclient.NonceResponse, error) { + return nil, nil +} + +func TestHelperProcessError(t *testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + switch os.Getenv("HELPER_MODE") { + case "stderr_only": + _, _ = os.Stderr.WriteString("boom") + os.Exit(1) + case "bad_json": + _, _ = os.Stdout.WriteString("{") + _, _ = os.Stderr.WriteString("warn") + os.Exit(1) + default: + _ = errors.New("unused") + os.Exit(2) + } +} + +func TestHelperProcess(t *testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + _, _ = os.Stdout.WriteString(`{"evidences":[],"resultCode":200,"resultMessage":"ok"}`) + os.Exit(0) +} diff --git a/internal/attestationloop/manager.go b/internal/attestationloop/manager.go index 68fb4e53..a4140510 100644 --- a/internal/attestationloop/manager.go +++ b/internal/attestationloop/manager.go @@ -21,6 +21,9 @@ import ( "sync" "time" + "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" ) @@ -34,6 +37,8 @@ type JWTProvider interface { type Manager interface { Run(ctx context.Context) error CollectOnce(ctx context.Context) (*Result, error) + LastResult() *Result + IsResultUpdated(since time.Time) bool } type manager struct { @@ -45,7 +50,8 @@ type manager struct { submitter Submitter interval time.Duration - lastResult *Result + lastResult *Result + lastUpdated time.Time } // NewManager creates an attestation loop manager skeleton. @@ -68,10 +74,29 @@ func NewManager( } func (m *manager) Run(ctx context.Context) error { + if m.nodeIDProvider == nil || m.jwtProvider == nil || m.nonceProvider == nil || m.collector == nil || m.submitter == nil { + return fmt.Errorf("attestation loop dependencies are incomplete") + } if _, err := m.CollectOnce(ctx); err != nil { - return err + log.Logger.Warnw("initial attestation workflow failed", "error", err) + } + if m.interval <= 0 { + return nil + } + + ticker := time.NewTicker(m.interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + if _, err := m.CollectOnce(ctx); err != nil { + log.Logger.Warnw("periodic attestation workflow failed", "error", err) + } + } } - return fmt.Errorf("attestation loop run loop not implemented") } func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { @@ -115,6 +140,7 @@ func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { m.mu.Lock() cloned := *result m.lastResult = &cloned + m.lastUpdated = time.Now().UTC() m.mu.Unlock() if err := m.submitter.Submit(ctx, result, jwt); err != nil { return nil, err @@ -122,6 +148,22 @@ func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { return result, nil } +func (m *manager) LastResult() *Result { + m.mu.RLock() + defer m.mu.RUnlock() + if m.lastResult == nil { + return nil + } + cloned := *m.lastResult + return &cloned +} + +func (m *manager) IsResultUpdated(since time.Time) bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.lastUpdated.After(since) +} + type backendSubmitter struct { client BackendClient } @@ -148,3 +190,50 @@ func (s *backendSubmitter) Submit(ctx context.Context, result *Result, jwt strin } return s.client.SubmitAttestation(ctx, result.NodeID, toAttestationRequest(result), jwt) } + +type stateJWTProvider struct { + state agentstate.State +} + +// NewStateJWTProvider returns a JWT provider backed by persisted agent state. +func NewStateJWTProvider(state agentstate.State) JWTProvider { + return &stateJWTProvider{state: state} +} + +func (p *stateJWTProvider) GetJWT(ctx context.Context) (string, error) { + if p.state == nil { + return "", fmt.Errorf("jwt provider requires agent state") + } + value, ok, err := p.state.GetJWT(ctx) + if err != nil { + return "", err + } + if !ok || value == "" { + return "", fmt.Errorf("jwt not available in agent state") + } + return value, nil +} + +func (p *stateJWTProvider) SetJWT(ctx context.Context, value string) error { + if p.state == nil { + return fmt.Errorf("jwt provider requires agent state") + } + return p.state.SetJWT(ctx, value) +} + +// NewStateNodeIDProvider returns a node ID provider backed by persisted agent state. +func NewStateNodeIDProvider(state agentstate.State) func(context.Context) (string, error) { + return func(ctx context.Context) (string, error) { + if state == nil { + return "", fmt.Errorf("node ID provider requires agent state") + } + value, ok, err := state.GetNodeID(ctx) + if err != nil { + return "", err + } + if !ok || value == "" { + return "", fmt.Errorf("node ID not available in agent state") + } + return value, nil + } +} diff --git a/internal/attestationloop/manager_test.go b/internal/attestationloop/manager_test.go new file mode 100644 index 00000000..c4c0082d --- /dev/null +++ b/internal/attestationloop/manager_test.go @@ -0,0 +1,176 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestationloop + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type testJWTProvider struct { + jwt string + setJWT string +} + +func (p *testJWTProvider) GetJWT(context.Context) (string, error) { return p.jwt, nil } +func (p *testJWTProvider) SetJWT(_ context.Context, value string) error { + p.setJWT = value + p.jwt = value + return nil +} + +type testNonceProvider struct { + nonce string + refreshTS time.Time + refreshedJWT string + err error +} + +func (p *testNonceProvider) GetNonce(context.Context, string, string) (string, time.Time, string, error) { + return p.nonce, p.refreshTS, p.refreshedJWT, p.err +} + +type testEvidenceCollector struct { + resp *SDKResponse + err error +} + +func (c *testEvidenceCollector) Collect(context.Context, string) (*SDKResponse, error) { + return c.resp, c.err +} + +type submitted struct { + result *Result + jwt string +} + +type testSubmitter struct { + submitted submitted + err error +} + +func (s *testSubmitter) Submit(_ context.Context, result *Result, jwt string) error { + s.submitted = submitted{result: result, jwt: jwt} + return s.err +} + +func TestCollectOnceSuccess(t *testing.T) { + refreshTS := time.Now().UTC() + jwtProvider := &testJWTProvider{jwt: "old-jwt"} + submitter := &testSubmitter{} + manager := NewManager( + func(context.Context) (string, error) { return "node-1", nil }, + jwtProvider, + &testNonceProvider{nonce: "abc123", refreshTS: refreshTS, refreshedJWT: "new-jwt"}, + &testEvidenceCollector{resp: &SDKResponse{ResultCode: 200, ResultMessage: "ok"}}, + submitter, + 0, + ) + + result, err := manager.CollectOnce(context.Background()) + require.NoError(t, err) + require.True(t, result.Success) + require.Equal(t, "node-1", result.NodeID) + require.Equal(t, refreshTS, result.NonceRefreshTimestamp) + require.Equal(t, "new-jwt", jwtProvider.setJWT) + require.NotNil(t, submitter.submitted.result) + require.Equal(t, "new-jwt", submitter.submitted.jwt) + require.True(t, submitter.submitted.result.Success) +} + +func TestCollectOnceCollectorFailureStillSubmitsFailureResult(t *testing.T) { + submitter := &testSubmitter{} + manager := NewManager( + func(context.Context) (string, error) { return "node-1", nil }, + &testJWTProvider{jwt: "jwt-token"}, + &testNonceProvider{nonce: "abc123"}, + &testEvidenceCollector{err: errors.New("collect failed")}, + submitter, + 0, + ) + + result, err := manager.CollectOnce(context.Background()) + require.NoError(t, err) + require.False(t, result.Success) + require.Equal(t, "collect failed", result.ErrorMessage) + require.NotNil(t, submitter.submitted.result) + require.False(t, submitter.submitted.result.Success) +} + +func TestCollectOnceMissingDependencies(t *testing.T) { + _, err := NewManager(nil, nil, nil, nil, nil, 0).CollectOnce(context.Background()) + require.Error(t, err) +} + +func TestManagerRunAndCachedResult(t *testing.T) { + submitter := &testSubmitter{} + mgr := NewManager( + func(context.Context) (string, error) { return "node-1", nil }, + &testJWTProvider{jwt: "jwt-token"}, + &testNonceProvider{nonce: "abc123"}, + &testEvidenceCollector{resp: &SDKResponse{ResultCode: 200}}, + submitter, + 5*time.Millisecond, + ) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + require.NoError(t, mgr.Run(ctx)) + + last := mgr.LastResult() + require.NotNil(t, last) + require.Equal(t, "node-1", last.NodeID) + require.True(t, mgr.IsResultUpdated(time.Time{})) +} + +func TestManagerHelpersAndSubmitterErrors(t *testing.T) { + mgr := NewManager( + func(context.Context) (string, error) { return "node-1", nil }, + &testJWTProvider{jwt: "jwt-token"}, + &testNonceProvider{nonce: "abc123"}, + &testEvidenceCollector{resp: &SDKResponse{}}, + &testSubmitter{}, + 0, + ) + require.Nil(t, mgr.LastResult()) + require.False(t, mgr.IsResultUpdated(time.Now().UTC())) + + err := NewBackendSubmitter(nil).Submit(context.Background(), &Result{}, "jwt") + require.ErrorContains(t, err, "backend client") + err = NewBackendSubmitter(&recordingClient{}).Submit(context.Background(), nil, "jwt") + require.ErrorContains(t, err, "requires result") + err = NewBackendSubmitter(&recordingClient{}).Submit(context.Background(), &Result{}, "") + require.ErrorContains(t, err, "requires jwt") +} + +func TestStateJWTProviderAndNodeIDProviderErrors(t *testing.T) { + _, err := NewStateJWTProvider(nil).GetJWT(context.Background()) + require.ErrorContains(t, err, "requires agent state") + err = NewStateJWTProvider(nil).SetJWT(context.Background(), "x") + require.ErrorContains(t, err, "requires agent state") + + _, err = NewStateJWTProvider(&stubState{}).GetJWT(context.Background()) + require.ErrorContains(t, err, "jwt not available") + + _, err = NewStateNodeIDProvider(nil)(context.Background()) + require.ErrorContains(t, err, "requires agent state") + _, err = NewStateNodeIDProvider(&stubState{})(context.Background()) + require.ErrorContains(t, err, "node ID not available") +} diff --git a/internal/attestationloop/nonce.go b/internal/attestationloop/nonce.go new file mode 100644 index 00000000..a27c795c --- /dev/null +++ b/internal/attestationloop/nonce.go @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestationloop + +import ( + "context" + "fmt" + "time" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +) + +// NonceBackendClient is the backend client view required by the nonce provider. +type NonceBackendClient interface { + GetNonce(ctx context.Context, nodeID string, jwt string) (*backendclient.NonceResponse, error) +} + +type backendNonceProvider struct { + client NonceBackendClient +} + +// NewBackendNonceProvider creates a nonce provider backed by the agent backend client. +func NewBackendNonceProvider(client NonceBackendClient) NonceProvider { + return &backendNonceProvider{client: client} +} + +func (p *backendNonceProvider) GetNonce(ctx context.Context, nodeID, jwt string) (string, time.Time, string, error) { + if p.client == nil { + return "", time.Time{}, "", fmt.Errorf("nonce provider requires backend client") + } + if nodeID == "" { + return "", time.Time{}, "", fmt.Errorf("nonce provider requires node ID") + } + if jwt == "" { + return "", time.Time{}, "", fmt.Errorf("nonce provider requires jwt") + } + + resp, err := p.client.GetNonce(ctx, nodeID, jwt) + if err != nil { + return "", time.Time{}, "", err + } + if resp == nil { + return "", time.Time{}, "", fmt.Errorf("nonce response is nil") + } + return resp.Nonce, resp.NonceRefreshTimestamp, resp.JWTAssertion, nil +} diff --git a/internal/attestationloop/nonce_test.go b/internal/attestationloop/nonce_test.go new file mode 100644 index 00000000..dfc0f9e3 --- /dev/null +++ b/internal/attestationloop/nonce_test.go @@ -0,0 +1,51 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestationloop + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +) + +type testNonceClient struct { + resp *backendclient.NonceResponse +} + +func (c *testNonceClient) GetNonce(context.Context, string, string) (*backendclient.NonceResponse, error) { + return c.resp, nil +} + +func TestBackendNonceProvider(t *testing.T) { + refreshTS := time.Now().UTC() + provider := NewBackendNonceProvider(&testNonceClient{ + resp: &backendclient.NonceResponse{ + Nonce: "abc123", + NonceRefreshTimestamp: refreshTS, + JWTAssertion: "new-jwt", + }, + }) + + nonce, ts, jwt, err := provider.GetNonce(context.Background(), "node-1", "jwt-token") + require.NoError(t, err) + require.Equal(t, "abc123", nonce) + require.Equal(t, refreshTS, ts) + require.Equal(t, "new-jwt", jwt) +} diff --git a/internal/attestationloop/types.go b/internal/attestationloop/types.go index eb327486..3766663a 100644 --- a/internal/attestationloop/types.go +++ b/internal/attestationloop/types.go @@ -19,6 +19,8 @@ package attestationloop import ( "context" "time" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/attestation" ) // Result is the agent-owned attestation state model for the new backend sync loop. @@ -61,3 +63,35 @@ type EvidenceCollector interface { type Submitter interface { Submit(ctx context.Context, result *Result, jwt string) error } + +// LegacyAttestationData converts the workflow result into the legacy attestation payload shape +// still consumed by the exporter collector path. +func (r *Result) LegacyAttestationData() *attestation.AttestationData { + if r == nil { + return nil + } + data := &attestation.AttestationData{ + NonceRefreshTimestamp: r.NonceRefreshTimestamp, + Success: r.Success, + ErrorMessage: r.ErrorMessage, + SDKResponse: attestation.AttestationSDKResponse{ + ResultCode: r.SDKResponse.ResultCode, + ResultMessage: r.SDKResponse.ResultMessage, + }, + } + if len(r.SDKResponse.Evidences) > 0 { + data.SDKResponse.Evidences = make([]attestation.EvidenceItem, 0, len(r.SDKResponse.Evidences)) + for _, ev := range r.SDKResponse.Evidences { + data.SDKResponse.Evidences = append(data.SDKResponse.Evidences, attestation.EvidenceItem{ + Arch: ev.Arch, + Certificate: ev.Certificate, + DriverVersion: ev.DriverVersion, + Evidence: ev.Evidence, + Nonce: ev.Nonce, + VBIOSVersion: ev.VBIOSVersion, + Version: ev.Version, + }) + } + } + return data +} diff --git a/internal/exporter/collector/collector.go b/internal/exporter/collector/collector.go index acbe48ec..ae472c11 100644 --- a/internal/exporter/collector/collector.go +++ b/internal/exporter/collector/collector.go @@ -31,7 +31,6 @@ import ( nvidianvml "github.com/NVIDIA/fleet-intelligence-sdk/pkg/nvidia-query/nvml" "github.com/google/uuid" - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestation" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" ) @@ -52,15 +51,14 @@ func GenerateEventID() string { // HealthData represents the collected health data type HealthData struct { - CollectionID string - MachineID string - Timestamp time.Time - MachineInfo *machineinfo.MachineInfo - Metrics pkgmetrics.Metrics - Events eventstore.Events - ComponentData map[string]interface{} - AttestationData *attestation.AttestationData - ConfigEntries []config.ConfigEntry + CollectionID string + MachineID string + Timestamp time.Time + MachineInfo *machineinfo.MachineInfo + Metrics pkgmetrics.Metrics + Events eventstore.Events + ComponentData map[string]interface{} + ConfigEntries []config.ConfigEntry } // Collector defines the interface for collecting health data @@ -70,17 +68,15 @@ type Collector interface { // collector implements the Collector interface type collector struct { - config *config.HealthExporterConfig - configEntries []config.ConfigEntry // Cached config entries computed once at startup - metricsStore pkgmetrics.Store - eventStore eventstore.Store - componentsRegistry components.Registry - nvmlInstance nvidianvml.Instance - attestationManager *attestation.Manager - lastAttestationCollection time.Time - machineID string // Agent's stable identity from server initialization - dcgmGPUIndexes map[string]string // UUID → DCGM device ID override for GPU indices - machineInfoProvider machineInfoProvider + config *config.HealthExporterConfig + configEntries []config.ConfigEntry // Cached config entries computed once at startup + metricsStore pkgmetrics.Store + eventStore eventstore.Store + componentsRegistry components.Registry + nvmlInstance nvidianvml.Instance + machineID string // Agent's stable identity from server initialization + dcgmGPUIndexes map[string]string // UUID → DCGM device ID override for GPU indices + machineInfoProvider machineInfoProvider } // New creates a new health data collector @@ -92,7 +88,7 @@ func New( eventStore eventstore.Store, componentsRegistry components.Registry, nvmlInstance nvidianvml.Instance, - attestationManager *attestation.Manager, + _ any, machineID string, dcgmGPUIndexes map[string]string, ) Collector { @@ -120,7 +116,6 @@ func New( eventStore: eventStore, componentsRegistry: componentsRegistry, nvmlInstance: nvmlInstance, - attestationManager: attestationManager, machineID: machineID, dcgmGPUIndexes: dcgmGPUIndexes, machineInfoProvider: provider, @@ -170,12 +165,6 @@ func (c *collector) Collect(ctx context.Context) (*HealthData, error) { } } - // Collect attestation data if provider is available - // Attestation is always enabled if manager is available - if err := c.collectAttestationData(data); err != nil { - log.Logger.Errorw("Failed to collect attestation data", "error", err) - } - // Collect config data if err := c.collectConfigData(data); err != nil { log.Logger.Errorw("Failed to collect config data", "error", err) @@ -336,25 +325,6 @@ func (c *collector) collectComponentData(data *HealthData) error { return nil } -// collectAttestationData collects attestation data from the attestation manager if available and updated -func (c *collector) collectAttestationData(data *HealthData) error { - if c.attestationManager == nil { - log.Logger.Debugw("No attestation manager available, skipping attestation data collection") - return nil - } - - // Get latest attestation data (success or failure info) - attestationData := c.attestationManager.GetAttestationData() - data.AttestationData = attestationData - - // Update collection timestamp if data was newly updated - if c.attestationManager.IsAttestationDataUpdated(c.lastAttestationCollection) { - c.lastAttestationCollection = time.Now().UTC() - } - - return nil -} - // collectConfigData returns cached agent configuration entries // Config entries are computed once at startup since there's no dynamic config reload func (c *collector) collectConfigData(data *HealthData) error { diff --git a/internal/exporter/collector/collector_test.go b/internal/exporter/collector/collector_test.go index a851d598..853e3551 100644 --- a/internal/exporter/collector/collector_test.go +++ b/internal/exporter/collector/collector_test.go @@ -34,310 +34,10 @@ import ( "github.com/stretchr/testify/require" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestation" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" ) -func TestCollector_AttestationDataCollection(t *testing.T) { - tests := []struct { - name string - description string - testLogic func(t *testing.T) - }{ - { - name: "first_collection_always_collects", - description: "First collection should always collect attestation data even if empty", - testLogic: testFirstCollectionAlwaysCollects, - }, - { - name: "subsequent_collection_skips_when_no_update", - description: "Subsequent collections should skip when attestation data hasn't been updated", - testLogic: testSubsequentCollectionSkipsWhenNoUpdate, - }, - { - name: "collection_after_attestation_update", - description: "Collection should happen after attestation data is updated", - testLogic: testCollectionAfterAttestationUpdate, - }, - { - name: "nil_attestation_manager_skips_collection", - description: "Collector should skip attestation collection when manager is nil", - testLogic: testNilAttestationManagerSkipsCollection, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Log("Testing:", tt.description) - tt.testLogic(t) - }) - } -} - -func testFirstCollectionAlwaysCollects(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - attestationManager := attestation.NewManager(ctx, nil, cfg) // nil nvmlInstance, 20s for testing - - // Create collector - testCollector := createTestCollector(attestationManager) - - // Start attestation manager to populate some data - attestationManager.Start() - defer attestationManager.Stop() - - // Wait a bit for attestation to run - time.Sleep(100 * time.Millisecond) - - // Collect data for the first time - data, err := testCollector.Collect(ctx) - - require.NoError(t, err) - require.NotNil(t, data) - - // First collection should work (but may not have attestation data due to test environment) - if data.AttestationData != nil { - t.Log("First collection successfully populated attestation data") - } else { - t.Log("First collection did not populate attestation data - this is expected when NVML/nonce fails in test environment") - } - - // Check if lastAttestationCollection was updated only if attestation data was collected - collectorImpl := testCollector.(*collector) - if data.AttestationData != nil { - assert.False(t, collectorImpl.lastAttestationCollection.IsZero(), - "lastAttestationCollection should be set after successful collection") - } else { - assert.True(t, collectorImpl.lastAttestationCollection.IsZero(), - "lastAttestationCollection should remain zero when attestation fails") - } -} - -func testSubsequentCollectionSkipsWhenNoUpdate(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - attestationManager := attestation.NewManager(ctx, nil, cfg) // nil nvmlInstance, 20s for testing - - // Create collector - testCollector := createTestCollector(attestationManager) - collectorImpl := testCollector.(*collector) - - // Start attestation manager - attestationManager.Start() - defer attestationManager.Stop() - - // Wait for attestation to run and populate data - time.Sleep(100 * time.Millisecond) - - // First collection - data1, err := testCollector.Collect(ctx) - require.NoError(t, err) - require.NotNil(t, data1) - - firstCollectionTime := collectorImpl.lastAttestationCollection - // In test environment, this will be zero since attestation fails - t.Logf("First collection time: %v", firstCollectionTime) - - // Verify first collection has attestation data - // Verify first collection has attestation data (or logs why it doesn't) - if data1.AttestationData != nil { - assert.Empty(t, data1.AttestationData.SDKResponse.Evidences, "Until Attestation is available in public release, this should be empty") - } else { - t.Log("First collection did not populate attestation data - this is expected when NVML/nonce fails") - } - - // Sleep a little to ensure time difference - time.Sleep(10 * time.Millisecond) - - // Second collection (should skip attestation since no update) - data2, err := testCollector.Collect(ctx) - require.NoError(t, err) - require.NotNil(t, data2) - - secondCollectionTime := collectorImpl.lastAttestationCollection - - // lastAttestationCollection should remain the same (indicating skip) - assert.Equal(t, firstCollectionTime, secondCollectionTime, - "lastAttestationCollection should not change when attestation collection is skipped") -} - -func testCollectionAfterAttestationUpdate(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - attestationManager := attestation.NewManager(ctx, nil, cfg) // nil nvmlInstance, 20s for testing - - // Create collector - testCollector := createTestCollector(attestationManager) - collectorImpl := testCollector.(*collector) - - // Start attestation manager with faster interval for testing (20 seconds) - attestationManager.Start() - defer attestationManager.Stop() - - // Wait for first attestation to run - time.Sleep(100 * time.Millisecond) - - // First collection - data1, err := testCollector.Collect(ctx) - require.NoError(t, err) - require.NotNil(t, data1) - - firstCollectionTime := collectorImpl.lastAttestationCollection - - // Verify first collection has attestation data - // Verify first collection has attestation data (or logs why it doesn't) - if data1.AttestationData != nil { - assert.Empty(t, data1.AttestationData.SDKResponse.Evidences, "Until Attestation is available in public release, this should be empty") - } else { - t.Log("First collection did not populate attestation data - this is expected when NVML/nonce fails") - } - - // Wait for attestation to run again (it's set to 20 seconds in the test) - t.Log("Waiting for attestation to refresh...") - time.Sleep(10 * time.Second) - - // Second collection (should collect since attestation was updated) - data2, err := testCollector.Collect(ctx) - require.NoError(t, err) - require.NotNil(t, data2) - - secondCollectionTime := collectorImpl.lastAttestationCollection - - // In test environment, both times will be zero since attestation fails - t.Logf("First collection time: %v, Second collection time: %v", firstCollectionTime, secondCollectionTime) - - // In a real environment with working NVML/HTTP, both collections would have evidence data - // In test environment, they will be nil due to missing dependencies - if data1.AttestationData != nil && data2.AttestationData != nil { - assert.Empty(t, data1.AttestationData.SDKResponse.Evidences, "Until Attestation is available in public release, this should be empty") - assert.Empty(t, data2.AttestationData.SDKResponse.Evidences, "Until Attestation is available in public release, this should be empty") - t.Log("Both collections successfully have attestation data") - } else { - t.Log("Collections do not have attestation data - expected in test environment without real dependencies") - } -} - -func testNilAttestationManagerSkipsCollection(t *testing.T) { - ctx := context.Background() - - // Create collector with nil attestation manager - testCollector := createTestCollectorWithNilAttestation() - - // Collection should skip gracefully - data, err := testCollector.Collect(ctx) - - require.NoError(t, err) - require.NotNil(t, data) - assert.Nil(t, data.AttestationData, "Should not collect attestation data when manager is nil") -} - -func TestCollector_AttestationDataCollection_WithMockData(t *testing.T) { - // This test verifies collection behavior when attestation is unavailable - ctx := context.Background() - attestationCfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - attestationManager := attestation.NewManager(ctx, nil, attestationCfg) - testCollector := createTestCollector(attestationManager) - collectorImpl := testCollector.(*collector) - - // Verify that collection works when no attestation data is available - data1, err := testCollector.Collect(ctx) - require.NoError(t, err) - require.NotNil(t, data1) - - // Should be nil since no attestation data is available - assert.Nil(t, data1.AttestationData, "Should be nil when no attestation data available") - assert.True(t, collectorImpl.lastAttestationCollection.IsZero(), "Should remain zero") - - t.Log("Successfully tested collection with no attestation data") -} - -func TestAttestationManager_UpdateTracking(t *testing.T) { - ctx := context.Background() - attestationCfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := attestation.NewManager(ctx, nil, attestationCfg) // nil nvmlInstance for testing - - // Initially, no updates - baseTime := time.Now().UTC() - assert.False(t, manager.IsAttestationDataUpdated(baseTime), - "Should return false before any attestation runs") - - // Start the manager and test the update tracking - manager.Start() - defer manager.Stop() - - // Give it time to attempt attestation - time.Sleep(100 * time.Millisecond) - - // In test environment this may still be false due to NVML/HTTP failures, but that's expected - updated := manager.IsAttestationDataUpdated(baseTime) - t.Logf("Attestation updated after start: %v", updated) - - // The important part is that the method doesn't crash and returns a boolean - assert.IsType(t, false, updated, "IsAttestationDataUpdated should return a boolean") -} - -// Helper functions - -func createTestCollector(attestationManager *attestation.Manager) Collector { - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: false, - IncludeMetrics: false, - IncludeEvents: false, - IncludeComponentData: false, - } - - return New( - cfg, - nil, // fullConfig - nil, // allComponentNames - nil, // metricsStore - nil, // eventStore - nil, // componentsRegistry - nil, // nvmlInstance - attestationManager, - "test-machine-id", - nil, // dcgmGPUIndexes - ) -} - -func createTestCollectorWithNilAttestation() Collector { - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: false, - IncludeMetrics: false, - IncludeEvents: false, - IncludeComponentData: false, - } - - return New( - cfg, - nil, // fullConfig - nil, // allComponentNames - nil, // metricsStore - nil, // eventStore - nil, // componentsRegistry - nil, // nvmlInstance - nil, // attestationManager (nil for testing) - "test-machine-id", - nil, // dcgmGPUIndexes - ) -} - func TestGenerateCollectionID(t *testing.T) { // Generate multiple collection IDs id1 := GenerateCollectionID() @@ -376,20 +76,13 @@ func TestGenerateEventID(t *testing.T) { } func TestNew(t *testing.T) { - ctx := context.Background() cfg := &config.HealthExporterConfig{ IncludeMachineInfo: true, IncludeMetrics: true, IncludeEvents: true, IncludeComponentData: true, } - attestationCfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - attestationManager := attestation.NewManager(ctx, nil, attestationCfg) - - c := New(cfg, nil, nil, nil, nil, nil, nil, attestationManager, "test-machine-id", nil) + c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil) assert.NotNil(t, c, "Collector should be created") @@ -424,7 +117,6 @@ func TestCollector_Collect_BasicFlow(t *testing.T) { assert.Empty(t, data.Metrics, "Metrics should be empty when disabled") assert.Empty(t, data.Events, "Events should be empty when disabled") assert.Empty(t, data.ComponentData, "ComponentData should be empty when disabled") - assert.Nil(t, data.AttestationData, "AttestationData should be nil when disabled") } func TestCollector_CollectMachineInfo_NoNVML(t *testing.T) { @@ -1005,15 +697,9 @@ func TestCollector_AllFeaturesEnabled(t *testing.T) { components: []components.Component{mockComp}, } - attestationCfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - attestationManager := attestation.NewManager(ctx, nil, attestationCfg) - mockEventStore := &mockEventStore{} - collector := New(cfg, nil, nil, mockMetricsStore, mockEventStore, mockRegistry, nil, attestationManager, "test-machine-id", nil) + collector := New(cfg, nil, nil, mockMetricsStore, mockEventStore, mockRegistry, nil, nil, "test-machine-id", nil) data, err := collector.Collect(ctx) require.NoError(t, err) @@ -1028,7 +714,6 @@ func TestCollector_AllFeaturesEnabled(t *testing.T) { assert.Len(t, data.Events, 1) assert.Len(t, data.ComponentData, 1) // MachineInfo will be nil without NVML - // AttestationData may be nil in test environment } // ============================================================================= diff --git a/internal/exporter/converter/otlp.go b/internal/exporter/converter/otlp.go index 89bd9be8..97d48cea 100644 --- a/internal/exporter/converter/otlp.go +++ b/internal/exporter/converter/otlp.go @@ -137,12 +137,6 @@ func (c *otlpConverter) createOTLPResource(data *collector.HealthData) *resource attributes = append(attributes, machineInfoAttributes...) } - // Add attestation data attributes if available using reflection - if data.AttestationData != nil { - attestationAttributes := convertStructToOTLPAttributesWithPrefix(data.AttestationData, "attestation") - attributes = append(attributes, attestationAttributes...) - } - return &resourcev1.Resource{ Attributes: attributes, } @@ -211,12 +205,6 @@ func (c *otlpConverter) convertMetricsToOTLP(data *collector.HealthData) []*metr Value: &commonv1.AnyValue_IntValue{IntValue: int64(len(data.ComponentData))}, }, }, - { - Key: "attestation_evidences_count", - Value: &commonv1.AnyValue{ - Value: &commonv1.AnyValue_IntValue{IntValue: int64(c.getAttestationEvidencesCount(data))}, - }, - }, }, }, }, @@ -646,11 +634,3 @@ func convertStructToOTLPAttributesWithPrefix(v interface{}, prefix string) []*co return attributes } - -// getAttestationEvidencesCount returns the count of attestation evidences -func (c *otlpConverter) getAttestationEvidencesCount(data *collector.HealthData) int { - if data.AttestationData == nil { - return 0 - } - return len(data.AttestationData.SDKResponse.Evidences) -} diff --git a/internal/exporter/converter/otlp_test.go b/internal/exporter/converter/otlp_test.go index ee475cee..ac5ef682 100644 --- a/internal/exporter/converter/otlp_test.go +++ b/internal/exporter/converter/otlp_test.go @@ -28,7 +28,6 @@ import ( metricsv1 "go.opentelemetry.io/proto/otlp/metrics/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestation" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/collector" "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" ) @@ -360,57 +359,6 @@ func TestOTLPConverter_Convert_WithMachineInfo(t *testing.T) { assert.Equal(t, "4.2.3", findAttribute(t, rm.Resource.Attributes, "dcgmVersion").GetStringValue()) } -func TestOTLPConverter_Convert_WithAttestationData(t *testing.T) { - attestationData := &attestation.AttestationData{ - Success: true, - SDKResponse: attestation.AttestationSDKResponse{ - Evidences: []attestation.EvidenceItem{ - { - Arch: "BLACKWELL", - Certificate: "test-cert", - DriverVersion: "575.28", - Evidence: "test-evidence", - Nonce: "test-nonce", - VBIOSVersion: "96.00.AF.00.01", - Version: "1.0", - }, - }, - ResultCode: 0, - ResultMessage: "Ok", - }, - NonceRefreshTimestamp: time.Date(2025, 11, 5, 12, 0, 0, 0, time.UTC), - } - - data := &collector.HealthData{ - Timestamp: time.Now(), - MachineID: "test-machine", - AttestationData: attestationData, - } - - converter := NewOTLPConverter() - otlpData := converter.Convert(data) - - require.NotNil(t, otlpData) - require.NotNil(t, otlpData.Logs) - - // Should NOT have attestation logs - rl := otlpData.Logs.ResourceLogs[0] - logs := rl.ScopeLogs[0].LogRecords - assert.Empty(t, logs, "Should not have attestation logs") - - // Should have attestation data in resource attributes - rm := otlpData.Metrics.ResourceMetrics[0] - attrs := rm.Resource.Attributes - foundAttestation := false - for _, attr := range attrs { - if contains(attr.Key, "attestation") { - foundAttestation = true - break - } - } - assert.True(t, foundAttestation, "Should have attestation data in resource attributes") -} - func TestOTLPConverter_ConvertStructToOTLPAttributes(t *testing.T) { type TestStruct struct { StringField string @@ -748,55 +696,6 @@ func TestBuildGPUUUIDToIndexMap(t *testing.T) { }) } -func TestOTLPConverter_GetAttestationEvidencesCount(t *testing.T) { - tests := []struct { - name string - data *collector.HealthData - expectedCount int - }{ - { - name: "with_evidences", - data: &collector.HealthData{ - AttestationData: &attestation.AttestationData{ - SDKResponse: attestation.AttestationSDKResponse{ - Evidences: []attestation.EvidenceItem{ - {Arch: "BLACKWELL"}, - {Arch: "HOPPER"}, - }, - }, - }, - }, - expectedCount: 2, - }, - { - name: "nil_attestation", - data: &collector.HealthData{ - AttestationData: nil, - }, - expectedCount: 0, - }, - { - name: "empty_evidences", - data: &collector.HealthData{ - AttestationData: &attestation.AttestationData{ - SDKResponse: attestation.AttestationSDKResponse{ - Evidences: []attestation.EvidenceItem{}, - }, - }, - }, - expectedCount: 0, - }, - } - - converter := &otlpConverter{} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - count := converter.getAttestationEvidencesCount(tt.data) - assert.Equal(t, tt.expectedCount, count) - }) - } -} - func TestOTLPConverter_SummaryMetric(t *testing.T) { data := &collector.HealthData{ Timestamp: time.Now(), @@ -901,17 +800,6 @@ func TestOTLPConverter_Convert_AllData(t *testing.T) { MachineInfo: &machineinfo.MachineInfo{ FleetintVersion: "0.1.5", }, - AttestationData: &attestation.AttestationData{ - Success: true, - SDKResponse: attestation.AttestationSDKResponse{ - Evidences: []attestation.EvidenceItem{ - {Arch: "BLACKWELL", VBIOSVersion: "96.00"}, - }, - ResultCode: 0, - ResultMessage: "Ok", - }, - NonceRefreshTimestamp: time.Now(), - }, } converter := NewOTLPConverter() diff --git a/internal/exporter/exporter.go b/internal/exporter/exporter.go index 185dde18..2eff6cfc 100644 --- a/internal/exporter/exporter.go +++ b/internal/exporter/exporter.go @@ -30,7 +30,6 @@ import ( "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestation" "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/collector" @@ -46,13 +45,12 @@ var newBackendClient = backendclient.New // healthExporter implements the Exporter interface with improved architecture type healthExporter struct { - ctx context.Context - cancel context.CancelFunc - options *exporterOptions - collector collector.Collector - fileWriter writer.FileWriter - httpWriter writer.HTTPWriter - attestationManager *attestation.Manager + ctx context.Context + cancel context.CancelFunc + options *exporterOptions + collector collector.Collector + fileWriter writer.FileWriter + httpWriter writer.HTTPWriter // Last export timestamp for tracking lastExport time.Time @@ -78,10 +76,6 @@ func New(ctx context.Context, opts ...ExporterOption) (Exporter, error) { } options.setDefaults() - // Create attestation manager (always enabled) - attestationManager := attestation.NewManager(cctx, options.nvmlInstance, &options.config.Attestation) - log.Logger.Infow("Attestation manager created", "interval", options.config.Attestation.Interval.Duration, "jitter_enabled", options.config.Attestation.JitterEnabled) - // Get all component names for config export allComponentNames := registry.AllComponentNames() @@ -93,7 +87,7 @@ func New(ctx context.Context, opts ...ExporterOption) (Exporter, error) { options.eventStore, options.componentsRegistry, options.nvmlInstance, - attestationManager, + nil, options.machineID, options.dcgmGPUIndexes, ) @@ -105,13 +99,12 @@ func New(ctx context.Context, opts ...ExporterOption) (Exporter, error) { httpWriter := writer.NewHTTPWriter(options.httpClient, otlpConverter) exporter := &healthExporter{ - ctx: cctx, - cancel: cancel, - options: options, - collector: dataCollector, - fileWriter: fileWriter, - httpWriter: httpWriter, - attestationManager: attestationManager, + ctx: cctx, + cancel: cancel, + options: options, + collector: dataCollector, + fileWriter: fileWriter, + httpWriter: httpWriter, } // Set JWT refresh function on the HTTP writer @@ -129,11 +122,6 @@ func (e *healthExporter) Start() error { log.Logger.Infow("Starting health exporter") - // Start the attestation manager if enabled - if e.attestationManager != nil { - e.attestationManager.Start() - } - // Start the health export ticker go func() { ticker := time.NewTicker(e.options.config.Interval.Duration) @@ -160,9 +148,6 @@ func (e *healthExporter) Start() error { // Stop gracefully shuts down the exporter func (e *healthExporter) Stop() error { log.Logger.Infow("Stopping health exporter") - if e.attestationManager != nil { - e.attestationManager.Stop() - } e.cancel() return nil } From 89b7bc21bd61febf3da43f5869f9c24095e73e06 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Thu, 16 Apr 2026 13:17:54 -0700 Subject: [PATCH 08/22] refactor: move agent config to inventory payload Signed-off-by: Jingxiang Zhang --- internal/backendclient/types.go | 9 + internal/config/config.go | 11 + internal/config/config_test.go | 16 ++ internal/enrollment/enrollment.go | 23 +- internal/exporter/collector/collector.go | 116 +++----- internal/exporter/collector/collector_test.go | 252 ------------------ internal/exporter/converter/otlp.go | 46 ++-- internal/exporter/converter/otlp_test.go | 96 ++++--- internal/exporter/options.go | 4 - internal/exporter/options_test.go | 12 - internal/inventory/mapper/backend.go | 9 +- internal/inventory/mapper/backend_test.go | 12 + internal/inventory/source/source.go | 17 +- internal/inventory/source/source_test.go | 28 ++ internal/inventory/types.go | 9 + 15 files changed, 231 insertions(+), 429 deletions(-) diff --git a/internal/backendclient/types.go b/internal/backendclient/types.go index 7d9ef5e3..54e6d36d 100644 --- a/internal/backendclient/types.go +++ b/internal/backendclient/types.go @@ -20,6 +20,7 @@ import "time" // NodeUpsertRequest is the backend DTO for node inventory upserts. type NodeUpsertRequest struct { Hostname string `json:"hostname"` + AgentConfig AgentConfig `json:"agentConfig,omitempty"` Resources NodeResources `json:"resources"` FleetintVersion string `json:"gpuHealthVersion"` GPUDriverVersion string `json:"gpuDriverVersion"` @@ -45,6 +46,14 @@ type NodeResources struct { NICInfo NICInfo `json:"nicInfo"` } +type AgentConfig struct { + TotalComponents int64 `json:"totalComponents,omitempty"` + APIVersion string `json:"apiVersion,omitempty"` + RetentionPeriodSeconds int64 `json:"retentionPeriodSeconds,omitempty"` + EnabledComponents []string `json:"enabledComponents,omitempty"` + DisabledComponents []string `json:"disabledComponents,omitempty"` +} + type CPUInfo struct { Type string `json:"type"` Manufacturer string `json:"manufacturer"` diff --git a/internal/config/config.go b/internal/config/config.go index f2eae596..5bc3ef26 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -264,6 +264,17 @@ func (config *Config) ToConfigEntries(allComponentNames []string) []ConfigEntry return entries } +// InventoryAgentConfig returns the useful, non-sensitive subset of agent config that should be +// persisted with inventory rather than exported through telemetry. +func (config *Config) InventoryAgentConfig(allComponentNames []string) (apiVersion string, retentionPeriodSeconds int64, enabled, disabled []string) { + if config == nil { + return "", 0, nil, nil + } + + enabled, disabled = config.getComponentLists(allComponentNames) + return config.APIVersion, int64(config.RetentionPeriod.Seconds()), enabled, disabled +} + // getComponentLists computes enabled/disabled lists from config rules against all available components. func (config *Config) getComponentLists(allComponentNames []string) (enabled, disabled []string) { enabled, disabled = []string{}, []string{} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 3d229b91..8ed9ad11 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -822,3 +822,19 @@ func TestGetComponentLists(t *testing.T) { assert.Equal(t, "[]", string(disabledJSON)) }) } + +func TestInventoryAgentConfig(t *testing.T) { + allComponents := []string{"cpu", "disk", "memory", "gpu"} + cfg := &Config{ + APIVersion: "v1", + RetentionPeriod: metav1.Duration{Duration: 24 * time.Hour}, + Components: []string{"*", "-memory", "-disk"}, + } + + apiVersion, retentionPeriodSeconds, enabled, disabled := cfg.InventoryAgentConfig(allComponents) + + assert.Equal(t, "v1", apiVersion) + assert.Equal(t, int64(86400), retentionPeriodSeconds) + assert.ElementsMatch(t, []string{"cpu", "gpu"}, enabled) + assert.ElementsMatch(t, []string{"memory", "disk"}, disabled) +} diff --git a/internal/enrollment/enrollment.go b/internal/enrollment/enrollment.go index d999b5df..585ffa7d 100644 --- a/internal/enrollment/enrollment.go +++ b/internal/enrollment/enrollment.go @@ -33,6 +33,7 @@ import ( inventorysink "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/sink" inventorysource "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/source" "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" + "github.com/NVIDIA/fleet-intelligence-agent/internal/registry" ) var ( @@ -105,6 +106,13 @@ func (f machineInfoCollectorFunc) Collect(ctx context.Context) (*machineinfo.Mac func syncInventoryOnce(ctx context.Context) error { state := agentstate.NewSQLite() sink := inventorysink.NewBackendSink(state) + allComponents := registry.AllComponentNames() + + cfg, err := config.Default(ctx) + if err != nil { + return fmt.Errorf("load default config for inventory sync: %w", err) + } + apiVersion, retentionPeriodSeconds, enabledComponents, disabledComponents := cfg.InventoryAgentConfig(allComponents) nvmlInstance, err := nvidianvml.New() if err != nil { @@ -112,9 +120,18 @@ func syncInventoryOnce(ctx context.Context) error { } defer func() { _ = nvmlInstance.Shutdown() }() - src := inventorysource.NewMachineInfoSource(machineInfoCollectorFunc(func(context.Context) (*machineinfo.MachineInfo, error) { - return machineinfo.GetMachineInfo(nvmlInstance) - })) + src := inventorysource.NewMachineInfoSourceWithAgentConfig( + machineInfoCollectorFunc(func(context.Context) (*machineinfo.MachineInfo, error) { + return machineinfo.GetMachineInfo(nvmlInstance) + }), + &inventory.AgentConfig{ + TotalComponents: int64(len(allComponents)), + APIVersion: apiVersion, + RetentionPeriodSeconds: retentionPeriodSeconds, + EnabledComponents: enabledComponents, + DisabledComponents: disabledComponents, + }, + ) manager := inventory.NewManager(src, sink, 0) _, err = manager.CollectOnce(ctx) return err diff --git a/internal/exporter/collector/collector.go b/internal/exporter/collector/collector.go index ae472c11..b7cd5f70 100644 --- a/internal/exporter/collector/collector.go +++ b/internal/exporter/collector/collector.go @@ -35,8 +35,6 @@ import ( "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" ) -const initialMachineInfoWait = 5 * time.Second - // GenerateCollectionID generates a unique identifier for a data collection cycle func GenerateCollectionID() string { bytes := make([]byte, 16) @@ -51,14 +49,14 @@ func GenerateEventID() string { // HealthData represents the collected health data type HealthData struct { - CollectionID string - MachineID string - Timestamp time.Time - MachineInfo *machineinfo.MachineInfo - Metrics pkgmetrics.Metrics - Events eventstore.Events - ComponentData map[string]interface{} - ConfigEntries []config.ConfigEntry + CollectionID string + MachineID string + Timestamp time.Time + MachineInfo *machineinfo.MachineInfo + GPUUUIDToIndex map[string]string + Metrics pkgmetrics.Metrics + Events eventstore.Events + ComponentData map[string]interface{} } // Collector defines the interface for collecting health data @@ -68,15 +66,13 @@ type Collector interface { // collector implements the Collector interface type collector struct { - config *config.HealthExporterConfig - configEntries []config.ConfigEntry // Cached config entries computed once at startup - metricsStore pkgmetrics.Store - eventStore eventstore.Store - componentsRegistry components.Registry - nvmlInstance nvidianvml.Instance - machineID string // Agent's stable identity from server initialization - dcgmGPUIndexes map[string]string // UUID → DCGM device ID override for GPU indices - machineInfoProvider machineInfoProvider + config *config.HealthExporterConfig + metricsStore pkgmetrics.Store + eventStore eventstore.Store + componentsRegistry components.Registry + nvmlInstance nvidianvml.Instance + machineID string // Agent's stable identity from server initialization + dcgmGPUIndexes map[string]string // UUID → DCGM device ID override for GPU indices } // New creates a new health data collector @@ -92,33 +88,14 @@ func New( machineID string, dcgmGPUIndexes map[string]string, ) Collector { - // Compute config entries once at startup (no dynamic config reload) - var configEntries []config.ConfigEntry - if fullConfig != nil { - configEntries = fullConfig.ToConfigEntries(allComponentNames) - log.Logger.Infow("Config entries computed at startup", "entries_count", len(configEntries)) - } - - var provider machineInfoProvider - if cfg != nil && cfg.IncludeMachineInfo && nvmlInstance != nil { - var opts []machineinfo.MachineInfoOption - if len(dcgmGPUIndexes) > 0 { - opts = append(opts, machineinfo.WithDCGMGPUIndexes(dcgmGPUIndexes)) - } - provider = newCachedMachineInfoProvider(nvmlInstance, 0, opts...) - provider.RefreshAsync(context.Background()) - } - return &collector{ - config: cfg, - configEntries: configEntries, - metricsStore: metricsStore, - eventStore: eventStore, - componentsRegistry: componentsRegistry, - nvmlInstance: nvmlInstance, - machineID: machineID, - dcgmGPUIndexes: dcgmGPUIndexes, - machineInfoProvider: provider, + config: cfg, + metricsStore: metricsStore, + eventStore: eventStore, + componentsRegistry: componentsRegistry, + nvmlInstance: nvmlInstance, + machineID: machineID, + dcgmGPUIndexes: dcgmGPUIndexes, } } @@ -134,14 +111,10 @@ func (c *collector) Collect(ctx context.Context) (*HealthData, error) { } data := &HealthData{ - CollectionID: collectionID, - MachineID: c.machineID, - Timestamp: time.Now().UTC(), - } - - // Collect machine info if enabled - if c.config.IncludeMachineInfo { - c.collectMachineInfo(ctx, data) + CollectionID: collectionID, + MachineID: c.machineID, + Timestamp: time.Now().UTC(), + GPUUUIDToIndex: cloneStringMap(c.dcgmGPUIndexes), } // Collect metrics if enabled @@ -165,31 +138,9 @@ func (c *collector) Collect(ctx context.Context) (*HealthData, error) { } } - // Collect config data - if err := c.collectConfigData(data); err != nil { - log.Logger.Errorw("Failed to collect config data", "error", err) - } - return data, nil } -// collectMachineInfo reads cached machine info and triggers a best-effort refresh. -func (c *collector) collectMachineInfo(ctx context.Context, data *HealthData) { - if c.machineInfoProvider == nil { - return - } - - if _, ok := c.machineInfoProvider.Get(); !ok { - c.machineInfoProvider.WaitForInitialRefresh(ctx, initialMachineInfoWait) - } - - if machineInfo, ok := c.machineInfoProvider.Get(); ok { - data.MachineInfo = machineInfo - } - - c.machineInfoProvider.RefreshAsync(ctx) -} - // collectMetrics collects metrics data from the metrics store func (c *collector) collectMetrics(ctx context.Context, data *HealthData) error { if c.metricsStore == nil { @@ -325,15 +276,14 @@ func (c *collector) collectComponentData(data *HealthData) error { return nil } -// collectConfigData returns cached agent configuration entries -// Config entries are computed once at startup since there's no dynamic config reload -func (c *collector) collectConfigData(data *HealthData) error { - if len(c.configEntries) == 0 { - log.Logger.Debugw("No config entries available, skipping config data collection") +func cloneStringMap(in map[string]string) map[string]string { + if len(in) == 0 { return nil } - // Return cached config entries (computed once at startup) - data.ConfigEntries = c.configEntries - return nil + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out } diff --git a/internal/exporter/collector/collector_test.go b/internal/exporter/collector/collector_test.go index 853e3551..0ba0bb78 100644 --- a/internal/exporter/collector/collector_test.go +++ b/internal/exporter/collector/collector_test.go @@ -19,7 +19,6 @@ import ( "context" "encoding/hex" "errors" - "sync" "sync/atomic" "testing" "time" @@ -136,180 +135,6 @@ func TestCollector_CollectMachineInfo_NoNVML(t *testing.T) { assert.Nil(t, data.MachineInfo, "MachineInfo should be nil without NVML") } -func TestCollector_CollectMachineInfo_UsesCachedValue(t *testing.T) { - ctx := context.Background() - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - } - - c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - expected := &machineinfo.MachineInfo{Hostname: "cached-host"} - provider := &mockMachineInfoProvider{ - cached: expected, - } - c.machineInfoProvider = provider - - data, err := c.Collect(ctx) - - require.NoError(t, err) - require.NotNil(t, data) - require.NotNil(t, data.MachineInfo) - assert.Equal(t, expected, data.MachineInfo) - assert.Equal(t, int32(1), provider.refreshCalls.Load()) -} - -func TestCollector_CollectMachineInfo_WaitsBrieflyForInitialRefresh(t *testing.T) { - ctx := context.Background() - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - } - - c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - provider := newMockMachineInfoProvider() - c.machineInfoProvider = provider - - go func() { - time.Sleep(50 * time.Millisecond) - provider.setCached(&machineinfo.MachineInfo{Hostname: "prewarmed-host"}) - provider.markInitialRefreshDone() - }() - - data, err := c.Collect(ctx) - - require.NoError(t, err) - require.NotNil(t, data) - require.NotNil(t, data.MachineInfo) - assert.Equal(t, "prewarmed-host", data.MachineInfo.Hostname) -} - -func TestCollector_CollectMachineInfo_RefreshDoesNotBlockMetrics(t *testing.T) { - ctx := context.Background() - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - IncludeMetrics: true, - MetricsLookback: metav1.Duration{Duration: 5 * time.Minute}, - } - - c := New(cfg, nil, nil, &mockMetricsStore{ - metrics: pkgmetrics.Metrics{ - {Component: "gpu", Name: "temperature", Value: 70, UnixMilliseconds: time.Now().UnixMilli()}, - }, - }, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - - blocker := make(chan struct{}) - provider := newMockMachineInfoProvider() - provider.refreshFn = func(parent context.Context) { - provider.markInitialRefreshDone() - <-blocker - } - c.machineInfoProvider = provider - - start := time.Now() - data, err := c.Collect(ctx) - elapsed := time.Since(start) - close(blocker) - - require.NoError(t, err) - require.NotNil(t, data) - assert.Len(t, data.Metrics, 1) - assert.Nil(t, data.MachineInfo) - assert.GreaterOrEqual(t, elapsed, 4900*time.Millisecond) - assert.Less(t, elapsed, 5500*time.Millisecond) -} - -func TestCollector_CollectMachineInfo_InitialWaitDoesNotRepeatAfterFirstRefresh(t *testing.T) { - ctx := context.Background() - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - } - - c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - provider := newMockMachineInfoProvider() - provider.markInitialRefreshDone() - provider.refreshFn = func(parent context.Context) {} - c.machineInfoProvider = provider - - start := time.Now() - data, err := c.Collect(ctx) - elapsed := time.Since(start) - - require.NoError(t, err) - require.NotNil(t, data) - assert.Less(t, elapsed, 200*time.Millisecond) -} - -func TestCollector_CollectMachineInfo_InitialWaitDoesNotRepeatAfterTimeout(t *testing.T) { - ctx := context.Background() - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - } - - c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - provider := newMockMachineInfoProvider() - provider.refreshFn = func(parent context.Context) {} - c.machineInfoProvider = provider - - start := time.Now() - data, err := c.Collect(ctx) - firstElapsed := time.Since(start) - - require.NoError(t, err) - require.NotNil(t, data) - assert.GreaterOrEqual(t, firstElapsed, 4900*time.Millisecond) - assert.Less(t, firstElapsed, 5500*time.Millisecond) - - start = time.Now() - data, err = c.Collect(ctx) - secondElapsed := time.Since(start) - - require.NoError(t, err) - require.NotNil(t, data) - assert.Less(t, secondElapsed, 200*time.Millisecond) -} - -func TestCollector_CollectMachineInfo_InitialWaitHonorsContextCancellation(t *testing.T) { - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - } - - c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - provider := newMockMachineInfoProvider() - provider.refreshFn = func(parent context.Context) {} - c.machineInfoProvider = provider - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - start := time.Now() - data, err := c.Collect(ctx) - elapsed := time.Since(start) - - require.NoError(t, err) - require.NotNil(t, data) - assert.Less(t, elapsed, 200*time.Millisecond) -} - -func TestCollector_CollectMachineInfo_RetainsLastGoodOnRefreshFailure(t *testing.T) { - ctx := context.Background() - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - } - - c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - expected := &machineinfo.MachineInfo{Hostname: "last-good"} - provider := newMockMachineInfoProvider() - provider.cached = expected - provider.initialRefreshOnce.Do(func() { close(provider.initialRefreshDone) }) - provider.refreshFn = func(parent context.Context) {} - c.machineInfoProvider = provider - - data, err := c.Collect(ctx) - - require.NoError(t, err) - require.NotNil(t, data) - assert.Equal(t, expected, data.MachineInfo) -} - func TestCachedMachineInfoProvider_DeduplicatesConcurrentRefresh(t *testing.T) { originalGetMachineInfo := getMachineInfo defer func() { @@ -725,83 +550,6 @@ type mockMetricsStore struct { shouldError bool } -type mockMachineInfoProvider struct { - mu sync.RWMutex - cached *machineinfo.MachineInfo - refreshFn func(context.Context) - refreshCalls atomic.Int32 - initialWaited bool - initialRefreshDone chan struct{} - initialRefreshOnce sync.Once -} - -func newMockMachineInfoProvider() *mockMachineInfoProvider { - return &mockMachineInfoProvider{ - initialRefreshDone: make(chan struct{}), - } -} - -func (m *mockMachineInfoProvider) Get() (*machineinfo.MachineInfo, bool) { - m.mu.RLock() - defer m.mu.RUnlock() - - if m.cached == nil { - return nil, false - } - return m.cached, true -} - -func (m *mockMachineInfoProvider) RefreshAsync(parent context.Context) { - m.refreshCalls.Add(1) - if m.refreshFn == nil { - return - } - go func() { - defer func() { - _ = recover() - }() - m.refreshFn(parent) - }() -} - -func (m *mockMachineInfoProvider) WaitForInitialRefresh(ctx context.Context, maxWait time.Duration) bool { - if maxWait <= 0 { - return false - } - - m.mu.Lock() - if m.initialWaited { - m.mu.Unlock() - return false - } - m.initialWaited = true - m.mu.Unlock() - - timer := time.NewTimer(maxWait) - defer timer.Stop() - - select { - case <-m.initialRefreshDone: - return true - case <-ctx.Done(): - return false - case <-timer.C: - return false - } -} - -func (m *mockMachineInfoProvider) setCached(info *machineinfo.MachineInfo) { - m.mu.Lock() - defer m.mu.Unlock() - m.cached = info -} - -func (m *mockMachineInfoProvider) markInitialRefreshDone() { - m.initialRefreshOnce.Do(func() { - close(m.initialRefreshDone) - }) -} - func (m *mockMetricsStore) Read(ctx context.Context, opts ...pkgmetrics.OpOption) (pkgmetrics.Metrics, error) { if m.shouldError { return nil, errors.New("mock metrics store error") diff --git a/internal/exporter/converter/otlp.go b/internal/exporter/converter/otlp.go index 97d48cea..ac33cc88 100644 --- a/internal/exporter/converter/otlp.go +++ b/internal/exporter/converter/otlp.go @@ -19,6 +19,7 @@ package converter import ( "encoding/json" "fmt" + "os" "reflect" "strings" "time" @@ -32,6 +33,8 @@ import ( "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/collector" ) +var osHostname = os.Hostname + // OTLPData holds both metrics and logs for OTLP export type OTLPData struct { Metrics *metricsv1.MetricsData @@ -98,7 +101,7 @@ func (c *otlpConverter) Convert(data *collector.HealthData) *OTLPData { } } -// createOTLPResource creates OTLP resource with machine info, agent config, and identification +// createOTLPResource creates a minimal OTLP resource for telemetry correlation. func (c *otlpConverter) createOTLPResource(data *collector.HealthData) *resourcev1.Resource { attributes := []*commonv1.KeyValue{ { @@ -113,35 +116,33 @@ func (c *otlpConverter) createOTLPResource(data *collector.HealthData) *resource Value: &commonv1.AnyValue_StringValue{StringValue: data.MachineID}, }, }, - { - Key: "agentConfig.totalComponents", - Value: &commonv1.AnyValue{ - Value: &commonv1.AnyValue_IntValue{IntValue: int64(len(data.ComponentData))}, - }, - }, } - // Add agent config entries as resource attributes - for _, entry := range data.ConfigEntries { + if hostname := resolveOTLPHostname(); hostname != "" { attributes = append(attributes, &commonv1.KeyValue{ - Key: "agentConfig." + entry.Key, + Key: "host.name", Value: &commonv1.AnyValue{ - Value: &commonv1.AnyValue_StringValue{StringValue: entry.Value}, + Value: &commonv1.AnyValue_StringValue{StringValue: hostname}, }, }) } - // Add machine info attributes if available using reflection - if data.MachineInfo != nil { - machineInfoAttributes := convertStructToOTLPAttributes(data.MachineInfo) - attributes = append(attributes, machineInfoAttributes...) - } - return &resourcev1.Resource{ Attributes: attributes, } } +func resolveOTLPHostname() string { + if hostname := strings.TrimSpace(os.Getenv("HOSTNAME")); hostname != "" { + return hostname + } + hostname, err := osHostname() + if err != nil { + return "" + } + return strings.TrimSpace(hostname) +} + // convertMetricsToOTLP converts health metrics to OTLP metrics format func (c *otlpConverter) convertMetricsToOTLP(data *collector.HealthData) []*metricsv1.Metric { var otlpMetrics []*metricsv1.Metric @@ -247,16 +248,17 @@ func (c *otlpConverter) convertLabelsToOTLPAttributes(labels map[string]string, return attributes } -// buildGPUUUIDToIndexMap builds a UUID → GPU index lookup from MachineInfo. +// buildGPUUUIDToIndexMap builds a UUID → GPU index lookup from the collector snapshot. func buildGPUUUIDToIndexMap(data *collector.HealthData) map[string]string { m := make(map[string]string) - if data.MachineInfo == nil || data.MachineInfo.GPUInfo == nil { + if len(data.GPUUUIDToIndex) == 0 { return m } - for _, gpu := range data.MachineInfo.GPUInfo.GPUs { - if gpu.UUID != "" && gpu.GPUIndex != "" { - m[gpu.UUID] = gpu.GPUIndex + for uuid, gpuIndex := range data.GPUUUIDToIndex { + if uuid == "" || gpuIndex == "" { + continue } + m[uuid] = gpuIndex } return m } diff --git a/internal/exporter/converter/otlp_test.go b/internal/exporter/converter/otlp_test.go index ac5ef682..529e037e 100644 --- a/internal/exporter/converter/otlp_test.go +++ b/internal/exporter/converter/otlp_test.go @@ -16,6 +16,8 @@ package converter import ( + "errors" + "os" "testing" "time" @@ -308,21 +310,12 @@ func TestOTLPConverter_Convert_WithComponentData(t *testing.T) { assert.True(t, found, "Should find component data log") } -func TestOTLPConverter_Convert_WithMachineInfo(t *testing.T) { +func TestOTLPConverter_Convert_IgnoresMachineInfoInResource(t *testing.T) { data := &collector.HealthData{ Timestamp: time.Now(), MachineID: "test-machine", MachineInfo: &machineinfo.MachineInfo{ - FleetintVersion: "0.1.5", - DCGMVersion: "4.2.3", - OSImage: "Ubuntu 22.04", - KernelVersion: "5.15.0", - CPUInfo: &apiv1.MachineCPUInfo{ - Type: "Intel", - Manufacturer: "Intel", - Architecture: "x86_64", - LogicalCores: 8, - }, + DCGMVersion: "4.2.3", }, } @@ -332,31 +325,13 @@ func TestOTLPConverter_Convert_WithMachineInfo(t *testing.T) { require.NotNil(t, otlpData) require.NotNil(t, otlpData.Metrics) - // Check resource has machine info attributes rm := otlpData.Metrics.ResourceMetrics[0] assert.NotNil(t, rm.Resource) assert.Greater(t, len(rm.Resource.Attributes), 0) - // Verify machine info is embedded in resource attributes - // The attributes may have different keys based on how machine info is flattened - attrCount := len(rm.Resource.Attributes) - assert.Greater(t, attrCount, 2, "Should have multiple resource attributes including machine info") - - // Check that some attributes exist (the exact key names may vary) - attrKeys := make([]string, 0, len(rm.Resource.Attributes)) for _, attr := range rm.Resource.Attributes { - attrKeys = append(attrKeys, attr.Key) - } - // Should have at least service.name and machine.id - hasServiceName := false - for _, key := range attrKeys { - if key == "service.name" { - hasServiceName = true - break - } + assert.NotEqual(t, "dcgmVersion", attr.Key) } - assert.True(t, hasServiceName, "Should have service.name attribute") - assert.Equal(t, "4.2.3", findAttribute(t, rm.Resource.Attributes, "dcgmVersion").GetStringValue()) } func TestOTLPConverter_ConvertStructToOTLPAttributes(t *testing.T) { @@ -647,13 +622,9 @@ func TestOTLPConverter_ConvertLabelsToOTLPAttributes_EnrichesGPUIndex(t *testing func TestBuildGPUUUIDToIndexMap(t *testing.T) { t.Run("builds map from machine info", func(t *testing.T) { data := &collector.HealthData{ - MachineInfo: &machineinfo.MachineInfo{ - GPUInfo: &apiv1.MachineGPUInfo{ - GPUs: []apiv1.MachineGPUInstance{ - {UUID: "GPU-abc-123", GPUIndex: "0"}, - {UUID: "GPU-def-456", GPUIndex: "1"}, - }, - }, + GPUUUIDToIndex: map[string]string{ + "GPU-abc-123": "0", + "GPU-def-456": "1", }, } @@ -669,9 +640,9 @@ func TestBuildGPUUUIDToIndexMap(t *testing.T) { assert.Empty(t, m) }) - t.Run("returns empty map when GPU info is nil", func(t *testing.T) { + t.Run("returns empty map when mapping is nil", func(t *testing.T) { data := &collector.HealthData{ - MachineInfo: &machineinfo.MachineInfo{}, + GPUUUIDToIndex: nil, } m := buildGPUUUIDToIndexMap(data) assert.Empty(t, m) @@ -679,14 +650,10 @@ func TestBuildGPUUUIDToIndexMap(t *testing.T) { t.Run("skips entries with empty uuid or index", func(t *testing.T) { data := &collector.HealthData{ - MachineInfo: &machineinfo.MachineInfo{ - GPUInfo: &apiv1.MachineGPUInfo{ - GPUs: []apiv1.MachineGPUInstance{ - {UUID: "GPU-abc-123", GPUIndex: "0"}, - {UUID: "", GPUIndex: "1"}, - {UUID: "GPU-ghi-789", GPUIndex: ""}, - }, - }, + GPUUUIDToIndex: map[string]string{ + "GPU-abc-123": "0", + "": "1", + "GPU-ghi-789": "", }, } @@ -780,6 +747,37 @@ func TestOTLPConverter_Interface(t *testing.T) { assert.NotNil(t, converter) } +func TestResolveOTLPHostname(t *testing.T) { + origHostEnv, hadHostEnv := os.LookupEnv("HOSTNAME") + origOSHostname := osHostname + t.Cleanup(func() { + osHostname = origOSHostname + if hadHostEnv { + _ = os.Setenv("HOSTNAME", origHostEnv) + } else { + _ = os.Unsetenv("HOSTNAME") + } + }) + + t.Run("falls back to hostname env", func(t *testing.T) { + _ = os.Setenv("HOSTNAME", "pod-host-a") + osHostname = func() (string, error) { return "os-host-a", nil } + assert.Equal(t, "pod-host-a", resolveOTLPHostname()) + }) + + t.Run("falls back to os hostname", func(t *testing.T) { + _ = os.Unsetenv("HOSTNAME") + osHostname = func() (string, error) { return "os-host-a", nil } + assert.Equal(t, "os-host-a", resolveOTLPHostname()) + }) + + t.Run("returns empty on hostname error", func(t *testing.T) { + _ = os.Unsetenv("HOSTNAME") + osHostname = func() (string, error) { return "", errors.New("boom") } + assert.Equal(t, "", resolveOTLPHostname()) + }) +} + func TestOTLPConverter_Convert_AllData(t *testing.T) { // Test with all data types combined data := &collector.HealthData{ @@ -797,9 +795,6 @@ func TestOTLPConverter_Convert_AllData(t *testing.T) { "reason": "All OK", }, }, - MachineInfo: &machineinfo.MachineInfo{ - FleetintVersion: "0.1.5", - }, } converter := NewOTLPConverter() @@ -818,7 +813,6 @@ func TestOTLPConverter_Convert_AllData(t *testing.T) { rl := otlpData.Logs.ResourceLogs[0] assert.Greater(t, len(rl.ScopeLogs[0].LogRecords), 0) - // Verify resource has attributes from machine info assert.Greater(t, len(rm.Resource.Attributes), 0) } diff --git a/internal/exporter/options.go b/internal/exporter/options.go index f89e8d95..dfc22438 100644 --- a/internal/exporter/options.go +++ b/internal/exporter/options.go @@ -180,10 +180,6 @@ func (c *exporterOptions) validate() error { return errors.New("components registry is required when IncludeComponentData is enabled") } - if c.config.IncludeMachineInfo && c.nvmlInstance == nil { - return errors.New("NVML instance is required when IncludeMachineInfo is enabled") - } - // Machine ID is always required - it should be set by server via WithMachineID if c.machineID == "" { return errors.New("machine ID is required - must be set via WithMachineID()") diff --git a/internal/exporter/options_test.go b/internal/exporter/options_test.go index e8d0cad8..3f614770 100644 --- a/internal/exporter/options_test.go +++ b/internal/exporter/options_test.go @@ -307,18 +307,6 @@ func TestExporterOptionsValidate(t *testing.T) { wantErr: true, expectedErr: "components registry is required when IncludeComponentData is enabled", }, - { - name: "machine info enabled but no NVML instance", - setupOpts: func() *exporterOptions { - return &exporterOptions{ - config: &config.HealthExporterConfig{ - IncludeMachineInfo: true, - }, - } - }, - wantErr: true, - expectedErr: "NVML instance is required when IncludeMachineInfo is enabled", - }, } for _, tt := range tests { diff --git a/internal/inventory/mapper/backend.go b/internal/inventory/mapper/backend.go index 7f68d017..d5966430 100644 --- a/internal/inventory/mapper/backend.go +++ b/internal/inventory/mapper/backend.go @@ -64,7 +64,14 @@ func ToNodeUpsertRequest(s *inventory.Snapshot) *backendclient.NodeUpsertRequest } return &backendclient.NodeUpsertRequest{ - Hostname: s.Hostname, + Hostname: s.Hostname, + AgentConfig: backendclient.AgentConfig{ + TotalComponents: s.AgentConfig.TotalComponents, + APIVersion: s.AgentConfig.APIVersion, + RetentionPeriodSeconds: s.AgentConfig.RetentionPeriodSeconds, + EnabledComponents: append([]string(nil), s.AgentConfig.EnabledComponents...), + DisabledComponents: append([]string(nil), s.AgentConfig.DisabledComponents...), + }, MachineID: s.MachineID, SystemUUID: s.SystemUUID, BootID: s.BootID, diff --git a/internal/inventory/mapper/backend_test.go b/internal/inventory/mapper/backend_test.go index 8dc5868d..4fb88192 100644 --- a/internal/inventory/mapper/backend_test.go +++ b/internal/inventory/mapper/backend_test.go @@ -45,6 +45,13 @@ func TestToNodeUpsertRequest(t *testing.T) { NetPrivateIP: "10.0.0.10", NetPublicIP: "203.0.113.10", InventoryHash: "hash-1", + AgentConfig: inventory.AgentConfig{ + TotalComponents: 30, + APIVersion: "v1", + RetentionPeriodSeconds: 86400, + EnabledComponents: []string{"cpu", "gpu"}, + DisabledComponents: []string{"disk"}, + }, Resources: inventory.Resources{ CPUInfo: inventory.CPUInfo{ Type: "Xeon", @@ -99,6 +106,11 @@ func TestToNodeUpsertRequest(t *testing.T) { require.Equal(t, "machine-id", req.MachineID) require.Equal(t, "203.0.113.10", req.NetPublicIP) require.Equal(t, "hash-1", req.InventoryHash) + require.Equal(t, int64(30), req.AgentConfig.TotalComponents) + require.Equal(t, "v1", req.AgentConfig.APIVersion) + require.Equal(t, int64(86400), req.AgentConfig.RetentionPeriodSeconds) + require.Equal(t, []string{"cpu", "gpu"}, req.AgentConfig.EnabledComponents) + require.Equal(t, []string{"disk"}, req.AgentConfig.DisabledComponents) require.Equal(t, int64(64), req.Resources.CPUInfo.LogicalCores) require.Equal(t, uint64(1024), req.Resources.MemoryInfo.TotalBytes) require.Len(t, req.Resources.GPUInfo.GPUs, 1) diff --git a/internal/inventory/source/source.go b/internal/inventory/source/source.go index 8d84a99a..997b0fbc 100644 --- a/internal/inventory/source/source.go +++ b/internal/inventory/source/source.go @@ -31,7 +31,8 @@ type MachineInfoCollector interface { } type machineInfoSource struct { - collector MachineInfoCollector + collector MachineInfoCollector + agentConfig inventory.AgentConfig } // NewMachineInfoSource wraps the machine inventory collector as an inventory source. @@ -39,6 +40,19 @@ func NewMachineInfoSource(collector MachineInfoCollector) inventory.Source { return &machineInfoSource{collector: collector} } +// NewMachineInfoSourceWithAgentConfig wraps the machine inventory collector and attaches useful +// agent configuration that should travel with inventory rather than OTLP telemetry. +func NewMachineInfoSourceWithAgentConfig(collector MachineInfoCollector, agentConfig *inventory.AgentConfig) inventory.Source { + var cfg inventory.AgentConfig + if agentConfig != nil { + cfg = *agentConfig + } + return &machineInfoSource{ + collector: collector, + agentConfig: cfg, + } +} + func (s *machineInfoSource) Collect(ctx context.Context) (*inventory.Snapshot, error) { if s.collector == nil { return nil, fmt.Errorf("machine info collector is required") @@ -66,6 +80,7 @@ func (s *machineInfoSource) Collect(ctx context.Context) (*inventory.Snapshot, e CUDAVersion: info.CUDAVersion, DCGMVersion: info.DCGMVersion, ContainerRuntimeVersion: info.ContainerRuntimeVersion, + AgentConfig: s.agentConfig, } if info.CPUInfo != nil { diff --git a/internal/inventory/source/source_test.go b/internal/inventory/source/source_test.go index 836d40f7..43c6ece3 100644 --- a/internal/inventory/source/source_test.go +++ b/internal/inventory/source/source_test.go @@ -22,6 +22,7 @@ import ( apiv1 "github.com/NVIDIA/fleet-intelligence-sdk/api/v1" "github.com/stretchr/testify/require" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" ) @@ -112,3 +113,30 @@ func TestMachineInfoSourceCollect(t *testing.T) { require.Len(t, snap.Resources.DiskInfo.BlockDevices, 1) require.Equal(t, "eth0", snap.Resources.NICInfo.PrivateIPInterfaces[0].Interface) } + +func TestMachineInfoSourceCollectWithAgentConfig(t *testing.T) { + src := NewMachineInfoSourceWithAgentConfig( + fakeMachineInfoCollector{ + info: &machineinfo.MachineInfo{ + MachineID: "machine-id", + Hostname: "host-a", + }, + }, + &inventory.AgentConfig{ + TotalComponents: 42, + APIVersion: "v1", + RetentionPeriodSeconds: 86400, + EnabledComponents: []string{"cpu", "gpu"}, + DisabledComponents: []string{"disk"}, + }, + ) + + snap, err := src.Collect(context.Background()) + require.NoError(t, err) + require.NotNil(t, snap) + require.Equal(t, int64(42), snap.AgentConfig.TotalComponents) + require.Equal(t, "v1", snap.AgentConfig.APIVersion) + require.Equal(t, int64(86400), snap.AgentConfig.RetentionPeriodSeconds) + require.Equal(t, []string{"cpu", "gpu"}, snap.AgentConfig.EnabledComponents) + require.Equal(t, []string{"disk"}, snap.AgentConfig.DisabledComponents) +} diff --git a/internal/inventory/types.go b/internal/inventory/types.go index 998debee..5e10116e 100644 --- a/internal/inventory/types.go +++ b/internal/inventory/types.go @@ -44,9 +44,18 @@ type Snapshot struct { ContainerRuntimeVersion string NetPrivateIP string NetPublicIP string + AgentConfig AgentConfig Resources Resources } +type AgentConfig struct { + TotalComponents int64 + APIVersion string + RetentionPeriodSeconds int64 + EnabledComponents []string + DisabledComponents []string +} + type Resources struct { CPUInfo CPUInfo MemoryInfo MemoryInfo From 960dc02921c8897dc3d77b55bdfda71193446875 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Thu, 16 Apr 2026 15:41:14 -0700 Subject: [PATCH 09/22] refactor: simplify inventory and attestation workflows Signed-off-by: Jingxiang Zhang --- SECURITY.md | 2 +- cmd/fleetint/run.go | 35 +- .../helm/fleet-intelligence-agent/README.md | 5 +- .../helm/fleet-intelligence-agent/values.yaml | 5 +- deployments/packages/systemd/fleetint.env | 5 +- docs/configuration.md | 8 +- internal/attestation/attestation.go | 613 -------------- internal/attestation/attestation_test.go | 778 ------------------ .../backend.go | 4 +- .../backend_test.go | 32 +- .../collector.go | 2 +- .../collector_test.go | 2 +- .../manager.go | 118 ++- .../manager_test.go | 88 +- .../{attestationloop => attestation}/nonce.go | 2 +- .../nonce_test.go | 2 +- .../{attestationloop => attestation}/types.go | 46 +- internal/config/config.go | 99 ++- internal/config/config_test.go | 115 ++- internal/config/default.go | 17 +- internal/enrollment/enrollment.go | 2 +- internal/exporter/collector/collector_test.go | 2 - internal/inventory/manager.go | 107 ++- internal/inventory/manager_run_test.go | 54 +- internal/inventory/manager_test.go | 2 +- internal/inventory/types.go | 7 + internal/server/server.go | 113 +++ internal/server/server_test.go | 88 ++ 28 files changed, 797 insertions(+), 1556 deletions(-) delete mode 100644 internal/attestation/attestation.go delete mode 100644 internal/attestation/attestation_test.go rename internal/{attestationloop => attestation}/backend.go (96%) rename internal/{attestationloop => attestation}/backend_test.go (90%) rename internal/{attestationloop => attestation}/collector.go (99%) rename internal/{attestationloop => attestation}/collector_test.go (99%) rename internal/{attestationloop => attestation}/manager.go (69%) rename internal/{attestationloop => attestation}/manager_test.go (64%) rename internal/{attestationloop => attestation}/nonce.go (98%) rename internal/{attestationloop => attestation}/nonce_test.go (98%) rename internal/{attestationloop => attestation}/types.go (59%) diff --git a/SECURITY.md b/SECURITY.md index 693f32a7..a32655dd 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -49,7 +49,7 @@ This repository contains the `fleetint` host agent. The notes below are intended - Exporter token refresh and endpoint reload: [`internal/exporter/exporter.go`](internal/exporter/exporter.go) - Local HTTP server routes: [`internal/server/server.go`](internal/server/server.go) - Optional fault injection handler: [`internal/server/handlers_inject_fault.go`](internal/server/handlers_inject_fault.go) -- Remote attestation and `nvattest` invocation: [`internal/attestation/attestation.go`](internal/attestation/attestation.go) +- Remote attestation and `nvattest` invocation: [`internal/attestation/manager.go`](internal/attestation/manager.go), [`internal/attestation/collector.go`](internal/attestation/collector.go) ### Threat Model diff --git a/cmd/fleetint/run.go b/cmd/fleetint/run.go index 466fbbb1..aaede648 100644 --- a/cmd/fleetint/run.go +++ b/cmd/fleetint/run.go @@ -202,15 +202,6 @@ func configureHealthExporterFromEnv(cfg *config.Config) error { return err } - // FLEETINT_ATTESTATION_JITTER_ENABLED - Enable/disable attestation jitter - if err := setBoolFromEnv("FLEETINT_ATTESTATION_JITTER_ENABLED", &he.Attestation.JitterEnabled, "set attestation jitter enabled from env", "attestation_jitter_enabled"); err != nil { - return err - } - - if err := setDurationFromEnv("FLEETINT_ATTESTATION_INTERVAL", &he.Attestation.Interval, "set attestation interval from env", "attestation_interval", 0, 0); err != nil { - return err - } - // Lookbacks if err := setDurationFromEnv("FLEETINT_METRICS_LOOKBACK", &he.MetricsLookback, "set health exporter metrics lookback from env", "metrics_lookback", 0, 0); err != nil { return err @@ -231,6 +222,29 @@ func configureHealthExporterFromEnv(cfg *config.Config) error { return nil } +func configureLoopConfigFromEnv(cfg *config.Config) error { + if cfg.Inventory != nil { + if err := setBoolFromEnv("FLEETINT_INVENTORY_ENABLED", &cfg.Inventory.Enabled, "set inventory enabled from env", "inventory_enabled"); err != nil { + return err + } + if err := setDurationFromEnv("FLEETINT_INVENTORY_INTERVAL", &cfg.Inventory.Interval, "set inventory interval from env", "inventory_interval", time.Minute, 0); err != nil { + return err + } + } + if cfg.Attestation != nil { + if err := setBoolFromEnv("FLEETINT_ATTESTATION_ENABLED", &cfg.Attestation.Enabled, "set attestation enabled from env", "attestation_enabled"); err != nil { + return err + } + if err := setDurationFromEnv("FLEETINT_ATTESTATION_INITIAL_INTERVAL", &cfg.Attestation.InitialInterval, "set attestation initial interval from env", "attestation_initial_interval", time.Minute, 0); err != nil { + return err + } + if err := setDurationFromEnv("FLEETINT_ATTESTATION_INTERVAL", &cfg.Attestation.Interval, "set attestation interval from env", "attestation_interval", time.Minute, 0); err != nil { + return err + } + } + return nil +} + func runCommand(cliContext *cli.Context) error { logLevel := cliContext.String("log-level") logFile := cliContext.String("log-file") @@ -310,6 +324,9 @@ func runCommand(cliContext *cli.Context) error { if err := configureHealthExporterFromEnv(cfg); err != nil { return fmt.Errorf("failed to configure health exporter from environment variables: %w", err) } + if err := configureLoopConfigFromEnv(cfg); err != nil { + return fmt.Errorf("failed to configure loop settings from environment variables: %w", err) + } log.Logger.Infow("health exporter configuration", "cfg", cfg.HealthExporter) if listenAddress != "" { diff --git a/deployments/helm/fleet-intelligence-agent/README.md b/deployments/helm/fleet-intelligence-agent/README.md index ddec247a..c6d75b28 100644 --- a/deployments/helm/fleet-intelligence-agent/README.md +++ b/deployments/helm/fleet-intelligence-agent/README.md @@ -33,7 +33,10 @@ Common values (defaults from `values.yaml`): | `env.FLEETINT_EVENTS_LOOKBACK` | `"1m"` | Lookback window for events export. | | `env.FLEETINT_CHECK_INTERVAL` | `"1m"` | Health check frequency (1s to 24h). | | `env.FLEETINT_RETRY_MAX_ATTEMPTS` | `"3"` | Max retry attempts for failed exports. | -| `env.FLEETINT_ATTESTATION_JITTER_ENABLED` | `"true"` | Enable/disable attestation jitter. | +| `env.FLEETINT_INVENTORY_ENABLED` | `"true"` | Enable or disable the inventory loop. | +| `env.FLEETINT_INVENTORY_INTERVAL` | `"1h"` | Inventory loop interval override. | +| `env.FLEETINT_ATTESTATION_ENABLED` | `"true"` | Enable or disable the attestation loop. | +| `env.FLEETINT_ATTESTATION_INITIAL_INTERVAL` | `"5m"` | Initial attestation bootstrap interval before the first successful attestation. | | `env.FLEETINT_ATTESTATION_INTERVAL` | `"24h"` | Attestation interval override. | | `env.HTTP_PROXY` | `""` | Optional HTTP proxy for outbound requests. | | `env.HTTPS_PROXY` | `""` | Optional HTTPS proxy for outbound requests. | diff --git a/deployments/helm/fleet-intelligence-agent/values.yaml b/deployments/helm/fleet-intelligence-agent/values.yaml index 0e1b75ee..f7d2150b 100644 --- a/deployments/helm/fleet-intelligence-agent/values.yaml +++ b/deployments/helm/fleet-intelligence-agent/values.yaml @@ -40,7 +40,10 @@ env: FLEETINT_EVENTS_LOOKBACK: "1m" FLEETINT_CHECK_INTERVAL: "1m" FLEETINT_RETRY_MAX_ATTEMPTS: "3" - FLEETINT_ATTESTATION_JITTER_ENABLED: "true" + FLEETINT_INVENTORY_ENABLED: "true" + FLEETINT_INVENTORY_INTERVAL: "1h" + FLEETINT_ATTESTATION_ENABLED: "true" + FLEETINT_ATTESTATION_INITIAL_INTERVAL: "5m" FLEETINT_ATTESTATION_INTERVAL: "24h" HTTP_PROXY: "" HTTPS_PROXY: "" diff --git a/deployments/packages/systemd/fleetint.env b/deployments/packages/systemd/fleetint.env index e589d782..b06c47d7 100644 --- a/deployments/packages/systemd/fleetint.env +++ b/deployments/packages/systemd/fleetint.env @@ -11,7 +11,10 @@ FLEETINT_METRICS_LOOKBACK="1m" FLEETINT_EVENTS_LOOKBACK="1m" FLEETINT_CHECK_INTERVAL="1m" FLEETINT_RETRY_MAX_ATTEMPTS="3" -FLEETINT_ATTESTATION_JITTER_ENABLED="true" +FLEETINT_INVENTORY_ENABLED="true" +FLEETINT_INVENTORY_INTERVAL="1h" +FLEETINT_ATTESTATION_ENABLED="true" +FLEETINT_ATTESTATION_INITIAL_INTERVAL="5m" FLEETINT_ATTESTATION_INTERVAL="24h" HTTP_PROXY="" HTTPS_PROXY="" diff --git a/docs/configuration.md b/docs/configuration.md index 9ec52f38..b3324345 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -50,7 +50,10 @@ These environment variables are read by `fleetint run` at startup. | `FLEETINT_EVENTS_LOOKBACK` | Lookback window for events included in each export. | `1m` | `/etc/default/fleetint` | `env.FLEETINT_EVENTS_LOOKBACK` | | `FLEETINT_CHECK_INTERVAL` | Health check interval for monitored components. Valid range: `1s` to `24h`. | `1m` | `/etc/default/fleetint` | `env.FLEETINT_CHECK_INTERVAL` | | `FLEETINT_RETRY_MAX_ATTEMPTS` | Maximum retry attempts for failed exports. Minimum: `0`. | `3` | `/etc/default/fleetint` | `env.FLEETINT_RETRY_MAX_ATTEMPTS` | -| `FLEETINT_ATTESTATION_JITTER_ENABLED` | Enable random startup jitter for attestation scheduling. | `true` | `/etc/default/fleetint` | `env.FLEETINT_ATTESTATION_JITTER_ENABLED` | +| `FLEETINT_INVENTORY_ENABLED` | Enable or disable the inventory loop. | `true` | `/etc/default/fleetint` | `env.FLEETINT_INVENTORY_ENABLED` | +| `FLEETINT_INVENTORY_INTERVAL` | Inventory loop interval override. Minimum: `1m`. | `1h` | `/etc/default/fleetint` | `env.FLEETINT_INVENTORY_INTERVAL` | +| `FLEETINT_ATTESTATION_ENABLED` | Enable or disable the attestation loop. | `true` | `/etc/default/fleetint` | `env.FLEETINT_ATTESTATION_ENABLED` | +| `FLEETINT_ATTESTATION_INITIAL_INTERVAL` | Initial attestation bootstrap interval before the first successful attestation. Minimum: `1m`. | `5m` | `/etc/default/fleetint` | `env.FLEETINT_ATTESTATION_INITIAL_INTERVAL` | | `FLEETINT_ATTESTATION_INTERVAL` | Attestation interval override. | `24h` | `/etc/default/fleetint` | `env.FLEETINT_ATTESTATION_INTERVAL` | | `HTTP_PROXY` | Proxy URL for outbound HTTP requests. | empty | `/etc/default/fleetint` | `env.HTTP_PROXY` | | `HTTPS_PROXY` | Proxy URL for outbound HTTPS requests. | empty | `/etc/default/fleetint` | `env.HTTPS_PROXY` | @@ -58,7 +61,7 @@ These environment variables are read by `fleetint run` at startup. Notes: - Duration-valued environment variables use Go duration syntax such as `30s`, `1m`, `10m`, or `24h`. -- These environment variables modify the health exporter configuration used by `fleetint run`. +- These environment variables modify the telemetry exporter configuration and runtime loop intervals used by `fleetint run`. - `DCGM_URL` and `DCGM_URL_IS_UNIX_SOCKET` configure connectivity to DCGM HostEngine for DCGM-backed components. ### Bare Metal Example @@ -95,6 +98,7 @@ env: FLEETINT_COLLECT_INTERVAL: "2m" FLEETINT_INCLUDE_EVENTS: "false" FLEETINT_CHECK_INTERVAL: "30s" + FLEETINT_INVENTORY_INTERVAL: "2m" HTTPS_PROXY: "http://proxy.example.com:3128" ``` diff --git a/internal/attestation/attestation.go b/internal/attestation/attestation.go deleted file mode 100644 index e07441f0..00000000 --- a/internal/attestation/attestation.go +++ /dev/null @@ -1,613 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package attestation provides functionality for GPU attestation -package attestation - -import ( - "bytes" - "context" - "crypto/rand" - "encoding/json" - "fmt" - "math/big" - "net/http" - "os/exec" - "sync" - "time" - - pkgfile "github.com/NVIDIA/fleet-intelligence-sdk/pkg/file" - "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" - pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" - nvidianvml "github.com/NVIDIA/fleet-intelligence-sdk/pkg/nvidia-query/nvml" - "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" - - "github.com/NVIDIA/fleet-intelligence-agent/internal/config" - "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" -) - -var defaultStateFileFn = config.DefaultStateFile - -// EvidenceItem represents a single evidence item from the attestation SDK -type EvidenceItem struct { - Arch string `json:"arch"` - Certificate string `json:"certificate"` - DriverVersion string `json:"driver_version"` - Evidence string `json:"evidence"` - Nonce string `json:"nonce"` - VBIOSVersion string `json:"vbios_version"` - Version string `json:"version"` -} - -// AttestationSDKResponse represents the complete response from the attestation SDK -type AttestationSDKResponse struct { - Evidences []EvidenceItem `json:"evidences"` - ResultCode int `json:"result_code"` - ResultMessage string `json:"result_message"` -} - -// AttestationData represents the complete attestation information including SDK response and timestamp -type AttestationData struct { - SDKResponse AttestationSDKResponse `json:"sdk_response"` - NonceRefreshTimestamp time.Time `json:"nonce_refresh_timestamp"` - Success bool `json:"success"` - ErrorMessage string `json:"error_message,omitempty"` -} - -// Manager manages the attestation process with configurable intervals -type Manager struct { - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup - cache *cache - nvmlInstance nvidianvml.Instance - config *config.AttestationConfig -} - -// cache holds the latest attestation results with thread-safe access -type cache struct { - mu sync.RWMutex - attestationData *AttestationData - lastUpdated time.Time -} - -// GetAttestationData returns the current attestation data thread-safely -func (c *cache) GetAttestationData() *AttestationData { - c.mu.RLock() - defer c.mu.RUnlock() - return c.attestationData -} - -// updateAttestationData updates the attestation cache thread-safely -func (c *cache) updateAttestationData(attestationData *AttestationData) { - c.mu.Lock() - defer c.mu.Unlock() - c.attestationData = attestationData - c.lastUpdated = time.Now().UTC() -} - -// NewManager creates a new attestation manager -func NewManager(ctx context.Context, nvmlInstance nvidianvml.Instance, config *config.AttestationConfig) *Manager { - cctx, cancel := context.WithCancel(ctx) - - // Use 24 hours as default if not specified or invalid - if config.Interval.Duration <= 0 { - config.Interval.Duration = 24 * time.Hour - } - - return &Manager{ - ctx: cctx, - cancel: cancel, - cache: &cache{}, - nvmlInstance: nvmlInstance, - config: config, - } -} - -// GetAttestationData returns the current attestation data directly -func (m *Manager) GetAttestationData() *AttestationData { - return m.cache.GetAttestationData() -} - -// IsAttestationDataUpdated checks if attestation data has been updated since the given time -func (m *Manager) IsAttestationDataUpdated(since time.Time) bool { - return m.cache.isUpdatedSince(since) -} - -// isUpdatedSince checks if the cache has been updated since the given time -func (c *cache) isUpdatedSince(since time.Time) bool { - c.mu.RLock() - defer c.mu.RUnlock() - return c.lastUpdated.After(since) -} - -// retryInterval is the shorter interval used when agent is not enrolled yet -const retryInterval = 5 * time.Minute - -// Start begins the attestation loop with jitter to prevent thundering herd -// Uses dynamic intervals: shorter retry interval when not enrolled, normal interval otherwise -func (m *Manager) Start() { - log.Logger.Infow("Starting attestation manager with thundering herd prevention") - - m.wg.Add(1) - go func() { - defer m.wg.Done() - // Add initial jitter to spread out startup requests (0-30 minutes) if enabled - var initialJitter time.Duration - if m.config.JitterEnabled { - initialJitter = m.calculateJitter(30 * time.Minute) - log.Logger.Infow("Adding initial startup jitter to prevent thundering herd", - "jitter_duration", initialJitter) - } else { - log.Logger.Infow("Startup jitter disabled for testing") - } - - // Wait for initial jitter before first attestation - select { - case <-m.ctx.Done(): - log.Logger.Infow("Context done during initial jitter") - return - case <-time.After(initialJitter): - // Continue to first attestation - } - - // Run first attestation with additional jitter - shouldRetrySoon := m.runAttestationWithJitter() - - // Create ticker with configurable interval (default 24 hours) - // Will be reset after each attestation based on result - ticker := time.NewTicker(m.getNextInterval(shouldRetrySoon)) - defer ticker.Stop() - - log.Logger.Infow("Attestation ticker started", "interval", m.getNextInterval(shouldRetrySoon)) - - for { - select { - case <-m.ctx.Done(): - log.Logger.Infow("Context done, stopping attestation manager") - return - case <-ticker.C: - shouldRetrySoon = m.runAttestationWithJitter() - nextInterval := m.getNextInterval(shouldRetrySoon) - ticker.Reset(nextInterval) - } - } - }() -} - -// getNextInterval returns the appropriate interval based on whether we should retry soon -func (m *Manager) getNextInterval(shouldRetrySoon bool) time.Duration { - if shouldRetrySoon { - // Use the shorter of retryInterval and configured interval - // to avoid slowing down attestation in fast-retry environments (e.g., testing) - interval := retryInterval - if m.config.Interval.Duration < retryInterval { - interval = m.config.Interval.Duration - } - log.Logger.Infow("Agent not enrolled, using retry interval", - "retry_interval", interval) - return interval - } - log.Logger.Infow("Using normal attestation interval", - "interval", m.config.Interval.Duration) - return m.config.Interval.Duration -} - -// Stop gracefully shuts down the attestation manager and waits for the -// background goroutine to exit. This ensures that any in-progress call -// to defaultStateFileFn (or any other shared state) finishes before Stop -// returns, which prevents data races in tests and orderly cleanup in production. -func (m *Manager) Stop() { - log.Logger.Infow("Stopping attestation manager") - m.cancel() - m.wg.Wait() -} - -// runAttestation performs the attestation process and updates the cache -// Returns true if attestation should be retried soon (e.g., agent not enrolled yet) -func (m *Manager) runAttestation() bool { - log.Logger.Infow("Starting attestation process") - - // Always update cache with result (success or failure) so server knows status - attestationData := &AttestationData{} - - // Step 1: Get machine ID - log.Logger.Debugw("Getting machine ID in Attestation") - machineId, err := m.getMachineId() - if err != nil { - log.Logger.Errorw("Failed to get machine ID in Attestation", "error", err) - // SDKResponse and NonceRefreshTimestamp are nil - attestationData.Success = false - attestationData.ErrorMessage = err.Error() - m.cache.updateAttestationData(attestationData) - return false - } - - log.Logger.Debugw("Machine ID retrieved in Attestation", - "machine_id", machineId) - - // Step 2: Load JWT token from metadata database - jwtToken := m.getJWTTokenFromMetadata(m.ctx) - if jwtToken == "" { - if endpoint := m.getEndpointFromMetadata(m.ctx); endpoint != "" { - log.Logger.Errorw("JWT token not found in metadata", "endpoint", endpoint) - // SDKResponse and NonceRefreshTimestamp are nil - attestationData.Success = false - attestationData.ErrorMessage = "JWT token not found in agent metadata while agent is enrolled" - m.cache.updateAttestationData(attestationData) - return false - } else { - log.Logger.Infow("No backend endpoint found in metadata, agent not enrolled, will retry soon") - // SDKResponse and NonceRefreshTimestamp are nil - attestationData.Success = false - attestationData.ErrorMessage = "No backend endpoint found in metadata, agent is not enrolled" - m.cache.updateAttestationData(attestationData) - return true // Retry soon - agent may enroll shortly - } - } - - // Step 3: Get nonce by calling the envoy endpoint - nonce, nonceRefreshTimestamp, err := m.getNonce(jwtToken, machineId) - if err != nil { - // if agent is not enrolled, it will return in step 2. so here we can directly return the nonce error - log.Logger.Errorw("Failed to get nonce", "error", err) - // SDKResponse and NonceRefreshTimestamp are nil - attestationData.Success = false - attestationData.ErrorMessage = err.Error() - m.cache.updateAttestationData(attestationData) - return false - } - - // Update nonce refresh timestamp with actual server response - attestationData.NonceRefreshTimestamp = nonceRefreshTimestamp - - // Step 4: Get evidences from attestation SDK - log.Logger.Debugw("Getting evidences with nonce") - sdkResponse, err := m.getEvidences(nonce) - if err != nil { - log.Logger.Errorw("Failed to get evidences from attestation SDK", "error", err) - // SDKResponse - attestationData.Success = false - attestationData.ErrorMessage = err.Error() - m.cache.updateAttestationData(attestationData) - return false - } - - // Success case: populate all data - attestationData.SDKResponse = *sdkResponse - attestationData.Success = true - attestationData.ErrorMessage = "" - log.Logger.Debugw("Attestation data", "attestation_data", attestationData) - - // Update the attestation cache - m.cache.updateAttestationData(attestationData) - return false -} - -func (m *Manager) getNonce(jwtToken string, machineId string) (string, time.Time, error) { - endpointURL, err := m.getValidatedNonceEndpoint(m.ctx) - if err != nil { - return "", time.Time{}, err - } - - // Request payload (only include machine ID, JWT token goes in header) - requestBody, err := json.Marshal(map[string]string{ - "uuid": machineId, - }) - if err != nil { - log.Logger.Debugw("failed to marshal request body in nonce endpoint request", "error", err) - return "", time.Time{}, err - } - - // Create HTTP request tied to the manager context so that Stop() cancellation - // unblocks the request and prevents wg.Wait() from hanging indefinitely. - req, err := http.NewRequestWithContext(m.ctx, "POST", endpointURL, bytes.NewBuffer(requestBody)) - if err != nil { - log.Logger.Debugw("failed to create HTTP request in nonce endpoint request", "error", err) - return "", time.Time{}, err - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", jwtToken)) - - // Make the HTTP request. Disable redirects so a compromised backend - // cannot bounce us to an internal service (SSRF). - client := &http.Client{ - Timeout: 30 * time.Second, - CheckRedirect: func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse - }, - } - resp, err := client.Do(req) - if err != nil { - log.Logger.Debugw("failed to make POST request in nonce endpoint request", "error", err) - return "", time.Time{}, err - } - defer resp.Body.Close() - - // Parsing the response - var response struct { - Nonce string `json:"nonce"` - NonceRefreshTimestamp time.Time `json:"nonce_refresh_timestamp"` - Error string `json:"error"` - } - - log.Logger.Debugw("HTTP Response received:", - "status_code", resp.StatusCode, - "status", resp.Status, - "content_type", resp.Header.Get("Content-Type")) - - if resp.StatusCode != http.StatusOK { - return "", time.Time{}, fmt.Errorf("nonce endpoint returned HTTP %d", resp.StatusCode) - } - - if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { - log.Logger.Debugw("failed to decode response in nonce endpoint request", "error", err) - return "", time.Time{}, err - } - - if response.Error != "" { - log.Logger.Debugw("error from server in nonce endpoint request", "error", response.Error) - return "", time.Time{}, fmt.Errorf("nonce endpoint returned error: %s", response.Error) - } - - log.Logger.Debugw("Nonce received from server", "nonce_refresh_timestamp", response.NonceRefreshTimestamp) - - return response.Nonce, response.NonceRefreshTimestamp, nil -} - -func (m *Manager) getValidatedNonceEndpoint(ctx context.Context) (string, error) { - baseEndpoint := m.getEndpointFromMetadata(ctx) - if baseEndpoint != "" { - validated, err := endpoint.ValidateBackendEndpoint(baseEndpoint) - if err != nil { - return "", fmt.Errorf("invalid backend endpoint: %w", err) - } - - joined, err := endpoint.JoinPath(validated, "api", "v1", "health", "nonce") - if err != nil { - return "", fmt.Errorf("failed to construct nonce endpoint: %w", err) - } - return joined, nil - } - - legacyNonceEndpoint := m.getLegacyNonceEndpointFromMetadata(ctx) - if legacyNonceEndpoint == "" { - return "", fmt.Errorf("backend endpoint not found in metadata") - } - - validated, err := endpoint.ValidateBackendEndpoint(legacyNonceEndpoint) - if err != nil { - return "", fmt.Errorf("invalid nonce endpoint: %w", err) - } - return validated.String(), nil -} - -// getJWTTokenFromMetadata retrieves the JWT token from the metadata database -func (m *Manager) getJWTTokenFromMetadata(ctx context.Context) string { - stateFile, err := defaultStateFileFn() - if err != nil { - log.Logger.Debugw("failed to get state file path", "error", err) - return "" - } - - dbRO, err := sqlite.Open(stateFile) - if err != nil { - log.Logger.Debugw("failed to open state database", "error", err) - return "" - } - defer dbRO.Close() - - // Load JWT token from metadata - if token, err := pkgmetadata.ReadMetadata(ctx, dbRO, pkgmetadata.MetadataKeyToken); err == nil && token != "" { - return token - } - - log.Logger.Debugw("JWT token not found in metadata") - return "" -} - -func (m *Manager) getEndpointFromMetadata(ctx context.Context) string { - stateFile, err := defaultStateFileFn() - if err != nil { - log.Logger.Debugw("failed to get state file path", "error", err) - return "" - } - - dbRO, err := sqlite.Open(stateFile) - if err != nil { - log.Logger.Debugw("failed to open state database", "error", err) - return "" - } - defer dbRO.Close() - - // Load backend base URL from metadata - if endpoint, err := pkgmetadata.ReadMetadata(ctx, dbRO, "backend_base_url"); err == nil && endpoint != "" { - return endpoint - } - - log.Logger.Debugw("backend endpoint not found in metadata") - return "" -} - -func (m *Manager) getLegacyNonceEndpointFromMetadata(ctx context.Context) string { - stateFile, err := defaultStateFileFn() - if err != nil { - log.Logger.Debugw("failed to get state file path", "error", err) - return "" - } - - dbRO, err := sqlite.Open(stateFile) - if err != nil { - log.Logger.Debugw("failed to open state database", "error", err) - return "" - } - defer dbRO.Close() - - if endpoint, err := pkgmetadata.ReadMetadata(ctx, dbRO, "nonce_endpoint"); err == nil && endpoint != "" { - return endpoint - } - - log.Logger.Debugw("legacy nonce endpoint not found in metadata") - return "" -} - -func (m *Manager) getMachineId() (string, error) { - stateFile, err := defaultStateFileFn() - if err != nil { - return "", fmt.Errorf("failed to get state file path: %w", err) - } - - dbRO, err := sqlite.Open(stateFile) - if err != nil { - return "", fmt.Errorf("failed to open state database: %w", err) - } - defer dbRO.Close() - - machineID, err := pkgmetadata.ReadMachineID(m.ctx, dbRO) - if err != nil { - return "", fmt.Errorf("failed to read machine ID from metadata: %w", err) - } - - if machineID == "" { - return "", fmt.Errorf("machine ID not found in metadata") - } - - return machineID, nil -} - -// validateNonce verifies that a nonce returned by the backend is safe to forward -// as a command-line argument to nvattest. It enforces an allowlist of characters -// (hex, base64url, and common padding/separator symbols) and a maximum length so -// that a compromised backend cannot craft an argument that exploits nvattest's own -// argument parser. -func validateNonce(nonce string) error { - if nonce == "" { - return fmt.Errorf("nonce is empty") - } - const maxLen = 512 - if len(nonce) > maxLen { - return fmt.Errorf("nonce length %d exceeds maximum of %d characters", len(nonce), maxLen) - } - for i, c := range nonce { - switch { - case c >= '0' && c <= '9', - c >= 'a' && c <= 'z', - c >= 'A' && c <= 'Z', - c == '-', c == '_', c == '=', c == '+', c == '/': - // allowed - default: - return fmt.Errorf("nonce contains invalid character %q at position %d", c, i) - } - } - return nil -} - -func (m *Manager) getEvidences(nonce string) (*AttestationSDKResponse, error) { - if err := validateNonce(nonce); err != nil { - return nil, fmt.Errorf("invalid nonce received from backend: %w", err) - } - - log.Logger.Infow("Calling attestation SDK CLI") - - // Execute nvattest command - // Set timeout to prevent hanging, derived from manager context to respect cancellation - ctx, cancel := context.WithTimeout(m.ctx, 60*time.Second) - defer cancel() - - nvattestPath, err := pkgfile.LocateExecutable("nvattest") - if err != nil { - return nil, fmt.Errorf("failed to locate attestation CLI: %w", err) - } - cmd := exec.CommandContext(ctx, nvattestPath, "collect-evidence", "--gpu-evidence-source=corelib", "--nonce", nonce, "--gpu-architecture", "blackwell", "--format", "json") - - // Capture stdout (JSON) and stderr (logs) separately - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - runErr := cmd.Run() - - log.Logger.Debugw("Attestation CLI completed", "exit_error", runErr, "stdout", stdout.String(), "stderr", stderr.String()) - - if runErr != nil { - // If stdout is empty, it means the command failed completely (e.g. command not found) - if stdout.Len() == 0 { - return nil, fmt.Errorf("attestation CLI execution failed: %w (stderr: %s)", runErr, stderr.String()) - } - // If stdout has content, we continue to try parsing the JSON response - log.Logger.Warnw("Attestation CLI returned error exit code but has output, attempting to parse", "error", runErr) - } - - // Parse the JSON response from stdout (clean JSON without log messages) - var response AttestationSDKResponse - if parseErr := json.Unmarshal(stdout.Bytes(), &response); parseErr != nil { - log.Logger.Debugw("Failed to parse attestation CLI JSON response", "parse_error", parseErr) - return nil, fmt.Errorf("failed to parse CLI response: %w (stderr: %s, stdout: %s, exec: %v)", parseErr, stderr.String(), stdout.String(), runErr) - } - - log.Logger.Infow("Successfully called attestation SDK", - "result_code", response.ResultCode, - "result_message", response.ResultMessage, - "evidences_count", len(response.Evidences)) - - return &response, nil -} - -// calculateJitter returns a random duration between 0 and maxJitter to prevent thundering herd -func (m *Manager) calculateJitter(maxJitter time.Duration) time.Duration { - if maxJitter <= 0 { - return 0 - } - - // Generate cryptographically secure random number - maxMs := int64(maxJitter / time.Millisecond) - if maxMs <= 0 { - return 0 - } - - randomMs, err := rand.Int(rand.Reader, big.NewInt(maxMs)) - if err != nil { - log.Logger.Warnw("Failed to generate secure random jitter, using fallback", "error", err) - // Fallback to simple time-based pseudo-random - return time.Duration(time.Now().UnixNano()%maxMs) * time.Millisecond - } - - return time.Duration(randomMs.Int64()) * time.Millisecond -} - -// runAttestationWithJitter runs attestation with additional per-request jitter -// Returns true if attestation should be retried soon (e.g., agent not enrolled yet) -func (m *Manager) runAttestationWithJitter() bool { - if !m.config.JitterEnabled { - log.Logger.Infow("Running attestation immediately (jitter disabled)") - return m.runAttestation() - } - - // Add significant jitter (0–30 minutes) for each request to spread load across a window, - // reducing thundering herd risk across many agents. - requestJitter := m.calculateJitter(30 * time.Minute) - log.Logger.Infow("Adding request jitter for thundering herd prevention", - "jitter_duration", requestJitter, - "max_jitter", "30 minutes") - - select { - case <-m.ctx.Done(): - return false - case <-time.After(requestJitter): - return m.runAttestation() - } -} diff --git a/internal/attestation/attestation_test.go b/internal/attestation/attestation_test.go deleted file mode 100644 index e43c80e7..00000000 --- a/internal/attestation/attestation_test.go +++ /dev/null @@ -1,778 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package attestation - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "path/filepath" - "strings" - "testing" - "time" - - pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" - "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/NVIDIA/fleet-intelligence-agent/internal/config" -) - -func useMissingStateFile(t *testing.T) { - t.Helper() - - orig := defaultStateFileFn - defaultStateFileFn = func() (string, error) { - return filepath.Join(t.TempDir(), "missing", "fleetint.state"), nil - } - t.Cleanup(func() { - defaultStateFileFn = orig - }) -} - -func TestManager_NewManager(t *testing.T) { - ctx := context.Background() - - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) // nil nvmlInstance, 20s for testing, no jitter - require.NotNil(t, manager) - assert.NotNil(t, manager.ctx) - assert.NotNil(t, manager.cancel) - assert.NotNil(t, manager.cache) - assert.Nil(t, manager.nvmlInstance) // Should be nil as passed -} - -func TestManager_StartStop(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // Start should not block (Start() doesn't return error) - manager.Start() - - // Give it a moment to start - time.Sleep(50 * time.Millisecond) - - // Stop should work cleanly - manager.Stop() - - // Verify context is canceled - select { - case <-manager.ctx.Done(): - // Expected - context should be canceled - case <-time.After(100 * time.Millisecond): - t.Error("Expected context to be canceled after Stop()") - } -} - -func TestManager_GetAttestationData(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // Initially should have empty data - attestationData := manager.GetAttestationData() - assert.Nil(t, attestationData) - - // Manually update cache to test getter - testAttestationData := &AttestationData{ - SDKResponse: AttestationSDKResponse{ - Evidences: []EvidenceItem{ - { - Arch: "turing", - Certificate: "test_cert", - DriverVersion: "550.90.07", - Evidence: "test_evidence", - Nonce: "test_nonce", - VBIOSVersion: "90.17.A9.00.0B", - Version: "1.0", - }, - }, - ResultCode: 0, - ResultMessage: "Ok", - }, - NonceRefreshTimestamp: time.Now().UTC(), - } - - manager.cache.updateAttestationData(testAttestationData) - - // Now should return the test data - attestationData = manager.GetAttestationData() - assert.Equal(t, testAttestationData, attestationData) -} - -func TestManager_IsAttestationDataUpdated(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - baseTime := time.Now().UTC() - - // Initially, no updates - assert.False(t, manager.IsAttestationDataUpdated(baseTime)) - - // Update the cache - testData := &AttestationData{ - SDKResponse: AttestationSDKResponse{ - Evidences: []EvidenceItem{{ - Arch: "turing", - Certificate: "test", - DriverVersion: "550.90.07", - Evidence: "test", - Nonce: "test_nonce", - VBIOSVersion: "90.17.A9.00.0B", - Version: "1.0", - }}, - ResultCode: 0, - ResultMessage: "Ok", - }, - NonceRefreshTimestamp: time.Now().UTC(), - } - manager.cache.updateAttestationData(testData) - - // Should now show as updated since baseTime - assert.True(t, manager.IsAttestationDataUpdated(baseTime)) - - // But not updated compared to a future time - futureTime := time.Now().Add(1 * time.Hour) - assert.False(t, manager.IsAttestationDataUpdated(futureTime)) -} - -func TestManager_GetMachineId_NoDatabase(t *testing.T) { - useMissingStateFile(t) - - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // getMachineId will fail because there's no database with machine ID in test environment - machineId, err := manager.getMachineId() - - // Expected to fail in test environment without proper database setup - assert.Error(t, err) - assert.Empty(t, machineId) -} - -// Define the response struct for testing -type testNonceResponse struct { - Nonce string `json:"nonce"` - NonceRefreshTimestamp time.Time `json:"nonceRefreshTimestamp"` - Error string `json:"error,omitempty"` -} - -func useDefaultTransport(t *testing.T, transport http.RoundTripper) { - t.Helper() - - orig := http.DefaultTransport - http.DefaultTransport = transport - t.Cleanup(func() { - http.DefaultTransport = orig - }) -} - -func TestManager_GetNonce_MockHTTP(t *testing.T) { - // Test the nonce parsing logic without actually calling the private method - // This tests the HTTP interaction pattern - - expectedNonce := "test-nonce-12345" - expectedTimestamp := time.Now().UTC() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify request method and content type - assert.Equal(t, "POST", r.Method) - assert.Equal(t, "application/json", r.Header.Get("Content-Type")) - - // Verify Bearer authorization header - authHeader := r.Header.Get("Authorization") - assert.Equal(t, "Bearer test-jwt-token", authHeader, "Should have Bearer authorization header") - - // Verify request body (should only contain uuid, not token) - var requestBody map[string]string - err := json.NewDecoder(r.Body).Decode(&requestBody) - require.NoError(t, err) - assert.Equal(t, "test-machine-id", requestBody["uuid"]) - assert.NotContains(t, requestBody, "token", "Token should not be in request body when using Bearer auth") - - // Send successful response - response := testNonceResponse{ - Nonce: expectedNonce, - NonceRefreshTimestamp: expectedTimestamp, - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - t.Errorf("Failed to encode response: %v", err) - } - })) - defer server.Close() - - // Test HTTP request/response parsing manually with Bearer auth - url := server.URL + "/nonce" - requestBody, err := json.Marshal(map[string]string{ - "uuid": "test-machine-id", - }) - require.NoError(t, err) - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestBody)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer test-jwt-token") - - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - var response testNonceResponse - err = json.NewDecoder(resp.Body).Decode(&response) - require.NoError(t, err) - - assert.Equal(t, expectedNonce, response.Nonce) - assert.Equal(t, expectedTimestamp.Unix(), response.NonceRefreshTimestamp.Unix()) -} - -func TestManager_GetNonce_ServerError(t *testing.T) { - // Test server error response parsing - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - response := testNonceResponse{ - Error: "Invalid token", - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - t.Errorf("Failed to encode response: %v", err) - } - })) - defer server.Close() - - // Test error response parsing with Bearer auth - url := server.URL + "/nonce" - requestBody, err := json.Marshal(map[string]string{ - "uuid": "test-machine-id", - }) - require.NoError(t, err) - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestBody)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer invalid-token") - - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - var response testNonceResponse - err = json.NewDecoder(resp.Body).Decode(&response) - require.NoError(t, err) - - assert.Equal(t, "Invalid token", response.Error) - assert.Empty(t, response.Nonce) - assert.True(t, response.NonceRefreshTimestamp.IsZero()) -} - -func TestManager_GetValidatedNonceEndpoint_DerivesFromStoredBackendBaseURL(t *testing.T) { - manager := newTestManager(t) - stateFile := setupAttestationMetadataDB(t, map[string]string{ - "backend_base_url": "https://backend.example.com", - }) - useTestStateFile(t, stateFile) - - got, err := manager.getValidatedNonceEndpoint(context.Background()) - require.NoError(t, err) - assert.Equal(t, "https://backend.example.com/api/v1/health/nonce", got) -} - -func TestManager_GetValidatedNonceEndpoint_RejectsInvalidStoredBackendBaseURL(t *testing.T) { - manager := newTestManager(t) - stateFile := setupAttestationMetadataDB(t, map[string]string{ - "backend_base_url": "http://evil.example.com", - }) - useTestStateFile(t, stateFile) - - _, err := manager.getValidatedNonceEndpoint(context.Background()) - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid backend endpoint") - assert.Contains(t, err.Error(), "https") -} - -func newTestManager(t *testing.T) *Manager { - t.Helper() - - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - return NewManager(ctx, nil, cfg) -} - -func setupAttestationMetadataDB(t *testing.T, entries map[string]string) string { - t.Helper() - - stateFile := filepath.Join(t.TempDir(), "fleetint.state") - db, err := sqlite.Open(stateFile) - require.NoError(t, err) - - err = pkgmetadata.CreateTableMetadata(context.Background(), db) - require.NoError(t, err) - - for key, value := range entries { - err = pkgmetadata.SetMetadata(context.Background(), db, key, value) - require.NoError(t, err) - } - - err = db.Close() - require.NoError(t, err) - - return stateFile -} - -func useTestStateFile(t *testing.T, stateFile string) { - t.Helper() - - orig := defaultStateFileFn - defaultStateFileFn = func() (string, error) { - return stateFile, nil - } - t.Cleanup(func() { - defaultStateFileFn = orig - }) -} - -func TestManager_GetEvidences(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - testNonce := "931d8dd0add203ac3d8b4fbde75e115278eefcdceac5b87671a748f32364dfcb" - - sdkResponse, err := manager.getEvidences(testNonce) - - // In CI environment, the nvattest binary might not exist or directory might not exist - if err != nil { - // If binary is missing, this is expected in CI - if strings.Contains(err.Error(), "executable file not found") || strings.Contains(err.Error(), "no such file or directory") || strings.Contains(err.Error(), "not found in PATH") || strings.Contains(err.Error(), "The following arguments were not expected") { - t.Log("Attestation CLI binary or directory not found (expected in CI)") - return - } - // If it's another error, fail the test - require.NoError(t, err, "Unexpected error running attestation CLI") - } - - assert.NotNil(t, sdkResponse) - - // In test environment with binary present but no GPU (or mock), CLI may fail - // Check for expected real CLI response structure - if sdkResponse.ResultCode == 0 { - // Success case (when running on real attestation-capable hardware) - assert.Equal(t, "Ok", sdkResponse.ResultMessage) - assert.NotEmpty(t, sdkResponse.Evidences, "Should have evidences on success") - t.Log("Attestation CLI succeeded - running on attestation-capable hardware") - } else { - // Expected failure case (test environment without proper attestation hardware) - // We expect a structured error response from the CLI - t.Logf("Attestation CLI failed as expected in test environment: %s (Code: %d)", - sdkResponse.ResultMessage, sdkResponse.ResultCode) - } -} - -func TestCache_ThreadSafety(t *testing.T) { - cache := &cache{} - - // Test concurrent reads and writes - done := make(chan bool, 10) - - // Start multiple goroutines writing to cache - for i := 0; i < 5; i++ { - go func(id int) { - defer func() { done <- true }() - - for j := 0; j < 10; j++ { - testData := &AttestationData{ - SDKResponse: AttestationSDKResponse{ - Evidences: []EvidenceItem{ - { - Arch: "BLACKWELL", - Certificate: fmt.Sprintf("cert-%d-%d", id, j), - DriverVersion: "575.28", - Evidence: fmt.Sprintf("evidence-%d-%d", id, j), - Nonce: fmt.Sprintf("nonce-%d-%d", id, j), - VBIOSVersion: "96.00.AF.00.01", - Version: "1.0", - }, - }, - ResultCode: 0, - ResultMessage: "Ok", - //NonceRefreshTimestamp: time.Now().UTC().Add(time.Duration(id*j) * time.Millisecond), - }, - NonceRefreshTimestamp: time.Now().UTC().Add(time.Duration(id*j) * time.Millisecond), - } - cache.updateAttestationData(testData) - time.Sleep(time.Millisecond) // Small delay to increase chance of concurrent access - } - }(i) - } - - // Start multiple goroutines reading from cache - for i := 0; i < 5; i++ { - go func(id int) { - defer func() { done <- true }() - - for j := 0; j < 10; j++ { - cache.GetAttestationData() - baseTime := time.Now().UTC().Add(-time.Duration(j) * time.Second) - cache.isUpdatedSince(baseTime) - time.Sleep(time.Millisecond) - } - }(i) - } - - // Wait for all goroutines to complete - for i := 0; i < 10; i++ { - select { - case <-done: - // Good, goroutine completed - case <-time.After(15 * time.Second): - t.Fatal("Goroutines did not complete within timeout - possible deadlock") - } - } - - // Verify cache is still functional - attestationData := cache.GetAttestationData() - assert.NotNil(t, attestationData) // Should have some data from the concurrent writes -} - -func TestManager_RunAttestation_WithFallback(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // Since runAttestation is called by Start(), we need to test it differently - // We'll test the fallback behavior by checking that it handles errors gracefully - - // The method should not panic when NVML is not available - assert.NotPanics(t, func() { - manager.runAttestation() - }) - - // After running with fallbacks, cache should contain failure information - attestationData := manager.GetAttestationData() - if attestationData != nil { - // If we have attestation data, it should indicate failure - assert.False(t, attestationData.Success, "Should indicate failure") - assert.NotEmpty(t, attestationData.ErrorMessage, "Should have error message") - t.Log("Attestation failed as expected with error:", attestationData.ErrorMessage) - } else { - t.Log("No attestation data available - this can happen if attestation manager didn't run") - } -} - -func TestManager_IntegrationTest(t *testing.T) { - // This is a more comprehensive test that tests the update tracking functionality - - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // Test the update tracking over time - baseTime := time.Now().UTC() - - // Initially no updates - assert.False(t, manager.IsAttestationDataUpdated(baseTime)) - - // Manually update cache (simulating successful attestation) - testAttestationData := &AttestationData{ - SDKResponse: AttestationSDKResponse{ - Evidences: []EvidenceItem{ - { - Arch: "BLACKWELL", - Certificate: "integration-cert", - DriverVersion: "575.28", - Evidence: "integration-evidence", - Nonce: "integration-nonce", - VBIOSVersion: "96.00.AF.00.01", - Version: "1.0", - }, - }, - ResultCode: 0, - ResultMessage: "Ok", - }, - NonceRefreshTimestamp: time.Now().UTC(), - } - manager.cache.updateAttestationData(testAttestationData) - - // Now should show updates - assert.True(t, manager.IsAttestationDataUpdated(baseTime)) - - // Data should be retrievable - attestationData := manager.GetAttestationData() - assert.Equal(t, testAttestationData, attestationData) - - // Test Start/Stop functionality - manager.Start() - time.Sleep(50 * time.Millisecond) // Let it start - manager.Stop() - - // Context should be done - select { - case <-manager.ctx.Done(): - // Expected - case <-time.After(100 * time.Millisecond): - t.Error("Context should be canceled after Stop()") - } -} - -func TestEvidenceItem_JSONSerialization(t *testing.T) { - evidence := EvidenceItem{ - Evidence: "test-evidence-data", - Certificate: "test-certificate-data", - } - - // Test marshaling - data, err := json.Marshal(evidence) - assert.NoError(t, err) - assert.Contains(t, string(data), "test-evidence-data") - assert.Contains(t, string(data), "test-certificate-data") - - // Test unmarshaling - var unmarshaled EvidenceItem - err = json.Unmarshal(data, &unmarshaled) - assert.NoError(t, err) - assert.Equal(t, evidence.Evidence, unmarshaled.Evidence) - assert.Equal(t, evidence.Certificate, unmarshaled.Certificate) -} - -func TestManager_CalculateJitter(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: true, - } - manager := NewManager(ctx, nil, cfg) - - // Test with 0 max jitter - jitter := manager.calculateJitter(0) - assert.Equal(t, time.Duration(0), jitter) - - // Test with positive max jitter - maxJitter := 100 * time.Millisecond - for i := 0; i < 10; i++ { - jitter = manager.calculateJitter(maxJitter) - assert.GreaterOrEqual(t, jitter, time.Duration(0)) - assert.Less(t, jitter, maxJitter) - } -} - -func TestValidateNonce(t *testing.T) { - tests := []struct { - name string - nonce string - wantErr string - }{ - {name: "valid_hex", nonce: "abcdef0123456789"}, - {name: "valid_base64", nonce: "dGVzdA=="}, - {name: "valid_base64url", nonce: "abc-def_ghi+jkl/mno="}, - {name: "empty", nonce: "", wantErr: "nonce is empty"}, - {name: "too_long", nonce: strings.Repeat("a", 513), wantErr: "exceeds maximum"}, - {name: "max_length_ok", nonce: strings.Repeat("a", 512)}, - {name: "space", nonce: "abc def", wantErr: "invalid character"}, - {name: "newline", nonce: "abc\ndef", wantErr: "invalid character"}, - {name: "semicolon", nonce: "abc;def", wantErr: "invalid character"}, - {name: "shell_metachar", nonce: "$(whoami)", wantErr: "invalid character"}, - {name: "flag_like_valid_chars", nonce: "--output=/etc/passwd"}, // all chars are in the base64url allowlist; safe because nvattest receives it as --nonce value, not a flag - {name: "pipe", nonce: "abc|def", wantErr: "invalid character"}, - {name: "null_byte", nonce: "abc\x00def", wantErr: "invalid character"}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - err := validateNonce(tc.nonce) - if tc.wantErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.wantErr) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestManager_RunAttestation_ReturnsRetrySoon(t *testing.T) { - useMissingStateFile(t) - - // This test verifies that runAttestation returns the correct retry hint - // When agent is not enrolled, it should return true (retry soon) - // When there's a real failure, it should return false (normal interval) - - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 24 * time.Hour}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // Run attestation - in test environment, this will fail at getMachineId - // (which is a real failure, not "not enrolled"), so it should return false - shouldRetrySoon := manager.runAttestation() - - // Since we don't have the metadata database with machine ID in test environment, - // attestation fails with a real error (not "not enrolled"), so shouldRetrySoon should be false - assert.False(t, shouldRetrySoon, "Should return false for real failures (not 'not enrolled')") - - // Verify the cache has the failure info - attestationData := manager.GetAttestationData() - require.NotNil(t, attestationData) - assert.False(t, attestationData.Success) - assert.NotEmpty(t, attestationData.ErrorMessage) - - // The error should NOT be about enrollment - assert.NotContains(t, attestationData.ErrorMessage, "not enrolled", - "Error should be about machine ID, not enrollment") -} - -func TestRetryInterval_Constant(t *testing.T) { - // Verify the retry interval constant is set appropriately - assert.Equal(t, 5*time.Minute, retryInterval, - "Retry interval should be 5 minutes for quick recovery after enrollment") -} - -func TestManager_GetNextInterval(t *testing.T) { - tests := []struct { - name string - configInterval time.Duration - shouldRetrySoon bool - expectedInterval time.Duration - }{ - { - name: "normal interval when not retrying", - configInterval: 24 * time.Hour, - shouldRetrySoon: false, - expectedInterval: 24 * time.Hour, - }, - { - name: "retry interval when config is longer", - configInterval: 24 * time.Hour, - shouldRetrySoon: true, - expectedInterval: 5 * time.Minute, // retryInterval - }, - { - name: "config interval when config is shorter than retry", - configInterval: 20 * time.Second, - shouldRetrySoon: true, - expectedInterval: 20 * time.Second, // use config, not retryInterval - }, - { - name: "config interval when config equals retry", - configInterval: 5 * time.Minute, - shouldRetrySoon: true, - expectedInterval: 5 * time.Minute, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: tt.configInterval}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - interval := manager.getNextInterval(tt.shouldRetrySoon) - assert.Equal(t, tt.expectedInterval, interval) - }) - } -} - -func TestManager_RunAttestationWithJitter_Disabled(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // Should run immediately (we can't easily verify it ran without mocking runAttestation, - // but we can ensure it doesn't panic and covers the code path) - done := make(chan bool) - go func() { - manager.runAttestationWithJitter() - done <- true - }() - - select { - case <-done: - // Success - case <-time.After(15 * time.Second): - t.Error("runAttestationWithJitter should return immediately when jitter is disabled") - } -} - -func TestManager_RunAttestationWithJitter_Enabled(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: true, - } - manager := NewManager(ctx, nil, cfg) - - // We can't easily wait for the random jitter, but we can verify it respects context cancellation - done := make(chan bool) - go func() { - manager.runAttestationWithJitter() - done <- true - }() - - // Cancel context to force exit - cancel() - - select { - case <-done: - // Success - case <-time.After(15 * time.Second): - t.Error("runAttestationWithJitter should return when context is canceled") - } -} diff --git a/internal/attestationloop/backend.go b/internal/attestation/backend.go similarity index 96% rename from internal/attestationloop/backend.go rename to internal/attestation/backend.go index c5f63ec4..388d38b0 100644 --- a/internal/attestationloop/backend.go +++ b/internal/attestation/backend.go @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package attestationloop +package attestation import ( "context" @@ -73,7 +73,7 @@ func (f *stateBackendClientFactory) client(ctx context.Context) (backendclient.C return nil, err } if !ok || baseURL == "" { - return nil, fmt.Errorf("backend base URL not available in agent state") + return nil, fmt.Errorf("%w: backend base URL not available in agent state", ErrNotEnrolled) } return newBackendClient(baseURL) } diff --git a/internal/attestationloop/backend_test.go b/internal/attestation/backend_test.go similarity index 90% rename from internal/attestationloop/backend_test.go rename to internal/attestation/backend_test.go index 290e436d..cf4a1498 100644 --- a/internal/attestationloop/backend_test.go +++ b/internal/attestation/backend_test.go @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package attestationloop +package attestation import ( "context" @@ -181,22 +181,18 @@ func TestStateProvidersAndSubmitter(t *testing.T) { require.Equal(t, "BLACKWELL", recording.lastReq.AttestationData.SDKResponse.Evidences[0].Arch) } -func TestLegacyAttestationData(t *testing.T) { - result := &Result{ - NonceRefreshTimestamp: time.Unix(20, 0).UTC(), - Success: false, - ErrorMessage: "boom", - SDKResponse: SDKResponse{ - ResultCode: 9, - ResultMessage: "bad", - Evidences: []EvidenceItem{{Arch: "BLACKWELL"}}, - }, +func TestStateProvidersPropagateBackendClientConstructionErrors(t *testing.T) { + orig := newBackendClient + t.Cleanup(func() { newBackendClient = orig }) + + newBackendClient = func(string) (backendclient.Client, error) { + return nil, errors.New("construct failed") } - legacy := result.LegacyAttestationData() - require.NotNil(t, legacy) - require.False(t, legacy.Success) - require.Equal(t, "boom", legacy.ErrorMessage) - require.Equal(t, 9, legacy.SDKResponse.ResultCode) - require.Len(t, legacy.SDKResponse.Evidences, 1) - require.Nil(t, (*Result)(nil).LegacyAttestationData()) + state := &stubState{baseURL: "https://backend.example.com", baseOK: true} + + _, _, _, err := NewStateNonceProvider(state).GetNonce(context.Background(), "node-1", "jwt-token") + require.ErrorContains(t, err, "construct failed") + + err = NewStateBackendSubmitter(state).Submit(context.Background(), &Result{NodeID: "node-1"}, "jwt-token") + require.ErrorContains(t, err, "construct failed") } diff --git a/internal/attestationloop/collector.go b/internal/attestation/collector.go similarity index 99% rename from internal/attestationloop/collector.go rename to internal/attestation/collector.go index 1294e1da..7f747bbf 100644 --- a/internal/attestationloop/collector.go +++ b/internal/attestation/collector.go @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package attestationloop +package attestation import ( "bytes" diff --git a/internal/attestationloop/collector_test.go b/internal/attestation/collector_test.go similarity index 99% rename from internal/attestationloop/collector_test.go rename to internal/attestation/collector_test.go index b75ae9a2..171578a5 100644 --- a/internal/attestationloop/collector_test.go +++ b/internal/attestation/collector_test.go @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package attestationloop +package attestation import ( "context" diff --git a/internal/attestationloop/manager.go b/internal/attestation/manager.go similarity index 69% rename from internal/attestationloop/manager.go rename to internal/attestation/manager.go index a4140510..52d67880 100644 --- a/internal/attestationloop/manager.go +++ b/internal/attestation/manager.go @@ -13,11 +13,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package attestationloop +package attestation import ( "context" + "crypto/rand" + "errors" "fmt" + "math/big" "sync" "time" @@ -48,7 +51,7 @@ type manager struct { nonceProvider NonceProvider collector EvidenceCollector submitter Submitter - interval time.Duration + config AttestationConfig lastResult *Result lastUpdated time.Time @@ -61,7 +64,7 @@ func NewManager( nonceProvider NonceProvider, collector EvidenceCollector, submitter Submitter, - interval time.Duration, + cfg AttestationConfig, ) Manager { return &manager{ nodeIDProvider: nodeIDProvider, @@ -69,7 +72,7 @@ func NewManager( nonceProvider: nonceProvider, collector: collector, submitter: submitter, - interval: interval, + config: cfg, } } @@ -77,25 +80,42 @@ func (m *manager) Run(ctx context.Context) error { if m.nodeIDProvider == nil || m.jwtProvider == nil || m.nonceProvider == nil || m.collector == nil || m.submitter == nil { return fmt.Errorf("attestation loop dependencies are incomplete") } - if _, err := m.CollectOnce(ctx); err != nil { - log.Logger.Warnw("initial attestation workflow failed", "error", err) - } - if m.interval <= 0 { + if m.config.Interval <= 0 { return nil } + startupInterval := m.config.Interval + if m.config.InitialInterval > 0 { + startupInterval = m.config.InitialInterval + } + if m.config.JitterEnabled { + jitter := calculateJitter(initialJitterCap(startupInterval)) + log.Logger.Infow("adding initial attestation startup jitter", "jitter_duration", jitter) + if err := sleepWithContext(ctx, jitter); err != nil { + return err + } + } - ticker := time.NewTicker(m.interval) - defer ticker.Stop() - + firstSuccess := false for { - select { - case <-ctx.Done(): - return nil - case <-ticker.C: - if _, err := m.CollectOnce(ctx); err != nil { - log.Logger.Warnw("periodic attestation workflow failed", "error", err) + _, err := m.CollectOnce(ctx) + nextInterval := m.config.Interval + if err == nil { + firstSuccess = true + } else { + log.Logger.Warnw("attestation workflow failed", "error", err) + switch { + case !firstSuccess && errors.Is(err, ErrNotEnrolled) && m.config.InitialInterval > 0: + nextInterval = m.config.InitialInterval + case m.config.RetryInterval > 0 && (nextInterval <= 0 || m.config.RetryInterval < nextInterval): + nextInterval = m.config.RetryInterval + if m.config.JitterEnabled { + nextInterval += calculateJitter(retryJitterCap(m.config.RetryInterval)) + } } } + if err := sleepWithContext(ctx, nextInterval); err != nil { + return err + } } } @@ -164,6 +184,66 @@ func (m *manager) IsResultUpdated(since time.Time) bool { return m.lastUpdated.After(since) } +func sleepWithContext(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } + } + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func calculateJitter(maxJitter time.Duration) time.Duration { + if maxJitter <= 0 { + return 0 + } + maxMs := int64(maxJitter / time.Millisecond) + if maxMs <= 0 { + return 0 + } + randomMs, err := rand.Int(rand.Reader, big.NewInt(maxMs)) + if err != nil { + log.Logger.Warnw("failed to generate secure attestation jitter, using fallback", "error", err) + return time.Duration(time.Now().UnixNano()%maxMs) * time.Millisecond + } + return time.Duration(randomMs.Int64()) * time.Millisecond +} + +func initialJitterCap(interval time.Duration) time.Duration { + if interval <= 0 { + return 0 + } + jitter := interval / 4 + const maxInitialJitter = 30 * time.Minute + if jitter > maxInitialJitter { + return maxInitialJitter + } + return jitter +} + +func retryJitterCap(retryInterval time.Duration) time.Duration { + if retryInterval <= 0 { + return 0 + } + jitter := retryInterval / 2 + const maxRetryJitter = 5 * time.Minute + if jitter > maxRetryJitter { + return maxRetryJitter + } + return jitter +} + type backendSubmitter struct { client BackendClient } @@ -209,7 +289,7 @@ func (p *stateJWTProvider) GetJWT(ctx context.Context) (string, error) { return "", err } if !ok || value == "" { - return "", fmt.Errorf("jwt not available in agent state") + return "", fmt.Errorf("%w: jwt not available in agent state", ErrNotEnrolled) } return value, nil } @@ -232,7 +312,7 @@ func NewStateNodeIDProvider(state agentstate.State) func(context.Context) (strin return "", err } if !ok || value == "" { - return "", fmt.Errorf("node ID not available in agent state") + return "", fmt.Errorf("%w: node ID not available in agent state", ErrNotEnrolled) } return value, nil } diff --git a/internal/attestationloop/manager_test.go b/internal/attestation/manager_test.go similarity index 64% rename from internal/attestationloop/manager_test.go rename to internal/attestation/manager_test.go index c4c0082d..5f9c8d8c 100644 --- a/internal/attestationloop/manager_test.go +++ b/internal/attestation/manager_test.go @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package attestationloop +package attestation import ( "context" @@ -81,7 +81,7 @@ func TestCollectOnceSuccess(t *testing.T) { &testNonceProvider{nonce: "abc123", refreshTS: refreshTS, refreshedJWT: "new-jwt"}, &testEvidenceCollector{resp: &SDKResponse{ResultCode: 200, ResultMessage: "ok"}}, submitter, - 0, + AttestationConfig{}, ) result, err := manager.CollectOnce(context.Background()) @@ -103,7 +103,7 @@ func TestCollectOnceCollectorFailureStillSubmitsFailureResult(t *testing.T) { &testNonceProvider{nonce: "abc123"}, &testEvidenceCollector{err: errors.New("collect failed")}, submitter, - 0, + AttestationConfig{}, ) result, err := manager.CollectOnce(context.Background()) @@ -115,7 +115,7 @@ func TestCollectOnceCollectorFailureStillSubmitsFailureResult(t *testing.T) { } func TestCollectOnceMissingDependencies(t *testing.T) { - _, err := NewManager(nil, nil, nil, nil, nil, 0).CollectOnce(context.Background()) + _, err := NewManager(nil, nil, nil, nil, nil, AttestationConfig{}).CollectOnce(context.Background()) require.Error(t, err) } @@ -127,12 +127,12 @@ func TestManagerRunAndCachedResult(t *testing.T) { &testNonceProvider{nonce: "abc123"}, &testEvidenceCollector{resp: &SDKResponse{ResultCode: 200}}, submitter, - 5*time.Millisecond, + AttestationConfig{InitialInterval: 5 * time.Millisecond, Interval: 5 * time.Millisecond}, ) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) defer cancel() - require.NoError(t, mgr.Run(ctx)) + require.ErrorIs(t, mgr.Run(ctx), context.DeadlineExceeded) last := mgr.LastResult() require.NotNil(t, last) @@ -140,6 +140,28 @@ func TestManagerRunAndCachedResult(t *testing.T) { require.True(t, mgr.IsResultUpdated(time.Time{})) } +func TestManagerRunUsesRetryIntervalOnFailure(t *testing.T) { + submitter := &testSubmitter{} + collector := &testEvidenceCollector{err: errors.New("collect failed")} + mgr := NewManager( + func(context.Context) (string, error) { return "node-1", nil }, + &testJWTProvider{jwt: "jwt-token"}, + &testNonceProvider{nonce: "abc123"}, + collector, + submitter, + AttestationConfig{InitialInterval: 5 * time.Millisecond, Interval: time.Hour, RetryInterval: 5 * time.Millisecond}, + ) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + require.ErrorIs(t, mgr.Run(ctx), context.DeadlineExceeded) + + last := mgr.LastResult() + require.NotNil(t, last) + require.False(t, last.Success) + require.Equal(t, "collect failed", last.ErrorMessage) +} + func TestManagerHelpersAndSubmitterErrors(t *testing.T) { mgr := NewManager( func(context.Context) (string, error) { return "node-1", nil }, @@ -147,7 +169,7 @@ func TestManagerHelpersAndSubmitterErrors(t *testing.T) { &testNonceProvider{nonce: "abc123"}, &testEvidenceCollector{resp: &SDKResponse{}}, &testSubmitter{}, - 0, + AttestationConfig{}, ) require.Nil(t, mgr.LastResult()) require.False(t, mgr.IsResultUpdated(time.Now().UTC())) @@ -174,3 +196,55 @@ func TestStateJWTProviderAndNodeIDProviderErrors(t *testing.T) { _, err = NewStateNodeIDProvider(&stubState{})(context.Background()) require.ErrorContains(t, err, "node ID not available") } + +func TestSleepWithContext(t *testing.T) { + require.NoError(t, sleepWithContext(context.Background(), 0)) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + require.ErrorIs(t, sleepWithContext(ctx, 0), context.Canceled) + + ctx, cancel = context.WithCancel(context.Background()) + cancel() + require.ErrorIs(t, sleepWithContext(ctx, time.Hour), context.Canceled) +} + +func TestAttestationJitterHelpers(t *testing.T) { + require.Equal(t, time.Duration(0), initialJitterCap(0)) + require.Equal(t, 75*time.Second, initialJitterCap(5*time.Minute)) + require.Equal(t, 30*time.Minute, initialJitterCap(4*time.Hour)) + + require.Equal(t, time.Duration(0), retryJitterCap(0)) + require.Equal(t, 150*time.Second, retryJitterCap(5*time.Minute)) + require.Equal(t, 5*time.Minute, retryJitterCap(20*time.Minute)) + + require.Equal(t, time.Duration(0), calculateJitter(0)) + jitter := calculateJitter(50 * time.Millisecond) + require.GreaterOrEqual(t, jitter, time.Duration(0)) + require.Less(t, jitter, 50*time.Millisecond) +} + +func TestManagerRunUsesInitialIntervalWithoutJitterWhenNotEnrolled(t *testing.T) { + mgr := NewManager( + func(context.Context) (string, error) { return "", ErrNotEnrolled }, + &testJWTProvider{jwt: "jwt-token"}, + &testNonceProvider{nonce: "abc123"}, + &testEvidenceCollector{resp: &SDKResponse{ResultCode: 200}}, + &testSubmitter{}, + AttestationConfig{ + InitialInterval: 5 * time.Millisecond, + Interval: time.Hour, + RetryInterval: time.Minute, + }, + ) + + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + start := time.Now() + err := mgr.Run(ctx) + elapsed := time.Since(start) + + require.ErrorIs(t, err, context.DeadlineExceeded) + require.GreaterOrEqual(t, elapsed, 15*time.Millisecond) + require.Less(t, elapsed, 100*time.Millisecond) +} diff --git a/internal/attestationloop/nonce.go b/internal/attestation/nonce.go similarity index 98% rename from internal/attestationloop/nonce.go rename to internal/attestation/nonce.go index a27c795c..975fda4a 100644 --- a/internal/attestationloop/nonce.go +++ b/internal/attestation/nonce.go @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package attestationloop +package attestation import ( "context" diff --git a/internal/attestationloop/nonce_test.go b/internal/attestation/nonce_test.go similarity index 98% rename from internal/attestationloop/nonce_test.go rename to internal/attestation/nonce_test.go index dfc0f9e3..8c09bc28 100644 --- a/internal/attestationloop/nonce_test.go +++ b/internal/attestation/nonce_test.go @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package attestationloop +package attestation import ( "context" diff --git a/internal/attestationloop/types.go b/internal/attestation/types.go similarity index 59% rename from internal/attestationloop/types.go rename to internal/attestation/types.go index 3766663a..bc224763 100644 --- a/internal/attestationloop/types.go +++ b/internal/attestation/types.go @@ -13,16 +13,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package attestationloop owns the backend attestation workflow. -package attestationloop +// Package attestation owns the backend attestation workflow. +package attestation import ( "context" + "errors" "time" - - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestation" ) +// ErrNotEnrolled indicates attestation cannot run yet because the agent is not enrolled. +var ErrNotEnrolled = errors.New("agent not enrolled") + // Result is the agent-owned attestation state model for the new backend sync loop. type Result struct { CollectedAt time.Time @@ -64,34 +66,10 @@ type Submitter interface { Submit(ctx context.Context, result *Result, jwt string) error } -// LegacyAttestationData converts the workflow result into the legacy attestation payload shape -// still consumed by the exporter collector path. -func (r *Result) LegacyAttestationData() *attestation.AttestationData { - if r == nil { - return nil - } - data := &attestation.AttestationData{ - NonceRefreshTimestamp: r.NonceRefreshTimestamp, - Success: r.Success, - ErrorMessage: r.ErrorMessage, - SDKResponse: attestation.AttestationSDKResponse{ - ResultCode: r.SDKResponse.ResultCode, - ResultMessage: r.SDKResponse.ResultMessage, - }, - } - if len(r.SDKResponse.Evidences) > 0 { - data.SDKResponse.Evidences = make([]attestation.EvidenceItem, 0, len(r.SDKResponse.Evidences)) - for _, ev := range r.SDKResponse.Evidences { - data.SDKResponse.Evidences = append(data.SDKResponse.Evidences, attestation.EvidenceItem{ - Arch: ev.Arch, - Certificate: ev.Certificate, - DriverVersion: ev.DriverVersion, - Evidence: ev.Evidence, - Nonce: ev.Nonce, - VBIOSVersion: ev.VBIOSVersion, - Version: ev.Version, - }) - } - } - return data +// AttestationConfig controls periodic attestation workflow scheduling. +type AttestationConfig struct { + InitialInterval time.Duration + Interval time.Duration + RetryInterval time.Duration + JitterEnabled bool } diff --git a/internal/config/config.go b/internal/config/config.go index 5bc3ef26..a200aae9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -65,17 +65,36 @@ type Config struct { // SECURITY: Only accessible from localhost (127.0.0.0/8 or ::1). Disabled by default. EnableFaultInjection bool `json:"enable_fault_injection"` + // Inventory controls the periodic inventory loop. + Inventory *InventoryConfig `json:"inventory,omitempty"` + + // Attestation controls the periodic attestation loop. + Attestation *AttestationConfig `json:"attestation,omitempty"` + // Health Exporter Configuration HealthExporter *HealthExporterConfig `json:"health_exporter,omitempty"` } -// AttestationConfig holds configuration for the attestation process -type AttestationConfig struct { - // Interval is how often to run attestation (default: 24 hours) +// InventoryConfig holds configuration for the periodic inventory loop. +type InventoryConfig struct { + // Enabled controls whether the periodic inventory loop runs. + Enabled bool `json:"enabled"` + + // Interval is how often to collect and export inventory. Interval metav1.Duration `json:"interval"` +} - // JitterEnabled controls whether to add random jitter to attestation schedule - JitterEnabled bool `json:"jitter_enabled"` +// AttestationConfig holds configuration for the periodic attestation loop. +type AttestationConfig struct { + // Enabled controls whether the periodic attestation loop runs. + Enabled bool `json:"enabled"` + + // InitialInterval is how often to check enrollment and attempt the first attestation run + // before switching to the steady-state interval after the first successful attestation. + InitialInterval metav1.Duration `json:"initial_interval"` + + // Interval is how often to run attestation. + Interval metav1.Duration `json:"interval"` } // HealthExporterConfig holds configuration for the health data exporter @@ -86,9 +105,6 @@ type HealthExporterConfig struct { // LogsEndpoint is the specific endpoint for sending logs/events data LogsEndpoint string `json:"logs_endpoint"` - // Attestation configuration - Attestation AttestationConfig `json:"attestation"` - // AuthToken is the authentication token for HTTP requests AuthToken string `json:"auth_token,omitempty"` @@ -151,6 +167,21 @@ func (config *Config) Validate() error { return fmt.Errorf("retention_period must be at least 1 minute, got %v", config.RetentionPeriod.Duration) } + if err := validateLoopConfig("inventory", config.Inventory); err != nil { + return err + } + if err := validateLoopConfig("attestation", config.Attestation); err != nil { + return err + } + if config.Attestation != nil && config.Attestation.Enabled { + if config.Attestation.InitialInterval.Duration <= 0 { + return errors.New("attestation.initial_interval is required when attestation is enabled") + } + if config.Attestation.InitialInterval.Duration < time.Minute { + return fmt.Errorf("attestation.initial_interval must be at least 1 minute, got %v", config.Attestation.InitialInterval.Duration) + } + } + // Validate health exporter configuration if present if config.HealthExporter != nil { // Validate health check interval @@ -183,6 +214,51 @@ func (config *Config) Validate() error { return nil } +func validateLoopConfig(name string, cfg interface { + GetEnabled() bool + GetInterval() time.Duration +}) error { + if cfg == nil || !cfg.GetEnabled() { + return nil + } + if cfg.GetInterval() <= 0 { + return fmt.Errorf("%s.interval is required when %s is enabled", name, name) + } + if cfg.GetInterval() < time.Minute { + return fmt.Errorf("%s.interval must be at least 1 minute, got %v", name, cfg.GetInterval()) + } + return nil +} + +func (c *InventoryConfig) GetEnabled() bool { + return c != nil && c.Enabled +} + +func (c *InventoryConfig) GetInterval() time.Duration { + if c == nil { + return 0 + } + return c.Interval.Duration +} + +func (c *AttestationConfig) GetEnabled() bool { + return c != nil && c.Enabled +} + +func (c *AttestationConfig) GetInterval() time.Duration { + if c == nil { + return 0 + } + return c.Interval.Duration +} + +func (c *AttestationConfig) GetInitialInterval() time.Duration { + if c == nil { + return 0 + } + return c.InitialInterval.Duration +} + // ShouldEnable returns true if the component should be enabled. // If no components are specified, all components are enabled by default. func (config *Config) ShouldEnable(componentName string) bool { @@ -356,13 +432,6 @@ func extractHealthExporterEntries(cfg *HealthExporterConfig) []ConfigEntry { strValue = fmt.Sprintf("%d", int64(duration.Seconds())) } else if d, ok := value.Interface().(time.Duration); ok { strValue = fmt.Sprintf("%d", int64(d.Seconds())) - } else if att, ok := value.Interface().(AttestationConfig); ok { - // Handle nested AttestationConfig - add individual fields - entries = append(entries, - ConfigEntry{Key: "health_exporter.attestation_interval", Value: fmt.Sprintf("%d", int64(att.Interval.Seconds()))}, - ConfigEntry{Key: "health_exporter.attestation_jitter_enabled", Value: fmt.Sprintf("%t", att.JitterEnabled)}, - ) - continue } default: strValue = fmt.Sprintf("%v", value.Interface()) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 8ed9ad11..5d251c2a 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -39,6 +39,13 @@ func TestDefault(t *testing.T) { assert.Equal(t, DefaultAPIVersion, cfg.APIVersion) assert.Equal(t, DefaultListenAddress, cfg.Address) assert.Equal(t, DefaultRetentionPeriod, cfg.RetentionPeriod) + require.NotNil(t, cfg.Inventory) + assert.True(t, cfg.Inventory.Enabled) + assert.Equal(t, metav1.Duration{Duration: 1 * time.Hour}, cfg.Inventory.Interval) + require.NotNil(t, cfg.Attestation) + assert.True(t, cfg.Attestation.Enabled) + assert.Equal(t, metav1.Duration{Duration: 5 * time.Minute}, cfg.Attestation.InitialInterval) + assert.Equal(t, metav1.Duration{Duration: 24 * time.Hour}, cfg.Attestation.Interval) // State path should be set assert.NotEmpty(t, cfg.State, "State path should be set") @@ -85,6 +92,97 @@ func TestConfigValidation(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "retention_period must be at least 1 minute") }) + + t.Run("inventory sync enabled without interval", func(t *testing.T) { + cfg := &Config{ + Address: ":8080", + RetentionPeriod: metav1.Duration{Duration: time.Hour}, + Inventory: &InventoryConfig{ + Enabled: true, + }, + } + + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "inventory.interval is required") + }) + + t.Run("inventory sync interval too short", func(t *testing.T) { + cfg := &Config{ + Address: ":8080", + RetentionPeriod: metav1.Duration{Duration: time.Hour}, + Inventory: &InventoryConfig{ + Enabled: true, + Interval: metav1.Duration{Duration: 500 * time.Millisecond}, + }, + } + + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "inventory.interval must be at least 1 minute") + }) + + t.Run("attestation enabled without interval", func(t *testing.T) { + cfg := &Config{ + Address: ":8080", + RetentionPeriod: metav1.Duration{Duration: time.Hour}, + Attestation: &AttestationConfig{ + Enabled: true, + InitialInterval: metav1.Duration{Duration: 5 * time.Minute}, + }, + } + + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "attestation.interval is required") + }) + + t.Run("attestation interval too short", func(t *testing.T) { + cfg := &Config{ + Address: ":8080", + RetentionPeriod: metav1.Duration{Duration: time.Hour}, + Attestation: &AttestationConfig{ + Enabled: true, + InitialInterval: metav1.Duration{Duration: 5 * time.Minute}, + Interval: metav1.Duration{Duration: 500 * time.Millisecond}, + }, + } + + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "attestation.interval must be at least 1 minute") + }) + + t.Run("attestation enabled without initial interval", func(t *testing.T) { + cfg := &Config{ + Address: ":8080", + RetentionPeriod: metav1.Duration{Duration: time.Hour}, + Attestation: &AttestationConfig{ + Enabled: true, + Interval: metav1.Duration{Duration: 24 * time.Hour}, + }, + } + + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "attestation.initial_interval is required") + }) + + t.Run("attestation initial interval too short", func(t *testing.T) { + cfg := &Config{ + Address: ":8080", + RetentionPeriod: metav1.Duration{Duration: time.Hour}, + Attestation: &AttestationConfig{ + Enabled: true, + InitialInterval: metav1.Duration{Duration: 500 * time.Millisecond}, + Interval: metav1.Duration{Duration: 24 * time.Hour}, + }, + } + + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "attestation.initial_interval must be at least 1 minute") + }) } func TestComponentSelection(t *testing.T) { @@ -555,9 +653,6 @@ func TestDefaultWithHealthExporter(t *testing.T) { require.NoError(t, err) assert.NotNil(t, cfg.HealthExporter) - // Attestation is always enabled, so we just check configuration - assert.Equal(t, metav1.Duration{Duration: 24 * time.Hour}, cfg.HealthExporter.Attestation.Interval) - assert.True(t, cfg.HealthExporter.Attestation.JitterEnabled) assert.Equal(t, metav1.Duration{Duration: 1 * time.Minute}, cfg.HealthExporter.Interval) assert.Equal(t, metav1.Duration{Duration: 30 * time.Second}, cfg.HealthExporter.Timeout) assert.True(t, cfg.HealthExporter.IncludeMetrics) @@ -634,12 +729,8 @@ func TestToConfigEntries(t *testing.T) { RetentionPeriod: metav1.Duration{Duration: 24 * time.Hour}, Components: []string{}, HealthExporter: &HealthExporterConfig{ - MetricsEndpoint: "https://example.com/metrics", - LogsEndpoint: "https://example.com/logs", - Attestation: AttestationConfig{ - Interval: metav1.Duration{Duration: 24 * time.Hour}, - JitterEnabled: true, - }, + MetricsEndpoint: "https://example.com/metrics", + LogsEndpoint: "https://example.com/logs", Interval: metav1.Duration{Duration: 1 * time.Minute}, Timeout: metav1.Duration{Duration: 30 * time.Second}, IncludeMetrics: true, @@ -662,7 +753,6 @@ func TestToConfigEntries(t *testing.T) { // Check for health exporter entries foundMetricsEndpoint := false foundLogsEndpoint := false - foundAttestation := false foundAuthToken := false for _, entry := range entries { @@ -672,10 +762,6 @@ func TestToConfigEntries(t *testing.T) { if entry.Key == "health_exporter.logs_endpoint" { foundLogsEndpoint = true } - if entry.Key == "health_exporter.attestation_jitter_enabled" { - foundAttestation = true - assert.Equal(t, "true", entry.Value) - } if entry.Key == "auth_token" || entry.Key == "health_exporter.auth_token" { foundAuthToken = true } @@ -684,7 +770,6 @@ func TestToConfigEntries(t *testing.T) { // Endpoints are excluded - they're enrollment-assigned, not user config assert.False(t, foundMetricsEndpoint, "metrics_endpoint should not be exported") assert.False(t, foundLogsEndpoint, "logs_endpoint should not be exported") - assert.True(t, foundAttestation) assert.False(t, foundAuthToken, "auth_token should not be exported") }) diff --git a/internal/config/default.go b/internal/config/default.go index ec55391a..9d413dbc 100644 --- a/internal/config/default.go +++ b/internal/config/default.go @@ -75,17 +75,22 @@ func Default(ctx context.Context, opts ...OpOption) (*Config, error) { Address: DefaultListenAddress, RetentionPeriod: DefaultRetentionPeriod, EnableFaultInjection: false, // Disabled by default for security + Inventory: &InventoryConfig{ + Enabled: true, + Interval: metav1.Duration{Duration: 1 * time.Hour}, + }, + Attestation: &AttestationConfig{ + Enabled: true, + InitialInterval: metav1.Duration{Duration: 5 * time.Minute}, + Interval: metav1.Duration{Duration: 24 * time.Hour}, + }, NvidiaToolOverwrites: nvidiacommon.ToolOverwrites{ InfinibandClassRootDir: options.InfinibandClassRootDir, }, // Health exporter is enabled by default HealthExporter: &HealthExporterConfig{ - MetricsEndpoint: "", - LogsEndpoint: "", - Attestation: AttestationConfig{ - Interval: metav1.Duration{Duration: 24 * time.Hour}, - JitterEnabled: true, - }, + MetricsEndpoint: "", + LogsEndpoint: "", AuthToken: "", Interval: metav1.Duration{Duration: 1 * time.Minute}, Timeout: metav1.Duration{Duration: 30 * time.Second}, diff --git a/internal/enrollment/enrollment.go b/internal/enrollment/enrollment.go index 585ffa7d..d1c90df4 100644 --- a/internal/enrollment/enrollment.go +++ b/internal/enrollment/enrollment.go @@ -132,7 +132,7 @@ func syncInventoryOnce(ctx context.Context) error { DisabledComponents: disabledComponents, }, ) - manager := inventory.NewManager(src, sink, 0) + manager := inventory.NewManager(src, sink, inventory.InventoryConfig{}) _, err = manager.CollectOnce(ctx) return err } diff --git a/internal/exporter/collector/collector_test.go b/internal/exporter/collector/collector_test.go index 0ba0bb78..c64c150a 100644 --- a/internal/exporter/collector/collector_test.go +++ b/internal/exporter/collector/collector_test.go @@ -96,7 +96,6 @@ func TestCollector_Collect_BasicFlow(t *testing.T) { IncludeMetrics: false, IncludeEvents: false, IncludeComponentData: false, - Attestation: config.AttestationConfig{}, } collector := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil) @@ -492,7 +491,6 @@ func TestCollector_AllFeaturesEnabled(t *testing.T) { IncludeMetrics: true, IncludeEvents: true, IncludeComponentData: true, - Attestation: config.AttestationConfig{}, MetricsLookback: metav1.Duration{Duration: 5 * time.Minute}, EventsLookback: metav1.Duration{Duration: 5 * time.Minute}, } diff --git a/internal/inventory/manager.go b/internal/inventory/manager.go index 1668a4dd..b7fa1491 100644 --- a/internal/inventory/manager.go +++ b/internal/inventory/manager.go @@ -17,10 +17,14 @@ package inventory import ( "context" + "crypto/rand" "errors" "fmt" + "math/big" "sync" "time" + + "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" ) // Manager coordinates periodic inventory collection into a store. @@ -30,44 +34,45 @@ type Manager interface { } type manager struct { - mu sync.RWMutex - source Source - sink Sink - interval time.Duration + mu sync.RWMutex + source Source + sink Sink + config InventoryConfig lastSnapshot *Snapshot lastExportedHash string } // NewManager creates an inventory manager. -func NewManager(source Source, sink Sink, interval time.Duration) Manager { +func NewManager(source Source, sink Sink, cfg InventoryConfig) Manager { return &manager{ - source: source, - sink: sink, - interval: interval, + source: source, + sink: sink, + config: cfg, } } func (m *manager) Run(ctx context.Context) error { if _, err := m.CollectOnce(ctx); err != nil { - return err + log.Logger.Warnw("initial inventory collection failed", "error", err) } - - if m.interval <= 0 { + if m.config.Interval <= 0 { return nil } - - ticker := time.NewTicker(m.interval) - defer ticker.Stop() + if m.config.JitterEnabled { + if err := sleepWithContext(ctx, calculateJitter(initialJitterCap(m.config.Interval))); err != nil { + return err + } + } for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - if _, err := m.CollectOnce(ctx); err != nil { - return err - } + _, err := m.CollectOnce(ctx) + nextInterval := m.config.Interval + if err != nil && m.config.RetryInterval > 0 && m.config.RetryInterval < nextInterval { + nextInterval = m.config.RetryInterval + calculateJitter(retryJitterCap(m.config.RetryInterval)) + } + if err := sleepWithContext(ctx, nextInterval); err != nil { + return err } } } @@ -109,3 +114,63 @@ func (m *manager) CollectOnce(ctx context.Context) (*Snapshot, error) { return snap, nil } + +func sleepWithContext(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } + } + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func calculateJitter(maxJitter time.Duration) time.Duration { + if maxJitter <= 0 { + return 0 + } + maxMs := int64(maxJitter / time.Millisecond) + if maxMs <= 0 { + return 0 + } + randomMs, err := rand.Int(rand.Reader, big.NewInt(maxMs)) + if err != nil { + log.Logger.Warnw("failed to generate secure inventory jitter, using fallback", "error", err) + return time.Duration(time.Now().UnixNano()%maxMs) * time.Millisecond + } + return time.Duration(randomMs.Int64()) * time.Millisecond +} + +func initialJitterCap(interval time.Duration) time.Duration { + if interval <= 0 { + return 0 + } + jitter := interval / 4 + const maxInitialJitter = 30 * time.Minute + if jitter > maxInitialJitter { + return maxInitialJitter + } + return jitter +} + +func retryJitterCap(retryInterval time.Duration) time.Duration { + if retryInterval <= 0 { + return 0 + } + jitter := retryInterval / 2 + const maxRetryJitter = 5 * time.Minute + if jitter > maxRetryJitter { + return maxRetryJitter + } + return jitter +} diff --git a/internal/inventory/manager_run_test.go b/internal/inventory/manager_run_test.go index 9bbce141..1048236c 100644 --- a/internal/inventory/manager_run_test.go +++ b/internal/inventory/manager_run_test.go @@ -33,13 +33,13 @@ type nilSnapshotSource struct{} func (nilSnapshotSource) Collect(context.Context) (*Snapshot, error) { return nil, nil } func TestManagerCollectOnceErrors(t *testing.T) { - _, err := NewManager(nil, nil, 0).CollectOnce(context.Background()) + _, err := NewManager(nil, nil, InventoryConfig{}).CollectOnce(context.Background()) require.ErrorContains(t, err, "inventory source is required") - _, err = NewManager(errSource{err: errors.New("boom")}, nil, 0).CollectOnce(context.Background()) + _, err = NewManager(errSource{err: errors.New("boom")}, nil, InventoryConfig{}).CollectOnce(context.Background()) require.ErrorContains(t, err, "boom") - _, err = NewManager(nilSnapshotSource{}, nil, 0).CollectOnce(context.Background()) + _, err = NewManager(nilSnapshotSource{}, nil, InventoryConfig{}).CollectOnce(context.Background()) require.ErrorContains(t, err, "nil snapshot") } @@ -49,7 +49,7 @@ func TestManagerRunWithZeroInterval(t *testing.T) { } sink := &fakeSink{} - err := NewManager(src, sink, 0).Run(context.Background()) + err := NewManager(src, sink, InventoryConfig{}).Run(context.Background()) require.NoError(t, err) require.Len(t, sink.exported, 1) } @@ -63,7 +63,7 @@ func TestManagerRunStopsOnContextCancel(t *testing.T) { done := make(chan error, 1) go func() { - done <- NewManager(src, sink, 10*time.Millisecond).Run(ctx) + done <- NewManager(src, sink, InventoryConfig{Interval: 10 * time.Millisecond}).Run(ctx) }() time.Sleep(25 * time.Millisecond) @@ -73,3 +73,47 @@ func TestManagerRunStopsOnContextCancel(t *testing.T) { require.ErrorIs(t, err, context.Canceled) require.NotEmpty(t, sink.exported) } + +func TestSleepWithContext(t *testing.T) { + require.NoError(t, sleepWithContext(context.Background(), 0)) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + require.ErrorIs(t, sleepWithContext(ctx, 0), context.Canceled) + + ctx, cancel = context.WithCancel(context.Background()) + cancel() + require.ErrorIs(t, sleepWithContext(ctx, time.Hour), context.Canceled) +} + +func TestInventoryJitterHelpers(t *testing.T) { + require.Equal(t, time.Duration(0), initialJitterCap(0)) + require.Equal(t, 15*time.Second, initialJitterCap(time.Minute)) + require.Equal(t, 30*time.Minute, initialJitterCap(4*time.Hour)) + + require.Equal(t, time.Duration(0), retryJitterCap(0)) + require.Equal(t, 30*time.Second, retryJitterCap(time.Minute)) + require.Equal(t, 5*time.Minute, retryJitterCap(20*time.Minute)) + + require.Equal(t, time.Duration(0), calculateJitter(0)) + jitter := calculateJitter(50 * time.Millisecond) + require.GreaterOrEqual(t, jitter, time.Duration(0)) + require.Less(t, jitter, 50*time.Millisecond) +} + +func TestManagerRunUsesRetryIntervalWithoutJitter(t *testing.T) { + src := errSource{err: errors.New("boom")} + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + + start := time.Now() + err := NewManager(src, nil, InventoryConfig{ + Interval: time.Hour, + RetryInterval: 5 * time.Millisecond, + }).Run(ctx) + elapsed := time.Since(start) + + require.ErrorIs(t, err, context.DeadlineExceeded) + require.GreaterOrEqual(t, elapsed, 15*time.Millisecond) + require.Less(t, elapsed, 100*time.Millisecond) +} diff --git a/internal/inventory/manager_test.go b/internal/inventory/manager_test.go index fcc44e75..9533031d 100644 --- a/internal/inventory/manager_test.go +++ b/internal/inventory/manager_test.go @@ -84,7 +84,7 @@ func TestManagerCollectOnceExportsOnlyWhenInventoryChanges(t *testing.T) { }, } sink := &fakeSink{} - mgr := NewManager(src, sink, 0) + mgr := NewManager(src, sink, InventoryConfig{}) snap1, err := mgr.CollectOnce(context.Background()) require.NoError(t, err) diff --git a/internal/inventory/types.go b/internal/inventory/types.go index 5e10116e..5f75cb9a 100644 --- a/internal/inventory/types.go +++ b/internal/inventory/types.go @@ -129,3 +129,10 @@ type Source interface { type Sink interface { Export(ctx context.Context, snap *Snapshot) error } + +// InventoryConfig controls periodic inventory scheduling. +type InventoryConfig struct { + Interval time.Duration + RetryInterval time.Duration + JitterEnabled bool +} diff --git a/internal/server/server.go b/internal/server/server.go index 8649c478..178f249f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -53,8 +53,14 @@ import ( nvidianvml "github.com/NVIDIA/fleet-intelligence-sdk/pkg/nvidia-query/nvml" "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" + "github.com/NVIDIA/fleet-intelligence-agent/internal/attestation" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" + inventorysink "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/sink" + inventorysource "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/source" + "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" "github.com/NVIDIA/fleet-intelligence-agent/internal/registry" ) @@ -87,6 +93,12 @@ type Server struct { machineID string } +type inventoryMachineInfoCollectorFunc func(context.Context) (*machineinfo.MachineInfo, error) + +func (f inventoryMachineInfoCollectorFunc) Collect(ctx context.Context) (*machineinfo.MachineInfo, error) { + return f(ctx) +} + // initializeDatabases opens and initializes database connections func initializeDatabases(ctx context.Context, cfg *config.Config) (*sql.DB, *sql.DB, error) { stateFile := ":memory:" @@ -164,6 +176,33 @@ func getHealthCheckInterval(config *config.Config) time.Duration { return healthCheckInterval } +func getInventorySyncInterval(config *config.Config) time.Duration { + if config == nil { + return 0 + } + if config.Inventory != nil { + if !config.Inventory.Enabled { + return 0 + } + return config.Inventory.Interval.Duration + } + return 0 +} + +func getAttestationInterval(config *config.Config) time.Duration { + if config == nil || config.Attestation == nil || !config.Attestation.Enabled { + return 0 + } + return config.Attestation.Interval.Duration +} + +func getAttestationTimeout(config *config.Config) time.Duration { + if config == nil || config.HealthExporter == nil { + return 0 + } + return config.HealthExporter.Timeout.Duration +} + // shouldEnableComponent determines if a component should be enabled based on configuration func shouldEnableComponent(name string, enabledByDefault bool, config *config.Config) bool { shouldEnable := enabledByDefault @@ -337,6 +376,9 @@ func New(ctx context.Context, auditLogger log.AuditLogger, config *config.Config } } + s.startInventoryLoop(ctx, config, nvmlInstance, dcgmGPUIndexes) + s.startAttestationLoop(ctx, config) + // Create and start health exporter with all dependencies if enabled if config.HealthExporter != nil { var err error @@ -368,6 +410,77 @@ func New(ctx context.Context, auditLogger log.AuditLogger, config *config.Config return s, nil } +func (s *Server) startInventoryLoop( + ctx context.Context, + cfg *config.Config, + nvmlInstance nvidianvml.Instance, + dcgmGPUIndexes map[string]string, +) { + interval := getInventorySyncInterval(cfg) + if interval <= 0 { + return + } + + allComponents := registry.AllComponentNames() + apiVersion, retentionPeriodSeconds, enabledComponents, disabledComponents := cfg.InventoryAgentConfig(allComponents) + + source := inventorysource.NewMachineInfoSourceWithAgentConfig( + inventoryMachineInfoCollectorFunc(func(context.Context) (*machineinfo.MachineInfo, error) { + return machineinfo.GetMachineInfo(nvmlInstance, machineinfo.WithDCGMGPUIndexes(dcgmGPUIndexes)) + }), + &inventory.AgentConfig{ + TotalComponents: int64(len(allComponents)), + APIVersion: apiVersion, + RetentionPeriodSeconds: retentionPeriodSeconds, + EnabledComponents: enabledComponents, + DisabledComponents: disabledComponents, + }, + ) + sink := inventorysink.NewBackendSink(agentstate.NewSQLite()) + manager := inventory.NewManager(source, sink, inventory.InventoryConfig{ + Interval: interval, + RetryInterval: inventoryRetryInterval, + JitterEnabled: true, + }) + + go func() { + if err := manager.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.Logger.Errorw("inventory loop manager exited", "error", err) + } + }() +} + +const attestationRetryInterval = 5 * time.Minute +const inventoryRetryInterval = 1 * time.Minute + +func (s *Server) startAttestationLoop(ctx context.Context, cfg *config.Config) { + interval := getAttestationInterval(cfg) + if interval <= 0 { + return + } + + state := agentstate.NewSQLite() + manager := attestation.NewManager( + attestation.NewStateNodeIDProvider(state), + attestation.NewStateJWTProvider(state), + attestation.NewStateNonceProvider(state), + attestation.NewCLIEvidenceCollector(getAttestationTimeout(cfg)), + attestation.NewStateBackendSubmitter(state), + attestation.AttestationConfig{ + InitialInterval: cfg.Attestation.InitialInterval.Duration, + Interval: interval, + RetryInterval: attestationRetryInterval, + JitterEnabled: true, + }, + ) + + go func() { + if err := manager.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.Logger.Errorw("attestation loop exited", "error", err) + } + }() +} + // GetHealthExporter returns the health exporter instance (for offline mode access) func (s *Server) GetHealthExporter() exporter.Exporter { return s.healthExporter diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 5964fdd9..6eb1f6f3 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -79,6 +79,94 @@ func TestGetHealthCheckInterval(t *testing.T) { } } +func TestGetInventorySyncInterval(t *testing.T) { + tests := []struct { + name string + config *config.Config + expected time.Duration + }{ + { + name: "nil config", + config: nil, + expected: 0, + }, + { + name: "uses inventory sync config", + config: &config.Config{ + Inventory: &config.InventoryConfig{ + Enabled: true, + Interval: metav1.Duration{Duration: 2 * time.Minute}, + }, + }, + expected: 2 * time.Minute, + }, + { + name: "disabled inventory sync", + config: &config.Config{ + Inventory: &config.InventoryConfig{ + Enabled: false, + Interval: metav1.Duration{Duration: 2 * time.Minute}, + }, + }, + expected: 0, + }, + { + name: "no inventory or exporter config", + config: &config.Config{}, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, getInventorySyncInterval(tt.config)) + }) + } +} + +func TestGetAttestationSettings(t *testing.T) { + tests := []struct { + name string + config *config.Config + wantInterval time.Duration + wantTimeout time.Duration + }{ + { + name: "nil config", + config: nil, + wantInterval: 0, + wantTimeout: 0, + }, + { + name: "no exporter", + config: &config.Config{}, + wantInterval: 0, + wantTimeout: 0, + }, + { + name: "uses health exporter attestation settings", + config: &config.Config{ + HealthExporter: &config.HealthExporterConfig{ + Timeout: metav1.Duration{Duration: 45 * time.Second}, + }, + Attestation: &config.AttestationConfig{ + Enabled: true, + Interval: metav1.Duration{Duration: 6 * time.Hour}, + }, + }, + wantInterval: 6 * time.Hour, + wantTimeout: 45 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.wantInterval, getAttestationInterval(tt.config)) + assert.Equal(t, tt.wantTimeout, getAttestationTimeout(tt.config)) + }) + } +} + // TestShouldEnableComponent tests the shouldEnableComponent function. func TestShouldEnableComponent(t *testing.T) { tests := []struct { From b60d84516cbb9413de8b816328dac42521701474 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Mon, 20 Apr 2026 11:14:44 -0700 Subject: [PATCH 10/22] refactor: align inventory payload and runtime flow Signed-off-by: Jingxiang Zhang --- internal/backendclient/types.go | 3 +-- internal/config/config.go | 6 ++--- internal/config/config_test.go | 4 +-- internal/endpoint/endpoint.go | 24 ++++++++++++++--- internal/endpoint/endpoint_test.go | 24 ++++++++++++++++- internal/enrollment/enrollment.go | 3 +-- internal/enrollment/enrollment_test.go | 19 ++++++++++++++ internal/exporter/converter/csv.go | 2 +- internal/exporter/converter/csv_test.go | 10 +++---- internal/inventory/hash_test.go | 1 - internal/inventory/manager_run_test.go | 4 +-- internal/inventory/manager_test.go | 3 --- internal/inventory/mapper/backend.go | 3 +-- internal/inventory/mapper/backend_test.go | 5 +--- internal/inventory/sink/backend.go | 9 ++++++- internal/inventory/sink/backend_test.go | 30 ++++++++++++--------- internal/inventory/source/source.go | 4 +-- internal/inventory/source/source_test.go | 27 ++++++++++++++----- internal/inventory/types.go | 4 +-- internal/machineinfo/machineinfo.go | 13 +++++---- internal/machineinfo/machineinfo_test.go | 32 +++++++++++------------ internal/server/server.go | 3 +-- 22 files changed, 150 insertions(+), 83 deletions(-) diff --git a/internal/backendclient/types.go b/internal/backendclient/types.go index 54e6d36d..2ef8504f 100644 --- a/internal/backendclient/types.go +++ b/internal/backendclient/types.go @@ -22,7 +22,7 @@ type NodeUpsertRequest struct { Hostname string `json:"hostname"` AgentConfig AgentConfig `json:"agentConfig,omitempty"` Resources NodeResources `json:"resources"` - FleetintVersion string `json:"gpuHealthVersion"` + AgentVersion string `json:"agentVersion"` GPUDriverVersion string `json:"gpuDriverVersion"` CUDAVersion string `json:"cudaVersion"` DCGMVersion string `json:"dcgmVersion"` @@ -48,7 +48,6 @@ type NodeResources struct { type AgentConfig struct { TotalComponents int64 `json:"totalComponents,omitempty"` - APIVersion string `json:"apiVersion,omitempty"` RetentionPeriodSeconds int64 `json:"retentionPeriodSeconds,omitempty"` EnabledComponents []string `json:"enabledComponents,omitempty"` DisabledComponents []string `json:"disabledComponents,omitempty"` diff --git a/internal/config/config.go b/internal/config/config.go index a200aae9..711cc239 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -342,13 +342,13 @@ func (config *Config) ToConfigEntries(allComponentNames []string) []ConfigEntry // InventoryAgentConfig returns the useful, non-sensitive subset of agent config that should be // persisted with inventory rather than exported through telemetry. -func (config *Config) InventoryAgentConfig(allComponentNames []string) (apiVersion string, retentionPeriodSeconds int64, enabled, disabled []string) { +func (config *Config) InventoryAgentConfig(allComponentNames []string) (retentionPeriodSeconds int64, enabled, disabled []string) { if config == nil { - return "", 0, nil, nil + return 0, nil, nil } enabled, disabled = config.getComponentLists(allComponentNames) - return config.APIVersion, int64(config.RetentionPeriod.Seconds()), enabled, disabled + return int64(config.RetentionPeriod.Seconds()), enabled, disabled } // getComponentLists computes enabled/disabled lists from config rules against all available components. diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 5d251c2a..c6467633 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -916,9 +916,7 @@ func TestInventoryAgentConfig(t *testing.T) { Components: []string{"*", "-memory", "-disk"}, } - apiVersion, retentionPeriodSeconds, enabled, disabled := cfg.InventoryAgentConfig(allComponents) - - assert.Equal(t, "v1", apiVersion) + retentionPeriodSeconds, enabled, disabled := cfg.InventoryAgentConfig(allComponents) assert.Equal(t, int64(86400), retentionPeriodSeconds) assert.ElementsMatch(t, []string{"cpu", "gpu"}, enabled) assert.ElementsMatch(t, []string{"memory", "disk"}, disabled) diff --git a/internal/endpoint/endpoint.go b/internal/endpoint/endpoint.go index 31b1d9b4..321fade4 100644 --- a/internal/endpoint/endpoint.go +++ b/internal/endpoint/endpoint.go @@ -109,19 +109,28 @@ func AgentBaseURL(serverURL *url.URL) *url.URL { return serverURL } -// ValidateBackendEndpoint validates a trusted backend HTTPS endpoint. +// ValidateBackendEndpoint validates a trusted backend endpoint. +// Production backends must use HTTPS. HTTP is accepted only for loopback hosts +// to support local development and smoke tests. func ValidateBackendEndpoint(raw string) (*url.URL, error) { parsed, err := parseURL(raw) if err != nil { return nil, err } + if parsed.Scheme == "http" { + host := parsed.Hostname() + if host == "" || !isLoopbackHost(host) { + return nil, fmt.Errorf("backend endpoint over http must use localhost or a loopback IP, got %q", host) + } + return parsed, nil + } if err := requireScheme(parsed, "https"); err != nil { return nil, err } return parsed, nil } -// DeriveBackendBaseURL converts a legacy HTTPS endpoint URL into its backend base URL. +// DeriveBackendBaseURL converts a legacy backend endpoint URL into its backend base URL. // For example, "https://backend.example.com/api/v1/health/metrics" becomes // "https://backend.example.com". func DeriveBackendBaseURL(raw string) (string, error) { @@ -129,8 +138,15 @@ func DeriveBackendBaseURL(raw string) (string, error) { if err != nil { return "", err } - if err := requireScheme(parsed, "https"); err != nil { - return "", err + if parsed.Scheme == "http" { + host := parsed.Hostname() + if host == "" || !isLoopbackHost(host) { + return "", fmt.Errorf("backend endpoint over http must use localhost or a loopback IP, got %q", host) + } + } else { + if err := requireScheme(parsed, "https"); err != nil { + return "", err + } } return (&url.URL{ Scheme: parsed.Scheme, diff --git a/internal/endpoint/endpoint_test.go b/internal/endpoint/endpoint_test.go index fe9fde54..3a084821 100644 --- a/internal/endpoint/endpoint_test.go +++ b/internal/endpoint/endpoint_test.go @@ -194,7 +194,29 @@ func TestValidateBackendEndpoint(t *testing.T) { _, err := ValidateBackendEndpoint("https://example.com/base") require.NoError(t, err) + _, err = ValidateBackendEndpoint("http://localhost:8080") + require.NoError(t, err) + + _, err = ValidateBackendEndpoint("http://127.0.0.1:8080") + require.NoError(t, err) + _, err = ValidateBackendEndpoint("http://example.com/base") require.Error(t, err) - assert.Contains(t, err.Error(), "https") + assert.Contains(t, err.Error(), "loopback") +} + +func TestDeriveBackendBaseURL(t *testing.T) { + t.Parallel() + + got, err := DeriveBackendBaseURL("https://backend.example.com/api/v1/health/metrics") + require.NoError(t, err) + assert.Equal(t, "https://backend.example.com", got) + + got, err = DeriveBackendBaseURL("http://localhost:8080/api/v1/health/metrics") + require.NoError(t, err) + assert.Equal(t, "http://localhost:8080", got) + + _, err = DeriveBackendBaseURL("http://example.com/api/v1/health/metrics") + require.Error(t, err) + assert.Contains(t, err.Error(), "loopback") } diff --git a/internal/enrollment/enrollment.go b/internal/enrollment/enrollment.go index d1c90df4..1ce73778 100644 --- a/internal/enrollment/enrollment.go +++ b/internal/enrollment/enrollment.go @@ -112,7 +112,7 @@ func syncInventoryOnce(ctx context.Context) error { if err != nil { return fmt.Errorf("load default config for inventory sync: %w", err) } - apiVersion, retentionPeriodSeconds, enabledComponents, disabledComponents := cfg.InventoryAgentConfig(allComponents) + retentionPeriodSeconds, enabledComponents, disabledComponents := cfg.InventoryAgentConfig(allComponents) nvmlInstance, err := nvidianvml.New() if err != nil { @@ -126,7 +126,6 @@ func syncInventoryOnce(ctx context.Context) error { }), &inventory.AgentConfig{ TotalComponents: int64(len(allComponents)), - APIVersion: apiVersion, RetentionPeriodSeconds: retentionPeriodSeconds, EnabledComponents: enabledComponents, DisabledComponents: disabledComponents, diff --git a/internal/enrollment/enrollment_test.go b/internal/enrollment/enrollment_test.go index 338c8285..53467e74 100644 --- a/internal/enrollment/enrollment_test.go +++ b/internal/enrollment/enrollment_test.go @@ -90,6 +90,25 @@ func TestEnrollWorkflowErrors(t *testing.T) { require.Error(t, err) }) + t.Run("localhost http endpoint allowed", func(t *testing.T) { + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + + called := false + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + called = true + require.Equal(t, "http://localhost:8080", rawBaseURL) + return &fakeBackendClient{enrollJWT: "jwt-token"}, nil + } + + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + err := Enroll(context.Background(), "http://localhost:8080", "sak-token") + require.NoError(t, err) + require.True(t, called) + }) + t.Run("backend client creation", func(t *testing.T) { originalFactory := newBackendClient t.Cleanup(func() { newBackendClient = originalFactory }) diff --git a/internal/exporter/converter/csv.go b/internal/exporter/converter/csv.go index 8d12a650..d8c913ab 100644 --- a/internal/exporter/converter/csv.go +++ b/internal/exporter/converter/csv.go @@ -282,7 +282,7 @@ func (c *csvConverter) writeMachineInfoCSV(outputDir, filename string, data *col // Header and basic system info records = append(records, []string{"attribute_name", "attribute_value"}, - []string{"Fleetint Version", i.FleetintVersion}, + []string{"Agent Version", i.AgentVersion}, []string{"Container Runtime Version", i.ContainerRuntimeVersion}, []string{"OS Image", i.OSImage}, []string{"Kernel Version", i.KernelVersion}, diff --git a/internal/exporter/converter/csv_test.go b/internal/exporter/converter/csv_test.go index 6a5905c8..1881c364 100644 --- a/internal/exporter/converter/csv_test.go +++ b/internal/exporter/converter/csv_test.go @@ -177,10 +177,10 @@ func TestCSVConverter_Convert_MachineInfo(t *testing.T) { timestamp := "20251105_120000" machineInfo := &machineinfo.MachineInfo{ - FleetintVersion: "0.1.5", - DCGMVersion: "4.2.3", - OSImage: "Ubuntu 22.04", - KernelVersion: "5.15.0", + AgentVersion: "0.1.5", + DCGMVersion: "4.2.3", + OSImage: "Ubuntu 22.04", + KernelVersion: "5.15.0", CPUInfo: &apiv1.MachineCPUInfo{ Type: "Intel", Manufacturer: "Intel", @@ -279,7 +279,7 @@ func TestCSVConverter_Convert_AllData(t *testing.T) { }, }, MachineInfo: &machineinfo.MachineInfo{ - FleetintVersion: "0.1.5", + AgentVersion: "0.1.5", }, } diff --git a/internal/inventory/hash_test.go b/internal/inventory/hash_test.go index 8a91d7f6..49b5a9c9 100644 --- a/internal/inventory/hash_test.go +++ b/internal/inventory/hash_test.go @@ -25,7 +25,6 @@ import ( func TestComputeHashIgnoresCollectedAtAndExistingHash(t *testing.T) { base := &Snapshot{ CollectedAt: time.Unix(100, 0).UTC(), - NodeID: "node-1", InventoryHash: "old-hash", Hostname: "host-a", MachineID: "machine-id", diff --git a/internal/inventory/manager_run_test.go b/internal/inventory/manager_run_test.go index 1048236c..3f9440ae 100644 --- a/internal/inventory/manager_run_test.go +++ b/internal/inventory/manager_run_test.go @@ -45,7 +45,7 @@ func TestManagerCollectOnceErrors(t *testing.T) { func TestManagerRunWithZeroInterval(t *testing.T) { src := &fakeSource{ - snapshots: []*Snapshot{{NodeID: "node-1", MachineID: "machine-1", Hostname: "host-a"}}, + snapshots: []*Snapshot{{MachineID: "machine-1", Hostname: "host-a"}}, } sink := &fakeSink{} @@ -56,7 +56,7 @@ func TestManagerRunWithZeroInterval(t *testing.T) { func TestManagerRunStopsOnContextCancel(t *testing.T) { src := &fakeSource{ - snapshots: []*Snapshot{{NodeID: "node-1", MachineID: "machine-1", Hostname: "host-a"}}, + snapshots: []*Snapshot{{MachineID: "machine-1", Hostname: "host-a"}}, } sink := &fakeSink{} ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/inventory/manager_test.go b/internal/inventory/manager_test.go index 9533031d..d72b8b2d 100644 --- a/internal/inventory/manager_test.go +++ b/internal/inventory/manager_test.go @@ -56,7 +56,6 @@ func TestManagerCollectOnceExportsOnlyWhenInventoryChanges(t *testing.T) { snapshots: []*Snapshot{ { CollectedAt: time.Unix(100, 0).UTC(), - NodeID: "node-1", Hostname: "host-a", MachineID: "machine-id", Resources: Resources{ @@ -65,7 +64,6 @@ func TestManagerCollectOnceExportsOnlyWhenInventoryChanges(t *testing.T) { }, { CollectedAt: time.Unix(200, 0).UTC(), - NodeID: "node-1", Hostname: "host-a", MachineID: "machine-id", Resources: Resources{ @@ -74,7 +72,6 @@ func TestManagerCollectOnceExportsOnlyWhenInventoryChanges(t *testing.T) { }, { CollectedAt: time.Unix(300, 0).UTC(), - NodeID: "node-1", Hostname: "host-b", MachineID: "machine-id", Resources: Resources{ diff --git a/internal/inventory/mapper/backend.go b/internal/inventory/mapper/backend.go index d5966430..60dd9d78 100644 --- a/internal/inventory/mapper/backend.go +++ b/internal/inventory/mapper/backend.go @@ -67,7 +67,6 @@ func ToNodeUpsertRequest(s *inventory.Snapshot) *backendclient.NodeUpsertRequest Hostname: s.Hostname, AgentConfig: backendclient.AgentConfig{ TotalComponents: s.AgentConfig.TotalComponents, - APIVersion: s.AgentConfig.APIVersion, RetentionPeriodSeconds: s.AgentConfig.RetentionPeriodSeconds, EnabledComponents: append([]string(nil), s.AgentConfig.EnabledComponents...), DisabledComponents: append([]string(nil), s.AgentConfig.DisabledComponents...), @@ -78,7 +77,7 @@ func ToNodeUpsertRequest(s *inventory.Snapshot) *backendclient.NodeUpsertRequest OperatingSystem: s.OperatingSystem, OSImage: s.OSImage, KernelVersion: s.KernelVersion, - FleetintVersion: s.FleetintVersion, + AgentVersion: s.AgentVersion, GPUDriverVersion: s.GPUDriverVersion, CUDAVersion: s.CUDAVersion, DCGMVersion: s.DCGMVersion, diff --git a/internal/inventory/mapper/backend_test.go b/internal/inventory/mapper/backend_test.go index 4fb88192..93f27021 100644 --- a/internal/inventory/mapper/backend_test.go +++ b/internal/inventory/mapper/backend_test.go @@ -29,7 +29,6 @@ func TestToNodeUpsertRequestNil(t *testing.T) { func TestToNodeUpsertRequest(t *testing.T) { req := ToNodeUpsertRequest(&inventory.Snapshot{ - NodeID: "node-1", Hostname: "host-a", MachineID: "machine-id", SystemUUID: "uuid-1", @@ -37,7 +36,7 @@ func TestToNodeUpsertRequest(t *testing.T) { OperatingSystem: "linux", OSImage: "Ubuntu", KernelVersion: "6.5.0", - FleetintVersion: "1.2.3", + AgentVersion: "1.2.3", GPUDriverVersion: "550.54.15", CUDAVersion: "12.4", DCGMVersion: "4.2.3", @@ -47,7 +46,6 @@ func TestToNodeUpsertRequest(t *testing.T) { InventoryHash: "hash-1", AgentConfig: inventory.AgentConfig{ TotalComponents: 30, - APIVersion: "v1", RetentionPeriodSeconds: 86400, EnabledComponents: []string{"cpu", "gpu"}, DisabledComponents: []string{"disk"}, @@ -107,7 +105,6 @@ func TestToNodeUpsertRequest(t *testing.T) { require.Equal(t, "203.0.113.10", req.NetPublicIP) require.Equal(t, "hash-1", req.InventoryHash) require.Equal(t, int64(30), req.AgentConfig.TotalComponents) - require.Equal(t, "v1", req.AgentConfig.APIVersion) require.Equal(t, int64(86400), req.AgentConfig.RetentionPeriodSeconds) require.Equal(t, []string{"cpu", "gpu"}, req.AgentConfig.EnabledComponents) require.Equal(t, []string{"disk"}, req.AgentConfig.DisabledComponents) diff --git a/internal/inventory/sink/backend.go b/internal/inventory/sink/backend.go index 73e4fd64..59f05591 100644 --- a/internal/inventory/sink/backend.go +++ b/internal/inventory/sink/backend.go @@ -63,9 +63,16 @@ func (s *backendSink) Export(ctx context.Context, snap *inventory.Snapshot) erro if !ok || jwt == "" { return inventory.ErrNotReady } + nodeUUID, ok, err := s.state.GetNodeID(ctx) + if err != nil { + return err + } + if !ok || nodeUUID == "" { + return inventory.ErrNotReady + } client, err := s.clientFactory(baseURL) if err != nil { return fmt.Errorf("create backend client: %w", err) } - return client.UpsertNode(ctx, snap.NodeID, mapper.ToNodeUpsertRequest(snap), jwt) + return client.UpsertNode(ctx, nodeUUID, mapper.ToNodeUpsertRequest(snap), jwt) } diff --git a/internal/inventory/sink/backend_test.go b/internal/inventory/sink/backend_test.go index a4ed0d0f..ad23fee2 100644 --- a/internal/inventory/sink/backend_test.go +++ b/internal/inventory/sink/backend_test.go @@ -29,6 +29,7 @@ import ( type fakeState struct { baseURL string jwt string + nodeID string err error } @@ -45,11 +46,16 @@ func (f fakeState) GetJWT(context.Context) (string, bool, error) { } return f.jwt, f.jwt != "", nil } -func (f fakeState) SetJWT(context.Context, string) error { return nil } -func (f fakeState) GetSAK(context.Context) (string, bool, error) { return "", false, nil } -func (f fakeState) SetSAK(context.Context, string) error { return nil } -func (f fakeState) GetNodeID(context.Context) (string, bool, error) { return "", false, nil } -func (f fakeState) SetNodeID(context.Context, string) error { return nil } +func (f fakeState) SetJWT(context.Context, string) error { return nil } +func (f fakeState) GetSAK(context.Context) (string, bool, error) { return "", false, nil } +func (f fakeState) SetSAK(context.Context, string) error { return nil } +func (f fakeState) GetNodeID(context.Context) (string, bool, error) { + if f.err != nil { + return "", false, f.err + } + return f.nodeID, f.nodeID != "", nil +} +func (f fakeState) SetNodeID(context.Context, string) error { return nil } type fakeClient struct { nodeID string @@ -78,21 +84,21 @@ func TestBackendSinkExportNotReady(t *testing.T) { clientFactory: backendclient.New, } - err := s.Export(context.Background(), &inventory.Snapshot{NodeID: "node-1"}) + err := s.Export(context.Background(), &inventory.Snapshot{}) require.ErrorIs(t, err, inventory.ErrNotReady) } func TestBackendSinkExportErrors(t *testing.T) { - err := (&backendSink{}).Export(context.Background(), &inventory.Snapshot{NodeID: "node-1"}) + err := (&backendSink{}).Export(context.Background(), &inventory.Snapshot{}) require.ErrorContains(t, err, "agent state") - err = (&backendSink{state: fakeState{baseURL: "https://example.com", jwt: "jwt"}}).Export(context.Background(), &inventory.Snapshot{NodeID: "node-1"}) + err = (&backendSink{state: fakeState{baseURL: "https://example.com", jwt: "jwt"}}).Export(context.Background(), &inventory.Snapshot{}) require.ErrorContains(t, err, "client factory") err = (&backendSink{ state: fakeState{err: errors.New("state error")}, clientFactory: backendclient.New, - }).Export(context.Background(), &inventory.Snapshot{NodeID: "node-1"}) + }).Export(context.Background(), &inventory.Snapshot{}) require.ErrorContains(t, err, "state error") err = (&backendSink{ @@ -102,11 +108,11 @@ func TestBackendSinkExportErrors(t *testing.T) { require.ErrorContains(t, err, "inventory snapshot") err = (&backendSink{ - state: fakeState{baseURL: "https://example.com", jwt: "jwt"}, + state: fakeState{baseURL: "https://example.com", jwt: "jwt", nodeID: "node-1"}, clientFactory: func(string) (backendclient.Client, error) { return nil, errors.New("client factory error") }, - }).Export(context.Background(), &inventory.Snapshot{NodeID: "node-1"}) + }).Export(context.Background(), &inventory.Snapshot{}) require.ErrorContains(t, err, "create backend client") } @@ -116,6 +122,7 @@ func TestBackendSinkExportUsesState(t *testing.T) { state: fakeState{ baseURL: "https://example.com", jwt: "jwt-token", + nodeID: "node-1", }, clientFactory: func(string) (backendclient.Client, error) { return client, nil @@ -123,7 +130,6 @@ func TestBackendSinkExportUsesState(t *testing.T) { } err := s.Export(context.Background(), &inventory.Snapshot{ - NodeID: "node-1", Hostname: "host-a", MachineID: "machine-id", }) diff --git a/internal/inventory/source/source.go b/internal/inventory/source/source.go index 997b0fbc..79e631b8 100644 --- a/internal/inventory/source/source.go +++ b/internal/inventory/source/source.go @@ -67,7 +67,6 @@ func (s *machineInfoSource) Collect(ctx context.Context) (*inventory.Snapshot, e snap := &inventory.Snapshot{ CollectedAt: time.Now().UTC(), - NodeID: info.MachineID, Hostname: info.Hostname, MachineID: info.MachineID, SystemUUID: info.SystemUUID, @@ -75,14 +74,13 @@ func (s *machineInfoSource) Collect(ctx context.Context) (*inventory.Snapshot, e OperatingSystem: info.OperatingSystem, OSImage: info.OSImage, KernelVersion: info.KernelVersion, - FleetintVersion: info.FleetintVersion, + AgentVersion: info.AgentVersion, GPUDriverVersion: info.GPUDriverVersion, CUDAVersion: info.CUDAVersion, DCGMVersion: info.DCGMVersion, ContainerRuntimeVersion: info.ContainerRuntimeVersion, AgentConfig: s.agentConfig, } - if info.CPUInfo != nil { snap.Resources.CPUInfo = inventory.CPUInfo{ Type: info.CPUInfo.Type, diff --git a/internal/inventory/source/source_test.go b/internal/inventory/source/source_test.go index 43c6ece3..b8d59496 100644 --- a/internal/inventory/source/source_test.go +++ b/internal/inventory/source/source_test.go @@ -38,7 +38,7 @@ func (f fakeMachineInfoCollector) Collect(context.Context) (*machineinfo.Machine func TestMachineInfoSourceCollect(t *testing.T) { src := NewMachineInfoSource(fakeMachineInfoCollector{ info: &machineinfo.MachineInfo{ - FleetintVersion: "1.2.3", + AgentVersion: "1.2.3", GPUDriverVersion: "550.54.15", CUDAVersion: "12.4", DCGMVersion: "4.2.3", @@ -101,7 +101,7 @@ func TestMachineInfoSourceCollect(t *testing.T) { snap, err := src.Collect(context.Background()) require.NoError(t, err) require.NotNil(t, snap) - require.Equal(t, "machine-id", snap.NodeID) + require.Equal(t, "machine-id", snap.MachineID) require.Equal(t, "host-a", snap.Hostname) require.Equal(t, "10.0.0.10", snap.NetPrivateIP) require.Equal(t, "Xeon", snap.Resources.CPUInfo.Type) @@ -118,13 +118,13 @@ func TestMachineInfoSourceCollectWithAgentConfig(t *testing.T) { src := NewMachineInfoSourceWithAgentConfig( fakeMachineInfoCollector{ info: &machineinfo.MachineInfo{ - MachineID: "machine-id", - Hostname: "host-a", + MachineID: "machine-id", + SystemUUID: "system-uuid", + Hostname: "host-a", }, }, &inventory.AgentConfig{ TotalComponents: 42, - APIVersion: "v1", RetentionPeriodSeconds: 86400, EnabledComponents: []string{"cpu", "gpu"}, DisabledComponents: []string{"disk"}, @@ -134,9 +134,24 @@ func TestMachineInfoSourceCollectWithAgentConfig(t *testing.T) { snap, err := src.Collect(context.Background()) require.NoError(t, err) require.NotNil(t, snap) + require.Equal(t, "machine-id", snap.MachineID) require.Equal(t, int64(42), snap.AgentConfig.TotalComponents) - require.Equal(t, "v1", snap.AgentConfig.APIVersion) require.Equal(t, int64(86400), snap.AgentConfig.RetentionPeriodSeconds) require.Equal(t, []string{"cpu", "gpu"}, snap.AgentConfig.EnabledComponents) require.Equal(t, []string{"disk"}, snap.AgentConfig.DisabledComponents) } + +func TestMachineInfoSourceCollectIgnoresSystemUUIDForMachineID(t *testing.T) { + src := NewMachineInfoSource(fakeMachineInfoCollector{ + info: &machineinfo.MachineInfo{ + MachineID: "machine-id", + SystemUUID: "system-uuid", + Hostname: "host-a", + }, + }) + + snap, err := src.Collect(context.Background()) + require.NoError(t, err) + require.NotNil(t, snap) + require.Equal(t, "machine-id", snap.MachineID) +} diff --git a/internal/inventory/types.go b/internal/inventory/types.go index 5f75cb9a..3aac2d9f 100644 --- a/internal/inventory/types.go +++ b/internal/inventory/types.go @@ -28,7 +28,6 @@ var ErrNotReady = errors.New("inventory backend not ready") // Snapshot is the agent-owned inventory state model. type Snapshot struct { CollectedAt time.Time - NodeID string InventoryHash string Hostname string MachineID string @@ -37,7 +36,7 @@ type Snapshot struct { OperatingSystem string OSImage string KernelVersion string - FleetintVersion string + AgentVersion string GPUDriverVersion string CUDAVersion string DCGMVersion string @@ -50,7 +49,6 @@ type Snapshot struct { type AgentConfig struct { TotalComponents int64 - APIVersion string RetentionPeriodSeconds int64 EnabledComponents []string DisabledComponents []string diff --git a/internal/machineinfo/machineinfo.go b/internal/machineinfo/machineinfo.go index 812f60b9..a6613cae 100644 --- a/internal/machineinfo/machineinfo.go +++ b/internal/machineinfo/machineinfo.go @@ -37,10 +37,10 @@ import ( var getDCGMVersion = dcgmversion.DetectHostengineVersion -// MachineInfo is a custom struct that replaces GPUdVersion with FleetintVersion +// MachineInfo is a custom struct that replaces GPUdVersion with AgentVersion. type MachineInfo struct { - // FleetintVersion represents the current version of Fleet Intelligence agent - FleetintVersion string `json:"fleetintVersion,omitempty"` + // AgentVersion represents the current version of the Fleet Intelligence agent. + AgentVersion string `json:"agentVersion,omitempty"` // GPUDriverVersion represents the current version of GPU driver installed GPUDriverVersion string `json:"gpuDriverVersion,omitempty"` // CUDAVersion represents the current version of cuda library. @@ -124,9 +124,9 @@ func GetMachineInfo(nvmlInstance nvidianvml.Instance, opts ...MachineInfoOption) dcgmVersion, _ := getDCGMVersion() - // Convert to our custom MachineInfo struct with Fleet Intelligence version + // Convert to our custom MachineInfo struct with the agent version. return &MachineInfo{ - FleetintVersion: version.Version, + AgentVersion: version.Version, GPUDriverVersion: gpudInfo.GPUDriverVersion, CUDAVersion: gpudInfo.CUDAVersion, DCGMVersion: dcgmVersion, @@ -156,8 +156,7 @@ func (i *MachineInfo) RenderTable(wr io.Writer) { table := tablewriter.NewWriter(wr) table.SetAlignment(tablewriter.ALIGN_CENTER) - // Show Fleetint Version instead of GPUd Version - table.Append([]string{"Fleetint Version", i.FleetintVersion}) + table.Append([]string{"Agent Version", i.AgentVersion}) table.Append([]string{"Container Runtime Version", i.ContainerRuntimeVersion}) table.Append([]string{"OS Image", i.OSImage}) table.Append([]string{"Kernel Version", i.KernelVersion}) diff --git a/internal/machineinfo/machineinfo_test.go b/internal/machineinfo/machineinfo_test.go index 4f10a5db..68b8c880 100644 --- a/internal/machineinfo/machineinfo_test.go +++ b/internal/machineinfo/machineinfo_test.go @@ -54,7 +54,7 @@ func TestGetMachineInfo(t *testing.T) { validate: func(t *testing.T, info *MachineInfo) { assert.NotNil(t, info) // The version should be set from the version package - assert.Equal(t, version.Version, info.FleetintVersion) + assert.Equal(t, version.Version, info.AgentVersion) assert.Equal(t, "4.2.3", info.DCGMVersion) // Other fields should be populated by the underlying GetMachineInfo assert.NotEmpty(t, info.Hostname) @@ -85,7 +85,7 @@ func TestMachineInfoStruct(t *testing.T) { now := metav1.Now() testInfo := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", GPUDriverVersion: "550.54.15", CUDAVersion: "12.4", DCGMVersion: "4.2.3", @@ -128,7 +128,7 @@ func TestMachineInfoStruct(t *testing.T) { } // Verify all fields are set correctly - assert.Equal(t, "1.0.0-test", testInfo.FleetintVersion) + assert.Equal(t, "1.0.0-test", testInfo.AgentVersion) assert.Equal(t, "550.54.15", testInfo.GPUDriverVersion) assert.Equal(t, "12.4", testInfo.CUDAVersion) assert.Equal(t, "4.2.3", testInfo.DCGMVersion) @@ -199,7 +199,7 @@ func TestRenderTable_Empty(t *testing.T) { // TestRenderTable_BasicFields tests RenderTable with basic fields func TestRenderTable_BasicFields(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", ContainerRuntimeVersion: "containerd://1.7.13", OSImage: "Ubuntu 22.04.4 LTS", KernelVersion: "6.5.0-28-generic", @@ -222,7 +222,7 @@ func TestRenderTable_BasicFields(t *testing.T) { assert.Contains(t, output, "4.2.3") // Verify labels are present - assert.Contains(t, output, "Fleetint Version") + assert.Contains(t, output, "Agent Version") assert.Contains(t, output, "Container Runtime Version") assert.Contains(t, output, "OS Image") assert.Contains(t, output, "Kernel Version") @@ -233,7 +233,7 @@ func TestRenderTable_BasicFields(t *testing.T) { // TestRenderTable_WithCPUInfo tests RenderTable with CPU information func TestRenderTable_WithCPUInfo(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", CPUInfo: &apiv1.MachineCPUInfo{ Type: "Intel(R) Xeon(R) CPU", Manufacturer: "GenuineIntel", @@ -264,7 +264,7 @@ func TestRenderTable_WithCPUInfo(t *testing.T) { // TestRenderTable_WithMemoryInfo tests RenderTable with memory information func TestRenderTable_WithMemoryInfo(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", MemoryInfo: &apiv1.MachineMemoryInfo{ TotalBytes: 137438953472, // 128 GiB }, @@ -285,7 +285,7 @@ func TestRenderTable_WithMemoryInfo(t *testing.T) { // TestRenderTable_WithGPUInfo tests RenderTable with GPU information func TestRenderTable_WithGPUInfo(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", GPUDriverVersion: "550.54.15", GPUInfo: &apiv1.MachineGPUInfo{ Product: "NVIDIA A100-SXM4-80GB", @@ -319,7 +319,7 @@ func TestRenderTable_WithGPUInfo(t *testing.T) { // TestRenderTable_WithNICInfo tests RenderTable with network interface information func TestRenderTable_WithNICInfo(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", NICInfo: &apiv1.MachineNICInfo{ PrivateIPInterfaces: []apiv1.MachineNetworkInterface{ { @@ -358,7 +358,7 @@ func TestRenderTable_WithNICInfo(t *testing.T) { // TestRenderTable_WithDiskInfo tests RenderTable with disk information func TestRenderTable_WithDiskInfo(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", DiskInfo: &apiv1.MachineDiskInfo{ ContainerRootDisk: "/dev/sda1", }, @@ -380,7 +380,7 @@ func TestRenderTable_Complete(t *testing.T) { now := metav1.NewTime(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)) info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", GPUDriverVersion: "550.54.15", CUDAVersion: "12.4", ContainerRuntimeVersion: "containerd://1.7.13", @@ -466,7 +466,7 @@ func TestMachineInfo_JSONMarshaling(t *testing.T) { // The actual marshaling is tested implicitly through the struct definition info := &MachineInfo{ - FleetintVersion: "1.0.0", + AgentVersion: "1.0.0", GPUDriverVersion: "550.54.15", CUDAVersion: "12.4", DCGMVersion: "4.2.3", @@ -482,7 +482,7 @@ func TestMachineInfo_JSONMarshaling(t *testing.T) { // Verify all fields have proper json tags assert.NotNil(t, info) - assert.NotEmpty(t, info.FleetintVersion) + assert.NotEmpty(t, info.AgentVersion) assert.NotEmpty(t, info.GPUDriverVersion) assert.NotEmpty(t, info.CUDAVersion) assert.NotEmpty(t, info.DCGMVersion) @@ -506,7 +506,7 @@ func TestGetMachineInfo_DCGMVersionBestEffort(t *testing.T) { // TestRenderTable_WithNilSubStructs tests that RenderTable handles nil sub-structs gracefully func TestRenderTable_WithNilSubStructs(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", // All sub-structs are intentionally nil CPUInfo: nil, MemoryInfo: nil, @@ -526,13 +526,13 @@ func TestRenderTable_WithNilSubStructs(t *testing.T) { // Should still show basic info assert.Contains(t, output, "1.0.0-test") - assert.Contains(t, output, "Fleetint Version") + assert.Contains(t, output, "Agent Version") } // TestRenderTable_EmptyNICList tests RenderTable with empty NIC list func TestRenderTable_EmptyNICList(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", NICInfo: &apiv1.MachineNICInfo{ PrivateIPInterfaces: []apiv1.MachineNetworkInterface{}, }, diff --git a/internal/server/server.go b/internal/server/server.go index 178f249f..f2432e5a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -422,7 +422,7 @@ func (s *Server) startInventoryLoop( } allComponents := registry.AllComponentNames() - apiVersion, retentionPeriodSeconds, enabledComponents, disabledComponents := cfg.InventoryAgentConfig(allComponents) + retentionPeriodSeconds, enabledComponents, disabledComponents := cfg.InventoryAgentConfig(allComponents) source := inventorysource.NewMachineInfoSourceWithAgentConfig( inventoryMachineInfoCollectorFunc(func(context.Context) (*machineinfo.MachineInfo, error) { @@ -430,7 +430,6 @@ func (s *Server) startInventoryLoop( }), &inventory.AgentConfig{ TotalComponents: int64(len(allComponents)), - APIVersion: apiVersion, RetentionPeriodSeconds: retentionPeriodSeconds, EnabledComponents: enabledComponents, DisabledComponents: disabledComponents, From 8d55dd99e0999ed8509e6885a148763db235918b Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Tue, 21 Apr 2026 13:57:34 -0700 Subject: [PATCH 11/22] fix: review comments Signed-off-by: Jingxiang Zhang --- cmd/fleetint/unenroll.go | 4 ++- cmd/fleetint/unenroll_test.go | 58 ++++++++++++++++++++++++++++++ internal/agentstate/sqlite.go | 30 ++++++++++++---- internal/agentstate/sqlite_test.go | 40 +++++++++++++++++++++ internal/enrollment/enrollment.go | 10 +++--- 5 files changed, 131 insertions(+), 11 deletions(-) create mode 100644 cmd/fleetint/unenroll_test.go diff --git a/cmd/fleetint/unenroll.go b/cmd/fleetint/unenroll.go index 381b31d1..ad8d12ef 100644 --- a/cmd/fleetint/unenroll.go +++ b/cmd/fleetint/unenroll.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "fmt" + "strings" "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" @@ -75,7 +76,8 @@ func removeEnrollmentMetadata(ctx context.Context, dbRW *sql.DB) error { } // Build batch delete query - query := "DELETE FROM gpud_metadata WHERE key IN (?, ?, ?, ?, ?, ?, ?)" + placeholders := strings.TrimSuffix(strings.Repeat("?, ", len(keysToDelete)), ", ") + query := fmt.Sprintf("DELETE FROM gpud_metadata WHERE key IN (%s)", placeholders) // Convert string slice to []interface{} for ExecContext args := make([]interface{}, len(keysToDelete)) diff --git a/cmd/fleetint/unenroll_test.go b/cmd/fleetint/unenroll_test.go new file mode 100644 index 00000000..d8f7adb2 --- /dev/null +++ b/cmd/fleetint/unenroll_test.go @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "path/filepath" + "testing" + + pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" + "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" + "github.com/stretchr/testify/require" +) + +func TestRemoveEnrollmentMetadata(t *testing.T) { + t.Parallel() + + db, err := sqlite.Open(filepath.Join(t.TempDir(), "agent.state")) + require.NoError(t, err) + defer db.Close() + + ctx := context.Background() + require.NoError(t, pkgmetadata.CreateTableMetadata(ctx, db)) + + for key, value := range map[string]string{ + pkgmetadata.MetadataKeyToken: "jwt-token", + "sak_token": "sak-token", + "backend_base_url": "https://backend.example.com", + "enroll_endpoint": "https://backend.example.com/api/v1/enroll", + "metrics_endpoint": "https://backend.example.com/api/v1/health/metrics", + "logs_endpoint": "https://backend.example.com/api/v1/health/logs", + "nonce_endpoint": "https://backend.example.com/api/v1/attest/nonce", + "keep_me": "still-here", + } { + require.NoError(t, pkgmetadata.SetMetadata(ctx, db, key, value)) + } + + require.NoError(t, removeEnrollmentMetadata(ctx, db)) + + for _, key := range []string{ + pkgmetadata.MetadataKeyToken, + "sak_token", + "backend_base_url", + "enroll_endpoint", + "metrics_endpoint", + "logs_endpoint", + "nonce_endpoint", + } { + value, err := pkgmetadata.ReadMetadata(ctx, db, key) + require.NoError(t, err) + require.Empty(t, value) + } + + value, err := pkgmetadata.ReadMetadata(ctx, db, "keep_me") + require.NoError(t, err) + require.Equal(t, "still-here", value) +} diff --git a/internal/agentstate/sqlite.go b/internal/agentstate/sqlite.go index fa25babb..cc969004 100644 --- a/internal/agentstate/sqlite.go +++ b/internal/agentstate/sqlite.go @@ -18,10 +18,13 @@ package agentstate import ( "context" "database/sql" + "errors" "fmt" + "strings" pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" + sqlite3 "github.com/mattn/go-sqlite3" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" @@ -65,6 +68,9 @@ func (s *sqliteState) GetBackendBaseURL(ctx context.Context) (string, bool, erro } func (s *sqliteState) SetBackendBaseURL(ctx context.Context, value string) error { + if _, err := endpoint.ValidateBackendEndpoint(value); err != nil { + return fmt.Errorf("validate backend base URL: %w", err) + } return s.setMetadata(ctx, metadataKeyBackendBaseURL, value) } @@ -101,6 +107,9 @@ func (s *sqliteState) getMetadata(ctx context.Context, key string) (string, bool value, err := pkgmetadata.ReadMetadata(ctx, db, key) if err != nil { + if isMetadataAbsentErr(err) { + return "", false, nil + } return "", false, fmt.Errorf("read metadata %q: %w", key, err) } if value == "" { @@ -116,18 +125,18 @@ func (s *sqliteState) setMetadata(ctx context.Context, key, value string) error } defer db.Close() - if err := pkgmetadata.CreateTableMetadata(ctx, db); err != nil { - return fmt.Errorf("create metadata table: %w", err) - } - if err := pkgmetadata.SetMetadata(ctx, db, key, value); err != nil { - return fmt.Errorf("set metadata %q: %w", key, err) - } stateFile, err := s.stateFileFn() if err == nil { if err := config.SecureStateFilePermissions(stateFile); err != nil { return fmt.Errorf("secure state file permissions: %w", err) } } + if err := pkgmetadata.CreateTableMetadata(ctx, db); err != nil { + return fmt.Errorf("create metadata table: %w", err) + } + if err := pkgmetadata.SetMetadata(ctx, db, key, value); err != nil { + return fmt.Errorf("set metadata %q: %w", key, err) + } return nil } @@ -143,6 +152,15 @@ func (s *sqliteState) openReadOnly() (*sql.DB, error) { return db, nil } +func isMetadataAbsentErr(err error) bool { + if errors.Is(err, sql.ErrNoRows) { + return true + } + + var sqliteErr sqlite3.Error + return errors.As(err, &sqliteErr) && sqliteErr.Code == sqlite3.ErrError && strings.Contains(strings.ToLower(sqliteErr.Error()), "no such table") +} + func (s *sqliteState) openReadWrite() (*sql.DB, error) { stateFile, err := s.stateFileFn() if err != nil { diff --git a/internal/agentstate/sqlite_test.go b/internal/agentstate/sqlite_test.go index 442f37d6..3a8d6554 100644 --- a/internal/agentstate/sqlite_test.go +++ b/internal/agentstate/sqlite_test.go @@ -21,6 +21,7 @@ import ( "path/filepath" "testing" + "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" "github.com/stretchr/testify/require" ) @@ -85,6 +86,45 @@ func TestSQLiteStateMissingValue(t *testing.T) { require.Empty(t, value) } +func TestSQLiteStateMissingMetadataTableIsTreatedAsAbsent(t *testing.T) { + t.Parallel() + + ctx := context.Background() + state := newTestSQLiteState(t) + + stateFile, err := state.stateFileFn() + require.NoError(t, err) + + db, err := sqlite.Open(stateFile) + require.NoError(t, err) + _, err = db.Exec("PRAGMA user_version = 1") + require.NoError(t, err) + require.NoError(t, db.Close()) + + for _, get := range []func(context.Context) (string, bool, error){ + state.GetJWT, + state.GetSAK, + state.GetNodeID, + } { + value, ok, err := get(ctx) + require.NoError(t, err) + require.False(t, ok) + require.Empty(t, value) + } +} + +func TestSQLiteStateSetBackendBaseURLValidatesInput(t *testing.T) { + t.Parallel() + + state := newTestSQLiteState(t) + + err := state.SetBackendBaseURL(context.Background(), "http://example.com") + require.Error(t, err) + + err = state.SetBackendBaseURL(context.Background(), "not-a-url") + require.Error(t, err) +} + func TestSQLiteStateGetBackendBaseURLFallsBackToLegacyEndpoints(t *testing.T) { t.Parallel() diff --git a/internal/enrollment/enrollment.go b/internal/enrollment/enrollment.go index 1ce73778..13e25895 100644 --- a/internal/enrollment/enrollment.go +++ b/internal/enrollment/enrollment.go @@ -77,18 +77,20 @@ func storeConfigInMetadata(ctx context.Context, baseURL, jwtToken, sakToken stri } defer dbRW.Close() + if err := config.SecureStateFilePermissions(stateFile); err != nil { + return fmt.Errorf("failed to secure state database permissions: %w", err) + } if err := pkgmetadata.CreateTableMetadata(ctx, dbRW); err != nil { return fmt.Errorf("failed to create metadata table: %w", err) } - state := agentstate.NewSQLite() - if err := state.SetSAK(ctx, sakToken); err != nil { + if err := pkgmetadata.SetMetadata(ctx, dbRW, "sak_token", sakToken); err != nil { return fmt.Errorf("failed to set SAK token: %w", err) } - if err := state.SetJWT(ctx, jwtToken); err != nil { + if err := pkgmetadata.SetMetadata(ctx, dbRW, pkgmetadata.MetadataKeyToken, jwtToken); err != nil { return fmt.Errorf("failed to set JWT token: %w", err) } - if err := state.SetBackendBaseURL(ctx, baseURL); err != nil { + if err := pkgmetadata.SetMetadata(ctx, dbRW, "backend_base_url", baseURL); err != nil { return fmt.Errorf("failed to set backend base URL: %w", err) } if err := config.SecureStateFilePermissions(stateFile); err != nil { From 44938f060e98b2722f795e9ab20bb5f829cec030 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Tue, 21 Apr 2026 14:14:34 -0700 Subject: [PATCH 12/22] chore: add license header to unenroll test Signed-off-by: Jingxiang Zhang --- cmd/fleetint/unenroll_test.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/cmd/fleetint/unenroll_test.go b/cmd/fleetint/unenroll_test.go index d8f7adb2..b52c9e5e 100644 --- a/cmd/fleetint/unenroll_test.go +++ b/cmd/fleetint/unenroll_test.go @@ -1,5 +1,17 @@ // SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. package main From 47c3759d96889a07759b01a5fab3e14923b3fde1 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Tue, 21 Apr 2026 14:55:53 -0700 Subject: [PATCH 13/22] fix: harden backend state and inventory workflows Signed-off-by: Jingxiang Zhang --- cmd/fleetint/status.go | 25 ++-- cmd/fleetint/unenroll.go | 5 +- internal/agentstate/sqlite.go | 34 ++++-- internal/agentstate/sqlite_test.go | 19 +++ internal/agentstate/state.go | 5 + internal/attestation/backend.go | 15 +++ internal/attestation/backend_test.go | 5 + internal/attestation/manager.go | 4 + internal/attestation/manager_test.go | 8 +- internal/attestation/nonce_test.go | 15 ++- internal/backendclient/client.go | 11 ++ internal/backendclient/client_test.go | 139 ++++++++++++++++++---- internal/enrollment/enrollment.go | 7 +- internal/inventory/manager.go | 34 ++++-- internal/inventory/manager_run_test.go | 8 +- internal/inventory/manager_test.go | 49 ++++++++ internal/inventory/mapper/backend_test.go | 4 + internal/inventory/source/source_test.go | 11 ++ 18 files changed, 324 insertions(+), 74 deletions(-) diff --git a/cmd/fleetint/status.go b/cmd/fleetint/status.go index 627b0af0..edb7666a 100644 --- a/cmd/fleetint/status.go +++ b/cmd/fleetint/status.go @@ -31,6 +31,7 @@ import ( "github.com/NVIDIA/fleet-intelligence-sdk/pkg/systemd" "github.com/urfave/cli" + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" "github.com/NVIDIA/fleet-intelligence-agent/internal/cmdutil" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" @@ -163,7 +164,7 @@ func statusCommand(cliContext *cli.Context) error { } func readEnrollmentStatus(ctx context.Context, dbRO *sql.DB) (*enrollmentStatus, error) { - baseURL, err := pkgmetadata.ReadMetadata(ctx, dbRO, "backend_base_url") + baseURL, err := pkgmetadata.ReadMetadata(ctx, dbRO, agentstate.MetadataKeyBackendBaseURL) if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, fmt.Errorf("failed to read backend base URL: %w", err) } @@ -172,17 +173,19 @@ func readEnrollmentStatus(ctx context.Context, dbRO *sql.DB) (*enrollmentStatus, if baseURL != "" { validated, err := endpoint.ValidateBackendEndpoint(baseURL) if err != nil { - return nil, fmt.Errorf("invalid backend base URL in metadata: %w", err) - } - status.metricsEndpoint, err = endpoint.JoinPath(validated, "api", "v1", "health", "metrics") - if err != nil { - return nil, fmt.Errorf("failed to construct metrics endpoint: %w", err) - } - status.logsEndpoint, err = endpoint.JoinPath(validated, "api", "v1", "health", "logs") - if err != nil { - return nil, fmt.Errorf("failed to construct logs endpoint: %w", err) + log.Logger.Warnw("ignoring invalid backend base URL in metadata", "backend_base_url", baseURL, "error", err) + status.baseURL = "" + } else { + status.metricsEndpoint, err = endpoint.JoinPath(validated, "api", "v1", "health", "metrics") + if err != nil { + return nil, fmt.Errorf("failed to construct metrics endpoint: %w", err) + } + status.logsEndpoint, err = endpoint.JoinPath(validated, "api", "v1", "health", "logs") + if err != nil { + return nil, fmt.Errorf("failed to construct logs endpoint: %w", err) + } + return status, nil } - return status, nil } status.metricsEndpoint, err = readLegacyEndpoint(ctx, dbRO, "metrics_endpoint") diff --git a/cmd/fleetint/unenroll.go b/cmd/fleetint/unenroll.go index ad8d12ef..a968a033 100644 --- a/cmd/fleetint/unenroll.go +++ b/cmd/fleetint/unenroll.go @@ -26,6 +26,7 @@ import ( "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" "github.com/urfave/cli" + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" ) @@ -67,8 +68,8 @@ func removeEnrollmentMetadata(ctx context.Context, dbRW *sql.DB) error { // List of metadata keys to delete keysToDelete := []string{ pkgmetadata.MetadataKeyToken, - "sak_token", - "backend_base_url", + agentstate.MetadataKeySAKToken, + agentstate.MetadataKeyBackendBaseURL, "enroll_endpoint", "metrics_endpoint", "logs_endpoint", diff --git a/internal/agentstate/sqlite.go b/internal/agentstate/sqlite.go index cc969004..f8d694ff 100644 --- a/internal/agentstate/sqlite.go +++ b/internal/agentstate/sqlite.go @@ -30,8 +30,6 @@ import ( "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" ) -const metadataKeyBackendBaseURL = "backend_base_url" - type sqliteState struct { stateFileFn func() (string, error) } @@ -48,20 +46,32 @@ func (s *sqliteState) GetBackendBaseURL(ctx context.Context) (string, bool, erro } defer db.Close() - if value, err := pkgmetadata.ReadMetadata(ctx, db, metadataKeyBackendBaseURL); err == nil && value != "" { + value, err := pkgmetadata.ReadMetadata(ctx, db, MetadataKeyBackendBaseURL) + switch { + case err == nil && value != "": return value, true, nil + case err == nil || isMetadataAbsentErr(err): + // fall through to legacy endpoint keys + default: + return "", false, fmt.Errorf("read metadata %q: %w", MetadataKeyBackendBaseURL, err) } for _, key := range []string{"enroll_endpoint", "metrics_endpoint", "logs_endpoint", "nonce_endpoint"} { value, err := pkgmetadata.ReadMetadata(ctx, db, key) - if err != nil || value == "" { + switch { + case err == nil && value == "": continue + case err == nil: + baseURL, err := endpoint.DeriveBackendBaseURL(value) + if err != nil { + return "", false, fmt.Errorf("derive backend base URL from metadata %q: %w", key, err) + } + return baseURL, true, nil + case isMetadataAbsentErr(err): + continue + default: + return "", false, fmt.Errorf("read metadata %q: %w", key, err) } - baseURL, err := endpoint.DeriveBackendBaseURL(value) - if err != nil { - return "", false, fmt.Errorf("derive backend base URL from metadata %q: %w", key, err) - } - return baseURL, true, nil } return "", false, nil @@ -71,7 +81,7 @@ func (s *sqliteState) SetBackendBaseURL(ctx context.Context, value string) error if _, err := endpoint.ValidateBackendEndpoint(value); err != nil { return fmt.Errorf("validate backend base URL: %w", err) } - return s.setMetadata(ctx, metadataKeyBackendBaseURL, value) + return s.setMetadata(ctx, MetadataKeyBackendBaseURL, value) } func (s *sqliteState) GetJWT(ctx context.Context) (string, bool, error) { @@ -83,11 +93,11 @@ func (s *sqliteState) SetJWT(ctx context.Context, value string) error { } func (s *sqliteState) GetSAK(ctx context.Context) (string, bool, error) { - return s.getMetadata(ctx, "sak_token") + return s.getMetadata(ctx, MetadataKeySAKToken) } func (s *sqliteState) SetSAK(ctx context.Context, value string) error { - return s.setMetadata(ctx, "sak_token", value) + return s.setMetadata(ctx, MetadataKeySAKToken, value) } func (s *sqliteState) GetNodeID(ctx context.Context) (string, bool, error) { diff --git a/internal/agentstate/sqlite_test.go b/internal/agentstate/sqlite_test.go index 3a8d6554..fede66d0 100644 --- a/internal/agentstate/sqlite_test.go +++ b/internal/agentstate/sqlite_test.go @@ -17,6 +17,7 @@ package agentstate import ( "context" + "database/sql" "errors" "path/filepath" "testing" @@ -161,3 +162,21 @@ func TestNewSQLite(t *testing.T) { t.Parallel() require.NotNil(t, NewSQLite()) } + +func TestSQLiteStateGetBackendBaseURLPropagatesReadErrors(t *testing.T) { + t.Parallel() + + ctx := context.Background() + state := newTestSQLiteState(t) + + stateFile, err := state.stateFileFn() + require.NoError(t, err) + + db, err := sqlite.Open(stateFile) + require.NoError(t, err) + require.NoError(t, db.Close()) + + _, _, err = state.GetBackendBaseURL(ctx) + require.Error(t, err) + require.NotErrorIs(t, err, sql.ErrNoRows) +} diff --git a/internal/agentstate/state.go b/internal/agentstate/state.go index 74ce21c9..e28a9dfd 100644 --- a/internal/agentstate/state.go +++ b/internal/agentstate/state.go @@ -18,6 +18,11 @@ package agentstate import "context" +const ( + MetadataKeyBackendBaseURL = "backend_base_url" + MetadataKeySAKToken = "sak_token" +) + // State provides local persisted metadata/state access for backend workflows. type State interface { GetBackendBaseURL(ctx context.Context) (value string, ok bool, err error) diff --git a/internal/attestation/backend.go b/internal/attestation/backend.go index 388d38b0..3dfd92f9 100644 --- a/internal/attestation/backend.go +++ b/internal/attestation/backend.go @@ -105,6 +105,21 @@ func NewStateBackendSubmitter(state agentstate.State) Submitter { } func (s *stateSubmitter) Submit(ctx context.Context, result *Result, jwt string) error { + if result == nil { + return fmt.Errorf("attestation submission requires result") + } + if result.NodeID == "" { + nodeID, ok, err := s.factory.state.GetNodeID(ctx) + if err != nil { + return err + } + if !ok || nodeID == "" { + return fmt.Errorf("%w: node ID not available in agent state", ErrNotEnrolled) + } + cloned := *result + cloned.NodeID = nodeID + result = &cloned + } client, err := s.factory.client(ctx) if err != nil { return err diff --git a/internal/attestation/backend_test.go b/internal/attestation/backend_test.go index cf4a1498..38e40054 100644 --- a/internal/attestation/backend_test.go +++ b/internal/attestation/backend_test.go @@ -179,6 +179,11 @@ func TestStateProvidersAndSubmitter(t *testing.T) { require.Equal(t, "jwt-token", recording.lastJWT) require.NotNil(t, recording.lastReq) require.Equal(t, "BLACKWELL", recording.lastReq.AttestationData.SDKResponse.Evidences[0].Arch) + + recording.lastNodeID = "" + err = NewStateBackendSubmitter(state).Submit(context.Background(), &Result{}, "jwt-token") + require.NoError(t, err) + require.Equal(t, "node-1", recording.lastNodeID) } func TestStateProvidersPropagateBackendClientConstructionErrors(t *testing.T) { diff --git a/internal/attestation/manager.go b/internal/attestation/manager.go index 52d67880..4497e425 100644 --- a/internal/attestation/manager.go +++ b/internal/attestation/manager.go @@ -143,6 +143,7 @@ func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { jwt = refreshedJWT } sdkResp, err := m.collector.Collect(ctx, nonce) + collectErr := err result := &Result{ CollectedAt: time.Now().UTC(), NodeID: nodeID, @@ -165,6 +166,9 @@ func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { if err := m.submitter.Submit(ctx, result, jwt); err != nil { return nil, err } + if collectErr != nil { + return result, collectErr + } return result, nil } diff --git a/internal/attestation/manager_test.go b/internal/attestation/manager_test.go index 5f9c8d8c..ab0dfa3f 100644 --- a/internal/attestation/manager_test.go +++ b/internal/attestation/manager_test.go @@ -50,9 +50,11 @@ func (p *testNonceProvider) GetNonce(context.Context, string, string) (string, t type testEvidenceCollector struct { resp *SDKResponse err error + n int } func (c *testEvidenceCollector) Collect(context.Context, string) (*SDKResponse, error) { + c.n++ return c.resp, c.err } @@ -64,9 +66,11 @@ type submitted struct { type testSubmitter struct { submitted submitted err error + count int } func (s *testSubmitter) Submit(_ context.Context, result *Result, jwt string) error { + s.count++ s.submitted = submitted{result: result, jwt: jwt} return s.err } @@ -107,7 +111,7 @@ func TestCollectOnceCollectorFailureStillSubmitsFailureResult(t *testing.T) { ) result, err := manager.CollectOnce(context.Background()) - require.NoError(t, err) + require.ErrorContains(t, err, "collect failed") require.False(t, result.Success) require.Equal(t, "collect failed", result.ErrorMessage) require.NotNil(t, submitter.submitted.result) @@ -160,6 +164,8 @@ func TestManagerRunUsesRetryIntervalOnFailure(t *testing.T) { require.NotNil(t, last) require.False(t, last.Success) require.Equal(t, "collect failed", last.ErrorMessage) + require.GreaterOrEqual(t, collector.n, 2) + require.GreaterOrEqual(t, submitter.count, 2) } func TestManagerHelpersAndSubmitterErrors(t *testing.T) { diff --git a/internal/attestation/nonce_test.go b/internal/attestation/nonce_test.go index 8c09bc28..77273f48 100644 --- a/internal/attestation/nonce_test.go +++ b/internal/attestation/nonce_test.go @@ -26,26 +26,33 @@ import ( ) type testNonceClient struct { - resp *backendclient.NonceResponse + resp *backendclient.NonceResponse + gotNodeID string + gotJWT string } -func (c *testNonceClient) GetNonce(context.Context, string, string) (*backendclient.NonceResponse, error) { +func (c *testNonceClient) GetNonce(_ context.Context, nodeID, jwt string) (*backendclient.NonceResponse, error) { + c.gotNodeID = nodeID + c.gotJWT = jwt return c.resp, nil } func TestBackendNonceProvider(t *testing.T) { refreshTS := time.Now().UTC() - provider := NewBackendNonceProvider(&testNonceClient{ + client := &testNonceClient{ resp: &backendclient.NonceResponse{ Nonce: "abc123", NonceRefreshTimestamp: refreshTS, JWTAssertion: "new-jwt", }, - }) + } + provider := NewBackendNonceProvider(client) nonce, ts, jwt, err := provider.GetNonce(context.Background(), "node-1", "jwt-token") require.NoError(t, err) require.Equal(t, "abc123", nonce) require.Equal(t, refreshTS, ts) require.Equal(t, "new-jwt", jwt) + require.Equal(t, "node-1", client.gotNodeID) + require.Equal(t, "jwt-token", client.gotJWT) } diff --git a/internal/backendclient/client.go b/internal/backendclient/client.go index 73f5dddc..573d49ec 100644 --- a/internal/backendclient/client.go +++ b/internal/backendclient/client.go @@ -35,6 +35,9 @@ const ( maxResponseBodyBytes = 1 << 20 ) +var errRedirectNotAllowed = errors.New("backend redirects are not allowed") +var errNilBaseURL = errors.New("backend base URL is required") + // Client is the backend workflow client used by enrollment, inventory, and attestation paths. type Client interface { Enroll(ctx context.Context, sakToken string) (jwt string, err error) @@ -64,6 +67,11 @@ func NewWithHTTPClient(baseURL *url.URL, httpClient *http.Client) Client { if httpClient == nil { httpClient = &http.Client{Timeout: 30 * time.Second} } + if httpClient.CheckRedirect == nil { + httpClient.CheckRedirect = func(*http.Request, []*http.Request) error { + return errRedirectNotAllowed + } + } return &client{ httpClient: httpClient, baseURL: baseURL, @@ -154,6 +162,9 @@ func (c *client) RefreshToken(ctx context.Context, jwt string) (string, error) { } func (c *client) doJSON(ctx context.Context, method string, pathElems []string, bearerToken string, reqBody any, respBody any) error { + if c.baseURL == nil { + return errNilBaseURL + } requestURL, err := endpoint.JoinPath(c.baseURL, pathElems...) if err != nil { return fmt.Errorf("failed to construct request URL: %w", err) diff --git a/internal/backendclient/client_test.go b/internal/backendclient/client_test.go index 75735d63..86cded32 100644 --- a/internal/backendclient/client_test.go +++ b/internal/backendclient/client_test.go @@ -23,19 +23,27 @@ import ( "net/http/httptest" "net/url" "strings" + "sync" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestClient_Enroll(t *testing.T) { t.Parallel() + var ( + gotMethod string + gotPath string + gotAuth string + gotUA string + ) server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - require.Equal(t, "/v1/agent/enroll", r.URL.Path) - require.Equal(t, "Bearer sak-token", r.Header.Get("Authorization")) - require.Equal(t, userAgent, r.Header.Get("User-Agent")) + gotMethod = r.Method + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + gotUA = r.Header.Get("User-Agent") w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]string{"jwtAssertion": "jwt-token"}) })) @@ -46,6 +54,10 @@ func TestClient_Enroll(t *testing.T) { jwt, err := c.Enroll(context.Background(), "sak-token") require.NoError(t, err) require.Equal(t, "jwt-token", jwt) + require.Equal(t, http.MethodPost, gotMethod) + require.Equal(t, "/v1/agent/enroll", gotPath) + require.Equal(t, "Bearer sak-token", gotAuth) + require.Equal(t, userAgent, gotUA) } func TestNew(t *testing.T) { @@ -59,14 +71,17 @@ func TestNew(t *testing.T) { func TestClient_UpsertNode(t *testing.T) { t.Parallel() + var ( + gotMethod string + gotPath string + gotAuth string + gotReq NodeUpsertRequest + ) server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPut, r.Method) - require.Equal(t, "/v1/agent/nodes/node-1", r.URL.Path) - require.Equal(t, "Bearer jwt-token", r.Header.Get("Authorization")) - - var req NodeUpsertRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) - require.Equal(t, "node-1", req.Hostname) + gotMethod = r.Method + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + _ = json.NewDecoder(r.Body).Decode(&gotReq) w.WriteHeader(http.StatusOK) })) defer server.Close() @@ -74,15 +89,24 @@ func TestClient_UpsertNode(t *testing.T) { c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) err := c.UpsertNode(context.Background(), "node-1", &NodeUpsertRequest{Hostname: "node-1"}, "jwt-token") require.NoError(t, err) + require.Equal(t, http.MethodPut, gotMethod) + require.Equal(t, "/v1/agent/nodes/node-1", gotPath) + require.Equal(t, "Bearer jwt-token", gotAuth) + require.Equal(t, "node-1", gotReq.Hostname) } func TestClient_GetNonce(t *testing.T) { t.Parallel() + var ( + gotMethod string + gotPath string + gotAuth string + ) server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - require.Equal(t, "/v1/agent/nodes/node-1/nonce", r.URL.Path) - require.Equal(t, "Bearer jwt-token", r.Header.Get("Authorization")) + gotMethod = r.Method + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") _ = json.NewEncoder(w).Encode(NonceResponse{ Nonce: "abc123", JWTAssertion: "new-jwt", @@ -95,15 +119,23 @@ func TestClient_GetNonce(t *testing.T) { require.NoError(t, err) require.Equal(t, "abc123", resp.Nonce) require.Equal(t, "new-jwt", resp.JWTAssertion) + require.Equal(t, http.MethodPost, gotMethod) + require.Equal(t, "/v1/agent/nodes/node-1/nonce", gotPath) + require.Equal(t, "Bearer jwt-token", gotAuth) } func TestClient_SubmitAttestation(t *testing.T) { t.Parallel() + var ( + gotMethod string + gotPath string + gotAuth string + ) server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - require.Equal(t, "/v1/agent/nodes/node-1/attestation", r.URL.Path) - require.Equal(t, "Bearer jwt-token", r.Header.Get("Authorization")) + gotMethod = r.Method + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") w.WriteHeader(http.StatusOK) })) defer server.Close() @@ -111,20 +143,25 @@ func TestClient_SubmitAttestation(t *testing.T) { c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) err := c.SubmitAttestation(context.Background(), "node-1", &AttestationRequest{}, "jwt-token") require.NoError(t, err) + require.Equal(t, http.MethodPost, gotMethod) + require.Equal(t, "/v1/agent/nodes/node-1/attestation", gotPath) + require.Equal(t, "Bearer jwt-token", gotAuth) } func TestClient_RefreshToken(t *testing.T) { t.Parallel() - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - require.Equal(t, "/v1/agent/token", r.URL.Path) - - var req struct { + var ( + gotMethod string + gotPath string + gotReq struct { JWTAssertion string `json:"jwtAssertion"` } - require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) - require.Equal(t, "jwt-token", req.JWTAssertion) + ) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + _ = json.NewDecoder(r.Body).Decode(&gotReq) _ = json.NewEncoder(w).Encode(map[string]string{"jwtAssertion": "new-jwt-token"}) })) @@ -134,6 +171,9 @@ func TestClient_RefreshToken(t *testing.T) { jwt, err := c.RefreshToken(context.Background(), "jwt-token") require.NoError(t, err) require.Equal(t, "new-jwt-token", jwt) + require.Equal(t, http.MethodPost, gotMethod) + require.Equal(t, "/v1/agent/token", gotPath) + require.Equal(t, "jwt-token", gotReq.JWTAssertion) } func TestClient_EnrollMapsHTTPStatus(t *testing.T) { @@ -179,6 +219,31 @@ func TestClient_ValidationErrors(t *testing.T) { _, err = c.RefreshToken(context.Background(), "") require.ErrorContains(t, err, "jwt cannot be empty") + + c = NewWithHTTPClient(nil, nil) + _, err = c.Enroll(context.Background(), "sak-token") + require.ErrorIs(t, err, errNilBaseURL) +} + +func TestClient_RejectsRedirects(t *testing.T) { + t.Parallel() + + redirected := false + target := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + redirected = true + w.WriteHeader(http.StatusOK) + })) + defer target.Close() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, target.URL, http.StatusTemporaryRedirect) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.Enroll(context.Background(), "sak-token") + require.ErrorContains(t, err, errRedirectNotAllowed.Error()) + require.False(t, redirected) } func TestClient_ResponseValidationAndErrors(t *testing.T) { @@ -253,6 +318,32 @@ func TestClient_ResponseValidationAndErrors(t *testing.T) { }) } +func TestClient_HandlerAssertionsDoNotRace(t *testing.T) { + t.Parallel() + + var ( + mu sync.Mutex + method string + ) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + method = r.Method + mu.Unlock() + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"jwtAssertion": "jwt-token"}) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.Enroll(context.Background(), "sak-token") + require.NoError(t, err) + + mu.Lock() + gotMethod := method + mu.Unlock() + assert.Equal(t, http.MethodPost, gotMethod) +} + func TestMapEnrollErrorStatuses(t *testing.T) { t.Parallel() diff --git a/internal/enrollment/enrollment.go b/internal/enrollment/enrollment.go index 13e25895..820dd8be 100644 --- a/internal/enrollment/enrollment.go +++ b/internal/enrollment/enrollment.go @@ -84,18 +84,15 @@ func storeConfigInMetadata(ctx context.Context, baseURL, jwtToken, sakToken stri return fmt.Errorf("failed to create metadata table: %w", err) } - if err := pkgmetadata.SetMetadata(ctx, dbRW, "sak_token", sakToken); err != nil { + if err := pkgmetadata.SetMetadata(ctx, dbRW, agentstate.MetadataKeySAKToken, sakToken); err != nil { return fmt.Errorf("failed to set SAK token: %w", err) } if err := pkgmetadata.SetMetadata(ctx, dbRW, pkgmetadata.MetadataKeyToken, jwtToken); err != nil { return fmt.Errorf("failed to set JWT token: %w", err) } - if err := pkgmetadata.SetMetadata(ctx, dbRW, "backend_base_url", baseURL); err != nil { + if err := pkgmetadata.SetMetadata(ctx, dbRW, agentstate.MetadataKeyBackendBaseURL, baseURL); err != nil { return fmt.Errorf("failed to set backend base URL: %w", err) } - if err := config.SecureStateFilePermissions(stateFile); err != nil { - return fmt.Errorf("failed to secure state database permissions: %w", err) - } return nil } diff --git a/internal/inventory/manager.go b/internal/inventory/manager.go index b7fa1491..8d4dc38f 100644 --- a/internal/inventory/manager.go +++ b/internal/inventory/manager.go @@ -34,10 +34,11 @@ type Manager interface { } type manager struct { - mu sync.RWMutex - source Source - sink Sink - config InventoryConfig + mu sync.RWMutex + exportMu sync.Mutex + source Source + sink Sink + config InventoryConfig lastSnapshot *Snapshot lastExportedHash string @@ -97,19 +98,26 @@ func (m *manager) CollectOnce(ctx context.Context) (*Snapshot, error) { m.mu.Lock() cloned := *snap m.lastSnapshot = &cloned - shouldExport := m.sink != nil && m.lastExportedHash != hash m.mu.Unlock() - if shouldExport { - if err := m.sink.Export(ctx, snap); err != nil { - if errors.Is(err, ErrNotReady) { - return snap, nil + if m.sink != nil { + m.exportMu.Lock() + defer m.exportMu.Unlock() + + m.mu.RLock() + alreadyExported := m.lastExportedHash == hash + m.mu.RUnlock() + if !alreadyExported { + if err := m.sink.Export(ctx, snap); err != nil { + if errors.Is(err, ErrNotReady) { + return snap, nil + } + return nil, err } - return nil, err + m.mu.Lock() + m.lastExportedHash = hash + m.mu.Unlock() } - m.mu.Lock() - m.lastExportedHash = hash - m.mu.Unlock() } return snap, nil diff --git a/internal/inventory/manager_run_test.go b/internal/inventory/manager_run_test.go index 3f9440ae..6eca8215 100644 --- a/internal/inventory/manager_run_test.go +++ b/internal/inventory/manager_run_test.go @@ -58,7 +58,7 @@ func TestManagerRunStopsOnContextCancel(t *testing.T) { src := &fakeSource{ snapshots: []*Snapshot{{MachineID: "machine-1", Hostname: "host-a"}}, } - sink := &fakeSink{} + sink := &fakeSink{ready: make(chan struct{}, 1)} ctx, cancel := context.WithCancel(context.Background()) done := make(chan error, 1) @@ -66,7 +66,11 @@ func TestManagerRunStopsOnContextCancel(t *testing.T) { done <- NewManager(src, sink, InventoryConfig{Interval: 10 * time.Millisecond}).Run(ctx) }() - time.Sleep(25 * time.Millisecond) + select { + case <-sink.ready: + case <-time.After(250 * time.Millisecond): + t.Fatal("timed out waiting for inventory export") + } cancel() err := <-done diff --git a/internal/inventory/manager_test.go b/internal/inventory/manager_test.go index d72b8b2d..e4b1b830 100644 --- a/internal/inventory/manager_test.go +++ b/internal/inventory/manager_test.go @@ -17,6 +17,7 @@ package inventory import ( "context" + "sync" "testing" "time" @@ -24,11 +25,14 @@ import ( ) type fakeSource struct { + mu sync.Mutex snapshots []*Snapshot index int } func (f *fakeSource) Collect(context.Context) (*Snapshot, error) { + f.mu.Lock() + defer f.mu.Unlock() if len(f.snapshots) == 0 { return nil, nil } @@ -42,12 +46,22 @@ func (f *fakeSource) Collect(context.Context) (*Snapshot, error) { } type fakeSink struct { + mu sync.Mutex exported []*Snapshot + ready chan struct{} } func (f *fakeSink) Export(_ context.Context, snap *Snapshot) error { + f.mu.Lock() + defer f.mu.Unlock() cloned := *snap f.exported = append(f.exported, &cloned) + if f.ready != nil { + select { + case f.ready <- struct{}{}: + default: + } + } return nil } @@ -98,3 +112,38 @@ func TestManagerCollectOnceExportsOnlyWhenInventoryChanges(t *testing.T) { require.NotEqual(t, snap1.InventoryHash, snap3.InventoryHash) require.Len(t, sink.exported, 2) } + +func TestManagerCollectOnceConcurrentExportSingleHash(t *testing.T) { + src := &fakeSource{ + snapshots: []*Snapshot{{ + CollectedAt: time.Unix(100, 0).UTC(), + Hostname: "host-a", + MachineID: "machine-id", + Resources: Resources{ + CPUInfo: CPUInfo{Type: "Xeon", LogicalCores: 64}, + }, + }}, + } + sink := &fakeSink{} + mgr := NewManager(src, sink, InventoryConfig{}) + + var ( + wg sync.WaitGroup + errs = make(chan error, 8) + ) + for range 8 { + wg.Add(1) + go func() { + defer wg.Done() + _, err := mgr.CollectOnce(context.Background()) + errs <- err + }() + } + wg.Wait() + close(errs) + + for err := range errs { + require.NoError(t, err) + } + require.Len(t, sink.exported, 1) +} diff --git a/internal/inventory/mapper/backend_test.go b/internal/inventory/mapper/backend_test.go index 93f27021..59b540eb 100644 --- a/internal/inventory/mapper/backend_test.go +++ b/internal/inventory/mapper/backend_test.go @@ -110,10 +110,14 @@ func TestToNodeUpsertRequest(t *testing.T) { require.Equal(t, []string{"disk"}, req.AgentConfig.DisabledComponents) require.Equal(t, int64(64), req.Resources.CPUInfo.LogicalCores) require.Equal(t, uint64(1024), req.Resources.MemoryInfo.TotalBytes) + require.Equal(t, "H100", req.Resources.GPUInfo.Product) + require.Equal(t, "NVIDIA", req.Resources.GPUInfo.Manufacturer) require.Len(t, req.Resources.GPUInfo.GPUs, 1) require.Equal(t, 7, req.Resources.GPUInfo.GPUs[0].BoardID) + require.Equal(t, "/dev/nvme0n1", req.Resources.DiskInfo.ContainerRootDisk) require.Len(t, req.Resources.DiskInfo.BlockDevices, 1) require.Equal(t, "parent0", req.Resources.DiskInfo.BlockDevices[0].Parents[0]) require.Len(t, req.Resources.NICInfo.PrivateIPInterfaces, 1) require.Equal(t, "eth0", req.Resources.NICInfo.PrivateIPInterfaces[0].Interface) + require.Equal(t, "10.0.0.10", req.Resources.NICInfo.PrivateIPInterfaces[0].IP) } diff --git a/internal/inventory/source/source_test.go b/internal/inventory/source/source_test.go index b8d59496..3596b3dd 100644 --- a/internal/inventory/source/source_test.go +++ b/internal/inventory/source/source_test.go @@ -155,3 +155,14 @@ func TestMachineInfoSourceCollectIgnoresSystemUUIDForMachineID(t *testing.T) { require.NotNil(t, snap) require.Equal(t, "machine-id", snap.MachineID) } + +func TestMachineInfoSourceCollectErrors(t *testing.T) { + _, err := NewMachineInfoSource(nil).Collect(context.Background()) + require.ErrorContains(t, err, "collector is required") + + _, err = NewMachineInfoSource(fakeMachineInfoCollector{err: context.DeadlineExceeded}).Collect(context.Background()) + require.ErrorIs(t, err, context.DeadlineExceeded) + + _, err = NewMachineInfoSource(fakeMachineInfoCollector{}).Collect(context.Background()) + require.ErrorContains(t, err, "nil machine info") +} From 466c374217634e4d46deebcaba44c042e9cb7542 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Tue, 21 Apr 2026 15:00:28 -0700 Subject: [PATCH 14/22] refactor: rename agent state node uuid accessors Signed-off-by: Jingxiang Zhang --- internal/agentstate/sqlite.go | 4 ++-- internal/agentstate/sqlite_test.go | 6 +++--- internal/agentstate/state.go | 4 ++-- internal/attestation/backend.go | 2 +- internal/attestation/backend_test.go | 4 ++-- internal/attestation/manager.go | 2 +- internal/inventory/sink/backend.go | 2 +- internal/inventory/sink/backend_test.go | 4 ++-- 8 files changed, 14 insertions(+), 14 deletions(-) diff --git a/internal/agentstate/sqlite.go b/internal/agentstate/sqlite.go index f8d694ff..285b0644 100644 --- a/internal/agentstate/sqlite.go +++ b/internal/agentstate/sqlite.go @@ -100,11 +100,11 @@ func (s *sqliteState) SetSAK(ctx context.Context, value string) error { return s.setMetadata(ctx, MetadataKeySAKToken, value) } -func (s *sqliteState) GetNodeID(ctx context.Context) (string, bool, error) { +func (s *sqliteState) GetNodeUUID(ctx context.Context) (string, bool, error) { return s.getMetadata(ctx, pkgmetadata.MetadataKeyMachineID) } -func (s *sqliteState) SetNodeID(ctx context.Context, value string) error { +func (s *sqliteState) SetNodeUUID(ctx context.Context, value string) error { return s.setMetadata(ctx, pkgmetadata.MetadataKeyMachineID, value) } diff --git a/internal/agentstate/sqlite_test.go b/internal/agentstate/sqlite_test.go index fede66d0..175b7ad6 100644 --- a/internal/agentstate/sqlite_test.go +++ b/internal/agentstate/sqlite_test.go @@ -48,7 +48,7 @@ func TestSQLiteStateRoundTrip(t *testing.T) { require.NoError(t, err) err = state.SetSAK(ctx, "sak-token") require.NoError(t, err) - err = state.SetNodeID(ctx, "node-1") + err = state.SetNodeUUID(ctx, "node-1") require.NoError(t, err) value, ok, err := state.GetBackendBaseURL(ctx) @@ -66,7 +66,7 @@ func TestSQLiteStateRoundTrip(t *testing.T) { require.True(t, ok) require.Equal(t, "sak-token", value) - value, ok, err = state.GetNodeID(ctx) + value, ok, err = state.GetNodeUUID(ctx) require.NoError(t, err) require.True(t, ok) require.Equal(t, "node-1", value) @@ -105,7 +105,7 @@ func TestSQLiteStateMissingMetadataTableIsTreatedAsAbsent(t *testing.T) { for _, get := range []func(context.Context) (string, bool, error){ state.GetJWT, state.GetSAK, - state.GetNodeID, + state.GetNodeUUID, } { value, ok, err := get(ctx) require.NoError(t, err) diff --git a/internal/agentstate/state.go b/internal/agentstate/state.go index e28a9dfd..2429c76e 100644 --- a/internal/agentstate/state.go +++ b/internal/agentstate/state.go @@ -34,6 +34,6 @@ type State interface { GetSAK(ctx context.Context) (value string, ok bool, err error) SetSAK(ctx context.Context, value string) error - GetNodeID(ctx context.Context) (value string, ok bool, err error) - SetNodeID(ctx context.Context, value string) error + GetNodeUUID(ctx context.Context) (value string, ok bool, err error) + SetNodeUUID(ctx context.Context, value string) error } diff --git a/internal/attestation/backend.go b/internal/attestation/backend.go index 3dfd92f9..2eae8880 100644 --- a/internal/attestation/backend.go +++ b/internal/attestation/backend.go @@ -109,7 +109,7 @@ func (s *stateSubmitter) Submit(ctx context.Context, result *Result, jwt string) return fmt.Errorf("attestation submission requires result") } if result.NodeID == "" { - nodeID, ok, err := s.factory.state.GetNodeID(ctx) + nodeID, ok, err := s.factory.state.GetNodeUUID(ctx) if err != nil { return err } diff --git a/internal/attestation/backend_test.go b/internal/attestation/backend_test.go index 38e40054..7ecfd2e9 100644 --- a/internal/attestation/backend_test.go +++ b/internal/attestation/backend_test.go @@ -47,10 +47,10 @@ func (s *stubState) GetJWT(context.Context) (string, bool, error) { return s. func (s *stubState) SetJWT(_ context.Context, v string) error { s.setJWT = v; s.jwt = v; return nil } func (s *stubState) GetSAK(context.Context) (string, bool, error) { return "", false, nil } func (s *stubState) SetSAK(context.Context, string) error { return nil } -func (s *stubState) GetNodeID(context.Context) (string, bool, error) { +func (s *stubState) GetNodeUUID(context.Context) (string, bool, error) { return s.nodeID, s.nodeOK, s.nodeErr } -func (s *stubState) SetNodeID(context.Context, string) error { return nil } +func (s *stubState) SetNodeUUID(context.Context, string) error { return nil } type recordingClient struct { lastNodeID string diff --git a/internal/attestation/manager.go b/internal/attestation/manager.go index 4497e425..0c695f94 100644 --- a/internal/attestation/manager.go +++ b/internal/attestation/manager.go @@ -311,7 +311,7 @@ func NewStateNodeIDProvider(state agentstate.State) func(context.Context) (strin if state == nil { return "", fmt.Errorf("node ID provider requires agent state") } - value, ok, err := state.GetNodeID(ctx) + value, ok, err := state.GetNodeUUID(ctx) if err != nil { return "", err } diff --git a/internal/inventory/sink/backend.go b/internal/inventory/sink/backend.go index 59f05591..a9b1dc1d 100644 --- a/internal/inventory/sink/backend.go +++ b/internal/inventory/sink/backend.go @@ -63,7 +63,7 @@ func (s *backendSink) Export(ctx context.Context, snap *inventory.Snapshot) erro if !ok || jwt == "" { return inventory.ErrNotReady } - nodeUUID, ok, err := s.state.GetNodeID(ctx) + nodeUUID, ok, err := s.state.GetNodeUUID(ctx) if err != nil { return err } diff --git a/internal/inventory/sink/backend_test.go b/internal/inventory/sink/backend_test.go index ad23fee2..1eefbc89 100644 --- a/internal/inventory/sink/backend_test.go +++ b/internal/inventory/sink/backend_test.go @@ -49,13 +49,13 @@ func (f fakeState) GetJWT(context.Context) (string, bool, error) { func (f fakeState) SetJWT(context.Context, string) error { return nil } func (f fakeState) GetSAK(context.Context) (string, bool, error) { return "", false, nil } func (f fakeState) SetSAK(context.Context, string) error { return nil } -func (f fakeState) GetNodeID(context.Context) (string, bool, error) { +func (f fakeState) GetNodeUUID(context.Context) (string, bool, error) { if f.err != nil { return "", false, f.err } return f.nodeID, f.nodeID != "", nil } -func (f fakeState) SetNodeID(context.Context, string) error { return nil } +func (f fakeState) SetNodeUUID(context.Context, string) error { return nil } type fakeClient struct { nodeID string From 8e7d86e87727e4f12610168383ca132a9431eb87 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Tue, 21 Apr 2026 15:13:52 -0700 Subject: [PATCH 15/22] refactor: rename node id to node uuid Signed-off-by: Jingxiang Zhang --- internal/attestation/backend.go | 14 +++---- internal/attestation/backend_test.go | 50 ++++++++++++------------- internal/attestation/collector_test.go | 2 +- internal/attestation/manager.go | 50 ++++++++++++------------- internal/attestation/manager_test.go | 12 +++--- internal/attestation/nonce.go | 10 ++--- internal/attestation/nonce_test.go | 12 +++--- internal/attestation/types.go | 4 +- internal/backendclient/client.go | 30 +++++++-------- internal/backendclient/client_test.go | 6 +-- internal/inventory/sink/backend_test.go | 30 +++++++-------- internal/server/server.go | 2 +- 12 files changed, 111 insertions(+), 111 deletions(-) diff --git a/internal/attestation/backend.go b/internal/attestation/backend.go index 2eae8880..d9c52811 100644 --- a/internal/attestation/backend.go +++ b/internal/attestation/backend.go @@ -87,12 +87,12 @@ func NewStateNonceProvider(state agentstate.State) NonceProvider { return &stateNonceProvider{factory: &stateBackendClientFactory{state: state}} } -func (p *stateNonceProvider) GetNonce(ctx context.Context, nodeID, jwt string) (string, time.Time, string, error) { +func (p *stateNonceProvider) GetNonce(ctx context.Context, nodeUUID, jwt string) (string, time.Time, string, error) { client, err := p.factory.client(ctx) if err != nil { return "", time.Time{}, "", err } - return NewBackendNonceProvider(client).GetNonce(ctx, nodeID, jwt) + return NewBackendNonceProvider(client).GetNonce(ctx, nodeUUID, jwt) } type stateSubmitter struct { @@ -108,16 +108,16 @@ func (s *stateSubmitter) Submit(ctx context.Context, result *Result, jwt string) if result == nil { return fmt.Errorf("attestation submission requires result") } - if result.NodeID == "" { - nodeID, ok, err := s.factory.state.GetNodeUUID(ctx) + if result.NodeUUID == "" { + nodeUUID, ok, err := s.factory.state.GetNodeUUID(ctx) if err != nil { return err } - if !ok || nodeID == "" { - return fmt.Errorf("%w: node ID not available in agent state", ErrNotEnrolled) + if !ok || nodeUUID == "" { + return fmt.Errorf("%w: node UUID not available in agent state", ErrNotEnrolled) } cloned := *result - cloned.NodeID = nodeID + cloned.NodeUUID = nodeUUID result = &cloned } client, err := s.factory.client(ctx) diff --git a/internal/attestation/backend_test.go b/internal/attestation/backend_test.go index 7ecfd2e9..ed43ae4d 100644 --- a/internal/attestation/backend_test.go +++ b/internal/attestation/backend_test.go @@ -27,16 +27,16 @@ import ( ) type stubState struct { - baseURL string - baseOK bool - baseErr error - jwt string - jwtOK bool - jwtErr error - setJWT string - nodeID string - nodeOK bool - nodeErr error + baseURL string + baseOK bool + baseErr error + jwt string + jwtOK bool + jwtErr error + setJWT string + nodeUUID string + nodeOK bool + nodeErr error } func (s *stubState) GetBackendBaseURL(context.Context) (string, bool, error) { @@ -48,15 +48,15 @@ func (s *stubState) SetJWT(_ context.Context, v string) error { s.setJWT func (s *stubState) GetSAK(context.Context) (string, bool, error) { return "", false, nil } func (s *stubState) SetSAK(context.Context, string) error { return nil } func (s *stubState) GetNodeUUID(context.Context) (string, bool, error) { - return s.nodeID, s.nodeOK, s.nodeErr + return s.nodeUUID, s.nodeOK, s.nodeErr } func (s *stubState) SetNodeUUID(context.Context, string) error { return nil } type recordingClient struct { - lastNodeID string - lastJWT string - lastReq *backendclient.AttestationRequest - nonceResp *backendclient.NonceResponse + lastNodeUUID string + lastJWT string + lastReq *backendclient.AttestationRequest + nonceResp *backendclient.NonceResponse } func (c *recordingClient) Enroll(context.Context, string) (string, error) { return "", nil } @@ -66,8 +66,8 @@ func (c *recordingClient) UpsertNode(context.Context, string, *backendclient.Nod func (c *recordingClient) GetNonce(context.Context, string, string) (*backendclient.NonceResponse, error) { return c.nonceResp, nil } -func (c *recordingClient) SubmitAttestation(_ context.Context, nodeID string, req *backendclient.AttestationRequest, jwt string) error { - c.lastNodeID = nodeID +func (c *recordingClient) SubmitAttestation(_ context.Context, nodeUUID string, req *backendclient.AttestationRequest, jwt string) error { + c.lastNodeUUID = nodeUUID c.lastJWT = jwt c.lastReq = req return nil @@ -146,7 +146,7 @@ func TestStateProvidersAndSubmitter(t *testing.T) { state := &stubState{ baseURL: "https://backend.example.com", baseOK: true, jwt: "jwt-token", jwtOK: true, - nodeID: "node-1", nodeOK: true, + nodeUUID: "node-1", nodeOK: true, } jwtProvider := NewStateJWTProvider(state) @@ -156,9 +156,9 @@ func TestStateProvidersAndSubmitter(t *testing.T) { require.NoError(t, jwtProvider.SetJWT(context.Background(), "updated")) require.Equal(t, "updated", state.setJWT) - nodeID, err := NewStateNodeIDProvider(state)(context.Background()) + nodeUUID, err := NewStateNodeUUIDProvider(state)(context.Background()) require.NoError(t, err) - require.Equal(t, "node-1", nodeID) + require.Equal(t, "node-1", nodeUUID) nonce, ts, refreshedJWT, err := NewStateNonceProvider(state).GetNonce(context.Background(), "node-1", "jwt-token") require.NoError(t, err) @@ -167,7 +167,7 @@ func TestStateProvidersAndSubmitter(t *testing.T) { require.Equal(t, "new-jwt", refreshedJWT) result := &Result{ - NodeID: "node-1", + NodeUUID: "node-1", SDKResponse: SDKResponse{ ResultCode: 1, Evidences: []EvidenceItem{{Arch: "BLACKWELL"}}, @@ -175,15 +175,15 @@ func TestStateProvidersAndSubmitter(t *testing.T) { } err = NewStateBackendSubmitter(state).Submit(context.Background(), result, "jwt-token") require.NoError(t, err) - require.Equal(t, "node-1", recording.lastNodeID) + require.Equal(t, "node-1", recording.lastNodeUUID) require.Equal(t, "jwt-token", recording.lastJWT) require.NotNil(t, recording.lastReq) require.Equal(t, "BLACKWELL", recording.lastReq.AttestationData.SDKResponse.Evidences[0].Arch) - recording.lastNodeID = "" + recording.lastNodeUUID = "" err = NewStateBackendSubmitter(state).Submit(context.Background(), &Result{}, "jwt-token") require.NoError(t, err) - require.Equal(t, "node-1", recording.lastNodeID) + require.Equal(t, "node-1", recording.lastNodeUUID) } func TestStateProvidersPropagateBackendClientConstructionErrors(t *testing.T) { @@ -198,6 +198,6 @@ func TestStateProvidersPropagateBackendClientConstructionErrors(t *testing.T) { _, _, _, err := NewStateNonceProvider(state).GetNonce(context.Background(), "node-1", "jwt-token") require.ErrorContains(t, err, "construct failed") - err = NewStateBackendSubmitter(state).Submit(context.Background(), &Result{NodeID: "node-1"}, "jwt-token") + err = NewStateBackendSubmitter(state).Submit(context.Background(), &Result{NodeUUID: "node-1"}, "jwt-token") require.ErrorContains(t, err, "construct failed") } diff --git a/internal/attestation/collector_test.go b/internal/attestation/collector_test.go index 171578a5..93423924 100644 --- a/internal/attestation/collector_test.go +++ b/internal/attestation/collector_test.go @@ -87,7 +87,7 @@ func TestBackendNonceProviderErrors(t *testing.T) { client := &testNonceClient{} _, _, _, err = NewBackendNonceProvider(client).GetNonce(context.Background(), "", "jwt") - require.ErrorContains(t, err, "node ID") + require.ErrorContains(t, err, "node UUID") _, _, _, err = NewBackendNonceProvider(client).GetNonce(context.Background(), "node", "") require.ErrorContains(t, err, "jwt") diff --git a/internal/attestation/manager.go b/internal/attestation/manager.go index 0c695f94..000e2092 100644 --- a/internal/attestation/manager.go +++ b/internal/attestation/manager.go @@ -45,13 +45,13 @@ type Manager interface { } type manager struct { - mu sync.RWMutex - nodeIDProvider func(context.Context) (string, error) - jwtProvider JWTProvider - nonceProvider NonceProvider - collector EvidenceCollector - submitter Submitter - config AttestationConfig + mu sync.RWMutex + nodeUUIDProvider func(context.Context) (string, error) + jwtProvider JWTProvider + nonceProvider NonceProvider + collector EvidenceCollector + submitter Submitter + config AttestationConfig lastResult *Result lastUpdated time.Time @@ -59,7 +59,7 @@ type manager struct { // NewManager creates an attestation loop manager skeleton. func NewManager( - nodeIDProvider func(context.Context) (string, error), + nodeUUIDProvider func(context.Context) (string, error), jwtProvider JWTProvider, nonceProvider NonceProvider, collector EvidenceCollector, @@ -67,17 +67,17 @@ func NewManager( cfg AttestationConfig, ) Manager { return &manager{ - nodeIDProvider: nodeIDProvider, - jwtProvider: jwtProvider, - nonceProvider: nonceProvider, - collector: collector, - submitter: submitter, - config: cfg, + nodeUUIDProvider: nodeUUIDProvider, + jwtProvider: jwtProvider, + nonceProvider: nonceProvider, + collector: collector, + submitter: submitter, + config: cfg, } } func (m *manager) Run(ctx context.Context) error { - if m.nodeIDProvider == nil || m.jwtProvider == nil || m.nonceProvider == nil || m.collector == nil || m.submitter == nil { + if m.nodeUUIDProvider == nil || m.jwtProvider == nil || m.nonceProvider == nil || m.collector == nil || m.submitter == nil { return fmt.Errorf("attestation loop dependencies are incomplete") } if m.config.Interval <= 0 { @@ -120,11 +120,11 @@ func (m *manager) Run(ctx context.Context) error { } func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { - if m.nodeIDProvider == nil || m.jwtProvider == nil || m.nonceProvider == nil || m.collector == nil || m.submitter == nil { + if m.nodeUUIDProvider == nil || m.jwtProvider == nil || m.nonceProvider == nil || m.collector == nil || m.submitter == nil { return nil, fmt.Errorf("attestation loop dependencies are incomplete") } - nodeID, err := m.nodeIDProvider(ctx) + nodeUUID, err := m.nodeUUIDProvider(ctx) if err != nil { return nil, err } @@ -132,7 +132,7 @@ func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { if err != nil { return nil, err } - nonce, refreshTS, refreshedJWT, err := m.nonceProvider.GetNonce(ctx, nodeID, jwt) + nonce, refreshTS, refreshedJWT, err := m.nonceProvider.GetNonce(ctx, nodeUUID, jwt) if err != nil { return nil, err } @@ -146,7 +146,7 @@ func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { collectErr := err result := &Result{ CollectedAt: time.Now().UTC(), - NodeID: nodeID, + NodeUUID: nodeUUID, NonceRefreshTimestamp: refreshTS, } if err != nil { @@ -254,7 +254,7 @@ type backendSubmitter struct { // BackendClient is the backend client view required by the attestation workflow. type BackendClient interface { - SubmitAttestation(ctx context.Context, nodeID string, req *backendclient.AttestationRequest, jwt string) error + SubmitAttestation(ctx context.Context, nodeUUID string, req *backendclient.AttestationRequest, jwt string) error } // NewBackendSubmitter creates a backend submitter backed by the agent backend client. @@ -272,7 +272,7 @@ func (s *backendSubmitter) Submit(ctx context.Context, result *Result, jwt strin if jwt == "" { return fmt.Errorf("attestation submission requires jwt") } - return s.client.SubmitAttestation(ctx, result.NodeID, toAttestationRequest(result), jwt) + return s.client.SubmitAttestation(ctx, result.NodeUUID, toAttestationRequest(result), jwt) } type stateJWTProvider struct { @@ -305,18 +305,18 @@ func (p *stateJWTProvider) SetJWT(ctx context.Context, value string) error { return p.state.SetJWT(ctx, value) } -// NewStateNodeIDProvider returns a node ID provider backed by persisted agent state. -func NewStateNodeIDProvider(state agentstate.State) func(context.Context) (string, error) { +// NewStateNodeUUIDProvider returns a node UUID provider backed by persisted agent state. +func NewStateNodeUUIDProvider(state agentstate.State) func(context.Context) (string, error) { return func(ctx context.Context) (string, error) { if state == nil { - return "", fmt.Errorf("node ID provider requires agent state") + return "", fmt.Errorf("node UUID provider requires agent state") } value, ok, err := state.GetNodeUUID(ctx) if err != nil { return "", err } if !ok || value == "" { - return "", fmt.Errorf("%w: node ID not available in agent state", ErrNotEnrolled) + return "", fmt.Errorf("%w: node UUID not available in agent state", ErrNotEnrolled) } return value, nil } diff --git a/internal/attestation/manager_test.go b/internal/attestation/manager_test.go index ab0dfa3f..fb7b5d12 100644 --- a/internal/attestation/manager_test.go +++ b/internal/attestation/manager_test.go @@ -91,7 +91,7 @@ func TestCollectOnceSuccess(t *testing.T) { result, err := manager.CollectOnce(context.Background()) require.NoError(t, err) require.True(t, result.Success) - require.Equal(t, "node-1", result.NodeID) + require.Equal(t, "node-1", result.NodeUUID) require.Equal(t, refreshTS, result.NonceRefreshTimestamp) require.Equal(t, "new-jwt", jwtProvider.setJWT) require.NotNil(t, submitter.submitted.result) @@ -140,7 +140,7 @@ func TestManagerRunAndCachedResult(t *testing.T) { last := mgr.LastResult() require.NotNil(t, last) - require.Equal(t, "node-1", last.NodeID) + require.Equal(t, "node-1", last.NodeUUID) require.True(t, mgr.IsResultUpdated(time.Time{})) } @@ -188,7 +188,7 @@ func TestManagerHelpersAndSubmitterErrors(t *testing.T) { require.ErrorContains(t, err, "requires jwt") } -func TestStateJWTProviderAndNodeIDProviderErrors(t *testing.T) { +func TestStateJWTProviderAndNodeUUIDProviderErrors(t *testing.T) { _, err := NewStateJWTProvider(nil).GetJWT(context.Background()) require.ErrorContains(t, err, "requires agent state") err = NewStateJWTProvider(nil).SetJWT(context.Background(), "x") @@ -197,10 +197,10 @@ func TestStateJWTProviderAndNodeIDProviderErrors(t *testing.T) { _, err = NewStateJWTProvider(&stubState{}).GetJWT(context.Background()) require.ErrorContains(t, err, "jwt not available") - _, err = NewStateNodeIDProvider(nil)(context.Background()) + _, err = NewStateNodeUUIDProvider(nil)(context.Background()) require.ErrorContains(t, err, "requires agent state") - _, err = NewStateNodeIDProvider(&stubState{})(context.Background()) - require.ErrorContains(t, err, "node ID not available") + _, err = NewStateNodeUUIDProvider(&stubState{})(context.Background()) + require.ErrorContains(t, err, "node UUID not available") } func TestSleepWithContext(t *testing.T) { diff --git a/internal/attestation/nonce.go b/internal/attestation/nonce.go index 975fda4a..074e7120 100644 --- a/internal/attestation/nonce.go +++ b/internal/attestation/nonce.go @@ -25,7 +25,7 @@ import ( // NonceBackendClient is the backend client view required by the nonce provider. type NonceBackendClient interface { - GetNonce(ctx context.Context, nodeID string, jwt string) (*backendclient.NonceResponse, error) + GetNonce(ctx context.Context, nodeUUID string, jwt string) (*backendclient.NonceResponse, error) } type backendNonceProvider struct { @@ -37,18 +37,18 @@ func NewBackendNonceProvider(client NonceBackendClient) NonceProvider { return &backendNonceProvider{client: client} } -func (p *backendNonceProvider) GetNonce(ctx context.Context, nodeID, jwt string) (string, time.Time, string, error) { +func (p *backendNonceProvider) GetNonce(ctx context.Context, nodeUUID, jwt string) (string, time.Time, string, error) { if p.client == nil { return "", time.Time{}, "", fmt.Errorf("nonce provider requires backend client") } - if nodeID == "" { - return "", time.Time{}, "", fmt.Errorf("nonce provider requires node ID") + if nodeUUID == "" { + return "", time.Time{}, "", fmt.Errorf("nonce provider requires node UUID") } if jwt == "" { return "", time.Time{}, "", fmt.Errorf("nonce provider requires jwt") } - resp, err := p.client.GetNonce(ctx, nodeID, jwt) + resp, err := p.client.GetNonce(ctx, nodeUUID, jwt) if err != nil { return "", time.Time{}, "", err } diff --git a/internal/attestation/nonce_test.go b/internal/attestation/nonce_test.go index 77273f48..2cd9cc09 100644 --- a/internal/attestation/nonce_test.go +++ b/internal/attestation/nonce_test.go @@ -26,13 +26,13 @@ import ( ) type testNonceClient struct { - resp *backendclient.NonceResponse - gotNodeID string - gotJWT string + resp *backendclient.NonceResponse + gotNodeUUID string + gotJWT string } -func (c *testNonceClient) GetNonce(_ context.Context, nodeID, jwt string) (*backendclient.NonceResponse, error) { - c.gotNodeID = nodeID +func (c *testNonceClient) GetNonce(_ context.Context, nodeUUID, jwt string) (*backendclient.NonceResponse, error) { + c.gotNodeUUID = nodeUUID c.gotJWT = jwt return c.resp, nil } @@ -53,6 +53,6 @@ func TestBackendNonceProvider(t *testing.T) { require.Equal(t, "abc123", nonce) require.Equal(t, refreshTS, ts) require.Equal(t, "new-jwt", jwt) - require.Equal(t, "node-1", client.gotNodeID) + require.Equal(t, "node-1", client.gotNodeUUID) require.Equal(t, "jwt-token", client.gotJWT) } diff --git a/internal/attestation/types.go b/internal/attestation/types.go index bc224763..eebd3b28 100644 --- a/internal/attestation/types.go +++ b/internal/attestation/types.go @@ -28,7 +28,7 @@ var ErrNotEnrolled = errors.New("agent not enrolled") // Result is the agent-owned attestation state model for the new backend sync loop. type Result struct { CollectedAt time.Time - NodeID string + NodeUUID string NonceRefreshTimestamp time.Time Success bool ErrorMessage string @@ -53,7 +53,7 @@ type EvidenceItem struct { // NonceProvider retrieves a backend nonce for a node. type NonceProvider interface { - GetNonce(ctx context.Context, nodeID, jwt string) (nonce string, refreshTS time.Time, refreshedJWT string, err error) + GetNonce(ctx context.Context, nodeUUID, jwt string) (nonce string, refreshTS time.Time, refreshedJWT string, err error) } // EvidenceCollector collects attestation evidence from local tooling. diff --git a/internal/backendclient/client.go b/internal/backendclient/client.go index 573d49ec..31958e2b 100644 --- a/internal/backendclient/client.go +++ b/internal/backendclient/client.go @@ -41,9 +41,9 @@ var errNilBaseURL = errors.New("backend base URL is required") // Client is the backend workflow client used by enrollment, inventory, and attestation paths. type Client interface { Enroll(ctx context.Context, sakToken string) (jwt string, err error) - UpsertNode(ctx context.Context, nodeID string, req *NodeUpsertRequest, jwt string) error - GetNonce(ctx context.Context, nodeID string, jwt string) (*NonceResponse, error) - SubmitAttestation(ctx context.Context, nodeID string, req *AttestationRequest, jwt string) error + UpsertNode(ctx context.Context, nodeUUID string, req *NodeUpsertRequest, jwt string) error + GetNonce(ctx context.Context, nodeUUID string, jwt string) (*NonceResponse, error) + SubmitAttestation(ctx context.Context, nodeUUID string, req *AttestationRequest, jwt string) error RefreshToken(ctx context.Context, jwt string) (newJWT string, err error) } @@ -95,9 +95,9 @@ func (c *client) Enroll(ctx context.Context, sakToken string) (string, error) { return resp.JWTAssertion, nil } -func (c *client) UpsertNode(ctx context.Context, nodeID string, req *NodeUpsertRequest, jwt string) error { - if nodeID == "" { - return fmt.Errorf("nodeID cannot be empty") +func (c *client) UpsertNode(ctx context.Context, nodeUUID string, req *NodeUpsertRequest, jwt string) error { + if nodeUUID == "" { + return fmt.Errorf("nodeUUID cannot be empty") } if jwt == "" { return fmt.Errorf("jwt cannot be empty") @@ -105,19 +105,19 @@ func (c *client) UpsertNode(ctx context.Context, nodeID string, req *NodeUpsertR if req == nil { return fmt.Errorf("node upsert request cannot be nil") } - return c.doJSON(ctx, http.MethodPut, []string{"v1", "agent", "nodes", nodeID}, jwt, req, nil) + return c.doJSON(ctx, http.MethodPut, []string{"v1", "agent", "nodes", nodeUUID}, jwt, req, nil) } -func (c *client) GetNonce(ctx context.Context, nodeID string, jwt string) (*NonceResponse, error) { - if nodeID == "" { - return nil, fmt.Errorf("nodeID cannot be empty") +func (c *client) GetNonce(ctx context.Context, nodeUUID string, jwt string) (*NonceResponse, error) { + if nodeUUID == "" { + return nil, fmt.Errorf("nodeUUID cannot be empty") } if jwt == "" { return nil, fmt.Errorf("jwt cannot be empty") } var resp NonceResponse - if err := c.doJSON(ctx, http.MethodPost, []string{"v1", "agent", "nodes", nodeID, "nonce"}, jwt, nil, &resp); err != nil { + if err := c.doJSON(ctx, http.MethodPost, []string{"v1", "agent", "nodes", nodeUUID, "nonce"}, jwt, nil, &resp); err != nil { return nil, err } if resp.Nonce == "" { @@ -126,9 +126,9 @@ func (c *client) GetNonce(ctx context.Context, nodeID string, jwt string) (*Nonc return &resp, nil } -func (c *client) SubmitAttestation(ctx context.Context, nodeID string, req *AttestationRequest, jwt string) error { - if nodeID == "" { - return fmt.Errorf("nodeID cannot be empty") +func (c *client) SubmitAttestation(ctx context.Context, nodeUUID string, req *AttestationRequest, jwt string) error { + if nodeUUID == "" { + return fmt.Errorf("nodeUUID cannot be empty") } if jwt == "" { return fmt.Errorf("jwt cannot be empty") @@ -136,7 +136,7 @@ func (c *client) SubmitAttestation(ctx context.Context, nodeID string, req *Atte if req == nil { return fmt.Errorf("attestation request cannot be nil") } - return c.doJSON(ctx, http.MethodPost, []string{"v1", "agent", "nodes", nodeID, "attestation"}, jwt, req, nil) + return c.doJSON(ctx, http.MethodPost, []string{"v1", "agent", "nodes", nodeUUID, "attestation"}, jwt, req, nil) } func (c *client) RefreshToken(ctx context.Context, jwt string) (string, error) { diff --git a/internal/backendclient/client_test.go b/internal/backendclient/client_test.go index 86cded32..8d8c4670 100644 --- a/internal/backendclient/client_test.go +++ b/internal/backendclient/client_test.go @@ -199,19 +199,19 @@ func TestClient_ValidationErrors(t *testing.T) { require.ErrorContains(t, err, "sakToken cannot be empty") err = c.UpsertNode(context.Background(), "", &NodeUpsertRequest{}, "jwt") - require.ErrorContains(t, err, "nodeID cannot be empty") + require.ErrorContains(t, err, "nodeUUID cannot be empty") err = c.UpsertNode(context.Background(), "node-1", nil, "jwt") require.ErrorContains(t, err, "cannot be nil") err = c.UpsertNode(context.Background(), "node-1", &NodeUpsertRequest{}, "") require.ErrorContains(t, err, "jwt cannot be empty") _, err = c.GetNonce(context.Background(), "", "jwt") - require.ErrorContains(t, err, "nodeID cannot be empty") + require.ErrorContains(t, err, "nodeUUID cannot be empty") _, err = c.GetNonce(context.Background(), "node-1", "") require.ErrorContains(t, err, "jwt cannot be empty") err = c.SubmitAttestation(context.Background(), "", &AttestationRequest{}, "jwt") - require.ErrorContains(t, err, "nodeID cannot be empty") + require.ErrorContains(t, err, "nodeUUID cannot be empty") err = c.SubmitAttestation(context.Background(), "node-1", nil, "jwt") require.ErrorContains(t, err, "cannot be nil") err = c.SubmitAttestation(context.Background(), "node-1", &AttestationRequest{}, "") diff --git a/internal/inventory/sink/backend_test.go b/internal/inventory/sink/backend_test.go index 1eefbc89..f144b57f 100644 --- a/internal/inventory/sink/backend_test.go +++ b/internal/inventory/sink/backend_test.go @@ -27,10 +27,10 @@ import ( ) type fakeState struct { - baseURL string - jwt string - nodeID string - err error + baseURL string + jwt string + nodeUUID string + err error } func (f fakeState) GetBackendBaseURL(context.Context) (string, bool, error) { @@ -53,14 +53,14 @@ func (f fakeState) GetNodeUUID(context.Context) (string, bool, error) { if f.err != nil { return "", false, f.err } - return f.nodeID, f.nodeID != "", nil + return f.nodeUUID, f.nodeUUID != "", nil } func (f fakeState) SetNodeUUID(context.Context, string) error { return nil } type fakeClient struct { - nodeID string - req *backendclient.NodeUpsertRequest - jwt string + nodeUUID string + req *backendclient.NodeUpsertRequest + jwt string } func (f *fakeClient) Enroll(context.Context, string) (string, error) { return "", nil } @@ -71,8 +71,8 @@ func (f *fakeClient) SubmitAttestation(context.Context, string, *backendclient.A return nil } func (f *fakeClient) RefreshToken(context.Context, string) (string, error) { return "", nil } -func (f *fakeClient) UpsertNode(_ context.Context, nodeID string, req *backendclient.NodeUpsertRequest, jwt string) error { - f.nodeID = nodeID +func (f *fakeClient) UpsertNode(_ context.Context, nodeUUID string, req *backendclient.NodeUpsertRequest, jwt string) error { + f.nodeUUID = nodeUUID f.req = req f.jwt = jwt return nil @@ -108,7 +108,7 @@ func TestBackendSinkExportErrors(t *testing.T) { require.ErrorContains(t, err, "inventory snapshot") err = (&backendSink{ - state: fakeState{baseURL: "https://example.com", jwt: "jwt", nodeID: "node-1"}, + state: fakeState{baseURL: "https://example.com", jwt: "jwt", nodeUUID: "node-1"}, clientFactory: func(string) (backendclient.Client, error) { return nil, errors.New("client factory error") }, @@ -120,9 +120,9 @@ func TestBackendSinkExportUsesState(t *testing.T) { client := &fakeClient{} s := &backendSink{ state: fakeState{ - baseURL: "https://example.com", - jwt: "jwt-token", - nodeID: "node-1", + baseURL: "https://example.com", + jwt: "jwt-token", + nodeUUID: "node-1", }, clientFactory: func(string) (backendclient.Client, error) { return client, nil @@ -134,7 +134,7 @@ func TestBackendSinkExportUsesState(t *testing.T) { MachineID: "machine-id", }) require.NoError(t, err) - require.Equal(t, "node-1", client.nodeID) + require.Equal(t, "node-1", client.nodeUUID) require.Equal(t, "jwt-token", client.jwt) require.NotNil(t, client.req) require.Equal(t, "host-a", client.req.Hostname) diff --git a/internal/server/server.go b/internal/server/server.go index f2432e5a..c1177786 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -460,7 +460,7 @@ func (s *Server) startAttestationLoop(ctx context.Context, cfg *config.Config) { state := agentstate.NewSQLite() manager := attestation.NewManager( - attestation.NewStateNodeIDProvider(state), + attestation.NewStateNodeUUIDProvider(state), attestation.NewStateJWTProvider(state), attestation.NewStateNonceProvider(state), attestation.NewCLIEvidenceCollector(getAttestationTimeout(cfg)), From 142a8394eaefb0f8053e97ac3950553fe3387665 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Tue, 21 Apr 2026 16:06:14 -0700 Subject: [PATCH 16/22] fix: normalize backend enroll endpoint Signed-off-by: Jingxiang Zhang --- internal/enrollment/enrollment.go | 19 ++++++++++- internal/enrollment/enrollment_test.go | 47 ++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/internal/enrollment/enrollment.go b/internal/enrollment/enrollment.go index 820dd8be..4cf067b7 100644 --- a/internal/enrollment/enrollment.go +++ b/internal/enrollment/enrollment.go @@ -19,6 +19,7 @@ package enrollment import ( "context" "fmt" + "net/url" "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" @@ -43,7 +44,7 @@ var ( // Enroll runs the full enrollment workflow and performs a best-effort initial inventory sync. func Enroll(ctx context.Context, baseEndpoint, sakToken string) error { - baseURL, err := endpoint.ValidateBackendEndpoint(baseEndpoint) + baseURL, err := normalizeBackendBaseURL(baseEndpoint) if err != nil { return fmt.Errorf("invalid enrollment endpoint: %w", err) } @@ -65,6 +66,22 @@ func Enroll(ctx context.Context, baseEndpoint, sakToken string) error { return nil } +func normalizeBackendBaseURL(raw string) (*url.URL, error) { + baseURL, err := endpoint.ValidateBackendEndpoint(raw) + if err != nil { + return nil, err + } + if baseURL.Path == "" || baseURL.Path == "/" { + return baseURL, nil + } + + normalized, err := endpoint.DeriveBackendBaseURL(raw) + if err != nil { + return nil, err + } + return endpoint.ValidateBackendEndpoint(normalized) +} + func storeConfigInMetadata(ctx context.Context, baseURL, jwtToken, sakToken string) error { stateFile, err := config.DefaultStateFile() if err != nil { diff --git a/internal/enrollment/enrollment_test.go b/internal/enrollment/enrollment_test.go index 53467e74..027a7413 100644 --- a/internal/enrollment/enrollment_test.go +++ b/internal/enrollment/enrollment_test.go @@ -84,6 +84,29 @@ func TestEnrollWorkflow(t *testing.T) { require.True(t, syncCalled) } +func TestEnrollWorkflowNormalizesLegacyEndpointToBaseURL(t *testing.T) { + originalFactory := newBackendClient + originalSync := syncInventoryAfterEnroll + t.Cleanup(func() { + newBackendClient = originalFactory + syncInventoryAfterEnroll = originalSync + }) + + client := &fakeBackendClient{enrollJWT: "jwt-token"} + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + require.Equal(t, "https://example.com", rawBaseURL) + return client, nil + } + syncInventoryAfterEnroll = func(context.Context) error { return nil } + + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + err := Enroll(context.Background(), "https://example.com/api/v1/health/metrics", "sak-token") + require.NoError(t, err) + require.Equal(t, "sak-token", client.enrollSAK) +} + func TestEnrollWorkflowErrors(t *testing.T) { t.Run("invalid endpoint", func(t *testing.T) { err := Enroll(context.Background(), "http://example.com", "sak-token") @@ -130,6 +153,30 @@ func TestEnrollWorkflowErrors(t *testing.T) { err := Enroll(context.Background(), "https://example.com", "sak-token") require.ErrorContains(t, err, "enroll boom") }) + + t.Run("localhost legacy endpoint allowed", func(t *testing.T) { + originalFactory := newBackendClient + originalSync := syncInventoryAfterEnroll + t.Cleanup(func() { + newBackendClient = originalFactory + syncInventoryAfterEnroll = originalSync + }) + + called := false + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + called = true + require.Equal(t, "http://localhost:8080", rawBaseURL) + return &fakeBackendClient{enrollJWT: "jwt-token"}, nil + } + syncInventoryAfterEnroll = func(context.Context) error { return nil } + + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + err := Enroll(context.Background(), "http://localhost:8080/api/v1/health/enroll", "sak-token") + require.NoError(t, err) + require.True(t, called) + }) } func TestEnrollWorkflowInventorySyncFailureIsNonFatal(t *testing.T) { From 5ce40b0183edca58345a22e6dc3c4c80dd094365 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Tue, 21 Apr 2026 16:11:48 -0700 Subject: [PATCH 17/22] fix: delay inventory loop after initial collect Signed-off-by: Jingxiang Zhang --- internal/inventory/manager.go | 26 +++++++++++------- internal/inventory/manager_run_test.go | 37 ++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/internal/inventory/manager.go b/internal/inventory/manager.go index 8d4dc38f..a720375b 100644 --- a/internal/inventory/manager.go +++ b/internal/inventory/manager.go @@ -54,28 +54,36 @@ func NewManager(source Source, sink Sink, cfg InventoryConfig) Manager { } func (m *manager) Run(ctx context.Context) error { - if _, err := m.CollectOnce(ctx); err != nil { + _, err := m.CollectOnce(ctx) + if err != nil { log.Logger.Warnw("initial inventory collection failed", "error", err) } if m.config.Interval <= 0 { return nil } + nextInterval := m.nextInterval(err) if m.config.JitterEnabled { - if err := sleepWithContext(ctx, calculateJitter(initialJitterCap(m.config.Interval))); err != nil { - return err - } + nextInterval += calculateJitter(initialJitterCap(nextInterval)) } for { - _, err := m.CollectOnce(ctx) - nextInterval := m.config.Interval - if err != nil && m.config.RetryInterval > 0 && m.config.RetryInterval < nextInterval { - nextInterval = m.config.RetryInterval + calculateJitter(retryJitterCap(m.config.RetryInterval)) - } if err := sleepWithContext(ctx, nextInterval); err != nil { return err } + _, err = m.CollectOnce(ctx) + nextInterval = m.nextInterval(err) + } +} + +func (m *manager) nextInterval(err error) time.Duration { + nextInterval := m.config.Interval + if err != nil && m.config.RetryInterval > 0 && m.config.RetryInterval < nextInterval { + nextInterval = m.config.RetryInterval + if m.config.JitterEnabled { + nextInterval += calculateJitter(retryJitterCap(m.config.RetryInterval)) + } } + return nextInterval } func (m *manager) CollectOnce(ctx context.Context) (*Snapshot, error) { diff --git a/internal/inventory/manager_run_test.go b/internal/inventory/manager_run_test.go index 6eca8215..ca21b8c1 100644 --- a/internal/inventory/manager_run_test.go +++ b/internal/inventory/manager_run_test.go @@ -32,6 +32,17 @@ type nilSnapshotSource struct{} func (nilSnapshotSource) Collect(context.Context) (*Snapshot, error) { return nil, nil } +type countingSource struct { + collectCh chan struct{} +} + +func (s *countingSource) Collect(context.Context) (*Snapshot, error) { + if s.collectCh != nil { + s.collectCh <- struct{}{} + } + return &Snapshot{MachineID: "machine-1", Hostname: "host-a"}, nil +} + func TestManagerCollectOnceErrors(t *testing.T) { _, err := NewManager(nil, nil, InventoryConfig{}).CollectOnce(context.Background()) require.ErrorContains(t, err, "inventory source is required") @@ -121,3 +132,29 @@ func TestManagerRunUsesRetryIntervalWithoutJitter(t *testing.T) { require.GreaterOrEqual(t, elapsed, 15*time.Millisecond) require.Less(t, elapsed, 100*time.Millisecond) } + +func TestManagerRunWaitsIntervalBeforeSecondCollect(t *testing.T) { + src := &countingSource{collectCh: make(chan struct{}, 4)} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- NewManager(src, nil, InventoryConfig{Interval: 50 * time.Millisecond}).Run(ctx) + }() + + select { + case <-src.collectCh: + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for initial collection") + } + + select { + case <-src.collectCh: + t.Fatal("second collection happened before interval elapsed") + case <-time.After(20 * time.Millisecond): + } + + cancel() + require.ErrorIs(t, <-done, context.Canceled) +} From 31a8997618e6b16424bf38fff51bcafa2a70dfd0 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Tue, 21 Apr 2026 16:17:42 -0700 Subject: [PATCH 18/22] fix: tolerate invalid exporter backend metadata Signed-off-by: Jingxiang Zhang --- internal/exporter/exporter.go | 15 +++- internal/exporter/exporter_test.go | 125 +++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 3 deletions(-) diff --git a/internal/exporter/exporter.go b/internal/exporter/exporter.go index 2eff6cfc..dd8301ad 100644 --- a/internal/exporter/exporter.go +++ b/internal/exporter/exporter.go @@ -265,16 +265,20 @@ func (e *healthExporter) refreshConfigFromMetadata(ctx context.Context) { validated, validateErr := endpoint.ValidateBackendEndpoint(baseURL) if validateErr != nil { log.Logger.Errorw("ignoring invalid backend base URL from metadata", "error", validateErr) + metricsEndpoint = e.readValidatedEndpoint(ctx, "metrics_endpoint") + logsEndpoint = e.readValidatedEndpoint(ctx, "logs_endpoint") } else { if joined, joinErr := endpoint.JoinPath(validated, "api", "v1", "health", "metrics"); joinErr == nil { metricsEndpoint = joined } else { log.Logger.Errorw("failed to derive metrics endpoint from backend base URL", "error", joinErr) + metricsEndpoint = e.readValidatedEndpoint(ctx, "metrics_endpoint") } if joined, joinErr := endpoint.JoinPath(validated, "api", "v1", "health", "logs"); joinErr == nil { logsEndpoint = joined } else { log.Logger.Errorw("failed to derive logs endpoint from backend base URL", "error", joinErr) + logsEndpoint = e.readValidatedEndpoint(ctx, "logs_endpoint") } } } else { @@ -348,8 +352,12 @@ func (e *healthExporter) refreshJWTToken(ctx context.Context) (string, error) { baseURL, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, "backend_base_url") if err == nil && baseURL != "" { - // use configured base URL - } else { + if _, validateErr := endpoint.ValidateBackendEndpoint(baseURL); validateErr != nil { + log.Logger.Errorw("ignoring invalid backend base URL for JWT refresh", "backend_base_url", baseURL, "error", validateErr) + baseURL = "" + } + } + if err != nil || baseURL == "" { baseURL, err = e.readLegacyBackendBaseURL(ctx) if err != nil { return "", err @@ -401,7 +409,8 @@ func (e *healthExporter) readLegacyBackendBaseURL(ctx context.Context) (string, } baseURL, err := endpoint.DeriveBackendBaseURL(value) if err != nil { - return "", fmt.Errorf("invalid legacy %s for JWT refresh: %w", key, err) + log.Logger.Errorw("ignoring invalid legacy backend endpoint for JWT refresh", "key", key, "value", value, "error", err) + continue } return baseURL, nil } diff --git a/internal/exporter/exporter_test.go b/internal/exporter/exporter_test.go index 9ea2362c..2a60edba 100644 --- a/internal/exporter/exporter_test.go +++ b/internal/exporter/exporter_test.go @@ -814,6 +814,41 @@ func TestRefreshConfigFromMetadata(t *testing.T) { err = exporter.Stop() require.NoError(t, err) }) + + t.Run("falls back to legacy endpoints when backend base URL is invalid", func(t *testing.T) { + tmpDB := setupTestDB(t) + defer tmpDB.Close() + + ctx := context.Background() + + err := pkgmetadata.SetMetadata(ctx, tmpDB, "backend_base_url", "http://bad-backend.example.com") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "metrics_endpoint", "https://legacy.example.com/api/v1/health/metrics") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "logs_endpoint", "https://legacy.example.com/api/v1/health/logs") + require.NoError(t, err) + + cfg := &config.HealthExporterConfig{ + Interval: metav1.Duration{Duration: 1 * time.Minute}, + Timeout: metav1.Duration{Duration: 30 * time.Second}, + } + + exporter, err := New(ctx, + WithConfig(cfg), + WithDatabaseConnections(tmpDB, tmpDB), + WithMachineID("test-machine-id"), + ) + require.NoError(t, err) + + he := exporter.(*healthExporter) + he.refreshConfigFromMetadata(ctx) + + assert.Equal(t, "https://legacy.example.com/api/v1/health/metrics", he.options.config.MetricsEndpoint) + assert.Equal(t, "https://legacy.example.com/api/v1/health/logs", he.options.config.LogsEndpoint) + + err = exporter.Stop() + require.NoError(t, err) + }) } // TestUpdateTokenInMetadata tests the updateTokenInMetadata function @@ -1018,6 +1053,96 @@ func TestRefreshJWTToken(t *testing.T) { err = exporter.Stop() require.NoError(t, err) }) + + t.Run("refreshes token using legacy endpoint when backend base URL is invalid", func(t *testing.T) { + expectedToken := "new-jwt-token" + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + assert.Equal(t, "https://backend.example.com", rawBaseURL) + return &fakeJWTRefreshClient{ + expectedSAK: "test-sak-token", + token: expectedToken, + }, nil + } + + tmpDB := setupTestDB(t) + defer tmpDB.Close() + + ctx := context.Background() + err := pkgmetadata.SetMetadata(ctx, tmpDB, "sak_token", "test-sak-token") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "backend_base_url", "http://bad-backend.example.com") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "enroll_endpoint", "https://backend.example.com/api/v1/health/enroll") + require.NoError(t, err) + + cfg := &config.HealthExporterConfig{ + Interval: metav1.Duration{Duration: 1 * time.Minute}, + Timeout: metav1.Duration{Duration: 30 * time.Second}, + } + + exporter, err := New(ctx, + WithConfig(cfg), + WithDatabaseConnections(tmpDB, tmpDB), + WithMachineID("test-machine-id"), + ) + require.NoError(t, err) + + he := exporter.(*healthExporter) + + token, err := he.refreshJWTToken(ctx) + require.NoError(t, err) + assert.Equal(t, expectedToken, token) + + err = exporter.Stop() + require.NoError(t, err) + }) + + t.Run("refreshes token using later legacy endpoint when earlier one is malformed", func(t *testing.T) { + expectedToken := "new-jwt-token" + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + assert.Equal(t, "https://backend.example.com", rawBaseURL) + return &fakeJWTRefreshClient{ + expectedSAK: "test-sak-token", + token: expectedToken, + }, nil + } + + tmpDB := setupTestDB(t) + defer tmpDB.Close() + + ctx := context.Background() + err := pkgmetadata.SetMetadata(ctx, tmpDB, "sak_token", "test-sak-token") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "enroll_endpoint", "https://backend.example.com?bad=query") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "metrics_endpoint", "https://backend.example.com/api/v1/health/metrics") + require.NoError(t, err) + + cfg := &config.HealthExporterConfig{ + Interval: metav1.Duration{Duration: 1 * time.Minute}, + Timeout: metav1.Duration{Duration: 30 * time.Second}, + } + + exporter, err := New(ctx, + WithConfig(cfg), + WithDatabaseConnections(tmpDB, tmpDB), + WithMachineID("test-machine-id"), + ) + require.NoError(t, err) + + he := exporter.(*healthExporter) + + token, err := he.refreshJWTToken(ctx) + require.NoError(t, err) + assert.Equal(t, expectedToken, token) + + err = exporter.Stop() + require.NoError(t, err) + }) } type fakeJWTRefreshClient struct { From 58f42fe9d2d64c94458db0429355f29b3e83aca6 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Wed, 22 Apr 2026 11:31:04 -0700 Subject: [PATCH 19/22] fix: send LogicalCores and TotalBytes as strings in inventory payload NodeResources on the backend read path uses string for these fields (matching the existing OTel/Kafka write path). Sending integers caused json.Unmarshal to fail silently, leaving the entire resources block empty in the node detail API response. Signed-off-by: Jingxiang Zhang --- internal/backendclient/types.go | 4 ++-- internal/inventory/mapper/backend.go | 6 ++++-- internal/inventory/mapper/backend_test.go | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/internal/backendclient/types.go b/internal/backendclient/types.go index 2ef8504f..8747bea7 100644 --- a/internal/backendclient/types.go +++ b/internal/backendclient/types.go @@ -57,11 +57,11 @@ type CPUInfo struct { Type string `json:"type"` Manufacturer string `json:"manufacturer"` Architecture string `json:"architecture"` - LogicalCores int64 `json:"logicalCores"` + LogicalCores string `json:"logicalCores"` } type MemoryInfo struct { - TotalBytes uint64 `json:"totalBytes"` + TotalBytes string `json:"totalBytes"` } type GPUInfo struct { diff --git a/internal/inventory/mapper/backend.go b/internal/inventory/mapper/backend.go index 60dd9d78..7ab224ad 100644 --- a/internal/inventory/mapper/backend.go +++ b/internal/inventory/mapper/backend.go @@ -17,6 +17,8 @@ package mapper import ( + "strconv" + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" ) @@ -90,10 +92,10 @@ func ToNodeUpsertRequest(s *inventory.Snapshot) *backendclient.NodeUpsertRequest Type: s.Resources.CPUInfo.Type, Manufacturer: s.Resources.CPUInfo.Manufacturer, Architecture: s.Resources.CPUInfo.Architecture, - LogicalCores: s.Resources.CPUInfo.LogicalCores, + LogicalCores: strconv.FormatInt(s.Resources.CPUInfo.LogicalCores, 10), }, MemoryInfo: backendclient.MemoryInfo{ - TotalBytes: s.Resources.MemoryInfo.TotalBytes, + TotalBytes: strconv.FormatUint(s.Resources.MemoryInfo.TotalBytes, 10), }, GPUInfo: backendclient.GPUInfo{ Product: s.Resources.GPUInfo.Product, diff --git a/internal/inventory/mapper/backend_test.go b/internal/inventory/mapper/backend_test.go index 59b540eb..383fd5f9 100644 --- a/internal/inventory/mapper/backend_test.go +++ b/internal/inventory/mapper/backend_test.go @@ -108,8 +108,8 @@ func TestToNodeUpsertRequest(t *testing.T) { require.Equal(t, int64(86400), req.AgentConfig.RetentionPeriodSeconds) require.Equal(t, []string{"cpu", "gpu"}, req.AgentConfig.EnabledComponents) require.Equal(t, []string{"disk"}, req.AgentConfig.DisabledComponents) - require.Equal(t, int64(64), req.Resources.CPUInfo.LogicalCores) - require.Equal(t, uint64(1024), req.Resources.MemoryInfo.TotalBytes) + require.Equal(t, "64", req.Resources.CPUInfo.LogicalCores) + require.Equal(t, "1024", req.Resources.MemoryInfo.TotalBytes) require.Equal(t, "H100", req.Resources.GPUInfo.Product) require.Equal(t, "NVIDIA", req.Resources.GPUInfo.Manufacturer) require.Len(t, req.Resources.GPUInfo.GPUs, 1) From b895ee4b1b9c0eae7f61b180173cbc030d8b243f Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Wed, 22 Apr 2026 11:50:34 -0700 Subject: [PATCH 20/22] refactor: drop InventoryHash from node upsert wire type InventoryHash is agent-internal state used for change detection only. The backend never stores or acts on it, so there is no reason to send it. Signed-off-by: Jingxiang Zhang --- internal/backendclient/types.go | 1 - internal/inventory/mapper/backend.go | 1 - internal/inventory/mapper/backend_test.go | 2 -- 3 files changed, 4 deletions(-) diff --git a/internal/backendclient/types.go b/internal/backendclient/types.go index 8747bea7..7fe771d1 100644 --- a/internal/backendclient/types.go +++ b/internal/backendclient/types.go @@ -35,7 +35,6 @@ type NodeUpsertRequest struct { BootID string `json:"bootID"` NetPrivateIP string `json:"netPrivateIP,omitempty"` NetPublicIP string `json:"netPublicIP,omitempty"` - InventoryHash string `json:"inventoryHash,omitempty"` } type NodeResources struct { diff --git a/internal/inventory/mapper/backend.go b/internal/inventory/mapper/backend.go index 7ab224ad..864da9cb 100644 --- a/internal/inventory/mapper/backend.go +++ b/internal/inventory/mapper/backend.go @@ -86,7 +86,6 @@ func ToNodeUpsertRequest(s *inventory.Snapshot) *backendclient.NodeUpsertRequest ContainerRuntimeVersion: s.ContainerRuntimeVersion, NetPrivateIP: s.NetPrivateIP, NetPublicIP: s.NetPublicIP, - InventoryHash: s.InventoryHash, Resources: backendclient.NodeResources{ CPUInfo: backendclient.CPUInfo{ Type: s.Resources.CPUInfo.Type, diff --git a/internal/inventory/mapper/backend_test.go b/internal/inventory/mapper/backend_test.go index 383fd5f9..8388aac1 100644 --- a/internal/inventory/mapper/backend_test.go +++ b/internal/inventory/mapper/backend_test.go @@ -43,7 +43,6 @@ func TestToNodeUpsertRequest(t *testing.T) { ContainerRuntimeVersion: "containerd://1.7.13", NetPrivateIP: "10.0.0.10", NetPublicIP: "203.0.113.10", - InventoryHash: "hash-1", AgentConfig: inventory.AgentConfig{ TotalComponents: 30, RetentionPeriodSeconds: 86400, @@ -103,7 +102,6 @@ func TestToNodeUpsertRequest(t *testing.T) { require.Equal(t, "host-a", req.Hostname) require.Equal(t, "machine-id", req.MachineID) require.Equal(t, "203.0.113.10", req.NetPublicIP) - require.Equal(t, "hash-1", req.InventoryHash) require.Equal(t, int64(30), req.AgentConfig.TotalComponents) require.Equal(t, int64(86400), req.AgentConfig.RetentionPeriodSeconds) require.Equal(t, []string{"cpu", "gpu"}, req.AgentConfig.EnabledComponents) From 26cf57e2be50011eb9d6f57fbafc245acb76bf97 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Wed, 22 Apr 2026 12:22:04 -0700 Subject: [PATCH 21/22] refactor: drop NetPublicIP from inventory wire type The agent never collected a public IP, so this field was always empty. Remove it from the snapshot type and the backend wire contract. Signed-off-by: Jingxiang Zhang --- internal/backendclient/types.go | 1 - internal/inventory/mapper/backend.go | 1 - internal/inventory/mapper/backend_test.go | 2 -- internal/inventory/types.go | 1 - 4 files changed, 5 deletions(-) diff --git a/internal/backendclient/types.go b/internal/backendclient/types.go index 7fe771d1..e17d4d65 100644 --- a/internal/backendclient/types.go +++ b/internal/backendclient/types.go @@ -34,7 +34,6 @@ type NodeUpsertRequest struct { MachineID string `json:"machineId"` BootID string `json:"bootID"` NetPrivateIP string `json:"netPrivateIP,omitempty"` - NetPublicIP string `json:"netPublicIP,omitempty"` } type NodeResources struct { diff --git a/internal/inventory/mapper/backend.go b/internal/inventory/mapper/backend.go index 864da9cb..e76bdac5 100644 --- a/internal/inventory/mapper/backend.go +++ b/internal/inventory/mapper/backend.go @@ -85,7 +85,6 @@ func ToNodeUpsertRequest(s *inventory.Snapshot) *backendclient.NodeUpsertRequest DCGMVersion: s.DCGMVersion, ContainerRuntimeVersion: s.ContainerRuntimeVersion, NetPrivateIP: s.NetPrivateIP, - NetPublicIP: s.NetPublicIP, Resources: backendclient.NodeResources{ CPUInfo: backendclient.CPUInfo{ Type: s.Resources.CPUInfo.Type, diff --git a/internal/inventory/mapper/backend_test.go b/internal/inventory/mapper/backend_test.go index 8388aac1..c870e370 100644 --- a/internal/inventory/mapper/backend_test.go +++ b/internal/inventory/mapper/backend_test.go @@ -42,7 +42,6 @@ func TestToNodeUpsertRequest(t *testing.T) { DCGMVersion: "4.2.3", ContainerRuntimeVersion: "containerd://1.7.13", NetPrivateIP: "10.0.0.10", - NetPublicIP: "203.0.113.10", AgentConfig: inventory.AgentConfig{ TotalComponents: 30, RetentionPeriodSeconds: 86400, @@ -101,7 +100,6 @@ func TestToNodeUpsertRequest(t *testing.T) { require.NotNil(t, req) require.Equal(t, "host-a", req.Hostname) require.Equal(t, "machine-id", req.MachineID) - require.Equal(t, "203.0.113.10", req.NetPublicIP) require.Equal(t, int64(30), req.AgentConfig.TotalComponents) require.Equal(t, int64(86400), req.AgentConfig.RetentionPeriodSeconds) require.Equal(t, []string{"cpu", "gpu"}, req.AgentConfig.EnabledComponents) diff --git a/internal/inventory/types.go b/internal/inventory/types.go index 3aac2d9f..58075d82 100644 --- a/internal/inventory/types.go +++ b/internal/inventory/types.go @@ -42,7 +42,6 @@ type Snapshot struct { DCGMVersion string ContainerRuntimeVersion string NetPrivateIP string - NetPublicIP string AgentConfig AgentConfig Resources Resources } From f9db3b9a3dd92bc22ad5b14a49383fb02810abd0 Mon Sep 17 00:00:00 2001 From: Jingxiang Zhang Date: Wed, 22 Apr 2026 14:13:21 -0700 Subject: [PATCH 22/22] feat: add X-Agent-Mode header on OTLP requests Signals to the backend receiver that inventory is written directly via the agent API, so legacy node extraction from OTLP resource attributes can be suppressed. Signed-off-by: Jingxiang Zhang --- internal/exporter/writer/http.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/exporter/writer/http.go b/internal/exporter/writer/http.go index dff82093..c5aa7730 100644 --- a/internal/exporter/writer/http.go +++ b/internal/exporter/writer/http.go @@ -228,6 +228,7 @@ func (w *httpWriter) sendOTLPRequest(ctx context.Context, reqData []byte, dataTy req.Header.Set("X-Machine-ID", machineID) req.Header.Set("X-Data-Type", dataType) req.Header.Set("X-Collection-ID", collectionID) + req.Header.Set("X-Agent-Mode", "direct-inventory-write") if authToken != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken))