diff --git a/server/cmd/gram/start.go b/server/cmd/gram/start.go index b01e60a8db..7a7e838dbb 100644 --- a/server/cmd/gram/start.go +++ b/server/cmd/gram/start.go @@ -692,7 +692,6 @@ func newStartCommand() *cli.Command { rbacEnabled, challengeLoggingEnabled, roleClient, - cache.NewRedisCacheAdapter(redisClient), authz.EngineOpts{DevMode: c.String("environment") == "local"}, ) @@ -945,7 +944,8 @@ func newStartCommand() *cli.Command { about.Attach(mux, about.NewService(logger, tracerProvider)) external.AttachWebhookHandler(mux, external.NewWebhookHandler(logger, tracerProvider, newWorkOSWebhooksClient(c), temporalEnv)) - access.Attach(mux, access.NewService(logger, tracerProvider, db, chDB, sessionManager, roleClient, authzEngine, productFeatures, auditLogger)) + roleManager := access.NewRoleManager(logger, db, roleClient, auditLogger) + access.Attach(mux, access.NewService(logger, tracerProvider, db, chDB, sessionManager, roleManager, authzEngine, productFeatures, auditLogger)) assistants.Attach(mux, assistantsSvc) assistantmemories.Attach(mux, assistantmemories.NewService( logger, diff --git a/server/cmd/gram/worker.go b/server/cmd/gram/worker.go index 81617ef8b2..ca192e5c28 100644 --- a/server/cmd/gram/worker.go +++ b/server/cmd/gram/worker.go @@ -508,7 +508,6 @@ func newWorkerCommand() *cli.Command { rbacEnabled, challengeLoggingEnabled, workos.NewStubClient(), - cache.NewRedisCacheAdapter(redisClient), authz.EngineOpts{DevMode: c.String("environment") == "local"}, ) diff --git a/server/internal/access/createrole_test.go b/server/internal/access/createrole_test.go index d584992fc1..b1d9215636 100644 --- a/server/internal/access/createrole_test.go +++ b/server/internal/access/createrole_test.go @@ -39,12 +39,6 @@ func TestService_CreateRole(t *testing.T) { CreatedAt: mockRoleTimestamp, UpdatedAt: mockRoleTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "member"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "member"), - // user_workos_only has never logged into Gram — should not be counted - mockMember(mockidp.MockOrgID, "membership_workos_only", "user_workos_only", "member"), - }, nil).Once() ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "org-custom-builder").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -60,8 +54,12 @@ func TestService_CreateRole(t *testing.T) { CreatedAt: mockMembershipTimestamp, }, nil).Once() + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", authz.SystemRoleMember)) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember("", "membership_2", "user_2", authz.SystemRoleMember)) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember("", "membership_workos_only", "user_workos_only", authz.SystemRoleMember)) role, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{ Name: "Custom Builder", @@ -74,14 +72,19 @@ func TestService_CreateRole(t *testing.T) { }) require.NoError(t, err) require.Equal(t, "Custom Builder", role.Name) + require.NotEmpty(t, role.ID) + require.NotEqual(t, "role_1", role.ID) require.Equal(t, "Can build selected resources", role.Description) require.False(t, role.IsSystem) require.Equal(t, 2, role.MemberCount) - require.Equal(t, mockRoleTimestamp, role.CreatedAt) - require.Equal(t, mockRoleTimestamp, role.UpdatedAt) + require.NotEmpty(t, role.CreatedAt) + require.NotEmpty(t, role.UpdatedAt) require.Len(t, role.Grants, 2) + roundtrip, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: role.ID}) + require.NoError(t, err) + require.Equal(t, role.ID, roundtrip.ID) - grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "org-custom-builder")) + grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "organization:"+role.ID)) require.Len(t, grants, 3) } @@ -93,33 +96,53 @@ func TestService_CreateRole_WorkOSCreateFailure(t *testing.T) { Name: "Custom Builder", Slug: "org-custom-builder", Description: "Can build selected resources", - }).Return((*thirdpartyworkos.Role)(nil), errors.New("workos unavailable")).Once() + }).Return((*thirdpartyworkos.Role)(nil), errors.New("workos unavailable")).Times(3) - _, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{ + role, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{ Name: "Custom Builder", Description: "Can build selected resources", Grants: []*gen.RoleGrant{ {Scope: string(authz.ScopeProjectRead), Selectors: []*gen.Selector{{ResourceKind: "project", ResourceID: "project-1"}}}, }, }) - require.Error(t, err) - require.Contains(t, err.Error(), "create role in workos") + require.NoError(t, err) + require.Equal(t, "Custom Builder", role.Name) } -func TestService_CreateRole_ContinuesAfterConflictWhenRoleAlreadyExists(t *testing.T) { +func TestService_CreateRole_WorkOSConflictFailure(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + ti.roles.On("CreateRole", mock.Anything, mockidp.MockOrgID, thirdpartyworkos.CreateRoleOpts{ + Name: "Custom Builder", + Slug: "org-custom-builder", + Description: "Can build selected resources", + }).Return((*thirdpartyworkos.Role)(nil), &thirdpartyworkos.APIError{Method: "POST", Path: "/authorization/organizations/org_workos_test/roles", StatusCode: 409, Body: "role already exists"}).Once() + + role, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{ + Name: "Custom Builder", + Description: "Can build selected resources", + Grants: []*gen.RoleGrant{ + {Scope: string(authz.ScopeProjectRead), Selectors: []*gen.Selector{{ResourceKind: "project", ResourceID: "project-1"}}}, + }, + }) + require.NoError(t, err) + require.Equal(t, "Custom Builder", role.Name) +} + +func TestService_CreateRole_WorkOSConflictUsesLocalRole(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) - require.NotNil(t, authCtx) - existingRole := mockRole("role_existing", "Custom Builder", "org-custom-builder", "Can build selected resources") + + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_1", "Custom Builder", "org-custom-builder", "Can build selected resources")) ti.roles.On("CreateRole", mock.Anything, mockidp.MockOrgID, thirdpartyworkos.CreateRoleOpts{ Name: "Custom Builder", Slug: "org-custom-builder", Description: "Can build selected resources", }).Return((*thirdpartyworkos.Role)(nil), &thirdpartyworkos.APIError{Method: "POST", Path: "/authorization/organizations/org_workos_test/roles", StatusCode: 409, Body: "role already exists"}).Once() - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{existingRole}, nil).Once() role, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{ Name: "Custom Builder", @@ -129,12 +152,8 @@ func TestService_CreateRole_ContinuesAfterConflictWhenRoleAlreadyExists(t *testi }, }) require.NoError(t, err) - require.Equal(t, "role_existing", role.ID) + require.Equal(t, roleID, role.ID) require.Equal(t, "Custom Builder", role.Name) - - grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "org-custom-builder")) - require.Len(t, grants, 1) - require.Equal(t, authCtx.ActiveOrganizationID, grants[0].OrganizationID) } func TestService_CreateRole_RejectsEmptySlug(t *testing.T) { @@ -198,7 +217,7 @@ func TestService_CreateRole_AuditLog(t *testing.T) { require.Equal(t, beforeCount+1, afterCount) } -func TestService_CreateRole_GrantSyncFailureDoesNotAssignMembers(t *testing.T) { +func TestService_CreateRole_LocalRoleWriteFailureDoesNotAssignMembers(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -206,35 +225,24 @@ func TestService_CreateRole_GrantSyncFailureDoesNotAssignMembers(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("CreateRole", mock.Anything, mockidp.MockOrgID, thirdpartyworkos.CreateRoleOpts{ - Name: "Broken Builder", - Slug: "org-broken-builder", - Description: "Will fail grant sync", - }).Run(func(mock.Arguments) { - ti.conn.Close() - }).Return(&thirdpartyworkos.Role{ - ID: "role_1", - Name: "Broken Builder", - Slug: "org-broken-builder", - Description: "Will fail grant sync", - CreatedAt: mockRoleTimestamp, - UpdatedAt: mockRoleTimestamp, - }, nil).Once() - inspectConn, err := pgxpool.New(ctx, ti.conn.Config().ConnString()) require.NoError(t, err) t.Cleanup(inspectConn.Close) - _, err = ti.service.CreateRole(ctx, &gen.CreateRolePayload{ + ti.conn.Close() + _, err = ti.service.roleMgr.CreateRole(ctx, authCtx.ActiveOrganizationID, mockidp.MockOrgID, accessAuditActor{ + Principal: urn.NewPrincipal(urn.PrincipalTypeUser, authCtx.UserID), + DisplayName: authCtx.Email, + }, &gen.CreateRolePayload{ Name: "Broken Builder", - Description: "Will fail grant sync", + Description: "Will fail local write", Grants: []*gen.RoleGrant{ {Scope: string(authz.ScopeProjectRead), Selectors: []*gen.Selector{{ResourceKind: "project", ResourceID: "project-1"}}}, }, MemberIds: []string{"local_user_1", "local_user_2"}, }) require.Error(t, err) - require.Contains(t, err.Error(), "sync grants for created role") + require.Contains(t, err.Error(), "role transaction") grants, err := accessrepo.New(inspectConn).ListPrincipalGrantsByOrg(ctx, accessrepo.ListPrincipalGrantsByOrgParams{ OrganizationID: authCtx.ActiveOrganizationID, diff --git a/server/internal/access/deleterole_test.go b/server/internal/access/deleterole_test.go index 16f852903d..eda775180b 100644 --- a/server/internal/access/deleterole_test.go +++ b/server/internal/access/deleterole_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" gen "github.com/speakeasy-api/gram/server/gen/access" + accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" "github.com/speakeasy-api/gram/server/internal/audit" "github.com/speakeasy-api/gram/server/internal/audit/audittest" "github.com/speakeasy-api/gram/server/internal/authz" @@ -25,19 +26,25 @@ func TestService_DeleteRole(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(nil).Once() seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeMCPConnect, authz.WildcardResource) - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) require.NoError(t, err) grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder")) require.Empty(t, grants) + + role, err := accessrepo.New(ti.conn).GetOrganizationRoleBySlug(ctx, accessrepo.GetOrganizationRoleBySlugParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosSlug: "custom-builder", + }) + require.NoError(t, err) + require.True(t, role.Deleted) + require.False(t, role.WorkosDeleted) + require.False(t, role.WorkosDeletedAt.Valid) } func TestService_DeleteRole_ReassignsMembersToDefault(t *testing.T) { @@ -48,14 +55,14 @@ func TestService_DeleteRole_ReassignsMembersToDefault(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_other", "user_3", "admin"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) + memberRoleID := seedGlobalRole(t, ctx, ti.conn, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_other", "user_3", "admin")) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", authz.SystemRoleMember).Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -72,11 +79,30 @@ func TestService_DeleteRole_ReassignsMembersToDefault(t *testing.T) { }, nil).Once() ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(nil).Once() - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) + require.NoError(t, err) + + members, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) + membersByID := map[string]*gen.AccessMember{} + for _, member := range members.Members { + membersByID[member.ID] = member + } + require.Equal(t, memberRoleID, membersByID["local_user_1"].RoleID) + require.Equal(t, memberRoleID, membersByID["local_user_2"].RoleID) + + roles, err := ti.service.ListRoles(ctx, &gen.ListRolesPayload{}) + require.NoError(t, err) + for _, role := range roles.Roles { + if role.ID == memberRoleID { + require.Equal(t, 2, role.MemberCount) + return + } + } + require.Fail(t, "member role not found") } -func TestService_DeleteRole_ReassignFailureHaltsDelete(t *testing.T) { +func TestService_DeleteRole_ReassignFailureDoesNotHaltDelete(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -84,25 +110,21 @@ func TestService_DeleteRole_ReassignFailureHaltsDelete(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - }, nil).Once() - ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", authz.SystemRoleMember).Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder")) + ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", authz.SystemRoleMember).Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Times(3) + ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(nil).Once() seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) - require.Error(t, err) - require.Contains(t, err.Error(), "reassign member to default role") + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) + require.NoError(t, err) - // Grants must remain since reassignment failed before grant cleanup ran. grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder")) - require.Len(t, grants, 1) + require.Empty(t, grants) } -func TestService_DeleteRole_PartialReassignFailureStopsLoop(t *testing.T) { +func TestService_DeleteRole_PartialReassignFailureContinuesDelete(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -110,15 +132,10 @@ func TestService_DeleteRole_PartialReassignFailureStopsLoop(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - // Iteration order over the slice is deterministic, so seed an explicit - // success-then-failure pair to exercise the partial-failure cache flush. - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder")) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", authz.SystemRoleMember).Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -126,26 +143,23 @@ func TestService_DeleteRole_PartialReassignFailureStopsLoop(t *testing.T) { RoleSlug: authz.SystemRoleMember, CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("UpdateMemberRole", mock.Anything, "membership_2", authz.SystemRoleMember).Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Once() + ti.roles.On("UpdateMemberRole", mock.Anything, "membership_2", authz.SystemRoleMember).Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Times(3) + ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(nil).Once() seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) - require.Error(t, err) - require.Contains(t, err.Error(), "reassign member to default role") + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) + require.NoError(t, err) - // The mock's AssertExpectations (registered in newMockRoleProvider) verifies - // that DeleteRole was never called and the loop stopped at the first failure. grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder")) - require.Len(t, grants, 1) + require.Empty(t, grants) } func TestService_DeleteRole_NotFound(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{}, nil).Once() - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_missing"}) + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "00000000-0000-0000-0000-000000000001"}) require.Error(t, err) require.Contains(t, err.Error(), "role not found") } @@ -154,11 +168,11 @@ func TestService_DeleteRole_SystemRole(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - }, nil).Once() + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_admin"}) + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) require.Error(t, err) require.Contains(t, err.Error(), "system roles cannot be deleted") } @@ -170,20 +184,15 @@ func TestService_DeleteRole_WorkOSDeleteFailure(t *testing.T) { authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() - ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(errors.New("workos unavailable")).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) + ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(errors.New("workos unavailable")).Times(3) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) - require.Error(t, err) - require.Contains(t, err.Error(), "delete role in workos") + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) + require.NoError(t, err) grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder")) require.Empty(t, grants) - } func TestService_DeleteRole_AuditLog(t *testing.T) { @@ -196,14 +205,11 @@ func TestService_DeleteRole_AuditLog(t *testing.T) { beforeCount, err := audittest.AuditLogCountByAction(ctx, ti.conn, audit.ActionAccessRoleDelete) require.NoError(t, err) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Audit Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Audit Builder", "custom-builder", "Old description")) ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(nil).Once() seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") - err = ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) + err = ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) require.NoError(t, err) record, err := audittest.LatestAuditLogByAction(ctx, ti.conn, audit.ActionAccessRoleDelete) diff --git a/server/internal/access/getrole_test.go b/server/internal/access/getrole_test.go index 05d236cb3a..f33c62745c 100644 --- a/server/internal/access/getrole_test.go +++ b/server/internal/access/getrole_test.go @@ -1,17 +1,13 @@ package access import ( - "errors" - mockidp "github.com/speakeasy-api/gram/dev-idp/pkg/testidp" "testing" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" gen "github.com/speakeasy-api/gram/server/gen/access" "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/contextvalues" - thirdpartyworkos "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" ) @@ -23,27 +19,22 @@ func TestService_GetRole(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - mockRole("role_custom", "Custom Builder", "custom-builder", "Can build selected resources"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_3", "user_3", "admin"), - // user_workos_only has never logged into Gram — should not be counted - mockMember(mockidp.MockOrgID, "membership_workos_only", "user_workos_only", "custom-builder"), - }, nil).Once() + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + customID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build selected resources")) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_3", "user3@test.com", "User 3", "user_3", "membership_3") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember("", "membership_2", "user_2", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_3", mockMember("", "membership_3", "user_3", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember("", "membership_workos_only", "user_workos_only", "custom-builder")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeMCPConnect, authz.WildcardResource) - role, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: "role_custom"}) + role, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: customID}) require.NoError(t, err) - require.Equal(t, "role_custom", role.ID) + require.Equal(t, customID, role.ID) require.Equal(t, "Custom Builder", role.Name) require.Equal(t, "Can build selected resources", role.Description) require.False(t, role.IsSystem) @@ -66,9 +57,8 @@ func TestService_GetRole_NotFound(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{}, nil).Once() - _, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: "role_missing"}) + _, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: "00000000-0000-0000-0000-000000000001"}) require.Error(t, err) require.Contains(t, err.Error(), "role not found") } @@ -103,13 +93,12 @@ func TestService_GetRole_OrganizationNotLinkedToWorkOS(t *testing.T) { require.Contains(t, err.Error(), "organization is not linked to WorkOS") } -func TestService_GetRole_WorkOSListRolesFailure(t *testing.T) { +func TestService_GetRole_InvalidID(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role(nil), errors.New("workos unavailable")).Once() _, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: "role_custom"}) require.Error(t, err) - require.Contains(t, err.Error(), "list roles from workos") + require.Contains(t, err.Error(), "invalid role ID") } diff --git a/server/internal/access/impl.go b/server/internal/access/impl.go index db7a11d089..a5a9fef1e8 100644 --- a/server/internal/access/impl.go +++ b/server/internal/access/impl.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "log/slog" - "regexp" "strings" "time" @@ -29,35 +28,18 @@ import ( chrepo "github.com/speakeasy-api/gram/server/internal/authz/repo" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/conv" + "github.com/speakeasy-api/gram/server/internal/database" "github.com/speakeasy-api/gram/server/internal/middleware" "github.com/speakeasy-api/gram/server/internal/o11y" "github.com/speakeasy-api/gram/server/internal/oops" orgrepo "github.com/speakeasy-api/gram/server/internal/organizations/repo" "github.com/speakeasy-api/gram/server/internal/productfeatures" pfRepo "github.com/speakeasy-api/gram/server/internal/productfeatures/repo" - "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" usersrepo "github.com/speakeasy-api/gram/server/internal/users/repo" ) -var ( - errConnectedUserNotFound = errors.New("connected user not found") - // Custom role names become stable slugs and user-facing identifiers, so keep - // them to a predictable ASCII set instead of silently normalizing symbols. - validRoleNamePattern = regexp.MustCompile(`^[A-Za-z0-9 _-]+$`) -) - -type RoleProvider interface { - ListRoles(ctx context.Context, orgID string) ([]workos.Role, error) - CreateRole(ctx context.Context, orgID string, opts workos.CreateRoleOpts) (*workos.Role, error) - UpdateRole(ctx context.Context, orgID string, roleSlug string, opts workos.UpdateRoleOpts) (*workos.Role, error) - DeleteRole(ctx context.Context, orgID string, roleSlug string) error - ListMembers(ctx context.Context, orgID string) ([]workos.Member, error) - UpdateMemberRole(ctx context.Context, membershipID string, roleSlug string) (*workos.Member, error) - GetUser(ctx context.Context, userID string) (*workos.User, error) - ListOrgUsers(ctx context.Context, orgID string) (map[string]workos.User, error) - GetOrgMembership(ctx context.Context, workOSUserID, workOSOrgID string) (*workos.Member, error) -} +var errConnectedUserNotFound = errors.New("connected user not found") // FeatureCacheWriter updates the Redis cache entry for a feature flag after a // direct DB write, keeping the cache consistent with the authoritative state. @@ -72,7 +54,7 @@ type Service struct { chConn driver.Conn auth *auth.Auth authz *authz.Engine - roles RoleProvider + roleMgr *RoleManager featureCache FeatureCacheWriter audit *audit.Logger } @@ -86,7 +68,7 @@ func NewService( db *pgxpool.Pool, chConn driver.Conn, sessions *sessions.Manager, - roles RoleProvider, + roleMgr *RoleManager, authz *authz.Engine, featureCache FeatureCacheWriter, auditLogger *audit.Logger, @@ -100,7 +82,7 @@ func NewService( chConn: chConn, auth: auth.New(logger, db, sessions, authz), authz: authz, - roles: roles, + roleMgr: roleMgr, featureCache: featureCache, audit: auditLogger, } @@ -120,10 +102,9 @@ func (s *Service) APIKeyAuth(ctx context.Context, key string, schema *security.A return s.auth.Authorize(ctx, key, schema) } -// ListRoles treats WorkOS as the source of truth for role records while Gram -// remains the source of truth for role grants. +// ListRoles reads local role records and enriches them with Gram's local grant state. func (s *Service) ListRoles(ctx context.Context, _ *gen.ListRolesPayload) (*gen.ListRolesResult, error) { - ac, workosOrgID, err := s.roleOrgContext(ctx) + ac, _, err := s.roleOrgContext(ctx) if err != nil { return nil, err } @@ -135,36 +116,13 @@ func (s *Service) ListRoles(ctx context.Context, _ *gen.ListRolesPayload) (*gen. attr.UserID(ac.UserID), ) - wRoles, err := s.roles.ListRoles(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list roles from workos").Log(ctx, s.logger) - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, s.logger) - } - memberCounts, err := s.localMemberCounts(ctx, ac.ActiveOrganizationID, members) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "count local members by role").Log(ctx, s.logger) - } - - roles := make([]*gen.Role, 0, len(wRoles)) - for _, wr := range wRoles { - role, err := buildRole(ctx, s.logger, s.db, ac.ActiveOrganizationID, wr, memberCounts[wr.Slug]) - if err != nil { - return nil, err - } - roles = append(roles, role) - } - - return &gen.ListRolesResult{Roles: roles}, nil + return s.roleMgr.ListRoles(ctx, ac.ActiveOrganizationID) } -// GetRole returns the WorkOS role definition enriched with Gram's local grant +// GetRole returns the role definition enriched with Gram's local grant // state so callers see the complete effective role configuration in one place. func (s *Service) GetRole(ctx context.Context, payload *gen.GetRolePayload) (*gen.Role, error) { - ac, workosOrgID, err := s.roleOrgContext(ctx) + ac, _, err := s.roleOrgContext(ctx) if err != nil { return nil, err } @@ -177,35 +135,9 @@ func (s *Service) GetRole(ctx context.Context, payload *gen.GetRolePayload) (*ge attr.AccessRoleID(payload.ID), ) - wRoles, err := s.roles.ListRoles(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list roles from workos").Log(ctx, s.logger) - } - - role, ok := findRoleByID(wRoles, payload.ID) - if !ok { - return nil, oops.E(oops.CodeNotFound, nil, "role not found").Log(ctx, s.logger) - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, s.logger) - } - - memberCounts, err := s.localMemberCounts(ctx, ac.ActiveOrganizationID, members) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "count local members by role").Log(ctx, s.logger) - } - - return buildRole(ctx, s.logger, s.db, ac.ActiveOrganizationID, role, memberCounts[role.Slug]) + return s.roleMgr.GetRoleByID(ctx, ac.ActiveOrganizationID, payload.ID) } -// CreateRole creates a role for a user of a given organization. -// It is an idempotent operation intentionally ordered so that member assignment happens last. -// If WorkOS role creation succeeds but local grant sync fails, we return an -// error with no users assigned to the new role. That leaves a partially -// created role behind, but keeps the outcome safe and retryable: repeating the -// request can finish configuration without having granted accidental access. func (s *Service) CreateRole(ctx context.Context, payload *gen.CreateRolePayload) (*gen.Role, error) { ac, workosOrgID, err := s.roleOrgContext(ctx) if err != nil { @@ -215,136 +147,22 @@ func (s *Service) CreateRole(ctx context.Context, payload *gen.CreateRolePayload return nil, err } - roleSlug, slugErr := slugify(payload.Name) - if slugErr != nil { - return nil, slugErr - } - logger := s.logger.With( - attr.SlogOrganizationID(ac.ActiveOrganizationID), - attr.SlogUserID(ac.UserID), - attr.SlogAccessRoleSlug(roleSlug), - ) - trace.SpanFromContext(ctx).SetAttributes( - attr.OrganizationID(ac.ActiveOrganizationID), - attr.UserID(ac.UserID), - attr.AccessRoleSlug(roleSlug), - ) - - wr, err := s.roles.CreateRole(ctx, workosOrgID, workos.CreateRoleOpts{ - Name: payload.Name, - Slug: roleSlug, - Description: payload.Description, - }) - var apiErr *workos.APIError - switch { - case errors.As(err, &apiErr) && apiErr.StatusCode == 409: - wRoles, listErr := s.roles.ListRoles(ctx, workosOrgID) - if listErr != nil { - return nil, oops.E(oops.CodeUnexpected, listErr, "list roles after create conflict").Log(ctx, logger) - } - - var existingRole workos.Role - ok := false - for _, candidate := range wRoles { - if candidate.Slug == roleSlug { - existingRole = candidate - ok = true - break - } - } - if !ok { - return nil, oops.E(oops.CodeUnexpected, err, "create role in workos").Log(ctx, logger) - } - - wr = &existingRole - case err != nil: - return nil, oops.E(oops.CodeUnexpected, err, "create role in workos").Log(ctx, logger) - } trace.SpanFromContext(ctx).SetAttributes( attr.OrganizationID(ac.ActiveOrganizationID), attr.UserID(ac.UserID), - attr.AccessRoleID(wr.ID), ) - // Stop before assigning members if grant sync fails. That can leave behind a - // newly created WorkOS role with no local grants, but it avoids assigning users - // to a role whose effective permissions are incomplete or unknown. Returning an - // error makes the setup retryable without creating accidental access. - if err := authz.SyncGrants(ctx, s.logger, s.db, ac.ActiveOrganizationID, wr.Slug, roleGrantPayloads(payload.Grants)); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "sync grants for created role").Log(ctx, logger) - } - - assignedWorkosIDs := make([]string, 0, len(payload.MemberIds)) - if len(payload.MemberIds) > 0 { - // payload.MemberIds are Gram user IDs (returned by ListMembers). - // Resolve them to WorkOS user IDs so we can look up WorkOS memberships. - gramToWorkos, err := gramToWorkosIDMap(ctx, s.db, payload.MemberIds) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "resolve gram user ids to workos ids").Log(ctx, logger) - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, logger) - } - - membershipByUser := membershipsByUserID(members) - - for _, gramID := range payload.MemberIds { - workosID, ok := gramToWorkos[gramID] - if !ok { - continue - } - membershipID, ok := membershipByUser[workosID] - if !ok { - continue - } - - if _, err := s.roles.UpdateMemberRole(ctx, membershipID, wr.Slug); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "assign members to created role").Log(ctx, logger) - } - - assignedWorkosIDs = append(assignedWorkosIDs, workosID) - } - s.authz.InvalidateAllRoleCaches(ctx, ac.ActiveOrganizationID) - } - - // Only count assigned members who have local Gram accounts and are - // connected to this org, consistent with how ListMembers filters users. - assignedCount := 0 - if len(assignedWorkosIDs) > 0 { - localRows, err := usersrepo.New(s.db).GetConnectedUsersByWorkosIDs(ctx, usersrepo.GetConnectedUsersByWorkosIDsParams{ - WorkosIds: assignedWorkosIDs, - OrganizationID: ac.ActiveOrganizationID, - }) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, fmt.Errorf("get connected users by workos ids: %w", err), "resolve local assigned members").Log(ctx, logger) - } - assignedCount = len(localRows) - } - - createdRole, err := buildRole(ctx, logger, s.db, ac.ActiveOrganizationID, *wr, assignedCount) + created, err := s.roleMgr.CreateRole(ctx, ac.ActiveOrganizationID, workosOrgID, accessAuditActor{ + Principal: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), + DisplayName: ac.Email, + }, payload) if err != nil { return nil, err } - if err := s.audit.LogAccessRoleCreate(ctx, s.db, audit.LogAccessRoleCreateEvent{ - OrganizationID: ac.ActiveOrganizationID, - Actor: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), - ActorDisplayName: ac.Email, - ActorSlug: nil, - RoleID: wr.ID, - RoleName: createdRole.Name, - RoleSlug: wr.Slug, - }); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "log access role creation").Log(ctx, logger) - } - - return createdRole, nil + return created.Role, nil } -// UpdateRole preserves the same split of responsibilities as creation: WorkOS -// owns role identity and membership, while Gram owns the role's grant set. func (s *Service) UpdateRole(ctx context.Context, payload *gen.UpdateRolePayload) (*gen.Role, error) { ac, workosOrgID, err := s.roleOrgContext(ctx) if err != nil { @@ -353,138 +171,24 @@ func (s *Service) UpdateRole(ctx context.Context, payload *gen.UpdateRolePayload if err := s.authz.Require(ctx, authz.Check{Scope: authz.ScopeOrgAdmin, ResourceKind: "", ResourceID: ac.ActiveOrganizationID, Dimensions: nil}); err != nil { return nil, err } - logger := s.logger.With( - attr.SlogOrganizationID(ac.ActiveOrganizationID), - attr.SlogUserID(ac.UserID), - attr.SlogAccessRoleID(payload.ID), - ) trace.SpanFromContext(ctx).SetAttributes( attr.OrganizationID(ac.ActiveOrganizationID), attr.UserID(ac.UserID), attr.AccessRoleID(payload.ID), ) - wRoles, err := s.roles.ListRoles(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list roles from workos").Log(ctx, logger) - } - - currentRole, ok := findRoleByID(wRoles, payload.ID) - if !ok { - return nil, oops.E(oops.CodeNotFound, nil, "role not found").Log(ctx, logger) - } - logger = logger.With(attr.SlogAccessRoleSlug(currentRole.Slug)) - trace.SpanFromContext(ctx).SetAttributes( - attr.OrganizationID(ac.ActiveOrganizationID), - attr.UserID(ac.UserID), - attr.AccessRoleSlug(currentRole.Slug), - ) - sysRole := isSystemRole(currentRole.Slug) - if sysRole && (payload.Name != nil || payload.Description != nil || payload.Grants != nil) { - return nil, oops.E(oops.CodeBadRequest, nil, "system role properties cannot be updated, only member assignment is allowed").Log(ctx, logger) - } - if sysRole && payload.MemberIds == nil { - return nil, oops.E(oops.CodeBadRequest, nil, "system role update requires member_ids").Log(ctx, logger) - } - if payload.Name != nil { - if _, err := slugify(*payload.Name); err != nil { - return nil, err - } - } - - membersBefore, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, logger) - } - memberCountsBefore, err := s.localMemberCounts(ctx, ac.ActiveOrganizationID, membersBefore) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "count local members by role").Log(ctx, logger) - } - existingRole, err := buildRole(ctx, logger, s.db, ac.ActiveOrganizationID, currentRole, memberCountsBefore[currentRole.Slug]) + updated, err := s.roleMgr.UpdateRole(ctx, ac.ActiveOrganizationID, workosOrgID, accessAuditActor{ + Principal: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), + DisplayName: ac.Email, + }, payload) if err != nil { return nil, err } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSlug(updated.Role.Slug)) - // System roles are immutable in WorkOS — skip the role update call and grant sync. - var updatedRole *workos.Role - if sysRole { - updatedRole = ¤tRole - } else { - updatedRole, err = s.roles.UpdateRole(ctx, workosOrgID, currentRole.Slug, workos.UpdateRoleOpts{ - Name: payload.Name, - Description: payload.Description, - }) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "update role in workos").Log(ctx, logger) - } - - // As with role creation, member reassignment happens after local grant sync so - // a failed sync never leaves users attached to a role with incomplete access. - if payload.Grants != nil { - if err := authz.SyncGrants(ctx, s.logger, s.db, ac.ActiveOrganizationID, currentRole.Slug, roleGrantPayloads(payload.Grants)); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "sync grants for updated role").Log(ctx, logger) - } - } - } - - if payload.MemberIds != nil { - gramToWorkos, err := gramToWorkosIDMap(ctx, s.db, payload.MemberIds) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "resolve gram user ids to workos ids").Log(ctx, logger) - } - - membershipByUser := membershipsByUserID(membersBefore) - - for _, gramID := range payload.MemberIds { - workosID, ok := gramToWorkos[gramID] - if !ok { - continue - } - membershipID, ok := membershipByUser[workosID] - if !ok { - continue - } - - if _, err := s.roles.UpdateMemberRole(ctx, membershipID, currentRole.Slug); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "assign members to updated role").Log(ctx, logger) - } - } - s.authz.InvalidateAllRoleCaches(ctx, ac.ActiveOrganizationID) - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, logger) - } - - memberCounts, err := s.localMemberCounts(ctx, ac.ActiveOrganizationID, members) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "count local members by role").Log(ctx, logger) - } - updatedRoleView, err := buildRole(ctx, logger, s.db, ac.ActiveOrganizationID, *updatedRole, memberCounts[updatedRole.Slug]) - if err != nil { - return nil, err - } - - if err := s.audit.LogAccessRoleUpdate(ctx, s.db, audit.LogAccessRoleUpdateEvent{ - OrganizationID: ac.ActiveOrganizationID, - Actor: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), - ActorDisplayName: ac.Email, - ActorSlug: nil, - RoleID: updatedRole.ID, - RoleName: updatedRoleView.Name, - RoleSlug: updatedRole.Slug, - RoleSnapshotBefore: existingRole, - RoleSnapshotAfter: updatedRoleView, - }); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "log access role update").Log(ctx, logger) - } - - return updatedRoleView, nil + return updated.After, nil } -// DeleteRole removes local grants before deleting the WorkOS role so retries can -// still complete cleanup if the external delete fails. func (s *Service) DeleteRole(ctx context.Context, payload *gen.DeleteRolePayload) error { ac, workosOrgID, err := s.roleOrgContext(ctx) if err != nil { @@ -493,81 +197,20 @@ func (s *Service) DeleteRole(ctx context.Context, payload *gen.DeleteRolePayload if err := s.authz.Require(ctx, authz.Check{Scope: authz.ScopeOrgAdmin, ResourceKind: "", ResourceID: ac.ActiveOrganizationID, Dimensions: nil}); err != nil { return err } - logger := s.logger.With( - attr.SlogOrganizationID(ac.ActiveOrganizationID), - attr.SlogUserID(ac.UserID), - attr.SlogAccessRoleID(payload.ID), - ) trace.SpanFromContext(ctx).SetAttributes( attr.OrganizationID(ac.ActiveOrganizationID), attr.UserID(ac.UserID), attr.AccessRoleID(payload.ID), ) - wRoles, err := s.roles.ListRoles(ctx, workosOrgID) - if err != nil { - return oops.E(oops.CodeUnexpected, err, "list roles from workos").Log(ctx, logger) - } - - currentRole, ok := findRoleByID(wRoles, payload.ID) - if !ok { - return oops.E(oops.CodeNotFound, nil, "role not found").Log(ctx, logger) - } - logger = logger.With(attr.SlogAccessRoleSlug(currentRole.Slug)) - trace.SpanFromContext(ctx).SetAttributes( - attr.OrganizationID(ac.ActiveOrganizationID), - attr.UserID(ac.UserID), - attr.AccessRoleSlug(currentRole.Slug), - ) - if isSystemRole(currentRole.Slug) { - return oops.E(oops.CodeBadRequest, nil, "system roles cannot be deleted").Log(ctx, logger) - } - - // WorkOS rejects deleting a role that still has members assigned, so move - // any assigned members to the default member role first. - members, err := s.roles.ListMembers(ctx, workosOrgID) + deleted, err := s.roleMgr.DeleteRole(ctx, ac.ActiveOrganizationID, workosOrgID, payload.ID, accessAuditActor{ + Principal: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), + DisplayName: ac.Email, + }) if err != nil { - return oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, logger) - } - reassigned := false - for _, m := range members { - if m.RoleSlug != currentRole.Slug { - continue - } - if _, err := s.roles.UpdateMemberRole(ctx, m.ID, authz.SystemRoleMember); err != nil { - if reassigned { - s.authz.InvalidateAllRoleCaches(ctx, ac.ActiveOrganizationID) - } - return oops.E(oops.CodeUnexpected, err, "reassign member to default role").Log(ctx, logger) - } - reassigned = true - } - if reassigned { - s.authz.InvalidateAllRoleCaches(ctx, ac.ActiveOrganizationID) - } - - if _, err := repo.New(s.db).DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ - OrganizationID: ac.ActiveOrganizationID, - PrincipalUrn: urn.NewPrincipal(urn.PrincipalTypeRole, currentRole.Slug), - }); err != nil { - return oops.E(oops.CodeUnexpected, err, "delete grants for deleted role").Log(ctx, logger) - } - - if err := s.roles.DeleteRole(ctx, workosOrgID, currentRole.Slug); err != nil { - return oops.E(oops.CodeUnexpected, err, "delete role in workos").Log(ctx, logger) - } - - if err := s.audit.LogAccessRoleDelete(ctx, s.db, audit.LogAccessRoleDeleteEvent{ - OrganizationID: ac.ActiveOrganizationID, - Actor: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), - ActorDisplayName: ac.Email, - ActorSlug: nil, - RoleID: currentRole.ID, - RoleName: currentRole.Name, - RoleSlug: currentRole.Slug, - }); err != nil { - return oops.E(oops.CodeUnexpected, err, "log access role deletion").Log(ctx, logger) + return err } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSlug(deleted.Slug)) return nil } @@ -603,8 +246,7 @@ func (s *Service) ListScopes(ctx context.Context, _ *gen.ListScopesPayload) (*ge // ListMembers follows the original access API contract by returning WorkOS user // identifiers while decorating them with the role information the UI needs. func (s *Service) ListMembers(ctx context.Context, _ *gen.ListMembersPayload) (*gen.ListMembersResult, error) { - _, workosOrgID, err := s.roleOrgContext(ctx) - ac, _ := s.authContext(ctx) + ac, _, err := s.roleOrgContext(ctx) if err != nil { return nil, err } @@ -616,71 +258,7 @@ func (s *Service) ListMembers(ctx context.Context, _ *gen.ListMembersPayload) (* attr.UserID(ac.UserID), ) - roles, err := s.roles.ListRoles(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list roles from workos").Log(ctx, s.logger) - } - roleIDBySlug := make(map[string]string, len(roles)) - for _, role := range roles { - roleIDBySlug[role.Slug] = role.ID - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, s.logger) - } - - users, err := s.roles.ListOrgUsers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list org users from workos").Log(ctx, s.logger) - } - - // Batch-resolve WorkOS user IDs to local Gram users, filtering to only - // those connected to this organization via organization_user_relationships. - // This single joined query prevents the list from surfacing users who - // exist in Gram but aren't connected to the current org (which would - // cause UpdateMemberRole to fail). - workosIDs := make([]string, 0, len(users)) - for workosUID := range users { - workosIDs = append(workosIDs, workosUID) - } - localUserRows, err := usersrepo.New(s.db).GetConnectedUsersByWorkosIDs(ctx, usersrepo.GetConnectedUsersByWorkosIDsParams{ - WorkosIds: workosIDs, - OrganizationID: ac.ActiveOrganizationID, - }) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "resolve connected users by workos ids").Log(ctx, s.logger) - } - localUsers := make(map[string]usersrepo.User, len(localUserRows)) - for _, u := range localUserRows { - if u.WorkosID.Valid { - localUsers[u.WorkosID.String] = u - } - } - - result := make([]*gen.AccessMember, 0, len(members)) - for _, member := range members { - user, ok := users[member.UserID] - if !ok { - continue - } - - local, ok := localUsers[member.UserID] - if !ok { - continue - } - - result = append(result, &gen.AccessMember{ - ID: local.ID, - Name: formatUserName(user), - Email: user.Email, - PhotoURL: conv.FromPGText[string](local.PhotoUrl), - RoleID: roleIDBySlug[member.RoleSlug], - JoinedAt: conv.Default(member.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), - }) - } - - return &gen.ListMembersResult{Members: result}, nil + return s.roleMgr.ListMembers(ctx, ac.ActiveOrganizationID) } // ListGrants returns the effective grants for the current user by combining @@ -700,7 +278,7 @@ func (s *Service) ListGrants(ctx context.Context, _ *gen.ListGrantsPayload) (*ge return &gen.ListUserGrantsResult{Grants: allScopesGrants()}, nil } - ac, workosOrgID, err := s.roleOrgContext(ctx) + ac, _, err := s.roleOrgContext(ctx) if err != nil { return nil, err } @@ -734,12 +312,20 @@ func (s *Service) ListGrants(ctx context.Context, _ *gen.ListGrantsPayload) (*ge } principals := []urn.Principal{urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID)} - roleSlugs, err := s.memberRoleSlugs(ctx, workosOrgID, connectedUser.WorkosID.String) + rolePrincipals, err := s.roleMgr.MemberRolePrincipals(ctx, ac.ActiveOrganizationID, connectedUser.WorkosID.String) if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, logger) + return nil, oops.E(oops.CodeUnexpected, err, "list member roles").Log(ctx, logger) } - for _, roleSlug := range roleSlugs { - principals = append(principals, urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug)) + roleSlugs := make([]string, 0, len(rolePrincipals)) + for _, role := range rolePrincipals { + // Effective-grant responses must include grants stored under either the + // canonical role URN or the legacy role slug during the migration. + rolePrincipalURNs, err := authz.RolePrincipals(role.RoleSlug, role.PrincipalUrn) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "build role principals").Log(ctx, logger) + } + principals = append(principals, rolePrincipalURNs...) + roleSlugs = append(roleSlugs, role.RoleSlug) } if len(roleSlugs) == 1 { logger = logger.With(attr.SlogAccessRoleSlug(roleSlugs[0])) @@ -757,19 +343,13 @@ func (s *Service) ListGrants(ctx context.Context, _ *gen.ListGrantsPayload) (*ge // UpdateMemberRole is intentionally stricter than member listing: it only // mutates access for users Gram knows are connected to the local organization. func (s *Service) UpdateMemberRole(ctx context.Context, payload *gen.UpdateMemberRolePayload) (*gen.AccessMember, error) { - ac, workosOrgID, err := s.roleOrgContext(ctx) + ac, _, err := s.roleOrgContext(ctx) if err != nil { return nil, err } if err := s.authz.Require(ctx, authz.Check{Scope: authz.ScopeOrgAdmin, ResourceKind: "", ResourceID: ac.ActiveOrganizationID, Dimensions: nil}); err != nil { return nil, err } - logger := s.logger.With( - attr.SlogOrganizationID(ac.ActiveOrganizationID), - attr.SlogUserID(ac.UserID), - attr.SlogAccessMemberID(payload.UserID), - attr.SlogAccessRoleID(payload.RoleID), - ) trace.SpanFromContext(ctx).SetAttributes( attr.OrganizationID(ac.ActiveOrganizationID), attr.UserID(ac.UserID), @@ -777,109 +357,21 @@ func (s *Service) UpdateMemberRole(ctx context.Context, payload *gen.UpdateMembe attr.AccessRoleID(payload.RoleID), ) - roles, err := s.roles.ListRoles(ctx, workosOrgID) + memberUpdate, err := s.roleMgr.UpdateMemberRole(ctx, ac.ActiveOrganizationID, payload.UserID, payload.RoleID, accessAuditActor{ + Principal: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), + DisplayName: ac.Email, + }) if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list roles from workos").Log(ctx, logger) - } - - roleSlug := "" - for _, role := range roles { - if role.ID == payload.RoleID { - roleSlug = role.Slug - break - } - } - if roleSlug == "" { - return nil, oops.E(oops.CodeNotFound, nil, "role not found").Log(ctx, logger) + return nil, err } - logger = logger.With(attr.SlogAccessRoleSlug(roleSlug)) + roleSlug := memberUpdate.RoleSlug trace.SpanFromContext(ctx).SetAttributes( attr.OrganizationID(ac.ActiveOrganizationID), attr.UserID(ac.UserID), attr.AccessRoleSlug(roleSlug), ) - connectedUser, err := connectedUser(ctx, s.db, ac.ActiveOrganizationID, payload.UserID) - switch { - case errors.Is(err, errConnectedUserNotFound): - return nil, oops.E(oops.CodeNotFound, nil, "member has not joined this organization").Log(ctx, logger) - case err != nil: - return nil, oops.E(oops.CodeUnexpected, err, "load connected user").Log(ctx, logger) - } - if !connectedUser.WorkosID.Valid || connectedUser.WorkosID.String == "" { - return nil, oops.E(oops.CodeBadRequest, nil, "member is not linked to WorkOS").Log(ctx, logger) - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list members from workos").Log(ctx, logger) - } - - roleIDBySlug := make(map[string]string, len(roles)) - for _, role := range roles { - roleIDBySlug[role.Slug] = role.ID - } - - membershipID := "" - var existingMember workos.Member - for _, member := range members { - if member.UserID == connectedUser.WorkosID.String { - membershipID = member.ID - existingMember = member - break - } - } - if membershipID == "" { - return nil, oops.E(oops.CodeNotFound, nil, "member not found").Log(ctx, logger) - } - - updatedMember, err := s.roles.UpdateMemberRole(ctx, membershipID, roleSlug) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "update member role in workos").Log(ctx, logger) - } - s.authz.InvalidateRoleCache(ctx, payload.UserID, ac.ActiveOrganizationID) - - users, err := s.roles.ListOrgUsers(ctx, workosOrgID) - if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "list org users from workos").Log(ctx, logger) - } - user, ok := users[updatedMember.UserID] - if !ok { - return nil, oops.E(oops.CodeNotFound, nil, "member user not found").Log(ctx, logger) - } - - beforeMember := &gen.AccessMember{ - ID: connectedUser.ID, - Name: formatUserName(user), - Email: user.Email, - PhotoURL: conv.FromPGText[string](connectedUser.PhotoUrl), - RoleID: roleIDBySlug[existingMember.RoleSlug], - JoinedAt: conv.Default(existingMember.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), - } - afterMember := &gen.AccessMember{ - ID: connectedUser.ID, - Name: formatUserName(user), - Email: user.Email, - PhotoURL: conv.FromPGText[string](connectedUser.PhotoUrl), - RoleID: payload.RoleID, - JoinedAt: conv.Default(updatedMember.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), - } - - if err := s.audit.LogAccessMemberRoleUpdate(ctx, s.db, audit.LogAccessMemberRoleUpdateEvent{ - OrganizationID: ac.ActiveOrganizationID, - Actor: urn.NewPrincipal(urn.PrincipalTypeUser, ac.UserID), - ActorDisplayName: ac.Email, - ActorSlug: nil, - MemberID: connectedUser.ID, - MemberName: afterMember.Name, - MemberEmail: afterMember.Email, - MemberSnapshotBefore: beforeMember, - MemberSnapshotAfter: afterMember, - }); err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "log access member role update").Log(ctx, logger) - } - - return afterMember, nil + return memberUpdate.After, nil } func (s *Service) authContext(ctx context.Context) (*contextvalues.AuthContext, error) { @@ -985,138 +477,6 @@ func authzSelectorToGen(sel authz.Selector) *gen.Selector { return s } -func formatUserName(user workos.User) string { - switch { - case user.FirstName != "" && user.LastName != "": - return user.FirstName + " " + user.LastName - case user.FirstName != "": - return user.FirstName - case user.LastName != "": - return user.LastName - default: - return user.Email - } -} - -// localMemberCounts counts WorkOS members per role slug, but only for members -// who have a local Gram account and are connected to the given organization -// via organization_user_relationships. This ensures counts match what -// ListMembers returns. -func (s *Service) localMemberCounts(ctx context.Context, organizationID string, members []workos.Member) (map[string]int, error) { - workosIDs := make([]string, 0, len(members)) - for _, m := range members { - workosIDs = append(workosIDs, m.UserID) - } - localRows, err := usersrepo.New(s.db).GetConnectedUsersByWorkosIDs(ctx, usersrepo.GetConnectedUsersByWorkosIDsParams{ - WorkosIds: workosIDs, - OrganizationID: organizationID, - }) - if err != nil { - return nil, fmt.Errorf("get connected users by workos ids: %w", err) - } - localSet := make(map[string]struct{}, len(localRows)) - for _, u := range localRows { - if u.WorkosID.Valid { - localSet[u.WorkosID.String] = struct{}{} - } - } - counts := make(map[string]int) - for _, m := range members { - if _, ok := localSet[m.UserID]; ok { - counts[m.RoleSlug]++ - } - } - return counts, nil -} - -// gramToWorkosIDMap resolves Gram user IDs to WorkOS user IDs. -// Dashboard sends Gram IDs (from ListMembers), but WorkOS membership lookups -// require WorkOS user IDs. -func gramToWorkosIDMap(ctx context.Context, db *pgxpool.Pool, gramIDs []string) (map[string]string, error) { - users, err := usersrepo.New(db).GetUsersByIDs(ctx, gramIDs) - if err != nil { - return nil, fmt.Errorf("get users by ids: %w", err) - } - m := make(map[string]string, len(users)) - for _, u := range users { - if u.WorkosID.Valid && u.WorkosID.String != "" { - m[u.ID] = u.WorkosID.String - } - } - return m, nil -} - -func membershipsByUserID(members []workos.Member) map[string]string { - membershipByUser := make(map[string]string, len(members)) - for _, member := range members { - membershipByUser[member.UserID] = member.ID - } - - return membershipByUser -} - -func (s *Service) memberRoleSlugs(ctx context.Context, workosOrgID string, workosUserID string) ([]string, error) { - if workosUserID == "" { - return nil, nil - } - - members, err := s.roles.ListMembers(ctx, workosOrgID) - if err != nil { - return nil, fmt.Errorf("list members for role lookup: %w", err) - } - - roleSlugs := make([]string, 0, len(members)) - seenRoleSlugs := make(map[string]struct{}, len(members)) - - for _, member := range members { - if member.UserID != workosUserID || member.RoleSlug == "" { - continue - } - if _, ok := seenRoleSlugs[member.RoleSlug]; ok { - continue - } - - seenRoleSlugs[member.RoleSlug] = struct{}{} - roleSlugs = append(roleSlugs, member.RoleSlug) - } - - return roleSlugs, nil -} - -func findRoleByID(roles []workos.Role, id string) (workos.Role, bool) { - for _, role := range roles { - if role.ID == id { - return role, true - } - } - - var zero workos.Role - return zero, false -} - -func buildRole(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, organizationID string, role workos.Role, memberCount int) (*gen.Role, error) { - grants, err := authz.GrantsForRole(ctx, logger, db, organizationID, role.Slug) - if err != nil { - return nil, err - } - genGrants := make([]*gen.RoleGrant, 0, len(grants)) - for _, g := range grants { - genGrants = append(genGrants, scopedGrantToGenRoleGrant(g)) - } - - return &gen.Role{ - ID: role.ID, - Name: role.Name, - Slug: role.Slug, - Description: role.Description, - IsSystem: isSystemRole(role.Slug), - Grants: genGrants, - MemberCount: memberCount, - CreatedAt: conv.Default(role.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), - UpdatedAt: conv.Default(role.UpdatedAt, time.Time{}.UTC().Format(time.RFC3339)), - }, nil -} - func scopedGrantToGenRoleGrant(g *authz.ScopedGrant) *gen.RoleGrant { var selectors []*gen.Selector for _, sel := range g.Selectors { @@ -1154,7 +514,7 @@ func listRoleGrantsFromGrants(grants []authz.Grant) []*gen.ListRoleGrant { return out } -func connectedUser(ctx context.Context, db *pgxpool.Pool, organizationID string, userID string) (usersrepo.User, error) { +func connectedUser(ctx context.Context, db database.DBTX, organizationID string, userID string) (usersrepo.User, error) { hasRelationship, err := orgrepo.New(db).HasOrganizationUserRelationship(ctx, orgrepo.HasOrganizationUserRelationshipParams{ OrganizationID: organizationID, UserID: conv.ToPGText(userID), @@ -1269,21 +629,6 @@ func (s *Service) requireSuperAdmin(ctx context.Context) (*contextvalues.AuthCon return ac, nil } -func slugify(name string) (string, error) { - slug := conv.ToSlug(strings.ReplaceAll(name, "_", " ")) - if slug == "" { - return "", oops.E(oops.CodeBadRequest, nil, "role name must contain at least one letter or digit") - } - if !validRoleNamePattern.MatchString(name) { - return "", oops.E(oops.CodeBadRequest, nil, "role name contains invalid characters") - } - if !strings.HasPrefix(slug, "org-") { - slug = "org-" + slug - } - - return slug, nil -} - type challengeUserInfo struct { email string photoURL *string diff --git a/server/internal/access/listmembers_test.go b/server/internal/access/listmembers_test.go index 2522ba3212..44f3680bfc 100644 --- a/server/internal/access/listmembers_test.go +++ b/server/internal/access/listmembers_test.go @@ -1,16 +1,12 @@ package access import ( - "errors" - mockidp "github.com/speakeasy-api/gram/dev-idp/pkg/testidp" "testing" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" gen "github.com/speakeasy-api/gram/server/gen/access" "github.com/speakeasy-api/gram/server/internal/contextvalues" - thirdpartyworkos "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" ) func TestService_ListMembers(t *testing.T) { @@ -23,22 +19,14 @@ func TestService_ListMembers(t *testing.T) { seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "grace@example.com", "Grace", "user_2", "membership_2") - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User{ - "user_1": mockUser("user_1", "Ada", "Lovelace", "ada@example.com"), - "user_2": mockUser("user_2", "Grace", "", "grace@example.com"), - }, nil).Once() + adminID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember("", "membership_2", "user_2", "custom-builder")) result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) - require.Len(t, result.Members, 2) + require.Len(t, result.Members, 3) byID := map[string]*gen.AccessMember{} for _, member := range result.Members { @@ -48,14 +36,16 @@ func TestService_ListMembers(t *testing.T) { // IDs should be Gram user IDs, not WorkOS user IDs. require.Equal(t, "Ada Lovelace", byID["local_user_1"].Name) require.Equal(t, "ada@example.com", byID["local_user_1"].Email) - require.Equal(t, "role_admin", byID["local_user_1"].RoleID) - require.Equal(t, "2024-11-15T15:04:05Z", byID["local_user_1"].JoinedAt) + require.Equal(t, adminID, byID["local_user_1"].RoleID) + require.NotEmpty(t, byID["local_user_1"].JoinedAt) require.Equal(t, "Grace", byID["local_user_2"].Name) - require.Equal(t, "role_builder", byID["local_user_2"].RoleID) + require.Equal(t, builderID, byID["local_user_2"].RoleID) + + require.Empty(t, byID[authCtx.UserID].RoleID) } -func TestService_ListMembers_ExcludesDisconnectedUsers(t *testing.T) { +func TestService_ListMembers_SkipsMembersWithoutLocalUser(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -67,38 +57,49 @@ func TestService_ListMembers_ExcludesDisconnectedUsers(t *testing.T) { // to this org — no row in organization_user_relationships. seedDisconnectedUser(t, ctx, ti.conn, "local_user_2", "grace@example.com", "Grace", "user_2") - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "admin"), - }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User{ - "user_1": mockUser("user_1", "Ada", "Lovelace", "ada@example.com"), - "user_2": mockUser("user_2", "Grace", "", "grace@example.com"), - }, nil).Once() + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember("", "membership_2", "user_2", "admin")) result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) - require.Len(t, result.Members, 1, "disconnected user should be excluded") - require.Equal(t, "local_user_1", result.Members[0].ID) - require.Equal(t, "Ada Lovelace", result.Members[0].Name) + require.Len(t, result.Members, 2) + + byID := map[string]*gen.AccessMember{} + for _, member := range result.Members { + byID[member.ID] = member + } + require.Equal(t, "Ada Lovelace", byID["local_user_1"].Name) + require.Nil(t, byID["user_2"]) + require.Empty(t, byID[authCtx.UserID].RoleID) } -func TestService_ListMembers_WorkOSUsersFailure(t *testing.T) { +func TestService_ListMembers_IncludesConnectedUsersWithoutRoleAssignments(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User(nil), errors.New("workos unavailable")).Once() - - _, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) - require.Error(t, err) - require.Contains(t, err.Error(), "list org users from workos") + authCtx, _ := contextvalues.GetAuthContext(ctx) + + result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) + require.NoError(t, err) + require.Len(t, result.Members, 1) + + require.Equal(t, authCtx.UserID, result.Members[0].ID) + require.Empty(t, result.Members[0].RoleID) + require.NotEmpty(t, result.Members[0].Name) +} + +func TestService_ListMembers_UsesDatabaseOnly(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, _ := contextvalues.GetAuthContext(ctx) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember("", "membership_1", "user_1", "admin")) + + result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) + require.NoError(t, err) + require.Len(t, result.Members, 1) + require.Equal(t, authCtx.UserID, result.Members[0].ID) + require.Empty(t, result.Members[0].RoleID) } diff --git a/server/internal/access/listroles_test.go b/server/internal/access/listroles_test.go index 3409b8ca8b..3bf95ed3f0 100644 --- a/server/internal/access/listroles_test.go +++ b/server/internal/access/listroles_test.go @@ -1,16 +1,13 @@ package access import ( - mockidp "github.com/speakeasy-api/gram/dev-idp/pkg/testidp" "testing" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" gen "github.com/speakeasy-api/gram/server/gen/access" "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/contextvalues" - thirdpartyworkos "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" ) @@ -22,21 +19,16 @@ func TestService_ListRoles(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - mockRole("role_custom", "Custom Builder", "custom-builder", "Can build selected resources"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_3", "user_3", "custom-builder"), - // user_workos_only has never logged into Gram — should not be counted - mockMember(mockidp.MockOrgID, "membership_workos_only", "user_workos_only", "custom-builder"), - }, nil).Once() + adminID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + customID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build selected resources")) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_3", "user3@test.com", "User 3", "user_3", "membership_3") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember("", "membership_2", "user_2", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_3", mockMember("", "membership_3", "user_3", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember("", "membership_workos_only", "user_workos_only", "custom-builder")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "admin"), authz.ScopeOrgAdmin, authz.WildcardResource) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-2") @@ -51,7 +43,7 @@ func TestService_ListRoles(t *testing.T) { rolesByID[role.ID] = role } - adminRole := rolesByID["role_admin"] + adminRole := rolesByID[adminID] require.NotNil(t, adminRole) require.Equal(t, "Admin", adminRole.Name) require.True(t, adminRole.IsSystem) @@ -62,7 +54,7 @@ func TestService_ListRoles(t *testing.T) { require.Equal(t, string(authz.ScopeOrgAdmin), adminRole.Grants[0].Scope) require.Nil(t, adminRole.Grants[0].Selectors) - customRole := rolesByID["role_custom"] + customRole := rolesByID[customID] require.NotNil(t, customRole) require.Equal(t, "Custom Builder", customRole.Name) require.False(t, customRole.IsSystem) @@ -96,14 +88,10 @@ func TestService_ListRoles_ExcludesDisconnectedUsersFromMemberCounts(t *testing. // (no organization_user_relationships row). Should not inflate member count. seedDisconnectedUser(t, ctx, ti.conn, "local_user_2", "user2@test.com", "User 2", "user_2") - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - // user_2 appears in WorkOS members but is disconnected locally. - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "admin"), - }, nil).Once() + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) + // user_2 appears in role assignments but is disconnected locally. + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember("", "membership_2", "user_2", "admin")) result, err := ti.service.ListRoles(ctx, &gen.ListRolesPayload{}) require.NoError(t, err) @@ -111,3 +99,21 @@ func TestService_ListRoles_ExcludesDisconnectedUsersFromMemberCounts(t *testing. // Only user_1 should be counted — user_2 has a local account but no org connection. require.Equal(t, 1, result.Roles[0].MemberCount) } + +func TestService_ListRoles_DoesNotCountConnectedUsersWithoutAssignments(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + _, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + memberID := seedGlobalRole(t, ctx, ti.conn, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) + + result, err := ti.service.ListRoles(ctx, &gen.ListRolesPayload{}) + require.NoError(t, err) + require.Len(t, result.Roles, 1) + + require.Equal(t, memberID, result.Roles[0].ID) + require.Equal(t, "Member", result.Roles[0].Name) + require.Equal(t, 0, result.Roles[0].MemberCount) +} diff --git a/server/internal/access/listusergrants_test.go b/server/internal/access/listusergrants_test.go index fe4c3d3f10..9deabe67d7 100644 --- a/server/internal/access/listusergrants_test.go +++ b/server/internal/access/listusergrants_test.go @@ -1,12 +1,8 @@ package access import ( - "errors" "testing" - mockidp "github.com/speakeasy-api/gram/dev-idp/pkg/testidp" - - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" gen "github.com/speakeasy-api/gram/server/gen/access" @@ -15,7 +11,6 @@ import ( "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/conv" orgrepo "github.com/speakeasy-api/gram/server/internal/organizations/repo" - thirdpartyworkos "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" ) @@ -42,13 +37,11 @@ func TestService_ListGrants(t *testing.T) { ctx = contextvalues.SetAuthContext(ctx, authCtx) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, authCtx.UserID, "member@example.com", "Member User", "workos_user_member", "membership_1") + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, authCtx.UserID, mockMember("", "membership_1", "workos_user_member", "custom-builder")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeUser, authCtx.UserID), authz.ScopeProjectRead, "project_123") seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeMCPConnect, "tool_456") - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "workos_user_member", "custom-builder"), - }, nil).Once() - result, err := ti.service.ListGrants(ctx, &gen.ListGrantsPayload{}) require.NoError(t, err) require.Len(t, result.Grants, 2) @@ -62,7 +55,7 @@ func TestService_ListGrants(t *testing.T) { require.Equal(t, "tool_456", byScope["mcp:connect"].Selectors[0].ResourceID) } -func TestService_ListGrants_MultipleRoles(t *testing.T) { +func TestService_ListGrants_RoleGrants(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -73,13 +66,10 @@ func TestService_ListGrants_MultipleRoles(t *testing.T) { ctx = contextvalues.SetAuthContext(ctx, authCtx) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, authCtx.UserID, "member@example.com", "Member User", "workos_user_member", "membership_1") + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, authCtx.UserID, mockMember("", "membership_1", "workos_user_member", "custom-builder")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project_123") - seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-mcp"), authz.ScopeMCPConnect, "tool_456") - - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "workos_user_member", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_2", "workos_user_member", "custom-mcp"), - }, nil).Once() + seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeMCPConnect, "tool_456") result, err := ti.service.ListGrants(ctx, &gen.ListGrantsPayload{}) require.NoError(t, err) @@ -187,7 +177,7 @@ func TestService_ListGrants_RBACDisabledReturnsFullAccess(t *testing.T) { ctx = contextvalues.SetAuthContext(ctx, authCtx) chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - ti.service.authz = authz.NewEngine(ti.service.logger, ti.conn, chConn, authztest.RBACAlwaysDisabled, authztest.ChallengeLoggingAlwaysDisabled, ti.roles, nil) + ti.service.authz = authz.NewEngine(ti.service.logger, ti.conn, chConn, authztest.RBACAlwaysDisabled, authztest.ChallengeLoggingAlwaysDisabled, ti.roles) result, err := ti.service.ListGrants(ctx, &gen.ListGrantsPayload{}) require.NoError(t, err) @@ -233,7 +223,7 @@ func TestService_ListGrants_EnterpriseWithoutSessionReturnsFullAccess(t *testing } } -func TestService_ListGrants_WorkOSMembersFailure(t *testing.T) { +func TestService_ListGrants_NoRoleAssignments(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) @@ -245,9 +235,7 @@ func TestService_ListGrants_WorkOSMembersFailure(t *testing.T) { seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, authCtx.UserID, "member@example.com", "Member User", "workos_user_member", "membership_1") - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member(nil), errors.New("workos unavailable")).Once() - - _, err := ti.service.ListGrants(ctx, &gen.ListGrantsPayload{}) - require.Error(t, err) - require.Contains(t, err.Error(), "list members from workos") + result, err := ti.service.ListGrants(ctx, &gen.ListGrantsPayload{}) + require.NoError(t, err) + require.Empty(t, result.Grants) } diff --git a/server/internal/access/mock_role_test.go b/server/internal/access/mock_role_test.go index b6e9a8c516..c02ba4e410 100644 --- a/server/internal/access/mock_role_test.go +++ b/server/internal/access/mock_role_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -25,19 +26,21 @@ func newMockRoleProvider(t *testing.T) *MockRoleProvider { roles := &MockRoleProvider{} t.Cleanup(func() { - require.True(t, roles.AssertExpectations(t)) + require.Eventually(t, func() bool { + return roles.AssertExpectations(mockExpectationProbe{}) + }, 2*time.Second, 10*time.Millisecond) }) return roles } -func (m *MockRoleProvider) ListRoles(ctx context.Context, orgID string) ([]thirdpartyworkos.Role, error) { - args := m.Called(ctx, orgID) - if roles, ok := args.Get(0).([]thirdpartyworkos.Role); ok { - return roles, mockErr(args, 1) - } - return nil, mockErr(args, 1) -} +type mockExpectationProbe struct{} + +func (mockExpectationProbe) Logf(string, ...any) {} + +func (mockExpectationProbe) Errorf(string, ...any) {} + +func (mockExpectationProbe) FailNow() {} func (m *MockRoleProvider) CreateRole(ctx context.Context, orgID string, opts thirdpartyworkos.CreateRoleOpts) (*thirdpartyworkos.Role, error) { args := m.Called(ctx, orgID, opts) @@ -60,14 +63,6 @@ func (m *MockRoleProvider) DeleteRole(ctx context.Context, orgID string, roleSlu return mockErr(args, 0) } -func (m *MockRoleProvider) ListMembers(ctx context.Context, orgID string) ([]thirdpartyworkos.Member, error) { - args := m.Called(ctx, orgID) - if members, ok := args.Get(0).([]thirdpartyworkos.Member); ok { - return members, mockErr(args, 1) - } - return nil, mockErr(args, 1) -} - func (m *MockRoleProvider) UpdateMemberRole(ctx context.Context, membershipID string, roleSlug string) (*thirdpartyworkos.Member, error) { args := m.Called(ctx, membershipID, roleSlug) if member, ok := args.Get(0).(*thirdpartyworkos.Member); ok { @@ -76,22 +71,6 @@ func (m *MockRoleProvider) UpdateMemberRole(ctx context.Context, membershipID st return nil, mockErr(args, 1) } -func (m *MockRoleProvider) GetUser(ctx context.Context, userID string) (*thirdpartyworkos.User, error) { - args := m.Called(ctx, userID) - if user, ok := args.Get(0).(*thirdpartyworkos.User); ok { - return user, mockErr(args, 1) - } - return nil, mockErr(args, 1) -} - -func (m *MockRoleProvider) ListOrgUsers(ctx context.Context, orgID string) (map[string]thirdpartyworkos.User, error) { - args := m.Called(ctx, orgID) - if users, ok := args.Get(0).(map[string]thirdpartyworkos.User); ok { - return users, mockErr(args, 1) - } - return nil, mockErr(args, 1) -} - func (m *MockRoleProvider) GetOrgMembership(ctx context.Context, workOSUserID, workOSOrgID string) (*thirdpartyworkos.Member, error) { args := m.Called(ctx, workOSUserID, workOSOrgID) if member, ok := args.Get(0).(*thirdpartyworkos.Member); ok { @@ -133,12 +112,3 @@ func mockMember(orgID, membershipID, userID, roleSlug string) thirdpartyworkos.M CreatedAt: mockMembershipTimestamp, } } - -func mockUser(id, firstName, lastName, email string) thirdpartyworkos.User { - return thirdpartyworkos.User{ - ID: id, - FirstName: firstName, - LastName: lastName, - Email: email, - } -} diff --git a/server/internal/access/queries.sql b/server/internal/access/queries.sql index 96eea91df9..0a48ddd5f1 100644 --- a/server/internal/access/queries.sql +++ b/server/internal/access/queries.sql @@ -73,9 +73,8 @@ WHERE deleted_at IS NULL ORDER BY workos_slug; -- name: UpsertGlobalRole :exec --- Upsert an environment-level WorkOS role. Caller must have already passed --- the row through ShouldProcessEvent. Resurrects a previously soft-deleted --- role on conflict. +-- Upsert an environment-level role. WorkOS sync callers pass an event ID; +-- local/bootstrap callers pass NULL so an existing WorkOS event cursor is preserved. INSERT INTO global_roles ( workos_slug, workos_name, @@ -95,7 +94,7 @@ ON CONFLICT (workos_slug) DO UPDATE SET workos_name = EXCLUDED.workos_name, workos_description = EXCLUDED.workos_description, workos_updated_at = EXCLUDED.workos_updated_at, - workos_last_event_id = EXCLUDED.workos_last_event_id, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, global_roles.workos_last_event_id), deleted_at = NULL, workos_deleted_at = NULL, updated_at = clock_timestamp(); @@ -122,10 +121,10 @@ WHERE organization_id = @organization_id AND deleted_at IS NULL ORDER BY workos_slug; --- name: UpsertOrganizationRole :exec --- Upsert an org-scoped WorkOS role. Caller must have already passed the row --- through ShouldProcessEvent. Resurrects a previously soft-deleted role on --- conflict. +-- name: UpsertOrganizationRole :one +-- Upsert an org-scoped role. WorkOS sync callers pass an event ID; local role +-- lifecycle callers pass NULL so an existing WorkOS event cursor is preserved. +WITH upserted AS ( INSERT INTO organization_roles ( organization_id, workos_slug, @@ -147,10 +146,35 @@ ON CONFLICT (organization_id, workos_slug) DO UPDATE SET workos_name = EXCLUDED.workos_name, workos_description = EXCLUDED.workos_description, workos_updated_at = EXCLUDED.workos_updated_at, - workos_last_event_id = EXCLUDED.workos_last_event_id, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, organization_roles.workos_last_event_id), deleted_at = NULL, workos_deleted_at = NULL, - updated_at = clock_timestamp(); + updated_at = clock_timestamp() +RETURNING + id, + organization_id, + workos_slug, + workos_name, + workos_description, + workos_created_at, + workos_updated_at +) +SELECT + upserted.id, + ('role:organization:' || upserted.id::text)::text AS role_urn, + upserted.workos_slug, + upserted.workos_name, + upserted.workos_description, + upserted.workos_created_at, + upserted.workos_updated_at, + COUNT(ora.id)::bigint AS member_count +FROM upserted +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = upserted.organization_id + AND ora.role_urn = 'role:organization:' || upserted.id::text + AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL +GROUP BY upserted.id, upserted.workos_slug, upserted.workos_name, upserted.workos_description, upserted.workos_created_at, upserted.workos_updated_at; -- name: MarkOrganizationRoleDeleted :execrows UPDATE organization_roles @@ -161,3 +185,382 @@ SET workos_deleted_at = @workos_deleted_at, WHERE organization_id = @organization_id AND workos_slug = @workos_slug AND deleted_at IS NULL; + +-- name: MarkOrganizationRoleDeletedLocally :execrows +UPDATE organization_roles +SET deleted_at = clock_timestamp(), + updated_at = clock_timestamp() +WHERE organization_id = @organization_id + AND workos_slug = @workos_slug + AND deleted_at IS NULL; + +-- name: ListActiveOrganizationRoles :many +WITH active_roles AS ( + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind + FROM global_roles + WHERE deleted IS FALSE + AND workos_deleted IS FALSE + UNION ALL + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_id = @organization_id + AND deleted IS FALSE + AND workos_deleted IS FALSE +) +SELECT + active_roles.id, + ('role:' || active_roles.role_kind || ':' || active_roles.id::text)::text AS role_urn, + active_roles.workos_slug, + active_roles.workos_name, + active_roles.workos_description, + active_roles.workos_created_at, + active_roles.workos_updated_at, + COUNT(DISTINCT ora.id)::bigint AS member_count +FROM active_roles +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = @organization_id + AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text + AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL +GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +ORDER BY active_roles.workos_slug; + +-- name: GetActiveOrganizationRoleBySlug :one +WITH active_roles AS ( + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.workos_slug = @workos_slug + AND deleted IS FALSE + AND workos_deleted IS FALSE + UNION ALL + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_id = @organization_id + AND organization_roles.workos_slug = @workos_slug + AND deleted IS FALSE + AND workos_deleted IS FALSE +) +SELECT + active_roles.id, + ('role:' || active_roles.role_kind || ':' || active_roles.id::text)::text AS role_urn, + active_roles.workos_slug, + active_roles.workos_name, + active_roles.workos_description, + active_roles.workos_created_at, + active_roles.workos_updated_at, + COUNT(DISTINCT ora.id)::bigint AS member_count +FROM active_roles +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = @organization_id + AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text + AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL +GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +-- Organization roles shadow global roles if a slug ever collides. +ORDER BY active_roles.role_kind DESC +LIMIT 1; + +-- name: GetOrganizationRoleByID :one +WITH active_roles AS ( + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.id = sqlc.arg(id) + AND deleted IS FALSE + AND workos_deleted IS FALSE +UNION ALL + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_id = @organization_id + AND organization_roles.id = sqlc.arg(id) + AND deleted IS FALSE + AND workos_deleted IS FALSE +) +SELECT + active_roles.id, + ('role:' || active_roles.role_kind || ':' || active_roles.id::text)::text AS role_urn, + active_roles.workos_slug, + active_roles.workos_name, + active_roles.workos_description, + active_roles.workos_created_at, + active_roles.workos_updated_at, + COUNT(DISTINCT ora.id)::bigint AS member_count +FROM active_roles +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = @organization_id + AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text + AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL +GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +-- Organization roles shadow global roles if an ID ever collides. +ORDER BY active_roles.role_kind DESC +LIMIT 1; + +-- name: ListOrganizationRoleAssignmentsBySlug :many +SELECT + ora.user_id, + ora.workos_user_id, + ora.workos_membership_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.created_at +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = @organization_id + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) = @workos_role_slug + AND ora.deleted_at IS NULL +ORDER BY ora.workos_user_id; + +-- name: GetOrganizationRoleAssignmentByWorkosUser :one +SELECT + our.user_id, + users.workos_id::text AS workos_user_id, + our.workos_membership_id, + COALESCE(organization_roles.id::text, global_roles.id::text, '')::text AS role_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug, '')::text AS role_slug, + COALESCE(ora.created_at, our.created_at)::timestamptz AS created_at +FROM organization_user_relationships AS our +JOIN users + ON users.id = our.user_id +LEFT JOIN LATERAL ( + SELECT * + FROM organization_role_assignments + WHERE organization_role_assignments.organization_id = our.organization_id + AND organization_role_assignments.workos_user_id = users.workos_id + AND organization_role_assignments.deleted_at IS NULL + ORDER BY organization_role_assignments.created_at + LIMIT 1 +) AS ora ON TRUE +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = our.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE our.organization_id = @organization_id + AND users.workos_id = sqlc.arg(workos_user_id)::text + AND users.workos_id IS NOT NULL + AND our.deleted IS FALSE +LIMIT 1; + +-- name: ListOrganizationRoleAssignmentsByWorkosUsers :many +SELECT + our.user_id, + users.workos_id::text AS workos_user_id, + our.workos_membership_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug, '')::text AS role_slug, + COALESCE(ora.created_at, our.created_at)::timestamptz AS created_at +FROM organization_user_relationships AS our +JOIN users + ON users.id = our.user_id +LEFT JOIN LATERAL ( + SELECT * + FROM organization_role_assignments + WHERE organization_role_assignments.organization_id = our.organization_id + AND organization_role_assignments.workos_user_id = users.workos_id + AND organization_role_assignments.deleted_at IS NULL + ORDER BY organization_role_assignments.created_at + LIMIT 1 +) AS ora ON TRUE +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = our.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE our.organization_id = @organization_id + AND users.workos_id = ANY(@workos_user_ids::text[]) + AND users.workos_id IS NOT NULL + AND our.workos_membership_id IS NOT NULL + AND our.deleted IS FALSE +ORDER BY users.workos_id, role_slug; + +-- name: ListAccessMembers :many +SELECT + users.id, + users.display_name, + users.email, + users.photo_url, + COALESCE(organization_roles.id::text, global_roles.id::text, '')::text AS role_id, + COALESCE(ora.created_at, our.created_at)::timestamptz AS joined_at +FROM organization_user_relationships AS our +JOIN users + ON users.id = our.user_id +LEFT JOIN LATERAL ( + SELECT * + FROM organization_role_assignments + WHERE organization_role_assignments.organization_id = our.organization_id + AND organization_role_assignments.workos_user_id = users.workos_id + AND organization_role_assignments.deleted_at IS NULL + ORDER BY organization_role_assignments.created_at + LIMIT 1 +) AS ora ON TRUE +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = our.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE our.organization_id = @organization_id + AND our.deleted IS FALSE +ORDER BY users.email, users.id; + +-- name: ListMemberRolePrincipalsByWorkosUser :many +SELECT + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.role_urn::text AS principal_urn +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = @organization_id + AND ora.workos_user_id = @workos_user_id + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL + AND ora.deleted_at IS NULL +ORDER BY role_slug; + +-- name: ListOrganizationRoleAssignmentRecordsByWorkosUser :many +SELECT + id, + organization_id, + workos_user_id, + user_id, + role_urn, + workos_membership_id, + workos_updated_at, + workos_last_event_id, + created_at, + updated_at, + deleted_at +FROM organization_role_assignments +WHERE organization_id = @organization_id + AND workos_user_id = @workos_user_id +ORDER BY created_at; + +-- name: UpsertOrganizationRoleAssignment :execrows +WITH input_role_urn AS ( + SELECT role_urn + FROM ( + SELECT 'role:organization:' || id::text AS role_urn, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_roles.organization_id = @organization_id + AND organization_roles.workos_slug = sqlc.arg(workos_role_slug) + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE + UNION ALL + SELECT 'role:global:' || id::text AS role_urn, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.workos_slug = sqlc.arg(workos_role_slug) + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE + ) roles + -- Keep assignment writes aligned with role reads: org roles shadow global roles. + ORDER BY role_kind DESC + LIMIT 1 +) +INSERT INTO organization_role_assignments ( + organization_id, + workos_user_id, + user_id, + role_urn, + workos_membership_id, + workos_updated_at, + workos_last_event_id +) +SELECT + @organization_id, + @workos_user_id, + @user_id, + input_role_urn.role_urn, + @workos_membership_id, + @workos_updated_at, + @workos_last_event_id +FROM input_role_urn +ON CONFLICT (organization_id, workos_user_id, role_urn) WHERE deleted_at IS NULL DO UPDATE SET + user_id = COALESCE(EXCLUDED.user_id, organization_role_assignments.user_id), + workos_membership_id = EXCLUDED.workos_membership_id, + workos_updated_at = EXCLUDED.workos_updated_at, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, organization_role_assignments.workos_last_event_id), + updated_at = clock_timestamp(); + +-- name: ReplaceOrganizationRoleAssignment :one +WITH input_role_urn AS ( + SELECT role_urn + FROM ( + SELECT 'role:organization:' || id::text AS role_urn, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_roles.organization_id = @organization_id + AND organization_roles.workos_slug = sqlc.arg(workos_role_slug) + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE + UNION ALL + SELECT 'role:global:' || id::text AS role_urn, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.workos_slug = sqlc.arg(workos_role_slug) + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE + ) roles + -- Keep assignment writes aligned with role reads: org roles shadow global roles. + ORDER BY role_kind DESC + LIMIT 1 +), +upserted AS ( + INSERT INTO organization_role_assignments ( + organization_id, + workos_user_id, + user_id, + role_urn, + workos_membership_id, + workos_updated_at, + workos_last_event_id + ) + SELECT + @organization_id, + @workos_user_id, + @user_id, + input_role_urn.role_urn, + @workos_membership_id, + @workos_updated_at, + @workos_last_event_id + FROM input_role_urn + ON CONFLICT (organization_id, workos_user_id, role_urn) WHERE deleted_at IS NULL DO UPDATE SET + user_id = COALESCE(EXCLUDED.user_id, organization_role_assignments.user_id), + workos_membership_id = EXCLUDED.workos_membership_id, + workos_updated_at = EXCLUDED.workos_updated_at, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, organization_role_assignments.workos_last_event_id), + updated_at = clock_timestamp() + RETURNING role_urn +), +deleted AS ( +UPDATE organization_role_assignments +SET deleted_at = clock_timestamp(), + updated_at = clock_timestamp() +WHERE organization_role_assignments.organization_id = @organization_id + AND organization_role_assignments.workos_user_id = @workos_user_id + AND EXISTS (SELECT 1 FROM upserted) + AND organization_role_assignments.role_urn NOT IN (SELECT role_urn FROM upserted) + AND organization_role_assignments.deleted_at IS NULL + RETURNING 1 +) +SELECT COUNT(*)::bigint FROM upserted; diff --git a/server/internal/access/rbac_test.go b/server/internal/access/rbac_test.go index c7a9163398..6dc2158684 100644 --- a/server/internal/access/rbac_test.go +++ b/server/internal/access/rbac_test.go @@ -33,14 +33,9 @@ func TestService_ListRoles_AllowsOrgReadGrant(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgRead, Selector: authz.NewSelector(authz.ScopeOrgRead, testAccessAuthContext(t, ctx).ActiveOrganizationID)}) - - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() + authCtx := testAccessAuthContext(t, ctx) + ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgRead, Selector: authz.NewSelector(authz.ScopeOrgRead, authCtx.ActiveOrganizationID)}) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) result, err := ti.service.ListRoles(ctx, &gen.ListRolesPayload{}) require.NoError(t, err) @@ -66,17 +61,12 @@ func TestService_GetRole_AllowsOrgReadGrant(t *testing.T) { authCtx := testAccessAuthContext(t, ctx) ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgRead, Selector: authz.NewSelector(authz.ScopeOrgRead, authCtx.ActiveOrganizationID)}) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Can build selected resources"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build selected resources")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") - role, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: "role_custom"}) + role, err := ti.service.GetRole(ctx, &gen.GetRolePayload{ID: roleID}) require.NoError(t, err) - require.Equal(t, "role_custom", role.ID) + require.Equal(t, roleID, role.ID) } func TestService_ListScopes_ForbiddenWithoutOrgReadGrant(t *testing.T) { @@ -122,20 +112,12 @@ func TestService_ListMembers_AllowsOrgReadGrant(t *testing.T) { ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgRead, Selector: authz.NewSelector(authz.ScopeOrgRead, authCtx.ActiveOrganizationID)}) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") - - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User{ - "user_1": mockUser("user_1", "Ada", "Lovelace", "ada@example.com"), - }, nil).Once() + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) result, err := ti.service.ListMembers(ctx, &gen.ListMembersPayload{}) require.NoError(t, err) - require.Len(t, result.Members, 1) + require.Len(t, result.Members, 2) } func TestService_CreateRole_ForbiddenWithoutOrgAdminGrant(t *testing.T) { @@ -171,7 +153,8 @@ func TestService_CreateRole_AllowsOrgAdminGrant(t *testing.T) { role, err := ti.service.CreateRole(ctx, &gen.CreateRolePayload{Name: "Allowed", Description: "Allowed"}) require.NoError(t, err) - require.Equal(t, "role_allowed", role.ID) + require.NotEmpty(t, role.ID) + require.NotEqual(t, "role_allowed", role.ID) } func TestService_UpdateRole_ForbiddenWithoutOrgAdminGrant(t *testing.T) { @@ -194,10 +177,7 @@ func TestService_UpdateRole_AllowsOrgAdminGrant(t *testing.T) { ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgAdmin, Selector: authz.NewSelector(authz.ScopeOrgAdmin, authCtx.ActiveOrganizationID)}) name := "Updated" - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old")) ti.roles.On("UpdateRole", mock.Anything, mockidp.MockOrgID, "custom-builder", thirdpartyworkos.UpdateRoleOpts{Name: &name}).Return(&thirdpartyworkos.Role{ ID: "role_custom", Name: name, @@ -206,9 +186,8 @@ func TestService_UpdateRole_AllowsOrgAdminGrant(t *testing.T) { CreatedAt: mockRoleTimestamp, UpdatedAt: mockRoleTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() - role, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: "role_custom", Name: &name}) + role, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: roleID, Name: &name}) require.NoError(t, err) require.Equal(t, name, role.Name) } @@ -232,14 +211,11 @@ func TestService_DeleteRole_AllowsOrgAdminGrant(t *testing.T) { authCtx := testAccessAuthContext(t, ctx) ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgAdmin, Selector: authz.NewSelector(authz.ScopeOrgAdmin, authCtx.ActiveOrganizationID)}) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) ti.roles.On("DeleteRole", mock.Anything, mockidp.MockOrgID, "custom-builder").Return(nil).Once() seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-1") - err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: "role_custom"}) + err := ti.service.DeleteRole(ctx, &gen.DeleteRolePayload{ID: roleID}) require.NoError(t, err) } @@ -262,12 +238,10 @@ func TestService_UpdateMemberRole_AllowsOrgAdminGrant(t *testing.T) { authCtx := testAccessAuthContext(t, ctx) ctx = withRBACGrants(t, ctx, authz.Grant{Scope: authz.ScopeOrgAdmin, Selector: authz.NewSelector(authz.ScopeOrgAdmin, authCtx.ActiveOrganizationID)}) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -275,14 +249,10 @@ func TestService_UpdateMemberRole_AllowsOrgAdminGrant(t *testing.T) { RoleSlug: "custom-builder", CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User{ - "user_1": mockUser("user_1", "Ada", "Lovelace", "ada@example.com"), - }, nil).Once() - seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") - member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "role_builder"}) + member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) require.NoError(t, err) - require.Equal(t, "role_builder", member.RoleID) + require.Equal(t, builderID, member.RoleID) } func withRBACGrants(t *testing.T, ctx context.Context, grants ...authz.Grant) context.Context { diff --git a/server/internal/access/repo/models.go b/server/internal/access/repo/models.go index 8faea64552..376864a673 100644 --- a/server/internal/access/repo/models.go +++ b/server/internal/access/repo/models.go @@ -61,3 +61,17 @@ type OrganizationRole struct { DeletedAt pgtype.Timestamptz Deleted bool } + +type OrganizationRoleAssignment struct { + ID uuid.UUID + OrganizationID string + WorkosUserID string + UserID pgtype.Text + RoleUrn string + WorkosMembershipID pgtype.Text + WorkosUpdatedAt pgtype.Timestamptz + WorkosLastEventID pgtype.Text + CreatedAt pgtype.Timestamptz + UpdatedAt pgtype.Timestamptz + DeletedAt pgtype.Timestamptz +} diff --git a/server/internal/access/repo/queries.sql.go b/server/internal/access/repo/queries.sql.go index 530d7c0df1..ea7dc7904c 100644 --- a/server/internal/access/repo/queries.sql.go +++ b/server/internal/access/repo/queries.sql.go @@ -54,6 +54,74 @@ func (q *Queries) DeletePrincipalGrantsByPrincipal(ctx context.Context, arg Dele return result.RowsAffected(), nil } +const getActiveOrganizationRoleBySlug = `-- name: GetActiveOrganizationRoleBySlug :one +WITH active_roles AS ( + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.workos_slug = $2 + AND deleted IS FALSE + AND workos_deleted IS FALSE + UNION ALL + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_id = $1 + AND organization_roles.workos_slug = $2 + AND deleted IS FALSE + AND workos_deleted IS FALSE +) +SELECT + active_roles.id, + ('role:' || active_roles.role_kind || ':' || active_roles.id::text)::text AS role_urn, + active_roles.workos_slug, + active_roles.workos_name, + active_roles.workos_description, + active_roles.workos_created_at, + active_roles.workos_updated_at, + COUNT(DISTINCT ora.id)::bigint AS member_count +FROM active_roles +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = $1 + AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text + AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL +GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +ORDER BY active_roles.role_kind DESC +LIMIT 1 +` + +type GetActiveOrganizationRoleBySlugParams struct { + OrganizationID string + WorkosSlug string +} + +type GetActiveOrganizationRoleBySlugRow struct { + ID uuid.UUID + RoleUrn string + WorkosSlug string + WorkosName string + WorkosDescription pgtype.Text + WorkosCreatedAt pgtype.Timestamptz + WorkosUpdatedAt pgtype.Timestamptz + MemberCount int64 +} + +// Organization roles shadow global roles if a slug ever collides. +func (q *Queries) GetActiveOrganizationRoleBySlug(ctx context.Context, arg GetActiveOrganizationRoleBySlugParams) (GetActiveOrganizationRoleBySlugRow, error) { + row := q.db.QueryRow(ctx, getActiveOrganizationRoleBySlug, arg.OrganizationID, arg.WorkosSlug) + var i GetActiveOrganizationRoleBySlugRow + err := row.Scan( + &i.ID, + &i.RoleUrn, + &i.WorkosSlug, + &i.WorkosName, + &i.WorkosDescription, + &i.WorkosCreatedAt, + &i.WorkosUpdatedAt, + &i.MemberCount, + ) + return i, err +} + const getGlobalRoleBySlug = `-- name: GetGlobalRoleBySlug :one SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, workos_deleted_at, workos_deleted, workos_last_event_id, created_at, updated_at, deleted_at, deleted FROM global_roles @@ -81,6 +149,138 @@ func (q *Queries) GetGlobalRoleBySlug(ctx context.Context, workosSlug string) (G return i, err } +const getOrganizationRoleAssignmentByWorkosUser = `-- name: GetOrganizationRoleAssignmentByWorkosUser :one +SELECT + our.user_id, + users.workos_id::text AS workos_user_id, + our.workos_membership_id, + COALESCE(organization_roles.id::text, global_roles.id::text, '')::text AS role_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug, '')::text AS role_slug, + COALESCE(ora.created_at, our.created_at)::timestamptz AS created_at +FROM organization_user_relationships AS our +JOIN users + ON users.id = our.user_id +LEFT JOIN LATERAL ( + SELECT id, organization_id, workos_user_id, user_id, role_urn, workos_membership_id, workos_updated_at, workos_last_event_id, created_at, updated_at, deleted_at + FROM organization_role_assignments + WHERE organization_role_assignments.organization_id = our.organization_id + AND organization_role_assignments.workos_user_id = users.workos_id + AND organization_role_assignments.deleted_at IS NULL + ORDER BY organization_role_assignments.created_at + LIMIT 1 +) AS ora ON TRUE +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = our.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE our.organization_id = $1 + AND users.workos_id = $2::text + AND users.workos_id IS NOT NULL + AND our.deleted IS FALSE +LIMIT 1 +` + +type GetOrganizationRoleAssignmentByWorkosUserParams struct { + OrganizationID string + WorkosUserID string +} + +type GetOrganizationRoleAssignmentByWorkosUserRow struct { + UserID pgtype.Text + WorkosUserID string + WorkosMembershipID pgtype.Text + RoleID string + RoleSlug string + CreatedAt pgtype.Timestamptz +} + +func (q *Queries) GetOrganizationRoleAssignmentByWorkosUser(ctx context.Context, arg GetOrganizationRoleAssignmentByWorkosUserParams) (GetOrganizationRoleAssignmentByWorkosUserRow, error) { + row := q.db.QueryRow(ctx, getOrganizationRoleAssignmentByWorkosUser, arg.OrganizationID, arg.WorkosUserID) + var i GetOrganizationRoleAssignmentByWorkosUserRow + err := row.Scan( + &i.UserID, + &i.WorkosUserID, + &i.WorkosMembershipID, + &i.RoleID, + &i.RoleSlug, + &i.CreatedAt, + ) + return i, err +} + +const getOrganizationRoleByID = `-- name: GetOrganizationRoleByID :one +WITH active_roles AS ( + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.id = $2 + AND deleted IS FALSE + AND workos_deleted IS FALSE +UNION ALL + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_id = $1 + AND organization_roles.id = $2 + AND deleted IS FALSE + AND workos_deleted IS FALSE +) +SELECT + active_roles.id, + ('role:' || active_roles.role_kind || ':' || active_roles.id::text)::text AS role_urn, + active_roles.workos_slug, + active_roles.workos_name, + active_roles.workos_description, + active_roles.workos_created_at, + active_roles.workos_updated_at, + COUNT(DISTINCT ora.id)::bigint AS member_count +FROM active_roles +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = $1 + AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text + AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL +GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +ORDER BY active_roles.role_kind DESC +LIMIT 1 +` + +type GetOrganizationRoleByIDParams struct { + OrganizationID string + ID uuid.UUID +} + +type GetOrganizationRoleByIDRow struct { + ID uuid.UUID + RoleUrn string + WorkosSlug string + WorkosName string + WorkosDescription pgtype.Text + WorkosCreatedAt pgtype.Timestamptz + WorkosUpdatedAt pgtype.Timestamptz + MemberCount int64 +} + +// Organization roles shadow global roles if an ID ever collides. +func (q *Queries) GetOrganizationRoleByID(ctx context.Context, arg GetOrganizationRoleByIDParams) (GetOrganizationRoleByIDRow, error) { + row := q.db.QueryRow(ctx, getOrganizationRoleByID, arg.OrganizationID, arg.ID) + var i GetOrganizationRoleByIDRow + err := row.Scan( + &i.ID, + &i.RoleUrn, + &i.WorkosSlug, + &i.WorkosName, + &i.WorkosDescription, + &i.WorkosCreatedAt, + &i.WorkosUpdatedAt, + &i.MemberCount, + ) + return i, err +} + const getOrganizationRoleBySlug = `-- name: GetOrganizationRoleBySlug :one SELECT id, organization_id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, workos_deleted_at, workos_deleted, workos_last_event_id, created_at, updated_at, deleted_at, deleted FROM organization_roles @@ -229,6 +429,148 @@ func (q *Queries) InsertChallengeResolutions(ctx context.Context, arg InsertChal return items, nil } +const listAccessMembers = `-- name: ListAccessMembers :many +SELECT + users.id, + users.display_name, + users.email, + users.photo_url, + COALESCE(organization_roles.id::text, global_roles.id::text, '')::text AS role_id, + COALESCE(ora.created_at, our.created_at)::timestamptz AS joined_at +FROM organization_user_relationships AS our +JOIN users + ON users.id = our.user_id +LEFT JOIN LATERAL ( + SELECT id, organization_id, workos_user_id, user_id, role_urn, workos_membership_id, workos_updated_at, workos_last_event_id, created_at, updated_at, deleted_at + FROM organization_role_assignments + WHERE organization_role_assignments.organization_id = our.organization_id + AND organization_role_assignments.workos_user_id = users.workos_id + AND organization_role_assignments.deleted_at IS NULL + ORDER BY organization_role_assignments.created_at + LIMIT 1 +) AS ora ON TRUE +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = our.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE our.organization_id = $1 + AND our.deleted IS FALSE +ORDER BY users.email, users.id +` + +type ListAccessMembersRow struct { + ID string + DisplayName string + Email string + PhotoUrl pgtype.Text + RoleID string + JoinedAt pgtype.Timestamptz +} + +func (q *Queries) ListAccessMembers(ctx context.Context, organizationID string) ([]ListAccessMembersRow, error) { + rows, err := q.db.Query(ctx, listAccessMembers, organizationID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListAccessMembersRow + for rows.Next() { + var i ListAccessMembersRow + if err := rows.Scan( + &i.ID, + &i.DisplayName, + &i.Email, + &i.PhotoUrl, + &i.RoleID, + &i.JoinedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listActiveOrganizationRoles = `-- name: ListActiveOrganizationRoles :many +WITH active_roles AS ( + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'global'::text AS role_kind + FROM global_roles + WHERE deleted IS FALSE + AND workos_deleted IS FALSE + UNION ALL + SELECT id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_id = $1 + AND deleted IS FALSE + AND workos_deleted IS FALSE +) +SELECT + active_roles.id, + ('role:' || active_roles.role_kind || ':' || active_roles.id::text)::text AS role_urn, + active_roles.workos_slug, + active_roles.workos_name, + active_roles.workos_description, + active_roles.workos_created_at, + active_roles.workos_updated_at, + COUNT(DISTINCT ora.id)::bigint AS member_count +FROM active_roles +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = $1 + AND ora.role_urn = 'role:' || active_roles.role_kind || ':' || active_roles.id::text + AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL +GROUP BY active_roles.id, active_roles.role_kind, active_roles.workos_slug, active_roles.workos_name, active_roles.workos_description, active_roles.workos_created_at, active_roles.workos_updated_at +ORDER BY active_roles.workos_slug +` + +type ListActiveOrganizationRolesRow struct { + ID uuid.UUID + RoleUrn string + WorkosSlug string + WorkosName string + WorkosDescription pgtype.Text + WorkosCreatedAt pgtype.Timestamptz + WorkosUpdatedAt pgtype.Timestamptz + MemberCount int64 +} + +func (q *Queries) ListActiveOrganizationRoles(ctx context.Context, organizationID string) ([]ListActiveOrganizationRolesRow, error) { + rows, err := q.db.Query(ctx, listActiveOrganizationRoles, organizationID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListActiveOrganizationRolesRow + for rows.Next() { + var i ListActiveOrganizationRolesRow + if err := rows.Scan( + &i.ID, + &i.RoleUrn, + &i.WorkosSlug, + &i.WorkosName, + &i.WorkosDescription, + &i.WorkosCreatedAt, + &i.WorkosUpdatedAt, + &i.MemberCount, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listChallengeResolutions = `-- name: ListChallengeResolutions :many SELECT id, organization_id, challenge_id, principal_urn, scope, resource_kind, resource_id, resolution_type, role_slug, resolved_by, created_at FROM authz_challenge_resolutions @@ -317,6 +659,250 @@ func (q *Queries) ListGlobalRoles(ctx context.Context) ([]GlobalRole, error) { return items, nil } +const listMemberRolePrincipalsByWorkosUser = `-- name: ListMemberRolePrincipalsByWorkosUser :many +SELECT + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.role_urn::text AS principal_urn +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = $1 + AND ora.workos_user_id = $2 + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) IS NOT NULL + AND ora.deleted_at IS NULL +ORDER BY role_slug +` + +type ListMemberRolePrincipalsByWorkosUserParams struct { + OrganizationID string + WorkosUserID string +} + +type ListMemberRolePrincipalsByWorkosUserRow struct { + RoleSlug string + PrincipalUrn string +} + +func (q *Queries) ListMemberRolePrincipalsByWorkosUser(ctx context.Context, arg ListMemberRolePrincipalsByWorkosUserParams) ([]ListMemberRolePrincipalsByWorkosUserRow, error) { + rows, err := q.db.Query(ctx, listMemberRolePrincipalsByWorkosUser, arg.OrganizationID, arg.WorkosUserID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListMemberRolePrincipalsByWorkosUserRow + for rows.Next() { + var i ListMemberRolePrincipalsByWorkosUserRow + if err := rows.Scan(&i.RoleSlug, &i.PrincipalUrn); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listOrganizationRoleAssignmentRecordsByWorkosUser = `-- name: ListOrganizationRoleAssignmentRecordsByWorkosUser :many +SELECT + id, + organization_id, + workos_user_id, + user_id, + role_urn, + workos_membership_id, + workos_updated_at, + workos_last_event_id, + created_at, + updated_at, + deleted_at +FROM organization_role_assignments +WHERE organization_id = $1 + AND workos_user_id = $2 +ORDER BY created_at +` + +type ListOrganizationRoleAssignmentRecordsByWorkosUserParams struct { + OrganizationID string + WorkosUserID string +} + +func (q *Queries) ListOrganizationRoleAssignmentRecordsByWorkosUser(ctx context.Context, arg ListOrganizationRoleAssignmentRecordsByWorkosUserParams) ([]OrganizationRoleAssignment, error) { + rows, err := q.db.Query(ctx, listOrganizationRoleAssignmentRecordsByWorkosUser, arg.OrganizationID, arg.WorkosUserID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []OrganizationRoleAssignment + for rows.Next() { + var i OrganizationRoleAssignment + if err := rows.Scan( + &i.ID, + &i.OrganizationID, + &i.WorkosUserID, + &i.UserID, + &i.RoleUrn, + &i.WorkosMembershipID, + &i.WorkosUpdatedAt, + &i.WorkosLastEventID, + &i.CreatedAt, + &i.UpdatedAt, + &i.DeletedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listOrganizationRoleAssignmentsBySlug = `-- name: ListOrganizationRoleAssignmentsBySlug :many +SELECT + ora.user_id, + ora.workos_user_id, + ora.workos_membership_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug)::text AS role_slug, + ora.created_at +FROM organization_role_assignments AS ora +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = ora.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE ora.organization_id = $1 + AND COALESCE(organization_roles.workos_slug, global_roles.workos_slug) = $2 + AND ora.deleted_at IS NULL +ORDER BY ora.workos_user_id +` + +type ListOrganizationRoleAssignmentsBySlugParams struct { + OrganizationID string + WorkosRoleSlug string +} + +type ListOrganizationRoleAssignmentsBySlugRow struct { + UserID pgtype.Text + WorkosUserID string + WorkosMembershipID pgtype.Text + RoleSlug string + CreatedAt pgtype.Timestamptz +} + +func (q *Queries) ListOrganizationRoleAssignmentsBySlug(ctx context.Context, arg ListOrganizationRoleAssignmentsBySlugParams) ([]ListOrganizationRoleAssignmentsBySlugRow, error) { + rows, err := q.db.Query(ctx, listOrganizationRoleAssignmentsBySlug, arg.OrganizationID, arg.WorkosRoleSlug) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListOrganizationRoleAssignmentsBySlugRow + for rows.Next() { + var i ListOrganizationRoleAssignmentsBySlugRow + if err := rows.Scan( + &i.UserID, + &i.WorkosUserID, + &i.WorkosMembershipID, + &i.RoleSlug, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const listOrganizationRoleAssignmentsByWorkosUsers = `-- name: ListOrganizationRoleAssignmentsByWorkosUsers :many +SELECT + our.user_id, + users.workos_id::text AS workos_user_id, + our.workos_membership_id, + COALESCE(organization_roles.workos_slug, global_roles.workos_slug, '')::text AS role_slug, + COALESCE(ora.created_at, our.created_at)::timestamptz AS created_at +FROM organization_user_relationships AS our +JOIN users + ON users.id = our.user_id +LEFT JOIN LATERAL ( + SELECT id, organization_id, workos_user_id, user_id, role_urn, workos_membership_id, workos_updated_at, workos_last_event_id, created_at, updated_at, deleted_at + FROM organization_role_assignments + WHERE organization_role_assignments.organization_id = our.organization_id + AND organization_role_assignments.workos_user_id = users.workos_id + AND organization_role_assignments.deleted_at IS NULL + ORDER BY organization_role_assignments.created_at + LIMIT 1 +) AS ora ON TRUE +LEFT JOIN organization_roles + ON ora.role_urn = 'role:organization:' || organization_roles.id::text + AND organization_roles.organization_id = our.organization_id + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE +LEFT JOIN global_roles + ON ora.role_urn = 'role:global:' || global_roles.id::text + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE +WHERE our.organization_id = $1 + AND users.workos_id = ANY($2::text[]) + AND users.workos_id IS NOT NULL + AND our.workos_membership_id IS NOT NULL + AND our.deleted IS FALSE +ORDER BY users.workos_id, role_slug +` + +type ListOrganizationRoleAssignmentsByWorkosUsersParams struct { + OrganizationID string + WorkosUserIds []string +} + +type ListOrganizationRoleAssignmentsByWorkosUsersRow struct { + UserID pgtype.Text + WorkosUserID string + WorkosMembershipID pgtype.Text + RoleSlug string + CreatedAt pgtype.Timestamptz +} + +func (q *Queries) ListOrganizationRoleAssignmentsByWorkosUsers(ctx context.Context, arg ListOrganizationRoleAssignmentsByWorkosUsersParams) ([]ListOrganizationRoleAssignmentsByWorkosUsersRow, error) { + rows, err := q.db.Query(ctx, listOrganizationRoleAssignmentsByWorkosUsers, arg.OrganizationID, arg.WorkosUserIds) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListOrganizationRoleAssignmentsByWorkosUsersRow + for rows.Next() { + var i ListOrganizationRoleAssignmentsByWorkosUsersRow + if err := rows.Scan( + &i.UserID, + &i.WorkosUserID, + &i.WorkosMembershipID, + &i.RoleSlug, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listOrganizationRolesByOrg = `-- name: ListOrganizationRolesByOrg :many SELECT id, organization_id, workos_slug, workos_name, workos_description, workos_created_at, workos_updated_at, workos_deleted_at, workos_deleted, workos_last_event_id, created_at, updated_at, deleted_at, deleted FROM organization_roles @@ -474,6 +1060,115 @@ func (q *Queries) MarkOrganizationRoleDeleted(ctx context.Context, arg MarkOrgan return result.RowsAffected(), nil } +const markOrganizationRoleDeletedLocally = `-- name: MarkOrganizationRoleDeletedLocally :execrows +UPDATE organization_roles +SET deleted_at = clock_timestamp(), + updated_at = clock_timestamp() +WHERE organization_id = $1 + AND workos_slug = $2 + AND deleted_at IS NULL +` + +type MarkOrganizationRoleDeletedLocallyParams struct { + OrganizationID string + WorkosSlug string +} + +func (q *Queries) MarkOrganizationRoleDeletedLocally(ctx context.Context, arg MarkOrganizationRoleDeletedLocallyParams) (int64, error) { + result, err := q.db.Exec(ctx, markOrganizationRoleDeletedLocally, arg.OrganizationID, arg.WorkosSlug) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil +} + +const replaceOrganizationRoleAssignment = `-- name: ReplaceOrganizationRoleAssignment :one +WITH input_role_urn AS ( + SELECT role_urn + FROM ( + SELECT 'role:organization:' || id::text AS role_urn, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_roles.organization_id = $1 + AND organization_roles.workos_slug = $2 + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE + UNION ALL + SELECT 'role:global:' || id::text AS role_urn, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.workos_slug = $2 + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE + ) roles + -- Keep assignment writes aligned with role reads: org roles shadow global roles. + ORDER BY role_kind DESC + LIMIT 1 +), +upserted AS ( + INSERT INTO organization_role_assignments ( + organization_id, + workos_user_id, + user_id, + role_urn, + workos_membership_id, + workos_updated_at, + workos_last_event_id + ) + SELECT + $1, + $3, + $4, + input_role_urn.role_urn, + $5, + $6, + $7 + FROM input_role_urn + ON CONFLICT (organization_id, workos_user_id, role_urn) WHERE deleted_at IS NULL DO UPDATE SET + user_id = COALESCE(EXCLUDED.user_id, organization_role_assignments.user_id), + workos_membership_id = EXCLUDED.workos_membership_id, + workos_updated_at = EXCLUDED.workos_updated_at, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, organization_role_assignments.workos_last_event_id), + updated_at = clock_timestamp() + RETURNING role_urn +), +deleted AS ( +UPDATE organization_role_assignments +SET deleted_at = clock_timestamp(), + updated_at = clock_timestamp() +WHERE organization_role_assignments.organization_id = $1 + AND organization_role_assignments.workos_user_id = $3 + AND EXISTS (SELECT 1 FROM upserted) + AND organization_role_assignments.role_urn NOT IN (SELECT role_urn FROM upserted) + AND organization_role_assignments.deleted_at IS NULL + RETURNING 1 +) +SELECT COUNT(*)::bigint FROM upserted +` + +type ReplaceOrganizationRoleAssignmentParams struct { + OrganizationID string + WorkosRoleSlug string + WorkosUserID string + UserID pgtype.Text + WorkosMembershipID pgtype.Text + WorkosUpdatedAt pgtype.Timestamptz + WorkosLastEventID pgtype.Text +} + +func (q *Queries) ReplaceOrganizationRoleAssignment(ctx context.Context, arg ReplaceOrganizationRoleAssignmentParams) (int64, error) { + row := q.db.QueryRow(ctx, replaceOrganizationRoleAssignment, + arg.OrganizationID, + arg.WorkosRoleSlug, + arg.WorkosUserID, + arg.UserID, + arg.WorkosMembershipID, + arg.WorkosUpdatedAt, + arg.WorkosLastEventID, + ) + var column_1 int64 + err := row.Scan(&column_1) + return column_1, err +} + const upsertGlobalRole = `-- name: UpsertGlobalRole :exec INSERT INTO global_roles ( workos_slug, @@ -494,7 +1189,7 @@ ON CONFLICT (workos_slug) DO UPDATE SET workos_name = EXCLUDED.workos_name, workos_description = EXCLUDED.workos_description, workos_updated_at = EXCLUDED.workos_updated_at, - workos_last_event_id = EXCLUDED.workos_last_event_id, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, global_roles.workos_last_event_id), deleted_at = NULL, workos_deleted_at = NULL, updated_at = clock_timestamp() @@ -509,9 +1204,8 @@ type UpsertGlobalRoleParams struct { WorkosLastEventID pgtype.Text } -// Upsert an environment-level WorkOS role. Caller must have already passed -// the row through ShouldProcessEvent. Resurrects a previously soft-deleted -// role on conflict. +// Upsert an environment-level role. WorkOS sync callers pass an event ID; +// local/bootstrap callers pass NULL so an existing WorkOS event cursor is preserved. func (q *Queries) UpsertGlobalRole(ctx context.Context, arg UpsertGlobalRoleParams) error { _, err := q.db.Exec(ctx, upsertGlobalRole, arg.WorkosSlug, @@ -524,7 +1218,8 @@ func (q *Queries) UpsertGlobalRole(ctx context.Context, arg UpsertGlobalRolePara return err } -const upsertOrganizationRole = `-- name: UpsertOrganizationRole :exec +const upsertOrganizationRole = `-- name: UpsertOrganizationRole :one +WITH upserted AS ( INSERT INTO organization_roles ( organization_id, workos_slug, @@ -546,10 +1241,35 @@ ON CONFLICT (organization_id, workos_slug) DO UPDATE SET workos_name = EXCLUDED.workos_name, workos_description = EXCLUDED.workos_description, workos_updated_at = EXCLUDED.workos_updated_at, - workos_last_event_id = EXCLUDED.workos_last_event_id, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, organization_roles.workos_last_event_id), deleted_at = NULL, workos_deleted_at = NULL, updated_at = clock_timestamp() +RETURNING + id, + organization_id, + workos_slug, + workos_name, + workos_description, + workos_created_at, + workos_updated_at +) +SELECT + upserted.id, + ('role:organization:' || upserted.id::text)::text AS role_urn, + upserted.workos_slug, + upserted.workos_name, + upserted.workos_description, + upserted.workos_created_at, + upserted.workos_updated_at, + COUNT(ora.id)::bigint AS member_count +FROM upserted +LEFT JOIN organization_role_assignments AS ora + ON ora.organization_id = upserted.organization_id + AND ora.role_urn = 'role:organization:' || upserted.id::text + AND ora.user_id IS NOT NULL + AND ora.deleted_at IS NULL +GROUP BY upserted.id, upserted.workos_slug, upserted.workos_name, upserted.workos_description, upserted.workos_created_at, upserted.workos_updated_at ` type UpsertOrganizationRoleParams struct { @@ -562,11 +1282,21 @@ type UpsertOrganizationRoleParams struct { WorkosLastEventID pgtype.Text } -// Upsert an org-scoped WorkOS role. Caller must have already passed the row -// through ShouldProcessEvent. Resurrects a previously soft-deleted role on -// conflict. -func (q *Queries) UpsertOrganizationRole(ctx context.Context, arg UpsertOrganizationRoleParams) error { - _, err := q.db.Exec(ctx, upsertOrganizationRole, +type UpsertOrganizationRoleRow struct { + ID uuid.UUID + RoleUrn string + WorkosSlug string + WorkosName string + WorkosDescription pgtype.Text + WorkosCreatedAt pgtype.Timestamptz + WorkosUpdatedAt pgtype.Timestamptz + MemberCount int64 +} + +// Upsert an org-scoped role. WorkOS sync callers pass an event ID; local role +// lifecycle callers pass NULL so an existing WorkOS event cursor is preserved. +func (q *Queries) UpsertOrganizationRole(ctx context.Context, arg UpsertOrganizationRoleParams) (UpsertOrganizationRoleRow, error) { + row := q.db.QueryRow(ctx, upsertOrganizationRole, arg.OrganizationID, arg.WorkosSlug, arg.WorkosName, @@ -575,7 +1305,91 @@ func (q *Queries) UpsertOrganizationRole(ctx context.Context, arg UpsertOrganiza arg.WorkosUpdatedAt, arg.WorkosLastEventID, ) - return err + var i UpsertOrganizationRoleRow + err := row.Scan( + &i.ID, + &i.RoleUrn, + &i.WorkosSlug, + &i.WorkosName, + &i.WorkosDescription, + &i.WorkosCreatedAt, + &i.WorkosUpdatedAt, + &i.MemberCount, + ) + return i, err +} + +const upsertOrganizationRoleAssignment = `-- name: UpsertOrganizationRoleAssignment :execrows +WITH input_role_urn AS ( + SELECT role_urn + FROM ( + SELECT 'role:organization:' || id::text AS role_urn, 'organization'::text AS role_kind + FROM organization_roles + WHERE organization_roles.organization_id = $1 + AND organization_roles.workos_slug = $7 + AND organization_roles.deleted IS FALSE + AND organization_roles.workos_deleted IS FALSE + UNION ALL + SELECT 'role:global:' || id::text AS role_urn, 'global'::text AS role_kind + FROM global_roles + WHERE global_roles.workos_slug = $7 + AND global_roles.deleted IS FALSE + AND global_roles.workos_deleted IS FALSE + ) roles + -- Keep assignment writes aligned with role reads: org roles shadow global roles. + ORDER BY role_kind DESC + LIMIT 1 +) +INSERT INTO organization_role_assignments ( + organization_id, + workos_user_id, + user_id, + role_urn, + workos_membership_id, + workos_updated_at, + workos_last_event_id +) +SELECT + $1, + $2, + $3, + input_role_urn.role_urn, + $4, + $5, + $6 +FROM input_role_urn +ON CONFLICT (organization_id, workos_user_id, role_urn) WHERE deleted_at IS NULL DO UPDATE SET + user_id = COALESCE(EXCLUDED.user_id, organization_role_assignments.user_id), + workos_membership_id = EXCLUDED.workos_membership_id, + workos_updated_at = EXCLUDED.workos_updated_at, + workos_last_event_id = COALESCE(EXCLUDED.workos_last_event_id, organization_role_assignments.workos_last_event_id), + updated_at = clock_timestamp() +` + +type UpsertOrganizationRoleAssignmentParams struct { + OrganizationID string + WorkosUserID string + UserID pgtype.Text + WorkosMembershipID pgtype.Text + WorkosUpdatedAt pgtype.Timestamptz + WorkosLastEventID pgtype.Text + WorkosRoleSlug string +} + +func (q *Queries) UpsertOrganizationRoleAssignment(ctx context.Context, arg UpsertOrganizationRoleAssignmentParams) (int64, error) { + result, err := q.db.Exec(ctx, upsertOrganizationRoleAssignment, + arg.OrganizationID, + arg.WorkosUserID, + arg.UserID, + arg.WorkosMembershipID, + arg.WorkosUpdatedAt, + arg.WorkosLastEventID, + arg.WorkosRoleSlug, + ) + if err != nil { + return 0, err + } + return result.RowsAffected(), nil } const upsertPrincipalGrant = `-- name: UpsertPrincipalGrant :one diff --git a/server/internal/access/role_manager.go b/server/internal/access/role_manager.go new file mode 100644 index 0000000000..514d520044 --- /dev/null +++ b/server/internal/access/role_manager.go @@ -0,0 +1,950 @@ +package access + +import ( + "context" + "errors" + "fmt" + "log/slog" + "regexp" + "strings" + "time" + + "github.com/cenkalti/backoff/v5" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "go.opentelemetry.io/otel/trace" + + gen "github.com/speakeasy-api/gram/server/gen/access" + "github.com/speakeasy-api/gram/server/internal/access/repo" + "github.com/speakeasy-api/gram/server/internal/attr" + "github.com/speakeasy-api/gram/server/internal/audit" + "github.com/speakeasy-api/gram/server/internal/authz" + "github.com/speakeasy-api/gram/server/internal/conv" + "github.com/speakeasy-api/gram/server/internal/o11y" + "github.com/speakeasy-api/gram/server/internal/oops" + "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" + "github.com/speakeasy-api/gram/server/internal/urn" + usersrepo "github.com/speakeasy-api/gram/server/internal/users/repo" +) + +var ErrRoleNotFound = errors.New("role not found") + +var validRoleNamePattern = regexp.MustCompile(`^[A-Za-z0-9 _-]+$`) + +const workOSSyncAttempts = 3 + +type RoleProvider interface { + CreateRole(ctx context.Context, orgID string, opts workos.CreateRoleOpts) (*workos.Role, error) + UpdateRole(ctx context.Context, orgID string, roleSlug string, opts workos.UpdateRoleOpts) (*workos.Role, error) + DeleteRole(ctx context.Context, orgID string, roleSlug string) error + UpdateMemberRole(ctx context.Context, membershipID string, roleSlug string) (*workos.Member, error) + GetOrgMembership(ctx context.Context, workOSUserID, workOSOrgID string) (*workos.Member, error) +} + +// RoleManager owns role reads and writes against local records, then syncs successful writes to WorkOS. +type RoleManager struct { + db *pgxpool.Pool + logger *slog.Logger + roles RoleProvider + audit *audit.Logger +} + +// NewRoleManager wires the role manager to the local DB, the WorkOS role client, and the audit logger. +func NewRoleManager(logger *slog.Logger, db *pgxpool.Pool, roles RoleProvider, auditLogger *audit.Logger) *RoleManager { + return &RoleManager{ + db: db, + logger: logger.With(attr.SlogComponent("access.role_manager")), + roles: roles, + audit: auditLogger, + } +} + +// ListRoles returns active roles for an organization from local records and enriches them with local grants and member counts. +func (r *RoleManager) ListRoles(ctx context.Context, gramOrgID string) (*gen.ListRolesResult, error) { + rows, err := repo.New(r.db).ListActiveOrganizationRoles(ctx, gramOrgID) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "list roles").Log(ctx, r.logger) + } + + roles := make([]*gen.Role, 0, len(rows)) + for _, row := range rows { + role, err := r.roleViewFromLocalRole(ctx, gramOrgID, localRole{ + ID: row.ID.String(), + PrincipalURN: row.RoleUrn, + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), + }) + if err != nil { + return nil, err + } + roles = append(roles, role) + } + + return &gen.ListRolesResult{Roles: roles}, nil +} + +// GetRoleByID returns one active role from the local role table with local grants and member count. +func (r *RoleManager) GetRoleByID(ctx context.Context, gramOrgID, id string) (*gen.Role, error) { + role, err := r.getLocalRoleByID(ctx, gramOrgID, id) + if err != nil { + return nil, err + } + + return r.roleViewFromLocalRole(ctx, gramOrgID, role) +} + +// ListMembers returns locally known organization members and includes a role ID only when a local assignment exists. +func (r *RoleManager) ListMembers(ctx context.Context, gramOrgID string) (*gen.ListMembersResult, error) { + rows, err := repo.New(r.db).ListAccessMembers(ctx, gramOrgID) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) + } + + result := make([]*gen.AccessMember, 0, len(rows)) + for _, row := range rows { + result = append(result, &gen.AccessMember{ + ID: row.ID, + Name: conv.Default(row.DisplayName, row.Email), + Email: row.Email, + PhotoURL: conv.FromPGText[string](row.PhotoUrl), + RoleID: row.RoleID, + JoinedAt: conv.FromPGTimestamptz(row.JoinedAt), + }) + } + + return &gen.ListMembersResult{Members: result}, nil +} + +type roleCreateResult struct { + Role *gen.Role + Slug string +} + +type workosSync func(context.Context) + +type accessAuditActor struct { + Principal urn.Principal + DisplayName *string +} + +// CreateRole creates the local role, grants, optional assignments, and audit entry atomically, then best-effort syncs WorkOS after commit. +func (r *RoleManager) CreateRole(ctx context.Context, gramOrgID, workosOrgID string, actor accessAuditActor, payload *gen.CreateRolePayload) (roleCreateResult, error) { + roleSlug, err := slugify(payload.Name) + if err != nil { + return roleCreateResult{}, err + } + + tx, err := r.db.Begin(ctx) + if err != nil { + return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "begin role transaction").Log(ctx, r.logger) + } + defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) + + now := time.Now().UTC() + createdRow, err := repo.New(tx).UpsertOrganizationRole(ctx, repo.UpsertOrganizationRoleParams{ + OrganizationID: gramOrgID, + WorkosSlug: roleSlug, + WorkosName: payload.Name, + WorkosDescription: conv.ToPGTextEmpty(payload.Description), + WorkosCreatedAt: conv.ToPGTimestamptz(now), + WorkosUpdatedAt: conv.ToPGTimestamptz(now), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + if err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "upsert local role record").Log(ctx, r.logger) + } + createdRole := localRole{ + ID: createdRow.ID.String(), + PrincipalURN: createdRow.RoleUrn, + Name: createdRow.WorkosName, + Slug: createdRow.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](createdRow.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(createdRow.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(createdRow.WorkosUpdatedAt), + MemberCount: int(createdRow.MemberCount), + } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleID(createdRole.ID)) + + if _, err := authz.SyncGrantsTx(ctx, tx, gramOrgID, roleSlug, createdRole.PrincipalURN, roleGrantPayloads(payload.Grants)); err != nil { + return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "sync grants for created role").Log(ctx, r.logger) + } + + workosSyncs := []workosSync{func(ctx context.Context) { + r.syncWorkOS(ctx, "create role in workos", func() error { + _, err := r.roles.CreateRole(ctx, workosOrgID, workos.CreateRoleOpts{ + Name: payload.Name, + Slug: roleSlug, + Description: payload.Description, + }) + var apiErr *workos.APIError + if errors.As(err, &apiErr) && apiErr.StatusCode == 409 { + return nil + } + if err == nil { + return nil + } + return fmt.Errorf("create role in workos: %w", err) + }) + }} + + if len(payload.MemberIds) > 0 { + var memberSyncs []workosSync + if _, memberSyncs, err = r.assignMembersToRoleTx(ctx, tx, gramOrgID, roleSlug, payload.MemberIds); err != nil { + return roleCreateResult{}, err + } + workosSyncs = append(workosSyncs, memberSyncs...) + createdRole, err = r.getLocalRoleBySlugTx(ctx, tx, gramOrgID, roleSlug) + if err != nil { + return roleCreateResult{}, err + } + } + + if err := r.audit.LogAccessRoleCreate(ctx, tx, audit.LogAccessRoleCreateEvent{ + OrganizationID: gramOrgID, + Actor: actor.Principal, + ActorDisplayName: actor.DisplayName, + ActorSlug: nil, + RoleID: createdRole.ID, + RoleName: createdRole.Name, + RoleSlug: createdRole.Slug, + }); err != nil { + return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "log access role creation").Log(ctx, r.logger) + } + + if err := tx.Commit(ctx); err != nil { + return roleCreateResult{}, oops.E(oops.CodeUnexpected, err, "commit role transaction").Log(ctx, r.logger) + } + + r.runWorkOSSyncs(ctx, workosSyncs) + + role, err := r.roleViewFromLocalRole(ctx, gramOrgID, createdRole) + if err != nil { + return roleCreateResult{}, err + } + + return roleCreateResult{Role: role, Slug: roleSlug}, nil +} + +type localRole struct { + ID string + PrincipalURN string + Name string + Slug string + Description string + CreatedAt string + UpdatedAt string + MemberCount int +} + +type roleUpdateResult struct { + Before *gen.Role + After *gen.Role + Role localRole +} + +// UpdateRole updates an existing local role, optional grants/assignments, and audit entry atomically, then best-effort syncs WorkOS after commit. +func (r *RoleManager) UpdateRole(ctx context.Context, gramOrgID, workosOrgID string, actor accessAuditActor, payload *gen.UpdateRolePayload) (roleUpdateResult, error) { + currentRole, err := r.getLocalRoleByID(ctx, gramOrgID, payload.ID) + if err != nil { + return roleUpdateResult{}, err + } + existingRole, err := r.roleViewFromLocalRole(ctx, gramOrgID, currentRole) + if err != nil { + return roleUpdateResult{}, err + } + + sysRole := isSystemRole(currentRole.Slug) + if sysRole && (payload.Name != nil || payload.Description != nil || payload.Grants != nil) { + return roleUpdateResult{}, oops.E(oops.CodeBadRequest, nil, "system role properties cannot be updated, only member assignment is allowed").Log(ctx, r.logger) + } + if sysRole && payload.MemberIds == nil { + return roleUpdateResult{}, oops.E(oops.CodeBadRequest, nil, "system role update requires member_ids").Log(ctx, r.logger) + } + if payload.Name != nil { + if _, err := slugify(*payload.Name); err != nil { + return roleUpdateResult{}, err + } + } + + tx, err := r.db.Begin(ctx) + if err != nil { + return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "begin role transaction").Log(ctx, r.logger) + } + defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) + + updatedRole := currentRole + var workosSyncs []workosSync + var updatedGrants []*gen.RoleGrant + if !sysRole { + localRecord := currentRole + if payload.Name != nil { + localRecord.Name = *payload.Name + } + if payload.Description != nil { + localRecord.Description = *payload.Description + } + localRecord.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + updatedRow, err := repo.New(tx).UpsertOrganizationRole(ctx, repo.UpsertOrganizationRoleParams{ + OrganizationID: gramOrgID, + WorkosSlug: localRecord.Slug, + WorkosName: localRecord.Name, + WorkosDescription: conv.ToPGTextEmpty(localRecord.Description), + WorkosCreatedAt: conv.ToPGTimestamptz(workosTimeOrNow(localRecord.CreatedAt)), + WorkosUpdatedAt: conv.ToPGTimestamptz(workosTimeOrNow(localRecord.UpdatedAt)), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + if err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "upsert local role record").Log(ctx, r.logger) + } + updatedRole = localRole{ + ID: updatedRow.ID.String(), + PrincipalURN: updatedRow.RoleUrn, + Name: updatedRow.WorkosName, + Slug: updatedRow.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](updatedRow.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(updatedRow.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(updatedRow.WorkosUpdatedAt), + MemberCount: int(updatedRow.MemberCount), + } + + if payload.Grants != nil { + syncedGrants, err := authz.SyncGrantsTx(ctx, tx, gramOrgID, currentRole.Slug, currentRole.PrincipalURN, roleGrantPayloads(payload.Grants)) + if err != nil { + return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "sync grants for updated role").Log(ctx, r.logger) + } + updatedGrants = make([]*gen.RoleGrant, 0, len(syncedGrants)) + for _, grant := range syncedGrants { + updatedGrants = append(updatedGrants, scopedGrantToGenRoleGrant(grant)) + } + } + + workosSyncs = append(workosSyncs, func(ctx context.Context) { + r.syncWorkOS(ctx, "update role in workos", func() error { + _, err := r.roles.UpdateRole(ctx, workosOrgID, currentRole.Slug, workos.UpdateRoleOpts{ + Name: payload.Name, + Description: payload.Description, + }) + if err == nil { + return nil + } + return fmt.Errorf("update role in workos: %w", err) + }) + }) + } + + if payload.MemberIds != nil { + var memberSyncs []workosSync + if _, memberSyncs, err = r.assignMembersToRoleTx(ctx, tx, gramOrgID, currentRole.Slug, payload.MemberIds); err != nil { + return roleUpdateResult{}, err + } + workosSyncs = append(workosSyncs, memberSyncs...) + updatedRole, err = r.getLocalRoleByIDTx(ctx, tx, gramOrgID, payload.ID) + if err != nil { + return roleUpdateResult{}, err + } + } + + updatedRoleView := &gen.Role{ + ID: updatedRole.ID, + Name: updatedRole.Name, + Slug: updatedRole.Slug, + Description: updatedRole.Description, + IsSystem: isSystemRole(updatedRole.Slug), + Grants: existingRole.Grants, + MemberCount: updatedRole.MemberCount, + CreatedAt: conv.Default(updatedRole.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), + UpdatedAt: conv.Default(updatedRole.UpdatedAt, time.Time{}.UTC().Format(time.RFC3339)), + } + if updatedGrants != nil { + updatedRoleView.Grants = updatedGrants + } + + if err := r.audit.LogAccessRoleUpdate(ctx, tx, audit.LogAccessRoleUpdateEvent{ + OrganizationID: gramOrgID, + Actor: actor.Principal, + ActorDisplayName: actor.DisplayName, + ActorSlug: nil, + RoleID: updatedRole.ID, + RoleName: updatedRoleView.Name, + RoleSlug: updatedRole.Slug, + RoleSnapshotBefore: existingRole, + RoleSnapshotAfter: updatedRoleView, + }); err != nil { + return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "log access role update").Log(ctx, r.logger) + } + + if err := tx.Commit(ctx); err != nil { + return roleUpdateResult{}, oops.E(oops.CodeUnexpected, err, "commit role transaction").Log(ctx, r.logger) + } + + r.runWorkOSSyncs(ctx, workosSyncs) + + return roleUpdateResult{Before: existingRole, After: updatedRoleView, Role: updatedRole}, nil +} + +// DeleteRole deletes a custom local role, reassignment records, grants, and audit entry atomically, then best-effort syncs WorkOS after commit. +func (r *RoleManager) DeleteRole(ctx context.Context, gramOrgID, workosOrgID, roleID string, actor accessAuditActor) (localRole, error) { + currentRole, err := r.getLocalRoleByID(ctx, gramOrgID, roleID) + if err != nil { + return localRole{}, err + } + if isSystemRole(currentRole.Slug) { + return localRole{}, oops.E(oops.CodeBadRequest, nil, "system roles cannot be deleted").Log(ctx, r.logger) + } + + tx, err := r.db.Begin(ctx) + if err != nil { + return localRole{}, oops.E(oops.CodeUnexpected, err, "begin role transaction").Log(ctx, r.logger) + } + defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) + + rows, err := repo.New(tx).ListOrganizationRoleAssignmentsBySlug(ctx, repo.ListOrganizationRoleAssignmentsBySlugParams{ + OrganizationID: gramOrgID, + WorkosRoleSlug: currentRole.Slug, + }) + if err != nil { + return localRole{}, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) + } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + + var workosSyncs []workosSync + for _, row := range rows { + membershipID := conv.FromPGTextOrEmpty[string](row.WorkosMembershipID) + if row.WorkosUserID != "" { + replaced, err := repo.New(tx).ReplaceOrganizationRoleAssignment(ctx, repo.ReplaceOrganizationRoleAssignmentParams{ + OrganizationID: gramOrgID, + WorkosUserID: row.WorkosUserID, + WorkosRoleSlug: authz.SystemRoleMember, + UserID: row.UserID, + WorkosMembershipID: conv.ToPGTextEmpty(membershipID), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Now().UTC()), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + if err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + return localRole{}, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) + } + if replaced == 0 { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + return localRole{}, oops.E(oops.CodeUnexpected, nil, "upsert local role assignment record").Log(ctx, r.logger) + } + } + workosSyncs = append(workosSyncs, func(ctx context.Context) { + r.syncWorkOS(ctx, "reassign member to default role in workos", func() error { + _, err := r.roles.UpdateMemberRole(ctx, membershipID, authz.SystemRoleMember) + if err == nil { + return nil + } + return fmt.Errorf("reassign member to default role in workos: %w", err) + }) + }) + } + + deletedCount, err := repo.New(tx).MarkOrganizationRoleDeletedLocally(ctx, repo.MarkOrganizationRoleDeletedLocallyParams{ + OrganizationID: gramOrgID, + WorkosSlug: currentRole.Slug, + }) + if err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + return localRole{}, oops.E(oops.CodeUnexpected, err, "mark local role record deleted").Log(ctx, r.logger) + } + if deletedCount == 0 { + return localRole{}, oops.E(oops.CodeNotFound, nil, "role not found").Log(ctx, r.logger) + } + + if err := authz.DeleteRoleGrants(ctx, repo.New(tx), gramOrgID, currentRole.Slug, currentRole.PrincipalURN); err != nil { + return localRole{}, oops.E(oops.CodeUnexpected, err, "delete grants for deleted role").Log(ctx, r.logger) + } + + if err := r.audit.LogAccessRoleDelete(ctx, tx, audit.LogAccessRoleDeleteEvent{ + OrganizationID: gramOrgID, + Actor: actor.Principal, + ActorDisplayName: actor.DisplayName, + ActorSlug: nil, + RoleID: currentRole.ID, + RoleName: currentRole.Name, + RoleSlug: currentRole.Slug, + }); err != nil { + return localRole{}, oops.E(oops.CodeUnexpected, err, "log access role deletion").Log(ctx, r.logger) + } + + if err := tx.Commit(ctx); err != nil { + return localRole{}, oops.E(oops.CodeUnexpected, err, "commit role transaction").Log(ctx, r.logger) + } + + workosSyncs = append(workosSyncs, func(ctx context.Context) { + r.syncWorkOS(ctx, "delete role in workos", func() error { + if err := r.roles.DeleteRole(ctx, workosOrgID, currentRole.Slug); err != nil { + return fmt.Errorf("delete role in workos: %w", err) + } + return nil + }) + }) + + r.runWorkOSSyncs(ctx, workosSyncs) + + return currentRole, nil +} + +type memberRoleUpdateContext struct { + RoleSlug string + MembershipID string + WorkosUserID string + UserID string + Before *gen.AccessMember + After *gen.AccessMember +} + +// UpdateMemberRole changes one member's local role assignment and audit entry atomically, then best-effort syncs WorkOS after commit. +func (r *RoleManager) UpdateMemberRole(ctx context.Context, gramOrgID, userID, roleID string, actor accessAuditActor) (memberRoleUpdateContext, error) { + tx, err := r.db.Begin(ctx) + if err != nil { + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "begin role transaction").Log(ctx, r.logger) + } + defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) + + role, err := r.getLocalRoleByIDTx(ctx, tx, gramOrgID, roleID) + if err != nil { + return memberRoleUpdateContext{}, err + } + + connectedUser, err := connectedUser(ctx, tx, gramOrgID, userID) + switch { + case errors.Is(err, errConnectedUserNotFound): + return memberRoleUpdateContext{}, oops.E(oops.CodeNotFound, nil, "member has not joined this organization").Log(ctx, r.logger) + case err != nil: + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "load connected user").Log(ctx, r.logger) + } + if !connectedUser.WorkosID.Valid || connectedUser.WorkosID.String == "" { + return memberRoleUpdateContext{}, oops.E(oops.CodeBadRequest, nil, "member is not linked to WorkOS").Log(ctx, r.logger) + } + + existing, err := repo.New(tx).GetOrganizationRoleAssignmentByWorkosUser(ctx, repo.GetOrganizationRoleAssignmentByWorkosUserParams{ + OrganizationID: gramOrgID, + WorkosUserID: connectedUser.WorkosID.String, + }) + switch { + case errors.Is(err, pgx.ErrNoRows): + return memberRoleUpdateContext{}, oops.E(oops.CodeNotFound, nil, "member not found").Log(ctx, r.logger) + case err != nil: + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "load member role assignment").Log(ctx, r.logger) + } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + + membershipID := conv.FromPGTextOrEmpty[string](existing.WorkosMembershipID) + if membershipID == "" { + return memberRoleUpdateContext{}, oops.E(oops.CodeNotFound, nil, "member is missing local WorkOS membership linkage").Log(ctx, r.logger) + } + + if existing.WorkosUserID != "" && role.Slug != "" { + replaced, err := repo.New(tx).ReplaceOrganizationRoleAssignment(ctx, repo.ReplaceOrganizationRoleAssignmentParams{ + OrganizationID: gramOrgID, + WorkosUserID: existing.WorkosUserID, + WorkosRoleSlug: role.Slug, + UserID: conv.ToPGTextEmpty(connectedUser.ID), + WorkosMembershipID: conv.ToPGTextEmpty(membershipID), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Now().UTC()), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + if err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) + } + if replaced == 0 { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, nil, "upsert local role assignment record").Log(ctx, r.logger) + } + } + + memberName := conv.Default(connectedUser.DisplayName, connectedUser.Email) + result := memberRoleUpdateContext{ + RoleSlug: role.Slug, + MembershipID: membershipID, + WorkosUserID: connectedUser.WorkosID.String, + UserID: connectedUser.ID, + Before: &gen.AccessMember{ + ID: connectedUser.ID, + Name: memberName, + Email: connectedUser.Email, + PhotoURL: conv.FromPGText[string](connectedUser.PhotoUrl), + RoleID: existing.RoleID, + JoinedAt: conv.FromPGTimestamptz(existing.CreatedAt), + }, + After: &gen.AccessMember{ + ID: connectedUser.ID, + Name: memberName, + Email: connectedUser.Email, + PhotoURL: conv.FromPGText[string](connectedUser.PhotoUrl), + RoleID: role.ID, + JoinedAt: conv.FromPGTimestamptz(existing.CreatedAt), + }, + } + + if err := r.audit.LogAccessMemberRoleUpdate(ctx, tx, audit.LogAccessMemberRoleUpdateEvent{ + OrganizationID: gramOrgID, + Actor: actor.Principal, + ActorDisplayName: actor.DisplayName, + ActorSlug: nil, + MemberID: result.UserID, + MemberName: result.After.Name, + MemberEmail: result.After.Email, + MemberSnapshotBefore: result.Before, + MemberSnapshotAfter: result.After, + }); err != nil { + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "log access member role update").Log(ctx, r.logger) + } + + if err := tx.Commit(ctx); err != nil { + return memberRoleUpdateContext{}, oops.E(oops.CodeUnexpected, err, "commit role transaction").Log(ctx, r.logger) + } + + r.runWorkOSSyncs(ctx, []workosSync{ + func(ctx context.Context) { + r.syncWorkOS(ctx, "update member role in workos", func() error { + _, err := r.roles.UpdateMemberRole(ctx, membershipID, role.Slug) + if err == nil { + return nil + } + return fmt.Errorf("update member role in workos: %w", err) + }) + }, + }) + + return result, nil +} + +// MemberRolePrincipals returns role slug and principal URN for each role assigned to a WorkOS user in this org. +func (r *RoleManager) MemberRolePrincipals(ctx context.Context, gramOrgID, workosUserID string) ([]repo.ListMemberRolePrincipalsByWorkosUserRow, error) { + if workosUserID == "" { + return nil, nil + } + + rows, err := repo.New(r.db).ListMemberRolePrincipalsByWorkosUser(ctx, repo.ListMemberRolePrincipalsByWorkosUserParams{ + OrganizationID: gramOrgID, + WorkosUserID: workosUserID, + }) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "list member roles").Log(ctx, r.logger) + } + + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + return rows, nil +} + +// getLocalRoleByID loads one local role record by Gram role ID. +func (r *RoleManager) getLocalRoleByID(ctx context.Context, gramOrgID, id string) (localRole, error) { + return r.getLocalRoleByIDTx(ctx, r.db, gramOrgID, id) +} + +func (r *RoleManager) getLocalRoleByIDTx(ctx context.Context, dbtx repo.DBTX, gramOrgID, id string) (localRole, error) { + roleID, err := uuid.Parse(id) + if err != nil { + return localRole{}, oops.E(oops.CodeBadRequest, err, "invalid role ID").Log(ctx, r.logger) + } + + row, err := repo.New(dbtx).GetOrganizationRoleByID(ctx, repo.GetOrganizationRoleByIDParams{ + ID: roleID, + OrganizationID: gramOrgID, + }) + switch { + case errors.Is(err, pgx.ErrNoRows): + return localRole{}, oops.E(oops.CodeNotFound, ErrRoleNotFound, "role not found").Log(ctx, r.logger) + case err != nil: + return localRole{}, oops.E(oops.CodeUnexpected, err, "get role").Log(ctx, r.logger) + } + + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + return localRole{ + ID: row.ID.String(), + PrincipalURN: row.RoleUrn, + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), + }, nil +} + +func (r *RoleManager) getLocalRoleBySlugTx(ctx context.Context, dbtx repo.DBTX, gramOrgID, slug string) (localRole, error) { + row, err := repo.New(dbtx).GetActiveOrganizationRoleBySlug(ctx, repo.GetActiveOrganizationRoleBySlugParams{ + OrganizationID: gramOrgID, + WorkosSlug: slug, + }) + switch { + case errors.Is(err, pgx.ErrNoRows): + return localRole{}, oops.E(oops.CodeNotFound, ErrRoleNotFound, "role not found").Log(ctx, r.logger) + case err != nil: + return localRole{}, oops.E(oops.CodeUnexpected, err, "get role").Log(ctx, r.logger) + } + + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + return localRole{ + ID: row.ID.String(), + PrincipalURN: row.RoleUrn, + Name: row.WorkosName, + Slug: row.WorkosSlug, + Description: conv.FromPGTextOrEmpty[string](row.WorkosDescription), + CreatedAt: conv.FromPGTimestamptz(row.WorkosCreatedAt), + UpdatedAt: conv.FromPGTimestamptz(row.WorkosUpdatedAt), + MemberCount: int(row.MemberCount), + }, nil +} + +type memberAssignmentTarget struct { + UserID string + WorkosUserID string + MembershipID string +} + +// requestedMemberAssignment groups input IDs that resolve to the same WorkOS user. +type requestedMemberAssignment struct { + InputIDs []string + UserID string +} + +func (r *RoleManager) memberAssignmentTargetsTx(ctx context.Context, dbtx repo.DBTX, gramOrgID string, memberIDs []string) ([]memberAssignmentTarget, error) { + if len(memberIDs) == 0 { + return nil, nil + } + + users, err := usersrepo.New(dbtx).GetUsersByIDs(ctx, memberIDs) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "resolve users by ids").Log(ctx, r.logger) + } + usersByID := make(map[string]usersrepo.User, len(users)) + for _, user := range users { + usersByID[user.ID] = user + } + + requestedInputs := make(map[string]struct{}, len(memberIDs)) + requestedByWorkosID := make(map[string]requestedMemberAssignment, len(memberIDs)) + workosIDs := make([]string, 0, len(memberIDs)) + for _, id := range memberIDs { + if _, ok := requestedInputs[id]; ok { + continue + } + requestedInputs[id] = struct{}{} + + workosID := id + userID := "" + if user, ok := usersByID[id]; ok { + userID = user.ID + if !user.WorkosID.Valid || user.WorkosID.String == "" { + return nil, oops.E(oops.CodeBadRequest, nil, "member %s is not linked to WorkOS", id).Log(ctx, r.logger) + } + workosID = user.WorkosID.String + } + if workosID == "" { + continue + } + requested, ok := requestedByWorkosID[workosID] + if ok { + requested.InputIDs = append(requested.InputIDs, id) + if requested.UserID == "" { + requested.UserID = userID + } + requestedByWorkosID[workosID] = requested + continue + } + requestedByWorkosID[workosID] = requestedMemberAssignment{ + InputIDs: []string{id}, + UserID: userID, + } + workosIDs = append(workosIDs, workosID) + } + + assignmentRows, err := repo.New(dbtx).ListOrganizationRoleAssignmentsByWorkosUsers(ctx, repo.ListOrganizationRoleAssignmentsByWorkosUsersParams{ + OrganizationID: gramOrgID, + WorkosUserIds: workosIDs, + }) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "list members").Log(ctx, r.logger) + } + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleSource("db")) + targets := make([]memberAssignmentTarget, 0, len(memberIDs)) + resolvedWorkosIDs := make(map[string]struct{}, len(requestedByWorkosID)) + resolvedInputs := make(map[string]struct{}, len(requestedInputs)) + for _, row := range assignmentRows { + requested, ok := requestedByWorkosID[row.WorkosUserID] + if !ok { + continue + } + if _, ok := resolvedWorkosIDs[row.WorkosUserID]; ok { + continue + } + userID := conv.FromPGTextOrEmpty[string](row.UserID) + if userID == "" { + userID = requested.UserID + } + resolvedWorkosIDs[row.WorkosUserID] = struct{}{} + for _, inputID := range requested.InputIDs { + resolvedInputs[inputID] = struct{}{} + } + targets = append(targets, memberAssignmentTarget{ + UserID: userID, + WorkosUserID: row.WorkosUserID, + MembershipID: conv.FromPGTextOrEmpty[string](row.WorkosMembershipID), + }) + } + if len(resolvedInputs) != len(requestedInputs) { + return nil, oops.E(oops.CodeBadRequest, nil, "member is missing local WorkOS membership linkage").Log(ctx, r.logger) + } + + return targets, nil +} + +func (r *RoleManager) assignMembersToRoleTx(ctx context.Context, dbtx repo.DBTX, gramOrgID, roleSlug string, memberIDs []string) (int, []workosSync, error) { + targets, err := r.memberAssignmentTargetsTx(ctx, dbtx, gramOrgID, memberIDs) + if err != nil { + return 0, nil, err + } + + assignedCount := 0 + workosSyncs := make([]workosSync, 0, len(targets)) + for _, target := range targets { + if target.WorkosUserID != "" && roleSlug != "" { + replaced, err := repo.New(dbtx).ReplaceOrganizationRoleAssignment(ctx, repo.ReplaceOrganizationRoleAssignmentParams{ + OrganizationID: gramOrgID, + WorkosUserID: target.WorkosUserID, + WorkosRoleSlug: roleSlug, + UserID: conv.ToPGTextEmpty(target.UserID), + WorkosMembershipID: conv.ToPGTextEmpty(target.MembershipID), + WorkosUpdatedAt: conv.ToPGTimestamptz(time.Now().UTC()), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + if err != nil { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + return 0, nil, oops.E(oops.CodeUnexpected, err, "upsert local role assignment record").Log(ctx, r.logger) + } + if replaced == 0 { + trace.SpanFromContext(ctx).SetAttributes(attr.AccessRoleDBWriteFailed(true)) + return 0, nil, oops.E(oops.CodeUnexpected, nil, "upsert local role assignment record").Log(ctx, r.logger) + } + } + assignedCount++ + membershipID := target.MembershipID + workosSyncs = append(workosSyncs, func(ctx context.Context) { + r.syncWorkOS(ctx, "assign member to role in workos", func() error { + _, err := r.roles.UpdateMemberRole(ctx, membershipID, roleSlug) + if err == nil { + return nil + } + return fmt.Errorf("assign member to role in workos: %w", err) + }) + }) + } + return assignedCount, workosSyncs, nil +} + +// runWorkOSSyncs starts best-effort WorkOS writes after the local transaction commits. +func (r *RoleManager) runWorkOSSyncs(ctx context.Context, syncs []workosSync) { + if len(syncs) == 0 { + return + } + syncCtx := context.WithoutCancel(ctx) + go func() { + defer func() { + if recovered := recover(); recovered != nil { + r.logger.ErrorContext(syncCtx, "workos sync panic", attr.SlogError(fmt.Errorf("%v", recovered))) + } + }() + for _, syncWorkOS := range syncs { + syncWorkOS(syncCtx) + } + }() +} + +// syncWorkOS runs a bounded best-effort WorkOS write after the local database already accepted the change. +func (r *RoleManager) syncWorkOS(ctx context.Context, operation string, fn func() error) { + exp := backoff.NewExponentialBackOff() + exp.InitialInterval = 100 * time.Millisecond + exp.MaxInterval = 300 * time.Millisecond + exp.RandomizationFactor = 0 + + _, err := backoff.Retry(ctx, func() (struct{}, error) { + err := fn() + if err == nil { + return struct{}{}, nil + } + if !retryWorkOSError(err) { + return struct{}{}, backoff.Permanent(err) + } + return struct{}{}, err + }, backoff.WithBackOff(exp), backoff.WithMaxTries(workOSSyncAttempts)) + if err == nil { + return + } + + r.logger.ErrorContext(ctx, "workos sync failed: "+operation, attr.SlogError(err)) +} + +// retryWorkOSError reports whether a WorkOS sync failure is worth retrying in-process. +func retryWorkOSError(err error) bool { + var apiErr *workos.APIError + if !errors.As(err, &apiErr) { + return true + } + return apiErr.StatusCode == 429 || apiErr.StatusCode >= 500 +} + +// roleViewFromLocalRole converts a local role record into the public API role view and attaches local grants. +func (r *RoleManager) roleViewFromLocalRole(ctx context.Context, organizationID string, role localRole) (*gen.Role, error) { + // Role grant reads intentionally include both canonical role URNs and + // legacy role slugs until old principal_grants rows are backfilled. + grants, err := authz.GrantsForRole(ctx, r.logger, r.db, organizationID, role.Slug, role.PrincipalURN) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "load role grants").Log(ctx, r.logger) + } + genGrants := make([]*gen.RoleGrant, 0, len(grants)) + for _, g := range grants { + genGrants = append(genGrants, scopedGrantToGenRoleGrant(g)) + } + + return &gen.Role{ + ID: role.ID, + Name: role.Name, + Slug: role.Slug, + Description: role.Description, + IsSystem: isSystemRole(role.Slug), + Grants: genGrants, + MemberCount: role.MemberCount, + CreatedAt: conv.Default(role.CreatedAt, time.Time{}.UTC().Format(time.RFC3339)), + UpdatedAt: conv.Default(role.UpdatedAt, time.Time{}.UTC().Format(time.RFC3339)), + }, nil +} + +// workosTimeOrNow parses a WorkOS RFC3339 timestamp or returns the current UTC time when WorkOS omits or malforms it. +func workosTimeOrNow(value string) time.Time { + if value == "" { + return time.Now().UTC() + } + t, err := time.Parse(time.RFC3339, value) + if err != nil { + return time.Now().UTC() + } + return t.UTC() +} + +// slugify validates a role name and turns it into Gram's WorkOS role slug format. +func slugify(name string) (string, error) { + slug := conv.ToSlug(strings.ReplaceAll(name, "_", " ")) + if slug == "" { + return "", oops.E(oops.CodeBadRequest, nil, "role name must contain at least one letter or digit") + } + if !validRoleNamePattern.MatchString(name) { + return "", oops.E(oops.CodeBadRequest, nil, "role name contains invalid characters") + } + if !strings.HasPrefix(slug, "org-") { + slug = "org-" + slug + } + + return slug, nil +} diff --git a/server/internal/access/role_manager_test.go b/server/internal/access/role_manager_test.go new file mode 100644 index 0000000000..1dcc30dd3f --- /dev/null +++ b/server/internal/access/role_manager_test.go @@ -0,0 +1,213 @@ +package access + +import ( + "testing" + "time" + + mockidp "github.com/speakeasy-api/gram/dev-idp/pkg/testidp" + "github.com/stretchr/testify/require" + + accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" + "github.com/speakeasy-api/gram/server/internal/contextvalues" + "github.com/speakeasy-api/gram/server/internal/conv" +) + +func TestRoleManager_ListRoles(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + adminID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + customID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build")) + + roles, err := ti.service.roleMgr.ListRoles(ctx, authCtx.ActiveOrganizationID) + require.NoError(t, err) + require.Len(t, roles.Roles, 2) + + bySlug := map[string]string{} + for _, role := range roles.Roles { + if role.Name == "Admin" { + bySlug["admin"] = role.ID + } + if role.Name == "Custom Builder" { + bySlug["custom-builder"] = role.ID + } + } + require.Equal(t, adminID, bySlug["admin"]) + require.Equal(t, customID, bySlug["custom-builder"]) +} + +func TestRoleManager_GetRoleByID(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + customID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build")) + + role, err := ti.service.roleMgr.GetRoleByID(ctx, authCtx.ActiveOrganizationID, customID) + require.NoError(t, err) + require.Equal(t, customID, role.ID) + require.Equal(t, "Custom Builder", role.Name) + + _, err = ti.service.roleMgr.GetRoleByID(ctx, authCtx.ActiveOrganizationID, "not-a-uuid") + require.Error(t, err) +} + +func TestRoleManager_MembersAndCounts(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "u1@example.com", "User 1", "user_1", "membership_1") + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "u2@example.com", "User 2", "user_2", "membership_2") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_3", "user_3", "custom-builder")) + + manager := ti.service.roleMgr + members, err := manager.ListMembers(ctx, authCtx.ActiveOrganizationID) + require.NoError(t, err) + require.Len(t, members.Members, 3) + + rolePrincipals, err := manager.MemberRolePrincipals(ctx, authCtx.ActiveOrganizationID, "user_2") + require.NoError(t, err) + slugs := make([]string, 0, len(rolePrincipals)) + for _, role := range rolePrincipals { + slugs = append(slugs, role.RoleSlug) + } + require.Equal(t, []string{"custom-builder"}, slugs) + + roles, err := manager.ListRoles(ctx, authCtx.ActiveOrganizationID) + require.NoError(t, err) + counts := make(map[string]int, len(roles.Roles)) + for _, role := range roles.Roles { + counts[role.Name] = role.MemberCount + } + require.Equal(t, 1, counts["Admin"]) + require.Equal(t, 1, counts["Custom Builder"]) +} + +func TestRoleManager_AssignMembersToRoleAcceptsConnectedMemberWithoutAssignment(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "u1@example.com", "User 1", "user_1", "membership_1") + + assigned, _, err := ti.service.roleMgr.assignMembersToRoleTx(ctx, ti.conn, authCtx.ActiveOrganizationID, "custom-builder", []string{"local_user_1"}) + require.NoError(t, err) + require.Equal(t, 1, assigned) +} + +func TestRoleManager_LocalRoleWritePreservesWorkOSLastEventID(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + now := time.Now().UTC() + _, err := accessrepo.New(ti.conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosSlug: "custom-builder", + WorkosName: "Custom Builder", + WorkosDescription: conv.ToPGTextEmpty("Before"), + WorkosCreatedAt: conv.ToPGTimestamptz(now), + WorkosUpdatedAt: conv.ToPGTimestamptz(now), + WorkosLastEventID: conv.ToPGText("event_01SEED"), + }) + require.NoError(t, err) + + _, err = accessrepo.New(ti.conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosSlug: "custom-builder", + WorkosName: "Custom Builder", + WorkosDescription: conv.ToPGTextEmpty("After"), + WorkosCreatedAt: conv.ToPGTimestamptz(now), + WorkosUpdatedAt: conv.ToPGTimestamptz(now.Add(time.Minute)), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + require.NoError(t, err) + + row, err := accessrepo.New(ti.conn).GetOrganizationRoleBySlug(ctx, accessrepo.GetOrganizationRoleBySlugParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosSlug: "custom-builder", + }) + require.NoError(t, err) + require.Equal(t, "event_01SEED", row.WorkosLastEventID.String) + + replaced, err := accessrepo.New(ti.conn).ReplaceOrganizationRoleAssignment(ctx, accessrepo.ReplaceOrganizationRoleAssignmentParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosUserID: "user_1", + WorkosRoleSlug: "custom-builder", + UserID: conv.ToPGTextEmpty(""), + WorkosMembershipID: conv.ToPGTextEmpty("membership_1"), + WorkosUpdatedAt: conv.ToPGTimestamptz(now), + WorkosLastEventID: conv.ToPGText("event_02SEED"), + }) + require.NoError(t, err) + require.Equal(t, int64(1), replaced) + + replaced, err = accessrepo.New(ti.conn).ReplaceOrganizationRoleAssignment(ctx, accessrepo.ReplaceOrganizationRoleAssignmentParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosUserID: "user_1", + WorkosRoleSlug: "custom-builder", + UserID: conv.ToPGTextEmpty(""), + WorkosMembershipID: conv.ToPGTextEmpty("membership_1"), + WorkosUpdatedAt: conv.ToPGTimestamptz(now.Add(time.Minute)), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + require.NoError(t, err) + require.Equal(t, int64(1), replaced) + + assignments, err := accessrepo.New(ti.conn).ListOrganizationRoleAssignmentRecordsByWorkosUser(ctx, accessrepo.ListOrganizationRoleAssignmentRecordsByWorkosUserParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosUserID: "user_1", + }) + require.NoError(t, err) + require.Len(t, assignments, 1) + require.False(t, assignments[0].DeletedAt.Valid) + require.Equal(t, "event_02SEED", assignments[0].WorkosLastEventID.String) +} + +func TestRoleManager_ReplaceRoleAssignmentSoftDeletesPreviousRole(t *testing.T) { + t.Parallel() + + ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", "member")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Can build")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "u1@example.com", "User 1", "user_1", "membership_1") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "member")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder")) + + assignments, err := accessrepo.New(ti.conn).ListOrganizationRoleAssignmentRecordsByWorkosUser(ctx, accessrepo.ListOrganizationRoleAssignmentRecordsByWorkosUserParams{ + OrganizationID: authCtx.ActiveOrganizationID, + WorkosUserID: "user_1", + }) + require.NoError(t, err) + activeCount := 0 + deletedCount := 0 + for _, assignment := range assignments { + if assignment.DeletedAt.Valid { + deletedCount++ + continue + } + activeCount++ + } + require.Equal(t, 1, activeCount) + require.Equal(t, 1, deletedCount) +} diff --git a/server/internal/access/setup_internal_test.go b/server/internal/access/setup_internal_test.go index 51f01742e6..57795a36b4 100644 --- a/server/internal/access/setup_internal_test.go +++ b/server/internal/access/setup_internal_test.go @@ -3,6 +3,7 @@ package access import ( "context" "testing" + "time" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" @@ -30,7 +31,7 @@ func newInternalTestService(t *testing.T) (context.Context, *Service, *pgxpool.P conn, err := res.CloneTestDatabase(t, "testdb") require.NoError(t, err) - return ctx, &Service{tracer: nil, logger: logger, db: conn, chConn: nil, auth: nil, authz: nil, roles: nil, featureCache: nil}, conn + return ctx, &Service{tracer: nil, logger: logger, db: conn, chConn: nil, auth: nil, authz: nil, roleMgr: nil, featureCache: nil}, conn } func seedInternalOrganization(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID string) { @@ -59,3 +60,23 @@ func seedInternalGrant(t *testing.T, ctx context.Context, conn *pgxpool.Pool, or }) require.NoError(t, err) } + +func seedInternalRole(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID string, roleSlug string) urn.Principal { + t.Helper() + + now := time.Now().UTC() + row, err := accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + OrganizationID: organizationID, + WorkosSlug: roleSlug, + WorkosName: roleSlug, + WorkosDescription: conv.ToPGTextEmpty(""), + WorkosCreatedAt: conv.ToPGTimestamptz(now), + WorkosUpdatedAt: conv.ToPGTimestamptz(now), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + require.NoError(t, err) + + principal, err := urn.ParsePrincipal(row.RoleUrn) + require.NoError(t, err) + return principal +} diff --git a/server/internal/access/setup_test.go b/server/internal/access/setup_test.go index 3d533b324d..b020579d61 100644 --- a/server/internal/access/setup_test.go +++ b/server/internal/access/setup_test.go @@ -5,6 +5,7 @@ import ( "log" "os" "testing" + "time" "github.com/ClickHouse/clickhouse-go/v2" "github.com/jackc/pgx/v5/pgxpool" @@ -90,7 +91,8 @@ func newTestAccessService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() - svc := NewService(logger, tracerProvider, conn, chConn, sessionManager, roles, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), noopFeatureCacheWriter{}, auditLogger) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) + svc := NewService(logger, tracerProvider, conn, chConn, sessionManager, NewRoleManager(logger, conn, roles, auditLogger), authzEngine, noopFeatureCacheWriter{}, auditLogger) return ctx, &testInstance{ service: svc, @@ -139,6 +141,85 @@ func listPrincipalGrants(t *testing.T, ctx context.Context, conn *pgxpool.Pool, return grants } +func seedRole(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID string, role workos.Role) string { + t.Helper() + + createdAt, err := time.Parse(time.RFC3339, role.CreatedAt) + require.NoError(t, err) + updatedAt, err := time.Parse(time.RFC3339, role.UpdatedAt) + require.NoError(t, err) + + _, err = accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + OrganizationID: organizationID, + WorkosSlug: role.Slug, + WorkosName: role.Name, + WorkosDescription: conv.ToPGTextEmpty(role.Description), + WorkosCreatedAt: conv.ToPGTimestamptz(createdAt), + WorkosUpdatedAt: conv.ToPGTimestamptz(updatedAt), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + require.NoError(t, err) + + row, err := accessrepo.New(conn).GetOrganizationRoleBySlug(ctx, accessrepo.GetOrganizationRoleBySlugParams{ + OrganizationID: organizationID, + WorkosSlug: role.Slug, + }) + require.NoError(t, err) + + return row.ID.String() +} + +func seedGlobalRole(t *testing.T, ctx context.Context, conn *pgxpool.Pool, role workos.Role) string { + t.Helper() + + createdAt, err := time.Parse(time.RFC3339, role.CreatedAt) + require.NoError(t, err) + updatedAt, err := time.Parse(time.RFC3339, role.UpdatedAt) + require.NoError(t, err) + + err = accessrepo.New(conn).UpsertGlobalRole(ctx, accessrepo.UpsertGlobalRoleParams{ + WorkosSlug: role.Slug, + WorkosName: role.Name, + WorkosDescription: conv.ToPGTextEmpty(role.Description), + WorkosCreatedAt: conv.ToPGTimestamptz(createdAt), + WorkosUpdatedAt: conv.ToPGTimestamptz(updatedAt), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + require.NoError(t, err) + + row, err := accessrepo.New(conn).GetGlobalRoleBySlug(ctx, role.Slug) + require.NoError(t, err) + + return row.ID.String() +} + +func seedRoleAssignment(t *testing.T, ctx context.Context, conn *pgxpool.Pool, organizationID, userID string, member workos.Member) { + t.Helper() + + updatedAt := time.Now().UTC() + if member.UpdatedAt != "" { + parsed, err := time.Parse(time.RFC3339, member.UpdatedAt) + require.NoError(t, err) + updatedAt = parsed + } else if member.CreatedAt != "" { + parsed, err := time.Parse(time.RFC3339, member.CreatedAt) + require.NoError(t, err) + updatedAt = parsed + } + + replaced, err := accessrepo.New(conn).ReplaceOrganizationRoleAssignment(ctx, accessrepo.ReplaceOrganizationRoleAssignmentParams{ + OrganizationID: organizationID, + WorkosUserID: member.UserID, + WorkosRoleSlug: member.RoleSlug, + UserID: conv.ToPGTextEmpty(userID), + WorkosMembershipID: conv.ToPGTextEmpty(member.ID), + WorkosUpdatedAt: conv.ToPGTimestamptz(updatedAt), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }) + require.NoError(t, err) + require.Equal(t, int64(1), replaced) +} + // seedDisconnectedUser creates a user in the users table with a workos_id but // does NOT insert into organization_user_relationships, simulating a WorkOS // user who hasn't been connected to the Gram org. diff --git a/server/internal/access/syncgrants_test.go b/server/internal/access/syncgrants_test.go index 1af0707124..d6ac5b01ec 100644 --- a/server/internal/access/syncgrants_test.go +++ b/server/internal/access/syncgrants_test.go @@ -16,14 +16,15 @@ func TestService_syncGrants_replacesRoleGrants(t *testing.T) { ctx, svc, conn := newInternalTestService(t) organizationID := "org_sync_grants_replace" roleSlug := "custom-editor" - rolePrincipal := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) - seedInternalOrganization(t, ctx, conn, organizationID) - seedInternalGrant(t, ctx, conn, organizationID, rolePrincipal, string(authz.ScopeProjectRead), "project-old") - seedInternalGrant(t, ctx, conn, organizationID, rolePrincipal, string(authz.ScopeProjectWrite), "project-stale") + rolePrincipal := seedInternalRole(t, ctx, conn, organizationID, roleSlug) + legacyRolePrincipal := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) + + seedInternalGrant(t, ctx, conn, organizationID, legacyRolePrincipal, string(authz.ScopeProjectRead), "project-old") + seedInternalGrant(t, ctx, conn, organizationID, legacyRolePrincipal, string(authz.ScopeProjectWrite), "project-stale") seedInternalGrant(t, ctx, conn, organizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "other-role"), string(authz.ScopeProjectRead), "project-other") - err := authz.SyncGrants(ctx, svc.logger, conn, organizationID, roleSlug, []*authz.RoleGrant{ + err := authz.SyncGrants(ctx, svc.logger, conn, organizationID, roleSlug, rolePrincipal.String(), []*authz.RoleGrant{ { Scope: string(authz.ScopeProjectRead), Selectors: nil, @@ -61,6 +62,12 @@ func TestService_syncGrants_replacesRoleGrants(t *testing.T) { string(authz.ScopeMCPConnect) + "|tool:analytics", string(authz.ScopeMCPConnect) + "|tool:payments", }, got) + legacyRows, err := accessrepo.New(conn).ListPrincipalGrantsByOrg(ctx, accessrepo.ListPrincipalGrantsByOrgParams{ + OrganizationID: organizationID, + PrincipalUrn: legacyRolePrincipal.String(), + }) + require.NoError(t, err) + require.Empty(t, legacyRows) otherRows, err := accessrepo.New(conn).ListPrincipalGrantsByOrg(ctx, accessrepo.ListPrincipalGrantsByOrgParams{ OrganizationID: organizationID, @@ -79,12 +86,11 @@ func TestService_syncGrants_emptySelectorsCreatesNoGrant(t *testing.T) { ctx, svc, conn := newInternalTestService(t) organizationID := "org_sync_grants_empty_sel" roleSlug := "custom-empty-sel" - rolePrincipal := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) - seedInternalOrganization(t, ctx, conn, organizationID) + rolePrincipal := seedInternalRole(t, ctx, conn, organizationID, roleSlug) // Empty non-nil selectors = no access (not wildcard). - err := authz.SyncGrants(ctx, svc.logger, conn, organizationID, roleSlug, []*authz.RoleGrant{ + err := authz.SyncGrants(ctx, svc.logger, conn, organizationID, roleSlug, rolePrincipal.String(), []*authz.RoleGrant{ { Scope: string(authz.ScopeMCPConnect), Selectors: []authz.Selector{}, @@ -106,13 +112,12 @@ func TestService_syncGrants_clearsRoleGrantsWhenEmpty(t *testing.T) { ctx, svc, conn := newInternalTestService(t) organizationID := "org_sync_grants_clear" roleSlug := "custom-viewer" - rolePrincipal := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) - seedInternalOrganization(t, ctx, conn, organizationID) + rolePrincipal := seedInternalRole(t, ctx, conn, organizationID, roleSlug) seedInternalGrant(t, ctx, conn, organizationID, rolePrincipal, string(authz.ScopeProjectRead), authz.WildcardResource) seedInternalGrant(t, ctx, conn, organizationID, rolePrincipal, string(authz.ScopeMCPRead), "tool:payments") - err := authz.SyncGrants(ctx, svc.logger, conn, organizationID, roleSlug, nil) + err := authz.SyncGrants(ctx, svc.logger, conn, organizationID, roleSlug, rolePrincipal.String(), nil) require.NoError(t, err) rows, err := accessrepo.New(conn).ListPrincipalGrantsByOrg(ctx, accessrepo.ListPrincipalGrantsByOrgParams{ diff --git a/server/internal/access/updatememberrole_test.go b/server/internal/access/updatememberrole_test.go index 7b3f1df054..c18119d41a 100644 --- a/server/internal/access/updatememberrole_test.go +++ b/server/internal/access/updatememberrole_test.go @@ -23,13 +23,10 @@ func TestService_UpdateMemberRole(t *testing.T) { require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -37,28 +34,23 @@ func TestService_UpdateMemberRole(t *testing.T) { RoleSlug: "custom-builder", CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User{ - "user_1": mockUser("user_1", "Ada", "Lovelace", "ada@example.com"), - }, nil).Once() - seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") - member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "role_builder"}) + member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) require.NoError(t, err) require.Equal(t, "local_user_1", member.ID) require.Equal(t, "Ada Lovelace", member.Name) require.Equal(t, "ada@example.com", member.Email) - require.Equal(t, "role_builder", member.RoleID) + require.Equal(t, builderID, member.RoleID) require.Nil(t, member.PhotoURL) - require.Equal(t, mockMembershipTimestamp, member.JoinedAt) + require.NotEmpty(t, member.JoinedAt) } func TestService_UpdateMemberRole_RoleNotFound(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{}, nil).Once() - _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "role_missing"}) + _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "00000000-0000-0000-0000-000000000001"}) require.Error(t, err) require.Contains(t, err.Error(), "role not found") } @@ -70,12 +62,10 @@ func TestService_UpdateMemberRole_MemberNotFound(t *testing.T) { authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") - _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "user_missing", RoleID: "role_builder"}) + _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "user_missing", RoleID: builderID}) require.Error(t, err) require.Contains(t, err.Error(), "member has not joined this organization") } @@ -87,15 +77,12 @@ func TestService_UpdateMemberRole_WorkOSMembershipNotFound(t *testing.T) { authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() - seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "") - _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "role_builder"}) + _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) require.Error(t, err) - require.Contains(t, err.Error(), "member not found") + require.Contains(t, err.Error(), "member is missing local WorkOS membership linkage") } func TestService_UpdateMemberRole_WorkOSFailure(t *testing.T) { @@ -105,18 +92,15 @@ func TestService_UpdateMemberRole_WorkOSFailure(t *testing.T) { authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) require.NotNil(t, authCtx) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() - ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Once() + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) + ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return((*thirdpartyworkos.Member)(nil), errors.New("workos unavailable")).Times(3) - _, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "role_builder"}) - require.Error(t, err) - require.Contains(t, err.Error(), "update member role in workos") + member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) + require.NoError(t, err) + require.Equal(t, builderID, member.RoleID) } func TestService_UpdateMemberRole_AuditLog(t *testing.T) { @@ -129,13 +113,10 @@ func TestService_UpdateMemberRole_AuditLog(t *testing.T) { beforeCount, err := audittest.AuditLogCountByAction(ctx, ti.conn, audit.ActionAccessMemberRoleUpdate) require.NoError(t, err) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockSystemRole("role_admin", "Admin", "admin"), - mockRole("role_builder", "Builder", "custom-builder", ""), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() + adminID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + builderID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_builder", "Builder", "custom-builder", "")) + seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember("", "membership_1", "user_1", "admin")) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -143,12 +124,8 @@ func TestService_UpdateMemberRole_AuditLog(t *testing.T) { RoleSlug: "custom-builder", CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("ListOrgUsers", mock.Anything, mockidp.MockOrgID).Return(map[string]thirdpartyworkos.User{ - "user_1": mockUser("user_1", "Ada", "Lovelace", "ada@example.com"), - }, nil).Once() - seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "ada@example.com", "Ada Lovelace", "user_1", "membership_1") - member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: "role_builder"}) + member, err := ti.service.UpdateMemberRole(ctx, &gen.UpdateMemberRolePayload{UserID: "local_user_1", RoleID: builderID}) require.NoError(t, err) require.NotNil(t, member) @@ -165,8 +142,8 @@ func TestService_UpdateMemberRole_AuditLog(t *testing.T) { require.NoError(t, err) afterSnapshot, err := audittest.DecodeAuditData(record.AfterSnapshot) require.NoError(t, err) - require.Equal(t, "role_admin", beforeSnapshot["RoleID"]) - require.Equal(t, "role_builder", afterSnapshot["RoleID"]) + require.Equal(t, adminID, beforeSnapshot["RoleID"]) + require.Equal(t, builderID, afterSnapshot["RoleID"]) afterCount, err := audittest.AuditLogCountByAction(ctx, ti.conn, audit.ActionAccessMemberRoleUpdate) require.NoError(t, err) diff --git a/server/internal/access/updaterole_test.go b/server/internal/access/updaterole_test.go index a8f71aa7b2..523b7365d6 100644 --- a/server/internal/access/updaterole_test.go +++ b/server/internal/access/updaterole_test.go @@ -28,9 +28,8 @@ func TestService_UpdateRole(t *testing.T) { name := "Platform Builder" description := "Updated description" - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) ti.roles.On("UpdateRole", mock.Anything, mockidp.MockOrgID, "custom-builder", thirdpartyworkos.UpdateRoleOpts{ Name: &name, Description: &description, @@ -42,11 +41,6 @@ func TestService_UpdateRole(t *testing.T) { CreatedAt: mockRoleTimestamp, UpdatedAt: mockRoleTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "member"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "member"), - mockMember(mockidp.MockOrgID, "membership_3", "user_3", "custom-builder"), - }, nil).Once() ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "custom-builder").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -61,20 +55,17 @@ func TestService_UpdateRole(t *testing.T) { RoleSlug: "custom-builder", CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_3", "user_3", "custom-builder"), - // user_workos_only has never logged into Gram — should not be counted - mockMember(mockidp.MockOrgID, "membership_workos_only", "user_workos_only", "custom-builder"), - }, nil).Once() seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_3", "user3@test.com", "User 3", "user_3", "membership_3") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", authz.SystemRoleMember)) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember(mockidp.MockOrgID, "membership_2", "user_2", authz.SystemRoleMember)) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_3", mockMember(mockidp.MockOrgID, "membership_3", "user_3", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "", mockMember(mockidp.MockOrgID, "membership_workos_only", "user_workos_only", "custom-builder")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-old") role, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_custom", + ID: roleID, Name: &name, Description: &description, Grants: []*gen.RoleGrant{ @@ -84,16 +75,17 @@ func TestService_UpdateRole(t *testing.T) { MemberIds: []string{"local_user_1", "local_user_2"}, }) require.NoError(t, err) - require.Equal(t, "role_custom", role.ID) + require.Equal(t, roleID, role.ID) require.Equal(t, name, role.Name) require.Equal(t, description, role.Description) require.False(t, role.IsSystem) require.Equal(t, 3, role.MemberCount) require.Equal(t, mockRoleTimestamp, role.CreatedAt) - require.Equal(t, mockRoleTimestamp, role.UpdatedAt) + require.NotEmpty(t, role.UpdatedAt) + require.NotEqual(t, mockRoleTimestamp, role.UpdatedAt) require.Len(t, role.Grants, 2) - grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder")) + grants := listPrincipalGrants(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "organization:"+roleID)) require.Len(t, grants, 3) } @@ -106,13 +98,8 @@ func TestService_UpdateRole_SystemRole_MemberAssignment(t *testing.T) { require.NotNil(t, authCtx) // admin and member are system roles — WorkOS UpdateRole must NOT be called. - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_admin", "Admin", "admin", "Full access"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "member"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "member"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "admin").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -120,20 +107,18 @@ func TestService_UpdateRole_SystemRole_MemberAssignment(t *testing.T) { RoleSlug: "admin", CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "member"), - }, nil).Once() seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", authz.SystemRoleMember)) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember(mockidp.MockOrgID, "membership_2", "user_2", authz.SystemRoleMember)) role, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_admin", + ID: roleID, MemberIds: []string{"local_user_1"}, }) require.NoError(t, err) - require.Equal(t, "role_admin", role.ID) + require.Equal(t, roleID, role.ID) require.True(t, role.IsSystem) require.Equal(t, 1, role.MemberCount) @@ -145,14 +130,14 @@ func TestService_UpdateRole_SystemRole_RejectsPropertyChanges(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_member", "Member", "member", "Default role"), - }, nil) + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) name := "Custom Name" _, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_member", + ID: roleID, Name: &name, }) require.Error(t, err) @@ -160,14 +145,14 @@ func TestService_UpdateRole_SystemRole_RejectsPropertyChanges(t *testing.T) { description := "Custom description" _, err = ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_member", + ID: roleID, Description: &description, }) require.Error(t, err) require.Contains(t, err.Error(), "system role properties cannot be updated") _, err = ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_member", + ID: roleID, Grants: []*gen.RoleGrant{{Scope: string(authz.ScopeProjectRead)}}, }) require.Error(t, err) @@ -178,13 +163,13 @@ func TestService_UpdateRole_SystemRole_RejectsNoopUpdate(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_admin", "Admin", "admin", "Full access"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) _, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_admin", + ID: roleID, }) require.Error(t, err) require.Contains(t, err.Error(), "system role update requires member_ids") @@ -200,12 +185,8 @@ func TestService_UpdateRole_SystemRole_AuditLog(t *testing.T) { beforeCount, err := audittest.AuditLogCountByAction(ctx, ti.conn, audit.ActionAccessRoleUpdate) require.NoError(t, err) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_admin", "Admin", "admin", "Full access"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "member"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_admin", "Admin", "admin")) + seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockSystemRole("role_member", "Member", authz.SystemRoleMember)) ti.roles.On("UpdateMemberRole", mock.Anything, "membership_1", "admin").Return(&thirdpartyworkos.Member{ ID: "membership_1", UserID: "user_1", @@ -213,14 +194,12 @@ func TestService_UpdateRole_SystemRole_AuditLog(t *testing.T) { RoleSlug: "admin", CreatedAt: mockMembershipTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "admin"), - }, nil).Once() seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", authz.SystemRoleMember)) role, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_admin", + ID: roleID, MemberIds: []string{"local_user_1"}, }) require.NoError(t, err) @@ -240,9 +219,8 @@ func TestService_UpdateRole_NotFound(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{}, nil).Once() - _, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: "role_missing"}) + _, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: "00000000-0000-0000-0000-000000000001"}) require.Error(t, err) require.Contains(t, err.Error(), "role not found") } @@ -251,15 +229,14 @@ func TestService_UpdateRole_WorkOSUpdateFailure(t *testing.T) { t.Parallel() ctx, ti := newTestAccessService(t) - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Custom Builder", "custom-builder", "Old description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{}, nil).Once() - ti.roles.On("UpdateRole", mock.Anything, mockidp.MockOrgID, "custom-builder", thirdpartyworkos.UpdateRoleOpts{}).Return((*thirdpartyworkos.Role)(nil), errors.New("workos unavailable")).Once() + authCtx, ok := contextvalues.GetAuthContext(ctx) + require.True(t, ok) + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Custom Builder", "custom-builder", "Old description")) + ti.roles.On("UpdateRole", mock.Anything, mockidp.MockOrgID, "custom-builder", thirdpartyworkos.UpdateRoleOpts{}).Return((*thirdpartyworkos.Role)(nil), errors.New("workos unavailable")).Times(3) - _, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: "role_custom"}) - require.Error(t, err) - require.Contains(t, err.Error(), "update role in workos") + role, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ID: roleID}) + require.NoError(t, err) + require.Equal(t, roleID, role.ID) } func TestService_UpdateRole_AuditLog(t *testing.T) { @@ -274,12 +251,7 @@ func TestService_UpdateRole_AuditLog(t *testing.T) { name := "Audit Builder" description := "After description" - ti.roles.On("ListRoles", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Role{ - mockRole("role_custom", "Before Builder", "custom-builder", "Before description"), - }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - }, nil).Once() + roleID := seedRole(t, ctx, ti.conn, authCtx.ActiveOrganizationID, mockRole("role_custom", "Before Builder", "custom-builder", "Before description")) ti.roles.On("UpdateRole", mock.Anything, mockidp.MockOrgID, "custom-builder", thirdpartyworkos.UpdateRoleOpts{ Name: &name, Description: &description, @@ -291,17 +263,15 @@ func TestService_UpdateRole_AuditLog(t *testing.T) { CreatedAt: mockRoleTimestamp, UpdatedAt: mockRoleTimestamp, }, nil).Once() - ti.roles.On("ListMembers", mock.Anything, mockidp.MockOrgID).Return([]thirdpartyworkos.Member{ - mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder"), - mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder"), - }, nil).Once() seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", "user1@test.com", "User 1", "user_1", "membership_1") seedConnectedUser(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", "user2@test.com", "User 2", "user_2", "membership_2") + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_1", mockMember(mockidp.MockOrgID, "membership_1", "user_1", "custom-builder")) + seedRoleAssignment(t, ctx, ti.conn, authCtx.ActiveOrganizationID, "local_user_2", mockMember(mockidp.MockOrgID, "membership_2", "user_2", "custom-builder")) seedGrant(t, ctx, ti.conn, authCtx.ActiveOrganizationID, urn.NewPrincipal(urn.PrincipalTypeRole, "custom-builder"), authz.ScopeProjectRead, "project-old") updated, err := ti.service.UpdateRole(ctx, &gen.UpdateRolePayload{ - ID: "role_custom", + ID: roleID, Name: &name, Description: &description, Grants: []*gen.RoleGrant{{ diff --git a/server/internal/assets/setup_test.go b/server/internal/assets/setup_test.go index 7455f20504..3898aa6ebe 100644 --- a/server/internal/assets/setup_test.go +++ b/server/internal/assets/setup_test.go @@ -101,7 +101,6 @@ func newTestAssetsService(t *testing.T) (context.Context, *testInstance) { authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), - cache.NoopCache, ), auditLogger, ) diff --git a/server/internal/assistantmemories/impl_test.go b/server/internal/assistantmemories/impl_test.go index 5c0dcb1a9d..5b2e044964 100644 --- a/server/internal/assistantmemories/impl_test.go +++ b/server/internal/assistantmemories/impl_test.go @@ -13,7 +13,6 @@ import ( gen "github.com/speakeasy-api/gram/server/gen/assistant_memories" "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/authztest" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/memory" "github.com/speakeasy-api/gram/server/internal/memory/repo" @@ -64,7 +63,7 @@ func newTestHarness(t *testing.T) (*testHarness, context.Context) { mem := &fakeMemory{listFn: nil, getFn: nil, deleteFn: nil} - authzEngine := authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysDisabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysDisabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := &Service{ tracer: tracerProvider.Tracer("test"), @@ -135,7 +134,7 @@ func TestListAssistantMemories_RBACDenied(t *testing.T) { h, ctx := newTestHarness(t) logger := testenv.NewLogger(t) - h.svc.authz = authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + h.svc.authz = authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx) _, err := h.svc.ListAssistantMemories(ctx, &gen.ListAssistantMemoriesPayload{ @@ -304,7 +303,7 @@ func TestGetAssistantMemory_RBACDenied(t *testing.T) { h, ctx := newTestHarness(t) logger := testenv.NewLogger(t) - h.svc.authz = authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + h.svc.authz = authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx) _, err := h.svc.GetAssistantMemory(ctx, &gen.GetAssistantMemoryPayload{ @@ -375,7 +374,7 @@ func TestDeleteAssistantMemory_RBACDenied(t *testing.T) { h, ctx := newTestHarness(t) logger := testenv.NewLogger(t) - h.svc.authz = authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + h.svc.authz = authz.NewEngine(logger, nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.NewGrant(authz.ScopeProjectRead, h.projectID.String())) diff --git a/server/internal/assistants/impl_test.go b/server/internal/assistants/impl_test.go index 3e0111b9a6..7e1ec5ae85 100644 --- a/server/internal/assistants/impl_test.go +++ b/server/internal/assistants/impl_test.go @@ -13,7 +13,6 @@ import ( "github.com/speakeasy-api/gram/server/gen/types" "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/authztest" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/oops" projectsRepo "github.com/speakeasy-api/gram/server/internal/projects/repo" @@ -254,7 +253,7 @@ func newRBACServiceWithConn(t *testing.T, dbName string) (*Service, context.Cont logger := testenv.NewLogger(t) chConn, err := assistantsInfra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) service := &Service{ tracer: testenv.NewTracerProvider(t).Tracer("test"), logger: logger, diff --git a/server/internal/attr/conventions.go b/server/internal/attr/conventions.go index 5c53dc2c60..23e23b429c 100644 --- a/server/internal/attr/conventions.go +++ b/server/internal/attr/conventions.go @@ -210,6 +210,8 @@ const ( AccessMemberIDKey = attribute.Key("gram.access.member.id") AccessRoleIDKey = attribute.Key("gram.access.role.id") AccessRoleSlugKey = attribute.Key("gram.access.role.slug") + AccessRoleSourceKey = attribute.Key("gram.access.role.source") + AccessRoleDBWriteFailedKey = attribute.Key("gram.access.role.db_write_failed_post_workos") OrganizationRoleAssignmentIDKey = attribute.Key("gram.org.role_assignment.id") OrganizationIDKey = attribute.Key("gram.org.id") OrganizationSlugKey = attribute.Key("gram.org.slug") @@ -911,6 +913,16 @@ func SlogAccessRoleID(v string) slog.Attr { return slog.String(string(Acces func AccessRoleSlug(v string) attribute.KeyValue { return AccessRoleSlugKey.String(v) } func SlogAccessRoleSlug(v string) slog.Attr { return slog.String(string(AccessRoleSlugKey), v) } +func AccessRoleSource(v string) attribute.KeyValue { return AccessRoleSourceKey.String(v) } +func SlogAccessRoleSource(v string) slog.Attr { return slog.String(string(AccessRoleSourceKey), v) } + +func AccessRoleDBWriteFailed(v bool) attribute.KeyValue { + return AccessRoleDBWriteFailedKey.Bool(v) +} +func SlogAccessRoleDBWriteFailed(v bool) slog.Attr { + return slog.Bool(string(AccessRoleDBWriteFailedKey), v) +} + func OrganizationRoleAssignmentID(v string) attribute.KeyValue { return OrganizationRoleAssignmentIDKey.String(v) } diff --git a/server/internal/auditapi/setup_test.go b/server/internal/auditapi/setup_test.go index 02e7c1571b..41b0d1af10 100644 --- a/server/internal/auditapi/setup_test.go +++ b/server/internal/auditapi/setup_test.go @@ -67,7 +67,7 @@ func newTestAuditService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) return ctx, &testInstance{ service: auditapi.NewService(logger, tracerProvider, conn, sessionManager, authzEngine), diff --git a/server/internal/auth/e2e_test.go b/server/internal/auth/e2e_test.go index 7e2ee95d99..64bc1756a0 100644 --- a/server/internal/auth/e2e_test.go +++ b/server/internal/auth/e2e_test.go @@ -162,7 +162,7 @@ func newE2EAuthService(t *testing.T, userInfo *MockUserInfo, fetcher *mockWorkOS require.NoError(t, err) nonceStore := cache.NewRedisCacheAdapter(redisClient) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := auth.NewService(logger, tracerProvider, conn, sessionManager, resolver, authConfigs, authzEngine, billingClient, noopCancelScheduler{}, posthogClient, nonceStore) ti := newTestAuthServiceResult(t, svc, conn, sessionManager, resolver, mockServer, authConfigs, nonceStore) diff --git a/server/internal/auth/impl.go b/server/internal/auth/impl.go index a1ca824041..bc13ecd0c8 100644 --- a/server/internal/auth/impl.go +++ b/server/internal/auth/impl.go @@ -418,7 +418,6 @@ func (s *Service) acceptPendingInvitationForMember(ctx context.Context, organiza }); err != nil { return fmt.Errorf("sync invite role assignments: %w", err) } - s.authz.InvalidateRoleCache(ctx, gramUserID, invite.OrganizationID) } if err := s.sessions.InvalidateUserInfoCache(ctx, gramUserID); err != nil { diff --git a/server/internal/auth/setup_test.go b/server/internal/auth/setup_test.go index d18b6b0f3a..9d80218e6a 100644 --- a/server/internal/auth/setup_test.go +++ b/server/internal/auth/setup_test.go @@ -160,7 +160,7 @@ func newTestAuthService(t *testing.T, userInfo *MockUserInfo) (context.Context, require.NoError(t, err) nonceStore := cache.NewRedisCacheAdapter(redisClient) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := auth.NewService(logger, tracerProvider, conn, sessionManager, resolver, authConfigs, authzEngine, billingClient, noopCancelScheduler{}, posthog, nonceStore) return ctx, newTestAuthServiceResult(t, svc, conn, sessionManager, resolver, mockServer, authConfigs, nonceStore) @@ -208,7 +208,7 @@ func newTestAuthServiceWithAuthz(t *testing.T, userInfo *MockUserInfo) (context. require.NoError(t, err) nonceStore := cache.NewRedisCacheAdapter(redisClient) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := auth.NewService(logger, tracerProvider, conn, sessionManager, resolver, authConfigs, authzEngine, billingClient, noopCancelScheduler{}, posthog, nonceStore) return ctx, newTestAuthServiceResult(t, svc, conn, sessionManager, resolver, mockServer, authConfigs, nonceStore) diff --git a/server/internal/authz/context_test.go b/server/internal/authz/context_test.go index 2efc754cbc..c4b3d065a6 100644 --- a/server/internal/authz/context_test.go +++ b/server/internal/authz/context_test.go @@ -7,7 +7,6 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" @@ -24,7 +23,7 @@ func TestPrepareContext_loadsUserGrants(t *testing.T) { conn := newTestDB(t) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) @@ -49,7 +48,7 @@ func TestPrepareContext_skipsNonSessionAuth(t *testing.T) { conn := newTestDB(t) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) @@ -71,7 +70,7 @@ func TestPrepareContext_loadsAssistantPrincipalGrants(t *testing.T) { conn := newTestDB(t) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) @@ -101,7 +100,7 @@ func TestShouldEnforce_assistantPrincipalOnEnterpriseOrgEnforces(t *testing.T) { conn := newTestDB(t) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) @@ -126,7 +125,7 @@ func TestShouldEnforce_assistantPrincipalOnNonEnterpriseSkips(t *testing.T) { conn := newTestDB(t) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) @@ -150,7 +149,7 @@ func TestPrepareContext_skipsNonEnterpriseOrgs(t *testing.T) { conn := newTestDB(t) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) authCtx, ok := contextvalues.GetAuthContext(ctx) require.True(t, ok) diff --git a/server/internal/authz/engine.go b/server/internal/authz/engine.go index 8cb46751c1..2d66071eee 100644 --- a/server/internal/authz/engine.go +++ b/server/internal/authz/engine.go @@ -5,16 +5,14 @@ import ( "errors" "fmt" "log/slog" - "time" "github.com/ClickHouse/clickhouse-go/v2" "github.com/jackc/pgx/v5/pgxpool" + accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" "github.com/speakeasy-api/gram/server/internal/attr" authzrepo "github.com/speakeasy-api/gram/server/internal/authz/repo" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/oops" - orgrepo "github.com/speakeasy-api/gram/server/internal/organizations/repo" "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" usersrepo "github.com/speakeasy-api/gram/server/internal/users/repo" @@ -31,28 +29,6 @@ type EngineOpts struct { DevMode bool } -// roleSlugCache is the Redis cache entry for a resolved role slug. -// Key is org-first so DeleteByPrefix on "role-slug:{orgID}:" invalidates the whole org. -type roleSlugCache struct { - UserID string - OrgID string - Slug string -} - -var _ cache.CacheableObject[roleSlugCache] = (*roleSlugCache)(nil) - -func (r roleSlugCache) CacheKey() string { - return "role-slug:" + r.OrgID + ":" + r.UserID -} - -func (r roleSlugCache) TTL() time.Duration { - return 5 * time.Minute -} - -func (r roleSlugCache) AdditionalCacheKeys() []string { - return nil -} - // ChallengeLoggingEnabled checks whether authz challenge logging to ClickHouse // is enabled for a given organization. Same signature as IsRBACEnabled. type ChallengeLoggingEnabled func(ctx context.Context, organizationID string) (bool, error) @@ -65,10 +41,9 @@ type Engine struct { challengeLoggingEnabled ChallengeLoggingEnabled isDev bool membership MembershipFetcher - roleCache cache.TypedCacheObject[roleSlugCache] } -func NewEngine(logger *slog.Logger, db *pgxpool.Pool, chDB clickhouse.Conn, isEnabled IsRBACEnabled, challengeLogging ChallengeLoggingEnabled, membership MembershipFetcher, roleCache cache.Cache, opts ...EngineOpts) *Engine { +func NewEngine(logger *slog.Logger, db *pgxpool.Pool, chDB clickhouse.Conn, isEnabled IsRBACEnabled, challengeLogging ChallengeLoggingEnabled, membership MembershipFetcher, opts ...EngineOpts) *Engine { var devMode bool if len(opts) > 0 { devMode = opts[0].DevMode @@ -84,7 +59,6 @@ func NewEngine(logger *slog.Logger, db *pgxpool.Pool, chDB clickhouse.Conn, isEn challengeLoggingEnabled: challengeLogging, isDev: devMode, membership: membership, - roleCache: cache.NewTypedObjectCache[roleSlugCache](logger.With(attr.SlogCacheNamespace("authz-role-slug")), roleCache, cache.SuffixNone), } } @@ -158,19 +132,25 @@ func (e *Engine) PrepareContext(ctx context.Context) (context.Context, error) { principals := []urn.Principal{urn.NewPrincipal(urn.PrincipalTypeUser, authCtx.UserID)} - roleSlug, err := e.resolveRoleSlug(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) + rolePrincipals, err := e.resolveRolePrincipals(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) if err != nil { e.logger.ErrorContext( ctx, - "failed to resolve role for authz grants", + "failed to resolve roles for authz grants", attr.SlogOrganizationID(authCtx.ActiveOrganizationID), attr.SlogUserID(authCtx.UserID), attr.SlogError(err), ) - return ctx, fmt.Errorf("resolve role slug: %w", err) - } - if roleSlug != "" { - principals = append(principals, urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug)) + return ctx, fmt.Errorf("resolve role slugs: %w", err) + } + for _, role := range rolePrincipals { + // Load grants for both canonical role:: principals and + // legacy role: principals while existing grant rows are backfilled. + rolePrincipalURNs, err := RolePrincipals(role.RoleSlug, role.PrincipalUrn) + if err != nil { + return ctx, fmt.Errorf("build role principals: %w", err) + } + principals = append(principals, rolePrincipalURNs...) } grants, err := LoadGrants(ctx, e.db, authCtx.ActiveOrganizationID, principals) @@ -188,77 +168,27 @@ func (e *Engine) PrepareContext(ctx context.Context) (context.Context, error) { return GrantsToContext(ctx, grants), nil } -func (e *Engine) resolveRoleSlug(ctx context.Context, userID, orgID string) (string, error) { - cacheKey := roleSlugCache{UserID: userID, OrgID: orgID, Slug: ""}.CacheKey() - if cached, err := e.roleCache.Get(ctx, cacheKey); err == nil { - return cached.Slug, nil - } - +func (e *Engine) resolveRolePrincipals(ctx context.Context, userID, orgID string) ([]accessrepo.ListMemberRolePrincipalsByWorkosUserRow, error) { user, err := usersrepo.New(e.db).GetUser(ctx, userID) if err != nil { - return "", fmt.Errorf("get user: %w", err) + return nil, fmt.Errorf("get user: %w", err) } if !user.WorkosID.Valid || user.WorkosID.String == "" { - e.storeRoleSlugCache(ctx, userID, orgID, "") - return "", nil + return nil, nil } - org, err := orgrepo.New(e.db).GetOrganizationMetadata(ctx, orgID) + // Role assignments are local source-of-truth records. They are written by + // the access write path and by WorkOS sync, so invitation/admin-console + // changes can lag until the sync job catches up. + rolePrincipals, err := accessrepo.New(e.db).ListMemberRolePrincipalsByWorkosUser(ctx, accessrepo.ListMemberRolePrincipalsByWorkosUserParams{ + OrganizationID: orgID, + WorkosUserID: user.WorkosID.String, + }) if err != nil { - return "", fmt.Errorf("get org: %w", err) - } - if !org.WorkosID.Valid || org.WorkosID.String == "" { - e.storeRoleSlugCache(ctx, userID, orgID, "") - return "", nil + return nil, fmt.Errorf("list member role slugs: %w", err) } - member, err := e.membership.GetOrgMembership(ctx, user.WorkosID.String, org.WorkosID.String) - if err != nil { - return "", fmt.Errorf("get org membership: %w", err) - } - if member == nil { - e.storeRoleSlugCache(ctx, userID, orgID, "") - return "", nil - } - - e.storeRoleSlugCache(ctx, userID, orgID, member.RoleSlug) - - return member.RoleSlug, nil -} - -func (e *Engine) storeRoleSlugCache(ctx context.Context, userID, orgID, slug string) { - entry := roleSlugCache{UserID: userID, OrgID: orgID, Slug: slug} - if err := e.roleCache.Store(ctx, entry); err != nil { - e.logger.WarnContext(ctx, "failed to cache role slug", - attr.SlogUserID(userID), - attr.SlogOrganizationID(orgID), - attr.SlogError(err), - ) - } -} - -// InvalidateRoleCache removes the cached role slug for a single user. Call -// this after updating a specific member's role via UpdateMemberRole. -func (e *Engine) InvalidateRoleCache(ctx context.Context, userID, orgID string) { - entry := roleSlugCache{UserID: userID, OrgID: orgID, Slug: ""} - if err := e.roleCache.Delete(ctx, entry); err != nil { - e.logger.WarnContext(ctx, "failed to invalidate cached role slug", - attr.SlogUserID(userID), - attr.SlogOrganizationID(orgID), - attr.SlogError(err), - ) - } -} - -// InvalidateAllRoleCaches removes all cached role slugs for an org. Call this -// after bulk role reassignments where individual user IDs aren't tracked. -func (e *Engine) InvalidateAllRoleCaches(ctx context.Context, orgID string) { - if err := e.roleCache.DeleteByPrefix(ctx, "role-slug:"+orgID+":"); err != nil { - e.logger.WarnContext(ctx, "failed to invalidate cached role slugs for org", - attr.SlogOrganizationID(orgID), - attr.SlogError(err), - ) - } + return rolePrincipals, nil } func (e *Engine) Require(ctx context.Context, checks ...Check) error { diff --git a/server/internal/authz/engine_test.go b/server/internal/authz/engine_test.go index 2d927f5972..264083518b 100644 --- a/server/internal/authz/engine_test.go +++ b/server/internal/authz/engine_test.go @@ -2,10 +2,7 @@ package authz import ( "context" - "encoding/json" "errors" - "fmt" - "strings" "testing" "time" @@ -13,7 +10,6 @@ import ( "github.com/stretchr/testify/require" authzrepo "github.com/speakeasy-api/gram/server/internal/authz/repo" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/oops" "github.com/speakeasy-api/gram/server/internal/testenv" @@ -43,7 +39,7 @@ func TestEngineRequire_requiresAuthContext(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) err = engine.Require(t.Context(), Check{Scope: ScopeProjectRead, ResourceID: "proj_123"}) var oopsErr *oops.ShareableError @@ -56,7 +52,7 @@ func TestEngineRequire_skipsWhenRBACFeatureDisabled(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient()) err = engine.Require(enterpriseSessionCtx(t), Check{Scope: ScopeProjectRead, ResourceID: "proj_123"}) require.NoError(t, err) @@ -67,7 +63,7 @@ func TestEngineRequire_mapsDeniedToForbidden(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), nil) err = engine.Require(ctx, Check{Scope: ScopeProjectRead, ResourceID: "proj_123"}) @@ -81,7 +77,7 @@ func TestEngineRequire_mapsMissingGrantsToUnexpected(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) err = engine.Require(enterpriseSessionCtx(t), Check{Scope: ScopeProjectRead, ResourceID: "proj_123"}) var oopsErr *oops.ShareableError @@ -95,7 +91,7 @@ func TestEngineRequire_returnsUnexpectedWhenFeatureCheckFails(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, failingRBAC(errors.New("boom")), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, failingRBAC(errors.New("boom")), staticChallengeLogging(true), workos.NewStubClient()) err = engine.Require(enterpriseSessionCtx(t), Check{Scope: ScopeProjectRead, ResourceID: "proj_123"}) var oopsErr *oops.ShareableError @@ -103,39 +99,12 @@ func TestEngineRequire_returnsUnexpectedWhenFeatureCheckFails(t *testing.T) { require.Equal(t, oops.CodeUnexpected, oopsErr.Code) } -func TestResolveRoleSlug_cachesEmptyMembershipResult(t *testing.T) { - t.Parallel() - - ctx := enterpriseTestCtx(t.Context()) - conn := newTestDB(t) - authCtx, ok := contextvalues.GetAuthContext(ctx) - require.True(t, ok) - require.NotNil(t, authCtx) - - seedOrganization(t, ctx, conn, authCtx.ActiveOrganizationID) - seedConnectedUser(t, ctx, conn, authCtx.ActiveOrganizationID, authCtx.UserID, "test@example.com", "Test User", "user_workos_test", "membership_test") - - membership := &countingMembershipFetcher{} - chConn, err := newClickhouseClient(t) - require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, staticRBAC(true), staticChallengeLogging(true), membership, newMapCache()) - - roleSlug, err := engine.resolveRoleSlug(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) - require.NoError(t, err) - require.Empty(t, roleSlug) - - roleSlug, err = engine.resolveRoleSlug(ctx, authCtx.UserID, authCtx.ActiveOrganizationID) - require.NoError(t, err) - require.Empty(t, roleSlug) - require.Equal(t, 1, membership.calls) -} - func TestEngineRequireAny_mapsDeniedToForbidden(t *testing.T) { t.Parallel() chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{NewGrant(ScopeMCPConnect, "tool_a")}) err = engine.RequireAny(ctx, @@ -152,7 +121,7 @@ func TestEngineFilter_returnsAllowedSubset(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{NewGrant(ScopeProjectRead, "proj_123")}) resourceIDs, err := engine.Filter(ctx, []Check{ @@ -170,7 +139,7 @@ func TestEngineFilter_logsSingleAggregateChallenge(t *testing.T) { ctx := GrantsToContext(enterpriseSessionCtxWithOrg(t, orgID), []Grant{NewGrant(ScopeProjectRead, "proj_allowed")}) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) resourceIDs, err := engine.Filter(ctx, []Check{ {Scope: ScopeProjectRead, ResourceID: "proj_allowed"}, @@ -219,7 +188,7 @@ func TestEngineFilter_logsDenyWhenNoMatches(t *testing.T) { ctx := GrantsToContext(enterpriseSessionCtxWithOrg(t, orgID), []Grant{NewGrant(ScopeProjectRead, "proj_other")}) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) resourceIDs, err := engine.Filter(ctx, []Check{ {Scope: ScopeProjectRead, ResourceID: "proj_a"}, @@ -262,7 +231,7 @@ func TestEngineFilter_skipsLogWhenNoChecks(t *testing.T) { ctx := GrantsToContext(enterpriseSessionCtxWithOrg(t, orgID), []Grant{NewGrant(ScopeProjectRead, WildcardResource)}) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) resourceIDs, err := engine.Filter(ctx, nil) require.NoError(t, err) @@ -288,7 +257,7 @@ func TestEngineRequire_denyGrantBlocksAccess(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ NewGrant(ScopeProjectRead, WildcardResource), NewDenyGrant(ScopeProjectRead, "proj_secret"), @@ -310,7 +279,7 @@ func TestEngineRequireAny_denySkipsToNextCheck(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ NewGrant(ScopeMCPConnect, WildcardResource), NewDenyGrant(ScopeMCPConnect, "tool_blocked"), @@ -329,7 +298,7 @@ func TestEngineRequireAny_allDeniedReturnsForbidden(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ NewGrant(ScopeMCPConnect, WildcardResource), NewDenyGrant(ScopeMCPConnect, WildcardResource), @@ -349,7 +318,7 @@ func TestEngineFilter_denyExcludesResources(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ NewGrant(ScopeProjectRead, WildcardResource), NewDenyGrant(ScopeProjectRead, "proj_secret"), @@ -369,7 +338,7 @@ func TestEngineFindMatched_denyReturnsFalse(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ NewGrant(ScopeMCPConnect, WildcardResource), NewDenyGrant(ScopeMCPConnect, "tool_blocked"), @@ -389,7 +358,7 @@ func TestEngineFilter_withDimensions(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ {Scope: ScopeMCPConnect, Selector: Selector{ "resource_kind": "mcp", @@ -412,7 +381,7 @@ func TestEngineFilter_withDisposition(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ {Scope: ScopeMCPConnect, Selector: Selector{ "resource_kind": "mcp", @@ -435,7 +404,7 @@ func TestEngineFilter_serverLevelGrantAllowsAllDimensions(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ NewGrant(ScopeMCPConnect, "toolsetA"), }) @@ -454,7 +423,7 @@ func TestEngineFilter_projectScopedGrantMatchesServersInProject(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ {Scope: ScopeMCPConnect, Selector: Selector{ "resource_kind": "mcp", @@ -477,7 +446,7 @@ func TestEngineRequire_projectScopedGrantAllowsToolsInProject(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ {Scope: ScopeMCPConnect, Selector: Selector{ "resource_kind": "mcp", @@ -508,7 +477,7 @@ func TestEngineRequire_projectScopedMCPReadGrant(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ {Scope: ScopeMCPRead, Selector: Selector{ "resource_kind": "mcp", @@ -533,7 +502,7 @@ func TestEngineFilter_projectAndServerGrantsCombine(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ // Project-scoped grant for proj_A {Scope: ScopeMCPConnect, Selector: Selector{ @@ -559,7 +528,7 @@ func TestEngineRequire_rejectsInvalidCheck(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{NewGrant(ScopeProjectRead, WildcardResource)}) err = engine.Require(ctx, Check{Scope: ScopeProjectRead, ResourceID: ""}) @@ -574,7 +543,7 @@ func TestEngineRequire_requiresChecks(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{NewGrant(ScopeProjectRead, WildcardResource)}) err = engine.Require(ctx) @@ -589,7 +558,7 @@ func TestEngineRequire_skipsForAPIKeyAuth(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) sessionID := "session_123" ctx := contextvalues.SetAuthContext(t.Context(), &contextvalues.AuthContext{ ActiveOrganizationID: "org_123", @@ -616,7 +585,7 @@ func TestEngineFilter_skipsForNonEnterpriseAccount(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) sessionID := "session_123" ctx := contextvalues.SetAuthContext(t.Context(), &contextvalues.AuthContext{ ActiveOrganizationID: "org_123", @@ -647,7 +616,7 @@ func TestEngineFindMatched_returnsParallelBools(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{NewGrant(ScopeProjectRead, "proj_123")}) matched, err := engine.FindMatched(ctx, []Check{ @@ -666,7 +635,7 @@ func TestEngineFindMatched_preservesOrderAcrossMixedMatches(t *testing.T) { // exactly, with no implicit reordering. chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{ NewGrant(ScopeProjectRead, "proj_b"), NewGrant(ScopeProjectRead, "proj_d"), @@ -687,7 +656,7 @@ func TestEngineFindMatched_returnsAllTrueWhenEnforcementDisabled(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) sessionID := "session_123" ctx := contextvalues.SetAuthContext(t.Context(), &contextvalues.AuthContext{ ActiveOrganizationID: "org_123", @@ -719,7 +688,7 @@ func TestEngineFindMatched_emptyInputReturnsEmptySlice(t *testing.T) { orgID := "org_" + uuid.NewString() chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtxWithOrg(t, orgID), []Grant{NewGrant(ScopeProjectRead, WildcardResource)}) matched, err := engine.FindMatched(ctx, nil) @@ -744,7 +713,7 @@ func TestEngineFindMatched_missingGrantsReturnsError(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) _, err = engine.FindMatched(enterpriseSessionCtx(t), []Check{ {Scope: ScopeProjectRead, ResourceID: "proj_123"}, @@ -760,7 +729,7 @@ func TestEngineFindMatched_rejectsInvalidCheck(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) ctx := GrantsToContext(enterpriseSessionCtx(t), []Grant{NewGrant(ScopeProjectRead, WildcardResource)}) _, err = engine.FindMatched(ctx, []Check{{Scope: ScopeProjectRead, ResourceID: ""}}) @@ -777,7 +746,7 @@ func TestEngineFindMatched_logsSingleAggregateChallenge(t *testing.T) { ctx := GrantsToContext(enterpriseSessionCtxWithOrg(t, orgID), []Grant{NewGrant(ScopeProjectRead, "proj_allowed")}) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) matched, err := engine.FindMatched(ctx, []Check{ {Scope: ScopeProjectRead, ResourceID: "proj_allowed"}, @@ -819,81 +788,6 @@ func TestEngineFindMatched_logsSingleAggregateChallenge(t *testing.T) { }, 5*time.Second, 100*time.Millisecond) } -type countingMembershipFetcher struct { - calls int -} - -func (c *countingMembershipFetcher) GetOrgMembership(context.Context, string, string) (*workos.Member, error) { - c.calls++ - return nil, nil -} - -type mapCache struct { - items map[string][]byte -} - -func newMapCache() *mapCache { - return &mapCache{items: map[string][]byte{}} -} - -func (m *mapCache) Get(_ context.Context, key string, value any) error { - item, ok := m.items[key] - if !ok { - return errors.New("cache miss") - } - if err := json.Unmarshal(item, value); err != nil { - return fmt.Errorf("unmarshal cache item: %w", err) - } - return nil -} - -func (m *mapCache) GetAndDelete(ctx context.Context, key string, value any) error { - if err := m.Get(ctx, key, value); err != nil { - return err - } - delete(m.items, key) - return nil -} - -func (m *mapCache) Set(_ context.Context, key string, value any, _ time.Duration) error { - item, err := json.Marshal(value) - if err != nil { - return fmt.Errorf("marshal cache item: %w", err) - } - m.items[key] = item - return nil -} - -func (m *mapCache) Update(ctx context.Context, key string, value any) error { - return m.Set(ctx, key, value, 0) -} - -func (m *mapCache) Delete(_ context.Context, key string) error { - delete(m.items, key) - return nil -} - -func (m *mapCache) Expire(_ context.Context, _ string, _ time.Duration) error { - return nil -} - -func (m *mapCache) ListAppend(context.Context, string, any, time.Duration) error { - return errors.New("not implemented") -} - -func (m *mapCache) ListRange(context.Context, string, int64, int64, any) error { - return errors.New("not implemented") -} - -func (m *mapCache) DeleteByPrefix(_ context.Context, prefix string) error { - for key := range m.items { - if strings.HasPrefix(key, prefix) { - delete(m.items, key) - } - } - return nil -} - func enterpriseSessionCtx(t *testing.T) context.Context { t.Helper() return enterpriseSessionCtxWithOrg(t, "org_123") @@ -952,7 +846,7 @@ func TestPrepareContext_adminImpersonationGrantsAllScopes(t *testing.T) { chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(true), staticChallengeLogging(true), workos.NewStubClient()) // Build a context that looks like admin impersonation: enterprise account, // IsAdmin flag, and AdminOverride pointing at the target org. @@ -989,7 +883,7 @@ func TestCanUseOverride_devPlusAdmin(t *testing.T) { t.Parallel() chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache, EngineOpts{DevMode: true}) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), EngineOpts{DevMode: true}) ctx := scopeOverrideCtx(t, true, "pro") enforce, err := engine.ShouldEnforce(ctx) @@ -1001,7 +895,7 @@ func TestCanUseOverride_devPlusNonAdmin(t *testing.T) { t.Parallel() chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache, EngineOpts{DevMode: true}) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), EngineOpts{DevMode: true}) ctx := scopeOverrideCtx(t, false, "pro") enforce, err := engine.ShouldEnforce(ctx) @@ -1013,7 +907,7 @@ func TestCanUseOverride_prodPlusAdmin(t *testing.T) { t.Parallel() chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient()) ctx := scopeOverrideCtx(t, true, "pro") enforce, err := engine.ShouldEnforce(ctx) @@ -1025,7 +919,7 @@ func TestCanUseOverride_prodPlusNonAdmin(t *testing.T) { t.Parallel() chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), nil, chConn, staticRBAC(false), staticChallengeLogging(true), workos.NewStubClient()) ctx := scopeOverrideCtx(t, false, "pro") enforce, err := engine.ShouldEnforce(ctx) @@ -1042,6 +936,7 @@ func TestSyncGrants_denyEffectSurvivesDBRoundTrip(t *testing.T) { ctx := t.Context() orgID := "org_deny_roundtrip" roleSlug := "deny-test-role" + rolePrincipalURN := "role:organization:" + uuid.NewString() seedOrganization(t, ctx, conn, orgID) @@ -1050,10 +945,10 @@ func TestSyncGrants_denyEffectSurvivesDBRoundTrip(t *testing.T) { {Scope: string(ScopeMCPConnect), Effect: PolicyEffectDeny, Selectors: []Selector{NewSelector(ScopeMCPConnect, "server_blocked")}}, // deny specific } - err := SyncGrants(ctx, testenv.NewLogger(t), conn, orgID, roleSlug, grants) + err := SyncGrants(ctx, testenv.NewLogger(t), conn, orgID, roleSlug, rolePrincipalURN, grants) require.NoError(t, err) - scoped, err := GrantsForRole(ctx, testenv.NewLogger(t), conn, orgID, roleSlug) + scoped, err := GrantsForRole(ctx, testenv.NewLogger(t), conn, orgID, roleSlug, rolePrincipalURN) require.NoError(t, err) var allowGrant, denyGrant *ScopedGrant diff --git a/server/internal/authz/grants.go b/server/internal/authz/grants.go index 70cb6c3da3..39550040e4 100644 --- a/server/internal/authz/grants.go +++ b/server/internal/authz/grants.go @@ -3,10 +3,13 @@ package authz import ( "cmp" "context" + "errors" "fmt" "log/slog" "slices" + "time" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" @@ -81,7 +84,40 @@ var SystemRoleGrants = map[string][]*RoleGrant{ // SeedSystemRoleGrants upserts the fixed grant sets for all system roles. func SeedSystemRoleGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, organizationID string) error { for roleSlug, grants := range SystemRoleGrants { - if err := SyncGrants(ctx, logger, db, organizationID, roleSlug, grants); err != nil { + existingRole, err := repo.New(db).GetGlobalRoleBySlug(ctx, roleSlug) + seedRole := false + switch { + case err == nil: + seedRole = existingRole.Deleted + case errors.Is(err, pgx.ErrNoRows): + seedRole = true + default: + return fmt.Errorf("load %s role: %w", roleSlug, err) + } + if seedRole { + name := roleSlug + description := "" + switch roleSlug { + case SystemRoleAdmin: + name = "Admin" + description = "Administrator role" + case SystemRoleMember: + name = "Member" + description = "Member role" + } + now := time.Now().UTC() + if err := repo.New(db).UpsertGlobalRole(ctx, repo.UpsertGlobalRoleParams{ + WorkosSlug: roleSlug, + WorkosName: name, + WorkosDescription: conv.ToPGTextEmpty(description), + WorkosCreatedAt: conv.ToPGTimestamptz(now), + WorkosUpdatedAt: conv.ToPGTimestamptz(now), + WorkosLastEventID: conv.ToPGTextEmpty(""), + }); err != nil { + return fmt.Errorf("seed %s role: %w", roleSlug, err) + } + } + if err := SyncGrants(ctx, logger, db, organizationID, roleSlug, "", grants); err != nil { return fmt.Errorf("seed %s grants: %w", roleSlug, err) } } @@ -102,28 +138,59 @@ type ScopedGrant struct { Selectors []Selector } -func SyncGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgID string, roleSlug string, grants []*RoleGrant) error { +func SyncGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgID string, roleSlug string, rolePrincipalURN string, grants []*RoleGrant) error { if orgID == "" { return fmt.Errorf("organization id is required") } - principalURN := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) - tx, err := db.Begin(ctx) if err != nil { return fmt.Errorf("begin grant sync transaction: %w", err) } defer o11y.NoLogDefer(func() error { return tx.Rollback(ctx) }) - q := repo.New(tx) + if _, err := SyncGrantsTx(ctx, tx, orgID, roleSlug, rolePrincipalURN, grants); err != nil { + return err + } - if _, err := q.DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ - OrganizationID: orgID, - PrincipalUrn: principalURN, - }); err != nil { - return fmt.Errorf("delete grants for role %q: %w", roleSlug, err) + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("commit grant sync transaction: %w", err) + } + + return nil +} + +func SyncGrantsTx(ctx context.Context, dbtx repo.DBTX, orgID string, roleSlug string, rolePrincipalURN string, grants []*RoleGrant) ([]*ScopedGrant, error) { + if orgID == "" { + return nil, fmt.Errorf("organization id is required") } + if rolePrincipalURN == "" { + role, err := repo.New(dbtx).GetActiveOrganizationRoleBySlug(ctx, repo.GetActiveOrganizationRoleBySlugParams{ + OrganizationID: orgID, + WorkosSlug: roleSlug, + }) + if err != nil { + return nil, fmt.Errorf("resolve role principal for %q: %w", roleSlug, err) + } + rolePrincipalURN = role.RoleUrn + } + + principalURN, err := urn.ParsePrincipal(rolePrincipalURN) + if err != nil { + return nil, fmt.Errorf("parse role principal urn %q: %w", rolePrincipalURN, err) + } + + q := repo.New(dbtx) + // During the role-principal migration, replace grants for both the new + // role:: principal and the legacy role: principal. New + // writes below only insert the canonical URN form. + if err := DeleteRoleGrants(ctx, q, orgID, roleSlug, rolePrincipalURN); err != nil { + return nil, err + } + + grantRows := make([]Grant, 0, len(grants)) + seenGrants := make(map[string]struct{}, len(grants)) for _, grant := range grants { if grant == nil { continue @@ -138,7 +205,7 @@ func SyncGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgI sel := NewSelector(scope, WildcardResource) selBytes, err := sel.MarshalJSON() if err != nil { - return fmt.Errorf("marshal wildcard selector for %q: %w", grant.Scope, err) + return nil, fmt.Errorf("marshal wildcard selector for %q: %w", grant.Scope, err) } if _, err := q.UpsertPrincipalGrant(ctx, repo.UpsertPrincipalGrantParams{ OrganizationID: orgID, @@ -147,19 +214,29 @@ func SyncGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgI Effect: effectToPgtype(effect), Selectors: selBytes, }); err != nil { - return fmt.Errorf("upsert unrestricted grant %q for role %q: %w", grant.Scope, roleSlug, err) + return nil, fmt.Errorf("upsert unrestricted grant %q for role %q: %w", grant.Scope, roleSlug, err) + } + grantKey := grant.Scope + "\x00" + string(selBytes) + if _, ok := seenGrants[grantKey]; !ok { + seenGrants[grantKey] = struct{}{} + grantRows = append(grantRows, Grant{ + PrincipalUrn: principalURN.String(), + Scope: scope, + Effect: effect, + Selector: sel, + }) } continue } for _, sel := range grant.Selectors { if err := ValidateSelector(scope, sel); err != nil { - return fmt.Errorf("invalid selector for scope %q: %w", grant.Scope, err) + return nil, fmt.Errorf("invalid selector for scope %q: %w", grant.Scope, err) } selBytes, err := sel.MarshalJSON() if err != nil { - return fmt.Errorf("marshal selector for scope %q: %w", grant.Scope, err) + return nil, fmt.Errorf("marshal selector for scope %q: %w", grant.Scope, err) } if _, err := q.UpsertPrincipalGrant(ctx, repo.UpsertPrincipalGrantParams{ OrganizationID: orgID, @@ -168,37 +245,112 @@ func SyncGrants(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgI Effect: effectToPgtype(effect), Selectors: selBytes, }); err != nil { - return fmt.Errorf("upsert grant %q for role %q: %w", grant.Scope, roleSlug, err) + return nil, fmt.Errorf("upsert grant %q for role %q: %w", grant.Scope, roleSlug, err) + } + grantKey := grant.Scope + "\x00" + string(selBytes) + if _, ok := seenGrants[grantKey]; !ok { + seenGrants[grantKey] = struct{}{} + grantRows = append(grantRows, Grant{ + PrincipalUrn: principalURN.String(), + Scope: scope, + Effect: effect, + Selector: sel, + }) } } } - if err := tx.Commit(ctx); err != nil { - return fmt.Errorf("commit grant sync transaction: %w", err) + return GrantsToScopedGrants(grantRows), nil +} + +func DeleteRoleGrants(ctx context.Context, q *repo.Queries, orgID, roleSlug, rolePrincipalURN string) error { + principalURN, err := urn.ParsePrincipal(rolePrincipalURN) + if err != nil { + return fmt.Errorf("parse role principal urn %q: %w", rolePrincipalURN, err) + } + + if _, err := q.DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ + OrganizationID: orgID, + PrincipalUrn: principalURN, + }); err != nil { + return fmt.Errorf("delete grants for role %q: %w", roleSlug, err) + } + + legacyPrincipalURN := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) + if legacyPrincipalURN.String() == principalURN.String() { + return nil + } + // Legacy grants were keyed by role slug (role:). Keep deleting that + // principal alongside the canonical role URN until the backfill is complete. + if _, err := q.DeletePrincipalGrantsByPrincipal(ctx, repo.DeletePrincipalGrantsByPrincipalParams{ + OrganizationID: orgID, + PrincipalUrn: legacyPrincipalURN, + }); err != nil { + return fmt.Errorf("delete legacy grants for role %q: %w", roleSlug, err) } return nil } -func GrantsForRole(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgID string, roleSlug string) ([]*ScopedGrant, error) { - rows, err := repo.New(db).ListPrincipalGrantsByOrg(ctx, repo.ListPrincipalGrantsByOrgParams{ +func RolePrincipals(roleSlug, rolePrincipalURN string) ([]urn.Principal, error) { + // TODO(AGE-1954): drop legacy role: principals after the role-principal backfill is complete. + // Keep the legacy role: principal until the role-principal backfill + // has moved existing grants to role::. + principals := make([]urn.Principal, 0, 2) + if rolePrincipalURN != "" { + principal, err := urn.ParsePrincipal(rolePrincipalURN) + if err != nil { + return nil, fmt.Errorf("parse role principal urn %q: %w", rolePrincipalURN, err) + } + principals = append(principals, principal) + } + + legacyPrincipal := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug) + if rolePrincipalURN == "" || legacyPrincipal.String() != rolePrincipalURN { + principals = append(principals, legacyPrincipal) + } + + return principals, nil +} + +func GrantsForRole(ctx context.Context, logger *slog.Logger, db *pgxpool.Pool, orgID string, roleSlug string, rolePrincipalURN string) ([]*ScopedGrant, error) { + // TODO(AGE-1954): remove dual-read after legacy role: grants are backfilled. + // During the role-principal migration, reads include both the canonical + // role:: principal and the legacy role: principal. + principals, err := RolePrincipals(roleSlug, rolePrincipalURN) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "build role principals").Log(ctx, logger) + } + principalURNs, err := parsePrincipalURNs(principals) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "build role principals").Log(ctx, logger) + } + + rows, err := repo.New(db).GetPrincipalGrants(ctx, repo.GetPrincipalGrantsParams{ OrganizationID: orgID, - PrincipalUrn: urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug).String(), + PrincipalUrns: principalURNs, }) if err != nil { return nil, oops.E(oops.CodeUnexpected, err, "list grants for role").Log(ctx, logger) } - rolePrincipalURN := urn.NewPrincipal(urn.PrincipalTypeRole, roleSlug).String() + scoped, err := scopedGrantsFromGrantRows(rows) + if err != nil { + return nil, oops.E(oops.CodeUnexpected, err, "unmarshal grant selector").Log(ctx, logger) + } + + return scoped, nil +} +func scopedGrantsFromGrantRows(rows []repo.GetPrincipalGrantsRow) ([]*ScopedGrant, error) { grantRows := make([]Grant, 0, len(rows)) for _, row := range rows { selectors, err := SelectorFromRow(row.Selectors) if err != nil { - return nil, oops.E(oops.CodeUnexpected, err, "unmarshal grant selector").Log(ctx, logger) + return nil, err } grantRows = append(grantRows, Grant{ - PrincipalUrn: rolePrincipalURN, + PrincipalUrn: row.PrincipalUrn.String(), Scope: Scope(row.Scope), Effect: effectFromNullable(row.Effect), Selector: selectors, diff --git a/server/internal/authz/integration_test.go b/server/internal/authz/integration_test.go index e98bf05484..77194495c1 100644 --- a/server/internal/authz/integration_test.go +++ b/server/internal/authz/integration_test.go @@ -5,7 +5,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/oops" "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" @@ -31,7 +30,7 @@ func TestRequire_withLoadedGrantsFromContext(t *testing.T) { ctx = GrantsToContext(ctx, grants) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) err = engine.Require(ctx, Check{Scope: ScopeProjectRead, ResourceID: "proj:123"}, @@ -65,7 +64,7 @@ func TestFilter_withLoadedGrantsFromContext(t *testing.T) { ctx = GrantsToContext(ctx, grants) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) projectIDs, err := engine.Filter(ctx, []Check{ {Scope: ScopeProjectRead, ResourceID: "proj:123"}, @@ -104,7 +103,7 @@ func TestFilter_withDimensions(t *testing.T) { ctx = GrantsToContext(ctx, grants) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) results, err := engine.Filter(ctx, []Check{ MCPToolCallCheck("toolsetX", MCPToolCallDimensions{Tool: "allowed_tool", Disposition: ""}), diff --git a/server/internal/authz/load_test.go b/server/internal/authz/load_test.go index c09a83f51a..1e39a976ee 100644 --- a/server/internal/authz/load_test.go +++ b/server/internal/authz/load_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/speakeasy-api/gram/server/internal/cache" + accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" "github.com/speakeasy-api/gram/server/internal/testenv" "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" "github.com/speakeasy-api/gram/server/internal/urn" @@ -30,11 +30,31 @@ func TestLoadGrants_loadsUserAndRoleGrants(t *testing.T) { ctx = GrantsToContext(ctx, grants) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) require.NoError(t, engine.Require(ctx, Check{Scope: ScopeProjectRead, ResourceID: "proj:123"})) require.NoError(t, engine.Require(ctx, Check{Scope: ScopeMCPConnect, ResourceID: "toolA"})) } +func TestSeedSystemRoleGrantsBootstrapsGlobalRoles(t *testing.T) { + t.Parallel() + + ctx := t.Context() + conn := newTestDB(t) + organizationID := "org_seed_system_roles" + seedOrganization(t, ctx, conn, organizationID) + + err := SeedSystemRoleGrants(ctx, testenv.NewLogger(t), conn, organizationID) + require.NoError(t, err) + + adminRole, err := accessrepo.New(conn).GetGlobalRoleBySlug(ctx, SystemRoleAdmin) + require.NoError(t, err) + require.Equal(t, "Admin", adminRole.WorkosName) + + grants, err := GrantsForRole(ctx, testenv.NewLogger(t), conn, organizationID, SystemRoleAdmin, "role:global:"+adminRole.ID.String()) + require.NoError(t, err) + require.NotEmpty(t, grants) +} + func TestLoadGrants_rejectsEmptyOrganizationID(t *testing.T) { t.Parallel() @@ -93,7 +113,7 @@ func TestLoadGrants_returnsEmptyGrantSetWhenNoRowsMatch(t *testing.T) { ctx = GrantsToContext(ctx, grants) chConn, err := newClickhouseClient(t) require.NoError(t, err) - engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient(), cache.NoopCache) + engine := NewEngine(testenv.NewLogger(t), conn, chConn, rbacAlwaysEnabled, challengeLoggingAlwaysEnabled, workos.NewStubClient()) projectIDs, err := engine.Filter(ctx, []Check{ {Scope: ScopeProjectRead, ResourceID: "proj:123"}, }) diff --git a/server/internal/background/activities/backfill_workos_organization.go b/server/internal/background/activities/backfill_workos_organization.go index 56d535a15d..e4d2459510 100644 --- a/server/internal/background/activities/backfill_workos_organization.go +++ b/server/internal/background/activities/backfill_workos_organization.go @@ -196,7 +196,7 @@ func backfillOrganizationRoles(ctx context.Context, logger *slog.Logger, dbtx pg continue } - if err := repo.UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + if _, err := repo.UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ OrganizationID: organizationID, WorkosSlug: role.Slug, WorkosName: role.Name, diff --git a/server/internal/background/activities/backfill_workos_test.go b/server/internal/background/activities/backfill_workos_test.go index c9bb5d74fa..9eed9407c5 100644 --- a/server/internal/background/activities/backfill_workos_test.go +++ b/server/internal/background/activities/backfill_workos_test.go @@ -394,7 +394,7 @@ func seedOrganizationRoleWithCursor(t *testing.T, ctx context.Context, conn *pgx t.Helper() updatedAt := time.Date(2026, 5, 7, 10, 0, 0, 0, time.UTC) - err := accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + _, err := accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ OrganizationID: organizationID, WorkosSlug: slug, WorkosName: name, diff --git a/server/internal/background/activities/process_workos_org_events.go b/server/internal/background/activities/process_workos_org_events.go index 38775a4114..e9189e75c1 100644 --- a/server/internal/background/activities/process_workos_org_events.go +++ b/server/internal/background/activities/process_workos_org_events.go @@ -15,6 +15,7 @@ import ( accessrepo "github.com/speakeasy-api/gram/server/internal/access/repo" "github.com/speakeasy-api/gram/server/internal/attr" "github.com/speakeasy-api/gram/server/internal/auth/orgslug" + "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/conv" "github.com/speakeasy-api/gram/server/internal/database" "github.com/speakeasy-api/gram/server/internal/o11y" @@ -490,7 +491,7 @@ func upsertOrganizationRole(ctx context.Context, logger *slog.Logger, dbtx datab return nil } - if err := repo.UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + if _, err := repo.UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ OrganizationID: org.ID, WorkosSlug: payload.Slug, WorkosName: payload.Name, @@ -558,10 +559,8 @@ func handleRoleDeleted(ctx context.Context, logger *slog.Logger, dbtx database.D return fmt.Errorf("mark organization role %q deleted: %w", payload.Slug, err) } - if _, err := repo.DeletePrincipalGrantsByPrincipal(ctx, accessrepo.DeletePrincipalGrantsByPrincipalParams{ - OrganizationID: org.ID, - PrincipalUrn: urn.NewPrincipal(urn.PrincipalTypeRole, payload.Slug), - }); err != nil { + rolePrincipal := urn.NewPrincipal(urn.PrincipalTypeRole, "organization:"+existing.ID.String()) + if err := authz.DeleteRoleGrants(ctx, repo, org.ID, payload.Slug, rolePrincipal.String()); err != nil { return fmt.Errorf("delete grants for role %q: %w", payload.Slug, err) } diff --git a/server/internal/background/activities/process_workos_org_events_test.go b/server/internal/background/activities/process_workos_org_events_test.go index aef20f75bc..8e4b4be8d4 100644 --- a/server/internal/background/activities/process_workos_org_events_test.go +++ b/server/internal/background/activities/process_workos_org_events_test.go @@ -1027,7 +1027,7 @@ func seedOrganizationRole(t *testing.T, ctx context.Context, conn *pgxpool.Pool, t.Helper() eventTime := time.Date(2026, 5, 6, 10, 0, 0, 0, time.UTC) - err := accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ + _, err := accessrepo.New(conn).UpsertOrganizationRole(ctx, accessrepo.UpsertOrganizationRoleParams{ OrganizationID: organizationID, WorkosSlug: slug, WorkosName: slug, diff --git a/server/internal/collections/setup_test.go b/server/internal/collections/setup_test.go index 8f52687328..7a92083df9 100644 --- a/server/internal/collections/setup_test.go +++ b/server/internal/collections/setup_test.go @@ -76,7 +76,7 @@ func newTestCollectionsService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) auditLogger := audit.NewLogger() svc := collections.NewService(logger, tracerProvider, conn, sessionManager, authzEngine, auditLogger, testenv.DefaultSiteURL(t)) diff --git a/server/internal/conv/from.go b/server/internal/conv/from.go index 060d232b40..a27bb9348a 100644 --- a/server/internal/conv/from.go +++ b/server/internal/conv/from.go @@ -163,6 +163,16 @@ func PtrToPGTimestamptz(t *time.Time) pgtype.Timestamptz { return pgtype.Timestamptz{Time: *t, Valid: true, InfinityModifier: pgtype.Finite} } +// FromPGTimestamptz converts a pgtype.Timestamptz to an RFC3339 UTC string. If +// the value is not valid, it returns an empty string. +func FromPGTimestamptz(t pgtype.Timestamptz) string { + if !t.Valid { + return "" + } + + return t.Time.UTC().Format(time.RFC3339) +} + // PtrToPGInterval converts a *time.Duration to a pgtype.Interval. If the // pointer is nil, the result has Valid set to false (which becomes SQL NULL). func PtrToPGInterval(d *time.Duration) pgtype.Interval { diff --git a/server/internal/conv/from_test.go b/server/internal/conv/from_test.go index d2cd0a18af..8a3a567c31 100644 --- a/server/internal/conv/from_test.go +++ b/server/internal/conv/from_test.go @@ -2,6 +2,7 @@ package conv_test import ( "testing" + "time" "github.com/jackc/pgx/v5/pgtype" "github.com/speakeasy-api/gram/server/internal/conv" @@ -45,6 +46,20 @@ func TestPtrInt32ToInt_Nil(t *testing.T) { require.Nil(t, result) } +func TestFromPGTimestamptz_Valid(t *testing.T) { + t.Parallel() + + input := pgtype.Timestamptz{Time: time.Date(2024, 11, 15, 15, 4, 5, 0, time.FixedZone("test", 2*60*60)), Valid: true} + + require.Equal(t, "2024-11-15T13:04:05Z", conv.FromPGTimestamptz(input)) +} + +func TestFromPGTimestamptz_Invalid(t *testing.T) { + t.Parallel() + + require.Empty(t, conv.FromPGTimestamptz(pgtype.Timestamptz{})) +} + func TestURLToSlug_HostAndPath(t *testing.T) { t.Parallel() diff --git a/server/internal/customdomains/setup_test.go b/server/internal/customdomains/setup_test.go index 515eea1a2a..35589b46b6 100644 --- a/server/internal/customdomains/setup_test.go +++ b/server/internal/customdomains/setup_test.go @@ -80,7 +80,7 @@ func newTestCustomDomainsService(t *testing.T) (context.Context, *serviceTestIns temporal := &stubTemporalClient{} chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) auditLogger := audit.NewLogger() svc := customdomains.NewService(logger, tracerProvider, conn, sessionManager, temporal, authzEngine, auditLogger) diff --git a/server/internal/deployments/setup_test.go b/server/internal/deployments/setup_test.go index 3b73748857..2c918e5e9c 100644 --- a/server/internal/deployments/setup_test.go +++ b/server/internal/deployments/setup_test.go @@ -106,7 +106,7 @@ func newTestDeploymentService(t *testing.T, assetStorage assets.BlobStore) (cont chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := deployments.NewService(logger, tracerProvider, conn, temporalEnv, sessionManager, assetStorage, posthog, testenv.DefaultSiteURL(t), mcpRegistryClient, authzEngine, auditLogger) assetsSvc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret", authzEngine, auditLogger) packagesSvc := packages.NewService(logger, tracerProvider, conn, sessionManager, authzEngine) diff --git a/server/internal/environments/setup_test.go b/server/internal/environments/setup_test.go index 7bfdd9e5c4..dc4cc7aa30 100644 --- a/server/internal/environments/setup_test.go +++ b/server/internal/environments/setup_test.go @@ -74,7 +74,7 @@ func newTestEnvironmentService(t *testing.T) (context.Context, *testInstance) { require.NoError(t, err) auditLogger := audit.NewLogger() - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := environments.NewService(logger, tracerProvider, conn, sessionManager, enc, authzEngine, auditLogger) return ctx, &testInstance{ diff --git a/server/internal/externalmcp/setup_test.go b/server/internal/externalmcp/setup_test.go index 94501825fd..acbfb7222d 100644 --- a/server/internal/externalmcp/setup_test.go +++ b/server/internal/externalmcp/setup_test.go @@ -72,7 +72,7 @@ func newTestExternalMCPService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := externalmcp.NewService(logger, tracerProvider, conn, sessionManager, mcpRegistryClient, authzEngine) return ctx, &testInstance{ diff --git a/server/internal/functions/setup_test.go b/server/internal/functions/setup_test.go index 8fcde5600b..c184440068 100644 --- a/server/internal/functions/setup_test.go +++ b/server/internal/functions/setup_test.go @@ -116,7 +116,7 @@ func newTestFunctionsService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := functions.NewService(logger, tracerProvider, conn, enc, tigrisStore) deploymentsSvc := deployments.NewService(logger, tracerProvider, conn, temporalEnv, sessionManager, assetStorage, ph, testenv.DefaultSiteURL(t), mcpRegistryClient, authzEngine, auditLogger) assetsSvc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret", authzEngine, auditLogger) diff --git a/server/internal/hooks/setup_test.go b/server/internal/hooks/setup_test.go index e526ac8917..ff2169170d 100644 --- a/server/internal/hooks/setup_test.go +++ b/server/internal/hooks/setup_test.go @@ -80,7 +80,7 @@ func newTestHooksService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) chatWriter, chatWriterShutdown := chat.NewChatMessageWriter(logger, conn, nil) t.Cleanup(func() { _ = chatWriterShutdown(t.Context()) }) shadowMCPClient := shadowmcp.NewClient(logger, conn, cacheAdapter) diff --git a/server/internal/keys/setup_test.go b/server/internal/keys/setup_test.go index cb10e8a28b..4249fa69d1 100644 --- a/server/internal/keys/setup_test.go +++ b/server/internal/keys/setup_test.go @@ -78,7 +78,7 @@ func newTestKeysService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) auditLogger := audit.NewLogger() svc := keys.NewService(logger, tracerProvider, conn, sessionManager, "local", authzEngine, auditLogger) keyAuth := auth.NewKeyAuth(conn, logger, billingClient) diff --git a/server/internal/mcp/handle_get_server_test.go b/server/internal/mcp/handle_get_server_test.go index 73a580a1d4..59d01f33a1 100644 --- a/server/internal/mcp/handle_get_server_test.go +++ b/server/internal/mcp/handle_get_server_test.go @@ -12,7 +12,6 @@ import ( "github.com/stretchr/testify/require" "github.com/speakeasy-api/gram/server/internal/authz" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/mcpmetadata" "github.com/speakeasy-api/gram/server/internal/thirdparty/workos" ) @@ -33,7 +32,7 @@ func TestHandleGetServer_ContentNegotiation(t *testing.T) { testInstance.serverURL, testInstance.siteURL, testInstance.cacheAdapter, - authz.NewEngine(testInstance.logger, testInstance.conn, chConn, nil, nil, workos.NewStubClient(), cache.NoopCache), + authz.NewEngine(testInstance.logger, testInstance.conn, chConn, nil, nil, workos.NewStubClient()), testInstance.audit, ) diff --git a/server/internal/mcp/rbac_test.go b/server/internal/mcp/rbac_test.go index 87456f1057..31ed3ef5da 100644 --- a/server/internal/mcp/rbac_test.go +++ b/server/internal/mcp/rbac_test.go @@ -11,7 +11,6 @@ import ( "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/authztest" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/conv" "github.com/speakeasy-api/gram/server/internal/oops" @@ -27,7 +26,7 @@ func TestServePublic_RBAC_PrivateMCP_DeniedWithNoGrants(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx) err = authzEngine.Require(ctx, authz.Check{Scope: authz.ScopeMCPConnect, ResourceID: toolset.ID.String()}) @@ -44,7 +43,7 @@ func TestServePublic_RBAC_PrivateMCP_DeniedWithUnrelatedGrant(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{Scope: authz.ScopeMCPConnect, Selector: authz.NewSelector(authz.ScopeMCPConnect, uuid.NewString())}) err = authzEngine.Require(ctx, authz.Check{Scope: authz.ScopeMCPConnect, ResourceID: toolset.ID.String()}) @@ -61,7 +60,7 @@ func TestServePublic_RBAC_PrivateMCP_AllowedWithWriteGrant(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{Scope: authz.ScopeMCPWrite, Selector: authz.NewSelector(authz.ScopeMCPWrite, toolset.ID.String())}) err = authzEngine.Require(ctx, authz.Check{Scope: authz.ScopeMCPConnect, ResourceID: toolset.ID.String()}) @@ -76,7 +75,7 @@ func TestServePublic_RBAC_PrivateMCP_AllowedWithConnectGrant(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{Scope: authz.ScopeMCPConnect, Selector: authz.NewSelector(authz.ScopeMCPConnect, toolset.ID.String())}) err = authzEngine.Require(ctx, authz.Check{Scope: authz.ScopeMCPConnect, ResourceID: toolset.ID.String()}) @@ -109,7 +108,7 @@ func TestServePublic_RBAC_ToolLevelGrant_AllowsMatchingTool(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ Scope: authz.ScopeMCPConnect, Selector: authz.Selector{ @@ -136,7 +135,7 @@ func TestServePublic_RBAC_ToolLevelGrant_DeniesWrongTool(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ Scope: authz.ScopeMCPConnect, Selector: authz.Selector{ @@ -165,7 +164,7 @@ func TestServePublic_RBAC_ServerLevelGrant_AllowsAnyTool(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Server-level grant (no tool key) should allow any tool. ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ Scope: authz.ScopeMCPConnect, @@ -188,7 +187,7 @@ func TestServePublic_RBAC_ToolLevelGrant_DeniesWrongServer(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ Scope: authz.ScopeMCPConnect, Selector: authz.Selector{ @@ -217,7 +216,7 @@ func TestServePublic_RBAC_ProjectScopedGrant_AllowsServerInProject(t *testing.T) chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Project-scoped grant: resource_id=*, project_id=toolset's project ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ Scope: authz.ScopeMCPConnect, @@ -241,7 +240,7 @@ func TestServePublic_RBAC_ProjectScopedGrant_DeniesServerInOtherProject(t *testi chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Project-scoped grant for a *different* project otherProjectID := uuid.New() ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ @@ -268,7 +267,7 @@ func TestServePublic_RBAC_ProjectScopedGrant_AllowsToolCallInProject(t *testing. chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ Scope: authz.ScopeMCPConnect, Selector: authz.Selector{ diff --git a/server/internal/mcp/rpc_tools_list_test.go b/server/internal/mcp/rpc_tools_list_test.go index f384ece810..ff6b4f4432 100644 --- a/server/internal/mcp/rpc_tools_list_test.go +++ b/server/internal/mcp/rpc_tools_list_test.go @@ -13,7 +13,6 @@ import ( "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/authztest" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" deployments_repo "github.com/speakeasy-api/gram/server/internal/deployments/repo" "github.com/speakeasy-api/gram/server/internal/oops" @@ -211,7 +210,7 @@ func TestServePublic_RBAC_ToolsList_FiltersToGrantedTools(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Grant mcp:connect only for "allowed_tool". ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ @@ -248,7 +247,7 @@ func TestServePublic_RBAC_ToolsList_ServerLevelGrantReturnsAll(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Server-level grant (no tool dimension) — all tools allowed. ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ @@ -278,7 +277,7 @@ func TestServePublic_RBAC_ToolsList_NoGrantsDenied(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // mcp:connect grant for a DIFFERENT toolset — should not match. ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ @@ -303,7 +302,7 @@ func TestServePublic_RBAC_ToolsList_MultipleToolGrants(t *testing.T) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Grant access to tool_a and tool_c but not tool_b. ctx = authztest.WithExactGrants(t, ctx, @@ -348,7 +347,7 @@ func TestServePublic_RBAC_ToolsList_DispositionGrant_AllowsMatchingDisposition(t chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Grant mcp:connect scoped to read_only disposition only. ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ @@ -386,7 +385,7 @@ func TestServePublic_RBAC_ToolsList_DisabledRBACAllowsAll(t *testing.T) { // Engine with RBAC disabled — simulates org without RBAC feature flag. chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysDisabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysDisabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // No grants in context at all. With RBAC disabled, every tool should pass. for _, tool := range []string{"tool_one", "tool_two", "tool_three"} { @@ -406,7 +405,7 @@ func TestServePublic_RBAC_ToolsList_DispositionGrant_ServerLevelAllowsAll(t *tes chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(ti.logger, ti.conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) // Server-level grant (no disposition key) — all dispositions allowed. ctx = authztest.WithExactGrants(t, ctx, authz.Grant{ diff --git a/server/internal/mcp/setup_test.go b/server/internal/mcp/setup_test.go index 26068a5a4c..a912bdd7c5 100644 --- a/server/internal/mcp/setup_test.go +++ b/server/internal/mcp/setup_test.go @@ -162,7 +162,7 @@ func newTestMCPServiceWithIdentityResolver(t *testing.T, identityResolver mcp.Id chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) telemLogger := telemetry.NewLogger(ctx, logger, chConn, logsEnabled, toolIOLogsEnabled) telemService := telemetry.NewService( diff --git a/server/internal/mcpendpoints/setup_test.go b/server/internal/mcpendpoints/setup_test.go index fd061348b7..3d57ff66c5 100644 --- a/server/internal/mcpendpoints/setup_test.go +++ b/server/internal/mcpendpoints/setup_test.go @@ -79,7 +79,7 @@ func newTestService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() - svc := mcpendpoints.NewService(logger, tracerProvider, conn, sessionManager, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), auditLogger) + svc := mcpendpoints.NewService(logger, tracerProvider, conn, sessionManager, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), auditLogger) return ctx, &testInstance{ service: svc, diff --git a/server/internal/mcpmetadata/setup_test.go b/server/internal/mcpmetadata/setup_test.go index 025fc995e8..a985f84668 100644 --- a/server/internal/mcpmetadata/setup_test.go +++ b/server/internal/mcpmetadata/setup_test.go @@ -89,7 +89,7 @@ func newTestMCPMetadataService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() - svc := mcpmetadata.NewService(logger, tracerProvider, conn, sessionManager, serverURL, siteURL, cacheAdapter, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), auditLogger) + svc := mcpmetadata.NewService(logger, tracerProvider, conn, sessionManager, serverURL, siteURL, cacheAdapter, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), auditLogger) return ctx, &testInstance{ service: svc, diff --git a/server/internal/mcpservers/setup_test.go b/server/internal/mcpservers/setup_test.go index 6a7de85436..90ab41a553 100644 --- a/server/internal/mcpservers/setup_test.go +++ b/server/internal/mcpservers/setup_test.go @@ -77,7 +77,7 @@ func newTestService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() - svc := mcpservers.NewService(logger, tracerProvider, conn, sessionManager, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), auditLogger) + svc := mcpservers.NewService(logger, tracerProvider, conn, sessionManager, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), auditLogger) return ctx, &testInstance{ service: svc, diff --git a/server/internal/organizations/setup_test.go b/server/internal/organizations/setup_test.go index df5ef5a17f..94554ad4d3 100644 --- a/server/internal/organizations/setup_test.go +++ b/server/internal/organizations/setup_test.go @@ -107,7 +107,7 @@ func newTestOrganizationsService(t *testing.T) (context.Context, *testInstance) chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, thirdpartyworkos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, thirdpartyworkos.NewStubClient()) auditLogger := audit.NewLogger() @@ -163,7 +163,7 @@ func newTestOrganizationsServiceRBAC(t *testing.T) (context.Context, *testInstan chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, thirdpartyworkos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, thirdpartyworkos.NewStubClient()) auditLogger := audit.NewLogger() @@ -217,7 +217,7 @@ func newTestOrganizationsServiceWithEmail(t *testing.T) (context.Context, *testI chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, thirdpartyworkos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, thirdpartyworkos.NewStubClient()) auditLogger := audit.NewLogger() diff --git a/server/internal/packages/setup_test.go b/server/internal/packages/setup_test.go index a09eb67186..ccb5c4ee2e 100644 --- a/server/internal/packages/setup_test.go +++ b/server/internal/packages/setup_test.go @@ -70,7 +70,7 @@ func newTestPackagesService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := packages.NewService(logger, tracerProvider, conn, sessionManager, authzEngine) return ctx, &testInstance{ diff --git a/server/internal/plugins/setup_test.go b/server/internal/plugins/setup_test.go index 86280926e5..4085932e01 100644 --- a/server/internal/plugins/setup_test.go +++ b/server/internal/plugins/setup_test.go @@ -93,7 +93,7 @@ func newTestPluginsService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() - svc := plugins.NewService(logger, tracerProvider, conn, sessionManager, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), auditLogger, nil, "local", "https://app.getgram.ai") + svc := plugins.NewService(logger, tracerProvider, conn, sessionManager, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), auditLogger, nil, "local", "https://app.getgram.ai") return ctx, &testInstance{ service: svc, @@ -147,7 +147,7 @@ func newTestPluginsServiceWithGitHub(t *testing.T, ghClient plugins.GitHubPublis tracerProvider, conn, sessionManager, - authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), + authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), auditLogger, ghConfig, "local", diff --git a/server/internal/productfeatures/setup_test.go b/server/internal/productfeatures/setup_test.go index 1599f2022c..4d68968047 100644 --- a/server/internal/productfeatures/setup_test.go +++ b/server/internal/productfeatures/setup_test.go @@ -86,7 +86,7 @@ func newTestProductFeaturesService(t *testing.T) (context.Context, *testInstance chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := productfeatures.NewService(logger, tracerProvider, conn, sessionManager, redisClient, authzEngine) return ctx, &testInstance{ diff --git a/server/internal/projects/setup_test.go b/server/internal/projects/setup_test.go index 3b98287bca..427b312205 100644 --- a/server/internal/projects/setup_test.go +++ b/server/internal/projects/setup_test.go @@ -106,7 +106,6 @@ func newTestProjectsService(t *testing.T, enableRBAC bool) (context.Context, *te }, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), - cache.NoopCache, ), auditLogger, ) diff --git a/server/internal/remotemcp/setup_test.go b/server/internal/remotemcp/setup_test.go index db4411260c..8897c5c46a 100644 --- a/server/internal/remotemcp/setup_test.go +++ b/server/internal/remotemcp/setup_test.go @@ -101,7 +101,7 @@ func newTestService(t *testing.T) (context.Context, *testInstance) { auditLogger := audit.NewLogger() - svc := remotemcp.NewService(logger, tracerProvider, conn, sessionManager, enc, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), servicePolicy, auditLogger) + svc := remotemcp.NewService(logger, tracerProvider, conn, sessionManager, enc, authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), servicePolicy, auditLogger) return ctx, &testInstance{ service: svc, diff --git a/server/internal/remotesessions/clienthandlers.go b/server/internal/remotesessions/clienthandlers.go index c128d43a9e..dec87a9802 100644 --- a/server/internal/remotesessions/clienthandlers.go +++ b/server/internal/remotesessions/clienthandlers.go @@ -515,4 +515,3 @@ func (s *Service) DeleteRemoteSessionClient(ctx context.Context, payload *gen.De return nil } - diff --git a/server/internal/remotesessions/setup_test.go b/server/internal/remotesessions/setup_test.go index 4ba2e481cb..5837096334 100644 --- a/server/internal/remotesessions/setup_test.go +++ b/server/internal/remotesessions/setup_test.go @@ -92,7 +92,7 @@ func newTestService(t *testing.T) (context.Context, *testInstance) { tracerProvider, conn, sessionManager, - authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), + authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), enc, envEntries, guardianPolicy, diff --git a/server/internal/resources/setup_test.go b/server/internal/resources/setup_test.go index cfbc788804..8900e4abb2 100644 --- a/server/internal/resources/setup_test.go +++ b/server/internal/resources/setup_test.go @@ -69,7 +69,7 @@ func newTestResourcesService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := resources.NewService(logger, tracerProvider, conn, sessionManager, authzEngine) return ctx, &testInstance{ diff --git a/server/internal/risk/setup_test.go b/server/internal/risk/setup_test.go index 9dd37d66cd..46dabe815e 100644 --- a/server/internal/risk/setup_test.go +++ b/server/internal/risk/setup_test.go @@ -89,7 +89,7 @@ func newTestRiskService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) shadowMCPClient := shadowmcp.NewClient(logger, conn, cache.NewRedisCacheAdapter(redisClient)) auditLogger := audit.NewLogger() diff --git a/server/internal/telemetry/setup_test.go b/server/internal/telemetry/setup_test.go index 0d2e8efe41..b24390f433 100644 --- a/server/internal/telemetry/setup_test.go +++ b/server/internal/telemetry/setup_test.go @@ -114,7 +114,7 @@ func newTestLogsService(t *testing.T) (context.Context, *testInstance) { posthogClient := posthog.New(ctx, logger, "test-posthog-key", "test-posthog-host", "") telemLogger := telemetry.NewLogger(ctx, logger, chConn, logsEnabled, toolIOLogsEnabled) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := telemetry.NewService(logger, tracerProvider, conn, chConn, sessionManager, chatSessionsManager, logsEnabled, sessionCaptureEnabled, posthogClient, authzEngine) return ctx, &testInstance{ diff --git a/server/internal/templates/setup_test.go b/server/internal/templates/setup_test.go index cc7e6d1181..6efe62a90f 100644 --- a/server/internal/templates/setup_test.go +++ b/server/internal/templates/setup_test.go @@ -80,7 +80,7 @@ func newTestTemplateService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) auditLogger := audit.NewLogger() svc := templates.NewService(logger, tracerProvider, conn, sessionManager, toolsetsSvc, authzEngine, auditLogger) diff --git a/server/internal/tools/setup_test.go b/server/internal/tools/setup_test.go index b2b0e46091..e08bc462ae 100644 --- a/server/internal/tools/setup_test.go +++ b/server/internal/tools/setup_test.go @@ -115,7 +115,7 @@ func newTestToolsService(t *testing.T, assetStorage assets.BlobStore) (context.C chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) toolsSvc := tools.NewService(logger, tracerProvider, conn, sessionManager, authzEngine, nil, nil) deploymentsSvc := deployments.NewService(logger, tracerProvider, conn, temporalEnv, sessionManager, assetStorage, posthog, testenv.DefaultSiteURL(t), mcpRegistryClient, authzEngine, auditLogger) assetsSvc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret", authzEngine, auditLogger) diff --git a/server/internal/toolsets/setup_test.go b/server/internal/toolsets/setup_test.go index ebb69c82cd..273a9a23ca 100644 --- a/server/internal/toolsets/setup_test.go +++ b/server/internal/toolsets/setup_test.go @@ -124,7 +124,7 @@ func newTestToolsetsService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) svc := toolsets.NewService(logger, tracerProvider, conn, sessionManager, nil, authzEngine, auditLogger) deploymentsSvc := deployments.NewService(logger, tracerProvider, conn, temporalEnv, sessionManager, assetStorage, posthog, testenv.DefaultSiteURL(t), mcpRegistryClient, authzEngine, auditLogger) assetsSvc := assets.NewService(logger, tracerProvider, guardianPolicy, conn, sessionManager, chatSessionsManager, assetStorage, "test-jwt-secret", authzEngine, auditLogger) diff --git a/server/internal/triggers/setup_test.go b/server/internal/triggers/setup_test.go index b6872220e8..3b053b890f 100644 --- a/server/internal/triggers/setup_test.go +++ b/server/internal/triggers/setup_test.go @@ -117,7 +117,7 @@ func newTestService(t *testing.T) (context.Context, *testInstance) { tracerProvider, conn, sessionManager, - authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), + authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), app, auditLogger, ) diff --git a/server/internal/usage/impl_test.go b/server/internal/usage/impl_test.go index 31b71aeba2..231aa7f1e3 100644 --- a/server/internal/usage/impl_test.go +++ b/server/internal/usage/impl_test.go @@ -17,7 +17,6 @@ import ( "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/authztest" "github.com/speakeasy-api/gram/server/internal/billing" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/oops" "github.com/speakeasy-api/gram/server/internal/testenv" @@ -144,7 +143,7 @@ func newTestService(t *testing.T, billingRepo billing.Repository, orgID string, chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, db, chConn, rbacDisabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, db, chConn, rbacDisabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) return &Service{ tracer: tp.Tracer("test"), @@ -204,7 +203,7 @@ func TestGetPeriodUsage_CacheHit(t *testing.T) { cached := sampleUsage(42, 2, 10) billingMock := &mockBillingRepo{} - billingMock.On("GetStoredPeriodUsage", mock.Anything, orgID).Return(cached, nil) + billingMock.On("GetStoredPeriodUsage", mock.Anything, orgID).Return(cached) // GetPeriodUsage should NOT be called svc := newTestService(t, billingMock, orgID, 5) @@ -226,7 +225,7 @@ func TestGetPeriodUsage_CacheMissFallback(t *testing.T) { billingMock := &mockBillingRepo{} billingMock.On("GetStoredPeriodUsage", mock.Anything, orgID).Return(nil, fmt.Errorf("cache miss")) - billingMock.On("GetPeriodUsage", mock.Anything, orgID).Return(fresh, nil) + billingMock.On("GetPeriodUsage", mock.Anything, orgID).Return(fresh) svc := newTestService(t, billingMock, orgID, 3) ctx := testAuthContext(orgID) @@ -280,7 +279,7 @@ func TestGetPeriodUsage_ActualServerCountFromDB(t *testing.T) { cached.ActualEnabledServerCount = 999 // cached value should be overridden billingMock := &mockBillingRepo{} - billingMock.On("GetStoredPeriodUsage", mock.Anything, orgID).Return(cached, nil) + billingMock.On("GetStoredPeriodUsage", mock.Anything, orgID).Return(cached) svc := newTestService(t, billingMock, orgID, 7) // DB says 7 ctx := testAuthContext(orgID) diff --git a/server/internal/usersessions/setup_test.go b/server/internal/usersessions/setup_test.go index 296dcaf094..656b19437d 100644 --- a/server/internal/usersessions/setup_test.go +++ b/server/internal/usersessions/setup_test.go @@ -88,7 +88,7 @@ func newTestService(t *testing.T) (context.Context, *testInstance) { conn, sessionManager, chatSessionsManager, - authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache), + authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()), audit.NewLogger(), ) diff --git a/server/internal/variations/setup_test.go b/server/internal/variations/setup_test.go index 9b8e02cfbc..7a7a83f301 100644 --- a/server/internal/variations/setup_test.go +++ b/server/internal/variations/setup_test.go @@ -70,7 +70,7 @@ func newTestVariationsService(t *testing.T) (context.Context, *testInstance) { chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) auditLogger := audit.NewLogger() svc := variations.NewService(logger, tracerProvider, conn, sessionManager, authzEngine, auditLogger) diff --git a/server/internal/xmcp/setup_test.go b/server/internal/xmcp/setup_test.go index 7a281978aa..b9ee385dc7 100644 --- a/server/internal/xmcp/setup_test.go +++ b/server/internal/xmcp/setup_test.go @@ -118,7 +118,7 @@ func newTestService(t *testing.T) (context.Context, *testInstance) { enc := testenv.NewEncryptionClient(t) chConn, err := infra.NewClickhouseClient(t) require.NoError(t, err) - authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + authzEngine := authz.NewEngine(logger, conn, chConn, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) mcpMetadataRepo := mcpmetadatarepo.New(conn) env := environments.NewEnvironmentEntries(logger, conn, enc, mcpMetadataRepo) diff --git a/server/internal/xmcp/tools_call_authz_interceptor_test.go b/server/internal/xmcp/tools_call_authz_interceptor_test.go index 1944adebfb..836608b821 100644 --- a/server/internal/xmcp/tools_call_authz_interceptor_test.go +++ b/server/internal/xmcp/tools_call_authz_interceptor_test.go @@ -8,7 +8,6 @@ import ( "github.com/speakeasy-api/gram/server/internal/authz" "github.com/speakeasy-api/gram/server/internal/authztest" - "github.com/speakeasy-api/gram/server/internal/cache" "github.com/speakeasy-api/gram/server/internal/contextvalues" "github.com/speakeasy-api/gram/server/internal/oops" "github.com/speakeasy-api/gram/server/internal/remotemcp/proxy" @@ -24,7 +23,7 @@ const ( func newAuthzEngineForTest(t *testing.T) *authz.Engine { t.Helper() - return authz.NewEngine(testenv.NewLogger(t), nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient(), cache.NoopCache) + return authz.NewEngine(testenv.NewLogger(t), nil, nil, authztest.RBACAlwaysEnabled, authztest.ChallengeLoggingAlwaysDisabled, workos.NewStubClient()) } func authzAuthContext(t *testing.T) *contextvalues.AuthContext {