Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@

### Bugs Fixed

- Fixed an issue with certain credentials not bypassing the token cache when claims are provided in `get_token` or `get_token_info` calls. ([#44552](https://github.com/Azure/azure-sdk-for-python/pull/44552))
- Fixed an issue with certain credentials not bypassing the token cache when claims are provided in `get_token` or `get_token_info` calls. ([#44552](https://github.com/Azure/azure-sdk-for-python/pull/44552)) ([#44815](https://github.com/Azure/azure-sdk-for-python/pull/44815))
- Fixed an issue where an unhelpful TypeError was raised during Entra ID token requests that returned empty responses. Now, a ClientAuthenticationError is raised with the full response for better troubleshooting. ([#44258](https://github.com/Azure/azure-sdk-for-python/pull/44258))

### Other Changes

- Bumped minimum dependency on `msal` to `>=1.31.0`.

## 1.26.0b1 (2025-11-07)

### Features Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
raise CredentialUnavailableError(error_message) from ex

try:
token_info = super()._request_token(*scopes)
token_info = super()._request_token(*scopes, **kwargs)
except CredentialUnavailableError:
# Response is not json, skip the IMDS credential
raise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def get_token(
:param str scopes: desired scope for the access token. This credential allows only one scope per request.
For more information about scopes, see
https://learn.microsoft.com/entra/identity-platform/scopes-oidc.

:keyword str claims: not used by this credential; any value provided will be ignored.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure.
:keyword str tenant_id: not used by this credential; any value provided will be ignored.

:return: An access token with the desired scopes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] =
def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]:
# casting because mypy can't determine that these methods are called
# only by get_token, which raises when self._client is None
return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes)
return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes, **kwargs)

def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
return cast(ManagedIdentityClient, self._client).request_token(*scopes, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def _process_response(self, response: PipelineResponse, request_time: int) -> Ac

return token

def get_cached_token(self, *scopes: str) -> Optional[AccessTokenInfo]:
def get_cached_token(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]:
# Do not return a cached token if claims are provided.
if kwargs.get("claims") is not None:
return None

resource = _scopes_to_resource(*scopes)
now = time.time()
for token in self._cache.search(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def get_unavailable_message(self, desc: str = "") -> str:
def close(self) -> None:
self.__exit__()

def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: # pylint:disable=unused-argument
def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:
if not scopes:
raise ValueError('"get_token" requires at least one scope')
resource = _scopes_to_resource(*scopes)
result = self._msal_client.acquire_token_for_client(resource=resource)
result = self._msal_client.acquire_token_for_client(resource=resource, claims_challenge=kwargs.get("claims"))
now = int(time.time())
if result and "access_token" in result and "expires_in" in result:
refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def close(self) -> None:
await self._client.close()

async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]:
return self._client.get_cached_token(*scopes)
return self._client.get_cached_token(*scopes, **kwargs)

async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ async def get_token(
:param str scopes: desired scope for the access token. This credential allows only one scope per request.
For more information about scopes, see
https://learn.microsoft.com/entra/identity-platform/scopes-oidc.
:keyword str claims: not used by this credential; any value provided will be ignored.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure.
:keyword str tenant_id: not used by this credential; any value provided will be ignored.

:return: An access token with the desired scopes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptio
async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessTokenInfo]:
# casting because mypy can't determine that these methods are called
# only by get_token, which raises when self._client is None
return cast(AsyncManagedIdentityClient, self._client).get_cached_token(*scopes)
return cast(AsyncManagedIdentityClient, self._client).get_cached_token(*scopes, **kwargs)

async def _request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo:
return await cast(AsyncManagedIdentityClient, self._client).request_token(*scopes, **kwargs)
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
dependencies = [
"azure-core>=1.31.0",
"cryptography>=2.5",
"msal>=1.30.0",
"msal>=1.31.0",
"msal-extensions>=1.2.0",
"typing-extensions>=4.0.0",
]
Expand Down
45 changes: 45 additions & 0 deletions sdk/identity/azure-identity/tests/test_managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@
{}, # IMDS
{EnvironmentVariables.MSI_ENDPOINT: "...", EnvironmentVariables.MSI_SECRET: "..."}, # Azure ML
)
# Environments where MSAL-based managed identity clients are used
MSAL_MANAGED_IDENTITY_ENVIRON = (
{EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IDENTITY_HEADER: "..."}, # App Service
{ # Service Fabric
EnvironmentVariables.IDENTITY_ENDPOINT: "...",
EnvironmentVariables.IDENTITY_HEADER: "...",
EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: "...",
},
{EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IMDS_ENDPOINT: "..."}, # Arc
{EnvironmentVariables.MSI_ENDPOINT: "...", EnvironmentVariables.MSI_SECRET: "..."}, # Azure ML
{}, # IMDS
)


@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS)
Expand Down Expand Up @@ -1168,3 +1180,36 @@ def test_log(caplog):
with mock.patch.dict("os.environ", mock_environ, clear=True):
ManagedIdentityCredential(client_id="foo")
assert "workload identity with client_id: foo" in caplog.text


@pytest.mark.parametrize("environ,get_token_method", product(MSAL_MANAGED_IDENTITY_ENVIRON, GET_TOKEN_METHODS))
def test_claims_propagated(environ, get_token_method):
"""Test that claims passed are forwarded to MSAL's acquire_token_for_client."""
from azure.identity import ManagedIdentityCredential

expected_claims = '{"access_token": {"xms_cc": {"values": ["cp1"]}}}'
expected_token = "test_token"
expires_in = 3600

mock_msal_client = mock.Mock()
mock_msal_client.acquire_token_for_client.return_value = {
"access_token": expected_token,
"expires_in": expires_in,
"token_type": "Bearer",
}

with mock.patch("msal.ManagedIdentityClient", return_value=mock_msal_client):
with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True):
credential = ManagedIdentityCredential()
kwargs = {"claims": expected_claims}
if get_token_method == "get_token_info":
kwargs = {"options": {"claims": expected_claims}}
token = getattr(credential, get_token_method)("scope", **kwargs)

# Verify the token was returned correctly
assert token.token == expected_token

# Verify acquire_token_for_client was called with the claims
mock_msal_client.acquire_token_for_client.assert_called_once()
call_kwargs = mock_msal_client.acquire_token_for_client.call_args
assert call_kwargs.kwargs.get("claims_challenge") == expected_claims
Original file line number Diff line number Diff line change
Expand Up @@ -1439,3 +1439,81 @@ def test_log(caplog):
with mock.patch.dict("os.environ", mock_environ, clear=True):
ManagedIdentityCredential(client_id="foo")
assert "workload identity with client_id: foo" in caplog.text


@pytest.mark.asyncio
@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS)
async def test_claims_force_token_refresh(environ):
"""When claims are provided, the token cache should be bypassed and claims passed to the token request."""
expected_token = "****"
scope = "scope"
expected_claims = '{"access_token": {"xms_cc": {"values": ["cp1"]}}}'
now = int(time.time())

