diff --git a/charts/ome-resources/README.md b/charts/ome-resources/README.md index 60989a56..4826cb6c 100644 --- a/charts/ome-resources/README.md +++ b/charts/ome-resources/README.md @@ -12,6 +12,9 @@ OME Resources and Controller | global.hub | string | `"ghcr.io/moirai-internal"` | | | modelAgent.health.port | int | `8080` | | | modelAgent.hostPath | string | `"/mnt/data/models"` | | +| modelAgent.integrity.checkInterval | string | `"10m"` | | +| modelAgent.integrity.deepCheckInterval | string | `"6h"` | | +| modelAgent.integrity.startupJitter | string | `"30s"` | | | modelAgent.image.pullPolicy | string | `"Always"` | | | modelAgent.image.repository | string | `"model-agent"` | | | modelAgent.image.tag | string | `"v0.1.2"` | | diff --git a/charts/ome-resources/templates/model-agent-daemonset/daemonset.yaml b/charts/ome-resources/templates/model-agent-daemonset/daemonset.yaml index 4bcadcef..341cfede 100644 --- a/charts/ome-resources/templates/model-agent-daemonset/daemonset.yaml +++ b/charts/ome-resources/templates/model-agent-daemonset/daemonset.yaml @@ -58,6 +58,12 @@ spec: - {{ .Values.modelAgent.hostPath }} - --num-download-worker - '2' + - --integrity-check-interval + - {{ .Values.modelAgent.integrity.checkInterval | quote }} + - --integrity-deep-check-interval + - {{ .Values.modelAgent.integrity.deepCheckInterval | quote }} + - --integrity-startup-jitter + - {{ .Values.modelAgent.integrity.startupJitter | quote }} env: - name: NODE_NAME valueFrom: diff --git a/charts/ome-resources/values.yaml b/charts/ome-resources/values.yaml index 67ad1b08..0403428f 100644 --- a/charts/ome-resources/values.yaml +++ b/charts/ome-resources/values.yaml @@ -144,6 +144,11 @@ modelAgent: health: port: 8080 + integrity: + checkInterval: 10m + deepCheckInterval: 6h + startupJitter: 30s + # Additional environment variables for the model-agent container # Examples: # env: diff --git a/cmd/model-agent/main.go b/cmd/model-agent/main.go index 08909beb..c4d28203 100644 --- a/cmd/model-agent/main.go +++ b/cmd/model-agent/main.go @@ -31,18 +31,21 @@ import ( // config holds all configuration parameters for the model agent type config struct { - port int - modelsRootDir string - modelsRootDirOnHost string - nodeName string - nodeLabelRetry int - concurrency int - multipartConcurrency int - downloadRetry int - downloadAuthType string - numDownloadWorker int - namespace string - logLevel string + port int + modelsRootDir string + modelsRootDirOnHost string + nodeName string + nodeLabelRetry int + concurrency int + multipartConcurrency int + downloadRetry int + downloadAuthType string + numDownloadWorker int + namespace string + logLevel string + integrityCheckInterval time.Duration + integrityDeepInterval time.Duration + integrityStartupJitter time.Duration } // Logger type alias for zap.SugaredLogger @@ -73,6 +76,10 @@ func init() { rootCmd.PersistentFlags().IntVar(&cfg.numDownloadWorker, "num-download-worker", 5, "Number of download workers") rootCmd.PersistentFlags().StringVar(&cfg.namespace, "namespace", "ome", "Kubernetes namespace to use") rootCmd.PersistentFlags().StringVar(&cfg.logLevel, "log-level", "info", "Log level (debug, info, warn, error)") + defaultIntegrityConfig := modelagent.DefaultIntegrityConfig() + rootCmd.PersistentFlags().DurationVar(&cfg.integrityCheckInterval, "integrity-check-interval", defaultIntegrityConfig.CheckInterval, "Interval for periodic Ready model artifact integrity checks; set <=0 to disable") + rootCmd.PersistentFlags().DurationVar(&cfg.integrityDeepInterval, "integrity-deep-check-interval", defaultIntegrityConfig.DeepCheckInterval, "Interval for deep checksum validation; set <=0 to disable checksum scans") + rootCmd.PersistentFlags().DurationVar(&cfg.integrityStartupJitter, "integrity-startup-jitter", defaultIntegrityConfig.StartupJitter, "Maximum deterministic startup jitter before the first integrity check") _ = v.BindPFlags(rootCmd.PersistentFlags()) v.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) @@ -268,6 +275,14 @@ func initializeComponents( logger, baseModelInformer.Lister(), clusterBaseModelInformer.Lister(), + baseModelInformer.Informer().HasSynced, + clusterBaseModelInformer.Informer().HasSynced, + scout.NodeShapeAlias(), + modelagent.IntegrityConfig{ + CheckInterval: v.GetDuration("integrity-check-interval"), + DeepCheckInterval: v.GetDuration("integrity-deep-check-interval"), + StartupJitter: v.GetDuration("integrity-startup-jitter"), + }, ) if err != nil { return nil, nil, fmt.Errorf("failed to create gopher: %w", err) diff --git a/cmd/model-agent/main_test.go b/cmd/model-agent/main_test.go index db035df7..16735ede 100644 --- a/cmd/model-agent/main_test.go +++ b/cmd/model-agent/main_test.go @@ -14,6 +14,8 @@ import ( "go.uber.org/zap" "go.uber.org/zap/zaptest" "k8s.io/client-go/rest" + + "github.com/sgl-project/ome/pkg/modelagent" ) func setupTestEnv(t *testing.T) { @@ -113,6 +115,10 @@ func TestDefaultConfig(t *testing.T) { testCmd.Flags().StringVar(&cfg.downloadAuthType, "download-auth-type", "instance-principal", "authentication method for model download") testCmd.Flags().IntVar(&cfg.numDownloadWorker, "num-download-worker", 3, "number of download workers") testCmd.Flags().StringVar(&cfg.namespace, "namespace", "ome", "the namespace of the ome model agents daemon set") + defaultIntegrityConfig := modelagent.DefaultIntegrityConfig() + testCmd.Flags().DurationVar(&cfg.integrityCheckInterval, "integrity-check-interval", defaultIntegrityConfig.CheckInterval, "Model artifact integrity check interval") + testCmd.Flags().DurationVar(&cfg.integrityDeepInterval, "integrity-deep-check-interval", defaultIntegrityConfig.DeepCheckInterval, "Model artifact deep integrity check interval") + testCmd.Flags().DurationVar(&cfg.integrityStartupJitter, "integrity-startup-jitter", defaultIntegrityConfig.StartupJitter, "Model artifact integrity startup jitter") // Call initConfig to set cfg.nodeName initConfig(nil, nil) @@ -127,6 +133,9 @@ func TestDefaultConfig(t *testing.T) { assert.Equal(t, "instance-principal", cfg.downloadAuthType) assert.Equal(t, 3, cfg.numDownloadWorker) assert.Equal(t, "ome", cfg.namespace) + assert.Equal(t, defaultIntegrityConfig.CheckInterval, cfg.integrityCheckInterval) + assert.Equal(t, defaultIntegrityConfig.DeepCheckInterval, cfg.integrityDeepInterval) + assert.Equal(t, defaultIntegrityConfig.StartupJitter, cfg.integrityStartupJitter) } func TestInitializeLogger(t *testing.T) { diff --git a/config/model-agent/daemonset.yaml b/config/model-agent/daemonset.yaml index 4900e56e..0f3e6eee 100644 --- a/config/model-agent/daemonset.yaml +++ b/config/model-agent/daemonset.yaml @@ -51,6 +51,12 @@ spec: - '2' - --concurrency - '2' + - --integrity-check-interval + - 10m + - --integrity-deep-check-interval + - 6h + - --integrity-startup-jitter + - 30s env: - name: NODE_NAME valueFrom: diff --git a/pkg/controller/v1beta1/basemodel/controller.go b/pkg/controller/v1beta1/basemodel/controller.go index 4c5ddd47..b591f336 100644 --- a/pkg/controller/v1beta1/basemodel/controller.go +++ b/pkg/controller/v1beta1/basemodel/controller.go @@ -250,6 +250,7 @@ func handleModelDeletion(ctx context.Context, kubeClient client.Client, obj clie // updateModelStatus updates BaseModel status based on ConfigMap data func (r *BaseModelReconciler) updateModelStatus(ctx context.Context, baseModel *v1beta1.BaseModel) error { return processModelStatus(ctx, r.Client, r.Log, baseModel.Namespace, baseModel.Name, false, + &baseModel.Spec, func(ctx context.Context, config *modelagent.ModelConfig) error { return r.updateModelSpecWithRetry(ctx, baseModel, config) }, @@ -261,6 +262,7 @@ func (r *BaseModelReconciler) updateModelStatus(ctx context.Context, baseModel * // updateModelStatus updates ClusterBaseModel status based on ConfigMap data func (r *ClusterBaseModelReconciler) updateModelStatus(ctx context.Context, clusterBaseModel *v1beta1.ClusterBaseModel) error { return processModelStatus(ctx, r.Client, r.Log, "", clusterBaseModel.Name, true, + &clusterBaseModel.Spec, func(ctx context.Context, config *modelagent.ModelConfig) error { return r.updateModelSpecWithRetry(ctx, clusterBaseModel, config) }, @@ -271,6 +273,7 @@ func (r *ClusterBaseModelReconciler) updateModelStatus(ctx context.Context, clus // processModelStatus is a shared utility function for processing ConfigMaps and updating model status func processModelStatus(ctx context.Context, kubeClient client.Client, log logr.Logger, namespace, name string, isClusterScope bool, + spec *v1beta1.BaseModelSpec, specUpdateFunc func(context.Context, *modelagent.ModelConfig) error, statusUpdateFunc func(context.Context, []string, []string) error) error { @@ -332,6 +335,17 @@ func processModelStatus(ctx context.Context, kubeClient client.Client, log logr. log.V(1).Info("Processing model entry", "node", configMap.Name, "status", modelEntry.Status, "hasConfig", modelEntry.Config != nil, "hasProgress", modelEntry.Progress != nil) + if !modelEntry.MatchesStorageIdentity(spec) { + currentStorageURI, currentStoragePath, _ := modelagent.StorageIdentityForSpec(spec) + log.Info("Skipping stale model status entry", + "node", configMap.Name, + "entryStorageUri", modelEntry.StorageURI, + "currentStorageUri", currentStorageURI, + "entryStoragePath", modelEntry.StoragePath, + "currentStoragePath", currentStoragePath) + continue + } + // Update model spec with config if available if modelEntry.Config != nil { if err := specUpdateFunc(ctx, modelEntry.Config); err != nil { diff --git a/pkg/controller/v1beta1/basemodel/controller_test.go b/pkg/controller/v1beta1/basemodel/controller_test.go index 835b7a6f..a31705fc 100644 --- a/pkg/controller/v1beta1/basemodel/controller_test.go +++ b/pkg/controller/v1beta1/basemodel/controller_test.go @@ -874,6 +874,232 @@ func TestMapConfigMapToModelRequests(t *testing.T) { } } +func TestProcessModelStatusSkipsStaleStorageIdentity(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scheme := runtime.NewScheme() + g.Expect(v1beta1.AddToScheme(scheme)).NotTo(gomega.HaveOccurred()) + g.Expect(corev1.AddToScheme(scheme)).NotTo(gomega.HaveOccurred()) + + storageURI := "hf://Qwen/Qwen2.5-0.5B" + oldPath := "/raid/models/storage-identity/qwen2-5-0-5B-good" + currentPath := "/raid/models/storage-identity/qwen2-5-0-5b-caseflip" + currentSpec := &v1beta1.BaseModelSpec{ + Storage: &v1beta1.StorageSpec{ + StorageUri: &storageURI, + Path: ¤tPath, + }, + } + oldSpec := &v1beta1.BaseModelSpec{ + Storage: &v1beta1.StorageSpec{ + StorageUri: &storageURI, + Path: &oldPath, + }, + } + + staleEntry := modelagent.ModelEntry{ + Name: "stale-model", + Status: modelagent.ModelStatusReady, + Config: &modelagent.ModelConfig{ + ModelType: "qwen2", + }, + } + staleEntry.ApplyStorageIdentity(oldSpec) + staleEntryData, err := json.Marshal(staleEntry) + g.Expect(err).NotTo(gomega.HaveOccurred()) + + c := ctrlclientfake.NewClientBuilder(). + WithScheme(scheme). + WithObjects( + &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: constants.OMENamespace}}, + &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "worker-node-1"}}, + &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "worker-node-1", + Namespace: constants.OMENamespace, + Labels: map[string]string{ + constants.ModelStatusConfigMapLabel: "true", + }, + }, + Data: map[string]string{ + "clusterbasemodel.stale-model": string(staleEntryData), + }, + }, + ). + Build() + + specUpdated := false + var nodesReady, nodesFailed []string + err = processModelStatus( + context.Background(), + c, + ctrl.Log.WithName("test"), + "", + "stale-model", + true, + currentSpec, + func(context.Context, *modelagent.ModelConfig) error { + specUpdated = true + return nil + }, + func(_ context.Context, ready, failed []string) error { + nodesReady = ready + nodesFailed = failed + return nil + }, + ) + g.Expect(err).NotTo(gomega.HaveOccurred()) + g.Expect(specUpdated).To(gomega.BeFalse()) + g.Expect(nodesReady).To(gomega.BeEmpty()) + g.Expect(nodesFailed).To(gomega.BeEmpty()) +} + +func TestProcessModelStatusSkipsStaleFailedStorageIdentity(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scheme := runtime.NewScheme() + g.Expect(v1beta1.AddToScheme(scheme)).NotTo(gomega.HaveOccurred()) + g.Expect(corev1.AddToScheme(scheme)).NotTo(gomega.HaveOccurred()) + + storageURI := "hf://Qwen/Qwen2.5-0.5B" + oldPath := "/raid/models/storage-identity/qwen2-5-0-5B-old" + currentPath := "/raid/models/storage-identity/qwen2-5-0-5b-current" + currentSpec := &v1beta1.BaseModelSpec{ + Storage: &v1beta1.StorageSpec{ + StorageUri: &storageURI, + Path: ¤tPath, + }, + } + oldSpec := &v1beta1.BaseModelSpec{ + Storage: &v1beta1.StorageSpec{ + StorageUri: &storageURI, + Path: &oldPath, + }, + } + + staleEntry := modelagent.ModelEntry{ + Name: "stale-model", + Status: modelagent.ModelStatusFailed, + } + staleEntry.ApplyStorageIdentity(oldSpec) + staleEntryData, err := json.Marshal(staleEntry) + g.Expect(err).NotTo(gomega.HaveOccurred()) + + c := ctrlclientfake.NewClientBuilder(). + WithScheme(scheme). + WithObjects( + &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: constants.OMENamespace}}, + &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "worker-node-1"}}, + &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "worker-node-1", + Namespace: constants.OMENamespace, + Labels: map[string]string{ + constants.ModelStatusConfigMapLabel: "true", + }, + }, + Data: map[string]string{ + "clusterbasemodel.stale-model": string(staleEntryData), + }, + }, + ). + Build() + + var nodesReady, nodesFailed []string + err = processModelStatus( + context.Background(), + c, + ctrl.Log.WithName("test"), + "", + "stale-model", + true, + currentSpec, + func(context.Context, *modelagent.ModelConfig) error { + return nil + }, + func(_ context.Context, ready, failed []string) error { + nodesReady = ready + nodesFailed = failed + return nil + }, + ) + g.Expect(err).NotTo(gomega.HaveOccurred()) + g.Expect(nodesReady).To(gomega.BeEmpty()) + g.Expect(nodesFailed).To(gomega.BeEmpty()) +} + +func TestProcessModelStatusAcceptsLegacyStorageIdentity(t *testing.T) { + g := gomega.NewGomegaWithT(t) + + scheme := runtime.NewScheme() + g.Expect(v1beta1.AddToScheme(scheme)).NotTo(gomega.HaveOccurred()) + g.Expect(corev1.AddToScheme(scheme)).NotTo(gomega.HaveOccurred()) + + storageURI := "hf://Qwen/Qwen2.5-0.5B" + storagePath := "/raid/models/storage-identity/qwen2-5-0-5b" + currentSpec := &v1beta1.BaseModelSpec{ + Storage: &v1beta1.StorageSpec{ + StorageUri: &storageURI, + Path: &storagePath, + }, + } + + legacyEntry := modelagent.ModelEntry{ + Name: "legacy-model", + Status: modelagent.ModelStatusReady, + Config: &modelagent.ModelConfig{ + ModelType: "qwen2", + }, + } + legacyEntryData, err := json.Marshal(legacyEntry) + g.Expect(err).NotTo(gomega.HaveOccurred()) + + c := ctrlclientfake.NewClientBuilder(). + WithScheme(scheme). + WithObjects( + &corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: constants.OMENamespace}}, + &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: "worker-node-1"}}, + &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "worker-node-1", + Namespace: constants.OMENamespace, + Labels: map[string]string{ + constants.ModelStatusConfigMapLabel: "true", + }, + }, + Data: map[string]string{ + "clusterbasemodel.legacy-model": string(legacyEntryData), + }, + }, + ). + Build() + + specUpdated := false + var nodesReady, nodesFailed []string + err = processModelStatus( + context.Background(), + c, + ctrl.Log.WithName("test"), + "", + "legacy-model", + true, + currentSpec, + func(context.Context, *modelagent.ModelConfig) error { + specUpdated = true + return nil + }, + func(_ context.Context, ready, failed []string) error { + nodesReady = ready + nodesFailed = failed + return nil + }, + ) + g.Expect(err).NotTo(gomega.HaveOccurred()) + g.Expect(specUpdated).To(gomega.BeTrue()) + g.Expect(nodesReady).To(gomega.Equal([]string{"worker-node-1"})) + g.Expect(nodesFailed).To(gomega.BeEmpty()) +} + func TestUpdateSpecWithConfig(t *testing.T) { g := gomega.NewGomegaWithT(t) diff --git a/pkg/modelagent/artifact_manifest.go b/pkg/modelagent/artifact_manifest.go new file mode 100644 index 00000000..701e2828 --- /dev/null +++ b/pkg/modelagent/artifact_manifest.go @@ -0,0 +1,286 @@ +package modelagent + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "github.com/sgl-project/ome/pkg/apis/ome/v1beta1" + "github.com/sgl-project/ome/pkg/utils/storage" +) + +const ( + artifactManifestVersion = 1 + artifactManifestFileExt = ".json" + artifactHashAlgorithm = "sha256" +) + +type artifactManifest struct { + Version int `json:"version"` + StorageURI string `json:"storageUri,omitempty"` + StoragePath string `json:"storagePath,omitempty"` + SourceType string `json:"sourceType,omitempty"` + ArtifactRoot string `json:"artifactRoot"` + CreatedAt string `json:"createdAt"` + Files []artifactFileEntry `json:"files"` +} + +type artifactFileEntry struct { + Path string `json:"path"` + Size int64 `json:"size"` + Hash string `json:"hash,omitempty"` + HashAlgorithm string `json:"hashAlgorithm,omitempty"` +} + +func validateArtifactManifest(ctx context.Context, spec *v1beta1.BaseModelSpec, modelRootDir, modelPath string, deep bool) integrityReport { + manifest, root, err := loadArtifactManifest(spec, modelRootDir, modelPath) + if err != nil { + if os.IsNotExist(err) { + return inconclusiveReport(integrityReasonManifestError, err.Error()) + } + return failureReport(integrityReasonManifestError, err.Error()) + } + if manifest.Version != artifactManifestVersion { + return failureReport(integrityReasonManifestError, fmt.Sprintf("unsupported artifact manifest version %d", manifest.Version)) + } + if len(manifest.Files) == 0 { + return failureReport(integrityReasonManifestError, "artifact manifest does not contain any files") + } + + var bytesScanned int64 + for _, file := range manifest.Files { + relPath, err := cleanManifestRelativePath(file.Path) + if err != nil { + return failureReport(integrityReasonManifestError, err.Error()) + } + localPath := filepath.Join(root, relPath) + info, err := os.Stat(localPath) + if err != nil { + if os.IsNotExist(err) { + return failureReport(integrityReasonMissingWeight, fmt.Sprintf("manifest file is missing: %s", file.Path)) + } + return failureReport(integrityReasonManifestError, err.Error()) + } + if !info.Mode().IsRegular() { + return failureReport(integrityReasonManifestError, fmt.Sprintf("manifest path is not a regular file: %s", file.Path)) + } + if info.Size() != file.Size { + return failureReport(integrityReasonSizeMismatch, + fmt.Sprintf("manifest file size mismatch for %s: expected %d got %d", file.Path, file.Size, info.Size())) + } + if deep && file.Hash != "" { + if file.HashAlgorithm != artifactHashAlgorithm { + return failureReport(integrityReasonManifestError, + fmt.Sprintf("unsupported manifest hash algorithm %s for %s", file.HashAlgorithm, file.Path)) + } + hash, scanned, err := hashFile(ctx, localPath) + bytesScanned += scanned + if err != nil { + return failureReport(integrityReasonManifestError, err.Error()) + } + if hash != file.Hash { + return failureReport(integrityReasonChecksumMismatch, + fmt.Sprintf("manifest checksum mismatch for %s", file.Path)) + } + } + } + return successReport(integrityReasonOK, bytesScanned) +} + +func createArtifactManifest(ctx context.Context, spec *v1beta1.BaseModelSpec, modelRootDir, modelPath string, storageType storage.StorageType) (integrityReport, error) { + root, err := resolveManifestRoot(modelPath) + if err != nil { + return integrityReport{}, err + } + + manifestPath, storageURI, storagePath, err := artifactManifestPath(spec, modelRootDir, modelPath) + if err != nil { + return integrityReport{}, err + } + manifest := artifactManifest{ + Version: artifactManifestVersion, + StorageURI: storageURI, + StoragePath: storagePath, + SourceType: string(storageType), + ArtifactRoot: root, + CreatedAt: time.Now().UTC().Format(time.RFC3339), + } + + var bytesScanned int64 + err = filepath.WalkDir(root, func(path string, entry os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if entry.IsDir() { + if entry.Name() == artifactManifestDir { + return filepath.SkipDir + } + return nil + } + + info, err := entry.Info() + if err != nil { + return err + } + if !info.Mode().IsRegular() { + return nil + } + + relPath, err := filepath.Rel(root, path) + if err != nil { + return err + } + hash, scanned, err := hashFile(ctx, path) + if err != nil { + return err + } + bytesScanned += scanned + manifest.Files = append(manifest.Files, artifactFileEntry{ + Path: filepath.ToSlash(relPath), + Size: info.Size(), + Hash: hash, + HashAlgorithm: artifactHashAlgorithm, + }) + return nil + }) + if err != nil { + return integrityReport{}, err + } + if len(manifest.Files) == 0 { + return integrityReport{}, fmt.Errorf("no regular files found under model artifact path %s", modelPath) + } + sort.Slice(manifest.Files, func(i, j int) bool { + return manifest.Files[i].Path < manifest.Files[j].Path + }) + + if err := writeArtifactManifest(manifestPath, manifest); err != nil { + return integrityReport{}, err + } + return successReport(integrityReasonBaselineCreated, bytesScanned), nil +} + +func loadArtifactManifest(spec *v1beta1.BaseModelSpec, modelRootDir, modelPath string) (*artifactManifest, string, error) { + root, err := resolveManifestRoot(modelPath) + if err != nil { + return nil, "", err + } + manifestPath, _, _, err := artifactManifestPath(spec, modelRootDir, modelPath) + if err != nil { + return nil, "", err + } + data, err := os.ReadFile(manifestPath) + if err != nil { + return nil, root, err + } + var manifest artifactManifest + if err := json.Unmarshal(data, &manifest); err != nil { + return nil, root, err + } + return &manifest, root, nil +} + +func artifactManifestPath(spec *v1beta1.BaseModelSpec, modelRootDir, modelPath string) (manifestPath, storageURI, storagePath string, err error) { + if modelRootDir == "" { + return "", "", "", fmt.Errorf("model root directory is empty") + } + storageURI, storagePath, ok := StorageIdentityForSpec(spec) + if !ok { + return "", "", "", fmt.Errorf("model storage identity is missing") + } + if storagePath == "" { + storagePath = modelPath + } + sum := sha256.Sum256([]byte(storageURI + "\n" + storagePath)) + filename := hex.EncodeToString(sum[:]) + artifactManifestFileExt + return filepath.Join(modelRootDir, artifactManifestDir, "integrity", filename), storageURI, storagePath, nil +} + +func writeArtifactManifest(manifestPath string, manifest artifactManifest) error { + manifestDir := filepath.Dir(manifestPath) + if err := os.MkdirAll(manifestDir, 0755); err != nil { + return err + } + + tmpFile, err := os.CreateTemp(manifestDir, "artifact-manifest-*.tmp") + if err != nil { + return err + } + tmpName := tmpFile.Name() + encoder := json.NewEncoder(tmpFile) + encoder.SetIndent("", " ") + encodeErr := encoder.Encode(manifest) + closeErr := tmpFile.Close() + if encodeErr != nil { + _ = os.Remove(tmpName) + return encodeErr + } + if closeErr != nil { + _ = os.Remove(tmpName) + return closeErr + } + return os.Rename(tmpName, manifestPath) +} + +func resolveManifestRoot(modelPath string) (string, error) { + if modelPath == "" { + return "", fmt.Errorf("model artifact path is empty") + } + resolved, err := filepath.EvalSymlinks(modelPath) + if err != nil { + return "", err + } + return resolved, nil +} + +func cleanManifestRelativePath(path string) (string, error) { + if path == "" { + return "", fmt.Errorf("manifest contains an empty file path") + } + cleaned := filepath.Clean(filepath.FromSlash(path)) + if filepath.IsAbs(cleaned) || cleaned == ".." || strings.HasPrefix(cleaned, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("manifest contains an invalid relative path: %s", path) + } + return cleaned, nil +} + +func hashFile(ctx context.Context, path string) (string, int64, error) { + file, err := os.Open(path) + if err != nil { + return "", 0, err + } + defer file.Close() + + hasher := sha256.New() + buffer := make([]byte, 1024*1024) + var total int64 + for { + select { + case <-ctx.Done(): + return "", total, ctx.Err() + default: + } + + n, readErr := file.Read(buffer) + if n > 0 { + total += int64(n) + if _, err := hasher.Write(buffer[:n]); err != nil { + return "", total, err + } + } + if readErr == io.EOF { + break + } + if readErr != nil { + return "", total, readErr + } + } + return hex.EncodeToString(hasher.Sum(nil)), total, nil +} diff --git a/pkg/modelagent/configmap_reconciler.go b/pkg/modelagent/configmap_reconciler.go index 8d49e4d4..34068902 100644 --- a/pkg/modelagent/configmap_reconciler.go +++ b/pkg/modelagent/configmap_reconciler.go @@ -34,6 +34,8 @@ const ( type CacheEntry struct { ModelName string // Name of the model ModelStatus ModelStatus // Current status of the model + StorageURI string // Source URI used for the local artifact + StoragePath string // Local path used for the artifact ModelMetadata *ModelMetadata // Model metadata if available } @@ -224,8 +226,10 @@ func (c *ConfigMapReconciler) recreateConfigMap(ctx context.Context) { for modelID, cacheEntry := range c.modelCache { // Create model entry from cache data modelEntry := &ModelEntry{ - Name: cacheEntry.ModelName, - Status: cacheEntry.ModelStatus, + Name: cacheEntry.ModelName, + Status: cacheEntry.ModelStatus, + StorageURI: cacheEntry.StorageURI, + StoragePath: cacheEntry.StoragePath, } // Convert metadata to ModelConfig if available @@ -278,8 +282,10 @@ func (c *ConfigMapReconciler) recreateConfigMap(ctx context.Context) { func (c *ConfigMapReconciler) restoreModelInConfigMap(modelID string, cacheEntry *CacheEntry) { // Construct model entry from cache data modelEntry := &ModelEntry{ - Name: cacheEntry.ModelName, - Status: cacheEntry.ModelStatus, + Name: cacheEntry.ModelName, + Status: cacheEntry.ModelStatus, + StorageURI: cacheEntry.StorageURI, + StoragePath: cacheEntry.StoragePath, } // Convert metadata to ModelConfig if available @@ -403,6 +409,7 @@ func (c *ConfigMapReconciler) ReconcileModelStatus(ctx context.Context, statusOp } // Get existing cache entry or create a new one + spec := getModelSpec(statusOp.BaseModel, statusOp.ClusterBaseModel) cacheEntry, exists := c.modelCache[modelID] if !exists { // Extract model name for the cache entry @@ -419,9 +426,13 @@ func (c *ConfigMapReconciler) ReconcileModelStatus(ctx context.Context, statusOp } c.modelCache[modelID] = cacheEntry } else { + if !cacheEntry.matchesStorageIdentity(spec) || (statusOp.ModelStatus == ModelStatusUpdating && !cacheEntry.hasStorageIdentity()) { + cacheEntry.ModelMetadata = nil + } // Just update the status in existing entry cacheEntry.ModelStatus = statusOp.ModelStatus } + cacheEntry.applyStorageIdentity(spec) c.cacheMutex.Unlock() c.logger.Infof("Successfully updated ConfigMap and cache for %s with status: %s", modelInfo, statusOp.ModelStatus) @@ -462,6 +473,41 @@ func getModelID(baseModel *v1beta1.BaseModel, clusterBaseModel *v1beta1.ClusterB return constants.GetModelConfigMapKey(namespace, modelName, isClusterBaseModel) } +func getModelSpec(baseModel *v1beta1.BaseModel, clusterBaseModel *v1beta1.ClusterBaseModel) *v1beta1.BaseModelSpec { + if baseModel != nil { + return &baseModel.Spec + } + if clusterBaseModel != nil { + return &clusterBaseModel.Spec + } + return nil +} + +func (entry *CacheEntry) applyStorageIdentity(spec *v1beta1.BaseModelSpec) { + if entry == nil { + return + } + storageURI, storagePath, ok := StorageIdentityForSpec(spec) + if !ok { + entry.StorageURI = "" + entry.StoragePath = "" + return + } + entry.StorageURI = storageURI + entry.StoragePath = storagePath +} + +func (entry *CacheEntry) hasStorageIdentity() bool { + return entry != nil && hasStorageIdentityFields(entry.StorageURI, entry.StoragePath) +} + +func (entry *CacheEntry) matchesStorageIdentity(spec *v1beta1.BaseModelSpec) bool { + if entry == nil { + return false + } + return matchesStorageIdentityFields(entry.StorageURI, entry.StoragePath, spec) +} + // ReconcileModelMetadata updates the ConfigMap with model metadata func (c *ConfigMapReconciler) ReconcileModelMetadata(ctx context.Context, op *ConfigMapMetadataOp) error { modelInfo := getConfigMapModelInfo(op.BaseModel, op.ClusterBaseModel) @@ -507,6 +553,7 @@ func (c *ConfigMapReconciler) ReconcileModelMetadata(ctx context.Context, op *Co // Update the metadata cacheEntry.ModelMetadata = &op.ModelMetadata } + cacheEntry.applyStorageIdentity(getModelSpec(op.BaseModel, op.ClusterBaseModel)) c.cacheMutex.Unlock() c.logger.Infof("Successfully updated ConfigMap and cache for %s with metadata", modelInfo) @@ -584,7 +631,15 @@ func (c *ConfigMapReconciler) updateModelProgressInConfigMap(ctx context.Context } // Update the progress (can be nil to clear it) + if !modelEntry.MatchesStorageIdentity(getModelSpec(op.BaseModel, op.ClusterBaseModel)) || + (op.Progress != nil && !modelEntry.hasStorageIdentity()) { + modelEntry.Config = nil + } + if op.Progress != nil { + modelEntry.Status = ModelStatusUpdating + } modelEntry.Progress = op.Progress + modelEntry.ApplyStorageIdentity(getModelSpec(op.BaseModel, op.ClusterBaseModel)) // Marshal the model entry back to JSON entryJSON, err := json.Marshal(modelEntry) @@ -768,6 +823,11 @@ func (c *ConfigMapReconciler) updateModelStatusInConfigMap(ctx context.Context, Config: nil, } } else { + if !modelEntry.MatchesStorageIdentity(getModelSpec(op.BaseModel, op.ClusterBaseModel)) || + (op.ModelStatus == ModelStatusUpdating && !modelEntry.hasStorageIdentity()) { + modelEntry.Config = nil + modelEntry.Progress = nil + } // Update just the status, preserving the config modelEntry.Status = op.ModelStatus // Clear progress when status becomes Ready or Failed (download complete) @@ -784,6 +844,7 @@ func (c *ConfigMapReconciler) updateModelStatusInConfigMap(ctx context.Context, Config: nil, } } + modelEntry.ApplyStorageIdentity(getModelSpec(op.BaseModel, op.ClusterBaseModel)) // For 'ModelStatusDeleted' status, we might want to entirely remove the entry if op.ModelStatus == ModelStatusDeleted { @@ -851,6 +912,7 @@ func (c *ConfigMapReconciler) updateModelMetadataInConfigMap(ctx context.Context // Update the config in the model entry modelEntry.Config = modelConfig + modelEntry.ApplyStorageIdentity(getModelSpec(op.BaseModel, op.ClusterBaseModel)) // Marshal the model entry back to JSON entryJSON, err := json.Marshal(modelEntry) diff --git a/pkg/modelagent/configmap_reconciler_test.go b/pkg/modelagent/configmap_reconciler_test.go index 4fded244..493e7499 100644 --- a/pkg/modelagent/configmap_reconciler_test.go +++ b/pkg/modelagent/configmap_reconciler_test.go @@ -256,6 +256,187 @@ func TestUpdateModelStatusInConfigMap(t *testing.T) { } +func TestUpdateModelStatusInConfigMapStoresStorageIdentity(t *testing.T) { + reconciler, _, _ := setupConfigMapTest(t) + configMap := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-node", + Namespace: "test-namespace", + }, + Data: make(map[string]string), + } + + storageURI := "hf://google/gemma-4-31B-it" + storagePath := "/raid/models/google/gemma-4-31B-it" + baseModel := createTestBaseModelCM() + baseModel.Spec.Storage = &v1beta1.StorageSpec{ + StorageUri: &storageURI, + Path: &storagePath, + } + + op := &ConfigMapStatusOp{ + BaseModel: baseModel, + ModelStatus: ModelStatusReady, + } + + err := reconciler.updateModelStatusInConfigMap(context.Background(), configMap, op, true) + assert.NoError(t, err) + + key := reconciler.getModelConfigMapKey(baseModel, nil) + var modelEntry ModelEntry + err = json.Unmarshal([]byte(configMap.Data[key]), &modelEntry) + assert.NoError(t, err) + assert.Equal(t, storageURI, modelEntry.StorageURI) + assert.Equal(t, storagePath, modelEntry.StoragePath) +} + +func TestUpdateModelStatusInConfigMapClearsConfigOnStorageIdentityChange(t *testing.T) { + reconciler, _, _ := setupConfigMapTest(t) + + oldStorageURI := "hf://google/gemma-4-31B-it" + oldStoragePath := "/raid/models/google/gemma-4-31B-it" + newStorageURI := "hf://google/gemma-4-31b-it" + newStoragePath := "/raid/models/google/gemma-4-31b-it" + baseModel := createTestBaseModelCM() + baseModel.Spec.Storage = &v1beta1.StorageSpec{ + StorageUri: &newStorageURI, + Path: &newStoragePath, + } + + existingEntry := ModelEntry{ + Name: baseModel.Name, + Status: ModelStatusReady, + StorageURI: oldStorageURI, + StoragePath: oldStoragePath, + Config: &ModelConfig{ + ModelType: "gemma", + }, + Progress: &DownloadProgress{ + Phase: "Downloading", + }, + } + existingEntryJSON, err := json.Marshal(existingEntry) + assert.NoError(t, err) + + key := reconciler.getModelConfigMapKey(baseModel, nil) + configMap := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-node", + Namespace: "test-namespace", + }, + Data: map[string]string{ + key: string(existingEntryJSON), + }, + } + + op := &ConfigMapStatusOp{ + BaseModel: baseModel, + ModelStatus: ModelStatusUpdating, + } + err = reconciler.updateModelStatusInConfigMap(context.Background(), configMap, op, true) + assert.NoError(t, err) + + var modelEntry ModelEntry + err = json.Unmarshal([]byte(configMap.Data[key]), &modelEntry) + assert.NoError(t, err) + assert.Equal(t, ModelStatusUpdating, modelEntry.Status) + assert.Equal(t, newStorageURI, modelEntry.StorageURI) + assert.Equal(t, newStoragePath, modelEntry.StoragePath) + assert.Nil(t, modelEntry.Config) + assert.Nil(t, modelEntry.Progress) +} + +func TestUpdateModelProgressInConfigMapStoresStorageIdentity(t *testing.T) { + reconciler, _, _ := setupConfigMapTest(t) + + storageURI := "hf://google/gemma-4-31b-it" + storagePath := "/raid/models/google/gemma-4-31b-it" + baseModel := createTestBaseModelCM() + baseModel.Spec.Storage = &v1beta1.StorageSpec{ + StorageUri: &storageURI, + Path: &storagePath, + } + + key := reconciler.getModelConfigMapKey(baseModel, nil) + existingEntry := ModelEntry{ + Name: baseModel.Name, + Status: ModelStatusReady, + Config: &ModelConfig{ + ModelType: "gemma", + }, + } + existingEntryJSON, err := json.Marshal(existingEntry) + assert.NoError(t, err) + + configMap := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-node", + Namespace: "test-namespace", + }, + Data: map[string]string{ + key: string(existingEntryJSON), + }, + } + + progress := &DownloadProgress{ + Phase: "Downloading", + TotalFiles: 10, + CompletedFiles: 1, + } + op := &ConfigMapProgressOp{ + BaseModel: baseModel, + Progress: progress, + } + + err = reconciler.updateModelProgressInConfigMap(context.Background(), configMap, op, true) + assert.NoError(t, err) + + var modelEntry ModelEntry + err = json.Unmarshal([]byte(configMap.Data[key]), &modelEntry) + assert.NoError(t, err) + assert.Equal(t, ModelStatusUpdating, modelEntry.Status) + assert.Equal(t, storageURI, modelEntry.StorageURI) + assert.Equal(t, storagePath, modelEntry.StoragePath) + assert.Nil(t, modelEntry.Config) + assert.Equal(t, progress, modelEntry.Progress) +} + +func TestReconcileModelStatusClearsCachedMetadataOnStorageIdentityChange(t *testing.T) { + reconciler, _, _ := setupConfigMapTest(t) + ctx := context.Background() + + oldStorageURI := "hf://google/gemma-4-31B-it" + oldStoragePath := "/raid/models/google/gemma-4-31B-it" + newStorageURI := "hf://google/gemma-4-31b-it" + newStoragePath := "/raid/models/google/gemma-4-31b-it" + baseModel := createTestBaseModelCM() + baseModel.Spec.Storage = &v1beta1.StorageSpec{ + StorageUri: &newStorageURI, + Path: &newStoragePath, + } + modelID := reconciler.getModelConfigMapKey(baseModel, nil) + + reconciler.modelCache[modelID] = &CacheEntry{ + ModelName: baseModel.Name, + ModelStatus: ModelStatusReady, + StorageURI: oldStorageURI, + StoragePath: oldStoragePath, + ModelMetadata: &ModelMetadata{ModelType: "gemma"}, + } + + err := reconciler.ReconcileModelStatus(ctx, &ConfigMapStatusOp{ + BaseModel: baseModel, + ModelStatus: ModelStatusUpdating, + }) + assert.NoError(t, err) + + cacheEntry := reconciler.modelCache[modelID] + assert.Equal(t, ModelStatusUpdating, cacheEntry.ModelStatus) + assert.Equal(t, newStorageURI, cacheEntry.StorageURI) + assert.Equal(t, newStoragePath, cacheEntry.StoragePath) + assert.Nil(t, cacheEntry.ModelMetadata) +} + // TestUpdateModelMetadataInConfigMap tests the updateModelMetadataInConfigMap method func TestUpdateModelMetadataInConfigMap(t *testing.T) { // Setup test environment diff --git a/pkg/modelagent/gopher.go b/pkg/modelagent/gopher.go index 36f97df4..afff9b2f 100644 --- a/pkg/modelagent/gopher.go +++ b/pkg/modelagent/gopher.go @@ -12,10 +12,10 @@ import ( "k8s.io/apimachinery/pkg/labels" - "github.com/oracle/oci-go-sdk/v65/objectstorage" "go.uber.org/zap" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/cache" "github.com/sgl-project/ome/pkg/apis/ome/v1beta1" omev1beta1lister "github.com/sgl-project/ome/pkg/client/listers/ome/v1beta1" @@ -59,6 +59,10 @@ type Gopher struct { configMapMutex sync.Mutex // Mutex to coordinate ConfigMap access baseModelLister omev1beta1lister.BaseModelLister clusterBaseModelLister omev1beta1lister.ClusterBaseModelLister + baseModelSynced cache.InformerSynced + clusterBaseModelSynced cache.InformerSynced + nodeShapeAlias string + integrityConfig IntegrityConfig // Track active downloads for cancellation activeDownloads map[string]context.CancelFunc // key: model UID @@ -83,11 +87,16 @@ func NewGopher( metrics *Metrics, logger *zap.SugaredLogger, baseModelLister omev1beta1lister.BaseModelLister, - clusterBaseModelLister omev1beta1lister.ClusterBaseModelLister) (*Gopher, error) { + clusterBaseModelLister omev1beta1lister.ClusterBaseModelLister, + baseModelSynced cache.InformerSynced, + clusterBaseModelSynced cache.InformerSynced, + nodeShapeAlias string, + integrityConfig IntegrityConfig) (*Gopher, error) { if xetConfig == nil { return nil, fmt.Errorf("xet hugging face config cannot be nil") } + normalizedIntegrityConfig := integrityConfig.normalized() return &Gopher{ modelConfigParser: modelConfigParser, @@ -105,6 +114,10 @@ func NewGopher( activeDownloads: make(map[string]context.CancelFunc), baseModelLister: baseModelLister, clusterBaseModelLister: clusterBaseModelLister, + baseModelSynced: baseModelSynced, + clusterBaseModelSynced: clusterBaseModelSynced, + nodeShapeAlias: nodeShapeAlias, + integrityConfig: normalizedIntegrityConfig, }, nil } @@ -113,6 +126,9 @@ func (s *Gopher) Run(stopCh <-chan struct{}, numWorker int) { s.configMapReconciler.StartReconciliation() s.logger.Info("Started ConfigMap reconciliation service") + // Start artifact integrity reconciliation for models already marked Ready on this node. + go s.runIntegrityReconciliationLoop(stopCh) + // Start worker goroutines for i := 0; i < numWorker; i++ { go s.runWorker() @@ -316,14 +332,21 @@ func (s *Gopher) processTask(task *GopherTask) error { // Record time for metrics downloadStartTime := time.Now() + artifactPath := "" switch storageType { case storage.StorageTypeOCI: osUri, err := getTargetDirPath(&baseModelSpec) - destPath := getDestPath(&baseModelSpec, s.modelRootDir) if err != nil { s.logger.Errorf("Failed to get target directory path for model %s: %v", modelInfo, err) return err } + destPath, err := resolveArtifactPath(&baseModelSpec, s.modelRootDir) + if err != nil { + s.logger.Errorf("Failed to resolve artifact path for model %s: %v", modelInfo, err) + s.markModelOnNodeFailed(task) + return err + } + artifactPath = destPath err = utils.Retry(s.downloadRetry, 100*time.Millisecond, func() error { downloadErr := s.downloadModel(ctx, osUri, destPath, task) if downloadErr != nil { @@ -373,6 +396,11 @@ func (s *Gopher) processTask(task *GopherTask) error { s.logger.Infof("Skipping download for model %s", modelInfo) case storage.StorageTypeHuggingFace: s.logger.Infof("Starting Hugging Face download for model %s", modelInfo) + artifactPath, err = resolveArtifactPath(&baseModelSpec, s.modelRootDir) + if err != nil { + s.markModelOnNodeFailed(task) + return err + } // Handle Hugging Face model download if err := s.processHuggingFaceModel(ctx, task, baseModelSpec, modelInfo, modelType, namespace, name); err != nil { @@ -386,6 +414,11 @@ func (s *Gopher) processTask(task *GopherTask) error { return nil case storage.StorageTypeLocal: s.logger.Infof("Processing local storage type for model %s", modelInfo) + artifactPath, err = resolveArtifactPath(&baseModelSpec, s.modelRootDir) + if err != nil { + s.markModelOnNodeFailed(task) + return err + } // For local storage, we just need to validate the path exists and parse model config if err := s.processLocalStorageModel(ctx, task, baseModelSpec, modelInfo, modelType, namespace, name); err != nil { return err @@ -393,6 +426,12 @@ func (s *Gopher) processTask(task *GopherTask) error { default: return fmt.Errorf("unknown storage type %s", storageType) } + if err := s.ensureArtifactManifest(ctx, task, &baseModelSpec, storageType, artifactPath); err != nil { + s.logger.Errorf("Failed to create or validate artifact manifest for model %s: %v", modelInfo, err) + s.metrics.RecordFailedDownload(modelType, namespace, name, "integrity_manifest_error") + s.markModelOnNodeFailed(task) + return err + } // Calculate download duration downloadDuration := time.Since(downloadStartTime) @@ -435,7 +474,10 @@ func (s *Gopher) processTask(task *GopherTask) error { switch storageType { case storage.StorageTypeOCI: s.logger.Infof("Starting deletion for model %s", modelInfo) - destPath := getDestPath(&baseModelSpec, s.modelRootDir) + destPath, err := resolveArtifactPath(&baseModelSpec, s.modelRootDir) + if err != nil { + return err + } // check if it needs to skip artifact deletion isSkippingDeletion, _, _, _ := s.isSkippingArtifactDeletion(ctx, task, destPath, false) if !isSkippingDeletion { @@ -454,8 +496,10 @@ func (s *Gopher) processTask(task *GopherTask) error { s.logger.Infof("Skipping deletion for model %s", modelInfo) case storage.StorageTypeHuggingFace: s.logger.Infof("Removing Hugging Face model %s", modelInfo) - // Use getDestPath to get the same path used during download - destPath := getDestPath(&baseModelSpec, s.modelRootDir) + destPath, err := resolveArtifactPath(&baseModelSpec, s.modelRootDir) + if err != nil { + return err + } // check if it needs to skip artifact deletion isSkippingDeletion, isRemoveParent, parentName, parentDir := s.isSkippingArtifactDeletion(ctx, task, destPath, true) @@ -661,22 +705,6 @@ func (s *Gopher) getHuggingFaceToken(task *GopherTask, baseModelSpec v1beta1.Bas return hfToken } -func getDestPath(baseModel *v1beta1.BaseModelSpec, modelRootDir string) string { - - storagePath := *baseModel.Storage.StorageUri - destPath := *baseModel.Storage.Path - - if len(destPath) == 0 { - if strings.HasSuffix(modelRootDir, "/") { - return modelRootDir + storagePath - } else { - return modelRootDir + "/" + storagePath - } - } - - return destPath -} - // getTargetDirPath determines the target directory path for a model based on its storage configuration func getTargetDirPath(baseModel *v1beta1.BaseModelSpec) (*ociobjectstore.ObjectURI, error) { @@ -778,22 +806,12 @@ func (s *Gopher) downloadModel(ctx context.Context, uri *ociobjectstore.ObjectUR s.logger.Infof("Done with list all %d objects in model bucket folder", len(objects)) - // Shape filtering for TensorRTLLM - if task.TensorRTLLMShapeFilter != nil && task.TensorRTLLMShapeFilter.IsTensorrtLLMModel && task.TensorRTLLMShapeFilter.ModelType == string(constants.ServingBaseModel) { + objects, filtered, err := filterObjectsForTensorRTLLM(objects, task.TensorRTLLMShapeFilter) + if err != nil { + return err + } + if filtered { s.logger.Infof("TensorRTLLM Serving model detected. Start filtering model files that doesn't belong to the node shape %s in model bucket folder", task.TensorRTLLMShapeFilter.ShapeAlias) - shapeFilteredObjects := make([]objectstorage.ObjectSummary, 0) - for _, object := range objects { - if object.Name != nil { - if strings.Contains(*object.Name, fmt.Sprintf("/%s/", task.TensorRTLLMShapeFilter.ShapeAlias)) { - shapeFilteredObjects = append(shapeFilteredObjects, object) - } - } - } - objects = shapeFilteredObjects - - if len(objects) == 0 { - return fmt.Errorf("no suitable objects found for shape %s", task.TensorRTLLMShapeFilter.ShapeAlias) - } s.logger.Infof("Found %d objects applicable for shape %s", len(objects), task.TensorRTLLMShapeFilter.ShapeAlias) } @@ -969,7 +987,13 @@ func (s *Gopher) processHuggingFaceModel(ctx context.Context, task *GopherTask, } // Create destination path - destPath := getDestPath(&baseModelSpec, s.modelRootDir) + destPath, err := resolveArtifactPath(&baseModelSpec, s.modelRootDir) + if err != nil { + s.logger.Errorf("Failed to resolve destination path for model %s: %v", modelInfo, err) + s.metrics.RecordFailedDownload(modelType, namespace, name, "invalid_storage_path") + s.markModelOnNodeFailed(task) + return err + } // fetch sha value based on model ID from Huggingface model API shaStr, isShaAvailable := s.fetchSha(ctx, hfComponents.ModelID, name) diff --git a/pkg/modelagent/integrity_reconciler.go b/pkg/modelagent/integrity_reconciler.go new file mode 100644 index 00000000..d98c332a --- /dev/null +++ b/pkg/modelagent/integrity_reconciler.go @@ -0,0 +1,815 @@ +package modelagent + +import ( + "context" + "encoding/json" + "fmt" + "hash/fnv" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "github.com/oracle/oci-go-sdk/v65/objectstorage" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/client-go/tools/cache" + + "github.com/sgl-project/ome/pkg/apis/ome/v1beta1" + "github.com/sgl-project/ome/pkg/constants" + hfmodelconfig "github.com/sgl-project/ome/pkg/hfutil/modelconfig" + "github.com/sgl-project/ome/pkg/ociobjectstore" + "github.com/sgl-project/ome/pkg/utils/storage" +) + +const ( + defaultIntegrityCheckInterval = 10 * time.Minute + defaultIntegrityDeepCheckInterval = 6 * time.Hour + defaultIntegrityStartupJitter = 30 * time.Second + + artifactManifestDir = ".ome" +) + +type IntegrityConfig struct { + CheckInterval time.Duration + DeepCheckInterval time.Duration + StartupJitter time.Duration +} + +func DefaultIntegrityConfig() IntegrityConfig { + return IntegrityConfig{ + CheckInterval: defaultIntegrityCheckInterval, + DeepCheckInterval: defaultIntegrityDeepCheckInterval, + StartupJitter: defaultIntegrityStartupJitter, + } +} + +func (c IntegrityConfig) normalized() IntegrityConfig { + if c.StartupJitter < 0 { + c.StartupJitter = 0 + } + return c +} + +func (c IntegrityConfig) enabled() bool { + return c.CheckInterval > 0 +} + +func (c IntegrityConfig) deepEnabled() bool { + return c.enabled() && c.DeepCheckInterval > 0 +} + +type integrityCheckType string + +const ( + integrityCheckBasic integrityCheckType = "basic" + integrityCheckDeep integrityCheckType = "deep" +) + +type integrityResult string + +const ( + integrityResultSuccess integrityResult = "success" + integrityResultFailure integrityResult = "failure" + integrityResultInconclusive integrityResult = "inconclusive" + integrityResultSkipped integrityResult = "skipped" +) + +type integrityReason string + +const ( + integrityReasonOK integrityReason = "ok" + integrityReasonBaselineCreated integrityReason = "baseline_created" + integrityReasonMissingPath integrityReason = "missing_path" + integrityReasonPathNotDirectory integrityReason = "path_not_directory" + integrityReasonMissingConfig integrityReason = "missing_config" + integrityReasonMissingWeight integrityReason = "missing_weight" + integrityReasonSizeMismatch integrityReason = "size_mismatch" + integrityReasonChecksumMismatch integrityReason = "checksum_mismatch" + integrityReasonSafetensorsCorrupt integrityReason = "safetensors_corrupt" + integrityReasonParseError integrityReason = "parse_error" + integrityReasonMetadataError integrityReason = "metadata_error" + integrityReasonManifestError integrityReason = "manifest_error" + integrityReasonStorageError integrityReason = "storage_error" + integrityReasonMarkFailedError integrityReason = "mark_failed_error" + integrityReasonSkippedStale integrityReason = "skipped_stale" +) + +type integrityReport struct { + Result integrityResult + Reason integrityReason + Message string + BytesScanned int64 +} + +func successReport(reason integrityReason, bytesScanned int64) integrityReport { + return integrityReport{Result: integrityResultSuccess, Reason: reason, BytesScanned: bytesScanned} +} + +func failureReport(reason integrityReason, message string) integrityReport { + return integrityReport{Result: integrityResultFailure, Reason: reason, Message: message} +} + +func inconclusiveReport(reason integrityReason, message string) integrityReport { + return integrityReport{Result: integrityResultInconclusive, Reason: reason, Message: message} +} + +type integrityModelRef struct { + Key string + BaseModel *v1beta1.BaseModel + ClusterBaseModel *v1beta1.ClusterBaseModel +} + +func (r integrityModelRef) spec() *v1beta1.BaseModelSpec { + if r.BaseModel != nil { + return &r.BaseModel.Spec + } + if r.ClusterBaseModel != nil { + return &r.ClusterBaseModel.Spec + } + return nil +} + +func (r integrityModelRef) task() *GopherTask { + return &GopherTask{ + BaseModel: r.BaseModel, + ClusterBaseModel: r.ClusterBaseModel, + } +} + +func (r integrityModelRef) modelTypeNamespaceName() (string, string, string) { + return GetModelTypeNamespaceAndName(r.task()) +} + +func (r integrityModelRef) logName() string { + return getModelInfoForLogging(r.task()) +} + +func (s *Gopher) runIntegrityReconciliationLoop(stopCh <-chan struct{}) { + if !s.integrityConfig.enabled() { + s.logger.Info("Model artifact integrity reconciliation is disabled") + return + } + + s.logger.Infof("Starting model artifact integrity reconciliation with interval %v, deep interval %v", + s.integrityConfig.CheckInterval, s.integrityConfig.DeepCheckInterval) + + if !cache.WaitForCacheSync(stopCh, s.baseModelSynced, s.clusterBaseModelSynced) { + s.logger.Warn("Stopping model artifact integrity reconciliation because model informer caches did not sync") + return + } + + if jitter := deterministicIntegrityJitter(s.nodeLabelReconciler.nodeName, s.integrityConfig.StartupJitter); jitter > 0 { + s.logger.Infof("Waiting %v before first model artifact integrity check", jitter) + timer := time.NewTimer(jitter) + select { + case <-timer.C: + case <-stopCh: + timer.Stop() + s.logger.Info("Stopping model artifact integrity reconciliation before first check") + return + } + } + + ticker := time.NewTicker(s.integrityConfig.CheckInterval) + defer ticker.Stop() + + var lastDeepCheck time.Time + for { + checkType := s.integrityCheckTypeForCycle(&lastDeepCheck) + s.reconcileReadyModelIntegrity(context.Background(), checkType) + + select { + case <-ticker.C: + case <-stopCh: + s.logger.Info("Stopping model artifact integrity reconciliation") + return + } + } +} + +func (s *Gopher) integrityCheckTypeForCycle(lastDeepCheck *time.Time) integrityCheckType { + if !s.integrityConfig.deepEnabled() { + return integrityCheckBasic + } + if lastDeepCheck.IsZero() || time.Since(*lastDeepCheck) >= s.integrityConfig.DeepCheckInterval { + *lastDeepCheck = time.Now() + return integrityCheckDeep + } + return integrityCheckBasic +} + +func (s *Gopher) reconcileReadyModelIntegrity(ctx context.Context, checkType integrityCheckType) { + start := time.Now() + summary := struct { + candidates int + success int + failure int + inconclusive int + skipped int + bytesScanned int64 + }{} + + cm, err := s.configMapReconciler.getConfigMap(ctx) + if err != nil { + s.logger.Warnf("Skipping model artifact integrity check: failed to read node ConfigMap: %v", err) + return + } + + modelRefs, err := s.buildIntegrityModelRefIndex() + if err != nil { + s.logger.Warnf("Skipping model artifact integrity check: failed to list model resources: %v", err) + return + } + + for key, raw := range cm.Data { + var entry ModelEntry + if err := json.Unmarshal([]byte(raw), &entry); err != nil { + s.logger.Warnf("Skipping unparsable model status entry %s during artifact integrity check: %v", key, err) + summary.skipped++ + continue + } + if entry.Status != ModelStatusReady { + continue + } + + ref, ok := modelRefs[key] + if !ok { + s.logger.Debugf("Skipping Ready model status entry %s during artifact integrity check: model no longer exists or key is unknown", key) + summary.skipped++ + continue + } + + spec := ref.spec() + if spec == nil || spec.Storage == nil || spec.Storage.StorageUri == nil { + s.recordIntegrityResult(ref, "", checkType, inconclusiveReport(integrityReasonStorageError, "model storage spec is missing"), 0) + summary.inconclusive++ + continue + } + + storageType, err := storage.GetStorageType(*spec.Storage.StorageUri) + if err != nil { + s.recordIntegrityResult(ref, "", checkType, inconclusiveReport(integrityReasonStorageError, err.Error()), 0) + summary.inconclusive++ + continue + } + + if entry.hasStorageIdentity() && !entry.MatchesStorageIdentity(spec) { + report := integrityReport{Result: integrityResultSkipped, Reason: integrityReasonSkippedStale} + s.recordIntegrityResult(ref, string(storageType), checkType, report, 0) + s.logger.Infow("Skipping stale Ready model status entry during artifact integrity check", + "model", ref.logName(), + "key", key, + "entryStorageUri", entry.StorageURI, + "entryStoragePath", entry.StoragePath) + summary.skipped++ + continue + } + + summary.candidates++ + candidateStart := time.Now() + report := s.validateReadyModelArtifact(ctx, ref, entry, storageType, checkType) + summary.bytesScanned += report.BytesScanned + s.recordIntegrityResult(ref, string(storageType), checkType, report, time.Since(candidateStart)) + + switch report.Result { + case integrityResultSuccess: + summary.success++ + if !entry.hasStorageIdentity() { + if err := s.backfillReadyStorageIdentity(ctx, ref); err != nil { + s.logger.Warnf("Artifact integrity check succeeded for %s but failed to backfill storage identity: %v", ref.logName(), err) + } + } + case integrityResultFailure: + summary.failure++ + s.logger.Warnf("Artifact integrity check failed for %s: reason=%s message=%s", ref.logName(), report.Reason, report.Message) + if err := s.markIntegrityFailureIfCurrent(ctx, key, ref); err != nil { + s.logger.Warnf("Failed to mark %s as Failed after artifact integrity failure: %v", ref.logName(), err) + markReport := integrityReport{Result: integrityResultFailure, Reason: integrityReasonMarkFailedError, Message: err.Error()} + s.recordIntegrityResult(ref, string(storageType), checkType, markReport, 0) + } + case integrityResultInconclusive: + summary.inconclusive++ + s.logger.Warnf("Artifact integrity check inconclusive for %s: reason=%s message=%s", ref.logName(), report.Reason, report.Message) + } + } + + s.logger.Infof("Completed model artifact integrity reconciliation checkType=%s candidates=%d success=%d failure=%d inconclusive=%d skipped=%d bytesScanned=%d duration=%v", + checkType, summary.candidates, summary.success, summary.failure, summary.inconclusive, summary.skipped, summary.bytesScanned, time.Since(start).Round(time.Millisecond)) +} + +func (s *Gopher) buildIntegrityModelRefIndex() (map[string]integrityModelRef, error) { + refs := make(map[string]integrityModelRef) + + clusterBaseModels, err := s.clusterBaseModelLister.List(labels.Everything()) + if err != nil { + return nil, err + } + for _, clusterBaseModel := range clusterBaseModels { + key := constants.GetModelConfigMapKey("", clusterBaseModel.Name, true) + refs[key] = integrityModelRef{Key: key, ClusterBaseModel: clusterBaseModel} + } + + baseModels, err := s.baseModelLister.List(labels.Everything()) + if err != nil { + return nil, err + } + for _, baseModel := range baseModels { + key := constants.GetModelConfigMapKey(baseModel.Namespace, baseModel.Name, false) + refs[key] = integrityModelRef{Key: key, BaseModel: baseModel} + } + + return refs, nil +} + +func (s *Gopher) validateReadyModelArtifact(ctx context.Context, ref integrityModelRef, entry ModelEntry, storageType storage.StorageType, checkType integrityCheckType) integrityReport { + spec := ref.spec() + modelPath := entry.StoragePath + if modelPath == "" { + resolvedPath, err := resolveArtifactPath(spec, s.modelRootDir) + if err != nil { + return inconclusiveReport(integrityReasonStorageError, err.Error()) + } + modelPath = resolvedPath + } + + switch storageType { + case storage.StorageTypeOCI: + return s.validateOCIArtifact(ctx, ref, modelPath, checkType) + case storage.StorageTypeHuggingFace, storage.StorageTypeLocal: + return s.validateFilesystemArtifact(ctx, ref, modelPath, storageType, checkType) + case storage.StorageTypeVendor, storage.StorageTypePVC: + return inconclusiveReport(integrityReasonStorageError, fmt.Sprintf("periodic integrity validation is not supported for storage type %s", storageType)) + default: + return inconclusiveReport(integrityReasonStorageError, fmt.Sprintf("unknown storage type %s", storageType)) + } +} + +func (s *Gopher) validateOCIArtifact(ctx context.Context, ref integrityModelRef, modelPath string, checkType integrityCheckType) integrityReport { + spec := ref.spec() + fsReport := s.validateFilesystemArtifact(ctx, ref, modelPath, storage.StorageTypeOCI, checkType) + if fsReport.Result == integrityResultFailure { + return fsReport + } + + osUri, err := getTargetDirPath(spec) + if err != nil { + return inconclusiveReport(integrityReasonStorageError, err.Error()) + } + + ociOSDataStore, err := s.createOCIOSDataStore(*spec) + if err != nil { + return inconclusiveReport(integrityReasonStorageError, err.Error()) + } + + objects, err := ociOSDataStore.ListObjects(*osUri) + if err != nil { + return inconclusiveReport(integrityReasonMetadataError, err.Error()) + } + + filter := tensorRTLLMShapeFilterForSpec(spec, s.nodeShapeAlias) + objects, _, err = filterObjectsForTensorRTLLM(objects, filter) + if err != nil { + return failureReport(integrityReasonMissingWeight, err.Error()) + } + if len(objects) == 0 { + return inconclusiveReport(integrityReasonMetadataError, "no OCI objects found for model source") + } + + for _, obj := range objects { + if obj.Name == nil { + continue + } + source := ociobjectstore.ObjectURI{ + Namespace: osUri.Namespace, + BucketName: osUri.BucketName, + ObjectName: *obj.Name, + Prefix: osUri.Prefix, + } + localPath := filepath.Join(modelPath, ociobjectstore.TrimObjectPrefix(source.ObjectName, source.Prefix)) + var result ociobjectstore.LocalCopyValidationResult + var err error + if checkType == integrityCheckDeep { + result, err = ociOSDataStore.ValidateLocalCopy(source, localPath) + } else { + result, err = validateOCIObjectSummaryLocalCopy(*obj.Name, obj.Size, localPath) + } + if err != nil { + return inconclusiveReport(integrityReasonMetadataError, err.Error()) + } + switch result.State { + case ociobjectstore.LocalCopyValidationInvalid: + return failureReport(ociValidationReason(result.Reason), fmt.Sprintf("%s: %s", source.ObjectName, result.Message)) + case ociobjectstore.LocalCopyValidationInconclusive: + if checkType == integrityCheckDeep && fsReport.BytesScanned == 0 { + return inconclusiveReport(integrityReasonMetadataError, fmt.Sprintf("%s: %s", source.ObjectName, result.Message)) + } + s.logger.Debugf("OCI local copy validation was inconclusive for %s: %s", source.ObjectName, result.Message) + } + } + + return fsReport +} + +func ociValidationReason(reason ociobjectstore.LocalCopyValidationReason) integrityReason { + switch reason { + case ociobjectstore.LocalCopyValidationReasonMissing: + return integrityReasonMissingWeight + case ociobjectstore.LocalCopyValidationReasonSizeMismatch: + return integrityReasonSizeMismatch + case ociobjectstore.LocalCopyValidationReasonChecksumMismatch: + return integrityReasonChecksumMismatch + default: + return integrityReasonMetadataError + } +} + +func validateOCIObjectSummaryLocalCopy(objectName string, objectSize *int64, localPath string) (ociobjectstore.LocalCopyValidationResult, error) { + fileInfo, err := os.Stat(localPath) + if err != nil { + if os.IsNotExist(err) { + return ociobjectstore.LocalCopyValidationResult{ + State: ociobjectstore.LocalCopyValidationInvalid, + Reason: ociobjectstore.LocalCopyValidationReasonMissing, + Message: fmt.Sprintf("local file for %s does not exist", objectName), + }, nil + } + return ociobjectstore.LocalCopyValidationResult{}, err + } + if objectSize != nil && fileInfo.Size() != *objectSize { + return ociobjectstore.LocalCopyValidationResult{ + State: ociobjectstore.LocalCopyValidationInvalid, + Reason: ociobjectstore.LocalCopyValidationReasonSizeMismatch, + Message: fmt.Sprintf("file size mismatch for %s: expected %d got %d", objectName, *objectSize, fileInfo.Size()), + }, nil + } + return ociobjectstore.LocalCopyValidationResult{ + State: ociobjectstore.LocalCopyValidationValid, + Reason: ociobjectstore.LocalCopyValidationReasonOK, + }, nil +} + +func (s *Gopher) validateFilesystemArtifact(ctx context.Context, ref integrityModelRef, modelPath string, storageType storage.StorageType, checkType integrityCheckType) integrityReport { + if err := validateArtifactPath(modelPath); err != nil { + if os.IsNotExist(err) { + return failureReport(integrityReasonMissingPath, err.Error()) + } + return failureReport(integrityReasonPathNotDirectory, err.Error()) + } + + if storageType == storage.StorageTypeHuggingFace || storageType == storage.StorageTypeLocal { + if report := s.validateModelConfigReadOnly(modelPath, ref); report.Result == integrityResultFailure { + return report + } + if report := validateWeightArtifacts(modelPath); report.Result == integrityResultFailure { + return report + } + } + + report := validateArtifactManifest(ctx, ref.spec(), s.modelRootDir, modelPath, checkType == integrityCheckDeep) + shouldCreateManifest := s.integrityConfig.deepEnabled() && checkType == integrityCheckDeep + if report.Result == integrityResultInconclusive && report.Reason == integrityReasonManifestError && shouldCreateManifest { + created, err := createArtifactManifest(ctx, ref.spec(), s.modelRootDir, modelPath, storageType) + if err != nil { + return inconclusiveReport(integrityReasonManifestError, err.Error()) + } + return successReport(integrityReasonBaselineCreated, created.BytesScanned) + } + if report.Result == integrityResultInconclusive && report.Reason == integrityReasonManifestError { + return successReport(integrityReasonOK, 0) + } + return report +} + +func (s *Gopher) validateModelConfigReadOnly(modelPath string, ref integrityModelRef) integrityReport { + if s.modelConfigParser == nil { + return inconclusiveReport(integrityReasonMetadataError, "model config parser is not initialized") + } + if s.modelConfigParser.shouldSkipConfigParsing(ref.BaseModel, ref.ClusterBaseModel) { + return successReport(integrityReasonOK, 0) + } + if _, err := s.modelConfigParser.ParseModelConfig(modelPath, nil, nil); err != nil { + if strings.Contains(err.Error(), "no model_index.json or config.json") { + return failureReport(integrityReasonMissingConfig, err.Error()) + } + return failureReport(integrityReasonParseError, err.Error()) + } + return successReport(integrityReasonOK, 0) +} + +func validateArtifactPath(modelPath string) error { + info, err := os.Stat(modelPath) + if err != nil { + return err + } + if !info.IsDir() { + return fmt.Errorf("model artifact path is not a directory: %s", modelPath) + } + return nil +} + +func validateWeightArtifacts(modelPath string) integrityReport { + indexReport, indexFound := validateWeightIndexFiles(modelPath) + if indexFound { + return indexReport + } + + weightFound := false + var parseError error + err := filepath.WalkDir(modelPath, func(path string, entry os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if entry.IsDir() { + if entry.Name() == artifactManifestDir { + return filepath.SkipDir + } + return nil + } + if !isWeightFile(path) { + return nil + } + weightFound = true + if err := validateWeightFile(path); err != nil { + parseError = err + return err + } + return nil + }) + if err != nil { + if parseError != nil { + return failureReport(integrityReasonSafetensorsCorrupt, parseError.Error()) + } + return failureReport(integrityReasonParseError, err.Error()) + } + if !weightFound { + return failureReport(integrityReasonMissingWeight, fmt.Sprintf("no model weight files found in %s", modelPath)) + } + return successReport(integrityReasonOK, 0) +} + +func validateWeightIndexFiles(modelPath string) (integrityReport, bool) { + foundIndex := false + var validationErr error + err := filepath.WalkDir(modelPath, func(path string, entry os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if entry.IsDir() { + if entry.Name() == artifactManifestDir { + return filepath.SkipDir + } + return nil + } + if !isWeightIndexFile(entry.Name()) { + return nil + } + shards, err := readWeightMapShards(path) + if err != nil { + validationErr = err + return err + } + if len(shards) == 0 { + return nil + } + foundIndex = true + for _, shard := range shards { + shardPath := filepath.Join(filepath.Dir(path), shard) + if err := validateWeightFile(shardPath); err != nil { + validationErr = err + return err + } + } + return nil + }) + if err != nil { + reason := integrityReasonParseError + if validationErr != nil && strings.Contains(validationErr.Error(), "safetensors") { + reason = integrityReasonSafetensorsCorrupt + } + if validationErr == nil { + validationErr = err + } + return failureReport(reason, validationErr.Error()), foundIndex + } + if foundIndex { + return successReport(integrityReasonOK, 0), true + } + return successReport(integrityReasonOK, 0), false +} + +func readWeightMapShards(indexPath string) ([]string, error) { + data, err := os.ReadFile(indexPath) + if err != nil { + return nil, err + } + var index struct { + WeightMap map[string]string `json:"weight_map"` + } + if err := json.Unmarshal(data, &index); err != nil { + return nil, err + } + if len(index.WeightMap) == 0 { + return nil, nil + } + seen := make(map[string]struct{}) + for _, shard := range index.WeightMap { + if shard == "" { + return nil, fmt.Errorf("empty shard filename found in %s", indexPath) + } + seen[shard] = struct{}{} + } + shards := make([]string, 0, len(seen)) + for shard := range seen { + shards = append(shards, shard) + } + sort.Strings(shards) + return shards, nil +} + +func validateWeightFile(path string) error { + info, err := os.Stat(path) + if err != nil { + return err + } + if !info.Mode().IsRegular() { + return fmt.Errorf("weight artifact is not a regular file: %s", path) + } + if info.Size() == 0 { + return fmt.Errorf("weight artifact is empty: %s", path) + } + if strings.HasSuffix(path, ".safetensors") { + if _, err := hfmodelconfig.ParseSafetensors(path); err != nil { + return err + } + } + return nil +} + +func isWeightFile(path string) bool { + switch strings.ToLower(filepath.Ext(path)) { + case ".safetensors", ".bin", ".gguf", ".pt", ".pth", ".onnx", ".ckpt", ".msgpack", ".h5", ".pb", ".engine", ".plan": + return true + default: + return false + } +} + +func isWeightIndexFile(name string) bool { + lowerName := strings.ToLower(name) + return strings.HasSuffix(lowerName, ".safetensors.index.json") || strings.HasSuffix(lowerName, ".bin.index.json") +} + +func (s *Gopher) ensureArtifactManifest(ctx context.Context, task *GopherTask, spec *v1beta1.BaseModelSpec, storageType storage.StorageType, modelPath string) error { + if !s.integrityConfig.deepEnabled() { + return nil + } + if storageType == storage.StorageTypeVendor || storageType == storage.StorageTypePVC { + return nil + } + if modelPath == "" { + resolvedPath, err := resolveArtifactPath(spec, s.modelRootDir) + if err != nil { + return err + } + modelPath = resolvedPath + } + if err := validateArtifactPath(modelPath); err != nil { + return err + } + if storageType == storage.StorageTypeHuggingFace || storageType == storage.StorageTypeLocal { + ref := integrityModelRef{} + if task != nil { + ref.BaseModel = task.BaseModel + ref.ClusterBaseModel = task.ClusterBaseModel + } + if report := s.validateModelConfigReadOnly(modelPath, ref); report.Result == integrityResultFailure { + return fmt.Errorf("artifact manifest validation failed: %s", report.Message) + } + if report := validateWeightArtifacts(modelPath); report.Result == integrityResultFailure { + return fmt.Errorf("artifact manifest validation failed: %s", report.Message) + } + } + if _, err := createArtifactManifest(ctx, spec, s.modelRootDir, modelPath, storageType); err != nil { + return err + } + return nil +} + +func (s *Gopher) backfillReadyStorageIdentity(ctx context.Context, ref integrityModelRef) error { + return s.configMapReconciler.ReconcileModelStatus(ctx, &ConfigMapStatusOp{ + ModelStatus: ModelStatusReady, + BaseModel: ref.BaseModel, + ClusterBaseModel: ref.ClusterBaseModel, + }) +} + +func (s *Gopher) markIntegrityFailureIfCurrent(ctx context.Context, key string, ref integrityModelRef) error { + s.configMapMutex.Lock() + defer s.configMapMutex.Unlock() + + cm, err := s.configMapReconciler.getConfigMap(ctx) + if err != nil { + return err + } + raw, ok := cm.Data[key] + if !ok { + return fmt.Errorf("model entry %s disappeared before marking Failed", key) + } + var current ModelEntry + if err := json.Unmarshal([]byte(raw), ¤t); err != nil { + return err + } + if current.Status != ModelStatusReady { + return fmt.Errorf("model entry %s is no longer Ready; current status is %s", key, current.Status) + } + if !current.MatchesStorageIdentity(ref.spec()) { + return fmt.Errorf("model entry %s storage identity changed before marking Failed", key) + } + op := &NodeLabelOp{ + ModelStateOnNode: Failed, + BaseModel: ref.BaseModel, + ClusterBaseModel: ref.ClusterBaseModel, + } + if err := s.nodeLabelReconciler.ReconcileNodeLabels(op); err != nil { + return err + } + return s.configMapReconciler.ReconcileModelStatus(ctx, &ConfigMapStatusOp{ + ModelStatus: ModelStatusFailed, + BaseModel: ref.BaseModel, + ClusterBaseModel: ref.ClusterBaseModel, + }) +} + +func (s *Gopher) recordIntegrityResult(ref integrityModelRef, storageType string, checkType integrityCheckType, report integrityReport, duration time.Duration) { + modelType, namespace, name := ref.modelTypeNamespaceName() + s.metrics.RecordIntegrityCheck(modelType, namespace, name, storageType, string(checkType), string(report.Result), string(report.Reason), duration, report.BytesScanned) +} + +func resolveArtifactPath(spec *v1beta1.BaseModelSpec, modelRootDir string) (string, error) { + if spec == nil || spec.Storage == nil || spec.Storage.StorageUri == nil { + return "", fmt.Errorf("model storage URI is missing") + } + storageURI := *spec.Storage.StorageUri + storageType, err := storage.GetStorageType(storageURI) + if err != nil { + return "", err + } + if spec.Storage.Path != nil && *spec.Storage.Path != "" { + return *spec.Storage.Path, nil + } + if storageType == storage.StorageTypeLocal { + localComponents, err := storage.ParseLocalStorageURI(storageURI) + if err != nil { + return "", err + } + return localComponents.Path, nil + } + if strings.HasSuffix(modelRootDir, "/") { + return modelRootDir + storageURI, nil + } + return modelRootDir + "/" + storageURI, nil +} + +func tensorRTLLMShapeFilterForSpec(spec *v1beta1.BaseModelSpec, nodeShapeAlias string) *TensorRTLLMShapeFilter { + if spec == nil { + return nil + } + modelType := string(constants.ServingBaseModel) + if spec.AdditionalMetadata != nil { + if modelTypeFromMetadata, ok := spec.AdditionalMetadata["type"]; ok { + modelType = modelTypeFromMetadata + } + } + return &TensorRTLLMShapeFilter{ + IsTensorrtLLMModel: spec.ModelFormat.Name == constants.TensorRTLLM, + ShapeAlias: nodeShapeAlias, + ModelType: modelType, + } +} + +func filterObjectsForTensorRTLLM(objects []objectstorage.ObjectSummary, filter *TensorRTLLMShapeFilter) ([]objectstorage.ObjectSummary, bool, error) { + if filter == nil || !filter.IsTensorrtLLMModel || filter.ModelType != string(constants.ServingBaseModel) { + return objects, false, nil + } + + filtered := make([]objectstorage.ObjectSummary, 0) + for _, object := range objects { + if object.Name != nil && strings.Contains(*object.Name, fmt.Sprintf("/%s/", filter.ShapeAlias)) { + filtered = append(filtered, object) + } + } + if len(filtered) == 0 { + return nil, true, fmt.Errorf("no suitable objects found for shape %s", filter.ShapeAlias) + } + return filtered, true, nil +} + +func deterministicIntegrityJitter(nodeName string, maxJitter time.Duration) time.Duration { + if maxJitter <= 0 || nodeName == "" { + return 0 + } + hash := fnv.New64a() + _, _ = hash.Write([]byte(nodeName)) + return time.Duration(hash.Sum64() % uint64(maxJitter)) +} diff --git a/pkg/modelagent/integrity_reconciler_test.go b/pkg/modelagent/integrity_reconciler_test.go new file mode 100644 index 00000000..60357d1a --- /dev/null +++ b/pkg/modelagent/integrity_reconciler_test.go @@ -0,0 +1,284 @@ +package modelagent + +import ( + "context" + "encoding/binary" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/kubernetes/fake" + "k8s.io/client-go/tools/cache" + + "github.com/sgl-project/ome/pkg/apis/ome/v1beta1" + omev1beta1lister "github.com/sgl-project/ome/pkg/client/listers/ome/v1beta1" + "github.com/sgl-project/ome/pkg/constants" + hfmodelconfig "github.com/sgl-project/ome/pkg/hfutil/modelconfig" + "github.com/sgl-project/ome/pkg/ociobjectstore" + "github.com/sgl-project/ome/pkg/utils/storage" +) + +func TestValidateFilesystemArtifactDetectsSameSizeCorruption(t *testing.T) { + modelPath := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(modelPath, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + weightPath := filepath.Join(modelPath, "model.safetensors") + writeTinySafetensors(t, weightPath) + + gopher := newIntegrityTestGopher(t, nil, nil) + report := gopher.validateFilesystemArtifact(context.Background(), integrityModelRef{ + BaseModel: testIntegrityBaseModel("default", "llama", "hf://meta-llama/llama", modelPath), + }, modelPath, storage.StorageTypeHuggingFace, integrityCheckDeep) + require.Equal(t, integrityResultSuccess, report.Result) + require.Equal(t, integrityReasonBaselineCreated, report.Reason) + + data, err := os.ReadFile(weightPath) + require.NoError(t, err) + data[len(data)-1] ^= 0xff + require.NoError(t, os.WriteFile(weightPath, data, 0644)) + + report = gopher.validateFilesystemArtifact(context.Background(), integrityModelRef{ + BaseModel: testIntegrityBaseModel("default", "llama", "hf://meta-llama/llama", modelPath), + }, modelPath, storage.StorageTypeHuggingFace, integrityCheckDeep) + assert.Equal(t, integrityResultFailure, report.Result) + assert.Equal(t, integrityReasonChecksumMismatch, report.Reason) +} + +func TestReconcileReadyModelIntegrityMarksMissingLocalPathFailed(t *testing.T) { + nodeName := "test-node" + namespace := constants.OMENamespace + missingPath := filepath.Join(t.TempDir(), "missing") + baseModel := testIntegrityBaseModel("default", "llama", "local:///models/llama", missingPath) + key := constants.GetModelConfigMapKey(baseModel.Namespace, baseModel.Name, false) + entry := ModelEntry{ + Name: baseModel.Name, + Status: ModelStatusReady, + StorageURI: *baseModel.Spec.Storage.StorageUri, + StoragePath: missingPath, + } + cm := testIntegrityConfigMap(t, nodeName, namespace, key, entry) + node := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: nodeName, Labels: map[string]string{}}} + + gopher := newIntegrityTestGopher(t, []*v1beta1.BaseModel{baseModel}, nil, node, cm) + gopher.reconcileReadyModelIntegrity(context.Background(), integrityCheckBasic) + + updatedCM, err := gopher.kubeClient.CoreV1().ConfigMaps(namespace).Get(context.Background(), nodeName, metav1.GetOptions{}) + require.NoError(t, err) + var updatedEntry ModelEntry + require.NoError(t, json.Unmarshal([]byte(updatedCM.Data[key]), &updatedEntry)) + assert.Equal(t, ModelStatusFailed, updatedEntry.Status) + + updatedNode, err := gopher.kubeClient.CoreV1().Nodes().Get(context.Background(), nodeName, metav1.GetOptions{}) + require.NoError(t, err) + assert.Equal(t, string(Failed), updatedNode.Labels[constants.GetBaseModelLabel(baseModel.Namespace, baseModel.Name)]) +} + +func TestReconcileReadyModelIntegritySkipsStaleStorageIdentity(t *testing.T) { + nodeName := "test-node" + namespace := constants.OMENamespace + missingPath := filepath.Join(t.TempDir(), "missing") + currentStorageURI := "local:///models/current" + baseModel := testIntegrityBaseModel("default", "llama", currentStorageURI, missingPath) + key := constants.GetModelConfigMapKey(baseModel.Namespace, baseModel.Name, false) + entry := ModelEntry{ + Name: baseModel.Name, + Status: ModelStatusReady, + StorageURI: "local:///models/old", + StoragePath: "/old/path", + } + cm := testIntegrityConfigMap(t, nodeName, namespace, key, entry) + node := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: nodeName, Labels: map[string]string{}}} + + gopher := newIntegrityTestGopher(t, []*v1beta1.BaseModel{baseModel}, nil, node, cm) + gopher.reconcileReadyModelIntegrity(context.Background(), integrityCheckBasic) + + updatedCM, err := gopher.kubeClient.CoreV1().ConfigMaps(namespace).Get(context.Background(), nodeName, metav1.GetOptions{}) + require.NoError(t, err) + var updatedEntry ModelEntry + require.NoError(t, json.Unmarshal([]byte(updatedCM.Data[key]), &updatedEntry)) + assert.Equal(t, ModelStatusReady, updatedEntry.Status) +} + +func TestReconcileReadyModelIntegrityBackfillsLegacyStorageIdentity(t *testing.T) { + nodeName := "test-node" + namespace := constants.OMENamespace + modelPath := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(modelPath, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + writeTinySafetensors(t, filepath.Join(modelPath, "model.safetensors")) + + storageURI := "local:///models/llama" + baseModel := testIntegrityBaseModel("default", "llama", storageURI, modelPath) + key := constants.GetModelConfigMapKey(baseModel.Namespace, baseModel.Name, false) + entry := ModelEntry{ + Name: baseModel.Name, + Status: ModelStatusReady, + } + cm := testIntegrityConfigMap(t, nodeName, namespace, key, entry) + node := &corev1.Node{ObjectMeta: metav1.ObjectMeta{Name: nodeName, Labels: map[string]string{}}} + + gopher := newIntegrityTestGopher(t, []*v1beta1.BaseModel{baseModel}, nil, node, cm) + gopher.reconcileReadyModelIntegrity(context.Background(), integrityCheckBasic) + + updatedCM, err := gopher.kubeClient.CoreV1().ConfigMaps(namespace).Get(context.Background(), nodeName, metav1.GetOptions{}) + require.NoError(t, err) + var updatedEntry ModelEntry + require.NoError(t, json.Unmarshal([]byte(updatedCM.Data[key]), &updatedEntry)) + assert.Equal(t, ModelStatusReady, updatedEntry.Status) + assert.Equal(t, storageURI, updatedEntry.StorageURI) + assert.Equal(t, modelPath, updatedEntry.StoragePath) +} + +func TestBuildIntegrityModelRefIndexUsesGeneratedConfigMapKeys(t *testing.T) { + longName := "this-is-a-very-long-model-name-that-requires-hashing-to-fit-config-map-key-limits" + baseModel := testIntegrityBaseModel("default", longName, "local:///models/long", "/models/long") + gopher := newIntegrityTestGopher(t, []*v1beta1.BaseModel{baseModel}, nil) + + refs, err := gopher.buildIntegrityModelRefIndex() + require.NoError(t, err) + key := constants.GetModelConfigMapKey(baseModel.Namespace, baseModel.Name, false) + ref, ok := refs[key] + require.True(t, ok) + assert.Equal(t, baseModel.Name, ref.BaseModel.Name) +} + +func TestIntegrityCheckTypeForCycle(t *testing.T) { + gopher := &Gopher{ + integrityConfig: IntegrityConfig{ + CheckInterval: time.Minute, + DeepCheckInterval: time.Hour, + }, + } + var lastDeep time.Time + assert.Equal(t, integrityCheckDeep, gopher.integrityCheckTypeForCycle(&lastDeep)) + assert.False(t, lastDeep.IsZero()) + assert.Equal(t, integrityCheckBasic, gopher.integrityCheckTypeForCycle(&lastDeep)) + lastDeep = time.Now().Add(-2 * time.Hour) + assert.Equal(t, integrityCheckDeep, gopher.integrityCheckTypeForCycle(&lastDeep)) + + gopher.integrityConfig.DeepCheckInterval = 0 + assert.Equal(t, integrityCheckBasic, gopher.integrityCheckTypeForCycle(&lastDeep)) +} + +func TestValidateOCIObjectSummaryLocalCopy(t *testing.T) { + localPath := filepath.Join(t.TempDir(), "model.safetensors") + require.NoError(t, os.WriteFile(localPath, []byte("abc"), 0644)) + + size := int64(3) + result, err := validateOCIObjectSummaryLocalCopy("model.safetensors", &size, localPath) + require.NoError(t, err) + assert.Equal(t, ociobjectstore.LocalCopyValidationValid, result.State) + + size = 4 + result, err = validateOCIObjectSummaryLocalCopy("model.safetensors", &size, localPath) + require.NoError(t, err) + assert.Equal(t, ociobjectstore.LocalCopyValidationInvalid, result.State) + assert.Equal(t, ociobjectstore.LocalCopyValidationReasonSizeMismatch, result.Reason) + + result, err = validateOCIObjectSummaryLocalCopy("missing.safetensors", nil, filepath.Join(t.TempDir(), "missing")) + require.NoError(t, err) + assert.Equal(t, ociobjectstore.LocalCopyValidationInvalid, result.State) + assert.Equal(t, ociobjectstore.LocalCopyValidationReasonMissing, result.Reason) +} + +func TestEnsureArtifactManifestReturnsCreateError(t *testing.T) { + modelPath := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(modelPath, "config.json"), []byte(`{"model_type":"llama"}`), 0644)) + writeTinySafetensors(t, filepath.Join(modelPath, "model.safetensors")) + + modelRootFile := filepath.Join(t.TempDir(), "model-root-file") + require.NoError(t, os.WriteFile(modelRootFile, []byte("not a directory"), 0644)) + + baseModel := testIntegrityBaseModel("default", "llama", "local:///models/llama", modelPath) + gopher := newIntegrityTestGopher(t, []*v1beta1.BaseModel{baseModel}, nil) + gopher.modelRootDir = modelRootFile + + err := gopher.ensureArtifactManifest(context.Background(), (&integrityModelRef{BaseModel: baseModel}).task(), &baseModel.Spec, storage.StorageTypeLocal, modelPath) + require.Error(t, err) +} + +func newIntegrityTestGopher(t *testing.T, baseModels []*v1beta1.BaseModel, clusterBaseModels []*v1beta1.ClusterBaseModel, objects ...runtime.Object) *Gopher { + t.Helper() + logger := zaptest.NewLogger(t).Sugar() + kubeClient := fake.NewSimpleClientset(objects...) + + baseIndexer := cache.NewIndexer(cache.MetaNamespaceKeyFunc, cache.Indexers{}) + for _, baseModel := range baseModels { + require.NoError(t, baseIndexer.Add(baseModel)) + } + clusterIndexer := cache.NewIndexer(cache.MetaNamespaceKeyFunc, cache.Indexers{}) + for _, clusterBaseModel := range clusterBaseModels { + require.NoError(t, clusterIndexer.Add(clusterBaseModel)) + } + + parser := &ModelConfigParser{ + logger: logger, + loadModelConfig: func(_ string) (hfmodelconfig.HuggingFaceModel, error) { + return createDefaultMockModel(), nil + }, + } + return &Gopher{ + modelConfigParser: parser, + configMapReconciler: NewConfigMapReconciler("test-node", constants.OMENamespace, kubeClient, logger), + kubeClient: kubeClient, + nodeLabelReconciler: NewNodeLabelReconciler("test-node", kubeClient, 1, logger), + metrics: NewMetrics(prometheus.NewRegistry()), + logger: logger, + modelRootDir: t.TempDir(), + baseModelLister: omev1beta1lister.NewBaseModelLister(baseIndexer), + clusterBaseModelLister: omev1beta1lister.NewClusterBaseModelLister(clusterIndexer), + nodeShapeAlias: "H100", + integrityConfig: DefaultIntegrityConfig(), + } +} + +func testIntegrityBaseModel(namespace, name, storageURI, storagePath string) *v1beta1.BaseModel { + return &v1beta1.BaseModel{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + Spec: v1beta1.BaseModelSpec{ + Storage: &v1beta1.StorageSpec{ + StorageUri: &storageURI, + Path: &storagePath, + }, + ModelFormat: v1beta1.ModelFormat{Name: "safetensors"}, + }, + } +} + +func testIntegrityConfigMap(t *testing.T, name, namespace, key string, entry ModelEntry) *corev1.ConfigMap { + t.Helper() + entryJSON, err := json.Marshal(entry) + require.NoError(t, err) + return &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + Labels: map[string]string{ + constants.ModelStatusConfigMapLabel: "true", + }, + }, + Data: map[string]string{ + key: string(entryJSON), + }, + } +} + +func writeTinySafetensors(t *testing.T, path string) { + t.Helper() + header := []byte(`{"weight":{"dtype":"F32","shape":[1],"data_offsets":[0,4]}}`) + content := make([]byte, 8+len(header)+4) + binary.LittleEndian.PutUint64(content[:8], uint64(len(header))) + copy(content[8:], header) + copy(content[8+len(header):], []byte{1, 2, 3, 4}) + require.NoError(t, os.WriteFile(path, content, 0644)) +} diff --git a/pkg/modelagent/metrics.go b/pkg/modelagent/metrics.go index e3f769b7..28c675ed 100644 --- a/pkg/modelagent/metrics.go +++ b/pkg/modelagent/metrics.go @@ -20,12 +20,15 @@ type Metrics struct { modelVerificationsTotal *prometheus.CounterVec mdChecksumsFailedTotal *prometheus.CounterVec rateLimitCounter *prometheus.CounterVec + integrityChecksTotal *prometheus.CounterVec + integrityBytesScannedTotal *prometheus.CounterVec // Histogram metrics modelDownloadDuration *prometheus.HistogramVec modelVerificationDuration prometheus.Histogram modelDownloadBytesTransferred *prometheus.CounterVec rateLimitWaitDuration *prometheus.HistogramVec + integrityCheckDuration *prometheus.HistogramVec // Go runtime metrics goGoroutines prometheus.Gauge @@ -172,6 +175,20 @@ func NewMetrics(registerer prometheus.Registerer) *Metrics { }, []string{"model_type", "namespace", "name"}, ), + integrityChecksTotal: promauto.With(registerer).NewCounterVec( + prometheus.CounterOpts{ + Name: "model_agent_integrity_checks_total", + Help: "The total number of periodic model artifact integrity checks", + }, + []string{"model_type", "namespace", "name", "storage_type", "check_type", "result", "reason"}, + ), + integrityBytesScannedTotal: promauto.With(registerer).NewCounterVec( + prometheus.CounterOpts{ + Name: "model_agent_integrity_bytes_scanned_total", + Help: "The total bytes scanned by periodic model artifact integrity checks", + }, + []string{"model_type", "namespace", "name", "storage_type", "check_type"}, + ), modelDownloadDuration: promauto.With(registerer).NewHistogramVec( prometheus.HistogramOpts{ Name: "model_agent_download_duration_seconds", @@ -200,6 +217,14 @@ func NewMetrics(registerer prometheus.Registerer) *Metrics { }, []string{"model_type", "namespace", "name"}, ), + integrityCheckDuration: promauto.With(registerer).NewHistogramVec( + prometheus.HistogramOpts{ + Name: "model_agent_integrity_check_duration_seconds", + Help: "The duration of periodic model artifact integrity checks in seconds", + Buckets: prometheus.ExponentialBuckets(0.1, 2, 12), + }, + []string{"model_type", "namespace", "name", "storage_type", "check_type", "result"}, + ), // Store Go runtime metrics goGoroutines: goGoroutines, goThreads: goThreads, @@ -259,6 +284,18 @@ func (m *Metrics) RecordRateLimit(modelType, namespace, name string, waitDuratio m.rateLimitWaitDuration.WithLabelValues(modelType, namespace, name).Observe(waitDuration.Seconds()) } +// RecordIntegrityCheck records a periodic model artifact integrity check. +func (m *Metrics) RecordIntegrityCheck(modelType, namespace, name, storageType, checkType, result, reason string, duration time.Duration, bytesScanned int64) { + if m == nil { + return + } + m.integrityChecksTotal.WithLabelValues(modelType, namespace, name, storageType, checkType, result, reason).Inc() + m.integrityCheckDuration.WithLabelValues(modelType, namespace, name, storageType, checkType, result).Observe(duration.Seconds()) + if bytesScanned > 0 { + m.integrityBytesScannedTotal.WithLabelValues(modelType, namespace, name, storageType, checkType).Add(float64(bytesScanned)) + } +} + // RegisterMetricsHandler registers the metrics HTTP handler func RegisterMetricsHandler(mux *http.ServeMux) { mux.Handle("/metrics", promhttp.Handler()) diff --git a/pkg/modelagent/metrics_test.go b/pkg/modelagent/metrics_test.go index 07c9349c..7ac213b4 100644 --- a/pkg/modelagent/metrics_test.go +++ b/pkg/modelagent/metrics_test.go @@ -29,6 +29,14 @@ func TestNewMetrics_RegistersMetrics(t *testing.T) { if count == 0 { t.Error("modelDownloadDuration did not record observation") } + + metrics.RecordIntegrityCheck("testtype", "testns", "testmodel", "HUGGINGFACE", "deep", "success", "ok", time.Second, 1024) + if got := testutil.ToFloat64(metrics.integrityChecksTotal.WithLabelValues("testtype", "testns", "testmodel", "HUGGINGFACE", "deep", "success", "ok")); got != 1 { + t.Errorf("integrityChecksTotal did not increment, got = %v, want = 1", got) + } + if got := testutil.ToFloat64(metrics.integrityBytesScannedTotal.WithLabelValues("testtype", "testns", "testmodel", "HUGGINGFACE", "deep")); got != 1024 { + t.Errorf("integrityBytesScannedTotal did not record bytes, got = %v, want = 1024", got) + } } func TestGoRuntimeMetrics_AreSet(t *testing.T) { diff --git a/pkg/modelagent/model_data.go b/pkg/modelagent/model_data.go index bf468808..f3f3db24 100644 --- a/pkg/modelagent/model_data.go +++ b/pkg/modelagent/model_data.go @@ -99,10 +99,12 @@ func (p *DownloadProgress) Percentage() float64 { // ModelEntry represents an entry in the node model ConfigMap // This is the top-level structure stored for each model in the ConfigMap type ModelEntry struct { - Name string `json:"name"` // Name of the model - Status ModelStatus `json:"status"` // Current status of the model on this node - Config *ModelConfig `json:"config,omitempty"` // Model configuration, may be nil if just tracking status - Progress *DownloadProgress `json:"progress,omitempty"` // Download progress, nil when not downloading + Name string `json:"name"` // Name of the model + Status ModelStatus `json:"status"` // Current status of the model on this node + StorageURI string `json:"storageUri,omitempty"` // Source URI used for the local artifact + StoragePath string `json:"storagePath,omitempty"` // Local path used for the artifact + Config *ModelConfig `json:"config,omitempty"` // Model configuration, may be nil if just tracking status + Progress *DownloadProgress `json:"progress,omitempty"` // Download progress, nil when not downloading } // ConvertMetadataToModelConfig converts internal ModelMetadata to a client-facing ModelConfig @@ -183,3 +185,64 @@ func ConvertMetadataToModelConfig(metadata ModelMetadata) *ModelConfig { Artifact: artifact, } } + +// StorageIdentityForSpec returns the storage fields that identify the model artifact on disk. +func StorageIdentityForSpec(spec *v1beta1.BaseModelSpec) (storageURI string, storagePath string, ok bool) { + if spec == nil || spec.Storage == nil { + return "", "", false + } + + if spec.Storage.StorageUri != nil { + storageURI = *spec.Storage.StorageUri + } + if spec.Storage.Path != nil { + storagePath = *spec.Storage.Path + } + + if storageURI == "" && storagePath == "" { + return "", "", false + } + return storageURI, storagePath, true +} + +// ApplyStorageIdentity records the current storage identity on a ConfigMap entry. +func (entry *ModelEntry) ApplyStorageIdentity(spec *v1beta1.BaseModelSpec) { + if entry == nil { + return + } + storageURI, storagePath, ok := StorageIdentityForSpec(spec) + if !ok { + entry.StorageURI = "" + entry.StoragePath = "" + return + } + entry.StorageURI = storageURI + entry.StoragePath = storagePath +} + +func (entry *ModelEntry) hasStorageIdentity() bool { + return entry != nil && hasStorageIdentityFields(entry.StorageURI, entry.StoragePath) +} + +// MatchesStorageIdentity reports whether a ConfigMap entry describes the current storage source and path. +func (entry *ModelEntry) MatchesStorageIdentity(spec *v1beta1.BaseModelSpec) bool { + if entry == nil { + return false + } + return matchesStorageIdentityFields(entry.StorageURI, entry.StoragePath, spec) +} + +func hasStorageIdentityFields(storageURI, storagePath string) bool { + return storageURI != "" || storagePath != "" +} + +func matchesStorageIdentityFields(entryStorageURI, entryStoragePath string, spec *v1beta1.BaseModelSpec) bool { + storageURI, storagePath, ok := StorageIdentityForSpec(spec) + if !ok { + return true + } + if !hasStorageIdentityFields(entryStorageURI, entryStoragePath) { + return true + } + return entryStorageURI == storageURI && entryStoragePath == storagePath +} diff --git a/pkg/modelagent/model_data_test.go b/pkg/modelagent/model_data_test.go index 368830a7..0016c29a 100644 --- a/pkg/modelagent/model_data_test.go +++ b/pkg/modelagent/model_data_test.go @@ -224,8 +224,10 @@ func TestConvertMetadataToModelConfig(t *testing.T) { func TestModelEntryMarshaling(t *testing.T) { // Test model entry JSON marshaling and unmarshaling modelEntry := ModelEntry{ - Name: "llama-70b", - Status: ModelStatusReady, + Name: "llama-70b", + Status: ModelStatusReady, + StorageURI: "hf://meta-llama/llama-70b", + StoragePath: "/raid/models/meta-llama/llama-70b", Config: &ModelConfig{ ModelType: "llama", ModelArchitecture: "LlamaModel", @@ -255,11 +257,67 @@ func TestModelEntryMarshaling(t *testing.T) { if unmarshaled.Status != modelEntry.Status { t.Errorf("Status mismatch: got %s, want %s", unmarshaled.Status, modelEntry.Status) } + if unmarshaled.StorageURI != modelEntry.StorageURI { + t.Errorf("StorageURI mismatch: got %s, want %s", unmarshaled.StorageURI, modelEntry.StorageURI) + } + if unmarshaled.StoragePath != modelEntry.StoragePath { + t.Errorf("StoragePath mismatch: got %s, want %s", unmarshaled.StoragePath, modelEntry.StoragePath) + } if !reflect.DeepEqual(unmarshaled.Config, modelEntry.Config) { t.Errorf("ConfigAttr mismatch: got %+v, want %+v", unmarshaled.Config, modelEntry.Config) } } +func TestStorageIdentityForSpec(t *testing.T) { + storageURI := "hf://google/gemma-4-31B-it" + storagePath := "/raid/models/google/gemma-4-31B-it" + spec := &v1beta1.BaseModelSpec{ + Storage: &v1beta1.StorageSpec{ + StorageUri: &storageURI, + Path: &storagePath, + }, + } + + actualURI, actualPath, ok := StorageIdentityForSpec(spec) + if !ok { + t.Fatal("expected storage identity to exist") + } + if actualURI != storageURI { + t.Fatalf("storage URI mismatch: got %s, want %s", actualURI, storageURI) + } + if actualPath != storagePath { + t.Fatalf("storage path mismatch: got %s, want %s", actualPath, storagePath) + } + + entry := &ModelEntry{} + entry.ApplyStorageIdentity(spec) + if entry.StorageURI != storageURI { + t.Fatalf("entry storage URI mismatch: got %s, want %s", entry.StorageURI, storageURI) + } + if entry.StoragePath != storagePath { + t.Fatalf("entry storage path mismatch: got %s, want %s", entry.StoragePath, storagePath) + } + if !entry.MatchesStorageIdentity(spec) { + t.Fatal("expected entry to match its source spec") + } + + updatedPath := "/raid/models/google/gemma-4-31b-it" + updatedSpec := &v1beta1.BaseModelSpec{ + Storage: &v1beta1.StorageSpec{ + StorageUri: &storageURI, + Path: &updatedPath, + }, + } + if entry.MatchesStorageIdentity(updatedSpec) { + t.Fatal("expected entry not to match a spec with a different storage path") + } + + legacyEntry := &ModelEntry{} + if !legacyEntry.MatchesStorageIdentity(spec) { + t.Fatal("expected legacy entry without storage identity to match for upgrade compatibility") + } +} + func TestModelStatusConstants(t *testing.T) { // Verify model status constants tests := []struct { diff --git a/pkg/modelagent/node_label_reconciler.go b/pkg/modelagent/node_label_reconciler.go index 0c0659e5..1676b48c 100644 --- a/pkg/modelagent/node_label_reconciler.go +++ b/pkg/modelagent/node_label_reconciler.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "strings" "time" "go.uber.org/zap" @@ -35,13 +34,6 @@ type NodeLabelReconciler struct { logger *zap.SugaredLogger // Logger for recording operations } -// patchStringValue represents a JSON patch operation for node labels -type patchStringValue struct { - Op string `json:"op,omitempty"` - Path string `json:"path"` - Value string `json:"value,omitempty"` -} - // ModelStateOnNode represents the model state in legacy format // Maintained for backward compatibility with existing codepaths type ModelStateOnNode string @@ -121,24 +113,18 @@ func (n *NodeLabelReconciler) applyNodeLabelOperation(op *NodeLabelOp) error { } // Generate patch payload - payloadBytes, err := getNodeLabelPatchPayloadBytes(op) + payloadBytes, err := getNodeLabelMergePatchPayloadBytes(op) if err != nil { n.logger.Errorf("Failed to get node label patch payload for %s: %v", modelInfo, err) return nil // Don't retry for payload generation issues } n.logger.Debugf("Generated node label patch payload for %s: %s", modelInfo, string(payloadBytes)) - // Skip empty patch operations - if len(payloadBytes) <= 2 { // Just "[]" for empty patch - n.logger.Infof("Empty patch payload for %s, skipping operation", modelInfo) - return nil - } - // Apply the patch _, err = n.kubeClient.CoreV1().Nodes().Patch( context.TODO(), n.nodeName, - types.JSONPatchType, + types.MergePatchType, payloadBytes, metav1.PatchOptions{}, ) @@ -153,14 +139,6 @@ func (n *NodeLabelReconciler) applyNodeLabelOperation(op *NodeLabelOp) error { n.logger.Warnf("Conflict during patch operation for node %s and model %s, will retry: %v", n.nodeName, modelInfo, err) return err // Return error to trigger retry } else if errors.IsInvalid(err) || errors.IsBadRequest(err) { - // For delete operations that fail with "not found" patch path errors, consider it already done - if op.ModelStateOnNode == Deleted && strings.Contains(err.Error(), "not found") { - n.logger.Infof("Label %s already removed from node %s for %s - considering delete operation successful", - labelKey, n.nodeName, modelInfo) - return nil - } - - // Other invalid request, could be malformed patch, log but don't retry n.logger.Warnf("Invalid patch request for node %s and model %s: %v", n.nodeName, modelInfo, err) return nil // Don't retry for bad requests } @@ -205,45 +183,26 @@ func getModelLabelKey(op *NodeLabelOp) (string, error) { return labelKey, nil } -// getNodeLabelPatchPayloadBytes generates the JSON patch for node labels -func getNodeLabelPatchPayloadBytes(op *NodeLabelOp) ([]byte, error) { +func getNodeLabelMergePatchPayloadBytes(op *NodeLabelOp) ([]byte, error) { labelKey, err := getModelLabelKey(op) if err != nil { return []byte{}, err } - var payload []patchStringValue + labels := map[string]interface{}{} switch op.ModelStateOnNode { - case Ready: - payload = []patchStringValue{{ - Op: "add", - Path: fmt.Sprintf("/metadata/labels/%s", strings.ReplaceAll(labelKey, "/", "~1")), - Value: string(Ready), - }} - case Updating: - payload = []patchStringValue{{ - Op: "add", - Path: fmt.Sprintf("/metadata/labels/%s", strings.ReplaceAll(labelKey, "/", "~1")), - Value: string(Updating), - }} - case Failed: - payload = []patchStringValue{{ - Op: "add", - Path: fmt.Sprintf("/metadata/labels/%s", strings.ReplaceAll(labelKey, "/", "~1")), - Value: string(Failed), - }} + case Ready, Updating, Failed: + labels[labelKey] = string(op.ModelStateOnNode) case Deleted: - payload = []patchStringValue{{ - Op: "remove", - Path: fmt.Sprintf("/metadata/labels/%s", strings.ReplaceAll(labelKey, "/", "~1")), - }} + labels[labelKey] = nil default: break } - payloadBytes, err := json.Marshal(payload) - if err != nil { - return nil, err + payload := map[string]interface{}{ + "metadata": map[string]interface{}{ + "labels": labels, + }, } - return payloadBytes, nil + return json.Marshal(payload) } diff --git a/pkg/modelagent/node_label_reconciler_test.go b/pkg/modelagent/node_label_reconciler_test.go index e143258e..1764bd08 100644 --- a/pkg/modelagent/node_label_reconciler_test.go +++ b/pkg/modelagent/node_label_reconciler_test.go @@ -4,16 +4,16 @@ import ( "context" "encoding/json" "errors" - "fmt" - "strings" "testing" "github.com/stretchr/testify/assert" "go.uber.org/zap" "go.uber.org/zap/zaptest" corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/kubernetes/fake" ktesting "k8s.io/client-go/testing" @@ -159,89 +159,72 @@ func TestGetNodeLabelModelInfo(t *testing.T) { assert.Equal(t, "unknown model", info) } -// TestGetNodeLabelPatchPayloadBytes tests the getNodeLabelPatchPayloadBytes function -func TestGetNodeLabelPatchPayloadBytes(t *testing.T) { - // Test with BaseModel and Ready state +func TestGetNodeLabelMergePatchPayloadBytes(t *testing.T) { baseModel := createTestBaseModel() - op := &NodeLabelOp{ - BaseModel: baseModel, - ClusterBaseModel: nil, - ModelStateOnNode: Ready, - } - - payload, err := getNodeLabelPatchPayloadBytes(op) - assert.NoError(t, err) - - // Verify JSON patch structure - var patches []patchStringValue - err = json.Unmarshal(payload, &patches) - assert.NoError(t, err) - assert.Len(t, patches, 1) - - labelKey := constants.GetBaseModelLabel(baseModel.Namespace, baseModel.Name) - expectedPath := fmt.Sprintf("/metadata/labels/%s", strings.ReplaceAll(labelKey, "/", "~1")) - - assert.Equal(t, "add", patches[0].Op) - assert.Equal(t, expectedPath, patches[0].Path) - assert.Equal(t, "Ready", patches[0].Value) - - // Test with ClusterBaseModel and Updating state clusterBaseModel := createTestClusterBaseModel() - op = &NodeLabelOp{ - BaseModel: nil, - ClusterBaseModel: clusterBaseModel, - ModelStateOnNode: Updating, - } - - payload, err = getNodeLabelPatchPayloadBytes(op) - assert.NoError(t, err) - - err = json.Unmarshal(payload, &patches) - assert.NoError(t, err) - assert.Len(t, patches, 1) - - labelKey = constants.GetClusterBaseModelLabel(clusterBaseModel.Name) - expectedPath = fmt.Sprintf("/metadata/labels/%s", strings.ReplaceAll(labelKey, "/", "~1")) - - assert.Equal(t, "add", patches[0].Op) - assert.Equal(t, expectedPath, patches[0].Path) - assert.Equal(t, "Updating", patches[0].Value) - - // Test with Failed state - op.ModelStateOnNode = Failed - payload, err = getNodeLabelPatchPayloadBytes(op) - assert.NoError(t, err) - - err = json.Unmarshal(payload, &patches) - assert.NoError(t, err) - assert.Len(t, patches, 1) - assert.Equal(t, "add", patches[0].Op) - // The Failed enum is converted to a string, so we need to compare with "Failed" - assert.Equal(t, "Failed", patches[0].Value) - - // Test with Deleted state (should be "remove" operation) - op.ModelStateOnNode = Deleted - payload, err = getNodeLabelPatchPayloadBytes(op) - assert.NoError(t, err) - // For a remove operation, let's verify the raw JSON doesn't contain a value field - var jsonMap []map[string]interface{} - err = json.Unmarshal(payload, &jsonMap) - assert.NoError(t, err) - assert.Len(t, jsonMap, 1) - assert.Equal(t, "remove", jsonMap[0]["op"]) - // In a remove operation, the value field should not exist in the JSON at all - _, valueExists := jsonMap[0]["value"] - assert.False(t, valueExists, "value field should not exist in remove operation") + tests := []struct { + name string + op *NodeLabelOp + labelKey string + expectedValue interface{} + }{ + { + name: "BaseModel Ready", + op: &NodeLabelOp{ + BaseModel: baseModel, + ClusterBaseModel: nil, + ModelStateOnNode: Ready, + }, + labelKey: constants.GetBaseModelLabel(baseModel.Namespace, baseModel.Name), + expectedValue: string(Ready), + }, + { + name: "BaseModel Failed", + op: &NodeLabelOp{ + BaseModel: baseModel, + ClusterBaseModel: nil, + ModelStateOnNode: Failed, + }, + labelKey: constants.GetBaseModelLabel(baseModel.Namespace, baseModel.Name), + expectedValue: string(Failed), + }, + { + name: "ClusterBaseModel Updating", + op: &NodeLabelOp{ + BaseModel: nil, + ClusterBaseModel: clusterBaseModel, + ModelStateOnNode: Updating, + }, + labelKey: constants.GetClusterBaseModelLabel(clusterBaseModel.Name), + expectedValue: string(Updating), + }, + { + name: "ClusterBaseModel Deleted", + op: &NodeLabelOp{ + BaseModel: nil, + ClusterBaseModel: clusterBaseModel, + ModelStateOnNode: Deleted, + }, + labelKey: constants.GetClusterBaseModelLabel(clusterBaseModel.Name), + expectedValue: nil, + }, + } - // Test with no model (should return error) - op = &NodeLabelOp{ - BaseModel: nil, - ClusterBaseModel: nil, - ModelStateOnNode: Ready, + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, err := getNodeLabelMergePatchPayloadBytes(tt.op) + assert.NoError(t, err) + + var patch map[string]map[string]map[string]interface{} + err = json.Unmarshal(payload, &patch) + assert.NoError(t, err) + assert.Equal(t, tt.expectedValue, patch["metadata"]["labels"][tt.labelKey]) + assert.Len(t, patch["metadata"]["labels"], 1) + }) } - _, err = getNodeLabelPatchPayloadBytes(op) + _, err := getNodeLabelMergePatchPayloadBytes(&NodeLabelOp{ModelStateOnNode: Ready}) assert.Error(t, err) assert.Contains(t, err.Error(), "empty op without any models") } @@ -256,7 +239,7 @@ func TestApplyNodeLabelOperation(t *testing.T) { // Let the default reactor handle the action, but capture the patch for verification patchAction := action.(ktesting.PatchAction) assert.Equal(t, "test-node", patchAction.GetName()) - assert.Equal(t, types.JSONPatchType, patchAction.GetPatchType()) + assert.Equal(t, types.MergePatchType, patchAction.GetPatchType()) // Return default reactor response return false, nil, nil @@ -285,6 +268,27 @@ func TestApplyNodeLabelOperation(t *testing.T) { assert.Contains(t, err.Error(), "test error") } +func TestApplyNodeLabelOperationWithMissingLabels(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + kubeClient := fake.NewSimpleClientset(&corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node-without-labels", + }, + }) + reconciler := NewNodeLabelReconciler("node-without-labels", kubeClient, 1, logger) + baseModel := createTestBaseModel() + + err := reconciler.applyNodeLabelOperation(&NodeLabelOp{ + BaseModel: baseModel, + ModelStateOnNode: Failed, + }) + assert.NoError(t, err) + + node, err := kubeClient.CoreV1().Nodes().Get(context.TODO(), "node-without-labels", metav1.GetOptions{}) + assert.NoError(t, err) + assert.Equal(t, string(Failed), node.Labels[constants.GetBaseModelLabel(baseModel.Namespace, baseModel.Name)]) +} + // TestReconcileNodeLabels tests the ReconcileNodeLabels method func TestReconcileNodeLabels(t *testing.T) { // Test successful patching @@ -400,16 +404,15 @@ func TestIdempotentOperations(t *testing.T) { err = reconciler.applyNodeLabelOperation(op) assert.NoError(t, err) - // Test case 3: Empty patch payloads - // Create a mock NodeLabelOp that will result in an empty patch - emptyPatchOp := &NodeLabelOp{ + // Test case 3: invalid model refs are ignored before patching + invalidModelOp := &NodeLabelOp{ BaseModel: nil, - ClusterBaseModel: nil, // This will cause an empty patch + ClusterBaseModel: nil, ModelStateOnNode: Ready, } reconciler.nodeName = "test-node" - err = reconciler.applyNodeLabelOperation(emptyPatchOp) - assert.NoError(t, err) // Should gracefully handle empty patch + err = reconciler.applyNodeLabelOperation(invalidModelOp) + assert.NoError(t, err) } // TestNodeLabelErrorHandling tests error handling in applyNodeLabelOperation @@ -419,20 +422,17 @@ func TestNodeLabelErrorHandling(t *testing.T) { // Prepare a model for tests clusterBaseModel := createTestClusterBaseModel() - // Setup a delete operation for a label that will fail with "not found" error + // Setup a delete operation for a label that will fail with a non-retryable request error op := &NodeLabelOp{ ClusterBaseModel: clusterBaseModel, ModelStateOnNode: Deleted, } - // Mock the patch to fail with a "not found" error + // Mock the patch to fail with a bad request error kubeClient.PrependReactor("patch", "nodes", func(action ktesting.Action) (bool, runtime.Object, error) { patchAction := action.(ktesting.PatchAction) if patchAction.GetName() == "test-node" { - // Return a "not found" error for delete operations - if string(patchAction.GetPatch()) != "" && strings.Contains(string(patchAction.GetPatch()), "remove") { - return true, nil, errors.New("the server rejected our request: path not found") - } + return true, nil, apierrors.NewBadRequest("patch rejected") } return false, nil, nil }) @@ -445,7 +445,7 @@ func TestNodeLabelErrorHandling(t *testing.T) { kubeClient.PrependReactor("patch", "nodes", func(action ktesting.Action) (bool, runtime.Object, error) { patchAction := action.(ktesting.PatchAction) if patchAction.GetName() == "test-node" { - return true, nil, errors.New("Operation cannot be fulfilled on nodes \"test-node\": the object has been modified") + return true, nil, apierrors.NewConflict(schema.GroupResource{Resource: "nodes"}, "test-node", errors.New("object modified")) } return false, nil, nil }) diff --git a/pkg/modelagent/scout.go b/pkg/modelagent/scout.go index 85b1ec5b..178bb1e0 100644 --- a/pkg/modelagent/scout.go +++ b/pkg/modelagent/scout.go @@ -132,6 +132,10 @@ func NewScout(ctx context.Context, nodeName string, return scout, nil } +func (w *Scout) NodeShapeAlias() string { + return w.nodeShapeAlias +} + func (w *Scout) Run(stopCh <-chan struct{}) error { defer runtime.HandleCrash() diff --git a/pkg/ociobjectstore/os_data_store.go b/pkg/ociobjectstore/os_data_store.go index 6ff96646..07168c58 100644 --- a/pkg/ociobjectstore/os_data_store.go +++ b/pkg/ociobjectstore/os_data_store.go @@ -27,6 +27,30 @@ import ( * OCIOSDataStore used to perform data store operations with Object Storage */ +type LocalCopyValidationState string + +const ( + LocalCopyValidationValid LocalCopyValidationState = "Valid" + LocalCopyValidationInvalid LocalCopyValidationState = "Invalid" + LocalCopyValidationInconclusive LocalCopyValidationState = "Inconclusive" +) + +type LocalCopyValidationReason string + +const ( + LocalCopyValidationReasonOK LocalCopyValidationReason = "OK" + LocalCopyValidationReasonMissing LocalCopyValidationReason = "Missing" + LocalCopyValidationReasonSizeMismatch LocalCopyValidationReason = "SizeMismatch" + LocalCopyValidationReasonChecksumMismatch LocalCopyValidationReason = "ChecksumMismatch" + LocalCopyValidationReasonChecksumMissing LocalCopyValidationReason = "ChecksumMissing" +) + +type LocalCopyValidationResult struct { + State LocalCopyValidationState + Reason LocalCopyValidationReason + Message string +} + // OCIOSDataStore performs data store operations against Oracle Object Storage. // It provides file upload, download, listing, and validation methods. type OCIOSDataStore struct { @@ -449,19 +473,35 @@ func (cds *OCIOSDataStore) ListObjects(target ObjectURI) ([]objectstorage.Object // IsLocalCopyValid checks whether a local file matches the expected object in size and MD5 checksum. // If the object was uploaded via multipart and lacks a standard MD5, it attempts to verify via custom metadata. // -// Returns true if the local file is a valid, verified copy of the object. +// Returns true for valid and checksum-inconclusive local copies to preserve the +// existing download skip behavior. Use ValidateLocalCopy when callers need to +// distinguish valid from inconclusive. func (cds *OCIOSDataStore) IsLocalCopyValid(source ObjectURI, localFilePath string) (bool, error) { + result, err := cds.ValidateLocalCopy(source, localFilePath) + if err != nil { + return false, err + } + return result.State != LocalCopyValidationInvalid, nil +} + +// ValidateLocalCopy checks whether a local file matches the expected object and reports +// whether the result is valid, invalid, or inconclusive when the remote checksum is unavailable. +func (cds *OCIOSDataStore) ValidateLocalCopy(source ObjectURI, localFilePath string) (LocalCopyValidationResult, error) { fileInfo, err := os.Stat(localFilePath) if err != nil { if os.IsNotExist(err) { - return false, nil + return LocalCopyValidationResult{ + State: LocalCopyValidationInvalid, + Reason: LocalCopyValidationReasonMissing, + Message: "local file does not exist", + }, nil } - return false, err + return LocalCopyValidationResult{}, err } headResponse, err := cds.HeadObject(source) if err != nil { - return false, fmt.Errorf("failed to get object metadata: %w", err) + return LocalCopyValidationResult{}, fmt.Errorf("failed to get object metadata: %w", err) } objectMd5 := headResponse.ContentMd5 objectLength := headResponse.ContentLength @@ -470,7 +510,11 @@ func (cds *OCIOSDataStore) IsLocalCopyValid(source ObjectURI, localFilePath stri cds.logger.Warnf("File size mismatch for %s: expected %d, got %d", localFilePath, *objectLength, fileInfo.Size()) // File size mismatch should always return false - return false, nil + return LocalCopyValidationResult{ + State: LocalCopyValidationInvalid, + Reason: LocalCopyValidationReasonSizeMismatch, + Message: fmt.Sprintf("file size mismatch: expected %d got %d", *objectLength, fileInfo.Size()), + }, nil } if objectMd5 == nil && headResponse.OpcMultipartMd5 != nil && isMultipartMd5(*headResponse.OpcMultipartMd5) { @@ -481,13 +525,24 @@ func (cds *OCIOSDataStore) IsLocalCopyValid(source ObjectURI, localFilePath stri objectMd5 = &realMd5 } else { cds.logger.Warnf("No MD5 in metadata for multipart object %s; skipping integrity check", source.ObjectName) - return true, nil + return LocalCopyValidationResult{ + State: LocalCopyValidationInconclusive, + Reason: LocalCopyValidationReasonChecksumMissing, + Message: "multipart object does not expose a comparable MD5 checksum", + }, nil } } + if objectMd5 == nil { + return LocalCopyValidationResult{ + State: LocalCopyValidationInconclusive, + Reason: LocalCopyValidationReasonChecksumMissing, + Message: "object does not expose a comparable MD5 checksum", + }, nil + } file, err := os.Open(localFilePath) if err != nil { - return false, err + return LocalCopyValidationResult{}, err } defer func(file *os.File) { err := file.Close() @@ -498,18 +553,25 @@ func (cds *OCIOSDataStore) IsLocalCopyValid(source ObjectURI, localFilePath stri hash := md5.New() if _, err := io.Copy(hash, file); err != nil { - return false, err + return LocalCopyValidationResult{}, err } localMd5 := base64.StdEncoding.EncodeToString(hash.Sum(nil)) if *objectMd5 == localMd5 { - return true, nil + return LocalCopyValidationResult{ + State: LocalCopyValidationValid, + Reason: LocalCopyValidationReasonOK, + }, nil } cds.logger.Warnf("MD5 mismatch for %s: expected %s, got %s", localFilePath, *objectMd5, localMd5) - return false, nil + return LocalCopyValidationResult{ + State: LocalCopyValidationInvalid, + Reason: LocalCopyValidationReasonChecksumMismatch, + Message: fmt.Sprintf("MD5 mismatch: expected %s got %s", *objectMd5, localMd5), + }, nil } // isMultipartMd5 determines if the given object MD5 string represents a multipart upload checksum. diff --git a/pkg/ociobjectstore/os_data_store_test.go b/pkg/ociobjectstore/os_data_store_test.go index ba366a75..e419977e 100644 --- a/pkg/ociobjectstore/os_data_store_test.go +++ b/pkg/ociobjectstore/os_data_store_test.go @@ -1,20 +1,57 @@ package ociobjectstore import ( + "crypto/md5" + "encoding/base64" "fmt" + "io" + "net/http" "os" "path/filepath" "strings" "testing" "time" + "github.com/oracle/oci-go-sdk/v65/common" + "github.com/oracle/oci-go-sdk/v65/objectstorage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/sgl-project/ome/pkg/logging" "github.com/sgl-project/ome/pkg/principals" testingPkg "github.com/sgl-project/ome/pkg/testing" ) +type noopRequestSigner struct{} + +func (noopRequestSigner) Sign(_ *http.Request) error { + return nil +} + +type fakeHeadObjectDispatcher struct { + contentLength int64 + contentMD5 string +} + +func (f fakeHeadObjectDispatcher) Do(req *http.Request) (*http.Response, error) { + header := http.Header{} + header.Set("Content-Length", fmt.Sprintf("%d", f.contentLength)) + if f.contentMD5 != "" { + header.Set("Content-MD5", f.contentMD5) + } + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Header: header, + Body: io.NopCloser(strings.NewReader("")), + ContentLength: f.contentLength, + Request: req, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + }, nil +} + func TestNewOCIOSDataStore(t *testing.T) { t.Run("Nil config", func(t *testing.T) { cds, err := NewOCIOSDataStore(nil) @@ -46,6 +83,86 @@ func TestNewOCIOSDataStore(t *testing.T) { }) } +func TestValidateLocalCopy(t *testing.T) { + localPath := filepath.Join(t.TempDir(), "model.safetensors") + content := []byte("model-data") + require.NoError(t, os.WriteFile(localPath, content, 0644)) + + t.Run("Valid", func(t *testing.T) { + sum := md5.Sum(content) + store := newValidationTestStore(t, int64(len(content)), base64.StdEncoding.EncodeToString(sum[:])) + + result, err := store.ValidateLocalCopy(testObjectURI(), localPath) + require.NoError(t, err) + assert.Equal(t, LocalCopyValidationValid, result.State) + assert.Equal(t, LocalCopyValidationReasonOK, result.Reason) + }) + + t.Run("Missing", func(t *testing.T) { + store := newValidationTestStore(t, int64(len(content)), "") + + result, err := store.ValidateLocalCopy(testObjectURI(), filepath.Join(t.TempDir(), "missing")) + require.NoError(t, err) + assert.Equal(t, LocalCopyValidationInvalid, result.State) + assert.Equal(t, LocalCopyValidationReasonMissing, result.Reason) + }) + + t.Run("SizeMismatch", func(t *testing.T) { + store := newValidationTestStore(t, int64(len(content)+1), "") + + result, err := store.ValidateLocalCopy(testObjectURI(), localPath) + require.NoError(t, err) + assert.Equal(t, LocalCopyValidationInvalid, result.State) + assert.Equal(t, LocalCopyValidationReasonSizeMismatch, result.Reason) + }) + + t.Run("ChecksumMismatch", func(t *testing.T) { + sum := md5.Sum([]byte("different")) + store := newValidationTestStore(t, int64(len(content)), base64.StdEncoding.EncodeToString(sum[:])) + + result, err := store.ValidateLocalCopy(testObjectURI(), localPath) + require.NoError(t, err) + assert.Equal(t, LocalCopyValidationInvalid, result.State) + assert.Equal(t, LocalCopyValidationReasonChecksumMismatch, result.Reason) + }) + + t.Run("ChecksumMissing", func(t *testing.T) { + store := newValidationTestStore(t, int64(len(content)), "") + + result, err := store.ValidateLocalCopy(testObjectURI(), localPath) + require.NoError(t, err) + assert.Equal(t, LocalCopyValidationInconclusive, result.State) + assert.Equal(t, LocalCopyValidationReasonChecksumMissing, result.Reason) + }) +} + +func newValidationTestStore(t *testing.T, contentLength int64, contentMD5 string) *OCIOSDataStore { + t.Helper() + client := objectstorage.ObjectStorageClient{ + BaseClient: common.BaseClient{ + HTTPClient: fakeHeadObjectDispatcher{ + contentLength: contentLength, + contentMD5: contentMD5, + }, + Signer: noopRequestSigner{}, + Host: "https://objectstorage.test", + UserAgent: "ome-test", + }, + } + return &OCIOSDataStore{ + logger: logging.Discard(), + Client: &client, + } +} + +func testObjectURI() ObjectURI { + return ObjectURI{ + Namespace: "namespace", + BucketName: "bucket", + ObjectName: "model.safetensors", + } +} + func TestIsMultipartMd5(t *testing.T) { tests := []struct { name string diff --git a/site/content/en/docs/administration/model-agent.md b/site/content/en/docs/administration/model-agent.md index 17fda06a..4f092d1a 100644 --- a/site/content/en/docs/administration/model-agent.md +++ b/site/content/en/docs/administration/model-agent.md @@ -107,6 +107,14 @@ For models using SafeTensors format: - **Node Labeling**: Apply labels to nodes indicating model availability - **Metric Emission**: Update Prometheus metrics for monitoring +### 7. Post-Ready Integrity Reconciliation + +The Model Agent periodically rechecks models that are already marked `Ready` on the local node. Each pass reads the node-scoped model status ConfigMap, validates only current `Ready` entries, and marks the node `Failed` for that model when required artifacts are missing or corrupted. + +Basic checks validate the model path, required config files, model weight presence, and persisted manifest file sizes. Deep checks run less frequently and compare SHA256 manifest hashes; OCI-backed models also reuse Object Storage metadata validation, including MD5 checks where available. + +For legacy Ready models without an existing manifest, the agent first performs conservative path/config/weight validation and then writes the baseline manifest during a deep check. Same-size corruption that already exists before that first baseline cannot be detected by the manifest. + ## Configuration Reference ### Command Line Arguments @@ -140,6 +148,9 @@ The Model Agent supports extensive configuration through command-line arguments: | `--node-name` | `$NODE_NAME` | Name of the current node (usually from environment) | | `--namespace` | `ome` | Kubernetes namespace for ConfigMaps and status tracking | | `--node-label-retry` | 5 | Number of retries for updating node labels | +| `--integrity-check-interval` | `10m` | Interval for periodic Ready model artifact checks; set `<=0` to disable | +| `--integrity-deep-check-interval` | `6h` | Interval for checksum validation; set `<=0` to disable deep scans | +| `--integrity-startup-jitter` | `30s` | Maximum deterministic startup jitter before the first integrity check | #### Logging and Monitoring @@ -169,6 +180,9 @@ The Model Agent also supports configuration through environment variables: | `OCI_CONFIG_FILE` | Path to OCI configuration file | | `HUGGINGFACE_TOKEN` | Default Hugging Face access token | | `INSTANCE_TYPE_MAP` | JSON mapping of cloud instance types to GPU short names (e.g., `{"BM.GPU.H100.8": "H100"}`) | +| `INTEGRITY_CHECK_INTERVAL` | Periodic Ready model artifact check interval | +| `INTEGRITY_DEEP_CHECK_INTERVAL` | Deep checksum validation interval | +| `INTEGRITY_STARTUP_JITTER` | Maximum startup jitter before the first integrity check | ## Advanced Download Features @@ -402,6 +416,12 @@ model_agent_verification_duration_seconds 12.34 # MD5 checksum failures model_agent_md5_checksum_failed_total{model_type="llama", namespace="default", name="llama-70b"} 0 + +# Periodic artifact integrity checks +model_agent_integrity_checks_total{model_type="basemodel", namespace="default", name="llama-70b", storage_type="HUGGINGFACE", check_type="basic", result="success", reason="ok"} 1 + +# Deep integrity bytes scanned +model_agent_integrity_bytes_scanned_total{model_type="basemodel", namespace="default", name="llama-70b", storage_type="HUGGINGFACE", check_type="deep"} 140737488355328 ``` #### Runtime Metrics