diff --git a/.github/workflows/pullpreview.yml b/.github/workflows/pullpreview.yml index 49b59ea..4ed20ab 100644 --- a/.github/workflows/pullpreview.yml +++ b/.github/workflows/pullpreview.yml @@ -161,4 +161,64 @@ jobs: ttl: 1h env: HCLOUD_TOKEN: "${{ secrets.HCLOUD_TOKEN }}" - HETZNER_CA_KEY: "${{ secrets.HETZNER_CA_KEY }}" + PULLPREVIEW_CA_KEY: "${{ secrets.PULLPREVIEW_CA_KEY || secrets.HETZNER_CA_KEY }}" + + deploy_smoke_ec2: + runs-on: ubuntu-slim + if: github.event_name == 'schedule' || github.event.label.name == 'pullpreview' || contains(github.event.pull_request.labels.*.name, 'pullpreview') + timeout-minutes: 35 + steps: + - uses: actions/checkout@v6 + + - name: Deploy smoke app on EC2 + id: pullpreview + uses: "./" + with: + admins: "@collaborators/push" + app_path: ./examples/workflow-smoke + provider: ec2 + region: us-east-1 + image: al2023-ami-2023 + instance_type: t3.small + dns: rev3.click + max_domain_length: 30 + # required here because the mysql image is private in GHCR + registries: docker://${{ secrets.GHCR_PAT }}@ghcr.io + proxy_tls: web:8080 + ttl: 1h + env: + AWS_ACCESS_KEY_ID: "${{ secrets.AWS_ACCESS_KEY_ID }}" + AWS_SECRET_ACCESS_KEY: "${{ secrets.AWS_SECRET_ACCESS_KEY }}" + AWS_REGION: "us-east-1" + PULLPREVIEW_CA_KEY: "${{ secrets.PULLPREVIEW_CA_KEY || secrets.HETZNER_CA_KEY }}" + + - name: Assert deploy and seed state on EC2 + if: steps.pullpreview.outputs.live == 'true' + shell: bash + env: + PREVIEW_URL: ${{ steps.pullpreview.outputs.url }} + run: | + set -euo pipefail + + if [[ "${PREVIEW_URL}" != https://* ]]; then + echo "::error::Expected https preview URL when proxy_tls is enabled, got ${PREVIEW_URL}" + exit 1 + fi + + response="" + for attempt in $(seq 1 60); do + response="$(curl -fsSL --max-time 15 "${PREVIEW_URL}" || true)" + if printf '%s' "${response}" | grep -q 'Hello World Deploy 1' && \ + printf '%s' "${response}" | grep -q 'seed_count=1' && \ + printf '%s' "${response}" | grep -q 'seed_label=persisted'; then + echo "EC2 smoke checks passed for ${PREVIEW_URL}" + exit 0 + fi + + echo "Attempt ${attempt}/60: waiting for EC2 smoke response from ${PREVIEW_URL}" + sleep 5 + done + + echo "::error::Unexpected response from ${PREVIEW_URL}" + printf '%s\n' "${response}" + exit 1 diff --git a/Makefile b/Makefile index 75246d0..608ff5c 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ GO ?= mise exec -- go -GO_TEST ?= $(GO) test ./internal/providers ./internal/pullpreview ./internal/providers/hetzner +GO_TEST ?= $(GO) test ./internal/providers ./internal/pullpreview ./internal/providers/hetzner ./internal/providers/ec2 DIST_DIR := dist BIN_NAME := pullpreview GO_LDFLAGS ?= -s -w diff --git a/README.md b/README.md index a478473..4f4bad7 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ is made to Pull Requests labelled with the `pullpreview` label. When triggered, it will: 1. Check out the repository code -2. Provision a preview instance (Lightsail by default, or Hetzner with `provider: hetzner`), with docker and docker-compose set up +2. Provision a preview instance (Lightsail by default, or `provider: hetzner` / `provider: ec2`), with docker and docker-compose set up 3. Continuously deploy the specified pull requests using your docker-compose file(s) 4. Report the preview instance URL in the GitHub UI @@ -118,11 +118,11 @@ All supported `with:` inputs from `action.yml`: | `compose_files` | `docker-compose.yml` | Comma-separated Compose files passed to deploy. | | `compose_options` | `--build` | Additional options appended to `docker compose up`. | | `license` | `""` | PullPreview license key. | -| `instance_type` | `small` | Provider-specific instance size (`small` for Lightsail, `cpx21` for Hetzner). | +| `instance_type` | `""` | Provider-specific instance size (defaults: Lightsail `small`, Hetzner `cpx21`, EC2 `t3.small`). | | `region` | `` | Optional provider region/datacenter override (`AWS_REGION`/Hetzner location). If empty, provider defaults apply. | -| `image` | `ubuntu-24.04` | Instance image for Hetzner (provider-specific) and ignored for AWS. | +| `image` | `""` | Provider image selector: Hetzner image name, or EC2 AMI ID / AMI name prefix. | | `deployment_variant` | `""` | Optional short suffix to run multiple preview environments per PR (max 4 chars). | -| `provider` | `lightsail` | Cloud provider (`lightsail`, `hetzner`). | +| `provider` | `lightsail` | Cloud provider (`lightsail`, `hetzner`, `ec2`). | | `registries` | `""` | Private registry credentials, e.g. `docker://user:password@ghcr.io`. | | `proxy_tls` | `""` | Automatic HTTPS forwarding with Caddy + Let's Encrypt (`service:port`, e.g. `web:80`). | | `pre_script` | `""` | Path to a local shell script (relative to `app_path`) executed inline over SSH before compose deploy (should be self-contained). | @@ -133,8 +133,11 @@ Notes: - `proxy_tls` forces URL/output/comment links to HTTPS on port `443`, injects a Caddy proxy service, and suppresses firewall exposure for port `80`. **When using `proxy_tls`, it is strongly recommended to set `dns` to a [custom domain](https://github.com/pullpreview/action/wiki/Using-a-custom-domain) or one of the built-in `revN.click` alternatives** to avoid hitting shared Let's Encrypt rate limits on `my.preview.run`. - `admins: "@collaborators/push"` uses GitHub API collaborators with push permission (first page, up to 100 users; warning is logged if more exist). - SSH key fetches are cached between runs in the action cache. -- For Hetzner, configure credentials and defaults via action inputs and environment: `HCLOUD_TOKEN` (required), `HETZNER_CA_KEY` (required), optional `region` and `image` (`region` defaults to `nbg1`, `image` defaults to `ubuntu-24.04`). `instance_type` defaults to `cpx21` when provider is Hetzner. -- `HETZNER_CA_KEY` must be an SSH private key (RSA or Ed25519) for the instance-access CA. PullPreview signs a per-run ephemeral login key with this CA key and uses SSH certificates (`...-cert.pub`) instead of reusing a persistent private key across runs. +- CA key env is `PULLPREVIEW_CA_KEY` (canonical). For Hetzner, legacy `HETZNER_CA_KEY` is still accepted if canonical is unset. +- For Hetzner, configure credentials and defaults via action inputs and environment: `HCLOUD_TOKEN` (required), `PULLPREVIEW_CA_KEY` (required; legacy `HETZNER_CA_KEY` fallback), optional `region` and `image` (`region` defaults to `nbg1`, `image` defaults to `ubuntu-24.04`). `instance_type` defaults to `cpx21` when provider is Hetzner. +- For EC2, configure AWS credentials plus `PULLPREVIEW_CA_KEY` and set `provider: ec2`. PullPreview requires a pre-existing public subnet tagged `pullpreview-enabled=true` in the selected region. +- For EC2 `image`: when `image` starts with `ami-`, it is used directly. Otherwise `image` is treated as an AMI name prefix and PullPreview selects the newest available match from owners `self` + `amazon`. If `image` is empty, PullPreview uses the default Amazon Linux 2023 prefix. +- `PULLPREVIEW_CA_KEY` must be an SSH private key (RSA or Ed25519) for the instance-access CA. PullPreview signs a per-run ephemeral login key with this CA key and uses SSH certificates (`...-cert.pub`) instead of reusing a persistent private key across runs. - Generate a CA key once for your repository secret: ```bash @@ -142,7 +145,7 @@ ssh-keygen -t rsa -b 3072 -m PEM -N "" -f hetzner_ca_key ``` - **Let's Encrypt rate limits**: Let's Encrypt allows a maximum of [50 certificates per registered domain per week](https://letsencrypt.org/docs/rate-limits/#new-certificates-per-registered-domain). If you use `proxy_tls` and hit this limit on the default `my.preview.run` domain, switch to one of the built-in alternatives: `rev1.click`, `rev2.click`, ... `rev9.click`. Set `dns: rev1.click` in your workflow inputs. You can also use a [custom domain](https://github.com/pullpreview/action/wiki/Using-a-custom-domain). -- For local CLI runs, set `HCLOUD_TOKEN` and `HETZNER_CA_KEY` (for example via `.env`) when using `provider: hetzner` to avoid relying on action inputs. +- For local CLI runs, set provider-specific credentials plus `PULLPREVIEW_CA_KEY` (for example via `.env`). ## Example @@ -221,10 +224,39 @@ jobs: ttl: 1h env: HCLOUD_TOKEN: "${{ secrets.HCLOUD_TOKEN }}" - HETZNER_CA_KEY: "${{ secrets.HETZNER_CA_KEY }}" + PULLPREVIEW_CA_KEY: "${{ secrets.PULLPREVIEW_CA_KEY }}" ``` +## EC2 example + +```yaml +# .github/workflows/pullpreview-ec2.yml +name: PullPreview +on: + pull_request: + types: [labeled, unlabeled, synchronize, closed, reopened, opened] + +jobs: + deploy_ec2: + runs-on: ubuntu-slim + if: github.event.label.name == 'pullpreview' || contains(github.event.pull_request.labels.*.name, 'pullpreview') + steps: + - uses: actions/checkout@v5 + - uses: pullpreview/action@v6 + with: + provider: ec2 + # optional: AMI ID or AMI name prefix + image: al2023-ami-2023 + # optional: raw EC2 instance type + instance_type: t3.small + env: + AWS_ACCESS_KEY_ID: "${{ secrets.AWS_ACCESS_KEY_ID }}" + AWS_SECRET_ACCESS_KEY: "${{ secrets.AWS_SECRET_ACCESS_KEY }}" + AWS_REGION: "us-east-1" + PULLPREVIEW_CA_KEY: "${{ secrets.PULLPREVIEW_CA_KEY }}" +``` + ## CLI usage (installed binary) Pull the released CLI binary from GitHub Releases, install it in your PATH, then use: diff --git a/action.yml b/action.yml index 2d24ce7..7f45450 100644 --- a/action.yml +++ b/action.yml @@ -50,9 +50,9 @@ inputs: required: false default: "" instance_type: - description: "Instance type to use" + description: "Instance type to use (provider-specific)" required: false - default: "small" + default: "" region: description: "Provider region (AWS region or Hetzner location), overrides provider defaults" required: false @@ -62,11 +62,11 @@ inputs: required: false default: "" image: - description: "Instance image (Hetzner only; ignored by AWS)" + description: "Provider image selector (Hetzner image name, or EC2 AMI ID/name prefix)" required: false - default: "ubuntu-24.04" + default: "" provider: - description: "Cloud provider to use: lightsail, hetzner" + description: "Cloud provider to use: lightsail, hetzner, ec2" required: false default: "lightsail" registries: diff --git a/cmd/pullpreview/main.go b/cmd/pullpreview/main.go index 34241c3..3db66bb 100644 --- a/cmd/pullpreview/main.go +++ b/cmd/pullpreview/main.go @@ -12,6 +12,7 @@ import ( "syscall" "github.com/pullpreview/action/internal/providers" + _ "github.com/pullpreview/action/internal/providers/ec2" _ "github.com/pullpreview/action/internal/providers/hetzner" _ "github.com/pullpreview/action/internal/providers/lightsail" "github.com/pullpreview/action/internal/pullpreview" @@ -199,7 +200,7 @@ func registerCommonFlags(fs *flag.FlagSet) *commonFlagValues { fs.StringVar(&values.options.ProxyTLS, "proxy-tls", "", "Enable automatic HTTPS proxying with Let's Encrypt (format: service:port, e.g. web:80)") fs.StringVar(&values.options.DNS, "dns", "my.preview.run", "DNS suffix to use") fs.StringVar(&values.ports, "ports", "80/tcp,443/tcp", "Ports to open for external access") - fs.StringVar(&values.options.InstanceType, "instance-type", "small", "Instance type to use") + fs.StringVar(&values.options.InstanceType, "instance-type", "", "Instance type to use") fs.StringVar(&values.options.DefaultPort, "default-port", "80", "Default port for URL") fs.Var(&values.tags, "tags", "Tags to add to the instance (key:value), comma-separated") fs.StringVar(&values.composeFiles, "compose-files", "docker-compose.yml", "Compose files to use") diff --git a/dist/pullpreview-linux-amd64 b/dist/pullpreview-linux-amd64 index b80af2b..6074623 100755 Binary files a/dist/pullpreview-linux-amd64 and b/dist/pullpreview-linux-amd64 differ diff --git a/go.mod b/go.mod index 645dc87..56b140e 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.25.1 require ( github.com/aws/aws-sdk-go-v2 v1.41.1 github.com/aws/aws-sdk-go-v2/config v1.32.7 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.289.0 github.com/aws/aws-sdk-go-v2/service/lightsail v1.50.11 github.com/google/go-github/v60 v60.0.0 github.com/hetznercloud/hcloud-go/v2 v2.36.0 diff --git a/go.sum b/go.sum index ff959d5..783ccef 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17 h1:WWLqlh79iO48yLkj1v github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17/go.mod h1:EhG22vHRrvF8oXSTYStZhJc1aUgKtnJe+aOiFEV90cM= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.289.0 h1:Ftj1M28RtAjgHpycBeQaFhfGx+aQ/swYEz+tBtIh9nE= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.289.0/go.mod h1:Uy+C+Sc58jozdoL1McQr8bDsEvNFx+/nBY+vpO1HVUY= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.17 h1:RuNSMoozM8oXlgLG/n6WLaFGoea7/CddrCfIiSA+xdY= diff --git a/internal/providers/ec2/ec2.go b/internal/providers/ec2/ec2.go new file mode 100644 index 0000000..4302f6a --- /dev/null +++ b/internal/providers/ec2/ec2.go @@ -0,0 +1,1334 @@ +package ec2 + +import ( + "context" + "encoding/base64" + "fmt" + "net" + "os" + "os/exec" + "regexp" + "sort" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + ec2svc "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "golang.org/x/crypto/ssh" + + "github.com/pullpreview/action/internal/providers/sshca" + "github.com/pullpreview/action/internal/pullpreview" +) + +const ( + defaultEC2SSHRetries = 12 + defaultEC2SSHInterval = 10 * time.Second + defaultEC2SSHCertTTL = 12 * time.Hour +) + +var ec2SecurityGroupNameSanitizer = regexp.MustCompile(`[^a-zA-Z0-9-]+`) + +type ec2Client interface { + DescribeInstances(context.Context, *ec2svc.DescribeInstancesInput) (*ec2svc.DescribeInstancesOutput, error) + RunInstances(context.Context, *ec2svc.RunInstancesInput) (*ec2svc.RunInstancesOutput, error) + TerminateInstances(context.Context, *ec2svc.TerminateInstancesInput) (*ec2svc.TerminateInstancesOutput, error) + StartInstances(context.Context, *ec2svc.StartInstancesInput) (*ec2svc.StartInstancesOutput, error) + ModifyInstanceAttribute(context.Context, *ec2svc.ModifyInstanceAttributeInput) (*ec2svc.ModifyInstanceAttributeOutput, error) + + DescribeSubnets(context.Context, *ec2svc.DescribeSubnetsInput) (*ec2svc.DescribeSubnetsOutput, error) + DescribeInstanceTypes(context.Context, *ec2svc.DescribeInstanceTypesInput) (*ec2svc.DescribeInstanceTypesOutput, error) + DescribeImages(context.Context, *ec2svc.DescribeImagesInput) (*ec2svc.DescribeImagesOutput, error) + + DescribeSecurityGroups(context.Context, *ec2svc.DescribeSecurityGroupsInput) (*ec2svc.DescribeSecurityGroupsOutput, error) + CreateSecurityGroup(context.Context, *ec2svc.CreateSecurityGroupInput) (*ec2svc.CreateSecurityGroupOutput, error) + DeleteSecurityGroup(context.Context, *ec2svc.DeleteSecurityGroupInput) (*ec2svc.DeleteSecurityGroupOutput, error) + AuthorizeSecurityGroupIngress(context.Context, *ec2svc.AuthorizeSecurityGroupIngressInput) (*ec2svc.AuthorizeSecurityGroupIngressOutput, error) + RevokeSecurityGroupIngress(context.Context, *ec2svc.RevokeSecurityGroupIngressInput) (*ec2svc.RevokeSecurityGroupIngressOutput, error) + + CreateTags(context.Context, *ec2svc.CreateTagsInput) (*ec2svc.CreateTagsOutput, error) + CreateKeyPair(context.Context, *ec2svc.CreateKeyPairInput) (*ec2svc.CreateKeyPairOutput, error) + DeleteKeyPair(context.Context, *ec2svc.DeleteKeyPairInput) (*ec2svc.DeleteKeyPairOutput, error) +} + +type ec2ClientAdapter struct { + client *ec2svc.Client +} + +func (a ec2ClientAdapter) DescribeInstances(ctx context.Context, input *ec2svc.DescribeInstancesInput) (*ec2svc.DescribeInstancesOutput, error) { + return a.client.DescribeInstances(ctx, input) +} + +func (a ec2ClientAdapter) RunInstances(ctx context.Context, input *ec2svc.RunInstancesInput) (*ec2svc.RunInstancesOutput, error) { + return a.client.RunInstances(ctx, input) +} + +func (a ec2ClientAdapter) TerminateInstances(ctx context.Context, input *ec2svc.TerminateInstancesInput) (*ec2svc.TerminateInstancesOutput, error) { + return a.client.TerminateInstances(ctx, input) +} + +func (a ec2ClientAdapter) StartInstances(ctx context.Context, input *ec2svc.StartInstancesInput) (*ec2svc.StartInstancesOutput, error) { + return a.client.StartInstances(ctx, input) +} + +func (a ec2ClientAdapter) ModifyInstanceAttribute(ctx context.Context, input *ec2svc.ModifyInstanceAttributeInput) (*ec2svc.ModifyInstanceAttributeOutput, error) { + return a.client.ModifyInstanceAttribute(ctx, input) +} + +func (a ec2ClientAdapter) DescribeSubnets(ctx context.Context, input *ec2svc.DescribeSubnetsInput) (*ec2svc.DescribeSubnetsOutput, error) { + return a.client.DescribeSubnets(ctx, input) +} + +func (a ec2ClientAdapter) DescribeInstanceTypes(ctx context.Context, input *ec2svc.DescribeInstanceTypesInput) (*ec2svc.DescribeInstanceTypesOutput, error) { + return a.client.DescribeInstanceTypes(ctx, input) +} + +func (a ec2ClientAdapter) DescribeImages(ctx context.Context, input *ec2svc.DescribeImagesInput) (*ec2svc.DescribeImagesOutput, error) { + return a.client.DescribeImages(ctx, input) +} + +func (a ec2ClientAdapter) DescribeSecurityGroups(ctx context.Context, input *ec2svc.DescribeSecurityGroupsInput) (*ec2svc.DescribeSecurityGroupsOutput, error) { + return a.client.DescribeSecurityGroups(ctx, input) +} + +func (a ec2ClientAdapter) CreateSecurityGroup(ctx context.Context, input *ec2svc.CreateSecurityGroupInput) (*ec2svc.CreateSecurityGroupOutput, error) { + return a.client.CreateSecurityGroup(ctx, input) +} + +func (a ec2ClientAdapter) DeleteSecurityGroup(ctx context.Context, input *ec2svc.DeleteSecurityGroupInput) (*ec2svc.DeleteSecurityGroupOutput, error) { + return a.client.DeleteSecurityGroup(ctx, input) +} + +func (a ec2ClientAdapter) AuthorizeSecurityGroupIngress(ctx context.Context, input *ec2svc.AuthorizeSecurityGroupIngressInput) (*ec2svc.AuthorizeSecurityGroupIngressOutput, error) { + return a.client.AuthorizeSecurityGroupIngress(ctx, input) +} + +func (a ec2ClientAdapter) RevokeSecurityGroupIngress(ctx context.Context, input *ec2svc.RevokeSecurityGroupIngressInput) (*ec2svc.RevokeSecurityGroupIngressOutput, error) { + return a.client.RevokeSecurityGroupIngress(ctx, input) +} + +func (a ec2ClientAdapter) CreateTags(ctx context.Context, input *ec2svc.CreateTagsInput) (*ec2svc.CreateTagsOutput, error) { + return a.client.CreateTags(ctx, input) +} + +func (a ec2ClientAdapter) CreateKeyPair(ctx context.Context, input *ec2svc.CreateKeyPairInput) (*ec2svc.CreateKeyPairOutput, error) { + return a.client.CreateKeyPair(ctx, input) +} + +func (a ec2ClientAdapter) DeleteKeyPair(ctx context.Context, input *ec2svc.DeleteKeyPairInput) (*ec2svc.DeleteKeyPairOutput, error) { + return a.client.DeleteKeyPair(ctx, input) +} + +var runSSHCommand = func(ctx context.Context, keyFile, certFile, user, host string) ([]byte, error) { + args := []string{ + "-o", "BatchMode=yes", + "-o", "IdentitiesOnly=yes", + "-o", "IdentityAgent=none", + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "LogLevel=ERROR", + "-o", "ConnectTimeout=8", + "-i", keyFile, + } + if strings.TrimSpace(certFile) != "" { + args = append(args, "-o", fmt.Sprintf("CertificateFile=%s", certFile)) + } + args = append(args, + fmt.Sprintf("%s@%s", user, host), + "sh", "-lc", "test -f /etc/pullpreview/ready && id -nG | tr ' ' '\\n' | grep -qx docker", + ) + cmd := exec.CommandContext(ctx, "ssh", args...) + return cmd.CombinedOutput() +} + +type Provider struct { + client ec2Client + ctx context.Context + region string + image string + sshUser string + caSigner ssh.Signer + caPublicKey string + sshRetryCount int + sshRetryDelay time.Duration + logger *pullpreview.Logger +} + +func newProviderWithClient(ctx context.Context, cfg Config, logger *pullpreview.Logger, client ec2Client) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + if client == nil { + return nil, fmt.Errorf("client cannot be nil") + } + parsedCA, err := sshca.Parse(cfg.CAKey, cfg.CAKeyEnv) + if err != nil { + return nil, err + } + if logger != nil { + logger.Infof("EC2 SSH CA pre-check passed (%s)", parsedCA.Source) + } + return &Provider{ + client: client, + ctx: pullpreview.EnsureContext(ctx), + region: cfg.Region, + image: cfg.Image, + sshUser: cfg.SSHUsername, + caSigner: parsedCA.Signer, + caPublicKey: parsedCA.PublicKey, + sshRetryCount: defaultEC2SSHRetries, + sshRetryDelay: defaultEC2SSHInterval, + logger: logger, + }, nil +} + +func (p *Provider) Name() string { + return "ec2" +} + +func (p *Provider) DisplayName() string { + return "AWS EC2" +} + +func (p *Provider) SupportsSnapshots() bool { + return false +} + +func (p *Provider) SupportsRestore() bool { + return false +} + +func (p *Provider) SupportsFirewall() bool { + return true +} + +func (p *Provider) Username() string { + return p.sshUser +} + +func (p *Provider) BuildUserData(options pullpreview.UserDataOptions) (string, error) { + lines := []string{ + "#!/usr/bin/env bash", + "set -xe ; set -o pipefail", + } + homeDir := pullpreview.HomeDirForUser(options.Username) + lines = append(lines, fmt.Sprintf("mkdir -p %s/.ssh", homeDir)) + if len(options.SSHPublicKeys) > 0 { + lines = append(lines, fmt.Sprintf("echo '%s' >> %s/.ssh/authorized_keys", strings.Join(options.SSHPublicKeys, "\n"), homeDir)) + lines = append(lines, + fmt.Sprintf("chown -R %s:%s %s/.ssh", options.Username, options.Username, homeDir), + fmt.Sprintf("chmod 0700 %s/.ssh && chmod 0600 %s/.ssh/authorized_keys", homeDir, homeDir), + ) + } + lines = append(lines, + fmt.Sprintf("mkdir -p %s && chown -R %s:%s %s", options.AppPath, options.Username, options.Username, options.AppPath), + "mkdir -p /etc/profile.d", + fmt.Sprintf("echo 'cd %s' > /etc/profile.d/pullpreview.sh", options.AppPath), + "if command -v dnf >/dev/null 2>&1; then", + " if grep -qi 'Amazon Linux' /etc/os-release; then", + " dnf -y install docker", + " else", + " dnf -y install dnf-plugins-core", + " dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo", + " dnf -y install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin", + " fi", + " systemctl enable --now docker || systemctl restart docker", + "elif command -v yum >/dev/null 2>&1; then", + " yum -y install yum-utils", + " yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo", + " yum -y install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin", + " systemctl enable --now docker || systemctl restart docker", + "elif command -v apt-get >/dev/null 2>&1; then", + " mkdir -p /etc/apt/keyrings", + " install -m 0755 -d /etc/apt/keyrings", + " apt-get update", + " apt-get install -y ca-certificates curl gnupg lsb-release", + " if grep -qi ubuntu /etc/os-release; then DISTRO=ubuntu; else DISTRO=debian; fi", + " curl -fsSL https://download.docker.com/linux/$DISTRO/gpg -o /etc/apt/keyrings/docker.asc", + " chmod a+r /etc/apt/keyrings/docker.asc", + " echo \"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/$DISTRO $(lsb_release -cs) stable\" > /etc/apt/sources.list.d/docker.list", + " apt-get update", + " apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin", + " systemctl enable --now docker || systemctl restart docker", + "else", + " echo 'unsupported OS family; expected dnf, yum, or apt'", + " exit 1", + "fi", + fmt.Sprintf("usermod -aG docker %s || true", options.Username), + "mkdir -p /etc/ssh/sshd_config.d", + fmt.Sprintf("cat <<'EOF' > /etc/ssh/pullpreview-user-ca.pub\n%s\nEOF", p.caPublicKey), + "cat <<'EOF' > /etc/ssh/sshd_config.d/pullpreview.conf", + "TrustedUserCAKeys /etc/ssh/pullpreview-user-ca.pub", + "EOF", + "systemctl restart ssh || systemctl restart sshd || true", + "mkdir -p /etc/pullpreview && touch /etc/pullpreview/ready", + fmt.Sprintf("chown -R %s:%s /etc/pullpreview", options.Username, options.Username), + ) + return strings.Join(lines, "\n"), nil +} + +func (p *Provider) Launch(name string, opts pullpreview.LaunchOptions) (pullpreview.AccessDetails, error) { + for { + existing, err := p.instanceByName(name) + if err != nil { + return pullpreview.AccessDetails{}, err + } + if existing == nil { + return p.createInstance(name, opts) + } + if err := p.ensureInstanceRunning(existing); err != nil { + return pullpreview.AccessDetails{}, err + } + existing, err = p.instanceByID(aws.ToString(existing.InstanceId)) + if err != nil { + return pullpreview.AccessDetails{}, err + } + if existing == nil { + continue + } + publicIP := p.publicIPAddress(existing) + if publicIP == "" { + if p.logger != nil { + p.logger.Warnf("Existing EC2 instance %q missing public IP; recreating", name) + } + if err := p.destroyInstanceAndSecurityGroups(existing, name); err != nil { + return pullpreview.AccessDetails{}, err + } + continue + } + sgID, _, err := p.ensureSecurityGroup(name, aws.ToString(existing.VpcId), opts.Ports, opts.CIDRs) + if err != nil { + return pullpreview.AccessDetails{}, err + } + if err := p.ensureInstanceSecurityGroup(existing, sgID); err != nil { + return pullpreview.AccessDetails{}, err + } + privateKey, cert, err := p.generateSignedAccessCredentials() + if err != nil { + return pullpreview.AccessDetails{}, err + } + if err := p.validateSSHAccessWithRetry(existing, privateKey, cert, 0); err != nil { + if p.logger != nil { + p.logger.Warnf("Existing EC2 instance %q SSH cert check failed; recreating (%v)", name, err) + } + if err := p.destroyInstanceAndSecurityGroups(existing, name); err != nil { + return pullpreview.AccessDetails{}, err + } + continue + } + if p.logger != nil { + p.logger.Infof("Reusing existing EC2 instance %s with cert-based SSH credentials", name) + } + return pullpreview.AccessDetails{ + Username: p.sshUser, + IPAddress: publicIP, + PrivateKey: strings.TrimSpace(privateKey), + CertKey: strings.TrimSpace(cert), + }, nil + } +} + +func (p *Provider) createInstance(name string, opts pullpreview.LaunchOptions) (pullpreview.AccessDetails, error) { + instanceType := resolveEC2InstanceType(opts.Size) + supportedArchs, err := p.instanceTypeArchitectures(instanceType) + if err != nil { + return pullpreview.AccessDetails{}, err + } + image, err := p.resolveImage(supportedArchs) + if err != nil { + return pullpreview.AccessDetails{}, err + } + subnet, err := p.findTaggedPublicSubnet() + if err != nil { + return pullpreview.AccessDetails{}, err + } + vpcID := aws.ToString(subnet.VpcId) + sgID, sgCreated, err := p.ensureSecurityGroup(name, vpcID, opts.Ports, opts.CIDRs) + if err != nil { + return pullpreview.AccessDetails{}, err + } + keyName, bootstrapPrivateKey, err := p.createBootstrapKey(name) + if err != nil { + return pullpreview.AccessDetails{}, p.cleanupCreateFailure(name, "", keyName, "", false, err) + } + + instanceTags := mergeTags(map[string]string{"stack": pullpreview.StackName}, opts.Tags) + instanceTags["pullpreview_instance_name"] = name + instanceTags["Name"] = name + + userData := base64.StdEncoding.EncodeToString([]byte(opts.UserData)) + runOut, err := p.client.RunInstances(p.ctx, &ec2svc.RunInstancesInput{ + ImageId: image.ImageId, + InstanceType: ec2types.InstanceType(instanceType), + MinCount: ptrInt32(1), + MaxCount: ptrInt32(1), + KeyName: aws.String(keyName), + UserData: aws.String(userData), + NetworkInterfaces: []ec2types.InstanceNetworkInterfaceSpecification{ + { + DeviceIndex: ptrInt32(0), + SubnetId: subnet.SubnetId, + Groups: []string{sgID}, + AssociatePublicIpAddress: aws.Bool(true), + }, + }, + TagSpecifications: []ec2types.TagSpecification{ + { + ResourceType: ec2types.ResourceTypeInstance, + Tags: toEC2Tags(instanceTags), + }, + }, + }) + if err != nil { + return pullpreview.AccessDetails{}, p.cleanupCreateFailure(name, "", keyName, sgID, sgCreated, err) + } + if len(runOut.Instances) == 0 || runOut.Instances[0].InstanceId == nil { + return pullpreview.AccessDetails{}, p.cleanupCreateFailure(name, "", keyName, sgID, sgCreated, fmt.Errorf("ec2 did not return created instance")) + } + instanceID := aws.ToString(runOut.Instances[0].InstanceId) + instance, err := p.waitForInstanceState(instanceID, ec2types.InstanceStateNameRunning, 60, 5*time.Second) + if err != nil { + return pullpreview.AccessDetails{}, p.cleanupCreateFailure(name, instanceID, keyName, sgID, sgCreated, err) + } + if err := p.validateSSHAccessWithRetry(instance, bootstrapPrivateKey, "", 0); err != nil { + return pullpreview.AccessDetails{}, p.cleanupCreateFailure(name, instanceID, keyName, sgID, sgCreated, err) + } + if err := p.deleteKeyPairIfExists(keyName); err != nil && p.logger != nil { + p.logger.Warnf("Unable to delete temporary EC2 key pair %s: %v", keyName, err) + } + + privateKey, cert, err := p.generateSignedAccessCredentials() + if err != nil { + return pullpreview.AccessDetails{}, p.cleanupCreateFailure(name, instanceID, "", sgID, sgCreated, err) + } + if err := p.validateSSHAccessWithRetry(instance, privateKey, cert, 0); err != nil { + return pullpreview.AccessDetails{}, p.cleanupCreateFailure(name, instanceID, "", sgID, sgCreated, err) + } + publicIP := p.publicIPAddress(instance) + if publicIP == "" { + return pullpreview.AccessDetails{}, p.cleanupCreateFailure(name, instanceID, "", sgID, sgCreated, fmt.Errorf("created instance missing public IP")) + } + + return pullpreview.AccessDetails{ + Username: p.sshUser, + IPAddress: publicIP, + PrivateKey: strings.TrimSpace(privateKey), + CertKey: strings.TrimSpace(cert), + }, nil +} + +func (p *Provider) cleanupCreateFailure(name, instanceID, keyName, securityGroupID string, deleteSecurityGroup bool, cause error) error { + if strings.TrimSpace(instanceID) != "" { + if err := p.terminateInstanceAndWait(instanceID); err != nil && p.logger != nil { + p.logger.Warnf("Create cleanup: unable to terminate EC2 instance %s: %v", instanceID, err) + } + } + if strings.TrimSpace(keyName) != "" { + if err := p.deleteKeyPairIfExists(keyName); err != nil && p.logger != nil { + p.logger.Warnf("Create cleanup: unable to delete EC2 key pair %s: %v", keyName, err) + } + } + if deleteSecurityGroup && strings.TrimSpace(securityGroupID) != "" { + if err := p.deleteSecurityGroupByID(securityGroupID); err != nil && p.logger != nil { + p.logger.Warnf("Create cleanup: unable to delete security group %s: %v", securityGroupID, err) + } + } + if cause != nil { + return cause + } + return fmt.Errorf("create cleanup failed for %q", name) +} + +func (p *Provider) Terminate(name string) error { + instance, err := p.instanceByName(name) + if err != nil { + return err + } + if instance != nil { + if err := p.terminateInstanceAndWait(aws.ToString(instance.InstanceId)); err != nil { + return err + } + } + if err := p.deleteSecurityGroupsForInstance(name); err != nil && p.logger != nil { + p.logger.Warnf("Unable to delete EC2 security group for %s: %v", name, err) + } + return nil +} + +func (p *Provider) Running(name string) (bool, error) { + instance, err := p.instanceByName(name) + if err != nil { + return false, err + } + if instance == nil { + return false, nil + } + return instanceStateName(instance) == ec2types.InstanceStateNameRunning, nil +} + +func (p *Provider) ListInstances(tags map[string]string) ([]pullpreview.InstanceSummary, error) { + filters := []ec2types.Filter{ + {Name: aws.String("instance-state-name"), Values: []string{"pending", "running", "stopping", "stopped"}}, + } + for key, value := range tags { + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) + if key == "" || value == "" { + continue + } + filters = append(filters, ec2types.Filter{Name: aws.String("tag:" + key), Values: []string{value}}) + } + instances, err := p.describeInstancesAll(&ec2svc.DescribeInstancesInput{Filters: filters}) + if err != nil { + return nil, err + } + result := make([]pullpreview.InstanceSummary, 0, len(instances)) + for _, instance := range instances { + summary := pullpreview.InstanceSummary{ + Name: instanceName(instance), + PublicIP: p.publicIPAddress(&instance), + Size: string(instance.InstanceType), + Region: p.region, + CreatedAt: aws.ToTime(instance.LaunchTime), + Tags: tagsToMap(instance.Tags), + } + if instance.Placement != nil { + summary.Zone = aws.ToString(instance.Placement.AvailabilityZone) + } + result = append(result, summary) + } + return result, nil +} + +func (p *Provider) resolveImage(supportedArchs map[string]struct{}) (ec2types.Image, error) { + value := strings.TrimSpace(p.image) + if strings.HasPrefix(value, "ami-") { + output, err := p.client.DescribeImages(p.ctx, &ec2svc.DescribeImagesInput{ImageIds: []string{value}}) + if err != nil { + return ec2types.Image{}, err + } + if len(output.Images) == 0 { + return ec2types.Image{}, fmt.Errorf("ami %q not found", value) + } + image := output.Images[0] + if image.State != ec2types.ImageStateAvailable { + return ec2types.Image{}, fmt.Errorf("ami %q is not available (state=%s)", value, image.State) + } + if err := ensureAMIArchitectureCompatible(image, supportedArchs, value); err != nil { + return ec2types.Image{}, err + } + return image, nil + } + + prefix := value + if prefix == "" { + prefix = defaultEC2ImagePrefix + } + output, err := p.client.DescribeImages(p.ctx, &ec2svc.DescribeImagesInput{ + Owners: []string{"self", "amazon"}, + Filters: []ec2types.Filter{ + {Name: aws.String("name"), Values: []string{prefix + "*"}}, + {Name: aws.String("state"), Values: []string{string(ec2types.ImageStateAvailable)}}, + }, + }) + if err != nil { + return ec2types.Image{}, err + } + if len(output.Images) == 0 { + return ec2types.Image{}, fmt.Errorf("no available AMI matched prefix %q (owners: self, amazon)", prefix) + } + images := append([]ec2types.Image{}, output.Images...) + sort.Slice(images, func(i, j int) bool { + left := parseImageCreationDate(images[i].CreationDate) + right := parseImageCreationDate(images[j].CreationDate) + if !left.Equal(right) { + return left.After(right) + } + return aws.ToString(images[i].ImageId) > aws.ToString(images[j].ImageId) + }) + selected := images[0] + if err := ensureAMIArchitectureCompatible(selected, supportedArchs, prefix); err != nil { + return ec2types.Image{}, err + } + return selected, nil +} + +func ensureAMIArchitectureCompatible(image ec2types.Image, supportedArchs map[string]struct{}, imageSource string) error { + arch := strings.TrimSpace(string(image.Architecture)) + if arch == "" { + return fmt.Errorf("selected AMI %q from %q has empty architecture", aws.ToString(image.ImageId), imageSource) + } + if _, ok := supportedArchs[arch]; ok { + return nil + } + allowed := make([]string, 0, len(supportedArchs)) + for value := range supportedArchs { + allowed = append(allowed, value) + } + sort.Strings(allowed) + return fmt.Errorf("selected AMI %q architecture %q is incompatible with instance type supported architectures [%s]", aws.ToString(image.ImageId), arch, strings.Join(allowed, ", ")) +} + +func parseImageCreationDate(raw *string) time.Time { + value := strings.TrimSpace(aws.ToString(raw)) + if value == "" { + return time.Time{} + } + parsed, err := time.Parse(time.RFC3339, value) + if err == nil { + return parsed + } + parsed, err = time.Parse("2006-01-02T15:04:05.000Z", value) + if err == nil { + return parsed + } + return time.Time{} +} + +func (p *Provider) instanceTypeArchitectures(instanceType string) (map[string]struct{}, error) { + output, err := p.client.DescribeInstanceTypes(p.ctx, &ec2svc.DescribeInstanceTypesInput{ + InstanceTypes: []ec2types.InstanceType{ec2types.InstanceType(instanceType)}, + }) + if err != nil { + return nil, err + } + if len(output.InstanceTypes) == 0 { + return nil, fmt.Errorf("instance type %q not found", instanceType) + } + architectures := map[string]struct{}{} + for _, arch := range output.InstanceTypes[0].ProcessorInfo.SupportedArchitectures { + value := strings.TrimSpace(string(arch)) + if value == "" { + continue + } + architectures[value] = struct{}{} + } + if len(architectures) == 0 { + return nil, fmt.Errorf("instance type %q does not report supported architectures", instanceType) + } + return architectures, nil +} + +func (p *Provider) findTaggedPublicSubnet() (*ec2types.Subnet, error) { + output, err := p.client.DescribeSubnets(p.ctx, &ec2svc.DescribeSubnetsInput{ + Filters: []ec2types.Filter{ + {Name: aws.String("tag:pullpreview-enabled"), Values: []string{"true"}}, + {Name: aws.String("state"), Values: []string{"available"}}, + }, + }) + if err != nil { + return nil, err + } + publicSubnets := make([]ec2types.Subnet, 0, len(output.Subnets)) + for _, subnet := range output.Subnets { + if aws.ToBool(subnet.MapPublicIpOnLaunch) { + publicSubnets = append(publicSubnets, subnet) + } + } + if len(publicSubnets) == 0 { + return nil, fmt.Errorf("no public subnet with tag pullpreview-enabled=true found in region %s", p.region) + } + sort.Slice(publicSubnets, func(i, j int) bool { + return aws.ToString(publicSubnets[i].SubnetId) < aws.ToString(publicSubnets[j].SubnetId) + }) + selected := publicSubnets[0] + return &selected, nil +} + +func (p *Provider) createBootstrapKey(name string) (string, string, error) { + keyName := fmt.Sprintf("pullpreview-%s-%d", sanitizeSecurityGroupName(name), time.Now().UnixNano()) + output, err := p.client.CreateKeyPair(p.ctx, &ec2svc.CreateKeyPairInput{KeyName: aws.String(keyName)}) + if err != nil { + return "", "", err + } + privateKey := strings.TrimSpace(aws.ToString(output.KeyMaterial)) + if privateKey == "" { + return "", "", fmt.Errorf("ec2 create-key-pair returned empty private key material") + } + return keyName, privateKey, nil +} + +func (p *Provider) deleteKeyPairIfExists(keyName string) error { + keyName = strings.TrimSpace(keyName) + if keyName == "" { + return nil + } + _, err := p.client.DeleteKeyPair(p.ctx, &ec2svc.DeleteKeyPairInput{KeyName: aws.String(keyName)}) + if err != nil && containsAWSError(err, "InvalidKeyPair.NotFound") { + return nil + } + return err +} + +func (p *Provider) validateSSHAccessWithRetry(instance *ec2types.Instance, privateKey, certKey string, attempts int) error { + if attempts <= 0 { + if p.sshRetryCount > 0 { + attempts = p.sshRetryCount + } else { + attempts = 1 + } + } + delay := p.sshRetryDelay + if delay <= 0 { + delay = defaultEC2SSHInterval + } + var lastErr error + for i := 0; i < attempts; i++ { + if err := p.validateSSHAccess(instance, privateKey, certKey); err == nil { + return nil + } else { + lastErr = err + } + if i < attempts-1 { + if p.logger != nil { + p.logger.Warnf("SSH access validation failed for EC2 instance %q (attempt %d/%d): %v", instanceNameValue(instance), i+1, attempts, lastErr) + } + time.Sleep(delay) + } + } + return fmt.Errorf("ssh access validation failed for instance %q after %d attempts: %w", instanceNameValue(instance), attempts, lastErr) +} + +func (p *Provider) validateSSHAccess(instance *ec2types.Instance, privateKey, certKey string) error { + privateKey = strings.TrimSpace(privateKey) + if privateKey == "" { + return fmt.Errorf("empty private key") + } + publicIP := p.publicIPAddress(instance) + if publicIP == "" { + return fmt.Errorf("instance %q missing public IP", instanceNameValue(instance)) + } + keyFile, err := os.CreateTemp("", "pullpreview-ec2-key-*") + if err != nil { + return err + } + if err := keyFile.Close(); err != nil { + _ = os.Remove(keyFile.Name()) + return err + } + if err := os.WriteFile(keyFile.Name(), []byte(privateKey+"\n"), 0600); err != nil { + _ = os.Remove(keyFile.Name()) + return err + } + if err := os.Chmod(keyFile.Name(), 0600); err != nil { + _ = os.Remove(keyFile.Name()) + return err + } + certFile := "" + if strings.TrimSpace(certKey) != "" { + certFile = keyFile.Name() + "-cert.pub" + if err := os.WriteFile(certFile, []byte(strings.TrimSpace(certKey)+"\n"), 0600); err != nil { + _ = os.Remove(keyFile.Name()) + return err + } + defer os.Remove(certFile) + } + defer os.Remove(keyFile.Name()) + + output, err := runSSHCommand(p.ctx, keyFile.Name(), certFile, p.sshUser, publicIP) + if err != nil { + return fmt.Errorf("%s: %w", strings.TrimSpace(string(output)), err) + } + return nil +} + +func (p *Provider) generateSignedAccessCredentials() (string, string, error) { + _, privateKey, signer, err := sshca.GenerateSSHKeyPairWithSigner() + if err != nil { + return "", "", err + } + cert, err := sshca.GenerateUserCertificate(p.caSigner, signer, p.sshUser, defaultEC2SSHCertTTL) + if err != nil { + return "", "", err + } + return privateKey, cert, nil +} + +func (p *Provider) instanceByName(name string) (*ec2types.Instance, error) { + instances, err := p.describeInstancesAll(&ec2svc.DescribeInstancesInput{ + Filters: []ec2types.Filter{ + {Name: aws.String("tag:pullpreview_instance_name"), Values: []string{strings.TrimSpace(name)}}, + {Name: aws.String("instance-state-name"), Values: []string{"pending", "running", "stopping", "stopped"}}, + }, + }) + if err != nil { + return nil, err + } + if len(instances) == 0 { + return nil, nil + } + sort.Slice(instances, func(i, j int) bool { + left := aws.ToTime(instances[i].LaunchTime) + right := aws.ToTime(instances[j].LaunchTime) + if !left.Equal(right) { + return left.After(right) + } + return aws.ToString(instances[i].InstanceId) > aws.ToString(instances[j].InstanceId) + }) + selected := instances[0] + return &selected, nil +} + +func (p *Provider) instanceByID(instanceID string) (*ec2types.Instance, error) { + instanceID = strings.TrimSpace(instanceID) + if instanceID == "" { + return nil, nil + } + instances, err := p.describeInstancesAll(&ec2svc.DescribeInstancesInput{InstanceIds: []string{instanceID}}) + if err != nil { + if containsAWSError(err, "InvalidInstanceID.NotFound") { + return nil, nil + } + return nil, err + } + if len(instances) == 0 { + return nil, nil + } + instance := instances[0] + return &instance, nil +} + +func (p *Provider) describeInstancesAll(input *ec2svc.DescribeInstancesInput) ([]ec2types.Instance, error) { + if input == nil { + input = &ec2svc.DescribeInstancesInput{} + } + instances := []ec2types.Instance{} + token := input.NextToken + for { + copyInput := *input + copyInput.NextToken = token + output, err := p.client.DescribeInstances(p.ctx, ©Input) + if err != nil { + return nil, err + } + for _, reservation := range output.Reservations { + instances = append(instances, reservation.Instances...) + } + if output.NextToken == nil || strings.TrimSpace(*output.NextToken) == "" { + break + } + token = output.NextToken + } + return instances, nil +} + +func (p *Provider) ensureInstanceRunning(instance *ec2types.Instance) error { + if instance == nil { + return nil + } + state := instanceStateName(instance) + if state == ec2types.InstanceStateNameRunning { + return nil + } + if state == ec2types.InstanceStateNameStopped { + _, err := p.client.StartInstances(p.ctx, &ec2svc.StartInstancesInput{InstanceIds: []string{aws.ToString(instance.InstanceId)}}) + if err != nil { + return err + } + _, err = p.waitForInstanceState(aws.ToString(instance.InstanceId), ec2types.InstanceStateNameRunning, 60, 5*time.Second) + return err + } + if state == ec2types.InstanceStateNameStopping { + _, err := p.waitForInstanceState(aws.ToString(instance.InstanceId), ec2types.InstanceStateNameStopped, 60, 5*time.Second) + if err != nil { + return err + } + _, err = p.client.StartInstances(p.ctx, &ec2svc.StartInstancesInput{InstanceIds: []string{aws.ToString(instance.InstanceId)}}) + if err != nil { + return err + } + _, err = p.waitForInstanceState(aws.ToString(instance.InstanceId), ec2types.InstanceStateNameRunning, 60, 5*time.Second) + return err + } + if state == ec2types.InstanceStateNamePending { + _, err := p.waitForInstanceState(aws.ToString(instance.InstanceId), ec2types.InstanceStateNameRunning, 60, 5*time.Second) + return err + } + return fmt.Errorf("instance %q is in unsupported state %s", instanceNameValue(instance), state) +} + +func (p *Provider) waitForInstanceState(instanceID string, desired ec2types.InstanceStateName, attempts int, delay time.Duration) (*ec2types.Instance, error) { + if attempts <= 0 { + attempts = 1 + } + if delay <= 0 { + delay = 5 * time.Second + } + var lastState ec2types.InstanceStateName + for i := 0; i < attempts; i++ { + instance, err := p.instanceByID(instanceID) + if err != nil { + return nil, err + } + if instance != nil { + lastState = instanceStateName(instance) + if lastState == desired { + return instance, nil + } + } + if desired == ec2types.InstanceStateNameTerminated && instance == nil { + return nil, nil + } + if i < attempts-1 { + time.Sleep(delay) + } + } + return nil, fmt.Errorf("timeout waiting for instance %s state %s (last=%s)", instanceID, desired, lastState) +} + +func (p *Provider) terminateInstanceAndWait(instanceID string) error { + instanceID = strings.TrimSpace(instanceID) + if instanceID == "" { + return nil + } + _, err := p.client.TerminateInstances(p.ctx, &ec2svc.TerminateInstancesInput{InstanceIds: []string{instanceID}}) + if err != nil { + if containsAWSError(err, "InvalidInstanceID.NotFound") { + return nil + } + return err + } + _, err = p.waitForInstanceState(instanceID, ec2types.InstanceStateNameTerminated, 80, 5*time.Second) + if err != nil { + return err + } + return nil +} + +func (p *Provider) ensureSecurityGroup(name, vpcID string, ports, cidrs []string) (string, bool, error) { + if strings.TrimSpace(vpcID) == "" { + return "", false, fmt.Errorf("missing VPC for security group setup") + } + groups, err := p.securityGroupsForInstance(name, vpcID) + if err != nil { + return "", false, err + } + created := false + groupID := "" + if len(groups) == 0 { + groupName := securityGroupName(name) + createdOut, err := p.client.CreateSecurityGroup(p.ctx, &ec2svc.CreateSecurityGroupInput{ + GroupName: aws.String(groupName), + Description: aws.String("PullPreview preview environment access"), + VpcId: aws.String(vpcID), + }) + if err != nil { + if containsAWSError(err, "InvalidGroup.Duplicate") { + lookup, lookupErr := p.client.DescribeSecurityGroups(p.ctx, &ec2svc.DescribeSecurityGroupsInput{ + Filters: []ec2types.Filter{ + {Name: aws.String("group-name"), Values: []string{groupName}}, + {Name: aws.String("vpc-id"), Values: []string{vpcID}}, + }, + }) + if lookupErr != nil { + return "", false, lookupErr + } + if len(lookup.SecurityGroups) > 0 { + groupID = aws.ToString(lookup.SecurityGroups[0].GroupId) + } + if strings.TrimSpace(groupID) == "" { + return "", false, err + } + } else { + return "", false, err + } + } + if strings.TrimSpace(groupID) == "" { + groupID = aws.ToString(createdOut.GroupId) + } + if strings.TrimSpace(groupID) == "" { + return "", false, fmt.Errorf("create security group returned empty group id") + } + _, _ = p.client.CreateTags(p.ctx, &ec2svc.CreateTagsInput{ + Resources: []string{groupID}, + Tags: toEC2Tags(map[string]string{ + "Name": groupName, + "stack": pullpreview.StackName, + "pullpreview_instance_name": strings.TrimSpace(name), + }), + }) + created = true + } else { + sort.Slice(groups, func(i, j int) bool { + return aws.ToString(groups[i].GroupId) < aws.ToString(groups[j].GroupId) + }) + groupID = aws.ToString(groups[0].GroupId) + } + rules, err := parseSecurityGroupIngressRules(ports, cidrs) + if err != nil { + return "", false, err + } + if err := p.syncSecurityGroupRules(groupID, rules); err != nil { + return "", false, err + } + return groupID, created, nil +} + +func (p *Provider) securityGroupsForInstance(name, vpcID string) ([]ec2types.SecurityGroup, error) { + filters := []ec2types.Filter{ + {Name: aws.String("tag:pullpreview_instance_name"), Values: []string{strings.TrimSpace(name)}}, + } + if strings.TrimSpace(vpcID) != "" { + filters = append(filters, ec2types.Filter{Name: aws.String("vpc-id"), Values: []string{strings.TrimSpace(vpcID)}}) + } + output, err := p.client.DescribeSecurityGroups(p.ctx, &ec2svc.DescribeSecurityGroupsInput{Filters: filters}) + if err != nil { + return nil, err + } + return output.SecurityGroups, nil +} + +func (p *Provider) syncSecurityGroupRules(groupID string, rules []ec2types.IpPermission) error { + groupID = strings.TrimSpace(groupID) + if groupID == "" { + return fmt.Errorf("missing security group id") + } + output, err := p.client.DescribeSecurityGroups(p.ctx, &ec2svc.DescribeSecurityGroupsInput{GroupIds: []string{groupID}}) + if err != nil { + return err + } + if len(output.SecurityGroups) == 0 { + return fmt.Errorf("security group %s not found", groupID) + } + existing := output.SecurityGroups[0].IpPermissions + if len(existing) > 0 { + _, err = p.client.RevokeSecurityGroupIngress(p.ctx, &ec2svc.RevokeSecurityGroupIngressInput{ + GroupId: aws.String(groupID), + IpPermissions: existing, + }) + if err != nil && !containsAWSError(err, "InvalidPermission.NotFound") { + return err + } + } + if len(rules) > 0 { + _, err = p.client.AuthorizeSecurityGroupIngress(p.ctx, &ec2svc.AuthorizeSecurityGroupIngressInput{ + GroupId: aws.String(groupID), + IpPermissions: rules, + }) + if err != nil { + return err + } + } + return nil +} + +func (p *Provider) ensureInstanceSecurityGroup(instance *ec2types.Instance, groupID string) error { + if instance == nil { + return fmt.Errorf("missing instance") + } + groupID = strings.TrimSpace(groupID) + if groupID == "" { + return fmt.Errorf("missing security group") + } + groups := []string{} + hasGroup := false + for _, group := range instance.SecurityGroups { + id := strings.TrimSpace(aws.ToString(group.GroupId)) + if id == "" { + continue + } + if id == groupID { + hasGroup = true + } + groups = append(groups, id) + } + if hasGroup { + return nil + } + groups = append(groups, groupID) + _, err := p.client.ModifyInstanceAttribute(p.ctx, &ec2svc.ModifyInstanceAttributeInput{ + InstanceId: instance.InstanceId, + Groups: groups, + }) + return err +} + +func (p *Provider) deleteSecurityGroupsForInstance(name string) error { + groups, err := p.securityGroupsForInstance(name, "") + if err != nil { + return err + } + var firstErr error + for _, group := range groups { + groupID := aws.ToString(group.GroupId) + if err := p.deleteSecurityGroupByID(groupID); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +func (p *Provider) deleteSecurityGroupByID(groupID string) error { + groupID = strings.TrimSpace(groupID) + if groupID == "" { + return nil + } + _, err := p.client.DeleteSecurityGroup(p.ctx, &ec2svc.DeleteSecurityGroupInput{GroupId: aws.String(groupID)}) + if err != nil && (containsAWSError(err, "InvalidGroup.NotFound") || containsAWSError(err, "DependencyViolation")) { + return nil + } + return err +} + +func (p *Provider) destroyInstanceAndSecurityGroups(instance *ec2types.Instance, name string) error { + if instance != nil { + if err := p.terminateInstanceAndWait(aws.ToString(instance.InstanceId)); err != nil { + return fmt.Errorf("failed to delete instance %q: %w", name, err) + } + } + if err := p.deleteSecurityGroupsForInstance(name); err != nil && p.logger != nil { + p.logger.Warnf("Unable to cleanup security groups for %s: %v", name, err) + } + return nil +} + +func (p *Provider) publicIPAddress(instance *ec2types.Instance) string { + if instance == nil { + return "" + } + return strings.TrimSpace(aws.ToString(instance.PublicIpAddress)) +} + +func resolveEC2InstanceType(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return defaultEC2InstanceType + } + return value +} + +func parseSecurityGroupIngressRules(ports, cidrs []string) ([]ec2types.IpPermission, error) { + normalizedCIDRs := normalizeCIDRs(cidrs) + rules := map[string]ec2types.IpPermission{} + for _, raw := range ports { + start, end, protocol, err := parseFirewallPort(raw) + if err != nil { + return nil, err + } + key := fmt.Sprintf("%d-%d/%s/%s", start, end, protocol, strings.Join(normalizedCIDRs, ",")) + if _, exists := rules[key]; exists { + continue + } + rules[key] = buildIPPermission(start, end, protocol, normalizedCIDRs) + } + const sshPort = 22 + sshCIDRs := []string{"0.0.0.0/0"} + sshKey := fmt.Sprintf("%d-%d/tcp/%s", sshPort, sshPort, strings.Join(sshCIDRs, ",")) + if _, exists := rules[sshKey]; !exists { + rules[sshKey] = buildIPPermission(sshPort, sshPort, "tcp", sshCIDRs) + } + result := make([]ec2types.IpPermission, 0, len(rules)) + for _, rule := range rules { + result = append(result, rule) + } + return result, nil +} + +func buildIPPermission(start, end int, protocol string, cidrs []string) ec2types.IpPermission { + permission := ec2types.IpPermission{IpProtocol: aws.String(protocol)} + permission.FromPort = ptrInt32(int32(start)) + permission.ToPort = ptrInt32(int32(end)) + for _, cidr := range cidrs { + if strings.Contains(cidr, ":") { + permission.Ipv6Ranges = append(permission.Ipv6Ranges, ec2types.Ipv6Range{CidrIpv6: aws.String(cidr)}) + continue + } + permission.IpRanges = append(permission.IpRanges, ec2types.IpRange{CidrIp: aws.String(cidr)}) + } + return permission +} + +func parseFirewallPort(raw string) (int, int, string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return 0, 0, "", fmt.Errorf("empty port definition") + } + portRange := raw + protocol := "tcp" + if idx := strings.Index(raw, "/"); idx >= 0 { + portRange = strings.TrimSpace(raw[:idx]) + protocol = strings.ToLower(strings.TrimSpace(raw[idx+1:])) + } + if protocol == "" { + protocol = "tcp" + } + if protocol != "tcp" && protocol != "udp" && protocol != "icmp" { + return 0, 0, "", fmt.Errorf("unsupported protocol %s in port definition %q", protocol, raw) + } + if strings.Contains(portRange, "-") { + parts := strings.SplitN(portRange, "-", 2) + if len(parts) != 2 || strings.TrimSpace(parts[0]) == "" || strings.TrimSpace(parts[1]) == "" { + return 0, 0, "", fmt.Errorf("invalid port range %q", raw) + } + start, err := mustParsePort(parts[0]) + if err != nil { + return 0, 0, "", fmt.Errorf("invalid port range %q: %w", raw, err) + } + end, err := mustParsePort(parts[1]) + if err != nil { + return 0, 0, "", fmt.Errorf("invalid port range %q: %w", raw, err) + } + return start, end, protocol, nil + } + port, err := mustParsePort(portRange) + if err != nil { + return 0, 0, "", fmt.Errorf("invalid port %q: %w", raw, err) + } + return port, port, protocol, nil +} + +func mustParsePort(raw string) (int, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return 0, fmt.Errorf("empty port") + } + value, err := strconv.Atoi(raw) + if err != nil { + return 0, err + } + if value <= 0 || value > 65535 { + return 0, fmt.Errorf("invalid port %d", value) + } + return value, nil +} + +func normalizeCIDRs(raw []string) []string { + if len(raw) == 0 { + raw = []string{"0.0.0.0/0"} + } + normalized := []string{} + seen := map[string]struct{}{} + for _, value := range raw { + value = strings.TrimSpace(value) + if value == "" { + continue + } + parsed := parseCIDR(value) + if parsed == "" { + continue + } + if _, ok := seen[parsed]; ok { + continue + } + seen[parsed] = struct{}{} + normalized = append(normalized, parsed) + } + if len(normalized) == 0 { + normalized = append(normalized, "0.0.0.0/0") + } + sort.Strings(normalized) + return normalized +} + +func parseCIDR(value string) string { + if _, parsed, err := net.ParseCIDR(value); err == nil { + return parsed.String() + } + ip := net.ParseIP(value) + if ip == nil { + return "" + } + if ip.To4() != nil { + return fmt.Sprintf("%s/32", ip.String()) + } + return fmt.Sprintf("%s/128", ip.String()) +} + +func mergeTags(base, extra map[string]string) map[string]string { + result := map[string]string{} + for key, value := range base { + result[key] = value + } + for key, value := range extra { + result[key] = value + } + return result +} + +func toEC2Tags(input map[string]string) []ec2types.Tag { + tags := make([]ec2types.Tag, 0, len(input)) + for key, value := range input { + k := strings.TrimSpace(key) + v := strings.TrimSpace(value) + if k == "" || v == "" { + continue + } + tags = append(tags, ec2types.Tag{Key: aws.String(k), Value: aws.String(v)}) + } + return tags +} + +func tagsToMap(input []ec2types.Tag) map[string]string { + result := map[string]string{} + for _, tag := range input { + result[aws.ToString(tag.Key)] = aws.ToString(tag.Value) + } + return result +} + +func instanceName(instance ec2types.Instance) string { + if name := instanceNameValue(&instance); strings.TrimSpace(name) != "" { + return name + } + return aws.ToString(instance.InstanceId) +} + +func instanceNameValue(instance *ec2types.Instance) string { + if instance == nil { + return "" + } + for _, tag := range instance.Tags { + if strings.EqualFold(strings.TrimSpace(aws.ToString(tag.Key)), "Name") { + return strings.TrimSpace(aws.ToString(tag.Value)) + } + } + for _, tag := range instance.Tags { + if strings.EqualFold(strings.TrimSpace(aws.ToString(tag.Key)), "pullpreview_instance_name") { + return strings.TrimSpace(aws.ToString(tag.Value)) + } + } + return strings.TrimSpace(aws.ToString(instance.InstanceId)) +} + +func instanceStateName(instance *ec2types.Instance) ec2types.InstanceStateName { + if instance == nil || instance.State == nil { + return "" + } + return instance.State.Name +} + +func securityGroupName(instanceName string) string { + name := sanitizeSecurityGroupName(instanceName) + if len(name) > 240 { + name = name[:240] + } + if name == "" { + name = "instance" + } + return "pullpreview-" + name +} + +func sanitizeSecurityGroupName(value string) string { + value = strings.TrimSpace(strings.ToLower(value)) + if value == "" { + return "instance" + } + value = ec2SecurityGroupNameSanitizer.ReplaceAllString(value, "-") + value = strings.Trim(value, "-") + if value == "" { + value = "instance" + } + return value +} + +func containsAWSError(err error, code string) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), code) +} + +func ptrInt32(value int32) *int32 { + return &value +} diff --git a/internal/providers/ec2/ec2_test.go b/internal/providers/ec2/ec2_test.go new file mode 100644 index 0000000..ee29da0 --- /dev/null +++ b/internal/providers/ec2/ec2_test.go @@ -0,0 +1,588 @@ +package ec2 + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + ec2svc "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/pullpreview/action/internal/providers/sshca" + "github.com/pullpreview/action/internal/pullpreview" +) + +func TestParseConfigFromEnv(t *testing.T) { + caKey := mustTestCAKey(t) + + cfgRaw, err := ParseConfigFromEnv(map[string]string{ + "REGION": "us-west-2", + "IMAGE": "my-prefix", + "PULLPREVIEW_CA_KEY": caKey, + }) + if err != nil { + t.Fatalf("ParseConfigFromEnv() error: %v", err) + } + cfg := cfgRaw.(Config) + if cfg.Region != "us-west-2" { + t.Fatalf("expected region override, got %q", cfg.Region) + } + if cfg.Image != "my-prefix" { + t.Fatalf("expected image prefix, got %q", cfg.Image) + } + if cfg.CAKey != caKey { + t.Fatalf("expected CA key, got %q", cfg.CAKey) + } + if cfg.SSHUsername != defaultEC2SSHUser { + t.Fatalf("expected default username %q, got %q", defaultEC2SSHUser, cfg.SSHUsername) + } + + cfgRaw, err = ParseConfigFromEnv(map[string]string{ + "AWS_REGION": "eu-central-1", + "PULLPREVIEW_CA_KEY": caKey, + }) + if err != nil { + t.Fatalf("ParseConfigFromEnv() fallback region error: %v", err) + } + cfg = cfgRaw.(Config) + if cfg.Region != "eu-central-1" { + t.Fatalf("expected AWS_REGION fallback, got %q", cfg.Region) + } + + if _, err := ParseConfigFromEnv(map[string]string{"REGION": "us-east-1"}); err == nil { + t.Fatalf("expected missing CA key error") + } +} + +func TestResolveImageByAMIID(t *testing.T) { + provider, fake := newTestProvider(t, Config{ + Region: "us-east-1", + Image: "ami-1234567890", + CAKey: mustTestCAKey(t), + CAKeyEnv: "PULLPREVIEW_CA_KEY", + SSHUsername: defaultEC2SSHUser, + }) + + fake.describeImagesFn = func(_ context.Context, input *ec2svc.DescribeImagesInput) (*ec2svc.DescribeImagesOutput, error) { + if len(input.ImageIds) != 1 || input.ImageIds[0] != "ami-1234567890" { + return nil, fmt.Errorf("unexpected image id lookup: %#v", input.ImageIds) + } + return &ec2svc.DescribeImagesOutput{Images: []ec2types.Image{ + {ImageId: aws.String("ami-1234567890"), State: ec2types.ImageStateAvailable, Architecture: ec2types.ArchitectureValuesX8664}, + }}, nil + } + + image, err := provider.resolveImage(map[string]struct{}{"x86_64": {}}) + if err != nil { + t.Fatalf("resolveImage() error: %v", err) + } + if aws.ToString(image.ImageId) != "ami-1234567890" { + t.Fatalf("unexpected selected image: %q", aws.ToString(image.ImageId)) + } +} + +func TestResolveImagePrefixUsesAvailabilityFilterOnlyAndNewest(t *testing.T) { + provider, fake := newTestProvider(t, Config{ + Region: "us-east-1", + Image: "pullpreview-app", + CAKey: mustTestCAKey(t), + CAKeyEnv: "PULLPREVIEW_CA_KEY", + SSHUsername: defaultEC2SSHUser, + }) + + fake.describeImagesFn = func(_ context.Context, input *ec2svc.DescribeImagesInput) (*ec2svc.DescribeImagesOutput, error) { + if len(input.Owners) != 2 || input.Owners[0] != "self" || input.Owners[1] != "amazon" { + return nil, fmt.Errorf("unexpected owners: %#v", input.Owners) + } + hasStateFilter := false + hasArchFilter := false + for _, filter := range input.Filters { + if aws.ToString(filter.Name) == "state" { + hasStateFilter = true + } + if aws.ToString(filter.Name) == "architecture" { + hasArchFilter = true + } + } + if !hasStateFilter { + return nil, fmt.Errorf("missing state=available filter") + } + if hasArchFilter { + return nil, fmt.Errorf("unexpected architecture filter") + } + return &ec2svc.DescribeImagesOutput{Images: []ec2types.Image{ + {ImageId: aws.String("ami-older"), State: ec2types.ImageStateAvailable, CreationDate: aws.String("2026-01-01T00:00:00Z"), Architecture: ec2types.ArchitectureValuesX8664}, + {ImageId: aws.String("ami-newest"), State: ec2types.ImageStateAvailable, CreationDate: aws.String("2026-01-02T00:00:00Z"), Architecture: ec2types.ArchitectureValuesX8664}, + }}, nil + } + + image, err := provider.resolveImage(map[string]struct{}{"x86_64": {}}) + if err != nil { + t.Fatalf("resolveImage() error: %v", err) + } + if aws.ToString(image.ImageId) != "ami-newest" { + t.Fatalf("expected newest AMI to be selected, got %q", aws.ToString(image.ImageId)) + } +} + +func TestResolveImageFailsFastOnArchMismatchWithoutFallback(t *testing.T) { + provider, fake := newTestProvider(t, Config{ + Region: "us-east-1", + Image: "pullpreview-app", + CAKey: mustTestCAKey(t), + CAKeyEnv: "PULLPREVIEW_CA_KEY", + SSHUsername: defaultEC2SSHUser, + }) + + fake.describeImagesFn = func(_ context.Context, _ *ec2svc.DescribeImagesInput) (*ec2svc.DescribeImagesOutput, error) { + return &ec2svc.DescribeImagesOutput{Images: []ec2types.Image{ + {ImageId: aws.String("ami-newest-arm"), State: ec2types.ImageStateAvailable, CreationDate: aws.String("2026-01-03T00:00:00Z"), Architecture: ec2types.ArchitectureValuesArm64}, + {ImageId: aws.String("ami-older-x86"), State: ec2types.ImageStateAvailable, CreationDate: aws.String("2026-01-01T00:00:00Z"), Architecture: ec2types.ArchitectureValuesX8664}, + }}, nil + } + + _, err := provider.resolveImage(map[string]struct{}{"x86_64": {}}) + if err == nil { + t.Fatalf("expected architecture mismatch failure") + } + if !strings.Contains(err.Error(), "ami-newest-arm") { + t.Fatalf("expected failure to reference newest incompatible AMI, got: %v", err) + } + if strings.Contains(err.Error(), "ami-older-x86") { + t.Fatalf("expected no fallback to older compatible AMI, got: %v", err) + } +} + +func TestFindTaggedPublicSubnetFailsWhenNoPublicSubnet(t *testing.T) { + provider, fake := newTestProvider(t, Config{ + Region: "us-east-1", + CAKey: mustTestCAKey(t), + CAKeyEnv: "PULLPREVIEW_CA_KEY", + SSHUsername: defaultEC2SSHUser, + }) + + fake.describeSubnetsFn = func(_ context.Context, _ *ec2svc.DescribeSubnetsInput) (*ec2svc.DescribeSubnetsOutput, error) { + return &ec2svc.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{ + {SubnetId: aws.String("subnet-private"), MapPublicIpOnLaunch: aws.Bool(false)}, + }}, nil + } + + _, err := provider.findTaggedPublicSubnet() + if err == nil { + t.Fatalf("expected missing tagged public subnet error") + } +} + +func TestLaunchReusesExistingInstance(t *testing.T) { + provider, fake := newTestProvider(t, Config{ + Region: "us-east-1", + CAKey: mustTestCAKey(t), + CAKeyEnv: "PULLPREVIEW_CA_KEY", + SSHUsername: defaultEC2SSHUser, + }) + instance := ec2types.Instance{ + InstanceId: aws.String("i-1"), + InstanceType: ec2types.InstanceTypeT3Small, + VpcId: aws.String("vpc-1"), + PublicIpAddress: aws.String("203.0.113.10"), + SecurityGroups: []ec2types.GroupIdentifier{{GroupId: aws.String("sg-1")}}, + State: &ec2types.InstanceState{Name: ec2types.InstanceStateNameRunning}, + Tags: []ec2types.Tag{{Key: aws.String("Name"), Value: aws.String("gh-1-pr-1")}, {Key: aws.String("pullpreview_instance_name"), Value: aws.String("gh-1-pr-1")}}, + LaunchTime: aws.Time(time.Now()), + Placement: &ec2types.Placement{AvailabilityZone: aws.String("us-east-1a")}, + PrivateIpAddress: aws.String("10.0.0.1"), + } + + fake.describeInstancesFn = func(_ context.Context, _ *ec2svc.DescribeInstancesInput) (*ec2svc.DescribeInstancesOutput, error) { + return &ec2svc.DescribeInstancesOutput{Reservations: []ec2types.Reservation{{Instances: []ec2types.Instance{instance}}}}, nil + } + fake.describeSecurityGroupsFn = func(_ context.Context, input *ec2svc.DescribeSecurityGroupsInput) (*ec2svc.DescribeSecurityGroupsOutput, error) { + if len(input.GroupIds) > 0 { + return &ec2svc.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{{GroupId: aws.String("sg-1"), IpPermissions: nil}}}, nil + } + return &ec2svc.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{{GroupId: aws.String("sg-1")}}}, nil + } + fake.authorizeSecurityGroupIngressFn = func(_ context.Context, _ *ec2svc.AuthorizeSecurityGroupIngressInput) (*ec2svc.AuthorizeSecurityGroupIngressOutput, error) { + return &ec2svc.AuthorizeSecurityGroupIngressOutput{}, nil + } + + origRunSSH := runSSHCommand + runSSHCommand = func(context.Context, string, string, string, string) ([]byte, error) { + return []byte("ok"), nil + } + defer func() { runSSHCommand = origRunSSH }() + + access, err := provider.Launch("gh-1-pr-1", pullpreview.LaunchOptions{Ports: []string{"80/tcp"}, CIDRs: []string{"0.0.0.0/0"}}) + if err != nil { + t.Fatalf("Launch() error: %v", err) + } + if access.IPAddress != "203.0.113.10" { + t.Fatalf("unexpected access ip: %q", access.IPAddress) + } + if strings.TrimSpace(access.CertKey) == "" { + t.Fatalf("expected cert key on reused launch") + } + if fake.runInstancesCalls != 0 { + t.Fatalf("expected no RunInstances calls, got %d", fake.runInstancesCalls) + } +} + +func TestLaunchRecreatesWhenExistingSSHValidationFails(t *testing.T) { + provider, fake := newTestProvider(t, Config{ + Region: "us-east-1", + CAKey: mustTestCAKey(t), + CAKeyEnv: "PULLPREVIEW_CA_KEY", + SSHUsername: defaultEC2SSHUser, + }) + provider.sshRetryCount = 1 + provider.sshRetryDelay = 1 * time.Millisecond + + oldInstance := ec2types.Instance{ + InstanceId: aws.String("i-old"), + InstanceType: ec2types.InstanceTypeT3Small, + VpcId: aws.String("vpc-1"), + PublicIpAddress: aws.String("203.0.113.10"), + SecurityGroups: []ec2types.GroupIdentifier{{GroupId: aws.String("sg-old")}}, + State: &ec2types.InstanceState{Name: ec2types.InstanceStateNameRunning}, + Tags: []ec2types.Tag{ + {Key: aws.String("Name"), Value: aws.String("gh-9-pr-9")}, + {Key: aws.String("pullpreview_instance_name"), Value: aws.String("gh-9-pr-9")}, + }, + LaunchTime: aws.Time(time.Now().Add(-time.Hour)), + Placement: &ec2types.Placement{AvailabilityZone: aws.String("us-east-1a")}, + } + newInstance := ec2types.Instance{ + InstanceId: aws.String("i-new"), + InstanceType: ec2types.InstanceTypeT3Small, + VpcId: aws.String("vpc-1"), + PublicIpAddress: aws.String("198.51.100.20"), + SecurityGroups: []ec2types.GroupIdentifier{{GroupId: aws.String("sg-new")}}, + State: &ec2types.InstanceState{Name: ec2types.InstanceStateNameRunning}, + Tags: []ec2types.Tag{ + {Key: aws.String("Name"), Value: aws.String("gh-9-pr-9")}, + {Key: aws.String("pullpreview_instance_name"), Value: aws.String("gh-9-pr-9")}, + }, + LaunchTime: aws.Time(time.Now()), + Placement: &ec2types.Placement{AvailabilityZone: aws.String("us-east-1a")}, + } + + state := "existing" + sgDeleted := false + + fake.describeInstancesFn = func(_ context.Context, input *ec2svc.DescribeInstancesInput) (*ec2svc.DescribeInstancesOutput, error) { + if len(input.InstanceIds) > 0 { + id := input.InstanceIds[0] + switch id { + case "i-old": + if state == "existing" { + return &ec2svc.DescribeInstancesOutput{Reservations: []ec2types.Reservation{{Instances: []ec2types.Instance{oldInstance}}}}, nil + } + return &ec2svc.DescribeInstancesOutput{}, nil + case "i-new": + if state == "created" { + return &ec2svc.DescribeInstancesOutput{Reservations: []ec2types.Reservation{{Instances: []ec2types.Instance{newInstance}}}}, nil + } + return &ec2svc.DescribeInstancesOutput{}, nil + } + return &ec2svc.DescribeInstancesOutput{}, nil + } + if state == "existing" { + return &ec2svc.DescribeInstancesOutput{Reservations: []ec2types.Reservation{{Instances: []ec2types.Instance{oldInstance}}}}, nil + } + if state == "created" { + return &ec2svc.DescribeInstancesOutput{Reservations: []ec2types.Reservation{{Instances: []ec2types.Instance{newInstance}}}}, nil + } + return &ec2svc.DescribeInstancesOutput{}, nil + } + fake.describeSecurityGroupsFn = func(_ context.Context, input *ec2svc.DescribeSecurityGroupsInput) (*ec2svc.DescribeSecurityGroupsOutput, error) { + if len(input.GroupIds) > 0 { + groupID := input.GroupIds[0] + return &ec2svc.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{{GroupId: aws.String(groupID)}}}, nil + } + if state == "existing" && !sgDeleted { + return &ec2svc.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{{GroupId: aws.String("sg-old")}}}, nil + } + if state == "created" { + return &ec2svc.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{{GroupId: aws.String("sg-new")}}}, nil + } + return &ec2svc.DescribeSecurityGroupsOutput{}, nil + } + fake.authorizeSecurityGroupIngressFn = func(_ context.Context, _ *ec2svc.AuthorizeSecurityGroupIngressInput) (*ec2svc.AuthorizeSecurityGroupIngressOutput, error) { + return &ec2svc.AuthorizeSecurityGroupIngressOutput{}, nil + } + fake.terminateInstancesFn = func(_ context.Context, _ *ec2svc.TerminateInstancesInput) (*ec2svc.TerminateInstancesOutput, error) { + state = "creating" + return &ec2svc.TerminateInstancesOutput{}, nil + } + fake.deleteSecurityGroupFn = func(_ context.Context, _ *ec2svc.DeleteSecurityGroupInput) (*ec2svc.DeleteSecurityGroupOutput, error) { + sgDeleted = true + return &ec2svc.DeleteSecurityGroupOutput{}, nil + } + fake.describeSubnetsFn = func(_ context.Context, _ *ec2svc.DescribeSubnetsInput) (*ec2svc.DescribeSubnetsOutput, error) { + return &ec2svc.DescribeSubnetsOutput{Subnets: []ec2types.Subnet{{SubnetId: aws.String("subnet-1"), VpcId: aws.String("vpc-1"), MapPublicIpOnLaunch: aws.Bool(true)}}}, nil + } + fake.describeImagesFn = func(_ context.Context, _ *ec2svc.DescribeImagesInput) (*ec2svc.DescribeImagesOutput, error) { + return &ec2svc.DescribeImagesOutput{Images: []ec2types.Image{ + {ImageId: aws.String("ami-123"), State: ec2types.ImageStateAvailable, CreationDate: aws.String("2026-01-02T00:00:00Z"), Architecture: ec2types.ArchitectureValuesX8664}, + }}, nil + } + fake.runInstancesFn = func(_ context.Context, _ *ec2svc.RunInstancesInput) (*ec2svc.RunInstancesOutput, error) { + state = "created" + return &ec2svc.RunInstancesOutput{Instances: []ec2types.Instance{{InstanceId: aws.String("i-new")}}}, nil + } + + origRunSSH := runSSHCommand + runSSHCommand = func(_ context.Context, _ string, certFile string, _ string, host string) ([]byte, error) { + if host == "203.0.113.10" && strings.TrimSpace(certFile) != "" { + return nil, fmt.Errorf("cert rejected") + } + return []byte("ok"), nil + } + defer func() { runSSHCommand = origRunSSH }() + + access, err := provider.Launch("gh-9-pr-9", pullpreview.LaunchOptions{Ports: []string{"80/tcp"}, CIDRs: []string{"0.0.0.0/0"}}) + if err != nil { + t.Fatalf("Launch() error: %v", err) + } + if access.IPAddress != "198.51.100.20" { + t.Fatalf("expected recreated instance IP, got %q", access.IPAddress) + } + if fake.runInstancesCalls == 0 { + t.Fatalf("expected create path to run after stale SSH failure") + } + if fake.terminateInstancesCalls == 0 { + t.Fatalf("expected stale instance termination") + } +} + +func TestTerminateAndListInstances(t *testing.T) { + provider, fake := newTestProvider(t, Config{ + Region: "us-east-1", + CAKey: mustTestCAKey(t), + CAKeyEnv: "PULLPREVIEW_CA_KEY", + SSHUsername: defaultEC2SSHUser, + }) + instance := ec2types.Instance{ + InstanceId: aws.String("i-terminate"), + InstanceType: ec2types.InstanceTypeT3Small, + PublicIpAddress: aws.String("198.51.100.8"), + VpcId: aws.String("vpc-1"), + State: &ec2types.InstanceState{Name: ec2types.InstanceStateNameRunning}, + Tags: []ec2types.Tag{ + {Key: aws.String("Name"), Value: aws.String("gh-2-pr-5")}, + {Key: aws.String("pullpreview_instance_name"), Value: aws.String("gh-2-pr-5")}, + {Key: aws.String("repo_name"), Value: aws.String("action")}, + {Key: aws.String("org_name"), Value: aws.String("pullpreview")}, + {Key: aws.String("stack"), Value: aws.String(pullpreview.StackName)}, + }, + LaunchTime: aws.Time(time.Unix(0, 0)), + Placement: &ec2types.Placement{AvailabilityZone: aws.String("us-east-1a")}, + } + + fake.describeInstancesFn = func(_ context.Context, input *ec2svc.DescribeInstancesInput) (*ec2svc.DescribeInstancesOutput, error) { + if len(input.InstanceIds) > 0 { + return &ec2svc.DescribeInstancesOutput{}, nil + } + return &ec2svc.DescribeInstancesOutput{Reservations: []ec2types.Reservation{{Instances: []ec2types.Instance{instance}}}}, nil + } + fake.terminateInstancesFn = func(_ context.Context, _ *ec2svc.TerminateInstancesInput) (*ec2svc.TerminateInstancesOutput, error) { + fake.terminateInstancesCalls++ + return &ec2svc.TerminateInstancesOutput{}, nil + } + fake.describeSecurityGroupsFn = func(_ context.Context, _ *ec2svc.DescribeSecurityGroupsInput) (*ec2svc.DescribeSecurityGroupsOutput, error) { + return &ec2svc.DescribeSecurityGroupsOutput{SecurityGroups: []ec2types.SecurityGroup{{GroupId: aws.String("sg-x")}}}, nil + } + fake.deleteSecurityGroupFn = func(_ context.Context, _ *ec2svc.DeleteSecurityGroupInput) (*ec2svc.DeleteSecurityGroupOutput, error) { + fake.deleteSecurityGroupCalls++ + return &ec2svc.DeleteSecurityGroupOutput{}, nil + } + + if err := provider.Terminate("gh-2-pr-5"); err != nil { + t.Fatalf("Terminate() error: %v", err) + } + if fake.terminateInstancesCalls == 0 { + t.Fatalf("expected TerminateInstances call") + } + if fake.deleteSecurityGroupCalls == 0 { + t.Fatalf("expected security group cleanup call") + } + + instances, err := provider.ListInstances(map[string]string{"stack": pullpreview.StackName}) + if err != nil { + t.Fatalf("ListInstances() error: %v", err) + } + if len(instances) != 1 || instances[0].Name != "gh-2-pr-5" { + t.Fatalf("unexpected list result: %#v", instances) + } +} + +func newTestProvider(t *testing.T, cfg Config) (*Provider, *fakeEC2Client) { + t.Helper() + if strings.TrimSpace(cfg.CAKey) == "" { + cfg.CAKey = mustTestCAKey(t) + } + if cfg.CAKeyEnv == "" { + cfg.CAKeyEnv = "PULLPREVIEW_CA_KEY" + } + if cfg.Region == "" { + cfg.Region = "us-east-1" + } + if cfg.SSHUsername == "" { + cfg.SSHUsername = defaultEC2SSHUser + } + fake := &fakeEC2Client{} + provider, err := newProviderWithClient(context.Background(), cfg, nil, fake) + if err != nil { + t.Fatalf("newProviderWithClient() error: %v", err) + } + return provider, fake +} + +func mustTestCAKey(t *testing.T) string { + t.Helper() + _, privateKey, _, err := sshca.GenerateSSHKeyPairWithSigner() + if err != nil { + t.Fatalf("GenerateSSHKeyPairWithSigner() error: %v", err) + } + return privateKey +} + +type fakeEC2Client struct { + describeInstancesFn func(context.Context, *ec2svc.DescribeInstancesInput) (*ec2svc.DescribeInstancesOutput, error) + runInstancesFn func(context.Context, *ec2svc.RunInstancesInput) (*ec2svc.RunInstancesOutput, error) + terminateInstancesFn func(context.Context, *ec2svc.TerminateInstancesInput) (*ec2svc.TerminateInstancesOutput, error) + startInstancesFn func(context.Context, *ec2svc.StartInstancesInput) (*ec2svc.StartInstancesOutput, error) + modifyInstanceAttributeFn func(context.Context, *ec2svc.ModifyInstanceAttributeInput) (*ec2svc.ModifyInstanceAttributeOutput, error) + describeSubnetsFn func(context.Context, *ec2svc.DescribeSubnetsInput) (*ec2svc.DescribeSubnetsOutput, error) + describeInstanceTypesFn func(context.Context, *ec2svc.DescribeInstanceTypesInput) (*ec2svc.DescribeInstanceTypesOutput, error) + describeImagesFn func(context.Context, *ec2svc.DescribeImagesInput) (*ec2svc.DescribeImagesOutput, error) + describeSecurityGroupsFn func(context.Context, *ec2svc.DescribeSecurityGroupsInput) (*ec2svc.DescribeSecurityGroupsOutput, error) + createSecurityGroupFn func(context.Context, *ec2svc.CreateSecurityGroupInput) (*ec2svc.CreateSecurityGroupOutput, error) + deleteSecurityGroupFn func(context.Context, *ec2svc.DeleteSecurityGroupInput) (*ec2svc.DeleteSecurityGroupOutput, error) + authorizeSecurityGroupIngressFn func(context.Context, *ec2svc.AuthorizeSecurityGroupIngressInput) (*ec2svc.AuthorizeSecurityGroupIngressOutput, error) + revokeSecurityGroupIngressFn func(context.Context, *ec2svc.RevokeSecurityGroupIngressInput) (*ec2svc.RevokeSecurityGroupIngressOutput, error) + createTagsFn func(context.Context, *ec2svc.CreateTagsInput) (*ec2svc.CreateTagsOutput, error) + createKeyPairFn func(context.Context, *ec2svc.CreateKeyPairInput) (*ec2svc.CreateKeyPairOutput, error) + deleteKeyPairFn func(context.Context, *ec2svc.DeleteKeyPairInput) (*ec2svc.DeleteKeyPairOutput, error) + + runInstancesCalls int + terminateInstancesCalls int + deleteSecurityGroupCalls int +} + +func (f *fakeEC2Client) DescribeInstances(ctx context.Context, input *ec2svc.DescribeInstancesInput) (*ec2svc.DescribeInstancesOutput, error) { + if f.describeInstancesFn != nil { + return f.describeInstancesFn(ctx, input) + } + return &ec2svc.DescribeInstancesOutput{}, nil +} + +func (f *fakeEC2Client) RunInstances(ctx context.Context, input *ec2svc.RunInstancesInput) (*ec2svc.RunInstancesOutput, error) { + f.runInstancesCalls++ + if f.runInstancesFn != nil { + return f.runInstancesFn(ctx, input) + } + return &ec2svc.RunInstancesOutput{}, nil +} + +func (f *fakeEC2Client) TerminateInstances(ctx context.Context, input *ec2svc.TerminateInstancesInput) (*ec2svc.TerminateInstancesOutput, error) { + f.terminateInstancesCalls++ + if f.terminateInstancesFn != nil { + return f.terminateInstancesFn(ctx, input) + } + return &ec2svc.TerminateInstancesOutput{}, nil +} + +func (f *fakeEC2Client) StartInstances(ctx context.Context, input *ec2svc.StartInstancesInput) (*ec2svc.StartInstancesOutput, error) { + if f.startInstancesFn != nil { + return f.startInstancesFn(ctx, input) + } + return &ec2svc.StartInstancesOutput{}, nil +} + +func (f *fakeEC2Client) ModifyInstanceAttribute(ctx context.Context, input *ec2svc.ModifyInstanceAttributeInput) (*ec2svc.ModifyInstanceAttributeOutput, error) { + if f.modifyInstanceAttributeFn != nil { + return f.modifyInstanceAttributeFn(ctx, input) + } + return &ec2svc.ModifyInstanceAttributeOutput{}, nil +} + +func (f *fakeEC2Client) DescribeSubnets(ctx context.Context, input *ec2svc.DescribeSubnetsInput) (*ec2svc.DescribeSubnetsOutput, error) { + if f.describeSubnetsFn != nil { + return f.describeSubnetsFn(ctx, input) + } + return &ec2svc.DescribeSubnetsOutput{}, nil +} + +func (f *fakeEC2Client) DescribeInstanceTypes(ctx context.Context, input *ec2svc.DescribeInstanceTypesInput) (*ec2svc.DescribeInstanceTypesOutput, error) { + if f.describeInstanceTypesFn != nil { + return f.describeInstanceTypesFn(ctx, input) + } + return &ec2svc.DescribeInstanceTypesOutput{InstanceTypes: []ec2types.InstanceTypeInfo{{ProcessorInfo: &ec2types.ProcessorInfo{SupportedArchitectures: []ec2types.ArchitectureType{ec2types.ArchitectureTypeX8664}}}}}, nil +} + +func (f *fakeEC2Client) DescribeImages(ctx context.Context, input *ec2svc.DescribeImagesInput) (*ec2svc.DescribeImagesOutput, error) { + if f.describeImagesFn != nil { + return f.describeImagesFn(ctx, input) + } + return &ec2svc.DescribeImagesOutput{}, nil +} + +func (f *fakeEC2Client) DescribeSecurityGroups(ctx context.Context, input *ec2svc.DescribeSecurityGroupsInput) (*ec2svc.DescribeSecurityGroupsOutput, error) { + if f.describeSecurityGroupsFn != nil { + return f.describeSecurityGroupsFn(ctx, input) + } + return &ec2svc.DescribeSecurityGroupsOutput{}, nil +} + +func (f *fakeEC2Client) CreateSecurityGroup(ctx context.Context, input *ec2svc.CreateSecurityGroupInput) (*ec2svc.CreateSecurityGroupOutput, error) { + if f.createSecurityGroupFn != nil { + return f.createSecurityGroupFn(ctx, input) + } + return &ec2svc.CreateSecurityGroupOutput{GroupId: aws.String("sg-created")}, nil +} + +func (f *fakeEC2Client) DeleteSecurityGroup(ctx context.Context, input *ec2svc.DeleteSecurityGroupInput) (*ec2svc.DeleteSecurityGroupOutput, error) { + f.deleteSecurityGroupCalls++ + if f.deleteSecurityGroupFn != nil { + return f.deleteSecurityGroupFn(ctx, input) + } + return &ec2svc.DeleteSecurityGroupOutput{}, nil +} + +func (f *fakeEC2Client) AuthorizeSecurityGroupIngress(ctx context.Context, input *ec2svc.AuthorizeSecurityGroupIngressInput) (*ec2svc.AuthorizeSecurityGroupIngressOutput, error) { + if f.authorizeSecurityGroupIngressFn != nil { + return f.authorizeSecurityGroupIngressFn(ctx, input) + } + return &ec2svc.AuthorizeSecurityGroupIngressOutput{}, nil +} + +func (f *fakeEC2Client) RevokeSecurityGroupIngress(ctx context.Context, input *ec2svc.RevokeSecurityGroupIngressInput) (*ec2svc.RevokeSecurityGroupIngressOutput, error) { + if f.revokeSecurityGroupIngressFn != nil { + return f.revokeSecurityGroupIngressFn(ctx, input) + } + return &ec2svc.RevokeSecurityGroupIngressOutput{}, nil +} + +func (f *fakeEC2Client) CreateTags(ctx context.Context, input *ec2svc.CreateTagsInput) (*ec2svc.CreateTagsOutput, error) { + if f.createTagsFn != nil { + return f.createTagsFn(ctx, input) + } + return &ec2svc.CreateTagsOutput{}, nil +} + +func (f *fakeEC2Client) CreateKeyPair(ctx context.Context, input *ec2svc.CreateKeyPairInput) (*ec2svc.CreateKeyPairOutput, error) { + if f.createKeyPairFn != nil { + return f.createKeyPairFn(ctx, input) + } + return &ec2svc.CreateKeyPairOutput{KeyMaterial: aws.String("PRIVATE")}, nil +} + +func (f *fakeEC2Client) DeleteKeyPair(ctx context.Context, input *ec2svc.DeleteKeyPairInput) (*ec2svc.DeleteKeyPairOutput, error) { + if f.deleteKeyPairFn != nil { + return f.deleteKeyPairFn(ctx, input) + } + return &ec2svc.DeleteKeyPairOutput{}, nil +} diff --git a/internal/providers/ec2/provider.go b/internal/providers/ec2/provider.go new file mode 100644 index 0000000..db6a23c --- /dev/null +++ b/internal/providers/ec2/provider.go @@ -0,0 +1,98 @@ +package ec2 + +import ( + "context" + "fmt" + "strings" + + awsconfig "github.com/aws/aws-sdk-go-v2/config" + ec2svc "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/pullpreview/action/internal/providers" + "github.com/pullpreview/action/internal/providers/sshca" + "github.com/pullpreview/action/internal/pullpreview" +) + +const ( + defaultEC2Region = "us-east-1" + defaultEC2ImagePrefix = "al2023-ami-2023" + defaultEC2InstanceType = "t3.small" + defaultEC2SSHUser = "ec2-user" +) + +type Config struct { + Region string + Image string + CAKey string + CAKeyEnv string + SSHUsername string +} + +func (c Config) ProviderName() string { + return "ec2" +} + +func (c Config) ProviderDisplayName() string { + return "AWS EC2" +} + +func (c Config) Validate() error { + if strings.TrimSpace(c.Region) == "" { + return fmt.Errorf("AWS region is required") + } + if strings.TrimSpace(c.CAKey) == "" { + return fmt.Errorf("PULLPREVIEW_CA_KEY is required for provider=ec2") + } + if strings.TrimSpace(c.SSHUsername) == "" { + return fmt.Errorf("ssh username is required") + } + return nil +} + +func ParseConfigFromEnv(env map[string]string) (pullpreview.ProviderConfig, error) { + region := strings.TrimSpace(env["REGION"]) + if region == "" { + region = strings.TrimSpace(env["AWS_REGION"]) + } + if region == "" { + region = defaultEC2Region + } + caResolution := sshca.ResolveFromEnv(env, "PULLPREVIEW_CA_KEY") + cfg := Config{ + Region: region, + Image: strings.TrimSpace(env["IMAGE"]), + CAKey: strings.TrimSpace(caResolution.Value), + CAKeyEnv: caResolution.EnvKey, + SSHUsername: defaultEC2SSHUser, + } + if _, err := sshca.Parse(cfg.CAKey, cfg.CAKeyEnv); err != nil { + return cfg, err + } + return cfg, cfg.Validate() +} + +func New(ctx context.Context, cfg Config, logger *pullpreview.Logger) (*Provider, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + awsCfg, err := awsconfig.LoadDefaultConfig(pullpreview.EnsureContext(ctx), awsconfig.WithRegion(cfg.Region)) + if err != nil { + return nil, err + } + return newProviderWithClient(ctx, cfg, logger, ec2ClientAdapter{client: ec2svc.NewFromConfig(awsCfg)}) +} + +func NewWithContext(ctx context.Context, cfg pullpreview.ProviderConfig, logger *pullpreview.Logger) (pullpreview.Provider, error) { + typed, ok := cfg.(Config) + if !ok { + pointer, ok := cfg.(*Config) + if !ok { + return nil, fmt.Errorf("invalid ec2 configuration type") + } + typed = *pointer + } + return New(ctx, typed, logger) +} + +func init() { + _ = providers.RegisterProvider("ec2", "AWS EC2", ParseConfigFromEnv, NewWithContext) +} diff --git a/internal/providers/hetzner/hetzner.go b/internal/providers/hetzner/hetzner.go index a51c361..46f8349 100644 --- a/internal/providers/hetzner/hetzner.go +++ b/internal/providers/hetzner/hetzner.go @@ -2,11 +2,7 @@ package hetzner import ( "context" - "crypto/rand" - "crypto/rsa" - "crypto/x509" "encoding/json" - "encoding/pem" "fmt" "net" "os" @@ -22,6 +18,7 @@ import ( "github.com/hetznercloud/hcloud-go/v2/hcloud" "github.com/pullpreview/action/internal/providers" + "github.com/pullpreview/action/internal/providers/sshca" "github.com/pullpreview/action/internal/pullpreview" ) @@ -147,6 +144,8 @@ type Config struct { Location string Image string CAKey string + CAKeyEnv string + UsedLegacyCAKey bool SSHUsername string SSHKeysCacheDir string } @@ -172,7 +171,7 @@ func (c Config) Validate() error { return fmt.Errorf("HCLOUD_TOKEN is required") } if strings.TrimSpace(c.CAKey) == "" { - return fmt.Errorf("HETZNER_CA_KEY is required") + return fmt.Errorf("PULLPREVIEW_CA_KEY is required (legacy HETZNER_CA_KEY is also supported for provider=hetzner)") } if strings.TrimSpace(c.Location) == "" { return fmt.Errorf("location is required") @@ -193,7 +192,8 @@ func ParseConfigFromEnv(env map[string]string) (pullpreview.ProviderConfig, erro if image == "" { image = defaultHetznerImage } - caKey := strings.TrimSpace(env["HETZNER_CA_KEY"]) + caResolution := sshca.ResolveFromEnv(env, "PULLPREVIEW_CA_KEY", "HETZNER_CA_KEY") + caKey := strings.TrimSpace(caResolution.Value) sshUser := defaultHetznerSSHUser sshKeysCacheDir := strings.TrimSpace(env["PULLPREVIEW_SSH_KEYS_CACHE_DIR"]) cfg := Config{ @@ -201,10 +201,12 @@ func ParseConfigFromEnv(env map[string]string) (pullpreview.ProviderConfig, erro Location: location, Image: image, CAKey: caKey, + CAKeyEnv: caResolution.EnvKey, + UsedLegacyCAKey: caResolution.UsedLegacy, SSHUsername: sshUser, SSHKeysCacheDir: sshKeysCacheDir, } - if _, _, _, _, err := parseHetznerCAKey(caKey); err != nil { + if _, _, _, _, err := parseHetznerCAKey(caKey, cfg.CAKeyEnv); err != nil { return cfg, err } return cfg, cfg.Validate() @@ -249,11 +251,14 @@ func newProviderWithContext(ctx context.Context, cfg Config, logger *pullpreview if client == nil { return nil, fmt.Errorf("client cannot be nil") } - caSigner, caPublicKey, caSource, _, err := parseHetznerCAKey(cfg.CAKey) + caSigner, caPublicKey, caSource, _, err := parseHetznerCAKey(cfg.CAKey, cfg.CAKeyEnv) if err != nil { return nil, err } if logger != nil { + if cfg.UsedLegacyCAKey { + logger.Warnf("HETZNER_CA_KEY is deprecated; use PULLPREVIEW_CA_KEY instead") + } logger.Infof("Hetzner SSH CA pre-check passed (%s)", caSource) } return &Provider{ @@ -827,78 +832,22 @@ func (p *Provider) generateSignedAccessCredentials() (string, string, error) { if err != nil { return "", "", err } - cert, err := generateUserCertificate(p.caSigner, signer, p.sshUser, defaultHetznerSSHCertTTL) + cert, err := sshca.GenerateUserCertificate(p.caSigner, signer, p.sshUser, defaultHetznerSSHCertTTL) if err != nil { return "", "", err } return privateKey, cert, nil } -func generateUserCertificate(caSigner ssh.Signer, userSigner ssh.Signer, principal string, ttl time.Duration) (string, error) { - if caSigner == nil { - return "", fmt.Errorf("missing CA signer") - } - if userSigner == nil { - return "", fmt.Errorf("missing user signer") - } - publicKey := userSigner.PublicKey() - if publicKey == nil { - return "", fmt.Errorf("user signer has no public key") - } - cert := &ssh.Certificate{ - Key: publicKey, - Serial: uint64(time.Now().UnixNano()), - CertType: ssh.UserCert, - KeyId: fmt.Sprintf("pullpreview-%s-%d", sanitizeNameForHetzner(principal), time.Now().UnixNano()), - ValidPrincipals: []string{principal}, - ValidAfter: uint64(time.Now().Add(-time.Minute).Unix()), - ValidBefore: uint64(time.Now().Add(ttl).Unix()), - } - if err := cert.SignCert(rand.Reader, caSigner); err != nil { - return "", err - } - return strings.TrimSpace(string(ssh.MarshalAuthorizedKey(cert))), nil -} - -func parseHetznerCAKey(raw string) (ssh.Signer, string, string, bool, error) { - raw = strings.TrimSpace(raw) - if raw == "" { - return nil, "", "", false, fmt.Errorf("HETZNER_CA_KEY is required") - } - - caSource := "inline HETZNER_CA_KEY" - caSourceFromFile := false - data := []byte(raw) - - if info, err := os.Stat(raw); err == nil { - if info.IsDir() { - return nil, "", "", false, fmt.Errorf("HETZNER_CA_KEY %q refers to a directory", raw) - } - caSource = raw - caSourceFromFile = true - data, err = os.ReadFile(raw) - if err != nil { - return nil, "", "", false, fmt.Errorf("failed to read HETZNER_CA_KEY from %q: %w", raw, err) - } +func parseHetznerCAKey(raw string, envName string) (ssh.Signer, string, string, bool, error) { + if strings.TrimSpace(envName) == "" { + envName = "PULLPREVIEW_CA_KEY" } - - signer, err := ssh.ParsePrivateKey(data) + parsed, err := sshca.Parse(raw, envName) if err != nil { - prefix := "inline HETZNER_CA_KEY" - if caSourceFromFile { - prefix = fmt.Sprintf("HETZNER_CA_KEY file %q", caSource) - } - return nil, "", caSource, caSourceFromFile, fmt.Errorf("invalid %s: %w", prefix, err) + return nil, "", "", false, err } - publicKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey()))) - if publicKey == "" { - errPrefix := "inline HETZNER_CA_KEY" - if caSourceFromFile { - errPrefix = fmt.Sprintf("HETZNER_CA_KEY file %q", caSource) - } - return nil, "", caSource, caSourceFromFile, fmt.Errorf("invalid %s: unable to derive public key", errPrefix) - } - return signer, publicKey, caSource, caSourceFromFile, nil + return parsed.Signer, parsed.PublicKey, parsed.Source, parsed.SourceFromFile, nil } func (p *Provider) cachePath(name string) string { @@ -1056,26 +1005,7 @@ func generateSSHKeyPair(_ string) (string, string, error) { } func generateSSHKeyPairWithSigner(_ string) (string, string, ssh.Signer, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return "", "", nil, err - } - private := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(key), - }) - if private == nil { - return "", "", nil, fmt.Errorf("unable to marshal private key") - } - public, err := ssh.NewPublicKey(&key.PublicKey) - if err != nil { - return "", "", nil, err - } - signer, err := ssh.NewSignerFromKey(key) - if err != nil { - return "", "", nil, err - } - return strings.TrimSpace(string(ssh.MarshalAuthorizedKey(public))), strings.TrimSpace(string(private)), signer, nil + return sshca.GenerateSSHKeyPairWithSigner() } func parseFirewallRules(ports, cidrs []string) ([]hcloud.FirewallRule, error) { diff --git a/internal/providers/hetzner/hetzner_test.go b/internal/providers/hetzner/hetzner_test.go index 39a8f06..80ba204 100644 --- a/internal/providers/hetzner/hetzner_test.go +++ b/internal/providers/hetzner/hetzner_test.go @@ -20,10 +20,10 @@ func TestParseConfigFromEnv(t *testing.T) { t.Fatalf("failed to generate ca key: %v", err) } cfgRaw, err := ParseConfigFromEnv(map[string]string{ - "HCLOUD_TOKEN": "abc", - "REGION": "fra1", - "IMAGE": "debian-12", - "HETZNER_CA_KEY": caKey, + "HCLOUD_TOKEN": "abc", + "REGION": "fra1", + "IMAGE": "debian-12", + "PULLPREVIEW_CA_KEY": caKey, }) if err != nil { t.Fatalf("ParseConfigFromEnv() error: %v", err) @@ -54,6 +54,25 @@ func TestParseConfigFromEnv(t *testing.T) { if cfg.Location != defaultHetznerLocation || cfg.Image != defaultHetznerImage || cfg.CAKey != caKeyFile || cfg.SSHUsername != defaultHetznerSSHUser { t.Fatalf("expected defaults and file-backed CA key path, got %#v", cfg) } + if !cfg.UsedLegacyCAKey { + t.Fatalf("expected legacy env fallback to be marked") + } + + cfgRaw, err = ParseConfigFromEnv(map[string]string{ + "HCLOUD_TOKEN": "priority", + "PULLPREVIEW_CA_KEY": caKey, + "HETZNER_CA_KEY": "not-a-key", + }) + if err != nil { + t.Fatalf("ParseConfigFromEnv() with canonical priority error: %v", err) + } + cfg = cfgRaw.(Config) + if cfg.CAKey != caKey { + t.Fatalf("expected canonical CA key precedence, got %q", cfg.CAKey) + } + if cfg.UsedLegacyCAKey { + t.Fatalf("did not expect legacy marker when canonical env is present") + } if _, err := ParseConfigFromEnv(map[string]string{"HCLOUD_TOKEN": "fallback", "HETZNER_CA_KEY": ""}); err == nil { t.Fatalf("expected missing CA key error") @@ -61,10 +80,10 @@ func TestParseConfigFromEnv(t *testing.T) { if _, err := ParseConfigFromEnv(map[string]string{"HCLOUD_TOKEN": "fallback"}); err == nil { t.Fatalf("expected missing CA key error") } - if _, err := ParseConfigFromEnv(map[string]string{"HETZNER_CA_KEY": caKey}); err == nil { + if _, err := ParseConfigFromEnv(map[string]string{"PULLPREVIEW_CA_KEY": caKey}); err == nil { t.Fatalf("expected missing token error") } - if _, err := ParseConfigFromEnv(map[string]string{"HCLOUD_TOKEN": "fallback", "HETZNER_CA_KEY": "not-a-key"}); err == nil { + if _, err := ParseConfigFromEnv(map[string]string{"HCLOUD_TOKEN": "fallback", "PULLPREVIEW_CA_KEY": "not-a-key"}); err == nil { t.Fatalf("expected invalid CA key error") } } diff --git a/internal/providers/sshca/sshca.go b/internal/providers/sshca/sshca.go new file mode 100644 index 0000000..4fdb7b4 --- /dev/null +++ b/internal/providers/sshca/sshca.go @@ -0,0 +1,177 @@ +package sshca + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "os" + "strings" + "time" + + "golang.org/x/crypto/ssh" +) + +type Resolution struct { + Value string + EnvKey string + UsedLegacy bool +} + +type Parsed struct { + Signer ssh.Signer + PublicKey string + Source string + SourceFromFile bool +} + +func ResolveFromEnv(env map[string]string, canonical string, legacy ...string) Resolution { + canonical = strings.TrimSpace(canonical) + if canonical != "" { + if value := strings.TrimSpace(env[canonical]); value != "" { + return Resolution{Value: value, EnvKey: canonical} + } + } + for _, key := range legacy { + key = strings.TrimSpace(key) + if key == "" { + continue + } + if value := strings.TrimSpace(env[key]); value != "" { + return Resolution{Value: value, EnvKey: key, UsedLegacy: true} + } + } + return Resolution{EnvKey: canonical} +} + +func Parse(raw string, envName string) (Parsed, error) { + raw = strings.TrimSpace(raw) + envName = strings.TrimSpace(envName) + if envName == "" { + envName = "PULLPREVIEW_CA_KEY" + } + if raw == "" { + return Parsed{}, fmt.Errorf("%s is required", envName) + } + + source := "inline " + envName + sourceFromFile := false + data := []byte(raw) + + if info, err := os.Stat(raw); err == nil { + if info.IsDir() { + return Parsed{}, fmt.Errorf("%s %q refers to a directory", envName, raw) + } + source = raw + sourceFromFile = true + data, err = os.ReadFile(raw) + if err != nil { + return Parsed{}, fmt.Errorf("failed to read %s from %q: %w", envName, raw, err) + } + } + + signer, err := ssh.ParsePrivateKey(data) + if err != nil { + prefix := "inline " + envName + if sourceFromFile { + prefix = fmt.Sprintf("%s file %q", envName, source) + } + return Parsed{}, fmt.Errorf("invalid %s: %w", prefix, err) + } + publicKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey()))) + if publicKey == "" { + errPrefix := "inline " + envName + if sourceFromFile { + errPrefix = fmt.Sprintf("%s file %q", envName, source) + } + return Parsed{}, fmt.Errorf("invalid %s: unable to derive public key", errPrefix) + } + + return Parsed{ + Signer: signer, + PublicKey: publicKey, + Source: source, + SourceFromFile: sourceFromFile, + }, nil +} + +func GenerateSSHKeyPairWithSigner() (string, string, ssh.Signer, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", nil, err + } + private := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + if private == nil { + return "", "", nil, fmt.Errorf("unable to marshal private key") + } + public, err := ssh.NewPublicKey(&key.PublicKey) + if err != nil { + return "", "", nil, err + } + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + return "", "", nil, err + } + return strings.TrimSpace(string(ssh.MarshalAuthorizedKey(public))), strings.TrimSpace(string(private)), signer, nil +} + +func GenerateUserCertificate(caSigner ssh.Signer, userSigner ssh.Signer, principal string, ttl time.Duration) (string, error) { + if caSigner == nil { + return "", fmt.Errorf("missing CA signer") + } + if userSigner == nil { + return "", fmt.Errorf("missing user signer") + } + publicKey := userSigner.PublicKey() + if publicKey == nil { + return "", fmt.Errorf("user signer has no public key") + } + principal = strings.TrimSpace(principal) + if principal == "" { + principal = "user" + } + if ttl <= 0 { + ttl = time.Hour + } + cert := &ssh.Certificate{ + Key: publicKey, + Serial: uint64(time.Now().UnixNano()), + CertType: ssh.UserCert, + KeyId: fmt.Sprintf("pullpreview-%s-%d", sanitizePrincipal(principal), time.Now().UnixNano()), + ValidPrincipals: []string{principal}, + ValidAfter: uint64(time.Now().Add(-time.Minute).Unix()), + ValidBefore: uint64(time.Now().Add(ttl).Unix()), + } + if err := cert.SignCert(rand.Reader, caSigner); err != nil { + return "", err + } + return strings.TrimSpace(string(ssh.MarshalAuthorizedKey(cert))), nil +} + +func sanitizePrincipal(value string) string { + value = strings.TrimSpace(strings.ToLower(value)) + if value == "" { + return "user" + } + value = strings.Map(func(r rune) rune { + switch { + case r >= 'a' && r <= 'z': + return r + case r >= '0' && r <= '9': + return r + case r == '-': + return r + default: + return '-' + } + }, value) + value = strings.Trim(value, "-") + if value == "" { + value = "user" + } + return value +}