From 1b86c130acc19189465681f5e4dd15b8b043bd38 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Fri, 20 Mar 2026 09:22:26 -0400 Subject: [PATCH 01/30] feat(lock): add cross-process lock infrastructure Add file-based process locks (via github.com/gofrs/flock) and a reusable lock-wait loop for serializing keychain, session-cache, and SSO-token operations across concurrent aws-vault processes. New files: - process_lock.go: ProcessLock interface + file-lock implementation - lock_waiter.go: poll-sleep-log wait loop with configurable timing - keychain_lock.go, session_lock.go, sso_lock.go: typed lock wrappers - locked_keyring.go: keyring.Keyring decorator that serializes all read/write operations behind a process lock - lock_test.go: shared test helpers for lock-based tests Co-Authored-By: Claude Opus 4.6 (1M context) --- go.mod | 1 + go.sum | 2 + vault/keychain_lock.go | 21 +++++ vault/lock_test.go | 66 +++++++++++++++ vault/lock_waiter.go | 85 +++++++++++++++++++ vault/locked_keyring.go | 176 ++++++++++++++++++++++++++++++++++++++++ vault/process_lock.go | 47 +++++++++++ vault/session_lock.go | 21 +++++ vault/sso_lock.go | 22 +++++ 9 files changed, 441 insertions(+) create mode 100644 vault/keychain_lock.go create mode 100644 vault/lock_test.go create mode 100644 vault/lock_waiter.go create mode 100644 vault/locked_keyring.go create mode 100644 vault/process_lock.go create mode 100644 vault/session_lock.go create mode 100644 vault/sso_lock.go diff --git a/go.mod b/go.mod index 717c6967..fcc411f3 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/byteness/keyring v1.9.0 github.com/charmbracelet/huh v1.0.0 github.com/charmbracelet/lipgloss v1.1.0 + github.com/gofrs/flock v0.8.1 github.com/google/go-cmp v0.7.0 github.com/mattn/go-isatty v0.0.21 github.com/mattn/go-tty v0.0.7 diff --git a/go.sum b/go.sum index e5e46070..ec69f876 100644 --- a/go.sum +++ b/go.sum @@ -116,6 +116,8 @@ github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= +github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= +github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/vault/keychain_lock.go b/vault/keychain_lock.go new file mode 100644 index 00000000..36b8e0ea --- /dev/null +++ b/vault/keychain_lock.go @@ -0,0 +1,21 @@ +package vault + +const keyringLockFilenamePrefix = "aws-vault.keyring" + +// KeyringLock coordinates keyring access across processes. +type KeyringLock = ProcessLock + +// NewDefaultKeyringLock creates a lock in the system temp directory. +// This only coordinates processes that share the same temp dir; differing TMPDIRs/users are out of scope. +func NewDefaultKeyringLock(lockKey string) KeyringLock { + return NewKeyringLock(defaultLockPath(keyringLockFilename(lockKey))) +} + +// NewKeyringLock creates a lock at the provided path. +func NewKeyringLock(path string) KeyringLock { + return NewFileLock(path) +} + +func keyringLockFilename(lockKey string) string { + return hashedLockFilename(keyringLockFilenamePrefix, lockKey) +} diff --git a/vault/lock_test.go b/vault/lock_test.go new file mode 100644 index 00000000..18a2607b --- /dev/null +++ b/vault/lock_test.go @@ -0,0 +1,66 @@ +package vault + +import ( + "context" + "time" +) + +type testLock struct { + tryResults []bool + tryCalls int + unlockCalls int + locked bool + path string + onTry func(*testLock) +} + +func (l *testLock) TryLock() (bool, error) { + l.tryCalls++ + locked := false + if l.tryCalls <= len(l.tryResults) { + locked = l.tryResults[l.tryCalls-1] + } + if locked { + l.locked = true + } + if l.onTry != nil { + l.onTry(l) + } + return locked, nil +} + +func (l *testLock) Unlock() error { + l.unlockCalls++ + l.locked = false + return nil +} + +func (l *testLock) Path() string { + if l.path != "" { + return l.path + } + return "/tmp/aws-vault.lock" +} + +type testClock struct { + now time.Time + sleepCalls int + cancelAfter int + cancel context.CancelFunc +} + +func (c *testClock) Now() time.Time { + return c.now +} + +func (c *testClock) Sleep(ctx context.Context, d time.Duration) error { + c.sleepCalls++ + c.now = c.now.Add(d) + if c.cancel != nil && c.cancelAfter > 0 && c.sleepCalls >= c.cancelAfter { + c.cancel() + } + if ctx.Err() != nil { + return ctx.Err() + } + return nil +} diff --git a/vault/lock_waiter.go b/vault/lock_waiter.go new file mode 100644 index 00000000..5cdac42a --- /dev/null +++ b/vault/lock_waiter.go @@ -0,0 +1,85 @@ +package vault + +import ( + "context" + "time" +) + +type lockLogger func(string, ...any) + +type lockWaiter struct { + lock ProcessLock + waitDelay time.Duration + logEvery time.Duration + warnAfter time.Duration + now func() time.Time + sleep func(context.Context, time.Duration) error + logf lockLogger + warnf lockLogger + warnMsg string + logMsg string + + lastLog time.Time + waitStart time.Time + warned bool +} + +func newLockWaiter( + lock ProcessLock, + warnMsg string, + logMsg string, + waitDelay time.Duration, + logEvery time.Duration, + warnAfter time.Duration, + now func() time.Time, + sleep func(context.Context, time.Duration) error, + logf lockLogger, + warnf lockLogger, +) *lockWaiter { + if now == nil { + now = time.Now + } + if sleep == nil { + sleep = func(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } + } + } + return &lockWaiter{ + lock: lock, + waitDelay: waitDelay, + logEvery: logEvery, + warnAfter: warnAfter, + now: now, + sleep: sleep, + logf: logf, + warnf: warnf, + warnMsg: warnMsg, + logMsg: logMsg, + } +} + +func (w *lockWaiter) sleepAfterMiss(ctx context.Context) error { + now := w.now() + if w.waitStart.IsZero() { + w.waitStart = now + } + if !w.warned && now.Sub(w.waitStart) >= w.warnAfter { + if w.warnf != nil { + w.warnf(w.warnMsg, w.lock.Path()) + } + w.warned = true + } + if w.logf != nil && (w.lastLog.IsZero() || now.Sub(w.lastLog) >= w.logEvery) { + w.logf(w.logMsg, w.lock.Path()) + w.lastLog = now + } + + return w.sleep(ctx, w.waitDelay) +} diff --git a/vault/locked_keyring.go b/vault/locked_keyring.go new file mode 100644 index 00000000..084d5c03 --- /dev/null +++ b/vault/locked_keyring.go @@ -0,0 +1,176 @@ +package vault + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "sync" + "time" + + "github.com/byteness/keyring" +) + +type lockedKeyring struct { + inner keyring.Keyring + lock KeyringLock + mu sync.Mutex + + lockKey string + lockWait time.Duration + lockLog time.Duration + warnAfter time.Duration + lockNow func() time.Time + lockSleep func(context.Context, time.Duration) error + lockLogf func(string, ...any) +} + +const ( + defaultKeyringLockWaitDelay = 100 * time.Millisecond + defaultKeyringLockLogEvery = 15 * time.Second + defaultKeyringLockWarnAfter = 5 * time.Second + defaultKeyringLockTimeout = 2 * time.Minute +) + +// NewLockedKeyring wraps the provided keyring with a cross-process lock +// to serialize keyring operations. +func NewLockedKeyring(kr keyring.Keyring, lockKey string) keyring.Keyring { + return &lockedKeyring{ + inner: kr, + lock: NewDefaultKeyringLock(lockKey), + lockKey: lockKey, + } +} + +func (k *lockedKeyring) ensureLockDependencies() { + if k.lock == nil { + lockKey := k.lockKey + if lockKey == "" { + lockKey = "aws-vault" + } + k.lock = NewDefaultKeyringLock(lockKey) + } + if k.lockWait == 0 { + k.lockWait = defaultKeyringLockWaitDelay + } + if k.lockLog == 0 { + k.lockLog = defaultKeyringLockLogEvery + } + if k.warnAfter == 0 { + k.warnAfter = defaultKeyringLockWarnAfter + } + if k.lockNow == nil { + k.lockNow = time.Now + } + if k.lockSleep == nil { + k.lockSleep = func(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } + } + } + if k.lockLogf == nil { + k.lockLogf = log.Printf + } +} + +func (k *lockedKeyring) withLock(fn func() error) error { + k.ensureLockDependencies() + + k.mu.Lock() + defer k.mu.Unlock() + + waiter := newLockWaiter( + k.lock, + "Waiting for keyring lock at %s\n", + "Waiting for keyring lock at %s", + k.lockWait, + k.lockLog, + k.warnAfter, + k.lockNow, + k.lockSleep, + k.lockLogf, + func(format string, args ...any) { + fmt.Fprintf(os.Stderr, format, args...) + }, + ) + + // The keyring.Keyring interface is not context-aware, so we cannot cancel + // in-flight keyring operations. This timeout is a safety net for the lock-wait + // loop: if the lock holder is hung (e.g. a stuck gpg subprocess in the pass + // backend), waiters will eventually give up rather than blocking indefinitely. + ctx, cancel := context.WithTimeout(context.Background(), defaultKeyringLockTimeout) + defer cancel() + + for { + locked, err := k.lock.TryLock() + if err != nil { + return err + } + if locked { + fnErr := fn() + if unlockErr := k.lock.Unlock(); unlockErr != nil { + return errors.Join(fnErr, fmt.Errorf("unlock keyring lock: %w", unlockErr)) + } + return fnErr + } + + if err = waiter.sleepAfterMiss(ctx); err != nil { + return err + } + } +} + +func (k *lockedKeyring) Get(key string) (keyring.Item, error) { + var item keyring.Item + if err := k.withLock(func() error { + var err error + item, err = k.inner.Get(key) + return err + }); err != nil { + return keyring.Item{}, err + } + return item, nil +} + +func (k *lockedKeyring) GetMetadata(key string) (keyring.Metadata, error) { + var meta keyring.Metadata + if err := k.withLock(func() error { + var err error + meta, err = k.inner.GetMetadata(key) + return err + }); err != nil { + return keyring.Metadata{}, err + } + return meta, nil +} + +func (k *lockedKeyring) Set(item keyring.Item) error { + return k.withLock(func() error { + return k.inner.Set(item) + }) +} + +func (k *lockedKeyring) Remove(key string) error { + return k.withLock(func() error { + return k.inner.Remove(key) + }) +} + +func (k *lockedKeyring) Keys() ([]string, error) { + var keys []string + if err := k.withLock(func() error { + var err error + keys, err = k.inner.Keys() + return err + }); err != nil { + return nil, err + } + return keys, nil +} diff --git a/vault/process_lock.go b/vault/process_lock.go new file mode 100644 index 00000000..0582fe3c --- /dev/null +++ b/vault/process_lock.go @@ -0,0 +1,47 @@ +package vault + +import ( + "crypto/sha256" + "fmt" + "os" + "path/filepath" + + "github.com/gofrs/flock" +) + +// ProcessLock coordinates work across processes. +type ProcessLock interface { + TryLock() (bool, error) + Unlock() error + Path() string +} + +type fileProcessLock struct { + lock *flock.Flock +} + +// NewFileLock creates a lock at the provided path. +func NewFileLock(path string) ProcessLock { + return &fileProcessLock{lock: flock.New(path)} +} + +func (l *fileProcessLock) TryLock() (bool, error) { + return l.lock.TryLock() +} + +func (l *fileProcessLock) Unlock() error { + return l.lock.Unlock() +} + +func (l *fileProcessLock) Path() string { + return l.lock.Path() +} + +func defaultLockPath(filename string) string { + return filepath.Join(os.TempDir(), filename) +} + +func hashedLockFilename(prefix, key string) string { + sum := sha256.Sum256([]byte(key)) + return fmt.Sprintf("%s.%x.lock", prefix, sum) +} diff --git a/vault/session_lock.go b/vault/session_lock.go new file mode 100644 index 00000000..5d77e240 --- /dev/null +++ b/vault/session_lock.go @@ -0,0 +1,21 @@ +package vault + +const sessionLockFilenamePrefix = "aws-vault.session" + +// SessionCacheLock coordinates session cache refreshes across processes. +type SessionCacheLock = ProcessLock + +// NewDefaultSessionCacheLock creates a lock in the system temp directory. +// This only coordinates processes that share the same temp dir; differing TMPDIRs/users are out of scope. +func NewDefaultSessionCacheLock(lockKey string) SessionCacheLock { + return NewSessionCacheLock(defaultLockPath(sessionLockFilename(lockKey))) +} + +// NewSessionCacheLock creates a lock at the provided path. +func NewSessionCacheLock(path string) SessionCacheLock { + return NewFileLock(path) +} + +func sessionLockFilename(lockKey string) string { + return hashedLockFilename(sessionLockFilenamePrefix, lockKey) +} diff --git a/vault/sso_lock.go b/vault/sso_lock.go new file mode 100644 index 00000000..06a2868c --- /dev/null +++ b/vault/sso_lock.go @@ -0,0 +1,22 @@ +package vault + +const ssoLockFilenamePrefix = "aws-vault.sso" + +// SSOTokenLock coordinates the SSO device flow across processes. +type SSOTokenLock = ProcessLock + +// NewDefaultSSOTokenLock creates a lock in the system temp directory keyed by startURL. +// Processes sharing the same StartURL serialize; different StartURLs lock independently. +// This only coordinates processes that share the same temp dir; differing TMPDIRs/users are out of scope. +func NewDefaultSSOTokenLock(startURL string) SSOTokenLock { + return NewSSOTokenLock(defaultLockPath(ssoLockFilename(startURL))) +} + +// NewSSOTokenLock creates a lock at the provided path. +func NewSSOTokenLock(path string) SSOTokenLock { + return NewFileLock(path) +} + +func ssoLockFilename(startURL string) string { + return hashedLockFilename(ssoLockFilenamePrefix, startURL) +} From fa92e9fca09ff7cc4a16231bd3a660eac3b19c45 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Fri, 20 Mar 2026 09:23:58 -0400 Subject: [PATCH 02/30] feat: add parallel-safe mode for SSO and session caching Wire the lock infrastructure into the SSO browser flow and session cache so only one aws-vault process refreshes credentials at a time. - SSO token lock: serializes OIDC browser auth per StartURL so concurrent processes don't each open a browser tab. - Session cache lock: serializes cache writes so concurrent processes don't race on "item already exists" keyring errors. - Keyring lock: wraps the keyring with a LockedKeyring decorator when --parallel-safe is enabled to serialize all keyring operations. - New --parallel-safe flag / AWS_VAULT_PARALLEL_SAFE env var threaded through exec, export, login, and rotate commands. - NewTempCredentialsProviderWithOptions for opt-in parallel safety. Co-Authored-By: Claude Opus 4.6 (1M context) --- cli/exec.go | 11 +- cli/export.go | 10 +- cli/global.go | 12 + cli/rotate.go | 16 +- vault/cachedsessionprovider.go | 150 ++++++++- vault/cachedsessionprovider_lock_test.go | 291 +++++++++++++++++ vault/ssorolecredentialsprovider.go | 175 +++++++++- vault/ssorolecredentialsprovider_lock_test.go | 303 ++++++++++++++++++ vault/vault.go | 71 +++- vault/vault_test.go | 39 ++- 10 files changed, 1039 insertions(+), 39 deletions(-) create mode 100644 vault/cachedsessionprovider_lock_test.go create mode 100644 vault/ssorolecredentialsprovider_lock_test.go diff --git a/cli/exec.go b/cli/exec.go index a2c21f46..7d9f3895 100644 --- a/cli/exec.go +++ b/cli/exec.go @@ -34,6 +34,7 @@ type ExecCommandInput struct { NoSession bool UseStdout bool ShowHelpMessages bool + ParallelSafe bool } func (input ExecCommandInput) validate() error { @@ -121,6 +122,7 @@ func ConfigureExecCommand(app *kingpin.Application, a *AwsVault) { StringsVar(&input.Args) cmd.Action(func(c *kingpin.ParseContext) (err error) { + input.ParallelSafe = a.ParallelSafe input.Config.MfaPromptMethod = a.PromptDriver(hasBackgroundServer(input)) input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration input.Config.AssumeRoleDuration = input.SessionDuration @@ -155,6 +157,7 @@ func ConfigureExecCommand(app *kingpin.Application, a *AwsVault) { Config: input.Config, SessionDuration: input.SessionDuration, NoSession: input.NoSession, + ParallelSafe: input.ParallelSafe, } err = ExportCommand(exportCommandInput, f, keyring) @@ -185,7 +188,13 @@ func ExecCommand(input ExecCommandInput, f *vault.ConfigFile, keyring keyring.Ke return 0, fmt.Errorf("Error loading config: %w", err) } - credsProvider, err := vault.NewTempCredentialsProvider(config, &vault.CredentialKeyring{Keyring: keyring}, input.NoSession, false) + credsProvider, err := vault.NewTempCredentialsProviderWithOptions( + config, + &vault.CredentialKeyring{Keyring: keyring}, + input.NoSession, + false, + vault.TempCredentialsOptions{ParallelSafe: input.ParallelSafe}, + ) if err != nil { return 0, fmt.Errorf("Error getting temporary credentials: %w", err) } diff --git a/cli/export.go b/cli/export.go index d5723eb0..52ee6032 100644 --- a/cli/export.go +++ b/cli/export.go @@ -23,6 +23,7 @@ type ExportCommandInput struct { SessionDuration time.Duration NoSession bool UseStdout bool + ParallelSafe bool } var ( @@ -66,6 +67,7 @@ func ConfigureExportCommand(app *kingpin.Application, a *AwsVault) { StringVar(&input.ProfileName) cmd.Action(func(c *kingpin.ParseContext) (err error) { + input.ParallelSafe = a.ParallelSafe input.Config.MfaPromptMethod = a.PromptDriver(false) input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration input.Config.AssumeRoleDuration = input.SessionDuration @@ -108,7 +110,13 @@ func ExportCommand(input ExportCommandInput, f *vault.ConfigFile, keyring keyrin } ckr := &vault.CredentialKeyring{Keyring: keyring} - credsProvider, err := vault.NewTempCredentialsProvider(config, ckr, input.NoSession, false) + credsProvider, err := vault.NewTempCredentialsProviderWithOptions( + config, + ckr, + input.NoSession, + false, + vault.TempCredentialsOptions{ParallelSafe: input.ParallelSafe}, + ) if err != nil { return fmt.Errorf("Error getting temporary credentials: %w", err) } diff --git a/cli/global.go b/cli/global.go index 0a2f618c..05bffee7 100644 --- a/cli/global.go +++ b/cli/global.go @@ -37,6 +37,7 @@ type AwsVault struct { KeyringConfig keyring.Config KeyringBackend string promptDriver string + ParallelSafe bool keyringImpl keyring.Keyring awsConfigFile *vault.ConfigFile @@ -77,6 +78,13 @@ func (a *AwsVault) Keyring() (keyring.Keyring, error) { if err != nil { return nil, err } + if a.ParallelSafe { + lockKey := a.KeyringConfig.KeychainName + if lockKey == "" { + lockKey = "aws-vault" + } + a.keyringImpl = vault.NewLockedKeyring(a.keyringImpl, lockKey) + } } return a.keyringImpl, nil @@ -201,6 +209,10 @@ func ConfigureGlobals(app *kingpin.Application) *AwsVault { Envar("AWS_VAULT_BIOMETRICS"). BoolVar(&a.UseBiometrics) + app.Flag("parallel-safe", "Enable cross-process locking for keychain and cached credentials"). + Envar("AWS_VAULT_PARALLEL_SAFE"). + BoolVar(&a.ParallelSafe) + app.PreAction(func(c *kingpin.ParseContext) error { if !a.Debug { log.SetOutput(io.Discard) diff --git a/cli/rotate.go b/cli/rotate.go index 97a4462e..751d897d 100644 --- a/cli/rotate.go +++ b/cli/rotate.go @@ -14,9 +14,10 @@ import ( ) type RotateCommandInput struct { - NoSession bool - ProfileName string - Config vault.ProfileConfig + NoSession bool + ProfileName string + Config vault.ProfileConfig + ParallelSafe bool } func ConfigureRotateCommand(app *kingpin.Application, a *AwsVault) { @@ -34,6 +35,7 @@ func ConfigureRotateCommand(app *kingpin.Application, a *AwsVault) { StringVar(&input.ProfileName) cmd.Action(func(c *kingpin.ParseContext) (err error) { + input.ParallelSafe = a.ParallelSafe input.Config.MfaPromptMethod = a.PromptDriver(false) f, err := a.AwsConfigFile() @@ -97,7 +99,13 @@ func RotateCommand(input RotateCommandInput, f *vault.ConfigFile, keyring keyrin credsProvider = vault.NewMasterCredentialsProvider(ckr, config.ProfileName) } else { // Can't always disable sessions completely, might need to use session for MFA-Protected API Access - credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, input.NoSession, true) + credsProvider, err = vault.NewTempCredentialsProviderWithOptions( + config, + ckr, + input.NoSession, + true, + vault.TempCredentialsOptions{ParallelSafe: input.ParallelSafe}, + ) if err != nil { return fmt.Errorf("Error getting temporary credentials: %w", err) } diff --git a/vault/cachedsessionprovider.go b/vault/cachedsessionprovider.go index 1a382d6b..ca22d21a 100644 --- a/vault/cachedsessionprovider.go +++ b/vault/cachedsessionprovider.go @@ -2,7 +2,10 @@ package vault import ( "context" + "errors" + "fmt" "log" + "os" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -21,23 +24,154 @@ type CachedSessionProvider struct { SessionProvider StsSessionProvider Keyring *SessionKeyring ExpiryWindow time.Duration + UseSessionLock bool + sessionLock SessionCacheLock + sessionLockWait time.Duration + sessionLockLog time.Duration + sessionNow func() time.Time + sessionSleep func(context.Context, time.Duration) error + sessionLogf func(string, ...any) +} + +const ( + defaultSessionLockWaitDelay = 100 * time.Millisecond + defaultSessionLockLogEvery = 15 * time.Second + defaultSessionLockWarnAfter = 5 * time.Second +) + +func defaultSessionSleep(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func (p *CachedSessionProvider) ensureSessionDependencies() { + if p.sessionLock == nil { + p.sessionLock = NewDefaultSessionCacheLock(p.SessionKey.StringForMatching()) + } + if p.sessionLockWait == 0 { + p.sessionLockWait = defaultSessionLockWaitDelay + } + if p.sessionLockLog == 0 { + p.sessionLockLog = defaultSessionLockLogEvery + } + if p.sessionNow == nil { + p.sessionNow = time.Now + } + if p.sessionSleep == nil { + p.sessionSleep = defaultSessionSleep + } + if p.sessionLogf == nil { + p.sessionLogf = log.Printf + } } func (p *CachedSessionProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { - creds, err := p.Keyring.Get(p.SessionKey) + creds, cached, err := p.getCachedSession() + if err == nil && cached { + return creds, nil + } + + if !p.UseSessionLock { + return p.getSessionWithoutLock(ctx) + } + + p.ensureSessionDependencies() + + return p.getSessionWithLock(ctx) +} + +func (p *CachedSessionProvider) getCachedSession() (creds *ststypes.Credentials, cached bool, err error) { + creds, err = p.Keyring.Get(p.SessionKey) + if err != nil { + return nil, false, err + } + if time.Until(*creds.Expiration) < p.ExpiryWindow { + return nil, false, nil + } + log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String()) + return creds, true, nil +} - if err != nil || time.Until(*creds.Expiration) < p.ExpiryWindow { - // lookup missed, we need to create a new one. - creds, err = p.SessionProvider.RetrieveStsCredentials(ctx) +func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststypes.Credentials, error) { + waiter := newLockWaiter( + p.sessionLock, + "Waiting for session lock at %s\n", + "Waiting for session lock at %s", + p.sessionLockWait, + p.sessionLockLog, + defaultSessionLockWarnAfter, + p.sessionNow, + p.sessionSleep, + p.sessionLogf, + func(format string, args ...any) { + fmt.Fprintf(os.Stderr, format, args...) + }, + ) + + for { + creds, cached, err := p.getCachedSession() + if err == nil && cached { + return creds, nil + } + if ctx.Err() != nil { + return nil, ctx.Err() + } + + locked, err := p.sessionLock.TryLock() if err != nil { return nil, err } - err = p.Keyring.Set(p.SessionKey, creds) - if err != nil { + if locked { + return p.doLockedSessionWork(ctx) + } + if err = waiter.sleepAfterMiss(ctx); err != nil { return nil, err } - } else { - log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String()) + } +} + +func (p *CachedSessionProvider) doLockedSessionWork(ctx context.Context) (creds *ststypes.Credentials, err error) { + defer func() { + if unlockErr := p.sessionLock.Unlock(); unlockErr != nil { + err = errors.Join(err, fmt.Errorf("unlock session lock: %w", unlockErr)) + } + }() + + creds, cached, cacheErr := p.getCachedSession() + if cacheErr == nil && cached { + return creds, nil + } + + creds, err = p.SessionProvider.RetrieveStsCredentials(ctx) + if err != nil { + return nil, err + } + if err = p.Keyring.Set(p.SessionKey, creds); err != nil { + return nil, err + } + + return creds, nil +} + +func (p *CachedSessionProvider) getSessionWithoutLock(ctx context.Context) (*ststypes.Credentials, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + creds, err := p.SessionProvider.RetrieveStsCredentials(ctx) + if err != nil { + return nil, err + } + + if err = p.Keyring.Set(p.SessionKey, creds); err != nil { + return nil, err } return creds, nil diff --git a/vault/cachedsessionprovider_lock_test.go b/vault/cachedsessionprovider_lock_test.go new file mode 100644 index 00000000..e5ef6ead --- /dev/null +++ b/vault/cachedsessionprovider_lock_test.go @@ -0,0 +1,291 @@ +package vault + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/byteness/keyring" +) + +type testSessionProvider struct { + creds *types.Credentials + calls int + onRetrieve func() +} + +func (p *testSessionProvider) RetrieveStsCredentials(context.Context) (*types.Credentials, error) { + p.calls++ + if p.onRetrieve != nil { + p.onRetrieve() + } + return p.creds, nil +} + +func (p *testSessionProvider) Retrieve(context.Context) (aws.Credentials, error) { + return aws.Credentials{}, nil +} + +type lockCheckingKeyring struct { + keyring.Keyring + setCalls int + setLock *testLock +} + +func (k *lockCheckingKeyring) Set(item keyring.Item) error { + k.setCalls++ + if k.setLock != nil && !k.setLock.locked { + return fmt.Errorf("lock not held during cache set") + } + return k.Keyring.Set(item) +} + +func newTestSessionKey() SessionMetadata { + return SessionMetadata{ + Type: "sso.GetRoleCredentials", + ProfileName: "test-profile", + MfaSerial: "https://sso.example", + } +} + +func newTestCreds(expires time.Time) *types.Credentials { + return &types.Credentials{ + AccessKeyId: aws.String("AKIATEST"), + SecretAccessKey: aws.String("secret"), + SessionToken: aws.String("token"), + Expiration: aws.Time(expires), + } +} + +func TestCachedSession_CacheHit_NoLock(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + if err := sk.Set(key, creds); err != nil { + t.Fatalf("set cache: %v", err) + } + + lock := &testLock{} + provider := &testSessionProvider{ + onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called on cache hit") }, + } + + p := &CachedSessionProvider{ + SessionKey: key, + SessionProvider: provider, + Keyring: sk, + ExpiryWindow: 0, + UseSessionLock: true, + sessionLock: lock, + } + + got, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if aws.ToString(got.AccessKeyId) != aws.ToString(creds.AccessKeyId) { + t.Fatalf("unexpected credentials returned") + } + if lock.tryCalls != 0 { + t.Fatalf("expected no lock attempts, got %d", lock.tryCalls) + } + if provider.calls != 0 { + t.Fatalf("expected no provider calls, got %d", provider.calls) + } +} + +func TestCachedSession_LockDisabled_SkipsLock(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + lock := &testLock{tryResults: []bool{true}} + provider := &testSessionProvider{creds: creds} + + p := &CachedSessionProvider{ + SessionKey: key, + SessionProvider: provider, + Keyring: sk, + ExpiryWindow: 0, + UseSessionLock: false, + sessionLock: lock, + } + + got, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if aws.ToString(got.AccessKeyId) != aws.ToString(creds.AccessKeyId) { + t.Fatalf("unexpected credentials returned") + } + if lock.tryCalls != 0 { + t.Fatalf("expected no lock attempts, got %d", lock.tryCalls) + } + if provider.calls != 1 { + t.Fatalf("expected 1 provider call, got %d", provider.calls) + } + if _, err := sk.Get(key); err != nil { + t.Fatalf("expected cached credentials, got %v", err) + } +} + +func TestCachedSession_LockMiss_ThenCacheHit_NoRefresh(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + lock := &testLock{tryResults: []bool{false}} + + provider := &testSessionProvider{ + onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called when cache fills while waiting") }, + } + + p := &CachedSessionProvider{ + SessionKey: key, + SessionProvider: provider, + Keyring: sk, + ExpiryWindow: 0, + UseSessionLock: true, + sessionLock: lock, + sessionLockWait: 5 * time.Second, + } + p.sessionSleep = func(ctx context.Context, d time.Duration) error { + return sk.Set(key, creds) + } + + got, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if aws.ToString(got.AccessKeyId) != aws.ToString(creds.AccessKeyId) { + t.Fatalf("unexpected credentials returned") + } + if lock.tryCalls != 1 { + t.Fatalf("expected 1 lock attempt, got %d", lock.tryCalls) + } + if provider.calls != 0 { + t.Fatalf("expected no provider calls, got %d", provider.calls) + } +} + +func TestCachedSession_LockAcquired_RecheckCache(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + lock := &testLock{tryResults: []bool{true}} + lock.onTry = func(l *testLock) { + if l.locked { + _ = sk.Set(key, creds) + } + } + + provider := &testSessionProvider{ + onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called when cache fills after lock") }, + } + + p := &CachedSessionProvider{ + SessionKey: key, + SessionProvider: provider, + Keyring: sk, + ExpiryWindow: 0, + UseSessionLock: true, + sessionLock: lock, + } + + got, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if aws.ToString(got.AccessKeyId) != aws.ToString(creds.AccessKeyId) { + t.Fatalf("unexpected credentials returned") + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } + if provider.calls != 0 { + t.Fatalf("expected no provider calls, got %d", provider.calls) + } +} + +func TestCachedSession_LockHeldThroughCacheSet(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + lock := &testLock{tryResults: []bool{true}} + wrappedKeyring := &lockCheckingKeyring{ + Keyring: keyring.NewArrayKeyring(nil), + setLock: lock, + } + sk := &SessionKeyring{Keyring: wrappedKeyring} + provider := &testSessionProvider{creds: creds} + + p := &CachedSessionProvider{ + SessionKey: key, + SessionProvider: provider, + Keyring: sk, + ExpiryWindow: 0, + UseSessionLock: true, + sessionLock: lock, + } + + _, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if wrappedKeyring.setCalls != 1 { + t.Fatalf("expected cache set once, got %d", wrappedKeyring.setCalls) + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } + if provider.calls != 1 { + t.Fatalf("expected 1 provider call, got %d", provider.calls) + } +} + +func TestCachedSession_LockWaitLogs(t *testing.T) { + lock := &testLock{tryResults: []bool{false, false, false, false}} + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + key := newTestSessionKey() + provider := &testSessionProvider{} + + ctx, cancel := context.WithCancel(context.Background()) + clock := &testClock{now: time.Unix(0, 0), cancel: cancel, cancelAfter: 4} + var logTimes []time.Time + + p := &CachedSessionProvider{ + SessionKey: key, + SessionProvider: provider, + Keyring: sk, + ExpiryWindow: 0, + UseSessionLock: true, + sessionLock: lock, + sessionLockWait: 5 * time.Second, + sessionLockLog: 15 * time.Second, + sessionNow: clock.Now, + } + p.sessionSleep = clock.Sleep + p.sessionLogf = func(string, ...any) { + logTimes = append(logTimes, clock.Now()) + } + + _, err := p.RetrieveStsCredentials(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context cancellation, got %v", err) + } + if len(logTimes) != 2 { + t.Fatalf("expected 2 log entries, got %d", len(logTimes)) + } + if !logTimes[0].Equal(time.Unix(0, 0)) { + t.Fatalf("unexpected first log time: %s", logTimes[0]) + } + if !logTimes[1].Equal(time.Unix(15, 0)) { + t.Fatalf("unexpected second log time: %s", logTimes[1]) + } +} diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index 9854f4ee..722382f4 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -9,7 +9,6 @@ import ( "os" "time" - "github.com/byteness/keyring" "github.com/aws/aws-sdk-go-v2/aws" awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go-v2/service/sso" @@ -17,6 +16,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ssooidc" ssooidctypes "github.com/aws/aws-sdk-go-v2/service/ssooidc/types" ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/byteness/keyring" "github.com/skratchdot/open-golang/open" ) @@ -28,19 +28,69 @@ type OIDCTokenCacher interface { // SSORoleCredentialsProvider creates temporary credentials for an SSO Role. type SSORoleCredentialsProvider struct { - OIDCClient *ssooidc.Client - OIDCTokenCache OIDCTokenCacher - StartURL string - SSOClient *sso.Client - AccountID string - RoleName string - UseStdout bool + OIDCClient *ssooidc.Client + OIDCTokenCache OIDCTokenCacher + StartURL string + SSOClient *sso.Client + AccountID string + RoleName string + UseStdout bool + UseSSOTokenLock bool + ssoTokenLock SSOTokenLock + ssoLockWait time.Duration + ssoLockLog time.Duration + ssoNow func() time.Time + ssoSleep func(context.Context, time.Duration) error + ssoLogf func(string, ...any) + newOIDCTokenFn func(context.Context) (*ssooidc.CreateTokenOutput, error) } func millisecondsTimeValue(v int64) time.Time { return time.Unix(0, v*int64(time.Millisecond)) } +const ( + defaultSSOLockWaitDelay = 100 * time.Millisecond + defaultSSOLockLogEvery = 15 * time.Second + defaultSSOLockWarnAfter = 5 * time.Second +) + +func defaultSSOSleep(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func (p *SSORoleCredentialsProvider) ensureSSODependencies() { + if p.ssoTokenLock == nil && !p.UseStdout && p.UseSSOTokenLock { + p.ssoTokenLock = NewDefaultSSOTokenLock(p.StartURL) + } + if p.ssoLockWait == 0 { + p.ssoLockWait = defaultSSOLockWaitDelay + } + if p.ssoLockLog == 0 { + p.ssoLockLog = defaultSSOLockLogEvery + } + if p.ssoNow == nil { + p.ssoNow = time.Now + } + if p.ssoSleep == nil { + p.ssoSleep = defaultSSOSleep + } + if p.ssoLogf == nil { + p.ssoLogf = log.Printf + } + if p.newOIDCTokenFn == nil { + p.newOIDCTokenFn = p.newOIDCToken + } +} + // Retrieve generates a new set of temporary credentials using SSO GetRoleCredentials. func (p *SSORoleCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { creds, err := p.getRoleCredentials(ctx) @@ -58,6 +108,8 @@ func (p *SSORoleCredentialsProvider) Retrieve(ctx context.Context) (aws.Credenti } func (p *SSORoleCredentialsProvider) getRoleCredentials(ctx context.Context) (*ssotypes.RoleCredentials, error) { + p.ensureSSODependencies() + token, cached, err := p.getOIDCToken(ctx) if err != nil { return nil, err @@ -113,27 +165,118 @@ func (p *SSORoleCredentialsProvider) getRoleCredentialsAsStsCredemtials(ctx cont } func (p *SSORoleCredentialsProvider) getOIDCToken(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { + p.ensureSSODependencies() + + token, cached, err = p.getCachedOIDCToken() + if err != nil || token != nil { + return token, cached, err + } + + if p.UseStdout { + return p.createAndCacheOIDCToken(ctx) + } + + if !p.UseSSOTokenLock { + return p.createAndCacheOIDCToken(ctx) + } + + return p.getOIDCTokenWithLock(ctx) +} + +func (p *SSORoleCredentialsProvider) getCachedOIDCToken() (token *ssooidc.CreateTokenOutput, cached bool, err error) { + if p.OIDCTokenCache == nil { + return nil, false, nil + } + + token, err = p.OIDCTokenCache.Get(p.StartURL) + if err != nil && err != keyring.ErrKeyNotFound { + return nil, false, err + } + if token != nil { + return token, true, nil + } + + return nil, false, nil +} + +func (p *SSORoleCredentialsProvider) createAndCacheOIDCToken(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { + token, err = p.newOIDCTokenFn(ctx) + if err != nil { + return nil, false, err + } + if p.OIDCTokenCache != nil { - token, err = p.OIDCTokenCache.Get(p.StartURL) - if err != nil && err != keyring.ErrKeyNotFound { + if err = p.OIDCTokenCache.Set(p.StartURL, token); err != nil { return nil, false, err } - if token != nil { - return token, true, nil + } + + return token, false, nil +} + +func (p *SSORoleCredentialsProvider) getOIDCTokenWithLock(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { + waiter := newLockWaiter( + p.ssoTokenLock, + "Waiting for SSO lock at %s\n", + "Waiting for SSO lock at %s", + p.ssoLockWait, + p.ssoLockLog, + defaultSSOLockWarnAfter, + p.ssoNow, + p.ssoSleep, + p.ssoLogf, + func(format string, args ...any) { + fmt.Fprintf(os.Stderr, format, args...) + }, + ) + + for { + token, cached, err = p.getCachedOIDCToken() + if err != nil || token != nil { + return token, cached, err + } + if ctx.Err() != nil { + return nil, false, ctx.Err() + } + + locked, err := p.ssoTokenLock.TryLock() + if err != nil { + return nil, false, err + } + if locked { + return p.doLockedOIDCTokenWork(ctx) + } + + if err = waiter.sleepAfterMiss(ctx); err != nil { + return nil, false, err } } - token, err = p.newOIDCToken(ctx) +} + +func (p *SSORoleCredentialsProvider) doLockedOIDCTokenWork(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { + defer func() { + if unlockErr := p.ssoTokenLock.Unlock(); unlockErr != nil { + err = errors.Join(err, fmt.Errorf("unlock SSO token lock: %w", unlockErr)) + } + }() + + token, cached, err = p.getCachedOIDCToken() + if err != nil || token != nil { + return token, cached, err + } + + token, err = p.newOIDCTokenFn(ctx) if err != nil { return nil, false, err } if p.OIDCTokenCache != nil { - err = p.OIDCTokenCache.Set(p.StartURL, token) - if err != nil { + if err = p.OIDCTokenCache.Set(p.StartURL, token); err != nil { return nil, false, err } } - return token, false, err + + return token, false, nil } func (p *SSORoleCredentialsProvider) newOIDCToken(ctx context.Context) (*ssooidc.CreateTokenOutput, error) { diff --git a/vault/ssorolecredentialsprovider_lock_test.go b/vault/ssorolecredentialsprovider_lock_test.go new file mode 100644 index 00000000..0729a2b6 --- /dev/null +++ b/vault/ssorolecredentialsprovider_lock_test.go @@ -0,0 +1,303 @@ +package vault + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ssooidc" + "github.com/byteness/keyring" +) + +type testTokenCache struct { + token *ssooidc.CreateTokenOutput + setCalls int + setLock *testLock +} + +func (c *testTokenCache) Get(string) (*ssooidc.CreateTokenOutput, error) { + if c.token == nil { + return nil, keyring.ErrKeyNotFound + } + return c.token, nil +} + +func (c *testTokenCache) Set(_ string, token *ssooidc.CreateTokenOutput) error { + c.setCalls++ + if c.setLock != nil && !c.setLock.locked { + return fmt.Errorf("lock not held during cache set") + } + c.token = token + return nil +} + +func (c *testTokenCache) Remove(string) error { + c.token = nil + return nil +} + +func TestGetOIDCToken_CacheHit_NoLock(t *testing.T) { + cachedToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("cached")} + cache := &testTokenCache{token: cachedToken} + lock := &testLock{} + + p := &SSORoleCredentialsProvider{ + OIDCTokenCache: cache, + StartURL: "https://sso.example", + ssoTokenLock: lock, + UseStdout: false, + UseSSOTokenLock: true, + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + t.Fatal("newOIDCToken should not be called on cache hit") + return nil, nil + } + p.ssoLogf = func(string, ...any) {} + p.ssoSleep = func(context.Context, time.Duration) error { return nil } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !cached { + t.Fatalf("expected cached token") + } + if token != cachedToken { + t.Fatalf("unexpected token returned") + } + if lock.tryCalls != 0 { + t.Fatalf("expected no lock attempts, got %d", lock.tryCalls) + } +} + +func TestGetOIDCToken_LockDisabled_SkipsLock(t *testing.T) { + freshToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("fresh")} + cache := &testTokenCache{} + lock := &testLock{tryResults: []bool{true}} + + p := &SSORoleCredentialsProvider{ + OIDCTokenCache: cache, + StartURL: "https://sso.example", + ssoTokenLock: lock, + UseStdout: false, + UseSSOTokenLock: false, + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + return freshToken, nil + } + p.ssoLogf = func(string, ...any) {} + p.ssoSleep = func(context.Context, time.Duration) error { return nil } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cached { + t.Fatalf("expected non-cached token") + } + if token != freshToken { + t.Fatalf("unexpected token returned") + } + if lock.tryCalls != 0 { + t.Fatalf("expected no lock attempts, got %d", lock.tryCalls) + } + if cache.setCalls != 1 { + t.Fatalf("expected cache set once, got %d", cache.setCalls) + } +} + +func TestGetOIDCToken_LockMiss_ThenCacheHit_NoLock(t *testing.T) { + cachedToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("cached")} + cache := &testTokenCache{} + lock := &testLock{tryResults: []bool{false}} + clock := &testClock{now: time.Unix(0, 0)} + + p := &SSORoleCredentialsProvider{ + OIDCTokenCache: cache, + StartURL: "https://sso.example", + ssoTokenLock: lock, + UseStdout: false, + UseSSOTokenLock: true, + ssoLockWait: 5 * time.Second, + ssoNow: clock.Now, + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + t.Fatal("newOIDCToken should not be called when cache fills while waiting") + return nil, nil + } + p.ssoLogf = func(string, ...any) {} + p.ssoSleep = func(ctx context.Context, d time.Duration) error { + clock.now = clock.now.Add(d) + cache.token = cachedToken + return nil + } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !cached { + t.Fatalf("expected cached token") + } + if token != cachedToken { + t.Fatalf("unexpected token returned") + } + if lock.tryCalls != 1 { + t.Fatalf("expected 1 lock attempt, got %d", lock.tryCalls) + } + if lock.unlockCalls != 0 { + t.Fatalf("expected no unlocks, got %d", lock.unlockCalls) + } +} + +func TestGetOIDCToken_LockAcquired_RecheckCache(t *testing.T) { + cachedToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("cached")} + cache := &testTokenCache{} + lock := &testLock{tryResults: []bool{true}} + lock.onTry = func(l *testLock) { + if l.locked { + cache.token = cachedToken + } + } + + p := &SSORoleCredentialsProvider{ + OIDCTokenCache: cache, + StartURL: "https://sso.example", + ssoTokenLock: lock, + UseStdout: false, + UseSSOTokenLock: true, + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + t.Fatal("newOIDCToken should not be called when cache is filled after lock") + return nil, nil + } + p.ssoLogf = func(string, ...any) {} + p.ssoSleep = func(context.Context, time.Duration) error { return nil } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !cached { + t.Fatalf("expected cached token") + } + if token != cachedToken { + t.Fatalf("unexpected token returned") + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } +} + +func TestGetOIDCToken_LockHeldThroughCacheSet(t *testing.T) { + freshToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("fresh")} + lock := &testLock{tryResults: []bool{true}} + cache := &testTokenCache{setLock: lock} + + p := &SSORoleCredentialsProvider{ + OIDCTokenCache: cache, + StartURL: "https://sso.example", + ssoTokenLock: lock, + UseStdout: false, + UseSSOTokenLock: true, + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + return freshToken, nil + } + p.ssoLogf = func(string, ...any) {} + p.ssoSleep = func(context.Context, time.Duration) error { return nil } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cached { + t.Fatalf("expected non-cached token") + } + if token != freshToken { + t.Fatalf("unexpected token returned") + } + if cache.setCalls != 1 { + t.Fatalf("expected cache set once, got %d", cache.setCalls) + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } +} + +func TestGetOIDCToken_UseStdout_SkipsLock(t *testing.T) { + freshToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("fresh")} + lock := &testLock{tryResults: []bool{true}} + cache := &testTokenCache{} + + p := &SSORoleCredentialsProvider{ + OIDCTokenCache: cache, + StartURL: "https://sso.example", + ssoTokenLock: lock, + UseStdout: true, + UseSSOTokenLock: true, + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + return freshToken, nil + } + p.ssoLogf = func(string, ...any) {} + p.ssoSleep = func(context.Context, time.Duration) error { return nil } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cached { + t.Fatalf("expected non-cached token") + } + if token != freshToken { + t.Fatalf("unexpected token returned") + } + if lock.tryCalls != 0 { + t.Fatalf("expected no lock attempts, got %d", lock.tryCalls) + } +} + +func TestGetOIDCToken_LockWaitLogs(t *testing.T) { + lock := &testLock{tryResults: []bool{false, false, false, false}} + cache := &testTokenCache{} + ctx, cancel := context.WithCancel(context.Background()) + clock := &testClock{now: time.Unix(0, 0), cancel: cancel, cancelAfter: 4} + var logTimes []time.Time + + p := &SSORoleCredentialsProvider{ + OIDCTokenCache: cache, + StartURL: "https://sso.example", + ssoTokenLock: lock, + UseStdout: false, + UseSSOTokenLock: true, + ssoLockWait: 5 * time.Second, + ssoLockLog: 15 * time.Second, + ssoNow: clock.Now, + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + t.Fatal("newOIDCToken should not be called when lock never acquired") + return nil, nil + } + p.ssoSleep = clock.Sleep + p.ssoLogf = func(string, ...any) { + logTimes = append(logTimes, clock.Now()) + } + + _, _, err := p.getOIDCToken(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context cancellation, got %v", err) + } + if len(logTimes) != 2 { + t.Fatalf("expected 2 log entries, got %d", len(logTimes)) + } + if !logTimes[0].Equal(time.Unix(0, 0)) { + t.Fatalf("unexpected first log time: %s", logTimes[0]) + } + if !logTimes[1].Equal(time.Unix(15, 0)) { + t.Fatalf("unexpected second log time: %s", logTimes[1]) + } +} diff --git a/vault/vault.go b/vault/vault.go index 786fad6e..74358605 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -230,6 +230,8 @@ type TempCredentialsCreator struct { DisableCache bool // DisableSessionsForProfile is a profile for which sessions should not be used DisableSessionsForProfile string + // ParallelSafe enables cross-process locking for cached credentials. + ParallelSafe bool chainedMfa string } @@ -253,17 +255,19 @@ func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, if err != nil { return nil, err } + sourcecredsProvider = t.applyParallelSafety(sourcecredsProvider) if hasStoredCredentials || !config.HasRole() { if canUseGetSessionToken, reason := t.canUseGetSessionToken(config); !canUseGetSessionToken { log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason) if !config.HasRole() { - return sourcecredsProvider, nil + return t.applyParallelSafety(sourcecredsProvider), nil } } t.chainedMfa = config.MfaSerial log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config)) sourcecredsProvider, err = NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) + sourcecredsProvider = t.applyParallelSafety(sourcecredsProvider) if !config.HasRole() || err != nil { return sourcecredsProvider, err } @@ -275,7 +279,11 @@ func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, config.MfaSerial = "" } log.Printf("profile %s: using AssumeRole %s", config.ProfileName, mfaDetails(isMfaChained, config)) - return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) + provider, err := NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) + if err != nil { + return nil, err + } + return t.applyParallelSafety(provider), nil } if isMasterCredentialsProvider(sourcecredsProvider) { @@ -283,12 +291,16 @@ func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, if canUseGetSessionToken { t.chainedMfa = config.MfaSerial log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config)) - return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) + provider, err := NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) + if err != nil { + return nil, err + } + return t.applyParallelSafety(provider), nil } log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason) } - return sourcecredsProvider, nil + return t.applyParallelSafety(sourcecredsProvider), nil } func (t *TempCredentialsCreator) GetProviderForProfile(config *ProfileConfig) (aws.CredentialsProvider, error) { @@ -303,22 +315,54 @@ func (t *TempCredentialsCreator) GetProviderForProfile(config *ProfileConfig) (a if config.HasSSOStartURL() { log.Printf("profile %s: using SSO role credentials", config.ProfileName) - return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config, !t.DisableCache) + provider, err := NewSSORoleCredentialsProvider(t.Keyring.Keyring, config, !t.DisableCache) + if err != nil { + return nil, err + } + return t.applyParallelSafety(provider), nil } if config.HasWebIdentity() { log.Printf("profile %s: using web identity", config.ProfileName) - return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config, !t.DisableCache) + provider, err := NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config, !t.DisableCache) + if err != nil { + return nil, err + } + return t.applyParallelSafety(provider), nil } if config.HasCredentialProcess() { log.Printf("profile %s: using credential process", config.ProfileName) - return NewCredentialProcessProvider(t.Keyring.Keyring, config, !t.DisableCache) + provider, err := NewCredentialProcessProvider(t.Keyring.Keyring, config, !t.DisableCache) + if err != nil { + return nil, err + } + return t.applyParallelSafety(provider), nil } return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName) } +func (t *TempCredentialsCreator) applyParallelSafety(provider aws.CredentialsProvider) aws.CredentialsProvider { + if !t.ParallelSafe { + return provider + } + + if cached, ok := provider.(*CachedSessionProvider); ok { + cached.UseSessionLock = true + if ssoProvider, ok := cached.SessionProvider.(*SSORoleCredentialsProvider); ok { + ssoProvider.UseSSOTokenLock = true + } + return provider + } + + if ssoProvider, ok := provider.(*SSORoleCredentialsProvider); ok { + ssoProvider.UseSSOTokenLock = true + } + + return provider +} + // canUseGetSessionToken determines if GetSessionToken should be used, and if not returns a reason func (t *TempCredentialsCreator) canUseGetSessionToken(c *ProfileConfig) (bool, string) { if t.DisableSessions { @@ -359,12 +403,23 @@ func mfaDetails(mfaChained bool, config *ProfileConfig) string { return "" } -// NewTempCredentialsProvider creates a credential provider for the given config +// TempCredentialsOptions controls how temporary credential providers are created. +type TempCredentialsOptions struct { + ParallelSafe bool +} + +// NewTempCredentialsProvider creates a credential provider for the given config. func NewTempCredentialsProvider(config *ProfileConfig, keyring *CredentialKeyring, disableSessions bool, disableCache bool) (aws.CredentialsProvider, error) { + return NewTempCredentialsProviderWithOptions(config, keyring, disableSessions, disableCache, TempCredentialsOptions{}) +} + +// NewTempCredentialsProviderWithOptions creates a credential provider for the given config with options. +func NewTempCredentialsProviderWithOptions(config *ProfileConfig, keyring *CredentialKeyring, disableSessions bool, disableCache bool, options TempCredentialsOptions) (aws.CredentialsProvider, error) { t := TempCredentialsCreator{ Keyring: keyring, DisableSessions: disableSessions, DisableCache: disableCache, + ParallelSafe: options.ParallelSafe, } return t.GetProviderForProfile(config) } diff --git a/vault/vault_test.go b/vault/vault_test.go index 75812bab..3615d7e8 100644 --- a/vault/vault_test.go +++ b/vault/vault_test.go @@ -4,8 +4,8 @@ import ( "os" "testing" - "github.com/byteness/keyring" "github.com/byteness/aws-vault/v7/vault" + "github.com/byteness/keyring" ) func TestUsageWebIdentityExample(t *testing.T) { @@ -123,3 +123,40 @@ sso_registration_scopes=sso:account:access t.Fatalf("Expected AccountID to be 2160xxxx, got %s", ssoProvider.AccountID) } } + +func TestTempCredentialsProviderParallelSafeSSOLocks(t *testing.T) { + config := &vault.ProfileConfig{ + ProfileName: "sso-profile", + SSOStartURL: "https://sso.example/start", + SSORegion: "us-east-1", + SSOAccountID: "123456789012", + SSORoleName: "Role", + } + + ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{})} + provider, err := vault.NewTempCredentialsProviderWithOptions( + config, + ckr, + false, + false, + vault.TempCredentialsOptions{ParallelSafe: true}, + ) + if err != nil { + t.Fatal(err) + } + + cached, ok := provider.(*vault.CachedSessionProvider) + if !ok { + t.Fatalf("Expected CachedSessionProvider, got %T", provider) + } + if !cached.UseSessionLock { + t.Fatalf("Expected UseSessionLock to be true") + } + ssoProvider, ok := cached.SessionProvider.(*vault.SSORoleCredentialsProvider) + if !ok { + t.Fatalf("Expected SSORoleCredentialsProvider, got %T", cached.SessionProvider) + } + if !ssoProvider.UseSSOTokenLock { + t.Fatalf("Expected UseSSOTokenLock to be true") + } +} From 191e52638fe1274d79c0025d605efb5ad141abe1 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Fri, 20 Mar 2026 09:24:26 -0400 Subject: [PATCH 03/30] feat(sso): add rate-limit retry with jittered backoff When many parallel aws-vault processes hit GetRoleCredentials simultaneously, AWS returns HTTP 429 (TooManyRequests). Add a retry loop with Retry-After header support and exponential backoff with jitter so the SSO token exchange is resilient under heavy load. Gives up after 5 minutes of persistent 429s as a safety net. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/ssorolecredentialsprovider.go | 158 ++++++++++++++++-- .../ssorolecredentialsprovider_retry_test.go | 78 +++++++++ 2 files changed, 218 insertions(+), 18 deletions(-) create mode 100644 vault/ssorolecredentialsprovider_retry_test.go diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index 722382f4..ed3d9d36 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -5,8 +5,11 @@ import ( "errors" "fmt" "log" + "math/rand" "net/http" "os" + "strconv" + "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -53,6 +56,13 @@ const ( defaultSSOLockWaitDelay = 100 * time.Millisecond defaultSSOLockLogEvery = 15 * time.Second defaultSSOLockWarnAfter = 5 * time.Second + // ssoRetryTimeout is a pathological safety net: if GetRoleCredentials is still + // returning 429s after this duration, give up and surface the error to the user. + ssoRetryTimeout = 5 * time.Minute + ssoRetryBase = 200 * time.Millisecond + ssoRetryMax = 5 * time.Second + ssoRetryAfterJitterMin = 1.1 + ssoRetryAfterJitterMax = 1.3 ) func defaultSSOSleep(ctx context.Context, d time.Duration) error { @@ -115,34 +125,67 @@ func (p *SSORoleCredentialsProvider) getRoleCredentials(ctx context.Context) (*s return nil, err } - resp, err := p.SSOClient.GetRoleCredentials(ctx, &sso.GetRoleCredentialsInput{ - AccessToken: token.AccessToken, - AccountId: aws.String(p.AccountID), - RoleName: aws.String(p.RoleName), - }) - if err != nil { + baseDelay, maxDelay := ssoRetryBase, ssoRetryMax + deadline := p.ssoNow().Add(ssoRetryTimeout) + attempt := 0 + rateLimitCount := 0 + var maxRetryAfterSeen time.Duration + for { + attempt++ + resp, err := p.SSOClient.GetRoleCredentials(ctx, &sso.GetRoleCredentialsInput{ + AccessToken: token.AccessToken, + AccountId: aws.String(p.AccountID), + RoleName: aws.String(p.RoleName), + }) + if err == nil { + log.Printf("Got credentials %s for SSO role %s (account: %s), expires in %s", FormatKeyForDisplay(*resp.RoleCredentials.AccessKeyId), p.RoleName, p.AccountID, time.Until(millisecondsTimeValue(resp.RoleCredentials.Expiration)).String()) + return resp.RoleCredentials, nil + } + if cached && p.OIDCTokenCache != nil { var rspError *awshttp.ResponseError - if !errors.As(err, &rspError) { - return nil, err + if errors.As(err, &rspError) && rspError.HTTPStatusCode() == http.StatusUnauthorized { + // Cached token rejected: drop it and retry with a fresh access token. + // This should only happen once because the cache is cleared before retrying. + if err = p.OIDCTokenCache.Remove(p.StartURL); err != nil { + return nil, err + } + token, cached, err = p.getOIDCToken(ctx) + if err != nil { + return nil, err + } + attempt = 0 + continue } + } - // If the error is a 401, remove the cached oidc token and try - // again. This is a recursive call but it should only happen once - // due to the cache being cleared before retrying. - if rspError.HTTPStatusCode() == http.StatusUnauthorized { - err = p.OIDCTokenCache.Remove(p.StartURL) - if err != nil { + if isSSORateLimitError(err) { + rateLimitCount++ + remaining := time.Until(deadline) + if 0 < remaining { + var delay time.Duration + if retryAfter, ok := retryAfterFromError(err); ok { + if maxRetryAfterSeen < retryAfter { + maxRetryAfterSeen = retryAfter + } + delay = jitterRetryAfter(retryAfter) + } else { + delay = jitteredBackoff(baseDelay, maxDelay, attempt) + } + if remaining < delay { + delay = remaining + } + log.Printf("SSO rate limited for role %s (account: %s); backing off %s, attempt %d (%d 429s, max retry-after %s)", p.RoleName, p.AccountID, delay, attempt, rateLimitCount, maxRetryAfterSeen) + if err = p.ssoSleep(ctx, delay); err != nil { return nil, err } - return p.getRoleCredentials(ctx) + continue } + return nil, fmt.Errorf("SSO rate limited for role %s (account: %s) persistently for %s (%d 429s, max retry-after %s); giving up — try again later: %w", p.RoleName, p.AccountID, ssoRetryTimeout, rateLimitCount, maxRetryAfterSeen, err) } + return nil, err } - log.Printf("Got credentials %s for SSO role %s (account: %s), expires in %s", FormatKeyForDisplay(*resp.RoleCredentials.AccessKeyId), p.RoleName, p.AccountID, time.Until(millisecondsTimeValue(resp.RoleCredentials.Expiration)).String()) - - return resp.RoleCredentials, nil } func (p *SSORoleCredentialsProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { @@ -344,3 +387,82 @@ func (p *SSORoleCredentialsProvider) newOIDCToken(ctx context.Context) (*ssooidc return t, nil } } + +func retryAfterFromError(err error) (time.Duration, bool) { + var rspError *awshttp.ResponseError + if errors.As(err, &rspError) { + if rspError.Response != nil { + if d, ok := parseRetryAfter(rspError.Response.Header.Get("Retry-After")); ok { + return d, true + } + } + } + return 0, false +} + +func parseRetryAfter(value string) (time.Duration, bool) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return 0, false + } + if secs, err := strconv.Atoi(trimmed); err == nil { + if secs < 0 { + return 0, false + } + return time.Duration(secs) * time.Second, true + } + if t, err := http.ParseTime(trimmed); err == nil { + d := time.Until(t) + if d < 0 { + d = 0 + } + return d, true + } + return 0, false +} + +func isSSORateLimitError(err error) bool { + var tooMany *ssotypes.TooManyRequestsException + if errors.As(err, &tooMany) { + return true + } + var rspError *awshttp.ResponseError + if errors.As(err, &rspError) && rspError.HTTPStatusCode() == http.StatusTooManyRequests { + return true + } + return false +} + +func jitterRetryAfter(base time.Duration) time.Duration { + if base <= 0 { + return 0 + } + return jitterDelay(base) +} + +func jitteredBackoff(base, max time.Duration, attempt int) time.Duration { + if attempt < 1 { + attempt = 1 + } + capDelay := base << uint(attempt-1) + if capDelay > max { + capDelay = max + } + if capDelay < base { + capDelay = base + } + return jitterDelay(capDelay) +} + +func jitterDelay(base time.Duration) time.Duration { + if base <= 0 { + return 0 + } + min := ssoRetryAfterJitterMin + max := ssoRetryAfterJitterMax + if max < min { + max = min + } + factor := min + rand.Float64()*(max-min) + return time.Duration(float64(base) * factor) +} diff --git a/vault/ssorolecredentialsprovider_retry_test.go b/vault/ssorolecredentialsprovider_retry_test.go new file mode 100644 index 00000000..71bc9a20 --- /dev/null +++ b/vault/ssorolecredentialsprovider_retry_test.go @@ -0,0 +1,78 @@ +package vault + +import ( + "errors" + "net/http" + "testing" + "time" + + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + ssotypes "github.com/aws/aws-sdk-go-v2/service/sso/types" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +func TestRetryAfterFromErrorSeconds(t *testing.T) { + header := http.Header{} + header.Set("Retry-After", "120") + resp := &http.Response{StatusCode: http.StatusTooManyRequests, Header: header} + err := &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{Response: resp}, + }, + } + + delay, ok := retryAfterFromError(err) + if !ok { + t.Fatal("expected retry-after delay to be detected") + } + if delay != 120*time.Second { + t.Fatalf("expected 120s retry-after, got %s", delay) + } +} + +func TestRetryAfterFromErrorMissingHeader(t *testing.T) { + resp := &http.Response{StatusCode: http.StatusTooManyRequests, Header: http.Header{}} + err := &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{Response: resp}, + }, + } + + delay, ok := retryAfterFromError(err) + if ok { + t.Fatalf("expected retry-after to be absent, got %s", delay) + } +} + +func TestIsSSORateLimitError(t *testing.T) { + if !isSSORateLimitError(&ssotypes.TooManyRequestsException{}) { + t.Fatal("expected TooManyRequestsException to be rate limit error") + } + + resp := &http.Response{StatusCode: http.StatusTooManyRequests} + err := &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{Response: resp}, + }, + } + if !isSSORateLimitError(err) { + t.Fatal("expected HTTP 429 response error to be rate limit error") + } + + if isSSORateLimitError(errors.New("boom")) { + t.Fatal("expected non-rate-limit error to be false") + } +} + +func TestJitterDelayRange(t *testing.T) { + base := 10 * time.Second + min := time.Duration(float64(base) * ssoRetryAfterJitterMin) + max := time.Duration(float64(base) * ssoRetryAfterJitterMax) + + for i := 0; i < 10; i++ { + delay := jitterDelay(base) + if delay < min || delay > max { + t.Fatalf("expected delay in range %s-%s, got %s", min, max, delay) + } + } +} From 5abae24865e5021080e228982965768215a8a109 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Fri, 20 Mar 2026 09:24:37 -0400 Subject: [PATCH 04/30] docs: document parallel-safe mode in USAGE.md Add a section explaining the --parallel-safe flag, what it protects against (browser storms, keyring races), its trade-offs (serialized keyring ops, all-or-nothing opt-in), and current limitations (lock wait timeout, SSO retry timeout). Co-Authored-By: Claude Opus 4.6 (1M context) --- USAGE.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/USAGE.md b/USAGE.md index 00e30ffa..899e09f3 100644 --- a/USAGE.md +++ b/USAGE.md @@ -273,6 +273,7 @@ WARNING: Use of this option runs against security best practices. It is recommen To configure the default flag values of `aws-vault` and its subcommands: * `AWS_VAULT_BACKEND`: Secret backend to use (see the flag `--backend`) * `AWS_VAULT_BIOMETRICS`: Use biometric authentication using TouchID, if supported (see the flag `--biometrics`) +* `AWS_VAULT_PARALLEL_SAFE`: Enable cross-process locking for keychain and cached credentials (see the flag `--parallel-safe`) * `AWS_VAULT_KEYCHAIN_NAME`: Name of macOS keychain to use (see the flag `--keychain`) * `AWS_VAULT_AUTO_LOGOUT`: Enable auto-logout when doing `login` (see the flag `--auto-logout`) * `AWS_VAULT_PROMPT`: Prompt driver to use (see the flag `--prompt`) @@ -634,6 +635,31 @@ role_arn=arn:aws:iam::123456789013:role/AnotherRole source_profile=Administrator-123456789012] ``` +## Parallel-safe mode + +When running many `aws-vault` processes in parallel (e.g. Terraform with hundreds of `credential_process` invocations), concurrent access to the secret store and SSO browser flows can cause errors: + +- **Browser storms**: Multiple processes each open a browser tab for the same SSO login, overwhelming AWS and triggering HTTP 500 errors. +- **Secret store races**: Concurrent writes to the same keyring entry cause "item already exists" errors or partial reads. + +The `--parallel-safe` flag (or `AWS_VAULT_PARALLEL_SAFE=true`) enables cross-process locking to prevent these issues: + +- **SSO token lock**: Only one process per SSO Start URL opens a browser tab; others wait for the cached token. +- **Session cache lock**: Only one process writes back to a given session cache entry at a time. +- **Keyring lock**: All keyring read/write operations are serialized across processes. + +This applies to **all backends** (keychain, file, pass, secret-service, etc.). + +### Trade-offs + +- Keyring operations are serialized, which adds a small amount of latency per operation. In practice this is negligible because the operations themselves are fast. +- **All concurrent invocations must use `--parallel-safe`**. If some processes enable it and others don't, the unprotected processes ignore the locks entirely. This is undefined behavior and may still cause races. Set `AWS_VAULT_PARALLEL_SAFE=true` in your environment to ensure consistent use. + +### Limitations + +- The keyring lock wait loop cannot be cancelled by the caller because the `keyring.Keyring` interface is not context-aware. If a lock holder hangs (e.g. a stuck `gpg` subprocess in the `pass` backend), waiters will time out after 2 minutes rather than waiting indefinitely. +- SSO rate-limit retries (HTTP 429 on `GetRoleCredentials`) will retry for up to 5 minutes before giving up with an error. + ## Assuming roles with web identities AWS supports assuming roles using [web identity federation and OpenID Connect](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-role.html#cli-configure-role-oidc), including login using Amazon, Google, Facebook or any other OpenID Connect server. The configuration options are as follows: From 6e4312060b3fd629daad4dd5c2968967265b4070 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 10:03:45 -0400 Subject: [PATCH 05/30] fix(sso): address code review feedback on providers - Rename shadowed err to sleepErr in getSessionWithLock - Add rationale comments to retry/backoff constants - Add doc comments to ensureSessionDependencies/ensureSSODependencies explaining the lazy-init pattern Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/cachedsessionprovider.go | 52 ++++---- vault/cachedsessionprovider_lock_test.go | 74 ++++-------- vault/locked_keyring.go | 77 ++++++------ vault/ssorolecredentialsprovider.go | 74 +++++++----- vault/ssorolecredentialsprovider_lock_test.go | 113 +++++++----------- vault/vault.go | 65 +++++----- 6 files changed, 211 insertions(+), 244 deletions(-) diff --git a/vault/cachedsessionprovider.go b/vault/cachedsessionprovider.go index ca22d21a..a985da1a 100644 --- a/vault/cachedsessionprovider.go +++ b/vault/cachedsessionprovider.go @@ -34,8 +34,19 @@ type CachedSessionProvider struct { } const ( + // defaultSessionLockWaitDelay is the polling interval between lock attempts. + // 100ms keeps latency low for the typical case where the lock holder + // finishes quickly (STS call + cache write). defaultSessionLockWaitDelay = 100 * time.Millisecond - defaultSessionLockLogEvery = 15 * time.Second + + // defaultSessionLockLogEvery controls how often we emit a debug log while + // waiting for the lock. 15s avoids log spam while still showing progress. + defaultSessionLockLogEvery = 15 * time.Second + + // defaultSessionLockWarnAfter is the delay before printing a user-visible + // "waiting for lock" message to stderr. 5s is long enough to avoid + // flashing the message on normal lock contention, short enough to + // reassure the user that the process isn't hung. defaultSessionLockWarnAfter = 5 * time.Second ) @@ -51,24 +62,21 @@ func defaultSessionSleep(ctx context.Context, d time.Duration) error { } } -func (p *CachedSessionProvider) ensureSessionDependencies() { - if p.sessionLock == nil { - p.sessionLock = NewDefaultSessionCacheLock(p.SessionKey.StringForMatching()) - } - if p.sessionLockWait == 0 { - p.sessionLockWait = defaultSessionLockWaitDelay - } - if p.sessionLockLog == 0 { - p.sessionLockLog = defaultSessionLockLogEvery - } - if p.sessionNow == nil { - p.sessionNow = time.Now - } - if p.sessionSleep == nil { - p.sessionSleep = defaultSessionSleep - } - if p.sessionLogf == nil { - p.sessionLogf = log.Printf +// NewCachedSessionProvider creates a CachedSessionProvider with production +// defaults for all internal dependencies. Tests can override unexported fields +// (sessionLock, sessionNow, etc.) after construction to inject mocks. +func NewCachedSessionProvider(key SessionMetadata, provider StsSessionProvider, keyring *SessionKeyring, expiryWindow time.Duration) *CachedSessionProvider { + return &CachedSessionProvider{ + SessionKey: key, + SessionProvider: provider, + Keyring: keyring, + ExpiryWindow: expiryWindow, + sessionLock: NewDefaultSessionCacheLock(key.StringForMatching()), + sessionLockWait: defaultSessionLockWaitDelay, + sessionLockLog: defaultSessionLockLogEvery, + sessionNow: time.Now, + sessionSleep: defaultSessionSleep, + sessionLogf: log.Printf, } } @@ -82,8 +90,6 @@ func (p *CachedSessionProvider) RetrieveStsCredentials(ctx context.Context) (*st return p.getSessionWithoutLock(ctx) } - p.ensureSessionDependencies() - return p.getSessionWithLock(ctx) } @@ -131,8 +137,8 @@ func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststyp if locked { return p.doLockedSessionWork(ctx) } - if err = waiter.sleepAfterMiss(ctx); err != nil { - return nil, err + if sleepErr := waiter.sleepAfterMiss(ctx); sleepErr != nil { + return nil, sleepErr } } } diff --git a/vault/cachedsessionprovider_lock_test.go b/vault/cachedsessionprovider_lock_test.go index e5ef6ead..198b96f9 100644 --- a/vault/cachedsessionprovider_lock_test.go +++ b/vault/cachedsessionprovider_lock_test.go @@ -75,14 +75,9 @@ func TestCachedSession_CacheHit_NoLock(t *testing.T) { onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called on cache hit") }, } - p := &CachedSessionProvider{ - SessionKey: key, - SessionProvider: provider, - Keyring: sk, - ExpiryWindow: 0, - UseSessionLock: true, - sessionLock: lock, - } + p := NewCachedSessionProvider(key, provider, sk, 0) + p.UseSessionLock = true + p.sessionLock = lock got, err := p.RetrieveStsCredentials(context.Background()) if err != nil { @@ -107,14 +102,9 @@ func TestCachedSession_LockDisabled_SkipsLock(t *testing.T) { lock := &testLock{tryResults: []bool{true}} provider := &testSessionProvider{creds: creds} - p := &CachedSessionProvider{ - SessionKey: key, - SessionProvider: provider, - Keyring: sk, - ExpiryWindow: 0, - UseSessionLock: false, - sessionLock: lock, - } + p := NewCachedSessionProvider(key, provider, sk, 0) + p.UseSessionLock = false + p.sessionLock = lock got, err := p.RetrieveStsCredentials(context.Background()) if err != nil { @@ -145,15 +135,10 @@ func TestCachedSession_LockMiss_ThenCacheHit_NoRefresh(t *testing.T) { onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called when cache fills while waiting") }, } - p := &CachedSessionProvider{ - SessionKey: key, - SessionProvider: provider, - Keyring: sk, - ExpiryWindow: 0, - UseSessionLock: true, - sessionLock: lock, - sessionLockWait: 5 * time.Second, - } + p := NewCachedSessionProvider(key, provider, sk, 0) + p.UseSessionLock = true + p.sessionLock = lock + p.sessionLockWait = 5 * time.Second p.sessionSleep = func(ctx context.Context, d time.Duration) error { return sk.Set(key, creds) } @@ -189,14 +174,9 @@ func TestCachedSession_LockAcquired_RecheckCache(t *testing.T) { onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called when cache fills after lock") }, } - p := &CachedSessionProvider{ - SessionKey: key, - SessionProvider: provider, - Keyring: sk, - ExpiryWindow: 0, - UseSessionLock: true, - sessionLock: lock, - } + p := NewCachedSessionProvider(key, provider, sk, 0) + p.UseSessionLock = true + p.sessionLock = lock got, err := p.RetrieveStsCredentials(context.Background()) if err != nil { @@ -224,14 +204,9 @@ func TestCachedSession_LockHeldThroughCacheSet(t *testing.T) { sk := &SessionKeyring{Keyring: wrappedKeyring} provider := &testSessionProvider{creds: creds} - p := &CachedSessionProvider{ - SessionKey: key, - SessionProvider: provider, - Keyring: sk, - ExpiryWindow: 0, - UseSessionLock: true, - sessionLock: lock, - } + p := NewCachedSessionProvider(key, provider, sk, 0) + p.UseSessionLock = true + p.sessionLock = lock _, err := p.RetrieveStsCredentials(context.Background()) if err != nil { @@ -259,17 +234,12 @@ func TestCachedSession_LockWaitLogs(t *testing.T) { clock := &testClock{now: time.Unix(0, 0), cancel: cancel, cancelAfter: 4} var logTimes []time.Time - p := &CachedSessionProvider{ - SessionKey: key, - SessionProvider: provider, - Keyring: sk, - ExpiryWindow: 0, - UseSessionLock: true, - sessionLock: lock, - sessionLockWait: 5 * time.Second, - sessionLockLog: 15 * time.Second, - sessionNow: clock.Now, - } + p := NewCachedSessionProvider(key, provider, sk, 0) + p.UseSessionLock = true + p.sessionLock = lock + p.sessionLockWait = 5 * time.Second + p.sessionLockLog = 15 * time.Second + p.sessionNow = clock.Now p.sessionSleep = clock.Sleep p.sessionLogf = func(string, ...any) { logTimes = append(logTimes, clock.Now()) diff --git a/vault/locked_keyring.go b/vault/locked_keyring.go index 084d5c03..914f2981 100644 --- a/vault/locked_keyring.go +++ b/vault/locked_keyring.go @@ -27,62 +27,57 @@ type lockedKeyring struct { } const ( + // defaultKeyringLockWaitDelay is the polling interval between lock attempts. + // 100ms keeps latency low for the typical case where the lock holder + // finishes a single keyring read/write quickly. defaultKeyringLockWaitDelay = 100 * time.Millisecond - defaultKeyringLockLogEvery = 15 * time.Second + + // defaultKeyringLockLogEvery controls how often we emit a debug log while + // waiting for the lock. 15s avoids log spam while still showing progress. + defaultKeyringLockLogEvery = 15 * time.Second + + // defaultKeyringLockWarnAfter is the delay before printing a user-visible + // "waiting for lock" message to stderr. 5s is long enough to avoid + // flashing the message on normal lock contention, short enough to + // reassure the user that the process isn't hung. defaultKeyringLockWarnAfter = 5 * time.Second - defaultKeyringLockTimeout = 2 * time.Minute + + // defaultKeyringLockTimeout is a safety net: the keyring.Keyring interface + // is not context-aware, so if the lock holder is hung (e.g. a stuck gpg + // subprocess in the pass backend), waiters give up after this duration + // rather than blocking indefinitely. 2 minutes is generous enough for any + // reasonable keyring operation. + defaultKeyringLockTimeout = 2 * time.Minute ) // NewLockedKeyring wraps the provided keyring with a cross-process lock // to serialize keyring operations. func NewLockedKeyring(kr keyring.Keyring, lockKey string) keyring.Keyring { return &lockedKeyring{ - inner: kr, - lock: NewDefaultKeyringLock(lockKey), - lockKey: lockKey, + inner: kr, + lock: NewDefaultKeyringLock(lockKey), + lockKey: lockKey, + lockWait: defaultKeyringLockWaitDelay, + lockLog: defaultKeyringLockLogEvery, + warnAfter: defaultKeyringLockWarnAfter, + lockNow: time.Now, + lockSleep: defaultLockedKeyringSleep, + lockLogf: log.Printf, } } -func (k *lockedKeyring) ensureLockDependencies() { - if k.lock == nil { - lockKey := k.lockKey - if lockKey == "" { - lockKey = "aws-vault" - } - k.lock = NewDefaultKeyringLock(lockKey) - } - if k.lockWait == 0 { - k.lockWait = defaultKeyringLockWaitDelay - } - if k.lockLog == 0 { - k.lockLog = defaultKeyringLockLogEvery - } - if k.warnAfter == 0 { - k.warnAfter = defaultKeyringLockWarnAfter - } - if k.lockNow == nil { - k.lockNow = time.Now - } - if k.lockSleep == nil { - k.lockSleep = func(ctx context.Context, d time.Duration) error { - timer := time.NewTimer(d) - defer timer.Stop() - select { - case <-ctx.Done(): - return ctx.Err() - case <-timer.C: - return nil - } - } - } - if k.lockLogf == nil { - k.lockLogf = log.Printf +func defaultLockedKeyringSleep(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil } } func (k *lockedKeyring) withLock(fn func() error) error { - k.ensureLockDependencies() - k.mu.Lock() defer k.mu.Unlock() diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index ed3d9d36..c6bdf252 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -53,14 +53,40 @@ func millisecondsTimeValue(v int64) time.Time { } const ( + // defaultSSOLockWaitDelay is the polling interval between lock attempts. + // 100ms keeps latency low for the typical case where the lock holder + // finishes quickly (browser auth + token cache write). defaultSSOLockWaitDelay = 100 * time.Millisecond - defaultSSOLockLogEvery = 15 * time.Second + + // defaultSSOLockLogEvery controls how often we emit a debug log while + // waiting for the lock. 15s avoids log spam while still showing progress + // during long waits (e.g. slow browser auth). + defaultSSOLockLogEvery = 15 * time.Second + + // defaultSSOLockWarnAfter is the delay before printing a user-visible + // "waiting for lock" message to stderr. 5s is long enough to avoid + // flashing the message on normal lock contention, short enough to + // reassure the user that the process isn't hung. defaultSSOLockWarnAfter = 5 * time.Second + // ssoRetryTimeout is a pathological safety net: if GetRoleCredentials is still // returning 429s after this duration, give up and surface the error to the user. - ssoRetryTimeout = 5 * time.Minute - ssoRetryBase = 200 * time.Millisecond - ssoRetryMax = 5 * time.Second + // 5 minutes is generous but accommodates burst-heavy credential_process workloads + // (e.g. Terraform with hundreds of parallel invocations). + ssoRetryTimeout = 5 * time.Minute + + // ssoRetryBase is the initial backoff delay before the first retry. + // 200ms is short enough to avoid unnecessary latency on transient 429s + // while still giving the SSO service breathing room. + ssoRetryBase = 200 * time.Millisecond + + // ssoRetryMax caps the exponential backoff so that individual waits + // don't grow unreasonably large between attempts. + ssoRetryMax = 5 * time.Second + + // ssoRetryAfterJitterMin and ssoRetryAfterJitterMax add ±10-30% jitter + // to Retry-After values to decorrelate concurrent processes that all + // received the same Retry-After header from the SSO service. ssoRetryAfterJitterMin = 1.1 ssoRetryAfterJitterMax = 1.3 ) @@ -77,28 +103,24 @@ func defaultSSOSleep(ctx context.Context, d time.Duration) error { } } -func (p *SSORoleCredentialsProvider) ensureSSODependencies() { - if p.ssoTokenLock == nil && !p.UseStdout && p.UseSSOTokenLock { +// initSSODefaults sets production defaults for all internal dependencies. +// Called by the constructor; tests can override unexported fields afterward. +func (p *SSORoleCredentialsProvider) initSSODefaults() { + p.ssoLockWait = defaultSSOLockWaitDelay + p.ssoLockLog = defaultSSOLockLogEvery + p.ssoNow = time.Now + p.ssoSleep = defaultSSOSleep + p.ssoLogf = log.Printf + p.newOIDCTokenFn = p.newOIDCToken +} + +// EnableSSOTokenLock creates the SSO token lock for cross-process coordination. +// Called by applyParallelSafety after setting UseSSOTokenLock. +func (p *SSORoleCredentialsProvider) EnableSSOTokenLock() { + p.UseSSOTokenLock = true + if !p.UseStdout && p.ssoTokenLock == nil { p.ssoTokenLock = NewDefaultSSOTokenLock(p.StartURL) } - if p.ssoLockWait == 0 { - p.ssoLockWait = defaultSSOLockWaitDelay - } - if p.ssoLockLog == 0 { - p.ssoLockLog = defaultSSOLockLogEvery - } - if p.ssoNow == nil { - p.ssoNow = time.Now - } - if p.ssoSleep == nil { - p.ssoSleep = defaultSSOSleep - } - if p.ssoLogf == nil { - p.ssoLogf = log.Printf - } - if p.newOIDCTokenFn == nil { - p.newOIDCTokenFn = p.newOIDCToken - } } // Retrieve generates a new set of temporary credentials using SSO GetRoleCredentials. @@ -118,8 +140,6 @@ func (p *SSORoleCredentialsProvider) Retrieve(ctx context.Context) (aws.Credenti } func (p *SSORoleCredentialsProvider) getRoleCredentials(ctx context.Context) (*ssotypes.RoleCredentials, error) { - p.ensureSSODependencies() - token, cached, err := p.getOIDCToken(ctx) if err != nil { return nil, err @@ -208,8 +228,6 @@ func (p *SSORoleCredentialsProvider) getRoleCredentialsAsStsCredemtials(ctx cont } func (p *SSORoleCredentialsProvider) getOIDCToken(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { - p.ensureSSODependencies() - token, cached, err = p.getCachedOIDCToken() if err != nil || token != nil { return token, cached, err diff --git a/vault/ssorolecredentialsprovider_lock_test.go b/vault/ssorolecredentialsprovider_lock_test.go index 0729a2b6..df89b6a8 100644 --- a/vault/ssorolecredentialsprovider_lock_test.go +++ b/vault/ssorolecredentialsprovider_lock_test.go @@ -39,24 +39,27 @@ func (c *testTokenCache) Remove(string) error { return nil } +func newTestSSORoleProvider() *SSORoleCredentialsProvider { + p := &SSORoleCredentialsProvider{ + StartURL: "https://sso.example", + } + p.initSSODefaults() + return p +} + func TestGetOIDCToken_CacheHit_NoLock(t *testing.T) { cachedToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("cached")} cache := &testTokenCache{token: cachedToken} lock := &testLock{} - p := &SSORoleCredentialsProvider{ - OIDCTokenCache: cache, - StartURL: "https://sso.example", - ssoTokenLock: lock, - UseStdout: false, - UseSSOTokenLock: true, - } + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = true p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { t.Fatal("newOIDCToken should not be called on cache hit") return nil, nil } - p.ssoLogf = func(string, ...any) {} - p.ssoSleep = func(context.Context, time.Duration) error { return nil } token, cached, err := p.getOIDCToken(context.Background()) if err != nil { @@ -78,18 +81,13 @@ func TestGetOIDCToken_LockDisabled_SkipsLock(t *testing.T) { cache := &testTokenCache{} lock := &testLock{tryResults: []bool{true}} - p := &SSORoleCredentialsProvider{ - OIDCTokenCache: cache, - StartURL: "https://sso.example", - ssoTokenLock: lock, - UseStdout: false, - UseSSOTokenLock: false, - } + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = false p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { return freshToken, nil } - p.ssoLogf = func(string, ...any) {} - p.ssoSleep = func(context.Context, time.Duration) error { return nil } token, cached, err := p.getOIDCToken(context.Background()) if err != nil { @@ -115,20 +113,16 @@ func TestGetOIDCToken_LockMiss_ThenCacheHit_NoLock(t *testing.T) { lock := &testLock{tryResults: []bool{false}} clock := &testClock{now: time.Unix(0, 0)} - p := &SSORoleCredentialsProvider{ - OIDCTokenCache: cache, - StartURL: "https://sso.example", - ssoTokenLock: lock, - UseStdout: false, - UseSSOTokenLock: true, - ssoLockWait: 5 * time.Second, - ssoNow: clock.Now, - } + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = true + p.ssoLockWait = 5 * time.Second + p.ssoNow = clock.Now p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { t.Fatal("newOIDCToken should not be called when cache fills while waiting") return nil, nil } - p.ssoLogf = func(string, ...any) {} p.ssoSleep = func(ctx context.Context, d time.Duration) error { clock.now = clock.now.Add(d) cache.token = cachedToken @@ -163,19 +157,14 @@ func TestGetOIDCToken_LockAcquired_RecheckCache(t *testing.T) { } } - p := &SSORoleCredentialsProvider{ - OIDCTokenCache: cache, - StartURL: "https://sso.example", - ssoTokenLock: lock, - UseStdout: false, - UseSSOTokenLock: true, - } + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = true p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { t.Fatal("newOIDCToken should not be called when cache is filled after lock") return nil, nil } - p.ssoLogf = func(string, ...any) {} - p.ssoSleep = func(context.Context, time.Duration) error { return nil } token, cached, err := p.getOIDCToken(context.Background()) if err != nil { @@ -197,18 +186,13 @@ func TestGetOIDCToken_LockHeldThroughCacheSet(t *testing.T) { lock := &testLock{tryResults: []bool{true}} cache := &testTokenCache{setLock: lock} - p := &SSORoleCredentialsProvider{ - OIDCTokenCache: cache, - StartURL: "https://sso.example", - ssoTokenLock: lock, - UseStdout: false, - UseSSOTokenLock: true, - } + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = true p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { return freshToken, nil } - p.ssoLogf = func(string, ...any) {} - p.ssoSleep = func(context.Context, time.Duration) error { return nil } token, cached, err := p.getOIDCToken(context.Background()) if err != nil { @@ -233,18 +217,14 @@ func TestGetOIDCToken_UseStdout_SkipsLock(t *testing.T) { lock := &testLock{tryResults: []bool{true}} cache := &testTokenCache{} - p := &SSORoleCredentialsProvider{ - OIDCTokenCache: cache, - StartURL: "https://sso.example", - ssoTokenLock: lock, - UseStdout: true, - UseSSOTokenLock: true, - } + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseStdout = true + p.UseSSOTokenLock = true p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { return freshToken, nil } - p.ssoLogf = func(string, ...any) {} - p.ssoSleep = func(context.Context, time.Duration) error { return nil } token, cached, err := p.getOIDCToken(context.Background()) if err != nil { @@ -268,24 +248,21 @@ func TestGetOIDCToken_LockWaitLogs(t *testing.T) { clock := &testClock{now: time.Unix(0, 0), cancel: cancel, cancelAfter: 4} var logTimes []time.Time - p := &SSORoleCredentialsProvider{ - OIDCTokenCache: cache, - StartURL: "https://sso.example", - ssoTokenLock: lock, - UseStdout: false, - UseSSOTokenLock: true, - ssoLockWait: 5 * time.Second, - ssoLockLog: 15 * time.Second, - ssoNow: clock.Now, + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = true + p.ssoLockWait = 5 * time.Second + p.ssoLockLog = 15 * time.Second + p.ssoNow = clock.Now + p.ssoSleep = clock.Sleep + p.ssoLogf = func(string, ...any) { + logTimes = append(logTimes, clock.Now()) } p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { t.Fatal("newOIDCToken should not be called when lock never acquired") return nil, nil } - p.ssoSleep = clock.Sleep - p.ssoLogf = func(string, ...any) { - logTimes = append(logTimes, clock.Now()) - } _, _, err := p.getOIDCToken(ctx) if !errors.Is(err, context.Canceled) { diff --git a/vault/vault.go b/vault/vault.go index 74358605..efff61ab 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -61,16 +61,16 @@ func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Ke } if useSessionCache { - return &CachedSessionProvider{ - SessionKey: SessionMetadata{ + return NewCachedSessionProvider( + SessionMetadata{ Type: "sts.GetSessionToken", ProfileName: config.ProfileName, MfaSerial: config.MfaSerial, }, - Keyring: &SessionKeyring{Keyring: k}, - ExpiryWindow: defaultExpirationWindow, - SessionProvider: sessionTokenProvider, - }, nil + sessionTokenProvider, + &SessionKeyring{Keyring: k}, + defaultExpirationWindow, + ), nil } return sessionTokenProvider, nil @@ -93,16 +93,16 @@ func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyr } if useSessionCache && config.MfaSerial != "" { - return &CachedSessionProvider{ - SessionKey: SessionMetadata{ + return NewCachedSessionProvider( + SessionMetadata{ Type: "sts.AssumeRole", ProfileName: config.ProfileName, MfaSerial: config.MfaSerial, }, - Keyring: &SessionKeyring{Keyring: k}, - ExpiryWindow: defaultExpirationWindow, - SessionProvider: p, - }, nil + p, + &SessionKeyring{Keyring: k}, + defaultExpirationWindow, + ), nil } return p, nil @@ -123,15 +123,15 @@ func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConf } if useSessionCache { - return &CachedSessionProvider{ - SessionKey: SessionMetadata{ + return NewCachedSessionProvider( + SessionMetadata{ Type: "sts.AssumeRoleWithWebIdentity", ProfileName: config.ProfileName, }, - Keyring: &SessionKeyring{Keyring: k}, - ExpiryWindow: defaultExpirationWindow, - SessionProvider: p, - }, nil + p, + &SessionKeyring{Keyring: k}, + defaultExpirationWindow, + ), nil } return p, nil @@ -149,19 +149,20 @@ func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, use RoleName: config.SSORoleName, UseStdout: config.SSOUseStdout, } + ssoRoleCredentialsProvider.initSSODefaults() if useSessionCache { ssoRoleCredentialsProvider.OIDCTokenCache = OIDCTokenKeyring{Keyring: k} - return &CachedSessionProvider{ - SessionKey: SessionMetadata{ + return NewCachedSessionProvider( + SessionMetadata{ Type: "sso.GetRoleCredentials", ProfileName: config.ProfileName, MfaSerial: config.SSOStartURL, }, - Keyring: &SessionKeyring{Keyring: k}, - ExpiryWindow: defaultExpirationWindow, - SessionProvider: ssoRoleCredentialsProvider, - }, nil + ssoRoleCredentialsProvider, + &SessionKeyring{Keyring: k}, + defaultExpirationWindow, + ), nil } return ssoRoleCredentialsProvider, nil @@ -175,15 +176,15 @@ func NewCredentialProcessProvider(k keyring.Keyring, config *ProfileConfig, useS } if useSessionCache { - return &CachedSessionProvider{ - SessionKey: SessionMetadata{ + return NewCachedSessionProvider( + SessionMetadata{ Type: "credential_process", ProfileName: config.ProfileName, }, - Keyring: &SessionKeyring{Keyring: k}, - ExpiryWindow: defaultExpirationWindow, - SessionProvider: credentialProcessProvider, - }, nil + credentialProcessProvider, + &SessionKeyring{Keyring: k}, + defaultExpirationWindow, + ), nil } return credentialProcessProvider, nil @@ -351,13 +352,13 @@ func (t *TempCredentialsCreator) applyParallelSafety(provider aws.CredentialsPro if cached, ok := provider.(*CachedSessionProvider); ok { cached.UseSessionLock = true if ssoProvider, ok := cached.SessionProvider.(*SSORoleCredentialsProvider); ok { - ssoProvider.UseSSOTokenLock = true + ssoProvider.EnableSSOTokenLock() } return provider } if ssoProvider, ok := provider.(*SSORoleCredentialsProvider); ok { - ssoProvider.UseSSOTokenLock = true + ssoProvider.EnableSSOTokenLock() } return provider From 6bed045504c31d3d3a82a5ff69613c6e9e4fa8fe Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 10:03:48 -0400 Subject: [PATCH 06/30] chore(deps): bump gofrs/flock from v0.8.1 to v0.13.0 Co-Authored-By: Claude Opus 4.6 (1M context) --- go.mod | 5 +++-- go.sum | 12 ++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index fcc411f3..6c95a0e7 100644 --- a/go.mod +++ b/go.mod @@ -12,10 +12,11 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 + github.com/aws/smithy-go v1.24.2 github.com/byteness/keyring v1.9.0 github.com/charmbracelet/huh v1.0.0 github.com/charmbracelet/lipgloss v1.1.0 - github.com/gofrs/flock v0.8.1 + github.com/gofrs/flock v0.13.0 github.com/google/go-cmp v0.7.0 github.com/mattn/go-isatty v0.0.21 github.com/mattn/go-tty v0.0.7 @@ -36,7 +37,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect - github.com/aws/smithy-go v1.24.2 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/byteness/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect github.com/byteness/go-libsecret v0.0.0-20260108215642-107379d3dee0 // indirect @@ -75,6 +75,7 @@ require ( github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/tetratelabs/wabin v0.0.0-20230304001439-f6f874872834 // indirect github.com/tetratelabs/wazero v1.11.0 // indirect github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect diff --git a/go.sum b/go.sum index ec69f876..f1381686 100644 --- a/go.sum +++ b/go.sum @@ -116,8 +116,8 @@ github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= -github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= -github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= +github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= +github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -128,6 +128,8 @@ github.com/ianlancetaylor/demangle v0.0.0-20251118225945-96ee0021ea0f h1:Fnl4pzx github.com/ianlancetaylor/demangle v0.0.0-20251118225945-96ee0021ea0f/go.mod h1:gx7rwoVhcfuVKG5uya9Hs3Sxj7EIvldVofAWIUtGouw= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= @@ -156,7 +158,6 @@ github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELU github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/noamcohen97/touchid-go v0.3.0 h1:fcXxVCizysD7KHRR6hrURt3nyNIs5JBGSbOIidD/3wo= github.com/noamcohen97/touchid-go v0.3.0/go.mod h1:X9MRNIBGEmPqwpDm1G3fQOAQX7fwBlhzUbnkDTxuta0= @@ -168,6 +169,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -238,8 +241,9 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b h1:QRR6H1YWRnHb4Y/HeNFCTJLFVxaq6wH4YuVdsUOr75U= gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/ini.v1 v1.67.1 h1:tVBILHy0R6e4wkYOn3XmiITt/hEVH4TFMYvAX2Ytz6k= gopkg.in/ini.v1 v1.67.1/go.mod h1:x/cyOwCgZqOkJoDIJ3c1KNHMo10+nLGAhh+kn3Zizss= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 89ddfa357e511a375c32ede21cc9ea2a2f87fcd3 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 10:03:51 -0400 Subject: [PATCH 07/30] docs: explain why login is excluded from --parallel-safe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Login opens a browser for an interactive console session — you cannot meaningfully log in to multiple consoles in parallel — so there is no concurrent-access problem for --parallel-safe to solve. Document this in the code comment, --help output, and USAGE.md. Co-Authored-By: Claude Opus 4.6 (1M context) --- USAGE.md | 4 ++++ cli/login.go | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/USAGE.md b/USAGE.md index 899e09f3..0d7e1dc2 100644 --- a/USAGE.md +++ b/USAGE.md @@ -655,6 +655,10 @@ This applies to **all backends** (keychain, file, pass, secret-service, etc.). - Keyring operations are serialized, which adds a small amount of latency per operation. In practice this is negligible because the operations themselves are fast. - **All concurrent invocations must use `--parallel-safe`**. If some processes enable it and others don't, the unprotected processes ignore the locks entirely. This is undefined behavior and may still cause races. Set `AWS_VAULT_PARALLEL_SAFE=true` in your environment to ensure consistent use. +### The `login` command + +The `login` command is intentionally excluded from `--parallel-safe`. Console login sessions are inherently single-use — you cannot meaningfully log in to multiple AWS consoles in parallel — so there is no concurrent-access problem for `--parallel-safe` to solve. The `exec`, `export`, and `rotate` commands all support `--parallel-safe`. + ### Limitations - The keyring lock wait loop cannot be cancelled by the caller because the `keyring.Keyring` interface is not context-aware. If a lock holder hangs (e.g. a stuck `gpg` subprocess in the `pass` backend), waiters will time out after 2 minutes rather than waiting indefinitely. diff --git a/cli/login.go b/cli/login.go index 08a4ef3e..c4351dbe 100644 --- a/cli/login.go +++ b/cli/login.go @@ -35,7 +35,11 @@ type LoginCommandInput struct { func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) { input := LoginCommandInput{} - cmd := app.Command("login", "Generate a login link for the AWS Console.") + // NOTE: login intentionally does not use --parallel-safe. The login command + // opens a browser for an interactive console session — you cannot meaningfully + // log in to multiple AWS consoles in parallel, so there is no concurrent-access + // problem for --parallel-safe to solve here. + cmd := app.Command("login", "Generate a login link for the AWS Console. Note: --parallel-safe does not apply to login because console sessions are inherently single-use.") cmd.Flag("duration", "Duration of the assume-role or federated session. Defaults to 1h"). Short('d'). From 7c8cb8d54a51dd63d4f6691332d6d68ce4e28385 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 10:03:55 -0400 Subject: [PATCH 08/30] feat(lock): derive keyring lock key from backend config The lock key previously fell back to the fixed string "aws-vault" for non-keychain backends, causing unintended contention on shared systems. Now each backend type incorporates its specific config (directory, prefix, collection name, vault ID, etc.) into the lock key so different backends and configurations lock independently. Co-Authored-By: Claude Opus 4.6 (1M context) --- cli/global.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++----- cli/login.go | 2 +- 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/cli/global.go b/cli/global.go index 05bffee7..266bfe9c 100644 --- a/cli/global.go +++ b/cli/global.go @@ -79,10 +79,7 @@ func (a *AwsVault) Keyring() (keyring.Keyring, error) { return nil, err } if a.ParallelSafe { - lockKey := a.KeyringConfig.KeychainName - if lockKey == "" { - lockKey = "aws-vault" - } + lockKey := a.keyringLockKey() a.keyringImpl = vault.NewLockedKeyring(a.keyringImpl, lockKey) } } @@ -90,6 +87,53 @@ func (a *AwsVault) Keyring() (keyring.Keyring, error) { return a.keyringImpl, nil } +// keyringLockKey returns a backend-specific key for the cross-process keyring +// lock. Different backends (and different configurations of the same backend) +// produce different keys so they don't contend on the same lock file. +func (a *AwsVault) keyringLockKey() string { + backend := a.KeyringBackend + switch keyring.BackendType(backend) { + case keyring.KeychainBackend: + if a.KeyringConfig.KeychainName != "" { + return backend + ":" + a.KeyringConfig.KeychainName + } + case keyring.FileBackend: + if a.KeyringConfig.FileDir != "" { + return backend + ":" + a.KeyringConfig.FileDir + } + case keyring.PassBackend: + key := backend + if a.KeyringConfig.PassDir != "" { + key += ":" + a.KeyringConfig.PassDir + } + if a.KeyringConfig.PassPrefix != "" { + key += ":" + a.KeyringConfig.PassPrefix + } + return key + case keyring.SecretServiceBackend: + if a.KeyringConfig.LibSecretCollectionName != "" { + return backend + ":" + a.KeyringConfig.LibSecretCollectionName + } + case keyring.KWalletBackend: + if a.KeyringConfig.KWalletFolder != "" { + return backend + ":" + a.KeyringConfig.KWalletFolder + } + case keyring.WinCredBackend: + if a.KeyringConfig.WinCredPrefix != "" { + return backend + ":" + a.KeyringConfig.WinCredPrefix + } + case keyring.OPBackend, keyring.OPConnectBackend, keyring.OPDesktopBackend: + if a.KeyringConfig.OPVaultID != "" { + return backend + ":" + a.KeyringConfig.OPVaultID + } + } + // Fall back to backend name, which is always set (defaults to first available). + if backend != "" { + return backend + } + return "aws-vault" +} + func (a *AwsVault) AwsConfigFile() (*vault.ConfigFile, error) { if a.awsConfigFile == nil { var err error @@ -209,7 +253,7 @@ func ConfigureGlobals(app *kingpin.Application) *AwsVault { Envar("AWS_VAULT_BIOMETRICS"). BoolVar(&a.UseBiometrics) - app.Flag("parallel-safe", "Enable cross-process locking for keychain and cached credentials"). + app.Flag("parallel-safe", "Enable cross-process locking for keychain and cached credentials (applies to exec, export, rotate; not login)"). Envar("AWS_VAULT_PARALLEL_SAFE"). BoolVar(&a.ParallelSafe) diff --git a/cli/login.go b/cli/login.go index c4351dbe..40375fa3 100644 --- a/cli/login.go +++ b/cli/login.go @@ -39,7 +39,7 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) { // opens a browser for an interactive console session — you cannot meaningfully // log in to multiple AWS consoles in parallel, so there is no concurrent-access // problem for --parallel-safe to solve here. - cmd := app.Command("login", "Generate a login link for the AWS Console. Note: --parallel-safe does not apply to login because console sessions are inherently single-use.") + cmd := app.Command("login", "Generate a login link for the AWS Console.") cmd.Flag("duration", "Duration of the assume-role or federated session. Defaults to 1h"). Short('d'). From 090dc369a7a4b05d0026dc77a46849252bc660b7 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:31:23 -0400 Subject: [PATCH 09/30] style: flip > and >= comparisons to < and <= per repo convention Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/lock_test.go | 2 +- vault/lock_waiter.go | 4 ++-- vault/ssorolecredentialsprovider_retry_test.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vault/lock_test.go b/vault/lock_test.go index 18a2607b..8a5a0f27 100644 --- a/vault/lock_test.go +++ b/vault/lock_test.go @@ -56,7 +56,7 @@ func (c *testClock) Now() time.Time { func (c *testClock) Sleep(ctx context.Context, d time.Duration) error { c.sleepCalls++ c.now = c.now.Add(d) - if c.cancel != nil && c.cancelAfter > 0 && c.sleepCalls >= c.cancelAfter { + if c.cancel != nil && 0 < c.cancelAfter && c.cancelAfter <= c.sleepCalls { c.cancel() } if ctx.Err() != nil { diff --git a/vault/lock_waiter.go b/vault/lock_waiter.go index 5cdac42a..2c8bb18a 100644 --- a/vault/lock_waiter.go +++ b/vault/lock_waiter.go @@ -70,13 +70,13 @@ func (w *lockWaiter) sleepAfterMiss(ctx context.Context) error { if w.waitStart.IsZero() { w.waitStart = now } - if !w.warned && now.Sub(w.waitStart) >= w.warnAfter { + if !w.warned && w.warnAfter <= now.Sub(w.waitStart) { if w.warnf != nil { w.warnf(w.warnMsg, w.lock.Path()) } w.warned = true } - if w.logf != nil && (w.lastLog.IsZero() || now.Sub(w.lastLog) >= w.logEvery) { + if w.logf != nil && (w.lastLog.IsZero() || w.logEvery <= now.Sub(w.lastLog)) { w.logf(w.logMsg, w.lock.Path()) w.lastLog = now } diff --git a/vault/ssorolecredentialsprovider_retry_test.go b/vault/ssorolecredentialsprovider_retry_test.go index 71bc9a20..a4c8a7bd 100644 --- a/vault/ssorolecredentialsprovider_retry_test.go +++ b/vault/ssorolecredentialsprovider_retry_test.go @@ -71,7 +71,7 @@ func TestJitterDelayRange(t *testing.T) { for i := 0; i < 10; i++ { delay := jitterDelay(base) - if delay < min || delay > max { + if delay < min || max < delay { t.Fatalf("expected delay in range %s-%s, got %s", min, max, delay) } } From e1d98d199c509bd8093f26a8ae195790cd4327d0 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:31:30 -0400 Subject: [PATCH 10/30] refactor: deduplicate defaultXxxSleep into shared defaultContextSleep Three identical sleep-with-context-cancellation functions collapsed into one in process_lock.go. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/cachedsessionprovider.go | 15 ++------------- vault/locked_keyring.go | 19 ++++++------------- vault/process_lock.go | 15 +++++++++++++++ 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/vault/cachedsessionprovider.go b/vault/cachedsessionprovider.go index a985da1a..263dd8d8 100644 --- a/vault/cachedsessionprovider.go +++ b/vault/cachedsessionprovider.go @@ -30,7 +30,7 @@ type CachedSessionProvider struct { sessionLockLog time.Duration sessionNow func() time.Time sessionSleep func(context.Context, time.Duration) error - sessionLogf func(string, ...any) + sessionLogf lockLogger } const ( @@ -50,17 +50,6 @@ const ( defaultSessionLockWarnAfter = 5 * time.Second ) -func defaultSessionSleep(ctx context.Context, d time.Duration) error { - timer := time.NewTimer(d) - defer timer.Stop() - - select { - case <-ctx.Done(): - return ctx.Err() - case <-timer.C: - return nil - } -} // NewCachedSessionProvider creates a CachedSessionProvider with production // defaults for all internal dependencies. Tests can override unexported fields @@ -75,7 +64,7 @@ func NewCachedSessionProvider(key SessionMetadata, provider StsSessionProvider, sessionLockWait: defaultSessionLockWaitDelay, sessionLockLog: defaultSessionLockLogEvery, sessionNow: time.Now, - sessionSleep: defaultSessionSleep, + sessionSleep: defaultContextSleep, sessionLogf: log.Printf, } } diff --git a/vault/locked_keyring.go b/vault/locked_keyring.go index 914f2981..23c9d037 100644 --- a/vault/locked_keyring.go +++ b/vault/locked_keyring.go @@ -15,7 +15,10 @@ import ( type lockedKeyring struct { inner keyring.Keyring lock KeyringLock - mu sync.Mutex + // mu serializes in-process access. The flock only coordinates across + // processes; without this mutex, concurrent goroutines in the same + // process could race on the try-lock loop. + mu sync.Mutex lockKey string lockWait time.Duration @@ -23,7 +26,7 @@ type lockedKeyring struct { warnAfter time.Duration lockNow func() time.Time lockSleep func(context.Context, time.Duration) error - lockLogf func(string, ...any) + lockLogf lockLogger } const ( @@ -61,21 +64,11 @@ func NewLockedKeyring(kr keyring.Keyring, lockKey string) keyring.Keyring { lockLog: defaultKeyringLockLogEvery, warnAfter: defaultKeyringLockWarnAfter, lockNow: time.Now, - lockSleep: defaultLockedKeyringSleep, + lockSleep: defaultContextSleep, lockLogf: log.Printf, } } -func defaultLockedKeyringSleep(ctx context.Context, d time.Duration) error { - timer := time.NewTimer(d) - defer timer.Stop() - select { - case <-ctx.Done(): - return ctx.Err() - case <-timer.C: - return nil - } -} func (k *lockedKeyring) withLock(fn func() error) error { k.mu.Lock() diff --git a/vault/process_lock.go b/vault/process_lock.go index 0582fe3c..4c0eb8cc 100644 --- a/vault/process_lock.go +++ b/vault/process_lock.go @@ -1,10 +1,12 @@ package vault import ( + "context" "crypto/sha256" "fmt" "os" "path/filepath" + "time" "github.com/gofrs/flock" ) @@ -45,3 +47,16 @@ func hashedLockFilename(prefix, key string) string { sum := sha256.Sum256([]byte(key)) return fmt.Sprintf("%s.%x.lock", prefix, sum) } + +// defaultContextSleep sleeps for d, respecting ctx cancellation. +// Shared by all lock-wait loops. +func defaultContextSleep(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} From 9f334b356b7c2426a69a48e586038f9a21e59a8a Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:31:46 -0400 Subject: [PATCH 11/30] fix(sso): typo, lockLogger type, comparison flip, and sleep dedup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename getRoleCredentialsAsStsCredemtials → getRoleCredentialsAsStsCredentials - Use lockLogger type instead of func(string, ...any) for ssoLogf - Flip capDelay > max to max < capDelay per repo convention - Replace defaultSSOSleep with shared defaultContextSleep Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/ssorolecredentialsprovider.go | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index c6bdf252..3d0df95b 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -44,7 +44,7 @@ type SSORoleCredentialsProvider struct { ssoLockLog time.Duration ssoNow func() time.Time ssoSleep func(context.Context, time.Duration) error - ssoLogf func(string, ...any) + ssoLogf lockLogger newOIDCTokenFn func(context.Context) (*ssooidc.CreateTokenOutput, error) } @@ -91,17 +91,6 @@ const ( ssoRetryAfterJitterMax = 1.3 ) -func defaultSSOSleep(ctx context.Context, d time.Duration) error { - timer := time.NewTimer(d) - defer timer.Stop() - - select { - case <-ctx.Done(): - return ctx.Err() - case <-timer.C: - return nil - } -} // initSSODefaults sets production defaults for all internal dependencies. // Called by the constructor; tests can override unexported fields afterward. @@ -109,7 +98,7 @@ func (p *SSORoleCredentialsProvider) initSSODefaults() { p.ssoLockWait = defaultSSOLockWaitDelay p.ssoLockLog = defaultSSOLockLogEvery p.ssoNow = time.Now - p.ssoSleep = defaultSSOSleep + p.ssoSleep = defaultContextSleep p.ssoLogf = log.Printf p.newOIDCTokenFn = p.newOIDCToken } @@ -209,11 +198,11 @@ func (p *SSORoleCredentialsProvider) getRoleCredentials(ctx context.Context) (*s } func (p *SSORoleCredentialsProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { - return p.getRoleCredentialsAsStsCredemtials(ctx) + return p.getRoleCredentialsAsStsCredentials(ctx) } -// getRoleCredentialsAsStsCredemtials returns getRoleCredentials as sts.Credentials because sessions.Store expects it -func (p *SSORoleCredentialsProvider) getRoleCredentialsAsStsCredemtials(ctx context.Context) (*ststypes.Credentials, error) { +// getRoleCredentialsAsStsCredentials returns getRoleCredentials as sts.Credentials because sessions.Store expects it +func (p *SSORoleCredentialsProvider) getRoleCredentialsAsStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { creds, err := p.getRoleCredentials(ctx) if err != nil { return nil, err @@ -463,7 +452,7 @@ func jitteredBackoff(base, max time.Duration, attempt int) time.Duration { attempt = 1 } capDelay := base << uint(attempt-1) - if capDelay > max { + if max < capDelay { capDelay = max } if capDelay < base { From c877087f55d050b2b7e28d7506f4c5295234bc0c Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:33:45 -0400 Subject: [PATCH 12/30] refactor(lock): collapse three lock-type files into NewDefaultLock MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit keychain_lock.go, session_lock.go, and sso_lock.go were structurally identical — each defined a type alias, two constructors, and a filename helper differing only by a string prefix. Replace all three with a single NewDefaultLock(prefix, key) in process_lock.go and use ProcessLock directly instead of the type aliases. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/cachedsessionprovider.go | 4 ++-- vault/keychain_lock.go | 21 --------------------- vault/locked_keyring.go | 4 ++-- vault/process_lock.go | 6 ++++++ vault/session_lock.go | 21 --------------------- vault/sso_lock.go | 22 ---------------------- vault/ssorolecredentialsprovider.go | 4 ++-- 7 files changed, 12 insertions(+), 70 deletions(-) delete mode 100644 vault/keychain_lock.go delete mode 100644 vault/session_lock.go delete mode 100644 vault/sso_lock.go diff --git a/vault/cachedsessionprovider.go b/vault/cachedsessionprovider.go index 263dd8d8..ecbb91ef 100644 --- a/vault/cachedsessionprovider.go +++ b/vault/cachedsessionprovider.go @@ -25,7 +25,7 @@ type CachedSessionProvider struct { Keyring *SessionKeyring ExpiryWindow time.Duration UseSessionLock bool - sessionLock SessionCacheLock + sessionLock ProcessLock sessionLockWait time.Duration sessionLockLog time.Duration sessionNow func() time.Time @@ -60,7 +60,7 @@ func NewCachedSessionProvider(key SessionMetadata, provider StsSessionProvider, SessionProvider: provider, Keyring: keyring, ExpiryWindow: expiryWindow, - sessionLock: NewDefaultSessionCacheLock(key.StringForMatching()), + sessionLock: NewDefaultLock("aws-vault.session", key.StringForMatching()), sessionLockWait: defaultSessionLockWaitDelay, sessionLockLog: defaultSessionLockLogEvery, sessionNow: time.Now, diff --git a/vault/keychain_lock.go b/vault/keychain_lock.go deleted file mode 100644 index 36b8e0ea..00000000 --- a/vault/keychain_lock.go +++ /dev/null @@ -1,21 +0,0 @@ -package vault - -const keyringLockFilenamePrefix = "aws-vault.keyring" - -// KeyringLock coordinates keyring access across processes. -type KeyringLock = ProcessLock - -// NewDefaultKeyringLock creates a lock in the system temp directory. -// This only coordinates processes that share the same temp dir; differing TMPDIRs/users are out of scope. -func NewDefaultKeyringLock(lockKey string) KeyringLock { - return NewKeyringLock(defaultLockPath(keyringLockFilename(lockKey))) -} - -// NewKeyringLock creates a lock at the provided path. -func NewKeyringLock(path string) KeyringLock { - return NewFileLock(path) -} - -func keyringLockFilename(lockKey string) string { - return hashedLockFilename(keyringLockFilenamePrefix, lockKey) -} diff --git a/vault/locked_keyring.go b/vault/locked_keyring.go index 23c9d037..85b65541 100644 --- a/vault/locked_keyring.go +++ b/vault/locked_keyring.go @@ -14,7 +14,7 @@ import ( type lockedKeyring struct { inner keyring.Keyring - lock KeyringLock + lock ProcessLock // mu serializes in-process access. The flock only coordinates across // processes; without this mutex, concurrent goroutines in the same // process could race on the try-lock loop. @@ -58,7 +58,7 @@ const ( func NewLockedKeyring(kr keyring.Keyring, lockKey string) keyring.Keyring { return &lockedKeyring{ inner: kr, - lock: NewDefaultKeyringLock(lockKey), + lock: NewDefaultLock("aws-vault.keyring", lockKey), lockKey: lockKey, lockWait: defaultKeyringLockWaitDelay, lockLog: defaultKeyringLockLogEvery, diff --git a/vault/process_lock.go b/vault/process_lock.go index 4c0eb8cc..93526ec6 100644 --- a/vault/process_lock.go +++ b/vault/process_lock.go @@ -48,6 +48,12 @@ func hashedLockFilename(prefix, key string) string { return fmt.Sprintf("%s.%x.lock", prefix, sum) } +// NewDefaultLock creates a ProcessLock in the system temp directory. +// The lock file name is derived from the prefix and a SHA-256 hash of key. +func NewDefaultLock(prefix, key string) ProcessLock { + return NewFileLock(defaultLockPath(hashedLockFilename(prefix, key))) +} + // defaultContextSleep sleeps for d, respecting ctx cancellation. // Shared by all lock-wait loops. func defaultContextSleep(ctx context.Context, d time.Duration) error { diff --git a/vault/session_lock.go b/vault/session_lock.go deleted file mode 100644 index 5d77e240..00000000 --- a/vault/session_lock.go +++ /dev/null @@ -1,21 +0,0 @@ -package vault - -const sessionLockFilenamePrefix = "aws-vault.session" - -// SessionCacheLock coordinates session cache refreshes across processes. -type SessionCacheLock = ProcessLock - -// NewDefaultSessionCacheLock creates a lock in the system temp directory. -// This only coordinates processes that share the same temp dir; differing TMPDIRs/users are out of scope. -func NewDefaultSessionCacheLock(lockKey string) SessionCacheLock { - return NewSessionCacheLock(defaultLockPath(sessionLockFilename(lockKey))) -} - -// NewSessionCacheLock creates a lock at the provided path. -func NewSessionCacheLock(path string) SessionCacheLock { - return NewFileLock(path) -} - -func sessionLockFilename(lockKey string) string { - return hashedLockFilename(sessionLockFilenamePrefix, lockKey) -} diff --git a/vault/sso_lock.go b/vault/sso_lock.go deleted file mode 100644 index 06a2868c..00000000 --- a/vault/sso_lock.go +++ /dev/null @@ -1,22 +0,0 @@ -package vault - -const ssoLockFilenamePrefix = "aws-vault.sso" - -// SSOTokenLock coordinates the SSO device flow across processes. -type SSOTokenLock = ProcessLock - -// NewDefaultSSOTokenLock creates a lock in the system temp directory keyed by startURL. -// Processes sharing the same StartURL serialize; different StartURLs lock independently. -// This only coordinates processes that share the same temp dir; differing TMPDIRs/users are out of scope. -func NewDefaultSSOTokenLock(startURL string) SSOTokenLock { - return NewSSOTokenLock(defaultLockPath(ssoLockFilename(startURL))) -} - -// NewSSOTokenLock creates a lock at the provided path. -func NewSSOTokenLock(path string) SSOTokenLock { - return NewFileLock(path) -} - -func ssoLockFilename(startURL string) string { - return hashedLockFilename(ssoLockFilenamePrefix, startURL) -} diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index 3d0df95b..6d70b44b 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -39,7 +39,7 @@ type SSORoleCredentialsProvider struct { RoleName string UseStdout bool UseSSOTokenLock bool - ssoTokenLock SSOTokenLock + ssoTokenLock ProcessLock ssoLockWait time.Duration ssoLockLog time.Duration ssoNow func() time.Time @@ -108,7 +108,7 @@ func (p *SSORoleCredentialsProvider) initSSODefaults() { func (p *SSORoleCredentialsProvider) EnableSSOTokenLock() { p.UseSSOTokenLock = true if !p.UseStdout && p.ssoTokenLock == nil { - p.ssoTokenLock = NewDefaultSSOTokenLock(p.StartURL) + p.ssoTokenLock = NewDefaultLock("aws-vault.sso", p.StartURL) } } From 5e01e27d7f17c61573cd42f7347e5403700327a0 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:35:08 -0400 Subject: [PATCH 13/30] refactor(lock): replace newLockWaiter positional params with options struct The 10-parameter constructor was hard to read and easy to mis-order. Replace with lockWaiterOpts struct for named fields. Also use the shared defaultContextSleep instead of an inline copy. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/cachedsessionprovider.go | 24 ++++----- vault/lock_waiter.go | 83 ++++++++++------------------- vault/locked_keyring.go | 24 ++++----- vault/ssorolecredentialsprovider.go | 24 ++++----- 4 files changed, 65 insertions(+), 90 deletions(-) diff --git a/vault/cachedsessionprovider.go b/vault/cachedsessionprovider.go index ecbb91ef..4e4994cf 100644 --- a/vault/cachedsessionprovider.go +++ b/vault/cachedsessionprovider.go @@ -95,20 +95,20 @@ func (p *CachedSessionProvider) getCachedSession() (creds *ststypes.Credentials, } func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststypes.Credentials, error) { - waiter := newLockWaiter( - p.sessionLock, - "Waiting for session lock at %s\n", - "Waiting for session lock at %s", - p.sessionLockWait, - p.sessionLockLog, - defaultSessionLockWarnAfter, - p.sessionNow, - p.sessionSleep, - p.sessionLogf, - func(format string, args ...any) { + waiter := newLockWaiter(lockWaiterOpts{ + Lock: p.sessionLock, + WarnMsg: "Waiting for session lock at %s\n", + LogMsg: "Waiting for session lock at %s", + WaitDelay: p.sessionLockWait, + LogEvery: p.sessionLockLog, + WarnAfter: defaultSessionLockWarnAfter, + Now: p.sessionNow, + Sleep: p.sessionSleep, + Logf: p.sessionLogf, + Warnf: func(format string, args ...any) { fmt.Fprintf(os.Stderr, format, args...) }, - ) + }) for { creds, cached, err := p.getCachedSession() diff --git a/vault/lock_waiter.go b/vault/lock_waiter.go index 2c8bb18a..1c5a2665 100644 --- a/vault/lock_waiter.go +++ b/vault/lock_waiter.go @@ -7,79 +7,54 @@ import ( type lockLogger func(string, ...any) +// lockWaiterOpts configures a lockWaiter. All fields are required except +// Now, Sleep, and Warnf which have sensible defaults. +type lockWaiterOpts struct { + Lock ProcessLock + WarnMsg string + LogMsg string + WaitDelay time.Duration + LogEvery time.Duration + WarnAfter time.Duration + Now func() time.Time + Sleep func(context.Context, time.Duration) error + Logf lockLogger + Warnf lockLogger +} + type lockWaiter struct { - lock ProcessLock - waitDelay time.Duration - logEvery time.Duration - warnAfter time.Duration - now func() time.Time - sleep func(context.Context, time.Duration) error - logf lockLogger - warnf lockLogger - warnMsg string - logMsg string + opts lockWaiterOpts lastLog time.Time waitStart time.Time warned bool } -func newLockWaiter( - lock ProcessLock, - warnMsg string, - logMsg string, - waitDelay time.Duration, - logEvery time.Duration, - warnAfter time.Duration, - now func() time.Time, - sleep func(context.Context, time.Duration) error, - logf lockLogger, - warnf lockLogger, -) *lockWaiter { - if now == nil { - now = time.Now - } - if sleep == nil { - sleep = func(ctx context.Context, d time.Duration) error { - timer := time.NewTimer(d) - defer timer.Stop() - select { - case <-ctx.Done(): - return ctx.Err() - case <-timer.C: - return nil - } - } +func newLockWaiter(opts lockWaiterOpts) *lockWaiter { + if opts.Now == nil { + opts.Now = time.Now } - return &lockWaiter{ - lock: lock, - waitDelay: waitDelay, - logEvery: logEvery, - warnAfter: warnAfter, - now: now, - sleep: sleep, - logf: logf, - warnf: warnf, - warnMsg: warnMsg, - logMsg: logMsg, + if opts.Sleep == nil { + opts.Sleep = defaultContextSleep } + return &lockWaiter{opts: opts} } func (w *lockWaiter) sleepAfterMiss(ctx context.Context) error { - now := w.now() + now := w.opts.Now() if w.waitStart.IsZero() { w.waitStart = now } - if !w.warned && w.warnAfter <= now.Sub(w.waitStart) { - if w.warnf != nil { - w.warnf(w.warnMsg, w.lock.Path()) + if !w.warned && w.opts.WarnAfter <= now.Sub(w.waitStart) { + if w.opts.Warnf != nil { + w.opts.Warnf(w.opts.WarnMsg, w.opts.Lock.Path()) } w.warned = true } - if w.logf != nil && (w.lastLog.IsZero() || w.logEvery <= now.Sub(w.lastLog)) { - w.logf(w.logMsg, w.lock.Path()) + if w.opts.Logf != nil && (w.lastLog.IsZero() || w.opts.LogEvery <= now.Sub(w.lastLog)) { + w.opts.Logf(w.opts.LogMsg, w.opts.Lock.Path()) w.lastLog = now } - return w.sleep(ctx, w.waitDelay) + return w.opts.Sleep(ctx, w.opts.WaitDelay) } diff --git a/vault/locked_keyring.go b/vault/locked_keyring.go index 85b65541..5d693ffd 100644 --- a/vault/locked_keyring.go +++ b/vault/locked_keyring.go @@ -74,20 +74,20 @@ func (k *lockedKeyring) withLock(fn func() error) error { k.mu.Lock() defer k.mu.Unlock() - waiter := newLockWaiter( - k.lock, - "Waiting for keyring lock at %s\n", - "Waiting for keyring lock at %s", - k.lockWait, - k.lockLog, - k.warnAfter, - k.lockNow, - k.lockSleep, - k.lockLogf, - func(format string, args ...any) { + waiter := newLockWaiter(lockWaiterOpts{ + Lock: k.lock, + WarnMsg: "Waiting for keyring lock at %s\n", + LogMsg: "Waiting for keyring lock at %s", + WaitDelay: k.lockWait, + LogEvery: k.lockLog, + WarnAfter: k.warnAfter, + Now: k.lockNow, + Sleep: k.lockSleep, + Logf: k.lockLogf, + Warnf: func(format string, args ...any) { fmt.Fprintf(os.Stderr, format, args...) }, - ) + }) // The keyring.Keyring interface is not context-aware, so we cannot cancel // in-flight keyring operations. This timeout is a safety net for the lock-wait diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index 6d70b44b..293fba7f 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -265,20 +265,20 @@ func (p *SSORoleCredentialsProvider) createAndCacheOIDCToken(ctx context.Context } func (p *SSORoleCredentialsProvider) getOIDCTokenWithLock(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { - waiter := newLockWaiter( - p.ssoTokenLock, - "Waiting for SSO lock at %s\n", - "Waiting for SSO lock at %s", - p.ssoLockWait, - p.ssoLockLog, - defaultSSOLockWarnAfter, - p.ssoNow, - p.ssoSleep, - p.ssoLogf, - func(format string, args ...any) { + waiter := newLockWaiter(lockWaiterOpts{ + Lock: p.ssoTokenLock, + WarnMsg: "Waiting for SSO lock at %s\n", + LogMsg: "Waiting for SSO lock at %s", + WaitDelay: p.ssoLockWait, + LogEvery: p.ssoLockLog, + WarnAfter: defaultSSOLockWarnAfter, + Now: p.ssoNow, + Sleep: p.ssoSleep, + Logf: p.ssoLogf, + Warnf: func(format string, args ...any) { fmt.Fprintf(os.Stderr, format, args...) }, - ) + }) for { token, cached, err = p.getCachedOIDCToken() From a9ed64bb5b942fb48f9812fa07e8cff1159ff551 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:36:08 -0400 Subject: [PATCH 14/30] fix(sso): widen jitter range from 1.1x-1.3x to 0.5x-1.5x MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The old range always added 10-30% to the base delay, never reducing it. This meant concurrent processes that received the same Retry-After header would all retry at nearly the same time. The new 0.5x-1.5x range provides full jitter — some fire earlier, some later — properly decorrelating parallel retries. Note: Go 1.20+ auto-seeds math/rand, so the seeding concern does not apply to this codebase (go 1.25). Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/ssorolecredentialsprovider.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index 293fba7f..1d36d054 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -84,11 +84,13 @@ const ( // don't grow unreasonably large between attempts. ssoRetryMax = 5 * time.Second - // ssoRetryAfterJitterMin and ssoRetryAfterJitterMax add ±10-30% jitter - // to Retry-After values to decorrelate concurrent processes that all - // received the same Retry-After header from the SSO service. - ssoRetryAfterJitterMin = 1.1 - ssoRetryAfterJitterMax = 1.3 + // ssoRetryAfterJitterMin and ssoRetryAfterJitterMax define the full-jitter + // range as a multiplier of the base delay. 0.5x-1.5x ensures retries + // spread across a wide window — some fire earlier than the base delay, + // some later — which decorrelates concurrent processes that all received + // the same Retry-After header from the SSO service. + ssoRetryAfterJitterMin = 0.5 + ssoRetryAfterJitterMax = 1.5 ) From 642e797df456bbbd0fe12807f921fb6f27068e2b Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:37:26 -0400 Subject: [PATCH 15/30] fix(lock): add 2-minute timeout to session and SSO lock-wait loops The keyring lock-wait loop already had a 2-minute timeout safety net, but the session and SSO loops would block indefinitely if the lock holder hung. Add matching context.WithTimeout to both, consistent with the keyring's defaultKeyringLockTimeout. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/cachedsessionprovider.go | 8 ++++++++ vault/ssorolecredentialsprovider.go | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/vault/cachedsessionprovider.go b/vault/cachedsessionprovider.go index 4e4994cf..34c6c11d 100644 --- a/vault/cachedsessionprovider.go +++ b/vault/cachedsessionprovider.go @@ -48,6 +48,11 @@ const ( // flashing the message on normal lock contention, short enough to // reassure the user that the process isn't hung. defaultSessionLockWarnAfter = 5 * time.Second + + // defaultSessionLockTimeout is a safety net: if the lock holder is hung, + // waiters give up after this duration rather than blocking indefinitely. + // 2 minutes matches the keyring lock timeout. + defaultSessionLockTimeout = 2 * time.Minute ) @@ -95,6 +100,9 @@ func (p *CachedSessionProvider) getCachedSession() (creds *ststypes.Credentials, } func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststypes.Credentials, error) { + ctx, cancel := context.WithTimeout(ctx, defaultSessionLockTimeout) + defer cancel() + waiter := newLockWaiter(lockWaiterOpts{ Lock: p.sessionLock, WarnMsg: "Waiting for session lock at %s\n", diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index 1d36d054..f1337b69 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -69,6 +69,12 @@ const ( // reassure the user that the process isn't hung. defaultSSOLockWarnAfter = 5 * time.Second + // defaultSSOLockTimeout is a safety net: if the lock holder is hung + // (e.g. a browser auth that was abandoned), waiters give up after this + // duration rather than blocking indefinitely. 2 minutes matches the + // keyring lock timeout. + defaultSSOLockTimeout = 2 * time.Minute + // ssoRetryTimeout is a pathological safety net: if GetRoleCredentials is still // returning 429s after this duration, give up and surface the error to the user. // 5 minutes is generous but accommodates burst-heavy credential_process workloads @@ -267,6 +273,9 @@ func (p *SSORoleCredentialsProvider) createAndCacheOIDCToken(ctx context.Context } func (p *SSORoleCredentialsProvider) getOIDCTokenWithLock(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultSSOLockTimeout) + defer cancel() + waiter := newLockWaiter(lockWaiterOpts{ Lock: p.ssoTokenLock, WarnMsg: "Waiting for SSO lock at %s\n", From 9e7f8d1dde18dccb04f876f8f492be5b6912d571 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:38:33 -0400 Subject: [PATCH 16/30] fix: exclude login from --parallel-safe keyring wrapping Login was documented as excluded from --parallel-safe, but Keyring() unconditionally wrapped with LockedKeyring when the flag was set. Split into rawKeyring (shared base) and Keyring (adds lock wrapper). Login now calls RawKeyring() to bypass the lock. Co-Authored-By: Claude Opus 4.6 (1M context) --- cli/global.go | 37 +++++++++++++++++++++++++++---------- cli/login.go | 5 ++++- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/cli/global.go b/cli/global.go index 266bfe9c..92241a9d 100644 --- a/cli/global.go +++ b/cli/global.go @@ -39,8 +39,9 @@ type AwsVault struct { promptDriver string ParallelSafe bool - keyringImpl keyring.Keyring - awsConfigFile *vault.ConfigFile + rawKeyringImpl keyring.Keyring + keyringImpl keyring.Keyring + awsConfigFile *vault.ConfigFile UseBiometrics bool } @@ -69,22 +70,38 @@ func (a *AwsVault) PromptDriver(avoidTerminalPrompt bool) string { } func (a *AwsVault) Keyring() (keyring.Keyring, error) { - if a.keyringImpl == nil { + raw, err := a.rawKeyring() + if err != nil { + return nil, err + } + if a.ParallelSafe { + if a.keyringImpl == nil { + lockKey := a.keyringLockKey() + a.keyringImpl = vault.NewLockedKeyring(raw, lockKey) + } + return a.keyringImpl, nil + } + return raw, nil +} + +// RawKeyring returns the keyring without the parallel-safe lock wrapper. +// Used by commands like login that are excluded from --parallel-safe. +func (a *AwsVault) RawKeyring() (keyring.Keyring, error) { + return a.rawKeyring() +} + +func (a *AwsVault) rawKeyring() (keyring.Keyring, error) { + if a.rawKeyringImpl == nil { if a.KeyringBackend != "" { a.KeyringConfig.AllowedBackends = []keyring.BackendType{keyring.BackendType(a.KeyringBackend)} } var err error - a.keyringImpl, err = keyring.Open(a.KeyringConfig) + a.rawKeyringImpl, err = keyring.Open(a.KeyringConfig) if err != nil { return nil, err } - if a.ParallelSafe { - lockKey := a.keyringLockKey() - a.keyringImpl = vault.NewLockedKeyring(a.keyringImpl, lockKey) - } } - - return a.keyringImpl, nil + return a.rawKeyringImpl, nil } // keyringLockKey returns a backend-specific key for the cross-process keyring diff --git a/cli/login.go b/cli/login.go index 40375fa3..5bb0ec13 100644 --- a/cli/login.go +++ b/cli/login.go @@ -79,7 +79,10 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) { input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration input.Config.AssumeRoleDuration = input.SessionDuration input.Config.GetFederationTokenDuration = input.SessionDuration - keyring, err := a.Keyring() + // Login uses the raw keyring without the parallel-safe lock wrapper. + // Console login is inherently single-use — there's no concurrent-access + // problem for --parallel-safe to solve here. + keyring, err := a.RawKeyring() if err != nil { return err } From 6c1387afb6513442f9dc7c713fe3b1128a5f4290 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:42:11 -0400 Subject: [PATCH 17/30] refactor(lock): extract withProcessLock generic helper The try/sleep/recheck lock protocol was copy-pasted in three places. Extract a single generic withProcessLock[T] that encodes the protocol once: check cache, try lock, do work under lock, sleep and retry. Also fixes a bug where the SSO retry loop used time.Until(deadline) (real clock) instead of deadline.Sub(p.ssoNow()) (injected clock), making the timeout untestable with fake clocks. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/cachedsessionprovider.go | 50 ++++++-------------- vault/locked_keyring.go | 38 +++++---------- vault/process_lock_loop.go | 68 +++++++++++++++++++++++++++ vault/ssorolecredentialsprovider.go | 71 ++++++++++++----------------- 4 files changed, 122 insertions(+), 105 deletions(-) create mode 100644 vault/process_lock_loop.go diff --git a/vault/cachedsessionprovider.go b/vault/cachedsessionprovider.go index 34c6c11d..5242c285 100644 --- a/vault/cachedsessionprovider.go +++ b/vault/cachedsessionprovider.go @@ -2,7 +2,6 @@ package vault import ( "context" - "errors" "fmt" "log" "os" @@ -103,7 +102,7 @@ func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststyp ctx, cancel := context.WithTimeout(ctx, defaultSessionLockTimeout) defer cancel() - waiter := newLockWaiter(lockWaiterOpts{ + return withProcessLock(ctx, p.sessionLock, lockWaiterOpts{ Lock: p.sessionLock, WarnMsg: "Waiting for session lock at %s\n", LogMsg: "Waiting for session lock at %s", @@ -116,51 +115,28 @@ func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststyp Warnf: func(format string, args ...any) { fmt.Fprintf(os.Stderr, format, args...) }, - }) - - for { + }, "session", func() (processLockResult[*ststypes.Credentials], error) { creds, cached, err := p.getCachedSession() if err == nil && cached { - return creds, nil + return processLockResult[*ststypes.Credentials]{value: creds, ok: true}, nil } - if ctx.Err() != nil { - return nil, ctx.Err() + return processLockResult[*ststypes.Credentials]{}, nil + }, func(ctx context.Context) (*ststypes.Credentials, error) { + // Recheck cache after acquiring lock — another process may have filled it. + creds, cached, cacheErr := p.getCachedSession() + if cacheErr == nil && cached { + return creds, nil } - locked, err := p.sessionLock.TryLock() + creds, err := p.SessionProvider.RetrieveStsCredentials(ctx) if err != nil { return nil, err } - if locked { - return p.doLockedSessionWork(ctx) - } - if sleepErr := waiter.sleepAfterMiss(ctx); sleepErr != nil { - return nil, sleepErr - } - } -} - -func (p *CachedSessionProvider) doLockedSessionWork(ctx context.Context) (creds *ststypes.Credentials, err error) { - defer func() { - if unlockErr := p.sessionLock.Unlock(); unlockErr != nil { - err = errors.Join(err, fmt.Errorf("unlock session lock: %w", unlockErr)) + if err = p.Keyring.Set(p.SessionKey, creds); err != nil { + return nil, err } - }() - - creds, cached, cacheErr := p.getCachedSession() - if cacheErr == nil && cached { return creds, nil - } - - creds, err = p.SessionProvider.RetrieveStsCredentials(ctx) - if err != nil { - return nil, err - } - if err = p.Keyring.Set(p.SessionKey, creds); err != nil { - return nil, err - } - - return creds, nil + }) } func (p *CachedSessionProvider) getSessionWithoutLock(ctx context.Context) (*ststypes.Credentials, error) { diff --git a/vault/locked_keyring.go b/vault/locked_keyring.go index 5d693ffd..38d503a1 100644 --- a/vault/locked_keyring.go +++ b/vault/locked_keyring.go @@ -2,7 +2,6 @@ package vault import ( "context" - "errors" "fmt" "log" "os" @@ -74,7 +73,14 @@ func (k *lockedKeyring) withLock(fn func() error) error { k.mu.Lock() defer k.mu.Unlock() - waiter := newLockWaiter(lockWaiterOpts{ + // The keyring.Keyring interface is not context-aware, so we cannot cancel + // in-flight keyring operations. This timeout is a safety net for the lock-wait + // loop: if the lock holder is hung (e.g. a stuck gpg subprocess in the pass + // backend), waiters will eventually give up rather than blocking indefinitely. + ctx, cancel := context.WithTimeout(context.Background(), defaultKeyringLockTimeout) + defer cancel() + + _, err := withProcessLock(ctx, k.lock, lockWaiterOpts{ Lock: k.lock, WarnMsg: "Waiting for keyring lock at %s\n", LogMsg: "Waiting for keyring lock at %s", @@ -87,32 +93,10 @@ func (k *lockedKeyring) withLock(fn func() error) error { Warnf: func(format string, args ...any) { fmt.Fprintf(os.Stderr, format, args...) }, + }, "keyring", nil, func(ctx context.Context) (struct{}, error) { + return struct{}{}, fn() }) - - // The keyring.Keyring interface is not context-aware, so we cannot cancel - // in-flight keyring operations. This timeout is a safety net for the lock-wait - // loop: if the lock holder is hung (e.g. a stuck gpg subprocess in the pass - // backend), waiters will eventually give up rather than blocking indefinitely. - ctx, cancel := context.WithTimeout(context.Background(), defaultKeyringLockTimeout) - defer cancel() - - for { - locked, err := k.lock.TryLock() - if err != nil { - return err - } - if locked { - fnErr := fn() - if unlockErr := k.lock.Unlock(); unlockErr != nil { - return errors.Join(fnErr, fmt.Errorf("unlock keyring lock: %w", unlockErr)) - } - return fnErr - } - - if err = waiter.sleepAfterMiss(ctx); err != nil { - return err - } - } + return err } func (k *lockedKeyring) Get(key string) (keyring.Item, error) { diff --git a/vault/process_lock_loop.go b/vault/process_lock_loop.go new file mode 100644 index 00000000..b3066049 --- /dev/null +++ b/vault/process_lock_loop.go @@ -0,0 +1,68 @@ +package vault + +import ( + "context" + "errors" + "fmt" +) + +// processLockResult is the result of a cache check or locked work function. +// ok indicates whether a cached result was found. +type processLockResult[T any] struct { + value T + ok bool +} + +// withProcessLock implements the try/sleep/recheck lock protocol. +// +// On each iteration it calls checkCache; if that returns ok=true, the cached +// value is returned without acquiring the lock. Otherwise it tries the lock: +// if acquired, it calls doWork under the lock (unlocking on return). If the +// lock is not acquired, it sleeps and retries. +// +// checkCache may be nil, in which case the cache check is skipped. +func withProcessLock[T any]( + ctx context.Context, + lock ProcessLock, + waiterOpts lockWaiterOpts, + lockName string, + checkCache func() (processLockResult[T], error), + doWork func(ctx context.Context) (T, error), +) (T, error) { + waiter := newLockWaiter(waiterOpts) + + for { + if checkCache != nil { + result, err := checkCache() + if err != nil { + var zero T + return zero, err + } + if result.ok { + return result.value, nil + } + } + if ctx.Err() != nil { + var zero T + return zero, ctx.Err() + } + + locked, err := lock.TryLock() + if err != nil { + var zero T + return zero, err + } + if locked { + result, workErr := doWork(ctx) + if unlockErr := lock.Unlock(); unlockErr != nil { + return result, errors.Join(workErr, fmt.Errorf("unlock %s lock: %w", lockName, unlockErr)) + } + return result, workErr + } + + if err = waiter.sleepAfterMiss(ctx); err != nil { + var zero T + return zero, err + } + } +} diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index f1337b69..7c2412e5 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -178,7 +178,7 @@ func (p *SSORoleCredentialsProvider) getRoleCredentials(ctx context.Context) (*s if isSSORateLimitError(err) { rateLimitCount++ - remaining := time.Until(deadline) + remaining := deadline.Sub(p.ssoNow()) if 0 < remaining { var delay time.Duration if retryAfter, ok := retryAfterFromError(err); ok { @@ -272,11 +272,16 @@ func (p *SSORoleCredentialsProvider) createAndCacheOIDCToken(ctx context.Context return token, false, nil } +type oidcTokenResult struct { + token *ssooidc.CreateTokenOutput + cached bool +} + func (p *SSORoleCredentialsProvider) getOIDCTokenWithLock(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { ctx, cancel := context.WithTimeout(ctx, defaultSSOLockTimeout) defer cancel() - waiter := newLockWaiter(lockWaiterOpts{ + result, err := withProcessLock(ctx, p.ssoTokenLock, lockWaiterOpts{ Lock: p.ssoTokenLock, WarnMsg: "Waiting for SSO lock at %s\n", LogMsg: "Waiting for SSO lock at %s", @@ -289,55 +294,39 @@ func (p *SSORoleCredentialsProvider) getOIDCTokenWithLock(ctx context.Context) ( Warnf: func(format string, args ...any) { fmt.Fprintf(os.Stderr, format, args...) }, - }) - - for { - token, cached, err = p.getCachedOIDCToken() - if err != nil || token != nil { - return token, cached, err + }, "SSO token", func() (processLockResult[oidcTokenResult], error) { + token, cached, err := p.getCachedOIDCToken() + if err != nil { + return processLockResult[oidcTokenResult]{}, err } - if ctx.Err() != nil { - return nil, false, ctx.Err() + if token != nil { + return processLockResult[oidcTokenResult]{value: oidcTokenResult{token, cached}, ok: true}, nil } - - locked, err := p.ssoTokenLock.TryLock() + return processLockResult[oidcTokenResult]{}, nil + }, func(ctx context.Context) (oidcTokenResult, error) { + // Recheck cache after acquiring lock — another process may have filled it. + token, cached, err := p.getCachedOIDCToken() if err != nil { - return nil, false, err + return oidcTokenResult{}, err } - if locked { - return p.doLockedOIDCTokenWork(ctx) + if token != nil { + return oidcTokenResult{token, cached}, nil } - if err = waiter.sleepAfterMiss(ctx); err != nil { - return nil, false, err - } - } -} - -func (p *SSORoleCredentialsProvider) doLockedOIDCTokenWork(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { - defer func() { - if unlockErr := p.ssoTokenLock.Unlock(); unlockErr != nil { - err = errors.Join(err, fmt.Errorf("unlock SSO token lock: %w", unlockErr)) + token, err = p.newOIDCTokenFn(ctx) + if err != nil { + return oidcTokenResult{}, err } - }() - token, cached, err = p.getCachedOIDCToken() - if err != nil || token != nil { - return token, cached, err - } - - token, err = p.newOIDCTokenFn(ctx) - if err != nil { - return nil, false, err - } - - if p.OIDCTokenCache != nil { - if err = p.OIDCTokenCache.Set(p.StartURL, token); err != nil { - return nil, false, err + if p.OIDCTokenCache != nil { + if err = p.OIDCTokenCache.Set(p.StartURL, token); err != nil { + return oidcTokenResult{}, err + } } - } - return token, false, nil + return oidcTokenResult{token, false}, nil + }) + return result.token, result.cached, err } func (p *SSORoleCredentialsProvider) newOIDCToken(ctx context.Context) (*ssooidc.CreateTokenOutput, error) { From 94abec6b9214903311327e74d720af855ef10691 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:42:33 -0400 Subject: [PATCH 18/30] test(sso): add retry loop tests for backoff, jitter, and timeout - TestJitteredBackoffProgression: verify cap doubles per attempt - TestJitteredBackoffRespectsMax: verify cap doesn't exceed max - TestJitteredBackoffDoublesPerAttempt: verify deterministic cap - TestJitterRetryAfterRange: verify 0.5x-1.5x jitter range - TestJitterRetryAfterZeroBase/NegativeBase: edge cases - TestGetRoleCredentialsTimeoutOnPersistentRateLimit: end-to-end test with fake SSO server and fake clock Co-Authored-By: Claude Opus 4.6 (1M context) --- .../ssorolecredentialsprovider_retry_test.go | 151 ++++++++++++++++++ 1 file changed, 151 insertions(+) diff --git a/vault/ssorolecredentialsprovider_retry_test.go b/vault/ssorolecredentialsprovider_retry_test.go index a4c8a7bd..0aedb844 100644 --- a/vault/ssorolecredentialsprovider_retry_test.go +++ b/vault/ssorolecredentialsprovider_retry_test.go @@ -1,12 +1,19 @@ package vault import ( + "context" "errors" + "fmt" "net/http" + "net/http/httptest" + "strings" "testing" "time" + "github.com/aws/aws-sdk-go-v2/aws" awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + "github.com/aws/aws-sdk-go-v2/service/sso" + "github.com/aws/aws-sdk-go-v2/service/ssooidc" ssotypes "github.com/aws/aws-sdk-go-v2/service/sso/types" smithyhttp "github.com/aws/smithy-go/transport/http" ) @@ -76,3 +83,147 @@ func TestJitterDelayRange(t *testing.T) { } } } + +func TestJitteredBackoffProgression(t *testing.T) { + base := 200 * time.Millisecond + max := 5 * time.Second + + // Each attempt should double the cap: 200ms, 400ms, 800ms, 1600ms, 3200ms, 5000ms (capped) + for attempt := 1; attempt <= 8; attempt++ { + expectedCap := base << uint(attempt-1) + if max < expectedCap { + expectedCap = max + } + minDelay := time.Duration(float64(expectedCap) * ssoRetryAfterJitterMin) + maxDelay := time.Duration(float64(expectedCap) * ssoRetryAfterJitterMax) + + for i := 0; i < 20; i++ { + delay := jitteredBackoff(base, max, attempt) + if delay < minDelay || maxDelay < delay { + t.Fatalf("attempt %d: expected delay in range %s-%s, got %s", + attempt, minDelay, maxDelay, delay) + } + } + } +} + +func TestJitteredBackoffRespectsMax(t *testing.T) { + base := 200 * time.Millisecond + max := 5 * time.Second + + // At high attempt numbers the cap should be max, not overflow + maxDelay := time.Duration(float64(max) * ssoRetryAfterJitterMax) + for attempt := 20; attempt <= 30; attempt++ { + for i := 0; i < 10; i++ { + delay := jitteredBackoff(base, max, attempt) + if maxDelay < delay { + t.Fatalf("attempt %d: delay %s exceeds max jittered cap %s", attempt, delay, maxDelay) + } + if delay < 0 { + t.Fatalf("attempt %d: negative delay %s", attempt, delay) + } + } + } +} + +func TestJitteredBackoffDoublesPerAttempt(t *testing.T) { + base := 1 * time.Second + max := 1 * time.Hour // very high max so we never hit the cap + + // Verify the cap doubles by checking that the median of many samples + // roughly doubles. Instead, verify the deterministic cap calculation: + // cap(attempt) = base << (attempt-1) + for attempt := 1; attempt <= 5; attempt++ { + expectedCap := base << uint(attempt-1) + minBound := time.Duration(float64(expectedCap) * ssoRetryAfterJitterMin) + maxBound := time.Duration(float64(expectedCap) * ssoRetryAfterJitterMax) + + delay := jitteredBackoff(base, max, attempt) + if delay < minBound || maxBound < delay { + t.Fatalf("attempt %d: expected delay in [%s, %s], got %s", + attempt, minBound, maxBound, delay) + } + } +} + +func TestJitterRetryAfterRange(t *testing.T) { + base := 2 * time.Second + minDelay := time.Duration(float64(base) * ssoRetryAfterJitterMin) + maxDelay := time.Duration(float64(base) * ssoRetryAfterJitterMax) + + for i := 0; i < 50; i++ { + delay := jitterRetryAfter(base) + if delay < minDelay || maxDelay < delay { + t.Fatalf("jitterRetryAfter(%s): expected delay in range %s-%s, got %s", + base, minDelay, maxDelay, delay) + } + } +} + +func TestJitterRetryAfterZeroBase(t *testing.T) { + delay := jitterRetryAfter(0) + if delay != 0 { + t.Fatalf("expected 0 delay for zero base, got %s", delay) + } +} + +func TestJitterRetryAfterNegativeBase(t *testing.T) { + delay := jitterRetryAfter(-1 * time.Second) + if delay != 0 { + t.Fatalf("expected 0 delay for negative base, got %s", delay) + } +} + +func TestGetRoleCredentialsTimeoutOnPersistentRateLimit(t *testing.T) { + // Set up an HTTP server that always returns 429 with a TooManyRequestsException + // body that the AWS SDK will deserialize into a TooManyRequestsException error. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Amzn-Errortype", "TooManyRequestsException") + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprint(w, `{"__type":"TooManyRequestsException","message":"Rate exceeded"}`) + })) + defer srv.Close() + + // Disable SDK retries so our retry loop handles them + ssoClient := sso.New(sso.Options{ + Region: "us-east-1", + BaseEndpoint: aws.String(srv.URL), + RetryMaxAttempts: 1, + }) + + startTime := time.Unix(1000000, 0) + clock := &testClock{now: startTime} + + p := newTestSSORoleProvider() + p.SSOClient = ssoClient + p.AccountID = "123456789012" + p.RoleName = "TestRole" + p.ssoNow = clock.Now + p.ssoSleep = clock.Sleep + p.ssoLogf = func(string, ...any) {} // suppress log output + + // Provide a cached OIDC token so getOIDCToken succeeds + cache := &testTokenCache{ + token: &ssooidc.CreateTokenOutput{AccessToken: aws.String("test-token")}, + } + p.OIDCTokenCache = cache + + _, err := p.getRoleCredentials(context.Background()) + if err == nil { + t.Fatal("expected error after timeout, got nil") + } + + if !strings.Contains(err.Error(), "persistently") { + t.Fatalf("expected timeout error mentioning 'persistently', got: %v", err) + } + if !strings.Contains(err.Error(), ssoRetryTimeout.String()) { + t.Fatalf("expected error to mention timeout duration %s, got: %v", ssoRetryTimeout, err) + } + + // Verify the clock advanced past the retry timeout + elapsed := clock.now.Sub(time.Unix(1000000, 0)) + if elapsed < ssoRetryTimeout { + t.Fatalf("expected clock to advance at least %s, advanced %s", ssoRetryTimeout, elapsed) + } +} From 4e6d70061dc328ac21dd47bfec68944c69d588ab Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:42:41 -0400 Subject: [PATCH 19/30] test(cli): add table-driven tests for keyringLockKey() Cover all backend types (keychain, file, pass, secret-service, kwallet, wincred, op variants) with config set and empty, plus the fallback to backend name and "aws-vault". Co-Authored-By: Claude Opus 4.6 (1M context) --- cli/global_test.go | 177 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 cli/global_test.go diff --git a/cli/global_test.go b/cli/global_test.go new file mode 100644 index 00000000..19afbf5a --- /dev/null +++ b/cli/global_test.go @@ -0,0 +1,177 @@ +package cli + +import ( + "testing" + + "github.com/byteness/keyring" +) + +func TestKeyringLockKey(t *testing.T) { + tests := []struct { + name string + backend string + config keyring.Config + want string + }{ + // Keychain backend + { + name: "keychain with keychain name", + backend: "keychain", + config: keyring.Config{KeychainName: "my-keychain"}, + want: "keychain:my-keychain", + }, + { + name: "keychain with empty keychain name", + backend: "keychain", + config: keyring.Config{}, + want: "keychain", + }, + + // File backend + { + name: "file with file dir", + backend: "file", + config: keyring.Config{FileDir: "/tmp/keys"}, + want: "file:/tmp/keys", + }, + { + name: "file with empty file dir", + backend: "file", + config: keyring.Config{}, + want: "file", + }, + + // Pass backend: dir and prefix combinations + { + name: "pass with dir and prefix", + backend: "pass", + config: keyring.Config{PassDir: "/store", PassPrefix: "aws"}, + want: "pass:/store:aws", + }, + { + name: "pass with dir only", + backend: "pass", + config: keyring.Config{PassDir: "/store"}, + want: "pass:/store", + }, + { + name: "pass with prefix only", + backend: "pass", + config: keyring.Config{PassPrefix: "aws"}, + want: "pass:aws", + }, + { + name: "pass with neither dir nor prefix", + backend: "pass", + config: keyring.Config{}, + want: "pass", + }, + + // Secret-service backend + { + name: "secret-service with collection name", + backend: "secret-service", + config: keyring.Config{LibSecretCollectionName: "awsvault"}, + want: "secret-service:awsvault", + }, + { + name: "secret-service with empty collection name", + backend: "secret-service", + config: keyring.Config{}, + want: "secret-service", + }, + + // KWallet backend + { + name: "kwallet with folder", + backend: "kwallet", + config: keyring.Config{KWalletFolder: "aws-vault"}, + want: "kwallet:aws-vault", + }, + { + name: "kwallet with empty folder", + backend: "kwallet", + config: keyring.Config{}, + want: "kwallet", + }, + + // WinCred backend + { + name: "wincred with prefix", + backend: "wincred", + config: keyring.Config{WinCredPrefix: "aws-vault"}, + want: "wincred:aws-vault", + }, + { + name: "wincred with empty prefix", + backend: "wincred", + config: keyring.Config{}, + want: "wincred", + }, + + // 1Password backends (all share OPVaultID) + { + name: "op with vault ID", + backend: "op", + config: keyring.Config{OPVaultID: "vault-123"}, + want: "op:vault-123", + }, + { + name: "op with empty vault ID", + backend: "op", + config: keyring.Config{}, + want: "op", + }, + { + name: "op-connect with vault ID", + backend: "op-connect", + config: keyring.Config{OPVaultID: "vault-456"}, + want: "op-connect:vault-456", + }, + { + name: "op-connect with empty vault ID", + backend: "op-connect", + config: keyring.Config{}, + want: "op-connect", + }, + { + name: "op-desktop with vault ID", + backend: "op-desktop", + config: keyring.Config{OPVaultID: "vault-789"}, + want: "op-desktop:vault-789", + }, + { + name: "op-desktop with empty vault ID", + backend: "op-desktop", + config: keyring.Config{}, + want: "op-desktop", + }, + + // Fallback cases + { + name: "unknown backend falls back to backend name", + backend: "some-unknown-backend", + config: keyring.Config{}, + want: "some-unknown-backend", + }, + { + name: "empty backend falls back to aws-vault", + backend: "", + config: keyring.Config{}, + want: "aws-vault", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &AwsVault{ + KeyringBackend: tt.backend, + KeyringConfig: tt.config, + } + got := a.keyringLockKey() + if got != tt.want { + t.Errorf("keyringLockKey() = %q, want %q", got, tt.want) + } + }) + } +} From d372514337b66c36dcaf05b59e20198d852a1e83 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:47:08 -0400 Subject: [PATCH 20/30] refactor: move parallel-safe locking into provider constructors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Eliminate applyParallelSafety spray pattern — each provider constructor now accepts parallelSafe and configures locking at construction time. This removes the post-construction mutation of UseSessionLock and EnableSSOTokenLock, making providers fully configured at birth. The observable behavior is unchanged: same lock files, same serialization guarantees, same test assertions. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/cachedsessionprovider.go | 3 +- vault/cachedsessionprovider_lock_test.go | 18 ++---- vault/ssorolecredentialsprovider.go | 2 +- vault/vault.go | 76 +++++++----------------- 4 files changed, 30 insertions(+), 69 deletions(-) diff --git a/vault/cachedsessionprovider.go b/vault/cachedsessionprovider.go index 5242c285..72c42f4b 100644 --- a/vault/cachedsessionprovider.go +++ b/vault/cachedsessionprovider.go @@ -58,12 +58,13 @@ const ( // NewCachedSessionProvider creates a CachedSessionProvider with production // defaults for all internal dependencies. Tests can override unexported fields // (sessionLock, sessionNow, etc.) after construction to inject mocks. -func NewCachedSessionProvider(key SessionMetadata, provider StsSessionProvider, keyring *SessionKeyring, expiryWindow time.Duration) *CachedSessionProvider { +func NewCachedSessionProvider(key SessionMetadata, provider StsSessionProvider, keyring *SessionKeyring, expiryWindow time.Duration, useSessionLock bool) *CachedSessionProvider { return &CachedSessionProvider{ SessionKey: key, SessionProvider: provider, Keyring: keyring, ExpiryWindow: expiryWindow, + UseSessionLock: useSessionLock, sessionLock: NewDefaultLock("aws-vault.session", key.StringForMatching()), sessionLockWait: defaultSessionLockWaitDelay, sessionLockLog: defaultSessionLockLogEvery, diff --git a/vault/cachedsessionprovider_lock_test.go b/vault/cachedsessionprovider_lock_test.go index 198b96f9..c3661fd9 100644 --- a/vault/cachedsessionprovider_lock_test.go +++ b/vault/cachedsessionprovider_lock_test.go @@ -75,8 +75,7 @@ func TestCachedSession_CacheHit_NoLock(t *testing.T) { onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called on cache hit") }, } - p := NewCachedSessionProvider(key, provider, sk, 0) - p.UseSessionLock = true + p := NewCachedSessionProvider(key, provider, sk, 0, true) p.sessionLock = lock got, err := p.RetrieveStsCredentials(context.Background()) @@ -102,8 +101,7 @@ func TestCachedSession_LockDisabled_SkipsLock(t *testing.T) { lock := &testLock{tryResults: []bool{true}} provider := &testSessionProvider{creds: creds} - p := NewCachedSessionProvider(key, provider, sk, 0) - p.UseSessionLock = false + p := NewCachedSessionProvider(key, provider, sk, 0, false) p.sessionLock = lock got, err := p.RetrieveStsCredentials(context.Background()) @@ -135,8 +133,7 @@ func TestCachedSession_LockMiss_ThenCacheHit_NoRefresh(t *testing.T) { onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called when cache fills while waiting") }, } - p := NewCachedSessionProvider(key, provider, sk, 0) - p.UseSessionLock = true + p := NewCachedSessionProvider(key, provider, sk, 0, true) p.sessionLock = lock p.sessionLockWait = 5 * time.Second p.sessionSleep = func(ctx context.Context, d time.Duration) error { @@ -174,8 +171,7 @@ func TestCachedSession_LockAcquired_RecheckCache(t *testing.T) { onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called when cache fills after lock") }, } - p := NewCachedSessionProvider(key, provider, sk, 0) - p.UseSessionLock = true + p := NewCachedSessionProvider(key, provider, sk, 0, true) p.sessionLock = lock got, err := p.RetrieveStsCredentials(context.Background()) @@ -204,8 +200,7 @@ func TestCachedSession_LockHeldThroughCacheSet(t *testing.T) { sk := &SessionKeyring{Keyring: wrappedKeyring} provider := &testSessionProvider{creds: creds} - p := NewCachedSessionProvider(key, provider, sk, 0) - p.UseSessionLock = true + p := NewCachedSessionProvider(key, provider, sk, 0, true) p.sessionLock = lock _, err := p.RetrieveStsCredentials(context.Background()) @@ -234,8 +229,7 @@ func TestCachedSession_LockWaitLogs(t *testing.T) { clock := &testClock{now: time.Unix(0, 0), cancel: cancel, cancelAfter: 4} var logTimes []time.Time - p := NewCachedSessionProvider(key, provider, sk, 0) - p.UseSessionLock = true + p := NewCachedSessionProvider(key, provider, sk, 0, true) p.sessionLock = lock p.sessionLockWait = 5 * time.Second p.sessionLockLog = 15 * time.Second diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index 7c2412e5..6e1c899a 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -112,7 +112,7 @@ func (p *SSORoleCredentialsProvider) initSSODefaults() { } // EnableSSOTokenLock creates the SSO token lock for cross-process coordination. -// Called by applyParallelSafety after setting UseSSOTokenLock. +// Called at construction time when parallelSafe is true. func (p *SSORoleCredentialsProvider) EnableSSOTokenLock() { p.UseSSOTokenLock = true if !p.UseStdout && p.ssoTokenLock == nil { diff --git a/vault/vault.go b/vault/vault.go index efff61ab..f4daa204 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -51,7 +51,7 @@ func NewMasterCredentialsProvider(k *CredentialKeyring, credentialsName string) return &KeyringProvider{k, credentialsName} } -func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) { +func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig, useSessionCache bool, parallelSafe bool) (aws.CredentialsProvider, error) { cfg := NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints, config.EndpointURL) sessionTokenProvider := &SessionTokenProvider{ @@ -70,6 +70,7 @@ func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Ke sessionTokenProvider, &SessionKeyring{Keyring: k}, defaultExpirationWindow, + parallelSafe, ), nil } @@ -77,7 +78,7 @@ func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Ke } // NewAssumeRoleProvider returns a provider that generates credentials using AssumeRole -func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) { +func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig, useSessionCache bool, parallelSafe bool) (aws.CredentialsProvider, error) { cfg := NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints, config.EndpointURL) p := &AssumeRoleProvider{ @@ -102,6 +103,7 @@ func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyr p, &SessionKeyring{Keyring: k}, defaultExpirationWindow, + parallelSafe, ), nil } @@ -110,7 +112,7 @@ func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyr // NewAssumeRoleWithWebIdentityProvider returns a provider that generates // credentials using AssumeRoleWithWebIdentity -func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) { +func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool, parallelSafe bool) (aws.CredentialsProvider, error) { cfg := NewAwsConfig(config.Region, config.STSRegionalEndpoints, config.EndpointURL) p := &AssumeRoleWithWebIdentityProvider{ @@ -131,6 +133,7 @@ func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConf p, &SessionKeyring{Keyring: k}, defaultExpirationWindow, + parallelSafe, ), nil } @@ -138,7 +141,7 @@ func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConf } // NewSSORoleCredentialsProvider creates a provider for SSO credentials -func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) { +func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool, parallelSafe bool) (aws.CredentialsProvider, error) { cfg := NewAwsConfig(config.SSORegion, config.STSRegionalEndpoints, config.EndpointURL) ssoRoleCredentialsProvider := &SSORoleCredentialsProvider{ @@ -150,6 +153,9 @@ func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, use UseStdout: config.SSOUseStdout, } ssoRoleCredentialsProvider.initSSODefaults() + if parallelSafe { + ssoRoleCredentialsProvider.EnableSSOTokenLock() + } if useSessionCache { ssoRoleCredentialsProvider.OIDCTokenCache = OIDCTokenKeyring{Keyring: k} @@ -162,6 +168,7 @@ func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, use ssoRoleCredentialsProvider, &SessionKeyring{Keyring: k}, defaultExpirationWindow, + parallelSafe, ), nil } @@ -170,7 +177,7 @@ func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, use // NewCredentialProcessProvider creates a provider to retrieve credentials from an external // executable as described in https://docs.aws.amazon.com/cli/latest/topic/config-vars.html#sourcing-credentials-from-external-processes -func NewCredentialProcessProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) { +func NewCredentialProcessProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool, parallelSafe bool) (aws.CredentialsProvider, error) { credentialProcessProvider := &CredentialProcessProvider{ CredentialProcess: config.CredentialProcess, } @@ -184,6 +191,7 @@ func NewCredentialProcessProvider(k keyring.Keyring, config *ProfileConfig, useS credentialProcessProvider, &SessionKeyring{Keyring: k}, defaultExpirationWindow, + parallelSafe, ), nil } @@ -256,19 +264,17 @@ func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, if err != nil { return nil, err } - sourcecredsProvider = t.applyParallelSafety(sourcecredsProvider) if hasStoredCredentials || !config.HasRole() { if canUseGetSessionToken, reason := t.canUseGetSessionToken(config); !canUseGetSessionToken { log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason) if !config.HasRole() { - return t.applyParallelSafety(sourcecredsProvider), nil + return sourcecredsProvider, nil } } t.chainedMfa = config.MfaSerial log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config)) - sourcecredsProvider, err = NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) - sourcecredsProvider = t.applyParallelSafety(sourcecredsProvider) + sourcecredsProvider, err = NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache, t.ParallelSafe) if !config.HasRole() || err != nil { return sourcecredsProvider, err } @@ -280,11 +286,7 @@ func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, config.MfaSerial = "" } log.Printf("profile %s: using AssumeRole %s", config.ProfileName, mfaDetails(isMfaChained, config)) - provider, err := NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) - if err != nil { - return nil, err - } - return t.applyParallelSafety(provider), nil + return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache, t.ParallelSafe) } if isMasterCredentialsProvider(sourcecredsProvider) { @@ -292,16 +294,12 @@ func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, if canUseGetSessionToken { t.chainedMfa = config.MfaSerial log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config)) - provider, err := NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) - if err != nil { - return nil, err - } - return t.applyParallelSafety(provider), nil + return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache, t.ParallelSafe) } log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason) } - return t.applyParallelSafety(sourcecredsProvider), nil + return sourcecredsProvider, nil } func (t *TempCredentialsCreator) GetProviderForProfile(config *ProfileConfig) (aws.CredentialsProvider, error) { @@ -316,54 +314,22 @@ func (t *TempCredentialsCreator) GetProviderForProfile(config *ProfileConfig) (a if config.HasSSOStartURL() { log.Printf("profile %s: using SSO role credentials", config.ProfileName) - provider, err := NewSSORoleCredentialsProvider(t.Keyring.Keyring, config, !t.DisableCache) - if err != nil { - return nil, err - } - return t.applyParallelSafety(provider), nil + return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config, !t.DisableCache, t.ParallelSafe) } if config.HasWebIdentity() { log.Printf("profile %s: using web identity", config.ProfileName) - provider, err := NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config, !t.DisableCache) - if err != nil { - return nil, err - } - return t.applyParallelSafety(provider), nil + return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config, !t.DisableCache, t.ParallelSafe) } if config.HasCredentialProcess() { log.Printf("profile %s: using credential process", config.ProfileName) - provider, err := NewCredentialProcessProvider(t.Keyring.Keyring, config, !t.DisableCache) - if err != nil { - return nil, err - } - return t.applyParallelSafety(provider), nil + return NewCredentialProcessProvider(t.Keyring.Keyring, config, !t.DisableCache, t.ParallelSafe) } return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName) } -func (t *TempCredentialsCreator) applyParallelSafety(provider aws.CredentialsProvider) aws.CredentialsProvider { - if !t.ParallelSafe { - return provider - } - - if cached, ok := provider.(*CachedSessionProvider); ok { - cached.UseSessionLock = true - if ssoProvider, ok := cached.SessionProvider.(*SSORoleCredentialsProvider); ok { - ssoProvider.EnableSSOTokenLock() - } - return provider - } - - if ssoProvider, ok := provider.(*SSORoleCredentialsProvider); ok { - ssoProvider.EnableSSOTokenLock() - } - - return provider -} - // canUseGetSessionToken determines if GetSessionToken should be used, and if not returns a reason func (t *TempCredentialsCreator) canUseGetSessionToken(c *ProfileConfig) (bool, string) { if t.DisableSessions { From a3a50bf0f5326a5c88990b57f0ba024322d889d0 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:48:23 -0400 Subject: [PATCH 21/30] test(lock): add lockedKeyring tests for lock-wait, timeout, and error joining Tests verify: - Lock retry behavior when lock is not immediately available - Timeout after configurable duration when lock is never acquired - errors.Join of work error and unlock error Also makes lockTimeout a struct field (was hardcoded constant) so tests can inject shorter timeouts. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/locked_keyring.go | 36 +++++----- vault/locked_keyring_test.go | 136 +++++++++++++++++++++++++++++++++++ 2 files changed, 155 insertions(+), 17 deletions(-) create mode 100644 vault/locked_keyring_test.go diff --git a/vault/locked_keyring.go b/vault/locked_keyring.go index 38d503a1..fb257435 100644 --- a/vault/locked_keyring.go +++ b/vault/locked_keyring.go @@ -19,13 +19,14 @@ type lockedKeyring struct { // process could race on the try-lock loop. mu sync.Mutex - lockKey string - lockWait time.Duration - lockLog time.Duration - warnAfter time.Duration - lockNow func() time.Time - lockSleep func(context.Context, time.Duration) error - lockLogf lockLogger + lockKey string + lockTimeout time.Duration + lockWait time.Duration + lockLog time.Duration + warnAfter time.Duration + lockNow func() time.Time + lockSleep func(context.Context, time.Duration) error + lockLogf lockLogger } const ( @@ -56,15 +57,16 @@ const ( // to serialize keyring operations. func NewLockedKeyring(kr keyring.Keyring, lockKey string) keyring.Keyring { return &lockedKeyring{ - inner: kr, - lock: NewDefaultLock("aws-vault.keyring", lockKey), - lockKey: lockKey, - lockWait: defaultKeyringLockWaitDelay, - lockLog: defaultKeyringLockLogEvery, - warnAfter: defaultKeyringLockWarnAfter, - lockNow: time.Now, - lockSleep: defaultContextSleep, - lockLogf: log.Printf, + inner: kr, + lock: NewDefaultLock("aws-vault.keyring", lockKey), + lockKey: lockKey, + lockTimeout: defaultKeyringLockTimeout, + lockWait: defaultKeyringLockWaitDelay, + lockLog: defaultKeyringLockLogEvery, + warnAfter: defaultKeyringLockWarnAfter, + lockNow: time.Now, + lockSleep: defaultContextSleep, + lockLogf: log.Printf, } } @@ -77,7 +79,7 @@ func (k *lockedKeyring) withLock(fn func() error) error { // in-flight keyring operations. This timeout is a safety net for the lock-wait // loop: if the lock holder is hung (e.g. a stuck gpg subprocess in the pass // backend), waiters will eventually give up rather than blocking indefinitely. - ctx, cancel := context.WithTimeout(context.Background(), defaultKeyringLockTimeout) + ctx, cancel := context.WithTimeout(context.Background(), k.lockTimeout) defer cancel() _, err := withProcessLock(ctx, k.lock, lockWaiterOpts{ diff --git a/vault/locked_keyring_test.go b/vault/locked_keyring_test.go new file mode 100644 index 00000000..cfdb5b57 --- /dev/null +++ b/vault/locked_keyring_test.go @@ -0,0 +1,136 @@ +package vault + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/byteness/keyring" +) + +// testUnlockErrLock is a testLock variant whose Unlock returns a configured error. +type testUnlockErrLock struct { + testLock + unlockErr error +} + +func (l *testUnlockErrLock) Unlock() error { + l.unlockCalls++ + l.locked = false + return l.unlockErr +} + +func newTestLockedKeyring(inner keyring.Keyring, lock ProcessLock, clock *testClock) *lockedKeyring { + return &lockedKeyring{ + inner: inner, + lock: lock, + lockKey: "test", + lockTimeout: defaultKeyringLockTimeout, + lockWait: 100 * time.Millisecond, + lockLog: 15 * time.Second, + warnAfter: 5 * time.Second, + lockNow: clock.Now, + lockSleep: clock.Sleep, + lockLogf: func(string, ...any) {}, + } +} + +func TestLockedKeyring_LockWaitRetries(t *testing.T) { + // Lock fails twice, then succeeds on the third attempt. + lock := &testLock{tryResults: []bool{false, false, true}} + kr := keyring.NewArrayKeyring([]keyring.Item{ + {Key: "foo", Data: []byte("bar")}, + }) + clock := &testClock{now: time.Unix(0, 0)} + + lk := newTestLockedKeyring(kr, lock, clock) + + item, err := lk.Get("foo") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(item.Data) != "bar" { + t.Fatalf("unexpected data: %s", string(item.Data)) + } + if lock.tryCalls != 3 { + t.Fatalf("expected 3 lock attempts, got %d", lock.tryCalls) + } + if clock.sleepCalls != 2 { + t.Fatalf("expected 2 sleep calls, got %d", clock.sleepCalls) + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } +} + +func TestLockedKeyring_Timeout(t *testing.T) { + // Lock is never acquired. With a short lockTimeout the context should + // time out and withLock should return context.DeadlineExceeded. + lock := &testLock{} // tryResults is empty so TryLock always returns false + kr := keyring.NewArrayKeyring(nil) + clock := &testClock{now: time.Unix(0, 0)} + + lk := newTestLockedKeyring(kr, lock, clock) + // Use a very short real timeout so the test completes quickly. + lk.lockTimeout = 50 * time.Millisecond + // Use real sleep so the context deadline fires from wall-clock time. + lk.lockSleep = defaultContextSleep + lk.lockWait = 10 * time.Millisecond + + _, err := lk.Keys() + if err == nil { + t.Fatal("expected timeout error, got nil") + } + // The context should have been cancelled via timeout. The error will + // be context.DeadlineExceeded because withLock uses context.WithTimeout. + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected context.DeadlineExceeded, got: %v", err) + } + // Verify that at least one lock attempt was made before timing out. + if lock.tryCalls < 1 { + t.Fatalf("expected at least 1 lock attempt, got %d", lock.tryCalls) + } +} + +func TestLockedKeyring_UnlockErrorJoined(t *testing.T) { + // Both the work function and Unlock return errors; they should be joined + // via errors.Join so that errors.Is can unwrap both. + workErr := fmt.Errorf("work failed") + unlockErr := fmt.Errorf("unlock broken") + + lock := &testUnlockErrLock{ + testLock: testLock{tryResults: []bool{true}}, + unlockErr: unlockErr, + } + + // Use a keyring whose Remove always fails with workErr. + inner := &failingKeyring{removeErr: workErr} + clock := &testClock{now: time.Unix(0, 0)} + + lk := newTestLockedKeyring(inner, lock, clock) + + err := lk.Remove("anything") + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, workErr) { + t.Fatalf("expected joined error to contain work error, got: %v", err) + } + // The unlock error is wrapped as "unlock keyring lock: " + // using %w, so errors.Is can unwrap through the wrapping. + if !errors.Is(err, unlockErr) { + t.Fatalf("expected joined error to contain unlock error, got: %v", err) + } +} + +// failingKeyring is a keyring.Keyring that returns configured errors. +type failingKeyring struct { + keyring.Keyring + removeErr error +} + +func (k *failingKeyring) Remove(string) error { + return k.removeErr +} From 400265c2bcea2e8e7caed5f57601675c6718c88b Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:50:04 -0400 Subject: [PATCH 22/30] test: add parallel-safe tests for GetSessionToken provider path Verify that UseSessionLock is set on the CachedSessionProvider when ParallelSafe=true for the GetSessionToken (stored credentials) path. This complements the existing SSO path test. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/vault_test.go | 97 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/vault/vault_test.go b/vault/vault_test.go index 3615d7e8..569ba886 100644 --- a/vault/vault_test.go +++ b/vault/vault_test.go @@ -124,6 +124,103 @@ sso_registration_scopes=sso:account:access } } +func TestTempCredentialsProviderParallelSafeGetSessionToken(t *testing.T) { + f := newConfigFile(t, []byte(` +[profile creds] +region = us-east-1 +`)) + defer os.Remove(f) + configFile, err := vault.LoadConfig(f) + if err != nil { + t.Fatal(err) + } + configLoader := &vault.ConfigLoader{ + File: configFile, + ActiveProfile: "creds", + BaseConfig: vault.ProfileConfig{MfaPromptMethod: "terminal"}, + } + config, err := configLoader.GetProfileConfig("creds") + if err != nil { + t.Fatalf("Should have found a profile: %v", err) + } + + ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{ + {Key: "creds", Data: []byte(`{"AccessKeyID":"AKIAIOSFODNN7EXAMPLE","SecretAccessKey":"secret"}`)}, + })} + provider, err := vault.NewTempCredentialsProviderWithOptions( + config, + ckr, + false, + false, + vault.TempCredentialsOptions{ParallelSafe: true}, + ) + if err != nil { + t.Fatal(err) + } + + cached, ok := provider.(*vault.CachedSessionProvider) + if !ok { + t.Fatalf("Expected CachedSessionProvider, got %T", provider) + } + if !cached.UseSessionLock { + t.Fatalf("Expected UseSessionLock to be true") + } + _, ok = cached.SessionProvider.(*vault.SessionTokenProvider) + if !ok { + t.Fatalf("Expected SessionTokenProvider, got %T", cached.SessionProvider) + } +} + +func TestTempCredentialsProviderParallelSafeAssumeRole(t *testing.T) { + f := newConfigFile(t, []byte(` +[profile source] +region = us-east-1 + +[profile role] +source_profile = source +role_arn = arn:aws:iam::222222222222:role/role +mfa_serial = arn:aws:iam::111111111111:mfa/user +region = us-east-1 +`)) + defer os.Remove(f) + configFile, err := vault.LoadConfig(f) + if err != nil { + t.Fatal(err) + } + configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "role"} + config, err := configLoader.GetProfileConfig("role") + if err != nil { + t.Fatalf("Should have found a profile: %v", err) + } + config.MfaToken = "123456" // avoid interactive MFA prompt + + ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{ + {Key: "source", Data: []byte(`{"AccessKeyID":"AKIAIOSFODNN7EXAMPLE","SecretAccessKey":"secret"}`)}, + })} + provider, err := vault.NewTempCredentialsProviderWithOptions( + config, + ckr, + true, // disableSessions: skip GetSessionToken so AssumeRole gets the MFA + false, + vault.TempCredentialsOptions{ParallelSafe: true}, + ) + if err != nil { + t.Fatal(err) + } + + cached, ok := provider.(*vault.CachedSessionProvider) + if !ok { + t.Fatalf("Expected CachedSessionProvider, got %T", provider) + } + if !cached.UseSessionLock { + t.Fatalf("Expected UseSessionLock to be true") + } + _, ok = cached.SessionProvider.(*vault.AssumeRoleProvider) + if !ok { + t.Fatalf("Expected AssumeRoleProvider, got %T", cached.SessionProvider) + } +} + func TestTempCredentialsProviderParallelSafeSSOLocks(t *testing.T) { config := &vault.ProfileConfig{ ProfileName: "sso-profile", From d0a294102fbfb310fbd7a47b5f4fc4684fe47d45 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:50:44 -0400 Subject: [PATCH 23/30] fix(test): simplify parallel-safe provider tests Use ProfileConfig directly instead of config file parsing to avoid needing MfaPromptMethod. Add AssumeRole path test. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/vault_test.go | 52 +++++++++++++-------------------------------- 1 file changed, 15 insertions(+), 37 deletions(-) diff --git a/vault/vault_test.go b/vault/vault_test.go index 569ba886..8ec3d165 100644 --- a/vault/vault_test.go +++ b/vault/vault_test.go @@ -125,23 +125,10 @@ sso_registration_scopes=sso:account:access } func TestTempCredentialsProviderParallelSafeGetSessionToken(t *testing.T) { - f := newConfigFile(t, []byte(` -[profile creds] -region = us-east-1 -`)) - defer os.Remove(f) - configFile, err := vault.LoadConfig(f) - if err != nil { - t.Fatal(err) - } - configLoader := &vault.ConfigLoader{ - File: configFile, - ActiveProfile: "creds", - BaseConfig: vault.ProfileConfig{MfaPromptMethod: "terminal"}, - } - config, err := configLoader.GetProfileConfig("creds") - if err != nil { - t.Fatalf("Should have found a profile: %v", err) + config := &vault.ProfileConfig{ + ProfileName: "creds", + Region: "us-east-1", + MfaToken: "123456", // provide token to avoid interactive prompt } ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{ @@ -172,27 +159,18 @@ region = us-east-1 } func TestTempCredentialsProviderParallelSafeAssumeRole(t *testing.T) { - f := newConfigFile(t, []byte(` -[profile source] -region = us-east-1 - -[profile role] -source_profile = source -role_arn = arn:aws:iam::222222222222:role/role -mfa_serial = arn:aws:iam::111111111111:mfa/user -region = us-east-1 -`)) - defer os.Remove(f) - configFile, err := vault.LoadConfig(f) - if err != nil { - t.Fatal(err) - } - configLoader := &vault.ConfigLoader{File: configFile, ActiveProfile: "role"} - config, err := configLoader.GetProfileConfig("role") - if err != nil { - t.Fatalf("Should have found a profile: %v", err) + config := &vault.ProfileConfig{ + ProfileName: "role", + Region: "us-east-1", + RoleARN: "arn:aws:iam::222222222222:role/role", + MfaSerial: "arn:aws:iam::111111111111:mfa/user", + MfaToken: "123456", // provide token to avoid interactive prompt + SourceProfileName: "source", + SourceProfile: &vault.ProfileConfig{ + ProfileName: "source", + Region: "us-east-1", + }, } - config.MfaToken = "123456" // avoid interactive MFA prompt ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{ {Key: "source", Data: []byte(`{"AccessKeyID":"AKIAIOSFODNN7EXAMPLE","SecretAccessKey":"secret"}`)}, From 3727f5c03d9dc3fafc28be58168d499403dbc0f3 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Wed, 8 Apr 2026 09:29:30 -0400 Subject: [PATCH 24/30] fix(sso): clamp jitteredBackoff overflow to max instead of base MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit At attempt ~37, base << uint(attempt-1) overflows int64 and produces a negative value. The existing guard caught this but reset to base (200ms) instead of max (5s), making late retries more aggressive under sustained 429s — exactly the wrong behavior during a rate-limit storm. Also extends TestJitteredBackoffRespectsMax to cover attempts 20-60 with both min and max bounds, catching the overflow regression. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/ssorolecredentialsprovider.go | 4 +++- vault/ssorolecredentialsprovider_retry_test.go | 11 +++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index 6e1c899a..35aeed0c 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -456,7 +456,9 @@ func jitteredBackoff(base, max time.Duration, attempt int) time.Duration { capDelay = max } if capDelay < base { - capDelay = base + // Overflow: large shift wrapped negative; clamp to max, not base, + // so late retries stay backed off instead of becoming aggressive. + capDelay = max } return jitterDelay(capDelay) } diff --git a/vault/ssorolecredentialsprovider_retry_test.go b/vault/ssorolecredentialsprovider_retry_test.go index 0aedb844..fc3464df 100644 --- a/vault/ssorolecredentialsprovider_retry_test.go +++ b/vault/ssorolecredentialsprovider_retry_test.go @@ -111,16 +111,19 @@ func TestJitteredBackoffRespectsMax(t *testing.T) { base := 200 * time.Millisecond max := 5 * time.Second - // At high attempt numbers the cap should be max, not overflow + // At high attempt numbers the cap should be max, not overflow. + // This includes attempts 37+ where base<<(attempt-1) overflows int64; + // the delay must stay clamped to max, not collapse to base. + minDelay := time.Duration(float64(max) * ssoRetryAfterJitterMin) maxDelay := time.Duration(float64(max) * ssoRetryAfterJitterMax) - for attempt := 20; attempt <= 30; attempt++ { + for attempt := 20; attempt <= 60; attempt++ { for i := 0; i < 10; i++ { delay := jitteredBackoff(base, max, attempt) if maxDelay < delay { t.Fatalf("attempt %d: delay %s exceeds max jittered cap %s", attempt, delay, maxDelay) } - if delay < 0 { - t.Fatalf("attempt %d: negative delay %s", attempt, delay) + if delay < minDelay { + t.Fatalf("attempt %d: delay %s below min jittered cap %s (overflow regression)", attempt, delay, minDelay) } } } From 0e0793d5c20c8d326d04b5f9dac51afdf89a0721 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Wed, 8 Apr 2026 09:29:40 -0400 Subject: [PATCH 25/30] docs(lock): note that lockedKeyring.mu blocks without timeout The in-process sync.Mutex is not covered by the 2-minute flock timeout. If a keyring operation hangs (e.g. a stuck gpg subprocess), other goroutines in the same process block indefinitely on mu.Lock(). Document this limitation since the keyring.Keyring interface is not context-aware. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/locked_keyring.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vault/locked_keyring.go b/vault/locked_keyring.go index fb257435..d7c8a87f 100644 --- a/vault/locked_keyring.go +++ b/vault/locked_keyring.go @@ -17,6 +17,13 @@ type lockedKeyring struct { // mu serializes in-process access. The flock only coordinates across // processes; without this mutex, concurrent goroutines in the same // process could race on the try-lock loop. + // + // NOTE: mu.Lock() blocks without a timeout. The 2-minute timeout + // (lockTimeout) only applies to the flock wait loop inside withLock. + // If a keyring operation hangs while holding mu (e.g. a stuck gpg + // subprocess), other goroutines in the same process will block + // indefinitely. The keyring.Keyring interface is not context-aware, + // so there is no clean way to cancel in-flight operations. mu sync.Mutex lockKey string From 6a334be65efd1226821056ecd4c48e2517edf939 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Wed, 8 Apr 2026 09:29:46 -0400 Subject: [PATCH 26/30] fix(cli): simplify --parallel-safe flag description The previous description enumerated specific commands (exec, export, rotate) but the keyring wrapping via a.Keyring() is app-global. Describe what the flag does without enumerating commands to avoid inaccuracy. Co-Authored-By: Claude Opus 4.6 (1M context) --- cli/global.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/global.go b/cli/global.go index 92241a9d..eac08fa5 100644 --- a/cli/global.go +++ b/cli/global.go @@ -270,7 +270,7 @@ func ConfigureGlobals(app *kingpin.Application) *AwsVault { Envar("AWS_VAULT_BIOMETRICS"). BoolVar(&a.UseBiometrics) - app.Flag("parallel-safe", "Enable cross-process locking for keychain and cached credentials (applies to exec, export, rotate; not login)"). + app.Flag("parallel-safe", "Enable cross-process locking for keyring operations, session caching, and SSO browser flows"). Envar("AWS_VAULT_PARALLEL_SAFE"). BoolVar(&a.ParallelSafe) From c2d068b262f03b4cd6e37fa60ec6d99a2ec1e515 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:41:33 -0400 Subject: [PATCH 27/30] fix(lock): separate lock-wait timeout from work context MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The lock-wait timeout (context.WithTimeout) was wrapping the entire withProcessLock call, which meant it capped both the time waiting for the lock AND the time doing work under the lock. For SSO browser auth, this could cancel a legitimate login that takes longer than 2 minutes. Fix: doWork closures now capture the caller's ctx (via `func(_ context.Context)`) instead of receiving the timeout-constrained waitCtx. The lock-wait timeout only bounds the polling loop, not the work done once the lock is acquired. Also replaces the redundant `Lock ProcessLock` field in lockWaiterOpts with `LockPath string` — the field was only used for `.Path()` in log messages, and passing the full lock created a risk of divergence with the positional lock parameter in withProcessLock. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/cachedsessionprovider.go | 30 +++++++----- vault/cachedsessionprovider_lock_test.go | 48 +++++++++++++++++++ vault/lock_waiter.go | 6 +-- vault/locked_keyring.go | 4 +- vault/ssorolecredentialsprovider.go | 14 ++++-- vault/ssorolecredentialsprovider_lock_test.go | 39 +++++++++++++++ 6 files changed, 120 insertions(+), 21 deletions(-) diff --git a/vault/cachedsessionprovider.go b/vault/cachedsessionprovider.go index 72c42f4b..9036bc45 100644 --- a/vault/cachedsessionprovider.go +++ b/vault/cachedsessionprovider.go @@ -24,10 +24,11 @@ type CachedSessionProvider struct { Keyring *SessionKeyring ExpiryWindow time.Duration UseSessionLock bool - sessionLock ProcessLock - sessionLockWait time.Duration - sessionLockLog time.Duration - sessionNow func() time.Time + sessionLock ProcessLock + sessionLockWait time.Duration + sessionLockLog time.Duration + sessionLockTimeout time.Duration + sessionNow func() time.Time sessionSleep func(context.Context, time.Duration) error sessionLogf lockLogger } @@ -65,10 +66,11 @@ func NewCachedSessionProvider(key SessionMetadata, provider StsSessionProvider, Keyring: keyring, ExpiryWindow: expiryWindow, UseSessionLock: useSessionLock, - sessionLock: NewDefaultLock("aws-vault.session", key.StringForMatching()), - sessionLockWait: defaultSessionLockWaitDelay, - sessionLockLog: defaultSessionLockLogEvery, - sessionNow: time.Now, + sessionLock: NewDefaultLock("aws-vault.session", key.StringForMatching()), + sessionLockWait: defaultSessionLockWaitDelay, + sessionLockLog: defaultSessionLockLogEvery, + sessionLockTimeout: defaultSessionLockTimeout, + sessionNow: time.Now, sessionSleep: defaultContextSleep, sessionLogf: log.Printf, } @@ -100,11 +102,11 @@ func (p *CachedSessionProvider) getCachedSession() (creds *ststypes.Credentials, } func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststypes.Credentials, error) { - ctx, cancel := context.WithTimeout(ctx, defaultSessionLockTimeout) + waitCtx, cancel := context.WithTimeout(ctx, p.sessionLockTimeout) defer cancel() - return withProcessLock(ctx, p.sessionLock, lockWaiterOpts{ - Lock: p.sessionLock, + return withProcessLock(waitCtx, p.sessionLock, lockWaiterOpts{ + LockPath: p.sessionLock.Path(), WarnMsg: "Waiting for session lock at %s\n", LogMsg: "Waiting for session lock at %s", WaitDelay: p.sessionLockWait, @@ -122,7 +124,11 @@ func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststyp return processLockResult[*ststypes.Credentials]{value: creds, ok: true}, nil } return processLockResult[*ststypes.Credentials]{}, nil - }, func(ctx context.Context) (*ststypes.Credentials, error) { + }, func(_ context.Context) (*ststypes.Credentials, error) { + // Use the caller's ctx (not the lock-wait timeout ctx) for actual + // work. The lock-wait timeout bounds how long we wait for the lock, + // not how long the work takes once we hold it. + // Recheck cache after acquiring lock — another process may have filled it. creds, cached, cacheErr := p.getCachedSession() if cacheErr == nil && cached { diff --git a/vault/cachedsessionprovider_lock_test.go b/vault/cachedsessionprovider_lock_test.go index c3661fd9..919fc447 100644 --- a/vault/cachedsessionprovider_lock_test.go +++ b/vault/cachedsessionprovider_lock_test.go @@ -253,3 +253,51 @@ func TestCachedSession_LockWaitLogs(t *testing.T) { t.Fatalf("unexpected second log time: %s", logTimes[1]) } } + +// ctxAwareSessionProvider is a test StsSessionProvider that respects context +// cancellation, used to verify the work context is not the lock-wait timeout. +type ctxAwareSessionProvider struct { + creds *types.Credentials + delay time.Duration +} + +func (p *ctxAwareSessionProvider) RetrieveStsCredentials(ctx context.Context) (*types.Credentials, error) { + select { + case <-time.After(p.delay): + return p.creds, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (p *ctxAwareSessionProvider) Retrieve(context.Context) (aws.Credentials, error) { + return aws.Credentials{}, nil +} + +func TestCachedSession_WorkNotCancelledByLockTimeout(t *testing.T) { + // The lock-wait timeout should only bound how long we wait for the + // lock, not how long the work takes. Simulate a provider that takes + // longer than the lock-wait timeout and verify it completes. + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + lock := &testLock{tryResults: []bool{true}} + + provider := &ctxAwareSessionProvider{creds: creds, delay: 50 * time.Millisecond} + + p := NewCachedSessionProvider(key, provider, sk, 0, true) + p.sessionLock = lock + p.sessionLockTimeout = 10 * time.Millisecond + + got, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error (work cancelled by lock-wait timeout?): %v", err) + } + if aws.ToString(got.AccessKeyId) != aws.ToString(creds.AccessKeyId) { + t.Fatalf("unexpected credentials returned") + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } +} diff --git a/vault/lock_waiter.go b/vault/lock_waiter.go index 1c5a2665..7a6488c9 100644 --- a/vault/lock_waiter.go +++ b/vault/lock_waiter.go @@ -10,7 +10,7 @@ type lockLogger func(string, ...any) // lockWaiterOpts configures a lockWaiter. All fields are required except // Now, Sleep, and Warnf which have sensible defaults. type lockWaiterOpts struct { - Lock ProcessLock + LockPath string WarnMsg string LogMsg string WaitDelay time.Duration @@ -47,12 +47,12 @@ func (w *lockWaiter) sleepAfterMiss(ctx context.Context) error { } if !w.warned && w.opts.WarnAfter <= now.Sub(w.waitStart) { if w.opts.Warnf != nil { - w.opts.Warnf(w.opts.WarnMsg, w.opts.Lock.Path()) + w.opts.Warnf(w.opts.WarnMsg, w.opts.LockPath) } w.warned = true } if w.opts.Logf != nil && (w.lastLog.IsZero() || w.opts.LogEvery <= now.Sub(w.lastLog)) { - w.opts.Logf(w.opts.LogMsg, w.opts.Lock.Path()) + w.opts.Logf(w.opts.LogMsg, w.opts.LockPath) w.lastLog = now } diff --git a/vault/locked_keyring.go b/vault/locked_keyring.go index d7c8a87f..51530de6 100644 --- a/vault/locked_keyring.go +++ b/vault/locked_keyring.go @@ -90,7 +90,7 @@ func (k *lockedKeyring) withLock(fn func() error) error { defer cancel() _, err := withProcessLock(ctx, k.lock, lockWaiterOpts{ - Lock: k.lock, + LockPath: k.lock.Path(), WarnMsg: "Waiting for keyring lock at %s\n", LogMsg: "Waiting for keyring lock at %s", WaitDelay: k.lockWait, @@ -102,7 +102,7 @@ func (k *lockedKeyring) withLock(fn func() error) error { Warnf: func(format string, args ...any) { fmt.Fprintf(os.Stderr, format, args...) }, - }, "keyring", nil, func(ctx context.Context) (struct{}, error) { + }, "keyring", nil, func(_ context.Context) (struct{}, error) { return struct{}{}, fn() }) return err diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index 35aeed0c..0c01a13a 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -42,6 +42,7 @@ type SSORoleCredentialsProvider struct { ssoTokenLock ProcessLock ssoLockWait time.Duration ssoLockLog time.Duration + ssoLockTimeout time.Duration ssoNow func() time.Time ssoSleep func(context.Context, time.Duration) error ssoLogf lockLogger @@ -105,6 +106,7 @@ const ( func (p *SSORoleCredentialsProvider) initSSODefaults() { p.ssoLockWait = defaultSSOLockWaitDelay p.ssoLockLog = defaultSSOLockLogEvery + p.ssoLockTimeout = defaultSSOLockTimeout p.ssoNow = time.Now p.ssoSleep = defaultContextSleep p.ssoLogf = log.Printf @@ -278,11 +280,11 @@ type oidcTokenResult struct { } func (p *SSORoleCredentialsProvider) getOIDCTokenWithLock(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { - ctx, cancel := context.WithTimeout(ctx, defaultSSOLockTimeout) + waitCtx, cancel := context.WithTimeout(ctx, p.ssoLockTimeout) defer cancel() - result, err := withProcessLock(ctx, p.ssoTokenLock, lockWaiterOpts{ - Lock: p.ssoTokenLock, + result, err := withProcessLock(waitCtx, p.ssoTokenLock, lockWaiterOpts{ + LockPath: p.ssoTokenLock.Path(), WarnMsg: "Waiting for SSO lock at %s\n", LogMsg: "Waiting for SSO lock at %s", WaitDelay: p.ssoLockWait, @@ -303,7 +305,11 @@ func (p *SSORoleCredentialsProvider) getOIDCTokenWithLock(ctx context.Context) ( return processLockResult[oidcTokenResult]{value: oidcTokenResult{token, cached}, ok: true}, nil } return processLockResult[oidcTokenResult]{}, nil - }, func(ctx context.Context) (oidcTokenResult, error) { + }, func(_ context.Context) (oidcTokenResult, error) { + // Use the caller's ctx (not the lock-wait timeout ctx) for actual + // work. The lock-wait timeout bounds how long we wait for the lock, + // not how long the work takes once we hold it. + // Recheck cache after acquiring lock — another process may have filled it. token, cached, err := p.getCachedOIDCToken() if err != nil { diff --git a/vault/ssorolecredentialsprovider_lock_test.go b/vault/ssorolecredentialsprovider_lock_test.go index df89b6a8..ffbd15a2 100644 --- a/vault/ssorolecredentialsprovider_lock_test.go +++ b/vault/ssorolecredentialsprovider_lock_test.go @@ -278,3 +278,42 @@ func TestGetOIDCToken_LockWaitLogs(t *testing.T) { t.Fatalf("unexpected second log time: %s", logTimes[1]) } } + +func TestGetOIDCToken_WorkNotCancelledByLockTimeout(t *testing.T) { + // The lock-wait timeout should only bound how long we wait for the + // lock, not how long the work takes. Simulate work that takes longer + // than the lock-wait timeout and verify it completes. + freshToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("fresh")} + lock := &testLock{tryResults: []bool{true}} + cache := &testTokenCache{setLock: lock} + + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = true + p.ssoLockTimeout = 10 * time.Millisecond + p.newOIDCTokenFn = func(ctx context.Context) (*ssooidc.CreateTokenOutput, error) { + // Work takes longer than the lock-wait timeout. + // If the timeout incorrectly cancels work, ctx.Err() fires. + select { + case <-time.After(50 * time.Millisecond): + return freshToken, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error (work cancelled by lock-wait timeout?): %v", err) + } + if cached { + t.Fatalf("expected non-cached token") + } + if token != freshToken { + t.Fatalf("unexpected token returned") + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } +} From 0933895faf1a4aaf381b27db79534eb77035967d Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:41:51 -0400 Subject: [PATCH 28/30] fix(sso): use context-aware sleep in OIDC device-flow polling The OIDC device-flow polling loop used time.Sleep which doesn't respond to context cancellation. While holding the SSO flock, this blocked other processes from attempting the lock and prevented prompt cancellation on Ctrl+C. Switch to p.ssoSleep (defaultContextSleep) so the sleep respects the ctx passed to newOIDCToken. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/ssorolecredentialsprovider.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index 0c01a13a..f12ba217 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -389,7 +389,9 @@ func (p *SSORoleCredentialsProvider) newOIDCToken(ctx context.Context) (*ssooidc var ape *ssooidctypes.AuthorizationPendingException if errors.As(err, &ape) { - time.Sleep(retryInterval) + if sleepErr := p.ssoSleep(ctx, retryInterval); sleepErr != nil { + return nil, sleepErr + } continue } From eed6b2143a97339a2ff1969af03836e88a95869e Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:42:21 -0400 Subject: [PATCH 29/30] fix(session): log non-trivial cache read errors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit getCachedSession errors were silently swallowed — the code fell through to re-fetch credentials on any error, which is correct self-healing behavior but made real keyring problems (permissions, corruption) invisible. Add log.Printf for errors that aren't keyring.ErrKeyNotFound so they surface in debug mode without changing the fallthrough behavior. Co-Authored-By: Claude Opus 4.6 (1M context) --- vault/cachedsessionprovider.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vault/cachedsessionprovider.go b/vault/cachedsessionprovider.go index 9036bc45..a8c3dc4a 100644 --- a/vault/cachedsessionprovider.go +++ b/vault/cachedsessionprovider.go @@ -2,6 +2,7 @@ package vault import ( "context" + "errors" "fmt" "log" "os" @@ -9,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/byteness/keyring" ) type StsSessionProvider interface { @@ -78,6 +80,9 @@ func NewCachedSessionProvider(key SessionMetadata, provider StsSessionProvider, func (p *CachedSessionProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { creds, cached, err := p.getCachedSession() + if err != nil && !errors.Is(err, keyring.ErrKeyNotFound) { + log.Printf("Reading cached session for %s: %v; will refresh", p.SessionKey.ProfileName, err) + } if err == nil && cached { return creds, nil } @@ -120,6 +125,9 @@ func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststyp }, }, "session", func() (processLockResult[*ststypes.Credentials], error) { creds, cached, err := p.getCachedSession() + if err != nil && !errors.Is(err, keyring.ErrKeyNotFound) { + log.Printf("Reading cached session for %s: %v; will try lock", p.SessionKey.ProfileName, err) + } if err == nil && cached { return processLockResult[*ststypes.Credentials]{value: creds, ok: true}, nil } From 952e2d91a589cd8ee15ec7b86a127eefdd1cd196 Mon Sep 17 00:00:00 2001 From: Tim Visher <194828183+timvisher-dd@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:54:34 -0400 Subject: [PATCH 30/30] fix(test): use rawKeyringImpl in CLI example tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Keyring() refactor split keyringImpl (parallel-safe wrapper) from rawKeyringImpl (underlying keyring). Example tests that inject a mock keyring were still setting keyringImpl, but Keyring() now calls rawKeyring() first — which tried keyring.Open and failed on CI where secret-service is unavailable. Set rawKeyringImpl instead. Co-Authored-By: Claude Opus 4.6 (1M context) --- cli/exec_test.go | 2 +- cli/export_test.go | 2 +- cli/list_test.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cli/exec_test.go b/cli/exec_test.go index fc22c442..9d87ee4f 100644 --- a/cli/exec_test.go +++ b/cli/exec_test.go @@ -9,7 +9,7 @@ import ( func ExampleExecCommand() { app := kingpin.New("aws-vault", "") awsVault := ConfigureGlobals(app) - awsVault.keyringImpl = keyring.NewArrayKeyring([]keyring.Item{ + awsVault.rawKeyringImpl = keyring.NewArrayKeyring([]keyring.Item{ {Key: "llamas", Data: []byte(`{"AccessKeyID":"ABC","SecretAccessKey":"XYZ"}`)}, }) ConfigureExecCommand(app, awsVault) diff --git a/cli/export_test.go b/cli/export_test.go index 8f02e398..94a67de4 100644 --- a/cli/export_test.go +++ b/cli/export_test.go @@ -9,7 +9,7 @@ import ( func ExampleExportCommand() { app := kingpin.New("aws-vault", "") awsVault := ConfigureGlobals(app) - awsVault.keyringImpl = keyring.NewArrayKeyring([]keyring.Item{ + awsVault.rawKeyringImpl = keyring.NewArrayKeyring([]keyring.Item{ {Key: "llamas", Data: []byte(`{"AccessKeyID":"ABC","SecretAccessKey":"XYZ"}`)}, }) ConfigureExportCommand(app, awsVault) diff --git a/cli/list_test.go b/cli/list_test.go index b1055e74..18e0bfb6 100644 --- a/cli/list_test.go +++ b/cli/list_test.go @@ -9,7 +9,7 @@ import ( func ExampleListCommand() { app := kingpin.New("aws-vault", "") awsVault := ConfigureGlobals(app) - awsVault.keyringImpl = keyring.NewArrayKeyring([]keyring.Item{ + awsVault.rawKeyringImpl = keyring.NewArrayKeyring([]keyring.Item{ {Key: "llamas", Data: []byte(`{"AccessKeyID":"ABC","SecretAccessKey":"XYZ"}`)}, }) ConfigureListCommand(app, awsVault)