Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions internal/api/oauthserver/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er
}

baseURL := s.buildAuthorizationURL(config.SiteURL, config.OAuthServer.AuthorizationPath)
redirectURL := fmt.Sprintf("%s?authorization_id=%s", baseURL, authorization.AuthorizationID)
redirectURL := s.buildAuthorizationRedirectURL(baseURL, authorization.AuthorizationID)

http.Redirect(w, r, redirectURL, http.StatusFound)
return nil
Expand Down Expand Up @@ -622,8 +622,13 @@ func (s *Server) buildErrorRedirectURL(redirectURI, errorCode, errorDescription,
return u.String()
}

// buildAuthorizationURL safely joins a base URL with a path, handling slashes correctly
// buildAuthorizationURL safely joins a base URL with a path, handling slashes correctly.
// If pathToJoin is an absolute URL, it is returned as-is.
func (s *Server) buildAuthorizationURL(baseURL, pathToJoin string) string {
if parsed, err := url.Parse(pathToJoin); err == nil && parsed.IsAbs() {
return parsed.String()
}

// Trim trailing slash from baseURL
baseURL = strings.TrimRight(baseURL, "/")

Expand All @@ -634,3 +639,16 @@ func (s *Server) buildAuthorizationURL(baseURL, pathToJoin string) string {

return baseURL + pathToJoin
}

// buildAuthorizationRedirectURL appends authorization_id while preserving existing query params/fragments.
func (s *Server) buildAuthorizationRedirectURL(baseURL, authorizationID string) string {
u, err := url.Parse(baseURL)
if err != nil {
return fmt.Sprintf("%s?authorization_id=%s", baseURL, authorizationID)
}

q := u.Query()
q.Set("authorization_id", authorizationID)
u.RawQuery = q.Encode()
return u.String()
}
76 changes: 76 additions & 0 deletions internal/api/oauthserver/authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,82 @@ func TestValidateRequestOrigin(t *testing.T) {
}
}

func TestBuildAuthorizationURL(t *testing.T) {
server := &Server{}

tests := []struct {
name string
baseURL string
pathToJoin string
expected string
}{
{
name: "joins relative path with leading slash",
baseURL: "https://example.com",
pathToJoin: "/oauth/consent",
expected: "https://example.com/oauth/consent",
},
{
name: "joins relative path without leading slash",
baseURL: "https://example.com/",
pathToJoin: "oauth/consent",
expected: "https://example.com/oauth/consent",
},
{
name: "returns absolute path unchanged",
baseURL: "https://example.com",
pathToJoin: "https://app.example.com/custom-consent",
expected: "https://app.example.com/custom-consent",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := server.buildAuthorizationURL(tt.baseURL, tt.pathToJoin)
assert.Equal(t, tt.expected, actual)
})
}
}



func TestBuildAuthorizationRedirectURL(t *testing.T) {
server := &Server{}

tests := []struct {
name string
baseURL string
authorizationID string
expected string
}{
{
name: "adds authorization_id to URL without query",
baseURL: "https://app.example.com/custom-consent",
authorizationID: "auth-123",
expected: "https://app.example.com/custom-consent?authorization_id=auth-123",
},
{
name: "preserves existing query parameters",
baseURL: "https://app.example.com/custom-consent?foo=bar",
authorizationID: "auth-123",
expected: "https://app.example.com/custom-consent?authorization_id=auth-123&foo=bar",
},
{
name: "preserves fragment",
baseURL: "https://app.example.com/custom-consent?foo=bar#frag",
authorizationID: "auth-123",
expected: "https://app.example.com/custom-consent?authorization_id=auth-123&foo=bar#frag",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := server.buildAuthorizationRedirectURL(tt.baseURL, tt.authorizationID)
assert.Equal(t, tt.expected, actual)
})
}
}

func TestValidateRequestOriginEdgeCases(t *testing.T) {
globalConfig, err := conf.LoadGlobal(oauthServerTestConfig)
require.NoError(t, err)
Expand Down