diff --git a/sdk/identity/azure-identity/CHANGELOG.md b/sdk/identity/azure-identity/CHANGELOG.md index 50dc8f42f6ee..a164938150c9 100644 --- a/sdk/identity/azure-identity/CHANGELOG.md +++ b/sdk/identity/azure-identity/CHANGELOG.md @@ -12,11 +12,13 @@ ### Bugs Fixed -- Fixed an issue with certain credentials not bypassing the token cache when claims are provided in `get_token` or `get_token_info` calls. ([#44552](https://github.com/Azure/azure-sdk-for-python/pull/44552)) +- Fixed an issue with certain credentials not bypassing the token cache when claims are provided in `get_token` or `get_token_info` calls. ([#44552](https://github.com/Azure/azure-sdk-for-python/pull/44552)) ([#44815](https://github.com/Azure/azure-sdk-for-python/pull/44815)) - Fixed an issue where an unhelpful TypeError was raised during Entra ID token requests that returned empty responses. Now, a ClientAuthenticationError is raised with the full response for better troubleshooting. ([#44258](https://github.com/Azure/azure-sdk-for-python/pull/44258)) ### Other Changes +- Bumped minimum dependency on `msal` to `>=1.31.0`. + ## 1.26.0b1 (2025-11-07) ### Features Added diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py index 3a5fe6addf63..b1e82bb88052 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/imds.py @@ -124,7 +124,7 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: raise CredentialUnavailableError(error_message) from ex try: - token_info = super()._request_token(*scopes) + token_info = super()._request_token(*scopes, **kwargs) except CredentialUnavailableError: # Response is not json, skip the IMDS credential raise diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py index 991415c09ad7..ea314648fdc9 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py @@ -174,8 +174,8 @@ def get_token( :param str scopes: desired scope for the access token. This credential allows only one scope per request. For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. - - :keyword str claims: not used by this credential; any value provided will be ignored. + :keyword str claims: additional claims required in the token, such as those returned in a resource provider's + claims challenge following an authorization failure. :keyword str tenant_id: not used by this credential; any value provided will be ignored. :return: An access token with the desired scopes. diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py index 949ac14a844f..42b9af5f4b63 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_base.py @@ -55,7 +55,7 @@ def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: # casting because mypy can't determine that these methods are called # only by get_token, which raises when self._client is None - return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes) + return cast(ManagedIdentityClient, self._client).get_cached_token(*scopes, **kwargs) def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: return cast(ManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py index 010c05b43db0..44b318e6bb19 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/managed_identity_client.py @@ -90,7 +90,11 @@ def _process_response(self, response: PipelineResponse, request_time: int) -> Ac return token - def get_cached_token(self, *scopes: str) -> Optional[AccessTokenInfo]: + def get_cached_token(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: + # Do not return a cached token if claims are provided. + if kwargs.get("claims") is not None: + return None + resource = _scopes_to_resource(*scopes) now = time.time() for token in self._cache.search(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource]): diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py index d9b3b317ed06..b17091c4141f 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_managed_identity_client.py @@ -45,11 +45,11 @@ def get_unavailable_message(self, desc: str = "") -> str: def close(self) -> None: self.__exit__() - def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: # pylint:disable=unused-argument + def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: if not scopes: raise ValueError('"get_token" requires at least one scope') resource = _scopes_to_resource(*scopes) - result = self._msal_client.acquire_token_for_client(resource=resource) + result = self._msal_client.acquire_token_for_client(resource=resource, claims_challenge=kwargs.get("claims")) now = int(time.time()) if result and "access_token" in result and "expires_in" in result: refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py index 81743d44c229..0fb758c3026c 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py @@ -65,7 +65,7 @@ async def close(self) -> None: await self._client.close() async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: - return self._client.get_cached_token(*scopes) + return self._client.get_cached_token(*scopes, **kwargs) async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py index b664c9c3b8bb..f944f7261e54 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py @@ -143,7 +143,8 @@ async def get_token( :param str scopes: desired scope for the access token. This credential allows only one scope per request. For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. - :keyword str claims: not used by this credential; any value provided will be ignored. + :keyword str claims: additional claims required in the token, such as those returned in a resource provider's + claims challenge following an authorization failure. :keyword str tenant_id: not used by this credential; any value provided will be ignored. :return: An access token with the desired scopes. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py index 636fbbf9b2f7..e07403a1982c 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py @@ -62,7 +62,7 @@ async def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptio async def _acquire_token_silently(self, *scopes: str, **kwargs) -> Optional[AccessTokenInfo]: # casting because mypy can't determine that these methods are called # only by get_token, which raises when self._client is None - return cast(AsyncManagedIdentityClient, self._client).get_cached_token(*scopes) + return cast(AsyncManagedIdentityClient, self._client).get_cached_token(*scopes, **kwargs) async def _request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo: return await cast(AsyncManagedIdentityClient, self._client).request_token(*scopes, **kwargs) diff --git a/sdk/identity/azure-identity/pyproject.toml b/sdk/identity/azure-identity/pyproject.toml index 409c49389483..557d2c15d1c7 100644 --- a/sdk/identity/azure-identity/pyproject.toml +++ b/sdk/identity/azure-identity/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ dependencies = [ "azure-core>=1.31.0", "cryptography>=2.5", - "msal>=1.30.0", + "msal>=1.31.0", "msal-extensions>=1.2.0", "typing-extensions>=4.0.0", ] diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py index d0ad26d1078a..beda28ec5de4 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity.py @@ -37,6 +37,18 @@ {}, # IMDS {EnvironmentVariables.MSI_ENDPOINT: "...", EnvironmentVariables.MSI_SECRET: "..."}, # Azure ML ) +# Environments where MSAL-based managed identity clients are used +MSAL_MANAGED_IDENTITY_ENVIRON = ( + {EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IDENTITY_HEADER: "..."}, # App Service + { # Service Fabric + EnvironmentVariables.IDENTITY_ENDPOINT: "...", + EnvironmentVariables.IDENTITY_HEADER: "...", + EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: "...", + }, + {EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IMDS_ENDPOINT: "..."}, # Arc + {EnvironmentVariables.MSI_ENDPOINT: "...", EnvironmentVariables.MSI_SECRET: "..."}, # Azure ML + {}, # IMDS +) @pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) @@ -1168,3 +1180,36 @@ def test_log(caplog): with mock.patch.dict("os.environ", mock_environ, clear=True): ManagedIdentityCredential(client_id="foo") assert "workload identity with client_id: foo" in caplog.text + + +@pytest.mark.parametrize("environ,get_token_method", product(MSAL_MANAGED_IDENTITY_ENVIRON, GET_TOKEN_METHODS)) +def test_claims_propagated(environ, get_token_method): + """Test that claims passed are forwarded to MSAL's acquire_token_for_client.""" + from azure.identity import ManagedIdentityCredential + + expected_claims = '{"access_token": {"xms_cc": {"values": ["cp1"]}}}' + expected_token = "test_token" + expires_in = 3600 + + mock_msal_client = mock.Mock() + mock_msal_client.acquire_token_for_client.return_value = { + "access_token": expected_token, + "expires_in": expires_in, + "token_type": "Bearer", + } + + with mock.patch("msal.ManagedIdentityClient", return_value=mock_msal_client): + with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True): + credential = ManagedIdentityCredential() + kwargs = {"claims": expected_claims} + if get_token_method == "get_token_info": + kwargs = {"options": {"claims": expected_claims}} + token = getattr(credential, get_token_method)("scope", **kwargs) + + # Verify the token was returned correctly + assert token.token == expected_token + + # Verify acquire_token_for_client was called with the claims + mock_msal_client.acquire_token_for_client.assert_called_once() + call_kwargs = mock_msal_client.acquire_token_for_client.call_args + assert call_kwargs.kwargs.get("claims_challenge") == expected_claims diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py index 97402716807f..838a8914c6e4 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py @@ -1439,3 +1439,81 @@ def test_log(caplog): with mock.patch.dict("os.environ", mock_environ, clear=True): ManagedIdentityCredential(client_id="foo") assert "workload identity with client_id: foo" in caplog.text + + +@pytest.mark.asyncio +@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) +async def test_claims_force_token_refresh(environ): + """When claims are provided, the token cache should be bypassed and claims passed to the token request.""" + expected_token = "****" + scope = "scope" + expected_claims = '{"access_token": {"xms_cc": {"values": ["cp1"]}}}' + now = int(time.time()) + + call_count = 0 + + async def mock_send(request, **kwargs): + nonlocal call_count + call_count += 1 + + assert "claims" not in kwargs + return mock_response( + json_payload={ + "access_token": expected_token, + "expires_in": 3600, + "expires_on": now + 3600, + "resource": scope, + "token_type": "Bearer", + } + ) + + with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True): + credential = ManagedIdentityCredential(transport=mock.Mock(send=mock_send)) + + # First call without claims + token = await credential.get_token(scope) + assert token.token == expected_token + first_call_count = call_count + + # Second call with claims - should make a new request (not use cache) + token = await credential.get_token_info(scope, options={"claims": expected_claims}) + assert token.token == expected_token + + # Verify a new token request was made when claims were provided + assert call_count > first_call_count, "Expected a new token request when claims were provided" + + +@pytest.mark.asyncio +async def test_access_tokens_cached(): + """Verify that cached tokens are used on subsequent token requests for the same scope.""" + expected_token = "****" + scope = "scope" + now = int(time.time()) + + call_count = 0 + + async def mock_send(request, **kwargs): + nonlocal call_count + call_count += 1 + return mock_response( + json_payload={ + "access_token": expected_token, + "expires_in": 3600, + "expires_on": now + 3600, + "resource": scope, + "token_type": "Bearer", + } + ) + + transport = mock.Mock(send=mock_send) + credential = ManagedIdentityCredential(transport=transport) + + token = await credential.get_token(scope) + assert token.token == expected_token + first_call_count = call_count + + token_info = await credential.get_token_info(scope) + assert token_info.token == expected_token + + # Verify no new token request was made when getting token info + assert call_count == first_call_count