From 0b0e72a50071f8b19fd7bedc400fbef9f1012b1f Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Thu, 22 Jan 2026 02:19:38 +0000 Subject: [PATCH] [Identity] Bypass cache when claims provided in Managed Identity Ensure the claims are propagated to the MSAL ManagedIdentityClient or, for those credentials not using MSAL, the cache is bypassed when claims are provided so that a new token request is made. Signed-off-by: Paul Van Eck --- sdk/identity/azure-identity/CHANGELOG.md | 4 +- .../azure/identity/_credentials/imds.py | 2 +- .../identity/_credentials/managed_identity.py | 4 +- .../_internal/managed_identity_base.py | 2 +- .../_internal/managed_identity_client.py | 6 +- .../_internal/msal_managed_identity_client.py | 4 +- .../azure/identity/aio/_credentials/imds.py | 2 +- .../aio/_credentials/managed_identity.py | 3 +- .../aio/_internal/managed_identity_base.py | 2 +- sdk/identity/azure-identity/pyproject.toml | 2 +- .../tests/test_managed_identity.py | 45 +++++++++++ .../tests/test_managed_identity_async.py | 78 +++++++++++++++++++ 12 files changed, 142 insertions(+), 12 deletions(-) diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 50dc8f42f6ee..a164938150c9 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -12,11 +12,13 @@ ### Bugs Fixed -- Fixed an issue with certain credentials not bypassing the token cache when claims are provided in `get_token` or `get_token_info` calls. ([#44552](https://github.com/Azure/azure-sdk-for-python/pull/44552)) +- Fixed an issue with certain credentials not bypassing the token cache when claims are provided in `get_token` or `get_token_info` calls. ([#44552](https://github.com/Azure/azure-sdk-for-python/pull/44552)) ([#44815](https://github.com/Azure/azure-sdk-for-python/pull/44815)) - Fixed an issue where an unhelpful TypeError was raised during Entra ID token requests that returned empty responses. Now, a ClientAuthenticationError is raised with the full response for better troubleshooting. ([#44258](https://github.com/Azure/azure-sdk-for-python/pull/44258)) ### Other Changes +- Bumped minimum dependency on `msal` to `>=1.31.0`. + ## 1.26.0b1 (2025-11-07) ### Features Added diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py index 3a5fe6addf63..b1e82bb88052 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py @@ -124,7 +124,7 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: raise CredentialUnavailableError(error_message) from ex try: - token_info = super()._request_token(*scopes) + token_info = super()._request_token(*scopes, **kwargs) except CredentialUnavailableError: # Response is not json, skip the IMDS credential raise diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py index 991415c09ad7..ea314648fdc9 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py @@ -174,8 +174,8 @@ def get_token( :param str scopes: desired scope for the access token. This credential allows only one scope per request. For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. - - :keyword str claims: not used by this credential; any value provided will be ignored. + :keyword str claims: additional claims required in the token, such as those returned in a resource provider's + claims challenge following an authorization failure. :keyword str tenant_id: not used by this credential; any value provided will be ignored. :return: An access token with the desired scopes. diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py index 949ac14a844f..42b9af5f4b63 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py @@ -55,7 +55,7 @@ def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: # casting because mypy can't determine that these methods are called # only by get_token, which raises when self._client is None - return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes) + return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes, **kwargs) def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: return cast(ManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py index 010c05b43db0..44b318e6bb19 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py @@ -90,7 +90,11 @@ def _process_response(self, response: PipelineResponse, request_time: int) -> Ac return token - def get_cached_token(self, *scopes: str) -> Optional[AccessTokenInfo]: + def get_cached_token(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: + # Do not return a cached token if claims are provided. + if kwargs.get("claims") is not None: + return None + resource = _scopes_to_resource(*scopes) now = time.time() for token in self._cache.search(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource]): diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py index d9b3b317ed06..b17091c4141f 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py @@ -45,11 +45,11 @@ def get_unavailable_message(self, desc: str = "") -> str: def close(self) -> None: self.__exit__() - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: # pylint:disable=unused-argument + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: if not scopes: raise ValueError('"get_token" requires at least one scope') resource = _scopes_to_resource(*scopes) - result = self._msal_client.acquire_token_for_client(resource=resource) + result = self._msal_client.acquire_token_for_client(resource=resource, claims_challenge=kwargs.get("claims")) now = int(time.time()) if result and "access_token" in result and "expires_in" in result: refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py index 81743d44c229..0fb758c3026c 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py @@ -65,7 +65,7 @@ async def close(self) -> None: await self._client.close() async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: - return self._client.get_cached_token(*scopes) + return self._client.get_cached_token(*scopes, **kwargs) async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py index b664c9c3b8bb..f944f7261e54 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py @@ -143,7 +143,8 @@ async def get_token( :param str scopes: desired scope for the access token. This credential allows only one scope per request. For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. - :keyword str claims: not used by this credential; any value provided will be ignored. + :keyword str claims: additional claims required in the token, such as those returned in a resource provider's + claims challenge following an authorization failure. :keyword str tenant_id: not used by this credential; any value provided will be ignored. :return: An access token with the desired scopes. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py index 636fbbf9b2f7..e07403a1982c 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py @@ -62,7 +62,7 @@ async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptio async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessTokenInfo]: # casting because mypy can't determine that these methods are called # only by get_token, which raises when self._client is None - return cast(AsyncManagedIdentityClient, self._client).get_cached_token(*scopes) + return cast(AsyncManagedIdentityClient, self._client).get_cached_token(*scopes, **kwargs) async def _request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo: return await cast(AsyncManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/pyproject.toml b/sdk/identity/azure-identity/pyproject.toml index 409c49389483..557d2c15d1c7 100644 --- a/sdk/identity/azure-identity/pyproject.toml +++ b/sdk/identity/azure-identity/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ dependencies = [ "azure-core>=1.31.0", "cryptography>=2.5", - "msal>=1.30.0", + "msal>=1.31.0", "msal-extensions>=1.2.0", "typing-extensions>=4.0.0", ] diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py index d0ad26d1078a..beda28ec5de4 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity.py @@ -37,6 +37,18 @@ {}, # IMDS {EnvironmentVariables.MSI_ENDPOINT: "...", EnvironmentVariables.MSI_SECRET: "..."}, # Azure ML ) +# Environments where MSAL-based managed identity clients are used +MSAL_MANAGED_IDENTITY_ENVIRON = ( + {EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IDENTITY_HEADER: "..."}, # App Service + { # Service Fabric + EnvironmentVariables.IDENTITY_ENDPOINT: "...", + EnvironmentVariables.IDENTITY_HEADER: "...", + EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: "...", + }, + {EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IMDS_ENDPOINT: "..."}, # Arc + {EnvironmentVariables.MSI_ENDPOINT: "...", EnvironmentVariables.MSI_SECRET: "..."}, # Azure ML + {}, # IMDS +) @pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) @@ -1168,3 +1180,36 @@ def test_log(caplog): with mock.patch.dict("os.environ", mock_environ, clear=True): ManagedIdentityCredential(client_id="foo") assert "workload identity with client_id: foo" in caplog.text + + +@pytest.mark.parametrize("environ,get_token_method", product(MSAL_MANAGED_IDENTITY_ENVIRON, GET_TOKEN_METHODS)) +def test_claims_propagated(environ, get_token_method): + """Test that claims passed are forwarded to MSAL's acquire_token_for_client.""" + from azure.identity import ManagedIdentityCredential + + expected_claims = '{"access_token": {"xms_cc": {"values": ["cp1"]}}}' + expected_token = "test_token" + expires_in = 3600 + + mock_msal_client = mock.Mock() + mock_msal_client.acquire_token_for_client.return_value = { + "access_token": expected_token, + "expires_in": expires_in, + "token_type": "Bearer", + } + + with mock.patch("msal.ManagedIdentityClient", return_value=mock_msal_client): + with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True): + credential = ManagedIdentityCredential() + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": {"claims": expected_claims}} + token = getattr(credential, get_token_method)("scope", **kwargs) + + # Verify the token was returned correctly + assert token.token == expected_token + + # Verify acquire_token_for_client was called with the claims + mock_msal_client.acquire_token_for_client.assert_called_once() + call_kwargs = mock_msal_client.acquire_token_for_client.call_args + assert call_kwargs.kwargs.get("claims_challenge") == expected_claims diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py index 97402716807f..838a8914c6e4 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py @@ -1439,3 +1439,81 @@ def test_log(caplog): with mock.patch.dict("os.environ", mock_environ, clear=True): ManagedIdentityCredential(client_id="foo") assert "workload identity with client_id: foo" in caplog.text + + +@pytest.mark.asyncio +@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) +async def test_claims_force_token_refresh(environ): + """When claims are provided, the token cache should be bypassed and claims passed to the token request.""" + expected_token = "****" + scope = "scope" + expected_claims = '{"access_token": {"xms_cc": {"values": ["cp1"]}}}' + now = int(time.time()) + + call_count = 0 + + async def mock_send(request, **kwargs): + nonlocal call_count + call_count += 1 + + assert "claims" not in kwargs + return mock_response( + json_payload={ + "access_token": expected_token, + "expires_in": 3600, + "expires_on": now + 3600, + "resource": scope, + "token_type": "Bearer", + } + ) + + with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True): + credential = ManagedIdentityCredential(transport=mock.Mock(send=mock_send)) + + # First call without claims + token = await credential.get_token(scope) + assert token.token == expected_token + first_call_count = call_count + + # Second call with claims - should make a new request (not use cache) + token = await credential.get_token_info(scope, options={"claims": expected_claims}) + assert token.token == expected_token + + # Verify a new token request was made when claims were provided + assert call_count > first_call_count, "Expected a new token request when claims were provided" + + +@pytest.mark.asyncio +async def test_access_tokens_cached(): + """Verify that cached tokens are used on subsequent token requests for the same scope.""" + expected_token = "****" + scope = "scope" + now = int(time.time()) + + call_count = 0 + + async def mock_send(request, **kwargs): + nonlocal call_count + call_count += 1 + return mock_response( + json_payload={ + "access_token": expected_token, + "expires_in": 3600, + "expires_on": now + 3600, + "resource": scope, + "token_type": "Bearer", + } + ) + + transport = mock.Mock(send=mock_send) + credential = ManagedIdentityCredential(transport=transport) + + token = await credential.get_token(scope) + assert token.token == expected_token + first_call_count = call_count + + token_info = await credential.get_token_info(scope) + assert token_info.token == expected_token + + # Verify no new token request was made when getting token info + assert call_count == first_call_count