Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ test-clean:
go clean -testcache

test: test-clean
go test -run=$(TEST) $(TEST_FLAGS) -json ./... | tparse --all --follow
go test -run=$(TEST) $(TEST_FLAGS) -json ./... | go run github.com/mfridman/tparse --all --follow

test-rerun: test-clean
go run github.com/goware/rerun/cmd/rerun -watch ./ -run 'make test'

test-coverage:
go test -run=$(TEST) $(TEST_FLAGS) -cover -coverprofile=coverage.out -json ./... | tparse --all --follow
go test -run=$(TEST) $(TEST_FLAGS) -cover -coverprofile=coverage.out -json ./... | go run github.com/mfridman/tparse --all --follow

test-coverage-inspect: test-coverage
go tool cover -html=coverage.out
Expand Down
33 changes: 1 addition & 32 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,37 +104,6 @@ func (c Config[any]) Verify(webrpcServices map[string][]string) error {
return errors.Join(errList...)
}

// ACL is a list of session types, encoded as a bitfield.
// SessionType(n) is represented by n=-the bit.
type ACL uint64

// NewACL returns a new ACL with the given session types.
func NewACL(sessions ...proto.SessionType) ACL {
var acl ACL
for _, v := range sessions {
acl = acl.And(v)
}
return acl
}

// And returns a new ACL with the given session types added.
func (a ACL) And(session ...proto.SessionType) ACL {
for _, v := range session {
a |= 1 << v
}
return a
}

// Includes returns true if the ACL includes the given session type.
func (t ACL) Includes(session proto.SessionType) bool {
return t&ACL(1<<session) != 0
}

// NewAuth creates a new Auth HS256 with the given secret.
func NewAuth(secret string) *Auth {
return &Auth{Algorithm: jwa.HS256, Private: []byte(secret)}
}

