Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,14 @@ def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> No
self._refresh_jitter = 0

@staticmethod
def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
"""Updates the Authorization header with the bearer token.
def _update_headers(headers: MutableMapping[str, str], token: str, token_type: str = "Bearer") -> None:
"""Updates the Authorization header with the access token.

:param MutableMapping[str, str] headers: The HTTP Request headers
:param str token: The OAuth token.
:param str token_type: The OAuth token type.
"""
headers["Authorization"] = "Bearer {}".format(token)
headers["Authorization"] = "{} {}".format(token_type, token)

@property
def _need_new_token(self) -> bool:
Expand Down Expand Up @@ -165,8 +166,10 @@ def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:

if self._token is None or self._need_new_token:
self._request_token(*self._scopes)
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
self._update_headers(request.http_request.headers, bearer_token)
token = cast(Union["AccessToken", "AccessTokenInfo"], self._token)
self._update_headers(
request.http_request.headers, token.token, getattr(token, "token_type", "Bearer")
)

def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
"""Acquire a token from the credential and authorize the request with it.
Expand All @@ -178,8 +181,10 @@ def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes:
:param str scopes: required scopes of authentication
"""
self._request_token(*scopes, **kwargs)
bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token
self._update_headers(request.http_request.headers, bearer_token)
token = cast(Union["AccessToken", "AccessTokenInfo"], self._token)
self._update_headers(
request.http_request.headers, token.token, getattr(token, "token_type", "Bearer")
)

def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]:
"""Authorize request with a bearer token and send it to the next policy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ async def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
# double check because another coroutine may have acquired a token while we waited to acquire the lock
if self._token is None or self._need_new_token():
await self._request_token(*self._scopes)
bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
request.http_request.headers["Authorization"] = "Bearer " + bearer_token
token = cast(Union[AccessToken, AccessTokenInfo], self._token)
request.http_request.headers["Authorization"] = "{} {}".format(
getattr(token, "token_type", "Bearer"), token.token
)
Comment on lines +78 to +80

async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None:
"""Acquire a token from the credential and authorize the request with it.
Expand All @@ -89,8 +91,10 @@ async def authorize_request(self, request: PipelineRequest[HTTPRequestType], *sc

async with self._lock:
await self._request_token(*scopes, **kwargs)
bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token
request.http_request.headers["Authorization"] = "Bearer " + bearer_token
token = cast(Union[AccessToken, AccessTokenInfo], self._token)
request.http_request.headers["Authorization"] = "{} {}".format(
getattr(token, "token_type", "Bearer"), token.token
)
Comment on lines +95 to +97

async def send(
self, request: PipelineRequest[HTTPRequestType]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,30 @@ async def get_token_info(*_, **__):
assert get_token_calls == 0


@pytest.mark.asyncio
@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
async def test_bearer_policy_adds_header_access_token_info_token_type(http_request):
"""The bearer token policy should use the token type from AccessTokenInfo when available."""
expected_token = AccessTokenInfo("expected_token", 2524608000, token_type="pop")

async def verify_authorization_header(request):
assert request.http_request.headers["Authorization"] == "pop {}".format(expected_token.token)
return Mock()

class MockCredential(AsyncSupportsTokenInfo):
async def get_token(self, *_, **__):
return AccessToken("other_token", 2524608000)

async def get_token_info(self, *_, **__):
return expected_token

fake_credential = MockCredential()
policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)]
pipeline = AsyncPipeline(transport=AsyncMock(), policies=policies)

await pipeline.run(http_request("GET", "https://spam.eggs"), context=None)


@pytest.mark.asyncio
@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
async def test_bearer_policy_authorize_request_access_token_info(http_request):
Expand Down
16 changes: 16 additions & 0 deletions sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ def verify_authorization_header(request):
assert fake_credential.get_token.call_count == 0


@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
def test_bearer_policy_adds_header_access_token_info_token_type(http_request):
"""The bearer token policy should use the token type from AccessTokenInfo when available."""
expected_token = AccessTokenInfo("expected_token", 2524608000, token_type="pop")

def verify_authorization_header(request):
assert request.http_request.headers["Authorization"] == "pop {}".format(expected_token.token)
return Mock()

fake_credential = Mock(get_token=Mock(), get_token_info=Mock(return_value=expected_token))
policies = [BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)]

pipeline = Pipeline(transport=Mock(), policies=policies)
pipeline.run(http_request("GET", "https://spam.eggs"))


@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
def test_bearer_policy_authorize_request_access_token_info(http_request):
"""The authorize_request method should add a header containing a token from its credential"""
Expand Down
Loading