diff --git a/cmd/rds-iam-psql/README.md b/cmd/rds-iam-psql/README.md new file mode 100644 index 0000000..5904fcf --- /dev/null +++ b/cmd/rds-iam-psql/README.md @@ -0,0 +1,111 @@ +# rds-iam-psql + +A CLI that launches an interactive `psql` session from a required RDS IAM URL: +- positional `postgres+rds-iam://...` DSN +- optional `-debug-aws` flag + +## Why? + +RDS IAM authentication lets you connect using AWS credentials instead of a static DB password. IAM auth tokens are short-lived and inconvenient to generate manually. This tool resolves a fresh DSN through `pgutils` and opens `psql` for you. + +## Installation + +```bash +go install github.com/corbaltcode/go-libraries/cmd/rds-iam-psql@latest +``` + +Or build from source: + +```bash +cd ./cmd/rds-iam-psql +go build +``` + +## Prerequisites + +- **psql** installed and available in your PATH +- **AWS credentials** configured (env vars, `~/.aws/credentials`, IAM role, etc.) +- **AWS region** configured for SDK resolution (for example: `AWS_REGION`, shared config profile, or runtime role config) +- **RDS IAM authentication enabled** on your database instance +- A DB user configured for IAM auth (for example: `CREATE USER myuser WITH LOGIN; GRANT rds_iam TO myuser;`) + +## Usage + +```bash +rds-iam-psql [-debug-aws] '' +``` + +- Flags must come before the DSN (standard Go flag parsing behavior). +- `` may omit the database path. When omitted, `pgutils` defaults the database name to the username. + +### Flags + +| Flag | Default | Description | +|------|---------|-------------| +| `-debug-aws` | `false` | Print STS caller identity before connecting | + +## Examples + +Basic IAM URL: + +```bash +./rds-iam-psql 'postgres+rds-iam://server@acremins-test.cicxifnkufnd.us-east-1.rds.amazonaws.com:5432/postgres' +``` + +IAM URL with cross-account role assumption: + +```bash +rds-iam-psql 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp?assume_role_arn=arn:aws:iam::123456789012:role/db-connect&assume_role_session_name=rds-iam-psql' +``` + +With AWS identity debugging: + +```bash +rds-iam-psql -debug-aws 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp' +``` + +Without explicit database name (defaults to username): + +```bash +rds-iam-psql 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432' +``` + +## Changing Search Path In psql + +If you need to change the schema search path, do it from the interactive `psql` session after connecting: + +```sql +SHOW search_path; +SET search_path TO app_schema, public; +``` + +This applies to the current session. If you need a persistent default, configure it in Postgres (for example with `ALTER ROLE ... SET search_path ...`). + +## How It Works + +1. Parses and validates the positional IAM URL. +2. Builds a `pgutils` connection string provider from the IAM URL. +3. If `-debug-aws` is set, runs STS `GetCallerIdentity` and prints the caller ARN. +4. Resolves an IAM tokenized DSN from the provider and launches `psql` with the IAM generated connection string. + +## Setting Up IAM Auth on RDS + +1. Enable IAM authentication on your RDS instance +2. Create a database user and grant IAM privileges: + ```sql + CREATE USER myuser WITH LOGIN; + GRANT rds_iam TO myuser; + ``` +3. Attach an IAM policy allowing `rds-db:connect` to your AWS user/role: + ```json + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": "rds-db:connect", + "Resource": "arn:aws:rds-db:::dbuser:/" + } + ] + } + ``` diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go new file mode 100644 index 0000000..03d1c4a --- /dev/null +++ b/cmd/rds-iam-psql/main.go @@ -0,0 +1,184 @@ +package main + +import ( + "bytes" + "context" + "errors" + "flag" + "fmt" + "io" + "log" + "net/url" + "os" + "os/exec" + "os/signal" + "strings" + "syscall" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/corbaltcode/go-libraries/pgutils" +) + +const usageTemplate = `Usage: + %[2]s [-debug-aws] 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb' + +Notes: + Flags must come before the DSN (standard Go flag parsing). + Database path is optional. If omitted, the database name defaults to the username. + +Flags: +%[1]s + +Examples: + %[2]s 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb' + %[2]s -debug-aws 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb' +` + +func main() { + rawURL, debugAWS, err := parseCLIArgs(os.Args[1:], os.Args[0]) + if err != nil { + if errors.Is(err, flag.ErrHelp) { + printUsage(os.Stdout, os.Args[0]) + return + } + fmt.Fprintf(os.Stderr, "%v\n\n", err) + printUsage(os.Stderr, os.Args[0]) + os.Exit(2) + } + + if err := validateRDSIAMURL(rawURL); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(2) + } + + ctx := context.Background() + + if debugAWS { + cfg, err := awsconfig.LoadDefaultConfig(ctx) + if err != nil { + log.Fatalf("failed to load AWS config: %v", err) + } + if err := printCallerIdentity(ctx, cfg); err != nil { + log.Fatalf("AWS credentials check failed: %v", err) + } + } + + connectionStringProvider, err := pgutils.NewConnectionStringProviderFromURLString(ctx, rawURL) + if err != nil { + log.Fatalf("failed to create connection string provider: %v", err) + } + + dsnWithToken, err := connectionStringProvider.ConnectionString(ctx) + if err != nil { + log.Fatalf("failed to get connection string from provider: %v", err) + } + + cmd := exec.Command("psql", dsnWithToken) + + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // Ignore SIGINT in the wrapper so interactive Ctrl-C can be handled by psql. + // Forward SIGTERM to the child process. + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigCh) + + if err := cmd.Start(); err != nil { + log.Fatalf("failed to start psql: %v", err) + } + + waitCh := make(chan error, 1) + go func() { waitCh <- cmd.Wait() }() + + for { + select { + case sig := <-sigCh: + switch sig { + case os.Interrupt: + continue + case syscall.SIGTERM: + if cmd.Process != nil { + _ = cmd.Process.Signal(syscall.SIGTERM) + } + } + case err := <-waitCh: + if err == nil { + return + } + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) + } + log.Fatalf("psql failed: %v", err) + } + } +} + +func newFlagSet(bin string, output io.Writer) (fs *flag.FlagSet, debugAWSFlag *bool) { + fs = flag.NewFlagSet(bin, flag.ContinueOnError) + fs.SetOutput(output) + + return fs, + fs.Bool("debug-aws", false, "Print AWS caller identity before connecting") +} + +func printUsage(output io.Writer, bin string) { + fs, _ := newFlagSet(bin, io.Discard) + + var defaults bytes.Buffer + fs.SetOutput(&defaults) + fs.PrintDefaults() + + fmt.Fprintf(output, usageTemplate, strings.TrimRight(defaults.String(), "\n"), bin) +} + +func parseCLIArgs(args []string, bin string) (rawURL string, debugAWS bool, err error) { + fs, debugAWSFlag := newFlagSet(bin, io.Discard) + + if err := fs.Parse(args); err != nil { + return "", false, err + } + + positionals := fs.Args() + if len(positionals) != 1 { + return "", false, fmt.Errorf("expected exactly one positional RDS IAM connection URL argument, got %d", len(positionals)) + } + + return positionals[0], *debugAWSFlag, nil +} + +func validateRDSIAMURL(rawURL string) error { + parsedURL, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("failed to parse positional connection URL: %w", err) + } + if parsedURL.Scheme != "postgres+rds-iam" { + return fmt.Errorf("unsupported connection URL scheme %q (expected postgres+rds-iam)", parsedURL.Scheme) + } + if parsedURL.User == nil || strings.TrimSpace(parsedURL.User.Username()) == "" { + return fmt.Errorf("connection URL must include a database username") + } + if _, ok := parsedURL.User.Password(); ok { + return fmt.Errorf("connection URL must not include a password for postgres+rds-iam") + } + if strings.TrimSpace(parsedURL.Host) == "" { + return fmt.Errorf("connection URL must include a database host") + } + + return nil +} + +func printCallerIdentity(ctx context.Context, cfg aws.Config) error { + stsClient := sts.NewFromConfig(cfg) + + out, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return fmt.Errorf("STS GetCallerIdentity failed (creds invalid/expired or STS not allowed): %w", err) + } + + fmt.Fprintf(os.Stderr, "Caller ARN: %s\n", aws.ToString(out.Arn)) + return nil +}