call_count = 0

async def mock_send(request, **kwargs):
nonlocal call_count
call_count += 1

assert "claims" not in kwargs
return mock_response(
json_payload={
"access_token": expected_token,
"expires_in": 3600,
"expires_on": now + 3600,
"resource": scope,
"token_type": "Bearer",
}
)

with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True):
credential = ManagedIdentityCredential(transport=mock.Mock(send=mock_send))

# First call without claims
token = await credential.get_token(scope)
assert token.token == expected_token
first_call_count = call_count

# Second call with claims - should make a new request (not use cache)
token = await credential.get_token_info(scope, options={"claims": expected_claims})
assert token.token == expected_token

# Verify a new token request was made when claims were provided
assert call_count > first_call_count, "Expected a new token request when claims were provided"


@pytest.mark.asyncio
async def test_access_tokens_cached():
"""Verify that cached tokens are used on subsequent token requests for the same scope."""
expected_token = "****"
scope = "scope"
now = int(time.time())

call_count = 0

async def mock_send(request, **kwargs):
nonlocal call_count
call_count += 1
return mock_response(
json_payload={
"access_token": expected_token,
"expires_in": 3600,
"expires_on": now + 3600,
"resource": scope,
"token_type": "Bearer",
}
)

transport = mock.Mock(send=mock_send)
credential = ManagedIdentityCredential(transport=transport)

token = await credential.get_token(scope)
assert token.token == expected_token
first_call_count = call_count

token_info = await credential.get_token_info(scope)
assert token_info.token == expected_token

# Verify no new token request was made when getting token info
assert call_count == first_call_count
Loading