diff --git a/azureappconfiguration/azureappconfiguration.go b/azureappconfiguration/azureappconfiguration.go index c91f570..a25d941 100644 --- a/azureappconfiguration/azureappconfiguration.go +++ b/azureappconfiguration/azureappconfiguration.go @@ -21,23 +21,32 @@ import ( "strconv" "strings" "sync" + "sync/atomic" + "github.com/Azure/AppConfiguration-GoProvider/azureappconfiguration/internal/refresh" "github.com/Azure/AppConfiguration-GoProvider/azureappconfiguration/internal/tracing" "github.com/Azure/AppConfiguration-GoProvider/azureappconfiguration/internal/tree" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" decoder "github.com/go-viper/mapstructure/v2" "golang.org/x/sync/errgroup" ) // An AzureAppConfiguration is a configuration provider that stores and manages settings sourced from Azure App Configuration. type AzureAppConfiguration struct { - keyValues map[string]any - kvSelectors []Selector - trimPrefixes []string + keyValues map[string]any + kvSelectors []Selector + trimPrefixes []string + watchedSettings []WatchedSetting + + sentinelETags map[WatchedSetting]*azcore.ETag + kvRefreshTimer refresh.Condition + onRefreshSuccess []func() + tracingOptions tracing.Options clientManager *configurationClientManager resolver *keyVaultReferenceResolver - tracingOptions tracing.Options + refreshInProgress atomic.Bool } // Load initializes a new AzureAppConfiguration instance and loads the configuration data from @@ -56,6 +65,10 @@ func Load(ctx context.Context, authentication AuthenticationOptions, options *Op return nil, err } + if err := verifyOptions(options); err != nil { + return nil, err + } + if options == nil { options = &Options{} } @@ -77,9 +90,17 @@ func Load(ctx context.Context, authentication AuthenticationOptions, options *Op credential: options.KeyVaultOptions.Credential, } + if options.RefreshOptions.Enabled { + azappcfg.kvRefreshTimer = refresh.NewTimer(options.RefreshOptions.Interval) + azappcfg.watchedSettings = normalizedWatchedSettings(options.RefreshOptions.WatchedSettings) + azappcfg.sentinelETags = make(map[WatchedSetting]*azcore.ETag) + } + if err := azappcfg.load(ctx); err != nil { return nil, err } + // Set the initial load finished flag + azappcfg.tracingOptions.InitialLoadFinished = true return azappcfg, nil } @@ -150,14 +171,107 @@ func (azappcfg *AzureAppConfiguration) GetBytes(options *ConstructionOptions) ([ return json.Marshal(azappcfg.constructHierarchicalMap(options.Separator)) } +// Refresh manually triggers a refresh of the configuration from Azure App Configuration. +// It checks if any watched settings have changed, and if so, reloads all configuration data. +// +// The refresh only occurs if: +// - Refresh has been configured with RefreshOptions when the client was created +// - The configured refresh interval has elapsed since the last refresh +// - No other refresh operation is currently in progress +// +// If the configuration has changed, any callback functions registered with OnRefreshSuccess will be executed. +// +// Parameters: +// - ctx: The context for the operation. +// +// Returns: +// - An error if refresh is not configured, or if the refresh operation fails +func (azappcfg *AzureAppConfiguration) Refresh(ctx context.Context) error { + if azappcfg.kvRefreshTimer == nil { + return fmt.Errorf("refresh is not configured") + } + + // Try to set refreshInProgress to true, returning false if it was already true + if !azappcfg.refreshInProgress.CompareAndSwap(false, true) { + return nil // Another refresh is already in progress + } + + // Reset the flag when we're done + defer azappcfg.refreshInProgress.Store(false) + + // Check if it's time to perform a refresh based on the timer interval + if !azappcfg.kvRefreshTimer.ShouldRefresh() { + return nil + } + + // Attempt to refresh and check if any values were actually updated + refreshed, err := azappcfg.refreshKeyValues(ctx, azappcfg.newKeyValueRefreshClient()) + if err != nil { + return fmt.Errorf("failed to refresh configuration: %w", err) + } + + // Only execute callbacks if actual changes were applied + if refreshed { + for _, callback := range azappcfg.onRefreshSuccess { + if callback != nil { + callback() + } + } + } + + return nil +} + +// OnRefreshSuccess registers a callback function that will be executed whenever the configuration +// is successfully refreshed and actual changes were detected. +// +// Multiple callback functions can be registered, and they will be executed in the order they were added. +// Callbacks are only executed when configuration values actually change. They run synchronously +// in the thread that initiated the refresh. +// +// Parameters: +// - callback: A function with no parameters that will be called after a successful refresh +func (azappcfg *AzureAppConfiguration) OnRefreshSuccess(callback func()) { + azappcfg.onRefreshSuccess = append(azappcfg.onRefreshSuccess, callback) +} + func (azappcfg *AzureAppConfiguration) load(ctx context.Context) error { - keyValuesClient := &selectorSettingsClient{ - selectors: azappcfg.kvSelectors, - client: azappcfg.clientManager.staticClient.client, - tracingOptions: azappcfg.tracingOptions, + eg, egCtx := errgroup.WithContext(ctx) + eg.Go(func() error { + keyValuesClient := &selectorSettingsClient{ + selectors: azappcfg.kvSelectors, + client: azappcfg.clientManager.staticClient.client, + tracingOptions: azappcfg.tracingOptions, + } + return azappcfg.loadKeyValues(egCtx, keyValuesClient) + }) + + if azappcfg.kvRefreshTimer != nil && len(azappcfg.watchedSettings) > 0 { + eg.Go(func() error { + watchedClient := &watchedSettingClient{ + watchedSettings: azappcfg.watchedSettings, + client: azappcfg.clientManager.staticClient.client, + tracingOptions: azappcfg.tracingOptions, + } + return azappcfg.loadWatchedSettings(egCtx, watchedClient) + }) + } + + return eg.Wait() +} + +func (azappcfg *AzureAppConfiguration) loadWatchedSettings(ctx context.Context, settingsClient settingsClient) error { + settingsResponse, err := settingsClient.getSettings(ctx) + if err != nil { + return err } - return azappcfg.loadKeyValues(ctx, keyValuesClient) + // Store ETags for all watched settings + if settingsResponse != nil && settingsResponse.watchedETags != nil { + azappcfg.sentinelETags = settingsResponse.watchedETags + } + + return nil } func (azappcfg *AzureAppConfiguration) loadKeyValues(ctx context.Context, settingsClient settingsClient) error { @@ -246,6 +360,48 @@ func (azappcfg *AzureAppConfiguration) loadKeyValues(ctx context.Context, settin return nil } +// refreshKeyValues checks if any watched settings have changed and reloads configuration if needed +// Returns true if configuration was actually refreshed, false otherwise +func (azappcfg *AzureAppConfiguration) refreshKeyValues(ctx context.Context, refreshClient refreshClient) (bool, error) { + // Check if any ETags have changed + eTagChanged, err := refreshClient.monitor.checkIfETagChanged(ctx) + if err != nil { + return false, fmt.Errorf("failed to check if watched settings have changed: %w", err) + } + + if !eTagChanged { + // No changes detected, reset timer and return + azappcfg.kvRefreshTimer.Reset() + return false, nil + } + + // Use an errgroup to reload key values and watched settings concurrently + eg, egCtx := errgroup.WithContext(ctx) + + // Reload key values in one goroutine + eg.Go(func() error { + settingsClient := refreshClient.loader + return azappcfg.loadKeyValues(egCtx, settingsClient) + }) + + if len(azappcfg.watchedSettings) > 0 { + eg.Go(func() error { + watchedClient := refreshClient.sentinels + return azappcfg.loadWatchedSettings(egCtx, watchedClient) + }) + } + + // Wait for all reloads to complete + if err := eg.Wait(); err != nil { + // Don't reset the timer if reload failed + return false, fmt.Errorf("failed to reload configuration: %w", err) + } + + // Reset the timer only after successful refresh + azappcfg.kvRefreshTimer.Reset() + return true, nil +} + func (azappcfg *AzureAppConfiguration) trimPrefix(key string) string { result := key for _, prefix := range azappcfg.trimPrefixes { @@ -324,3 +480,38 @@ func configureTracingOptions(options *Options) tracing.Options { return tracingOption } + +func normalizedWatchedSettings(s []WatchedSetting) []WatchedSetting { + result := make([]WatchedSetting, len(s)) + for i, setting := range s { + // Make a copy of the setting + normalizedSetting := setting + if normalizedSetting.Label == "" { + normalizedSetting.Label = defaultLabel + } + + result[i] = normalizedSetting + } + + return result +} + +func (azappcfg *AzureAppConfiguration) newKeyValueRefreshClient() refreshClient { + return refreshClient{ + loader: &selectorSettingsClient{ + selectors: azappcfg.kvSelectors, + client: azappcfg.clientManager.staticClient.client, + tracingOptions: azappcfg.tracingOptions, + }, + monitor: &watchedSettingClient{ + eTags: azappcfg.sentinelETags, + client: azappcfg.clientManager.staticClient.client, + tracingOptions: azappcfg.tracingOptions, + }, + sentinels: &watchedSettingClient{ + watchedSettings: azappcfg.watchedSettings, + client: azappcfg.clientManager.staticClient.client, + tracingOptions: azappcfg.tracingOptions, + }, + } +} diff --git a/azureappconfiguration/constants.go b/azureappconfiguration/constants.go index 608489f..028387d 100644 --- a/azureappconfiguration/constants.go +++ b/azureappconfiguration/constants.go @@ -3,6 +3,8 @@ package azureappconfiguration +import "time" + // Configuration client constants const ( endpointKey string = "Endpoint" @@ -18,3 +20,9 @@ const ( secretReferenceContentType string = "application/vnd.microsoft.appconfig.keyvaultref+json;charset=utf-8" featureFlagContentType string = "application/vnd.microsoft.appconfig.ff+json;charset=utf-8" ) + +// Refresh interval constants +const ( + // minimalRefreshInterval is the minimum allowed refresh interval for key-value settings + minimalRefreshInterval time.Duration = time.Second +) diff --git a/azureappconfiguration/internal/refresh/refresh.go b/azureappconfiguration/internal/refresh/refresh.go new file mode 100644 index 0000000..ba48ce3 --- /dev/null +++ b/azureappconfiguration/internal/refresh/refresh.go @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +package refresh + +import "time" + +// Timer manages the timing for refresh operations +type Timer struct { + interval time.Duration // How often refreshes should occur + nextRefreshTime time.Time // When the next refresh should occur +} + +// Condition interface defines the methods a refresh timer should implement +type Condition interface { + ShouldRefresh() bool + Reset() +} + +const ( + DefaultRefreshInterval time.Duration = 30 * time.Second +) + +// NewTimer creates a new refresh timer with the specified interval +// If interval is zero or negative, it falls back to the DefaultRefreshInterval +func NewTimer(interval time.Duration) *Timer { + // Use default interval if not specified or invalid + if interval <= 0 { + interval = DefaultRefreshInterval + } + + return &Timer{ + interval: interval, + nextRefreshTime: time.Now().Add(interval), + } +} + +// ShouldRefresh checks whether it's time for a refresh +func (rt *Timer) ShouldRefresh() bool { + return !time.Now().Before(rt.nextRefreshTime) +} + +// Reset resets the timer for the next refresh cycle +func (rt *Timer) Reset() { + rt.nextRefreshTime = time.Now().Add(rt.interval) +} diff --git a/azureappconfiguration/internal/tracing/tracing.go b/azureappconfiguration/internal/tracing/tracing.go index 3b56fa5..fdecf93 100644 --- a/azureappconfiguration/internal/tracing/tracing.go +++ b/azureappconfiguration/internal/tracing/tracing.go @@ -11,7 +11,7 @@ import ( ) type RequestType string - +type RequestTracingKey string type HostType string const ( @@ -32,7 +32,6 @@ const ( // Documentation : https://docs.microsoft.com/en-us/azure/service-fabric/service-fabric-environment-variables-reference EnvVarServiceFabric = "Fabric_NodeName" - RequestTracingKey = "Tracing" RequestTypeKey = "RequestType" HostTypeKey = "Host" KeyVaultConfiguredTag = "UsesKeyVault" @@ -50,6 +49,7 @@ const ( type Options struct { Enabled bool + InitialLoadFinished bool Host HostType KeyVaultConfigured bool UseAIConfiguration bool @@ -75,12 +75,10 @@ func CreateCorrelationContextHeader(ctx context.Context, options Options) http.H header := http.Header{} output := make([]string, 0) - if tracing := ctx.Value(RequestTracingKey); tracing != nil { - if tracing.(RequestType) == RequestTypeStartUp { - output = append(output, RequestTypeKey+"="+string(RequestTypeStartUp)) - } else if tracing.(RequestType) == RequestTypeWatch { - output = append(output, RequestTypeKey+"="+string(RequestTypeWatch)) - } + if !options.InitialLoadFinished { + output = append(output, RequestTypeKey+"="+string(RequestTypeStartUp)) + } else { + output = append(output, RequestTypeKey+"="+string(RequestTypeWatch)) } if options.Host != "" { diff --git a/azureappconfiguration/internal/tracing/tracing_test.go b/azureappconfiguration/internal/tracing/tracing_test.go index 03ac057..86336b0 100644 --- a/azureappconfiguration/internal/tracing/tracing_test.go +++ b/azureappconfiguration/internal/tracing/tracing_test.go @@ -20,14 +20,13 @@ func TestCreateCorrelationContextHeader(t *testing.T) { // The header should be empty but exist corrContext := header.Get(CorrelationContextHeader) - assert.Equal(t, "", corrContext) + assert.Equal(t, "RequestType=StartUp", corrContext) }) t.Run("with RequestTypeStartUp", func(t *testing.T) { - ctx := context.WithValue(context.Background(), RequestTracingKey, RequestTypeStartUp) options := Options{} - header := CreateCorrelationContextHeader(ctx, options) + header := CreateCorrelationContextHeader(context.Background(), options) // Should contain RequestTypeStartUp corrContext := header.Get(CorrelationContextHeader) @@ -35,10 +34,11 @@ func TestCreateCorrelationContextHeader(t *testing.T) { }) t.Run("with RequestTypeWatch", func(t *testing.T) { - ctx := context.WithValue(context.Background(), RequestTracingKey, RequestTypeWatch) - options := Options{} + options := Options{ + InitialLoadFinished: true, + } - header := CreateCorrelationContextHeader(ctx, options) + header := CreateCorrelationContextHeader(context.Background(), options) // Should contain RequestTypeWatch corrContext := header.Get(CorrelationContextHeader) @@ -132,7 +132,6 @@ func TestCreateCorrelationContextHeader(t *testing.T) { }) t.Run("with all options", func(t *testing.T) { - ctx := context.WithValue(context.Background(), RequestTracingKey, RequestTypeStartUp) options := Options{ Host: HostTypeAzureFunction, KeyVaultConfigured: true, @@ -140,7 +139,7 @@ func TestCreateCorrelationContextHeader(t *testing.T) { UseAIChatCompletionConfiguration: true, } - header := CreateCorrelationContextHeader(ctx, options) + header := CreateCorrelationContextHeader(context.Background(), options) // Check the complete header corrContext := header.Get(CorrelationContextHeader) @@ -168,13 +167,12 @@ func TestCreateCorrelationContextHeader(t *testing.T) { }) t.Run("delimiter handling", func(t *testing.T) { - ctx := context.WithValue(context.Background(), RequestTracingKey, RequestTypeStartUp) options := Options{ Host: HostTypeAzureWebApp, KeyVaultConfigured: true, } - header := CreateCorrelationContextHeader(ctx, options) + header := CreateCorrelationContextHeader(context.Background(), options) // Check the complete header corrContext := header.Get(CorrelationContextHeader) diff --git a/azureappconfiguration/options.go b/azureappconfiguration/options.go index f3f5371..a8e0d81 100644 --- a/azureappconfiguration/options.go +++ b/azureappconfiguration/options.go @@ -6,6 +6,7 @@ package azureappconfiguration import ( "context" "net/url" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig" @@ -21,13 +22,15 @@ type Options struct { // Selectors defines what key-values to load from Azure App Configuration // Each selector combines a key filter and label filter // If selectors are not provided, all key-values with no label are loaded by default. - Selectors []Selector + Selectors []Selector + // RefreshOptions contains optional parameters to configure the behavior of key-value settings refresh + RefreshOptions KeyValueRefreshOptions // KeyVaultOptions configures how Key Vault references are resolved. KeyVaultOptions KeyVaultOptions // ClientOptions provides options for configuring the underlying Azure App Configuration client. - ClientOptions *azappconfig.ClientOptions + ClientOptions *azappconfig.ClientOptions } // AuthenticationOptions contains parameters for authenticating with the Azure App Configuration service. @@ -35,11 +38,11 @@ type Options struct { type AuthenticationOptions struct { // Credential is a token credential for Azure EntraID Authenticaiton. // Required when Endpoint is provided. - Credential azcore.TokenCredential + Credential azcore.TokenCredential // Endpoint is the URL of the Azure App Configuration service. // Required when using token-based authentication with Credential. - Endpoint string + Endpoint string // ConnectionString is the connection string for the Azure App Configuration service. ConnectionString string @@ -49,7 +52,7 @@ type AuthenticationOptions struct { type Selector struct { // KeyFilter specifies which keys to retrieve from Azure App Configuration. // It can include wildcards, e.g. "app*" will match all keys starting with "app". - KeyFilter string + KeyFilter string // LabelFilter specifies which labels to retrieve from Azure App Configuration. // Empty string or omitted value will use the default no-label filter. @@ -57,6 +60,25 @@ type Selector struct { LabelFilter string } +// KeyValueRefreshOptions contains optional parameters to configure the behavior of key-value settings refresh +type KeyValueRefreshOptions struct { + // WatchedSettings specifies the key-value settings to watch for changes + WatchedSettings []WatchedSetting + + // Interval specifies the minimum time interval between consecutive refresh operations for the watched settings + // Must be greater than 1 second. If not provided, the default interval 30 seconds will be used + Interval time.Duration + + // Enabled specifies whether the provider should automatically refresh when the configuration is changed. + Enabled bool +} + +// WatchedSetting specifies the key and label of a key-value setting to watch for changes +type WatchedSetting struct { + Key string + Label string +} + // SecretResolver is an interface to resolve secrets from Key Vault references. // Implement this interface to provide custom secret resolution logic. type SecretResolver interface { diff --git a/azureappconfiguration/refresh_test.go b/azureappconfiguration/refresh_test.go new file mode 100644 index 0000000..c2ed49d --- /dev/null +++ b/azureappconfiguration/refresh_test.go @@ -0,0 +1,405 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +package azureappconfiguration + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/Azure/AppConfiguration-GoProvider/azureappconfiguration/internal/refresh" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockETagsClient implements the eTagsClient interface for testing +type mockETagsClient struct { + changed bool + checkCallCount int + err error +} + +func (m *mockETagsClient) checkIfETagChanged(ctx context.Context) (bool, error) { + m.checkCallCount++ + if m.err != nil { + return false, m.err + } + return m.changed, nil +} + +// mockRefreshCondition implements the refreshtimer.RefreshCondition interface for testing +type mockRefreshCondition struct { + shouldRefresh bool + resetCalled bool +} + +func (m *mockRefreshCondition) ShouldRefresh() bool { + return m.shouldRefresh +} + +func (m *mockRefreshCondition) Reset() { + m.resetCalled = true +} + +func TestRefresh_NotConfigured(t *testing.T) { + // Setup a provider with no refresh configuration + azappcfg := &AzureAppConfiguration{} + + // Attempt to refresh + err := azappcfg.Refresh(context.Background()) + + // Verify that an error is returned + require.Error(t, err) + assert.Contains(t, err.Error(), "refresh is not configured") +} + +func TestRefresh_NotTimeToRefresh(t *testing.T) { + // Setup a provider with a timer that indicates it's not time to refresh + mockTimer := &mockRefreshCondition{shouldRefresh: false} + azappcfg := &AzureAppConfiguration{ + kvRefreshTimer: mockTimer, + } + + // Attempt to refresh + err := azappcfg.Refresh(context.Background()) + + // Verify no error and that we returned early + assert.NoError(t, err) + // Timer should not be reset if we're not refreshing + assert.False(t, mockTimer.resetCalled) +} + +func TestRefreshEnabled_EmptyWatchedSettings(t *testing.T) { + // Test verifying validation when refresh is enabled but no watched settings + options := &Options{ + RefreshOptions: KeyValueRefreshOptions{ + Enabled: true, // Enabled but without watched settings + WatchedSettings: []WatchedSetting{}, + }, + } + + // Verify error + err := verifyOptions(options) + require.Error(t, err) + assert.Contains(t, err.Error(), "watched settings cannot be empty") +} + +func TestRefreshEnabled_IntervalTooShort(t *testing.T) { + // Test verifying validation when refresh interval is too short + options := &Options{ + RefreshOptions: KeyValueRefreshOptions{ + Enabled: true, + Interval: 500 * time.Millisecond, // Too short, should be at least minimalRefreshInterval + WatchedSettings: []WatchedSetting{ + {Key: "test-key", Label: "test-label"}, + }, + }, + } + + // Verify error + err := verifyOptions(options) + require.Error(t, err) + assert.Contains(t, err.Error(), "key value refresh interval cannot be less than") +} + +func TestRefreshEnabled_EmptyWatchedSettingKey(t *testing.T) { + // Test verifying validation when a watched setting has an empty key + options := &Options{ + RefreshOptions: KeyValueRefreshOptions{ + Enabled: true, + WatchedSettings: []WatchedSetting{ + {Key: "", Label: "test-label"}, // Empty key should be rejected + }, + }, + } + + // Verify error + err := verifyOptions(options) + require.Error(t, err) + assert.Contains(t, err.Error(), "watched setting key cannot be empty") +} + +func TestRefreshEnabled_InvalidWatchedSettingKey(t *testing.T) { + // Test verifying validation when watched setting keys contain invalid chars + options := &Options{ + RefreshOptions: KeyValueRefreshOptions{ + Enabled: true, + WatchedSettings: []WatchedSetting{ + {Key: "test*key", Label: "test-label"}, // Key contains wildcard, not allowed + }, + }, + } + + // Verify error + err := verifyOptions(options) + require.Error(t, err) + assert.Contains(t, err.Error(), "watched setting key cannot contain") +} + +func TestRefreshEnabled_InvalidWatchedSettingLabel(t *testing.T) { + // Test verifying validation when watched setting labels contain invalid chars + options := &Options{ + RefreshOptions: KeyValueRefreshOptions{ + Enabled: true, + WatchedSettings: []WatchedSetting{ + {Key: "test-key", Label: "test*label"}, // Label contains wildcard, not allowed + }, + }, + } + + // Verify error + err := verifyOptions(options) + require.Error(t, err) + assert.Contains(t, err.Error(), "watched setting label cannot contain") +} + +func TestRefreshEnabled_ValidSettings(t *testing.T) { + // Test verifying valid refresh options pass validation + options := &Options{ + RefreshOptions: KeyValueRefreshOptions{ + Enabled: true, + Interval: 5 * time.Second, // Valid interval + WatchedSettings: []WatchedSetting{ + {Key: "test-key-1", Label: "test-label-1"}, + {Key: "test-key-2", Label: ""}, // Empty label should be normalized later + }, + }, + } + + // Verify no error + err := verifyOptions(options) + assert.NoError(t, err) +} + +func TestNormalizedWatchedSettings(t *testing.T) { + // Test the normalizedWatchedSettings function + settings := []WatchedSetting{ + {Key: "key1", Label: "label1"}, + {Key: "key2", Label: ""}, // Empty label should be set to defaultLabel + } + + normalized := normalizedWatchedSettings(settings) + + // Verify results + assert.Len(t, normalized, 2) + assert.Equal(t, "key1", normalized[0].Key) + assert.Equal(t, "label1", normalized[0].Label) + assert.Equal(t, "key2", normalized[1].Key) + assert.Equal(t, defaultLabel, normalized[1].Label) +} + +// Additional test to verify real RefreshTimer behavior +func TestRealRefreshTimer(t *testing.T) { + // Create a real refresh timer with a short interval + timer := refresh.NewTimer(100 * time.Millisecond) + + // Initially it should not be time to refresh + assert.False(t, timer.ShouldRefresh(), "New timer should not immediately indicate refresh needed") + + // After the interval passes, it should indicate time to refresh + time.Sleep(110 * time.Millisecond) + assert.True(t, timer.ShouldRefresh(), "Timer should indicate refresh needed after interval") + + // After reset, it should not be time to refresh again + timer.Reset() + assert.False(t, timer.ShouldRefresh(), "Timer should not indicate refresh needed right after reset") +} + +// mockKvRefreshClient implements the settingsClient interface for testing +type mockKvRefreshClient struct { + settings []azappconfig.Setting + watchedETags map[WatchedSetting]*azcore.ETag + getCallCount int + err error +} + +func (m *mockKvRefreshClient) getSettings(ctx context.Context) (*settingsResponse, error) { + m.getCallCount++ + if m.err != nil { + return nil, m.err + } + return &settingsResponse{ + settings: m.settings, + watchedETags: m.watchedETags, + }, nil +} + +// TestRefreshKeyValues_NoChanges tests when no ETags change is detected +func TestRefreshKeyValues_NoChanges(t *testing.T) { + // Setup mocks + mockTimer := &mockRefreshCondition{} + mockMonitor := &mockETagsClient{changed: false} + mockLoader := &mockKvRefreshClient{} + mockSentinels := &mockKvRefreshClient{} + + mockClient := refreshClient{ + loader: mockLoader, + monitor: mockMonitor, + sentinels: mockSentinels, + } + + // Setup provider + azappcfg := &AzureAppConfiguration{ + kvRefreshTimer: mockTimer, + } + + // Call refreshKeyValues + refreshed, err := azappcfg.refreshKeyValues(context.Background(), mockClient) + + // Verify results + assert.NoError(t, err) + assert.False(t, refreshed, "Should return false when no changes detected") + assert.Equal(t, 1, mockMonitor.checkCallCount, "Monitor should be called exactly once") + assert.Equal(t, 0, mockLoader.getCallCount, "Loader should not be called when no changes") + assert.Equal(t, 0, mockSentinels.getCallCount, "Sentinels should not be called when no changes") + assert.True(t, mockTimer.resetCalled, "Timer should be reset even when no changes") +} + +// TestRefreshKeyValues_ChangesDetected tests when ETags changed and reload succeeds +func TestRefreshKeyValues_ChangesDetected(t *testing.T) { + // Setup mocks for successful refresh + mockTimer := &mockRefreshCondition{} + mockMonitor := &mockETagsClient{changed: true} + mockLoader := &mockKvRefreshClient{} + mockSentinels := &mockKvRefreshClient{} + + mockClient := refreshClient{ + loader: mockLoader, + monitor: mockMonitor, + sentinels: mockSentinels, + } + + // Setup provider with watchedSettings + azappcfg := &AzureAppConfiguration{ + kvRefreshTimer: mockTimer, + watchedSettings: []WatchedSetting{{Key: "test", Label: "test"}}, + } + + // Call refreshKeyValues + refreshed, err := azappcfg.refreshKeyValues(context.Background(), mockClient) + + // Verify results + assert.NoError(t, err) + assert.True(t, refreshed, "Should return true when changes detected and applied") + assert.Equal(t, 1, mockMonitor.checkCallCount, "Monitor should be called exactly once") + assert.Equal(t, 1, mockLoader.getCallCount, "Loader should be called when changes detected") + assert.Equal(t, 1, mockSentinels.getCallCount, "Sentinels should be called when changes detected") + assert.True(t, mockTimer.resetCalled, "Timer should be reset after successful refresh") +} + +// TestRefreshKeyValues_LoaderError tests when loader client returns an error +func TestRefreshKeyValues_LoaderError(t *testing.T) { + // Setup mocks with loader error + mockTimer := &mockRefreshCondition{} + mockMonitor := &mockETagsClient{changed: true} + mockLoader := &mockKvRefreshClient{err: fmt.Errorf("loader error")} + mockSentinels := &mockKvRefreshClient{} + + mockClient := refreshClient{ + loader: mockLoader, + monitor: mockMonitor, + sentinels: mockSentinels, + } + + // Setup provider + azappcfg := &AzureAppConfiguration{ + kvRefreshTimer: mockTimer, + } + + // Call refreshKeyValues + refreshed, err := azappcfg.refreshKeyValues(context.Background(), mockClient) + + // Verify results + assert.Error(t, err) + assert.False(t, refreshed, "Should return false when error occurs") + assert.Contains(t, err.Error(), "loader error") + assert.Equal(t, 1, mockMonitor.checkCallCount, "Monitor should be called exactly once") + assert.Equal(t, 1, mockLoader.getCallCount, "Loader should be called when changes detected") + assert.False(t, mockTimer.resetCalled, "Timer should not be reset when error occurs") +} + +// TestRefreshKeyValues_SentinelError tests when sentinel client returns an error +func TestRefreshKeyValues_SentinelError(t *testing.T) { + // Setup mocks with sentinel error + mockTimer := &mockRefreshCondition{} + mockMonitor := &mockETagsClient{changed: true} + mockLoader := &mockKvRefreshClient{} + mockSentinels := &mockKvRefreshClient{err: fmt.Errorf("sentinel error")} + + mockClient := refreshClient{ + loader: mockLoader, + monitor: mockMonitor, + sentinels: mockSentinels, + } + + // Setup provider with watchedSettings to ensure sentinels are used + azappcfg := &AzureAppConfiguration{ + kvRefreshTimer: mockTimer, + watchedSettings: []WatchedSetting{{Key: "test", Label: "test"}}, + } + + // Call refreshKeyValues + refreshed, err := azappcfg.refreshKeyValues(context.Background(), mockClient) + + // Verify results + assert.Error(t, err) + assert.False(t, refreshed, "Should return false when error occurs") + assert.Contains(t, err.Error(), "sentinel error") + assert.Equal(t, 1, mockMonitor.checkCallCount, "Monitor should be called exactly once") + assert.Equal(t, 1, mockLoader.getCallCount, "Loader should be called when changes detected") + assert.Equal(t, 1, mockSentinels.getCallCount, "Sentinels should be called when changes detected") + assert.False(t, mockTimer.resetCalled, "Timer should not be reset when error occurs") +} + +// TestRefreshKeyValues_MonitorError tests when monitor client returns an error +func TestRefreshKeyValues_MonitorError(t *testing.T) { + // Setup mocks with monitor error + mockTimer := &mockRefreshCondition{} + mockMonitor := &mockETagsClient{err: fmt.Errorf("monitor error")} + mockLoader := &mockKvRefreshClient{} + mockSentinels := &mockKvRefreshClient{} + + mockClient := refreshClient{ + loader: mockLoader, + monitor: mockMonitor, + sentinels: mockSentinels, + } + + // Setup provider + azappcfg := &AzureAppConfiguration{ + kvRefreshTimer: mockTimer, + } + + // Call refreshKeyValues + refreshed, err := azappcfg.refreshKeyValues(context.Background(), mockClient) + + // Verify results + assert.Error(t, err) + assert.False(t, refreshed, "Should return false when error occurs") + assert.Contains(t, err.Error(), "monitor error") + assert.Equal(t, 1, mockMonitor.checkCallCount, "Monitor should be called exactly once") + assert.Equal(t, 0, mockLoader.getCallCount, "Loader should not be called when monitor fails") + assert.Equal(t, 0, mockSentinels.getCallCount, "Sentinels should not be called when monitor fails") + assert.False(t, mockTimer.resetCalled, "Timer should not be reset when error occurs") +} + +// TestRefresh_AlreadyInProgress tests the new atomic implementation of refresh status checking +func TestRefresh_AlreadyInProgress(t *testing.T) { + // Setup a provider with refresh already in progress + azappcfg := &AzureAppConfiguration{ + kvRefreshTimer: &mockRefreshCondition{}, + } + + // Manually set the refresh in progress flag + azappcfg.refreshInProgress.Store(true) + + // Attempt to refresh + err := azappcfg.Refresh(context.Background()) + + // Verify no error and that we returned early + assert.NoError(t, err) +} diff --git a/azureappconfiguration/settings_client.go b/azureappconfiguration/settings_client.go index 1889cfc..212177b 100644 --- a/azureappconfiguration/settings_client.go +++ b/azureappconfiguration/settings_client.go @@ -5,16 +5,19 @@ package azureappconfiguration import ( "context" + "errors" + "log" "github.com/Azure/AppConfiguration-GoProvider/azureappconfiguration/internal/tracing" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/data/azappconfig" ) type settingsResponse struct { - settings []azappconfig.Setting - // TODO: pageETags + settings []azappconfig.Setting + watchedETags map[WatchedSetting]*azcore.ETag } type selectorSettingsClient struct { @@ -23,10 +26,27 @@ type selectorSettingsClient struct { tracingOptions tracing.Options } +type watchedSettingClient struct { + watchedSettings []WatchedSetting + eTags map[WatchedSetting]*azcore.ETag + client *azappconfig.Client + tracingOptions tracing.Options +} + type settingsClient interface { getSettings(ctx context.Context) (*settingsResponse, error) } +type eTagsClient interface { + checkIfETagChanged(ctx context.Context) (bool, error) +} + +type refreshClient struct { + loader settingsClient + monitor eTagsClient + sentinels settingsClient +} + func (s *selectorSettingsClient) getSettings(ctx context.Context) (*settingsResponse, error) { if s.tracingOptions.Enabled { ctx = policy.WithHTTPHeader(ctx, tracing.CreateCorrelationContextHeader(ctx, s.tracingOptions)) @@ -55,3 +75,58 @@ func (s *selectorSettingsClient) getSettings(ctx context.Context) (*settingsResp settings: settings, }, nil } + +func (c *watchedSettingClient) getSettings(ctx context.Context) (*settingsResponse, error) { + if c.tracingOptions.Enabled { + ctx = policy.WithHTTPHeader(ctx, tracing.CreateCorrelationContextHeader(ctx, c.tracingOptions)) + } + + settings := make([]azappconfig.Setting, 0, len(c.watchedSettings)) + watchedETags := make(map[WatchedSetting]*azcore.ETag) + for _, watchedSetting := range c.watchedSettings { + response, err := c.client.GetSetting(ctx, watchedSetting.Key, &azappconfig.GetSettingOptions{Label: to.Ptr(watchedSetting.Label)}) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == 404 { + label := watchedSetting.Label + if label == "" || label == "\x00" { // NUL is escaped to \x00 in golang + label = "no" + } + // If the watched setting is not found, log and continue + log.Printf("Watched key '%s' with %s label does not exist", watchedSetting.Key, label) + continue + } + return nil, err + } + + settings = append(settings, response.Setting) + watchedETags[watchedSetting] = response.Setting.ETag + } + + return &settingsResponse{ + settings: settings, + watchedETags: watchedETags, + }, nil +} + +func (c *watchedSettingClient) checkIfETagChanged(ctx context.Context) (bool, error) { + if c.tracingOptions.Enabled { + ctx = policy.WithHTTPHeader(ctx, tracing.CreateCorrelationContextHeader(ctx, c.tracingOptions)) + } + + for watchedSetting, ETag := range c.eTags { + _, err := c.client.GetSetting(ctx, watchedSetting.Key, &azappconfig.GetSettingOptions{Label: to.Ptr(watchedSetting.Label), OnlyIfChanged: ETag}) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && (respErr.StatusCode == 404 || respErr.StatusCode == 304) { + continue + } + + return false, err + } + + return true, nil + } + + return false, nil +} diff --git a/azureappconfiguration/utils.go b/azureappconfiguration/utils.go index 48cd274..f567360 100644 --- a/azureappconfiguration/utils.go +++ b/azureappconfiguration/utils.go @@ -29,6 +29,32 @@ func verifyOptions(options *Options) error { return err } + if options.RefreshOptions.Enabled { + if options.RefreshOptions.Interval != 0 && + options.RefreshOptions.Interval < minimalRefreshInterval { + return fmt.Errorf("key value refresh interval cannot be less than %s", minimalRefreshInterval) + } + + if len(options.RefreshOptions.WatchedSettings) == 0 { + return fmt.Errorf("watched settings cannot be empty") + } + + for _, watchedSetting := range options.RefreshOptions.WatchedSettings { + if watchedSetting.Key == "" { + return fmt.Errorf("watched setting key cannot be empty") + } + + if strings.Contains(watchedSetting.Key, "*") || strings.Contains(watchedSetting.Key, ",") { + return fmt.Errorf("watched setting key cannot contain '*' or ','") + } + + if watchedSetting.Label != "" && + (strings.Contains(watchedSetting.Label, "*") || strings.Contains(watchedSetting.Label, ",")) { + return fmt.Errorf("watched setting label cannot contain '*' or ','") + } + } + } + return nil }