diff --git a/core/membership/mocks/policy_service.go b/core/membership/mocks/policy_service.go index ab81d7b31..c2e1b4dac 100644 --- a/core/membership/mocks/policy_service.go +++ b/core/membership/mocks/policy_service.go @@ -127,6 +127,54 @@ func (_c *PolicyService_Delete_Call) RunAndReturn(run func(context.Context, stri return _c } +// DeleteWithMinRoleGuard provides a mock function with given fields: ctx, id, guardRoleID +func (_m *PolicyService) DeleteWithMinRoleGuard(ctx context.Context, id string, guardRoleID string) error { + ret := _m.Called(ctx, id, guardRoleID) + + if len(ret) == 0 { + panic("no return value specified for DeleteWithMinRoleGuard") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, id, guardRoleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PolicyService_DeleteWithMinRoleGuard_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteWithMinRoleGuard' +type PolicyService_DeleteWithMinRoleGuard_Call struct { + *mock.Call +} + +// DeleteWithMinRoleGuard is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - guardRoleID string +func (_e *PolicyService_Expecter) DeleteWithMinRoleGuard(ctx interface{}, id interface{}, guardRoleID interface{}) *PolicyService_DeleteWithMinRoleGuard_Call { + return &PolicyService_DeleteWithMinRoleGuard_Call{Call: _e.mock.On("DeleteWithMinRoleGuard", ctx, id, guardRoleID)} +} + +func (_c *PolicyService_DeleteWithMinRoleGuard_Call) Run(run func(ctx context.Context, id string, guardRoleID string)) *PolicyService_DeleteWithMinRoleGuard_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *PolicyService_DeleteWithMinRoleGuard_Call) Return(_a0 error) *PolicyService_DeleteWithMinRoleGuard_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *PolicyService_DeleteWithMinRoleGuard_Call) RunAndReturn(run func(context.Context, string, string) error) *PolicyService_DeleteWithMinRoleGuard_Call { + _c.Call.Return(run) + return _c +} + // List provides a mock function with given fields: ctx, flt func (_m *PolicyService) List(ctx context.Context, flt policy.Filter) ([]policy.Policy, error) { ret := _m.Called(ctx, flt) diff --git a/core/membership/service.go b/core/membership/service.go index c81b0ebc1..b4fc49f93 100644 --- a/core/membership/service.go +++ b/core/membership/service.go @@ -28,6 +28,7 @@ type PolicyService interface { Create(ctx context.Context, pol policy.Policy) (policy.Policy, error) List(ctx context.Context, flt policy.Filter) ([]policy.Policy, error) Delete(ctx context.Context, id string) error + DeleteWithMinRoleGuard(ctx context.Context, id string, guardRoleID string) error } type RelationService interface { @@ -217,11 +218,12 @@ func (s *Service) SetOrganizationMemberRole(ctx context.Context, orgID, principa return nil } - if err := s.validateMinOwnerConstraint(ctx, orgID, resolvedRoleID, existing); err != nil { + ownerRoleID, err := s.validateMinOwnerConstraint(ctx, orgID, resolvedRoleID, existing) + if err != nil { return err } - if err := s.replacePolicy(ctx, orgID, schema.OrganizationNamespace, principalID, principalType, resolvedRoleID, existing); err != nil { + if err := s.replacePolicyWithOwnerGuard(ctx, orgID, schema.OrganizationNamespace, principalID, principalType, resolvedRoleID, existing, ownerRoleID); err != nil { return err } @@ -271,7 +273,8 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal return ErrNotMember } - if err = s.validateMinOwnerConstraint(ctx, orgID, "", orgPolicies); err != nil { + ownerRoleID, err := s.validateMinOwnerConstraint(ctx, orgID, "", orgPolicies) + if err != nil { return err } @@ -303,30 +306,53 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal return fmt.Errorf("list all principal policies: %w", err) } - // delete sub-resource policies first (projects, groups), then relations, - // then org policies last — so a retry after partial failure won't hit ErrNotMember - var orgPolicyIDs []string - var errs error + // Phase 1: classify policies by type (no deletions yet) + var orgPoliciesToDelete []policy.Policy + var subResourcePolicies []policy.Policy for _, pol := range allPolicies { switch pol.ResourceType { case schema.OrganizationNamespace: if pol.ResourceID == orgID { - orgPolicyIDs = append(orgPolicyIDs, pol.ID) + orgPoliciesToDelete = append(orgPoliciesToDelete, pol) } case schema.ProjectNamespace: if _, ok := orgProjectIDSet[pol.ResourceID]; ok { - if err := s.policyService.Delete(ctx, pol.ID); err != nil { - errs = errors.Join(errs, fmt.Errorf("delete project policy %s: %w", pol.ID, err)) - } + subResourcePolicies = append(subResourcePolicies, pol) } case schema.GroupNamespace: if _, ok := orgGroupIDSet[pol.ResourceID]; ok { - if err := s.policyService.Delete(ctx, pol.ID); err != nil { - errs = errors.Join(errs, fmt.Errorf("delete group policy %s: %w", pol.ID, err)) + subResourcePolicies = append(subResourcePolicies, pol) + } + } + } + + // Phase 2: guarded org owner policy delete runs first — if the guard rejects + // (last owner), we return early before touching any other policies or relations + for _, pol := range orgPoliciesToDelete { + if pol.RoleID == ownerRoleID { + if err := s.policyService.DeleteWithMinRoleGuard(ctx, pol.ID, ownerRoleID); err != nil { + if errors.Is(err, policy.ErrLastRoleGuard) { + return ErrLastOwnerRole } + return fmt.Errorf("delete org policy %s: %w", pol.ID, err) + } + } + } + + // Phase 3: guard passed — safe to delete remaining org policies, sub-resource policies + for _, pol := range orgPoliciesToDelete { + if pol.RoleID != ownerRoleID { + if err := s.policyService.Delete(ctx, pol.ID); err != nil { + return fmt.Errorf("delete org policy %s: %w", pol.ID, err) } } } + var errs error + for _, pol := range subResourcePolicies { + if err := s.policyService.Delete(ctx, pol.ID); err != nil { + errs = errors.Join(errs, fmt.Errorf("delete sub-resource policy %s: %w", pol.ID, err)) + } + } if errs != nil { s.log.Error("partial failure removing member: some policies could not be deleted, manual cleanup may be needed", "org_id", orgID, @@ -337,7 +363,7 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal return errs } - // remove relations at group level + // guarded deletes passed — safe to clean up relations for _, g := range orgGroups { if err := s.removeRelations(ctx, g.ID, schema.GroupNamespace, principalID, principalType); err != nil { s.log.Error("partial failure removing member: group relation cleanup failed, manual cleanup may be needed", @@ -351,7 +377,6 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal } } - // remove relations at org level if err := s.removeRelations(ctx, orgID, schema.OrganizationNamespace, principalID, principalType); err != nil { s.log.Error("partial failure removing member: org relation cleanup failed, manual cleanup may be needed", "org_id", orgID, @@ -362,7 +387,6 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal return fmt.Errorf("remove org relations: %w", err) } - // remove identity link for service users (serviceuser#org@organization) if principalType == schema.ServiceUserPrincipal { err := s.relationService.Delete(ctx, relation.Relation{ Object: relation.Object{ID: principalID, Namespace: schema.ServiceUserPrincipal}, @@ -374,20 +398,6 @@ func (s *Service) RemoveOrganizationMember(ctx context.Context, orgID, principal } } - // delete org-level policies last - for _, policyID := range orgPolicyIDs { - if err := s.policyService.Delete(ctx, policyID); err != nil { - s.log.Error("partial failure removing member: org policy deletion failed, manual cleanup may be needed", - "org_id", orgID, - "policy_id", policyID, - "principal_id", principalID, - "principal_type", principalType, - "error", err, - ) - return fmt.Errorf("delete org policy %s: %w", policyID, err) - } - } - s.auditOrgMemberRemoved(ctx, org, principalID, targetAuditType) audit.GetAuditor(ctx, org.ID).Log(audit.OrgMemberDeletedEvent, audit.Target{ ID: principalID, @@ -411,15 +421,16 @@ func (s *Service) removeRelations(ctx context.Context, resourceID, resourceType, } // validateMinOwnerConstraint ensures the org always has at least one owner after a role change. -func (s *Service) validateMinOwnerConstraint(ctx context.Context, orgID, newRoleID string, existing []policy.Policy) error { +// Returns the resolved owner role ID for reuse by callers. +func (s *Service) validateMinOwnerConstraint(ctx context.Context, orgID, newRoleID string, existing []policy.Policy) (string, error) { ownerRole, err := s.roleService.Get(ctx, schema.RoleOrganizationOwner) if err != nil { - return fmt.Errorf("get owner role: %w", err) + return "", fmt.Errorf("get owner role: %w", err) } // no constraint if promoting to owner if newRoleID == ownerRole.ID { - return nil + return ownerRole.ID, nil } // no constraint if user is not currently an owner @@ -431,7 +442,7 @@ func (s *Service) validateMinOwnerConstraint(ctx context.Context, orgID, newRole } } if !isCurrentlyOwner { - return nil + return ownerRole.ID, nil } // user is owner, being demoted — make sure at least one other owner remains @@ -440,10 +451,44 @@ func (s *Service) validateMinOwnerConstraint(ctx context.Context, orgID, newRole RoleID: ownerRole.ID, }) if err != nil { - return fmt.Errorf("list owner policies: %w", err) + return "", fmt.Errorf("list owner policies: %w", err) } if len(ownerPolicies) <= 1 { - return ErrLastOwnerRole + return "", ErrLastOwnerRole + } + return ownerRole.ID, nil +} + +// replacePolicyWithOwnerGuard deletes existing policies using an atomic SQL guard +// that prevents removing the last owner, then creates the new policy. +func (s *Service) replacePolicyWithOwnerGuard(ctx context.Context, resourceID, resourceType, principalID, principalType, roleID string, existing []policy.Policy, ownerRoleID string) error { + for _, p := range existing { + if p.RoleID == ownerRoleID { + err := s.policyService.DeleteWithMinRoleGuard(ctx, p.ID, ownerRoleID) + if err != nil { + if errors.Is(err, policy.ErrLastRoleGuard) { + return ErrLastOwnerRole + } + return fmt.Errorf("delete policy %s: %w", p.ID, err) + } + } else { + if err := s.policyService.Delete(ctx, p.ID); err != nil { + return fmt.Errorf("delete policy %s: %w", p.ID, err) + } + } + } + + _, err := s.createPolicy(ctx, resourceID, resourceType, principalID, principalType, roleID) + if err != nil { + s.log.ErrorContext(ctx, "membership state inconsistent: old policies deleted but new policy creation failed, needs manual fix", + "resource_id", resourceID, + "resource_type", resourceType, + "principal_id", principalID, + "principal_type", principalType, + "role_id", roleID, + "error", err, + ) + return err } return nil } diff --git a/core/membership/service_test.go b/core/membership/service_test.go index 960102619..75d58dc44 100644 --- a/core/membership/service_test.go +++ b/core/membership/service_test.go @@ -405,6 +405,22 @@ func TestService_SetOrganizationMemberRole(t *testing.T) { roleID: managerRoleID, wantErr: membership.ErrLastOwnerRole, }, + { + name: "should return ErrLastOwnerRole when DB guard rejects concurrent demotion", + setup: func(policySvc *mocks.PolicyService, _ *mocks.RelationService, roleSvc *mocks.RoleService, orgSvc *mocks.OrgService, userSvc *mocks.UserService, _ *mocks.AuditRecordRepository) { + orgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) + userSvc.EXPECT().GetByID(ctx, userID).Return(enabledUser, nil) + roleSvc.EXPECT().Get(ctx, viewerRoleID).Return(role.Role{ID: viewerRoleID, Scopes: []string{schema.OrganizationNamespace}}, nil) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) + roleSvc.EXPECT().Get(ctx, schema.RoleOrganizationOwner).Return(role.Role{ID: ownerRoleID, Name: schema.RoleOrganizationOwner}, nil) + // app-level check passes (sees 2 owners) + policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, RoleID: ownerRoleID}).Return([]policy.Policy{{ID: "p1"}, {ID: "p2"}}, nil) + // DB-level guard rejects (concurrent request already deleted the other owner) + policySvc.EXPECT().DeleteWithMinRoleGuard(ctx, "p1", ownerRoleID).Return(policy.ErrLastRoleGuard) + }, + roleID: viewerRoleID, + wantErr: membership.ErrLastOwnerRole, + }, { name: "should succeed demoting owner to viewer with multiple owners", setup: func(policySvc *mocks.PolicyService, relSvc *mocks.RelationService, roleSvc *mocks.RoleService, orgSvc *mocks.OrgService, userSvc *mocks.UserService, auditRepo *mocks.AuditRecordRepository) { @@ -414,8 +430,8 @@ func TestService_SetOrganizationMemberRole(t *testing.T) { policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) roleSvc.EXPECT().Get(ctx, schema.RoleOrganizationOwner).Return(role.Role{ID: ownerRoleID, Name: schema.RoleOrganizationOwner}, nil) policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, RoleID: ownerRoleID}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}, {ID: "p2", RoleID: ownerRoleID}}, nil) - // replace policy - policySvc.EXPECT().Delete(ctx, "p1").Return(nil) + // replace policy with owner guard + policySvc.EXPECT().DeleteWithMinRoleGuard(ctx, "p1", ownerRoleID).Return(nil) policySvc.EXPECT().Create(ctx, policy.Policy{ RoleID: viewerRoleID, ResourceID: orgID, ResourceType: schema.OrganizationNamespace, PrincipalID: userID, PrincipalType: schema.UserPrincipal, @@ -438,7 +454,7 @@ func TestService_SetOrganizationMemberRole(t *testing.T) { policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: viewerRoleID}}, nil) // promoting to owner — min-owner constraint doesn't apply roleSvc.EXPECT().Get(ctx, schema.RoleOrganizationOwner).Return(role.Role{ID: ownerRoleID, Name: schema.RoleOrganizationOwner}, nil) - // replace policy + // existing policy is viewer (non-owner), uses plain Delete policySvc.EXPECT().Delete(ctx, "p1").Return(nil) policySvc.EXPECT().Create(ctx, policy.Policy{ RoleID: ownerRoleID, ResourceID: orgID, ResourceType: schema.OrganizationNamespace, @@ -462,7 +478,7 @@ func TestService_SetOrganizationMemberRole(t *testing.T) { policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) roleSvc.EXPECT().Get(ctx, schema.RoleOrganizationOwner).Return(role.Role{ID: ownerRoleID, Name: schema.RoleOrganizationOwner}, nil) policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, RoleID: ownerRoleID}).Return([]policy.Policy{{ID: "p1"}, {ID: "p2"}}, nil) - policySvc.EXPECT().Delete(ctx, "p1").Return(nil) + policySvc.EXPECT().DeleteWithMinRoleGuard(ctx, "p1", ownerRoleID).Return(nil) policySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{}, nil) // relation delete fails with a real error — logged, no rollback relSvc.EXPECT().Delete(ctx, orgRelation(schema.OwnerRelationName)).Return(errors.New("spicedb connection error")) @@ -478,6 +494,7 @@ func TestService_SetOrganizationMemberRole(t *testing.T) { roleSvc.EXPECT().Get(ctx, viewerRoleID).Return(role.Role{ID: viewerRoleID, Scopes: []string{schema.OrganizationNamespace}}, nil) policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: managerRoleID}}, nil) roleSvc.EXPECT().Get(ctx, schema.RoleOrganizationOwner).Return(role.Role{ID: ownerRoleID, Name: schema.RoleOrganizationOwner}, nil) + // existing policy is manager (non-owner), uses plain Delete policySvc.EXPECT().Delete(ctx, "p1").Return(nil) policySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{}, nil) // both relation deletes return not-found — that's fine, should continue @@ -546,7 +563,7 @@ func TestService_SetOrganizationMemberRole_ServiceUser(t *testing.T) { mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: suID, PrincipalType: schema.ServiceUserPrincipal}).Return([]policy.Policy{{ID: "p1", RoleID: ownerRoleID}}, nil) mockRoleSvc.EXPECT().Get(ctx, schema.RoleOrganizationOwner).Return(role.Role{ID: ownerRoleID, Name: schema.RoleOrganizationOwner}, nil) mockPolicySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, RoleID: ownerRoleID}).Return([]policy.Policy{{ID: "p1"}, {ID: "p2"}}, nil) - mockPolicySvc.EXPECT().Delete(ctx, "p1").Return(nil) + mockPolicySvc.EXPECT().DeleteWithMinRoleGuard(ctx, "p1", ownerRoleID).Return(nil) mockPolicySvc.EXPECT().Create(ctx, mock.Anything).Return(policy.Policy{}, nil) mockRelSvc.EXPECT().Delete(ctx, mock.Anything).Return(relation.ErrNotExist).Times(2) mockRelSvc.EXPECT().Create(ctx, mock.Anything).Return(relation.Relation{}, nil) @@ -752,10 +769,10 @@ func TestService_RemoveOrganizationMember(t *testing.T) { }, nil) d.policySvc.EXPECT().Delete(ctx, "proj-p1").Return(errors.New("delete failed")) }, - wantErrContain: "delete project policy", + wantErrContain: "delete sub-resource policy", }, { - name: "should return error if org relation removal fails without deleting org policies", + name: "should return error if org relation removal fails after org policies deleted", setup: func(d testDeps) { d.orgSvc.EXPECT().Get(ctx, orgID).Return(enabledOrg, nil) d.policySvc.EXPECT().List(ctx, policy.Filter{OrgID: orgID, PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{{ID: "org-p1", RoleID: viewerRoleID}}, nil) @@ -763,9 +780,11 @@ func TestService_RemoveOrganizationMember(t *testing.T) { d.projSvc.EXPECT().List(ctx, project.Filter{OrgID: orgID}).Return([]project.Project{}, nil) d.grpSvc.EXPECT().List(ctx, group.Filter{OrganizationID: orgID}).Return([]group.Group{}, nil) d.policySvc.EXPECT().List(ctx, policy.Filter{PrincipalID: userID, PrincipalType: schema.UserPrincipal}).Return([]policy.Policy{ - {ID: "org-p1", ResourceType: schema.OrganizationNamespace, ResourceID: orgID}, + {ID: "org-p1", ResourceType: schema.OrganizationNamespace, ResourceID: orgID, RoleID: viewerRoleID}, }, nil) - // org policy Delete should NOT be called — relations fail first, org policies are last + // org policy deleted first (viewer, plain Delete) + d.policySvc.EXPECT().Delete(ctx, "org-p1").Return(nil) + // then relation removal fails d.relSvc.EXPECT().Delete(ctx, relation.Relation{Object: orgObj, Subject: userSub, RelationName: schema.OwnerRelationName}).Return(errors.New("spicedb down")) }, wantErrContain: "remove org relations", diff --git a/core/policy/errors.go b/core/policy/errors.go index e33649d08..236305b23 100644 --- a/core/policy/errors.go +++ b/core/policy/errors.go @@ -8,4 +8,5 @@ var ( ErrInvalidID = errors.New("policy id is invalid") ErrConflict = errors.New("policy already exist") ErrInvalidDetail = errors.New("invalid policy detail") + ErrLastRoleGuard = errors.New("cannot delete: this is the last policy with the guarded role for this resource") ) diff --git a/core/policy/mocks/repository.go b/core/policy/mocks/repository.go index fdd9172ba..7486bbebe 100644 --- a/core/policy/mocks/repository.go +++ b/core/policy/mocks/repository.go @@ -126,6 +126,54 @@ func (_c *Repository_Delete_Call) RunAndReturn(run func(context.Context, string) return _c } +// DeleteWithMinRoleGuard provides a mock function with given fields: ctx, id, guardRoleID +func (_m *Repository) DeleteWithMinRoleGuard(ctx context.Context, id string, guardRoleID string) error { + ret := _m.Called(ctx, id, guardRoleID) + + if len(ret) == 0 { + panic("no return value specified for DeleteWithMinRoleGuard") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, id, guardRoleID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Repository_DeleteWithMinRoleGuard_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteWithMinRoleGuard' +type Repository_DeleteWithMinRoleGuard_Call struct { + *mock.Call +} + +// DeleteWithMinRoleGuard is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - guardRoleID string +func (_e *Repository_Expecter) DeleteWithMinRoleGuard(ctx interface{}, id interface{}, guardRoleID interface{}) *Repository_DeleteWithMinRoleGuard_Call { + return &Repository_DeleteWithMinRoleGuard_Call{Call: _e.mock.On("DeleteWithMinRoleGuard", ctx, id, guardRoleID)} +} + +func (_c *Repository_DeleteWithMinRoleGuard_Call) Run(run func(ctx context.Context, id string, guardRoleID string)) *Repository_DeleteWithMinRoleGuard_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *Repository_DeleteWithMinRoleGuard_Call) Return(_a0 error) *Repository_DeleteWithMinRoleGuard_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Repository_DeleteWithMinRoleGuard_Call) RunAndReturn(run func(context.Context, string, string) error) *Repository_DeleteWithMinRoleGuard_Call { + _c.Call.Return(run) + return _c +} + // Get provides a mock function with given fields: ctx, id func (_m *Repository) Get(ctx context.Context, id string) (policy.Policy, error) { ret := _m.Called(ctx, id) diff --git a/core/policy/policy.go b/core/policy/policy.go index ce2aaecf6..7287d10b6 100644 --- a/core/policy/policy.go +++ b/core/policy/policy.go @@ -13,6 +13,7 @@ type Repository interface { Count(ctx context.Context, f Filter) (int64, error) Upsert(ctx context.Context, pol Policy) (Policy, error) Delete(ctx context.Context, id string) error + DeleteWithMinRoleGuard(ctx context.Context, id string, guardRoleID string) error GroupMemberCount(ctx context.Context, IDs []string) ([]MemberCount, error) ProjectMemberCount(ctx context.Context, IDs []string) ([]MemberCount, error) OrgMemberCount(ctx context.Context, ID string) (MemberCount, error) diff --git a/core/policy/service.go b/core/policy/service.go index 93075d4c6..e1a432d31 100644 --- a/core/policy/service.go +++ b/core/policy/service.go @@ -89,6 +89,18 @@ func (s Service) Delete(ctx context.Context, id string) error { return s.repository.Delete(ctx, id) } +func (s Service) DeleteWithMinRoleGuard(ctx context.Context, id string, guardRoleID string) error { + if err := s.repository.DeleteWithMinRoleGuard(ctx, id, guardRoleID); err != nil { + return err + } + return s.relationService.Delete(ctx, relation.Relation{ + Object: relation.Object{ + ID: id, + Namespace: schema.RoleBindingNamespace, + }, + }) +} + // AssignRole Note: ideally this should be in a single transaction // read more about how user defined roles work in spicedb https://authzed.com/blog/user-defined-roles func (s Service) AssignRole(ctx context.Context, pol Policy) error { diff --git a/internal/store/postgres/policy_repository.go b/internal/store/postgres/policy_repository.go index cb834a479..aba751860 100644 --- a/internal/store/postgres/policy_repository.go +++ b/internal/store/postgres/policy_repository.go @@ -363,6 +363,84 @@ func (r PolicyRepository) Delete(ctx context.Context, id string) error { return nil } +// DeleteWithMinRoleGuard atomically deletes a policy only if at least one other +// policy with the same guarded role remains for the resource. Uses SELECT FOR UPDATE +// to serialize concurrent deletions under READ COMMITTED isolation, preventing the +// TOCTOU race where two concurrent requests both pass a count check then both delete. +// Resource ID and type are derived from the existing policy, not from caller input. +func (r PolicyRepository) DeleteWithMinRoleGuard(ctx context.Context, id string, guardRoleID string) error { + existingPolicy, err := r.Get(ctx, id) + if err != nil { + return err + } + + if err := r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error { + return r.dbc.WithTimeout(ctx, TABLE_POLICIES, "DeleteWithMinRoleGuard", func(ctx context.Context) error { + query := `WITH locked AS ( + SELECT id FROM ` + TABLE_POLICIES + ` + WHERE resource_id = $2 + AND resource_type = $3 + AND role_id = $4 + ORDER BY id + FOR UPDATE + ) + DELETE FROM ` + TABLE_POLICIES + ` WHERE id = $1 AND ( + (SELECT role_id FROM ` + TABLE_POLICIES + ` WHERE id = $1) != $4 + OR (SELECT COUNT(*) FROM locked WHERE id != $1) > 0 + )` + result, err := tx.ExecContext(ctx, query, + id, + existingPolicy.ResourceID, + existingPolicy.ResourceType, + guardRoleID, + ) + if err != nil { + return err + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return err + } + if rowsAffected == 0 { + var existingID string + err := tx.QueryRowContext(ctx, + `SELECT id FROM `+TABLE_POLICIES+` WHERE id = $1`, id, + ).Scan(&existingID) + if errors.Is(err, sql.ErrNoRows) { + return sql.ErrNoRows + } + if err != nil { + return err + } + return policy.ErrLastRoleGuard + } + + policyDB := Policy{ + ID: existingPolicy.ID, + RoleID: existingPolicy.RoleID, + ResourceID: existingPolicy.ResourceID, + ResourceType: existingPolicy.ResourceType, + PrincipalID: existingPolicy.PrincipalID, + PrincipalType: existingPolicy.PrincipalType, + } + auditRecord := r.buildPolicyAuditRecord(ctx, tx, auditrecord.PolicyDeletedEvent, policyDB, time.Now(), nil) + return InsertAuditRecordInTx(ctx, tx, auditRecord) + }) + }); err != nil { + if errors.Is(err, policy.ErrLastRoleGuard) { + return err + } + err = checkPostgresError(err) + switch { + case errors.Is(err, sql.ErrNoRows): + return policy.ErrNotExist + default: + return err + } + } + return nil +} + func (r PolicyRepository) GroupMemberCount(ctx context.Context, groupIDs []string) ([]policy.MemberCount, error) { if len(groupIDs) == 0 { return nil, policy.ErrInvalidID