From 160a55fcd3e3081dbd6c77df4a5e510c6e643887 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Mon, 24 Nov 2025 11:51:36 +0530 Subject: [PATCH 01/19] feat: add token exchange flow - Adds the Token Exchange (RFC 8693) for Enterprise-Managed Authorization --- oauthex/token_exchange.go | 267 +++++++++++++++++++++++++++++++++ oauthex/token_exchange_test.go | 220 +++++++++++++++++++++++++++ 2 files changed, 487 insertions(+) create mode 100644 oauthex/token_exchange.go create mode 100644 oauthex/token_exchange_test.go diff --git a/oauthex/token_exchange.go b/oauthex/token_exchange.go new file mode 100644 index 00000000..fb162d0b --- /dev/null +++ b/oauthex/token_exchange.go @@ -0,0 +1,267 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements Token Exchange (RFC 8693) for Enterprise Managed Authorization. +// See https://datatracker.ietf.org/doc/html/rfc8693 + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" +) + +// Token type identifiers defined by RFC 8693 and SEP-990. +const ( + // TokenTypeIDToken is the URN for OpenID Connect ID Tokens. + TokenTypeIDToken = "urn:ietf:params:oauth:token-type:id_token" + + // TokenTypeSAML2 is the URN for SAML 2.0 assertions. + TokenTypeSAML2 = "urn:ietf:params:oauth:token-type:saml2" + + // TokenTypeIDJAG is the URN for Identity Assertion JWT Authorization Grants. + // This is the token type returned by IdP during token exchange for SEP-990. + TokenTypeIDJAG = "urn:ietf:params:oauth:token-type:id-jag" + + // GrantTypeTokenExchange is the grant type for RFC 8693 token exchange. + GrantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange" +) + +// TokenExchangeRequest represents a Token Exchange request per RFC 8693. +// This is used for Enterprise Managed Authorization (SEP-990) where an MCP Client +// exchanges an ID Token from an enterprise IdP for an ID-JAG that can be used +// to obtain an access token from an MCP Server's authorization server. +type TokenExchangeRequest struct { + // RequestedTokenType indicates the type of security token being requested. + // For SEP-990, this MUST be TokenTypeIDJAG. + RequestedTokenType string + + // Audience is the logical name of the target service where the client + // intends to use the requested token. For SEP-990, this MUST be the + // Issuer URL of the MCP Server's authorization server. + Audience string + + // Resource is the physical location or identifier of the target resource. + // For SEP-990, this MUST be the RFC9728 Resource Identifier of the MCP Server. + Resource string + + // Scope is a list of space-separated scopes for the requested token. + // This is OPTIONAL per RFC 8693 but commonly used in SEP-990. + Scope []string + + // SubjectToken is the security token that represents the identity of the + // party on behalf of whom the request is being made. For SEP-990, this is + // typically an OpenID Connect ID Token. + SubjectToken string + + // SubjectTokenType is the type of the security token in SubjectToken. + // For SEP-990 with OIDC, this MUST be TokenTypeIDToken. + SubjectTokenType string +} + +// TokenExchangeResponse represents the response from a token exchange request +// per RFC 8693 Section 2.2. +type TokenExchangeResponse struct { + // IssuedTokenType is the type of the security token in AccessToken. + // For SEP-990, this MUST be TokenTypeIDJAG. + IssuedTokenType string `json:"issued_token_type"` + + // AccessToken is the security token issued by the authorization server. + // Despite the name "access_token" (required by RFC 8693), for SEP-990 + // this contains an ID-JAG JWT, not an OAuth access token. + AccessToken string `json:"access_token"` + + // TokenType indicates the type of token returned. For SEP-990, this is "N_A" + // because the issued token is not an OAuth access token. + TokenType string `json:"token_type"` + + // Scope is the scope of the issued token, if the issued token scope is + // different from the requested scope. Per RFC 8693, this SHOULD be included + // if the scope differs from the request. + Scope string `json:"scope,omitempty"` + + // ExpiresIn is the lifetime in seconds of the issued token. + ExpiresIn int `json:"expires_in,omitempty"` +} + +// TokenExchangeError represents an error response from a token exchange request. +type TokenExchangeError struct { + // Error is the error code as defined in RFC 6749 Section 5.2. + ErrorCode string `json:"error"` + + // ErrorDescription is a human-readable description of the error. + ErrorDescription string `json:"error_description,omitempty"` + + // ErrorURI is a URI identifying a human-readable web page with information + // about the error. + ErrorURI string `json:"error_uri,omitempty"` +} + +func (e *TokenExchangeError) Error() string { + if e.ErrorDescription != "" { + return fmt.Sprintf("token exchange failed: %s (%s)", e.ErrorCode, e.ErrorDescription) + } + return fmt.Sprintf("token exchange failed: %s", e.ErrorCode) +} + +// ExchangeToken performs a token exchange request per RFC 8693 for Enterprise +// Managed Authorization (SEP-990). It exchanges an identity assertion (typically +// an ID Token) for an Identity Assertion JWT Authorization Grant (ID-JAG) that +// can be used to obtain an access token from an MCP Server. +// +// The tokenEndpoint parameter should be the IdP's token endpoint (typically +// obtained from the IdP's authorization server metadata). +// +// Client authentication must be performed by the caller by including appropriate +// credentials in the request (e.g., using Basic auth via the Authorization header, +// or including client_id and client_secret in the form data). +// +// Example: +// +// req := &TokenExchangeRequest{ +// RequestedTokenType: TokenTypeIDJAG, +// Audience: "https://auth.mcpserver.example/", +// Resource: "https://mcp.mcpserver.example/", +// Scope: []string{"read", "write"}, +// SubjectToken: idToken, +// SubjectTokenType: TokenTypeIDToken, +// } +// +// resp, err := ExchangeToken(ctx, idpTokenEndpoint, req, clientID, clientSecret, nil) +func ExchangeToken( + ctx context.Context, + tokenEndpoint string, + req *TokenExchangeRequest, + clientID string, + clientSecret string, + httpClient *http.Client, +) (*TokenExchangeResponse, error) { + if tokenEndpoint == "" { + return nil, fmt.Errorf("token endpoint is required") + } + if req == nil { + return nil, fmt.Errorf("token exchange request is required") + } + + // Validate required fields per SEP-990 Section 4 + if req.RequestedTokenType == "" { + return nil, fmt.Errorf("requested_token_type is required") + } + if req.Audience == "" { + return nil, fmt.Errorf("audience is required") + } + if req.Resource == "" { + return nil, fmt.Errorf("resource is required") + } + if req.SubjectToken == "" { + return nil, fmt.Errorf("subject_token is required") + } + if req.SubjectTokenType == "" { + return nil, fmt.Errorf("subject_token_type is required") + } + + // Validate URL schemes to prevent XSS attacks (see #526) + if err := checkURLScheme(tokenEndpoint); err != nil { + return nil, fmt.Errorf("invalid token endpoint: %w", err) + } + if err := checkURLScheme(req.Audience); err != nil { + return nil, fmt.Errorf("invalid audience: %w", err) + } + if err := checkURLScheme(req.Resource); err != nil { + return nil, fmt.Errorf("invalid resource: %w", err) + } + + // Build the token exchange request body per RFC 8693 + formData := url.Values{} + formData.Set("grant_type", GrantTypeTokenExchange) + formData.Set("requested_token_type", req.RequestedTokenType) + formData.Set("audience", req.Audience) + formData.Set("resource", req.Resource) + formData.Set("subject_token", req.SubjectToken) + formData.Set("subject_token_type", req.SubjectTokenType) + + if len(req.Scope) > 0 { + formData.Set("scope", strings.Join(req.Scope, " ")) + } + + // Add client authentication (following OAuth 2.0 client_secret_post method) + if clientID != "" { + formData.Set("client_id", clientID) + } + if clientSecret != "" { + formData.Set("client_secret", clientSecret) + } + + // Create HTTP request + httpReq, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + tokenEndpoint, + strings.NewReader(formData.Encode()), + ) + if err != nil { + return nil, fmt.Errorf("failed to create token exchange request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + httpReq.Header.Set("Accept", "application/json") + + // Use provided client or default + if httpClient == nil { + httpClient = http.DefaultClient + } + + // Execute the request + httpResp, err := httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("token exchange request failed: %w", err) + } + defer httpResp.Body.Close() + + // Read response body (limit to 1MB for safety, following SDK pattern) + body, err := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("failed to read token exchange response: %w", err) + } + + // Handle success response (200 OK per RFC 8693) + if httpResp.StatusCode == http.StatusOK { + var resp TokenExchangeResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse token exchange response: %w (body: %s)", err, string(body)) + } + + // Validate response per SEP-990 Section 4.2 + if resp.IssuedTokenType == "" { + return nil, fmt.Errorf("response missing required field: issued_token_type") + } + if resp.AccessToken == "" { + return nil, fmt.Errorf("response missing required field: access_token") + } + if resp.TokenType == "" { + return nil, fmt.Errorf("response missing required field: token_type") + } + + return &resp, nil + } + + // Handle error response (400 Bad Request per RFC 6749) + if httpResp.StatusCode == http.StatusBadRequest { + var errResp TokenExchangeError + if err := json.Unmarshal(body, &errResp); err != nil { + return nil, fmt.Errorf("failed to parse error response: %w (body: %s)", err, string(body)) + } + return nil, &errResp + } + + // Handle unexpected status codes + return nil, fmt.Errorf("unexpected status code %d: %s", httpResp.StatusCode, string(body)) +} diff --git a/oauthex/token_exchange_test.go b/oauthex/token_exchange_test.go new file mode 100644 index 00000000..316fa8e3 --- /dev/null +++ b/oauthex/token_exchange_test.go @@ -0,0 +1,220 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +// TestExchangeToken tests the basic token exchange flow. +func TestExchangeToken(t *testing.T) { + // Create a test IdP server that implements token exchange + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method and content type + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + contentType := r.Header.Get("Content-Type") + if contentType != "application/x-www-form-urlencoded" { + t.Errorf("expected application/x-www-form-urlencoded, got %s", contentType) + http.Error(w, "invalid content type", http.StatusBadRequest) + return + } + + // Parse form data + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + + // Verify required parameters per SEP-990 Section 4 + grantType := r.FormValue("grant_type") + if grantType != GrantTypeTokenExchange { + t.Errorf("expected grant_type %s, got %s", GrantTypeTokenExchange, grantType) + writeErrorResponse(w, "invalid_grant", "invalid grant_type") + return + } + + requestedTokenType := r.FormValue("requested_token_type") + if requestedTokenType != TokenTypeIDJAG { + t.Errorf("expected requested_token_type %s, got %s", TokenTypeIDJAG, requestedTokenType) + writeErrorResponse(w, "invalid_request", "invalid requested_token_type") + return + } + + audience := r.FormValue("audience") + if audience == "" { + t.Error("audience is required") + writeErrorResponse(w, "invalid_request", "missing audience") + return + } + + resource := r.FormValue("resource") + if resource == "" { + t.Error("resource is required") + writeErrorResponse(w, "invalid_request", "missing resource") + return + } + + subjectToken := r.FormValue("subject_token") + if subjectToken == "" { + t.Error("subject_token is required") + writeErrorResponse(w, "invalid_request", "missing subject_token") + return + } + + subjectTokenType := r.FormValue("subject_token_type") + if subjectTokenType != TokenTypeIDToken { + t.Errorf("expected subject_token_type %s, got %s", TokenTypeIDToken, subjectTokenType) + writeErrorResponse(w, "invalid_request", "invalid subject_token_type") + return + } + + // Verify client authentication + clientID := r.FormValue("client_id") + clientSecret := r.FormValue("client_secret") + if clientID == "" || clientSecret == "" { + t.Error("client authentication required") + writeErrorResponse(w, "invalid_client", "client authentication failed") + return + } + + if clientID != "test-client-id" || clientSecret != "test-client-secret" { + t.Error("invalid client credentials") + writeErrorResponse(w, "invalid_client", "invalid credentials") + return + } + + // Return successful token exchange response per SEP-990 Section 4.2 + resp := TokenExchangeResponse{ + IssuedTokenType: TokenTypeIDJAG, + AccessToken: "fake-id-jag-token", + TokenType: "N_A", + Scope: r.FormValue("scope"), + ExpiresIn: 300, + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + // Test successful token exchange + t.Run("successful exchange", func(t *testing.T) { + req := &TokenExchangeRequest{ + RequestedTokenType: TokenTypeIDJAG, + Audience: "https://auth.mcpserver.example/", + Resource: "https://mcp.mcpserver.example/", + Scope: []string{"read", "write"}, + SubjectToken: "fake-id-token", + SubjectTokenType: TokenTypeIDToken, + } + + resp, err := ExchangeToken( + context.Background(), + server.URL, + req, + "test-client-id", + "test-client-secret", + server.Client(), + ) + + if err != nil { + t.Fatalf("ExchangeToken failed: %v", err) + } + + if resp.IssuedTokenType != TokenTypeIDJAG { + t.Errorf("expected issued_token_type %s, got %s", TokenTypeIDJAG, resp.IssuedTokenType) + } + + if resp.AccessToken != "fake-id-jag-token" { + t.Errorf("expected access_token 'fake-id-jag-token', got %s", resp.AccessToken) + } + + if resp.TokenType != "N_A" { + t.Errorf("expected token_type 'N_A', got %s", resp.TokenType) + } + + if resp.Scope != "read write" { + t.Errorf("expected scope 'read write', got %s", resp.Scope) + } + + if resp.ExpiresIn != 300 { + t.Errorf("expected expires_in 300, got %d", resp.ExpiresIn) + } + }) + + // Test missing required fields + t.Run("missing audience", func(t *testing.T) { + req := &TokenExchangeRequest{ + RequestedTokenType: TokenTypeIDJAG, + Resource: "https://mcp.mcpserver.example/", + SubjectToken: "fake-id-token", + SubjectTokenType: TokenTypeIDToken, + } + + _, err := ExchangeToken( + context.Background(), + server.URL, + req, + "test-client-id", + "test-client-secret", + server.Client(), + ) + + if err == nil { + t.Error("expected error for missing audience, got nil") + } + }) + + // Test invalid URL schemes + t.Run("invalid audience URL scheme", func(t *testing.T) { + req := &TokenExchangeRequest{ + RequestedTokenType: TokenTypeIDJAG, + Audience: "javascript:alert(1)", + Resource: "https://mcp.mcpserver.example/", + SubjectToken: "fake-id-token", + SubjectTokenType: TokenTypeIDToken, + } + + _, err := ExchangeToken( + context.Background(), + server.URL, + req, + "test-client-id", + "test-client-secret", + server.Client(), + ) + + if err == nil { + t.Error("expected error for invalid audience URL scheme, got nil") + } + }) +} + +// writeErrorResponse writes an OAuth 2.0 error response per RFC 6749 Section 5.2. +func writeErrorResponse(w http.ResponseWriter, errorCode, errorDescription string) { + errResp := TokenExchangeError{ + ErrorCode: errorCode, + ErrorDescription: errorDescription, + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(errResp) +} From 6b57384417188fdd71abcbe67b5c9494cc0b4aa1 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Mon, 24 Nov 2025 15:40:55 +0530 Subject: [PATCH 02/19] feat: add jwt bearer flow with id-jag --- oauthex/jwt_bearer.go | 193 +++++++++++++++++++++++++++++++++++++ oauthex/jwt_bearer_test.go | 154 +++++++++++++++++++++++++++++ 2 files changed, 347 insertions(+) create mode 100644 oauthex/jwt_bearer.go create mode 100644 oauthex/jwt_bearer_test.go diff --git a/oauthex/jwt_bearer.go b/oauthex/jwt_bearer.go new file mode 100644 index 00000000..03a0df5d --- /dev/null +++ b/oauthex/jwt_bearer.go @@ -0,0 +1,193 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements JWT Bearer Authorization Grant (RFC 7523) for Enterprise Managed Authorization. +// See https://datatracker.ietf.org/doc/html/rfc7523 + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "golang.org/x/oauth2" +) + +// GrantTypeJWTBearer is the grant type for RFC 7523 JWT Bearer authorization grant. +// This is used in SEP-990 to exchange an ID-JAG for an access token at the MCP Server. +const GrantTypeJWTBearer = "urn:ietf:params:oauth:grant-type:jwt-bearer" + +// JWTBearerResponse represents the response from a JWT Bearer grant request +// per RFC 7523. This uses the standard OAuth 2.0 token response format. +type JWTBearerResponse struct { + // AccessToken is the OAuth access token issued by the MCP Server's + // authorization server. + AccessToken string `json:"access_token"` + // TokenType is the type of token issued. This is typically "Bearer". + TokenType string `json:"token_type"` + // ExpiresIn is the lifetime in seconds of the access token. + ExpiresIn int `json:"expires_in,omitempty"` + // RefreshToken is the refresh token, which can be used to obtain new + // access tokens using the same authorization grant. + RefreshToken string `json:"refresh_token,omitempty"` + // Scope is the scope of the access token as described by RFC 6749 Section 3.3. + Scope string `json:"scope,omitempty"` +} + +// JWTBearerError represents an error response from a JWT Bearer grant request. +type JWTBearerError struct { + // ErrorCode is the error code as defined in RFC 6749 Section 5.2. + // The JSON field name is "error" per the RFC specification. + ErrorCode string `json:"error"` + // ErrorDescription is a human-readable description of the error. + ErrorDescription string `json:"error_description,omitempty"` + // ErrorURI is a URI identifying a human-readable web page with information + // about the error. + ErrorURI string `json:"error_uri,omitempty"` +} + +func (e *JWTBearerError) Error() string { + if e.ErrorDescription != "" { + return fmt.Sprintf("JWT bearer grant failed: %s (%s)", e.ErrorCode, e.ErrorDescription) + } + return fmt.Sprintf("JWT bearer grant failed: %s", e.ErrorCode) +} + +// ExchangeJWTBearer exchanges an Identity Assertion JWT Authorization Grant (ID-JAG) +// for an access token using JWT Bearer Grant per RFC 7523. This is the second step +// in Enterprise Managed Authorization (SEP-990) after obtaining the ID-JAG from the +// IdP via token exchange. +// +// The tokenEndpoint parameter should be the MCP Server's token endpoint (typically +// obtained from the MCP Server's authorization server metadata). +// +// The assertion parameter should be the ID-JAG JWT obtained from the token exchange +// step with the enterprise IdP. +// +// Client authentication must be performed by the caller by including appropriate +// credentials in the request (e.g., using Basic auth via the Authorization header, +// or including client_id and client_secret in the form data). +// +// Example: +// +// // First, get ID-JAG via token exchange +// idJAG := tokenExchangeResp.AccessToken +// +// // Then exchange ID-JAG for access token +// token, err := ExchangeJWTBearer( +// ctx, +// "https://auth.mcpserver.example/oauth2/token", +// idJAG, +// "mcp-client-id", +// "mcp-client-secret", +// nil, +// ) +func ExchangeJWTBearer( + ctx context.Context, + tokenEndpoint string, + assertion string, + clientID string, + clientSecret string, + httpClient *http.Client, +) (*oauth2.Token, error) { + if tokenEndpoint == "" { + return nil, fmt.Errorf("token endpoint is required") + } + if assertion == "" { + return nil, fmt.Errorf("assertion is required") + } + // Validate URL scheme to prevent XSS attacks (see #526) + if err := checkURLScheme(tokenEndpoint); err != nil { + return nil, fmt.Errorf("invalid token endpoint: %w", err) + } + // Build the JWT Bearer grant request per RFC 7523 Section 2.1 + formData := url.Values{} + formData.Set("grant_type", GrantTypeJWTBearer) + formData.Set("assertion", assertion) + // Add client authentication (following OAuth 2.0 client_secret_post method) + // Note: Per SEP-990 Section 5.1, the client_id in the assertion must match + // the authenticated client + if clientID != "" { + formData.Set("client_id", clientID) + } + if clientSecret != "" { + formData.Set("client_secret", clientSecret) + } + // Create HTTP request + httpReq, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + tokenEndpoint, + strings.NewReader(formData.Encode()), + ) + if err != nil { + return nil, fmt.Errorf("failed to create JWT bearer grant request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + httpReq.Header.Set("Accept", "application/json") + // Use provided client or default + if httpClient == nil { + httpClient = http.DefaultClient + } + // Execute the request + httpResp, err := httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("JWT bearer grant request failed: %w", err) + } + defer httpResp.Body.Close() + // Read response body (limit to 1MB for safety, following SDK pattern) + body, err := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("failed to read JWT bearer grant response: %w", err) + } + // Handle success response (200 OK per OAuth 2.0) + if httpResp.StatusCode == http.StatusOK { + var resp JWTBearerResponse + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to parse JWT bearer grant response: %w (body: %s)", err, string(body)) + } + // Validate response per OAuth 2.0 + if resp.AccessToken == "" { + return nil, fmt.Errorf("response missing required field: access_token") + } + if resp.TokenType == "" { + return nil, fmt.Errorf("response missing required field: token_type") + } + // Convert to golang.org/x/oauth2.Token + token := &oauth2.Token{ + AccessToken: resp.AccessToken, + TokenType: resp.TokenType, + RefreshToken: resp.RefreshToken, + } + // Set expiration if provided + if resp.ExpiresIn > 0 { + token.Expiry = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second) + } + // Add scope to extra data if provided + if resp.Scope != "" { + token = token.WithExtra(map[string]interface{}{ + "scope": resp.Scope, + }) + } + return token, nil + } + // Handle error response (400 Bad Request per RFC 6749) + if httpResp.StatusCode == http.StatusBadRequest { + var errResp JWTBearerError + if err := json.Unmarshal(body, &errResp); err != nil { + return nil, fmt.Errorf("failed to parse error response: %w (body: %s)", err, string(body)) + } + return nil, &errResp + } + // Handle unexpected status codes + return nil, fmt.Errorf("unexpected status code %d: %s", httpResp.StatusCode, string(body)) +} diff --git a/oauthex/jwt_bearer_test.go b/oauthex/jwt_bearer_test.go new file mode 100644 index 00000000..3145d0bf --- /dev/null +++ b/oauthex/jwt_bearer_test.go @@ -0,0 +1,154 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +// TestExchangeJWTBearer tests the JWT Bearer grant flow. +func TestExchangeJWTBearer(t *testing.T) { + // Create a test MCP Server auth server that accepts JWT Bearer grants + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request method and content type + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + contentType := r.Header.Get("Content-Type") + if contentType != "application/x-www-form-urlencoded" { + t.Errorf("expected application/x-www-form-urlencoded, got %s", contentType) + http.Error(w, "invalid content type", http.StatusBadRequest) + return + } + // Parse form data + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + // Verify grant type per RFC 7523 + grantType := r.FormValue("grant_type") + if grantType != GrantTypeJWTBearer { + t.Errorf("expected grant_type %s, got %s", GrantTypeJWTBearer, grantType) + writeJWTBearerErrorResponse(w, "unsupported_grant_type", "grant type not supported") + return + } + // Verify assertion is provided + assertion := r.FormValue("assertion") + if assertion == "" { + t.Error("assertion is required") + writeJWTBearerErrorResponse(w, "invalid_request", "missing assertion") + return + } + // Verify client authentication + clientID := r.FormValue("client_id") + clientSecret := r.FormValue("client_secret") + if clientID == "" || clientSecret == "" { + t.Error("client authentication required") + writeJWTBearerErrorResponse(w, "invalid_client", "client authentication failed") + return + } + if clientID != "mcp-client-id" || clientSecret != "mcp-client-secret" { + t.Error("invalid client credentials") + writeJWTBearerErrorResponse(w, "invalid_client", "invalid credentials") + return + } + // Return successful OAuth token response + resp := JWTBearerResponse{ + AccessToken: "mcp-access-token-123", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: "read write", + RefreshToken: "mcp-refresh-token-456", + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + // Test successful JWT Bearer grant + t.Run("successful exchange", func(t *testing.T) { + token, err := ExchangeJWTBearer( + context.Background(), + server.URL, + "fake-id-jag-jwt", + "mcp-client-id", + "mcp-client-secret", + server.Client(), + ) + if err != nil { + t.Fatalf("ExchangeJWTBearer failed: %v", err) + } + if token.AccessToken != "mcp-access-token-123" { + t.Errorf("expected access_token 'mcp-access-token-123', got %s", token.AccessToken) + } + if token.TokenType != "Bearer" { + t.Errorf("expected token_type 'Bearer', got %s", token.TokenType) + } + if token.RefreshToken != "mcp-refresh-token-456" { + t.Errorf("expected refresh_token 'mcp-refresh-token-456', got %s", token.RefreshToken) + } + // Check expiration (should be ~1 hour from now) + expectedExpiry := time.Now().Add(3600 * time.Second) + if token.Expiry.Before(time.Now()) || token.Expiry.After(expectedExpiry.Add(5*time.Second)) { + t.Errorf("unexpected expiry time: %v", token.Expiry) + } + // Check scope in extra data + scope, ok := token.Extra("scope").(string) + if !ok || scope != "read write" { + t.Errorf("expected scope 'read write', got %v", token.Extra("scope")) + } + }) + // Test missing assertion + t.Run("missing assertion", func(t *testing.T) { + _, err := ExchangeJWTBearer( + context.Background(), + server.URL, + "", // empty assertion + "mcp-client-id", + "mcp-client-secret", + server.Client(), + ) + if err == nil { + t.Error("expected error for missing assertion, got nil") + } + }) + // Test invalid URL scheme + t.Run("invalid token endpoint URL", func(t *testing.T) { + _, err := ExchangeJWTBearer( + context.Background(), + "javascript:alert(1)", + "fake-id-jag-jwt", + "mcp-client-id", + "mcp-client-secret", + server.Client(), + ) + if err == nil { + t.Error("expected error for invalid URL scheme, got nil") + } + }) +} + +// writeJWTBearerErrorResponse writes an OAuth 2.0 error response per RFC 6749 Section 5.2. +func writeJWTBearerErrorResponse(w http.ResponseWriter, errorCode, errorDescription string) { + errResp := JWTBearerError{ + ErrorCode: errorCode, + ErrorDescription: errorDescription, + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(errResp) +} From adaaecb209b3d3614a8f34ea6fe8233cc2c64b69 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Mon, 24 Nov 2025 17:45:45 +0530 Subject: [PATCH 03/19] feat: add id-jag+jwt parsing and enterprise auth --- auth/enterprise_auth.go | 137 ++++++++++++++++++++++ auth/enterprise_auth_test.go | 218 +++++++++++++++++++++++++++++++++++ oauthex/id_jag.go | 138 ++++++++++++++++++++++ oauthex/id_jag_test.go | 176 ++++++++++++++++++++++++++++ 4 files changed, 669 insertions(+) create mode 100644 auth/enterprise_auth.go create mode 100644 auth/enterprise_auth_test.go create mode 100644 oauthex/id_jag.go create mode 100644 oauthex/id_jag_test.go diff --git a/auth/enterprise_auth.go b/auth/enterprise_auth.go new file mode 100644 index 00000000..1fd471dc --- /dev/null +++ b/auth/enterprise_auth.go @@ -0,0 +1,137 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements the client-side Enterprise Managed Authorization flow +// for MCP as specified in SEP-990. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "fmt" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +// EnterpriseAuthConfig contains configuration for Enterprise Managed Authorization +// (SEP-990). This configures both the IdP (for token exchange) and the MCP Server +// (for JWT Bearer grant). +type EnterpriseAuthConfig struct { + // IdP configuration (where the user authenticates) + IdPIssuerURL string // e.g., "https://acme.okta.com" + IdPClientID string // MCP Client's ID at the IdP + IdPClientSecret string // MCP Client's secret at the IdP + + // MCP Server configuration (the resource being accessed) + MCPAuthServerURL string // MCP Server's auth server issuer URL + MCPResourceURL string // MCP Server's resource identifier + MCPClientID string // MCP Client's ID at the MCP Server + MCPClientSecret string // MCP Client's secret at the MCP Server + MCPScopes []string // Requested scopes at the MCP Server + + // Optional HTTP client for customization + HTTPClient *http.Client +} + +// EnterpriseAuthFlow performs the complete Enterprise Managed Authorization flow: +// 1. Token Exchange: ID Token → ID-JAG at IdP +// 2. JWT Bearer: ID-JAG → Access Token at MCP Server +// +// This function takes an ID Token that was obtained via SSO (e.g., OIDC login) +// and exchanges it for an access token that can be used to call the MCP Server. +// +// Example: +// +// config := &EnterpriseAuthConfig{ +// IdPIssuerURL: "https://acme.okta.com", +// IdPClientID: "client-id-at-idp", +// IdPClientSecret: "secret-at-idp", +// MCPAuthServerURL: "https://auth.mcpserver.example", +// MCPResourceURL: "https://mcp.mcpserver.example", +// MCPClientID: "client-id-at-mcp", +// MCPClientSecret: "secret-at-mcp", +// MCPScopes: []string{"read", "write"}, +// } +// +// // After user logs in via OIDC, you have an ID Token +// accessToken, err := EnterpriseAuthFlow(ctx, config, idToken) +// if err != nil { +// log.Fatal(err) +// } +// +// // Use accessToken to call MCP Server APIs +func EnterpriseAuthFlow( + ctx context.Context, + config *EnterpriseAuthConfig, + idToken string, +) (*oauth2.Token, error) { + if config == nil { + return nil, fmt.Errorf("config is required") + } + if idToken == "" { + return nil, fmt.Errorf("idToken is required") + } + // Validate configuration + if config.IdPIssuerURL == "" { + return nil, fmt.Errorf("IdPIssuerURL is required") + } + if config.MCPAuthServerURL == "" { + return nil, fmt.Errorf("MCPAuthServerURL is required") + } + if config.MCPResourceURL == "" { + return nil, fmt.Errorf("MCPResourceURL is required") + } + httpClient := config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + + // Step 1: Token Exchange (ID Token → ID-JAG) + tokenExchangeReq := &oauthex.TokenExchangeRequest{ + RequestedTokenType: oauthex.TokenTypeIDJAG, + Audience: config.MCPAuthServerURL, + Resource: config.MCPResourceURL, + Scope: config.MCPScopes, + SubjectToken: idToken, + SubjectTokenType: oauthex.TokenTypeIDToken, + } + // Construct IdP token endpoint (assuming standard path) + idpTokenEndpoint := config.IdPIssuerURL + "/oauth2/v1/token" + if config.IdPIssuerURL[len(config.IdPIssuerURL)-1] == '/' { + idpTokenEndpoint = config.IdPIssuerURL + "oauth2/v1/token" + } + tokenExchangeResp, err := oauthex.ExchangeToken( + ctx, + idpTokenEndpoint, + tokenExchangeReq, + config.IdPClientID, + config.IdPClientSecret, + httpClient, + ) + if err != nil { + return nil, fmt.Errorf("token exchange failed: %w", err) + } + + // Step 2: JWT Bearer Grant (ID-JAG → Access Token) + mcpTokenEndpoint := config.MCPAuthServerURL + "/v1/token" + if config.MCPAuthServerURL[len(config.MCPAuthServerURL)-1] == '/' { + mcpTokenEndpoint = config.MCPAuthServerURL + "v1/token" + } + accessToken, err := oauthex.ExchangeJWTBearer( + ctx, + mcpTokenEndpoint, + tokenExchangeResp.AccessToken, // The ID-JAG + config.MCPClientID, + config.MCPClientSecret, + httpClient, + ) + if err != nil { + return nil, fmt.Errorf("JWT bearer grant failed: %w", err) + } + return accessToken, nil +} diff --git a/auth/enterprise_auth_test.go b/auth/enterprise_auth_test.go new file mode 100644 index 00000000..db2e5ffd --- /dev/null +++ b/auth/enterprise_auth_test.go @@ -0,0 +1,218 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +// TestEnterpriseAuthFlow tests the complete enterprise auth flow. +func TestEnterpriseAuthFlow(t *testing.T) { + // Create test servers for IdP and MCP Server + idpServer := createMockIdPServer(t) + defer idpServer.Close() + mcpServer := createMockMCPServer(t) + defer mcpServer.Close() + // Create a test ID Token + idToken := createTestIDToken() + // Configure enterprise auth + config := &EnterpriseAuthConfig{ + IdPIssuerURL: idpServer.URL, + IdPClientID: "test-idp-client", + IdPClientSecret: "test-idp-secret", + MCPAuthServerURL: mcpServer.URL, + MCPResourceURL: "https://mcp.example.com", + MCPClientID: "test-mcp-client", + MCPClientSecret: "test-mcp-secret", + MCPScopes: []string{"read", "write"}, + HTTPClient: idpServer.Client(), + } + // Test successful flow + t.Run("successful flow", func(t *testing.T) { + token, err := EnterpriseAuthFlow(context.Background(), config, idToken) + if err != nil { + t.Fatalf("EnterpriseAuthFlow failed: %v", err) + } + if token.AccessToken != "mcp-access-token" { + t.Errorf("expected access token 'mcp-access-token', got '%s'", token.AccessToken) + } + if token.TokenType != "Bearer" { + t.Errorf("expected token type 'Bearer', got '%s'", token.TokenType) + } + }) + // Test missing config + t.Run("nil config", func(t *testing.T) { + _, err := EnterpriseAuthFlow(context.Background(), nil, idToken) + if err == nil { + t.Error("expected error for nil config, got nil") + } + }) + // Test missing ID token + t.Run("empty ID token", func(t *testing.T) { + _, err := EnterpriseAuthFlow(context.Background(), config, "") + if err == nil { + t.Error("expected error for empty ID token, got nil") + } + }) + // Test missing IdP issuer + t.Run("missing IdP issuer", func(t *testing.T) { + badConfig := *config + badConfig.IdPIssuerURL = "" + _, err := EnterpriseAuthFlow(context.Background(), &badConfig, idToken) + if err == nil { + t.Error("expected error for missing IdP issuer, got nil") + } + }) +} + +// createMockIdPServer creates a mock IdP server for testing. +func createMockIdPServer(t *testing.T) *httptest.Server { + var serverURL string + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle OIDC discovery endpoint + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": serverURL, // Use actual server URL + "token_endpoint": serverURL + "/oauth2/v1/token", + "jwks_uri": serverURL + "/.well-known/jwks.json", + "code_challenge_methods_supported": []string{"S256"}, + "grant_types_supported": []string{ + "authorization_code", + "urn:ietf:params:oauth:grant-type:token-exchange", + }, + "response_types_supported": []string{"code"}, + }) + return + } + + // Handle token exchange endpoint + if r.URL.Path != "/oauth2/v1/token" { + http.NotFound(w, r) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + grantType := r.FormValue("grant_type") + if grantType != oauthex.GrantTypeTokenExchange { + http.Error(w, "invalid grant type", http.StatusBadRequest) + return + } + + // Return a mock ID-JAG + now := time.Now().Unix() + header := map[string]string{"typ": "oauth-id-jag+jwt", "alg": "RS256"} + claims := map[string]interface{}{ + "iss": "https://test.okta.com", + "sub": "test-user", + "aud": r.FormValue("audience"), + "resource": r.FormValue("resource"), + "client_id": r.FormValue("client_id"), + "jti": "test-jti", + "exp": now + 300, + "iat": now, + "scope": r.FormValue("scope"), + } + headerJSON, _ := json.Marshal(header) + claimsJSON, _ := json.Marshal(claims) + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) + mockIDJAG := fmt.Sprintf("%s.%s.mock-signature", headerB64, claimsB64) + + resp := oauthex.TokenExchangeResponse{ + IssuedTokenType: oauthex.TokenTypeIDJAG, + AccessToken: mockIDJAG, + TokenType: "N_A", + ExpiresIn: 300, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + serverURL = server.URL // Capture server URL for discovery response + return server +} + +// createMockMCPServer creates a mock MCP Server for testing. +func createMockMCPServer(t *testing.T) *httptest.Server { + var serverURL string + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle OIDC discovery endpoint + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": serverURL, // Use actual server URL + "token_endpoint": serverURL + "/v1/token", + "jwks_uri": serverURL + "/.well-known/jwks.json", + "code_challenge_methods_supported": []string{"S256"}, + "grant_types_supported": []string{ + "urn:ietf:params:oauth:grant-type:jwt-bearer", + }, + }) + return + } + + // Handle JWT Bearer endpoint + if r.URL.Path != "/v1/token" { + http.NotFound(w, r) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + grantType := r.FormValue("grant_type") + if grantType != oauthex.GrantTypeJWTBearer { + http.Error(w, "invalid grant type", http.StatusBadRequest) + return + } + + resp := oauthex.JWTBearerResponse{ + AccessToken: "mcp-access-token", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: "read write", + RefreshToken: "mcp-refresh-token", + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + serverURL = server.URL // Capture server URL for discovery response + return server +} + +// createTestIDToken creates a mock ID Token for testing. +func createTestIDToken() string { + now := time.Now().Unix() + header := map[string]string{"typ": "JWT", "alg": "RS256"} + claims := map[string]interface{}{ + "iss": "https://test.okta.com", + "sub": "test-user", + "aud": "test-client", + "exp": now + 3600, + "iat": now, + "email": "test@example.com", + } + headerJSON, _ := json.Marshal(header) + claimsJSON, _ := json.Marshal(claims) + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) + + return fmt.Sprintf("%s.%s.mock-signature", headerB64, claimsB64) +} diff --git a/oauthex/id_jag.go b/oauthex/id_jag.go new file mode 100644 index 00000000..860a36ea --- /dev/null +++ b/oauthex/id_jag.go @@ -0,0 +1,138 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements ID-JAG (Identity Assertion JWT Authorization Grant) parsing +// for Enterprise Managed Authorization (SEP-990). +// See https://github.com/modelcontextprotocol/ext-auth/blob/main/specification/draft/enterprise-managed-authorization.mdx + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "time" +) + +// IDJAGClaims represents the claims in an Identity Assertion JWT Authorization Grant +// per SEP-990 Section 4.3. The ID-JAG is issued by the IdP during token exchange +// and describes the authorization grant for accessing an MCP Server. +type IDJAGClaims struct { + // Issuer is the IdP's issuer URL. + Issuer string `json:"iss"` + // Subject is the user identifier at the MCP Server. + Subject string `json:"sub"` + // Audience is the Issuer URL of the MCP Server's authorization server. + Audience string `json:"aud"` + // Resource is the Resource Identifier of the MCP Server. + Resource string `json:"resource"` + // ClientID is the identifier of the MCP Client that this JWT was issued to. + ClientID string `json:"client_id"` + // JTI is the unique identifier of this JWT. + JTI string `json:"jti"` + // ExpiresAt is the expiration time of this JWT (Unix timestamp). + ExpiresAt int64 `json:"exp"` + // IssuedAt is the time this JWT was issued (Unix timestamp). + IssuedAt int64 `json:"iat"` + // Scope is a space-separated list of scopes associated with the token. + Scope string `json:"scope,omitempty"` +} + +// Expiry returns the expiration time as a time.Time. +func (c *IDJAGClaims) Expiry() time.Time { + return time.Unix(c.ExpiresAt, 0) +} + +// IssuedTime returns the issued-at time as a time.Time. +func (c *IDJAGClaims) IssuedTime() time.Time { + return time.Unix(c.IssuedAt, 0) +} + +// IsExpired checks if the ID-JAG has expired. +func (c *IDJAGClaims) IsExpired() bool { + return time.Now().After(c.Expiry()) +} + +// ParseIDJAG parses an ID-JAG JWT and extracts its claims without validating +// the signature. This is useful for inspecting the contents of an ID-JAG during +// development or debugging. +// +// For production use on the server-side, use ValidateIDJAG instead, which +// performs full signature validation and claim verification. +// +// The JWT must have a "typ" header of "oauth-id-jag+jwt" per SEP-990 Section 4.3. +// +// Example: +// +// claims, err := ParseIDJAG(idJAG) +// if err != nil { +// log.Fatalf("Failed to parse ID-JAG: %v", err) +// } +// fmt.Printf("Subject: %s\n", claims.Subject) +// fmt.Printf("Expires: %v\n", claims.Expiry()) +func ParseIDJAG(jwt string) (*IDJAGClaims, error) { + if jwt == "" { + return nil, fmt.Errorf("JWT is empty") + } + // Split JWT into parts (header.payload.signature) + parts := strings.Split(jwt, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) + } + // Decode header to check typ claim + headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return nil, fmt.Errorf("failed to decode JWT header: %w", err) + } + var header struct { + Type string `json:"typ"` + Alg string `json:"alg"` + } + if err := json.Unmarshal(headerJSON, &header); err != nil { + return nil, fmt.Errorf("failed to parse JWT header: %w", err) + } + // Verify typ claim per SEP-990 Section 4.3 + if header.Type != "oauth-id-jag+jwt" { + return nil, fmt.Errorf("invalid JWT type: expected 'oauth-id-jag+jwt', got '%s'", header.Type) + } + // Decode payload + payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to decode JWT payload: %w", err) + } + // Parse claims + var claims IDJAGClaims + if err := json.Unmarshal(payloadJSON, &claims); err != nil { + return nil, fmt.Errorf("failed to parse JWT claims: %w", err) + } + // Validate required claims are present per SEP-990 Section 4.3 + if claims.Issuer == "" { + return nil, fmt.Errorf("missing required claim: iss") + } + if claims.Subject == "" { + return nil, fmt.Errorf("missing required claim: sub") + } + if claims.Audience == "" { + return nil, fmt.Errorf("missing required claim: aud") + } + if claims.Resource == "" { + return nil, fmt.Errorf("missing required claim: resource") + } + if claims.ClientID == "" { + return nil, fmt.Errorf("missing required claim: client_id") + } + if claims.JTI == "" { + return nil, fmt.Errorf("missing required claim: jti") + } + if claims.ExpiresAt == 0 { + return nil, fmt.Errorf("missing required claim: exp") + } + if claims.IssuedAt == 0 { + return nil, fmt.Errorf("missing required claim: iat") + } + return &claims, nil +} diff --git a/oauthex/id_jag_test.go b/oauthex/id_jag_test.go new file mode 100644 index 00000000..ff710fcc --- /dev/null +++ b/oauthex/id_jag_test.go @@ -0,0 +1,176 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package oauthex + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "testing" + "time" +) + +// TestParseIDJAG tests parsing of ID-JAG tokens. +func TestParseIDJAG(t *testing.T) { + // Create a test ID-JAG JWT + now := time.Now().Unix() + + header := map[string]string{ + "typ": "oauth-id-jag+jwt", + "alg": "RS256", + } + + claims := map[string]interface{}{ + "iss": "https://acme.okta.com", + "sub": "alice@acme.com", + "aud": "https://auth.mcpserver.example", + "resource": "https://mcp.mcpserver.example", + "client_id": "xyz789", + "jti": "unique-id-123", + "exp": now + 300, + "iat": now, + "scope": "read write", + } + // Encode header and payload + headerJSON, _ := json.Marshal(header) + claimsJSON, _ := json.Marshal(claims) + + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) + + // Create fake JWT (header.payload.signature) + fakeJWT := fmt.Sprintf("%s.%s.fake-signature", headerB64, claimsB64) + // Test successful parsing + t.Run("successful parse", func(t *testing.T) { + parsed, err := ParseIDJAG(fakeJWT) + if err != nil { + t.Fatalf("ParseIDJAG failed: %v", err) + } + if parsed.Issuer != "https://acme.okta.com" { + t.Errorf("expected issuer 'https://acme.okta.com', got '%s'", parsed.Issuer) + } + if parsed.Subject != "alice@acme.com" { + t.Errorf("expected subject 'alice@acme.com', got '%s'", parsed.Subject) + } + if parsed.Audience != "https://auth.mcpserver.example" { + t.Errorf("expected audience 'https://auth.mcpserver.example', got '%s'", parsed.Audience) + } + if parsed.Resource != "https://mcp.mcpserver.example" { + t.Errorf("expected resource 'https://mcp.mcpserver.example', got '%s'", parsed.Resource) + } + if parsed.ClientID != "xyz789" { + t.Errorf("expected client_id 'xyz789', got '%s'", parsed.ClientID) + } + if parsed.JTI != "unique-id-123" { + t.Errorf("expected jti 'unique-id-123', got '%s'", parsed.JTI) + } + if parsed.Scope != "read write" { + t.Errorf("expected scope 'read write', got '%s'", parsed.Scope) + } + if parsed.IsExpired() { + t.Error("expected ID-JAG not to be expired") + } + }) + // Test empty JWT + t.Run("empty JWT", func(t *testing.T) { + _, err := ParseIDJAG("") + if err == nil { + t.Error("expected error for empty JWT, got nil") + } + }) + // Test invalid format + t.Run("invalid format", func(t *testing.T) { + _, err := ParseIDJAG("invalid.jwt") + if err == nil { + t.Error("expected error for invalid JWT format, got nil") + } + }) + // Test wrong typ header + t.Run("wrong typ header", func(t *testing.T) { + wrongHeader := map[string]string{ + "typ": "JWT", // Should be "oauth-id-jag+jwt" + "alg": "RS256", + } + wrongHeaderJSON, _ := json.Marshal(wrongHeader) + wrongHeaderB64 := base64.RawURLEncoding.EncodeToString(wrongHeaderJSON) + wrongJWT := fmt.Sprintf("%s.%s.fake-signature", wrongHeaderB64, claimsB64) + _, err := ParseIDJAG(wrongJWT) + if err == nil { + t.Error("expected error for wrong typ header, got nil") + } + if err != nil && !strings.Contains(err.Error(), "invalid JWT type") { + t.Errorf("expected 'invalid JWT type' error, got: %v", err) + } + }) + // Test missing required claims + t.Run("missing required claims", func(t *testing.T) { + incompleteClaims := map[string]interface{}{ + "iss": "https://acme.okta.com", + // Missing other required claims + } + incompleteJSON, _ := json.Marshal(incompleteClaims) + incompleteB64 := base64.RawURLEncoding.EncodeToString(incompleteJSON) + incompleteJWT := fmt.Sprintf("%s.%s.fake-signature", headerB64, incompleteB64) + _, err := ParseIDJAG(incompleteJWT) + if err == nil { + t.Error("expected error for missing claims, got nil") + } + }) + // Test expired ID-JAG + t.Run("expired ID-JAG", func(t *testing.T) { + expiredClaims := map[string]interface{}{ + "iss": "https://acme.okta.com", + "sub": "alice@acme.com", + "aud": "https://auth.mcpserver.example", + "resource": "https://mcp.mcpserver.example", + "client_id": "xyz789", + "jti": "unique-id-123", + "exp": now - 300, // Expired 5 minutes ago + "iat": now - 600, + "scope": "read write", + } + expiredJSON, _ := json.Marshal(expiredClaims) + expiredB64 := base64.RawURLEncoding.EncodeToString(expiredJSON) + expiredJWT := fmt.Sprintf("%s.%s.fake-signature", headerB64, expiredB64) + parsed, err := ParseIDJAG(expiredJWT) + if err != nil { + t.Fatalf("ParseIDJAG failed: %v", err) + } + if !parsed.IsExpired() { + t.Error("expected ID-JAG to be expired") + } + }) +} + +// TestIDJAGClaimsMethods tests the helper methods on IDJAGClaims. +func TestIDJAGClaimsMethods(t *testing.T) { + now := time.Now() + claims := &IDJAGClaims{ + ExpiresAt: now.Add(1 * time.Hour).Unix(), + IssuedAt: now.Unix(), + } + // Test Expiry + expiry := claims.Expiry() + if expiry.Before(now) { + t.Error("expected expiry to be in the future") + } + // Test IssuedTime + issued := claims.IssuedTime() + if issued.After(now.Add(1 * time.Second)) { + t.Error("expected issued time to be in the past") + } + // Test IsExpired (should not be expired) + if claims.IsExpired() { + t.Error("expected claims not to be expired") + } + // Test IsExpired (should be expired) + claims.ExpiresAt = now.Add(-1 * time.Hour).Unix() + if !claims.IsExpired() { + t.Error("expected claims to be expired") + } +} From 5a52c5e386746d7576f68b9f0782848bddab1d53 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Mon, 24 Nov 2025 19:59:31 +0530 Subject: [PATCH 04/19] chore: use auth server metadata for token endpoints # Conflicts: # oauthex/oauth2.go --- auth/enterprise_auth.go | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/auth/enterprise_auth.go b/auth/enterprise_auth.go index 1fd471dc..5478e0e4 100644 --- a/auth/enterprise_auth.go +++ b/auth/enterprise_auth.go @@ -91,7 +91,13 @@ func EnterpriseAuthFlow( httpClient = http.DefaultClient } - // Step 1: Token Exchange (ID Token → ID-JAG) + // Step 1: Discover IdP token endpoint via OIDC discovery + idpMeta, err := oauthex.GetAuthServerMeta(ctx, config.IdPIssuerURL, httpClient) + if err != nil { + return nil, fmt.Errorf("failed to discover IdP metadata: %w", err) + } + + // Step 2: Token Exchange (ID Token → ID-JAG) tokenExchangeReq := &oauthex.TokenExchangeRequest{ RequestedTokenType: oauthex.TokenTypeIDJAG, Audience: config.MCPAuthServerURL, @@ -100,14 +106,10 @@ func EnterpriseAuthFlow( SubjectToken: idToken, SubjectTokenType: oauthex.TokenTypeIDToken, } - // Construct IdP token endpoint (assuming standard path) - idpTokenEndpoint := config.IdPIssuerURL + "/oauth2/v1/token" - if config.IdPIssuerURL[len(config.IdPIssuerURL)-1] == '/' { - idpTokenEndpoint = config.IdPIssuerURL + "oauth2/v1/token" - } + tokenExchangeResp, err := oauthex.ExchangeToken( ctx, - idpTokenEndpoint, + idpMeta.TokenEndpoint, tokenExchangeReq, config.IdPClientID, config.IdPClientSecret, @@ -117,15 +119,16 @@ func EnterpriseAuthFlow( return nil, fmt.Errorf("token exchange failed: %w", err) } - // Step 2: JWT Bearer Grant (ID-JAG → Access Token) - mcpTokenEndpoint := config.MCPAuthServerURL + "/v1/token" - if config.MCPAuthServerURL[len(config.MCPAuthServerURL)-1] == '/' { - mcpTokenEndpoint = config.MCPAuthServerURL + "v1/token" + // Step 3: JWT Bearer Grant (ID-JAG → Access Token) + mcpMeta, err := oauthex.GetAuthServerMeta(ctx, config.MCPAuthServerURL, httpClient) + if err != nil { + return nil, fmt.Errorf("failed to discover MCP auth server metadata: %w", err) } + accessToken, err := oauthex.ExchangeJWTBearer( ctx, - mcpTokenEndpoint, - tokenExchangeResp.AccessToken, // The ID-JAG + mcpMeta.TokenEndpoint, + tokenExchangeResp.AccessToken, config.MCPClientID, config.MCPClientSecret, httpClient, From e070a87e552591c2f92a0f728f76296d61210a0d Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Mon, 24 Nov 2025 21:11:04 +0530 Subject: [PATCH 05/19] feat: fetch jwks for jwt signature verification --- auth/jwks_cache.go | 150 ++++++++++++++++++++++++++++++ auth/jwks_cache_test.go | 199 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 349 insertions(+) create mode 100644 auth/jwks_cache.go create mode 100644 auth/jwks_cache_test.go diff --git a/auth/jwks_cache.go b/auth/jwks_cache.go new file mode 100644 index 00000000..70efe89f --- /dev/null +++ b/auth/jwks_cache.go @@ -0,0 +1,150 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements JWKS (JSON Web Key Set) fetching and caching for +// JWT signature verification in Enterprise Managed Authorization (SEP-990). + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "time" +) + +// JWK represents a JSON Web Key per RFC 7517. +type JWK struct { + // KeyType is the key type (e.g., "RSA", "EC"). + KeyType string `json:"kty"` + // Use indicates the intended use of the key (e.g., "sig" for signature). + Use string `json:"use,omitempty"` + // KeyID is the key identifier. + KeyID string `json:"kid"` + // Algorithm is the algorithm intended for use with the key. + Algorithm string `json:"alg,omitempty"` + // N is the RSA modulus (base64url encoded). + N string `json:"n,omitempty"` + // E is the RSA public exponent (base64url encoded). + E string `json:"e,omitempty"` + // X is the X coordinate for elliptic curve keys (base64url encoded). + X string `json:"x,omitempty"` + // Y is the Y coordinate for elliptic curve keys (base64url encoded). + Y string `json:"y,omitempty"` + // Curve is the elliptic curve name (e.g., "P-256"). + Curve string `json:"crv,omitempty"` +} + +// JWKS represents a JSON Web Key Set per RFC 7517. +type JWKS struct { + Keys []JWK `json:"keys"` +} + +// FindKey finds a key by its key ID (kid). +func (j *JWKS) FindKey(kid string) (*JWK, error) { + for i := range j.Keys { + if j.Keys[i].KeyID == kid { + return &j.Keys[i], nil + } + } + return nil, fmt.Errorf("key with kid %q not found", kid) +} + +// JWKSCache caches JWKS responses to reduce network requests. +type JWKSCache struct { + mu sync.RWMutex + entries map[string]*jwksCacheEntry + client *http.Client +} +type jwksCacheEntry struct { + jwks *JWKS + expiresAt time.Time +} + +// NewJWKSCache creates a new JWKS cache with the given HTTP client. +// If client is nil, http.DefaultClient is used. +func NewJWKSCache(client *http.Client) *JWKSCache { + if client == nil { + client = http.DefaultClient + } + return &JWKSCache{ + entries: make(map[string]*jwksCacheEntry), + client: client, + } +} + +// Get fetches JWKS from the given URL, using cache if available and not expired. +// The cache duration is 1 hour per best practices for JWKS caching. +func (c *JWKSCache) Get(ctx context.Context, jwksURL string) (*JWKS, error) { + // Check cache first + c.mu.RLock() + entry, ok := c.entries[jwksURL] + c.mu.RUnlock() + if ok && time.Now().Before(entry.expiresAt) { + return entry.jwks, nil + } + // Fetch from network + jwks, err := c.fetch(ctx, jwksURL) + if err != nil { + return nil, err + } + // Update cache + c.mu.Lock() + c.entries[jwksURL] = &jwksCacheEntry{ + jwks: jwks, + expiresAt: time.Now().Add(1 * time.Hour), + } + c.mu.Unlock() + return jwks, nil +} + +// fetch retrieves JWKS from the given URL. +func (c *JWKSCache) fetch(ctx context.Context, jwksURL string) (*JWKS, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create JWKS request: %w", err) + } + req.Header.Set("Accept", "application/json") + resp, err := c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to fetch JWKS: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("JWKS endpoint returned status %d", resp.StatusCode) + } + // Read response body (limit to 1MB for safety) + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("failed to read JWKS response: %w", err) + } + // Parse JWKS + var jwks JWKS + if err := json.Unmarshal(body, &jwks); err != nil { + return nil, fmt.Errorf("failed to parse JWKS: %w", err) + } + if len(jwks.Keys) == 0 { + return nil, fmt.Errorf("JWKS contains no keys") + } + return &jwks, nil +} + +// Invalidate removes a JWKS entry from the cache, forcing a fresh fetch on next Get. +func (c *JWKSCache) Invalidate(jwksURL string) { + c.mu.Lock() + delete(c.entries, jwksURL) + c.mu.Unlock() +} + +// Clear removes all entries from the cache. +func (c *JWKSCache) Clear() { + c.mu.Lock() + c.entries = make(map[string]*jwksCacheEntry) + c.mu.Unlock() +} diff --git a/auth/jwks_cache_test.go b/auth/jwks_cache_test.go new file mode 100644 index 00000000..4ec87ca7 --- /dev/null +++ b/auth/jwks_cache_test.go @@ -0,0 +1,199 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +// TestJWKSCache tests JWKS fetching and caching. +func TestJWKSCache(t *testing.T) { + // Create test JWKS + testJWKS := &JWKS{ + Keys: []JWK{ + { + KeyType: "RSA", + Use: "sig", + KeyID: "test-key-1", + Algorithm: "RS256", + N: "test-modulus", + E: "AQAB", + }, + { + KeyType: "RSA", + Use: "sig", + KeyID: "test-key-2", + Algorithm: "RS256", + N: "test-modulus-2", + E: "AQAB", + }, + }, + } + // Create test server + var requestCount int + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(testJWKS) + })) + defer server.Close() + cache := NewJWKSCache(server.Client()) + // Test first fetch + t.Run("first fetch", func(t *testing.T) { + jwks, err := cache.Get(context.Background(), server.URL) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if len(jwks.Keys) != 2 { + t.Errorf("expected 2 keys, got %d", len(jwks.Keys)) + } + if jwks.Keys[0].KeyID != "test-key-1" { + t.Errorf("expected key ID 'test-key-1', got '%s'", jwks.Keys[0].KeyID) + } + if requestCount != 1 { + t.Errorf("expected 1 request, got %d", requestCount) + } + }) + // Test cache hit + t.Run("cache hit", func(t *testing.T) { + jwks, err := cache.Get(context.Background(), server.URL) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if len(jwks.Keys) != 2 { + t.Errorf("expected 2 keys from cache, got %d", len(jwks.Keys)) + } + // Should still be 1 request (served from cache) + if requestCount != 1 { + t.Errorf("expected 1 request (cached), got %d", requestCount) + } + }) + // Test FindKey + t.Run("find key", func(t *testing.T) { + jwks, _ := cache.Get(context.Background(), server.URL) + key, err := jwks.FindKey("test-key-2") + if err != nil { + t.Fatalf("FindKey failed: %v", err) + } + if key.KeyID != "test-key-2" { + t.Errorf("expected key ID 'test-key-2', got '%s'", key.KeyID) + } + if key.N != "test-modulus-2" { + t.Errorf("expected modulus 'test-modulus-2', got '%s'", key.N) + } + }) + // Test key not found + t.Run("key not found", func(t *testing.T) { + jwks, _ := cache.Get(context.Background(), server.URL) + _, err := jwks.FindKey("nonexistent") + if err == nil { + t.Error("expected error for nonexistent key, got nil") + } + }) + // Test Invalidate + t.Run("invalidate", func(t *testing.T) { + cache.Invalidate(server.URL) + // Next fetch should hit the server again + _, err := cache.Get(context.Background(), server.URL) + if err != nil { + t.Fatalf("Get after invalidate failed: %v", err) + } + if requestCount != 2 { + t.Errorf("expected 2 requests after invalidate, got %d", requestCount) + } + }) + // Test Clear + t.Run("clear", func(t *testing.T) { + cache.Clear() + // Next fetch should hit the server again + _, err := cache.Get(context.Background(), server.URL) + if err != nil { + t.Fatalf("Get after clear failed: %v", err) + } + if requestCount != 3 { + t.Errorf("expected 3 requests after clear, got %d", requestCount) + } + }) + // Test error handling + t.Run("server error", func(t *testing.T) { + errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "internal error", http.StatusInternalServerError) + })) + defer errorServer.Close() + _, err := cache.Get(context.Background(), errorServer.URL) + if err == nil { + t.Error("expected error for server error, got nil") + } + }) + // Test invalid JSON + t.Run("invalid json", func(t *testing.T) { + invalidServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("invalid json")) + })) + defer invalidServer.Close() + _, err := cache.Get(context.Background(), invalidServer.URL) + if err == nil { + t.Error("expected error for invalid JSON, got nil") + } + }) + // Test empty keys + t.Run("empty keys", func(t *testing.T) { + emptyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(&JWKS{Keys: []JWK{}}) + })) + defer emptyServer.Close() + _, err := cache.Get(context.Background(), emptyServer.URL) + if err == nil { + t.Error("expected error for empty keys, got nil") + } + }) +} + +// TestJWKSCacheExpiration tests cache expiration. +func TestJWKSCacheExpiration(t *testing.T) { + testJWKS := &JWKS{ + Keys: []JWK{{KeyID: "test", KeyType: "RSA", N: "test", E: "AQAB"}}, + } + var requestCount int + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(testJWKS) + })) + defer server.Close() + cache := NewJWKSCache(server.Client()) + // First fetch + _, err := cache.Get(context.Background(), server.URL) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + // Manually expire the cache entry + cache.mu.Lock() + if entry, ok := cache.entries[server.URL]; ok { + entry.expiresAt = time.Now().Add(-1 * time.Hour) + } + cache.mu.Unlock() + // Next fetch should hit server again + _, err = cache.Get(context.Background(), server.URL) + if err != nil { + t.Fatalf("Get after expiration failed: %v", err) + } + if requestCount != 2 { + t.Errorf("expected 2 requests after expiration, got %d", requestCount) + } +} From a0e9031b64db62e993655255971e4149823d9a18 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Mon, 24 Nov 2025 21:21:19 +0530 Subject: [PATCH 06/19] feat: add id-jag validation for mcp servers --- auth/id_jag_verifier.go | 250 +++++++++++++++++++++++++++++++++++ auth/id_jag_verifier_test.go | 191 ++++++++++++++++++++++++++ 2 files changed, 441 insertions(+) create mode 100644 auth/id_jag_verifier.go create mode 100644 auth/id_jag_verifier_test.go diff --git a/auth/id_jag_verifier.go b/auth/id_jag_verifier.go new file mode 100644 index 00000000..45cb710d --- /dev/null +++ b/auth/id_jag_verifier.go @@ -0,0 +1,250 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements ID-JAG (Identity Assertion JWT Authorization Grant) validation +// for MCP Servers in Enterprise Managed Authorization (SEP-990). + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "net/http" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +// TrustedIdPConfig contains configuration for a trusted Identity Provider. +type TrustedIdPConfig struct { + // IssuerURL is the IdP's issuer URL (must match the iss claim). + IssuerURL string + // JWKSURL is the URL to fetch the IdP's JSON Web Key Set. + JWKSURL string +} + +// IDJAGVerifierConfig configures ID-JAG validation for an MCP Server. +type IDJAGVerifierConfig struct { + // AuthServerIssuerURL is this MCP Server's authorization server issuer URL. + // This must match the aud claim in the ID-JAG. + AuthServerIssuerURL string + // TrustedIdPs is a map of trusted Identity Providers. + // The key is a friendly name, the value is the IdP configuration. + TrustedIdPs map[string]*TrustedIdPConfig + // JWKSCache is the cache for JWKS responses. If nil, a new cache is created. + JWKSCache *JWKSCache + // HTTPClient is the HTTP client for fetching JWKS. If nil, http.DefaultClient is used. + HTTPClient *http.Client + // AllowedClockSkew is the allowed clock skew for exp/iat validation. + // Default is 5 minutes. + AllowedClockSkew time.Duration +} + +// IDJAGVerifier validates ID-JAG tokens for MCP Servers. +type IDJAGVerifier struct { + config *IDJAGVerifierConfig + jwksCache *JWKSCache + usedJTIs map[string]time.Time // Replay attack prevention + usedJTIMu sync.RWMutex +} + +// NewIDJAGVerifier creates a new ID-JAG verifier with the given configuration. +// This returns a TokenVerifier that can be used with RequireBearerToken middleware. +// +// Example: +// +// config := &IDJAGVerifierConfig{ +// AuthServerIssuerURL: "https://auth.mcpserver.example", +// TrustedIdPs: map[string]*TrustedIdPConfig{ +// "acme-okta": { +// IssuerURL: "https://acme.okta.com", +// JWKSURL: "https://acme.okta.com/.well-known/jwks.json", +// }, +// }, +// } +// +// verifier := NewIDJAGVerifier(config) +// middleware := RequireBearerToken(verifier, &RequireBearerTokenOptions{ +// Scopes: []string{"read"}, +// }) +func NewIDJAGVerifier(config *IDJAGVerifierConfig) TokenVerifier { + if config.JWKSCache == nil { + config.JWKSCache = NewJWKSCache(config.HTTPClient) + } + if config.AllowedClockSkew == 0 { + config.AllowedClockSkew = 5 * time.Minute + } + verifier := &IDJAGVerifier{ + config: config, + jwksCache: config.JWKSCache, + usedJTIs: make(map[string]time.Time), + } + // Start cleanup goroutine for JTI tracking + go verifier.cleanupExpiredJTIs() + return verifier.Verify +} + +// Verify validates an ID-JAG token and returns TokenInfo. +// This implements the TokenVerifier interface. +func (v *IDJAGVerifier) Verify(ctx context.Context, token string, req *http.Request) (*TokenInfo, error) { + // Step 1: Parse the ID-JAG (without signature verification yet) + claims, err := oauthex.ParseIDJAG(token) + if err != nil { + return nil, fmt.Errorf("%w: failed to parse ID-JAG: %v", ErrInvalidToken, err) + } + // Step 2: Check if expired (with clock skew) + expiryTime := time.Unix(claims.ExpiresAt, 0) + if time.Now().After(expiryTime.Add(v.config.AllowedClockSkew)) { + return nil, fmt.Errorf("%w: ID-JAG expired at %v", ErrInvalidToken, expiryTime) + } + // Step 3: Validate aud claim per SEP-990 Section 5.1 + if claims.Audience != v.config.AuthServerIssuerURL { + return nil, fmt.Errorf("%w: invalid audience: expected %q, got %q", + ErrInvalidToken, v.config.AuthServerIssuerURL, claims.Audience) + } + // Step 4: Find trusted IdP + var trustedIdP *TrustedIdPConfig + for _, idp := range v.config.TrustedIdPs { + if idp.IssuerURL == claims.Issuer { + trustedIdP = idp + break + } + } + if trustedIdP == nil { + return nil, fmt.Errorf("%w: untrusted issuer: %q", ErrInvalidToken, claims.Issuer) + } + // Step 5: Verify JWT signature using IdP's JWKS + if err := v.verifySignature(ctx, token, trustedIdP.JWKSURL); err != nil { + return nil, fmt.Errorf("%w: signature verification failed: %v", ErrInvalidToken, err) + } + // Step 6: Replay attack prevention (check JTI) + if err := v.checkJTI(claims.JTI, expiryTime); err != nil { + return nil, fmt.Errorf("%w: %v", ErrInvalidToken, err) + } + // Step 7: Return TokenInfo + scopes := []string{} + if claims.Scope != "" { + scopes = strings.Split(claims.Scope, " ") + } + return &TokenInfo{ + Scopes: scopes, + Expiration: expiryTime, + Extra: map[string]any{ + "sub": claims.Subject, + "client_id": claims.ClientID, + "resource": claims.Resource, + "iss": claims.Issuer, + }, + }, nil +} + +// verifySignature verifies the JWT signature using the IdP's JWKS. +func (v *IDJAGVerifier) verifySignature(ctx context.Context, tokenString, jwksURL string) error { + // Parse JWT to get header + parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + return fmt.Errorf("invalid JWT format") + } + // Decode header to get kid + headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return fmt.Errorf("failed to decode JWT header: %w", err) + } + var header struct { + Kid string `json:"kid"` + Alg string `json:"alg"` + } + if err := json.Unmarshal(headerJSON, &header); err != nil { + return fmt.Errorf("failed to parse JWT header: %w", err) + } + // Fetch JWKS + jwks, err := v.jwksCache.Get(ctx, jwksURL) + if err != nil { + return fmt.Errorf("failed to fetch JWKS: %w", err) + } + // Find the key + jwk, err := jwks.FindKey(header.Kid) + if err != nil { + return fmt.Errorf("key not found in JWKS: %w", err) + } + // Parse JWT with verification + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + // Verify algorithm + if token.Method.Alg() != header.Alg { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + // Convert JWK to public key + return jwkToPublicKey(jwk) + }) + if err != nil { + return fmt.Errorf("JWT verification failed: %w", err) + } + if !token.Valid { + return fmt.Errorf("JWT is invalid") + } + return nil +} + +// checkJTI checks if the JTI has been used before (replay attack prevention). +func (v *IDJAGVerifier) checkJTI(jti string, expiresAt time.Time) error { + v.usedJTIMu.Lock() + defer v.usedJTIMu.Unlock() + if _, used := v.usedJTIs[jti]; used { + return fmt.Errorf("JTI %q already used (replay attack)", jti) + } + // Mark as used + v.usedJTIs[jti] = expiresAt + return nil +} + +// cleanupExpiredJTIs periodically removes expired JTIs from the tracking map. +func (v *IDJAGVerifier) cleanupExpiredJTIs() { + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + for range ticker.C { + v.usedJTIMu.Lock() + now := time.Now() + for jti, expiresAt := range v.usedJTIs { + if now.After(expiresAt) { + delete(v.usedJTIs, jti) + } + } + v.usedJTIMu.Unlock() + } +} + +// jwkToPublicKey converts a JWK to a public key for signature verification. +func jwkToPublicKey(jwk *JWK) (interface{}, error) { + switch jwk.KeyType { + case "RSA": + // Decode modulus + nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N) + if err != nil { + return nil, fmt.Errorf("failed to decode modulus: %w", err) + } + // Decode exponent + eBytes, err := base64.RawURLEncoding.DecodeString(jwk.E) + if err != nil { + return nil, fmt.Errorf("failed to decode exponent: %w", err) + } + // Convert to big.Int + n := new(big.Int).SetBytes(nBytes) + e := new(big.Int).SetBytes(eBytes) + return &rsa.PublicKey{ + N: n, + E: int(e.Int64()), + }, nil + default: + return nil, fmt.Errorf("unsupported key type: %s", jwk.KeyType) + } +} diff --git a/auth/id_jag_verifier_test.go b/auth/id_jag_verifier_test.go new file mode 100644 index 00000000..544aef6b --- /dev/null +++ b/auth/id_jag_verifier_test.go @@ -0,0 +1,191 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// TestIDJAGVerifier tests ID-JAG validation. +func TestIDJAGVerifier(t *testing.T) { + // Generate RSA key pair for testing + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + publicKey := &privateKey.PublicKey + // Create mock JWKS server + jwksServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jwks := &JWKS{ + Keys: []JWK{ + { + KeyType: "RSA", + Use: "sig", + KeyID: "test-key", + Algorithm: "RS256", + N: base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(publicKey.E)).Bytes()), + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(jwks) + })) + defer jwksServer.Close() + // Configure verifier + config := &IDJAGVerifierConfig{ + AuthServerIssuerURL: "https://auth.mcpserver.example", + TrustedIdPs: map[string]*TrustedIdPConfig{ + "test-idp": { + IssuerURL: "https://test.okta.com", + JWKSURL: jwksServer.URL, + }, + }, + HTTPClient: jwksServer.Client(), + } + verifier := NewIDJAGVerifier(config) + // Test valid ID-JAG + t.Run("valid ID-JAG", func(t *testing.T) { + idJAG := createTestIDJAG(t, privateKey, map[string]interface{}{ + "iss": "https://test.okta.com", + "sub": "user123", + "aud": "https://auth.mcpserver.example", + "resource": "https://mcp.mcpserver.example", + "client_id": "client123", + "jti": "jti-" + fmt.Sprint(time.Now().UnixNano()), + "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), + "scope": "read write", + }) + tokenInfo, err := verifier(context.Background(), idJAG, nil) + if err != nil { + t.Fatalf("Verify failed: %v", err) + } + if len(tokenInfo.Scopes) != 2 { + t.Errorf("expected 2 scopes, got %d", len(tokenInfo.Scopes)) + } + if tokenInfo.Extra["sub"] != "user123" { + t.Errorf("expected sub 'user123', got %v", tokenInfo.Extra["sub"]) + } + if tokenInfo.Extra["client_id"] != "client123" { + t.Errorf("expected client_id 'client123', got %v", tokenInfo.Extra["client_id"]) + } + }) + // Test expired ID-JAG + t.Run("expired ID-JAG", func(t *testing.T) { + idJAG := createTestIDJAG(t, privateKey, map[string]interface{}{ + "iss": "https://test.okta.com", + "sub": "user123", + "aud": "https://auth.mcpserver.example", + "resource": "https://mcp.mcpserver.example", + "client_id": "client123", + "jti": "jti-expired", + "exp": time.Now().Add(-1 * time.Hour).Unix(), + "iat": time.Now().Add(-2 * time.Hour).Unix(), + "scope": "read write", + }) + _, err := verifier(context.Background(), idJAG, nil) + if err == nil { + t.Error("expected error for expired ID-JAG, got nil") + } + }) + // Test wrong audience + t.Run("wrong audience", func(t *testing.T) { + idJAG := createTestIDJAG(t, privateKey, map[string]interface{}{ + "iss": "https://test.okta.com", + "sub": "user123", + "aud": "https://wrong.audience.com", + "resource": "https://mcp.mcpserver.example", + "client_id": "client123", + "jti": "jti-wrong-aud", + "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), + "scope": "read write", + }) + _, err := verifier(context.Background(), idJAG, nil) + if err == nil { + t.Error("expected error for wrong audience, got nil") + } + if !strings.Contains(err.Error(), "invalid audience") { + t.Errorf("expected 'invalid audience' error, got: %v", err) + } + }) + // Test untrusted issuer + t.Run("untrusted issuer", func(t *testing.T) { + idJAG := createTestIDJAG(t, privateKey, map[string]interface{}{ + "iss": "https://untrusted.idp.com", + "sub": "user123", + "aud": "https://auth.mcpserver.example", + "resource": "https://mcp.mcpserver.example", + "client_id": "client123", + "jti": "jti-untrusted", + "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), + "scope": "read write", + }) + _, err := verifier(context.Background(), idJAG, nil) + if err == nil { + t.Error("expected error for untrusted issuer, got nil") + } + if !strings.Contains(err.Error(), "untrusted issuer") { + t.Errorf("expected 'untrusted issuer' error, got: %v", err) + } + }) + // Test replay attack + t.Run("replay attack", func(t *testing.T) { + jti := "jti-replay-" + fmt.Sprint(time.Now().UnixNano()) + idJAG := createTestIDJAG(t, privateKey, map[string]interface{}{ + "iss": "https://test.okta.com", + "sub": "user123", + "aud": "https://auth.mcpserver.example", + "resource": "https://mcp.mcpserver.example", + "client_id": "client123", + "jti": jti, + "exp": time.Now().Add(1 * time.Hour).Unix(), + "iat": time.Now().Unix(), + "scope": "read write", + }) + // First use should succeed + _, err := verifier(context.Background(), idJAG, nil) + if err != nil { + t.Fatalf("First verify failed: %v", err) + } + // Second use (replay) should fail + _, err = verifier(context.Background(), idJAG, nil) + if err == nil { + t.Error("expected error for replay attack, got nil") + } + if !strings.Contains(err.Error(), "already used") { + t.Errorf("expected 'already used' error, got: %v", err) + } + }) +} + +// createTestIDJAG creates a test ID-JAG JWT signed with the given private key. +func createTestIDJAG(t *testing.T, privateKey *rsa.PrivateKey, claims map[string]interface{}) string { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims(claims)) + token.Header["typ"] = "oauth-id-jag+jwt" + token.Header["kid"] = "test-key" + signedToken, err := token.SignedString(privateKey) + if err != nil { + t.Fatalf("Failed to sign token: %v", err) + } + return signedToken +} From 35c0565fe43c4bdf2bb7e80ad0f4678803810400 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Mon, 24 Nov 2025 23:16:30 +0530 Subject: [PATCH 07/19] feat: add oidc login for sso --- auth/enterprise_auth.go | 47 ++++- auth/oidc_login.go | 454 ++++++++++++++++++++++++++++++++++++++++ auth/oidc_login_test.go | 384 +++++++++++++++++++++++++++++++++ oauthex/oauth2.go | 6 + 4 files changed, 888 insertions(+), 3 deletions(-) create mode 100644 auth/oidc_login.go create mode 100644 auth/oidc_login_test.go diff --git a/auth/enterprise_auth.go b/auth/enterprise_auth.go index 5478e0e4..2d903ca2 100644 --- a/auth/enterprise_auth.go +++ b/auth/enterprise_auth.go @@ -45,7 +45,48 @@ type EnterpriseAuthConfig struct { // This function takes an ID Token that was obtained via SSO (e.g., OIDC login) // and exchanges it for an access token that can be used to call the MCP Server. // -// Example: +// There are two ways to obtain an ID Token for use with this function: +// +// Option 1: Use the OIDC login helper functions (full flow with SSO): +// +// // Step 1: Initiate OIDC login +// oidcConfig := &OIDCLoginConfig{ +// IssuerURL: "https://acme.okta.com", +// ClientID: "client-id", +// RedirectURL: "http://localhost:8080/callback", +// Scopes: []string{"openid", "profile", "email"}, +// } +// authReq, err := InitiateOIDCLogin(ctx, oidcConfig) +// if err != nil { +// log.Fatal(err) +// } +// +// // Step 2: Direct user to authReq.AuthURL for authentication +// fmt.Printf("Visit: %s\n", authReq.AuthURL) +// +// // Step 3: After redirect, complete login with authorization code +// tokens, err := CompleteOIDCLogin(ctx, oidcConfig, authCode, authReq.CodeVerifier) +// if err != nil { +// log.Fatal(err) +// } +// +// // Step 4: Use ID token for enterprise auth +// enterpriseConfig := &EnterpriseAuthConfig{ +// IdPIssuerURL: "https://acme.okta.com", +// IdPClientID: "client-id-at-idp", +// IdPClientSecret: "secret-at-idp", +// MCPAuthServerURL: "https://auth.mcpserver.example", +// MCPResourceURL: "https://mcp.mcpserver.example", +// MCPClientID: "client-id-at-mcp", +// MCPClientSecret: "secret-at-mcp", +// MCPScopes: []string{"read", "write"}, +// } +// accessToken, err := EnterpriseAuthFlow(ctx, enterpriseConfig, tokens.IDToken) +// if err != nil { +// log.Fatal(err) +// } +// +// Option 2: Bring your own ID Token (if you already have one): // // config := &EnterpriseAuthConfig{ // IdPIssuerURL: "https://acme.okta.com", @@ -58,8 +99,8 @@ type EnterpriseAuthConfig struct { // MCPScopes: []string{"read", "write"}, // } // -// // After user logs in via OIDC, you have an ID Token -// accessToken, err := EnterpriseAuthFlow(ctx, config, idToken) +// // If you already obtained an ID token through your own means +// accessToken, err := EnterpriseAuthFlow(ctx, config, myIDToken) // if err != nil { // log.Fatal(err) // } diff --git a/auth/oidc_login.go b/auth/oidc_login.go new file mode 100644 index 00000000..5c1d9106 --- /dev/null +++ b/auth/oidc_login.go @@ -0,0 +1,454 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file implements OIDC Authorization Code flow for obtaining ID tokens +// as part of Enterprise Managed Authorization (SEP-990). +// See https://openid.net/specs/openid-connect-core-1_0.html + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +// OIDCLoginConfig configures the OIDC Authorization Code flow for obtaining +// an ID Token. This is an OPTIONAL step before calling EnterpriseAuthFlow. +// Users can alternatively obtain ID tokens through their own methods. +type OIDCLoginConfig struct { + // IssuerURL is the IdP's issuer URL (e.g., "https://acme.okta.com"). + IssuerURL string + // ClientID is the MCP Client's ID registered at the IdP. + ClientID string + // ClientSecret is the MCP Client's secret at the IdP. + // This is OPTIONAL and only used if the client is confidential. + ClientSecret string + // RedirectURL is the OAuth2 redirect URI registered with the IdP. + // This must match exactly what was registered with the IdP. + RedirectURL string + // Scopes are the OAuth2/OIDC scopes to request. + // "openid" is REQUIRED for OIDC. Common values: ["openid", "profile", "email"] + Scopes []string + // LoginHint is an OPTIONAL hint to the IdP about the user's identity. + // Some IdPs may require this (e.g., as an email address for routing to SSO providers). + // Example: "user@example.com" + LoginHint string + // HTTPClient is the HTTP client for making requests. + // If nil, http.DefaultClient is used. + HTTPClient *http.Client +} + +// OIDCAuthorizationRequest represents the result of initiating an OIDC +// authorization code flow. Users must direct the end-user to AuthURL +// to complete authentication. +type OIDCAuthorizationRequest struct { + // AuthURL is the URL the user should visit to authenticate. + // This URL includes the authorization request parameters. + AuthURL string + // State is the OAuth2 state parameter for CSRF protection. + // Users MUST validate that the state returned from the IdP matches this value. + State string + // CodeVerifier is the PKCE code verifier for secure authorization code exchange. + // This must be provided to CompleteOIDCLogin along with the authorization code. + CodeVerifier string +} + +// OIDCTokenResponse contains the tokens returned from a successful OIDC login. +type OIDCTokenResponse struct { + // IDToken is the OpenID Connect ID Token (JWT). + // This can be passed to EnterpriseAuthFlow for token exchange. + IDToken string + // AccessToken is the OAuth2 access token (if issued by IdP). + // This is typically not needed for SEP-990, but may be useful for other IdP APIs. + AccessToken string + // RefreshToken is the OAuth2 refresh token (if issued by IdP). + RefreshToken string + // TokenType is the token type (typically "Bearer"). + TokenType string + // ExpiresAt is when the ID token expires. + ExpiresAt int64 +} + +// InitiateOIDCLogin initiates an OIDC Authorization Code flow with PKCE. +// This is the first step for users who want to use SSO to obtain an ID token. +// +// The returned AuthURL should be presented to the user (e.g., opened in a browser). +// After the user authenticates, the IdP will redirect to the RedirectURL with +// an authorization code and state parameter. +// +// Example: +// +// config := &OIDCLoginConfig{ +// IssuerURL: "https://acme.okta.com", +// ClientID: "client-id", +// RedirectURL: "http://localhost:8080/callback", +// Scopes: []string{"openid", "profile", "email"}, +// } +// +// authReq, err := InitiateOIDCLogin(ctx, config) +// if err != nil { +// log.Fatal(err) +// } +// +// // Direct user to authReq.AuthURL +// fmt.Printf("Visit this URL to login: %s\n", authReq.AuthURL) +// +// // After user completes login, IdP redirects to RedirectURL with code & state +// // Extract code and state from the redirect, then call CompleteOIDCLogin +func InitiateOIDCLogin( + ctx context.Context, + config *OIDCLoginConfig, +) (*OIDCAuthorizationRequest, error) { + if config == nil { + return nil, fmt.Errorf("config is required") + } + // Validate required fields + if config.IssuerURL == "" { + return nil, fmt.Errorf("IssuerURL is required") + } + if config.ClientID == "" { + return nil, fmt.Errorf("ClientID is required") + } + if config.RedirectURL == "" { + return nil, fmt.Errorf("RedirectURL is required") + } + if len(config.Scopes) == 0 { + return nil, fmt.Errorf("Scopes is required (must include 'openid')") + } + // Validate that "openid" scope is present (required for OIDC) + hasOpenID := false + for _, scope := range config.Scopes { + if scope == "openid" { + hasOpenID = true + break + } + } + if !hasOpenID { + return nil, fmt.Errorf("Scopes must include 'openid' for OIDC") + } + // Validate URL schemes to prevent XSS attacks + if err := oauthex.CheckURLScheme(config.IssuerURL); err != nil { + return nil, fmt.Errorf("invalid IssuerURL: %w", err) + } + if err := oauthex.CheckURLScheme(config.RedirectURL); err != nil { + return nil, fmt.Errorf("invalid RedirectURL: %w", err) + } + // Discover OIDC endpoints via .well-known + httpClient := config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + meta, err := oauthex.GetAuthServerMeta(ctx, config.IssuerURL, httpClient) + if err != nil { + return nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) + } + if meta.AuthorizationEndpoint == "" { + return nil, fmt.Errorf("authorization_endpoint not found in OIDC metadata") + } + // Generate PKCE code verifier and challenge (RFC 7636) + codeVerifier, err := generateCodeVerifier() + if err != nil { + return nil, fmt.Errorf("failed to generate PKCE verifier: %w", err) + } + codeChallenge := generateCodeChallenge(codeVerifier) + // Generate state for CSRF protection (RFC 6749 Section 10.12) + state, err := generateState() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + // Build authorization URL per OIDC Core Section 3.1.2.1 + authURL, err := buildAuthorizationURL( + meta.AuthorizationEndpoint, + config.ClientID, + config.RedirectURL, + config.Scopes, + state, + codeChallenge, + config.LoginHint, + ) + if err != nil { + return nil, fmt.Errorf("failed to build authorization URL: %w", err) + } + return &OIDCAuthorizationRequest{ + AuthURL: authURL, + State: state, + CodeVerifier: codeVerifier, + }, nil +} + +// CompleteOIDCLogin completes the OIDC Authorization Code flow by exchanging +// the authorization code for tokens. This is the second step after the user +// has authenticated and been redirected back to the application. +// +// The authCode and returnedState parameters should come from the redirect URL +// query parameters. The state MUST match the state from InitiateOIDCLogin +// for CSRF protection. +// +// Example: +// +// // In your redirect handler (e.g., http://localhost:8080/callback) +// authCode := r.URL.Query().Get("code") +// returnedState := r.URL.Query().Get("state") +// +// // Validate state matches what we sent +// if returnedState != authReq.State { +// log.Fatal("State mismatch - possible CSRF attack") +// } +// +// // Exchange code for tokens +// tokens, err := CompleteOIDCLogin(ctx, config, authCode, authReq.CodeVerifier) +// if err != nil { +// log.Fatal(err) +// } +// +// // Now use tokens.IDToken with EnterpriseAuthFlow +// accessToken, err := EnterpriseAuthFlow(ctx, enterpriseConfig, tokens.IDToken) +func CompleteOIDCLogin( + ctx context.Context, + config *OIDCLoginConfig, + authCode string, + codeVerifier string, +) (*OIDCTokenResponse, error) { + if config == nil { + return nil, fmt.Errorf("config is required") + } + if authCode == "" { + return nil, fmt.Errorf("authCode is required") + } + if codeVerifier == "" { + return nil, fmt.Errorf("codeVerifier is required") + } + // Validate required fields + if config.IssuerURL == "" { + return nil, fmt.Errorf("IssuerURL is required") + } + if config.ClientID == "" { + return nil, fmt.Errorf("ClientID is required") + } + if config.RedirectURL == "" { + return nil, fmt.Errorf("RedirectURL is required") + } + // Discover token endpoint + httpClient := config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + meta, err := oauthex.GetAuthServerMeta(ctx, config.IssuerURL, httpClient) + if err != nil { + return nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) + } + if meta.TokenEndpoint == "" { + return nil, fmt.Errorf("token_endpoint not found in OIDC metadata") + } + // Build token request per OIDC Core Section 3.1.3.1 + formData := url.Values{} + formData.Set("grant_type", "authorization_code") + formData.Set("code", authCode) + formData.Set("redirect_uri", config.RedirectURL) + formData.Set("client_id", config.ClientID) + formData.Set("code_verifier", codeVerifier) + // Add client_secret if provided (confidential client) + if config.ClientSecret != "" { + formData.Set("client_secret", config.ClientSecret) + } + // Exchange authorization code for tokens + oauth2Token, err := exchangeAuthorizationCode( + ctx, + meta.TokenEndpoint, + formData, + httpClient, + ) + if err != nil { + return nil, fmt.Errorf("token exchange failed: %w", err) + } + // Extract ID Token from response + idToken, ok := oauth2Token.Extra("id_token").(string) + if !ok || idToken == "" { + return nil, fmt.Errorf("id_token not found in token response") + } + return &OIDCTokenResponse{ + IDToken: idToken, + AccessToken: oauth2Token.AccessToken, + RefreshToken: oauth2Token.RefreshToken, + TokenType: oauth2Token.TokenType, + ExpiresAt: oauth2Token.Expiry.Unix(), + }, nil +} + +// generateCodeVerifier generates a cryptographically random code verifier +// for PKCE per RFC 7636 Section 4.1. +func generateCodeVerifier() (string, error) { + // Per RFC 7636: code verifier is 43-128 characters from [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~" + // We use 32 random bytes (256 bits) base64url-encoded = 43 characters + randomBytes := make([]byte, 32) + if _, err := rand.Read(randomBytes); err != nil { + return "", fmt.Errorf("failed to generate random bytes: %w", err) + } + return base64.RawURLEncoding.EncodeToString(randomBytes), nil +} + +// generateCodeChallenge generates the PKCE code challenge from the verifier +// using SHA256 per RFC 7636 Section 4.2. +func generateCodeChallenge(verifier string) string { + hash := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(hash[:]) +} + +// generateState generates a cryptographically random state parameter +// for CSRF protection per RFC 6749 Section 10.12. +func generateState() (string, error) { + randomBytes := make([]byte, 32) + if _, err := rand.Read(randomBytes); err != nil { + return "", fmt.Errorf("failed to generate random bytes: %w", err) + } + return base64.RawURLEncoding.EncodeToString(randomBytes), nil +} + +// buildAuthorizationURL constructs the OIDC authorization URL. +func buildAuthorizationURL( + authEndpoint string, + clientID string, + redirectURL string, + scopes []string, + state string, + codeChallenge string, + loginHint string, +) (string, error) { + u, err := url.Parse(authEndpoint) + if err != nil { + return "", fmt.Errorf("invalid authorization endpoint: %w", err) + } + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", clientID) + q.Set("redirect_uri", redirectURL) + q.Set("scope", strings.Join(scopes, " ")) + q.Set("state", state) + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + // Add login_hint if provided (optional per OIDC spec, but some IdPs may require it) + if loginHint != "" { + q.Set("login_hint", loginHint) + } + u.RawQuery = q.Encode() + return u.String(), nil +} + +// exchangeAuthorizationCode exchanges the authorization code for tokens. +func exchangeAuthorizationCode( + ctx context.Context, + tokenEndpoint string, + formData url.Values, + httpClient *http.Client, +) (*oauth2.Token, error) { + // Create HTTP request + httpReq, err := http.NewRequestWithContext( + ctx, + http.MethodPost, + tokenEndpoint, + strings.NewReader(formData.Encode()), + ) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + httpReq.Header.Set("Accept", "application/json") + + // Execute request + httpResp, err := httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("token request failed: %w", err) + } + defer httpResp.Body.Close() + + // Read response body (limit to 1MB for safety) + body, err := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20)) + if err != nil { + return nil, fmt.Errorf("failed to read token response: %w", err) + } + + // Handle success response (200 OK) + if httpResp.StatusCode == http.StatusOK { + // Parse token response manually (following jwt_bearer.go pattern) + var tokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + IDToken string `json:"id_token,omitempty"` + Scope string `json:"scope,omitempty"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w (body: %s)", err, string(body)) + } + + // Validate required fields + if tokenResp.AccessToken == "" { + return nil, fmt.Errorf("response missing required field: access_token") + } + if tokenResp.TokenType == "" { + return nil, fmt.Errorf("response missing required field: token_type") + } + + // Convert to oauth2.Token + token := &oauth2.Token{ + AccessToken: tokenResp.AccessToken, + TokenType: tokenResp.TokenType, + RefreshToken: tokenResp.RefreshToken, + } + + // Set expiration if provided + if tokenResp.ExpiresIn > 0 { + token.Expiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + // Add extra fields (id_token, scope) + extra := make(map[string]interface{}) + if tokenResp.IDToken != "" { + extra["id_token"] = tokenResp.IDToken + } + if tokenResp.Scope != "" { + extra["scope"] = tokenResp.Scope + } + if len(extra) > 0 { + token = token.WithExtra(extra) + } + + return token, nil + } + + // Handle error response (400 Bad Request) + if httpResp.StatusCode == http.StatusBadRequest { + var errResp struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` + ErrorURI string `json:"error_uri,omitempty"` + } + if err := json.Unmarshal(body, &errResp); err != nil { + return nil, fmt.Errorf("failed to parse error response: %w (body: %s)", err, string(body)) + } + if errResp.ErrorDescription != "" { + return nil, fmt.Errorf("token request failed: %s (%s)", errResp.Error, errResp.ErrorDescription) + } + return nil, fmt.Errorf("token request failed: %s", errResp.Error) + } + + // Handle unexpected status codes + return nil, fmt.Errorf("unexpected status code %d: %s", httpResp.StatusCode, string(body)) +} diff --git a/auth/oidc_login_test.go b/auth/oidc_login_test.go new file mode 100644 index 00000000..ca8c3609 --- /dev/null +++ b/auth/oidc_login_test.go @@ -0,0 +1,384 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +// TestInitiateOIDCLogin tests the OIDC authorization request generation. +func TestInitiateOIDCLogin(t *testing.T) { + // Create mock IdP server + idpServer := createMockOIDCServer(t) + defer idpServer.Close() + config := &OIDCLoginConfig{ + IssuerURL: idpServer.URL, + ClientID: "test-client", + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{"openid", "profile", "email"}, + HTTPClient: idpServer.Client(), + } + t.Run("successful initiation", func(t *testing.T) { + authReq, err := InitiateOIDCLogin(context.Background(), config) + if err != nil { + t.Fatalf("InitiateOIDCLogin failed: %v", err) + } + // Validate AuthURL + if authReq.AuthURL == "" { + t.Error("AuthURL is empty") + } + // Parse and validate URL parameters + u, err := url.Parse(authReq.AuthURL) + if err != nil { + t.Fatalf("Failed to parse AuthURL: %v", err) + } + q := u.Query() + if q.Get("response_type") != "code" { + t.Errorf("expected response_type 'code', got '%s'", q.Get("response_type")) + } + if q.Get("client_id") != "test-client" { + t.Errorf("expected client_id 'test-client', got '%s'", q.Get("client_id")) + } + if q.Get("redirect_uri") != "http://localhost:8080/callback" { + t.Errorf("expected redirect_uri 'http://localhost:8080/callback', got '%s'", q.Get("redirect_uri")) + } + if q.Get("scope") != "openid profile email" { + t.Errorf("expected scope 'openid profile email', got '%s'", q.Get("scope")) + } + if q.Get("code_challenge_method") != "S256" { + t.Errorf("expected code_challenge_method 'S256', got '%s'", q.Get("code_challenge_method")) + } + // Validate state is generated + if authReq.State == "" { + t.Error("State is empty") + } + if q.Get("state") != authReq.State { + t.Errorf("state in URL doesn't match returned state") + } + // Validate PKCE parameters + if authReq.CodeVerifier == "" { + t.Error("CodeVerifier is empty") + } + if q.Get("code_challenge") == "" { + t.Error("code_challenge is empty") + } + }) + t.Run("with login_hint", func(t *testing.T) { + configWithHint := *config + configWithHint.LoginHint = "user@example.com" + authReq, err := InitiateOIDCLogin(context.Background(), &configWithHint) + if err != nil { + t.Fatalf("InitiateOIDCLogin failed: %v", err) + } + u, err := url.Parse(authReq.AuthURL) + if err != nil { + t.Fatalf("Failed to parse AuthURL: %v", err) + } + q := u.Query() + if q.Get("login_hint") != "user@example.com" { + t.Errorf("expected login_hint 'user@example.com', got '%s'", q.Get("login_hint")) + } + }) + t.Run("without login_hint", func(t *testing.T) { + authReq, err := InitiateOIDCLogin(context.Background(), config) + if err != nil { + t.Fatalf("InitiateOIDCLogin failed: %v", err) + } + u, err := url.Parse(authReq.AuthURL) + if err != nil { + t.Fatalf("Failed to parse AuthURL: %v", err) + } + q := u.Query() + if q.Has("login_hint") { + t.Errorf("expected no login_hint parameter, but got '%s'", q.Get("login_hint")) + } + }) + t.Run("nil config", func(t *testing.T) { + _, err := InitiateOIDCLogin(context.Background(), nil) + if err == nil { + t.Error("expected error for nil config, got nil") + } + }) + t.Run("missing openid scope", func(t *testing.T) { + badConfig := *config + badConfig.Scopes = []string{"profile", "email"} // Missing "openid" + _, err := InitiateOIDCLogin(context.Background(), &badConfig) + if err == nil { + t.Error("expected error for missing openid scope, got nil") + } + if !strings.Contains(err.Error(), "openid") { + t.Errorf("expected error about missing 'openid', got: %v", err) + } + }) + t.Run("missing required fields", func(t *testing.T) { + tests := []struct { + name string + mutate func(*OIDCLoginConfig) + expectErr string + }{ + { + name: "missing IssuerURL", + mutate: func(c *OIDCLoginConfig) { c.IssuerURL = "" }, + expectErr: "IssuerURL is required", + }, + { + name: "missing ClientID", + mutate: func(c *OIDCLoginConfig) { c.ClientID = "" }, + expectErr: "ClientID is required", + }, + { + name: "missing RedirectURL", + mutate: func(c *OIDCLoginConfig) { c.RedirectURL = "" }, + expectErr: "RedirectURL is required", + }, + { + name: "missing Scopes", + mutate: func(c *OIDCLoginConfig) { c.Scopes = nil }, + expectErr: "Scopes is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + badConfig := *config + tt.mutate(&badConfig) + _, err := InitiateOIDCLogin(context.Background(), &badConfig) + if err == nil { + t.Error("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.expectErr) { + t.Errorf("expected error containing '%s', got: %v", tt.expectErr, err) + } + }) + } + }) +} + +// TestCompleteOIDCLogin tests the authorization code exchange. +func TestCompleteOIDCLogin(t *testing.T) { + // Create mock IdP server + idpServer := createMockOIDCServerWithToken(t) + defer idpServer.Close() + config := &OIDCLoginConfig{ + IssuerURL: idpServer.URL, + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{"openid", "profile", "email"}, + HTTPClient: idpServer.Client(), + } + t.Run("successful code exchange", func(t *testing.T) { + tokens, err := CompleteOIDCLogin( + context.Background(), + config, + "test-auth-code", + "test-code-verifier", + ) + if err != nil { + t.Fatalf("CompleteOIDCLogin failed: %v", err) + } + // Validate tokens + if tokens.IDToken == "" { + t.Error("IDToken is empty") + } + if tokens.AccessToken == "" { + t.Error("AccessToken is empty") + } + if tokens.TokenType != "Bearer" { + t.Errorf("expected TokenType 'Bearer', got '%s'", tokens.TokenType) + } + if tokens.ExpiresAt == 0 { + t.Error("ExpiresAt is zero") + } + }) + t.Run("nil config", func(t *testing.T) { + _, err := CompleteOIDCLogin( + context.Background(), + nil, + "test-auth-code", + "test-code-verifier", + ) + if err == nil { + t.Error("expected error for nil config, got nil") + } + }) + t.Run("missing parameters", func(t *testing.T) { + tests := []struct { + name string + authCode string + codeVerifier string + expectErr string + }{ + { + name: "missing authCode", + authCode: "", + codeVerifier: "test-verifier", + expectErr: "authCode is required", + }, + { + name: "missing codeVerifier", + authCode: "test-code", + codeVerifier: "", + expectErr: "codeVerifier is required", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := CompleteOIDCLogin( + context.Background(), + config, + tt.authCode, + tt.codeVerifier, + ) + if err == nil { + t.Error("expected error, got nil") + } + if !strings.Contains(err.Error(), tt.expectErr) { + t.Errorf("expected error containing '%s', got: %v", tt.expectErr, err) + } + }) + } + }) +} + +// TestOIDCLoginE2E tests the complete OIDC login flow end-to-end. +func TestOIDCLoginE2E(t *testing.T) { + // Create mock IdP server + idpServer := createMockOIDCServerWithToken(t) + defer idpServer.Close() + config := &OIDCLoginConfig{ + IssuerURL: idpServer.URL, + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{"openid", "profile", "email"}, + HTTPClient: idpServer.Client(), + } + // Step 1: Initiate login + authReq, err := InitiateOIDCLogin(context.Background(), config) + if err != nil { + t.Fatalf("InitiateOIDCLogin failed: %v", err) + } + // Step 2: Simulate user authentication and redirect + // (In real flow, user would visit authReq.AuthURL and IdP would redirect back) + // Here we just use a mock authorization code + mockAuthCode := "mock-authorization-code" + // Step 3: Complete login with authorization code + tokens, err := CompleteOIDCLogin( + context.Background(), + config, + mockAuthCode, + authReq.CodeVerifier, + ) + if err != nil { + t.Fatalf("CompleteOIDCLogin failed: %v", err) + } + // Validate we got an ID token + if tokens.IDToken == "" { + t.Error("Expected ID token, got empty string") + } + // Validate ID token is a JWT (has 3 parts) + parts := strings.Split(tokens.IDToken, ".") + if len(parts) != 3 { + t.Errorf("Expected JWT with 3 parts, got %d parts", len(parts)) + } +} + +// createMockOIDCServer creates a mock OIDC server for testing InitiateOIDCLogin. +func createMockOIDCServer(t *testing.T) *httptest.Server { + var serverURL string + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle OIDC discovery + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": serverURL, + "authorization_endpoint": serverURL + "/authorize", + "token_endpoint": serverURL + "/token", + "jwks_uri": serverURL + "/.well-known/jwks.json", + "response_types_supported": []string{"code"}, + "code_challenge_methods_supported": []string{"S256"}, + "grant_types_supported": []string{"authorization_code"}, + }) + return + } + http.NotFound(w, r) + })) + serverURL = server.URL + return server +} + +// createMockOIDCServerWithToken creates a mock OIDC server that also handles token exchange. +func createMockOIDCServerWithToken(t *testing.T) *httptest.Server { + var serverURL string + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle OIDC discovery + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": serverURL, + "authorization_endpoint": serverURL + "/authorize", + "token_endpoint": serverURL + "/token", + "jwks_uri": serverURL + "/.well-known/jwks.json", + "response_types_supported": []string{"code"}, + "code_challenge_methods_supported": []string{"S256"}, + "grant_types_supported": []string{"authorization_code"}, + }) + return + } + // Handle token endpoint + if r.URL.Path == "/token" { + if err := r.ParseForm(); err != nil { + http.Error(w, "failed to parse form", http.StatusBadRequest) + return + } + // Validate grant type + if r.FormValue("grant_type") != "authorization_code" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + // Create mock ID token (JWT) + now := time.Now().Unix() + idToken := fmt.Sprintf("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.%s.mock-signature", + base64EncodeClaims(map[string]interface{}{ + "iss": serverURL, + "sub": "test-user", + "aud": "test-client", + "exp": now + 3600, + "iat": now, + "email": "test@example.com", + })) + // Return token response + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "mock-access-token", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "mock-refresh-token", + "id_token": idToken, + }) + return + } + http.NotFound(w, r) + })) + serverURL = server.URL + return server +} + +// base64EncodeClaims encodes JWT claims for testing. +func base64EncodeClaims(claims map[string]interface{}) string { + claimsJSON, _ := json.Marshal(claims) + return base64.RawURLEncoding.EncodeToString(claimsJSON) +} diff --git a/oauthex/oauth2.go b/oauthex/oauth2.go index d8aeb3c2..fc3b45ee 100644 --- a/oauthex/oauth2.go +++ b/oauthex/oauth2.go @@ -94,3 +94,9 @@ func checkHTTPSOrLoopback(addr string) error { } return nil } + +// CheckURLScheme validates a URL scheme for security. +// This is exported for use by the auth package. +func CheckURLScheme(u string) error { + return checkURLScheme(u) +} From 8b7ba7a94d28a0e42d1d4e9e170735fb87e56ace Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Wed, 26 Nov 2025 18:50:32 +0530 Subject: [PATCH 08/19] docs: add docs for enterprise authorization --- docs/protocol.md | 63 +++++++++++++++++++++++++++++++++++ internal/docs/protocol.src.md | 63 +++++++++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+) diff --git a/docs/protocol.md b/docs/protocol.md index f316e3f8..c013463e 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -305,6 +305,37 @@ For more sophisticated CORS policies, wrap the handler with a CORS middleware li The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/server/auth-middleware) shows how to implement authorization for both JWT tokens and API keys. +#### Enterprise Managed Authorization (SEP-990) + +For enterprise environments with centralized identity providers, the SDK supports ID-JAG (Identity Assertion JWT Authorization Grant) token validation using [`NewIDJAGVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#NewIDJAGVerifier). + +This verifier validates ID-JAG tokens issued by trusted identity providers through: +- Signature verification using the IdP's JWKS +- Audience validation (ensures token was issued for this MCP server) +- Expiration and clock skew handling +- Replay attack prevention via JTI tracking + +Example configuration: + +```go +config := &IDJAGVerifierConfig{ + AuthServerIssuerURL: "https://auth.mcpserver.example", + TrustedIdPs: map[string]*TrustedIdPConfig{ + "acme-okta": { + IssuerURL: "https://acme.okta.com", + JWKSURL: "https://acme.okta.com/.well-known/jwks.json", + }, + }, +} + +verifier := NewIDJAGVerifier(config) +middleware := RequireBearerToken(verifier, &RequireBearerTokenOptions{ + Scopes: []string{"read"}, +}) +``` + +Tool handlers can access ID-JAG claims from `req.Extra.TokenInfo.Extra`, which includes the subject, client ID, resource, and issuer from the validated token. + ### Client > [!IMPORTANT] @@ -357,6 +388,38 @@ session, err := client.Connect(ctx, transport, nil) The `auth.AuthorizationCodeHandler` automatically manages token refreshing and step-up authentication (when the server returns `insufficient_scope` error). +#### Enterprise Authentication Flow (SEP-990) + +For enterprise SSO scenarios, the SDK provides an [`EnterpriseAuthFlow`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#EnterpriseAuthFlow) function that implements the complete token exchange flow: + +1. **Token Exchange** at IdP: ID Token -> ID-JAG +2. **JWT Bearer Grant** at MCP Server: ID-JAG -> Access Token + +This flow is typically used after obtaining an ID Token via OIDC login: + +```go +// Step 1: Obtain ID token via OIDC (see auth.InitiateOIDCLogin and auth.CompleteOIDCLogin) +idToken := "..." // from OIDC login + +// Step 2: Exchange for MCP access token +enterpriseConfig := &EnterpriseAuthConfig{ + IdPIssuerURL: "https://acme.okta.com", + IdPClientID: "client-id-at-idp", + IdPClientSecret: "secret-at-idp", + MCPAuthServerURL: "https://auth.mcpserver.example", + MCPResourceURL: "https://mcp.mcpserver.example", + MCPClientID: "client-id-at-mcp", + MCPClientSecret: "secret-at-mcp", + MCPScopes: []string{"read", "write"}, +} +accessToken, err := EnterpriseAuthFlow(ctx, enterpriseConfig, tokens.IDToken) +// Use accessToken with MCP client +``` + +Helper functions are provided for OIDC login: +- [`InitiateOIDCLogin`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#InitiateOIDCLogin) - Generate authorization URL with PKCE +- [`CompleteOIDCLogin`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#CompleteOIDCLogin) - Exchange authorization code for tokens + ## Security Here we discuss the mitigations described under diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 98078619..2efe2110 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -230,6 +230,37 @@ For more sophisticated CORS policies, wrap the handler with a CORS middleware li The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/server/auth-middleware) shows how to implement authorization for both JWT tokens and API keys. +#### Enterprise Managed Authorization (SEP-990) + +For enterprise environments with centralized identity providers, the SDK supports ID-JAG (Identity Assertion JWT Authorization Grant) token validation using [`NewIDJAGVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#NewIDJAGVerifier). + +This verifier validates ID-JAG tokens issued by trusted identity providers through: +- Signature verification using the IdP's JWKS +- Audience validation (ensures token was issued for this MCP server) +- Expiration and clock skew handling +- Replay attack prevention via JTI tracking + +Example configuration: + +```go +config := &IDJAGVerifierConfig{ + AuthServerIssuerURL: "https://auth.mcpserver.example", + TrustedIdPs: map[string]*TrustedIdPConfig{ + "acme-okta": { + IssuerURL: "https://acme.okta.com", + JWKSURL: "https://acme.okta.com/.well-known/jwks.json", + }, + }, +} + +verifier := NewIDJAGVerifier(config) +middleware := RequireBearerToken(verifier, &RequireBearerTokenOptions{ + Scopes: []string{"read"}, +}) +``` + +Tool handlers can access ID-JAG claims from `req.Extra.TokenInfo.Extra`, which includes the subject, client ID, resource, and issuer from the validated token. + ### Client > [!IMPORTANT] @@ -282,6 +313,38 @@ session, err := client.Connect(ctx, transport, nil) The `auth.AuthorizationCodeHandler` automatically manages token refreshing and step-up authentication (when the server returns `insufficient_scope` error). +#### Enterprise Authentication Flow (SEP-990) + +For enterprise SSO scenarios, the SDK provides an [`EnterpriseAuthFlow`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#EnterpriseAuthFlow) function that implements the complete token exchange flow: + +1. **Token Exchange** at IdP: ID Token -> ID-JAG +2. **JWT Bearer Grant** at MCP Server: ID-JAG -> Access Token + +This flow is typically used after obtaining an ID Token via OIDC login: + +```go +// Step 1: Obtain ID token via OIDC (see auth.InitiateOIDCLogin and auth.CompleteOIDCLogin) +idToken := "..." // from OIDC login + +// Step 2: Exchange for MCP access token +enterpriseConfig := &EnterpriseAuthConfig{ + IdPIssuerURL: "https://acme.okta.com", + IdPClientID: "client-id-at-idp", + IdPClientSecret: "secret-at-idp", + MCPAuthServerURL: "https://auth.mcpserver.example", + MCPResourceURL: "https://mcp.mcpserver.example", + MCPClientID: "client-id-at-mcp", + MCPClientSecret: "secret-at-mcp", + MCPScopes: []string{"read", "write"}, +} +accessToken, err := EnterpriseAuthFlow(ctx, enterpriseConfig, tokens.IDToken) +// Use accessToken with MCP client +``` + +Helper functions are provided for OIDC login: +- [`InitiateOIDCLogin`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#InitiateOIDCLogin) - Generate authorization URL with PKCE +- [`CompleteOIDCLogin`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#CompleteOIDCLogin) - Exchange authorization code for tokens + ## Security Here we discuss the mitigations described under From b7433c3d40a9082d90383fa20dcf39e64c7dbaa8 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Thu, 27 Nov 2025 17:38:13 +0530 Subject: [PATCH 09/19] chore: remove id jag verification --- auth/id_jag_verifier.go | 250 ---------------------------------- auth/id_jag_verifier_test.go | 191 -------------------------- auth/jwks_cache.go | 150 -------------------- auth/jwks_cache_test.go | 199 --------------------------- docs/protocol.md | 61 +++------ internal/docs/protocol.src.md | 61 +++------ 6 files changed, 34 insertions(+), 878 deletions(-) delete mode 100644 auth/id_jag_verifier.go delete mode 100644 auth/id_jag_verifier_test.go delete mode 100644 auth/jwks_cache.go delete mode 100644 auth/jwks_cache_test.go diff --git a/auth/id_jag_verifier.go b/auth/id_jag_verifier.go deleted file mode 100644 index 45cb710d..00000000 --- a/auth/id_jag_verifier.go +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -// This file implements ID-JAG (Identity Assertion JWT Authorization Grant) validation -// for MCP Servers in Enterprise Managed Authorization (SEP-990). - -//go:build mcp_go_client_oauth - -package auth - -import ( - "context" - "crypto/rsa" - "encoding/base64" - "encoding/json" - "fmt" - "math/big" - "net/http" - "strings" - "sync" - "time" - - "github.com/golang-jwt/jwt/v5" - "github.com/modelcontextprotocol/go-sdk/oauthex" -) - -// TrustedIdPConfig contains configuration for a trusted Identity Provider. -type TrustedIdPConfig struct { - // IssuerURL is the IdP's issuer URL (must match the iss claim). - IssuerURL string - // JWKSURL is the URL to fetch the IdP's JSON Web Key Set. - JWKSURL string -} - -// IDJAGVerifierConfig configures ID-JAG validation for an MCP Server. -type IDJAGVerifierConfig struct { - // AuthServerIssuerURL is this MCP Server's authorization server issuer URL. - // This must match the aud claim in the ID-JAG. - AuthServerIssuerURL string - // TrustedIdPs is a map of trusted Identity Providers. - // The key is a friendly name, the value is the IdP configuration. - TrustedIdPs map[string]*TrustedIdPConfig - // JWKSCache is the cache for JWKS responses. If nil, a new cache is created. - JWKSCache *JWKSCache - // HTTPClient is the HTTP client for fetching JWKS. If nil, http.DefaultClient is used. - HTTPClient *http.Client - // AllowedClockSkew is the allowed clock skew for exp/iat validation. - // Default is 5 minutes. - AllowedClockSkew time.Duration -} - -// IDJAGVerifier validates ID-JAG tokens for MCP Servers. -type IDJAGVerifier struct { - config *IDJAGVerifierConfig - jwksCache *JWKSCache - usedJTIs map[string]time.Time // Replay attack prevention - usedJTIMu sync.RWMutex -} - -// NewIDJAGVerifier creates a new ID-JAG verifier with the given configuration. -// This returns a TokenVerifier that can be used with RequireBearerToken middleware. -// -// Example: -// -// config := &IDJAGVerifierConfig{ -// AuthServerIssuerURL: "https://auth.mcpserver.example", -// TrustedIdPs: map[string]*TrustedIdPConfig{ -// "acme-okta": { -// IssuerURL: "https://acme.okta.com", -// JWKSURL: "https://acme.okta.com/.well-known/jwks.json", -// }, -// }, -// } -// -// verifier := NewIDJAGVerifier(config) -// middleware := RequireBearerToken(verifier, &RequireBearerTokenOptions{ -// Scopes: []string{"read"}, -// }) -func NewIDJAGVerifier(config *IDJAGVerifierConfig) TokenVerifier { - if config.JWKSCache == nil { - config.JWKSCache = NewJWKSCache(config.HTTPClient) - } - if config.AllowedClockSkew == 0 { - config.AllowedClockSkew = 5 * time.Minute - } - verifier := &IDJAGVerifier{ - config: config, - jwksCache: config.JWKSCache, - usedJTIs: make(map[string]time.Time), - } - // Start cleanup goroutine for JTI tracking - go verifier.cleanupExpiredJTIs() - return verifier.Verify -} - -// Verify validates an ID-JAG token and returns TokenInfo. -// This implements the TokenVerifier interface. -func (v *IDJAGVerifier) Verify(ctx context.Context, token string, req *http.Request) (*TokenInfo, error) { - // Step 1: Parse the ID-JAG (without signature verification yet) - claims, err := oauthex.ParseIDJAG(token) - if err != nil { - return nil, fmt.Errorf("%w: failed to parse ID-JAG: %v", ErrInvalidToken, err) - } - // Step 2: Check if expired (with clock skew) - expiryTime := time.Unix(claims.ExpiresAt, 0) - if time.Now().After(expiryTime.Add(v.config.AllowedClockSkew)) { - return nil, fmt.Errorf("%w: ID-JAG expired at %v", ErrInvalidToken, expiryTime) - } - // Step 3: Validate aud claim per SEP-990 Section 5.1 - if claims.Audience != v.config.AuthServerIssuerURL { - return nil, fmt.Errorf("%w: invalid audience: expected %q, got %q", - ErrInvalidToken, v.config.AuthServerIssuerURL, claims.Audience) - } - // Step 4: Find trusted IdP - var trustedIdP *TrustedIdPConfig - for _, idp := range v.config.TrustedIdPs { - if idp.IssuerURL == claims.Issuer { - trustedIdP = idp - break - } - } - if trustedIdP == nil { - return nil, fmt.Errorf("%w: untrusted issuer: %q", ErrInvalidToken, claims.Issuer) - } - // Step 5: Verify JWT signature using IdP's JWKS - if err := v.verifySignature(ctx, token, trustedIdP.JWKSURL); err != nil { - return nil, fmt.Errorf("%w: signature verification failed: %v", ErrInvalidToken, err) - } - // Step 6: Replay attack prevention (check JTI) - if err := v.checkJTI(claims.JTI, expiryTime); err != nil { - return nil, fmt.Errorf("%w: %v", ErrInvalidToken, err) - } - // Step 7: Return TokenInfo - scopes := []string{} - if claims.Scope != "" { - scopes = strings.Split(claims.Scope, " ") - } - return &TokenInfo{ - Scopes: scopes, - Expiration: expiryTime, - Extra: map[string]any{ - "sub": claims.Subject, - "client_id": claims.ClientID, - "resource": claims.Resource, - "iss": claims.Issuer, - }, - }, nil -} - -// verifySignature verifies the JWT signature using the IdP's JWKS. -func (v *IDJAGVerifier) verifySignature(ctx context.Context, tokenString, jwksURL string) error { - // Parse JWT to get header - parts := strings.Split(tokenString, ".") - if len(parts) != 3 { - return fmt.Errorf("invalid JWT format") - } - // Decode header to get kid - headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) - if err != nil { - return fmt.Errorf("failed to decode JWT header: %w", err) - } - var header struct { - Kid string `json:"kid"` - Alg string `json:"alg"` - } - if err := json.Unmarshal(headerJSON, &header); err != nil { - return fmt.Errorf("failed to parse JWT header: %w", err) - } - // Fetch JWKS - jwks, err := v.jwksCache.Get(ctx, jwksURL) - if err != nil { - return fmt.Errorf("failed to fetch JWKS: %w", err) - } - // Find the key - jwk, err := jwks.FindKey(header.Kid) - if err != nil { - return fmt.Errorf("key not found in JWKS: %w", err) - } - // Parse JWT with verification - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - // Verify algorithm - if token.Method.Alg() != header.Alg { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - // Convert JWK to public key - return jwkToPublicKey(jwk) - }) - if err != nil { - return fmt.Errorf("JWT verification failed: %w", err) - } - if !token.Valid { - return fmt.Errorf("JWT is invalid") - } - return nil -} - -// checkJTI checks if the JTI has been used before (replay attack prevention). -func (v *IDJAGVerifier) checkJTI(jti string, expiresAt time.Time) error { - v.usedJTIMu.Lock() - defer v.usedJTIMu.Unlock() - if _, used := v.usedJTIs[jti]; used { - return fmt.Errorf("JTI %q already used (replay attack)", jti) - } - // Mark as used - v.usedJTIs[jti] = expiresAt - return nil -} - -// cleanupExpiredJTIs periodically removes expired JTIs from the tracking map. -func (v *IDJAGVerifier) cleanupExpiredJTIs() { - ticker := time.NewTicker(10 * time.Minute) - defer ticker.Stop() - for range ticker.C { - v.usedJTIMu.Lock() - now := time.Now() - for jti, expiresAt := range v.usedJTIs { - if now.After(expiresAt) { - delete(v.usedJTIs, jti) - } - } - v.usedJTIMu.Unlock() - } -} - -// jwkToPublicKey converts a JWK to a public key for signature verification. -func jwkToPublicKey(jwk *JWK) (interface{}, error) { - switch jwk.KeyType { - case "RSA": - // Decode modulus - nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N) - if err != nil { - return nil, fmt.Errorf("failed to decode modulus: %w", err) - } - // Decode exponent - eBytes, err := base64.RawURLEncoding.DecodeString(jwk.E) - if err != nil { - return nil, fmt.Errorf("failed to decode exponent: %w", err) - } - // Convert to big.Int - n := new(big.Int).SetBytes(nBytes) - e := new(big.Int).SetBytes(eBytes) - return &rsa.PublicKey{ - N: n, - E: int(e.Int64()), - }, nil - default: - return nil, fmt.Errorf("unsupported key type: %s", jwk.KeyType) - } -} diff --git a/auth/id_jag_verifier_test.go b/auth/id_jag_verifier_test.go deleted file mode 100644 index 544aef6b..00000000 --- a/auth/id_jag_verifier_test.go +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -//go:build mcp_go_client_oauth - -package auth - -import ( - "context" - "crypto/rand" - "crypto/rsa" - "encoding/base64" - "encoding/json" - "fmt" - "math/big" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/golang-jwt/jwt/v5" -) - -// TestIDJAGVerifier tests ID-JAG validation. -func TestIDJAGVerifier(t *testing.T) { - // Generate RSA key pair for testing - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - t.Fatalf("Failed to generate key: %v", err) - } - publicKey := &privateKey.PublicKey - // Create mock JWKS server - jwksServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - jwks := &JWKS{ - Keys: []JWK{ - { - KeyType: "RSA", - Use: "sig", - KeyID: "test-key", - Algorithm: "RS256", - N: base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes()), - E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(publicKey.E)).Bytes()), - }, - }, - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(jwks) - })) - defer jwksServer.Close() - // Configure verifier - config := &IDJAGVerifierConfig{ - AuthServerIssuerURL: "https://auth.mcpserver.example", - TrustedIdPs: map[string]*TrustedIdPConfig{ - "test-idp": { - IssuerURL: "https://test.okta.com", - JWKSURL: jwksServer.URL, - }, - }, - HTTPClient: jwksServer.Client(), - } - verifier := NewIDJAGVerifier(config) - // Test valid ID-JAG - t.Run("valid ID-JAG", func(t *testing.T) { - idJAG := createTestIDJAG(t, privateKey, map[string]interface{}{ - "iss": "https://test.okta.com", - "sub": "user123", - "aud": "https://auth.mcpserver.example", - "resource": "https://mcp.mcpserver.example", - "client_id": "client123", - "jti": "jti-" + fmt.Sprint(time.Now().UnixNano()), - "exp": time.Now().Add(1 * time.Hour).Unix(), - "iat": time.Now().Unix(), - "scope": "read write", - }) - tokenInfo, err := verifier(context.Background(), idJAG, nil) - if err != nil { - t.Fatalf("Verify failed: %v", err) - } - if len(tokenInfo.Scopes) != 2 { - t.Errorf("expected 2 scopes, got %d", len(tokenInfo.Scopes)) - } - if tokenInfo.Extra["sub"] != "user123" { - t.Errorf("expected sub 'user123', got %v", tokenInfo.Extra["sub"]) - } - if tokenInfo.Extra["client_id"] != "client123" { - t.Errorf("expected client_id 'client123', got %v", tokenInfo.Extra["client_id"]) - } - }) - // Test expired ID-JAG - t.Run("expired ID-JAG", func(t *testing.T) { - idJAG := createTestIDJAG(t, privateKey, map[string]interface{}{ - "iss": "https://test.okta.com", - "sub": "user123", - "aud": "https://auth.mcpserver.example", - "resource": "https://mcp.mcpserver.example", - "client_id": "client123", - "jti": "jti-expired", - "exp": time.Now().Add(-1 * time.Hour).Unix(), - "iat": time.Now().Add(-2 * time.Hour).Unix(), - "scope": "read write", - }) - _, err := verifier(context.Background(), idJAG, nil) - if err == nil { - t.Error("expected error for expired ID-JAG, got nil") - } - }) - // Test wrong audience - t.Run("wrong audience", func(t *testing.T) { - idJAG := createTestIDJAG(t, privateKey, map[string]interface{}{ - "iss": "https://test.okta.com", - "sub": "user123", - "aud": "https://wrong.audience.com", - "resource": "https://mcp.mcpserver.example", - "client_id": "client123", - "jti": "jti-wrong-aud", - "exp": time.Now().Add(1 * time.Hour).Unix(), - "iat": time.Now().Unix(), - "scope": "read write", - }) - _, err := verifier(context.Background(), idJAG, nil) - if err == nil { - t.Error("expected error for wrong audience, got nil") - } - if !strings.Contains(err.Error(), "invalid audience") { - t.Errorf("expected 'invalid audience' error, got: %v", err) - } - }) - // Test untrusted issuer - t.Run("untrusted issuer", func(t *testing.T) { - idJAG := createTestIDJAG(t, privateKey, map[string]interface{}{ - "iss": "https://untrusted.idp.com", - "sub": "user123", - "aud": "https://auth.mcpserver.example", - "resource": "https://mcp.mcpserver.example", - "client_id": "client123", - "jti": "jti-untrusted", - "exp": time.Now().Add(1 * time.Hour).Unix(), - "iat": time.Now().Unix(), - "scope": "read write", - }) - _, err := verifier(context.Background(), idJAG, nil) - if err == nil { - t.Error("expected error for untrusted issuer, got nil") - } - if !strings.Contains(err.Error(), "untrusted issuer") { - t.Errorf("expected 'untrusted issuer' error, got: %v", err) - } - }) - // Test replay attack - t.Run("replay attack", func(t *testing.T) { - jti := "jti-replay-" + fmt.Sprint(time.Now().UnixNano()) - idJAG := createTestIDJAG(t, privateKey, map[string]interface{}{ - "iss": "https://test.okta.com", - "sub": "user123", - "aud": "https://auth.mcpserver.example", - "resource": "https://mcp.mcpserver.example", - "client_id": "client123", - "jti": jti, - "exp": time.Now().Add(1 * time.Hour).Unix(), - "iat": time.Now().Unix(), - "scope": "read write", - }) - // First use should succeed - _, err := verifier(context.Background(), idJAG, nil) - if err != nil { - t.Fatalf("First verify failed: %v", err) - } - // Second use (replay) should fail - _, err = verifier(context.Background(), idJAG, nil) - if err == nil { - t.Error("expected error for replay attack, got nil") - } - if !strings.Contains(err.Error(), "already used") { - t.Errorf("expected 'already used' error, got: %v", err) - } - }) -} - -// createTestIDJAG creates a test ID-JAG JWT signed with the given private key. -func createTestIDJAG(t *testing.T, privateKey *rsa.PrivateKey, claims map[string]interface{}) string { - token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims(claims)) - token.Header["typ"] = "oauth-id-jag+jwt" - token.Header["kid"] = "test-key" - signedToken, err := token.SignedString(privateKey) - if err != nil { - t.Fatalf("Failed to sign token: %v", err) - } - return signedToken -} diff --git a/auth/jwks_cache.go b/auth/jwks_cache.go deleted file mode 100644 index 70efe89f..00000000 --- a/auth/jwks_cache.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -// This file implements JWKS (JSON Web Key Set) fetching and caching for -// JWT signature verification in Enterprise Managed Authorization (SEP-990). - -//go:build mcp_go_client_oauth - -package auth - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "sync" - "time" -) - -// JWK represents a JSON Web Key per RFC 7517. -type JWK struct { - // KeyType is the key type (e.g., "RSA", "EC"). - KeyType string `json:"kty"` - // Use indicates the intended use of the key (e.g., "sig" for signature). - Use string `json:"use,omitempty"` - // KeyID is the key identifier. - KeyID string `json:"kid"` - // Algorithm is the algorithm intended for use with the key. - Algorithm string `json:"alg,omitempty"` - // N is the RSA modulus (base64url encoded). - N string `json:"n,omitempty"` - // E is the RSA public exponent (base64url encoded). - E string `json:"e,omitempty"` - // X is the X coordinate for elliptic curve keys (base64url encoded). - X string `json:"x,omitempty"` - // Y is the Y coordinate for elliptic curve keys (base64url encoded). - Y string `json:"y,omitempty"` - // Curve is the elliptic curve name (e.g., "P-256"). - Curve string `json:"crv,omitempty"` -} - -// JWKS represents a JSON Web Key Set per RFC 7517. -type JWKS struct { - Keys []JWK `json:"keys"` -} - -// FindKey finds a key by its key ID (kid). -func (j *JWKS) FindKey(kid string) (*JWK, error) { - for i := range j.Keys { - if j.Keys[i].KeyID == kid { - return &j.Keys[i], nil - } - } - return nil, fmt.Errorf("key with kid %q not found", kid) -} - -// JWKSCache caches JWKS responses to reduce network requests. -type JWKSCache struct { - mu sync.RWMutex - entries map[string]*jwksCacheEntry - client *http.Client -} -type jwksCacheEntry struct { - jwks *JWKS - expiresAt time.Time -} - -// NewJWKSCache creates a new JWKS cache with the given HTTP client. -// If client is nil, http.DefaultClient is used. -func NewJWKSCache(client *http.Client) *JWKSCache { - if client == nil { - client = http.DefaultClient - } - return &JWKSCache{ - entries: make(map[string]*jwksCacheEntry), - client: client, - } -} - -// Get fetches JWKS from the given URL, using cache if available and not expired. -// The cache duration is 1 hour per best practices for JWKS caching. -func (c *JWKSCache) Get(ctx context.Context, jwksURL string) (*JWKS, error) { - // Check cache first - c.mu.RLock() - entry, ok := c.entries[jwksURL] - c.mu.RUnlock() - if ok && time.Now().Before(entry.expiresAt) { - return entry.jwks, nil - } - // Fetch from network - jwks, err := c.fetch(ctx, jwksURL) - if err != nil { - return nil, err - } - // Update cache - c.mu.Lock() - c.entries[jwksURL] = &jwksCacheEntry{ - jwks: jwks, - expiresAt: time.Now().Add(1 * time.Hour), - } - c.mu.Unlock() - return jwks, nil -} - -// fetch retrieves JWKS from the given URL. -func (c *JWKSCache) fetch(ctx context.Context, jwksURL string) (*JWKS, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to create JWKS request: %w", err) - } - req.Header.Set("Accept", "application/json") - resp, err := c.client.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to fetch JWKS: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("JWKS endpoint returned status %d", resp.StatusCode) - } - // Read response body (limit to 1MB for safety) - body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) - if err != nil { - return nil, fmt.Errorf("failed to read JWKS response: %w", err) - } - // Parse JWKS - var jwks JWKS - if err := json.Unmarshal(body, &jwks); err != nil { - return nil, fmt.Errorf("failed to parse JWKS: %w", err) - } - if len(jwks.Keys) == 0 { - return nil, fmt.Errorf("JWKS contains no keys") - } - return &jwks, nil -} - -// Invalidate removes a JWKS entry from the cache, forcing a fresh fetch on next Get. -func (c *JWKSCache) Invalidate(jwksURL string) { - c.mu.Lock() - delete(c.entries, jwksURL) - c.mu.Unlock() -} - -// Clear removes all entries from the cache. -func (c *JWKSCache) Clear() { - c.mu.Lock() - c.entries = make(map[string]*jwksCacheEntry) - c.mu.Unlock() -} diff --git a/auth/jwks_cache_test.go b/auth/jwks_cache_test.go deleted file mode 100644 index 4ec87ca7..00000000 --- a/auth/jwks_cache_test.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -//go:build mcp_go_client_oauth - -package auth - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "time" -) - -// TestJWKSCache tests JWKS fetching and caching. -func TestJWKSCache(t *testing.T) { - // Create test JWKS - testJWKS := &JWKS{ - Keys: []JWK{ - { - KeyType: "RSA", - Use: "sig", - KeyID: "test-key-1", - Algorithm: "RS256", - N: "test-modulus", - E: "AQAB", - }, - { - KeyType: "RSA", - Use: "sig", - KeyID: "test-key-2", - Algorithm: "RS256", - N: "test-modulus-2", - E: "AQAB", - }, - }, - } - // Create test server - var requestCount int - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount++ - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(testJWKS) - })) - defer server.Close() - cache := NewJWKSCache(server.Client()) - // Test first fetch - t.Run("first fetch", func(t *testing.T) { - jwks, err := cache.Get(context.Background(), server.URL) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if len(jwks.Keys) != 2 { - t.Errorf("expected 2 keys, got %d", len(jwks.Keys)) - } - if jwks.Keys[0].KeyID != "test-key-1" { - t.Errorf("expected key ID 'test-key-1', got '%s'", jwks.Keys[0].KeyID) - } - if requestCount != 1 { - t.Errorf("expected 1 request, got %d", requestCount) - } - }) - // Test cache hit - t.Run("cache hit", func(t *testing.T) { - jwks, err := cache.Get(context.Background(), server.URL) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - if len(jwks.Keys) != 2 { - t.Errorf("expected 2 keys from cache, got %d", len(jwks.Keys)) - } - // Should still be 1 request (served from cache) - if requestCount != 1 { - t.Errorf("expected 1 request (cached), got %d", requestCount) - } - }) - // Test FindKey - t.Run("find key", func(t *testing.T) { - jwks, _ := cache.Get(context.Background(), server.URL) - key, err := jwks.FindKey("test-key-2") - if err != nil { - t.Fatalf("FindKey failed: %v", err) - } - if key.KeyID != "test-key-2" { - t.Errorf("expected key ID 'test-key-2', got '%s'", key.KeyID) - } - if key.N != "test-modulus-2" { - t.Errorf("expected modulus 'test-modulus-2', got '%s'", key.N) - } - }) - // Test key not found - t.Run("key not found", func(t *testing.T) { - jwks, _ := cache.Get(context.Background(), server.URL) - _, err := jwks.FindKey("nonexistent") - if err == nil { - t.Error("expected error for nonexistent key, got nil") - } - }) - // Test Invalidate - t.Run("invalidate", func(t *testing.T) { - cache.Invalidate(server.URL) - // Next fetch should hit the server again - _, err := cache.Get(context.Background(), server.URL) - if err != nil { - t.Fatalf("Get after invalidate failed: %v", err) - } - if requestCount != 2 { - t.Errorf("expected 2 requests after invalidate, got %d", requestCount) - } - }) - // Test Clear - t.Run("clear", func(t *testing.T) { - cache.Clear() - // Next fetch should hit the server again - _, err := cache.Get(context.Background(), server.URL) - if err != nil { - t.Fatalf("Get after clear failed: %v", err) - } - if requestCount != 3 { - t.Errorf("expected 3 requests after clear, got %d", requestCount) - } - }) - // Test error handling - t.Run("server error", func(t *testing.T) { - errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "internal error", http.StatusInternalServerError) - })) - defer errorServer.Close() - _, err := cache.Get(context.Background(), errorServer.URL) - if err == nil { - t.Error("expected error for server error, got nil") - } - }) - // Test invalid JSON - t.Run("invalid json", func(t *testing.T) { - invalidServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Write([]byte("invalid json")) - })) - defer invalidServer.Close() - _, err := cache.Get(context.Background(), invalidServer.URL) - if err == nil { - t.Error("expected error for invalid JSON, got nil") - } - }) - // Test empty keys - t.Run("empty keys", func(t *testing.T) { - emptyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(&JWKS{Keys: []JWK{}}) - })) - defer emptyServer.Close() - _, err := cache.Get(context.Background(), emptyServer.URL) - if err == nil { - t.Error("expected error for empty keys, got nil") - } - }) -} - -// TestJWKSCacheExpiration tests cache expiration. -func TestJWKSCacheExpiration(t *testing.T) { - testJWKS := &JWKS{ - Keys: []JWK{{KeyID: "test", KeyType: "RSA", N: "test", E: "AQAB"}}, - } - var requestCount int - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount++ - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(testJWKS) - })) - defer server.Close() - cache := NewJWKSCache(server.Client()) - // First fetch - _, err := cache.Get(context.Background(), server.URL) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - // Manually expire the cache entry - cache.mu.Lock() - if entry, ok := cache.entries[server.URL]; ok { - entry.expiresAt = time.Now().Add(-1 * time.Hour) - } - cache.mu.Unlock() - // Next fetch should hit server again - _, err = cache.Get(context.Background(), server.URL) - if err != nil { - t.Fatalf("Get after expiration failed: %v", err) - } - if requestCount != 2 { - t.Errorf("expected 2 requests after expiration, got %d", requestCount) - } -} diff --git a/docs/protocol.md b/docs/protocol.md index c013463e..a857bf21 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -305,37 +305,6 @@ For more sophisticated CORS policies, wrap the handler with a CORS middleware li The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/server/auth-middleware) shows how to implement authorization for both JWT tokens and API keys. -#### Enterprise Managed Authorization (SEP-990) - -For enterprise environments with centralized identity providers, the SDK supports ID-JAG (Identity Assertion JWT Authorization Grant) token validation using [`NewIDJAGVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#NewIDJAGVerifier). - -This verifier validates ID-JAG tokens issued by trusted identity providers through: -- Signature verification using the IdP's JWKS -- Audience validation (ensures token was issued for this MCP server) -- Expiration and clock skew handling -- Replay attack prevention via JTI tracking - -Example configuration: - -```go -config := &IDJAGVerifierConfig{ - AuthServerIssuerURL: "https://auth.mcpserver.example", - TrustedIdPs: map[string]*TrustedIdPConfig{ - "acme-okta": { - IssuerURL: "https://acme.okta.com", - JWKSURL: "https://acme.okta.com/.well-known/jwks.json", - }, - }, -} - -verifier := NewIDJAGVerifier(config) -middleware := RequireBearerToken(verifier, &RequireBearerTokenOptions{ - Scopes: []string{"read"}, -}) -``` - -Tool handlers can access ID-JAG claims from `req.Extra.TokenInfo.Extra`, which includes the subject, client ID, resource, and issuer from the validated token. - ### Client > [!IMPORTANT] @@ -390,10 +359,12 @@ and step-up authentication (when the server returns `insufficient_scope` error). #### Enterprise Authentication Flow (SEP-990) -For enterprise SSO scenarios, the SDK provides an [`EnterpriseAuthFlow`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#EnterpriseAuthFlow) function that implements the complete token exchange flow: +For enterprise SSO scenarios, the SDK provides an +[`EnterpriseAuthFlow`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#EnterpriseAuthFlow) +function that implements the complete token exchange flow: -1. **Token Exchange** at IdP: ID Token -> ID-JAG -2. **JWT Bearer Grant** at MCP Server: ID-JAG -> Access Token +1. **Token Exchange** at IdP: ID Token → ID-JAG +2. **JWT Bearer Grant** at MCP Server: ID-JAG → Access Token This flow is typically used after obtaining an ID Token via OIDC login: @@ -402,17 +373,18 @@ This flow is typically used after obtaining an ID Token via OIDC login: idToken := "..." // from OIDC login // Step 2: Exchange for MCP access token -enterpriseConfig := &EnterpriseAuthConfig{ - IdPIssuerURL: "https://acme.okta.com", - IdPClientID: "client-id-at-idp", - IdPClientSecret: "secret-at-idp", - MCPAuthServerURL: "https://auth.mcpserver.example", - MCPResourceURL: "https://mcp.mcpserver.example", - MCPClientID: "client-id-at-mcp", - MCPClientSecret: "secret-at-mcp", - MCPScopes: []string{"read", "write"}, +config := &auth.EnterpriseAuthConfig{ + IdPIssuerURL: "https://company.okta.com", + IdPClientID: "client-id-at-idp", + IdPClientSecret: "secret-at-idp", + MCPAuthServerURL: "https://auth.mcpserver.example", + MCPResourceURL: "https://mcp.mcpserver.example", + MCPClientID: "client-id-at-mcp", + MCPClientSecret: "secret-at-mcp", + MCPScopes: []string{"read", "write"}, } -accessToken, err := EnterpriseAuthFlow(ctx, enterpriseConfig, tokens.IDToken) + +accessToken, err := auth.EnterpriseAuthFlow(ctx, config, idToken) // Use accessToken with MCP client ``` @@ -638,3 +610,4 @@ func Example_progress() { // frobbing widgets 2/2 } ``` + diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 2efe2110..8475a324 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -230,37 +230,6 @@ For more sophisticated CORS policies, wrap the handler with a CORS middleware li The [_auth middleware example_](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/server/auth-middleware) shows how to implement authorization for both JWT tokens and API keys. -#### Enterprise Managed Authorization (SEP-990) - -For enterprise environments with centralized identity providers, the SDK supports ID-JAG (Identity Assertion JWT Authorization Grant) token validation using [`NewIDJAGVerifier`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#NewIDJAGVerifier). - -This verifier validates ID-JAG tokens issued by trusted identity providers through: -- Signature verification using the IdP's JWKS -- Audience validation (ensures token was issued for this MCP server) -- Expiration and clock skew handling -- Replay attack prevention via JTI tracking - -Example configuration: - -```go -config := &IDJAGVerifierConfig{ - AuthServerIssuerURL: "https://auth.mcpserver.example", - TrustedIdPs: map[string]*TrustedIdPConfig{ - "acme-okta": { - IssuerURL: "https://acme.okta.com", - JWKSURL: "https://acme.okta.com/.well-known/jwks.json", - }, - }, -} - -verifier := NewIDJAGVerifier(config) -middleware := RequireBearerToken(verifier, &RequireBearerTokenOptions{ - Scopes: []string{"read"}, -}) -``` - -Tool handlers can access ID-JAG claims from `req.Extra.TokenInfo.Extra`, which includes the subject, client ID, resource, and issuer from the validated token. - ### Client > [!IMPORTANT] @@ -315,10 +284,12 @@ and step-up authentication (when the server returns `insufficient_scope` error). #### Enterprise Authentication Flow (SEP-990) -For enterprise SSO scenarios, the SDK provides an [`EnterpriseAuthFlow`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#EnterpriseAuthFlow) function that implements the complete token exchange flow: +For enterprise SSO scenarios, the SDK provides an +[`EnterpriseAuthFlow`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#EnterpriseAuthFlow) +function that implements the complete token exchange flow: -1. **Token Exchange** at IdP: ID Token -> ID-JAG -2. **JWT Bearer Grant** at MCP Server: ID-JAG -> Access Token +1. **Token Exchange** at IdP: ID Token → ID-JAG +2. **JWT Bearer Grant** at MCP Server: ID-JAG → Access Token This flow is typically used after obtaining an ID Token via OIDC login: @@ -327,17 +298,18 @@ This flow is typically used after obtaining an ID Token via OIDC login: idToken := "..." // from OIDC login // Step 2: Exchange for MCP access token -enterpriseConfig := &EnterpriseAuthConfig{ - IdPIssuerURL: "https://acme.okta.com", - IdPClientID: "client-id-at-idp", - IdPClientSecret: "secret-at-idp", - MCPAuthServerURL: "https://auth.mcpserver.example", - MCPResourceURL: "https://mcp.mcpserver.example", - MCPClientID: "client-id-at-mcp", - MCPClientSecret: "secret-at-mcp", - MCPScopes: []string{"read", "write"}, +config := &auth.EnterpriseAuthConfig{ + IdPIssuerURL: "https://company.okta.com", + IdPClientID: "client-id-at-idp", + IdPClientSecret: "secret-at-idp", + MCPAuthServerURL: "https://auth.mcpserver.example", + MCPResourceURL: "https://mcp.mcpserver.example", + MCPClientID: "client-id-at-mcp", + MCPClientSecret: "secret-at-mcp", + MCPScopes: []string{"read", "write"}, } -accessToken, err := EnterpriseAuthFlow(ctx, enterpriseConfig, tokens.IDToken) + +accessToken, err := auth.EnterpriseAuthFlow(ctx, config, idToken) // Use accessToken with MCP client ``` @@ -461,3 +433,4 @@ or Issue #460 discusses some potential ergonomic improvements to this API. %include ../../mcp/mcp_example_test.go progress - + From 6b8bfdc7101a7c8b9ccea3399ae3d8d94f2bf447 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Thu, 4 Dec 2025 20:41:09 +0530 Subject: [PATCH 10/19] chore: rename MCPResourceURL to MCPResourceURI --- auth/enterprise_auth.go | 12 ++++++------ auth/enterprise_auth_test.go | 2 +- docs/protocol.md | 2 +- internal/docs/protocol.src.md | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/auth/enterprise_auth.go b/auth/enterprise_auth.go index 2d903ca2..54535c40 100644 --- a/auth/enterprise_auth.go +++ b/auth/enterprise_auth.go @@ -29,7 +29,7 @@ type EnterpriseAuthConfig struct { // MCP Server configuration (the resource being accessed) MCPAuthServerURL string // MCP Server's auth server issuer URL - MCPResourceURL string // MCP Server's resource identifier + MCPResourceURI string // MCP Server's resource identifier MCPClientID string // MCP Client's ID at the MCP Server MCPClientSecret string // MCP Client's secret at the MCP Server MCPScopes []string // Requested scopes at the MCP Server @@ -76,7 +76,7 @@ type EnterpriseAuthConfig struct { // IdPClientID: "client-id-at-idp", // IdPClientSecret: "secret-at-idp", // MCPAuthServerURL: "https://auth.mcpserver.example", -// MCPResourceURL: "https://mcp.mcpserver.example", +// MCPResourceURI: "https://mcp.mcpserver.example", // MCPClientID: "client-id-at-mcp", // MCPClientSecret: "secret-at-mcp", // MCPScopes: []string{"read", "write"}, @@ -93,7 +93,7 @@ type EnterpriseAuthConfig struct { // IdPClientID: "client-id-at-idp", // IdPClientSecret: "secret-at-idp", // MCPAuthServerURL: "https://auth.mcpserver.example", -// MCPResourceURL: "https://mcp.mcpserver.example", +// MCPResourceURI: "https://mcp.mcpserver.example", // MCPClientID: "client-id-at-mcp", // MCPClientSecret: "secret-at-mcp", // MCPScopes: []string{"read", "write"}, @@ -124,8 +124,8 @@ func EnterpriseAuthFlow( if config.MCPAuthServerURL == "" { return nil, fmt.Errorf("MCPAuthServerURL is required") } - if config.MCPResourceURL == "" { - return nil, fmt.Errorf("MCPResourceURL is required") + if config.MCPResourceURI == "" { + return nil, fmt.Errorf("MCPResourceURI is required") } httpClient := config.HTTPClient if httpClient == nil { @@ -142,7 +142,7 @@ func EnterpriseAuthFlow( tokenExchangeReq := &oauthex.TokenExchangeRequest{ RequestedTokenType: oauthex.TokenTypeIDJAG, Audience: config.MCPAuthServerURL, - Resource: config.MCPResourceURL, + Resource: config.MCPResourceURI, Scope: config.MCPScopes, SubjectToken: idToken, SubjectTokenType: oauthex.TokenTypeIDToken, diff --git a/auth/enterprise_auth_test.go b/auth/enterprise_auth_test.go index db2e5ffd..c44e4233 100644 --- a/auth/enterprise_auth_test.go +++ b/auth/enterprise_auth_test.go @@ -34,7 +34,7 @@ func TestEnterpriseAuthFlow(t *testing.T) { IdPClientID: "test-idp-client", IdPClientSecret: "test-idp-secret", MCPAuthServerURL: mcpServer.URL, - MCPResourceURL: "https://mcp.example.com", + MCPResourceURI: "https://mcp.example.com", MCPClientID: "test-mcp-client", MCPClientSecret: "test-mcp-secret", MCPScopes: []string{"read", "write"}, diff --git a/docs/protocol.md b/docs/protocol.md index a857bf21..cc487ccb 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -378,7 +378,7 @@ config := &auth.EnterpriseAuthConfig{ IdPClientID: "client-id-at-idp", IdPClientSecret: "secret-at-idp", MCPAuthServerURL: "https://auth.mcpserver.example", - MCPResourceURL: "https://mcp.mcpserver.example", + MCPResourceURI: "https://mcp.mcpserver.example", MCPClientID: "client-id-at-mcp", MCPClientSecret: "secret-at-mcp", MCPScopes: []string{"read", "write"}, diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 8475a324..1511a304 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -303,7 +303,7 @@ config := &auth.EnterpriseAuthConfig{ IdPClientID: "client-id-at-idp", IdPClientSecret: "secret-at-idp", MCPAuthServerURL: "https://auth.mcpserver.example", - MCPResourceURL: "https://mcp.mcpserver.example", + MCPResourceURI: "https://mcp.mcpserver.example", MCPClientID: "client-id-at-mcp", MCPClientSecret: "secret-at-mcp", MCPScopes: []string{"read", "write"}, From 897455b4a52a46d07fa3feab7fd6ca93b4e21862 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Sun, 22 Feb 2026 14:52:25 +0530 Subject: [PATCH 11/19] feat: use oauth2 for token structs --- oauthex/jwt_bearer.go | 63 +++++++++++++--------------------- oauthex/jwt_bearer_test.go | 13 +++++-- oauthex/token_exchange.go | 41 ++++++++++------------ oauthex/token_exchange_test.go | 5 ++- 4 files changed, 56 insertions(+), 66 deletions(-) diff --git a/oauthex/jwt_bearer.go b/oauthex/jwt_bearer.go index 03a0df5d..0d8576fe 100644 --- a/oauthex/jwt_bearer.go +++ b/oauthex/jwt_bearer.go @@ -26,42 +26,6 @@ import ( // This is used in SEP-990 to exchange an ID-JAG for an access token at the MCP Server. const GrantTypeJWTBearer = "urn:ietf:params:oauth:grant-type:jwt-bearer" -// JWTBearerResponse represents the response from a JWT Bearer grant request -// per RFC 7523. This uses the standard OAuth 2.0 token response format. -type JWTBearerResponse struct { - // AccessToken is the OAuth access token issued by the MCP Server's - // authorization server. - AccessToken string `json:"access_token"` - // TokenType is the type of token issued. This is typically "Bearer". - TokenType string `json:"token_type"` - // ExpiresIn is the lifetime in seconds of the access token. - ExpiresIn int `json:"expires_in,omitempty"` - // RefreshToken is the refresh token, which can be used to obtain new - // access tokens using the same authorization grant. - RefreshToken string `json:"refresh_token,omitempty"` - // Scope is the scope of the access token as described by RFC 6749 Section 3.3. - Scope string `json:"scope,omitempty"` -} - -// JWTBearerError represents an error response from a JWT Bearer grant request. -type JWTBearerError struct { - // ErrorCode is the error code as defined in RFC 6749 Section 5.2. - // The JSON field name is "error" per the RFC specification. - ErrorCode string `json:"error"` - // ErrorDescription is a human-readable description of the error. - ErrorDescription string `json:"error_description,omitempty"` - // ErrorURI is a URI identifying a human-readable web page with information - // about the error. - ErrorURI string `json:"error_uri,omitempty"` -} - -func (e *JWTBearerError) Error() string { - if e.ErrorDescription != "" { - return fmt.Sprintf("JWT bearer grant failed: %s (%s)", e.ErrorCode, e.ErrorDescription) - } - return fmt.Sprintf("JWT bearer grant failed: %s", e.ErrorCode) -} - // ExchangeJWTBearer exchanges an Identity Assertion JWT Authorization Grant (ID-JAG) // for an access token using JWT Bearer Grant per RFC 7523. This is the second step // in Enterprise Managed Authorization (SEP-990) after obtaining the ID-JAG from the @@ -151,7 +115,13 @@ func ExchangeJWTBearer( } // Handle success response (200 OK per OAuth 2.0) if httpResp.StatusCode == http.StatusOK { - var resp JWTBearerResponse + var resp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` + } if err := json.Unmarshal(body, &resp); err != nil { return nil, fmt.Errorf("failed to parse JWT bearer grant response: %w (body: %s)", err, string(body)) } @@ -182,12 +152,25 @@ func ExchangeJWTBearer( } // Handle error response (400 Bad Request per RFC 6749) if httpResp.StatusCode == http.StatusBadRequest { - var errResp JWTBearerError + var errResp struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` + ErrorURI string `json:"error_uri,omitempty"` + } if err := json.Unmarshal(body, &errResp); err != nil { return nil, fmt.Errorf("failed to parse error response: %w (body: %s)", err, string(body)) } - return nil, &errResp + return nil, &oauth2.RetrieveError{ + Response: httpResp, + Body: body, + ErrorCode: errResp.Error, + ErrorDescription: errResp.ErrorDescription, + ErrorURI: errResp.ErrorURI, + } } // Handle unexpected status codes - return nil, fmt.Errorf("unexpected status code %d: %s", httpResp.StatusCode, string(body)) + return nil, &oauth2.RetrieveError{ + Response: httpResp, + Body: body, + } } diff --git a/oauthex/jwt_bearer_test.go b/oauthex/jwt_bearer_test.go index 3145d0bf..33f0ed8f 100644 --- a/oauthex/jwt_bearer_test.go +++ b/oauthex/jwt_bearer_test.go @@ -64,7 +64,13 @@ func TestExchangeJWTBearer(t *testing.T) { return } // Return successful OAuth token response - resp := JWTBearerResponse{ + resp := struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + Scope string `json:"scope,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + }{ AccessToken: "mcp-access-token-123", TokenType: "Bearer", ExpiresIn: 3600, @@ -143,7 +149,10 @@ func TestExchangeJWTBearer(t *testing.T) { // writeJWTBearerErrorResponse writes an OAuth 2.0 error response per RFC 6749 Section 5.2. func writeJWTBearerErrorResponse(w http.ResponseWriter, errorCode, errorDescription string) { - errResp := JWTBearerError{ + errResp := struct { + Error string `json:"error"` + ErrorCode string `json:"error_description,omitempty"` + }{ ErrorCode: errorCode, ErrorDescription: errorDescription, } diff --git a/oauthex/token_exchange.go b/oauthex/token_exchange.go index fb162d0b..8ec13b62 100644 --- a/oauthex/token_exchange.go +++ b/oauthex/token_exchange.go @@ -17,6 +17,8 @@ import ( "net/http" "net/url" "strings" + + "golang.org/x/oauth2" ) // Token type identifiers defined by RFC 8693 and SEP-990. @@ -92,26 +94,6 @@ type TokenExchangeResponse struct { ExpiresIn int `json:"expires_in,omitempty"` } -// TokenExchangeError represents an error response from a token exchange request. -type TokenExchangeError struct { - // Error is the error code as defined in RFC 6749 Section 5.2. - ErrorCode string `json:"error"` - - // ErrorDescription is a human-readable description of the error. - ErrorDescription string `json:"error_description,omitempty"` - - // ErrorURI is a URI identifying a human-readable web page with information - // about the error. - ErrorURI string `json:"error_uri,omitempty"` -} - -func (e *TokenExchangeError) Error() string { - if e.ErrorDescription != "" { - return fmt.Sprintf("token exchange failed: %s (%s)", e.ErrorCode, e.ErrorDescription) - } - return fmt.Sprintf("token exchange failed: %s", e.ErrorCode) -} - // ExchangeToken performs a token exchange request per RFC 8693 for Enterprise // Managed Authorization (SEP-990). It exchanges an identity assertion (typically // an ID Token) for an Identity Assertion JWT Authorization Grant (ID-JAG) that @@ -255,13 +237,26 @@ func ExchangeToken( // Handle error response (400 Bad Request per RFC 6749) if httpResp.StatusCode == http.StatusBadRequest { - var errResp TokenExchangeError + var errResp struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` + ErrorURI string `json:"error_uri,omitempty"` + } if err := json.Unmarshal(body, &errResp); err != nil { return nil, fmt.Errorf("failed to parse error response: %w (body: %s)", err, string(body)) } - return nil, &errResp + return nil, &oauth2.RetrieveError{ + Response: httpResp, + Body: body, + ErrorCode: errResp.Error, + ErrorDescription: errResp.ErrorDescription, + ErrorURI: errResp.ErrorURI, + } } // Handle unexpected status codes - return nil, fmt.Errorf("unexpected status code %d: %s", httpResp.StatusCode, string(body)) + return nil, &oauth2.RetrieveError{ + Response: httpResp, + Body: body, + } } diff --git a/oauthex/token_exchange_test.go b/oauthex/token_exchange_test.go index 316fa8e3..82b88d61 100644 --- a/oauthex/token_exchange_test.go +++ b/oauthex/token_exchange_test.go @@ -208,7 +208,10 @@ func TestExchangeToken(t *testing.T) { // writeErrorResponse writes an OAuth 2.0 error response per RFC 6749 Section 5.2. func writeErrorResponse(w http.ResponseWriter, errorCode, errorDescription string) { - errResp := TokenExchangeError{ + errResp := struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` + }{ ErrorCode: errorCode, ErrorDescription: errorDescription, } From b6161afaf2a4b2c5e1aa958c8d0cb11b3fd3284d Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Wed, 4 Mar 2026 17:13:04 +0530 Subject: [PATCH 12/19] chore: update auth ext package --- auth/client.go | 7 ++ auth/enterprise_auth.go | 19 ++- auth/enterprise_auth_test.go | 8 +- auth/extauth/enterprise_handler.go | 188 +++++++++++++++++++++++++++++ auth/oidc_login.go | 4 +- oauthex/jwt_bearer.go | 2 +- oauthex/jwt_bearer_test.go | 6 +- oauthex/token_exchange_test.go | 2 +- 8 files changed, 226 insertions(+), 10 deletions(-) create mode 100644 auth/extauth/enterprise_handler.go diff --git a/auth/client.go b/auth/client.go index 0af6963f..778a2cd0 100644 --- a/auth/client.go +++ b/auth/client.go @@ -40,3 +40,10 @@ type OAuthHandler interface { // The function is responsible for closing the response body. Authorize(context.Context, *http.Request, *http.Response) error } + +// OAuthHandlerBase is an embeddable type that satisfies the private method +// requirement of [OAuthHandler]. Extension packages should embed this type +// in their handler structs to implement OAuthHandler. +type OAuthHandlerBase struct{} + +func (OAuthHandlerBase) isOAuthHandler() {} diff --git a/auth/enterprise_auth.go b/auth/enterprise_auth.go index 54535c40..50e8a9a7 100644 --- a/auth/enterprise_auth.go +++ b/auth/enterprise_auth.go @@ -133,7 +133,7 @@ func EnterpriseAuthFlow( } // Step 1: Discover IdP token endpoint via OIDC discovery - idpMeta, err := oauthex.GetAuthServerMeta(ctx, config.IdPIssuerURL, httpClient) + idpMeta, err := GetAuthServerMetadatForIssuer(ctx, config.IdPIssuerURL, httpClient) if err != nil { return nil, fmt.Errorf("failed to discover IdP metadata: %w", err) } @@ -161,7 +161,7 @@ func EnterpriseAuthFlow( } // Step 3: JWT Bearer Grant (ID-JAG → Access Token) - mcpMeta, err := oauthex.GetAuthServerMeta(ctx, config.MCPAuthServerURL, httpClient) + mcpMeta, err := GetAuthServerMetadatForIssuer(ctx, config.MCPAuthServerURL, httpClient) if err != nil { return nil, fmt.Errorf("failed to discover MCP auth server metadata: %w", err) } @@ -179,3 +179,18 @@ func EnterpriseAuthFlow( } return accessToken, nil } + +// GetAuthServerMetadatForIssuer fetches authorization server metadata for the given issuer URL. +// It tries standard well-known endpoints (OAuth 2.0 and OIDC) and returns the first successful result. +func GetAuthServerMetadatForIssuer(ctx context.Context, IssuerURL string, httpClient *httpClient) (*oauthex.AuthServerMeta, error) { + for _, metadataURL := range authorizationServerMetadataURLs(issuerURL) { + asm, err := oauthex.GetAuthServerMeta(ctx, metadataURL, issuerURL, httpClient) + if err != nil { + return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) + } + if asm != nil { + return asm, nil + } + } + return nil, fmt.Errorf("no authorization server metadata found for %s", issuerURL) +} diff --git a/auth/enterprise_auth_test.go b/auth/enterprise_auth_test.go index c44e4233..53975295 100644 --- a/auth/enterprise_auth_test.go +++ b/auth/enterprise_auth_test.go @@ -183,7 +183,13 @@ func createMockMCPServer(t *testing.T) *httptest.Server { return } - resp := oauthex.JWTBearerResponse{ + resp := struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in,omitempty"` + Scope string `json:"scope,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + }{ AccessToken: "mcp-access-token", TokenType: "Bearer", ExpiresIn: 3600, diff --git a/auth/extauth/enterprise_handler.go b/auth/extauth/enterprise_handler.go new file mode 100644 index 00000000..76cb66aa --- /dev/null +++ b/auth/extauth/enterprise_handler.go @@ -0,0 +1,188 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package extauth provides OAuth handler implementations for MCP authorization extensions. +// This package implements Enterprise Managed Authorization as defined in SEP-990. + +//go:build mcp_go_client_oauth + +package extauth + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/oauthex" + "golang.org/x/oauth2" +) + +// IDTokenFetcher is called to obtain an ID Token from the enterprise IdP. +// This is typically done via OIDC login flow where the user authenticates +// with their enterprise identity provider. +type IDTokenFetcher func(ctx context.Context) (string, error) + +// EnterpriseHandlerConfig is the configuration for [EnterpriseHandler]. +type EnterpriseHandlerConfig struct { + // IdP configuration (where the user authenticates) + + // IdPIssuerURL is the enterprise IdP's issuer URL (e.g., "https://acme.okta.com"). + // Used for OIDC discovery to find the token endpoint. + IdPIssuerURL string + + // IdPClientID is the MCP Client's ID registered at the IdP. + IdPClientID string + + // IdPClientSecret is the MCP Client's secret registered at the IdP. + IdPClientSecret string + + // MCP Server configuration (the resource being accessed) + + // MCPAuthServerURL is the MCP Server's authorization server issuer URL. + // Used as the audience for token exchange and for metadata discovery. + MCPAuthServerURL string + + // MCPResourceURI is the MCP Server's resource identifier (RFC 9728). + // Used as the resource parameter in token exchange. + MCPResourceURI string + + // MCPClientID is the MCP Client's ID registered at the MCP Server. + MCPClientID string + + // MCPClientSecret is the MCP Client's secret registered at the MCP Server. + MCPClientSecret string + + // MCPScopes is the list of scopes to request at the MCP Server. + MCPScopes []string + + // IDTokenFetcher is called to obtain an ID Token when authorization is needed. + // The implementation should handle the OIDC login flow (e.g., browser redirect, + // callback handling) and return the ID token. + IDTokenFetcher IDTokenFetcher + + // HTTPClient is an optional HTTP client for customization. + // If nil, http.DefaultClient is used. + HTTPClient *http.Client +} + +// EnterpriseHandler is an implementation of [auth.OAuthHandler] that uses +// Enterprise Managed Authorization (SEP-990) to obtain access tokens. +// +// The flow consists of: +// 1. OIDC Login: User authenticates with enterprise IdP → ID Token +// 2. Token Exchange (RFC 8693): ID Token → ID-JAG at IdP +// 3. JWT Bearer Grant (RFC 7523): ID-JAG → Access Token at MCP Server +type EnterpriseHandler struct { + auth.OAuthHandlerBase + config *EnterpriseHandlerConfig + + // tokenSource is the token source obtained after authorization. + tokenSource oauth2.TokenSource +} + +// Compile-time check that EnterpriseHandler implements auth.OAuthHandler. +var _ auth.OAuthHandler = (*EnterpriseHandler)(nil) + +// NewEnterpriseHandler creates a new EnterpriseHandler. +// It performs validation of the configuration and returns an error if invalid. +func NewEnterpriseHandler(config *EnterpriseHandlerConfig) (*EnterpriseHandler, error) { + if config == nil { + return nil, errors.New("config must be provided") + } + if config.IdPIssuerURL == "" { + return nil, errors.New("IdPIssuerURL is required") + } + if config.MCPAuthServerURL == "" { + return nil, errors.New("MCPAuthServerURL is required") + } + if config.MCPResourceURI == "" { + return nil, errors.New("MCPResourceURI is required") + } + if config.IDTokenFetcher == nil { + return nil, errors.New("IDTokenFetcher is required") + } + return &EnterpriseHandler{config: config}, nil +} + +// TokenSource returns the token source for outgoing requests. +// Returns nil if authorization has not been performed yet. +func (h *EnterpriseHandler) TokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return h.tokenSource, nil +} + +// Authorize performs the Enterprise Managed Authorization flow. +// It is called when a request fails with 401 or 403. +func (h *EnterpriseHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { + defer func() { + if resp != nil && resp.Body != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + }() + + httpClient := h.config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + + // Step 1: Get ID Token via the configured fetcher (e.g., OIDC login) + idToken, err := h.config.IDTokenFetcher(ctx) + if err != nil { + return fmt.Errorf("failed to obtain ID token: %w", err) + } + + // Step 2: Discover IdP token endpoint via OIDC discovery + idpMeta, err := auth.GetAuthServerMetadataForIssuer(ctx, h.config.IdPIssuerURL, httpClient) + if err != nil { + return fmt.Errorf("failed to discover IdP metadata: %w", err) + } + + // Step 3: Token Exchange (ID Token → ID-JAG) + tokenExchangeReq := &oauthex.TokenExchangeRequest{ + RequestedTokenType: oauthex.TokenTypeIDJAG, + Audience: h.config.MCPAuthServerURL, + Resource: h.config.MCPResourceURI, + Scope: h.config.MCPScopes, + SubjectToken: idToken, + SubjectTokenType: oauthex.TokenTypeIDToken, + } + + tokenExchangeResp, err := oauthex.ExchangeToken( + ctx, + idpMeta.TokenEndpoint, + tokenExchangeReq, + h.config.IdPClientID, + h.config.IdPClientSecret, + httpClient, + ) + if err != nil { + return fmt.Errorf("token exchange failed: %w", err) + } + + // Step 4: Discover MCP Server token endpoint + mcpMeta, err := auth.GetAuthServerMetadataForIssuer(ctx, h.config.MCPAuthServerURL, httpClient) + if err != nil { + return fmt.Errorf("failed to discover MCP auth server metadata: %w", err) + } + + // Step 5: JWT Bearer Grant (ID-JAG → Access Token) + accessToken, err := oauthex.ExchangeJWTBearer( + ctx, + mcpMeta.TokenEndpoint, + tokenExchangeResp.AccessToken, + h.config.MCPClientID, + h.config.MCPClientSecret, + httpClient, + ) + if err != nil { + return fmt.Errorf("JWT bearer grant failed: %w", err) + } + + // Store the token source for subsequent requests + h.tokenSource = oauth2.StaticTokenSource(accessToken) + return nil +} diff --git a/auth/oidc_login.go b/auth/oidc_login.go index 5c1d9106..0787149b 100644 --- a/auth/oidc_login.go +++ b/auth/oidc_login.go @@ -153,7 +153,7 @@ func InitiateOIDCLogin( if httpClient == nil { httpClient = http.DefaultClient } - meta, err := oauthex.GetAuthServerMeta(ctx, config.IssuerURL, httpClient) + meta, err := GetAuthServerMetadataForIssuer(ctx, config.IssuerURL, httpClient) if err != nil { return nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) } @@ -248,7 +248,7 @@ func CompleteOIDCLogin( if httpClient == nil { httpClient = http.DefaultClient } - meta, err := oauthex.GetAuthServerMeta(ctx, config.IssuerURL, httpClient) + meta, err := GetAuthServerMetadataForIssuer(ctx, config.IssuerURL, httpClient) if err != nil { return nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) } diff --git a/oauthex/jwt_bearer.go b/oauthex/jwt_bearer.go index 0d8576fe..e79ef638 100644 --- a/oauthex/jwt_bearer.go +++ b/oauthex/jwt_bearer.go @@ -118,7 +118,7 @@ func ExchangeJWTBearer( var resp struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in,omitempty` + ExpiresIn int `json:"expires_in,omitempty"` RefreshToken string `json:"refresh_token,omitempty"` Scope string `json:"scope,omitempty"` } diff --git a/oauthex/jwt_bearer_test.go b/oauthex/jwt_bearer_test.go index 33f0ed8f..9eb08544 100644 --- a/oauthex/jwt_bearer_test.go +++ b/oauthex/jwt_bearer_test.go @@ -150,10 +150,10 @@ func TestExchangeJWTBearer(t *testing.T) { // writeJWTBearerErrorResponse writes an OAuth 2.0 error response per RFC 6749 Section 5.2. func writeJWTBearerErrorResponse(w http.ResponseWriter, errorCode, errorDescription string) { errResp := struct { - Error string `json:"error"` - ErrorCode string `json:"error_description,omitempty"` + Error string `json:"error"` + ErrorDescription string `json:"error_description,omitempty"` }{ - ErrorCode: errorCode, + Error: errorCode, ErrorDescription: errorDescription, } w.Header().Set("Content-Type", "application/json") diff --git a/oauthex/token_exchange_test.go b/oauthex/token_exchange_test.go index 82b88d61..b709b900 100644 --- a/oauthex/token_exchange_test.go +++ b/oauthex/token_exchange_test.go @@ -212,7 +212,7 @@ func writeErrorResponse(w http.ResponseWriter, errorCode, errorDescription strin Error string `json:"error"` ErrorDescription string `json:"error_description,omitempty"` }{ - ErrorCode: errorCode, + Error: errorCode, ErrorDescription: errorDescription, } From 0649bcd0e044298e2a89caa74ec69f2c5872b875 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Wed, 4 Mar 2026 22:07:45 +0530 Subject: [PATCH 13/19] chore: update header files with current year --- auth/enterprise_auth.go | 2 +- auth/enterprise_auth_test.go | 2 +- auth/oidc_login.go | 2 +- auth/oidc_login_test.go | 2 +- oauthex/id_jag.go | 2 +- oauthex/id_jag_test.go | 2 +- oauthex/jwt_bearer.go | 2 +- oauthex/jwt_bearer_test.go | 2 +- oauthex/token_exchange.go | 2 +- oauthex/token_exchange_test.go | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/auth/enterprise_auth.go b/auth/enterprise_auth.go index 50e8a9a7..089e4a11 100644 --- a/auth/enterprise_auth.go +++ b/auth/enterprise_auth.go @@ -1,4 +1,4 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. diff --git a/auth/enterprise_auth_test.go b/auth/enterprise_auth_test.go index 53975295..61274c0c 100644 --- a/auth/enterprise_auth_test.go +++ b/auth/enterprise_auth_test.go @@ -1,4 +1,4 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. diff --git a/auth/oidc_login.go b/auth/oidc_login.go index 0787149b..12d1c672 100644 --- a/auth/oidc_login.go +++ b/auth/oidc_login.go @@ -1,4 +1,4 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. diff --git a/auth/oidc_login_test.go b/auth/oidc_login_test.go index ca8c3609..830b6907 100644 --- a/auth/oidc_login_test.go +++ b/auth/oidc_login_test.go @@ -1,4 +1,4 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. diff --git a/oauthex/id_jag.go b/oauthex/id_jag.go index 860a36ea..2f02ccfe 100644 --- a/oauthex/id_jag.go +++ b/oauthex/id_jag.go @@ -1,4 +1,4 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. diff --git a/oauthex/id_jag_test.go b/oauthex/id_jag_test.go index ff710fcc..41eeb1ec 100644 --- a/oauthex/id_jag_test.go +++ b/oauthex/id_jag_test.go @@ -1,4 +1,4 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. diff --git a/oauthex/jwt_bearer.go b/oauthex/jwt_bearer.go index e79ef638..9de26a40 100644 --- a/oauthex/jwt_bearer.go +++ b/oauthex/jwt_bearer.go @@ -1,4 +1,4 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. diff --git a/oauthex/jwt_bearer_test.go b/oauthex/jwt_bearer_test.go index 9eb08544..8a04732a 100644 --- a/oauthex/jwt_bearer_test.go +++ b/oauthex/jwt_bearer_test.go @@ -1,4 +1,4 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. diff --git a/oauthex/token_exchange.go b/oauthex/token_exchange.go index 8ec13b62..01405b8a 100644 --- a/oauthex/token_exchange.go +++ b/oauthex/token_exchange.go @@ -1,4 +1,4 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. diff --git a/oauthex/token_exchange_test.go b/oauthex/token_exchange_test.go index b709b900..a0403fdc 100644 --- a/oauthex/token_exchange_test.go +++ b/oauthex/token_exchange_test.go @@ -1,4 +1,4 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. From d5b2a5d42133905ca601b64c28670a4d0eb0b1d5 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Sun, 15 Mar 2026 21:24:39 +0530 Subject: [PATCH 14/19] refactor: use oauth2 package and remove unused functions --- auth/authorization_code.go | 36 +-- auth/enterprise_auth.go | 19 +- auth/extauth/enterprise_handler.go | 18 +- auth/{ => extauth}/oidc_login.go | 337 +++++++++++--------------- auth/{ => extauth}/oidc_login_test.go | 94 ++++++- auth/shared.go | 69 ++++++ oauthex/id_jag.go | 138 ----------- oauthex/id_jag_test.go | 176 -------------- oauthex/jwt_bearer.go | 124 ++-------- oauthex/token_exchange.go | 135 ++++------- 10 files changed, 397 insertions(+), 749 deletions(-) rename auth/{ => extauth}/oidc_login.go (54%) rename auth/{ => extauth}/oidc_login_test.go (80%) create mode 100644 auth/shared.go delete mode 100644 oauthex/id_jag.go delete mode 100644 oauthex/id_jag_test.go diff --git a/auth/authorization_code.go b/auth/authorization_code.go index ac51ea12..bc80a859 100644 --- a/auth/authorization_code.go +++ b/auth/authorization_code.go @@ -11,6 +11,7 @@ import ( "crypto/rand" "errors" "fmt" + "io" "net/http" "net/url" "slices" @@ -198,6 +199,7 @@ func isNonRootHTTPSURL(u string) bool { // On success, [AuthorizationCodeHandler.TokenSource] will return a token source with the fetched token. func (h *AuthorizationCodeHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { defer resp.Body.Close() + defer io.Copy(io.Discard, resp.Body) wwwChallenges, err := oauthex.ParseWWWAuthenticate(resp.Header[http.CanonicalHeaderKey("WWW-Authenticate")]) if err != nil { @@ -395,40 +397,6 @@ func (h *AuthorizationCodeHandler) getAuthServerMetadata(ctx context.Context, pr return asm, nil } -// authorizationServerMetadataURLs returns a list of URLs to try when looking for -// authorization server metadata as mandated by the MCP specification: -// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery. -func authorizationServerMetadataURLs(issuerURL string) []string { - var urls []string - - baseURL, err := url.Parse(issuerURL) - if err != nil { - return nil - } - - if baseURL.Path == "" { - // "OAuth 2.0 Authorization Server Metadata". - baseURL.Path = "/.well-known/oauth-authorization-server" - urls = append(urls, baseURL.String()) - // "OpenID Connect Discovery 1.0". - baseURL.Path = "/.well-known/openid-configuration" - urls = append(urls, baseURL.String()) - return urls - } - - originalPath := baseURL.Path - // "OAuth 2.0 Authorization Server Metadata with path insertion". - baseURL.Path = "/.well-known/oauth-authorization-server/" + strings.TrimLeft(originalPath, "/") - urls = append(urls, baseURL.String()) - // "OpenID Connect Discovery 1.0 with path insertion". - baseURL.Path = "/.well-known/openid-configuration/" + strings.TrimLeft(originalPath, "/") - urls = append(urls, baseURL.String()) - // "OpenID Connect Discovery 1.0 with path appending". - baseURL.Path = "/" + strings.Trim(originalPath, "/") + "/.well-known/openid-configuration" - urls = append(urls, baseURL.String()) - return urls -} - type registrationType int const ( diff --git a/auth/enterprise_auth.go b/auth/enterprise_auth.go index 089e4a11..21005702 100644 --- a/auth/enterprise_auth.go +++ b/auth/enterprise_auth.go @@ -133,7 +133,7 @@ func EnterpriseAuthFlow( } // Step 1: Discover IdP token endpoint via OIDC discovery - idpMeta, err := GetAuthServerMetadatForIssuer(ctx, config.IdPIssuerURL, httpClient) + idpMeta, err := GetAuthServerMetadataForIssuer(ctx, config.IdPIssuerURL, httpClient) if err != nil { return nil, fmt.Errorf("failed to discover IdP metadata: %w", err) } @@ -161,7 +161,7 @@ func EnterpriseAuthFlow( } // Step 3: JWT Bearer Grant (ID-JAG → Access Token) - mcpMeta, err := GetAuthServerMetadatForIssuer(ctx, config.MCPAuthServerURL, httpClient) + mcpMeta, err := GetAuthServerMetadataForIssuer(ctx, config.MCPAuthServerURL, httpClient) if err != nil { return nil, fmt.Errorf("failed to discover MCP auth server metadata: %w", err) } @@ -179,18 +179,3 @@ func EnterpriseAuthFlow( } return accessToken, nil } - -// GetAuthServerMetadatForIssuer fetches authorization server metadata for the given issuer URL. -// It tries standard well-known endpoints (OAuth 2.0 and OIDC) and returns the first successful result. -func GetAuthServerMetadatForIssuer(ctx context.Context, IssuerURL string, httpClient *httpClient) (*oauthex.AuthServerMeta, error) { - for _, metadataURL := range authorizationServerMetadataURLs(issuerURL) { - asm, err := oauthex.GetAuthServerMeta(ctx, metadataURL, issuerURL, httpClient) - if err != nil { - return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) - } - if asm != nil { - return asm, nil - } - } - return nil, fmt.Errorf("no authorization server metadata found for %s", issuerURL) -} diff --git a/auth/extauth/enterprise_handler.go b/auth/extauth/enterprise_handler.go index 76cb66aa..f1382ab0 100644 --- a/auth/extauth/enterprise_handler.go +++ b/auth/extauth/enterprise_handler.go @@ -32,40 +32,50 @@ type EnterpriseHandlerConfig struct { // IdPIssuerURL is the enterprise IdP's issuer URL (e.g., "https://acme.okta.com"). // Used for OIDC discovery to find the token endpoint. + // REQUIRED. IdPIssuerURL string // IdPClientID is the MCP Client's ID registered at the IdP. + // OPTIONAL. Required if the IdP requires client authentication for token exchange. IdPClientID string // IdPClientSecret is the MCP Client's secret registered at the IdP. + // OPTIONAL. Required if the IdP requires client authentication for token exchange. IdPClientSecret string // MCP Server configuration (the resource being accessed) // MCPAuthServerURL is the MCP Server's authorization server issuer URL. // Used as the audience for token exchange and for metadata discovery. + // REQUIRED. MCPAuthServerURL string // MCPResourceURI is the MCP Server's resource identifier (RFC 9728). // Used as the resource parameter in token exchange. + // REQUIRED. MCPResourceURI string // MCPClientID is the MCP Client's ID registered at the MCP Server. + // OPTIONAL. Required if the MCP Server requires client authentication. MCPClientID string // MCPClientSecret is the MCP Client's secret registered at the MCP Server. + // OPTIONAL. Required if the MCP Server requires client authentication. MCPClientSecret string // MCPScopes is the list of scopes to request at the MCP Server. + // OPTIONAL. MCPScopes []string // IDTokenFetcher is called to obtain an ID Token when authorization is needed. // The implementation should handle the OIDC login flow (e.g., browser redirect, // callback handling) and return the ID token. + // REQUIRED. IDTokenFetcher IDTokenFetcher // HTTPClient is an optional HTTP client for customization. // If nil, http.DefaultClient is used. + // OPTIONAL. HTTPClient *http.Client } @@ -117,12 +127,8 @@ func (h *EnterpriseHandler) TokenSource(ctx context.Context) (oauth2.TokenSource // Authorize performs the Enterprise Managed Authorization flow. // It is called when a request fails with 401 or 403. func (h *EnterpriseHandler) Authorize(ctx context.Context, req *http.Request, resp *http.Response) error { - defer func() { - if resp != nil && resp.Body != nil { - io.Copy(io.Discard, resp.Body) - resp.Body.Close() - } - }() + defer resp.Body.Close() + defer io.Copy(io.Discard, resp.Body) httpClient := h.config.HTTPClient if httpClient == nil { diff --git a/auth/oidc_login.go b/auth/extauth/oidc_login.go similarity index 54% rename from auth/oidc_login.go rename to auth/extauth/oidc_login.go index 12d1c672..f71ecd57 100644 --- a/auth/oidc_login.go +++ b/auth/extauth/oidc_login.go @@ -8,21 +8,15 @@ //go:build mcp_go_client_oauth -package auth +package extauth import ( "context" "crypto/rand" - "crypto/sha256" - "encoding/base64" - "encoding/json" "fmt" - "io" "net/http" - "net/url" - "strings" - "time" + "github.com/modelcontextprotocol/go-sdk/auth" "github.com/modelcontextprotocol/go-sdk/oauthex" "golang.org/x/oauth2" ) @@ -32,24 +26,30 @@ import ( // Users can alternatively obtain ID tokens through their own methods. type OIDCLoginConfig struct { // IssuerURL is the IdP's issuer URL (e.g., "https://acme.okta.com"). + // REQUIRED. IssuerURL string // ClientID is the MCP Client's ID registered at the IdP. + // REQUIRED. ClientID string // ClientSecret is the MCP Client's secret at the IdP. - // This is OPTIONAL and only used if the client is confidential. + // OPTIONAL. Only required if the client is confidential. ClientSecret string // RedirectURL is the OAuth2 redirect URI registered with the IdP. // This must match exactly what was registered with the IdP. + // REQUIRED. RedirectURL string // Scopes are the OAuth2/OIDC scopes to request. // "openid" is REQUIRED for OIDC. Common values: ["openid", "profile", "email"] + // REQUIRED. Scopes []string - // LoginHint is an OPTIONAL hint to the IdP about the user's identity. + // LoginHint is a hint to the IdP about the user's identity. // Some IdPs may require this (e.g., as an email address for routing to SSO providers). // Example: "user@example.com" + // OPTIONAL. LoginHint string // HTTPClient is the HTTP client for making requests. // If nil, http.DefaultClient is used. + // OPTIONAL. HTTPClient *http.Client } @@ -153,37 +153,42 @@ func InitiateOIDCLogin( if httpClient == nil { httpClient = http.DefaultClient } - meta, err := GetAuthServerMetadataForIssuer(ctx, config.IssuerURL, httpClient) + meta, err := auth.GetAuthServerMetadataForIssuer(ctx, config.IssuerURL, httpClient) if err != nil { return nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) } if meta.AuthorizationEndpoint == "" { return nil, fmt.Errorf("authorization_endpoint not found in OIDC metadata") } - // Generate PKCE code verifier and challenge (RFC 7636) - codeVerifier, err := generateCodeVerifier() - if err != nil { - return nil, fmt.Errorf("failed to generate PKCE verifier: %w", err) - } - codeChallenge := generateCodeChallenge(codeVerifier) + + // Generate PKCE code verifier (RFC 7636) + codeVerifier := oauth2.GenerateVerifier() + // Generate state for CSRF protection (RFC 6749 Section 10.12) - state, err := generateState() - if err != nil { - return nil, fmt.Errorf("failed to generate state: %w", err) - } - // Build authorization URL per OIDC Core Section 3.1.2.1 - authURL, err := buildAuthorizationURL( - meta.AuthorizationEndpoint, - config.ClientID, - config.RedirectURL, - config.Scopes, - state, - codeChallenge, - config.LoginHint, - ) - if err != nil { - return nil, fmt.Errorf("failed to build authorization URL: %w", err) + state := rand.Text() + + // Build oauth2.Config to use standard library's AuthCodeURL. + oauth2Config := &oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURL, + Scopes: config.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: meta.AuthorizationEndpoint, + TokenURL: meta.TokenEndpoint, + }, } + + // Build authorization URL using oauth2.Config.AuthCodeURL with PKCE. + // S256ChallengeOption automatically computes the S256 challenge from the verifier. + authURLOpts := []oauth2.AuthCodeOption{ + oauth2.S256ChallengeOption(codeVerifier), + } + if config.LoginHint != "" { + authURLOpts = append(authURLOpts, oauth2.SetAuthURLParam("login_hint", config.LoginHint)) + } + authURL := oauth2Config.AuthCodeURL(state, authURLOpts...) + return &OIDCAuthorizationRequest{ AuthURL: authURL, State: state, @@ -248,35 +253,42 @@ func CompleteOIDCLogin( if httpClient == nil { httpClient = http.DefaultClient } - meta, err := GetAuthServerMetadataForIssuer(ctx, config.IssuerURL, httpClient) + meta, err := auth.GetAuthServerMetadataForIssuer(ctx, config.IssuerURL, httpClient) if err != nil { return nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) } if meta.TokenEndpoint == "" { return nil, fmt.Errorf("token_endpoint not found in OIDC metadata") } - // Build token request per OIDC Core Section 3.1.3.1 - formData := url.Values{} - formData.Set("grant_type", "authorization_code") - formData.Set("code", authCode) - formData.Set("redirect_uri", config.RedirectURL) - formData.Set("client_id", config.ClientID) - formData.Set("code_verifier", codeVerifier) - // Add client_secret if provided (confidential client) - if config.ClientSecret != "" { - formData.Set("client_secret", config.ClientSecret) - } - // Exchange authorization code for tokens - oauth2Token, err := exchangeAuthorizationCode( - ctx, - meta.TokenEndpoint, - formData, - httpClient, + + // Build oauth2.Config for token exchange. + oauth2Config := &oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURL, + Scopes: config.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: meta.AuthorizationEndpoint, + TokenURL: meta.TokenEndpoint, + }, + } + + // Use custom HTTP client if provided + ctxWithClient := context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + // Exchange authorization code for tokens using oauth2.Config.Exchange. + // VerifierOption provides the PKCE code_verifier for the token request. + oauth2Token, err := oauth2Config.Exchange( + ctxWithClient, + authCode, + oauth2.VerifierOption(codeVerifier), ) if err != nil { return nil, fmt.Errorf("token exchange failed: %w", err) } - // Extract ID Token from response + + // Extract ID Token from response. + // oauth2.Token.Extra() provides access to additional fields like id_token. idToken, ok := oauth2Token.Extra("id_token").(string) if !ok || idToken == "" { return nil, fmt.Errorf("id_token not found in token response") @@ -290,165 +302,96 @@ func CompleteOIDCLogin( }, nil } -// generateCodeVerifier generates a cryptographically random code verifier -// for PKCE per RFC 7636 Section 4.1. -func generateCodeVerifier() (string, error) { - // Per RFC 7636: code verifier is 43-128 characters from [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~" - // We use 32 random bytes (256 bits) base64url-encoded = 43 characters - randomBytes := make([]byte, 32) - if _, err := rand.Read(randomBytes); err != nil { - return "", fmt.Errorf("failed to generate random bytes: %w", err) - } - return base64.RawURLEncoding.EncodeToString(randomBytes), nil -} - -// generateCodeChallenge generates the PKCE code challenge from the verifier -// using SHA256 per RFC 7636 Section 4.2. -func generateCodeChallenge(verifier string) string { - hash := sha256.Sum256([]byte(verifier)) - return base64.RawURLEncoding.EncodeToString(hash[:]) -} - -// generateState generates a cryptographically random state parameter -// for CSRF protection per RFC 6749 Section 10.12. -func generateState() (string, error) { - randomBytes := make([]byte, 32) - if _, err := rand.Read(randomBytes); err != nil { - return "", fmt.Errorf("failed to generate random bytes: %w", err) - } - return base64.RawURLEncoding.EncodeToString(randomBytes), nil +// OIDCAuthorizationResult contains the authorization code and state returned +// from the IdP after user authentication. +type OIDCAuthorizationResult struct { + // Code is the authorization code returned by the IdP. + Code string + // State is the state parameter returned by the IdP. + // This MUST match the state sent in the authorization request. + State string } -// buildAuthorizationURL constructs the OIDC authorization URL. -func buildAuthorizationURL( - authEndpoint string, - clientID string, - redirectURL string, - scopes []string, - state string, - codeChallenge string, - loginHint string, -) (string, error) { - u, err := url.Parse(authEndpoint) - if err != nil { - return "", fmt.Errorf("invalid authorization endpoint: %w", err) - } - q := u.Query() - q.Set("response_type", "code") - q.Set("client_id", clientID) - q.Set("redirect_uri", redirectURL) - q.Set("scope", strings.Join(scopes, " ")) - q.Set("state", state) - q.Set("code_challenge", codeChallenge) - q.Set("code_challenge_method", "S256") - // Add login_hint if provided (optional per OIDC spec, but some IdPs may require it) - if loginHint != "" { - q.Set("login_hint", loginHint) - } - u.RawQuery = q.Encode() - return u.String(), nil -} +// AuthorizationCodeFetcher is a callback function that handles directing the user +// to the authorization URL and returning the authorization result. +// +// Implementations should: +// 1. Direct the user to authURL (e.g., open in browser) +// 2. Wait for the IdP redirect to the configured RedirectURL +// 3. Extract the code and state from the redirect query parameters +// 4. Return them in OIDCAuthorizationResult +// +// The expectedState parameter is provided for CSRF validation. Implementations +// MUST verify that the returned state matches expectedState. +type AuthorizationCodeFetcher func(ctx context.Context, authURL string, expectedState string) (*OIDCAuthorizationResult, error) -// exchangeAuthorizationCode exchanges the authorization code for tokens. -func exchangeAuthorizationCode( +// PerformOIDCLogin performs the complete OIDC Authorization Code flow with PKCE +// in a single function call. This is the recommended approach for most use cases. +// +// The authCodeFetcher callback handles the user interaction: +// - Directing the user to the IdP login page +// - Waiting for the redirect with the authorization code +// - Validating CSRF state and returning the result +// +// Example: +// +// config := &OIDCLoginConfig{ +// IssuerURL: "https://acme.okta.com", +// ClientID: "client-id", +// RedirectURL: "http://localhost:8080/callback", +// Scopes: []string{"openid", "profile", "email"}, +// } +// +// tokens, err := PerformOIDCLogin(ctx, config, func(ctx context.Context, authURL, expectedState string) (*OIDCAuthorizationResult, error) { +// // Open browser for user +// fmt.Printf("Please visit: %s\n", authURL) +// +// // Start local server to receive callback +// code, state := waitForCallback(ctx) +// +// // Validate state for CSRF protection +// if state != expectedState { +// return nil, fmt.Errorf("state mismatch") +// } +// +// return &OIDCAuthorizationResult{Code: code, State: state}, nil +// }) +// if err != nil { +// log.Fatal(err) +// } +// +// // Use tokens.IDToken with EnterpriseHandler +func PerformOIDCLogin( ctx context.Context, - tokenEndpoint string, - formData url.Values, - httpClient *http.Client, -) (*oauth2.Token, error) { - // Create HTTP request - httpReq, err := http.NewRequestWithContext( - ctx, - http.MethodPost, - tokenEndpoint, - strings.NewReader(formData.Encode()), - ) - if err != nil { - return nil, fmt.Errorf("failed to create token request: %w", err) + config *OIDCLoginConfig, + authCodeFetcher AuthorizationCodeFetcher, +) (*OIDCTokenResponse, error) { + if authCodeFetcher == nil { + return nil, fmt.Errorf("authCodeFetcher is required") } - httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - httpReq.Header.Set("Accept", "application/json") - - // Execute request - httpResp, err := httpClient.Do(httpReq) + // Step 1: Initiate the OIDC flow to get the authorization URL + authReq, err := InitiateOIDCLogin(ctx, config) if err != nil { - return nil, fmt.Errorf("token request failed: %w", err) + return nil, fmt.Errorf("failed to initiate OIDC login: %w", err) } - defer httpResp.Body.Close() - // Read response body (limit to 1MB for safety) - body, err := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20)) + // Step 2: Use callback to get authorization code from user interaction + authResult, err := authCodeFetcher(ctx, authReq.AuthURL, authReq.State) if err != nil { - return nil, fmt.Errorf("failed to read token response: %w", err) + return nil, fmt.Errorf("failed to fetch authorization code: %w", err) } - // Handle success response (200 OK) - if httpResp.StatusCode == http.StatusOK { - // Parse token response manually (following jwt_bearer.go pattern) - var tokenResp struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in,omitempty"` - RefreshToken string `json:"refresh_token,omitempty"` - IDToken string `json:"id_token,omitempty"` - Scope string `json:"scope,omitempty"` - } - if err := json.Unmarshal(body, &tokenResp); err != nil { - return nil, fmt.Errorf("failed to parse token response: %w (body: %s)", err, string(body)) - } - - // Validate required fields - if tokenResp.AccessToken == "" { - return nil, fmt.Errorf("response missing required field: access_token") - } - if tokenResp.TokenType == "" { - return nil, fmt.Errorf("response missing required field: token_type") - } - - // Convert to oauth2.Token - token := &oauth2.Token{ - AccessToken: tokenResp.AccessToken, - TokenType: tokenResp.TokenType, - RefreshToken: tokenResp.RefreshToken, - } - - // Set expiration if provided - if tokenResp.ExpiresIn > 0 { - token.Expiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) - } - - // Add extra fields (id_token, scope) - extra := make(map[string]interface{}) - if tokenResp.IDToken != "" { - extra["id_token"] = tokenResp.IDToken - } - if tokenResp.Scope != "" { - extra["scope"] = tokenResp.Scope - } - if len(extra) > 0 { - token = token.WithExtra(extra) - } - - return token, nil + // Step 3: Validate state for CSRF protection + if authResult.State != authReq.State { + return nil, fmt.Errorf("state mismatch: expected %q, got %q", authReq.State, authResult.State) } - // Handle error response (400 Bad Request) - if httpResp.StatusCode == http.StatusBadRequest { - var errResp struct { - Error string `json:"error"` - ErrorDescription string `json:"error_description,omitempty"` - ErrorURI string `json:"error_uri,omitempty"` - } - if err := json.Unmarshal(body, &errResp); err != nil { - return nil, fmt.Errorf("failed to parse error response: %w (body: %s)", err, string(body)) - } - if errResp.ErrorDescription != "" { - return nil, fmt.Errorf("token request failed: %s (%s)", errResp.Error, errResp.ErrorDescription) - } - return nil, fmt.Errorf("token request failed: %s", errResp.Error) + // Step 4: Exchange authorization code for tokens + tokens, err := CompleteOIDCLogin(ctx, config, authResult.Code, authReq.CodeVerifier) + if err != nil { + return nil, fmt.Errorf("failed to complete OIDC login: %w", err) } - // Handle unexpected status codes - return nil, fmt.Errorf("unexpected status code %d: %s", httpResp.StatusCode, string(body)) + return tokens, nil } diff --git a/auth/oidc_login_test.go b/auth/extauth/oidc_login_test.go similarity index 80% rename from auth/oidc_login_test.go rename to auth/extauth/oidc_login_test.go index 830b6907..a82a8456 100644 --- a/auth/oidc_login_test.go +++ b/auth/extauth/oidc_login_test.go @@ -4,7 +4,7 @@ //go:build mcp_go_client_oauth -package auth +package extauth import ( "context" @@ -382,3 +382,95 @@ func base64EncodeClaims(claims map[string]interface{}) string { claimsJSON, _ := json.Marshal(claims) return base64.RawURLEncoding.EncodeToString(claimsJSON) } + +// TestPerformOIDCLogin tests the combined OIDC login flow with callback. +func TestPerformOIDCLogin(t *testing.T) { + // Create mock IdP server + idpServer := createMockOIDCServerWithToken(t) + defer idpServer.Close() + config := &OIDCLoginConfig{ + IssuerURL: idpServer.URL, + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURL: "http://localhost:8080/callback", + Scopes: []string{"openid", "profile", "email"}, + HTTPClient: idpServer.Client(), + } + + t.Run("successful flow", func(t *testing.T) { + tokens, err := PerformOIDCLogin(context.Background(), config, + func(ctx context.Context, authURL, expectedState string) (*OIDCAuthorizationResult, error) { + // Validate authURL has required parameters + u, err := url.Parse(authURL) + if err != nil { + return nil, fmt.Errorf("invalid authURL: %w", err) + } + q := u.Query() + if q.Get("response_type") != "code" { + return nil, fmt.Errorf("missing response_type") + } + if q.Get("state") == "" { + return nil, fmt.Errorf("missing state") + } + + // Simulate successful user authentication + return &OIDCAuthorizationResult{ + Code: "mock-auth-code", + State: expectedState, // Return the expected state + }, nil + }) + + if err != nil { + t.Fatalf("PerformOIDCLogin failed: %v", err) + } + + if tokens.IDToken == "" { + t.Error("IDToken is empty") + } + if tokens.AccessToken == "" { + t.Error("AccessToken is empty") + } + }) + + t.Run("state mismatch", func(t *testing.T) { + _, err := PerformOIDCLogin(context.Background(), config, + func(ctx context.Context, authURL, expectedState string) (*OIDCAuthorizationResult, error) { + // Return wrong state to simulate CSRF attack + return &OIDCAuthorizationResult{ + Code: "mock-auth-code", + State: "wrong-state", + }, nil + }) + + if err == nil { + t.Error("expected error for state mismatch, got nil") + } + if !strings.Contains(err.Error(), "state mismatch") { + t.Errorf("expected state mismatch error, got: %v", err) + } + }) + + t.Run("fetcher error", func(t *testing.T) { + _, err := PerformOIDCLogin(context.Background(), config, + func(ctx context.Context, authURL, expectedState string) (*OIDCAuthorizationResult, error) { + return nil, fmt.Errorf("user cancelled") + }) + + if err == nil { + t.Error("expected error, got nil") + } + if !strings.Contains(err.Error(), "user cancelled") { + t.Errorf("expected 'user cancelled' error, got: %v", err) + } + }) + + t.Run("nil fetcher", func(t *testing.T) { + _, err := PerformOIDCLogin(context.Background(), config, nil) + if err == nil { + t.Error("expected error for nil fetcher, got nil") + } + if !strings.Contains(err.Error(), "authCodeFetcher is required") { + t.Errorf("expected 'authCodeFetcher is required' error, got: %v", err) + } + }) +} diff --git a/auth/shared.go b/auth/shared.go new file mode 100644 index 00000000..c0b6e68b --- /dev/null +++ b/auth/shared.go @@ -0,0 +1,69 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file contains shared utilities for OAuth handlers. + +//go:build mcp_go_client_oauth + +package auth + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/modelcontextprotocol/go-sdk/oauthex" +) + +// GetAuthServerMetadataForIssuer fetches authorization server metadata for the given issuer URL. +// It tries standard well-known endpoints (OAuth 2.0 and OIDC) and returns the first successful result. +func GetAuthServerMetadataForIssuer(ctx context.Context, issuerURL string, httpClient *http.Client) (*oauthex.AuthServerMeta, error) { + for _, metadataURL := range authorizationServerMetadataURLs(issuerURL) { + asm, err := oauthex.GetAuthServerMeta(ctx, metadataURL, issuerURL, httpClient) + if err != nil { + return nil, fmt.Errorf("failed to get authorization server metadata: %w", err) + } + if asm != nil { + return asm, nil + } + } + return nil, fmt.Errorf("no authorization server metadata found for %s", issuerURL) +} + +// authorizationServerMetadataURLs returns a list of URLs to try when looking for +// authorization server metadata as mandated by the MCP specification: +// https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#authorization-server-metadata-discovery. +func authorizationServerMetadataURLs(issuerURL string) []string { + var urls []string + + baseURL, err := url.Parse(issuerURL) + if err != nil { + return nil + } + + if baseURL.Path == "" { + // "OAuth 2.0 Authorization Server Metadata". + baseURL.Path = "/.well-known/oauth-authorization-server" + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0". + baseURL.Path = "/.well-known/openid-configuration" + urls = append(urls, baseURL.String()) + return urls + } + + originalPath := baseURL.Path + // "OAuth 2.0 Authorization Server Metadata with path insertion". + baseURL.Path = "/.well-known/oauth-authorization-server/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0 with path insertion". + baseURL.Path = "/.well-known/openid-configuration/" + strings.TrimLeft(originalPath, "/") + urls = append(urls, baseURL.String()) + // "OpenID Connect Discovery 1.0 with path appending". + baseURL.Path = "/" + strings.Trim(originalPath, "/") + "/.well-known/openid-configuration" + urls = append(urls, baseURL.String()) + + return urls +} diff --git a/oauthex/id_jag.go b/oauthex/id_jag.go deleted file mode 100644 index 2f02ccfe..00000000 --- a/oauthex/id_jag.go +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2026 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -// This file implements ID-JAG (Identity Assertion JWT Authorization Grant) parsing -// for Enterprise Managed Authorization (SEP-990). -// See https://github.com/modelcontextprotocol/ext-auth/blob/main/specification/draft/enterprise-managed-authorization.mdx - -//go:build mcp_go_client_oauth - -package oauthex - -import ( - "encoding/base64" - "encoding/json" - "fmt" - "strings" - "time" -) - -// IDJAGClaims represents the claims in an Identity Assertion JWT Authorization Grant -// per SEP-990 Section 4.3. The ID-JAG is issued by the IdP during token exchange -// and describes the authorization grant for accessing an MCP Server. -type IDJAGClaims struct { - // Issuer is the IdP's issuer URL. - Issuer string `json:"iss"` - // Subject is the user identifier at the MCP Server. - Subject string `json:"sub"` - // Audience is the Issuer URL of the MCP Server's authorization server. - Audience string `json:"aud"` - // Resource is the Resource Identifier of the MCP Server. - Resource string `json:"resource"` - // ClientID is the identifier of the MCP Client that this JWT was issued to. - ClientID string `json:"client_id"` - // JTI is the unique identifier of this JWT. - JTI string `json:"jti"` - // ExpiresAt is the expiration time of this JWT (Unix timestamp). - ExpiresAt int64 `json:"exp"` - // IssuedAt is the time this JWT was issued (Unix timestamp). - IssuedAt int64 `json:"iat"` - // Scope is a space-separated list of scopes associated with the token. - Scope string `json:"scope,omitempty"` -} - -// Expiry returns the expiration time as a time.Time. -func (c *IDJAGClaims) Expiry() time.Time { - return time.Unix(c.ExpiresAt, 0) -} - -// IssuedTime returns the issued-at time as a time.Time. -func (c *IDJAGClaims) IssuedTime() time.Time { - return time.Unix(c.IssuedAt, 0) -} - -// IsExpired checks if the ID-JAG has expired. -func (c *IDJAGClaims) IsExpired() bool { - return time.Now().After(c.Expiry()) -} - -// ParseIDJAG parses an ID-JAG JWT and extracts its claims without validating -// the signature. This is useful for inspecting the contents of an ID-JAG during -// development or debugging. -// -// For production use on the server-side, use ValidateIDJAG instead, which -// performs full signature validation and claim verification. -// -// The JWT must have a "typ" header of "oauth-id-jag+jwt" per SEP-990 Section 4.3. -// -// Example: -// -// claims, err := ParseIDJAG(idJAG) -// if err != nil { -// log.Fatalf("Failed to parse ID-JAG: %v", err) -// } -// fmt.Printf("Subject: %s\n", claims.Subject) -// fmt.Printf("Expires: %v\n", claims.Expiry()) -func ParseIDJAG(jwt string) (*IDJAGClaims, error) { - if jwt == "" { - return nil, fmt.Errorf("JWT is empty") - } - // Split JWT into parts (header.payload.signature) - parts := strings.Split(jwt, ".") - if len(parts) != 3 { - return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts)) - } - // Decode header to check typ claim - headerJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) - if err != nil { - return nil, fmt.Errorf("failed to decode JWT header: %w", err) - } - var header struct { - Type string `json:"typ"` - Alg string `json:"alg"` - } - if err := json.Unmarshal(headerJSON, &header); err != nil { - return nil, fmt.Errorf("failed to parse JWT header: %w", err) - } - // Verify typ claim per SEP-990 Section 4.3 - if header.Type != "oauth-id-jag+jwt" { - return nil, fmt.Errorf("invalid JWT type: expected 'oauth-id-jag+jwt', got '%s'", header.Type) - } - // Decode payload - payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return nil, fmt.Errorf("failed to decode JWT payload: %w", err) - } - // Parse claims - var claims IDJAGClaims - if err := json.Unmarshal(payloadJSON, &claims); err != nil { - return nil, fmt.Errorf("failed to parse JWT claims: %w", err) - } - // Validate required claims are present per SEP-990 Section 4.3 - if claims.Issuer == "" { - return nil, fmt.Errorf("missing required claim: iss") - } - if claims.Subject == "" { - return nil, fmt.Errorf("missing required claim: sub") - } - if claims.Audience == "" { - return nil, fmt.Errorf("missing required claim: aud") - } - if claims.Resource == "" { - return nil, fmt.Errorf("missing required claim: resource") - } - if claims.ClientID == "" { - return nil, fmt.Errorf("missing required claim: client_id") - } - if claims.JTI == "" { - return nil, fmt.Errorf("missing required claim: jti") - } - if claims.ExpiresAt == 0 { - return nil, fmt.Errorf("missing required claim: exp") - } - if claims.IssuedAt == 0 { - return nil, fmt.Errorf("missing required claim: iat") - } - return &claims, nil -} diff --git a/oauthex/id_jag_test.go b/oauthex/id_jag_test.go deleted file mode 100644 index 41eeb1ec..00000000 --- a/oauthex/id_jag_test.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright 2026 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -//go:build mcp_go_client_oauth - -package oauthex - -import ( - "encoding/base64" - "encoding/json" - "fmt" - "strings" - "testing" - "time" -) - -// TestParseIDJAG tests parsing of ID-JAG tokens. -func TestParseIDJAG(t *testing.T) { - // Create a test ID-JAG JWT - now := time.Now().Unix() - - header := map[string]string{ - "typ": "oauth-id-jag+jwt", - "alg": "RS256", - } - - claims := map[string]interface{}{ - "iss": "https://acme.okta.com", - "sub": "alice@acme.com", - "aud": "https://auth.mcpserver.example", - "resource": "https://mcp.mcpserver.example", - "client_id": "xyz789", - "jti": "unique-id-123", - "exp": now + 300, - "iat": now, - "scope": "read write", - } - // Encode header and payload - headerJSON, _ := json.Marshal(header) - claimsJSON, _ := json.Marshal(claims) - - headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) - claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) - - // Create fake JWT (header.payload.signature) - fakeJWT := fmt.Sprintf("%s.%s.fake-signature", headerB64, claimsB64) - // Test successful parsing - t.Run("successful parse", func(t *testing.T) { - parsed, err := ParseIDJAG(fakeJWT) - if err != nil { - t.Fatalf("ParseIDJAG failed: %v", err) - } - if parsed.Issuer != "https://acme.okta.com" { - t.Errorf("expected issuer 'https://acme.okta.com', got '%s'", parsed.Issuer) - } - if parsed.Subject != "alice@acme.com" { - t.Errorf("expected subject 'alice@acme.com', got '%s'", parsed.Subject) - } - if parsed.Audience != "https://auth.mcpserver.example" { - t.Errorf("expected audience 'https://auth.mcpserver.example', got '%s'", parsed.Audience) - } - if parsed.Resource != "https://mcp.mcpserver.example" { - t.Errorf("expected resource 'https://mcp.mcpserver.example', got '%s'", parsed.Resource) - } - if parsed.ClientID != "xyz789" { - t.Errorf("expected client_id 'xyz789', got '%s'", parsed.ClientID) - } - if parsed.JTI != "unique-id-123" { - t.Errorf("expected jti 'unique-id-123', got '%s'", parsed.JTI) - } - if parsed.Scope != "read write" { - t.Errorf("expected scope 'read write', got '%s'", parsed.Scope) - } - if parsed.IsExpired() { - t.Error("expected ID-JAG not to be expired") - } - }) - // Test empty JWT - t.Run("empty JWT", func(t *testing.T) { - _, err := ParseIDJAG("") - if err == nil { - t.Error("expected error for empty JWT, got nil") - } - }) - // Test invalid format - t.Run("invalid format", func(t *testing.T) { - _, err := ParseIDJAG("invalid.jwt") - if err == nil { - t.Error("expected error for invalid JWT format, got nil") - } - }) - // Test wrong typ header - t.Run("wrong typ header", func(t *testing.T) { - wrongHeader := map[string]string{ - "typ": "JWT", // Should be "oauth-id-jag+jwt" - "alg": "RS256", - } - wrongHeaderJSON, _ := json.Marshal(wrongHeader) - wrongHeaderB64 := base64.RawURLEncoding.EncodeToString(wrongHeaderJSON) - wrongJWT := fmt.Sprintf("%s.%s.fake-signature", wrongHeaderB64, claimsB64) - _, err := ParseIDJAG(wrongJWT) - if err == nil { - t.Error("expected error for wrong typ header, got nil") - } - if err != nil && !strings.Contains(err.Error(), "invalid JWT type") { - t.Errorf("expected 'invalid JWT type' error, got: %v", err) - } - }) - // Test missing required claims - t.Run("missing required claims", func(t *testing.T) { - incompleteClaims := map[string]interface{}{ - "iss": "https://acme.okta.com", - // Missing other required claims - } - incompleteJSON, _ := json.Marshal(incompleteClaims) - incompleteB64 := base64.RawURLEncoding.EncodeToString(incompleteJSON) - incompleteJWT := fmt.Sprintf("%s.%s.fake-signature", headerB64, incompleteB64) - _, err := ParseIDJAG(incompleteJWT) - if err == nil { - t.Error("expected error for missing claims, got nil") - } - }) - // Test expired ID-JAG - t.Run("expired ID-JAG", func(t *testing.T) { - expiredClaims := map[string]interface{}{ - "iss": "https://acme.okta.com", - "sub": "alice@acme.com", - "aud": "https://auth.mcpserver.example", - "resource": "https://mcp.mcpserver.example", - "client_id": "xyz789", - "jti": "unique-id-123", - "exp": now - 300, // Expired 5 minutes ago - "iat": now - 600, - "scope": "read write", - } - expiredJSON, _ := json.Marshal(expiredClaims) - expiredB64 := base64.RawURLEncoding.EncodeToString(expiredJSON) - expiredJWT := fmt.Sprintf("%s.%s.fake-signature", headerB64, expiredB64) - parsed, err := ParseIDJAG(expiredJWT) - if err != nil { - t.Fatalf("ParseIDJAG failed: %v", err) - } - if !parsed.IsExpired() { - t.Error("expected ID-JAG to be expired") - } - }) -} - -// TestIDJAGClaimsMethods tests the helper methods on IDJAGClaims. -func TestIDJAGClaimsMethods(t *testing.T) { - now := time.Now() - claims := &IDJAGClaims{ - ExpiresAt: now.Add(1 * time.Hour).Unix(), - IssuedAt: now.Unix(), - } - // Test Expiry - expiry := claims.Expiry() - if expiry.Before(now) { - t.Error("expected expiry to be in the future") - } - // Test IssuedTime - issued := claims.IssuedTime() - if issued.After(now.Add(1 * time.Second)) { - t.Error("expected issued time to be in the past") - } - // Test IsExpired (should not be expired) - if claims.IsExpired() { - t.Error("expected claims not to be expired") - } - // Test IsExpired (should be expired) - claims.ExpiresAt = now.Add(-1 * time.Hour).Unix() - if !claims.IsExpired() { - t.Error("expected claims to be expired") - } -} diff --git a/oauthex/jwt_bearer.go b/oauthex/jwt_bearer.go index 9de26a40..7b71a0d1 100644 --- a/oauthex/jwt_bearer.go +++ b/oauthex/jwt_bearer.go @@ -11,13 +11,8 @@ package oauthex import ( "context" - "encoding/json" "fmt" - "io" "net/http" - "net/url" - "strings" - "time" "golang.org/x/oauth2" ) @@ -73,104 +68,37 @@ func ExchangeJWTBearer( if err := checkURLScheme(tokenEndpoint); err != nil { return nil, fmt.Errorf("invalid token endpoint: %w", err) } - // Build the JWT Bearer grant request per RFC 7523 Section 2.1 - formData := url.Values{} - formData.Set("grant_type", GrantTypeJWTBearer) - formData.Set("assertion", assertion) - // Add client authentication (following OAuth 2.0 client_secret_post method) - // Note: Per SEP-990 Section 5.1, the client_id in the assertion must match - // the authenticated client - if clientID != "" { - formData.Set("client_id", clientID) - } - if clientSecret != "" { - formData.Set("client_secret", clientSecret) - } - // Create HTTP request - httpReq, err := http.NewRequestWithContext( - ctx, - http.MethodPost, - tokenEndpoint, - strings.NewReader(formData.Encode()), - ) - if err != nil { - return nil, fmt.Errorf("failed to create JWT bearer grant request: %w", err) + + // Per RFC 6749 Section 3.2, parameters sent without a value (like the empty + // "code" parameter) MUST be treated as if they were omitted from the request. + // The oauth2 library's Exchange method sends an empty code, but compliant + // servers should ignore it. + cfg := &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: tokenEndpoint, + AuthStyle: oauth2.AuthStyleInParams, // Use POST body auth per SEP-990 + }, } - httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - httpReq.Header.Set("Accept", "application/json") - // Use provided client or default + + // Use custom HTTP client if provided if httpClient == nil { httpClient = http.DefaultClient } - // Execute the request - httpResp, err := httpClient.Do(httpReq) + ctxWithClient := context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + // Exchange with JWT Bearer grant type and assertion. + // SetAuthURLParam overrides the default grant_type and adds the assertion parameter. + token, err := cfg.Exchange( + ctxWithClient, + "", // empty code - per RFC 6749 Section 3.2, empty params should be ignored + oauth2.SetAuthURLParam("grant_type", GrantTypeJWTBearer), + oauth2.SetAuthURLParam("assertion", assertion), + ) if err != nil { return nil, fmt.Errorf("JWT bearer grant request failed: %w", err) } - defer httpResp.Body.Close() - // Read response body (limit to 1MB for safety, following SDK pattern) - body, err := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20)) - if err != nil { - return nil, fmt.Errorf("failed to read JWT bearer grant response: %w", err) - } - // Handle success response (200 OK per OAuth 2.0) - if httpResp.StatusCode == http.StatusOK { - var resp struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in,omitempty"` - RefreshToken string `json:"refresh_token,omitempty"` - Scope string `json:"scope,omitempty"` - } - if err := json.Unmarshal(body, &resp); err != nil { - return nil, fmt.Errorf("failed to parse JWT bearer grant response: %w (body: %s)", err, string(body)) - } - // Validate response per OAuth 2.0 - if resp.AccessToken == "" { - return nil, fmt.Errorf("response missing required field: access_token") - } - if resp.TokenType == "" { - return nil, fmt.Errorf("response missing required field: token_type") - } - // Convert to golang.org/x/oauth2.Token - token := &oauth2.Token{ - AccessToken: resp.AccessToken, - TokenType: resp.TokenType, - RefreshToken: resp.RefreshToken, - } - // Set expiration if provided - if resp.ExpiresIn > 0 { - token.Expiry = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second) - } - // Add scope to extra data if provided - if resp.Scope != "" { - token = token.WithExtra(map[string]interface{}{ - "scope": resp.Scope, - }) - } - return token, nil - } - // Handle error response (400 Bad Request per RFC 6749) - if httpResp.StatusCode == http.StatusBadRequest { - var errResp struct { - Error string `json:"error"` - ErrorDescription string `json:"error_description,omitempty"` - ErrorURI string `json:"error_uri,omitempty"` - } - if err := json.Unmarshal(body, &errResp); err != nil { - return nil, fmt.Errorf("failed to parse error response: %w (body: %s)", err, string(body)) - } - return nil, &oauth2.RetrieveError{ - Response: httpResp, - Body: body, - ErrorCode: errResp.Error, - ErrorDescription: errResp.ErrorDescription, - ErrorURI: errResp.ErrorURI, - } - } - // Handle unexpected status codes - return nil, &oauth2.RetrieveError{ - Response: httpResp, - Body: body, - } + + return token, nil } diff --git a/oauthex/token_exchange.go b/oauthex/token_exchange.go index 01405b8a..de723d12 100644 --- a/oauthex/token_exchange.go +++ b/oauthex/token_exchange.go @@ -11,11 +11,8 @@ package oauthex import ( "context" - "encoding/json" "fmt" - "io" "net/http" - "net/url" "strings" "golang.org/x/oauth2" @@ -161,102 +158,76 @@ func ExchangeToken( return nil, fmt.Errorf("invalid resource: %w", err) } - // Build the token exchange request body per RFC 8693 - formData := url.Values{} - formData.Set("grant_type", GrantTypeTokenExchange) - formData.Set("requested_token_type", req.RequestedTokenType) - formData.Set("audience", req.Audience) - formData.Set("resource", req.Resource) - formData.Set("subject_token", req.SubjectToken) - formData.Set("subject_token_type", req.SubjectTokenType) - - if len(req.Scope) > 0 { - formData.Set("scope", strings.Join(req.Scope, " ")) + // Per RFC 6749 Section 3.2, parameters sent without a value (like the empty + // "code" parameter) MUST be treated as if they were omitted from the request. + // The oauth2 library's Exchange method sends an empty code, but compliant + // servers should ignore it. + cfg := &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: tokenEndpoint, + AuthStyle: oauth2.AuthStyleInParams, // Use POST body auth per SEP-990 + }, } - // Add client authentication (following OAuth 2.0 client_secret_post method) - if clientID != "" { - formData.Set("client_id", clientID) - } - if clientSecret != "" { - formData.Set("client_secret", clientSecret) + // Use custom HTTP client if provided + if httpClient == nil { + httpClient = http.DefaultClient } + ctxWithClient := context.WithValue(ctx, oauth2.HTTPClient, httpClient) - // Create HTTP request - httpReq, err := http.NewRequestWithContext( - ctx, - http.MethodPost, - tokenEndpoint, - strings.NewReader(formData.Encode()), - ) - if err != nil { - return nil, fmt.Errorf("failed to create token exchange request: %w", err) + // Build token exchange parameters per RFC 8693 + opts := []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("grant_type", GrantTypeTokenExchange), + oauth2.SetAuthURLParam("requested_token_type", req.RequestedTokenType), + oauth2.SetAuthURLParam("audience", req.Audience), + oauth2.SetAuthURLParam("resource", req.Resource), + oauth2.SetAuthURLParam("subject_token", req.SubjectToken), + oauth2.SetAuthURLParam("subject_token_type", req.SubjectTokenType), } - - httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") - httpReq.Header.Set("Accept", "application/json") - - // Use provided client or default - if httpClient == nil { - httpClient = http.DefaultClient + if len(req.Scope) > 0 { + opts = append(opts, oauth2.SetAuthURLParam("scope", strings.Join(req.Scope, " "))) } - // Execute the request - httpResp, err := httpClient.Do(httpReq) + // Exchange with token exchange grant type. + // SetAuthURLParam overrides the default grant_type and adds all required parameters. + token, err := cfg.Exchange( + ctxWithClient, + "", // empty code - per RFC 6749 Section 3.2, empty params should be ignored + opts..., + ) if err != nil { return nil, fmt.Errorf("token exchange request failed: %w", err) } - defer httpResp.Body.Close() - // Read response body (limit to 1MB for safety, following SDK pattern) - body, err := io.ReadAll(io.LimitReader(httpResp.Body, 1<<20)) - if err != nil { - return nil, fmt.Errorf("failed to read token exchange response: %w", err) + // Extract issued_token_type from Token.Extra(). + // The oauth2 library stores additional response fields in Extra. + issuedTokenType, _ := token.Extra("issued_token_type").(string) + if issuedTokenType == "" { + return nil, fmt.Errorf("response missing required field: issued_token_type") } - // Handle success response (200 OK per RFC 8693) - if httpResp.StatusCode == http.StatusOK { - var resp TokenExchangeResponse - if err := json.Unmarshal(body, &resp); err != nil { - return nil, fmt.Errorf("failed to parse token exchange response: %w (body: %s)", err, string(body)) - } - - // Validate response per SEP-990 Section 4.2 - if resp.IssuedTokenType == "" { - return nil, fmt.Errorf("response missing required field: issued_token_type") - } - if resp.AccessToken == "" { - return nil, fmt.Errorf("response missing required field: access_token") - } - if resp.TokenType == "" { - return nil, fmt.Errorf("response missing required field: token_type") - } + // Build TokenExchangeResponse from oauth2.Token + resp := &TokenExchangeResponse{ + IssuedTokenType: issuedTokenType, + AccessToken: token.AccessToken, + TokenType: token.TokenType, + } - return &resp, nil + // Extract optional fields from Extra + if scope, ok := token.Extra("scope").(string); ok { + resp.Scope = scope } - // Handle error response (400 Bad Request per RFC 6749) - if httpResp.StatusCode == http.StatusBadRequest { - var errResp struct { - Error string `json:"error"` - ErrorDescription string `json:"error_description,omitempty"` - ErrorURI string `json:"error_uri,omitempty"` - } - if err := json.Unmarshal(body, &errResp); err != nil { - return nil, fmt.Errorf("failed to parse error response: %w (body: %s)", err, string(body)) - } - return nil, &oauth2.RetrieveError{ - Response: httpResp, - Body: body, - ErrorCode: errResp.Error, - ErrorDescription: errResp.ErrorDescription, - ErrorURI: errResp.ErrorURI, + // Calculate expires_in from token.Expiry if available + if !token.Expiry.IsZero() { + resp.ExpiresIn = int(token.Expiry.Sub(token.Expiry).Seconds()) // This would be 0 + // Actually get the raw expires_in if available + if expiresIn, ok := token.Extra("expires_in").(float64); ok { + resp.ExpiresIn = int(expiresIn) } } - // Handle unexpected status codes - return nil, &oauth2.RetrieveError{ - Response: httpResp, - Body: body, - } + return resp, nil } From 4fa29721c343619cd655ee2502caeeb5ff2474b5 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Wed, 18 Mar 2026 20:57:10 +0530 Subject: [PATCH 15/19] chore: rename GetAuthServerMetadata remove unused exports --- auth/enterprise_auth.go | 4 ++-- auth/extauth/enterprise_handler.go | 4 ++-- auth/extauth/oidc_login.go | 4 ++-- auth/shared.go | 4 ++-- oauthex/oauth2.go | 10 ++-------- 5 files changed, 10 insertions(+), 16 deletions(-) diff --git a/auth/enterprise_auth.go b/auth/enterprise_auth.go index 21005702..ef6227d1 100644 --- a/auth/enterprise_auth.go +++ b/auth/enterprise_auth.go @@ -133,7 +133,7 @@ func EnterpriseAuthFlow( } // Step 1: Discover IdP token endpoint via OIDC discovery - idpMeta, err := GetAuthServerMetadataForIssuer(ctx, config.IdPIssuerURL, httpClient) + idpMeta, err := GetAuthServerMetadata(ctx, config.IdPIssuerURL, httpClient) if err != nil { return nil, fmt.Errorf("failed to discover IdP metadata: %w", err) } @@ -161,7 +161,7 @@ func EnterpriseAuthFlow( } // Step 3: JWT Bearer Grant (ID-JAG → Access Token) - mcpMeta, err := GetAuthServerMetadataForIssuer(ctx, config.MCPAuthServerURL, httpClient) + mcpMeta, err := GetAuthServerMetadata(ctx, config.MCPAuthServerURL, httpClient) if err != nil { return nil, fmt.Errorf("failed to discover MCP auth server metadata: %w", err) } diff --git a/auth/extauth/enterprise_handler.go b/auth/extauth/enterprise_handler.go index f1382ab0..3cc5bc03 100644 --- a/auth/extauth/enterprise_handler.go +++ b/auth/extauth/enterprise_handler.go @@ -142,7 +142,7 @@ func (h *EnterpriseHandler) Authorize(ctx context.Context, req *http.Request, re } // Step 2: Discover IdP token endpoint via OIDC discovery - idpMeta, err := auth.GetAuthServerMetadataForIssuer(ctx, h.config.IdPIssuerURL, httpClient) + idpMeta, err := auth.GetAuthServerMetadata(ctx, h.config.IdPIssuerURL, httpClient) if err != nil { return fmt.Errorf("failed to discover IdP metadata: %w", err) } @@ -170,7 +170,7 @@ func (h *EnterpriseHandler) Authorize(ctx context.Context, req *http.Request, re } // Step 4: Discover MCP Server token endpoint - mcpMeta, err := auth.GetAuthServerMetadataForIssuer(ctx, h.config.MCPAuthServerURL, httpClient) + mcpMeta, err := auth.GetAuthServerMetadata(ctx, h.config.MCPAuthServerURL, httpClient) if err != nil { return fmt.Errorf("failed to discover MCP auth server metadata: %w", err) } diff --git a/auth/extauth/oidc_login.go b/auth/extauth/oidc_login.go index f71ecd57..e42cd704 100644 --- a/auth/extauth/oidc_login.go +++ b/auth/extauth/oidc_login.go @@ -153,7 +153,7 @@ func InitiateOIDCLogin( if httpClient == nil { httpClient = http.DefaultClient } - meta, err := auth.GetAuthServerMetadataForIssuer(ctx, config.IssuerURL, httpClient) + meta, err := auth.GetAuthServerMetadata(ctx, config.IssuerURL, httpClient) if err != nil { return nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) } @@ -253,7 +253,7 @@ func CompleteOIDCLogin( if httpClient == nil { httpClient = http.DefaultClient } - meta, err := auth.GetAuthServerMetadataForIssuer(ctx, config.IssuerURL, httpClient) + meta, err := auth.GetAuthServerMetadata(ctx, config.IssuerURL, httpClient) if err != nil { return nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) } diff --git a/auth/shared.go b/auth/shared.go index c0b6e68b..fc0a482f 100644 --- a/auth/shared.go +++ b/auth/shared.go @@ -18,9 +18,9 @@ import ( "github.com/modelcontextprotocol/go-sdk/oauthex" ) -// GetAuthServerMetadataForIssuer fetches authorization server metadata for the given issuer URL. +// GetAuthServerMetadata fetches authorization server metadata for the given issuer URL. // It tries standard well-known endpoints (OAuth 2.0 and OIDC) and returns the first successful result. -func GetAuthServerMetadataForIssuer(ctx context.Context, issuerURL string, httpClient *http.Client) (*oauthex.AuthServerMeta, error) { +func GetAuthServerMetadata(ctx context.Context, issuerURL string, httpClient *http.Client) (*oauthex.AuthServerMeta, error) { for _, metadataURL := range authorizationServerMetadataURLs(issuerURL) { asm, err := oauthex.GetAuthServerMeta(ctx, metadataURL, issuerURL, httpClient) if err != nil { diff --git a/oauthex/oauth2.go b/oauthex/oauth2.go index fc3b45ee..39a91f35 100644 --- a/oauthex/oauth2.go +++ b/oauthex/oauth2.go @@ -63,10 +63,10 @@ func getJSON[T any](ctx context.Context, c *http.Client, url string, limit int64 return &t, nil } -// checkURLScheme ensures that its argument is a valid URL with a scheme +// CheckURLScheme ensures that its argument is a valid URL with a scheme // that prevents XSS attacks. // See #526. -func checkURLScheme(u string) error { +func CheckURLScheme(u string) error { if u == "" { return nil } @@ -94,9 +94,3 @@ func checkHTTPSOrLoopback(addr string) error { } return nil } - -// CheckURLScheme validates a URL scheme for security. -// This is exported for use by the auth package. -func CheckURLScheme(u string) error { - return checkURLScheme(u) -} From 72159959a71e188f0085e1b723bf207f5c138b2c Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Wed, 18 Mar 2026 21:12:05 +0530 Subject: [PATCH 16/19] refactor: use struct for client credentials --- oauthex/token_exchange.go | 28 +++++++++++++++++++--------- oauthex/token_exchange_test.go | 18 ++++++++++++------ 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/oauthex/token_exchange.go b/oauthex/token_exchange.go index de723d12..e34eb369 100644 --- a/oauthex/token_exchange.go +++ b/oauthex/token_exchange.go @@ -34,6 +34,17 @@ const ( GrantTypeTokenExchange = "urn:ietf:params:oauth:grant-type:token-exchange" ) +// ClientCredentials holds client authentication credentials for OAuth token requests. +type ClientCredentials struct { + // ClientID is the OAuth2 client identifier. + // REQUIRED. + ClientID string + + // ClientSecret is the OAuth2 client secret for confidential clients. + // OPTIONAL. Not required for public clients. + ClientSecret string +} + // TokenExchangeRequest represents a Token Exchange request per RFC 8693. // This is used for Enterprise Managed Authorization (SEP-990) where an MCP Client // exchanges an ID Token from an enterprise IdP for an ID-JAG that can be used @@ -99,10 +110,6 @@ type TokenExchangeResponse struct { // The tokenEndpoint parameter should be the IdP's token endpoint (typically // obtained from the IdP's authorization server metadata). // -// Client authentication must be performed by the caller by including appropriate -// credentials in the request (e.g., using Basic auth via the Authorization header, -// or including client_id and client_secret in the form data). -// // Example: // // req := &TokenExchangeRequest{ @@ -113,14 +120,14 @@ type TokenExchangeResponse struct { // SubjectToken: idToken, // SubjectTokenType: TokenTypeIDToken, // } +// clientCreds := &ClientCredentials{ClientID: "my-client", ClientSecret: "secret"} // -// resp, err := ExchangeToken(ctx, idpTokenEndpoint, req, clientID, clientSecret, nil) +// resp, err := ExchangeToken(ctx, idpTokenEndpoint, req, clientCreds, nil) func ExchangeToken( ctx context.Context, tokenEndpoint string, req *TokenExchangeRequest, - clientID string, - clientSecret string, + clientCreds *ClientCredentials, httpClient *http.Client, ) (*TokenExchangeResponse, error) { if tokenEndpoint == "" { @@ -129,6 +136,9 @@ func ExchangeToken( if req == nil { return nil, fmt.Errorf("token exchange request is required") } + if clientCreds == nil { + return nil, fmt.Errorf("client credentials are required") + } // Validate required fields per SEP-990 Section 4 if req.RequestedTokenType == "" { @@ -163,8 +173,8 @@ func ExchangeToken( // The oauth2 library's Exchange method sends an empty code, but compliant // servers should ignore it. cfg := &oauth2.Config{ - ClientID: clientID, - ClientSecret: clientSecret, + ClientID: clientCreds.ClientID, + ClientSecret: clientCreds.ClientSecret, Endpoint: oauth2.Endpoint{ TokenURL: tokenEndpoint, AuthStyle: oauth2.AuthStyleInParams, // Use POST body auth per SEP-990 diff --git a/oauthex/token_exchange_test.go b/oauthex/token_exchange_test.go index a0403fdc..8c646cad 100644 --- a/oauthex/token_exchange_test.go +++ b/oauthex/token_exchange_test.go @@ -128,8 +128,10 @@ func TestExchangeToken(t *testing.T) { context.Background(), server.URL, req, - "test-client-id", - "test-client-secret", + &ClientCredentials{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + }, server.Client(), ) @@ -171,8 +173,10 @@ func TestExchangeToken(t *testing.T) { context.Background(), server.URL, req, - "test-client-id", - "test-client-secret", + &ClientCredentials{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + }, server.Client(), ) @@ -195,8 +199,10 @@ func TestExchangeToken(t *testing.T) { context.Background(), server.URL, req, - "test-client-id", - "test-client-secret", + &ClientCredentials{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + }, server.Client(), ) From 45a996483be071e6db4e8c0cfc4b9473a0a182b3 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Wed, 18 Mar 2026 21:44:18 +0530 Subject: [PATCH 17/19] feat: use exported check url scheme method --- oauthex/auth_meta.go | 2 +- oauthex/dcr.go | 4 ++-- oauthex/jwt_bearer.go | 2 +- oauthex/resource_meta.go | 2 +- oauthex/token_exchange.go | 6 +++--- oauthex/url_scheme_test.go | 6 +++--- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/oauthex/auth_meta.go b/oauthex/auth_meta.go index 36210576..da21ca54 100644 --- a/oauthex/auth_meta.go +++ b/oauthex/auth_meta.go @@ -185,7 +185,7 @@ func validateAuthServerMetaURLs(asm *AuthServerMeta) error { } for _, u := range urls { - if err := checkURLScheme(u.value); err != nil { + if err := CheckURLScheme(u.value); err != nil { return fmt.Errorf("%s: %w", u.name, err) } } diff --git a/oauthex/dcr.go b/oauthex/dcr.go index 6db30255..3159a07f 100644 --- a/oauthex/dcr.go +++ b/oauthex/dcr.go @@ -237,7 +237,7 @@ func RegisterClient(ctx context.Context, registrationEndpoint string, clientMeta func validateClientRegistrationURLs(meta *ClientRegistrationMetadata) error { // Validate redirect URIs for i, uri := range meta.RedirectURIs { - if err := checkURLScheme(uri); err != nil { + if err := CheckURLScheme(uri); err != nil { return fmt.Errorf("redirect_uris[%d]: %w", i, err) } } @@ -255,7 +255,7 @@ func validateClientRegistrationURLs(meta *ClientRegistrationMetadata) error { } for _, u := range urls { - if err := checkURLScheme(u.value); err != nil { + if err := CheckURLScheme(u.value); err != nil { return fmt.Errorf("%s: %w", u.name, err) } } diff --git a/oauthex/jwt_bearer.go b/oauthex/jwt_bearer.go index 7b71a0d1..4856cc28 100644 --- a/oauthex/jwt_bearer.go +++ b/oauthex/jwt_bearer.go @@ -65,7 +65,7 @@ func ExchangeJWTBearer( return nil, fmt.Errorf("assertion is required") } // Validate URL scheme to prevent XSS attacks (see #526) - if err := checkURLScheme(tokenEndpoint); err != nil { + if err := CheckURLScheme(tokenEndpoint); err != nil { return nil, fmt.Errorf("invalid token endpoint: %w", err) } diff --git a/oauthex/resource_meta.go b/oauthex/resource_meta.go index 4680c153..43557101 100644 --- a/oauthex/resource_meta.go +++ b/oauthex/resource_meta.go @@ -112,7 +112,7 @@ func GetProtectedResourceMetadata(ctx context.Context, metadataURL, resourceURL } // Validate the authorization server URLs to prevent XSS attacks (see #526). for i, u := range prm.AuthorizationServers { - if err := checkURLScheme(u); err != nil { + if err := CheckURLScheme(u); err != nil { return nil, fmt.Errorf("authorization_servers[%d]: %v", i, err) } if err := checkHTTPSOrLoopback(u); err != nil { diff --git a/oauthex/token_exchange.go b/oauthex/token_exchange.go index e34eb369..aa58c386 100644 --- a/oauthex/token_exchange.go +++ b/oauthex/token_exchange.go @@ -158,13 +158,13 @@ func ExchangeToken( } // Validate URL schemes to prevent XSS attacks (see #526) - if err := checkURLScheme(tokenEndpoint); err != nil { + if err := CheckURLScheme(tokenEndpoint); err != nil { return nil, fmt.Errorf("invalid token endpoint: %w", err) } - if err := checkURLScheme(req.Audience); err != nil { + if err := CheckURLScheme(req.Audience); err != nil { return nil, fmt.Errorf("invalid audience: %w", err) } - if err := checkURLScheme(req.Resource); err != nil { + if err := CheckURLScheme(req.Resource); err != nil { return nil, fmt.Errorf("invalid resource: %w", err) } diff --git a/oauthex/url_scheme_test.go b/oauthex/url_scheme_test.go index 83eeb5e1..c5221a62 100644 --- a/oauthex/url_scheme_test.go +++ b/oauthex/url_scheme_test.go @@ -15,7 +15,7 @@ import ( "testing" ) -// TestCheckURLScheme tests the checkURLScheme function directly. +// TestCheckURLScheme tests the CheckURLScheme function directly. func TestCheckURLScheme(t *testing.T) { tests := []struct { name string @@ -40,9 +40,9 @@ func TestCheckURLScheme(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := checkURLScheme(tt.url) + err := CheckURLScheme(tt.url) if (err != nil) != tt.wantErr { - t.Errorf("checkURLScheme(%q): got err %v, want err %v", tt.url, err != nil, tt.wantErr) + t.Errorf("CheckURLScheme(%q): got err %v, want err %v", tt.url, err != nil, tt.wantErr) } }) } From ef22b6f9d5fed3ede0b89b900b3737dc85cd5c92 Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Wed, 18 Mar 2026 22:52:17 +0530 Subject: [PATCH 18/19] refactor: oidc login and make jwt bearer private method --- auth/enterprise_auth.go | 181 ----------------- auth/enterprise_auth_test.go | 224 -------------------- auth/extauth/enterprise_handler.go | 65 +++++- auth/extauth/oidc_login.go | 314 ++++++++--------------------- auth/extauth/oidc_login_test.go | 104 +++++----- oauthex/jwt_bearer.go | 104 ---------- oauthex/jwt_bearer_test.go | 163 --------------- 7 files changed, 195 insertions(+), 960 deletions(-) delete mode 100644 auth/enterprise_auth.go delete mode 100644 auth/enterprise_auth_test.go delete mode 100644 oauthex/jwt_bearer.go delete mode 100644 oauthex/jwt_bearer_test.go diff --git a/auth/enterprise_auth.go b/auth/enterprise_auth.go deleted file mode 100644 index ef6227d1..00000000 --- a/auth/enterprise_auth.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2026 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -// This file implements the client-side Enterprise Managed Authorization flow -// for MCP as specified in SEP-990. - -//go:build mcp_go_client_oauth - -package auth - -import ( - "context" - "fmt" - "net/http" - - "github.com/modelcontextprotocol/go-sdk/oauthex" - "golang.org/x/oauth2" -) - -// EnterpriseAuthConfig contains configuration for Enterprise Managed Authorization -// (SEP-990). This configures both the IdP (for token exchange) and the MCP Server -// (for JWT Bearer grant). -type EnterpriseAuthConfig struct { - // IdP configuration (where the user authenticates) - IdPIssuerURL string // e.g., "https://acme.okta.com" - IdPClientID string // MCP Client's ID at the IdP - IdPClientSecret string // MCP Client's secret at the IdP - - // MCP Server configuration (the resource being accessed) - MCPAuthServerURL string // MCP Server's auth server issuer URL - MCPResourceURI string // MCP Server's resource identifier - MCPClientID string // MCP Client's ID at the MCP Server - MCPClientSecret string // MCP Client's secret at the MCP Server - MCPScopes []string // Requested scopes at the MCP Server - - // Optional HTTP client for customization - HTTPClient *http.Client -} - -// EnterpriseAuthFlow performs the complete Enterprise Managed Authorization flow: -// 1. Token Exchange: ID Token → ID-JAG at IdP -// 2. JWT Bearer: ID-JAG → Access Token at MCP Server -// -// This function takes an ID Token that was obtained via SSO (e.g., OIDC login) -// and exchanges it for an access token that can be used to call the MCP Server. -// -// There are two ways to obtain an ID Token for use with this function: -// -// Option 1: Use the OIDC login helper functions (full flow with SSO): -// -// // Step 1: Initiate OIDC login -// oidcConfig := &OIDCLoginConfig{ -// IssuerURL: "https://acme.okta.com", -// ClientID: "client-id", -// RedirectURL: "http://localhost:8080/callback", -// Scopes: []string{"openid", "profile", "email"}, -// } -// authReq, err := InitiateOIDCLogin(ctx, oidcConfig) -// if err != nil { -// log.Fatal(err) -// } -// -// // Step 2: Direct user to authReq.AuthURL for authentication -// fmt.Printf("Visit: %s\n", authReq.AuthURL) -// -// // Step 3: After redirect, complete login with authorization code -// tokens, err := CompleteOIDCLogin(ctx, oidcConfig, authCode, authReq.CodeVerifier) -// if err != nil { -// log.Fatal(err) -// } -// -// // Step 4: Use ID token for enterprise auth -// enterpriseConfig := &EnterpriseAuthConfig{ -// IdPIssuerURL: "https://acme.okta.com", -// IdPClientID: "client-id-at-idp", -// IdPClientSecret: "secret-at-idp", -// MCPAuthServerURL: "https://auth.mcpserver.example", -// MCPResourceURI: "https://mcp.mcpserver.example", -// MCPClientID: "client-id-at-mcp", -// MCPClientSecret: "secret-at-mcp", -// MCPScopes: []string{"read", "write"}, -// } -// accessToken, err := EnterpriseAuthFlow(ctx, enterpriseConfig, tokens.IDToken) -// if err != nil { -// log.Fatal(err) -// } -// -// Option 2: Bring your own ID Token (if you already have one): -// -// config := &EnterpriseAuthConfig{ -// IdPIssuerURL: "https://acme.okta.com", -// IdPClientID: "client-id-at-idp", -// IdPClientSecret: "secret-at-idp", -// MCPAuthServerURL: "https://auth.mcpserver.example", -// MCPResourceURI: "https://mcp.mcpserver.example", -// MCPClientID: "client-id-at-mcp", -// MCPClientSecret: "secret-at-mcp", -// MCPScopes: []string{"read", "write"}, -// } -// -// // If you already obtained an ID token through your own means -// accessToken, err := EnterpriseAuthFlow(ctx, config, myIDToken) -// if err != nil { -// log.Fatal(err) -// } -// -// // Use accessToken to call MCP Server APIs -func EnterpriseAuthFlow( - ctx context.Context, - config *EnterpriseAuthConfig, - idToken string, -) (*oauth2.Token, error) { - if config == nil { - return nil, fmt.Errorf("config is required") - } - if idToken == "" { - return nil, fmt.Errorf("idToken is required") - } - // Validate configuration - if config.IdPIssuerURL == "" { - return nil, fmt.Errorf("IdPIssuerURL is required") - } - if config.MCPAuthServerURL == "" { - return nil, fmt.Errorf("MCPAuthServerURL is required") - } - if config.MCPResourceURI == "" { - return nil, fmt.Errorf("MCPResourceURI is required") - } - httpClient := config.HTTPClient - if httpClient == nil { - httpClient = http.DefaultClient - } - - // Step 1: Discover IdP token endpoint via OIDC discovery - idpMeta, err := GetAuthServerMetadata(ctx, config.IdPIssuerURL, httpClient) - if err != nil { - return nil, fmt.Errorf("failed to discover IdP metadata: %w", err) - } - - // Step 2: Token Exchange (ID Token → ID-JAG) - tokenExchangeReq := &oauthex.TokenExchangeRequest{ - RequestedTokenType: oauthex.TokenTypeIDJAG, - Audience: config.MCPAuthServerURL, - Resource: config.MCPResourceURI, - Scope: config.MCPScopes, - SubjectToken: idToken, - SubjectTokenType: oauthex.TokenTypeIDToken, - } - - tokenExchangeResp, err := oauthex.ExchangeToken( - ctx, - idpMeta.TokenEndpoint, - tokenExchangeReq, - config.IdPClientID, - config.IdPClientSecret, - httpClient, - ) - if err != nil { - return nil, fmt.Errorf("token exchange failed: %w", err) - } - - // Step 3: JWT Bearer Grant (ID-JAG → Access Token) - mcpMeta, err := GetAuthServerMetadata(ctx, config.MCPAuthServerURL, httpClient) - if err != nil { - return nil, fmt.Errorf("failed to discover MCP auth server metadata: %w", err) - } - - accessToken, err := oauthex.ExchangeJWTBearer( - ctx, - mcpMeta.TokenEndpoint, - tokenExchangeResp.AccessToken, - config.MCPClientID, - config.MCPClientSecret, - httpClient, - ) - if err != nil { - return nil, fmt.Errorf("JWT bearer grant failed: %w", err) - } - return accessToken, nil -} diff --git a/auth/enterprise_auth_test.go b/auth/enterprise_auth_test.go deleted file mode 100644 index 61274c0c..00000000 --- a/auth/enterprise_auth_test.go +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright 2026 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -//go:build mcp_go_client_oauth - -package auth - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/modelcontextprotocol/go-sdk/oauthex" -) - -// TestEnterpriseAuthFlow tests the complete enterprise auth flow. -func TestEnterpriseAuthFlow(t *testing.T) { - // Create test servers for IdP and MCP Server - idpServer := createMockIdPServer(t) - defer idpServer.Close() - mcpServer := createMockMCPServer(t) - defer mcpServer.Close() - // Create a test ID Token - idToken := createTestIDToken() - // Configure enterprise auth - config := &EnterpriseAuthConfig{ - IdPIssuerURL: idpServer.URL, - IdPClientID: "test-idp-client", - IdPClientSecret: "test-idp-secret", - MCPAuthServerURL: mcpServer.URL, - MCPResourceURI: "https://mcp.example.com", - MCPClientID: "test-mcp-client", - MCPClientSecret: "test-mcp-secret", - MCPScopes: []string{"read", "write"}, - HTTPClient: idpServer.Client(), - } - // Test successful flow - t.Run("successful flow", func(t *testing.T) { - token, err := EnterpriseAuthFlow(context.Background(), config, idToken) - if err != nil { - t.Fatalf("EnterpriseAuthFlow failed: %v", err) - } - if token.AccessToken != "mcp-access-token" { - t.Errorf("expected access token 'mcp-access-token', got '%s'", token.AccessToken) - } - if token.TokenType != "Bearer" { - t.Errorf("expected token type 'Bearer', got '%s'", token.TokenType) - } - }) - // Test missing config - t.Run("nil config", func(t *testing.T) { - _, err := EnterpriseAuthFlow(context.Background(), nil, idToken) - if err == nil { - t.Error("expected error for nil config, got nil") - } - }) - // Test missing ID token - t.Run("empty ID token", func(t *testing.T) { - _, err := EnterpriseAuthFlow(context.Background(), config, "") - if err == nil { - t.Error("expected error for empty ID token, got nil") - } - }) - // Test missing IdP issuer - t.Run("missing IdP issuer", func(t *testing.T) { - badConfig := *config - badConfig.IdPIssuerURL = "" - _, err := EnterpriseAuthFlow(context.Background(), &badConfig, idToken) - if err == nil { - t.Error("expected error for missing IdP issuer, got nil") - } - }) -} - -// createMockIdPServer creates a mock IdP server for testing. -func createMockIdPServer(t *testing.T) *httptest.Server { - var serverURL string - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Handle OIDC discovery endpoint - if r.URL.Path == "/.well-known/openid-configuration" { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ - "issuer": serverURL, // Use actual server URL - "token_endpoint": serverURL + "/oauth2/v1/token", - "jwks_uri": serverURL + "/.well-known/jwks.json", - "code_challenge_methods_supported": []string{"S256"}, - "grant_types_supported": []string{ - "authorization_code", - "urn:ietf:params:oauth:grant-type:token-exchange", - }, - "response_types_supported": []string{"code"}, - }) - return - } - - // Handle token exchange endpoint - if r.URL.Path != "/oauth2/v1/token" { - http.NotFound(w, r) - return - } - - if err := r.ParseForm(); err != nil { - http.Error(w, "failed to parse form", http.StatusBadRequest) - return - } - grantType := r.FormValue("grant_type") - if grantType != oauthex.GrantTypeTokenExchange { - http.Error(w, "invalid grant type", http.StatusBadRequest) - return - } - - // Return a mock ID-JAG - now := time.Now().Unix() - header := map[string]string{"typ": "oauth-id-jag+jwt", "alg": "RS256"} - claims := map[string]interface{}{ - "iss": "https://test.okta.com", - "sub": "test-user", - "aud": r.FormValue("audience"), - "resource": r.FormValue("resource"), - "client_id": r.FormValue("client_id"), - "jti": "test-jti", - "exp": now + 300, - "iat": now, - "scope": r.FormValue("scope"), - } - headerJSON, _ := json.Marshal(header) - claimsJSON, _ := json.Marshal(claims) - headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) - claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) - mockIDJAG := fmt.Sprintf("%s.%s.mock-signature", headerB64, claimsB64) - - resp := oauthex.TokenExchangeResponse{ - IssuedTokenType: oauthex.TokenTypeIDJAG, - AccessToken: mockIDJAG, - TokenType: "N_A", - ExpiresIn: 300, - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) - })) - serverURL = server.URL // Capture server URL for discovery response - return server -} - -// createMockMCPServer creates a mock MCP Server for testing. -func createMockMCPServer(t *testing.T) *httptest.Server { - var serverURL string - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Handle OIDC discovery endpoint - if r.URL.Path == "/.well-known/openid-configuration" { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ - "issuer": serverURL, // Use actual server URL - "token_endpoint": serverURL + "/v1/token", - "jwks_uri": serverURL + "/.well-known/jwks.json", - "code_challenge_methods_supported": []string{"S256"}, - "grant_types_supported": []string{ - "urn:ietf:params:oauth:grant-type:jwt-bearer", - }, - }) - return - } - - // Handle JWT Bearer endpoint - if r.URL.Path != "/v1/token" { - http.NotFound(w, r) - return - } - - if err := r.ParseForm(); err != nil { - http.Error(w, "failed to parse form", http.StatusBadRequest) - return - } - grantType := r.FormValue("grant_type") - if grantType != oauthex.GrantTypeJWTBearer { - http.Error(w, "invalid grant type", http.StatusBadRequest) - return - } - - resp := struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in,omitempty"` - Scope string `json:"scope,omitempty"` - RefreshToken string `json:"refresh_token,omitempty"` - }{ - AccessToken: "mcp-access-token", - TokenType: "Bearer", - ExpiresIn: 3600, - Scope: "read write", - RefreshToken: "mcp-refresh-token", - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) - })) - serverURL = server.URL // Capture server URL for discovery response - return server -} - -// createTestIDToken creates a mock ID Token for testing. -func createTestIDToken() string { - now := time.Now().Unix() - header := map[string]string{"typ": "JWT", "alg": "RS256"} - claims := map[string]interface{}{ - "iss": "https://test.okta.com", - "sub": "test-user", - "aud": "test-client", - "exp": now + 3600, - "iat": now, - "email": "test@example.com", - } - headerJSON, _ := json.Marshal(header) - claimsJSON, _ := json.Marshal(claims) - headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) - claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON) - - return fmt.Sprintf("%s.%s.mock-signature", headerB64, claimsB64) -} diff --git a/auth/extauth/enterprise_handler.go b/auth/extauth/enterprise_handler.go index 3cc5bc03..11c7c8e0 100644 --- a/auth/extauth/enterprise_handler.go +++ b/auth/extauth/enterprise_handler.go @@ -21,10 +21,19 @@ import ( "golang.org/x/oauth2" ) +// grantTypeJWTBearer is the grant type for RFC 7523 JWT Bearer authorization grant. +const grantTypeJWTBearer = "urn:ietf:params:oauth:grant-type:jwt-bearer" + +// IDTokenResult contains the ID token obtained from OIDC login. +type IDTokenResult struct { + // Token is the OpenID Connect ID Token (JWT). + Token string +} + // IDTokenFetcher is called to obtain an ID Token from the enterprise IdP. // This is typically done via OIDC login flow where the user authenticates // with their enterprise identity provider. -type IDTokenFetcher func(ctx context.Context) (string, error) +type IDTokenFetcher func(ctx context.Context) (*IDTokenResult, error) // EnterpriseHandlerConfig is the configuration for [EnterpriseHandler]. type EnterpriseHandlerConfig struct { @@ -136,7 +145,7 @@ func (h *EnterpriseHandler) Authorize(ctx context.Context, req *http.Request, re } // Step 1: Get ID Token via the configured fetcher (e.g., OIDC login) - idToken, err := h.config.IDTokenFetcher(ctx) + idTokenResult, err := h.config.IDTokenFetcher(ctx) if err != nil { return fmt.Errorf("failed to obtain ID token: %w", err) } @@ -153,7 +162,7 @@ func (h *EnterpriseHandler) Authorize(ctx context.Context, req *http.Request, re Audience: h.config.MCPAuthServerURL, Resource: h.config.MCPResourceURI, Scope: h.config.MCPScopes, - SubjectToken: idToken, + SubjectToken: idTokenResult.Token, SubjectTokenType: oauthex.TokenTypeIDToken, } @@ -161,8 +170,10 @@ func (h *EnterpriseHandler) Authorize(ctx context.Context, req *http.Request, re ctx, idpMeta.TokenEndpoint, tokenExchangeReq, - h.config.IdPClientID, - h.config.IdPClientSecret, + &oauthex.ClientCredentials{ + ClientID: h.config.IdPClientID, + ClientSecret: h.config.IdPClientSecret, + }, httpClient, ) if err != nil { @@ -176,12 +187,14 @@ func (h *EnterpriseHandler) Authorize(ctx context.Context, req *http.Request, re } // Step 5: JWT Bearer Grant (ID-JAG → Access Token) - accessToken, err := oauthex.ExchangeJWTBearer( + accessToken, err := exchangeJWTBearer( ctx, mcpMeta.TokenEndpoint, tokenExchangeResp.AccessToken, - h.config.MCPClientID, - h.config.MCPClientSecret, + &oauthex.ClientCredentials{ + ClientID: h.config.MCPClientID, + ClientSecret: h.config.MCPClientSecret, + }, httpClient, ) if err != nil { @@ -192,3 +205,39 @@ func (h *EnterpriseHandler) Authorize(ctx context.Context, req *http.Request, re h.tokenSource = oauth2.StaticTokenSource(accessToken) return nil } + +// exchangeJWTBearer exchanges an Identity Assertion JWT Authorization Grant (ID-JAG) +// for an access token using JWT Bearer Grant per RFC 7523. +func exchangeJWTBearer( + ctx context.Context, + tokenEndpoint string, + assertion string, + clientCreds *oauthex.ClientCredentials, + httpClient *http.Client, +) (*oauth2.Token, error) { + cfg := &oauth2.Config{ + ClientID: clientCreds.ClientID, + ClientSecret: clientCreds.ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: tokenEndpoint, + AuthStyle: oauth2.AuthStyleInParams, + }, + } + + if httpClient == nil { + httpClient = http.DefaultClient + } + ctxWithClient := context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + token, err := cfg.Exchange( + ctxWithClient, + "", + oauth2.SetAuthURLParam("grant_type", grantTypeJWTBearer), + oauth2.SetAuthURLParam("assertion", assertion), + ) + if err != nil { + return nil, fmt.Errorf("JWT bearer grant request failed: %w", err) + } + + return token, nil +} diff --git a/auth/extauth/oidc_login.go b/auth/extauth/oidc_login.go index e42cd704..e738de40 100644 --- a/auth/extauth/oidc_login.go +++ b/auth/extauth/oidc_login.go @@ -22,8 +22,8 @@ import ( ) // OIDCLoginConfig configures the OIDC Authorization Code flow for obtaining -// an ID Token. This is an OPTIONAL step before calling EnterpriseAuthFlow. -// Users can alternatively obtain ID tokens through their own methods. +// an ID Token. This is used with [PerformOIDCLogin] to authenticate users +// with an enterprise IdP before calling the Enterprise Managed Authorization flow. type OIDCLoginConfig struct { // IssuerURL is the IdP's issuer URL (e.g., "https://acme.okta.com"). // REQUIRED. @@ -53,25 +53,10 @@ type OIDCLoginConfig struct { HTTPClient *http.Client } -// OIDCAuthorizationRequest represents the result of initiating an OIDC -// authorization code flow. Users must direct the end-user to AuthURL -// to complete authentication. -type OIDCAuthorizationRequest struct { - // AuthURL is the URL the user should visit to authenticate. - // This URL includes the authorization request parameters. - AuthURL string - // State is the OAuth2 state parameter for CSRF protection. - // Users MUST validate that the state returned from the IdP matches this value. - State string - // CodeVerifier is the PKCE code verifier for secure authorization code exchange. - // This must be provided to CompleteOIDCLogin along with the authorization code. - CodeVerifier string -} - // OIDCTokenResponse contains the tokens returned from a successful OIDC login. type OIDCTokenResponse struct { // IDToken is the OpenID Connect ID Token (JWT). - // This can be passed to EnterpriseAuthFlow for token exchange. + // This can be passed to EnterpriseHandler's IDTokenFetcher. IDToken string // AccessToken is the OAuth2 access token (if issued by IdP). // This is typically not needed for SEP-990, but may be useful for other IdP APIs. @@ -84,53 +69,86 @@ type OIDCTokenResponse struct { ExpiresAt int64 } -// InitiateOIDCLogin initiates an OIDC Authorization Code flow with PKCE. -// This is the first step for users who want to use SSO to obtain an ID token. -// -// The returned AuthURL should be presented to the user (e.g., opened in a browser). -// After the user authenticates, the IdP will redirect to the RedirectURL with -// an authorization code and state parameter. -// -// Example: -// -// config := &OIDCLoginConfig{ -// IssuerURL: "https://acme.okta.com", -// ClientID: "client-id", -// RedirectURL: "http://localhost:8080/callback", -// Scopes: []string{"openid", "profile", "email"}, -// } +// AuthorizationCodeFetcher is a callback function that handles directing the user +// to the authorization URL and returning the authorization result. // -// authReq, err := InitiateOIDCLogin(ctx, config) -// if err != nil { -// log.Fatal(err) -// } +// Implementations should: +// 1. Direct the user to args.URL (e.g., open in browser) +// 2. Wait for the IdP redirect to the configured RedirectURL +// 3. Extract the code and state from the redirect query parameters +// 4. Return them in auth.AuthorizationResult // -// // Direct user to authReq.AuthURL -// fmt.Printf("Visit this URL to login: %s\n", authReq.AuthURL) +// The implementation MUST verify that the returned state matches the state +// sent in the authorization request (available via parsing args.URL). +type AuthorizationCodeFetcher func(ctx context.Context, args auth.AuthorizationArgs) (*auth.AuthorizationResult, error) + +// PerformOIDCLogin performs the complete OIDC Authorization Code flow with PKCE +// in a single function call. This is the recommended approach for obtaining an +// ID Token for use with [EnterpriseHandler]. // -// // After user completes login, IdP redirects to RedirectURL with code & state -// // Extract code and state from the redirect, then call CompleteOIDCLogin -func InitiateOIDCLogin( +// The authCodeFetcher callback handles the user interaction: +// - Directing the user to the IdP login page +// - Waiting for the redirect with the authorization code +// - Validating CSRF state and returning the result +func PerformOIDCLogin( + ctx context.Context, + config *OIDCLoginConfig, + authCodeFetcher AuthorizationCodeFetcher, +) (*OIDCTokenResponse, error) { + if authCodeFetcher == nil { + return nil, fmt.Errorf("authCodeFetcher is required") + } + + authReq, oauth2Config, err := initiateOIDCLogin(ctx, config) + if err != nil { + return nil, fmt.Errorf("failed to initiate OIDC login: %w", err) + } + + authResult, err := authCodeFetcher(ctx, auth.AuthorizationArgs{URL: authReq.authURL}) + if err != nil { + return nil, fmt.Errorf("failed to fetch authorization code: %w", err) + } + + if authResult.State != authReq.state { + return nil, fmt.Errorf("state mismatch: expected %q, got %q", authReq.state, authResult.State) + } + + tokens, err := completeOIDCLogin(ctx, config, oauth2Config, authResult.Code, authReq.codeVerifier) + if err != nil { + return nil, fmt.Errorf("failed to complete OIDC login: %w", err) + } + + return tokens, nil +} + +// oidcAuthorizationRequest holds internal state for OIDC authorization. +type oidcAuthorizationRequest struct { + authURL string + state string + codeVerifier string +} + +// initiateOIDCLogin initiates an OIDC Authorization Code flow with PKCE. +func initiateOIDCLogin( ctx context.Context, config *OIDCLoginConfig, -) (*OIDCAuthorizationRequest, error) { +) (*oidcAuthorizationRequest, *oauth2.Config, error) { if config == nil { - return nil, fmt.Errorf("config is required") + return nil, nil, fmt.Errorf("config is required") } - // Validate required fields if config.IssuerURL == "" { - return nil, fmt.Errorf("IssuerURL is required") + return nil, nil, fmt.Errorf("IssuerURL is required") } if config.ClientID == "" { - return nil, fmt.Errorf("ClientID is required") + return nil, nil, fmt.Errorf("ClientID is required") } if config.RedirectURL == "" { - return nil, fmt.Errorf("RedirectURL is required") + return nil, nil, fmt.Errorf("RedirectURL is required") } if len(config.Scopes) == 0 { - return nil, fmt.Errorf("Scopes is required (must include 'openid')") + return nil, nil, fmt.Errorf("Scopes is required (must include 'openid')") } - // Validate that "openid" scope is present (required for OIDC) + hasOpenID := false for _, scope := range config.Scopes { if scope == "openid" { @@ -139,35 +157,31 @@ func InitiateOIDCLogin( } } if !hasOpenID { - return nil, fmt.Errorf("Scopes must include 'openid' for OIDC") + return nil, nil, fmt.Errorf("Scopes must include 'openid' for OIDC") } - // Validate URL schemes to prevent XSS attacks + if err := oauthex.CheckURLScheme(config.IssuerURL); err != nil { - return nil, fmt.Errorf("invalid IssuerURL: %w", err) + return nil, nil, fmt.Errorf("invalid IssuerURL: %w", err) } if err := oauthex.CheckURLScheme(config.RedirectURL); err != nil { - return nil, fmt.Errorf("invalid RedirectURL: %w", err) + return nil, nil, fmt.Errorf("invalid RedirectURL: %w", err) } - // Discover OIDC endpoints via .well-known + httpClient := config.HTTPClient if httpClient == nil { httpClient = http.DefaultClient } meta, err := auth.GetAuthServerMetadata(ctx, config.IssuerURL, httpClient) if err != nil { - return nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) + return nil, nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) } if meta.AuthorizationEndpoint == "" { - return nil, fmt.Errorf("authorization_endpoint not found in OIDC metadata") + return nil, nil, fmt.Errorf("authorization_endpoint not found in OIDC metadata") } - // Generate PKCE code verifier (RFC 7636) codeVerifier := oauth2.GenerateVerifier() - - // Generate state for CSRF protection (RFC 6749 Section 10.12) state := rand.Text() - // Build oauth2.Config to use standard library's AuthCodeURL. oauth2Config := &oauth2.Config{ ClientID: config.ClientID, ClientSecret: config.ClientSecret, @@ -179,8 +193,6 @@ func InitiateOIDCLogin( }, } - // Build authorization URL using oauth2.Config.AuthCodeURL with PKCE. - // S256ChallengeOption automatically computes the S256 challenge from the verifier. authURLOpts := []oauth2.AuthCodeOption{ oauth2.S256ChallengeOption(codeVerifier), } @@ -189,95 +201,35 @@ func InitiateOIDCLogin( } authURL := oauth2Config.AuthCodeURL(state, authURLOpts...) - return &OIDCAuthorizationRequest{ - AuthURL: authURL, - State: state, - CodeVerifier: codeVerifier, - }, nil + return &oidcAuthorizationRequest{ + authURL: authURL, + state: state, + codeVerifier: codeVerifier, + }, oauth2Config, nil } -// CompleteOIDCLogin completes the OIDC Authorization Code flow by exchanging -// the authorization code for tokens. This is the second step after the user -// has authenticated and been redirected back to the application. -// -// The authCode and returnedState parameters should come from the redirect URL -// query parameters. The state MUST match the state from InitiateOIDCLogin -// for CSRF protection. -// -// Example: -// -// // In your redirect handler (e.g., http://localhost:8080/callback) -// authCode := r.URL.Query().Get("code") -// returnedState := r.URL.Query().Get("state") -// -// // Validate state matches what we sent -// if returnedState != authReq.State { -// log.Fatal("State mismatch - possible CSRF attack") -// } -// -// // Exchange code for tokens -// tokens, err := CompleteOIDCLogin(ctx, config, authCode, authReq.CodeVerifier) -// if err != nil { -// log.Fatal(err) -// } -// -// // Now use tokens.IDToken with EnterpriseAuthFlow -// accessToken, err := EnterpriseAuthFlow(ctx, enterpriseConfig, tokens.IDToken) -func CompleteOIDCLogin( +// completeOIDCLogin completes the OIDC Authorization Code flow by exchanging +// the authorization code for tokens. +func completeOIDCLogin( ctx context.Context, config *OIDCLoginConfig, + oauth2Config *oauth2.Config, authCode string, codeVerifier string, ) (*OIDCTokenResponse, error) { - if config == nil { - return nil, fmt.Errorf("config is required") - } if authCode == "" { return nil, fmt.Errorf("authCode is required") } if codeVerifier == "" { return nil, fmt.Errorf("codeVerifier is required") } - // Validate required fields - if config.IssuerURL == "" { - return nil, fmt.Errorf("IssuerURL is required") - } - if config.ClientID == "" { - return nil, fmt.Errorf("ClientID is required") - } - if config.RedirectURL == "" { - return nil, fmt.Errorf("RedirectURL is required") - } - // Discover token endpoint + httpClient := config.HTTPClient if httpClient == nil { httpClient = http.DefaultClient } - meta, err := auth.GetAuthServerMetadata(ctx, config.IssuerURL, httpClient) - if err != nil { - return nil, fmt.Errorf("failed to discover OIDC metadata: %w", err) - } - if meta.TokenEndpoint == "" { - return nil, fmt.Errorf("token_endpoint not found in OIDC metadata") - } - - // Build oauth2.Config for token exchange. - oauth2Config := &oauth2.Config{ - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, - RedirectURL: config.RedirectURL, - Scopes: config.Scopes, - Endpoint: oauth2.Endpoint{ - AuthURL: meta.AuthorizationEndpoint, - TokenURL: meta.TokenEndpoint, - }, - } - - // Use custom HTTP client if provided ctxWithClient := context.WithValue(ctx, oauth2.HTTPClient, httpClient) - // Exchange authorization code for tokens using oauth2.Config.Exchange. - // VerifierOption provides the PKCE code_verifier for the token request. oauth2Token, err := oauth2Config.Exchange( ctxWithClient, authCode, @@ -287,8 +239,6 @@ func CompleteOIDCLogin( return nil, fmt.Errorf("token exchange failed: %w", err) } - // Extract ID Token from response. - // oauth2.Token.Extra() provides access to additional fields like id_token. idToken, ok := oauth2Token.Extra("id_token").(string) if !ok || idToken == "" { return nil, fmt.Errorf("id_token not found in token response") @@ -301,97 +251,3 @@ func CompleteOIDCLogin( ExpiresAt: oauth2Token.Expiry.Unix(), }, nil } - -// OIDCAuthorizationResult contains the authorization code and state returned -// from the IdP after user authentication. -type OIDCAuthorizationResult struct { - // Code is the authorization code returned by the IdP. - Code string - // State is the state parameter returned by the IdP. - // This MUST match the state sent in the authorization request. - State string -} - -// AuthorizationCodeFetcher is a callback function that handles directing the user -// to the authorization URL and returning the authorization result. -// -// Implementations should: -// 1. Direct the user to authURL (e.g., open in browser) -// 2. Wait for the IdP redirect to the configured RedirectURL -// 3. Extract the code and state from the redirect query parameters -// 4. Return them in OIDCAuthorizationResult -// -// The expectedState parameter is provided for CSRF validation. Implementations -// MUST verify that the returned state matches expectedState. -type AuthorizationCodeFetcher func(ctx context.Context, authURL string, expectedState string) (*OIDCAuthorizationResult, error) - -// PerformOIDCLogin performs the complete OIDC Authorization Code flow with PKCE -// in a single function call. This is the recommended approach for most use cases. -// -// The authCodeFetcher callback handles the user interaction: -// - Directing the user to the IdP login page -// - Waiting for the redirect with the authorization code -// - Validating CSRF state and returning the result -// -// Example: -// -// config := &OIDCLoginConfig{ -// IssuerURL: "https://acme.okta.com", -// ClientID: "client-id", -// RedirectURL: "http://localhost:8080/callback", -// Scopes: []string{"openid", "profile", "email"}, -// } -// -// tokens, err := PerformOIDCLogin(ctx, config, func(ctx context.Context, authURL, expectedState string) (*OIDCAuthorizationResult, error) { -// // Open browser for user -// fmt.Printf("Please visit: %s\n", authURL) -// -// // Start local server to receive callback -// code, state := waitForCallback(ctx) -// -// // Validate state for CSRF protection -// if state != expectedState { -// return nil, fmt.Errorf("state mismatch") -// } -// -// return &OIDCAuthorizationResult{Code: code, State: state}, nil -// }) -// if err != nil { -// log.Fatal(err) -// } -// -// // Use tokens.IDToken with EnterpriseHandler -func PerformOIDCLogin( - ctx context.Context, - config *OIDCLoginConfig, - authCodeFetcher AuthorizationCodeFetcher, -) (*OIDCTokenResponse, error) { - if authCodeFetcher == nil { - return nil, fmt.Errorf("authCodeFetcher is required") - } - - // Step 1: Initiate the OIDC flow to get the authorization URL - authReq, err := InitiateOIDCLogin(ctx, config) - if err != nil { - return nil, fmt.Errorf("failed to initiate OIDC login: %w", err) - } - - // Step 2: Use callback to get authorization code from user interaction - authResult, err := authCodeFetcher(ctx, authReq.AuthURL, authReq.State) - if err != nil { - return nil, fmt.Errorf("failed to fetch authorization code: %w", err) - } - - // Step 3: Validate state for CSRF protection - if authResult.State != authReq.State { - return nil, fmt.Errorf("state mismatch: expected %q, got %q", authReq.State, authResult.State) - } - - // Step 4: Exchange authorization code for tokens - tokens, err := CompleteOIDCLogin(ctx, config, authResult.Code, authReq.CodeVerifier) - if err != nil { - return nil, fmt.Errorf("failed to complete OIDC login: %w", err) - } - - return tokens, nil -} diff --git a/auth/extauth/oidc_login_test.go b/auth/extauth/oidc_login_test.go index a82a8456..5645b457 100644 --- a/auth/extauth/oidc_login_test.go +++ b/auth/extauth/oidc_login_test.go @@ -17,6 +17,8 @@ import ( "strings" "testing" "time" + + "github.com/modelcontextprotocol/go-sdk/auth" ) // TestInitiateOIDCLogin tests the OIDC authorization request generation. @@ -32,18 +34,18 @@ func TestInitiateOIDCLogin(t *testing.T) { HTTPClient: idpServer.Client(), } t.Run("successful initiation", func(t *testing.T) { - authReq, err := InitiateOIDCLogin(context.Background(), config) + authReq, _, err := initiateOIDCLogin(context.Background(), config) if err != nil { - t.Fatalf("InitiateOIDCLogin failed: %v", err) + t.Fatalf("initiateOIDCLogin failed: %v", err) } - // Validate AuthURL - if authReq.AuthURL == "" { - t.Error("AuthURL is empty") + // Validate authURL + if authReq.authURL == "" { + t.Error("authURL is empty") } // Parse and validate URL parameters - u, err := url.Parse(authReq.AuthURL) + u, err := url.Parse(authReq.authURL) if err != nil { - t.Fatalf("Failed to parse AuthURL: %v", err) + t.Fatalf("Failed to parse authURL: %v", err) } q := u.Query() if q.Get("response_type") != "code" { @@ -62,15 +64,15 @@ func TestInitiateOIDCLogin(t *testing.T) { t.Errorf("expected code_challenge_method 'S256', got '%s'", q.Get("code_challenge_method")) } // Validate state is generated - if authReq.State == "" { - t.Error("State is empty") + if authReq.state == "" { + t.Error("state is empty") } - if q.Get("state") != authReq.State { + if q.Get("state") != authReq.state { t.Errorf("state in URL doesn't match returned state") } // Validate PKCE parameters - if authReq.CodeVerifier == "" { - t.Error("CodeVerifier is empty") + if authReq.codeVerifier == "" { + t.Error("codeVerifier is empty") } if q.Get("code_challenge") == "" { t.Error("code_challenge is empty") @@ -79,13 +81,13 @@ func TestInitiateOIDCLogin(t *testing.T) { t.Run("with login_hint", func(t *testing.T) { configWithHint := *config configWithHint.LoginHint = "user@example.com" - authReq, err := InitiateOIDCLogin(context.Background(), &configWithHint) + authReq, _, err := initiateOIDCLogin(context.Background(), &configWithHint) if err != nil { - t.Fatalf("InitiateOIDCLogin failed: %v", err) + t.Fatalf("initiateOIDCLogin failed: %v", err) } - u, err := url.Parse(authReq.AuthURL) + u, err := url.Parse(authReq.authURL) if err != nil { - t.Fatalf("Failed to parse AuthURL: %v", err) + t.Fatalf("Failed to parse authURL: %v", err) } q := u.Query() if q.Get("login_hint") != "user@example.com" { @@ -93,13 +95,13 @@ func TestInitiateOIDCLogin(t *testing.T) { } }) t.Run("without login_hint", func(t *testing.T) { - authReq, err := InitiateOIDCLogin(context.Background(), config) + authReq, _, err := initiateOIDCLogin(context.Background(), config) if err != nil { - t.Fatalf("InitiateOIDCLogin failed: %v", err) + t.Fatalf("initiateOIDCLogin failed: %v", err) } - u, err := url.Parse(authReq.AuthURL) + u, err := url.Parse(authReq.authURL) if err != nil { - t.Fatalf("Failed to parse AuthURL: %v", err) + t.Fatalf("Failed to parse authURL: %v", err) } q := u.Query() if q.Has("login_hint") { @@ -107,7 +109,7 @@ func TestInitiateOIDCLogin(t *testing.T) { } }) t.Run("nil config", func(t *testing.T) { - _, err := InitiateOIDCLogin(context.Background(), nil) + _, _, err := initiateOIDCLogin(context.Background(), nil) if err == nil { t.Error("expected error for nil config, got nil") } @@ -115,7 +117,7 @@ func TestInitiateOIDCLogin(t *testing.T) { t.Run("missing openid scope", func(t *testing.T) { badConfig := *config badConfig.Scopes = []string{"profile", "email"} // Missing "openid" - _, err := InitiateOIDCLogin(context.Background(), &badConfig) + _, _, err := initiateOIDCLogin(context.Background(), &badConfig) if err == nil { t.Error("expected error for missing openid scope, got nil") } @@ -154,7 +156,7 @@ func TestInitiateOIDCLogin(t *testing.T) { t.Run(tt.name, func(t *testing.T) { badConfig := *config tt.mutate(&badConfig) - _, err := InitiateOIDCLogin(context.Background(), &badConfig) + _, _, err := initiateOIDCLogin(context.Background(), &badConfig) if err == nil { t.Error("expected error, got nil") } @@ -180,14 +182,21 @@ func TestCompleteOIDCLogin(t *testing.T) { HTTPClient: idpServer.Client(), } t.Run("successful code exchange", func(t *testing.T) { - tokens, err := CompleteOIDCLogin( + // First initiate to get oauth2Config + _, oauth2Config, err := initiateOIDCLogin(context.Background(), config) + if err != nil { + t.Fatalf("initiateOIDCLogin failed: %v", err) + } + + tokens, err := completeOIDCLogin( context.Background(), config, + oauth2Config, "test-auth-code", "test-code-verifier", ) if err != nil { - t.Fatalf("CompleteOIDCLogin failed: %v", err) + t.Fatalf("completeOIDCLogin failed: %v", err) } // Validate tokens if tokens.IDToken == "" { @@ -203,18 +212,9 @@ func TestCompleteOIDCLogin(t *testing.T) { t.Error("ExpiresAt is zero") } }) - t.Run("nil config", func(t *testing.T) { - _, err := CompleteOIDCLogin( - context.Background(), - nil, - "test-auth-code", - "test-code-verifier", - ) - if err == nil { - t.Error("expected error for nil config, got nil") - } - }) t.Run("missing parameters", func(t *testing.T) { + _, oauth2Config, _ := initiateOIDCLogin(context.Background(), config) + tests := []struct { name string authCode string @@ -236,9 +236,10 @@ func TestCompleteOIDCLogin(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := CompleteOIDCLogin( + _, err := completeOIDCLogin( context.Background(), config, + oauth2Config, tt.authCode, tt.codeVerifier, ) @@ -267,23 +268,24 @@ func TestOIDCLoginE2E(t *testing.T) { HTTPClient: idpServer.Client(), } // Step 1: Initiate login - authReq, err := InitiateOIDCLogin(context.Background(), config) + authReq, oauth2Config, err := initiateOIDCLogin(context.Background(), config) if err != nil { - t.Fatalf("InitiateOIDCLogin failed: %v", err) + t.Fatalf("initiateOIDCLogin failed: %v", err) } // Step 2: Simulate user authentication and redirect - // (In real flow, user would visit authReq.AuthURL and IdP would redirect back) + // (In real flow, user would visit authReq.authURL and IdP would redirect back) // Here we just use a mock authorization code mockAuthCode := "mock-authorization-code" // Step 3: Complete login with authorization code - tokens, err := CompleteOIDCLogin( + tokens, err := completeOIDCLogin( context.Background(), config, + oauth2Config, mockAuthCode, - authReq.CodeVerifier, + authReq.codeVerifier, ) if err != nil { - t.Fatalf("CompleteOIDCLogin failed: %v", err) + t.Fatalf("completeOIDCLogin failed: %v", err) } // Validate we got an ID token if tokens.IDToken == "" { @@ -296,7 +298,7 @@ func TestOIDCLoginE2E(t *testing.T) { } } -// createMockOIDCServer creates a mock OIDC server for testing InitiateOIDCLogin. +// createMockOIDCServer creates a mock OIDC server for testing initiateOIDCLogin. func createMockOIDCServer(t *testing.T) *httptest.Server { var serverURL string server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -399,9 +401,9 @@ func TestPerformOIDCLogin(t *testing.T) { t.Run("successful flow", func(t *testing.T) { tokens, err := PerformOIDCLogin(context.Background(), config, - func(ctx context.Context, authURL, expectedState string) (*OIDCAuthorizationResult, error) { + func(ctx context.Context, args auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { // Validate authURL has required parameters - u, err := url.Parse(authURL) + u, err := url.Parse(args.URL) if err != nil { return nil, fmt.Errorf("invalid authURL: %w", err) } @@ -414,9 +416,9 @@ func TestPerformOIDCLogin(t *testing.T) { } // Simulate successful user authentication - return &OIDCAuthorizationResult{ + return &auth.AuthorizationResult{ Code: "mock-auth-code", - State: expectedState, // Return the expected state + State: q.Get("state"), // Return the expected state from URL }, nil }) @@ -434,9 +436,9 @@ func TestPerformOIDCLogin(t *testing.T) { t.Run("state mismatch", func(t *testing.T) { _, err := PerformOIDCLogin(context.Background(), config, - func(ctx context.Context, authURL, expectedState string) (*OIDCAuthorizationResult, error) { + func(ctx context.Context, args auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { // Return wrong state to simulate CSRF attack - return &OIDCAuthorizationResult{ + return &auth.AuthorizationResult{ Code: "mock-auth-code", State: "wrong-state", }, nil @@ -452,7 +454,7 @@ func TestPerformOIDCLogin(t *testing.T) { t.Run("fetcher error", func(t *testing.T) { _, err := PerformOIDCLogin(context.Background(), config, - func(ctx context.Context, authURL, expectedState string) (*OIDCAuthorizationResult, error) { + func(ctx context.Context, args auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { return nil, fmt.Errorf("user cancelled") }) diff --git a/oauthex/jwt_bearer.go b/oauthex/jwt_bearer.go deleted file mode 100644 index 4856cc28..00000000 --- a/oauthex/jwt_bearer.go +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2026 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -// This file implements JWT Bearer Authorization Grant (RFC 7523) for Enterprise Managed Authorization. -// See https://datatracker.ietf.org/doc/html/rfc7523 - -//go:build mcp_go_client_oauth - -package oauthex - -import ( - "context" - "fmt" - "net/http" - - "golang.org/x/oauth2" -) - -// GrantTypeJWTBearer is the grant type for RFC 7523 JWT Bearer authorization grant. -// This is used in SEP-990 to exchange an ID-JAG for an access token at the MCP Server. -const GrantTypeJWTBearer = "urn:ietf:params:oauth:grant-type:jwt-bearer" - -// ExchangeJWTBearer exchanges an Identity Assertion JWT Authorization Grant (ID-JAG) -// for an access token using JWT Bearer Grant per RFC 7523. This is the second step -// in Enterprise Managed Authorization (SEP-990) after obtaining the ID-JAG from the -// IdP via token exchange. -// -// The tokenEndpoint parameter should be the MCP Server's token endpoint (typically -// obtained from the MCP Server's authorization server metadata). -// -// The assertion parameter should be the ID-JAG JWT obtained from the token exchange -// step with the enterprise IdP. -// -// Client authentication must be performed by the caller by including appropriate -// credentials in the request (e.g., using Basic auth via the Authorization header, -// or including client_id and client_secret in the form data). -// -// Example: -// -// // First, get ID-JAG via token exchange -// idJAG := tokenExchangeResp.AccessToken -// -// // Then exchange ID-JAG for access token -// token, err := ExchangeJWTBearer( -// ctx, -// "https://auth.mcpserver.example/oauth2/token", -// idJAG, -// "mcp-client-id", -// "mcp-client-secret", -// nil, -// ) -func ExchangeJWTBearer( - ctx context.Context, - tokenEndpoint string, - assertion string, - clientID string, - clientSecret string, - httpClient *http.Client, -) (*oauth2.Token, error) { - if tokenEndpoint == "" { - return nil, fmt.Errorf("token endpoint is required") - } - if assertion == "" { - return nil, fmt.Errorf("assertion is required") - } - // Validate URL scheme to prevent XSS attacks (see #526) - if err := CheckURLScheme(tokenEndpoint); err != nil { - return nil, fmt.Errorf("invalid token endpoint: %w", err) - } - - // Per RFC 6749 Section 3.2, parameters sent without a value (like the empty - // "code" parameter) MUST be treated as if they were omitted from the request. - // The oauth2 library's Exchange method sends an empty code, but compliant - // servers should ignore it. - cfg := &oauth2.Config{ - ClientID: clientID, - ClientSecret: clientSecret, - Endpoint: oauth2.Endpoint{ - TokenURL: tokenEndpoint, - AuthStyle: oauth2.AuthStyleInParams, // Use POST body auth per SEP-990 - }, - } - - // Use custom HTTP client if provided - if httpClient == nil { - httpClient = http.DefaultClient - } - ctxWithClient := context.WithValue(ctx, oauth2.HTTPClient, httpClient) - - // Exchange with JWT Bearer grant type and assertion. - // SetAuthURLParam overrides the default grant_type and adds the assertion parameter. - token, err := cfg.Exchange( - ctxWithClient, - "", // empty code - per RFC 6749 Section 3.2, empty params should be ignored - oauth2.SetAuthURLParam("grant_type", GrantTypeJWTBearer), - oauth2.SetAuthURLParam("assertion", assertion), - ) - if err != nil { - return nil, fmt.Errorf("JWT bearer grant request failed: %w", err) - } - - return token, nil -} diff --git a/oauthex/jwt_bearer_test.go b/oauthex/jwt_bearer_test.go deleted file mode 100644 index 8a04732a..00000000 --- a/oauthex/jwt_bearer_test.go +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2026 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. - -//go:build mcp_go_client_oauth - -package oauthex - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "time" -) - -// TestExchangeJWTBearer tests the JWT Bearer grant flow. -func TestExchangeJWTBearer(t *testing.T) { - // Create a test MCP Server auth server that accepts JWT Bearer grants - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify request method and content type - if r.Method != http.MethodPost { - t.Errorf("expected POST request, got %s", r.Method) - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - contentType := r.Header.Get("Content-Type") - if contentType != "application/x-www-form-urlencoded" { - t.Errorf("expected application/x-www-form-urlencoded, got %s", contentType) - http.Error(w, "invalid content type", http.StatusBadRequest) - return - } - // Parse form data - if err := r.ParseForm(); err != nil { - http.Error(w, "failed to parse form", http.StatusBadRequest) - return - } - // Verify grant type per RFC 7523 - grantType := r.FormValue("grant_type") - if grantType != GrantTypeJWTBearer { - t.Errorf("expected grant_type %s, got %s", GrantTypeJWTBearer, grantType) - writeJWTBearerErrorResponse(w, "unsupported_grant_type", "grant type not supported") - return - } - // Verify assertion is provided - assertion := r.FormValue("assertion") - if assertion == "" { - t.Error("assertion is required") - writeJWTBearerErrorResponse(w, "invalid_request", "missing assertion") - return - } - // Verify client authentication - clientID := r.FormValue("client_id") - clientSecret := r.FormValue("client_secret") - if clientID == "" || clientSecret == "" { - t.Error("client authentication required") - writeJWTBearerErrorResponse(w, "invalid_client", "client authentication failed") - return - } - if clientID != "mcp-client-id" || clientSecret != "mcp-client-secret" { - t.Error("invalid client credentials") - writeJWTBearerErrorResponse(w, "invalid_client", "invalid credentials") - return - } - // Return successful OAuth token response - resp := struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in,omitempty"` - Scope string `json:"scope,omitempty"` - RefreshToken string `json:"refresh_token,omitempty"` - }{ - AccessToken: "mcp-access-token-123", - TokenType: "Bearer", - ExpiresIn: 3600, - Scope: "read write", - RefreshToken: "mcp-refresh-token-456", - } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Cache-Control", "no-store") - w.Header().Set("Pragma", "no-cache") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(resp) - })) - defer server.Close() - // Test successful JWT Bearer grant - t.Run("successful exchange", func(t *testing.T) { - token, err := ExchangeJWTBearer( - context.Background(), - server.URL, - "fake-id-jag-jwt", - "mcp-client-id", - "mcp-client-secret", - server.Client(), - ) - if err != nil { - t.Fatalf("ExchangeJWTBearer failed: %v", err) - } - if token.AccessToken != "mcp-access-token-123" { - t.Errorf("expected access_token 'mcp-access-token-123', got %s", token.AccessToken) - } - if token.TokenType != "Bearer" { - t.Errorf("expected token_type 'Bearer', got %s", token.TokenType) - } - if token.RefreshToken != "mcp-refresh-token-456" { - t.Errorf("expected refresh_token 'mcp-refresh-token-456', got %s", token.RefreshToken) - } - // Check expiration (should be ~1 hour from now) - expectedExpiry := time.Now().Add(3600 * time.Second) - if token.Expiry.Before(time.Now()) || token.Expiry.After(expectedExpiry.Add(5*time.Second)) { - t.Errorf("unexpected expiry time: %v", token.Expiry) - } - // Check scope in extra data - scope, ok := token.Extra("scope").(string) - if !ok || scope != "read write" { - t.Errorf("expected scope 'read write', got %v", token.Extra("scope")) - } - }) - // Test missing assertion - t.Run("missing assertion", func(t *testing.T) { - _, err := ExchangeJWTBearer( - context.Background(), - server.URL, - "", // empty assertion - "mcp-client-id", - "mcp-client-secret", - server.Client(), - ) - if err == nil { - t.Error("expected error for missing assertion, got nil") - } - }) - // Test invalid URL scheme - t.Run("invalid token endpoint URL", func(t *testing.T) { - _, err := ExchangeJWTBearer( - context.Background(), - "javascript:alert(1)", - "fake-id-jag-jwt", - "mcp-client-id", - "mcp-client-secret", - server.Client(), - ) - if err == nil { - t.Error("expected error for invalid URL scheme, got nil") - } - }) -} - -// writeJWTBearerErrorResponse writes an OAuth 2.0 error response per RFC 6749 Section 5.2. -func writeJWTBearerErrorResponse(w http.ResponseWriter, errorCode, errorDescription string) { - errResp := struct { - Error string `json:"error"` - ErrorDescription string `json:"error_description,omitempty"` - }{ - Error: errorCode, - ErrorDescription: errorDescription, - } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Cache-Control", "no-store") - w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(errResp) -} From a7963d330b909d7edf0f9752cc53db3192a674da Mon Sep 17 00:00:00 2001 From: Pranav RK Date: Fri, 20 Mar 2026 23:06:11 +0530 Subject: [PATCH 19/19] docs: update docs for enterprise authorization --- docs/protocol.md | 65 +++++--- examples/auth/enterprise-client/main.go | 197 ++++++++++++++++++++++++ internal/docs/protocol.src.md | 65 +++++--- oauthex/token_exchange.go | 14 -- 4 files changed, 285 insertions(+), 56 deletions(-) create mode 100644 examples/auth/enterprise-client/main.go diff --git a/docs/protocol.md b/docs/protocol.md index cc487ccb..47790c54 100644 --- a/docs/protocol.md +++ b/docs/protocol.md @@ -357,40 +357,63 @@ session, err := client.Connect(ctx, transport, nil) The `auth.AuthorizationCodeHandler` automatically manages token refreshing and step-up authentication (when the server returns `insufficient_scope` error). -#### Enterprise Authentication Flow (SEP-990) +#### Enterprise Managed Authorization (SEP-990) -For enterprise SSO scenarios, the SDK provides an -[`EnterpriseAuthFlow`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#EnterpriseAuthFlow) -function that implements the complete token exchange flow: +For enterprise SSO scenarios where users authenticate with an enterprise Identity Provider (IdP), +the SDK provides +[`extauth.EnterpriseHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth/extauth#EnterpriseHandler), +an implementation of `OAuthHandler` that automates the Enterprise Managed Authorization flow: -1. **Token Exchange** at IdP: ID Token → ID-JAG -2. **JWT Bearer Grant** at MCP Server: ID-JAG → Access Token +1. **OIDC Login**: User authenticates with enterprise IdP → ID Token +2. **Token Exchange** (RFC 8693): ID Token → ID-JAG at IdP +3. **JWT Bearer Grant** (RFC 7523): ID-JAG → Access Token at MCP Server -This flow is typically used after obtaining an ID Token via OIDC login: +To use enterprise managed authorization, create an `EnterpriseHandler` and assign it to your transport: ```go -// Step 1: Obtain ID token via OIDC (see auth.InitiateOIDCLogin and auth.CompleteOIDCLogin) -idToken := "..." // from OIDC login +// Create ID token fetcher using OIDC login +idTokenFetcher := func(ctx context.Context) (*extauth.IDTokenResult, error) { + oidcConfig := &extauth.OIDCLoginConfig{ + IssuerURL: "https://company.okta.com", + ClientID: "idp-client-id", + ClientSecret: "idp-client-secret", + RedirectURL: "http://localhost:3142", + Scopes: []string{"openid", "profile", "email"}, + } + + tokens, err := extauth.PerformOIDCLogin(ctx, oidcConfig, authCodeFetcher) + if err != nil { + return nil, err + } + + return &extauth.IDTokenResult{Token: tokens.IDToken}, nil +} -// Step 2: Exchange for MCP access token -config := &auth.EnterpriseAuthConfig{ +// Create Enterprise Handler +enterpriseHandler, err := extauth.NewEnterpriseHandler(&extauth.EnterpriseHandlerConfig{ IdPIssuerURL: "https://company.okta.com", - IdPClientID: "client-id-at-idp", - IdPClientSecret: "secret-at-idp", + IdPClientID: "idp-client-id", + IdPClientSecret: "idp-client-secret", MCPAuthServerURL: "https://auth.mcpserver.example", MCPResourceURI: "https://mcp.mcpserver.example", - MCPClientID: "client-id-at-mcp", - MCPClientSecret: "secret-at-mcp", + MCPClientID: "mcp-client-id", + MCPClientSecret: "mcp-client-secret", MCPScopes: []string{"read", "write"}, -} + IDTokenFetcher: idTokenFetcher, +}) -accessToken, err := auth.EnterpriseAuthFlow(ctx, config, idToken) -// Use accessToken with MCP client +// Use with transport +transport := &mcp.StreamableClientTransport{ + Endpoint: "https://example.com/mcp", + OAuthHandler: enterpriseHandler, +} +client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) +session, err := client.Connect(ctx, transport, nil) ``` -Helper functions are provided for OIDC login: -- [`InitiateOIDCLogin`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#InitiateOIDCLogin) - Generate authorization URL with PKCE -- [`CompleteOIDCLogin`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#CompleteOIDCLogin) - Exchange authorization code for tokens +The `EnterpriseHandler` automatically manages the token exchange flow and token refreshing. + +For a complete working example, see [examples/auth/enterprise-client](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/auth/enterprise-client). ## Security diff --git a/examples/auth/enterprise-client/main.go b/examples/auth/enterprise-client/main.go new file mode 100644 index 00000000..654590c0 --- /dev/null +++ b/examples/auth/enterprise-client/main.go @@ -0,0 +1,197 @@ +// Copyright 2026 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +//go:build mcp_go_client_oauth + +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net" + "net/http" + + "github.com/modelcontextprotocol/go-sdk/auth" + "github.com/modelcontextprotocol/go-sdk/auth/extauth" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + // IdP (Identity Provider) configuration + idpIssuerURL = flag.String("idp_issuer", "https://your-idp.okta.com", "IdP issuer URL (e.g., https://your-company.okta.com)") + idpClientID = flag.String("idp_client_id", "", "Client ID registered at the IdP") + idpClientSecret = flag.String("idp_client_secret", "", "Client secret at the IdP (optional for public clients)") + + // MCP Server configuration + mcpServerURL = flag.String("mcp_server", "http://localhost:8000/mcp", "URL of the MCP server") + mcpAuthServerURL = flag.String("mcp_auth_server", "https://auth.mcpserver.example", "MCP server's authorization server URL") + mcpResourceURI = flag.String("mcp_resource_uri", "https://mcp.mcpserver.example", "MCP server's resource identifier (RFC 9728)") + mcpClientID = flag.String("mcp_client_id", "", "Client ID at the MCP server (optional)") + mcpClientSecret = flag.String("mcp_client_secret", "", "Client secret at the MCP server (optional)") + + // OAuth callback configuration + callbackPort = flag.Int("callback_port", 3142, "Port for the local HTTP server that will receive the OAuth callback") +) + +// codeReceiver handles the OAuth callback from the IdP's authorization endpoint. +// It starts a local HTTP server to receive the authorization code after the user +// authenticates with their enterprise IdP. +type codeReceiver struct { + authChan chan *auth.AuthorizationResult + errChan chan error + listener net.Listener + server *http.Server +} + +// serveRedirectHandler starts an HTTP server to handle the OAuth redirect callback. +func (r *codeReceiver) serveRedirectHandler(listener net.Listener) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + // Extract the authorization code and state from the callback URL + r.authChan <- &auth.AuthorizationResult{ + Code: req.URL.Query().Get("code"), + State: req.URL.Query().Get("state"), + } + fmt.Fprint(w, "Authentication successful. You can close this window.") + }) + + r.server = &http.Server{ + Addr: fmt.Sprintf("localhost:%d", *callbackPort), + Handler: mux, + } + if err := r.server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { + r.errChan <- err + } +} + +// getAuthorizationCode implements the AuthorizationCodeFetcher interface. +// It displays the authorization URL to the user and waits for the callback. +func (r *codeReceiver) getAuthorizationCode(ctx context.Context, args auth.AuthorizationArgs) (*auth.AuthorizationResult, error) { + fmt.Printf("\nPlease open the following URL in your browser to authenticate:\n%s\n\n", args.URL) + select { + case authRes := <-r.authChan: + return authRes, nil + case err := <-r.errChan: + return nil, err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// close shuts down the HTTP server. +func (r *codeReceiver) close() { + if r.server != nil { + r.server.Close() + } +} + +func main() { + flag.Parse() + + // Validate required configuration + if *idpClientID == "" { + log.Fatal("--idp_client_id is required") + } + + // Set up the OAuth callback receiver + receiver := &codeReceiver{ + authChan: make(chan *auth.AuthorizationResult), + errChan: make(chan error), + } + listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *callbackPort)) + if err != nil { + log.Fatalf("failed to listen on port %d: %v", *callbackPort, err) + } + go receiver.serveRedirectHandler(listener) + defer receiver.close() + + log.Printf("OAuth callback server listening on http://localhost:%d", *callbackPort) + + // Create an ID Token fetcher that performs OIDC login with the enterprise IdP + idTokenFetcher := func(ctx context.Context) (*extauth.IDTokenResult, error) { + log.Println("Starting OIDC login flow...") + + oidcConfig := &extauth.OIDCLoginConfig{ + IssuerURL: *idpIssuerURL, + ClientID: *idpClientID, + ClientSecret: *idpClientSecret, + RedirectURL: fmt.Sprintf("http://localhost:%d", *callbackPort), + Scopes: []string{"openid", "profile", "email"}, + } + + // PerformOIDCLogin handles the complete OIDC Authorization Code flow with PKCE + tokens, err := extauth.PerformOIDCLogin(ctx, oidcConfig, receiver.getAuthorizationCode) + if err != nil { + return nil, fmt.Errorf("OIDC login failed: %w", err) + } + + log.Println("OIDC login successful, obtained ID token") + return &extauth.IDTokenResult{Token: tokens.IDToken}, nil + } + + // Create the Enterprise Handler + // This handler implements the complete Enterprise Managed Authorization flow: + // 1. OIDC Login: User authenticates with enterprise IdP → ID Token (via idTokenFetcher) + // 2. Token Exchange (RFC 8693): ID Token → ID-JAG at IdP + // 3. JWT Bearer Grant (RFC 7523): ID-JAG → Access Token at MCP Server + log.Println("Creating enterprise authorization handler...") + enterpriseHandler, err := extauth.NewEnterpriseHandler(&extauth.EnterpriseHandlerConfig{ + // IdP configuration (where the user authenticates) + IdPIssuerURL: *idpIssuerURL, + IdPClientID: *idpClientID, + IdPClientSecret: *idpClientSecret, + + // MCP Server configuration (the resource being accessed) + MCPAuthServerURL: *mcpAuthServerURL, + MCPResourceURI: *mcpResourceURI, + MCPClientID: *mcpClientID, + MCPClientSecret: *mcpClientSecret, + MCPScopes: []string{"read", "write"}, + + // ID Token fetcher (performs OIDC login when needed) + IDTokenFetcher: idTokenFetcher, + }) + if err != nil { + log.Fatalf("failed to create enterprise handler: %v", err) + } + + // Create the MCP client transport with the enterprise handler + transport := &mcp.StreamableClientTransport{ + Endpoint: *mcpServerURL, + OAuthHandler: enterpriseHandler, + } + + ctx := context.Background() + client := mcp.NewClient(&mcp.Implementation{ + Name: "enterprise-client-example", + Version: "1.0.0", + }, nil) + + log.Printf("Connecting to MCP server at %s...", *mcpServerURL) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + log.Fatalf("failed to connect to MCP server: %v", err) + } + defer session.Close() + + log.Println("Successfully connected to MCP server!") + + // List available tools as a demonstration + tools, err := session.ListTools(ctx, nil) + if err != nil { + log.Fatalf("failed to list tools: %v", err) + } + + log.Println("\nAvailable tools:") + if len(tools.Tools) == 0 { + log.Println(" (no tools available)") + } else { + for _, tool := range tools.Tools { + log.Printf(" - %q: %s", tool.Name, tool.Description) + } + } +} diff --git a/internal/docs/protocol.src.md b/internal/docs/protocol.src.md index 1511a304..997d7852 100644 --- a/internal/docs/protocol.src.md +++ b/internal/docs/protocol.src.md @@ -282,40 +282,63 @@ session, err := client.Connect(ctx, transport, nil) The `auth.AuthorizationCodeHandler` automatically manages token refreshing and step-up authentication (when the server returns `insufficient_scope` error). -#### Enterprise Authentication Flow (SEP-990) +#### Enterprise Managed Authorization (SEP-990) -For enterprise SSO scenarios, the SDK provides an -[`EnterpriseAuthFlow`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#EnterpriseAuthFlow) -function that implements the complete token exchange flow: +For enterprise SSO scenarios where users authenticate with an enterprise Identity Provider (IdP), +the SDK provides +[`extauth.EnterpriseHandler`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth/extauth#EnterpriseHandler), +an implementation of `OAuthHandler` that automates the Enterprise Managed Authorization flow: -1. **Token Exchange** at IdP: ID Token → ID-JAG -2. **JWT Bearer Grant** at MCP Server: ID-JAG → Access Token +1. **OIDC Login**: User authenticates with enterprise IdP → ID Token +2. **Token Exchange** (RFC 8693): ID Token → ID-JAG at IdP +3. **JWT Bearer Grant** (RFC 7523): ID-JAG → Access Token at MCP Server -This flow is typically used after obtaining an ID Token via OIDC login: +To use enterprise managed authorization, create an `EnterpriseHandler` and assign it to your transport: ```go -// Step 1: Obtain ID token via OIDC (see auth.InitiateOIDCLogin and auth.CompleteOIDCLogin) -idToken := "..." // from OIDC login +// Create ID token fetcher using OIDC login +idTokenFetcher := func(ctx context.Context) (*extauth.IDTokenResult, error) { + oidcConfig := &extauth.OIDCLoginConfig{ + IssuerURL: "https://company.okta.com", + ClientID: "idp-client-id", + ClientSecret: "idp-client-secret", + RedirectURL: "http://localhost:3142", + Scopes: []string{"openid", "profile", "email"}, + } + + tokens, err := extauth.PerformOIDCLogin(ctx, oidcConfig, authCodeFetcher) + if err != nil { + return nil, err + } + + return &extauth.IDTokenResult{Token: tokens.IDToken}, nil +} -// Step 2: Exchange for MCP access token -config := &auth.EnterpriseAuthConfig{ +// Create Enterprise Handler +enterpriseHandler, err := extauth.NewEnterpriseHandler(&extauth.EnterpriseHandlerConfig{ IdPIssuerURL: "https://company.okta.com", - IdPClientID: "client-id-at-idp", - IdPClientSecret: "secret-at-idp", + IdPClientID: "idp-client-id", + IdPClientSecret: "idp-client-secret", MCPAuthServerURL: "https://auth.mcpserver.example", MCPResourceURI: "https://mcp.mcpserver.example", - MCPClientID: "client-id-at-mcp", - MCPClientSecret: "secret-at-mcp", + MCPClientID: "mcp-client-id", + MCPClientSecret: "mcp-client-secret", MCPScopes: []string{"read", "write"}, -} + IDTokenFetcher: idTokenFetcher, +}) -accessToken, err := auth.EnterpriseAuthFlow(ctx, config, idToken) -// Use accessToken with MCP client +// Use with transport +transport := &mcp.StreamableClientTransport{ + Endpoint: "https://example.com/mcp", + OAuthHandler: enterpriseHandler, +} +client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) +session, err := client.Connect(ctx, transport, nil) ``` -Helper functions are provided for OIDC login: -- [`InitiateOIDCLogin`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#InitiateOIDCLogin) - Generate authorization URL with PKCE -- [`CompleteOIDCLogin`](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk/auth#CompleteOIDCLogin) - Exchange authorization code for tokens +The `EnterpriseHandler` automatically manages the token exchange flow and token refreshing. + +For a complete working example, see [examples/auth/enterprise-client](https://github.com/modelcontextprotocol/go-sdk/tree/main/examples/auth/enterprise-client). ## Security diff --git a/oauthex/token_exchange.go b/oauthex/token_exchange.go index aa58c386..47d614ef 100644 --- a/oauthex/token_exchange.go +++ b/oauthex/token_exchange.go @@ -109,20 +109,6 @@ type TokenExchangeResponse struct { // // The tokenEndpoint parameter should be the IdP's token endpoint (typically // obtained from the IdP's authorization server metadata). -// -// Example: -// -// req := &TokenExchangeRequest{ -// RequestedTokenType: TokenTypeIDJAG, -// Audience: "https://auth.mcpserver.example/", -// Resource: "https://mcp.mcpserver.example/", -// Scope: []string{"read", "write"}, -// SubjectToken: idToken, -// SubjectTokenType: TokenTypeIDToken, -// } -// clientCreds := &ClientCredentials{ClientID: "my-client", ClientSecret: "secret"} -// -// resp, err := ExchangeToken(ctx, idpTokenEndpoint, req, clientCreds, nil) func ExchangeToken( ctx context.Context, tokenEndpoint string,