Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions cmd/rds-iam-psql/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# rds-iam-psql

A CLI that launches an interactive `psql` session from a required RDS IAM URL:
- positional `postgres+rds-iam://...` DSN
- optional `-debug-aws` flag

## Why?

RDS IAM authentication lets you connect using AWS credentials instead of a static DB password. IAM auth tokens are short-lived and inconvenient to generate manually. This tool resolves a fresh DSN through `pgutils` and opens `psql` for you.

## Installation

```bash
go install github.com/corbaltcode/go-libraries/cmd/rds-iam-psql@latest
```

Or build from source:

```bash
cd ./cmd/rds-iam-psql
go build
```

## Prerequisites

- **psql** installed and available in your PATH
- **AWS credentials** configured (env vars, `~/.aws/credentials`, IAM role, etc.)
- **AWS region** configured for SDK resolution (for example: `AWS_REGION`, shared config profile, or runtime role config)
- **RDS IAM authentication enabled** on your database instance
- A DB user configured for IAM auth (for example: `CREATE USER myuser WITH LOGIN; GRANT rds_iam TO myuser;`)

## Usage

```bash
rds-iam-psql [-debug-aws] '<postgres+rds-iam-url>'
```

- Flags must come before the DSN (standard Go flag parsing behavior).
- `<postgres+rds-iam-url>` may omit the database path. When omitted, `pgutils` defaults the database name to the username.

### Flags

| Flag | Default | Description |
|------|---------|-------------|
| `-debug-aws` | `false` | Print STS caller identity before connecting |

## Examples

Basic IAM URL:

```bash
./rds-iam-psql 'postgres+rds-iam://server@acremins-test.cicxifnkufnd.us-east-1.rds.amazonaws.com:5432/postgres'
```

IAM URL with cross-account role assumption:

```bash
rds-iam-psql 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp?assume_role_arn=arn:aws:iam::123456789012:role/db-connect&assume_role_session_name=rds-iam-psql'
```

With AWS identity debugging:

```bash
rds-iam-psql -debug-aws 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432/myapp'
```

Without explicit database name (defaults to username):

```bash
rds-iam-psql 'postgres+rds-iam://app_user@mydb.abc123.us-east-1.rds.amazonaws.com:5432'
```

## Changing Search Path In psql

If you need to change the schema search path, do it from the interactive `psql` session after connecting:

```sql
SHOW search_path;
SET search_path TO app_schema, public;
```

This applies to the current session. If you need a persistent default, configure it in Postgres (for example with `ALTER ROLE ... SET search_path ...`).

## How It Works

1. Parses and validates the positional IAM URL.
2. Builds a `pgutils` connection string provider from the IAM URL.
3. If `-debug-aws` is set, runs STS `GetCallerIdentity` and prints the caller ARN.
4. Resolves an IAM tokenized DSN from the provider and launches `psql` with the IAM generated connection string.

## Setting Up IAM Auth on RDS

1. Enable IAM authentication on your RDS instance
2. Create a database user and grant IAM privileges:
```sql
CREATE USER myuser WITH LOGIN;
GRANT rds_iam TO myuser;
```
3. Attach an IAM policy allowing `rds-db:connect` to your AWS user/role:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "rds-db:connect",
"Resource": "arn:aws:rds-db:<region>:<account-id>:dbuser:<dbi-resource-id>/<db-user>"
}
]
}
```
184 changes: 184 additions & 0 deletions cmd/rds-iam-psql/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package main

import (
"bytes"
"context"
"errors"
"flag"
"fmt"
"io"
"log"
"net/url"
"os"
"os/exec"
"os/signal"
"strings"
"syscall"

"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/corbaltcode/go-libraries/pgutils"
)

const usageTemplate = `Usage:
%[2]s [-debug-aws] 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'

Notes:
Flags must come before the DSN (standard Go flag parsing).
Database path is optional. If omitted, the database name defaults to the username.

Flags:
%[1]s

