diff --git a/pyproject.toml b/pyproject.toml index 737839a23..34ee58d84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "uvicorn>=0.31.1; sys_platform != 'emscripten'", "jsonschema>=4.20.0", "pywin32>=311; sys_platform == 'win32'", + "authlib>=1.4.0", "pyjwt[crypto]>=2.10.1", "typing-extensions>=4.13.0", "typing-inspection>=0.4.1", diff --git a/src/mcp/client/auth/__init__.py b/src/mcp/client/auth/__init__.py index ab3179ecb..7be58d617 100644 --- a/src/mcp/client/auth/__init__.py +++ b/src/mcp/client/auth/__init__.py @@ -3,6 +3,7 @@ Implements authorization code flow with PKCE and automatic token refresh. """ +from mcp.client.auth.authlib_adapter import AuthlibAdapterConfig, AuthlibOAuthAdapter from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError from mcp.client.auth.oauth2 import ( OAuthClientProvider, @@ -11,6 +12,8 @@ ) __all__ = [ + "AuthlibAdapterConfig", + "AuthlibOAuthAdapter", "OAuthClientProvider", "OAuthFlowError", "OAuthRegistrationError", diff --git a/src/mcp/client/auth/authlib_adapter.py b/src/mcp/client/auth/authlib_adapter.py new file mode 100644 index 000000000..f59b047e5 --- /dev/null +++ b/src/mcp/client/auth/authlib_adapter.py @@ -0,0 +1,307 @@ +"""Authlib-backed OAuth2 adapter for MCP HTTPX integration. + +Provides :class:`AuthlibOAuthAdapter`, an ``httpx.Auth`` plugin that wraps +``authlib.integrations.httpx_client.AsyncOAuth2Client`` to handle token +acquisition, automatic refresh, and Bearer-header injection. + +The adapter is a drop-in replacement for :class:`~mcp.client.auth.OAuthClientProvider` +when you already have OAuth endpoints and credentials (i.e. no MCP-specific +metadata discovery is needed). For full MCP discovery (PRM / OASM / DCR), +continue to use :class:`~mcp.client.auth.OAuthClientProvider`. + +Supported grant types in this release: +- ``client_credentials`` — fully self-contained (no browser interaction) +- ``authorization_code`` + PKCE — requires *redirect_handler* / *callback_handler* + +Example (client_credentials):: + + from mcp.client.auth import AuthlibAdapterConfig, AuthlibOAuthAdapter + + config = AuthlibAdapterConfig( + token_endpoint="https://auth.example.com/token", + client_id="my-client", + client_secret="secret", + scopes=["read", "write"], + ) + adapter = AuthlibOAuthAdapter(config=config, storage=InMemoryTokenStorage()) + async with httpx.AsyncClient(auth=adapter) as client: + resp = await client.get("https://api.example.com/resource") +""" + +from __future__ import annotations + +import logging +import secrets +import string +from collections.abc import AsyncGenerator, Awaitable, Callable +from typing import Any, Protocol + +import anyio +import httpx +from authlib.integrations.httpx_client import AsyncOAuth2Client # type: ignore[import-untyped] +from pydantic import BaseModel, Field + +from mcp.client.auth.exceptions import OAuthFlowError +from mcp.client.auth.oauth2 import TokenStorage +from mcp.shared.auth import OAuthToken + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Internal protocol — typed interface for untyped Authlib client +# --------------------------------------------------------------------------- + + +class _AsyncOAuth2ClientProtocol(Protocol): + """Minimal typed interface for authlib.integrations.httpx_client.AsyncOAuth2Client. + + Defined as a Protocol so that pyright strict mode can type-check all member + accesses on the Authlib client without requiring upstream type stubs. + """ + + token: dict[str, Any] | None + scope: str | None + code_challenge_method: str + + async def fetch_token(self, url: str, **kwargs: Any) -> dict[str, Any]: ... # pragma: no cover + + def create_authorization_url(self, url: str, **kwargs: Any) -> tuple[str, str]: ... # pragma: no cover + + async def ensure_active_token(self, token: dict[str, Any]) -> None: ... # pragma: no cover + + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + + +class AuthlibAdapterConfig(BaseModel): + """Configuration for :class:`AuthlibOAuthAdapter`. + + Args: + token_endpoint: URL of the OAuth 2.0 token endpoint (required). + client_id: OAuth client identifier (required). + client_secret: OAuth client secret; omit for public clients. + scopes: List of OAuth scopes to request. + token_endpoint_auth_method: How to authenticate at the token endpoint. + Accepted values: ``"client_secret_basic"`` (default), + ``"client_secret_post"``, ``"none"``. + authorization_endpoint: URL of the authorization endpoint. When set, + the adapter uses the *authorization_code + PKCE* grant on 401; when + ``None`` (default) it uses *client_credentials*. + redirect_uri: Redirect URI registered with the authorization server. + Required when *authorization_endpoint* is set. + leeway: Seconds before token expiry at which automatic refresh is + triggered (default: 60). + extra_token_params: Additional key-value pairs forwarded verbatim to + every ``fetch_token`` call (e.g. ``{"audience": "..."}``). + """ + + token_endpoint: str + client_id: str + client_secret: str | None = Field(default=None, repr=False) # excluded from repr to prevent secret leakage + scopes: list[str] | None = None + token_endpoint_auth_method: str = "client_secret_basic" + # authorization_code flow (optional) + authorization_endpoint: str | None = None + redirect_uri: str | None = None + # Authlib tuning + leeway: int = 60 + extra_token_params: dict[str, Any] | None = None + + +# --------------------------------------------------------------------------- +# Adapter +# --------------------------------------------------------------------------- + + +class AuthlibOAuthAdapter(httpx.Auth): + """Authlib-backed ``httpx.Auth`` provider. + + Wraps :class:`authlib.integrations.httpx_client.AsyncOAuth2Client` as a + drop-in ``httpx.Auth`` plugin. Token storage is delegated to the same + :class:`~mcp.client.auth.TokenStorage` protocol used by the existing + :class:`~mcp.client.auth.OAuthClientProvider`. + + Args: + config: Adapter configuration (endpoints, credentials, scopes …). + storage: Token persistence implementation. + redirect_handler: Async callback that receives the authorization URL + and opens it (browser, print, etc.). Required for + *authorization_code* flow. + callback_handler: Async callback that waits for the user to complete + authorization and returns ``(code, state)``. Required for + *authorization_code* flow. + """ + + requires_response_body = True + + def __init__( + self, + config: AuthlibAdapterConfig, + storage: TokenStorage, + redirect_handler: Callable[[str], Awaitable[None]] | None = None, + callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, + ) -> None: + self.config = config + self.storage = storage + self.redirect_handler = redirect_handler + self.callback_handler = callback_handler + self._lock: anyio.Lock = anyio.Lock() + self._initialized: bool = False + + scope_str = " ".join(config.scopes) if config.scopes else None + self._client: _AsyncOAuth2ClientProtocol = AsyncOAuth2Client( # type: ignore[assignment] + client_id=config.client_id, + client_secret=config.client_secret, + scope=scope_str, + redirect_uri=config.redirect_uri, + token_endpoint_auth_method=config.token_endpoint_auth_method, + update_token=self._on_token_update, + leeway=config.leeway, + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + async def _on_token_update( + self, + token: dict[str, Any], + refresh_token: str | None = None, # noqa: ARG002 (Authlib callback signature) + access_token: str | None = None, # noqa: ARG002 + ) -> None: + """Authlib ``update_token`` callback — persists refreshed tokens.""" + oauth_token = OAuthToken( + access_token=token["access_token"], + token_type=token.get("token_type", "Bearer"), + expires_in=token.get("expires_in"), + scope=token.get("scope"), + refresh_token=token.get("refresh_token"), + ) + await self.storage.set_tokens(oauth_token) + + async def _initialize(self) -> None: + """Load persisted tokens into the Authlib client on first use.""" + stored = await self.storage.get_tokens() + if stored: + token_dict: dict[str, Any] = { + "access_token": stored.access_token, + "token_type": stored.token_type, + } + if stored.refresh_token is not None: + token_dict["refresh_token"] = stored.refresh_token + if stored.scope is not None: + token_dict["scope"] = stored.scope + if stored.expires_in is not None: + token_dict["expires_in"] = stored.expires_in + self._client.token = token_dict + self._initialized = True + + def _build_token_request_params(self) -> dict[str, Any]: + """Merge base params with any extra params from config.""" + params: dict[str, Any] = {} + if self.config.extra_token_params: + params.update(self.config.extra_token_params) + return params + + async def _fetch_client_credentials_token(self) -> None: + """Acquire a token via the *client_credentials* grant.""" + params = self._build_token_request_params() + await self._client.fetch_token( + self.config.token_endpoint, + grant_type="client_credentials", + **params, + ) + if self._client.token: + await self._on_token_update(dict(self._client.token)) + + async def _perform_authorization_code_flow(self) -> None: + """Acquire a token via *authorization_code + PKCE* grant. + + Raises: + OAuthFlowError: If *redirect_handler*, *callback_handler*, + *authorization_endpoint*, or *redirect_uri* are missing. + """ + if not self.config.authorization_endpoint: + raise OAuthFlowError("authorization_endpoint is required for authorization_code flow") + if not self.config.redirect_uri: + raise OAuthFlowError("redirect_uri is required for authorization_code flow") + if self.redirect_handler is None: + raise OAuthFlowError("redirect_handler is required for authorization_code flow") + if self.callback_handler is None: + raise OAuthFlowError("callback_handler is required for authorization_code flow") + + # Generate PKCE state + build authorization URL via Authlib + state = secrets.token_urlsafe(32) + # Authlib generates code_verifier/code_challenge internally when + # code_challenge_method is set on the client. + self._client.code_challenge_method = "S256" + # Generate a random code_verifier (Authlib will compute the challenge) + code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128)) + + auth_url, _ = self._client.create_authorization_url( + self.config.authorization_endpoint, + state=state, + code_verifier=code_verifier, + ) + + await self.redirect_handler(auth_url) + auth_code, returned_state = await self.callback_handler() + + if returned_state is None or not secrets.compare_digest(returned_state, state): + raise OAuthFlowError(f"State mismatch: {returned_state!r} != {state!r}") + if not auth_code: + raise OAuthFlowError("No authorization code received from callback") + + params = self._build_token_request_params() + await self._client.fetch_token( + self.config.token_endpoint, + grant_type="authorization_code", + code=auth_code, + redirect_uri=self.config.redirect_uri, + code_verifier=code_verifier, + **params, + ) + if self._client.token: + await self._on_token_update(dict(self._client.token)) + + def _inject_bearer(self, request: httpx.Request) -> None: + """Add ``Authorization: Bearer `` header if a token is held.""" + token = self._client.token + if token and token.get("access_token"): + request.headers["Authorization"] = f"Bearer {token['access_token']}" + + # ------------------------------------------------------------------ + # httpx.Auth entry point + # ------------------------------------------------------------------ + + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: + """HTTPX auth flow: ensure a valid token then inject it into the request. + + On a ``401`` response the adapter acquires a fresh token (via + *client_credentials* or *authorization_code*) and retries once. + """ + async with self._lock: + if not self._initialized: + await self._initialize() + + # Let Authlib auto-refresh if the token is close to expiry + if self._client.token: + await self._client.ensure_active_token(self._client.token) + + self._inject_bearer(request) + + response = yield request + + if response.status_code == 401: + async with self._lock: + # Acquire a brand-new token + if self.config.authorization_endpoint: + await self._perform_authorization_code_flow() + else: + await self._fetch_client_credentials_token() + self._inject_bearer(request) + + yield request diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index 0ca36b98d..f179213f9 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -4,7 +4,7 @@ from httpx import Request, Response from pydantic import AnyUrl, ValidationError -from mcp.client.auth import OAuthRegistrationError, OAuthTokenError +from mcp.client.auth.exceptions import OAuthRegistrationError, OAuthTokenError from mcp.client.streamable_http import MCP_PROTOCOL_VERSION from mcp.shared.auth import ( OAuthClientInformationFull, diff --git a/tests/client/auth/__init__.py b/tests/client/auth/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/client/auth/test_authlib_adapter.py b/tests/client/auth/test_authlib_adapter.py new file mode 100644 index 000000000..27ce625c6 --- /dev/null +++ b/tests/client/auth/test_authlib_adapter.py @@ -0,0 +1,684 @@ +"""Tests for AuthlibOAuthAdapter and AuthlibAdapterConfig. + +Follows codebase conventions: +- Function-based tests (no Test-prefixed classes) +- @pytest.mark.anyio for all async tests +- Mocks via unittest.mock; never a fixed sleep for async waits +- 100% branch coverage target +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from mcp.client.auth import AuthlibAdapterConfig, AuthlibOAuthAdapter +from mcp.client.auth.exceptions import OAuthFlowError +from mcp.shared.auth import OAuthToken + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +class _InMemoryStorage: + """Minimal in-memory TokenStorage for tests.""" + + def __init__(self, initial: OAuthToken | None = None) -> None: + self._token = initial + self._client_info = None + + async def get_tokens(self) -> OAuthToken | None: + return self._token + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._token = tokens + + async def get_client_info(self) -> None: + return self._client_info # pragma: no cover + + async def set_client_info(self, client_info: Any) -> None: # pragma: no cover + self._client_info = client_info + + +def _make_config(**kwargs: Any) -> AuthlibAdapterConfig: + defaults: dict[str, Any] = { + "token_endpoint": "https://auth.example.com/token", + "client_id": "test-client", + "client_secret": "test-secret", + } + defaults.update(kwargs) + return AuthlibAdapterConfig(**defaults) + + +def _make_adapter(config: AuthlibAdapterConfig | None = None, **kwargs: Any) -> AuthlibOAuthAdapter: + return AuthlibOAuthAdapter( + config=config or _make_config(), + storage=_InMemoryStorage(), + **kwargs, + ) + + +def _mock_response(status_code: int = 200) -> httpx.Response: + return httpx.Response(status_code, request=httpx.Request("GET", "https://api.example.com/")) + + +# --------------------------------------------------------------------------- +# AuthlibAdapterConfig tests +# --------------------------------------------------------------------------- + + +def test_config_defaults() -> None: + """Default values are applied correctly.""" + cfg = AuthlibAdapterConfig(token_endpoint="https://t.example.com/token", client_id="cid") + assert cfg.client_secret is None + assert cfg.scopes is None + assert cfg.token_endpoint_auth_method == "client_secret_basic" + assert cfg.authorization_endpoint is None + assert cfg.redirect_uri is None + assert cfg.leeway == 60 + assert cfg.extra_token_params is None + + +def test_config_all_fields() -> None: + """All fields are stored correctly when provided.""" + cfg = AuthlibAdapterConfig( + token_endpoint="https://t.example.com/token", + client_id="cid", + client_secret="s3cr3t", + scopes=["read", "write"], + token_endpoint_auth_method="client_secret_post", + authorization_endpoint="https://t.example.com/authorize", + redirect_uri="https://app.example.com/callback", + leeway=30, + extra_token_params={"audience": "https://api.example.com"}, + ) + assert cfg.client_secret == "s3cr3t" + assert cfg.scopes == ["read", "write"] + assert cfg.token_endpoint_auth_method == "client_secret_post" + assert cfg.authorization_endpoint == "https://t.example.com/authorize" + assert cfg.redirect_uri == "https://app.example.com/callback" + assert cfg.leeway == 30 + assert cfg.extra_token_params == {"audience": "https://api.example.com"} + + +# --------------------------------------------------------------------------- +# AuthlibOAuthAdapter construction +# --------------------------------------------------------------------------- + + +def test_adapter_construction_scope_joined() -> None: + """Scopes list is joined as a space-separated string for Authlib.""" + cfg = _make_config(scopes=["read", "write", "admin"]) + adapter = AuthlibOAuthAdapter(config=cfg, storage=_InMemoryStorage()) + assert adapter._client.scope == "read write admin" + + +def test_adapter_construction_no_scope() -> None: + """No scope param produces None scope on the Authlib client.""" + adapter = _make_adapter() + assert adapter._client.scope is None + + +def test_adapter_exported_from_package() -> None: + """AuthlibOAuthAdapter and AuthlibAdapterConfig are importable from the package root.""" + from mcp.client.auth import AuthlibAdapterConfig as Cfg + from mcp.client.auth import AuthlibOAuthAdapter as Adp + + assert Adp is AuthlibOAuthAdapter + assert Cfg is AuthlibAdapterConfig + + +# --------------------------------------------------------------------------- +# _initialize +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_initialize_loads_stored_token() -> None: + """Stored OAuthToken is converted to an Authlib token dict on init.""" + stored = OAuthToken( + access_token="at-123", + token_type="Bearer", + expires_in=3600, + scope="read", + refresh_token="rt-456", + ) + adapter = AuthlibOAuthAdapter(config=_make_config(), storage=_InMemoryStorage(stored)) + await adapter._initialize() + + tok = adapter._client.token + assert tok is not None + assert tok["access_token"] == "at-123" + assert tok["token_type"] == "Bearer" + assert tok["expires_in"] == 3600 + assert tok["scope"] == "read" + assert tok["refresh_token"] == "rt-456" + assert adapter._initialized is True + + +@pytest.mark.anyio +async def test_initialize_no_stored_token() -> None: + """With no persisted token, Authlib client token stays None.""" + adapter = AuthlibOAuthAdapter(config=_make_config(), storage=_InMemoryStorage(None)) + await adapter._initialize() + + assert adapter._client.token is None + assert adapter._initialized is True + + +@pytest.mark.anyio +async def test_initialize_token_without_optional_fields() -> None: + """Token with no refresh_token, scope, or expires_in loads correctly.""" + stored = OAuthToken(access_token="at-only", token_type="Bearer") + adapter = AuthlibOAuthAdapter(config=_make_config(), storage=_InMemoryStorage(stored)) + await adapter._initialize() + + tok = adapter._client.token + assert tok is not None + assert tok["access_token"] == "at-only" + assert "refresh_token" not in tok + assert "scope" not in tok + assert "expires_in" not in tok + + +# --------------------------------------------------------------------------- +# _on_token_update (storage callback) +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_on_token_update_persists_full_token() -> None: + """_on_token_update stores all fields via TokenStorage.set_tokens.""" + storage = _InMemoryStorage() + adapter = AuthlibOAuthAdapter(config=_make_config(), storage=storage) + + await adapter._on_token_update( + { + "access_token": "new-at", + "token_type": "bearer", + "expires_in": 1800, + "scope": "read write", + "refresh_token": "new-rt", + } + ) + + saved = await storage.get_tokens() + assert saved is not None + assert saved.access_token == "new-at" + assert saved.token_type == "Bearer" # normalized + assert saved.expires_in == 1800 + assert saved.scope == "read write" + assert saved.refresh_token == "new-rt" + + +@pytest.mark.anyio +async def test_on_token_update_missing_optional_fields() -> None: + """_on_token_update handles token dict without refresh_token / expires_in.""" + storage = _InMemoryStorage() + adapter = AuthlibOAuthAdapter(config=_make_config(), storage=storage) + + await adapter._on_token_update({"access_token": "bare-at"}) + + saved = await storage.get_tokens() + assert saved is not None + assert saved.access_token == "bare-at" + assert saved.refresh_token is None + assert saved.expires_in is None + + +# --------------------------------------------------------------------------- +# _fetch_client_credentials_token +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_fetch_client_credentials_calls_fetch_token() -> None: + """_fetch_client_credentials_token calls Authlib fetch_token with correct params.""" + adapter = _make_adapter() + adapter._initialized = True + + fake_token: dict[str, Any] = {"access_token": "cc-token", "token_type": "Bearer"} + + with patch.object(adapter._client, "fetch_token", new=AsyncMock(return_value=fake_token)): + adapter._client.token = fake_token # simulate Authlib setting it + with patch.object(adapter, "_on_token_update", new=AsyncMock()) as mock_update: + await adapter._fetch_client_credentials_token() + + mock_update.assert_awaited_once() + + +@pytest.mark.anyio +async def test_fetch_client_credentials_with_extra_params() -> None: + """extra_token_params are forwarded to fetch_token.""" + cfg = _make_config(extra_token_params={"audience": "https://api.example.com"}) + adapter = AuthlibOAuthAdapter(config=cfg, storage=_InMemoryStorage()) + adapter._initialized = True + adapter._client.token = None + + with patch.object(adapter._client, "fetch_token", new=AsyncMock()) as mock_ft: + await adapter._fetch_client_credentials_token() + + mock_ft.assert_awaited_once() + _, kwargs = mock_ft.call_args + assert kwargs.get("audience") == "https://api.example.com" + assert kwargs.get("grant_type") == "client_credentials" + + +@pytest.mark.anyio +async def test_fetch_client_credentials_no_extra_params() -> None: + """Without extra_token_params only grant_type is passed.""" + adapter = _make_adapter() + adapter._client.token = None + + with patch.object(adapter._client, "fetch_token", new=AsyncMock()) as mock_ft: + await adapter._fetch_client_credentials_token() + + _, kwargs = mock_ft.call_args + assert kwargs.get("grant_type") == "client_credentials" + assert "audience" not in kwargs + + +# --------------------------------------------------------------------------- +# _perform_authorization_code_flow — validation branches +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_auth_code_flow_missing_authorization_endpoint_raises() -> None: + """OAuthFlowError when authorization_endpoint is not set.""" + adapter = _make_adapter() + with pytest.raises(OAuthFlowError, match="authorization_endpoint"): + await adapter._perform_authorization_code_flow() + + +@pytest.mark.anyio +async def test_auth_code_flow_missing_redirect_uri_raises() -> None: + """OAuthFlowError when redirect_uri is not set.""" + cfg = _make_config(authorization_endpoint="https://auth.example.com/authorize") + adapter = AuthlibOAuthAdapter(config=cfg, storage=_InMemoryStorage()) + with pytest.raises(OAuthFlowError, match="redirect_uri"): + await adapter._perform_authorization_code_flow() + + +@pytest.mark.anyio +async def test_auth_code_flow_missing_redirect_handler_raises() -> None: + """OAuthFlowError when redirect_handler is None.""" + cfg = _make_config( + authorization_endpoint="https://auth.example.com/authorize", + redirect_uri="https://app.example.com/cb", + ) + adapter = AuthlibOAuthAdapter(config=cfg, storage=_InMemoryStorage()) + with pytest.raises(OAuthFlowError, match="redirect_handler"): + await adapter._perform_authorization_code_flow() + + +@pytest.mark.anyio +async def test_auth_code_flow_missing_callback_handler_raises() -> None: + """OAuthFlowError when callback_handler is None.""" + cfg = _make_config( + authorization_endpoint="https://auth.example.com/authorize", + redirect_uri="https://app.example.com/cb", + ) + adapter = AuthlibOAuthAdapter( + config=cfg, + storage=_InMemoryStorage(), + redirect_handler=AsyncMock(), + callback_handler=None, + ) + with pytest.raises(OAuthFlowError, match="callback_handler"): + await adapter._perform_authorization_code_flow() + + +@pytest.mark.anyio +async def test_auth_code_flow_state_mismatch_raises() -> None: + """OAuthFlowError when state returned by callback doesn't match.""" + cfg = _make_config( + authorization_endpoint="https://auth.example.com/authorize", + redirect_uri="https://app.example.com/cb", + ) + redirect_calls: list[str] = [] + + async def redirect(url: str) -> None: + redirect_calls.append(url) + + async def callback() -> tuple[str, str | None]: + return "some-code", "WRONG-STATE" + + adapter = AuthlibOAuthAdapter( + config=cfg, + storage=_InMemoryStorage(), + redirect_handler=redirect, + callback_handler=callback, + ) + + with patch.object( + adapter._client, + "create_authorization_url", + return_value=("https://auth.example.com/authorize?code_challenge=x", "correct-state"), + ): + with pytest.raises(OAuthFlowError, match="State mismatch"): + await adapter._perform_authorization_code_flow() + + +@pytest.mark.anyio +async def test_auth_code_flow_none_state_raises() -> None: + """OAuthFlowError when callback returns None as state (covers the is-None branch of the or-guard).""" + cfg = _make_config( + authorization_endpoint="https://auth.example.com/authorize", + redirect_uri="https://app.example.com/cb", + ) + + async def redirect(_url: str) -> None: + pass + + async def callback() -> tuple[str, str | None]: + return "some-code", None # state is None → first branch of the `or` fires + + adapter = AuthlibOAuthAdapter( + config=cfg, + storage=_InMemoryStorage(), + redirect_handler=redirect, + callback_handler=callback, + ) + + with pytest.raises(OAuthFlowError, match="State mismatch"): + await adapter._perform_authorization_code_flow() + + +@pytest.mark.anyio +async def test_auth_code_flow_no_token_after_fetch() -> None: + """When fetch_token leaves client.token as None, _on_token_update is NOT called.""" + cfg = _make_config( + authorization_endpoint="https://auth.example.com/authorize", + redirect_uri="https://app.example.com/cb", + ) + captured_state: list[str] = [] + + async def redirect(url: str) -> None: + from urllib.parse import parse_qs, urlparse + + qs = parse_qs(urlparse(url).query) + captured_state.extend(qs.get("state", [])) + + async def callback() -> tuple[str, str | None]: + return "auth-code-xyz", captured_state[0] if captured_state else None + + adapter = AuthlibOAuthAdapter( + config=cfg, + storage=_InMemoryStorage(), + redirect_handler=redirect, + callback_handler=callback, + ) + adapter._client.token = None # ensure no pre-existing token + + with ( + patch.object(adapter._client, "fetch_token", new=AsyncMock()), # fetch_token does NOT set token + patch.object(adapter, "_on_token_update", new=AsyncMock()) as mock_update, + ): + # Token remains None after fetch_token; _on_token_update must NOT be called + await adapter._perform_authorization_code_flow() + + mock_update.assert_not_awaited() + + +@pytest.mark.anyio +async def test_auth_code_flow_empty_code_raises() -> None: + """OAuthFlowError when callback returns empty authorization code.""" + cfg = _make_config( + authorization_endpoint="https://auth.example.com/authorize", + redirect_uri="https://app.example.com/cb", + ) + captured_state: list[str] = [] + + async def redirect(url: str) -> None: + # URL contains state= so we can extract it + from urllib.parse import parse_qs, urlparse + + qs = parse_qs(urlparse(url).query) + captured_state.extend(qs.get("state", [])) + + async def callback() -> tuple[str, str | None]: + return "", captured_state[0] if captured_state else None + + adapter = AuthlibOAuthAdapter( + config=cfg, + storage=_InMemoryStorage(), + redirect_handler=redirect, + callback_handler=callback, + ) + + with pytest.raises(OAuthFlowError, match="No authorization code"): + await adapter._perform_authorization_code_flow() + + +@pytest.mark.anyio +async def test_auth_code_flow_success() -> None: + """Happy path: redirect called, callback returns code+state, fetch_token invoked.""" + cfg = _make_config( + authorization_endpoint="https://auth.example.com/authorize", + redirect_uri="https://app.example.com/cb", + ) + captured_state: list[str] = [] + + async def redirect(url: str) -> None: + from urllib.parse import parse_qs, urlparse + + qs = parse_qs(urlparse(url).query) + captured_state.extend(qs.get("state", [])) + + async def callback() -> tuple[str, str | None]: + return "auth-code-xyz", captured_state[0] if captured_state else None + + storage = _InMemoryStorage() + adapter = AuthlibOAuthAdapter( + config=cfg, + storage=storage, + redirect_handler=redirect, + callback_handler=callback, + ) + + fake_token: dict[str, Any] = {"access_token": "code-at", "token_type": "Bearer"} + with patch.object(adapter._client, "fetch_token", new=AsyncMock()) as mock_ft: + adapter._client.token = fake_token + await adapter._perform_authorization_code_flow() + + mock_ft.assert_awaited_once() + _, kwargs = mock_ft.call_args + assert kwargs["grant_type"] == "authorization_code" + assert kwargs["code"] == "auth-code-xyz" + assert kwargs["redirect_uri"] == "https://app.example.com/cb" + assert "code_verifier" in kwargs + + saved = await storage.get_tokens() + assert saved is not None + assert saved.access_token == "code-at" + + +# --------------------------------------------------------------------------- +# async_auth_flow +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_auth_flow_injects_bearer_when_token_valid() -> None: + """On first use with a valid stored token, Bearer header is injected.""" + stored = OAuthToken(access_token="valid-at", token_type="Bearer") + adapter = AuthlibOAuthAdapter(config=_make_config(), storage=_InMemoryStorage(stored)) + + request = httpx.Request("GET", "https://api.example.com/resource") + ok_response = _mock_response(200) + + with patch.object(adapter._client, "ensure_active_token", new=AsyncMock()): + flow = adapter.async_auth_flow(request) + sent = await flow.__anext__() + assert sent.headers.get("Authorization") == "Bearer valid-at" + with pytest.raises(StopAsyncIteration): + await flow.asend(ok_response) + + +@pytest.mark.anyio +async def test_auth_flow_no_bearer_when_no_token() -> None: + """When storage is empty and no 401, no Authorization header is added.""" + adapter = _make_adapter() + + request = httpx.Request("GET", "https://api.example.com/resource") + ok_response = _mock_response(200) + + with patch.object(adapter._client, "ensure_active_token", new=AsyncMock()): + flow = adapter.async_auth_flow(request) + sent = await flow.__anext__() + assert "Authorization" not in sent.headers + with pytest.raises(StopAsyncIteration): + await flow.asend(ok_response) + + +@pytest.mark.anyio +async def test_auth_flow_client_credentials_on_401() -> None: + """On 401, client_credentials token is acquired and request is retried.""" + adapter = _make_adapter() # no authorization_endpoint → client_credentials + + request = httpx.Request("GET", "https://api.example.com/resource") + response_401 = _mock_response(401) + response_200 = _mock_response(200) + + new_token: dict[str, Any] = {"access_token": "fresh-at", "token_type": "Bearer"} + + async def fake_fetch(url: str, **kwargs: Any) -> dict[str, Any]: + adapter._client.token = new_token + return new_token + + with ( + patch.object(adapter._client, "ensure_active_token", new=AsyncMock()), + patch.object(adapter._client, "fetch_token", new=AsyncMock(side_effect=fake_fetch)), + patch.object(adapter, "_on_token_update", new=AsyncMock()), + ): + flow = adapter.async_auth_flow(request) + first_request = await flow.__anext__() + assert "Authorization" not in first_request.headers + + second_request = await flow.asend(response_401) + assert second_request.headers.get("Authorization") == "Bearer fresh-at" + + with pytest.raises(StopAsyncIteration): + await flow.asend(response_200) + + +@pytest.mark.anyio +async def test_auth_flow_authorization_code_on_401() -> None: + """On 401, authorization_code flow is triggered when endpoint is configured.""" + cfg = _make_config( + authorization_endpoint="https://auth.example.com/authorize", + redirect_uri="https://app.example.com/cb", + ) + captured_state: list[str] = [] + + async def redirect(url: str) -> None: + from urllib.parse import parse_qs, urlparse + + qs = parse_qs(urlparse(url).query) + captured_state.extend(qs.get("state", [])) + + async def callback() -> tuple[str, str | None]: + return "auth-code", captured_state[0] if captured_state else None + + storage = _InMemoryStorage() + adapter = AuthlibOAuthAdapter( + config=cfg, + storage=storage, + redirect_handler=redirect, + callback_handler=callback, + ) + + request = httpx.Request("GET", "https://api.example.com/resource") + response_401 = _mock_response(401) + response_200 = _mock_response(200) + + new_token: dict[str, Any] = {"access_token": "ac-token", "token_type": "Bearer"} + + async def fake_fetch(url: str, **kwargs: Any) -> dict[str, Any]: + adapter._client.token = new_token + return new_token + + with ( + patch.object(adapter._client, "ensure_active_token", new=AsyncMock()), + patch.object(adapter._client, "fetch_token", new=AsyncMock(side_effect=fake_fetch)), + patch.object(adapter, "_on_token_update", new=AsyncMock()), + ): + flow = adapter.async_auth_flow(request) + await flow.__anext__() # initial request + second_request = await flow.asend(response_401) # triggers auth_code flow + assert second_request.headers.get("Authorization") == "Bearer ac-token" + + with pytest.raises(StopAsyncIteration): + await flow.asend(response_200) + + +@pytest.mark.anyio +async def test_auth_flow_ensure_active_token_skipped_when_no_token() -> None: + """ensure_active_token is NOT called when the client has no token yet.""" + adapter = _make_adapter() + adapter._initialized = True + adapter._client.token = None + + request = httpx.Request("GET", "https://api.example.com/resource") + ok_response = _mock_response(200) + + with patch.object(adapter._client, "ensure_active_token", new=AsyncMock()) as mock_eat: + flow = adapter.async_auth_flow(request) + await flow.__anext__() + mock_eat.assert_not_awaited() + with pytest.raises(StopAsyncIteration): + await flow.asend(ok_response) + + +@pytest.mark.anyio +async def test_auth_flow_ensure_active_token_called_when_token_present() -> None: + """ensure_active_token IS called when a token already exists.""" + stored = OAuthToken(access_token="existing-at", token_type="Bearer") + adapter = AuthlibOAuthAdapter(config=_make_config(), storage=_InMemoryStorage(stored)) + + request = httpx.Request("GET", "https://api.example.com/resource") + ok_response = _mock_response(200) + + with patch.object(adapter._client, "ensure_active_token", new=AsyncMock()) as mock_eat: + flow = adapter.async_auth_flow(request) + await flow.__anext__() + mock_eat.assert_awaited_once() + with pytest.raises(StopAsyncIteration): + await flow.asend(ok_response) + + +# --------------------------------------------------------------------------- +# _inject_bearer edge cases +# --------------------------------------------------------------------------- + + +def test_inject_bearer_adds_header() -> None: + """_inject_bearer adds Authorization header when token is set.""" + adapter = _make_adapter() + adapter._client.token = {"access_token": "tok", "token_type": "Bearer"} + request = httpx.Request("GET", "https://api.example.com/") + adapter._inject_bearer(request) + assert request.headers["Authorization"] == "Bearer tok" + + +def test_inject_bearer_skips_when_no_access_token() -> None: + """_inject_bearer does not add header when access_token is missing.""" + adapter = _make_adapter() + adapter._client.token = {} # empty dict — no access_token key + request = httpx.Request("GET", "https://api.example.com/") + adapter._inject_bearer(request) + assert "Authorization" not in request.headers + + +def test_inject_bearer_skips_when_token_is_none() -> None: + """_inject_bearer does not add header when Authlib token is None.""" + adapter = _make_adapter() + adapter._client.token = None + request = httpx.Request("GET", "https://api.example.com/") + adapter._inject_bearer(request) + assert "Authorization" not in request.headers diff --git a/uv.lock b/uv.lock index d01d510f1..2da69f941 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", @@ -71,6 +71,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" }, ] +[[package]] +name = "authlib" +version = "1.6.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6b/6c/c88eac87468c607f88bc24df1f3b31445ee6fc9ba123b09e666adf687cd9/authlib-1.6.8.tar.gz", hash = "sha256:41ae180a17cf672bc784e4a518e5c82687f1fe1e98b0cafaeda80c8e4ab2d1cb", size = 165074, upload-time = "2026-02-14T04:02:17.941Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/73/f7084bf12755113cd535ae586782ff3a6e710bfbe6a0d13d1c2f81ffbbfa/authlib-1.6.8-py2.py3-none-any.whl", hash = "sha256:97286fd7a15e6cfefc32771c8ef9c54f0ed58028f1322de6a2a7c969c3817888", size = 244116, upload-time = "2026-02-14T04:02:15.579Z" }, +] + [[package]] name = "babel" version = "2.17.0" @@ -784,6 +796,7 @@ name = "mcp" source = { editable = "." } dependencies = [ { name = "anyio" }, + { name = "authlib" }, { name = "httpx" }, { name = "httpx-sse" }, { name = "jsonschema" }, @@ -838,6 +851,7 @@ docs = [ [package.metadata] requires-dist = [ { name = "anyio", specifier = ">=4.5" }, + { name = "authlib", specifier = ">=1.4.0" }, { name = "httpx", specifier = ">=0.27.1" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" },