Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/fleet/cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,7 @@ func newAppleMDMProfileManagerSchedule(
instanceID string,
ds fleet.Datastore,
commander *apple_mdm.MDMAppleCommander,
redisKeyValue fleet.AdvancedKeyValueStore,
logger *slog.Logger,
certProfilesLimit int,
) (*schedule.Schedule, error) {
Expand All @@ -1478,7 +1479,7 @@ func newAppleMDMProfileManagerSchedule(
ctx, name, instanceID, defaultInterval, ds, ds,
schedule.WithLogger(logger),
schedule.WithJob("manage_apple_profiles", func(ctx context.Context) error {
return service.ReconcileAppleProfiles(ctx, ds, commander, logger, certProfilesLimit)
return service.ReconcileAppleProfiles(ctx, ds, commander, redisKeyValue, logger, certProfilesLimit)
}),
schedule.WithJob("manage_apple_declarations", func(ctx context.Context) error {
return service.ReconcileAppleDeclarations(ctx, ds, commander, logger)
Expand Down
3 changes: 2 additions & 1 deletion cmd/fleet/cron_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ func TestNewAppleMDMProfileManagerWithoutConfig(t *testing.T) {
ctx := context.Background()
mdmStorage := &mdmmock.MDMAppleStore{}
ds := new(mock.Store)
kv := new(mock.AdvancedKVStore)
cmdr := apple_mdm.NewMDMAppleCommander(mdmStorage, nil)
logger := slog.New(slog.DiscardHandler)

sch, err := newAppleMDMProfileManagerSchedule(ctx, "foo", ds, cmdr, logger, 0)
sch, err := newAppleMDMProfileManagerSchedule(ctx, "foo", ds, cmdr, kv, logger, 0)
require.NotNil(t, sch)
require.NoError(t, err)
}
Expand Down
1 change: 1 addition & 0 deletions cmd/fleet/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,7 @@ func runServeCmd(cmd *cobra.Command, configManager configpkg.Manager, debug, dev
instanceID,
ds,
apple_mdm.NewMDMAppleCommander(mdmStorage, mdmPushService),
redis_key_value.New(redisPool),
logger,
config.MDM.CertificateProfilesLimit,
)
Expand Down
5 changes: 5 additions & 0 deletions server/fleet/mdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ const (

StickyMDMEnrollmentKeyPrefix = "sticky_mdm_enrollment_" // + host UUID
StickyMDMEnrollmentTTL = 30 * time.Minute

// MDMProfileProcessingKeyPrefix is used to indicate that a host is currently being processed for MDM profile installation.
// We wrap the key in braces to make Redis hash the keys to the same slot, avoding CrossSlot errors.
MDMProfileProcessingKeyPrefix = "{mdm_profile_processing}" // + :hostUUID
MDMProfileProcessingTTL = 1 * time.Minute // We use a low time here, to avoid letting it sit for too long in case of errors.
)

// FleetVarName represents the name of a Fleet variable (without the FLEET_VAR_ prefix).
Expand Down
10 changes: 10 additions & 0 deletions server/fleet/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,16 @@ type KeyValueStore interface {
Get(ctx context.Context, key string) (*string, error)
}

type AdvancedKeyValueStore interface {
KeyValueStore

// MGet returns the values for the given keys.
// It returns a map of key to value, where the value is nil if the key doesn't exist.
// Important to use hashes for the keys to land in the same slot.
MGet(ctx context.Context, keys []string) (map[string]*string, error)
Delete(ctx context.Context, key string) error
}

const (
// BatchSetSoftwareInstallerStatusProcessing is the value returned for an ongoing BatchSetSoftwareInstallers operation.
BatchSetSoftwareInstallersStatusProcessing = "processing"
Expand Down
65 changes: 65 additions & 0 deletions server/mock/redis_advanced/advanced_key_value_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Automatically generated by mockimpl. DO NOT EDIT!

package mock

import (
"context"
"sync"
"time"

"github.com/fleetdm/fleet/v4/server/fleet"
)

var _ fleet.AdvancedKeyValueStore = (*AdvancedKeyValueStore)(nil)

type SetFunc func(ctx context.Context, key string, value string, expireTime time.Duration) error

type GetFunc func(ctx context.Context, key string) (*string, error)

type MGetFunc func(ctx context.Context, keys []string) (map[string]*string, error)

type DeleteFunc func(ctx context.Context, key string) error

type AdvancedKeyValueStore struct {
SetFunc SetFunc
SetFuncInvoked bool

GetFunc GetFunc
GetFuncInvoked bool

MGetFunc MGetFunc
MGetFuncInvoked bool

DeleteFunc DeleteFunc
DeleteFuncInvoked bool

mu sync.Mutex
}

func (akv *AdvancedKeyValueStore) Set(ctx context.Context, key string, value string, expireTime time.Duration) error {
akv.mu.Lock()
akv.SetFuncInvoked = true
akv.mu.Unlock()
return akv.SetFunc(ctx, key, value, expireTime)
}

func (akv *AdvancedKeyValueStore) Get(ctx context.Context, key string) (*string, error) {
akv.mu.Lock()
akv.GetFuncInvoked = true
akv.mu.Unlock()
return akv.GetFunc(ctx, key)
}

func (akv *AdvancedKeyValueStore) MGet(ctx context.Context, keys []string) (map[string]*string, error) {
akv.mu.Lock()
akv.MGetFuncInvoked = true
akv.mu.Unlock()
return akv.MGetFunc(ctx, keys)
}

func (akv *AdvancedKeyValueStore) Delete(ctx context.Context, key string) error {
akv.mu.Lock()
akv.DeleteFuncInvoked = true
akv.mu.Unlock()
return akv.DeleteFunc(ctx, key)
}
7 changes: 7 additions & 0 deletions server/mock/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@ package mock
import (
"github.com/fleetdm/fleet/v4/server/fleet"
kvmock "github.com/fleetdm/fleet/v4/server/mock/redis"
akvmock "github.com/fleetdm/fleet/v4/server/mock/redis_advanced"
svcmock "github.com/fleetdm/fleet/v4/server/mock/service"
)

//go:generate go run ./mockimpl/impl.go -o service/service_mock.go "s *Service" "fleet.Service"
//go:generate go run ./mockimpl/impl.go -o redis/key_value_store.go "kv *KeyValueStore" "fleet.KeyValueStore"
// We need to use a new folder to avoid multiple of the same functions
//go:generate go run ./mockimpl/impl.go -o redis_advanced/advanced_key_value_store.go "akv *AdvancedKeyValueStore" "fleet.AdvancedKeyValueStore"

var _ fleet.Service = new(svcmock.Service)

type KVStore struct {
kvmock.KeyValueStore
}

type AdvancedKVStore struct {
akvmock.AdvancedKeyValueStore
}
59 changes: 55 additions & 4 deletions server/service/apple_mdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3320,7 +3320,7 @@ type MDMAppleCheckinAndCommandService struct {
vppInstaller fleet.AppleMDMVPPInstaller
mdmLifecycle *mdmlifecycle.HostLifecycle
commandHandlers map[string][]fleet.MDMCommandResultsHandler
keyValueStore fleet.KeyValueStore
keyValueStore fleet.AdvancedKeyValueStore
newActivityFn mdmlifecycle.NewActivityFunc
isPremium bool
}
Expand All @@ -3331,7 +3331,7 @@ func NewMDMAppleCheckinAndCommandService(
vppInstaller fleet.AppleMDMVPPInstaller,
isPremium bool,
logger *slog.Logger,
keyValueStore fleet.KeyValueStore,
keyValueStore fleet.AdvancedKeyValueStore,
newActivityFn mdmlifecycle.NewActivityFunc,
) *MDMAppleCheckinAndCommandService {
mdmLifecycle := mdmlifecycle.New(ds, logger, newActivityFn)
Expand Down Expand Up @@ -3406,15 +3406,21 @@ func (svc *MDMAppleCheckinAndCommandService) Authenticate(r *mdm.Request, m *mdm
return err
}

if !scepRenewalInProgress {
if svc.keyValueStore != nil {
if svc.keyValueStore != nil {
if !scepRenewalInProgress {
// Set sticky key for MDM enrollments to avoid updating team id on orbit enrollments
err = svc.keyValueStore.Set(r.Context, fleet.StickyMDMEnrollmentKeyPrefix+r.ID, "1", fleet.StickyMDMEnrollmentTTL)
if err != nil {
// We do not want to fail here, just log the error to notify
svc.logger.ErrorContext(r.Context, "failed to set sticky mdm enrollment key", "err", err, "host_uuid", r.ID)
}
}

// Set profile processing flag, is being handled by the apple_mdm worker, it will be cleared later if it's a SCEP renewal.
if err := svc.keyValueStore.Set(r.Context, fleet.MDMProfileProcessingKeyPrefix+":"+r.ID, "1", fleet.MDMProfileProcessingTTL); err != nil {
svc.logger.ErrorContext(r.Context, "failed to set mdm profile processing key", "err", err, "host_uuid", r.ID)
// We do not want to fail here, just log the error to notify of issues
}
}

return nil
Expand Down Expand Up @@ -3444,6 +3450,13 @@ func (svc *MDMAppleCheckinAndCommandService) TokenUpdate(r *mdm.Request, m *mdm.
if !m.AwaitingConfiguration {
// Normal SCEP renewal - device is NOT at Setup Assistant. Clean refs and short-circuit.
svc.logger.InfoContext(r.Context, "cleaned SCEP refs, skipping setup experience and mdm lifecycle turn on action", "host_uuid", r.ID)

// Clean up redis key for profile processing if set.
if svc.keyValueStore != nil {
if err := svc.keyValueStore.Delete(r.Context, fleet.MDMProfileProcessingKeyPrefix+":"+r.ID); err != nil {
svc.logger.ErrorContext(r.Context, "failed to delete mdm profile processing key", "err", err, "host_uuid", r.ID)
}
}
return nil
}

Expand Down Expand Up @@ -4895,6 +4908,7 @@ func ReconcileAppleProfiles(
ctx context.Context,
ds fleet.Datastore,
commander *apple_mdm.MDMAppleCommander,
redisKeyValue fleet.AdvancedKeyValueStore,
logger *slog.Logger,
certProfilesLimit int,
) error {
Expand Down Expand Up @@ -5239,6 +5253,43 @@ func ReconcileAppleProfiles(
})
}

// check if some of the hosts to install already is handled by the apple setup worker
// we want to batch check for 1k hosts at a time to avoid hitting query parameter limits
const isBeingSetupBatchSize = 1000
for i := 0; i < len(hostProfiles); i += isBeingSetupBatchSize {
end := min(i+isBeingSetupBatchSize, len(hostProfiles))
batch := hostProfiles[i:end]
hostUUIDs := make([]string, len(batch))
hostUUIDToHostProfiles := make(map[string][]*fleet.MDMAppleBulkUpsertHostProfilePayload, len(batch))
for j, hp := range batch {
hostUUIDs[j] = fleet.MDMProfileProcessingKeyPrefix + ":" + hp.HostUUID
hostUUIDToHostProfiles[hp.HostUUID] = append(hostUUIDToHostProfiles[hp.HostUUID], hp)
}

setupHostUUIDs, err := redisKeyValue.MGet(ctx, hostUUIDs)
if err != nil {
return ctxerr.Wrap(ctx, err, "filtering hosts being set up")
}
for keyedHostUUID, exists := range setupHostUUIDs {
if exists != nil {
hostUUID := strings.TrimPrefix(keyedHostUUID, fleet.MDMProfileProcessingKeyPrefix+":")
logger.DebugContext(ctx, "skipping profile reconciliation for host being set up", "host_uuid", hostUUID)
hps, ok := hostUUIDToHostProfiles[hostUUID]
if !ok {
logger.DebugContext(ctx, "expected host uuid to be present but was not, do not skip profile reconciliation", "host_uuid", hostUUID)
continue
}
for _, hp := range hps {
// Clear out host profile status and installTargets to avoid iterating over them in ProcessAndEnqueueProfiles
hp.Status = nil
hp.CommandUUID = ""
hostProfilesToInstallMap[fleet.HostProfileUUID{HostUUID: hp.HostUUID, ProfileUUID: hp.ProfileUUID}] = hp
delete(installTargets, hp.ProfileUUID)
}
}
}
}

// delete all profiles that have a matching identifier to be installed.
// This is to prevent sending both a `RemoveProfile` and an
// `InstallProfile` for the same identifier, which can cause race
Expand Down
Loading
Loading