Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ async def _exchange_token_authorization_code(

async def _handle_token_response(self, response: httpx.Response) -> None:
"""Handle token exchange response."""
if response.status_code != 200:
if response.status_code not in {200, 201}:
body = await response.aread()
body = body.decode("utf-8")
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body}")
Expand Down
110 changes: 110 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,116 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth(
# Verify exactly one request was yielded (no double-sending)
assert request_yields == 1, f"Expected 1 request yield, got {request_yields}"

@pytest.mark.anyio
async def test_token_exchange_accepts_201_status(
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage
):
"""Test that token exchange accepts both 200 and 201 status codes."""
# Ensure no tokens are stored
oauth_provider.context.current_tokens = None
oauth_provider.context.token_expiry_time = None
oauth_provider._initialized = True

# Create a test request
test_request = httpx.Request("GET", "https://api.example.com/mcp")

# Mock the auth flow
auth_flow = oauth_provider.async_auth_flow(test_request)

# First request should be the original request without auth header
request = await auth_flow.__anext__()
assert "Authorization" not in request.headers

# Send a 401 response to trigger the OAuth flow
response = httpx.Response(
401,
headers={
"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
},
request=test_request,
)

# Next request should be to discover protected resource metadata
discovery_request = await auth_flow.asend(response)
assert discovery_request.method == "GET"
assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource"

# Send a successful discovery response with minimal protected resource metadata
discovery_response = httpx.Response(
200,
content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}',
request=discovery_request,
)

# Next request should be to discover OAuth metadata
oauth_metadata_request = await auth_flow.asend(discovery_response)
assert oauth_metadata_request.method == "GET"
assert str(oauth_metadata_request.url).startswith("https://auth.example.com/")
assert "mcp-protocol-version" in oauth_metadata_request.headers

# Send a successful OAuth metadata response
oauth_metadata_response = httpx.Response(
200,
content=(
b'{"issuer": "https://auth.example.com", '
b'"authorization_endpoint": "https://auth.example.com/authorize", '
b'"token_endpoint": "https://auth.example.com/token", '
b'"registration_endpoint": "https://auth.example.com/register"}'
),
request=oauth_metadata_request,
)

# Next request should be to register client
registration_request = await auth_flow.asend(oauth_metadata_response)
assert registration_request.method == "POST"
assert str(registration_request.url) == "https://auth.example.com/register"

# Send a successful registration response with 201 status
registration_response = httpx.Response(
201,
content=b'{"client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_uris": ["http://localhost:3030/callback"]}',
request=registration_request,
)

# Mock the authorization process
oauth_provider._perform_authorization_code_grant = mock.AsyncMock(
return_value=("test_auth_code", "test_code_verifier")
)

# Next request should be to exchange token
token_request = await auth_flow.asend(registration_response)
assert token_request.method == "POST"
assert str(token_request.url) == "https://auth.example.com/token"
assert "code=test_auth_code" in token_request.content.decode()

# Send a successful token response with 201 status code (test both 200 and 201 are accepted)
token_response = httpx.Response(
201,
content=(
b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
b'"refresh_token": "new_refresh_token"}'
),
request=token_request,
)

# Final request should be the original request with auth header
final_request = await auth_flow.asend(token_response)
assert final_request.headers["Authorization"] == "Bearer new_access_token"
assert final_request.method == "GET"
assert str(final_request.url) == "https://api.example.com/mcp"

# Send final success response to properly close the generator
final_response = httpx.Response(200, request=final_request)
try:
await auth_flow.asend(final_response)
except StopAsyncIteration:
pass # Expected - generator should complete

# Verify tokens were stored
assert oauth_provider.context.current_tokens is not None
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
assert oauth_provider.context.token_expiry_time is not None

@pytest.mark.anyio
async def test_403_insufficient_scope_updates_scope_from_header(
self,
Expand Down