Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions services/proxy/pkg/staticroutes/backchannellogout.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,11 @@ func (s *StaticRouteHandler) backchannelLogout(w http.ResponseWriter, r *http.Re
continue
}

if err := s.publishBackchannelLogoutEvent(r.Context(), session, value); err != nil {
s.Logger.Warn().Err(err).Msgf("failed to publish backchannel logout event for: %s", key)
continue
if requestSubjectAndSession.Mode() == bcl.LogoutModeSession {
if err := s.publishBackchannelLogoutEvent(r.Context(), session, value); err != nil {
s.Logger.Warn().Err(err).Msgf("failed to publish backchannel logout event for: %s", key)
continue
}
}

err = s.UserInfoCache.Delete(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ func NewKey(subject, session string) (string, error) {
return subjectSession, nil
}

// LogoutMode defines the mode of backchannel logout, either by session or by subject
type LogoutMode int

const (
// LogoutModeUndefined is used when the logout mode cannot be determined
LogoutModeUndefined LogoutMode = iota
// LogoutModeSubject is used when the logout mode is determined by the subject
LogoutModeSubject
// LogoutModeSession is used when the logout mode is determined by the session id
LogoutModeSession
)

// ErrDecoding is returned when decoding fails
var ErrDecoding = errors.New("failed to decode")

Expand Down Expand Up @@ -62,6 +74,18 @@ func (suse SuSe) Session() (string, error) {
return string(subject), nil
}

// Mode determines the backchannel logout mode based on the presence of subject and session
func (suse SuSe) Mode() LogoutMode {
switch {
case suse.encodedSession == "" && suse.encodedSubject != "":
return LogoutModeSubject
case suse.encodedSession != "":
return LogoutModeSession
default:
return LogoutModeUndefined
}
}

// ErrInvalidSubjectOrSession is returned when the provided key does not match the expected key format
var ErrInvalidSubjectOrSession = errors.New("invalid subject or session")

Expand Down Expand Up @@ -91,31 +115,11 @@ func NewSuSe(key string) (SuSe, error) {
return suse, errors.Join(ErrInvalidSubjectOrSession, err)
}

return suse, nil
}

// logoutMode defines the mode of backchannel logout, either by session or by subject
type logoutMode int

const (
// logoutModeUndefined is used when the logout mode cannot be determined
logoutModeUndefined logoutMode = iota
// logoutModeSubject is used when the logout mode is determined by the subject
logoutModeSubject
// logoutModeSession is used when the logout mode is determined by the session id
logoutModeSession
)

// getLogoutMode determines the backchannel logout mode based on the presence of subject and session in the SuSe struct
func getLogoutMode(suse SuSe) logoutMode {
switch {
case suse.encodedSession == "" && suse.encodedSubject != "":
return logoutModeSubject
case suse.encodedSession != "":
return logoutModeSession
default:
return logoutModeUndefined
if mode := suse.Mode(); mode == LogoutModeUndefined {
return suse, ErrInvalidSubjectOrSession
}

return suse, nil
}

// ErrSuspiciousCacheResult is returned when the cache result is suspicious
Expand All @@ -126,18 +130,15 @@ var ErrSuspiciousCacheResult = errors.New("suspicious cache result")
// it uses a seperator to prevent sufix and prefix exploration in the cache and checks
// if the retrieved records match the requested subject and or session id as well, to prevent false positives.
func GetLogoutRecords(suse SuSe, store microstore.Store) ([]*microstore.Record, error) {
// get subject.session mode
mode := getLogoutMode(suse)

var key string
var opts []microstore.ReadOption
switch {
case mode == logoutModeSubject && suse.encodedSubject != "":
case suse.Mode() == LogoutModeSubject && suse.encodedSubject != "":
// the dot at the end prevents prefix exploration in the cache,
// so only keys that start with 'subject.*' will be returned, but not 'sub*'.
key = suse.encodedSubject + "."
opts = append(opts, microstore.ReadPrefix())
case mode == logoutModeSession && suse.encodedSession != "":
case suse.Mode() == LogoutModeSession && suse.encodedSession != "":
// the dot at the beginning prevents sufix exploration in the cache,
// so only keys that end with '*.session' will be returned, but not '*sion'.
key = "." + suse.encodedSession
Expand All @@ -156,7 +157,7 @@ func GetLogoutRecords(suse SuSe, store microstore.Store) ([]*microstore.Record,
return nil, microstore.ErrNotFound
}

if mode == logoutModeSession && len(records) > 1 {
if suse.Mode() == LogoutModeSession && len(records) > 1 {
return nil, errors.Join(errors.New("multiple session records found"), ErrSuspiciousCacheResult)
}

Expand All @@ -171,10 +172,10 @@ func GetLogoutRecords(suse SuSe, store microstore.Store) ([]*microstore.Record,

switch {
// in subject mode, the subject must match, but the session id can be different
case mode == logoutModeSubject && suse.encodedSubject == recordSuSe.encodedSubject:
case suse.Mode() == LogoutModeSubject && suse.encodedSubject == recordSuSe.encodedSubject:
continue
// in session mode, the session id must match, but the subject can be different
case mode == logoutModeSession && suse.encodedSession == recordSuSe.encodedSession:
case suse.Mode() == LogoutModeSession && suse.encodedSession == recordSuSe.encodedSession:
continue
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,33 @@ func TestNewSuSe(t *testing.T) {
key string
wantSubject string
wantSession string
wantMode LogoutMode
wantErr error
}{
{
name: "key variation: '.session'",
key: mustNewKey(t, "", "session"),
wantSession: "session",
wantMode: LogoutModeSession,
},
{
name: "key variation: 'session'",
key: mustNewKey(t, "", "session"),
wantSession: "session",
wantMode: LogoutModeSession,
},
{
name: "key variation: 'subject.'",
key: mustNewKey(t, "subject", ""),
wantSubject: "subject",
wantMode: LogoutModeSubject,
},
{
name: "key variation: 'subject.session'",
key: mustNewKey(t, "subject", "session"),
wantSubject: "subject",
wantSession: "session",
wantMode: LogoutModeSession,
},
{
name: "key variation: 'dot'",
Expand All @@ -103,19 +108,22 @@ func TestNewSuSe(t *testing.T) {
wantErr: ErrInvalidSubjectOrSession,
},
{
name: "key variation: string('subject.session')",
key: "subject.session",
wantErr: ErrInvalidSubjectOrSession,
name: "key variation: string('subject.session')",
key: "subject.session",
wantErr: ErrInvalidSubjectOrSession,
wantMode: LogoutModeSession,
},
{
name: "key variation: string('subject.')",
key: "subject.",
wantErr: ErrInvalidSubjectOrSession,
name: "key variation: string('subject.')",
key: "subject.",
wantErr: ErrInvalidSubjectOrSession,
wantMode: LogoutModeSubject,
},
{
name: "key variation: string('.session')",
key: ".session",
wantErr: ErrInvalidSubjectOrSession,
name: "key variation: string('.session')",
key: ".session",
wantErr: ErrInvalidSubjectOrSession,
wantMode: LogoutModeSession,
},
}

Expand All @@ -124,6 +132,9 @@ func TestNewSuSe(t *testing.T) {
suSe, err := NewSuSe(tt.key)
require.ErrorIs(t, err, tt.wantErr)

mode := suSe.Mode()
require.Equal(t, tt.wantMode, mode)

subject, _ := suSe.Subject()
require.Equal(t, tt.wantSubject, subject)

Expand All @@ -133,42 +144,6 @@ func TestNewSuSe(t *testing.T) {
}
}

func TestGetLogoutMode(t *testing.T) {
tests := []struct {
name string
suSe SuSe
want logoutMode
}{
{
name: "key variation: '.session'",
suSe: mustNewSuSe(t, "", "session"),
want: logoutModeSession,
},
{
name: "key variation: 'subject.session'",
suSe: mustNewSuSe(t, "subject", "session"),
want: logoutModeSession,
},
{
name: "key variation: 'subject.'",
suSe: mustNewSuSe(t, "subject", ""),
want: logoutModeSubject,
},
{
name: "key variation: 'empty'",
suSe: SuSe{},
want: logoutModeUndefined,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mode := getLogoutMode(tt.suSe)
require.Equal(t, tt.want, mode)
})
}
}

func TestGetLogoutRecords(t *testing.T) {
sessionStore := store.NewMemoryStore()

Expand Down