Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions cli/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type ExecCommandInput struct {
NoSession bool
UseStdout bool
ShowHelpMessages bool
UseProfileEnv bool
}

func (input ExecCommandInput) validate() error {
Expand Down Expand Up @@ -108,6 +109,9 @@ func ConfigureExecCommand(app *kingpin.Application, a *AwsVault) {
OverrideDefaultFromEnvar("AWS_VAULT_STDOUT").
BoolVar(&input.UseStdout)

cmd.Flag("profile-env", "Set AWS_PROFILE instead of injecting AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY").
BoolVar(&input.UseProfileEnv)

cmd.Arg("profile", "Name of the profile").
//Required().
Default(os.Getenv("AWS_PROFILE")).
Expand Down Expand Up @@ -220,8 +224,20 @@ func ExecCommand(input ExecCommandInput, f *vault.ConfigFile, keyring keyring.Ke
}
printHelpMessage(subshellHelp, input.ShowHelpMessages)
} else {
if err = addCredsToEnv(credsProvider, input.ProfileName, &cmdEnv); err != nil {
return 0, err
if input.UseProfileEnv {
if _, err = credsProvider.Retrieve(context.TODO()); err != nil {
return 0, fmt.Errorf("Failed to get credentials for %s: %w", input.ProfileName, err)
}
if config.HasSSOStartURL() {
if err = vault.SyncOIDCTokenToStandardCache(config, keyring); err != nil {
log.Printf("Warning: failed to sync OIDC token to standard cache: %s", err)
}
}
cmdEnv.Set("AWS_PROFILE", input.ProfileName)
} else {
if err = addCredsToEnv(credsProvider, input.ProfileName, &cmdEnv); err != nil {
return 0, err
}
}
printHelpMessage(subshellHelp, input.ShowHelpMessages)

Expand Down
7 changes: 4 additions & 3 deletions vault/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ func (c *ConfigFile) parseFile() error {
log.Printf("Parsing config file %s", c.Path)

f, err := ini.LoadSources(ini.LoadOptions{
AllowNestedValues: true,
InsensitiveSections: false,
InsensitiveKeys: true,
AllowNestedValues: true,
InsensitiveSections: false,
InsensitiveKeys: true,
IgnoreInlineComment: true,
}, c.Path)
if err != nil {
return fmt.Errorf("Error parsing config file %s: %w", c.Path, err)
Expand Down
83 changes: 82 additions & 1 deletion vault/vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package vault

import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials/ssocreds"
"github.com/aws/aws-sdk-go-v2/service/sso"
"github.com/aws/aws-sdk-go-v2/service/ssooidc"
"github.com/aws/aws-sdk-go-v2/service/sts"
Expand All @@ -32,7 +35,7 @@ func NewAwsConfig(region, stsRegionalEndpoints, endpointURL string) aws.Config {
func NewAwsConfigWithCredsProvider(credsProvider aws.CredentialsProvider, region, stsRegionalEndpoints, endpointURL string) aws.Config {
return aws.Config{
Region: region,
Credentials: credsProvider,
Credentials: aws.NewCredentialsCache(credsProvider),
EndpointResolverWithOptions: getSTSEndpointResolver(stsRegionalEndpoints, endpointURL),
}
}
Expand Down Expand Up @@ -167,6 +170,84 @@ func NewSSORoleCredentialsProvider(k keyring.Keyring, config *ProfileConfig, use
return ssoRoleCredentialsProvider, nil
}

// ssoTokenCacheKey returns the key used to compute the standard SSO cache file path.
// For profiles using [sso-session] this is the session name; for legacy profiles it is the start URL.
func ssoTokenCacheKey(config *ProfileConfig) string {
if config.SSOSession != "" {
return config.SSOSession
}
return config.SSOStartURL
}

// NewStandardCachedSSOCredentialsProvider returns an ssocreds.Provider that reads the SSO
// access token from the standard AWS CLI cache file (~/.aws/sso/cache/<sha1>.json).
// Returns nil, nil if the standard token file does not exist.
func NewStandardCachedSSOCredentialsProvider(config *ProfileConfig) (aws.CredentialsProvider, error) {
tokenFilepath, err := ssocreds.StandardCachedTokenFilepath(ssoTokenCacheKey(config))
if err != nil {
return nil, err
}

if _, err := os.Stat(tokenFilepath); os.IsNotExist(err) {
return nil, nil
}

cfg := NewAwsConfig(config.SSORegion, config.STSRegionalEndpoints, config.EndpointURL)

return ssocreds.New(
sso.NewFromConfig(cfg),
config.SSOAccountID,
config.SSORoleName,
config.SSOStartURL,
func(o *ssocreds.Options) {
o.CachedTokenFilepath = tokenFilepath
},
), nil
}

// SyncOIDCTokenToStandardCache writes the OIDC access token for the given profile
// from the keyring to the standard AWS SSO cache file (~/.aws/sso/cache/<sha1>.json),
// so that other AWS tools that read the standard file location can use it.
// Returns nil without error if the standard cache file already exists.
func SyncOIDCTokenToStandardCache(config *ProfileConfig, k keyring.Keyring) error {
tokenFilepath, err := ssocreds.StandardCachedTokenFilepath(ssoTokenCacheKey(config))
if err != nil {
return err
}

token, err := (OIDCTokenKeyring{Keyring: k}).Get(config.SSOStartURL)
if err != nil {
return fmt.Errorf("OIDC token not found in keyring for %s: %w", config.SSOStartURL, err)
}

expiration := time.Now().Add(time.Duration(token.ExpiresIn) * time.Second)

type cachedToken struct {
AccessToken string `json:"accessToken"`
ExpiresAt string `json:"expiresAt"`
RefreshToken string `json:"refreshToken,omitempty"`
}

t := cachedToken{
AccessToken: aws.ToString(token.AccessToken),
ExpiresAt: expiration.UTC().Format(time.RFC3339),
}
if token.RefreshToken != nil {
t.RefreshToken = aws.ToString(token.RefreshToken)
}

b, err := json.Marshal(t)
if err != nil {
return err
}

if err := os.MkdirAll(filepath.Dir(tokenFilepath), 0700); err != nil {
return err
}

return os.WriteFile(tokenFilepath, b, 0600)
}

// NewCredentialProcessProvider creates a provider to retrieve credentials from an external
// executable as described in https://docs.aws.amazon.com/cli/latest/topic/config-vars.html#sourcing-credentials-from-external-processes
func NewCredentialProcessProvider(k keyring.Keyring, config *ProfileConfig, useSessionCache bool) (aws.CredentialsProvider, error) {
Expand Down
Loading