diff --git a/cmd/fleet/cron.go b/cmd/fleet/cron.go index ad07c5d92a6..fe0c68541b8 100644 --- a/cmd/fleet/cron.go +++ b/cmd/fleet/cron.go @@ -1462,6 +1462,7 @@ func newAppleMDMProfileManagerSchedule( instanceID string, ds fleet.Datastore, commander *apple_mdm.MDMAppleCommander, + redisKeyValue fleet.AdvancedKeyValueStore, logger *slog.Logger, certProfilesLimit int, ) (*schedule.Schedule, error) { @@ -1478,7 +1479,7 @@ func newAppleMDMProfileManagerSchedule( ctx, name, instanceID, defaultInterval, ds, ds, schedule.WithLogger(logger), schedule.WithJob("manage_apple_profiles", func(ctx context.Context) error { - return service.ReconcileAppleProfiles(ctx, ds, commander, logger, certProfilesLimit) + return service.ReconcileAppleProfiles(ctx, ds, commander, redisKeyValue, logger, certProfilesLimit) }), schedule.WithJob("manage_apple_declarations", func(ctx context.Context) error { return service.ReconcileAppleDeclarations(ctx, ds, commander, logger) diff --git a/cmd/fleet/cron_test.go b/cmd/fleet/cron_test.go index 015fbaf57db..74ef69c26dd 100644 --- a/cmd/fleet/cron_test.go +++ b/cmd/fleet/cron_test.go @@ -26,10 +26,11 @@ func TestNewAppleMDMProfileManagerWithoutConfig(t *testing.T) { ctx := context.Background() mdmStorage := &mdmmock.MDMAppleStore{} ds := new(mock.Store) + kv := new(mock.AdvancedKVStore) cmdr := apple_mdm.NewMDMAppleCommander(mdmStorage, nil) logger := slog.New(slog.DiscardHandler) - sch, err := newAppleMDMProfileManagerSchedule(ctx, "foo", ds, cmdr, logger, 0) + sch, err := newAppleMDMProfileManagerSchedule(ctx, "foo", ds, cmdr, kv, logger, 0) require.NotNil(t, sch) require.NoError(t, err) } diff --git a/cmd/fleet/serve.go b/cmd/fleet/serve.go index d37b32e4164..7ec36efc649 100644 --- a/cmd/fleet/serve.go +++ b/cmd/fleet/serve.go @@ -1258,6 +1258,7 @@ func runServeCmd(cmd *cobra.Command, configManager configpkg.Manager, debug, dev instanceID, ds, apple_mdm.NewMDMAppleCommander(mdmStorage, mdmPushService), + redis_key_value.New(redisPool), logger, config.MDM.CertificateProfilesLimit, ) diff --git a/server/fleet/mdm.go b/server/fleet/mdm.go index bc16f0c6bcb..ad675b8d590 100644 --- a/server/fleet/mdm.go +++ b/server/fleet/mdm.go @@ -31,6 +31,11 @@ const ( StickyMDMEnrollmentKeyPrefix = "sticky_mdm_enrollment_" // + host UUID StickyMDMEnrollmentTTL = 30 * time.Minute + + // MDMProfileProcessingKeyPrefix is used to indicate that a host is currently being processed for MDM profile installation. + // We wrap the key in braces to make Redis hash the keys to the same slot, avoding CrossSlot errors. + MDMProfileProcessingKeyPrefix = "{mdm_profile_processing}" // + :hostUUID + MDMProfileProcessingTTL = 1 * time.Minute // We use a low time here, to avoid letting it sit for too long in case of errors. ) // FleetVarName represents the name of a Fleet variable (without the FLEET_VAR_ prefix). diff --git a/server/fleet/service.go b/server/fleet/service.go index ae8bb1ea4af..b101c498d9d 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -1471,6 +1471,16 @@ type KeyValueStore interface { Get(ctx context.Context, key string) (*string, error) } +type AdvancedKeyValueStore interface { + KeyValueStore + + // MGet returns the values for the given keys. + // It returns a map of key to value, where the value is nil if the key doesn't exist. + // Important to use hashes for the keys to land in the same slot. + MGet(ctx context.Context, keys []string) (map[string]*string, error) + Delete(ctx context.Context, key string) error +} + const ( // BatchSetSoftwareInstallerStatusProcessing is the value returned for an ongoing BatchSetSoftwareInstallers operation. BatchSetSoftwareInstallersStatusProcessing = "processing" diff --git a/server/mock/redis_advanced/advanced_key_value_store.go b/server/mock/redis_advanced/advanced_key_value_store.go new file mode 100644 index 00000000000..dc12397ee48 --- /dev/null +++ b/server/mock/redis_advanced/advanced_key_value_store.go @@ -0,0 +1,65 @@ +// Automatically generated by mockimpl. DO NOT EDIT! + +package mock + +import ( + "context" + "sync" + "time" + + "github.com/fleetdm/fleet/v4/server/fleet" +) + +var _ fleet.AdvancedKeyValueStore = (*AdvancedKeyValueStore)(nil) + +type SetFunc func(ctx context.Context, key string, value string, expireTime time.Duration) error + +type GetFunc func(ctx context.Context, key string) (*string, error) + +type MGetFunc func(ctx context.Context, keys []string) (map[string]*string, error) + +type DeleteFunc func(ctx context.Context, key string) error + +type AdvancedKeyValueStore struct { + SetFunc SetFunc + SetFuncInvoked bool + + GetFunc GetFunc + GetFuncInvoked bool + + MGetFunc MGetFunc + MGetFuncInvoked bool + + DeleteFunc DeleteFunc + DeleteFuncInvoked bool + + mu sync.Mutex +} + +func (akv *AdvancedKeyValueStore) Set(ctx context.Context, key string, value string, expireTime time.Duration) error { + akv.mu.Lock() + akv.SetFuncInvoked = true + akv.mu.Unlock() + return akv.SetFunc(ctx, key, value, expireTime) +} + +func (akv *AdvancedKeyValueStore) Get(ctx context.Context, key string) (*string, error) { + akv.mu.Lock() + akv.GetFuncInvoked = true + akv.mu.Unlock() + return akv.GetFunc(ctx, key) +} + +func (akv *AdvancedKeyValueStore) MGet(ctx context.Context, keys []string) (map[string]*string, error) { + akv.mu.Lock() + akv.MGetFuncInvoked = true + akv.mu.Unlock() + return akv.MGetFunc(ctx, keys) +} + +func (akv *AdvancedKeyValueStore) Delete(ctx context.Context, key string) error { + akv.mu.Lock() + akv.DeleteFuncInvoked = true + akv.mu.Unlock() + return akv.DeleteFunc(ctx, key) +} diff --git a/server/mock/service.go b/server/mock/service.go index 6ac07e5cda4..25445ff5832 100644 --- a/server/mock/service.go +++ b/server/mock/service.go @@ -3,14 +3,21 @@ package mock import ( "github.com/fleetdm/fleet/v4/server/fleet" kvmock "github.com/fleetdm/fleet/v4/server/mock/redis" + akvmock "github.com/fleetdm/fleet/v4/server/mock/redis_advanced" svcmock "github.com/fleetdm/fleet/v4/server/mock/service" ) //go:generate go run ./mockimpl/impl.go -o service/service_mock.go "s *Service" "fleet.Service" //go:generate go run ./mockimpl/impl.go -o redis/key_value_store.go "kv *KeyValueStore" "fleet.KeyValueStore" +// We need to use a new folder to avoid multiple of the same functions +//go:generate go run ./mockimpl/impl.go -o redis_advanced/advanced_key_value_store.go "akv *AdvancedKeyValueStore" "fleet.AdvancedKeyValueStore" var _ fleet.Service = new(svcmock.Service) type KVStore struct { kvmock.KeyValueStore } + +type AdvancedKVStore struct { + akvmock.AdvancedKeyValueStore +} diff --git a/server/service/apple_mdm.go b/server/service/apple_mdm.go index d0a976166d8..d58bf85e6aa 100644 --- a/server/service/apple_mdm.go +++ b/server/service/apple_mdm.go @@ -3320,7 +3320,7 @@ type MDMAppleCheckinAndCommandService struct { vppInstaller fleet.AppleMDMVPPInstaller mdmLifecycle *mdmlifecycle.HostLifecycle commandHandlers map[string][]fleet.MDMCommandResultsHandler - keyValueStore fleet.KeyValueStore + keyValueStore fleet.AdvancedKeyValueStore newActivityFn mdmlifecycle.NewActivityFunc isPremium bool } @@ -3331,7 +3331,7 @@ func NewMDMAppleCheckinAndCommandService( vppInstaller fleet.AppleMDMVPPInstaller, isPremium bool, logger *slog.Logger, - keyValueStore fleet.KeyValueStore, + keyValueStore fleet.AdvancedKeyValueStore, newActivityFn mdmlifecycle.NewActivityFunc, ) *MDMAppleCheckinAndCommandService { mdmLifecycle := mdmlifecycle.New(ds, logger, newActivityFn) @@ -3406,8 +3406,8 @@ func (svc *MDMAppleCheckinAndCommandService) Authenticate(r *mdm.Request, m *mdm return err } - if !scepRenewalInProgress { - if svc.keyValueStore != nil { + if svc.keyValueStore != nil { + if !scepRenewalInProgress { // Set sticky key for MDM enrollments to avoid updating team id on orbit enrollments err = svc.keyValueStore.Set(r.Context, fleet.StickyMDMEnrollmentKeyPrefix+r.ID, "1", fleet.StickyMDMEnrollmentTTL) if err != nil { @@ -3415,6 +3415,12 @@ func (svc *MDMAppleCheckinAndCommandService) Authenticate(r *mdm.Request, m *mdm svc.logger.ErrorContext(r.Context, "failed to set sticky mdm enrollment key", "err", err, "host_uuid", r.ID) } } + + // Set profile processing flag, is being handled by the apple_mdm worker, it will be cleared later if it's a SCEP renewal. + if err := svc.keyValueStore.Set(r.Context, fleet.MDMProfileProcessingKeyPrefix+":"+r.ID, "1", fleet.MDMProfileProcessingTTL); err != nil { + svc.logger.ErrorContext(r.Context, "failed to set mdm profile processing key", "err", err, "host_uuid", r.ID) + // We do not want to fail here, just log the error to notify of issues + } } return nil @@ -3444,6 +3450,13 @@ func (svc *MDMAppleCheckinAndCommandService) TokenUpdate(r *mdm.Request, m *mdm. if !m.AwaitingConfiguration { // Normal SCEP renewal - device is NOT at Setup Assistant. Clean refs and short-circuit. svc.logger.InfoContext(r.Context, "cleaned SCEP refs, skipping setup experience and mdm lifecycle turn on action", "host_uuid", r.ID) + + // Clean up redis key for profile processing if set. + if svc.keyValueStore != nil { + if err := svc.keyValueStore.Delete(r.Context, fleet.MDMProfileProcessingKeyPrefix+":"+r.ID); err != nil { + svc.logger.ErrorContext(r.Context, "failed to delete mdm profile processing key", "err", err, "host_uuid", r.ID) + } + } return nil } @@ -4895,6 +4908,7 @@ func ReconcileAppleProfiles( ctx context.Context, ds fleet.Datastore, commander *apple_mdm.MDMAppleCommander, + redisKeyValue fleet.AdvancedKeyValueStore, logger *slog.Logger, certProfilesLimit int, ) error { @@ -5239,6 +5253,43 @@ func ReconcileAppleProfiles( }) } + // check if some of the hosts to install already is handled by the apple setup worker + // we want to batch check for 1k hosts at a time to avoid hitting query parameter limits + const isBeingSetupBatchSize = 1000 + for i := 0; i < len(hostProfiles); i += isBeingSetupBatchSize { + end := min(i+isBeingSetupBatchSize, len(hostProfiles)) + batch := hostProfiles[i:end] + hostUUIDs := make([]string, len(batch)) + hostUUIDToHostProfiles := make(map[string][]*fleet.MDMAppleBulkUpsertHostProfilePayload, len(batch)) + for j, hp := range batch { + hostUUIDs[j] = fleet.MDMProfileProcessingKeyPrefix + ":" + hp.HostUUID + hostUUIDToHostProfiles[hp.HostUUID] = append(hostUUIDToHostProfiles[hp.HostUUID], hp) + } + + setupHostUUIDs, err := redisKeyValue.MGet(ctx, hostUUIDs) + if err != nil { + return ctxerr.Wrap(ctx, err, "filtering hosts being set up") + } + for keyedHostUUID, exists := range setupHostUUIDs { + if exists != nil { + hostUUID := strings.TrimPrefix(keyedHostUUID, fleet.MDMProfileProcessingKeyPrefix+":") + logger.DebugContext(ctx, "skipping profile reconciliation for host being set up", "host_uuid", hostUUID) + hps, ok := hostUUIDToHostProfiles[hostUUID] + if !ok { + logger.DebugContext(ctx, "expected host uuid to be present but was not, do not skip profile reconciliation", "host_uuid", hostUUID) + continue + } + for _, hp := range hps { + // Clear out host profile status and installTargets to avoid iterating over them in ProcessAndEnqueueProfiles + hp.Status = nil + hp.CommandUUID = "" + hostProfilesToInstallMap[fleet.HostProfileUUID{HostUUID: hp.HostUUID, ProfileUUID: hp.ProfileUUID}] = hp + delete(installTargets, hp.ProfileUUID) + } + } + } + } + // delete all profiles that have a matching identifier to be installed. // This is to prevent sending both a `RemoveProfile` and an // `InstallProfile` for the same identifier, which can cause race diff --git a/server/service/apple_mdm_test.go b/server/service/apple_mdm_test.go index 659a7f7ab34..32cfa341264 100644 --- a/server/service/apple_mdm_test.go +++ b/server/service/apple_mdm_test.go @@ -2808,6 +2808,7 @@ func TestMDMAppleReconcileAppleProfiles(t *testing.T) { ctx := context.Background() mdmStorage := &mdmmock.MDMAppleStore{} ds := new(mock.Store) + kv := new(mock.AdvancedKVStore) pushFactory, _ := newMockAPNSPushProviderFactory() pusher := nanomdm_pushsvc.New( mdmStorage, @@ -2861,6 +2862,10 @@ func TestMDMAppleReconcileAppleProfiles(t *testing.T) { return baseProfilesToInstall, baseProfilesToRemove, nil } + kv.MGetFunc = func(ctx context.Context, keys []string) (map[string]*string, error) { + return map[string]*string{}, nil + } + ds.GetMDMAppleProfilesContentsFunc = func(ctx context.Context, profileUUIDs []string) (map[string]mobileconfig.Mobileconfig, error) { require.ElementsMatch(t, []string{p1, p2, p4, p5, p7}, profileUUIDs) // only those profiles that are to be installed @@ -3161,7 +3166,7 @@ func TestMDMAppleReconcileAppleProfiles(t *testing.T) { failedCount++ require.Len(t, payload, 0) } - err := ReconcileAppleProfiles(ctx, ds, cmdr, slog.New(slog.DiscardHandler), 0) + err := ReconcileAppleProfiles(ctx, ds, cmdr, kv, slog.New(slog.DiscardHandler), 0) require.NoError(t, err) require.Equal(t, 1, failedCount) checkAndReset(t, true, &ds.ListMDMAppleProfilesToInstallAndRemoveFuncInvoked) @@ -3208,7 +3213,7 @@ func TestMDMAppleReconcileAppleProfiles(t *testing.T) { } enqueueFailForOp = fleet.MDMOperationTypeRemove - err := ReconcileAppleProfiles(ctx, ds, cmdr, slog.New(slog.DiscardHandler), 0) + err := ReconcileAppleProfiles(ctx, ds, cmdr, kv, slog.New(slog.DiscardHandler), 0) require.NoError(t, err) require.Equal(t, 1, failedCount) checkAndReset(t, true, &ds.ListMDMAppleProfilesToInstallAndRemoveFuncInvoked) @@ -3281,7 +3286,7 @@ func TestMDMAppleReconcileAppleProfiles(t *testing.T) { } enqueueFailForOp = fleet.MDMOperationTypeInstall - err := ReconcileAppleProfiles(ctx, ds, cmdr, slog.New(slog.DiscardHandler), 0) + err := ReconcileAppleProfiles(ctx, ds, cmdr, kv, slog.New(slog.DiscardHandler), 0) require.NoError(t, err) require.Equal(t, 1, failedCount) checkAndReset(t, true, &ds.ListMDMAppleProfilesToInstallAndRemoveFuncInvoked) @@ -3455,7 +3460,7 @@ func TestMDMAppleReconcileAppleProfiles(t *testing.T) { contents1 = originalContents1 expectedContents1 = originalExpectedContents1 }) - err := ReconcileAppleProfiles(ctx, ds, cmdr, slog.New(slog.DiscardHandler), 0) + err := ReconcileAppleProfiles(ctx, ds, cmdr, kv, slog.New(slog.DiscardHandler), 0) require.NoError(t, err) assert.Equal(t, 2, upsertCount) // checkAndReset(t, true, &ds.GetAllCertificateAuthoritiesFuncInvoked) @@ -3483,7 +3488,7 @@ func TestMDMAppleReconcileAppleProfiles(t *testing.T) { ds.GetHostEmailsFunc = func(ctx context.Context, hostUUID string, source string) ([]string, error) { return nil, errors.New("GetHostEmailsFuncError") } - err := ReconcileAppleProfiles(ctx, ds, cmdr, slog.New(slog.Default().Handler()), 0) + err := ReconcileAppleProfiles(ctx, ds, cmdr, kv, slog.New(slog.Default().Handler()), 0) assert.ErrorContains(t, err, "GetHostEmailsFuncError") // checkAndReset(t, true, &ds.GetAllCertificateAuthoritiesFuncInvoked) checkAndReset(t, true, &ds.ListMDMAppleProfilesToInstallAndRemoveFuncInvoked) @@ -3545,7 +3550,7 @@ func TestMDMAppleReconcileAppleProfiles(t *testing.T) { hostUUIDs = append(hostUUIDs, p.HostUUID) } - err := ReconcileAppleProfiles(ctx, ds, cmdr, slog.New(slog.DiscardHandler), 0) + err := ReconcileAppleProfiles(ctx, ds, cmdr, kv, slog.New(slog.DiscardHandler), 0) require.NoError(t, err) assert.Empty(t, hostUUIDs, "all host+profile combinations should be updated") require.Equal(t, 5, failedCount, "number of profiles with bad content") @@ -3563,6 +3568,7 @@ func TestReconcileAppleProfilesCAThrottle(t *testing.T) { ctx := t.Context() mdmStorage := &mdmmock.MDMAppleStore{} ds := new(mock.Store) + kv := new(mock.AdvancedKVStore) pushFactory, _ := newMockAPNSPushProviderFactory() pusher := nanomdm_pushsvc.New( mdmStorage, @@ -3617,6 +3623,10 @@ func TestReconcileAppleProfilesCAThrottle(t *testing.T) { }, nil } + kv.MGetFunc = func(ctx context.Context, keys []string) (map[string]*string, error) { + return make(map[string]*string), nil + } + ds.BulkDeleteMDMAppleHostsConfigProfilesFunc = func(ctx context.Context, payload []*fleet.MDMAppleProfilePayload) error { return nil } @@ -3692,7 +3702,7 @@ func TestReconcileAppleProfilesCAThrottle(t *testing.T) { t.Run("limit=0 sends all profiles", func(t *testing.T) { upsertedProfiles = nil bulkUpsertCallCount = 0 - err := ReconcileAppleProfiles(ctx, ds, cmdr, slog.New(slog.DiscardHandler), 0) + err := ReconcileAppleProfiles(ctx, ds, cmdr, kv, slog.New(slog.DiscardHandler), 0) require.NoError(t, err) // All 10 host-profile pairs should be upserted (5 CA + 5 non-CA) @@ -3711,7 +3721,7 @@ func TestReconcileAppleProfilesCAThrottle(t *testing.T) { t.Run("limit=2 throttles CA profiles only", func(t *testing.T) { upsertedProfiles = nil bulkUpsertCallCount = 0 - err := ReconcileAppleProfiles(ctx, ds, cmdr, slog.New(slog.DiscardHandler), 2) + err := ReconcileAppleProfiles(ctx, ds, cmdr, kv, slog.New(slog.DiscardHandler), 2) require.NoError(t, err) // Should have 2 CA + 5 non-CA = 7 host-profile pairs upserted @@ -3749,7 +3759,7 @@ func TestReconcileAppleProfilesCAThrottle(t *testing.T) { return recentProfilesToInstall, nil, nil } - err := ReconcileAppleProfiles(ctx, ds, cmdr, slog.New(slog.DiscardHandler), 2) + err := ReconcileAppleProfiles(ctx, ds, cmdr, kv, slog.New(slog.DiscardHandler), 2) require.NoError(t, err) var caCount, nonCACount int @@ -3787,7 +3797,7 @@ func TestReconcileAppleProfilesCAThrottle(t *testing.T) { return nil, profilesToRemove, nil } - err := ReconcileAppleProfiles(ctx, ds, cmdr, slog.New(slog.DiscardHandler), 2) + err := ReconcileAppleProfiles(ctx, ds, cmdr, kv, slog.New(slog.DiscardHandler), 2) require.NoError(t, err) var removeCount int @@ -3805,6 +3815,167 @@ func TestReconcileAppleProfilesCAThrottle(t *testing.T) { }) } +func TestReconcileAppleProfilesSkipsHostBeingProcessed(t *testing.T) { + ctx := t.Context() + mdmStorage := &mdmmock.MDMAppleStore{} + ds := new(mock.Store) + kv := new(mock.AdvancedKVStore) + pushFactory, _ := newMockAPNSPushProviderFactory() + pusher := nanomdm_pushsvc.New( + mdmStorage, + mdmStorage, + pushFactory, + NewNanoMDMLogger(slog.New(slog.DiscardHandler)), + ) + mdmConfig := config.MDMConfig{ + AppleSCEPCert: "./testdata/server.pem", + AppleSCEPKey: "./testdata/server.key", + } + ds.GetAllMDMConfigAssetsByNameFunc = func(ctx context.Context, assetNames []fleet.MDMAssetName, + _ sqlx.QueryerContext, + ) (map[fleet.MDMAssetName]fleet.MDMConfigAsset, error) { + _, pemCert, pemKey, err := mdmConfig.AppleSCEP() + require.NoError(t, err) + return map[fleet.MDMAssetName]fleet.MDMConfigAsset{ + fleet.MDMAssetCACert: {Value: pemCert}, + fleet.MDMAssetCAKey: {Value: pemKey}, + }, nil + } + + cmdr := apple_mdm.NewMDMAppleCommander(mdmStorage, pusher) + + profileUUID := "a" + uuid.NewString() + profileContent := []byte("regular profile content") + blockedHostUUID := "host-blocked" + nonSetupHostUUID := "host-non-setup" + + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{MDM: fleet.MDM{EnabledAndConfigured: true}}, nil + } + ds.ListMDMAppleProfilesToInstallAndRemoveFunc = func(ctx context.Context) ([]*fleet.MDMAppleProfilePayload, []*fleet.MDMAppleProfilePayload, error) { + return []*fleet.MDMAppleProfilePayload{ + {ProfileUUID: profileUUID, ProfileIdentifier: "com.test.profile", ProfileName: "Test Profile", HostUUID: blockedHostUUID, Scope: fleet.PayloadScopeSystem}, + {ProfileUUID: profileUUID, ProfileIdentifier: "com.test.profile", ProfileName: "Test Profile", HostUUID: nonSetupHostUUID, Scope: fleet.PayloadScopeSystem}, + }, nil, nil + } + ds.GetMDMAppleProfilesContentsFunc = func(ctx context.Context, profileUUIDs []string) (map[string]mobileconfig.Mobileconfig, error) { + return map[string]mobileconfig.Mobileconfig{profileUUID: profileContent}, nil + } + ds.BulkDeleteMDMAppleHostsConfigProfilesFunc = func(ctx context.Context, payload []*fleet.MDMAppleProfilePayload) error { + return nil + } + ds.GetNanoMDMUserEnrollmentFunc = func(ctx context.Context, hostUUID string) (*fleet.NanoEnrollment, error) { + return nil, nil + } + ds.GetGroupedCertificateAuthoritiesFunc = func(ctx context.Context, allCAs bool) (*fleet.GroupedCertificateAuthorities, error) { + return &fleet.GroupedCertificateAuthorities{}, nil + } + ds.AggregateEnrollSecretPerTeamFunc = func(ctx context.Context) ([]*fleet.EnrollSecret, error) { + return []*fleet.EnrollSecret{}, nil + } + ds.BulkUpsertMDMAppleConfigProfilesFunc = func(ctx context.Context, p []*fleet.MDMAppleConfigProfile) error { + return nil + } + mdmStorage.BulkDeleteHostUserCommandsWithoutResultsFunc = func(ctx context.Context, commandToIDs map[string][]string) error { + return nil + } + mdmStorage.EnqueueCommandFunc = func(ctx context.Context, id []string, cmd *mdm.CommandWithSubtype) (map[string]error, error) { + return nil, nil + } + mdmStorage.RetrievePushInfoFunc = func(ctx context.Context, tokens []string) (map[string]*mdm.Push, error) { + res := make(map[string]*mdm.Push, len(tokens)) + for _, t := range tokens { + res[t] = &mdm.Push{PushMagic: "", Token: []byte(t), Topic: ""} + } + return res, nil + } + mdmStorage.RetrievePushCertFunc = func(ctx context.Context, topic string) (*tls.Certificate, string, error) { + cert, err := tls.LoadX509KeyPair("testdata/server.pem", "testdata/server.key") + return &cert, "", err + } + mdmStorage.IsPushCertStaleFunc = func(ctx context.Context, topic string, staleToken string) (bool, error) { + return false, nil + } + mdmStorage.GetAllMDMConfigAssetsByNameFunc = func(ctx context.Context, assetNames []fleet.MDMAssetName, + _ sqlx.QueryerContext, + ) (map[fleet.MDMAssetName]fleet.MDMConfigAsset, error) { + certPEM, err := os.ReadFile("./testdata/server.pem") + require.NoError(t, err) + keyPEM, err := os.ReadFile("./testdata/server.key") + require.NoError(t, err) + return map[fleet.MDMAssetName]fleet.MDMConfigAsset{ + fleet.MDMAssetCACert: {Value: certPEM}, + fleet.MDMAssetCAKey: {Value: keyPEM}, + }, nil + } + + // Track what gets upserted and which hosts get commands enqueued + var upsertedProfiles []*fleet.MDMAppleBulkUpsertHostProfilePayload + var bulkUpsertCallCount int + ds.BulkUpsertMDMAppleHostProfilesFunc = func(ctx context.Context, payload []*fleet.MDMAppleBulkUpsertHostProfilePayload) error { + bulkUpsertCallCount++ + if bulkUpsertCallCount == 1 { + upsertedProfiles = payload + } + return nil + } + + // Simulate an in-memory KV store with TTL support + kvStore := make(map[string]string) + kv.MGetFunc = func(ctx context.Context, keys []string) (map[string]*string, error) { + result := make(map[string]*string, len(keys)) + for _, k := range keys { + if v, ok := kvStore[k]; ok { + result[k] = &v + } else { + result[k] = nil + } + } + return result, nil + } + + // verify host marked as going through setup does not get profiles reconciled + blockedKey := fleet.MDMProfileProcessingKeyPrefix + ":" + blockedHostUUID + kvStore[blockedKey] = "1" + + upsertedProfiles = nil + bulkUpsertCallCount = 0 + err := ReconcileAppleProfiles(ctx, ds, cmdr, kv, slog.New(slog.DiscardHandler), 0) + require.NoError(t, err) + + // Only the non setup host should have profiles with a pending status and command UUID; + // the blocked host should have its status/command cleared. + var pendingHosts []string + var skippedHosts []string + for _, p := range upsertedProfiles { + if p.Status != nil && *p.Status == fleet.MDMDeliveryPending && p.CommandUUID != "" { + pendingHosts = append(pendingHosts, p.HostUUID) + } else if p.Status == nil && p.CommandUUID == "" { + skippedHosts = append(skippedHosts, p.HostUUID) + } + } + assert.Contains(t, pendingHosts, nonSetupHostUUID, "non setup host should have profiles enqueued") + assert.NotContains(t, pendingHosts, blockedHostUUID, "blocked host should NOT have profiles enqueued") + assert.Contains(t, skippedHosts, blockedHostUUID, "blocked host should be skipped with nil status") + + // expire the key, the host that didn't get profiles before should do now + delete(kvStore, blockedKey) // simulate TTL expiry + + upsertedProfiles = nil + bulkUpsertCallCount = 0 + err = ReconcileAppleProfiles(ctx, ds, cmdr, kv, slog.New(slog.DiscardHandler), 0) + require.NoError(t, err) + + pendingHosts = nil + for _, p := range upsertedProfiles { + if p.Status != nil && *p.Status == fleet.MDMDeliveryPending && p.CommandUUID != "" { + pendingHosts = append(pendingHosts, p.HostUUID) + } + } + assert.Contains(t, pendingHosts, nonSetupHostUUID, "non setup host should still have profiles enqueued") + assert.Contains(t, pendingHosts, blockedHostUUID, "previously blocked host should now have profiles enqueued after key expiry") +} + func TestAppleMDMFileVaultEscrowFunctions(t *testing.T) { svc := Service{} diff --git a/server/service/integration_mdm_profiles_test.go b/server/service/integration_mdm_profiles_test.go index 0933fd98b04..6592294e5d4 100644 --- a/server/service/integration_mdm_profiles_test.go +++ b/server/service/integration_mdm_profiles_test.go @@ -34,6 +34,7 @@ import ( "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/service/contract" "github.com/fleetdm/fleet/v4/server/service/integrationtest/scep_server" + "github.com/fleetdm/fleet/v4/server/service/redis_key_value" "github.com/fleetdm/fleet/v4/server/test" "github.com/google/uuid" "github.com/jmoiron/sqlx" @@ -5222,6 +5223,7 @@ func (s *integrationMDMTestSuite) TestBatchSetMDMProfilesBackwardsCompat() { func (s *integrationMDMTestSuite) TestMDMBatchSetProfilesKeepsReservedNames() { t := s.T() ctx := context.Background() + kv := redis_key_value.New(s.redisPool) checkMacProfs := func(teamID *uint, names ...string) { var count int @@ -5263,7 +5265,7 @@ func (s *integrationMDMTestSuite) TestMDMBatchSetProfilesKeepsReservedNames() { if len(secrets) == 0 { require.NoError(t, s.ds.ApplyEnrollSecrets(ctx, nil, []*fleet.EnrollSecret{{Secret: t.Name()}})) } - require.NoError(t, ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, s.logger, 0)) + require.NoError(t, ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, kv, s.logger, 0)) // turn on disk encryption and os updates s.DoJSON("PATCH", "/api/latest/fleet/config", json.RawMessage(`{ @@ -5343,7 +5345,7 @@ func (s *integrationMDMTestSuite) TestMDMBatchSetProfilesKeepsReservedNames() { require.Equal(t, "14.6.1", tmResp.Team.Config.MDM.MacOSUpdates.MinimumVersion.Value) require.Equal(t, true, tmResp.Team.Config.MDM.MacOSUpdates.UpdateNewHosts.Value) - require.NoError(t, ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, s.logger, 0)) + require.NoError(t, ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, kv, s.logger, 0)) checkMacProfs(&tmResp.Team.ID, servermdm.ListFleetReservedMacOSProfileNames()...) checkWinProfs(&tmResp.Team.ID, servermdm.ListFleetReservedWindowsProfileNames()...) @@ -6337,8 +6339,6 @@ func (s *integrationMDMTestSuite) TestAppleProfileDeletion() { return err }) - // trigger a profile sync - s.awaitTriggerProfileSchedule(t) installs, removes := checkNextPayloads(t, mdmDevice, false) // verify that we received all profiles s.signedProfilesMatch( diff --git a/server/service/integration_mdm_test.go b/server/service/integration_mdm_test.go index 65ce68ce5db..c2b7fc50430 100644 --- a/server/service/integration_mdm_test.go +++ b/server/service/integration_mdm_test.go @@ -84,6 +84,7 @@ import ( "github.com/fleetdm/fleet/v4/server/service/integrationtest/scep_server" "github.com/fleetdm/fleet/v4/server/service/mock" "github.com/fleetdm/fleet/v4/server/service/osquery_utils" + "github.com/fleetdm/fleet/v4/server/service/redis_key_value" "github.com/fleetdm/fleet/v4/server/service/schedule" "github.com/fleetdm/fleet/v4/server/test" "github.com/fleetdm/fleet/v4/server/worker" @@ -140,6 +141,7 @@ type integrationMDMTestSuite struct { proxyCallbackURL string jwtSigningKey *rsa.PrivateKey softwareInstallerStore fleet.SoftwareInstallerStore + keyValueStore fleet.AdvancedKeyValueStore } // appleVPPConfigSrvConf is used to configure the mock server that mocks Apple's VPP endpoints. @@ -302,6 +304,8 @@ func (s *integrationMDMTestSuite) SetupSuite() { softwareTitleIconStore, err := filesystem.NewSoftwareTitleIconStore(iconDir) require.NoError(s.T(), err) + keyValueStore := redis_key_value.New(s.redisPool) + serverConfig := TestServerOpts{ License: &fleet.LicenseInfo{ Tier: fleet.TierPremium, @@ -321,6 +325,7 @@ func (s *integrationMDMTestSuite) SetupSuite() { BootstrapPackageStore: bootstrapPackageStore, androidMockClient: androidMockClient, androidModule: androidSvc, + KeyValueStore: keyValueStore, StartCronSchedules: []TestNewScheduleFunc{ func(ctx context.Context, ds fleet.Datastore) fleet.NewCronScheduleFunc { return func() (fleet.CronSchedule, error) { @@ -338,7 +343,7 @@ func (s *integrationMDMTestSuite) SetupSuite() { s.onProfileJobDone() }() } - err = ReconcileAppleProfiles(ctx, ds, mdmCommander, logger, 0) + err = ReconcileAppleProfiles(ctx, ds, mdmCommander, keyValueStore, logger, 0) require.NoError(s.T(), err) return err }), @@ -498,6 +503,7 @@ func (s *integrationMDMTestSuite) SetupSuite() { s.mdmCommander = mdmCommander s.logger = serverLogger s.androidAPIClient = androidMockClient + s.keyValueStore = keyValueStore fleetdmSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { status := s.fleetDMNextCSRStatus.Swap(http.StatusOK) @@ -12021,6 +12027,7 @@ func (s *integrationMDMTestSuite) TestSilentMigrationGotchas() { func (s *integrationMDMTestSuite) TestAPNsPushCron() { t := s.T() ctx := context.Background() + kv := redis_key_value.New(s.redisPool) s.Do("POST", "/api/v1/fleet/mdm/profiles/batch", batchSetMDMProfilesRequest{Profiles: []fleet.MDMProfileBatchPayload{ {Name: "N1", Contents: mobileconfigForTest("N1", "I1")}, @@ -12050,13 +12057,13 @@ func (s *integrationMDMTestSuite) TestAPNsPushCron() { } // trigger the reconciliation schedule - err := ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, s.logger, 0) + err := ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, kv, s.logger, 0) require.NoError(t, err) require.Len(t, recordedPushes, 1) recordedPushes = nil // triggering the schedule again doesn't send any more pushes - err = ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, s.logger, 0) + err = ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, kv, s.logger, 0) require.NoError(t, err) require.Len(t, recordedPushes, 0) recordedPushes = nil @@ -12088,6 +12095,7 @@ func (s *integrationMDMTestSuite) TestAPNsPushCron() { func (s *integrationMDMTestSuite) TestAPNsPushWithNotNow() { t := s.T() ctx := context.Background() + kv := redis_key_value.New(s.redisPool) // macOS host, MDM on _, macDevice := createHostThenEnrollMDM(s.ds, s.server.URL, t) @@ -12111,7 +12119,7 @@ func (s *integrationMDMTestSuite) TestAPNsPushWithNotNow() { } // trigger the reconciliation schedule - err := ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, s.logger, 0) + err := ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, kv, s.logger, 0) require.NoError(t, err) require.Len(t, recordedPushes, 1) recordedPushes = nil @@ -12132,7 +12140,7 @@ func (s *integrationMDMTestSuite) TestAPNsPushWithNotNow() { }}, http.StatusNoContent) // trigger the reconciliation schedule - err = ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, s.logger, 0) + err = ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, kv, s.logger, 0) require.NoError(t, err) require.Len(t, recordedPushes, 1) recordedPushes = nil @@ -12152,7 +12160,7 @@ func (s *integrationMDMTestSuite) TestAPNsPushWithNotNow() { assert.Nil(t, cmd) // A 'NotNow' command will not trigger a new push. Device is expected to check in again when conditions change. - err = ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, s.logger, 0) + err = ReconcileAppleProfiles(ctx, s.ds, s.mdmCommander, kv, s.logger, 0) require.NoError(t, err) require.Len(t, recordedPushes, 0) recordedPushes = nil diff --git a/server/service/redis_key_value/redis_key_value.go b/server/service/redis_key_value/redis_key_value.go index 010c24c19cc..a51a8100ec0 100644 --- a/server/service/redis_key_value/redis_key_value.go +++ b/server/service/redis_key_value/redis_key_value.go @@ -13,8 +13,8 @@ import ( redigo "github.com/gomodule/redigo/redis" ) -// RedisKeyValue is a basic key/value store with SET and GET operations -// Items are removed via expiration (defined in the SET operation). +// RedisKeyValue is a key/value store with basic SET and GET operations and advanced operations +// Items are removed via expiration (defined in the SET operation), or via the DEL command type RedisKeyValue struct { pool fleet.RedisPool testPrefix string // for tests, the key prefix to use to avoid conflicts @@ -56,3 +56,38 @@ func (r *RedisKeyValue) Get(ctx context.Context, key string) (*string, error) { } return &res, nil } + +func (r *RedisKeyValue) MGet(ctx context.Context, keys []string) (map[string]*string, error) { + conn := redis.ConfigureDoer(r.pool, r.pool.Get()) + defer conn.Close() + + redisKeys := make([]interface{}, len(keys)) + for i, key := range keys { + redisKeys[i] = r.testPrefix + prefix + key + } + + res, err := redigo.Strings(conn.Do("MGET", redisKeys...)) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "redis failed to mget") + } + + result := make(map[string]*string, len(keys)) + for i, key := range keys { + if res[i] == "" { + result[key] = nil + } else { + result[key] = &res[i] + } + } + return result, nil +} + +func (r *RedisKeyValue) Delete(ctx context.Context, key string) error { + conn := redis.ConfigureDoer(r.pool, r.pool.Get()) + defer conn.Close() + + if _, err := redigo.Int(conn.Do("DEL", r.testPrefix+prefix+key)); err != nil { + return ctxerr.Wrap(ctx, err, "redis failed to delete") + } + return nil +}