Examples:
%[2]s 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'
%[2]s -debug-aws 'postgres+rds-iam://myuser@mydb.abc123.us-east-1.rds.amazonaws.com:5432/mydb'
`

func main() {
rawURL, debugAWS, err := parseCLIArgs(os.Args[1:], os.Args[0])
if err != nil {
if errors.Is(err, flag.ErrHelp) {
printUsage(os.Stdout, os.Args[0])
return
}
fmt.Fprintf(os.Stderr, "%v\n\n", err)
printUsage(os.Stderr, os.Args[0])
os.Exit(2)
}

if err := validateRDSIAMURL(rawURL); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(2)
}

ctx := context.Background()

if debugAWS {
cfg, err := awsconfig.LoadDefaultConfig(ctx)
if err != nil {
log.Fatalf("failed to load AWS config: %v", err)
}
if err := printCallerIdentity(ctx, cfg); err != nil {
log.Fatalf("AWS credentials check failed: %v", err)
}
}

connectionStringProvider, err := pgutils.NewConnectionStringProviderFromURLString(ctx, rawURL)
if err != nil {
log.Fatalf("failed to create connection string provider: %v", err)
}

dsnWithToken, err := connectionStringProvider.ConnectionString(ctx)
if err != nil {
log.Fatalf("failed to get connection string from provider: %v", err)
}

cmd := exec.Command("psql", dsnWithToken)

cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

// Ignore SIGINT in the wrapper so interactive Ctrl-C can be handled by psql.
// Forward SIGTERM to the child process.
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
defer signal.Stop(sigCh)

if err := cmd.Start(); err != nil {
log.Fatalf("failed to start psql: %v", err)
}

waitCh := make(chan error, 1)
go func() { waitCh <- cmd.Wait() }()

for {
select {
case sig := <-sigCh:
switch sig {
case os.Interrupt:
continue
case syscall.SIGTERM:
if cmd.Process != nil {
_ = cmd.Process.Signal(syscall.SIGTERM)
}
}
case err := <-waitCh:
if err == nil {
return
}
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
log.Fatalf("psql failed: %v", err)
}
}
}

func newFlagSet(bin string, output io.Writer) (fs *flag.FlagSet, debugAWSFlag *bool) {
fs = flag.NewFlagSet(bin, flag.ContinueOnError)
fs.SetOutput(output)

return fs,
fs.Bool("debug-aws", false, "Print AWS caller identity before connecting")
}

func printUsage(output io.Writer, bin string) {
fs, _ := newFlagSet(bin, io.Discard)

var defaults bytes.Buffer
fs.SetOutput(&defaults)
fs.PrintDefaults()

fmt.Fprintf(output, usageTemplate, strings.TrimRight(defaults.String(), "\n"), bin)
}

func parseCLIArgs(args []string, bin string) (rawURL string, debugAWS bool, err error) {
fs, debugAWSFlag := newFlagSet(bin, io.Discard)

if err := fs.Parse(args); err != nil {
return "", false, err
}

positionals := fs.Args()
if len(positionals) != 1 {
return "", false, fmt.Errorf("expected exactly one positional RDS IAM connection URL argument, got %d", len(positionals))
}

return positionals[0], *debugAWSFlag, nil
}

func validateRDSIAMURL(rawURL string) error {
parsedURL, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("failed to parse positional connection URL: %w", err)
}
if parsedURL.Scheme != "postgres+rds-iam" {
return fmt.Errorf("unsupported connection URL scheme %q (expected postgres+rds-iam)", parsedURL.Scheme)
}
if parsedURL.User == nil || strings.TrimSpace(parsedURL.User.Username()) == "" {
return fmt.Errorf("connection URL must include a database username")
}
if _, ok := parsedURL.User.Password(); ok {
return fmt.Errorf("connection URL must not include a password for postgres+rds-iam")
}
if strings.TrimSpace(parsedURL.Host) == "" {
return fmt.Errorf("connection URL must include a database host")
}

return nil
}

func printCallerIdentity(ctx context.Context, cfg aws.Config) error {
stsClient := sts.NewFromConfig(cfg)

out, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return fmt.Errorf("STS GetCallerIdentity failed (creds invalid/expired or STS not allowed): %w", err)
}

fmt.Fprintf(os.Stderr, "Caller ARN: %s\n", aws.ToString(out.Arn))
return nil
}