Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions charts/ome-resources/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ OME Resources and Controller
| global.hub | string | `"ghcr.io/moirai-internal"` | |
| modelAgent.health.port | int | `8080` | |
| modelAgent.hostPath | string | `"/mnt/data/models"` | |
| modelAgent.integrity.checkInterval | string | `"10m"` | |
| modelAgent.integrity.deepCheckInterval | string | `"6h"` | |
| modelAgent.integrity.startupJitter | string | `"30s"` | |
| modelAgent.image.pullPolicy | string | `"Always"` | |
| modelAgent.image.repository | string | `"model-agent"` | |
| modelAgent.image.tag | string | `"v0.1.2"` | |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ spec:
- {{ .Values.modelAgent.hostPath }}
- --num-download-worker
- '2'
- --integrity-check-interval
- {{ .Values.modelAgent.integrity.checkInterval | quote }}
- --integrity-deep-check-interval
- {{ .Values.modelAgent.integrity.deepCheckInterval | quote }}
- --integrity-startup-jitter
- {{ .Values.modelAgent.integrity.startupJitter | quote }}
env:
- name: NODE_NAME
valueFrom:
Expand Down
5 changes: 5 additions & 0 deletions charts/ome-resources/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ modelAgent:
health:
port: 8080

integrity:
checkInterval: 10m
deepCheckInterval: 6h
startupJitter: 30s

