From c973fc79a6eeb5196a9bc6ea810bdf520d9e78e7 Mon Sep 17 00:00:00 2001 From: Max Anderson Date: Mon, 1 Jun 2026 11:33:54 -0400 Subject: [PATCH 1/2] feat: implement the jotsmith OIDC issuer CLI Initial implementation of the full command surface from PRD.md: - setup: configure an existing Storage Account + Key Vault as an OIDC issuer (enable static website hosting, create the RSA signing key, publish the discovery document and JWKS). Idempotent. - token mint: assemble standard and custom claims and sign inside Key Vault (RS256); stdout is the compact JWT and nothing else. - token verify: live discovery + JWKS round-trip with RS256 signature and standard-claim checks (60s clock skew). - token decode: inspect a JWT without verifying it. - key rotate: snap-cutover rotation (ADR-0005). - doctor: read-only audit with --repair (re-upload docs / re-enable static website) and --json output. - discovery show / jwks show / config show. - destroy: tear down the issuer state, keeping the Azure resources. - completion: bash/zsh/fish/powershell scripts. Private key material never leaves Key Vault; iss always comes from config; kid is the RFC 7638 thumbprint. Built on urfave/cli v3 and the modular Azure SDK with DefaultAzureCredential. Includes unit tests throughout plus integration tests behind the integration build tag, CI, golangci-lint v2, and goreleaser config. OpenCode session ID: ses_17c986ef9ffeHc0S3r1wQcmLDT --- .github/workflows/ci.yml | 39 +++ .gitignore | 13 + .golangci.yml | 26 ++ .goreleaser.yaml | 55 ++++ AGENTS.md | 14 ++ cmd/jotsmith/main.go | 40 +++ docs/agents/domain.md | 38 +++ docs/agents/issue-tracker.md | 29 +++ docs/agents/triage-labels.md | 15 ++ go.mod | 28 +++ go.sum | 59 +++++ internal/azurex/destroy.go | 44 ++++ internal/azurex/doc.go | 10 + internal/azurex/errors.go | 50 ++++ internal/azurex/integration_test.go | 114 +++++++++ internal/azurex/manager.go | 96 +++++++ internal/azurex/provider.go | 345 +++++++++++++++++++++++++ internal/azurex/signer.go | 32 +++ internal/azurex/types.go | 130 ++++++++++ internal/cli/azure.go | 25 ++ internal/cli/completion.go | 95 +++++++ internal/cli/completion_test.go | 59 +++++ internal/cli/config.go | 68 +++++ internal/cli/config_test.go | 82 ++++++ internal/cli/context.go | 27 ++ internal/cli/destroy.go | 112 +++++++++ internal/cli/destroy_test.go | 157 ++++++++++++ internal/cli/discovery.go | 45 ++++ internal/cli/discovery_jwks_test.go | 58 +++++ internal/cli/doc.go | 7 + internal/cli/doctor.go | 374 ++++++++++++++++++++++++++++ internal/cli/doctor_repair_test.go | 165 ++++++++++++ internal/cli/doctor_test.go | 206 +++++++++++++++ internal/cli/errors.go | 44 ++++ internal/cli/integration_test.go | 134 ++++++++++ internal/cli/io.go | 24 ++ internal/cli/jwks.go | 66 +++++ internal/cli/key.go | 132 ++++++++++ internal/cli/key_test.go | 143 +++++++++++ internal/cli/logging.go | 137 ++++++++++ internal/cli/prompt.go | 28 +++ internal/cli/root.go | 104 ++++++++ internal/cli/root_test.go | 120 +++++++++ internal/cli/setup.go | 234 +++++++++++++++++ internal/cli/setup_test.go | 229 +++++++++++++++++ internal/cli/token.go | 371 +++++++++++++++++++++++++++ internal/cli/token_claims_test.go | 201 +++++++++++++++ internal/cli/token_mint_test.go | 277 ++++++++++++++++++++ internal/cli/token_test.go | 100 ++++++++ internal/cli/verify.go | 120 +++++++++ internal/cli/verify_test.go | 142 +++++++++++ internal/cli/version.go | 20 ++ internal/config/config.go | 172 +++++++++++++ internal/config/config_test.go | 154 ++++++++++++ internal/config/doc.go | 6 + internal/config/errors.go | 35 +++ internal/jwk/doc.go | 6 + internal/jwk/jwk.go | 89 +++++++ internal/jwk/jwk_test.go | 67 +++++ internal/oidc/doc.go | 4 + internal/oidc/oidc.go | 40 +++ internal/oidc/oidc_test.go | 52 ++++ internal/sign/b64.go | 11 + internal/sign/decode.go | 83 ++++++ internal/sign/decode_test.go | 116 +++++++++ internal/sign/doc.go | 9 + internal/sign/mint.go | 68 +++++ internal/sign/mint_test.go | 110 ++++++++ internal/sign/verify.go | 185 ++++++++++++++ internal/sign/verify_test.go | 198 +++++++++++++++ 70 files changed, 6688 insertions(+) create mode 100644 .github/workflows/ci.yml create mode 100644 .gitignore create mode 100644 .golangci.yml create mode 100644 .goreleaser.yaml create mode 100644 cmd/jotsmith/main.go create mode 100644 docs/agents/domain.md create mode 100644 docs/agents/issue-tracker.md create mode 100644 docs/agents/triage-labels.md create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/azurex/destroy.go create mode 100644 internal/azurex/doc.go create mode 100644 internal/azurex/errors.go create mode 100644 internal/azurex/integration_test.go create mode 100644 internal/azurex/manager.go create mode 100644 internal/azurex/provider.go create mode 100644 internal/azurex/signer.go create mode 100644 internal/azurex/types.go create mode 100644 internal/cli/azure.go create mode 100644 internal/cli/completion.go create mode 100644 internal/cli/completion_test.go create mode 100644 internal/cli/config.go create mode 100644 internal/cli/config_test.go create mode 100644 internal/cli/context.go create mode 100644 internal/cli/destroy.go create mode 100644 internal/cli/destroy_test.go create mode 100644 internal/cli/discovery.go create mode 100644 internal/cli/discovery_jwks_test.go create mode 100644 internal/cli/doc.go create mode 100644 internal/cli/doctor.go create mode 100644 internal/cli/doctor_repair_test.go create mode 100644 internal/cli/doctor_test.go create mode 100644 internal/cli/errors.go create mode 100644 internal/cli/integration_test.go create mode 100644 internal/cli/io.go create mode 100644 internal/cli/jwks.go create mode 100644 internal/cli/key.go create mode 100644 internal/cli/key_test.go create mode 100644 internal/cli/logging.go create mode 100644 internal/cli/prompt.go create mode 100644 internal/cli/root.go create mode 100644 internal/cli/root_test.go create mode 100644 internal/cli/setup.go create mode 100644 internal/cli/setup_test.go create mode 100644 internal/cli/token.go create mode 100644 internal/cli/token_claims_test.go create mode 100644 internal/cli/token_mint_test.go create mode 100644 internal/cli/token_test.go create mode 100644 internal/cli/verify.go create mode 100644 internal/cli/verify_test.go create mode 100644 internal/cli/version.go create mode 100644 internal/config/config.go create mode 100644 internal/config/config_test.go create mode 100644 internal/config/doc.go create mode 100644 internal/config/errors.go create mode 100644 internal/jwk/doc.go create mode 100644 internal/jwk/jwk.go create mode 100644 internal/jwk/jwk_test.go create mode 100644 internal/oidc/doc.go create mode 100644 internal/oidc/oidc.go create mode 100644 internal/oidc/oidc_test.go create mode 100644 internal/sign/b64.go create mode 100644 internal/sign/decode.go create mode 100644 internal/sign/decode_test.go create mode 100644 internal/sign/doc.go create mode 100644 internal/sign/mint.go create mode 100644 internal/sign/mint_test.go create mode 100644 internal/sign/verify.go create mode 100644 internal/sign/verify_test.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..356061d --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,39 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + +permissions: + contents: read + +env: + GO_VERSION: "1.25" + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5 + with: + go-version: ${{ env.GO_VERSION }} + - name: golangci-lint + uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8 + with: + version: v2.11.3 + + test: + name: Test & Build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + - uses: actions/setup-go@40f1582b2485089dde7abd97c1529aa768e1baff # v5 + with: + go-version: ${{ env.GO_VERSION }} + - name: Build + run: go build ./... + - name: Test + run: go test ./... diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cf3cf86 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +# Build output +/dist/ +/jotsmith +/bin/ + +# Test / coverage +*.out +coverage.txt + +# Editor / OS +.DS_Store +.idea/ +.vscode/ diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..3c83c9a --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,26 @@ +version: "2" + +run: + timeout: 5m + +linters: + default: standard + enable: + - bodyclose + - errorlint + - misspell + - unconvert + exclusions: + generated: lax + presets: + - common-false-positives + - std-error-handling + +formatters: + enable: + - gofmt + - goimports + settings: + goimports: + local-prefixes: + - github.com/MaxAnderson95/jotsmith diff --git a/.goreleaser.yaml b/.goreleaser.yaml new file mode 100644 index 0000000..9341ab0 --- /dev/null +++ b/.goreleaser.yaml @@ -0,0 +1,55 @@ +version: 2 + +project_name: jotsmith + +before: + hooks: + - go mod tidy + +builds: + - id: jotsmith + main: ./cmd/jotsmith + binary: jotsmith + env: + - CGO_ENABLED=0 + goos: + - linux + - darwin + - windows + goarch: + - amd64 + - arm64 + ldflags: + - -s -w + - -X main.version={{ .Version }} + - -X main.commit={{ .Commit }} + - -X main.date={{ .Date }} + +archives: + - id: default + formats: + - tar.gz + name_template: >- + {{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }} + format_overrides: + - goos: windows + formats: + - zip + +checksum: + name_template: checksums.txt + +snapshot: + version_template: "{{ incpatch .Version }}-next" + +changelog: + sort: asc + filters: + exclude: + - "^docs:" + - "^test:" + - "^chore:" + +release: + draft: false + prerelease: auto diff --git a/AGENTS.md b/AGENTS.md index 667257d..7b8fad6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -90,3 +90,17 @@ Tests live next to the code as `*_test.go`. Integration tests that touch real Az - Multi-user / shared-tenant. - Telemetry / metrics / OTel export. (Structured logs only, to stderr.) - Token revocation list. JWTs are self-contained. + +## Agent skills + +### Issue tracker + +GitHub Issues on `MaxAnderson95/jotsmith`, accessed via `gh`. See `docs/agents/issue-tracker.md`. + +### Triage labels + +Canonical five-role vocabulary (`needs-triage`, `needs-info`, `ready-for-agent`, `ready-for-human`, `wontfix`). See `docs/agents/triage-labels.md`. + +### Domain docs + +Single-context: `CONTEXT.md` + `docs/adr/` at the repo root. See `docs/agents/domain.md`. diff --git a/cmd/jotsmith/main.go b/cmd/jotsmith/main.go new file mode 100644 index 0000000..81b0514 --- /dev/null +++ b/cmd/jotsmith/main.go @@ -0,0 +1,40 @@ +// Command jotsmith stands up a personal OIDC issuer in Azure (Storage Account +// static website + Key Vault) and mints JWTs of arbitrary shape against it. +// +// See PRD.md and CONTEXT.md at the repo root for the product surface and +// domain language. +package main + +import ( + "context" + "fmt" + "os" + + "github.com/MaxAnderson95/jotsmith/internal/cli" +) + +// Build metadata, overridden at release time via -ldflags +// (-X main.version=... -X main.commit=... -X main.date=...). +var ( + version = "dev" + commit = "none" + date = "unknown" +) + +func main() { + streams := cli.DefaultIOStreams() + root := cli.NewRootCommand(streams, cli.BuildInfo{ + Version: version, + Commit: commit, + Date: date, + }) + + if err := root.Run(context.Background(), os.Args); err != nil { + // Some commands (e.g. doctor --json) signal a non-zero exit without a + // human message; printing an empty line would be noise. + if msg := err.Error(); msg != "" { + fmt.Fprintln(streams.Err, msg) + } + os.Exit(cli.ExitCode(err)) + } +} diff --git a/docs/agents/domain.md b/docs/agents/domain.md new file mode 100644 index 0000000..a31ffa0 --- /dev/null +++ b/docs/agents/domain.md @@ -0,0 +1,38 @@ +# Domain Docs + +How the engineering skills should consume this repo's domain documentation when exploring the codebase. + +## Before exploring, read these + +- **`CONTEXT.md`** at the repo root — the ubiquitous-language glossary. Use the terms it defines; avoid the listed synonyms. +- **`docs/adr/`** — read ADRs that touch the area you're about to work in. The five v1 ADRs (`0001`–`0005`) cover decisions that are hard to reverse and surprising without context. +- **`PRD.md`** at the repo root is canonical for product surface. If code disagrees with the PRD, the bug is in the code. + +If any of these files don't exist for a topic, **proceed silently** — don't flag their absence or suggest creating them upfront. `grill-with-docs` creates them lazily when terms or decisions actually get resolved. + +## File structure (single-context) + +``` +/ +├── PRD.md +├── CONTEXT.md +├── AGENTS.md +└── docs/adr/ + ├── 0001-issuer-url-raw-static-website.md + ├── 0002-setup-configures-existing-resources-only.md + ├── 0003-rs256-only-with-multikey-ready-schema.md + ├── 0004-single-issuer-per-config.md + └── 0005-snap-cutover-rotation.md +``` + +## Use the glossary's vocabulary + +When your output names a domain concept (issue title, refactor proposal, hypothesis, test name), use the term as defined in `CONTEXT.md`. Don't drift to synonyms the glossary explicitly avoids — `mint` not "sign"/"issue"/"generate", `setup` not "init"/"configure", `doctor` not "check"/"audit"/"lint", `rotate` not "cycle"/"refresh", `Issuer URL` not "endpoint"/"IdP URL". + +If the concept you need isn't in the glossary yet, that's a signal — either you're inventing language the project doesn't use (reconsider) or there's a real gap (note it for `grill-with-docs`). + +## Flag ADR conflicts + +If your output contradicts an existing ADR, surface it explicitly rather than silently overriding: + +> _Contradicts ADR-0002 (setup configures existing resources only) — but worth reopening because…_ diff --git a/docs/agents/issue-tracker.md b/docs/agents/issue-tracker.md new file mode 100644 index 0000000..4beb990 --- /dev/null +++ b/docs/agents/issue-tracker.md @@ -0,0 +1,29 @@ +# Issue tracker: GitHub + +Issues and PRDs for this repo live as GitHub issues on `MaxAnderson95/jotsmith`. Use the `gh` CLI for all operations. + +## Conventions + +- **Create an issue**: `gh issue create --title "..." --body-file `. Use `--body-file` for any non-trivial body (backticks, code fences, tables) per the global AGENTS.md rules. +- **Read an issue**: prefer `gh api repos/MaxAnderson95/jotsmith/issues/` for reads — `gh issue view` can fail on repos whose GraphQL queries touch deprecated Projects-Classic fields. If `gh issue view` errors, fall back to `gh api` for the rest of the task. +- **List issues**: `gh issue list --state open --json number,title,body,labels --jq '[.[] | {number, title, labels: [.labels[].name]}]'` with `--label` and `--state` filters. +- **Comment on an issue**: `gh issue comment --body-file ` (or `--body "..."` for trivial bodies). +- **Apply / remove labels**: `gh issue edit --add-label "..."` / `--remove-label "..."`. +- **Close**: `gh issue close --comment "..."`. + +`gh` infers the repo from `git remote -v` automatically inside a clone. + +## Issue body footer + +Every issue body created or edited via `gh` must carry the opencode session-ID footer block at the bottom, per the global AGENTS.md rules: + +- On creation: `_Originally created in OpenCode session ID: ses_XXXXXXX_` +- On edit by a different session: append `_Edited by OpenCode session ID: ses_YYYYYYYY_` (append-only, never rewrite). + +## When a skill says "publish to the issue tracker" + +Create a GitHub issue on `MaxAnderson95/jotsmith`. + +## When a skill says "fetch the relevant ticket" + +Run `gh api repos/MaxAnderson95/jotsmith/issues/` for the body, and `gh api repos/MaxAnderson95/jotsmith/issues//comments` for comments. Verify edits with `gh api repos/MaxAnderson95/jotsmith/issues/ --jq '{title, body}'`. diff --git a/docs/agents/triage-labels.md b/docs/agents/triage-labels.md new file mode 100644 index 0000000..6ee1649 --- /dev/null +++ b/docs/agents/triage-labels.md @@ -0,0 +1,15 @@ +# Triage Labels + +The skills speak in terms of five canonical triage roles. This file maps those roles to the actual label strings used in this repo's issue tracker. + +| Canonical role | Label in this repo | Meaning | +| ----------------- | ------------------ | ---------------------------------------- | +| `needs-triage` | `needs-triage` | Maintainer needs to evaluate this issue | +| `needs-info` | `needs-info` | Waiting on reporter for more information | +| `ready-for-agent` | `ready-for-agent` | Fully specified, ready for an AFK agent | +| `ready-for-human` | `ready-for-human` | Requires human implementation | +| `wontfix` | `wontfix` | Will not be actioned | + +When a skill mentions a role (e.g. "apply the AFK-ready triage label"), use the corresponding label string from this table. + +Edit the right-hand column if the vocabulary ever changes. diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..3fdfac1 --- /dev/null +++ b/go.mod @@ -0,0 +1,28 @@ +module github.com/MaxAnderson95/jotsmith + +go 1.25.0 + +require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.5.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions v1.3.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 + github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.5.0 + github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.7.0 + github.com/google/uuid v1.6.0 + github.com/urfave/cli/v3 v3.9.0 +) + +require ( + github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.7.2 // indirect + github.com/golang-jwt/jwt/v5 v5.3.1 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect + golang.org/x/crypto v0.51.0 // indirect + golang.org/x/net v0.54.0 // indirect + golang.org/x/sys v0.44.0 // indirect + golang.org/x/text v0.37.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4a62151 --- /dev/null +++ b/go.sum @@ -0,0 +1,59 @@ +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 h1:jHb/wfvRikGdxMXYV3QG/SzUOPYN9KEUUuC0Yd0/vC0= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1/go.mod h1:pzBXCYn05zvYIrwLgtK8Ap8QcjRg+0i76tMQdWN6wOk= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6XuRW0nU7hgg4zlmZZa+a9q4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0/go.mod h1:7dCRMLwisfRH3dBupKeNCioWYUZ4SS09Z14H+7i8ZoY= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0 h1:2qsIIvxVT+uE6yrNldntJKlLRgxGbZ85kgtz5SNBhMw= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.0/go.mod h1:AW8VEadnhw9xox+VaVd9sP7NjzOAnaZBLRH6Tq3cJ38= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.5.0 h1:nnQ9vXH039UrEFxi08pPuZBE7VfqSJt343uJLw0rhWI= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault v1.5.0/go.mod h1:4YIVtzMFVsPwBvitCDX7J9sqthSj43QD1sP6fYc1egc= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0 h1:Dd+RhdJn0OTtVGaeDLZpcumkIVCtA/3/Fo42+eoYvVM= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.2.0/go.mod h1:5kakwfW5CjC9KK+Q4wjXAg+ShuIm2mBMua0ZFj2C8PE= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions v1.3.0 h1:wxQx2Bt4xzPIKvW59WQf1tJNx/ZZKPfN+EhPX3Z6CYY= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions v1.3.0/go.mod h1:TpiwjwnW/khS0LKs4vW5UmmT9OWcxaveS8U7+tlknzo= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 h1:/Zt+cDPnpC3OVDm/JKLOs7M2DKmLRIIp3XIx9pHHiig= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1/go.mod h1:Ng3urmn6dYe8gnbCMoHHVl5APYz2txho3koEkV2o2HA= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.5.0 h1:MaKvxE6D0KkjOg6Wd9M00iqP5PR0kUxCfiezes4JweM= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.5.0/go.mod h1:i2h9fsTFKZorh8RdV2IcSUf/Qj98GlTkrTvUbX/s8as= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0 h1:nCYfgcSyHZXJI8J0IWE5MsCGlb2xp9fJiXyxWgmOFg4= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0/go.mod h1:ucUjca2JtSZboY8IoUqyQyuuXvwbMBVwFOm0vdQPNhA= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.7.0 h1:BM85pSYlVYQHdq00nxyPoOkyLF5NArJG3bOsrmbwr4k= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.7.0/go.mod h1:QYjP2cB7ZYtS/8jAbE0VSBZde/tjExqGjp+8JY6/+ts= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.7.2 h1:RHK7bS+HQMslb1sZpAokUt+zTVmue0hKSs2C791hhzU= +github.com/AzureAD/microsoft-authentication-library-for-go v1.7.2/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/urfave/cli/v3 v3.9.0 h1:AV9lIiPv3ukYnxunaCUsHnEozptYmDN2F0+yWqLMn/c= +github.com/urfave/cli/v3 v3.9.0/go.mod h1:ysVLtOEmg2tOy6PknnYVhDoouyC/6N42TMeoMzskhso= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= +golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= +golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/azurex/destroy.go b/internal/azurex/destroy.go new file mode 100644 index 0000000..c74c231 --- /dev/null +++ b/internal/azurex/destroy.go @@ -0,0 +1,44 @@ +package azurex + +import "context" + +// DeleteSigningKey soft-deletes the configured signing key. Key Vault retains a +// recoverable copy (subject to the vault's soft-delete retention) rather than +// purging it, so the caller can recover or purge it separately. +// +// If the key does not exist (already deleted), it returns a *NotFoundError so +// destroy can treat that as an idempotent no-op. +func (p *Provider) DeleteSigningKey(ctx context.Context) (DeletedKey, error) { + client, err := p.keysClient(ctx) + if err != nil { + return DeletedKey{}, err + } + resp, err := client.DeleteKey(ctx, p.target.KeyName, nil) + if err != nil { + if isNotFound(err) { + return DeletedKey{}, &NotFoundError{Kind: "signing key", Resource: p.target.KeyName} + } + return DeletedKey{}, &ResourceError{Op: "delete signing key", Resource: p.target.KeyName, Err: err} + } + return DeletedKey{Name: p.target.KeyName, RecoveryID: deref(resp.RecoveryID)}, nil +} + +// DeleteBlob deletes the blob at $web/. If the blob (or container) does +// not exist it returns a *NotFoundError so destroy can treat that as an +// idempotent no-op. +func (p *Provider) DeleteBlob(ctx context.Context, path string) error { + client, err := p.blobClient(ctx) + if err != nil { + return err + } + if _, err := client.DeleteBlob(ctx, webContainer, path, nil); err != nil { + if isBlobNotFound(err) { + return &NotFoundError{Kind: "blob", Resource: webContainer + "/" + path} + } + return &ResourceError{Op: "delete blob", Resource: webContainer + "/" + path, Err: err} + } + return nil +} + +// ensure Provider satisfies the DestroyManager contract at compile time. +var _ DestroyManager = (*Provider)(nil) diff --git a/internal/azurex/doc.go b/internal/azurex/doc.go new file mode 100644 index 0000000..cae352a --- /dev/null +++ b/internal/azurex/doc.go @@ -0,0 +1,10 @@ +// Package azurex holds thin adapters over the Azure SDK for the control-plane +// and data-plane operations jotsmith performs: credential resolution, +// subscription and storage-account reads, blob get/put, and Key Vault key +// reads, creation, and signing. +// +// The only control-plane mutation jotsmith is ever allowed to make is enabling +// static website hosting on an existing Storage Account (ADR-0002). Private key +// material never leaves Key Vault: signing goes through the data-plane Sign +// API. +package azurex diff --git a/internal/azurex/errors.go b/internal/azurex/errors.go new file mode 100644 index 0000000..498eaea --- /dev/null +++ b/internal/azurex/errors.go @@ -0,0 +1,50 @@ +package azurex + +import ( + "errors" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" +) + +// ResourceError wraps an underlying SDK error with the operation and the Azure +// resource identity it was performed against, so error messages always name +// what jotsmith was trying to touch. +type ResourceError struct { + Op string // e.g. "get storage account" + Resource string // e.g. "jotsmithmax" + Err error +} + +func (e *ResourceError) Error() string { + return fmt.Sprintf("%s %q: %v", e.Op, e.Resource, e.Err) +} + +func (e *ResourceError) Unwrap() error { return e.Err } + +// NotFoundError indicates the named Azure resource (account, vault, key, or +// blob) does not exist. +type NotFoundError struct { + Kind string // "storage account", "key vault", "signing key", "blob" + Resource string +} + +func (e *NotFoundError) Error() string { + return fmt.Sprintf("%s %q not found", e.Kind, e.Resource) +} + +// isNotFound reports whether err represents an HTTP 404 from an ARM/data-plane +// call. +func isNotFound(err error) bool { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) { + return respErr.StatusCode == 404 + } + return false +} + +// isBlobNotFound reports whether err is a blob-service "not found". +func isBlobNotFound(err error) bool { + return bloberror.HasCode(err, bloberror.BlobNotFound, bloberror.ContainerNotFound) +} diff --git a/internal/azurex/integration_test.go b/internal/azurex/integration_test.go new file mode 100644 index 0000000..adcb497 --- /dev/null +++ b/internal/azurex/integration_test.go @@ -0,0 +1,114 @@ +//go:build integration + +// Integration tests for the azurex read paths. These talk to real Azure and +// are excluded from the default build; run them with: +// +// go test -tags integration ./internal/azurex/... +// +// Required environment variables: +// +// JOTSMITH_TEST_SUBSCRIPTION - subscription ID containing the test resources +// JOTSMITH_TEST_STORAGE_ACCOUNT- name of an existing GPv2 storage account +// JOTSMITH_TEST_KEY_VAULT - name of an existing RBAC-mode key vault +// JOTSMITH_TEST_KEY_NAME - name of an existing RSA signing key (optional) +// +// Minimum RBAC for the running principal (DefaultAzureCredential): +// +// Reader on the subscription (list/get resources) +// Storage Blob Data Reader on the storage account (read $web blobs + service props) +// Key Vault Crypto User on the key vault (get/sign keys) +// +// These are read-only; the write paths exercised by setup/rotate/destroy need +// the elevated roles documented in PRD §5.2. +package azurex + +import ( + "context" + "os" + "strings" + "testing" + "time" + + "github.com/MaxAnderson95/jotsmith/internal/jwk" + "github.com/MaxAnderson95/jotsmith/internal/sign" +) + +func integrationTarget(t *testing.T) Target { + t.Helper() + sub := os.Getenv("JOTSMITH_TEST_SUBSCRIPTION") + sa := os.Getenv("JOTSMITH_TEST_STORAGE_ACCOUNT") + kv := os.Getenv("JOTSMITH_TEST_KEY_VAULT") + if sub == "" || sa == "" || kv == "" { + t.Skip("integration env not set: JOTSMITH_TEST_SUBSCRIPTION/STORAGE_ACCOUNT/KEY_VAULT") + } + keyName := os.Getenv("JOTSMITH_TEST_KEY_NAME") + if keyName == "" { + keyName = "signing-key" + } + return Target{ + SubscriptionID: sub, + StorageAccount: sa, + KeyVault: kv, + KeyName: keyName, + DiscoveryPath: ".well-known/openid-configuration", + JWKSPath: ".well-known/jwks.json", + } +} + +func TestIntegrationReadPaths(t *testing.T) { + ctx := context.Background() + p, err := NewProvider(ctx, integrationTarget(t), nil) + if err != nil { + t.Fatalf("NewProvider: %v", err) + } + if _, err := p.GetSubscription(ctx); err != nil { + t.Fatalf("GetSubscription: %v", err) + } + if _, err := p.GetStorageAccount(ctx); err != nil { + t.Fatalf("GetStorageAccount: %v", err) + } + if _, err := p.GetVault(ctx); err != nil { + t.Fatalf("GetVault: %v", err) + } +} + +// TestIntegrationSignMint exercises the real Key Vault Sign path: it derives the +// kid from the live key, mints a token through the same sign.Mint code path the +// CLI uses, and asserts the result parses as a compact JWT. Requires the running +// principal to have a signing role (Key Vault Crypto User or Officer) on the +// configured key. +func TestIntegrationSignMint(t *testing.T) { + ctx := context.Background() + p, err := NewProvider(ctx, integrationTarget(t), nil) + if err != nil { + t.Fatalf("NewProvider: %v", err) + } + + key, err := p.GetSigningKey(ctx) + if err != nil { + t.Fatalf("GetSigningKey: %v", err) + } + kid := jwk.Thumbprint(key.N, key.E) + + now := time.Now() + token, err := sign.Mint(ctx, p, kid, map[string]any{ + "iss": "https://integration.test", + "sub": "integration", + "iat": now.Unix(), + "exp": now.Add(5 * time.Minute).Unix(), + }) + if err != nil { + t.Fatalf("sign.Mint: %v", err) + } + + if parts := strings.Split(token, "."); len(parts) != 3 { + t.Fatalf("minted token is not compact 3-segment form: %q", token) + } + decoded, err := sign.Decode(token) + if err != nil { + t.Fatalf("decoding minted token: %v", err) + } + if decoded.SignatureBytes == 0 { + t.Error("minted token has an empty signature") + } +} diff --git a/internal/azurex/manager.go b/internal/azurex/manager.go new file mode 100644 index 0000000..68b76cb --- /dev/null +++ b/internal/azurex/manager.go @@ -0,0 +1,96 @@ +package azurex + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service" +) + +// EnableStaticWebsite turns on static website hosting for the storage account's +// blob service. This is the single control-plane-ish mutation jotsmith is ever +// allowed to make (ADR-0002). Index and error documents are left unset, as +// jotsmith serves only /.well-known/* documents, not a browsable site. +// +// Per the Set Blob Service Properties contract, omitting the other root +// elements (logging, metrics, CORS) preserves their existing settings. +func (p *Provider) EnableStaticWebsite(ctx context.Context) error { + client, err := p.blobClient(ctx) + if err != nil { + return err + } + _, err = client.ServiceClient().SetProperties(ctx, &service.SetPropertiesOptions{ + StaticWebsite: &service.StaticWebsite{Enabled: to.Ptr(true)}, + }) + if err != nil { + return &ResourceError{Op: "enable static website hosting", Resource: p.target.StorageAccount, Err: err} + } + return nil +} + +// UploadBlob writes data to $web/, overwriting any existing blob, with the +// given Content-Type and Cache-Control headers. +func (p *Provider) UploadBlob(ctx context.Context, path, contentType, cacheControl string, data []byte) error { + client, err := p.blobClient(ctx) + if err != nil { + return err + } + _, err = client.UploadBuffer(ctx, webContainer, path, data, &azblob.UploadBufferOptions{ + HTTPHeaders: &blob.HTTPHeaders{ + BlobContentType: to.Ptr(contentType), + BlobCacheControl: to.Ptr(cacheControl), + }, + }) + if err != nil { + return &ResourceError{Op: "upload blob", Resource: webContainer + "/" + path, Err: err} + } + return nil +} + +// CreateRSAKey creates an RSA key (or, if one already exists under the same +// name, a new version of it) with the sign and verify operations enabled, and +// returns the new public material. +func (p *Provider) CreateRSAKey(ctx context.Context, bits int) (Key, error) { + client, err := p.keysClient(ctx) + if err != nil { + return Key{}, err + } + resp, err := client.CreateKey(ctx, p.target.KeyName, azkeys.CreateKeyParameters{ + Kty: to.Ptr(azkeys.KeyTypeRSA), + KeySize: to.Ptr(int32(bits)), + KeyOps: []*azkeys.KeyOperation{ + to.Ptr(azkeys.KeyOperationSign), + to.Ptr(azkeys.KeyOperationVerify), + }, + }, nil) + if err != nil { + return Key{}, &ResourceError{Op: "create signing key", Resource: p.target.KeyName, Err: err} + } + return keyFromBundle(p.target.KeyName, resp.Attributes, resp.Key), nil +} + +func keyFromBundle(name string, attrs *azkeys.KeyAttributes, jwk *azkeys.JSONWebKey) Key { + out := Key{Name: name} + if attrs != nil { + out.Enabled = deref(attrs.Enabled) + } + if jwk != nil { + out.N = jwk.N + out.E = jwk.E + out.Ops = opStrings(jwk.KeyOps) + if jwk.KID != nil { + out.Version = jwk.KID.Version() + } + } + return out +} + +// ensure Provider satisfies the manager contracts at compile time. +var ( + _ SetupManager = (*Provider)(nil) + _ RotateManager = (*Provider)(nil) + _ RepairManager = (*Provider)(nil) +) diff --git a/internal/azurex/provider.go b/internal/azurex/provider.go new file mode 100644 index 0000000..c5452e0 --- /dev/null +++ b/internal/azurex/provider.go @@ -0,0 +1,345 @@ +package azurex + +import ( + "context" + "fmt" + "io" + "log/slog" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armsubscriptions" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" +) + +// webContainer is the container static-website hosting creates. ".well-known" +// cannot be an Azure container name (can't start with "."), so jotsmith +// publishes documents as folders inside $web (see AGENTS.md gotchas). +const webContainer = "$web" + +// Provider is the concrete, Azure-backed implementation of Inspector. It is the +// single seam where SDK calls happen. Control-plane reads are memoized so the +// data-plane clients (which need endpoints discovered from ARM) can be built +// lazily and so doctor's independent checks don't refetch. +// +// Provider performs read-only operations only; the one allowed control-plane +// mutation (enabling static website hosting) and all data-plane writes are +// added in later slices. +type Provider struct { + target Target + cred *azidentity.DefaultAzureCredential + log *slog.Logger + + subscriptions *armsubscriptions.Client + accounts *armstorage.AccountsClient + vaults *armkeyvault.VaultsClient + + accountDone bool + accountVal StorageAccount + accountErr error + + vaultDone bool + vaultVal Vault + vaultErr error + + blob *azblob.Client + blobErr error + + keys *azkeys.Client + keysErr error +} + +// NewProvider resolves DefaultAzureCredential and builds the ARM clients for +// the given target. It does not contact Azure beyond credential construction; +// resource reads happen lazily. +func NewProvider(_ context.Context, target Target, log *slog.Logger) (*Provider, error) { + if log == nil { + log = slog.New(slog.NewTextHandler(io.Discard, nil)) + } + + cred, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return nil, fmt.Errorf("resolving Azure credential (DefaultAzureCredential): %w", err) + } + log.Debug("resolved Azure credential via DefaultAzureCredential") + + subs, err := armsubscriptions.NewClient(cred, nil) + if err != nil { + return nil, fmt.Errorf("creating subscriptions client: %w", err) + } + accounts, err := armstorage.NewAccountsClient(target.SubscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("creating storage accounts client: %w", err) + } + vaults, err := armkeyvault.NewVaultsClient(target.SubscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("creating key vaults client: %w", err) + } + + return &Provider{ + target: target, + cred: cred, + log: log, + subscriptions: subs, + accounts: accounts, + vaults: vaults, + }, nil +} + +// GetSubscription confirms the configured subscription is accessible. +func (p *Provider) GetSubscription(ctx context.Context) (Subscription, error) { + resp, err := p.subscriptions.Get(ctx, p.target.SubscriptionID, nil) + if err != nil { + return Subscription{}, &ResourceError{Op: "get subscription", Resource: p.target.SubscriptionID, Err: err} + } + return Subscription{ + ID: deref(resp.SubscriptionID), + DisplayName: deref(resp.DisplayName), + State: string(deref(resp.State)), + }, nil +} + +// GetStorageAccount finds the configured account subscription-wide (to discover +// its resource group) and reads its full properties. +func (p *Provider) GetStorageAccount(ctx context.Context) (StorageAccount, error) { + if p.accountDone { + return p.accountVal, p.accountErr + } + p.accountDone = true + p.accountVal, p.accountErr = p.fetchAccount(ctx) + return p.accountVal, p.accountErr +} + +func (p *Provider) fetchAccount(ctx context.Context) (StorageAccount, error) { + name := p.target.StorageAccount + rg, err := p.findResourceGroup(ctx, name) + if err != nil { + return StorageAccount{}, err + } + + resp, err := p.accounts.GetProperties(ctx, rg, name, nil) + if err != nil { + return StorageAccount{}, &ResourceError{Op: "get storage account", Resource: name, Err: err} + } + acct := resp.Account + + out := StorageAccount{ + Name: name, + ResourceGroup: rg, + Kind: string(deref(acct.Kind)), + Location: deref(acct.Location), + } + if acct.SKU != nil { + out.SKUName = string(deref(acct.SKU.Name)) + } + if acct.Properties != nil && acct.Properties.PrimaryEndpoints != nil { + // Strip the trailing slash on the web endpoint so it compares cleanly + // against the canonical (slash-free) issuer. The blob endpoint is kept + // verbatim for the data-plane client. + out.WebEndpoint = strings.TrimRight(deref(acct.Properties.PrimaryEndpoints.Web), "/") + out.BlobEndpoint = deref(acct.Properties.PrimaryEndpoints.Blob) + } + return out, nil +} + +// findResourceGroup enumerates storage accounts in the subscription and parses +// the resource group from the matching account's resource ID. jotsmith only +// asks the user for the account name, never the resource group. +func (p *Provider) findResourceGroup(ctx context.Context, name string) (string, error) { + pager := p.accounts.NewListPager(nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return "", &ResourceError{Op: "list storage accounts in subscription", Resource: p.target.SubscriptionID, Err: err} + } + for _, acct := range page.Value { + if acct == nil || acct.Name == nil || *acct.Name != name { + continue + } + id, err := arm.ParseResourceID(deref(acct.ID)) + if err != nil { + return "", &ResourceError{Op: "parse storage account resource ID", Resource: name, Err: err} + } + return id.ResourceGroupName, nil + } + } + return "", &NotFoundError{Kind: "storage account", Resource: name} +} + +// GetStaticWebsite reads the blob service's static-website configuration. +func (p *Provider) GetStaticWebsite(ctx context.Context) (StaticWebsite, error) { + client, err := p.blobClient(ctx) + if err != nil { + return StaticWebsite{}, err + } + resp, err := client.ServiceClient().GetProperties(ctx, nil) + if err != nil { + return StaticWebsite{}, &ResourceError{Op: "get blob service properties", Resource: p.target.StorageAccount, Err: err} + } + enabled := resp.StaticWebsite != nil && deref(resp.StaticWebsite.Enabled) + return StaticWebsite{Enabled: enabled}, nil +} + +// GetDiscoveryDocument downloads the published discovery document blob. +func (p *Provider) GetDiscoveryDocument(ctx context.Context) ([]byte, error) { + return p.getBlob(ctx, p.target.DiscoveryPath) +} + +// GetJWKS downloads the published JWKS blob. +func (p *Provider) GetJWKS(ctx context.Context) ([]byte, error) { + return p.getBlob(ctx, p.target.JWKSPath) +} + +func (p *Provider) getBlob(ctx context.Context, path string) ([]byte, error) { + client, err := p.blobClient(ctx) + if err != nil { + return nil, err + } + resp, err := client.DownloadStream(ctx, webContainer, path, nil) + if err != nil { + if isBlobNotFound(err) { + return nil, &NotFoundError{Kind: "blob", Resource: webContainer + "/" + path} + } + return nil, &ResourceError{Op: "download blob", Resource: webContainer + "/" + path, Err: err} + } + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, &ResourceError{Op: "read blob body", Resource: webContainer + "/" + path, Err: err} + } + return data, nil +} + +func (p *Provider) blobClient(ctx context.Context) (*azblob.Client, error) { + if p.blob != nil || p.blobErr != nil { + return p.blob, p.blobErr + } + acct, err := p.GetStorageAccount(ctx) + if err != nil { + p.blobErr = err + return nil, err + } + if acct.BlobEndpoint == "" { + p.blobErr = &ResourceError{Op: "resolve blob endpoint", Resource: acct.Name, Err: fmt.Errorf("account has no primary blob endpoint")} + return nil, p.blobErr + } + client, err := azblob.NewClient(acct.BlobEndpoint, p.cred, nil) + if err != nil { + p.blobErr = &ResourceError{Op: "create blob client", Resource: acct.Name, Err: err} + return nil, p.blobErr + } + p.blob = client + return client, nil +} + +// GetVault finds the configured vault subscription-wide and reads its +// properties, including whether it is in Azure RBAC mode. +func (p *Provider) GetVault(ctx context.Context) (Vault, error) { + if p.vaultDone { + return p.vaultVal, p.vaultErr + } + p.vaultDone = true + p.vaultVal, p.vaultErr = p.fetchVault(ctx) + return p.vaultVal, p.vaultErr +} + +func (p *Provider) fetchVault(ctx context.Context) (Vault, error) { + name := p.target.KeyVault + rg, err := p.findVaultResourceGroup(ctx, name) + if err != nil { + return Vault{}, err + } + resp, err := p.vaults.Get(ctx, rg, name, nil) + if err != nil { + return Vault{}, &ResourceError{Op: "get key vault", Resource: name, Err: err} + } + + out := Vault{Name: name, ResourceGroup: rg} + if resp.Properties != nil { + out.URI = deref(resp.Properties.VaultURI) + out.RBACEnabled = deref(resp.Properties.EnableRbacAuthorization) + } + return out, nil +} + +func (p *Provider) findVaultResourceGroup(ctx context.Context, name string) (string, error) { + pager := p.vaults.NewListBySubscriptionPager(nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return "", &ResourceError{Op: "list key vaults in subscription", Resource: p.target.SubscriptionID, Err: err} + } + for _, v := range page.Value { + if v == nil || v.Name == nil || *v.Name != name { + continue + } + id, err := arm.ParseResourceID(deref(v.ID)) + if err != nil { + return "", &ResourceError{Op: "parse key vault resource ID", Resource: name, Err: err} + } + return id.ResourceGroupName, nil + } + } + return "", &NotFoundError{Kind: "key vault", Resource: name} +} + +// GetSigningKey reads the configured signing key's public material and status. +func (p *Provider) GetSigningKey(ctx context.Context) (Key, error) { + client, err := p.keysClient(ctx) + if err != nil { + return Key{}, err + } + resp, err := client.GetKey(ctx, p.target.KeyName, "", nil) + if err != nil { + if isNotFound(err) { + return Key{}, &NotFoundError{Kind: "signing key", Resource: p.target.KeyName} + } + return Key{}, &ResourceError{Op: "get signing key", Resource: p.target.KeyName, Err: err} + } + + return keyFromBundle(p.target.KeyName, resp.Attributes, resp.Key), nil +} + +func (p *Provider) keysClient(ctx context.Context) (*azkeys.Client, error) { + if p.keys != nil || p.keysErr != nil { + return p.keys, p.keysErr + } + vault, err := p.GetVault(ctx) + if err != nil { + p.keysErr = err + return nil, err + } + if vault.URI == "" { + p.keysErr = &ResourceError{Op: "resolve key vault URI", Resource: vault.Name, Err: fmt.Errorf("vault has no URI")} + return nil, p.keysErr + } + client, err := azkeys.NewClient(vault.URI, p.cred, nil) + if err != nil { + p.keysErr = &ResourceError{Op: "create key vault data-plane client", Resource: vault.Name, Err: err} + return nil, p.keysErr + } + p.keys = client + return client, nil +} + +func opStrings(ops []*azkeys.KeyOperation) []string { + out := make([]string, 0, len(ops)) + for _, op := range ops { + if op != nil { + out = append(out, string(*op)) + } + } + return out +} + +func deref[T any](p *T) T { + var zero T + if p != nil { + return *p + } + return zero +} diff --git a/internal/azurex/signer.go b/internal/azurex/signer.go new file mode 100644 index 0000000..f6478a5 --- /dev/null +++ b/internal/azurex/signer.go @@ -0,0 +1,32 @@ +package azurex + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" +) + +// Sign creates an RS256 signature over digest using the current version of the +// configured Key Vault key. digest must be the SHA-256 of the JWT signing +// input; Key Vault's Sign API signs a precomputed digest, not raw bytes. +// +// Only the digest is sent to Key Vault — the private key never leaves the +// vault, which is the central security invariant of the product. +func (p *Provider) Sign(ctx context.Context, digest []byte) ([]byte, error) { + client, err := p.keysClient(ctx) + if err != nil { + return nil, err + } + resp, err := client.Sign(ctx, p.target.KeyName, "", azkeys.SignParameters{ + Algorithm: to.Ptr(azkeys.SignatureAlgorithmRS256), + Value: digest, + }, nil) + if err != nil { + return nil, &ResourceError{Op: "sign digest with key", Resource: p.target.KeyName, Err: err} + } + return resp.Result, nil +} + +// ensure Provider satisfies the Signer contract at compile time. +var _ Signer = (*Provider)(nil) diff --git a/internal/azurex/types.go b/internal/azurex/types.go new file mode 100644 index 0000000..bd9af14 --- /dev/null +++ b/internal/azurex/types.go @@ -0,0 +1,130 @@ +package azurex + +import "context" + +// Target identifies the Azure resources a Provider operates on. It is assembled +// from the loaded config so azurex never imports the config package. +type Target struct { + SubscriptionID string + StorageAccount string + KeyVault string + KeyName string + DiscoveryPath string + JWKSPath string +} + +// Subscription is the subset of Azure subscription state jotsmith cares about. +type Subscription struct { + ID string + DisplayName string + State string +} + +// StorageAccount captures the control-plane storage-account state doctor and +// setup need. WebEndpoint and BlobEndpoint have any trailing slash stripped. +type StorageAccount struct { + Name string + ResourceGroup string + Kind string + SKUName string + Location string + WebEndpoint string + BlobEndpoint string +} + +// StaticWebsite is the static-website hosting state of a storage account's blob +// service. +type StaticWebsite struct { + Enabled bool +} + +// Vault captures the control-plane Key Vault state jotsmith needs, notably +// whether it is in Azure RBAC mode (jotsmith refuses legacy access-policy +// vaults). +type Vault struct { + Name string + ResourceGroup string + URI string + RBACEnabled bool +} + +// Key is the public material and status of a Key Vault signing key. N and E are +// the raw big-endian RSA modulus and exponent bytes; private material never +// leaves Key Vault. +type Key struct { + Name string + Version string + Enabled bool + Ops []string + N []byte + E []byte +} + +// Inspector is the read-only Azure surface jotsmith's doctor (and the +// discovery/jwks show commands) depend on. It is implemented by *Provider and, +// in tests, by fakes. +type Inspector interface { + GetSubscription(ctx context.Context) (Subscription, error) + GetStorageAccount(ctx context.Context) (StorageAccount, error) + GetStaticWebsite(ctx context.Context) (StaticWebsite, error) + GetDiscoveryDocument(ctx context.Context) ([]byte, error) + GetJWKS(ctx context.Context) ([]byte, error) + GetVault(ctx context.Context) (Vault, error) + GetSigningKey(ctx context.Context) (Key, error) +} + +// SetupManager is the read+write Azure surface jotsmith setup depends on: the +// read-only Inspector plus the narrow set of mutations setup performs. The only +// control-plane-ish mutation is EnableStaticWebsite (ADR-0002). +type SetupManager interface { + Inspector + EnableStaticWebsite(ctx context.Context) error + UploadBlob(ctx context.Context, path, contentType, cacheControl string, data []byte) error + CreateRSAKey(ctx context.Context, bits int) (Key, error) +} + +// Signer is the Key Vault signing surface token mint depends on: read the +// current key's public material (to derive the kid) and sign a SHA-256 digest +// with it. The private key never leaves Key Vault (the digest is computed +// client-side and only the digest is sent to KV). +type Signer interface { + GetSigningKey(ctx context.Context) (Key, error) + Sign(ctx context.Context, digest []byte) ([]byte, error) +} + +// RotateManager is the Azure surface `key rotate` depends on: read the current +// key (to report the before kid), create a new key version, read the published +// discovery document (to decide whether it needs refreshing), and upload the +// new JWKS / discovery blobs. It is a subset of SetupManager, so *Provider +// satisfies it. +type RotateManager interface { + GetSigningKey(ctx context.Context) (Key, error) + CreateRSAKey(ctx context.Context, bits int) (Key, error) + GetDiscoveryDocument(ctx context.Context) ([]byte, error) + UploadBlob(ctx context.Context, path, contentType, cacheControl string, data []byte) error +} + +// RepairManager is the Azure surface `doctor --repair` depends on: the +// read-only Inspector plus the in-place fixes doctor is allowed to perform +// (re-enable static website hosting, re-upload the discovery/JWKS blobs). It is +// a subset of SetupManager, so *Provider satisfies it. +type RepairManager interface { + Inspector + EnableStaticWebsite(ctx context.Context) error + UploadBlob(ctx context.Context, path, contentType, cacheControl string, data []byte) error +} + +// DestroyManager is the Azure surface `destroy` depends on: soft-delete the +// signing key and delete the published blobs. It never deletes the Storage +// Account or Key Vault themselves (ADR-0002). +type DestroyManager interface { + DeleteSigningKey(ctx context.Context) (DeletedKey, error) + DeleteBlob(ctx context.Context, path string) error +} + +// DeletedKey is the result of a soft delete: the key name plus the recovery ID +// the user can purge or recover with. +type DeletedKey struct { + Name string + RecoveryID string +} diff --git a/internal/cli/azure.go b/internal/cli/azure.go new file mode 100644 index 0000000..0382c85 --- /dev/null +++ b/internal/cli/azure.go @@ -0,0 +1,25 @@ +package cli + +import ( + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/config" +) + +// keyVaultError wraps a Key Vault data-plane failure with a hint naming the +// role the signing principal needs, mirroring the failure modes in PRD §8. +func keyVaultError(err error, vault string) error { + return failuref("%v\n hint: ensure the signing principal has the 'Key Vault Crypto Officer' role on key vault %q", err, vault) +} + +// targetFrom builds the azurex.Target a Provider operates on from a loaded +// config. +func targetFrom(cfg *config.Config) azurex.Target { + return azurex.Target{ + SubscriptionID: cfg.SubscriptionID, + StorageAccount: cfg.StorageAccount, + KeyVault: cfg.KeyVault, + KeyName: cfg.KeyName, + DiscoveryPath: cfg.DiscoveryPath, + JWKSPath: cfg.JWKSPath, + } +} diff --git a/internal/cli/completion.go b/internal/cli/completion.go new file mode 100644 index 0000000..6357b59 --- /dev/null +++ b/internal/cli/completion.go @@ -0,0 +1,95 @@ +package cli + +import ( + "context" + "io" + "strings" + + ucli "github.com/urfave/cli/v3" +) + +const appName = "jotsmith" + +// supportedShells maps the shell names jotsmith accepts to the keys urfave's +// built-in completion renderer uses. jotsmith exposes the friendly name +// "powershell"; urfave's renderer registers it as "pwsh". +var supportedShells = []struct{ name, urfave string }{ + {"bash", "bash"}, + {"zsh", "zsh"}, + {"fish", "fish"}, + {"powershell", "pwsh"}, +} + +const completionDescription = `Output a shell completion script to stdout for the named shell. + +Install (pick your shell): + bash: source <(jotsmith completion bash) + zsh: source <(jotsmith completion zsh) + fish: jotsmith completion fish | source + powershell: jotsmith completion powershell | Out-String | Invoke-Expression + +Example: + jotsmith completion zsh` + +// configureCompletion reshapes urfave's auto-generated completion command so it +// writes the script to stdout, is visible in help, accepts the friendly shell +// name "powershell", and exits 2 (usage error) on an unknown or missing shell. +// Runtime completion via the --generate-shell-completion flag is unaffected. +func configureCompletion(streams IOStreams) ucli.ConfigureShellCompletionCommand { + return func(cc *ucli.Command) { + cc.Hidden = false + cc.Usage = "Output a shell completion script (bash, zsh, fish, powershell)" + cc.ArgsUsage = "bash|zsh|fish|powershell" + cc.Description = completionDescription + cc.Writer = streams.Out + cc.Action = func(ctx context.Context, cmd *ucli.Command) error { + if cmd.Args().Len() != 1 { + return usageErrorf("completion requires exactly one shell argument: one of %s", shellNames()) + } + return runCompletion(ctx, cmd.Args().First(), streams) + } + } +} + +func runCompletion(ctx context.Context, shell string, streams IOStreams) error { + urfaveShell := "" + for _, s := range supportedShells { + if s.name == shell { + urfaveShell = s.urfave + break + } + } + if urfaveShell == "" { + return usageErrorf("unknown shell %q; supported shells: %s", shell, shellNames()) + } + if err := renderShellCompletion(ctx, urfaveShell, streams.Out); err != nil { + return failuref("generating %s completion script: %v", shell, err) + } + return nil +} + +// renderShellCompletion delegates to urfave's built-in completion renderer by +// running a throwaway command tree whose completion command writes to w. This +// reuses the exact scripts urfave ships rather than vendoring copies. +func renderShellCompletion(ctx context.Context, urfaveShell string, w io.Writer) error { + helper := &ucli.Command{ + Name: appName, + Writer: w, + ErrWriter: io.Discard, + // Setting this hook makes urfave add its completion command; the hook + // redirects that command's output to w (it otherwise defaults to + // os.Stdout, not the parent's writer). + ConfigureShellCompletionCommand: func(cc *ucli.Command) { + cc.Writer = w + }, + } + return helper.Run(ctx, []string{appName, "completion", urfaveShell}) +} + +func shellNames() string { + names := make([]string, 0, len(supportedShells)) + for _, s := range supportedShells { + names = append(names, s.name) + } + return strings.Join(names, ", ") +} diff --git a/internal/cli/completion_test.go b/internal/cli/completion_test.go new file mode 100644 index 0000000..ae31ef2 --- /dev/null +++ b/internal/cli/completion_test.go @@ -0,0 +1,59 @@ +package cli + +import ( + "strings" + "testing" +) + +func TestCompletion_EachShellEmitsScriptToStdout(t *testing.T) { + for _, shell := range []string{"bash", "zsh", "fish", "powershell"} { + t.Run(shell, func(t *testing.T) { + stdout, stderr, err := run(t, "completion", shell) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if strings.TrimSpace(stdout) == "" { + t.Errorf("expected a non-empty %s completion script on stdout", shell) + } + if stderr != "" { + t.Errorf("expected empty stderr on success, got:\n%s", stderr) + } + }) + } +} + +func TestCompletion_UnknownShellIsUsageError(t *testing.T) { + stdout, _, err := run(t, "completion", "tcsh") + if err == nil { + t.Fatal("expected an error for an unknown shell") + } + if got := ExitCode(err); got != exitUsage { + t.Errorf("expected exit %d, got %d", exitUsage, got) + } + if stdout != "" { + t.Errorf("expected empty stdout on usage error, got:\n%s", stdout) + } + if !strings.Contains(err.Error(), "bash, zsh, fish, powershell") { + t.Errorf("error should list supported shells, got: %v", err) + } +} + +func TestCompletion_MissingShellIsUsageError(t *testing.T) { + _, _, err := run(t, "completion") + if err == nil { + t.Fatal("expected a usage error when no shell is given") + } + if got := ExitCode(err); got != exitUsage { + t.Errorf("expected exit %d, got %d", exitUsage, got) + } +} + +func TestCompletion_TooManyArgsIsUsageError(t *testing.T) { + _, _, err := run(t, "completion", "bash", "extra") + if err == nil { + t.Fatal("expected a usage error for too many arguments") + } + if got := ExitCode(err); got != exitUsage { + t.Errorf("expected exit %d, got %d", exitUsage, got) + } +} diff --git a/internal/cli/config.go b/internal/cli/config.go new file mode 100644 index 0000000..413ca69 --- /dev/null +++ b/internal/cli/config.go @@ -0,0 +1,68 @@ +package cli + +import ( + "context" + "encoding/json" + "fmt" + + ucli "github.com/urfave/cli/v3" + + "github.com/MaxAnderson95/jotsmith/internal/config" +) + +// loadConfig resolves the config path from the global --config flag (with its +// JOTSMITH_CONFIG / XDG fallbacks) and loads it. Any load error is mapped to a +// usage error (exit code 2), as documented for every config-backed command. +func loadConfig(ctx context.Context, cmd *ucli.Command) (*config.Config, error) { + path := config.ResolvePath(cmd.String("config")) + cfg, err := config.Load(path, loggerFrom(ctx)) + if err != nil { + return nil, usageErrorf("%v", err) + } + return cfg, nil +} + +func configCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "config", + Usage: "Inspect the jotsmith config file", + Commands: []*ucli.Command{ + configShowCommand(streams), + }, + } +} + +func configShowCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "show", + Usage: "Print the resolved config, or its path with --path", + Description: "Print the loaded config as pretty JSON to stdout.\n\n" + + "Examples:\n" + + " jotsmith config show # print the config contents\n" + + " jotsmith config show --path # print the resolved config file path", + Flags: []ucli.Flag{ + &ucli.BoolFlag{ + Name: "path", + Usage: "print the resolved config file path instead of its contents", + }, + }, + Action: func(ctx context.Context, cmd *ucli.Command) error { + if cmd.Bool("path") { + fmt.Fprintln(streams.Out, config.ResolvePath(cmd.String("config"))) + return nil + } + + cfg, err := loadConfig(ctx, cmd) + if err != nil { + return err + } + + b, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return failuref("rendering config as JSON: %v", err) + } + fmt.Fprintln(streams.Out, string(b)) + return nil + }, + } +} diff --git a/internal/cli/config_test.go b/internal/cli/config_test.go new file mode 100644 index 0000000..7b1f8e8 --- /dev/null +++ b/internal/cli/config_test.go @@ -0,0 +1,82 @@ +package cli + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" +) + +const testConfigBody = `{ + "version": 1, + "subscription_id": "00000000-0000-0000-0000-000000000000", + "storage_account": "jotsmithmax", + "key_vault": "jotsmith-max-kv", + "key_name": "signing-key", + "issuer": "https://jotsmithmax.z13.web.core.windows.net" +}` + +func writeTempConfig(t *testing.T, body string) string { + t.Helper() + path := filepath.Join(t.TempDir(), "config.json") + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("writing temp config: %v", err) + } + return path +} + +func TestConfigShowPrintsJSONToStdout(t *testing.T) { + path := writeTempConfig(t, testConfigBody) + stdout, _, err := run(t, "config", "show", "--config", path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var got map[string]any + if err := json.Unmarshal([]byte(stdout), &got); err != nil { + t.Fatalf("stdout is not valid JSON: %v\n%s", err, stdout) + } + if got["storage_account"] != "jotsmithmax" { + t.Errorf("unexpected storage_account: %v", got["storage_account"]) + } + // Path defaults should be surfaced in the effective config. + if got["jwks_path"] != ".well-known/jwks.json" { + t.Errorf("expected default jwks_path, got %v", got["jwks_path"]) + } +} + +func TestConfigShowPathPrintsResolvedPath(t *testing.T) { + path := writeTempConfig(t, testConfigBody) + stdout, _, err := run(t, "config", "show", "--path", "--config", path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if strings.TrimSpace(stdout) != path { + t.Errorf("expected resolved path %q, got %q", path, strings.TrimSpace(stdout)) + } +} + +func TestConfigShowPathDoesNotRequireExistingFile(t *testing.T) { + missing := filepath.Join(t.TempDir(), "absent.json") + stdout, _, err := run(t, "config", "show", "--path", "--config", missing) + if err != nil { + t.Fatalf("--path should not require the file to exist: %v", err) + } + if strings.TrimSpace(stdout) != missing { + t.Errorf("expected %q, got %q", missing, strings.TrimSpace(stdout)) + } +} + +func TestConfigShowLoadErrorExitsTwo(t *testing.T) { + missing := filepath.Join(t.TempDir(), "absent.json") + stdout, _, err := run(t, "config", "show", "--config", missing) + if err == nil { + t.Fatal("expected an error loading a missing config") + } + if got := ExitCode(err); got != exitUsage { + t.Errorf("expected exit %d, got %d", exitUsage, got) + } + if stdout != "" { + t.Errorf("expected no stdout on load error, got:\n%s", stdout) + } +} diff --git a/internal/cli/context.go b/internal/cli/context.go new file mode 100644 index 0000000..5bfb7e4 --- /dev/null +++ b/internal/cli/context.go @@ -0,0 +1,27 @@ +package cli + +import ( + "context" + "log/slog" +) + +type ctxKey int + +const loggerKey ctxKey = iota + +func withLogger(ctx context.Context, l *slog.Logger) context.Context { + return context.WithValue(ctx, loggerKey, l) +} + +// loggerFrom returns the logger stored in ctx by the root Before hook, falling +// back to a discard logger so commands invoked in isolation never panic. +func loggerFrom(ctx context.Context) *slog.Logger { + if l, ok := ctx.Value(loggerKey).(*slog.Logger); ok && l != nil { + return l + } + return slog.New(slog.NewTextHandler(discard{}, nil)) +} + +type discard struct{} + +func (discard) Write(p []byte) (int, error) { return len(p), nil } diff --git a/internal/cli/destroy.go b/internal/cli/destroy.go new file mode 100644 index 0000000..ff8be7b --- /dev/null +++ b/internal/cli/destroy.go @@ -0,0 +1,112 @@ +package cli + +import ( + "context" + "errors" + "fmt" + "os" + + ucli "github.com/urfave/cli/v3" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/config" + "github.com/MaxAnderson95/jotsmith/internal/oidc" +) + +func destroyCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "destroy", + Usage: "Tear down the issuer-shaped state (keeps the Azure resources)", + Description: "Soft-delete the signing key and delete the published discovery / JWKS blobs. " + + "The Storage Account and Key Vault themselves are never deleted. The local config is " + + "kept unless --all is given.\n\n" + + "Prompts for confirmation unless --yes. Re-running against an already-destroyed issuer " + + "is a no-op that still exits 0.\n\n" + + "Example:\n" + + " jotsmith destroy --yes", + Flags: []ucli.Flag{ + &ucli.BoolFlag{Name: "yes", Usage: "skip the confirmation prompt"}, + &ucli.BoolFlag{Name: "all", Usage: "also delete the local config file"}, + }, + Action: func(ctx context.Context, cmd *ucli.Command) error { + cfg, err := loadConfig(ctx, cmd) + if err != nil { + return err + } + mgr, err := azurex.NewProvider(ctx, targetFrom(cfg), loggerFrom(ctx)) + if err != nil { + return failuref("%v", err) + } + configPath := config.ResolvePath(cmd.String("config")) + return runDestroy(ctx, mgr, cfg, configPath, cmd.Bool("yes"), cmd.Bool("all"), streams) + }, + } +} + +// runDestroy soft-deletes the signing key and the published blobs. It is +// idempotent: resources that are already gone are reported as WARN and do not +// fail the command. Anything actually deleted is reported to stderr; nothing is +// written to stdout. +func runDestroy(ctx context.Context, mgr azurex.DestroyManager, cfg *config.Config, configPath string, yes, all bool, streams IOStreams) error { + ok, cerr := confirm(streams, yes, fmt.Sprintf("Destroy the issuer state for %q? This soft-deletes the signing key and removes the published documents.", cfg.Issuer)) + if cerr != nil { + return cerr + } + if !ok { + fmt.Fprintln(streams.Err, "destroy aborted") + return nil + } + + fmt.Fprintln(streams.Err, "destroying issuer state:") + + deleted, err := mgr.DeleteSigningKey(ctx) + switch { + case err == nil: + fmt.Fprintf(streams.Err, " [deleted] signing key %q (soft delete; recover or purge it via `az keyvault key recover/purge --name %s`)\n", deleted.Name, deleted.Name) + case isAlreadyGone(err): + fmt.Fprintf(streams.Err, " [warn] signing key %q was already absent\n", cfg.KeyName) + default: + return keyVaultError(err, cfg.KeyVault) + } + + for _, path := range []string{cfg.DiscoveryPath, cfg.JWKSPath} { + url := oidc.JoinURL(cfg.Issuer, path) + switch derr := mgr.DeleteBlob(ctx, path); { + case derr == nil: + fmt.Fprintf(streams.Err, " [deleted] %s\n", url) + case isAlreadyGone(derr): + fmt.Fprintf(streams.Err, " [warn] %s was already absent\n", url) + default: + return failuref("%v", derr) + } + } + + if all { + if rerr := removeConfig(configPath, streams); rerr != nil { + return rerr + } + } else { + fmt.Fprintf(streams.Err, " [kept] config file %s (pass --all to remove it)\n", configPath) + } + + return nil +} + +func removeConfig(path string, streams IOStreams) error { + switch err := os.Remove(path); { + case err == nil: + fmt.Fprintf(streams.Err, " [deleted] config file %s\n", path) + case errors.Is(err, os.ErrNotExist): + fmt.Fprintf(streams.Err, " [warn] config file %s was already absent\n", path) + default: + return failuref("removing config file %s: %v", path, err) + } + return nil +} + +// isAlreadyGone reports whether err means the resource was already absent, which +// destroy treats as an idempotent success rather than a failure. +func isAlreadyGone(err error) bool { + var notFound *azurex.NotFoundError + return errors.As(err, ¬Found) +} diff --git a/internal/cli/destroy_test.go b/internal/cli/destroy_test.go new file mode 100644 index 0000000..c1043bd --- /dev/null +++ b/internal/cli/destroy_test.go @@ -0,0 +1,157 @@ +package cli + +import ( + "context" + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/config" +) + +// fakeDestroyer records destroy operations and can simulate per-resource +// outcomes (deleted, already-gone, or a hard error). +type fakeDestroyer struct { + keyErr error + keyDeleted bool + deletedPaths []string + blobErr map[string]error // per-path error; *NotFoundError means already gone +} + +func (f *fakeDestroyer) DeleteSigningKey(context.Context) (azurex.DeletedKey, error) { + if f.keyErr != nil { + return azurex.DeletedKey{}, f.keyErr + } + f.keyDeleted = true + return azurex.DeletedKey{Name: "signing-key", RecoveryID: "https://kv/deletedkeys/signing-key"}, nil +} + +func (f *fakeDestroyer) DeleteBlob(_ context.Context, path string) error { + if err := f.blobErr[path]; err != nil { + return err + } + f.deletedPaths = append(f.deletedPaths, path) + return nil +} + +func yesStreams() (IOStreams, *strings.Builder, *strings.Builder) { + var out, errBuf strings.Builder + return IOStreams{In: strings.NewReader(""), Out: &out, Err: &errBuf}, &out, &errBuf +} + +func TestRunDestroy_DeletesKeyAndBlobsNothingToStdout(t *testing.T) { + f := &fakeDestroyer{} + streams, out, errBuf := yesStreams() + + if err := runDestroy(context.Background(), f, testConfig(), "/tmp/cfg.json", true, false, streams); err != nil { + t.Fatalf("runDestroy: %v", err) + } + if !f.keyDeleted { + t.Error("signing key was not deleted") + } + if len(f.deletedPaths) != 2 { + t.Errorf("expected both blobs deleted, got %v", f.deletedPaths) + } + if out.Len() != 0 { + t.Errorf("destroy must write nothing to stdout, got:\n%s", out.String()) + } + es := errBuf.String() + if !strings.Contains(es, "signing key") || !strings.Contains(es, "recover or purge") { + t.Errorf("stderr should report the key deletion + recovery hint, got:\n%s", es) + } + if !strings.Contains(es, config.DefaultJWKSPath) { + t.Errorf("stderr should report the deleted blob URLs, got:\n%s", es) + } +} + +func TestRunDestroy_NonTTYWithoutYesFails(t *testing.T) { + f := &fakeDestroyer{} + streams, _, _ := yesStreams() + err := runDestroy(context.Background(), f, testConfig(), "/tmp/cfg.json", false, false, streams) + if err == nil || ExitCode(err) != exitFailure { + t.Fatalf("expected exit-1 refusal without --yes on non-TTY, got %v (code %d)", err, ExitCode(err)) + } + if f.keyDeleted || len(f.deletedPaths) != 0 { + t.Error("nothing should be deleted when confirmation is refused") + } +} + +func TestRunDestroy_IdempotentWhenAlreadyGone(t *testing.T) { + f := &fakeDestroyer{ + keyErr: &azurex.NotFoundError{Kind: "signing key", Resource: "signing-key"}, + blobErr: map[string]error{ + config.DefaultDiscoveryPath: &azurex.NotFoundError{Kind: "blob", Resource: config.DefaultDiscoveryPath}, + config.DefaultJWKSPath: &azurex.NotFoundError{Kind: "blob", Resource: config.DefaultJWKSPath}, + }, + } + streams, _, errBuf := yesStreams() + if err := runDestroy(context.Background(), f, testConfig(), "/tmp/cfg.json", true, false, streams); err != nil { + t.Fatalf("idempotent destroy should exit 0, got %v", err) + } + if c := strings.Count(errBuf.String(), "already absent"); c != 3 { + t.Errorf("expected 3 'already absent' WARN lines (key + 2 blobs), got %d:\n%s", c, errBuf.String()) + } +} + +func TestRunDestroy_HardKeyErrorFails(t *testing.T) { + f := &fakeDestroyer{keyErr: errors.New("403 forbidden")} + streams, _, _ := yesStreams() + err := runDestroy(context.Background(), f, testConfig(), "/tmp/cfg.json", true, false, streams) + if err == nil || ExitCode(err) != exitFailure { + t.Fatalf("expected exit-1 on hard KV error, got %v (code %d)", err, ExitCode(err)) + } +} + +func TestRunDestroy_AllRemovesConfig(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.json") + if err := os.WriteFile(cfgPath, []byte("{}"), 0o600); err != nil { + t.Fatal(err) + } + + f := &fakeDestroyer{} + streams, _, errBuf := yesStreams() + if err := runDestroy(context.Background(), f, testConfig(), cfgPath, true, true, streams); err != nil { + t.Fatalf("runDestroy --all: %v", err) + } + if _, err := os.Stat(cfgPath); !os.IsNotExist(err) { + t.Errorf("--all should remove the config file, stat err=%v", err) + } + if !strings.Contains(errBuf.String(), "config file") { + t.Errorf("stderr should report config deletion, got:\n%s", errBuf.String()) + } +} + +func TestRunDestroy_WithoutAllKeepsConfig(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.json") + if err := os.WriteFile(cfgPath, []byte("{}"), 0o600); err != nil { + t.Fatal(err) + } + + f := &fakeDestroyer{} + streams, _, errBuf := yesStreams() + if err := runDestroy(context.Background(), f, testConfig(), cfgPath, true, false, streams); err != nil { + t.Fatalf("runDestroy: %v", err) + } + if _, err := os.Stat(cfgPath); err != nil { + t.Errorf("config should be kept without --all, stat err=%v", err) + } + if !strings.Contains(errBuf.String(), "pass --all to remove it") { + t.Errorf("stderr should note the config was kept, got:\n%s", errBuf.String()) + } +} + +func TestDestroyMissingConfigIsUsageError(t *testing.T) { + t.Setenv("JOTSMITH_CONFIG", filepath.Join(t.TempDir(), "absent.json")) + _, _, err := run(t, "destroy", "--yes") + if err == nil { + t.Fatal("expected a usage error when config is missing") + } + if got := ExitCode(err); got != exitUsage { + t.Errorf("expected exit %d, got %d", exitUsage, got) + } +} diff --git a/internal/cli/discovery.go b/internal/cli/discovery.go new file mode 100644 index 0000000..1441b91 --- /dev/null +++ b/internal/cli/discovery.go @@ -0,0 +1,45 @@ +package cli + +import ( + "context" + "encoding/json" + "fmt" + + ucli "github.com/urfave/cli/v3" + + "github.com/MaxAnderson95/jotsmith/internal/oidc" +) + +func discoveryCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "discovery", + Usage: "Show the OIDC discovery document", + Commands: []*ucli.Command{ + discoveryShowCommand(streams), + }, + } +} + +func discoveryShowCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "show", + Usage: "Print the discovery document jotsmith would publish", + Description: "Render and print the discovery document to stdout exactly as setup / " + + "doctor --repair would upload it. Does not fetch from the network.\n\n" + + "Example:\n" + + " jotsmith discovery show", + Action: func(ctx context.Context, cmd *ucli.Command) error { + cfg, err := loadConfig(ctx, cmd) + if err != nil { + return err + } + doc := oidc.Render(cfg.Issuer, cfg.JWKSPath) + b, err := json.MarshalIndent(doc, "", " ") + if err != nil { + return failuref("rendering discovery document: %v", err) + } + fmt.Fprintln(streams.Out, string(b)) + return nil + }, + } +} diff --git a/internal/cli/discovery_jwks_test.go b/internal/cli/discovery_jwks_test.go new file mode 100644 index 0000000..747c5ca --- /dev/null +++ b/internal/cli/discovery_jwks_test.go @@ -0,0 +1,58 @@ +package cli + +import ( + "context" + "encoding/json" + "testing" + + "github.com/MaxAnderson95/jotsmith/internal/jwk" +) + +func TestDiscoveryShowPrintsRenderedDoc(t *testing.T) { + path := writeTempConfig(t, testConfigBody) + stdout, stderr, err := run(t, "discovery", "show", "--config", path) + if err != nil { + t.Fatalf("unexpected error: %v (stderr: %s)", err, stderr) + } + var doc struct { + Issuer string `json:"issuer"` + JWKSURI string `json:"jwks_uri"` + } + if err := json.Unmarshal([]byte(stdout), &doc); err != nil { + t.Fatalf("stdout not valid discovery JSON: %v\n%s", err, stdout) + } + if doc.Issuer != "https://jotsmithmax.z13.web.core.windows.net" { + t.Errorf("issuer = %q", doc.Issuer) + } + if doc.JWKSURI != "https://jotsmithmax.z13.web.core.windows.net/.well-known/jwks.json" { + t.Errorf("jwks_uri = %q", doc.JWKSURI) + } +} + +func TestBuildJWKS_FromInspector(t *testing.T) { + insp := healthyInspector(t) + set, err := buildJWKS(context.Background(), insp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(set.Keys) != 1 { + t.Fatalf("expected 1 key, got %d", len(set.Keys)) + } + want := jwk.Thumbprint(insp.key.N, insp.key.E) + if set.Keys[0].Kid != want { + t.Errorf("kid = %q, want %q", set.Keys[0].Kid, want) + } + if set.Keys[0].Kty != "RSA" || set.Keys[0].Alg != "RS256" { + t.Errorf("unexpected JWK fields: %+v", set.Keys[0]) + } +} + +func TestBuildJWKS_PropagatesError(t *testing.T) { + insp := healthyInspector(t) + insp.key = healthyInspector(t).key + insp.keyErr = context.DeadlineExceeded + _, err := buildJWKS(context.Background(), insp) + if err == nil { + t.Fatal("expected error to propagate from inspector") + } +} diff --git a/internal/cli/doc.go b/internal/cli/doc.go new file mode 100644 index 0000000..0cea42e --- /dev/null +++ b/internal/cli/doc.go @@ -0,0 +1,7 @@ +// Package cli defines the jotsmith command tree built on urfave/cli/v3. +// +// Each top-level noun lives in its own file (root.go, plus token.go, key.go, +// etc. as they are added) and is wired into the root command in +// cmd/jotsmith/main.go. Every command inherits the global flags declared on the +// root: --config, --log-level, and --no-color. +package cli diff --git a/internal/cli/doctor.go b/internal/cli/doctor.go new file mode 100644 index 0000000..0526d9a --- /dev/null +++ b/internal/cli/doctor.go @@ -0,0 +1,374 @@ +package cli + +import ( + "context" + "encoding/json" + "fmt" + "io" + "slices" + + ucli "github.com/urfave/cli/v3" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/config" + "github.com/MaxAnderson95/jotsmith/internal/jwk" + "github.com/MaxAnderson95/jotsmith/internal/oidc" +) + +// Check names. Defined as constants so the inspect, repair, and report stages +// refer to the same identifiers. +const ( + checkCredential = "Azure credential resolvable" + checkSubscription = "Subscription accessible" + checkStorageAccount = "Storage account exists" + checkStaticWebsite = "Static website hosting enabled" + checkIssuerEndpoint = "Issuer matches web endpoint" + checkDiscovery = "Discovery document published" + checkJWKS = "JWKS published" + checkVaultRBAC = "Key Vault in RBAC mode" + checkSigningKey = "Signing key valid" +) + +// checkStatus is the outcome of a single doctor check. +type checkStatus string + +const ( + statusPass checkStatus = "PASS" + statusWarn checkStatus = "WARN" + statusFail checkStatus = "FAIL" +) + +type check struct { + Name string `json:"name"` + Status checkStatus `json:"status"` + Message string `json:"message"` + Repaired bool `json:"repaired,omitempty"` +} + +// doctorReport is the --json output shape. +type doctorReport struct { + Checks []check `json:"checks"` + Summary doctorSummary `json:"summary"` +} + +type doctorSummary struct { + Pass int `json:"pass"` + Warn int `json:"warn"` + Fail int `json:"fail"` + Repaired int `json:"repaired"` +} + +func anyFailed(checks []check) bool { + for _, c := range checks { + if c.Status == statusFail { + return true + } + } + return false +} + +// checkFailed reports whether the named check is present and FAILed. +func checkFailed(checks []check, name string) bool { + for _, c := range checks { + if c.Name == name { + return c.Status == statusFail + } + } + return false +} + +func doctorCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "doctor", + Usage: "Audit Azure state against the local config", + Description: "Audit the configured Azure resources against the local config and report " + + "PASS / WARN / FAIL per check. Read-only by default.\n\n" + + "--repair fixes in place every FAIL it can: it re-enables static website hosting and " + + "re-uploads the discovery document / JWKS when they are missing, malformed, or stale. " + + "FAILs needing human action (legacy access-policy Key Vault, issuer/web-endpoint " + + "mismatch) are reported but not changed.\n\n" + + "--json emits a single machine-readable JSON object on stdout and suppresses the " + + "pretty report.\n\n" + + "Examples:\n" + + " jotsmith doctor\n" + + " jotsmith doctor --repair\n" + + " jotsmith doctor --json", + Flags: []ucli.Flag{ + &ucli.BoolFlag{Name: "e2e", Usage: "additionally mint and verify a short-lived token end to end"}, + &ucli.BoolFlag{Name: "repair", Usage: "fix in place every FAIL the tool knows how to fix"}, + &ucli.BoolFlag{Name: "json", Usage: "emit machine-readable JSON on stdout instead of the pretty report"}, + }, + Action: func(ctx context.Context, cmd *ucli.Command) error { + cfg, err := loadConfig(ctx, cmd) + if err != nil { + return err + } + opts := doctorOptions{ + e2e: cmd.Bool("e2e"), + repair: cmd.Bool("repair"), + json: cmd.Bool("json"), + } + color := !noColor(cmd) && isTerminal(streams.Err) + + mgr, err := azurex.NewProvider(ctx, targetFrom(cfg), loggerFrom(ctx)) + if err != nil { + credFail := []check{{Name: checkCredential, Status: statusFail, Message: err.Error()}} + if opts.json { + _ = writeDoctorJSON(streams.Out, credFail) + return ucli.Exit("", exitFailure) + } + renderDoctor(streams.Err, color, credFail) + return ucli.Exit("doctor: Azure credential could not be resolved", exitFailure) + } + return runDoctor(ctx, mgr, cfg, opts, color, streams) + }, + } +} + +type doctorOptions struct { + e2e bool + repair bool + json bool +} + +// runDoctor inspects the issuer, optionally repairs fixable FAILs, then renders +// the report. Exit code is 0 when no FAIL remains after repairs, 1 otherwise. +func runDoctor(ctx context.Context, mgr azurex.RepairManager, cfg *config.Config, opts doctorOptions, color bool, streams IOStreams) error { + checks := inspect(ctx, mgr, cfg) + + if opts.repair { + repaired := repairChecks(ctx, mgr, cfg, checks) + // Re-read the live state so the report reflects the repairs. + checks = inspect(ctx, mgr, cfg) + for i := range checks { + if repaired[checks[i].Name] && checks[i].Status == statusPass { + checks[i].Repaired = true + } + } + } + + // The credential check is implicit: reaching here means the provider was + // built, so the credential resolved. It is not repairable. + full := append([]check{{Name: checkCredential, Status: statusPass, Message: "resolved via DefaultAzureCredential"}}, checks...) + + if opts.e2e { + full = append(full, check{ + Name: "End-to-end mint and verify", + Status: statusWarn, + Message: "skipped: end-to-end check is not yet implemented", + }) + } + + if opts.json { + if err := writeDoctorJSON(streams.Out, full); err != nil { + return failuref("rendering doctor JSON: %v", err) + } + } else { + renderDoctor(streams.Err, color, full) + } + + if anyFailed(full) { + if opts.json { + return ucli.Exit("", exitFailure) + } + return ucli.Exit("doctor: one or more checks failed", exitFailure) + } + return nil +} + +// repairChecks performs the in-place fixes for fixable FAILs and returns the set +// of check names a fix was successfully applied for. It mirrors setup's +// idempotent upload logic. Human-action FAILs (legacy vault, issuer mismatch) +// are intentionally never touched. +func repairChecks(ctx context.Context, mgr azurex.RepairManager, cfg *config.Config, checks []check) map[string]bool { + repaired := map[string]bool{} + + if checkFailed(checks, checkStaticWebsite) { + if err := mgr.EnableStaticWebsite(ctx); err == nil { + repaired[checkStaticWebsite] = true + } + } + + if checkFailed(checks, checkDiscovery) { + if doc, err := json.MarshalIndent(oidc.Render(cfg.Issuer, cfg.JWKSPath), "", " "); err == nil { + if uerr := mgr.UploadBlob(ctx, cfg.DiscoveryPath, contentTypeJSON, cacheControl, doc); uerr == nil { + repaired[checkDiscovery] = true + } + } + } + + // A malformed/stale JWKS and a thumbprint mismatch are both fixed by + // re-uploading the JWKS for the current key — but only when the key itself + // is usable (a disabled or sign-less key is a human-action FAIL). + if checkFailed(checks, checkJWKS) || checkFailed(checks, checkSigningKey) { + if key, err := mgr.GetSigningKey(ctx); err == nil && key.Enabled && slices.Contains(key.Ops, "sign") { + if doc, derr := json.MarshalIndent(jwk.NewSet(jwk.FromRSA(key.N, key.E)), "", " "); derr == nil { + if uerr := mgr.UploadBlob(ctx, cfg.JWKSPath, contentTypeJSON, cacheControl, doc); uerr == nil { + repaired[checkJWKS] = true + repaired[checkSigningKey] = true + } + } + } + } + + return repaired +} + +// inspect runs the read-only checks from PRD §6.6 against insp and returns one +// result per check, in report order. It never panics on Azure errors: every +// error becomes a FAIL with a descriptive message. +func inspect(ctx context.Context, insp azurex.Inspector, cfg *config.Config) []check { + var checks []check + add := func(name string, status checkStatus, format string, args ...any) { + checks = append(checks, check{Name: name, Status: status, Message: fmt.Sprintf(format, args...)}) + } + + if sub, err := insp.GetSubscription(ctx); err != nil { + add(checkSubscription, statusFail, "%v", err) + } else { + add(checkSubscription, statusPass, "%s (%s)", sub.DisplayName, sub.ID) + } + + acct, acctErr := insp.GetStorageAccount(ctx) + if acctErr != nil { + add(checkStorageAccount, statusFail, "%v", acctErr) + } else { + add(checkStorageAccount, statusPass, "%s (resource group %s, kind %s)", acct.Name, acct.ResourceGroup, acct.Kind) + } + + switch sw, err := insp.GetStaticWebsite(ctx); { + case err != nil: + add(checkStaticWebsite, statusFail, "%v", err) + case !sw.Enabled: + add(checkStaticWebsite, statusFail, "static website hosting is disabled on %q", cfg.StorageAccount) + default: + add(checkStaticWebsite, statusPass, "enabled") + } + + switch { + case acctErr != nil: + add(checkIssuerEndpoint, statusFail, "cannot verify: storage account is unavailable") + case acct.WebEndpoint == "": + add(checkIssuerEndpoint, statusFail, "storage account has no primary web endpoint") + case acct.WebEndpoint != cfg.Issuer: + add(checkIssuerEndpoint, statusFail, "config issuer %q does not match storage web endpoint %q; run `jotsmith setup --force-issuer-rewrite` to update the config issuer", cfg.Issuer, acct.WebEndpoint) + default: + add(checkIssuerEndpoint, statusPass, "%s", cfg.Issuer) + } + + if data, err := insp.GetDiscoveryDocument(ctx); err != nil { + add(checkDiscovery, statusFail, "%v", err) + } else { + var doc struct { + Issuer string `json:"issuer"` + } + switch { + case json.Unmarshal(data, &doc) != nil: + add(checkDiscovery, statusFail, "published discovery document is not valid JSON") + case doc.Issuer != cfg.Issuer: + add(checkDiscovery, statusFail, "published issuer %q does not match config issuer %q", doc.Issuer, cfg.Issuer) + default: + add(checkDiscovery, statusPass, "issuer %s", doc.Issuer) + } + } + + publishedKid := "" + if data, err := insp.GetJWKS(ctx); err != nil { + add(checkJWKS, statusFail, "%v", err) + } else { + var set jwk.Set + switch { + case json.Unmarshal(data, &set) != nil: + add(checkJWKS, statusFail, "published JWKS is not valid JSON") + case len(set.Keys) == 0: + add(checkJWKS, statusFail, "published JWKS contains no keys") + case set.Keys[0].Kty != "RSA" || set.Keys[0].N == "" || set.Keys[0].E == "": + add(checkJWKS, statusFail, "published JWKS key is not a valid RSA JWK") + default: + publishedKid = set.Keys[0].Kid + add(checkJWKS, statusPass, "1 RSA key, kid %s", publishedKid) + } + } + + if v, err := insp.GetVault(ctx); err != nil { + add(checkVaultRBAC, statusFail, "%v", err) + } else if !v.RBACEnabled { + add(checkVaultRBAC, statusFail, "key vault %q is in legacy access-policy mode; jotsmith requires Azure RBAC mode", v.Name) + } else { + add(checkVaultRBAC, statusPass, "%s", v.Name) + } + + switch k, err := insp.GetSigningKey(ctx); { + case err != nil: + add(checkSigningKey, statusFail, "%v", err) + case !k.Enabled: + add(checkSigningKey, statusFail, "signing key %q is disabled", k.Name) + case !slices.Contains(k.Ops, "sign"): + add(checkSigningKey, statusFail, "signing key %q lacks the sign operation", k.Name) + default: + computed := jwk.Thumbprint(k.N, k.E) + switch { + case publishedKid == "": + add(checkSigningKey, statusWarn, "thumbprint %s computed, but published JWKS kid is unavailable to compare", computed) + case computed != publishedKid: + add(checkSigningKey, statusFail, "key thumbprint %s does not match published JWKS kid %s", computed, publishedKid) + default: + add(checkSigningKey, statusPass, "enabled, can sign, thumbprint %s matches JWKS", computed) + } + } + + return checks +} + +func writeDoctorJSON(w io.Writer, checks []check) error { + report := doctorReport{Checks: checks, Summary: summarize(checks)} + b, err := json.MarshalIndent(report, "", " ") + if err != nil { + return err + } + _, err = fmt.Fprintln(w, string(b)) + return err +} + +func summarize(checks []check) doctorSummary { + var s doctorSummary + for _, c := range checks { + switch c.Status { + case statusPass: + s.Pass++ + case statusWarn: + s.Warn++ + case statusFail: + s.Fail++ + } + if c.Repaired { + s.Repaired++ + } + } + return s +} + +func renderDoctor(w io.Writer, color bool, checks []check) { + for _, c := range checks { + suffix := "" + if c.Repaired { + suffix = " (repaired)" + } + fmt.Fprintf(w, "%s %s: %s%s\n", statusLabel(c.Status, color), c.Name, c.Message, suffix) + } +} + +func statusLabel(s checkStatus, color bool) string { + text := "[" + string(s) + "]" + if !color { + return text + } + switch s { + case statusPass: + return "\x1b[32m" + text + "\x1b[0m" + case statusWarn: + return "\x1b[33m" + text + "\x1b[0m" + default: + return "\x1b[31m" + text + "\x1b[0m" + } +} diff --git a/internal/cli/doctor_repair_test.go b/internal/cli/doctor_repair_test.go new file mode 100644 index 0000000..1aa2912 --- /dev/null +++ b/internal/cli/doctor_repair_test.go @@ -0,0 +1,165 @@ +package cli + +import ( + "context" + "encoding/json" + "path/filepath" + "strings" + "testing" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/config" +) + +// healthyRepairManager wraps a fully-healthy inspector in a fakeManager so +// doctor --repair can drive the mutating methods. +func healthyRepairManager(t *testing.T) *fakeManager { + t.Helper() + return &fakeManager{fakeInspector: healthyInspector(t)} +} + +func TestRunDoctor_RepairReuploadsMissingBlobs(t *testing.T) { + mgr := healthyRepairManager(t) + mgr.disco, mgr.discoErr = nil, &azurex.NotFoundError{Kind: "blob", Resource: config.DefaultDiscoveryPath} + mgr.jwks, mgr.jwksErr = nil, &azurex.NotFoundError{Kind: "blob", Resource: config.DefaultJWKSPath} + + streams, out, _ := mintStreams() + err := runDoctor(context.Background(), mgr, testConfig(), doctorOptions{repair: true, json: true}, false, streams) + if err != nil { + t.Fatalf("expected exit 0 after repair, got %v", err) + } + + report := parseReport(t, out.String()) + if c := reportCheck(t, report, checkDiscovery); c.Status != statusPass || !c.Repaired { + t.Errorf("discovery should be PASS+repaired, got %s repaired=%v", c.Status, c.Repaired) + } + if c := reportCheck(t, report, checkJWKS); c.Status != statusPass || !c.Repaired { + t.Errorf("JWKS should be PASS+repaired, got %s repaired=%v", c.Status, c.Repaired) + } + if _, ok := mgr.uploads[config.DefaultDiscoveryPath]; !ok { + t.Error("discovery doc was not re-uploaded") + } + if _, ok := mgr.uploads[config.DefaultJWKSPath]; !ok { + t.Error("JWKS was not re-uploaded") + } +} + +func TestRunDoctor_RepairReenablesStaticWebsite(t *testing.T) { + mgr := healthyRepairManager(t) + mgr.sw.Enabled = false + + streams, out, _ := mintStreams() + if err := runDoctor(context.Background(), mgr, testConfig(), doctorOptions{repair: true, json: true}, false, streams); err != nil { + t.Fatalf("expected exit 0 after repair, got %v", err) + } + report := parseReport(t, out.String()) + if c := reportCheck(t, report, checkStaticWebsite); c.Status != statusPass || !c.Repaired { + t.Errorf("static website should be PASS+repaired, got %s repaired=%v", c.Status, c.Repaired) + } +} + +func TestRunDoctor_DoesNotRepairHumanActionFails(t *testing.T) { + t.Run("legacy access-policy vault", func(t *testing.T) { + mgr := healthyRepairManager(t) + mgr.vault.RBACEnabled = false + streams, out, _ := mintStreams() + err := runDoctor(context.Background(), mgr, testConfig(), doctorOptions{repair: true, json: true}, false, streams) + if err == nil || ExitCode(err) != exitFailure { + t.Fatalf("expected exit 1 for unrepairable vault, got %v", err) + } + if c := reportCheck(t, parseReport(t, out.String()), checkVaultRBAC); c.Status != statusFail || c.Repaired { + t.Errorf("vault check should stay FAIL and not be marked repaired, got %s repaired=%v", c.Status, c.Repaired) + } + }) + + t.Run("issuer/endpoint mismatch advises setup --force-issuer-rewrite", func(t *testing.T) { + mgr := healthyRepairManager(t) + mgr.acct.WebEndpoint = "https://wrong.z99.web.core.windows.net" + streams, out, _ := mintStreams() + err := runDoctor(context.Background(), mgr, testConfig(), doctorOptions{repair: true, json: true}, false, streams) + if err == nil || ExitCode(err) != exitFailure { + t.Fatalf("expected exit 1 for endpoint mismatch, got %v", err) + } + c := reportCheck(t, parseReport(t, out.String()), checkIssuerEndpoint) + if c.Status != statusFail { + t.Fatalf("endpoint check should stay FAIL, got %s", c.Status) + } + if !strings.Contains(c.Message, "setup --force-issuer-rewrite") { + t.Errorf("endpoint FAIL should advise setup --force-issuer-rewrite, got: %s", c.Message) + } + }) +} + +func TestRunDoctor_JSONShapeAndStdoutPurity(t *testing.T) { + mgr := healthyRepairManager(t) + streams, out, errBuf := mintStreams() + if err := runDoctor(context.Background(), mgr, testConfig(), doctorOptions{json: true}, false, streams); err != nil { + t.Fatalf("healthy doctor should exit 0, got %v", err) + } + + // stdout is exactly one JSON object; stderr carries no pretty report. + trimmed := strings.TrimSpace(out.String()) + if !strings.HasPrefix(trimmed, "{") || !strings.HasSuffix(trimmed, "}") { + t.Errorf("stdout should be a single JSON object, got:\n%s", out.String()) + } + if strings.Contains(errBuf.String(), "[PASS]") { + t.Errorf("--json must suppress the pretty stderr report, got:\n%s", errBuf.String()) + } + + report := parseReport(t, out.String()) + total := report.Summary.Pass + report.Summary.Warn + report.Summary.Fail + if total != len(report.Checks) { + t.Errorf("summary totals (%d) do not match number of checks (%d)", total, len(report.Checks)) + } + if report.Summary.Fail != 0 { + t.Errorf("healthy issuer should have zero FAILs, got %d", report.Summary.Fail) + } + // The credential check is always present and first. + if len(report.Checks) == 0 || report.Checks[0].Name != checkCredential { + t.Errorf("first check should be %q", checkCredential) + } +} + +func TestRunDoctor_JSONReportsRepairedCount(t *testing.T) { + mgr := healthyRepairManager(t) + mgr.sw.Enabled = false + streams, out, _ := mintStreams() + if err := runDoctor(context.Background(), mgr, testConfig(), doctorOptions{repair: true, json: true}, false, streams); err != nil { + t.Fatalf("expected exit 0, got %v", err) + } + report := parseReport(t, out.String()) + if report.Summary.Repaired < 1 { + t.Errorf("expected at least one repaired check in summary, got %d", report.Summary.Repaired) + } +} + +func TestDoctorMissingConfigIsUsageError(t *testing.T) { + t.Setenv("JOTSMITH_CONFIG", filepath.Join(t.TempDir(), "absent.json")) + _, _, err := run(t, "doctor", "--json") + if err == nil { + t.Fatal("expected a usage error when config is missing") + } + if got := ExitCode(err); got != exitUsage { + t.Errorf("expected exit %d, got %d", exitUsage, got) + } +} + +func parseReport(t *testing.T, stdout string) doctorReport { + t.Helper() + var r doctorReport + if err := json.Unmarshal([]byte(stdout), &r); err != nil { + t.Fatalf("stdout is not a valid doctor report: %v\n%s", err, stdout) + } + return r +} + +func reportCheck(t *testing.T, r doctorReport, name string) check { + t.Helper() + for _, c := range r.Checks { + if c.Name == name { + return c + } + } + t.Fatalf("check %q not found in report", name) + return check{} +} diff --git a/internal/cli/doctor_test.go b/internal/cli/doctor_test.go new file mode 100644 index 0000000..b806d65 --- /dev/null +++ b/internal/cli/doctor_test.go @@ -0,0 +1,206 @@ +package cli + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "math/big" + "testing" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/config" + "github.com/MaxAnderson95/jotsmith/internal/jwk" +) + +type fakeInspector struct { + sub azurex.Subscription + subErr error + acct azurex.StorageAccount + acctErr error + sw azurex.StaticWebsite + swErr error + disco []byte + discoErr error + jwks []byte + jwksErr error + vault azurex.Vault + vaultErr error + key azurex.Key + keyErr error +} + +func (f *fakeInspector) GetSubscription(context.Context) (azurex.Subscription, error) { + return f.sub, f.subErr +} +func (f *fakeInspector) GetStorageAccount(context.Context) (azurex.StorageAccount, error) { + return f.acct, f.acctErr +} +func (f *fakeInspector) GetStaticWebsite(context.Context) (azurex.StaticWebsite, error) { + return f.sw, f.swErr +} +func (f *fakeInspector) GetDiscoveryDocument(context.Context) ([]byte, error) { + return f.disco, f.discoErr +} +func (f *fakeInspector) GetJWKS(context.Context) ([]byte, error) { return f.jwks, f.jwksErr } +func (f *fakeInspector) GetVault(context.Context) (azurex.Vault, error) { + return f.vault, f.vaultErr +} +func (f *fakeInspector) GetSigningKey(context.Context) (azurex.Key, error) { + return f.key, f.keyErr +} + +const issuer = "https://jotsmithmax.z13.web.core.windows.net" + +func testConfig() *config.Config { + return &config.Config{ + Version: 1, + SubscriptionID: "sub-123", + StorageAccount: "jotsmithmax", + KeyVault: "jotsmith-kv", + KeyName: "signing-key", + Issuer: issuer, + JWKSPath: config.DefaultJWKSPath, + DiscoveryPath: config.DefaultDiscoveryPath, + } +} + +// healthyInspector returns a fake whose every check passes, using a real RSA +// key so the published JWKS kid matches the signing key thumbprint. +func healthyInspector(t *testing.T) *fakeInspector { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generating RSA key: %v", err) + } + n := priv.N.Bytes() + e := big.NewInt(int64(priv.E)).Bytes() + + jwksJSON, err := json.Marshal(jwk.NewSet(jwk.FromRSA(n, e))) + if err != nil { + t.Fatalf("marshaling jwks: %v", err) + } + discoJSON := []byte(`{"issuer":"` + issuer + `","jwks_uri":"` + issuer + `/.well-known/jwks.json"}`) + + return &fakeInspector{ + sub: azurex.Subscription{ID: "sub-123", DisplayName: "Test Sub", State: "Enabled"}, + acct: azurex.StorageAccount{Name: "jotsmithmax", ResourceGroup: "rg", Kind: "StorageV2", WebEndpoint: issuer, BlobEndpoint: "https://jotsmithmax.blob.core.windows.net/"}, + sw: azurex.StaticWebsite{Enabled: true}, + disco: discoJSON, + jwks: jwksJSON, + vault: azurex.Vault{Name: "jotsmith-kv", RBACEnabled: true, URI: "https://jotsmith-kv.vault.azure.net/"}, + key: azurex.Key{Name: "signing-key", Enabled: true, Ops: []string{"sign", "verify"}, N: n, E: e}, + } +} + +func findCheck(t *testing.T, checks []check, name string) check { + t.Helper() + for _, c := range checks { + if c.Name == name { + return c + } + } + t.Fatalf("check %q not found in %v", name, checks) + return check{} +} + +func TestInspect_AllHealthy(t *testing.T) { + checks := inspect(context.Background(), healthyInspector(t), testConfig()) + if anyFailed(checks) { + for _, c := range checks { + if c.Status == statusFail { + t.Errorf("unexpected FAIL: %s: %s", c.Name, c.Message) + } + } + } + if got := findCheck(t, checks, "Signing key valid"); got.Status != statusPass { + t.Errorf("expected signing key PASS, got %s: %s", got.Status, got.Message) + } +} + +func TestInspect_LegacyAccessPolicyVaultFails(t *testing.T) { + f := healthyInspector(t) + f.vault.RBACEnabled = false + checks := inspect(context.Background(), f, testConfig()) + got := findCheck(t, checks, "Key Vault in RBAC mode") + if got.Status != statusFail { + t.Fatalf("expected FAIL for legacy vault, got %s", got.Status) + } +} + +func TestInspect_EndpointMismatchFails(t *testing.T) { + f := healthyInspector(t) + f.acct.WebEndpoint = "https://wrong.z99.web.core.windows.net" + checks := inspect(context.Background(), f, testConfig()) + got := findCheck(t, checks, "Issuer matches web endpoint") + if got.Status != statusFail { + t.Fatalf("expected FAIL for endpoint mismatch, got %s: %s", got.Status, got.Message) + } +} + +func TestInspect_MissingDiscoveryBlobFailsNoCrash(t *testing.T) { + f := healthyInspector(t) + f.disco = nil + f.discoErr = &azurex.NotFoundError{Kind: "blob", Resource: "$web/.well-known/openid-configuration"} + checks := inspect(context.Background(), f, testConfig()) + got := findCheck(t, checks, "Discovery document published") + if got.Status != statusFail { + t.Fatalf("expected FAIL for missing discovery doc, got %s", got.Status) + } +} + +func TestInspect_ThumbprintMismatchFails(t *testing.T) { + f := healthyInspector(t) + // Replace the published JWKS with a different key so the kid won't match. + other, _ := rsa.GenerateKey(rand.Reader, 2048) + on := other.N.Bytes() + oe := big.NewInt(int64(other.E)).Bytes() + f.jwks, _ = json.Marshal(jwk.NewSet(jwk.FromRSA(on, oe))) + + checks := inspect(context.Background(), f, testConfig()) + got := findCheck(t, checks, "Signing key valid") + if got.Status != statusFail { + t.Fatalf("expected FAIL for thumbprint mismatch, got %s: %s", got.Status, got.Message) + } +} + +func TestInspect_DisabledKeyFails(t *testing.T) { + f := healthyInspector(t) + f.key.Enabled = false + checks := inspect(context.Background(), f, testConfig()) + got := findCheck(t, checks, "Signing key valid") + if got.Status != statusFail { + t.Fatalf("expected FAIL for disabled key, got %s", got.Status) + } +} + +func TestRenderDoctorNoColor(t *testing.T) { + checks := []check{ + {Name: "A", Status: statusPass, Message: "ok"}, + {Name: "B", Status: statusFail, Message: "broken"}, + } + var buf testBuffer + renderDoctor(&buf, false, checks) + out := buf.String() + if !contains(out, "[PASS] A: ok") || !contains(out, "[FAIL] B: broken") { + t.Errorf("unexpected render:\n%s", out) + } + if contains(out, "\x1b[") { + t.Errorf("expected no ANSI color, got:\n%q", out) + } +} + +// minimal io.Writer for capturing render output without importing bytes here +type testBuffer struct{ b []byte } + +func (t *testBuffer) Write(p []byte) (int, error) { t.b = append(t.b, p...); return len(p), nil } +func (t *testBuffer) String() string { return string(t.b) } + +func contains(s, sub string) bool { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} diff --git a/internal/cli/errors.go b/internal/cli/errors.go new file mode 100644 index 0000000..68a2c7e --- /dev/null +++ b/internal/cli/errors.go @@ -0,0 +1,44 @@ +package cli + +import ( + "errors" + "fmt" + + ucli "github.com/urfave/cli/v3" +) + +// Exit codes used across jotsmith commands: +// +// 0 — success +// 1 — runtime failure (Azure error, verification failure, I/O error) +// 2 — usage error (bad flags, missing required input, invalid arguments) +const ( + exitOK = 0 + exitFailure = 1 + exitUsage = 2 +) + +// usageErrorf wraps a usage problem (exit code 2). +func usageErrorf(format string, args ...any) error { + return ucli.Exit(fmt.Sprintf(format, args...), exitUsage) +} + +// failuref wraps a runtime failure (exit code 1). +func failuref(format string, args ...any) error { + return ucli.Exit(fmt.Sprintf(format, args...), exitFailure) +} + +// ExitCode maps an error returned by the root command to a process exit code. +// Errors carrying an explicit code (via ucli.Exit) use it; everything else — +// including urfave's own flag-parsing and required-flag errors — is treated as +// a usage error. +func ExitCode(err error) int { + if err == nil { + return exitOK + } + var ec ucli.ExitCoder + if errors.As(err, &ec) { + return ec.ExitCode() + } + return exitUsage +} diff --git a/internal/cli/integration_test.go b/internal/cli/integration_test.go new file mode 100644 index 0000000..b43c5b2 --- /dev/null +++ b/internal/cli/integration_test.go @@ -0,0 +1,134 @@ +//go:build integration + +// Integration tests for the CLI command flows that need a live, already-set-up +// issuer. They are excluded from the default build; run them with: +// +// JOTSMITH_CONFIG=/path/to/real/config.json go test -tags integration ./internal/cli/... +// +// The config must point at an issuer previously stood up with `jotsmith setup` +// (reachable storage static website + a signing key the running principal can +// use). Without JOTSMITH_CONFIG the tests skip. +package cli + +import ( + "context" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/config" +) + +func liveConfig(t *testing.T) *config.Config { + t.Helper() + path := os.Getenv("JOTSMITH_CONFIG") + if path == "" { + t.Skip("set JOTSMITH_CONFIG to a real issuer config to run CLI integration tests") + } + cfg, err := config.Load(path, nil) + if err != nil { + t.Fatalf("loading config %s: %v", path, err) + } + return cfg +} + +// TestIntegrationMintThenVerify mints a short-lived token against the live +// issuer and verifies it end-to-end over HTTPS through the discovery + JWKS +// fetch path. +func TestIntegrationMintThenVerify(t *testing.T) { + ctx := context.Background() + cfg := liveConfig(t) + + signer, err := azurex.NewProvider(ctx, targetFrom(cfg), nil) + if err != nil { + t.Fatalf("NewProvider: %v", err) + } + + mintStream, out, _ := mintStreams() + if err := runMint(ctx, signer, cfg, mintOptions{sub: "integration", exp: "5m"}, mintStream); err != nil { + t.Fatalf("runMint: %v", err) + } + token := strings.TrimSpace(out.String()) + + verifyStream, _, errBuf := mintStreams() + client := &http.Client{Timeout: verifyHTTPTimeout} + if err := runVerify(ctx, client, cfg, token, "", "", verifyStream); err != nil { + t.Fatalf("runVerify failed: %v (stderr: %s)", err, errBuf.String()) + } +} + +// TestIntegrationDoctorRepairsDrift introduces drift by overwriting the +// published discovery document with garbage, then asserts `doctor --repair` +// re-uploads it and a subsequent read-only `doctor` passes. +func TestIntegrationDoctorRepairsDrift(t *testing.T) { + ctx := context.Background() + cfg := liveConfig(t) + + mgr, err := azurex.NewProvider(ctx, targetFrom(cfg), nil) + if err != nil { + t.Fatalf("NewProvider: %v", err) + } + + if err := mgr.UploadBlob(ctx, cfg.DiscoveryPath, "application/json", "no-cache", []byte("not json")); err != nil { + t.Fatalf("introducing drift: %v", err) + } + + repairStream, _, _ := mintStreams() + if err := runDoctor(ctx, mgr, cfg, doctorOptions{repair: true}, false, repairStream); err != nil { + t.Fatalf("doctor --repair should fix drift and exit 0: %v", err) + } + + checkStream, _, _ := mintStreams() + if err := runDoctor(ctx, mgr, cfg, doctorOptions{}, false, checkStream); err != nil { + t.Fatalf("doctor should pass after repair: %v", err) + } +} + +// TestIntegrationRotateInvalidatesPriorToken mints a token, rotates the signing +// key, and asserts the original token no longer verifies. This is destructive — +// it rotates the live signing key — so it requires an explicit opt-in beyond the +// integration tag. +func TestIntegrationRotateInvalidatesPriorToken(t *testing.T) { + if os.Getenv("JOTSMITH_TEST_ALLOW_ROTATE") == "" { + t.Skip("set JOTSMITH_TEST_ALLOW_ROTATE=1 to run the destructive rotation test") + } + ctx := context.Background() + cfg := liveConfig(t) + + signer, err := azurex.NewProvider(ctx, targetFrom(cfg), nil) + if err != nil { + t.Fatalf("NewProvider: %v", err) + } + client := &http.Client{Timeout: verifyHTTPTimeout} + + mintStream, out, _ := mintStreams() + if err := runMint(ctx, signer, cfg, mintOptions{sub: "integration-rotate", exp: "5m"}, mintStream); err != nil { + t.Fatalf("runMint: %v", err) + } + token := strings.TrimSpace(out.String()) + + preStream, _, preErr := mintStreams() + if err := runVerify(ctx, client, cfg, token, "", "", preStream); err != nil { + t.Fatalf("token should verify before rotation: %v (stderr: %s)", err, preErr.String()) + } + + rotateStream, _, _ := mintStreams() + if err := runRotate(ctx, signer, cfg, true, rotateStream); err != nil { + t.Fatalf("runRotate: %v", err) + } + + // Allow the no-cache JWKS upload to become visible, then confirm the prior + // token no longer verifies (its kid is gone from the published JWKS). + var lastErr error + for attempt := 0; attempt < 5; attempt++ { + s, _, _ := mintStreams() + if lastErr = runVerify(ctx, client, cfg, token, "", "", s); lastErr != nil { + return // expected: verification now fails + } + time.Sleep(2 * time.Second) + } + t.Fatalf("token minted under the prior key still verified after rotation (lastErr=%v)", lastErr) +} diff --git a/internal/cli/io.go b/internal/cli/io.go new file mode 100644 index 0000000..5aa83a6 --- /dev/null +++ b/internal/cli/io.go @@ -0,0 +1,24 @@ +package cli + +import ( + "io" + "os" +) + +// IOStreams bundles the three standard streams so commands and tests can inject +// their own. +// +// Out is reserved for primary, machine-consumable command output (e.g. the +// compact JWT from `token mint`, the JSON from `config show`). Everything else +// — logs, prompts, help, the `--verbose` token preview, and error messages — +// goes to Err. This split is what keeps stdout pipe-safe. +type IOStreams struct { + In io.Reader + Out io.Writer + Err io.Writer +} + +// DefaultIOStreams returns the process standard streams. +func DefaultIOStreams() IOStreams { + return IOStreams{In: os.Stdin, Out: os.Stdout, Err: os.Stderr} +} diff --git a/internal/cli/jwks.go b/internal/cli/jwks.go new file mode 100644 index 0000000..87e8263 --- /dev/null +++ b/internal/cli/jwks.go @@ -0,0 +1,66 @@ +package cli + +import ( + "context" + "encoding/json" + "fmt" + + ucli "github.com/urfave/cli/v3" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/jwk" +) + +func jwksCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "jwks", + Usage: "Show the JWKS", + Commands: []*ucli.Command{ + jwksShowCommand(streams), + }, + } +} + +func jwksShowCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "show", + Usage: "Print the JWKS computed from the current Key Vault key", + Description: "Fetch the current Key Vault public key, build its JWK, and print the JWKS " + + "to stdout exactly as setup / doctor --repair would upload it. The keys array " + + "always has exactly one entry in v1.\n\n" + + "Example:\n" + + " jotsmith jwks show", + Action: func(ctx context.Context, cmd *ucli.Command) error { + cfg, err := loadConfig(ctx, cmd) + if err != nil { + return err + } + insp, err := azurex.NewProvider(ctx, targetFrom(cfg), loggerFrom(ctx)) + if err != nil { + return failuref("%v", err) + } + + set, err := buildJWKS(ctx, insp) + if err != nil { + return keyVaultError(err, cfg.KeyVault) + } + + b, err := json.MarshalIndent(set, "", " ") + if err != nil { + return failuref("rendering JWKS: %v", err) + } + fmt.Fprintln(streams.Out, string(b)) + return nil + }, + } +} + +// buildJWKS reads the current signing key's public material and wraps it in a +// single-entry JWKS. +func buildJWKS(ctx context.Context, insp azurex.Inspector) (jwk.Set, error) { + key, err := insp.GetSigningKey(ctx) + if err != nil { + return jwk.Set{}, err + } + return jwk.NewSet(jwk.FromRSA(key.N, key.E)), nil +} diff --git a/internal/cli/key.go b/internal/cli/key.go new file mode 100644 index 0000000..2c4b7e9 --- /dev/null +++ b/internal/cli/key.go @@ -0,0 +1,132 @@ +package cli + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + + ucli "github.com/urfave/cli/v3" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/config" + "github.com/MaxAnderson95/jotsmith/internal/jwk" + "github.com/MaxAnderson95/jotsmith/internal/oidc" +) + +func keyCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "key", + Usage: "Manage the signing key", + Commands: []*ucli.Command{ + keyRotateCommand(streams), + }, + } +} + +func keyRotateCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "rotate", + Usage: "Rotate the signing key (snap-cutover)", + Description: "Create a new version of the Key Vault signing key, recompute its JWK and kid, " + + "and replace the published JWKS with a single-entry array for the new key.\n\n" + + "This is a snap-cutover: the moment rotation completes, every token minted under the " + + "prior key stops verifying. There is no overlap window.\n\n" + + "Prompts for confirmation unless --yes.\n\n" + + "Example:\n" + + " jotsmith key rotate --yes", + Flags: []ucli.Flag{ + &ucli.BoolFlag{Name: "yes", Usage: "skip the confirmation prompt"}, + }, + Action: func(ctx context.Context, cmd *ucli.Command) error { + cfg, err := loadConfig(ctx, cmd) + if err != nil { + return err + } + mgr, err := azurex.NewProvider(ctx, targetFrom(cfg), loggerFrom(ctx)) + if err != nil { + return failuref("%v", err) + } + return runRotate(ctx, mgr, cfg, cmd.Bool("yes"), streams) + }, + } +} + +// runRotate performs the snap-cutover rotation. It is decoupled from the +// concrete provider (via azurex.RotateManager) so it is unit-testable. +// +// Ordering matters for the failure mode: the new key version is created first, +// then the JWKS is replaced. If the JWKS upload fails after the version exists, +// the command exits non-zero and the published JWKS still advertises the old +// kid, which `doctor` reports as drift. +func runRotate(ctx context.Context, mgr azurex.RotateManager, cfg *config.Config, yes bool, streams IOStreams) error { + beforeKid, err := currentKID(ctx, mgr, cfg) + if err != nil { + return err + } + + ok, cerr := confirm(streams, yes, "Rotate the signing key? Every token minted with the current key will stop verifying immediately.") + if cerr != nil { + return cerr + } + if !ok { + fmt.Fprintln(streams.Err, "rotation aborted") + return nil + } + + newKey, err := mgr.CreateRSAKey(ctx, rsaKeyBits) + if err != nil { + return keyVaultError(err, cfg.KeyVault) + } + jwkEntry := jwk.FromRSA(newKey.N, newKey.E) + + jwksDoc, err := json.MarshalIndent(jwk.NewSet(jwkEntry), "", " ") + if err != nil { + return failuref("rendering JWKS: %v", err) + } + if uerr := mgr.UploadBlob(ctx, cfg.JWKSPath, contentTypeJSON, cacheControl, jwksDoc); uerr != nil { + return failuref("new key version was created but the JWKS upload failed (run `jotsmith doctor` to inspect the drift): %v", uerr) + } + + if rerr := refreshDiscoveryIfChanged(ctx, mgr, cfg); rerr != nil { + return rerr + } + + fmt.Fprintln(streams.Err, "key rotated:") + fmt.Fprintf(streams.Err, " before kid: %s\n", beforeKid) + fmt.Fprintf(streams.Err, " after kid: %s\n", jwkEntry.Kid) + return nil +} + +// currentKID returns the thumbprint of the current signing key, or "(none)" if +// no key exists yet (rotation will create the first version). +func currentKID(ctx context.Context, mgr azurex.RotateManager, cfg *config.Config) (string, error) { + key, err := mgr.GetSigningKey(ctx) + if err != nil { + var notFound *azurex.NotFoundError + if errors.As(err, ¬Found) { + return "(none)", nil + } + return "", keyVaultError(err, cfg.KeyVault) + } + return jwk.Thumbprint(key.N, key.E), nil +} + +// refreshDiscoveryIfChanged uploads the discovery document only when its +// rendered bytes differ from what is already published (rotation does not +// change the discovery doc, so normally this is a no-op). +func refreshDiscoveryIfChanged(ctx context.Context, mgr azurex.RotateManager, cfg *config.Config) error { + desired, err := json.MarshalIndent(oidc.Render(cfg.Issuer, cfg.JWKSPath), "", " ") + if err != nil { + return failuref("rendering discovery document: %v", err) + } + current, derr := mgr.GetDiscoveryDocument(ctx) + if derr == nil && bytes.Equal(bytes.TrimSpace(current), bytes.TrimSpace(desired)) { + return nil + } + if uerr := mgr.UploadBlob(ctx, cfg.DiscoveryPath, contentTypeJSON, cacheControl, desired); uerr != nil { + return failuref("refreshing discovery document: %v", uerr) + } + return nil +} diff --git a/internal/cli/key_test.go b/internal/cli/key_test.go new file mode 100644 index 0000000..de216eb --- /dev/null +++ b/internal/cli/key_test.go @@ -0,0 +1,143 @@ +package cli + +import ( + "context" + "encoding/json" + "errors" + "path/filepath" + "strings" + "testing" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/config" + "github.com/MaxAnderson95/jotsmith/internal/jwk" + "github.com/MaxAnderson95/jotsmith/internal/oidc" +) + +// rotateManager builds a fake with a current ("before") key and a distinct +// "after" key returned by CreateRSAKey, plus a published discovery doc that +// already matches what rotation would render (so no spurious refresh). +func rotateManager(t *testing.T) (mgr *fakeManager, beforeKid, afterKid string) { + t.Helper() + bn, be := rsaMaterial(t) + an, ae := rsaMaterial(t) + disco, err := json.MarshalIndent(oidc.Render(issuer, config.DefaultJWKSPath), "", " ") + if err != nil { + t.Fatal(err) + } + mgr = &fakeManager{ + fakeInspector: &fakeInspector{ + key: azurex.Key{Name: "signing-key", Enabled: true, Ops: []string{"sign", "verify"}, N: bn, E: be}, + disco: disco, + }, + createdKey: azurex.Key{Name: "signing-key", Enabled: true, Ops: []string{"sign", "verify"}, N: an, E: ae}, + } + return mgr, jwk.Thumbprint(bn, be), jwk.Thumbprint(an, ae) +} + +func TestRunRotate_HappyPath(t *testing.T) { + mgr, beforeKid, afterKid := rotateManager(t) + streams, out, errBuf := mintStreams() + + if err := runRotate(context.Background(), mgr, testConfig(), true, streams); err != nil { + t.Fatalf("runRotate: %v", err) + } + if mgr.createCalls != 1 { + t.Errorf("expected one CreateRSAKey call, got %d", mgr.createCalls) + } + if out.Len() != 0 { + t.Errorf("rotate must write nothing to stdout, got:\n%s", out.String()) + } + + // Discovery doc already matched, so only the JWKS should have been uploaded. + if len(mgr.uploads) != 1 { + t.Errorf("expected only the JWKS upload, got %d uploads", len(mgr.uploads)) + } + jwksRaw, ok := mgr.uploads[config.DefaultJWKSPath] + if !ok { + t.Fatal("JWKS was not uploaded") + } + var set jwk.Set + if err := json.Unmarshal(jwksRaw, &set); err != nil { + t.Fatalf("uploaded JWKS is not valid: %v", err) + } + if len(set.Keys) != 1 || set.Keys[0].Kid != afterKid { + t.Errorf("JWKS should be a single entry for the new kid %q, got %#v", afterKid, set.Keys) + } + + if !strings.Contains(errBuf.String(), beforeKid) || !strings.Contains(errBuf.String(), afterKid) { + t.Errorf("stderr should report before/after kids, got:\n%s", errBuf.String()) + } +} + +func TestRunRotate_NonTTYWithoutYesFails(t *testing.T) { + mgr, _, _ := rotateManager(t) + streams, _, _ := mintStreams() // In is a non-TTY strings.Reader + + err := runRotate(context.Background(), mgr, testConfig(), false, streams) + if err == nil || ExitCode(err) != exitFailure { + t.Fatalf("expected exit-1 refusal without --yes on non-TTY, got %v (code %d)", err, ExitCode(err)) + } + if mgr.createCalls != 0 { + t.Error("no key should be created when confirmation is refused") + } + if len(mgr.uploads) != 0 { + t.Error("nothing should be uploaded when confirmation is refused") + } +} + +func TestRunRotate_JWKSUploadFailureLeavesDrift(t *testing.T) { + mgr, _, _ := rotateManager(t) + mgr.uploadErr = errors.New("403 forbidden") + streams, _, _ := mintStreams() + + err := runRotate(context.Background(), mgr, testConfig(), true, streams) + if err == nil || ExitCode(err) != exitFailure { + t.Fatalf("expected exit-1 on JWKS upload failure, got %v (code %d)", err, ExitCode(err)) + } + // The key version was created before the upload, so doctor would see drift. + if mgr.createCalls != 1 { + t.Errorf("expected the new key version to have been created, got %d calls", mgr.createCalls) + } +} + +func TestRunRotate_RefreshesDiscoveryWhenChanged(t *testing.T) { + t.Run("stale discovery is rewritten", func(t *testing.T) { + mgr, _, _ := rotateManager(t) + mgr.disco = []byte(`{"issuer":"stale"}`) + streams, _, _ := mintStreams() + if err := runRotate(context.Background(), mgr, testConfig(), true, streams); err != nil { + t.Fatal(err) + } + if _, ok := mgr.uploads[config.DefaultDiscoveryPath]; !ok { + t.Error("stale discovery doc should have been refreshed") + } + if len(mgr.uploads) != 2 { + t.Errorf("expected JWKS + discovery uploads, got %d", len(mgr.uploads)) + } + }) + + t.Run("missing discovery is written", func(t *testing.T) { + mgr, _, _ := rotateManager(t) + mgr.disco = nil + mgr.discoErr = &azurex.NotFoundError{Kind: "blob", Resource: config.DefaultDiscoveryPath} + streams, _, _ := mintStreams() + if err := runRotate(context.Background(), mgr, testConfig(), true, streams); err != nil { + t.Fatal(err) + } + if _, ok := mgr.uploads[config.DefaultDiscoveryPath]; !ok { + t.Error("missing discovery doc should have been written") + } + }) +} + +func TestKeyRotateMissingConfigIsUsageError(t *testing.T) { + t.Setenv("JOTSMITH_CONFIG", filepath.Join(t.TempDir(), "absent.json")) + _, _, err := run(t, "key", "rotate", "--yes") + if err == nil { + t.Fatal("expected a usage error when config is missing") + } + if got := ExitCode(err); got != exitUsage { + t.Errorf("expected exit %d, got %d", exitUsage, got) + } +} diff --git a/internal/cli/logging.go b/internal/cli/logging.go new file mode 100644 index 0000000..347c99f --- /dev/null +++ b/internal/cli/logging.go @@ -0,0 +1,137 @@ +package cli + +import ( + "context" + "fmt" + "io" + "log/slog" + "os" + "strings" + "sync" +) + +// LevelTrace is the most verbose level, below slog.LevelDebug. slog has no +// native trace level, so jotsmith defines one. +const LevelTrace = slog.Level(-8) + +// parseLogLevel maps a user-facing level string to an slog.Level. An empty +// string defaults to info. +func parseLogLevel(s string) (slog.Level, error) { + switch strings.ToLower(strings.TrimSpace(s)) { + case "error": + return slog.LevelError, nil + case "warn", "warning": + return slog.LevelWarn, nil + case "", "info": + return slog.LevelInfo, nil + case "debug": + return slog.LevelDebug, nil + case "trace": + return LevelTrace, nil + default: + return 0, fmt.Errorf("invalid log level %q (want one of: error, warn, info, debug, trace)", s) + } +} + +// newLogger builds an slog.Logger that writes a compact, optionally colorized +// line per record to w. Color is only emitted when requested and w is a +// terminal, so redirected/captured output stays clean. +func newLogger(w io.Writer, level slog.Level, color bool) *slog.Logger { + return slog.New(&prettyHandler{ + mu: &sync.Mutex{}, + w: w, + level: level, + color: color && isTerminal(w), + }) +} + +type prettyHandler struct { + mu *sync.Mutex + w io.Writer + level slog.Level + color bool + attrs string +} + +func (h *prettyHandler) Enabled(_ context.Context, l slog.Level) bool { + return l >= h.level +} + +func (h *prettyHandler) Handle(_ context.Context, r slog.Record) error { + var b strings.Builder + b.WriteString(h.levelLabel(r.Level)) + b.WriteByte(' ') + b.WriteString(r.Message) + b.WriteString(h.attrs) + r.Attrs(func(a slog.Attr) bool { + writeAttr(&b, a) + return true + }) + b.WriteByte('\n') + + h.mu.Lock() + defer h.mu.Unlock() + _, err := io.WriteString(h.w, b.String()) + return err +} + +func (h *prettyHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + if len(attrs) == 0 { + return h + } + var b strings.Builder + b.WriteString(h.attrs) + for _, a := range attrs { + writeAttr(&b, a) + } + nh := *h + nh.attrs = b.String() + return &nh +} + +// WithGroup is a no-op: jotsmith's log records do not use attribute groups. +func (h *prettyHandler) WithGroup(string) slog.Handler { return h } + +func writeAttr(b *strings.Builder, a slog.Attr) { + if a.Equal(slog.Attr{}) { + return + } + b.WriteByte(' ') + b.WriteString(a.Key) + b.WriteByte('=') + fmt.Fprint(b, a.Value.Any()) +} + +func (h *prettyHandler) levelLabel(l slog.Level) string { + var name, color string + switch { + case l < slog.LevelDebug: + name, color = "TRACE", "\x1b[90m" + case l < slog.LevelInfo: + name, color = "DEBUG", "\x1b[36m" + case l < slog.LevelWarn: + name, color = "INFO", "\x1b[32m" + case l < slog.LevelError: + name, color = "WARN", "\x1b[33m" + default: + name, color = "ERROR", "\x1b[31m" + } + if h.color { + return color + name + "\x1b[0m" + } + return name +} + +// isTerminal reports whether v (an *os.File stream) is a character device (a +// TTY). Non-file streams — such as the buffers used in tests — are never TTYs. +func isTerminal(v any) bool { + f, ok := v.(*os.File) + if !ok { + return false + } + fi, err := f.Stat() + if err != nil { + return false + } + return fi.Mode()&os.ModeCharDevice != 0 +} diff --git a/internal/cli/prompt.go b/internal/cli/prompt.go new file mode 100644 index 0000000..64e5b28 --- /dev/null +++ b/internal/cli/prompt.go @@ -0,0 +1,28 @@ +package cli + +import ( + "bufio" + "fmt" + "strings" +) + +// confirm asks the user to approve a destructive operation. If yes is true it +// returns true without prompting. Otherwise it requires an interactive terminal +// on stdin: when stdin is not a TTY it returns an error rather than hanging or +// silently proceeding, so non-interactive callers must pass --yes explicitly. +func confirm(streams IOStreams, yes bool, question string) (bool, error) { + if yes { + return true, nil + } + if !isTerminal(streams.In) { + return false, failuref("refusing to prompt because stdin is not a terminal; re-run with --yes to confirm: %s", question) + } + + fmt.Fprintf(streams.Err, "%s [y/N]: ", question) + line, err := bufio.NewReader(streams.In).ReadString('\n') + if err != nil && line == "" { + return false, nil + } + answer := strings.ToLower(strings.TrimSpace(line)) + return answer == "y" || answer == "yes", nil +} diff --git a/internal/cli/root.go b/internal/cli/root.go new file mode 100644 index 0000000..5b07418 --- /dev/null +++ b/internal/cli/root.go @@ -0,0 +1,104 @@ +package cli + +import ( + "context" + "os" + + ucli "github.com/urfave/cli/v3" +) + +// NewRootCommand builds the jotsmith root command, wired with the global flags +// every subcommand inherits and the version/help routing the CLI promises: +// +// - jotsmith --version -> version info on stdout, exit 0 +// - jotsmith / help -> usage on stderr, exit 0 +// +// Help and usage are routed to stderr (root Writer = streams.Err) so stdout +// stays reserved for machine-consumable output. The version printer writes to +// streams.Out explicitly. +// +// The root installs a no-op ExitErrHandler so that Run never calls os.Exit on +// its own; exit-code mapping is centralized in ExitCode and applied by main. +func NewRootCommand(streams IOStreams, build BuildInfo) *ucli.Command { + root := &ucli.Command{ + Name: "jotsmith", + Usage: "Stand up a personal OIDC issuer in Azure and mint JWTs against it", + Version: build.Version, + HideVersion: true, // rendered manually so --version lands on stdout + EnableShellCompletion: true, + ConfigureShellCompletionCommand: configureCompletion(streams), + Reader: streams.In, + Writer: streams.Err, + ErrWriter: streams.Err, + Flags: globalFlags(), + Commands: []*ucli.Command{ + configCommand(streams), + tokenCommand(streams), + doctorCommand(streams), + discoveryCommand(streams), + jwksCommand(streams), + setupCommand(streams), + keyCommand(streams), + destroyCommand(streams), + }, + ExitErrHandler: func(context.Context, *ucli.Command, error) {}, + } + + root.Before = func(ctx context.Context, cmd *ucli.Command) (context.Context, error) { + level, err := parseLogLevel(cmd.String("log-level")) + if err != nil { + return ctx, usageErrorf("%v", err) + } + logger := newLogger(streams.Err, level, !noColor(cmd)) + return withLogger(ctx, logger), nil + } + + root.Action = func(_ context.Context, cmd *ucli.Command) error { + if cmd.Bool("version") { + printVersion(streams.Out, cmd.Name, build) + return nil + } + // No subcommand given: show usage on stderr and exit 0. + _ = ucli.ShowAppHelp(cmd) + return nil + } + + return root +} + +func globalFlags() []ucli.Flag { + return []ucli.Flag{ + &ucli.StringFlag{ + Name: "config", + Usage: "path to the jotsmith config file", + Sources: ucli.EnvVars("JOTSMITH_CONFIG"), + }, + &ucli.StringFlag{ + Name: "log-level", + Usage: "log verbosity: error, warn, info, debug, or trace", + Value: "info", + Sources: ucli.EnvVars("JOTSMITH_LOG_LEVEL"), + }, + &ucli.BoolFlag{ + Name: "no-color", + Usage: "disable ANSI color in stderr output", + }, + &ucli.BoolFlag{ + Name: "version", + Usage: "print version information and exit", + }, + } +} + +// noColor honors both the --no-color flag and the presence (per the NO_COLOR +// convention, any value including empty) of the NO_COLOR environment variable. +// NO_COLOR is read directly rather than bound to the bool flag because the +// convention is presence-based and would otherwise fail to parse values like +// "yes". +func noColor(cmd *ucli.Command) bool { + if cmd.Bool("no-color") { + return true + } + _, present := os.LookupEnv("NO_COLOR") + return present +} diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go new file mode 100644 index 0000000..02bb53c --- /dev/null +++ b/internal/cli/root_test.go @@ -0,0 +1,120 @@ +package cli + +import ( + "bytes" + "context" + "strings" + "testing" +) + +// run builds a fresh root command wired to in-memory buffers and executes it +// with the given args (excluding the program name, which is prepended). +func run(t *testing.T, args ...string) (stdout, stderr string, err error) { + t.Helper() + var out, errBuf bytes.Buffer + streams := IOStreams{In: strings.NewReader(""), Out: &out, Err: &errBuf} + root := NewRootCommand(streams, BuildInfo{Version: "test", Commit: "abc1234", Date: "2026-01-01"}) + err = root.Run(context.Background(), append([]string{"jotsmith"}, args...)) + return out.String(), errBuf.String(), err +} + +func TestVersionGoesToStdout(t *testing.T) { + stdout, stderr, err := run(t, "--version") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(stdout, "jotsmith version test") { + t.Errorf("stdout missing version line:\n%s", stdout) + } + if !strings.Contains(stdout, "commit: abc1234") { + t.Errorf("stdout missing commit:\n%s", stdout) + } + if !strings.Contains(stdout, "built: 2026-01-01") { + t.Errorf("stdout missing build date:\n%s", stdout) + } + if stderr != "" { + t.Errorf("expected empty stderr for --version, got:\n%s", stderr) + } +} + +func TestNoArgsPrintsUsageToStderrExitZero(t *testing.T) { + stdout, stderr, err := run(t) + if err != nil { + t.Fatalf("expected exit 0, got error: %v", err) + } + if stdout != "" { + t.Errorf("expected empty stdout, got:\n%s", stdout) + } + if !strings.Contains(stderr, "USAGE") && !strings.Contains(stderr, "NAME:") { + t.Errorf("expected usage on stderr, got:\n%s", stderr) + } +} + +func TestHelpPrintsUsageToStderrExitZero(t *testing.T) { + stdout, stderr, err := run(t, "help") + if err != nil { + t.Fatalf("expected exit 0, got error: %v", err) + } + if stdout != "" { + t.Errorf("expected empty stdout from help, got:\n%s", stdout) + } + if stderr == "" { + t.Error("expected usage text on stderr from help") + } +} + +func TestUnknownFlagIsUsageError(t *testing.T) { + stdout, _, err := run(t, "--bogus") + if err == nil { + t.Fatal("expected an error for an unknown flag") + } + if got := ExitCode(err); got != exitUsage { + t.Errorf("expected exit code %d, got %d", exitUsage, got) + } + if stdout != "" { + t.Errorf("expected empty stdout on usage error, got:\n%s", stdout) + } +} + +func TestInvalidLogLevelIsUsageError(t *testing.T) { + _, _, err := run(t, "--log-level", "shout") + if err == nil { + t.Fatal("expected an error for an invalid log level") + } + if got := ExitCode(err); got != exitUsage { + t.Errorf("expected exit code %d, got %d", exitUsage, got) + } +} + +func TestGlobalFlagsAreRegistered(t *testing.T) { + streams := IOStreams{In: strings.NewReader(""), Out: &bytes.Buffer{}, Err: &bytes.Buffer{}} + root := NewRootCommand(streams, BuildInfo{Version: "test"}) + want := []string{"config", "log-level", "no-color"} + have := map[string]bool{} + for _, f := range root.Flags { + for _, n := range f.Names() { + have[n] = true + } + } + for _, w := range want { + if !have[w] { + t.Errorf("global flag %q not registered", w) + } + } +} + +func TestParseLogLevel(t *testing.T) { + cases := map[string]bool{ + "error": true, "warn": true, "info": true, "debug": true, "trace": true, + "": true, "INFO": true, "bogus": false, + } + for in, ok := range cases { + _, err := parseLogLevel(in) + if ok && err != nil { + t.Errorf("parseLogLevel(%q) unexpected error: %v", in, err) + } + if !ok && err == nil { + t.Errorf("parseLogLevel(%q) expected error, got nil", in) + } + } +} diff --git a/internal/cli/setup.go b/internal/cli/setup.go new file mode 100644 index 0000000..3799f12 --- /dev/null +++ b/internal/cli/setup.go @@ -0,0 +1,234 @@ +package cli + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + + ucli "github.com/urfave/cli/v3" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/config" + "github.com/MaxAnderson95/jotsmith/internal/jwk" + "github.com/MaxAnderson95/jotsmith/internal/oidc" +) + +const ( + contentTypeJSON = "application/json" + cacheControl = "no-cache" + rsaKeyBits = 2048 +) + +type setupOptions struct { + subscription string + storageAccount string + keyVault string + keyName string + force bool + forceIssuerRewrite bool + yes bool + configPath string +} + +func setupCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "setup", + Usage: "Configure an existing Storage Account + Key Vault as an OIDC issuer", + Description: "Bring an existing Storage Account and Key Vault into the state required to be a " + + "working issuer and write the local config. Never provisions Azure resources.\n\n" + + "Example:\n" + + " jotsmith setup --subscription --storage-account --key-vault ", + Flags: []ucli.Flag{ + &ucli.StringFlag{Name: "subscription", Usage: "Azure subscription ID", Required: true}, + &ucli.StringFlag{Name: "storage-account", Usage: "existing storage account name", Required: true}, + &ucli.StringFlag{Name: "key-vault", Usage: "existing key vault name (RBAC mode)", Required: true}, + &ucli.StringFlag{Name: "key-name", Usage: "signing key name", Value: config.DefaultKeyName}, + &ucli.BoolFlag{Name: "force", Usage: "rotate the key (same effect as `key rotate`)"}, + &ucli.BoolFlag{Name: "force-issuer-rewrite", Usage: "rewrite the config issuer when the storage web endpoint legitimately changed"}, + &ucli.BoolFlag{Name: "yes", Usage: "skip confirmation prompts"}, + }, + Action: func(ctx context.Context, cmd *ucli.Command) error { + opts := setupOptions{ + subscription: cmd.String("subscription"), + storageAccount: cmd.String("storage-account"), + keyVault: cmd.String("key-vault"), + keyName: cmd.String("key-name"), + force: cmd.Bool("force"), + forceIssuerRewrite: cmd.Bool("force-issuer-rewrite"), + yes: cmd.Bool("yes"), + configPath: config.ResolvePath(cmd.String("config")), + } + log := loggerFrom(ctx) + + mgr, err := azurex.NewProvider(ctx, azurex.Target{ + SubscriptionID: opts.subscription, + StorageAccount: opts.storageAccount, + KeyVault: opts.keyVault, + KeyName: opts.keyName, + DiscoveryPath: config.DefaultDiscoveryPath, + JWKSPath: config.DefaultJWKSPath, + }, log) + if err != nil { + return failuref("%v", err) + } + + existing := loadExistingConfig(opts.configPath, log) + return runSetup(ctx, mgr, opts, existing, streams, log) + }, + } +} + +// loadExistingConfig returns the config at path if it loads cleanly, or nil if +// it is absent or unreadable (setup is allowed to run with no prior config). +func loadExistingConfig(path string, log *slog.Logger) *config.Config { + cfg, err := config.Load(path, log) + if err != nil { + return nil + } + return cfg +} + +// runSetup performs the setup workflow against mgr. It is decoupled from the +// concrete Azure provider so the orchestration is unit-testable. +func runSetup(ctx context.Context, mgr azurex.SetupManager, opts setupOptions, existing *config.Config, streams IOStreams, log *slog.Logger) error { + if _, err := mgr.GetSubscription(ctx); err != nil { + return failuref("subscription is not accessible: %v", err) + } + + acct, err := mgr.GetStorageAccount(ctx) + if err != nil { + return failuref("%v", err) + } + // GPv2 (StorageV2) is required for static website hosting; all current + // Azure regions that offer GPv2 support static websites, so the empty + // web-endpoint check below catches any account that cannot host one. + if acct.Kind != "StorageV2" { + return failuref("storage account %q is kind %q; static website hosting requires a general-purpose v2 (StorageV2) account", acct.Name, acct.Kind) + } + if acct.WebEndpoint == "" { + return failuref("storage account %q does not expose a static-website endpoint", acct.Name) + } + issuer := acct.WebEndpoint + + if existing != nil && existing.Issuer != issuer { + if !opts.forceIssuerRewrite { + return failuref("storage web endpoint %q differs from the configured issuer %q; re-run with --force-issuer-rewrite to update it (this invalidates every consumer's trust policy)", issuer, existing.Issuer) + } + ok, cerr := confirm(streams, opts.yes, fmt.Sprintf("Rewrite issuer from %q to %q? Consumers trusting the old issuer will reject all tokens.", existing.Issuer, issuer)) + if cerr != nil { + return cerr + } + if !ok { + fmt.Fprintln(streams.Err, "setup aborted") + return nil + } + } + + vault, err := mgr.GetVault(ctx) + if err != nil { + return failuref("%v", err) + } + if !vault.RBACEnabled { + return failuref("key vault %q is in legacy access-policy mode; jotsmith requires Azure RBAC mode", vault.Name) + } + + if sw, swErr := mgr.GetStaticWebsite(ctx); swErr != nil { + return failuref("%v", swErr) + } else if !sw.Enabled { + log.Info("enabling static website hosting", "storage_account", acct.Name) + if eerr := mgr.EnableStaticWebsite(ctx); eerr != nil { + return failuref("%v", eerr) + } + } + + key, err := resolveSigningKey(ctx, mgr, opts, streams, log) + if err != nil { + return err + } + + jwkEntry := jwk.FromRSA(key.N, key.E) + discoveryDoc, err := json.MarshalIndent(oidc.Render(issuer, config.DefaultJWKSPath), "", " ") + if err != nil { + return failuref("rendering discovery document: %v", err) + } + jwksDoc, err := json.MarshalIndent(jwk.NewSet(jwkEntry), "", " ") + if err != nil { + return failuref("rendering JWKS: %v", err) + } + + if uerr := mgr.UploadBlob(ctx, config.DefaultDiscoveryPath, contentTypeJSON, cacheControl, discoveryDoc); uerr != nil { + return failuref("%v", uerr) + } + if uerr := mgr.UploadBlob(ctx, config.DefaultJWKSPath, contentTypeJSON, cacheControl, jwksDoc); uerr != nil { + return failuref("%v", uerr) + } + + cfg := &config.Config{ + Version: config.SupportedVersion, + SubscriptionID: opts.subscription, + StorageAccount: opts.storageAccount, + KeyVault: opts.keyVault, + KeyName: opts.keyName, + Issuer: issuer, + JWKSPath: config.DefaultJWKSPath, + DiscoveryPath: config.DefaultDiscoveryPath, + } + if werr := config.Write(opts.configPath, cfg, log); werr != nil { + return failuref("writing config: %v", werr) + } + + printSetupSummary(streams, cfg, jwkEntry.Kid) + return nil +} + +// resolveSigningKey keeps an existing enabled key, creates one if absent, or +// rotates (creates a new version) when --force is set. +func resolveSigningKey(ctx context.Context, mgr azurex.SetupManager, opts setupOptions, streams IOStreams, log *slog.Logger) (azurex.Key, error) { + key, err := mgr.GetSigningKey(ctx) + + var notFound *azurex.NotFoundError + switch { + case errors.As(err, ¬Found): + log.Info("creating signing key", "key_name", opts.keyName) + key, err = mgr.CreateRSAKey(ctx, rsaKeyBits) + if err != nil { + return azurex.Key{}, keyVaultError(err, opts.keyVault) + } + return key, nil + + case err != nil: + return azurex.Key{}, keyVaultError(err, opts.keyVault) + + case opts.force: + ok, cerr := confirm(streams, opts.yes, "Rotate the signing key? Tokens minted with the current key will stop verifying immediately.") + if cerr != nil { + return azurex.Key{}, cerr + } + if !ok { + return azurex.Key{}, failuref("setup aborted") + } + log.Info("rotating signing key", "key_name", opts.keyName) + key, err = mgr.CreateRSAKey(ctx, rsaKeyBits) + if err != nil { + return azurex.Key{}, keyVaultError(err, opts.keyVault) + } + return key, nil + + case !key.Enabled: + return azurex.Key{}, failuref("signing key %q exists but is disabled; re-run with --force to create a new version", opts.keyName) + + default: + return key, nil + } +} + +func printSetupSummary(streams IOStreams, cfg *config.Config, kid string) { + fmt.Fprintln(streams.Err, "setup complete:") + fmt.Fprintf(streams.Err, " issuer URL: %s\n", cfg.Issuer) + fmt.Fprintf(streams.Err, " kid: %s\n", kid) + fmt.Fprintf(streams.Err, " discovery: %s\n", oidc.JoinURL(cfg.Issuer, cfg.DiscoveryPath)) + fmt.Fprintf(streams.Err, " JWKS: %s\n", oidc.JoinURL(cfg.Issuer, cfg.JWKSPath)) + fmt.Fprintf(streams.Err, " config: written\n") +} diff --git a/internal/cli/setup_test.go b/internal/cli/setup_test.go new file mode 100644 index 0000000..ef8a1cd --- /dev/null +++ b/internal/cli/setup_test.go @@ -0,0 +1,229 @@ +package cli + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "log/slog" + "math/big" + "path/filepath" + "strings" + "testing" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/config" +) + +func rsaMaterial(t *testing.T) (n, e []byte) { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generating RSA key: %v", err) + } + return priv.N.Bytes(), big.NewInt(int64(priv.E)).Bytes() +} + +type fakeManager struct { + *fakeInspector + enableErr error + uploads map[string][]byte + uploadErr error + createCalls int + createdKey azurex.Key + createErr error +} + +func (m *fakeManager) EnableStaticWebsite(context.Context) error { + if m.enableErr != nil { + return m.enableErr + } + m.sw.Enabled = true + return nil +} + +func (m *fakeManager) UploadBlob(_ context.Context, path, _, _ string, data []byte) error { + if m.uploadErr != nil { + return m.uploadErr + } + if m.uploads == nil { + m.uploads = map[string][]byte{} + } + m.uploads[path] = data + // Reflect the upload in the inspector's view so a subsequent read (e.g. a + // doctor --repair re-inspect) sees the repaired blob. + switch path { + case config.DefaultDiscoveryPath: + m.disco, m.discoErr = data, nil + case config.DefaultJWKSPath: + m.jwks, m.jwksErr = data, nil + } + return nil +} + +func (m *fakeManager) CreateRSAKey(context.Context, int) (azurex.Key, error) { + if m.createErr != nil { + return azurex.Key{}, m.createErr + } + m.createCalls++ + return m.createdKey, nil +} + +// healthyManager returns a fake whose subscription/account/vault are all good, +// with no signing key yet (so a fresh setup creates one). +func healthyManager(t *testing.T) *fakeManager { + t.Helper() + n, e := rsaMaterial(t) + return &fakeManager{ + fakeInspector: &fakeInspector{ + sub: azurex.Subscription{ID: "sub-123", DisplayName: "Test"}, + acct: azurex.StorageAccount{Name: "jotsmithmax", ResourceGroup: "rg", Kind: "StorageV2", WebEndpoint: issuer, BlobEndpoint: "https://jotsmithmax.blob.core.windows.net/"}, + sw: azurex.StaticWebsite{Enabled: false}, + vault: azurex.Vault{Name: "jotsmith-kv", RBACEnabled: true, URI: "https://jotsmith-kv.vault.azure.net/"}, + keyErr: &azurex.NotFoundError{Kind: "signing key", Resource: "signing-key"}, + }, + createdKey: azurex.Key{Name: "signing-key", Enabled: true, Ops: []string{"sign", "verify"}, N: n, E: e}, + } +} + +func setupOpts(t *testing.T) (setupOptions, IOStreams) { + t.Helper() + return setupOptions{ + subscription: "sub-123", + storageAccount: "jotsmithmax", + keyVault: "jotsmith-kv", + keyName: "signing-key", + configPath: filepath.Join(t.TempDir(), "config.json"), + }, IOStreams{ + In: strings.NewReader(""), + Out: &bytes.Buffer{}, + Err: &bytes.Buffer{}, + } +} + +func discardLogger() *slog.Logger { return slog.New(slog.NewTextHandler(discard{}, nil)) } + +func TestRunSetup_FreshCreatesKeyUploadsAndWritesConfig(t *testing.T) { + mgr := healthyManager(t) + opts, streams := setupOpts(t) + + if err := runSetup(context.Background(), mgr, opts, nil, streams, discardLogger()); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if mgr.createCalls != 1 { + t.Errorf("expected one CreateRSAKey call, got %d", mgr.createCalls) + } + if _, ok := mgr.uploads[config.DefaultDiscoveryPath]; !ok { + t.Error("discovery document was not uploaded") + } + if _, ok := mgr.uploads[config.DefaultJWKSPath]; !ok { + t.Error("JWKS was not uploaded") + } + cfg, err := config.Load(opts.configPath, nil) + if err != nil { + t.Fatalf("config not written/loadable: %v", err) + } + if cfg.Issuer != issuer { + t.Errorf("issuer frozen incorrectly: %q", cfg.Issuer) + } + // stdout must stay empty; the summary goes to stderr. + if out := streams.Out.(*bytes.Buffer).String(); out != "" { + t.Errorf("expected empty stdout, got: %s", out) + } +} + +func TestRunSetup_IdempotentKeepsEnabledKey(t *testing.T) { + mgr := healthyManager(t) + n, e := rsaMaterial(t) + mgr.keyErr = nil + mgr.key = azurex.Key{Name: "signing-key", Enabled: true, Ops: []string{"sign", "verify"}, N: n, E: e} + opts, streams := setupOpts(t) + + if err := runSetup(context.Background(), mgr, opts, nil, streams, discardLogger()); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if mgr.createCalls != 0 { + t.Errorf("expected no key creation on idempotent run, got %d", mgr.createCalls) + } + if len(mgr.uploads) != 2 { + t.Errorf("expected refresh-upload of both docs, got %d uploads", len(mgr.uploads)) + } +} + +func TestRunSetup_LegacyVaultFails(t *testing.T) { + mgr := healthyManager(t) + mgr.vault.RBACEnabled = false + opts, streams := setupOpts(t) + err := runSetup(context.Background(), mgr, opts, nil, streams, discardLogger()) + if err == nil || ExitCode(err) != exitFailure { + t.Fatalf("expected exit-1 failure for legacy vault, got %v (code %d)", err, ExitCode(err)) + } +} + +func TestRunSetup_NonGPv2Fails(t *testing.T) { + mgr := healthyManager(t) + mgr.acct.Kind = "Storage" + opts, streams := setupOpts(t) + if err := runSetup(context.Background(), mgr, opts, nil, streams, discardLogger()); err == nil { + t.Fatal("expected failure for non-GPv2 account") + } +} + +func TestRunSetup_ForceWithoutYesOnNonTTYFails(t *testing.T) { + mgr := healthyManager(t) + n, e := rsaMaterial(t) + mgr.keyErr = nil + mgr.key = azurex.Key{Name: "signing-key", Enabled: true, Ops: []string{"sign", "verify"}, N: n, E: e} + opts, streams := setupOpts(t) + opts.force = true // rotation requires confirmation; stdin is not a TTY and --yes absent + + if err := runSetup(context.Background(), mgr, opts, nil, streams, discardLogger()); err == nil { + t.Fatal("expected refusal to prompt for --force without --yes on non-TTY") + } + if mgr.createCalls != 0 { + t.Error("key must not be rotated when confirmation is refused") + } +} + +func TestRunSetup_ForceWithYesRotates(t *testing.T) { + mgr := healthyManager(t) + n, e := rsaMaterial(t) + mgr.keyErr = nil + mgr.key = azurex.Key{Name: "signing-key", Enabled: true, Ops: []string{"sign", "verify"}, N: n, E: e} + opts, streams := setupOpts(t) + opts.force = true + opts.yes = true + + if err := runSetup(context.Background(), mgr, opts, nil, streams, discardLogger()); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if mgr.createCalls != 1 { + t.Errorf("expected rotation to create a new key version, got %d calls", mgr.createCalls) + } +} + +func TestRunSetup_IssuerRewriteGuard(t *testing.T) { + mgr := healthyManager(t) + opts, streams := setupOpts(t) + existing := &config.Config{ + Version: 1, SubscriptionID: "sub-123", StorageAccount: "jotsmithmax", + KeyVault: "jotsmith-kv", KeyName: "signing-key", + Issuer: "https://old.z99.web.core.windows.net", + JWKSPath: config.DefaultJWKSPath, + DiscoveryPath: config.DefaultDiscoveryPath, + } + + // Without --force-issuer-rewrite: refuse. + if err := runSetup(context.Background(), mgr, opts, existing, streams, discardLogger()); err == nil { + t.Fatal("expected refusal to rewrite issuer without --force-issuer-rewrite") + } + + // With --force-issuer-rewrite + --yes: proceed. + mgr2 := healthyManager(t) + opts2, streams2 := setupOpts(t) + opts2.forceIssuerRewrite = true + opts2.yes = true + if err := runSetup(context.Background(), mgr2, opts2, existing, streams2, discardLogger()); err != nil { + t.Fatalf("expected issuer rewrite to proceed with flag+yes, got %v", err) + } +} diff --git a/internal/cli/token.go b/internal/cli/token.go new file mode 100644 index 0000000..e405ef6 --- /dev/null +++ b/internal/cli/token.go @@ -0,0 +1,371 @@ +package cli + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os" + "strings" + "time" + + "github.com/google/uuid" + ucli "github.com/urfave/cli/v3" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/config" + "github.com/MaxAnderson95/jotsmith/internal/jwk" + "github.com/MaxAnderson95/jotsmith/internal/sign" +) + +// defaultExp is the token lifetime when --exp is omitted. maxExp is the longest +// lifetime allowed without --allow-long-lived (PRD §10 OQ #3: tokens are +// short-lived by default for a test tool). +const ( + defaultExp = "15m" + maxExp = 24 * time.Hour +) + +func tokenCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "token", + Usage: "Mint, verify, and inspect JWTs", + Commands: []*ucli.Command{ + tokenMintCommand(streams), + tokenVerifyCommand(streams), + tokenDecodeCommand(streams), + }, + } +} + +// mintOptions is the resolved flag set for `token mint`. iss is deliberately +// absent: it always comes from config and can never be supplied by the user. +type mintOptions struct { + sub string + aud []string + exp string + iat string + nbf string + jti string + claim []string + claimJSON []string + claimsFile string + allowLongLived bool + verbose bool + + // Whether each standard-claim flag was explicitly set. A custom claim + // source only overrides a standard claim the user did not set; an explicit + // standard flag always wins. + audSet bool + iatSet bool + nbfSet bool + expSet bool + jtiSet bool +} + +func tokenMintCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "mint", + Usage: "Mint a signed JWT against the configured issuer", + Description: "Assemble standard and custom claims, sign the token inside Key Vault, and " + + "write the compact JWT to stdout (and nothing else).\n\n" + + "Custom claims merge lowest-to-highest precedence:\n" + + " --claims-file < --claim-json < --claim\n" + + "Standard-claim flags (--sub, --aud, --exp, --iat, --nbf, --jti) always win over any " + + "custom source. The iss claim always comes from config; there is no --iss flag and a " + + "custom iss is ignored.\n\n" + + "Examples:\n" + + " jotsmith token mint --sub octo-sts --aud sigstore --exp 5m\n" + + " jotsmith token mint --sub ci --claim repo=acme/app --claim-json admin=true", + Flags: []ucli.Flag{ + &ucli.StringFlag{Name: "sub", Usage: "subject (sub) claim", Required: true}, + &ucli.StringSliceFlag{Name: "aud", Usage: "audience (aud); repeat for multiple"}, + &ucli.StringFlag{Name: "exp", Usage: "expiry as a duration (15m, 1h) relative to iat, or an RFC3339 timestamp", Value: defaultExp}, + &ucli.StringFlag{Name: "iat", Usage: "issued-at as an RFC3339 timestamp (default: now)"}, + &ucli.StringFlag{Name: "nbf", Usage: "not-before as an RFC3339 timestamp (default: iat)"}, + &ucli.StringFlag{Name: "jti", Usage: "JWT ID (default: a fresh UUID v4)"}, + &ucli.StringSliceFlag{Name: "claim", Usage: "custom string claim key=value; repeatable"}, + &ucli.StringSliceFlag{Name: "claim-json", Usage: "custom JSON claim key=; repeatable (numbers, booleans, arrays, objects)"}, + &ucli.StringFlag{Name: "claims-file", Usage: "path to a JSON object whose members are merged in as claims"}, + &ucli.BoolFlag{Name: "allow-long-lived", Usage: "permit exp more than 24h in the future"}, + &ucli.BoolFlag{Name: "verbose", Usage: "print the decoded token and metadata to stderr after minting"}, + }, + Action: func(ctx context.Context, cmd *ucli.Command) error { + cfg, err := loadConfig(ctx, cmd) + if err != nil { + return err + } + opts := mintOptions{ + sub: cmd.String("sub"), + aud: cmd.StringSlice("aud"), + exp: cmd.String("exp"), + iat: cmd.String("iat"), + nbf: cmd.String("nbf"), + jti: cmd.String("jti"), + claim: cmd.StringSlice("claim"), + claimJSON: cmd.StringSlice("claim-json"), + claimsFile: cmd.String("claims-file"), + allowLongLived: cmd.Bool("allow-long-lived"), + verbose: cmd.Bool("verbose"), + audSet: cmd.IsSet("aud"), + iatSet: cmd.IsSet("iat"), + nbfSet: cmd.IsSet("nbf"), + expSet: cmd.IsSet("exp"), + jtiSet: cmd.IsSet("jti"), + } + + signer, err := azurex.NewProvider(ctx, targetFrom(cfg), loggerFrom(ctx)) + if err != nil { + return failuref("%v", err) + } + return runMint(ctx, signer, cfg, opts, streams) + }, + } +} + +// runMint resolves the claim times, derives the kid from the current signing +// key, assembles and signs the JWT, and writes it to stdout. It is decoupled +// from the concrete provider (via azurex.Signer) so it is unit-testable. +func runMint(ctx context.Context, signer azurex.Signer, cfg *config.Config, opts mintOptions, streams IOStreams) error { + // Validate all user input (times + custom claims) before touching Azure so + // bad flags fail fast without a network round-trip. + iat, nbf, exp, err := resolveClaimTimes(opts) + if err != nil { + return err + } + claims, err := mergeCustomClaims(opts) + if err != nil { + return err + } + applyStandardClaims(claims, cfg, opts, iat, nbf, exp) + + key, err := signer.GetSigningKey(ctx) + if err != nil { + return keyVaultError(err, cfg.KeyVault) + } + if !key.Enabled { + return failuref("signing key %q is disabled in key vault %q; enable it or rotate the key before minting", cfg.KeyName, cfg.KeyVault) + } + kid := jwk.Thumbprint(key.N, key.E) + + token, err := sign.Mint(ctx, signer, kid, claims) + if err != nil { + return keyVaultError(err, cfg.KeyVault) + } + + // stdout receives exactly the compact JWT plus one newline; nothing else. + fmt.Fprintln(streams.Out, token) + + if opts.verbose { + printMintVerbose(streams, token, kid, iat, nbf, exp) + } + return nil +} + +// mergeCustomClaims builds the custom-claim base map in precedence order +// (lowest to highest): --claims-file < --claim-json < --claim. Standard claims +// are layered on top afterwards by applyStandardClaims. +func mergeCustomClaims(opts mintOptions) (map[string]any, error) { + claims := map[string]any{} + + if opts.claimsFile != "" { + if err := mergeClaimsFile(claims, opts.claimsFile); err != nil { + return nil, err + } + } + for _, kv := range opts.claimJSON { + k, raw, ok := splitClaim(kv) + if !ok { + return nil, usageErrorf("invalid --claim-json %q: expected key=", kv) + } + var v any + if err := json.Unmarshal([]byte(raw), &v); err != nil { + return nil, usageErrorf("invalid JSON for --claim-json %q: %v", k, err) + } + claims[k] = v + } + for _, kv := range opts.claim { + k, val, ok := splitClaim(kv) + if !ok { + return nil, usageErrorf("invalid --claim %q: expected key=value", kv) + } + claims[k] = val + } + return claims, nil +} + +// mergeClaimsFile reads a JSON object from path and merges its members into +// claims. A missing file, non-object root, or invalid JSON is a usage error. +func mergeClaimsFile(claims map[string]any, path string) error { + data, err := os.ReadFile(path) + if err != nil { + return usageErrorf("reading --claims-file %q: %v", path, err) + } + if trimmed := bytes.TrimSpace(data); len(trimmed) == 0 || trimmed[0] != '{' { + return usageErrorf("--claims-file %q must contain a JSON object", path) + } + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + return usageErrorf("invalid JSON in --claims-file %q: %v", path, err) + } + for k, v := range m { + claims[k] = v + } + return nil +} + +// splitClaim splits "key=value" on the first '='. A missing '=' or an empty key +// is invalid. +func splitClaim(s string) (key, value string, ok bool) { + i := strings.IndexByte(s, '=') + if i <= 0 { + return "", "", false + } + return s[:i], s[i+1:], true +} + +// applyStandardClaims layers the standard claims onto the custom-claim base. +// A standard claim flag that was explicitly set always wins; if it was not set, +// the custom value (if any) is kept and otherwise the computed default applies. +// iss always comes from config and overrides any custom value. +func applyStandardClaims(claims map[string]any, cfg *config.Config, opts mintOptions, iat, nbf, exp time.Time) { + // sub is a required flag, so it is always explicitly set and always wins. + claims["sub"] = opts.sub + + setStandardClaim(claims, "iat", opts.iatSet, iat.Unix()) + setStandardClaim(claims, "nbf", opts.nbfSet, nbf.Unix()) + setStandardClaim(claims, "exp", opts.expSet, exp.Unix()) + + if opts.jtiSet { + claims["jti"] = opts.jti + } else if _, ok := claims["jti"]; !ok { + claims["jti"] = uuid.NewString() + } + + if opts.audSet { + setAudience(claims, opts.aud) + } + + // iss is never user-supplied: config wins, even over a custom claim. + claims["iss"] = cfg.Issuer +} + +// setStandardClaim writes val when the flag was explicitly set (standard wins), +// or when no custom source already provided the claim (apply the default). +func setStandardClaim(claims map[string]any, key string, flagSet bool, val any) { + if flagSet { + claims[key] = val + return + } + if _, ok := claims[key]; !ok { + claims[key] = val + } +} + +// setAudience encodes the audience per JWT convention: a single value is a +// string, multiple values are an array. +func setAudience(claims map[string]any, aud []string) { + switch len(aud) { + case 0: + // nothing to set + case 1: + claims["aud"] = aud[0] + default: + claims["aud"] = aud + } +} + +// resolveClaimTimes computes iat, nbf, and exp from the flags, applying the +// defaults (iat=now, nbf=iat, exp=iat+15m) and enforcing exp > iat and the +// 24h long-lived guard. All input errors are usage errors (exit code 2). +func resolveClaimTimes(opts mintOptions) (iat, nbf, exp time.Time, err error) { + now := time.Now() + + iat = now + if opts.iat != "" { + iat, err = time.Parse(time.RFC3339, opts.iat) + if err != nil { + return time.Time{}, time.Time{}, time.Time{}, usageErrorf("invalid --iat %q: must be an RFC3339 timestamp", opts.iat) + } + } + + nbf = iat + if opts.nbf != "" { + nbf, err = time.Parse(time.RFC3339, opts.nbf) + if err != nil { + return time.Time{}, time.Time{}, time.Time{}, usageErrorf("invalid --nbf %q: must be an RFC3339 timestamp", opts.nbf) + } + } + + exp, err = parseExpiry(opts.exp, iat) + if err != nil { + return time.Time{}, time.Time{}, time.Time{}, err + } + if !exp.After(iat) { + return time.Time{}, time.Time{}, time.Time{}, usageErrorf("--exp %q resolves to %s, which is not after iat %s", opts.exp, exp.UTC().Format(time.RFC3339), iat.UTC().Format(time.RFC3339)) + } + if exp.Sub(iat) > maxExp && !opts.allowLongLived { + return time.Time{}, time.Time{}, time.Time{}, usageErrorf("--exp %q exceeds the 24h maximum; pass --allow-long-lived to override", opts.exp) + } + return iat, nbf, exp, nil +} + +// parseExpiry interprets the --exp value as either a Go duration relative to iat +// or an absolute RFC3339 timestamp. +func parseExpiry(raw string, iat time.Time) (time.Time, error) { + if d, derr := time.ParseDuration(raw); derr == nil { + return iat.Add(d), nil + } + if t, terr := time.Parse(time.RFC3339, raw); terr == nil { + return t, nil + } + return time.Time{}, usageErrorf("invalid --exp %q: must be a duration (e.g. 15m, 1h) or an RFC3339 timestamp", raw) +} + +func printMintVerbose(streams IOStreams, token, kid string, iat, nbf, exp time.Time) { + fmt.Fprintln(streams.Err, "minted token:") + if decoded, err := sign.Decode(token); err == nil { + if h, herr := json.MarshalIndent(decoded.Header, " ", " "); herr == nil { + fmt.Fprintf(streams.Err, " header: %s\n", h) + } + if p, perr := json.MarshalIndent(decoded.Payload, " ", " "); perr == nil { + fmt.Fprintf(streams.Err, " payload: %s\n", p) + } + fmt.Fprintf(streams.Err, " signature bytes: %d\n", decoded.SignatureBytes) + } + fmt.Fprintf(streams.Err, " kid: %s\n", kid) + fmt.Fprintf(streams.Err, " iat: %s\n", iat.UTC().Format(time.RFC3339)) + fmt.Fprintf(streams.Err, " nbf: %s\n", nbf.UTC().Format(time.RFC3339)) + fmt.Fprintf(streams.Err, " exp: %s\n", exp.UTC().Format(time.RFC3339)) +} + +func tokenDecodeCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "decode", + Usage: "Decode a JWT's header and payload without verifying it", + ArgsUsage: "", + Description: "Decode (do not verify) a compact JWT and print its header and payload as " + + "JSON to stdout, along with the byte length of the signature. Works on tokens " + + "from any issuer.\n\n" + + "Example:\n" + + " jotsmith token decode eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJtZSJ9.SflKx...", + Action: func(_ context.Context, cmd *ucli.Command) error { + args := cmd.Args() + if args.Len() != 1 { + return usageErrorf("token decode requires exactly one argument: the JWT to decode") + } + + decoded, err := sign.Decode(args.First()) + if err != nil { + return failuref("%v", err) + } + + b, err := json.MarshalIndent(decoded, "", " ") + if err != nil { + return failuref("rendering decoded token: %v", err) + } + fmt.Fprintln(streams.Out, string(b)) + return nil + }, + } +} diff --git a/internal/cli/token_claims_test.go b/internal/cli/token_claims_test.go new file mode 100644 index 0000000..3d1ad61 --- /dev/null +++ b/internal/cli/token_claims_test.go @@ -0,0 +1,201 @@ +package cli + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +// baseMintOpts mirrors the CLI defaults for a minimal valid mint: --sub set, +// --exp defaulted, no standard flag explicitly set. Tests layer custom-claim +// fields and *Set bools on top. +func baseMintOpts() mintOptions { + return mintOptions{sub: "me", exp: defaultExp} +} + +func mintPayload(t *testing.T, opts mintOptions) map[string]any { + t.Helper() + streams, out, _ := mintStreams() + if err := runMint(context.Background(), newFakeSigner(t), testConfig(), opts, streams); err != nil { + t.Fatalf("runMint: %v", err) + } + return decodePayload(t, strings.TrimSpace(out.String())) +} + +func TestMint_ClaimStringValues(t *testing.T) { + opts := baseMintOpts() + opts.claim = []string{"repo=acme/app", "empty="} + p := mintPayload(t, opts) + if p["repo"] != "acme/app" { + t.Errorf("repo claim: got %#v", p["repo"]) + } + if p["empty"] != "" { + t.Errorf("empty claim should be the empty string, got %#v", p["empty"]) + } +} + +func TestMint_ClaimJSONTypedValues(t *testing.T) { + opts := baseMintOpts() + opts.claimJSON = []string{ + `admin=true`, + `count=42`, + `roles=["a","b"]`, + `meta={"k":"v"}`, + } + p := mintPayload(t, opts) + if p["admin"] != true { + t.Errorf("admin should be bool true, got %#v", p["admin"]) + } + if p["count"] != float64(42) { + t.Errorf("count should be number 42, got %#v", p["count"]) + } + roles, ok := p["roles"].([]any) + if !ok || len(roles) != 2 || roles[0] != "a" { + t.Errorf("roles should be [\"a\",\"b\"], got %#v", p["roles"]) + } + meta, ok := p["meta"].(map[string]any) + if !ok || meta["k"] != "v" { + t.Errorf("meta should be {\"k\":\"v\"}, got %#v", p["meta"]) + } +} + +func TestMint_ClaimsFileMergedAndValidated(t *testing.T) { + dir := t.TempDir() + good := filepath.Join(dir, "claims.json") + if err := os.WriteFile(good, []byte(`{"team":"platform","tier":3}`), 0o600); err != nil { + t.Fatal(err) + } + opts := baseMintOpts() + opts.claimsFile = good + p := mintPayload(t, opts) + if p["team"] != "platform" || p["tier"] != float64(3) { + t.Errorf("claims-file not merged: %#v", p) + } + + // non-object root → usage error + arr := filepath.Join(dir, "arr.json") + if err := os.WriteFile(arr, []byte(`[1,2,3]`), 0o600); err != nil { + t.Fatal(err) + } + o := baseMintOpts() + o.claimsFile = arr + if err := mustMintError(t, o); ExitCode(err) != exitUsage { + t.Errorf("non-object claims-file should be a usage error, got code %d", ExitCode(err)) + } + + // invalid JSON → usage error + bad := filepath.Join(dir, "bad.json") + if err := os.WriteFile(bad, []byte(`{not json`), 0o600); err != nil { + t.Fatal(err) + } + o = baseMintOpts() + o.claimsFile = bad + if err := mustMintError(t, o); ExitCode(err) != exitUsage { + t.Errorf("invalid claims-file JSON should be a usage error, got code %d", ExitCode(err)) + } + + // missing file → usage error + o = baseMintOpts() + o.claimsFile = filepath.Join(dir, "nope.json") + if err := mustMintError(t, o); ExitCode(err) != exitUsage { + t.Errorf("missing claims-file should be a usage error, got code %d", ExitCode(err)) + } +} + +func TestMint_MergePrecedence(t *testing.T) { + dir := t.TempDir() + file := filepath.Join(dir, "c.json") + if err := os.WriteFile(file, []byte(`{"k":"from-file","only_file":"f"}`), 0o600); err != nil { + t.Fatal(err) + } + opts := baseMintOpts() + opts.claimsFile = file + opts.claimJSON = []string{`k="from-json"`, `only_json="j"`} + opts.claim = []string{"k=from-claim"} + + p := mintPayload(t, opts) + if p["k"] != "from-claim" { + t.Errorf("--claim should win over --claim-json and --claims-file, got %#v", p["k"]) + } + if p["only_file"] != "f" || p["only_json"] != "j" { + t.Errorf("non-conflicting layers should survive: %#v", p) + } +} + +func TestMint_StandardFlagsBeatCustom(t *testing.T) { + // sub flag wins over a custom sub. + opts := baseMintOpts() + opts.claim = []string{"sub=hacker"} + if p := mintPayload(t, opts); p["sub"] != "me" { + t.Errorf("--sub must win over custom sub, got %#v", p["sub"]) + } + + // explicit --exp wins over a custom exp; absent --exp lets custom through. + opts = baseMintOpts() + opts.iat = "2020-01-01T00:00:00Z" + opts.iatSet = true + opts.exp = "1h" + opts.expSet = true + opts.claimJSON = []string{"exp=999999"} + if p := mintPayload(t, opts); p["exp"] != float64(1577840400) { + t.Errorf("explicit --exp must win over custom exp, got %#v", p["exp"]) + } +} + +func TestMint_CustomAcceptedWhenStandardFlagAbsent(t *testing.T) { + // exp not passed as a flag → a custom exp is honored. + opts := baseMintOpts() + opts.claimJSON = []string{"exp=999999"} + if p := mintPayload(t, opts); p["exp"] != float64(999999) { + t.Errorf("custom exp should be accepted when --exp absent, got %#v", p["exp"]) + } + + // aud not passed as a flag → a custom aud is honored. + opts = baseMintOpts() + opts.claim = []string{"aud=from-custom"} + if p := mintPayload(t, opts); p["aud"] != "from-custom" { + t.Errorf("custom aud should be accepted when --aud absent, got %#v", p["aud"]) + } +} + +func TestMint_IssAlwaysFromConfig(t *testing.T) { + opts := baseMintOpts() + opts.claim = []string{"iss=https://evil.example"} + opts.claimJSON = []string{`iss="https://also-evil.example"`} + if p := mintPayload(t, opts); p["iss"] != issuer { + t.Errorf("iss must always come from config, got %#v", p["iss"]) + } +} + +func TestMint_InvalidClaimSyntaxIsUsageError(t *testing.T) { + o := baseMintOpts() + o.claim = []string{"no-equals-sign"} + if err := mustMintError(t, o); ExitCode(err) != exitUsage { + t.Errorf("--claim without '=' should be a usage error, got code %d", ExitCode(err)) + } + + o = baseMintOpts() + o.claim = []string{"=novalue"} + if err := mustMintError(t, o); ExitCode(err) != exitUsage { + t.Errorf("--claim with empty key should be a usage error, got code %d", ExitCode(err)) + } + + o = baseMintOpts() + o.claimJSON = []string{"k=not valid json"} + if err := mustMintError(t, o); ExitCode(err) != exitUsage { + t.Errorf("--claim-json with invalid JSON should be a usage error, got code %d", ExitCode(err)) + } +} + +// mustMintError runs mint expecting a non-nil error and returns it. +func mustMintError(t *testing.T, opts mintOptions) error { + t.Helper() + streams, _, _ := mintStreams() + err := runMint(context.Background(), newFakeSigner(t), testConfig(), opts, streams) + if err == nil { + t.Fatal("expected an error, got nil") + } + return err +} diff --git a/internal/cli/token_mint_test.go b/internal/cli/token_mint_test.go new file mode 100644 index 0000000..e2ed837 --- /dev/null +++ b/internal/cli/token_mint_test.go @@ -0,0 +1,277 @@ +package cli + +import ( + "bytes" + "context" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "math/big" + "strings" + "testing" + "time" + + "github.com/google/uuid" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/sign" +) + +// fakeSigner satisfies azurex.Signer with a real RSA key so minted tokens carry +// verifiable signatures. +type fakeSigner struct { + priv *rsa.PrivateKey + key azurex.Key + keyErr error + signErr error +} + +func (s *fakeSigner) GetSigningKey(context.Context) (azurex.Key, error) { + return s.key, s.keyErr +} + +func (s *fakeSigner) Sign(_ context.Context, digest []byte) ([]byte, error) { + if s.signErr != nil { + return nil, s.signErr + } + return rsa.SignPKCS1v15(rand.Reader, s.priv, crypto.SHA256, digest) +} + +func newFakeSigner(t *testing.T) *fakeSigner { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generating key: %v", err) + } + return &fakeSigner{ + priv: priv, + key: azurex.Key{ + Name: "signing-key", Enabled: true, Ops: []string{"sign", "verify"}, + N: priv.N.Bytes(), E: big.NewInt(int64(priv.E)).Bytes(), + }, + } +} + +func mintStreams() (IOStreams, *bytes.Buffer, *bytes.Buffer) { + var out, errBuf bytes.Buffer + return IOStreams{In: strings.NewReader(""), Out: &out, Err: &errBuf}, &out, &errBuf +} + +// decodePayload returns the minted token's payload as a generic map. +func decodePayload(t *testing.T, token string) map[string]any { + t.Helper() + decoded, err := sign.Decode(token) + if err != nil { + t.Fatalf("decoding minted token: %v", err) + } + var m map[string]any + if err := json.Unmarshal(decoded.Payload, &m); err != nil { + t.Fatalf("unmarshaling payload: %v", err) + } + return m +} + +func TestRunMint_StdoutIsExactlyOneJWT(t *testing.T) { + signer := newFakeSigner(t) + streams, out, errBuf := mintStreams() + + if err := runMint(context.Background(), signer, testConfig(), mintOptions{sub: "me", exp: defaultExp}, streams); err != nil { + t.Fatalf("runMint: %v", err) + } + + stdout := out.String() + if !strings.HasSuffix(stdout, "\n") { + t.Fatal("stdout must end with a newline") + } + trimmed := strings.TrimSuffix(stdout, "\n") + if strings.Contains(trimmed, "\n") { + t.Errorf("stdout has more than one line:\n%q", stdout) + } + if _, err := sign.Decode(trimmed); err != nil { + t.Errorf("stdout is not a valid compact JWT: %v", err) + } + if errBuf.Len() != 0 { + t.Errorf("expected empty stderr without --verbose, got:\n%s", errBuf.String()) + } +} + +func TestRunMint_SignatureVerifiesAndIssFromConfig(t *testing.T) { + signer := newFakeSigner(t) + streams, out, _ := mintStreams() + + if err := runMint(context.Background(), signer, testConfig(), mintOptions{sub: "me", exp: "5m"}, streams); err != nil { + t.Fatalf("runMint: %v", err) + } + token := strings.TrimSpace(out.String()) + + parts := strings.Split(token, ".") + signingInput := parts[0] + "." + parts[1] + digest := sha256.Sum256([]byte(signingInput)) + sig, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + t.Fatalf("decoding signature: %v", err) + } + if err := rsa.VerifyPKCS1v15(&signer.priv.PublicKey, crypto.SHA256, digest[:], sig); err != nil { + t.Errorf("minted signature does not verify: %v", err) + } + + payload := decodePayload(t, token) + if payload["iss"] != issuer { + t.Errorf("iss not taken from config: got %v want %s", payload["iss"], issuer) + } +} + +func TestRunMint_AudienceCardinality(t *testing.T) { + t.Run("absent omits aud", func(t *testing.T) { + streams, out, _ := mintStreams() + if err := runMint(context.Background(), newFakeSigner(t), testConfig(), mintOptions{sub: "me", exp: defaultExp}, streams); err != nil { + t.Fatal(err) + } + if _, ok := decodePayload(t, strings.TrimSpace(out.String()))["aud"]; ok { + t.Error("aud should be omitted when no --aud given") + } + }) + t.Run("single is a string", func(t *testing.T) { + streams, out, _ := mintStreams() + if err := runMint(context.Background(), newFakeSigner(t), testConfig(), mintOptions{sub: "me", exp: defaultExp, aud: []string{"one"}, audSet: true}, streams); err != nil { + t.Fatal(err) + } + if got := decodePayload(t, strings.TrimSpace(out.String()))["aud"]; got != "one" { + t.Errorf("aud should be the string \"one\", got %#v", got) + } + }) + t.Run("multiple is an array", func(t *testing.T) { + streams, out, _ := mintStreams() + if err := runMint(context.Background(), newFakeSigner(t), testConfig(), mintOptions{sub: "me", exp: defaultExp, aud: []string{"a", "b"}, audSet: true}, streams); err != nil { + t.Fatal(err) + } + got, ok := decodePayload(t, strings.TrimSpace(out.String()))["aud"].([]any) + if !ok || len(got) != 2 || got[0] != "a" || got[1] != "b" { + t.Errorf("aud should be [\"a\",\"b\"], got %#v", decodePayload(t, strings.TrimSpace(out.String()))["aud"]) + } + }) +} + +func TestRunMint_JTIDefaultsToUUID(t *testing.T) { + streams, out, _ := mintStreams() + if err := runMint(context.Background(), newFakeSigner(t), testConfig(), mintOptions{sub: "me", exp: defaultExp}, streams); err != nil { + t.Fatal(err) + } + jti, _ := decodePayload(t, strings.TrimSpace(out.String()))["jti"].(string) + if _, err := uuid.Parse(jti); err != nil { + t.Errorf("default jti %q is not a UUID: %v", jti, err) + } + + streams, out, _ = mintStreams() + if err := runMint(context.Background(), newFakeSigner(t), testConfig(), mintOptions{sub: "me", exp: defaultExp, jti: "fixed-id", jtiSet: true}, streams); err != nil { + t.Fatal(err) + } + if got := decodePayload(t, strings.TrimSpace(out.String()))["jti"]; got != "fixed-id" { + t.Errorf("explicit --jti not honored: got %v", got) + } +} + +func TestRunMint_VerboseWritesToStderrOnly(t *testing.T) { + streams, out, errBuf := mintStreams() + if err := runMint(context.Background(), newFakeSigner(t), testConfig(), mintOptions{sub: "me", exp: defaultExp, verbose: true}, streams); err != nil { + t.Fatal(err) + } + // stdout is still exactly one JWT line. + if strings.Count(strings.TrimSuffix(out.String(), "\n"), "\n") != 0 { + t.Errorf("stdout polluted by verbose output:\n%s", out.String()) + } + if !strings.Contains(errBuf.String(), "minted token:") { + t.Errorf("verbose preview missing from stderr:\n%s", errBuf.String()) + } +} + +func TestRunMint_DisabledKeyFails(t *testing.T) { + signer := newFakeSigner(t) + signer.key.Enabled = false + streams, _, _ := mintStreams() + err := runMint(context.Background(), signer, testConfig(), mintOptions{sub: "me", exp: defaultExp}, streams) + if err == nil || ExitCode(err) != exitFailure { + t.Fatalf("expected exit-1 failure for disabled key, got %v (code %d)", err, ExitCode(err)) + } +} + +func TestRunMint_KeyVaultErrorsAreFailures(t *testing.T) { + t.Run("get key error", func(t *testing.T) { + signer := newFakeSigner(t) + signer.keyErr = errors.New("forbidden") + streams, _, _ := mintStreams() + if err := runMint(context.Background(), signer, testConfig(), mintOptions{sub: "me", exp: defaultExp}, streams); err == nil || ExitCode(err) != exitFailure { + t.Fatalf("expected exit-1 failure, got %v (code %d)", err, ExitCode(err)) + } + }) + t.Run("sign error", func(t *testing.T) { + signer := newFakeSigner(t) + signer.signErr = errors.New("crypto officer role missing") + streams, _, _ := mintStreams() + if err := runMint(context.Background(), signer, testConfig(), mintOptions{sub: "me", exp: defaultExp}, streams); err == nil || ExitCode(err) != exitFailure { + t.Fatalf("expected exit-1 failure, got %v (code %d)", err, ExitCode(err)) + } + }) +} + +func TestRunMint_LongLivedGuard(t *testing.T) { + streams, _, _ := mintStreams() + err := runMint(context.Background(), newFakeSigner(t), testConfig(), mintOptions{sub: "me", exp: "48h"}, streams) + if err == nil || ExitCode(err) != exitUsage { + t.Fatalf("expected usage error for exp>24h, got %v (code %d)", err, ExitCode(err)) + } + + streams, out, _ := mintStreams() + if err := runMint(context.Background(), newFakeSigner(t), testConfig(), mintOptions{sub: "me", exp: "48h", allowLongLived: true}, streams); err != nil { + t.Fatalf("expected --allow-long-lived to permit 48h, got %v", err) + } + if out.Len() == 0 { + t.Error("expected a token on stdout with --allow-long-lived") + } +} + +func TestResolveClaimTimes(t *testing.T) { + t.Run("defaults", func(t *testing.T) { + iat, nbf, exp, err := resolveClaimTimes(mintOptions{exp: defaultExp}) + if err != nil { + t.Fatal(err) + } + if !nbf.Equal(iat) { + t.Error("nbf should default to iat") + } + if d := exp.Sub(iat); d != 15*time.Minute { + t.Errorf("default exp should be iat+15m, got +%s", d) + } + }) + t.Run("absolute RFC3339 exp", func(t *testing.T) { + _, _, exp, err := resolveClaimTimes(mintOptions{ + iat: "2026-01-01T00:00:00Z", + exp: "2026-01-01T01:00:00Z", + }) + if err != nil { + t.Fatal(err) + } + if want := time.Date(2026, 1, 1, 1, 0, 0, 0, time.UTC); !exp.Equal(want) { + t.Errorf("absolute exp: got %s want %s", exp, want) + } + }) + t.Run("exp not after iat", func(t *testing.T) { + if _, _, _, err := resolveClaimTimes(mintOptions{exp: "0s"}); err == nil || ExitCode(err) != exitUsage { + t.Fatalf("expected usage error, got %v", err) + } + }) + t.Run("bad iat", func(t *testing.T) { + if _, _, _, err := resolveClaimTimes(mintOptions{iat: "nope", exp: defaultExp}); err == nil || ExitCode(err) != exitUsage { + t.Fatalf("expected usage error for bad iat, got %v", err) + } + }) + t.Run("bad exp", func(t *testing.T) { + if _, _, _, err := resolveClaimTimes(mintOptions{exp: "banana"}); err == nil || ExitCode(err) != exitUsage { + t.Fatalf("expected usage error for bad exp, got %v", err) + } + }) +} diff --git a/internal/cli/token_test.go b/internal/cli/token_test.go new file mode 100644 index 0000000..c7f41e1 --- /dev/null +++ b/internal/cli/token_test.go @@ -0,0 +1,100 @@ +package cli + +import ( + "encoding/base64" + "encoding/json" + "strings" + "testing" +) + +func b64url(s string) string { + return base64.RawURLEncoding.EncodeToString([]byte(s)) +} + +func TestTokenDecodePrintsJSONToStdout(t *testing.T) { + tok := b64url(`{"alg":"RS256","typ":"JWT","kid":"k1"}`) + "." + + b64url(`{"sub":"me","aud":["a","b"]}`) + "." + + base64.RawURLEncoding.EncodeToString([]byte("0123456789")) + + stdout, stderr, err := run(t, "token", "decode", tok) + if err != nil { + t.Fatalf("unexpected error: %v (stderr: %s)", err, stderr) + } + + var out struct { + Header map[string]any `json:"header"` + Payload map[string]any `json:"payload"` + SignatureBytes int `json:"signature_bytes"` + } + if err := json.Unmarshal([]byte(stdout), &out); err != nil { + t.Fatalf("stdout is not the expected JSON object: %v\n%s", err, stdout) + } + if out.Header["kid"] != "k1" { + t.Errorf("header kid = %v, want k1", out.Header["kid"]) + } + if out.SignatureBytes != 10 { + t.Errorf("signature_bytes = %d, want 10", out.SignatureBytes) + } +} + +func TestTokenDecodeMalformedExitsOne(t *testing.T) { + stdout, _, err := run(t, "token", "decode", "only.two") + if err == nil { + t.Fatal("expected an error for a malformed token") + } + if got := ExitCode(err); got != exitFailure { + t.Errorf("expected exit %d, got %d", exitFailure, got) + } + if stdout != "" { + t.Errorf("expected empty stdout on failure, got:\n%s", stdout) + } +} + +func TestTokenDecodeMissingArgIsUsageError(t *testing.T) { + _, _, err := run(t, "token", "decode") + if err == nil { + t.Fatal("expected a usage error when no token is given") + } + if got := ExitCode(err); got != exitUsage { + t.Errorf("expected exit %d, got %d", exitUsage, got) + } +} + +func TestTokenMintRequiresSub(t *testing.T) { + stdout, _, err := run(t, "token", "mint") + if err == nil { + t.Fatal("expected a usage error when --sub is missing") + } + if got := ExitCode(err); got != exitUsage { + t.Errorf("expected exit %d, got %d", exitUsage, got) + } + if stdout != "" { + t.Errorf("expected empty stdout on usage error, got:\n%s", stdout) + } +} + +func TestTokenMintRejectsIssFlag(t *testing.T) { + // --iss must never exist; supplying it is an unknown-flag usage error. + stdout, _, err := run(t, "token", "mint", "--sub", "me", "--iss", "https://evil.example") + if err == nil { + t.Fatal("expected an unknown-flag error for --iss") + } + if got := ExitCode(err); got != exitUsage { + t.Errorf("expected exit %d, got %d", exitUsage, got) + } + if stdout != "" { + t.Errorf("expected empty stdout on usage error, got:\n%s", stdout) + } +} + +func TestTokenDecodeStdoutIsOnlyJSON(t *testing.T) { + tok := b64url(`{"alg":"none"}`) + "." + b64url(`{"sub":"x"}`) + "." + stdout, _, err := run(t, "token", "decode", tok) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + trimmed := strings.TrimRight(stdout, "\n") + if !strings.HasPrefix(trimmed, "{") || !strings.HasSuffix(trimmed, "}") { + t.Errorf("stdout should be exactly one JSON object, got:\n%s", stdout) + } +} diff --git a/internal/cli/verify.go b/internal/cli/verify.go new file mode 100644 index 0000000..13b17d4 --- /dev/null +++ b/internal/cli/verify.go @@ -0,0 +1,120 @@ +package cli + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + ucli "github.com/urfave/cli/v3" + + "github.com/MaxAnderson95/jotsmith/internal/config" + "github.com/MaxAnderson95/jotsmith/internal/jwk" + "github.com/MaxAnderson95/jotsmith/internal/oidc" + "github.com/MaxAnderson95/jotsmith/internal/sign" +) + +// verifyHTTPTimeout bounds each discovery/JWKS fetch so a hung issuer endpoint +// fails with a clear error rather than blocking forever. +const verifyHTTPTimeout = 10 * time.Second + +func tokenVerifyCommand(streams IOStreams) *ucli.Command { + return &ucli.Command{ + Name: "verify", + Usage: "Verify a JWT against the configured issuer over HTTPS", + ArgsUsage: "", + Description: "Fetch the issuer's discovery document and JWKS over HTTPS, verify the token's " + + "RS256 signature against the published key, and check its standard claims (iss, exp, " + + "nbf, iat) with ±60s clock skew. Prints OK and the decoded claims to stderr; nothing " + + "is written to stdout.\n\n" + + "Example:\n" + + " jotsmith token verify \"$(jotsmith token mint --sub me)\" --aud sigstore", + Flags: []ucli.Flag{ + &ucli.StringFlag{Name: "aud", Usage: "require this audience to be present in the token"}, + &ucli.StringFlag{Name: "sub", Usage: "require this exact subject"}, + }, + Action: func(ctx context.Context, cmd *ucli.Command) error { + args := cmd.Args() + if args.Len() != 1 { + return usageErrorf("token verify requires exactly one argument: the JWT to verify") + } + cfg, err := loadConfig(ctx, cmd) + if err != nil { + return err + } + client := &http.Client{Timeout: verifyHTTPTimeout} + return runVerify(ctx, client, cfg, args.First(), cmd.String("aud"), cmd.String("sub"), streams) + }, + } +} + +// runVerify fetches the issuer's JWKS, verifies the token, and reports the +// result to stderr. Verification and fetch failures are runtime failures (exit +// code 1); nothing is ever written to stdout. +func runVerify(ctx context.Context, client *http.Client, cfg *config.Config, token, aud, sub string, streams IOStreams) error { + set, err := fetchJWKS(ctx, client, cfg) + if err != nil { + return failuref("%v", err) + } + + decoded, err := sign.Verify(token, sign.VerifyOptions{ + Issuer: cfg.Issuer, + Keys: set, + ExpectedAud: aud, + ExpectedSub: sub, + Now: time.Now(), + Skew: sign.DefaultSkew, + }) + if err != nil { + return failuref("verification failed: %v", err) + } + + fmt.Fprintln(streams.Err, "OK") + if b, e := json.MarshalIndent(decoded, "", " "); e == nil { + fmt.Fprintln(streams.Err, string(b)) + } + return nil +} + +// fetchJWKS fetches the discovery document from the configured issuer, validates +// that its issuer field equals the configured issuer, then fetches and parses +// the JWKS the document points to. +func fetchJWKS(ctx context.Context, client *http.Client, cfg *config.Config) (jwk.Set, error) { + discoURL := oidc.JoinURL(cfg.Issuer, cfg.DiscoveryPath) + var disco oidc.Discovery + if err := getJSON(ctx, client, discoURL, &disco); err != nil { + return jwk.Set{}, fmt.Errorf("fetching discovery document: %w", err) + } + if disco.Issuer != cfg.Issuer { + return jwk.Set{}, fmt.Errorf("discovery issuer %q does not match configured issuer %q", disco.Issuer, cfg.Issuer) + } + if disco.JWKSURI == "" { + return jwk.Set{}, fmt.Errorf("discovery document has no jwks_uri") + } + + var set jwk.Set + if err := getJSON(ctx, client, disco.JWKSURI, &set); err != nil { + return jwk.Set{}, fmt.Errorf("fetching JWKS: %w", err) + } + return set, nil +} + +func getJSON(ctx context.Context, client *http.Client, url string, dst any) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("building request for %s: %w", url, err) + } + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("GET %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("GET %s: unexpected status %s", url, resp.Status) + } + if err := json.NewDecoder(resp.Body).Decode(dst); err != nil { + return fmt.Errorf("decoding %s: %w", url, err) + } + return nil +} diff --git a/internal/cli/verify_test.go b/internal/cli/verify_test.go new file mode 100644 index 0000000..ed83105 --- /dev/null +++ b/internal/cli/verify_test.go @@ -0,0 +1,142 @@ +package cli + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/MaxAnderson95/jotsmith/internal/config" + "github.com/MaxAnderson95/jotsmith/internal/jwk" + "github.com/MaxAnderson95/jotsmith/internal/oidc" +) + +// issuerServer stands up an httptest server that publishes a discovery document +// and JWKS for the given key, with its issuer set to the server's own URL. It +// returns the server and a config pointing at it. +func issuerServer(t *testing.T, signer *fakeSigner) (*httptest.Server, *config.Config) { + t.Helper() + mux := http.NewServeMux() + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + set := jwk.NewSet(jwk.FromRSA(signer.key.N, signer.key.E)) + disco := oidc.Render(srv.URL, config.DefaultJWKSPath) + + mux.HandleFunc("/"+config.DefaultDiscoveryPath, func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(disco) + }) + mux.HandleFunc("/"+config.DefaultJWKSPath, func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(set) + }) + + cfg := &config.Config{ + Version: 1, SubscriptionID: "sub-123", StorageAccount: "jotsmithmax", + KeyVault: "jotsmith-kv", KeyName: "signing-key", + Issuer: srv.URL, + JWKSPath: config.DefaultJWKSPath, + DiscoveryPath: config.DefaultDiscoveryPath, + } + return srv, cfg +} + +// mintFor mints a token through runMint against cfg using signer. +func mintFor(t *testing.T, signer *fakeSigner, cfg *config.Config, opts mintOptions) string { + t.Helper() + streams, out, _ := mintStreams() + if err := runMint(context.Background(), signer, cfg, opts, streams); err != nil { + t.Fatalf("runMint: %v", err) + } + return strings.TrimSpace(out.String()) +} + +func TestRunVerify_HappyPathToStderrOnly(t *testing.T) { + signer := newFakeSigner(t) + _, cfg := issuerServer(t, signer) + token := mintFor(t, signer, cfg, mintOptions{sub: "me", exp: defaultExp}) + + streams, out, errBuf := mintStreams() + client := &http.Client{Timeout: verifyHTTPTimeout} + if err := runVerify(context.Background(), client, cfg, token, "", "", streams); err != nil { + t.Fatalf("runVerify: %v", err) + } + if out.Len() != 0 { + t.Errorf("verify must write nothing to stdout, got:\n%s", out.String()) + } + if !strings.Contains(errBuf.String(), "OK") { + t.Errorf("expected OK on stderr, got:\n%s", errBuf.String()) + } +} + +func TestRunVerify_AudAndSubChecks(t *testing.T) { + signer := newFakeSigner(t) + _, cfg := issuerServer(t, signer) + token := mintFor(t, signer, cfg, mintOptions{sub: "ci", exp: defaultExp, aud: []string{"sigstore"}, audSet: true}) + + client := &http.Client{Timeout: verifyHTTPTimeout} + + streams, _, _ := mintStreams() + if err := runVerify(context.Background(), client, cfg, token, "sigstore", "ci", streams); err != nil { + t.Fatalf("expected match, got %v", err) + } + + streams, _, _ = mintStreams() + if err := runVerify(context.Background(), client, cfg, token, "wrong-aud", "", streams); err == nil || ExitCode(err) != exitFailure { + t.Fatalf("expected exit-1 for wrong aud, got %v (code %d)", err, ExitCode(err)) + } +} + +func TestRunVerify_DiscoveryIssuerMismatch(t *testing.T) { + signer := newFakeSigner(t) + mux := http.NewServeMux() + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + // Discovery advertises a different issuer than the config points at. + disco := oidc.Render("https://imposter.example", config.DefaultJWKSPath) + mux.HandleFunc("/"+config.DefaultDiscoveryPath, func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(disco) + }) + + cfg := &config.Config{ + Version: 1, SubscriptionID: "s", StorageAccount: "a", KeyVault: "k", KeyName: "signing-key", + Issuer: srv.URL, JWKSPath: config.DefaultJWKSPath, DiscoveryPath: config.DefaultDiscoveryPath, + } + token := mintFor(t, signer, &config.Config{ + Version: 1, SubscriptionID: "s", StorageAccount: "a", KeyVault: "k", KeyName: "signing-key", + Issuer: srv.URL, JWKSPath: config.DefaultJWKSPath, DiscoveryPath: config.DefaultDiscoveryPath, + }, mintOptions{sub: "me", exp: defaultExp}) + + client := &http.Client{Timeout: verifyHTTPTimeout} + streams, _, _ := mintStreams() + if err := runVerify(context.Background(), client, cfg, token, "", "", streams); err == nil || ExitCode(err) != exitFailure { + t.Fatalf("expected exit-1 for issuer mismatch, got %v (code %d)", err, ExitCode(err)) + } +} + +func TestRunVerify_FetchErrorIsFailure(t *testing.T) { + cfg := &config.Config{ + Version: 1, SubscriptionID: "s", StorageAccount: "a", KeyVault: "k", KeyName: "signing-key", + Issuer: "http://127.0.0.1:0", // unroutable: connection refused + JWKSPath: config.DefaultJWKSPath, + DiscoveryPath: config.DefaultDiscoveryPath, + } + client := &http.Client{Timeout: 500 * time.Millisecond} + streams, _, _ := mintStreams() + if err := runVerify(context.Background(), client, cfg, "a.b.c", "", "", streams); err == nil || ExitCode(err) != exitFailure { + t.Fatalf("expected exit-1 on fetch error, got %v (code %d)", err, ExitCode(err)) + } +} + +func TestTokenVerifyMissingArgIsUsageError(t *testing.T) { + _, _, err := run(t, "token", "verify") + if err == nil { + t.Fatal("expected a usage error when no token is given") + } + if got := ExitCode(err); got != exitUsage { + t.Errorf("expected exit %d, got %d", exitUsage, got) + } +} diff --git a/internal/cli/version.go b/internal/cli/version.go new file mode 100644 index 0000000..3efbe22 --- /dev/null +++ b/internal/cli/version.go @@ -0,0 +1,20 @@ +package cli + +import ( + "fmt" + "io" +) + +// BuildInfo carries version metadata injected at build time via -ldflags. +type BuildInfo struct { + Version string + Commit string + Date string +} + +// printVersion writes human-readable version information to w (stdout). +func printVersion(w io.Writer, name string, b BuildInfo) { + fmt.Fprintf(w, "%s version %s\n", name, b.Version) + fmt.Fprintf(w, "commit: %s\n", b.Commit) + fmt.Fprintf(w, "built: %s\n", b.Date) +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..79217c9 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,172 @@ +package config + +import ( + "encoding/json" + "errors" + "fmt" + "io/fs" + "log/slog" + "net/url" + "os" + "path/filepath" + "strings" +) + +// Default blob paths inside the $web container. Configurable in the file but +// defaulted so users never have to set them (PRD §5.3). +const ( + DefaultJWKSPath = ".well-known/jwks.json" + DefaultDiscoveryPath = ".well-known/openid-configuration" + DefaultKeyName = "signing-key" + + // SupportedVersion is the only config schema version v1 accepts. + SupportedVersion = 1 +) + +// Config represents exactly one issuer (ADR-0004). The field order matches the +// PRD §5.3 schema so marshaled output reads naturally. +type Config struct { + Version int `json:"version"` + SubscriptionID string `json:"subscription_id"` + StorageAccount string `json:"storage_account"` + KeyVault string `json:"key_vault"` + KeyName string `json:"key_name"` + Issuer string `json:"issuer"` + JWKSPath string `json:"jwks_path"` + DiscoveryPath string `json:"discovery_path"` +} + +// DefaultPath returns the config path under XDG_CONFIG_HOME (or $HOME/.config) +// per PRD §5.3. +func DefaultPath() string { + base := os.Getenv("XDG_CONFIG_HOME") + if base == "" { + home, err := os.UserHomeDir() + if err != nil { + home = "" + } + base = filepath.Join(home, ".config") + } + return filepath.Join(base, "jotsmith", "config.json") +} + +// ResolvePath implements the path-resolution precedence (highest wins): +// explicit (the --config flag) > JOTSMITH_CONFIG env > the XDG/HOME default. +func ResolvePath(explicit string) string { + if explicit != "" { + return explicit + } + if env := os.Getenv("JOTSMITH_CONFIG"); env != "" { + return env + } + return DefaultPath() +} + +// Write validates cfg and writes it as pretty JSON to path, creating the parent +// directory if needed. The file is written 0600 since it records the issuer's +// Azure coordinates. Validation runs first so a half-baked config is never +// persisted. +func Write(path string, cfg *Config, log *slog.Logger) error { + if err := cfg.Validate(log); err != nil { + return err + } + data, err := json.MarshalIndent(cfg, "", " ") + if err != nil { + return fmt.Errorf("marshaling config: %w", err) + } + data = append(data, '\n') + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return fmt.Errorf("creating config directory: %w", err) + } + if err := os.WriteFile(path, data, 0o600); err != nil { + return fmt.Errorf("writing config %s: %w", path, err) + } + return nil +} + +// Load reads, parses, and validates the config at path. It returns a +// *NotFoundError if the file is absent, a *SyntaxError if the JSON is +// malformed, and a *ValidationError if a required field is missing or invalid. +// +// log may be nil. When non-nil it receives a debug record if the issuer is +// normalized (e.g. a trailing slash is stripped). +func Load(path string, log *slog.Logger) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil, &NotFoundError{Path: path} + } + return nil, fmt.Errorf("reading config %s: %w", path, err) + } + + var c Config + if err := json.Unmarshal(data, &c); err != nil { + return nil, &SyntaxError{Path: path, Err: err} + } + + if err := c.Validate(log); err != nil { + return nil, err + } + return &c, nil +} + +// Validate checks required fields and normalizes the config in place. It is +// safe to call on a freshly constructed Config (e.g. one assembled by setup +// before writing it out). +func (c *Config) Validate(log *slog.Logger) error { + if c.Version == 0 { + return &ValidationError{Field: "version", Msg: "is required"} + } + if c.Version != SupportedVersion { + return &ValidationError{Field: "version", Msg: fmt.Sprintf("unsupported version %d (only %d is supported)", c.Version, SupportedVersion)} + } + + required := []struct { + field string + value string + }{ + {"subscription_id", c.SubscriptionID}, + {"storage_account", c.StorageAccount}, + {"key_vault", c.KeyVault}, + {"key_name", c.KeyName}, + {"issuer", c.Issuer}, + } + for _, r := range required { + if strings.TrimSpace(r.value) == "" { + return &ValidationError{Field: r.field, Msg: "is required"} + } + } + + normalized, err := normalizeIssuer(c.Issuer, log) + if err != nil { + return err + } + c.Issuer = normalized + + if c.JWKSPath == "" { + c.JWKSPath = DefaultJWKSPath + } + if c.DiscoveryPath == "" { + c.DiscoveryPath = DefaultDiscoveryPath + } + return nil +} + +// normalizeIssuer validates the issuer as an absolute URL and strips a trailing +// slash so it matches the canonical iss string used in tokens and the discovery +// document (CONTEXT.md: "Issuer URL ... has no trailing slash"). +func normalizeIssuer(raw string, log *slog.Logger) (string, error) { + u, err := url.Parse(raw) + if err != nil { + return "", &ValidationError{Field: "issuer", Msg: fmt.Sprintf("is not a valid URL: %v", err)} + } + if u.Scheme == "" || u.Host == "" { + return "", &ValidationError{Field: "issuer", Msg: "must be an absolute URL with a scheme and host"} + } + + trimmed := strings.TrimRight(raw, "/") + if trimmed != raw && log != nil { + log.Debug("normalized issuer: stripped trailing slash", "configured", raw, "normalized", trimmed) + } + return trimmed, nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..c182e51 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,154 @@ +package config + +import ( + "errors" + "os" + "path/filepath" + "testing" +) + +func writeConfig(t *testing.T, body string) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "config.json") + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatalf("writing temp config: %v", err) + } + return path +} + +const validBody = `{ + "version": 1, + "subscription_id": "00000000-0000-0000-0000-000000000000", + "storage_account": "jotsmithmax", + "key_vault": "jotsmith-max-kv", + "key_name": "signing-key", + "issuer": "https://jotsmithmax.z13.web.core.windows.net", + "jwks_path": ".well-known/jwks.json", + "discovery_path": ".well-known/openid-configuration" +}` + +func TestResolvePath_FlagOverridesEnv(t *testing.T) { + t.Setenv("JOTSMITH_CONFIG", "/from/env.json") + if got := ResolvePath("/from/flag.json"); got != "/from/flag.json" { + t.Errorf("flag should win, got %q", got) + } +} + +func TestResolvePath_EnvOverridesXDG(t *testing.T) { + t.Setenv("XDG_CONFIG_HOME", "/xdg") + t.Setenv("JOTSMITH_CONFIG", "/from/env.json") + if got := ResolvePath(""); got != "/from/env.json" { + t.Errorf("env should win over XDG, got %q", got) + } +} + +func TestResolvePath_XDGFallback(t *testing.T) { + t.Setenv("JOTSMITH_CONFIG", "") + t.Setenv("XDG_CONFIG_HOME", "/xdg") + want := filepath.Join("/xdg", "jotsmith", "config.json") + if got := ResolvePath(""); got != want { + t.Errorf("XDG fallback wrong: got %q want %q", got, want) + } +} + +func TestResolvePath_HomeFallbackWhenXDGUnset(t *testing.T) { + t.Setenv("JOTSMITH_CONFIG", "") + t.Setenv("XDG_CONFIG_HOME", "") + t.Setenv("HOME", "/home/tester") + want := filepath.Join("/home/tester", ".config", "jotsmith", "config.json") + if got := ResolvePath(""); got != want { + t.Errorf("HOME fallback wrong: got %q want %q", got, want) + } +} + +func TestLoad_MissingFile(t *testing.T) { + _, err := Load(filepath.Join(t.TempDir(), "nope.json"), nil) + var nfe *NotFoundError + if !errors.As(err, &nfe) { + t.Fatalf("expected *NotFoundError, got %T: %v", err, err) + } +} + +func TestLoad_InvalidJSON(t *testing.T) { + path := writeConfig(t, "{ this is not json ]") + _, err := Load(path, nil) + var se *SyntaxError + if !errors.As(err, &se) { + t.Fatalf("expected *SyntaxError, got %T: %v", err, err) + } +} + +func TestLoad_MissingRequiredField(t *testing.T) { + body := `{"version":1,"storage_account":"sa","key_vault":"kv","key_name":"k","issuer":"https://x.example"}` + _, err := Load(writeConfig(t, body), nil) + var ve *ValidationError + if !errors.As(err, &ve) { + t.Fatalf("expected *ValidationError, got %T: %v", err, err) + } + if ve.Field != "subscription_id" { + t.Errorf("expected offending field subscription_id, got %q", ve.Field) + } +} + +func TestLoad_UnsupportedVersion(t *testing.T) { + body := `{"version":2,"subscription_id":"s","storage_account":"sa","key_vault":"kv","key_name":"k","issuer":"https://x.example"}` + _, err := Load(writeConfig(t, body), nil) + var ve *ValidationError + if !errors.As(err, &ve) || ve.Field != "version" { + t.Fatalf("expected version ValidationError, got %T: %v", err, err) + } +} + +func TestLoad_MissingVersionIsRequired(t *testing.T) { + body := `{"subscription_id":"s","storage_account":"sa","key_vault":"kv","key_name":"k","issuer":"https://x.example"}` + _, err := Load(writeConfig(t, body), nil) + var ve *ValidationError + if !errors.As(err, &ve) || ve.Field != "version" { + t.Fatalf("expected version ValidationError, got %T: %v", err, err) + } +} + +func TestLoad_TrailingSlashIssuerCorrected(t *testing.T) { + body := `{"version":1,"subscription_id":"s","storage_account":"sa","key_vault":"kv","key_name":"k","issuer":"https://x.example/"}` + cfg, err := Load(writeConfig(t, body), nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Issuer != "https://x.example" { + t.Errorf("trailing slash not stripped: %q", cfg.Issuer) + } +} + +func TestLoad_InvalidIssuerURL(t *testing.T) { + body := `{"version":1,"subscription_id":"s","storage_account":"sa","key_vault":"kv","key_name":"k","issuer":"not-a-url"}` + _, err := Load(writeConfig(t, body), nil) + var ve *ValidationError + if !errors.As(err, &ve) || ve.Field != "issuer" { + t.Fatalf("expected issuer ValidationError, got %T: %v", err, err) + } +} + +func TestLoad_ValidAppliesPathDefaults(t *testing.T) { + body := `{"version":1,"subscription_id":"s","storage_account":"sa","key_vault":"kv","key_name":"k","issuer":"https://x.example"}` + cfg, err := Load(writeConfig(t, body), nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.JWKSPath != DefaultJWKSPath { + t.Errorf("jwks_path default not applied: %q", cfg.JWKSPath) + } + if cfg.DiscoveryPath != DefaultDiscoveryPath { + t.Errorf("discovery_path default not applied: %q", cfg.DiscoveryPath) + } +} + +func TestLoad_FullValid(t *testing.T) { + cfg, err := Load(writeConfig(t, validBody), nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.StorageAccount != "jotsmithmax" || cfg.KeyName != "signing-key" { + t.Errorf("unexpected config: %+v", cfg) + } +} diff --git a/internal/config/doc.go b/internal/config/doc.go new file mode 100644 index 0000000..cf737fd --- /dev/null +++ b/internal/config/doc.go @@ -0,0 +1,6 @@ +// Package config loads, validates, and renders the jotsmith config file. +// +// A config file represents exactly one issuer (ADR-0004): one Storage +// Account, one Key Vault, one Issuer URL. The schema is flat and versioned; +// see PRD §5.3. +package config diff --git a/internal/config/errors.go b/internal/config/errors.go new file mode 100644 index 0000000..9625c9f --- /dev/null +++ b/internal/config/errors.go @@ -0,0 +1,35 @@ +package config + +import "fmt" + +// NotFoundError indicates the config file does not exist at the resolved path. +type NotFoundError struct { + Path string +} + +func (e *NotFoundError) Error() string { + return fmt.Sprintf("config file not found at %s (run `jotsmith setup` first)", e.Path) +} + +// SyntaxError indicates the config file exists but is not valid JSON. +type SyntaxError struct { + Path string + Err error +} + +func (e *SyntaxError) Error() string { + return fmt.Sprintf("config file %s is not valid JSON: %v", e.Path, e.Err) +} + +func (e *SyntaxError) Unwrap() error { return e.Err } + +// ValidationError indicates a required field is missing or a field value is +// invalid. The offending field is always named. +type ValidationError struct { + Field string + Msg string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("invalid config: field %q %s", e.Field, e.Msg) +} diff --git a/internal/jwk/doc.go b/internal/jwk/doc.go new file mode 100644 index 0000000..c871487 --- /dev/null +++ b/internal/jwk/doc.go @@ -0,0 +1,6 @@ +// Package jwk constructs JSON Web Keys from RSA public-key material and +// computes the RFC 7638 thumbprint used as the kid. +// +// The JWKS is always serialized as an array (length 1 in v1) to leave room for +// future multi-key support without a schema migration (ADR-0003). +package jwk diff --git a/internal/jwk/jwk.go b/internal/jwk/jwk.go new file mode 100644 index 0000000..9d15a71 --- /dev/null +++ b/internal/jwk/jwk.go @@ -0,0 +1,89 @@ +package jwk + +import ( + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "fmt" + "math/big" +) + +// b64 is base64url without padding, as JWK members and thumbprints require +// (RFC 7515 / RFC 7638). +var b64 = base64.RawURLEncoding + +// JWK is a single JSON Web Key entry. v1 only ever emits RSA signing keys, so +// the fields are the minimal RSA public-key set plus the derived kid. No +// x5c/x5t members: jotsmith has no certificate chain (PRD §7). +type JWK struct { + Kty string `json:"kty"` + Use string `json:"use"` + Alg string `json:"alg"` + Kid string `json:"kid"` + N string `json:"n"` + E string `json:"e"` +} + +// Set is a JSON Web Key Set. Keys is always an array (length 1 in v1) to leave +// room for future multi-key support without a schema migration (ADR-0003). +type Set struct { + Keys []JWK `json:"keys"` +} + +// FromRSA builds a signing JWK from the raw big-endian RSA modulus (n) and +// public exponent (e) bytes, deriving kid as the RFC 7638 thumbprint. +func FromRSA(n, e []byte) JWK { + return JWK{ + Kty: "RSA", + Use: "sig", + Alg: "RS256", + Kid: Thumbprint(n, e), + N: b64.EncodeToString(n), + E: b64.EncodeToString(e), + } +} + +// NewSet wraps a single JWK in a Set. +func NewSet(k JWK) Set { + return Set{Keys: []JWK{k}} +} + +// RSAPublicKey reconstructs the RSA public key from the JWK's base64url n and e +// members, so a verifier can check signatures. It rejects non-RSA keys and an +// exponent that does not fit Go's int-sized rsa.PublicKey.E. +func (k JWK) RSAPublicKey() (*rsa.PublicKey, error) { + if k.Kty != "RSA" { + return nil, fmt.Errorf("unsupported key type %q (only RSA)", k.Kty) + } + nBytes, err := b64.DecodeString(k.N) + if err != nil { + return nil, fmt.Errorf("decoding modulus (n): %w", err) + } + eBytes, err := b64.DecodeString(k.E) + if err != nil { + return nil, fmt.Errorf("decoding exponent (e): %w", err) + } + if len(nBytes) == 0 || len(eBytes) == 0 { + return nil, fmt.Errorf("modulus or exponent is empty") + } + e := new(big.Int).SetBytes(eBytes) + if !e.IsInt64() || e.Int64() > int64(^uint32(0)) { + return nil, fmt.Errorf("exponent out of supported range") + } + return &rsa.PublicKey{ + N: new(big.Int).SetBytes(nBytes), + E: int(e.Int64()), + }, nil +} + +// Thumbprint computes the RFC 7638 §3 thumbprint of an RSA public key: SHA-256 +// over the canonical JSON containing only the required members (e, kty, n) in +// lexicographic order with no whitespace, base64url-encoded without padding. +// +// The member values for n and e are themselves base64url encodings of the raw +// key bytes, matching the JWK representation, so consumers compute the same kid. +func Thumbprint(n, e []byte) string { + canonical := `{"e":"` + b64.EncodeToString(e) + `","kty":"RSA","n":"` + b64.EncodeToString(n) + `"}` + sum := sha256.Sum256([]byte(canonical)) + return b64.EncodeToString(sum[:]) +} diff --git a/internal/jwk/jwk_test.go b/internal/jwk/jwk_test.go new file mode 100644 index 0000000..cf98a85 --- /dev/null +++ b/internal/jwk/jwk_test.go @@ -0,0 +1,67 @@ +package jwk + +import ( + "encoding/base64" + "encoding/json" + "testing" +) + +// TestThumbprint_RFC7638Example verifies the thumbprint matches the worked +// example in RFC 7638 §3.1 byte-for-byte. The n and e below are the exact +// base64url member values from the RFC's intermediate JSON object. +func TestThumbprint_RFC7638Example(t *testing.T) { + const nB64 = "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2" + + "aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCi" + + "FV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65Y" + + "GjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n" + + "91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_x" + + "BniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw" + const eB64 = "AQAB" + const wantKid = "NzbLsXh8uDCcd-6MNwXF4W_7noWXFZAfHkxZsRGC9Xs" + + n, err := base64.RawURLEncoding.DecodeString(nB64) + if err != nil { + t.Fatalf("decoding n: %v", err) + } + e, err := base64.RawURLEncoding.DecodeString(eB64) + if err != nil { + t.Fatalf("decoding e: %v", err) + } + + if got := Thumbprint(n, e); got != wantKid { + t.Errorf("Thumbprint mismatch:\n got: %s\nwant: %s", got, wantKid) + } +} + +func TestFromRSA_Shape(t *testing.T) { + n := []byte{0x01, 0x02, 0x03} + e := []byte{0x01, 0x00, 0x01} + k := FromRSA(n, e) + + if k.Kty != "RSA" || k.Use != "sig" || k.Alg != "RS256" { + t.Errorf("unexpected fixed fields: %+v", k) + } + if k.Kid != Thumbprint(n, e) { + t.Errorf("kid is not the thumbprint") + } + if k.N != base64.RawURLEncoding.EncodeToString(n) || k.E != base64.RawURLEncoding.EncodeToString(e) { + t.Errorf("n/e not base64url encoded: %+v", k) + } +} + +func TestSet_MarshalsKeysArray(t *testing.T) { + set := NewSet(FromRSA([]byte{1, 2, 3}, []byte{1, 0, 1})) + b, err := json.Marshal(set) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var probe struct { + Keys []map[string]any `json:"keys"` + } + if err := json.Unmarshal(b, &probe); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(probe.Keys) != 1 { + t.Fatalf("expected keys array of length 1, got %d", len(probe.Keys)) + } +} diff --git a/internal/oidc/doc.go b/internal/oidc/doc.go new file mode 100644 index 0000000..a7eeadb --- /dev/null +++ b/internal/oidc/doc.go @@ -0,0 +1,4 @@ +// Package oidc renders the OpenID Connect discovery document jotsmith +// publishes at /.well-known/openid-configuration. See PRD §7 for the exact +// shape. +package oidc diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go new file mode 100644 index 0000000..6f3bc5c --- /dev/null +++ b/internal/oidc/oidc.go @@ -0,0 +1,40 @@ +package oidc + +import "strings" + +// Discovery is the OpenID Connect discovery document jotsmith publishes at +// /.well-known/openid-configuration. Field order matches PRD §7. +// +// The endpoints a full IdP would advertise (authorization, token, userinfo, +// registration) are deliberately omitted: jotsmith has nothing to point them +// at, and real workload-identity issuers (e.g. GitHub Actions) omit them too. +type Discovery struct { + Issuer string `json:"issuer"` + JWKSURI string `json:"jwks_uri"` + ResponseTypesSupported []string `json:"response_types_supported"` + SubjectTypesSupported []string `json:"subject_types_supported"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` + ScopesSupported []string `json:"scopes_supported"` + ClaimsSupported []string `json:"claims_supported"` +} + +// Render builds the discovery document for the given issuer and JWKS blob path. +// id_token_signing_alg_values_supported is an array of one ("RS256") per +// ADR-0003. +func Render(issuer, jwksPath string) Discovery { + return Discovery{ + Issuer: issuer, + JWKSURI: JoinURL(issuer, jwksPath), + ResponseTypesSupported: []string{"id_token"}, + SubjectTypesSupported: []string{"public"}, + IDTokenSigningAlgValuesSupported: []string{"RS256"}, + ScopesSupported: []string{"openid"}, + ClaimsSupported: []string{"iss", "sub", "aud", "exp", "iat", "nbf", "jti"}, + } +} + +// JoinURL joins a base URL and a path with exactly one slash between them, +// regardless of whether base has a trailing slash or path a leading one. +func JoinURL(base, path string) string { + return strings.TrimRight(base, "/") + "/" + strings.TrimLeft(path, "/") +} diff --git a/internal/oidc/oidc_test.go b/internal/oidc/oidc_test.go new file mode 100644 index 0000000..0839c83 --- /dev/null +++ b/internal/oidc/oidc_test.go @@ -0,0 +1,52 @@ +package oidc + +import ( + "encoding/json" + "testing" +) + +func TestJoinURL(t *testing.T) { + cases := []struct{ base, path, want string }{ + {"https://x.example", ".well-known/jwks.json", "https://x.example/.well-known/jwks.json"}, + {"https://x.example/", ".well-known/jwks.json", "https://x.example/.well-known/jwks.json"}, + {"https://x.example/", "/.well-known/jwks.json", "https://x.example/.well-known/jwks.json"}, + {"https://x.example", "/.well-known/jwks.json", "https://x.example/.well-known/jwks.json"}, + } + for _, c := range cases { + if got := JoinURL(c.base, c.path); got != c.want { + t.Errorf("JoinURL(%q,%q) = %q, want %q", c.base, c.path, got, c.want) + } + } +} + +func TestRender(t *testing.T) { + d := Render("https://jotsmithmax.z13.web.core.windows.net", ".well-known/jwks.json") + + if d.Issuer != "https://jotsmithmax.z13.web.core.windows.net" { + t.Errorf("issuer = %q", d.Issuer) + } + if d.JWKSURI != "https://jotsmithmax.z13.web.core.windows.net/.well-known/jwks.json" { + t.Errorf("jwks_uri = %q", d.JWKSURI) + } + if len(d.IDTokenSigningAlgValuesSupported) != 1 || d.IDTokenSigningAlgValuesSupported[0] != "RS256" { + t.Errorf("alg values = %v", d.IDTokenSigningAlgValuesSupported) + } + + // Confirm the JSON keys match the spec surface exactly. + b, _ := json.Marshal(d) + var got map[string]json.RawMessage + if err := json.Unmarshal(b, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + for _, key := range []string{ + "issuer", "jwks_uri", "response_types_supported", "subject_types_supported", + "id_token_signing_alg_values_supported", "scopes_supported", "claims_supported", + } { + if _, ok := got[key]; !ok { + t.Errorf("discovery JSON missing key %q", key) + } + } + if _, ok := got["token_endpoint"]; ok { + t.Error("discovery must not advertise token_endpoint") + } +} diff --git a/internal/sign/b64.go b/internal/sign/b64.go new file mode 100644 index 0000000..0509099 --- /dev/null +++ b/internal/sign/b64.go @@ -0,0 +1,11 @@ +package sign + +import "encoding/base64" + +// b64 is the base64url encoding JWTs use: the URL-safe alphabet with no padding +// (RFC 7515 §2, "base64url"). +var b64 = base64.RawURLEncoding + +func decodeSegment(s string) ([]byte, error) { + return b64.DecodeString(s) +} diff --git a/internal/sign/decode.go b/internal/sign/decode.go new file mode 100644 index 0000000..ba4acfa --- /dev/null +++ b/internal/sign/decode.go @@ -0,0 +1,83 @@ +package sign + +import ( + "encoding/json" + "fmt" + "strings" +) + +// Decoded is the result of inspecting a compact JWT without verifying it. +// Header and Payload are preserved verbatim as raw JSON so re-rendering keeps +// the original key order, number representation, and unicode. +type Decoded struct { + Header json.RawMessage `json:"header"` + Payload json.RawMessage `json:"payload"` + SignatureBytes int `json:"signature_bytes"` +} + +// DecodeError reports which compact-JWT segment failed to decode and why. The +// Segment is one of "format", "header", "payload", or "signature". +type DecodeError struct { + Segment string + Err error +} + +func (e *DecodeError) Error() string { + if e.Segment == "format" { + return fmt.Sprintf("malformed JWT: %v", e.Err) + } + return fmt.Sprintf("malformed JWT %s segment: %v", e.Segment, e.Err) +} + +func (e *DecodeError) Unwrap() error { return e.Err } + +// Decode splits a compact-serialized JWT and base64url-decodes the header and +// payload as JSON. It performs no signature verification; it only reports the +// byte length of the decoded signature segment. +// +// A three-segment token with an empty signature segment (alg:none) decodes +// cleanly with SignatureBytes == 0. Any other segment count, invalid base64url, +// or invalid JSON in the header or payload is reported as a *DecodeError naming +// the failing segment. +func Decode(token string) (*Decoded, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return nil, &DecodeError{ + Segment: "format", + Err: fmt.Errorf("expected 3 dot-separated segments, got %d", len(parts)), + } + } + + header, err := decodeJSONSegment(parts[0]) + if err != nil { + return nil, &DecodeError{Segment: "header", Err: err} + } + payload, err := decodeJSONSegment(parts[1]) + if err != nil { + return nil, &DecodeError{Segment: "payload", Err: err} + } + sig, err := decodeSegment(parts[2]) + if err != nil { + return nil, &DecodeError{Segment: "signature", Err: fmt.Errorf("not valid base64url: %w", err)} + } + + return &Decoded{ + Header: header, + Payload: payload, + SignatureBytes: len(sig), + }, nil +} + +func decodeJSONSegment(seg string) (json.RawMessage, error) { + raw, err := decodeSegment(seg) + if err != nil { + return nil, fmt.Errorf("not valid base64url: %w", err) + } + // Unmarshalling into json.RawMessage validates the bytes are a single + // well-formed JSON value and rejects trailing garbage. + var msg json.RawMessage + if err := json.Unmarshal(raw, &msg); err != nil { + return nil, fmt.Errorf("not valid JSON: %w", err) + } + return msg, nil +} diff --git a/internal/sign/decode_test.go b/internal/sign/decode_test.go new file mode 100644 index 0000000..ec63b51 --- /dev/null +++ b/internal/sign/decode_test.go @@ -0,0 +1,116 @@ +package sign + +import ( + "bytes" + "encoding/json" + "errors" + "strings" + "testing" +) + +// makeToken assembles a compact token from raw segment strings, base64url +// encoding each. A nil signature produces an empty third segment (alg:none). +func makeToken(header, payload string, sig []byte) string { + return b64.EncodeToString([]byte(header)) + "." + + b64.EncodeToString([]byte(payload)) + "." + + b64.EncodeToString(sig) +} + +func TestDecode_WellFormed(t *testing.T) { + sig := bytes.Repeat([]byte{0xAB}, 256) // RS256 / RSA-2048 signature length + tok := makeToken(`{"alg":"RS256","typ":"JWT","kid":"abc"}`, `{"sub":"me","aud":"x"}`, sig) + + d, err := Decode(tok) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if d.SignatureBytes != 256 { + t.Errorf("signature_bytes = %d, want 256", d.SignatureBytes) + } + var hdr map[string]any + if err := json.Unmarshal(d.Header, &hdr); err != nil { + t.Fatalf("header not valid JSON: %v", err) + } + if hdr["alg"] != "RS256" { + t.Errorf("alg = %v, want RS256", hdr["alg"]) + } +} + +func TestDecode_AlgNoneZeroSignature(t *testing.T) { + tok := makeToken(`{"alg":"none","typ":"JWT"}`, `{"sub":"me"}`, nil) + if !strings.HasSuffix(tok, ".") { + t.Fatalf("expected trailing dot for empty signature, got %q", tok) + } + d, err := Decode(tok) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if d.SignatureBytes != 0 { + t.Errorf("signature_bytes = %d, want 0", d.SignatureBytes) + } +} + +func TestDecode_SegmentCountMismatch(t *testing.T) { + for _, tok := range []string{"only.two", "a.b.c.d", "noseparators"} { + _, err := Decode(tok) + var de *DecodeError + if !errors.As(err, &de) || de.Segment != "format" { + t.Errorf("Decode(%q): expected format DecodeError, got %T %v", tok, err, err) + } + } +} + +func TestDecode_InvalidBase64URLHeader(t *testing.T) { + tok := "!!!notbase64!!!." + b64.EncodeToString([]byte(`{}`)) + ".AAAA" + _, err := Decode(tok) + var de *DecodeError + if !errors.As(err, &de) || de.Segment != "header" { + t.Fatalf("expected header DecodeError, got %T %v", err, err) + } +} + +func TestDecode_InvalidJSONInValidBase64URL(t *testing.T) { + // Valid base64url, but the decoded bytes are not valid JSON. + tok := b64.EncodeToString([]byte("not json")) + "." + b64.EncodeToString([]byte(`{}`)) + ".AAAA" + _, err := Decode(tok) + var de *DecodeError + if !errors.As(err, &de) || de.Segment != "header" { + t.Fatalf("expected header DecodeError, got %T %v", err, err) + } + if !strings.Contains(de.Error(), "not valid JSON") { + t.Errorf("error should mention invalid JSON, got %q", de.Error()) + } +} + +func TestDecode_InvalidJSONPayload(t *testing.T) { + tok := b64.EncodeToString([]byte(`{"alg":"RS256"}`)) + "." + b64.EncodeToString([]byte("{broken")) + ".AAAA" + _, err := Decode(tok) + var de *DecodeError + if !errors.As(err, &de) || de.Segment != "payload" { + t.Fatalf("expected payload DecodeError, got %T %v", err, err) + } +} + +func TestDecode_NestedAndUnicodePreserved(t *testing.T) { + payload := `{"sub":"me","roles":["a","b"],"meta":{"n":42,"deep":{"x":true}},"name":"héllo 世界"}` + tok := makeToken(`{"alg":"RS256","typ":"JWT"}`, payload, []byte{0x01, 0x02}) + + d, err := Decode(tok) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if d.SignatureBytes != 2 { + t.Errorf("signature_bytes = %d, want 2", d.SignatureBytes) + } + var got map[string]any + if err := json.Unmarshal(d.Payload, &got); err != nil { + t.Fatalf("payload not valid JSON: %v", err) + } + if got["name"] != "héllo 世界" { + t.Errorf("unicode not preserved: %v", got["name"]) + } + meta, ok := got["meta"].(map[string]any) + if !ok || meta["n"].(float64) != 42 { + t.Errorf("nested object not preserved: %v", got["meta"]) + } +} diff --git a/internal/sign/doc.go b/internal/sign/doc.go new file mode 100644 index 0000000..22dce45 --- /dev/null +++ b/internal/sign/doc.go @@ -0,0 +1,9 @@ +// Package sign holds jotsmith's compact-JWT mechanics: assembly, signing via +// Key Vault, and the inverse operations (decode and verify) that the token +// subcommands share. +// +// When signing, the signing input (base64url(header).base64url(payload)) is +// hashed with SHA-256 client-side; only the digest is sent to Key Vault's Sign +// API. The returned signature is base64url-encoded and appended. No private key +// material ever touches the CLI process. +package sign diff --git a/internal/sign/mint.go b/internal/sign/mint.go new file mode 100644 index 0000000..278db71 --- /dev/null +++ b/internal/sign/mint.go @@ -0,0 +1,68 @@ +package sign + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/json" + "fmt" +) + +// KeyVaultSigner signs a SHA-256 digest with the issuer's RSA key inside Key +// Vault and returns the raw signature bytes. It is satisfied by +// *azurex.Provider. It is declared here, rather than imported from azurex, so +// the sign package carries no Azure SDK dependency and Mint stays unit-testable +// with a fake. +type KeyVaultSigner interface { + Sign(ctx context.Context, digest []byte) ([]byte, error) +} + +// header is the fixed JWT header. The struct field order (alg, typ, kid) is the +// canonical serialization order; v1 only ever emits RS256. +type header struct { + Alg string `json:"alg"` + Typ string `json:"typ"` + Kid string `json:"kid"` +} + +// Mint assembles a compact RS256 JWT from kid and claims, signing it inside Key +// Vault. The header is fixed to {"alg":"RS256","typ":"JWT","kid":kid}; claims is +// marshaled as the payload (encoding/json sorts map keys, yielding +// deterministic, whitespace-free JSON). +// +// The signing input is hashed with SHA-256 client-side; only that digest is +// sent to Key Vault. No private key material ever touches this process. +func Mint(ctx context.Context, signer KeyVaultSigner, kid string, claims map[string]any) (string, error) { + headerJSON, err := marshalCanonical(header{Alg: "RS256", Typ: "JWT", Kid: kid}) + if err != nil { + return "", fmt.Errorf("marshaling header: %w", err) + } + payloadJSON, err := marshalCanonical(claims) + if err != nil { + return "", fmt.Errorf("marshaling claims: %w", err) + } + + signingInput := b64.EncodeToString(headerJSON) + "." + b64.EncodeToString(payloadJSON) + digest := sha256.Sum256([]byte(signingInput)) + + sig, err := signer.Sign(ctx, digest[:]) + if err != nil { + return "", err + } + return signingInput + "." + b64.EncodeToString(sig), nil +} + +// marshalCanonical marshals v to compact JSON with no whitespace and without +// HTML-escaping <, >, & (which encoding/json does by default). JWT claim values +// such as audience URLs must round-trip verbatim, so HTML escaping is disabled. +func marshalCanonical(v any) ([]byte, error) { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + if err := enc.Encode(v); err != nil { + return nil, err + } + // Encoder.Encode appends a trailing newline; strip it so the segment is the + // bare JSON value. + return bytes.TrimRight(buf.Bytes(), "\n"), nil +} diff --git a/internal/sign/mint_test.go b/internal/sign/mint_test.go new file mode 100644 index 0000000..0090e49 --- /dev/null +++ b/internal/sign/mint_test.go @@ -0,0 +1,110 @@ +package sign + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "errors" + "strings" + "testing" +) + +// fakeKV signs with a real RSA key the same way Key Vault does for RS256: +// RSASSA-PKCS1-v1_5 over a precomputed SHA-256 digest. +type fakeKV struct { + priv *rsa.PrivateKey + gotDigest []byte + err error +} + +func (f *fakeKV) Sign(_ context.Context, digest []byte) ([]byte, error) { + if f.err != nil { + return nil, f.err + } + f.gotDigest = append([]byte(nil), digest...) + return rsa.SignPKCS1v15(rand.Reader, f.priv, crypto.SHA256, digest) +} + +func TestMint_CanonicalEncodingAndVerifiableSignature(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generating key: %v", err) + } + kv := &fakeKV{priv: priv} + + claims := map[string]any{ + "iss": "https://issuer.example", + "sub": "me", + "exp": int64(123), + } + token, err := Mint(context.Background(), kv, "test-kid", claims) + if err != nil { + t.Fatalf("Mint: %v", err) + } + + parts := strings.Split(token, ".") + if len(parts) != 3 { + t.Fatalf("expected 3 segments, got %d (%q)", len(parts), token) + } + + header, err := b64.DecodeString(parts[0]) + if err != nil { + t.Fatalf("decoding header: %v", err) + } + wantHeader := `{"alg":"RS256","typ":"JWT","kid":"test-kid"}` + if string(header) != wantHeader { + t.Errorf("header not canonical:\n got %s\nwant %s", header, wantHeader) + } + + payload, err := b64.DecodeString(parts[1]) + if err != nil { + t.Fatalf("decoding payload: %v", err) + } + // encoding/json sorts map keys; no whitespace. + wantPayload := `{"exp":123,"iss":"https://issuer.example","sub":"me"}` + if string(payload) != wantPayload { + t.Errorf("payload not canonical:\n got %s\nwant %s", payload, wantPayload) + } + + // The digest sent to KV must be the SHA-256 of the signing input — proof the + // hashing happens client-side and no key material is needed locally. + signingInput := parts[0] + "." + parts[1] + wantDigest := sha256.Sum256([]byte(signingInput)) + if string(kv.gotDigest) != string(wantDigest[:]) { + t.Error("digest sent to KV is not the SHA-256 of the signing input") + } + + // The returned signature must verify against the public key. + sig, err := b64.DecodeString(parts[2]) + if err != nil { + t.Fatalf("decoding signature: %v", err) + } + if err := rsa.VerifyPKCS1v15(&priv.PublicKey, crypto.SHA256, wantDigest[:], sig); err != nil { + t.Errorf("signature does not verify: %v", err) + } +} + +func TestMint_DoesNotHTMLEscapeClaims(t *testing.T) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + token, err := Mint(context.Background(), &fakeKV{priv: priv}, "k", map[string]any{ + "aud": "a&bd", + }) + if err != nil { + t.Fatalf("Mint: %v", err) + } + payload, _ := b64.DecodeString(strings.Split(token, ".")[1]) + if !strings.Contains(string(payload), "a&bd") { + t.Errorf("expected unescaped claim value, got %s", payload) + } +} + +func TestMint_PropagatesSignerError(t *testing.T) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + sentinel := errors.New("kv down") + _, err := Mint(context.Background(), &fakeKV{priv: priv, err: sentinel}, "k", map[string]any{"sub": "x"}) + if !errors.Is(err, sentinel) { + t.Fatalf("expected signer error to propagate, got %v", err) + } +} diff --git a/internal/sign/verify.go b/internal/sign/verify.go new file mode 100644 index 0000000..86a4cec --- /dev/null +++ b/internal/sign/verify.go @@ -0,0 +1,185 @@ +package sign + +import ( + "crypto" + "crypto/rsa" + "crypto/sha256" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/MaxAnderson95/jotsmith/internal/jwk" +) + +// DefaultSkew is the clock-skew tolerance applied to time-based claim checks +// (PRD §6.3: ±60 seconds). +const DefaultSkew = 60 * time.Second + +// VerifyOptions configures token verification. Keys is the JWKS already fetched +// from the issuer; the HTTP round-trip lives in the CLI so this package stays +// pure and unit-testable. A zero Now means time.Now(). +type VerifyOptions struct { + Issuer string + Keys jwk.Set + ExpectedAud string + ExpectedSub string + Now time.Time + Skew time.Duration +} + +// VerifyError reports which verification check failed. It is distinct from +// DecodeError so callers can tell a malformed token from a failed check. +type VerifyError struct{ Reason string } + +func (e *VerifyError) Error() string { return e.Reason } + +// Verify checks a compact JWT's RS256 signature against the JWKS and validates +// its standard claims (iss, exp, nbf, iat) with clock-skew tolerance, plus the +// optional aud/sub expectations. On success it returns the decoded token. +// +// Only RS256 is accepted; unsigned tokens and any other alg are rejected. +func Verify(token string, opts VerifyOptions) (*Decoded, error) { + now := opts.Now + if now.IsZero() { + now = time.Now() + } + skew := opts.Skew + + parts := strings.Split(token, ".") + if len(parts) != 3 { + return nil, &VerifyError{Reason: fmt.Sprintf("malformed JWT: expected 3 segments, got %d", len(parts))} + } + + kid, err := requireRS256(parts[0]) + if err != nil { + return nil, err + } + + pub, err := publicKeyForKID(opts.Keys, kid) + if err != nil { + return nil, err + } + + sig, err := decodeSegment(parts[2]) + if err != nil { + return nil, &VerifyError{Reason: "signature is not valid base64url"} + } + digest := sha256.Sum256([]byte(parts[0] + "." + parts[1])) + if err := rsa.VerifyPKCS1v15(pub, crypto.SHA256, digest[:], sig); err != nil { + return nil, &VerifyError{Reason: "signature verification failed"} + } + + payloadBytes, err := decodeSegment(parts[1]) + if err != nil { + return nil, &VerifyError{Reason: "payload is not valid base64url"} + } + if err := verifyClaims(payloadBytes, opts, now, skew); err != nil { + return nil, err + } + + headerBytes, _ := decodeSegment(parts[0]) + return &Decoded{ + Header: json.RawMessage(headerBytes), + Payload: json.RawMessage(payloadBytes), + SignatureBytes: len(sig), + }, nil +} + +// requireRS256 parses the header segment and returns the kid, rejecting any alg +// other than RS256 (including "none"). +func requireRS256(headerSeg string) (string, error) { + raw, err := decodeSegment(headerSeg) + if err != nil { + return "", &VerifyError{Reason: "header is not valid base64url"} + } + var hdr struct { + Alg string `json:"alg"` + Kid string `json:"kid"` + } + if err := json.Unmarshal(raw, &hdr); err != nil { + return "", &VerifyError{Reason: "header is not valid JSON"} + } + if hdr.Alg != "RS256" { + return "", &VerifyError{Reason: fmt.Sprintf("unsupported alg %q (only RS256 is accepted)", hdr.Alg)} + } + return hdr.Kid, nil +} + +// publicKeyForKID finds the JWK whose kid matches and reconstructs its RSA +// public key. +func publicKeyForKID(set jwk.Set, kid string) (*rsa.PublicKey, error) { + for i := range set.Keys { + if set.Keys[i].Kid == kid { + pub, err := set.Keys[i].RSAPublicKey() + if err != nil { + return nil, &VerifyError{Reason: fmt.Sprintf("reconstructing public key for kid %q: %v", kid, err)} + } + return pub, nil + } + } + return nil, &VerifyError{Reason: fmt.Sprintf("no JWK in the issuer's JWKS matches kid %q", kid)} +} + +// verifyClaims checks iss, the time-based claims with skew, and the optional +// aud/sub expectations. +func verifyClaims(payload []byte, opts VerifyOptions, now time.Time, skew time.Duration) error { + var claims struct { + Iss string `json:"iss"` + Sub string `json:"sub"` + Exp int64 `json:"exp"` + Nbf int64 `json:"nbf"` + Iat int64 `json:"iat"` + Aud json.RawMessage `json:"aud"` + } + if err := json.Unmarshal(payload, &claims); err != nil { + return &VerifyError{Reason: "payload is not valid JSON"} + } + + if claims.Iss != opts.Issuer { + return &VerifyError{Reason: fmt.Sprintf("token iss %q does not match expected issuer %q", claims.Iss, opts.Issuer)} + } + if claims.Exp != 0 { + if exp := time.Unix(claims.Exp, 0); now.After(exp.Add(skew)) { + return &VerifyError{Reason: fmt.Sprintf("token expired at %s", exp.UTC().Format(time.RFC3339))} + } + } + if claims.Nbf != 0 { + if nbf := time.Unix(claims.Nbf, 0); now.Before(nbf.Add(-skew)) { + return &VerifyError{Reason: fmt.Sprintf("token is not valid until %s (nbf)", nbf.UTC().Format(time.RFC3339))} + } + } + if claims.Iat != 0 { + if iat := time.Unix(claims.Iat, 0); iat.After(now.Add(skew)) { + return &VerifyError{Reason: fmt.Sprintf("token iat %s is in the future", iat.UTC().Format(time.RFC3339))} + } + } + if opts.ExpectedAud != "" && !audienceContains(claims.Aud, opts.ExpectedAud) { + return &VerifyError{Reason: fmt.Sprintf("audience does not include %q", opts.ExpectedAud)} + } + if opts.ExpectedSub != "" && claims.Sub != opts.ExpectedSub { + return &VerifyError{Reason: fmt.Sprintf("subject %q does not match expected %q", claims.Sub, opts.ExpectedSub)} + } + return nil +} + +// audienceContains reports whether the raw aud claim (a JSON string or array of +// strings) includes want. +func audienceContains(raw json.RawMessage, want string) bool { + if len(raw) == 0 { + return false + } + var single string + if err := json.Unmarshal(raw, &single); err == nil { + return single == want + } + var many []string + if err := json.Unmarshal(raw, &many); err == nil { + for _, a := range many { + if a == want { + return true + } + } + } + return false +} diff --git a/internal/sign/verify_test.go b/internal/sign/verify_test.go new file mode 100644 index 0000000..e662b58 --- /dev/null +++ b/internal/sign/verify_test.go @@ -0,0 +1,198 @@ +package sign + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "errors" + "math/big" + "testing" + "time" + + "github.com/MaxAnderson95/jotsmith/internal/jwk" +) + +// signedToken mints a token with priv and returns it together with a JWKS that +// publishes priv's public key under its thumbprint kid. +func signedToken(t *testing.T, priv *rsa.PrivateKey, claims map[string]any) (string, jwk.Set) { + t.Helper() + n := priv.N.Bytes() + e := big.NewInt(int64(priv.E)).Bytes() + kid := jwk.Thumbprint(n, e) + token, err := Mint(context.Background(), &fakeKV{priv: priv}, kid, claims) + if err != nil { + t.Fatalf("Mint: %v", err) + } + return token, jwk.NewSet(jwk.FromRSA(n, e)) +} + +const verifyIssuer = "https://issuer.example" + +func validClaims(now time.Time) map[string]any { + return map[string]any{ + "iss": verifyIssuer, + "sub": "me", + "iat": now.Unix(), + "nbf": now.Unix(), + "exp": now.Add(15 * time.Minute).Unix(), + "aud": "sigstore", + } +} + +func TestVerify_ValidToken(t *testing.T) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + now := time.Now() + token, set := signedToken(t, priv, validClaims(now)) + + decoded, err := Verify(token, VerifyOptions{Issuer: verifyIssuer, Keys: set, Now: now, Skew: DefaultSkew}) + if err != nil { + t.Fatalf("expected valid token, got %v", err) + } + if decoded.SignatureBytes == 0 { + t.Error("expected a non-empty signature") + } +} + +func TestVerify_BadKID(t *testing.T) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + now := time.Now() + token, set := signedToken(t, priv, validClaims(now)) + // Replace the published kid so nothing matches the token header. + set.Keys[0].Kid = "different-kid" + + if _, err := Verify(token, VerifyOptions{Issuer: verifyIssuer, Keys: set, Now: now}); !isVerifyErr(err) { + t.Fatalf("expected a verify error for bad kid, got %v", err) + } +} + +func TestVerify_SignatureMismatch(t *testing.T) { + privA, _ := rsa.GenerateKey(rand.Reader, 2048) + privB, _ := rsa.GenerateKey(rand.Reader, 2048) + now := time.Now() + token, _ := signedToken(t, privA, validClaims(now)) + + // Publish key B but force its kid to collide with the token's (key A's) kid, + // so the verifier reconstructs the wrong public key. + kidA := jwk.Thumbprint(privA.N.Bytes(), big.NewInt(int64(privA.E)).Bytes()) + jwkB := jwk.FromRSA(privB.N.Bytes(), big.NewInt(int64(privB.E)).Bytes()) + jwkB.Kid = kidA + set := jwk.NewSet(jwkB) + + _, err := Verify(token, VerifyOptions{Issuer: verifyIssuer, Keys: set, Now: now}) + if !isVerifyErr(err) { + t.Fatalf("expected signature verification failure, got %v", err) + } +} + +func TestVerify_Expired(t *testing.T) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + now := time.Now() + claims := validClaims(now) + claims["exp"] = now.Add(-10 * time.Minute).Unix() + token, set := signedToken(t, priv, claims) + + if _, err := Verify(token, VerifyOptions{Issuer: verifyIssuer, Keys: set, Now: now, Skew: DefaultSkew}); !isVerifyErr(err) { + t.Fatalf("expected expired error, got %v", err) + } +} + +func TestVerify_NotYetValid(t *testing.T) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + now := time.Now() + claims := validClaims(now) + claims["nbf"] = now.Add(10 * time.Minute).Unix() + token, set := signedToken(t, priv, claims) + + if _, err := Verify(token, VerifyOptions{Issuer: verifyIssuer, Keys: set, Now: now, Skew: DefaultSkew}); !isVerifyErr(err) { + t.Fatalf("expected not-yet-valid error, got %v", err) + } +} + +func TestVerify_WrongIssuer(t *testing.T) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + now := time.Now() + token, set := signedToken(t, priv, validClaims(now)) + + if _, err := Verify(token, VerifyOptions{Issuer: "https://other.example", Keys: set, Now: now}); !isVerifyErr(err) { + t.Fatalf("expected wrong-issuer error, got %v", err) + } +} + +func TestVerify_Audience(t *testing.T) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + now := time.Now() + + t.Run("string aud matches", func(t *testing.T) { + token, set := signedToken(t, priv, validClaims(now)) + if _, err := Verify(token, VerifyOptions{Issuer: verifyIssuer, Keys: set, Now: now, ExpectedAud: "sigstore"}); err != nil { + t.Fatalf("expected match, got %v", err) + } + }) + t.Run("string aud mismatch", func(t *testing.T) { + token, set := signedToken(t, priv, validClaims(now)) + if _, err := Verify(token, VerifyOptions{Issuer: verifyIssuer, Keys: set, Now: now, ExpectedAud: "nope"}); !isVerifyErr(err) { + t.Fatalf("expected aud mismatch, got %v", err) + } + }) + t.Run("array aud contains", func(t *testing.T) { + claims := validClaims(now) + claims["aud"] = []string{"a", "sigstore", "c"} + token, set := signedToken(t, priv, claims) + if _, err := Verify(token, VerifyOptions{Issuer: verifyIssuer, Keys: set, Now: now, ExpectedAud: "sigstore"}); err != nil { + t.Fatalf("expected match in array, got %v", err) + } + }) + t.Run("array aud missing", func(t *testing.T) { + claims := validClaims(now) + claims["aud"] = []string{"a", "b"} + token, set := signedToken(t, priv, claims) + if _, err := Verify(token, VerifyOptions{Issuer: verifyIssuer, Keys: set, Now: now, ExpectedAud: "sigstore"}); !isVerifyErr(err) { + t.Fatalf("expected aud-not-found, got %v", err) + } + }) +} + +func TestVerify_WrongSubject(t *testing.T) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + now := time.Now() + token, set := signedToken(t, priv, validClaims(now)) + + if _, err := Verify(token, VerifyOptions{Issuer: verifyIssuer, Keys: set, Now: now, ExpectedSub: "someone-else"}); !isVerifyErr(err) { + t.Fatalf("expected subject mismatch, got %v", err) + } + if _, err := Verify(token, VerifyOptions{Issuer: verifyIssuer, Keys: set, Now: now, ExpectedSub: "me"}); err != nil { + t.Fatalf("expected subject match, got %v", err) + } +} + +func TestVerify_RejectsNonRS256(t *testing.T) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + now := time.Now() + _, set := signedToken(t, priv, validClaims(now)) + + enc := base64.RawURLEncoding.EncodeToString + token := enc([]byte(`{"alg":"none","typ":"JWT"}`)) + "." + enc([]byte(`{"iss":"`+verifyIssuer+`"}`)) + "." + + if _, err := Verify(token, VerifyOptions{Issuer: verifyIssuer, Keys: set, Now: now}); !isVerifyErr(err) { + t.Fatalf("expected alg rejection, got %v", err) + } +} + +func TestVerify_SkewToleratesRecentExpiry(t *testing.T) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + now := time.Now() + claims := validClaims(now) + // Expired 30s ago — within the ±60s skew, so it should still verify. + claims["exp"] = now.Add(-30 * time.Second).Unix() + token, set := signedToken(t, priv, claims) + + if _, err := Verify(token, VerifyOptions{Issuer: verifyIssuer, Keys: set, Now: now, Skew: DefaultSkew}); err != nil { + t.Fatalf("expected skew to tolerate 30s-stale exp, got %v", err) + } +} + +func isVerifyErr(err error) bool { + var ve *VerifyError + return errors.As(err, &ve) +} From a411afa53d51651e642dfc85276d51cbe7f11ea0 Mon Sep 17 00:00:00 2001 From: Max Anderson Date: Mon, 1 Jun 2026 11:51:14 -0400 Subject: [PATCH 2/2] fix: harden issuer publishing invariants Centralize discovery and JWKS rendering so setup, rotate, doctor repair, and show commands use the same document path. Enforce the v1 single RS256 JWK invariant during doctor and publishing, remove arbitrary signing-key bit-size plumbing, canonicalize discovery issuer rendering, and make doctor --e2e perform a real mint/verify round trip.\n\nOpenCode session ID: ses_17c2d0269ffeCyZ490A503eFP2 --- internal/azurex/manager.go | 12 +++-- internal/azurex/types.go | 4 +- internal/cli/discovery.go | 8 +-- internal/cli/doctor.go | 25 ++++----- internal/cli/doctor_e2e.go | 49 ++++++++++++++++++ internal/cli/doctor_repair_test.go | 53 +++++++++++++++++++ internal/cli/doctor_test.go | 12 +++++ internal/cli/jwks.go | 17 +----- internal/cli/key.go | 13 ++--- internal/cli/key_test.go | 4 +- internal/cli/publish.go | 54 +++++++++++++++++++ internal/cli/setup.go | 21 +++----- internal/cli/setup_test.go | 4 +- internal/jwk/jwk.go | 83 ++++++++++++++++++++++++------ internal/jwk/jwk_test.go | 24 +++++++++ internal/oidc/oidc.go | 9 +++- internal/oidc/oidc_test.go | 2 +- 17 files changed, 306 insertions(+), 88 deletions(-) create mode 100644 internal/cli/doctor_e2e.go create mode 100644 internal/cli/publish.go diff --git a/internal/azurex/manager.go b/internal/azurex/manager.go index 68b76cb..654d274 100644 --- a/internal/azurex/manager.go +++ b/internal/azurex/manager.go @@ -10,6 +10,8 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/service" ) +const signingKeyBits = 2048 + // EnableStaticWebsite turns on static website hosting for the storage account's // blob service. This is the single control-plane-ish mutation jotsmith is ever // allowed to make (ADR-0002). Index and error documents are left unset, as @@ -50,17 +52,17 @@ func (p *Provider) UploadBlob(ctx context.Context, path, contentType, cacheContr return nil } -// CreateRSAKey creates an RSA key (or, if one already exists under the same -// name, a new version of it) with the sign and verify operations enabled, and -// returns the new public material. -func (p *Provider) CreateRSAKey(ctx context.Context, bits int) (Key, error) { +// CreateSigningKey creates the v1 signing key shape: RSA 2048 with sign and +// verify operations enabled. If a key already exists under the same name, Key +// Vault creates a new version and returns its public material. +func (p *Provider) CreateSigningKey(ctx context.Context) (Key, error) { client, err := p.keysClient(ctx) if err != nil { return Key{}, err } resp, err := client.CreateKey(ctx, p.target.KeyName, azkeys.CreateKeyParameters{ Kty: to.Ptr(azkeys.KeyTypeRSA), - KeySize: to.Ptr(int32(bits)), + KeySize: to.Ptr(int32(signingKeyBits)), KeyOps: []*azkeys.KeyOperation{ to.Ptr(azkeys.KeyOperationSign), to.Ptr(azkeys.KeyOperationVerify), diff --git a/internal/azurex/types.go b/internal/azurex/types.go index bd9af14..eed2da7 100644 --- a/internal/azurex/types.go +++ b/internal/azurex/types.go @@ -80,7 +80,7 @@ type SetupManager interface { Inspector EnableStaticWebsite(ctx context.Context) error UploadBlob(ctx context.Context, path, contentType, cacheControl string, data []byte) error - CreateRSAKey(ctx context.Context, bits int) (Key, error) + CreateSigningKey(ctx context.Context) (Key, error) } // Signer is the Key Vault signing surface token mint depends on: read the @@ -99,7 +99,7 @@ type Signer interface { // satisfies it. type RotateManager interface { GetSigningKey(ctx context.Context) (Key, error) - CreateRSAKey(ctx context.Context, bits int) (Key, error) + CreateSigningKey(ctx context.Context) (Key, error) GetDiscoveryDocument(ctx context.Context) ([]byte, error) UploadBlob(ctx context.Context, path, contentType, cacheControl string, data []byte) error } diff --git a/internal/cli/discovery.go b/internal/cli/discovery.go index 1441b91..cd84ca3 100644 --- a/internal/cli/discovery.go +++ b/internal/cli/discovery.go @@ -2,12 +2,9 @@ package cli import ( "context" - "encoding/json" "fmt" ucli "github.com/urfave/cli/v3" - - "github.com/MaxAnderson95/jotsmith/internal/oidc" ) func discoveryCommand(streams IOStreams) *ucli.Command { @@ -33,10 +30,9 @@ func discoveryShowCommand(streams IOStreams) *ucli.Command { if err != nil { return err } - doc := oidc.Render(cfg.Issuer, cfg.JWKSPath) - b, err := json.MarshalIndent(doc, "", " ") + b, err := renderDiscoveryJSON(cfg.Issuer, cfg.JWKSPath) if err != nil { - return failuref("rendering discovery document: %v", err) + return failuref("%v", err) } fmt.Fprintln(streams.Out, string(b)) return nil diff --git a/internal/cli/doctor.go b/internal/cli/doctor.go index 0526d9a..8250868 100644 --- a/internal/cli/doctor.go +++ b/internal/cli/doctor.go @@ -12,7 +12,6 @@ import ( "github.com/MaxAnderson95/jotsmith/internal/azurex" "github.com/MaxAnderson95/jotsmith/internal/config" "github.com/MaxAnderson95/jotsmith/internal/jwk" - "github.com/MaxAnderson95/jotsmith/internal/oidc" ) // Check names. Defined as constants so the inspect, repair, and report stages @@ -27,6 +26,7 @@ const ( checkJWKS = "JWKS published" checkVaultRBAC = "Key Vault in RBAC mode" checkSigningKey = "Signing key valid" + checkE2E = "End-to-end mint and verify" ) // checkStatus is the outcome of a single doctor check. @@ -152,11 +152,7 @@ func runDoctor(ctx context.Context, mgr azurex.RepairManager, cfg *config.Config full := append([]check{{Name: checkCredential, Status: statusPass, Message: "resolved via DefaultAzureCredential"}}, checks...) if opts.e2e { - full = append(full, check{ - Name: "End-to-end mint and verify", - Status: statusWarn, - Message: "skipped: end-to-end check is not yet implemented", - }) + full = append(full, inspectE2E(ctx, mgr, cfg)) } if opts.json { @@ -190,7 +186,7 @@ func repairChecks(ctx context.Context, mgr azurex.RepairManager, cfg *config.Con } if checkFailed(checks, checkDiscovery) { - if doc, err := json.MarshalIndent(oidc.Render(cfg.Issuer, cfg.JWKSPath), "", " "); err == nil { + if doc, err := renderDiscoveryJSON(cfg.Issuer, cfg.JWKSPath); err == nil { if uerr := mgr.UploadBlob(ctx, cfg.DiscoveryPath, contentTypeJSON, cacheControl, doc); uerr == nil { repaired[checkDiscovery] = true } @@ -202,7 +198,7 @@ func repairChecks(ctx context.Context, mgr azurex.RepairManager, cfg *config.Con // is usable (a disabled or sign-less key is a human-action FAIL). if checkFailed(checks, checkJWKS) || checkFailed(checks, checkSigningKey) { if key, err := mgr.GetSigningKey(ctx); err == nil && key.Enabled && slices.Contains(key.Ops, "sign") { - if doc, derr := json.MarshalIndent(jwk.NewSet(jwk.FromRSA(key.N, key.E)), "", " "); derr == nil { + if doc, _, derr := renderJWKSJSON(key); derr == nil { if uerr := mgr.UploadBlob(ctx, cfg.JWKSPath, contentTypeJSON, cacheControl, doc); uerr == nil { repaired[checkJWKS] = true repaired[checkSigningKey] = true @@ -277,15 +273,12 @@ func inspect(ctx context.Context, insp azurex.Inspector, cfg *config.Config) []c add(checkJWKS, statusFail, "%v", err) } else { var set jwk.Set - switch { - case json.Unmarshal(data, &set) != nil: + if err := json.Unmarshal(data, &set); err != nil { add(checkJWKS, statusFail, "published JWKS is not valid JSON") - case len(set.Keys) == 0: - add(checkJWKS, statusFail, "published JWKS contains no keys") - case set.Keys[0].Kty != "RSA" || set.Keys[0].N == "" || set.Keys[0].E == "": - add(checkJWKS, statusFail, "published JWKS key is not a valid RSA JWK") - default: - publishedKid = set.Keys[0].Kid + } else if entry, serr := set.SingleRS256(); serr != nil { + add(checkJWKS, statusFail, "published JWKS is invalid for jotsmith v1: %v", serr) + } else { + publishedKid = entry.Kid add(checkJWKS, statusPass, "1 RSA key, kid %s", publishedKid) } } diff --git a/internal/cli/doctor_e2e.go b/internal/cli/doctor_e2e.go new file mode 100644 index 0000000..e7b8ba3 --- /dev/null +++ b/internal/cli/doctor_e2e.go @@ -0,0 +1,49 @@ +package cli + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/config" + "github.com/MaxAnderson95/jotsmith/internal/jwk" + "github.com/MaxAnderson95/jotsmith/internal/sign" +) + +func inspectE2E(ctx context.Context, mgr azurex.RepairManager, cfg *config.Config) check { + signer, ok := mgr.(azurex.Signer) + if !ok { + return check{Name: checkE2E, Status: statusWarn, Message: "skipped: Azure manager cannot sign test tokens"} + } + + key, err := signer.GetSigningKey(ctx) + if err != nil { + return check{Name: checkE2E, Status: statusFail, Message: fmt.Sprintf("getting signing key: %v", err)} + } + if !key.Enabled { + return check{Name: checkE2E, Status: statusFail, Message: fmt.Sprintf("signing key %q is disabled", key.Name)} + } + + now := time.Now().UTC() + token, err := sign.Mint(ctx, signer, jwk.Thumbprint(key.N, key.E), map[string]any{ + "iss": cfg.Issuer, + "sub": "jotsmith-doctor", + "iat": now.Unix(), + "nbf": now.Unix(), + "exp": now.Add(5 * time.Minute).Unix(), + }) + if err != nil { + return check{Name: checkE2E, Status: statusFail, Message: fmt.Sprintf("minting test token: %v", err)} + } + + set, err := fetchJWKS(ctx, &http.Client{Timeout: verifyHTTPTimeout}, cfg) + if err != nil { + return check{Name: checkE2E, Status: statusFail, Message: fmt.Sprintf("fetching live discovery/JWKS: %v", err)} + } + if _, err := sign.Verify(token, sign.VerifyOptions{Issuer: cfg.Issuer, Keys: set, Now: now, Skew: sign.DefaultSkew}); err != nil { + return check{Name: checkE2E, Status: statusFail, Message: fmt.Sprintf("verifying test token: %v", err)} + } + return check{Name: checkE2E, Status: statusPass, Message: "minted and verified a short-lived token via live discovery"} +} diff --git a/internal/cli/doctor_repair_test.go b/internal/cli/doctor_repair_test.go index 1aa2912..70fb19b 100644 --- a/internal/cli/doctor_repair_test.go +++ b/internal/cli/doctor_repair_test.go @@ -3,12 +3,15 @@ package cli import ( "context" "encoding/json" + "net/http" + "net/http/httptest" "path/filepath" "strings" "testing" "github.com/MaxAnderson95/jotsmith/internal/azurex" "github.com/MaxAnderson95/jotsmith/internal/config" + "github.com/MaxAnderson95/jotsmith/internal/jwk" ) // healthyRepairManager wraps a fully-healthy inspector in a fakeManager so @@ -133,6 +136,56 @@ func TestRunDoctor_JSONReportsRepairedCount(t *testing.T) { } } +type doctorE2EManager struct { + *fakeManager + signer *fakeSigner +} + +func (m *doctorE2EManager) Sign(ctx context.Context, digest []byte) ([]byte, error) { + return m.signer.Sign(ctx, digest) +} + +func TestRunDoctor_E2EMintsAndVerifies(t *testing.T) { + signer := newFakeSigner(t) + set := jwk.NewSet(jwk.FromRSA(signer.key.N, signer.key.E)) + + serverURL := "" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", contentTypeJSON) + switch r.URL.Path { + case "/.well-known/openid-configuration": + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": serverURL, + "jwks_uri": serverURL + "/.well-known/jwks.json", + }) + case "/.well-known/jwks.json": + _ = json.NewEncoder(w).Encode(set) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + serverURL = srv.URL + + cfg := testConfig() + cfg.Issuer = srv.URL + manager := healthyRepairManager(t) + manager.acct.WebEndpoint = srv.URL + manager.key = signer.key + manager.disco = []byte(`{"issuer":"` + srv.URL + `","jwks_uri":"` + srv.URL + `/.well-known/jwks.json"}`) + manager.jwks, _ = json.Marshal(set) + + streams, out, _ := mintStreams() + err := runDoctor(context.Background(), &doctorE2EManager{fakeManager: manager, signer: signer}, cfg, doctorOptions{e2e: true, json: true}, false, streams) + if err != nil { + t.Fatalf("expected e2e doctor to pass, got %v", err) + } + report := parseReport(t, out.String()) + if c := reportCheck(t, report, checkE2E); c.Status != statusPass { + t.Fatalf("expected E2E PASS, got %s: %s", c.Status, c.Message) + } +} + func TestDoctorMissingConfigIsUsageError(t *testing.T) { t.Setenv("JOTSMITH_CONFIG", filepath.Join(t.TempDir(), "absent.json")) _, _, err := run(t, "doctor", "--json") diff --git a/internal/cli/doctor_test.go b/internal/cli/doctor_test.go index b806d65..fbb62f2 100644 --- a/internal/cli/doctor_test.go +++ b/internal/cli/doctor_test.go @@ -164,6 +164,18 @@ func TestInspect_ThumbprintMismatchFails(t *testing.T) { } } +func TestInspect_MultipleJWKSKeysFails(t *testing.T) { + f := healthyInspector(t) + entry := jwk.FromRSA(f.key.N, f.key.E) + f.jwks, _ = json.Marshal(jwk.Set{Keys: []jwk.JWK{entry, entry}}) + + checks := inspect(context.Background(), f, testConfig()) + got := findCheck(t, checks, "JWKS published") + if got.Status != statusFail { + t.Fatalf("expected FAIL for multi-key JWKS, got %s: %s", got.Status, got.Message) + } +} + func TestInspect_DisabledKeyFails(t *testing.T) { f := healthyInspector(t) f.key.Enabled = false diff --git a/internal/cli/jwks.go b/internal/cli/jwks.go index 87e8263..ee9cf0f 100644 --- a/internal/cli/jwks.go +++ b/internal/cli/jwks.go @@ -2,13 +2,11 @@ package cli import ( "context" - "encoding/json" "fmt" ucli "github.com/urfave/cli/v3" "github.com/MaxAnderson95/jotsmith/internal/azurex" - "github.com/MaxAnderson95/jotsmith/internal/jwk" ) func jwksCommand(streams IOStreams) *ucli.Command { @@ -44,23 +42,12 @@ func jwksShowCommand(streams IOStreams) *ucli.Command { if err != nil { return keyVaultError(err, cfg.KeyVault) } - - b, err := json.MarshalIndent(set, "", " ") + b, err := renderJWKSetJSON(set) if err != nil { - return failuref("rendering JWKS: %v", err) + return failuref("%v", err) } fmt.Fprintln(streams.Out, string(b)) return nil }, } } - -// buildJWKS reads the current signing key's public material and wraps it in a -// single-entry JWKS. -func buildJWKS(ctx context.Context, insp azurex.Inspector) (jwk.Set, error) { - key, err := insp.GetSigningKey(ctx) - if err != nil { - return jwk.Set{}, err - } - return jwk.NewSet(jwk.FromRSA(key.N, key.E)), nil -} diff --git a/internal/cli/key.go b/internal/cli/key.go index 2c4b7e9..74db3e3 100644 --- a/internal/cli/key.go +++ b/internal/cli/key.go @@ -3,7 +3,6 @@ package cli import ( "bytes" "context" - "encoding/json" "errors" "fmt" @@ -12,7 +11,6 @@ import ( "github.com/MaxAnderson95/jotsmith/internal/azurex" "github.com/MaxAnderson95/jotsmith/internal/config" "github.com/MaxAnderson95/jotsmith/internal/jwk" - "github.com/MaxAnderson95/jotsmith/internal/oidc" ) func keyCommand(streams IOStreams) *ucli.Command { @@ -75,15 +73,14 @@ func runRotate(ctx context.Context, mgr azurex.RotateManager, cfg *config.Config return nil } - newKey, err := mgr.CreateRSAKey(ctx, rsaKeyBits) + newKey, err := mgr.CreateSigningKey(ctx) if err != nil { return keyVaultError(err, cfg.KeyVault) } - jwkEntry := jwk.FromRSA(newKey.N, newKey.E) - jwksDoc, err := json.MarshalIndent(jwk.NewSet(jwkEntry), "", " ") + jwksDoc, jwkEntry, err := renderJWKSJSON(newKey) if err != nil { - return failuref("rendering JWKS: %v", err) + return failuref("%v", err) } if uerr := mgr.UploadBlob(ctx, cfg.JWKSPath, contentTypeJSON, cacheControl, jwksDoc); uerr != nil { return failuref("new key version was created but the JWKS upload failed (run `jotsmith doctor` to inspect the drift): %v", uerr) @@ -117,9 +114,9 @@ func currentKID(ctx context.Context, mgr azurex.RotateManager, cfg *config.Confi // rendered bytes differ from what is already published (rotation does not // change the discovery doc, so normally this is a no-op). func refreshDiscoveryIfChanged(ctx context.Context, mgr azurex.RotateManager, cfg *config.Config) error { - desired, err := json.MarshalIndent(oidc.Render(cfg.Issuer, cfg.JWKSPath), "", " ") + desired, err := renderDiscoveryJSON(cfg.Issuer, cfg.JWKSPath) if err != nil { - return failuref("rendering discovery document: %v", err) + return failuref("%v", err) } current, derr := mgr.GetDiscoveryDocument(ctx) if derr == nil && bytes.Equal(bytes.TrimSpace(current), bytes.TrimSpace(desired)) { diff --git a/internal/cli/key_test.go b/internal/cli/key_test.go index de216eb..4cce045 100644 --- a/internal/cli/key_test.go +++ b/internal/cli/key_test.go @@ -15,7 +15,7 @@ import ( ) // rotateManager builds a fake with a current ("before") key and a distinct -// "after" key returned by CreateRSAKey, plus a published discovery doc that +// "after" key returned by CreateSigningKey, plus a published discovery doc that // already matches what rotation would render (so no spurious refresh). func rotateManager(t *testing.T) (mgr *fakeManager, beforeKid, afterKid string) { t.Helper() @@ -43,7 +43,7 @@ func TestRunRotate_HappyPath(t *testing.T) { t.Fatalf("runRotate: %v", err) } if mgr.createCalls != 1 { - t.Errorf("expected one CreateRSAKey call, got %d", mgr.createCalls) + t.Errorf("expected one CreateSigningKey call, got %d", mgr.createCalls) } if out.Len() != 0 { t.Errorf("rotate must write nothing to stdout, got:\n%s", out.String()) diff --git a/internal/cli/publish.go b/internal/cli/publish.go new file mode 100644 index 0000000..aac7bda --- /dev/null +++ b/internal/cli/publish.go @@ -0,0 +1,54 @@ +package cli + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/MaxAnderson95/jotsmith/internal/azurex" + "github.com/MaxAnderson95/jotsmith/internal/jwk" + "github.com/MaxAnderson95/jotsmith/internal/oidc" +) + +const ( + contentTypeJSON = "application/json" + cacheControl = "no-cache" +) + +func renderDiscoveryJSON(issuer, jwksPath string) ([]byte, error) { + doc, err := json.MarshalIndent(oidc.Render(issuer, jwksPath), "", " ") + if err != nil { + return nil, fmt.Errorf("rendering discovery document: %w", err) + } + return doc, nil +} + +func renderJWKSJSON(key azurex.Key) ([]byte, jwk.JWK, error) { + entry := jwk.FromRSA(key.N, key.E) + doc, err := renderJWKSetJSON(jwk.NewSet(entry)) + if err != nil { + return nil, jwk.JWK{}, err + } + return doc, entry, nil +} + +// buildJWKS reads the current signing key's public material and wraps it in the +// single-entry JWKS shape required by jotsmith v1. +func buildJWKS(ctx context.Context, insp azurex.Inspector) (jwk.Set, error) { + key, err := insp.GetSigningKey(ctx) + if err != nil { + return jwk.Set{}, err + } + return jwk.NewSet(jwk.FromRSA(key.N, key.E)), nil +} + +func renderJWKSetJSON(set jwk.Set) ([]byte, error) { + if _, err := set.SingleRS256(); err != nil { + return nil, fmt.Errorf("rendering JWKS: %w", err) + } + doc, err := json.MarshalIndent(set, "", " ") + if err != nil { + return nil, fmt.Errorf("rendering JWKS: %w", err) + } + return doc, nil +} diff --git a/internal/cli/setup.go b/internal/cli/setup.go index 3799f12..4abb462 100644 --- a/internal/cli/setup.go +++ b/internal/cli/setup.go @@ -2,7 +2,6 @@ package cli import ( "context" - "encoding/json" "errors" "fmt" "log/slog" @@ -11,16 +10,9 @@ import ( "github.com/MaxAnderson95/jotsmith/internal/azurex" "github.com/MaxAnderson95/jotsmith/internal/config" - "github.com/MaxAnderson95/jotsmith/internal/jwk" "github.com/MaxAnderson95/jotsmith/internal/oidc" ) -const ( - contentTypeJSON = "application/json" - cacheControl = "no-cache" - rsaKeyBits = 2048 -) - type setupOptions struct { subscription string storageAccount string @@ -148,14 +140,13 @@ func runSetup(ctx context.Context, mgr azurex.SetupManager, opts setupOptions, e return err } - jwkEntry := jwk.FromRSA(key.N, key.E) - discoveryDoc, err := json.MarshalIndent(oidc.Render(issuer, config.DefaultJWKSPath), "", " ") + discoveryDoc, err := renderDiscoveryJSON(issuer, config.DefaultJWKSPath) if err != nil { - return failuref("rendering discovery document: %v", err) + return failuref("%v", err) } - jwksDoc, err := json.MarshalIndent(jwk.NewSet(jwkEntry), "", " ") + jwksDoc, jwkEntry, err := renderJWKSJSON(key) if err != nil { - return failuref("rendering JWKS: %v", err) + return failuref("%v", err) } if uerr := mgr.UploadBlob(ctx, config.DefaultDiscoveryPath, contentTypeJSON, cacheControl, discoveryDoc); uerr != nil { @@ -192,7 +183,7 @@ func resolveSigningKey(ctx context.Context, mgr azurex.SetupManager, opts setupO switch { case errors.As(err, ¬Found): log.Info("creating signing key", "key_name", opts.keyName) - key, err = mgr.CreateRSAKey(ctx, rsaKeyBits) + key, err = mgr.CreateSigningKey(ctx) if err != nil { return azurex.Key{}, keyVaultError(err, opts.keyVault) } @@ -210,7 +201,7 @@ func resolveSigningKey(ctx context.Context, mgr azurex.SetupManager, opts setupO return azurex.Key{}, failuref("setup aborted") } log.Info("rotating signing key", "key_name", opts.keyName) - key, err = mgr.CreateRSAKey(ctx, rsaKeyBits) + key, err = mgr.CreateSigningKey(ctx) if err != nil { return azurex.Key{}, keyVaultError(err, opts.keyVault) } diff --git a/internal/cli/setup_test.go b/internal/cli/setup_test.go index ef8a1cd..1edc958 100644 --- a/internal/cli/setup_test.go +++ b/internal/cli/setup_test.go @@ -61,7 +61,7 @@ func (m *fakeManager) UploadBlob(_ context.Context, path, _, _ string, data []by return nil } -func (m *fakeManager) CreateRSAKey(context.Context, int) (azurex.Key, error) { +func (m *fakeManager) CreateSigningKey(context.Context) (azurex.Key, error) { if m.createErr != nil { return azurex.Key{}, m.createErr } @@ -111,7 +111,7 @@ func TestRunSetup_FreshCreatesKeyUploadsAndWritesConfig(t *testing.T) { t.Fatalf("unexpected error: %v", err) } if mgr.createCalls != 1 { - t.Errorf("expected one CreateRSAKey call, got %d", mgr.createCalls) + t.Errorf("expected one CreateSigningKey call, got %d", mgr.createCalls) } if _, ok := mgr.uploads[config.DefaultDiscoveryPath]; !ok { t.Error("discovery document was not uploaded") diff --git a/internal/jwk/jwk.go b/internal/jwk/jwk.go index 9d15a71..5d5e956 100644 --- a/internal/jwk/jwk.go +++ b/internal/jwk/jwk.go @@ -8,6 +8,12 @@ import ( "math/big" ) +const ( + KeyTypeRSA = "RSA" + UseSig = "sig" + AlgRS256 = "RS256" +) + // b64 is base64url without padding, as JWK members and thumbprints require // (RFC 7515 / RFC 7638). var b64 = base64.RawURLEncoding @@ -34,9 +40,9 @@ type Set struct { // public exponent (e) bytes, deriving kid as the RFC 7638 thumbprint. func FromRSA(n, e []byte) JWK { return JWK{ - Kty: "RSA", - Use: "sig", - Alg: "RS256", + Kty: KeyTypeRSA, + Use: UseSig, + Alg: AlgRS256, Kid: Thumbprint(n, e), N: b64.EncodeToString(n), E: b64.EncodeToString(e), @@ -48,23 +54,52 @@ func NewSet(k JWK) Set { return Set{Keys: []JWK{k}} } +// SingleRS256 returns the only JWK in a v1 jotsmith JWKS, enforcing the +// snap-cutover invariant: exactly one RSA signing key advertising RS256, with a +// kid derived from the key's RFC 7638 thumbprint. +func (s Set) SingleRS256() (JWK, error) { + switch len(s.Keys) { + case 0: + return JWK{}, fmt.Errorf("contains no keys") + case 1: + // continue below + default: + return JWK{}, fmt.Errorf("contains %d keys; jotsmith v1 requires exactly one", len(s.Keys)) + } + + k := s.Keys[0] + if k.Kty != KeyTypeRSA { + return JWK{}, fmt.Errorf("key kty is %q, want %q", k.Kty, KeyTypeRSA) + } + if k.Use != UseSig { + return JWK{}, fmt.Errorf("key use is %q, want %q", k.Use, UseSig) + } + if k.Alg != AlgRS256 { + return JWK{}, fmt.Errorf("key alg is %q, want %q", k.Alg, AlgRS256) + } + if k.Kid == "" { + return JWK{}, fmt.Errorf("key kid is empty") + } + nBytes, eBytes, err := k.publicMaterial() + if err != nil { + return JWK{}, err + } + if _, err := k.RSAPublicKey(); err != nil { + return JWK{}, err + } + if want := Thumbprint(nBytes, eBytes); k.Kid != want { + return JWK{}, fmt.Errorf("key kid %q does not match RFC 7638 thumbprint %q", k.Kid, want) + } + return k, nil +} + // RSAPublicKey reconstructs the RSA public key from the JWK's base64url n and e // members, so a verifier can check signatures. It rejects non-RSA keys and an // exponent that does not fit Go's int-sized rsa.PublicKey.E. func (k JWK) RSAPublicKey() (*rsa.PublicKey, error) { - if k.Kty != "RSA" { - return nil, fmt.Errorf("unsupported key type %q (only RSA)", k.Kty) - } - nBytes, err := b64.DecodeString(k.N) + nBytes, eBytes, err := k.publicMaterial() if err != nil { - return nil, fmt.Errorf("decoding modulus (n): %w", err) - } - eBytes, err := b64.DecodeString(k.E) - if err != nil { - return nil, fmt.Errorf("decoding exponent (e): %w", err) - } - if len(nBytes) == 0 || len(eBytes) == 0 { - return nil, fmt.Errorf("modulus or exponent is empty") + return nil, err } e := new(big.Int).SetBytes(eBytes) if !e.IsInt64() || e.Int64() > int64(^uint32(0)) { @@ -76,6 +111,24 @@ func (k JWK) RSAPublicKey() (*rsa.PublicKey, error) { }, nil } +func (k JWK) publicMaterial() ([]byte, []byte, error) { + if k.Kty != KeyTypeRSA { + return nil, nil, fmt.Errorf("unsupported key type %q (only RSA)", k.Kty) + } + nBytes, err := b64.DecodeString(k.N) + if err != nil { + return nil, nil, fmt.Errorf("decoding modulus (n): %w", err) + } + eBytes, err := b64.DecodeString(k.E) + if err != nil { + return nil, nil, fmt.Errorf("decoding exponent (e): %w", err) + } + if len(nBytes) == 0 || len(eBytes) == 0 { + return nil, nil, fmt.Errorf("modulus or exponent is empty") + } + return nBytes, eBytes, nil +} + // Thumbprint computes the RFC 7638 §3 thumbprint of an RSA public key: SHA-256 // over the canonical JSON containing only the required members (e, kty, n) in // lexicographic order with no whitespace, base64url-encoded without padding. diff --git a/internal/jwk/jwk_test.go b/internal/jwk/jwk_test.go index cf98a85..ca31e5d 100644 --- a/internal/jwk/jwk_test.go +++ b/internal/jwk/jwk_test.go @@ -3,6 +3,7 @@ package jwk import ( "encoding/base64" "encoding/json" + "strings" "testing" ) @@ -65,3 +66,26 @@ func TestSet_MarshalsKeysArray(t *testing.T) { t.Fatalf("expected keys array of length 1, got %d", len(probe.Keys)) } } + +func TestSet_SingleRS256(t *testing.T) { + k := FromRSA([]byte{1, 2, 3}, []byte{1, 0, 1}) + + got, err := NewSet(k).SingleRS256() + if err != nil { + t.Fatalf("expected valid single-key JWKS, got: %v", err) + } + if got.Kid != k.Kid { + t.Errorf("returned kid = %q, want %q", got.Kid, k.Kid) + } + + multi := Set{Keys: []JWK{k, k}} + if _, err := multi.SingleRS256(); err == nil || !strings.Contains(err.Error(), "exactly one") { + t.Fatalf("expected multi-key JWKS to fail exact single-key invariant, got %v", err) + } + + badKid := k + badKid.Kid = "wrong" + if _, err := NewSet(badKid).SingleRS256(); err == nil || !strings.Contains(err.Error(), "thumbprint") { + t.Fatalf("expected non-thumbprint kid to fail, got %v", err) + } +} diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 6f3bc5c..a95329c 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -22,6 +22,7 @@ type Discovery struct { // id_token_signing_alg_values_supported is an array of one ("RS256") per // ADR-0003. func Render(issuer, jwksPath string) Discovery { + issuer = CanonicalIssuer(issuer) return Discovery{ Issuer: issuer, JWKSURI: JoinURL(issuer, jwksPath), @@ -33,8 +34,14 @@ func Render(issuer, jwksPath string) Discovery { } } +// CanonicalIssuer strips trailing slashes so the discovery issuer exactly +// matches the iss claim value consumers validate. +func CanonicalIssuer(issuer string) string { + return strings.TrimRight(issuer, "/") +} + // JoinURL joins a base URL and a path with exactly one slash between them, // regardless of whether base has a trailing slash or path a leading one. func JoinURL(base, path string) string { - return strings.TrimRight(base, "/") + "/" + strings.TrimLeft(path, "/") + return CanonicalIssuer(base) + "/" + strings.TrimLeft(path, "/") } diff --git a/internal/oidc/oidc_test.go b/internal/oidc/oidc_test.go index 0839c83..1f4c8a6 100644 --- a/internal/oidc/oidc_test.go +++ b/internal/oidc/oidc_test.go @@ -20,7 +20,7 @@ func TestJoinURL(t *testing.T) { } func TestRender(t *testing.T) { - d := Render("https://jotsmithmax.z13.web.core.windows.net", ".well-known/jwks.json") + d := Render("https://jotsmithmax.z13.web.core.windows.net/", ".well-known/jwks.json") if d.Issuer != "https://jotsmithmax.z13.web.core.windows.net" { t.Errorf("issuer = %q", d.Issuer)