diff --git a/services/proxy/pkg/staticroutes/backchannellogout.go b/services/proxy/pkg/staticroutes/backchannellogout.go index 53375d63bb..28ad2f767e 100644 --- a/services/proxy/pkg/staticroutes/backchannellogout.go +++ b/services/proxy/pkg/staticroutes/backchannellogout.go @@ -107,9 +107,11 @@ func (s *StaticRouteHandler) backchannelLogout(w http.ResponseWriter, r *http.Re continue } - if err := s.publishBackchannelLogoutEvent(r.Context(), session, value); err != nil { - s.Logger.Warn().Err(err).Msgf("failed to publish backchannel logout event for: %s", key) - continue + if requestSubjectAndSession.Mode() == bcl.LogoutModeSession { + if err := s.publishBackchannelLogoutEvent(r.Context(), session, value); err != nil { + s.Logger.Warn().Err(err).Msgf("failed to publish backchannel logout event for: %s", key) + continue + } } err = s.UserInfoCache.Delete(value) diff --git a/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout.go b/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout.go index fa29c24fb3..1d906233cd 100644 --- a/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout.go +++ b/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout.go @@ -31,6 +31,18 @@ func NewKey(subject, session string) (string, error) { return subjectSession, nil } +// LogoutMode defines the mode of backchannel logout, either by session or by subject +type LogoutMode int + +const ( + // LogoutModeUndefined is used when the logout mode cannot be determined + LogoutModeUndefined LogoutMode = iota + // LogoutModeSubject is used when the logout mode is determined by the subject + LogoutModeSubject + // LogoutModeSession is used when the logout mode is determined by the session id + LogoutModeSession +) + // ErrDecoding is returned when decoding fails var ErrDecoding = errors.New("failed to decode") @@ -62,6 +74,18 @@ func (suse SuSe) Session() (string, error) { return string(subject), nil } +// Mode determines the backchannel logout mode based on the presence of subject and session +func (suse SuSe) Mode() LogoutMode { + switch { + case suse.encodedSession == "" && suse.encodedSubject != "": + return LogoutModeSubject + case suse.encodedSession != "": + return LogoutModeSession + default: + return LogoutModeUndefined + } +} + // ErrInvalidSubjectOrSession is returned when the provided key does not match the expected key format var ErrInvalidSubjectOrSession = errors.New("invalid subject or session") @@ -91,31 +115,11 @@ func NewSuSe(key string) (SuSe, error) { return suse, errors.Join(ErrInvalidSubjectOrSession, err) } - return suse, nil -} - -// logoutMode defines the mode of backchannel logout, either by session or by subject -type logoutMode int - -const ( - // logoutModeUndefined is used when the logout mode cannot be determined - logoutModeUndefined logoutMode = iota - // logoutModeSubject is used when the logout mode is determined by the subject - logoutModeSubject - // logoutModeSession is used when the logout mode is determined by the session id - logoutModeSession -) - -// getLogoutMode determines the backchannel logout mode based on the presence of subject and session in the SuSe struct -func getLogoutMode(suse SuSe) logoutMode { - switch { - case suse.encodedSession == "" && suse.encodedSubject != "": - return logoutModeSubject - case suse.encodedSession != "": - return logoutModeSession - default: - return logoutModeUndefined + if mode := suse.Mode(); mode == LogoutModeUndefined { + return suse, ErrInvalidSubjectOrSession } + + return suse, nil } // ErrSuspiciousCacheResult is returned when the cache result is suspicious @@ -126,18 +130,15 @@ var ErrSuspiciousCacheResult = errors.New("suspicious cache result") // it uses a seperator to prevent sufix and prefix exploration in the cache and checks // if the retrieved records match the requested subject and or session id as well, to prevent false positives. func GetLogoutRecords(suse SuSe, store microstore.Store) ([]*microstore.Record, error) { - // get subject.session mode - mode := getLogoutMode(suse) - var key string var opts []microstore.ReadOption switch { - case mode == logoutModeSubject && suse.encodedSubject != "": + case suse.Mode() == LogoutModeSubject && suse.encodedSubject != "": // the dot at the end prevents prefix exploration in the cache, // so only keys that start with 'subject.*' will be returned, but not 'sub*'. key = suse.encodedSubject + "." opts = append(opts, microstore.ReadPrefix()) - case mode == logoutModeSession && suse.encodedSession != "": + case suse.Mode() == LogoutModeSession && suse.encodedSession != "": // the dot at the beginning prevents sufix exploration in the cache, // so only keys that end with '*.session' will be returned, but not '*sion'. key = "." + suse.encodedSession @@ -156,7 +157,7 @@ func GetLogoutRecords(suse SuSe, store microstore.Store) ([]*microstore.Record, return nil, microstore.ErrNotFound } - if mode == logoutModeSession && len(records) > 1 { + if suse.Mode() == LogoutModeSession && len(records) > 1 { return nil, errors.Join(errors.New("multiple session records found"), ErrSuspiciousCacheResult) } @@ -171,10 +172,10 @@ func GetLogoutRecords(suse SuSe, store microstore.Store) ([]*microstore.Record, switch { // in subject mode, the subject must match, but the session id can be different - case mode == logoutModeSubject && suse.encodedSubject == recordSuSe.encodedSubject: + case suse.Mode() == LogoutModeSubject && suse.encodedSubject == recordSuSe.encodedSubject: continue // in session mode, the session id must match, but the subject can be different - case mode == logoutModeSession && suse.encodedSession == recordSuSe.encodedSession: + case suse.Mode() == LogoutModeSession && suse.encodedSession == recordSuSe.encodedSession: continue } diff --git a/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout_test.go b/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout_test.go index 617bd6d9e0..40f99f6a4c 100644 --- a/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout_test.go +++ b/services/proxy/pkg/staticroutes/internal/backchannellogout/backchannellogout_test.go @@ -69,28 +69,33 @@ func TestNewSuSe(t *testing.T) { key string wantSubject string wantSession string + wantMode LogoutMode wantErr error }{ { name: "key variation: '.session'", key: mustNewKey(t, "", "session"), wantSession: "session", + wantMode: LogoutModeSession, }, { name: "key variation: 'session'", key: mustNewKey(t, "", "session"), wantSession: "session", + wantMode: LogoutModeSession, }, { name: "key variation: 'subject.'", key: mustNewKey(t, "subject", ""), wantSubject: "subject", + wantMode: LogoutModeSubject, }, { name: "key variation: 'subject.session'", key: mustNewKey(t, "subject", "session"), wantSubject: "subject", wantSession: "session", + wantMode: LogoutModeSession, }, { name: "key variation: 'dot'", @@ -103,19 +108,22 @@ func TestNewSuSe(t *testing.T) { wantErr: ErrInvalidSubjectOrSession, }, { - name: "key variation: string('subject.session')", - key: "subject.session", - wantErr: ErrInvalidSubjectOrSession, + name: "key variation: string('subject.session')", + key: "subject.session", + wantErr: ErrInvalidSubjectOrSession, + wantMode: LogoutModeSession, }, { - name: "key variation: string('subject.')", - key: "subject.", - wantErr: ErrInvalidSubjectOrSession, + name: "key variation: string('subject.')", + key: "subject.", + wantErr: ErrInvalidSubjectOrSession, + wantMode: LogoutModeSubject, }, { - name: "key variation: string('.session')", - key: ".session", - wantErr: ErrInvalidSubjectOrSession, + name: "key variation: string('.session')", + key: ".session", + wantErr: ErrInvalidSubjectOrSession, + wantMode: LogoutModeSession, }, } @@ -124,6 +132,9 @@ func TestNewSuSe(t *testing.T) { suSe, err := NewSuSe(tt.key) require.ErrorIs(t, err, tt.wantErr) + mode := suSe.Mode() + require.Equal(t, tt.wantMode, mode) + subject, _ := suSe.Subject() require.Equal(t, tt.wantSubject, subject) @@ -133,42 +144,6 @@ func TestNewSuSe(t *testing.T) { } } -func TestGetLogoutMode(t *testing.T) { - tests := []struct { - name string - suSe SuSe - want logoutMode - }{ - { - name: "key variation: '.session'", - suSe: mustNewSuSe(t, "", "session"), - want: logoutModeSession, - }, - { - name: "key variation: 'subject.session'", - suSe: mustNewSuSe(t, "subject", "session"), - want: logoutModeSession, - }, - { - name: "key variation: 'subject.'", - suSe: mustNewSuSe(t, "subject", ""), - want: logoutModeSubject, - }, - { - name: "key variation: 'empty'", - suSe: SuSe{}, - want: logoutModeUndefined, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mode := getLogoutMode(tt.suSe) - require.Equal(t, tt.want, mode) - }) - } -} - func TestGetLogoutRecords(t *testing.T) { sessionStore := store.NewMemoryStore()