# Additional environment variables for the model-agent container
# Examples:
# env:
Expand Down
39 changes: 27 additions & 12 deletions cmd/model-agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,21 @@ import (

// config holds all configuration parameters for the model agent
type config struct {
port int
modelsRootDir string
modelsRootDirOnHost string
nodeName string
nodeLabelRetry int
concurrency int
multipartConcurrency int
downloadRetry int
downloadAuthType string
numDownloadWorker int
namespace string
logLevel string
port int
modelsRootDir string
modelsRootDirOnHost string
nodeName string
nodeLabelRetry int
concurrency int
multipartConcurrency int
downloadRetry int
downloadAuthType string
numDownloadWorker int
namespace string
logLevel string
integrityCheckInterval time.Duration
integrityDeepInterval time.Duration
integrityStartupJitter time.Duration
}

// Logger type alias for zap.SugaredLogger
Expand Down Expand Up @@ -73,6 +76,10 @@ func init() {
rootCmd.PersistentFlags().IntVar(&cfg.numDownloadWorker, "num-download-worker", 5, "Number of download workers")
rootCmd.PersistentFlags().StringVar(&cfg.namespace, "namespace", "ome", "Kubernetes namespace to use")
rootCmd.PersistentFlags().StringVar(&cfg.logLevel, "log-level", "info", "Log level (debug, info, warn, error)")
defaultIntegrityConfig := modelagent.DefaultIntegrityConfig()
rootCmd.PersistentFlags().DurationVar(&cfg.integrityCheckInterval, "integrity-check-interval", defaultIntegrityConfig.CheckInterval, "Interval for periodic Ready model artifact integrity checks; set <=0 to disable")
rootCmd.PersistentFlags().DurationVar(&cfg.integrityDeepInterval, "integrity-deep-check-interval", defaultIntegrityConfig.DeepCheckInterval, "Interval for deep checksum validation; set <=0 to disable checksum scans")
rootCmd.PersistentFlags().DurationVar(&cfg.integrityStartupJitter, "integrity-startup-jitter", defaultIntegrityConfig.StartupJitter, "Maximum deterministic startup jitter before the first integrity check")

_ = v.BindPFlags(rootCmd.PersistentFlags())
v.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
Expand Down Expand Up @@ -268,6 +275,14 @@ func initializeComponents(
logger,
baseModelInformer.Lister(),
clusterBaseModelInformer.Lister(),
baseModelInformer.Informer().HasSynced,
clusterBaseModelInformer.Informer().HasSynced,
scout.NodeShapeAlias(),
modelagent.IntegrityConfig{
CheckInterval: v.GetDuration("integrity-check-interval"),
DeepCheckInterval: v.GetDuration("integrity-deep-check-interval"),
StartupJitter: v.GetDuration("integrity-startup-jitter"),
},
)
if err != nil {
return nil, nil, fmt.Errorf("failed to create gopher: %w", err)
Expand Down
9 changes: 9 additions & 0 deletions cmd/model-agent/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
"k8s.io/client-go/rest"

"github.com/sgl-project/ome/pkg/modelagent"
)

func setupTestEnv(t *testing.T) {
Expand Down Expand Up @@ -113,6 +115,10 @@ func TestDefaultConfig(t *testing.T) {
testCmd.Flags().StringVar(&cfg.downloadAuthType, "download-auth-type", "instance-principal", "authentication method for model download")
testCmd.Flags().IntVar(&cfg.numDownloadWorker, "num-download-worker", 3, "number of download workers")
testCmd.Flags().StringVar(&cfg.namespace, "namespace", "ome", "the namespace of the ome model agents daemon set")
defaultIntegrityConfig := modelagent.DefaultIntegrityConfig()
testCmd.Flags().DurationVar(&cfg.integrityCheckInterval, "integrity-check-interval", defaultIntegrityConfig.CheckInterval, "Model artifact integrity check interval")
testCmd.Flags().DurationVar(&cfg.integrityDeepInterval, "integrity-deep-check-interval", defaultIntegrityConfig.DeepCheckInterval, "Model artifact deep integrity check interval")
testCmd.Flags().DurationVar(&cfg.integrityStartupJitter, "integrity-startup-jitter", defaultIntegrityConfig.StartupJitter, "Model artifact integrity startup jitter")

// Call initConfig to set cfg.nodeName
initConfig(nil, nil)
Expand All @@ -127,6 +133,9 @@ func TestDefaultConfig(t *testing.T) {
assert.Equal(t, "instance-principal", cfg.downloadAuthType)
assert.Equal(t, 3, cfg.numDownloadWorker)
assert.Equal(t, "ome", cfg.namespace)
assert.Equal(t, defaultIntegrityConfig.CheckInterval, cfg.integrityCheckInterval)
assert.Equal(t, defaultIntegrityConfig.DeepCheckInterval, cfg.integrityDeepInterval)
assert.Equal(t, defaultIntegrityConfig.StartupJitter, cfg.integrityStartupJitter)
}

func TestInitializeLogger(t *testing.T) {
Expand Down
6 changes: 6 additions & 0 deletions config/model-agent/daemonset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ spec:
- '2'
- --concurrency
- '2'
- --integrity-check-interval
- 10m
- --integrity-deep-check-interval
- 6h
- --integrity-startup-jitter
- 30s
env:
- name: NODE_NAME
valueFrom:
Expand Down
14 changes: 14 additions & 0 deletions pkg/controller/v1beta1/basemodel/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ func handleModelDeletion(ctx context.Context, kubeClient client.Client, obj clie
// updateModelStatus updates BaseModel status based on ConfigMap data
func (r *BaseModelReconciler) updateModelStatus(ctx context.Context, baseModel *v1beta1.BaseModel) error {
return processModelStatus(ctx, r.Client, r.Log, baseModel.Namespace, baseModel.Name, false,
&baseModel.Spec,
func(ctx context.Context, config *modelagent.ModelConfig) error {
return r.updateModelSpecWithRetry(ctx, baseModel, config)
},
Expand All @@ -261,6 +262,7 @@ func (r *BaseModelReconciler) updateModelStatus(ctx context.Context, baseModel *
// updateModelStatus updates ClusterBaseModel status based on ConfigMap data
func (r *ClusterBaseModelReconciler) updateModelStatus(ctx context.Context, clusterBaseModel *v1beta1.ClusterBaseModel) error {
return processModelStatus(ctx, r.Client, r.Log, "", clusterBaseModel.Name, true,
&clusterBaseModel.Spec,
func(ctx context.Context, config *modelagent.ModelConfig) error {
return r.updateModelSpecWithRetry(ctx, clusterBaseModel, config)
},
Expand All @@ -271,6 +273,7 @@ func (r *ClusterBaseModelReconciler) updateModelStatus(ctx context.Context, clus

// processModelStatus is a shared utility function for processing ConfigMaps and updating model status
func processModelStatus(ctx context.Context, kubeClient client.Client, log logr.Logger, namespace, name string, isClusterScope bool,
spec *v1beta1.BaseModelSpec,
specUpdateFunc func(context.Context, *modelagent.ModelConfig) error,
statusUpdateFunc func(context.Context, []string, []string) error) error {

Expand Down Expand Up @@ -332,6 +335,17 @@ func processModelStatus(ctx context.Context, kubeClient client.Client, log logr.

log.V(1).Info("Processing model entry", "node", configMap.Name, "status", modelEntry.Status, "hasConfig", modelEntry.Config != nil, "hasProgress", modelEntry.Progress != nil)

if !modelEntry.MatchesStorageIdentity(spec) {
currentStorageURI, currentStoragePath, _ := modelagent.StorageIdentityForSpec(spec)
log.Info("Skipping stale model status entry",
"node", configMap.Name,
"entryStorageUri", modelEntry.StorageURI,
"currentStorageUri", currentStorageURI,
"entryStoragePath", modelEntry.StoragePath,
"currentStoragePath", currentStoragePath)
continue
}

// Update model spec with config if available
if modelEntry.Config != nil {
if err := specUpdateFunc(ctx, modelEntry.Config); err != nil {
Expand Down
Loading