// Auth is a struct that holds the private and public keys for JWT signing and verification.
type Auth struct {
Algorithm jwa.SignatureAlgorithm
Expand All @@ -145,7 +114,7 @@ type Auth struct {
// GetVerifier returns a JWTAuth using the private secret when available, otherwise the public key
func (a Auth) GetVerifier(options ...jwt.ValidateOption) (*jwtauth.JWTAuth, error) {
if a.Algorithm == "" {
return nil, fmt.Errorf("missing algorithm")
a.Algorithm = jwa.HS256
}

if a.Private != nil {
Expand Down
26 changes: 13 additions & 13 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,29 +76,29 @@ func TestVerify(t *testing.T) {
}

// Valid ACL config
acl := authcontrol.Config[any]{
acl := authcontrol.Config[proto.SessionTypes]{
"Service1": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
"Method2": authcontrol.NewACL(proto.SessionType_User),
"Method3": authcontrol.NewACL(proto.SessionType_User),
"Method1": proto.NewSessionTypes(proto.SessionType_User),
"Method2": proto.NewSessionTypes(proto.SessionType_User),
"Method3": proto.NewSessionTypes(proto.SessionType_User),
},
"Service2": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
"Method1": proto.NewSessionTypes(proto.SessionType_User),
},
}

err := acl.Verify(services)
assert.NoError(t, err)

// Wrong Service
acl = authcontrol.Config[any]{
acl = authcontrol.Config[proto.SessionTypes]{
"WrongService1": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
"Method2": authcontrol.NewACL(proto.SessionType_User),
"Method3": authcontrol.NewACL(proto.SessionType_User),
"Method1": proto.NewSessionTypes(proto.SessionType_User),
"Method2": proto.NewSessionTypes(proto.SessionType_User),
"Method3": proto.NewSessionTypes(proto.SessionType_User),
},
"Service2": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
"Method1": proto.NewSessionTypes(proto.SessionType_User),
},
}

Expand All @@ -113,12 +113,12 @@ func TestVerify(t *testing.T) {
assert.Equal(t, errors.Join(expectedErrors...).Error(), err.Error())

// Wrong Methods
acl = authcontrol.Config[any]{
acl = authcontrol.Config[proto.SessionTypes]{
"Service1": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
"Method1": proto.NewSessionTypes(proto.SessionType_User),
},
"Service2": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
"Method1": proto.NewSessionTypes(proto.SessionType_User),
},
}

Expand Down
30 changes: 30 additions & 0 deletions go.work.sum
Original file line number Diff line number Diff line change
@@ -1,17 +1,43 @@
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/alecthomas/kingpin/v2 v2.4.0/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HRL7ATRpFZCw6tE=
github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/charmbracelet/colorprofile v0.3.1 h1:k8dTHMd7fgw4bnFd7jXTLZrSU/CQrKnL3m+AxCzDz40=
github.com/charmbracelet/colorprofile v0.3.1/go.mod h1:/GkGusxNs8VB/RSOh3fu0TJmQ4ICMMPApIIVn0KszZ0=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.9.2 h1:92AGsQmNTRMzuzHEYfCdjQeUzTrgE1vfO5/7fEVoXdY=
github.com/charmbracelet/x/ansi v0.9.2/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
github.com/charmbracelet/x/cellbuf v0.0.13 h1:/KBBKHuVRbq1lYx5BzEHBAFBP8VcQzJejZ/IA3iR28k=
github.com/charmbracelet/x/cellbuf v0.0.13/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mfridman/tparse v0.18.0 h1:wh6dzOKaIwkUGyKgOntDW4liXSo37qg5AXbIhkMV3vE=
github.com/mfridman/tparse v0.18.0/go.mod h1:gEvqZTuCgEhPbYk/2lS3Kcxg1GmTxxU7kTC8DvP0i/A=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/webrpc/gen-typescript v0.16.1/go.mod h1:xQzYnVaSMfcygDXA5SuW8eYyCLHBHkj15wCF7gcJF5Y=
github.com/webrpc/webrpc v0.22.0/go.mod h1:eeABnLz9BC4F9GGw6UKebVPkzkFYLrZRlcOvh6o8n10=
github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
Expand All @@ -21,9 +47,13 @@ golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457/go.mod h1:pRgIJT+bRLFKnoM1ldnzKoxTIn14Yxz928LQRYYgIN0=
golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
28 changes: 21 additions & 7 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@ type Options struct {
// It is used to validate the `scope` claim for admin sessions.
ServiceName string

// JWTsecret is required, and it is used for the JWT verification.
// If a Project Store is also provided and the request has a project claim,
// it could be replaced by the a specific verifier.
// JWTSecret is used to create the default Auth (HS256) when Auth is not provided.
//
// Deprecated: Use Auth: Auth{Private: []byte(Secret)} instead.
JWTSecret string

// ProjectStore is a pluggable backends that verifies if the project from the claim exists.
// When provived, it checks the Project from the JWT, and can override the JWT Auth.
// Auth is the JWT verifier. If not provided, it is created from JWTSecret.
// If a ProjectStore is also provided and the request has a project claim,
// it can be overridden by a project-specific Auth.
Auth *Auth

// ProjectStore is a pluggable backend that verifies if the project from the claim exists.
// When provided, it checks the Project from the JWT, and can override the JWT Auth.
ProjectStore ProjectStore

// AccessKeyFuncs are used to extract the access key from the request.
Expand Down Expand Up @@ -62,6 +67,11 @@ func (o *Options) ApplyDefaults() {
if o.ErrHandler == nil {
o.ErrHandler = errHandler
}

// Create default Auth from JWTSecret if not provided
if o.Auth == nil {
o.Auth = &Auth{Private: []byte(o.JWTSecret)}
}
}

func VerifyToken(cfg Options) func(next http.Handler) http.Handler {
Expand All @@ -74,7 +84,7 @@ func VerifyToken(cfg Options) func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

auth := NewAuth(cfg.JWTSecret)
auth := cfg.Auth

if cfg.ProjectStore != nil {
projectID, err := findProjectClaim(r)
Expand Down Expand Up @@ -311,9 +321,13 @@ func Session(cfg Options) func(next http.Handler) http.Handler {
}
}

type ACL interface {
Get(ctx context.Context, path string) (proto.SessionTypes, bool)
}

// AccessControl middleware that checks if the session type is allowed to access the endpoint.
// It also sets the compute units on the context if the endpoint requires it.
func AccessControl(acl Config[ACL], cfg Options) func(next http.Handler) http.Handler {
func AccessControl(acl ACL, cfg Options) func(next http.Handler) http.Handler {
cfg.ApplyDefaults()

return func(next http.Handler) http.Handler {
Expand Down
Loading