diff --git a/.mise-tasks/backfill/workos.sh b/.mise-tasks/backfill/workos.sh new file mode 100755 index 0000000000..e1fe512de5 --- /dev/null +++ b/.mise-tasks/backfill/workos.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +#MISE description="Run the local WorkOS backfill script" +#MISE dir="{{ config_root }}" + +set -euo pipefail + +exec go run ./server/cmd/workos-backfill "$@" diff --git a/docs/runbooks/workos-backfill.md b/docs/runbooks/workos-backfill.md new file mode 100644 index 0000000000..328243edc0 --- /dev/null +++ b/docs/runbooks/workos-backfill.md @@ -0,0 +1,292 @@ +# WorkOS Backfill + +Use this runbook to run the local WorkOS backfill script against local, dev, or +prod databases. The script syncs WorkOS snapshot state into Gram for: + +- global roles +- organization metadata +- organization roles +- users +- organization memberships +- organization role assignments + +The entrypoint is `server/cmd/workos-backfill`. + +## Prerequisites + +Set a WorkOS API key before running the script: + +```sh +export WORKOS_API_KEY='sk_test_...' +``` + +For prod, use the live WorkOS key and add `--environment=prod`. For local and +dev, the real WorkOS endpoint requires a test key that starts with `sk_test_`. + +For local database access, set `GRAM_DATABASE_URL` or pass `--database-url`. +For dev and prod Cloud SQL access, run the script with `--cloudsql-proxy`; +that mode ignores `GRAM_DATABASE_URL` and builds a local proxy URL. + +If you need to point at a non-standard WorkOS API endpoint, set +`WORKOS_API_URL` or pass `--workos-endpoint`. + +When `--cloudsql-proxy` is used, the script follows the `gram-infra` access +pattern: it starts the proxy, creates your IAM database user if needed, grants +read-only access for preflight/validate or write access for write phases, then +connects as your active `gcloud` account. + +## Commands + +Run preflight first. This is read-only and is the default phase: + +```sh +mise backfill:workos +``` + +Equivalent explicit command: + +```sh +mise backfill:workos --phase=preflight --environment=dev --cloudsql-proxy +``` + +Limit the scope while testing: + +```sh +mise backfill:workos --phase=preflight --environment=dev --cloudsql-proxy --limit=5 +``` + +Process organizations in deterministic batches. Organizations are sorted by +WorkOS organization ID, then the offset and size are applied: + +```sh +mise backfill:workos --phase=preflight --environment=dev --cloudsql-proxy --page-size=25 --page-offset=0 +mise backfill:workos --phase=preflight --environment=dev --cloudsql-proxy --page-size=25 --page-offset=25 +``` + +Run a specific WorkOS organization: + +```sh +mise backfill:workos --phase=preflight --environment=dev --cloudsql-proxy --workos-org-id=org_... +``` + +Multiple organizations can be repeated or comma-separated: + +```sh +mise backfill:workos --phase=preflight --environment=dev --cloudsql-proxy --workos-org-id=org_1,org_2 +``` + +Run global roles only: + +```sh +mise backfill:workos --phase=global-roles --environment=dev --cloudsql-proxy --dry-run=false +``` + +Run organizations, users, memberships, and assignments: + +```sh +mise backfill:workos --phase=organizations --environment=dev --cloudsql-proxy --dry-run=false +``` + +Run everything: + +```sh +mise backfill:workos --phase=all --environment=dev --cloudsql-proxy --dry-run=false +``` + +Validate after writes: + +```sh +mise backfill:workos --phase=validate --environment=dev --cloudsql-proxy +``` + +For prod, the command requires explicit prod confirmation: + +```sh +mise backfill:workos --phase=all --environment=prod --cloudsql-proxy --dry-run=false --confirm-prod=prod +``` + +Non-prod writes prompt for `backfill` unless `--auto-approve` is passed. Prod +writes never skip the prod confirmation. + +## Useful Flags + +- `--phase`: `preflight`, `global-roles`, `organizations`, `validate`, or `all`. +- `--environment`: `local`, `dev`, or `prod`. +- `--cloudsql-proxy`: start a local Cloud SQL proxy for dev/prod DB access. +- `--cloudsql-port`: local proxy port; defaults to a free port. +- `--cloudsql-db-name`: database name; defaults to `gram`. +- `--dry-run`: defaults to `true`; set `--dry-run=false` to write. +- `--workos-org-id`: process selected WorkOS organizations only. +- `--limit`: cap the number of WorkOS organizations inspected. +- `--page-size`: process at most this many organizations after the offset. +- `--page-offset`: skip this many organizations after deterministic sorting. +- `--statement-timeout`: Postgres `statement_timeout`; defaults to `30m`. +- `--breakpoint-before-write`: pause after preflight and before DB writes. +- `--pause-after-each`: pause after each organization backfill. +- `--auto-approve`: skip the non-prod `backfill` prompt. +- `--confirm-prod=prod`: required for non-interactive prod access. + +## Interpreting Output + +The script prints a preflight plan before it writes. + +```text +Global role preflight: + workos_global_roles: 3 + role_rows: affected=0 create=0 update=0 delete=0 noop=3 stale_skip=0 +``` + +`affected` means rows that would mutate the database. It is +`create + update + delete`. + +```text +Organization preflight: + workos_orgs: 119 + expected_organization_roles: 39 + expected_users: 101 + expected_memberships: 101 + skipped_unlinked_without_external_id: 4 + organization_rows: affected=115 create=115 update=0 delete=0 noop=0 stale_skip=4 + role_rows: affected=37 create=37 update=0 delete=0 noop=0 stale_skip=0 + user_rows: affected=87 create=87 update=0 delete=0 noop=0 stale_skip=10 + membership_rows: affected=87 create=87 update=0 delete=0 noop=0 stale_skip=10 + assignment_rows: affected=87 create=87 update=0 delete=0 noop=0 stale_skip=10 +``` + +Row states: + +- `create`: the local row does not exist and will be inserted. +- `update`: the local row exists and at least one synced field will change. +- `delete`: the local row is absent from the WorkOS snapshot and will be soft-deleted. +- `noop`: the local row already matches the WorkOS snapshot. +- `stale_skip`: the local row has newer synced WorkOS state, or the script cannot safely resolve the local row. + +The sample section shows representative organizations and the dominant row +state for each entity type: + +```text +sample: + org_... -> gram_org_id org=create:1 roles=noop:0 users=create:2 memberships=create:2 assignments=create:2 name="Example" +``` + +When updates or deletes are planned, the script prints a capped +`planned_change_details` section with field-level changes: + +```text +planned_change_details: showing=1 total=1 + update user user_123 + email: "old@example.com" -> "new@example.com" +``` + +After writes, the completion report includes both organization-level progress +and row outcomes for the successfully written and validated organizations: + +```text +Organization backfill complete. + scanned: 119 + written: 115 + validated: 115 + skipped: 4 + skipped_noop: 0 + failed: 0 + validation_failures: 0 + organization_rows: affected=115 create=115 update=0 delete=0 noop=0 stale_skip=0 + role_rows: affected=37 create=37 update=0 delete=0 noop=0 stale_skip=0 + user_rows: affected=87 create=87 update=0 delete=0 noop=0 stale_skip=0 + membership_rows: affected=87 create=87 update=0 delete=0 noop=0 stale_skip=0 + assignment_rows: affected=87 create=87 update=0 delete=0 noop=0 stale_skip=0 +``` + +`written` counts organizations whose write transaction completed and whose +validation passed. It is not a row count. + +## How It Works + +Preflight loads the WorkOS snapshot first. For each selected organization, it +fetches the WorkOS organization, roles, users, and memberships. It then compares +that snapshot to the current database and classifies expected row changes. + +Writes are split by phase: + +- `global-roles` syncs global WorkOS roles. +- `organizations` syncs organization metadata, organization roles, users, + memberships, and organization role assignments. +- `all` runs both phases. + +Each organization write runs inside a database transaction. The write path: + +1. Resolves the local organization by `workos_id`, or by WorkOS `external_id` + when no local `workos_id` row exists. +2. Skips unlinked WorkOS organizations with no `external_id`. +3. Upserts organization metadata and preserves an existing slug. New slugs use + Gram's normal unique organization slug generation. +4. Upserts organization roles and soft-deletes roles missing from the WorkOS + snapshot. +5. Resolves each WorkOS user by existing local `workos_id`, then by WorkOS + `external_id`. +6. Skips users that cannot be resolved to a local Gram user ID. +7. Upserts users, memberships, and role assignments for resolved users. +8. Commits the transaction and validates the expected rows. + +When `--cloudsql-proxy` is set, the script derives the Cloud SQL instance from +`--environment`, starts `cloud-sql-proxy` with `--auto-iam-authn`, picks a free +local port unless `--cloudsql-port` is provided, reads the `*_gram_db_password` +secret to prepare IAM user access, and connects to `127.0.0.1` as the active +`gcloud` account. + +The script sets `lock_timeout=5s` and `statement_timeout=5min` for DB sessions. +Preflight, validate, and dry-run sessions are read-only. + +## Debugging + +Use `--breakpoint-before-write` to pause after preflight and before writes: + +```sh +mise backfill:workos --phase=organizations --environment=dev --cloudsql-proxy --dry-run=false --breakpoint-before-write +``` + +Use `--pause-after-each` when stepping through a small batch: + +```sh +mise backfill:workos --phase=organizations --environment=dev --cloudsql-proxy --dry-run=false --limit=1 --pause-after-each +``` + +For VSCode, launch `server/cmd/workos-backfill` as a Go program and pass the +same args. Useful breakpoints: + +- `runOrganizationBackfill` in `server/cmd/workos-backfill/main.go`, on the + `backfill.Do(...)` call, before an organization transaction starts. +- `BackfillWorkOSOrganization.Do` in `server/cmd/workos-backfill/backfill.go`, + where the WorkOS snapshot is fetched and the transaction begins. +- `backfillWorkOSUser` in `server/cmd/workos-backfill/backfill_user.go`, when + debugging skipped users or user ID resolution. + +When stopped before an organization write, inspect: + +- `org.workosOrgID` +- `org.gramOrgID` +- `org.orgChanges` +- `org.roleChanges` +- `org.userChanges` +- `org.membershipChanges` +- `org.assignmentChanges` +- `org.changeDetails` + +## Safety Checklist + +1. Run `preflight` first. +2. Use `--workos-org-id` or `--limit` while debugging. +3. Confirm `affected`, `create`, `update`, `delete`, `noop`, and `stale_skip` + counts match expectations. +4. Review `planned_change_details` for updates and deletes. +5. Run writes in phases when possible: `global-roles`, then `organizations`. +6. Run `validate` after writes. +7. For prod, capture the preflight output before writing. + +## Troubleshooting + +If the command fails with `SQLSTATE 42501` or `permission denied for table ...`, +the Cloud SQL proxy connected successfully, but the automatic grant step did not +apply the expected table privileges. Check that your `gcloud` account can manage +Cloud SQL users and read the `*_gram_db_password` secret for the target +environment. diff --git a/go.mod b/go.mod index dd8d72ec66..5f4c5ae6c6 100644 --- a/go.mod +++ b/go.mod @@ -85,6 +85,7 @@ require ( golang.org/x/net v0.53.0 golang.org/x/sync v0.20.0 golang.org/x/sys v0.43.0 + golang.org/x/term v0.42.0 golang.org/x/tools v0.44.0 google.golang.org/api v0.274.0 google.golang.org/grpc v1.80.0 @@ -328,7 +329,6 @@ require ( golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect golang.org/x/mod v0.35.0 // indirect golang.org/x/oauth2 v0.36.0 // indirect - golang.org/x/term v0.42.0 // indirect golang.org/x/text v0.36.0 // indirect golang.org/x/time v0.15.0 // indirect google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 // indirect diff --git a/server/internal/background/activities/backfill_workos_organization.go b/server/cmd/workos-backfill/backfill.go similarity index 56% rename from server/internal/background/activities/backfill_workos_organization.go rename to server/cmd/workos-backfill/backfill.go index efa820bdc4..00afc6b0b3 100644 --- a/server/internal/background/activities/backfill_workos_organization.go +++ b/server/cmd/workos-backfill/backfill.go @@ -1,4 +1,4 @@ -package activities +package main import ( "context" @@ -10,27 +10,24 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" - "github.com/workos/workos-go/v6/pkg/events" accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" "github.com/speakeasy-api/gram/server/internal/attr" + "github.com/speakeasy-api/gram/server/internal/auth/orgslug" "github.com/speakeasy-api/gram/server/internal/conv" "github.com/speakeasy-api/gram/server/internal/o11y" "github.com/speakeasy-api/gram/server/internal/oops" orgrepo "github.com/speakeasy-api/gram/server/internal/organizations/repo" "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" - usersrepo "github.com/speakeasy-api/gram/server/internal/users/repo" ) -type WorkOSClient interface { +type Client interface { GetOrganization(ctx context.Context, orgID string) (*workos.Organization, error) - ListOrganizations(ctx context.Context) ([]workos.Organization, error) ListRoles(ctx context.Context, orgID string) ([]workos.Role, error) + ListOrgUsers(ctx context.Context, orgID string) (map[string]workos.User, error) ListOrgMemberships(ctx context.Context, orgID string) ([]workos.Member, error) ListGlobalRoles(ctx context.Context) ([]workos.Role, error) - ListEvents(ctx context.Context, opts events.ListEventsOpts) (events.ListEventsResponse, error) - UpdateUserExternalID(ctx context.Context, workosUserID, externalID string) error } type BackfillWorkOSOrganizationParams struct { @@ -40,7 +37,7 @@ type BackfillWorkOSOrganizationParams struct { type BackfillWorkOSOrganization struct { logger *slog.Logger db *pgxpool.Pool - workos WorkOSClient + workos Client } type backfillWorkOSMember struct { @@ -48,7 +45,7 @@ type backfillWorkOSMember struct { updatedAt time.Time } -func NewBackfillWorkOSOrganization(logger *slog.Logger, db *pgxpool.Pool, workosClient WorkOSClient) *BackfillWorkOSOrganization { +func NewBackfillWorkOSOrganization(logger *slog.Logger, db *pgxpool.Pool, workosClient Client) *BackfillWorkOSOrganization { return &BackfillWorkOSOrganization{ logger: logger.With(attr.SlogComponent("backfill_workos_organization")), db: db, @@ -73,6 +70,11 @@ func (b *BackfillWorkOSOrganization) Do(ctx context.Context, params BackfillWork return oops.E(oops.CodeUnexpected, err, "list WorkOS organization roles").Log(ctx, logger) } + users, err := b.workos.ListOrgUsers(ctx, params.WorkOSOrganizationID) + if err != nil { + return oops.E(oops.CodeUnexpected, err, "list WorkOS organization users").Log(ctx, logger) + } + members, err := b.workos.ListOrgMemberships(ctx, params.WorkOSOrganizationID) if err != nil { return oops.E(oops.CodeUnexpected, err, "list WorkOS organization memberships").Log(ctx, logger) @@ -104,7 +106,15 @@ func (b *BackfillWorkOSOrganization) Do(ctx context.Context, params BackfillWork logger.DebugContext(ctx, "skipping WorkOS organization backfill for unlinked organization with no external_id") return nil } - org.ID = workosOrg.ExternalID + org, err = orgQueries.GetOrganizationMetadata(ctx, workosOrg.ExternalID) + switch { + case errors.Is(err, pgx.ErrNoRows): + org.ID = workosOrg.ExternalID + case err != nil: + return fmt.Errorf("get organization by external id %q: %w", workosOrg.ExternalID, err) + case org.WorkosID.Valid && org.WorkosID.String != params.WorkOSOrganizationID: + return fmt.Errorf("workos organization %q resolved to gram organization %q with different workos_id %q", params.WorkOSOrganizationID, org.ID, org.WorkosID.String) + } case err != nil: return fmt.Errorf("get organization by workos id %q: %w", params.WorkOSOrganizationID, err) } @@ -117,7 +127,19 @@ func (b *BackfillWorkOSOrganization) Do(ctx context.Context, params BackfillWork return err } for _, member := range parsedMembers { - if err := backfillOrganizationMember(ctx, tx, org.ID, member); err != nil { + user, ok := users[member.member.UserID] + if !ok { + return fmt.Errorf("missing WorkOS user %q for membership %q", member.member.UserID, member.member.ID) + } + gramUserID, userResolved, err := backfillWorkOSUser(ctx, logger, tx, user) + if err != nil { + return fmt.Errorf("backfill WorkOS user %q: %w", user.ID, err) + } + if !userResolved { + continue + } + + if err := backfillOrganizationMember(ctx, tx, org.ID, member, gramUserID); err != nil { return err } } @@ -138,14 +160,23 @@ func backfillOrganizationMetadata(ctx context.Context, repo *orgrepo.Queries, or if org.WorkosUpdatedAt.Valid { rowUpdatedAt = &org.WorkosUpdatedAt.Time } - if !ShouldProcessEvent(lastEventID, rowUpdatedAt, "", updatedAt) { + if !shouldProcessEvent(lastEventID, rowUpdatedAt, "", updatedAt) { return org, nil } + slug := org.Slug + if slug == "" { + var err error + slug, err = uniqueOrganizationSlug(ctx, repo, workosOrg.Name, workosOrg.ID) + if err != nil { + return orgrepo.OrganizationMetadatum{}, err + } + } + updatedOrg, err := repo.UpsertOrganizationMetadataFromWorkOS(ctx, orgrepo.UpsertOrganizationMetadataFromWorkOSParams{ ID: org.ID, Name: workosOrg.Name, - Slug: conv.ToSlug(workosOrg.Name), + Slug: slug, WorkosID: conv.ToPGText(workosOrg.ID), WorkosUpdatedAt: conv.ToPGTimestamptz(updatedAt), WorkosLastEventID: conv.ToPGText(""), @@ -157,6 +188,18 @@ func backfillOrganizationMetadata(ctx context.Context, repo *orgrepo.Queries, or return updatedOrg, nil } +func uniqueOrganizationSlug(ctx context.Context, repo orgslug.Lookup, name, fallback string) (string, error) { + base := orgslug.Slugify(name) + if base == "" { + base = fallback + } + slug, err := orgslug.FindUnique(ctx, repo, base) + if err != nil { + return "", fmt.Errorf("find unique organization slug: %w", err) + } + return slug, nil +} + func backfillOrganizationRoles(ctx context.Context, logger *slog.Logger, dbtx pgx.Tx, organizationID string, roles []workos.Role) error { repo := accessrepo.New(dbtx) snapshotSlugs := make(map[string]time.Time) @@ -191,7 +234,7 @@ func backfillOrganizationRoles(ctx context.Context, logger *slog.Logger, dbtx pg if existing.WorkosUpdatedAt.Valid { rowUpdatedAt = &existing.WorkosUpdatedAt.Time } - if !ShouldProcessEvent(lastEventID, rowUpdatedAt, "", updatedAt) { + if !shouldProcessEvent(lastEventID, rowUpdatedAt, "", updatedAt) { continue } @@ -226,7 +269,7 @@ func backfillOrganizationRoles(ctx context.Context, logger *slog.Logger, dbtx pg rowUpdatedAt = &localRole.WorkosUpdatedAt.Time } deletedAt := time.Now().UTC() - if !ShouldProcessEvent(lastEventID, rowUpdatedAt, "", deletedAt) { + if !shouldProcessEvent(lastEventID, rowUpdatedAt, "", deletedAt) { continue } @@ -250,42 +293,62 @@ func backfillOrganizationRoles(ctx context.Context, logger *slog.Logger, dbtx pg return nil } -func backfillOrganizationMember(ctx context.Context, dbtx pgx.Tx, organizationID string, parsed backfillWorkOSMember) error { +func backfillOrganizationMember(ctx context.Context, dbtx pgx.Tx, organizationID string, parsed backfillWorkOSMember, gramUserID string) error { member := parsed.member orgQueries := orgrepo.New(dbtx) - gramUserID, err := usersrepo.New(dbtx).GetUserIDByWorkosID(ctx, conv.ToPGText(member.UserID)) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { - return fmt.Errorf("get user by workos id %q: %w", member.UserID, err) + if err := repairMissingWorkOSMembershipFields(ctx, dbtx, organizationID, gramUserID, member, parsed.updatedAt); err != nil { + return err } - cursor, err := latestMembershipCursor(ctx, orgQueries, organizationID, gramUserID, member.UserID) + relationshipCursor, err := latestRelationshipCursor(ctx, orgQueries, organizationID, gramUserID) if err != nil { return err } - if !ShouldProcessEvent(cursor.lastEventID, cursor.updatedAt, "", parsed.updatedAt) { - return nil - } - - if err := orgQueries.UpsertWorkOSMembership(ctx, orgrepo.UpsertWorkOSMembershipParams{ - OrganizationID: organizationID, - UserID: conv.ToPGTextEmpty(gramUserID), - WorkosUserID: conv.ToPGText(member.UserID), - WorkosMembershipID: conv.ToPGText(member.ID), - WorkosUpdatedAt: conv.ToPGTimestamptz(parsed.updatedAt), - WorkosLastEventID: conv.ToPGText(""), - }); err != nil { - return fmt.Errorf("upsert organization membership %q: %w", member.ID, err) + if shouldProcessEvent(relationshipCursor.lastEventID, relationshipCursor.updatedAt, "", parsed.updatedAt) { + updatedExisting, err := updateExistingWorkOSMembership(ctx, dbtx, organizationID, gramUserID, member, parsed.updatedAt) + if err != nil { + return err + } + if !updatedExisting { + if err := mergeConflictingWorkOSMembership(ctx, dbtx, organizationID, gramUserID, member.ID); err != nil { + return err + } + if err := orgQueries.UpsertWorkOSMembership(ctx, orgrepo.UpsertWorkOSMembershipParams{ + OrganizationID: organizationID, + UserID: conv.ToPGText(gramUserID), + WorkosUserID: conv.ToPGText(member.UserID), + WorkosMembershipID: conv.ToPGText(member.ID), + WorkosUpdatedAt: conv.ToPGTimestamptz(parsed.updatedAt), + WorkosLastEventID: conv.ToPGText(""), + }); err != nil { + return fmt.Errorf("upsert organization membership %q: %w", member.ID, err) + } + } } roleSlugs := []string{} if member.RoleSlug != "" { + roleExists, err := activeAssignmentRoleExists(ctx, dbtx, organizationID, member.RoleSlug) + if err != nil { + return err + } + if !roleExists { + return nil + } roleSlugs = []string{member.RoleSlug} } + assignmentCursor, err := latestAssignmentCursor(ctx, orgQueries, organizationID, member.UserID) + if err != nil { + return err + } + if !shouldProcessEvent(assignmentCursor.lastEventID, assignmentCursor.updatedAt, "", parsed.updatedAt) { + return nil + } if err := orgQueries.SyncUserOrganizationRoleAssignments(ctx, orgrepo.SyncUserOrganizationRoleAssignmentsParams{ OrganizationID: organizationID, WorkosUserID: member.UserID, - UserID: conv.ToPGTextEmpty(gramUserID), + UserID: conv.ToPGText(gramUserID), WorkosMembershipID: conv.ToPGText(member.ID), WorkosUpdatedAt: conv.ToPGTimestamptz(parsed.updatedAt), WorkosLastEventID: conv.ToPGText(""), @@ -297,19 +360,133 @@ func backfillOrganizationMember(ctx context.Context, dbtx pgx.Tx, organizationID return nil } +func repairMissingWorkOSMembershipFields(ctx context.Context, dbtx pgx.Tx, organizationID, gramUserID string, member workos.Member, updatedAt time.Time) error { + if gramUserID == "" { + return nil + } + + _, err := dbtx.Exec(ctx, ` +UPDATE organization_user_relationships +SET user_id = COALESCE(user_id, $2), + workos_user_id = COALESCE(workos_user_id, $3), + workos_updated_at = COALESCE(workos_updated_at, $5), + updated_at = CASE + WHEN user_id IS NULL OR workos_user_id IS NULL OR workos_updated_at IS NULL THEN clock_timestamp() + ELSE updated_at + END +WHERE organization_id = $1 + AND workos_membership_id = $4 + AND deleted IS FALSE + AND (user_id IS NULL OR workos_user_id IS NULL OR workos_updated_at IS NULL) + AND NOT EXISTS ( + SELECT 1 + FROM organization_user_relationships target + WHERE target.organization_id = $1 + AND target.user_id = $2 + AND target.id <> organization_user_relationships.id + )`, + organizationID, + conv.ToPGText(gramUserID), + conv.ToPGText(member.UserID), + conv.ToPGText(member.ID), + conv.ToPGTimestamptz(updatedAt), + ) + if err != nil { + return fmt.Errorf("repair missing WorkOS membership fields %q: %w", member.ID, err) + } + + _, err = dbtx.Exec(ctx, ` +UPDATE organization_user_relationships +SET workos_membership_id = $4, + workos_user_id = COALESCE(workos_user_id, $3), + workos_updated_at = COALESCE(workos_updated_at, $5), + updated_at = clock_timestamp() +WHERE organization_id = $1 + AND user_id = $2 + AND deleted IS FALSE + AND workos_membership_id IS NULL + AND NOT EXISTS ( + SELECT 1 + FROM organization_user_relationships owner + WHERE owner.workos_membership_id = $4 + AND owner.deleted IS FALSE + )`, + organizationID, + conv.ToPGText(gramUserID), + conv.ToPGText(member.UserID), + conv.ToPGText(member.ID), + conv.ToPGTimestamptz(updatedAt), + ) + if err != nil { + return fmt.Errorf("repair missing WorkOS membership id %q: %w", member.ID, err) + } + return nil +} + +func updateExistingWorkOSMembership(ctx context.Context, dbtx pgx.Tx, organizationID string, gramUserID string, member workos.Member, updatedAt time.Time) (bool, error) { + tag, err := dbtx.Exec(ctx, ` +UPDATE organization_user_relationships +SET user_id = COALESCE($2, user_id), + workos_user_id = $3, + workos_updated_at = $5, + workos_last_event_id = $6, + deleted_at = NULL, + updated_at = clock_timestamp() +WHERE organization_id = $1 + AND workos_membership_id = $4 + AND deleted IS FALSE + AND NOT EXISTS ( + SELECT 1 + FROM organization_user_relationships target + WHERE target.organization_id = $1 + AND target.user_id = $2 + AND target.id <> organization_user_relationships.id + )`, + organizationID, + conv.ToPGText(gramUserID), + conv.ToPGText(member.UserID), + conv.ToPGText(member.ID), + conv.ToPGTimestamptz(updatedAt), + conv.ToPGText(""), + ) + if err != nil { + return false, fmt.Errorf("update existing WorkOS membership %q: %w", member.ID, err) + } + return tag.RowsAffected() > 0, nil +} + +func mergeConflictingWorkOSMembership(ctx context.Context, dbtx pgx.Tx, organizationID, gramUserID, workosMembershipID string) error { + if gramUserID == "" { + return nil + } + _, err := dbtx.Exec(ctx, ` +UPDATE organization_user_relationships +SET deleted_at = COALESCE(deleted_at, clock_timestamp()), + updated_at = clock_timestamp() +WHERE workos_membership_id = $3 + AND deleted IS FALSE + AND EXISTS ( + SELECT 1 + FROM organization_user_relationships target + WHERE target.organization_id = $1 + AND target.user_id = $2 + ) + AND NOT ( + organization_id = $1 + AND user_id IS NOT DISTINCT FROM $2 + )`, organizationID, gramUserID, workosMembershipID) + if err != nil { + return fmt.Errorf("merge conflicting WorkOS membership %q: %w", workosMembershipID, err) + } + return nil +} + type membershipCursor struct { lastEventID *string updatedAt *time.Time } -// latestMembershipCursor returns the newest local WorkOS state for a membership -// before applying a snapshot. Membership backfill writes two local shapes: -// organization_user_relationships when the WorkOS user is linked to a Gram -// user, and organization_role_assignments even when the user is still unknown -// locally. Both can be updated by event processing, so the snapshot must compare -// against the freshest cursor/timestamp from both tables before it overwrites -// either table. -func latestMembershipCursor(ctx context.Context, repo *orgrepo.Queries, organizationID, gramUserID, workosUserID string) (membershipCursor, error) { +func latestAssignmentCursor(ctx context.Context, repo *orgrepo.Queries, organizationID, workosUserID string) (membershipCursor, error) { var cursor membershipCursor assignments, err := repo.ListOrganizationRoleAssignmentsByWorkOSUser(ctx, orgrepo.ListOrganizationRoleAssignmentsByWorkOSUserParams{ @@ -320,13 +497,17 @@ func latestMembershipCursor(ctx context.Context, repo *orgrepo.Queries, organiza return membershipCursor{}, fmt.Errorf("list organization role assignments for WorkOS user %q: %w", workosUserID, err) } for _, assignment := range assignments { + if assignment.DeletedAt.Valid { + continue + } moveMembershipCursor(&cursor, assignment.WorkosLastEventID, assignment.WorkosUpdatedAt) } - // No relationship row exists when the WorkOS user is not linked to a Gram user. - if gramUserID == "" { - return cursor, nil - } + return cursor, nil +} + +func latestRelationshipCursor(ctx context.Context, repo *orgrepo.Queries, organizationID, gramUserID string) (membershipCursor, error) { + var cursor membershipCursor existing, err := repo.GetOrganizationRelationshipForUser(ctx, orgrepo.GetOrganizationRelationshipForUserParams{ OrganizationID: organizationID, @@ -352,6 +533,16 @@ func parseWorkOSTime(raw string) (time.Time, error) { return t, nil } +func shouldProcessEvent(rowLastEventID *string, rowWorkOSUpdatedAt *time.Time, eventID string, eventUpdatedAt time.Time) bool { + if rowLastEventID == nil || *rowLastEventID == "" { + if rowWorkOSUpdatedAt == nil { + return true + } + return !eventUpdatedAt.Before(*rowWorkOSUpdatedAt) + } + return eventID > *rowLastEventID +} + // moveMembershipCursor tracks per-field upper bounds rather than a coherent // row state. Backfill only uses the cursor as a conservative skip signal, so any // newer event ID or updated timestamp from either local membership shape should diff --git a/server/cmd/workos-backfill/backfill_test.go b/server/cmd/workos-backfill/backfill_test.go new file mode 100644 index 0000000000..aa8cd6e227 --- /dev/null +++ b/server/cmd/workos-backfill/backfill_test.go @@ -0,0 +1,1121 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" + + accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" + "github.com/speakeasy-api/gram/server/internal/conv" + orgrepo "github.com/speakeasy-api/gram/server/internal/organizations/repo" + "github.com/speakeasy-api/gram/server/internal/testenv" + "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" + usersrepo "github.com/speakeasy-api/gram/server/internal/users/repo" +) + +var infra *testenv.Environment + +func TestMain(m *testing.M) { + res, cleanup, err := testenv.Launch(context.Background(), testenv.LaunchOptions{Postgres: true}) + if err != nil { + log.Fatalf("Failed to launch test infrastructure: %v", err) + } + + infra = res + + code := m.Run() + + if err := cleanup(); err != nil { + log.Fatalf("Failed to cleanup test infrastructure: %v", err) + } + + os.Exit(code) +} + +func newBackfillTestConn(t *testing.T, name string) *pgxpool.Pool { + t.Helper() + + conn, err := infra.CloneTestDatabase(t, name) + require.NoError(t, err) + return conn +} + +func TestApplyOrganizationWindow_PageOffsetAndSize(t *testing.T) { + t.Parallel() + + orgs := []workos.Organization{ + {ID: "org_1", Name: "One", ExternalID: "", CreatedAt: "", UpdatedAt: ""}, + {ID: "org_2", Name: "Two", ExternalID: "", CreatedAt: "", UpdatedAt: ""}, + {ID: "org_3", Name: "Three", ExternalID: "", CreatedAt: "", UpdatedAt: ""}, + {ID: "org_4", Name: "Four", ExternalID: "", CreatedAt: "", UpdatedAt: ""}, + } + + window := applyOrganizationWindow(orgs, options{ + phase: phasePreflight, + environment: envLocal, + databaseURL: "", + cloudSQLProxy: false, + cloudSQLPort: 0, + cloudSQLDBName: "gram", + workosAPIKey: "", + workosEndpoint: "", + workosOrgIDs: nil, + limit: 0, + pageSize: 2, + pageOffset: 1, + statementTimeout: defaultStatementTimeout, + dryRun: true, + autoApprove: false, + pauseAfterEach: false, + confirmProd: "", + breakpointBefore: false, + }) + + require.Equal(t, []workos.Organization{ + {ID: "org_2", Name: "Two", ExternalID: "", CreatedAt: "", UpdatedAt: ""}, + {ID: "org_3", Name: "Three", ExternalID: "", CreatedAt: "", UpdatedAt: ""}, + }, window) +} + +func TestRunOrganizationBackfill_SkipsNoopOrganization(t *testing.T) { + t.Parallel() + + rep := runOrganizationBackfill( + context.Background(), + testenv.NewLogger(t), + nil, + nil, + options{ + phase: phaseOrganizations, + environment: envLocal, + databaseURL: "", + cloudSQLProxy: false, + cloudSQLPort: 0, + cloudSQLDBName: "gram", + workosAPIKey: "", + workosEndpoint: "", + workosOrgIDs: nil, + limit: 0, + pageSize: 0, + pageOffset: 0, + statementTimeout: defaultStatementTimeout, + dryRun: false, + autoApprove: false, + pauseAfterEach: false, + confirmProd: "", + breakpointBefore: false, + }, + []orgExpectation{{ + workosOrgID: "org_noop", + gramOrgID: "gram_noop", + name: "Noop", + skipped: false, + roles: nil, + users: nil, + members: nil, + orgChanges: changeCounts{ + Create: 0, + Update: 0, + Noop: 1, + Delete: 0, + StaleSkip: 0, + }, + roleChanges: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + userChanges: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + membershipChanges: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + assignmentChanges: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + changeDetails: nil, + }}, + ) + + require.Equal(t, report{ + scanned: 1, + skipped: 0, + skippedNoop: 1, + written: 0, + validated: 0, + failed: 0, + validationFailures: 0, + organizationRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + roleRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + userRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + membershipRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + assignmentRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + }, rep) +} + +func TestBackfillWorkOSOrganization_CreatesUnlinkedOrganizationWithExternalID(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_create_org_external_id") + logger := testenv.NewLogger(t) + + const organizationID = "gram_org_from_workos_external_id" + const workosOrgID = "org_01JBACKFILLCREATE" + + workosClient := newWorkOSSnapshotClient(t, ctx, + workos.Organization{ + ID: workosOrgID, + Name: "Backfill Created Org", + ExternalID: organizationID, + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }, + nil, + nil, + ) + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + + err := activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + org, err := orgrepo.New(conn).GetOrganizationByWorkosID(ctx, conv.ToPGText(workosOrgID)) + require.NoError(t, err) + require.Equal(t, organizationID, org.ID) + require.Equal(t, "Backfill Created Org", org.Name) + require.Equal(t, "backfill-created-org", org.Slug) + require.Equal(t, workosOrgID, org.WorkosID.String) + require.Empty(t, org.WorkosLastEventID.String) +} + +func TestBackfillWorkOSOrganization_CreatesUniqueSlugOnNameCollision(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_create_org_slug_collision") + logger := testenv.NewLogger(t) + + const existingOrganizationID = "gram_org_existing_tester" + const organizationID = "gram_org_new_tester" + const workosOrgID = "org_01JBACKFILLSLUGCOLLISION" + + _, err := orgrepo.New(conn).UpsertOrganizationMetadata(ctx, orgrepo.UpsertOrganizationMetadataParams{ + ID: existingOrganizationID, + Name: "tester", + Slug: "tester", + WorkosID: conv.ToPGText("org_01JEXISTINGTESTER"), + Whitelisted: pgtype.Bool{Bool: false, Valid: false}, + }) + require.NoError(t, err) + + workosClient := newWorkOSSnapshotClient(t, ctx, + workos.Organization{ + ID: workosOrgID, + Name: "tester", + ExternalID: organizationID, + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }, + nil, + nil, + ) + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + + err = activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + org, err := orgrepo.New(conn).GetOrganizationByWorkosID(ctx, conv.ToPGText(workosOrgID)) + require.NoError(t, err) + require.Equal(t, organizationID, org.ID) + require.Equal(t, "tester", org.Name) + require.NotEqual(t, "tester", org.Slug) + require.Contains(t, org.Slug, "tester-") + require.Len(t, org.Slug, len("tester-")+4) +} + +func TestBackfillWorkOSOrganization_ExternalIDChangeDoesNotChangeOrganizationID(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_external_id_immutable") + logger := testenv.NewLogger(t) + + const organizationID = "gram_org_original_external_id" + const changedExternalID = "gram_org_changed_external_id" + const workosOrgID = "org_01JBACKFILLIMMUTABLE" + + seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) + + workosClient := newWorkOSSnapshotClient(t, ctx, + workos.Organization{ + ID: workosOrgID, + Name: "Backfill Immutable Org", + ExternalID: changedExternalID, + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }, + nil, + nil, + ) + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + + err := activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + org, err := orgrepo.New(conn).GetOrganizationByWorkosID(ctx, conv.ToPGText(workosOrgID)) + require.NoError(t, err) + require.Equal(t, organizationID, org.ID) + require.Equal(t, "Backfill Immutable Org", org.Name) + + _, err = orgrepo.New(conn).GetOrganizationMetadata(ctx, changedExternalID) + require.ErrorIs(t, err, pgx.ErrNoRows) +} + +func TestBackfillWorkOSOrganization_BackfillsUserAndSyncsSingleRoleAssignment(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_user_single_role") + logger := testenv.NewLogger(t) + + const organizationID = "gram_org_backfill_user" + const workosOrgID = "org_01JBACKFILLUSER" + const workosUserID = "user_01JBACKFILLUSER" + const gramUserID = "gram_user_01JBACKFILLUSER" + const membershipID = "mem_01JBACKFILLUSER" + const roleSlug = "org-support" + + seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) + + workosClient := newWorkOSSnapshotClient(t, ctx, + workos.Organization{ + ID: workosOrgID, + Name: "Backfill User", + ExternalID: "", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }, + []workos.Role{{ + ID: "role_01JSUPPORT", + Name: "Support", + Slug: roleSlug, + Description: "Support operators", + Type: "OrganizationRole", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }}, + []workos.Member{{ + ID: membershipID, + UserID: workosUserID, + OrganizationID: workosOrgID, + Organization: "Backfill User", + RoleSlug: roleSlug, + Status: "active", + CreatedAt: "2026-05-07T11:05:00Z", + UpdatedAt: "2026-05-07T11:05:00Z", + }}, + ) + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + + err := activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + role, err := accessrepo.New(conn).GetOrganizationRoleBySlug(ctx, accessrepo.GetOrganizationRoleBySlugParams{ + OrganizationID: organizationID, + WorkosSlug: roleSlug, + }) + require.NoError(t, err) + + assignments, err := orgrepo.New(conn).ListOrganizationRoleAssignmentsByWorkOSUser(ctx, orgrepo.ListOrganizationRoleAssignmentsByWorkOSUserParams{ + OrganizationID: organizationID, + WorkosUserID: workosUserID, + }) + require.NoError(t, err) + require.Len(t, assignments, 1) + require.Equal(t, fmt.Sprintf("role:organization:%s", role.ID.String()), assignments[0].RoleUrn) + require.True(t, assignments[0].UserID.Valid) + require.Equal(t, gramUserID, assignments[0].UserID.String) + require.Equal(t, membershipID, assignments[0].WorkosMembershipID.String) + require.Empty(t, assignments[0].WorkosLastEventID.String) + + relationship, err := orgrepo.New(conn).GetRelationshipByMembershipID(ctx, conv.ToPGText(membershipID)) + require.NoError(t, err) + require.True(t, relationship.UserID.Valid) + require.Equal(t, gramUserID, relationship.UserID.String) + require.Equal(t, workosUserID, relationship.WorkosUserID.String) +} + +func TestBackfillWorkOSOrganization_LinksExistingLocalUserByExternalID(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_existing_user_external_id") + logger := testenv.NewLogger(t) + + const organizationID = "gram_org_backfill_existing_user" + const workosOrgID = "org_01JBACKFILLEXISTINGUSER" + const workosUserID = "user_01JBACKFILLEXISTINGUSER" + const gramUserID = "gram_user_01JBACKFILLEXISTINGUSER" + const membershipID = "mem_01JBACKFILLEXISTINGUSER" + + seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) + _, err := usersrepo.New(conn).UpsertUser(ctx, usersrepo.UpsertUserParams{ + ID: gramUserID, + Email: "old@example.com", + DisplayName: "Old Name", + PhotoUrl: conv.ToPGTextEmpty(""), + Admin: false, + }) + require.NoError(t, err) + + workosClient := workos.NewStubClient() + workosClient.UpsertOrganization(workos.Organization{ + ID: workosOrgID, + Name: "Backfill Existing User", + ExternalID: "", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }) + workosClient.UpsertUser(workosOrgID, workos.User{ + ID: workosUserID, + FirstName: "Existing", + LastName: "User", + Email: "existing@example.com", + ProfilePictureURL: "", + ExternalID: gramUserID, + CreatedAt: "2026-05-07T11:05:00Z", + UpdatedAt: "2026-05-07T11:05:00Z", + }) + workosClient.UpsertOrganizationMembership(workos.Member{ + ID: membershipID, + UserID: workosUserID, + OrganizationID: workosOrgID, + Organization: "Backfill Existing User", + RoleSlug: "", + Status: "active", + CreatedAt: "2026-05-07T11:05:00Z", + UpdatedAt: "2026-05-07T11:05:00Z", + }) + + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + err = activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + user, err := usersrepo.New(conn).GetUser(ctx, gramUserID) + require.NoError(t, err) + require.Equal(t, "existing@example.com", user.Email) + require.Equal(t, "Existing User", user.DisplayName) + require.Equal(t, workosUserID, user.WorkosID.String) + + relationship, err := orgrepo.New(conn).GetRelationshipByMembershipID(ctx, conv.ToPGText(membershipID)) + require.NoError(t, err) + require.True(t, relationship.UserID.Valid) + require.Equal(t, gramUserID, relationship.UserID.String) + require.Equal(t, workosUserID, relationship.WorkosUserID.String) +} + +func TestBackfillWorkOSOrganization_MergesExistingMembershipPlaceholder(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_merge_membership_placeholder") + logger := testenv.NewLogger(t) + + const organizationID = "gram_org_backfill_membership_placeholder" + const workosOrgID = "org_01JBACKFILLMEMPLACEHOLDER" + const workosUserID = "user_01JBACKFILLMEMPLACEHOLDER" + const gramUserID = "gram_user_01JBACKFILLMEMPLACEHOLDER" + const membershipID = "om_01JBACKFILLMEMPLACEHOLDER" + + seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) + _, err := usersrepo.New(conn).UpsertUser(ctx, usersrepo.UpsertUserParams{ + ID: gramUserID, + Email: "membership-placeholder@example.com", + DisplayName: "Membership Placeholder", + PhotoUrl: conv.ToPGTextEmpty(""), + Admin: false, + }) + require.NoError(t, err) + err = orgrepo.New(conn).UpsertWorkOSMembership(ctx, orgrepo.UpsertWorkOSMembershipParams{ + OrganizationID: organizationID, + UserID: conv.ToPGTextEmpty(""), + WorkosUserID: conv.ToPGText(workosUserID), + WorkosMembershipID: conv.ToPGText(membershipID), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Date(2026, 5, 7, 10, 0, 0, 0, time.UTC)), + WorkosLastEventID: conv.ToPGText(""), + }) + require.NoError(t, err) + _, err = orgrepo.New(conn).UpsertOrganizationUserRelationship(ctx, orgrepo.UpsertOrganizationUserRelationshipParams{ + OrganizationID: organizationID, + UserID: conv.ToPGText(gramUserID), + }) + require.NoError(t, err) + + workosClient := workos.NewStubClient() + workosClient.UpsertOrganization(workos.Organization{ + ID: workosOrgID, + Name: "Backfill Membership Placeholder", + ExternalID: "", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }) + workosClient.UpsertUser(workosOrgID, workos.User{ + ID: workosUserID, + FirstName: "Membership", + LastName: "Placeholder", + Email: "membership-placeholder@example.com", + ProfilePictureURL: "", + ExternalID: gramUserID, + CreatedAt: "2026-05-07T11:05:00Z", + UpdatedAt: "2026-05-07T11:05:00Z", + }) + workosClient.UpsertOrganizationMembership(workos.Member{ + ID: membershipID, + UserID: workosUserID, + OrganizationID: workosOrgID, + Organization: "Backfill Membership Placeholder", + RoleSlug: "", + Status: "active", + CreatedAt: "2026-05-07T11:05:00Z", + UpdatedAt: "2026-05-07T11:05:00Z", + }) + + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + err = activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + relationship, err := orgrepo.New(conn).GetOrganizationRelationshipForUser(ctx, orgrepo.GetOrganizationRelationshipForUserParams{ + OrganizationID: organizationID, + UserID: conv.ToPGText(gramUserID), + }) + require.NoError(t, err) + require.Equal(t, membershipID, relationship.WorkosMembershipID.String) + require.Equal(t, workosUserID, relationship.WorkosUserID.String) + require.False(t, relationship.Deleted) +} + +func TestBackfillWorkOSOrganization_UpdatesExistingMembershipWorkOSFields(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_update_existing_membership_fields") + logger := testenv.NewLogger(t) + + const organizationID = "gram_org_backfill_existing_membership_fields" + const workosOrgID = "org_01JBACKFILLEXISTINGMEMFIELDS" + const workosUserID = "user_01JBACKFILLEXISTINGMEMFIELDS" + const gramUserID = "gram_user_01JBACKFILLEXISTINGMEMFIELDS" + const membershipID = "om_01JBACKFILLEXISTINGMEMFIELDS" + + seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) + _, err := usersrepo.New(conn).UpsertUser(ctx, usersrepo.UpsertUserParams{ + ID: gramUserID, + Email: "existing-membership-fields@example.com", + DisplayName: "Existing Membership Fields", + PhotoUrl: conv.ToPGTextEmpty(""), + Admin: false, + }) + require.NoError(t, err) + err = orgrepo.New(conn).UpsertWorkOSMembership(ctx, orgrepo.UpsertWorkOSMembershipParams{ + OrganizationID: organizationID, + UserID: conv.ToPGText(gramUserID), + WorkosUserID: conv.ToPGTextEmpty(""), + WorkosMembershipID: conv.ToPGText(membershipID), + WorkosUpdatedAt: pgtype.Timestamptz{Time: time.Time{}, Valid: false}, + WorkosLastEventID: conv.ToPGText("event_01JNEWER"), + }) + require.NoError(t, err) + + workosClient := workos.NewStubClient() + workosClient.UpsertOrganization(workos.Organization{ + ID: workosOrgID, + Name: "Backfill Existing Membership Fields", + ExternalID: "", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }) + workosClient.UpsertUser(workosOrgID, workos.User{ + ID: workosUserID, + FirstName: "Existing", + LastName: "Membership", + Email: "existing-membership-fields@example.com", + ProfilePictureURL: "", + ExternalID: gramUserID, + CreatedAt: "2026-05-07T11:05:00Z", + UpdatedAt: "2026-05-07T11:05:00Z", + }) + workosClient.UpsertOrganizationMembership(workos.Member{ + ID: membershipID, + UserID: workosUserID, + OrganizationID: workosOrgID, + Organization: "Backfill Existing Membership Fields", + RoleSlug: "", + Status: "active", + CreatedAt: "2026-05-07T11:05:00Z", + UpdatedAt: "2026-05-07T11:05:00Z", + }) + + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + err = activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + relationship, err := orgrepo.New(conn).GetRelationshipByMembershipID(ctx, conv.ToPGText(membershipID)) + require.NoError(t, err) + require.Equal(t, organizationID, relationship.OrganizationID) + require.Equal(t, gramUserID, relationship.UserID.String) + require.Equal(t, workosUserID, relationship.WorkosUserID.String) + require.True(t, relationship.WorkosUpdatedAt.Valid) + require.Equal(t, "event_01JNEWER", relationship.WorkosLastEventID.String) + require.False(t, relationship.Deleted) +} + +func TestBackfillWorkOSOrganization_RepairsMissingMembershipID(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_repair_missing_membership_id") + logger := testenv.NewLogger(t) + + const organizationID = "gram_org_backfill_missing_membership_id" + const workosOrgID = "org_01JBACKFILLMISSINGMEMID" + const workosUserID = "user_01JBACKFILLMISSINGMEMID" + const gramUserID = "gram_user_01JBACKFILLMISSINGMEMID" + const membershipID = "om_01JBACKFILLMISSINGMEMID" + + seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) + _, err := usersrepo.New(conn).UpsertUser(ctx, usersrepo.UpsertUserParams{ + ID: gramUserID, + Email: "missing-membership-id@example.com", + DisplayName: "Missing Membership ID", + PhotoUrl: conv.ToPGTextEmpty(""), + Admin: false, + }) + require.NoError(t, err) + _, err = orgrepo.New(conn).UpsertOrganizationUserRelationship(ctx, orgrepo.UpsertOrganizationUserRelationshipParams{ + OrganizationID: organizationID, + UserID: conv.ToPGText(gramUserID), + }) + require.NoError(t, err) + err = orgrepo.New(conn).SetOrganizationRelationshipWorkOSCursor(ctx, orgrepo.SetOrganizationRelationshipWorkOSCursorParams{ + OrganizationID: organizationID, + UserID: conv.ToPGText(gramUserID), + WorkosUpdatedAt: pgtype.Timestamptz{Time: time.Time{}, Valid: false}, + WorkosLastEventID: conv.ToPGText("event_01JNEWER"), + }) + require.NoError(t, err) + + workosClient := workos.NewStubClient() + workosClient.UpsertOrganization(workos.Organization{ + ID: workosOrgID, + Name: "Backfill Missing Membership ID", + ExternalID: "", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }) + workosClient.UpsertUser(workosOrgID, workos.User{ + ID: workosUserID, + FirstName: "Missing", + LastName: "Membership", + Email: "missing-membership-id@example.com", + ProfilePictureURL: "", + ExternalID: gramUserID, + CreatedAt: "2026-05-07T11:05:00Z", + UpdatedAt: "2026-05-07T11:05:00Z", + }) + workosClient.UpsertOrganizationMembership(workos.Member{ + ID: membershipID, + UserID: workosUserID, + OrganizationID: workosOrgID, + Organization: "Backfill Missing Membership ID", + RoleSlug: "", + Status: "active", + CreatedAt: "2026-05-07T11:05:00Z", + UpdatedAt: "2026-05-07T11:05:00Z", + }) + + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + err = activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + relationship, err := orgrepo.New(conn).GetRelationshipByMembershipID(ctx, conv.ToPGText(membershipID)) + require.NoError(t, err) + require.Equal(t, organizationID, relationship.OrganizationID) + require.Equal(t, gramUserID, relationship.UserID.String) + require.Equal(t, workosUserID, relationship.WorkosUserID.String) + require.True(t, relationship.WorkosUpdatedAt.Valid) + require.Equal(t, "event_01JNEWER", relationship.WorkosLastEventID.String) + require.False(t, relationship.Deleted) +} + +func TestBackfillWorkOSOrganization_ValidationSkipsUnresolvableAssignmentRole(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_unresolvable_assignment_role") + logger := testenv.NewLogger(t) + + const organizationID = "gram_org_backfill_unresolvable_role" + const workosOrgID = "org_01JBACKFILLUNRESOLVABLEROLE" + const workosUserID = "user_01JBACKFILLUNRESOLVABLEROLE" + const gramUserID = "gram_user_01JBACKFILLUNRESOLVABLEROLE" + const membershipID = "om_01JBACKFILLUNRESOLVABLEROLE" + + seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) + + workosClient := workos.NewStubClient() + workosClient.UpsertOrganization(workos.Organization{ + ID: workosOrgID, + Name: "Backfill Unresolvable Role", + ExternalID: "", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }) + user := workos.User{ + ID: workosUserID, + FirstName: "Missing", + LastName: "Role", + Email: "missing-role@example.com", + ProfilePictureURL: "", + ExternalID: gramUserID, + CreatedAt: "2026-05-07T11:05:00Z", + UpdatedAt: "2026-05-07T11:05:00Z", + } + member := workos.Member{ + ID: membershipID, + UserID: workosUserID, + OrganizationID: workosOrgID, + Organization: "Backfill Unresolvable Role", + RoleSlug: "member", + Status: "active", + CreatedAt: "2026-05-07T11:05:00Z", + UpdatedAt: "2026-05-07T11:05:00Z", + } + workosClient.UpsertUser(workosOrgID, user) + workosClient.UpsertOrganizationMembership(member) + + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + err := activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + assignments, err := orgrepo.New(conn).ListOrganizationRoleAssignmentsByWorkOSUser(ctx, orgrepo.ListOrganizationRoleAssignmentsByWorkOSUserParams{ + OrganizationID: organizationID, + WorkosUserID: workosUserID, + }) + require.NoError(t, err) + require.Empty(t, assignments) + + err = validateOrganization(ctx, conn, orgExpectation{ + workosOrgID: workosOrgID, + gramOrgID: organizationID, + name: "Backfill Unresolvable Role", + skipped: false, + roles: nil, + users: map[string]workos.User{ + workosUserID: user, + }, + members: []workos.Member{member}, + orgChanges: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + roleChanges: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + userChanges: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + membershipChanges: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + assignmentChanges: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 1}, + changeDetails: nil, + }) + require.NoError(t, err) +} + +func TestBackfillWorkOSOrganization_MembershipWithNewerEventSkipsRoleSnapshot(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_membership_newer_event_wins") + logger := testenv.NewLogger(t) + + const organizationID = "gram_org_backfill_membership_event_wins" + const workosOrgID = "org_01JBACKFILLMEMEVENT" + const workosUserID = "user_01JBACKFILLMEMEVENT" + const membershipID = "mem_01JBACKFILLMEMEVENT" + const roleSlug = "org-member" + + seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) + seedOrganizationRoleWithCursor(t, ctx, conn, organizationID, roleSlug, "Member", "") + err := orgrepo.New(conn).SyncUserOrganizationRoleAssignments(ctx, orgrepo.SyncUserOrganizationRoleAssignmentsParams{ + OrganizationID: organizationID, + WorkosUserID: workosUserID, + WorkosRoleSlugs: []string{roleSlug}, + UserID: conv.ToPGTextEmpty(""), + WorkosMembershipID: conv.ToPGText(membershipID), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Date(2026, 5, 7, 12, 0, 0, 0, time.UTC)), + WorkosLastEventID: conv.ToPGText("event_99FRESH"), + }) + require.NoError(t, err) + + workosClient := newWorkOSSnapshotClient(t, ctx, + workos.Organization{ + ID: workosOrgID, + Name: "Backfill Membership Event Wins", + ExternalID: "", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }, + []workos.Role{{ + ID: "role_01JMEMBER", + Name: "Member", + Slug: roleSlug, + Description: "", + Type: "OrganizationRole", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }}, + []workos.Member{{ + ID: membershipID, + UserID: workosUserID, + OrganizationID: workosOrgID, + Organization: "Backfill Membership Event Wins", + RoleSlug: "", + Status: "active", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }}, + ) + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + + err = activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + assignments, err := orgrepo.New(conn).ListOrganizationRoleAssignmentsByWorkOSUser(ctx, orgrepo.ListOrganizationRoleAssignmentsByWorkOSUserParams{ + OrganizationID: organizationID, + WorkosUserID: workosUserID, + }) + require.NoError(t, err) + require.Len(t, assignments, 1) + require.Equal(t, "event_99FRESH", assignments[0].WorkosLastEventID.String) +} + +func TestBackfillWorkOSOrganization_NewerRelationshipDoesNotSkipMissingAssignment(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_relationship_newer_assignment_missing") + logger := testenv.NewLogger(t) + + const organizationID = "gram_org_backfill_relationship_newer" + const workosOrgID = "org_01JBACKFILLRELNEWER" + const workosUserID = "user_01JBACKFILLRELNEWER" + const gramUserID = "gram_user_01JBACKFILLRELNEWER" + const membershipID = "om_01JBACKFILLRELNEWER" + const roleSlug = "org-member" + + seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) + _, err := usersrepo.New(conn).UpsertUser(ctx, usersrepo.UpsertUserParams{ + ID: gramUserID, + Email: "relationship-newer@example.com", + DisplayName: "Relationship Newer", + PhotoUrl: conv.ToPGTextEmpty(""), + Admin: false, + }) + require.NoError(t, err) + err = orgrepo.New(conn).UpsertWorkOSMembership(ctx, orgrepo.UpsertWorkOSMembershipParams{ + OrganizationID: organizationID, + UserID: conv.ToPGText(gramUserID), + WorkosUserID: conv.ToPGText(workosUserID), + WorkosMembershipID: conv.ToPGText(membershipID), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Date(2026, 5, 8, 12, 0, 0, 0, time.UTC)), + WorkosLastEventID: conv.ToPGText("event_99RELATIONSHIP"), + }) + require.NoError(t, err) + + workosClient := newWorkOSSnapshotClient(t, ctx, + workos.Organization{ + ID: workosOrgID, + Name: "Backfill Relationship Newer", + ExternalID: "", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }, + []workos.Role{{ + ID: "role_01JRELNEWER", + Name: "Member", + Slug: roleSlug, + Description: "", + Type: "OrganizationRole", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }}, + []workos.Member{{ + ID: membershipID, + UserID: workosUserID, + OrganizationID: workosOrgID, + Organization: "Backfill Relationship Newer", + RoleSlug: roleSlug, + Status: "active", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }}, + ) + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + + err = activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + relationship, err := orgrepo.New(conn).GetRelationshipByMembershipID(ctx, conv.ToPGText(membershipID)) + require.NoError(t, err) + require.Equal(t, "event_99RELATIONSHIP", relationship.WorkosLastEventID.String) + + role, err := accessrepo.New(conn).GetOrganizationRoleBySlug(ctx, accessrepo.GetOrganizationRoleBySlugParams{ + OrganizationID: organizationID, + WorkosSlug: roleSlug, + }) + require.NoError(t, err) + + assignments, err := orgrepo.New(conn).ListOrganizationRoleAssignmentsByWorkOSUser(ctx, orgrepo.ListOrganizationRoleAssignmentsByWorkOSUserParams{ + OrganizationID: organizationID, + WorkosUserID: workosUserID, + }) + require.NoError(t, err) + require.Len(t, assignments, 1) + require.Equal(t, fmt.Sprintf("role:organization:%s", role.ID.String()), assignments[0].RoleUrn) + require.Equal(t, membershipID, assignments[0].WorkosMembershipID.String) +} + +func TestBackfillWorkOSOrganization_RecreatesDeletedAssignmentFromSnapshot(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_recreate_deleted_assignment") + logger := testenv.NewLogger(t) + + const organizationID = "gram_org_backfill_recreate_deleted_assignment" + const workosOrgID = "org_01JBACKFILLRECREATEDELETED" + const workosUserID = "user_01JBACKFILLRECREATEDELETED" + const gramUserID = "gram_user_01JBACKFILLRECREATEDELETED" + const membershipID = "om_01JBACKFILLRECREATEDELETED" + const roleSlug = "org-member" + + seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) + seedOrganizationRoleWithCursor(t, ctx, conn, organizationID, roleSlug, "Member", "") + _, err := usersrepo.New(conn).UpsertUser(ctx, usersrepo.UpsertUserParams{ + ID: gramUserID, + Email: "recreate-deleted@example.com", + DisplayName: "Recreate Deleted", + PhotoUrl: conv.ToPGTextEmpty(""), + Admin: false, + }) + require.NoError(t, err) + err = orgrepo.New(conn).SyncUserOrganizationRoleAssignments(ctx, orgrepo.SyncUserOrganizationRoleAssignmentsParams{ + OrganizationID: organizationID, + WorkosUserID: workosUserID, + WorkosRoleSlugs: []string{roleSlug}, + UserID: conv.ToPGText(gramUserID), + WorkosMembershipID: conv.ToPGText(membershipID), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Date(2026, 5, 7, 11, 0, 0, 0, time.UTC)), + WorkosLastEventID: conv.ToPGText("event_01DELETESETUP"), + }) + require.NoError(t, err) + err = orgrepo.New(conn).SyncUserOrganizationRoleAssignments(ctx, orgrepo.SyncUserOrganizationRoleAssignmentsParams{ + OrganizationID: organizationID, + WorkosUserID: workosUserID, + WorkosRoleSlugs: []string{}, + UserID: conv.ToPGText(gramUserID), + WorkosMembershipID: conv.ToPGText(membershipID), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Date(2026, 5, 7, 12, 0, 0, 0, time.UTC)), + WorkosLastEventID: conv.ToPGText("event_02DELETESETUP"), + }) + require.NoError(t, err) + + workosClient := newWorkOSSnapshotClient(t, ctx, + workos.Organization{ + ID: workosOrgID, + Name: "Backfill Recreate Deleted Assignment", + ExternalID: "", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }, + []workos.Role{{ + ID: "role_01JRECREATE", + Name: "Member", + Slug: roleSlug, + Description: "", + Type: "OrganizationRole", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }}, + []workos.Member{{ + ID: membershipID, + UserID: workosUserID, + OrganizationID: workosOrgID, + Organization: "Backfill Recreate Deleted Assignment", + RoleSlug: roleSlug, + Status: "active", + CreatedAt: "2026-05-07T13:00:00Z", + UpdatedAt: "2026-05-07T13:00:00Z", + }}, + ) + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + + err = activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + assignments, err := orgrepo.New(conn).ListOrganizationRoleAssignmentsByWorkOSUser(ctx, orgrepo.ListOrganizationRoleAssignmentsByWorkOSUserParams{ + OrganizationID: organizationID, + WorkosUserID: workosUserID, + }) + require.NoError(t, err) + activeAssignments := 0 + deletedAssignments := 0 + for _, assignment := range assignments { + if assignment.DeletedAt.Valid { + deletedAssignments++ + } else { + activeAssignments++ + require.Equal(t, membershipID, assignment.WorkosMembershipID.String) + } + } + require.Equal(t, 1, deletedAssignments) + require.Equal(t, 1, activeAssignments) +} + +func TestBackfillWorkOSOrganization_RoleWithLastEventIDSkipsSnapshot(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_role_last_event_wins") + logger := testenv.NewLogger(t) + + const organizationID = "gram_org_backfill_event_wins" + const workosOrgID = "org_01JBACKFILLEVENTWINS" + const roleSlug = "org-billing" + + seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) + seedOrganizationRoleWithCursor(t, ctx, conn, organizationID, roleSlug, "Billing From Event", "event_01JNEWER") + + workosClient := newWorkOSSnapshotClient(t, ctx, + workos.Organization{ + ID: workosOrgID, + Name: "Backfill Event Wins", + ExternalID: "", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }, + []workos.Role{{ + ID: "role_01JBILLING", + Name: "Billing From Snapshot", + Slug: roleSlug, + Description: "", + Type: "OrganizationRole", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T12:00:00Z", + }}, + nil, + ) + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + + err := activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + role, err := accessrepo.New(conn).GetOrganizationRoleBySlug(ctx, accessrepo.GetOrganizationRoleBySlugParams{ + OrganizationID: organizationID, + WorkosSlug: roleSlug, + }) + require.NoError(t, err) + require.Equal(t, "Billing From Event", role.WorkosName) + require.Equal(t, "event_01JNEWER", role.WorkosLastEventID.String) +} + +func TestBackfillWorkOSOrganization_MissingRoleSoftDeleted(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newBackfillTestConn(t, "workos_backfill_role_deleted") + logger := testenv.NewLogger(t) + + const organizationID = "gram_org_backfill_delete_role" + const workosOrgID = "org_01JBACKFILLDELETE" + const roleSlug = "org-obsolete" + + seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) + seedOrganizationRoleWithCursor(t, ctx, conn, organizationID, roleSlug, "Obsolete", "") + + workosClient := newWorkOSSnapshotClient(t, ctx, + workos.Organization{ + ID: workosOrgID, + Name: "Backfill Delete Role", + ExternalID: "", + CreatedAt: "2026-05-07T11:00:00Z", + UpdatedAt: "2026-05-07T11:00:00Z", + }, + nil, + nil, + ) + activity := NewBackfillWorkOSOrganization(logger, conn, workosClient) + + err := activity.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + role, err := accessrepo.New(conn).GetOrganizationRoleBySlug(ctx, accessrepo.GetOrganizationRoleBySlugParams{ + OrganizationID: organizationID, + WorkosSlug: roleSlug, + }) + require.NoError(t, err) + require.True(t, role.Deleted) + require.True(t, role.WorkosDeleted) + require.Empty(t, role.WorkosLastEventID.String) +} + +func newWorkOSSnapshotClient(t *testing.T, ctx context.Context, org workos.Organization, roles []workos.Role, members []workos.Member) *workos.StubClient { + t.Helper() + + client := workos.NewStubClient() + client.UpsertOrganization(org) + for _, role := range roles { + _, err := client.CreateRole(ctx, org.ID, workos.CreateRoleOpts{ + Name: role.Name, + Slug: role.Slug, + Description: role.Description, + }) + require.NoError(t, err) + } + for _, member := range members { + client.UpsertUser(org.ID, workos.User{ + ID: member.UserID, + FirstName: "Test", + LastName: "User", + Email: member.UserID + "@example.com", + ProfilePictureURL: "", + ExternalID: "gram_" + member.UserID, + CreatedAt: member.CreatedAt, + UpdatedAt: member.UpdatedAt, + }) + client.UpsertOrganizationMembership(member) + } + + return client +} + +func seedLinkedWorkOSOrganization(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID, workosOrgID string) { + t.Helper() + + _, err := orgrepo.New(conn).UpsertOrganizationMetadata(ctx, orgrepo.UpsertOrganizationMetadataParams{ + ID: organizationID, + Name: organizationID, + Slug: organizationID, + WorkosID: conv.ToPGText(workosOrgID), + Whitelisted: pgtype.Bool{Bool: false, Valid: false}, + }) + require.NoError(t, err) +} + +func seedOrganizationRoleWithCursor(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID, slug, name, lastEventID string) { + t.Helper() + + updatedAt := time.Date(2026, 5, 7, 10, 0, 0, 0, time.UTC) + err := accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + OrganizationID: organizationID, + WorkosSlug: slug, + WorkosName: name, + WorkosDescription: conv.ToPGText(""), + WorkosCreatedAt: conv.ToPGTimestamptz(updatedAt), + WorkosUpdatedAt: conv.ToPGTimestamptz(updatedAt), + WorkosLastEventID: conv.ToPGText(lastEventID), + }) + require.NoError(t, err) +} diff --git a/server/cmd/workos-backfill/backfill_user.go b/server/cmd/workos-backfill/backfill_user.go new file mode 100644 index 0000000000..4bc9a4af0e --- /dev/null +++ b/server/cmd/workos-backfill/backfill_user.go @@ -0,0 +1,203 @@ +package main + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/jackc/pgx/v5" + + "github.com/speakeasy-api/gram/server/internal/attr" + "github.com/speakeasy-api/gram/server/internal/conv" + "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" + usersrepo "github.com/speakeasy-api/gram/server/internal/users/repo" +) + +func backfillWorkOSUser(ctx context.Context, logger *slog.Logger, dbtx pgx.Tx, user workos.User) (string, bool, error) { + createdAt, err := parseWorkOSTime(user.CreatedAt) + if err != nil { + return "", false, fmt.Errorf("parse user %q created_at: %w", user.ID, err) + } + updatedAt, err := parseWorkOSTime(user.UpdatedAt) + if err != nil { + return "", false, fmt.Errorf("parse user %q updated_at: %w", user.ID, err) + } + + existingUser, found, err := getUserByWorkOSID(ctx, dbtx, user.ID) + if err != nil { + return "", false, err + } + + gramUserID := user.ExternalID + if found { + gramUserID = existingUser.ID + } else if user.ExternalID != "" { + existingUser, found, err = findUserByID(ctx, dbtx, user.ExternalID) + if err != nil { + return "", false, err + } + if found { + gramUserID = existingUser.ID + } + } + if gramUserID == "" { + logger.WarnContext(ctx, "skipping WorkOS user backfill without local user ID", attr.SlogWorkOSUserID(user.ID)) + return "", false, nil + } + + if found && existingUser.WorkosID.Valid && existingUser.WorkosID.String != user.ID { + return "", false, fmt.Errorf("local user %q is already linked to different WorkOS user %q", existingUser.ID, existingUser.WorkosID.String) + } + if found && existingUser.WorkosUpdatedAt.Valid && !shouldProcessEvent(nil, &existingUser.WorkosUpdatedAt.Time, "", updatedAt) { + return gramUserID, true, nil + } + + if found && (!existingUser.WorkosID.Valid || existingUser.WorkosID.String == user.ID) { + if err := updateSyncedUserByID(ctx, dbtx, gramUserID, user, createdAt, updatedAt); err != nil { + return "", false, err + } + return gramUserID, true, nil + } + + if _, err := usersrepo.New(dbtx).UpsertSyncedUser(ctx, usersrepo.UpsertSyncedUserParams{ + ID: gramUserID, + Email: user.Email, + DisplayName: displayNameFromWorkOSUser(user), + PhotoUrl: conv.ToPGTextEmpty(user.ProfilePictureURL), + WorkosID: conv.ToPGText(user.ID), + WorkosCreatedAt: conv.ToPGTimestamptz(createdAt), + WorkosUpdatedAt: conv.ToPGTimestamptz(updatedAt), + }); err != nil { + return "", false, fmt.Errorf("upsert synced user: %w", err) + } + + if user.ExternalID == "" { + logger.WarnContext(ctx, "WorkOS user missing external ID during backfill", attr.SlogWorkOSUserID(user.ID), attr.SlogUserID(gramUserID)) + } + + return gramUserID, true, nil +} + +func updateSyncedUserByID(ctx context.Context, dbtx pgx.Tx, gramUserID string, user workos.User, createdAt, updatedAt time.Time) error { + tag, err := dbtx.Exec(ctx, ` +UPDATE users +SET email = $2, + display_name = $3, + photo_url = $4, + workos_id = COALESCE(workos_id, $5), + workos_created_at = COALESCE(workos_created_at, $6), + workos_updated_at = $7, + workos_deleted_at = NULL, + deleted_at = NULL, + updated_at = clock_timestamp() +WHERE id = $1 + AND (workos_id IS NULL OR workos_id = $5) + AND (workos_updated_at IS NULL OR $7 >= workos_updated_at)`, + gramUserID, + user.Email, + displayNameFromWorkOSUser(user), + conv.ToPGTextEmpty(user.ProfilePictureURL), + conv.ToPGText(user.ID), + conv.ToPGTimestamptz(createdAt), + conv.ToPGTimestamptz(updatedAt), + ) + if err != nil { + return fmt.Errorf("update synced user %q by local id: %w", gramUserID, err) + } + if tag.RowsAffected() == 0 { + return fmt.Errorf("update synced user %q by local id: no rows updated", gramUserID) + } + return nil +} + +func getUserByWorkOSID(ctx context.Context, dbtx pgx.Tx, workosUserID string) (usersrepo.User, bool, error) { + users, err := usersrepo.New(dbtx).GetUsersByWorkosIDs(ctx, []string{workosUserID}) + var zero usersrepo.User + switch { + case err != nil: + return zero, false, fmt.Errorf("get user by WorkOS ID: %w", err) + case len(users) == 0: + return zero, false, nil + default: + return users[0], true, nil + } +} + +type queryRower interface { + QueryRow(ctx context.Context, sql string, args ...any) pgx.Row +} + +func findUserByWorkOSID(ctx context.Context, db queryRower, workosUserID string) (usersrepo.User, bool, error) { + var user usersrepo.User + err := db.QueryRow(ctx, ` +SELECT id, email, display_name, photo_url, admin, last_login, workos_id, workos_created_at, workos_updated_at, workos_deleted_at, deleted_at, created_at, updated_at +FROM users +WHERE workos_id = $1 +LIMIT 1`, workosUserID).Scan( + &user.ID, + &user.Email, + &user.DisplayName, + &user.PhotoUrl, + &user.Admin, + &user.LastLogin, + &user.WorkosID, + &user.WorkosCreatedAt, + &user.WorkosUpdatedAt, + &user.WorkosDeletedAt, + &user.DeletedAt, + &user.CreatedAt, + &user.UpdatedAt, + ) + var zero usersrepo.User + switch { + case errors.Is(err, pgx.ErrNoRows): + return zero, false, nil + case err != nil: + return zero, false, fmt.Errorf("get user by WorkOS ID: %w", err) + default: + return user, true, nil + } +} + +func findUserByID(ctx context.Context, db queryRower, userID string) (usersrepo.User, bool, error) { + var user usersrepo.User + err := db.QueryRow(ctx, ` +SELECT id, email, display_name, photo_url, admin, last_login, workos_id, workos_created_at, workos_updated_at, workos_deleted_at, deleted_at, created_at, updated_at +FROM users +WHERE id = $1 +LIMIT 1`, userID).Scan( + &user.ID, + &user.Email, + &user.DisplayName, + &user.PhotoUrl, + &user.Admin, + &user.LastLogin, + &user.WorkosID, + &user.WorkosCreatedAt, + &user.WorkosUpdatedAt, + &user.WorkosDeletedAt, + &user.DeletedAt, + &user.CreatedAt, + &user.UpdatedAt, + ) + var zero usersrepo.User + switch { + case errors.Is(err, pgx.ErrNoRows): + return zero, false, nil + case err != nil: + return zero, false, fmt.Errorf("get user by ID: %w", err) + default: + return user, true, nil + } +} + +func displayNameFromWorkOSUser(user workos.User) string { + displayName := strings.TrimSpace(strings.Join([]string{user.FirstName, user.LastName}, " ")) + if displayName != "" { + return displayName + } + return user.Email +} diff --git a/server/cmd/workos-backfill/cloudsql.go b/server/cmd/workos-backfill/cloudsql.go new file mode 100644 index 0000000000..7f00065bda --- /dev/null +++ b/server/cmd/workos-backfill/cloudsql.go @@ -0,0 +1,379 @@ +package main + +import ( + "context" + "errors" + "fmt" + "net" + "net/url" + "os/exec" + "strconv" + "strings" + "sync" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type cloudSQLInstance struct { + projectID string + region string + instance string +} + +type cloudSQLAccessMode string + +const ( + cloudSQLAccessModeRead cloudSQLAccessMode = "read-only" + cloudSQLAccessModeWrite cloudSQLAccessMode = "write" +) + +func startCloudSQLProxy(ctx context.Context, opts options, readOnly bool) (string, func(), error) { + instance, err := cloudSQLInstanceForEnvironment(opts.environment) + if err != nil { + return "", nil, err + } + port := opts.cloudSQLPort + if port == 0 { + port, err = freeLocalPort() + if err != nil { + return "", nil, err + } + } + user, err := activeGCloudAccount(ctx) + if err != nil { + return "", nil, err + } + + instancePath := fmt.Sprintf("%s:%s:%s", instance.projectID, instance.region, instance.instance) + fmt.Printf("Starting Cloud SQL proxy for %s on 127.0.0.1:%d as %s\n", instancePath, port, user) + + // #nosec G204 -- instancePath is selected from hardcoded dev/prod Cloud SQL config and port is validated before use. + cmd := exec.CommandContext(ctx, "cloud-sql-proxy", fmt.Sprintf("%s?port=%d", instancePath, port), "--auto-iam-authn") + if err := cmd.Start(); err != nil { + return "", nil, fmt.Errorf("start cloud-sql-proxy: %w", err) + } + + errCh := make(chan error, 1) + go func() { + errCh <- cmd.Wait() + }() + + var cleanupOnce sync.Once + cleanup := func() { + cleanupOnce.Do(func() { + if cmd.ProcessState != nil { + return + } + select { + case <-errCh: + return + default: + } + if cmd.Process == nil { + return + } + _ = cmd.Process.Kill() + select { + case <-errCh: + case <-time.After(5 * time.Second): + _ = cmd.Process.Kill() + <-errCh + } + }) + } + + if err := waitForTCP(ctx, "127.0.0.1", port, errCh); err != nil { + cleanup() + return "", nil, err + } + + mode := cloudSQLModeForReadOnly(readOnly) + if err := prepareCloudSQLIAMAccess(ctx, instance, port, strings.TrimSpace(opts.cloudSQLDBName), user, mode); err != nil { + cleanup() + return "", nil, err + } + + databaseURL := cloudSQLDatabaseURL("127.0.0.1", port, strings.TrimSpace(opts.cloudSQLDBName), user) + return databaseURL, cleanup, nil +} + +func cloudSQLInstanceForEnvironment(env environment) (cloudSQLInstance, error) { + switch env { + case envDev: + return cloudSQLInstance{ + projectID: "linen-analyst-344721", + region: "us-west1", + instance: "gram-dev-instance", + }, nil + case envProd: + return cloudSQLInstance{ + projectID: "speakeasy-prod-354914", + region: "us-west1", + instance: "gram-prod-instance", + }, nil + default: + return cloudSQLInstance{ + projectID: "", + region: "", + instance: "", + }, fmt.Errorf("cloud sql proxy is not configured for environment %q", env) + } +} + +func freeLocalPort() (int, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return 0, fmt.Errorf("find free local port: %w", err) + } + defer func() { + _ = listener.Close() + }() + addr, ok := listener.Addr().(*net.TCPAddr) + if !ok { + return 0, errors.New("free local port listener did not return a TCP address") + } + return addr.Port, nil +} + +func activeGCloudAccount(ctx context.Context) (string, error) { + output, err := gcloudOutput(ctx, "auth", "list", "--format=value(ACCOUNT)") + if err != nil { + return "", err + } + var fallback string + for line := range strings.SplitSeq(output, "\n") { + account := strings.TrimSpace(line) + if account == "" { + continue + } + if fallback == "" { + fallback = account + } + if strings.Contains(strings.ToLower(account), "speakeasy") { + return account, nil + } + } + if fallback != "" { + return fallback, nil + } + return "", errors.New("no active gcloud account found") +} + +func gcloudOutput(ctx context.Context, args ...string) (string, error) { + // #nosec G204 -- callers pass fixed gcloud subcommands with validated env-derived values. + cmd := exec.CommandContext(ctx, "gcloud", args...) + out, err := cmd.Output() + if err == nil { + return string(out), nil + } + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + return "", fmt.Errorf("gcloud %s: %w: %s", strings.Join(args, " "), err, strings.TrimSpace(string(exitErr.Stderr))) + } + return "", fmt.Errorf("gcloud %s: %w", strings.Join(args, " "), err) +} + +func prepareCloudSQLIAMAccess(ctx context.Context, instance cloudSQLInstance, port int, dbName, user string, mode cloudSQLAccessMode) error { + fmt.Printf("Preparing Cloud SQL IAM database access mode=%s user=%s\n", mode, user) + if err := ensureCloudSQLIAMUser(ctx, instance, user); err != nil { + return err + } + + password, err := cloudSQLAdminPassword(ctx, instance) + if err != nil { + return err + } + adminURL := cloudSQLAdminDatabaseURL("127.0.0.1", port, dbName, password) + adminDB, err := connectDB(ctx, adminURL, false, defaultStatementTimeout) + if err != nil { + return fmt.Errorf("connect Cloud SQL admin database: %w", err) + } + defer adminDB.Close() + + if err := grantCloudSQLIAMUserAccess(ctx, adminDB, user, mode); err != nil { + return err + } + return nil +} + +func ensureCloudSQLIAMUser(ctx context.Context, instance cloudSQLInstance, user string) error { + fmt.Println("Looking up Cloud SQL IAM database user") + output, err := gcloudOutput(ctx, + "sql", + "users", + "list", + "--instance", instance.instance, + "--format=value(NAME)", + "--project", instance.projectID, + ) + if err != nil { + return err + } + for line := range strings.SplitSeq(output, "\n") { + if strings.TrimSpace(line) == user { + return nil + } + } + + fmt.Printf("Creating Cloud SQL IAM database user %s on %s\n", user, instance.instance) + _, err = gcloudOutput(ctx, + "sql", + "users", + "create", user, + "--instance", instance.instance, + "--type=cloud_iam_user", + "--project", instance.projectID, + ) + if err != nil { + return fmt.Errorf("create Cloud SQL IAM database user %s: %w", user, err) + } + return nil +} + +func cloudSQLAdminPassword(ctx context.Context, instance cloudSQLInstance) (string, error) { + envName, err := cloudSQLSecretEnvironment(instance) + if err != nil { + return "", err + } + output, err := gcloudOutput(ctx, + "secrets", + "versions", + "access", "latest", + "--secret", fmt.Sprintf("%s_gram_db_password", envName), + "--project", instance.projectID, + ) + if err != nil { + return "", fmt.Errorf("read Cloud SQL admin password secret: %w", err) + } + password := strings.TrimSpace(output) + if password == "" { + return "", errors.New("cloud SQL admin password secret was empty") + } + return password, nil +} + +func cloudSQLSecretEnvironment(instance cloudSQLInstance) (string, error) { + switch instance.instance { + case "gram-dev-instance": + return string(envDev), nil + case "gram-prod-instance": + return string(envProd), nil + default: + return "", fmt.Errorf("no Cloud SQL password secret configured for instance %q", instance.instance) + } +} + +type dbExecer interface { + Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) +} + +func grantCloudSQLIAMUserAccess(ctx context.Context, db dbExecer, user string, mode cloudSQLAccessMode) error { + schema := pgx.Identifier{"gram"}.Sanitize() + grantee := pgx.Identifier{user}.Sanitize() + statements := []string{ + fmt.Sprintf("REVOKE ALL ON ALL TABLES IN SCHEMA %s FROM %s", schema, grantee), + fmt.Sprintf("REVOKE USAGE, CREATE ON SCHEMA %s FROM %s", schema, grantee), + } + switch mode { + case cloudSQLAccessModeRead: + statements = append(statements, + fmt.Sprintf("GRANT USAGE ON SCHEMA %s TO %s", schema, grantee), + fmt.Sprintf("GRANT SELECT ON ALL TABLES IN SCHEMA %s TO %s", schema, grantee), + ) + case cloudSQLAccessModeWrite: + statements = append(statements, + fmt.Sprintf("GRANT USAGE, CREATE ON SCHEMA %s TO %s", schema, grantee), + fmt.Sprintf("GRANT SELECT, INSERT, UPDATE ON ALL TABLES IN SCHEMA %s TO %s", schema, grantee), + ) + default: + return fmt.Errorf("unsupported Cloud SQL access mode %q", mode) + } + + for _, statement := range statements { + if _, err := db.Exec(ctx, statement); err != nil { + return fmt.Errorf("grant Cloud SQL IAM user access with %q: %w", statement, err) + } + } + return nil +} + +func waitForTCP(ctx context.Context, host string, port int, proxyErr <-chan error) error { + address := net.JoinHostPort(host, strconv.Itoa(port)) + deadline := time.After(10 * time.Second) + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + conn, err := net.DialTimeout("tcp", address, 200*time.Millisecond) + if err == nil { + _ = conn.Close() + return nil + } + + select { + case <-ctx.Done(): + return fmt.Errorf("wait for cloud-sql-proxy: %w", ctx.Err()) + case err := <-proxyErr: + return fmt.Errorf("cloud-sql-proxy exited before accepting connections: %w", err) + case <-deadline: + return fmt.Errorf("cloud-sql-proxy did not accept connections on %s within 10s", address) + case <-ticker.C: + } + } +} + +func cloudSQLDatabaseURL(host string, port int, dbName, user string) string { + u := url.URL{ + Scheme: "postgres", + User: url.User(user), + Host: net.JoinHostPort(host, strconv.Itoa(port)), + Path: "/" + dbName, + } + q := u.Query() + q.Set("sslmode", "disable") + q.Set("search_path", "gram") + u.RawQuery = q.Encode() + return u.String() +} + +func cloudSQLAdminDatabaseURL(host string, port int, dbName, password string) string { + u := url.URL{ + Scheme: "postgres", + User: url.UserPassword("gram", password), + Host: net.JoinHostPort(host, strconv.Itoa(port)), + Path: "/" + dbName, + } + q := u.Query() + q.Set("sslmode", "disable") + q.Set("search_path", "gram") + u.RawQuery = q.Encode() + return u.String() +} + +func cloudSQLProxyHint() string { + return "; cloud sql proxy is enabled, so check that gcloud auth is active and the Cloud SQL proxy can reach the instance" +} + +func cloudSQLRunHint(err error, opts options) string { + if !opts.cloudSQLProxy { + return "" + } + message := err.Error() + if !strings.Contains(message, "SQLSTATE 42501") && !strings.Contains(message, "permission denied") { + return "" + } + return fmt.Sprintf(` + +Cloud SQL permission hint: + The script attempted to grant %s access to your IAM database user before connecting. + Check that the gram admin password secret is current and that your gcloud account can manage Cloud SQL users.`, cloudSQLModeForReadOnly(opts.dryRun || opts.phase == phasePreflight || opts.phase == phaseValidate)) +} + +func cloudSQLModeForReadOnly(readOnly bool) cloudSQLAccessMode { + if readOnly { + return cloudSQLAccessModeRead + } + return cloudSQLAccessModeWrite +} diff --git a/server/internal/background/activities/backfill_workos_global_roles.go b/server/cmd/workos-backfill/global_roles.go similarity index 93% rename from server/internal/background/activities/backfill_workos_global_roles.go rename to server/cmd/workos-backfill/global_roles.go index 31df03cc92..14008cf585 100644 --- a/server/internal/background/activities/backfill_workos_global_roles.go +++ b/server/cmd/workos-backfill/global_roles.go @@ -1,4 +1,4 @@ -package activities +package main import ( "context" @@ -20,10 +20,10 @@ import ( type BackfillWorkOSGlobalRoles struct { logger *slog.Logger db *pgxpool.Pool - workos WorkOSClient + workos Client } -func NewBackfillWorkOSGlobalRoles(logger *slog.Logger, db *pgxpool.Pool, workosClient WorkOSClient) *BackfillWorkOSGlobalRoles { +func NewBackfillWorkOSGlobalRoles(logger *slog.Logger, db *pgxpool.Pool, workosClient Client) *BackfillWorkOSGlobalRoles { return &BackfillWorkOSGlobalRoles{ logger: logger.With(attr.SlogComponent("backfill_workos_global_roles")), db: db, @@ -69,7 +69,7 @@ func (b *BackfillWorkOSGlobalRoles) Do(ctx context.Context) error { if existing.WorkosUpdatedAt.Valid { rowUpdatedAt = &existing.WorkosUpdatedAt.Time } - if !ShouldProcessEvent(lastEventID, rowUpdatedAt, "", updatedAt) { + if !shouldProcessEvent(lastEventID, rowUpdatedAt, "", updatedAt) { continue } @@ -103,7 +103,7 @@ func (b *BackfillWorkOSGlobalRoles) Do(ctx context.Context) error { rowUpdatedAt = &localRole.WorkosUpdatedAt.Time } deletedAt := time.Now().UTC() - if !ShouldProcessEvent(lastEventID, rowUpdatedAt, "", deletedAt) { + if !shouldProcessEvent(lastEventID, rowUpdatedAt, "", deletedAt) { continue } diff --git a/server/cmd/workos-backfill/main.go b/server/cmd/workos-backfill/main.go new file mode 100644 index 0000000000..fb03144af6 --- /dev/null +++ b/server/cmd/workos-backfill/main.go @@ -0,0 +1,2661 @@ +package main + +import ( + "bufio" + "context" + "errors" + "flag" + "fmt" + "log/slog" + "os" + "sort" + "strings" + "syscall" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + "go.opentelemetry.io/otel/trace/noop" + "golang.org/x/term" + + "github.com/speakeasy-api/gram/server/internal/conv" + "github.com/speakeasy-api/gram/server/internal/guardian" + "github.com/speakeasy-api/gram/server/internal/o11y" + "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" +) + +const sampleSize = 5 +const updateDetailLimit = 20 +const changeSummarySampleLimit = 3 +const defaultStatementTimeout = 30 * time.Minute + +type phase string + +const ( + phasePreflight phase = "preflight" + phaseGlobalRoles phase = "global-roles" + phaseOrganizations phase = "organizations" + phaseValidate phase = "validate" + phaseAll phase = "all" +) + +type environment string + +const ( + envLocal environment = "local" + envDev environment = "dev" + envProd environment = "prod" +) + +type options struct { + phase phase + environment environment + databaseURL string + cloudSQLProxy bool + cloudSQLPort int + cloudSQLDBName string + workosAPIKey string + workosEndpoint string + workosOrgIDs []string + limit int + pageSize int + pageOffset int + statementTimeout time.Duration + dryRun bool + autoApprove bool + pauseAfterEach bool + confirmProd string + breakpointBefore bool +} + +type orgExpectation struct { + workosOrgID string + gramOrgID string + name string + skipped bool + roles []workos.Role + users map[string]workos.User + members []workos.Member + orgChanges changeCounts + roleChanges changeCounts + userChanges changeCounts + membershipChanges changeCounts + assignmentChanges changeCounts + changeDetails []changeDetail +} + +type changeDetail struct { + Entity string + ID string + Action string + Fields []fieldChange +} + +type fieldChange struct { + Name string + Before string + After string +} + +type changeSummaryGroup struct { + Entity string + Action string + Risk string + Fields []string + Count int + Samples []changeDetail +} + +type report struct { + scanned int + skipped int + skippedNoop int + written int + validated int + failed int + validationFailures int + organizationRows changeCounts + roleRows changeCounts + userRows changeCounts + membershipRows changeCounts + assignmentRows changeCounts +} + +type changeCounts struct { + Create int + Update int + Noop int + Delete int + StaleSkip int +} + +func (c changeCounts) Add(other changeCounts) changeCounts { + return changeCounts{ + Create: c.Create + other.Create, + Update: c.Update + other.Update, + Noop: c.Noop + other.Noop, + Delete: c.Delete + other.Delete, + StaleSkip: c.StaleSkip + other.StaleSkip, + } +} + +func (c changeCounts) Mutating() int { + return c.Create + c.Update + c.Delete +} + +type stringList []string + +func (s *stringList) String() string { + return strings.Join(*s, ",") +} + +func (s *stringList) Set(value string) error { + for part := range strings.SplitSeq(value, ",") { + part = strings.TrimSpace(part) + if part != "" { + *s = append(*s, part) + } + } + return nil +} + +func main() { + ctx := context.Background() + opts := parseFlags() + if err := run(ctx, opts); err != nil { + fmt.Fprintf(os.Stderr, "workos-backfill: %v%s\n", err, cloudSQLRunHint(err, opts)) + os.Exit(1) + } +} + +func parseFlags() options { + opts := options{ + phase: phasePreflight, + environment: envLocal, + databaseURL: strings.TrimSpace(os.Getenv("GRAM_DATABASE_URL")), + cloudSQLProxy: false, + cloudSQLPort: 0, + cloudSQLDBName: "gram", + workosAPIKey: strings.TrimSpace(firstNonEmpty(os.Getenv("WORKOS_API_KEY"), os.Getenv("WORK_OS_SECRET_KEY"))), + workosEndpoint: strings.TrimSpace(os.Getenv("WORKOS_API_URL")), + workosOrgIDs: nil, + limit: 0, + pageSize: 0, + pageOffset: 0, + statementTimeout: defaultStatementTimeout, + dryRun: true, + autoApprove: false, + pauseAfterEach: false, + confirmProd: "", + breakpointBefore: false, + } + + var rawPhase string + var rawEnv string + var orgIDs stringList + flag.StringVar(&rawPhase, "phase", string(opts.phase), "phase to run: preflight, global-roles, organizations, validate, all") + flag.StringVar(&rawEnv, "environment", string(opts.environment), "target environment: local, dev, prod") + flag.StringVar(&opts.databaseURL, "database-url", opts.databaseURL, "Postgres connection URL (defaults to GRAM_DATABASE_URL)") + flag.BoolVar(&opts.cloudSQLProxy, "cloudsql-proxy", opts.cloudSQLProxy, "start a local Cloud SQL proxy and connect through it") + flag.IntVar(&opts.cloudSQLPort, "cloudsql-port", opts.cloudSQLPort, "local Cloud SQL proxy port (defaults to a free port)") + flag.StringVar(&opts.cloudSQLDBName, "cloudsql-db-name", opts.cloudSQLDBName, "Cloud SQL database name") + flag.StringVar(&opts.workosAPIKey, "workos-api-key", opts.workosAPIKey, "WorkOS API key (defaults to WORKOS_API_KEY or WORK_OS_SECRET_KEY)") + flag.StringVar(&opts.workosEndpoint, "workos-endpoint", opts.workosEndpoint, "WorkOS API endpoint override (defaults to WORKOS_API_URL)") + flag.Var(&orgIDs, "workos-org-id", "WorkOS organization id to process; repeat or comma-separate") + flag.IntVar(&opts.limit, "limit", opts.limit, "maximum organizations to inspect or backfill") + flag.IntVar(&opts.pageSize, "page-size", opts.pageSize, "number of organizations to inspect or backfill after page offset (0 means all remaining)") + flag.IntVar(&opts.pageOffset, "page-offset", opts.pageOffset, "number of organizations to skip after deterministic sorting") + flag.DurationVar(&opts.statementTimeout, "statement-timeout", opts.statementTimeout, "Postgres statement_timeout for each DB connection") + flag.BoolVar(&opts.dryRun, "dry-run", opts.dryRun, "inspect and validate without DB writes") + flag.BoolVar(&opts.autoApprove, "auto-approve", opts.autoApprove, "skip non-prod write confirmations") + flag.BoolVar(&opts.pauseAfterEach, "pause-after-each", opts.pauseAfterEach, "wait for Enter after each organization") + flag.BoolVar(&opts.breakpointBefore, "breakpoint-before-write", opts.breakpointBefore, "wait for Enter after preflight before writes") + flag.StringVar(&opts.confirmProd, "confirm-prod", opts.confirmProd, "must be set to prod for non-interactive prod access") + flag.Parse() + + opts.phase = phase(rawPhase) + opts.environment = environment(rawEnv) + opts.workosOrgIDs = orgIDs + must(validateOptions(opts)) + return opts +} + +func validateOptions(opts options) error { + switch opts.phase { + case phasePreflight, phaseGlobalRoles, phaseOrganizations, phaseValidate, phaseAll: + default: + return fmt.Errorf("invalid phase %q", opts.phase) + } + + switch opts.environment { + case envLocal, envDev, envProd: + default: + return fmt.Errorf("invalid environment %q", opts.environment) + } + + if opts.cloudSQLProxy && opts.environment == envLocal { + return errors.New("--cloudsql-proxy requires --environment=dev or --environment=prod") + } + if opts.cloudSQLPort < 0 || opts.cloudSQLPort > 65535 { + return errors.New("--cloudsql-port must be between 0 and 65535") + } + if strings.TrimSpace(opts.cloudSQLDBName) == "" { + return errors.New("--cloudsql-db-name must be non-empty") + } + if opts.databaseURL == "" && !opts.cloudSQLProxy { + return errors.New("--database-url or GRAM_DATABASE_URL is required") + } + if opts.workosAPIKey == "" { + return errors.New("--workos-api-key, WORKOS_API_KEY, or WORK_OS_SECRET_KEY is required") + } + if opts.limit < 0 { + return errors.New("--limit must be non-negative") + } + if opts.pageSize < 0 { + return errors.New("--page-size must be non-negative") + } + if opts.pageOffset < 0 { + return errors.New("--page-offset must be non-negative") + } + if opts.statementTimeout <= 0 { + return errors.New("--statement-timeout must be positive") + } + if opts.workosEndpoint == "" { + if opts.environment == envProd { + if strings.HasPrefix(opts.workosAPIKey, "sk_test_") || !strings.HasPrefix(opts.workosAPIKey, "sk_") { + return errors.New("prod WorkOS key must be live and start with sk_, not sk_test_") + } + } else if !strings.HasPrefix(opts.workosAPIKey, "sk_test_") { + return fmt.Errorf("%s WorkOS key must start with sk_test_ when using the real WorkOS endpoint", opts.environment) + } + } + + return nil +} + +func run(ctx context.Context, opts options) error { + logger := slog.New(o11y.NewLogHandler(&o11y.LogHandlerOptions{ + RawLevel: "info", + Pretty: true, + DataDogAttr: false, + })) + + if opts.environment == envProd { + if err := confirmProdAccess(opts); err != nil { + return err + } + } + + readOnly := opts.dryRun || opts.phase == phasePreflight || opts.phase == phaseValidate + databaseURL := opts.databaseURL + var cleanupCloudSQLProxy func() + if opts.cloudSQLProxy { + var err error + databaseURL, cleanupCloudSQLProxy, err = startCloudSQLProxy(ctx, opts, readOnly) + if err != nil { + return err + } + defer cleanupCloudSQLProxy() + } + + db, err := connectDB(ctx, databaseURL, readOnly, opts.statementTimeout) + if err != nil { + if opts.cloudSQLProxy { + return fmt.Errorf("%w%s", err, cloudSQLProxyHint()) + } + return err + } + defer db.Close() + + workosClient, err := newWorkOSClient(opts) + if err != nil { + return err + } + + fmt.Printf("WorkOS backfill phase=%s environment=%s dry_run=%t read_only_db=%t\n", opts.phase, opts.environment, opts.dryRun, readOnly) + fmt.Printf("Database statement_timeout: %s\n", opts.statementTimeout) + if opts.pageOffset > 0 || opts.pageSize > 0 { + fmt.Printf("Organization page: offset=%d size=%d\n", opts.pageOffset, opts.pageSize) + } + if opts.cloudSQLProxy { + fmt.Println("Database connection: local Cloud SQL proxy") + } + if opts.workosEndpoint != "" { + fmt.Printf("WorkOS endpoint override: %s\n", opts.workosEndpoint) + } + + var success = true + if opts.phase == phaseGlobalRoles || opts.phase == phaseValidate || opts.phase == phaseAll || opts.phase == phasePreflight { + globalRoles, err := workosClient.ListGlobalRoles(ctx) + if err != nil { + return fmt.Errorf("list WorkOS global roles: %w", err) + } + globalRoleChanges, err := classifyGlobalRoleChanges(ctx, db, globalRoles) + if err != nil { + return err + } + globalRoleDetails, err := collectGlobalRoleChangeDetails(ctx, db, globalRoles) + if err != nil { + return err + } + printGlobalRolePlan(globalRoles, globalRoleChanges, globalRoleDetails) + + if opts.phase == phaseGlobalRoles || opts.phase == phaseAll { + if opts.dryRun { + fmt.Println("Dry-run enabled: global role backfill writes skipped.") + } else if globalRoleChanges.Mutating() == 0 { + fmt.Println("Global role backfill skipped: no planned row changes.") + } else { + if err := confirmWrite(opts, fmt.Sprintf("global role changes: create=%d update=%d delete=%d noop=%d stale_skip=%d", + globalRoleChanges.Create, + globalRoleChanges.Update, + globalRoleChanges.Delete, + globalRoleChanges.Noop, + globalRoleChanges.StaleSkip, + )); err != nil { + return err + } + if opts.breakpointBefore { + waitForEnter("Breakpoint before global role writes. Press Enter to continue.") + } + if err := NewBackfillWorkOSGlobalRoles(logger, db, workosClient).Do(ctx); err != nil { + return fmt.Errorf("backfill WorkOS global roles: %w", err) + } + } + } + + if shouldValidate(opts) && (opts.phase == phaseGlobalRoles || opts.phase == phaseValidate || opts.phase == phaseAll) { + rep := validateGlobalRoles(ctx, db, globalRoles) + printReport("Global role validation complete.", rep) + success = rep.validationFailures == 0 && rep.failed == 0 && success + } else if opts.dryRun && (opts.phase == phaseGlobalRoles || opts.phase == phaseAll) { + fmt.Println("Dry-run enabled: global role validation skipped because writes were not performed.") + } + } + + if opts.phase == phasePreflight || opts.phase == phaseOrganizations || opts.phase == phaseValidate || opts.phase == phaseAll { + orgs, err := buildOrganizationPlan(ctx, db, workosClient, opts) + if err != nil { + return err + } + printOrganizationPlan(orgs) + + if opts.phase == phasePreflight { + return nil + } + if opts.phase == phaseOrganizations || opts.phase == phaseAll { + if opts.dryRun { + fmt.Println("Dry-run enabled: organization backfill writes skipped.") + } else { + if plannedOrganizationMutations(orgs) > 0 { + if err := confirmWrite(opts, organizationSummary(orgs)); err != nil { + return err + } + if opts.breakpointBefore { + waitForEnter("Breakpoint before organization writes. Press Enter to continue.") + } + } + rep := runOrganizationBackfill(ctx, logger, db, workosClient, opts, orgs) + printReport("Organization backfill complete.", rep) + success = rep.failed == 0 && rep.validationFailures == 0 && success + } + } + if shouldValidate(opts) && (opts.phase == phaseValidate || opts.phase == phaseAll) { + rep := validateOrganizations(ctx, db, orgs) + printReport("Organization validation complete.", rep) + success = rep.validationFailures == 0 && rep.failed == 0 && success + } else if opts.dryRun && (opts.phase == phaseOrganizations || opts.phase == phaseAll) { + fmt.Println("Dry-run enabled: organization validation skipped because writes were not performed.") + } + } + + if !success { + return errors.New("backfill completed with failures") + } + return nil +} + +func shouldValidate(opts options) bool { + return opts.phase == phaseValidate || !opts.dryRun +} + +func connectDB(ctx context.Context, databaseURL string, readOnly bool, statementTimeout time.Duration) (*pgxpool.Pool, error) { + cfg, err := pgxpool.ParseConfig(databaseURL) + if err != nil { + return nil, fmt.Errorf("parse database URL: %w", err) + } + statementTimeoutMs := max(1, statementTimeout.Milliseconds()) + cfg.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { + if _, err := conn.Exec(ctx, "SET lock_timeout = '5s'"); err != nil { + return fmt.Errorf("set lock_timeout: %w", err) + } + if _, err := conn.Exec(ctx, fmt.Sprintf("SET statement_timeout = %d", statementTimeoutMs)); err != nil { + return fmt.Errorf("set statement_timeout: %w", err) + } + if readOnly { + if _, err := conn.Exec(ctx, "SET default_transaction_read_only = on"); err != nil { + return fmt.Errorf("set default_transaction_read_only: %w", err) + } + } + return nil + } + + db, err := pgxpool.NewWithConfig(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("connect database: %w", err) + } + if err := db.Ping(ctx); err != nil { + db.Close() + return nil, fmt.Errorf("ping database: %w", err) + } + return db, nil +} + +func retryTransientDBDisconnect(ctx context.Context, label string, fn func() error) error { + const attempts = 3 + var err error + for attempt := 1; attempt <= attempts; attempt++ { + err = fn() + if err == nil { + return nil + } + if !isTransientDBDisconnect(err) || attempt == attempts { + return err + } + + delay := time.Duration(attempt) * 500 * time.Millisecond + fmt.Fprintf(os.Stderr, "WARN transient database disconnect during %s; retrying in %s: %v\n", label, delay, err) + timer := time.NewTimer(delay) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + return fmt.Errorf("%s: %w", label, ctx.Err()) + case <-timer.C: + } + } + return err +} + +func isTransientDBDisconnect(err error) bool { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + switch pgErr.Code { + case "57P01", "57P02", "57P03", "08000", "08003", "08006", "08007", "08P01": + return true + default: + return false + } + } + return strings.Contains(err.Error(), "SQLSTATE 57P01") || + strings.Contains(err.Error(), "terminating connection due to administrator command") || + strings.Contains(err.Error(), "conn closed") +} + +func newWorkOSClient(opts options) (*workos.Client, error) { + tracerProvider := noop.NewTracerProvider() + policy := guardian.NewDefaultPolicy(tracerProvider) + if opts.workosEndpoint != "" { + unsafePolicy, err := guardian.NewUnsafePolicy(tracerProvider, nil) + if err != nil { + return nil, fmt.Errorf("create unsafe guardian policy: %w", err) + } + policy = unsafePolicy + } + + return workos.NewClient(policy, opts.workosAPIKey, workos.ClientOpts{ + Endpoint: opts.workosEndpoint, + HTTPClient: nil, + }), nil +} + +func buildOrganizationPlan(ctx context.Context, db *pgxpool.Pool, workosClient *workos.Client, opts options) ([]orgExpectation, error) { + workosOrgs, err := selectedOrganizations(ctx, workosClient, opts) + if err != nil { + return nil, err + } + fmt.Printf("Planning organization backfill for %d WorkOS organizations\n", len(workosOrgs)) + + out := make([]orgExpectation, 0, len(workosOrgs)) + for i, org := range workosOrgs { + fmt.Printf("[%d/%d] plan %s name=%q\n", i+1, len(workosOrgs), org.ID, org.Name) + expectation, err := planOrganizationWithRetry(ctx, db, workosClient, org) + if err != nil { + return nil, err + } + out = append(out, expectation) + } + + return out, nil +} + +func planOrganizationWithRetry(ctx context.Context, db *pgxpool.Pool, workosClient *workos.Client, org workos.Organization) (orgExpectation, error) { + var expectation orgExpectation + err := retryTransientDBDisconnect(ctx, fmt.Sprintf("plan organization %s", org.ID), func() error { + next, err := planOrganization(ctx, db, workosClient, org) + if err != nil { + return err + } + expectation = next + return nil + }) + return expectation, err +} + +func planOrganization(ctx context.Context, db *pgxpool.Pool, workosClient *workos.Client, org workos.Organization) (orgExpectation, error) { + var zero orgExpectation + gramOrgID, skipped, err := expectedGramOrgID(ctx, db, org) + if err != nil { + return zero, err + } + roles, err := workosClient.ListRoles(ctx, org.ID) + if err != nil { + return zero, fmt.Errorf("list roles for %s: %w", org.ID, err) + } + users, err := workosClient.ListOrgUsers(ctx, org.ID) + if err != nil { + return zero, fmt.Errorf("list users for %s: %w", org.ID, err) + } + members, err := workosClient.ListOrgMemberships(ctx, org.ID) + if err != nil { + return zero, fmt.Errorf("list memberships for %s: %w", org.ID, err) + } + + orgChanges, err := classifyOrganizationMetadataChange(ctx, db, org, gramOrgID, skipped) + if err != nil { + return zero, err + } + roleChanges, err := classifyOrganizationRoleChanges(ctx, db, gramOrgID, skipped, roles) + if err != nil { + return zero, err + } + userChanges, err := classifyUserChanges(ctx, db, skipped, users) + if err != nil { + return zero, err + } + membershipChanges, err := classifyMembershipChanges(ctx, db, gramOrgID, skipped, users, members) + if err != nil { + return zero, err + } + assignmentChanges, err := classifyAssignmentChanges(ctx, db, gramOrgID, skipped, roles, users, members) + if err != nil { + return zero, err + } + changeDetails, err := collectOrganizationChangeDetails(ctx, db, org, gramOrgID, skipped, roles, users, members) + if err != nil { + return zero, err + } + + return orgExpectation{ + workosOrgID: org.ID, + gramOrgID: gramOrgID, + name: org.Name, + skipped: skipped, + roles: roles, + users: users, + members: members, + orgChanges: orgChanges, + roleChanges: roleChanges, + userChanges: userChanges, + membershipChanges: membershipChanges, + assignmentChanges: assignmentChanges, + changeDetails: changeDetails, + }, nil +} + +func selectedOrganizations(ctx context.Context, workosClient *workos.Client, opts options) ([]workos.Organization, error) { + if len(opts.workosOrgIDs) > 0 { + fmt.Printf("Loading %d selected WorkOS organizations\n", len(opts.workosOrgIDs)) + out := make([]workos.Organization, 0, len(opts.workosOrgIDs)) + for i, orgID := range opts.workosOrgIDs { + fmt.Printf("[%d/%d] get WorkOS organization %s\n", i+1, len(opts.workosOrgIDs), orgID) + org, err := workosClient.GetOrganization(ctx, orgID) + if err != nil { + return nil, fmt.Errorf("get WorkOS organization %s: %w", orgID, err) + } + out = append(out, *org) + } + return applyOrganizationWindow(out, opts), nil + } + + fmt.Println("Listing WorkOS organizations") + orgs, err := workosClient.ListOrganizations(ctx) + if err != nil { + return nil, fmt.Errorf("list WorkOS organizations: %w", err) + } + fmt.Printf("Listed %d WorkOS organizations\n", len(orgs)) + sort.Slice(orgs, func(i, j int) bool { return orgs[i].ID < orgs[j].ID }) + return applyOrganizationWindow(orgs, opts), nil +} + +func applyOrganizationWindow(orgs []workos.Organization, opts options) []workos.Organization { + originalLen := len(orgs) + if opts.pageOffset > 0 { + if opts.pageOffset >= len(orgs) { + fmt.Printf("Applying organization page offset: %d of %d leaves 0 organizations\n", opts.pageOffset, originalLen) + return orgs[:0] + } + orgs = orgs[opts.pageOffset:] + fmt.Printf("Applying organization page offset: skipped %d of %d\n", opts.pageOffset, originalLen) + } + if opts.pageSize > 0 && len(orgs) > opts.pageSize { + fmt.Printf("Applying organization page size: %d of %d remaining\n", opts.pageSize, len(orgs)) + orgs = orgs[:opts.pageSize] + } + if opts.limit > 0 && len(orgs) > opts.limit { + fmt.Printf("Applying organization limit: %d of %d\n", opts.limit, len(orgs)) + orgs = orgs[:opts.limit] + } + return orgs +} + +func expectedGramOrgID(ctx context.Context, db *pgxpool.Pool, org workos.Organization) (string, bool, error) { + var id string + err := db.QueryRow(ctx, "SELECT id FROM organization_metadata WHERE workos_id = $1 LIMIT 1", org.ID).Scan(&id) + switch { + case err == nil: + return id, false, nil + case errors.Is(err, pgx.ErrNoRows): + if org.ExternalID == "" { + return "", true, nil + } + return org.ExternalID, false, nil + default: + return "", false, fmt.Errorf("lookup local organization by workos id %s: %w", org.ID, err) + } +} + +func classifyGlobalRoleChanges(ctx context.Context, db *pgxpool.Pool, roles []workos.Role) (changeCounts, error) { + counts := changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0} + snapshotSlugs := make(map[string]struct{}, len(roles)) + + for _, role := range roles { + snapshotSlugs[role.Slug] = struct{}{} + + updatedAt, err := parseWorkOSTime(role.UpdatedAt) + if err != nil { + return changeCounts{}, fmt.Errorf("parse global role %q updated_at: %w", role.Slug, err) + } + + change, err := classifyRoleRow(ctx, db, "global_roles", "TRUE", nil, role, updatedAt) + if err != nil { + return changeCounts{}, fmt.Errorf("classify global role %q: %w", role.Slug, err) + } + counts = addChange(counts, change) + } + + deleteCounts, err := classifyMissingGlobalRoleDeletes(ctx, db, snapshotSlugs) + if err != nil { + return changeCounts{}, err + } + return counts.Add(deleteCounts), nil +} + +func classifyMissingGlobalRoleDeletes(ctx context.Context, db *pgxpool.Pool, snapshotSlugs map[string]struct{}) (changeCounts, error) { + counts := changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0} + rows, err := db.Query(ctx, ` +SELECT workos_slug, workos_updated_at, workos_last_event_id +FROM global_roles +WHERE deleted_at IS NULL`) + if err != nil { + return changeCounts{}, fmt.Errorf("query local global roles: %w", err) + } + defer rows.Close() + + now := time.Now().UTC() + for rows.Next() { + var slug string + var updatedAt pgtype.Timestamptz + var lastEventID pgtype.Text + if err := rows.Scan(&slug, &updatedAt, &lastEventID); err != nil { + return changeCounts{}, fmt.Errorf("scan local global role: %w", err) + } + if _, ok := snapshotSlugs[slug]; ok { + continue + } + if shouldProcessEvent(textPtr(lastEventID), timePtr(updatedAt), "", now) { + counts.Delete++ + } else { + counts.StaleSkip++ + } + } + if err := rows.Err(); err != nil { + return changeCounts{}, fmt.Errorf("iterate local global roles: %w", err) + } + return counts, nil +} + +func classifyOrganizationMetadataChange(ctx context.Context, db *pgxpool.Pool, org workos.Organization, gramOrgID string, skipped bool) (changeCounts, error) { + counts := changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0} + if skipped { + counts.StaleSkip = 1 + return counts, nil + } + + updatedAt, err := parseWorkOSTime(org.UpdatedAt) + if err != nil { + return changeCounts{}, fmt.Errorf("parse organization %q updated_at: %w", org.ID, err) + } + + var existingID string + var name string + var slug string + var workosID pgtype.Text + var rowUpdatedAt pgtype.Timestamptz + var lastEventID pgtype.Text + var disabledAt pgtype.Timestamptz + err = db.QueryRow(ctx, ` +SELECT id, name, slug, workos_id, workos_updated_at, workos_last_event_id, disabled_at +FROM organization_metadata +WHERE workos_id = $1 OR id = $2 +ORDER BY CASE WHEN workos_id = $1 THEN 0 ELSE 1 END +LIMIT 1`, org.ID, gramOrgID).Scan(&existingID, &name, &slug, &workosID, &rowUpdatedAt, &lastEventID, &disabledAt) + switch { + case errors.Is(err, pgx.ErrNoRows): + counts.Create = 1 + return counts, nil + case err != nil: + return changeCounts{}, fmt.Errorf("query organization metadata %q: %w", org.ID, err) + } + + if !shouldProcessEvent(textPtr(lastEventID), timePtr(rowUpdatedAt), "", updatedAt) { + counts.StaleSkip = 1 + return counts, nil + } + + if existingID == gramOrgID && + name == org.Name && + slug != "" && + workosID.Valid && + workosID.String == org.ID && + !disabledAt.Valid { + counts.Noop = 1 + return counts, nil + } + + counts.Update = 1 + return counts, nil +} + +func classifyOrganizationRoleChanges(ctx context.Context, db *pgxpool.Pool, organizationID string, skipped bool, roles []workos.Role) (changeCounts, error) { + counts := changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0} + if skipped { + return counts, nil + } + + snapshotSlugs := map[string]struct{}{} + for _, role := range roles { + if role.Type != "OrganizationRole" { + continue + } + snapshotSlugs[role.Slug] = struct{}{} + + updatedAt, err := parseWorkOSTime(role.UpdatedAt) + if err != nil { + return changeCounts{}, fmt.Errorf("parse organization role %q updated_at: %w", role.Slug, err) + } + change, err := classifyRoleRow(ctx, db, "organization_roles", "organization_id = $1", []any{organizationID}, role, updatedAt) + if err != nil { + return changeCounts{}, fmt.Errorf("classify organization role %q: %w", role.Slug, err) + } + counts = addChange(counts, change) + } + + deleteCounts, err := classifyMissingOrganizationRoleDeletes(ctx, db, organizationID, snapshotSlugs) + if err != nil { + return changeCounts{}, err + } + return counts.Add(deleteCounts), nil +} + +func classifyMissingOrganizationRoleDeletes(ctx context.Context, db *pgxpool.Pool, organizationID string, snapshotSlugs map[string]struct{}) (changeCounts, error) { + counts := changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0} + rows, err := db.Query(ctx, ` +SELECT workos_slug, workos_updated_at, workos_last_event_id +FROM organization_roles +WHERE organization_id = $1 + AND deleted_at IS NULL`, organizationID) + if err != nil { + return changeCounts{}, fmt.Errorf("query local organization roles: %w", err) + } + defer rows.Close() + + now := time.Now().UTC() + for rows.Next() { + var slug string + var updatedAt pgtype.Timestamptz + var lastEventID pgtype.Text + if err := rows.Scan(&slug, &updatedAt, &lastEventID); err != nil { + return changeCounts{}, fmt.Errorf("scan local organization role: %w", err) + } + if _, ok := snapshotSlugs[slug]; ok { + continue + } + if shouldProcessEvent(textPtr(lastEventID), timePtr(updatedAt), "", now) { + counts.Delete++ + } else { + counts.StaleSkip++ + } + } + if err := rows.Err(); err != nil { + return changeCounts{}, fmt.Errorf("iterate local organization roles: %w", err) + } + return counts, nil +} + +func classifyRoleRow(ctx context.Context, db *pgxpool.Pool, table, predicate string, args []any, role workos.Role, updatedAt time.Time) (string, error) { + queryArgs := append([]any{}, args...) + queryArgs = append(queryArgs, role.Slug) + slugArg := len(queryArgs) + query := fmt.Sprintf(` +SELECT workos_name, workos_description, workos_updated_at, workos_last_event_id, workos_deleted, deleted +FROM %s +WHERE %s + AND workos_slug = $%d +LIMIT 1`, table, predicate, slugArg) // #nosec G201 -- table and predicate are fixed call-site constants. + + var name string + var description pgtype.Text + var rowUpdatedAt pgtype.Timestamptz + var lastEventID pgtype.Text + var workosDeleted bool + var deleted bool + err := db.QueryRow(ctx, query, queryArgs...).Scan(&name, &description, &rowUpdatedAt, &lastEventID, &workosDeleted, &deleted) + switch { + case errors.Is(err, pgx.ErrNoRows): + return "create", nil + case err != nil: + return "", fmt.Errorf("query local role: %w", err) + } + + if !shouldProcessEvent(textPtr(lastEventID), timePtr(rowUpdatedAt), "", updatedAt) { + return "stale_skip", nil + } + if deleted || workosDeleted || name != role.Name || !pgTextEmptyEqual(description, role.Description) || !pgTimeEqual(rowUpdatedAt, updatedAt) { + return "update", nil + } + return "noop", nil +} + +func classifyUserChanges(ctx context.Context, db *pgxpool.Pool, skipped bool, users map[string]workos.User) (changeCounts, error) { + counts := changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0} + if skipped { + return counts, nil + } + for _, user := range users { + updatedAt, err := parseWorkOSTime(user.UpdatedAt) + if err != nil { + counts.StaleSkip++ + continue + } + change, err := classifyUserRow(ctx, db, user, updatedAt) + if err != nil { + return changeCounts{}, err + } + counts = addChange(counts, change) + } + return counts, nil +} + +func classifyUserRow(ctx context.Context, db *pgxpool.Pool, user workos.User, updatedAt time.Time) (string, error) { + existing, found, err := findUserByWorkOSID(ctx, db, user.ID) + if err != nil { + return "", err + } + if !found { + if user.ExternalID == "" { + return "stale_skip", nil + } + existing, found, err = findUserByID(ctx, db, user.ExternalID) + if err != nil { + return "", err + } + if !found { + return "create", nil + } + if existing.WorkosID.Valid && existing.WorkosID.String != user.ID { + return "", fmt.Errorf("local user %q is already linked to different WorkOS user %q", existing.ID, existing.WorkosID.String) + } + } + if existing.WorkosUpdatedAt.Valid && !shouldProcessEvent(nil, &existing.WorkosUpdatedAt.Time, "", updatedAt) { + return "stale_skip", nil + } + if existing.Email == user.Email && + existing.DisplayName == displayNameFromWorkOSUser(user) && + pgTextEmptyEqual(existing.PhotoUrl, user.ProfilePictureURL) && + existing.WorkosID.Valid && + existing.WorkosID.String == user.ID && + pgTimeEqual(existing.WorkosUpdatedAt, updatedAt) && + !existing.DeletedAt.Valid && + !existing.WorkosDeletedAt.Valid { + return "noop", nil + } + return "update", nil +} + +func classifyMembershipChanges(ctx context.Context, db *pgxpool.Pool, organizationID string, skipped bool, users map[string]workos.User, members []workos.Member) (changeCounts, error) { + counts := changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0} + if skipped { + return counts, nil + } + for _, member := range members { + gramUserID, resolved, err := expectedGramUserID(ctx, db, users[member.UserID]) + if err != nil { + return changeCounts{}, err + } + if !resolved { + counts.StaleSkip++ + continue + } + updatedAt, err := parseWorkOSTime(member.UpdatedAt) + if err != nil { + counts.StaleSkip++ + continue + } + change, err := classifyMembershipRow(ctx, db, organizationID, member, gramUserID, updatedAt) + if err != nil { + return changeCounts{}, err + } + counts = addChange(counts, change) + } + return counts, nil +} + +func classifyMembershipRow(ctx context.Context, db *pgxpool.Pool, organizationID string, member workos.Member, gramUserID string, updatedAt time.Time) (string, error) { + var userID pgtype.Text + var workosUserID pgtype.Text + var rowUpdatedAt pgtype.Timestamptz + var lastEventID pgtype.Text + var deletedAt pgtype.Timestamptz + err := db.QueryRow(ctx, ` +SELECT user_id, workos_user_id, workos_updated_at, workos_last_event_id, deleted_at +FROM organization_user_relationships +WHERE organization_id = $1 + AND workos_membership_id = $2 +ORDER BY updated_at DESC +LIMIT 1`, organizationID, member.ID).Scan(&userID, &workosUserID, &rowUpdatedAt, &lastEventID, &deletedAt) + switch { + case errors.Is(err, pgx.ErrNoRows): + missingIDChange, err := classifyMissingMembershipIDRepair(ctx, db, organizationID, member, gramUserID, updatedAt) + if err != nil { + return "", err + } + return missingIDChange, nil + case err != nil: + return "", fmt.Errorf("query local membership %q: %w", member.ID, err) + } + if membershipNeedsMissingFieldRepair(userID, workosUserID, rowUpdatedAt, deletedAt, gramUserID, member) { + return "update", nil + } + if !shouldProcessEvent(textPtr(lastEventID), timePtr(rowUpdatedAt), "", updatedAt) { + return "stale_skip", nil + } + if deletedAt.Valid || + !userID.Valid || + userID.String != gramUserID || + !workosUserID.Valid || + workosUserID.String != member.UserID || + !pgTimeEqual(rowUpdatedAt, updatedAt) { + return "update", nil + } + return "noop", nil +} + +func classifyMissingMembershipIDRepair(ctx context.Context, db *pgxpool.Pool, organizationID string, member workos.Member, gramUserID string, updatedAt time.Time) (string, error) { + var workosMembershipID pgtype.Text + var workosUserID pgtype.Text + var rowUpdatedAt pgtype.Timestamptz + var lastEventID pgtype.Text + var deletedAt pgtype.Timestamptz + err := db.QueryRow(ctx, ` +SELECT workos_membership_id, workos_user_id, workos_updated_at, workos_last_event_id, deleted_at +FROM organization_user_relationships +WHERE organization_id = $1 + AND user_id = $2 +ORDER BY updated_at DESC +LIMIT 1`, organizationID, conv.ToPGText(gramUserID)).Scan(&workosMembershipID, &workosUserID, &rowUpdatedAt, &lastEventID, &deletedAt) + switch { + case errors.Is(err, pgx.ErrNoRows): + return "create", nil + case err != nil: + return "", fmt.Errorf("query local membership by user %q: %w", member.ID, err) + } + if deletedAt.Valid { + return "update", nil + } + if !workosMembershipID.Valid { + return "update", nil + } + if !shouldProcessEvent(textPtr(lastEventID), timePtr(rowUpdatedAt), "", updatedAt) { + return "stale_skip", nil + } + if workosMembershipID.String != member.ID || + !workosUserID.Valid || + workosUserID.String != member.UserID || + !pgTimeEqual(rowUpdatedAt, updatedAt) { + return "update", nil + } + return "noop", nil +} + +func classifyAssignmentChanges(ctx context.Context, db *pgxpool.Pool, organizationID string, skipped bool, roles []workos.Role, users map[string]workos.User, members []workos.Member) (changeCounts, error) { + counts := changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0} + if skipped { + return counts, nil + } + for _, member := range members { + gramUserID, resolved, err := expectedGramUserID(ctx, db, users[member.UserID]) + if err != nil { + return changeCounts{}, err + } + if !resolved { + counts.StaleSkip++ + continue + } + if member.RoleSlug != "" { + available, err := plannedAssignmentRoleAvailable(ctx, db, organizationID, roles, member.RoleSlug) + if err != nil { + return changeCounts{}, err + } + if !available { + counts.StaleSkip++ + continue + } + } + activeAssignments, missingUserAssignments, err := countActiveAssignments(ctx, db, organizationID, member.ID, gramUserID) + if err != nil { + return changeCounts{}, err + } + if member.RoleSlug == "" { + if activeAssignments > 0 { + counts.Delete += activeAssignments + } else { + counts.Noop++ + } + continue + } + if activeAssignments == 0 { + counts.Create++ + } else if missingUserAssignments > 0 { + counts.Update += missingUserAssignments + counts.Noop += activeAssignments - missingUserAssignments + } else { + counts.Noop += activeAssignments + } + } + return counts, nil +} + +func countActiveAssignments(ctx context.Context, db *pgxpool.Pool, organizationID, membershipID, gramUserID string) (int, int, error) { + var count int + var missingUserCount int + if err := db.QueryRow(ctx, ` +SELECT count(*)::int +FROM organization_role_assignments +WHERE organization_id = $1 + AND workos_membership_id = $2 + AND deleted_at IS NULL`, organizationID, membershipID).Scan(&count); err != nil { + return 0, 0, fmt.Errorf("count active role assignments for membership %q: %w", membershipID, err) + } + if err := db.QueryRow(ctx, ` +SELECT count(*)::int +FROM organization_role_assignments +WHERE organization_id = $1 + AND workos_membership_id = $2 + AND deleted_at IS NULL + AND (user_id IS NULL OR user_id <> $3)`, organizationID, membershipID, gramUserID).Scan(&missingUserCount); err != nil { + return 0, 0, fmt.Errorf("count role assignments missing user for membership %q: %w", membershipID, err) + } + return count, missingUserCount, nil +} + +func plannedAssignmentRoleAvailable(ctx context.Context, db *pgxpool.Pool, organizationID string, roles []workos.Role, roleSlug string) (bool, error) { + for _, role := range roles { + if role.Slug != roleSlug || role.Type != "OrganizationRole" { + continue + } + updatedAt, err := parseWorkOSTime(role.UpdatedAt) + if err != nil { + return false, fmt.Errorf("parse organization role %q updated_at: %w", role.Slug, err) + } + change, err := classifyRoleRow(ctx, db, "organization_roles", "organization_id = $1", []any{organizationID}, role, updatedAt) + if err != nil { + return false, fmt.Errorf("classify organization role %q: %w", role.Slug, err) + } + if change == "create" || change == "update" || change == "noop" { + return true, nil + } + return activeAssignmentRoleExists(ctx, db, organizationID, roleSlug) + } + + return activeGlobalRoleExists(ctx, db, roleSlug) +} + +func activeAssignmentRoleExists(ctx context.Context, db queryRower, organizationID string, roleSlug string) (bool, error) { + var exists bool + err := db.QueryRow(ctx, ` +SELECT EXISTS ( + SELECT 1 + FROM organization_roles + WHERE organization_id = $1 + AND workos_slug = $2 + AND deleted IS FALSE + AND workos_deleted IS FALSE +) OR EXISTS ( + SELECT 1 + FROM global_roles + WHERE workos_slug = $2 + AND deleted IS FALSE + AND workos_deleted IS FALSE +)`, organizationID, roleSlug).Scan(&exists) + if err != nil { + return false, fmt.Errorf("check active role for assignment slug %q: %w", roleSlug, err) + } + return exists, nil +} + +func activeGlobalRoleExists(ctx context.Context, db queryRower, roleSlug string) (bool, error) { + var exists bool + err := db.QueryRow(ctx, ` +SELECT EXISTS ( + SELECT 1 + FROM global_roles + WHERE workos_slug = $1 + AND deleted IS FALSE + AND workos_deleted IS FALSE +)`, roleSlug).Scan(&exists) + if err != nil { + return false, fmt.Errorf("check active global role slug %q: %w", roleSlug, err) + } + return exists, nil +} + +func addChange(counts changeCounts, change string) changeCounts { + switch change { + case "create": + counts.Create++ + case "update": + counts.Update++ + case "noop": + counts.Noop++ + case "delete": + counts.Delete++ + case "stale_skip": + counts.StaleSkip++ + } + return counts +} + +func collectGlobalRoleChangeDetails(ctx context.Context, db *pgxpool.Pool, roles []workos.Role) ([]changeDetail, error) { + details := make([]changeDetail, 0) + snapshotSlugs := make(map[string]struct{}, len(roles)) + for _, role := range roles { + snapshotSlugs[role.Slug] = struct{}{} + updatedAt, err := parseWorkOSTime(role.UpdatedAt) + if err != nil { + return nil, fmt.Errorf("parse global role %q updated_at: %w", role.Slug, err) + } + + detail, ok, err := collectGlobalRoleUpdateDetail(ctx, db, role, updatedAt) + if err != nil { + return nil, err + } + if ok { + details = append(details, detail) + continue + } + change, err := classifyRoleRow(ctx, db, "global_roles", "TRUE", nil, role, updatedAt) + if err != nil { + return nil, fmt.Errorf("classify global role %q: %w", role.Slug, err) + } + if change == "create" { + details = append(details, roleCreateDetail("global_role", role, updatedAt)) + } + } + + deleteDetails, err := collectMissingGlobalRoleDeleteDetails(ctx, db, snapshotSlugs) + if err != nil { + return nil, err + } + return append(details, deleteDetails...), nil +} + +func collectGlobalRoleUpdateDetail(ctx context.Context, db *pgxpool.Pool, role workos.Role, updatedAt time.Time) (changeDetail, bool, error) { + var name string + var description pgtype.Text + var rowUpdatedAt pgtype.Timestamptz + var lastEventID pgtype.Text + var workosDeleted bool + var deleted bool + err := db.QueryRow(ctx, ` +SELECT workos_name, workos_description, workos_updated_at, workos_last_event_id, workos_deleted, deleted +FROM global_roles +WHERE workos_slug = $1 +LIMIT 1`, role.Slug).Scan(&name, &description, &rowUpdatedAt, &lastEventID, &workosDeleted, &deleted) + switch { + case errors.Is(err, pgx.ErrNoRows): + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + case err != nil: + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, fmt.Errorf("query local global role %q: %w", role.Slug, err) + } + if !shouldProcessEvent(textPtr(lastEventID), timePtr(rowUpdatedAt), "", updatedAt) { + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + } + + fields := make([]fieldChange, 0) + fields = appendFieldChange(fields, "workos_name", name, role.Name) + if !pgTextEmptyEqual(description, role.Description) { + fields = appendFieldChange(fields, "workos_description", pgTextDisplay(description), role.Description) + } + fields = appendFieldChange(fields, "workos_updated_at", pgTimeDisplay(rowUpdatedAt), timeDisplay(updatedAt)) + fields = appendFieldChange(fields, "workos_deleted", boolDisplay(workosDeleted), "false") + fields = appendFieldChange(fields, "deleted", boolDisplay(deleted), "false") + if len(fields) == 0 { + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + } + return changeDetail{Entity: "global_role", ID: role.Slug, Action: "update", Fields: fields}, true, nil +} + +func collectMissingGlobalRoleDeleteDetails(ctx context.Context, db *pgxpool.Pool, snapshotSlugs map[string]struct{}) ([]changeDetail, error) { + rows, err := db.Query(ctx, ` +SELECT workos_slug, workos_updated_at, workos_last_event_id +FROM global_roles +WHERE deleted_at IS NULL`) + if err != nil { + return nil, fmt.Errorf("query local global roles: %w", err) + } + defer rows.Close() + + details := make([]changeDetail, 0) + now := time.Now().UTC() + for rows.Next() { + var slug string + var updatedAt pgtype.Timestamptz + var lastEventID pgtype.Text + if err := rows.Scan(&slug, &updatedAt, &lastEventID); err != nil { + return nil, fmt.Errorf("scan local global role: %w", err) + } + if _, ok := snapshotSlugs[slug]; ok { + continue + } + if shouldProcessEvent(textPtr(lastEventID), timePtr(updatedAt), "", now) { + details = append(details, changeDetail{ + Entity: "global_role", + ID: slug, + Action: "delete", + Fields: []fieldChange{{ + Name: "deleted_at", + Before: "", + After: "now", + }}, + }) + } + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate local global roles: %w", err) + } + return details, nil +} + +func collectOrganizationChangeDetails(ctx context.Context, db *pgxpool.Pool, org workos.Organization, gramOrgID string, skipped bool, roles []workos.Role, users map[string]workos.User, members []workos.Member) ([]changeDetail, error) { + if skipped { + return nil, nil + } + details := make([]changeDetail, 0) + + orgDetail, ok, err := collectOrganizationUpdateDetail(ctx, db, org, gramOrgID) + if err != nil { + return nil, err + } + if ok { + details = append(details, orgDetail) + } else { + orgChanges, err := classifyOrganizationMetadataChange(ctx, db, org, gramOrgID, skipped) + if err != nil { + return nil, err + } + if orgChanges.Create > 0 { + details = append(details, organizationCreateDetail(org, gramOrgID)) + } + } + + roleDetails, err := collectOrganizationRoleChangeDetails(ctx, db, gramOrgID, roles) + if err != nil { + return nil, err + } + details = append(details, roleDetails...) + + userDetails, err := collectUserChangeDetails(ctx, db, users) + if err != nil { + return nil, err + } + details = append(details, userDetails...) + + membershipDetails, err := collectMembershipChangeDetails(ctx, db, gramOrgID, users, members) + if err != nil { + return nil, err + } + details = append(details, membershipDetails...) + + assignmentDetails, err := collectAssignmentChangeDetails(ctx, db, gramOrgID, roles, users, members) + if err != nil { + return nil, err + } + details = append(details, assignmentDetails...) + + return details, nil +} + +func collectOrganizationUpdateDetail(ctx context.Context, db *pgxpool.Pool, org workos.Organization, gramOrgID string) (changeDetail, bool, error) { + updatedAt, err := parseWorkOSTime(org.UpdatedAt) + if err != nil { + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, fmt.Errorf("parse organization %q updated_at: %w", org.ID, err) + } + + var existingID string + var name string + var slug string + var workosID pgtype.Text + var rowUpdatedAt pgtype.Timestamptz + var lastEventID pgtype.Text + var disabledAt pgtype.Timestamptz + err = db.QueryRow(ctx, ` +SELECT id, name, slug, workos_id, workos_updated_at, workos_last_event_id, disabled_at +FROM organization_metadata +WHERE workos_id = $1 OR id = $2 +ORDER BY CASE WHEN workos_id = $1 THEN 0 ELSE 1 END +LIMIT 1`, org.ID, gramOrgID).Scan(&existingID, &name, &slug, &workosID, &rowUpdatedAt, &lastEventID, &disabledAt) + switch { + case errors.Is(err, pgx.ErrNoRows): + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + case err != nil: + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, fmt.Errorf("query organization metadata %q: %w", org.ID, err) + } + if !shouldProcessEvent(textPtr(lastEventID), timePtr(rowUpdatedAt), "", updatedAt) { + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + } + + fields := make([]fieldChange, 0) + fields = appendFieldChange(fields, "id", existingID, gramOrgID) + fields = appendFieldChange(fields, "name", name, org.Name) + if slug == "" { + fields = appendFieldChange(fields, "slug", "", "generated unique slug") + } + fields = appendFieldChange(fields, "workos_id", pgTextDisplay(workosID), org.ID) + fields = appendFieldChange(fields, "workos_updated_at", pgTimeDisplay(rowUpdatedAt), timeDisplay(updatedAt)) + if disabledAt.Valid { + fields = appendFieldChange(fields, "disabled_at", pgTimeDisplay(disabledAt), "") + } + if len(fields) == 0 { + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + } + return changeDetail{Entity: "organization", ID: org.ID, Action: "update", Fields: fields}, true, nil +} + +func collectOrganizationRoleChangeDetails(ctx context.Context, db *pgxpool.Pool, organizationID string, roles []workos.Role) ([]changeDetail, error) { + details := make([]changeDetail, 0) + snapshotSlugs := make(map[string]struct{}, len(roles)) + for _, role := range roles { + if role.Type != "OrganizationRole" { + continue + } + snapshotSlugs[role.Slug] = struct{}{} + updatedAt, err := parseWorkOSTime(role.UpdatedAt) + if err != nil { + return nil, fmt.Errorf("parse organization role %q updated_at: %w", role.Slug, err) + } + detail, ok, err := collectOrganizationRoleUpdateDetail(ctx, db, organizationID, role, updatedAt) + if err != nil { + return nil, err + } + if ok { + details = append(details, detail) + continue + } + change, err := classifyRoleRow(ctx, db, "organization_roles", "organization_id = $1", []any{organizationID}, role, updatedAt) + if err != nil { + return nil, fmt.Errorf("classify organization role %q: %w", role.Slug, err) + } + if change == "create" { + details = append(details, roleCreateDetail("organization_role", role, updatedAt)) + } + } + deleteDetails, err := collectMissingOrganizationRoleDeleteDetails(ctx, db, organizationID, snapshotSlugs) + if err != nil { + return nil, err + } + return append(details, deleteDetails...), nil +} + +func collectOrganizationRoleUpdateDetail(ctx context.Context, db *pgxpool.Pool, organizationID string, role workos.Role, updatedAt time.Time) (changeDetail, bool, error) { + var name string + var description pgtype.Text + var rowUpdatedAt pgtype.Timestamptz + var lastEventID pgtype.Text + var workosDeleted bool + var deleted bool + err := db.QueryRow(ctx, ` +SELECT workos_name, workos_description, workos_updated_at, workos_last_event_id, workos_deleted, deleted +FROM organization_roles +WHERE organization_id = $1 + AND workos_slug = $2 +LIMIT 1`, organizationID, role.Slug).Scan(&name, &description, &rowUpdatedAt, &lastEventID, &workosDeleted, &deleted) + switch { + case errors.Is(err, pgx.ErrNoRows): + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + case err != nil: + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, fmt.Errorf("query local organization role %q: %w", role.Slug, err) + } + if !shouldProcessEvent(textPtr(lastEventID), timePtr(rowUpdatedAt), "", updatedAt) { + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + } + + fields := make([]fieldChange, 0) + fields = appendFieldChange(fields, "workos_name", name, role.Name) + if !pgTextEmptyEqual(description, role.Description) { + fields = appendFieldChange(fields, "workos_description", pgTextDisplay(description), role.Description) + } + fields = appendFieldChange(fields, "workos_updated_at", pgTimeDisplay(rowUpdatedAt), timeDisplay(updatedAt)) + fields = appendFieldChange(fields, "workos_deleted", boolDisplay(workosDeleted), "false") + fields = appendFieldChange(fields, "deleted", boolDisplay(deleted), "false") + if len(fields) == 0 { + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + } + return changeDetail{Entity: "organization_role", ID: role.Slug, Action: "update", Fields: fields}, true, nil +} + +func collectMissingOrganizationRoleDeleteDetails(ctx context.Context, db *pgxpool.Pool, organizationID string, snapshotSlugs map[string]struct{}) ([]changeDetail, error) { + rows, err := db.Query(ctx, ` +SELECT workos_slug, workos_updated_at, workos_last_event_id +FROM organization_roles +WHERE organization_id = $1 + AND deleted_at IS NULL`, organizationID) + if err != nil { + return nil, fmt.Errorf("query local organization roles: %w", err) + } + defer rows.Close() + + details := make([]changeDetail, 0) + now := time.Now().UTC() + for rows.Next() { + var slug string + var updatedAt pgtype.Timestamptz + var lastEventID pgtype.Text + if err := rows.Scan(&slug, &updatedAt, &lastEventID); err != nil { + return nil, fmt.Errorf("scan local organization role: %w", err) + } + if _, ok := snapshotSlugs[slug]; ok { + continue + } + if shouldProcessEvent(textPtr(lastEventID), timePtr(updatedAt), "", now) { + details = append(details, changeDetail{ + Entity: "organization_role", + ID: slug, + Action: "delete", + Fields: []fieldChange{{ + Name: "deleted_at", + Before: "", + After: "now", + }}, + }) + } + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate local organization roles: %w", err) + } + return details, nil +} + +func collectUserChangeDetails(ctx context.Context, db *pgxpool.Pool, users map[string]workos.User) ([]changeDetail, error) { + details := make([]changeDetail, 0) + for _, user := range users { + updatedAt, err := parseWorkOSTime(user.UpdatedAt) + if err != nil { + continue + } + createdAt, err := parseWorkOSTime(user.CreatedAt) + if err != nil { + continue + } + existing, found, err := findUserByWorkOSID(ctx, db, user.ID) + if err != nil { + return nil, err + } + if !found && user.ExternalID != "" { + existing, found, err = findUserByID(ctx, db, user.ExternalID) + if err != nil { + return nil, err + } + if found && existing.WorkosID.Valid && existing.WorkosID.String != user.ID { + return nil, fmt.Errorf("local user %q is already linked to different WorkOS user %q", existing.ID, existing.WorkosID.String) + } + } + if !found { + if user.ExternalID != "" { + details = append(details, userCreateDetail(user, user.ExternalID, createdAt, updatedAt)) + } + continue + } + if existing.WorkosUpdatedAt.Valid && !shouldProcessEvent(nil, &existing.WorkosUpdatedAt.Time, "", updatedAt) { + continue + } + fields := make([]fieldChange, 0) + fields = appendFieldChange(fields, "email", existing.Email, user.Email) + fields = appendFieldChange(fields, "display_name", existing.DisplayName, displayNameFromWorkOSUser(user)) + if !pgTextEmptyEqual(existing.PhotoUrl, user.ProfilePictureURL) { + fields = appendFieldChange(fields, "photo_url", pgTextDisplay(existing.PhotoUrl), user.ProfilePictureURL) + } + fields = appendFieldChange(fields, "workos_id", pgTextDisplay(existing.WorkosID), user.ID) + if !existing.WorkosCreatedAt.Valid { + fields = appendFieldChange(fields, "workos_created_at", pgTimeDisplay(existing.WorkosCreatedAt), timeDisplay(createdAt)) + } + fields = appendFieldChange(fields, "workos_updated_at", pgTimeDisplay(existing.WorkosUpdatedAt), timeDisplay(updatedAt)) + if existing.DeletedAt.Valid { + fields = appendFieldChange(fields, "deleted_at", pgTimeDisplay(existing.DeletedAt), "") + } + if existing.WorkosDeletedAt.Valid { + fields = appendFieldChange(fields, "workos_deleted_at", pgTimeDisplay(existing.WorkosDeletedAt), "") + } + if len(fields) > 0 { + details = append(details, changeDetail{Entity: "user", ID: user.ID, Action: "update", Fields: fields}) + } + } + return details, nil +} + +func collectMembershipChangeDetails(ctx context.Context, db *pgxpool.Pool, organizationID string, users map[string]workos.User, members []workos.Member) ([]changeDetail, error) { + details := make([]changeDetail, 0) + for _, member := range members { + gramUserID, resolved, err := expectedGramUserID(ctx, db, users[member.UserID]) + if err != nil { + return nil, err + } + if !resolved { + continue + } + updatedAt, err := parseWorkOSTime(member.UpdatedAt) + if err != nil { + continue + } + detail, ok, err := collectMembershipUpdateDetail(ctx, db, organizationID, member, gramUserID, updatedAt) + if err != nil { + return nil, err + } + if ok { + details = append(details, detail) + continue + } + change, err := classifyMembershipRow(ctx, db, organizationID, member, gramUserID, updatedAt) + if err != nil { + return nil, err + } + if change == "create" { + details = append(details, membershipCreateDetail(organizationID, member, gramUserID, updatedAt)) + } + } + return details, nil +} + +func collectMembershipUpdateDetail(ctx context.Context, db *pgxpool.Pool, organizationID string, member workos.Member, gramUserID string, updatedAt time.Time) (changeDetail, bool, error) { + var userID pgtype.Text + var workosUserID pgtype.Text + var rowUpdatedAt pgtype.Timestamptz + var lastEventID pgtype.Text + var deletedAt pgtype.Timestamptz + err := db.QueryRow(ctx, ` +SELECT user_id, workos_user_id, workos_updated_at, workos_last_event_id, deleted_at +FROM organization_user_relationships +WHERE organization_id = $1 + AND workos_membership_id = $2 +ORDER BY updated_at DESC +LIMIT 1`, organizationID, member.ID).Scan(&userID, &workosUserID, &rowUpdatedAt, &lastEventID, &deletedAt) + switch { + case errors.Is(err, pgx.ErrNoRows): + return collectMissingMembershipIDRepairDetail(ctx, db, organizationID, member, gramUserID, updatedAt) + case err != nil: + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, fmt.Errorf("query local membership %q: %w", member.ID, err) + } + if membershipNeedsMissingFieldRepair(userID, workosUserID, rowUpdatedAt, deletedAt, gramUserID, member) { + fields := missingMembershipFieldRepairs(userID, workosUserID, rowUpdatedAt, gramUserID, member, updatedAt) + return changeDetail{Entity: "membership", ID: member.ID, Action: "update", Fields: fields}, true, nil + } + if !shouldProcessEvent(textPtr(lastEventID), timePtr(rowUpdatedAt), "", updatedAt) { + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + } + + fields := make([]fieldChange, 0) + fields = appendFieldChange(fields, "user_id", pgTextDisplay(userID), gramUserID) + fields = appendFieldChange(fields, "workos_user_id", pgTextDisplay(workosUserID), member.UserID) + fields = appendFieldChange(fields, "workos_updated_at", pgTimeDisplay(rowUpdatedAt), timeDisplay(updatedAt)) + if deletedAt.Valid { + fields = appendFieldChange(fields, "deleted_at", pgTimeDisplay(deletedAt), "") + } + if len(fields) == 0 { + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + } + return changeDetail{Entity: "membership", ID: member.ID, Action: "update", Fields: fields}, true, nil +} + +func collectMissingMembershipIDRepairDetail(ctx context.Context, db *pgxpool.Pool, organizationID string, member workos.Member, gramUserID string, updatedAt time.Time) (changeDetail, bool, error) { + var workosMembershipID pgtype.Text + var workosUserID pgtype.Text + var rowUpdatedAt pgtype.Timestamptz + var lastEventID pgtype.Text + var deletedAt pgtype.Timestamptz + err := db.QueryRow(ctx, ` +SELECT workos_membership_id, workos_user_id, workos_updated_at, workos_last_event_id, deleted_at +FROM organization_user_relationships +WHERE organization_id = $1 + AND user_id = $2 +ORDER BY updated_at DESC +LIMIT 1`, organizationID, conv.ToPGText(gramUserID)).Scan(&workosMembershipID, &workosUserID, &rowUpdatedAt, &lastEventID, &deletedAt) + switch { + case errors.Is(err, pgx.ErrNoRows): + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + case err != nil: + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, fmt.Errorf("query local membership by user %q: %w", member.ID, err) + } + if deletedAt.Valid { + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + } + if workosMembershipID.Valid && !shouldProcessEvent(textPtr(lastEventID), timePtr(rowUpdatedAt), "", updatedAt) { + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + } + + fields := make([]fieldChange, 0) + if !workosMembershipID.Valid { + fields = appendFieldChange(fields, "workos_membership_id", pgTextDisplay(workosMembershipID), member.ID) + } + if !workosUserID.Valid && member.UserID != "" { + fields = appendFieldChange(fields, "workos_user_id", pgTextDisplay(workosUserID), member.UserID) + } + if !rowUpdatedAt.Valid { + fields = appendFieldChange(fields, "workos_updated_at", pgTimeDisplay(rowUpdatedAt), timeDisplay(updatedAt)) + } + if len(fields) == 0 { + return changeDetail{Entity: "", ID: "", Action: "", Fields: nil}, false, nil + } + return changeDetail{Entity: "membership", ID: member.ID, Action: "update", Fields: fields}, true, nil +} + +func membershipNeedsMissingFieldRepair(userID pgtype.Text, workosUserID pgtype.Text, rowUpdatedAt pgtype.Timestamptz, deletedAt pgtype.Timestamptz, gramUserID string, member workos.Member) bool { + if deletedAt.Valid { + return false + } + return !userID.Valid && gramUserID != "" || + !workosUserID.Valid && member.UserID != "" || + !rowUpdatedAt.Valid +} + +func missingMembershipFieldRepairs(userID pgtype.Text, workosUserID pgtype.Text, rowUpdatedAt pgtype.Timestamptz, gramUserID string, member workos.Member, updatedAt time.Time) []fieldChange { + fields := make([]fieldChange, 0) + if !userID.Valid && gramUserID != "" { + fields = appendFieldChange(fields, "user_id", pgTextDisplay(userID), gramUserID) + } + if !workosUserID.Valid && member.UserID != "" { + fields = appendFieldChange(fields, "workos_user_id", pgTextDisplay(workosUserID), member.UserID) + } + if !rowUpdatedAt.Valid { + fields = appendFieldChange(fields, "workos_updated_at", pgTimeDisplay(rowUpdatedAt), timeDisplay(updatedAt)) + } + return fields +} + +func collectAssignmentChangeDetails(ctx context.Context, db *pgxpool.Pool, organizationID string, roles []workos.Role, users map[string]workos.User, members []workos.Member) ([]changeDetail, error) { + details := make([]changeDetail, 0) + for _, member := range members { + gramUserID, resolved, err := expectedGramUserID(ctx, db, users[member.UserID]) + if err != nil { + return nil, err + } + if !resolved { + continue + } + roleAvailable := false + if member.RoleSlug != "" { + roleAvailable, err = plannedAssignmentRoleAvailable(ctx, db, organizationID, roles, member.RoleSlug) + if err != nil { + return nil, err + } + if !roleAvailable { + continue + } + } + rows, err := db.Query(ctx, ` +SELECT id, user_id +FROM organization_role_assignments +WHERE organization_id = $1 + AND workos_membership_id = $2 + AND deleted_at IS NULL`, organizationID, member.ID) + if err != nil { + return nil, fmt.Errorf("query active role assignments for membership %q: %w", member.ID, err) + } + activeAssignments := 0 + for rows.Next() { + var id pgtype.UUID + var userID pgtype.Text + if err := rows.Scan(&id, &userID); err != nil { + rows.Close() + return nil, fmt.Errorf("scan role assignment for membership %q: %w", member.ID, err) + } + activeAssignments++ + if member.RoleSlug == "" { + details = append(details, changeDetail{ + Entity: "role_assignment", + ID: uuidDisplay(id), + Action: "delete", + Fields: []fieldChange{{ + Name: "deleted_at", + Before: "", + After: "now", + }}, + }) + continue + } + if !userID.Valid || userID.String != gramUserID { + details = append(details, changeDetail{ + Entity: "role_assignment", + ID: uuidDisplay(id), + Action: "update", + Fields: []fieldChange{{ + Name: "user_id", + Before: pgTextDisplay(userID), + After: gramUserID, + }}, + }) + } + } + if err := rows.Err(); err != nil { + rows.Close() + return nil, fmt.Errorf("iterate role assignments for membership %q: %w", member.ID, err) + } + rows.Close() + if activeAssignments == 0 && member.RoleSlug != "" && roleAvailable { + details = append(details, roleAssignmentCreateDetail(organizationID, member, gramUserID)) + } + } + return details, nil +} + +func runOrganizationBackfill(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, workosClient Client, opts options, orgs []orgExpectation) report { + rep := report{ + scanned: len(orgs), + skipped: 0, + skippedNoop: 0, + written: 0, + validated: 0, + failed: 0, + validationFailures: 0, + organizationRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + roleRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + userRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + membershipRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + assignmentRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + } + backfill := NewBackfillWorkOSOrganization(logger, db, workosClient) + for i, org := range orgs { + if org.skipped { + rep.skipped++ + fmt.Printf("[%d/%d] skip %s: no local row and no WorkOS external_id\n", i+1, len(orgs), org.workosOrgID) + continue + } + if plannedOrganizationMutation(org) == 0 { + rep.skippedNoop++ + fmt.Printf("[%d/%d] skip noop %s -> %s\n", i+1, len(orgs), org.workosOrgID, org.gramOrgID) + continue + } + + fmt.Printf("[%d/%d] backfill %s -> %s\n", i+1, len(orgs), org.workosOrgID, org.gramOrgID) + if err := backfill.Do(ctx, BackfillWorkOSOrganizationParams{WorkOSOrganizationID: org.workosOrgID}); err != nil { + rep.failed++ + fmt.Fprintf(os.Stderr, " failed: %v\n", err) + } else { + rep.written++ + rep.organizationRows = rep.organizationRows.Add(org.orgChanges) + rep.roleRows = rep.roleRows.Add(org.roleChanges) + rep.userRows = rep.userRows.Add(org.userChanges) + rep.membershipRows = rep.membershipRows.Add(org.membershipChanges) + rep.assignmentRows = rep.assignmentRows.Add(org.assignmentChanges) + if err := validateOrganization(ctx, db, org); err != nil { + rep.validationFailures++ + fmt.Fprintf(os.Stderr, " validation failed: %v\n", err) + } else { + rep.validated++ + } + } + + if opts.pauseAfterEach { + waitForEnter("Paused after organization. Press Enter to continue.") + } + } + return rep +} + +func validateOrganizations(ctx context.Context, db *pgxpool.Pool, orgs []orgExpectation) report { + rep := report{ + scanned: len(orgs), + skipped: 0, + skippedNoop: 0, + written: 0, + validated: 0, + failed: 0, + validationFailures: 0, + organizationRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + roleRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + userRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + membershipRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + assignmentRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + } + for _, org := range orgs { + if org.skipped { + rep.skipped++ + continue + } + if err := validateOrganization(ctx, db, org); err != nil { + rep.validationFailures++ + fmt.Fprintf(os.Stderr, "validation failed for %s: %v\n", org.workosOrgID, err) + continue + } + rep.validated++ + } + return rep +} + +func validateOrganization(ctx context.Context, db *pgxpool.Pool, org orgExpectation) error { + var gramOrgID string + if err := db.QueryRow(ctx, "SELECT id FROM organization_metadata WHERE workos_id = $1 LIMIT 1", org.workosOrgID).Scan(&gramOrgID); err != nil { + return fmt.Errorf("organization_metadata missing workos_id=%s: %w", org.workosOrgID, err) + } + if gramOrgID != org.gramOrgID { + return fmt.Errorf("organization id mismatch: got %s, expected %s", gramOrgID, org.gramOrgID) + } + + expectedRoleSlugs := make([]string, 0, len(org.roles)) + for _, role := range org.roles { + if role.Type == "OrganizationRole" { + expectedRoleSlugs = append(expectedRoleSlugs, role.Slug) + } + } + if err := requireSlugs(ctx, db, "organization_roles", "organization_id = $1", []any{gramOrgID}, expectedRoleSlugs); err != nil { + return fmt.Errorf("organization roles: %w", err) + } + + expectedUserIDs := make([]string, 0, len(org.users)) + resolvableWorkOSUserIDs := make(map[string]struct{}, len(org.users)) + for _, user := range org.users { + gramUserID, ok, err := expectedGramUserID(ctx, db, user) + if err != nil { + return err + } + if !ok { + continue + } + expectedUserIDs = append(expectedUserIDs, user.ID) + resolvableWorkOSUserIDs[user.ID] = struct{}{} + _ = gramUserID + } + if err := requireUsers(ctx, db, expectedUserIDs); err != nil { + return err + } + + membershipIDs := make([]string, 0, len(org.members)) + roleMembershipIDs := make([]string, 0, len(org.members)) + for _, member := range org.members { + if _, ok := resolvableWorkOSUserIDs[member.UserID]; !ok { + continue + } + membershipIDs = append(membershipIDs, member.ID) + if member.RoleSlug != "" { + roleExists, err := activeAssignmentRoleExists(ctx, db, gramOrgID, member.RoleSlug) + if err != nil { + return err + } + if roleExists { + roleMembershipIDs = append(roleMembershipIDs, member.ID) + } + } + } + if err := requireMemberships(ctx, db, gramOrgID, membershipIDs); err != nil { + return err + } + if err := requireRoleAssignments(ctx, db, gramOrgID, roleMembershipIDs); err != nil { + return err + } + + return nil +} + +func expectedGramUserID(ctx context.Context, db *pgxpool.Pool, user workos.User) (string, bool, error) { + if user.ID == "" { + return "", false, nil + } + existing, found, err := findUserByWorkOSID(ctx, db, user.ID) + if err != nil { + return "", false, err + } + if found { + return existing.ID, true, nil + } + if user.ExternalID == "" { + return "", false, nil + } + return user.ExternalID, true, nil +} + +func validateGlobalRoles(ctx context.Context, db *pgxpool.Pool, roles []workos.Role) report { + rep := report{ + scanned: len(roles), + skipped: 0, + skippedNoop: 0, + written: 0, + validated: 0, + failed: 0, + validationFailures: 0, + organizationRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + roleRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + userRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + membershipRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + assignmentRows: changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}, + } + expectedSlugs := make([]string, 0, len(roles)) + for _, role := range roles { + expectedSlugs = append(expectedSlugs, role.Slug) + } + if err := requireSlugs(ctx, db, "global_roles", "TRUE", nil, expectedSlugs); err != nil { + rep.validationFailures = 1 + fmt.Fprintf(os.Stderr, "global role validation failed: %v\n", err) + return rep + } + rep.validated = len(roles) + return rep +} + +func requireSlugs(ctx context.Context, db *pgxpool.Pool, table, predicate string, args []any, expected []string) error { + expectedSet := set(expected) + query := fmt.Sprintf("SELECT workos_slug FROM %s WHERE %s AND deleted IS FALSE AND workos_deleted IS FALSE", table, predicate) // #nosec G201 -- table and predicate are fixed call-site constants. + rows, err := db.Query(ctx, query, args...) + if err != nil { + return fmt.Errorf("query active role slugs: %w", err) + } + defer rows.Close() + + actualSet := map[string]bool{} + for rows.Next() { + var slug string + if err := rows.Scan(&slug); err != nil { + return fmt.Errorf("scan role slug: %w", err) + } + actualSet[slug] = true + } + if err := rows.Err(); err != nil { + return fmt.Errorf("iterate role slugs: %w", err) + } + + missing := difference(expectedSet, actualSet) + extra := difference(actualSet, expectedSet) + if len(missing) > 0 || len(extra) > 0 { + return fmt.Errorf("slug mismatch: missing=%v extra=%v", missing, extra) + } + return nil +} + +func requireMemberships(ctx context.Context, db *pgxpool.Pool, orgID string, membershipIDs []string) error { + if len(membershipIDs) == 0 { + return nil + } + rows, err := db.Query(ctx, ` +SELECT workos_membership_id +FROM organization_user_relationships +WHERE organization_id = $1 + AND workos_membership_id = ANY($2::text[]) + AND deleted IS FALSE`, orgID, membershipIDs) + if err != nil { + return fmt.Errorf("query active memberships: %w", err) + } + defer rows.Close() + + actual := map[string]bool{} + for rows.Next() { + var membershipID string + if err := rows.Scan(&membershipID); err != nil { + return fmt.Errorf("scan membership id: %w", err) + } + actual[membershipID] = true + } + if err := rows.Err(); err != nil { + return fmt.Errorf("iterate memberships: %w", err) + } + + if missing := difference(set(membershipIDs), actual); len(missing) > 0 { + return fmt.Errorf("missing memberships: %v", missing) + } + return nil +} + +func requireUsers(ctx context.Context, db *pgxpool.Pool, workosUserIDs []string) error { + if len(workosUserIDs) == 0 { + return nil + } + rows, err := db.Query(ctx, ` +SELECT workos_id +FROM users +WHERE workos_id = ANY($1::text[]) + AND deleted_at IS NULL`, workosUserIDs) + if err != nil { + return fmt.Errorf("query active users: %w", err) + } + defer rows.Close() + + actual := map[string]bool{} + for rows.Next() { + var workosUserID string + if err := rows.Scan(&workosUserID); err != nil { + return fmt.Errorf("scan user workos id: %w", err) + } + actual[workosUserID] = true + } + if err := rows.Err(); err != nil { + return fmt.Errorf("iterate users: %w", err) + } + + if missing := difference(set(workosUserIDs), actual); len(missing) > 0 { + return fmt.Errorf("missing users: %v", missing) + } + return nil +} + +func requireRoleAssignments(ctx context.Context, db *pgxpool.Pool, orgID string, membershipIDs []string) error { + if len(membershipIDs) == 0 { + return nil + } + rows, err := db.Query(ctx, ` +SELECT DISTINCT workos_membership_id +FROM organization_role_assignments +WHERE organization_id = $1 + AND workos_membership_id = ANY($2::text[]) + AND deleted_at IS NULL`, orgID, membershipIDs) + if err != nil { + return fmt.Errorf("query active role assignments: %w", err) + } + defer rows.Close() + + actual := map[string]bool{} + for rows.Next() { + var membershipID string + if err := rows.Scan(&membershipID); err != nil { + return fmt.Errorf("scan role assignment membership id: %w", err) + } + actual[membershipID] = true + } + if err := rows.Err(); err != nil { + return fmt.Errorf("iterate role assignments: %w", err) + } + + if missing := difference(set(membershipIDs), actual); len(missing) > 0 { + return fmt.Errorf("missing role assignments for memberships: %v", missing) + } + return nil +} + +func printOrganizationPlan(orgs []orgExpectation) { + var roles int + var users int + var memberships int + var skipped int + orgChanges := changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0} + roleChanges := changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0} + userChanges := changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0} + membershipChanges := changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0} + assignmentChanges := changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0} + changeDetails := make([]changeDetail, 0) + for _, org := range orgs { + if org.skipped { + skipped++ + } + orgChanges = orgChanges.Add(org.orgChanges) + roleChanges = roleChanges.Add(org.roleChanges) + userChanges = userChanges.Add(org.userChanges) + membershipChanges = membershipChanges.Add(org.membershipChanges) + assignmentChanges = assignmentChanges.Add(org.assignmentChanges) + for _, role := range org.roles { + if role.Type == "OrganizationRole" { + roles++ + } + } + users += len(org.users) + memberships += len(org.members) + changeDetails = append(changeDetails, org.changeDetails...) + } + + fmt.Println("Organization preflight:") + fmt.Printf(" workos_orgs: %d\n", len(orgs)) + fmt.Printf(" expected_organization_roles: %d\n", roles) + fmt.Printf(" expected_users: %d\n", users) + fmt.Printf(" expected_memberships: %d\n", memberships) + fmt.Printf(" skipped_unlinked_without_external_id: %d\n", skipped) + printChangeCounts(" organization_rows", orgChanges) + printChangeCounts(" role_rows", roleChanges) + printChangeCounts(" user_rows", userChanges) + printChangeCounts(" membership_rows", membershipChanges) + printChangeCounts(" assignment_rows", assignmentChanges) + printSamples(orgs) + printChangeSummary(" planned_change_summary", changeDetails) + printChangeDetails(" planned_change_details", changeDetails) +} + +func printGlobalRolePlan(roles []workos.Role, changes changeCounts, details []changeDetail) { + fmt.Println("Global role preflight:") + fmt.Printf(" workos_global_roles: %d\n", len(roles)) + printChangeCounts(" role_rows", changes) + for _, role := range sampleRoles(roles) { + fmt.Printf(" %s (%s)\n", role.Slug, role.Name) + } + printChangeSummary(" planned_change_summary", details) + printChangeDetails(" planned_change_details", details) +} + +func printSamples(orgs []orgExpectation) { + limit := min(len(orgs), sampleSize) + if limit == 0 { + return + } + fmt.Println(" sample:") + for _, org := range orgs[:limit] { + status := org.gramOrgID + if org.skipped { + status = "skip" + } + fmt.Printf(" %s -> %s org=%s roles=%s users=%s memberships=%s assignments=%s name=%q\n", + org.workosOrgID, + status, + formatDominantChange(org.orgChanges), + formatDominantChange(org.roleChanges), + formatDominantChange(org.userChanges), + formatDominantChange(org.membershipChanges), + formatDominantChange(org.assignmentChanges), + org.name, + ) + } + if len(orgs) > limit { + fmt.Printf(" ... and %d more\n", len(orgs)-limit) + } +} + +func organizationSummary(orgs []orgExpectation) string { + var skipped int + for _, org := range orgs { + if org.skipped { + skipped++ + } + } + return fmt.Sprintf("apply %d planned row changes; skip %d unlinked WorkOS organizations without external_id", plannedOrganizationMutations(orgs), skipped) +} + +func plannedOrganizationMutations(orgs []orgExpectation) int { + var mutating int + for _, org := range orgs { + if !org.skipped { + mutating += plannedOrganizationMutation(org) + } + } + return mutating +} + +func plannedOrganizationMutation(org orgExpectation) int { + return org.orgChanges.Mutating() + + org.roleChanges.Mutating() + + org.userChanges.Mutating() + + org.membershipChanges.Mutating() + + org.assignmentChanges.Mutating() +} + +func printChangeCounts(label string, counts changeCounts) { + fmt.Printf("%s: affected=%d create=%d update=%d delete=%d noop=%d stale_skip=%d\n", + label, + counts.Mutating(), + counts.Create, + counts.Update, + counts.Delete, + counts.Noop, + counts.StaleSkip, + ) +} + +func printChangeSummary(label string, details []changeDetail) { + if len(details) == 0 { + return + } + groups := summarizeChangeDetails(details) + fmt.Printf("%s: groups=%d changed_records=%d\n", label, len(groups), len(details)) + for _, group := range groups { + fmt.Printf(" %s %s risk=%s fields=%s count=%d\n", + group.Action, + group.Entity, + group.Risk, + strings.Join(group.Fields, ","), + group.Count, + ) + for _, sample := range group.Samples { + fmt.Printf(" sample %s\n", sample.ID) + for _, field := range sampleSummaryFields(sample.Fields) { + fmt.Printf(" %s: %q -> %q\n", field.Name, field.Before, field.After) + } + } + if group.Count > len(group.Samples) { + fmt.Printf(" ... and %d more\n", group.Count-len(group.Samples)) + } + } +} + +func summarizeChangeDetails(details []changeDetail) []changeSummaryGroup { + groupsByKey := make(map[string]*changeSummaryGroup) + for _, detail := range details { + fields := changeFieldNames(detail.Fields) + risk := changeRisk(detail.Action, fields) + key := strings.Join([]string{risk, detail.Entity, detail.Action, strings.Join(fields, ",")}, "|") + group, ok := groupsByKey[key] + if !ok { + group = &changeSummaryGroup{ + Entity: detail.Entity, + Action: detail.Action, + Risk: risk, + Fields: fields, + Count: 0, + Samples: nil, + } + groupsByKey[key] = group + } + group.Count++ + if len(group.Samples) < changeSummarySampleLimit { + group.Samples = append(group.Samples, detail) + } + } + + groups := make([]changeSummaryGroup, 0, len(groupsByKey)) + for _, group := range groupsByKey { + groups = append(groups, *group) + } + sort.Slice(groups, func(i, j int) bool { + leftRisk := changeRiskRank(groups[i].Risk) + rightRisk := changeRiskRank(groups[j].Risk) + if leftRisk != rightRisk { + return leftRisk < rightRisk + } + if groups[i].Count != groups[j].Count { + return groups[i].Count > groups[j].Count + } + if groups[i].Entity != groups[j].Entity { + return groups[i].Entity < groups[j].Entity + } + if groups[i].Action != groups[j].Action { + return groups[i].Action < groups[j].Action + } + return strings.Join(groups[i].Fields, ",") < strings.Join(groups[j].Fields, ",") + }) + return groups +} + +func changeFieldNames(fields []fieldChange) []string { + names := make([]string, 0, len(fields)) + seen := make(map[string]struct{}, len(fields)) + for _, field := range fields { + if _, ok := seen[field.Name]; ok { + continue + } + seen[field.Name] = struct{}{} + names = append(names, field.Name) + } + sort.Strings(names) + if len(names) == 0 { + return []string{""} + } + return names +} + +func sampleSummaryFields(fields []fieldChange) []fieldChange { + if len(fields) <= changeSummarySampleLimit { + return fields + } + ranked := append([]fieldChange(nil), fields...) + sort.SliceStable(ranked, func(i, j int) bool { + return changeFieldRiskRank(ranked[i].Name) < changeFieldRiskRank(ranked[j].Name) + }) + return ranked[:changeSummarySampleLimit] +} + +func changeRisk(action string, fields []string) string { + if action == "delete" { + return "critical" + } + risk := "metadata_only" + for _, field := range fields { + fieldRisk := changeFieldRisk(field) + if changeRiskRank(fieldRisk) < changeRiskRank(risk) { + risk = fieldRisk + } + } + return risk +} + +func changeFieldRisk(field string) string { + switch field { + case "deleted", "deleted_at", "disabled_at", "workos_deleted", "workos_deleted_at": + return "critical" + case "organization_id", "role_id", "user_id", "workos_id", "workos_membership_id", "workos_slug", "workos_user_id": + return "identity" + case "email", "id", "name", "slug", "workos_description", "workos_name": + return "display" + case "display_name", "photo_url": + return "profile" + case "workos_created_at", "workos_last_event_id", "workos_updated_at": + return "metadata_only" + default: + return "normal" + } +} + +func changeFieldRiskRank(field string) int { + return changeRiskRank(changeFieldRisk(field)) +} + +func changeRiskRank(risk string) int { + switch risk { + case "critical": + return 0 + case "identity": + return 1 + case "display": + return 2 + case "profile": + return 3 + case "normal": + return 4 + case "metadata_only": + return 5 + default: + return 6 + } +} + +func printChangeDetails(label string, details []changeDetail) { + if len(details) == 0 { + return + } + fmt.Printf("%s: showing=%d total=%d\n", label, min(len(details), updateDetailLimit), len(details)) + limit := min(len(details), updateDetailLimit) + for _, detail := range details[:limit] { + fmt.Printf(" %s %s %s\n", detail.Action, detail.Entity, detail.ID) + for _, field := range detail.Fields { + fmt.Printf(" %s: %q -> %q\n", field.Name, field.Before, field.After) + } + } + if len(details) > limit { + fmt.Printf(" ... and %d more changed records\n", len(details)-limit) + } +} + +func organizationCreateDetail(org workos.Organization, gramOrgID string) changeDetail { + fields := []fieldChange{ + {Name: "id", Before: "", After: gramOrgID}, + {Name: "name", Before: "", After: org.Name}, + {Name: "slug", Before: "", After: "generated unique slug"}, + {Name: "workos_id", Before: "", After: org.ID}, + } + if updatedAt, err := parseWorkOSTime(org.UpdatedAt); err == nil { + fields = append(fields, fieldChange{Name: "workos_updated_at", Before: "", After: timeDisplay(updatedAt)}) + } + return changeDetail{Entity: "organization", ID: org.ID, Action: "create", Fields: fields} +} + +func roleCreateDetail(entity string, role workos.Role, updatedAt time.Time) changeDetail { + fields := []fieldChange{ + {Name: "workos_slug", Before: "", After: role.Slug}, + {Name: "workos_name", Before: "", After: role.Name}, + {Name: "workos_description", Before: "", After: role.Description}, + } + if createdAt, err := parseWorkOSTime(role.CreatedAt); err == nil { + fields = append(fields, fieldChange{Name: "workos_created_at", Before: "", After: timeDisplay(createdAt)}) + } + fields = append(fields, fieldChange{Name: "workos_updated_at", Before: "", After: timeDisplay(updatedAt)}) + return changeDetail{Entity: entity, ID: role.Slug, Action: "create", Fields: fields} +} + +func userCreateDetail(user workos.User, gramUserID string, createdAt, updatedAt time.Time) changeDetail { + return changeDetail{ + Entity: "user", + ID: user.ID, + Action: "create", + Fields: []fieldChange{ + {Name: "id", Before: "", After: gramUserID}, + {Name: "email", Before: "", After: user.Email}, + {Name: "display_name", Before: "", After: displayNameFromWorkOSUser(user)}, + {Name: "photo_url", Before: "", After: user.ProfilePictureURL}, + {Name: "workos_id", Before: "", After: user.ID}, + {Name: "workos_created_at", Before: "", After: timeDisplay(createdAt)}, + {Name: "workos_updated_at", Before: "", After: timeDisplay(updatedAt)}, + }, + } +} + +func membershipCreateDetail(organizationID string, member workos.Member, gramUserID string, updatedAt time.Time) changeDetail { + return changeDetail{ + Entity: "membership", + ID: member.ID, + Action: "create", + Fields: []fieldChange{ + {Name: "organization_id", Before: "", After: organizationID}, + {Name: "user_id", Before: "", After: gramUserID}, + {Name: "workos_user_id", Before: "", After: member.UserID}, + {Name: "workos_membership_id", Before: "", After: member.ID}, + {Name: "workos_updated_at", Before: "", After: timeDisplay(updatedAt)}, + }, + } +} + +func roleAssignmentCreateDetail(organizationID string, member workos.Member, gramUserID string) changeDetail { + return changeDetail{ + Entity: "role_assignment", + ID: member.ID, + Action: "create", + Fields: []fieldChange{ + {Name: "organization_id", Before: "", After: organizationID}, + {Name: "user_id", Before: "", After: gramUserID}, + {Name: "workos_membership_id", Before: "", After: member.ID}, + {Name: "workos_slug", Before: "", After: member.RoleSlug}, + }, + } +} + +func appendFieldChange(fields []fieldChange, name, before, after string) []fieldChange { + if before == after { + return fields + } + return append(fields, fieldChange{Name: name, Before: before, After: after}) +} + +func pgTextDisplay(value pgtype.Text) string { + if !value.Valid { + return "" + } + return value.String +} + +func pgTimeDisplay(value pgtype.Timestamptz) string { + if !value.Valid { + return "" + } + return timeDisplay(value.Time) +} + +func timeDisplay(value time.Time) string { + if value.IsZero() { + return "" + } + return value.UTC().Format(time.RFC3339Nano) +} + +func boolDisplay(value bool) string { + return fmt.Sprintf("%t", value) +} + +func uuidDisplay(value pgtype.UUID) string { + if !value.Valid { + return "" + } + return fmt.Sprintf("%x-%x-%x-%x-%x", value.Bytes[0:4], value.Bytes[4:6], value.Bytes[6:8], value.Bytes[8:10], value.Bytes[10:16]) +} + +func formatDominantChange(counts changeCounts) string { + switch { + case counts.Create > 0: + return fmt.Sprintf("create:%d", counts.Create) + case counts.Update > 0: + return fmt.Sprintf("update:%d", counts.Update) + case counts.Delete > 0: + return fmt.Sprintf("delete:%d", counts.Delete) + case counts.StaleSkip > 0: + return fmt.Sprintf("stale_skip:%d", counts.StaleSkip) + default: + return fmt.Sprintf("noop:%d", counts.Noop) + } +} + +func printReport(title string, rep report) { + fmt.Println(title) + fmt.Printf(" scanned: %d\n", rep.scanned) + fmt.Printf(" written: %d\n", rep.written) + fmt.Printf(" validated: %d\n", rep.validated) + fmt.Printf(" skipped: %d\n", rep.skipped) + fmt.Printf(" skipped_noop: %d\n", rep.skippedNoop) + fmt.Printf(" failed: %d\n", rep.failed) + fmt.Printf(" validation_failures: %d\n", rep.validationFailures) + if reportHasRowOutcomes(rep) { + printChangeCounts(" organization_rows", rep.organizationRows) + printChangeCounts(" role_rows", rep.roleRows) + printChangeCounts(" user_rows", rep.userRows) + printChangeCounts(" membership_rows", rep.membershipRows) + printChangeCounts(" assignment_rows", rep.assignmentRows) + } +} + +func reportHasRowOutcomes(rep report) bool { + return rep.organizationRows != (changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}) || + rep.roleRows != (changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}) || + rep.userRows != (changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}) || + rep.membershipRows != (changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}) || + rep.assignmentRows != (changeCounts{Create: 0, Update: 0, Noop: 0, Delete: 0, StaleSkip: 0}) +} + +func confirmWrite(opts options, summary string) error { + if opts.dryRun || opts.phase == phasePreflight || opts.phase == phaseValidate { + return nil + } + fmt.Printf("Write preflight: %s\n", summary) + fmt.Println(" DB changes run with lock_timeout=5s and statement_timeout=5min.") + fmt.Println(" WorkOS snapshot writes are not automatically reversible.") + if opts.autoApprove && opts.environment != envProd { + return nil + } + return promptExact("Type backfill to continue: ", "backfill") +} + +func confirmProdAccess(opts options) error { + if opts.confirmProd == "prod" { + return nil + } + if !term.IsTerminal(syscall.Stdin) || !term.IsTerminal(syscall.Stdout) { + return errors.New("prod access requires --confirm-prod=prod in non-interactive mode") + } + return promptExact("You are connecting to prod. Type prod to continue: ", "prod") +} + +func promptExact(prompt, want string) error { + fmt.Print(prompt) + reader := bufio.NewReader(os.Stdin) + got, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("read confirmation: %w", err) + } + if strings.TrimSpace(got) != want { + return fmt.Errorf("confirmation did not match %q", want) + } + return nil +} + +func waitForEnter(message string) { + fmt.Println(message) + _, _ = bufio.NewReader(os.Stdin).ReadString('\n') +} + +func sampleRoles(roles []workos.Role) []workos.Role { + if len(roles) <= sampleSize { + return roles + } + return roles[:sampleSize] +} + +func set(items []string) map[string]bool { + out := make(map[string]bool, len(items)) + for _, item := range items { + if item != "" { + out[item] = true + } + } + return out +} + +func difference(left, right map[string]bool) []string { + out := make([]string, 0) + for item := range left { + if !right[item] { + out = append(out, item) + } + } + sort.Strings(out) + return out +} + +func textPtr(value pgtype.Text) *string { + if !value.Valid { + return nil + } + return &value.String +} + +func timePtr(value pgtype.Timestamptz) *time.Time { + if !value.Valid { + return nil + } + return &value.Time +} + +func pgTextEmptyEqual(value pgtype.Text, want string) bool { + if want == "" { + return !value.Valid + } + return value.Valid && value.String == want +} + +func pgTimeEqual(value pgtype.Timestamptz, want time.Time) bool { + return value.Valid && value.Time.Equal(want) +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if value != "" { + return value + } + } + return "" +} + +func must(err error) { + if err != nil { + fmt.Fprintf(os.Stderr, "workos-backfill: %v\n", err) + flag.Usage() + os.Exit(2) + } +} diff --git a/server/cmd/workos-role-dump/main.go b/server/cmd/workos-role-dump/main.go new file mode 100644 index 0000000000..c504e5b230 --- /dev/null +++ b/server/cmd/workos-role-dump/main.go @@ -0,0 +1,239 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "net/http" + "net/url" + "os" + "sort" + "strings" + + "go.opentelemetry.io/otel/trace/noop" + + "github.com/speakeasy-api/gram/server/internal/guardian" + "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" +) + +const defaultWorkOSEndpoint = "https://api.workos.com" + +type options struct { + workosAPIKey string + workosEndpoint string + workosOrgIDs []string +} + +type stringList []string + +func (s *stringList) String() string { + return strings.Join(*s, ",") +} + +func (s *stringList) Set(value string) error { + for part := range strings.SplitSeq(value, ",") { + part = strings.TrimSpace(part) + if part != "" { + *s = append(*s, part) + } + } + return nil +} + +type rawRoleListResponse struct { + Data []workos.Role `json:"data"` + ListMetadata struct { + After string `json:"after"` + } `json:"list_metadata"` +} + +func main() { + ctx := context.Background() + if err := run(ctx, parseFlags()); err != nil { + fmt.Fprintf(os.Stderr, "workos-role-dump: %v\n", err) + os.Exit(1) + } +} + +func parseFlags() options { + opts := options{ + workosAPIKey: strings.TrimSpace(firstNonEmpty(os.Getenv("WORKOS_API_KEY"), os.Getenv("WORK_OS_SECRET_KEY"))), + workosEndpoint: strings.TrimSpace(firstNonEmpty(os.Getenv("WORKOS_API_URL"), defaultWorkOSEndpoint)), + workosOrgIDs: nil, + } + + var orgIDs stringList + flag.StringVar(&opts.workosAPIKey, "workos-api-key", opts.workosAPIKey, "WorkOS API key (defaults to WORKOS_API_KEY or WORK_OS_SECRET_KEY)") + flag.StringVar(&opts.workosEndpoint, "workos-endpoint", opts.workosEndpoint, "WorkOS API endpoint override (defaults to WORKOS_API_URL or api.workos.com)") + flag.Var(&orgIDs, "workos-org-id", "WorkOS organization id to inspect; repeat or comma-separate") + flag.Parse() + + opts.workosOrgIDs = orgIDs + must(validateOptions(opts)) + return opts +} + +func validateOptions(opts options) error { + if opts.workosAPIKey == "" { + return errors.New("--workos-api-key, WORKOS_API_KEY, or WORK_OS_SECRET_KEY is required") + } + if opts.workosEndpoint == "" { + return errors.New("--workos-endpoint or WORKOS_API_URL must be non-empty") + } + if len(opts.workosOrgIDs) == 0 { + return errors.New("--workos-org-id is required") + } + return nil +} + +func run(ctx context.Context, opts options) error { + policy := guardian.NewDefaultPolicy(noop.NewTracerProvider()) + sdkClient := workos.NewClient(policy, opts.workosAPIKey, workos.ClientOpts{ + Endpoint: opts.workosEndpoint, + HTTPClient: nil, + }) + httpClient := policy.PooledClient() + + for _, orgID := range opts.workosOrgIDs { + fmt.Printf("WorkOS organization %s\n", orgID) + + sdkRoles, err := sdkClient.ListRoles(ctx, orgID) + if err != nil { + return fmt.Errorf("list roles through SDK wrapper for %s: %w", orgID, err) + } + rawRoles, err := listRawAuthorizationRoles(ctx, httpClient, opts, orgID) + if err != nil { + return fmt.Errorf("list raw authorization roles for %s: %w", orgID, err) + } + + printRoleList(" sdk ListOrganizationRoles", sdkRoles) + printRoleList(" raw /authorization/organizations/{orgID}/roles", rawRoles) + printRoleDiff(sdkRoles, rawRoles) + } + return nil +} + +func listRawAuthorizationRoles(ctx context.Context, httpClient *guardian.HTTPClient, opts options, orgID string) ([]workos.Role, error) { + roles := make([]workos.Role, 0) + var after string + for { + reqURL, err := rawRolesURL(opts.workosEndpoint, orgID, after) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) + if err != nil { + return nil, fmt.Errorf("create raw roles request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+opts.workosAPIKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("send raw roles request: %w", err) + } + body, readErr := io.ReadAll(resp.Body) + if closeErr := resp.Body.Close(); closeErr != nil && readErr == nil { + readErr = closeErr + } + if readErr != nil { + return nil, fmt.Errorf("read raw roles response: %w", readErr) + } + if resp.StatusCode >= http.StatusBadRequest { + return nil, fmt.Errorf("raw roles request failed status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var parsed rawRoleListResponse + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("decode raw roles response: %w", err) + } + roles = append(roles, parsed.Data...) + if parsed.ListMetadata.After == "" { + return roles, nil + } + after = parsed.ListMetadata.After + } +} + +func rawRolesURL(endpoint, orgID, after string) (string, error) { + base, err := url.Parse(endpoint) + if err != nil { + return "", fmt.Errorf("parse WorkOS endpoint: %w", err) + } + path, err := url.JoinPath(base.Path, "/authorization/organizations", orgID, "roles") + if err != nil { + return "", fmt.Errorf("build raw roles path: %w", err) + } + base.Path = path + q := base.Query() + q.Set("limit", "100") + if after != "" { + q.Set("after", after) + } + base.RawQuery = q.Encode() + return base.String(), nil +} + +func printRoleList(title string, roles []workos.Role) { + roles = sortedRoles(roles) + fmt.Printf("%s count=%d\n", title, len(roles)) + for _, role := range roles { + fmt.Printf(" slug=%q type=%q name=%q updated_at=%q\n", role.Slug, role.Type, role.Name, role.UpdatedAt) + } +} + +func printRoleDiff(sdkRoles []workos.Role, rawRoles []workos.Role) { + sdkSet := roleSlugSet(sdkRoles) + rawSet := roleSlugSet(rawRoles) + missingFromSDK := difference(rawSet, sdkSet) + missingFromRaw := difference(sdkSet, rawSet) + fmt.Println(" diff") + fmt.Printf(" raw_missing_from_sdk=%v\n", missingFromSDK) + fmt.Printf(" sdk_missing_from_raw=%v\n", missingFromRaw) +} + +func sortedRoles(roles []workos.Role) []workos.Role { + out := append([]workos.Role(nil), roles...) + sort.Slice(out, func(i, j int) bool { + return out[i].Slug < out[j].Slug + }) + return out +} + +func roleSlugSet(roles []workos.Role) map[string]bool { + out := make(map[string]bool, len(roles)) + for _, role := range roles { + out[role.Slug] = true + } + return out +} + +func difference(left map[string]bool, right map[string]bool) []string { + out := make([]string, 0) + for value := range left { + if !right[value] { + out = append(out, value) + } + } + sort.Strings(out) + return out +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + +func must(err error) { + if err != nil { + fmt.Fprintf(os.Stderr, "workos-role-dump: %v\n", err) + os.Exit(2) + } +} diff --git a/server/internal/background/activities.go b/server/internal/background/activities.go index de81e3dbe3..844bc9ddeb 100644 --- a/server/internal/background/activities.go +++ b/server/internal/background/activities.go @@ -78,9 +78,6 @@ type Activities struct { reapSoftDeletedAssistantMems *activities.ReapSoftDeletedAssistantMemories signalAssistantCoordinator *activities.SignalAssistantCoordinator signalAssistantThread *activities.SignalAssistantThread - listWorkOSOrganizations *activities.ListWorkOSOrganizations - backfillWorkOSOrganization *activities.BackfillWorkOSOrganization - backfillWorkOSGlobalRoles *activities.BackfillWorkOSGlobalRoles processWorkOSOrganizationEvents *activities.ProcessWorkOSOrganizationEvents processWorkOSGlobalRoleEvents *activities.ProcessWorkOSGlobalRoleEvents processWorkOSUserEvents *activities.ProcessWorkOSUserEvents @@ -163,9 +160,6 @@ func NewActivities( reapSoftDeletedAssistantMems: activities.NewReapSoftDeletedAssistantMemories(logger, db), signalAssistantCoordinator: activities.NewSignalAssistantCoordinator(&AssistantWorkflowSignaler{TemporalEnv: temporalEnv}), signalAssistantThread: activities.NewSignalAssistantThread(&AssistantWorkflowSignaler{TemporalEnv: temporalEnv}), - listWorkOSOrganizations: activities.NewListWorkOSOrganizations(logger, workosClient), - backfillWorkOSOrganization: activities.NewBackfillWorkOSOrganization(logger, db, workosClient), - backfillWorkOSGlobalRoles: activities.NewBackfillWorkOSGlobalRoles(logger, db, workosClient), processWorkOSOrganizationEvents: activities.NewProcessWorkOSOrganizationEvents(logger, db, workosClient), processWorkOSGlobalRoleEvents: activities.NewProcessWorkOSGlobalRoleEvents(logger, db, workosClient), processWorkOSUserEvents: activities.NewProcessWorkOSUserEvents(logger, db, workosClient), @@ -175,18 +169,6 @@ func NewActivities( } } -func (a *Activities) ListWorkOSOrganizations(ctx context.Context) ([]string, error) { - return a.listWorkOSOrganizations.Do(ctx) -} - -func (a *Activities) BackfillWorkOSOrganization(ctx context.Context, params activities.BackfillWorkOSOrganizationParams) error { - return a.backfillWorkOSOrganization.Do(ctx, params) -} - -func (a *Activities) BackfillWorkOSGlobalRoles(ctx context.Context) error { - return a.backfillWorkOSGlobalRoles.Do(ctx) -} - func (a *Activities) ProcessWorkOSOrganizationEvents(ctx context.Context, params activities.ProcessWorkOSOrganizationEventsParams) (*activities.ProcessWorkOSOrganizationEventsResult, error) { return a.processWorkOSOrganizationEvents.Do(ctx, params) } diff --git a/server/internal/background/activities/backfill_workos_test.go b/server/internal/background/activities/backfill_workos_test.go deleted file mode 100644 index 4159b7176c..0000000000 --- a/server/internal/background/activities/backfill_workos_test.go +++ /dev/null @@ -1,368 +0,0 @@ -package activities_test - -import ( - "context" - "fmt" - "testing" - "time" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" - "github.com/stretchr/testify/require" - - accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" - "github.com/speakeasy-api/gram/server/internal/background/activities" - "github.com/speakeasy-api/gram/server/internal/conv" - orgrepo "github.com/speakeasy-api/gram/server/internal/organizations/repo" - "github.com/speakeasy-api/gram/server/internal/testenv" - "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" -) - -func TestBackfillWorkOSOrganization_CreatesUnlinkedOrganizationWithExternalID(t *testing.T) { - t.Parallel() - - ctx := context.Background() - conn := newOrgEventsTestConn(t, "workos_backfill_create_org_external_id") - logger := testenv.NewLogger(t) - - const organizationID = "gram_org_from_workos_external_id" - const workosOrgID = "org_01JBACKFILLCREATE" - - workosClient := newWorkOSSnapshotClient(t, ctx, - workos.Organization{ - ID: workosOrgID, - Name: "Backfill Created Org", - ExternalID: organizationID, - CreatedAt: "2026-05-07T11:00:00Z", - UpdatedAt: "2026-05-07T11:00:00Z", - }, - nil, - nil, - ) - activity := activities.NewBackfillWorkOSOrganization(logger, conn, workosClient) - - err := activity.Do(ctx, activities.BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) - require.NoError(t, err) - - org, err := orgrepo.New(conn).GetOrganizationByWorkosID(ctx, conv.ToPGText(workosOrgID)) - require.NoError(t, err) - require.Equal(t, organizationID, org.ID) - require.Equal(t, "Backfill Created Org", org.Name) - require.Equal(t, "backfill-created-org", org.Slug) - require.Equal(t, workosOrgID, org.WorkosID.String) - require.Empty(t, org.WorkosLastEventID.String) -} - -func TestBackfillWorkOSOrganization_ExternalIDChangeDoesNotChangeOrganizationID(t *testing.T) { - t.Parallel() - - ctx := context.Background() - conn := newOrgEventsTestConn(t, "workos_backfill_external_id_immutable") - logger := testenv.NewLogger(t) - - const organizationID = "gram_org_original_external_id" - const changedExternalID = "gram_org_changed_external_id" - const workosOrgID = "org_01JBACKFILLIMMUTABLE" - - seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) - - workosClient := newWorkOSSnapshotClient(t, ctx, - workos.Organization{ - ID: workosOrgID, - Name: "Backfill Immutable Org", - ExternalID: changedExternalID, - CreatedAt: "2026-05-07T11:00:00Z", - UpdatedAt: "2026-05-07T11:00:00Z", - }, - nil, - nil, - ) - activity := activities.NewBackfillWorkOSOrganization(logger, conn, workosClient) - - err := activity.Do(ctx, activities.BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) - require.NoError(t, err) - - org, err := orgrepo.New(conn).GetOrganizationByWorkosID(ctx, conv.ToPGText(workosOrgID)) - require.NoError(t, err) - require.Equal(t, organizationID, org.ID) - require.Equal(t, "Backfill Immutable Org", org.Name) - - _, err = orgrepo.New(conn).GetOrganizationMetadata(ctx, changedExternalID) - require.ErrorIs(t, err, pgx.ErrNoRows) -} - -func TestBackfillWorkOSOrganization_UnknownUserSyncsSingleRoleAssignment(t *testing.T) { - t.Parallel() - - ctx := context.Background() - conn := newOrgEventsTestConn(t, "workos_backfill_unknown_user_single_role") - logger := testenv.NewLogger(t) - - const organizationID = "gram_org_backfill_unknown_user" - const workosOrgID = "org_01JBACKFILLUNKNOWN" - const workosUserID = "user_01JBACKFILLUNKNOWN" - const membershipID = "mem_01JBACKFILLUNKNOWN" - const roleSlug = "org-support" - - seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) - - workosClient := newWorkOSSnapshotClient(t, ctx, - workos.Organization{ - ID: workosOrgID, - Name: "Backfill Unknown User", - ExternalID: "", - CreatedAt: "2026-05-07T11:00:00Z", - UpdatedAt: "2026-05-07T11:00:00Z", - }, - []workos.Role{{ - ID: "role_01JSUPPORT", - Name: "Support", - Slug: roleSlug, - Description: "Support operators", - Type: "OrganizationRole", - CreatedAt: "2026-05-07T11:00:00Z", - UpdatedAt: "2026-05-07T11:00:00Z", - }}, - []workos.Member{{ - ID: membershipID, - UserID: workosUserID, - OrganizationID: workosOrgID, - Organization: "Backfill Unknown User", - RoleSlug: roleSlug, - Status: "active", - CreatedAt: "2026-05-07T11:05:00Z", - UpdatedAt: "2026-05-07T11:05:00Z", - }}, - ) - activity := activities.NewBackfillWorkOSOrganization(logger, conn, workosClient) - - err := activity.Do(ctx, activities.BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) - require.NoError(t, err) - - role, err := accessrepo.New(conn).GetOrganizationRoleBySlug(ctx, accessrepo.GetOrganizationRoleBySlugParams{ - OrganizationID: organizationID, - WorkosSlug: roleSlug, - }) - require.NoError(t, err) - - assignments, err := orgrepo.New(conn).ListOrganizationRoleAssignmentsByWorkOSUser(ctx, orgrepo.ListOrganizationRoleAssignmentsByWorkOSUserParams{ - OrganizationID: organizationID, - WorkosUserID: workosUserID, - }) - require.NoError(t, err) - require.Len(t, assignments, 1) - require.Equal(t, fmt.Sprintf("role:organization:%s", role.ID.String()), assignments[0].RoleUrn) - require.False(t, assignments[0].UserID.Valid) - require.Equal(t, membershipID, assignments[0].WorkosMembershipID.String) - require.Empty(t, assignments[0].WorkosLastEventID.String) - - relationship, err := orgrepo.New(conn).GetRelationshipByMembershipID(ctx, conv.ToPGText(membershipID)) - require.NoError(t, err) - require.False(t, relationship.UserID.Valid) - require.Equal(t, workosUserID, relationship.WorkosUserID.String) -} - -func TestBackfillWorkOSOrganization_MembershipWithNewerEventSkipsRoleSnapshot(t *testing.T) { - t.Parallel() - - ctx := context.Background() - conn := newOrgEventsTestConn(t, "workos_backfill_membership_newer_event_wins") - logger := testenv.NewLogger(t) - - const organizationID = "gram_org_backfill_membership_event_wins" - const workosOrgID = "org_01JBACKFILLMEMEVENT" - const workosUserID = "user_01JBACKFILLMEMEVENT" - const membershipID = "mem_01JBACKFILLMEMEVENT" - const roleSlug = "org-member" - - seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) - seedOrganizationRoleWithCursor(t, ctx, conn, organizationID, roleSlug, "Member", "") - err := orgrepo.New(conn).SyncUserOrganizationRoleAssignments(ctx, orgrepo.SyncUserOrganizationRoleAssignmentsParams{ - OrganizationID: organizationID, - WorkosUserID: workosUserID, - WorkosRoleSlugs: []string{roleSlug}, - UserID: conv.ToPGTextEmpty(""), - WorkosMembershipID: conv.ToPGText(membershipID), - WorkosUpdatedAt: conv.ToPGTimestamptz(time.Date(2026, 5, 7, 12, 0, 0, 0, time.UTC)), - WorkosLastEventID: conv.ToPGText("event_99FRESH"), - }) - require.NoError(t, err) - - workosClient := newWorkOSSnapshotClient(t, ctx, - workos.Organization{ - ID: workosOrgID, - Name: "Backfill Membership Event Wins", - ExternalID: "", - CreatedAt: "2026-05-07T11:00:00Z", - UpdatedAt: "2026-05-07T11:00:00Z", - }, - []workos.Role{{ - ID: "role_01JMEMBER", - Name: "Member", - Slug: roleSlug, - Description: "", - Type: "OrganizationRole", - CreatedAt: "2026-05-07T11:00:00Z", - UpdatedAt: "2026-05-07T11:00:00Z", - }}, - []workos.Member{{ - ID: membershipID, - UserID: workosUserID, - OrganizationID: workosOrgID, - Organization: "Backfill Membership Event Wins", - RoleSlug: "", - Status: "active", - CreatedAt: "2026-05-07T11:00:00Z", - UpdatedAt: "2026-05-07T11:00:00Z", - }}, - ) - activity := activities.NewBackfillWorkOSOrganization(logger, conn, workosClient) - - err = activity.Do(ctx, activities.BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) - require.NoError(t, err) - - assignments, err := orgrepo.New(conn).ListOrganizationRoleAssignmentsByWorkOSUser(ctx, orgrepo.ListOrganizationRoleAssignmentsByWorkOSUserParams{ - OrganizationID: organizationID, - WorkosUserID: workosUserID, - }) - require.NoError(t, err) - require.Len(t, assignments, 1) - require.Equal(t, "event_99FRESH", assignments[0].WorkosLastEventID.String) -} - -func TestBackfillWorkOSOrganization_RoleWithLastEventIDSkipsSnapshot(t *testing.T) { - t.Parallel() - - ctx := context.Background() - conn := newOrgEventsTestConn(t, "workos_backfill_role_last_event_wins") - logger := testenv.NewLogger(t) - - const organizationID = "gram_org_backfill_event_wins" - const workosOrgID = "org_01JBACKFILLEVENTWINS" - const roleSlug = "org-billing" - - seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) - seedOrganizationRoleWithCursor(t, ctx, conn, organizationID, roleSlug, "Billing From Event", "event_01JNEWER") - - workosClient := newWorkOSSnapshotClient(t, ctx, - workos.Organization{ - ID: workosOrgID, - Name: "Backfill Event Wins", - ExternalID: "", - CreatedAt: "2026-05-07T11:00:00Z", - UpdatedAt: "2026-05-07T11:00:00Z", - }, - []workos.Role{{ - ID: "role_01JBILLING", - Name: "Billing From Snapshot", - Slug: roleSlug, - Description: "", - Type: "OrganizationRole", - CreatedAt: "2026-05-07T11:00:00Z", - UpdatedAt: "2026-05-07T12:00:00Z", - }}, - nil, - ) - activity := activities.NewBackfillWorkOSOrganization(logger, conn, workosClient) - - err := activity.Do(ctx, activities.BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) - require.NoError(t, err) - - role, err := accessrepo.New(conn).GetOrganizationRoleBySlug(ctx, accessrepo.GetOrganizationRoleBySlugParams{ - OrganizationID: organizationID, - WorkosSlug: roleSlug, - }) - require.NoError(t, err) - require.Equal(t, "Billing From Event", role.WorkosName) - require.Equal(t, "event_01JNEWER", role.WorkosLastEventID.String) -} - -func TestBackfillWorkOSOrganization_MissingRoleSoftDeleted(t *testing.T) { - t.Parallel() - - ctx := context.Background() - conn := newOrgEventsTestConn(t, "workos_backfill_role_deleted") - logger := testenv.NewLogger(t) - - const organizationID = "gram_org_backfill_delete_role" - const workosOrgID = "org_01JBACKFILLDELETE" - const roleSlug = "org-obsolete" - - seedLinkedWorkOSOrganization(t, ctx, conn, organizationID, workosOrgID) - seedOrganizationRoleWithCursor(t, ctx, conn, organizationID, roleSlug, "Obsolete", "") - - workosClient := newWorkOSSnapshotClient(t, ctx, - workos.Organization{ - ID: workosOrgID, - Name: "Backfill Delete Role", - ExternalID: "", - CreatedAt: "2026-05-07T11:00:00Z", - UpdatedAt: "2026-05-07T11:00:00Z", - }, - nil, - nil, - ) - activity := activities.NewBackfillWorkOSOrganization(logger, conn, workosClient) - - err := activity.Do(ctx, activities.BackfillWorkOSOrganizationParams{WorkOSOrganizationID: workosOrgID}) - require.NoError(t, err) - - role, err := accessrepo.New(conn).GetOrganizationRoleBySlug(ctx, accessrepo.GetOrganizationRoleBySlugParams{ - OrganizationID: organizationID, - WorkosSlug: roleSlug, - }) - require.NoError(t, err) - require.True(t, role.Deleted) - require.True(t, role.WorkosDeleted) - require.Empty(t, role.WorkosLastEventID.String) -} - -func newWorkOSSnapshotClient(t *testing.T, ctx context.Context, org workos.Organization, roles []workos.Role, members []workos.Member) *workos.StubClient { - t.Helper() - - client := workos.NewStubClient() - client.UpsertOrganization(org) - for _, role := range roles { - _, err := client.CreateRole(ctx, org.ID, workos.CreateRoleOpts{ - Name: role.Name, - Slug: role.Slug, - Description: role.Description, - }) - require.NoError(t, err) - } - for _, member := range members { - client.UpsertOrganizationMembership(member) - } - - return client -} - -func seedLinkedWorkOSOrganization(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID, workosOrgID string) { - t.Helper() - - _, err := orgrepo.New(conn).UpsertOrganizationMetadata(ctx, orgrepo.UpsertOrganizationMetadataParams{ - ID: organizationID, - Name: organizationID, - Slug: organizationID, - WorkosID: conv.ToPGText(workosOrgID), - Whitelisted: pgtype.Bool{Bool: false, Valid: false}, - }) - require.NoError(t, err) -} - -func seedOrganizationRoleWithCursor(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID, slug, name, lastEventID string) { - t.Helper() - - updatedAt := time.Date(2026, 5, 7, 10, 0, 0, 0, time.UTC) - err := accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ - OrganizationID: organizationID, - WorkosSlug: slug, - WorkosName: name, - WorkosDescription: conv.ToPGText(""), - WorkosCreatedAt: conv.ToPGTimestamptz(updatedAt), - WorkosUpdatedAt: conv.ToPGTimestamptz(updatedAt), - WorkosLastEventID: conv.ToPGText(lastEventID), - }) - require.NoError(t, err) -} diff --git a/server/internal/background/activities/list_workos_organizations.go b/server/internal/background/activities/list_workos_organizations.go deleted file mode 100644 index 88ca02feed..0000000000 --- a/server/internal/background/activities/list_workos_organizations.go +++ /dev/null @@ -1,37 +0,0 @@ -package activities - -import ( - "context" - "log/slog" - - "github.com/speakeasy-api/gram/server/internal/attr" - "github.com/speakeasy-api/gram/server/internal/oops" -) - -type ListWorkOSOrganizations struct { - logger *slog.Logger - workos WorkOSClient -} - -func NewListWorkOSOrganizations(logger *slog.Logger, workosClient WorkOSClient) *ListWorkOSOrganizations { - return &ListWorkOSOrganizations{ - logger: logger.With(attr.SlogComponent("list_workos_organizations")), - workos: workosClient, - } -} - -func (l *ListWorkOSOrganizations) Do(ctx context.Context) ([]string, error) { - orgs, err := l.workos.ListOrganizations(ctx) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list WorkOS organizations").Log(ctx, l.logger) - } - - orgIDs := make([]string, 0, len(orgs)) - for _, org := range orgs { - if org.ID != "" { - orgIDs = append(orgIDs, org.ID) - } - } - - return orgIDs, nil -} diff --git a/server/internal/background/activities/process_workos_org_events.go b/server/internal/background/activities/process_workos_org_events.go index 074b2ac574..f47d18352d 100644 --- a/server/internal/background/activities/process_workos_org_events.go +++ b/server/internal/background/activities/process_workos_org_events.go @@ -14,6 +14,7 @@ import ( accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" "github.com/speakeasy-api/gram/server/internal/attr" + "github.com/speakeasy-api/gram/server/internal/auth/orgslug" "github.com/speakeasy-api/gram/server/internal/conv" "github.com/speakeasy-api/gram/server/internal/database" "github.com/speakeasy-api/gram/server/internal/o11y" @@ -236,6 +237,8 @@ func handleOrganizationUpsert(ctx context.Context, logger *slog.Logger, dbtx dat repo := orgrepo.New(dbtx) row, err := repo.GetOrganizationByWorkosID(ctx, conv.ToPGText(payload.ID)) + organizationID := payload.ExternalID + var slug string switch { case errors.Is(err, pgx.ErrNoRows): if payload.ExternalID == "" { @@ -247,8 +250,26 @@ func handleOrganizationUpsert(ctx context.Context, logger *slog.Logger, dbtx dat logger.WarnContext(ctx, "skipping organization event for unlinked org with no external_id", attr.SlogWorkOSOrganizationID(payload.ID)) return nil } + row, err = repo.GetOrganizationMetadata(ctx, payload.ExternalID) + switch { + case errors.Is(err, pgx.ErrNoRows): + slug, err = uniqueOrganizationSlug(ctx, repo, payload.Name, payload.ID) + if err != nil { + return err + } + case err != nil: + return fmt.Errorf("get organization by external id %q: %w", payload.ExternalID, err) + case row.WorkosID.Valid && row.WorkosID.String != payload.ID: + return fmt.Errorf("workos organization %q resolved to gram organization %q with different workos_id %q", payload.ID, row.ID, row.WorkosID.String) + default: + organizationID = row.ID + slug = row.Slug + } case err != nil: return fmt.Errorf("get organization by workos id %q: %w", payload.ID, err) + default: + organizationID = row.ID + slug = row.Slug } var lastEventID *string @@ -263,15 +284,17 @@ func handleOrganizationUpsert(ctx context.Context, logger *slog.Logger, dbtx dat return nil } - organizationID := payload.ExternalID - if err == nil { - organizationID = row.ID + if slug == "" { + slug, err = uniqueOrganizationSlug(ctx, repo, payload.Name, payload.ID) + if err != nil { + return err + } } _, err = repo.UpsertOrganizationMetadataFromWorkOS(ctx, orgrepo.UpsertOrganizationMetadataFromWorkOSParams{ ID: organizationID, Name: payload.Name, - Slug: conv.ToSlug(payload.Name), + Slug: slug, WorkosID: conv.ToPGText(payload.ID), WorkosUpdatedAt: conv.ToPGTimestamptz(payload.UpdatedAt), WorkosLastEventID: conv.ToPGText(event.ID), @@ -283,6 +306,18 @@ func handleOrganizationUpsert(ctx context.Context, logger *slog.Logger, dbtx dat return nil } +func uniqueOrganizationSlug(ctx context.Context, repo orgslug.Lookup, name, fallback string) (string, error) { + base := orgslug.Slugify(name) + if base == "" { + base = fallback + } + slug, err := orgslug.FindUnique(ctx, repo, base) + if err != nil { + return "", fmt.Errorf("find unique organization slug: %w", err) + } + return slug, nil +} + func handleOrganizationDeleted(ctx context.Context, logger *slog.Logger, dbtx database.DBTX, event events.Event) error { var payload workosOrganizationEventPayload if err := json.Unmarshal(event.Data, &payload); err != nil { @@ -447,29 +482,3 @@ func handleRoleDeleted(ctx context.Context, logger *slog.Logger, dbtx database.D return nil } - -// ShouldProcessEvent decides whether a WorkOS event should be applied to a -// row, guarding against duplicate-apply when the sync replays history (e.g. -// reconcile schedule overlapping with webhook delivery, or a manual -// backfill). -// -// Algorithm: -// -// - If the row has no recorded last_event_id, it has not yet been touched -// by an event-driven update. Use the row's workos_updated_at as the -// baseline: apply the event only if its payload's updated_at is at least -// as recent as the row. -// - Otherwise, compare event IDs lexicographically. WorkOS event IDs are -// time-ordered (ULIDs), so a strictly greater ID means the event is -// newer than the last one we applied. -// -// Inputs are nilable to model NULL columns directly. -func ShouldProcessEvent(rowLastEventID *string, rowWorkOSUpdatedAt *time.Time, eventID string, eventUpdatedAt time.Time) bool { - if rowLastEventID == nil || *rowLastEventID == "" { - if rowWorkOSUpdatedAt == nil { - return true - } - return !eventUpdatedAt.Before(*rowWorkOSUpdatedAt) - } - return eventID > *rowLastEventID -} diff --git a/server/internal/background/activities/process_workos_org_events_test.go b/server/internal/background/activities/process_workos_org_events_test.go index 759c21d15f..38e890f2dc 100644 --- a/server/internal/background/activities/process_workos_org_events_test.go +++ b/server/internal/background/activities/process_workos_org_events_test.go @@ -237,6 +237,48 @@ func TestProcessWorkOSOrganizationEvents_OrganizationCreatedAndUpdated(t *testin require.False(t, row.DisabledAt.Valid) } +func TestProcessWorkOSOrganizationEvents_OrganizationCreateUsesUniqueSlugOnCollision(t *testing.T) { + t.Parallel() + + ctx := context.Background() + conn := newOrgEventsTestConn(t, "workos_org_events_org_slug_collision") + logger := testenv.NewLogger(t) + + const workosOrgID = "org_01HZSLUGCOLLISION" + const externalID = "sb_01HZSLUGCOLLISION" + + err := orgrepo.New(conn).CreateOrganizationMetadata(ctx, orgrepo.CreateOrganizationMetadataParams{ + ID: "existing-tester-org", + Name: "tester", + Slug: "tester", + }) + require.NoError(t, err) + + stub := newWorkOSClientWithEvents([][]events.Event{ + { + { + ID: "event_01HZSLUG", + Event: "organization.created", + CreatedAt: time.Now(), + Data: []byte(`{"id":"` + workosOrgID + `","object":"organization","name":"tester","external_id":"` + externalID + + `","updated_at":"2026-05-06T10:00:00Z"}`), + }, + }, + }) + + activity := activities.NewProcessWorkOSOrganizationEvents(logger, conn, stub) + _, err = activity.Do(ctx, activities.ProcessWorkOSOrganizationEventsParams{WorkOSOrganizationID: workosOrgID}) + require.NoError(t, err) + + row, err := orgrepo.New(conn).GetOrganizationByWorkosID(ctx, conv.ToPGText(workosOrgID)) + require.NoError(t, err) + require.Equal(t, externalID, row.ID) + require.Equal(t, "tester", row.Name) + require.NotEqual(t, "tester", row.Slug) + require.Contains(t, row.Slug, "tester-") + require.Len(t, row.Slug, len("tester-")+4) +} + func TestProcessWorkOSOrganizationEvents_OrganizationUpdateSkippedWhenStale(t *testing.T) { t.Parallel() diff --git a/server/internal/background/activities/should_process.go b/server/internal/background/activities/should_process.go new file mode 100644 index 0000000000..1c57101411 --- /dev/null +++ b/server/internal/background/activities/should_process.go @@ -0,0 +1,15 @@ +package activities + +import "time" + +// ShouldProcessEvent decides whether a WorkOS event should be applied to a +// row, guarding against duplicate-apply when the sync replays history. +func ShouldProcessEvent(rowLastEventID *string, rowWorkOSUpdatedAt *time.Time, eventID string, eventUpdatedAt time.Time) bool { + if rowLastEventID == nil || *rowLastEventID == "" { + if rowWorkOSUpdatedAt == nil { + return true + } + return !eventUpdatedAt.Before(*rowWorkOSUpdatedAt) + } + return eventID > *rowLastEventID +} diff --git a/server/internal/background/activities/workos_client.go b/server/internal/background/activities/workos_client.go new file mode 100644 index 0000000000..2d9fff4396 --- /dev/null +++ b/server/internal/background/activities/workos_client.go @@ -0,0 +1,12 @@ +package activities + +import ( + "context" + + "github.com/workos/workos-go/v6/pkg/events" +) + +type WorkOSClient interface { + ListEvents(ctx context.Context, opts events.ListEventsOpts) (events.ListEventsResponse, error) + UpdateUserExternalID(ctx context.Context, workosUserID, externalID string) error +} diff --git a/server/internal/background/backfill_workos.go b/server/internal/background/backfill_workos.go deleted file mode 100644 index 492cbe472a..0000000000 --- a/server/internal/background/backfill_workos.go +++ /dev/null @@ -1,79 +0,0 @@ -package background - -import ( - "context" - "fmt" - "time" - - "go.temporal.io/api/enums/v1" - "go.temporal.io/sdk/client" - "go.temporal.io/sdk/temporal" - "go.temporal.io/sdk/workflow" - - "github.com/speakeasy-api/gram/server/internal/background/activities" - tenv "github.com/speakeasy-api/gram/server/internal/temporal" -) - -type BackfillWorkOSParams struct { - WorkOSOrganizationID string `json:"workos_organization_id,omitempty"` -} - -func ExecuteBackfillWorkOSWorkflow(ctx context.Context, env *tenv.Environment, params BackfillWorkOSParams) (client.WorkflowRun, error) { - workflowID := fmt.Sprintf("v1:backfill-workos:%d", time.Now().Unix()) - if params.WorkOSOrganizationID != "" { - workflowID = fmt.Sprintf("v1:backfill-workos:%s:%d", params.WorkOSOrganizationID, time.Now().Unix()) - } - - return env.Client().ExecuteWorkflow(ctx, - client.StartWorkflowOptions{ - ID: workflowID, - TaskQueue: string(env.Queue()), - WorkflowExecutionTimeout: 2 * time.Hour, - WorkflowIDReusePolicy: enums.WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE, - }, - BackfillWorkOSWorkflow, - params, - ) -} - -func BackfillWorkOSWorkflow(ctx workflow.Context, params BackfillWorkOSParams) error { - var a *Activities - ctx = workflow.WithActivityOptions(ctx, workflow.ActivityOptions{ - StartToCloseTimeout: 10 * time.Minute, - RetryPolicy: &temporal.RetryPolicy{ - InitialInterval: time.Second, - MaximumInterval: time.Minute, - BackoffCoefficient: 2, - MaximumAttempts: 5, - }, - }) - - if err := workflow.ExecuteActivity(ctx, a.BackfillWorkOSGlobalRoles).Get(ctx, nil); err != nil { - workflow.GetLogger(ctx).Warn("global role backfill failed", "error", err) - } - - if params.WorkOSOrganizationID != "" { - if err := workflow.ExecuteActivity(ctx, a.BackfillWorkOSOrganization, activities.BackfillWorkOSOrganizationParams{ - WorkOSOrganizationID: params.WorkOSOrganizationID, - }).Get(ctx, nil); err != nil { - return fmt.Errorf("backfill WorkOS organization %q: %w", params.WorkOSOrganizationID, err) - } - - return nil - } - - var orgIDs []string - if err := workflow.ExecuteActivity(ctx, a.ListWorkOSOrganizations).Get(ctx, &orgIDs); err != nil { - return fmt.Errorf("list WorkOS organizations: %w", err) - } - - for _, orgID := range orgIDs { - if err := workflow.ExecuteActivity(ctx, a.BackfillWorkOSOrganization, activities.BackfillWorkOSOrganizationParams{ - WorkOSOrganizationID: orgID, - }).Get(ctx, nil); err != nil { - workflow.GetLogger(ctx).Warn("WorkOS organization backfill failed", "workos_org_id", orgID, "error", err) - } - } - - return nil -} diff --git a/server/internal/background/backfill_workos_test.go b/server/internal/background/backfill_workos_test.go deleted file mode 100644 index 09dd906a5d..0000000000 --- a/server/internal/background/backfill_workos_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package background - -import ( - "context" - "testing" - - "github.com/stretchr/testify/require" - "go.temporal.io/sdk/activity" - "go.temporal.io/sdk/testsuite" - - "github.com/speakeasy-api/gram/server/internal/background/activities" -) - -func TestBackfillWorkOSWorkflow_BackfillsAllOrganizations(t *testing.T) { - t.Parallel() - - var suite testsuite.WorkflowTestSuite - env := suite.NewTestWorkflowEnvironment() - - env.RegisterActivityWithOptions( - func(context.Context) ([]string, error) { - return []string{"org_01", "org_02"}, nil - }, - activity.RegisterOptions{Name: "ListWorkOSOrganizations"}, - ) - - var globalRolesBackfilled bool - env.RegisterActivityWithOptions( - func(context.Context) error { - globalRolesBackfilled = true - return nil - }, - activity.RegisterOptions{Name: "BackfillWorkOSGlobalRoles"}, - ) - - var backfilledOrgIDs []string - env.RegisterActivityWithOptions( - func(_ context.Context, params activities.BackfillWorkOSOrganizationParams) error { - backfilledOrgIDs = append(backfilledOrgIDs, params.WorkOSOrganizationID) - return nil - }, - activity.RegisterOptions{Name: "BackfillWorkOSOrganization"}, - ) - - env.ExecuteWorkflow(BackfillWorkOSWorkflow, BackfillWorkOSParams{}) - - require.True(t, env.IsWorkflowCompleted()) - require.NoError(t, env.GetWorkflowError()) - require.True(t, globalRolesBackfilled) - require.Equal(t, []string{"org_01", "org_02"}, backfilledOrgIDs) -} - -func TestBackfillWorkOSWorkflow_BackfillsSingleOrganization(t *testing.T) { - t.Parallel() - - var suite testsuite.WorkflowTestSuite - env := suite.NewTestWorkflowEnvironment() - - var globalRolesBackfilled bool - env.RegisterActivityWithOptions( - func(context.Context) error { - globalRolesBackfilled = true - return nil - }, - activity.RegisterOptions{Name: "BackfillWorkOSGlobalRoles"}, - ) - - var backfilledOrgIDs []string - env.RegisterActivityWithOptions( - func(_ context.Context, params activities.BackfillWorkOSOrganizationParams) error { - backfilledOrgIDs = append(backfilledOrgIDs, params.WorkOSOrganizationID) - return nil - }, - activity.RegisterOptions{Name: "BackfillWorkOSOrganization"}, - ) - - env.ExecuteWorkflow(BackfillWorkOSWorkflow, BackfillWorkOSParams{WorkOSOrganizationID: "org_01TARGET"}) - - require.True(t, env.IsWorkflowCompleted()) - require.NoError(t, env.GetWorkflowError()) - require.True(t, globalRolesBackfilled) - require.Equal(t, []string{"org_01TARGET"}, backfilledOrgIDs) -} diff --git a/server/internal/background/worker.go b/server/internal/background/worker.go index 8c0235a76b..b800a64966 100644 --- a/server/internal/background/worker.go +++ b/server/internal/background/worker.go @@ -288,9 +288,6 @@ func NewTemporalWorker( temporalWorker.RegisterActivity(activities.SignalAssistantThread) temporalWorker.RegisterActivity(activities.CancelAssistantsSubscription) // WorkOS sync activities - temporalWorker.RegisterActivity(activities.ListWorkOSOrganizations) - temporalWorker.RegisterActivity(activities.BackfillWorkOSOrganization) - temporalWorker.RegisterActivity(activities.BackfillWorkOSGlobalRoles) temporalWorker.RegisterActivity(activities.ProcessWorkOSOrganizationEvents) temporalWorker.RegisterActivity(activities.ProcessWorkOSGlobalRoleEvents) temporalWorker.RegisterActivity(activities.ProcessWorkOSUserEvents) @@ -331,7 +328,6 @@ func NewTemporalWorker( temporalWorker.RegisterWorkflow(ProcessWorkOSGlobalRoleEventsWorkflowDebounced) temporalWorker.RegisterWorkflow(ProcessWorkOSUserEventsWorkflow) temporalWorker.RegisterWorkflow(ProcessWorkOSUserEventsWorkflowDebounced) - temporalWorker.RegisterWorkflow(BackfillWorkOSWorkflow) // Assistants signup followups temporalWorker.RegisterWorkflow(CancelAssistantsSubscriptionWorkflow) // Outbox -> Relay workflow and GC diff --git a/server/internal/organizations/queries.sql b/server/internal/organizations/queries.sql index 24fdbf7d6f..e2cc22c483 100644 --- a/server/internal/organizations/queries.sql +++ b/server/internal/organizations/queries.sql @@ -172,6 +172,14 @@ FROM organization_user_relationships WHERE organization_id = @organization_id AND user_id = @user_id; +-- name: SetOrganizationRelationshipWorkOSCursor :exec +UPDATE organization_user_relationships +SET workos_updated_at = @workos_updated_at, + workos_last_event_id = @workos_last_event_id, + updated_at = clock_timestamp() +WHERE organization_id = @organization_id + AND user_id = @user_id; + -- name: GetRelationshipByMembershipID :one SELECT * FROM organization_user_relationships @@ -495,4 +503,4 @@ RETURNING id, svix_app_id, webhooks_enabled; -- name: GetSvixAppID :one SELECT svix_app_id FROM organization_metadata -WHERE id = @id AND svix_app_id IS NOT NULL; \ No newline at end of file +WHERE id = @id AND svix_app_id IS NOT NULL; diff --git a/server/internal/organizations/repo/queries.sql.go b/server/internal/organizations/repo/queries.sql.go index f4ed0bbf5b..2b9c6d5538 100644 --- a/server/internal/organizations/repo/queries.sql.go +++ b/server/internal/organizations/repo/queries.sql.go @@ -752,6 +752,32 @@ func (q *Queries) SetOrgWorkosID(ctx context.Context, arg SetOrgWorkosIDParams) return i, err } +const setOrganizationRelationshipWorkOSCursor = `-- name: SetOrganizationRelationshipWorkOSCursor :exec +UPDATE organization_user_relationships +SET workos_updated_at = $1, + workos_last_event_id = $2, + updated_at = clock_timestamp() +WHERE organization_id = $3 + AND user_id = $4 +` + +type SetOrganizationRelationshipWorkOSCursorParams struct { + WorkosUpdatedAt pgtype.Timestamptz + WorkosLastEventID pgtype.Text + OrganizationID string + UserID pgtype.Text +} + +func (q *Queries) SetOrganizationRelationshipWorkOSCursor(ctx context.Context, arg SetOrganizationRelationshipWorkOSCursorParams) error { + _, err := q.db.Exec(ctx, setOrganizationRelationshipWorkOSCursor, + arg.WorkosUpdatedAt, + arg.WorkosLastEventID, + arg.OrganizationID, + arg.UserID, + ) + return err +} + const setUserWorkOSMemberships = `-- name: SetUserWorkOSMemberships :exec WITH input_memberships AS ( SELECT unnest($2::text[]) AS workos_org_id, diff --git a/server/internal/thirdparty/workos/client.go b/server/internal/thirdparty/workos/client.go index 9dafcc0fe7..03024de0d0 100644 --- a/server/internal/thirdparty/workos/client.go +++ b/server/internal/thirdparty/workos/client.go @@ -172,6 +172,8 @@ func convertUser(u usermanagement.User) User { Email: u.Email, ProfilePictureURL: u.ProfilePictureURL, ExternalID: u.ExternalID, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, } } diff --git a/server/internal/thirdparty/workos/stub.go b/server/internal/thirdparty/workos/stub.go index ca8703d6d3..fe85e7a54e 100644 --- a/server/internal/thirdparty/workos/stub.go +++ b/server/internal/thirdparty/workos/stub.go @@ -232,6 +232,14 @@ func (s *StubClient) UpsertOrganizationMembership(member Member) { state.memberships[member.ID] = member } +func (s *StubClient) UpsertUser(orgID string, user User) { + s.mut.Lock() + defer s.mut.Unlock() + + state := s.orgState(orgID) + state.users[user.ID] = user +} + func (s *StubClient) UpdateMemberRole(_ context.Context, membershipID string, roleSlug string) (*Member, error) { s.mut.Lock() defer s.mut.Unlock() diff --git a/server/internal/thirdparty/workos/user.go b/server/internal/thirdparty/workos/user.go index b34831a577..b2268fe948 100644 --- a/server/internal/thirdparty/workos/user.go +++ b/server/internal/thirdparty/workos/user.go @@ -28,6 +28,8 @@ type User struct { Email string ProfilePictureURL string ExternalID string + CreatedAt string + UpdatedAt string } // ListMembers lists all active organization memberships for the given org.