diff --git a/USAGE.md b/USAGE.md index 00e30ffa..0d7e1dc2 100644 --- a/USAGE.md +++ b/USAGE.md @@ -273,6 +273,7 @@ WARNING: Use of this option runs against security best practices. It is recommen To configure the default flag values of `aws-vault` and its subcommands: * `AWS_VAULT_BACKEND`: Secret backend to use (see the flag `--backend`) * `AWS_VAULT_BIOMETRICS`: Use biometric authentication using TouchID, if supported (see the flag `--biometrics`) +* `AWS_VAULT_PARALLEL_SAFE`: Enable cross-process locking for keychain and cached credentials (see the flag `--parallel-safe`) * `AWS_VAULT_KEYCHAIN_NAME`: Name of macOS keychain to use (see the flag `--keychain`) * `AWS_VAULT_AUTO_LOGOUT`: Enable auto-logout when doing `login` (see the flag `--auto-logout`) * `AWS_VAULT_PROMPT`: Prompt driver to use (see the flag `--prompt`) @@ -634,6 +635,35 @@ role_arn=arn:aws:iam::123456789013:role/AnotherRole source_profile=Administrator-123456789012] ``` +## Parallel-safe mode + +When running many `aws-vault` processes in parallel (e.g. Terraform with hundreds of `credential_process` invocations), concurrent access to the secret store and SSO browser flows can cause errors: + +- **Browser storms**: Multiple processes each open a browser tab for the same SSO login, overwhelming AWS and triggering HTTP 500 errors. +- **Secret store races**: Concurrent writes to the same keyring entry cause "item already exists" errors or partial reads. + +The `--parallel-safe` flag (or `AWS_VAULT_PARALLEL_SAFE=true`) enables cross-process locking to prevent these issues: + +- **SSO token lock**: Only one process per SSO Start URL opens a browser tab; others wait for the cached token. +- **Session cache lock**: Only one process writes back to a given session cache entry at a time. +- **Keyring lock**: All keyring read/write operations are serialized across processes. + +This applies to **all backends** (keychain, file, pass, secret-service, etc.). + +### Trade-offs + +- Keyring operations are serialized, which adds a small amount of latency per operation. In practice this is negligible because the operations themselves are fast. +- **All concurrent invocations must use `--parallel-safe`**. If some processes enable it and others don't, the unprotected processes ignore the locks entirely. This is undefined behavior and may still cause races. Set `AWS_VAULT_PARALLEL_SAFE=true` in your environment to ensure consistent use. + +### The `login` command + +The `login` command is intentionally excluded from `--parallel-safe`. Console login sessions are inherently single-use — you cannot meaningfully log in to multiple AWS consoles in parallel — so there is no concurrent-access problem for `--parallel-safe` to solve. The `exec`, `export`, and `rotate` commands all support `--parallel-safe`. + +### Limitations + +- The keyring lock wait loop cannot be cancelled by the caller because the `keyring.Keyring` interface is not context-aware. If a lock holder hangs (e.g. a stuck `gpg` subprocess in the `pass` backend), waiters will time out after 2 minutes rather than waiting indefinitely. +- SSO rate-limit retries (HTTP 429 on `GetRoleCredentials`) will retry for up to 5 minutes before giving up with an error. + ## Assuming roles with web identities AWS supports assuming roles using [web identity federation and OpenID Connect](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-role.html#cli-configure-role-oidc), including login using Amazon, Google, Facebook or any other OpenID Connect server. The configuration options are as follows: diff --git a/cli/exec.go b/cli/exec.go index a2c21f46..7d9f3895 100644 --- a/cli/exec.go +++ b/cli/exec.go @@ -34,6 +34,7 @@ type ExecCommandInput struct { NoSession bool UseStdout bool ShowHelpMessages bool + ParallelSafe bool } func (input ExecCommandInput) validate() error { @@ -121,6 +122,7 @@ func ConfigureExecCommand(app *kingpin.Application, a *AwsVault) { StringsVar(&input.Args) cmd.Action(func(c *kingpin.ParseContext) (err error) { + input.ParallelSafe = a.ParallelSafe input.Config.MfaPromptMethod = a.PromptDriver(hasBackgroundServer(input)) input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration input.Config.AssumeRoleDuration = input.SessionDuration @@ -155,6 +157,7 @@ func ConfigureExecCommand(app *kingpin.Application, a *AwsVault) { Config: input.Config, SessionDuration: input.SessionDuration, NoSession: input.NoSession, + ParallelSafe: input.ParallelSafe, } err = ExportCommand(exportCommandInput, f, keyring) @@ -185,7 +188,13 @@ func ExecCommand(input ExecCommandInput, f *vault.ConfigFile, keyring keyring.Ke return 0, fmt.Errorf("Error loading config: %w", err) } - credsProvider, err := vault.NewTempCredentialsProvider(config, &vault.CredentialKeyring{Keyring: keyring}, input.NoSession, false) + credsProvider, err := vault.NewTempCredentialsProviderWithOptions( + config, + &vault.CredentialKeyring{Keyring: keyring}, + input.NoSession, + false, + vault.TempCredentialsOptions{ParallelSafe: input.ParallelSafe}, + ) if err != nil { return 0, fmt.Errorf("Error getting temporary credentials: %w", err) } diff --git a/cli/exec_test.go b/cli/exec_test.go index fc22c442..9d87ee4f 100644 --- a/cli/exec_test.go +++ b/cli/exec_test.go @@ -9,7 +9,7 @@ import ( func ExampleExecCommand() { app := kingpin.New("aws-vault", "") awsVault := ConfigureGlobals(app) - awsVault.keyringImpl = keyring.NewArrayKeyring([]keyring.Item{ + awsVault.rawKeyringImpl = keyring.NewArrayKeyring([]keyring.Item{ {Key: "llamas", Data: []byte(`{"AccessKeyID":"ABC","SecretAccessKey":"XYZ"}`)}, }) ConfigureExecCommand(app, awsVault) diff --git a/cli/export.go b/cli/export.go index d5723eb0..52ee6032 100644 --- a/cli/export.go +++ b/cli/export.go @@ -23,6 +23,7 @@ type ExportCommandInput struct { SessionDuration time.Duration NoSession bool UseStdout bool + ParallelSafe bool } var ( @@ -66,6 +67,7 @@ func ConfigureExportCommand(app *kingpin.Application, a *AwsVault) { StringVar(&input.ProfileName) cmd.Action(func(c *kingpin.ParseContext) (err error) { + input.ParallelSafe = a.ParallelSafe input.Config.MfaPromptMethod = a.PromptDriver(false) input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration input.Config.AssumeRoleDuration = input.SessionDuration @@ -108,7 +110,13 @@ func ExportCommand(input ExportCommandInput, f *vault.ConfigFile, keyring keyrin } ckr := &vault.CredentialKeyring{Keyring: keyring} - credsProvider, err := vault.NewTempCredentialsProvider(config, ckr, input.NoSession, false) + credsProvider, err := vault.NewTempCredentialsProviderWithOptions( + config, + ckr, + input.NoSession, + false, + vault.TempCredentialsOptions{ParallelSafe: input.ParallelSafe}, + ) if err != nil { return fmt.Errorf("Error getting temporary credentials: %w", err) } diff --git a/cli/export_test.go b/cli/export_test.go index 8f02e398..94a67de4 100644 --- a/cli/export_test.go +++ b/cli/export_test.go @@ -9,7 +9,7 @@ import ( func ExampleExportCommand() { app := kingpin.New("aws-vault", "") awsVault := ConfigureGlobals(app) - awsVault.keyringImpl = keyring.NewArrayKeyring([]keyring.Item{ + awsVault.rawKeyringImpl = keyring.NewArrayKeyring([]keyring.Item{ {Key: "llamas", Data: []byte(`{"AccessKeyID":"ABC","SecretAccessKey":"XYZ"}`)}, }) ConfigureExportCommand(app, awsVault) diff --git a/cli/global.go b/cli/global.go index 0a2f618c..eac08fa5 100644 --- a/cli/global.go +++ b/cli/global.go @@ -37,9 +37,11 @@ type AwsVault struct { KeyringConfig keyring.Config KeyringBackend string promptDriver string + ParallelSafe bool - keyringImpl keyring.Keyring - awsConfigFile *vault.ConfigFile + rawKeyringImpl keyring.Keyring + keyringImpl keyring.Keyring + awsConfigFile *vault.ConfigFile UseBiometrics bool } @@ -68,18 +70,85 @@ func (a *AwsVault) PromptDriver(avoidTerminalPrompt bool) string { } func (a *AwsVault) Keyring() (keyring.Keyring, error) { - if a.keyringImpl == nil { + raw, err := a.rawKeyring() + if err != nil { + return nil, err + } + if a.ParallelSafe { + if a.keyringImpl == nil { + lockKey := a.keyringLockKey() + a.keyringImpl = vault.NewLockedKeyring(raw, lockKey) + } + return a.keyringImpl, nil + } + return raw, nil +} + +// RawKeyring returns the keyring without the parallel-safe lock wrapper. +// Used by commands like login that are excluded from --parallel-safe. +func (a *AwsVault) RawKeyring() (keyring.Keyring, error) { + return a.rawKeyring() +} + +func (a *AwsVault) rawKeyring() (keyring.Keyring, error) { + if a.rawKeyringImpl == nil { if a.KeyringBackend != "" { a.KeyringConfig.AllowedBackends = []keyring.BackendType{keyring.BackendType(a.KeyringBackend)} } var err error - a.keyringImpl, err = keyring.Open(a.KeyringConfig) + a.rawKeyringImpl, err = keyring.Open(a.KeyringConfig) if err != nil { return nil, err } } + return a.rawKeyringImpl, nil +} - return a.keyringImpl, nil +// keyringLockKey returns a backend-specific key for the cross-process keyring +// lock. Different backends (and different configurations of the same backend) +// produce different keys so they don't contend on the same lock file. +func (a *AwsVault) keyringLockKey() string { + backend := a.KeyringBackend + switch keyring.BackendType(backend) { + case keyring.KeychainBackend: + if a.KeyringConfig.KeychainName != "" { + return backend + ":" + a.KeyringConfig.KeychainName + } + case keyring.FileBackend: + if a.KeyringConfig.FileDir != "" { + return backend + ":" + a.KeyringConfig.FileDir + } + case keyring.PassBackend: + key := backend + if a.KeyringConfig.PassDir != "" { + key += ":" + a.KeyringConfig.PassDir + } + if a.KeyringConfig.PassPrefix != "" { + key += ":" + a.KeyringConfig.PassPrefix + } + return key + case keyring.SecretServiceBackend: + if a.KeyringConfig.LibSecretCollectionName != "" { + return backend + ":" + a.KeyringConfig.LibSecretCollectionName + } + case keyring.KWalletBackend: + if a.KeyringConfig.KWalletFolder != "" { + return backend + ":" + a.KeyringConfig.KWalletFolder + } + case keyring.WinCredBackend: + if a.KeyringConfig.WinCredPrefix != "" { + return backend + ":" + a.KeyringConfig.WinCredPrefix + } + case keyring.OPBackend, keyring.OPConnectBackend, keyring.OPDesktopBackend: + if a.KeyringConfig.OPVaultID != "" { + return backend + ":" + a.KeyringConfig.OPVaultID + } + } + // Fall back to backend name, which is always set (defaults to first available). + if backend != "" { + return backend + } + return "aws-vault" } func (a *AwsVault) AwsConfigFile() (*vault.ConfigFile, error) { @@ -201,6 +270,10 @@ func ConfigureGlobals(app *kingpin.Application) *AwsVault { Envar("AWS_VAULT_BIOMETRICS"). BoolVar(&a.UseBiometrics) + app.Flag("parallel-safe", "Enable cross-process locking for keyring operations, session caching, and SSO browser flows"). + Envar("AWS_VAULT_PARALLEL_SAFE"). + BoolVar(&a.ParallelSafe) + app.PreAction(func(c *kingpin.ParseContext) error { if !a.Debug { log.SetOutput(io.Discard) diff --git a/cli/global_test.go b/cli/global_test.go new file mode 100644 index 00000000..19afbf5a --- /dev/null +++ b/cli/global_test.go @@ -0,0 +1,177 @@ +package cli + +import ( + "testing" + + "github.com/byteness/keyring" +) + +func TestKeyringLockKey(t *testing.T) { + tests := []struct { + name string + backend string + config keyring.Config + want string + }{ + // Keychain backend + { + name: "keychain with keychain name", + backend: "keychain", + config: keyring.Config{KeychainName: "my-keychain"}, + want: "keychain:my-keychain", + }, + { + name: "keychain with empty keychain name", + backend: "keychain", + config: keyring.Config{}, + want: "keychain", + }, + + // File backend + { + name: "file with file dir", + backend: "file", + config: keyring.Config{FileDir: "/tmp/keys"}, + want: "file:/tmp/keys", + }, + { + name: "file with empty file dir", + backend: "file", + config: keyring.Config{}, + want: "file", + }, + + // Pass backend: dir and prefix combinations + { + name: "pass with dir and prefix", + backend: "pass", + config: keyring.Config{PassDir: "/store", PassPrefix: "aws"}, + want: "pass:/store:aws", + }, + { + name: "pass with dir only", + backend: "pass", + config: keyring.Config{PassDir: "/store"}, + want: "pass:/store", + }, + { + name: "pass with prefix only", + backend: "pass", + config: keyring.Config{PassPrefix: "aws"}, + want: "pass:aws", + }, + { + name: "pass with neither dir nor prefix", + backend: "pass", + config: keyring.Config{}, + want: "pass", + }, + + // Secret-service backend + { + name: "secret-service with collection name", + backend: "secret-service", + config: keyring.Config{LibSecretCollectionName: "awsvault"}, + want: "secret-service:awsvault", + }, + { + name: "secret-service with empty collection name", + backend: "secret-service", + config: keyring.Config{}, + want: "secret-service", + }, + + // KWallet backend + { + name: "kwallet with folder", + backend: "kwallet", + config: keyring.Config{KWalletFolder: "aws-vault"}, + want: "kwallet:aws-vault", + }, + { + name: "kwallet with empty folder", + backend: "kwallet", + config: keyring.Config{}, + want: "kwallet", + }, + + // WinCred backend + { + name: "wincred with prefix", + backend: "wincred", + config: keyring.Config{WinCredPrefix: "aws-vault"}, + want: "wincred:aws-vault", + }, + { + name: "wincred with empty prefix", + backend: "wincred", + config: keyring.Config{}, + want: "wincred", + }, + + // 1Password backends (all share OPVaultID) + { + name: "op with vault ID", + backend: "op", + config: keyring.Config{OPVaultID: "vault-123"}, + want: "op:vault-123", + }, + { + name: "op with empty vault ID", + backend: "op", + config: keyring.Config{}, + want: "op", + }, + { + name: "op-connect with vault ID", + backend: "op-connect", + config: keyring.Config{OPVaultID: "vault-456"}, + want: "op-connect:vault-456", + }, + { + name: "op-connect with empty vault ID", + backend: "op-connect", + config: keyring.Config{}, + want: "op-connect", + }, + { + name: "op-desktop with vault ID", + backend: "op-desktop", + config: keyring.Config{OPVaultID: "vault-789"}, + want: "op-desktop:vault-789", + }, + { + name: "op-desktop with empty vault ID", + backend: "op-desktop", + config: keyring.Config{}, + want: "op-desktop", + }, + + // Fallback cases + { + name: "unknown backend falls back to backend name", + backend: "some-unknown-backend", + config: keyring.Config{}, + want: "some-unknown-backend", + }, + { + name: "empty backend falls back to aws-vault", + backend: "", + config: keyring.Config{}, + want: "aws-vault", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &AwsVault{ + KeyringBackend: tt.backend, + KeyringConfig: tt.config, + } + got := a.keyringLockKey() + if got != tt.want { + t.Errorf("keyringLockKey() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/cli/list_test.go b/cli/list_test.go index b1055e74..18e0bfb6 100644 --- a/cli/list_test.go +++ b/cli/list_test.go @@ -9,7 +9,7 @@ import ( func ExampleListCommand() { app := kingpin.New("aws-vault", "") awsVault := ConfigureGlobals(app) - awsVault.keyringImpl = keyring.NewArrayKeyring([]keyring.Item{ + awsVault.rawKeyringImpl = keyring.NewArrayKeyring([]keyring.Item{ {Key: "llamas", Data: []byte(`{"AccessKeyID":"ABC","SecretAccessKey":"XYZ"}`)}, }) ConfigureListCommand(app, awsVault) diff --git a/cli/login.go b/cli/login.go index 08a4ef3e..5bb0ec13 100644 --- a/cli/login.go +++ b/cli/login.go @@ -35,6 +35,10 @@ type LoginCommandInput struct { func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) { input := LoginCommandInput{} + // NOTE: login intentionally does not use --parallel-safe. The login command + // opens a browser for an interactive console session — you cannot meaningfully + // log in to multiple AWS consoles in parallel, so there is no concurrent-access + // problem for --parallel-safe to solve here. cmd := app.Command("login", "Generate a login link for the AWS Console.") cmd.Flag("duration", "Duration of the assume-role or federated session. Defaults to 1h"). @@ -75,7 +79,10 @@ func ConfigureLoginCommand(app *kingpin.Application, a *AwsVault) { input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration input.Config.AssumeRoleDuration = input.SessionDuration input.Config.GetFederationTokenDuration = input.SessionDuration - keyring, err := a.Keyring() + // Login uses the raw keyring without the parallel-safe lock wrapper. + // Console login is inherently single-use — there's no concurrent-access + // problem for --parallel-safe to solve here. + keyring, err := a.RawKeyring() if err != nil { return err } diff --git a/cli/rotate.go b/cli/rotate.go index 97a4462e..751d897d 100644 --- a/cli/rotate.go +++ b/cli/rotate.go @@ -14,9 +14,10 @@ import ( ) type RotateCommandInput struct { - NoSession bool - ProfileName string - Config vault.ProfileConfig + NoSession bool + ProfileName string + Config vault.ProfileConfig + ParallelSafe bool } func ConfigureRotateCommand(app *kingpin.Application, a *AwsVault) { @@ -34,6 +35,7 @@ func ConfigureRotateCommand(app *kingpin.Application, a *AwsVault) { StringVar(&input.ProfileName) cmd.Action(func(c *kingpin.ParseContext) (err error) { + input.ParallelSafe = a.ParallelSafe input.Config.MfaPromptMethod = a.PromptDriver(false) f, err := a.AwsConfigFile() @@ -97,7 +99,13 @@ func RotateCommand(input RotateCommandInput, f *vault.ConfigFile, keyring keyrin credsProvider = vault.NewMasterCredentialsProvider(ckr, config.ProfileName) } else { // Can't always disable sessions completely, might need to use session for MFA-Protected API Access - credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, input.NoSession, true) + credsProvider, err = vault.NewTempCredentialsProviderWithOptions( + config, + ckr, + input.NoSession, + true, + vault.TempCredentialsOptions{ParallelSafe: input.ParallelSafe}, + ) if err != nil { return fmt.Errorf("Error getting temporary credentials: %w", err) } diff --git a/go.mod b/go.mod index 717c6967..6c95a0e7 100644 --- a/go.mod +++ b/go.mod @@ -12,9 +12,11 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 + github.com/aws/smithy-go v1.24.2 github.com/byteness/keyring v1.9.0 github.com/charmbracelet/huh v1.0.0 github.com/charmbracelet/lipgloss v1.1.0 + github.com/gofrs/flock v0.13.0 github.com/google/go-cmp v0.7.0 github.com/mattn/go-isatty v0.0.21 github.com/mattn/go-tty v0.0.7 @@ -35,7 +37,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect - github.com/aws/smithy-go v1.24.2 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/byteness/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect github.com/byteness/go-libsecret v0.0.0-20260108215642-107379d3dee0 // indirect @@ -74,6 +75,7 @@ require ( github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/tetratelabs/wabin v0.0.0-20230304001439-f6f874872834 // indirect github.com/tetratelabs/wazero v1.11.0 // indirect github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect diff --git a/go.sum b/go.sum index e5e46070..f1381686 100644 --- a/go.sum +++ b/go.sum @@ -116,6 +116,8 @@ github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= +github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= +github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -126,6 +128,8 @@ github.com/ianlancetaylor/demangle v0.0.0-20251118225945-96ee0021ea0f h1:Fnl4pzx github.com/ianlancetaylor/demangle v0.0.0-20251118225945-96ee0021ea0f/go.mod h1:gx7rwoVhcfuVKG5uya9Hs3Sxj7EIvldVofAWIUtGouw= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= @@ -154,7 +158,6 @@ github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELU github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/noamcohen97/touchid-go v0.3.0 h1:fcXxVCizysD7KHRR6hrURt3nyNIs5JBGSbOIidD/3wo= github.com/noamcohen97/touchid-go v0.3.0/go.mod h1:X9MRNIBGEmPqwpDm1G3fQOAQX7fwBlhzUbnkDTxuta0= @@ -166,6 +169,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -236,8 +241,9 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b h1:QRR6H1YWRnHb4Y/HeNFCTJLFVxaq6wH4YuVdsUOr75U= gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/ini.v1 v1.67.1 h1:tVBILHy0R6e4wkYOn3XmiITt/hEVH4TFMYvAX2Ytz6k= gopkg.in/ini.v1 v1.67.1/go.mod h1:x/cyOwCgZqOkJoDIJ3c1KNHMo10+nLGAhh+kn3Zizss= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/vault/cachedsessionprovider.go b/vault/cachedsessionprovider.go index 1a382d6b..a8c3dc4a 100644 --- a/vault/cachedsessionprovider.go +++ b/vault/cachedsessionprovider.go @@ -2,11 +2,15 @@ package vault import ( "context" + "errors" + "fmt" "log" + "os" "time" "github.com/aws/aws-sdk-go-v2/aws" ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/byteness/keyring" ) type StsSessionProvider interface { @@ -21,23 +25,147 @@ type CachedSessionProvider struct { SessionProvider StsSessionProvider Keyring *SessionKeyring ExpiryWindow time.Duration + UseSessionLock bool + sessionLock ProcessLock + sessionLockWait time.Duration + sessionLockLog time.Duration + sessionLockTimeout time.Duration + sessionNow func() time.Time + sessionSleep func(context.Context, time.Duration) error + sessionLogf lockLogger +} + +const ( + // defaultSessionLockWaitDelay is the polling interval between lock attempts. + // 100ms keeps latency low for the typical case where the lock holder + // finishes quickly (STS call + cache write). + defaultSessionLockWaitDelay = 100 * time.Millisecond + + // defaultSessionLockLogEvery controls how often we emit a debug log while + // waiting for the lock. 15s avoids log spam while still showing progress. + defaultSessionLockLogEvery = 15 * time.Second + + // defaultSessionLockWarnAfter is the delay before printing a user-visible + // "waiting for lock" message to stderr. 5s is long enough to avoid + // flashing the message on normal lock contention, short enough to + // reassure the user that the process isn't hung. + defaultSessionLockWarnAfter = 5 * time.Second + + // defaultSessionLockTimeout is a safety net: if the lock holder is hung, + // waiters give up after this duration rather than blocking indefinitely. + // 2 minutes matches the keyring lock timeout. + defaultSessionLockTimeout = 2 * time.Minute +) + + +// NewCachedSessionProvider creates a CachedSessionProvider with production +// defaults for all internal dependencies. Tests can override unexported fields +// (sessionLock, sessionNow, etc.) after construction to inject mocks. +func NewCachedSessionProvider(key SessionMetadata, provider StsSessionProvider, keyring *SessionKeyring, expiryWindow time.Duration, useSessionLock bool) *CachedSessionProvider { + return &CachedSessionProvider{ + SessionKey: key, + SessionProvider: provider, + Keyring: keyring, + ExpiryWindow: expiryWindow, + UseSessionLock: useSessionLock, + sessionLock: NewDefaultLock("aws-vault.session", key.StringForMatching()), + sessionLockWait: defaultSessionLockWaitDelay, + sessionLockLog: defaultSessionLockLogEvery, + sessionLockTimeout: defaultSessionLockTimeout, + sessionNow: time.Now, + sessionSleep: defaultContextSleep, + sessionLogf: log.Printf, + } } func (p *CachedSessionProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { - creds, err := p.Keyring.Get(p.SessionKey) + creds, cached, err := p.getCachedSession() + if err != nil && !errors.Is(err, keyring.ErrKeyNotFound) { + log.Printf("Reading cached session for %s: %v; will refresh", p.SessionKey.ProfileName, err) + } + if err == nil && cached { + return creds, nil + } + + if !p.UseSessionLock { + return p.getSessionWithoutLock(ctx) + } + + return p.getSessionWithLock(ctx) +} + +func (p *CachedSessionProvider) getCachedSession() (creds *ststypes.Credentials, cached bool, err error) { + creds, err = p.Keyring.Get(p.SessionKey) + if err != nil { + return nil, false, err + } + if time.Until(*creds.Expiration) < p.ExpiryWindow { + return nil, false, nil + } + log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String()) + return creds, true, nil +} + +func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststypes.Credentials, error) { + waitCtx, cancel := context.WithTimeout(ctx, p.sessionLockTimeout) + defer cancel() + + return withProcessLock(waitCtx, p.sessionLock, lockWaiterOpts{ + LockPath: p.sessionLock.Path(), + WarnMsg: "Waiting for session lock at %s\n", + LogMsg: "Waiting for session lock at %s", + WaitDelay: p.sessionLockWait, + LogEvery: p.sessionLockLog, + WarnAfter: defaultSessionLockWarnAfter, + Now: p.sessionNow, + Sleep: p.sessionSleep, + Logf: p.sessionLogf, + Warnf: func(format string, args ...any) { + fmt.Fprintf(os.Stderr, format, args...) + }, + }, "session", func() (processLockResult[*ststypes.Credentials], error) { + creds, cached, err := p.getCachedSession() + if err != nil && !errors.Is(err, keyring.ErrKeyNotFound) { + log.Printf("Reading cached session for %s: %v; will try lock", p.SessionKey.ProfileName, err) + } + if err == nil && cached { + return processLockResult[*ststypes.Credentials]{value: creds, ok: true}, nil + } + return processLockResult[*ststypes.Credentials]{}, nil + }, func(_ context.Context) (*ststypes.Credentials, error) { + // Use the caller's ctx (not the lock-wait timeout ctx) for actual + // work. The lock-wait timeout bounds how long we wait for the lock, + // not how long the work takes once we hold it. + + // Recheck cache after acquiring lock — another process may have filled it. + creds, cached, cacheErr := p.getCachedSession() + if cacheErr == nil && cached { + return creds, nil + } - if err != nil || time.Until(*creds.Expiration) < p.ExpiryWindow { - // lookup missed, we need to create a new one. - creds, err = p.SessionProvider.RetrieveStsCredentials(ctx) + creds, err := p.SessionProvider.RetrieveStsCredentials(ctx) if err != nil { return nil, err } - err = p.Keyring.Set(p.SessionKey, creds) - if err != nil { + if err = p.Keyring.Set(p.SessionKey, creds); err != nil { return nil, err } - } else { - log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String()) + return creds, nil + }) +} + +func (p *CachedSessionProvider) getSessionWithoutLock(ctx context.Context) (*ststypes.Credentials, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + creds, err := p.SessionProvider.RetrieveStsCredentials(ctx) + if err != nil { + return nil, err + } + + if err = p.Keyring.Set(p.SessionKey, creds); err != nil { + return nil, err } return creds, nil diff --git a/vault/cachedsessionprovider_lock_test.go b/vault/cachedsessionprovider_lock_test.go new file mode 100644 index 00000000..919fc447 --- /dev/null +++ b/vault/cachedsessionprovider_lock_test.go @@ -0,0 +1,303 @@ +package vault + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/byteness/keyring" +) + +type testSessionProvider struct { + creds *types.Credentials + calls int + onRetrieve func() +} + +func (p *testSessionProvider) RetrieveStsCredentials(context.Context) (*types.Credentials, error) { + p.calls++ + if p.onRetrieve != nil { + p.onRetrieve() + } + return p.creds, nil +} + +func (p *testSessionProvider) Retrieve(context.Context) (aws.Credentials, error) { + return aws.Credentials{}, nil +} + +type lockCheckingKeyring struct { + keyring.Keyring + setCalls int + setLock *testLock +} + +func (k *lockCheckingKeyring) Set(item keyring.Item) error { + k.setCalls++ + if k.setLock != nil && !k.setLock.locked { + return fmt.Errorf("lock not held during cache set") + } + return k.Keyring.Set(item) +} + +func newTestSessionKey() SessionMetadata { + return SessionMetadata{ + Type: "sso.GetRoleCredentials", + ProfileName: "test-profile", + MfaSerial: "https://sso.example", + } +} + +func newTestCreds(expires time.Time) *types.Credentials { + return &types.Credentials{ + AccessKeyId: aws.String("AKIATEST"), + SecretAccessKey: aws.String("secret"), + SessionToken: aws.String("token"), + Expiration: aws.Time(expires), + } +} + +func TestCachedSession_CacheHit_NoLock(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + if err := sk.Set(key, creds); err != nil { + t.Fatalf("set cache: %v", err) + } + + lock := &testLock{} + provider := &testSessionProvider{ + onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called on cache hit") }, + } + + p := NewCachedSessionProvider(key, provider, sk, 0, true) + p.sessionLock = lock + + got, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if aws.ToString(got.AccessKeyId) != aws.ToString(creds.AccessKeyId) { + t.Fatalf("unexpected credentials returned") + } + if lock.tryCalls != 0 { + t.Fatalf("expected no lock attempts, got %d", lock.tryCalls) + } + if provider.calls != 0 { + t.Fatalf("expected no provider calls, got %d", provider.calls) + } +} + +func TestCachedSession_LockDisabled_SkipsLock(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + lock := &testLock{tryResults: []bool{true}} + provider := &testSessionProvider{creds: creds} + + p := NewCachedSessionProvider(key, provider, sk, 0, false) + p.sessionLock = lock + + got, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if aws.ToString(got.AccessKeyId) != aws.ToString(creds.AccessKeyId) { + t.Fatalf("unexpected credentials returned") + } + if lock.tryCalls != 0 { + t.Fatalf("expected no lock attempts, got %d", lock.tryCalls) + } + if provider.calls != 1 { + t.Fatalf("expected 1 provider call, got %d", provider.calls) + } + if _, err := sk.Get(key); err != nil { + t.Fatalf("expected cached credentials, got %v", err) + } +} + +func TestCachedSession_LockMiss_ThenCacheHit_NoRefresh(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + lock := &testLock{tryResults: []bool{false}} + + provider := &testSessionProvider{ + onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called when cache fills while waiting") }, + } + + p := NewCachedSessionProvider(key, provider, sk, 0, true) + p.sessionLock = lock + p.sessionLockWait = 5 * time.Second + p.sessionSleep = func(ctx context.Context, d time.Duration) error { + return sk.Set(key, creds) + } + + got, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if aws.ToString(got.AccessKeyId) != aws.ToString(creds.AccessKeyId) { + t.Fatalf("unexpected credentials returned") + } + if lock.tryCalls != 1 { + t.Fatalf("expected 1 lock attempt, got %d", lock.tryCalls) + } + if provider.calls != 0 { + t.Fatalf("expected no provider calls, got %d", provider.calls) + } +} + +func TestCachedSession_LockAcquired_RecheckCache(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + lock := &testLock{tryResults: []bool{true}} + lock.onTry = func(l *testLock) { + if l.locked { + _ = sk.Set(key, creds) + } + } + + provider := &testSessionProvider{ + onRetrieve: func() { t.Fatal("RetrieveStsCredentials should not be called when cache fills after lock") }, + } + + p := NewCachedSessionProvider(key, provider, sk, 0, true) + p.sessionLock = lock + + got, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if aws.ToString(got.AccessKeyId) != aws.ToString(creds.AccessKeyId) { + t.Fatalf("unexpected credentials returned") + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } + if provider.calls != 0 { + t.Fatalf("expected no provider calls, got %d", provider.calls) + } +} + +func TestCachedSession_LockHeldThroughCacheSet(t *testing.T) { + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + lock := &testLock{tryResults: []bool{true}} + wrappedKeyring := &lockCheckingKeyring{ + Keyring: keyring.NewArrayKeyring(nil), + setLock: lock, + } + sk := &SessionKeyring{Keyring: wrappedKeyring} + provider := &testSessionProvider{creds: creds} + + p := NewCachedSessionProvider(key, provider, sk, 0, true) + p.sessionLock = lock + + _, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if wrappedKeyring.setCalls != 1 { + t.Fatalf("expected cache set once, got %d", wrappedKeyring.setCalls) + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } + if provider.calls != 1 { + t.Fatalf("expected 1 provider call, got %d", provider.calls) + } +} + +func TestCachedSession_LockWaitLogs(t *testing.T) { + lock := &testLock{tryResults: []bool{false, false, false, false}} + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + key := newTestSessionKey() + provider := &testSessionProvider{} + + ctx, cancel := context.WithCancel(context.Background()) + clock := &testClock{now: time.Unix(0, 0), cancel: cancel, cancelAfter: 4} + var logTimes []time.Time + + p := NewCachedSessionProvider(key, provider, sk, 0, true) + p.sessionLock = lock + p.sessionLockWait = 5 * time.Second + p.sessionLockLog = 15 * time.Second + p.sessionNow = clock.Now + p.sessionSleep = clock.Sleep + p.sessionLogf = func(string, ...any) { + logTimes = append(logTimes, clock.Now()) + } + + _, err := p.RetrieveStsCredentials(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context cancellation, got %v", err) + } + if len(logTimes) != 2 { + t.Fatalf("expected 2 log entries, got %d", len(logTimes)) + } + if !logTimes[0].Equal(time.Unix(0, 0)) { + t.Fatalf("unexpected first log time: %s", logTimes[0]) + } + if !logTimes[1].Equal(time.Unix(15, 0)) { + t.Fatalf("unexpected second log time: %s", logTimes[1]) + } +} + +// ctxAwareSessionProvider is a test StsSessionProvider that respects context +// cancellation, used to verify the work context is not the lock-wait timeout. +type ctxAwareSessionProvider struct { + creds *types.Credentials + delay time.Duration +} + +func (p *ctxAwareSessionProvider) RetrieveStsCredentials(ctx context.Context) (*types.Credentials, error) { + select { + case <-time.After(p.delay): + return p.creds, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (p *ctxAwareSessionProvider) Retrieve(context.Context) (aws.Credentials, error) { + return aws.Credentials{}, nil +} + +func TestCachedSession_WorkNotCancelledByLockTimeout(t *testing.T) { + // The lock-wait timeout should only bound how long we wait for the + // lock, not how long the work takes. Simulate a provider that takes + // longer than the lock-wait timeout and verify it completes. + key := newTestSessionKey() + creds := newTestCreds(time.Now().Add(time.Hour)) + kr := keyring.NewArrayKeyring(nil) + sk := &SessionKeyring{Keyring: kr} + lock := &testLock{tryResults: []bool{true}} + + provider := &ctxAwareSessionProvider{creds: creds, delay: 50 * time.Millisecond} + + p := NewCachedSessionProvider(key, provider, sk, 0, true) + p.sessionLock = lock + p.sessionLockTimeout = 10 * time.Millisecond + + got, err := p.RetrieveStsCredentials(context.Background()) + if err != nil { + t.Fatalf("unexpected error (work cancelled by lock-wait timeout?): %v", err) + } + if aws.ToString(got.AccessKeyId) != aws.ToString(creds.AccessKeyId) { + t.Fatalf("unexpected credentials returned") + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } +} diff --git a/vault/lock_test.go b/vault/lock_test.go new file mode 100644 index 00000000..8a5a0f27 --- /dev/null +++ b/vault/lock_test.go @@ -0,0 +1,66 @@ +package vault + +import ( + "context" + "time" +) + +type testLock struct { + tryResults []bool + tryCalls int + unlockCalls int + locked bool + path string + onTry func(*testLock) +} + +func (l *testLock) TryLock() (bool, error) { + l.tryCalls++ + locked := false + if l.tryCalls <= len(l.tryResults) { + locked = l.tryResults[l.tryCalls-1] + } + if locked { + l.locked = true + } + if l.onTry != nil { + l.onTry(l) + } + return locked, nil +} + +func (l *testLock) Unlock() error { + l.unlockCalls++ + l.locked = false + return nil +} + +func (l *testLock) Path() string { + if l.path != "" { + return l.path + } + return "/tmp/aws-vault.lock" +} + +type testClock struct { + now time.Time + sleepCalls int + cancelAfter int + cancel context.CancelFunc +} + +func (c *testClock) Now() time.Time { + return c.now +} + +func (c *testClock) Sleep(ctx context.Context, d time.Duration) error { + c.sleepCalls++ + c.now = c.now.Add(d) + if c.cancel != nil && 0 < c.cancelAfter && c.cancelAfter <= c.sleepCalls { + c.cancel() + } + if ctx.Err() != nil { + return ctx.Err() + } + return nil +} diff --git a/vault/lock_waiter.go b/vault/lock_waiter.go new file mode 100644 index 00000000..7a6488c9 --- /dev/null +++ b/vault/lock_waiter.go @@ -0,0 +1,60 @@ +package vault + +import ( + "context" + "time" +) + +type lockLogger func(string, ...any) + +// lockWaiterOpts configures a lockWaiter. All fields are required except +// Now, Sleep, and Warnf which have sensible defaults. +type lockWaiterOpts struct { + LockPath string + WarnMsg string + LogMsg string + WaitDelay time.Duration + LogEvery time.Duration + WarnAfter time.Duration + Now func() time.Time + Sleep func(context.Context, time.Duration) error + Logf lockLogger + Warnf lockLogger +} + +type lockWaiter struct { + opts lockWaiterOpts + + lastLog time.Time + waitStart time.Time + warned bool +} + +func newLockWaiter(opts lockWaiterOpts) *lockWaiter { + if opts.Now == nil { + opts.Now = time.Now + } + if opts.Sleep == nil { + opts.Sleep = defaultContextSleep + } + return &lockWaiter{opts: opts} +} + +func (w *lockWaiter) sleepAfterMiss(ctx context.Context) error { + now := w.opts.Now() + if w.waitStart.IsZero() { + w.waitStart = now + } + if !w.warned && w.opts.WarnAfter <= now.Sub(w.waitStart) { + if w.opts.Warnf != nil { + w.opts.Warnf(w.opts.WarnMsg, w.opts.LockPath) + } + w.warned = true + } + if w.opts.Logf != nil && (w.lastLog.IsZero() || w.opts.LogEvery <= now.Sub(w.lastLog)) { + w.opts.Logf(w.opts.LogMsg, w.opts.LockPath) + w.lastLog = now + } + + return w.opts.Sleep(ctx, w.opts.WaitDelay) +} diff --git a/vault/locked_keyring.go b/vault/locked_keyring.go new file mode 100644 index 00000000..51530de6 --- /dev/null +++ b/vault/locked_keyring.go @@ -0,0 +1,157 @@ +package vault + +import ( + "context" + "fmt" + "log" + "os" + "sync" + "time" + + "github.com/byteness/keyring" +) + +type lockedKeyring struct { + inner keyring.Keyring + lock ProcessLock + // mu serializes in-process access. The flock only coordinates across + // processes; without this mutex, concurrent goroutines in the same + // process could race on the try-lock loop. + // + // NOTE: mu.Lock() blocks without a timeout. The 2-minute timeout + // (lockTimeout) only applies to the flock wait loop inside withLock. + // If a keyring operation hangs while holding mu (e.g. a stuck gpg + // subprocess), other goroutines in the same process will block + // indefinitely. The keyring.Keyring interface is not context-aware, + // so there is no clean way to cancel in-flight operations. + mu sync.Mutex + + lockKey string + lockTimeout time.Duration + lockWait time.Duration + lockLog time.Duration + warnAfter time.Duration + lockNow func() time.Time + lockSleep func(context.Context, time.Duration) error + lockLogf lockLogger +} + +const ( + // defaultKeyringLockWaitDelay is the polling interval between lock attempts. + // 100ms keeps latency low for the typical case where the lock holder + // finishes a single keyring read/write quickly. + defaultKeyringLockWaitDelay = 100 * time.Millisecond + + // defaultKeyringLockLogEvery controls how often we emit a debug log while + // waiting for the lock. 15s avoids log spam while still showing progress. + defaultKeyringLockLogEvery = 15 * time.Second + + // defaultKeyringLockWarnAfter is the delay before printing a user-visible + // "waiting for lock" message to stderr. 5s is long enough to avoid + // flashing the message on normal lock contention, short enough to + // reassure the user that the process isn't hung. + defaultKeyringLockWarnAfter = 5 * time.Second + + // defaultKeyringLockTimeout is a safety net: the keyring.Keyring interface + // is not context-aware, so if the lock holder is hung (e.g. a stuck gpg + // subprocess in the pass backend), waiters give up after this duration + // rather than blocking indefinitely. 2 minutes is generous enough for any + // reasonable keyring operation. + defaultKeyringLockTimeout = 2 * time.Minute +) + +// NewLockedKeyring wraps the provided keyring with a cross-process lock +// to serialize keyring operations. +func NewLockedKeyring(kr keyring.Keyring, lockKey string) keyring.Keyring { + return &lockedKeyring{ + inner: kr, + lock: NewDefaultLock("aws-vault.keyring", lockKey), + lockKey: lockKey, + lockTimeout: defaultKeyringLockTimeout, + lockWait: defaultKeyringLockWaitDelay, + lockLog: defaultKeyringLockLogEvery, + warnAfter: defaultKeyringLockWarnAfter, + lockNow: time.Now, + lockSleep: defaultContextSleep, + lockLogf: log.Printf, + } +} + + +func (k *lockedKeyring) withLock(fn func() error) error { + k.mu.Lock() + defer k.mu.Unlock() + + // The keyring.Keyring interface is not context-aware, so we cannot cancel + // in-flight keyring operations. This timeout is a safety net for the lock-wait + // loop: if the lock holder is hung (e.g. a stuck gpg subprocess in the pass + // backend), waiters will eventually give up rather than blocking indefinitely. + ctx, cancel := context.WithTimeout(context.Background(), k.lockTimeout) + defer cancel() + + _, err := withProcessLock(ctx, k.lock, lockWaiterOpts{ + LockPath: k.lock.Path(), + WarnMsg: "Waiting for keyring lock at %s\n", + LogMsg: "Waiting for keyring lock at %s", + WaitDelay: k.lockWait, + LogEvery: k.lockLog, + WarnAfter: k.warnAfter, + Now: k.lockNow, + Sleep: k.lockSleep, + Logf: k.lockLogf, + Warnf: func(format string, args ...any) { + fmt.Fprintf(os.Stderr, format, args...) + }, + }, "keyring", nil, func(_ context.Context) (struct{}, error) { + return struct{}{}, fn() + }) + return err +} + +func (k *lockedKeyring) Get(key string) (keyring.Item, error) { + var item keyring.Item + if err := k.withLock(func() error { + var err error + item, err = k.inner.Get(key) + return err + }); err != nil { + return keyring.Item{}, err + } + return item, nil +} + +func (k *lockedKeyring) GetMetadata(key string) (keyring.Metadata, error) { + var meta keyring.Metadata + if err := k.withLock(func() error { + var err error + meta, err = k.inner.GetMetadata(key) + return err + }); err != nil { + return keyring.Metadata{}, err + } + return meta, nil +} + +func (k *lockedKeyring) Set(item keyring.Item) error { + return k.withLock(func() error { + return k.inner.Set(item) + }) +} + +func (k *lockedKeyring) Remove(key string) error { + return k.withLock(func() error { + return k.inner.Remove(key) + }) +} + +func (k *lockedKeyring) Keys() ([]string, error) { + var keys []string + if err := k.withLock(func() error { + var err error + keys, err = k.inner.Keys() + return err + }); err != nil { + return nil, err + } + return keys, nil +} diff --git a/vault/locked_keyring_test.go b/vault/locked_keyring_test.go new file mode 100644 index 00000000..cfdb5b57 --- /dev/null +++ b/vault/locked_keyring_test.go @@ -0,0 +1,136 @@ +package vault + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/byteness/keyring" +) + +// testUnlockErrLock is a testLock variant whose Unlock returns a configured error. +type testUnlockErrLock struct { + testLock + unlockErr error +} + +func (l *testUnlockErrLock) Unlock() error { + l.unlockCalls++ + l.locked = false + return l.unlockErr +} + +func newTestLockedKeyring(inner keyring.Keyring, lock ProcessLock, clock *testClock) *lockedKeyring { + return &lockedKeyring{ + inner: inner, + lock: lock, + lockKey: "test", + lockTimeout: defaultKeyringLockTimeout, + lockWait: 100 * time.Millisecond, + lockLog: 15 * time.Second, + warnAfter: 5 * time.Second, + lockNow: clock.Now, + lockSleep: clock.Sleep, + lockLogf: func(string, ...any) {}, + } +} + +func TestLockedKeyring_LockWaitRetries(t *testing.T) { + // Lock fails twice, then succeeds on the third attempt. + lock := &testLock{tryResults: []bool{false, false, true}} + kr := keyring.NewArrayKeyring([]keyring.Item{ + {Key: "foo", Data: []byte("bar")}, + }) + clock := &testClock{now: time.Unix(0, 0)} + + lk := newTestLockedKeyring(kr, lock, clock) + + item, err := lk.Get("foo") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(item.Data) != "bar" { + t.Fatalf("unexpected data: %s", string(item.Data)) + } + if lock.tryCalls != 3 { + t.Fatalf("expected 3 lock attempts, got %d", lock.tryCalls) + } + if clock.sleepCalls != 2 { + t.Fatalf("expected 2 sleep calls, got %d", clock.sleepCalls) + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } +} + +func TestLockedKeyring_Timeout(t *testing.T) { + // Lock is never acquired. With a short lockTimeout the context should + // time out and withLock should return context.DeadlineExceeded. + lock := &testLock{} // tryResults is empty so TryLock always returns false + kr := keyring.NewArrayKeyring(nil) + clock := &testClock{now: time.Unix(0, 0)} + + lk := newTestLockedKeyring(kr, lock, clock) + // Use a very short real timeout so the test completes quickly. + lk.lockTimeout = 50 * time.Millisecond + // Use real sleep so the context deadline fires from wall-clock time. + lk.lockSleep = defaultContextSleep + lk.lockWait = 10 * time.Millisecond + + _, err := lk.Keys() + if err == nil { + t.Fatal("expected timeout error, got nil") + } + // The context should have been cancelled via timeout. The error will + // be context.DeadlineExceeded because withLock uses context.WithTimeout. + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected context.DeadlineExceeded, got: %v", err) + } + // Verify that at least one lock attempt was made before timing out. + if lock.tryCalls < 1 { + t.Fatalf("expected at least 1 lock attempt, got %d", lock.tryCalls) + } +} + +func TestLockedKeyring_UnlockErrorJoined(t *testing.T) { + // Both the work function and Unlock return errors; they should be joined + // via errors.Join so that errors.Is can unwrap both. + workErr := fmt.Errorf("work failed") + unlockErr := fmt.Errorf("unlock broken") + + lock := &testUnlockErrLock{ + testLock: testLock{tryResults: []bool{true}}, + unlockErr: unlockErr, + } + + // Use a keyring whose Remove always fails with workErr. + inner := &failingKeyring{removeErr: workErr} + clock := &testClock{now: time.Unix(0, 0)} + + lk := newTestLockedKeyring(inner, lock, clock) + + err := lk.Remove("anything") + if err == nil { + t.Fatal("expected error, got nil") + } + if !errors.Is(err, workErr) { + t.Fatalf("expected joined error to contain work error, got: %v", err) + } + // The unlock error is wrapped as "unlock keyring lock: " + // using %w, so errors.Is can unwrap through the wrapping. + if !errors.Is(err, unlockErr) { + t.Fatalf("expected joined error to contain unlock error, got: %v", err) + } +} + +// failingKeyring is a keyring.Keyring that returns configured errors. +type failingKeyring struct { + keyring.Keyring + removeErr error +} + +func (k *failingKeyring) Remove(string) error { + return k.removeErr +} diff --git a/vault/process_lock.go b/vault/process_lock.go new file mode 100644 index 00000000..93526ec6 --- /dev/null +++ b/vault/process_lock.go @@ -0,0 +1,68 @@ +package vault + +import ( + "context" + "crypto/sha256" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/gofrs/flock" +) + +// ProcessLock coordinates work across processes. +type ProcessLock interface { + TryLock() (bool, error) + Unlock() error + Path() string +} + +type fileProcessLock struct { + lock *flock.Flock +} + +// NewFileLock creates a lock at the provided path. +func NewFileLock(path string) ProcessLock { + return &fileProcessLock{lock: flock.New(path)} +} + +func (l *fileProcessLock) TryLock() (bool, error) { + return l.lock.TryLock() +} + +func (l *fileProcessLock) Unlock() error { + return l.lock.Unlock() +} + +func (l *fileProcessLock) Path() string { + return l.lock.Path() +} + +func defaultLockPath(filename string) string { + return filepath.Join(os.TempDir(), filename) +} + +func hashedLockFilename(prefix, key string) string { + sum := sha256.Sum256([]byte(key)) + return fmt.Sprintf("%s.%x.lock", prefix, sum) +} + +// NewDefaultLock creates a ProcessLock in the system temp directory. +// The lock file name is derived from the prefix and a SHA-256 hash of key. +func NewDefaultLock(prefix, key string) ProcessLock { + return NewFileLock(defaultLockPath(hashedLockFilename(prefix, key))) +} + +// defaultContextSleep sleeps for d, respecting ctx cancellation. +// Shared by all lock-wait loops. +func defaultContextSleep(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/vault/process_lock_loop.go b/vault/process_lock_loop.go new file mode 100644 index 00000000..b3066049 --- /dev/null +++ b/vault/process_lock_loop.go @@ -0,0 +1,68 @@ +package vault + +import ( + "context" + "errors" + "fmt" +) + +// processLockResult is the result of a cache check or locked work function. +// ok indicates whether a cached result was found. +type processLockResult[T any] struct { + value T + ok bool +} + +// withProcessLock implements the try/sleep/recheck lock protocol. +// +// On each iteration it calls checkCache; if that returns ok=true, the cached +// value is returned without acquiring the lock. Otherwise it tries the lock: +// if acquired, it calls doWork under the lock (unlocking on return). If the +// lock is not acquired, it sleeps and retries. +// +// checkCache may be nil, in which case the cache check is skipped. +func withProcessLock[T any]( + ctx context.Context, + lock ProcessLock, + waiterOpts lockWaiterOpts, + lockName string, + checkCache func() (processLockResult[T], error), + doWork func(ctx context.Context) (T, error), +) (T, error) { + waiter := newLockWaiter(waiterOpts) + + for { + if checkCache != nil { + result, err := checkCache() + if err != nil { + var zero T + return zero, err + } + if result.ok { + return result.value, nil + } + } + if ctx.Err() != nil { + var zero T + return zero, ctx.Err() + } + + locked, err := lock.TryLock() + if err != nil { + var zero T + return zero, err + } + if locked { + result, workErr := doWork(ctx) + if unlockErr := lock.Unlock(); unlockErr != nil { + return result, errors.Join(workErr, fmt.Errorf("unlock %s lock: %w", lockName, unlockErr)) + } + return result, workErr + } + + if err = waiter.sleepAfterMiss(ctx); err != nil { + var zero T + return zero, err + } + } +} diff --git a/vault/ssorolecredentialsprovider.go b/vault/ssorolecredentialsprovider.go index 9854f4ee..f12ba217 100644 --- a/vault/ssorolecredentialsprovider.go +++ b/vault/ssorolecredentialsprovider.go @@ -5,11 +5,13 @@ import ( "errors" "fmt" "log" + "math/rand" "net/http" "os" + "strconv" + "strings" "time" - "github.com/byteness/keyring" "github.com/aws/aws-sdk-go-v2/aws" awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" "github.com/aws/aws-sdk-go-v2/service/sso" @@ -17,6 +19,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ssooidc" ssooidctypes "github.com/aws/aws-sdk-go-v2/service/ssooidc/types" ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/byteness/keyring" "github.com/skratchdot/open-golang/open" ) @@ -28,19 +31,97 @@ type OIDCTokenCacher interface { // SSORoleCredentialsProvider creates temporary credentials for an SSO Role. type SSORoleCredentialsProvider struct { - OIDCClient *ssooidc.Client - OIDCTokenCache OIDCTokenCacher - StartURL string - SSOClient *sso.Client - AccountID string - RoleName string - UseStdout bool + OIDCClient *ssooidc.Client + OIDCTokenCache OIDCTokenCacher + StartURL string + SSOClient *sso.Client + AccountID string + RoleName string + UseStdout bool + UseSSOTokenLock bool + ssoTokenLock ProcessLock + ssoLockWait time.Duration + ssoLockLog time.Duration + ssoLockTimeout time.Duration + ssoNow func() time.Time + ssoSleep func(context.Context, time.Duration) error + ssoLogf lockLogger + newOIDCTokenFn func(context.Context) (*ssooidc.CreateTokenOutput, error) } func millisecondsTimeValue(v int64) time.Time { return time.Unix(0, v*int64(time.Millisecond)) } +const ( + // defaultSSOLockWaitDelay is the polling interval between lock attempts. + // 100ms keeps latency low for the typical case where the lock holder + // finishes quickly (browser auth + token cache write). + defaultSSOLockWaitDelay = 100 * time.Millisecond + + // defaultSSOLockLogEvery controls how often we emit a debug log while + // waiting for the lock. 15s avoids log spam while still showing progress + // during long waits (e.g. slow browser auth). + defaultSSOLockLogEvery = 15 * time.Second + + // defaultSSOLockWarnAfter is the delay before printing a user-visible + // "waiting for lock" message to stderr. 5s is long enough to avoid + // flashing the message on normal lock contention, short enough to + // reassure the user that the process isn't hung. + defaultSSOLockWarnAfter = 5 * time.Second + + // defaultSSOLockTimeout is a safety net: if the lock holder is hung + // (e.g. a browser auth that was abandoned), waiters give up after this + // duration rather than blocking indefinitely. 2 minutes matches the + // keyring lock timeout. + defaultSSOLockTimeout = 2 * time.Minute + + // ssoRetryTimeout is a pathological safety net: if GetRoleCredentials is still + // returning 429s after this duration, give up and surface the error to the user. + // 5 minutes is generous but accommodates burst-heavy credential_process workloads + // (e.g. Terraform with hundreds of parallel invocations). + ssoRetryTimeout = 5 * time.Minute + + // ssoRetryBase is the initial backoff delay before the first retry. + // 200ms is short enough to avoid unnecessary latency on transient 429s + // while still giving the SSO service breathing room. + ssoRetryBase = 200 * time.Millisecond + + // ssoRetryMax caps the exponential backoff so that individual waits + // don't grow unreasonably large between attempts. + ssoRetryMax = 5 * time.Second + + // ssoRetryAfterJitterMin and ssoRetryAfterJitterMax define the full-jitter + // range as a multiplier of the base delay. 0.5x-1.5x ensures retries + // spread across a wide window — some fire earlier than the base delay, + // some later — which decorrelates concurrent processes that all received + // the same Retry-After header from the SSO service. + ssoRetryAfterJitterMin = 0.5 + ssoRetryAfterJitterMax = 1.5 +) + + +// initSSODefaults sets production defaults for all internal dependencies. +// Called by the constructor; tests can override unexported fields afterward. +func (p *SSORoleCredentialsProvider) initSSODefaults() { + p.ssoLockWait = defaultSSOLockWaitDelay + p.ssoLockLog = defaultSSOLockLogEvery + p.ssoLockTimeout = defaultSSOLockTimeout + p.ssoNow = time.Now + p.ssoSleep = defaultContextSleep + p.ssoLogf = log.Printf + p.newOIDCTokenFn = p.newOIDCToken +} + +// EnableSSOTokenLock creates the SSO token lock for cross-process coordination. +// Called at construction time when parallelSafe is true. +func (p *SSORoleCredentialsProvider) EnableSSOTokenLock() { + p.UseSSOTokenLock = true + if !p.UseStdout && p.ssoTokenLock == nil { + p.ssoTokenLock = NewDefaultLock("aws-vault.sso", p.StartURL) + } +} + // Retrieve generates a new set of temporary credentials using SSO GetRoleCredentials. func (p *SSORoleCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { creds, err := p.getRoleCredentials(ctx) @@ -63,42 +144,75 @@ func (p *SSORoleCredentialsProvider) getRoleCredentials(ctx context.Context) (*s return nil, err } - resp, err := p.SSOClient.GetRoleCredentials(ctx, &sso.GetRoleCredentialsInput{ - AccessToken: token.AccessToken, - AccountId: aws.String(p.AccountID), - RoleName: aws.String(p.RoleName), - }) - if err != nil { + baseDelay, maxDelay := ssoRetryBase, ssoRetryMax + deadline := p.ssoNow().Add(ssoRetryTimeout) + attempt := 0 + rateLimitCount := 0 + var maxRetryAfterSeen time.Duration + for { + attempt++ + resp, err := p.SSOClient.GetRoleCredentials(ctx, &sso.GetRoleCredentialsInput{ + AccessToken: token.AccessToken, + AccountId: aws.String(p.AccountID), + RoleName: aws.String(p.RoleName), + }) + if err == nil { + log.Printf("Got credentials %s for SSO role %s (account: %s), expires in %s", FormatKeyForDisplay(*resp.RoleCredentials.AccessKeyId), p.RoleName, p.AccountID, time.Until(millisecondsTimeValue(resp.RoleCredentials.Expiration)).String()) + return resp.RoleCredentials, nil + } + if cached && p.OIDCTokenCache != nil { var rspError *awshttp.ResponseError - if !errors.As(err, &rspError) { - return nil, err + if errors.As(err, &rspError) && rspError.HTTPStatusCode() == http.StatusUnauthorized { + // Cached token rejected: drop it and retry with a fresh access token. + // This should only happen once because the cache is cleared before retrying. + if err = p.OIDCTokenCache.Remove(p.StartURL); err != nil { + return nil, err + } + token, cached, err = p.getOIDCToken(ctx) + if err != nil { + return nil, err + } + attempt = 0 + continue } + } - // If the error is a 401, remove the cached oidc token and try - // again. This is a recursive call but it should only happen once - // due to the cache being cleared before retrying. - if rspError.HTTPStatusCode() == http.StatusUnauthorized { - err = p.OIDCTokenCache.Remove(p.StartURL) - if err != nil { + if isSSORateLimitError(err) { + rateLimitCount++ + remaining := deadline.Sub(p.ssoNow()) + if 0 < remaining { + var delay time.Duration + if retryAfter, ok := retryAfterFromError(err); ok { + if maxRetryAfterSeen < retryAfter { + maxRetryAfterSeen = retryAfter + } + delay = jitterRetryAfter(retryAfter) + } else { + delay = jitteredBackoff(baseDelay, maxDelay, attempt) + } + if remaining < delay { + delay = remaining + } + log.Printf("SSO rate limited for role %s (account: %s); backing off %s, attempt %d (%d 429s, max retry-after %s)", p.RoleName, p.AccountID, delay, attempt, rateLimitCount, maxRetryAfterSeen) + if err = p.ssoSleep(ctx, delay); err != nil { return nil, err } - return p.getRoleCredentials(ctx) + continue } + return nil, fmt.Errorf("SSO rate limited for role %s (account: %s) persistently for %s (%d 429s, max retry-after %s); giving up — try again later: %w", p.RoleName, p.AccountID, ssoRetryTimeout, rateLimitCount, maxRetryAfterSeen, err) } + return nil, err } - log.Printf("Got credentials %s for SSO role %s (account: %s), expires in %s", FormatKeyForDisplay(*resp.RoleCredentials.AccessKeyId), p.RoleName, p.AccountID, time.Until(millisecondsTimeValue(resp.RoleCredentials.Expiration)).String()) - - return resp.RoleCredentials, nil } func (p *SSORoleCredentialsProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { - return p.getRoleCredentialsAsStsCredemtials(ctx) + return p.getRoleCredentialsAsStsCredentials(ctx) } -// getRoleCredentialsAsStsCredemtials returns getRoleCredentials as sts.Credentials because sessions.Store expects it -func (p *SSORoleCredentialsProvider) getRoleCredentialsAsStsCredemtials(ctx context.Context) (*ststypes.Credentials, error) { +// getRoleCredentialsAsStsCredentials returns getRoleCredentials as sts.Credentials because sessions.Store expects it +func (p *SSORoleCredentialsProvider) getRoleCredentialsAsStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { creds, err := p.getRoleCredentials(ctx) if err != nil { return nil, err @@ -113,27 +227,112 @@ func (p *SSORoleCredentialsProvider) getRoleCredentialsAsStsCredemtials(ctx cont } func (p *SSORoleCredentialsProvider) getOIDCToken(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { - if p.OIDCTokenCache != nil { - token, err = p.OIDCTokenCache.Get(p.StartURL) - if err != nil && err != keyring.ErrKeyNotFound { - return nil, false, err - } - if token != nil { - return token, true, nil - } + token, cached, err = p.getCachedOIDCToken() + if err != nil || token != nil { + return token, cached, err } - token, err = p.newOIDCToken(ctx) + + if p.UseStdout { + return p.createAndCacheOIDCToken(ctx) + } + + if !p.UseSSOTokenLock { + return p.createAndCacheOIDCToken(ctx) + } + + return p.getOIDCTokenWithLock(ctx) +} + +func (p *SSORoleCredentialsProvider) getCachedOIDCToken() (token *ssooidc.CreateTokenOutput, cached bool, err error) { + if p.OIDCTokenCache == nil { + return nil, false, nil + } + + token, err = p.OIDCTokenCache.Get(p.StartURL) + if err != nil && err != keyring.ErrKeyNotFound { + return nil, false, err + } + if token != nil { + return token, true, nil + } + + return nil, false, nil +} + +func (p *SSORoleCredentialsProvider) createAndCacheOIDCToken(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { + token, err = p.newOIDCTokenFn(ctx) if err != nil { return nil, false, err } if p.OIDCTokenCache != nil { - err = p.OIDCTokenCache.Set(p.StartURL, token) - if err != nil { + if err = p.OIDCTokenCache.Set(p.StartURL, token); err != nil { return nil, false, err } } - return token, false, err + + return token, false, nil +} + +type oidcTokenResult struct { + token *ssooidc.CreateTokenOutput + cached bool +} + +func (p *SSORoleCredentialsProvider) getOIDCTokenWithLock(ctx context.Context) (token *ssooidc.CreateTokenOutput, cached bool, err error) { + waitCtx, cancel := context.WithTimeout(ctx, p.ssoLockTimeout) + defer cancel() + + result, err := withProcessLock(waitCtx, p.ssoTokenLock, lockWaiterOpts{ + LockPath: p.ssoTokenLock.Path(), + WarnMsg: "Waiting for SSO lock at %s\n", + LogMsg: "Waiting for SSO lock at %s", + WaitDelay: p.ssoLockWait, + LogEvery: p.ssoLockLog, + WarnAfter: defaultSSOLockWarnAfter, + Now: p.ssoNow, + Sleep: p.ssoSleep, + Logf: p.ssoLogf, + Warnf: func(format string, args ...any) { + fmt.Fprintf(os.Stderr, format, args...) + }, + }, "SSO token", func() (processLockResult[oidcTokenResult], error) { + token, cached, err := p.getCachedOIDCToken() + if err != nil { + return processLockResult[oidcTokenResult]{}, err + } + if token != nil { + return processLockResult[oidcTokenResult]{value: oidcTokenResult{token, cached}, ok: true}, nil + } + return processLockResult[oidcTokenResult]{}, nil + }, func(_ context.Context) (oidcTokenResult, error) { + // Use the caller's ctx (not the lock-wait timeout ctx) for actual + // work. The lock-wait timeout bounds how long we wait for the lock, + // not how long the work takes once we hold it. + + // Recheck cache after acquiring lock — another process may have filled it. + token, cached, err := p.getCachedOIDCToken() + if err != nil { + return oidcTokenResult{}, err + } + if token != nil { + return oidcTokenResult{token, cached}, nil + } + + token, err = p.newOIDCTokenFn(ctx) + if err != nil { + return oidcTokenResult{}, err + } + + if p.OIDCTokenCache != nil { + if err = p.OIDCTokenCache.Set(p.StartURL, token); err != nil { + return oidcTokenResult{}, err + } + } + + return oidcTokenResult{token, false}, nil + }) + return result.token, result.cached, err } func (p *SSORoleCredentialsProvider) newOIDCToken(ctx context.Context) (*ssooidc.CreateTokenOutput, error) { @@ -190,7 +389,9 @@ func (p *SSORoleCredentialsProvider) newOIDCToken(ctx context.Context) (*ssooidc var ape *ssooidctypes.AuthorizationPendingException if errors.As(err, &ape) { - time.Sleep(retryInterval) + if sleepErr := p.ssoSleep(ctx, retryInterval); sleepErr != nil { + return nil, sleepErr + } continue } @@ -201,3 +402,84 @@ func (p *SSORoleCredentialsProvider) newOIDCToken(ctx context.Context) (*ssooidc return t, nil } } + +func retryAfterFromError(err error) (time.Duration, bool) { + var rspError *awshttp.ResponseError + if errors.As(err, &rspError) { + if rspError.Response != nil { + if d, ok := parseRetryAfter(rspError.Response.Header.Get("Retry-After")); ok { + return d, true + } + } + } + return 0, false +} + +func parseRetryAfter(value string) (time.Duration, bool) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return 0, false + } + if secs, err := strconv.Atoi(trimmed); err == nil { + if secs < 0 { + return 0, false + } + return time.Duration(secs) * time.Second, true + } + if t, err := http.ParseTime(trimmed); err == nil { + d := time.Until(t) + if d < 0 { + d = 0 + } + return d, true + } + return 0, false +} + +func isSSORateLimitError(err error) bool { + var tooMany *ssotypes.TooManyRequestsException + if errors.As(err, &tooMany) { + return true + } + var rspError *awshttp.ResponseError + if errors.As(err, &rspError) && rspError.HTTPStatusCode() == http.StatusTooManyRequests { + return true + } + return false +} + +func jitterRetryAfter(base time.Duration) time.Duration { + if base <= 0 { + return 0 + } + return jitterDelay(base) +} + +func jitteredBackoff(base, max time.Duration, attempt int) time.Duration { + if attempt < 1 { + attempt = 1 + } + capDelay := base << uint(attempt-1) + if max < capDelay { + capDelay = max + } + if capDelay < base { + // Overflow: large shift wrapped negative; clamp to max, not base, + // so late retries stay backed off instead of becoming aggressive. + capDelay = max + } + return jitterDelay(capDelay) +} + +func jitterDelay(base time.Duration) time.Duration { + if base <= 0 { + return 0 + } + min := ssoRetryAfterJitterMin + max := ssoRetryAfterJitterMax + if max < min { + max = min + } + factor := min + rand.Float64()*(max-min) + return time.Duration(float64(base) * factor) +} diff --git a/vault/ssorolecredentialsprovider_lock_test.go b/vault/ssorolecredentialsprovider_lock_test.go new file mode 100644 index 00000000..ffbd15a2 --- /dev/null +++ b/vault/ssorolecredentialsprovider_lock_test.go @@ -0,0 +1,319 @@ +package vault + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ssooidc" + "github.com/byteness/keyring" +) + +type testTokenCache struct { + token *ssooidc.CreateTokenOutput + setCalls int + setLock *testLock +} + +func (c *testTokenCache) Get(string) (*ssooidc.CreateTokenOutput, error) { + if c.token == nil { + return nil, keyring.ErrKeyNotFound + } + return c.token, nil +} + +func (c *testTokenCache) Set(_ string, token *ssooidc.CreateTokenOutput) error { + c.setCalls++ + if c.setLock != nil && !c.setLock.locked { + return fmt.Errorf("lock not held during cache set") + } + c.token = token + return nil +} + +func (c *testTokenCache) Remove(string) error { + c.token = nil + return nil +} + +func newTestSSORoleProvider() *SSORoleCredentialsProvider { + p := &SSORoleCredentialsProvider{ + StartURL: "https://sso.example", + } + p.initSSODefaults() + return p +} + +func TestGetOIDCToken_CacheHit_NoLock(t *testing.T) { + cachedToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("cached")} + cache := &testTokenCache{token: cachedToken} + lock := &testLock{} + + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = true + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + t.Fatal("newOIDCToken should not be called on cache hit") + return nil, nil + } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !cached { + t.Fatalf("expected cached token") + } + if token != cachedToken { + t.Fatalf("unexpected token returned") + } + if lock.tryCalls != 0 { + t.Fatalf("expected no lock attempts, got %d", lock.tryCalls) + } +} + +func TestGetOIDCToken_LockDisabled_SkipsLock(t *testing.T) { + freshToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("fresh")} + cache := &testTokenCache{} + lock := &testLock{tryResults: []bool{true}} + + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = false + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + return freshToken, nil + } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cached { + t.Fatalf("expected non-cached token") + } + if token != freshToken { + t.Fatalf("unexpected token returned") + } + if lock.tryCalls != 0 { + t.Fatalf("expected no lock attempts, got %d", lock.tryCalls) + } + if cache.setCalls != 1 { + t.Fatalf("expected cache set once, got %d", cache.setCalls) + } +} + +func TestGetOIDCToken_LockMiss_ThenCacheHit_NoLock(t *testing.T) { + cachedToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("cached")} + cache := &testTokenCache{} + lock := &testLock{tryResults: []bool{false}} + clock := &testClock{now: time.Unix(0, 0)} + + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = true + p.ssoLockWait = 5 * time.Second + p.ssoNow = clock.Now + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + t.Fatal("newOIDCToken should not be called when cache fills while waiting") + return nil, nil + } + p.ssoSleep = func(ctx context.Context, d time.Duration) error { + clock.now = clock.now.Add(d) + cache.token = cachedToken + return nil + } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !cached { + t.Fatalf("expected cached token") + } + if token != cachedToken { + t.Fatalf("unexpected token returned") + } + if lock.tryCalls != 1 { + t.Fatalf("expected 1 lock attempt, got %d", lock.tryCalls) + } + if lock.unlockCalls != 0 { + t.Fatalf("expected no unlocks, got %d", lock.unlockCalls) + } +} + +func TestGetOIDCToken_LockAcquired_RecheckCache(t *testing.T) { + cachedToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("cached")} + cache := &testTokenCache{} + lock := &testLock{tryResults: []bool{true}} + lock.onTry = func(l *testLock) { + if l.locked { + cache.token = cachedToken + } + } + + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = true + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + t.Fatal("newOIDCToken should not be called when cache is filled after lock") + return nil, nil + } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !cached { + t.Fatalf("expected cached token") + } + if token != cachedToken { + t.Fatalf("unexpected token returned") + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } +} + +func TestGetOIDCToken_LockHeldThroughCacheSet(t *testing.T) { + freshToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("fresh")} + lock := &testLock{tryResults: []bool{true}} + cache := &testTokenCache{setLock: lock} + + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = true + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + return freshToken, nil + } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cached { + t.Fatalf("expected non-cached token") + } + if token != freshToken { + t.Fatalf("unexpected token returned") + } + if cache.setCalls != 1 { + t.Fatalf("expected cache set once, got %d", cache.setCalls) + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } +} + +func TestGetOIDCToken_UseStdout_SkipsLock(t *testing.T) { + freshToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("fresh")} + lock := &testLock{tryResults: []bool{true}} + cache := &testTokenCache{} + + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseStdout = true + p.UseSSOTokenLock = true + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + return freshToken, nil + } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cached { + t.Fatalf("expected non-cached token") + } + if token != freshToken { + t.Fatalf("unexpected token returned") + } + if lock.tryCalls != 0 { + t.Fatalf("expected no lock attempts, got %d", lock.tryCalls) + } +} + +func TestGetOIDCToken_LockWaitLogs(t *testing.T) { + lock := &testLock{tryResults: []bool{false, false, false, false}} + cache := &testTokenCache{} + ctx, cancel := context.WithCancel(context.Background()) + clock := &testClock{now: time.Unix(0, 0), cancel: cancel, cancelAfter: 4} + var logTimes []time.Time + + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = true + p.ssoLockWait = 5 * time.Second + p.ssoLockLog = 15 * time.Second + p.ssoNow = clock.Now + p.ssoSleep = clock.Sleep + p.ssoLogf = func(string, ...any) { + logTimes = append(logTimes, clock.Now()) + } + p.newOIDCTokenFn = func(context.Context) (*ssooidc.CreateTokenOutput, error) { + t.Fatal("newOIDCToken should not be called when lock never acquired") + return nil, nil + } + + _, _, err := p.getOIDCToken(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context cancellation, got %v", err) + } + if len(logTimes) != 2 { + t.Fatalf("expected 2 log entries, got %d", len(logTimes)) + } + if !logTimes[0].Equal(time.Unix(0, 0)) { + t.Fatalf("unexpected first log time: %s", logTimes[0]) + } + if !logTimes[1].Equal(time.Unix(15, 0)) { + t.Fatalf("unexpected second log time: %s", logTimes[1]) + } +} + +func TestGetOIDCToken_WorkNotCancelledByLockTimeout(t *testing.T) { + // The lock-wait timeout should only bound how long we wait for the + // lock, not how long the work takes. Simulate work that takes longer + // than the lock-wait timeout and verify it completes. + freshToken := &ssooidc.CreateTokenOutput{AccessToken: aws.String("fresh")} + lock := &testLock{tryResults: []bool{true}} + cache := &testTokenCache{setLock: lock} + + p := newTestSSORoleProvider() + p.OIDCTokenCache = cache + p.ssoTokenLock = lock + p.UseSSOTokenLock = true + p.ssoLockTimeout = 10 * time.Millisecond + p.newOIDCTokenFn = func(ctx context.Context) (*ssooidc.CreateTokenOutput, error) { + // Work takes longer than the lock-wait timeout. + // If the timeout incorrectly cancels work, ctx.Err() fires. + select { + case <-time.After(50 * time.Millisecond): + return freshToken, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + token, cached, err := p.getOIDCToken(context.Background()) + if err != nil { + t.Fatalf("unexpected error (work cancelled by lock-wait timeout?): %v", err) + } + if cached { + t.Fatalf("expected non-cached token") + } + if token != freshToken { + t.Fatalf("unexpected token returned") + } + if lock.unlockCalls != 1 { + t.Fatalf("expected 1 unlock, got %d", lock.unlockCalls) + } +} diff --git a/vault/ssorolecredentialsprovider_retry_test.go b/vault/ssorolecredentialsprovider_retry_test.go new file mode 100644 index 00000000..fc3464df --- /dev/null +++ b/vault/ssorolecredentialsprovider_retry_test.go @@ -0,0 +1,232 @@ +package vault + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + "github.com/aws/aws-sdk-go-v2/service/sso" + "github.com/aws/aws-sdk-go-v2/service/ssooidc" + ssotypes "github.com/aws/aws-sdk-go-v2/service/sso/types" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +func TestRetryAfterFromErrorSeconds(t *testing.T) { + header := http.Header{} + header.Set("Retry-After", "120") + resp := &http.Response{StatusCode: http.StatusTooManyRequests, Header: header} + err := &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{Response: resp}, + }, + } + + delay, ok := retryAfterFromError(err) + if !ok { + t.Fatal("expected retry-after delay to be detected") + } + if delay != 120*time.Second { + t.Fatalf("expected 120s retry-after, got %s", delay) + } +} + +func TestRetryAfterFromErrorMissingHeader(t *testing.T) { + resp := &http.Response{StatusCode: http.StatusTooManyRequests, Header: http.Header{}} + err := &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{Response: resp}, + }, + } + + delay, ok := retryAfterFromError(err) + if ok { + t.Fatalf("expected retry-after to be absent, got %s", delay) + } +} + +func TestIsSSORateLimitError(t *testing.T) { + if !isSSORateLimitError(&ssotypes.TooManyRequestsException{}) { + t.Fatal("expected TooManyRequestsException to be rate limit error") + } + + resp := &http.Response{StatusCode: http.StatusTooManyRequests} + err := &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{Response: resp}, + }, + } + if !isSSORateLimitError(err) { + t.Fatal("expected HTTP 429 response error to be rate limit error") + } + + if isSSORateLimitError(errors.New("boom")) { + t.Fatal("expected non-rate-limit error to be false") + } +} + +func TestJitterDelayRange(t *testing.T) { + base := 10 * time.Second + min := time.Duration(float64(base) * ssoRetryAfterJitterMin) + max := time.Duration(float64(base) * ssoRetryAfterJitterMax) + + for i := 0; i < 10; i++ { + delay := jitterDelay(base) + if delay < min || max < delay { + t.Fatalf("expected delay in range %s-%s, got %s", min, max, delay) + } + } +} + +func TestJitteredBackoffProgression(t *testing.T) { + base := 200 * time.Millisecond + max := 5 * time.Second + + // Each attempt should double the cap: 200ms, 400ms, 800ms, 1600ms, 3200ms, 5000ms (capped) + for attempt := 1; attempt <= 8; attempt++ { + expectedCap := base << uint(attempt-1) + if max < expectedCap { + expectedCap = max + } + minDelay := time.Duration(float64(expectedCap) * ssoRetryAfterJitterMin) + maxDelay := time.Duration(float64(expectedCap) * ssoRetryAfterJitterMax) + + for i := 0; i < 20; i++ { + delay := jitteredBackoff(base, max, attempt) + if delay < minDelay || maxDelay < delay { + t.Fatalf("attempt %d: expected delay in range %s-%s, got %s", + attempt, minDelay, maxDelay, delay) + } + } + } +} + +func TestJitteredBackoffRespectsMax(t *testing.T) { + base := 200 * time.Millisecond + max := 5 * time.Second + + // At high attempt numbers the cap should be max, not overflow. + // This includes attempts 37+ where base<<(attempt-1) overflows int64; + // the delay must stay clamped to max, not collapse to base. + minDelay := time.Duration(float64(max) * ssoRetryAfterJitterMin) + maxDelay := time.Duration(float64(max) * ssoRetryAfterJitterMax) + for attempt := 20; attempt <= 60; attempt++ { + for i := 0; i < 10; i++ { + delay := jitteredBackoff(base, max, attempt) + if maxDelay < delay { + t.Fatalf("attempt %d: delay %s exceeds max jittered cap %s", attempt, delay, maxDelay) + } + if delay < minDelay { + t.Fatalf("attempt %d: delay %s below min jittered cap %s (overflow regression)", attempt, delay, minDelay) + } + } + } +} + +func TestJitteredBackoffDoublesPerAttempt(t *testing.T) { + base := 1 * time.Second + max := 1 * time.Hour // very high max so we never hit the cap + + // Verify the cap doubles by checking that the median of many samples + // roughly doubles. Instead, verify the deterministic cap calculation: + // cap(attempt) = base << (attempt-1) + for attempt := 1; attempt <= 5; attempt++ { + expectedCap := base << uint(attempt-1) + minBound := time.Duration(float64(expectedCap) * ssoRetryAfterJitterMin) + maxBound := time.Duration(float64(expectedCap) * ssoRetryAfterJitterMax) + + delay := jitteredBackoff(base, max, attempt) + if delay < minBound || maxBound < delay { + t.Fatalf("attempt %d: expected delay in [%s, %s], got %s", + attempt, minBound, maxBound, delay) + } + } +} + +func TestJitterRetryAfterRange(t *testing.T) { + base := 2 * time.Second + minDelay := time.Duration(float64(base) * ssoRetryAfterJitterMin) + maxDelay := time.Duration(float64(base) * ssoRetryAfterJitterMax) + + for i := 0; i < 50; i++ { + delay := jitterRetryAfter(base) + if delay < minDelay || maxDelay < delay { + t.Fatalf("jitterRetryAfter(%s): expected delay in range %s-%s, got %s", + base, minDelay, maxDelay, delay) + } + } +} + +func TestJitterRetryAfterZeroBase(t *testing.T) { + delay := jitterRetryAfter(0) + if delay != 0 { + t.Fatalf("expected 0 delay for zero base, got %s", delay) + } +} + +func TestJitterRetryAfterNegativeBase(t *testing.T) { + delay := jitterRetryAfter(-1 * time.Second) + if delay != 0 { + t.Fatalf("expected 0 delay for negative base, got %s", delay) + } +} + +func TestGetRoleCredentialsTimeoutOnPersistentRateLimit(t *testing.T) { + // Set up an HTTP server that always returns 429 with a TooManyRequestsException + // body that the AWS SDK will deserialize into a TooManyRequestsException error. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Amzn-Errortype", "TooManyRequestsException") + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprint(w, `{"__type":"TooManyRequestsException","message":"Rate exceeded"}`) + })) + defer srv.Close() + + // Disable SDK retries so our retry loop handles them + ssoClient := sso.New(sso.Options{ + Region: "us-east-1", + BaseEndpoint: aws.String(srv.URL), + RetryMaxAttempts: 1, + }) + + startTime := time.Unix(1000000, 0) + clock := &testClock{now: startTime} + + p := newTestSSORoleProvider() + p.SSOClient = ssoClient + p.AccountID = "123456789012" + p.RoleName = "TestRole" + p.ssoNow = clock.Now + p.ssoSleep = clock.Sleep + p.ssoLogf = func(string, ...any) {} // suppress log output + + // Provide a cached OIDC token so getOIDCToken succeeds + cache := &testTokenCache{ + token: &ssooidc.CreateTokenOutput{AccessToken: aws.String("test-token")}, + } + p.OIDCTokenCache = cache + + _, err := p.getRoleCredentials(context.Background()) + if err == nil { + t.Fatal("expected error after timeout, got nil") + } + + if !strings.Contains(err.Error(), "persistently") { + t.Fatalf("expected timeout error mentioning 'persistently', got: %v", err) + } + if !strings.Contains(err.Error(), ssoRetryTimeout.String()) { + t.Fatalf("expected error to mention timeout duration %s, got: %v", ssoRetryTimeout, err) + } + + // Verify the clock advanced past the retry timeout + elapsed := clock.now.Sub(time.Unix(1000000, 0)) + if elapsed < ssoRetryTimeout { + t.Fatalf("expected clock to advance at least %s, advanced %s", ssoRetryTimeout, elapsed) + } +} diff --git a/vault/vault.go b/vault/vault.go index 786fad6e..f4daa204 100644 --- a/vault/vault.go +++ b/vault/vault.go @@ -51,7 +51,7 @@ func NewMasterCredentialsProvider(k *CredentialKeyring, credentialsName string) return &KeyringProvider{k, credentialsName} } -func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) { +func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig, useSessionCache bool, parallelSafe bool) (aws.CredentialsProvider, error) { cfg := NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints, config.EndpointURL) sessionTokenProvider := &SessionTokenProvider{ @@ -61,23 +61,24 @@ func NewSessionTokenProvider(credsProvider aws.CredentialsProvider, k keyring.Ke } if useSessionCache { - return &CachedSessionProvider{ - SessionKey: SessionMetadata{ + return NewCachedSessionProvider( + SessionMetadata{ Type: "sts.GetSessionToken", ProfileName: config.ProfileName, MfaSerial: config.MfaSerial, }, - Keyring: &SessionKeyring{Keyring: k}, - ExpiryWindow: defaultExpirationWindow, - SessionProvider: sessionTokenProvider, - }, nil + sessionTokenProvider, + &SessionKeyring{Keyring: k}, + defaultExpirationWindow, + parallelSafe, + ), nil } return sessionTokenProvider, nil } // NewAssumeRoleProvider returns a provider that generates credentials using AssumeRole -func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) { +func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyring, config *ProfileConfig, useSessionCache bool, parallelSafe bool) (aws.CredentialsProvider, error) { cfg := NewAwsConfigWithCredsProvider(credsProvider, config.Region, config.STSRegionalEndpoints, config.EndpointURL) p := &AssumeRoleProvider{ @@ -93,16 +94,17 @@ func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyr } if useSessionCache && config.MfaSerial != "" { - return &CachedSessionProvider{ - SessionKey: SessionMetadata{ + return NewCachedSessionProvider( + SessionMetadata{ Type: "sts.AssumeRole", ProfileName: config.ProfileName, MfaSerial: config.MfaSerial, }, - Keyring: &SessionKeyring{Keyring: k}, - ExpiryWindow: defaultExpirationWindow, - SessionProvider: p, - }, nil + p, + &SessionKeyring{Keyring: k}, + defaultExpirationWindow, + parallelSafe, + ), nil } return p, nil @@ -110,7 +112,7 @@ func NewAssumeRoleProvider(credsProvider aws.CredentialsProvider, k keyring.Keyr // NewAssumeRoleWithWebIdentityProvider returns a provider that generates // credentials using AssumeRoleWithWebIdentity -func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) { +func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool, parallelSafe bool) (aws.CredentialsProvider, error) { cfg := NewAwsConfig(config.Region, config.STSRegionalEndpoints, config.EndpointURL) p := &AssumeRoleWithWebIdentityProvider{ @@ -123,22 +125,23 @@ func NewAssumeRoleWithWebIdentityProvider(k keyring.Keyring, config *ProfileConf } if useSessionCache { - return &CachedSessionProvider{ - SessionKey: SessionMetadata{ + return NewCachedSessionProvider( + SessionMetadata{ Type: "sts.AssumeRoleWithWebIdentity", ProfileName: config.ProfileName, }, - Keyring: &SessionKeyring{Keyring: k}, - ExpiryWindow: defaultExpirationWindow, - SessionProvider: p, - }, nil + p, + &SessionKeyring{Keyring: k}, + defaultExpirationWindow, + parallelSafe, + ), nil } return p, nil } // NewSSORoleCredentialsProvider creates a provider for SSO credentials -func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) { +func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool, parallelSafe bool) (aws.CredentialsProvider, error) { cfg := NewAwsConfig(config.SSORegion, config.STSRegionalEndpoints, config.EndpointURL) ssoRoleCredentialsProvider := &SSORoleCredentialsProvider{ @@ -149,19 +152,24 @@ func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, use RoleName: config.SSORoleName, UseStdout: config.SSOUseStdout, } + ssoRoleCredentialsProvider.initSSODefaults() + if parallelSafe { + ssoRoleCredentialsProvider.EnableSSOTokenLock() + } if useSessionCache { ssoRoleCredentialsProvider.OIDCTokenCache = OIDCTokenKeyring{Keyring: k} - return &CachedSessionProvider{ - SessionKey: SessionMetadata{ + return NewCachedSessionProvider( + SessionMetadata{ Type: "sso.GetRoleCredentials", ProfileName: config.ProfileName, MfaSerial: config.SSOStartURL, }, - Keyring: &SessionKeyring{Keyring: k}, - ExpiryWindow: defaultExpirationWindow, - SessionProvider: ssoRoleCredentialsProvider, - }, nil + ssoRoleCredentialsProvider, + &SessionKeyring{Keyring: k}, + defaultExpirationWindow, + parallelSafe, + ), nil } return ssoRoleCredentialsProvider, nil @@ -169,21 +177,22 @@ func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, use // NewCredentialProcessProvider creates a provider to retrieve credentials from an external // executable as described in https://docs.aws.amazon.com/cli/latest/topic/config-vars.html#sourcing-credentials-from-external-processes -func NewCredentialProcessProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) { +func NewCredentialProcessProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool, parallelSafe bool) (aws.CredentialsProvider, error) { credentialProcessProvider := &CredentialProcessProvider{ CredentialProcess: config.CredentialProcess, } if useSessionCache { - return &CachedSessionProvider{ - SessionKey: SessionMetadata{ + return NewCachedSessionProvider( + SessionMetadata{ Type: "credential_process", ProfileName: config.ProfileName, }, - Keyring: &SessionKeyring{Keyring: k}, - ExpiryWindow: defaultExpirationWindow, - SessionProvider: credentialProcessProvider, - }, nil + credentialProcessProvider, + &SessionKeyring{Keyring: k}, + defaultExpirationWindow, + parallelSafe, + ), nil } return credentialProcessProvider, nil @@ -230,6 +239,8 @@ type TempCredentialsCreator struct { DisableCache bool // DisableSessionsForProfile is a profile for which sessions should not be used DisableSessionsForProfile string + // ParallelSafe enables cross-process locking for cached credentials. + ParallelSafe bool chainedMfa string } @@ -263,7 +274,7 @@ func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, } t.chainedMfa = config.MfaSerial log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config)) - sourcecredsProvider, err = NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) + sourcecredsProvider, err = NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache, t.ParallelSafe) if !config.HasRole() || err != nil { return sourcecredsProvider, err } @@ -275,7 +286,7 @@ func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, config.MfaSerial = "" } log.Printf("profile %s: using AssumeRole %s", config.ProfileName, mfaDetails(isMfaChained, config)) - return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) + return NewAssumeRoleProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache, t.ParallelSafe) } if isMasterCredentialsProvider(sourcecredsProvider) { @@ -283,7 +294,7 @@ func (t *TempCredentialsCreator) getSourceCredWithSession(config *ProfileConfig, if canUseGetSessionToken { t.chainedMfa = config.MfaSerial log.Printf("profile %s: using GetSessionToken %s", config.ProfileName, mfaDetails(false, config)) - return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache) + return NewSessionTokenProvider(sourcecredsProvider, t.Keyring.Keyring, config, !t.DisableCache, t.ParallelSafe) } log.Printf("profile %s: skipping GetSessionToken because %s", config.ProfileName, reason) } @@ -303,17 +314,17 @@ func (t *TempCredentialsCreator) GetProviderForProfile(config *ProfileConfig) (a if config.HasSSOStartURL() { log.Printf("profile %s: using SSO role credentials", config.ProfileName) - return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config, !t.DisableCache) + return NewSSORoleCredentialsProvider(t.Keyring.Keyring, config, !t.DisableCache, t.ParallelSafe) } if config.HasWebIdentity() { log.Printf("profile %s: using web identity", config.ProfileName) - return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config, !t.DisableCache) + return NewAssumeRoleWithWebIdentityProvider(t.Keyring.Keyring, config, !t.DisableCache, t.ParallelSafe) } if config.HasCredentialProcess() { log.Printf("profile %s: using credential process", config.ProfileName) - return NewCredentialProcessProvider(t.Keyring.Keyring, config, !t.DisableCache) + return NewCredentialProcessProvider(t.Keyring.Keyring, config, !t.DisableCache, t.ParallelSafe) } return nil, fmt.Errorf("profile %s: credentials missing", config.ProfileName) @@ -359,12 +370,23 @@ func mfaDetails(mfaChained bool, config *ProfileConfig) string { return "" } -// NewTempCredentialsProvider creates a credential provider for the given config +// TempCredentialsOptions controls how temporary credential providers are created. +type TempCredentialsOptions struct { + ParallelSafe bool +} + +// NewTempCredentialsProvider creates a credential provider for the given config. func NewTempCredentialsProvider(config *ProfileConfig, keyring *CredentialKeyring, disableSessions bool, disableCache bool) (aws.CredentialsProvider, error) { + return NewTempCredentialsProviderWithOptions(config, keyring, disableSessions, disableCache, TempCredentialsOptions{}) +} + +// NewTempCredentialsProviderWithOptions creates a credential provider for the given config with options. +func NewTempCredentialsProviderWithOptions(config *ProfileConfig, keyring *CredentialKeyring, disableSessions bool, disableCache bool, options TempCredentialsOptions) (aws.CredentialsProvider, error) { t := TempCredentialsCreator{ Keyring: keyring, DisableSessions: disableSessions, DisableCache: disableCache, + ParallelSafe: options.ParallelSafe, } return t.GetProviderForProfile(config) } diff --git a/vault/vault_test.go b/vault/vault_test.go index 75812bab..8ec3d165 100644 --- a/vault/vault_test.go +++ b/vault/vault_test.go @@ -4,8 +4,8 @@ import ( "os" "testing" - "github.com/byteness/keyring" "github.com/byteness/aws-vault/v7/vault" + "github.com/byteness/keyring" ) func TestUsageWebIdentityExample(t *testing.T) { @@ -123,3 +123,115 @@ sso_registration_scopes=sso:account:access t.Fatalf("Expected AccountID to be 2160xxxx, got %s", ssoProvider.AccountID) } } + +func TestTempCredentialsProviderParallelSafeGetSessionToken(t *testing.T) { + config := &vault.ProfileConfig{ + ProfileName: "creds", + Region: "us-east-1", + MfaToken: "123456", // provide token to avoid interactive prompt + } + + ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{ + {Key: "creds", Data: []byte(`{"AccessKeyID":"AKIAIOSFODNN7EXAMPLE","SecretAccessKey":"secret"}`)}, + })} + provider, err := vault.NewTempCredentialsProviderWithOptions( + config, + ckr, + false, + false, + vault.TempCredentialsOptions{ParallelSafe: true}, + ) + if err != nil { + t.Fatal(err) + } + + cached, ok := provider.(*vault.CachedSessionProvider) + if !ok { + t.Fatalf("Expected CachedSessionProvider, got %T", provider) + } + if !cached.UseSessionLock { + t.Fatalf("Expected UseSessionLock to be true") + } + _, ok = cached.SessionProvider.(*vault.SessionTokenProvider) + if !ok { + t.Fatalf("Expected SessionTokenProvider, got %T", cached.SessionProvider) + } +} + +func TestTempCredentialsProviderParallelSafeAssumeRole(t *testing.T) { + config := &vault.ProfileConfig{ + ProfileName: "role", + Region: "us-east-1", + RoleARN: "arn:aws:iam::222222222222:role/role", + MfaSerial: "arn:aws:iam::111111111111:mfa/user", + MfaToken: "123456", // provide token to avoid interactive prompt + SourceProfileName: "source", + SourceProfile: &vault.ProfileConfig{ + ProfileName: "source", + Region: "us-east-1", + }, + } + + ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{ + {Key: "source", Data: []byte(`{"AccessKeyID":"AKIAIOSFODNN7EXAMPLE","SecretAccessKey":"secret"}`)}, + })} + provider, err := vault.NewTempCredentialsProviderWithOptions( + config, + ckr, + true, // disableSessions: skip GetSessionToken so AssumeRole gets the MFA + false, + vault.TempCredentialsOptions{ParallelSafe: true}, + ) + if err != nil { + t.Fatal(err) + } + + cached, ok := provider.(*vault.CachedSessionProvider) + if !ok { + t.Fatalf("Expected CachedSessionProvider, got %T", provider) + } + if !cached.UseSessionLock { + t.Fatalf("Expected UseSessionLock to be true") + } + _, ok = cached.SessionProvider.(*vault.AssumeRoleProvider) + if !ok { + t.Fatalf("Expected AssumeRoleProvider, got %T", cached.SessionProvider) + } +} + +func TestTempCredentialsProviderParallelSafeSSOLocks(t *testing.T) { + config := &vault.ProfileConfig{ + ProfileName: "sso-profile", + SSOStartURL: "https://sso.example/start", + SSORegion: "us-east-1", + SSOAccountID: "123456789012", + SSORoleName: "Role", + } + + ckr := &vault.CredentialKeyring{Keyring: keyring.NewArrayKeyring([]keyring.Item{})} + provider, err := vault.NewTempCredentialsProviderWithOptions( + config, + ckr, + false, + false, + vault.TempCredentialsOptions{ParallelSafe: true}, + ) + if err != nil { + t.Fatal(err) + } + + cached, ok := provider.(*vault.CachedSessionProvider) + if !ok { + t.Fatalf("Expected CachedSessionProvider, got %T", provider) + } + if !cached.UseSessionLock { + t.Fatalf("Expected UseSessionLock to be true") + } + ssoProvider, ok := cached.SessionProvider.(*vault.SSORoleCredentialsProvider) + if !ok { + t.Fatalf("Expected SSORoleCredentialsProvider, got %T", cached.SessionProvider) + } + if !ssoProvider.UseSSOTokenLock { + t.Fatalf("Expected UseSSOTokenLock to be true") + } +}