diff --git a/apps/vscode-e2e/src/suite/mcp-oauth.test.ts b/apps/vscode-e2e/src/suite/mcp-oauth.test.ts new file mode 100644 index 00000000000..0a1219ae888 --- /dev/null +++ b/apps/vscode-e2e/src/suite/mcp-oauth.test.ts @@ -0,0 +1,368 @@ +import * as assert from "assert" +import * as fs from "fs/promises" +import * as path from "path" +import * as os from "os" +import * as http from "http" +import * as vscode from "vscode" + +import { waitFor, sleep } from "./utils" +import { setDefaultSuiteTimeout } from "./test-utils" + +/** + * Minimal MCP-protocol-aware request handler. + * + * The SDK's StreamableHTTPClientTransport uses: + * - GET /mcp → SSE stream (we return 405 to indicate not supported) + * - POST /mcp → JSON-RPC messages (initialize, tools/list, etc.) + */ +function handleMcpRequest(req: http.IncomingMessage, res: http.ServerResponse, endpointsHit: Set): void { + if (req.method === "GET") { + // Signal that we don't support the SSE push channel. + // The SDK treats 405 as "SSE not supported, POST-only mode". + endpointsHit.add("mcp-authed-get") + res.writeHead(405) + res.end() + return + } + + // POST — read body, parse JSON-RPC, dispatch + let body = "" + req.on("data", (chunk) => (body += chunk)) + req.on("end", () => { + endpointsHit.add("mcp-authed") + + let message: { id?: number; method?: string } + try { + message = JSON.parse(body) + } catch { + res.writeHead(400) + res.end() + return + } + + // Notifications (no id) → 202 Accepted + if (message.id === undefined) { + res.writeHead(202) + res.end() + return + } + + let result: unknown + switch (message.method) { + case "initialize": + result = { + protocolVersion: "2024-11-05", + capabilities: {}, + serverInfo: { name: "test-oauth-server", version: "1.0.0" }, + } + break + case "tools/list": + result = { tools: [] } + break + case "resources/list": + result = { resources: [] } + break + case "resources/templates/list": + result = { resourceTemplates: [] } + break + default: + result = {} + } + + res.writeHead(200, { "Content-Type": "application/json" }) + res.end(JSON.stringify({ jsonrpc: "2.0", id: message.id, result })) + }) +} + +suite("Roo Code MCP OAuth", function () { + setDefaultSuiteTimeout(this) + + let tempDir: string + let testFiles: { mcpConfig: string } + let mockServer: http.Server + let mockServerPort: number + + // Track which OAuth / MCP endpoints were hit + const endpointsHit: Set = new Set() + + suiteSetup(async () => { + // Enable test mode so the OAuth callback server resolves immediately + // without needing a real browser redirect. + process.env.MCP_OAUTH_TEST_MODE = "true" + + tempDir = await fs.mkdtemp(path.join(os.tmpdir(), "roo-test-mcp-oauth-")) + + mockServer = http.createServer((req, res) => { + const url = req.url || "" + console.log(`[MOCK SERVER] ${req.method} ${url}`) + + // ── MCP endpoint ───────────────────────────────────────────── + if (url === "/mcp" || url.startsWith("/mcp?") || url.startsWith("/mcp/")) { + const authHeader = req.headers.authorization + if (!authHeader || !authHeader.startsWith("Bearer ")) { + endpointsHit.add("mcp-401") + res.writeHead(401, { + "WWW-Authenticate": `Bearer resource_metadata="http://localhost:${mockServerPort}/.well-known/oauth-protected-resource"`, + }) + res.end() + } else { + // Authenticated — handle as MCP protocol + handleMcpRequest(req, res, endpointsHit) + } + return + } + + // ── OAuth discovery / registration / token endpoints ───────── + + if (url === "/.well-known/oauth-protected-resource") { + endpointsHit.add("resource-metadata") + res.writeHead(200, { "Content-Type": "application/json" }) + res.end( + JSON.stringify({ + resource: `http://localhost:${mockServerPort}/mcp`, + authorization_servers: [`http://localhost:${mockServerPort}/auth`], + }), + ) + return + } + + // SDK constructs: new URL("/.well-known/oauth-authorization-server", "http://host/auth") + // which resolves to http://host/.well-known/oauth-authorization-server (origin-relative) + // Our custom fetchOAuthAuthServerMetadata constructs the RFC 8414 URL with issuer path: + // /.well-known/oauth-authorization-server/auth (with issuer path) + // Handle BOTH forms so our provider gets _authServerMeta. + if ( + url === "/.well-known/oauth-authorization-server" || + url === "/.well-known/oauth-authorization-server/auth" + ) { + endpointsHit.add("auth-metadata") + res.writeHead(200, { "Content-Type": "application/json" }) + res.end( + JSON.stringify({ + issuer: `http://localhost:${mockServerPort}/auth`, + authorization_endpoint: `http://localhost:${mockServerPort}/auth/authorize`, + token_endpoint: `http://localhost:${mockServerPort}/auth/token`, + registration_endpoint: `http://localhost:${mockServerPort}/auth/register`, + code_challenge_methods_supported: ["S256"], + response_types_supported: ["code"], + }), + ) + return + } + + if (url === "/auth/register" && req.method === "POST") { + endpointsHit.add("register") + res.writeHead(201, { "Content-Type": "application/json" }) + res.end( + JSON.stringify({ + client_id: "test-client-id", + redirect_uris: ["http://localhost:3000/callback"], + }), + ) + return + } + + if (url === "/auth/token" && req.method === "POST") { + endpointsHit.add("token") + res.writeHead(200, { "Content-Type": "application/json" }) + res.end( + JSON.stringify({ + access_token: "test-access-token", + token_type: "Bearer", + expires_in: 3600, + }), + ) + return + } + + // Capture authorize hits (only reachable if a real browser is present) + if (url.startsWith("/auth/authorize")) { + endpointsHit.add("authorize") + res.writeHead(200, { "Content-Type": "text/plain" }) + res.end("Authorization endpoint reached") + return + } + + res.writeHead(404) + res.end() + }) + + // Find an available port + mockServerPort = await new Promise((resolve, reject) => { + mockServer.listen(0, "127.0.0.1", () => { + const addr = mockServer.address() + if (!addr || typeof addr === "string") return reject(new Error("Failed to get address")) + resolve(addr.port) + }) + mockServer.on("error", reject) + }) + + const workspaceDir = vscode.workspace.workspaceFolders?.[0]?.uri.fsPath || tempDir + const rooDir = path.join(workspaceDir, ".roo") + await fs.mkdir(rooDir, { recursive: true }) + + const mcpConfig = { + mcpServers: { + "test-oauth-server": { + type: "streamable-http", + url: `http://localhost:${mockServerPort}/mcp`, + }, + }, + } + + testFiles = { mcpConfig: path.join(rooDir, "mcp.json") } + await fs.writeFile(testFiles.mcpConfig, JSON.stringify(mcpConfig, null, 2)) + + console.log("[TEST] Mock server port:", mockServerPort) + console.log("[TEST] MCP config:", testFiles.mcpConfig) + }) + + suiteTeardown(async () => { + delete process.env.MCP_OAUTH_TEST_MODE + + try { + await globalThis.api.cancelCurrentTask() + } catch { + // Task might not be running + } + + if (mockServer) { + await new Promise((resolve) => mockServer.close(() => resolve())) + } + + for (const filePath of Object.values(testFiles)) { + try { + await fs.unlink(filePath) + } catch { + // ignore + } + } + + const workspaceDir = vscode.workspace.workspaceFolders?.[0]?.uri.fsPath || tempDir + try { + await fs.rm(path.join(workspaceDir, ".roo"), { recursive: true, force: true }) + } catch { + // ignore + } + + await fs.rm(tempDir, { recursive: true, force: true }) + }) + + setup(async () => { + try { + await globalThis.api.cancelCurrentTask() + } catch { + // ignore + } + endpointsHit.clear() + await sleep(100) + }) + + teardown(async () => { + try { + await globalThis.api.cancelCurrentTask() + } catch { + // ignore + } + await sleep(100) + }) + + test("Should complete the full OAuth flow when connecting to an OAuth-protected MCP server", async function () { + // Re-write the config to trigger the file watcher and force a reconnect. + const workspaceDir = vscode.workspace.workspaceFolders?.[0]?.uri.fsPath || tempDir + const mcpConfigPath = path.join(workspaceDir, ".roo", "mcp.json") + + await fs.writeFile( + mcpConfigPath, + JSON.stringify( + { + mcpServers: { + "test-oauth-server": { + type: "streamable-http", + url: `http://localhost:${mockServerPort}/mcp`, + }, + }, + }, + null, + 2, + ), + ) + + // Step 1: Initial connection attempt gets 401 → triggers OAuth discovery + await waitFor(() => endpointsHit.has("mcp-401"), { timeout: 30_000 }) + console.log("[TEST] Got initial 401, OAuth flow started") + + // Step 2: SDK discovers OAuth metadata + await waitFor(() => endpointsHit.has("resource-metadata"), { timeout: 15_000 }) + console.log("[TEST] Resource metadata fetched") + + await waitFor(() => endpointsHit.has("auth-metadata"), { timeout: 15_000 }) + console.log("[TEST] Auth server metadata fetched") + + // Step 3: Dynamic client registration + await waitFor(() => endpointsHit.has("register"), { timeout: 15_000 }) + console.log("[TEST] Client registered") + + // Step 4: In MCP_OAUTH_TEST_MODE the callback server resolves immediately with + // a test auth code (no real browser needed). The SDK exchanges it for a token. + await waitFor(() => endpointsHit.has("token"), { timeout: 15_000 }) + console.log("[TEST] Access token obtained") + + // Step 5: The background _completeOAuthFlow task retries client.connect() with + // the bearer token. Verify the MCP server receives an authenticated request. + await waitFor(() => endpointsHit.has("mcp-authed"), { timeout: 15_000 }) + console.log("[TEST] MCP server connected with valid Bearer token") + + // Assert the complete OAuth flow ran + assert.ok(endpointsHit.has("mcp-401"), "MCP server should return 401 to trigger OAuth") + assert.ok(endpointsHit.has("resource-metadata"), "Resource metadata discovery should run") + assert.ok(endpointsHit.has("auth-metadata"), "Auth server metadata discovery should run") + assert.ok(endpointsHit.has("register"), "Dynamic client registration should run") + assert.ok(endpointsHit.has("token"), "Token exchange should succeed") + assert.ok(endpointsHit.has("mcp-authed"), "Retry connection should succeed with Bearer token") + + console.log("[TEST] MCP OAuth flow completed successfully. Endpoints hit:", [...endpointsHit]) + }) + + test("Should reuse stored token on reconnect without re-running the full OAuth flow", async function () { + // This test runs after the previous one, so a token is already stored in SecretStorage. + // Trigger another reconnect — the SDK should inject the cached token directly and skip the + // browser-based auth flow (no new register or token endpoints should be hit). + + // Clear only mcp-related hit tracking (token endpoint should NOT be re-hit) + endpointsHit.clear() + + const workspaceDir = vscode.workspace.workspaceFolders?.[0]?.uri.fsPath || tempDir + const mcpConfigPath = path.join(workspaceDir, ".roo", "mcp.json") + + // Slightly modify the config to force a reconnect + await fs.writeFile( + mcpConfigPath, + JSON.stringify( + { + mcpServers: { + "test-oauth-server": { + type: "streamable-http", + url: `http://localhost:${mockServerPort}/mcp`, + // A different but valid timeout value triggers config-change detection + timeout: 30, + }, + }, + }, + null, + 2, + ), + ) + + // Wait for the MCP server to receive an authenticated request + await waitFor(() => endpointsHit.has("mcp-authed"), { timeout: 30_000 }) + console.log("[TEST] Token reuse: MCP server got authenticated request") + + // The full OAuth flow should NOT have re-run (token was cached in SecretStorage) + assert.ok(endpointsHit.has("mcp-authed"), "Reconnect should use cached token") + assert.ok(!endpointsHit.has("mcp-401"), "Should not get 401 when token is cached") + assert.ok(!endpointsHit.has("register"), "Should not re-register client when token is cached") + + console.log("[TEST] Token reuse test passed. Endpoints hit:", [...endpointsHit]) + }) +}) diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index ea38ee02d6d..aa1dc3abec9 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -33,8 +33,11 @@ import { t } from "../../i18n" import { ClineProvider } from "../../core/webview/ClineProvider" import { GlobalFileNames } from "../../shared/globalFileNames" +import { UnauthorizedError } from "@modelcontextprotocol/sdk/client/auth.js" import { fileExistsAtPath } from "../../utils/fs" +import { SecretStorageService } from "./SecretStorageService" +import { McpOAuthClientProvider } from "./McpOAuthClientProvider" import { arePathsEqual, getWorkspacePath } from "../../utils/path" import { injectVariables } from "../../utils/config" import { safeWriteJson } from "../../utils/safeWriteJson" @@ -162,6 +165,7 @@ export class McpHub { private flagResetTimer?: NodeJS.Timeout private sanitizedNameRegistry: Map = new Map() private initializationPromise: Promise + private secretStorage?: SecretStorageService constructor(provider: ClineProvider) { this.providerRef = new WeakRef(provider) @@ -181,6 +185,10 @@ export class McpHub { async waitUntilReady(): Promise { await this.initializationPromise } + + public setSecretStorage(secretStorage: SecretStorageService): void { + this.secretStorage = secretStorage + } /** * Registers a client (e.g., ClineProvider) using this hub. * Increments the reference count. @@ -696,6 +704,7 @@ export class McpHub { ) let transport: StdioClientTransport | SSEClientTransport | StreamableHTTPClientTransport + let streamableHttpAuthProvider: McpOAuthClientProvider | undefined // Inject variables to the config (environment, magic variables,...) const configInjected = (await injectVariables(config, { @@ -779,11 +788,33 @@ export class McpHub { console.error(`No stderr stream for ${name}`) } } else if (configInjected.type === "streamable-http") { - // Streamable HTTP connection + if (!this.secretStorage) { + throw new Error("SecretStorageService not initialized — call setSecretStorage() before connecting") + } + + // Create an OAuth provider for this server. + // + // McpOAuthClientProvider.create() performs OAuth discovery (RFC 9728 + + // RFC 8414) once and starts the local callback server so the redirect + // URI port is stable before any connect attempt. + // + // If the server already has a stored token the SDK will use it + // transparently; the browser is only opened when a 401 forces a new + // authorization flow. + const authProvider = await McpOAuthClientProvider.create(configInjected.url, this.secretStorage, name) + + // Pre-register the OAuth client so the SDK can skip its own + // registration step (broken for path-prefixed issuers — see + // utils/oauth.ts for upstream issue links). + try { + await authProvider.registerClientIfNeeded() + } catch { + // Registration may not be supported — the SDK will attempt its own. + } + transport = new StreamableHTTPClientTransport(new URL(configInjected.url), { - requestInit: { - headers: configInjected.headers, - }, + authProvider, + requestInit: { headers: configInjected.headers }, }) // Set up Streamable HTTP specific error handling @@ -804,6 +835,9 @@ export class McpHub { } await this.notifyWebviewOfServerChanges() } + + // Keep a reference so the UnauthorizedError handler can use it. + streamableHttpAuthProvider = authProvider } else if (configInjected.type === "sse") { // SSE connection const sseOptions = { @@ -875,7 +909,34 @@ export class McpHub { this.connections.push(connection) // Connect (this will automatically start the transport) - await client.connect(transport) + try { + await client.connect(transport) + } catch (connectError) { + if (connectError instanceof UnauthorizedError && streamableHttpAuthProvider) { + // The server requires OAuth. The SDK has already called + // authProvider.redirectToAuthorization() which started the local callback + // server (lazily) and opened the user's browser. + // + // We fire-and-forget the rest of the flow so the extension (chat window, + // other servers) is not blocked waiting for the user's browser session. + connection.server.status = "connecting" + void this._completeOAuthFlow( + streamableHttpAuthProvider, + transport as StreamableHTTPClientTransport, + connection, + name, + source, + ) + return + } + // Non-OAuth error — let the outer catch handle it. + await streamableHttpAuthProvider?.close() + throw connectError + } + + // Successful connection — close callback server if it was started. + await streamableHttpAuthProvider?.close() + connection.server.status = "connected" connection.server.error = "" connection.server.instructions = client.getInstructions() @@ -895,6 +956,63 @@ export class McpHub { } } + /** + * Background task: waits for the user to complete the OAuth browser flow, + * exchanges the auth code for tokens, then reconnects from scratch. + * + * After the SDK throws UnauthorizedError the transport is left in a + * "started" state (_abortController is set), so we cannot simply call + * client.connect() on it again — the SDK would throw "already started". + * The clean solution is to delete the broken connection and let + * connectToServer() create fresh client/transport objects. The new + * provider will find the token in SecretStorage and connect without + * triggering another OAuth round-trip. + * + * This runs detached from the initialization path so `waitUntilReady()` + * and the rest of the extension are not blocked by the user's browser session. + */ + private async _completeOAuthFlow( + authProvider: McpOAuthClientProvider, + transport: StreamableHTTPClientTransport, + connection: ConnectedMcpConnection, + name: string, + source: "global" | "project", + ): Promise { + try { + const code = await authProvider.waitForAuthCode() + // Exchange auth code for tokens using the pre-fetched token_endpoint + // directly. The SDK's transport.finishAuth() re-runs discovery internally + // and hits the same broken URL for path-prefixed issuers (see + // utils/oauth.ts for upstream issue links). + await authProvider.exchangeCodeForTokens(code) + authProvider.close().catch(console.error) + + // Recover the validated server config stored on the connection so we + // can pass it directly to connectToServer without re-reading the file. + const parsedConfig = JSON.parse(connection.server.config) + const validatedConfig = this.validateServerConfig(parsedConfig, name) + + // Remove the broken connection (closes the old transport/client), + // then reconnect. The new McpOAuthClientProvider will find the token + // in SecretStorage and connect without another OAuth round-trip. + await this.deleteConnection(name, source) + await this.connectToServer(name, validatedConfig, source) + + await this.notifyWebviewOfServerChanges() + void vscode.window.showInformationMessage( + `MCP server "${name}" connected successfully after OAuth authentication.`, + ) + } catch (error) { + await authProvider.close() + const conn = this.findConnection(name, source) + if (conn) { + conn.server.status = "disconnected" + this.appendErrorMessage(conn, error instanceof Error ? error.message : `${error}`) + } + await this.notifyWebviewOfServerChanges() + } + } + private appendErrorMessage(connection: McpConnection, error: string, level: "error" | "warn" | "info" = "error") { const MAX_ERROR_LENGTH = 1000 const truncatedError = @@ -1278,6 +1396,11 @@ export class McpHub { // Validate the config const validatedConfig = this.validateServerConfig(parsedConfig, serverName) + // Clear OAuth tokens for streamable-http servers on restart + if (validatedConfig.type === "streamable-http" && this.secretStorage) { + await this.secretStorage.deleteOAuthData(validatedConfig.url) + } + // Try to connect again using validated config await this.connectToServer(serverName, validatedConfig, connection.server.source || "global") vscode.window.showInformationMessage(t("mcp:info.server_connected", { serverName })) diff --git a/src/services/mcp/McpOAuthClientProvider.ts b/src/services/mcp/McpOAuthClientProvider.ts new file mode 100644 index 00000000000..eb38b62f11e --- /dev/null +++ b/src/services/mcp/McpOAuthClientProvider.ts @@ -0,0 +1,337 @@ +import * as http from "http" + +import * as vscode from "vscode" +import type { OAuthClientProvider } from "@modelcontextprotocol/sdk/client/auth.js" +import type { + OAuthClientInformation, + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthTokens, +} from "@modelcontextprotocol/sdk/shared/auth.js" + +import { SecretStorageService } from "./SecretStorageService" +import { startCallbackServer, stopCallbackServer } from "./utils/callbackServer" +import { fetchOAuthAuthServerMetadata } from "./utils/oauth" + +/** + * Implements the MCP SDK's OAuthClientProvider interface for VS Code. + * + * Responsibilities: + * - Stores/loads OAuth tokens via VS Code SecretStorage + * - Runs a local HTTP callback server to receive the authorization code + * - Opens the browser for the authorization redirect + * - Provides PKCE code verifier round-trip storage + * + * Usage pattern in McpHub: + * 1. `const authProvider = await McpOAuthClientProvider.create(url, secretStorage)` + * 2. Pass `authProvider` to `StreamableHTTPClientTransport({ authProvider })` + * 3. `await client.connect(transport)` — may throw `UnauthorizedError` + * 4. On `UnauthorizedError`: `code = await authProvider.waitForAuthCode()` + * 5. `await transport.finishAuth(code)` then retry `client.connect(transport)` + * 6. `await authProvider.close()` when done (success or permanent failure) + */ +export class McpOAuthClientProvider implements OAuthClientProvider { + private _codeVerifier?: string + // Client info is kept in-memory only (not persisted) to avoid stale registrations + // when the redirect URI port changes between sessions. + private _clientInfo?: OAuthClientInformationFull + private _closed = false + + private constructor( + private readonly _serverUrl: string, + private readonly _secretStorage: SecretStorageService, + private readonly _server: http.Server, + private readonly _port: number, + private readonly _authCodePromise: Promise, + private readonly _tokenEndpointAuthMethod: string, + private readonly _grantTypes: string[], + private readonly _scopes: string[], + private readonly _state: string, + private readonly _authServerMeta: Record | null, + private readonly _resourceIndicator: string | null, + private readonly _clientName: string, + ) {} + + /** + * Factory — discovers OAuth Authorization Server metadata once (RFC 9728 + + * RFC 8414), starts the local callback server, and returns a ready provider. + * + * Discovery and callback-server startup both happen here so that: + * - `redirectUrl` (used by the SDK to build the authorization URL) is + * stable before any connect attempt. + * - The same metadata object is reused for client registration without a + * second network round-trip. + */ + static async create( + serverUrl: string, + secretStorage: SecretStorageService, + serverName?: string, + ): Promise { + // Fetch auth server metadata once. Reused for: + // - selecting token_endpoint_auth_method / grant_types / scopes + // - pre-registering the client (registration_endpoint) + // - RFC 8707 resource indicator (injected into authorization URL) + const discovery = await fetchOAuthAuthServerMetadata(serverUrl) + const authServerMeta = discovery?.authServerMeta ?? null + const resourceIndicator = discovery?.resourceIndicator ?? null + + // Extract auth-method preferences. + // Prefer "none" → first supported → "client_secret_post" + const authMethods: string[] = authServerMeta?.token_endpoint_auth_methods_supported ?? [] + const tokenEndpointAuthMethod = authMethods.includes("none") ? "none" : (authMethods[0] ?? "client_secret_post") + const grantTypes: string[] = authServerMeta?.grant_types_supported ?? ["authorization_code", "refresh_token"] + const scopes: string[] = authServerMeta?.scopes_supported ?? ["openid"] + + // Generate a CSRF state token for the OAuth flow. + const state = Array.from(crypto.getRandomValues(new Uint8Array(8))) + .map((b) => b.toString(16).padStart(2, "0")) + .join("") + + // Start the callback server now so the port is known and stable. + // The SDK reads `redirectUrl` synchronously when building the authorization + // URL, so the port must be available before any connect attempt. + const { server, port, result } = await startCallbackServer(undefined, state) + + const authCodePromise = result.then((r) => { + if (r.error) throw new Error(`OAuth authorization failed: ${r.error}`) + if (!r.code) throw new Error("No authorization code received in callback") + return r.code + }) + + return new McpOAuthClientProvider( + serverUrl, + secretStorage, + server, + port, + authCodePromise, + tokenEndpointAuthMethod, + grantTypes, + scopes, + state, + authServerMeta, + resourceIndicator, + serverName || "Roo Code", + ) + } + + // ── OAuthClientProvider interface ──────────────────────────────────────── + + get redirectUrl(): string { + return `http://localhost:${this._port}/callback` + } + + state(): string { + return this._state + } + + get clientMetadata(): OAuthClientMetadata { + return { + client_name: this._clientName, + redirect_uris: [this.redirectUrl], + grant_types: this._grantTypes, + response_types: ["code"], + token_endpoint_auth_method: this._tokenEndpointAuthMethod, + } + } + + async clientInformation(): Promise { + return this._clientInfo + } + + async saveClientInformation(info: OAuthClientInformationFull): Promise { + this._clientInfo = info + } + + /** + * Registers this client with the authorization server if a + * `registration_endpoint` is present in the pre-fetched auth server + * metadata. No-ops if already registered or if the server doesn't + * support dynamic client registration. + * + * Called by McpHub before the first `client.connect()` attempt so that + * `clientInformation()` returns a valid client_id and the SDK skips its + * own registration step — which fails for issuers with path components + * due to the same metadata discovery bug (see utils/oauth.ts for + * upstream issue links). + */ + async registerClientIfNeeded(): Promise { + if (this._clientInfo) return // already registered + + // Check if we have a cached client_id from previous registration + const cachedData = await this._secretStorage.getOAuthData(this._serverUrl) + if (cachedData?.client_id && cachedData.redirect_uri === this.redirectUrl) { + this._clientInfo = { + client_id: cachedData.client_id, + redirect_uris: [this.redirectUrl], + client_name: this._clientName, + grant_types: this._grantTypes, + response_types: ["code"], + token_endpoint_auth_method: this._tokenEndpointAuthMethod, + } + return + } + + if (!this._authServerMeta?.registration_endpoint) return // DCR not supported + + const response = await fetch(this._authServerMeta.registration_endpoint as string, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json", + }, + body: JSON.stringify(this.clientMetadata), + }) + + if (!response.ok) { + throw new Error(`Dynamic client registration failed: HTTP ${response.status}`) + } + + this._clientInfo = (await response.json()) as OAuthClientInformationFull + } + + async tokens(): Promise { + const data = await this._secretStorage.getOAuthData(this._serverUrl) + if (!data) return undefined + // Return undefined 5 minutes before expiry so the SDK triggers re-auth + // before the server actually rejects requests. + if (Date.now() >= data.expires_at - 5 * 60 * 1000) return undefined + return data.tokens + } + + async saveTokens(tokens: OAuthTokens): Promise { + const expires_at = tokens.expires_in ? Date.now() + tokens.expires_in * 1000 : Date.now() + 3600 * 1000 // default 1 hour when server omits expires_in + await this._secretStorage.saveOAuthData(this._serverUrl, { + tokens, + expires_at, + client_id: this._clientInfo?.client_id, + redirect_uri: this.redirectUrl, + }) + } + + async redirectToAuthorization(authorizationUrl: URL): Promise { + // Workaround for SDK metadata discovery bug (see utils/oauth.ts for issue links). + // The SDK's discoverOAuthMetadata() builds a wrong well-known URL for issuers + // with path components, causing it to fall back to a default "/authorize" path. + // We correct the URL using our pre-fetched metadata: + // 1. Replace the origin+pathname with the real authorization_endpoint. + // 2. Preserve all SDK-generated query params (client_id, code_challenge, etc.) + // 3. Add `scope` when the server advertises scopes but the SDK omitted it. + // 4. Add RFC 8707 `resource` parameter when the protected resource metadata + // advertised a resource indicator. + let correctedUrl = authorizationUrl + if (this._authServerMeta?.authorization_endpoint) { + try { + const fixed = new URL(this._authServerMeta.authorization_endpoint as string) + // Copy all query params generated by the SDK + authorizationUrl.searchParams.forEach((value, key) => { + fixed.searchParams.set(key, value) + }) + // Ensure the scope param is present — the SDK sometimes omits it + if (!fixed.searchParams.has("scope") && this._scopes.length > 0) { + fixed.searchParams.set("scope", this._scopes.join(" ")) + } + // RFC 8707: inject the resource indicator so the auth server can + // scope the issued access token to this specific resource server. + if (this._resourceIndicator && !fixed.searchParams.has("resource")) { + fixed.searchParams.set("resource", this._resourceIndicator) + } + correctedUrl = fixed + } catch { + // Fall through and use the original URL if correction fails + } + } + + void vscode.window.showInformationMessage("MCP server requires authentication. Opening browser for OAuth…") + try { + await vscode.env.openExternal(vscode.Uri.parse(correctedUrl.toString())) + } catch { + void vscode.window.showInformationMessage( + `Please open this URL in your browser to authenticate: ${correctedUrl}`, + ) + } + } + + async saveCodeVerifier(codeVerifier: string): Promise { + this._codeVerifier = codeVerifier + } + + async codeVerifier(): Promise { + if (!this._codeVerifier) throw new Error("No PKCE code verifier saved") + return this._codeVerifier + } + + // ── Extra helpers for McpHub ───────────────────────────────────────────── + + /** + * Resolves with the authorization code once the user completes the OAuth + * browser flow and the local callback server receives the redirect. + * Rejects on error or 5-minute timeout. + */ + waitForAuthCode(): Promise { + return this._authCodePromise + } + + /** + * Exchanges an authorization code for tokens by POSTing directly to the + * `token_endpoint` from our pre-fetched metadata. + * + * This bypasses the SDK's `transport.finishAuth()` which internally re-runs + * `discoverOAuthMetadata()` and hits the same broken URL construction for + * issuers with path components (see utils/oauth.ts for upstream issue links). + * + * After a successful exchange the tokens are persisted via `saveTokens()` + * so the next `client.connect()` call finds them in SecretStorage and + * connects without another OAuth round-trip. + * + * @param authorizationCode The code received in the OAuth callback redirect. + * @throws When the token endpoint is unknown or the exchange request fails. + */ + async exchangeCodeForTokens(authorizationCode: string): Promise { + if (!this._authServerMeta?.token_endpoint) { + throw new Error("No token_endpoint in auth server metadata — cannot exchange code") + } + if (!this._clientInfo) { + throw new Error("No client information — registerClientIfNeeded() must be called first") + } + + const codeVerifier = await this.codeVerifier() + + // Build the token request body per RFC 6749 §4.1.3 + RFC 7636 §4.5. + const params: Record = { + grant_type: "authorization_code", + code: authorizationCode, + redirect_uri: this.redirectUrl, + client_id: this._clientInfo.client_id, + code_verifier: codeVerifier, + } + + // Include client_secret in the body when the auth method is client_secret_post. + if (this._tokenEndpointAuthMethod === "client_secret_post" && this._clientInfo.client_secret) { + params.client_secret = this._clientInfo.client_secret + } + + const response = await fetch(this._authServerMeta.token_endpoint as string, { + method: "POST", + headers: { + "Content-Type": "application/x-www-form-urlencoded", + Accept: "application/json", + }, + body: new URLSearchParams(params).toString(), + }) + + if (!response.ok) { + throw new Error(`Token exchange failed: HTTP ${response.status}`) + } + + const tokens = (await response.json()) as OAuthTokens + await this.saveTokens(tokens) + } + + /** Close the local callback server. Always call this when done. */ + async close(): Promise { + if (!this._closed) { + this._closed = true + await stopCallbackServer(this._server).catch(() => {}) + } + } +} diff --git a/src/services/mcp/McpServerManager.ts b/src/services/mcp/McpServerManager.ts index 3fd7146d9f9..cc456d7e895 100644 --- a/src/services/mcp/McpServerManager.ts +++ b/src/services/mcp/McpServerManager.ts @@ -1,5 +1,6 @@ import * as vscode from "vscode" import { McpHub } from "./McpHub" +import { SecretStorageService } from "./SecretStorageService" import { ClineProvider } from "../../core/webview/ClineProvider" /** @@ -37,6 +38,9 @@ export class McpServerManager { // Double-check instance in case it was created while we were waiting if (!this.instance) { const hub = new McpHub(provider) + // Set the secret storage service for OAuth operations + const secretStorage = new SecretStorageService(context) + hub.setSecretStorage(secretStorage) // Wait for all MCP servers to finish connecting (or timing out) await hub.waitUntilReady() this.instance = hub diff --git a/src/services/mcp/SecretStorageService.ts b/src/services/mcp/SecretStorageService.ts new file mode 100644 index 00000000000..e320b2ba6d9 --- /dev/null +++ b/src/services/mcp/SecretStorageService.ts @@ -0,0 +1,48 @@ +import * as vscode from "vscode" +import type { OAuthTokens } from "@modelcontextprotocol/sdk/shared/auth.js" + +export interface StoredMcpOAuthData { + tokens: OAuthTokens + /** Unix ms timestamp after which the access token should be considered expired. */ + expires_at: number + /** The client_id used to obtain these tokens (for token reuse without re-registration). */ + client_id?: string + /** The redirect_uri used during client registration (to detect port changes). */ + redirect_uri?: string +} + +/** + * Thin wrapper around VS Code SecretStorage for persisting MCP OAuth tokens. + * Tokens are stored per-server (keyed by host) so different servers on the + * same host share credentials, which is the common case for multi-path APIs. + */ +export class SecretStorageService { + private readonly _storage: vscode.SecretStorage + private readonly _namespace = "mcp.oauth." + + constructor(context: vscode.ExtensionContext) { + this._storage = context.secrets + } + + private _key(serverUrl: string): string { + return `${this._namespace}${new URL(serverUrl).host}.data` + } + + async getOAuthData(serverUrl: string): Promise { + const raw = await this._storage.get(this._key(serverUrl)) + if (!raw) return undefined + try { + return JSON.parse(raw) as StoredMcpOAuthData + } catch { + return undefined + } + } + + async saveOAuthData(serverUrl: string, data: StoredMcpOAuthData): Promise { + await this._storage.store(this._key(serverUrl), JSON.stringify(data)) + } + + async deleteOAuthData(serverUrl: string): Promise { + await this._storage.delete(this._key(serverUrl)) + } +} diff --git a/src/services/mcp/__tests__/McpOAuthClientProvider.spec.ts b/src/services/mcp/__tests__/McpOAuthClientProvider.spec.ts new file mode 100644 index 00000000000..b5863b3dd0a --- /dev/null +++ b/src/services/mcp/__tests__/McpOAuthClientProvider.spec.ts @@ -0,0 +1,615 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" + +// Mock vscode +vi.mock("vscode", () => ({ + window: { + showInformationMessage: vi.fn(), + }, + env: { + openExternal: vi.fn().mockResolvedValue(true), + }, + Uri: { + parse: vi.fn((url: string) => ({ toString: () => url })), + }, +})) + +// Mock callbackServer +vi.mock("../utils/callbackServer", () => ({ + startCallbackServer: vi.fn(), + stopCallbackServer: vi.fn().mockResolvedValue(undefined), +})) + +// Mock fetch for auth discovery so tests don't make real network calls +const mockFetch = vi.fn() +global.fetch = mockFetch + +// Mock SDK auth discovery functions +vi.mock("@modelcontextprotocol/sdk/client/auth.js", () => ({ + discoverOAuthProtectedResourceMetadata: vi.fn().mockResolvedValue({ + resource: "https://example.com/", + authorization_servers: ["https://auth.example.com"], + }), +})) + +// Set up fetch mock to return auth metadata with "none" auth method +mockFetch.mockResolvedValue({ + ok: true, + json: () => + Promise.resolve({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + registration_endpoint: "https://auth.example.com/register", + response_types_supported: ["code"], + token_endpoint_auth_methods_supported: ["none"], + grant_types_supported: ["authorization_code", "refresh_token"], + }), +}) + +import { McpOAuthClientProvider } from "../McpOAuthClientProvider" +import { SecretStorageService } from "../SecretStorageService" +import { startCallbackServer, stopCallbackServer } from "../utils/callbackServer" +import { discoverOAuthProtectedResourceMetadata } from "@modelcontextprotocol/sdk/client/auth.js" +import * as vscode from "vscode" + +function createMockSecretStorage(): SecretStorageService { + const store = new Map() + return { + getOAuthData: vi.fn(async (url: string) => { + const raw = store.get(url) + return raw ? JSON.parse(raw) : undefined + }), + saveOAuthData: vi.fn(async (url: string, data: any) => { + store.set(url, JSON.stringify(data)) + }), + deleteOAuthData: vi.fn(async (url: string) => { + store.delete(url) + }), + } as unknown as SecretStorageService +} + +function setupCallbackServerMock(code = "test-auth-code", state?: string) { + const mockServer = { close: vi.fn((cb: () => void) => cb()) } + const resultPromise = Promise.resolve({ code, state }) + ;(startCallbackServer as any).mockResolvedValue({ + server: mockServer, + port: 12345, + result: resultPromise, + }) + return { mockServer, resultPromise } +} + +describe("McpOAuthClientProvider", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe("create", () => { + it("should start a callback server and return a provider", async () => { + setupCallbackServerMock() + + const secretStorage = createMockSecretStorage() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", secretStorage) + + expect(startCallbackServer).toHaveBeenCalledWith(undefined, expect.any(String)) + expect(provider.redirectUrl).toBe("http://localhost:12345/callback") + await provider.close() + }) + }) + + describe("clientMetadata", () => { + it("should return correct metadata with redirect URI", async () => { + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + const metadata = provider.clientMetadata + + expect(metadata.client_name).toBe("Roo Code") + expect(metadata.redirect_uris).toEqual(["http://localhost:12345/callback"]) + expect(metadata.grant_types).toContain("authorization_code") + expect(metadata.response_types).toContain("code") + expect(metadata.token_endpoint_auth_method).toBe("none") + await provider.close() + }) + + it("should use server name as client_name when provided", async () => { + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create( + "https://example.com/mcp", + createMockSecretStorage(), + "figma", + ) + + expect(provider.clientMetadata.client_name).toBe("figma") + await provider.close() + }) + }) + + describe("clientInformation / saveClientInformation", () => { + it("should return undefined initially", async () => { + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + expect(await provider.clientInformation()).toBeUndefined() + await provider.close() + }) + + it("should return saved client info", async () => { + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + const info = { + client_id: "test-id", + client_secret: "test-secret", + redirect_uris: ["http://localhost:12345/callback"], + } + await provider.saveClientInformation(info as any) + + const result = await provider.clientInformation() + expect(result).toEqual(info) + await provider.close() + }) + }) + + describe("tokens / saveTokens", () => { + it("should return undefined when no tokens stored", async () => { + setupCallbackServerMock() + const secretStorage = createMockSecretStorage() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", secretStorage) + + expect(await provider.tokens()).toBeUndefined() + await provider.close() + }) + + it("should store and return tokens", async () => { + setupCallbackServerMock() + const secretStorage = createMockSecretStorage() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", secretStorage) + + const tokens = { + access_token: "test-token", + token_type: "Bearer", + expires_in: 3600, + } + await provider.saveTokens(tokens) + + const result = await provider.tokens() + expect(result).toEqual(tokens) + await provider.close() + }) + + it("should return undefined for expired tokens", async () => { + setupCallbackServerMock() + const secretStorage = createMockSecretStorage() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", secretStorage) + + // Directly store data with an expires_at in the past so tokens() returns undefined + await secretStorage.saveOAuthData("https://example.com/mcp", { + tokens: { access_token: "expired", token_type: "Bearer" }, + expires_at: Date.now() - 1000, // already expired + }) + + expect(await provider.tokens()).toBeUndefined() + await provider.close() + }) + }) + + describe("codeVerifier / saveCodeVerifier", () => { + it("should throw if no verifier saved", async () => { + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + await expect(provider.codeVerifier()).rejects.toThrow("No PKCE code verifier saved") + await provider.close() + }) + + it("should round-trip code verifier", async () => { + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + await provider.saveCodeVerifier("test-verifier-123") + expect(await provider.codeVerifier()).toBe("test-verifier-123") + await provider.close() + }) + }) + + describe("redirectToAuthorization", () => { + it("should open browser with the authorization URL", async () => { + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + const authUrl = new URL("https://auth.example.com/authorize?client_id=test") + await provider.redirectToAuthorization(authUrl) + + expect(vscode.env.openExternal).toHaveBeenCalled() + expect(vscode.window.showInformationMessage).toHaveBeenCalledWith( + expect.stringContaining("Opening browser for OAuth"), + ) + await provider.close() + }) + + it("should show URL as fallback if browser open fails", async () => { + setupCallbackServerMock() + ;(vscode.env.openExternal as any).mockRejectedValueOnce(new Error("no browser")) + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + const authUrl = new URL("https://auth.example.com/authorize?client_id=test") + await provider.redirectToAuthorization(authUrl) + + expect(vscode.window.showInformationMessage).toHaveBeenCalledWith( + expect.stringContaining("Please open this URL"), + ) + await provider.close() + }) + + it("should correct a wrong authorization URL using pre-fetched metadata", async () => { + // Mock discovery to return an issuer with a path component. + // The SDK's discoverOAuthMetadata builds the wrong URL for such issuers, + // so it typically falls back to a bare /authorize path. + ;(discoverOAuthProtectedResourceMetadata as any).mockResolvedValueOnce({ + resource: "https://mcp.kapa.ai/", + authorization_servers: ["https://mcp.kapa.ai/auth/public"], + }) + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + issuer: "https://mcp.kapa.ai/auth/public", + authorization_endpoint: "https://mcp.kapa.ai/auth/public/authorize", + token_endpoint: "https://mcp.kapa.ai/auth/public/token", + registration_endpoint: "https://mcp.kapa.ai/auth/public/register", + token_endpoint_auth_methods_supported: ["client_secret_post"], + grant_types_supported: ["authorization_code", "refresh_token"], + scopes_supported: ["openid"], + }), + }) + + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://mcp.kapa.ai/mcp", createMockSecretStorage()) + + // Simulate the SDK building the wrong base URL (using bare /authorize) and omitting scope + const sdkWrongUrl = new URL("https://mcp.kapa.ai/authorize?client_id=abc&code_challenge=xyz&state=123") + await provider.redirectToAuthorization(sdkWrongUrl) + + // The provider should have corrected the URL to use the real authorization_endpoint + const openedUri = (vscode.env.openExternal as any).mock.calls[0][0].toString() + expect(openedUri).toContain("https://mcp.kapa.ai/auth/public/authorize") + expect(openedUri).toContain("client_id=abc") + expect(openedUri).toContain("code_challenge=xyz") + expect(openedUri).toContain("state=123") + // scope should be injected from metadata + expect(openedUri).toContain("scope=openid") + // RFC 8707: resource indicator from protected resource metadata should be injected + expect(openedUri).toContain("resource=") + expect(decodeURIComponent(openedUri)).toContain("resource=https://mcp.kapa.ai/") + await provider.close() + }) + + it("should inject RFC 8707 resource indicator from protected resource metadata", async () => { + // Mock discovery returning a resource indicator (RFC 9728 `resource` field) + ;(discoverOAuthProtectedResourceMetadata as any).mockResolvedValueOnce({ + resource: "https://temporal.mcp.kapa.ai/", + authorization_servers: ["https://mcp.kapa.ai/auth/public"], + }) + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + issuer: "https://mcp.kapa.ai/auth/public", + authorization_endpoint: "https://mcp.kapa.ai/auth/public/authorize", + token_endpoint: "https://mcp.kapa.ai/auth/public/token", + token_endpoint_auth_methods_supported: ["none"], + grant_types_supported: ["authorization_code"], + scopes_supported: ["openid"], + }), + }) + + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create( + "https://temporal.mcp.kapa.ai/mcp", + createMockSecretStorage(), + ) + + const sdkUrl = new URL("https://mcp.kapa.ai/authorize?client_id=abc&state=123") + await provider.redirectToAuthorization(sdkUrl) + + const openedUri = (vscode.env.openExternal as any).mock.calls[0][0].toString() + // The resource indicator from the protected resource metadata must appear + // as the `resource` query parameter (RFC 8707) + expect(decodeURIComponent(openedUri)).toContain("resource=https://temporal.mcp.kapa.ai/") + await provider.close() + }) + + it("should not duplicate resource if the SDK already included it", async () => { + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + // SDK URL already contains a resource param + const sdkUrl = new URL( + "https://auth.example.com/authorize?client_id=abc&resource=https%3A%2F%2Fexample.com%2F&state=123", + ) + await provider.redirectToAuthorization(sdkUrl) + + const openedUri = (vscode.env.openExternal as any).mock.calls[0][0].toString() + const resourceMatches = (openedUri.match(/resource=/g) || []).length + expect(resourceMatches).toBe(1) + await provider.close() + }) + + it("should not duplicate scope if the SDK already included it", async () => { + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + // SDK URL already includes scope=openid + const sdkUrl = new URL("https://auth.example.com/authorize?client_id=abc&scope=openid&state=123") + await provider.redirectToAuthorization(sdkUrl) + + // scope should appear exactly once + const openedUri = (vscode.env.openExternal as any).mock.calls[0][0].toString() + const scopeMatches = (openedUri.match(/scope=/g) || []).length + expect(scopeMatches).toBe(1) + await provider.close() + }) + }) + + describe("exchangeCodeForTokens", () => { + it("should POST to the token_endpoint and save tokens", async () => { + ;(discoverOAuthProtectedResourceMetadata as any).mockResolvedValueOnce({ + resource: "https://mcp.kapa.ai/", + authorization_servers: ["https://mcp.kapa.ai/auth/public"], + }) + const tokenResponse = { + access_token: "access-token-xyz", + token_type: "Bearer", + expires_in: 3600, + } + mockFetch + .mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + issuer: "https://mcp.kapa.ai/auth/public", + authorization_endpoint: "https://mcp.kapa.ai/auth/public/authorize", + token_endpoint: "https://mcp.kapa.ai/auth/public/token", + registration_endpoint: "https://mcp.kapa.ai/auth/public/register", + token_endpoint_auth_methods_supported: ["client_secret_post"], + grant_types_supported: ["authorization_code", "refresh_token"], + scopes_supported: ["openid"], + }), + }) + .mockResolvedValueOnce({ + ok: true, + json: () => Promise.resolve(tokenResponse), + }) + + setupCallbackServerMock() + const secretStorage = createMockSecretStorage() + const provider = await McpOAuthClientProvider.create("https://mcp.kapa.ai/mcp", secretStorage) + + // Set up client info and code verifier + await provider.saveClientInformation({ + client_id: "client-id-123", + client_secret: "client-secret-abc", + redirect_uris: ["http://localhost:12345/callback"], + } as any) + await provider.saveCodeVerifier("pkce-verifier-123") + + await provider.exchangeCodeForTokens("auth-code-abc") + + // Verify the token endpoint was called with correct params + const tokenCall = mockFetch.mock.calls[mockFetch.mock.calls.length - 1] + expect(tokenCall[0]).toBe("https://mcp.kapa.ai/auth/public/token") + expect(tokenCall[1].method).toBe("POST") + const body = new URLSearchParams(tokenCall[1].body) + expect(body.get("grant_type")).toBe("authorization_code") + expect(body.get("code")).toBe("auth-code-abc") + expect(body.get("client_id")).toBe("client-id-123") + expect(body.get("client_secret")).toBe("client-secret-abc") + expect(body.get("code_verifier")).toBe("pkce-verifier-123") + expect(body.get("redirect_uri")).toBe("http://localhost:12345/callback") + + // Verify tokens were saved + const saved = await provider.tokens() + expect(saved).toEqual(tokenResponse) + + await provider.close() + }) + + it("should throw when no token_endpoint is available", async () => { + ;(discoverOAuthProtectedResourceMetadata as any).mockResolvedValueOnce({ + authorization_servers: ["https://auth.example.com"], + }) + // Return metadata without token_endpoint + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + // no token_endpoint + }), + }) + + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + await provider.saveClientInformation({ client_id: "id", redirect_uris: [] } as any) + await provider.saveCodeVerifier("verifier") + + await expect(provider.exchangeCodeForTokens("code")).rejects.toThrow("No token_endpoint") + await provider.close() + }) + + it("should throw when no client information is available", async () => { + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + await provider.saveCodeVerifier("verifier") + + // No saveClientInformation called — should throw + await expect(provider.exchangeCodeForTokens("code")).rejects.toThrow("No client information") + await provider.close() + }) + + it("should throw when the token endpoint returns a non-OK response", async () => { + ;(discoverOAuthProtectedResourceMetadata as any).mockResolvedValueOnce({ + resource: "https://mcp.kapa.ai/", + authorization_servers: ["https://mcp.kapa.ai/auth/public"], + }) + mockFetch + .mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + issuer: "https://mcp.kapa.ai/auth/public", + authorization_endpoint: "https://mcp.kapa.ai/auth/public/authorize", + token_endpoint: "https://mcp.kapa.ai/auth/public/token", + token_endpoint_auth_methods_supported: ["client_secret_post"], + grant_types_supported: ["authorization_code"], + }), + }) + .mockResolvedValueOnce({ + ok: false, + status: 400, + text: () => Promise.resolve('{"error":"invalid_grant"}'), + }) + + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://mcp.kapa.ai/mcp", createMockSecretStorage()) + + await provider.saveClientInformation({ client_id: "id", redirect_uris: [] } as any) + await provider.saveCodeVerifier("verifier") + + await expect(provider.exchangeCodeForTokens("bad-code")).rejects.toThrow("Token exchange failed: HTTP 400") + await provider.close() + }) + }) + + describe("waitForAuthCode", () => { + it("should resolve with auth code from callback server", async () => { + setupCallbackServerMock("my-code") + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + const code = await provider.waitForAuthCode() + expect(code).toBe("my-code") + await provider.close() + }) + + it("should reject if callback returns error", async () => { + const mockServer = { close: vi.fn((cb: () => void) => cb()) } + ;(startCallbackServer as any).mockResolvedValue({ + server: mockServer, + port: 12345, + result: Promise.resolve({ error: "access_denied" }), + }) + + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + await expect(provider.waitForAuthCode()).rejects.toThrow("OAuth authorization failed: access_denied") + await provider.close() + }) + + it("should reject if callback returns no code", async () => { + const mockServer = { close: vi.fn((cb: () => void) => cb()) } + ;(startCallbackServer as any).mockResolvedValue({ + server: mockServer, + port: 12345, + result: Promise.resolve({}), + }) + + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + await expect(provider.waitForAuthCode()).rejects.toThrow("No authorization code received") + await provider.close() + }) + }) + + describe("close", () => { + it("should stop the callback server", async () => { + const { mockServer } = setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + await provider.close() + + expect(stopCallbackServer).toHaveBeenCalledWith(mockServer) + }) + + it("should be idempotent", async () => { + setupCallbackServerMock() + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", createMockSecretStorage()) + + await provider.close() + await provider.close() + + expect(stopCallbackServer).toHaveBeenCalledTimes(1) + }) + }) + + describe("registerClientIfNeeded", () => { + it("should reuse cached client_id when redirect_uri matches", async () => { + setupCallbackServerMock() + const secretStorage = createMockSecretStorage() + + // Pre-populate storage with cached data + await secretStorage.saveOAuthData("https://example.com/mcp", { + tokens: { access_token: "cached-token", token_type: "Bearer" }, + expires_at: Date.now() + 3600000, + client_id: "cached-client-id", + redirect_uri: "http://localhost:12345/callback", + }) + + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", secretStorage) + await provider.registerClientIfNeeded() + + expect((await provider.clientInformation())?.client_id).toBe("cached-client-id") + await provider.close() + }) + + it("should not reuse cached client_id when redirect_uri does not match", async () => { + setupCallbackServerMock() + const secretStorage = createMockSecretStorage() + + // Clear previous mocks and set up for this test + mockFetch.mockClear() + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + registration_endpoint: "https://auth.example.com/register", + response_types_supported: ["code"], + token_endpoint_auth_methods_supported: ["none"], + grant_types_supported: ["authorization_code", "refresh_token"], + }), + }) + mockFetch.mockResolvedValueOnce({ + ok: true, + json: () => + Promise.resolve({ + client_id: "new-client-id", + redirect_uris: ["http://localhost:12345/callback"], + client_name: "Roo Code", + grant_types: ["authorization_code", "refresh_token"], + response_types: ["code"], + token_endpoint_auth_method: "none", + }), + }) + + // Pre-populate storage with cached data with different redirect_uri + await secretStorage.saveOAuthData("https://example.com/mcp", { + tokens: { access_token: "cached-token", token_type: "Bearer" }, + expires_at: Date.now() + 3600000, + client_id: "cached-client-id", + redirect_uri: "http://localhost:99999/callback", // different port + }) + + const provider = await McpOAuthClientProvider.create("https://example.com/mcp", secretStorage) + await provider.registerClientIfNeeded() + + expect((await provider.clientInformation())?.client_id).toBe("new-client-id") + await provider.close() + }) + }) +}) diff --git a/src/services/mcp/__tests__/SecretStorageService.spec.ts b/src/services/mcp/__tests__/SecretStorageService.spec.ts new file mode 100644 index 00000000000..12709347d0c --- /dev/null +++ b/src/services/mcp/__tests__/SecretStorageService.spec.ts @@ -0,0 +1,102 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" + +vi.mock("vscode", () => ({})) + +import { SecretStorageService, StoredMcpOAuthData } from "../SecretStorageService" + +function createMockContext() { + const store = new Map() + return { + secrets: { + get: vi.fn(async (key: string) => store.get(key)), + store: vi.fn(async (key: string, value: string) => { + store.set(key, value) + }), + delete: vi.fn(async (key: string) => { + store.delete(key) + }), + }, + } as any +} + +describe("SecretStorageService", () => { + let service: SecretStorageService + let context: ReturnType + + beforeEach(() => { + context = createMockContext() + service = new SecretStorageService(context) + }) + + describe("getOAuthData", () => { + it("should return undefined when no data stored", async () => { + const result = await service.getOAuthData("https://example.com/mcp") + expect(result).toBeUndefined() + }) + + it("should return stored data", async () => { + const data: StoredMcpOAuthData = { + tokens: { access_token: "tok", token_type: "Bearer" }, + expires_at: Date.now() + 3600_000, + } + await service.saveOAuthData("https://example.com/mcp", data) + + const result = await service.getOAuthData("https://example.com/mcp") + expect(result).toEqual(data) + }) + + it("should return undefined for malformed JSON", async () => { + // Manually store garbage via the underlying mock + context.secrets.store("mcp.oauth.example.com.data", "not-json") + + const result = await service.getOAuthData("https://example.com/mcp") + expect(result).toBeUndefined() + }) + }) + + describe("saveOAuthData", () => { + it("should persist data under host-based key", async () => { + const data: StoredMcpOAuthData = { + tokens: { access_token: "abc", token_type: "Bearer" }, + expires_at: 12345, + } + await service.saveOAuthData("https://example.com/mcp", data) + + expect(context.secrets.store).toHaveBeenCalledWith("mcp.oauth.example.com.data", JSON.stringify(data)) + }) + }) + + describe("deleteOAuthData", () => { + it("should delete stored data", async () => { + const data: StoredMcpOAuthData = { + tokens: { access_token: "tok", token_type: "Bearer" }, + expires_at: Date.now() + 3600_000, + } + await service.saveOAuthData("https://example.com/mcp", data) + + await service.deleteOAuthData("https://example.com/mcp") + + expect(context.secrets.delete).toHaveBeenCalledWith("mcp.oauth.example.com.data") + const result = await service.getOAuthData("https://example.com/mcp") + expect(result).toBeUndefined() + }) + }) + + describe("key isolation", () => { + it("should isolate data by host", async () => { + const data1: StoredMcpOAuthData = { + tokens: { access_token: "a", token_type: "Bearer" }, + expires_at: 1, + } + const data2: StoredMcpOAuthData = { + tokens: { access_token: "b", token_type: "Bearer" }, + expires_at: 2, + } + await service.saveOAuthData("https://host1.com/mcp", data1) + await service.saveOAuthData("https://host2.com/mcp", data2) + + expect((await service.getOAuthData("https://host1.com/mcp"))?.tokens.access_token).toBe("a") + expect((await service.getOAuthData("https://host2.com/mcp"))?.tokens.access_token).toBe("b") + }) + }) +}) diff --git a/src/services/mcp/utils/__tests__/callbackServer.spec.ts b/src/services/mcp/utils/__tests__/callbackServer.spec.ts new file mode 100644 index 00000000000..3159ffbae22 --- /dev/null +++ b/src/services/mcp/utils/__tests__/callbackServer.spec.ts @@ -0,0 +1,101 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" +import { startCallbackServer, stopCallbackServer } from "../callbackServer" +import * as http from "http" + +vi.mock("http", () => ({ + createServer: vi.fn(), +})) + +describe("startCallbackServer", () => { + beforeEach(() => { + vi.restoreAllMocks() + }) + + it("should start server and resolve with callback result", async () => { + const mockServer = { + listen: vi.fn((port, host, callback) => { + callback() + return mockServer + }), + address: vi.fn(() => ({ port: 3000 })), + on: vi.fn(), + close: vi.fn(), + } + + ;(http.createServer as any).mockReturnValue(mockServer) + + const promise = startCallbackServer() + const { server, port, result } = await promise + + expect(port).toBe(3000) + expect(server).toBe(mockServer) + + // Simulate callback request + const requestCall = mockServer.on.mock.calls.find((call) => call[0] === "request") + const requestHandler = requestCall ? requestCall[1] : vi.fn() + const mockReq = { + url: "/callback?code=test-code&state=test-state", + method: "GET", + } + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + on: vi.fn((event, cb) => { + if (event === "finish") setImmediate(cb) + }), + } + + requestHandler(mockReq, mockRes) + + const callbackResult = await result + expect(callbackResult.code).toBe("test-code") + expect(callbackResult.state).toBe("test-state") + }) + + it("should reject invalid state", async () => { + const mockServer = { + listen: vi.fn((port, host, callback) => { + callback() + return mockServer + }), + address: vi.fn(() => ({ port: 3000 })), + on: vi.fn(), + close: vi.fn(), + } + + ;(http.createServer as any).mockReturnValue(mockServer) + + const promise = startCallbackServer(undefined, "expected-state") + const { result } = await promise + + // Simulate callback request with wrong state + const requestCall = mockServer.on.mock.calls.find((call) => call[0] === "request") + const requestHandler = requestCall ? requestCall[1] : vi.fn() + const mockReq = { + url: "/callback?code=test-code&state=wrong-state", + method: "GET", + } + const mockRes = { + writeHead: vi.fn(), + end: vi.fn(), + on: vi.fn((event, cb) => { + if (event === "finish") setImmediate(cb) + }), + } + + requestHandler(mockReq, mockRes) + + await expect(result).rejects.toThrow("Invalid state parameter") + }) +}) + +describe("stopCallbackServer", () => { + it("should close the server", async () => { + const mockServer = { + close: vi.fn((callback) => callback()), + } + + await stopCallbackServer(mockServer as any) + expect(mockServer.close).toHaveBeenCalled() + }) +}) diff --git a/src/services/mcp/utils/__tests__/oauth.spec.ts b/src/services/mcp/utils/__tests__/oauth.spec.ts new file mode 100644 index 00000000000..e2242c8ef87 --- /dev/null +++ b/src/services/mcp/utils/__tests__/oauth.spec.ts @@ -0,0 +1,148 @@ +import { describe, it, expect, vi, beforeEach } from "vitest" + +// Mock the SDK's discoverOAuthProtectedResourceMetadata +vi.mock("@modelcontextprotocol/sdk/client/auth.js", () => ({ + discoverOAuthProtectedResourceMetadata: vi.fn(), +})) + +import { discoverOAuthProtectedResourceMetadata } from "@modelcontextprotocol/sdk/client/auth.js" +import { fetchOAuthAuthServerMetadata } from "../oauth" + +const mockFetch = vi.fn() +global.fetch = mockFetch + +describe("fetchOAuthAuthServerMetadata", () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it("returns null when resource metadata has no authorization_servers", async () => { + ;(discoverOAuthProtectedResourceMetadata as any).mockResolvedValue({ + resource: "https://example.com/", + authorization_servers: [], + }) + + const result = await fetchOAuthAuthServerMetadata("https://example.com/mcp") + expect(result).toBeNull() + }) + + it("returns null when discoverOAuthProtectedResourceMetadata throws", async () => { + ;(discoverOAuthProtectedResourceMetadata as any).mockRejectedValue(new Error("network error")) + + const result = await fetchOAuthAuthServerMetadata("https://example.com/mcp") + expect(result).toBeNull() + }) + + it("constructs the RFC 8414 discovery URL correctly for an issuer with a path", async () => { + ;(discoverOAuthProtectedResourceMetadata as any).mockResolvedValue({ + resource: "https://mcp.kapa.ai/", + authorization_servers: ["https://mcp.kapa.ai/auth/public"], + }) + + const mockMeta = { + issuer: "https://mcp.kapa.ai/auth/public", + registration_endpoint: "https://mcp.kapa.ai/auth/public/register", + } + mockFetch.mockResolvedValueOnce({ ok: true, json: () => Promise.resolve(mockMeta) }) + + const result = await fetchOAuthAuthServerMetadata("https://mcp.kapa.ai/mcp") + + // Verify the RFC 8414 §3.1 URL: well-known inserted between host and path + expect(mockFetch).toHaveBeenCalledWith( + "https://mcp.kapa.ai/.well-known/oauth-authorization-server/auth/public", + expect.objectContaining({ headers: { Accept: "application/json" } }), + ) + expect(result).toEqual({ authServerMeta: mockMeta, resourceIndicator: "https://mcp.kapa.ai/" }) + }) + + it("constructs the RFC 8414 discovery URL correctly for an issuer without a path", async () => { + ;(discoverOAuthProtectedResourceMetadata as any).mockResolvedValue({ + resource: "https://auth.example.com/", + authorization_servers: ["https://auth.example.com"], + }) + + const mockMeta = { issuer: "https://auth.example.com" } + mockFetch.mockResolvedValueOnce({ ok: true, json: () => Promise.resolve(mockMeta) }) + + await fetchOAuthAuthServerMetadata("https://auth.example.com/mcp") + + expect(mockFetch).toHaveBeenCalledWith( + "https://auth.example.com/.well-known/oauth-authorization-server", + expect.any(Object), + ) + }) + + it("strips trailing slash from issuer path before inserting well-known", async () => { + ;(discoverOAuthProtectedResourceMetadata as any).mockResolvedValue({ + resource: "https://example.com/", + authorization_servers: ["https://example.com/issuer/"], + }) + + mockFetch.mockResolvedValueOnce({ ok: true, json: () => Promise.resolve({}) }) + + await fetchOAuthAuthServerMetadata("https://example.com/mcp") + + expect(mockFetch).toHaveBeenCalledWith( + "https://example.com/.well-known/oauth-authorization-server/issuer", + expect.any(Object), + ) + }) + + it("returns null when the discovery endpoint returns a non-OK response", async () => { + ;(discoverOAuthProtectedResourceMetadata as any).mockResolvedValue({ + resource: "https://example.com/", + authorization_servers: ["https://auth.example.com"], + }) + + mockFetch.mockResolvedValueOnce({ ok: false, status: 404 }) + + const result = await fetchOAuthAuthServerMetadata("https://example.com/mcp") + expect(result).toBeNull() + }) + + it("returns null when fetch throws", async () => { + ;(discoverOAuthProtectedResourceMetadata as any).mockResolvedValue({ + resource: "https://example.com/", + authorization_servers: ["https://auth.example.com"], + }) + + mockFetch.mockRejectedValueOnce(new Error("connection refused")) + + const result = await fetchOAuthAuthServerMetadata("https://example.com/mcp") + expect(result).toBeNull() + }) + + it("returns the parsed metadata and resource indicator on success", async () => { + ;(discoverOAuthProtectedResourceMetadata as any).mockResolvedValue({ + resource: "https://example.com/", + authorization_servers: ["https://auth.example.com/oauth2"], + }) + + const meta = { + issuer: "https://auth.example.com/oauth2", + authorization_endpoint: "https://auth.example.com/oauth2/authorize", + token_endpoint: "https://auth.example.com/oauth2/token", + registration_endpoint: "https://auth.example.com/oauth2/register", + token_endpoint_auth_methods_supported: ["client_secret_post", "client_secret_basic"], + grant_types_supported: ["authorization_code", "refresh_token"], + scopes_supported: ["openid"], + } + mockFetch.mockResolvedValueOnce({ ok: true, json: () => Promise.resolve(meta) }) + + const result = await fetchOAuthAuthServerMetadata("https://example.com/mcp") + expect(result).toEqual({ authServerMeta: meta, resourceIndicator: "https://example.com/" }) + }) + + it("returns null resourceIndicator when protected resource metadata has no resource field", async () => { + ;(discoverOAuthProtectedResourceMetadata as any).mockResolvedValue({ + authorization_servers: ["https://auth.example.com"], + // no `resource` field + }) + + const meta = { issuer: "https://auth.example.com" } + mockFetch.mockResolvedValueOnce({ ok: true, json: () => Promise.resolve(meta) }) + + const result = await fetchOAuthAuthServerMetadata("https://example.com/mcp") + expect(result).toEqual({ authServerMeta: meta, resourceIndicator: null }) + }) +}) diff --git a/src/services/mcp/utils/callbackServer.ts b/src/services/mcp/utils/callbackServer.ts new file mode 100644 index 00000000000..ecc266f3946 --- /dev/null +++ b/src/services/mcp/utils/callbackServer.ts @@ -0,0 +1,197 @@ +import * as http from "http" + +export interface CallbackResult { + code?: string + error?: string + error_description?: string + state?: string +} + +/** + * Starts a local HTTP server to handle OAuth callback. + * @param port Optional port to use (defaults to random available port) + * @param expectedState Optional expected state for CSRF protection + * @returns Promise<{server: http.Server, port: number, result: Promise}> + */ +export function startCallbackServer( + port?: number, + expectedState?: string, +): Promise<{ + server: http.Server + port: number + result: Promise +}> { + // In test mode, immediately resolve with mock data + if (process.env.MCP_OAUTH_TEST_MODE === "true") { + return new Promise((resolve) => { + const mockServer = http.createServer() + resolve({ + server: mockServer, + port: 3000, + result: Promise.resolve({ code: "test-auth-code", state: expectedState }), + }) + }) + } + + return new Promise((resolve, reject) => { + const server = http.createServer() + + server.listen(port || 0, "127.0.0.1", () => { + const address = server.address() + if (!address || typeof address === "string") { + reject(new Error("Failed to get server address")) + return + } + + const actualPort = address.port + + const resultPromise = new Promise((resolveResult, rejectResult) => { + let resolved = false + + const timeout = setTimeout( + () => { + if (!resolved) { + resolved = true + rejectResult(new Error("Callback timeout")) + server.close() + } + }, + 5 * 60 * 1000, + ) // 5 minutes + + server.on("request", (req: any, res: any) => { + if (resolved) return + + const url = new URL(req.url || "", `http://localhost:${actualPort}`) + const pathname = url.pathname + + if (pathname === "/callback") { + resolved = true + clearTimeout(timeout) + + const code = url.searchParams.get("code") + const error = url.searchParams.get("error") + const errorDescription = url.searchParams.get("error_description") + const state = url.searchParams.get("state") + + // Verify state for CSRF protection + if (expectedState && state !== expectedState) { + res.writeHead(400, { "Content-Type": "text/html" }) + res.end(` + + + + OAuth Callback + + +

OAuth Authentication Failed

+

Error: Invalid state parameter

+ + + `) + rejectResult(new Error("Invalid state parameter")) + return + } + + // Send HTML response + res.writeHead(200, { "Content-Type": "text/html" }) + res.end(` + + + + OAuth Callback - Roo Code + + + + +

${error ? "Failed" : "Success!"}

+
+

+ ${ + error + ? "Authentication failed. Please check the MCP server logs." + : "MCP server authenticated successfully. You can now close this browser tab." + } +

+
The server connection is complete.
+ + + + `) + + resolveResult({ + code: code || undefined, + error: error || undefined, + error_description: errorDescription || undefined, + state: state || undefined, + }) + + // Close server immediately after response drains + res.on("finish", () => { + server.close() + }) + } else { + res.writeHead(404) + res.end("Not found") + } + }) + + server.on("error", (error: any) => { + if (!resolved) { + resolved = true + clearTimeout(timeout) + rejectResult(error) + } + }) + }) + + resolve({ + server, + port: actualPort, + result: resultPromise, + }) + }) + + server.on("error", reject) + }) +} + +/** + * Stops the callback server. + * @param server The HTTP server to stop + */ +export function stopCallbackServer(server: http.Server): Promise { + return new Promise((resolve) => { + server.close(() => resolve()) + }) +} diff --git a/src/services/mcp/utils/oauth.ts b/src/services/mcp/utils/oauth.ts new file mode 100644 index 00000000000..8bdf2452ce6 --- /dev/null +++ b/src/services/mcp/utils/oauth.ts @@ -0,0 +1,78 @@ +import { discoverOAuthProtectedResourceMetadata } from "@modelcontextprotocol/sdk/client/auth.js" + +/** + * Result of a successful OAuth discovery for an MCP server. + */ +export interface OAuthDiscoveryResult { + /** The raw OAuth Authorization Server metadata (RFC 8414). */ + authServerMeta: Record + /** + * The RFC 8707 resource indicator — the `resource` field from the Protected + * Resource Metadata (RFC 9728). `null` when the server didn't advertise one. + * + * Must be sent as the `resource` query parameter in authorization requests so + * the auth server can scope the issued tokens to this specific resource server. + */ + resourceIndicator: string | null +} + +/** + * Fetches the raw OAuth Authorization Server metadata for an MCP server URL. + * + * This replaces the SDK's built-in `discoverOAuthMetadata()` because it + * constructs the RFC 8414 well-known URL incorrectly for auth servers with + * path components — a known bug tracked in multiple upstream issues: + * + * - https://github.com/modelcontextprotocol/typescript-sdk/issues/545 + * (URL constructor discards base path with leading-slash well-known) + * - https://github.com/modelcontextprotocol/typescript-sdk/issues/762 + * (uses MCP server URL instead of authorization server URL) + * - https://github.com/modelcontextprotocol/typescript-sdk/issues/744 + * (doesn't respect provided authorization server URL) + * - https://github.com/modelcontextprotocol/typescript-sdk/issues/822 + * (general RFC 8414 compliance — affects Keycloak, Okta, Azure Entra) + * + * Performs two discovery steps: + * 1. RFC 9728 – fetches the Protected Resource Metadata to find the issuer URL + * and the RFC 8707 resource indicator. + * 2. RFC 8414 §3.1 – constructs the well-known discovery URL by inserting + * `/.well-known/oauth-authorization-server` *between* the host and the issuer + * path (not appended after the path). + * + * Correct: https://example.com/.well-known/oauth-authorization-server/auth/public + * SDK wrong: https://example.com/auth/public/.well-known/oauth-authorization-server + * + * Returns an {@link OAuthDiscoveryResult} on success, or `null` if any step fails. + */ +export async function fetchOAuthAuthServerMetadata(serverUrl: string): Promise { + try { + // Step 1 – RFC 9728: resolve the authorization server issuer URL and + // capture the resource indicator for RFC 8707. + const resourceMeta = await discoverOAuthProtectedResourceMetadata(serverUrl) + const authServers = resourceMeta.authorization_servers + if (!authServers?.length) return null + + // RFC 8707: the `resource` field from the protected resource metadata is + // used as the `resource` parameter in the authorization request so the auth + // server can issue tokens scoped to this specific resource server. + const resourceIndicator: string | null = + typeof resourceMeta.resource === "string" ? resourceMeta.resource : null + + // Step 2 – RFC 8414 §3.1: build the well-known URL. + // For issuer "https://example.com/auth/public" + // → "https://example.com/.well-known/oauth-authorization-server/auth/public" + const parsed = new URL(authServers[0]) + const base = `${parsed.protocol}//${parsed.host}` + const issuePath = parsed.pathname.replace(/\/$/, "") || "" + const discoveryUrl = `${base}/.well-known/oauth-authorization-server${issuePath}` + + const response = await fetch(discoveryUrl, { + headers: { Accept: "application/json" }, + }) + if (!response.ok) return null + const authServerMeta = (await response.json()) as Record + return { authServerMeta, resourceIndicator } + } catch { + return null + } +}