From 740ae9cbe480c835fd5c63570aad82e0bde7a7ac Mon Sep 17 00:00:00 2001 From: aman Date: Mon, 9 Mar 2026 08:18:30 +0530 Subject: [PATCH 1/5] feat: PAT authentication chain, token validation, and error handling --- cmd/serve.go | 3 +- core/authenticate/authenticate.go | 9 +- core/authenticate/authenticators.go | 230 ++++++++++++++++++ core/authenticate/mocks/user_pat_service.go | 94 +++++++ core/authenticate/service.go | 169 ++----------- core/authenticate/service_test.go | 26 +- core/relation/errors.go | 1 + core/relation/service.go | 10 + core/userpat/{ => errors}/errors.go | 5 +- core/userpat/mocks/repository.go | 131 +++++++++- core/userpat/models/pat.go | 20 ++ core/userpat/service.go | 36 +-- core/userpat/service_test.go | 116 ++++----- core/userpat/userpat.go | 19 +- core/userpat/validator.go | 74 ++++++ core/userpat/validator_test.go | 136 +++++++++++ internal/api/v1beta1connect/authenticate.go | 6 + internal/api/v1beta1connect/interfaces.go | 3 +- .../v1beta1connect/mocks/user_pat_service.go | 18 +- internal/api/v1beta1connect/organization.go | 6 + internal/api/v1beta1connect/user_pat.go | 22 +- internal/api/v1beta1connect/user_pat_test.go | 50 ++-- internal/store/postgres/userpat.go | 8 +- internal/store/postgres/userpat_repository.go | 58 ++++- .../store/postgres/userpat_repository_test.go | 39 +-- pkg/server/connect_interceptors/session.go | 11 +- pkg/server/server.go | 2 +- 27 files changed, 962 insertions(+), 340 deletions(-) create mode 100644 core/authenticate/authenticators.go create mode 100644 core/authenticate/mocks/user_pat_service.go rename core/userpat/{ => errors}/errors.go (84%) create mode 100644 core/userpat/models/pat.go create mode 100644 core/userpat/validator.go create mode 100644 core/userpat/validator_test.go diff --git a/cmd/serve.go b/cmd/serve.go index 07ed23a13..5c0dee646 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -412,8 +412,9 @@ func buildAPIDependencies( roleService := role.NewService(roleRepository, relationService, permissionService, auditRecordRepository, cfg.App.PAT.DeniedPermissionsSet()) policyService := policy.NewService(policyPGRepository, relationService, roleService) userService := user.NewService(userRepository, relationService, policyService, roleService) + patValidator := userpat.NewValidator(logger, userPATRepo, cfg.App.PAT) authnService := authenticate.NewService(logger, cfg.App.Authentication, - postgres.NewFlowRepository(logger, dbc), mailDialer, tokenService, sessionService, userService, serviceUserService, webAuthConfig) + postgres.NewFlowRepository(logger, dbc), mailDialer, tokenService, sessionService, userService, serviceUserService, webAuthConfig, patValidator) groupService := group.NewService(groupRepository, relationService, authnService, policyService) organizationService := organization.NewService(organizationRepository, relationService, userService, authnService, policyService, preferenceService, auditRecordRepository) diff --git a/core/authenticate/authenticate.go b/core/authenticate/authenticate.go index 553cb6288..f84fb5f79 100644 --- a/core/authenticate/authenticate.go +++ b/core/authenticate/authenticate.go @@ -7,6 +7,7 @@ import ( "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" + pat "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/pkg/metadata" @@ -42,6 +43,8 @@ const ( // ClientCredentialsClientAssertion is used to authenticate using client_id and client_secret // that provides access token for the client ClientCredentialsClientAssertion ClientAssertion = "client_credentials" + // PATClientAssertion is used to authenticate using Personal Access Token + PATClientAssertion ClientAssertion = "pat" // PassthroughHeaderClientAssertion is used to authenticate using headers passed by the client // this is non secure way of authenticating client in test environments PassthroughHeaderClientAssertion ClientAssertion = "passthrough_header" @@ -53,9 +56,10 @@ func (a ClientAssertion) String() string { var APIAssertions = []ClientAssertion{ SessionClientAssertion, + PATClientAssertion, AccessTokenClientAssertion, - OpaqueTokenClientAssertion, JWTGrantClientAssertion, + OpaqueTokenClientAssertion, // ClientCredentialsClientAssertion should be removed in future to avoid DDOS attacks on CPU // and should only be allowed to be used get access token for the client ClientCredentialsClientAssertion, @@ -131,9 +135,10 @@ type Principal struct { // ID is the unique identifier of principal ID string // Type is the namespace of principal - // E.g. app/user, app/serviceuser + // E.g. app/user, app/serviceuser, app/pat Type string User *user.User ServiceUser *serviceuser.ServiceUser + PAT *pat.PAT } diff --git a/core/authenticate/authenticators.go b/core/authenticate/authenticators.go new file mode 100644 index 000000000..27e6855b0 --- /dev/null +++ b/core/authenticate/authenticators.go @@ -0,0 +1,230 @@ +package authenticate + +import ( + "context" + "encoding/base64" + "fmt" + "strings" + + "github.com/lestrrat-go/jwx/v2/jwt" + frontiersession "github.com/raystack/frontier/core/authenticate/session" + "github.com/raystack/frontier/core/authenticate/token" + patErrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/internal/bootstrap/schema" + "github.com/raystack/frontier/pkg/errors" + "github.com/raystack/frontier/pkg/utils" +) + +// AuthenticatorFunc attempts to authenticate a request. +// Returns (Principal, nil) on success, errSkip if not applicable (try next), +// or any other error for a terminal authentication failure. +type AuthenticatorFunc func(ctx context.Context, s *Service) (Principal, error) + +// errSkip signals that this authenticator doesn't apply to the request. +var errSkip = errors.New("skip authenticator") + +// authenticators maps each ClientAssertion to its authentication function. +var authenticators = map[ClientAssertion]AuthenticatorFunc{ + SessionClientAssertion: authenticateWithSession, + PATClientAssertion: authenticateWithPAT, + AccessTokenClientAssertion: authenticateWithAccessToken, + JWTGrantClientAssertion: authenticateWithJWTGrant, + ClientCredentialsClientAssertion: authenticateWithClientCredentials, + OpaqueTokenClientAssertion: authenticateWithClientCredentials, + PassthroughHeaderClientAssertion: authenticateWithPassthroughHeader, +} + +// authenticateWithSession extracts user from session cookie. +// Copied from original GetPrincipal session block. +func authenticateWithSession(ctx context.Context, s *Service) (Principal, error) { + session, err := s.sessionService.ExtractFromContext(ctx) + if err == nil && session.IsValid(s.Now()) && utils.IsValidUUID(session.UserID) { + // userID is a valid uuid + currentUser, err := s.userService.GetByID(ctx, session.UserID) + if err != nil { + s.log.Debug(fmt.Sprintf("unable to get session user by id: %v", err)) + return Principal{}, err + } + return Principal{ + ID: currentUser.ID, + Type: schema.UserPrincipal, + User: ¤tUser, + }, nil + } + if err != nil && !errors.Is(err, frontiersession.ErrNoSession) { + s.log.Debug(fmt.Sprintf("unable to extract session from context: %v", err)) + return Principal{}, err + } + return Principal{}, errSkip +} + +// authenticateWithPAT validates a personal access token. +func authenticateWithPAT(ctx context.Context, s *Service) (Principal, error) { + value, ok := GetTokenFromContext(ctx) + if !ok { + return Principal{}, errSkip + } + + pat, err := s.userPATService.Validate(ctx, value) + if err != nil { + if errors.Is(err, patErrors.ErrInvalidPAT) || errors.Is(err, patErrors.ErrDisabled) { + return Principal{}, errSkip + } + s.log.Debug("PAT validation failed", "err", err) + return Principal{}, err + } + + // resolve the owning user so downstream handlers can access principal.User + currentUser, err := s.userService.GetByID(ctx, pat.UserID) + if err != nil { + s.log.Debug("failed to get PAT owner", "err", err) + return Principal{}, err + } + + return Principal{ + ID: pat.ID, + Type: schema.PATPrincipal, + PAT: &pat, + User: ¤tUser, + }, nil +} + +// authenticateWithAccessToken validates a Frontier-issued JWT access token. +// Copied from original GetPrincipal access token block. +func authenticateWithAccessToken(ctx context.Context, s *Service) (Principal, error) { + userToken, ok := GetTokenFromContext(ctx) + if !ok { + return Principal{}, errSkip + } + + insecureJWT, err := jwt.ParseInsecure([]byte(userToken)) + if err != nil { + // NOTE: in the original code, AccessToken and JWTGrant were in the same if-block, + // so JWT parse failure fell through to GetByJWT. With separate authenticators, + // errSkip is required to preserve that behavior. + s.log.Debug(fmt.Sprintf("unable to parse token: %v", err)) + return Principal{}, errSkip + } + + // check type of jwt + if genClaim, ok := insecureJWT.Get(token.GeneratedClaimKey); ok { + // jwt generated by frontier using public key + claimVal, ok := genClaim.(string) + if !ok || claimVal != token.GeneratedClaimValue { + s.log.Debug("generated claim value mismatch") + return Principal{}, errors.ErrUnauthenticated + } + + // extract user from token if present as its created by frontier + userID, claims, err := s.internalTokenService.Parse(ctx, []byte(userToken)) + if err != nil || !utils.IsValidUUID(userID) { + s.log.Debug("failed to parse as internal token ", "err", err) + return Principal{}, errors.ErrUnauthenticated + } + + // userID is a valid uuid + if claims[token.SubTypeClaimsKey] == schema.ServiceUserPrincipal { + currentUser, err := s.serviceUserService.Get(ctx, userID) + if err != nil { + s.log.Debug("failed to get service user", "err", err) + return Principal{}, err + } + return Principal{ + ID: currentUser.ID, + Type: schema.ServiceUserPrincipal, + ServiceUser: ¤tUser, + }, nil + } + + currentUser, err := s.userService.GetByID(ctx, userID) + if err != nil { + s.log.Debug("failed to get user", "err", err) + return Principal{}, err + } + return Principal{ + ID: currentUser.ID, + Type: schema.UserPrincipal, + User: ¤tUser, + }, nil + } + + // NOTE: in the original code, a valid JWT without GeneratedClaimKey fell through to + // GetByJWT within the same if-block. errSkip preserves that behavior. + return Principal{}, errSkip +} + +// authenticateWithJWTGrant validates a service user JWT grant token. +// Copied from original GetPrincipal jwt grant block. +func authenticateWithJWTGrant(ctx context.Context, s *Service) (Principal, error) { + userToken, ok := GetTokenFromContext(ctx) + if !ok { + return Principal{}, errSkip + } + + serviceUser, err := s.serviceUserService.GetByJWT(ctx, userToken) + if err == nil { + return Principal{ + ID: serviceUser.ID, + Type: schema.ServiceUserPrincipal, + ServiceUser: &serviceUser, + }, nil + } + s.log.Debug("failed to parse as user token ", "err", err) + return Principal{}, errors.ErrUnauthenticated +} + +// authenticateWithClientCredentials validates client_id:client_secret credentials. +// Copied from original GetPrincipal client credentials block. +func authenticateWithClientCredentials(ctx context.Context, s *Service) (Principal, error) { + userSecretRaw, ok := GetSecretFromContext(ctx) + if !ok { + return Principal{}, errSkip + } + + // verify client secret + userSecret, err := base64.StdEncoding.DecodeString(userSecretRaw) + if err != nil { + s.log.Debug("failed to decode user secret", "err", err) + return Principal{}, errors.ErrUnauthenticated + } + userSecretParts := strings.Split(string(userSecret), ":") + if len(userSecretParts) != 2 { + s.log.Debug("failed to parse user secret") + return Principal{}, errors.ErrUnauthenticated + } + clientID, clientSecret := userSecretParts[0], userSecretParts[1] + + // extract user from secret if it's a service user + serviceUser, err := s.serviceUserService.GetBySecret(ctx, clientID, clientSecret) + if err == nil { + return Principal{ + ID: serviceUser.ID, + Type: schema.ServiceUserPrincipal, + ServiceUser: &serviceUser, + }, nil + } + s.log.Debug("failed to authenticate with client credentials", "err", err) + return Principal{}, errors.ErrUnauthenticated +} + +// authenticateWithPassthroughHeader extracts user from email header. +// Copied from original GetPrincipal passthrough block. +func authenticateWithPassthroughHeader(ctx context.Context, s *Service) (Principal, error) { + // check if header with user email is set + // TODO(kushsharma): this should ideally be deprecated + val, ok := GetEmailFromContext(ctx) + if !ok || len(val) == 0 { + return Principal{}, errSkip + } + + currentUser, err := s.getOrCreateUser(ctx, strings.TrimSpace(val), strings.Split(val, "@")[0]) + if err != nil { + s.log.Debug("failed to get user", "err", err) + return Principal{}, err + } + return Principal{ + ID: currentUser.ID, + Type: schema.UserPrincipal, + User: ¤tUser, + }, nil +} diff --git a/core/authenticate/mocks/user_pat_service.go b/core/authenticate/mocks/user_pat_service.go new file mode 100644 index 000000000..f4da52e89 --- /dev/null +++ b/core/authenticate/mocks/user_pat_service.go @@ -0,0 +1,94 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + models "github.com/raystack/frontier/core/userpat/models" + mock "github.com/stretchr/testify/mock" +) + +// UserPATService is an autogenerated mock type for the UserPATService type +type UserPATService struct { + mock.Mock +} + +type UserPATService_Expecter struct { + mock *mock.Mock +} + +func (_m *UserPATService) EXPECT() *UserPATService_Expecter { + return &UserPATService_Expecter{mock: &_m.Mock} +} + +// Validate provides a mock function with given fields: ctx, value +func (_m *UserPATService) Validate(ctx context.Context, value string) (models.PAT, error) { + ret := _m.Called(ctx, value) + + if len(ret) == 0 { + panic("no return value specified for Validate") + } + + var r0 models.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (models.PAT, error)); ok { + return rf(ctx, value) + } + if rf, ok := ret.Get(0).(func(context.Context, string) models.PAT); ok { + r0 = rf(ctx, value) + } else { + r0 = ret.Get(0).(models.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, value) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UserPATService_Validate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Validate' +type UserPATService_Validate_Call struct { + *mock.Call +} + +// Validate is a helper method to define mock.On call +// - ctx context.Context +// - value string +func (_e *UserPATService_Expecter) Validate(ctx interface{}, value interface{}) *UserPATService_Validate_Call { + return &UserPATService_Validate_Call{Call: _e.mock.On("Validate", ctx, value)} +} + +func (_c *UserPATService_Validate_Call) Run(run func(ctx context.Context, value string)) *UserPATService_Validate_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *UserPATService_Validate_Call) Return(_a0 models.PAT, _a1 error) *UserPATService_Validate_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *UserPATService_Validate_Call) RunAndReturn(run func(context.Context, string) (models.PAT, error)) *UserPATService_Validate_Call { + _c.Call.Return(run) + return _c +} + +// NewUserPATService creates a new instance of UserPATService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewUserPATService(t interface { + mock.TestingT + Cleanup(func()) +}) *UserPATService { + mock := &UserPATService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/authenticate/service.go b/core/authenticate/service.go index abdacf861..45a7754be 100644 --- a/core/authenticate/service.go +++ b/core/authenticate/service.go @@ -18,10 +18,9 @@ import ( "golang.org/x/exp/slices" - "github.com/lestrrat-go/jwx/v2/jwt" - frontiersession "github.com/raystack/frontier/core/authenticate/session" "github.com/raystack/frontier/core/serviceuser" + patModels "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/internal/metrics" "github.com/raystack/frontier/pkg/errors" @@ -89,6 +88,10 @@ type TokenService interface { Parse(ctx context.Context, userToken []byte) (string, map[string]any, error) } +type UserPATService interface { + Validate(ctx context.Context, value string) (patModels.PAT, error) +} + type Service struct { log log.Logger cron *cron.Cron @@ -100,12 +103,14 @@ type Service struct { internalTokenService TokenService sessionService SessionService serviceUserService ServiceUserService + userPATService UserPATService webAuth *webauthn.WebAuthn } func NewService(logger log.Logger, config Config, flowRepo FlowRepository, mailDialer mailer.Dialer, tokenService TokenService, sessionService SessionService, - userService UserService, serviceUserService ServiceUserService, webAuthConfig *webauthn.WebAuthn) *Service { + userService UserService, serviceUserService ServiceUserService, webAuthConfig *webauthn.WebAuthn, + userPATService UserPATService) *Service { r := &Service{ log: logger, cron: cron.New(cron.WithChain( @@ -122,6 +127,7 @@ func NewService(logger log.Logger, config Config, flowRepo FlowRepository, internalTokenService: tokenService, sessionService: sessionService, serviceUserService: serviceUserService, + userPATService: userPATService, webAuth: webAuthConfig, } return r @@ -746,157 +752,26 @@ func (s Service) GetPrincipal(ctx context.Context, assertions ...ClientAssertion defer promCollect() } - var currentPrincipal Principal - if len(assertions) == 0 { - // check all assertions - assertions = APIAssertions - } - - // check if already enriched by auth middleware if val, ok := GetPrincipalFromContext(ctx); ok { - currentPrincipal = *val - return currentPrincipal, nil + return *val, nil } - // extract user from session if present - if slices.Contains[[]ClientAssertion](assertions, SessionClientAssertion) { - session, err := s.sessionService.ExtractFromContext(ctx) - if err == nil && session.IsValid(s.Now()) && utils.IsValidUUID(session.UserID) { - // userID is a valid uuid - currentUser, err := s.userService.GetByID(ctx, session.UserID) - if err != nil { - s.log.Debug(fmt.Sprintf("unable to get session user by id: %v", err)) - return Principal{}, err - } - return Principal{ - ID: currentUser.ID, - Type: schema.UserPrincipal, - User: ¤tUser, - }, nil - } - if err != nil && !errors.Is(err, frontiersession.ErrNoSession) { - s.log.Debug(fmt.Sprintf("unable to extract session from context: %v", err)) - return Principal{}, err - } + if len(assertions) == 0 { + // check all assertions + assertions = APIAssertions } - // check for token - userToken, tokenOK := GetTokenFromContext(ctx) - if tokenOK { - if slices.Contains[[]ClientAssertion](assertions, AccessTokenClientAssertion) { - insecureJWT, err := jwt.ParseInsecure([]byte(userToken)) - if err != nil { - s.log.Debug(fmt.Sprintf("unable to parse token: %v", err)) - return Principal{}, errors.ErrUnauthenticated - } - // check type of jwt - if genClaim, ok := insecureJWT.Get(token.GeneratedClaimKey); ok { - // jwt generated by frontier using public key - claimVal, ok := genClaim.(string) - if !ok || claimVal != token.GeneratedClaimValue { - s.log.Debug("generated claim value mismatch") - return Principal{}, errors.ErrUnauthenticated - } - - // extract user from token if present as its created by frontier - userID, claims, err := s.internalTokenService.Parse(ctx, []byte(userToken)) - if err != nil || !utils.IsValidUUID(userID) { - s.log.Debug("failed to parse as internal token ", "err", err) - return Principal{}, errors.ErrUnauthenticated - } - - // userID is a valid uuid - if claims[token.SubTypeClaimsKey] == schema.ServiceUserPrincipal { - currentUser, err := s.serviceUserService.Get(ctx, userID) - if err != nil { - s.log.Debug("failed to get service user", "err", err) - return Principal{}, err - } - return Principal{ - ID: currentUser.ID, - Type: schema.ServiceUserPrincipal, - ServiceUser: ¤tUser, - }, nil - } - - currentUser, err := s.userService.GetByID(ctx, userID) - if err != nil { - s.log.Debug("failed to get user", "err", err) - return Principal{}, err - } - return Principal{ - ID: currentUser.ID, - Type: schema.UserPrincipal, - User: ¤tUser, - }, nil - } - } - - // extract user from token if it's a service user - if slices.Contains[[]ClientAssertion](assertions, JWTGrantClientAssertion) { - serviceUser, err := s.serviceUserService.GetByJWT(ctx, userToken) - if err == nil { - return Principal{ - ID: serviceUser.ID, - Type: schema.ServiceUserPrincipal, - ServiceUser: &serviceUser, - }, nil - } - if err != nil { - s.log.Debug("failed to parse as user token ", "err", err) - return Principal{}, errors.ErrUnauthenticated - } + for _, assertion := range assertions { + authenticator, exists := authenticators[assertion] + if !exists { + continue } - } - - // check for client secret - if slices.Contains[[]ClientAssertion](assertions, ClientCredentialsClientAssertion) || - slices.Contains[[]ClientAssertion](assertions, OpaqueTokenClientAssertion) { - userSecretRaw, secretOK := GetSecretFromContext(ctx) - if secretOK { - // verify client secret - userSecret, err := base64.StdEncoding.DecodeString(userSecretRaw) - if err != nil { - s.log.Debug("failed to decode user secret", "err", err) - return Principal{}, errors.ErrUnauthenticated - } - userSecretParts := strings.Split(string(userSecret), ":") - if len(userSecretParts) != 2 { - s.log.Debug("failed to parse user secret", "err", err) - return Principal{}, errors.ErrUnauthenticated - } - clientID, clientSecret := userSecretParts[0], userSecretParts[1] - - // extract user from secret if it's a service user - serviceUser, err := s.serviceUserService.GetBySecret(ctx, clientID, clientSecret) - if err == nil { - return Principal{ - ID: serviceUser.ID, - Type: schema.ServiceUserPrincipal, - ServiceUser: &serviceUser, - }, nil - } - if err != nil { - s.log.Debug("failed to parse as user token ", "err", err) - return Principal{}, errors.ErrUnauthenticated - } + principal, err := authenticator(ctx, &s) + if err == nil { + return principal, nil } - } - - if slices.Contains[[]ClientAssertion](assertions, PassthroughHeaderClientAssertion) { - // check if header with user email is set - // TODO(kushsharma): this should ideally be deprecated - if val, ok := GetEmailFromContext(ctx); ok && len(val) > 0 { - currentUser, err := s.getOrCreateUser(ctx, strings.TrimSpace(val), strings.Split(val, "@")[0]) - if err != nil { - s.log.Debug("failed to get user", "err", err) - return Principal{}, err - } - return Principal{ - ID: currentUser.ID, - Type: schema.UserPrincipal, - User: ¤tUser, - }, nil + if !errors.Is(err, errSkip) { + return Principal{}, err } } diff --git a/core/authenticate/service_test.go b/core/authenticate/service_test.go index 7ce41de71..17d0069d9 100644 --- a/core/authenticate/service_test.go +++ b/core/authenticate/service_test.go @@ -76,7 +76,7 @@ func TestService_GetPrincipal(t *testing.T) { }, wantErr: false, setup: func() *authenticate.Service { - return authenticate.NewService(nil, authenticate.Config{}, nil, nil, nil, nil, nil, nil, nil) + return authenticate.NewService(nil, authenticate.Config{}, nil, nil, nil, nil, nil, nil, nil, nil) }, }, { @@ -111,7 +111,7 @@ func TestService_GetPrincipal(t *testing.T) { }, nil) return authenticate.NewService(nil, authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -135,7 +135,7 @@ func TestService_GetPrincipal(t *testing.T) { mockSessionService.EXPECT().ExtractFromContext(mock.Anything).Return(mockSess, nil) return authenticate.NewService(log.NewLogrus(), authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -163,7 +163,7 @@ func TestService_GetPrincipal(t *testing.T) { }, nil) return authenticate.NewService(nil, authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -181,7 +181,7 @@ func TestService_GetPrincipal(t *testing.T) { mockTokenService.EXPECT().Parse(mock.Anything, tokenBytes).Return("", map[string]interface{}{}, errors.New("invalid token")) return authenticate.NewService(log.NewLogrus(), authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -208,7 +208,7 @@ func TestService_GetPrincipal(t *testing.T) { }, nil) return authenticate.NewService(nil, authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -226,7 +226,7 @@ func TestService_GetPrincipal(t *testing.T) { mockServiceUserService.EXPECT().GetByJWT(mock.Anything, string(tokenBytes)).Return(serviceuser.ServiceUser{}, errors.New("invalid")) return authenticate.NewService(log.NewLogrus(), authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -253,7 +253,7 @@ func TestService_GetPrincipal(t *testing.T) { }, nil) return authenticate.NewService(nil, authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, { @@ -280,7 +280,7 @@ func TestService_GetPrincipal(t *testing.T) { }, nil) return authenticate.NewService(nil, authenticate.Config{}, - mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil) + mockFlow, nil, mockTokenService, mockSessionService, mockUserService, mockServiceUserService, nil, nil) }, }, } @@ -339,7 +339,7 @@ func TestService_StartFlow(t *testing.T) { wantErr: authenticate.ErrUnsupportedMethod, setup: func() *authenticate.Service { return authenticate.NewService(nil, authenticate.Config{}, nil, nil, - nil, nil, nil, nil, nil) + nil, nil, nil, nil, nil, nil) }, }, { @@ -370,7 +370,7 @@ func TestService_StartFlow(t *testing.T) { TestUsers: testusers.Config{Enabled: true, OTP: "111111", Domain: "example.com"}, }, mockFlowRepo, mockDialer, nil, nil, - nil, nil, nil) + nil, nil, nil, nil) srv.Now = func() time.Time { return timeNow } @@ -402,7 +402,7 @@ func TestService_StartFlow(t *testing.T) { TestUsers: testusers.Config{Enabled: true, OTP: "111111", Domain: "example.com"}, }, mockFlowRepo, mockDialer, nil, nil, - nil, nil, nil) + nil, nil, nil, nil) srv.Now = func() time.Time { return timeNow } @@ -433,7 +433,7 @@ func TestService_StartFlow(t *testing.T) { MailOTP: authenticate.MailOTPConfig{}, }, mockFlowRepo, mockDialer, nil, nil, - nil, nil, nil) + nil, nil, nil, nil) srv.Now = func() time.Time { return timeNow } diff --git a/core/relation/errors.go b/core/relation/errors.go index 9d09538d1..cd87d886b 100644 --- a/core/relation/errors.go +++ b/core/relation/errors.go @@ -11,4 +11,5 @@ var ( ErrCreatingRelationInStore = errors.New("error while creating relation") ErrCreatingRelationInAuthzEngine = errors.New("error while creating relation in authz engine") ErrFetchingUser = errors.New("error while fetching user") + ErrSubjectNotAllowed = errors.New("subject type is not allowed on this relation") ) diff --git a/core/relation/service.go b/core/relation/service.go index 0549da742..25341e217 100644 --- a/core/relation/service.go +++ b/core/relation/service.go @@ -5,6 +5,10 @@ import ( "errors" "fmt" "regexp" + + "github.com/raystack/frontier/internal/bootstrap/schema" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type Service struct { @@ -35,6 +39,12 @@ func (s Service) Create(ctx context.Context, rel Relation) (Relation, error) { err = s.authzRepository.Add(ctx, createdRelation) if err != nil { + // PAT subjects may be rejected by the authz schema for relations they are not allowed on + if createdRelation.Subject.Namespace == schema.PATPrincipal { + if st, ok := status.FromError(err); ok && st.Code() == codes.InvalidArgument { + return Relation{}, fmt.Errorf("%w: %s", ErrSubjectNotAllowed, st.Message()) + } + } return Relation{}, fmt.Errorf("%w: %s", ErrCreatingRelationInAuthzEngine, err.Error()) } diff --git a/core/userpat/errors.go b/core/userpat/errors/errors.go similarity index 84% rename from core/userpat/errors.go rename to core/userpat/errors/errors.go index 297069d6b..ea4c1063d 100644 --- a/core/userpat/errors.go +++ b/core/userpat/errors/errors.go @@ -1,4 +1,4 @@ -package userpat +package errors import "errors" @@ -6,7 +6,8 @@ var ( ErrNotFound = errors.New("personal access token not found") ErrConflict = errors.New("personal access token with this name already exists") ErrExpired = errors.New("personal access token has expired") - ErrInvalidToken = errors.New("personal access token is invalid") + ErrInvalidPAT = errors.New("not a personal access token") + ErrMalformedPAT = errors.New("personal access token is malformed") ErrLimitExceeded = errors.New("maximum number of personal access tokens reached") ErrDisabled = errors.New("personal access tokens are not enabled") ErrExpiryExceeded = errors.New("expiry exceeds maximum allowed lifetime") diff --git a/core/userpat/mocks/repository.go b/core/userpat/mocks/repository.go index aae660f04..b4a022075 100644 --- a/core/userpat/mocks/repository.go +++ b/core/userpat/mocks/repository.go @@ -5,8 +5,10 @@ package mocks import ( context "context" - userpat "github.com/raystack/frontier/core/userpat" + models "github.com/raystack/frontier/core/userpat/models" mock "github.com/stretchr/testify/mock" + + time "time" ) // Repository is an autogenerated mock type for the Repository type @@ -81,25 +83,25 @@ func (_c *Repository_CountActive_Call) RunAndReturn(run func(context.Context, st } // Create provides a mock function with given fields: ctx, pat -func (_m *Repository) Create(ctx context.Context, pat userpat.PAT) (userpat.PAT, error) { +func (_m *Repository) Create(ctx context.Context, pat models.PAT) (models.PAT, error) { ret := _m.Called(ctx, pat) if len(ret) == 0 { panic("no return value specified for Create") } - var r0 userpat.PAT + var r0 models.PAT var r1 error - if rf, ok := ret.Get(0).(func(context.Context, userpat.PAT) (userpat.PAT, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, models.PAT) (models.PAT, error)); ok { return rf(ctx, pat) } - if rf, ok := ret.Get(0).(func(context.Context, userpat.PAT) userpat.PAT); ok { + if rf, ok := ret.Get(0).(func(context.Context, models.PAT) models.PAT); ok { r0 = rf(ctx, pat) } else { - r0 = ret.Get(0).(userpat.PAT) + r0 = ret.Get(0).(models.PAT) } - if rf, ok := ret.Get(1).(func(context.Context, userpat.PAT) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, models.PAT) error); ok { r1 = rf(ctx, pat) } else { r1 = ret.Error(1) @@ -115,24 +117,129 @@ type Repository_Create_Call struct { // Create is a helper method to define mock.On call // - ctx context.Context -// - pat userpat.PAT +// - pat models.PAT func (_e *Repository_Expecter) Create(ctx interface{}, pat interface{}) *Repository_Create_Call { return &Repository_Create_Call{Call: _e.mock.On("Create", ctx, pat)} } -func (_c *Repository_Create_Call) Run(run func(ctx context.Context, pat userpat.PAT)) *Repository_Create_Call { +func (_c *Repository_Create_Call) Run(run func(ctx context.Context, pat models.PAT)) *Repository_Create_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(userpat.PAT)) + run(args[0].(context.Context), args[1].(models.PAT)) }) return _c } -func (_c *Repository_Create_Call) Return(_a0 userpat.PAT, _a1 error) *Repository_Create_Call { +func (_c *Repository_Create_Call) Return(_a0 models.PAT, _a1 error) *Repository_Create_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *Repository_Create_Call) RunAndReturn(run func(context.Context, userpat.PAT) (userpat.PAT, error)) *Repository_Create_Call { +func (_c *Repository_Create_Call) RunAndReturn(run func(context.Context, models.PAT) (models.PAT, error)) *Repository_Create_Call { + _c.Call.Return(run) + return _c +} + +// GetBySecretHash provides a mock function with given fields: ctx, secretHash +func (_m *Repository) GetBySecretHash(ctx context.Context, secretHash string) (models.PAT, error) { + ret := _m.Called(ctx, secretHash) + + if len(ret) == 0 { + panic("no return value specified for GetBySecretHash") + } + + var r0 models.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (models.PAT, error)); ok { + return rf(ctx, secretHash) + } + if rf, ok := ret.Get(0).(func(context.Context, string) models.PAT); ok { + r0 = rf(ctx, secretHash) + } else { + r0 = ret.Get(0).(models.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, secretHash) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_GetBySecretHash_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetBySecretHash' +type Repository_GetBySecretHash_Call struct { + *mock.Call +} + +// GetBySecretHash is a helper method to define mock.On call +// - ctx context.Context +// - secretHash string +func (_e *Repository_Expecter) GetBySecretHash(ctx interface{}, secretHash interface{}) *Repository_GetBySecretHash_Call { + return &Repository_GetBySecretHash_Call{Call: _e.mock.On("GetBySecretHash", ctx, secretHash)} +} + +func (_c *Repository_GetBySecretHash_Call) Run(run func(ctx context.Context, secretHash string)) *Repository_GetBySecretHash_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_GetBySecretHash_Call) Return(_a0 models.PAT, _a1 error) *Repository_GetBySecretHash_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_GetBySecretHash_Call) RunAndReturn(run func(context.Context, string) (models.PAT, error)) *Repository_GetBySecretHash_Call { + _c.Call.Return(run) + return _c +} + +// UpdateLastUsedAt provides a mock function with given fields: ctx, id, at +func (_m *Repository) UpdateLastUsedAt(ctx context.Context, id string, at time.Time) error { + ret := _m.Called(ctx, id, at) + + if len(ret) == 0 { + panic("no return value specified for UpdateLastUsedAt") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, time.Time) error); ok { + r0 = rf(ctx, id, at) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Repository_UpdateLastUsedAt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateLastUsedAt' +type Repository_UpdateLastUsedAt_Call struct { + *mock.Call +} + +// UpdateLastUsedAt is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - at time.Time +func (_e *Repository_Expecter) UpdateLastUsedAt(ctx interface{}, id interface{}, at interface{}) *Repository_UpdateLastUsedAt_Call { + return &Repository_UpdateLastUsedAt_Call{Call: _e.mock.On("UpdateLastUsedAt", ctx, id, at)} +} + +func (_c *Repository_UpdateLastUsedAt_Call) Run(run func(ctx context.Context, id string, at time.Time)) *Repository_UpdateLastUsedAt_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(time.Time)) + }) + return _c +} + +func (_c *Repository_UpdateLastUsedAt_Call) Return(_a0 error) *Repository_UpdateLastUsedAt_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Repository_UpdateLastUsedAt_Call) RunAndReturn(run func(context.Context, string, time.Time) error) *Repository_UpdateLastUsedAt_Call { _c.Call.Return(run) return _c } diff --git a/core/userpat/models/pat.go b/core/userpat/models/pat.go new file mode 100644 index 000000000..18caea4dd --- /dev/null +++ b/core/userpat/models/pat.go @@ -0,0 +1,20 @@ +package models + +import ( + "time" + + "github.com/raystack/frontier/pkg/metadata" +) + +type PAT struct { + ID string `rql:"name=id,type=string"` + UserID string `rql:"name=user_id,type=string"` + OrgID string `rql:"name=org_id,type=string"` + Title string `rql:"name=title,type=string"` + SecretHash string `json:"-"` + Metadata metadata.Metadata + LastUsedAt *time.Time `rql:"name=last_used_at,type=datetime"` + ExpiresAt time.Time `rql:"name=expires_at,type=datetime"` + CreatedAt time.Time `rql:"name=created_at,type=datetime"` + UpdatedAt time.Time `rql:"name=updated_at,type=datetime"` +} diff --git a/core/userpat/service.go b/core/userpat/service.go index 78ce57e50..810fcb77b 100644 --- a/core/userpat/service.go +++ b/core/userpat/service.go @@ -15,6 +15,8 @@ import ( "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/policy" "github.com/raystack/frontier/core/role" + paterrors "github.com/raystack/frontier/core/userpat/errors" + patmodels "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" pkgAuditRecord "github.com/raystack/frontier/pkg/auditrecord" "github.com/raystack/salt/log" @@ -77,19 +79,19 @@ type CreateRequest struct { // the configured maximum PAT lifetime. func (s *Service) ValidateExpiry(expiresAt time.Time) error { if !expiresAt.After(time.Now()) { - return ErrExpiryInPast + return paterrors.ErrExpiryInPast } if expiresAt.After(time.Now().Add(s.config.MaxExpiry())) { - return ErrExpiryExceeded + return paterrors.ErrExpiryExceeded } return nil } // Create generates a new PAT and returns it with the plaintext value. // The plaintext value is only available at creation time. -func (s *Service) Create(ctx context.Context, req CreateRequest) (PAT, string, error) { +func (s *Service) Create(ctx context.Context, req CreateRequest) (patmodels.PAT, string, error) { if !s.config.Enabled { - return PAT{}, "", ErrDisabled + return patmodels.PAT{}, "", paterrors.ErrDisabled } // NOTE: CountActive + Create is not atomic (TOCTOU race). Two concurrent requests @@ -98,23 +100,23 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (PAT, string, e // use an atomic INSERT ... SELECT with a count subquery in the WHERE clause. count, err := s.repo.CountActive(ctx, req.UserID, req.OrgID) if err != nil { - return PAT{}, "", fmt.Errorf("counting active PATs: %w", err) + return patmodels.PAT{}, "", fmt.Errorf("counting active PATs: %w", err) } if count >= s.config.MaxPerUserPerOrg { - return PAT{}, "", ErrLimitExceeded + return patmodels.PAT{}, "", paterrors.ErrLimitExceeded } roles, err := s.resolveAndValidateRoles(ctx, req.RoleIDs) if err != nil { - return PAT{}, "", err + return patmodels.PAT{}, "", err } patValue, secretHash, err := s.generatePAT() if err != nil { - return PAT{}, "", err + return patmodels.PAT{}, "", err } - pat := PAT{ + pat := patmodels.PAT{ UserID: req.UserID, OrgID: req.OrgID, Title: req.Title, @@ -125,11 +127,11 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (PAT, string, e created, err := s.repo.Create(ctx, pat) if err != nil { - return PAT{}, "", err + return patmodels.PAT{}, "", err } if err := s.createPolicies(ctx, created.ID, req.OrgID, roles, req.ProjectIDs); err != nil { - return PAT{}, "", fmt.Errorf("creating policies: %w", err) + return patmodels.PAT{}, "", fmt.Errorf("creating policies: %w", err) } // TODO: move audit record creation into the same transaction as PAT creation to avoid partial state where PAT exists but audit record doesn't. @@ -144,7 +146,7 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (PAT, string, e } // createAuditRecord logs a PAT lifecycle event with org context and PAT metadata. -func (s *Service) createAuditRecord(ctx context.Context, event pkgAuditRecord.Event, pat PAT, occurredAt time.Time, targetMetadata map[string]any) error { +func (s *Service) createAuditRecord(ctx context.Context, event pkgAuditRecord.Event, pat patmodels.PAT, occurredAt time.Time, targetMetadata map[string]any) error { orgName := "" if org, err := s.orgService.GetRaw(ctx, pat.OrgID); err == nil { orgName = org.Title @@ -194,7 +196,7 @@ func (s *Service) resolveAndValidateRoles(ctx context.Context, roleIDs []string) missing = append(missing, id) } } - return nil, fmt.Errorf("role IDs not found: %v: %w", missing, ErrRoleNotFound) + return nil, fmt.Errorf("role IDs not found: %v: %w", missing, paterrors.ErrRoleNotFound) } if err := s.validateRolePermissions(roles); err != nil { @@ -203,11 +205,11 @@ func (s *Service) resolveAndValidateRoles(ctx context.Context, roleIDs []string) for _, r := range roles { if len(r.Scopes) == 0 { - return nil, fmt.Errorf("role %s has scopes %v: %w", r.Name, r.Scopes, ErrUnsupportedScope) + return nil, fmt.Errorf("role %s has scopes %v: %w", r.Name, r.Scopes, paterrors.ErrUnsupportedScope) } for _, scope := range r.Scopes { if scope != schema.ProjectNamespace && scope != schema.OrganizationNamespace { - return nil, fmt.Errorf("role %s has scopes %v: %w", r.Name, r.Scopes, ErrUnsupportedScope) + return nil, fmt.Errorf("role %s has scopes %v: %w", r.Name, r.Scopes, paterrors.ErrUnsupportedScope) } } } @@ -229,7 +231,7 @@ func (s *Service) createPolicies(ctx context.Context, patID, orgID string, roles case slices.Contains(r.Scopes, schema.OrganizationNamespace): err = s.createOrgScopedPolicy(ctx, patID, orgID, r) default: - err = fmt.Errorf("role %s has scopes %v: %w", r.Name, r.Scopes, ErrUnsupportedScope) + err = fmt.Errorf("role %s has scopes %v: %w", r.Name, r.Scopes, paterrors.ErrUnsupportedScope) } if err != nil { return err @@ -243,7 +245,7 @@ func (s *Service) validateRolePermissions(roles []role.Role) error { for _, r := range roles { for _, perm := range r.Permissions { if _, denied := s.deniedPerms[perm]; denied { - return fmt.Errorf("role %s has denied permission %s: %w", r.Name, perm, ErrDeniedRole) + return fmt.Errorf("role %s has denied permission %s: %w", r.Name, perm, paterrors.ErrDeniedRole) } } } diff --git a/core/userpat/service_test.go b/core/userpat/service_test.go index 33b2f2074..dd1629472 100644 --- a/core/userpat/service_test.go +++ b/core/userpat/service_test.go @@ -10,12 +10,14 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/raystack/frontier/core/auditrecord/models" + auditmodels "github.com/raystack/frontier/core/auditrecord/models" "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/policy" "github.com/raystack/frontier/core/role" "github.com/raystack/frontier/core/userpat" + paterrors "github.com/raystack/frontier/core/userpat/errors" "github.com/raystack/frontier/core/userpat/mocks" + "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/salt/log" "github.com/stretchr/testify/mock" @@ -46,7 +48,7 @@ func newSuccessMocks(t *testing.T) (*mocks.OrganizationService, *mocks.RoleServi Return(policy.Policy{}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) auditRepo.On("Create", mock.Anything, mock.Anything). - Return(models.AuditRecord{}, nil).Maybe() + Return(auditmodels.AuditRecord{}, nil).Maybe() return orgSvc, roleSvc, policySvc, auditRepo } @@ -60,7 +62,7 @@ func TestService_Create(t *testing.T) { wantErr bool wantErrIs error wantErrMsg string - validateFunc func(t *testing.T, got userpat.PAT, tokenValue string) + validateFunc func(t *testing.T, got models.PAT, tokenValue string) }{ { name: "should return ErrDisabled when PAT feature is disabled", @@ -72,7 +74,7 @@ func TestService_Create(t *testing.T) { ExpiresAt: time.Now().Add(24 * time.Hour), }, wantErr: true, - wantErrIs: userpat.ErrDisabled, + wantErrIs: paterrors.ErrDisabled, setup: func() *userpat.Service { repo := mocks.NewRepository(t) orgSvc := mocks.NewOrganizationService(t) @@ -112,7 +114,7 @@ func TestService_Create(t *testing.T) { ExpiresAt: time.Now().Add(24 * time.Hour), }, wantErr: true, - wantErrIs: userpat.ErrLimitExceeded, + wantErrIs: paterrors.ErrLimitExceeded, setup: func() *userpat.Service { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). @@ -132,7 +134,7 @@ func TestService_Create(t *testing.T) { ExpiresAt: time.Now().Add(24 * time.Hour), }, wantErr: true, - wantErrIs: userpat.ErrLimitExceeded, + wantErrIs: paterrors.ErrLimitExceeded, setup: func() *userpat.Service { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). @@ -157,8 +159,8 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{}, errors.New("insert failed")) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{}, errors.New("insert failed")) orgSvc := mocks.NewOrganizationService(t) auditRepo := mocks.NewAuditRecordRepository(t) roleSvc := mocks.NewRoleService(t) @@ -178,13 +180,13 @@ func TestService_Create(t *testing.T) { ExpiresAt: time.Now().Add(24 * time.Hour), }, wantErr: true, - wantErrIs: userpat.ErrConflict, + wantErrIs: paterrors.ErrConflict, setup: func() *userpat.Service { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{}, userpat.ErrConflict) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{}, paterrors.ErrConflict) orgSvc := mocks.NewOrganizationService(t) auditRepo := mocks.NewAuditRecordRepository(t) roleSvc := mocks.NewRoleService(t) @@ -210,8 +212,8 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Run(func(ctx context.Context, pat userpat.PAT) { + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Run(func(ctx context.Context, pat models.PAT) { if pat.UserID != "user-1" { t.Errorf("Create() UserID = %v, want %v", pat.UserID, "user-1") } @@ -231,7 +233,7 @@ func TestService_Create(t *testing.T) { t.Errorf("Create() ExpiresAt = %v, want %v", pat.ExpiresAt, futureExpiry) } }). - Return(userpat.PAT{ + Return(models.PAT{ ID: "pat-id-1", UserID: "user-1", OrgID: "org-1", @@ -243,7 +245,7 @@ func TestService_Create(t *testing.T) { orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) }, - validateFunc: func(t *testing.T, got userpat.PAT, tokenValue string) { + validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() if got.ID != "pat-id-1" { t.Errorf("Create() ID = %v, want %v", got.ID, "pat-id-1") @@ -270,12 +272,12 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) }, - validateFunc: func(t *testing.T, got userpat.PAT, tokenValue string) { + validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() if !strings.HasPrefix(tokenValue, "fpt_") { t.Errorf("token should start with prefix fpt_, got %v", tokenValue) @@ -307,12 +309,12 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) }, - validateFunc: func(t *testing.T, got userpat.PAT, tokenValue string) { + validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() parts := strings.SplitN(tokenValue, "_", 2) if len(parts) != 2 { @@ -343,8 +345,8 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) return userpat.NewService(log.NewNoop(), repo, userpat.Config{ Enabled: true, @@ -353,7 +355,7 @@ func TestService_Create(t *testing.T) { MaxLifetime: "8760h", }, orgSvc, roleSvc, policySvc, auditRepo) }, - validateFunc: func(t *testing.T, got userpat.PAT, tokenValue string) { + validateFunc: func(t *testing.T, got models.PAT, tokenValue string) { t.Helper() if !strings.HasPrefix(tokenValue, "custom_") { t.Errorf("token should start with custom_, got %v", tokenValue) @@ -374,8 +376,8 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(49), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) }, @@ -394,8 +396,8 @@ func TestService_Create(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) }, @@ -428,8 +430,8 @@ func TestService_Create_UniquePATs(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil).Times(2) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil).Times(2) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil).Times(2) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) svc := userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) @@ -460,11 +462,11 @@ func TestService_Create_HashVerification(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1"). Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Run(func(ctx context.Context, pat userpat.PAT) { + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Run(func(ctx context.Context, pat models.PAT) { capturedHash = pat.SecretHash }). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1"}, nil) + Return(models.PAT{ID: "pat-1", OrgID: "org-1"}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) svc := userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) @@ -499,14 +501,14 @@ func TestService_Create_HashVerification(t *testing.T) { func TestService_CreatePolicies_OrgScopedRole(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) - auditRepo.On("Create", mock.Anything, mock.Anything).Return(models.AuditRecord{}, nil).Maybe() + auditRepo.On("Create", mock.Anything, mock.Anything).Return(auditmodels.AuditRecord{}, nil).Maybe() roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"org-role-1"}}).Return([]role.Role{{ @@ -541,14 +543,14 @@ func TestService_CreatePolicies_OrgScopedRole(t *testing.T) { func TestService_CreatePolicies_ProjectScopedAllProjects(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) - auditRepo.On("Create", mock.Anything, mock.Anything).Return(models.AuditRecord{}, nil).Maybe() + auditRepo.On("Create", mock.Anything, mock.Anything).Return(auditmodels.AuditRecord{}, nil).Maybe() roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"proj-role-1"}}).Return([]role.Role{{ @@ -584,14 +586,14 @@ func TestService_CreatePolicies_ProjectScopedAllProjects(t *testing.T) { func TestService_CreatePolicies_ProjectScopedSpecificProjects(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) - auditRepo.On("Create", mock.Anything, mock.Anything).Return(models.AuditRecord{}, nil).Maybe() + auditRepo.On("Create", mock.Anything, mock.Anything).Return(auditmodels.AuditRecord{}, nil).Maybe() roleSvc := mocks.NewRoleService(t) roleSvc.EXPECT().List(mock.Anything, role.Filter{IDs: []string{"proj-role-1"}}).Return([]role.Role{{ @@ -663,7 +665,7 @@ func TestService_CreatePolicies_DeniedPermission(t *testing.T) { if err == nil { t.Fatal("Create() expected error for denied permission, got nil") } - if !errors.Is(err, userpat.ErrDeniedRole) { + if !errors.Is(err, paterrors.ErrDeniedRole) { t.Errorf("Create() error = %v, want ErrDeniedRole", err) } } @@ -727,7 +729,7 @@ func TestService_CreatePolicies_UnsupportedScope(t *testing.T) { if err == nil { t.Fatal("Create() expected error for unsupported scope, got nil") } - if !errors.Is(err, userpat.ErrUnsupportedScope) { + if !errors.Is(err, paterrors.ErrUnsupportedScope) { t.Errorf("Create() error = %v, want ErrUnsupportedScope", err) } } @@ -761,7 +763,7 @@ func TestService_CreatePolicies_MissingRoleID(t *testing.T) { if err == nil { t.Fatal("Create() expected error for missing role, got nil") } - if !errors.Is(err, userpat.ErrRoleNotFound) { + if !errors.Is(err, paterrors.ErrRoleNotFound) { t.Errorf("Create() error = %v, want ErrRoleNotFound", err) } if !strings.Contains(err.Error(), "role-b") { @@ -772,8 +774,8 @@ func TestService_CreatePolicies_MissingRoleID(t *testing.T) { func TestService_CreatePolicies_NoRoles(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) orgSvc, roleSvc, policySvc, auditRepo := newSuccessMocks(t) svc := userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, roleSvc, policySvc, auditRepo) @@ -1005,7 +1007,7 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { }, want: nil, // no policies should be created wantErr: true, - wantErrIs: userpat.ErrDeniedRole, + wantErrIs: paterrors.ErrDeniedRole, }, { name: "unsupported scope rejects before any policy creation", @@ -1017,7 +1019,7 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { }, want: nil, // scope validation happens upfront — no token or policies created wantErr: true, - wantErrIs: userpat.ErrUnsupportedScope, + wantErrIs: paterrors.ErrUnsupportedScope, }, { name: "role with mixed supported and unsupported scopes is rejected", @@ -1028,7 +1030,7 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { }, want: nil, wantErr: true, - wantErrIs: userpat.ErrUnsupportedScope, + wantErrIs: paterrors.ErrUnsupportedScope, }, { name: "role with empty scopes is unsupported", @@ -1039,7 +1041,7 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { }, want: nil, wantErr: true, - wantErrIs: userpat.ErrUnsupportedScope, + wantErrIs: paterrors.ErrUnsupportedScope, }, { name: "role count mismatch: requested 2 but found 1", @@ -1050,7 +1052,7 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { }, want: nil, wantErr: true, - wantErrIs: userpat.ErrRoleNotFound, + wantErrIs: paterrors.ErrRoleNotFound, }, } @@ -1065,15 +1067,15 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").Return(int64(0), nil) // Only mock repo.Create for success cases — validation errors fail before token creation if !tt.wantErr { - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) } orgSvc := mocks.NewOrganizationService(t) orgSvc.On("GetRaw", mock.Anything, mock.Anything). Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe() auditRepo := mocks.NewAuditRecordRepository(t) - auditRepo.On("Create", mock.Anything, mock.Anything).Return(models.AuditRecord{}, nil).Maybe() + auditRepo.On("Create", mock.Anything, mock.Anything).Return(auditmodels.AuditRecord{}, nil).Maybe() // --- roleService: return the test's roles roleSvc := mocks.NewRoleService(t) @@ -1172,8 +1174,8 @@ func TestService_CreatePolicies_ScopeMatrix(t *testing.T) { func TestService_CreatePolicies_PolicyCreateFailure(t *testing.T) { repo := mocks.NewRepository(t) repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").Return(int64(0), nil) - repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.PAT")). - Return(userpat.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) + repo.EXPECT().Create(mock.Anything, mock.AnythingOfType("models.PAT")). + Return(models.PAT{ID: "pat-1", OrgID: "org-1", CreatedAt: time.Now()}, nil) orgSvc := mocks.NewOrganizationService(t) auditRepo := mocks.NewAuditRecordRepository(t) diff --git a/core/userpat/userpat.go b/core/userpat/userpat.go index 6e2965813..0a3dcc3f7 100644 --- a/core/userpat/userpat.go +++ b/core/userpat/userpat.go @@ -4,23 +4,12 @@ import ( "context" "time" - "github.com/raystack/frontier/pkg/metadata" + "github.com/raystack/frontier/core/userpat/models" ) -type PAT struct { - ID string `rql:"name=id,type=string"` - UserID string `rql:"name=user_id,type=string"` - OrgID string `rql:"name=org_id,type=string"` - Title string `rql:"name=title,type=string"` - SecretHash string `json:"-"` - Metadata metadata.Metadata - LastUsedAt *time.Time `rql:"name=last_used_at,type=datetime"` - ExpiresAt time.Time `rql:"name=expires_at,type=datetime"` - CreatedAt time.Time `rql:"name=created_at,type=datetime"` - UpdatedAt time.Time `rql:"name=updated_at,type=datetime"` -} - type Repository interface { - Create(ctx context.Context, pat PAT) (PAT, error) + Create(ctx context.Context, pat models.PAT) (models.PAT, error) CountActive(ctx context.Context, userID, orgID string) (int64, error) + GetBySecretHash(ctx context.Context, secretHash string) (models.PAT, error) + UpdateLastUsedAt(ctx context.Context, id string, at time.Time) error } diff --git a/core/userpat/validator.go b/core/userpat/validator.go new file mode 100644 index 000000000..ac766ec4f --- /dev/null +++ b/core/userpat/validator.go @@ -0,0 +1,74 @@ +package userpat + +import ( + "context" + "encoding/base64" + "encoding/hex" + "fmt" + "strings" + "time" + + paterrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/core/userpat/models" + "github.com/raystack/salt/log" + "golang.org/x/crypto/sha3" +) + +// Validator validates PAT values during authentication. +// It is separate from Service to avoid circular dependencies in wiring +// (authenticate.Service → userpat.Service → orgService → authenticate.Service). +type Validator struct { + repo Repository + config Config + logger log.Logger +} + +func NewValidator(logger log.Logger, repo Repository, config Config) *Validator { + return &Validator{ + repo: repo, + config: config, + logger: logger, + } +} + +// Validate checks a PAT value and returns the corresponding PAT. +// Returns ErrInvalidPAT if the value doesn't match the configured prefix (allowing +// the auth chain to fall through to the next authenticator). +// Returns ErrMalformedPAT, ErrExpired, ErrNotFound, or ErrDisabled for terminal auth failures. +func (v *Validator) Validate(ctx context.Context, value string) (models.PAT, error) { + if !v.config.Enabled { + return models.PAT{}, paterrors.ErrDisabled + } + + prefix := v.config.Prefix + "_" + if !strings.HasPrefix(value, prefix) { + return models.PAT{}, paterrors.ErrInvalidPAT + } + + encoded := value[len(prefix):] + secretBytes, err := base64.RawURLEncoding.DecodeString(encoded) + if err != nil { + return models.PAT{}, fmt.Errorf("%w: invalid encoding", paterrors.ErrMalformedPAT) + } + + hash := sha3.Sum256(secretBytes) + secretHash := hex.EncodeToString(hash[:]) + + pat, err := v.repo.GetBySecretHash(ctx, secretHash) + if err != nil { + return models.PAT{}, err + } + + if pat.ExpiresAt.Before(time.Now()) { + return models.PAT{}, paterrors.ErrExpired + } + + // async last_used_at update — don't block the auth path + go func() { + if err := v.repo.UpdateLastUsedAt(context.Background(), pat.ID, time.Now()); err != nil { + v.logger.Error("failed to update PAT last_used_at", "pat_id", pat.ID, "error", err) + } + }() + + return pat, nil +} diff --git a/core/userpat/validator_test.go b/core/userpat/validator_test.go new file mode 100644 index 000000000..8f56c5af5 --- /dev/null +++ b/core/userpat/validator_test.go @@ -0,0 +1,136 @@ +package userpat_test + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "errors" + "testing" + "time" + + "github.com/raystack/frontier/core/userpat" + paterrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/core/userpat/mocks" + "github.com/raystack/frontier/core/userpat/models" + "github.com/raystack/salt/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/sha3" +) + +func validPATValue(t *testing.T, prefix string) (value string, secretHash string) { + t.Helper() + secretBytes := make([]byte, 32) + _, err := rand.Read(secretBytes) + require.NoError(t, err) + value = prefix + "_" + base64.RawURLEncoding.EncodeToString(secretBytes) + hash := sha3.Sum256(secretBytes) + secretHash = hex.EncodeToString(hash[:]) + return value, secretHash +} + +func TestValidator_Validate(t *testing.T) { + const prefix = "fpt" + cfg := userpat.Config{ + Enabled: true, + Prefix: prefix, + } + + t.Run("disabled feature returns ErrDisabled", func(t *testing.T) { + v := userpat.NewValidator(log.NewNoop(), nil, userpat.Config{Enabled: false}) + _, err := v.Validate(context.Background(), "fpt_anything") + assert.ErrorIs(t, err, paterrors.ErrDisabled) + }) + + t.Run("wrong prefix returns ErrInvalidPAT", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + _, err := v.Validate(context.Background(), "ghp_sometoken") + assert.ErrorIs(t, err, paterrors.ErrInvalidPAT) + }) + + t.Run("no prefix separator returns ErrInvalidPAT", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + _, err := v.Validate(context.Background(), "randomstring") + assert.ErrorIs(t, err, paterrors.ErrInvalidPAT) + }) + + t.Run("malformed base64 returns ErrMalformedPAT", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + _, err := v.Validate(context.Background(), "fpt_!!!not-base64!!!") + assert.ErrorIs(t, err, paterrors.ErrMalformedPAT) + assert.NotErrorIs(t, err, paterrors.ErrInvalidPAT) + }) + + t.Run("unknown hash returns ErrNotFound", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + value, secretHash := validPATValue(t, prefix) + repo.EXPECT().GetBySecretHash(mock.Anything, secretHash).Return(models.PAT{}, paterrors.ErrNotFound) + + _, err := v.Validate(context.Background(), value) + assert.ErrorIs(t, err, paterrors.ErrNotFound) + }) + + t.Run("expired PAT returns ErrExpired", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + value, secretHash := validPATValue(t, prefix) + repo.EXPECT().GetBySecretHash(mock.Anything, secretHash).Return(models.PAT{ + ID: "pat-1", + ExpiresAt: time.Now().Add(-time.Hour), + }, nil) + + _, err := v.Validate(context.Background(), value) + assert.ErrorIs(t, err, paterrors.ErrExpired) + }) + + t.Run("db error propagates as-is", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + value, secretHash := validPATValue(t, prefix) + dbErr := errors.New("connection refused") + repo.EXPECT().GetBySecretHash(mock.Anything, secretHash).Return(models.PAT{}, dbErr) + + _, err := v.Validate(context.Background(), value) + assert.ErrorIs(t, err, dbErr) + assert.NotErrorIs(t, err, paterrors.ErrInvalidPAT) + }) + + t.Run("valid PAT returns PAT and updates last_used_at", func(t *testing.T) { + repo := mocks.NewRepository(t) + v := userpat.NewValidator(log.NewNoop(), repo, cfg) + + value, secretHash := validPATValue(t, prefix) + expectedPAT := models.PAT{ + ID: "pat-1", + UserID: "user-1", + OrgID: "org-1", + Title: "my-pat", + ExpiresAt: time.Now().Add(time.Hour), + } + repo.EXPECT().GetBySecretHash(mock.Anything, secretHash).Return(expectedPAT, nil) + repo.EXPECT().UpdateLastUsedAt(mock.Anything, "pat-1", mock.AnythingOfType("time.Time")).Return(nil) + + pat, err := v.Validate(context.Background(), value) + require.NoError(t, err) + assert.Equal(t, expectedPAT.ID, pat.ID) + assert.Equal(t, expectedPAT.UserID, pat.UserID) + assert.Equal(t, expectedPAT.OrgID, pat.OrgID) + assert.Equal(t, expectedPAT.Title, pat.Title) + + // wait briefly for the async goroutine to complete + time.Sleep(50 * time.Millisecond) + repo.AssertCalled(t, "UpdateLastUsedAt", mock.Anything, "pat-1", mock.AnythingOfType("time.Time")) + }) +} diff --git a/internal/api/v1beta1connect/authenticate.go b/internal/api/v1beta1connect/authenticate.go index 32dbd3cce..bb4aa4818 100644 --- a/internal/api/v1beta1connect/authenticate.go +++ b/internal/api/v1beta1connect/authenticate.go @@ -14,6 +14,7 @@ import ( "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/user" + patErrors "github.com/raystack/frontier/core/userpat/errors" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/pkg/server/consts" sessionutils "github.com/raystack/frontier/pkg/session" @@ -302,6 +303,11 @@ func (h *ConnectHandler) GetLoggedInPrincipal(ctx context.Context, via ...authen return principal, connect.NewError(connect.CodeNotFound, ErrUserNotExist) case errors.Is(err, errors.ErrUnauthenticated): return principal, connect.NewError(connect.CodeUnauthenticated, ErrUnauthenticated) + case errors.Is(err, patErrors.ErrMalformedPAT), + errors.Is(err, patErrors.ErrNotFound), + errors.Is(err, patErrors.ErrExpired), + errors.Is(err, patErrors.ErrDisabled): + return principal, connect.NewError(connect.CodeUnauthenticated, ErrUnauthenticated) default: return principal, connect.NewError(connect.CodeInternal, err) } diff --git a/internal/api/v1beta1connect/interfaces.go b/internal/api/v1beta1connect/interfaces.go index 235477aa3..3738f39ae 100644 --- a/internal/api/v1beta1connect/interfaces.go +++ b/internal/api/v1beta1connect/interfaces.go @@ -47,6 +47,7 @@ import ( "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" "github.com/raystack/frontier/core/userpat" + "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/core/webhook" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/pkg/metadata" @@ -400,5 +401,5 @@ type AuditRecordService interface { type UserPATService interface { ValidateExpiry(expiresAt time.Time) error - Create(ctx context.Context, req userpat.CreateRequest) (userpat.PAT, string, error) + Create(ctx context.Context, req userpat.CreateRequest) (models.PAT, string, error) } diff --git a/internal/api/v1beta1connect/mocks/user_pat_service.go b/internal/api/v1beta1connect/mocks/user_pat_service.go index ecbdb4362..16e89e76e 100644 --- a/internal/api/v1beta1connect/mocks/user_pat_service.go +++ b/internal/api/v1beta1connect/mocks/user_pat_service.go @@ -4,10 +4,12 @@ package mocks import ( context "context" - time "time" + models "github.com/raystack/frontier/core/userpat/models" mock "github.com/stretchr/testify/mock" + time "time" + userpat "github.com/raystack/frontier/core/userpat" ) @@ -25,23 +27,23 @@ func (_m *UserPATService) EXPECT() *UserPATService_Expecter { } // Create provides a mock function with given fields: ctx, req -func (_m *UserPATService) Create(ctx context.Context, req userpat.CreateRequest) (userpat.PAT, string, error) { +func (_m *UserPATService) Create(ctx context.Context, req userpat.CreateRequest) (models.PAT, string, error) { ret := _m.Called(ctx, req) if len(ret) == 0 { panic("no return value specified for Create") } - var r0 userpat.PAT + var r0 models.PAT var r1 string var r2 error - if rf, ok := ret.Get(0).(func(context.Context, userpat.CreateRequest) (userpat.PAT, string, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, userpat.CreateRequest) (models.PAT, string, error)); ok { return rf(ctx, req) } - if rf, ok := ret.Get(0).(func(context.Context, userpat.CreateRequest) userpat.PAT); ok { + if rf, ok := ret.Get(0).(func(context.Context, userpat.CreateRequest) models.PAT); ok { r0 = rf(ctx, req) } else { - r0 = ret.Get(0).(userpat.PAT) + r0 = ret.Get(0).(models.PAT) } if rf, ok := ret.Get(1).(func(context.Context, userpat.CreateRequest) string); ok { @@ -78,12 +80,12 @@ func (_c *UserPATService_Create_Call) Run(run func(ctx context.Context, req user return _c } -func (_c *UserPATService_Create_Call) Return(_a0 userpat.PAT, _a1 string, _a2 error) *UserPATService_Create_Call { +func (_c *UserPATService_Create_Call) Return(_a0 models.PAT, _a1 string, _a2 error) *UserPATService_Create_Call { _c.Call.Return(_a0, _a1, _a2) return _c } -func (_c *UserPATService_Create_Call) RunAndReturn(run func(context.Context, userpat.CreateRequest) (userpat.PAT, string, error)) *UserPATService_Create_Call { +func (_c *UserPATService_Create_Call) RunAndReturn(run func(context.Context, userpat.CreateRequest) (models.PAT, string, error)) *UserPATService_Create_Call { _c.Call.Return(run) return _c } diff --git a/internal/api/v1beta1connect/organization.go b/internal/api/v1beta1connect/organization.go index 6e80d9a43..30ac95f60 100644 --- a/internal/api/v1beta1connect/organization.go +++ b/internal/api/v1beta1connect/organization.go @@ -9,6 +9,7 @@ import ( "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/policy" "github.com/raystack/frontier/core/project" + "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/role" "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" @@ -144,6 +145,11 @@ func (h *ConnectHandler) CreateOrganization(ctx context.Context, request *connec return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) case errors.Is(err, organization.ErrConflict): return nil, connect.NewError(connect.CodeAlreadyExists, ErrConflictRequest) + case errors.Is(err, relation.ErrSubjectNotAllowed): + errorLogger.LogServiceError(ctx, request, "CreateOrganization.Create", err, + zap.String("org_name", request.Msg.GetBody().GetName()), + zap.String("org_title", request.Msg.GetBody().GetTitle())) + return nil, connect.NewError(connect.CodePermissionDenied, ErrUnauthorized) default: errorLogger.LogServiceError(ctx, request, "CreateOrganization.Create", err, zap.String("org_name", request.Msg.GetBody().GetName()), diff --git a/internal/api/v1beta1connect/user_pat.go b/internal/api/v1beta1connect/user_pat.go index 35a54918f..35bf6dd89 100644 --- a/internal/api/v1beta1connect/user_pat.go +++ b/internal/api/v1beta1connect/user_pat.go @@ -6,6 +6,8 @@ import ( "connectrpc.com/connect" "github.com/raystack/frontier/core/userpat" + paterrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/pkg/metadata" frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1" @@ -47,18 +49,18 @@ func (h *ConnectHandler) CreateCurrentUserPAT(ctx context.Context, request *conn zap.String("org_id", request.Msg.GetOrgId())) switch { - case errors.Is(err, userpat.ErrDisabled): + case errors.Is(err, paterrors.ErrDisabled): return nil, connect.NewError(connect.CodeFailedPrecondition, err) - case errors.Is(err, userpat.ErrConflict): + case errors.Is(err, paterrors.ErrConflict): return nil, connect.NewError(connect.CodeAlreadyExists, err) - case errors.Is(err, userpat.ErrLimitExceeded): + case errors.Is(err, paterrors.ErrLimitExceeded): return nil, connect.NewError(connect.CodeResourceExhausted, err) - case errors.Is(err, userpat.ErrRoleNotFound): - return nil, connect.NewError(connect.CodeInvalidArgument, userpat.ErrRoleNotFound) - case errors.Is(err, userpat.ErrDeniedRole): - return nil, connect.NewError(connect.CodeInvalidArgument, userpat.ErrDeniedRole) - case errors.Is(err, userpat.ErrUnsupportedScope): - return nil, connect.NewError(connect.CodeInvalidArgument, userpat.ErrUnsupportedScope) + case errors.Is(err, paterrors.ErrRoleNotFound): + return nil, connect.NewError(connect.CodeInvalidArgument, paterrors.ErrRoleNotFound) + case errors.Is(err, paterrors.ErrDeniedRole): + return nil, connect.NewError(connect.CodeInvalidArgument, paterrors.ErrDeniedRole) + case errors.Is(err, paterrors.ErrUnsupportedScope): + return nil, connect.NewError(connect.CodeInvalidArgument, paterrors.ErrUnsupportedScope) default: return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) } @@ -69,7 +71,7 @@ func (h *ConnectHandler) CreateCurrentUserPAT(ctx context.Context, request *conn }), nil } -func transformPATToPB(pat userpat.PAT, patValue string) *frontierv1beta1.PAT { +func transformPATToPB(pat models.PAT, patValue string) *frontierv1beta1.PAT { pbPAT := &frontierv1beta1.PAT{ Id: pat.ID, Title: pat.Title, diff --git a/internal/api/v1beta1connect/user_pat_test.go b/internal/api/v1beta1connect/user_pat_test.go index 2ccd9191c..f42e028e8 100644 --- a/internal/api/v1beta1connect/user_pat_test.go +++ b/internal/api/v1beta1connect/user_pat_test.go @@ -10,6 +10,8 @@ import ( "github.com/raystack/frontier/core/authenticate" "github.com/raystack/frontier/core/user" "github.com/raystack/frontier/core/userpat" + paterrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/api/v1beta1connect/mocks" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/pkg/errors" @@ -74,7 +76,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { Type: schema.UserPrincipal, User: &user.User{ID: testUserID}, }, nil) - ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(userpat.ErrExpiryInPast) + ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(paterrors.ErrExpiryInPast) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -83,7 +85,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)), }), want: nil, - wantErr: connect.NewError(connect.CodeInvalidArgument, userpat.ErrExpiryInPast), + wantErr: connect.NewError(connect.CodeInvalidArgument, paterrors.ErrExpiryInPast), }, { name: "should return invalid argument when expiry exceeds max lifetime", @@ -93,7 +95,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { Type: schema.UserPrincipal, User: &user.User{ID: testUserID}, }, nil) - ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(userpat.ErrExpiryExceeded) + ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(paterrors.ErrExpiryExceeded) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -102,7 +104,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(time.Now().Add(48 * time.Hour)), }), want: nil, - wantErr: connect.NewError(connect.CodeInvalidArgument, userpat.ErrExpiryExceeded), + wantErr: connect.NewError(connect.CodeInvalidArgument, paterrors.ErrExpiryExceeded), }, { name: "should return failed precondition when PAT is disabled", @@ -114,7 +116,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", userpat.ErrDisabled) + Return(models.PAT{}, "", paterrors.ErrDisabled) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -123,7 +125,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(testTime), }), want: nil, - wantErr: connect.NewError(connect.CodeFailedPrecondition, userpat.ErrDisabled), + wantErr: connect.NewError(connect.CodeFailedPrecondition, paterrors.ErrDisabled), }, { name: "should return already exists when title conflicts", @@ -135,7 +137,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", userpat.ErrConflict) + Return(models.PAT{}, "", paterrors.ErrConflict) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -144,7 +146,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(testTime), }), want: nil, - wantErr: connect.NewError(connect.CodeAlreadyExists, userpat.ErrConflict), + wantErr: connect.NewError(connect.CodeAlreadyExists, paterrors.ErrConflict), }, { name: "should return resource exhausted when limit exceeded", @@ -156,7 +158,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", userpat.ErrLimitExceeded) + Return(models.PAT{}, "", paterrors.ErrLimitExceeded) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -165,7 +167,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(testTime), }), want: nil, - wantErr: connect.NewError(connect.CodeResourceExhausted, userpat.ErrLimitExceeded), + wantErr: connect.NewError(connect.CodeResourceExhausted, paterrors.ErrLimitExceeded), }, { name: "should return invalid argument when role is not found", @@ -177,7 +179,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", fmt.Errorf("fetching roles: %w", userpat.ErrRoleNotFound)) + Return(models.PAT{}, "", fmt.Errorf("fetching roles: %w", paterrors.ErrRoleNotFound)) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -186,7 +188,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(testTime), }), want: nil, - wantErr: connect.NewError(connect.CodeInvalidArgument, userpat.ErrRoleNotFound), + wantErr: connect.NewError(connect.CodeInvalidArgument, paterrors.ErrRoleNotFound), }, { name: "should return invalid argument when role is denied", @@ -198,7 +200,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", fmt.Errorf("creating policies: %w", userpat.ErrDeniedRole)) + Return(models.PAT{}, "", fmt.Errorf("creating policies: %w", paterrors.ErrDeniedRole)) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -207,7 +209,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(testTime), }), want: nil, - wantErr: connect.NewError(connect.CodeInvalidArgument, userpat.ErrDeniedRole), + wantErr: connect.NewError(connect.CodeInvalidArgument, paterrors.ErrDeniedRole), }, { name: "should return invalid argument when role scope is unsupported", @@ -219,7 +221,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", fmt.Errorf("creating policies: %w", userpat.ErrUnsupportedScope)) + Return(models.PAT{}, "", fmt.Errorf("creating policies: %w", paterrors.ErrUnsupportedScope)) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -228,7 +230,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { ExpiresAt: timestamppb.New(testTime), }), want: nil, - wantErr: connect.NewError(connect.CodeInvalidArgument, userpat.ErrUnsupportedScope), + wantErr: connect.NewError(connect.CodeInvalidArgument, paterrors.ErrUnsupportedScope), }, { name: "should return internal error for unknown service failure", @@ -240,7 +242,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{}, "", errors.New("unexpected error")) + Return(models.PAT{}, "", errors.New("unexpected error")) }, request: connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ Title: "my-token", @@ -265,7 +267,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { req.OrgID == testOrgID && req.Title == "my-token" && len(req.RoleIDs) == 1 && req.RoleIDs[0] == testRoleID - })).Return(userpat.PAT{ + })).Return(models.PAT{ ID: "pat-1", UserID: testUserID, OrgID: testOrgID, @@ -305,7 +307,7 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { }, nil) ps.EXPECT().ValidateExpiry(mock.AnythingOfType("time.Time")).Return(nil) ps.EXPECT().Create(mock.Anything, mock.AnythingOfType("userpat.CreateRequest")). - Return(userpat.PAT{ + Return(models.PAT{ ID: "pat-1", UserID: testUserID, OrgID: testOrgID, @@ -390,13 +392,13 @@ func TestTransformPATToPB(t *testing.T) { tests := []struct { name string - pat userpat.PAT + pat models.PAT patValue string want *frontierv1beta1.PAT }{ { name: "should transform minimal PAT", - pat: userpat.PAT{ + pat: models.PAT{ ID: "pat-1", UserID: "user-1", OrgID: "org-1", @@ -418,7 +420,7 @@ func TestTransformPATToPB(t *testing.T) { }, { name: "should include token value when provided", - pat: userpat.PAT{ + pat: models.PAT{ ID: "pat-1", UserID: "user-1", OrgID: "org-1", @@ -441,7 +443,7 @@ func TestTransformPATToPB(t *testing.T) { }, { name: "should include last_used_at when set", - pat: userpat.PAT{ + pat: models.PAT{ ID: "pat-1", UserID: "user-1", OrgID: "org-1", @@ -465,7 +467,7 @@ func TestTransformPATToPB(t *testing.T) { }, { name: "should include metadata when set", - pat: userpat.PAT{ + pat: models.PAT{ ID: "pat-1", UserID: "user-1", OrgID: "org-1", diff --git a/internal/store/postgres/userpat.go b/internal/store/postgres/userpat.go index 67d2c03e8..c629d93d1 100644 --- a/internal/store/postgres/userpat.go +++ b/internal/store/postgres/userpat.go @@ -4,7 +4,7 @@ import ( "encoding/json" "time" - "github.com/raystack/frontier/core/userpat" + "github.com/raystack/frontier/core/userpat/models" ) type UserPAT struct { @@ -21,14 +21,14 @@ type UserPAT struct { DeletedAt *time.Time `db:"deleted_at"` } -func (t UserPAT) transform() (userpat.PAT, error) { +func (t UserPAT) transform() (models.PAT, error) { var unmarshalledMetadata map[string]any if len(t.Metadata) > 0 { if err := json.Unmarshal(t.Metadata, &unmarshalledMetadata); err != nil { - return userpat.PAT{}, err + return models.PAT{}, err } } - return userpat.PAT{ + return models.PAT{ ID: t.ID, UserID: t.UserID, OrgID: t.OrgID, diff --git a/internal/store/postgres/userpat_repository.go b/internal/store/postgres/userpat_repository.go index 63bf49c51..d53f9bf6c 100644 --- a/internal/store/postgres/userpat_repository.go +++ b/internal/store/postgres/userpat_repository.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -11,7 +12,8 @@ import ( "github.com/doug-martin/goqu/v9" "github.com/google/uuid" - "github.com/raystack/frontier/core/userpat" + paterrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/pkg/db" ) @@ -25,14 +27,14 @@ func NewUserPATRepository(dbc *db.Client) *UserPATRepository { } } -func (r UserPATRepository) Create(ctx context.Context, pat userpat.PAT) (userpat.PAT, error) { +func (r UserPATRepository) Create(ctx context.Context, pat models.PAT) (models.PAT, error) { if strings.TrimSpace(pat.ID) == "" { pat.ID = uuid.New().String() } marshaledMetadata, err := json.Marshal(pat.Metadata) if err != nil { - return userpat.PAT{}, fmt.Errorf("%w: %w", parseErr, err) + return models.PAT{}, fmt.Errorf("%w: %w", parseErr, err) } var model UserPAT @@ -47,7 +49,7 @@ func (r UserPATRepository) Create(ctx context.Context, pat userpat.PAT) (userpat "expires_at": pat.ExpiresAt, }).Returning(&UserPAT{}).ToSQL() if err != nil { - return userpat.PAT{}, fmt.Errorf("%w: %w", queryErr, err) + return models.PAT{}, fmt.Errorf("%w: %w", queryErr, err) } if err = r.dbc.WithTimeout(ctx, TABLE_USER_PATS, "Create", func(ctx context.Context) error { @@ -55,9 +57,9 @@ func (r UserPATRepository) Create(ctx context.Context, pat userpat.PAT) (userpat }); err != nil { err = checkPostgresError(err) if errors.Is(err, ErrDuplicateKey) { - return userpat.PAT{}, userpat.ErrConflict + return models.PAT{}, paterrors.ErrConflict } - return userpat.PAT{}, fmt.Errorf("%w: %w", dbErr, err) + return models.PAT{}, fmt.Errorf("%w: %w", dbErr, err) } return model.transform() @@ -84,3 +86,47 @@ func (r UserPATRepository) CountActive(ctx context.Context, userID, orgID string return count, nil } + +func (r UserPATRepository) GetBySecretHash(ctx context.Context, secretHash string) (models.PAT, error) { + query, params, err := dialect.From(TABLE_USER_PATS). + Select(&UserPAT{}). + Where( + goqu.Ex{"secret_hash": secretHash}, + goqu.Ex{"deleted_at": nil}, + ).Limit(1).ToSQL() + if err != nil { + return models.PAT{}, fmt.Errorf("%w: %w", queryErr, err) + } + + var model UserPAT + if err = r.dbc.WithTimeout(ctx, TABLE_USER_PATS, "GetBySecretHash", func(ctx context.Context) error { + return r.dbc.GetContext(ctx, &model, query, params...) + }); err != nil { + err = checkPostgresError(err) + if errors.Is(err, sql.ErrNoRows) { + return models.PAT{}, paterrors.ErrNotFound + } + return models.PAT{}, fmt.Errorf("%w: %w", dbErr, err) + } + + return model.transform() +} + +func (r UserPATRepository) UpdateLastUsedAt(ctx context.Context, id string, at time.Time) error { + query, params, err := dialect.Update(TABLE_USER_PATS). + Set(goqu.Record{"last_used_at": at}). + Where(goqu.Ex{"id": id}). + ToSQL() + if err != nil { + return fmt.Errorf("%w: %w", queryErr, err) + } + + if err = r.dbc.WithTimeout(ctx, TABLE_USER_PATS, "UpdateLastUsedAt", func(ctx context.Context) error { + _, err := r.dbc.ExecContext(ctx, query, params...) + return err + }); err != nil { + return fmt.Errorf("%w: %w", dbErr, err) + } + + return nil +} diff --git a/internal/store/postgres/userpat_repository_test.go b/internal/store/postgres/userpat_repository_test.go index d01bb7d65..a124ab61c 100644 --- a/internal/store/postgres/userpat_repository_test.go +++ b/internal/store/postgres/userpat_repository_test.go @@ -10,7 +10,8 @@ import ( "github.com/ory/dockertest" "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/user" - "github.com/raystack/frontier/core/userpat" + paterrors "github.com/raystack/frontier/core/userpat/errors" + "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/store/postgres" "github.com/raystack/frontier/pkg/db" "github.com/raystack/salt/log" @@ -74,7 +75,7 @@ func (s *UserPATRepositoryTestSuite) cleanup() error { func (s *UserPATRepositoryTestSuite) TestCreate() { s.Run("should create a token and return it with generated ID", func() { - pat := userpat.PAT{ + pat := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "test-token", @@ -96,7 +97,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { s.Run("should use provided ID if set", func() { customID := uuid.New().String() - pat := userpat.PAT{ + pat := models.PAT{ ID: customID, UserID: s.users[0].ID, OrgID: s.orgs[0].ID, @@ -111,7 +112,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { }) s.Run("should store and return metadata", func() { - pat := userpat.PAT{ + pat := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "token-with-meta", @@ -127,7 +128,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { }) s.Run("should return ErrConflict for duplicate title per user per org", func() { - pat := userpat.PAT{ + pat := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "duplicate-title", @@ -141,11 +142,11 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { pat.ID = "" pat.SecretHash = "hashB" _, err = s.repository.Create(s.ctx, pat) - s.ErrorIs(err, userpat.ErrConflict) + s.ErrorIs(err, paterrors.ErrConflict) }) s.Run("should return ErrConflict for duplicate secret hash", func() { - pat1 := userpat.PAT{ + pat1 := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "token-unique-hash-1", @@ -156,7 +157,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { _, err := s.repository.Create(s.ctx, pat1) s.Require().NoError(err) - pat2 := userpat.PAT{ + pat2 := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "token-unique-hash-2", @@ -164,11 +165,11 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { ExpiresAt: time.Now().Add(24 * time.Hour), } _, err = s.repository.Create(s.ctx, pat2) - s.ErrorIs(err, userpat.ErrConflict) + s.ErrorIs(err, paterrors.ErrConflict) }) s.Run("should allow same title for different users in same org", func() { - pat1 := userpat.PAT{ + pat1 := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "shared-title", @@ -178,7 +179,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { _, err := s.repository.Create(s.ctx, pat1) s.Require().NoError(err) - pat2 := userpat.PAT{ + pat2 := models.PAT{ UserID: s.users[1].ID, OrgID: s.orgs[0].ID, Title: "shared-title", @@ -190,7 +191,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { }) s.Run("should allow same title for same user in different orgs", func() { - pat1 := userpat.PAT{ + pat1 := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "cross-org-title", @@ -200,7 +201,7 @@ func (s *UserPATRepositoryTestSuite) TestCreate() { _, err := s.repository.Create(s.ctx, pat1) s.Require().NoError(err) - pat2 := userpat.PAT{ + pat2 := models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[1].ID, Title: "cross-org-title", @@ -229,7 +230,7 @@ func (s *UserPATRepositoryTestSuite) TestCountActive_ExcludesExpired() { s.truncateTokens() // create an active token - _, err := s.repository.Create(s.ctx, userpat.PAT{ + _, err := s.repository.Create(s.ctx, models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "active-token", @@ -239,7 +240,7 @@ func (s *UserPATRepositoryTestSuite) TestCountActive_ExcludesExpired() { s.Require().NoError(err) // create an expired token - _, err = s.repository.Create(s.ctx, userpat.PAT{ + _, err = s.repository.Create(s.ctx, models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "expired-token", @@ -257,7 +258,7 @@ func (s *UserPATRepositoryTestSuite) TestCountActive_FiltersByUserAndOrg() { s.truncateTokens() // token for user[0] in org[0] - _, err := s.repository.Create(s.ctx, userpat.PAT{ + _, err := s.repository.Create(s.ctx, models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: "user0-org0", @@ -267,7 +268,7 @@ func (s *UserPATRepositoryTestSuite) TestCountActive_FiltersByUserAndOrg() { s.Require().NoError(err) // token for user[1] in org[0] - _, err = s.repository.Create(s.ctx, userpat.PAT{ + _, err = s.repository.Create(s.ctx, models.PAT{ UserID: s.users[1].ID, OrgID: s.orgs[0].ID, Title: "user1-org0", @@ -277,7 +278,7 @@ func (s *UserPATRepositoryTestSuite) TestCountActive_FiltersByUserAndOrg() { s.Require().NoError(err) // token for user[0] in org[1] - _, err = s.repository.Create(s.ctx, userpat.PAT{ + _, err = s.repository.Create(s.ctx, models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[1].ID, Title: "user0-org1", @@ -295,7 +296,7 @@ func (s *UserPATRepositoryTestSuite) TestCountActive_MultipleTokens() { s.truncateTokens() for i := 0; i < 3; i++ { - _, err := s.repository.Create(s.ctx, userpat.PAT{ + _, err := s.repository.Create(s.ctx, models.PAT{ UserID: s.users[0].ID, OrgID: s.orgs[0].ID, Title: fmt.Sprintf("multi-token-%d", i), diff --git a/pkg/server/connect_interceptors/session.go b/pkg/server/connect_interceptors/session.go index 4a44f7378..4c782f7c0 100644 --- a/pkg/server/connect_interceptors/session.go +++ b/pkg/server/connect_interceptors/session.go @@ -8,6 +8,7 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt" "github.com/raystack/frontier/core/authenticate" + "github.com/raystack/frontier/core/userpat" "github.com/raystack/frontier/internal/api/v1beta1connect" "github.com/raystack/frontier/pkg/server/consts" @@ -22,6 +23,7 @@ type SessionInterceptor struct { // use secure cookie EncodeMulti/DecodeMulti cookieCodec securecookie.Codec conf authenticate.SessionConfig + patConf userpat.Config h *v1beta1connect.ConnectHandler } @@ -66,6 +68,8 @@ func (s *SessionInterceptor) WrapStreamingHandler(next connect.StreamingHandlerF if token.JwtID() != "" && token.Expiration().After(time.Now().UTC()) { incomingMD.Set(consts.UserTokenGatewayKey, tokenVal) } + } else if s.patConf.Prefix != "" && strings.HasPrefix(tokenVal, s.patConf.Prefix+"_") { + incomingMD.Set(consts.UserTokenGatewayKey, tokenVal) } secretVal := strings.TrimSpace(strings.TrimPrefix(authHeader[0], "Basic ")) if len(secretVal) > 0 { @@ -112,6 +116,8 @@ func (s *SessionInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc if token.JwtID() != "" && token.Expiration().After(time.Now().UTC()) { incomingMD.Set(consts.UserTokenGatewayKey, tokenVal) } + } else if s.patConf.Prefix != "" && strings.HasPrefix(tokenVal, s.patConf.Prefix+"_") { + incomingMD.Set(consts.UserTokenGatewayKey, tokenVal) } secretVal := strings.TrimSpace(strings.TrimPrefix(authHeader[0], "Basic ")) if len(secretVal) > 0 { @@ -197,12 +203,13 @@ func (s *SessionInterceptor) UnaryConnectResponseInterceptor() connect.UnaryInte return connect.UnaryInterceptorFunc(interceptor) } -func NewSessionInterceptor(cookieCutter securecookie.Codec, conf authenticate.SessionConfig, h *v1beta1connect.ConnectHandler) *SessionInterceptor { +func NewSessionInterceptor(cookieCutter securecookie.Codec, conf authenticate.SessionConfig, h *v1beta1connect.ConnectHandler, patConf userpat.Config) *SessionInterceptor { return &SessionInterceptor{ // could be nil if not configured by user cookieCodec: cookieCutter, conf: conf, h: h, + patConf: patConf, } } @@ -256,6 +263,8 @@ func (s *SessionInterceptor) UnaryConnectRequestHeadersAnnotator() connect.Unary if token.JwtID() != "" && token.Expiration().After(time.Now().UTC()) { incomingMD.Set(consts.UserTokenGatewayKey, tokenVal) } + } else if s.patConf.Prefix != "" && strings.HasPrefix(tokenVal, s.patConf.Prefix+"_") { + incomingMD.Set(consts.UserTokenGatewayKey, tokenVal) } secretVal := strings.TrimSpace(strings.TrimPrefix(authHeader[0], "Basic ")) if len(secretVal) > 0 { diff --git a/pkg/server/server.go b/pkg/server/server.go index bb9a766f0..802f1ee6f 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -146,7 +146,7 @@ func ServeConnect(ctx context.Context, logger log.Logger, cfg Config, deps api.D authNInterceptor := connectinterceptors.NewAuthenticationInterceptor(frontierService, cfg.Authentication.Session.Headers) authZInterceptor := connectinterceptors.NewAuthorizationInterceptor(frontierService) - sessionInterceptor := connectinterceptors.NewSessionInterceptor(sessionCookieCutter, cfg.Authentication.Session, frontierService) + sessionInterceptor := connectinterceptors.NewSessionInterceptor(sessionCookieCutter, cfg.Authentication.Session, frontierService, cfg.PAT) auditInterceptor := connectinterceptors.NewAuditInterceptor(deps.AuditService) interceptors := connect.WithInterceptors( From 250929a5a06f68dfb61b19ca3fdee857eacd2445 Mon Sep 17 00:00:00 2001 From: aman Date: Tue, 10 Mar 2026 11:59:53 +0530 Subject: [PATCH 2/5] refactor: relocate `errSkip`, improve PAT validation test, and clarify error message --- core/authenticate/authenticators.go | 3 --- core/authenticate/errors.go | 3 +++ core/userpat/service.go | 2 +- core/userpat/validator.go | 2 -- core/userpat/validator_test.go | 14 ++++++++++---- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/core/authenticate/authenticators.go b/core/authenticate/authenticators.go index 27e6855b0..5dd361d36 100644 --- a/core/authenticate/authenticators.go +++ b/core/authenticate/authenticators.go @@ -20,9 +20,6 @@ import ( // or any other error for a terminal authentication failure. type AuthenticatorFunc func(ctx context.Context, s *Service) (Principal, error) -// errSkip signals that this authenticator doesn't apply to the request. -var errSkip = errors.New("skip authenticator") - // authenticators maps each ClientAssertion to its authentication function. var authenticators = map[ClientAssertion]AuthenticatorFunc{ SessionClientAssertion: authenticateWithSession, diff --git a/core/authenticate/errors.go b/core/authenticate/errors.go index d2e4172bc..94002e7a5 100644 --- a/core/authenticate/errors.go +++ b/core/authenticate/errors.go @@ -4,4 +4,7 @@ import "errors" var ( ErrInvalidID = errors.New("user id is invalid") + + // errSkip signals that this authenticator doesn't apply to the request. + errSkip = errors.New("skip authenticator") ) diff --git a/core/userpat/service.go b/core/userpat/service.go index 810fcb77b..da9ca16e4 100644 --- a/core/userpat/service.go +++ b/core/userpat/service.go @@ -205,7 +205,7 @@ func (s *Service) resolveAndValidateRoles(ctx context.Context, roleIDs []string) for _, r := range roles { if len(r.Scopes) == 0 { - return nil, fmt.Errorf("role %s has scopes %v: %w", r.Name, r.Scopes, paterrors.ErrUnsupportedScope) + return nil, fmt.Errorf("role %s has no scopes defined: %w", r.Name, paterrors.ErrUnsupportedScope) } for _, scope := range r.Scopes { if scope != schema.ProjectNamespace && scope != schema.OrganizationNamespace { diff --git a/core/userpat/validator.go b/core/userpat/validator.go index ac766ec4f..4a85f13c0 100644 --- a/core/userpat/validator.go +++ b/core/userpat/validator.go @@ -15,8 +15,6 @@ import ( ) // Validator validates PAT values during authentication. -// It is separate from Service to avoid circular dependencies in wiring -// (authenticate.Service → userpat.Service → orgService → authenticate.Service). type Validator struct { repo Repository config Config diff --git a/core/userpat/validator_test.go b/core/userpat/validator_test.go index 8f56c5af5..e593b903e 100644 --- a/core/userpat/validator_test.go +++ b/core/userpat/validator_test.go @@ -110,6 +110,7 @@ func TestValidator_Validate(t *testing.T) { t.Run("valid PAT returns PAT and updates last_used_at", func(t *testing.T) { repo := mocks.NewRepository(t) v := userpat.NewValidator(log.NewNoop(), repo, cfg) + done := make(chan struct{}) value, secretHash := validPATValue(t, prefix) expectedPAT := models.PAT{ @@ -120,7 +121,10 @@ func TestValidator_Validate(t *testing.T) { ExpiresAt: time.Now().Add(time.Hour), } repo.EXPECT().GetBySecretHash(mock.Anything, secretHash).Return(expectedPAT, nil) - repo.EXPECT().UpdateLastUsedAt(mock.Anything, "pat-1", mock.AnythingOfType("time.Time")).Return(nil) + repo.EXPECT(). + UpdateLastUsedAt(mock.Anything, "pat-1", mock.AnythingOfType("time.Time")). + Run(func(_ context.Context, _ string, _ time.Time) { close(done) }). + Return(nil) pat, err := v.Validate(context.Background(), value) require.NoError(t, err) @@ -129,8 +133,10 @@ func TestValidator_Validate(t *testing.T) { assert.Equal(t, expectedPAT.OrgID, pat.OrgID) assert.Equal(t, expectedPAT.Title, pat.Title) - // wait briefly for the async goroutine to complete - time.Sleep(50 * time.Millisecond) - repo.AssertCalled(t, "UpdateLastUsedAt", mock.Anything, "pat-1", mock.AnythingOfType("time.Time")) + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("UpdateLastUsedAt was not called") + } }) } From 5af6e2d2ac5e42acf8ec2feeaa0a6dd6e4a068eb Mon Sep 17 00:00:00 2001 From: aman Date: Tue, 10 Mar 2026 14:56:12 +0530 Subject: [PATCH 3/5] chore: add clarification comment for nullable `last_used_at` field in PAT model --- core/userpat/models/pat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/userpat/models/pat.go b/core/userpat/models/pat.go index 18caea4dd..618843d07 100644 --- a/core/userpat/models/pat.go +++ b/core/userpat/models/pat.go @@ -13,7 +13,7 @@ type PAT struct { Title string `rql:"name=title,type=string"` SecretHash string `json:"-"` Metadata metadata.Metadata - LastUsedAt *time.Time `rql:"name=last_used_at,type=datetime"` + LastUsedAt *time.Time `rql:"name=last_used_at,type=datetime"` // last_used_at can be null ExpiresAt time.Time `rql:"name=expires_at,type=datetime"` CreatedAt time.Time `rql:"name=created_at,type=datetime"` UpdatedAt time.Time `rql:"name=updated_at,type=datetime"` From f50bfe80ebe45e77b89d6a6970486e819936446a Mon Sep 17 00:00:00 2001 From: aman Date: Wed, 11 Mar 2026 13:59:47 +0530 Subject: [PATCH 4/5] feat: enforce PAT scope on permission checks --- cmd/serve.go | 1 + core/authenticate/mocks/authenticator_func.go | 95 +++++ core/resource/mocks/authn_service.go | 109 +++++ core/resource/mocks/config_repository.go | 95 +++++ core/resource/mocks/org_service.go | 94 +++++ core/resource/mocks/pat_service.go | 94 +++++ core/resource/mocks/project_service.go | 94 +++++ core/resource/mocks/relation_service.go | 257 ++++++++++++ core/resource/mocks/repository.go | 371 ++++++++++++++++++ core/resource/service.go | 162 +++++++- core/resource/service_test.go | 307 +++++++++++++++ core/userpat/mocks/repository.go | 57 +++ core/userpat/service.go | 4 + core/userpat/userpat.go | 1 + internal/store/postgres/userpat_repository.go | 25 ++ test/e2e/regression/pat_test.go | 365 +++++++++++++++++ 16 files changed, 2121 insertions(+), 10 deletions(-) create mode 100644 core/authenticate/mocks/authenticator_func.go create mode 100644 core/resource/mocks/authn_service.go create mode 100644 core/resource/mocks/config_repository.go create mode 100644 core/resource/mocks/org_service.go create mode 100644 core/resource/mocks/pat_service.go create mode 100644 core/resource/mocks/project_service.go create mode 100644 core/resource/mocks/relation_service.go create mode 100644 core/resource/mocks/repository.go create mode 100644 core/resource/service_test.go create mode 100644 test/e2e/regression/pat_test.go diff --git a/cmd/serve.go b/cmd/serve.go index 5c0dee646..d86137756 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -473,6 +473,7 @@ func buildAPIDependencies( authnService, projectService, organizationService, + userPATService, ) invitationService := invitation.NewService(mailDialer, postgres.NewInvitationRepository(logger, dbc), diff --git a/core/authenticate/mocks/authenticator_func.go b/core/authenticate/mocks/authenticator_func.go new file mode 100644 index 000000000..837935b78 --- /dev/null +++ b/core/authenticate/mocks/authenticator_func.go @@ -0,0 +1,95 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + authenticate "github.com/raystack/frontier/core/authenticate" + + mock "github.com/stretchr/testify/mock" +) + +// AuthenticatorFunc is an autogenerated mock type for the AuthenticatorFunc type +type AuthenticatorFunc struct { + mock.Mock +} + +type AuthenticatorFunc_Expecter struct { + mock *mock.Mock +} + +func (_m *AuthenticatorFunc) EXPECT() *AuthenticatorFunc_Expecter { + return &AuthenticatorFunc_Expecter{mock: &_m.Mock} +} + +// Execute provides a mock function with given fields: ctx, s +func (_m *AuthenticatorFunc) Execute(ctx context.Context, s *authenticate.Service) (authenticate.Principal, error) { + ret := _m.Called(ctx, s) + + if len(ret) == 0 { + panic("no return value specified for Execute") + } + + var r0 authenticate.Principal + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *authenticate.Service) (authenticate.Principal, error)); ok { + return rf(ctx, s) + } + if rf, ok := ret.Get(0).(func(context.Context, *authenticate.Service) authenticate.Principal); ok { + r0 = rf(ctx, s) + } else { + r0 = ret.Get(0).(authenticate.Principal) + } + + if rf, ok := ret.Get(1).(func(context.Context, *authenticate.Service) error); ok { + r1 = rf(ctx, s) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// AuthenticatorFunc_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' +type AuthenticatorFunc_Execute_Call struct { + *mock.Call +} + +// Execute is a helper method to define mock.On call +// - ctx context.Context +// - s *authenticate.Service +func (_e *AuthenticatorFunc_Expecter) Execute(ctx interface{}, s interface{}) *AuthenticatorFunc_Execute_Call { + return &AuthenticatorFunc_Execute_Call{Call: _e.mock.On("Execute", ctx, s)} +} + +func (_c *AuthenticatorFunc_Execute_Call) Run(run func(ctx context.Context, s *authenticate.Service)) *AuthenticatorFunc_Execute_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*authenticate.Service)) + }) + return _c +} + +func (_c *AuthenticatorFunc_Execute_Call) Return(_a0 authenticate.Principal, _a1 error) *AuthenticatorFunc_Execute_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *AuthenticatorFunc_Execute_Call) RunAndReturn(run func(context.Context, *authenticate.Service) (authenticate.Principal, error)) *AuthenticatorFunc_Execute_Call { + _c.Call.Return(run) + return _c +} + +// NewAuthenticatorFunc creates a new instance of AuthenticatorFunc. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewAuthenticatorFunc(t interface { + mock.TestingT + Cleanup(func()) +}) *AuthenticatorFunc { + mock := &AuthenticatorFunc{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/resource/mocks/authn_service.go b/core/resource/mocks/authn_service.go new file mode 100644 index 000000000..7d0536f98 --- /dev/null +++ b/core/resource/mocks/authn_service.go @@ -0,0 +1,109 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + authenticate "github.com/raystack/frontier/core/authenticate" + + mock "github.com/stretchr/testify/mock" +) + +// AuthnService is an autogenerated mock type for the AuthnService type +type AuthnService struct { + mock.Mock +} + +type AuthnService_Expecter struct { + mock *mock.Mock +} + +func (_m *AuthnService) EXPECT() *AuthnService_Expecter { + return &AuthnService_Expecter{mock: &_m.Mock} +} + +// GetPrincipal provides a mock function with given fields: ctx, via +func (_m *AuthnService) GetPrincipal(ctx context.Context, via ...authenticate.ClientAssertion) (authenticate.Principal, error) { + _va := make([]interface{}, len(via)) + for _i := range via { + _va[_i] = via[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for GetPrincipal") + } + + var r0 authenticate.Principal + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ...authenticate.ClientAssertion) (authenticate.Principal, error)); ok { + return rf(ctx, via...) + } + if rf, ok := ret.Get(0).(func(context.Context, ...authenticate.ClientAssertion) authenticate.Principal); ok { + r0 = rf(ctx, via...) + } else { + r0 = ret.Get(0).(authenticate.Principal) + } + + if rf, ok := ret.Get(1).(func(context.Context, ...authenticate.ClientAssertion) error); ok { + r1 = rf(ctx, via...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// AuthnService_GetPrincipal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPrincipal' +type AuthnService_GetPrincipal_Call struct { + *mock.Call +} + +// GetPrincipal is a helper method to define mock.On call +// - ctx context.Context +// - via ...authenticate.ClientAssertion +func (_e *AuthnService_Expecter) GetPrincipal(ctx interface{}, via ...interface{}) *AuthnService_GetPrincipal_Call { + return &AuthnService_GetPrincipal_Call{Call: _e.mock.On("GetPrincipal", + append([]interface{}{ctx}, via...)...)} +} + +func (_c *AuthnService_GetPrincipal_Call) Run(run func(ctx context.Context, via ...authenticate.ClientAssertion)) *AuthnService_GetPrincipal_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]authenticate.ClientAssertion, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(authenticate.ClientAssertion) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *AuthnService_GetPrincipal_Call) Return(_a0 authenticate.Principal, _a1 error) *AuthnService_GetPrincipal_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *AuthnService_GetPrincipal_Call) RunAndReturn(run func(context.Context, ...authenticate.ClientAssertion) (authenticate.Principal, error)) *AuthnService_GetPrincipal_Call { + _c.Call.Return(run) + return _c +} + +// NewAuthnService creates a new instance of AuthnService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewAuthnService(t interface { + mock.TestingT + Cleanup(func()) +}) *AuthnService { + mock := &AuthnService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/resource/mocks/config_repository.go b/core/resource/mocks/config_repository.go new file mode 100644 index 000000000..d7c2162f8 --- /dev/null +++ b/core/resource/mocks/config_repository.go @@ -0,0 +1,95 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + resource "github.com/raystack/frontier/core/resource" + mock "github.com/stretchr/testify/mock" +) + +// ConfigRepository is an autogenerated mock type for the ConfigRepository type +type ConfigRepository struct { + mock.Mock +} + +type ConfigRepository_Expecter struct { + mock *mock.Mock +} + +func (_m *ConfigRepository) EXPECT() *ConfigRepository_Expecter { + return &ConfigRepository_Expecter{mock: &_m.Mock} +} + +// GetAll provides a mock function with given fields: ctx +func (_m *ConfigRepository) GetAll(ctx context.Context) ([]resource.YAML, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetAll") + } + + var r0 []resource.YAML + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]resource.YAML, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []resource.YAML); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]resource.YAML) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ConfigRepository_GetAll_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAll' +type ConfigRepository_GetAll_Call struct { + *mock.Call +} + +// GetAll is a helper method to define mock.On call +// - ctx context.Context +func (_e *ConfigRepository_Expecter) GetAll(ctx interface{}) *ConfigRepository_GetAll_Call { + return &ConfigRepository_GetAll_Call{Call: _e.mock.On("GetAll", ctx)} +} + +func (_c *ConfigRepository_GetAll_Call) Run(run func(ctx context.Context)) *ConfigRepository_GetAll_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *ConfigRepository_GetAll_Call) Return(_a0 []resource.YAML, _a1 error) *ConfigRepository_GetAll_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ConfigRepository_GetAll_Call) RunAndReturn(run func(context.Context) ([]resource.YAML, error)) *ConfigRepository_GetAll_Call { + _c.Call.Return(run) + return _c +} + +// NewConfigRepository creates a new instance of ConfigRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewConfigRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *ConfigRepository { + mock := &ConfigRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/resource/mocks/org_service.go b/core/resource/mocks/org_service.go new file mode 100644 index 000000000..ab3213749 --- /dev/null +++ b/core/resource/mocks/org_service.go @@ -0,0 +1,94 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + organization "github.com/raystack/frontier/core/organization" + mock "github.com/stretchr/testify/mock" +) + +// OrgService is an autogenerated mock type for the OrgService type +type OrgService struct { + mock.Mock +} + +type OrgService_Expecter struct { + mock *mock.Mock +} + +func (_m *OrgService) EXPECT() *OrgService_Expecter { + return &OrgService_Expecter{mock: &_m.Mock} +} + +// Get provides a mock function with given fields: ctx, idOrName +func (_m *OrgService) Get(ctx context.Context, idOrName string) (organization.Organization, error) { + ret := _m.Called(ctx, idOrName) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 organization.Organization + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (organization.Organization, error)); ok { + return rf(ctx, idOrName) + } + if rf, ok := ret.Get(0).(func(context.Context, string) organization.Organization); ok { + r0 = rf(ctx, idOrName) + } else { + r0 = ret.Get(0).(organization.Organization) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, idOrName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// OrgService_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type OrgService_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - idOrName string +func (_e *OrgService_Expecter) Get(ctx interface{}, idOrName interface{}) *OrgService_Get_Call { + return &OrgService_Get_Call{Call: _e.mock.On("Get", ctx, idOrName)} +} + +func (_c *OrgService_Get_Call) Run(run func(ctx context.Context, idOrName string)) *OrgService_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *OrgService_Get_Call) Return(_a0 organization.Organization, _a1 error) *OrgService_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *OrgService_Get_Call) RunAndReturn(run func(context.Context, string) (organization.Organization, error)) *OrgService_Get_Call { + _c.Call.Return(run) + return _c +} + +// NewOrgService creates a new instance of OrgService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewOrgService(t interface { + mock.TestingT + Cleanup(func()) +}) *OrgService { + mock := &OrgService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/resource/mocks/pat_service.go b/core/resource/mocks/pat_service.go new file mode 100644 index 000000000..72d530bf6 --- /dev/null +++ b/core/resource/mocks/pat_service.go @@ -0,0 +1,94 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + models "github.com/raystack/frontier/core/userpat/models" + mock "github.com/stretchr/testify/mock" +) + +// PATService is an autogenerated mock type for the PATService type +type PATService struct { + mock.Mock +} + +type PATService_Expecter struct { + mock *mock.Mock +} + +func (_m *PATService) EXPECT() *PATService_Expecter { + return &PATService_Expecter{mock: &_m.Mock} +} + +// GetByID provides a mock function with given fields: ctx, id +func (_m *PATService) GetByID(ctx context.Context, id string) (models.PAT, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetByID") + } + + var r0 models.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (models.PAT, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) models.PAT); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(models.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PATService_GetByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByID' +type PATService_GetByID_Call struct { + *mock.Call +} + +// GetByID is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *PATService_Expecter) GetByID(ctx interface{}, id interface{}) *PATService_GetByID_Call { + return &PATService_GetByID_Call{Call: _e.mock.On("GetByID", ctx, id)} +} + +func (_c *PATService_GetByID_Call) Run(run func(ctx context.Context, id string)) *PATService_GetByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *PATService_GetByID_Call) Return(_a0 models.PAT, _a1 error) *PATService_GetByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *PATService_GetByID_Call) RunAndReturn(run func(context.Context, string) (models.PAT, error)) *PATService_GetByID_Call { + _c.Call.Return(run) + return _c +} + +// NewPATService creates a new instance of PATService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewPATService(t interface { + mock.TestingT + Cleanup(func()) +}) *PATService { + mock := &PATService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/resource/mocks/project_service.go b/core/resource/mocks/project_service.go new file mode 100644 index 000000000..9fd776116 --- /dev/null +++ b/core/resource/mocks/project_service.go @@ -0,0 +1,94 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + project "github.com/raystack/frontier/core/project" + mock "github.com/stretchr/testify/mock" +) + +// ProjectService is an autogenerated mock type for the ProjectService type +type ProjectService struct { + mock.Mock +} + +type ProjectService_Expecter struct { + mock *mock.Mock +} + +func (_m *ProjectService) EXPECT() *ProjectService_Expecter { + return &ProjectService_Expecter{mock: &_m.Mock} +} + +// Get provides a mock function with given fields: ctx, idOrName +func (_m *ProjectService) Get(ctx context.Context, idOrName string) (project.Project, error) { + ret := _m.Called(ctx, idOrName) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 project.Project + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (project.Project, error)); ok { + return rf(ctx, idOrName) + } + if rf, ok := ret.Get(0).(func(context.Context, string) project.Project); ok { + r0 = rf(ctx, idOrName) + } else { + r0 = ret.Get(0).(project.Project) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, idOrName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ProjectService_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type ProjectService_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - idOrName string +func (_e *ProjectService_Expecter) Get(ctx interface{}, idOrName interface{}) *ProjectService_Get_Call { + return &ProjectService_Get_Call{Call: _e.mock.On("Get", ctx, idOrName)} +} + +func (_c *ProjectService_Get_Call) Run(run func(ctx context.Context, idOrName string)) *ProjectService_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *ProjectService_Get_Call) Return(_a0 project.Project, _a1 error) *ProjectService_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ProjectService_Get_Call) RunAndReturn(run func(context.Context, string) (project.Project, error)) *ProjectService_Get_Call { + _c.Call.Return(run) + return _c +} + +// NewProjectService creates a new instance of ProjectService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewProjectService(t interface { + mock.TestingT + Cleanup(func()) +}) *ProjectService { + mock := &ProjectService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/resource/mocks/relation_service.go b/core/resource/mocks/relation_service.go new file mode 100644 index 000000000..507e8e644 --- /dev/null +++ b/core/resource/mocks/relation_service.go @@ -0,0 +1,257 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + relation "github.com/raystack/frontier/core/relation" + mock "github.com/stretchr/testify/mock" +) + +// RelationService is an autogenerated mock type for the RelationService type +type RelationService struct { + mock.Mock +} + +type RelationService_Expecter struct { + mock *mock.Mock +} + +func (_m *RelationService) EXPECT() *RelationService_Expecter { + return &RelationService_Expecter{mock: &_m.Mock} +} + +// BatchCheckPermission provides a mock function with given fields: ctx, relations +func (_m *RelationService) BatchCheckPermission(ctx context.Context, relations []relation.Relation) ([]relation.CheckPair, error) { + ret := _m.Called(ctx, relations) + + if len(ret) == 0 { + panic("no return value specified for BatchCheckPermission") + } + + var r0 []relation.CheckPair + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []relation.Relation) ([]relation.CheckPair, error)); ok { + return rf(ctx, relations) + } + if rf, ok := ret.Get(0).(func(context.Context, []relation.Relation) []relation.CheckPair); ok { + r0 = rf(ctx, relations) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]relation.CheckPair) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []relation.Relation) error); ok { + r1 = rf(ctx, relations) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RelationService_BatchCheckPermission_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BatchCheckPermission' +type RelationService_BatchCheckPermission_Call struct { + *mock.Call +} + +// BatchCheckPermission is a helper method to define mock.On call +// - ctx context.Context +// - relations []relation.Relation +func (_e *RelationService_Expecter) BatchCheckPermission(ctx interface{}, relations interface{}) *RelationService_BatchCheckPermission_Call { + return &RelationService_BatchCheckPermission_Call{Call: _e.mock.On("BatchCheckPermission", ctx, relations)} +} + +func (_c *RelationService_BatchCheckPermission_Call) Run(run func(ctx context.Context, relations []relation.Relation)) *RelationService_BatchCheckPermission_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]relation.Relation)) + }) + return _c +} + +func (_c *RelationService_BatchCheckPermission_Call) Return(_a0 []relation.CheckPair, _a1 error) *RelationService_BatchCheckPermission_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *RelationService_BatchCheckPermission_Call) RunAndReturn(run func(context.Context, []relation.Relation) ([]relation.CheckPair, error)) *RelationService_BatchCheckPermission_Call { + _c.Call.Return(run) + return _c +} + +// CheckPermission provides a mock function with given fields: ctx, rel +func (_m *RelationService) CheckPermission(ctx context.Context, rel relation.Relation) (bool, error) { + ret := _m.Called(ctx, rel) + + if len(ret) == 0 { + panic("no return value specified for CheckPermission") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, relation.Relation) (bool, error)); ok { + return rf(ctx, rel) + } + if rf, ok := ret.Get(0).(func(context.Context, relation.Relation) bool); ok { + r0 = rf(ctx, rel) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, relation.Relation) error); ok { + r1 = rf(ctx, rel) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RelationService_CheckPermission_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckPermission' +type RelationService_CheckPermission_Call struct { + *mock.Call +} + +// CheckPermission is a helper method to define mock.On call +// - ctx context.Context +// - rel relation.Relation +func (_e *RelationService_Expecter) CheckPermission(ctx interface{}, rel interface{}) *RelationService_CheckPermission_Call { + return &RelationService_CheckPermission_Call{Call: _e.mock.On("CheckPermission", ctx, rel)} +} + +func (_c *RelationService_CheckPermission_Call) Run(run func(ctx context.Context, rel relation.Relation)) *RelationService_CheckPermission_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(relation.Relation)) + }) + return _c +} + +func (_c *RelationService_CheckPermission_Call) Return(_a0 bool, _a1 error) *RelationService_CheckPermission_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *RelationService_CheckPermission_Call) RunAndReturn(run func(context.Context, relation.Relation) (bool, error)) *RelationService_CheckPermission_Call { + _c.Call.Return(run) + return _c +} + +// Create provides a mock function with given fields: ctx, rel +func (_m *RelationService) Create(ctx context.Context, rel relation.Relation) (relation.Relation, error) { + ret := _m.Called(ctx, rel) + + if len(ret) == 0 { + panic("no return value specified for Create") + } + + var r0 relation.Relation + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, relation.Relation) (relation.Relation, error)); ok { + return rf(ctx, rel) + } + if rf, ok := ret.Get(0).(func(context.Context, relation.Relation) relation.Relation); ok { + r0 = rf(ctx, rel) + } else { + r0 = ret.Get(0).(relation.Relation) + } + + if rf, ok := ret.Get(1).(func(context.Context, relation.Relation) error); ok { + r1 = rf(ctx, rel) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RelationService_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' +type RelationService_Create_Call struct { + *mock.Call +} + +// Create is a helper method to define mock.On call +// - ctx context.Context +// - rel relation.Relation +func (_e *RelationService_Expecter) Create(ctx interface{}, rel interface{}) *RelationService_Create_Call { + return &RelationService_Create_Call{Call: _e.mock.On("Create", ctx, rel)} +} + +func (_c *RelationService_Create_Call) Run(run func(ctx context.Context, rel relation.Relation)) *RelationService_Create_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(relation.Relation)) + }) + return _c +} + +func (_c *RelationService_Create_Call) Return(_a0 relation.Relation, _a1 error) *RelationService_Create_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *RelationService_Create_Call) RunAndReturn(run func(context.Context, relation.Relation) (relation.Relation, error)) *RelationService_Create_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, rel +func (_m *RelationService) Delete(ctx context.Context, rel relation.Relation) error { + ret := _m.Called(ctx, rel) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, relation.Relation) error); ok { + r0 = rf(ctx, rel) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RelationService_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type RelationService_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - rel relation.Relation +func (_e *RelationService_Expecter) Delete(ctx interface{}, rel interface{}) *RelationService_Delete_Call { + return &RelationService_Delete_Call{Call: _e.mock.On("Delete", ctx, rel)} +} + +func (_c *RelationService_Delete_Call) Run(run func(ctx context.Context, rel relation.Relation)) *RelationService_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(relation.Relation)) + }) + return _c +} + +func (_c *RelationService_Delete_Call) Return(_a0 error) *RelationService_Delete_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *RelationService_Delete_Call) RunAndReturn(run func(context.Context, relation.Relation) error) *RelationService_Delete_Call { + _c.Call.Return(run) + return _c +} + +// NewRelationService creates a new instance of RelationService. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRelationService(t interface { + mock.TestingT + Cleanup(func()) +}) *RelationService { + mock := &RelationService{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/resource/mocks/repository.go b/core/resource/mocks/repository.go new file mode 100644 index 000000000..32cd3b036 --- /dev/null +++ b/core/resource/mocks/repository.go @@ -0,0 +1,371 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package mocks + +import ( + context "context" + + resource "github.com/raystack/frontier/core/resource" + mock "github.com/stretchr/testify/mock" +) + +// Repository is an autogenerated mock type for the Repository type +type Repository struct { + mock.Mock +} + +type Repository_Expecter struct { + mock *mock.Mock +} + +func (_m *Repository) EXPECT() *Repository_Expecter { + return &Repository_Expecter{mock: &_m.Mock} +} + +// Create provides a mock function with given fields: ctx, _a1 +func (_m *Repository) Create(ctx context.Context, _a1 resource.Resource) (resource.Resource, error) { + ret := _m.Called(ctx, _a1) + + if len(ret) == 0 { + panic("no return value specified for Create") + } + + var r0 resource.Resource + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, resource.Resource) (resource.Resource, error)); ok { + return rf(ctx, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, resource.Resource) resource.Resource); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Get(0).(resource.Resource) + } + + if rf, ok := ret.Get(1).(func(context.Context, resource.Resource) error); ok { + r1 = rf(ctx, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_Create_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Create' +type Repository_Create_Call struct { + *mock.Call +} + +// Create is a helper method to define mock.On call +// - ctx context.Context +// - _a1 resource.Resource +func (_e *Repository_Expecter) Create(ctx interface{}, _a1 interface{}) *Repository_Create_Call { + return &Repository_Create_Call{Call: _e.mock.On("Create", ctx, _a1)} +} + +func (_c *Repository_Create_Call) Run(run func(ctx context.Context, _a1 resource.Resource)) *Repository_Create_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(resource.Resource)) + }) + return _c +} + +func (_c *Repository_Create_Call) Return(_a0 resource.Resource, _a1 error) *Repository_Create_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_Create_Call) RunAndReturn(run func(context.Context, resource.Resource) (resource.Resource, error)) *Repository_Create_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, id +func (_m *Repository) Delete(ctx context.Context, id string) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Repository_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type Repository_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) Delete(ctx interface{}, id interface{}) *Repository_Delete_Call { + return &Repository_Delete_Call{Call: _e.mock.On("Delete", ctx, id)} +} + +func (_c *Repository_Delete_Call) Run(run func(ctx context.Context, id string)) *Repository_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_Delete_Call) Return(_a0 error) *Repository_Delete_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Repository_Delete_Call) RunAndReturn(run func(context.Context, string) error) *Repository_Delete_Call { + _c.Call.Return(run) + return _c +} + +// GetByID provides a mock function with given fields: ctx, id +func (_m *Repository) GetByID(ctx context.Context, id string) (resource.Resource, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetByID") + } + + var r0 resource.Resource + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (resource.Resource, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) resource.Resource); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(resource.Resource) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_GetByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByID' +type Repository_GetByID_Call struct { + *mock.Call +} + +// GetByID is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) GetByID(ctx interface{}, id interface{}) *Repository_GetByID_Call { + return &Repository_GetByID_Call{Call: _e.mock.On("GetByID", ctx, id)} +} + +func (_c *Repository_GetByID_Call) Run(run func(ctx context.Context, id string)) *Repository_GetByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_GetByID_Call) Return(_a0 resource.Resource, _a1 error) *Repository_GetByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_GetByID_Call) RunAndReturn(run func(context.Context, string) (resource.Resource, error)) *Repository_GetByID_Call { + _c.Call.Return(run) + return _c +} + +// GetByURN provides a mock function with given fields: ctx, urn +func (_m *Repository) GetByURN(ctx context.Context, urn string) (resource.Resource, error) { + ret := _m.Called(ctx, urn) + + if len(ret) == 0 { + panic("no return value specified for GetByURN") + } + + var r0 resource.Resource + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (resource.Resource, error)); ok { + return rf(ctx, urn) + } + if rf, ok := ret.Get(0).(func(context.Context, string) resource.Resource); ok { + r0 = rf(ctx, urn) + } else { + r0 = ret.Get(0).(resource.Resource) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, urn) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_GetByURN_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByURN' +type Repository_GetByURN_Call struct { + *mock.Call +} + +// GetByURN is a helper method to define mock.On call +// - ctx context.Context +// - urn string +func (_e *Repository_Expecter) GetByURN(ctx interface{}, urn interface{}) *Repository_GetByURN_Call { + return &Repository_GetByURN_Call{Call: _e.mock.On("GetByURN", ctx, urn)} +} + +func (_c *Repository_GetByURN_Call) Run(run func(ctx context.Context, urn string)) *Repository_GetByURN_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_GetByURN_Call) Return(_a0 resource.Resource, _a1 error) *Repository_GetByURN_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_GetByURN_Call) RunAndReturn(run func(context.Context, string) (resource.Resource, error)) *Repository_GetByURN_Call { + _c.Call.Return(run) + return _c +} + +// List provides a mock function with given fields: ctx, flt +func (_m *Repository) List(ctx context.Context, flt resource.Filter) ([]resource.Resource, error) { + ret := _m.Called(ctx, flt) + + if len(ret) == 0 { + panic("no return value specified for List") + } + + var r0 []resource.Resource + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, resource.Filter) ([]resource.Resource, error)); ok { + return rf(ctx, flt) + } + if rf, ok := ret.Get(0).(func(context.Context, resource.Filter) []resource.Resource); ok { + r0 = rf(ctx, flt) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]resource.Resource) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, resource.Filter) error); ok { + r1 = rf(ctx, flt) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type Repository_List_Call struct { + *mock.Call +} + +// List is a helper method to define mock.On call +// - ctx context.Context +// - flt resource.Filter +func (_e *Repository_Expecter) List(ctx interface{}, flt interface{}) *Repository_List_Call { + return &Repository_List_Call{Call: _e.mock.On("List", ctx, flt)} +} + +func (_c *Repository_List_Call) Run(run func(ctx context.Context, flt resource.Filter)) *Repository_List_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(resource.Filter)) + }) + return _c +} + +func (_c *Repository_List_Call) Return(_a0 []resource.Resource, _a1 error) *Repository_List_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_List_Call) RunAndReturn(run func(context.Context, resource.Filter) ([]resource.Resource, error)) *Repository_List_Call { + _c.Call.Return(run) + return _c +} + +// Update provides a mock function with given fields: ctx, _a1 +func (_m *Repository) Update(ctx context.Context, _a1 resource.Resource) (resource.Resource, error) { + ret := _m.Called(ctx, _a1) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 resource.Resource + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, resource.Resource) (resource.Resource, error)); ok { + return rf(ctx, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, resource.Resource) resource.Resource); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Get(0).(resource.Resource) + } + + if rf, ok := ret.Get(1).(func(context.Context, resource.Resource) error); ok { + r1 = rf(ctx, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type Repository_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - ctx context.Context +// - _a1 resource.Resource +func (_e *Repository_Expecter) Update(ctx interface{}, _a1 interface{}) *Repository_Update_Call { + return &Repository_Update_Call{Call: _e.mock.On("Update", ctx, _a1)} +} + +func (_c *Repository_Update_Call) Run(run func(ctx context.Context, _a1 resource.Resource)) *Repository_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(resource.Resource)) + }) + return _c +} + +func (_c *Repository_Update_Call) Return(_a0 resource.Resource, _a1 error) *Repository_Update_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_Update_Call) RunAndReturn(run func(context.Context, resource.Resource) (resource.Resource, error)) *Repository_Update_Call { + _c.Call.Return(run) + return _c +} + +// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *Repository { + mock := &Repository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/resource/service.go b/core/resource/service.go index 7adf1e825..0f46cd33e 100644 --- a/core/resource/service.go +++ b/core/resource/service.go @@ -11,6 +11,7 @@ import ( "github.com/raystack/frontier/core/authenticate" "github.com/raystack/frontier/core/project" + patmodels "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/pkg/utils" "github.com/raystack/frontier/core/relation" @@ -36,6 +37,10 @@ type OrgService interface { Get(ctx context.Context, idOrName string) (organization.Organization, error) } +type PATService interface { + GetByID(ctx context.Context, id string) (patmodels.PAT, error) +} + type Service struct { repository Repository configRepository ConfigRepository @@ -43,11 +48,13 @@ type Service struct { authnService AuthnService projectService ProjectService orgService OrgService + patService PATService } func NewService(repository Repository, configRepository ConfigRepository, relationService RelationService, authnService AuthnService, - projectService ProjectService, orgService OrgService) *Service { + projectService ProjectService, orgService OrgService, + patService PATService) *Service { return &Service{ repository: repository, configRepository: configRepository, @@ -55,6 +62,7 @@ func NewService(repository Repository, configRepository ConfigRepository, authnService: authnService, projectService: projectService, orgService: orgService, + patService: patService, } } @@ -163,12 +171,17 @@ func (s Service) AddResourceOwner(ctx context.Context, res Resource) error { } func (s Service) CheckAuthz(ctx context.Context, check Check) (bool, error) { - relSubject, err := s.buildRelationSubject(ctx, check.Subject) + relObject, err := s.buildRelationObject(ctx, check.Object) if err != nil { return false, err } - relObject, err := s.buildRelationObject(ctx, check.Object) + // PAT scope — early exit if denied + if allowed, err := s.checkPATScope(ctx, check.Subject, relObject, check.Permission); err != nil || !allowed { + return false, err + } + + relSubject, err := s.buildRelationSubject(ctx, check.Subject) if err != nil { return false, err } @@ -183,6 +196,10 @@ func (s Service) CheckAuthz(ctx context.Context, check Check) (bool, error) { func (s Service) buildRelationSubject(ctx context.Context, sub relation.Subject) (relation.Subject, error) { // use existing if passed in request if sub.ID != "" && sub.Namespace != "" { + // PAT subject → resolve to underlying user for authorization + if sub.Namespace == schema.PATPrincipal { + return s.resolvePATUser(ctx, sub.ID) + } return sub, nil } @@ -190,6 +207,10 @@ func (s Service) buildRelationSubject(ctx context.Context, sub relation.Subject) if err != nil { return relation.Subject{}, err } + // PAT principal → use underlying user for authorization + if principal.PAT != nil { + return relation.Subject{ID: principal.PAT.UserID, Namespace: schema.UserPrincipal}, nil + } return relation.Subject{ ID: principal.ID, Namespace: principal.Type, @@ -229,26 +250,147 @@ func (s Service) buildRelationObject(ctx context.Context, obj relation.Object) ( return obj, nil } +// resolvePATUser resolves a PAT ID to its owning user subject. +// Tries context first(cached), falls back to DB (for federated checks with explicit subject). +func (s Service) resolvePATUser(ctx context.Context, patID string) (relation.Subject, error) { + principal, err := s.authnService.GetPrincipal(ctx) + if err == nil && principal.PAT != nil && principal.PAT.ID == patID { + return relation.Subject{ID: principal.PAT.UserID, Namespace: schema.UserPrincipal}, nil + } + + pat, err := s.patService.GetByID(ctx, patID) + if err != nil { + return relation.Subject{}, err + } + return relation.Subject{ID: pat.UserID, Namespace: schema.UserPrincipal}, nil +} + +// resolvePATID returns the PAT ID to scope-check, if any. +// Explicit app/pat subject takes precedence (federated check by admin), +// otherwise falls back to the authenticated principal's PAT. +func (s Service) resolvePATID(ctx context.Context, subject relation.Subject) string { + if subject.Namespace == schema.PATPrincipal && subject.ID != "" { + return subject.ID + } + principal, _ := s.authnService.GetPrincipal(ctx) + if principal.PAT != nil { + return principal.PAT.ID + } + return "" +} + +// checkPATScope checks if the PAT has scope for the given permission on the object. +// Returns (true, nil) if no PAT is involved. +func (s Service) checkPATScope(ctx context.Context, subject relation.Subject, object relation.Object, permission string) (bool, error) { + patID := s.resolvePATID(ctx, subject) + if patID == "" { + return true, nil + } + return s.relationService.CheckPermission(ctx, relation.Relation{ + Subject: relation.Subject{ID: patID, Namespace: schema.PATPrincipal}, + Object: object, + RelationName: permission, + }) +} + +// BatchCheck checks permissions for multiple resource checks. +// For PAT requests, it first batch-checks PAT scope, then only runs user permission +// checks for scope-allowed items. Scope-denied items return false directly. func (s Service) BatchCheck(ctx context.Context, checks []Check) ([]relation.CheckPair, error) { - relations := make([]relation.Relation, 0, len(checks)) - for _, check := range checks { - // we can parallelize this to speed up the process + relations, patScopeRelations, patScopeIdx, err := s.buildBatchRelations(ctx, checks) + if err != nil { + return nil, err + } + + // no PAT involved — straight to user permission check + if len(patScopeRelations) == 0 { + return s.relationService.BatchCheckPermission(ctx, relations) + } + + // PAT scope gate — check which items the PAT has scope for + scopeDenied, err := s.batchCheckPATScope(ctx, patScopeRelations, patScopeIdx) + if err != nil { + return nil, err + } + + // run user permission checks only for scope-allowed items, merge results + return s.batchCheckWithScopeFilter(ctx, relations, scopeDenied) +} + +// buildBatchRelations resolves objects/subjects and builds parallel PAT scope relations. +// Returns user relations, PAT scope relations, and index mapping from scope back to user relations. +func (s Service) buildBatchRelations(ctx context.Context, checks []Check) ( + relations, patScopeRelations []relation.Relation, patScopeIdx []int, err error, +) { + relations = make([]relation.Relation, 0, len(checks)) + for i, check := range checks { relObject, err := s.buildRelationObject(ctx, check.Object) if err != nil { - return nil, err + return nil, nil, nil, err } - relSubject, err := s.buildRelationSubject(ctx, check.Subject) if err != nil { - return nil, err + return nil, nil, nil, err } relations = append(relations, relation.Relation{ Subject: relSubject, Object: relObject, RelationName: check.Permission, }) + + if patID := s.resolvePATID(ctx, check.Subject); patID != "" { + patScopeRelations = append(patScopeRelations, relation.Relation{ + Subject: relation.Subject{ID: patID, Namespace: schema.PATPrincipal}, + Object: relObject, + RelationName: check.Permission, + }) + patScopeIdx = append(patScopeIdx, i) + } + } + return relations, patScopeRelations, patScopeIdx, nil +} + +// batchCheckPATScope runs a batch scope check and returns the set of denied relation indices. +func (s Service) batchCheckPATScope(ctx context.Context, patScopeRelations []relation.Relation, patScopeIdx []int) (map[int]bool, error) { + scopeResults, err := s.relationService.BatchCheckPermission(ctx, patScopeRelations) + if err != nil { + return nil, err + } + denied := make(map[int]bool, len(scopeResults)) + for j, sr := range scopeResults { + if !sr.Status { + denied[patScopeIdx[j]] = true + } + } + return denied, nil +} + +// batchCheckWithScopeFilter runs user permission checks for scope-allowed items +// and returns merged results where scope-denied items are false. +func (s Service) batchCheckWithScopeFilter(ctx context.Context, relations []relation.Relation, scopeDenied map[int]bool) ([]relation.CheckPair, error) { + var allowedRelations []relation.Relation + var allowedIdx []int + for i, rel := range relations { + if !scopeDenied[i] { + allowedRelations = append(allowedRelations, rel) + allowedIdx = append(allowedIdx, i) + } + } + + results := make([]relation.CheckPair, len(relations)) + for i := range results { + results[i] = relation.CheckPair{Relation: relations[i], Status: false} + } + if len(allowedRelations) > 0 { + userResults, err := s.relationService.BatchCheckPermission(ctx, allowedRelations) + if err != nil { + return nil, err + } + for j, idx := range allowedIdx { + results[idx] = userResults[j] + } } - return s.relationService.BatchCheckPermission(ctx, relations) + return results, nil } func (s Service) Delete(ctx context.Context, namespaceID, id string) error { diff --git a/core/resource/service_test.go b/core/resource/service_test.go new file mode 100644 index 000000000..109d15a29 --- /dev/null +++ b/core/resource/service_test.go @@ -0,0 +1,307 @@ +package resource_test + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/raystack/frontier/core/authenticate" + "github.com/raystack/frontier/core/relation" + "github.com/raystack/frontier/core/resource" + "github.com/raystack/frontier/core/resource/mocks" + patmodels "github.com/raystack/frontier/core/userpat/models" + "github.com/raystack/frontier/internal/bootstrap/schema" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func newTestService(t *testing.T) (*mocks.Repository, *mocks.ConfigRepository, *mocks.RelationService, *mocks.AuthnService, *mocks.ProjectService, *mocks.OrgService, *mocks.PATService, *resource.Service) { + t.Helper() + repo := mocks.NewRepository(t) + configRepo := mocks.NewConfigRepository(t) + relationSvc := mocks.NewRelationService(t) + authnSvc := mocks.NewAuthnService(t) + projectSvc := mocks.NewProjectService(t) + orgSvc := mocks.NewOrgService(t) + patSvc := mocks.NewPATService(t) + svc := resource.NewService(repo, configRepo, relationSvc, authnSvc, projectSvc, orgSvc, patSvc) + return repo, configRepo, relationSvc, authnSvc, projectSvc, orgSvc, patSvc, svc +} + +func TestCheckAuthz_NonPAT(t *testing.T) { + ctx := context.Background() + _, _, relationSvc, authnSvc, _, _, _, svc := newTestService(t) + + userID := uuid.New().String() + orgID := uuid.New().String() + + authnSvc.EXPECT().GetPrincipal(ctx, mock.Anything).Return(authenticate.Principal{ + ID: userID, + Type: schema.UserPrincipal, + }, nil).Maybe() + + relationSvc.EXPECT().CheckPermission(ctx, relation.Relation{ + Subject: relation.Subject{ID: userID, Namespace: schema.UserPrincipal}, + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.GetPermission, + }).Return(true, nil) + + result, err := svc.CheckAuthz(ctx, resource.Check{ + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + Permission: schema.GetPermission, + }) + assert.NoError(t, err) + assert.True(t, result) +} + +func TestCheckAuthz_PATScopeAllowed(t *testing.T) { + ctx := context.Background() + _, _, relationSvc, authnSvc, _, _, _, svc := newTestService(t) + + patID := uuid.New().String() + userID := uuid.New().String() + orgID := uuid.New().String() + + authnSvc.EXPECT().GetPrincipal(ctx, mock.Anything).Return(authenticate.Principal{ + ID: patID, + Type: schema.PATPrincipal, + PAT: &patmodels.PAT{ID: patID, UserID: userID}, + }, nil).Maybe() + + // PAT scope check — allowed + relationSvc.EXPECT().CheckPermission(ctx, relation.Relation{ + Subject: relation.Subject{ID: patID, Namespace: schema.PATPrincipal}, + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.GetPermission, + }).Return(true, nil) + + // User permission check — allowed + relationSvc.EXPECT().CheckPermission(ctx, relation.Relation{ + Subject: relation.Subject{ID: userID, Namespace: schema.UserPrincipal}, + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.GetPermission, + }).Return(true, nil) + + result, err := svc.CheckAuthz(ctx, resource.Check{ + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + Permission: schema.GetPermission, + }) + assert.NoError(t, err) + assert.True(t, result) +} + +func TestCheckAuthz_PATScopeDenied(t *testing.T) { + ctx := context.Background() + _, _, relationSvc, authnSvc, _, _, _, svc := newTestService(t) + + patID := uuid.New().String() + userID := uuid.New().String() + orgID := uuid.New().String() + + authnSvc.EXPECT().GetPrincipal(ctx, mock.Anything).Return(authenticate.Principal{ + ID: patID, + Type: schema.PATPrincipal, + PAT: &patmodels.PAT{ID: patID, UserID: userID}, + }, nil).Maybe() + + // PAT scope check — denied + relationSvc.EXPECT().CheckPermission(ctx, relation.Relation{ + Subject: relation.Subject{ID: patID, Namespace: schema.PATPrincipal}, + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.UpdatePermission, + }).Return(false, nil) + + // User check should NOT be called (early exit) + result, err := svc.CheckAuthz(ctx, resource.Check{ + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + Permission: schema.UpdatePermission, + }) + assert.NoError(t, err) + assert.False(t, result) +} + +func TestCheckAuthz_PATScopeAllowed_UserDenied(t *testing.T) { + ctx := context.Background() + _, _, relationSvc, authnSvc, _, _, _, svc := newTestService(t) + + patID := uuid.New().String() + userID := uuid.New().String() + orgID := uuid.New().String() + + authnSvc.EXPECT().GetPrincipal(ctx, mock.Anything).Return(authenticate.Principal{ + ID: patID, + Type: schema.PATPrincipal, + PAT: &patmodels.PAT{ID: patID, UserID: userID}, + }, nil).Maybe() + + // PAT scope check — allowed + relationSvc.EXPECT().CheckPermission(ctx, relation.Relation{ + Subject: relation.Subject{ID: patID, Namespace: schema.PATPrincipal}, + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.DeletePermission, + }).Return(true, nil) + + // User permission check — denied + relationSvc.EXPECT().CheckPermission(ctx, relation.Relation{ + Subject: relation.Subject{ID: userID, Namespace: schema.UserPrincipal}, + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.DeletePermission, + }).Return(false, nil) + + result, err := svc.CheckAuthz(ctx, resource.Check{ + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + Permission: schema.DeletePermission, + }) + assert.NoError(t, err) + assert.False(t, result) +} + +func TestCheckAuthz_ExplicitPATSubject_ScopeAllowed(t *testing.T) { + ctx := context.Background() + _, _, relationSvc, authnSvc, _, _, patSvc, svc := newTestService(t) + + patID := uuid.New().String() + userID := uuid.New().String() + orgID := uuid.New().String() + + // Principal is NOT a PAT (e.g., superuser making federated check with PAT subject) + authnSvc.EXPECT().GetPrincipal(ctx, mock.Anything).Return(authenticate.Principal{ + ID: uuid.New().String(), + Type: schema.UserPrincipal, + }, nil).Maybe() + + // PAT scope check for explicit subject — allowed + relationSvc.EXPECT().CheckPermission(ctx, relation.Relation{ + Subject: relation.Subject{ID: patID, Namespace: schema.PATPrincipal}, + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.GetPermission, + }).Return(true, nil) + + // Federated check passes explicit app/pat subject — needs DB lookup + patSvc.EXPECT().GetByID(ctx, patID).Return(patmodels.PAT{ + ID: patID, + UserID: userID, + }, nil) + + // User permission check (resolved from PAT) + relationSvc.EXPECT().CheckPermission(ctx, relation.Relation{ + Subject: relation.Subject{ID: userID, Namespace: schema.UserPrincipal}, + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.GetPermission, + }).Return(true, nil) + + result, err := svc.CheckAuthz(ctx, resource.Check{ + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + Subject: relation.Subject{ID: patID, Namespace: schema.PATPrincipal}, + Permission: schema.GetPermission, + }) + assert.NoError(t, err) + assert.True(t, result) +} + +func TestCheckAuthz_ExplicitPATSubject_ScopeDenied(t *testing.T) { + ctx := context.Background() + _, _, relationSvc, authnSvc, _, _, _, svc := newTestService(t) + + patID := uuid.New().String() + orgID := uuid.New().String() + + // Principal is NOT a PAT (e.g., superuser making federated check with PAT subject) + authnSvc.EXPECT().GetPrincipal(ctx, mock.Anything).Return(authenticate.Principal{ + ID: uuid.New().String(), + Type: schema.UserPrincipal, + }, nil).Maybe() + + // PAT scope check for explicit subject — denied + relationSvc.EXPECT().CheckPermission(ctx, relation.Relation{ + Subject: relation.Subject{ID: patID, Namespace: schema.PATPrincipal}, + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + RelationName: schema.UpdatePermission, + }).Return(false, nil) + + // User check should NOT be called — PAT scope denied + result, err := svc.CheckAuthz(ctx, resource.Check{ + Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, + Subject: relation.Subject{ID: patID, Namespace: schema.PATPrincipal}, + Permission: schema.UpdatePermission, + }) + assert.NoError(t, err) + assert.False(t, result) +} + +func TestBatchCheck_PATScopeAllowed(t *testing.T) { + ctx := context.Background() + _, _, relationSvc, authnSvc, _, _, _, svc := newTestService(t) + + patID := uuid.New().String() + userID := uuid.New().String() + orgID := uuid.New().String() + projID := uuid.New().String() + + authnSvc.EXPECT().GetPrincipal(ctx, mock.Anything).Return(authenticate.Principal{ + ID: patID, + Type: schema.PATPrincipal, + PAT: &patmodels.PAT{ID: patID, UserID: userID}, + }, nil).Maybe() + + checks := []resource.Check{ + {Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, Permission: schema.GetPermission}, + {Object: relation.Object{ID: projID, Namespace: schema.ProjectNamespace}, Permission: schema.GetPermission}, + } + + // PAT scope batch — all allowed + relationSvc.EXPECT().BatchCheckPermission(ctx, []relation.Relation{ + {Subject: relation.Subject{ID: patID, Namespace: schema.PATPrincipal}, Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, RelationName: schema.GetPermission}, + {Subject: relation.Subject{ID: patID, Namespace: schema.PATPrincipal}, Object: relation.Object{ID: projID, Namespace: schema.ProjectNamespace}, RelationName: schema.GetPermission}, + }).Return([]relation.CheckPair{ + {Status: true}, + {Status: true}, + }, nil) + + // User check batch + relationSvc.EXPECT().BatchCheckPermission(ctx, []relation.Relation{ + {Subject: relation.Subject{ID: userID, Namespace: schema.UserPrincipal}, Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, RelationName: schema.GetPermission}, + {Subject: relation.Subject{ID: userID, Namespace: schema.UserPrincipal}, Object: relation.Object{ID: projID, Namespace: schema.ProjectNamespace}, RelationName: schema.GetPermission}, + }).Return([]relation.CheckPair{ + {Status: true}, + {Status: true}, + }, nil) + + results, err := svc.BatchCheck(ctx, checks) + assert.NoError(t, err) + assert.Len(t, results, 2) + assert.True(t, results[0].Status) + assert.True(t, results[1].Status) +} + +func TestBatchCheck_PATScopeDenied(t *testing.T) { + ctx := context.Background() + _, _, relationSvc, authnSvc, _, _, _, svc := newTestService(t) + + patID := uuid.New().String() + userID := uuid.New().String() + orgID := uuid.New().String() + + authnSvc.EXPECT().GetPrincipal(ctx, mock.Anything).Return(authenticate.Principal{ + ID: patID, + Type: schema.PATPrincipal, + PAT: &patmodels.PAT{ID: patID, UserID: userID}, + }, nil).Maybe() + + checks := []resource.Check{ + {Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, Permission: schema.UpdatePermission}, + } + + // PAT scope batch — denied + relationSvc.EXPECT().BatchCheckPermission(ctx, []relation.Relation{ + {Subject: relation.Subject{ID: patID, Namespace: schema.PATPrincipal}, Object: relation.Object{ID: orgID, Namespace: schema.OrganizationNamespace}, RelationName: schema.UpdatePermission}, + }).Return([]relation.CheckPair{ + {Status: false}, + }, nil) + + // User check should NOT be called — scope-denied items return false directly + results, err := svc.BatchCheck(ctx, checks) + assert.NoError(t, err) + assert.Len(t, results, 1) + assert.False(t, results[0].Status) +} diff --git a/core/userpat/mocks/repository.go b/core/userpat/mocks/repository.go index b4a022075..0e0bfb811 100644 --- a/core/userpat/mocks/repository.go +++ b/core/userpat/mocks/repository.go @@ -139,6 +139,63 @@ func (_c *Repository_Create_Call) RunAndReturn(run func(context.Context, models. return _c } +// GetByID provides a mock function with given fields: ctx, id +func (_m *Repository) GetByID(ctx context.Context, id string) (models.PAT, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetByID") + } + + var r0 models.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (models.PAT, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) models.PAT); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(models.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_GetByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByID' +type Repository_GetByID_Call struct { + *mock.Call +} + +// GetByID is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) GetByID(ctx interface{}, id interface{}) *Repository_GetByID_Call { + return &Repository_GetByID_Call{Call: _e.mock.On("GetByID", ctx, id)} +} + +func (_c *Repository_GetByID_Call) Run(run func(ctx context.Context, id string)) *Repository_GetByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_GetByID_Call) Return(_a0 models.PAT, _a1 error) *Repository_GetByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_GetByID_Call) RunAndReturn(run func(context.Context, string) (models.PAT, error)) *Repository_GetByID_Call { + _c.Call.Return(run) + return _c +} + // GetBySecretHash provides a mock function with given fields: ctx, secretHash func (_m *Repository) GetBySecretHash(ctx context.Context, secretHash string) (models.PAT, error) { ret := _m.Called(ctx, secretHash) diff --git a/core/userpat/service.go b/core/userpat/service.go index da9ca16e4..4d082e63e 100644 --- a/core/userpat/service.go +++ b/core/userpat/service.go @@ -87,6 +87,10 @@ func (s *Service) ValidateExpiry(expiresAt time.Time) error { return nil } +func (s *Service) GetByID(ctx context.Context, id string) (patmodels.PAT, error) { + return s.repo.GetByID(ctx, id) +} + // Create generates a new PAT and returns it with the plaintext value. // The plaintext value is only available at creation time. func (s *Service) Create(ctx context.Context, req CreateRequest) (patmodels.PAT, string, error) { diff --git a/core/userpat/userpat.go b/core/userpat/userpat.go index 0a3dcc3f7..65befac1b 100644 --- a/core/userpat/userpat.go +++ b/core/userpat/userpat.go @@ -10,6 +10,7 @@ import ( type Repository interface { Create(ctx context.Context, pat models.PAT) (models.PAT, error) CountActive(ctx context.Context, userID, orgID string) (int64, error) + GetByID(ctx context.Context, id string) (models.PAT, error) GetBySecretHash(ctx context.Context, secretHash string) (models.PAT, error) UpdateLastUsedAt(ctx context.Context, id string, at time.Time) error } diff --git a/internal/store/postgres/userpat_repository.go b/internal/store/postgres/userpat_repository.go index d53f9bf6c..e4efdb222 100644 --- a/internal/store/postgres/userpat_repository.go +++ b/internal/store/postgres/userpat_repository.go @@ -87,6 +87,31 @@ func (r UserPATRepository) CountActive(ctx context.Context, userID, orgID string return count, nil } +func (r UserPATRepository) GetByID(ctx context.Context, id string) (models.PAT, error) { + query, params, err := dialect.From(TABLE_USER_PATS). + Select(&UserPAT{}). + Where( + goqu.Ex{"id": id}, + goqu.Ex{"deleted_at": nil}, + ).Limit(1).ToSQL() + if err != nil { + return models.PAT{}, fmt.Errorf("%w: %w", queryErr, err) + } + + var model UserPAT + if err = r.dbc.WithTimeout(ctx, TABLE_USER_PATS, "GetByID", func(ctx context.Context) error { + return r.dbc.GetContext(ctx, &model, query, params...) + }); err != nil { + err = checkPostgresError(err) + if errors.Is(err, sql.ErrNoRows) { + return models.PAT{}, paterrors.ErrNotFound + } + return models.PAT{}, fmt.Errorf("%w: %w", dbErr, err) + } + + return model.transform() +} + func (r UserPATRepository) GetBySecretHash(ctx context.Context, secretHash string) (models.PAT, error) { query, params, err := dialect.From(TABLE_USER_PATS). Select(&UserPAT{}). diff --git a/test/e2e/regression/pat_test.go b/test/e2e/regression/pat_test.go new file mode 100644 index 000000000..a4fd698ae --- /dev/null +++ b/test/e2e/regression/pat_test.go @@ -0,0 +1,365 @@ +//go:build !race + +package e2e_test + +import ( + "context" + "os" + "path" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/raystack/frontier/core/authenticate" + testusers "github.com/raystack/frontier/core/authenticate/test_users" + "github.com/raystack/frontier/core/userpat" + "github.com/raystack/frontier/internal/bootstrap/schema" + "github.com/raystack/frontier/pkg/server" + + "github.com/raystack/frontier/config" + "github.com/raystack/frontier/pkg/logger" + frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1" + "github.com/raystack/frontier/test/e2e/testbench" + "github.com/stretchr/testify/suite" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type PATRegressionTestSuite struct { + suite.Suite + testBench *testbench.TestBench + adminCookie string + roleIDs map[string]string // role name -> UUID +} + +func (s *PATRegressionTestSuite) SetupSuite() { + wd, err := os.Getwd() + s.Require().Nil(err) + testDataPath := path.Join("file://", wd, fixturesDir) + + connectPort, err := testbench.GetFreePort() + s.Require().NoError(err) + + appConfig := &config.Frontier{ + Log: logger.Config{ + Level: "error", + }, + App: server.Config{ + Host: "localhost", + Connect: server.ConnectConfig{Port: connectPort}, + ResourcesConfigPath: path.Join(testDataPath, "resource"), + Authentication: authenticate.Config{ + Session: authenticate.SessionConfig{ + HashSecretKey: "hash-secret-should-be-32-chars--", + BlockSecretKey: "hash-secret-should-be-32-chars--", + Validity: time.Hour, + }, + Token: authenticate.TokenConfig{ + RSAPath: "testdata/jwks.json", + Issuer: "frontier", + }, + MailOTP: authenticate.MailOTPConfig{ + Subject: "{{.Otp}}", + Body: "{{.Otp}}", + Validity: 10 * time.Minute, + }, + TestUsers: testusers.Config{Enabled: true, Domain: "raystack.org", OTP: testbench.TestOTP}, + }, + PAT: userpat.Config{Enabled: true, Prefix: "fpt", MaxPerUserPerOrg: 50, MaxLifetime: "8760h"}, + }, + } + + s.testBench, err = testbench.Init(appConfig) + s.Require().NoError(err) + + ctx := context.Background() + + adminCookie, err := testbench.AuthenticateUser(ctx, s.testBench.Client, testbench.OrgAdminEmail) + s.Require().NoError(err) + s.adminCookie = adminCookie + + s.Require().NoError(testbench.BootstrapUsers(ctx, s.testBench.Client, adminCookie)) + s.Require().NoError(testbench.BootstrapOrganizations(ctx, s.testBench.Client, adminCookie)) + s.Require().NoError(testbench.BootstrapProject(ctx, s.testBench.Client, adminCookie)) + s.Require().NoError(testbench.BootstrapGroup(ctx, s.testBench.Client, adminCookie)) + + // build role name → UUID map for PAT creation (requires UUIDs) + ctxAdmin := testbench.ContextWithAuth(ctx, adminCookie) + rolesResp, err := s.testBench.Client.ListRoles(ctxAdmin, connect.NewRequest(&frontierv1beta1.ListRolesRequest{})) + s.Require().NoError(err) + s.roleIDs = make(map[string]string, len(rolesResp.Msg.GetRoles())) + for _, r := range rolesResp.Msg.GetRoles() { + s.roleIDs[r.GetName()] = r.GetId() + } +} + +func (s *PATRegressionTestSuite) TearDownSuite() { + err := s.testBench.Close() + s.Require().NoError(err) +} + +func (s *PATRegressionTestSuite) roleID(name string) string { + id, ok := s.roleIDs[name] + s.Require().True(ok, "role %q not found in platform roles", name) + return id +} + +func getPATCtx(token string) context.Context { + return testbench.ContextWithHeaders(context.Background(), map[string]string{ + "Authorization": "Bearer " + token, + }) +} + +func (s *PATRegressionTestSuite) createOrgAndProjects(ctxAdmin context.Context, orgName, proj1Name, proj2Name string) (string, string, string) { + createOrgResp, err := s.testBench.Client.CreateOrganization(ctxAdmin, connect.NewRequest(&frontierv1beta1.CreateOrganizationRequest{ + Body: &frontierv1beta1.OrganizationRequestBody{ + Name: orgName, + }, + })) + s.Require().NoError(err) + orgID := createOrgResp.Msg.GetOrganization().GetId() + + proj1Resp, err := s.testBench.Client.CreateProject(ctxAdmin, connect.NewRequest(&frontierv1beta1.CreateProjectRequest{ + Body: &frontierv1beta1.ProjectRequestBody{ + Name: proj1Name, + OrgId: orgID, + }, + })) + s.Require().NoError(err) + proj1ID := proj1Resp.Msg.GetProject().GetId() + + var proj2ID string + if proj2Name != "" { + proj2Resp, err := s.testBench.Client.CreateProject(ctxAdmin, connect.NewRequest(&frontierv1beta1.CreateProjectRequest{ + Body: &frontierv1beta1.ProjectRequestBody{ + Name: proj2Name, + OrgId: orgID, + }, + })) + s.Require().NoError(err) + proj2ID = proj2Resp.Msg.GetProject().GetId() + } + + return orgID, proj1ID, proj2ID +} + +func (s *PATRegressionTestSuite) createPAT(ctxAdmin context.Context, orgID, title string, roleIDs, projectIDs []string) (string, string) { + patResp, err := s.testBench.Client.CreateCurrentUserPAT(ctxAdmin, connect.NewRequest(&frontierv1beta1.CreateCurrentUserPATRequest{ + Title: title, + OrgId: orgID, + RoleIds: roleIDs, + ProjectIds: projectIDs, + ExpiresAt: timestamppb.New(time.Now().Add(24 * time.Hour)), + })) + s.Require().NoError(err) + s.Require().NotEmpty(patResp.Msg.GetPat().GetToken()) + return patResp.Msg.GetPat().GetId(), patResp.Msg.GetPat().GetToken() +} + +func (s *PATRegressionTestSuite) checkPermission(ctx context.Context, namespace, id, permission string) bool { + resp, err := s.testBench.Client.CheckResourcePermission(ctx, connect.NewRequest(&frontierv1beta1.CheckResourcePermissionRequest{ + Resource: schema.JoinNamespaceAndResourceID(namespace, id), + Permission: permission, + })) + s.Require().NoError(err) + return resp.Msg.GetStatus() +} + +func (s *PATRegressionTestSuite) TestPATScope_OrgViewer_ProjectViewer() { + ctxAdmin := testbench.ContextWithAuth(context.Background(), s.adminCookie) + orgID, proj1ID, proj2ID := s.createOrgAndProjects(ctxAdmin, "org-pat-ov-pv", "pat-ov-pv-p1", "pat-ov-pv-p2") + + _, patToken := s.createPAT(ctxAdmin, orgID, "pat-ov-pv", + []string{s.roleID(schema.RoleOrganizationViewer), s.roleID(schema.RoleProjectViewer)}, + []string{proj1ID}, + ) + patCtx := getPATCtx(patToken) + + s.Run("org get allowed", func() { + s.Assert().True(s.checkPermission(patCtx, schema.OrganizationNamespace, orgID, schema.GetPermission)) + }) + s.Run("org update denied", func() { + s.Assert().False(s.checkPermission(patCtx, schema.OrganizationNamespace, orgID, schema.UpdatePermission)) + }) + s.Run("scoped project get allowed", func() { + s.Assert().True(s.checkPermission(patCtx, schema.ProjectNamespace, proj1ID, schema.GetPermission)) + }) + s.Run("scoped project update denied", func() { + s.Assert().False(s.checkPermission(patCtx, schema.ProjectNamespace, proj1ID, schema.UpdatePermission)) + }) + s.Run("unscoped project get denied", func() { + s.Assert().False(s.checkPermission(patCtx, schema.ProjectNamespace, proj2ID, schema.GetPermission)) + }) + s.Run("batch check mixed results", func() { + batchResp, err := s.testBench.Client.BatchCheckPermission(patCtx, connect.NewRequest(&frontierv1beta1.BatchCheckPermissionRequest{ + Bodies: []*frontierv1beta1.BatchCheckPermissionBody{ + { + Resource: schema.JoinNamespaceAndResourceID(schema.OrganizationNamespace, orgID), + Permission: schema.GetPermission, + }, + { + Resource: schema.JoinNamespaceAndResourceID(schema.OrganizationNamespace, orgID), + Permission: schema.UpdatePermission, + }, + { + Resource: schema.JoinNamespaceAndResourceID(schema.ProjectNamespace, proj1ID), + Permission: schema.GetPermission, + }, + }, + })) + s.Require().NoError(err) + pairs := batchResp.Msg.GetPairs() + s.Require().Len(pairs, 3) + s.Assert().True(pairs[0].GetStatus(), "org:get should be true") + s.Assert().False(pairs[1].GetStatus(), "org:update should be false") + s.Assert().True(pairs[2].GetStatus(), "proj1:get should be true") + }) +} + +func (s *PATRegressionTestSuite) TestPATScope_OrgOwner() { + ctxAdmin := testbench.ContextWithAuth(context.Background(), s.adminCookie) + orgID, proj1ID, _ := s.createOrgAndProjects(ctxAdmin, "org-pat-oo", "pat-oo-p1", "") + + _, patToken := s.createPAT(ctxAdmin, orgID, "pat-oo", + []string{s.roleID(schema.RoleOrganizationOwner)}, + nil, + ) + patCtx := getPATCtx(patToken) + + s.Run("org get allowed", func() { + s.Assert().True(s.checkPermission(patCtx, schema.OrganizationNamespace, orgID, schema.GetPermission)) + }) + s.Run("org update allowed", func() { + s.Assert().True(s.checkPermission(patCtx, schema.OrganizationNamespace, orgID, schema.UpdatePermission)) + }) + s.Run("project get inherited from org owner", func() { + s.Assert().True(s.checkPermission(patCtx, schema.ProjectNamespace, proj1ID, schema.GetPermission)) + }) + s.Run("project update inherited from org owner", func() { + s.Assert().True(s.checkPermission(patCtx, schema.ProjectNamespace, proj1ID, schema.UpdatePermission)) + }) +} + +func (s *PATRegressionTestSuite) TestPATScope_OrgViewer_AllProjects() { + ctxAdmin := testbench.ContextWithAuth(context.Background(), s.adminCookie) + orgID, proj1ID, proj2ID := s.createOrgAndProjects(ctxAdmin, "org-pat-ov-ap", "pat-ov-ap-p1", "pat-ov-ap-p2") + + _, patToken := s.createPAT(ctxAdmin, orgID, "pat-ov-ap", + []string{s.roleID(schema.RoleOrganizationViewer), s.roleID(schema.RoleProjectOwner)}, + nil, // empty = all projects + ) + patCtx := getPATCtx(patToken) + + s.Run("org get allowed", func() { + s.Assert().True(s.checkPermission(patCtx, schema.OrganizationNamespace, orgID, schema.GetPermission)) + }) + s.Run("org update denied", func() { + s.Assert().False(s.checkPermission(patCtx, schema.OrganizationNamespace, orgID, schema.UpdatePermission)) + }) + s.Run("proj1 update allowed", func() { + s.Assert().True(s.checkPermission(patCtx, schema.ProjectNamespace, proj1ID, schema.UpdatePermission)) + }) + s.Run("proj2 update allowed", func() { + s.Assert().True(s.checkPermission(patCtx, schema.ProjectNamespace, proj2ID, schema.UpdatePermission)) + }) +} + +func (s *PATRegressionTestSuite) TestPATScope_BillingManager() { + ctxAdmin := testbench.ContextWithAuth(context.Background(), s.adminCookie) + orgID, proj1ID, _ := s.createOrgAndProjects(ctxAdmin, "org-pat-bm", "pat-bm-p1", "") + + _, patToken := s.createPAT(ctxAdmin, orgID, "pat-bm", + []string{s.roleID("app_billing_manager")}, + nil, + ) + patCtx := getPATCtx(patToken) + + s.Run("org billingview allowed", func() { + s.Assert().True(s.checkPermission(patCtx, schema.OrganizationNamespace, orgID, schema.BillingViewPermission)) + }) + s.Run("org billingmanage allowed", func() { + s.Assert().True(s.checkPermission(patCtx, schema.OrganizationNamespace, orgID, schema.BillingManagePermission)) + }) + s.Run("org get denied", func() { + s.Assert().False(s.checkPermission(patCtx, schema.OrganizationNamespace, orgID, schema.GetPermission)) + }) + s.Run("org update denied", func() { + s.Assert().False(s.checkPermission(patCtx, schema.OrganizationNamespace, orgID, schema.UpdatePermission)) + }) + s.Run("project get denied", func() { + s.Assert().False(s.checkPermission(patCtx, schema.ProjectNamespace, proj1ID, schema.GetPermission)) + }) +} + +func (s *PATRegressionTestSuite) TestPATScope_Interceptor() { + ctxAdmin := testbench.ContextWithAuth(context.Background(), s.adminCookie) + + createOrgResp, err := s.testBench.Client.CreateOrganization(ctxAdmin, connect.NewRequest(&frontierv1beta1.CreateOrganizationRequest{ + Body: &frontierv1beta1.OrganizationRequestBody{ + Name: "org-pat-interceptor", + }, + })) + s.Require().NoError(err) + orgID := createOrgResp.Msg.GetOrganization().GetId() + + _, patToken := s.createPAT(ctxAdmin, orgID, "pat-interceptor", + []string{s.roleID(schema.RoleOrganizationViewer)}, + nil, + ) + patCtx := getPATCtx(patToken) + + // UpdateOrganization requires update permission — PAT only has viewer scope + _, err = s.testBench.Client.UpdateOrganization(patCtx, connect.NewRequest(&frontierv1beta1.UpdateOrganizationRequest{ + Id: orgID, + Body: &frontierv1beta1.OrganizationRequestBody{ + Name: "org-pat-interceptor", + Title: "updated title", + }, + })) + s.Assert().Error(err) + s.Assert().Equal(connect.CodePermissionDenied, connect.CodeOf(err)) +} + +func (s *PATRegressionTestSuite) TestPATScope_FederatedCheck() { + ctxAdmin := testbench.ContextWithAuth(context.Background(), s.adminCookie) + + createOrgResp, err := s.testBench.Client.CreateOrganization(ctxAdmin, connect.NewRequest(&frontierv1beta1.CreateOrganizationRequest{ + Body: &frontierv1beta1.OrganizationRequestBody{ + Name: "org-pat-federated", + }, + })) + s.Require().NoError(err) + orgID := createOrgResp.Msg.GetOrganization().GetId() + + patID, _ := s.createPAT(ctxAdmin, orgID, "pat-federated", + []string{s.roleID(schema.RoleOrganizationViewer)}, + nil, + ) + + patSubject := schema.JoinNamespaceAndResourceID(schema.PATPrincipal, patID) + orgResource := schema.JoinNamespaceAndResourceID(schema.OrganizationNamespace, orgID) + + s.Run("federated check get allowed", func() { + resp, err := s.testBench.AdminClient.CheckFederatedResourcePermission(ctxAdmin, + connect.NewRequest(&frontierv1beta1.CheckFederatedResourcePermissionRequest{ + Subject: patSubject, + Resource: orgResource, + Permission: schema.GetPermission, + })) + s.Require().NoError(err) + s.Assert().True(resp.Msg.GetStatus()) + }) + s.Run("federated check update denied", func() { + resp, err := s.testBench.AdminClient.CheckFederatedResourcePermission(ctxAdmin, + connect.NewRequest(&frontierv1beta1.CheckFederatedResourcePermissionRequest{ + Subject: patSubject, + Resource: orgResource, + Permission: schema.UpdatePermission, + })) + s.Require().NoError(err) + s.Assert().False(resp.Msg.GetStatus()) + }) +} + +func TestEndToEndPATRegressionTestSuite(t *testing.T) { + suite.Run(t, new(PATRegressionTestSuite)) +} From ab997c66f981916e6d04005ed70b94de8a66492c Mon Sep 17 00:00:00 2001 From: aman Date: Wed, 11 Mar 2026 15:53:05 +0530 Subject: [PATCH 5/5] fix: merge conflict --- core/userpat/mocks/repository.go | 57 ++++++++++++++++++++++++++++++++ core/userpat/userpat.go | 1 + 2 files changed, 58 insertions(+) diff --git a/core/userpat/mocks/repository.go b/core/userpat/mocks/repository.go index b4a022075..0e0bfb811 100644 --- a/core/userpat/mocks/repository.go +++ b/core/userpat/mocks/repository.go @@ -139,6 +139,63 @@ func (_c *Repository_Create_Call) RunAndReturn(run func(context.Context, models. return _c } +// GetByID provides a mock function with given fields: ctx, id +func (_m *Repository) GetByID(ctx context.Context, id string) (models.PAT, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for GetByID") + } + + var r0 models.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (models.PAT, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string) models.PAT); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(models.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_GetByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetByID' +type Repository_GetByID_Call struct { + *mock.Call +} + +// GetByID is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) GetByID(ctx interface{}, id interface{}) *Repository_GetByID_Call { + return &Repository_GetByID_Call{Call: _e.mock.On("GetByID", ctx, id)} +} + +func (_c *Repository_GetByID_Call) Run(run func(ctx context.Context, id string)) *Repository_GetByID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_GetByID_Call) Return(_a0 models.PAT, _a1 error) *Repository_GetByID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_GetByID_Call) RunAndReturn(run func(context.Context, string) (models.PAT, error)) *Repository_GetByID_Call { + _c.Call.Return(run) + return _c +} + // GetBySecretHash provides a mock function with given fields: ctx, secretHash func (_m *Repository) GetBySecretHash(ctx context.Context, secretHash string) (models.PAT, error) { ret := _m.Called(ctx, secretHash) diff --git a/core/userpat/userpat.go b/core/userpat/userpat.go index 0a3dcc3f7..65befac1b 100644 --- a/core/userpat/userpat.go +++ b/core/userpat/userpat.go @@ -10,6 +10,7 @@ import ( type Repository interface { Create(ctx context.Context, pat models.PAT) (models.PAT, error) CountActive(ctx context.Context, userID, orgID string) (int64, error) + GetByID(ctx context.Context, id string) (models.PAT, error) GetBySecretHash(ctx context.Context, secretHash string) (models.PAT, error) UpdateLastUsedAt(ctx context.Context, id string, at time.Time) error }