From 6289f50f5e580b3991d38b5d1027920516cbab97 Mon Sep 17 00:00:00 2001 From: tgmendes Date: Wed, 13 May 2026 14:47:38 +0100 Subject: [PATCH 01/12] chore: use local database state for role management --- server/cmd/gram/start.go | 3 +- server/internal/access/createrole_test.go | 28 +- server/internal/access/deleterole_test.go | 76 +- server/internal/access/getrole_test.go | 33 +- server/internal/access/impl.go | 693 ++--------------- server/internal/access/listmembers_test.go | 60 +- server/internal/access/listroles_test.go | 36 +- server/internal/access/listusergrants_test.go | 32 +- server/internal/access/mock_role_test.go | 41 - server/internal/access/queries.sql | 133 ++++ server/internal/access/rbac_test.go | 67 +- server/internal/access/repo/queries.sql.go | 310 ++++++++ server/internal/access/role_manager.go | 721 ++++++++++++++++++ server/internal/access/role_manager_test.go | 85 +++ server/internal/access/setup_internal_test.go | 2 +- server/internal/access/setup_test.go | 58 +- .../internal/access/updatememberrole_test.go | 71 +- server/internal/access/updaterole_test.go | 103 +-- server/internal/attr/conventions.go | 12 + server/internal/conv/from.go | 10 + server/internal/conv/from_test.go | 15 + 21 files changed, 1557 insertions(+), 1032 deletions(-) create mode 100644 server/internal/access/role_manager.go create mode 100644 server/internal/access/role_manager_test.go diff --git a/server/cmd/gram/start.go b/server/cmd/gram/start.go index dcef683797..7f583a89f1 100644 --- a/server/cmd/gram/start.go +++ b/server/cmd/gram/start.go @@ -866,7 +866,8 @@ func newStartCommand() *cli.Command { about.Attach(mux, about.NewService(logger, tracerProvider)) external.AttachWebhookHandler(mux, external.NewWebhookHandler(logger, tracerProvider, newWorkOSWebhooksClient(c), temporalEnv)) - access.Attach(mux, access.NewService(logger, tracerProvider, db, chDB, sessionManager, roleClient, authzEngine, productFeatures, auditLogger)) + roleManager := access.NewRoleManager(logger, db, roleClient, authzEngine) + access.Attach(mux, access.NewService(logger, tracerProvider, db, chDB, sessionManager, roleManager, authzEngine, productFeatures, auditLogger)) assistants.Attach(mux, assistantsSvc) assistantmemories.Attach(mux, assistantmemories.NewService( logger, diff --git a/server/internal/access/createrole_test.go b/server/internal/access/createrole_test.go index d584992fc1..dbdfcb324b 100644 --- a/server/internal/access/createrole_test.go +++ b/server/internal/access/createrole_test.go @@ -39,12 +39,6 @@ func TestService_CreateRole(t *testing.T) { CreatedAt: mockRoleTimestamp, UpdatedAt: mockRoleTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "member"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "member"), - // user_workos_only has never logged into Gram — should not be counted - mockMember(mockidp.MockOrgID, "membership_workos_only", "user_workos_only", "member"), - }, nil).Once() ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "org-custom-builder").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -60,8 +54,12 @@ func TestService_CreateRole(t *testing.T) { CreatedAt: mockMembershipTimestamp, }, nil).Once() + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", authz.SystemRoleMember)) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember("", "membership_2", "user_2", authz.SystemRoleMember)) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember("", "membership_workos_only", "user_workos_only", authz.SystemRoleMember)) role, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{ Name: "Custom Builder", @@ -106,35 +104,25 @@ func TestService_CreateRole_WorkOSCreateFailure(t *testing.T) { require.Contains(t, err.Error(), "create role in workos") } -func TestService_CreateRole_ContinuesAfterConflictWhenRoleAlreadyExists(t *testing.T) { +func TestService_CreateRole_WorkOSConflictFailure(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - authCtx, ok := contextvalues.GetAuthContext(ctx) - require.True(t, ok) - require.NotNil(t, authCtx) - existingRole := mockRole("role_existing", "Custom Builder", "org-custom-builder", "Can build selected resources") ti.roles.On("CreateRole", mock.Anything, mockidp.MockOrgID, thirdpartyworkos.CreateRoleOpts{ Name: "Custom Builder", Slug: "org-custom-builder", Description: "Can build selected resources", }).Return((*thirdpartyworkos.Role)(nil), &thirdpartyworkos.APIError{Method: "POST", Path: "/authorization/organizations/org_workos_test/roles", StatusCode: 409, Body: "role already exists"}).Once() - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{existingRole}, nil).Once() - role, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{ + _, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{ Name: "Custom Builder", Description: "Can build selected resources", Grants: []*gen.RoleGrant{ {Scope: string(authz.ScopeProjectRead), Selectors: []*gen.Selector{{ResourceKind: "project", ResourceID: "project-1"}}}, }, }) - require.NoError(t, err) - require.Equal(t, "role_existing", role.ID) - require.Equal(t, "Custom Builder", role.Name) - - grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "org-custom-builder")) - require.Len(t, grants, 1) - require.Equal(t, authCtx.ActiveOrganizationID, grants[0].OrganizationID) + require.Error(t, err) + require.Contains(t, err.Error(), "create role in workos") } func TestService_CreateRole_RejectsEmptySlug(t *testing.T) { diff --git a/server/internal/access/deleterole_test.go b/server/internal/access/deleterole_test.go index 16f852903d..9ec1c05421 100644 --- a/server/internal/access/deleterole_test.go +++ b/server/internal/access/deleterole_test.go @@ -25,15 +25,12 @@ func TestService_DeleteRole(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(nil).Once() seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeMCPConnect, authz.WildcardResource) - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) require.NoError(t, err) grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder")) @@ -48,14 +45,12 @@ func TestService_DeleteRole_ReassignsMembersToDefault(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_other", "user_3", "admin"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_other", "user_3", "admin")) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", authz.SystemRoleMember).Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -72,7 +67,7 @@ func TestService_DeleteRole_ReassignsMembersToDefault(t *testing.T) { }, nil).Once() ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(nil).Once() - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) require.NoError(t, err) } @@ -84,16 +79,12 @@ func TestService_DeleteRole_ReassignFailureHaltsDelete(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder")) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", authz.SystemRoleMember).Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Once() seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) require.Error(t, err) require.Contains(t, err.Error(), "reassign member to default role") @@ -110,15 +101,10 @@ func TestService_DeleteRole_PartialReassignFailureStopsLoop(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - // Iteration order over the slice is deterministic, so seed an explicit - // success-then-failure pair to exercise the partial-failure cache flush. - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder")) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", authz.SystemRoleMember).Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -129,7 +115,7 @@ func TestService_DeleteRole_PartialReassignFailureStopsLoop(t *testing.T) { ti.roles.On("UpdateMemberRole", mock.Anything, "membership_2", authz.SystemRoleMember).Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Once() seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) require.Error(t, err) require.Contains(t, err.Error(), "reassign member to default role") @@ -143,9 +129,8 @@ func TestService_DeleteRole_NotFound(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{}, nil).Once() - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_missing"}) + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "00000000-0000-0000-0000-000000000001"}) require.Error(t, err) require.Contains(t, err.Error(), "role not found") } @@ -154,11 +139,11 @@ func TestService_DeleteRole_SystemRole(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - }, nil).Once() + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_admin"}) + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) require.Error(t, err) require.Contains(t, err.Error(), "system roles cannot be deleted") } @@ -170,20 +155,16 @@ func TestService_DeleteRole_WorkOSDeleteFailure(t *testing.T) { authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(errors.New("workos unavailable")).Once() seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) require.Error(t, err) require.Contains(t, err.Error(), "delete role in workos") grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder")) - require.Empty(t, grants) - + require.Len(t, grants, 1) } func TestService_DeleteRole_AuditLog(t *testing.T) { @@ -196,14 +177,11 @@ func TestService_DeleteRole_AuditLog(t *testing.T) { beforeCount, err := audittest.AuditLogCountByAction(ctx, ti.conn, audit.ActionAccessRoleDelete) require.NoError(t, err) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Audit Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Audit Builder", "custom-builder", "Old description")) ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(nil).Once() seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") - err = ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) + err = ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) require.NoError(t, err) record, err := audittest.LatestAuditLogByAction(ctx, ti.conn, audit.ActionAccessRoleDelete) diff --git a/server/internal/access/getrole_test.go b/server/internal/access/getrole_test.go index 05d236cb3a..f33c62745c 100644 --- a/server/internal/access/getrole_test.go +++ b/server/internal/access/getrole_test.go @@ -1,17 +1,13 @@ package access import ( - "errors" - mockidp "github.com/speakeasy-api/gram/dev-idp/pkg/testidp" "testing" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" gen "github.com/speakeasy-api/gram/server/gen/access" "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/contextvalues" - thirdpartyworkos "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" ) @@ -23,27 +19,22 @@ func TestService_GetRole(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - mockRole("role_custom", "Custom Builder", "custom-builder", "Can build selected resources"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_3", "user_3", "admin"), - // user_workos_only has never logged into Gram — should not be counted - mockMember(mockidp.MockOrgID, "membership_workos_only", "user_workos_only", "custom-builder"), - }, nil).Once() + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + customID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build selected resources")) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_3", "user3@test.com", "User 3", "user_3", "membership_3") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember("", "membership_2", "user_2", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_3", mockMember("", "membership_3", "user_3", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember("", "membership_workos_only", "user_workos_only", "custom-builder")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeMCPConnect, authz.WildcardResource) - role, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: "role_custom"}) + role, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: customID}) require.NoError(t, err) - require.Equal(t, "role_custom", role.ID) + require.Equal(t, customID, role.ID) require.Equal(t, "Custom Builder", role.Name) require.Equal(t, "Can build selected resources", role.Description) require.False(t, role.IsSystem) @@ -66,9 +57,8 @@ func TestService_GetRole_NotFound(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{}, nil).Once() - _, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: "role_missing"}) + _, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: "00000000-0000-0000-0000-000000000001"}) require.Error(t, err) require.Contains(t, err.Error(), "role not found") } @@ -103,13 +93,12 @@ func TestService_GetRole_OrganizationNotLinkedToWorkOS(t *testing.T) { require.Contains(t, err.Error(), "organization is not linked to WorkOS") } -func TestService_GetRole_WorkOSListRolesFailure(t *testing.T) { +func TestService_GetRole_InvalidID(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role(nil), errors.New("workos unavailable")).Once() _, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: "role_custom"}) require.Error(t, err) - require.Contains(t, err.Error(), "list roles from workos") + require.Contains(t, err.Error(), "invalid role ID") } diff --git a/server/internal/access/impl.go b/server/internal/access/impl.go index 979fb03c3c..b3cff10018 100644 --- a/server/internal/access/impl.go +++ b/server/internal/access/impl.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "log/slog" - "regexp" "strings" "time" @@ -35,29 +34,11 @@ import ( orgrepo "github.com/speakeasy-api/gram/server/internal/organizations/repo" "github.com/speakeasy-api/gram/server/internal/productfeatures" pfRepo "github.com/speakeasy-api/gram/server/internal/productfeatures/repo" - "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" usersrepo "github.com/speakeasy-api/gram/server/internal/users/repo" ) -var ( - errConnectedUserNotFound = errors.New("connected user not found") - // Custom role names become stable slugs and user-facing identifiers, so keep - // them to a predictable ASCII set instead of silently normalizing symbols. - validRoleNamePattern = regexp.MustCompile(`^[A-Za-z0-9 _-]+$`) -) - -type RoleProvider interface { - ListRoles(ctx context.Context, orgID string) ([]workos.Role, error) - CreateRole(ctx context.Context, orgID string, opts workos.CreateRoleOpts) (*workos.Role, error) - UpdateRole(ctx context.Context, orgID string, roleSlug string, opts workos.UpdateRoleOpts) (*workos.Role, error) - DeleteRole(ctx context.Context, orgID string, roleSlug string) error - ListMembers(ctx context.Context, orgID string) ([]workos.Member, error) - UpdateMemberRole(ctx context.Context, membershipID string, roleSlug string) (*workos.Member, error) - GetUser(ctx context.Context, userID string) (*workos.User, error) - ListOrgUsers(ctx context.Context, orgID string) (map[string]workos.User, error) - GetOrgMembership(ctx context.Context, workOSUserID, workOSOrgID string) (*workos.Member, error) -} +var errConnectedUserNotFound = errors.New("connected user not found") // FeatureCacheWriter updates the Redis cache entry for a feature flag after a // direct DB write, keeping the cache consistent with the authoritative state. @@ -72,7 +53,7 @@ type Service struct { chConn driver.Conn auth *auth.Auth authz *authz.Engine - roles RoleProvider + roleMgr *RoleManager featureCache FeatureCacheWriter audit *audit.Logger } @@ -86,7 +67,7 @@ func NewService( db *pgxpool.Pool, chConn driver.Conn, sessions *sessions.Manager, - roles RoleProvider, + roleMgr *RoleManager, authz *authz.Engine, featureCache FeatureCacheWriter, auditLogger *audit.Logger, @@ -100,7 +81,7 @@ func NewService( chConn: chConn, auth: auth.New(logger, db, sessions, authz), authz: authz, - roles: roles, + roleMgr: roleMgr, featureCache: featureCache, audit: auditLogger, } @@ -120,10 +101,9 @@ func (s *Service) APIKeyAuth(ctx context.Context, key string, schema *security.A return s.auth.Authorize(ctx, key, schema) } -// ListRoles treats WorkOS as the source of truth for role records while Gram -// remains the source of truth for role grants. +// ListRoles reads local role records and enriches them with Gram's local grant state. func (s *Service) ListRoles(ctx context.Context, _ *gen.ListRolesPayload) (*gen.ListRolesResult, error) { - ac, workosOrgID, err := s.roleOrgContext(ctx) + ac, _, err := s.roleOrgContext(ctx) if err != nil { return nil, err } @@ -135,36 +115,13 @@ func (s *Service) ListRoles(ctx context.Context, _ *gen.ListRolesPayload) (*gen. attr.UserID(ac.UserID), ) - wRoles, err := s.roles.ListRoles(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list roles from workos").Log(ctx, s.logger) - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, s.logger) - } - memberCounts, err := s.localMemberCounts(ctx, ac.ActiveOrganizationID, members) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "count local members by role").Log(ctx, s.logger) - } - - roles := make([]*gen.Role, 0, len(wRoles)) - for _, wr := range wRoles { - role, err := buildRole(ctx, s.logger, s.db, ac.ActiveOrganizationID, wr, memberCounts[wr.Slug]) - if err != nil { - return nil, err - } - roles = append(roles, role) - } - - return &gen.ListRolesResult{Roles: roles}, nil + return s.roleMgr.ListRoles(ctx, ac.ActiveOrganizationID) } -// GetRole returns the WorkOS role definition enriched with Gram's local grant +// GetRole returns the role definition enriched with Gram's local grant // state so callers see the complete effective role configuration in one place. func (s *Service) GetRole(ctx context.Context, payload *gen.GetRolePayload) (*gen.Role, error) { - ac, workosOrgID, err := s.roleOrgContext(ctx) + ac, _, err := s.roleOrgContext(ctx) if err != nil { return nil, err } @@ -177,35 +134,9 @@ func (s *Service) GetRole(ctx context.Context, payload *gen.GetRolePayload) (*ge attr.AccessRoleID(payload.ID), ) - wRoles, err := s.roles.ListRoles(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list roles from workos").Log(ctx, s.logger) - } - - role, ok := findRoleByID(wRoles, payload.ID) - if !ok { - return nil, oops.E(oops.CodeNotFound, nil, "role not found").Log(ctx, s.logger) - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, s.logger) - } - - memberCounts, err := s.localMemberCounts(ctx, ac.ActiveOrganizationID, members) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "count local members by role").Log(ctx, s.logger) - } - - return buildRole(ctx, s.logger, s.db, ac.ActiveOrganizationID, role, memberCounts[role.Slug]) + return s.roleMgr.GetRoleByID(ctx, ac.ActiveOrganizationID, payload.ID) } -// CreateRole creates a role for a user of a given organization. -// It is an idempotent operation intentionally ordered so that member assignment happens last. -// If WorkOS role creation succeeds but local grant sync fails, we return an -// error with no users assigned to the new role. That leaves a partially -// created role behind, but keeps the outcome safe and retryable: repeating the -// request can finish configuration without having granted accidental access. func (s *Service) CreateRole(ctx context.Context, payload *gen.CreateRolePayload) (*gen.Role, error) { ac, workosOrgID, err := s.roleOrgContext(ctx) if err != nil { @@ -215,136 +146,36 @@ func (s *Service) CreateRole(ctx context.Context, payload *gen.CreateRolePayload return nil, err } - roleSlug, slugErr := slugify(payload.Name) - if slugErr != nil { - return nil, slugErr - } logger := s.logger.With( attr.SlogOrganizationID(ac.ActiveOrganizationID), attr.SlogUserID(ac.UserID), - attr.SlogAccessRoleSlug(roleSlug), - ) - trace.SpanFromContext(ctx).SetAttributes( - attr.OrganizationID(ac.ActiveOrganizationID), - attr.UserID(ac.UserID), - attr.AccessRoleSlug(roleSlug), ) - - wr, err := s.roles.CreateRole(ctx, workosOrgID, workos.CreateRoleOpts{ - Name: payload.Name, - Slug: roleSlug, - Description: payload.Description, - }) - var apiErr *workos.APIError - switch { - case errors.As(err, &apiErr) && apiErr.StatusCode == 409: - wRoles, listErr := s.roles.ListRoles(ctx, workosOrgID) - if listErr != nil { - return nil, oops.E(oops.CodeUnexpected, listErr, "list roles after create conflict").Log(ctx, logger) - } - - var existingRole workos.Role - ok := false - for _, candidate := range wRoles { - if candidate.Slug == roleSlug { - existingRole = candidate - ok = true - break - } - } - if !ok { - return nil, oops.E(oops.CodeUnexpected, err, "create role in workos").Log(ctx, logger) - } - - wr = &existingRole - case err != nil: - return nil, oops.E(oops.CodeUnexpected, err, "create role in workos").Log(ctx, logger) - } trace.SpanFromContext(ctx).SetAttributes( attr.OrganizationID(ac.ActiveOrganizationID), attr.UserID(ac.UserID), - attr.AccessRoleID(wr.ID), ) - // Stop before assigning members if grant sync fails. That can leave behind a - // newly created WorkOS role with no local grants, but it avoids assigning users - // to a role whose effective permissions are incomplete or unknown. Returning an - // error makes the setup retryable without creating accidental access. - if err := authz.SyncGrants(ctx, s.logger, s.db, ac.ActiveOrganizationID, wr.Slug, roleGrantPayloads(payload.Grants)); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "sync grants for created role").Log(ctx, logger) - } - - assignedWorkosIDs := make([]string, 0, len(payload.MemberIds)) - if len(payload.MemberIds) > 0 { - // payload.MemberIds are Gram user IDs (returned by ListMembers). - // Resolve them to WorkOS user IDs so we can look up WorkOS memberships. - gramToWorkos, err := gramToWorkosIDMap(ctx, s.db, payload.MemberIds) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "resolve gram user ids to workos ids").Log(ctx, logger) - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, logger) - } - - membershipByUser := membershipsByUserID(members) - - for _, gramID := range payload.MemberIds { - workosID, ok := gramToWorkos[gramID] - if !ok { - continue - } - membershipID, ok := membershipByUser[workosID] - if !ok { - continue - } - - if _, err := s.roles.UpdateMemberRole(ctx, membershipID, wr.Slug); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "assign members to created role").Log(ctx, logger) - } - - assignedWorkosIDs = append(assignedWorkosIDs, workosID) - } - s.authz.InvalidateAllRoleCaches(ctx, ac.ActiveOrganizationID) - } - - // Only count assigned members who have local Gram accounts and are - // connected to this org, consistent with how ListMembers filters users. - assignedCount := 0 - if len(assignedWorkosIDs) > 0 { - localRows, err := usersrepo.New(s.db).GetConnectedUsersByWorkosIDs(ctx, usersrepo.GetConnectedUsersByWorkosIDsParams{ - WorkosIds: assignedWorkosIDs, - OrganizationID: ac.ActiveOrganizationID, - }) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, fmt.Errorf("get connected users by workos ids: %w", err), "resolve local assigned members").Log(ctx, logger) - } - assignedCount = len(localRows) - } - - createdRole, err := buildRole(ctx, logger, s.db, ac.ActiveOrganizationID, *wr, assignedCount) + created, err := s.roleMgr.CreateRole(ctx, ac.ActiveOrganizationID, workosOrgID, payload) if err != nil { return nil, err } + logger = logger.With(attr.SlogAccessRoleSlug(created.Slug)) if err := s.audit.LogAccessRoleCreate(ctx, s.db, audit.LogAccessRoleCreateEvent{ OrganizationID: ac.ActiveOrganizationID, Actor: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), ActorDisplayName: ac.Email, ActorSlug: nil, - RoleID: wr.ID, - RoleName: createdRole.Name, - RoleSlug: wr.Slug, + RoleID: created.Role.ID, + RoleName: created.Role.Name, + RoleSlug: created.Slug, }); err != nil { return nil, oops.E(oops.CodeUnexpected, err, "log access role creation").Log(ctx, logger) } - return createdRole, nil + return created.Role, nil } -// UpdateRole preserves the same split of responsibilities as creation: WorkOS -// owns role identity and membership, while Gram owns the role's grant set. func (s *Service) UpdateRole(ctx context.Context, payload *gen.UpdateRolePayload) (*gen.Role, error) { ac, workosOrgID, err := s.roleOrgContext(ctx) if err != nil { @@ -364,127 +195,30 @@ func (s *Service) UpdateRole(ctx context.Context, payload *gen.UpdateRolePayload attr.AccessRoleID(payload.ID), ) - wRoles, err := s.roles.ListRoles(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list roles from workos").Log(ctx, logger) - } - - currentRole, ok := findRoleByID(wRoles, payload.ID) - if !ok { - return nil, oops.E(oops.CodeNotFound, nil, "role not found").Log(ctx, logger) - } - logger = logger.With(attr.SlogAccessRoleSlug(currentRole.Slug)) - trace.SpanFromContext(ctx).SetAttributes( - attr.OrganizationID(ac.ActiveOrganizationID), - attr.UserID(ac.UserID), - attr.AccessRoleSlug(currentRole.Slug), - ) - sysRole := isSystemRole(currentRole.Slug) - if sysRole && (payload.Name != nil || payload.Description != nil || payload.Grants != nil) { - return nil, oops.E(oops.CodeBadRequest, nil, "system role properties cannot be updated, only member assignment is allowed").Log(ctx, logger) - } - if sysRole && payload.MemberIds == nil { - return nil, oops.E(oops.CodeBadRequest, nil, "system role update requires member_ids").Log(ctx, logger) - } - if payload.Name != nil { - if _, err := slugify(*payload.Name); err != nil { - return nil, err - } - } - - membersBefore, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, logger) - } - memberCountsBefore, err := s.localMemberCounts(ctx, ac.ActiveOrganizationID, membersBefore) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "count local members by role").Log(ctx, logger) - } - existingRole, err := buildRole(ctx, logger, s.db, ac.ActiveOrganizationID, currentRole, memberCountsBefore[currentRole.Slug]) - if err != nil { - return nil, err - } - - // System roles are immutable in WorkOS — skip the role update call and grant sync. - var updatedRole *workos.Role - if sysRole { - updatedRole = ¤tRole - } else { - updatedRole, err = s.roles.UpdateRole(ctx, workosOrgID, currentRole.Slug, workos.UpdateRoleOpts{ - Name: payload.Name, - Description: payload.Description, - }) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "update role in workos").Log(ctx, logger) - } - - // As with role creation, member reassignment happens after local grant sync so - // a failed sync never leaves users attached to a role with incomplete access. - if payload.Grants != nil { - if err := authz.SyncGrants(ctx, s.logger, s.db, ac.ActiveOrganizationID, currentRole.Slug, roleGrantPayloads(payload.Grants)); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "sync grants for updated role").Log(ctx, logger) - } - } - } - - if payload.MemberIds != nil { - gramToWorkos, err := gramToWorkosIDMap(ctx, s.db, payload.MemberIds) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "resolve gram user ids to workos ids").Log(ctx, logger) - } - - membershipByUser := membershipsByUserID(membersBefore) - - for _, gramID := range payload.MemberIds { - workosID, ok := gramToWorkos[gramID] - if !ok { - continue - } - membershipID, ok := membershipByUser[workosID] - if !ok { - continue - } - - if _, err := s.roles.UpdateMemberRole(ctx, membershipID, currentRole.Slug); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "assign members to updated role").Log(ctx, logger) - } - } - s.authz.InvalidateAllRoleCaches(ctx, ac.ActiveOrganizationID) - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, logger) - } - - memberCounts, err := s.localMemberCounts(ctx, ac.ActiveOrganizationID, members) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "count local members by role").Log(ctx, logger) - } - updatedRoleView, err := buildRole(ctx, logger, s.db, ac.ActiveOrganizationID, *updatedRole, memberCounts[updatedRole.Slug]) + updated, err := s.roleMgr.UpdateRole(ctx, ac.ActiveOrganizationID, workosOrgID, payload) if err != nil { return nil, err } + logger = logger.With(attr.SlogAccessRoleSlug(updated.Role.Slug)) + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSlug(updated.Role.Slug)) if err := s.audit.LogAccessRoleUpdate(ctx, s.db, audit.LogAccessRoleUpdateEvent{ OrganizationID: ac.ActiveOrganizationID, Actor: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), ActorDisplayName: ac.Email, ActorSlug: nil, - RoleID: updatedRole.ID, - RoleName: updatedRoleView.Name, - RoleSlug: updatedRole.Slug, - RoleSnapshotBefore: existingRole, - RoleSnapshotAfter: updatedRoleView, + RoleID: updated.Role.ID, + RoleName: updated.After.Name, + RoleSlug: updated.Role.Slug, + RoleSnapshotBefore: updated.Before, + RoleSnapshotAfter: updated.After, }); err != nil { return nil, oops.E(oops.CodeUnexpected, err, "log access role update").Log(ctx, logger) } - return updatedRoleView, nil + return updated.After, nil } -// DeleteRole removes local grants before deleting the WorkOS role so retries can -// still complete cleanup if the external delete fails. func (s *Service) DeleteRole(ctx context.Context, payload *gen.DeleteRolePayload) error { ac, workosOrgID, err := s.roleOrgContext(ctx) if err != nil { @@ -504,67 +238,21 @@ func (s *Service) DeleteRole(ctx context.Context, payload *gen.DeleteRolePayload attr.AccessRoleID(payload.ID), ) - wRoles, err := s.roles.ListRoles(ctx, workosOrgID) - if err != nil { - return oops.E(oops.CodeUnexpected, err, "list roles from workos").Log(ctx, logger) - } - - currentRole, ok := findRoleByID(wRoles, payload.ID) - if !ok { - return oops.E(oops.CodeNotFound, nil, "role not found").Log(ctx, logger) - } - logger = logger.With(attr.SlogAccessRoleSlug(currentRole.Slug)) - trace.SpanFromContext(ctx).SetAttributes( - attr.OrganizationID(ac.ActiveOrganizationID), - attr.UserID(ac.UserID), - attr.AccessRoleSlug(currentRole.Slug), - ) - if isSystemRole(currentRole.Slug) { - return oops.E(oops.CodeBadRequest, nil, "system roles cannot be deleted").Log(ctx, logger) - } - - // WorkOS rejects deleting a role that still has members assigned, so move - // any assigned members to the default member role first. - members, err := s.roles.ListMembers(ctx, workosOrgID) + deletedRole, err := s.roleMgr.DeleteRole(ctx, ac.ActiveOrganizationID, workosOrgID, payload.ID) if err != nil { - return oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, logger) - } - reassigned := false - for _, m := range members { - if m.RoleSlug != currentRole.Slug { - continue - } - if _, err := s.roles.UpdateMemberRole(ctx, m.ID, authz.SystemRoleMember); err != nil { - if reassigned { - s.authz.InvalidateAllRoleCaches(ctx, ac.ActiveOrganizationID) - } - return oops.E(oops.CodeUnexpected, err, "reassign member to default role").Log(ctx, logger) - } - reassigned = true - } - if reassigned { - s.authz.InvalidateAllRoleCaches(ctx, ac.ActiveOrganizationID) - } - - if _, err := repo.New(s.db).DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ - OrganizationID: ac.ActiveOrganizationID, - PrincipalUrn: urn.NewPrincipal(urn.PrincipalTypeRole, currentRole.Slug), - }); err != nil { - return oops.E(oops.CodeUnexpected, err, "delete grants for deleted role").Log(ctx, logger) - } - - if err := s.roles.DeleteRole(ctx, workosOrgID, currentRole.Slug); err != nil { - return oops.E(oops.CodeUnexpected, err, "delete role in workos").Log(ctx, logger) + return err } + logger = logger.With(attr.SlogAccessRoleSlug(deletedRole.Slug)) + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSlug(deletedRole.Slug)) if err := s.audit.LogAccessRoleDelete(ctx, s.db, audit.LogAccessRoleDeleteEvent{ OrganizationID: ac.ActiveOrganizationID, Actor: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), ActorDisplayName: ac.Email, ActorSlug: nil, - RoleID: currentRole.ID, - RoleName: currentRole.Name, - RoleSlug: currentRole.Slug, + RoleID: deletedRole.ID, + RoleName: deletedRole.Name, + RoleSlug: deletedRole.Slug, }); err != nil { return oops.E(oops.CodeUnexpected, err, "log access role deletion").Log(ctx, logger) } @@ -603,8 +291,7 @@ func (s *Service) ListScopes(ctx context.Context, _ *gen.ListScopesPayload) (*ge // ListMembers follows the original access API contract by returning WorkOS user // identifiers while decorating them with the role information the UI needs. func (s *Service) ListMembers(ctx context.Context, _ *gen.ListMembersPayload) (*gen.ListMembersResult, error) { - _, workosOrgID, err := s.roleOrgContext(ctx) - ac, _ := s.authContext(ctx) + ac, _, err := s.roleOrgContext(ctx) if err != nil { return nil, err } @@ -616,71 +303,7 @@ func (s *Service) ListMembers(ctx context.Context, _ *gen.ListMembersPayload) (* attr.UserID(ac.UserID), ) - roles, err := s.roles.ListRoles(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list roles from workos").Log(ctx, s.logger) - } - roleIDBySlug := make(map[string]string, len(roles)) - for _, role := range roles { - roleIDBySlug[role.Slug] = role.ID - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, s.logger) - } - - users, err := s.roles.ListOrgUsers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list org users from workos").Log(ctx, s.logger) - } - - // Batch-resolve WorkOS user IDs to local Gram users, filtering to only - // those connected to this organization via organization_user_relationships. - // This single joined query prevents the list from surfacing users who - // exist in Gram but aren't connected to the current org (which would - // cause UpdateMemberRole to fail). - workosIDs := make([]string, 0, len(users)) - for workosUID := range users { - workosIDs = append(workosIDs, workosUID) - } - localUserRows, err := usersrepo.New(s.db).GetConnectedUsersByWorkosIDs(ctx, usersrepo.GetConnectedUsersByWorkosIDsParams{ - WorkosIds: workosIDs, - OrganizationID: ac.ActiveOrganizationID, - }) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "resolve connected users by workos ids").Log(ctx, s.logger) - } - localUsers := make(map[string]usersrepo.User, len(localUserRows)) - for _, u := range localUserRows { - if u.WorkosID.Valid { - localUsers[u.WorkosID.String] = u - } - } - - result := make([]*gen.AccessMember, 0, len(members)) - for _, member := range members { - user, ok := users[member.UserID] - if !ok { - continue - } - - local, ok := localUsers[member.UserID] - if !ok { - continue - } - - result = append(result, &gen.AccessMember{ - ID: local.ID, - Name: formatUserName(user), - Email: user.Email, - PhotoURL: conv.FromPGText[string](local.PhotoUrl), - RoleID: roleIDBySlug[member.RoleSlug], - JoinedAt: conv.Default(member.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), - }) - } - - return &gen.ListMembersResult{Members: result}, nil + return s.roleMgr.ListMembers(ctx, ac.ActiveOrganizationID) } // ListGrants returns the effective grants for the current user by combining @@ -700,7 +323,7 @@ func (s *Service) ListGrants(ctx context.Context, _ *gen.ListGrantsPayload) (*ge return &gen.ListUserGrantsResult{Grants: allScopesGrants()}, nil } - ac, workosOrgID, err := s.roleOrgContext(ctx) + ac, _, err := s.roleOrgContext(ctx) if err != nil { return nil, err } @@ -734,9 +357,9 @@ func (s *Service) ListGrants(ctx context.Context, _ *gen.ListGrantsPayload) (*ge } principals := []urn.Principal{urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID)} - roleSlugs, err := s.memberRoleSlugs(ctx, workosOrgID, connectedUser.WorkosID.String) + roleSlugs, err := s.roleMgr.MemberRoleSlugs(ctx, ac.ActiveOrganizationID, connectedUser.WorkosID.String) if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, logger) + return nil, oops.E(oops.CodeUnexpected, err, "list member roles").Log(ctx, logger) } for _, roleSlug := range roleSlugs { principals = append(principals, urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug)) @@ -757,7 +380,7 @@ func (s *Service) ListGrants(ctx context.Context, _ *gen.ListGrantsPayload) (*ge // UpdateMemberRole is intentionally stricter than member listing: it only // mutates access for users Gram knows are connected to the local organization. func (s *Service) UpdateMemberRole(ctx context.Context, payload *gen.UpdateMemberRolePayload) (*gen.AccessMember, error) { - ac, workosOrgID, err := s.roleOrgContext(ctx) + ac, _, err := s.roleOrgContext(ctx) if err != nil { return nil, err } @@ -777,21 +400,11 @@ func (s *Service) UpdateMemberRole(ctx context.Context, payload *gen.UpdateMembe attr.AccessRoleID(payload.RoleID), ) - roles, err := s.roles.ListRoles(ctx, workosOrgID) + memberUpdate, err := s.roleMgr.UpdateMemberRole(ctx, ac.ActiveOrganizationID, payload.UserID, payload.RoleID) if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list roles from workos").Log(ctx, logger) - } - - roleSlug := "" - for _, role := range roles { - if role.ID == payload.RoleID { - roleSlug = role.Slug - break - } - } - if roleSlug == "" { - return nil, oops.E(oops.CodeNotFound, nil, "role not found").Log(ctx, logger) + return nil, err } + roleSlug := memberUpdate.RoleSlug logger = logger.With(attr.SlogAccessRoleSlug(roleSlug)) trace.SpanFromContext(ctx).SetAttributes( attr.OrganizationID(ac.ActiveOrganizationID), @@ -799,87 +412,21 @@ func (s *Service) UpdateMemberRole(ctx context.Context, payload *gen.UpdateMembe attr.AccessRoleSlug(roleSlug), ) - connectedUser, err := connectedUser(ctx, s.db, ac.ActiveOrganizationID, payload.UserID) - switch { - case errors.Is(err, errConnectedUserNotFound): - return nil, oops.E(oops.CodeNotFound, nil, "member has not joined this organization").Log(ctx, logger) - case err != nil: - return nil, oops.E(oops.CodeUnexpected, err, "load connected user").Log(ctx, logger) - } - if !connectedUser.WorkosID.Valid || connectedUser.WorkosID.String == "" { - return nil, oops.E(oops.CodeBadRequest, nil, "member is not linked to WorkOS").Log(ctx, logger) - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, logger) - } - - roleIDBySlug := make(map[string]string, len(roles)) - for _, role := range roles { - roleIDBySlug[role.Slug] = role.ID - } - - membershipID := "" - var existingMember workos.Member - for _, member := range members { - if member.UserID == connectedUser.WorkosID.String { - membershipID = member.ID - existingMember = member - break - } - } - if membershipID == "" { - return nil, oops.E(oops.CodeNotFound, nil, "member not found").Log(ctx, logger) - } - - updatedMember, err := s.roles.UpdateMemberRole(ctx, membershipID, roleSlug) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "update member role in workos").Log(ctx, logger) - } - s.authz.InvalidateRoleCache(ctx, payload.UserID, ac.ActiveOrganizationID) - - users, err := s.roles.ListOrgUsers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list org users from workos").Log(ctx, logger) - } - user, ok := users[updatedMember.UserID] - if !ok { - return nil, oops.E(oops.CodeNotFound, nil, "member user not found").Log(ctx, logger) - } - - beforeMember := &gen.AccessMember{ - ID: connectedUser.ID, - Name: formatUserName(user), - Email: user.Email, - PhotoURL: conv.FromPGText[string](connectedUser.PhotoUrl), - RoleID: roleIDBySlug[existingMember.RoleSlug], - JoinedAt: conv.Default(existingMember.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), - } - afterMember := &gen.AccessMember{ - ID: connectedUser.ID, - Name: formatUserName(user), - Email: user.Email, - PhotoURL: conv.FromPGText[string](connectedUser.PhotoUrl), - RoleID: payload.RoleID, - JoinedAt: conv.Default(updatedMember.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), - } - if err := s.audit.LogAccessMemberRoleUpdate(ctx, s.db, audit.LogAccessMemberRoleUpdateEvent{ OrganizationID: ac.ActiveOrganizationID, Actor: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), ActorDisplayName: ac.Email, ActorSlug: nil, - MemberID: connectedUser.ID, - MemberName: afterMember.Name, - MemberEmail: afterMember.Email, - MemberSnapshotBefore: beforeMember, - MemberSnapshotAfter: afterMember, + MemberID: memberUpdate.UserID, + MemberName: memberUpdate.After.Name, + MemberEmail: memberUpdate.After.Email, + MemberSnapshotBefore: memberUpdate.Before, + MemberSnapshotAfter: memberUpdate.After, }); err != nil { return nil, oops.E(oops.CodeUnexpected, err, "log access member role update").Log(ctx, logger) } - return afterMember, nil + return memberUpdate.After, nil } func (s *Service) authContext(ctx context.Context) (*contextvalues.AuthContext, error) { @@ -980,137 +527,6 @@ func authzSelectorToGen(sel authz.Selector) *gen.Selector { return s } -func formatUserName(user workos.User) string { - switch { - case user.FirstName != "" && user.LastName != "": - return user.FirstName + " " + user.LastName - case user.FirstName != "": - return user.FirstName - case user.LastName != "": - return user.LastName - default: - return user.Email - } -} - -// localMemberCounts counts WorkOS members per role slug, but only for members -// who have a local Gram account and are connected to the given organization -// via organization_user_relationships. This ensures counts match what -// ListMembers returns. -func (s *Service) localMemberCounts(ctx context.Context, organizationID string, members []workos.Member) (map[string]int, error) { - workosIDs := make([]string, 0, len(members)) - for _, m := range members { - workosIDs = append(workosIDs, m.UserID) - } - localRows, err := usersrepo.New(s.db).GetConnectedUsersByWorkosIDs(ctx, usersrepo.GetConnectedUsersByWorkosIDsParams{ - WorkosIds: workosIDs, - OrganizationID: organizationID, - }) - if err != nil { - return nil, fmt.Errorf("get connected users by workos ids: %w", err) - } - localSet := make(map[string]struct{}, len(localRows)) - for _, u := range localRows { - if u.WorkosID.Valid { - localSet[u.WorkosID.String] = struct{}{} - } - } - counts := make(map[string]int) - for _, m := range members { - if _, ok := localSet[m.UserID]; ok { - counts[m.RoleSlug]++ - } - } - return counts, nil -} - -// gramToWorkosIDMap resolves Gram user IDs to WorkOS user IDs. -// Dashboard sends Gram IDs (from ListMembers), but WorkOS membership lookups -// require WorkOS user IDs. -func gramToWorkosIDMap(ctx context.Context, db *pgxpool.Pool, gramIDs []string) (map[string]string, error) { - users, err := usersrepo.New(db).GetUsersByIDs(ctx, gramIDs) - if err != nil { - return nil, fmt.Errorf("get users by ids: %w", err) - } - m := make(map[string]string, len(users)) - for _, u := range users { - if u.WorkosID.Valid && u.WorkosID.String != "" { - m[u.ID] = u.WorkosID.String - } - } - return m, nil -} - -func membershipsByUserID(members []workos.Member) map[string]string { - membershipByUser := make(map[string]string, len(members)) - for _, member := range members { - membershipByUser[member.UserID] = member.ID - } - - return membershipByUser -} - -func (s *Service) memberRoleSlugs(ctx context.Context, workosOrgID string, workosUserID string) ([]string, error) { - if workosUserID == "" { - return nil, nil - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, fmt.Errorf("list members for role lookup: %w", err) - } - - roleSlugs := make([]string, 0, len(members)) - seenRoleSlugs := make(map[string]struct{}, len(members)) - - for _, member := range members { - if member.UserID != workosUserID || member.RoleSlug == "" { - continue - } - if _, ok := seenRoleSlugs[member.RoleSlug]; ok { - continue - } - - seenRoleSlugs[member.RoleSlug] = struct{}{} - roleSlugs = append(roleSlugs, member.RoleSlug) - } - - return roleSlugs, nil -} - -func findRoleByID(roles []workos.Role, id string) (workos.Role, bool) { - for _, role := range roles { - if role.ID == id { - return role, true - } - } - - var zero workos.Role - return zero, false -} - -func buildRole(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, organizationID string, role workos.Role, memberCount int) (*gen.Role, error) { - grants, err := authz.GrantsForRole(ctx, logger, db, organizationID, role.Slug) - if err != nil { - return nil, err - } - genGrants := make([]*gen.RoleGrant, 0, len(grants)) - for _, g := range grants { - genGrants = append(genGrants, scopedGrantToGenRoleGrant(g)) - } - - return &gen.Role{ - ID: role.ID, - Name: role.Name, - Description: role.Description, - IsSystem: isSystemRole(role.Slug), - Grants: genGrants, - MemberCount: memberCount, - CreatedAt: conv.Default(role.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), - UpdatedAt: conv.Default(role.UpdatedAt, time.Time{}.UTC().Format(time.RFC3339)), - }, nil -} - func scopedGrantToGenRoleGrant(g *authz.ScopedGrant) *gen.RoleGrant { var selectors []*gen.Selector for _, sel := range g.Selectors { @@ -1263,21 +679,6 @@ func (s *Service) requireSuperAdmin(ctx context.Context) (*contextvalues.AuthCon return ac, nil } -func slugify(name string) (string, error) { - slug := conv.ToSlug(strings.ReplaceAll(name, "_", " ")) - if slug == "" { - return "", oops.E(oops.CodeBadRequest, nil, "role name must contain at least one letter or digit") - } - if !validRoleNamePattern.MatchString(name) { - return "", oops.E(oops.CodeBadRequest, nil, "role name contains invalid characters") - } - if !strings.HasPrefix(slug, "org-") { - slug = "org-" + slug - } - - return slug, nil -} - type challengeUserInfo struct { email string photoURL *string diff --git a/server/internal/access/listmembers_test.go b/server/internal/access/listmembers_test.go index 2522ba3212..89c8eb8d73 100644 --- a/server/internal/access/listmembers_test.go +++ b/server/internal/access/listmembers_test.go @@ -1,16 +1,12 @@ package access import ( - "errors" - mockidp "github.com/speakeasy-api/gram/dev-idp/pkg/testidp" "testing" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" gen "github.com/speakeasy-api/gram/server/gen/access" "github.com/speakeasy-api/gram/server/internal/contextvalues" - thirdpartyworkos "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" ) func TestService_ListMembers(t *testing.T) { @@ -23,18 +19,10 @@ func TestService_ListMembers(t *testing.T) { seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "grace@example.com", "Grace", "user_2", "membership_2") - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User{ - "user_1": mockUser("user_1", "Ada", "Lovelace", "ada@example.com"), - "user_2": mockUser("user_2", "Grace", "", "grace@example.com"), - }, nil).Once() + adminID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember("", "membership_2", "user_2", "custom-builder")) result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) @@ -48,11 +36,11 @@ func TestService_ListMembers(t *testing.T) { // IDs should be Gram user IDs, not WorkOS user IDs. require.Equal(t, "Ada Lovelace", byID["local_user_1"].Name) require.Equal(t, "ada@example.com", byID["local_user_1"].Email) - require.Equal(t, "role_admin", byID["local_user_1"].RoleID) - require.Equal(t, "2024-11-15T15:04:05Z", byID["local_user_1"].JoinedAt) + require.Equal(t, adminID, byID["local_user_1"].RoleID) + require.NotEmpty(t, byID["local_user_1"].JoinedAt) require.Equal(t, "Grace", byID["local_user_2"].Name) - require.Equal(t, "role_builder", byID["local_user_2"].RoleID) + require.Equal(t, builderID, byID["local_user_2"].RoleID) } func TestService_ListMembers_ExcludesDisconnectedUsers(t *testing.T) { @@ -67,17 +55,9 @@ func TestService_ListMembers_ExcludesDisconnectedUsers(t *testing.T) { // to this org — no row in organization_user_relationships. seedDisconnectedUser(t, ctx, ti.conn, "local_user_2", "grace@example.com", "Grace", "user_2") - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "admin"), - }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User{ - "user_1": mockUser("user_1", "Ada", "Lovelace", "ada@example.com"), - "user_2": mockUser("user_2", "Grace", "", "grace@example.com"), - }, nil).Once() + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember("", "membership_2", "user_2", "admin")) result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) @@ -86,19 +66,15 @@ func TestService_ListMembers_ExcludesDisconnectedUsers(t *testing.T) { require.Equal(t, "Ada Lovelace", result.Members[0].Name) } -func TestService_ListMembers_WorkOSUsersFailure(t *testing.T) { +func TestService_ListMembers_UsesDatabaseOnly(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User(nil), errors.New("workos unavailable")).Once() - - _, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) - require.Error(t, err) - require.Contains(t, err.Error(), "list org users from workos") + authCtx, _ := contextvalues.GetAuthContext(ctx) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember("", "membership_1", "user_1", "admin")) + + result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) + require.NoError(t, err) + require.Empty(t, result.Members) } diff --git a/server/internal/access/listroles_test.go b/server/internal/access/listroles_test.go index 3409b8ca8b..e6a51873b7 100644 --- a/server/internal/access/listroles_test.go +++ b/server/internal/access/listroles_test.go @@ -1,16 +1,13 @@ package access import ( - mockidp "github.com/speakeasy-api/gram/dev-idp/pkg/testidp" "testing" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" gen "github.com/speakeasy-api/gram/server/gen/access" "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/contextvalues" - thirdpartyworkos "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" ) @@ -22,21 +19,16 @@ func TestService_ListRoles(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - mockRole("role_custom", "Custom Builder", "custom-builder", "Can build selected resources"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_3", "user_3", "custom-builder"), - // user_workos_only has never logged into Gram — should not be counted - mockMember(mockidp.MockOrgID, "membership_workos_only", "user_workos_only", "custom-builder"), - }, nil).Once() + adminID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + customID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build selected resources")) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_3", "user3@test.com", "User 3", "user_3", "membership_3") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember("", "membership_2", "user_2", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_3", mockMember("", "membership_3", "user_3", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember("", "membership_workos_only", "user_workos_only", "custom-builder")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "admin"), authz.ScopeOrgAdmin, authz.WildcardResource) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-2") @@ -51,7 +43,7 @@ func TestService_ListRoles(t *testing.T) { rolesByID[role.ID] = role } - adminRole := rolesByID["role_admin"] + adminRole := rolesByID[adminID] require.NotNil(t, adminRole) require.Equal(t, "Admin", adminRole.Name) require.True(t, adminRole.IsSystem) @@ -62,7 +54,7 @@ func TestService_ListRoles(t *testing.T) { require.Equal(t, string(authz.ScopeOrgAdmin), adminRole.Grants[0].Scope) require.Nil(t, adminRole.Grants[0].Selectors) - customRole := rolesByID["role_custom"] + customRole := rolesByID[customID] require.NotNil(t, customRole) require.Equal(t, "Custom Builder", customRole.Name) require.False(t, customRole.IsSystem) @@ -96,14 +88,10 @@ func TestService_ListRoles_ExcludesDisconnectedUsersFromMemberCounts(t *testing. // (no organization_user_relationships row). Should not inflate member count. seedDisconnectedUser(t, ctx, ti.conn, "local_user_2", "user2@test.com", "User 2", "user_2") - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - // user_2 appears in WorkOS members but is disconnected locally. - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "admin"), - }, nil).Once() + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) + // user_2 appears in role assignments but is disconnected locally. + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember("", "membership_2", "user_2", "admin")) result, err := ti.service.ListRoles(ctx, &gen.ListRolesPayload{}) require.NoError(t, err) diff --git a/server/internal/access/listusergrants_test.go b/server/internal/access/listusergrants_test.go index 81f6949db4..98f29955bb 100644 --- a/server/internal/access/listusergrants_test.go +++ b/server/internal/access/listusergrants_test.go @@ -1,19 +1,14 @@ package access import ( - "errors" "testing" - mockidp "github.com/speakeasy-api/gram/dev-idp/pkg/testidp" - - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" gen "github.com/speakeasy-api/gram/server/gen/access" "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/authztest" "github.com/speakeasy-api/gram/server/internal/contextvalues" - thirdpartyworkos "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" ) @@ -40,13 +35,11 @@ func TestService_ListGrants(t *testing.T) { ctx = contextvalues.SetAuthContext(ctx, authCtx) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, authCtx.UserID, "member@example.com", "Member User", "workos_user_member", "membership_1") + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, authCtx.UserID, mockMember("", "membership_1", "workos_user_member", "custom-builder")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeUser, authCtx.UserID), authz.ScopeProjectRead, "project_123") seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeMCPConnect, "tool_456") - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "workos_user_member", "custom-builder"), - }, nil).Once() - result, err := ti.service.ListGrants(ctx, &gen.ListGrantsPayload{}) require.NoError(t, err) require.Len(t, result.Grants, 2) @@ -60,7 +53,7 @@ func TestService_ListGrants(t *testing.T) { require.Equal(t, "tool_456", byScope["mcp:connect"].Selectors[0].ResourceID) } -func TestService_ListGrants_MultipleRoles(t *testing.T) { +func TestService_ListGrants_RoleGrants(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -71,13 +64,10 @@ func TestService_ListGrants_MultipleRoles(t *testing.T) { ctx = contextvalues.SetAuthContext(ctx, authCtx) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, authCtx.UserID, "member@example.com", "Member User", "workos_user_member", "membership_1") + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, authCtx.UserID, mockMember("", "membership_1", "workos_user_member", "custom-builder")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project_123") - seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-mcp"), authz.ScopeMCPConnect, "tool_456") - - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "workos_user_member", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_2", "workos_user_member", "custom-mcp"), - }, nil).Once() + seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeMCPConnect, "tool_456") result, err := ti.service.ListGrants(ctx, &gen.ListGrantsPayload{}) require.NoError(t, err) @@ -223,7 +213,7 @@ func TestService_ListGrants_EnterpriseWithoutSessionReturnsFullAccess(t *testing } } -func TestService_ListGrants_WorkOSMembersFailure(t *testing.T) { +func TestService_ListGrants_NoRoleAssignments(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -235,9 +225,7 @@ func TestService_ListGrants_WorkOSMembersFailure(t *testing.T) { seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, authCtx.UserID, "member@example.com", "Member User", "workos_user_member", "membership_1") - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member(nil), errors.New("workos unavailable")).Once() - - _, err := ti.service.ListGrants(ctx, &gen.ListGrantsPayload{}) - require.Error(t, err) - require.Contains(t, err.Error(), "list members from workos") + result, err := ti.service.ListGrants(ctx, &gen.ListGrantsPayload{}) + require.NoError(t, err) + require.Empty(t, result.Grants) } diff --git a/server/internal/access/mock_role_test.go b/server/internal/access/mock_role_test.go index b6e9a8c516..224256313f 100644 --- a/server/internal/access/mock_role_test.go +++ b/server/internal/access/mock_role_test.go @@ -31,14 +31,6 @@ func newMockRoleProvider(t *testing.T) *MockRoleProvider { return roles } -func (m *MockRoleProvider) ListRoles(ctx context.Context, orgID string) ([]thirdpartyworkos.Role, error) { - args := m.Called(ctx, orgID) - if roles, ok := args.Get(0).([]thirdpartyworkos.Role); ok { - return roles, mockErr(args, 1) - } - return nil, mockErr(args, 1) -} - func (m *MockRoleProvider) CreateRole(ctx context.Context, orgID string, opts thirdpartyworkos.CreateRoleOpts) (*thirdpartyworkos.Role, error) { args := m.Called(ctx, orgID, opts) if role, ok := args.Get(0).(*thirdpartyworkos.Role); ok { @@ -60,14 +52,6 @@ func (m *MockRoleProvider) DeleteRole(ctx context.Context, orgID string, roleSlu return mockErr(args, 0) } -func (m *MockRoleProvider) ListMembers(ctx context.Context, orgID string) ([]thirdpartyworkos.Member, error) { - args := m.Called(ctx, orgID) - if members, ok := args.Get(0).([]thirdpartyworkos.Member); ok { - return members, mockErr(args, 1) - } - return nil, mockErr(args, 1) -} - func (m *MockRoleProvider) UpdateMemberRole(ctx context.Context, membershipID string, roleSlug string) (*thirdpartyworkos.Member, error) { args := m.Called(ctx, membershipID, roleSlug) if member, ok := args.Get(0).(*thirdpartyworkos.Member); ok { @@ -76,22 +60,6 @@ func (m *MockRoleProvider) UpdateMemberRole(ctx context.Context, membershipID st return nil, mockErr(args, 1) } -func (m *MockRoleProvider) GetUser(ctx context.Context, userID string) (*thirdpartyworkos.User, error) { - args := m.Called(ctx, userID) - if user, ok := args.Get(0).(*thirdpartyworkos.User); ok { - return user, mockErr(args, 1) - } - return nil, mockErr(args, 1) -} - -func (m *MockRoleProvider) ListOrgUsers(ctx context.Context, orgID string) (map[string]thirdpartyworkos.User, error) { - args := m.Called(ctx, orgID) - if users, ok := args.Get(0).(map[string]thirdpartyworkos.User); ok { - return users, mockErr(args, 1) - } - return nil, mockErr(args, 1) -} - func (m *MockRoleProvider) GetOrgMembership(ctx context.Context, workOSUserID, workOSOrgID string) (*thirdpartyworkos.Member, error) { args := m.Called(ctx, workOSUserID, workOSOrgID) if member, ok := args.Get(0).(*thirdpartyworkos.Member); ok { @@ -133,12 +101,3 @@ func mockMember(orgID, membershipID, userID, roleSlug string) thirdpartyworkos.M CreatedAt: mockMembershipTimestamp, } } - -func mockUser(id, firstName, lastName, email string) thirdpartyworkos.User { - return thirdpartyworkos.User{ - ID: id, - FirstName: firstName, - LastName: lastName, - Email: email, - } -} diff --git a/server/internal/access/queries.sql b/server/internal/access/queries.sql index b53e3944a5..8f8885311d 100644 --- a/server/internal/access/queries.sql +++ b/server/internal/access/queries.sql @@ -148,3 +148,136 @@ SET workos_deleted_at = @workos_deleted_at, WHERE organization_id = @organization_id AND workos_slug = @workos_slug AND deleted_at IS NULL; + +-- name: ListActiveOrganizationRoles :many +SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at +FROM global_roles +WHERE deleted IS FALSE + AND workos_deleted IS FALSE +UNION ALL +SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at +FROM organization_roles +WHERE organization_id = @organization_id + AND deleted IS FALSE + AND workos_deleted IS FALSE +ORDER BY workos_slug; + +-- name: GetOrganizationRoleByID :one +SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at +FROM global_roles +WHERE global_roles.id = sqlc.arg(id) + AND deleted IS FALSE + AND workos_deleted IS FALSE +UNION ALL +SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at +FROM organization_roles +WHERE organization_id = @organization_id + AND organization_roles.id = sqlc.arg(id) + AND deleted IS FALSE + AND workos_deleted IS FALSE +LIMIT 1; + +-- name: ListOrganizationRoleAssignmentsForOrg :many +SELECT + ora.user_id, + ora.workos_user_id, + ora.workos_membership_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.created_at +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = @organization_id + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL +ORDER BY ora.workos_user_id, role_slug; + +-- name: ListMemberRoleSlugsByWorkosUser :many +SELECT DISTINCT COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = @organization_id + AND ora.workos_user_id = @workos_user_id + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL +ORDER BY role_slug; + +-- name: CountMembersByRoleForOrg :many +SELECT + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + COUNT(*)::bigint AS member_count +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = @organization_id + AND ora.user_id IS NOT NULL + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL +GROUP BY role_slug; + +-- name: ReplaceOrganizationRoleAssignment :exec +WITH input_role_urn AS ( + SELECT 'role:organization:' || id::text AS role_urn + FROM organization_roles + WHERE organization_roles.organization_id = @organization_id + AND organization_roles.workos_slug = sqlc.arg(workos_role_slug) + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE + UNION ALL + SELECT 'role:global:' || id::text AS role_urn + FROM global_roles + WHERE global_roles.workos_slug = sqlc.arg(workos_role_slug) + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +), +upserted AS ( + INSERT INTO organization_role_assignments ( + organization_id, + workos_user_id, + user_id, + role_urn, + workos_membership_id, + workos_updated_at, + workos_last_event_id + ) + SELECT + @organization_id, + @workos_user_id, + @user_id, + input_role_urn.role_urn, + @workos_membership_id, + @workos_updated_at, + @workos_last_event_id + FROM input_role_urn + ON CONFLICT (organization_id, workos_user_id, role_urn) DO UPDATE SET + user_id = COALESCE(EXCLUDED.user_id, organization_role_assignments.user_id), + workos_membership_id = EXCLUDED.workos_membership_id, + workos_updated_at = EXCLUDED.workos_updated_at, + workos_last_event_id = EXCLUDED.workos_last_event_id, + updated_at = clock_timestamp() + RETURNING role_urn +) +DELETE FROM organization_role_assignments +WHERE organization_role_assignments.organization_id = @organization_id + AND organization_role_assignments.workos_user_id = @workos_user_id + AND EXISTS (SELECT 1 FROM upserted) + AND organization_role_assignments.role_urn NOT IN (SELECT role_urn FROM upserted); diff --git a/server/internal/access/rbac_test.go b/server/internal/access/rbac_test.go index c7a9163398..c1948af065 100644 --- a/server/internal/access/rbac_test.go +++ b/server/internal/access/rbac_test.go @@ -33,14 +33,9 @@ func TestService_ListRoles_AllowsOrgReadGrant(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgRead, Selector: authz.NewSelector(authz.ScopeOrgRead, testAccessAuthContext(t, ctx).ActiveOrganizationID)}) - - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() + authCtx := testAccessAuthContext(t, ctx) + ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgRead, Selector: authz.NewSelector(authz.ScopeOrgRead, authCtx.ActiveOrganizationID)}) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) result, err := ti.service.ListRoles(ctx, &gen.ListRolesPayload{}) require.NoError(t, err) @@ -66,17 +61,12 @@ func TestService_GetRole_AllowsOrgReadGrant(t *testing.T) { authCtx := testAccessAuthContext(t, ctx) ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgRead, Selector: authz.NewSelector(authz.ScopeOrgRead, authCtx.ActiveOrganizationID)}) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Can build selected resources"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build selected resources")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") - role, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: "role_custom"}) + role, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: roleID}) require.NoError(t, err) - require.Equal(t, "role_custom", role.ID) + require.Equal(t, roleID, role.ID) } func TestService_ListScopes_ForbiddenWithoutOrgReadGrant(t *testing.T) { @@ -122,16 +112,8 @@ func TestService_ListMembers_AllowsOrgReadGrant(t *testing.T) { ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgRead, Selector: authz.NewSelector(authz.ScopeOrgRead, authCtx.ActiveOrganizationID)}) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") - - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User{ - "user_1": mockUser("user_1", "Ada", "Lovelace", "ada@example.com"), - }, nil).Once() + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) @@ -194,10 +176,7 @@ func TestService_UpdateRole_AllowsOrgAdminGrant(t *testing.T) { ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgAdmin, Selector: authz.NewSelector(authz.ScopeOrgAdmin, authCtx.ActiveOrganizationID)}) name := "Updated" - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old")) ti.roles.On("UpdateRole", mock.Anything, mockidp.MockOrgID, "custom-builder", thirdpartyworkos.UpdateRoleOpts{Name: &name}).Return(&thirdpartyworkos.Role{ ID: "role_custom", Name: name, @@ -206,9 +185,8 @@ func TestService_UpdateRole_AllowsOrgAdminGrant(t *testing.T) { CreatedAt: mockRoleTimestamp, UpdatedAt: mockRoleTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() - role, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: "role_custom", Name: &name}) + role, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: roleID, Name: &name}) require.NoError(t, err) require.Equal(t, name, role.Name) } @@ -232,14 +210,11 @@ func TestService_DeleteRole_AllowsOrgAdminGrant(t *testing.T) { authCtx := testAccessAuthContext(t, ctx) ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgAdmin, Selector: authz.NewSelector(authz.ScopeOrgAdmin, authCtx.ActiveOrganizationID)}) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(nil).Once() seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) require.NoError(t, err) } @@ -262,12 +237,10 @@ func TestService_UpdateMemberRole_AllowsOrgAdminGrant(t *testing.T) { authCtx := testAccessAuthContext(t, ctx) ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgAdmin, Selector: authz.NewSelector(authz.ScopeOrgAdmin, authCtx.ActiveOrganizationID)}) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -275,14 +248,10 @@ func TestService_UpdateMemberRole_AllowsOrgAdminGrant(t *testing.T) { RoleSlug: "custom-builder", CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User{ - "user_1": mockUser("user_1", "Ada", "Lovelace", "ada@example.com"), - }, nil).Once() - seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") - member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "role_builder"}) + member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) require.NoError(t, err) - require.Equal(t, "role_builder", member.RoleID) + require.Equal(t, builderID, member.RoleID) } func withRBACGrants(t *testing.T, ctx context.Context, grants ...authz.Grant) context.Context { diff --git a/server/internal/access/repo/queries.sql.go b/server/internal/access/repo/queries.sql.go index d778577c69..43786c78ce 100644 --- a/server/internal/access/repo/queries.sql.go +++ b/server/internal/access/repo/queries.sql.go @@ -13,6 +13,51 @@ import ( "github.com/speakeasy-api/gram/server/internal/urn" ) +const countMembersByRoleForOrg = `-- name: CountMembersByRoleForOrg :many +SELECT + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + COUNT(*)::bigint AS member_count +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = $1 + AND ora.user_id IS NOT NULL + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL +GROUP BY role_slug +` + +type CountMembersByRoleForOrgRow struct { + RoleSlug string + MemberCount int64 +} + +func (q *Queries) CountMembersByRoleForOrg(ctx context.Context, organizationID string) ([]CountMembersByRoleForOrgRow, error) { + rows, err := q.db.Query(ctx, countMembersByRoleForOrg, organizationID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []CountMembersByRoleForOrgRow + for rows.Next() { + var i CountMembersByRoleForOrgRow + if err := rows.Scan(&i.RoleSlug, &i.MemberCount); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const deletePrincipalGrant = `-- name: DeletePrincipalGrant :execrows DELETE FROM principal_grants WHERE id = $1 @@ -81,6 +126,50 @@ func (q *Queries) GetGlobalRoleBySlug(ctx context.Context, workosSlug string) (G return i, err } +const getOrganizationRoleByID = `-- name: GetOrganizationRoleByID :one +SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at +FROM global_roles +WHERE global_roles.id = $1 + AND deleted IS FALSE + AND workos_deleted IS FALSE +UNION ALL +SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at +FROM organization_roles +WHERE organization_id = $2 + AND organization_roles.id = $1 + AND deleted IS FALSE + AND workos_deleted IS FALSE +LIMIT 1 +` + +type GetOrganizationRoleByIDParams struct { + ID uuid.UUID + OrganizationID string +} + +type GetOrganizationRoleByIDRow struct { + ID uuid.UUID + WorkosSlug string + WorkosName string + WorkosDescription pgtype.Text + WorkosCreatedAt pgtype.Timestamptz + WorkosUpdatedAt pgtype.Timestamptz +} + +func (q *Queries) GetOrganizationRoleByID(ctx context.Context, arg GetOrganizationRoleByIDParams) (GetOrganizationRoleByIDRow, error) { + row := q.db.QueryRow(ctx, getOrganizationRoleByID, arg.ID, arg.OrganizationID) + var i GetOrganizationRoleByIDRow + err := row.Scan( + &i.ID, + &i.WorkosSlug, + &i.WorkosName, + &i.WorkosDescription, + &i.WorkosCreatedAt, + &i.WorkosUpdatedAt, + ) + return i, err +} + const getOrganizationRoleBySlug = `-- name: GetOrganizationRoleBySlug :one SELECT id, organization_id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, workos_deleted_at, workos_deleted, workos_last_event_id, created_at, updated_at, deleted_at, deleted FROM organization_roles @@ -223,6 +312,56 @@ func (q *Queries) InsertChallengeResolutions(ctx context.Context, arg InsertChal return items, nil } +const listActiveOrganizationRoles = `-- name: ListActiveOrganizationRoles :many +SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at +FROM global_roles +WHERE deleted IS FALSE + AND workos_deleted IS FALSE +UNION ALL +SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at +FROM organization_roles +WHERE organization_id = $1 + AND deleted IS FALSE + AND workos_deleted IS FALSE +ORDER BY workos_slug +` + +type ListActiveOrganizationRolesRow struct { + ID uuid.UUID + WorkosSlug string + WorkosName string + WorkosDescription pgtype.Text + WorkosCreatedAt pgtype.Timestamptz + WorkosUpdatedAt pgtype.Timestamptz +} + +func (q *Queries) ListActiveOrganizationRoles(ctx context.Context, organizationID string) ([]ListActiveOrganizationRolesRow, error) { + rows, err := q.db.Query(ctx, listActiveOrganizationRoles, organizationID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListActiveOrganizationRolesRow + for rows.Next() { + var i ListActiveOrganizationRolesRow + if err := rows.Scan( + &i.ID, + &i.WorkosSlug, + &i.WorkosName, + &i.WorkosDescription, + &i.WorkosCreatedAt, + &i.WorkosUpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listChallengeResolutions = `-- name: ListChallengeResolutions :many SELECT id, organization_id, challenge_id, principal_urn, scope, resource_kind, resource_id, resolution_type, role_slug, resolved_by, created_at FROM authz_challenge_resolutions @@ -270,6 +409,105 @@ func (q *Queries) ListChallengeResolutions(ctx context.Context, arg ListChalleng return items, nil } +const listMemberRoleSlugsByWorkosUser = `-- name: ListMemberRoleSlugsByWorkosUser :many +SELECT DISTINCT COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = $1 + AND ora.workos_user_id = $2 + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL +ORDER BY role_slug +` + +type ListMemberRoleSlugsByWorkosUserParams struct { + OrganizationID string + WorkosUserID string +} + +func (q *Queries) ListMemberRoleSlugsByWorkosUser(ctx context.Context, arg ListMemberRoleSlugsByWorkosUserParams) ([]string, error) { + rows, err := q.db.Query(ctx, listMemberRoleSlugsByWorkosUser, arg.OrganizationID, arg.WorkosUserID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var role_slug string + if err := rows.Scan(&role_slug); err != nil { + return nil, err + } + items = append(items, role_slug) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listOrganizationRoleAssignmentsForOrg = `-- name: ListOrganizationRoleAssignmentsForOrg :many +SELECT + ora.user_id, + ora.workos_user_id, + ora.workos_membership_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.created_at +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = $1 + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL +ORDER BY ora.workos_user_id, role_slug +` + +type ListOrganizationRoleAssignmentsForOrgRow struct { + UserID pgtype.Text + WorkosUserID string + WorkosMembershipID pgtype.Text + RoleSlug string + CreatedAt pgtype.Timestamptz +} + +func (q *Queries) ListOrganizationRoleAssignmentsForOrg(ctx context.Context, organizationID string) ([]ListOrganizationRoleAssignmentsForOrgRow, error) { + rows, err := q.db.Query(ctx, listOrganizationRoleAssignmentsForOrg, organizationID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListOrganizationRoleAssignmentsForOrgRow + for rows.Next() { + var i ListOrganizationRoleAssignmentsForOrgRow + if err := rows.Scan( + &i.UserID, + &i.WorkosUserID, + &i.WorkosMembershipID, + &i.RoleSlug, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listPrincipalGrantsByOrg = `-- name: ListPrincipalGrantsByOrg :many SELECT id, organization_id, principal_urn, principal_type, scope, selectors, created_at, updated_at @@ -382,6 +620,78 @@ func (q *Queries) MarkOrganizationRoleDeleted(ctx context.Context, arg MarkOrgan return result.RowsAffected(), nil } +const replaceOrganizationRoleAssignment = `-- name: ReplaceOrganizationRoleAssignment :exec +WITH input_role_urn AS ( + SELECT 'role:organization:' || id::text AS role_urn + FROM organization_roles + WHERE organization_roles.organization_id = $1 + AND organization_roles.workos_slug = $3 + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE + UNION ALL + SELECT 'role:global:' || id::text AS role_urn + FROM global_roles + WHERE global_roles.workos_slug = $3 + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +), +upserted AS ( + INSERT INTO organization_role_assignments ( + organization_id, + workos_user_id, + user_id, + role_urn, + workos_membership_id, + workos_updated_at, + workos_last_event_id + ) + SELECT + $1, + $2, + $4, + input_role_urn.role_urn, + $5, + $6, + $7 + FROM input_role_urn + ON CONFLICT (organization_id, workos_user_id, role_urn) DO UPDATE SET + user_id = COALESCE(EXCLUDED.user_id, organization_role_assignments.user_id), + workos_membership_id = EXCLUDED.workos_membership_id, + workos_updated_at = EXCLUDED.workos_updated_at, + workos_last_event_id = EXCLUDED.workos_last_event_id, + updated_at = clock_timestamp() + RETURNING role_urn +) +DELETE FROM organization_role_assignments +WHERE organization_role_assignments.organization_id = $1 + AND organization_role_assignments.workos_user_id = $2 + AND EXISTS (SELECT 1 FROM upserted) + AND organization_role_assignments.role_urn NOT IN (SELECT role_urn FROM upserted) +` + +type ReplaceOrganizationRoleAssignmentParams struct { + OrganizationID string + WorkosUserID string + WorkosRoleSlug string + UserID pgtype.Text + WorkosMembershipID pgtype.Text + WorkosUpdatedAt pgtype.Timestamptz + WorkosLastEventID pgtype.Text +} + +func (q *Queries) ReplaceOrganizationRoleAssignment(ctx context.Context, arg ReplaceOrganizationRoleAssignmentParams) error { + _, err := q.db.Exec(ctx, replaceOrganizationRoleAssignment, + arg.OrganizationID, + arg.WorkosUserID, + arg.WorkosRoleSlug, + arg.UserID, + arg.WorkosMembershipID, + arg.WorkosUpdatedAt, + arg.WorkosLastEventID, + ) + return err +} + const upsertGlobalRole = `-- name: UpsertGlobalRole :exec INSERT INTO global_roles ( workos_slug, diff --git a/server/internal/access/role_manager.go b/server/internal/access/role_manager.go new file mode 100644 index 0000000000..81546edca9 --- /dev/null +++ b/server/internal/access/role_manager.go @@ -0,0 +1,721 @@ +package access + +import ( + "context" + "errors" + "log/slog" + "regexp" + "strings" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "go.opentelemetry.io/otel/trace" + + gen "github.com/speakeasy-api/gram/server/gen/access" + "github.com/speakeasy-api/gram/server/internal/access/repo" + "github.com/speakeasy-api/gram/server/internal/attr" + "github.com/speakeasy-api/gram/server/internal/authz" + "github.com/speakeasy-api/gram/server/internal/conv" + "github.com/speakeasy-api/gram/server/internal/oops" + "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" + "github.com/speakeasy-api/gram/server/internal/urn" + usersrepo "github.com/speakeasy-api/gram/server/internal/users/repo" +) + +var ErrRoleNotFound = errors.New("role not found") + +var validRoleNamePattern = regexp.MustCompile(`^[A-Za-z0-9 _-]+$`) + +type RoleProvider interface { + CreateRole(ctx context.Context, orgID string, opts workos.CreateRoleOpts) (*workos.Role, error) + UpdateRole(ctx context.Context, orgID string, roleSlug string, opts workos.UpdateRoleOpts) (*workos.Role, error) + DeleteRole(ctx context.Context, orgID string, roleSlug string) error + UpdateMemberRole(ctx context.Context, membershipID string, roleSlug string) (*workos.Member, error) + GetOrgMembership(ctx context.Context, workOSUserID, workOSOrgID string) (*workos.Member, error) +} + +// RoleManager owns role reads from local records and role writes through WorkOS. +type RoleManager struct { + db *pgxpool.Pool + logger *slog.Logger + roles RoleProvider + authz *authz.Engine +} + +// NewRoleManager wires the role manager to the local DB, the WorkOS role client, and the authz engine. +func NewRoleManager(logger *slog.Logger, db *pgxpool.Pool, roles RoleProvider, authzEngine *authz.Engine) *RoleManager { + return &RoleManager{ + db: db, + logger: logger.With(attr.SlogComponent("access.role_manager")), + roles: roles, + authz: authzEngine, + } +} + +// ListRoles returns active roles for an organization from local records and enriches them with local grants and member counts. +func (r *RoleManager) ListRoles(ctx context.Context, gramOrgID string) (*gen.ListRolesResult, error) { + rows, err := repo.New(r.db).ListActiveOrganizationRoles(ctx, gramOrgID) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "list roles").Log(ctx, r.logger) + } + + memberCounts, err := r.memberCounts(ctx, gramOrgID) + if err != nil { + return nil, err + } + + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + roles := make([]*gen.Role, 0, len(rows)) + for _, row := range rows { + role, err := r.roleViewFromLocalRole(ctx, gramOrgID, localRoleFromActiveRow(row), memberCounts[row.WorkosSlug]) + if err != nil { + return nil, err + } + roles = append(roles, role) + } + + return &gen.ListRolesResult{Roles: roles}, nil +} + +// GetRoleByID returns one active role from the local role table with local grants and member count. +func (r *RoleManager) GetRoleByID(ctx context.Context, gramOrgID, id string) (*gen.Role, error) { + role, err := r.getLocalRoleByID(ctx, gramOrgID, id) + if err != nil { + return nil, err + } + + memberCounts, err := r.memberCounts(ctx, gramOrgID) + if err != nil { + return nil, err + } + + return r.roleViewFromLocalRole(ctx, gramOrgID, role, memberCounts[role.Slug]) +} + +type localRoleAssignment struct { + UserID string + WorkosUserID string + MembershipID string + RoleSlug string + CreatedAt string +} + +// ListMembers returns locally known organization members with role IDs resolved from local role assignments. +func (r *RoleManager) ListMembers(ctx context.Context, gramOrgID string) (*gen.ListMembersResult, error) { + roleRows, err := repo.New(r.db).ListActiveOrganizationRoles(ctx, gramOrgID) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "list roles").Log(ctx, r.logger) + } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + roles := make(map[string]string, len(roleRows)) + for _, row := range roleRows { + roles[row.WorkosSlug] = row.ID.String() + } + + assignmentRows, err := repo.New(r.db).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) + } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + assignments := make([]localRoleAssignment, 0, len(assignmentRows)) + for _, row := range assignmentRows { + assignments = append(assignments, localRoleAssignment{ + UserID: conv.FromPGTextOrEmpty[string](row.UserID), + WorkosUserID: row.WorkosUserID, + MembershipID: conv.FromPGTextOrEmpty[string](row.WorkosMembershipID), + RoleSlug: row.RoleSlug, + CreatedAt: conv.FromPGTimestamptz(row.CreatedAt), + }) + } + + userIDs := make([]string, 0, len(assignments)) + for _, assignment := range assignments { + if assignment.UserID != "" { + userIDs = append(userIDs, assignment.UserID) + } + } + localRows, err := usersrepo.New(r.db).GetUsersByIDs(ctx, userIDs) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "resolve users by ids").Log(ctx, r.logger) + } + localUsers := make(map[string]usersrepo.User, len(localRows)) + for _, u := range localRows { + localUsers[u.ID] = u + } + + result := make([]*gen.AccessMember, 0, len(assignments)) + for _, assignment := range assignments { + user, ok := localUsers[assignment.UserID] + if !ok { + continue + } + + result = append(result, &gen.AccessMember{ + ID: user.ID, + Name: conv.Default(user.DisplayName, user.Email), + Email: user.Email, + PhotoURL: conv.FromPGText[string](user.PhotoUrl), + RoleID: roles[assignment.RoleSlug], + JoinedAt: assignment.CreatedAt, + }) + } + + return &gen.ListMembersResult{Members: result}, nil +} + +type roleCreateResult struct { + Role *gen.Role + Slug string +} + +// CreateRole creates a WorkOS role, upserts the local role record, syncs local grants, and optionally assigns members. +func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID string, payload *gen.CreateRolePayload) (roleCreateResult, error) { + roleSlug, err := slugify(payload.Name) + if err != nil { + return roleCreateResult{}, err + } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSlug(roleSlug)) + + wr, err := r.roles.CreateRole(ctx, workosOrgID, workos.CreateRoleOpts{ + Name: payload.Name, + Slug: roleSlug, + Description: payload.Description, + }) + var apiErr *workos.APIError + switch { + case errors.As(err, &apiErr) && apiErr.StatusCode == 409: + return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "create role in workos").Log(ctx, r.logger) + case err != nil: + return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "create role in workos").Log(ctx, r.logger) + } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleID(wr.ID)) + if err := repo.New(r.db).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, *wr)); err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + _ = oops.E(oops.CodeUnexpected, err, "upsert local role record after workos write").Log(ctx, r.logger) + } + + if err := authz.SyncGrants(ctx, r.logger, r.db, gramOrgID, wr.Slug, roleGrantPayloads(payload.Grants)); err != nil { + return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "sync grants for created role").Log(ctx, r.logger) + } + + assignedCount := 0 + if len(payload.MemberIds) > 0 { + assignedCount, err = r.assignMembersToRole(ctx, gramOrgID, wr.Slug, payload.MemberIds) + if err != nil { + return roleCreateResult{}, err + } + } + + role, err := r.roleViewFromLocalRole(ctx, gramOrgID, localRoleFromWorkOS(*wr), assignedCount) + if err != nil { + return roleCreateResult{}, err + } + + return roleCreateResult{Role: role, Slug: wr.Slug}, nil +} + +type localRole struct { + ID string + Name string + Slug string + Description string + CreatedAt string + UpdatedAt string +} + +type roleUpdateResult struct { + Before *gen.Role + After *gen.Role + Role localRole +} + +// UpdateRole updates an existing role and optionally replaces its assigned members. +func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID string, payload *gen.UpdateRolePayload) (roleUpdateResult, error) { + currentRole, err := r.getLocalRoleByID(ctx, gramOrgID, payload.ID) + if err != nil { + return roleUpdateResult{}, err + } + memberCountsBefore, err := r.memberCounts(ctx, gramOrgID) + if err != nil { + return roleUpdateResult{}, err + } + existingRole, err := r.roleViewFromLocalRole(ctx, gramOrgID, currentRole, memberCountsBefore[currentRole.Slug]) + if err != nil { + return roleUpdateResult{}, err + } + + sysRole := isSystemRole(currentRole.Slug) + if sysRole && (payload.Name != nil || payload.Description != nil || payload.Grants != nil) { + return roleUpdateResult{}, oops.E(oops.CodeBadRequest, nil, "system role properties cannot be updated, only member assignment is allowed").Log(ctx, r.logger) + } + if sysRole && payload.MemberIds == nil { + return roleUpdateResult{}, oops.E(oops.CodeBadRequest, nil, "system role update requires member_ids").Log(ctx, r.logger) + } + if payload.Name != nil { + if _, err := slugify(*payload.Name); err != nil { + return roleUpdateResult{}, err + } + } + + updatedRole := currentRole + if !sysRole { + wRole, err := r.roles.UpdateRole(ctx, workosOrgID, currentRole.Slug, workos.UpdateRoleOpts{ + Name: payload.Name, + Description: payload.Description, + }) + if err != nil { + return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "update role in workos").Log(ctx, r.logger) + } + updatedRole = localRoleFromWorkOS(*wRole) + if err := repo.New(r.db).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, *wRole)); err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + _ = oops.E(oops.CodeUnexpected, err, "upsert local role record after workos write").Log(ctx, r.logger) + } + + if payload.Grants != nil { + if err := authz.SyncGrants(ctx, r.logger, r.db, gramOrgID, currentRole.Slug, roleGrantPayloads(payload.Grants)); err != nil { + return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "sync grants for updated role").Log(ctx, r.logger) + } + } + } + + if payload.MemberIds != nil { + if _, err := r.assignMembersToRole(ctx, gramOrgID, currentRole.Slug, payload.MemberIds); err != nil { + return roleUpdateResult{}, err + } + } + + memberCounts, err := r.memberCounts(ctx, gramOrgID) + if err != nil { + return roleUpdateResult{}, err + } + updatedRoleView, err := r.roleViewFromLocalRole(ctx, gramOrgID, updatedRole, memberCounts[updatedRole.Slug]) + if err != nil { + return roleUpdateResult{}, err + } + + return roleUpdateResult{Before: existingRole, After: updatedRoleView, Role: updatedRole}, nil +} + +// DeleteRole deletes a custom role after moving assigned members to the default member role. +// Side effects: reads local records; writes WorkOS membership reassignments before deleting the WorkOS role; upserts local assignment records, marks the local role deleted, invalidates role caches, and deletes local role grants. +func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, roleID string) (localRole, error) { + currentRole, err := r.getLocalRoleByID(ctx, gramOrgID, roleID) + if err != nil { + return localRole{}, err + } + if isSystemRole(currentRole.Slug) { + return localRole{}, oops.E(oops.CodeBadRequest, nil, "system roles cannot be deleted").Log(ctx, r.logger) + } + + rows, err := repo.New(r.db).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) + if err != nil { + return localRole{}, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) + } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + + reassigned := false + for _, row := range rows { + if row.RoleSlug != currentRole.Slug { + continue + } + membershipID := conv.FromPGTextOrEmpty[string](row.WorkosMembershipID) + if _, err := r.roles.UpdateMemberRole(ctx, membershipID, authz.SystemRoleMember); err != nil { + if reassigned { + r.authz.InvalidateAllRoleCaches(ctx, gramOrgID) + } + return localRole{}, oops.E(oops.CodeUnexpected, err, "reassign member to default role").Log(ctx, r.logger) + } + if row.WorkosUserID != "" { + if err := repo.New(r.db).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, row.WorkosUserID, authz.SystemRoleMember, "", membershipID)); err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + _ = oops.E(oops.CodeUnexpected, err, "upsert local role assignment record after workos write").Log(ctx, r.logger) + } + } + reassigned = true + } + if reassigned { + r.authz.InvalidateAllRoleCaches(ctx, gramOrgID) + } + + if err := r.roles.DeleteRole(ctx, workosOrgID, currentRole.Slug); err != nil { + return localRole{}, oops.E(oops.CodeUnexpected, err, "delete role in workos").Log(ctx, r.logger) + } + if _, err := repo.New(r.db).MarkOrganizationRoleDeleted(ctx, repo.MarkOrganizationRoleDeletedParams{ + OrganizationID: gramOrgID, + WorkosSlug: currentRole.Slug, + WorkosDeletedAt: conv.ToPGTimestamptz(time.Now().UTC()), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }); err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + _ = oops.E(oops.CodeUnexpected, err, "mark local role record deleted after workos write").Log(ctx, r.logger) + } + + if _, err := repo.New(r.db).DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ + OrganizationID: gramOrgID, + PrincipalUrn: urn.NewPrincipal(urn.PrincipalTypeRole, currentRole.Slug), + }); err != nil { + return localRole{}, oops.E(oops.CodeUnexpected, err, "delete grants for deleted role").Log(ctx, r.logger) + } + + return currentRole, nil +} + +type memberRoleUpdateContext struct { + RoleSlug string + MembershipID string + WorkosUserID string + UserID string + Before *gen.AccessMember + After *gen.AccessMember +} + +// UpdateMemberRole changes one member's role assignment. +// Side effects: reads local user, role, and local assignment records; writes the WorkOS membership first, upserts the local assignment record, and invalidates that member's role cache. +func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, roleID string) (memberRoleUpdateContext, error) { + role, err := r.getLocalRoleByID(ctx, gramOrgID, roleID) + if err != nil { + return memberRoleUpdateContext{}, err + } + + connectedUser, err := connectedUser(ctx, r.db, gramOrgID, userID) + switch { + case errors.Is(err, errConnectedUserNotFound): + return memberRoleUpdateContext{}, oops.E(oops.CodeNotFound, nil, "member has not joined this organization").Log(ctx, r.logger) + case err != nil: + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "load connected user").Log(ctx, r.logger) + } + if !connectedUser.WorkosID.Valid || connectedUser.WorkosID.String == "" { + return memberRoleUpdateContext{}, oops.E(oops.CodeBadRequest, nil, "member is not linked to WorkOS").Log(ctx, r.logger) + } + + roleRows, err := repo.New(r.db).ListActiveOrganizationRoles(ctx, gramOrgID) + if err != nil { + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "list roles").Log(ctx, r.logger) + } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + roleIDBySlug := make(map[string]string, len(roleRows)) + for _, row := range roleRows { + roleIDBySlug[row.WorkosSlug] = row.ID.String() + } + + assignmentRows, err := repo.New(r.db).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) + if err != nil { + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) + } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + var existing localRoleAssignment + for _, row := range assignmentRows { + if row.WorkosUserID == connectedUser.WorkosID.String { + existing = localRoleAssignment{ + UserID: conv.FromPGTextOrEmpty[string](row.UserID), + WorkosUserID: row.WorkosUserID, + MembershipID: conv.FromPGTextOrEmpty[string](row.WorkosMembershipID), + RoleSlug: row.RoleSlug, + CreatedAt: conv.FromPGTimestamptz(row.CreatedAt), + } + break + } + } + if existing.MembershipID == "" { + return memberRoleUpdateContext{}, oops.E(oops.CodeNotFound, nil, "member not found").Log(ctx, r.logger) + } + + updatedMember, err := r.roles.UpdateMemberRole(ctx, existing.MembershipID, role.Slug) + if err != nil { + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "update member role in workos").Log(ctx, r.logger) + } + if updatedMember.UserID != "" && role.Slug != "" { + if err := repo.New(r.db).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, updatedMember.UserID, role.Slug, connectedUser.ID, updatedMember.ID)); err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + _ = oops.E(oops.CodeUnexpected, err, "upsert local role assignment record after workos write").Log(ctx, r.logger) + } + } + r.authz.InvalidateRoleCache(ctx, userID, gramOrgID) + + memberName := conv.Default(connectedUser.DisplayName, connectedUser.Email) + return memberRoleUpdateContext{ + RoleSlug: role.Slug, + MembershipID: existing.MembershipID, + WorkosUserID: connectedUser.WorkosID.String, + UserID: connectedUser.ID, + Before: &gen.AccessMember{ + ID: connectedUser.ID, + Name: memberName, + Email: connectedUser.Email, + PhotoURL: conv.FromPGText[string](connectedUser.PhotoUrl), + RoleID: roleIDBySlug[existing.RoleSlug], + JoinedAt: existing.CreatedAt, + }, + After: &gen.AccessMember{ + ID: connectedUser.ID, + Name: memberName, + Email: connectedUser.Email, + PhotoURL: conv.FromPGText[string](connectedUser.PhotoUrl), + RoleID: roleID, + JoinedAt: existing.CreatedAt, + }, + }, nil +} + +// MemberRoleSlugs returns local role slugs assigned to a WorkOS user inside an organization. +// Side effects: reads Postgres local assignment records; does not call WorkOS. +func (r *RoleManager) MemberRoleSlugs(ctx context.Context, gramOrgID, workosUserID string) ([]string, error) { + if workosUserID == "" { + return nil, nil + } + + roleSlugs, err := repo.New(r.db).ListMemberRoleSlugsByWorkosUser(ctx, repo.ListMemberRoleSlugsByWorkosUserParams{ + OrganizationID: gramOrgID, + WorkosUserID: workosUserID, + }) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "list member roles").Log(ctx, r.logger) + } + + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + return roleSlugs, nil +} + +// getLocalRoleByID loads one local role record by Gram role ID. +// Side effects: reads Postgres local role records; does not call WorkOS. +func (r *RoleManager) getLocalRoleByID(ctx context.Context, gramOrgID, id string) (localRole, error) { + roleID, err := uuid.Parse(id) + if err != nil { + return localRole{}, oops.E(oops.CodeBadRequest, err, "invalid role ID").Log(ctx, r.logger) + } + + row, err := repo.New(r.db).GetOrganizationRoleByID(ctx, repo.GetOrganizationRoleByIDParams{ + ID: roleID, + OrganizationID: gramOrgID, + }) + switch { + case errors.Is(err, pgx.ErrNoRows): + return localRole{}, oops.E(oops.CodeNotFound, ErrRoleNotFound, "role not found").Log(ctx, r.logger) + case err != nil: + return localRole{}, oops.E(oops.CodeUnexpected, err, "get role").Log(ctx, r.logger) + } + + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + return localRoleFromRoleRow(row), nil +} + +type memberAssignmentTarget struct { + UserID string + WorkosUserID string + MembershipID string +} + +// memberAssignmentTargets resolves Gram user IDs to WorkOS membership IDs using local user and local assignment records. +func (r *RoleManager) memberAssignmentTargets(ctx context.Context, gramOrgID string, memberIDs []string) ([]memberAssignmentTarget, error) { + if len(memberIDs) == 0 { + return nil, nil + } + + users, err := usersrepo.New(r.db).GetUsersByIDs(ctx, memberIDs) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "resolve users by ids").Log(ctx, r.logger) + } + workosByGramID := make(map[string]string, len(users)) + for _, user := range users { + if user.WorkosID.Valid && user.WorkosID.String != "" { + workosByGramID[user.ID] = user.WorkosID.String + } + } + + assignmentRows, err := repo.New(r.db).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) + } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + membershipByWorkosID := make(map[string]string, len(assignmentRows)) + for _, row := range assignmentRows { + membershipByWorkosID[row.WorkosUserID] = conv.FromPGTextOrEmpty[string](row.WorkosMembershipID) + } + + targets := make([]memberAssignmentTarget, 0, len(memberIDs)) + for _, gramID := range memberIDs { + workosID, ok := workosByGramID[gramID] + if !ok { + continue + } + membershipID, ok := membershipByWorkosID[workosID] + if !ok { + continue + } + targets = append(targets, memberAssignmentTarget{ + UserID: gramID, + WorkosUserID: workosID, + MembershipID: membershipID, + }) + } + + return targets, nil +} + +// assignMembersToRole moves each requested member to the given WorkOS role and mirrors the result locally. +// Side effects: reads local users and assignments, writes WorkOS memberships, upserts local assignment records, and invalidates org role caches when any member is assigned. +func (r *RoleManager) assignMembersToRole(ctx context.Context, gramOrgID, roleSlug string, memberIDs []string) (int, error) { + targets, err := r.memberAssignmentTargets(ctx, gramOrgID, memberIDs) + if err != nil { + return 0, err + } + + assignedCount := 0 + for _, target := range targets { + if _, err := r.roles.UpdateMemberRole(ctx, target.MembershipID, roleSlug); err != nil { + return 0, oops.E(oops.CodeUnexpected, err, "assign members to role").Log(ctx, r.logger) + } + if target.WorkosUserID != "" && roleSlug != "" { + if err := repo.New(r.db).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, target.WorkosUserID, roleSlug, target.UserID, target.MembershipID)); err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + _ = oops.E(oops.CodeUnexpected, err, "upsert local role assignment record after workos write").Log(ctx, r.logger) + } + } + assignedCount++ + } + if assignedCount > 0 { + r.authz.InvalidateAllRoleCaches(ctx, gramOrgID) + } + + return assignedCount, nil +} + +// memberCounts returns the number of locally connected members per role slug. +// Side effects: reads Postgres local assignment records; does not call WorkOS. +func (r *RoleManager) memberCounts(ctx context.Context, gramOrgID string) (map[string]int, error) { + rows, err := repo.New(r.db).CountMembersByRoleForOrg(ctx, gramOrgID) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "count local members by role").Log(ctx, r.logger) + } + + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + counts := make(map[string]int, len(rows)) + for _, row := range rows { + counts[row.RoleSlug] = int(row.MemberCount) + } + return counts, nil +} + +// roleViewFromLocalRole converts a local role record into the public API role view and attaches local grants. +// Side effects: reads Postgres grants; does not call WorkOS. +func (r *RoleManager) roleViewFromLocalRole(ctx context.Context, organizationID string, role localRole, memberCount int) (*gen.Role, error) { + grants, err := authz.GrantsForRole(ctx, r.logger, r.db, organizationID, role.Slug) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "load role grants").Log(ctx, r.logger) + } + genGrants := make([]*gen.RoleGrant, 0, len(grants)) + for _, g := range grants { + genGrants = append(genGrants, scopedGrantToGenRoleGrant(g)) + } + + return &gen.Role{ + ID: role.ID, + Name: role.Name, + Description: role.Description, + IsSystem: isSystemRole(role.Slug), + Grants: genGrants, + MemberCount: memberCount, + CreatedAt: conv.Default(role.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), + UpdatedAt: conv.Default(role.UpdatedAt, time.Time{}.UTC().Format(time.RFC3339)), + }, nil +} + +// localRoleFromActiveRow converts a sqlc active-role row into the manager's internal local role record shape. +// Side effects: none. +func localRoleFromActiveRow(row repo.ListActiveOrganizationRolesRow) localRole { + return localRole{ + ID: row.ID.String(), + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + } +} + +// localRoleFromRoleRow converts a sqlc role lookup row into the manager's internal local role record shape. +// Side effects: none. +func localRoleFromRoleRow(row repo.GetOrganizationRoleByIDRow) localRole { + return localRole{ + ID: row.ID.String(), + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + } +} + +// localRoleFromWorkOS converts a WorkOS role response into the manager's internal local role record shape. +// Side effects: none. +func localRoleFromWorkOS(role workos.Role) localRole { + return localRole{ + ID: role.ID, + Name: role.Name, + Slug: role.Slug, + Description: role.Description, + CreatedAt: role.CreatedAt, + UpdatedAt: role.UpdatedAt, + } +} + +// organizationRoleParams builds the SQL parameters for storing a local role record from a WorkOS write response. +// Side effects: reads the clock for updated_at and possibly created_at fallback. +func organizationRoleParams(gramOrgID string, role workos.Role) repo.UpsertOrganizationRoleParams { + return repo.UpsertOrganizationRoleParams{ + OrganizationID: gramOrgID, + WorkosSlug: role.Slug, + WorkosName: role.Name, + WorkosDescription: conv.ToPGTextEmpty(role.Description), + WorkosCreatedAt: conv.ToPGTimestamptz(workosTimeOrNow(role.CreatedAt)), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Now().UTC()), + WorkosLastEventID: conv.ToPGTextEmpty(""), + } +} + +// replaceRoleAssignmentParams builds SQL parameters for storing the authoritative local role assignment after a WorkOS write. +// Side effects: reads the clock for updated_at. +func replaceRoleAssignmentParams(gramOrgID, workosUserID, roleSlug, userID, membershipID string) repo.ReplaceOrganizationRoleAssignmentParams { + return repo.ReplaceOrganizationRoleAssignmentParams{ + OrganizationID: gramOrgID, + WorkosUserID: workosUserID, + WorkosRoleSlug: roleSlug, + UserID: conv.ToPGTextEmpty(userID), + WorkosMembershipID: conv.ToPGTextEmpty(membershipID), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Now().UTC()), + WorkosLastEventID: conv.ToPGTextEmpty(""), + } +} + +// workosTimeOrNow parses a WorkOS RFC3339 timestamp or returns the current UTC time when WorkOS omits or malforms it. +// Side effects: reads the clock only when a fallback is needed. +func workosTimeOrNow(value string) time.Time { + if value == "" { + return time.Now().UTC() + } + t, err := time.Parse(time.RFC3339, value) + if err != nil { + return time.Now().UTC() + } + return t.UTC() +} + +// slugify validates a role name and turns it into Gram's WorkOS role slug format. +// Side effects: none. +func slugify(name string) (string, error) { + slug := conv.ToSlug(strings.ReplaceAll(name, "_", " ")) + if slug == "" { + return "", oops.E(oops.CodeBadRequest, nil, "role name must contain at least one letter or digit") + } + if !validRoleNamePattern.MatchString(name) { + return "", oops.E(oops.CodeBadRequest, nil, "role name contains invalid characters") + } + if !strings.HasPrefix(slug, "org-") { + slug = "org-" + slug + } + + return slug, nil +} diff --git a/server/internal/access/role_manager_test.go b/server/internal/access/role_manager_test.go new file mode 100644 index 0000000000..012426e88b --- /dev/null +++ b/server/internal/access/role_manager_test.go @@ -0,0 +1,85 @@ +package access + +import ( + "testing" + + mockidp "github.com/speakeasy-api/gram/dev-idp/pkg/testidp" + "github.com/stretchr/testify/require" + + "github.com/speakeasy-api/gram/server/internal/contextvalues" +) + +func TestRoleManager_ListRoles(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + adminID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + customID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build")) + + roles, err := ti.service.roleMgr.ListRoles(ctx, authCtx.ActiveOrganizationID) + require.NoError(t, err) + require.Len(t, roles.Roles, 2) + + bySlug := map[string]string{} + for _, role := range roles.Roles { + if role.Name == "Admin" { + bySlug["admin"] = role.ID + } + if role.Name == "Custom Builder" { + bySlug["custom-builder"] = role.ID + } + } + require.Equal(t, adminID, bySlug["admin"]) + require.Equal(t, customID, bySlug["custom-builder"]) +} + +func TestRoleManager_GetRoleByID(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + customID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build")) + + role, err := ti.service.roleMgr.GetRoleByID(ctx, authCtx.ActiveOrganizationID, customID) + require.NoError(t, err) + require.Equal(t, customID, role.ID) + require.Equal(t, "Custom Builder", role.Name) + + _, err = ti.service.roleMgr.GetRoleByID(ctx, authCtx.ActiveOrganizationID, "not-a-uuid") + require.Error(t, err) +} + +func TestRoleManager_MembersAndCounts(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "u1@example.com", "User 1", "user_1", "membership_1") + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "u2@example.com", "User 2", "user_2", "membership_2") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_3", "user_3", "custom-builder")) + + manager := ti.service.roleMgr + members, err := manager.ListMembers(ctx, authCtx.ActiveOrganizationID) + require.NoError(t, err) + require.Len(t, members.Members, 2) + + slugs, err := manager.MemberRoleSlugs(ctx, authCtx.ActiveOrganizationID, "user_2") + require.NoError(t, err) + require.Equal(t, []string{"custom-builder"}, slugs) + + counts, err := manager.memberCounts(ctx, authCtx.ActiveOrganizationID) + require.NoError(t, err) + require.Equal(t, 1, counts["admin"]) + require.Equal(t, 1, counts["custom-builder"]) +} diff --git a/server/internal/access/setup_internal_test.go b/server/internal/access/setup_internal_test.go index 51f01742e6..a1ba2d34cb 100644 --- a/server/internal/access/setup_internal_test.go +++ b/server/internal/access/setup_internal_test.go @@ -30,7 +30,7 @@ func newInternalTestService(t *testing.T) (context.Context, *Service, *pgxpool.P conn, err := res.CloneTestDatabase(t, "testdb") require.NoError(t, err) - return ctx, &Service{tracer: nil, logger: logger, db: conn, chConn: nil, auth: nil, authz: nil, roles: nil, featureCache: nil}, conn + return ctx, &Service{tracer: nil, logger: logger, db: conn, chConn: nil, auth: nil, authz: nil, roleMgr: nil, featureCache: nil}, conn } func seedInternalOrganization(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID string) { diff --git a/server/internal/access/setup_test.go b/server/internal/access/setup_test.go index 627d875f02..8f77a1f6a7 100644 --- a/server/internal/access/setup_test.go +++ b/server/internal/access/setup_test.go @@ -5,6 +5,7 @@ import ( "log" "os" "testing" + "time" "github.com/ClickHouse/clickhouse-go/v2" "github.com/jackc/pgx/v5/pgxpool" @@ -94,7 +95,8 @@ func newTestAccessService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() - svc := NewService(logger, tracerProvider, conn, chConn, sessionManager, roles, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), noopFeatureCacheWriter{}, auditLogger) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + svc := NewService(logger, tracerProvider, conn, chConn, sessionManager, NewRoleManager(logger, conn, roles, authzEngine), authzEngine, noopFeatureCacheWriter{}, auditLogger) return ctx, &testInstance{ service: svc, @@ -143,6 +145,60 @@ func listPrincipalGrants(t *testing.T, ctx context.Context, conn *pgxpool.Pool, return grants } +func seedRole(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID string, role workos.Role) string { + t.Helper() + + createdAt, err := time.Parse(time.RFC3339, role.CreatedAt) + require.NoError(t, err) + updatedAt, err := time.Parse(time.RFC3339, role.UpdatedAt) + require.NoError(t, err) + + err = accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + OrganizationID: organizationID, + WorkosSlug: role.Slug, + WorkosName: role.Name, + WorkosDescription: conv.ToPGTextEmpty(role.Description), + WorkosCreatedAt: conv.ToPGTimestamptz(createdAt), + WorkosUpdatedAt: conv.ToPGTimestamptz(updatedAt), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + require.NoError(t, err) + + row, err := accessrepo.New(conn).GetOrganizationRoleBySlug(ctx, accessrepo.GetOrganizationRoleBySlugParams{ + OrganizationID: organizationID, + WorkosSlug: role.Slug, + }) + require.NoError(t, err) + + return row.ID.String() +} + +func seedRoleAssignment(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID, userID string, member workos.Member) { + t.Helper() + + updatedAt := time.Now().UTC() + if member.UpdatedAt != "" { + parsed, err := time.Parse(time.RFC3339, member.UpdatedAt) + require.NoError(t, err) + updatedAt = parsed + } else if member.CreatedAt != "" { + parsed, err := time.Parse(time.RFC3339, member.CreatedAt) + require.NoError(t, err) + updatedAt = parsed + } + + err := accessrepo.New(conn).ReplaceOrganizationRoleAssignment(ctx, accessrepo.ReplaceOrganizationRoleAssignmentParams{ + OrganizationID: organizationID, + WorkosUserID: member.UserID, + WorkosRoleSlug: member.RoleSlug, + UserID: conv.ToPGTextEmpty(userID), + WorkosMembershipID: conv.ToPGTextEmpty(member.ID), + WorkosUpdatedAt: conv.ToPGTimestamptz(updatedAt), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + require.NoError(t, err) +} + // seedDisconnectedUser creates a user in the users table with a workos_id but // does NOT insert into organization_user_relationships, simulating a WorkOS // user who hasn't been connected to the Gram org. diff --git a/server/internal/access/updatememberrole_test.go b/server/internal/access/updatememberrole_test.go index 7b3f1df054..1b8821c266 100644 --- a/server/internal/access/updatememberrole_test.go +++ b/server/internal/access/updatememberrole_test.go @@ -23,13 +23,10 @@ func TestService_UpdateMemberRole(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -37,28 +34,23 @@ func TestService_UpdateMemberRole(t *testing.T) { RoleSlug: "custom-builder", CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User{ - "user_1": mockUser("user_1", "Ada", "Lovelace", "ada@example.com"), - }, nil).Once() - seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") - member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "role_builder"}) + member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) require.NoError(t, err) require.Equal(t, "local_user_1", member.ID) require.Equal(t, "Ada Lovelace", member.Name) require.Equal(t, "ada@example.com", member.Email) - require.Equal(t, "role_builder", member.RoleID) + require.Equal(t, builderID, member.RoleID) require.Nil(t, member.PhotoURL) - require.Equal(t, mockMembershipTimestamp, member.JoinedAt) + require.NotEmpty(t, member.JoinedAt) } func TestService_UpdateMemberRole_RoleNotFound(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{}, nil).Once() - _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "role_missing"}) + _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "00000000-0000-0000-0000-000000000001"}) require.Error(t, err) require.Contains(t, err.Error(), "role not found") } @@ -70,12 +62,10 @@ func TestService_UpdateMemberRole_MemberNotFound(t *testing.T) { authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") - _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "user_missing", RoleID: "role_builder"}) + _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "user_missing", RoleID: builderID}) require.Error(t, err) require.Contains(t, err.Error(), "member has not joined this organization") } @@ -87,13 +77,10 @@ func TestService_UpdateMemberRole_WorkOSMembershipNotFound(t *testing.T) { authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") - _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "role_builder"}) + _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) require.Error(t, err) require.Contains(t, err.Error(), "member not found") } @@ -105,16 +92,13 @@ func TestService_UpdateMemberRole_WorkOSFailure(t *testing.T) { authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() - ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Once() + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) + ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Once() - _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "role_builder"}) + _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) require.Error(t, err) require.Contains(t, err.Error(), "update member role in workos") } @@ -129,13 +113,10 @@ func TestService_UpdateMemberRole_AuditLog(t *testing.T) { beforeCount, err := audittest.AuditLogCountByAction(ctx, ti.conn, audit.ActionAccessMemberRoleUpdate) require.NoError(t, err) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() + adminID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -143,12 +124,8 @@ func TestService_UpdateMemberRole_AuditLog(t *testing.T) { RoleSlug: "custom-builder", CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User{ - "user_1": mockUser("user_1", "Ada", "Lovelace", "ada@example.com"), - }, nil).Once() - seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") - member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "role_builder"}) + member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) require.NoError(t, err) require.NotNil(t, member) @@ -165,8 +142,8 @@ func TestService_UpdateMemberRole_AuditLog(t *testing.T) { require.NoError(t, err) afterSnapshot, err := audittest.DecodeAuditData(record.AfterSnapshot) require.NoError(t, err) - require.Equal(t, "role_admin", beforeSnapshot["RoleID"]) - require.Equal(t, "role_builder", afterSnapshot["RoleID"]) + require.Equal(t, adminID, beforeSnapshot["RoleID"]) + require.Equal(t, builderID, afterSnapshot["RoleID"]) afterCount, err := audittest.AuditLogCountByAction(ctx, ti.conn, audit.ActionAccessMemberRoleUpdate) require.NoError(t, err) diff --git a/server/internal/access/updaterole_test.go b/server/internal/access/updaterole_test.go index a8f71aa7b2..9bd47b78e2 100644 --- a/server/internal/access/updaterole_test.go +++ b/server/internal/access/updaterole_test.go @@ -28,9 +28,8 @@ func TestService_UpdateRole(t *testing.T) { name := "Platform Builder" description := "Updated description" - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) ti.roles.On("UpdateRole", mock.Anything, mockidp.MockOrgID, "custom-builder", thirdpartyworkos.UpdateRoleOpts{ Name: &name, Description: &description, @@ -42,11 +41,6 @@ func TestService_UpdateRole(t *testing.T) { CreatedAt: mockRoleTimestamp, UpdatedAt: mockRoleTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "member"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "member"), - mockMember(mockidp.MockOrgID, "membership_3", "user_3", "custom-builder"), - }, nil).Once() ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -61,20 +55,17 @@ func TestService_UpdateRole(t *testing.T) { RoleSlug: "custom-builder", CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_3", "user_3", "custom-builder"), - // user_workos_only has never logged into Gram — should not be counted - mockMember(mockidp.MockOrgID, "membership_workos_only", "user_workos_only", "custom-builder"), - }, nil).Once() seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_3", "user3@test.com", "User 3", "user_3", "membership_3") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", authz.SystemRoleMember)) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember(mockidp.MockOrgID, "membership_2", "user_2", authz.SystemRoleMember)) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_3", mockMember(mockidp.MockOrgID, "membership_3", "user_3", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_workos_only", "user_workos_only", "custom-builder")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-old") role, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_custom", + ID: roleID, Name: &name, Description: &description, Grants: []*gen.RoleGrant{ @@ -106,13 +97,8 @@ func TestService_UpdateRole_SystemRole_MemberAssignment(t *testing.T) { require.NotNil(t, authCtx) // admin and member are system roles — WorkOS UpdateRole must NOT be called. - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_admin", "Admin", "admin", "Full access"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "member"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "member"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "admin").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -120,20 +106,18 @@ func TestService_UpdateRole_SystemRole_MemberAssignment(t *testing.T) { RoleSlug: "admin", CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "member"), - }, nil).Once() seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", authz.SystemRoleMember)) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember(mockidp.MockOrgID, "membership_2", "user_2", authz.SystemRoleMember)) role, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_admin", + ID: roleID, MemberIds: []string{"local_user_1"}, }) require.NoError(t, err) - require.Equal(t, "role_admin", role.ID) + require.Equal(t, roleID, role.ID) require.True(t, role.IsSystem) require.Equal(t, 1, role.MemberCount) @@ -145,14 +129,14 @@ func TestService_UpdateRole_SystemRole_RejectsPropertyChanges(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_member", "Member", "member", "Default role"), - }, nil) + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) name := "Custom Name" _, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_member", + ID: roleID, Name: &name, }) require.Error(t, err) @@ -160,14 +144,14 @@ func TestService_UpdateRole_SystemRole_RejectsPropertyChanges(t *testing.T) { description := "Custom description" _, err = ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_member", + ID: roleID, Description: &description, }) require.Error(t, err) require.Contains(t, err.Error(), "system role properties cannot be updated") _, err = ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_member", + ID: roleID, Grants: []*gen.RoleGrant{{Scope: string(authz.ScopeProjectRead)}}, }) require.Error(t, err) @@ -178,13 +162,13 @@ func TestService_UpdateRole_SystemRole_RejectsNoopUpdate(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_admin", "Admin", "admin", "Full access"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) _, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_admin", + ID: roleID, }) require.Error(t, err) require.Contains(t, err.Error(), "system role update requires member_ids") @@ -200,12 +184,8 @@ func TestService_UpdateRole_SystemRole_AuditLog(t *testing.T) { beforeCount, err := audittest.AuditLogCountByAction(ctx, ti.conn, audit.ActionAccessRoleUpdate) require.NoError(t, err) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_admin", "Admin", "admin", "Full access"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "member"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "admin").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -213,14 +193,12 @@ func TestService_UpdateRole_SystemRole_AuditLog(t *testing.T) { RoleSlug: "admin", CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", authz.SystemRoleMember)) role, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_admin", + ID: roleID, MemberIds: []string{"local_user_1"}, }) require.NoError(t, err) @@ -240,9 +218,8 @@ func TestService_UpdateRole_NotFound(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{}, nil).Once() - _, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: "role_missing"}) + _, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: "00000000-0000-0000-0000-000000000001"}) require.Error(t, err) require.Contains(t, err.Error(), "role not found") } @@ -251,13 +228,12 @@ func TestService_UpdateRole_WorkOSUpdateFailure(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) ti.roles.On("UpdateRole", mock.Anything, mockidp.MockOrgID, "custom-builder", thirdpartyworkos.UpdateRoleOpts{}).Return((*thirdpartyworkos.Role)(nil), errors.New("workos unavailable")).Once() - _, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: "role_custom"}) + _, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: roleID}) require.Error(t, err) require.Contains(t, err.Error(), "update role in workos") } @@ -274,12 +250,7 @@ func TestService_UpdateRole_AuditLog(t *testing.T) { name := "Audit Builder" description := "After description" - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Before Builder", "custom-builder", "Before description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Before Builder", "custom-builder", "Before description")) ti.roles.On("UpdateRole", mock.Anything, mockidp.MockOrgID, "custom-builder", thirdpartyworkos.UpdateRoleOpts{ Name: &name, Description: &description, @@ -291,17 +262,15 @@ func TestService_UpdateRole_AuditLog(t *testing.T) { CreatedAt: mockRoleTimestamp, UpdatedAt: mockRoleTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - }, nil).Once() seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-old") updated, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_custom", + ID: roleID, Name: &name, Description: &description, Grants: []*gen.RoleGrant{{ diff --git a/server/internal/attr/conventions.go b/server/internal/attr/conventions.go index 5f6ff246d0..9324bcba7a 100644 --- a/server/internal/attr/conventions.go +++ b/server/internal/attr/conventions.go @@ -196,6 +196,8 @@ const ( AccessMemberIDKey = attribute.Key("gram.access.member.id") AccessRoleIDKey = attribute.Key("gram.access.role.id") AccessRoleSlugKey = attribute.Key("gram.access.role.slug") + AccessRoleSourceKey = attribute.Key("gram.access.role.source") + AccessRoleDBWriteFailedKey = attribute.Key("gram.access.role.db_write_failed_post_workos") OrganizationIDKey = attribute.Key("gram.org.id") OrganizationSlugKey = attribute.Key("gram.org.slug") WorkOSOrganizationIDKey = attribute.Key("gram.workos.organization_id") @@ -843,6 +845,16 @@ func SlogAccessRoleID(v string) slog.Attr { return slog.String(string(Acces func AccessRoleSlug(v string) attribute.KeyValue { return AccessRoleSlugKey.String(v) } func SlogAccessRoleSlug(v string) slog.Attr { return slog.String(string(AccessRoleSlugKey), v) } +func AccessRoleSource(v string) attribute.KeyValue { return AccessRoleSourceKey.String(v) } +func SlogAccessRoleSource(v string) slog.Attr { return slog.String(string(AccessRoleSourceKey), v) } + +func AccessRoleDBWriteFailed(v bool) attribute.KeyValue { + return AccessRoleDBWriteFailedKey.Bool(v) +} +func SlogAccessRoleDBWriteFailed(v bool) slog.Attr { + return slog.Bool(string(AccessRoleDBWriteFailedKey), v) +} + func OrganizationID(v string) attribute.KeyValue { return OrganizationIDKey.String(v) } func SlogOrganizationID(v string) slog.Attr { return slog.String(string(OrganizationIDKey), v) } diff --git a/server/internal/conv/from.go b/server/internal/conv/from.go index 060d232b40..a27bb9348a 100644 --- a/server/internal/conv/from.go +++ b/server/internal/conv/from.go @@ -163,6 +163,16 @@ func PtrToPGTimestamptz(t *time.Time) pgtype.Timestamptz { return pgtype.Timestamptz{Time: *t, Valid: true, InfinityModifier: pgtype.Finite} } +// FromPGTimestamptz converts a pgtype.Timestamptz to an RFC3339 UTC string. If +// the value is not valid, it returns an empty string. +func FromPGTimestamptz(t pgtype.Timestamptz) string { + if !t.Valid { + return "" + } + + return t.Time.UTC().Format(time.RFC3339) +} + // PtrToPGInterval converts a *time.Duration to a pgtype.Interval. If the // pointer is nil, the result has Valid set to false (which becomes SQL NULL). func PtrToPGInterval(d *time.Duration) pgtype.Interval { diff --git a/server/internal/conv/from_test.go b/server/internal/conv/from_test.go index d2cd0a18af..8a3a567c31 100644 --- a/server/internal/conv/from_test.go +++ b/server/internal/conv/from_test.go @@ -2,6 +2,7 @@ package conv_test import ( "testing" + "time" "github.com/jackc/pgx/v5/pgtype" "github.com/speakeasy-api/gram/server/internal/conv" @@ -45,6 +46,20 @@ func TestPtrInt32ToInt_Nil(t *testing.T) { require.Nil(t, result) } +func TestFromPGTimestamptz_Valid(t *testing.T) { + t.Parallel() + + input := pgtype.Timestamptz{Time: time.Date(2024, 11, 15, 15, 4, 5, 0, time.FixedZone("test", 2*60*60)), Valid: true} + + require.Equal(t, "2024-11-15T13:04:05Z", conv.FromPGTimestamptz(input)) +} + +func TestFromPGTimestamptz_Invalid(t *testing.T) { + t.Parallel() + + require.Empty(t, conv.FromPGTimestamptz(pgtype.Timestamptz{})) +} + func TestURLToSlug_HostAndPath(t *testing.T) { t.Parallel() From d98f1506f6c6bf643fc83d5ba78c42996b419624 Mon Sep 17 00:00:00 2001 From: tgmendes Date: Thu, 14 May 2026 17:22:10 +0100 Subject: [PATCH 02/12] chore: update logic to treat local DB as source of truth --- server/cmd/gram/start.go | 1 - server/cmd/gram/worker.go | 1 - server/internal/access/createrole_test.go | 73 ++-- server/internal/access/deleterole_test.go | 31 +- server/internal/access/listmembers_test.go | 17 +- server/internal/access/listusergrants_test.go | 2 +- server/internal/access/queries.sql | 10 +- server/internal/access/rbac_test.go | 3 +- server/internal/access/repo/queries.sql.go | 28 +- server/internal/access/role_manager.go | 319 +++++++++++++----- server/internal/access/role_manager_test.go | 2 +- server/internal/access/setup_test.go | 5 +- .../internal/access/updatememberrole_test.go | 8 +- server/internal/access/updaterole_test.go | 13 +- server/internal/assets/setup_test.go | 1 - .../internal/assistantmemories/impl_test.go | 9 +- server/internal/assistants/impl_test.go | 3 +- server/internal/auditapi/setup_test.go | 2 +- server/internal/auth/setup_test.go | 4 +- server/internal/authz/context_test.go | 13 +- server/internal/authz/engine.go | 92 +---- server/internal/authz/engine_test.go | 114 ++----- server/internal/authz/integration_test.go | 7 +- server/internal/authz/load_test.go | 5 +- server/internal/collections/setup_test.go | 2 +- server/internal/customdomains/setup_test.go | 2 +- server/internal/deployments/setup_test.go | 2 +- server/internal/environments/setup_test.go | 2 +- server/internal/externalmcp/setup_test.go | 2 +- server/internal/functions/setup_test.go | 2 +- server/internal/hooks/setup_test.go | 2 +- server/internal/keys/setup_test.go | 2 +- server/internal/mcp/handle_get_server_test.go | 3 +- server/internal/mcp/rbac_test.go | 23 +- server/internal/mcp/rpc_tools_list_test.go | 15 +- server/internal/mcp/setup_test.go | 2 +- server/internal/mcpendpoints/setup_test.go | 2 +- server/internal/mcpmetadata/setup_test.go | 2 +- server/internal/mcpservers/setup_test.go | 2 +- server/internal/organizations/setup_test.go | 8 +- server/internal/packages/setup_test.go | 2 +- server/internal/plugins/setup_test.go | 4 +- server/internal/productfeatures/setup_test.go | 2 +- server/internal/projects/setup_test.go | 1 - server/internal/remotemcp/setup_test.go | 2 +- server/internal/resources/setup_test.go | 2 +- server/internal/risk/setup_test.go | 2 +- server/internal/telemetry/setup_test.go | 2 +- server/internal/templates/setup_test.go | 2 +- server/internal/tools/setup_test.go | 2 +- server/internal/toolsets/setup_test.go | 2 +- server/internal/triggers/setup_test.go | 2 +- server/internal/usage/impl_test.go | 9 +- server/internal/usersessions/setup_test.go | 2 +- server/internal/variations/setup_test.go | 2 +- server/internal/xmcp/setup_test.go | 4 +- .../xmcp/tools_call_authz_interceptor_test.go | 3 +- 57 files changed, 463 insertions(+), 418 deletions(-) diff --git a/server/cmd/gram/start.go b/server/cmd/gram/start.go index 7f583a89f1..4d3ffaeb4c 100644 --- a/server/cmd/gram/start.go +++ b/server/cmd/gram/start.go @@ -659,7 +659,6 @@ func newStartCommand() *cli.Command { rbacEnabled, challengeLoggingEnabled, roleClient, - cache.NewRedisCacheAdapter(redisClient), authz.EngineOpts{DevMode: c.String("environment") == "local"}, ) diff --git a/server/cmd/gram/worker.go b/server/cmd/gram/worker.go index a27eead8ea..e206fcb898 100644 --- a/server/cmd/gram/worker.go +++ b/server/cmd/gram/worker.go @@ -486,7 +486,6 @@ func newWorkerCommand() *cli.Command { rbacEnabled, challengeLoggingEnabled, workos.NewStubClient(), - cache.NewRedisCacheAdapter(redisClient), authz.EngineOpts{DevMode: c.String("environment") == "local"}, ) diff --git a/server/internal/access/createrole_test.go b/server/internal/access/createrole_test.go index dbdfcb324b..dc5390826b 100644 --- a/server/internal/access/createrole_test.go +++ b/server/internal/access/createrole_test.go @@ -72,12 +72,17 @@ func TestService_CreateRole(t *testing.T) { }) require.NoError(t, err) require.Equal(t, "Custom Builder", role.Name) + require.NotEmpty(t, role.ID) + require.NotEqual(t, "role_1", role.ID) require.Equal(t, "Can build selected resources", role.Description) require.False(t, role.IsSystem) require.Equal(t, 2, role.MemberCount) - require.Equal(t, mockRoleTimestamp, role.CreatedAt) - require.Equal(t, mockRoleTimestamp, role.UpdatedAt) + require.NotEmpty(t, role.CreatedAt) + require.NotEmpty(t, role.UpdatedAt) require.Len(t, role.Grants, 2) + roundtrip, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: role.ID}) + require.NoError(t, err) + require.Equal(t, role.ID, roundtrip.ID) grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "org-custom-builder")) require.Len(t, grants, 3) @@ -91,17 +96,17 @@ func TestService_CreateRole_WorkOSCreateFailure(t *testing.T) { Name: "Custom Builder", Slug: "org-custom-builder", Description: "Can build selected resources", - }).Return((*thirdpartyworkos.Role)(nil), errors.New("workos unavailable")).Once() + }).Return((*thirdpartyworkos.Role)(nil), errors.New("workos unavailable")).Times(3) - _, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{ + role, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{ Name: "Custom Builder", Description: "Can build selected resources", Grants: []*gen.RoleGrant{ {Scope: string(authz.ScopeProjectRead), Selectors: []*gen.Selector{{ResourceKind: "project", ResourceID: "project-1"}}}, }, }) - require.Error(t, err) - require.Contains(t, err.Error(), "create role in workos") + require.NoError(t, err) + require.Equal(t, "Custom Builder", role.Name) } func TestService_CreateRole_WorkOSConflictFailure(t *testing.T) { @@ -114,15 +119,41 @@ func TestService_CreateRole_WorkOSConflictFailure(t *testing.T) { Description: "Can build selected resources", }).Return((*thirdpartyworkos.Role)(nil), &thirdpartyworkos.APIError{Method: "POST", Path: "/authorization/organizations/org_workos_test/roles", StatusCode: 409, Body: "role already exists"}).Once() - _, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{ + role, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{ Name: "Custom Builder", Description: "Can build selected resources", Grants: []*gen.RoleGrant{ {Scope: string(authz.ScopeProjectRead), Selectors: []*gen.Selector{{ResourceKind: "project", ResourceID: "project-1"}}}, }, }) - require.Error(t, err) - require.Contains(t, err.Error(), "create role in workos") + require.NoError(t, err) + require.Equal(t, "Custom Builder", role.Name) +} + +func TestService_CreateRole_WorkOSConflictUsesLocalRole(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_1", "Custom Builder", "org-custom-builder", "Can build selected resources")) + ti.roles.On("CreateRole", mock.Anything, mockidp.MockOrgID, thirdpartyworkos.CreateRoleOpts{ + Name: "Custom Builder", + Slug: "org-custom-builder", + Description: "Can build selected resources", + }).Return((*thirdpartyworkos.Role)(nil), &thirdpartyworkos.APIError{Method: "POST", Path: "/authorization/organizations/org_workos_test/roles", StatusCode: 409, Body: "role already exists"}).Once() + + role, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{ + Name: "Custom Builder", + Description: "Can build selected resources", + Grants: []*gen.RoleGrant{ + {Scope: string(authz.ScopeProjectRead), Selectors: []*gen.Selector{{ResourceKind: "project", ResourceID: "project-1"}}}, + }, + }) + require.NoError(t, err) + require.Equal(t, roleID, role.ID) + require.Equal(t, "Custom Builder", role.Name) } func TestService_CreateRole_RejectsEmptySlug(t *testing.T) { @@ -186,7 +217,7 @@ func TestService_CreateRole_AuditLog(t *testing.T) { require.Equal(t, beforeCount+1, afterCount) } -func TestService_CreateRole_GrantSyncFailureDoesNotAssignMembers(t *testing.T) { +func TestService_CreateRole_LocalRoleWriteFailureDoesNotAssignMembers(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -194,35 +225,21 @@ func TestService_CreateRole_GrantSyncFailureDoesNotAssignMembers(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("CreateRole", mock.Anything, mockidp.MockOrgID, thirdpartyworkos.CreateRoleOpts{ - Name: "Broken Builder", - Slug: "org-broken-builder", - Description: "Will fail grant sync", - }).Run(func(mock.Arguments) { - ti.conn.Close() - }).Return(&thirdpartyworkos.Role{ - ID: "role_1", - Name: "Broken Builder", - Slug: "org-broken-builder", - Description: "Will fail grant sync", - CreatedAt: mockRoleTimestamp, - UpdatedAt: mockRoleTimestamp, - }, nil).Once() - inspectConn, err := pgxpool.New(ctx, ti.conn.Config().ConnString()) require.NoError(t, err) t.Cleanup(inspectConn.Close) - _, err = ti.service.CreateRole(ctx, &gen.CreateRolePayload{ + ti.conn.Close() + _, err = ti.service.roleMgr.CreateRole(ctx, authCtx.ActiveOrganizationID, mockidp.MockOrgID, &gen.CreateRolePayload{ Name: "Broken Builder", - Description: "Will fail grant sync", + Description: "Will fail local write", Grants: []*gen.RoleGrant{ {Scope: string(authz.ScopeProjectRead), Selectors: []*gen.Selector{{ResourceKind: "project", ResourceID: "project-1"}}}, }, MemberIds: []string{"local_user_1", "local_user_2"}, }) require.Error(t, err) - require.Contains(t, err.Error(), "sync grants for created role") + require.Contains(t, err.Error(), "upsert local role record") grants, err := accessrepo.New(inspectConn).ListPrincipalGrantsByOrg(ctx, accessrepo.ListPrincipalGrantsByOrgParams{ OrganizationID: authCtx.ActiveOrganizationID, diff --git a/server/internal/access/deleterole_test.go b/server/internal/access/deleterole_test.go index 9ec1c05421..b534987530 100644 --- a/server/internal/access/deleterole_test.go +++ b/server/internal/access/deleterole_test.go @@ -71,7 +71,7 @@ func TestService_DeleteRole_ReassignsMembersToDefault(t *testing.T) { require.NoError(t, err) } -func TestService_DeleteRole_ReassignFailureHaltsDelete(t *testing.T) { +func TestService_DeleteRole_ReassignFailureDoesNotHaltDelete(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -80,20 +80,20 @@ func TestService_DeleteRole_ReassignFailureHaltsDelete(t *testing.T) { require.NotNil(t, authCtx) roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder")) - ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", authz.SystemRoleMember).Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Once() + ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", authz.SystemRoleMember).Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Times(3) + ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(nil).Once() seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) - require.Error(t, err) - require.Contains(t, err.Error(), "reassign member to default role") + require.NoError(t, err) - // Grants must remain since reassignment failed before grant cleanup ran. grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder")) - require.Len(t, grants, 1) + require.Empty(t, grants) } -func TestService_DeleteRole_PartialReassignFailureStopsLoop(t *testing.T) { +func TestService_DeleteRole_PartialReassignFailureContinuesDelete(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -112,17 +112,15 @@ func TestService_DeleteRole_PartialReassignFailureStopsLoop(t *testing.T) { RoleSlug: authz.SystemRoleMember, CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("UpdateMemberRole", mock.Anything, "membership_2", authz.SystemRoleMember).Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Once() + ti.roles.On("UpdateMemberRole", mock.Anything, "membership_2", authz.SystemRoleMember).Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Times(3) + ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(nil).Once() seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) - require.Error(t, err) - require.Contains(t, err.Error(), "reassign member to default role") + require.NoError(t, err) - // The mock's AssertExpectations (registered in newMockRoleProvider) verifies - // that DeleteRole was never called and the loop stopped at the first failure. grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder")) - require.Len(t, grants, 1) + require.Empty(t, grants) } func TestService_DeleteRole_NotFound(t *testing.T) { @@ -156,15 +154,14 @@ func TestService_DeleteRole_WorkOSDeleteFailure(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) - ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(errors.New("workos unavailable")).Once() + ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(errors.New("workos unavailable")).Times(3) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) - require.Error(t, err) - require.Contains(t, err.Error(), "delete role in workos") + require.NoError(t, err) grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder")) - require.Len(t, grants, 1) + require.Empty(t, grants) } func TestService_DeleteRole_AuditLog(t *testing.T) { diff --git a/server/internal/access/listmembers_test.go b/server/internal/access/listmembers_test.go index 89c8eb8d73..619c32c610 100644 --- a/server/internal/access/listmembers_test.go +++ b/server/internal/access/listmembers_test.go @@ -43,7 +43,7 @@ func TestService_ListMembers(t *testing.T) { require.Equal(t, builderID, byID["local_user_2"].RoleID) } -func TestService_ListMembers_ExcludesDisconnectedUsers(t *testing.T) { +func TestService_ListMembers_IncludesWorkOSOnlyUsers(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -61,9 +61,15 @@ func TestService_ListMembers_ExcludesDisconnectedUsers(t *testing.T) { result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) - require.Len(t, result.Members, 1, "disconnected user should be excluded") - require.Equal(t, "local_user_1", result.Members[0].ID) - require.Equal(t, "Ada Lovelace", result.Members[0].Name) + require.Len(t, result.Members, 2) + + byID := map[string]*gen.AccessMember{} + for _, member := range result.Members { + byID[member.ID] = member + } + require.Equal(t, "Ada Lovelace", byID["local_user_1"].Name) + require.Equal(t, "user_2", byID["user_2"].ID) + require.Equal(t, "user_2", byID["user_2"].Name) } func TestService_ListMembers_UsesDatabaseOnly(t *testing.T) { @@ -76,5 +82,6 @@ func TestService_ListMembers_UsesDatabaseOnly(t *testing.T) { result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) - require.Empty(t, result.Members) + require.Len(t, result.Members, 1) + require.Equal(t, "user_1", result.Members[0].ID) } diff --git a/server/internal/access/listusergrants_test.go b/server/internal/access/listusergrants_test.go index 98f29955bb..68a0c077e9 100644 --- a/server/internal/access/listusergrants_test.go +++ b/server/internal/access/listusergrants_test.go @@ -167,7 +167,7 @@ func TestService_ListGrants_RBACDisabledReturnsFullAccess(t *testing.T) { ctx = contextvalues.SetAuthContext(ctx, authCtx) chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - ti.service.authz = authz.NewEngine(ti.service.logger, ti.conn, chConn, authztest.RBACAlwaysDisabled, authztest.ChallengeLoggingAlwaysDisabled, ti.roles, nil) + ti.service.authz = authz.NewEngine(ti.service.logger, ti.conn, chConn, authztest.RBACAlwaysDisabled, authztest.ChallengeLoggingAlwaysDisabled, ti.roles) result, err := ti.service.ListGrants(ctx, &gen.ListGrantsPayload{}) require.NoError(t, err) diff --git a/server/internal/access/queries.sql b/server/internal/access/queries.sql index 8f8885311d..0a564119f7 100644 --- a/server/internal/access/queries.sql +++ b/server/internal/access/queries.sql @@ -234,7 +234,7 @@ WHERE ora.organization_id = @organization_id AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL GROUP BY role_slug; --- name: ReplaceOrganizationRoleAssignment :exec +-- name: ReplaceOrganizationRoleAssignment :one WITH input_role_urn AS ( SELECT 'role:organization:' || id::text AS role_urn FROM organization_roles @@ -275,9 +275,13 @@ upserted AS ( workos_last_event_id = EXCLUDED.workos_last_event_id, updated_at = clock_timestamp() RETURNING role_urn -) +), +deleted AS ( DELETE FROM organization_role_assignments WHERE organization_role_assignments.organization_id = @organization_id AND organization_role_assignments.workos_user_id = @workos_user_id AND EXISTS (SELECT 1 FROM upserted) - AND organization_role_assignments.role_urn NOT IN (SELECT role_urn FROM upserted); + AND organization_role_assignments.role_urn NOT IN (SELECT role_urn FROM upserted) + RETURNING 1 +) +SELECT COUNT(*)::bigint FROM upserted; diff --git a/server/internal/access/rbac_test.go b/server/internal/access/rbac_test.go index c1948af065..c815b640b7 100644 --- a/server/internal/access/rbac_test.go +++ b/server/internal/access/rbac_test.go @@ -153,7 +153,8 @@ func TestService_CreateRole_AllowsOrgAdminGrant(t *testing.T) { role, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{Name: "Allowed", Description: "Allowed"}) require.NoError(t, err) - require.Equal(t, "role_allowed", role.ID) + require.NotEmpty(t, role.ID) + require.NotEqual(t, "role_allowed", role.ID) } func TestService_UpdateRole_ForbiddenWithoutOrgAdminGrant(t *testing.T) { diff --git a/server/internal/access/repo/queries.sql.go b/server/internal/access/repo/queries.sql.go index 43786c78ce..ac4a54faf4 100644 --- a/server/internal/access/repo/queries.sql.go +++ b/server/internal/access/repo/queries.sql.go @@ -620,18 +620,18 @@ func (q *Queries) MarkOrganizationRoleDeleted(ctx context.Context, arg MarkOrgan return result.RowsAffected(), nil } -const replaceOrganizationRoleAssignment = `-- name: ReplaceOrganizationRoleAssignment :exec +const replaceOrganizationRoleAssignment = `-- name: ReplaceOrganizationRoleAssignment :one WITH input_role_urn AS ( SELECT 'role:organization:' || id::text AS role_urn FROM organization_roles WHERE organization_roles.organization_id = $1 - AND organization_roles.workos_slug = $3 + AND organization_roles.workos_slug = $2 AND organization_roles.deleted IS FALSE AND organization_roles.workos_deleted IS FALSE UNION ALL SELECT 'role:global:' || id::text AS role_urn FROM global_roles - WHERE global_roles.workos_slug = $3 + WHERE global_roles.workos_slug = $2 AND global_roles.deleted IS FALSE AND global_roles.workos_deleted IS FALSE ), @@ -647,7 +647,7 @@ upserted AS ( ) SELECT $1, - $2, + $3, $4, input_role_urn.role_urn, $5, @@ -661,35 +661,41 @@ upserted AS ( workos_last_event_id = EXCLUDED.workos_last_event_id, updated_at = clock_timestamp() RETURNING role_urn -) +), +deleted AS ( DELETE FROM organization_role_assignments WHERE organization_role_assignments.organization_id = $1 - AND organization_role_assignments.workos_user_id = $2 + AND organization_role_assignments.workos_user_id = $3 AND EXISTS (SELECT 1 FROM upserted) AND organization_role_assignments.role_urn NOT IN (SELECT role_urn FROM upserted) + RETURNING 1 +) +SELECT COUNT(*)::bigint FROM upserted ` type ReplaceOrganizationRoleAssignmentParams struct { OrganizationID string - WorkosUserID string WorkosRoleSlug string + WorkosUserID string UserID pgtype.Text WorkosMembershipID pgtype.Text WorkosUpdatedAt pgtype.Timestamptz WorkosLastEventID pgtype.Text } -func (q *Queries) ReplaceOrganizationRoleAssignment(ctx context.Context, arg ReplaceOrganizationRoleAssignmentParams) error { - _, err := q.db.Exec(ctx, replaceOrganizationRoleAssignment, +func (q *Queries) ReplaceOrganizationRoleAssignment(ctx context.Context, arg ReplaceOrganizationRoleAssignmentParams) (int64, error) { + row := q.db.QueryRow(ctx, replaceOrganizationRoleAssignment, arg.OrganizationID, - arg.WorkosUserID, arg.WorkosRoleSlug, + arg.WorkosUserID, arg.UserID, arg.WorkosMembershipID, arg.WorkosUpdatedAt, arg.WorkosLastEventID, ) - return err + var column_1 int64 + err := row.Scan(&column_1) + return column_1, err } const upsertGlobalRole = `-- name: UpsertGlobalRole :exec diff --git a/server/internal/access/role_manager.go b/server/internal/access/role_manager.go index 81546edca9..3e7b51eebb 100644 --- a/server/internal/access/role_manager.go +++ b/server/internal/access/role_manager.go @@ -3,6 +3,7 @@ package access import ( "context" "errors" + "fmt" "log/slog" "regexp" "strings" @@ -28,6 +29,8 @@ var ErrRoleNotFound = errors.New("role not found") var validRoleNamePattern = regexp.MustCompile(`^[A-Za-z0-9 _-]+$`) +const workOSSyncAttempts = 3 + type RoleProvider interface { CreateRole(ctx context.Context, orgID string, opts workos.CreateRoleOpts) (*workos.Role, error) UpdateRole(ctx context.Context, orgID string, roleSlug string, opts workos.UpdateRoleOpts) (*workos.Role, error) @@ -36,7 +39,7 @@ type RoleProvider interface { GetOrgMembership(ctx context.Context, workOSUserID, workOSOrgID string) (*workos.Member, error) } -// RoleManager owns role reads from local records and role writes through WorkOS. +// RoleManager owns role reads and writes against local records, then syncs successful writes to WorkOS. type RoleManager struct { db *pgxpool.Pool logger *slog.Logger @@ -149,6 +152,14 @@ func (r *RoleManager) ListMembers(ctx context.Context, gramOrgID string) (*gen.L for _, assignment := range assignments { user, ok := localUsers[assignment.UserID] if !ok { + result = append(result, &gen.AccessMember{ + ID: assignment.WorkosUserID, + Name: assignment.WorkosUserID, + Email: "", + PhotoURL: nil, + RoleID: roles[assignment.RoleSlug], + JoinedAt: assignment.CreatedAt, + }) continue } @@ -170,7 +181,8 @@ type roleCreateResult struct { Slug string } -// CreateRole creates a WorkOS role, upserts the local role record, syncs local grants, and optionally assigns members. +// CreateRole creates the local role record, syncs local grants, optionally assigns members, and then best-effort syncs WorkOS. +// Side effects: writes Postgres role/grant/assignment records, invalidates local authz state, and logs WorkOS sync failures after bounded retries. func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID string, payload *gen.CreateRolePayload) (roleCreateResult, error) { roleSlug, err := slugify(payload.Name) if err != nil { @@ -178,42 +190,60 @@ func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID str } trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSlug(roleSlug)) - wr, err := r.roles.CreateRole(ctx, workosOrgID, workos.CreateRoleOpts{ + now := time.Now().UTC().Format(time.RFC3339) + if err := repo.New(r.db).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, workos.Role{ + ID: "", Name: payload.Name, Slug: roleSlug, Description: payload.Description, - }) - var apiErr *workos.APIError - switch { - case errors.As(err, &apiErr) && apiErr.StatusCode == 409: - return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "create role in workos").Log(ctx, r.logger) - case err != nil: - return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "create role in workos").Log(ctx, r.logger) - } - trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleID(wr.ID)) - if err := repo.New(r.db).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, *wr)); err != nil { + CreatedAt: now, + UpdatedAt: now, + })); err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - _ = oops.E(oops.CodeUnexpected, err, "upsert local role record after workos write").Log(ctx, r.logger) + return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "upsert local role record").Log(ctx, r.logger) } + createdRole, err := r.getLocalRoleBySlug(ctx, gramOrgID, roleSlug) + if err != nil { + return roleCreateResult{}, err + } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleID(createdRole.ID)) - if err := authz.SyncGrants(ctx, r.logger, r.db, gramOrgID, wr.Slug, roleGrantPayloads(payload.Grants)); err != nil { + if err := authz.SyncGrants(ctx, r.logger, r.db, gramOrgID, roleSlug, roleGrantPayloads(payload.Grants)); err != nil { return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "sync grants for created role").Log(ctx, r.logger) } - assignedCount := 0 + r.syncWorkOS(ctx, "create role in workos", func() error { + _, err := r.roles.CreateRole(ctx, workosOrgID, workos.CreateRoleOpts{ + Name: payload.Name, + Slug: roleSlug, + Description: payload.Description, + }) + var apiErr *workos.APIError + if errors.As(err, &apiErr) && apiErr.StatusCode == 409 { + return nil + } + if err == nil { + return nil + } + return fmt.Errorf("create role in workos: %w", err) + }) + if len(payload.MemberIds) > 0 { - assignedCount, err = r.assignMembersToRole(ctx, gramOrgID, wr.Slug, payload.MemberIds) - if err != nil { + if _, err := r.assignMembersToRole(ctx, gramOrgID, roleSlug, payload.MemberIds); err != nil { return roleCreateResult{}, err } } - role, err := r.roleViewFromLocalRole(ctx, gramOrgID, localRoleFromWorkOS(*wr), assignedCount) + memberCounts, err := r.memberCounts(ctx, gramOrgID) + if err != nil { + return roleCreateResult{}, err + } + role, err := r.roleViewFromLocalRole(ctx, gramOrgID, createdRole, memberCounts[createdRole.Slug]) if err != nil { return roleCreateResult{}, err } - return roleCreateResult{Role: role, Slug: wr.Slug}, nil + return roleCreateResult{Role: role, Slug: roleSlug}, nil } type localRole struct { @@ -231,7 +261,8 @@ type roleUpdateResult struct { Role localRole } -// UpdateRole updates an existing role and optionally replaces its assigned members. +// UpdateRole updates an existing local role and optionally replaces its assigned members, then best-effort syncs WorkOS. +// Side effects: writes Postgres role/grant/assignment records, invalidates local authz state, and logs WorkOS sync failures after bounded retries. func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID string, payload *gen.UpdateRolePayload) (roleUpdateResult, error) { currentRole, err := r.getLocalRoleByID(ctx, gramOrgID, payload.ID) if err != nil { @@ -261,17 +292,28 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str updatedRole := currentRole if !sysRole { - wRole, err := r.roles.UpdateRole(ctx, workosOrgID, currentRole.Slug, workos.UpdateRoleOpts{ - Name: payload.Name, - Description: payload.Description, - }) - if err != nil { - return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "update role in workos").Log(ctx, r.logger) + localRecord := currentRole + if payload.Name != nil { + localRecord.Name = *payload.Name } - updatedRole = localRoleFromWorkOS(*wRole) - if err := repo.New(r.db).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, *wRole)); err != nil { + if payload.Description != nil { + localRecord.Description = *payload.Description + } + localRecord.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + if err := repo.New(r.db).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, workos.Role{ + ID: "", + Name: localRecord.Name, + Slug: localRecord.Slug, + Description: localRecord.Description, + CreatedAt: localRecord.CreatedAt, + UpdatedAt: localRecord.UpdatedAt, + })); err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - _ = oops.E(oops.CodeUnexpected, err, "upsert local role record after workos write").Log(ctx, r.logger) + return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "upsert local role record").Log(ctx, r.logger) + } + updatedRole, err = r.getLocalRoleBySlug(ctx, gramOrgID, localRecord.Slug) + if err != nil { + return roleUpdateResult{}, err } if payload.Grants != nil { @@ -279,6 +321,17 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "sync grants for updated role").Log(ctx, r.logger) } } + + r.syncWorkOS(ctx, "update role in workos", func() error { + _, err := r.roles.UpdateRole(ctx, workosOrgID, currentRole.Slug, workos.UpdateRoleOpts{ + Name: payload.Name, + Description: payload.Description, + }) + if err == nil { + return nil + } + return fmt.Errorf("update role in workos: %w", err) + }) } if payload.MemberIds != nil { @@ -299,8 +352,8 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str return roleUpdateResult{Before: existingRole, After: updatedRoleView, Role: updatedRole}, nil } -// DeleteRole deletes a custom role after moving assigned members to the default member role. -// Side effects: reads local records; writes WorkOS membership reassignments before deleting the WorkOS role; upserts local assignment records, marks the local role deleted, invalidates role caches, and deletes local role grants. +// DeleteRole deletes a custom local role after moving assigned members to the default member role, then best-effort syncs WorkOS. +// Side effects: writes Postgres assignment/role/grant records, invalidates local authz state, and logs WorkOS sync failures after bounded retries. func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, roleID string) (localRole, error) { currentRole, err := r.getLocalRoleByID(ctx, gramOrgID, roleID) if err != nil { @@ -316,33 +369,34 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro } trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) - reassigned := false for _, row := range rows { if row.RoleSlug != currentRole.Slug { continue } membershipID := conv.FromPGTextOrEmpty[string](row.WorkosMembershipID) - if _, err := r.roles.UpdateMemberRole(ctx, membershipID, authz.SystemRoleMember); err != nil { - if reassigned { - r.authz.InvalidateAllRoleCaches(ctx, gramOrgID) - } - return localRole{}, oops.E(oops.CodeUnexpected, err, "reassign member to default role").Log(ctx, r.logger) - } if row.WorkosUserID != "" { - if err := repo.New(r.db).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, row.WorkosUserID, authz.SystemRoleMember, "", membershipID)); err != nil { + replaced, err := repo.New(r.db).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, row.WorkosUserID, authz.SystemRoleMember, "", membershipID)) + if err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + return localRole{}, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) + } + if replaced == 0 { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - _ = oops.E(oops.CodeUnexpected, err, "upsert local role assignment record after workos write").Log(ctx, r.logger) + return localRole{}, oops.E(oops.CodeUnexpected, nil, "upsert local role assignment record").Log(ctx, r.logger) } } - reassigned = true - } - if reassigned { - r.authz.InvalidateAllRoleCaches(ctx, gramOrgID) + if userID := conv.FromPGTextOrEmpty[string](row.UserID); userID != "" { + r.authz.InvalidateRoleCache(ctx, userID, gramOrgID) + } + r.syncWorkOS(ctx, "reassign member to default role in workos", func() error { + _, err := r.roles.UpdateMemberRole(ctx, membershipID, authz.SystemRoleMember) + if err == nil { + return nil + } + return fmt.Errorf("reassign member to default role in workos: %w", err) + }) } - if err := r.roles.DeleteRole(ctx, workosOrgID, currentRole.Slug); err != nil { - return localRole{}, oops.E(oops.CodeUnexpected, err, "delete role in workos").Log(ctx, r.logger) - } if _, err := repo.New(r.db).MarkOrganizationRoleDeleted(ctx, repo.MarkOrganizationRoleDeletedParams{ OrganizationID: gramOrgID, WorkosSlug: currentRole.Slug, @@ -350,7 +404,7 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro WorkosLastEventID: conv.ToPGTextEmpty(""), }); err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - _ = oops.E(oops.CodeUnexpected, err, "mark local role record deleted after workos write").Log(ctx, r.logger) + return localRole{}, oops.E(oops.CodeUnexpected, err, "mark local role record deleted").Log(ctx, r.logger) } if _, err := repo.New(r.db).DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ @@ -360,6 +414,14 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro return localRole{}, oops.E(oops.CodeUnexpected, err, "delete grants for deleted role").Log(ctx, r.logger) } + r.authz.InvalidateAllRoleCaches(ctx, gramOrgID) + r.syncWorkOS(ctx, "delete role in workos", func() error { + if err := r.roles.DeleteRole(ctx, workosOrgID, currentRole.Slug); err != nil { + return fmt.Errorf("delete role in workos: %w", err) + } + return nil + }) + return currentRole, nil } @@ -372,8 +434,8 @@ type memberRoleUpdateContext struct { After *gen.AccessMember } -// UpdateMemberRole changes one member's role assignment. -// Side effects: reads local user, role, and local assignment records; writes the WorkOS membership first, upserts the local assignment record, and invalidates that member's role cache. +// UpdateMemberRole changes one member's local role assignment, then best-effort syncs WorkOS. +// Side effects: writes a Postgres assignment record, invalidates local authz state, and logs WorkOS sync failures after bounded retries. func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, roleID string) (memberRoleUpdateContext, error) { role, err := r.getLocalRoleByID(ctx, gramOrgID, roleID) if err != nil { @@ -423,17 +485,25 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r return memberRoleUpdateContext{}, oops.E(oops.CodeNotFound, nil, "member not found").Log(ctx, r.logger) } - updatedMember, err := r.roles.UpdateMemberRole(ctx, existing.MembershipID, role.Slug) - if err != nil { - return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "update member role in workos").Log(ctx, r.logger) - } - if updatedMember.UserID != "" && role.Slug != "" { - if err := repo.New(r.db).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, updatedMember.UserID, role.Slug, connectedUser.ID, updatedMember.ID)); err != nil { + if existing.WorkosUserID != "" && role.Slug != "" { + replaced, err := repo.New(r.db).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, existing.WorkosUserID, role.Slug, connectedUser.ID, existing.MembershipID)) + if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - _ = oops.E(oops.CodeUnexpected, err, "upsert local role assignment record after workos write").Log(ctx, r.logger) + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) + } + if replaced == 0 { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, nil, "upsert local role assignment record").Log(ctx, r.logger) } } r.authz.InvalidateRoleCache(ctx, userID, gramOrgID) + r.syncWorkOS(ctx, "update member role in workos", func() error { + _, err := r.roles.UpdateMemberRole(ctx, existing.MembershipID, role.Slug) + if err == nil { + return nil + } + return fmt.Errorf("update member role in workos: %w", err) + }) memberName := conv.Default(connectedUser.DisplayName, connectedUser.Email) return memberRoleUpdateContext{ @@ -502,6 +572,26 @@ func (r *RoleManager) getLocalRoleByID(ctx context.Context, gramOrgID, id string return localRoleFromRoleRow(row), nil } +// getLocalRoleBySlug loads one local organization role record by WorkOS slug. +// Side effects: reads Postgres local role records; does not call WorkOS. +func (r *RoleManager) getLocalRoleBySlug(ctx context.Context, gramOrgID, slug string) (localRole, error) { + row, err := repo.New(r.db).GetOrganizationRoleBySlug(ctx, repo.GetOrganizationRoleBySlugParams{ + OrganizationID: gramOrgID, + WorkosSlug: slug, + }) + switch { + case errors.Is(err, pgx.ErrNoRows): + return localRole{}, oops.E(oops.CodeNotFound, ErrRoleNotFound, "role not found").Log(ctx, r.logger) + case err != nil: + return localRole{}, oops.E(oops.CodeUnexpected, err, "get role").Log(ctx, r.logger) + case row.Deleted || row.WorkosDeleted: + return localRole{}, oops.E(oops.CodeNotFound, ErrRoleNotFound, "role not found").Log(ctx, r.logger) + } + + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + return localRoleFromOrganizationRole(row), nil +} + type memberAssignmentTarget struct { UserID string WorkosUserID string @@ -513,6 +603,10 @@ func (r *RoleManager) memberAssignmentTargets(ctx context.Context, gramOrgID str if len(memberIDs) == 0 { return nil, nil } + requested := make(map[string]struct{}, len(memberIDs)) + for _, id := range memberIDs { + requested[id] = struct{}{} + } users, err := usersrepo.New(r.db).GetUsersByIDs(ctx, memberIDs) if err != nil { @@ -536,27 +630,38 @@ func (r *RoleManager) memberAssignmentTargets(ctx context.Context, gramOrgID str } targets := make([]memberAssignmentTarget, 0, len(memberIDs)) - for _, gramID := range memberIDs { - workosID, ok := workosByGramID[gramID] - if !ok { + for _, row := range assignmentRows { + userID := conv.FromPGTextOrEmpty[string](row.UserID) + membershipID := conv.FromPGTextOrEmpty[string](row.WorkosMembershipID) + if _, ok := requested[row.WorkosUserID]; ok { + targets = append(targets, memberAssignmentTarget{ + UserID: userID, + WorkosUserID: row.WorkosUserID, + MembershipID: membershipID, + }) continue } - membershipID, ok := membershipByWorkosID[workosID] + + workosID, ok := workosByGramID[userID] if !ok { continue } - targets = append(targets, memberAssignmentTarget{ - UserID: gramID, - WorkosUserID: workosID, - MembershipID: membershipID, - }) + if _, ok := requested[userID]; ok { + if _, ok := membershipByWorkosID[workosID]; ok { + targets = append(targets, memberAssignmentTarget{ + UserID: userID, + WorkosUserID: workosID, + MembershipID: membershipID, + }) + } + } } return targets, nil } -// assignMembersToRole moves each requested member to the given WorkOS role and mirrors the result locally. -// Side effects: reads local users and assignments, writes WorkOS memberships, upserts local assignment records, and invalidates org role caches when any member is assigned. +// assignMembersToRole moves each requested member to the given local role and best-effort syncs WorkOS. +// Side effects: reads local users and assignments, writes Postgres assignment records, invalidates local authz state, and logs WorkOS sync failures after bounded retries. func (r *RoleManager) assignMembersToRole(ctx context.Context, gramOrgID, roleSlug string, memberIDs []string) (int, error) { targets, err := r.memberAssignmentTargets(ctx, gramOrgID, memberIDs) if err != nil { @@ -565,16 +670,26 @@ func (r *RoleManager) assignMembersToRole(ctx context.Context, gramOrgID, roleSl assignedCount := 0 for _, target := range targets { - if _, err := r.roles.UpdateMemberRole(ctx, target.MembershipID, roleSlug); err != nil { - return 0, oops.E(oops.CodeUnexpected, err, "assign members to role").Log(ctx, r.logger) - } if target.WorkosUserID != "" && roleSlug != "" { - if err := repo.New(r.db).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, target.WorkosUserID, roleSlug, target.UserID, target.MembershipID)); err != nil { + replaced, err := repo.New(r.db).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, target.WorkosUserID, roleSlug, target.UserID, target.MembershipID)) + if err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + return 0, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) + } + if replaced == 0 { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - _ = oops.E(oops.CodeUnexpected, err, "upsert local role assignment record after workos write").Log(ctx, r.logger) + return 0, oops.E(oops.CodeUnexpected, nil, "upsert local role assignment record").Log(ctx, r.logger) } } assignedCount++ + r.authz.InvalidateRoleCache(ctx, target.UserID, gramOrgID) + r.syncWorkOS(ctx, "assign member to role in workos", func() error { + _, err := r.roles.UpdateMemberRole(ctx, target.MembershipID, roleSlug) + if err == nil { + return nil + } + return fmt.Errorf("assign member to role in workos: %w", err) + }) } if assignedCount > 0 { r.authz.InvalidateAllRoleCaches(ctx, gramOrgID) @@ -583,6 +698,40 @@ func (r *RoleManager) assignMembersToRole(ctx context.Context, gramOrgID, roleSl return assignedCount, nil } +// syncWorkOS runs a bounded best-effort WorkOS write after the local database already accepted the change. +// Side effects: calls WorkOS, waits briefly between retryable failures, and logs the final failure without returning it. +func (r *RoleManager) syncWorkOS(ctx context.Context, operation string, fn func() error) { + var err error + for attempt := 1; attempt <= workOSSyncAttempts; attempt++ { + err = fn() + if err == nil { + return + } + if !retryWorkOSError(err) || attempt == workOSSyncAttempts { + break + } + + select { + case <-ctx.Done(): + err = ctx.Err() + attempt = workOSSyncAttempts + case <-time.After(time.Duration(attempt) * 100 * time.Millisecond): + } + } + + r.logger.ErrorContext(ctx, "workos sync failed: "+operation, attr.SlogError(err)) +} + +// retryWorkOSError reports whether a WorkOS sync failure is worth retrying in-process. +// Side effects: none. +func retryWorkOSError(err error) bool { + var apiErr *workos.APIError + if !errors.As(err, &apiErr) { + return true + } + return apiErr.StatusCode == 429 || apiErr.StatusCode >= 500 +} + // memberCounts returns the number of locally connected members per role slug. // Side effects: reads Postgres local assignment records; does not call WorkOS. func (r *RoleManager) memberCounts(ctx context.Context, gramOrgID string) (map[string]int, error) { @@ -649,20 +798,20 @@ func localRoleFromRoleRow(row repo.GetOrganizationRoleByIDRow) localRole { } } -// localRoleFromWorkOS converts a WorkOS role response into the manager's internal local role record shape. +// localRoleFromOrganizationRole converts an organization role row into the manager's internal local role record shape. // Side effects: none. -func localRoleFromWorkOS(role workos.Role) localRole { +func localRoleFromOrganizationRole(row repo.OrganizationRole) localRole { return localRole{ - ID: role.ID, - Name: role.Name, - Slug: role.Slug, - Description: role.Description, - CreatedAt: role.CreatedAt, - UpdatedAt: role.UpdatedAt, + ID: row.ID.String(), + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), } } -// organizationRoleParams builds the SQL parameters for storing a local role record from a WorkOS write response. +// organizationRoleParams builds the SQL parameters for storing the authoritative local role record. // Side effects: reads the clock for updated_at and possibly created_at fallback. func organizationRoleParams(gramOrgID string, role workos.Role) repo.UpsertOrganizationRoleParams { return repo.UpsertOrganizationRoleParams{ @@ -671,12 +820,12 @@ func organizationRoleParams(gramOrgID string, role workos.Role) repo.UpsertOrgan WorkosName: role.Name, WorkosDescription: conv.ToPGTextEmpty(role.Description), WorkosCreatedAt: conv.ToPGTimestamptz(workosTimeOrNow(role.CreatedAt)), - WorkosUpdatedAt: conv.ToPGTimestamptz(time.Now().UTC()), + WorkosUpdatedAt: conv.ToPGTimestamptz(workosTimeOrNow(role.UpdatedAt)), WorkosLastEventID: conv.ToPGTextEmpty(""), } } -// replaceRoleAssignmentParams builds SQL parameters for storing the authoritative local role assignment after a WorkOS write. +// replaceRoleAssignmentParams builds SQL parameters for storing the authoritative local role assignment. // Side effects: reads the clock for updated_at. func replaceRoleAssignmentParams(gramOrgID, workosUserID, roleSlug, userID, membershipID string) repo.ReplaceOrganizationRoleAssignmentParams { return repo.ReplaceOrganizationRoleAssignmentParams{ diff --git a/server/internal/access/role_manager_test.go b/server/internal/access/role_manager_test.go index 012426e88b..b8c23e7907 100644 --- a/server/internal/access/role_manager_test.go +++ b/server/internal/access/role_manager_test.go @@ -72,7 +72,7 @@ func TestRoleManager_MembersAndCounts(t *testing.T) { manager := ti.service.roleMgr members, err := manager.ListMembers(ctx, authCtx.ActiveOrganizationID) require.NoError(t, err) - require.Len(t, members.Members, 2) + require.Len(t, members.Members, 3) slugs, err := manager.MemberRoleSlugs(ctx, authCtx.ActiveOrganizationID, "user_2") require.NoError(t, err) diff --git a/server/internal/access/setup_test.go b/server/internal/access/setup_test.go index 8f77a1f6a7..2ab999b78a 100644 --- a/server/internal/access/setup_test.go +++ b/server/internal/access/setup_test.go @@ -95,7 +95,7 @@ func newTestAccessService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := NewService(logger, tracerProvider, conn, chConn, sessionManager, NewRoleManager(logger, conn, roles, authzEngine), authzEngine, noopFeatureCacheWriter{}, auditLogger) return ctx, &testInstance{ @@ -187,7 +187,7 @@ func seedRoleAssignment(t *testing.T, ctx context.Context, conn *pgxpool.Pool, o updatedAt = parsed } - err := accessrepo.New(conn).ReplaceOrganizationRoleAssignment(ctx, accessrepo.ReplaceOrganizationRoleAssignmentParams{ + replaced, err := accessrepo.New(conn).ReplaceOrganizationRoleAssignment(ctx, accessrepo.ReplaceOrganizationRoleAssignmentParams{ OrganizationID: organizationID, WorkosUserID: member.UserID, WorkosRoleSlug: member.RoleSlug, @@ -197,6 +197,7 @@ func seedRoleAssignment(t *testing.T, ctx context.Context, conn *pgxpool.Pool, o WorkosLastEventID: conv.ToPGTextEmpty(""), }) require.NoError(t, err) + require.Equal(t, int64(1), replaced) } // seedDisconnectedUser creates a user in the users table with a workos_id but diff --git a/server/internal/access/updatememberrole_test.go b/server/internal/access/updatememberrole_test.go index 1b8821c266..21886e44a4 100644 --- a/server/internal/access/updatememberrole_test.go +++ b/server/internal/access/updatememberrole_test.go @@ -96,11 +96,11 @@ func TestService_UpdateMemberRole_WorkOSFailure(t *testing.T) { seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) - ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Once() + ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Times(3) - _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) - require.Error(t, err) - require.Contains(t, err.Error(), "update member role in workos") + member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) + require.NoError(t, err) + require.Equal(t, builderID, member.RoleID) } func TestService_UpdateMemberRole_AuditLog(t *testing.T) { diff --git a/server/internal/access/updaterole_test.go b/server/internal/access/updaterole_test.go index 9bd47b78e2..f1adfada8a 100644 --- a/server/internal/access/updaterole_test.go +++ b/server/internal/access/updaterole_test.go @@ -75,13 +75,14 @@ func TestService_UpdateRole(t *testing.T) { MemberIds: []string{"local_user_1", "local_user_2"}, }) require.NoError(t, err) - require.Equal(t, "role_custom", role.ID) + require.Equal(t, roleID, role.ID) require.Equal(t, name, role.Name) require.Equal(t, description, role.Description) require.False(t, role.IsSystem) require.Equal(t, 3, role.MemberCount) require.Equal(t, mockRoleTimestamp, role.CreatedAt) - require.Equal(t, mockRoleTimestamp, role.UpdatedAt) + require.NotEmpty(t, role.UpdatedAt) + require.NotEqual(t, mockRoleTimestamp, role.UpdatedAt) require.Len(t, role.Grants, 2) grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder")) @@ -231,11 +232,11 @@ func TestService_UpdateRole_WorkOSUpdateFailure(t *testing.T) { authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) - ti.roles.On("UpdateRole", mock.Anything, mockidp.MockOrgID, "custom-builder", thirdpartyworkos.UpdateRoleOpts{}).Return((*thirdpartyworkos.Role)(nil), errors.New("workos unavailable")).Once() + ti.roles.On("UpdateRole", mock.Anything, mockidp.MockOrgID, "custom-builder", thirdpartyworkos.UpdateRoleOpts{}).Return((*thirdpartyworkos.Role)(nil), errors.New("workos unavailable")).Times(3) - _, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: roleID}) - require.Error(t, err) - require.Contains(t, err.Error(), "update role in workos") + role, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: roleID}) + require.NoError(t, err) + require.Equal(t, roleID, role.ID) } func TestService_UpdateRole_AuditLog(t *testing.T) { diff --git a/server/internal/assets/setup_test.go b/server/internal/assets/setup_test.go index 57238c54a6..a11baf95e2 100644 --- a/server/internal/assets/setup_test.go +++ b/server/internal/assets/setup_test.go @@ -101,7 +101,6 @@ func newTestAssetsService(t *testing.T) (context.Context, *testInstance) { authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), - cache.NoopCache, ), auditLogger, ) diff --git a/server/internal/assistantmemories/impl_test.go b/server/internal/assistantmemories/impl_test.go index d5c0867e54..275e9c18a1 100644 --- a/server/internal/assistantmemories/impl_test.go +++ b/server/internal/assistantmemories/impl_test.go @@ -13,7 +13,6 @@ import ( gen "github.com/speakeasy-api/gram/server/gen/assistant_memories" "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/authztest" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/memory" "github.com/speakeasy-api/gram/server/internal/memory/repo" @@ -76,7 +75,7 @@ func newTestHarness(t *testing.T) (*testHarness, context.Context) { mem := &fakeMemory{listFn: nil, getFn: nil, deleteFn: nil} features := &fakeFeatures{enabled: true, err: nil} - authzEngine := authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysDisabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysDisabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := &Service{ tracer: tracerProvider.Tracer("test"), @@ -167,7 +166,7 @@ func TestListAssistantMemories_RBACDenied(t *testing.T) { h, ctx := newTestHarness(t) logger := testenv.NewLogger(t) - h.svc.authz = authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + h.svc.authz = authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx) _, err := h.svc.ListAssistantMemories(ctx, &gen.ListAssistantMemoriesPayload{ @@ -349,7 +348,7 @@ func TestGetAssistantMemory_RBACDenied(t *testing.T) { h, ctx := newTestHarness(t) logger := testenv.NewLogger(t) - h.svc.authz = authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + h.svc.authz = authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx) _, err := h.svc.GetAssistantMemory(ctx, &gen.GetAssistantMemoryPayload{ @@ -433,7 +432,7 @@ func TestDeleteAssistantMemory_RBACDenied(t *testing.T) { h, ctx := newTestHarness(t) logger := testenv.NewLogger(t) - h.svc.authz = authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + h.svc.authz = authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.NewGrant(authz.ScopeProjectRead, h.projectID.String())) diff --git a/server/internal/assistants/impl_test.go b/server/internal/assistants/impl_test.go index f685ae708b..de094f06c8 100644 --- a/server/internal/assistants/impl_test.go +++ b/server/internal/assistants/impl_test.go @@ -13,7 +13,6 @@ import ( "github.com/speakeasy-api/gram/server/gen/types" "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/authztest" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/oops" projectsRepo "github.com/speakeasy-api/gram/server/internal/projects/repo" @@ -254,7 +253,7 @@ func newRBACServiceWithConn(t *testing.T, dbName string) (*Service, context.Cont logger := testenv.NewLogger(t) chConn, err := assistantsInfra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) service := &Service{ tracer: testenv.NewTracerProvider(t).Tracer("test"), logger: logger, diff --git a/server/internal/auditapi/setup_test.go b/server/internal/auditapi/setup_test.go index 51501d55ec..34df18315c 100644 --- a/server/internal/auditapi/setup_test.go +++ b/server/internal/auditapi/setup_test.go @@ -71,7 +71,7 @@ func newTestAuditService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) return ctx, &testInstance{ service: auditapi.NewService(logger, tracerProvider, conn, sessionManager, authzEngine), diff --git a/server/internal/auth/setup_test.go b/server/internal/auth/setup_test.go index 5736fbf9b6..90fb4c3d15 100644 --- a/server/internal/auth/setup_test.go +++ b/server/internal/auth/setup_test.go @@ -322,7 +322,7 @@ func newTestAuthService(t *testing.T, userInfo *MockUserInfo) (context.Context, chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := auth.NewService(logger, tracerProvider, conn, sessionManager, authConfigs, authzEngine, billingClient, noopCancelScheduler{}, posthog) return ctx, newTestAuthServiceResult(t, svc, conn, sessionManager, mockServer, authConfigs) @@ -366,7 +366,7 @@ func newTestAuthServiceWithAuthz(t *testing.T, userInfo *MockUserInfo) (context. chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := auth.NewService(logger, tracerProvider, conn, sessionManager, authConfigs, authzEngine, billingClient, noopCancelScheduler{}, posthog) return ctx, newTestAuthServiceResult(t, svc, conn, sessionManager, mockServer, authConfigs) diff --git a/server/internal/authz/context_test.go b/server/internal/authz/context_test.go index 2efc754cbc..c4b3d065a6 100644 --- a/server/internal/authz/context_test.go +++ b/server/internal/authz/context_test.go @@ -7,7 +7,6 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" @@ -24,7 +23,7 @@ func TestPrepareContext_loadsUserGrants(t *testing.T) { conn := newTestDB(t) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) @@ -49,7 +48,7 @@ func TestPrepareContext_skipsNonSessionAuth(t *testing.T) { conn := newTestDB(t) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) @@ -71,7 +70,7 @@ func TestPrepareContext_loadsAssistantPrincipalGrants(t *testing.T) { conn := newTestDB(t) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) @@ -101,7 +100,7 @@ func TestShouldEnforce_assistantPrincipalOnEnterpriseOrgEnforces(t *testing.T) { conn := newTestDB(t) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) @@ -126,7 +125,7 @@ func TestShouldEnforce_assistantPrincipalOnNonEnterpriseSkips(t *testing.T) { conn := newTestDB(t) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) @@ -150,7 +149,7 @@ func TestPrepareContext_skipsNonEnterpriseOrgs(t *testing.T) { conn := newTestDB(t) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) diff --git a/server/internal/authz/engine.go b/server/internal/authz/engine.go index 9959c29506..db211a25f2 100644 --- a/server/internal/authz/engine.go +++ b/server/internal/authz/engine.go @@ -5,16 +5,14 @@ import ( "errors" "fmt" "log/slog" - "time" "github.com/ClickHouse/clickhouse-go/v2" "github.com/jackc/pgx/v5/pgxpool" + accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" "github.com/speakeasy-api/gram/server/internal/attr" authzrepo "github.com/speakeasy-api/gram/server/internal/authz/repo" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/oops" - orgrepo "github.com/speakeasy-api/gram/server/internal/organizations/repo" "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" usersrepo "github.com/speakeasy-api/gram/server/internal/users/repo" @@ -31,28 +29,6 @@ type EngineOpts struct { DevMode bool } -// roleSlugCache is the Redis cache entry for a resolved role slug. -// Key is org-first so DeleteByPrefix on "role-slug:{orgID}:" invalidates the whole org. -type roleSlugCache struct { - UserID string - OrgID string - Slug string -} - -var _ cache.CacheableObject[roleSlugCache] = (*roleSlugCache)(nil) - -func (r roleSlugCache) CacheKey() string { - return "role-slug:" + r.OrgID + ":" + r.UserID -} - -func (r roleSlugCache) TTL() time.Duration { - return 5 * time.Minute -} - -func (r roleSlugCache) AdditionalCacheKeys() []string { - return nil -} - // ChallengeLoggingEnabled checks whether authz challenge logging to ClickHouse // is enabled for a given organization. Same signature as IsRBACEnabled. type ChallengeLoggingEnabled func(ctx context.Context, organizationID string) (bool, error) @@ -65,10 +41,9 @@ type Engine struct { challengeLoggingEnabled ChallengeLoggingEnabled isDev bool membership MembershipFetcher - roleCache cache.TypedCacheObject[roleSlugCache] } -func NewEngine(logger *slog.Logger, db *pgxpool.Pool, chDB clickhouse.Conn, isEnabled IsRBACEnabled, challengeLogging ChallengeLoggingEnabled, membership MembershipFetcher, roleCache cache.Cache, opts ...EngineOpts) *Engine { +func NewEngine(logger *slog.Logger, db *pgxpool.Pool, chDB clickhouse.Conn, isEnabled IsRBACEnabled, challengeLogging ChallengeLoggingEnabled, membership MembershipFetcher, opts ...EngineOpts) *Engine { var devMode bool if len(opts) > 0 { devMode = opts[0].DevMode @@ -84,7 +59,6 @@ func NewEngine(logger *slog.Logger, db *pgxpool.Pool, chDB clickhouse.Conn, isEn challengeLoggingEnabled: challengeLogging, isDev: devMode, membership: membership, - roleCache: cache.NewTypedObjectCache[roleSlugCache](logger.With(attr.SlogCacheNamespace("authz-role-slug")), roleCache, cache.SuffixNone), } } @@ -179,76 +153,36 @@ func (e *Engine) PrepareContext(ctx context.Context) (context.Context, error) { } func (e *Engine) resolveRoleSlug(ctx context.Context, userID, orgID string) (string, error) { - cacheKey := roleSlugCache{UserID: userID, OrgID: orgID, Slug: ""}.CacheKey() - if cached, err := e.roleCache.Get(ctx, cacheKey); err == nil { - return cached.Slug, nil - } - user, err := usersrepo.New(e.db).GetUser(ctx, userID) if err != nil { return "", fmt.Errorf("get user: %w", err) } if !user.WorkosID.Valid || user.WorkosID.String == "" { - e.storeRoleSlugCache(ctx, userID, orgID, "") return "", nil } - org, err := orgrepo.New(e.db).GetOrganizationMetadata(ctx, orgID) + roleSlugs, err := accessrepo.New(e.db).ListMemberRoleSlugsByWorkosUser(ctx, accessrepo.ListMemberRoleSlugsByWorkosUserParams{ + OrganizationID: orgID, + WorkosUserID: user.WorkosID.String, + }) if err != nil { - return "", fmt.Errorf("get org: %w", err) + return "", fmt.Errorf("list member role slugs: %w", err) } - if !org.WorkosID.Valid || org.WorkosID.String == "" { - e.storeRoleSlugCache(ctx, userID, orgID, "") + if len(roleSlugs) == 0 { return "", nil } - member, err := e.membership.GetOrgMembership(ctx, user.WorkosID.String, org.WorkosID.String) - if err != nil { - return "", fmt.Errorf("get org membership: %w", err) - } - if member == nil { - e.storeRoleSlugCache(ctx, userID, orgID, "") - return "", nil - } - - e.storeRoleSlugCache(ctx, userID, orgID, member.RoleSlug) - - return member.RoleSlug, nil -} - -func (e *Engine) storeRoleSlugCache(ctx context.Context, userID, orgID, slug string) { - entry := roleSlugCache{UserID: userID, OrgID: orgID, Slug: slug} - if err := e.roleCache.Store(ctx, entry); err != nil { - e.logger.WarnContext(ctx, "failed to cache role slug", - attr.SlogUserID(userID), - attr.SlogOrganizationID(orgID), - attr.SlogError(err), - ) - } + return roleSlugs[0], nil } -// InvalidateRoleCache removes the cached role slug for a single user. Call -// this after updating a specific member's role via UpdateMemberRole. +// InvalidateRoleCache is retained for callers that used to clear the Redis role cache. +// Role resolution now reads Postgres directly, so this is intentionally a no-op. func (e *Engine) InvalidateRoleCache(ctx context.Context, userID, orgID string) { - entry := roleSlugCache{UserID: userID, OrgID: orgID, Slug: ""} - if err := e.roleCache.Delete(ctx, entry); err != nil { - e.logger.WarnContext(ctx, "failed to invalidate cached role slug", - attr.SlogUserID(userID), - attr.SlogOrganizationID(orgID), - attr.SlogError(err), - ) - } } -// InvalidateAllRoleCaches removes all cached role slugs for an org. Call this -// after bulk role reassignments where individual user IDs aren't tracked. +// InvalidateAllRoleCaches is retained for callers that used to clear the Redis role cache. +// Role resolution now reads Postgres directly, so this is intentionally a no-op. func (e *Engine) InvalidateAllRoleCaches(ctx context.Context, orgID string) { - if err := e.roleCache.DeleteByPrefix(ctx, "role-slug:"+orgID+":"); err != nil { - e.logger.WarnContext(ctx, "failed to invalidate cached role slugs for org", - attr.SlogOrganizationID(orgID), - attr.SlogError(err), - ) - } } func (e *Engine) Require(ctx context.Context, checks ...Check) error { diff --git a/server/internal/authz/engine_test.go b/server/internal/authz/engine_test.go index b10465bf3d..5634a65003 100644 --- a/server/internal/authz/engine_test.go +++ b/server/internal/authz/engine_test.go @@ -2,10 +2,7 @@ package authz import ( "context" - "encoding/json" "errors" - "fmt" - "strings" "testing" "time" @@ -13,7 +10,6 @@ import ( "github.com/stretchr/testify/require" authzrepo "github.com/speakeasy-api/gram/server/internal/authz/repo" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/oops" "github.com/speakeasy-api/gram/server/internal/testenv" @@ -43,7 +39,7 @@ func TestEngineRequire_requiresAuthContext(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) err = engine.Require(t.Context(), Check{Scope: ScopeProjectRead, ResourceID: "proj_123"}) var oopsErr *oops.ShareableError @@ -56,7 +52,7 @@ func TestEngineRequire_skipsWhenRBACFeatureDisabled(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient()) err = engine.Require(enterpriseSessionCtx(t), Check{Scope: ScopeProjectRead, ResourceID: "proj_123"}) require.NoError(t, err) @@ -67,7 +63,7 @@ func TestEngineRequire_mapsDeniedToForbidden(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), nil) err = engine.Require(ctx, Check{Scope: ScopeProjectRead, ResourceID: "proj_123"}) @@ -81,7 +77,7 @@ func TestEngineRequire_mapsMissingGrantsToUnexpected(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) err = engine.Require(enterpriseSessionCtx(t), Check{Scope: ScopeProjectRead, ResourceID: "proj_123"}) var oopsErr *oops.ShareableError @@ -95,7 +91,7 @@ func TestEngineRequire_returnsUnexpectedWhenFeatureCheckFails(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, failingRBAC(errors.New("boom")), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, failingRBAC(errors.New("boom")), staticChallengeLogging(true), workos.NewStubClient()) err = engine.Require(enterpriseSessionCtx(t), Check{Scope: ScopeProjectRead, ResourceID: "proj_123"}) var oopsErr *oops.ShareableError @@ -103,7 +99,7 @@ func TestEngineRequire_returnsUnexpectedWhenFeatureCheckFails(t *testing.T) { require.Equal(t, oops.CodeUnexpected, oopsErr.Code) } -func TestResolveRoleSlug_cachesEmptyMembershipResult(t *testing.T) { +func TestResolveRoleSlug_readsLocalAssignmentsOnly(t *testing.T) { t.Parallel() ctx := enterpriseTestCtx(t.Context()) @@ -118,7 +114,7 @@ func TestResolveRoleSlug_cachesEmptyMembershipResult(t *testing.T) { membership := &countingMembershipFetcher{} chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, staticRBAC(true), staticChallengeLogging(true), membership, newMapCache()) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, staticRBAC(true), staticChallengeLogging(true), membership) roleSlug, err := engine.resolveRoleSlug(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) require.NoError(t, err) @@ -127,7 +123,7 @@ func TestResolveRoleSlug_cachesEmptyMembershipResult(t *testing.T) { roleSlug, err = engine.resolveRoleSlug(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) require.NoError(t, err) require.Empty(t, roleSlug) - require.Equal(t, 1, membership.calls) + require.Equal(t, 0, membership.calls) } func TestEngineRequireAny_mapsDeniedToForbidden(t *testing.T) { @@ -135,7 +131,7 @@ func TestEngineRequireAny_mapsDeniedToForbidden(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{NewGrant(ScopeMCPConnect, "tool_a")}) err = engine.RequireAny(ctx, @@ -152,7 +148,7 @@ func TestEngineFilter_returnsAllowedSubset(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{NewGrant(ScopeProjectRead, "proj_123")}) resourceIDs, err := engine.Filter(ctx, []Check{ @@ -170,7 +166,7 @@ func TestEngineFilter_logsSingleAggregateChallenge(t *testing.T) { ctx := GrantsToContext(enterpriseSessionCtxWithOrg(t, orgID), []Grant{NewGrant(ScopeProjectRead, "proj_allowed")}) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) resourceIDs, err := engine.Filter(ctx, []Check{ {Scope: ScopeProjectRead, ResourceID: "proj_allowed"}, @@ -219,7 +215,7 @@ func TestEngineFilter_logsDenyWhenNoMatches(t *testing.T) { ctx := GrantsToContext(enterpriseSessionCtxWithOrg(t, orgID), []Grant{NewGrant(ScopeProjectRead, "proj_other")}) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) resourceIDs, err := engine.Filter(ctx, []Check{ {Scope: ScopeProjectRead, ResourceID: "proj_a"}, @@ -262,7 +258,7 @@ func TestEngineFilter_skipsLogWhenNoChecks(t *testing.T) { ctx := GrantsToContext(enterpriseSessionCtxWithOrg(t, orgID), []Grant{NewGrant(ScopeProjectRead, WildcardResource)}) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) resourceIDs, err := engine.Filter(ctx, nil) require.NoError(t, err) @@ -286,7 +282,7 @@ func TestEngineFilter_withDimensions(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ {Scope: ScopeMCPConnect, Selector: Selector{ "resource_kind": "mcp", @@ -309,7 +305,7 @@ func TestEngineFilter_withDisposition(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ {Scope: ScopeMCPConnect, Selector: Selector{ "resource_kind": "mcp", @@ -332,7 +328,7 @@ func TestEngineFilter_serverLevelGrantAllowsAllDimensions(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ NewGrant(ScopeMCPConnect, "toolsetA"), }) @@ -351,7 +347,7 @@ func TestEngineFilter_projectScopedGrantMatchesServersInProject(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ {Scope: ScopeMCPConnect, Selector: Selector{ "resource_kind": "mcp", @@ -374,7 +370,7 @@ func TestEngineRequire_projectScopedGrantAllowsToolsInProject(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ {Scope: ScopeMCPConnect, Selector: Selector{ "resource_kind": "mcp", @@ -405,7 +401,7 @@ func TestEngineRequire_projectScopedMCPReadGrant(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ {Scope: ScopeMCPRead, Selector: Selector{ "resource_kind": "mcp", @@ -430,7 +426,7 @@ func TestEngineFilter_projectAndServerGrantsCombine(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ // Project-scoped grant for proj_A {Scope: ScopeMCPConnect, Selector: Selector{ @@ -456,7 +452,7 @@ func TestEngineRequire_rejectsInvalidCheck(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{NewGrant(ScopeProjectRead, WildcardResource)}) err = engine.Require(ctx, Check{Scope: ScopeProjectRead, ResourceID: ""}) @@ -471,7 +467,7 @@ func TestEngineRequire_requiresChecks(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{NewGrant(ScopeProjectRead, WildcardResource)}) err = engine.Require(ctx) @@ -486,7 +482,7 @@ func TestEngineRequire_skipsForAPIKeyAuth(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) sessionID := "session_123" ctx := contextvalues.SetAuthContext(t.Context(), &contextvalues.AuthContext{ ActiveOrganizationID: "org_123", @@ -513,7 +509,7 @@ func TestEngineFilter_skipsForNonEnterpriseAccount(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) sessionID := "session_123" ctx := contextvalues.SetAuthContext(t.Context(), &contextvalues.AuthContext{ ActiveOrganizationID: "org_123", @@ -548,60 +544,6 @@ func (c *countingMembershipFetcher) GetOrgMembership(context.Context, string, st return nil, nil } -type mapCache struct { - items map[string][]byte -} - -func newMapCache() *mapCache { - return &mapCache{items: map[string][]byte{}} -} - -func (m *mapCache) Get(_ context.Context, key string, value any) error { - item, ok := m.items[key] - if !ok { - return errors.New("cache miss") - } - if err := json.Unmarshal(item, value); err != nil { - return fmt.Errorf("unmarshal cache item: %w", err) - } - return nil -} - -func (m *mapCache) Set(_ context.Context, key string, value any, _ time.Duration) error { - item, err := json.Marshal(value) - if err != nil { - return fmt.Errorf("marshal cache item: %w", err) - } - m.items[key] = item - return nil -} - -func (m *mapCache) Update(ctx context.Context, key string, value any) error { - return m.Set(ctx, key, value, 0) -} - -func (m *mapCache) Delete(_ context.Context, key string) error { - delete(m.items, key) - return nil -} - -func (m *mapCache) ListAppend(context.Context, string, any, time.Duration) error { - return errors.New("not implemented") -} - -func (m *mapCache) ListRange(context.Context, string, int64, int64, any) error { - return errors.New("not implemented") -} - -func (m *mapCache) DeleteByPrefix(_ context.Context, prefix string) error { - for key := range m.items { - if strings.HasPrefix(key, prefix) { - delete(m.items, key) - } - } - return nil -} - func enterpriseSessionCtx(t *testing.T) context.Context { t.Helper() return enterpriseSessionCtxWithOrg(t, "org_123") @@ -654,7 +596,7 @@ func TestCanUseOverride_devPlusAdmin(t *testing.T) { t.Parallel() chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache, EngineOpts{DevMode: true}) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), EngineOpts{DevMode: true}) ctx := scopeOverrideCtx(t, true, "pro") enforce, err := engine.ShouldEnforce(ctx) @@ -666,7 +608,7 @@ func TestCanUseOverride_devPlusNonAdmin(t *testing.T) { t.Parallel() chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache, EngineOpts{DevMode: true}) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), EngineOpts{DevMode: true}) ctx := scopeOverrideCtx(t, false, "pro") enforce, err := engine.ShouldEnforce(ctx) @@ -678,7 +620,7 @@ func TestCanUseOverride_prodPlusAdmin(t *testing.T) { t.Parallel() chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient()) ctx := scopeOverrideCtx(t, true, "pro") enforce, err := engine.ShouldEnforce(ctx) @@ -690,7 +632,7 @@ func TestCanUseOverride_prodPlusNonAdmin(t *testing.T) { t.Parallel() chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient()) ctx := scopeOverrideCtx(t, false, "pro") enforce, err := engine.ShouldEnforce(ctx) diff --git a/server/internal/authz/integration_test.go b/server/internal/authz/integration_test.go index e98bf05484..77194495c1 100644 --- a/server/internal/authz/integration_test.go +++ b/server/internal/authz/integration_test.go @@ -5,7 +5,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/oops" "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" @@ -31,7 +30,7 @@ func TestRequire_withLoadedGrantsFromContext(t *testing.T) { ctx = GrantsToContext(ctx, grants) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) err = engine.Require(ctx, Check{Scope: ScopeProjectRead, ResourceID: "proj:123"}, @@ -65,7 +64,7 @@ func TestFilter_withLoadedGrantsFromContext(t *testing.T) { ctx = GrantsToContext(ctx, grants) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) projectIDs, err := engine.Filter(ctx, []Check{ {Scope: ScopeProjectRead, ResourceID: "proj:123"}, @@ -104,7 +103,7 @@ func TestFilter_withDimensions(t *testing.T) { ctx = GrantsToContext(ctx, grants) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) results, err := engine.Filter(ctx, []Check{ MCPToolCallCheck("toolsetX", MCPToolCallDimensions{Tool: "allowed_tool", Disposition: ""}), diff --git a/server/internal/authz/load_test.go b/server/internal/authz/load_test.go index c09a83f51a..a6b51eb5cc 100644 --- a/server/internal/authz/load_test.go +++ b/server/internal/authz/load_test.go @@ -5,7 +5,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" @@ -30,7 +29,7 @@ func TestLoadGrants_loadsUserAndRoleGrants(t *testing.T) { ctx = GrantsToContext(ctx, grants) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) require.NoError(t, engine.Require(ctx, Check{Scope: ScopeProjectRead, ResourceID: "proj:123"})) require.NoError(t, engine.Require(ctx, Check{Scope: ScopeMCPConnect, ResourceID: "toolA"})) } @@ -93,7 +92,7 @@ func TestLoadGrants_returnsEmptyGrantSetWhenNoRowsMatch(t *testing.T) { ctx = GrantsToContext(ctx, grants) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) projectIDs, err := engine.Filter(ctx, []Check{ {Scope: ScopeProjectRead, ResourceID: "proj:123"}, }) diff --git a/server/internal/collections/setup_test.go b/server/internal/collections/setup_test.go index cee749b9a6..4b24b7325e 100644 --- a/server/internal/collections/setup_test.go +++ b/server/internal/collections/setup_test.go @@ -80,7 +80,7 @@ func newTestCollectionsService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) auditLogger := audit.NewLogger() svc := collections.NewService(logger, tracerProvider, conn, sessionManager, authzEngine, testenv.DefaultSiteURL(t)) diff --git a/server/internal/customdomains/setup_test.go b/server/internal/customdomains/setup_test.go index 9c33549559..ba34969956 100644 --- a/server/internal/customdomains/setup_test.go +++ b/server/internal/customdomains/setup_test.go @@ -84,7 +84,7 @@ func newTestCustomDomainsService(t *testing.T) (context.Context, *serviceTestIns temporal := &stubTemporalClient{} chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) auditLogger := audit.NewLogger() svc := customdomains.NewService(logger, tracerProvider, conn, sessionManager, temporal, authzEngine, auditLogger) diff --git a/server/internal/deployments/setup_test.go b/server/internal/deployments/setup_test.go index 630f541d1e..8ec879e7a3 100644 --- a/server/internal/deployments/setup_test.go +++ b/server/internal/deployments/setup_test.go @@ -106,7 +106,7 @@ func newTestDeploymentService(t *testing.T, assetStorage assets.BlobStore) (cont chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := deployments.NewService(logger, tracerProvider, conn, temporalEnv, sessionManager, assetStorage, posthog, testenv.DefaultSiteURL(t), mcpRegistryClient, authzEngine, auditLogger) assetsSvc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret", authzEngine, auditLogger) packagesSvc := packages.NewService(logger, tracerProvider, conn, sessionManager, authzEngine) diff --git a/server/internal/environments/setup_test.go b/server/internal/environments/setup_test.go index d96404ea9b..d461f0e656 100644 --- a/server/internal/environments/setup_test.go +++ b/server/internal/environments/setup_test.go @@ -78,7 +78,7 @@ func newTestEnvironmentService(t *testing.T) (context.Context, *testInstance) { require.NoError(t, err) auditLogger := audit.NewLogger() - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := environments.NewService(logger, tracerProvider, conn, sessionManager, enc, authzEngine, auditLogger) return ctx, &testInstance{ diff --git a/server/internal/externalmcp/setup_test.go b/server/internal/externalmcp/setup_test.go index 9ddbc7dfc1..fa8c9653e4 100644 --- a/server/internal/externalmcp/setup_test.go +++ b/server/internal/externalmcp/setup_test.go @@ -76,7 +76,7 @@ func newTestExternalMCPService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := externalmcp.NewService(logger, tracerProvider, conn, sessionManager, mcpRegistryClient, authzEngine) return ctx, &testInstance{ diff --git a/server/internal/functions/setup_test.go b/server/internal/functions/setup_test.go index f36a752d53..08fdda33fd 100644 --- a/server/internal/functions/setup_test.go +++ b/server/internal/functions/setup_test.go @@ -116,7 +116,7 @@ func newTestFunctionsService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := functions.NewService(logger, tracerProvider, conn, enc, tigrisStore) deploymentsSvc := deployments.NewService(logger, tracerProvider, conn, temporalEnv, sessionManager, assetStorage, ph, testenv.DefaultSiteURL(t), mcpRegistryClient, authzEngine, auditLogger) assetsSvc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret", authzEngine, auditLogger) diff --git a/server/internal/hooks/setup_test.go b/server/internal/hooks/setup_test.go index 7735ee8d8f..144e845e7a 100644 --- a/server/internal/hooks/setup_test.go +++ b/server/internal/hooks/setup_test.go @@ -83,7 +83,7 @@ func newTestHooksService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) chatWriter, chatWriterShutdown := chat.NewChatMessageWriter(logger, conn, nil) t.Cleanup(func() { _ = chatWriterShutdown(t.Context()) }) shadowMCPClient := shadowmcp.NewClient(logger, conn, cacheAdapter) diff --git a/server/internal/keys/setup_test.go b/server/internal/keys/setup_test.go index 7f120542dc..9bf5a1611e 100644 --- a/server/internal/keys/setup_test.go +++ b/server/internal/keys/setup_test.go @@ -82,7 +82,7 @@ func newTestKeysService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) auditLogger := audit.NewLogger() svc := keys.NewService(logger, tracerProvider, conn, sessionManager, "local", authzEngine, auditLogger) keyAuth := auth.NewKeyAuth(conn, logger, billingClient) diff --git a/server/internal/mcp/handle_get_server_test.go b/server/internal/mcp/handle_get_server_test.go index 73a580a1d4..59d01f33a1 100644 --- a/server/internal/mcp/handle_get_server_test.go +++ b/server/internal/mcp/handle_get_server_test.go @@ -12,7 +12,6 @@ import ( "github.com/stretchr/testify/require" "github.com/speakeasy-api/gram/server/internal/authz" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/mcpmetadata" "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" ) @@ -33,7 +32,7 @@ func TestHandleGetServer_ContentNegotiation(t *testing.T) { testInstance.serverURL, testInstance.siteURL, testInstance.cacheAdapter, - authz.NewEngine(testInstance.logger, testInstance.conn, chConn, nil, nil, workos.NewStubClient(), cache.NoopCache), + authz.NewEngine(testInstance.logger, testInstance.conn, chConn, nil, nil, workos.NewStubClient()), testInstance.audit, ) diff --git a/server/internal/mcp/rbac_test.go b/server/internal/mcp/rbac_test.go index 87456f1057..31ed3ef5da 100644 --- a/server/internal/mcp/rbac_test.go +++ b/server/internal/mcp/rbac_test.go @@ -11,7 +11,6 @@ import ( "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/authztest" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/conv" "github.com/speakeasy-api/gram/server/internal/oops" @@ -27,7 +26,7 @@ func TestServePublic_RBAC_PrivateMCP_DeniedWithNoGrants(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx) err = authzEngine.Require(ctx, authz.Check{Scope: authz.ScopeMCPConnect, ResourceID: toolset.ID.String()}) @@ -44,7 +43,7 @@ func TestServePublic_RBAC_PrivateMCP_DeniedWithUnrelatedGrant(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{Scope: authz.ScopeMCPConnect, Selector: authz.NewSelector(authz.ScopeMCPConnect, uuid.NewString())}) err = authzEngine.Require(ctx, authz.Check{Scope: authz.ScopeMCPConnect, ResourceID: toolset.ID.String()}) @@ -61,7 +60,7 @@ func TestServePublic_RBAC_PrivateMCP_AllowedWithWriteGrant(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{Scope: authz.ScopeMCPWrite, Selector: authz.NewSelector(authz.ScopeMCPWrite, toolset.ID.String())}) err = authzEngine.Require(ctx, authz.Check{Scope: authz.ScopeMCPConnect, ResourceID: toolset.ID.String()}) @@ -76,7 +75,7 @@ func TestServePublic_RBAC_PrivateMCP_AllowedWithConnectGrant(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{Scope: authz.ScopeMCPConnect, Selector: authz.NewSelector(authz.ScopeMCPConnect, toolset.ID.String())}) err = authzEngine.Require(ctx, authz.Check{Scope: authz.ScopeMCPConnect, ResourceID: toolset.ID.String()}) @@ -109,7 +108,7 @@ func TestServePublic_RBAC_ToolLevelGrant_AllowsMatchingTool(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ Scope: authz.ScopeMCPConnect, Selector: authz.Selector{ @@ -136,7 +135,7 @@ func TestServePublic_RBAC_ToolLevelGrant_DeniesWrongTool(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ Scope: authz.ScopeMCPConnect, Selector: authz.Selector{ @@ -165,7 +164,7 @@ func TestServePublic_RBAC_ServerLevelGrant_AllowsAnyTool(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Server-level grant (no tool key) should allow any tool. ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ Scope: authz.ScopeMCPConnect, @@ -188,7 +187,7 @@ func TestServePublic_RBAC_ToolLevelGrant_DeniesWrongServer(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ Scope: authz.ScopeMCPConnect, Selector: authz.Selector{ @@ -217,7 +216,7 @@ func TestServePublic_RBAC_ProjectScopedGrant_AllowsServerInProject(t *testing.T) chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Project-scoped grant: resource_id=*, project_id=toolset's project ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ Scope: authz.ScopeMCPConnect, @@ -241,7 +240,7 @@ func TestServePublic_RBAC_ProjectScopedGrant_DeniesServerInOtherProject(t *testi chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Project-scoped grant for a *different* project otherProjectID := uuid.New() ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ @@ -268,7 +267,7 @@ func TestServePublic_RBAC_ProjectScopedGrant_AllowsToolCallInProject(t *testing. chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ Scope: authz.ScopeMCPConnect, Selector: authz.Selector{ diff --git a/server/internal/mcp/rpc_tools_list_test.go b/server/internal/mcp/rpc_tools_list_test.go index f384ece810..ff6b4f4432 100644 --- a/server/internal/mcp/rpc_tools_list_test.go +++ b/server/internal/mcp/rpc_tools_list_test.go @@ -13,7 +13,6 @@ import ( "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/authztest" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" deployments_repo "github.com/speakeasy-api/gram/server/internal/deployments/repo" "github.com/speakeasy-api/gram/server/internal/oops" @@ -211,7 +210,7 @@ func TestServePublic_RBAC_ToolsList_FiltersToGrantedTools(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Grant mcp:connect only for "allowed_tool". ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ @@ -248,7 +247,7 @@ func TestServePublic_RBAC_ToolsList_ServerLevelGrantReturnsAll(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Server-level grant (no tool dimension) — all tools allowed. ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ @@ -278,7 +277,7 @@ func TestServePublic_RBAC_ToolsList_NoGrantsDenied(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // mcp:connect grant for a DIFFERENT toolset — should not match. ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ @@ -303,7 +302,7 @@ func TestServePublic_RBAC_ToolsList_MultipleToolGrants(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Grant access to tool_a and tool_c but not tool_b. ctx = authztest.WithExactGrants(t, ctx, @@ -348,7 +347,7 @@ func TestServePublic_RBAC_ToolsList_DispositionGrant_AllowsMatchingDisposition(t chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Grant mcp:connect scoped to read_only disposition only. ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ @@ -386,7 +385,7 @@ func TestServePublic_RBAC_ToolsList_DisabledRBACAllowsAll(t *testing.T) { // Engine with RBAC disabled — simulates org without RBAC feature flag. chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysDisabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysDisabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // No grants in context at all. With RBAC disabled, every tool should pass. for _, tool := range []string{"tool_one", "tool_two", "tool_three"} { @@ -406,7 +405,7 @@ func TestServePublic_RBAC_ToolsList_DispositionGrant_ServerLevelAllowsAll(t *tes chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Server-level grant (no disposition key) — all dispositions allowed. ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ diff --git a/server/internal/mcp/setup_test.go b/server/internal/mcp/setup_test.go index bb0b9733c8..e7d856f7f2 100644 --- a/server/internal/mcp/setup_test.go +++ b/server/internal/mcp/setup_test.go @@ -131,7 +131,7 @@ func newTestMCPService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) telemLogger := telemetry.NewLogger(ctx, logger, chConn, logsEnabled, toolIOLogsEnabled) telemService := telemetry.NewService( diff --git a/server/internal/mcpendpoints/setup_test.go b/server/internal/mcpendpoints/setup_test.go index 7b0a378e0a..1a024e8754 100644 --- a/server/internal/mcpendpoints/setup_test.go +++ b/server/internal/mcpendpoints/setup_test.go @@ -83,7 +83,7 @@ func newTestService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() - svc := mcpendpoints.NewService(logger, tracerProvider, conn, sessionManager, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), auditLogger) + svc := mcpendpoints.NewService(logger, tracerProvider, conn, sessionManager, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), auditLogger) return ctx, &testInstance{ service: svc, diff --git a/server/internal/mcpmetadata/setup_test.go b/server/internal/mcpmetadata/setup_test.go index d08360f57d..54f730b67b 100644 --- a/server/internal/mcpmetadata/setup_test.go +++ b/server/internal/mcpmetadata/setup_test.go @@ -93,7 +93,7 @@ func newTestMCPMetadataService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() - svc := mcpmetadata.NewService(logger, tracerProvider, conn, sessionManager, serverURL, siteURL, cacheAdapter, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), auditLogger) + svc := mcpmetadata.NewService(logger, tracerProvider, conn, sessionManager, serverURL, siteURL, cacheAdapter, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), auditLogger) return ctx, &testInstance{ service: svc, diff --git a/server/internal/mcpservers/setup_test.go b/server/internal/mcpservers/setup_test.go index 6ca66573de..18e9f72b2d 100644 --- a/server/internal/mcpservers/setup_test.go +++ b/server/internal/mcpservers/setup_test.go @@ -81,7 +81,7 @@ func newTestService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() - svc := mcpservers.NewService(logger, tracerProvider, conn, sessionManager, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), auditLogger) + svc := mcpservers.NewService(logger, tracerProvider, conn, sessionManager, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), auditLogger) return ctx, &testInstance{ service: svc, diff --git a/server/internal/organizations/setup_test.go b/server/internal/organizations/setup_test.go index 6b7836a9a9..cc5ebc85a6 100644 --- a/server/internal/organizations/setup_test.go +++ b/server/internal/organizations/setup_test.go @@ -106,7 +106,7 @@ func newTestOrganizationsService(t *testing.T) (context.Context, *testInstance) chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, thirdpartyworkos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, thirdpartyworkos.NewStubClient()) svc := organizations.NewService(logger, tracerProvider, conn, sessionManager, orgs, stubOrgFeatures{}, authzEngine) return ctx, &testInstance{ @@ -155,7 +155,7 @@ func newTestOrganizationsServiceRBAC(t *testing.T) (context.Context, *testInstan chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, thirdpartyworkos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, thirdpartyworkos.NewStubClient()) svc := organizations.NewService(logger, tracerProvider, conn, sessionManager, orgs, stubOrgFeaturesEnabled{}, authzEngine) return ctx, &testInstance{ @@ -168,11 +168,11 @@ func newTestOrganizationsServiceRBAC(t *testing.T) (context.Context, *testInstan // expectWorkOSOrgAdminRole stubs a successful WorkOS admin membership check for the session user. func expectWorkOSOrgAdminRole(t *testing.T, orgs *MockOrganizationProvider) { t.Helper() - orgs.On("GetOrgMembership", mock.Anything, testAuthUserWorkOSID, mockidp.MockOrgID).Return(&thirdpartyworkos.Member{RoleSlug: "admin"}, nil).Once() + orgs.On("GetOrgMembership", mock.Anything, testAuthUserWorkOSID, mockidp.MockOrgID).Return(&thirdpartyworkos.Member{RoleSlug: "admin"}).Once() } // expectWorkOSOrgNonAdminRole stubs WorkOS membership with a non-admin role. func expectWorkOSOrgNonAdminRole(t *testing.T, orgs *MockOrganizationProvider) { t.Helper() - orgs.On("GetOrgMembership", mock.Anything, testAuthUserWorkOSID, mockidp.MockOrgID).Return(&thirdpartyworkos.Member{RoleSlug: "member"}, nil).Once() + orgs.On("GetOrgMembership", mock.Anything, testAuthUserWorkOSID, mockidp.MockOrgID).Return(&thirdpartyworkos.Member{RoleSlug: "member"}).Once() } diff --git a/server/internal/packages/setup_test.go b/server/internal/packages/setup_test.go index 39e4663681..19f8db4708 100644 --- a/server/internal/packages/setup_test.go +++ b/server/internal/packages/setup_test.go @@ -74,7 +74,7 @@ func newTestPackagesService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := packages.NewService(logger, tracerProvider, conn, sessionManager, authzEngine) return ctx, &testInstance{ diff --git a/server/internal/plugins/setup_test.go b/server/internal/plugins/setup_test.go index 6804a41e59..cedbffb4c2 100644 --- a/server/internal/plugins/setup_test.go +++ b/server/internal/plugins/setup_test.go @@ -97,7 +97,7 @@ func newTestPluginsService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() - svc := plugins.NewService(logger, tracerProvider, conn, sessionManager, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), auditLogger, nil, "local", "https://app.getgram.ai") + svc := plugins.NewService(logger, tracerProvider, conn, sessionManager, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), auditLogger, nil, "local", "https://app.getgram.ai") return ctx, &testInstance{ service: svc, @@ -154,7 +154,7 @@ func newTestPluginsServiceWithGitHub(t *testing.T, ghClient plugins.GitHubPublis tracerProvider, conn, sessionManager, - authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), + authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), auditLogger, ghConfig, "local", diff --git a/server/internal/productfeatures/setup_test.go b/server/internal/productfeatures/setup_test.go index 76a5ede88b..9be41ff232 100644 --- a/server/internal/productfeatures/setup_test.go +++ b/server/internal/productfeatures/setup_test.go @@ -90,7 +90,7 @@ func newTestProductFeaturesService(t *testing.T) (context.Context, *testInstance chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := productfeatures.NewService(logger, tracerProvider, conn, sessionManager, redisClient, authzEngine) return ctx, &testInstance{ diff --git a/server/internal/projects/setup_test.go b/server/internal/projects/setup_test.go index 9d127dceae..a029e24ff5 100644 --- a/server/internal/projects/setup_test.go +++ b/server/internal/projects/setup_test.go @@ -110,7 +110,6 @@ func newTestProjectsService(t *testing.T, enableRBAC bool) (context.Context, *te }, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), - cache.NoopCache, ), auditLogger, ) diff --git a/server/internal/remotemcp/setup_test.go b/server/internal/remotemcp/setup_test.go index aaa848d9fc..fec5f1a540 100644 --- a/server/internal/remotemcp/setup_test.go +++ b/server/internal/remotemcp/setup_test.go @@ -107,7 +107,7 @@ func newTestService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() - svc := remotemcp.NewService(logger, tracerProvider, conn, sessionManager, enc, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), servicePolicy, auditLogger) + svc := remotemcp.NewService(logger, tracerProvider, conn, sessionManager, enc, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), servicePolicy, auditLogger) return ctx, &testInstance{ service: svc, diff --git a/server/internal/resources/setup_test.go b/server/internal/resources/setup_test.go index e3889b4ce3..cf2dac4236 100644 --- a/server/internal/resources/setup_test.go +++ b/server/internal/resources/setup_test.go @@ -73,7 +73,7 @@ func newTestResourcesService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := resources.NewService(logger, tracerProvider, conn, sessionManager, authzEngine) return ctx, &testInstance{ diff --git a/server/internal/risk/setup_test.go b/server/internal/risk/setup_test.go index fbb40faca4..c0ce42dd1a 100644 --- a/server/internal/risk/setup_test.go +++ b/server/internal/risk/setup_test.go @@ -93,7 +93,7 @@ func newTestRiskService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) shadowMCPClient := shadowmcp.NewClient(logger, conn, cache.NewRedisCacheAdapter(redisClient)) auditLogger := audit.NewLogger() diff --git a/server/internal/telemetry/setup_test.go b/server/internal/telemetry/setup_test.go index 8b902b4963..7a8b90ee42 100644 --- a/server/internal/telemetry/setup_test.go +++ b/server/internal/telemetry/setup_test.go @@ -118,7 +118,7 @@ func newTestLogsService(t *testing.T) (context.Context, *testInstance) { posthogClient := posthog.New(ctx, logger, "test-posthog-key", "test-posthog-host", "") telemLogger := telemetry.NewLogger(ctx, logger, chConn, logsEnabled, toolIOLogsEnabled) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := telemetry.NewService(logger, tracerProvider, conn, chConn, sessionManager, chatSessionsManager, logsEnabled, sessionCaptureEnabled, posthogClient, authzEngine) return ctx, &testInstance{ diff --git a/server/internal/templates/setup_test.go b/server/internal/templates/setup_test.go index 5c4fcbeecd..0801bd55b0 100644 --- a/server/internal/templates/setup_test.go +++ b/server/internal/templates/setup_test.go @@ -84,7 +84,7 @@ func newTestTemplateService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) auditLogger := audit.NewLogger() svc := templates.NewService(logger, tracerProvider, conn, sessionManager, toolsetsSvc, authzEngine, auditLogger) diff --git a/server/internal/tools/setup_test.go b/server/internal/tools/setup_test.go index dbf1b66c40..d39ab04911 100644 --- a/server/internal/tools/setup_test.go +++ b/server/internal/tools/setup_test.go @@ -115,7 +115,7 @@ func newTestToolsService(t *testing.T, assetStorage assets.BlobStore) (context.C chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) toolsSvc := tools.NewService(logger, tracerProvider, conn, sessionManager, authzEngine, nil, nil) deploymentsSvc := deployments.NewService(logger, tracerProvider, conn, temporalEnv, sessionManager, assetStorage, posthog, testenv.DefaultSiteURL(t), mcpRegistryClient, authzEngine, auditLogger) assetsSvc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret", authzEngine, auditLogger) diff --git a/server/internal/toolsets/setup_test.go b/server/internal/toolsets/setup_test.go index 3372aa41c2..81c7ca16b7 100644 --- a/server/internal/toolsets/setup_test.go +++ b/server/internal/toolsets/setup_test.go @@ -124,7 +124,7 @@ func newTestToolsetsService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := toolsets.NewService(logger, tracerProvider, conn, sessionManager, nil, authzEngine, auditLogger) deploymentsSvc := deployments.NewService(logger, tracerProvider, conn, temporalEnv, sessionManager, assetStorage, posthog, testenv.DefaultSiteURL(t), mcpRegistryClient, authzEngine, auditLogger) assetsSvc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret", authzEngine, auditLogger) diff --git a/server/internal/triggers/setup_test.go b/server/internal/triggers/setup_test.go index ca543445b6..837a6f092e 100644 --- a/server/internal/triggers/setup_test.go +++ b/server/internal/triggers/setup_test.go @@ -121,7 +121,7 @@ func newTestService(t *testing.T) (context.Context, *testInstance) { tracerProvider, conn, sessionManager, - authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), + authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), app, auditLogger, ) diff --git a/server/internal/usage/impl_test.go b/server/internal/usage/impl_test.go index 31b71aeba2..231aa7f1e3 100644 --- a/server/internal/usage/impl_test.go +++ b/server/internal/usage/impl_test.go @@ -17,7 +17,6 @@ import ( "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/authztest" "github.com/speakeasy-api/gram/server/internal/billing" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/oops" "github.com/speakeasy-api/gram/server/internal/testenv" @@ -144,7 +143,7 @@ func newTestService(t *testing.T, billingRepo billing.Repository, orgID string, chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, db, chConn, rbacDisabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, db, chConn, rbacDisabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) return &Service{ tracer: tp.Tracer("test"), @@ -204,7 +203,7 @@ func TestGetPeriodUsage_CacheHit(t *testing.T) { cached := sampleUsage(42, 2, 10) billingMock := &mockBillingRepo{} - billingMock.On("GetStoredPeriodUsage", mock.Anything, orgID).Return(cached, nil) + billingMock.On("GetStoredPeriodUsage", mock.Anything, orgID).Return(cached) // GetPeriodUsage should NOT be called svc := newTestService(t, billingMock, orgID, 5) @@ -226,7 +225,7 @@ func TestGetPeriodUsage_CacheMissFallback(t *testing.T) { billingMock := &mockBillingRepo{} billingMock.On("GetStoredPeriodUsage", mock.Anything, orgID).Return(nil, fmt.Errorf("cache miss")) - billingMock.On("GetPeriodUsage", mock.Anything, orgID).Return(fresh, nil) + billingMock.On("GetPeriodUsage", mock.Anything, orgID).Return(fresh) svc := newTestService(t, billingMock, orgID, 3) ctx := testAuthContext(orgID) @@ -280,7 +279,7 @@ func TestGetPeriodUsage_ActualServerCountFromDB(t *testing.T) { cached.ActualEnabledServerCount = 999 // cached value should be overridden billingMock := &mockBillingRepo{} - billingMock.On("GetStoredPeriodUsage", mock.Anything, orgID).Return(cached, nil) + billingMock.On("GetStoredPeriodUsage", mock.Anything, orgID).Return(cached) svc := newTestService(t, billingMock, orgID, 7) // DB says 7 ctx := testAuthContext(orgID) diff --git a/server/internal/usersessions/setup_test.go b/server/internal/usersessions/setup_test.go index 43a1d8edf3..aeb09cde7d 100644 --- a/server/internal/usersessions/setup_test.go +++ b/server/internal/usersessions/setup_test.go @@ -92,7 +92,7 @@ func newTestService(t *testing.T) (context.Context, *testInstance) { conn, sessionManager, chatSessionsManager, - authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), + authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), audit.NewLogger(), ) diff --git a/server/internal/variations/setup_test.go b/server/internal/variations/setup_test.go index 72f9fb81a9..9bcb0ede21 100644 --- a/server/internal/variations/setup_test.go +++ b/server/internal/variations/setup_test.go @@ -74,7 +74,7 @@ func newTestVariationsService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) auditLogger := audit.NewLogger() svc := variations.NewService(logger, tracerProvider, conn, sessionManager, authzEngine, auditLogger) diff --git a/server/internal/xmcp/setup_test.go b/server/internal/xmcp/setup_test.go index 087ad44f0b..e02abbdc3f 100644 --- a/server/internal/xmcp/setup_test.go +++ b/server/internal/xmcp/setup_test.go @@ -116,7 +116,7 @@ func newTestService(t *testing.T) (context.Context, *testInstance) { enc := testenv.NewEncryptionClient(t) chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) mcpMetadataRepo := mcpmetadatarepo.New(conn) env := environments.NewEnvironmentEntries(logger, conn, enc, mcpMetadataRepo) @@ -321,7 +321,7 @@ func seedOAuthProxyServer(t *testing.T, ctx context.Context, ti *testInstance, p // seedRemoteMCPEndpointWithOAuthProxy wires up a remote-backed mcp_server // configured for the OAuth-proxy token-swap flow. The proxy resolution // is currently stubbed in mcp.Service.ResolveOAuthProxyUpstreamToken -// (returns "", nil), so this seeding is enough to drive the auth-switch +// (returns ""), so this seeding is enough to drive the auth-switch // branch in xmcp; once the resolver is implemented it will exercise the // full token-swap path. func seedRemoteMCPEndpointWithOAuthProxy(t *testing.T, ctx context.Context, ti *testInstance, projectID uuid.UUID, upstreamURL string) (slug string) { diff --git a/server/internal/xmcp/tools_call_authz_interceptor_test.go b/server/internal/xmcp/tools_call_authz_interceptor_test.go index 1944adebfb..836608b821 100644 --- a/server/internal/xmcp/tools_call_authz_interceptor_test.go +++ b/server/internal/xmcp/tools_call_authz_interceptor_test.go @@ -8,7 +8,6 @@ import ( "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/authztest" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/oops" "github.com/speakeasy-api/gram/server/internal/remotemcp/proxy" @@ -24,7 +23,7 @@ const ( func newAuthzEngineForTest(t *testing.T) *authz.Engine { t.Helper() - return authz.NewEngine(testenv.NewLogger(t), nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + return authz.NewEngine(testenv.NewLogger(t), nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) } func authzAuthContext(t *testing.T) *contextvalues.AuthContext { From 0d050cbcccaa230cb153a48ac5b8221153f76dc8 Mon Sep 17 00:00:00 2001 From: tgmendes Date: Fri, 15 May 2026 10:52:23 +0100 Subject: [PATCH 03/12] chore: code improvements plus fetch total members when listing members --- server/internal/access/deleterole_test.go | 10 + server/internal/access/listmembers_test.go | 10 +- server/internal/access/queries.sql | 122 ++++++++---- server/internal/access/repo/queries.sql.go | 204 +++++++++++++------- server/internal/access/role_manager.go | 145 ++++++-------- server/internal/access/role_manager_test.go | 28 ++- 6 files changed, 313 insertions(+), 206 deletions(-) diff --git a/server/internal/access/deleterole_test.go b/server/internal/access/deleterole_test.go index b534987530..ebc6e77221 100644 --- a/server/internal/access/deleterole_test.go +++ b/server/internal/access/deleterole_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" gen "github.com/speakeasy-api/gram/server/gen/access" + accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" "github.com/speakeasy-api/gram/server/internal/audit" "github.com/speakeasy-api/gram/server/internal/audit/audittest" "github.com/speakeasy-api/gram/server/internal/authz" @@ -35,6 +36,15 @@ func TestService_DeleteRole(t *testing.T) { grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder")) require.Empty(t, grants) + + role, err := accessrepo.New(ti.conn).GetOrganizationRoleBySlug(ctx, accessrepo.GetOrganizationRoleBySlugParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosSlug: "custom-builder", + }) + require.NoError(t, err) + require.True(t, role.Deleted) + require.False(t, role.WorkosDeleted) + require.False(t, role.WorkosDeletedAt.Valid) } func TestService_DeleteRole_ReassignsMembersToDefault(t *testing.T) { diff --git a/server/internal/access/listmembers_test.go b/server/internal/access/listmembers_test.go index 619c32c610..1771325b46 100644 --- a/server/internal/access/listmembers_test.go +++ b/server/internal/access/listmembers_test.go @@ -43,7 +43,7 @@ func TestService_ListMembers(t *testing.T) { require.Equal(t, builderID, byID["local_user_2"].RoleID) } -func TestService_ListMembers_IncludesWorkOSOnlyUsers(t *testing.T) { +func TestService_ListMembers_SkipsMembersWithoutLocalUser(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -61,15 +61,14 @@ func TestService_ListMembers_IncludesWorkOSOnlyUsers(t *testing.T) { result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) - require.Len(t, result.Members, 2) + require.Len(t, result.Members, 1) byID := map[string]*gen.AccessMember{} for _, member := range result.Members { byID[member.ID] = member } require.Equal(t, "Ada Lovelace", byID["local_user_1"].Name) - require.Equal(t, "user_2", byID["user_2"].ID) - require.Equal(t, "user_2", byID["user_2"].Name) + require.Nil(t, byID["user_2"]) } func TestService_ListMembers_UsesDatabaseOnly(t *testing.T) { @@ -82,6 +81,5 @@ func TestService_ListMembers_UsesDatabaseOnly(t *testing.T) { result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) - require.Len(t, result.Members, 1) - require.Equal(t, "user_1", result.Members[0].ID) + require.Empty(t, result.Members) } diff --git a/server/internal/access/queries.sql b/server/internal/access/queries.sql index 0a564119f7..1418002ac8 100644 --- a/server/internal/access/queries.sql +++ b/server/internal/access/queries.sql @@ -149,32 +149,93 @@ WHERE organization_id = @organization_id AND workos_slug = @workos_slug AND deleted_at IS NULL; +-- name: MarkOrganizationRoleDeletedLocally :execrows +UPDATE organization_roles +SET workos_last_event_id = @workos_last_event_id, + deleted_at = clock_timestamp(), + updated_at = clock_timestamp() +WHERE organization_id = @organization_id + AND workos_slug = @workos_slug + AND deleted_at IS NULL; + -- name: ListActiveOrganizationRoles :many -SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at -FROM global_roles -WHERE deleted IS FALSE - AND workos_deleted IS FALSE -UNION ALL -SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at +WITH active_roles AS ( + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind + FROM global_roles + WHERE deleted IS FALSE + AND workos_deleted IS FALSE + UNION ALL + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_id = @organization_id + AND deleted IS FALSE + AND workos_deleted IS FALSE +) +SELECT + active_roles.id, + active_roles.workos_slug, + active_roles.workos_name, + active_roles.workos_description, + active_roles.workos_created_at, + active_roles.workos_updated_at, + COUNT(ora.id)::bigint AS member_count +FROM active_roles +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = @organization_id + AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text + AND ora.user_id IS NOT NULL +GROUP BY active_roles.id, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +ORDER BY active_roles.workos_slug; + +-- name: GetActiveOrganizationRoleBySlug :one +SELECT + organization_roles.id, + organization_roles.workos_slug, + organization_roles.workos_name, + organization_roles.workos_description, + organization_roles.workos_created_at, + organization_roles.workos_updated_at, + COUNT(ora.id)::bigint AS member_count FROM organization_roles -WHERE organization_id = @organization_id - AND deleted IS FALSE - AND workos_deleted IS FALSE -ORDER BY workos_slug; +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = organization_roles.organization_id + AND ora.role_urn = 'role:organization:' || organization_roles.id::text + AND ora.user_id IS NOT NULL +WHERE organization_roles.organization_id = @organization_id + AND organization_roles.workos_slug = @workos_slug + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +GROUP BY organization_roles.id, organization_roles.workos_slug, organization_roles.workos_name, organization_roles.workos_description, organization_roles.workos_created_at, organization_roles.workos_updated_at; -- name: GetOrganizationRoleByID :one -SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at -FROM global_roles -WHERE global_roles.id = sqlc.arg(id) - AND deleted IS FALSE - AND workos_deleted IS FALSE +WITH active_roles AS ( + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.id = sqlc.arg(id) + AND deleted IS FALSE + AND workos_deleted IS FALSE UNION ALL -SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at -FROM organization_roles -WHERE organization_id = @organization_id - AND organization_roles.id = sqlc.arg(id) - AND deleted IS FALSE - AND workos_deleted IS FALSE + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_id = @organization_id + AND organization_roles.id = sqlc.arg(id) + AND deleted IS FALSE + AND workos_deleted IS FALSE +) +SELECT + active_roles.id, + active_roles.workos_slug, + active_roles.workos_name, + active_roles.workos_description, + active_roles.workos_created_at, + active_roles.workos_updated_at, + COUNT(ora.id)::bigint AS member_count +FROM active_roles +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = @organization_id + AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text + AND ora.user_id IS NOT NULL +GROUP BY active_roles.id, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at LIMIT 1; -- name: ListOrganizationRoleAssignmentsForOrg :many @@ -215,25 +276,6 @@ WHERE ora.organization_id = @organization_id AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL ORDER BY role_slug; --- name: CountMembersByRoleForOrg :many -SELECT - COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, - COUNT(*)::bigint AS member_count -FROM organization_role_assignments AS ora -LEFT JOIN organization_roles - ON ora.role_urn = 'role:organization:' || organization_roles.id::text - AND organization_roles.organization_id = ora.organization_id - AND organization_roles.deleted IS FALSE - AND organization_roles.workos_deleted IS FALSE -LEFT JOIN global_roles - ON ora.role_urn = 'role:global:' || global_roles.id::text - AND global_roles.deleted IS FALSE - AND global_roles.workos_deleted IS FALSE -WHERE ora.organization_id = @organization_id - AND ora.user_id IS NOT NULL - AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL -GROUP BY role_slug; - -- name: ReplaceOrganizationRoleAssignment :one WITH input_role_urn AS ( SELECT 'role:organization:' || id::text AS role_urn diff --git a/server/internal/access/repo/queries.sql.go b/server/internal/access/repo/queries.sql.go index ac4a54faf4..d2fa1f9fbe 100644 --- a/server/internal/access/repo/queries.sql.go +++ b/server/internal/access/repo/queries.sql.go @@ -13,51 +13,6 @@ import ( "github.com/speakeasy-api/gram/server/internal/urn" ) -const countMembersByRoleForOrg = `-- name: CountMembersByRoleForOrg :many -SELECT - COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, - COUNT(*)::bigint AS member_count -FROM organization_role_assignments AS ora -LEFT JOIN organization_roles - ON ora.role_urn = 'role:organization:' || organization_roles.id::text - AND organization_roles.organization_id = ora.organization_id - AND organization_roles.deleted IS FALSE - AND organization_roles.workos_deleted IS FALSE -LEFT JOIN global_roles - ON ora.role_urn = 'role:global:' || global_roles.id::text - AND global_roles.deleted IS FALSE - AND global_roles.workos_deleted IS FALSE -WHERE ora.organization_id = $1 - AND ora.user_id IS NOT NULL - AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL -GROUP BY role_slug -` - -type CountMembersByRoleForOrgRow struct { - RoleSlug string - MemberCount int64 -} - -func (q *Queries) CountMembersByRoleForOrg(ctx context.Context, organizationID string) ([]CountMembersByRoleForOrgRow, error) { - rows, err := q.db.Query(ctx, countMembersByRoleForOrg, organizationID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []CountMembersByRoleForOrgRow - for rows.Next() { - var i CountMembersByRoleForOrgRow - if err := rows.Scan(&i.RoleSlug, &i.MemberCount); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const deletePrincipalGrant = `-- name: DeletePrincipalGrant :execrows DELETE FROM principal_grants WHERE id = $1 @@ -99,6 +54,57 @@ func (q *Queries) DeletePrincipalGrantsByPrincipal(ctx context.Context, arg Dele return result.RowsAffected(), nil } +const getActiveOrganizationRoleBySlug = `-- name: GetActiveOrganizationRoleBySlug :one +SELECT + organization_roles.id, + organization_roles.workos_slug, + organization_roles.workos_name, + organization_roles.workos_description, + organization_roles.workos_created_at, + organization_roles.workos_updated_at, + COUNT(ora.id)::bigint AS member_count +FROM organization_roles +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = organization_roles.organization_id + AND ora.role_urn = 'role:organization:' || organization_roles.id::text + AND ora.user_id IS NOT NULL +WHERE organization_roles.organization_id = $1 + AND organization_roles.workos_slug = $2 + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +GROUP BY organization_roles.id, organization_roles.workos_slug, organization_roles.workos_name, organization_roles.workos_description, organization_roles.workos_created_at, organization_roles.workos_updated_at +` + +type GetActiveOrganizationRoleBySlugParams struct { + OrganizationID string + WorkosSlug string +} + +type GetActiveOrganizationRoleBySlugRow struct { + ID uuid.UUID + WorkosSlug string + WorkosName string + WorkosDescription pgtype.Text + WorkosCreatedAt pgtype.Timestamptz + WorkosUpdatedAt pgtype.Timestamptz + MemberCount int64 +} + +func (q *Queries) GetActiveOrganizationRoleBySlug(ctx context.Context, arg GetActiveOrganizationRoleBySlugParams) (GetActiveOrganizationRoleBySlugRow, error) { + row := q.db.QueryRow(ctx, getActiveOrganizationRoleBySlug, arg.OrganizationID, arg.WorkosSlug) + var i GetActiveOrganizationRoleBySlugRow + err := row.Scan( + &i.ID, + &i.WorkosSlug, + &i.WorkosName, + &i.WorkosDescription, + &i.WorkosCreatedAt, + &i.WorkosUpdatedAt, + &i.MemberCount, + ) + return i, err +} + const getGlobalRoleBySlug = `-- name: GetGlobalRoleBySlug :one SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, workos_deleted_at, workos_deleted, workos_last_event_id, created_at, updated_at, deleted_at, deleted FROM global_roles @@ -127,24 +133,40 @@ func (q *Queries) GetGlobalRoleBySlug(ctx context.Context, workosSlug string) (G } const getOrganizationRoleByID = `-- name: GetOrganizationRoleByID :one -SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at -FROM global_roles -WHERE global_roles.id = $1 - AND deleted IS FALSE - AND workos_deleted IS FALSE +WITH active_roles AS ( + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.id = $2 + AND deleted IS FALSE + AND workos_deleted IS FALSE UNION ALL -SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at -FROM organization_roles -WHERE organization_id = $2 - AND organization_roles.id = $1 - AND deleted IS FALSE - AND workos_deleted IS FALSE + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_id = $1 + AND organization_roles.id = $2 + AND deleted IS FALSE + AND workos_deleted IS FALSE +) +SELECT + active_roles.id, + active_roles.workos_slug, + active_roles.workos_name, + active_roles.workos_description, + active_roles.workos_created_at, + active_roles.workos_updated_at, + COUNT(ora.id)::bigint AS member_count +FROM active_roles +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = $1 + AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text + AND ora.user_id IS NOT NULL +GROUP BY active_roles.id, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at LIMIT 1 ` type GetOrganizationRoleByIDParams struct { - ID uuid.UUID OrganizationID string + ID uuid.UUID } type GetOrganizationRoleByIDRow struct { @@ -154,10 +176,11 @@ type GetOrganizationRoleByIDRow struct { WorkosDescription pgtype.Text WorkosCreatedAt pgtype.Timestamptz WorkosUpdatedAt pgtype.Timestamptz + MemberCount int64 } func (q *Queries) GetOrganizationRoleByID(ctx context.Context, arg GetOrganizationRoleByIDParams) (GetOrganizationRoleByIDRow, error) { - row := q.db.QueryRow(ctx, getOrganizationRoleByID, arg.ID, arg.OrganizationID) + row := q.db.QueryRow(ctx, getOrganizationRoleByID, arg.OrganizationID, arg.ID) var i GetOrganizationRoleByIDRow err := row.Scan( &i.ID, @@ -166,6 +189,7 @@ func (q *Queries) GetOrganizationRoleByID(ctx context.Context, arg GetOrganizati &i.WorkosDescription, &i.WorkosCreatedAt, &i.WorkosUpdatedAt, + &i.MemberCount, ) return i, err } @@ -313,17 +337,33 @@ func (q *Queries) InsertChallengeResolutions(ctx context.Context, arg InsertChal } const listActiveOrganizationRoles = `-- name: ListActiveOrganizationRoles :many -SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at -FROM global_roles -WHERE deleted IS FALSE - AND workos_deleted IS FALSE -UNION ALL -SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at -FROM organization_roles -WHERE organization_id = $1 - AND deleted IS FALSE - AND workos_deleted IS FALSE -ORDER BY workos_slug +WITH active_roles AS ( + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind + FROM global_roles + WHERE deleted IS FALSE + AND workos_deleted IS FALSE + UNION ALL + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_id = $1 + AND deleted IS FALSE + AND workos_deleted IS FALSE +) +SELECT + active_roles.id, + active_roles.workos_slug, + active_roles.workos_name, + active_roles.workos_description, + active_roles.workos_created_at, + active_roles.workos_updated_at, + COUNT(ora.id)::bigint AS member_count +FROM active_roles +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = $1 + AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text + AND ora.user_id IS NOT NULL +GROUP BY active_roles.id, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +ORDER BY active_roles.workos_slug ` type ListActiveOrganizationRolesRow struct { @@ -333,6 +373,7 @@ type ListActiveOrganizationRolesRow struct { WorkosDescription pgtype.Text WorkosCreatedAt pgtype.Timestamptz WorkosUpdatedAt pgtype.Timestamptz + MemberCount int64 } func (q *Queries) ListActiveOrganizationRoles(ctx context.Context, organizationID string) ([]ListActiveOrganizationRolesRow, error) { @@ -351,6 +392,7 @@ func (q *Queries) ListActiveOrganizationRoles(ctx context.Context, organizationI &i.WorkosDescription, &i.WorkosCreatedAt, &i.WorkosUpdatedAt, + &i.MemberCount, ); err != nil { return nil, err } @@ -620,6 +662,30 @@ func (q *Queries) MarkOrganizationRoleDeleted(ctx context.Context, arg MarkOrgan return result.RowsAffected(), nil } +const markOrganizationRoleDeletedLocally = `-- name: MarkOrganizationRoleDeletedLocally :execrows +UPDATE organization_roles +SET workos_last_event_id = $1, + deleted_at = clock_timestamp(), + updated_at = clock_timestamp() +WHERE organization_id = $2 + AND workos_slug = $3 + AND deleted_at IS NULL +` + +type MarkOrganizationRoleDeletedLocallyParams struct { + WorkosLastEventID pgtype.Text + OrganizationID string + WorkosSlug string +} + +func (q *Queries) MarkOrganizationRoleDeletedLocally(ctx context.Context, arg MarkOrganizationRoleDeletedLocallyParams) (int64, error) { + result, err := q.db.Exec(ctx, markOrganizationRoleDeletedLocally, arg.WorkosLastEventID, arg.OrganizationID, arg.WorkosSlug) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} + const replaceOrganizationRoleAssignment = `-- name: ReplaceOrganizationRoleAssignment :one WITH input_role_urn AS ( SELECT 'role:organization:' || id::text AS role_urn diff --git a/server/internal/access/role_manager.go b/server/internal/access/role_manager.go index 3e7b51eebb..ed3853b7b4 100644 --- a/server/internal/access/role_manager.go +++ b/server/internal/access/role_manager.go @@ -64,15 +64,9 @@ func (r *RoleManager) ListRoles(ctx context.Context, gramOrgID string) (*gen.Lis return nil, oops.E(oops.CodeUnexpected, err, "list roles").Log(ctx, r.logger) } - memberCounts, err := r.memberCounts(ctx, gramOrgID) - if err != nil { - return nil, err - } - - trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) roles := make([]*gen.Role, 0, len(rows)) for _, row := range rows { - role, err := r.roleViewFromLocalRole(ctx, gramOrgID, localRoleFromActiveRow(row), memberCounts[row.WorkosSlug]) + role, err := r.roleViewFromLocalRole(ctx, gramOrgID, localRoleFromActiveRow(row)) if err != nil { return nil, err } @@ -89,12 +83,7 @@ func (r *RoleManager) GetRoleByID(ctx context.Context, gramOrgID, id string) (*g return nil, err } - memberCounts, err := r.memberCounts(ctx, gramOrgID) - if err != nil { - return nil, err - } - - return r.roleViewFromLocalRole(ctx, gramOrgID, role, memberCounts[role.Slug]) + return r.roleViewFromLocalRole(ctx, gramOrgID, role) } type localRoleAssignment struct { @@ -111,7 +100,7 @@ func (r *RoleManager) ListMembers(ctx context.Context, gramOrgID string) (*gen.L if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "list roles").Log(ctx, r.logger) } - trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + roles := make(map[string]string, len(roleRows)) for _, row := range roleRows { roles[row.WorkosSlug] = row.ID.String() @@ -121,7 +110,7 @@ func (r *RoleManager) ListMembers(ctx context.Context, gramOrgID string) (*gen.L if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) } - trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + assignments := make([]localRoleAssignment, 0, len(assignmentRows)) for _, row := range assignmentRows { assignments = append(assignments, localRoleAssignment{ @@ -152,14 +141,6 @@ func (r *RoleManager) ListMembers(ctx context.Context, gramOrgID string) (*gen.L for _, assignment := range assignments { user, ok := localUsers[assignment.UserID] if !ok { - result = append(result, &gen.AccessMember{ - ID: assignment.WorkosUserID, - Name: assignment.WorkosUserID, - Email: "", - PhotoURL: nil, - RoleID: roles[assignment.RoleSlug], - JoinedAt: assignment.CreatedAt, - }) continue } @@ -232,13 +213,12 @@ func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID str if _, err := r.assignMembersToRole(ctx, gramOrgID, roleSlug, payload.MemberIds); err != nil { return roleCreateResult{}, err } + createdRole, err = r.getLocalRoleBySlug(ctx, gramOrgID, roleSlug) + if err != nil { + return roleCreateResult{}, err + } } - - memberCounts, err := r.memberCounts(ctx, gramOrgID) - if err != nil { - return roleCreateResult{}, err - } - role, err := r.roleViewFromLocalRole(ctx, gramOrgID, createdRole, memberCounts[createdRole.Slug]) + role, err := r.roleViewFromLocalRole(ctx, gramOrgID, createdRole) if err != nil { return roleCreateResult{}, err } @@ -253,6 +233,7 @@ type localRole struct { Description string CreatedAt string UpdatedAt string + MemberCount int } type roleUpdateResult struct { @@ -268,11 +249,7 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str if err != nil { return roleUpdateResult{}, err } - memberCountsBefore, err := r.memberCounts(ctx, gramOrgID) - if err != nil { - return roleUpdateResult{}, err - } - existingRole, err := r.roleViewFromLocalRole(ctx, gramOrgID, currentRole, memberCountsBefore[currentRole.Slug]) + existingRole, err := r.roleViewFromLocalRole(ctx, gramOrgID, currentRole) if err != nil { return roleUpdateResult{}, err } @@ -338,13 +315,13 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str if _, err := r.assignMembersToRole(ctx, gramOrgID, currentRole.Slug, payload.MemberIds); err != nil { return roleUpdateResult{}, err } + updatedRole, err = r.getLocalRoleByID(ctx, gramOrgID, payload.ID) + if err != nil { + return roleUpdateResult{}, err + } } - memberCounts, err := r.memberCounts(ctx, gramOrgID) - if err != nil { - return roleUpdateResult{}, err - } - updatedRoleView, err := r.roleViewFromLocalRole(ctx, gramOrgID, updatedRole, memberCounts[updatedRole.Slug]) + updatedRoleView, err := r.roleViewFromLocalRole(ctx, gramOrgID, updatedRole) if err != nil { return roleUpdateResult{}, err } @@ -397,10 +374,9 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro }) } - if _, err := repo.New(r.db).MarkOrganizationRoleDeleted(ctx, repo.MarkOrganizationRoleDeletedParams{ + if _, err := repo.New(r.db).MarkOrganizationRoleDeletedLocally(ctx, repo.MarkOrganizationRoleDeletedLocallyParams{ OrganizationID: gramOrgID, WorkosSlug: currentRole.Slug, - WorkosDeletedAt: conv.ToPGTimestamptz(time.Now().UTC()), WorkosLastEventID: conv.ToPGTextEmpty(""), }); err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) @@ -575,7 +551,7 @@ func (r *RoleManager) getLocalRoleByID(ctx context.Context, gramOrgID, id string // getLocalRoleBySlug loads one local organization role record by WorkOS slug. // Side effects: reads Postgres local role records; does not call WorkOS. func (r *RoleManager) getLocalRoleBySlug(ctx context.Context, gramOrgID, slug string) (localRole, error) { - row, err := repo.New(r.db).GetOrganizationRoleBySlug(ctx, repo.GetOrganizationRoleBySlugParams{ + row, err := repo.New(r.db).GetActiveOrganizationRoleBySlug(ctx, repo.GetActiveOrganizationRoleBySlugParams{ OrganizationID: gramOrgID, WorkosSlug: slug, }) @@ -584,12 +560,10 @@ func (r *RoleManager) getLocalRoleBySlug(ctx context.Context, gramOrgID, slug st return localRole{}, oops.E(oops.CodeNotFound, ErrRoleNotFound, "role not found").Log(ctx, r.logger) case err != nil: return localRole{}, oops.E(oops.CodeUnexpected, err, "get role").Log(ctx, r.logger) - case row.Deleted || row.WorkosDeleted: - return localRole{}, oops.E(oops.CodeNotFound, ErrRoleNotFound, "role not found").Log(ctx, r.logger) } trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) - return localRoleFromOrganizationRole(row), nil + return localRoleFromSlugRow(row), nil } type memberAssignmentTarget struct { @@ -613,9 +587,13 @@ func (r *RoleManager) memberAssignmentTargets(ctx context.Context, gramOrgID str return nil, oops.E(oops.CodeUnexpected, err, "resolve users by ids").Log(ctx, r.logger) } workosByGramID := make(map[string]string, len(users)) + requestedByWorkosID := make(map[string]string, len(users)) for _, user := range users { if user.WorkosID.Valid && user.WorkosID.String != "" { workosByGramID[user.ID] = user.WorkosID.String + if _, ok := requested[user.ID]; ok { + requestedByWorkosID[user.WorkosID.String] = user.ID + } } } @@ -624,37 +602,43 @@ func (r *RoleManager) memberAssignmentTargets(ctx context.Context, gramOrgID str return nil, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) } trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) - membershipByWorkosID := make(map[string]string, len(assignmentRows)) - for _, row := range assignmentRows { - membershipByWorkosID[row.WorkosUserID] = conv.FromPGTextOrEmpty[string](row.WorkosMembershipID) - } - targets := make([]memberAssignmentTarget, 0, len(memberIDs)) + resolved := make(map[string]struct{}, len(requested)) + seenWorkosID := make(map[string]struct{}, len(memberIDs)) for _, row := range assignmentRows { userID := conv.FromPGTextOrEmpty[string](row.UserID) membershipID := conv.FromPGTextOrEmpty[string](row.WorkosMembershipID) - if _, ok := requested[row.WorkosUserID]; ok { - targets = append(targets, memberAssignmentTarget{ - UserID: userID, - WorkosUserID: row.WorkosUserID, - MembershipID: membershipID, - }) + requestedID := "" + if _, ok := requested[userID]; ok { + requestedID = userID + } else if gramID, ok := requestedByWorkosID[row.WorkosUserID]; ok { + requestedID = gramID + } else if _, ok := requested[row.WorkosUserID]; ok { + requestedID = row.WorkosUserID + } else { continue } - - workosID, ok := workosByGramID[userID] - if !ok { + workosID := row.WorkosUserID + if userWorkosID, ok := workosByGramID[userID]; ok { + workosID = userWorkosID + } + if workosID == "" || membershipID == "" { continue } - if _, ok := requested[userID]; ok { - if _, ok := membershipByWorkosID[workosID]; ok { - targets = append(targets, memberAssignmentTarget{ - UserID: userID, - WorkosUserID: workosID, - MembershipID: membershipID, - }) - } + if _, ok := seenWorkosID[workosID]; ok { + resolved[requestedID] = struct{}{} + continue } + seenWorkosID[workosID] = struct{}{} + resolved[requestedID] = struct{}{} + targets = append(targets, memberAssignmentTarget{ + UserID: userID, + WorkosUserID: workosID, + MembershipID: membershipID, + }) + } + if len(resolved) != len(requested) { + return nil, oops.E(oops.CodeBadRequest, nil, "member role assignment not found; wait for WorkOS sync to complete").Log(ctx, r.logger) } return targets, nil @@ -732,25 +716,9 @@ func retryWorkOSError(err error) bool { return apiErr.StatusCode == 429 || apiErr.StatusCode >= 500 } -// memberCounts returns the number of locally connected members per role slug. -// Side effects: reads Postgres local assignment records; does not call WorkOS. -func (r *RoleManager) memberCounts(ctx context.Context, gramOrgID string) (map[string]int, error) { - rows, err := repo.New(r.db).CountMembersByRoleForOrg(ctx, gramOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "count local members by role").Log(ctx, r.logger) - } - - trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) - counts := make(map[string]int, len(rows)) - for _, row := range rows { - counts[row.RoleSlug] = int(row.MemberCount) - } - return counts, nil -} - // roleViewFromLocalRole converts a local role record into the public API role view and attaches local grants. // Side effects: reads Postgres grants; does not call WorkOS. -func (r *RoleManager) roleViewFromLocalRole(ctx context.Context, organizationID string, role localRole, memberCount int) (*gen.Role, error) { +func (r *RoleManager) roleViewFromLocalRole(ctx context.Context, organizationID string, role localRole) (*gen.Role, error) { grants, err := authz.GrantsForRole(ctx, r.logger, r.db, organizationID, role.Slug) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "load role grants").Log(ctx, r.logger) @@ -766,7 +734,7 @@ func (r *RoleManager) roleViewFromLocalRole(ctx context.Context, organizationID Description: role.Description, IsSystem: isSystemRole(role.Slug), Grants: genGrants, - MemberCount: memberCount, + MemberCount: role.MemberCount, CreatedAt: conv.Default(role.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), UpdatedAt: conv.Default(role.UpdatedAt, time.Time{}.UTC().Format(time.RFC3339)), }, nil @@ -782,6 +750,7 @@ func localRoleFromActiveRow(row repo.ListActiveOrganizationRolesRow) localRole { Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), } } @@ -795,12 +764,13 @@ func localRoleFromRoleRow(row repo.GetOrganizationRoleByIDRow) localRole { Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), } } -// localRoleFromOrganizationRole converts an organization role row into the manager's internal local role record shape. +// localRoleFromSlugRow converts a sqlc role slug lookup row into the manager's internal local role record shape. // Side effects: none. -func localRoleFromOrganizationRole(row repo.OrganizationRole) localRole { +func localRoleFromSlugRow(row repo.GetActiveOrganizationRoleBySlugRow) localRole { return localRole{ ID: row.ID.String(), Name: row.WorkosName, @@ -808,6 +778,7 @@ func localRoleFromOrganizationRole(row repo.OrganizationRole) localRole { Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), } } diff --git a/server/internal/access/role_manager_test.go b/server/internal/access/role_manager_test.go index b8c23e7907..b4fe9d343a 100644 --- a/server/internal/access/role_manager_test.go +++ b/server/internal/access/role_manager_test.go @@ -72,14 +72,34 @@ func TestRoleManager_MembersAndCounts(t *testing.T) { manager := ti.service.roleMgr members, err := manager.ListMembers(ctx, authCtx.ActiveOrganizationID) require.NoError(t, err) - require.Len(t, members.Members, 3) + require.Len(t, members.Members, 2) slugs, err := manager.MemberRoleSlugs(ctx, authCtx.ActiveOrganizationID, "user_2") require.NoError(t, err) require.Equal(t, []string{"custom-builder"}, slugs) - counts, err := manager.memberCounts(ctx, authCtx.ActiveOrganizationID) + roles, err := manager.ListRoles(ctx, authCtx.ActiveOrganizationID) require.NoError(t, err) - require.Equal(t, 1, counts["admin"]) - require.Equal(t, 1, counts["custom-builder"]) + counts := make(map[string]int, len(roles.Roles)) + for _, role := range roles.Roles { + counts[role.Name] = role.MemberCount + } + require.Equal(t, 1, counts["Admin"]) + require.Equal(t, 1, counts["Custom Builder"]) +} + +func TestRoleManager_AssignMembersToRoleRequiresLocalAssignment(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "u1@example.com", "User 1", "user_1", "membership_1") + + assigned, err := ti.service.roleMgr.assignMembersToRole(ctx, authCtx.ActiveOrganizationID, "custom-builder", []string{"local_user_1"}) + require.Error(t, err) + require.Equal(t, 0, assigned) + require.Contains(t, err.Error(), "member role assignment not found") } From c414ddb6fd3af461671b6db6683f8c0272cfa65e Mon Sep 17 00:00:00 2001 From: tgmendes Date: Fri, 15 May 2026 11:03:01 +0100 Subject: [PATCH 04/12] chore: update list members to automatically join data --- server/internal/access/queries.sql | 24 +++++++++ server/internal/access/repo/queries.sql.go | 61 ++++++++++++++++++++++ server/internal/access/role_manager.go | 60 ++++----------------- 3 files changed, 94 insertions(+), 51 deletions(-) diff --git a/server/internal/access/queries.sql b/server/internal/access/queries.sql index 1418002ac8..9f5c85a6e2 100644 --- a/server/internal/access/queries.sql +++ b/server/internal/access/queries.sql @@ -259,6 +259,30 @@ WHERE ora.organization_id = @organization_id AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL ORDER BY ora.workos_user_id, role_slug; +-- name: ListAccessMembers :many +SELECT + users.id, + users.display_name, + users.email, + users.photo_url, + COALESCE(organization_roles.id, global_roles.id)::uuid AS role_id, + ora.created_at AS joined_at +FROM organization_role_assignments AS ora +JOIN users + ON users.id = ora.user_id +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = @organization_id + AND COALESCE(organization_roles.id, global_roles.id) IS NOT NULL +ORDER BY users.email, users.id; + -- name: ListMemberRoleSlugsByWorkosUser :many SELECT DISTINCT COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug FROM organization_role_assignments AS ora diff --git a/server/internal/access/repo/queries.sql.go b/server/internal/access/repo/queries.sql.go index d2fa1f9fbe..92b34c6aa9 100644 --- a/server/internal/access/repo/queries.sql.go +++ b/server/internal/access/repo/queries.sql.go @@ -336,6 +336,67 @@ func (q *Queries) InsertChallengeResolutions(ctx context.Context, arg InsertChal return items, nil } +const listAccessMembers = `-- name: ListAccessMembers :many +SELECT + users.id, + users.display_name, + users.email, + users.photo_url, + COALESCE(organization_roles.id, global_roles.id)::uuid AS role_id, + ora.created_at AS joined_at +FROM organization_role_assignments AS ora +JOIN users + ON users.id = ora.user_id +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = $1 + AND COALESCE(organization_roles.id, global_roles.id) IS NOT NULL +ORDER BY users.email, users.id +` + +type ListAccessMembersRow struct { + ID string + DisplayName string + Email string + PhotoUrl pgtype.Text + RoleID uuid.UUID + JoinedAt pgtype.Timestamptz +} + +func (q *Queries) ListAccessMembers(ctx context.Context, organizationID string) ([]ListAccessMembersRow, error) { + rows, err := q.db.Query(ctx, listAccessMembers, organizationID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListAccessMembersRow + for rows.Next() { + var i ListAccessMembersRow + if err := rows.Scan( + &i.ID, + &i.DisplayName, + &i.Email, + &i.PhotoUrl, + &i.RoleID, + &i.JoinedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listActiveOrganizationRoles = `-- name: ListActiveOrganizationRoles :many WITH active_roles AS ( SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind diff --git a/server/internal/access/role_manager.go b/server/internal/access/role_manager.go index ed3853b7b4..4df271f4ee 100644 --- a/server/internal/access/role_manager.go +++ b/server/internal/access/role_manager.go @@ -96,61 +96,20 @@ type localRoleAssignment struct { // ListMembers returns locally known organization members with role IDs resolved from local role assignments. func (r *RoleManager) ListMembers(ctx context.Context, gramOrgID string) (*gen.ListMembersResult, error) { - roleRows, err := repo.New(r.db).ListActiveOrganizationRoles(ctx, gramOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list roles").Log(ctx, r.logger) - } - - roles := make(map[string]string, len(roleRows)) - for _, row := range roleRows { - roles[row.WorkosSlug] = row.ID.String() - } - - assignmentRows, err := repo.New(r.db).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) + rows, err := repo.New(r.db).ListAccessMembers(ctx, gramOrgID) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) } - assignments := make([]localRoleAssignment, 0, len(assignmentRows)) - for _, row := range assignmentRows { - assignments = append(assignments, localRoleAssignment{ - UserID: conv.FromPGTextOrEmpty[string](row.UserID), - WorkosUserID: row.WorkosUserID, - MembershipID: conv.FromPGTextOrEmpty[string](row.WorkosMembershipID), - RoleSlug: row.RoleSlug, - CreatedAt: conv.FromPGTimestamptz(row.CreatedAt), - }) - } - - userIDs := make([]string, 0, len(assignments)) - for _, assignment := range assignments { - if assignment.UserID != "" { - userIDs = append(userIDs, assignment.UserID) - } - } - localRows, err := usersrepo.New(r.db).GetUsersByIDs(ctx, userIDs) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "resolve users by ids").Log(ctx, r.logger) - } - localUsers := make(map[string]usersrepo.User, len(localRows)) - for _, u := range localRows { - localUsers[u.ID] = u - } - - result := make([]*gen.AccessMember, 0, len(assignments)) - for _, assignment := range assignments { - user, ok := localUsers[assignment.UserID] - if !ok { - continue - } - + result := make([]*gen.AccessMember, 0, len(rows)) + for _, row := range rows { result = append(result, &gen.AccessMember{ - ID: user.ID, - Name: conv.Default(user.DisplayName, user.Email), - Email: user.Email, - PhotoURL: conv.FromPGText[string](user.PhotoUrl), - RoleID: roles[assignment.RoleSlug], - JoinedAt: assignment.CreatedAt, + ID: row.ID, + Name: conv.Default(row.DisplayName, row.Email), + Email: row.Email, + PhotoURL: conv.FromPGText[string](row.PhotoUrl), + RoleID: row.RoleID.String(), + JoinedAt: conv.FromPGTimestamptz(row.JoinedAt), }) } @@ -169,7 +128,6 @@ func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID str if err != nil { return roleCreateResult{}, err } - trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSlug(roleSlug)) now := time.Now().UTC().Format(time.RFC3339) if err := repo.New(r.db).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, workos.Role{ From d7e50d3a104e1bfd93b18a4d97ed3686572bba28 Mon Sep 17 00:00:00 2001 From: tgmendes Date: Fri, 15 May 2026 11:07:18 +0100 Subject: [PATCH 05/12] return data when upserting instead of relying on separate query --- server/internal/access/queries.sql | 28 +++++++++- server/internal/access/repo/queries.sql.go | 52 +++++++++++++++++-- server/internal/access/role_manager.go | 32 ++++++++---- server/internal/access/setup_test.go | 2 +- .../activities/process_workos_org_events.go | 2 +- .../process_workos_org_events_test.go | 2 +- 6 files changed, 99 insertions(+), 19 deletions(-) diff --git a/server/internal/access/queries.sql b/server/internal/access/queries.sql index 9f5c85a6e2..2b3def7788 100644 --- a/server/internal/access/queries.sql +++ b/server/internal/access/queries.sql @@ -109,10 +109,11 @@ FROM organization_roles WHERE organization_id = @organization_id AND workos_slug = @workos_slug; --- name: UpsertOrganizationRole :exec +-- name: UpsertOrganizationRole :one -- Upsert an org-scoped WorkOS role. Caller must have already passed the row -- through ShouldProcessEvent. Resurrects a previously soft-deleted role on -- conflict. +WITH upserted AS ( INSERT INTO organization_roles ( organization_id, workos_slug, @@ -137,7 +138,30 @@ ON CONFLICT (organization_id, workos_slug) DO UPDATE SET workos_last_event_id = EXCLUDED.workos_last_event_id, deleted_at = NULL, workos_deleted_at = NULL, - updated_at = clock_timestamp(); + updated_at = clock_timestamp() +RETURNING + id, + organization_id, + workos_slug, + workos_name, + workos_description, + workos_created_at, + workos_updated_at +) +SELECT + upserted.id, + upserted.workos_slug, + upserted.workos_name, + upserted.workos_description, + upserted.workos_created_at, + upserted.workos_updated_at, + COUNT(ora.id)::bigint AS member_count +FROM upserted +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = upserted.organization_id + AND ora.role_urn = 'role:organization:' || upserted.id::text + AND ora.user_id IS NOT NULL +GROUP BY upserted.id, upserted.workos_slug, upserted.workos_name, upserted.workos_description, upserted.workos_created_at, upserted.workos_updated_at; -- name: MarkOrganizationRoleDeleted :execrows UPDATE organization_roles diff --git a/server/internal/access/repo/queries.sql.go b/server/internal/access/repo/queries.sql.go index 92b34c6aa9..eb835b4feb 100644 --- a/server/internal/access/repo/queries.sql.go +++ b/server/internal/access/repo/queries.sql.go @@ -875,7 +875,8 @@ func (q *Queries) UpsertGlobalRole(ctx context.Context, arg UpsertGlobalRolePara return err } -const upsertOrganizationRole = `-- name: UpsertOrganizationRole :exec +const upsertOrganizationRole = `-- name: UpsertOrganizationRole :one +WITH upserted AS ( INSERT INTO organization_roles ( organization_id, workos_slug, @@ -901,6 +902,29 @@ ON CONFLICT (organization_id, workos_slug) DO UPDATE SET deleted_at = NULL, workos_deleted_at = NULL, updated_at = clock_timestamp() +RETURNING + id, + organization_id, + workos_slug, + workos_name, + workos_description, + workos_created_at, + workos_updated_at +) +SELECT + upserted.id, + upserted.workos_slug, + upserted.workos_name, + upserted.workos_description, + upserted.workos_created_at, + upserted.workos_updated_at, + COUNT(ora.id)::bigint AS member_count +FROM upserted +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = upserted.organization_id + AND ora.role_urn = 'role:organization:' || upserted.id::text + AND ora.user_id IS NOT NULL +GROUP BY upserted.id, upserted.workos_slug, upserted.workos_name, upserted.workos_description, upserted.workos_created_at, upserted.workos_updated_at ` type UpsertOrganizationRoleParams struct { @@ -913,11 +937,21 @@ type UpsertOrganizationRoleParams struct { WorkosLastEventID pgtype.Text } +type UpsertOrganizationRoleRow struct { + ID uuid.UUID + WorkosSlug string + WorkosName string + WorkosDescription pgtype.Text + WorkosCreatedAt pgtype.Timestamptz + WorkosUpdatedAt pgtype.Timestamptz + MemberCount int64 +} + // Upsert an org-scoped WorkOS role. Caller must have already passed the row // through ShouldProcessEvent. Resurrects a previously soft-deleted role on // conflict. -func (q *Queries) UpsertOrganizationRole(ctx context.Context, arg UpsertOrganizationRoleParams) error { - _, err := q.db.Exec(ctx, upsertOrganizationRole, +func (q *Queries) UpsertOrganizationRole(ctx context.Context, arg UpsertOrganizationRoleParams) (UpsertOrganizationRoleRow, error) { + row := q.db.QueryRow(ctx, upsertOrganizationRole, arg.OrganizationID, arg.WorkosSlug, arg.WorkosName, @@ -926,7 +960,17 @@ func (q *Queries) UpsertOrganizationRole(ctx context.Context, arg UpsertOrganiza arg.WorkosUpdatedAt, arg.WorkosLastEventID, ) - return err + var i UpsertOrganizationRoleRow + err := row.Scan( + &i.ID, + &i.WorkosSlug, + &i.WorkosName, + &i.WorkosDescription, + &i.WorkosCreatedAt, + &i.WorkosUpdatedAt, + &i.MemberCount, + ) + return i, err } const upsertPrincipalGrant = `-- name: UpsertPrincipalGrant :one diff --git a/server/internal/access/role_manager.go b/server/internal/access/role_manager.go index 4df271f4ee..b38a1fa684 100644 --- a/server/internal/access/role_manager.go +++ b/server/internal/access/role_manager.go @@ -130,20 +130,26 @@ func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID str } now := time.Now().UTC().Format(time.RFC3339) - if err := repo.New(r.db).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, workos.Role{ + createdRow, err := repo.New(r.db).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, workos.Role{ ID: "", Name: payload.Name, Slug: roleSlug, Description: payload.Description, CreatedAt: now, UpdatedAt: now, - })); err != nil { + })) + if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "upsert local role record").Log(ctx, r.logger) } - createdRole, err := r.getLocalRoleBySlug(ctx, gramOrgID, roleSlug) - if err != nil { - return roleCreateResult{}, err + createdRole := localRole{ + ID: createdRow.ID.String(), + Name: createdRow.WorkosName, + Slug: createdRow.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](createdRow.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(createdRow.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(createdRow.WorkosUpdatedAt), + MemberCount: int(createdRow.MemberCount), } trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleID(createdRole.ID)) @@ -235,20 +241,26 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str localRecord.Description = *payload.Description } localRecord.UpdatedAt = time.Now().UTC().Format(time.RFC3339) - if err := repo.New(r.db).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, workos.Role{ + updatedRow, err := repo.New(r.db).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, workos.Role{ ID: "", Name: localRecord.Name, Slug: localRecord.Slug, Description: localRecord.Description, CreatedAt: localRecord.CreatedAt, UpdatedAt: localRecord.UpdatedAt, - })); err != nil { + })) + if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "upsert local role record").Log(ctx, r.logger) } - updatedRole, err = r.getLocalRoleBySlug(ctx, gramOrgID, localRecord.Slug) - if err != nil { - return roleUpdateResult{}, err + updatedRole = localRole{ + ID: updatedRow.ID.String(), + Name: updatedRow.WorkosName, + Slug: updatedRow.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](updatedRow.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(updatedRow.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(updatedRow.WorkosUpdatedAt), + MemberCount: int(updatedRow.MemberCount), } if payload.Grants != nil { diff --git a/server/internal/access/setup_test.go b/server/internal/access/setup_test.go index 2ab999b78a..dd421191d9 100644 --- a/server/internal/access/setup_test.go +++ b/server/internal/access/setup_test.go @@ -153,7 +153,7 @@ func seedRole(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizatio updatedAt, err := time.Parse(time.RFC3339, role.UpdatedAt) require.NoError(t, err) - err = accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + _, err = accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ OrganizationID: organizationID, WorkosSlug: role.Slug, WorkosName: role.Name, diff --git a/server/internal/background/activities/process_workos_org_events.go b/server/internal/background/activities/process_workos_org_events.go index 039e6ed755..567bd53eaa 100644 --- a/server/internal/background/activities/process_workos_org_events.go +++ b/server/internal/background/activities/process_workos_org_events.go @@ -378,7 +378,7 @@ func upsertOrganizationRole(ctx context.Context, logger *slog.Logger, dbtx datab return nil } - if err := repo.UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + if _, err := repo.UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ OrganizationID: org.ID, WorkosSlug: payload.Slug, WorkosName: payload.Name, diff --git a/server/internal/background/activities/process_workos_org_events_test.go b/server/internal/background/activities/process_workos_org_events_test.go index 7c133cc6d4..d21c2a6c93 100644 --- a/server/internal/background/activities/process_workos_org_events_test.go +++ b/server/internal/background/activities/process_workos_org_events_test.go @@ -566,7 +566,7 @@ func seedOrganizationRole(t *testing.T, ctx context.Context, conn *pgxpool.Pool, t.Helper() eventTime := time.Date(2026, 5, 6, 10, 0, 0, 0, time.UTC) - err := accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + _, err := accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ OrganizationID: organizationID, WorkosSlug: slug, WorkosName: slug, From 342c47e52042a58c3f8d95bb476afa9703e6ea5f Mon Sep 17 00:00:00 2001 From: tgmendes Date: Mon, 18 May 2026 10:00:37 +0100 Subject: [PATCH 06/12] chore: improve code --- server/cmd/gram/start.go | 2 +- server/internal/access/createrole_test.go | 8 +- server/internal/access/impl.go | 104 +---- server/internal/access/mock_role_test.go | 13 +- server/internal/access/queries.sql | 79 +++- server/internal/access/repo/queries.sql.go | 106 ++++- server/internal/access/role_manager.go | 458 +++++++++++++------- server/internal/access/role_manager_test.go | 2 +- server/internal/access/setup_test.go | 2 +- server/internal/authz/engine.go | 34 +- server/internal/authz/engine_test.go | 35 +- server/internal/authz/grants.go | 60 ++- server/internal/authz/setup_test.go | 36 ++ 13 files changed, 633 insertions(+), 306 deletions(-) diff --git a/server/cmd/gram/start.go b/server/cmd/gram/start.go index 4d3ffaeb4c..f1dee1630b 100644 --- a/server/cmd/gram/start.go +++ b/server/cmd/gram/start.go @@ -865,7 +865,7 @@ func newStartCommand() *cli.Command { about.Attach(mux, about.NewService(logger, tracerProvider)) external.AttachWebhookHandler(mux, external.NewWebhookHandler(logger, tracerProvider, newWorkOSWebhooksClient(c), temporalEnv)) - roleManager := access.NewRoleManager(logger, db, roleClient, authzEngine) + roleManager := access.NewRoleManager(logger, db, roleClient, auditLogger) access.Attach(mux, access.NewService(logger, tracerProvider, db, chDB, sessionManager, roleManager, authzEngine, productFeatures, auditLogger)) assistants.Attach(mux, assistantsSvc) assistantmemories.Attach(mux, assistantmemories.NewService( diff --git a/server/internal/access/createrole_test.go b/server/internal/access/createrole_test.go index dc5390826b..626d7ef153 100644 --- a/server/internal/access/createrole_test.go +++ b/server/internal/access/createrole_test.go @@ -230,7 +230,11 @@ func TestService_CreateRole_LocalRoleWriteFailureDoesNotAssignMembers(t *testing t.Cleanup(inspectConn.Close) ti.conn.Close() - _, err = ti.service.roleMgr.CreateRole(ctx, authCtx.ActiveOrganizationID, mockidp.MockOrgID, &gen.CreateRolePayload{ + _, err = ti.service.roleMgr.CreateRole(ctx, authCtx.ActiveOrganizationID, mockidp.MockOrgID, accessAuditActor{ + Principal: urn.NewPrincipal(urn.PrincipalTypeUser, authCtx.UserID), + DisplayName: authCtx.Email, + Slug: nil, + }, &gen.CreateRolePayload{ Name: "Broken Builder", Description: "Will fail local write", Grants: []*gen.RoleGrant{ @@ -239,7 +243,7 @@ func TestService_CreateRole_LocalRoleWriteFailureDoesNotAssignMembers(t *testing MemberIds: []string{"local_user_1", "local_user_2"}, }) require.Error(t, err) - require.Contains(t, err.Error(), "upsert local role record") + require.Contains(t, err.Error(), "role transaction") grants, err := accessrepo.New(inspectConn).ListPrincipalGrantsByOrg(ctx, accessrepo.ListPrincipalGrantsByOrgParams{ OrganizationID: authCtx.ActiveOrganizationID, diff --git a/server/internal/access/impl.go b/server/internal/access/impl.go index b3cff10018..5437f0df01 100644 --- a/server/internal/access/impl.go +++ b/server/internal/access/impl.go @@ -28,6 +28,7 @@ import ( chrepo "github.com/speakeasy-api/gram/server/internal/authz/repo" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/conv" + "github.com/speakeasy-api/gram/server/internal/database" "github.com/speakeasy-api/gram/server/internal/middleware" "github.com/speakeasy-api/gram/server/internal/o11y" "github.com/speakeasy-api/gram/server/internal/oops" @@ -146,32 +147,19 @@ func (s *Service) CreateRole(ctx context.Context, payload *gen.CreateRolePayload return nil, err } - logger := s.logger.With( - attr.SlogOrganizationID(ac.ActiveOrganizationID), - attr.SlogUserID(ac.UserID), - ) trace.SpanFromContext(ctx).SetAttributes( attr.OrganizationID(ac.ActiveOrganizationID), attr.UserID(ac.UserID), ) - created, err := s.roleMgr.CreateRole(ctx, ac.ActiveOrganizationID, workosOrgID, payload) + created, err := s.roleMgr.CreateRole(ctx, ac.ActiveOrganizationID, workosOrgID, accessAuditActor{ + Principal: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), + DisplayName: ac.Email, + Slug: nil, + }, payload) if err != nil { return nil, err } - logger = logger.With(attr.SlogAccessRoleSlug(created.Slug)) - - if err := s.audit.LogAccessRoleCreate(ctx, s.db, audit.LogAccessRoleCreateEvent{ - OrganizationID: ac.ActiveOrganizationID, - Actor: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), - ActorDisplayName: ac.Email, - ActorSlug: nil, - RoleID: created.Role.ID, - RoleName: created.Role.Name, - RoleSlug: created.Slug, - }); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "log access role creation").Log(ctx, logger) - } return created.Role, nil } @@ -184,38 +172,22 @@ func (s *Service) UpdateRole(ctx context.Context, payload *gen.UpdateRolePayload if err := s.authz.Require(ctx, authz.Check{Scope: authz.ScopeOrgAdmin, ResourceKind: "", ResourceID: ac.ActiveOrganizationID, Dimensions: nil}); err != nil { return nil, err } - logger := s.logger.With( - attr.SlogOrganizationID(ac.ActiveOrganizationID), - attr.SlogUserID(ac.UserID), - attr.SlogAccessRoleID(payload.ID), - ) trace.SpanFromContext(ctx).SetAttributes( attr.OrganizationID(ac.ActiveOrganizationID), attr.UserID(ac.UserID), attr.AccessRoleID(payload.ID), ) - updated, err := s.roleMgr.UpdateRole(ctx, ac.ActiveOrganizationID, workosOrgID, payload) + updated, err := s.roleMgr.UpdateRole(ctx, ac.ActiveOrganizationID, workosOrgID, accessAuditActor{ + Principal: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), + DisplayName: ac.Email, + Slug: nil, + }, payload) if err != nil { return nil, err } - logger = logger.With(attr.SlogAccessRoleSlug(updated.Role.Slug)) trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSlug(updated.Role.Slug)) - if err := s.audit.LogAccessRoleUpdate(ctx, s.db, audit.LogAccessRoleUpdateEvent{ - OrganizationID: ac.ActiveOrganizationID, - Actor: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), - ActorDisplayName: ac.Email, - ActorSlug: nil, - RoleID: updated.Role.ID, - RoleName: updated.After.Name, - RoleSlug: updated.Role.Slug, - RoleSnapshotBefore: updated.Before, - RoleSnapshotAfter: updated.After, - }); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "log access role update").Log(ctx, logger) - } - return updated.After, nil } @@ -227,36 +199,23 @@ func (s *Service) DeleteRole(ctx context.Context, payload *gen.DeleteRolePayload if err := s.authz.Require(ctx, authz.Check{Scope: authz.ScopeOrgAdmin, ResourceKind: "", ResourceID: ac.ActiveOrganizationID, Dimensions: nil}); err != nil { return err } - logger := s.logger.With( - attr.SlogOrganizationID(ac.ActiveOrganizationID), - attr.SlogUserID(ac.UserID), - attr.SlogAccessRoleID(payload.ID), - ) trace.SpanFromContext(ctx).SetAttributes( attr.OrganizationID(ac.ActiveOrganizationID), attr.UserID(ac.UserID), attr.AccessRoleID(payload.ID), ) - deletedRole, err := s.roleMgr.DeleteRole(ctx, ac.ActiveOrganizationID, workosOrgID, payload.ID) + deleted, err := s.roleMgr.DeleteRole(ctx, ac.ActiveOrganizationID, workosOrgID, payload.ID, accessAuditActor{ + Principal: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), + DisplayName: ac.Email, + Slug: nil, + }) if err != nil { return err } - logger = logger.With(attr.SlogAccessRoleSlug(deletedRole.Slug)) + deletedRole := deleted.Role trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSlug(deletedRole.Slug)) - if err := s.audit.LogAccessRoleDelete(ctx, s.db, audit.LogAccessRoleDeleteEvent{ - OrganizationID: ac.ActiveOrganizationID, - Actor: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), - ActorDisplayName: ac.Email, - ActorSlug: nil, - RoleID: deletedRole.ID, - RoleName: deletedRole.Name, - RoleSlug: deletedRole.Slug, - }); err != nil { - return oops.E(oops.CodeUnexpected, err, "log access role deletion").Log(ctx, logger) - } - return nil } @@ -387,12 +346,6 @@ func (s *Service) UpdateMemberRole(ctx context.Context, payload *gen.UpdateMembe if err := s.authz.Require(ctx, authz.Check{Scope: authz.ScopeOrgAdmin, ResourceKind: "", ResourceID: ac.ActiveOrganizationID, Dimensions: nil}); err != nil { return nil, err } - logger := s.logger.With( - attr.SlogOrganizationID(ac.ActiveOrganizationID), - attr.SlogUserID(ac.UserID), - attr.SlogAccessMemberID(payload.UserID), - attr.SlogAccessRoleID(payload.RoleID), - ) trace.SpanFromContext(ctx).SetAttributes( attr.OrganizationID(ac.ActiveOrganizationID), attr.UserID(ac.UserID), @@ -400,32 +353,21 @@ func (s *Service) UpdateMemberRole(ctx context.Context, payload *gen.UpdateMembe attr.AccessRoleID(payload.RoleID), ) - memberUpdate, err := s.roleMgr.UpdateMemberRole(ctx, ac.ActiveOrganizationID, payload.UserID, payload.RoleID) + memberUpdate, err := s.roleMgr.UpdateMemberRole(ctx, ac.ActiveOrganizationID, payload.UserID, payload.RoleID, accessAuditActor{ + Principal: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), + DisplayName: ac.Email, + Slug: nil, + }) if err != nil { return nil, err } roleSlug := memberUpdate.RoleSlug - logger = logger.With(attr.SlogAccessRoleSlug(roleSlug)) trace.SpanFromContext(ctx).SetAttributes( attr.OrganizationID(ac.ActiveOrganizationID), attr.UserID(ac.UserID), attr.AccessRoleSlug(roleSlug), ) - if err := s.audit.LogAccessMemberRoleUpdate(ctx, s.db, audit.LogAccessMemberRoleUpdateEvent{ - OrganizationID: ac.ActiveOrganizationID, - Actor: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), - ActorDisplayName: ac.Email, - ActorSlug: nil, - MemberID: memberUpdate.UserID, - MemberName: memberUpdate.After.Name, - MemberEmail: memberUpdate.After.Email, - MemberSnapshotBefore: memberUpdate.Before, - MemberSnapshotAfter: memberUpdate.After, - }); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "log access member role update").Log(ctx, logger) - } - return memberUpdate.After, nil } @@ -564,7 +506,7 @@ func listRoleGrantsFromGrants(grants []authz.Grant) []*gen.ListRoleGrant { return out } -func connectedUser(ctx context.Context, db *pgxpool.Pool, organizationID string, userID string) (usersrepo.User, error) { +func connectedUser(ctx context.Context, db database.DBTX, organizationID string, userID string) (usersrepo.User, error) { hasRelationship, err := orgrepo.New(db).HasOrganizationUserRelationship(ctx, orgrepo.HasOrganizationUserRelationshipParams{ OrganizationID: organizationID, UserID: userID, diff --git a/server/internal/access/mock_role_test.go b/server/internal/access/mock_role_test.go index 224256313f..c02ba4e410 100644 --- a/server/internal/access/mock_role_test.go +++ b/server/internal/access/mock_role_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -25,12 +26,22 @@ func newMockRoleProvider(t *testing.T) *MockRoleProvider { roles := &MockRoleProvider{} t.Cleanup(func() { - require.True(t, roles.AssertExpectations(t)) + require.Eventually(t, func() bool { + return roles.AssertExpectations(mockExpectationProbe{}) + }, 2*time.Second, 10*time.Millisecond) }) return roles } +type mockExpectationProbe struct{} + +func (mockExpectationProbe) Logf(string, ...any) {} + +func (mockExpectationProbe) Errorf(string, ...any) {} + +func (mockExpectationProbe) FailNow() {} + func (m *MockRoleProvider) CreateRole(ctx context.Context, orgID string, opts thirdpartyworkos.CreateRoleOpts) (*thirdpartyworkos.Role, error) { args := m.Called(ctx, orgID, opts) if role, ok := args.Get(0).(*thirdpartyworkos.Role); ok { diff --git a/server/internal/access/queries.sql b/server/internal/access/queries.sql index 2b3def7788..3cf2b4286d 100644 --- a/server/internal/access/queries.sql +++ b/server/internal/access/queries.sql @@ -212,24 +212,35 @@ GROUP BY active_roles.id, active_roles.workos_slug, active_roles.workos_name, ac ORDER BY active_roles.workos_slug; -- name: GetActiveOrganizationRoleBySlug :one +WITH active_roles AS ( + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.workos_slug = @workos_slug + AND deleted IS FALSE + AND workos_deleted IS FALSE + UNION ALL + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_id = @organization_id + AND organization_roles.workos_slug = @workos_slug + AND deleted IS FALSE + AND workos_deleted IS FALSE +) SELECT - organization_roles.id, - organization_roles.workos_slug, - organization_roles.workos_name, - organization_roles.workos_description, - organization_roles.workos_created_at, - organization_roles.workos_updated_at, + active_roles.id, + active_roles.workos_slug, + active_roles.workos_name, + active_roles.workos_description, + active_roles.workos_created_at, + active_roles.workos_updated_at, COUNT(ora.id)::bigint AS member_count -FROM organization_roles +FROM active_roles LEFT JOIN organization_role_assignments AS ora - ON ora.organization_id = organization_roles.organization_id - AND ora.role_urn = 'role:organization:' || organization_roles.id::text + ON ora.organization_id = @organization_id + AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL -WHERE organization_roles.organization_id = @organization_id - AND organization_roles.workos_slug = @workos_slug - AND organization_roles.deleted IS FALSE - AND organization_roles.workos_deleted IS FALSE -GROUP BY organization_roles.id, organization_roles.workos_slug, organization_roles.workos_name, organization_roles.workos_description, organization_roles.workos_created_at, organization_roles.workos_updated_at; +GROUP BY active_roles.id, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +LIMIT 1; -- name: GetOrganizationRoleByID :one WITH active_roles AS ( @@ -324,6 +335,46 @@ WHERE ora.organization_id = @organization_id AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL ORDER BY role_slug; +-- name: UpsertOrganizationRoleAssignment :execrows +WITH input_role_urn AS ( + SELECT 'role:organization:' || id::text AS role_urn + FROM organization_roles + WHERE organization_roles.organization_id = @organization_id + AND organization_roles.workos_slug = sqlc.arg(workos_role_slug) + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE + UNION ALL + SELECT 'role:global:' || id::text AS role_urn + FROM global_roles + WHERE global_roles.workos_slug = sqlc.arg(workos_role_slug) + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +) +INSERT INTO organization_role_assignments ( + organization_id, + workos_user_id, + user_id, + role_urn, + workos_membership_id, + workos_updated_at, + workos_last_event_id +) +SELECT + @organization_id, + @workos_user_id, + @user_id, + input_role_urn.role_urn, + @workos_membership_id, + @workos_updated_at, + @workos_last_event_id +FROM input_role_urn +ON CONFLICT (organization_id, workos_user_id, role_urn) DO UPDATE SET + user_id = COALESCE(EXCLUDED.user_id, organization_role_assignments.user_id), + workos_membership_id = EXCLUDED.workos_membership_id, + workos_updated_at = EXCLUDED.workos_updated_at, + workos_last_event_id = EXCLUDED.workos_last_event_id, + updated_at = clock_timestamp(); + -- name: ReplaceOrganizationRoleAssignment :one WITH input_role_urn AS ( SELECT 'role:organization:' || id::text AS role_urn diff --git a/server/internal/access/repo/queries.sql.go b/server/internal/access/repo/queries.sql.go index eb835b4feb..a01b124d8a 100644 --- a/server/internal/access/repo/queries.sql.go +++ b/server/internal/access/repo/queries.sql.go @@ -55,24 +55,35 @@ func (q *Queries) DeletePrincipalGrantsByPrincipal(ctx context.Context, arg Dele } const getActiveOrganizationRoleBySlug = `-- name: GetActiveOrganizationRoleBySlug :one +WITH active_roles AS ( + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.workos_slug = $2 + AND deleted IS FALSE + AND workos_deleted IS FALSE + UNION ALL + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_id = $1 + AND organization_roles.workos_slug = $2 + AND deleted IS FALSE + AND workos_deleted IS FALSE +) SELECT - organization_roles.id, - organization_roles.workos_slug, - organization_roles.workos_name, - organization_roles.workos_description, - organization_roles.workos_created_at, - organization_roles.workos_updated_at, + active_roles.id, + active_roles.workos_slug, + active_roles.workos_name, + active_roles.workos_description, + active_roles.workos_created_at, + active_roles.workos_updated_at, COUNT(ora.id)::bigint AS member_count -FROM organization_roles +FROM active_roles LEFT JOIN organization_role_assignments AS ora - ON ora.organization_id = organization_roles.organization_id - AND ora.role_urn = 'role:organization:' || organization_roles.id::text + ON ora.organization_id = $1 + AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL -WHERE organization_roles.organization_id = $1 - AND organization_roles.workos_slug = $2 - AND organization_roles.deleted IS FALSE - AND organization_roles.workos_deleted IS FALSE -GROUP BY organization_roles.id, organization_roles.workos_slug, organization_roles.workos_name, organization_roles.workos_description, organization_roles.workos_created_at, organization_roles.workos_updated_at +GROUP BY active_roles.id, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +LIMIT 1 ` type GetActiveOrganizationRoleBySlugParams struct { @@ -973,6 +984,73 @@ func (q *Queries) UpsertOrganizationRole(ctx context.Context, arg UpsertOrganiza return i, err } +const upsertOrganizationRoleAssignment = `-- name: UpsertOrganizationRoleAssignment :execrows +WITH input_role_urn AS ( + SELECT 'role:organization:' || id::text AS role_urn + FROM organization_roles + WHERE organization_roles.organization_id = $1 + AND organization_roles.workos_slug = $7 + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE + UNION ALL + SELECT 'role:global:' || id::text AS role_urn + FROM global_roles + WHERE global_roles.workos_slug = $7 + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +) +INSERT INTO organization_role_assignments ( + organization_id, + workos_user_id, + user_id, + role_urn, + workos_membership_id, + workos_updated_at, + workos_last_event_id +) +SELECT + $1, + $2, + $3, + input_role_urn.role_urn, + $4, + $5, + $6 +FROM input_role_urn +ON CONFLICT (organization_id, workos_user_id, role_urn) DO UPDATE SET + user_id = COALESCE(EXCLUDED.user_id, organization_role_assignments.user_id), + workos_membership_id = EXCLUDED.workos_membership_id, + workos_updated_at = EXCLUDED.workos_updated_at, + workos_last_event_id = EXCLUDED.workos_last_event_id, + updated_at = clock_timestamp() +` + +type UpsertOrganizationRoleAssignmentParams struct { + OrganizationID string + WorkosUserID string + UserID pgtype.Text + WorkosMembershipID pgtype.Text + WorkosUpdatedAt pgtype.Timestamptz + WorkosLastEventID pgtype.Text + WorkosRoleSlug string +} + +func (q *Queries) UpsertOrganizationRoleAssignment(ctx context.Context, arg UpsertOrganizationRoleAssignmentParams) (int64, error) { + result, err := q.db.Exec(ctx, upsertOrganizationRoleAssignment, + arg.OrganizationID, + arg.WorkosUserID, + arg.UserID, + arg.WorkosMembershipID, + arg.WorkosUpdatedAt, + arg.WorkosLastEventID, + arg.WorkosRoleSlug, + ) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} + const upsertPrincipalGrant = `-- name: UpsertPrincipalGrant :one INSERT INTO principal_grants (organization_id, principal_urn, scope, selectors) VALUES ($1, $2, $3, $4) diff --git a/server/internal/access/role_manager.go b/server/internal/access/role_manager.go index b38a1fa684..fed74db866 100644 --- a/server/internal/access/role_manager.go +++ b/server/internal/access/role_manager.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/cenkalti/backoff/v5" "github.com/google/uuid" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" @@ -17,8 +18,10 @@ import ( gen "github.com/speakeasy-api/gram/server/gen/access" "github.com/speakeasy-api/gram/server/internal/access/repo" "github.com/speakeasy-api/gram/server/internal/attr" + "github.com/speakeasy-api/gram/server/internal/audit" "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/conv" + "github.com/speakeasy-api/gram/server/internal/o11y" "github.com/speakeasy-api/gram/server/internal/oops" "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" @@ -44,16 +47,16 @@ type RoleManager struct { db *pgxpool.Pool logger *slog.Logger roles RoleProvider - authz *authz.Engine + audit *audit.Logger } -// NewRoleManager wires the role manager to the local DB, the WorkOS role client, and the authz engine. -func NewRoleManager(logger *slog.Logger, db *pgxpool.Pool, roles RoleProvider, authzEngine *authz.Engine) *RoleManager { +// NewRoleManager wires the role manager to the local DB, the WorkOS role client, and the audit logger. +func NewRoleManager(logger *slog.Logger, db *pgxpool.Pool, roles RoleProvider, auditLogger *audit.Logger) *RoleManager { return &RoleManager{ db: db, logger: logger.With(attr.SlogComponent("access.role_manager")), roles: roles, - authz: authzEngine, + audit: auditLogger, } } @@ -121,16 +124,30 @@ type roleCreateResult struct { Slug string } -// CreateRole creates the local role record, syncs local grants, optionally assigns members, and then best-effort syncs WorkOS. -// Side effects: writes Postgres role/grant/assignment records, invalidates local authz state, and logs WorkOS sync failures after bounded retries. -func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID string, payload *gen.CreateRolePayload) (roleCreateResult, error) { +type workosSync func(context.Context) + +type accessAuditActor struct { + Principal urn.Principal + DisplayName *string + Slug *string +} + +// CreateRole creates the local role, grants, optional assignments, and audit entry atomically, then best-effort syncs WorkOS after commit. +// Side effects: writes Postgres role/grant/assignment/audit records and logs WorkOS sync failures after bounded retries. +func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID string, actor accessAuditActor, payload *gen.CreateRolePayload) (roleCreateResult, error) { roleSlug, err := slugify(payload.Name) if err != nil { return roleCreateResult{}, err } + tx, err := r.db.Begin(ctx) + if err != nil { + return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "begin role transaction").Log(ctx, r.logger) + } + defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) + now := time.Now().UTC().Format(time.RFC3339) - createdRow, err := repo.New(r.db).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, workos.Role{ + createdRow, err := repo.New(tx).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, workos.Role{ ID: "", Name: payload.Name, Slug: roleSlug, @@ -142,46 +159,61 @@ func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID str trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "upsert local role record").Log(ctx, r.logger) } - createdRole := localRole{ - ID: createdRow.ID.String(), - Name: createdRow.WorkosName, - Slug: createdRow.WorkosSlug, - Description: conv.FromPGTextOrEmpty[string](createdRow.WorkosDescription), - CreatedAt: conv.FromPGTimestamptz(createdRow.WorkosCreatedAt), - UpdatedAt: conv.FromPGTimestamptz(createdRow.WorkosUpdatedAt), - MemberCount: int(createdRow.MemberCount), - } + createdRole := localRoleFromUpsertRow(createdRow) trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleID(createdRole.ID)) - if err := authz.SyncGrants(ctx, r.logger, r.db, gramOrgID, roleSlug, roleGrantPayloads(payload.Grants)); err != nil { + if _, err := authz.SyncGrantsTx(ctx, tx, gramOrgID, roleSlug, roleGrantPayloads(payload.Grants)); err != nil { return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "sync grants for created role").Log(ctx, r.logger) } - r.syncWorkOS(ctx, "create role in workos", func() error { - _, err := r.roles.CreateRole(ctx, workosOrgID, workos.CreateRoleOpts{ - Name: payload.Name, - Slug: roleSlug, - Description: payload.Description, + workosSyncs := []workosSync{func(ctx context.Context) { + r.syncWorkOS(ctx, "create role in workos", func() error { + _, err := r.roles.CreateRole(ctx, workosOrgID, workos.CreateRoleOpts{ + Name: payload.Name, + Slug: roleSlug, + Description: payload.Description, + }) + var apiErr *workos.APIError + if errors.As(err, &apiErr) && apiErr.StatusCode == 409 { + return nil + } + if err == nil { + return nil + } + return fmt.Errorf("create role in workos: %w", err) }) - var apiErr *workos.APIError - if errors.As(err, &apiErr) && apiErr.StatusCode == 409 { - return nil - } - if err == nil { - return nil - } - return fmt.Errorf("create role in workos: %w", err) - }) + }} if len(payload.MemberIds) > 0 { - if _, err := r.assignMembersToRole(ctx, gramOrgID, roleSlug, payload.MemberIds); err != nil { + var memberSyncs []workosSync + if _, memberSyncs, err = r.assignMembersToRoleTx(ctx, tx, gramOrgID, roleSlug, payload.MemberIds); err != nil { return roleCreateResult{}, err } - createdRole, err = r.getLocalRoleBySlug(ctx, gramOrgID, roleSlug) + workosSyncs = append(workosSyncs, memberSyncs...) + createdRole, err = r.getLocalRoleBySlugTx(ctx, tx, gramOrgID, roleSlug) if err != nil { return roleCreateResult{}, err } } + + if err := r.audit.LogAccessRoleCreate(ctx, tx, audit.LogAccessRoleCreateEvent{ + OrganizationID: gramOrgID, + Actor: actor.Principal, + ActorDisplayName: actor.DisplayName, + ActorSlug: actor.Slug, + RoleID: createdRole.ID, + RoleName: createdRole.Name, + RoleSlug: createdRole.Slug, + }); err != nil { + return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "log access role creation").Log(ctx, r.logger) + } + + if err := tx.Commit(ctx); err != nil { + return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "commit role transaction").Log(ctx, r.logger) + } + + r.runWorkOSSyncs(ctx, workosSyncs) + role, err := r.roleViewFromLocalRole(ctx, gramOrgID, createdRole) if err != nil { return roleCreateResult{}, err @@ -206,9 +238,9 @@ type roleUpdateResult struct { Role localRole } -// UpdateRole updates an existing local role and optionally replaces its assigned members, then best-effort syncs WorkOS. -// Side effects: writes Postgres role/grant/assignment records, invalidates local authz state, and logs WorkOS sync failures after bounded retries. -func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID string, payload *gen.UpdateRolePayload) (roleUpdateResult, error) { +// UpdateRole updates an existing local role, optional grants/assignments, and audit entry atomically, then best-effort syncs WorkOS after commit. +// Side effects: writes Postgres role/grant/assignment/audit records and logs WorkOS sync failures after bounded retries. +func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID string, actor accessAuditActor, payload *gen.UpdateRolePayload) (roleUpdateResult, error) { currentRole, err := r.getLocalRoleByID(ctx, gramOrgID, payload.ID) if err != nil { return roleUpdateResult{}, err @@ -231,7 +263,15 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str } } + tx, err := r.db.Begin(ctx) + if err != nil { + return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "begin role transaction").Log(ctx, r.logger) + } + defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) + updatedRole := currentRole + var workosSyncs []workosSync + var updatedGrants []*gen.RoleGrant if !sysRole { localRecord := currentRole if payload.Name != nil { @@ -241,7 +281,7 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str localRecord.Description = *payload.Description } localRecord.UpdatedAt = time.Now().UTC().Format(time.RFC3339) - updatedRow, err := repo.New(r.db).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, workos.Role{ + updatedRow, err := repo.New(tx).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, workos.Role{ ID: "", Name: localRecord.Name, Slug: localRecord.Slug, @@ -253,122 +293,172 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "upsert local role record").Log(ctx, r.logger) } - updatedRole = localRole{ - ID: updatedRow.ID.String(), - Name: updatedRow.WorkosName, - Slug: updatedRow.WorkosSlug, - Description: conv.FromPGTextOrEmpty[string](updatedRow.WorkosDescription), - CreatedAt: conv.FromPGTimestamptz(updatedRow.WorkosCreatedAt), - UpdatedAt: conv.FromPGTimestamptz(updatedRow.WorkosUpdatedAt), - MemberCount: int(updatedRow.MemberCount), - } + updatedRole = localRoleFromUpsertRow(updatedRow) if payload.Grants != nil { - if err := authz.SyncGrants(ctx, r.logger, r.db, gramOrgID, currentRole.Slug, roleGrantPayloads(payload.Grants)); err != nil { + syncedGrants, err := authz.SyncGrantsTx(ctx, tx, gramOrgID, currentRole.Slug, roleGrantPayloads(payload.Grants)) + if err != nil { return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "sync grants for updated role").Log(ctx, r.logger) } + updatedGrants = make([]*gen.RoleGrant, 0, len(syncedGrants)) + for _, grant := range syncedGrants { + updatedGrants = append(updatedGrants, scopedGrantToGenRoleGrant(grant)) + } } - r.syncWorkOS(ctx, "update role in workos", func() error { - _, err := r.roles.UpdateRole(ctx, workosOrgID, currentRole.Slug, workos.UpdateRoleOpts{ - Name: payload.Name, - Description: payload.Description, + workosSyncs = append(workosSyncs, func(ctx context.Context) { + r.syncWorkOS(ctx, "update role in workos", func() error { + _, err := r.roles.UpdateRole(ctx, workosOrgID, currentRole.Slug, workos.UpdateRoleOpts{ + Name: payload.Name, + Description: payload.Description, + }) + if err == nil { + return nil + } + return fmt.Errorf("update role in workos: %w", err) }) - if err == nil { - return nil - } - return fmt.Errorf("update role in workos: %w", err) }) } if payload.MemberIds != nil { - if _, err := r.assignMembersToRole(ctx, gramOrgID, currentRole.Slug, payload.MemberIds); err != nil { + var memberSyncs []workosSync + if _, memberSyncs, err = r.assignMembersToRoleTx(ctx, tx, gramOrgID, currentRole.Slug, payload.MemberIds); err != nil { return roleUpdateResult{}, err } - updatedRole, err = r.getLocalRoleByID(ctx, gramOrgID, payload.ID) + workosSyncs = append(workosSyncs, memberSyncs...) + updatedRole, err = r.getLocalRoleByIDTx(ctx, tx, gramOrgID, payload.ID) if err != nil { return roleUpdateResult{}, err } } - updatedRoleView, err := r.roleViewFromLocalRole(ctx, gramOrgID, updatedRole) - if err != nil { - return roleUpdateResult{}, err + updatedRoleView := roleViewFromLocalRoleAndGrants(updatedRole, existingRole.Grants) + if updatedGrants != nil { + updatedRoleView.Grants = updatedGrants + } + + if err := r.audit.LogAccessRoleUpdate(ctx, tx, audit.LogAccessRoleUpdateEvent{ + OrganizationID: gramOrgID, + Actor: actor.Principal, + ActorDisplayName: actor.DisplayName, + ActorSlug: actor.Slug, + RoleID: updatedRole.ID, + RoleName: updatedRoleView.Name, + RoleSlug: updatedRole.Slug, + RoleSnapshotBefore: existingRole, + RoleSnapshotAfter: updatedRoleView, + }); err != nil { + return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "log access role update").Log(ctx, r.logger) + } + + if err := tx.Commit(ctx); err != nil { + return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "commit role transaction").Log(ctx, r.logger) } + r.runWorkOSSyncs(ctx, workosSyncs) + return roleUpdateResult{Before: existingRole, After: updatedRoleView, Role: updatedRole}, nil } -// DeleteRole deletes a custom local role after moving assigned members to the default member role, then best-effort syncs WorkOS. -// Side effects: writes Postgres assignment/role/grant records, invalidates local authz state, and logs WorkOS sync failures after bounded retries. -func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, roleID string) (localRole, error) { +type roleDeleteResult struct { + Role localRole +} + +// DeleteRole deletes a custom local role, reassignment records, grants, and audit entry atomically, then best-effort syncs WorkOS after commit. +// Side effects: writes Postgres assignment/role/grant/audit records and logs WorkOS sync failures after bounded retries. +func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, roleID string, actor accessAuditActor) (roleDeleteResult, error) { currentRole, err := r.getLocalRoleByID(ctx, gramOrgID, roleID) if err != nil { - return localRole{}, err + return roleDeleteResult{}, err } if isSystemRole(currentRole.Slug) { - return localRole{}, oops.E(oops.CodeBadRequest, nil, "system roles cannot be deleted").Log(ctx, r.logger) + return roleDeleteResult{}, oops.E(oops.CodeBadRequest, nil, "system roles cannot be deleted").Log(ctx, r.logger) + } + + tx, err := r.db.Begin(ctx) + if err != nil { + return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "begin role transaction").Log(ctx, r.logger) } + defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) - rows, err := repo.New(r.db).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) + rows, err := repo.New(tx).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) if err != nil { - return localRole{}, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) + return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) } trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + var workosSyncs []workosSync for _, row := range rows { if row.RoleSlug != currentRole.Slug { continue } membershipID := conv.FromPGTextOrEmpty[string](row.WorkosMembershipID) if row.WorkosUserID != "" { - replaced, err := repo.New(r.db).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, row.WorkosUserID, authz.SystemRoleMember, "", membershipID)) + replaced, err := repo.New(tx).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, row.WorkosUserID, authz.SystemRoleMember, "", membershipID)) if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - return localRole{}, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) + return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) } if replaced == 0 { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - return localRole{}, oops.E(oops.CodeUnexpected, nil, "upsert local role assignment record").Log(ctx, r.logger) + return roleDeleteResult{}, oops.E(oops.CodeUnexpected, nil, "upsert local role assignment record").Log(ctx, r.logger) } } - if userID := conv.FromPGTextOrEmpty[string](row.UserID); userID != "" { - r.authz.InvalidateRoleCache(ctx, userID, gramOrgID) - } - r.syncWorkOS(ctx, "reassign member to default role in workos", func() error { - _, err := r.roles.UpdateMemberRole(ctx, membershipID, authz.SystemRoleMember) - if err == nil { - return nil - } - return fmt.Errorf("reassign member to default role in workos: %w", err) + workosSyncs = append(workosSyncs, func(ctx context.Context) { + r.syncWorkOS(ctx, "reassign member to default role in workos", func() error { + _, err := r.roles.UpdateMemberRole(ctx, membershipID, authz.SystemRoleMember) + if err == nil { + return nil + } + return fmt.Errorf("reassign member to default role in workos: %w", err) + }) }) } - if _, err := repo.New(r.db).MarkOrganizationRoleDeletedLocally(ctx, repo.MarkOrganizationRoleDeletedLocallyParams{ + if _, err := repo.New(tx).MarkOrganizationRoleDeletedLocally(ctx, repo.MarkOrganizationRoleDeletedLocallyParams{ OrganizationID: gramOrgID, WorkosSlug: currentRole.Slug, WorkosLastEventID: conv.ToPGTextEmpty(""), }); err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - return localRole{}, oops.E(oops.CodeUnexpected, err, "mark local role record deleted").Log(ctx, r.logger) + return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "mark local role record deleted").Log(ctx, r.logger) } - if _, err := repo.New(r.db).DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ + if _, err := repo.New(tx).DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ OrganizationID: gramOrgID, PrincipalUrn: urn.NewPrincipal(urn.PrincipalTypeRole, currentRole.Slug), }); err != nil { - return localRole{}, oops.E(oops.CodeUnexpected, err, "delete grants for deleted role").Log(ctx, r.logger) + return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "delete grants for deleted role").Log(ctx, r.logger) } - r.authz.InvalidateAllRoleCaches(ctx, gramOrgID) - r.syncWorkOS(ctx, "delete role in workos", func() error { - if err := r.roles.DeleteRole(ctx, workosOrgID, currentRole.Slug); err != nil { - return fmt.Errorf("delete role in workos: %w", err) - } - return nil + if err := r.audit.LogAccessRoleDelete(ctx, tx, audit.LogAccessRoleDeleteEvent{ + OrganizationID: gramOrgID, + Actor: actor.Principal, + ActorDisplayName: actor.DisplayName, + ActorSlug: actor.Slug, + RoleID: currentRole.ID, + RoleName: currentRole.Name, + RoleSlug: currentRole.Slug, + }); err != nil { + return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "log access role deletion").Log(ctx, r.logger) + } + + if err := tx.Commit(ctx); err != nil { + return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "commit role transaction").Log(ctx, r.logger) + } + + workosSyncs = append(workosSyncs, func(ctx context.Context) { + r.syncWorkOS(ctx, "delete role in workos", func() error { + if err := r.roles.DeleteRole(ctx, workosOrgID, currentRole.Slug); err != nil { + return fmt.Errorf("delete role in workos: %w", err) + } + return nil + }) }) - return currentRole, nil + r.runWorkOSSyncs(ctx, workosSyncs) + + return roleDeleteResult{Role: currentRole}, nil } type memberRoleUpdateContext struct { @@ -380,15 +470,21 @@ type memberRoleUpdateContext struct { After *gen.AccessMember } -// UpdateMemberRole changes one member's local role assignment, then best-effort syncs WorkOS. -// Side effects: writes a Postgres assignment record, invalidates local authz state, and logs WorkOS sync failures after bounded retries. -func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, roleID string) (memberRoleUpdateContext, error) { - role, err := r.getLocalRoleByID(ctx, gramOrgID, roleID) +// UpdateMemberRole changes one member's local role assignment and audit entry atomically, then best-effort syncs WorkOS after commit. +// Side effects: writes a Postgres assignment/audit record and logs WorkOS sync failures after bounded retries. +func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, roleID string, actor accessAuditActor) (memberRoleUpdateContext, error) { + tx, err := r.db.Begin(ctx) + if err != nil { + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "begin role transaction").Log(ctx, r.logger) + } + defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) + + role, err := r.getLocalRoleByIDTx(ctx, tx, gramOrgID, roleID) if err != nil { return memberRoleUpdateContext{}, err } - connectedUser, err := connectedUser(ctx, r.db, gramOrgID, userID) + connectedUser, err := connectedUser(ctx, tx, gramOrgID, userID) switch { case errors.Is(err, errConnectedUserNotFound): return memberRoleUpdateContext{}, oops.E(oops.CodeNotFound, nil, "member has not joined this organization").Log(ctx, r.logger) @@ -399,7 +495,7 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r return memberRoleUpdateContext{}, oops.E(oops.CodeBadRequest, nil, "member is not linked to WorkOS").Log(ctx, r.logger) } - roleRows, err := repo.New(r.db).ListActiveOrganizationRoles(ctx, gramOrgID) + roleRows, err := repo.New(tx).ListActiveOrganizationRoles(ctx, gramOrgID) if err != nil { return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "list roles").Log(ctx, r.logger) } @@ -409,7 +505,7 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r roleIDBySlug[row.WorkosSlug] = row.ID.String() } - assignmentRows, err := repo.New(r.db).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) + assignmentRows, err := repo.New(tx).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) if err != nil { return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) } @@ -432,7 +528,7 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r } if existing.WorkosUserID != "" && role.Slug != "" { - replaced, err := repo.New(r.db).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, existing.WorkosUserID, role.Slug, connectedUser.ID, existing.MembershipID)) + replaced, err := repo.New(tx).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, existing.WorkosUserID, role.Slug, connectedUser.ID, existing.MembershipID)) if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) @@ -442,17 +538,9 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, nil, "upsert local role assignment record").Log(ctx, r.logger) } } - r.authz.InvalidateRoleCache(ctx, userID, gramOrgID) - r.syncWorkOS(ctx, "update member role in workos", func() error { - _, err := r.roles.UpdateMemberRole(ctx, existing.MembershipID, role.Slug) - if err == nil { - return nil - } - return fmt.Errorf("update member role in workos: %w", err) - }) memberName := conv.Default(connectedUser.DisplayName, connectedUser.Email) - return memberRoleUpdateContext{ + result := memberRoleUpdateContext{ RoleSlug: role.Slug, MembershipID: existing.MembershipID, WorkosUserID: connectedUser.WorkosID.String, @@ -473,7 +561,39 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r RoleID: roleID, JoinedAt: existing.CreatedAt, }, - }, nil + } + + if err := r.audit.LogAccessMemberRoleUpdate(ctx, tx, audit.LogAccessMemberRoleUpdateEvent{ + OrganizationID: gramOrgID, + Actor: actor.Principal, + ActorDisplayName: actor.DisplayName, + ActorSlug: actor.Slug, + MemberID: result.UserID, + MemberName: result.After.Name, + MemberEmail: result.After.Email, + MemberSnapshotBefore: result.Before, + MemberSnapshotAfter: result.After, + }); err != nil { + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "log access member role update").Log(ctx, r.logger) + } + + if err := tx.Commit(ctx); err != nil { + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "commit role transaction").Log(ctx, r.logger) + } + + r.runWorkOSSyncs(ctx, []workosSync{ + func(ctx context.Context) { + r.syncWorkOS(ctx, "update member role in workos", func() error { + _, err := r.roles.UpdateMemberRole(ctx, existing.MembershipID, role.Slug) + if err == nil { + return nil + } + return fmt.Errorf("update member role in workos: %w", err) + }) + }, + }) + + return result, nil } // MemberRoleSlugs returns local role slugs assigned to a WorkOS user inside an organization. @@ -498,12 +618,16 @@ func (r *RoleManager) MemberRoleSlugs(ctx context.Context, gramOrgID, workosUser // getLocalRoleByID loads one local role record by Gram role ID. // Side effects: reads Postgres local role records; does not call WorkOS. func (r *RoleManager) getLocalRoleByID(ctx context.Context, gramOrgID, id string) (localRole, error) { + return r.getLocalRoleByIDTx(ctx, r.db, gramOrgID, id) +} + +func (r *RoleManager) getLocalRoleByIDTx(ctx context.Context, dbtx repo.DBTX, gramOrgID, id string) (localRole, error) { roleID, err := uuid.Parse(id) if err != nil { return localRole{}, oops.E(oops.CodeBadRequest, err, "invalid role ID").Log(ctx, r.logger) } - row, err := repo.New(r.db).GetOrganizationRoleByID(ctx, repo.GetOrganizationRoleByIDParams{ + row, err := repo.New(dbtx).GetOrganizationRoleByID(ctx, repo.GetOrganizationRoleByIDParams{ ID: roleID, OrganizationID: gramOrgID, }) @@ -518,10 +642,8 @@ func (r *RoleManager) getLocalRoleByID(ctx context.Context, gramOrgID, id string return localRoleFromRoleRow(row), nil } -// getLocalRoleBySlug loads one local organization role record by WorkOS slug. -// Side effects: reads Postgres local role records; does not call WorkOS. -func (r *RoleManager) getLocalRoleBySlug(ctx context.Context, gramOrgID, slug string) (localRole, error) { - row, err := repo.New(r.db).GetActiveOrganizationRoleBySlug(ctx, repo.GetActiveOrganizationRoleBySlugParams{ +func (r *RoleManager) getLocalRoleBySlugTx(ctx context.Context, dbtx repo.DBTX, gramOrgID, slug string) (localRole, error) { + row, err := repo.New(dbtx).GetActiveOrganizationRoleBySlug(ctx, repo.GetActiveOrganizationRoleBySlugParams{ OrganizationID: gramOrgID, WorkosSlug: slug, }) @@ -542,8 +664,7 @@ type memberAssignmentTarget struct { MembershipID string } -// memberAssignmentTargets resolves Gram user IDs to WorkOS membership IDs using local user and local assignment records. -func (r *RoleManager) memberAssignmentTargets(ctx context.Context, gramOrgID string, memberIDs []string) ([]memberAssignmentTarget, error) { +func (r *RoleManager) memberAssignmentTargetsTx(ctx context.Context, dbtx repo.DBTX, gramOrgID string, memberIDs []string) ([]memberAssignmentTarget, error) { if len(memberIDs) == 0 { return nil, nil } @@ -552,7 +673,7 @@ func (r *RoleManager) memberAssignmentTargets(ctx context.Context, gramOrgID str requested[id] = struct{}{} } - users, err := usersrepo.New(r.db).GetUsersByIDs(ctx, memberIDs) + users, err := usersrepo.New(dbtx).GetUsersByIDs(ctx, memberIDs) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "resolve users by ids").Log(ctx, r.logger) } @@ -567,7 +688,7 @@ func (r *RoleManager) memberAssignmentTargets(ctx context.Context, gramOrgID str } } - assignmentRows, err := repo.New(r.db).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) + assignmentRows, err := repo.New(dbtx).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) } @@ -614,63 +735,75 @@ func (r *RoleManager) memberAssignmentTargets(ctx context.Context, gramOrgID str return targets, nil } -// assignMembersToRole moves each requested member to the given local role and best-effort syncs WorkOS. -// Side effects: reads local users and assignments, writes Postgres assignment records, invalidates local authz state, and logs WorkOS sync failures after bounded retries. -func (r *RoleManager) assignMembersToRole(ctx context.Context, gramOrgID, roleSlug string, memberIDs []string) (int, error) { - targets, err := r.memberAssignmentTargets(ctx, gramOrgID, memberIDs) +func (r *RoleManager) assignMembersToRoleTx(ctx context.Context, dbtx repo.DBTX, gramOrgID, roleSlug string, memberIDs []string) (int, []workosSync, error) { + targets, err := r.memberAssignmentTargetsTx(ctx, dbtx, gramOrgID, memberIDs) if err != nil { - return 0, err + return 0, nil, err } assignedCount := 0 + workosSyncs := make([]workosSync, 0, len(targets)) for _, target := range targets { if target.WorkosUserID != "" && roleSlug != "" { - replaced, err := repo.New(r.db).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, target.WorkosUserID, roleSlug, target.UserID, target.MembershipID)) + replaced, err := repo.New(dbtx).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, target.WorkosUserID, roleSlug, target.UserID, target.MembershipID)) if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - return 0, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) + return 0, nil, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) } if replaced == 0 { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - return 0, oops.E(oops.CodeUnexpected, nil, "upsert local role assignment record").Log(ctx, r.logger) + return 0, nil, oops.E(oops.CodeUnexpected, nil, "upsert local role assignment record").Log(ctx, r.logger) } } assignedCount++ - r.authz.InvalidateRoleCache(ctx, target.UserID, gramOrgID) - r.syncWorkOS(ctx, "assign member to role in workos", func() error { - _, err := r.roles.UpdateMemberRole(ctx, target.MembershipID, roleSlug) - if err == nil { - return nil - } - return fmt.Errorf("assign member to role in workos: %w", err) + membershipID := target.MembershipID + workosSyncs = append(workosSyncs, func(ctx context.Context) { + r.syncWorkOS(ctx, "assign member to role in workos", func() error { + _, err := r.roles.UpdateMemberRole(ctx, membershipID, roleSlug) + if err == nil { + return nil + } + return fmt.Errorf("assign member to role in workos: %w", err) + }) }) } - if assignedCount > 0 { - r.authz.InvalidateAllRoleCaches(ctx, gramOrgID) - } + return assignedCount, workosSyncs, nil +} - return assignedCount, nil +// runWorkOSSyncs starts best-effort WorkOS writes after the local transaction commits. +// Side effects: launches a goroutine that outlives request cancellation. +func (r *RoleManager) runWorkOSSyncs(ctx context.Context, syncs []workosSync) { + if len(syncs) == 0 { + return + } + syncCtx := context.WithoutCancel(ctx) + go func() { + for _, syncWorkOS := range syncs { + syncWorkOS(syncCtx) + } + }() } // syncWorkOS runs a bounded best-effort WorkOS write after the local database already accepted the change. // Side effects: calls WorkOS, waits briefly between retryable failures, and logs the final failure without returning it. func (r *RoleManager) syncWorkOS(ctx context.Context, operation string, fn func() error) { - var err error - for attempt := 1; attempt <= workOSSyncAttempts; attempt++ { - err = fn() + exp := backoff.NewExponentialBackOff() + exp.InitialInterval = 100 * time.Millisecond + exp.MaxInterval = 300 * time.Millisecond + exp.RandomizationFactor = 0 + + _, err := backoff.Retry(ctx, func() (struct{}, error) { + err := fn() if err == nil { - return - } - if !retryWorkOSError(err) || attempt == workOSSyncAttempts { - break + return struct{}{}, nil } - - select { - case <-ctx.Done(): - err = ctx.Err() - attempt = workOSSyncAttempts - case <-time.After(time.Duration(attempt) * 100 * time.Millisecond): + if !retryWorkOSError(err) { + return struct{}{}, backoff.Permanent(err) } + return struct{}{}, err + }, backoff.WithBackOff(exp), backoff.WithMaxTries(workOSSyncAttempts)) + if err == nil { + return } r.logger.ErrorContext(ctx, "workos sync failed: "+operation, attr.SlogError(err)) @@ -710,6 +843,19 @@ func (r *RoleManager) roleViewFromLocalRole(ctx context.Context, organizationID }, nil } +func roleViewFromLocalRoleAndGrants(role localRole, grants []*gen.RoleGrant) *gen.Role { + return &gen.Role{ + ID: role.ID, + Name: role.Name, + Description: role.Description, + IsSystem: isSystemRole(role.Slug), + Grants: grants, + MemberCount: role.MemberCount, + CreatedAt: conv.Default(role.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), + UpdatedAt: conv.Default(role.UpdatedAt, time.Time{}.UTC().Format(time.RFC3339)), + } +} + // localRoleFromActiveRow converts a sqlc active-role row into the manager's internal local role record shape. // Side effects: none. func localRoleFromActiveRow(row repo.ListActiveOrganizationRolesRow) localRole { @@ -738,6 +884,18 @@ func localRoleFromRoleRow(row repo.GetOrganizationRoleByIDRow) localRole { } } +func localRoleFromUpsertRow(row repo.UpsertOrganizationRoleRow) localRole { + return localRole{ + ID: row.ID.String(), + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), + } +} + // localRoleFromSlugRow converts a sqlc role slug lookup row into the manager's internal local role record shape. // Side effects: none. func localRoleFromSlugRow(row repo.GetActiveOrganizationRoleBySlugRow) localRole { diff --git a/server/internal/access/role_manager_test.go b/server/internal/access/role_manager_test.go index b4fe9d343a..804fd6d4b3 100644 --- a/server/internal/access/role_manager_test.go +++ b/server/internal/access/role_manager_test.go @@ -98,7 +98,7 @@ func TestRoleManager_AssignMembersToRoleRequiresLocalAssignment(t *testing.T) { seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build")) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "u1@example.com", "User 1", "user_1", "membership_1") - assigned, err := ti.service.roleMgr.assignMembersToRole(ctx, authCtx.ActiveOrganizationID, "custom-builder", []string{"local_user_1"}) + assigned, _, err := ti.service.roleMgr.assignMembersToRoleTx(ctx, ti.conn, authCtx.ActiveOrganizationID, "custom-builder", []string{"local_user_1"}) require.Error(t, err) require.Equal(t, 0, assigned) require.Contains(t, err.Error(), "member role assignment not found") diff --git a/server/internal/access/setup_test.go b/server/internal/access/setup_test.go index dd421191d9..a79cc75379 100644 --- a/server/internal/access/setup_test.go +++ b/server/internal/access/setup_test.go @@ -96,7 +96,7 @@ func newTestAccessService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) - svc := NewService(logger, tracerProvider, conn, chConn, sessionManager, NewRoleManager(logger, conn, roles, authzEngine), authzEngine, noopFeatureCacheWriter{}, auditLogger) + svc := NewService(logger, tracerProvider, conn, chConn, sessionManager, NewRoleManager(logger, conn, roles, auditLogger), authzEngine, noopFeatureCacheWriter{}, auditLogger) return ctx, &testInstance{ service: svc, diff --git a/server/internal/authz/engine.go b/server/internal/authz/engine.go index db211a25f2..b8c4e639c7 100644 --- a/server/internal/authz/engine.go +++ b/server/internal/authz/engine.go @@ -122,18 +122,18 @@ func (e *Engine) PrepareContext(ctx context.Context) (context.Context, error) { principals := []urn.Principal{urn.NewPrincipal(urn.PrincipalTypeUser, authCtx.UserID)} - roleSlug, err := e.resolveRoleSlug(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) + roleSlugs, err := e.resolveRoleSlugs(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) if err != nil { e.logger.ErrorContext( ctx, - "failed to resolve role for authz grants", + "failed to resolve roles for authz grants", attr.SlogOrganizationID(authCtx.ActiveOrganizationID), attr.SlogUserID(authCtx.UserID), attr.SlogError(err), ) - return ctx, fmt.Errorf("resolve role slug: %w", err) + return ctx, fmt.Errorf("resolve role slugs: %w", err) } - if roleSlug != "" { + for _, roleSlug := range roleSlugs { principals = append(principals, urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug)) } @@ -152,37 +152,27 @@ func (e *Engine) PrepareContext(ctx context.Context) (context.Context, error) { return GrantsToContext(ctx, grants), nil } -func (e *Engine) resolveRoleSlug(ctx context.Context, userID, orgID string) (string, error) { +func (e *Engine) resolveRoleSlugs(ctx context.Context, userID, orgID string) ([]string, error) { user, err := usersrepo.New(e.db).GetUser(ctx, userID) if err != nil { - return "", fmt.Errorf("get user: %w", err) + return nil, fmt.Errorf("get user: %w", err) } if !user.WorkosID.Valid || user.WorkosID.String == "" { - return "", nil + return nil, nil } + // Role assignments are local source-of-truth records. They are written by + // the access write path and by WorkOS sync, so invitation/admin-console + // changes can lag until the sync job catches up. roleSlugs, err := accessrepo.New(e.db).ListMemberRoleSlugsByWorkosUser(ctx, accessrepo.ListMemberRoleSlugsByWorkosUserParams{ OrganizationID: orgID, WorkosUserID: user.WorkosID.String, }) if err != nil { - return "", fmt.Errorf("list member role slugs: %w", err) + return nil, fmt.Errorf("list member role slugs: %w", err) } - if len(roleSlugs) == 0 { - return "", nil - } - - return roleSlugs[0], nil -} - -// InvalidateRoleCache is retained for callers that used to clear the Redis role cache. -// Role resolution now reads Postgres directly, so this is intentionally a no-op. -func (e *Engine) InvalidateRoleCache(ctx context.Context, userID, orgID string) { -} -// InvalidateAllRoleCaches is retained for callers that used to clear the Redis role cache. -// Role resolution now reads Postgres directly, so this is intentionally a no-op. -func (e *Engine) InvalidateAllRoleCaches(ctx context.Context, orgID string) { + return roleSlugs, nil } func (e *Engine) Require(ctx context.Context, checks ...Check) error { diff --git a/server/internal/authz/engine_test.go b/server/internal/authz/engine_test.go index 5634a65003..25bd68f312 100644 --- a/server/internal/authz/engine_test.go +++ b/server/internal/authz/engine_test.go @@ -99,7 +99,7 @@ func TestEngineRequire_returnsUnexpectedWhenFeatureCheckFails(t *testing.T) { require.Equal(t, oops.CodeUnexpected, oopsErr.Code) } -func TestResolveRoleSlug_readsLocalAssignmentsOnly(t *testing.T) { +func TestResolveRoleSlugs_readsLocalAssignmentsOnly(t *testing.T) { t.Parallel() ctx := enterpriseTestCtx(t.Context()) @@ -116,16 +116,41 @@ func TestResolveRoleSlug_readsLocalAssignmentsOnly(t *testing.T) { require.NoError(t, err) engine := NewEngine(testenv.NewLogger(t), conn, chConn, staticRBAC(true), staticChallengeLogging(true), membership) - roleSlug, err := engine.resolveRoleSlug(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) + roleSlugs, err := engine.resolveRoleSlugs(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) require.NoError(t, err) - require.Empty(t, roleSlug) + require.Empty(t, roleSlugs) - roleSlug, err = engine.resolveRoleSlug(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) + roleSlugs, err = engine.resolveRoleSlugs(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) require.NoError(t, err) - require.Empty(t, roleSlug) + require.Empty(t, roleSlugs) require.Equal(t, 0, membership.calls) } +func TestResolveRoleSlugs_returnsAllLocalAssignments(t *testing.T) { + t.Parallel() + + ctx := enterpriseTestCtx(t.Context()) + conn := newTestDB(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + require.NotNil(t, authCtx) + + seedOrganization(t, ctx, conn, authCtx.ActiveOrganizationID) + seedConnectedUser(t, ctx, conn, authCtx.ActiveOrganizationID, authCtx.UserID, "test@example.com", "Test User", "user_workos_test", "membership_test") + seedRole(t, ctx, conn, authCtx.ActiveOrganizationID, "custom-alpha") + seedRole(t, ctx, conn, authCtx.ActiveOrganizationID, "custom-beta") + seedRoleAssignment(t, ctx, conn, authCtx.ActiveOrganizationID, authCtx.UserID, "user_workos_test", "membership_test", "custom-alpha") + seedRoleAssignment(t, ctx, conn, authCtx.ActiveOrganizationID, authCtx.UserID, "user_workos_test", "membership_test", "custom-beta") + + chConn, err := newClickhouseClient(t) + require.NoError(t, err) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, staticRBAC(true), staticChallengeLogging(true), &countingMembershipFetcher{}) + + roleSlugs, err := engine.resolveRoleSlugs(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) + require.NoError(t, err) + require.ElementsMatch(t, []string{"custom-alpha", "custom-beta"}, roleSlugs) +} + func TestEngineRequireAny_mapsDeniedToForbidden(t *testing.T) { t.Parallel() diff --git a/server/internal/authz/grants.go b/server/internal/authz/grants.go index c0546f805b..092f051f1d 100644 --- a/server/internal/authz/grants.go +++ b/server/internal/authz/grants.go @@ -77,21 +77,37 @@ func SyncGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgI return fmt.Errorf("organization id is required") } - principalURN := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) - tx, err := db.Begin(ctx) if err != nil { return fmt.Errorf("begin grant sync transaction: %w", err) } defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) - q := repo.New(tx) + if _, err := SyncGrantsTx(ctx, tx, orgID, roleSlug, grants); err != nil { + return err + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("commit grant sync transaction: %w", err) + } + + return nil +} + +func SyncGrantsTx(ctx context.Context, dbtx repo.DBTX, orgID string, roleSlug string, grants []*RoleGrant) ([]*ScopedGrant, error) { + if orgID == "" { + return nil, fmt.Errorf("organization id is required") + } + + principalURN := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) + + q := repo.New(dbtx) if _, err := q.DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ OrganizationID: orgID, PrincipalUrn: principalURN, }); err != nil { - return fmt.Errorf("delete grants for role %q: %w", roleSlug, err) + return nil, fmt.Errorf("delete grants for role %q: %w", roleSlug, err) } for _, grant := range grants { @@ -107,7 +123,7 @@ func SyncGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgI sel := NewSelector(scope, WildcardResource) selBytes, err := sel.MarshalJSON() if err != nil { - return fmt.Errorf("marshal wildcard selector for %q: %w", grant.Scope, err) + return nil, fmt.Errorf("marshal wildcard selector for %q: %w", grant.Scope, err) } if _, err := q.UpsertPrincipalGrant(ctx, repo.UpsertPrincipalGrantParams{ OrganizationID: orgID, @@ -115,19 +131,19 @@ func SyncGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgI Scope: grant.Scope, Selectors: selBytes, }); err != nil { - return fmt.Errorf("upsert unrestricted grant %q for role %q: %w", grant.Scope, roleSlug, err) + return nil, fmt.Errorf("upsert unrestricted grant %q for role %q: %w", grant.Scope, roleSlug, err) } continue } for _, sel := range grant.Selectors { if err := ValidateSelector(scope, sel); err != nil { - return fmt.Errorf("invalid selector for scope %q: %w", grant.Scope, err) + return nil, fmt.Errorf("invalid selector for scope %q: %w", grant.Scope, err) } selBytes, err := sel.MarshalJSON() if err != nil { - return fmt.Errorf("marshal selector for scope %q: %w", grant.Scope, err) + return nil, fmt.Errorf("marshal selector for scope %q: %w", grant.Scope, err) } if _, err := q.UpsertPrincipalGrant(ctx, repo.UpsertPrincipalGrantParams{ OrganizationID: orgID, @@ -135,16 +151,25 @@ func SyncGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgI Scope: grant.Scope, Selectors: selBytes, }); err != nil { - return fmt.Errorf("upsert grant %q for role %q: %w", grant.Scope, roleSlug, err) + return nil, fmt.Errorf("upsert grant %q for role %q: %w", grant.Scope, roleSlug, err) } } } - if err := tx.Commit(ctx); err != nil { - return fmt.Errorf("commit grant sync transaction: %w", err) + rows, err := q.ListPrincipalGrantsByOrg(ctx, repo.ListPrincipalGrantsByOrgParams{ + OrganizationID: orgID, + PrincipalUrn: principalURN.String(), + }) + if err != nil { + return nil, fmt.Errorf("list synced grants for role %q: %w", roleSlug, err) } - return nil + scoped, err := scopedGrantsFromRows(principalURN.String(), rows) + if err != nil { + return nil, fmt.Errorf("load synced grants for role %q: %w", roleSlug, err) + } + + return scoped, nil } func GrantsForRole(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgID string, roleSlug string) ([]*ScopedGrant, error) { @@ -156,13 +181,20 @@ func GrantsForRole(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, o return nil, oops.E(oops.CodeUnexpected, err, "list grants for role").Log(ctx, logger) } - rolePrincipalURN := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug).String() + scoped, err := scopedGrantsFromRows(urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug).String(), rows) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "unmarshal grant selector").Log(ctx, logger) + } + + return scoped, nil +} +func scopedGrantsFromRows(rolePrincipalURN string, rows []repo.ListPrincipalGrantsByOrgRow) ([]*ScopedGrant, error) { grantRows := make([]Grant, 0, len(rows)) for _, row := range rows { selectors, err := SelectorFromRow(row.Selectors) if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "unmarshal grant selector").Log(ctx, logger) + return nil, err } grantRows = append(grantRows, Grant{ PrincipalUrn: rolePrincipalURN, diff --git a/server/internal/authz/setup_test.go b/server/internal/authz/setup_test.go index 0bcd2d2e22..2c81b8e5b4 100644 --- a/server/internal/authz/setup_test.go +++ b/server/internal/authz/setup_test.go @@ -5,6 +5,7 @@ import ( "log" "os" "testing" + "time" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" @@ -145,3 +146,38 @@ func seedConnectedUser(t *testing.T, ctx context.Context, conn *pgxpool.Pool, or }) require.NoError(t, err) } + +func seedRole(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID string, slug string) string { + t.Helper() + + now := time.Now().UTC() + role, err := accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + OrganizationID: organizationID, + WorkosSlug: slug, + WorkosName: slug, + WorkosDescription: conv.ToPGTextEmpty(""), + WorkosCreatedAt: conv.ToPGTimestamptz(now), + WorkosUpdatedAt: conv.ToPGTimestamptz(now), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + require.NoError(t, err) + + return "role:organization:" + role.ID.String() +} + +func seedRoleAssignment(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID string, userID string, workosUserID string, workosMembershipID string, roleSlug string) { + t.Helper() + + seedRole(t, ctx, conn, organizationID, roleSlug) + rows, err := accessrepo.New(conn).UpsertOrganizationRoleAssignment(ctx, accessrepo.UpsertOrganizationRoleAssignmentParams{ + OrganizationID: organizationID, + WorkosUserID: workosUserID, + UserID: conv.ToPGTextEmpty(userID), + WorkosMembershipID: conv.ToPGTextEmpty(workosMembershipID), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Now().UTC()), + WorkosLastEventID: conv.ToPGTextEmpty(""), + WorkosRoleSlug: roleSlug, + }) + require.NoError(t, err) + require.Equal(t, int64(1), rows) +} From 7ec1a533fae5682dafba89d68582cd8ab118884f Mon Sep 17 00:00:00 2001 From: tgmendes Date: Mon, 18 May 2026 16:01:09 +0100 Subject: [PATCH 07/12] chore: use local role principals for access management --- server/internal/access/createrole_test.go | 2 +- server/internal/access/impl.go | 12 +- server/internal/access/queries.sql | 83 ++++++- server/internal/access/repo/queries.sql.go | 219 +++++++++++++++++- server/internal/access/role_manager.go | 207 ++++++++--------- server/internal/access/role_manager_test.go | 6 +- server/internal/access/setup_internal_test.go | 21 ++ server/internal/access/syncgrants_test.go | 27 ++- server/internal/access/updaterole_test.go | 2 +- server/internal/authz/engine.go | 16 +- server/internal/authz/engine_test.go | 22 +- server/internal/authz/grants.go | 115 +++++++-- .../activities/process_workos_org_events.go | 7 +- 13 files changed, 565 insertions(+), 174 deletions(-) diff --git a/server/internal/access/createrole_test.go b/server/internal/access/createrole_test.go index 626d7ef153..079e50ff7d 100644 --- a/server/internal/access/createrole_test.go +++ b/server/internal/access/createrole_test.go @@ -84,7 +84,7 @@ func TestService_CreateRole(t *testing.T) { require.NoError(t, err) require.Equal(t, role.ID, roundtrip.ID) - grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "org-custom-builder")) + grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "organization:"+role.ID)) require.Len(t, grants, 3) } diff --git a/server/internal/access/impl.go b/server/internal/access/impl.go index 5437f0df01..74fdebca85 100644 --- a/server/internal/access/impl.go +++ b/server/internal/access/impl.go @@ -316,12 +316,18 @@ func (s *Service) ListGrants(ctx context.Context, _ *gen.ListGrantsPayload) (*ge } principals := []urn.Principal{urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID)} - roleSlugs, err := s.roleMgr.MemberRoleSlugs(ctx, ac.ActiveOrganizationID, connectedUser.WorkosID.String) + rolePrincipals, err := s.roleMgr.MemberRolePrincipals(ctx, ac.ActiveOrganizationID, connectedUser.WorkosID.String) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "list member roles").Log(ctx, logger) } - for _, roleSlug := range roleSlugs { - principals = append(principals, urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug)) + roleSlugs := make([]string, 0, len(rolePrincipals)) + for _, role := range rolePrincipals { + rolePrincipalURNs, err := authz.RolePrincipals(role.RoleSlug, role.PrincipalUrn) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "build role principals").Log(ctx, logger) + } + principals = append(principals, rolePrincipalURNs...) + roleSlugs = append(roleSlugs, role.RoleSlug) } if len(roleSlugs) == 1 { logger = logger.With(attr.SlogAccessRoleSlug(roleSlugs[0])) diff --git a/server/internal/access/queries.sql b/server/internal/access/queries.sql index 3cf2b4286d..9b0ee4cdb2 100644 --- a/server/internal/access/queries.sql +++ b/server/internal/access/queries.sql @@ -150,6 +150,7 @@ RETURNING ) SELECT upserted.id, + ('role:organization:' || upserted.id::text)::text AS role_urn, upserted.workos_slug, upserted.workos_name, upserted.workos_description, @@ -197,6 +198,7 @@ WITH active_roles AS ( ) SELECT active_roles.id, + ('role:' || active_roles.role_kind || ':' || active_roles.id::text)::text AS role_urn, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, @@ -208,7 +210,7 @@ LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = @organization_id AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL -GROUP BY active_roles.id, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at ORDER BY active_roles.workos_slug; -- name: GetActiveOrganizationRoleBySlug :one @@ -228,6 +230,7 @@ WITH active_roles AS ( ) SELECT active_roles.id, + ('role:' || active_roles.role_kind || ':' || active_roles.id::text)::text AS role_urn, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, @@ -239,7 +242,7 @@ LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = @organization_id AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL -GROUP BY active_roles.id, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at LIMIT 1; -- name: GetOrganizationRoleByID :one @@ -259,6 +262,7 @@ UNION ALL ) SELECT active_roles.id, + ('role:' || active_roles.role_kind || ':' || active_roles.id::text)::text AS role_urn, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, @@ -270,7 +274,7 @@ LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = @organization_id AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL -GROUP BY active_roles.id, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at LIMIT 1; -- name: ListOrganizationRoleAssignmentsForOrg :many @@ -294,6 +298,73 @@ WHERE ora.organization_id = @organization_id AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL ORDER BY ora.workos_user_id, role_slug; +-- name: ListOrganizationRoleAssignmentsBySlug :many +SELECT + ora.user_id, + ora.workos_user_id, + ora.workos_membership_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.created_at +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = @organization_id + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) = @workos_role_slug +ORDER BY ora.workos_user_id; + +-- name: GetOrganizationRoleAssignmentByWorkosUser :one +SELECT + ora.user_id, + ora.workos_user_id, + ora.workos_membership_id, + COALESCE(organization_roles.id, global_roles.id)::uuid AS role_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.created_at +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = @organization_id + AND ora.workos_user_id = @workos_user_id + AND COALESCE(organization_roles.id, global_roles.id) IS NOT NULL +ORDER BY ora.created_at +LIMIT 1; + +-- name: ListOrganizationRoleAssignmentsByWorkosUsers :many +SELECT + ora.user_id, + ora.workos_user_id, + ora.workos_membership_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.created_at +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = @organization_id + AND ora.workos_user_id = ANY(@workos_user_ids::text[]) + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL +ORDER BY ora.workos_user_id, role_slug; + -- name: ListAccessMembers :many SELECT users.id, @@ -318,8 +389,10 @@ WHERE ora.organization_id = @organization_id AND COALESCE(organization_roles.id, global_roles.id) IS NOT NULL ORDER BY users.email, users.id; --- name: ListMemberRoleSlugsByWorkosUser :many -SELECT DISTINCT COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug +-- name: ListMemberRolePrincipalsByWorkosUser :many +SELECT + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.role_urn::text AS principal_urn FROM organization_role_assignments AS ora LEFT JOIN organization_roles ON ora.role_urn = 'role:organization:' || organization_roles.id::text diff --git a/server/internal/access/repo/queries.sql.go b/server/internal/access/repo/queries.sql.go index a01b124d8a..c00cdc1a38 100644 --- a/server/internal/access/repo/queries.sql.go +++ b/server/internal/access/repo/queries.sql.go @@ -71,6 +71,7 @@ WITH active_roles AS ( ) SELECT active_roles.id, + ('role:' || active_roles.role_kind || ':' || active_roles.id::text)::text AS role_urn, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, @@ -82,7 +83,7 @@ LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = $1 AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL -GROUP BY active_roles.id, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at LIMIT 1 ` @@ -93,6 +94,7 @@ type GetActiveOrganizationRoleBySlugParams struct { type GetActiveOrganizationRoleBySlugRow struct { ID uuid.UUID + RoleUrn string WorkosSlug string WorkosName string WorkosDescription pgtype.Text @@ -106,6 +108,7 @@ func (q *Queries) GetActiveOrganizationRoleBySlug(ctx context.Context, arg GetAc var i GetActiveOrganizationRoleBySlugRow err := row.Scan( &i.ID, + &i.RoleUrn, &i.WorkosSlug, &i.WorkosName, &i.WorkosDescription, @@ -143,6 +146,59 @@ func (q *Queries) GetGlobalRoleBySlug(ctx context.Context, workosSlug string) (G return i, err } +const getOrganizationRoleAssignmentByWorkosUser = `-- name: GetOrganizationRoleAssignmentByWorkosUser :one +SELECT + ora.user_id, + ora.workos_user_id, + ora.workos_membership_id, + COALESCE(organization_roles.id, global_roles.id)::uuid AS role_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.created_at +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = $1 + AND ora.workos_user_id = $2 + AND COALESCE(organization_roles.id, global_roles.id) IS NOT NULL +ORDER BY ora.created_at +LIMIT 1 +` + +type GetOrganizationRoleAssignmentByWorkosUserParams struct { + OrganizationID string + WorkosUserID string +} + +type GetOrganizationRoleAssignmentByWorkosUserRow struct { + UserID pgtype.Text + WorkosUserID string + WorkosMembershipID pgtype.Text + RoleID uuid.UUID + RoleSlug string + CreatedAt pgtype.Timestamptz +} + +func (q *Queries) GetOrganizationRoleAssignmentByWorkosUser(ctx context.Context, arg GetOrganizationRoleAssignmentByWorkosUserParams) (GetOrganizationRoleAssignmentByWorkosUserRow, error) { + row := q.db.QueryRow(ctx, getOrganizationRoleAssignmentByWorkosUser, arg.OrganizationID, arg.WorkosUserID) + var i GetOrganizationRoleAssignmentByWorkosUserRow + err := row.Scan( + &i.UserID, + &i.WorkosUserID, + &i.WorkosMembershipID, + &i.RoleID, + &i.RoleSlug, + &i.CreatedAt, + ) + return i, err +} + const getOrganizationRoleByID = `-- name: GetOrganizationRoleByID :one WITH active_roles AS ( SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind @@ -160,6 +216,7 @@ UNION ALL ) SELECT active_roles.id, + ('role:' || active_roles.role_kind || ':' || active_roles.id::text)::text AS role_urn, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, @@ -171,7 +228,7 @@ LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = $1 AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL -GROUP BY active_roles.id, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at LIMIT 1 ` @@ -182,6 +239,7 @@ type GetOrganizationRoleByIDParams struct { type GetOrganizationRoleByIDRow struct { ID uuid.UUID + RoleUrn string WorkosSlug string WorkosName string WorkosDescription pgtype.Text @@ -195,6 +253,7 @@ func (q *Queries) GetOrganizationRoleByID(ctx context.Context, arg GetOrganizati var i GetOrganizationRoleByIDRow err := row.Scan( &i.ID, + &i.RoleUrn, &i.WorkosSlug, &i.WorkosName, &i.WorkosDescription, @@ -423,6 +482,7 @@ WITH active_roles AS ( ) SELECT active_roles.id, + ('role:' || active_roles.role_kind || ':' || active_roles.id::text)::text AS role_urn, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, @@ -434,12 +494,13 @@ LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = $1 AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL -GROUP BY active_roles.id, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at ORDER BY active_roles.workos_slug ` type ListActiveOrganizationRolesRow struct { ID uuid.UUID + RoleUrn string WorkosSlug string WorkosName string WorkosDescription pgtype.Text @@ -459,6 +520,7 @@ func (q *Queries) ListActiveOrganizationRoles(ctx context.Context, organizationI var i ListActiveOrganizationRolesRow if err := rows.Scan( &i.ID, + &i.RoleUrn, &i.WorkosSlug, &i.WorkosName, &i.WorkosDescription, @@ -523,8 +585,10 @@ func (q *Queries) ListChallengeResolutions(ctx context.Context, arg ListChalleng return items, nil } -const listMemberRoleSlugsByWorkosUser = `-- name: ListMemberRoleSlugsByWorkosUser :many -SELECT DISTINCT COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug +const listMemberRolePrincipalsByWorkosUser = `-- name: ListMemberRolePrincipalsByWorkosUser :many +SELECT + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.role_urn::text AS principal_urn FROM organization_role_assignments AS ora LEFT JOIN organization_roles ON ora.role_urn = 'role:organization:' || organization_roles.id::text @@ -541,24 +605,152 @@ WHERE ora.organization_id = $1 ORDER BY role_slug ` -type ListMemberRoleSlugsByWorkosUserParams struct { +type ListMemberRolePrincipalsByWorkosUserParams struct { OrganizationID string WorkosUserID string } -func (q *Queries) ListMemberRoleSlugsByWorkosUser(ctx context.Context, arg ListMemberRoleSlugsByWorkosUserParams) ([]string, error) { - rows, err := q.db.Query(ctx, listMemberRoleSlugsByWorkosUser, arg.OrganizationID, arg.WorkosUserID) +type ListMemberRolePrincipalsByWorkosUserRow struct { + RoleSlug string + PrincipalUrn string +} + +func (q *Queries) ListMemberRolePrincipalsByWorkosUser(ctx context.Context, arg ListMemberRolePrincipalsByWorkosUserParams) ([]ListMemberRolePrincipalsByWorkosUserRow, error) { + rows, err := q.db.Query(ctx, listMemberRolePrincipalsByWorkosUser, arg.OrganizationID, arg.WorkosUserID) if err != nil { return nil, err } defer rows.Close() - var items []string + var items []ListMemberRolePrincipalsByWorkosUserRow for rows.Next() { - var role_slug string - if err := rows.Scan(&role_slug); err != nil { + var i ListMemberRolePrincipalsByWorkosUserRow + if err := rows.Scan(&i.RoleSlug, &i.PrincipalUrn); err != nil { return nil, err } - items = append(items, role_slug) + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listOrganizationRoleAssignmentsBySlug = `-- name: ListOrganizationRoleAssignmentsBySlug :many +SELECT + ora.user_id, + ora.workos_user_id, + ora.workos_membership_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.created_at +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = $1 + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) = $2 +ORDER BY ora.workos_user_id +` + +type ListOrganizationRoleAssignmentsBySlugParams struct { + OrganizationID string + WorkosRoleSlug string +} + +type ListOrganizationRoleAssignmentsBySlugRow struct { + UserID pgtype.Text + WorkosUserID string + WorkosMembershipID pgtype.Text + RoleSlug string + CreatedAt pgtype.Timestamptz +} + +func (q *Queries) ListOrganizationRoleAssignmentsBySlug(ctx context.Context, arg ListOrganizationRoleAssignmentsBySlugParams) ([]ListOrganizationRoleAssignmentsBySlugRow, error) { + rows, err := q.db.Query(ctx, listOrganizationRoleAssignmentsBySlug, arg.OrganizationID, arg.WorkosRoleSlug) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListOrganizationRoleAssignmentsBySlugRow + for rows.Next() { + var i ListOrganizationRoleAssignmentsBySlugRow + if err := rows.Scan( + &i.UserID, + &i.WorkosUserID, + &i.WorkosMembershipID, + &i.RoleSlug, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listOrganizationRoleAssignmentsByWorkosUsers = `-- name: ListOrganizationRoleAssignmentsByWorkosUsers :many +SELECT + ora.user_id, + ora.workos_user_id, + ora.workos_membership_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.created_at +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = $1 + AND ora.workos_user_id = ANY($2::text[]) + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL +ORDER BY ora.workos_user_id, role_slug +` + +type ListOrganizationRoleAssignmentsByWorkosUsersParams struct { + OrganizationID string + WorkosUserIds []string +} + +type ListOrganizationRoleAssignmentsByWorkosUsersRow struct { + UserID pgtype.Text + WorkosUserID string + WorkosMembershipID pgtype.Text + RoleSlug string + CreatedAt pgtype.Timestamptz +} + +func (q *Queries) ListOrganizationRoleAssignmentsByWorkosUsers(ctx context.Context, arg ListOrganizationRoleAssignmentsByWorkosUsersParams) ([]ListOrganizationRoleAssignmentsByWorkosUsersRow, error) { + rows, err := q.db.Query(ctx, listOrganizationRoleAssignmentsByWorkosUsers, arg.OrganizationID, arg.WorkosUserIds) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListOrganizationRoleAssignmentsByWorkosUsersRow + for rows.Next() { + var i ListOrganizationRoleAssignmentsByWorkosUsersRow + if err := rows.Scan( + &i.UserID, + &i.WorkosUserID, + &i.WorkosMembershipID, + &i.RoleSlug, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) } if err := rows.Err(); err != nil { return nil, err @@ -924,6 +1116,7 @@ RETURNING ) SELECT upserted.id, + ('role:organization:' || upserted.id::text)::text AS role_urn, upserted.workos_slug, upserted.workos_name, upserted.workos_description, @@ -950,6 +1143,7 @@ type UpsertOrganizationRoleParams struct { type UpsertOrganizationRoleRow struct { ID uuid.UUID + RoleUrn string WorkosSlug string WorkosName string WorkosDescription pgtype.Text @@ -974,6 +1168,7 @@ func (q *Queries) UpsertOrganizationRole(ctx context.Context, arg UpsertOrganiza var i UpsertOrganizationRoleRow err := row.Scan( &i.ID, + &i.RoleUrn, &i.WorkosSlug, &i.WorkosName, &i.WorkosDescription, diff --git a/server/internal/access/role_manager.go b/server/internal/access/role_manager.go index fed74db866..b6c4157bcd 100644 --- a/server/internal/access/role_manager.go +++ b/server/internal/access/role_manager.go @@ -89,14 +89,6 @@ func (r *RoleManager) GetRoleByID(ctx context.Context, gramOrgID, id string) (*g return r.roleViewFromLocalRole(ctx, gramOrgID, role) } -type localRoleAssignment struct { - UserID string - WorkosUserID string - MembershipID string - RoleSlug string - CreatedAt string -} - // ListMembers returns locally known organization members with role IDs resolved from local role assignments. func (r *RoleManager) ListMembers(ctx context.Context, gramOrgID string) (*gen.ListMembersResult, error) { rows, err := repo.New(r.db).ListAccessMembers(ctx, gramOrgID) @@ -133,7 +125,6 @@ type accessAuditActor struct { } // CreateRole creates the local role, grants, optional assignments, and audit entry atomically, then best-effort syncs WorkOS after commit. -// Side effects: writes Postgres role/grant/assignment/audit records and logs WorkOS sync failures after bounded retries. func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID string, actor accessAuditActor, payload *gen.CreateRolePayload) (roleCreateResult, error) { roleSlug, err := slugify(payload.Name) if err != nil { @@ -162,7 +153,7 @@ func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID str createdRole := localRoleFromUpsertRow(createdRow) trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleID(createdRole.ID)) - if _, err := authz.SyncGrantsTx(ctx, tx, gramOrgID, roleSlug, roleGrantPayloads(payload.Grants)); err != nil { + if _, err := authz.SyncGrantsTx(ctx, tx, gramOrgID, roleSlug, createdRole.PrincipalURN, roleGrantPayloads(payload.Grants)); err != nil { return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "sync grants for created role").Log(ctx, r.logger) } @@ -223,13 +214,14 @@ func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID str } type localRole struct { - ID string - Name string - Slug string - Description string - CreatedAt string - UpdatedAt string - MemberCount int + ID string + PrincipalURN string + Name string + Slug string + Description string + CreatedAt string + UpdatedAt string + MemberCount int } type roleUpdateResult struct { @@ -239,7 +231,6 @@ type roleUpdateResult struct { } // UpdateRole updates an existing local role, optional grants/assignments, and audit entry atomically, then best-effort syncs WorkOS after commit. -// Side effects: writes Postgres role/grant/assignment/audit records and logs WorkOS sync failures after bounded retries. func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID string, actor accessAuditActor, payload *gen.UpdateRolePayload) (roleUpdateResult, error) { currentRole, err := r.getLocalRoleByID(ctx, gramOrgID, payload.ID) if err != nil { @@ -296,7 +287,7 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str updatedRole = localRoleFromUpsertRow(updatedRow) if payload.Grants != nil { - syncedGrants, err := authz.SyncGrantsTx(ctx, tx, gramOrgID, currentRole.Slug, roleGrantPayloads(payload.Grants)) + syncedGrants, err := authz.SyncGrantsTx(ctx, tx, gramOrgID, currentRole.Slug, currentRole.PrincipalURN, roleGrantPayloads(payload.Grants)) if err != nil { return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "sync grants for updated role").Log(ctx, r.logger) } @@ -365,7 +356,6 @@ type roleDeleteResult struct { } // DeleteRole deletes a custom local role, reassignment records, grants, and audit entry atomically, then best-effort syncs WorkOS after commit. -// Side effects: writes Postgres assignment/role/grant/audit records and logs WorkOS sync failures after bounded retries. func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, roleID string, actor accessAuditActor) (roleDeleteResult, error) { currentRole, err := r.getLocalRoleByID(ctx, gramOrgID, roleID) if err != nil { @@ -381,7 +371,10 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro } defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) - rows, err := repo.New(tx).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) + rows, err := repo.New(tx).ListOrganizationRoleAssignmentsBySlug(ctx, repo.ListOrganizationRoleAssignmentsBySlugParams{ + OrganizationID: gramOrgID, + WorkosRoleSlug: currentRole.Slug, + }) if err != nil { return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) } @@ -389,9 +382,6 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro var workosSyncs []workosSync for _, row := range rows { - if row.RoleSlug != currentRole.Slug { - continue - } membershipID := conv.FromPGTextOrEmpty[string](row.WorkosMembershipID) if row.WorkosUserID != "" { replaced, err := repo.New(tx).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, row.WorkosUserID, authz.SystemRoleMember, "", membershipID)) @@ -424,10 +414,7 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "mark local role record deleted").Log(ctx, r.logger) } - if _, err := repo.New(tx).DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ - OrganizationID: gramOrgID, - PrincipalUrn: urn.NewPrincipal(urn.PrincipalTypeRole, currentRole.Slug), - }); err != nil { + if err := authz.DeleteRoleGrants(ctx, repo.New(tx), gramOrgID, currentRole.Slug, currentRole.PrincipalURN); err != nil { return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "delete grants for deleted role").Log(ctx, r.logger) } @@ -471,7 +458,6 @@ type memberRoleUpdateContext struct { } // UpdateMemberRole changes one member's local role assignment and audit entry atomically, then best-effort syncs WorkOS after commit. -// Side effects: writes a Postgres assignment/audit record and logs WorkOS sync failures after bounded retries. func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, roleID string, actor accessAuditActor) (memberRoleUpdateContext, error) { tx, err := r.db.Begin(ctx) if err != nil { @@ -495,40 +481,26 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r return memberRoleUpdateContext{}, oops.E(oops.CodeBadRequest, nil, "member is not linked to WorkOS").Log(ctx, r.logger) } - roleRows, err := repo.New(tx).ListActiveOrganizationRoles(ctx, gramOrgID) - if err != nil { - return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "list roles").Log(ctx, r.logger) + existing, err := repo.New(tx).GetOrganizationRoleAssignmentByWorkosUser(ctx, repo.GetOrganizationRoleAssignmentByWorkosUserParams{ + OrganizationID: gramOrgID, + WorkosUserID: connectedUser.WorkosID.String, + }) + switch { + case errors.Is(err, pgx.ErrNoRows): + return memberRoleUpdateContext{}, oops.E(oops.CodeNotFound, nil, "member not found").Log(ctx, r.logger) + case err != nil: + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "load member role assignment").Log(ctx, r.logger) } trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) - roleIDBySlug := make(map[string]string, len(roleRows)) - for _, row := range roleRows { - roleIDBySlug[row.WorkosSlug] = row.ID.String() - } - assignmentRows, err := repo.New(tx).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) - if err != nil { - return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) - } - trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) - var existing localRoleAssignment - for _, row := range assignmentRows { - if row.WorkosUserID == connectedUser.WorkosID.String { - existing = localRoleAssignment{ - UserID: conv.FromPGTextOrEmpty[string](row.UserID), - WorkosUserID: row.WorkosUserID, - MembershipID: conv.FromPGTextOrEmpty[string](row.WorkosMembershipID), - RoleSlug: row.RoleSlug, - CreatedAt: conv.FromPGTimestamptz(row.CreatedAt), - } - break - } - } - if existing.MembershipID == "" { + membershipID := conv.FromPGTextOrEmpty[string](existing.WorkosMembershipID) + if membershipID == "" { + // WorkOS sync must attach membership IDs before role changes can be propagated upstream. return memberRoleUpdateContext{}, oops.E(oops.CodeNotFound, nil, "member not found").Log(ctx, r.logger) } if existing.WorkosUserID != "" && role.Slug != "" { - replaced, err := repo.New(tx).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, existing.WorkosUserID, role.Slug, connectedUser.ID, existing.MembershipID)) + replaced, err := repo.New(tx).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, existing.WorkosUserID, role.Slug, connectedUser.ID, membershipID)) if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) @@ -542,7 +514,7 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r memberName := conv.Default(connectedUser.DisplayName, connectedUser.Email) result := memberRoleUpdateContext{ RoleSlug: role.Slug, - MembershipID: existing.MembershipID, + MembershipID: membershipID, WorkosUserID: connectedUser.WorkosID.String, UserID: connectedUser.ID, Before: &gen.AccessMember{ @@ -550,16 +522,16 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r Name: memberName, Email: connectedUser.Email, PhotoURL: conv.FromPGText[string](connectedUser.PhotoUrl), - RoleID: roleIDBySlug[existing.RoleSlug], - JoinedAt: existing.CreatedAt, + RoleID: existing.RoleID.String(), + JoinedAt: conv.FromPGTimestamptz(existing.CreatedAt), }, After: &gen.AccessMember{ ID: connectedUser.ID, Name: memberName, Email: connectedUser.Email, PhotoURL: conv.FromPGText[string](connectedUser.PhotoUrl), - RoleID: roleID, - JoinedAt: existing.CreatedAt, + RoleID: role.ID, + JoinedAt: conv.FromPGTimestamptz(existing.CreatedAt), }, } @@ -584,7 +556,7 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r r.runWorkOSSyncs(ctx, []workosSync{ func(ctx context.Context) { r.syncWorkOS(ctx, "update member role in workos", func() error { - _, err := r.roles.UpdateMemberRole(ctx, existing.MembershipID, role.Slug) + _, err := r.roles.UpdateMemberRole(ctx, membershipID, role.Slug) if err == nil { return nil } @@ -596,14 +568,12 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r return result, nil } -// MemberRoleSlugs returns local role slugs assigned to a WorkOS user inside an organization. -// Side effects: reads Postgres local assignment records; does not call WorkOS. -func (r *RoleManager) MemberRoleSlugs(ctx context.Context, gramOrgID, workosUserID string) ([]string, error) { +func (r *RoleManager) MemberRolePrincipals(ctx context.Context, gramOrgID, workosUserID string) ([]repo.ListMemberRolePrincipalsByWorkosUserRow, error) { if workosUserID == "" { return nil, nil } - roleSlugs, err := repo.New(r.db).ListMemberRoleSlugsByWorkosUser(ctx, repo.ListMemberRoleSlugsByWorkosUserParams{ + rows, err := repo.New(r.db).ListMemberRolePrincipalsByWorkosUser(ctx, repo.ListMemberRolePrincipalsByWorkosUserParams{ OrganizationID: gramOrgID, WorkosUserID: workosUserID, }) @@ -612,11 +582,10 @@ func (r *RoleManager) MemberRoleSlugs(ctx context.Context, gramOrgID, workosUser } trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) - return roleSlugs, nil + return rows, nil } // getLocalRoleByID loads one local role record by Gram role ID. -// Side effects: reads Postgres local role records; does not call WorkOS. func (r *RoleManager) getLocalRoleByID(ctx context.Context, gramOrgID, id string) (localRole, error) { return r.getLocalRoleByIDTx(ctx, r.db, gramOrgID, id) } @@ -687,8 +656,27 @@ func (r *RoleManager) memberAssignmentTargetsTx(ctx context.Context, dbtx repo.D } } } + workosIDs := make([]string, 0, len(memberIDs)) + seenRequestedWorkosID := make(map[string]struct{}, len(memberIDs)) + for _, id := range memberIDs { + workosID := id + if userWorkosID, ok := workosByGramID[id]; ok { + workosID = userWorkosID + } + if workosID == "" { + continue + } + if _, ok := seenRequestedWorkosID[workosID]; ok { + continue + } + seenRequestedWorkosID[workosID] = struct{}{} + workosIDs = append(workosIDs, workosID) + } - assignmentRows, err := repo.New(dbtx).ListOrganizationRoleAssignmentsForOrg(ctx, gramOrgID) + assignmentRows, err := repo.New(dbtx).ListOrganizationRoleAssignmentsByWorkosUsers(ctx, repo.ListOrganizationRoleAssignmentsByWorkosUsersParams{ + OrganizationID: gramOrgID, + WorkosUserIds: workosIDs, + }) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) } @@ -704,6 +692,9 @@ func (r *RoleManager) memberAssignmentTargetsTx(ctx context.Context, dbtx repo.D requestedID = userID } else if gramID, ok := requestedByWorkosID[row.WorkosUserID]; ok { requestedID = gramID + if userID == "" { + userID = gramID + } } else if _, ok := requested[row.WorkosUserID]; ok { requestedID = row.WorkosUserID } else { @@ -771,13 +762,17 @@ func (r *RoleManager) assignMembersToRoleTx(ctx context.Context, dbtx repo.DBTX, } // runWorkOSSyncs starts best-effort WorkOS writes after the local transaction commits. -// Side effects: launches a goroutine that outlives request cancellation. func (r *RoleManager) runWorkOSSyncs(ctx context.Context, syncs []workosSync) { if len(syncs) == 0 { return } syncCtx := context.WithoutCancel(ctx) go func() { + defer func() { + if recovered := recover(); recovered != nil { + r.logger.ErrorContext(syncCtx, "workos sync panic", attr.SlogError(fmt.Errorf("%v", recovered))) + } + }() for _, syncWorkOS := range syncs { syncWorkOS(syncCtx) } @@ -785,7 +780,6 @@ func (r *RoleManager) runWorkOSSyncs(ctx context.Context, syncs []workosSync) { } // syncWorkOS runs a bounded best-effort WorkOS write after the local database already accepted the change. -// Side effects: calls WorkOS, waits briefly between retryable failures, and logs the final failure without returning it. func (r *RoleManager) syncWorkOS(ctx context.Context, operation string, fn func() error) { exp := backoff.NewExponentialBackOff() exp.InitialInterval = 100 * time.Millisecond @@ -810,7 +804,6 @@ func (r *RoleManager) syncWorkOS(ctx context.Context, operation string, fn func( } // retryWorkOSError reports whether a WorkOS sync failure is worth retrying in-process. -// Side effects: none. func retryWorkOSError(err error) bool { var apiErr *workos.APIError if !errors.As(err, &apiErr) { @@ -820,9 +813,8 @@ func retryWorkOSError(err error) bool { } // roleViewFromLocalRole converts a local role record into the public API role view and attaches local grants. -// Side effects: reads Postgres grants; does not call WorkOS. func (r *RoleManager) roleViewFromLocalRole(ctx context.Context, organizationID string, role localRole) (*gen.Role, error) { - grants, err := authz.GrantsForRole(ctx, r.logger, r.db, organizationID, role.Slug) + grants, err := authz.GrantsForRole(ctx, r.logger, r.db, organizationID, role.Slug, role.PrincipalURN) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "load role grants").Log(ctx, r.logger) } @@ -857,61 +849,61 @@ func roleViewFromLocalRoleAndGrants(role localRole, grants []*gen.RoleGrant) *ge } // localRoleFromActiveRow converts a sqlc active-role row into the manager's internal local role record shape. -// Side effects: none. func localRoleFromActiveRow(row repo.ListActiveOrganizationRolesRow) localRole { return localRole{ - ID: row.ID.String(), - Name: row.WorkosName, - Slug: row.WorkosSlug, - Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), - CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), - UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), - MemberCount: int(row.MemberCount), + ID: row.ID.String(), + PrincipalURN: row.RoleUrn, + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), } } // localRoleFromRoleRow converts a sqlc role lookup row into the manager's internal local role record shape. -// Side effects: none. func localRoleFromRoleRow(row repo.GetOrganizationRoleByIDRow) localRole { return localRole{ - ID: row.ID.String(), - Name: row.WorkosName, - Slug: row.WorkosSlug, - Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), - CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), - UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), - MemberCount: int(row.MemberCount), + ID: row.ID.String(), + PrincipalURN: row.RoleUrn, + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), } } func localRoleFromUpsertRow(row repo.UpsertOrganizationRoleRow) localRole { return localRole{ - ID: row.ID.String(), - Name: row.WorkosName, - Slug: row.WorkosSlug, - Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), - CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), - UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), - MemberCount: int(row.MemberCount), + ID: row.ID.String(), + PrincipalURN: row.RoleUrn, + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), } } // localRoleFromSlugRow converts a sqlc role slug lookup row into the manager's internal local role record shape. -// Side effects: none. func localRoleFromSlugRow(row repo.GetActiveOrganizationRoleBySlugRow) localRole { return localRole{ - ID: row.ID.String(), - Name: row.WorkosName, - Slug: row.WorkosSlug, - Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), - CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), - UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), - MemberCount: int(row.MemberCount), + ID: row.ID.String(), + PrincipalURN: row.RoleUrn, + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), } } // organizationRoleParams builds the SQL parameters for storing the authoritative local role record. -// Side effects: reads the clock for updated_at and possibly created_at fallback. func organizationRoleParams(gramOrgID string, role workos.Role) repo.UpsertOrganizationRoleParams { return repo.UpsertOrganizationRoleParams{ OrganizationID: gramOrgID, @@ -925,7 +917,6 @@ func organizationRoleParams(gramOrgID string, role workos.Role) repo.UpsertOrgan } // replaceRoleAssignmentParams builds SQL parameters for storing the authoritative local role assignment. -// Side effects: reads the clock for updated_at. func replaceRoleAssignmentParams(gramOrgID, workosUserID, roleSlug, userID, membershipID string) repo.ReplaceOrganizationRoleAssignmentParams { return repo.ReplaceOrganizationRoleAssignmentParams{ OrganizationID: gramOrgID, @@ -939,7 +930,6 @@ func replaceRoleAssignmentParams(gramOrgID, workosUserID, roleSlug, userID, memb } // workosTimeOrNow parses a WorkOS RFC3339 timestamp or returns the current UTC time when WorkOS omits or malforms it. -// Side effects: reads the clock only when a fallback is needed. func workosTimeOrNow(value string) time.Time { if value == "" { return time.Now().UTC() @@ -952,7 +942,6 @@ func workosTimeOrNow(value string) time.Time { } // slugify validates a role name and turns it into Gram's WorkOS role slug format. -// Side effects: none. func slugify(name string) (string, error) { slug := conv.ToSlug(strings.ReplaceAll(name, "_", " ")) if slug == "" { diff --git a/server/internal/access/role_manager_test.go b/server/internal/access/role_manager_test.go index 804fd6d4b3..be2d88826d 100644 --- a/server/internal/access/role_manager_test.go +++ b/server/internal/access/role_manager_test.go @@ -74,8 +74,12 @@ func TestRoleManager_MembersAndCounts(t *testing.T) { require.NoError(t, err) require.Len(t, members.Members, 2) - slugs, err := manager.MemberRoleSlugs(ctx, authCtx.ActiveOrganizationID, "user_2") + rolePrincipals, err := manager.MemberRolePrincipals(ctx, authCtx.ActiveOrganizationID, "user_2") require.NoError(t, err) + slugs := make([]string, 0, len(rolePrincipals)) + for _, role := range rolePrincipals { + slugs = append(slugs, role.RoleSlug) + } require.Equal(t, []string{"custom-builder"}, slugs) roles, err := manager.ListRoles(ctx, authCtx.ActiveOrganizationID) diff --git a/server/internal/access/setup_internal_test.go b/server/internal/access/setup_internal_test.go index a1ba2d34cb..57795a36b4 100644 --- a/server/internal/access/setup_internal_test.go +++ b/server/internal/access/setup_internal_test.go @@ -3,6 +3,7 @@ package access import ( "context" "testing" + "time" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" @@ -59,3 +60,23 @@ func seedInternalGrant(t *testing.T, ctx context.Context, conn *pgxpool.Pool, or }) require.NoError(t, err) } + +func seedInternalRole(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID string, roleSlug string) urn.Principal { + t.Helper() + + now := time.Now().UTC() + row, err := accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + OrganizationID: organizationID, + WorkosSlug: roleSlug, + WorkosName: roleSlug, + WorkosDescription: conv.ToPGTextEmpty(""), + WorkosCreatedAt: conv.ToPGTimestamptz(now), + WorkosUpdatedAt: conv.ToPGTimestamptz(now), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + require.NoError(t, err) + + principal, err := urn.ParsePrincipal(row.RoleUrn) + require.NoError(t, err) + return principal +} diff --git a/server/internal/access/syncgrants_test.go b/server/internal/access/syncgrants_test.go index 1af0707124..d6ac5b01ec 100644 --- a/server/internal/access/syncgrants_test.go +++ b/server/internal/access/syncgrants_test.go @@ -16,14 +16,15 @@ func TestService_syncGrants_replacesRoleGrants(t *testing.T) { ctx, svc, conn := newInternalTestService(t) organizationID := "org_sync_grants_replace" roleSlug := "custom-editor" - rolePrincipal := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) - seedInternalOrganization(t, ctx, conn, organizationID) - seedInternalGrant(t, ctx, conn, organizationID, rolePrincipal, string(authz.ScopeProjectRead), "project-old") - seedInternalGrant(t, ctx, conn, organizationID, rolePrincipal, string(authz.ScopeProjectWrite), "project-stale") + rolePrincipal := seedInternalRole(t, ctx, conn, organizationID, roleSlug) + legacyRolePrincipal := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) + + seedInternalGrant(t, ctx, conn, organizationID, legacyRolePrincipal, string(authz.ScopeProjectRead), "project-old") + seedInternalGrant(t, ctx, conn, organizationID, legacyRolePrincipal, string(authz.ScopeProjectWrite), "project-stale") seedInternalGrant(t, ctx, conn, organizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "other-role"), string(authz.ScopeProjectRead), "project-other") - err := authz.SyncGrants(ctx, svc.logger, conn, organizationID, roleSlug, []*authz.RoleGrant{ + err := authz.SyncGrants(ctx, svc.logger, conn, organizationID, roleSlug, rolePrincipal.String(), []*authz.RoleGrant{ { Scope: string(authz.ScopeProjectRead), Selectors: nil, @@ -61,6 +62,12 @@ func TestService_syncGrants_replacesRoleGrants(t *testing.T) { string(authz.ScopeMCPConnect) + "|tool:analytics", string(authz.ScopeMCPConnect) + "|tool:payments", }, got) + legacyRows, err := accessrepo.New(conn).ListPrincipalGrantsByOrg(ctx, accessrepo.ListPrincipalGrantsByOrgParams{ + OrganizationID: organizationID, + PrincipalUrn: legacyRolePrincipal.String(), + }) + require.NoError(t, err) + require.Empty(t, legacyRows) otherRows, err := accessrepo.New(conn).ListPrincipalGrantsByOrg(ctx, accessrepo.ListPrincipalGrantsByOrgParams{ OrganizationID: organizationID, @@ -79,12 +86,11 @@ func TestService_syncGrants_emptySelectorsCreatesNoGrant(t *testing.T) { ctx, svc, conn := newInternalTestService(t) organizationID := "org_sync_grants_empty_sel" roleSlug := "custom-empty-sel" - rolePrincipal := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) - seedInternalOrganization(t, ctx, conn, organizationID) + rolePrincipal := seedInternalRole(t, ctx, conn, organizationID, roleSlug) // Empty non-nil selectors = no access (not wildcard). - err := authz.SyncGrants(ctx, svc.logger, conn, organizationID, roleSlug, []*authz.RoleGrant{ + err := authz.SyncGrants(ctx, svc.logger, conn, organizationID, roleSlug, rolePrincipal.String(), []*authz.RoleGrant{ { Scope: string(authz.ScopeMCPConnect), Selectors: []authz.Selector{}, @@ -106,13 +112,12 @@ func TestService_syncGrants_clearsRoleGrantsWhenEmpty(t *testing.T) { ctx, svc, conn := newInternalTestService(t) organizationID := "org_sync_grants_clear" roleSlug := "custom-viewer" - rolePrincipal := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) - seedInternalOrganization(t, ctx, conn, organizationID) + rolePrincipal := seedInternalRole(t, ctx, conn, organizationID, roleSlug) seedInternalGrant(t, ctx, conn, organizationID, rolePrincipal, string(authz.ScopeProjectRead), authz.WildcardResource) seedInternalGrant(t, ctx, conn, organizationID, rolePrincipal, string(authz.ScopeMCPRead), "tool:payments") - err := authz.SyncGrants(ctx, svc.logger, conn, organizationID, roleSlug, nil) + err := authz.SyncGrants(ctx, svc.logger, conn, organizationID, roleSlug, rolePrincipal.String(), nil) require.NoError(t, err) rows, err := accessrepo.New(conn).ListPrincipalGrantsByOrg(ctx, accessrepo.ListPrincipalGrantsByOrgParams{ diff --git a/server/internal/access/updaterole_test.go b/server/internal/access/updaterole_test.go index f1adfada8a..523b7365d6 100644 --- a/server/internal/access/updaterole_test.go +++ b/server/internal/access/updaterole_test.go @@ -85,7 +85,7 @@ func TestService_UpdateRole(t *testing.T) { require.NotEqual(t, mockRoleTimestamp, role.UpdatedAt) require.Len(t, role.Grants, 2) - grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder")) + grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "organization:"+roleID)) require.Len(t, grants, 3) } diff --git a/server/internal/authz/engine.go b/server/internal/authz/engine.go index b8c4e639c7..eede372284 100644 --- a/server/internal/authz/engine.go +++ b/server/internal/authz/engine.go @@ -122,7 +122,7 @@ func (e *Engine) PrepareContext(ctx context.Context) (context.Context, error) { principals := []urn.Principal{urn.NewPrincipal(urn.PrincipalTypeUser, authCtx.UserID)} - roleSlugs, err := e.resolveRoleSlugs(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) + rolePrincipals, err := e.resolveRolePrincipals(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) if err != nil { e.logger.ErrorContext( ctx, @@ -133,8 +133,12 @@ func (e *Engine) PrepareContext(ctx context.Context) (context.Context, error) { ) return ctx, fmt.Errorf("resolve role slugs: %w", err) } - for _, roleSlug := range roleSlugs { - principals = append(principals, urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug)) + for _, role := range rolePrincipals { + rolePrincipalURNs, err := RolePrincipals(role.RoleSlug, role.PrincipalUrn) + if err != nil { + return ctx, fmt.Errorf("build role principals: %w", err) + } + principals = append(principals, rolePrincipalURNs...) } grants, err := LoadGrants(ctx, e.db, authCtx.ActiveOrganizationID, principals) @@ -152,7 +156,7 @@ func (e *Engine) PrepareContext(ctx context.Context) (context.Context, error) { return GrantsToContext(ctx, grants), nil } -func (e *Engine) resolveRoleSlugs(ctx context.Context, userID, orgID string) ([]string, error) { +func (e *Engine) resolveRolePrincipals(ctx context.Context, userID, orgID string) ([]accessrepo.ListMemberRolePrincipalsByWorkosUserRow, error) { user, err := usersrepo.New(e.db).GetUser(ctx, userID) if err != nil { return nil, fmt.Errorf("get user: %w", err) @@ -164,7 +168,7 @@ func (e *Engine) resolveRoleSlugs(ctx context.Context, userID, orgID string) ([] // Role assignments are local source-of-truth records. They are written by // the access write path and by WorkOS sync, so invitation/admin-console // changes can lag until the sync job catches up. - roleSlugs, err := accessrepo.New(e.db).ListMemberRoleSlugsByWorkosUser(ctx, accessrepo.ListMemberRoleSlugsByWorkosUserParams{ + rolePrincipals, err := accessrepo.New(e.db).ListMemberRolePrincipalsByWorkosUser(ctx, accessrepo.ListMemberRolePrincipalsByWorkosUserParams{ OrganizationID: orgID, WorkosUserID: user.WorkosID.String, }) @@ -172,7 +176,7 @@ func (e *Engine) resolveRoleSlugs(ctx context.Context, userID, orgID string) ([] return nil, fmt.Errorf("list member role slugs: %w", err) } - return roleSlugs, nil + return rolePrincipals, nil } func (e *Engine) Require(ctx context.Context, checks ...Check) error { diff --git a/server/internal/authz/engine_test.go b/server/internal/authz/engine_test.go index 25bd68f312..fe31b6c124 100644 --- a/server/internal/authz/engine_test.go +++ b/server/internal/authz/engine_test.go @@ -116,17 +116,17 @@ func TestResolveRoleSlugs_readsLocalAssignmentsOnly(t *testing.T) { require.NoError(t, err) engine := NewEngine(testenv.NewLogger(t), conn, chConn, staticRBAC(true), staticChallengeLogging(true), membership) - roleSlugs, err := engine.resolveRoleSlugs(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) + rolePrincipals, err := engine.resolveRolePrincipals(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) require.NoError(t, err) - require.Empty(t, roleSlugs) + require.Empty(t, rolePrincipals) - roleSlugs, err = engine.resolveRoleSlugs(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) + rolePrincipals, err = engine.resolveRolePrincipals(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) require.NoError(t, err) - require.Empty(t, roleSlugs) + require.Empty(t, rolePrincipals) require.Equal(t, 0, membership.calls) } -func TestResolveRoleSlugs_returnsAllLocalAssignments(t *testing.T) { +func TestResolveRolePrincipals_returnsAllLocalAssignments(t *testing.T) { t.Parallel() ctx := enterpriseTestCtx(t.Context()) @@ -146,9 +146,19 @@ func TestResolveRoleSlugs_returnsAllLocalAssignments(t *testing.T) { require.NoError(t, err) engine := NewEngine(testenv.NewLogger(t), conn, chConn, staticRBAC(true), staticChallengeLogging(true), &countingMembershipFetcher{}) - roleSlugs, err := engine.resolveRoleSlugs(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) + rolePrincipals, err := engine.resolveRolePrincipals(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) require.NoError(t, err) + roleSlugs := make([]string, 0, len(rolePrincipals)) + principalURNs := make([]string, 0, len(rolePrincipals)) + for _, role := range rolePrincipals { + roleSlugs = append(roleSlugs, role.RoleSlug) + principalURNs = append(principalURNs, role.PrincipalUrn) + } require.ElementsMatch(t, []string{"custom-alpha", "custom-beta"}, roleSlugs) + require.Len(t, principalURNs, 2) + for _, principalURN := range principalURNs { + require.Contains(t, principalURN, "role:organization:") + } } func TestEngineRequireAny_mapsDeniedToForbidden(t *testing.T) { diff --git a/server/internal/authz/grants.go b/server/internal/authz/grants.go index 092f051f1d..4ac07eceb3 100644 --- a/server/internal/authz/grants.go +++ b/server/internal/authz/grants.go @@ -53,7 +53,7 @@ var SystemRoleGrants = map[string][]*RoleGrant{ // SeedSystemRoleGrants upserts the fixed grant sets for all system roles. func SeedSystemRoleGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, organizationID string) error { for roleSlug, grants := range SystemRoleGrants { - if err := SyncGrants(ctx, logger, db, organizationID, roleSlug, grants); err != nil { + if err := SyncGrants(ctx, logger, db, organizationID, roleSlug, "", grants); err != nil { return fmt.Errorf("seed %s grants: %w", roleSlug, err) } } @@ -72,7 +72,7 @@ type ScopedGrant struct { Selectors []Selector } -func SyncGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgID string, roleSlug string, grants []*RoleGrant) error { +func SyncGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgID string, roleSlug string, rolePrincipalURN string, grants []*RoleGrant) error { if orgID == "" { return fmt.Errorf("organization id is required") } @@ -83,7 +83,7 @@ func SyncGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgI } defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) - if _, err := SyncGrantsTx(ctx, tx, orgID, roleSlug, grants); err != nil { + if _, err := SyncGrantsTx(ctx, tx, orgID, roleSlug, rolePrincipalURN, grants); err != nil { return err } @@ -94,20 +94,30 @@ func SyncGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgI return nil } -func SyncGrantsTx(ctx context.Context, dbtx repo.DBTX, orgID string, roleSlug string, grants []*RoleGrant) ([]*ScopedGrant, error) { +func SyncGrantsTx(ctx context.Context, dbtx repo.DBTX, orgID string, roleSlug string, rolePrincipalURN string, grants []*RoleGrant) ([]*ScopedGrant, error) { if orgID == "" { return nil, fmt.Errorf("organization id is required") } - principalURN := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) + if rolePrincipalURN == "" { + role, err := repo.New(dbtx).GetActiveOrganizationRoleBySlug(ctx, repo.GetActiveOrganizationRoleBySlugParams{ + OrganizationID: orgID, + WorkosSlug: roleSlug, + }) + if err != nil { + return nil, fmt.Errorf("resolve role principal for %q: %w", roleSlug, err) + } + rolePrincipalURN = role.RoleUrn + } - q := repo.New(dbtx) + principalURN, err := urn.ParsePrincipal(rolePrincipalURN) + if err != nil { + return nil, fmt.Errorf("parse role principal urn %q: %w", rolePrincipalURN, err) + } - if _, err := q.DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ - OrganizationID: orgID, - PrincipalUrn: principalURN, - }); err != nil { - return nil, fmt.Errorf("delete grants for role %q: %w", roleSlug, err) + q := repo.New(dbtx) + if err := DeleteRoleGrants(ctx, q, orgID, roleSlug, rolePrincipalURN); err != nil { + return nil, err } for _, grant := range grants { @@ -172,16 +182,74 @@ func SyncGrantsTx(ctx context.Context, dbtx repo.DBTX, orgID string, roleSlug st return scoped, nil } -func GrantsForRole(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgID string, roleSlug string) ([]*ScopedGrant, error) { - rows, err := repo.New(db).ListPrincipalGrantsByOrg(ctx, repo.ListPrincipalGrantsByOrgParams{ +func DeleteRoleGrants(ctx context.Context, q *repo.Queries, orgID, roleSlug, rolePrincipalURN string) error { + principalURN, err := urn.ParsePrincipal(rolePrincipalURN) + if err != nil { + return fmt.Errorf("parse role principal urn %q: %w", rolePrincipalURN, err) + } + + if _, err := q.DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ + OrganizationID: orgID, + PrincipalUrn: principalURN, + }); err != nil { + return fmt.Errorf("delete grants for role %q: %w", roleSlug, err) + } + + legacyPrincipalURN := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) + if legacyPrincipalURN.String() == principalURN.String() { + return nil + } + if _, err := q.DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ + OrganizationID: orgID, + PrincipalUrn: legacyPrincipalURN, + }); err != nil { + return fmt.Errorf("delete legacy grants for role %q: %w", roleSlug, err) + } + + return nil +} + +func RolePrincipals(roleSlug, rolePrincipalURN string) ([]urn.Principal, error) { + // Keep the legacy role: principal until the role-principal backfill + // has moved existing grants to role::. + principals := make([]urn.Principal, 0, 2) + if rolePrincipalURN != "" { + principal, err := urn.ParsePrincipal(rolePrincipalURN) + if err != nil { + return nil, fmt.Errorf("parse role principal urn %q: %w", rolePrincipalURN, err) + } + principals = append(principals, principal) + } + + legacyPrincipal := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) + if rolePrincipalURN == "" || legacyPrincipal.String() != rolePrincipalURN { + principals = append(principals, legacyPrincipal) + } + + return principals, nil +} + +func GrantsForRole(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgID string, roleSlug string, rolePrincipalURN string) ([]*ScopedGrant, error) { + // During the role-principal migration, reads include both the canonical + // role:: principal and the legacy role: principal. + principals, err := RolePrincipals(roleSlug, rolePrincipalURN) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "build role principals").Log(ctx, logger) + } + principalURNs, err := parsePrincipalURNs(principals) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "build role principals").Log(ctx, logger) + } + + rows, err := repo.New(db).GetPrincipalGrants(ctx, repo.GetPrincipalGrantsParams{ OrganizationID: orgID, - PrincipalUrn: urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug).String(), + PrincipalUrns: principalURNs, }) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "list grants for role").Log(ctx, logger) } - scoped, err := scopedGrantsFromRows(urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug).String(), rows) + scoped, err := scopedGrantsFromGrantRows(rows) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "unmarshal grant selector").Log(ctx, logger) } @@ -189,6 +257,23 @@ func GrantsForRole(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, o return scoped, nil } +func scopedGrantsFromGrantRows(rows []repo.GetPrincipalGrantsRow) ([]*ScopedGrant, error) { + grantRows := make([]Grant, 0, len(rows)) + for _, row := range rows { + selectors, err := SelectorFromRow(row.Selectors) + if err != nil { + return nil, err + } + grantRows = append(grantRows, Grant{ + PrincipalUrn: row.PrincipalUrn.String(), + Scope: Scope(row.Scope), + Selector: selectors, + }) + } + + return GrantsToScopedGrants(grantRows), nil +} + func scopedGrantsFromRows(rolePrincipalURN string, rows []repo.ListPrincipalGrantsByOrgRow) ([]*ScopedGrant, error) { grantRows := make([]Grant, 0, len(rows)) for _, row := range rows { diff --git a/server/internal/background/activities/process_workos_org_events.go b/server/internal/background/activities/process_workos_org_events.go index 567bd53eaa..b659738215 100644 --- a/server/internal/background/activities/process_workos_org_events.go +++ b/server/internal/background/activities/process_workos_org_events.go @@ -14,6 +14,7 @@ import ( accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" "github.com/speakeasy-api/gram/server/internal/attr" + "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/conv" "github.com/speakeasy-api/gram/server/internal/database" "github.com/speakeasy-api/gram/server/internal/o11y" @@ -443,10 +444,8 @@ func handleRoleDeleted(ctx context.Context, logger *slog.Logger, dbtx database.D return fmt.Errorf("mark organization role %q deleted: %w", payload.Slug, err) } - if _, err := repo.DeletePrincipalGrantsByPrincipal(ctx, accessrepo.DeletePrincipalGrantsByPrincipalParams{ - OrganizationID: org.ID, - PrincipalUrn: urn.NewPrincipal(urn.PrincipalTypeRole, payload.Slug), - }); err != nil { + rolePrincipal := urn.NewPrincipal(urn.PrincipalTypeRole, "organization:"+existing.ID.String()) + if err := authz.DeleteRoleGrants(ctx, repo, org.ID, payload.Slug, rolePrincipal.String()); err != nil { return fmt.Errorf("delete grants for role %q: %w", payload.Slug, err) } From 644bd8ec0e2df8917868aaf4674bf019afe352f9 Mon Sep 17 00:00:00 2001 From: tgmendes Date: Tue, 19 May 2026 11:38:16 +0100 Subject: [PATCH 08/12] chore: clean up local role management queries --- server/internal/access/queries.sql | 26 +----- server/internal/access/repo/queries.sql.go | 61 +----------- server/internal/access/role_manager.go | 103 +++++++++++---------- server/internal/authz/grants.go | 54 +++++------ 4 files changed, 89 insertions(+), 155 deletions(-) diff --git a/server/internal/access/queries.sql b/server/internal/access/queries.sql index 9b0ee4cdb2..33711a6193 100644 --- a/server/internal/access/queries.sql +++ b/server/internal/access/queries.sql @@ -243,6 +243,8 @@ LEFT JOIN organization_role_assignments AS ora AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +-- Organization roles shadow global roles if a slug ever collides. +ORDER BY active_roles.role_kind DESC LIMIT 1; -- name: GetOrganizationRoleByID :one @@ -275,29 +277,10 @@ LEFT JOIN organization_role_assignments AS ora AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +-- Organization roles shadow global roles if an ID ever collides. +ORDER BY active_roles.role_kind DESC LIMIT 1; --- name: ListOrganizationRoleAssignmentsForOrg :many -SELECT - ora.user_id, - ora.workos_user_id, - ora.workos_membership_id, - COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, - ora.created_at -FROM organization_role_assignments AS ora -LEFT JOIN organization_roles - ON ora.role_urn = 'role:organization:' || organization_roles.id::text - AND organization_roles.organization_id = ora.organization_id - AND organization_roles.deleted IS FALSE - AND organization_roles.workos_deleted IS FALSE -LEFT JOIN global_roles - ON ora.role_urn = 'role:global:' || global_roles.id::text - AND global_roles.deleted IS FALSE - AND global_roles.workos_deleted IS FALSE -WHERE ora.organization_id = @organization_id - AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL -ORDER BY ora.workos_user_id, role_slug; - -- name: ListOrganizationRoleAssignmentsBySlug :many SELECT ora.user_id, @@ -362,6 +345,7 @@ LEFT JOIN global_roles AND global_roles.workos_deleted IS FALSE WHERE ora.organization_id = @organization_id AND ora.workos_user_id = ANY(@workos_user_ids::text[]) + AND ora.workos_membership_id IS NOT NULL AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL ORDER BY ora.workos_user_id, role_slug; diff --git a/server/internal/access/repo/queries.sql.go b/server/internal/access/repo/queries.sql.go index c00cdc1a38..bdd9b047d3 100644 --- a/server/internal/access/repo/queries.sql.go +++ b/server/internal/access/repo/queries.sql.go @@ -84,6 +84,7 @@ LEFT JOIN organization_role_assignments AS ora AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +ORDER BY active_roles.role_kind DESC LIMIT 1 ` @@ -103,6 +104,7 @@ type GetActiveOrganizationRoleBySlugRow struct { MemberCount int64 } +// Organization roles shadow global roles if a slug ever collides. func (q *Queries) GetActiveOrganizationRoleBySlug(ctx context.Context, arg GetActiveOrganizationRoleBySlugParams) (GetActiveOrganizationRoleBySlugRow, error) { row := q.db.QueryRow(ctx, getActiveOrganizationRoleBySlug, arg.OrganizationID, arg.WorkosSlug) var i GetActiveOrganizationRoleBySlugRow @@ -229,6 +231,7 @@ LEFT JOIN organization_role_assignments AS ora AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +ORDER BY active_roles.role_kind DESC LIMIT 1 ` @@ -248,6 +251,7 @@ type GetOrganizationRoleByIDRow struct { MemberCount int64 } +// Organization roles shadow global roles if an ID ever collides. func (q *Queries) GetOrganizationRoleByID(ctx context.Context, arg GetOrganizationRoleByIDParams) (GetOrganizationRoleByIDRow, error) { row := q.db.QueryRow(ctx, getOrganizationRoleByID, arg.OrganizationID, arg.ID) var i GetOrganizationRoleByIDRow @@ -715,6 +719,7 @@ LEFT JOIN global_roles AND global_roles.workos_deleted IS FALSE WHERE ora.organization_id = $1 AND ora.workos_user_id = ANY($2::text[]) + AND ora.workos_membership_id IS NOT NULL AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL ORDER BY ora.workos_user_id, role_slug ` @@ -758,62 +763,6 @@ func (q *Queries) ListOrganizationRoleAssignmentsByWorkosUsers(ctx context.Conte return items, nil } -const listOrganizationRoleAssignmentsForOrg = `-- name: ListOrganizationRoleAssignmentsForOrg :many -SELECT - ora.user_id, - ora.workos_user_id, - ora.workos_membership_id, - COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, - ora.created_at -FROM organization_role_assignments AS ora -LEFT JOIN organization_roles - ON ora.role_urn = 'role:organization:' || organization_roles.id::text - AND organization_roles.organization_id = ora.organization_id - AND organization_roles.deleted IS FALSE - AND organization_roles.workos_deleted IS FALSE -LEFT JOIN global_roles - ON ora.role_urn = 'role:global:' || global_roles.id::text - AND global_roles.deleted IS FALSE - AND global_roles.workos_deleted IS FALSE -WHERE ora.organization_id = $1 - AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL -ORDER BY ora.workos_user_id, role_slug -` - -type ListOrganizationRoleAssignmentsForOrgRow struct { - UserID pgtype.Text - WorkosUserID string - WorkosMembershipID pgtype.Text - RoleSlug string - CreatedAt pgtype.Timestamptz -} - -func (q *Queries) ListOrganizationRoleAssignmentsForOrg(ctx context.Context, organizationID string) ([]ListOrganizationRoleAssignmentsForOrgRow, error) { - rows, err := q.db.Query(ctx, listOrganizationRoleAssignmentsForOrg, organizationID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ListOrganizationRoleAssignmentsForOrgRow - for rows.Next() { - var i ListOrganizationRoleAssignmentsForOrgRow - if err := rows.Scan( - &i.UserID, - &i.WorkosUserID, - &i.WorkosMembershipID, - &i.RoleSlug, - &i.CreatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const listPrincipalGrantsByOrg = `-- name: ListPrincipalGrantsByOrg :many SELECT id, organization_id, principal_urn, principal_type, scope, selectors, created_at, updated_at diff --git a/server/internal/access/role_manager.go b/server/internal/access/role_manager.go index b6c4157bcd..3803c92db2 100644 --- a/server/internal/access/role_manager.go +++ b/server/internal/access/role_manager.go @@ -405,14 +405,18 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro }) } - if _, err := repo.New(tx).MarkOrganizationRoleDeletedLocally(ctx, repo.MarkOrganizationRoleDeletedLocallyParams{ + deletedCount, err := repo.New(tx).MarkOrganizationRoleDeletedLocally(ctx, repo.MarkOrganizationRoleDeletedLocallyParams{ OrganizationID: gramOrgID, WorkosSlug: currentRole.Slug, WorkosLastEventID: conv.ToPGTextEmpty(""), - }); err != nil { + }) + if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "mark local role record deleted").Log(ctx, r.logger) } + if deletedCount == 0 { + return roleDeleteResult{}, oops.E(oops.CodeNotFound, nil, "role not found").Log(ctx, r.logger) + } if err := authz.DeleteRoleGrants(ctx, repo.New(tx), gramOrgID, currentRole.Slug, currentRole.PrincipalURN); err != nil { return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "delete grants for deleted role").Log(ctx, r.logger) @@ -568,6 +572,7 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r return result, nil } +// MemberRolePrincipals returns role slug and principal URN for each role assigned to a WorkOS user in this org. func (r *RoleManager) MemberRolePrincipals(ctx context.Context, gramOrgID, workosUserID string) ([]repo.ListMemberRolePrincipalsByWorkosUserRow, error) { if workosUserID == "" { return nil, nil @@ -633,43 +638,60 @@ type memberAssignmentTarget struct { MembershipID string } +// requestedMemberAssignment groups input IDs that resolve to the same WorkOS user. +type requestedMemberAssignment struct { + InputIDs []string + UserID string +} + func (r *RoleManager) memberAssignmentTargetsTx(ctx context.Context, dbtx repo.DBTX, gramOrgID string, memberIDs []string) ([]memberAssignmentTarget, error) { if len(memberIDs) == 0 { return nil, nil } - requested := make(map[string]struct{}, len(memberIDs)) - for _, id := range memberIDs { - requested[id] = struct{}{} - } users, err := usersrepo.New(dbtx).GetUsersByIDs(ctx, memberIDs) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "resolve users by ids").Log(ctx, r.logger) } - workosByGramID := make(map[string]string, len(users)) - requestedByWorkosID := make(map[string]string, len(users)) + usersByID := make(map[string]usersrepo.User, len(users)) for _, user := range users { - if user.WorkosID.Valid && user.WorkosID.String != "" { - workosByGramID[user.ID] = user.WorkosID.String - if _, ok := requested[user.ID]; ok { - requestedByWorkosID[user.WorkosID.String] = user.ID - } - } + usersByID[user.ID] = user } + + requestedInputs := make(map[string]struct{}, len(memberIDs)) + requestedByWorkosID := make(map[string]requestedMemberAssignment, len(memberIDs)) workosIDs := make([]string, 0, len(memberIDs)) - seenRequestedWorkosID := make(map[string]struct{}, len(memberIDs)) for _, id := range memberIDs { + if _, ok := requestedInputs[id]; ok { + continue + } + requestedInputs[id] = struct{}{} + workosID := id - if userWorkosID, ok := workosByGramID[id]; ok { - workosID = userWorkosID + userID := "" + if user, ok := usersByID[id]; ok { + userID = user.ID + if !user.WorkosID.Valid || user.WorkosID.String == "" { + return nil, oops.E(oops.CodeBadRequest, nil, "member %s is not linked to WorkOS", id).Log(ctx, r.logger) + } + workosID = user.WorkosID.String } if workosID == "" { continue } - if _, ok := seenRequestedWorkosID[workosID]; ok { + requested, ok := requestedByWorkosID[workosID] + if ok { + requested.InputIDs = append(requested.InputIDs, id) + if requested.UserID == "" { + requested.UserID = userID + } + requestedByWorkosID[workosID] = requested continue } - seenRequestedWorkosID[workosID] = struct{}{} + requestedByWorkosID[workosID] = requestedMemberAssignment{ + InputIDs: []string{id}, + UserID: userID, + } workosIDs = append(workosIDs, workosID) } @@ -682,44 +704,31 @@ func (r *RoleManager) memberAssignmentTargetsTx(ctx context.Context, dbtx repo.D } trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) targets := make([]memberAssignmentTarget, 0, len(memberIDs)) - resolved := make(map[string]struct{}, len(requested)) - seenWorkosID := make(map[string]struct{}, len(memberIDs)) + resolvedWorkosIDs := make(map[string]struct{}, len(requestedByWorkosID)) + resolvedInputs := make(map[string]struct{}, len(requestedInputs)) for _, row := range assignmentRows { - userID := conv.FromPGTextOrEmpty[string](row.UserID) - membershipID := conv.FromPGTextOrEmpty[string](row.WorkosMembershipID) - requestedID := "" - if _, ok := requested[userID]; ok { - requestedID = userID - } else if gramID, ok := requestedByWorkosID[row.WorkosUserID]; ok { - requestedID = gramID - if userID == "" { - userID = gramID - } - } else if _, ok := requested[row.WorkosUserID]; ok { - requestedID = row.WorkosUserID - } else { + requested, ok := requestedByWorkosID[row.WorkosUserID] + if !ok { continue } - workosID := row.WorkosUserID - if userWorkosID, ok := workosByGramID[userID]; ok { - workosID = userWorkosID - } - if workosID == "" || membershipID == "" { + if _, ok := resolvedWorkosIDs[row.WorkosUserID]; ok { continue } - if _, ok := seenWorkosID[workosID]; ok { - resolved[requestedID] = struct{}{} - continue + userID := conv.FromPGTextOrEmpty[string](row.UserID) + if userID == "" { + userID = requested.UserID + } + resolvedWorkosIDs[row.WorkosUserID] = struct{}{} + for _, inputID := range requested.InputIDs { + resolvedInputs[inputID] = struct{}{} } - seenWorkosID[workosID] = struct{}{} - resolved[requestedID] = struct{}{} targets = append(targets, memberAssignmentTarget{ UserID: userID, - WorkosUserID: workosID, - MembershipID: membershipID, + WorkosUserID: row.WorkosUserID, + MembershipID: conv.FromPGTextOrEmpty[string](row.WorkosMembershipID), }) } - if len(resolved) != len(requested) { + if len(resolvedInputs) != len(requestedInputs) { return nil, oops.E(oops.CodeBadRequest, nil, "member role assignment not found; wait for WorkOS sync to complete").Log(ctx, r.logger) } diff --git a/server/internal/authz/grants.go b/server/internal/authz/grants.go index 4ac07eceb3..fb9d0794a5 100644 --- a/server/internal/authz/grants.go +++ b/server/internal/authz/grants.go @@ -120,6 +120,8 @@ func SyncGrantsTx(ctx context.Context, dbtx repo.DBTX, orgID string, roleSlug st return nil, err } + grantRows := make([]Grant, 0, len(grants)) + seenGrants := make(map[string]struct{}, len(grants)) for _, grant := range grants { if grant == nil { continue @@ -143,6 +145,15 @@ func SyncGrantsTx(ctx context.Context, dbtx repo.DBTX, orgID string, roleSlug st }); err != nil { return nil, fmt.Errorf("upsert unrestricted grant %q for role %q: %w", grant.Scope, roleSlug, err) } + grantKey := grant.Scope + "\x00" + string(selBytes) + if _, ok := seenGrants[grantKey]; !ok { + seenGrants[grantKey] = struct{}{} + grantRows = append(grantRows, Grant{ + PrincipalUrn: principalURN.String(), + Scope: scope, + Selector: sel, + }) + } continue } @@ -163,23 +174,19 @@ func SyncGrantsTx(ctx context.Context, dbtx repo.DBTX, orgID string, roleSlug st }); err != nil { return nil, fmt.Errorf("upsert grant %q for role %q: %w", grant.Scope, roleSlug, err) } + grantKey := grant.Scope + "\x00" + string(selBytes) + if _, ok := seenGrants[grantKey]; !ok { + seenGrants[grantKey] = struct{}{} + grantRows = append(grantRows, Grant{ + PrincipalUrn: principalURN.String(), + Scope: scope, + Selector: sel, + }) + } } } - rows, err := q.ListPrincipalGrantsByOrg(ctx, repo.ListPrincipalGrantsByOrgParams{ - OrganizationID: orgID, - PrincipalUrn: principalURN.String(), - }) - if err != nil { - return nil, fmt.Errorf("list synced grants for role %q: %w", roleSlug, err) - } - - scoped, err := scopedGrantsFromRows(principalURN.String(), rows) - if err != nil { - return nil, fmt.Errorf("load synced grants for role %q: %w", roleSlug, err) - } - - return scoped, nil + return GrantsToScopedGrants(grantRows), nil } func DeleteRoleGrants(ctx context.Context, q *repo.Queries, orgID, roleSlug, rolePrincipalURN string) error { @@ -210,6 +217,7 @@ func DeleteRoleGrants(ctx context.Context, q *repo.Queries, orgID, roleSlug, rol } func RolePrincipals(roleSlug, rolePrincipalURN string) ([]urn.Principal, error) { + // TODO(AGE-1954): drop legacy role: principals after the role-principal backfill is complete. // Keep the legacy role: principal until the role-principal backfill // has moved existing grants to role::. principals := make([]urn.Principal, 0, 2) @@ -230,6 +238,7 @@ func RolePrincipals(roleSlug, rolePrincipalURN string) ([]urn.Principal, error) } func GrantsForRole(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgID string, roleSlug string, rolePrincipalURN string) ([]*ScopedGrant, error) { + // TODO(AGE-1954): remove dual-read after legacy role: grants are backfilled. // During the role-principal migration, reads include both the canonical // role:: principal and the legacy role: principal. principals, err := RolePrincipals(roleSlug, rolePrincipalURN) @@ -274,23 +283,6 @@ func scopedGrantsFromGrantRows(rows []repo.GetPrincipalGrantsRow) ([]*ScopedGran return GrantsToScopedGrants(grantRows), nil } -func scopedGrantsFromRows(rolePrincipalURN string, rows []repo.ListPrincipalGrantsByOrgRow) ([]*ScopedGrant, error) { - grantRows := make([]Grant, 0, len(rows)) - for _, row := range rows { - selectors, err := SelectorFromRow(row.Selectors) - if err != nil { - return nil, err - } - grantRows = append(grantRows, Grant{ - PrincipalUrn: rolePrincipalURN, - Scope: Scope(row.Scope), - Selector: selectors, - }) - } - - return GrantsToScopedGrants(grantRows), nil -} - type scopeAgg struct { unrestricted bool selectors []Selector From 49e03c2d956938952c054ebad7b790c4fd73b37d Mon Sep 17 00:00:00 2001 From: tgmendes Date: Thu, 21 May 2026 10:06:16 +0100 Subject: [PATCH 09/12] fix: list unassigned access members --- server/internal/access/listmembers_test.go | 26 ++++- server/internal/access/listroles_test.go | 18 +++ server/internal/access/queries.sql | 103 +++++++++++------ server/internal/access/rbac_test.go | 2 +- server/internal/access/repo/queries.sql.go | 107 ++++++++++++------ server/internal/access/role_manager.go | 6 +- server/internal/access/role_manager_test.go | 9 +- server/internal/access/setup_test.go | 24 ++++ .../internal/access/updatememberrole_test.go | 2 +- 9 files changed, 214 insertions(+), 83 deletions(-) diff --git a/server/internal/access/listmembers_test.go b/server/internal/access/listmembers_test.go index 1771325b46..44f3680bfc 100644 --- a/server/internal/access/listmembers_test.go +++ b/server/internal/access/listmembers_test.go @@ -26,7 +26,7 @@ func TestService_ListMembers(t *testing.T) { result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) - require.Len(t, result.Members, 2) + require.Len(t, result.Members, 3) byID := map[string]*gen.AccessMember{} for _, member := range result.Members { @@ -41,6 +41,8 @@ func TestService_ListMembers(t *testing.T) { require.Equal(t, "Grace", byID["local_user_2"].Name) require.Equal(t, builderID, byID["local_user_2"].RoleID) + + require.Empty(t, byID[authCtx.UserID].RoleID) } func TestService_ListMembers_SkipsMembersWithoutLocalUser(t *testing.T) { @@ -61,7 +63,7 @@ func TestService_ListMembers_SkipsMembersWithoutLocalUser(t *testing.T) { result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) - require.Len(t, result.Members, 1) + require.Len(t, result.Members, 2) byID := map[string]*gen.AccessMember{} for _, member := range result.Members { @@ -69,6 +71,22 @@ func TestService_ListMembers_SkipsMembersWithoutLocalUser(t *testing.T) { } require.Equal(t, "Ada Lovelace", byID["local_user_1"].Name) require.Nil(t, byID["user_2"]) + require.Empty(t, byID[authCtx.UserID].RoleID) +} + +func TestService_ListMembers_IncludesConnectedUsersWithoutRoleAssignments(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, _ := contextvalues.GetAuthContext(ctx) + + result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) + require.NoError(t, err) + require.Len(t, result.Members, 1) + + require.Equal(t, authCtx.UserID, result.Members[0].ID) + require.Empty(t, result.Members[0].RoleID) + require.NotEmpty(t, result.Members[0].Name) } func TestService_ListMembers_UsesDatabaseOnly(t *testing.T) { @@ -81,5 +99,7 @@ func TestService_ListMembers_UsesDatabaseOnly(t *testing.T) { result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) - require.Empty(t, result.Members) + require.Len(t, result.Members, 1) + require.Equal(t, authCtx.UserID, result.Members[0].ID) + require.Empty(t, result.Members[0].RoleID) } diff --git a/server/internal/access/listroles_test.go b/server/internal/access/listroles_test.go index e6a51873b7..3bf95ed3f0 100644 --- a/server/internal/access/listroles_test.go +++ b/server/internal/access/listroles_test.go @@ -99,3 +99,21 @@ func TestService_ListRoles_ExcludesDisconnectedUsersFromMemberCounts(t *testing. // Only user_1 should be counted — user_2 has a local account but no org connection. require.Equal(t, 1, result.Roles[0].MemberCount) } + +func TestService_ListRoles_DoesNotCountConnectedUsersWithoutAssignments(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + _, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + memberID := seedGlobalRole(t, ctx, ti.conn, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) + + result, err := ti.service.ListRoles(ctx, &gen.ListRolesPayload{}) + require.NoError(t, err) + require.Len(t, result.Roles, 1) + + require.Equal(t, memberID, result.Roles[0].ID) + require.Equal(t, "Member", result.Roles[0].Name) + require.Equal(t, 0, result.Roles[0].MemberCount) +} diff --git a/server/internal/access/queries.sql b/server/internal/access/queries.sql index 90f22813be..c5dd496ad8 100644 --- a/server/internal/access/queries.sql +++ b/server/internal/access/queries.sql @@ -217,12 +217,13 @@ SELECT active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at, - COUNT(ora.id)::bigint AS member_count + COUNT(DISTINCT ora.id)::bigint AS member_count FROM active_roles LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = @organization_id AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at ORDER BY active_roles.workos_slug; @@ -249,12 +250,13 @@ SELECT active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at, - COUNT(ora.id)::bigint AS member_count + COUNT(DISTINCT ora.id)::bigint AS member_count FROM active_roles LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = @organization_id AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at -- Organization roles shadow global roles if a slug ever collides. ORDER BY active_roles.role_kind DESC @@ -283,12 +285,13 @@ SELECT active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at, - COUNT(ora.id)::bigint AS member_count + COUNT(DISTINCT ora.id)::bigint AS member_count FROM active_roles LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = @organization_id AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at -- Organization roles shadow global roles if an ID ever collides. ORDER BY active_roles.role_kind DESC @@ -317,50 +320,73 @@ ORDER BY ora.workos_user_id; -- name: GetOrganizationRoleAssignmentByWorkosUser :one SELECT - ora.user_id, - ora.workos_user_id, - ora.workos_membership_id, - COALESCE(organization_roles.id, global_roles.id)::uuid AS role_id, - COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, - ora.created_at -FROM organization_role_assignments AS ora + our.user_id, + users.workos_id::text AS workos_user_id, + our.workos_membership_id, + COALESCE(organization_roles.id::text, global_roles.id::text, '')::text AS role_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug, '')::text AS role_slug, + COALESCE(ora.created_at, our.created_at)::timestamptz AS created_at +FROM organization_user_relationships AS our +JOIN users + ON users.id = our.user_id +LEFT JOIN LATERAL ( + SELECT * + FROM organization_role_assignments + WHERE organization_role_assignments.organization_id = our.organization_id + AND organization_role_assignments.workos_user_id = users.workos_id + AND organization_role_assignments.deleted_at IS NULL + ORDER BY organization_role_assignments.created_at + LIMIT 1 +) AS ora ON TRUE LEFT JOIN organization_roles ON ora.role_urn = 'role:organization:' || organization_roles.id::text - AND organization_roles.organization_id = ora.organization_id + AND organization_roles.organization_id = our.organization_id AND organization_roles.deleted IS FALSE AND organization_roles.workos_deleted IS FALSE LEFT JOIN global_roles ON ora.role_urn = 'role:global:' || global_roles.id::text AND global_roles.deleted IS FALSE AND global_roles.workos_deleted IS FALSE -WHERE ora.organization_id = @organization_id - AND ora.workos_user_id = @workos_user_id - AND COALESCE(organization_roles.id, global_roles.id) IS NOT NULL -ORDER BY ora.created_at +WHERE our.organization_id = @organization_id + AND users.workos_id = sqlc.arg(workos_user_id)::text + AND users.workos_id IS NOT NULL + AND our.deleted IS FALSE LIMIT 1; -- name: ListOrganizationRoleAssignmentsByWorkosUsers :many SELECT - ora.user_id, - ora.workos_user_id, - ora.workos_membership_id, - COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, - ora.created_at -FROM organization_role_assignments AS ora + our.user_id, + users.workos_id::text AS workos_user_id, + our.workos_membership_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug, '')::text AS role_slug, + COALESCE(ora.created_at, our.created_at)::timestamptz AS created_at +FROM organization_user_relationships AS our +JOIN users + ON users.id = our.user_id +LEFT JOIN LATERAL ( + SELECT * + FROM organization_role_assignments + WHERE organization_role_assignments.organization_id = our.organization_id + AND organization_role_assignments.workos_user_id = users.workos_id + AND organization_role_assignments.deleted_at IS NULL + ORDER BY organization_role_assignments.created_at + LIMIT 1 +) AS ora ON TRUE LEFT JOIN organization_roles ON ora.role_urn = 'role:organization:' || organization_roles.id::text - AND organization_roles.organization_id = ora.organization_id + AND organization_roles.organization_id = our.organization_id AND organization_roles.deleted IS FALSE AND organization_roles.workos_deleted IS FALSE LEFT JOIN global_roles ON ora.role_urn = 'role:global:' || global_roles.id::text AND global_roles.deleted IS FALSE AND global_roles.workos_deleted IS FALSE -WHERE ora.organization_id = @organization_id - AND ora.workos_user_id = ANY(@workos_user_ids::text[]) - AND ora.workos_membership_id IS NOT NULL - AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL -ORDER BY ora.workos_user_id, role_slug; +WHERE our.organization_id = @organization_id + AND users.workos_id = ANY(@workos_user_ids::text[]) + AND users.workos_id IS NOT NULL + AND our.workos_membership_id IS NOT NULL + AND our.deleted IS FALSE +ORDER BY users.workos_id, role_slug; -- name: ListAccessMembers :many SELECT @@ -368,22 +394,31 @@ SELECT users.display_name, users.email, users.photo_url, - COALESCE(organization_roles.id, global_roles.id)::uuid AS role_id, - ora.created_at AS joined_at -FROM organization_role_assignments AS ora + COALESCE(organization_roles.id::text, global_roles.id::text, '')::text AS role_id, + COALESCE(ora.created_at, our.created_at)::timestamptz AS joined_at +FROM organization_user_relationships AS our JOIN users - ON users.id = ora.user_id + ON users.id = our.user_id +LEFT JOIN LATERAL ( + SELECT * + FROM organization_role_assignments + WHERE organization_role_assignments.organization_id = our.organization_id + AND organization_role_assignments.workos_user_id = users.workos_id + AND organization_role_assignments.deleted_at IS NULL + ORDER BY organization_role_assignments.created_at + LIMIT 1 +) AS ora ON TRUE LEFT JOIN organization_roles ON ora.role_urn = 'role:organization:' || organization_roles.id::text - AND organization_roles.organization_id = ora.organization_id + AND organization_roles.organization_id = our.organization_id AND organization_roles.deleted IS FALSE AND organization_roles.workos_deleted IS FALSE LEFT JOIN global_roles ON ora.role_urn = 'role:global:' || global_roles.id::text AND global_roles.deleted IS FALSE AND global_roles.workos_deleted IS FALSE -WHERE ora.organization_id = @organization_id - AND COALESCE(organization_roles.id, global_roles.id) IS NOT NULL +WHERE our.organization_id = @organization_id + AND our.deleted IS FALSE ORDER BY users.email, users.id; -- name: ListMemberRolePrincipalsByWorkosUser :many diff --git a/server/internal/access/rbac_test.go b/server/internal/access/rbac_test.go index c815b640b7..6dc2158684 100644 --- a/server/internal/access/rbac_test.go +++ b/server/internal/access/rbac_test.go @@ -117,7 +117,7 @@ func TestService_ListMembers_AllowsOrgReadGrant(t *testing.T) { result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) - require.Len(t, result.Members, 1) + require.Len(t, result.Members, 2) } func TestService_CreateRole_ForbiddenWithoutOrgAdminGrant(t *testing.T) { diff --git a/server/internal/access/repo/queries.sql.go b/server/internal/access/repo/queries.sql.go index bbfef019df..ba8d6d60c8 100644 --- a/server/internal/access/repo/queries.sql.go +++ b/server/internal/access/repo/queries.sql.go @@ -77,12 +77,13 @@ SELECT active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at, - COUNT(ora.id)::bigint AS member_count + COUNT(DISTINCT ora.id)::bigint AS member_count FROM active_roles LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = $1 AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at ORDER BY active_roles.role_kind DESC LIMIT 1 @@ -150,26 +151,37 @@ func (q *Queries) GetGlobalRoleBySlug(ctx context.Context, workosSlug string) (G const getOrganizationRoleAssignmentByWorkosUser = `-- name: GetOrganizationRoleAssignmentByWorkosUser :one SELECT - ora.user_id, - ora.workos_user_id, - ora.workos_membership_id, - COALESCE(organization_roles.id, global_roles.id)::uuid AS role_id, - COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, - ora.created_at -FROM organization_role_assignments AS ora + our.user_id, + users.workos_id::text AS workos_user_id, + our.workos_membership_id, + COALESCE(organization_roles.id::text, global_roles.id::text, '')::text AS role_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug, '')::text AS role_slug, + COALESCE(ora.created_at, our.created_at)::timestamptz AS created_at +FROM organization_user_relationships AS our +JOIN users + ON users.id = our.user_id +LEFT JOIN LATERAL ( + SELECT id, organization_id, workos_user_id, user_id, role_urn, workos_membership_id, workos_updated_at, workos_last_event_id, created_at, updated_at, deleted_at + FROM organization_role_assignments + WHERE organization_role_assignments.organization_id = our.organization_id + AND organization_role_assignments.workos_user_id = users.workos_id + AND organization_role_assignments.deleted_at IS NULL + ORDER BY organization_role_assignments.created_at + LIMIT 1 +) AS ora ON TRUE LEFT JOIN organization_roles ON ora.role_urn = 'role:organization:' || organization_roles.id::text - AND organization_roles.organization_id = ora.organization_id + AND organization_roles.organization_id = our.organization_id AND organization_roles.deleted IS FALSE AND organization_roles.workos_deleted IS FALSE LEFT JOIN global_roles ON ora.role_urn = 'role:global:' || global_roles.id::text AND global_roles.deleted IS FALSE AND global_roles.workos_deleted IS FALSE -WHERE ora.organization_id = $1 - AND ora.workos_user_id = $2 - AND COALESCE(organization_roles.id, global_roles.id) IS NOT NULL -ORDER BY ora.created_at +WHERE our.organization_id = $1 + AND users.workos_id = $2::text + AND users.workos_id IS NOT NULL + AND our.deleted IS FALSE LIMIT 1 ` @@ -182,7 +194,7 @@ type GetOrganizationRoleAssignmentByWorkosUserRow struct { UserID pgtype.Text WorkosUserID string WorkosMembershipID pgtype.Text - RoleID uuid.UUID + RoleID string RoleSlug string CreatedAt pgtype.Timestamptz } @@ -224,12 +236,13 @@ SELECT active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at, - COUNT(ora.id)::bigint AS member_count + COUNT(DISTINCT ora.id)::bigint AS member_count FROM active_roles LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = $1 AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at ORDER BY active_roles.role_kind DESC LIMIT 1 @@ -422,22 +435,31 @@ SELECT users.display_name, users.email, users.photo_url, - COALESCE(organization_roles.id, global_roles.id)::uuid AS role_id, - ora.created_at AS joined_at -FROM organization_role_assignments AS ora + COALESCE(organization_roles.id::text, global_roles.id::text, '')::text AS role_id, + COALESCE(ora.created_at, our.created_at)::timestamptz AS joined_at +FROM organization_user_relationships AS our JOIN users - ON users.id = ora.user_id + ON users.id = our.user_id +LEFT JOIN LATERAL ( + SELECT id, organization_id, workos_user_id, user_id, role_urn, workos_membership_id, workos_updated_at, workos_last_event_id, created_at, updated_at, deleted_at + FROM organization_role_assignments + WHERE organization_role_assignments.organization_id = our.organization_id + AND organization_role_assignments.workos_user_id = users.workos_id + AND organization_role_assignments.deleted_at IS NULL + ORDER BY organization_role_assignments.created_at + LIMIT 1 +) AS ora ON TRUE LEFT JOIN organization_roles ON ora.role_urn = 'role:organization:' || organization_roles.id::text - AND organization_roles.organization_id = ora.organization_id + AND organization_roles.organization_id = our.organization_id AND organization_roles.deleted IS FALSE AND organization_roles.workos_deleted IS FALSE LEFT JOIN global_roles ON ora.role_urn = 'role:global:' || global_roles.id::text AND global_roles.deleted IS FALSE AND global_roles.workos_deleted IS FALSE -WHERE ora.organization_id = $1 - AND COALESCE(organization_roles.id, global_roles.id) IS NOT NULL +WHERE our.organization_id = $1 + AND our.deleted IS FALSE ORDER BY users.email, users.id ` @@ -446,7 +468,7 @@ type ListAccessMembersRow struct { DisplayName string Email string PhotoUrl pgtype.Text - RoleID uuid.UUID + RoleID string JoinedAt pgtype.Timestamptz } @@ -498,12 +520,13 @@ SELECT active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at, - COUNT(ora.id)::bigint AS member_count + COUNT(DISTINCT ora.id)::bigint AS member_count FROM active_roles LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = $1 AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at ORDER BY active_roles.workos_slug ` @@ -749,26 +772,38 @@ func (q *Queries) ListOrganizationRoleAssignmentsBySlug(ctx context.Context, arg const listOrganizationRoleAssignmentsByWorkosUsers = `-- name: ListOrganizationRoleAssignmentsByWorkosUsers :many SELECT - ora.user_id, - ora.workos_user_id, - ora.workos_membership_id, - COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, - ora.created_at -FROM organization_role_assignments AS ora + our.user_id, + users.workos_id::text AS workos_user_id, + our.workos_membership_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug, '')::text AS role_slug, + COALESCE(ora.created_at, our.created_at)::timestamptz AS created_at +FROM organization_user_relationships AS our +JOIN users + ON users.id = our.user_id +LEFT JOIN LATERAL ( + SELECT id, organization_id, workos_user_id, user_id, role_urn, workos_membership_id, workos_updated_at, workos_last_event_id, created_at, updated_at, deleted_at + FROM organization_role_assignments + WHERE organization_role_assignments.organization_id = our.organization_id + AND organization_role_assignments.workos_user_id = users.workos_id + AND organization_role_assignments.deleted_at IS NULL + ORDER BY organization_role_assignments.created_at + LIMIT 1 +) AS ora ON TRUE LEFT JOIN organization_roles ON ora.role_urn = 'role:organization:' || organization_roles.id::text - AND organization_roles.organization_id = ora.organization_id + AND organization_roles.organization_id = our.organization_id AND organization_roles.deleted IS FALSE AND organization_roles.workos_deleted IS FALSE LEFT JOIN global_roles ON ora.role_urn = 'role:global:' || global_roles.id::text AND global_roles.deleted IS FALSE AND global_roles.workos_deleted IS FALSE -WHERE ora.organization_id = $1 - AND ora.workos_user_id = ANY($2::text[]) - AND ora.workos_membership_id IS NOT NULL - AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL -ORDER BY ora.workos_user_id, role_slug +WHERE our.organization_id = $1 + AND users.workos_id = ANY($2::text[]) + AND users.workos_id IS NOT NULL + AND our.workos_membership_id IS NOT NULL + AND our.deleted IS FALSE +ORDER BY users.workos_id, role_slug ` type ListOrganizationRoleAssignmentsByWorkosUsersParams struct { diff --git a/server/internal/access/role_manager.go b/server/internal/access/role_manager.go index 6099e8c9a7..0e9bbc4278 100644 --- a/server/internal/access/role_manager.go +++ b/server/internal/access/role_manager.go @@ -89,7 +89,7 @@ func (r *RoleManager) GetRoleByID(ctx context.Context, gramOrgID, id string) (*g return r.roleViewFromLocalRole(ctx, gramOrgID, role) } -// ListMembers returns locally known organization members with role IDs resolved from local role assignments. +// ListMembers returns locally known organization members and includes a role ID only when a local assignment exists. func (r *RoleManager) ListMembers(ctx context.Context, gramOrgID string) (*gen.ListMembersResult, error) { rows, err := repo.New(r.db).ListAccessMembers(ctx, gramOrgID) if err != nil { @@ -103,7 +103,7 @@ func (r *RoleManager) ListMembers(ctx context.Context, gramOrgID string) (*gen.L Name: conv.Default(row.DisplayName, row.Email), Email: row.Email, PhotoURL: conv.FromPGText[string](row.PhotoUrl), - RoleID: row.RoleID.String(), + RoleID: row.RoleID, JoinedAt: conv.FromPGTimestamptz(row.JoinedAt), }) } @@ -528,7 +528,7 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r Name: memberName, Email: connectedUser.Email, PhotoURL: conv.FromPGText[string](connectedUser.PhotoUrl), - RoleID: existing.RoleID.String(), + RoleID: existing.RoleID, JoinedAt: conv.FromPGTimestamptz(existing.CreatedAt), }, After: &gen.AccessMember{ diff --git a/server/internal/access/role_manager_test.go b/server/internal/access/role_manager_test.go index be2d88826d..3badd4096a 100644 --- a/server/internal/access/role_manager_test.go +++ b/server/internal/access/role_manager_test.go @@ -72,7 +72,7 @@ func TestRoleManager_MembersAndCounts(t *testing.T) { manager := ti.service.roleMgr members, err := manager.ListMembers(ctx, authCtx.ActiveOrganizationID) require.NoError(t, err) - require.Len(t, members.Members, 2) + require.Len(t, members.Members, 3) rolePrincipals, err := manager.MemberRolePrincipals(ctx, authCtx.ActiveOrganizationID, "user_2") require.NoError(t, err) @@ -92,7 +92,7 @@ func TestRoleManager_MembersAndCounts(t *testing.T) { require.Equal(t, 1, counts["Custom Builder"]) } -func TestRoleManager_AssignMembersToRoleRequiresLocalAssignment(t *testing.T) { +func TestRoleManager_AssignMembersToRoleAcceptsConnectedMemberWithoutAssignment(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -103,7 +103,6 @@ func TestRoleManager_AssignMembersToRoleRequiresLocalAssignment(t *testing.T) { seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "u1@example.com", "User 1", "user_1", "membership_1") assigned, _, err := ti.service.roleMgr.assignMembersToRoleTx(ctx, ti.conn, authCtx.ActiveOrganizationID, "custom-builder", []string{"local_user_1"}) - require.Error(t, err) - require.Equal(t, 0, assigned) - require.Contains(t, err.Error(), "member role assignment not found") + require.NoError(t, err) + require.Equal(t, 1, assigned) } diff --git a/server/internal/access/setup_test.go b/server/internal/access/setup_test.go index 86cfa5abab..b020579d61 100644 --- a/server/internal/access/setup_test.go +++ b/server/internal/access/setup_test.go @@ -169,6 +169,30 @@ func seedRole(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizatio return row.ID.String() } +func seedGlobalRole(t *testing.T, ctx context.Context, conn *pgxpool.Pool, role workos.Role) string { + t.Helper() + + createdAt, err := time.Parse(time.RFC3339, role.CreatedAt) + require.NoError(t, err) + updatedAt, err := time.Parse(time.RFC3339, role.UpdatedAt) + require.NoError(t, err) + + err = accessrepo.New(conn).UpsertGlobalRole(ctx, accessrepo.UpsertGlobalRoleParams{ + WorkosSlug: role.Slug, + WorkosName: role.Name, + WorkosDescription: conv.ToPGTextEmpty(role.Description), + WorkosCreatedAt: conv.ToPGTimestamptz(createdAt), + WorkosUpdatedAt: conv.ToPGTimestamptz(updatedAt), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + require.NoError(t, err) + + row, err := accessrepo.New(conn).GetGlobalRoleBySlug(ctx, role.Slug) + require.NoError(t, err) + + return row.ID.String() +} + func seedRoleAssignment(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID, userID string, member workos.Member) { t.Helper() diff --git a/server/internal/access/updatememberrole_test.go b/server/internal/access/updatememberrole_test.go index 21886e44a4..ac25969eed 100644 --- a/server/internal/access/updatememberrole_test.go +++ b/server/internal/access/updatememberrole_test.go @@ -78,7 +78,7 @@ func TestService_UpdateMemberRole_WorkOSMembershipNotFound(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) - seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "") _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) require.Error(t, err) From e4240a5e3cb9af229c060b7d6c4f1cbb9b175e5e Mon Sep 17 00:00:00 2001 From: tgmendes Date: Thu, 21 May 2026 10:34:25 +0100 Subject: [PATCH 10/12] chore: clarify role principal migration --- server/internal/access/impl.go | 5 +- server/internal/access/role_manager.go | 262 ++++++++++++------------- server/internal/authz/engine.go | 2 + server/internal/authz/grants.go | 5 + 4 files changed, 132 insertions(+), 142 deletions(-) diff --git a/server/internal/access/impl.go b/server/internal/access/impl.go index d0cf381174..7f78d6c07b 100644 --- a/server/internal/access/impl.go +++ b/server/internal/access/impl.go @@ -213,8 +213,7 @@ func (s *Service) DeleteRole(ctx context.Context, payload *gen.DeleteRolePayload if err != nil { return err } - deletedRole := deleted.Role - trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSlug(deletedRole.Slug)) + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSlug(deleted.Slug)) return nil } @@ -322,6 +321,8 @@ func (s *Service) ListGrants(ctx context.Context, _ *gen.ListGrantsPayload) (*ge } roleSlugs := make([]string, 0, len(rolePrincipals)) for _, role := range rolePrincipals { + // Effective-grant responses must include grants stored under either the + // canonical role URN or the legacy role slug during the migration. rolePrincipalURNs, err := authz.RolePrincipals(role.RoleSlug, role.PrincipalUrn) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "build role principals").Log(ctx, logger) diff --git a/server/internal/access/role_manager.go b/server/internal/access/role_manager.go index 0e9bbc4278..ac223393a2 100644 --- a/server/internal/access/role_manager.go +++ b/server/internal/access/role_manager.go @@ -69,7 +69,16 @@ func (r *RoleManager) ListRoles(ctx context.Context, gramOrgID string) (*gen.Lis roles := make([]*gen.Role, 0, len(rows)) for _, row := range rows { - role, err := r.roleViewFromLocalRole(ctx, gramOrgID, localRoleFromActiveRow(row)) + role, err := r.roleViewFromLocalRole(ctx, gramOrgID, localRole{ + ID: row.ID.String(), + PrincipalURN: row.RoleUrn, + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), + }) if err != nil { return nil, err } @@ -137,21 +146,30 @@ func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID str } defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) - now := time.Now().UTC().Format(time.RFC3339) - createdRow, err := repo.New(tx).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, workos.Role{ - ID: "", - Type: "", - Name: payload.Name, - Slug: roleSlug, - Description: payload.Description, - CreatedAt: now, - UpdatedAt: now, - })) + now := time.Now().UTC() + createdRow, err := repo.New(tx).UpsertOrganizationRole(ctx, repo.UpsertOrganizationRoleParams{ + OrganizationID: gramOrgID, + WorkosSlug: roleSlug, + WorkosName: payload.Name, + WorkosDescription: conv.ToPGTextEmpty(payload.Description), + WorkosCreatedAt: conv.ToPGTimestamptz(now), + WorkosUpdatedAt: conv.ToPGTimestamptz(now), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "upsert local role record").Log(ctx, r.logger) } - createdRole := localRoleFromUpsertRow(createdRow) + createdRole := localRole{ + ID: createdRow.ID.String(), + PrincipalURN: createdRow.RoleUrn, + Name: createdRow.WorkosName, + Slug: createdRow.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](createdRow.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(createdRow.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(createdRow.WorkosUpdatedAt), + MemberCount: int(createdRow.MemberCount), + } trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleID(createdRole.ID)) if _, err := authz.SyncGrantsTx(ctx, tx, gramOrgID, roleSlug, createdRole.PrincipalURN, roleGrantPayloads(payload.Grants)); err != nil { @@ -273,20 +291,29 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str localRecord.Description = *payload.Description } localRecord.UpdatedAt = time.Now().UTC().Format(time.RFC3339) - updatedRow, err := repo.New(tx).UpsertOrganizationRole(ctx, organizationRoleParams(gramOrgID, workos.Role{ - ID: "", - Type: "", - Name: localRecord.Name, - Slug: localRecord.Slug, - Description: localRecord.Description, - CreatedAt: localRecord.CreatedAt, - UpdatedAt: localRecord.UpdatedAt, - })) + updatedRow, err := repo.New(tx).UpsertOrganizationRole(ctx, repo.UpsertOrganizationRoleParams{ + OrganizationID: gramOrgID, + WorkosSlug: localRecord.Slug, + WorkosName: localRecord.Name, + WorkosDescription: conv.ToPGTextEmpty(localRecord.Description), + WorkosCreatedAt: conv.ToPGTimestamptz(workosTimeOrNow(localRecord.CreatedAt)), + WorkosUpdatedAt: conv.ToPGTimestamptz(workosTimeOrNow(localRecord.UpdatedAt)), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "upsert local role record").Log(ctx, r.logger) } - updatedRole = localRoleFromUpsertRow(updatedRow) + updatedRole = localRole{ + ID: updatedRow.ID.String(), + PrincipalURN: updatedRow.RoleUrn, + Name: updatedRow.WorkosName, + Slug: updatedRow.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](updatedRow.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(updatedRow.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(updatedRow.WorkosUpdatedAt), + MemberCount: int(updatedRow.MemberCount), + } if payload.Grants != nil { syncedGrants, err := authz.SyncGrantsTx(ctx, tx, gramOrgID, currentRole.Slug, currentRole.PrincipalURN, roleGrantPayloads(payload.Grants)) @@ -325,7 +352,17 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str } } - updatedRoleView := roleViewFromLocalRoleAndGrants(updatedRole, existingRole.Grants) + updatedRoleView := &gen.Role{ + ID: updatedRole.ID, + Name: updatedRole.Name, + Slug: updatedRole.Slug, + Description: updatedRole.Description, + IsSystem: isSystemRole(updatedRole.Slug), + Grants: existingRole.Grants, + MemberCount: updatedRole.MemberCount, + CreatedAt: conv.Default(updatedRole.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), + UpdatedAt: conv.Default(updatedRole.UpdatedAt, time.Time{}.UTC().Format(time.RFC3339)), + } if updatedGrants != nil { updatedRoleView.Grants = updatedGrants } @@ -353,23 +390,19 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str return roleUpdateResult{Before: existingRole, After: updatedRoleView, Role: updatedRole}, nil } -type roleDeleteResult struct { - Role localRole -} - // DeleteRole deletes a custom local role, reassignment records, grants, and audit entry atomically, then best-effort syncs WorkOS after commit. -func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, roleID string, actor accessAuditActor) (roleDeleteResult, error) { +func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, roleID string, actor accessAuditActor) (localRole, error) { currentRole, err := r.getLocalRoleByID(ctx, gramOrgID, roleID) if err != nil { - return roleDeleteResult{}, err + return localRole{}, err } if isSystemRole(currentRole.Slug) { - return roleDeleteResult{}, oops.E(oops.CodeBadRequest, nil, "system roles cannot be deleted").Log(ctx, r.logger) + return localRole{}, oops.E(oops.CodeBadRequest, nil, "system roles cannot be deleted").Log(ctx, r.logger) } tx, err := r.db.Begin(ctx) if err != nil { - return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "begin role transaction").Log(ctx, r.logger) + return localRole{}, oops.E(oops.CodeUnexpected, err, "begin role transaction").Log(ctx, r.logger) } defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) @@ -378,7 +411,7 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro WorkosRoleSlug: currentRole.Slug, }) if err != nil { - return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) + return localRole{}, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) } trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) @@ -386,14 +419,22 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro for _, row := range rows { membershipID := conv.FromPGTextOrEmpty[string](row.WorkosMembershipID) if row.WorkosUserID != "" { - replaced, err := repo.New(tx).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, row.WorkosUserID, authz.SystemRoleMember, "", membershipID)) + replaced, err := repo.New(tx).ReplaceOrganizationRoleAssignment(ctx, repo.ReplaceOrganizationRoleAssignmentParams{ + OrganizationID: gramOrgID, + WorkosUserID: row.WorkosUserID, + WorkosRoleSlug: authz.SystemRoleMember, + UserID: conv.ToPGTextEmpty(""), + WorkosMembershipID: conv.ToPGTextEmpty(membershipID), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Now().UTC()), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) + return localRole{}, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) } if replaced == 0 { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - return roleDeleteResult{}, oops.E(oops.CodeUnexpected, nil, "upsert local role assignment record").Log(ctx, r.logger) + return localRole{}, oops.E(oops.CodeUnexpected, nil, "upsert local role assignment record").Log(ctx, r.logger) } } workosSyncs = append(workosSyncs, func(ctx context.Context) { @@ -414,14 +455,14 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro }) if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) - return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "mark local role record deleted").Log(ctx, r.logger) + return localRole{}, oops.E(oops.CodeUnexpected, err, "mark local role record deleted").Log(ctx, r.logger) } if deletedCount == 0 { - return roleDeleteResult{}, oops.E(oops.CodeNotFound, nil, "role not found").Log(ctx, r.logger) + return localRole{}, oops.E(oops.CodeNotFound, nil, "role not found").Log(ctx, r.logger) } if err := authz.DeleteRoleGrants(ctx, repo.New(tx), gramOrgID, currentRole.Slug, currentRole.PrincipalURN); err != nil { - return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "delete grants for deleted role").Log(ctx, r.logger) + return localRole{}, oops.E(oops.CodeUnexpected, err, "delete grants for deleted role").Log(ctx, r.logger) } if err := r.audit.LogAccessRoleDelete(ctx, tx, audit.LogAccessRoleDeleteEvent{ @@ -433,11 +474,11 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro RoleName: currentRole.Name, RoleSlug: currentRole.Slug, }); err != nil { - return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "log access role deletion").Log(ctx, r.logger) + return localRole{}, oops.E(oops.CodeUnexpected, err, "log access role deletion").Log(ctx, r.logger) } if err := tx.Commit(ctx); err != nil { - return roleDeleteResult{}, oops.E(oops.CodeUnexpected, err, "commit role transaction").Log(ctx, r.logger) + return localRole{}, oops.E(oops.CodeUnexpected, err, "commit role transaction").Log(ctx, r.logger) } workosSyncs = append(workosSyncs, func(ctx context.Context) { @@ -451,7 +492,7 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro r.runWorkOSSyncs(ctx, workosSyncs) - return roleDeleteResult{Role: currentRole}, nil + return currentRole, nil } type memberRoleUpdateContext struct { @@ -506,7 +547,15 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r } if existing.WorkosUserID != "" && role.Slug != "" { - replaced, err := repo.New(tx).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, existing.WorkosUserID, role.Slug, connectedUser.ID, membershipID)) + replaced, err := repo.New(tx).ReplaceOrganizationRoleAssignment(ctx, repo.ReplaceOrganizationRoleAssignmentParams{ + OrganizationID: gramOrgID, + WorkosUserID: existing.WorkosUserID, + WorkosRoleSlug: role.Slug, + UserID: conv.ToPGTextEmpty(connectedUser.ID), + WorkosMembershipID: conv.ToPGTextEmpty(membershipID), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Now().UTC()), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) @@ -615,7 +664,16 @@ func (r *RoleManager) getLocalRoleByIDTx(ctx context.Context, dbtx repo.DBTX, gr } trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) - return localRoleFromRoleRow(row), nil + return localRole{ + ID: row.ID.String(), + PrincipalURN: row.RoleUrn, + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), + }, nil } func (r *RoleManager) getLocalRoleBySlugTx(ctx context.Context, dbtx repo.DBTX, gramOrgID, slug string) (localRole, error) { @@ -631,7 +689,16 @@ func (r *RoleManager) getLocalRoleBySlugTx(ctx context.Context, dbtx repo.DBTX, } trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) - return localRoleFromSlugRow(row), nil + return localRole{ + ID: row.ID.String(), + PrincipalURN: row.RoleUrn, + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), + }, nil } type memberAssignmentTarget struct { @@ -747,7 +814,15 @@ func (r *RoleManager) assignMembersToRoleTx(ctx context.Context, dbtx repo.DBTX, workosSyncs := make([]workosSync, 0, len(targets)) for _, target := range targets { if target.WorkosUserID != "" && roleSlug != "" { - replaced, err := repo.New(dbtx).ReplaceOrganizationRoleAssignment(ctx, replaceRoleAssignmentParams(gramOrgID, target.WorkosUserID, roleSlug, target.UserID, target.MembershipID)) + replaced, err := repo.New(dbtx).ReplaceOrganizationRoleAssignment(ctx, repo.ReplaceOrganizationRoleAssignmentParams{ + OrganizationID: gramOrgID, + WorkosUserID: target.WorkosUserID, + WorkosRoleSlug: roleSlug, + UserID: conv.ToPGTextEmpty(target.UserID), + WorkosMembershipID: conv.ToPGTextEmpty(target.MembershipID), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Now().UTC()), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) return 0, nil, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) @@ -825,6 +900,8 @@ func retryWorkOSError(err error) bool { // roleViewFromLocalRole converts a local role record into the public API role view and attaches local grants. func (r *RoleManager) roleViewFromLocalRole(ctx context.Context, organizationID string, role localRole) (*gen.Role, error) { + // Role grant reads intentionally include both canonical role URNs and + // legacy role slugs until old principal_grants rows are backfilled. grants, err := authz.GrantsForRole(ctx, r.logger, r.db, organizationID, role.Slug, role.PrincipalURN) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "load role grants").Log(ctx, r.logger) @@ -847,101 +924,6 @@ func (r *RoleManager) roleViewFromLocalRole(ctx context.Context, organizationID }, nil } -func roleViewFromLocalRoleAndGrants(role localRole, grants []*gen.RoleGrant) *gen.Role { - return &gen.Role{ - ID: role.ID, - Name: role.Name, - Slug: role.Slug, - Description: role.Description, - IsSystem: isSystemRole(role.Slug), - Grants: grants, - MemberCount: role.MemberCount, - CreatedAt: conv.Default(role.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), - UpdatedAt: conv.Default(role.UpdatedAt, time.Time{}.UTC().Format(time.RFC3339)), - } -} - -// localRoleFromActiveRow converts a sqlc active-role row into the manager's internal local role record shape. -func localRoleFromActiveRow(row repo.ListActiveOrganizationRolesRow) localRole { - return localRole{ - ID: row.ID.String(), - PrincipalURN: row.RoleUrn, - Name: row.WorkosName, - Slug: row.WorkosSlug, - Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), - CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), - UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), - MemberCount: int(row.MemberCount), - } -} - -// localRoleFromRoleRow converts a sqlc role lookup row into the manager's internal local role record shape. -func localRoleFromRoleRow(row repo.GetOrganizationRoleByIDRow) localRole { - return localRole{ - ID: row.ID.String(), - PrincipalURN: row.RoleUrn, - Name: row.WorkosName, - Slug: row.WorkosSlug, - Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), - CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), - UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), - MemberCount: int(row.MemberCount), - } -} - -func localRoleFromUpsertRow(row repo.UpsertOrganizationRoleRow) localRole { - return localRole{ - ID: row.ID.String(), - PrincipalURN: row.RoleUrn, - Name: row.WorkosName, - Slug: row.WorkosSlug, - Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), - CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), - UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), - MemberCount: int(row.MemberCount), - } -} - -// localRoleFromSlugRow converts a sqlc role slug lookup row into the manager's internal local role record shape. -func localRoleFromSlugRow(row repo.GetActiveOrganizationRoleBySlugRow) localRole { - return localRole{ - ID: row.ID.String(), - PrincipalURN: row.RoleUrn, - Name: row.WorkosName, - Slug: row.WorkosSlug, - Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), - CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), - UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), - MemberCount: int(row.MemberCount), - } -} - -// organizationRoleParams builds the SQL parameters for storing the authoritative local role record. -func organizationRoleParams(gramOrgID string, role workos.Role) repo.UpsertOrganizationRoleParams { - return repo.UpsertOrganizationRoleParams{ - OrganizationID: gramOrgID, - WorkosSlug: role.Slug, - WorkosName: role.Name, - WorkosDescription: conv.ToPGTextEmpty(role.Description), - WorkosCreatedAt: conv.ToPGTimestamptz(workosTimeOrNow(role.CreatedAt)), - WorkosUpdatedAt: conv.ToPGTimestamptz(workosTimeOrNow(role.UpdatedAt)), - WorkosLastEventID: conv.ToPGTextEmpty(""), - } -} - -// replaceRoleAssignmentParams builds SQL parameters for storing the authoritative local role assignment. -func replaceRoleAssignmentParams(gramOrgID, workosUserID, roleSlug, userID, membershipID string) repo.ReplaceOrganizationRoleAssignmentParams { - return repo.ReplaceOrganizationRoleAssignmentParams{ - OrganizationID: gramOrgID, - WorkosUserID: workosUserID, - WorkosRoleSlug: roleSlug, - UserID: conv.ToPGTextEmpty(userID), - WorkosMembershipID: conv.ToPGTextEmpty(membershipID), - WorkosUpdatedAt: conv.ToPGTimestamptz(time.Now().UTC()), - WorkosLastEventID: conv.ToPGTextEmpty(""), - } -} - // workosTimeOrNow parses a WorkOS RFC3339 timestamp or returns the current UTC time when WorkOS omits or malforms it. func workosTimeOrNow(value string) time.Time { if value == "" { diff --git a/server/internal/authz/engine.go b/server/internal/authz/engine.go index 6a475f921b..2d66071eee 100644 --- a/server/internal/authz/engine.go +++ b/server/internal/authz/engine.go @@ -144,6 +144,8 @@ func (e *Engine) PrepareContext(ctx context.Context) (context.Context, error) { return ctx, fmt.Errorf("resolve role slugs: %w", err) } for _, role := range rolePrincipals { + // Load grants for both canonical role:: principals and + // legacy role: principals while existing grant rows are backfilled. rolePrincipalURNs, err := RolePrincipals(role.RoleSlug, role.PrincipalUrn) if err != nil { return ctx, fmt.Errorf("build role principals: %w", err) diff --git a/server/internal/authz/grants.go b/server/internal/authz/grants.go index 2f34428e51..ca4f30f2eb 100644 --- a/server/internal/authz/grants.go +++ b/server/internal/authz/grants.go @@ -146,6 +146,9 @@ func SyncGrantsTx(ctx context.Context, dbtx repo.DBTX, orgID string, roleSlug st } q := repo.New(dbtx) + // During the role-principal migration, replace grants for both the new + // role:: principal and the legacy role: principal. New + // writes below only insert the canonical URN form. if err := DeleteRoleGrants(ctx, q, orgID, roleSlug, rolePrincipalURN); err != nil { return nil, err } @@ -241,6 +244,8 @@ func DeleteRoleGrants(ctx context.Context, q *repo.Queries, orgID, roleSlug, rol if legacyPrincipalURN.String() == principalURN.String() { return nil } + // Legacy grants were keyed by role slug (role:). Keep deleting that + // principal alongside the canonical role URN until the backfill is complete. if _, err := q.DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ OrganizationID: orgID, PrincipalUrn: legacyPrincipalURN, From c9839aea18cb1a87c5c37a226422303c8eed6da6 Mon Sep 17 00:00:00 2001 From: tgmendes Date: Thu, 21 May 2026 11:01:44 +0100 Subject: [PATCH 11/12] chore: harden local role state sync --- server/internal/access/createrole_test.go | 1 - server/internal/access/impl.go | 4 - server/internal/access/queries.sql | 47 +++++--- server/internal/access/repo/models.go | 14 +++ server/internal/access/repo/queries.sql.go | 96 ++++++++++++---- server/internal/access/role_manager.go | 19 ++-- server/internal/access/role_manager_test.go | 105 ++++++++++++++++++ .../internal/access/updatememberrole_test.go | 2 +- server/internal/authz/grants.go | 36 ++++++ server/internal/authz/load_test.go | 21 ++++ 10 files changed, 296 insertions(+), 49 deletions(-) diff --git a/server/internal/access/createrole_test.go b/server/internal/access/createrole_test.go index 079e50ff7d..b1d9215636 100644 --- a/server/internal/access/createrole_test.go +++ b/server/internal/access/createrole_test.go @@ -233,7 +233,6 @@ func TestService_CreateRole_LocalRoleWriteFailureDoesNotAssignMembers(t *testing _, err = ti.service.roleMgr.CreateRole(ctx, authCtx.ActiveOrganizationID, mockidp.MockOrgID, accessAuditActor{ Principal: urn.NewPrincipal(urn.PrincipalTypeUser, authCtx.UserID), DisplayName: authCtx.Email, - Slug: nil, }, &gen.CreateRolePayload{ Name: "Broken Builder", Description: "Will fail local write", diff --git a/server/internal/access/impl.go b/server/internal/access/impl.go index 7f78d6c07b..a5a9fef1e8 100644 --- a/server/internal/access/impl.go +++ b/server/internal/access/impl.go @@ -155,7 +155,6 @@ func (s *Service) CreateRole(ctx context.Context, payload *gen.CreateRolePayload created, err := s.roleMgr.CreateRole(ctx, ac.ActiveOrganizationID, workosOrgID, accessAuditActor{ Principal: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), DisplayName: ac.Email, - Slug: nil, }, payload) if err != nil { return nil, err @@ -181,7 +180,6 @@ func (s *Service) UpdateRole(ctx context.Context, payload *gen.UpdateRolePayload updated, err := s.roleMgr.UpdateRole(ctx, ac.ActiveOrganizationID, workosOrgID, accessAuditActor{ Principal: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), DisplayName: ac.Email, - Slug: nil, }, payload) if err != nil { return nil, err @@ -208,7 +206,6 @@ func (s *Service) DeleteRole(ctx context.Context, payload *gen.DeleteRolePayload deleted, err := s.roleMgr.DeleteRole(ctx, ac.ActiveOrganizationID, workosOrgID, payload.ID, accessAuditActor{ Principal: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), DisplayName: ac.Email, - Slug: nil, }) if err != nil { return err @@ -363,7 +360,6 @@ func (s *Service) UpdateMemberRole(ctx context.Context, payload *gen.UpdateMembe memberUpdate, err := s.roleMgr.UpdateMemberRole(ctx, ac.ActiveOrganizationID, payload.UserID, payload.RoleID, accessAuditActor{ Principal: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), DisplayName: ac.Email, - Slug: nil, }) if err != nil { return nil, err diff --git a/server/internal/access/queries.sql b/server/internal/access/queries.sql index c5dd496ad8..5387dae4f7 100644 --- a/server/internal/access/queries.sql +++ b/server/internal/access/queries.sql @@ -73,9 +73,8 @@ WHERE deleted_at IS NULL ORDER BY workos_slug; -- name: UpsertGlobalRole :exec --- Upsert an environment-level WorkOS role. Caller must have already passed --- the row through ShouldProcessEvent. Resurrects a previously soft-deleted --- role on conflict. +-- Upsert an environment-level role. WorkOS sync callers pass an event ID; +-- local/bootstrap callers pass NULL so an existing WorkOS event cursor is preserved. INSERT INTO global_roles ( workos_slug, workos_name, @@ -95,7 +94,7 @@ ON CONFLICT (workos_slug) DO UPDATE SET workos_name = EXCLUDED.workos_name, workos_description = EXCLUDED.workos_description, workos_updated_at = EXCLUDED.workos_updated_at, - workos_last_event_id = EXCLUDED.workos_last_event_id, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, global_roles.workos_last_event_id), deleted_at = NULL, workos_deleted_at = NULL, updated_at = clock_timestamp(); @@ -123,9 +122,8 @@ WHERE organization_id = @organization_id ORDER BY workos_slug; -- name: UpsertOrganizationRole :one --- Upsert an org-scoped WorkOS role. Caller must have already passed the row --- through ShouldProcessEvent. Resurrects a previously soft-deleted role on --- conflict. +-- Upsert an org-scoped role. WorkOS sync callers pass an event ID; local role +-- lifecycle callers pass NULL so an existing WorkOS event cursor is preserved. WITH upserted AS ( INSERT INTO organization_roles ( organization_id, @@ -148,7 +146,7 @@ ON CONFLICT (organization_id, workos_slug) DO UPDATE SET workos_name = EXCLUDED.workos_name, workos_description = EXCLUDED.workos_description, workos_updated_at = EXCLUDED.workos_updated_at, - workos_last_event_id = EXCLUDED.workos_last_event_id, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, organization_roles.workos_last_event_id), deleted_at = NULL, workos_deleted_at = NULL, updated_at = clock_timestamp() @@ -175,6 +173,7 @@ LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = upserted.organization_id AND ora.role_urn = 'role:organization:' || upserted.id::text AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL GROUP BY upserted.id, upserted.workos_slug, upserted.workos_name, upserted.workos_description, upserted.workos_created_at, upserted.workos_updated_at; -- name: MarkOrganizationRoleDeleted :execrows @@ -189,8 +188,7 @@ WHERE organization_id = @organization_id -- name: MarkOrganizationRoleDeletedLocally :execrows UPDATE organization_roles -SET workos_last_event_id = @workos_last_event_id, - deleted_at = clock_timestamp(), +SET deleted_at = clock_timestamp(), updated_at = clock_timestamp() WHERE organization_id = @organization_id AND workos_slug = @workos_slug @@ -316,6 +314,7 @@ LEFT JOIN global_roles AND global_roles.workos_deleted IS FALSE WHERE ora.organization_id = @organization_id AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) = @workos_role_slug + AND ora.deleted_at IS NULL ORDER BY ora.workos_user_id; -- name: GetOrganizationRoleAssignmentByWorkosUser :one @@ -438,8 +437,27 @@ LEFT JOIN global_roles WHERE ora.organization_id = @organization_id AND ora.workos_user_id = @workos_user_id AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL + AND ora.deleted_at IS NULL ORDER BY role_slug; +-- name: ListOrganizationRoleAssignmentRecordsByWorkosUser :many +SELECT + id, + organization_id, + workos_user_id, + user_id, + role_urn, + workos_membership_id, + workos_updated_at, + workos_last_event_id, + created_at, + updated_at, + deleted_at +FROM organization_role_assignments +WHERE organization_id = @organization_id + AND workos_user_id = @workos_user_id +ORDER BY created_at; + -- name: UpsertOrganizationRoleAssignment :execrows WITH input_role_urn AS ( SELECT 'role:organization:' || id::text AS role_urn @@ -477,7 +495,7 @@ ON CONFLICT (organization_id, workos_user_id, role_urn) WHERE deleted_at IS NULL user_id = COALESCE(EXCLUDED.user_id, organization_role_assignments.user_id), workos_membership_id = EXCLUDED.workos_membership_id, workos_updated_at = EXCLUDED.workos_updated_at, - workos_last_event_id = EXCLUDED.workos_last_event_id, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, organization_role_assignments.workos_last_event_id), updated_at = clock_timestamp(); -- name: ReplaceOrganizationRoleAssignment :one @@ -518,16 +536,19 @@ upserted AS ( user_id = COALESCE(EXCLUDED.user_id, organization_role_assignments.user_id), workos_membership_id = EXCLUDED.workos_membership_id, workos_updated_at = EXCLUDED.workos_updated_at, - workos_last_event_id = EXCLUDED.workos_last_event_id, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, organization_role_assignments.workos_last_event_id), updated_at = clock_timestamp() RETURNING role_urn ), deleted AS ( -DELETE FROM organization_role_assignments +UPDATE organization_role_assignments +SET deleted_at = clock_timestamp(), + updated_at = clock_timestamp() WHERE organization_role_assignments.organization_id = @organization_id AND organization_role_assignments.workos_user_id = @workos_user_id AND EXISTS (SELECT 1 FROM upserted) AND organization_role_assignments.role_urn NOT IN (SELECT role_urn FROM upserted) + AND organization_role_assignments.deleted_at IS NULL RETURNING 1 ) SELECT COUNT(*)::bigint FROM upserted; diff --git a/server/internal/access/repo/models.go b/server/internal/access/repo/models.go index 8faea64552..376864a673 100644 --- a/server/internal/access/repo/models.go +++ b/server/internal/access/repo/models.go @@ -61,3 +61,17 @@ type OrganizationRole struct { DeletedAt pgtype.Timestamptz Deleted bool } + +type OrganizationRoleAssignment struct { + ID uuid.UUID + OrganizationID string + WorkosUserID string + UserID pgtype.Text + RoleUrn string + WorkosMembershipID pgtype.Text + WorkosUpdatedAt pgtype.Timestamptz + WorkosLastEventID pgtype.Text + CreatedAt pgtype.Timestamptz + UpdatedAt pgtype.Timestamptz + DeletedAt pgtype.Timestamptz +} diff --git a/server/internal/access/repo/queries.sql.go b/server/internal/access/repo/queries.sql.go index ba8d6d60c8..2eb53d281d 100644 --- a/server/internal/access/repo/queries.sql.go +++ b/server/internal/access/repo/queries.sql.go @@ -676,6 +676,7 @@ LEFT JOIN global_roles WHERE ora.organization_id = $1 AND ora.workos_user_id = $2 AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL + AND ora.deleted_at IS NULL ORDER BY role_slug ` @@ -709,6 +710,62 @@ func (q *Queries) ListMemberRolePrincipalsByWorkosUser(ctx context.Context, arg return items, nil } +const listOrganizationRoleAssignmentRecordsByWorkosUser = `-- name: ListOrganizationRoleAssignmentRecordsByWorkosUser :many +SELECT + id, + organization_id, + workos_user_id, + user_id, + role_urn, + workos_membership_id, + workos_updated_at, + workos_last_event_id, + created_at, + updated_at, + deleted_at +FROM organization_role_assignments +WHERE organization_id = $1 + AND workos_user_id = $2 +ORDER BY created_at +` + +type ListOrganizationRoleAssignmentRecordsByWorkosUserParams struct { + OrganizationID string + WorkosUserID string +} + +func (q *Queries) ListOrganizationRoleAssignmentRecordsByWorkosUser(ctx context.Context, arg ListOrganizationRoleAssignmentRecordsByWorkosUserParams) ([]OrganizationRoleAssignment, error) { + rows, err := q.db.Query(ctx, listOrganizationRoleAssignmentRecordsByWorkosUser, arg.OrganizationID, arg.WorkosUserID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []OrganizationRoleAssignment + for rows.Next() { + var i OrganizationRoleAssignment + if err := rows.Scan( + &i.ID, + &i.OrganizationID, + &i.WorkosUserID, + &i.UserID, + &i.RoleUrn, + &i.WorkosMembershipID, + &i.WorkosUpdatedAt, + &i.WorkosLastEventID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listOrganizationRoleAssignmentsBySlug = `-- name: ListOrganizationRoleAssignmentsBySlug :many SELECT ora.user_id, @@ -728,6 +785,7 @@ LEFT JOIN global_roles AND global_roles.workos_deleted IS FALSE WHERE ora.organization_id = $1 AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) = $2 + AND ora.deleted_at IS NULL ORDER BY ora.workos_user_id ` @@ -1004,22 +1062,20 @@ func (q *Queries) MarkOrganizationRoleDeleted(ctx context.Context, arg MarkOrgan const markOrganizationRoleDeletedLocally = `-- name: MarkOrganizationRoleDeletedLocally :execrows UPDATE organization_roles -SET workos_last_event_id = $1, - deleted_at = clock_timestamp(), +SET deleted_at = clock_timestamp(), updated_at = clock_timestamp() -WHERE organization_id = $2 - AND workos_slug = $3 +WHERE organization_id = $1 + AND workos_slug = $2 AND deleted_at IS NULL ` type MarkOrganizationRoleDeletedLocallyParams struct { - WorkosLastEventID pgtype.Text - OrganizationID string - WorkosSlug string + OrganizationID string + WorkosSlug string } func (q *Queries) MarkOrganizationRoleDeletedLocally(ctx context.Context, arg MarkOrganizationRoleDeletedLocallyParams) (int64, error) { - result, err := q.db.Exec(ctx, markOrganizationRoleDeletedLocally, arg.WorkosLastEventID, arg.OrganizationID, arg.WorkosSlug) + result, err := q.db.Exec(ctx, markOrganizationRoleDeletedLocally, arg.OrganizationID, arg.WorkosSlug) if err != nil { return 0, err } @@ -1064,16 +1120,19 @@ upserted AS ( user_id = COALESCE(EXCLUDED.user_id, organization_role_assignments.user_id), workos_membership_id = EXCLUDED.workos_membership_id, workos_updated_at = EXCLUDED.workos_updated_at, - workos_last_event_id = EXCLUDED.workos_last_event_id, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, organization_role_assignments.workos_last_event_id), updated_at = clock_timestamp() RETURNING role_urn ), deleted AS ( -DELETE FROM organization_role_assignments +UPDATE organization_role_assignments +SET deleted_at = clock_timestamp(), + updated_at = clock_timestamp() WHERE organization_role_assignments.organization_id = $1 AND organization_role_assignments.workos_user_id = $3 AND EXISTS (SELECT 1 FROM upserted) AND organization_role_assignments.role_urn NOT IN (SELECT role_urn FROM upserted) + AND organization_role_assignments.deleted_at IS NULL RETURNING 1 ) SELECT COUNT(*)::bigint FROM upserted @@ -1124,7 +1183,7 @@ ON CONFLICT (workos_slug) DO UPDATE SET workos_name = EXCLUDED.workos_name, workos_description = EXCLUDED.workos_description, workos_updated_at = EXCLUDED.workos_updated_at, - workos_last_event_id = EXCLUDED.workos_last_event_id, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, global_roles.workos_last_event_id), deleted_at = NULL, workos_deleted_at = NULL, updated_at = clock_timestamp() @@ -1139,9 +1198,8 @@ type UpsertGlobalRoleParams struct { WorkosLastEventID pgtype.Text } -// Upsert an environment-level WorkOS role. Caller must have already passed -// the row through ShouldProcessEvent. Resurrects a previously soft-deleted -// role on conflict. +// Upsert an environment-level role. WorkOS sync callers pass an event ID; +// local/bootstrap callers pass NULL so an existing WorkOS event cursor is preserved. func (q *Queries) UpsertGlobalRole(ctx context.Context, arg UpsertGlobalRoleParams) error { _, err := q.db.Exec(ctx, upsertGlobalRole, arg.WorkosSlug, @@ -1177,7 +1235,7 @@ ON CONFLICT (organization_id, workos_slug) DO UPDATE SET workos_name = EXCLUDED.workos_name, workos_description = EXCLUDED.workos_description, workos_updated_at = EXCLUDED.workos_updated_at, - workos_last_event_id = EXCLUDED.workos_last_event_id, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, organization_roles.workos_last_event_id), deleted_at = NULL, workos_deleted_at = NULL, updated_at = clock_timestamp() @@ -1204,6 +1262,7 @@ LEFT JOIN organization_role_assignments AS ora ON ora.organization_id = upserted.organization_id AND ora.role_urn = 'role:organization:' || upserted.id::text AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL GROUP BY upserted.id, upserted.workos_slug, upserted.workos_name, upserted.workos_description, upserted.workos_created_at, upserted.workos_updated_at ` @@ -1228,9 +1287,8 @@ type UpsertOrganizationRoleRow struct { MemberCount int64 } -// Upsert an org-scoped WorkOS role. Caller must have already passed the row -// through ShouldProcessEvent. Resurrects a previously soft-deleted role on -// conflict. +// Upsert an org-scoped role. WorkOS sync callers pass an event ID; local role +// lifecycle callers pass NULL so an existing WorkOS event cursor is preserved. func (q *Queries) UpsertOrganizationRole(ctx context.Context, arg UpsertOrganizationRoleParams) (UpsertOrganizationRoleRow, error) { row := q.db.QueryRow(ctx, upsertOrganizationRole, arg.OrganizationID, @@ -1292,7 +1350,7 @@ ON CONFLICT (organization_id, workos_user_id, role_urn) WHERE deleted_at IS NULL user_id = COALESCE(EXCLUDED.user_id, organization_role_assignments.user_id), workos_membership_id = EXCLUDED.workos_membership_id, workos_updated_at = EXCLUDED.workos_updated_at, - workos_last_event_id = EXCLUDED.workos_last_event_id, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, organization_role_assignments.workos_last_event_id), updated_at = clock_timestamp() ` diff --git a/server/internal/access/role_manager.go b/server/internal/access/role_manager.go index ac223393a2..d93a9b87a9 100644 --- a/server/internal/access/role_manager.go +++ b/server/internal/access/role_manager.go @@ -130,7 +130,6 @@ type workosSync func(context.Context) type accessAuditActor struct { Principal urn.Principal DisplayName *string - Slug *string } // CreateRole creates the local role, grants, optional assignments, and audit entry atomically, then best-effort syncs WorkOS after commit. @@ -210,7 +209,7 @@ func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID str OrganizationID: gramOrgID, Actor: actor.Principal, ActorDisplayName: actor.DisplayName, - ActorSlug: actor.Slug, + ActorSlug: nil, RoleID: createdRole.ID, RoleName: createdRole.Name, RoleSlug: createdRole.Slug, @@ -371,7 +370,7 @@ func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID str OrganizationID: gramOrgID, Actor: actor.Principal, ActorDisplayName: actor.DisplayName, - ActorSlug: actor.Slug, + ActorSlug: nil, RoleID: updatedRole.ID, RoleName: updatedRoleView.Name, RoleSlug: updatedRole.Slug, @@ -449,9 +448,8 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro } deletedCount, err := repo.New(tx).MarkOrganizationRoleDeletedLocally(ctx, repo.MarkOrganizationRoleDeletedLocallyParams{ - OrganizationID: gramOrgID, - WorkosSlug: currentRole.Slug, - WorkosLastEventID: conv.ToPGTextEmpty(""), + OrganizationID: gramOrgID, + WorkosSlug: currentRole.Slug, }) if err != nil { trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) @@ -469,7 +467,7 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro OrganizationID: gramOrgID, Actor: actor.Principal, ActorDisplayName: actor.DisplayName, - ActorSlug: actor.Slug, + ActorSlug: nil, RoleID: currentRole.ID, RoleName: currentRole.Name, RoleSlug: currentRole.Slug, @@ -542,8 +540,7 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r membershipID := conv.FromPGTextOrEmpty[string](existing.WorkosMembershipID) if membershipID == "" { - // WorkOS sync must attach membership IDs before role changes can be propagated upstream. - return memberRoleUpdateContext{}, oops.E(oops.CodeNotFound, nil, "member not found").Log(ctx, r.logger) + return memberRoleUpdateContext{}, oops.E(oops.CodeNotFound, nil, "member is missing local WorkOS membership linkage").Log(ctx, r.logger) } if existing.WorkosUserID != "" && role.Slug != "" { @@ -594,7 +591,7 @@ func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, r OrganizationID: gramOrgID, Actor: actor.Principal, ActorDisplayName: actor.DisplayName, - ActorSlug: actor.Slug, + ActorSlug: nil, MemberID: result.UserID, MemberName: result.After.Name, MemberEmail: result.After.Email, @@ -798,7 +795,7 @@ func (r *RoleManager) memberAssignmentTargetsTx(ctx context.Context, dbtx repo.D }) } if len(resolvedInputs) != len(requestedInputs) { - return nil, oops.E(oops.CodeBadRequest, nil, "member role assignment not found; wait for WorkOS sync to complete").Log(ctx, r.logger) + return nil, oops.E(oops.CodeBadRequest, nil, "member is missing local WorkOS membership linkage").Log(ctx, r.logger) } return targets, nil diff --git a/server/internal/access/role_manager_test.go b/server/internal/access/role_manager_test.go index 3badd4096a..1dcc30dd3f 100644 --- a/server/internal/access/role_manager_test.go +++ b/server/internal/access/role_manager_test.go @@ -2,11 +2,14 @@ package access import ( "testing" + "time" mockidp "github.com/speakeasy-api/gram/dev-idp/pkg/testidp" "github.com/stretchr/testify/require" + accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" "github.com/speakeasy-api/gram/server/internal/contextvalues" + "github.com/speakeasy-api/gram/server/internal/conv" ) func TestRoleManager_ListRoles(t *testing.T) { @@ -106,3 +109,105 @@ func TestRoleManager_AssignMembersToRoleAcceptsConnectedMemberWithoutAssignment( require.NoError(t, err) require.Equal(t, 1, assigned) } + +func TestRoleManager_LocalRoleWritePreservesWorkOSLastEventID(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + now := time.Now().UTC() + _, err := accessrepo.New(ti.conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosSlug: "custom-builder", + WorkosName: "Custom Builder", + WorkosDescription: conv.ToPGTextEmpty("Before"), + WorkosCreatedAt: conv.ToPGTimestamptz(now), + WorkosUpdatedAt: conv.ToPGTimestamptz(now), + WorkosLastEventID: conv.ToPGText("event_01SEED"), + }) + require.NoError(t, err) + + _, err = accessrepo.New(ti.conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosSlug: "custom-builder", + WorkosName: "Custom Builder", + WorkosDescription: conv.ToPGTextEmpty("After"), + WorkosCreatedAt: conv.ToPGTimestamptz(now), + WorkosUpdatedAt: conv.ToPGTimestamptz(now.Add(time.Minute)), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + require.NoError(t, err) + + row, err := accessrepo.New(ti.conn).GetOrganizationRoleBySlug(ctx, accessrepo.GetOrganizationRoleBySlugParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosSlug: "custom-builder", + }) + require.NoError(t, err) + require.Equal(t, "event_01SEED", row.WorkosLastEventID.String) + + replaced, err := accessrepo.New(ti.conn).ReplaceOrganizationRoleAssignment(ctx, accessrepo.ReplaceOrganizationRoleAssignmentParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosUserID: "user_1", + WorkosRoleSlug: "custom-builder", + UserID: conv.ToPGTextEmpty(""), + WorkosMembershipID: conv.ToPGTextEmpty("membership_1"), + WorkosUpdatedAt: conv.ToPGTimestamptz(now), + WorkosLastEventID: conv.ToPGText("event_02SEED"), + }) + require.NoError(t, err) + require.Equal(t, int64(1), replaced) + + replaced, err = accessrepo.New(ti.conn).ReplaceOrganizationRoleAssignment(ctx, accessrepo.ReplaceOrganizationRoleAssignmentParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosUserID: "user_1", + WorkosRoleSlug: "custom-builder", + UserID: conv.ToPGTextEmpty(""), + WorkosMembershipID: conv.ToPGTextEmpty("membership_1"), + WorkosUpdatedAt: conv.ToPGTimestamptz(now.Add(time.Minute)), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + require.NoError(t, err) + require.Equal(t, int64(1), replaced) + + assignments, err := accessrepo.New(ti.conn).ListOrganizationRoleAssignmentRecordsByWorkosUser(ctx, accessrepo.ListOrganizationRoleAssignmentRecordsByWorkosUserParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosUserID: "user_1", + }) + require.NoError(t, err) + require.Len(t, assignments, 1) + require.False(t, assignments[0].DeletedAt.Valid) + require.Equal(t, "event_02SEED", assignments[0].WorkosLastEventID.String) +} + +func TestRoleManager_ReplaceRoleAssignmentSoftDeletesPreviousRole(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", "member")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "u1@example.com", "User 1", "user_1", "membership_1") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "member")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder")) + + assignments, err := accessrepo.New(ti.conn).ListOrganizationRoleAssignmentRecordsByWorkosUser(ctx, accessrepo.ListOrganizationRoleAssignmentRecordsByWorkosUserParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosUserID: "user_1", + }) + require.NoError(t, err) + activeCount := 0 + deletedCount := 0 + for _, assignment := range assignments { + if assignment.DeletedAt.Valid { + deletedCount++ + continue + } + activeCount++ + } + require.Equal(t, 1, activeCount) + require.Equal(t, 1, deletedCount) +} diff --git a/server/internal/access/updatememberrole_test.go b/server/internal/access/updatememberrole_test.go index ac25969eed..c18119d41a 100644 --- a/server/internal/access/updatememberrole_test.go +++ b/server/internal/access/updatememberrole_test.go @@ -82,7 +82,7 @@ func TestService_UpdateMemberRole_WorkOSMembershipNotFound(t *testing.T) { _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) require.Error(t, err) - require.Contains(t, err.Error(), "member not found") + require.Contains(t, err.Error(), "member is missing local WorkOS membership linkage") } func TestService_UpdateMemberRole_WorkOSFailure(t *testing.T) { diff --git a/server/internal/authz/grants.go b/server/internal/authz/grants.go index ca4f30f2eb..39550040e4 100644 --- a/server/internal/authz/grants.go +++ b/server/internal/authz/grants.go @@ -3,10 +3,13 @@ package authz import ( "cmp" "context" + "errors" "fmt" "log/slog" "slices" + "time" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" @@ -81,6 +84,39 @@ var SystemRoleGrants = map[string][]*RoleGrant{ // SeedSystemRoleGrants upserts the fixed grant sets for all system roles. func SeedSystemRoleGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, organizationID string) error { for roleSlug, grants := range SystemRoleGrants { + existingRole, err := repo.New(db).GetGlobalRoleBySlug(ctx, roleSlug) + seedRole := false + switch { + case err == nil: + seedRole = existingRole.Deleted + case errors.Is(err, pgx.ErrNoRows): + seedRole = true + default: + return fmt.Errorf("load %s role: %w", roleSlug, err) + } + if seedRole { + name := roleSlug + description := "" + switch roleSlug { + case SystemRoleAdmin: + name = "Admin" + description = "Administrator role" + case SystemRoleMember: + name = "Member" + description = "Member role" + } + now := time.Now().UTC() + if err := repo.New(db).UpsertGlobalRole(ctx, repo.UpsertGlobalRoleParams{ + WorkosSlug: roleSlug, + WorkosName: name, + WorkosDescription: conv.ToPGTextEmpty(description), + WorkosCreatedAt: conv.ToPGTimestamptz(now), + WorkosUpdatedAt: conv.ToPGTimestamptz(now), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }); err != nil { + return fmt.Errorf("seed %s role: %w", roleSlug, err) + } + } if err := SyncGrants(ctx, logger, db, organizationID, roleSlug, "", grants); err != nil { return fmt.Errorf("seed %s grants: %w", roleSlug, err) } diff --git a/server/internal/authz/load_test.go b/server/internal/authz/load_test.go index a6b51eb5cc..1e39a976ee 100644 --- a/server/internal/authz/load_test.go +++ b/server/internal/authz/load_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/require" + accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" @@ -34,6 +35,26 @@ func TestLoadGrants_loadsUserAndRoleGrants(t *testing.T) { require.NoError(t, engine.Require(ctx, Check{Scope: ScopeMCPConnect, ResourceID: "toolA"})) } +func TestSeedSystemRoleGrantsBootstrapsGlobalRoles(t *testing.T) { + t.Parallel() + + ctx := t.Context() + conn := newTestDB(t) + organizationID := "org_seed_system_roles" + seedOrganization(t, ctx, conn, organizationID) + + err := SeedSystemRoleGrants(ctx, testenv.NewLogger(t), conn, organizationID) + require.NoError(t, err) + + adminRole, err := accessrepo.New(conn).GetGlobalRoleBySlug(ctx, SystemRoleAdmin) + require.NoError(t, err) + require.Equal(t, "Admin", adminRole.WorkosName) + + grants, err := GrantsForRole(ctx, testenv.NewLogger(t), conn, organizationID, SystemRoleAdmin, "role:global:"+adminRole.ID.String()) + require.NoError(t, err) + require.NotEmpty(t, grants) +} + func TestLoadGrants_rejectsEmptyOrganizationID(t *testing.T) { t.Parallel() From 7250b24900726dae627cab18d4535c9bf59b75d4 Mon Sep 17 00:00:00 2001 From: tgmendes Date: Thu, 21 May 2026 16:22:38 +0100 Subject: [PATCH 12/12] fix: preserve local user on role deletion reassignment --- server/internal/access/deleterole_test.go | 27 ++++++++-- server/internal/access/queries.sql | 60 +++++++++++++--------- server/internal/access/repo/queries.sql.go | 60 +++++++++++++--------- server/internal/access/role_manager.go | 2 +- 4 files changed, 97 insertions(+), 52 deletions(-) diff --git a/server/internal/access/deleterole_test.go b/server/internal/access/deleterole_test.go index ebc6e77221..eda775180b 100644 --- a/server/internal/access/deleterole_test.go +++ b/server/internal/access/deleterole_test.go @@ -56,10 +56,12 @@ func TestService_DeleteRole_ReassignsMembersToDefault(t *testing.T) { require.NotNil(t, authCtx) roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) - seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) + memberRoleID := seedGlobalRole(t, ctx, ti.conn, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) - seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder")) - seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder")) seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_other", "user_3", "admin")) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", authz.SystemRoleMember).Return(&thirdpartyworkos.Member{ ID: "membership_1", @@ -79,6 +81,25 @@ func TestService_DeleteRole_ReassignsMembersToDefault(t *testing.T) { err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) require.NoError(t, err) + + members, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) + require.NoError(t, err) + membersByID := map[string]*gen.AccessMember{} + for _, member := range members.Members { + membersByID[member.ID] = member + } + require.Equal(t, memberRoleID, membersByID["local_user_1"].RoleID) + require.Equal(t, memberRoleID, membersByID["local_user_2"].RoleID) + + roles, err := ti.service.ListRoles(ctx, &gen.ListRolesPayload{}) + require.NoError(t, err) + for _, role := range roles.Roles { + if role.ID == memberRoleID { + require.Equal(t, 2, role.MemberCount) + return + } + } + require.Fail(t, "member role not found") } func TestService_DeleteRole_ReassignFailureDoesNotHaltDelete(t *testing.T) { diff --git a/server/internal/access/queries.sql b/server/internal/access/queries.sql index 5387dae4f7..0a48ddd5f1 100644 --- a/server/internal/access/queries.sql +++ b/server/internal/access/queries.sql @@ -460,18 +460,24 @@ ORDER BY created_at; -- name: UpsertOrganizationRoleAssignment :execrows WITH input_role_urn AS ( - SELECT 'role:organization:' || id::text AS role_urn - FROM organization_roles - WHERE organization_roles.organization_id = @organization_id - AND organization_roles.workos_slug = sqlc.arg(workos_role_slug) - AND organization_roles.deleted IS FALSE - AND organization_roles.workos_deleted IS FALSE - UNION ALL - SELECT 'role:global:' || id::text AS role_urn - FROM global_roles - WHERE global_roles.workos_slug = sqlc.arg(workos_role_slug) - AND global_roles.deleted IS FALSE - AND global_roles.workos_deleted IS FALSE + SELECT role_urn + FROM ( + SELECT 'role:organization:' || id::text AS role_urn, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_roles.organization_id = @organization_id + AND organization_roles.workos_slug = sqlc.arg(workos_role_slug) + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE + UNION ALL + SELECT 'role:global:' || id::text AS role_urn, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.workos_slug = sqlc.arg(workos_role_slug) + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE + ) roles + -- Keep assignment writes aligned with role reads: org roles shadow global roles. + ORDER BY role_kind DESC + LIMIT 1 ) INSERT INTO organization_role_assignments ( organization_id, @@ -500,18 +506,24 @@ ON CONFLICT (organization_id, workos_user_id, role_urn) WHERE deleted_at IS NULL -- name: ReplaceOrganizationRoleAssignment :one WITH input_role_urn AS ( - SELECT 'role:organization:' || id::text AS role_urn - FROM organization_roles - WHERE organization_roles.organization_id = @organization_id - AND organization_roles.workos_slug = sqlc.arg(workos_role_slug) - AND organization_roles.deleted IS FALSE - AND organization_roles.workos_deleted IS FALSE - UNION ALL - SELECT 'role:global:' || id::text AS role_urn - FROM global_roles - WHERE global_roles.workos_slug = sqlc.arg(workos_role_slug) - AND global_roles.deleted IS FALSE - AND global_roles.workos_deleted IS FALSE + SELECT role_urn + FROM ( + SELECT 'role:organization:' || id::text AS role_urn, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_roles.organization_id = @organization_id + AND organization_roles.workos_slug = sqlc.arg(workos_role_slug) + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE + UNION ALL + SELECT 'role:global:' || id::text AS role_urn, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.workos_slug = sqlc.arg(workos_role_slug) + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE + ) roles + -- Keep assignment writes aligned with role reads: org roles shadow global roles. + ORDER BY role_kind DESC + LIMIT 1 ), upserted AS ( INSERT INTO organization_role_assignments ( diff --git a/server/internal/access/repo/queries.sql.go b/server/internal/access/repo/queries.sql.go index 2eb53d281d..ea7dc7904c 100644 --- a/server/internal/access/repo/queries.sql.go +++ b/server/internal/access/repo/queries.sql.go @@ -1084,18 +1084,24 @@ func (q *Queries) MarkOrganizationRoleDeletedLocally(ctx context.Context, arg Ma const replaceOrganizationRoleAssignment = `-- name: ReplaceOrganizationRoleAssignment :one WITH input_role_urn AS ( - SELECT 'role:organization:' || id::text AS role_urn - FROM organization_roles - WHERE organization_roles.organization_id = $1 - AND organization_roles.workos_slug = $2 - AND organization_roles.deleted IS FALSE - AND organization_roles.workos_deleted IS FALSE - UNION ALL - SELECT 'role:global:' || id::text AS role_urn - FROM global_roles - WHERE global_roles.workos_slug = $2 - AND global_roles.deleted IS FALSE - AND global_roles.workos_deleted IS FALSE + SELECT role_urn + FROM ( + SELECT 'role:organization:' || id::text AS role_urn, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_roles.organization_id = $1 + AND organization_roles.workos_slug = $2 + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE + UNION ALL + SELECT 'role:global:' || id::text AS role_urn, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.workos_slug = $2 + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE + ) roles + -- Keep assignment writes aligned with role reads: org roles shadow global roles. + ORDER BY role_kind DESC + LIMIT 1 ), upserted AS ( INSERT INTO organization_role_assignments ( @@ -1315,18 +1321,24 @@ func (q *Queries) UpsertOrganizationRole(ctx context.Context, arg UpsertOrganiza const upsertOrganizationRoleAssignment = `-- name: UpsertOrganizationRoleAssignment :execrows WITH input_role_urn AS ( - SELECT 'role:organization:' || id::text AS role_urn - FROM organization_roles - WHERE organization_roles.organization_id = $1 - AND organization_roles.workos_slug = $7 - AND organization_roles.deleted IS FALSE - AND organization_roles.workos_deleted IS FALSE - UNION ALL - SELECT 'role:global:' || id::text AS role_urn - FROM global_roles - WHERE global_roles.workos_slug = $7 - AND global_roles.deleted IS FALSE - AND global_roles.workos_deleted IS FALSE + SELECT role_urn + FROM ( + SELECT 'role:organization:' || id::text AS role_urn, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_roles.organization_id = $1 + AND organization_roles.workos_slug = $7 + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE + UNION ALL + SELECT 'role:global:' || id::text AS role_urn, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.workos_slug = $7 + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE + ) roles + -- Keep assignment writes aligned with role reads: org roles shadow global roles. + ORDER BY role_kind DESC + LIMIT 1 ) INSERT INTO organization_role_assignments ( organization_id, diff --git a/server/internal/access/role_manager.go b/server/internal/access/role_manager.go index d93a9b87a9..514d520044 100644 --- a/server/internal/access/role_manager.go +++ b/server/internal/access/role_manager.go @@ -422,7 +422,7 @@ func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, ro OrganizationID: gramOrgID, WorkosUserID: row.WorkosUserID, WorkosRoleSlug: authz.SystemRoleMember, - UserID: conv.ToPGTextEmpty(""), + UserID: row.UserID, WorkosMembershipID: conv.ToPGTextEmpty(membershipID), WorkosUpdatedAt: conv.ToPGTimestamptz(time.Now().UTC()), WorkosLastEventID: conv.ToPGTextEmpty(""),