diff --git a/pyrit/auth/__init__.py b/pyrit/auth/__init__.py index 4074809e51..02cd90b1dd 100644 --- a/pyrit/auth/__init__.py +++ b/pyrit/auth/__init__.py @@ -7,6 +7,7 @@ from pyrit.auth.authenticator import Authenticator from pyrit.auth.azure_auth import ( + AsyncTokenProviderCredential, AzureAuth, TokenProviderCredential, get_azure_async_token_provider, @@ -19,6 +20,7 @@ from pyrit.auth.manual_copilot_authenticator import ManualCopilotAuthenticator __all__ = [ + "AsyncTokenProviderCredential", "Authenticator", "AzureAuth", "AzureStorageAuth", diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index b606189636..2cb471c6c7 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -3,6 +3,7 @@ from __future__ import annotations +import inspect import logging import time from typing import TYPE_CHECKING, Any, Union, cast @@ -23,7 +24,7 @@ ) if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Awaitable, Callable import azure.cognitiveservices.speech as speechsdk @@ -67,6 +68,60 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: return AccessToken(str(token), expires_on) +class AsyncTokenProviderCredential: + """ + Async wrapper to convert a token provider callable into an Azure AsyncTokenCredential. + + This class bridges the gap between token provider functions (sync or async) and Azure SDK + async clients that require an AsyncTokenCredential object (with async def get_token). + """ + + def __init__(self, token_provider: Callable[[], Union[str, Awaitable[str]]]) -> None: + """ + Initialize AsyncTokenProviderCredential. + + Args: + token_provider: A callable that returns a token string (sync) or an awaitable that + returns a token string (async). Both are supported transparently. + """ + self._token_provider = token_provider + + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + """ + Get an access token asynchronously. + + Args: + scopes: Token scopes (ignored as the scope is already configured in the token provider). + kwargs: Additional arguments (ignored). + + Returns: + AccessToken: The access token with expiration time. + """ + result = self._token_provider() + if inspect.isawaitable(result): + token = await result + else: + token = result + expires_on = int(time.time()) + 3600 + return AccessToken(str(token), expires_on) + + async def close(self) -> None: + """No-op close for protocol compliance. The callable provider does not hold resources.""" + + async def __aenter__(self) -> AsyncTokenProviderCredential: + """ + Enter the async context manager. + + Returns: + AsyncTokenProviderCredential: This credential instance. + """ + return self + + async def __aexit__(self, *args: Any) -> None: + """Exit the async context manager.""" + await self.close() + + class AzureAuth(Authenticator): """ Azure CLI Authentication. diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index 6b587d4f30..303577416e 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -1,13 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import asyncio import base64 import inspect -from collections.abc import Callable +import logging +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Optional -from azure.ai.contentsafety import ContentSafetyClient +from azure.ai.contentsafety.aio import ContentSafetyClient from azure.ai.contentsafety.models import ( AnalyzeImageOptions, AnalyzeImageResult, @@ -18,7 +18,7 @@ ) from azure.core.credentials import AzureKeyCredential -from pyrit.auth import TokenProviderCredential, get_azure_token_provider +from pyrit.auth import AsyncTokenProviderCredential, get_azure_async_token_provider from pyrit.common import default_values from pyrit.identifiers import ComponentIdentifier from pyrit.models import ( @@ -38,6 +38,48 @@ from pyrit.score.scorer_evaluation.scorer_evaluator import ScorerEvalDatasetFiles from pyrit.score.scorer_evaluation.scorer_metrics import ScorerMetrics +logger = logging.getLogger(__name__) + + +def _ensure_async_token_provider( + api_key: str | Callable[[], str | Awaitable[str]] | None, +) -> str | Callable[[], Awaitable[str]] | None: + """ + Ensure the api_key is either a string or an async callable. + + If a synchronous callable token provider is provided, it's automatically wrapped + in an async function to make it compatible with the async ContentSafetyClient. + + Args: + api_key: Either a string API key or a callable that returns a token (sync or async). + + Returns: + Either a string API key or an async callable that returns a token. + """ + if api_key is None or isinstance(api_key, str) or not callable(api_key): + return api_key + + # Check if the callable is already async + if inspect.iscoroutinefunction(api_key): + return api_key + + # Wrap synchronous token provider in async function + logger.info( + "Detected synchronous token provider." + " Automatically wrapping in async function for compatibility with async ContentSafetyClient." + ) + + async def async_token_provider() -> str: + """ + Async wrapper for synchronous token provider. + + Returns: + str: The token string from the synchronous provider. + """ + return api_key() # type: ignore[return-value] + + return async_token_provider + class AzureContentFilterScorer(FloatScaleScorer): """ @@ -94,7 +136,7 @@ def __init__( self, *, endpoint: Optional[str | None] = None, - api_key: Optional[str | Callable[[], str] | None] = None, + api_key: Optional[str | Callable[[], str | Awaitable[str]] | None] = None, harm_categories: Optional[list[TextCategory]] = None, validator: Optional[ScorerPromptValidator] = None, ) -> None: @@ -104,13 +146,12 @@ def __init__( Args: endpoint (Optional[str | None]): The endpoint URL for the Azure Content Safety service. Defaults to the `ENDPOINT_URI_ENVIRONMENT_VARIABLE` environment variable. - api_key (Optional[str | Callable[[], str] | None]): + api_key (Optional[str | Callable[[], str | Awaitable[str]] | None]): The API key for accessing the Azure Content Safety service, - or a synchronous callable that returns an access token. Async token providers - are not supported. If not provided (via parameter - or environment variable), Entra ID authentication is used automatically. - You can also explicitly pass a token provider from pyrit.auth - (e.g., get_azure_token_provider('https://cognitiveservices.azure.com/.default')). + or a callable that returns an access token. Both synchronous and asynchronous + token providers are supported. Sync providers are automatically wrapped for + async compatibility. If not provided (via parameter or environment variable), + Entra ID authentication is used automatically. Defaults to the `API_KEY_ENVIRONMENT_VARIABLE` environment variable. harm_categories (Optional[list[TextCategory]]): The harm categories you want to query for as defined in azure.ai.contentsafety.models.TextCategory. If not provided, defaults to all categories. @@ -129,36 +170,25 @@ def __init__( ) # API key: use passed value, env var, or fall back to Entra ID for Azure endpoints - resolved_api_key: str | Callable[[], str] + resolved_api_key: str | Callable[[], str | Awaitable[str]] if api_key is not None and callable(api_key): - if asyncio.iscoroutinefunction(api_key): - raise ValueError( - "Async token providers are not supported by AzureContentFilterScorer. " - "Use a synchronous token provider (e.g., get_azure_token_provider) instead." - ) - # Guard against sync callables that return coroutines/awaitables (e.g., lambda: async_fn()) - test_result = api_key() - if inspect.isawaitable(test_result): - if hasattr(test_result, "close"): - test_result.close() # prevent "coroutine was never awaited" warning - raise ValueError( - "The provided token provider returns a coroutine/awaitable, which is not supported " - "by AzureContentFilterScorer. Use a synchronous token provider instead." - ) resolved_api_key = api_key else: api_key_value = default_values.get_non_required_value( env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key ) - resolved_api_key = api_key_value or get_azure_token_provider("https://cognitiveservices.azure.com/.default") + resolved_api_key = api_key_value or get_azure_async_token_provider( + "https://cognitiveservices.azure.com/.default" + ) - self._api_key = resolved_api_key + # Ensure api_key is async-compatible (wrap sync token providers if needed) + self._api_key = _ensure_async_token_provider(resolved_api_key) # Create ContentSafetyClient with appropriate credential if self._endpoint is not None: if callable(self._api_key): - # Token provider - create a TokenCredential wrapper - credential = TokenProviderCredential(self._api_key) + # Token provider - create an AsyncTokenCredential wrapper + credential = AsyncTokenProviderCredential(self._api_key) self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential) else: # String API key @@ -291,7 +321,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op categories=self._category_values, output_type="EightSeverityLevels", ) - text_result = self._azure_cf_client.analyze_text(text_request_options) + text_result = await self._azure_cf_client.analyze_text(text_request_options) filter_results.append(text_result) elif message_piece.converted_value_data_type == "image_path": @@ -301,7 +331,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op image_request_options = AnalyzeImageOptions( image=image_data, categories=self._category_values, output_type="FourSeverityLevels" ) - image_result = self._azure_cf_client.analyze_image(image_request_options) + image_result = await self._azure_cf_client.analyze_image(image_request_options) filter_results.append(image_result) # Collect all scores from all chunks/images diff --git a/tests/integration/mocks.py b/tests/integration/mocks.py index fedeb95e12..1c997f3326 100644 --- a/tests/integration/mocks.py +++ b/tests/integration/mocks.py @@ -6,7 +6,7 @@ from sqlalchemy import inspect -from pyrit.identifiers import AttackIdentifier +from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, SQLiteMemory from pyrit.models import Message, MessagePiece from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute @@ -49,7 +49,7 @@ def set_system_prompt( *, system_prompt: str, conversation_id: str, - attack_identifier: Optional[AttackIdentifier] = None, + attack_identifier: Optional[ComponentIdentifier] = None, labels: Optional[dict[str, str]] = None, ) -> None: self.system_prompt = system_prompt diff --git a/tests/integration/score/test_azure_content_filter_integration.py b/tests/integration/score/test_azure_content_filter_integration.py index 9f7fdef20b..53d40ff320 100644 --- a/tests/integration/score/test_azure_content_filter_integration.py +++ b/tests/integration/score/test_azure_content_filter_integration.py @@ -12,6 +12,11 @@ from pyrit.memory import CentralMemory, MemoryInterface from pyrit.score import AzureContentFilterScorer +pytestmark = pytest.mark.skipif( + not os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT"), + reason="AZURE_CONTENT_SAFETY_API_ENDPOINT not configured", +) + @pytest.fixture def memory() -> Generator[MemoryInterface, None, None]: @@ -27,13 +32,6 @@ async def test_azure_content_filter_scorer_image_integration(memory) -> None: environment variables to be set. Uses a sample image from the assets folder. """ with patch.object(CentralMemory, "get_memory_instance", return_value=memory): - # Verify required environment variables are set - api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY") - endpoint = os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT") - - if not api_key or not endpoint: - pytest.skip("Azure Content Safety credentials not configured") - scorer = AzureContentFilterScorer() image_path = HOME_PATH / "assets" / "architecture_components.png" @@ -62,13 +60,6 @@ async def test_azure_content_filter_scorer_long_text_chunking_integration(memory This verifies that the chunking and aggregation logic works correctly with the real API. """ with patch.object(CentralMemory, "get_memory_instance", return_value=memory): - # Verify required environment variables are set - api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY") - endpoint = os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT") - - if not api_key or not endpoint: - pytest.skip("Azure Content Safety credentials not configured") - scorer = AzureContentFilterScorer() # This should be greater than the rate limit diff --git a/tests/unit/score/test_azure_content_filter.py b/tests/unit/score/test_azure_content_filter.py index 27c0cc298a..bd5d97b26e 100644 --- a/tests/unit/score/test_azure_content_filter.py +++ b/tests/unit/score/test_azure_content_filter.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. +import inspect import os from unittest.mock import AsyncMock, MagicMock, patch @@ -55,7 +56,7 @@ async def test_score_async_unsupported_data_type_returns_empty_list( @pytest.mark.asyncio async def test_score_piece_async_text(patch_central_database, text_message_piece: MessagePiece): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "2", "category": "Hate"}]} scorer._azure_cf_client = mock_client scores = await scorer._score_piece_async(text_message_piece) @@ -72,7 +73,7 @@ async def test_score_piece_async_text(patch_central_database, text_message_piece @pytest.mark.asyncio async def test_score_piece_async_image(patch_central_database, image_message_piece: MessagePiece): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_image.return_value = {"categoriesAnalysis": [{"severity": "3", "category": "Hate"}]} scorer._azure_cf_client = mock_client # Patch _get_base64_image_data to avoid actual file IO @@ -102,25 +103,34 @@ def test_explicit_category(): assert len(scorer._harm_categories) == 1 -def test_async_callable_api_key_raises(): +def test_async_callable_api_key_accepted(): async def async_provider(): return "token" - with pytest.raises(ValueError, match="Async token providers are not supported"): - AzureContentFilterScorer(api_key=async_provider, endpoint="bar") + scorer = AzureContentFilterScorer(api_key=async_provider, endpoint="bar") + # Async callable should be passed through as-is + assert callable(scorer._api_key) + assert inspect.iscoroutinefunction(scorer._api_key) -def test_sync_callable_returning_coroutine_raises(): +def test_sync_callable_returning_coroutine_accepted(): async def async_fn(): return "token" - with pytest.raises(ValueError, match="returns a coroutine/awaitable"): - AzureContentFilterScorer(api_key=lambda: async_fn(), endpoint="bar") + sync_lambda = lambda: async_fn() # noqa: E731 + # Confirm the lambda itself is NOT a coroutine function (it's sync) + assert not inspect.iscoroutinefunction(sync_lambda) + + scorer = AzureContentFilterScorer(api_key=sync_lambda, endpoint="bar") + # After init, the sync callable should be wrapped in an async function + assert callable(scorer._api_key) + assert inspect.iscoroutinefunction(scorer._api_key) def test_sync_callable_api_key_accepted(): scorer = AzureContentFilterScorer(api_key=lambda: "token", endpoint="bar") assert callable(scorer._api_key) + assert inspect.iscoroutinefunction(scorer._api_key) @pytest.mark.asyncio @@ -129,7 +139,7 @@ async def test_azure_content_filter_scorer_adds_to_memory(): with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "2", "category": "Hate"}]} scorer._azure_cf_client = mock_client @@ -143,7 +153,7 @@ async def test_azure_content_filter_scorer_adds_to_memory(): async def test_azure_content_filter_scorer_score(patch_central_database): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "2", "category": "Hate"}]} scorer._azure_cf_client = mock_client @@ -181,7 +191,7 @@ async def test_azure_content_filter_scorer_chunks_long_text(patch_central_databa with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() # Mock returns for two chunks mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "3", "category": "Hate"}]} scorer._azure_cf_client = mock_client @@ -205,7 +215,7 @@ async def test_azure_content_filter_scorer_accepts_short_text(patch_central_data with patch.object(CentralMemory, "get_memory_instance", return_value=memory): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) - mock_client = MagicMock() + mock_client = AsyncMock() mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "3", "category": "Hate"}]} scorer._azure_cf_client = mock_client