From f8dcc535da28e322aa1aa5cf91132fa661b8bada Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Fri, 12 Dec 2025 05:27:36 -0800 Subject: [PATCH 01/10] init --- cmd/rds-iam-psql/README.md | 114 ++++++++++++++++++++++++++++++ cmd/rds-iam-psql/main.go | 139 +++++++++++++++++++++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 cmd/rds-iam-psql/README.md create mode 100644 cmd/rds-iam-psql/main.go diff --git a/cmd/rds-iam-psql/README.md b/cmd/rds-iam-psql/README.md new file mode 100644 index 0000000..b6dae57 --- /dev/null +++ b/cmd/rds-iam-psql/README.md @@ -0,0 +1,114 @@ +# rds-iam-psql + +A simple CLI tool that bridges AWS RDS IAM authentication into an interactive `psql` session. It generates a short-lived IAM auth token and launches `psql` with the token as the password, so you never have to manage database passwords. + +## Why? + +RDS IAM authentication lets you connect to PostgreSQL using your AWS credentials instead of a static database password. However, the auth tokens are temporary (15 minutes) and cumbersome to generate manually. This tool handles token generation automatically and drops you into a familiar `psql` shell. + +## Installation + +```bash +go install github.com/corbaltcode/go-libraries/cmd/rds-iam-psql@latest +``` + +Or build from source: + +```bash +cd ./cmd/rds-iam-psql +go build +``` + +## Prerequisites + +- **psql** installed and available in your PATH +- **AWS credentials** configured (via environment variables, `~/.aws/credentials`, IAM role, etc.) +- **RDS IAM authentication enabled** on your database instance +- A database user configured for IAM authentication (created with `CREATE USER myuser WITH LOGIN; GRANT rds_iam TO myuser;`) + +## Usage + +```bash +rds-iam-psql -host -user -db [options] +``` + +### Required Flags + +| Flag | Description | +|------|-------------| +| `-host` | RDS endpoint hostname (without port), e.g. `mydb.abc123.us-east-1.rds.amazonaws.com` | +| `-user` | Database username configured for IAM auth | +| `-db` | Database name to connect to | + +### Optional Flags + +| Flag | Default | Description | +|------|---------|-------------| +| `-port` | `5432` | PostgreSQL port | +| `-region` | auto | AWS region. If omitted, inferred from AWS config or the hostname | +| `-profile` | | AWS shared config profile to use (e.g. `dev`, `prod`) | +| `-psql` | `psql` | Path to the `psql` binary | +| `-sslmode` | `require` | SSL mode (`require`, `verify-full`, etc.) | +| `-search-path` | | PostgreSQL `search_path` to set on connection (e.g. `myschema,public`) | + +## Examples + +Basic connection: + +```bash +rds-iam-psql -host mydb.abc123.us-east-1.rds.amazonaws.com -user app_user -db myapp +``` + +With a specific AWS profile and schema: + +```bash +rds-iam-psql \ + -host mydb.abc123.us-east-1.rds.amazonaws.com \ + -user app_user \ + -db myapp \ + -profile production \ + -search-path "app_schema,public" +``` + +Using a non-standard port and explicit region: + +```bash +rds-iam-psql \ + -host mydb.abc123.us-east-1.rds.amazonaws.com \ + -port 5433 \ + -user admin \ + -db postgres \ + -region us-east-1 +``` + +## How It Works + +1. Loads your AWS credentials from the standard credential chain +2. Generates a temporary RDS IAM auth token using `auth.BuildAuthToken` +3. Launches `psql` with: + - `PGPASSWORD` set to the auth token + - `PGSSLMODE` set according to `-sslmode` + - `PGOPTIONS` set if `-search-path` is provided +4. Attaches stdin/stdout/stderr for interactive use + +## Setting Up IAM Auth on RDS + +1. Enable IAM authentication on your RDS instance +2. Create a database user and grant IAM privileges: + ```sql + CREATE USER myuser WITH LOGIN; + GRANT rds_iam TO myuser; + ``` +3. Attach an IAM policy allowing `rds-db:connect` to your AWS user/role: + ```json + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": "rds-db:connect", + "Resource": "arn:aws:rds-db:::dbuser:/" + } + ] + } + ``` diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go new file mode 100644 index 0000000..d076047 --- /dev/null +++ b/cmd/rds-iam-psql/main.go @@ -0,0 +1,139 @@ +// rds-iam-psql.go +package main + +import ( + "context" + "flag" + "fmt" + "log" + "os" + "os/exec" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/rds/auth" +) + +func main() { + var ( + host = flag.String("host", "", "RDS PostgreSQL endpoint hostname (no port, e.g. mydb.abc123.us-east-1.rds.amazonaws.com)") + port = flag.Int("port", 5432, "RDS PostgreSQL port (default 5432)") + user = flag.String("user", "", "Database user name") + dbName = flag.String("db", "", "Database name") + region = flag.String("region", "", "AWS region for the RDS instance (e.g. us-east-1). If empty, uses AWS config or tries to infer from host.") + profile = flag.String("profile", "", "Optional AWS shared config profile (e.g. dev)") + psqlPath = flag.String("psql", "psql", "Path to psql binary") + sslMode = flag.String("sslmode", "require", "PGSSLMODE for psql (e.g. require, verify-full)") + searchPath = flag.String("search-path", "", "Optional PostgreSQL search_path to set (e.g. 'myschema,public')") + ) + flag.Parse() + + if *host == "" || *user == "" || *dbName == "" { + log.Fatalf("host, user, and db are required\n\nUsage example:\n %s -host mydb.abc123.us-east-1.rds.amazonaws.com -port 5432 -user myuser -db mydb -search-path \"login,public\" -region us-east-1\n", os.Args[0]) + } + + ctx := context.Background() + + // Load AWS config (standard RDS/IAM auth expects your AWS creds, *not* the DB password). + var cfg aws.Config + var err error + if *profile != "" { + cfg, err = awsconfig.LoadDefaultConfig(ctx, awsconfig.WithSharedConfigProfile(*profile)) + } else { + cfg, err = awsconfig.LoadDefaultConfig(ctx) + } + if err != nil { + log.Fatalf("failed to load AWS config: %v", err) + } + + awsRegion := *region + if awsRegion == "" { + awsRegion = cfg.Region + } + if awsRegion == "" { + // Last resort: try to infer from the hostname if it looks like a standard RDS endpoint. + if inferred := inferRegionFromHost(*host); inferred != "" { + awsRegion = inferred + } + } + + if awsRegion == "" { + log.Fatalf("AWS region is not set; pass -region or set AWS_REGION / configure your AWS profile") + } + + endpointWithPort := fmt.Sprintf("%s:%d", *host, *port) + + // Generate the IAM auth token. + authToken, err := auth.BuildAuthToken(ctx, endpointWithPort, awsRegion, *user, cfg.Credentials) + if err != nil { + log.Fatalf("failed to build RDS IAM auth token: %v", err) + } + + // Prepare psql command. We pass the token through PGPASSWORD and SSL mode via PGSSLMODE. + cmd := exec.Command( + *psqlPath, + "--host", *host, + "--port", fmt.Sprintf("%d", *port), + "--username", *user, + "--dbname", *dbName, + ) + + // Attach stdio so it behaves like an interactive shell. + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // Inherit existing env and add PG vars. + env := os.Environ() + env = append(env, + "PGPASSWORD="+authToken, + "PGSSLMODE="+*sslMode, + ) + + // If a search path is provided, wire it through PGOPTIONS. + if sp := strings.TrimSpace(*searchPath); sp != "" { + // Build our addition: one -c flag. + add := "-c search_path=" + sp + + // Check if PGOPTIONS already exists; if so, append. + found := false + for i, e := range env { + if strings.HasPrefix(e, "PGOPTIONS=") { + current := strings.TrimPrefix(e, "PGOPTIONS=") + if strings.TrimSpace(current) == "" { + env[i] = "PGOPTIONS=" + add + } else { + env[i] = "PGOPTIONS=" + current + " " + add + } + found = true + break + } + } + if !found { + env = append(env, "PGOPTIONS="+add) + } + } + + cmd.Env = env + + if err := cmd.Run(); err != nil { + // psql will print its own error messages; just propagate the exit code. + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) + } + log.Fatalf("failed to run psql: %v", err) + } +} + +// inferRegionFromHost tries to pull the AWS region out of a typical RDS hostname like +// "mydb.abc123.us-east-1.rds.amazonaws.com". If it can't, it returns "". +func inferRegionFromHost(host string) string { + parts := strings.Split(host, ".") + for i := 0; i < len(parts); i++ { + if parts[i] == "rds" && i > 0 { + return parts[i-1] + } + } + return "" +} From 200599b9d491e3bc1642c4bc2b2869af087cc182 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Thu, 18 Dec 2025 20:25:02 -0800 Subject: [PATCH 02/10] Fix ctrl-c issue, and remove parsing region from db hostname --- cmd/rds-iam-psql/main.go | 64 +++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index d076047..659518c 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -1,4 +1,3 @@ -// rds-iam-psql.go package main import ( @@ -8,7 +7,9 @@ import ( "log" "os" "os/exec" + "os/signal" "strings" + "syscall" "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" @@ -51,12 +52,6 @@ func main() { if awsRegion == "" { awsRegion = cfg.Region } - if awsRegion == "" { - // Last resort: try to infer from the hostname if it looks like a standard RDS endpoint. - if inferred := inferRegionFromHost(*host); inferred != "" { - awsRegion = inferred - } - } if awsRegion == "" { log.Fatalf("AWS region is not set; pass -region or set AWS_REGION / configure your AWS profile") @@ -93,10 +88,8 @@ func main() { // If a search path is provided, wire it through PGOPTIONS. if sp := strings.TrimSpace(*searchPath); sp != "" { - // Build our addition: one -c flag. add := "-c search_path=" + sp - // Check if PGOPTIONS already exists; if so, append. found := false for i, e := range env { if strings.HasPrefix(e, "PGOPTIONS=") { @@ -117,23 +110,46 @@ func main() { cmd.Env = env - if err := cmd.Run(); err != nil { - // psql will print its own error messages; just propagate the exit code. - if exitErr, ok := err.(*exec.ExitError); ok { - os.Exit(exitErr.ExitCode()) - } - log.Fatalf("failed to run psql: %v", err) + // --- Ctrl-C handling --- + // The key idea: keep psql in the same foreground process group so it can read + // from the terminal. We intercept SIGINT only to prevent THIS wrapper from + // exiting; psql will still receive SIGINT normally and cancel the current + // query / line as expected. + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(sigCh) + + if err := cmd.Start(); err != nil { + log.Fatalf("failed to start psql: %v", err) } -} -// inferRegionFromHost tries to pull the AWS region out of a typical RDS hostname like -// "mydb.abc123.us-east-1.rds.amazonaws.com". If it can't, it returns "". -func inferRegionFromHost(host string) string { - parts := strings.Split(host, ".") - for i := 0; i < len(parts); i++ { - if parts[i] == "rds" && i > 0 { - return parts[i-1] + waitCh := make(chan error, 1) + go func() { waitCh <- cmd.Wait() }() + + for { + select { + case sig := <-sigCh: + switch sig { + case os.Interrupt: + // Swallow SIGINT so this wrapper doesn't exit. + // psql still gets SIGINT (same terminal foreground process group). + continue + case syscall.SIGTERM: + // If we're being terminated, pass it through to psql and exit accordingly. + if cmd.Process != nil { + _ = cmd.Process.Signal(syscall.SIGTERM) + } + } + case err := <-waitCh: + // psql exited; now we exit with the same code. + if err == nil { + return + } + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) + } + log.Fatalf("psql failed: %v", err) } } - return "" } + From a59d002953481f76fe932baff31e3e0a396a7293 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Thu, 18 Dec 2025 20:41:17 -0800 Subject: [PATCH 03/10] Sts check --- cmd/rds-iam-psql/main.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index 659518c..43a6870 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -14,6 +14,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/feature/rds/auth" + "github.com/aws/aws-sdk-go-v2/service/sts" ) func main() { @@ -48,6 +49,11 @@ func main() { log.Fatalf("failed to load AWS config: %v", err) } + // Fail fast + print identity (account/arn/role-ish). + if err := printCallerIdentity(ctx, cfg); err != nil { + log.Fatalf("AWS credentials check failed: %v", err) + } + awsRegion := *region if awsRegion == "" { awsRegion = cfg.Region @@ -153,3 +159,19 @@ func main() { } } +func printCallerIdentity(ctx context.Context, cfg aws.Config) error { + stsClient := sts.NewFromConfig(cfg) + + out, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + return fmt.Errorf("STS GetCallerIdentity failed (creds invalid/expired or STS not allowed): %w", err) + } + + account := aws.ToString(out.Account) + arn := aws.ToString(out.Arn) + + fmt.Printf("AWS Account: %s\n", account) + fmt.Printf("Caller ARN: %s\n", arn) + + return nil +} From ece0181e437f3a12a80fa91967c914b11f5b0254 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Thu, 18 Dec 2025 20:41:39 -0800 Subject: [PATCH 04/10] White space --- cmd/rds-iam-psql/main.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index 43a6870..f6213b0 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -175,3 +175,4 @@ func printCallerIdentity(ctx context.Context, cfg aws.Config) error { return nil } + From ac386aa4f076709e4c7c354311e39cb71fdf7728 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Thu, 18 Dec 2025 20:45:21 -0800 Subject: [PATCH 05/10] tighter sts print --- cmd/rds-iam-psql/main.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index f6213b0..d6e2c6f 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -167,12 +167,7 @@ func printCallerIdentity(ctx context.Context, cfg aws.Config) error { return fmt.Errorf("STS GetCallerIdentity failed (creds invalid/expired or STS not allowed): %w", err) } - account := aws.ToString(out.Account) - arn := aws.ToString(out.Arn) - - fmt.Printf("AWS Account: %s\n", account) - fmt.Printf("Caller ARN: %s\n", arn) - + fmt.Printf("Caller ARN: %s\n", aws.ToString(out.Arn)) return nil } From d46322af740cc377805bf3abf1d9db4b37c84fd6 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Thu, 18 Dec 2025 20:45:41 -0800 Subject: [PATCH 06/10] go fmt --- cmd/rds-iam-psql/main.go | 1 - 1 file changed, 1 deletion(-) diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index d6e2c6f..8406c0d 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -170,4 +170,3 @@ func printCallerIdentity(ctx context.Context, cfg aws.Config) error { fmt.Printf("Caller ARN: %s\n", aws.ToString(out.Arn)) return nil } - From b2884c959839b60a9f7cfaf6a3a562d9d8539982 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Wed, 18 Feb 2026 19:15:23 -0800 Subject: [PATCH 07/10] use connetor --- cmd/rds-iam-psql/README.md | 85 ++++++----- cmd/rds-iam-psql/main.go | 139 +++++++++++------- pgutils/connector.go | 288 +++++++++++++++++++++---------------- pgutils/listener.go | 7 +- 4 files changed, 307 insertions(+), 212 deletions(-) diff --git a/cmd/rds-iam-psql/README.md b/cmd/rds-iam-psql/README.md index b6dae57..a316106 100644 --- a/cmd/rds-iam-psql/README.md +++ b/cmd/rds-iam-psql/README.md @@ -1,10 +1,14 @@ # rds-iam-psql -A simple CLI tool that bridges AWS RDS IAM authentication into an interactive `psql` session. It generates a short-lived IAM auth token and launches `psql` with the token as the password, so you never have to manage database passwords. +A CLI that launches an interactive `psql` session from either: +- a positional connection URL, or +- individual `-host/-port/-user/-db` flags. + +It supports standard PostgreSQL URLs and `pgutils` custom IAM URLs (`postgres+rds-iam://...`). ## Why? -RDS IAM authentication lets you connect to PostgreSQL using your AWS credentials instead of a static database password. However, the auth tokens are temporary (15 minutes) and cumbersome to generate manually. This tool handles token generation automatically and drops you into a familiar `psql` shell. +RDS IAM authentication lets you connect using AWS credentials instead of a static DB password. IAM auth tokens are short-lived and inconvenient to generate manually. This tool resolves a fresh DSN through `pgutils` and opens `psql` for you. ## Installation @@ -22,74 +26,85 @@ go build ## Prerequisites - **psql** installed and available in your PATH -- **AWS credentials** configured (via environment variables, `~/.aws/credentials`, IAM role, etc.) -- **RDS IAM authentication enabled** on your database instance -- A database user configured for IAM authentication (created with `CREATE USER myuser WITH LOGIN; GRANT rds_iam TO myuser;`) +- For IAM URLs (`postgres+rds-iam://...`), **AWS credentials** configured (env vars, `~/.aws/credentials`, IAM role, etc.) +- For IAM URLs (`postgres+rds-iam://...`), **AWS_REGION** set +- For IAM URLs (`postgres+rds-iam://...`), **RDS IAM authentication enabled** on your database instance +- For IAM URLs (`postgres+rds-iam://...`), a DB user configured for IAM auth (for example: `CREATE USER myuser WITH LOGIN; GRANT rds_iam TO myuser;`) ## Usage ```bash -rds-iam-psql -host -user -db [options] +rds-iam-psql [connection-url] [options] ``` -### Required Flags +```bash +rds-iam-psql -host -user -db [options] +``` -| Flag | Description | -|------|-------------| -| `-host` | RDS endpoint hostname (without port), e.g. `mydb.abc123.us-east-1.rds.amazonaws.com` | -| `-user` | Database username configured for IAM auth | -| `-db` | Database name to connect to | +`connection-url` supports: +- `postgres+rds-iam://user@host:5432/dbname` +- `postgres://user:pass@host:5432/dbname?...` +- `postgresql://user:pass@host:5432/dbname?...` -### Optional Flags +If `connection-url` is provided, do not combine it with `-host/-port/-user/-db`. + +### Flags | Flag | Default | Description | |------|---------|-------------| +| `-host` | | Endpoint hostname (required if `connection-url` is not provided) | | `-port` | `5432` | PostgreSQL port | -| `-region` | auto | AWS region. If omitted, inferred from AWS config or the hostname | -| `-profile` | | AWS shared config profile to use (e.g. `dev`, `prod`) | +| `-user` | | DB username (required if `connection-url` is not provided) | +| `-db` | | DB name (required if `connection-url` is not provided) | | `-psql` | `psql` | Path to the `psql` binary | | `-sslmode` | `require` | SSL mode (`require`, `verify-full`, etc.) | | `-search-path` | | PostgreSQL `search_path` to set on connection (e.g. `myschema,public`) | ## Examples -Basic connection: +Positional IAM URL (your requested form): + +```bash +./rds-iam-psql 'postgres+rds-iam://server@acremins-test.cicxifnkufnd.us-east-1.rds.amazonaws.com:5432/postgres' +``` + +IAM URL with cross-account role assumption: + +```bash +rds-iam-psql 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp?assume_role_arn=arn:aws:iam::123456789012:role/db-connect&assume_role_session_name=rds-iam-psql' +``` + +Flag-based IAM connection: ```bash rds-iam-psql -host mydb.abc123.us-east-1.rds.amazonaws.com -user app_user -db myapp ``` -With a specific AWS profile and schema: +Standard PostgreSQL URL (non-IAM): ```bash -rds-iam-psql \ - -host mydb.abc123.us-east-1.rds.amazonaws.com \ - -user app_user \ - -db myapp \ - -profile production \ - -search-path "app_schema,public" +rds-iam-psql 'postgresql://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable' ``` -Using a non-standard port and explicit region: +With search path: ```bash rds-iam-psql \ -host mydb.abc123.us-east-1.rds.amazonaws.com \ - -port 5433 \ - -user admin \ - -db postgres \ - -region us-east-1 + -user app_user \ + -db myapp \ + -search-path "app_schema,public" ``` ## How It Works -1. Loads your AWS credentials from the standard credential chain -2. Generates a temporary RDS IAM auth token using `auth.BuildAuthToken` -3. Launches `psql` with: - - `PGPASSWORD` set to the auth token - - `PGSSLMODE` set according to `-sslmode` - - `PGOPTIONS` set if `-search-path` is provided -4. Attaches stdin/stdout/stderr for interactive use +1. Parses input from either positional URL or `-host/-port/-user/-db`. +2. Builds a `pgutils.ConnectionStringProvider` from the URL. +3. For IAM URLs, validates AWS auth context (including `AWS_REGION`). +4. Resolves a DSN from the provider and launches `psql` with: +- `PGPASSWORD` set from the URL password/token +- `PGSSLMODE` set from `-sslmode` +- `PGOPTIONS` set when `-search-path` is provided ## Setting Up IAM Auth on RDS diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index 8406c0d..e554ea8 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -5,16 +5,19 @@ import ( "flag" "fmt" "log" + "net" + "net/url" "os" "os/exec" "os/signal" + "strconv" "strings" "syscall" "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/feature/rds/auth" "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/corbaltcode/go-libraries/pgutils" ) func main() { @@ -23,76 +26,80 @@ func main() { port = flag.Int("port", 5432, "RDS PostgreSQL port (default 5432)") user = flag.String("user", "", "Database user name") dbName = flag.String("db", "", "Database name") - region = flag.String("region", "", "AWS region for the RDS instance (e.g. us-east-1). If empty, uses AWS config or tries to infer from host.") - profile = flag.String("profile", "", "Optional AWS shared config profile (e.g. dev)") psqlPath = flag.String("psql", "psql", "Path to psql binary") sslMode = flag.String("sslmode", "require", "PGSSLMODE for psql (e.g. require, verify-full)") searchPath = flag.String("search-path", "", "Optional PostgreSQL search_path to set (e.g. 'myschema,public')") ) flag.Parse() - if *host == "" || *user == "" || *dbName == "" { - log.Fatalf("host, user, and db are required\n\nUsage example:\n %s -host mydb.abc123.us-east-1.rds.amazonaws.com -port 5432 -user myuser -db mydb -search-path \"login,public\" -region us-east-1\n", os.Args[0]) + args := flag.Args() + if len(args) > 1 { + log.Fatalf("expected at most one positional connection URL argument, got %d", len(args)) } - ctx := context.Background() - - // Load AWS config (standard RDS/IAM auth expects your AWS creds, *not* the DB password). - var cfg aws.Config - var err error - if *profile != "" { - cfg, err = awsconfig.LoadDefaultConfig(ctx, awsconfig.WithSharedConfigProfile(*profile)) - } else { - cfg, err = awsconfig.LoadDefaultConfig(ctx) + connectionURLArg := "" + if len(args) == 1 { + connectionURLArg = args[0] } + + rawURL, usesIAM, err := buildRawURL(connectionURLArg, *host, *port, *user, *dbName) if err != nil { - log.Fatalf("failed to load AWS config: %v", err) + log.Fatalf("%v\n\nUsage examples:\n %s -host mydb.abc123.us-east-1.rds.amazonaws.com -port 5432 -user myuser -db mydb -search-path \"login,public\"\n %s 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'\n", err, os.Args[0], os.Args[0]) } - // Fail fast + print identity (account/arn/role-ish). - if err := printCallerIdentity(ctx, cfg); err != nil { - log.Fatalf("AWS credentials check failed: %v", err) - } + ctx := context.Background() - awsRegion := *region - if awsRegion == "" { - awsRegion = cfg.Region + connectionStringProvider, err := pgutils.NewConnectionStringProviderFromURLString(ctx, rawURL) + if err != nil { + log.Fatalf("failed to create connection string provider: %v", err) } - if awsRegion == "" { - log.Fatalf("AWS region is not set; pass -region or set AWS_REGION / configure your AWS profile") + if usesIAM { + if os.Getenv("AWS_REGION") == "" { + log.Fatalf("AWS_REGION must be set for IAM auth") + } + + cfg, err := awsconfig.LoadDefaultConfig(ctx) + if err != nil { + log.Fatalf("failed to load AWS config: %v", err) + } + if err := printCallerIdentity(ctx, cfg); err != nil { + log.Fatalf("AWS credentials check failed: %v", err) + } } - endpointWithPort := fmt.Sprintf("%s:%d", *host, *port) + dsnWithToken, err := connectionStringProvider.ConnectionString(ctx) + if err != nil { + log.Fatalf("failed to get connection string from provider: %v", err) + } - // Generate the IAM auth token. - authToken, err := auth.BuildAuthToken(ctx, endpointWithPort, awsRegion, *user, cfg.Credentials) + parsedURL, err := url.Parse(dsnWithToken) if err != nil { - log.Fatalf("failed to build RDS IAM auth token: %v", err) + log.Fatalf("failed to parse connection string from provider: %v", err) } - // Prepare psql command. We pass the token through PGPASSWORD and SSL mode via PGSSLMODE. - cmd := exec.Command( - *psqlPath, - "--host", *host, - "--port", fmt.Sprintf("%d", *port), - "--username", *user, - "--dbname", *dbName, - ) + password := "" + if parsedURL.User != nil { + var ok bool + password, ok = parsedURL.User.Password() + if ok { + parsedURL.User = url.User(parsedURL.User.Username()) + } + } + + // Pass DSN to psql without password in argv, and provide password via env. + cmd := exec.Command(*psqlPath, parsedURL.String()) - // Attach stdio so it behaves like an interactive shell. cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - // Inherit existing env and add PG vars. env := os.Environ() - env = append(env, - "PGPASSWORD="+authToken, - "PGSSLMODE="+*sslMode, - ) + if password != "" { + env = append(env, "PGPASSWORD="+password) + } + env = append(env, "PGSSLMODE="+*sslMode) - // If a search path is provided, wire it through PGOPTIONS. if sp := strings.TrimSpace(*searchPath); sp != "" { add := "-c search_path=" + sp @@ -116,11 +123,8 @@ func main() { cmd.Env = env - // --- Ctrl-C handling --- - // The key idea: keep psql in the same foreground process group so it can read - // from the terminal. We intercept SIGINT only to prevent THIS wrapper from - // exiting; psql will still receive SIGINT normally and cancel the current - // query / line as expected. + // Keep psql in the foreground process group. Swallow SIGINT in wrapper so + // psql handles Ctrl-C directly. sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) defer signal.Stop(sigCh) @@ -137,17 +141,13 @@ func main() { case sig := <-sigCh: switch sig { case os.Interrupt: - // Swallow SIGINT so this wrapper doesn't exit. - // psql still gets SIGINT (same terminal foreground process group). continue case syscall.SIGTERM: - // If we're being terminated, pass it through to psql and exit accordingly. if cmd.Process != nil { _ = cmd.Process.Signal(syscall.SIGTERM) } } case err := <-waitCh: - // psql exited; now we exit with the same code. if err == nil { return } @@ -159,6 +159,41 @@ func main() { } } +func buildRawURL(connectionURLArg, host string, port int, user, dbName string) (string, bool, error) { + if connectionURLArg != "" { + if host != "" || user != "" || dbName != "" || port != 5432 { + return "", false, fmt.Errorf("positional connection URL cannot be combined with -host, -port, -user, or -db") + } + parsedURL, err := url.Parse(connectionURLArg) + if err != nil { + return "", false, fmt.Errorf("failed to parse positional connection URL: %w", err) + } + switch parsedURL.Scheme { + case "postgres+rds-iam": + return connectionURLArg, true, nil + case "postgres", "postgresql": + return connectionURLArg, false, nil + default: + return "", false, fmt.Errorf("unsupported connection URL scheme %q (expected postgres, postgresql, or postgres+rds-iam)", parsedURL.Scheme) + } + } + + if host == "" || user == "" || dbName == "" { + return "", false, fmt.Errorf("host, user, and db are required when no positional connection URL is provided") + } + if port <= 0 { + return "", false, fmt.Errorf("invalid port: %d", port) + } + + iamURL := &url.URL{ + Scheme: "postgres+rds-iam", + User: url.User(user), + Host: net.JoinHostPort(host, strconv.Itoa(port)), + Path: "/" + dbName, + } + return iamURL.String(), true, nil +} + func printCallerIdentity(ctx context.Context, cfg aws.Config) error { stsClient := sts.NewFromConfig(cfg) diff --git a/pgutils/connector.go b/pgutils/connector.go index 21dce91..1a0b773 100644 --- a/pgutils/connector.go +++ b/pgutils/connector.go @@ -5,8 +5,9 @@ import ( "errors" "fmt" "log" + "net" "net/url" - "time" + "strings" "database/sql" "database/sql/driver" @@ -20,109 +21,161 @@ import ( "github.com/lib/pq" ) -type baseConnectionStringProvider interface { - getBaseConnectionString(ctx context.Context) (string, error) -} +const defaultPostgresPort = "5432" + +var pqDriver = &pq.Driver{} -type PostgresqlConnector struct { - baseConnectionStringProvider - searchPath string +// ConnectionStringProvider returns a Postgres connection string for use by clients +// that need a DSN (e.g., pq.Listener) or to build a connector. +type ConnectionStringProvider interface { + ConnectionString(ctx context.Context) (string, error) } -func (conn *PostgresqlConnector) WithSearchPath(searchPath string) *PostgresqlConnector { - return &PostgresqlConnector{ - baseConnectionStringProvider: conn.baseConnectionStringProvider, - searchPath: searchPath, - } +type connectionStringProviderFunc func(context.Context) (string, error) + +func (f connectionStringProviderFunc) ConnectionString(ctx context.Context) (string, error) { + return f(ctx) } -func (conn *PostgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) { - dsn, err := conn.GetConnectionString(ctx) +// NewConnectionStringProviderFromURLString parses rawURL and constructs a provider. +// +// Standard Postgres example: +// +// postgres://user:pass@host:5432/dbname?sslmode=require +// +// IAM example 1: +// +// postgres+rds-iam://user@host:5432/dbname +// +// IAM example 2 (cross-account): +// +// postgres+rds-iam://user@host:5432/dbname?assume_role_arn=...&assume_role_session_name=... +// +// For postgres+rds-iam, the provider generates a fresh IAM auth token on each ConnectionString(ctx) call. +func NewConnectionStringProviderFromURLString(ctx context.Context, rawURL string) (ConnectionStringProvider, error) { + u, err := url.Parse(rawURL) if err != nil { - return nil, fmt.Errorf("get connection string: %w", err) + return nil, fmt.Errorf("parsing URL: %w", err) } - pqConnector, err := pq.NewConnector(dsn) - if err != nil { - return nil, fmt.Errorf("create pq connector: %w", err) + + switch u.Scheme { + case "postgres", "postgresql": + return &staticConnectionStringProvider{connectionString: u.String()}, nil + case "postgres+rds-iam": + return newIAMConnectionStringProviderFromURL(ctx, u) + default: + return nil, fmt.Errorf("unsupported URL scheme: %q (expected postgres, postgresql, or postgres+rds-iam)", u.Scheme) } +} - return pqConnector.Connect(ctx) +// ToConnector wraps a ConnectionStringProvider as a driver.Connector. +// Each Connect(ctx) call asks the provider for a fresh DSN. +func ToConnector(provider ConnectionStringProvider) driver.Connector { + return &postgresqlConnector{connectionStringProvider: provider} } -func (conn *PostgresqlConnector) GetConnectionString(ctx context.Context) (string, error) { - dsn, err := conn.getBaseConnectionString(ctx) - if err != nil { - return "", fmt.Errorf("get base connection string: %w", err) +// WithSchemaSearchPath returns a ConnectionStringProvider that appends search_path +// to the DSN produced by the underlying provider. +func WithSchemaSearchPath(provider ConnectionStringProvider, searchPath string) ConnectionStringProvider { + return connectionStringProviderFunc(func(ctx context.Context) (string, error) { + dsn, err := provider.ConnectionString(ctx) + if err != nil { + return "", fmt.Errorf("ConnectionString failed: %w", err) + } + + dsnWithPath, err := addSearchPathToURL(dsn, searchPath) + if err != nil { + return "", fmt.Errorf("applying schema search path failed: %w", err) + } + + return dsnWithPath, nil + }) +} + +// ConnectDB opens a connection using the connector and verifies it with a ping +func ConnectDB(conn driver.Connector) (*sqlx.DB, error) { + sqlDB := sql.OpenDB(conn) + db := sqlx.NewDb(sqlDB, "postgres") + if err := db.Ping(); err != nil { + db.Close() + return nil, err } - if conn.searchPath == "" { - return dsn, nil + return db, nil +} + +// MustConnectDB is like ConnectDB but panics on error +func MustConnectDB(conn driver.Connector) *sqlx.DB { + db, err := ConnectDB(conn) + if err != nil { + panic(err) } + return db +} - // Add search path - u, err := url.Parse(dsn) +// addSearchPathToURL returns a copy of u with search_path set in the query string. +// It returns an error if search_path is already present. +func addSearchPathToURL(rawURL string, searchPath string) (string, error) { + u, err := url.Parse(rawURL) if err != nil { - return "", fmt.Errorf("parse DSN URL: %w", err) + return "", fmt.Errorf("url string failed to parse while adding search path: %w", err) + } + + if searchPath == "" { + return u.String(), nil } + q := u.Query() if v := q.Get("search_path"); v != "" { return "", fmt.Errorf("search_path already set to %q", v) } - q.Set("search_path", conn.searchPath) // url.Values will percent-encode commas as needed + q.Set("search_path", searchPath) u.RawQuery = q.Encode() return u.String(), nil } -func (c *PostgresqlConnector) Driver() driver.Driver { - return &pq.Driver{} +type postgresqlConnector struct { + connectionStringProvider ConnectionStringProvider } -type staticConnectionStringProvider struct { - connectionString string -} +func (c *postgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) { + dsn, err := c.connectionStringProvider.ConnectionString(ctx) + if err != nil { + return nil, fmt.Errorf("getting connection string from provider: %w", err) + } + pqConnector, err := pq.NewConnector(dsn) + if err != nil { + return nil, fmt.Errorf("creating pq connector: %w", err) + } -func (p *staticConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { - return p.connectionString, nil + return pqConnector.Connect(ctx) } -func NewPostgresqlConnectorFromConnectionString(connectionString string) *PostgresqlConnector { - return &PostgresqlConnector{ - baseConnectionStringProvider: &staticConnectionStringProvider{connectionString}, - } +func (c *postgresqlConnector) Driver() driver.Driver { + return pqDriver } -type IAMAuthConfig struct { - RDSEndpoint string - User string - Database string - - // Optional: cross-account role assumption. - // Set this to a role ARN in the RDS account (Account A) that has rds-db:connect. - AssumeRoleARN string - - // Optional: if your trust policy requires an external ID. - AssumeRoleExternalID string - - // Optional: override the default session name. - AssumeRoleSessionName string - - // Optional: override STS assume role duration. - // If zero, SDK default is used. - AssumeRoleDuration time.Duration +type staticConnectionStringProvider struct { + connectionString string } -type iamAuthConnectionStringProvider struct { - IAMAuthConfig +func (p *staticConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) { + return p.connectionString, nil +} - region string - creds aws.CredentialsProvider +type rdsIAMConnectionStringProvider struct { + RDSEndpoint string + Region string + User string + Database string + CredentialsProvider aws.CredentialsProvider } -func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { - authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.region, p.User, p.creds) +func (p *rdsIAMConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) { + authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.Region, p.User, p.CredentialsProvider) if err != nil { return "", fmt.Errorf("building auth token: %w", err) } - log.Printf("Signing RDS IAM token for \n Endpoint: %s \n User: %s \n Database: %s", p.RDSEndpoint, p.User, p.Database) + log.Printf("Signing RDS IAM token for Endpoint: %s User: %s Database: %s", p.RDSEndpoint, p.User, p.Database) dsnURL := &url.URL{ Scheme: "postgresql", @@ -134,9 +187,43 @@ func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Co return dsnURL.String(), nil } -func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) (*PostgresqlConnector, error) { - if cfg.RDSEndpoint == "" || cfg.User == "" || cfg.Database == "" { - return nil, errors.New("RDS endpoint, user, and database are required") +func newIAMConnectionStringProviderFromURL(ctx context.Context, u *url.URL) (ConnectionStringProvider, error) { + user := "" + if u.User != nil { + user = u.User.Username() + if _, hasPw := u.User.Password(); hasPw { + return nil, errors.New("postgres+rds-iam URL must not include a password") + } + } + if user == "" { + return nil, errors.New("postgres+rds-iam URL missing username") + } + + host := u.Hostname() + if host == "" { + return nil, errors.New("postgres+rds-iam URL missing host") + } + + port := u.Port() + if port == "" { + port = defaultPostgresPort + } + + // Match libpq/psql defaulting: if dbname isn't specified, dbname defaults to username. + dbName := strings.TrimPrefix(u.Path, "/") + if dbName == "" { + dbName = user + } + + q := u.Query() + supportedParams := map[string]struct{}{ + "assume_role_arn": {}, + "assume_role_session_name": {}, + } + for k := range q { + if _, ok := supportedParams[k]; !ok { + return nil, fmt.Errorf("postgres+rds-iam URL has unsupported query parameter: %s", k) + } } awsCfg, err := awsconfig.LoadDefaultConfig(ctx) @@ -149,66 +236,25 @@ func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) } creds := awsCfg.Credentials - - // Cross-account support: - // If AssumeRoleARN is set, assume a role in the RDS account (Account A) - // using the ECS task role creds from Account B as the source credentials. - if cfg.AssumeRoleARN != "" { - log.Printf("RDS IAM Assuming Role: %s for \n Endpoint: %s \n User: %s \n Database: %s", cfg.AssumeRoleARN, cfg.RDSEndpoint, cfg.User, cfg.Database) + assumeRoleARN := q.Get("assume_role_arn") + if assumeRoleARN != "" { stsClient := sts.NewFromConfig(awsCfg) - - sessionName := cfg.AssumeRoleSessionName + sessionName := q.Get("assume_role_session_name") if sessionName == "" { sessionName = "pgutils-rds-iam" } - - assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, cfg.AssumeRoleARN, func(assumeRoleOpts *stscreds.AssumeRoleOptions) { - assumeRoleOpts.RoleSessionName = sessionName - - if cfg.AssumeRoleExternalID != "" { - assumeRoleOpts.ExternalID = aws.String(cfg.AssumeRoleExternalID) - } - - if cfg.AssumeRoleDuration != 0 { - assumeRoleOpts.Duration = cfg.AssumeRoleDuration - } + log.Printf("RDS IAM Assuming Role: %s with session name: %s for Host: %s User: %s Database: %s", assumeRoleARN, sessionName, host, user, dbName) + assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, assumeRoleARN, func(opts *stscreds.AssumeRoleOptions) { + opts.RoleSessionName = sessionName }) - - // Cache to avoid calling STS too frequently. creds = aws.NewCredentialsCache(assumeProvider) } - return &PostgresqlConnector{ - baseConnectionStringProvider: &iamAuthConnectionStringProvider{ - IAMAuthConfig: *cfg, - region: awsCfg.Region, - creds: creds, - }, + return &rdsIAMConnectionStringProvider{ + Region: awsCfg.Region, + RDSEndpoint: net.JoinHostPort(host, port), + User: user, + Database: dbName, + CredentialsProvider: creds, }, nil } - -// Provides missing sqlx.OpenDB -func OpenDB(conn *PostgresqlConnector) *sqlx.DB { - sqlDB := sql.OpenDB(conn) - return sqlx.NewDb(sqlDB, "postgres") -} - -// ConnectDB opens a connection using the connector and verifies it with a ping -func ConnectDB(conn *PostgresqlConnector) (*sqlx.DB, error) { - db := OpenDB(conn) - if err := db.Ping(); err != nil { - db.Close() - return nil, err - } - return db, nil -} - -// MustConnectDB is like ConnectDB but panics on error -func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB { - db, err := ConnectDB(conn) - if err != nil { - panic(err) - } - return db -} - diff --git a/pgutils/listener.go b/pgutils/listener.go index 958462c..d1a7d06 100644 --- a/pgutils/listener.go +++ b/pgutils/listener.go @@ -69,7 +69,7 @@ func listenerEventToString(t pq.ListenerEventType) string { // The callback is invoked from the listener goroutine; it MUST NOT block // for long periods. If you need to do heavy work, offload it to another // goroutine. -func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string, callback func(*pq.Notification), onClose func()) error { +func Listen(ctx context.Context, provider ConnectionStringProvider, pgChannelName string, callback func(*pq.Notification), onClose func()) error { if callback == nil { return fmt.Errorf("listener callback cannot be nil") } @@ -77,9 +77,9 @@ func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string reconnectEventCh := make(chan struct{}, 1) // We just need a single reconnect event to trigger, so buffer size of 1 makeListener := func() (*pq.Listener, error) { - url, err := conn.GetConnectionString(ctx) + url, err := provider.ConnectionString(ctx) if err != nil { - return nil, fmt.Errorf("get url: %w", err) + return nil, fmt.Errorf("error getting connection string from provider: %w", err) } cb := func(t pq.ListenerEventType, e error) { @@ -174,4 +174,3 @@ func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string return nil } - From edc2f009b90a3cc334672e99e5fa3bfb1cf4c15e Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Mon, 9 Mar 2026 00:38:30 -0700 Subject: [PATCH 08/10] refactor --- cmd/rds-iam-psql/README.md | 71 +++++-------- cmd/rds-iam-psql/main.go | 211 +++++++++++++++++++++++-------------- 2 files changed, 156 insertions(+), 126 deletions(-) diff --git a/cmd/rds-iam-psql/README.md b/cmd/rds-iam-psql/README.md index a316106..c92cd36 100644 --- a/cmd/rds-iam-psql/README.md +++ b/cmd/rds-iam-psql/README.md @@ -1,10 +1,9 @@ # rds-iam-psql -A CLI that launches an interactive `psql` session from either: -- a positional connection URL, or -- individual `-host/-port/-user/-db` flags. - -It supports standard PostgreSQL URLs and `pgutils` custom IAM URLs (`postgres+rds-iam://...`). +A CLI that launches an interactive `psql` session from a required RDS IAM URL: +- positional `postgres+rds-iam://...` DSN +- optional `-search-path` flag +- optional `-debug-aws` flag ## Why? @@ -26,43 +25,30 @@ go build ## Prerequisites - **psql** installed and available in your PATH -- For IAM URLs (`postgres+rds-iam://...`), **AWS credentials** configured (env vars, `~/.aws/credentials`, IAM role, etc.) -- For IAM URLs (`postgres+rds-iam://...`), **AWS_REGION** set -- For IAM URLs (`postgres+rds-iam://...`), **RDS IAM authentication enabled** on your database instance -- For IAM URLs (`postgres+rds-iam://...`), a DB user configured for IAM auth (for example: `CREATE USER myuser WITH LOGIN; GRANT rds_iam TO myuser;`) +- **AWS credentials** configured (env vars, `~/.aws/credentials`, IAM role, etc.) +- **AWS region** configured for SDK resolution (for example: `AWS_REGION`, shared config profile, or runtime role config) +- **RDS IAM authentication enabled** on your database instance +- A DB user configured for IAM auth (for example: `CREATE USER myuser WITH LOGIN; GRANT rds_iam TO myuser;`) ## Usage ```bash -rds-iam-psql [connection-url] [options] +rds-iam-psql [-search-path "schema,public"] [-debug-aws] '' ``` -```bash -rds-iam-psql -host -user -db [options] -``` - -`connection-url` supports: -- `postgres+rds-iam://user@host:5432/dbname` -- `postgres://user:pass@host:5432/dbname?...` -- `postgresql://user:pass@host:5432/dbname?...` - -If `connection-url` is provided, do not combine it with `-host/-port/-user/-db`. +- Flags must come before the DSN (standard Go flag parsing behavior). +- `` may omit the database path. When omitted, `pgutils` defaults the database name to the username. ### Flags | Flag | Default | Description | |------|---------|-------------| -| `-host` | | Endpoint hostname (required if `connection-url` is not provided) | -| `-port` | `5432` | PostgreSQL port | -| `-user` | | DB username (required if `connection-url` is not provided) | -| `-db` | | DB name (required if `connection-url` is not provided) | -| `-psql` | `psql` | Path to the `psql` binary | -| `-sslmode` | `require` | SSL mode (`require`, `verify-full`, etc.) | | `-search-path` | | PostgreSQL `search_path` to set on connection (e.g. `myschema,public`) | +| `-debug-aws` | `false` | Print STS caller identity before connecting | ## Examples -Positional IAM URL (your requested form): +Basic IAM URL: ```bash ./rds-iam-psql 'postgres+rds-iam://server@acremins-test.cicxifnkufnd.us-east-1.rds.amazonaws.com:5432/postgres' @@ -74,37 +60,34 @@ IAM URL with cross-account role assumption: rds-iam-psql 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp?assume_role_arn=arn:aws:iam::123456789012:role/db-connect&assume_role_session_name=rds-iam-psql' ``` -Flag-based IAM connection: +With search path: ```bash -rds-iam-psql -host mydb.abc123.us-east-1.rds.amazonaws.com -user app_user -db myapp +rds-iam-psql \ + -search-path "app_schema,public" \ + 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp' ``` -Standard PostgreSQL URL (non-IAM): +With AWS identity debugging: ```bash -rds-iam-psql 'postgresql://postgres:postgres@127.0.0.1:5432/postgres?sslmode=disable' +rds-iam-psql -debug-aws 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp' ``` -With search path: +Without explicit database name (defaults to username): ```bash -rds-iam-psql \ - -host mydb.abc123.us-east-1.rds.amazonaws.com \ - -user app_user \ - -db myapp \ - -search-path "app_schema,public" +rds-iam-psql 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432' ``` ## How It Works -1. Parses input from either positional URL or `-host/-port/-user/-db`. -2. Builds a `pgutils.ConnectionStringProvider` from the URL. -3. For IAM URLs, validates AWS auth context (including `AWS_REGION`). -4. Resolves a DSN from the provider and launches `psql` with: -- `PGPASSWORD` set from the URL password/token -- `PGSSLMODE` set from `-sslmode` -- `PGOPTIONS` set when `-search-path` is provided +1. Parses and validates the positional IAM URL. +2. Builds a `pgutils` connection string provider from the IAM URL. +3. If `-search-path` is set, adds libpq `options=-csearch_path=...` to the connection URI before launching `psql`. +4. If `-debug-aws` is set, runs STS `GetCallerIdentity` and prints the caller ARN. +5. Resolves an IAM tokenized DSN from the provider and launches `psql` with: +- `PGPASSWORD` set from the generated token ## Setting Up IAM Auth on RDS diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index e554ea8..6f7b89e 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -1,16 +1,17 @@ package main import ( + "bytes" "context" + "errors" "flag" "fmt" + "io" "log" - "net" "net/url" "os" "os/exec" "os/signal" - "strconv" "strings" "syscall" @@ -20,45 +21,42 @@ import ( "github.com/corbaltcode/go-libraries/pgutils" ) -func main() { - var ( - host = flag.String("host", "", "RDS PostgreSQL endpoint hostname (no port, e.g. mydb.abc123.us-east-1.rds.amazonaws.com)") - port = flag.Int("port", 5432, "RDS PostgreSQL port (default 5432)") - user = flag.String("user", "", "Database user name") - dbName = flag.String("db", "", "Database name") - psqlPath = flag.String("psql", "psql", "Path to psql binary") - sslMode = flag.String("sslmode", "require", "PGSSLMODE for psql (e.g. require, verify-full)") - searchPath = flag.String("search-path", "", "Optional PostgreSQL search_path to set (e.g. 'myschema,public')") - ) - flag.Parse() +const usageTemplate = `Usage: + %[2]s [-search-path "login,public"] [-debug-aws] 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb' - args := flag.Args() - if len(args) > 1 { - log.Fatalf("expected at most one positional connection URL argument, got %d", len(args)) - } +Notes: + Flags must come before the DSN (standard Go flag parsing). + Database path is optional. If omitted, the database name defaults to the username. - connectionURLArg := "" - if len(args) == 1 { - connectionURLArg = args[0] - } +Flags: +%[1]s - rawURL, usesIAM, err := buildRawURL(connectionURLArg, *host, *port, *user, *dbName) +Examples: + %[2]s 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb' + %[2]s -search-path "login,public" 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432' + %[2]s -debug-aws 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb' +` + +func main() { + rawURL, searchPath, debugAWS, err := parseCLIArgs(os.Args[1:], os.Args[0]) if err != nil { - log.Fatalf("%v\n\nUsage examples:\n %s -host mydb.abc123.us-east-1.rds.amazonaws.com -port 5432 -user myuser -db mydb -search-path \"login,public\"\n %s 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'\n", err, os.Args[0], os.Args[0]) + if errors.Is(err, flag.ErrHelp) { + printUsage(os.Stdout, os.Args[0]) + return + } + fmt.Fprintf(os.Stderr, "%v\n\n", err) + printUsage(os.Stderr, os.Args[0]) + os.Exit(2) } - ctx := context.Background() - - connectionStringProvider, err := pgutils.NewConnectionStringProviderFromURLString(ctx, rawURL) - if err != nil { - log.Fatalf("failed to create connection string provider: %v", err) + if err := validateRDSIAMURL(rawURL); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(2) } - if usesIAM { - if os.Getenv("AWS_REGION") == "" { - log.Fatalf("AWS_REGION must be set for IAM auth") - } + ctx := context.Background() + if debugAWS { cfg, err := awsconfig.LoadDefaultConfig(ctx) if err != nil { log.Fatalf("failed to load AWS config: %v", err) @@ -68,6 +66,11 @@ func main() { } } + connectionStringProvider, err := pgutils.NewConnectionStringProviderFromURLString(ctx, rawURL) + if err != nil { + log.Fatalf("failed to create connection string provider: %v", err) + } + dsnWithToken, err := connectionStringProvider.ConnectionString(ctx) if err != nil { log.Fatalf("failed to get connection string from provider: %v", err) @@ -78,6 +81,11 @@ func main() { log.Fatalf("failed to parse connection string from provider: %v", err) } + if err := addSearchPathToPSQLURL(parsedURL, searchPath); err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(2) + } + password := "" if parsedURL.User != nil { var ok bool @@ -88,7 +96,7 @@ func main() { } // Pass DSN to psql without password in argv, and provide password via env. - cmd := exec.Command(*psqlPath, parsedURL.String()) + cmd := exec.Command("psql", parsedURL.String()) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout @@ -98,28 +106,6 @@ func main() { if password != "" { env = append(env, "PGPASSWORD="+password) } - env = append(env, "PGSSLMODE="+*sslMode) - - if sp := strings.TrimSpace(*searchPath); sp != "" { - add := "-c search_path=" + sp - - found := false - for i, e := range env { - if strings.HasPrefix(e, "PGOPTIONS=") { - current := strings.TrimPrefix(e, "PGOPTIONS=") - if strings.TrimSpace(current) == "" { - env[i] = "PGOPTIONS=" + add - } else { - env[i] = "PGOPTIONS=" + current + " " + add - } - found = true - break - } - } - if !found { - env = append(env, "PGOPTIONS="+add) - } - } cmd.Env = env @@ -159,39 +145,100 @@ func main() { } } -func buildRawURL(connectionURLArg, host string, port int, user, dbName string) (string, bool, error) { - if connectionURLArg != "" { - if host != "" || user != "" || dbName != "" || port != 5432 { - return "", false, fmt.Errorf("positional connection URL cannot be combined with -host, -port, -user, or -db") - } - parsedURL, err := url.Parse(connectionURLArg) - if err != nil { - return "", false, fmt.Errorf("failed to parse positional connection URL: %w", err) - } - switch parsedURL.Scheme { - case "postgres+rds-iam": - return connectionURLArg, true, nil - case "postgres", "postgresql": - return connectionURLArg, false, nil - default: - return "", false, fmt.Errorf("unsupported connection URL scheme %q (expected postgres, postgresql, or postgres+rds-iam)", parsedURL.Scheme) - } +func newFlagSet(bin string, output io.Writer) (fs *flag.FlagSet, searchPathFlag *string, debugAWSFlag *bool) { + fs = flag.NewFlagSet(bin, flag.ContinueOnError) + fs.SetOutput(output) + + return fs, + fs.String("search-path", "", "Optional PostgreSQL search_path to set (e.g. 'myschema,public')"), + fs.Bool("debug-aws", false, "Print AWS caller identity before connecting") +} + +func printUsage(output io.Writer, bin string) { + fs, _, _ := newFlagSet(bin, io.Discard) + + var defaults bytes.Buffer + fs.SetOutput(&defaults) + fs.PrintDefaults() + + fmt.Fprintf(output, usageTemplate, strings.TrimRight(defaults.String(), "\n"), bin) +} + +func parseCLIArgs(args []string, bin string) (rawURL string, searchPath string, debugAWS bool, err error) { + fs, searchPathFlag, debugAWSFlag := newFlagSet(bin, io.Discard) + + if err := fs.Parse(args); err != nil { + return "", "", false, err + } + + positionals := fs.Args() + if len(positionals) != 1 { + return "", "", false, fmt.Errorf("expected exactly one positional RDS IAM connection URL argument, got %d", len(positionals)) + } + + return positionals[0], *searchPathFlag, *debugAWSFlag, nil +} + +func addSearchPathToPSQLURL(u *url.URL, searchPath string) error { + normalized, err := normalizeSearchPath(searchPath) + if err != nil { + return err + } + if normalized == "" { + return nil + } + + query := u.Query() + add := "-csearch_path=" + normalized + + existing := strings.TrimSpace(query.Get("options")) + if existing == "" { + query.Set("options", add) + } else { + query.Set("options", existing+" "+add) + } + + u.RawQuery = query.Encode() + return nil +} + +func normalizeSearchPath(searchPath string) (string, error) { + if strings.TrimSpace(searchPath) == "" { + return "", nil } - if host == "" || user == "" || dbName == "" { - return "", false, fmt.Errorf("host, user, and db are required when no positional connection URL is provided") + parts := strings.Split(searchPath, ",") + cleaned := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + cleaned = append(cleaned, p) + } } - if port <= 0 { - return "", false, fmt.Errorf("invalid port: %d", port) + + if len(cleaned) == 0 { + return "", fmt.Errorf("search path cannot be empty") } - iamURL := &url.URL{ - Scheme: "postgres+rds-iam", - User: url.User(user), - Host: net.JoinHostPort(host, strconv.Itoa(port)), - Path: "/" + dbName, + return strings.Join(cleaned, ","), nil +} + +func validateRDSIAMURL(rawURL string) error { + parsedURL, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("failed to parse positional connection URL: %w", err) + } + if parsedURL.Scheme != "postgres+rds-iam" { + return fmt.Errorf("unsupported connection URL scheme %q (expected postgres+rds-iam)", parsedURL.Scheme) } - return iamURL.String(), true, nil + if parsedURL.User == nil || strings.TrimSpace(parsedURL.User.Username()) == "" { + return fmt.Errorf("connection URL must include a database username") + } + if strings.TrimSpace(parsedURL.Host) == "" { + return fmt.Errorf("connection URL must include a database host") + } + + return nil } func printCallerIdentity(ctx context.Context, cfg aws.Config) error { From da15554a9d7adc4f4f121be68ff3bed0af954681 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Mon, 9 Mar 2026 01:33:05 -0700 Subject: [PATCH 09/10] remove searchpath --- cmd/rds-iam-psql/README.md | 28 +++++----- cmd/rds-iam-psql/main.go | 102 ++++++------------------------------- 2 files changed, 30 insertions(+), 100 deletions(-) diff --git a/cmd/rds-iam-psql/README.md b/cmd/rds-iam-psql/README.md index c92cd36..968f8c9 100644 --- a/cmd/rds-iam-psql/README.md +++ b/cmd/rds-iam-psql/README.md @@ -2,7 +2,6 @@ A CLI that launches an interactive `psql` session from a required RDS IAM URL: - positional `postgres+rds-iam://...` DSN -- optional `-search-path` flag - optional `-debug-aws` flag ## Why? @@ -33,7 +32,7 @@ go build ## Usage ```bash -rds-iam-psql [-search-path "schema,public"] [-debug-aws] '' +rds-iam-psql [-debug-aws] '' ``` - Flags must come before the DSN (standard Go flag parsing behavior). @@ -43,7 +42,6 @@ rds-iam-psql [-search-path "schema,public"] [-debug-aws] ' | Flag | Default | Description | |------|---------|-------------| -| `-search-path` | | PostgreSQL `search_path` to set on connection (e.g. `myschema,public`) | | `-debug-aws` | `false` | Print STS caller identity before connecting | ## Examples @@ -60,14 +58,6 @@ IAM URL with cross-account role assumption: rds-iam-psql 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp?assume_role_arn=arn:aws:iam::123456789012:role/db-connect&assume_role_session_name=rds-iam-psql' ``` -With search path: - -```bash -rds-iam-psql \ - -search-path "app_schema,public" \ - 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp' -``` - With AWS identity debugging: ```bash @@ -80,13 +70,23 @@ Without explicit database name (defaults to username): rds-iam-psql 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432' ``` +## Changing Search Path In psql + +If you need to change the schema search path, do it from the interactive `psql` session after connecting: + +```sql +SHOW search_path; +SET search_path TO app_schema, public; +``` + +This applies to the current session. If you need a persistent default, configure it in Postgres (for example with `ALTER ROLE ... SET search_path ...`). + ## How It Works 1. Parses and validates the positional IAM URL. 2. Builds a `pgutils` connection string provider from the IAM URL. -3. If `-search-path` is set, adds libpq `options=-csearch_path=...` to the connection URI before launching `psql`. -4. If `-debug-aws` is set, runs STS `GetCallerIdentity` and prints the caller ARN. -5. Resolves an IAM tokenized DSN from the provider and launches `psql` with: +3. If `-debug-aws` is set, runs STS `GetCallerIdentity` and prints the caller ARN. +4. Resolves an IAM tokenized DSN from the provider and launches `psql` with: - `PGPASSWORD` set from the generated token ## Setting Up IAM Auth on RDS diff --git a/cmd/rds-iam-psql/main.go b/cmd/rds-iam-psql/main.go index 6f7b89e..03d1c4a 100644 --- a/cmd/rds-iam-psql/main.go +++ b/cmd/rds-iam-psql/main.go @@ -22,7 +22,7 @@ import ( ) const usageTemplate = `Usage: - %[2]s [-search-path "login,public"] [-debug-aws] 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb' + %[2]s [-debug-aws] 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb' Notes: Flags must come before the DSN (standard Go flag parsing). @@ -33,12 +33,11 @@ Flags: Examples: %[2]s 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb' - %[2]s -search-path "login,public" 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432' %[2]s -debug-aws 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb' ` func main() { - rawURL, searchPath, debugAWS, err := parseCLIArgs(os.Args[1:], os.Args[0]) + rawURL, debugAWS, err := parseCLIArgs(os.Args[1:], os.Args[0]) if err != nil { if errors.Is(err, flag.ErrHelp) { printUsage(os.Stdout, os.Args[0]) @@ -76,41 +75,14 @@ func main() { log.Fatalf("failed to get connection string from provider: %v", err) } - parsedURL, err := url.Parse(dsnWithToken) - if err != nil { - log.Fatalf("failed to parse connection string from provider: %v", err) - } - - if err := addSearchPathToPSQLURL(parsedURL, searchPath); err != nil { - fmt.Fprintf(os.Stderr, "%v\n", err) - os.Exit(2) - } - - password := "" - if parsedURL.User != nil { - var ok bool - password, ok = parsedURL.User.Password() - if ok { - parsedURL.User = url.User(parsedURL.User.Username()) - } - } - - // Pass DSN to psql without password in argv, and provide password via env. - cmd := exec.Command("psql", parsedURL.String()) + cmd := exec.Command("psql", dsnWithToken) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - env := os.Environ() - if password != "" { - env = append(env, "PGPASSWORD="+password) - } - - cmd.Env = env - - // Keep psql in the foreground process group. Swallow SIGINT in wrapper so - // psql handles Ctrl-C directly. + // Ignore SIGINT in the wrapper so interactive Ctrl-C can be handled by psql. + // Forward SIGTERM to the child process. sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) defer signal.Stop(sigCh) @@ -145,17 +117,16 @@ func main() { } } -func newFlagSet(bin string, output io.Writer) (fs *flag.FlagSet, searchPathFlag *string, debugAWSFlag *bool) { +func newFlagSet(bin string, output io.Writer) (fs *flag.FlagSet, debugAWSFlag *bool) { fs = flag.NewFlagSet(bin, flag.ContinueOnError) fs.SetOutput(output) return fs, - fs.String("search-path", "", "Optional PostgreSQL search_path to set (e.g. 'myschema,public')"), fs.Bool("debug-aws", false, "Print AWS caller identity before connecting") } func printUsage(output io.Writer, bin string) { - fs, _, _ := newFlagSet(bin, io.Discard) + fs, _ := newFlagSet(bin, io.Discard) var defaults bytes.Buffer fs.SetOutput(&defaults) @@ -164,63 +135,19 @@ func printUsage(output io.Writer, bin string) { fmt.Fprintf(output, usageTemplate, strings.TrimRight(defaults.String(), "\n"), bin) } -func parseCLIArgs(args []string, bin string) (rawURL string, searchPath string, debugAWS bool, err error) { - fs, searchPathFlag, debugAWSFlag := newFlagSet(bin, io.Discard) +func parseCLIArgs(args []string, bin string) (rawURL string, debugAWS bool, err error) { + fs, debugAWSFlag := newFlagSet(bin, io.Discard) if err := fs.Parse(args); err != nil { - return "", "", false, err + return "", false, err } positionals := fs.Args() if len(positionals) != 1 { - return "", "", false, fmt.Errorf("expected exactly one positional RDS IAM connection URL argument, got %d", len(positionals)) + return "", false, fmt.Errorf("expected exactly one positional RDS IAM connection URL argument, got %d", len(positionals)) } - return positionals[0], *searchPathFlag, *debugAWSFlag, nil -} - -func addSearchPathToPSQLURL(u *url.URL, searchPath string) error { - normalized, err := normalizeSearchPath(searchPath) - if err != nil { - return err - } - if normalized == "" { - return nil - } - - query := u.Query() - add := "-csearch_path=" + normalized - - existing := strings.TrimSpace(query.Get("options")) - if existing == "" { - query.Set("options", add) - } else { - query.Set("options", existing+" "+add) - } - - u.RawQuery = query.Encode() - return nil -} - -func normalizeSearchPath(searchPath string) (string, error) { - if strings.TrimSpace(searchPath) == "" { - return "", nil - } - - parts := strings.Split(searchPath, ",") - cleaned := make([]string, 0, len(parts)) - for _, p := range parts { - p = strings.TrimSpace(p) - if p != "" { - cleaned = append(cleaned, p) - } - } - - if len(cleaned) == 0 { - return "", fmt.Errorf("search path cannot be empty") - } - - return strings.Join(cleaned, ","), nil + return positionals[0], *debugAWSFlag, nil } func validateRDSIAMURL(rawURL string) error { @@ -234,6 +161,9 @@ func validateRDSIAMURL(rawURL string) error { if parsedURL.User == nil || strings.TrimSpace(parsedURL.User.Username()) == "" { return fmt.Errorf("connection URL must include a database username") } + if _, ok := parsedURL.User.Password(); ok { + return fmt.Errorf("connection URL must not include a password for postgres+rds-iam") + } if strings.TrimSpace(parsedURL.Host) == "" { return fmt.Errorf("connection URL must include a database host") } @@ -249,6 +179,6 @@ func printCallerIdentity(ctx context.Context, cfg aws.Config) error { return fmt.Errorf("STS GetCallerIdentity failed (creds invalid/expired or STS not allowed): %w", err) } - fmt.Printf("Caller ARN: %s\n", aws.ToString(out.Arn)) + fmt.Fprintf(os.Stderr, "Caller ARN: %s\n", aws.ToString(out.Arn)) return nil } From 36ee21bf28dd8e2c2a994c608ef02dfce9617539 Mon Sep 17 00:00:00 2001 From: Andrew Cremins Date: Mon, 9 Mar 2026 01:38:21 -0700 Subject: [PATCH 10/10] Refactor readme --- cmd/rds-iam-psql/README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cmd/rds-iam-psql/README.md b/cmd/rds-iam-psql/README.md index 968f8c9..5904fcf 100644 --- a/cmd/rds-iam-psql/README.md +++ b/cmd/rds-iam-psql/README.md @@ -86,8 +86,7 @@ This applies to the current session. If you need a persistent default, configure 1. Parses and validates the positional IAM URL. 2. Builds a `pgutils` connection string provider from the IAM URL. 3. If `-debug-aws` is set, runs STS `GetCallerIdentity` and prints the caller ARN. -4. Resolves an IAM tokenized DSN from the provider and launches `psql` with: -- `PGPASSWORD` set from the generated token +4. Resolves an IAM tokenized DSN from the provider and launches `psql` with the IAM generated connection string. ## Setting Up IAM Auth on RDS