Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyrit/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from pyrit.auth.authenticator import Authenticator
from pyrit.auth.azure_auth import (
AsyncTokenProviderCredential,
AzureAuth,
TokenProviderCredential,
get_azure_async_token_provider,
Expand All @@ -19,6 +20,7 @@
from pyrit.auth.manual_copilot_authenticator import ManualCopilotAuthenticator

__all__ = [
"AsyncTokenProviderCredential",
"Authenticator",
"AzureAuth",
"AzureStorageAuth",
Expand Down
57 changes: 56 additions & 1 deletion pyrit/auth/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import inspect
import logging
import time
from typing import TYPE_CHECKING, Any, Union, cast
Expand All @@ -23,7 +24,7 @@
)

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Awaitable, Callable

import azure.cognitiveservices.speech as speechsdk

Expand Down Expand Up @@ -67,6 +68,60 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
return AccessToken(str(token), expires_on)


class AsyncTokenProviderCredential:
"""
Async wrapper to convert a token provider callable into an Azure AsyncTokenCredential.

This class bridges the gap between token provider functions (sync or async) and Azure SDK
async clients that require an AsyncTokenCredential object (with async def get_token).
"""

def __init__(self, token_provider: Callable[[], Union[str, Awaitable[str]]]) -> None:
"""
Initialize AsyncTokenProviderCredential.

Args:
token_provider: A callable that returns a token string (sync) or an awaitable that
returns a token string (async). Both are supported transparently.
"""
self._token_provider = token_provider

async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"""
Get an access token asynchronously.

Args:
scopes: Token scopes (ignored as the scope is already configured in the token provider).
kwargs: Additional arguments (ignored).

Returns:
AccessToken: The access token with expiration time.
"""
result = self._token_provider()
if inspect.isawaitable(result):
token = await result
else:
token = result
expires_on = int(time.time()) + 3600
return AccessToken(str(token), expires_on)

async def close(self) -> None:
"""No-op close for protocol compliance. The callable provider does not hold resources."""

async def __aenter__(self) -> AsyncTokenProviderCredential:
"""
Enter the async context manager.

Returns:
AsyncTokenProviderCredential: This credential instance.
"""
return self

async def __aexit__(self, *args: Any) -> None:
"""Exit the async context manager."""
await self.close()


class AzureAuth(Authenticator):
"""
Azure CLI Authentication.
Expand Down
94 changes: 62 additions & 32 deletions pyrit/score/float_scale/azure_content_filter_scorer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import base64
import inspect
from collections.abc import Callable
import logging
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Optional

from azure.ai.contentsafety import ContentSafetyClient
from azure.ai.contentsafety.aio import ContentSafetyClient
from azure.ai.contentsafety.models import (
AnalyzeImageOptions,
AnalyzeImageResult,
Expand All @@ -18,7 +18,7 @@
)
from azure.core.credentials import AzureKeyCredential

from pyrit.auth import TokenProviderCredential, get_azure_token_provider
from pyrit.auth import AsyncTokenProviderCredential, get_azure_async_token_provider
from pyrit.common import default_values
from pyrit.identifiers import ComponentIdentifier
from pyrit.models import (
Expand All @@ -38,6 +38,48 @@
from pyrit.score.scorer_evaluation.scorer_evaluator import ScorerEvalDatasetFiles
from pyrit.score.scorer_evaluation.scorer_metrics import ScorerMetrics

logger = logging.getLogger(__name__)


def _ensure_async_token_provider(
api_key: str | Callable[[], str | Awaitable[str]] | None,
) -> str | Callable[[], Awaitable[str]] | None:
"""
Ensure the api_key is either a string or an async callable.

If a synchronous callable token provider is provided, it's automatically wrapped
in an async function to make it compatible with the async ContentSafetyClient.

Args:
api_key: Either a string API key or a callable that returns a token (sync or async).

Returns:
Either a string API key or an async callable that returns a token.
"""
if api_key is None or isinstance(api_key, str) or not callable(api_key):
return api_key

# Check if the callable is already async
if inspect.iscoroutinefunction(api_key):
return api_key

# Wrap synchronous token provider in async function
logger.info(
"Detected synchronous token provider."
" Automatically wrapping in async function for compatibility with async ContentSafetyClient."
)

async def async_token_provider() -> str:
"""
Async wrapper for synchronous token provider.

Returns:
str: The token string from the synchronous provider.
"""
return api_key() # type: ignore[return-value]

return async_token_provider


class AzureContentFilterScorer(FloatScaleScorer):
"""
Expand Down Expand Up @@ -94,7 +136,7 @@ def __init__(
self,
*,
endpoint: Optional[str | None] = None,
api_key: Optional[str | Callable[[], str] | None] = None,
api_key: Optional[str | Callable[[], str | Awaitable[str]] | None] = None,
harm_categories: Optional[list[TextCategory]] = None,
validator: Optional[ScorerPromptValidator] = None,
) -> None:
Expand All @@ -104,13 +146,12 @@ def __init__(
Args:
endpoint (Optional[str | None]): The endpoint URL for the Azure Content Safety service.
Defaults to the `ENDPOINT_URI_ENVIRONMENT_VARIABLE` environment variable.
api_key (Optional[str | Callable[[], str] | None]):
api_key (Optional[str | Callable[[], str | Awaitable[str]] | None]):
The API key for accessing the Azure Content Safety service,
or a synchronous callable that returns an access token. Async token providers
are not supported. If not provided (via parameter
or environment variable), Entra ID authentication is used automatically.
You can also explicitly pass a token provider from pyrit.auth
(e.g., get_azure_token_provider('https://cognitiveservices.azure.com/.default')).
or a callable that returns an access token. Both synchronous and asynchronous
token providers are supported. Sync providers are automatically wrapped for
async compatibility. If not provided (via parameter or environment variable),
Entra ID authentication is used automatically.
Defaults to the `API_KEY_ENVIRONMENT_VARIABLE` environment variable.
harm_categories (Optional[list[TextCategory]]): The harm categories you want to query for as
defined in azure.ai.contentsafety.models.TextCategory. If not provided, defaults to all categories.
Expand All @@ -129,36 +170,25 @@ def __init__(
)

# API key: use passed value, env var, or fall back to Entra ID for Azure endpoints
resolved_api_key: str | Callable[[], str]
resolved_api_key: str | Callable[[], str | Awaitable[str]]
if api_key is not None and callable(api_key):
if asyncio.iscoroutinefunction(api_key):
raise ValueError(
"Async token providers are not supported by AzureContentFilterScorer. "
"Use a synchronous token provider (e.g., get_azure_token_provider) instead."
)
# Guard against sync callables that return coroutines/awaitables (e.g., lambda: async_fn())
test_result = api_key()
if inspect.isawaitable(test_result):
if hasattr(test_result, "close"):
test_result.close() # prevent "coroutine was never awaited" warning
raise ValueError(
"The provided token provider returns a coroutine/awaitable, which is not supported "
"by AzureContentFilterScorer. Use a synchronous token provider instead."
)
resolved_api_key = api_key
else:
api_key_value = default_values.get_non_required_value(
env_var_name=self.API_KEY_ENVIRONMENT_VARIABLE, passed_value=api_key
)
resolved_api_key = api_key_value or get_azure_token_provider("https://cognitiveservices.azure.com/.default")
resolved_api_key = api_key_value or get_azure_async_token_provider(
"https://cognitiveservices.azure.com/.default"
)

self._api_key = resolved_api_key
# Ensure api_key is async-compatible (wrap sync token providers if needed)
self._api_key = _ensure_async_token_provider(resolved_api_key)

# Create ContentSafetyClient with appropriate credential
if self._endpoint is not None:
if callable(self._api_key):
# Token provider - create a TokenCredential wrapper
credential = TokenProviderCredential(self._api_key)
# Token provider - create an AsyncTokenCredential wrapper
credential = AsyncTokenProviderCredential(self._api_key)
self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential)
else:
# String API key
Expand Down Expand Up @@ -291,7 +321,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op
categories=self._category_values,
output_type="EightSeverityLevels",
)
text_result = self._azure_cf_client.analyze_text(text_request_options)
text_result = await self._azure_cf_client.analyze_text(text_request_options)
filter_results.append(text_result)

