From c8d207e89bda5a62baa4c844b9409a35a540bde3 Mon Sep 17 00:00:00 2001 From: Rohit Singhal Date: Thu, 18 Jun 2026 00:38:24 +0100 Subject: [PATCH] Honor token type in bearer auth policies Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../core/pipeline/policies/_authentication.py | 19 +++++++++------ .../policies/_authentication_async.py | 12 ++++++---- .../async_tests/test_authentication_async.py | 24 +++++++++++++++++++ .../azure-core/tests/test_authentication.py | 16 +++++++++++++ 4 files changed, 60 insertions(+), 11 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index f702fdbb5311..b418182289f7 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -106,13 +106,14 @@ def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> No self._refresh_jitter = 0 @staticmethod - def _update_headers(headers: MutableMapping[str, str], token: str) -> None: - """Updates the Authorization header with the bearer token. + def _update_headers(headers: MutableMapping[str, str], token: str, token_type: str = "Bearer") -> None: + """Updates the Authorization header with the access token. :param MutableMapping[str, str] headers: The HTTP Request headers :param str token: The OAuth token. + :param str token_type: The OAuth token type. """ - headers["Authorization"] = "Bearer {}".format(token) + headers["Authorization"] = "{} {}".format(token_type, token) @property def _need_new_token(self) -> bool: @@ -165,8 +166,10 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: if self._token is None or self._need_new_token: self._request_token(*self._scopes) - bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token - self._update_headers(request.http_request.headers, bearer_token) + token = cast(Union["AccessToken", "AccessTokenInfo"], self._token) + self._update_headers( + request.http_request.headers, token.token, getattr(token, "token_type", "Bearer") + ) def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: """Acquire a token from the credential and authorize the request with it. @@ -178,8 +181,10 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: :param str scopes: required scopes of authentication """ self._request_token(*scopes, **kwargs) - bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token - self._update_headers(request.http_request.headers, bearer_token) + token = cast(Union["AccessToken", "AccessTokenInfo"], self._token) + self._update_headers( + request.http_request.headers, token.token, getattr(token, "token_type", "Bearer") + ) def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: """Authorize request with a bearer token and send it to the next policy diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index 9d238756b902..9cb06958058c 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -74,8 +74,10 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: # double check because another coroutine may have acquired a token while we waited to acquire the lock if self._token is None or self._need_new_token(): await self._request_token(*self._scopes) - bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token - request.http_request.headers["Authorization"] = "Bearer " + bearer_token + token = cast(Union[AccessToken, AccessTokenInfo], self._token) + request.http_request.headers["Authorization"] = "{} {}".format( + getattr(token, "token_type", "Bearer"), token.token + ) async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: """Acquire a token from the credential and authorize the request with it. @@ -89,8 +91,10 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc async with self._lock: await self._request_token(*scopes, **kwargs) - bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token - request.http_request.headers["Authorization"] = "Bearer " + bearer_token + token = cast(Union[AccessToken, AccessTokenInfo], self._token) + request.http_request.headers["Authorization"] = "{} {}".format( + getattr(token, "token_type", "Bearer"), token.token + ) async def send( self, request: PipelineRequest[HTTPRequestType] diff --git a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py index 0be2027fe1fa..13703c189445 100644 --- a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py @@ -119,6 +119,30 @@ async def get_token_info(*_, **__): assert get_token_calls == 0 +@pytest.mark.asyncio +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_adds_header_access_token_info_token_type(http_request): + """The bearer token policy should use the token type from AccessTokenInfo when available.""" + expected_token = AccessTokenInfo("expected_token", 2524608000, token_type="pop") + + async def verify_authorization_header(request): + assert request.http_request.headers["Authorization"] == "pop {}".format(expected_token.token) + return Mock() + + class MockCredential(AsyncSupportsTokenInfo): + async def get_token(self, *_, **__): + return AccessToken("other_token", 2524608000) + + async def get_token_info(self, *_, **__): + return expected_token + + fake_credential = MockCredential() + policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] + pipeline = AsyncPipeline(transport=AsyncMock(), policies=policies) + + await pipeline.run(http_request("GET", "https://spam.eggs"), context=None) + + @pytest.mark.asyncio @pytest.mark.parametrize("http_request", HTTP_REQUESTS) async def test_bearer_policy_authorize_request_access_token_info(http_request): diff --git a/sdk/core/azure-core/tests/test_authentication.py b/sdk/core/azure-core/tests/test_authentication.py index be71bd1e0764..b22834c656ce 100644 --- a/sdk/core/azure-core/tests/test_authentication.py +++ b/sdk/core/azure-core/tests/test_authentication.py @@ -111,6 +111,22 @@ def verify_authorization_header(request): assert fake_credential.get_token.call_count == 0 +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_adds_header_access_token_info_token_type(http_request): + """The bearer token policy should use the token type from AccessTokenInfo when available.""" + expected_token = AccessTokenInfo("expected_token", 2524608000, token_type="pop") + + def verify_authorization_header(request): + assert request.http_request.headers["Authorization"] == "pop {}".format(expected_token.token) + return Mock() + + fake_credential = Mock(get_token=Mock(), get_token_info=Mock(return_value=expected_token)) + policies = [BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] + + pipeline = Pipeline(transport=Mock(), policies=policies) + pipeline.run(http_request("GET", "https://spam.eggs")) + + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) def test_bearer_policy_authorize_request_access_token_info(http_request): """The authorize_request method should add a header containing a token from its credential"""