From 0ef5a75912a714f4a61205d91ffed82f7e5dd6b8 Mon Sep 17 00:00:00 2001 From: Atharva Patil <53966412+atharva1051@users.noreply.github.com> Date: Thu, 19 Feb 2026 12:15:35 +0000 Subject: [PATCH 1/7] chore: gh-host as oauth auth server --- pkg/http/server.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/http/server.go b/pkg/http/server.go index 7397e54a8..a7dd0b2c9 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -126,6 +126,7 @@ func RunHTTPServer(cfg ServerConfig) error { oauthCfg := &oauth.Config{ BaseURL: cfg.BaseURL, ResourcePath: cfg.ResourcePath, + AuthorizationServer: cfg.Host + "login/oauth", } serverOptions := []HandlerOption{} From ee63c5146eef55891b070fe79850cda2ef2fd701 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 19 Feb 2026 12:25:34 +0000 Subject: [PATCH 2/7] Initial plan From 3faac62ccff69ec9d6689d04316779a101324d41 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 19 Feb 2026 12:29:29 +0000 Subject: [PATCH 3/7] fix: only set AuthorizationServer when gh-host is explicitly configured When gh-host is not set (github.com users), the AuthorizationServer field was being set to just "login/oauth" (empty host + path), breaking OAuth metadata. Now it only overrides the default when gh-host is provided (GHES users), allowing github.com users to get the correct default of "https://github.com/login/oauth". Co-authored-by: atharva1051 <53966412+atharva1051@users.noreply.github.com> --- pkg/http/server.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/http/server.go b/pkg/http/server.go index a7dd0b2c9..25a7ae1d4 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -126,7 +126,9 @@ func RunHTTPServer(cfg ServerConfig) error { oauthCfg := &oauth.Config{ BaseURL: cfg.BaseURL, ResourcePath: cfg.ResourcePath, - AuthorizationServer: cfg.Host + "login/oauth", + } + if cfg.Host != "" { + oauthCfg.AuthorizationServer = cfg.Host + "login/oauth" } serverOptions := []HandlerOption{} From f5b8ecc600939bb72b8a689011e5b1aa8225a59f Mon Sep 17 00:00:00 2001 From: Atharva Patil <53966412+atharva1051@users.noreply.github.com> Date: Thu, 19 Feb 2026 13:53:08 +0000 Subject: [PATCH 4/7] fix: joining url using net/url --- pkg/http/server.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pkg/http/server.go b/pkg/http/server.go index 25a7ae1d4..a35ce7a60 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -6,6 +6,7 @@ import ( "io" "log/slog" "net/http" + "net/url" "os" "os/signal" "slices" @@ -128,7 +129,12 @@ func RunHTTPServer(cfg ServerConfig) error { ResourcePath: cfg.ResourcePath, } if cfg.Host != "" { - oauthCfg.AuthorizationServer = cfg.Host + "login/oauth" + u := &url.URL{ + Scheme: "https", + Host: cfg.Host, + Path: "/login/oauth", + } + oauthCfg.AuthorizationServer = u.String() } serverOptions := []HandlerOption{} From 412c6c99b3a62f5c3f4093d42ac0b20c3c4b1517 Mon Sep 17 00:00:00 2001 From: Atharva Patil <53966412+atharva1051@users.noreply.github.com> Date: Mon, 23 Feb 2026 17:11:02 +0000 Subject: [PATCH 5/7] fix: extended api host to have a oauth url --- pkg/http/server.go | 10 ++- pkg/scopes/fetcher_test.go | 3 + pkg/utils/api.go | 24 +++++++ pkg/utils/api_test.go | 140 +++++++++++++++++++++++++++++++++++++ 4 files changed, 171 insertions(+), 6 deletions(-) create mode 100644 pkg/utils/api_test.go diff --git a/pkg/http/server.go b/pkg/http/server.go index a35ce7a60..ae0bb46d5 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -6,7 +6,6 @@ import ( "io" "log/slog" "net/http" - "net/url" "os" "os/signal" "slices" @@ -129,12 +128,11 @@ func RunHTTPServer(cfg ServerConfig) error { ResourcePath: cfg.ResourcePath, } if cfg.Host != "" { - u := &url.URL{ - Scheme: "https", - Host: cfg.Host, - Path: "/login/oauth", + oauthURL, err := apiHost.OAuthURL(ctx) + if err != nil { + return fmt.Errorf("failed to get OAuth URL: %w", err) } - oauthCfg.AuthorizationServer = u.String() + oauthCfg.AuthorizationServer = oauthURL.String() } serverOptions := []HandlerOption{} diff --git a/pkg/scopes/fetcher_test.go b/pkg/scopes/fetcher_test.go index 2d887d7a8..08520489c 100644 --- a/pkg/scopes/fetcher_test.go +++ b/pkg/scopes/fetcher_test.go @@ -28,6 +28,9 @@ func (t testAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) { func (t testAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) { return nil, nil } +func (t testAPIHostResolver) OAuthURL(_ context.Context) (*url.URL, error) { + return nil, nil +} func TestParseScopeHeader(t *testing.T) { tests := []struct { diff --git a/pkg/utils/api.go b/pkg/utils/api.go index a523917de..6e02530fb 100644 --- a/pkg/utils/api.go +++ b/pkg/utils/api.go @@ -14,6 +14,7 @@ type APIHostResolver interface { GraphqlURL(ctx context.Context) (*url.URL, error) UploadURL(ctx context.Context) (*url.URL, error) RawURL(ctx context.Context) (*url.URL, error) + OAuthURL(ctx context.Context) (*url.URL, error) } type APIHost struct { @@ -21,6 +22,7 @@ type APIHost struct { gqlURL *url.URL uploadURL *url.URL rawURL *url.URL + oauthURL *url.URL } var _ APIHostResolver = APIHost{} @@ -52,6 +54,10 @@ func (a APIHost) RawURL(_ context.Context) (*url.URL, error) { return a.rawURL, nil } +func (a APIHost) OAuthURL(_ context.Context) (*url.URL, error) { + return a.oauthURL, nil +} + func newDotcomHost() (APIHost, error) { baseRestURL, err := url.Parse("https://api.github.com/") if err != nil { @@ -73,11 +79,17 @@ func newDotcomHost() (APIHost, error) { return APIHost{}, fmt.Errorf("failed to parse dotcom Raw URL: %w", err) } + oauthURL, err := url.Parse("https://github.com/login/oauth") + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse dotcom OAuth URL: %w", err) + } + return APIHost{ restURL: baseRestURL, gqlURL: gqlURL, uploadURL: uploadURL, rawURL: rawURL, + oauthURL: oauthURL, }, nil } @@ -112,11 +124,17 @@ func newGHECHost(hostname string) (APIHost, error) { return APIHost{}, fmt.Errorf("failed to parse GHEC Raw URL: %w", err) } + oauthURL, err := url.Parse(fmt.Sprintf("https://%s/login/oauth", u.Hostname())) + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHEC OAuth URL: %w", err) + } + return APIHost{ restURL: restURL, gqlURL: gqlURL, uploadURL: uploadURL, rawURL: rawURL, + oauthURL: oauthURL, }, nil } @@ -164,11 +182,17 @@ func newGHESHost(hostname string) (APIHost, error) { return APIHost{}, fmt.Errorf("failed to parse GHES Raw URL: %w", err) } + oauthURL, err := url.Parse(fmt.Sprintf("%s://%s/login/oauth", u.Scheme, u.Hostname())) + if err != nil { + return APIHost{}, fmt.Errorf("failed to parse GHES OAuth URL: %w", err) + } + return APIHost{ restURL: restURL, gqlURL: gqlURL, uploadURL: uploadURL, rawURL: rawURL, + oauthURL: oauthURL, }, nil } diff --git a/pkg/utils/api_test.go b/pkg/utils/api_test.go new file mode 100644 index 000000000..8fa9279ab --- /dev/null +++ b/pkg/utils/api_test.go @@ -0,0 +1,140 @@ +package utils //nolint:revive //TODO: figure out a better name for this package + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOAuthURL(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + host string + expectedOAuth string + expectError bool + errorSubstring string + }{ + { + name: "dotcom (empty host)", + host: "", + expectedOAuth: "https://github.com/login/oauth", + }, + { + name: "dotcom (explicit github.com)", + host: "https://github.com", + expectedOAuth: "https://github.com/login/oauth", + }, + { + name: "GHEC with HTTPS", + host: "https://acme.ghe.com", + expectedOAuth: "https://acme.ghe.com/login/oauth", + }, + { + name: "GHEC with HTTP (should error)", + host: "http://acme.ghe.com", + expectError: true, + errorSubstring: "GHEC URL must be HTTPS", + }, + { + name: "GHES with HTTPS", + host: "https://ghes.example.com", + expectedOAuth: "https://ghes.example.com/login/oauth", + }, + { + name: "GHES with HTTP", + host: "http://ghes.example.com", + expectedOAuth: "http://ghes.example.com/login/oauth", + }, + { + name: "GHES with HTTP and custom port (port stripped - not supported yet)", + host: "http://ghes.local:8080", + expectedOAuth: "http://ghes.local/login/oauth", // Port is stripped ref: ln222 api.go comment + }, + { + name: "GHES with HTTPS and custom port (port stripped - not supported yet)", + host: "https://ghes.local:8443", + expectedOAuth: "https://ghes.local/login/oauth", // Port is stripped ref: ln222 api.go comment + }, + { + name: "host without scheme (should error)", + host: "ghes.example.com", + expectError: true, + errorSubstring: "host must have a scheme", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + apiHost, err := NewAPIHost(tt.host) + + if tt.expectError { + require.Error(t, err) + if tt.errorSubstring != "" { + assert.Contains(t, err.Error(), tt.errorSubstring) + } + return + } + + require.NoError(t, err) + require.NotNil(t, apiHost) + + oauthURL, err := apiHost.OAuthURL(ctx) + require.NoError(t, err) + require.NotNil(t, oauthURL) + + assert.Equal(t, tt.expectedOAuth, oauthURL.String()) + }) + } +} + +func TestAPIHost_AllURLsHaveConsistentScheme(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + host string + expectedScheme string + }{ + { + name: "GHES with HTTPS", + host: "https://ghes.example.com", + expectedScheme: "https", + }, + { + name: "GHES with HTTP", + host: "http://ghes.example.com", + expectedScheme: "http", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + apiHost, err := NewAPIHost(tt.host) + require.NoError(t, err) + + restURL, err := apiHost.BaseRESTURL(ctx) + require.NoError(t, err) + assert.Equal(t, tt.expectedScheme, restURL.Scheme, "REST URL scheme should match") + + gqlURL, err := apiHost.GraphqlURL(ctx) + require.NoError(t, err) + assert.Equal(t, tt.expectedScheme, gqlURL.Scheme, "GraphQL URL scheme should match") + + uploadURL, err := apiHost.UploadURL(ctx) + require.NoError(t, err) + assert.Equal(t, tt.expectedScheme, uploadURL.Scheme, "Upload URL scheme should match") + + rawURL, err := apiHost.RawURL(ctx) + require.NoError(t, err) + assert.Equal(t, tt.expectedScheme, rawURL.Scheme, "Raw URL scheme should match") + + oauthURL, err := apiHost.OAuthURL(ctx) + require.NoError(t, err) + assert.Equal(t, tt.expectedScheme, oauthURL.Scheme, "OAuth URL scheme should match") + }) + } +} From fbc5ba1a6b8b226116541496344484940500e4cf Mon Sep 17 00:00:00 2001 From: Atharva Patil <53966412+atharva1051@users.noreply.github.com> Date: Mon, 23 Feb 2026 17:19:32 +0000 Subject: [PATCH 6/7] chore: oauth accepts APIHostResolver as argument --- pkg/http/oauth/oauth.go | 18 ++++++++++-- pkg/http/oauth/oauth_test.go | 56 ++++++++++++++++++++++++++++++++++++ pkg/http/server.go | 8 +----- 3 files changed, 73 insertions(+), 9 deletions(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 5da253566..6e2337978 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -3,11 +3,13 @@ package oauth import ( + "context" "fmt" "net/http" "strings" "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/utils" "github.com/go-chi/chi/v5" "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/oauthex" @@ -43,8 +45,13 @@ type Config struct { // This is used to construct the OAuth resource URL. BaseURL string + // APIHost is the GitHub API host resolver that provides OAuth URL. + // If set, this takes precedence over AuthorizationServer. + APIHost utils.APIHostResolver + // AuthorizationServer is the OAuth authorization server URL. // Defaults to GitHub's OAuth server if not specified. + // This field is ignored if APIHost is set. AuthorizationServer string // ResourcePath is the externally visible base path for the MCP server (e.g., "/mcp"). @@ -64,8 +71,15 @@ func NewAuthHandler(cfg *Config) (*AuthHandler, error) { cfg = &Config{} } - // Default authorization server to GitHub - if cfg.AuthorizationServer == "" { + // Resolve authorization server from APIHost if provided + if cfg.APIHost != nil { + oauthURL, err := cfg.APIHost.OAuthURL(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to get OAuth URL from API host: %w", err) + } + cfg.AuthorizationServer = oauthURL.String() + } else if cfg.AuthorizationServer == "" { + // Default authorization server to GitHub if not provided cfg.AuthorizationServer = DefaultAuthorizationServer } diff --git a/pkg/http/oauth/oauth_test.go b/pkg/http/oauth/oauth_test.go index 9133e8331..78e156649 100644 --- a/pkg/http/oauth/oauth_test.go +++ b/pkg/http/oauth/oauth_test.go @@ -1,18 +1,49 @@ package oauth import ( + "context" "crypto/tls" "encoding/json" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/utils" "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// mockAPIHostResolver is a test implementation of utils.APIHostResolver +type mockAPIHostResolver struct { + oauthURL string +} + +func (m mockAPIHostResolver) BaseRESTURL(_ context.Context) (*url.URL, error) { + return nil, nil +} + +func (m mockAPIHostResolver) GraphqlURL(_ context.Context) (*url.URL, error) { + return nil, nil +} + +func (m mockAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) { + return nil, nil +} + +func (m mockAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) { + return nil, nil +} + +func (m mockAPIHostResolver) OAuthURL(_ context.Context) (*url.URL, error) { + return url.Parse(m.oauthURL) +} + +// Ensure mockAPIHostResolver implements utils.APIHostResolver +var _ utils.APIHostResolver = mockAPIHostResolver{} + func TestNewAuthHandler(t *testing.T) { t.Parallel() @@ -51,6 +82,31 @@ func TestNewAuthHandler(t *testing.T) { expectedAuthServer: DefaultAuthorizationServer, expectedResourcePath: "/mcp", }, + { + name: "APIHost with HTTPS GHES", + cfg: &Config{ + APIHost: mockAPIHostResolver{oauthURL: "https://ghes.example.com/login/oauth"}, + }, + expectedAuthServer: "https://ghes.example.com/login/oauth", + expectedResourcePath: "", + }, + { + name: "APIHost with HTTP GHES", + cfg: &Config{ + APIHost: mockAPIHostResolver{oauthURL: "http://ghes.local/login/oauth"}, + }, + expectedAuthServer: "http://ghes.local/login/oauth", + expectedResourcePath: "", + }, + { + name: "APIHost takes precedence over AuthorizationServer", + cfg: &Config{ + APIHost: mockAPIHostResolver{oauthURL: "https://ghes.example.com/login/oauth"}, + AuthorizationServer: "https://should-be-ignored.example.com/oauth", + }, + expectedAuthServer: "https://ghes.example.com/login/oauth", + expectedResourcePath: "", + }, } for _, tc := range tests { diff --git a/pkg/http/server.go b/pkg/http/server.go index ae0bb46d5..1de400783 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -126,13 +126,7 @@ func RunHTTPServer(cfg ServerConfig) error { oauthCfg := &oauth.Config{ BaseURL: cfg.BaseURL, ResourcePath: cfg.ResourcePath, - } - if cfg.Host != "" { - oauthURL, err := apiHost.OAuthURL(ctx) - if err != nil { - return fmt.Errorf("failed to get OAuth URL: %w", err) - } - oauthCfg.AuthorizationServer = oauthURL.String() + APIHost: apiHost, } serverOptions := []HandlerOption{} From 60717a7bf7b5c3d1b7d5952b7c549c5606ba7e19 Mon Sep 17 00:00:00 2001 From: Atharva Patil <53966412+atharva1051@users.noreply.github.com> Date: Tue, 24 Feb 2026 09:09:52 +0000 Subject: [PATCH 7/7] chore: made it similar to #2070 --- pkg/http/oauth/oauth.go | 44 +++---- pkg/http/oauth/oauth_test.go | 214 +++++++++++++++++++++++------------ pkg/http/server.go | 3 +- pkg/scopes/fetcher_test.go | 2 +- pkg/utils/api.go | 59 +++++----- pkg/utils/api_test.go | 140 ----------------------- 6 files changed, 195 insertions(+), 267 deletions(-) delete mode 100644 pkg/utils/api_test.go diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index 6e2337978..3b4d41959 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -3,7 +3,6 @@ package oauth import ( - "context" "fmt" "net/http" "strings" @@ -18,9 +17,6 @@ import ( const ( // OAuthProtectedResourcePrefix is the well-known path prefix for OAuth protected resource metadata. OAuthProtectedResourcePrefix = "/.well-known/oauth-protected-resource" - - // DefaultAuthorizationServer is GitHub's OAuth authorization server. - DefaultAuthorizationServer = "https://github.com/login/oauth" ) // SupportedScopes lists all OAuth scopes that may be required by MCP tools. @@ -45,13 +41,8 @@ type Config struct { // This is used to construct the OAuth resource URL. BaseURL string - // APIHost is the GitHub API host resolver that provides OAuth URL. - // If set, this takes precedence over AuthorizationServer. - APIHost utils.APIHostResolver - // AuthorizationServer is the OAuth authorization server URL. // Defaults to GitHub's OAuth server if not specified. - // This field is ignored if APIHost is set. AuthorizationServer string // ResourcePath is the externally visible base path for the MCP server (e.g., "/mcp"). @@ -62,29 +53,27 @@ type Config struct { // AuthHandler handles OAuth-related HTTP endpoints. type AuthHandler struct { - cfg *Config + cfg *Config + apiHost utils.APIHostResolver } // NewAuthHandler creates a new OAuth auth handler. -func NewAuthHandler(cfg *Config) (*AuthHandler, error) { +func NewAuthHandler(cfg *Config, apiHost utils.APIHostResolver) (*AuthHandler, error) { if cfg == nil { cfg = &Config{} } - // Resolve authorization server from APIHost if provided - if cfg.APIHost != nil { - oauthURL, err := cfg.APIHost.OAuthURL(context.Background()) + if apiHost == nil { + var err error + apiHost, err = utils.NewAPIHost("https://api.github.com") if err != nil { - return nil, fmt.Errorf("failed to get OAuth URL from API host: %w", err) + return nil, fmt.Errorf("failed to create default API host: %w", err) } - cfg.AuthorizationServer = oauthURL.String() - } else if cfg.AuthorizationServer == "" { - // Default authorization server to GitHub if not provided - cfg.AuthorizationServer = DefaultAuthorizationServer } return &AuthHandler{ - cfg: cfg, + cfg: cfg, + apiHost: apiHost, }, nil } @@ -109,15 +98,28 @@ func (h *AuthHandler) RegisterRoutes(r chi.Router) { func (h *AuthHandler) metadataHandler() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() resourcePath := resolveResourcePath( strings.TrimPrefix(r.URL.Path, OAuthProtectedResourcePrefix), h.cfg.ResourcePath, ) resourceURL := h.buildResourceURL(r, resourcePath) + var authorizationServerURL string + if h.cfg.AuthorizationServer != "" { + authorizationServerURL = h.cfg.AuthorizationServer + } else { + authURL, err := h.apiHost.AuthorizationServerURL(ctx) + if err != nil { + http.Error(w, fmt.Sprintf("failed to resolve authorization server URL: %v", err), http.StatusInternalServerError) + return + } + authorizationServerURL = authURL.String() + } + metadata := &oauthex.ProtectedResourceMetadata{ Resource: resourceURL, - AuthorizationServers: []string{h.cfg.AuthorizationServer}, + AuthorizationServers: []string{authorizationServerURL}, ResourceName: "GitHub MCP Server", ScopesSupported: SupportedScopes, BearerMethodsSupported: []string{"header"}, diff --git a/pkg/http/oauth/oauth_test.go b/pkg/http/oauth/oauth_test.go index 78e156649..6d76b579f 100644 --- a/pkg/http/oauth/oauth_test.go +++ b/pkg/http/oauth/oauth_test.go @@ -1,12 +1,10 @@ package oauth import ( - "context" "crypto/tls" "encoding/json" "net/http" "net/http/httptest" - "net/url" "testing" "github.com/github/github-mcp-server/pkg/http/headers" @@ -16,55 +14,22 @@ import ( "github.com/stretchr/testify/require" ) -// mockAPIHostResolver is a test implementation of utils.APIHostResolver -type mockAPIHostResolver struct { - oauthURL string -} - -func (m mockAPIHostResolver) BaseRESTURL(_ context.Context) (*url.URL, error) { - return nil, nil -} - -func (m mockAPIHostResolver) GraphqlURL(_ context.Context) (*url.URL, error) { - return nil, nil -} - -func (m mockAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) { - return nil, nil -} - -func (m mockAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) { - return nil, nil -} - -func (m mockAPIHostResolver) OAuthURL(_ context.Context) (*url.URL, error) { - return url.Parse(m.oauthURL) -} - -// Ensure mockAPIHostResolver implements utils.APIHostResolver -var _ utils.APIHostResolver = mockAPIHostResolver{} +var ( + defaultAuthorizationServer = "https://github.com/login/oauth" +) func TestNewAuthHandler(t *testing.T) { t.Parallel() + dotcomHost, err := utils.NewAPIHost("https://api.github.com") + require.NoError(t, err) + tests := []struct { name string cfg *Config expectedAuthServer string expectedResourcePath string }{ - { - name: "nil config uses defaults", - cfg: nil, - expectedAuthServer: DefaultAuthorizationServer, - expectedResourcePath: "", - }, - { - name: "empty config uses defaults", - cfg: &Config{}, - expectedAuthServer: DefaultAuthorizationServer, - expectedResourcePath: "", - }, { name: "custom authorization server", cfg: &Config{ @@ -79,45 +44,21 @@ func TestNewAuthHandler(t *testing.T) { BaseURL: "https://example.com", ResourcePath: "/mcp", }, - expectedAuthServer: DefaultAuthorizationServer, + expectedAuthServer: "", expectedResourcePath: "/mcp", }, - { - name: "APIHost with HTTPS GHES", - cfg: &Config{ - APIHost: mockAPIHostResolver{oauthURL: "https://ghes.example.com/login/oauth"}, - }, - expectedAuthServer: "https://ghes.example.com/login/oauth", - expectedResourcePath: "", - }, - { - name: "APIHost with HTTP GHES", - cfg: &Config{ - APIHost: mockAPIHostResolver{oauthURL: "http://ghes.local/login/oauth"}, - }, - expectedAuthServer: "http://ghes.local/login/oauth", - expectedResourcePath: "", - }, - { - name: "APIHost takes precedence over AuthorizationServer", - cfg: &Config{ - APIHost: mockAPIHostResolver{oauthURL: "https://ghes.example.com/login/oauth"}, - AuthorizationServer: "https://should-be-ignored.example.com/oauth", - }, - expectedAuthServer: "https://ghes.example.com/login/oauth", - expectedResourcePath: "", - }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() - handler, err := NewAuthHandler(tc.cfg) + handler, err := NewAuthHandler(tc.cfg, dotcomHost) require.NoError(t, err) require.NotNil(t, handler) assert.Equal(t, tc.expectedAuthServer, handler.cfg.AuthorizationServer) + assert.Equal(t, tc.expectedResourcePath, handler.cfg.ResourcePath) }) } } @@ -428,7 +369,7 @@ func TestHandleProtectedResource(t *testing.T) { authServers, ok := body["authorization_servers"].([]any) require.True(t, ok) require.Len(t, authServers, 1) - assert.Equal(t, DefaultAuthorizationServer, authServers[0]) + assert.Equal(t, defaultAuthorizationServer, authServers[0]) }, }, { @@ -507,7 +448,10 @@ func TestHandleProtectedResource(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - handler, err := NewAuthHandler(tc.cfg) + dotcomHost, err := utils.NewAPIHost("https://api.github.com") + require.NoError(t, err) + + handler, err := NewAuthHandler(tc.cfg, dotcomHost) require.NoError(t, err) router := chi.NewRouter() @@ -549,9 +493,12 @@ func TestHandleProtectedResource(t *testing.T) { func TestRegisterRoutes(t *testing.T) { t.Parallel() + dotcomHost, err := utils.NewAPIHost("https://api.github.com") + require.NoError(t, err) + handler, err := NewAuthHandler(&Config{ BaseURL: "https://api.example.com", - }) + }, dotcomHost) require.NoError(t, err) router := chi.NewRouter() @@ -615,9 +562,12 @@ func TestSupportedScopes(t *testing.T) { func TestProtectedResourceResponseFormat(t *testing.T) { t.Parallel() + dotcomHost, err := utils.NewAPIHost("https://api.github.com") + require.NoError(t, err) + handler, err := NewAuthHandler(&Config{ BaseURL: "https://api.example.com", - }) + }, dotcomHost) require.NoError(t, err) router := chi.NewRouter() @@ -654,7 +604,7 @@ func TestProtectedResourceResponseFormat(t *testing.T) { authServers, ok := response["authorization_servers"].([]any) require.True(t, ok) assert.Len(t, authServers, 1) - assert.Equal(t, DefaultAuthorizationServer, authServers[0]) + assert.Equal(t, defaultAuthorizationServer, authServers[0]) } func TestOAuthProtectedResourcePrefix(t *testing.T) { @@ -667,5 +617,121 @@ func TestOAuthProtectedResourcePrefix(t *testing.T) { func TestDefaultAuthorizationServer(t *testing.T) { t.Parallel() - assert.Equal(t, "https://github.com/login/oauth", DefaultAuthorizationServer) + assert.Equal(t, "https://github.com/login/oauth", defaultAuthorizationServer) +} + +func TestAPIHostResolver_AuthorizationServerURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + host string + oauthConfig *Config + expectedURL string + expectedError bool + expectedStatusCode int + errorContains string + }{ + { + name: "valid host returns authorization server URL", + host: "https://github.com", + expectedURL: "https://github.com/login/oauth", + expectedStatusCode: http.StatusOK, + }, + { + name: "invalid host returns error", + host: "://invalid-url", + expectedURL: "", + expectedError: true, + errorContains: "could not parse host as URL", + }, + { + name: "host without scheme returns error", + host: "github.com", + expectedURL: "", + expectedError: true, + errorContains: "host must have a scheme", + }, + { + name: "GHEC host returns correct authorization server URL", + host: "https://test.ghe.com", + expectedURL: "https://test.ghe.com/login/oauth", + expectedStatusCode: http.StatusOK, + }, + { + name: "GHES host returns correct authorization server URL", + host: "https://ghe.example.com", + expectedURL: "https://ghe.example.com/login/oauth", + expectedStatusCode: http.StatusOK, + }, + { + name: "GHES with http scheme returns the correct authorization server URL", + host: "http://ghe.example.com", + expectedURL: "http://ghe.example.com/login/oauth", + expectedStatusCode: http.StatusOK, + }, + { + name: "custom authorization server in config takes precedence", + host: "https://github.com", + oauthConfig: &Config{ + AuthorizationServer: "https://custom.auth.example.com/oauth", + }, + expectedURL: "https://custom.auth.example.com/oauth", + expectedStatusCode: http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + apiHost, err := utils.NewAPIHost(tc.host) + if tc.expectedError { + require.Error(t, err) + if tc.errorContains != "" { + assert.Contains(t, err.Error(), tc.errorContains) + } + return + } + require.NoError(t, err) + + config := tc.oauthConfig + if config == nil { + config = &Config{} + } + config.BaseURL = tc.host + + handler, err := NewAuthHandler(config, apiHost) + require.NoError(t, err) + + router := chi.NewRouter() + handler.RegisterRoutes(router) + + req := httptest.NewRequest(http.MethodGet, OAuthProtectedResourcePrefix, nil) + req.Host = "api.example.com" + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var response map[string]any + err = json.Unmarshal(rec.Body.Bytes(), &response) + require.NoError(t, err) + + assert.Contains(t, response, "authorization_servers") + if tc.expectedStatusCode != http.StatusOK { + require.Equal(t, tc.expectedStatusCode, rec.Code) + if tc.errorContains != "" { + assert.Contains(t, rec.Body.String(), tc.errorContains) + } + return + } + + responseAuthServers, ok := response["authorization_servers"].([]any) + require.True(t, ok) + require.Len(t, responseAuthServers, 1) + assert.Equal(t, tc.expectedURL, responseAuthServers[0]) + }) + } } diff --git a/pkg/http/server.go b/pkg/http/server.go index 1de400783..872303940 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -126,7 +126,6 @@ func RunHTTPServer(cfg ServerConfig) error { oauthCfg := &oauth.Config{ BaseURL: cfg.BaseURL, ResourcePath: cfg.ResourcePath, - APIHost: apiHost, } serverOptions := []HandlerOption{} @@ -137,7 +136,7 @@ func RunHTTPServer(cfg ServerConfig) error { r := chi.NewRouter() handler := NewHTTPMcpHandler(ctx, &cfg, deps, t, logger, apiHost, append(serverOptions, WithFeatureChecker(featureChecker), WithOAuthConfig(oauthCfg))...) - oauthHandler, err := oauth.NewAuthHandler(oauthCfg) + oauthHandler, err := oauth.NewAuthHandler(oauthCfg, apiHost) if err != nil { return fmt.Errorf("failed to create OAuth handler: %w", err) } diff --git a/pkg/scopes/fetcher_test.go b/pkg/scopes/fetcher_test.go index 08520489c..7ef910a56 100644 --- a/pkg/scopes/fetcher_test.go +++ b/pkg/scopes/fetcher_test.go @@ -28,7 +28,7 @@ func (t testAPIHostResolver) UploadURL(_ context.Context) (*url.URL, error) { func (t testAPIHostResolver) RawURL(_ context.Context) (*url.URL, error) { return nil, nil } -func (t testAPIHostResolver) OAuthURL(_ context.Context) (*url.URL, error) { +func (t testAPIHostResolver) AuthorizationServerURL(_ context.Context) (*url.URL, error) { return nil, nil } diff --git a/pkg/utils/api.go b/pkg/utils/api.go index 6e02530fb..a22711b23 100644 --- a/pkg/utils/api.go +++ b/pkg/utils/api.go @@ -14,15 +14,15 @@ type APIHostResolver interface { GraphqlURL(ctx context.Context) (*url.URL, error) UploadURL(ctx context.Context) (*url.URL, error) RawURL(ctx context.Context) (*url.URL, error) - OAuthURL(ctx context.Context) (*url.URL, error) + AuthorizationServerURL(ctx context.Context) (*url.URL, error) } type APIHost struct { - restURL *url.URL - gqlURL *url.URL - uploadURL *url.URL - rawURL *url.URL - oauthURL *url.URL + restURL *url.URL + gqlURL *url.URL + uploadURL *url.URL + rawURL *url.URL + authorizationServerURL *url.URL } var _ APIHostResolver = APIHost{} @@ -54,8 +54,8 @@ func (a APIHost) RawURL(_ context.Context) (*url.URL, error) { return a.rawURL, nil } -func (a APIHost) OAuthURL(_ context.Context) (*url.URL, error) { - return a.oauthURL, nil +func (a APIHost) AuthorizationServerURL(_ context.Context) (*url.URL, error) { + return a.authorizationServerURL, nil } func newDotcomHost() (APIHost, error) { @@ -79,17 +79,18 @@ func newDotcomHost() (APIHost, error) { return APIHost{}, fmt.Errorf("failed to parse dotcom Raw URL: %w", err) } - oauthURL, err := url.Parse("https://github.com/login/oauth") + // The authorization server for GitHub.com is at github.com/login/oauth, not api.github.com + authorizationServerURL, err := url.Parse("https://github.com/login/oauth") if err != nil { - return APIHost{}, fmt.Errorf("failed to parse dotcom OAuth URL: %w", err) + return APIHost{}, fmt.Errorf("failed to parse dotcom Authorization Server URL: %w", err) } return APIHost{ - restURL: baseRestURL, - gqlURL: gqlURL, - uploadURL: uploadURL, - rawURL: rawURL, - oauthURL: oauthURL, + restURL: baseRestURL, + gqlURL: gqlURL, + uploadURL: uploadURL, + rawURL: rawURL, + authorizationServerURL: authorizationServerURL, }, nil } @@ -124,17 +125,17 @@ func newGHECHost(hostname string) (APIHost, error) { return APIHost{}, fmt.Errorf("failed to parse GHEC Raw URL: %w", err) } - oauthURL, err := url.Parse(fmt.Sprintf("https://%s/login/oauth", u.Hostname())) + authorizationServerURL, err := url.Parse(fmt.Sprintf("https://%s/login/oauth", u.Hostname())) if err != nil { - return APIHost{}, fmt.Errorf("failed to parse GHEC OAuth URL: %w", err) + return APIHost{}, fmt.Errorf("failed to parse GHEC Authorization Server URL: %w", err) } return APIHost{ - restURL: restURL, - gqlURL: gqlURL, - uploadURL: uploadURL, - rawURL: rawURL, - oauthURL: oauthURL, + restURL: restURL, + gqlURL: gqlURL, + uploadURL: uploadURL, + rawURL: rawURL, + authorizationServerURL: authorizationServerURL, }, nil } @@ -182,17 +183,17 @@ func newGHESHost(hostname string) (APIHost, error) { return APIHost{}, fmt.Errorf("failed to parse GHES Raw URL: %w", err) } - oauthURL, err := url.Parse(fmt.Sprintf("%s://%s/login/oauth", u.Scheme, u.Hostname())) + authorizationServerURL, err := url.Parse(fmt.Sprintf("%s://%s/login/oauth", u.Scheme, u.Hostname())) if err != nil { - return APIHost{}, fmt.Errorf("failed to parse GHES OAuth URL: %w", err) + return APIHost{}, fmt.Errorf("failed to parse GHES Authorization Server URL: %w", err) } return APIHost{ - restURL: restURL, - gqlURL: gqlURL, - uploadURL: uploadURL, - rawURL: rawURL, - oauthURL: oauthURL, + restURL: restURL, + gqlURL: gqlURL, + uploadURL: uploadURL, + rawURL: rawURL, + authorizationServerURL: authorizationServerURL, }, nil } diff --git a/pkg/utils/api_test.go b/pkg/utils/api_test.go deleted file mode 100644 index 8fa9279ab..000000000 --- a/pkg/utils/api_test.go +++ /dev/null @@ -1,140 +0,0 @@ -package utils //nolint:revive //TODO: figure out a better name for this package - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestOAuthURL(t *testing.T) { - ctx := context.Background() - - tests := []struct { - name string - host string - expectedOAuth string - expectError bool - errorSubstring string - }{ - { - name: "dotcom (empty host)", - host: "", - expectedOAuth: "https://github.com/login/oauth", - }, - { - name: "dotcom (explicit github.com)", - host: "https://github.com", - expectedOAuth: "https://github.com/login/oauth", - }, - { - name: "GHEC with HTTPS", - host: "https://acme.ghe.com", - expectedOAuth: "https://acme.ghe.com/login/oauth", - }, - { - name: "GHEC with HTTP (should error)", - host: "http://acme.ghe.com", - expectError: true, - errorSubstring: "GHEC URL must be HTTPS", - }, - { - name: "GHES with HTTPS", - host: "https://ghes.example.com", - expectedOAuth: "https://ghes.example.com/login/oauth", - }, - { - name: "GHES with HTTP", - host: "http://ghes.example.com", - expectedOAuth: "http://ghes.example.com/login/oauth", - }, - { - name: "GHES with HTTP and custom port (port stripped - not supported yet)", - host: "http://ghes.local:8080", - expectedOAuth: "http://ghes.local/login/oauth", // Port is stripped ref: ln222 api.go comment - }, - { - name: "GHES with HTTPS and custom port (port stripped - not supported yet)", - host: "https://ghes.local:8443", - expectedOAuth: "https://ghes.local/login/oauth", // Port is stripped ref: ln222 api.go comment - }, - { - name: "host without scheme (should error)", - host: "ghes.example.com", - expectError: true, - errorSubstring: "host must have a scheme", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - apiHost, err := NewAPIHost(tt.host) - - if tt.expectError { - require.Error(t, err) - if tt.errorSubstring != "" { - assert.Contains(t, err.Error(), tt.errorSubstring) - } - return - } - - require.NoError(t, err) - require.NotNil(t, apiHost) - - oauthURL, err := apiHost.OAuthURL(ctx) - require.NoError(t, err) - require.NotNil(t, oauthURL) - - assert.Equal(t, tt.expectedOAuth, oauthURL.String()) - }) - } -} - -func TestAPIHost_AllURLsHaveConsistentScheme(t *testing.T) { - ctx := context.Background() - - tests := []struct { - name string - host string - expectedScheme string - }{ - { - name: "GHES with HTTPS", - host: "https://ghes.example.com", - expectedScheme: "https", - }, - { - name: "GHES with HTTP", - host: "http://ghes.example.com", - expectedScheme: "http", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - apiHost, err := NewAPIHost(tt.host) - require.NoError(t, err) - - restURL, err := apiHost.BaseRESTURL(ctx) - require.NoError(t, err) - assert.Equal(t, tt.expectedScheme, restURL.Scheme, "REST URL scheme should match") - - gqlURL, err := apiHost.GraphqlURL(ctx) - require.NoError(t, err) - assert.Equal(t, tt.expectedScheme, gqlURL.Scheme, "GraphQL URL scheme should match") - - uploadURL, err := apiHost.UploadURL(ctx) - require.NoError(t, err) - assert.Equal(t, tt.expectedScheme, uploadURL.Scheme, "Upload URL scheme should match") - - rawURL, err := apiHost.RawURL(ctx) - require.NoError(t, err) - assert.Equal(t, tt.expectedScheme, rawURL.Scheme, "Raw URL scheme should match") - - oauthURL, err := apiHost.OAuthURL(ctx) - require.NoError(t, err) - assert.Equal(t, tt.expectedScheme, oauthURL.Scheme, "OAuth URL scheme should match") - }) - } -}