elif message_piece.converted_value_data_type == "image_path":
Expand All @@ -301,7 +331,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op
image_request_options = AnalyzeImageOptions(
image=image_data, categories=self._category_values, output_type="FourSeverityLevels"
)
image_result = self._azure_cf_client.analyze_image(image_request_options)
image_result = await self._azure_cf_client.analyze_image(image_request_options)
filter_results.append(image_result)

# Collect all scores from all chunks/images
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from sqlalchemy import inspect

from pyrit.identifiers import AttackIdentifier
from pyrit.identifiers import ComponentIdentifier
from pyrit.memory import MemoryInterface, SQLiteMemory
from pyrit.models import Message, MessagePiece
from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute
Expand Down Expand Up @@ -49,7 +49,7 @@ def set_system_prompt(
*,
system_prompt: str,
conversation_id: str,
attack_identifier: Optional[AttackIdentifier] = None,
attack_identifier: Optional[ComponentIdentifier] = None,
labels: Optional[dict[str, str]] = None,
) -> None:
self.system_prompt = system_prompt
Expand Down
19 changes: 5 additions & 14 deletions tests/integration/score/test_azure_content_filter_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from pyrit.memory import CentralMemory, MemoryInterface
from pyrit.score import AzureContentFilterScorer

pytestmark = pytest.mark.skipif(
not os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT"),
reason="AZURE_CONTENT_SAFETY_API_ENDPOINT not configured",
)


@pytest.fixture
def memory() -> Generator[MemoryInterface, None, None]:
Expand All @@ -27,13 +32,6 @@ async def test_azure_content_filter_scorer_image_integration(memory) -> None:
environment variables to be set. Uses a sample image from the assets folder.
"""
with patch.object(CentralMemory, "get_memory_instance", return_value=memory):
# Verify required environment variables are set
api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY")
endpoint = os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT")

if not api_key or not endpoint:
pytest.skip("Azure Content Safety credentials not configured")

scorer = AzureContentFilterScorer()

image_path = HOME_PATH / "assets" / "architecture_components.png"
Expand Down Expand Up @@ -62,13 +60,6 @@ async def test_azure_content_filter_scorer_long_text_chunking_integration(memory
This verifies that the chunking and aggregation logic works correctly with the real API.
"""
with patch.object(CentralMemory, "get_memory_instance", return_value=memory):
# Verify required environment variables are set
api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY")
endpoint = os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT")

if not api_key or not endpoint:
pytest.skip("Azure Content Safety credentials not configured")

scorer = AzureContentFilterScorer()

# This should be greater than the rate limit
Expand Down
Loading
Loading