diff --git a/pyrit/prompt_target/azure_blob_storage_target.py b/pyrit/prompt_target/azure_blob_storage_target.py index 824c104f47..be74edab2e 100644 --- a/pyrit/prompt_target/azure_blob_storage_target.py +++ b/pyrit/prompt_target/azure_blob_storage_target.py @@ -15,6 +15,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.utils import limit_requests_per_minute logger = logging.getLogger(__name__) @@ -49,6 +50,12 @@ class AzureBlobStorageTarget(PromptTarget): AZURE_STORAGE_CONTAINER_ENVIRONMENT_VARIABLE: str = "AZURE_STORAGE_ACCOUNT_CONTAINER_URL" SAS_TOKEN_ENVIRONMENT_VARIABLE: str = "AZURE_STORAGE_ACCOUNT_SAS_TOKEN" + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities( + input_modalities=["text", "url"], + output_modalities=["url"], + supports_multi_message_pieces=False, + ) + def __init__( self, *, @@ -196,18 +203,3 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: ) return [response] - - def _validate_request(self, *, message: Message) -> None: - n_pieces = len(message.message_pieces) - if n_pieces != 1: - raise ValueError(f"This target only supports a single message piece. Received {n_pieces} pieces") - - piece_type = message.message_pieces[0].converted_value_data_type - if piece_type not in ["text", "url"]: - raise ValueError(f"This target only supports text and url prompt input. Received: {piece_type}.") - - request = message.message_pieces[0] - messages = self._memory.get_message_pieces(conversation_id=request.conversation_id) - - if len(messages) > 0: - raise ValueError("This target only supports a single turn conversation.") diff --git a/pyrit/prompt_target/azure_ml_chat_target.py b/pyrit/prompt_target/azure_ml_chat_target.py index d058735210..3bc9b39e50 100644 --- a/pyrit/prompt_target/azure_ml_chat_target.py +++ b/pyrit/prompt_target/azure_ml_chat_target.py @@ -272,12 +272,3 @@ def _get_headers(self) -> dict[str, str]: def _validate_request(self, *, message: Message) -> None: pass - - def is_json_response_supported(self) -> bool: - """ - Check if the target supports JSON as a response format. - - Returns: - bool: True if JSON response is supported, False otherwise. - """ - return False diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index abdb66920e..3b9edce036 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import abc from typing import Optional from pyrit.identifiers import ComponentIdentifier @@ -31,7 +30,7 @@ def __init__( endpoint: str = "", model_name: str = "", underlying_model: Optional[str] = None, - capabilities: Optional[TargetCapabilities] = None, + custom_capabilities: Optional[TargetCapabilities] = None, ) -> None: """ Initialize the PromptChatTarget. @@ -43,7 +42,7 @@ def __init__( underlying_model (str, Optional): The underlying model name (e.g., "gpt-4o") for identification purposes. This is useful when the deployment name in Azure differs from the actual model. Defaults to None. - capabilities (TargetCapabilities, Optional): Override the default capabilities for + custom_capabilities (TargetCapabilities, Optional): Override the default capabilities for this target instance. If None, uses the class-level defaults. Defaults to None. """ super().__init__( @@ -51,7 +50,7 @@ def __init__( endpoint=endpoint, model_name=model_name, underlying_model=underlying_model, - capabilities=capabilities, + custom_capabilities=custom_capabilities, ) def set_system_prompt( @@ -85,15 +84,6 @@ def set_system_prompt( ).to_message() ) - @abc.abstractmethod - def is_json_response_supported(self) -> bool: - """ - Abstract method to determine if JSON response format is supported by the target. - - Returns: - bool: True if JSON response is supported, False otherwise. - """ - def is_response_format_json(self, message_piece: MessagePiece) -> bool: """ Check if the response format is JSON and ensure the target supports it. @@ -127,7 +117,7 @@ def _get_json_response_config(self, *, message_piece: MessagePiece) -> _JsonResp """ config = _JsonResponseConfig.from_metadata(metadata=message_piece.prompt_metadata) - if config.enabled and not self.is_json_response_supported(): + if config.enabled and not self.capabilities.supports_json_response: target_name = self.get_identifier().class_name raise ValueError(f"This target {target_name} does not support JSON response format.") diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 7ad4fc4ed5..b8cdd9659e 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -38,7 +38,7 @@ def __init__( endpoint: str = "", model_name: str = "", underlying_model: Optional[str] = None, - capabilities: Optional[TargetCapabilities] = None, + custom_capabilities: Optional[TargetCapabilities] = None, ) -> None: """ Initialize the PromptTarget. @@ -52,7 +52,7 @@ def __init__( identification purposes. This is useful when the deployment name in Azure differs from the actual model. If not provided, `model_name` will be used for the identifier. Defaults to None. - capabilities (TargetCapabilities, Optional): Override the default capabilities for + custom_capabilities (TargetCapabilities, Optional): Override the default capabilities for this target instance. Useful for targets whose capabilities depend on deployment configuration (e.g., Playwright, HTTP). If None, uses the class-level ``_DEFAULT_CAPABILITIES``. Defaults to None. @@ -63,7 +63,9 @@ def __init__( self._endpoint = endpoint self._model_name = model_name self._underlying_model = underlying_model - self._capabilities = capabilities if capabilities is not None else type(self)._DEFAULT_CAPABILITIES + self._capabilities = ( + custom_capabilities if custom_capabilities is not None else type(self)._DEFAULT_CAPABILITIES + ) if self._verbose: logging.basicConfig(level=logging.INFO) @@ -78,14 +80,36 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: but some (like response target with tool calls) may return multiple messages. """ - @abc.abstractmethod def _validate_request(self, *, message: Message) -> None: """ Validate the provided message. Args: message: The message to validate. + + Raises: + ValueError: if the target does not support the provided message pieces or if the message + violates any constraints based on the target's capabilities. + """ + n_pieces = len(message.message_pieces) + if not self.capabilities.supports_multi_message_pieces and n_pieces != 1: + raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") + + for piece in message.message_pieces: + piece_type = piece.converted_value_data_type + if piece_type not in self.capabilities.input_modalities: + supported_types = ", ".join(sorted(self.capabilities.input_modalities)) + raise ValueError( + f"This target supports only the following data types: {supported_types}. Received: {piece_type}." + ) + + if not self.supports_multi_turn: + request = message.message_pieces[0] + messages = self._memory.get_message_pieces(conversation_id=request.conversation_id) + + if len(messages) > 0: + raise ValueError("This target only supports a single turn conversation.") def set_model_name(self, *, model_name: str) -> None: """ @@ -140,6 +164,15 @@ def _create_identifier( return ComponentIdentifier.of(self, params=all_params, children=children) + def is_json_response_supported(self) -> bool: + """ + Method to determine if JSON response format is supported by the target. + + Returns: + bool: True if JSON response is supported, False otherwise. + """ + return self.capabilities.supports_json_response + @property def capabilities(self) -> TargetCapabilities: """ diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index 4ca70c9eff..996168ca32 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -1,7 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from dataclasses import dataclass +from dataclasses import dataclass, field, fields + +from pyrit.models import PromptDataType @dataclass(frozen=True) @@ -20,3 +22,38 @@ class attribute. Users can override individual capabilities per instance # (i.e., it accepts and uses conversation history or maintains state # across turns via external mechanisms like WebSocket connections). supports_multi_turn: bool = False + + # Whether the target natively supports multiple message pieces in a single request. + supports_multi_message_pieces: bool = True + + # Whether the target natively supports JSON output (e.g., via a "json" response format). + supports_json_response: bool = False + + # The input modalities supported by the target (e.g., "text", "image"). + input_modalities: list[PromptDataType] = field(default_factory=lambda: ["text"]) + + # The output modalities supported by the target (e.g., "text", "image"). + output_modalities: list[PromptDataType] = field(default_factory=lambda: ["text"]) + + def assert_satifies(self, required_capabilities: "TargetCapabilities") -> None: + """ + Assert that the current capabilities satisfy the required capabilities. + + Args: + required_capabilities (TargetCapabilities): The required capabilities to check against. + + Raises: + ValueError: If any of the required capabilities are not satisfied. + """ + unmet = [] + for f in fields(required_capabilities): + required_value = getattr(required_capabilities, f.name) + self_value = getattr(self, f.name) + if isinstance(required_value, list): + missing = set(required_value) - set(self_value) + if missing: + unmet.append(f"{f.name}: missing {missing}") + elif required_value and not self_value: + unmet.append(f.name) + if unmet: + raise ValueError(f"Target does not satisfy the following capabilities: {', '.join(unmet)}") diff --git a/pyrit/prompt_target/crucible_target.py b/pyrit/prompt_target/crucible_target.py index f096bf926e..df466018aa 100644 --- a/pyrit/prompt_target/crucible_target.py +++ b/pyrit/prompt_target/crucible_target.py @@ -14,6 +14,7 @@ ) from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.utils import limit_requests_per_minute logger = logging.getLogger(__name__) @@ -24,6 +25,8 @@ class CrucibleTarget(PromptTarget): API_KEY_ENVIRONMENT_VARIABLE: str = "CRUCIBLE_API_KEY" + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_message_pieces=False) + def __init__( self, *, @@ -80,15 +83,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: return [response_entry] - def _validate_request(self, *, message: Message) -> None: - n_pieces = len(message.message_pieces) - if n_pieces != 1: - raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") - - piece_type = message.message_pieces[0].converted_value_data_type - if piece_type != "text": - raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") - @pyrit_target_retry async def _complete_text_async(self, text: str) -> str: payload: dict[str, object] = { diff --git a/pyrit/prompt_target/gandalf_target.py b/pyrit/prompt_target/gandalf_target.py index 5e3e89935d..07f434c465 100644 --- a/pyrit/prompt_target/gandalf_target.py +++ b/pyrit/prompt_target/gandalf_target.py @@ -10,6 +10,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.utils import limit_requests_per_minute logger = logging.getLogger(__name__) @@ -38,6 +39,8 @@ class GandalfLevel(enum.Enum): class GandalfTarget(PromptTarget): """A prompt target for the Gandalf security challenge.""" + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_message_pieces=False) + def __init__( self, *, @@ -93,15 +96,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: return [response_entry] - def _validate_request(self, *, message: Message) -> None: - n_pieces = len(message.message_pieces) - if n_pieces != 1: - raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") - - piece_type = message.message_pieces[0].converted_value_data_type - if piece_type != "text": - raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") - async def check_password(self, password: str) -> bool: """ Check if the password is correct. diff --git a/pyrit/prompt_target/http_target/http_target.py b/pyrit/prompt_target/http_target/http_target.py index f95be4fde0..e77f28aa65 100644 --- a/pyrit/prompt_target/http_target/http_target.py +++ b/pyrit/prompt_target/http_target/http_target.py @@ -5,7 +5,7 @@ import json import logging import re -from collections.abc import Callable, Sequence +from collections.abc import Callable from typing import Any, Optional import httpx @@ -17,6 +17,7 @@ construct_response_from_request, ) from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.utils import limit_requests_per_minute logger = logging.getLogger(__name__) @@ -40,6 +41,8 @@ class HTTPTarget(PromptTarget): httpx_client_kwargs: (dict): additional keyword arguments to pass to the HTTP client """ + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_message_pieces=False) + def __init__( self, http_request: str, @@ -290,10 +293,3 @@ def _infer_full_url_from_host( host = headers_dict["host"] return f"{http_protocol}{host}{path}" - - def _validate_request(self, *, message: Message) -> None: - message_pieces: Sequence[MessagePiece] = message.message_pieces - - n_pieces = len(message_pieces) - if n_pieces != 1: - raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index 85da9e084c..fe3c15d5dc 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -20,6 +20,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.models import Message, construct_response_from_request from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.utils import limit_requests_per_minute logger = logging.getLogger(__name__) @@ -34,6 +35,11 @@ class HuggingFaceChatTarget(PromptChatTarget): Inherits from PromptTarget to comply with the current design standards. """ + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=False, + ) + # Class-level cache for model and tokenizer _cached_model = None _cached_tokenizer = None @@ -388,25 +394,6 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any: logger.error(error_message) raise ValueError(error_message) - def _validate_request(self, *, message: Message) -> None: - """ - Validate the provided message. - - Args: - message: The message to validate. - - Raises: - ValueError: If the message does not contain exactly one text piece. - ValueError: If the message piece is not of type text. - """ - n_pieces = len(message.message_pieces) - if n_pieces != 1: - raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") - - piece_type = message.message_pieces[0].converted_value_data_type - if piece_type != "text": - raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") - def is_json_response_supported(self) -> bool: """ Check if the target supports JSON as a response format. diff --git a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py index a21c87365e..ef3e04514b 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_endpoint_target.py @@ -150,16 +150,3 @@ def _validate_request(self, *, message: Message) -> None: n_pieces = len(message.message_pieces) if n_pieces != 1: raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") - - piece_type = message.message_pieces[0].converted_value_data_type - if piece_type != "text": - raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") - - def is_json_response_supported(self) -> bool: - """ - Check if the target supports JSON as a response format. - - Returns: - bool: True if JSON response is supported, False otherwise. - """ - return False diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index a9d631da65..70dd42f4e6 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -5,6 +5,7 @@ import json import logging from collections.abc import MutableSequence +from dataclasses import replace from typing import Any, Optional from pyrit.common import convert_local_image_to_data_url @@ -24,6 +25,7 @@ ) from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p from pyrit.prompt_target.openai.openai_chat_audio_config import OpenAIChatAudioConfig from pyrit.prompt_target.openai.openai_target import OpenAITarget @@ -63,6 +65,13 @@ class OpenAIChatTarget(OpenAITarget, PromptChatTarget): """ + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities( + supports_multi_turn=True, + supports_json_response=True, + input_modalities=["text", "image_path", "audio_path"], + output_modalities=["text", "audio_path"], + ) + def __init__( self, *, @@ -77,21 +86,11 @@ def __init__( is_json_supported: bool = True, audio_response_config: Optional[OpenAIChatAudioConfig] = None, extra_body_parameters: Optional[dict[str, Any]] = None, + custom_capabilities: Optional[TargetCapabilities] = None, **kwargs: Any, ) -> None: """ Args: - model_name (str, Optional): The name of the model. - If no value is provided, the OPENAI_CHAT_MODEL environment variable will be used. - endpoint (str, Optional): The target URL for the OpenAI service. - api_key (str | Callable[[], str], Optional): The API key for accessing the OpenAI service, - or a callable that returns an access token. For Azure endpoints with Entra authentication, - pass a token provider from pyrit.auth (e.g., get_azure_openai_auth(endpoint)). - Defaults to the `OPENAI_CHAT_KEY` environment variable. - headers (str, Optional): Headers of the endpoint (JSON). - max_requests_per_minute (int, Optional): Number of requests the target can handle per - minute before hitting a rate limit. The number of requests sent to the target - will be capped at the value provided. max_completion_tokens (int, Optional): An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens. @@ -118,12 +117,12 @@ def __init__( setting the response_format header. Official OpenAI models all support this, but if you are using this target with different models, is_json_supported should be set correctly to avoid issues when using adversarial infrastructure (e.g. Crescendo scorers will set this flag). + This value is now deprecated in favor of `custom_capabilities`. audio_response_config (OpenAIChatAudioConfig, Optional): Configuration for audio output from models that support it (e.g., gpt-4o-audio-preview). When provided, enables audio modality in responses. extra_body_parameters (dict, Optional): Additional parameters to be included in the request body. + custom_capabilities (TargetCapabilities, Optional): Override the default target capabilities. **kwargs: Additional keyword arguments passed to the parent OpenAITarget class. - httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the ``httpx.AsyncClient()`` - constructor. For example, to specify a 3 minute timeout: ``httpx_client_kwargs={"timeout": 180}`` Raises: PyritException: If the temperature or top_p values are out of bounds. @@ -135,7 +134,14 @@ def __init__( json.JSONDecodeError: If the response from the target is not valid JSON. Exception: If the request fails for any other reason. """ - super().__init__(**kwargs) + # initialize custom capabilities with the _DEFAULT_CAPABILITIES and the is_json_supported flag + # If custom_capabilities is provided, use it as-is (it takes precedence over the deprecated is_json_supported). + # Otherwise, apply is_json_supported to the default capabilities for backwards compatibility. + if custom_capabilities is not None: + effective_capabilities = custom_capabilities + else: + effective_capabilities = replace(type(self)._DEFAULT_CAPABILITIES, supports_json_response=is_json_supported) + super().__init__(custom_capabilities=effective_capabilities, **kwargs) # Validate temperature and top_p validate_temperature(temperature) @@ -146,7 +152,6 @@ def __init__( self._temperature = temperature self._top_p = top_p - self._is_json_supported = is_json_supported self._max_completion_tokens = max_completion_tokens self._max_tokens = max_tokens self._frequency_penalty = frequency_penalty @@ -471,15 +476,6 @@ async def _save_audio_response_async(self, *, audio_data_base64: str) -> str: return audio_serializer.value - def is_json_response_supported(self) -> bool: - """ - Check if the target supports JSON as a response format. - - Returns: - bool: True if JSON response is supported, False otherwise. - """ - return self._is_json_supported - async def _build_chat_messages_async(self, conversation: MutableSequence[Message]) -> list[dict[str, Any]]: """ Build chat messages based on message entries. @@ -652,27 +648,6 @@ async def _construct_request_body( # Filter out None values return {k: v for k, v in body_parameters.items() if v is not None} - def _validate_request(self, *, message: Message) -> None: - """ - Validate the structure and content of a message for compatibility of this target. - - Args: - message (Message): The message object. - - Raises: - ValueError: If any of the message pieces have a data type other than 'text' or 'image_path'. - """ - converted_prompt_data_types = [ - message_piece.converted_value_data_type for message_piece in message.message_pieces - ] - - # Some models may not support all of these - for prompt_data_type in converted_prompt_data_types: - if prompt_data_type not in ["text", "image_path", "audio_path"]: - raise ValueError( - f"This target only supports text, image_path, and audio_path. Received: {prompt_data_type}." - ) - def _build_response_format(self, json_config: _JsonResponseConfig) -> Optional[dict[str, Any]]: if not json_config.enabled: return None diff --git a/pyrit/prompt_target/openai/openai_completion_target.py b/pyrit/prompt_target/openai/openai_completion_target.py index e0000c148a..5017116f03 100644 --- a/pyrit/prompt_target/openai/openai_completion_target.py +++ b/pyrit/prompt_target/openai/openai_completion_target.py @@ -19,7 +19,7 @@ class OpenAICompletionTarget(OpenAITarget): """A prompt target for OpenAI completion endpoints.""" - _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_turn=False) + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_message_pieces=False) def __init__( self, @@ -167,21 +167,3 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> extracted_response = [choice.text for choice in response.choices] return construct_response_from_request(request=request, response_text_pieces=extracted_response) - - def _validate_request(self, *, message: Message) -> None: - n_pieces = len(message.message_pieces) - if n_pieces != 1: - raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") - - piece_type = message.message_pieces[0].converted_value_data_type - if piece_type != "text": - raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") - - def is_json_response_supported(self) -> bool: - """ - Check if the target supports JSON as a response format. - - Returns: - bool: True if JSON response is supported, False otherwise. - """ - return False diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index 8734adb776..5e926d2056 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -28,7 +28,11 @@ class OpenAIImageTarget(OpenAITarget): # Maximum number of image inputs supported by the OpenAI image API _MAX_INPUT_IMAGES = 16 - _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_turn=False) + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities( + supports_multi_turn=False, + input_modalities=["text", "image_path"], + output_modalities=["image_path"], + ) def __init__( self, @@ -45,17 +49,6 @@ def __init__( Initialize the image target with specified parameters. Args: - model_name (str, Optional): The name of the model (or deployment name in Azure). - If no value is provided, the OPENAI_IMAGE_MODEL environment variable will be used. - endpoint (str, Optional): The target URL for the OpenAI service. - api_key (str | Callable[[], str], Optional): The API key for accessing the OpenAI service, - or a callable that returns an access token. For Azure endpoints with Entra authentication, - pass a token provider from pyrit.auth (e.g., get_azure_openai_auth(endpoint)). - Defaults to the `OPENAI_IMAGE_API_KEY` environment variable. - headers (str, Optional): Headers of the endpoint (JSON). - max_requests_per_minute (int, Optional): Number of requests the target can handle per - minute before hitting a rate limit. The number of requests sent to the target - will be capped at the value provided. image_size (Literal, Optional): The size of the generated image. Accepts "256x256", "512x512", "1024x1024", "1536x1024", "1024x1536", "1792x1024", or "1024x1792". @@ -297,14 +290,10 @@ async def _get_image_bytes(self, image_data: Any) -> bytes: raise EmptyResponseException(message="The image generation returned an empty response.") def _validate_request(self, *, message: Message) -> None: - n_pieces = len(message.message_pieces) - - if n_pieces < 1: - raise ValueError("The message must contain at least one piece.") + super()._validate_request(message=message) text_pieces = [p for p in message.message_pieces if p.converted_value_data_type == "text"] image_pieces = [p for p in message.message_pieces if p.converted_value_data_type == "image_path"] - other_pieces = [p for p in message.message_pieces if p.converted_value_data_type not in ("text", "image_path")] if len(text_pieces) != 1: raise ValueError(f"The message must contain exactly one text piece. Received: {len(text_pieces)}.") @@ -313,26 +302,3 @@ def _validate_request(self, *, message: Message) -> None: raise ValueError( f"The message can contain up to {self._MAX_INPUT_IMAGES} image pieces. Received: {len(image_pieces)}." ) - - if len(other_pieces) > 0: - other_types = [p.converted_value_data_type for p in other_pieces] - raise ValueError(f"The message contains unsupported piece types. Unsupported types: {other_types}.") - - request = text_pieces[0] - messages = self._memory.get_conversation(conversation_id=request.conversation_id) - - n_messages = len(messages) - if n_messages > 0: - raise ValueError( - "This target only supports a single turn conversation. " - f"Received: {n_messages} messages which indicates a prior turn." - ) - - def is_json_response_supported(self) -> bool: - """ - Check if the target supports JSON as a response format. - - Returns: - bool: True if JSON response is supported, False otherwise. - """ - return False diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index d3bd23f829..d7bd1eb42b 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -68,7 +68,12 @@ class RealtimeTarget(OpenAITarget, PromptChatTarget): and https://platform.openai.com/docs/guides/realtime-websocket """ - _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_turn=True) + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities( + supports_multi_turn=True, + supports_multi_message_pieces=False, + input_modalities=["text", "audio_path"], + output_modalities=["text", "audio_path"], + ) def __init__( self, @@ -771,32 +776,3 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> This implementation exists to satisfy the abstract base class requirement. """ raise NotImplementedError("RealtimeTarget uses receive_events for message construction") - - def _validate_request(self, *, message: Message) -> None: - """ - Validate the structure and content of a message for compatibility of this target. - - Args: - message (Message): The message object. - - Raises: - ValueError: If more than two message pieces are provided. - ValueError: If any of the message pieces have a data type other than 'text' or 'audio_path'. - """ - # Check the number of message pieces - n_pieces = len(message.message_pieces) - if n_pieces != 1: - raise ValueError(f"This target only supports one message piece. Received: {n_pieces} pieces.") - - piece_type = message.message_pieces[0].converted_value_data_type - if piece_type not in ["text", "audio_path"]: - raise ValueError(f"This target only supports text and audio_path prompt input. Received: {piece_type}.") - - def is_json_response_supported(self) -> bool: - """ - Check if the target supports JSON as a response format. - - Returns: - bool: True if JSON response is supported, False otherwise. - """ - return False diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 9951b6db92..03b7fde09d 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -29,6 +29,7 @@ ) from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.utils import limit_requests_per_minute, validate_temperature, validate_top_p from pyrit.prompt_target.openai.openai_error_handling import _is_content_filter_error from pyrit.prompt_target.openai.openai_target import OpenAITarget @@ -67,6 +68,19 @@ class OpenAIResponseTarget(OpenAITarget, PromptChatTarget): https://platform.openai.com/docs/api-reference/responses/create """ + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities( + supports_multi_turn=True, + supports_json_response=True, + input_modalities=[ + "text", + "image_path", + "function_call", + "tool_call", + "function_call_output", + "reasoning", + ], + ) + def __init__( self, *, @@ -85,12 +99,6 @@ def __init__( Args: custom_functions: Mapping of user-defined function names (e.g., "my_func"). - model_name (str, Optional): The name of the model (or deployment name in Azure). - If no value is provided, the OPENAI_RESPONSES_MODEL environment variable will be used. - endpoint (str, Optional): The target URL for the OpenAI service. - api_key (str, Optional): The API key for accessing the Azure OpenAI service. - Defaults to the OPENAI_RESPONSES_KEY environment variable. - headers (str, Optional): Headers of the endpoint (JSON). max_requests_per_minute (int, Optional): Number of requests the target can handle per minute before hitting a rate limit. The number of requests sent to the target will be capped at the value provided. @@ -108,18 +116,12 @@ def __init__( reasoning_summary (Literal["auto", "concise", "detailed"], Optional): Controls whether a summary of the model's reasoning is included in the response. Defaults to None (no summary). - is_json_supported (bool, Optional): If True, the target will support formatting responses as JSON by - setting the response_format header. Official OpenAI models all support this, but if you are using - this target with different models, is_json_supported should be set correctly to avoid issues when - using adversarial infrastructure (e.g. Crescendo scorers will set this flag). extra_body_parameters (dict, Optional): Additional parameters to be included in the request body. fail_on_missing_function: if True, raise when a function_call references an unknown function or does not output a function; if False, return a structured error so we can wrap it as function_call_output and let the model potentially recover (e.g., pick another tool or ask for clarification). **kwargs: Additional keyword arguments passed to the parent OpenAITarget class. - httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the ``httpx.AsyncClient()`` - constructor. For example, to specify a 3 minute timeout: ``httpx_client_kwargs={"timeout": 180}`` Raises: PyritException: If the temperature or top_p values are out of bounds. @@ -563,15 +565,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: # Return all responses (normalizer will persist all of them to memory) return responses_to_return - def is_json_response_supported(self) -> bool: - """ - Check if the target supports JSON as a response format. - - Returns: - bool: True if JSON response is supported, False otherwise. - """ - return True - def _parse_response_output_section( self, *, section: Any, message_piece: MessagePiece, error: Optional[PromptResponseError] ) -> MessagePiece | None: @@ -672,31 +665,6 @@ def _parse_response_output_section( response_error=error or "none", ) - def _validate_request(self, *, message: Message) -> None: - """ - Validate the structure and content of a message for compatibility of this target. - - Args: - message (Message): The message object. - - Raises: - ValueError: If any of the message pieces have a data type other than supported set. - """ - # Some models may not support all of these; we accept them at the transport layer - # so the Responses API can decide. We include reasoning and function_call_output now. - allowed_types = { - "text", - "image_path", - "function_call", - "tool_call", - "function_call_output", - "reasoning", - } - for message_piece in message.message_pieces: - if message_piece.converted_value_data_type not in allowed_types: - raise ValueError(f"Unsupported data type: {message_piece.converted_value_data_type}") - return - # Agentic helpers (module scope) def _find_last_pending_tool_call(self, reply: Message) -> Optional[dict[str, Any]]: diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index 0128991e3f..10e99d822e 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -111,7 +111,7 @@ def __init__( max_requests_per_minute: Optional[int] = None, httpx_client_kwargs: Optional[dict[str, Any]] = None, underlying_model: Optional[str] = None, - capabilities: Optional[TargetCapabilities] = None, + custom_capabilities: Optional[TargetCapabilities] = None, ) -> None: """ Initialize an instance of OpenAITarget. @@ -139,7 +139,7 @@ def __init__( from the actual model. If not provided, will attempt to fetch from environment variable. If it is not there either, the identifier "model_name" attribute will use the model_name. Defaults to None. - capabilities (TargetCapabilities, Optional): Override the default capabilities for + custom_capabilities (TargetCapabilities, Optional): Override the default capabilities for this target instance. If None, uses the class-level defaults. Defaults to None. Raises: @@ -176,7 +176,7 @@ def __init__( endpoint=endpoint_value, model_name=self._model_name, underlying_model=underlying_model_value, - capabilities=capabilities, + custom_capabilities=custom_capabilities, ) # API key: use passed value, env var, or fall back to Entra ID for Azure endpoints @@ -708,11 +708,11 @@ def _warn_if_irregular_endpoint(self, expected_url_regex: list[str]) -> None: f"For more details and guidance, please see the .env_example file in the repository." ) - @abstractmethod def is_json_response_supported(self) -> bool: """ - Abstract method to determine if JSON response format is supported by the target. + Determine if JSON response format is supported by the target. Returns: bool: True if JSON response is supported, False otherwise. """ + return self._capabilities.supports_json_response diff --git a/pyrit/prompt_target/openai/openai_tts_target.py b/pyrit/prompt_target/openai/openai_tts_target.py index 130bf7274a..c20c3746ab 100644 --- a/pyrit/prompt_target/openai/openai_tts_target.py +++ b/pyrit/prompt_target/openai/openai_tts_target.py @@ -27,7 +27,10 @@ class OpenAITTSTarget(OpenAITarget): """A prompt target for OpenAI Text-to-Speech (TTS) endpoints.""" - _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_turn=False) + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities( + supports_multi_message_pieces=False, + output_modalities=["audio_path"], + ) def __init__( self, @@ -167,31 +170,3 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> return construct_response_from_request( request=request, response_text_pieces=[str(audio_response.value)], response_type="audio_path" ) - - def _validate_request(self, *, message: Message) -> None: - n_pieces = len(message.message_pieces) - if n_pieces != 1: - raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") - - piece_type = message.message_pieces[0].converted_value_data_type - if piece_type != "text": - raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") - - request = message.message_pieces[0] - messages = self._memory.get_conversation(conversation_id=request.conversation_id) - - n_messages = len(messages) - if n_messages > 0: - raise ValueError( - "This target only supports a single turn conversation. " - f"Received: {n_messages} messages which indicates a prior turn." - ) - - def is_json_response_supported(self) -> bool: - """ - Check if the target supports JSON as a response format. - - Returns: - bool: True if JSON response is supported, False otherwise. - """ - return False diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index f09f5bd679..7a662b4070 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -52,7 +52,11 @@ class OpenAIVideoTarget(OpenAITarget): SUPPORTED_RESOLUTIONS: list[VideoSize] = ["720x1280", "1280x720", "1024x1792", "1792x1024"] SUPPORTED_DURATIONS: list[VideoSeconds] = ["4", "8", "12"] SUPPORTED_IMAGE_FORMATS: list[str] = ["image/jpeg", "image/png", "image/webp"] - _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_turn=False) + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities( + supports_multi_turn=False, + input_modalities=["text", "image_path"], + output_modalities=["video_path"], + ) def __init__( self, @@ -460,6 +464,8 @@ def _validate_request(self, *, message: Message) -> None: Raises: ValueError: If the request is invalid. """ + super()._validate_request(message=message) + text_pieces = message.get_pieces_by_type(data_type="text") image_pieces = message.get_pieces_by_type(data_type="image_path") video_pieces = message.get_pieces_by_type(data_type="video_path") @@ -494,24 +500,6 @@ def _validate_request(self, *, message: Message) -> None: if video_pieces and image_pieces: raise ValueError("Cannot combine video_path and image_path pieces.") - messages = self._memory.get_conversation(conversation_id=text_piece.conversation_id) - - n_messages = len(messages) - if n_messages > 0: - raise ValueError( - "This target only supports a single turn conversation. " - f"Received: {n_messages} messages which indicates a prior turn." - ) - - def is_json_response_supported(self) -> bool: - """ - Check if the target supports JSON response data. - - Returns: - bool: False, as video generation doesn't return JSON content. - """ - return False - @staticmethod def _validate_video_remix_pieces(*, message: Message) -> None: """ diff --git a/pyrit/prompt_target/playwright_copilot_target.py b/pyrit/prompt_target/playwright_copilot_target.py index 520a8d4888..db606dd2bc 100644 --- a/pyrit/prompt_target/playwright_copilot_target.py +++ b/pyrit/prompt_target/playwright_copilot_target.py @@ -7,7 +7,7 @@ from contextlib import suppress from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union from pyrit.identifiers import ComponentIdentifier from pyrit.models import ( @@ -79,7 +79,11 @@ class PlaywrightCopilotTarget(PromptTarget): # Supported data types SUPPORTED_DATA_TYPES = {"text", "image_path"} - _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_turn=True) + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities( + supports_multi_turn=True, + input_modalities=["text", "image_path"], + output_modalities=["text", "image_path"], + ) # Placeholder text constants PLACEHOLDER_GENERATING_RESPONSE: str = "generating response" @@ -109,7 +113,6 @@ def __init__( *, page: "Page", copilot_type: CopilotType = CopilotType.CONSUMER, - capabilities: Optional[TargetCapabilities] = None, ) -> None: """ Initialize the Playwright Copilot target. @@ -118,14 +121,12 @@ def __init__( page (Page): The Playwright page object for browser interaction. copilot_type (CopilotType): The type of Copilot to interact with. Defaults to CopilotType.CONSUMER. - capabilities (TargetCapabilities, Optional): Override the default capabilities for - this target instance. If None, uses the class-level defaults. Defaults to None. Raises: RuntimeError: If the Playwright page is not initialized. ValueError: If the page URL doesn't match the specified copilot_type. """ - super().__init__(capabilities=capabilities) + super().__init__() self._page = page self._type = copilot_type @@ -862,26 +863,3 @@ async def _check_login_requirement_async(self) -> None: sign_in_header_present = sign_in_header_count > 0 if sign_in_header_present: raise RuntimeError("Login required to access advanced features in Consumer Copilot.") - - def _validate_request(self, *, message: Message) -> None: - """ - Validate that the message is compatible with Copilot. - - Args: - message: The message to validate. - - Raises: - ValueError: If the message has no pieces. - ValueError: If any piece has an unsupported data type. - """ - if not message.message_pieces: - raise ValueError("This target requires at least one message piece.") - - # Validate that all pieces are supported types - for i, piece in enumerate(message.message_pieces): - piece_type = piece.converted_value_data_type - if piece_type not in self.SUPPORTED_DATA_TYPES: - supported_types = ", ".join(self.SUPPORTED_DATA_TYPES) - raise ValueError( - f"This target only supports {supported_types} prompt input. Piece {i} has type: {piece_type}." - ) diff --git a/pyrit/prompt_target/playwright_target.py b/pyrit/prompt_target/playwright_target.py index 62d488110e..c073758418 100644 --- a/pyrit/prompt_target/playwright_target.py +++ b/pyrit/prompt_target/playwright_target.py @@ -52,7 +52,10 @@ class PlaywrightTarget(PromptTarget): # Supported data types SUPPORTED_DATA_TYPES = {"text", "image_path"} - _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_turn=True) + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities( + supports_multi_turn=True, + input_modalities=["text", "image_path"], + ) def __init__( self, @@ -60,7 +63,6 @@ def __init__( interaction_func: InteractionFunction, page: "Page", max_requests_per_minute: Optional[int] = None, - capabilities: Optional[TargetCapabilities] = None, ) -> None: """ Initialize the Playwright target. @@ -71,11 +73,9 @@ def __init__( max_requests_per_minute (int, Optional): Number of requests the target can handle per minute before hitting a rate limit. The number of requests sent to the target will be capped at the value provided. - capabilities (TargetCapabilities, Optional): Override the default capabilities for - this target instance. If None, uses the class-level defaults. Defaults to None. """ endpoint = page.url if page else "" - super().__init__(max_requests_per_minute=max_requests_per_minute, endpoint=endpoint, capabilities=capabilities) + super().__init__(max_requests_per_minute=max_requests_per_minute, endpoint=endpoint) self._interaction_func = interaction_func self._page = page @@ -109,16 +109,3 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: request_piece = message.message_pieces[0] response_entry = construct_response_from_request(request=request_piece, response_text_pieces=[text]) return [response_entry] - - def _validate_request(self, *, message: Message) -> None: - if not message.message_pieces: - raise ValueError("This target requires at least one message piece.") - - # Validate that all pieces are supported types - for i, piece in enumerate(message.message_pieces): - piece_type = piece.converted_value_data_type - if piece_type not in self.SUPPORTED_DATA_TYPES: - supported_types = ", ".join(self.SUPPORTED_DATA_TYPES) - raise ValueError( - f"This target only supports {supported_types} input. Piece {i} has type: {piece_type}." - ) diff --git a/pyrit/prompt_target/prompt_shield_target.py b/pyrit/prompt_target/prompt_shield_target.py index fe1d3e760f..c7cfc4794e 100644 --- a/pyrit/prompt_target/prompt_shield_target.py +++ b/pyrit/prompt_target/prompt_shield_target.py @@ -3,17 +3,17 @@ import json import logging -from collections.abc import Callable, Sequence +from collections.abc import Callable from typing import Any, Literal, Optional from pyrit.common import default_values, net_utility from pyrit.identifiers import ComponentIdentifier from pyrit.models import ( Message, - MessagePiece, construct_response_from_request, ) from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.utils import limit_requests_per_minute logger = logging.getLogger(__name__) @@ -49,6 +49,8 @@ class PromptShieldTarget(PromptTarget): ENDPOINT_URI_ENVIRONMENT_VARIABLE: str = "AZURE_CONTENT_SAFETY_API_ENDPOINT" API_KEY_ENVIRONMENT_VARIABLE: str = "AZURE_CONTENT_SAFETY_API_KEY" + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_message_pieces=False) + _endpoint: str _api_key: str | Callable[[], str] | None _api_version: str @@ -163,17 +165,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: return [response_entry] - def _validate_request(self, *, message: Message) -> None: - message_pieces: Sequence[MessagePiece] = message.message_pieces - - n_pieces = len(message_pieces) - if n_pieces != 1: - raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") - - piece_type = message_pieces[0].converted_value_data_type - if piece_type != "text": - raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") - def _validate_response(self, request_body: dict[str, Any], response_body: dict[str, Any]) -> None: """ Ensure that every field sent to the Prompt Shield was analyzed. diff --git a/pyrit/prompt_target/websocket_copilot_target.py b/pyrit/prompt_target/websocket_copilot_target.py index ad9ed2c641..39a80b66bc 100644 --- a/pyrit/prompt_target/websocket_copilot_target.py +++ b/pyrit/prompt_target/websocket_copilot_target.py @@ -71,10 +71,12 @@ class WebSocketCopilotTarget(PromptTarget): The free version of Copilot is not compatible. """ - SUPPORTED_DATA_TYPES = {"text", "image_path"} RESPONSE_TIMEOUT_SECONDS: int = 60 CONNECTION_TIMEOUT_SECONDS: int = 30 - _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_turn=True) + _DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities( + supports_multi_turn=True, + input_modalities=["text", "image_path"], + ) def __init__( self, @@ -84,7 +86,6 @@ def __init__( model_name: str = "copilot", response_timeout_seconds: int = RESPONSE_TIMEOUT_SECONDS, authenticator: Optional[Union[CopilotAuthenticator, ManualCopilotAuthenticator]] = None, - capabilities: Optional[TargetCapabilities] = None, ) -> None: """ Initialize the WebSocketCopilotTarget. @@ -98,8 +99,6 @@ def __init__( authenticator (Optional[Union[CopilotAuthenticator, ManualCopilotAuthenticator]]): Authenticator instance. Supports both ``CopilotAuthenticator`` and ``ManualCopilotAuthenticator``. If None, a new ``CopilotAuthenticator`` instance will be created with default settings. - capabilities (TargetCapabilities, Optional): Override the default capabilities for - this target instance. If None, uses the class-level defaults. Defaults to None. Raises: ValueError: If ``response_timeout_seconds`` is not a positive integer. @@ -122,7 +121,6 @@ def __init__( max_requests_per_minute=max_requests_per_minute, endpoint=self._websocket_base_url, model_name=model_name, - capabilities=capabilities, ) def _build_identifier(self) -> ComponentIdentifier: @@ -579,13 +577,9 @@ def _validate_request(self, *, message: Message) -> None: Raises: ValueError: If message contains unsupported data types or invalid image formats. """ + super()._validate_request(message=message) for piece in message.message_pieces: piece_type = piece.converted_value_data_type - if piece_type not in self.SUPPORTED_DATA_TYPES: - supported_types = ", ".join(sorted(self.SUPPORTED_DATA_TYPES)) - raise ValueError( - f"This target supports only the following data types: {supported_types}. Received: {piece_type}." - ) if piece_type == "image_path": mime_type = DataTypeSerializer.get_mime_type(piece.converted_value) diff --git a/tests/integration/mocks.py b/tests/integration/mocks.py index 1c997f3326..5b872eb014 100644 --- a/tests/integration/mocks.py +++ b/tests/integration/mocks.py @@ -83,6 +83,3 @@ def _validate_request(self, *, message: Message) -> None: """ Validates the provided message """ - - def is_json_response_supported(self) -> bool: - return False diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 2c0064cee7..0bfa55f609 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -167,9 +167,6 @@ def _validate_request(self, *, message: Message) -> None: Validates the provided message """ - def is_json_response_supported(self) -> bool: - return False - def get_azure_sql_memory() -> Generator[AzureSQLMemory, None, None]: # Create a test Azure SQL Server DB using in-memory SQLite diff --git a/tests/unit/registry/test_target_registry.py b/tests/unit/registry/test_target_registry.py index 0db391c6a0..503d096a38 100644 --- a/tests/unit/registry/test_target_registry.py +++ b/tests/unit/registry/test_target_registry.py @@ -54,9 +54,6 @@ async def send_prompt_async( def _validate_request(self, *, message: Message) -> None: pass - def is_json_response_supported(self) -> bool: - return False - class TestTargetRegistrySingleton: """Tests for the singleton pattern in TargetRegistry.""" diff --git a/tests/unit/target/test_azure_ml_chat_target.py b/tests/unit/target/test_azure_ml_chat_target.py index a8cc386d16..a4714f96b8 100644 --- a/tests/unit/target/test_azure_ml_chat_target.py +++ b/tests/unit/target/test_azure_ml_chat_target.py @@ -197,10 +197,6 @@ async def test_send_prompt_async_empty_response_retries(aml_online_chat: AzureML assert mock_complete_chat_async.call_count == 2 -def test_is_json_response_supported(aml_online_chat: AzureMLChatTarget): - assert aml_online_chat.is_json_response_supported() is False - - def test_invalid_temperature_too_low_raises(patch_central_database): with pytest.raises(Exception, match="temperature must be between 0 and 2"): AzureMLChatTarget( diff --git a/tests/unit/target/test_azure_openai_completion_target.py b/tests/unit/target/test_azure_openai_completion_target.py index a4bb4160eb..fe7c497c0d 100644 --- a/tests/unit/target/test_azure_openai_completion_target.py +++ b/tests/unit/target/test_azure_openai_completion_target.py @@ -60,7 +60,7 @@ async def test_azure_completion_validate_request_length(azure_completion_target: @pytest.mark.asyncio async def test_azure_completion_validate_prompt_type(azure_completion_target: OpenAICompletionTarget): request = Message(message_pieces=[get_image_message_piece()]) - with pytest.raises(ValueError, match="This target only supports text prompt input."): + with pytest.raises(ValueError, match="This target supports only the following data types"): await azure_completion_target.send_prompt_async(message=request) diff --git a/tests/unit/target/test_crucible_target.py b/tests/unit/target/test_crucible_target.py index 09962e0768..85a0f40f7b 100644 --- a/tests/unit/target/test_crucible_target.py +++ b/tests/unit/target/test_crucible_target.py @@ -40,5 +40,5 @@ async def test_crucible_validate_request_length(crucible_target: CrucibleTarget) async def test_crucible_validate_prompt_type(crucible_target: CrucibleTarget): message_piece = get_image_message_piece() request = Message(message_pieces=[message_piece]) - with pytest.raises(ValueError, match="This target only supports text prompt input."): + with pytest.raises(ValueError, match="This target supports only the following data types"): await crucible_target.send_prompt_async(message=request) diff --git a/tests/unit/target/test_gandalf_target.py b/tests/unit/target/test_gandalf_target.py index 7f699452cf..5c5bfbdbb7 100644 --- a/tests/unit/target/test_gandalf_target.py +++ b/tests/unit/target/test_gandalf_target.py @@ -39,5 +39,5 @@ async def test_gandalf_validate_request_length(gandalf_target: GandalfTarget): @pytest.mark.asyncio async def test_gandalf_validate_prompt_type(gandalf_target: GandalfTarget): request = Message(message_pieces=[get_image_message_piece()]) - with pytest.raises(ValueError, match="This target only supports text prompt input."): + with pytest.raises(ValueError, match="This target supports only the following data types"): await gandalf_target.send_prompt_async(message=request) diff --git a/tests/unit/target/test_http_target.py b/tests/unit/target/test_http_target.py index 6e977edf7c..56368dcc43 100644 --- a/tests/unit/target/test_http_target.py +++ b/tests/unit/target/test_http_target.py @@ -67,7 +67,12 @@ def test_http_target_sets_endpoint_and_rate_limit(mock_callback_function, sqlite async def test_send_prompt_async(mock_request, mock_http_target, mock_http_response): message = MagicMock() message.message_pieces = [ - MagicMock(converted_value="test_prompt", prompt_target_identifier=None, attack_identifier=None) + MagicMock( + converted_value="test_prompt", + converted_value_data_type="text", + prompt_target_identifier=None, + attack_identifier=None, + ) ] mock_request.return_value = mock_http_response response = await mock_http_target.send_prompt_async(message=message) @@ -105,7 +110,7 @@ def test_parse_raw_http_respects_url_path(patch_central_database): @pytest.mark.asyncio -async def test_send_prompt_async_client_kwargs(): +async def test_send_prompt_async_client_kwargs(patch_central_database): with patch("httpx.AsyncClient.request", new_callable=AsyncMock) as mock_request: # Create httpx_client_kwargs to test httpx_client_kwargs = {"timeout": 10, "verify": False} @@ -114,7 +119,14 @@ async def test_send_prompt_async_client_kwargs(): # Use **httpx_client_kwargs to pass them as keyword arguments http_target = HTTPTarget(http_request=sample_request, **httpx_client_kwargs) message = MagicMock() - message.message_pieces = [MagicMock(converted_value="", prompt_target_identifier=None, attack_identifier=None)] + message.message_pieces = [ + MagicMock( + converted_value="", + converted_value_data_type="text", + prompt_target_identifier=None, + attack_identifier=None, + ) + ] mock_response = MagicMock() mock_response.content = b"Response content" mock_request.return_value = mock_response @@ -150,7 +162,12 @@ async def test_send_prompt_regex_parse_async(mock_request, mock_http_target): message = MagicMock() message.message_pieces = [ - MagicMock(converted_value="test_prompt", prompt_target_identifier=None, attack_identifier=None) + MagicMock( + converted_value="test_prompt", + converted_value_data_type="text", + prompt_target_identifier=None, + attack_identifier=None, + ) ] mock_response = MagicMock() @@ -179,7 +196,12 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http # Send first prompt message = MagicMock() message.message_pieces = [ - MagicMock(converted_value="test_prompt", prompt_target_identifier=None, attack_identifier=None) + MagicMock( + converted_value="test_prompt", + converted_value_data_type="text", + prompt_target_identifier=None, + attack_identifier=None, + ) ] response = await mock_http_target.send_prompt_async(message=message) @@ -199,7 +221,12 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http # Send second prompt second_message = MagicMock() second_message.message_pieces = [ - MagicMock(converted_value="second_test_prompt", prompt_target_identifier=None, attack_identifier=None) + MagicMock( + converted_value="second_test_prompt", + converted_value_data_type="text", + prompt_target_identifier=None, + attack_identifier=None, + ) ] await mock_http_target.send_prompt_async(message=second_message) @@ -226,7 +253,7 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http @pytest.mark.asyncio -async def test_http_target_with_injected_client(): +async def test_http_target_with_injected_client(patch_central_database): custom_client = httpx.AsyncClient(timeout=30.0, verify=False, headers={"X-Custom-Header": "test_value"}) sample_request = ( @@ -249,7 +276,12 @@ async def test_http_target_with_injected_client(): message = MagicMock() message.message_pieces = [ - MagicMock(converted_value="test_prompt", prompt_target_identifier=None, attack_identifier=None) + MagicMock( + converted_value="test_prompt", + converted_value_data_type="text", + prompt_target_identifier=None, + attack_identifier=None, + ) ] response = await target.send_prompt_async(message=message) diff --git a/tests/unit/target/test_huggingface_chat_target.py b/tests/unit/target/test_huggingface_chat_target.py index 77052fb1b0..0de8ddbbf3 100644 --- a/tests/unit/target/test_huggingface_chat_target.py +++ b/tests/unit/target/test_huggingface_chat_target.py @@ -318,15 +318,6 @@ async def test_optional_kwargs_args_passed_when_loading_model(mock_transformers) assert call_args.get("attn_implementation") == "flash_attention_2" -@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") -@pytest.mark.asyncio -async def test_is_json_response_supported(): - hf_chat = HuggingFaceChatTarget(model_id="dummy", use_cuda=False, trust_remote_code=True) - # Await the background task to prevent warnings - await hf_chat.load_model_and_tokenizer_task - assert hf_chat.is_json_response_supported() is False - - @pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") @pytest.mark.asyncio async def test_hugging_face_chat_sets_endpoint_and_rate_limit(patch_central_database): diff --git a/tests/unit/target/test_image_target.py b/tests/unit/target/test_image_target.py index 4c5056e247..294a41b8ae 100644 --- a/tests/unit/target/test_image_target.py +++ b/tests/unit/target/test_image_target.py @@ -408,15 +408,6 @@ async def test_send_prompt_async_url_response_downloads_image( os.remove(path) -def test_is_json_response_supported(patch_central_database): - mock_memory = MagicMock() - mock_memory.get_conversation.return_value = [] - mock_memory.add_message_to_memory = AsyncMock() - - mock_image_target = OpenAIImageTarget(model_name="test", endpoint="test", api_key="test") - assert mock_image_target.is_json_response_supported() is False - - @pytest.mark.asyncio async def test_validate_no_text_piece(image_target: OpenAIImageTarget): image_piece = get_image_message_piece() @@ -498,7 +489,7 @@ async def test_validate_piece_type(image_target: OpenAIImageTarget): request = Message(message_pieces=[audio_piece, text_piece]) with pytest.raises( ValueError, - match=f"The message contains unsupported piece types.", + match="This target supports only the following data types", ): await image_target.send_prompt_async(message=request) finally: @@ -513,7 +504,7 @@ async def test_validate_previous_conversations( message_piece = sample_conversations[0] mock_memory = MagicMock() - mock_memory.get_conversation.return_value = sample_conversations + mock_memory.get_message_pieces.return_value = sample_conversations mock_memory.add_message_to_memory = AsyncMock() image_target._memory = mock_memory diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index 846efe3536..1a60e3257e 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -105,35 +105,6 @@ def test_init_with_no_additional_request_headers_var_raises(): OpenAIChatTarget(model_name="gpt-4", endpoint="", api_key="xxxxx", headers="") -def test_init_is_json_supported_defaults_to_true(patch_central_database): - target = OpenAIChatTarget( - model_name="gpt-4", - endpoint="https://mock.azure.com/", - api_key="mock-api-key", - ) - assert target.is_json_response_supported() is True - - -def test_init_is_json_supported_can_be_set_to_false(patch_central_database): - target = OpenAIChatTarget( - model_name="gpt-4", - endpoint="https://mock.azure.com/", - api_key="mock-api-key", - is_json_supported=False, - ) - assert target.is_json_response_supported() is False - - -def test_init_is_json_supported_can_be_set_to_true(patch_central_database): - target = OpenAIChatTarget( - model_name="gpt-4", - endpoint="https://mock.azure.com/", - api_key="mock-api-key", - is_json_supported=True, - ) - assert target.is_json_response_supported() is True - - @pytest.mark.asyncio() async def test_build_chat_messages_for_multi_modal(target: OpenAIChatTarget): image_request = get_image_message_piece() @@ -556,17 +527,13 @@ def test_validate_request_unsupported_data_types(target: OpenAIChatTarget): with pytest.raises(ValueError) as excinfo: target._validate_request(message=message) - assert "This target only supports text, image_path, and audio_path." in str(excinfo.value), ( + assert "This target supports only the following data types" in str(excinfo.value), ( "Error not raised for unsupported data types" ) os.remove(image_piece.original_value) -def test_is_json_response_supported(target: OpenAIChatTarget): - assert target.is_json_response_supported() is True - - def test_inheritance_from_prompt_chat_target(target: OpenAIChatTarget): """Test that OpenAIChatTarget properly inherits from PromptChatTarget.""" assert isinstance(target, PromptChatTarget), "OpenAIChatTarget must inherit from PromptChatTarget" diff --git a/tests/unit/target/test_openai_response_target.py b/tests/unit/target/test_openai_response_target.py index 2c6fd598fc..507f8f0935 100644 --- a/tests/unit/target/test_openai_response_target.py +++ b/tests/unit/target/test_openai_response_target.py @@ -586,15 +586,13 @@ def test_validate_request_unsupported_data_types(target: OpenAIResponseTarget): with pytest.raises(ValueError) as excinfo: target._validate_request(message=message) - assert "Unsupported data type" in str(excinfo.value), "Error not raised for unsupported data types" + assert "This target supports only the following data types" in str(excinfo.value), ( + "Error not raised for unsupported data types" + ) os.remove(image_piece.original_value) -def test_is_json_response_supported(target: OpenAIResponseTarget): - assert target.is_json_response_supported() is True - - def test_inheritance_from_prompt_chat_target(target: OpenAIResponseTarget): """Test that OpenAIResponseTarget properly inherits from PromptChatTarget.""" assert isinstance(target, PromptChatTarget), "OpenAIResponseTarget must inherit from PromptChatTarget" @@ -670,11 +668,7 @@ def test_validate_request_raises_for_invalid_type(target: OpenAIResponseTarget): ) with pytest.raises(ValueError) as excinfo: target._validate_request(message=req) - assert "Unsupported data type" in str(excinfo.value) - - -def test_is_json_response_supported_returns_true(target: OpenAIResponseTarget): - assert target.is_json_response_supported() is True + assert "This target supports only the following data types" in str(excinfo.value) @pytest.mark.asyncio diff --git a/tests/unit/target/test_openai_target_auth.py b/tests/unit/target/test_openai_target_auth.py index 2045ae6e20..f3973efbcb 100644 --- a/tests/unit/target/test_openai_target_auth.py +++ b/tests/unit/target/test_openai_target_auth.py @@ -27,9 +27,6 @@ def _get_target_api_paths(self) -> list[str]: def _get_provider_examples(self) -> dict[str, str]: return {} - def is_json_response_supported(self) -> bool: - return True - async def _construct_message_from_response(self, response, request): raise NotImplementedError diff --git a/tests/unit/target/test_playwright_copilot_target.py b/tests/unit/target/test_playwright_copilot_target.py index 07ecb45226..40e113fe2c 100644 --- a/tests/unit/target/test_playwright_copilot_target.py +++ b/tests/unit/target/test_playwright_copilot_target.py @@ -150,9 +150,7 @@ def test_validate_request_unsupported_type(self, mock_page): ) request = Message(message_pieces=[unsupported_piece]) - with pytest.raises( - ValueError, match=r"This target only supports .* prompt input\. Piece 0 has type: audio_path\." - ): + with pytest.raises(ValueError, match=r"This target supports only the following data types"): target._validate_request(message=request) def test_validate_request_valid_text(self, mock_page, text_request_piece): diff --git a/tests/unit/target/test_playwright_target.py b/tests/unit/target/test_playwright_target.py index 23468a13d2..dd1f579cab 100644 --- a/tests/unit/target/test_playwright_target.py +++ b/tests/unit/target/test_playwright_target.py @@ -124,7 +124,7 @@ def test_validate_request_unsupported_type(self, mock_interaction_func, mock_pag ) request = Message(message_pieces=[unsupported_piece]) - with pytest.raises(ValueError, match=r"This target only supports .* input\. Piece 0 has type: audio_path\."): + with pytest.raises(ValueError, match=r"This target supports only the following data types"): target._validate_request(message=request) def test_validate_request_valid_text(self, mock_interaction_func, mock_page, text_message_piece): @@ -336,7 +336,7 @@ def test_validate_request_multiple_unsupported_types(self, mock_interaction_func request = Message(message_pieces=unsupported_pieces) # Should fail on the first unsupported type - with pytest.raises(ValueError, match=r"This target only supports .* input\. Piece 0 has type: audio_path\."): + with pytest.raises(ValueError, match=r"This target supports only the following data types"): target._validate_request(message=request) @pytest.mark.asyncio diff --git a/tests/unit/target/test_prompt_target_azure_blob_storage.py b/tests/unit/target/test_prompt_target_azure_blob_storage.py index 9fbe5f28b9..163a87a60e 100644 --- a/tests/unit/target/test_prompt_target_azure_blob_storage.py +++ b/tests/unit/target/test_prompt_target_azure_blob_storage.py @@ -81,7 +81,7 @@ async def test_azure_blob_storage_validate_prompt_type( ): mock_upload_async.return_value = None request = Message(message_pieces=[get_image_message_piece()]) - with pytest.raises(ValueError, match="This target only supports text and url prompt input."): + with pytest.raises(ValueError, match="This target supports only the following data types"): await azure_blob_storage_target.send_prompt_async(message=request) diff --git a/tests/unit/target/test_realtime_target.py b/tests/unit/target/test_realtime_target.py index 12c157ad40..25f68b7c92 100644 --- a/tests/unit/target/test_realtime_target.py +++ b/tests/unit/target/test_realtime_target.py @@ -181,7 +181,8 @@ async def test_send_prompt_async_invalid_request(target): with pytest.raises(ValueError) as excinfo: target._validate_request(message=message) - assert str(excinfo.value) == "This target only supports text and audio_path prompt input. Received: image_path." + assert "This target supports only the following data types" in str(excinfo.value) + assert "image_path" in str(excinfo.value) @pytest.mark.asyncio diff --git a/tests/unit/target/test_supports_multi_turn.py b/tests/unit/target/test_supports_multi_turn.py index 60f65db44f..805a2c4e17 100644 --- a/tests/unit/target/test_supports_multi_turn.py +++ b/tests/unit/target/test_supports_multi_turn.py @@ -94,7 +94,7 @@ def test_constructor_override_supports_multi_turn(self): model_name="test-model", endpoint="https://mock.azure.com/", api_key="mock-api-key", - capabilities=TargetCapabilities(supports_multi_turn=False), + custom_capabilities=TargetCapabilities(supports_multi_turn=False), ) assert target.supports_multi_turn is False @@ -113,7 +113,7 @@ def test_constructor_override_single_turn_to_multi(self): model_name="dall-e-3", endpoint="https://mock.azure.com/", api_key="mock-api-key", - capabilities=TargetCapabilities(supports_multi_turn=True), + custom_capabilities=TargetCapabilities(supports_multi_turn=True), ) assert target.supports_multi_turn is True @@ -138,7 +138,7 @@ def test_capabilities_override_via_constructor(self): model_name="test-model", endpoint="https://mock.azure.com/", api_key="mock-api-key", - capabilities=TargetCapabilities(supports_multi_turn=False), + custom_capabilities=TargetCapabilities(supports_multi_turn=False), ) caps = target.capabilities assert isinstance(caps, TargetCapabilities) diff --git a/tests/unit/target/test_target_capabilities.py b/tests/unit/target/test_target_capabilities.py new file mode 100644 index 0000000000..54bc39287d --- /dev/null +++ b/tests/unit/target/test_target_capabilities.py @@ -0,0 +1,240 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.prompt_target.common.target_capabilities import TargetCapabilities + + +@pytest.mark.usefixtures("patch_central_database") +class TestTargetCapabilitiesModalities: + """Test that each target declares the correct input/output modalities via _DEFAULT_CAPABILITIES.""" + + def test_default_capabilities_are_text_only(self): + caps = TargetCapabilities() + assert caps.input_modalities == ["text"] + assert caps.output_modalities == ["text"] + + def test_openai_chat_target_modalities(self): + from pyrit.prompt_target import OpenAIChatTarget + + target = OpenAIChatTarget( + model_name="test-model", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + assert "text" in target.capabilities.input_modalities + assert "image_path" in target.capabilities.input_modalities + assert "audio_path" in target.capabilities.input_modalities + assert "text" in target.capabilities.output_modalities + assert "audio_path" in target.capabilities.output_modalities + + def test_openai_image_target_modalities(self): + from pyrit.prompt_target import OpenAIImageTarget + + target = OpenAIImageTarget( + model_name="dall-e-3", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + assert "text" in target.capabilities.input_modalities + assert "image_path" in target.capabilities.input_modalities + assert target.capabilities.output_modalities == ["image_path"] + + def test_openai_tts_target_modalities(self): + from pyrit.prompt_target import OpenAITTSTarget + + target = OpenAITTSTarget( + model_name="tts-1", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + assert target.capabilities.input_modalities == ["text"] + assert target.capabilities.output_modalities == ["audio_path"] + + def test_openai_video_target_modalities(self): + from pyrit.prompt_target import OpenAIVideoTarget + + target = OpenAIVideoTarget( + model_name="sora-2", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + assert "text" in target.capabilities.input_modalities + assert "image_path" in target.capabilities.input_modalities + assert target.capabilities.output_modalities == ["video_path"] + + def test_openai_realtime_target_modalities(self): + from pyrit.prompt_target import RealtimeTarget + + target = RealtimeTarget( + model_name="gpt-4o-realtime", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + assert "text" in target.capabilities.input_modalities + assert "audio_path" in target.capabilities.input_modalities + assert "text" in target.capabilities.output_modalities + assert "audio_path" in target.capabilities.output_modalities + + def test_openai_response_target_modalities(self): + from pyrit.prompt_target import OpenAIResponseTarget + + target = OpenAIResponseTarget( + model_name="o1", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + assert "text" in target.capabilities.input_modalities + assert "image_path" in target.capabilities.input_modalities + assert target.capabilities.output_modalities == ["text"] + + def test_openai_completion_target_modalities(self): + from pyrit.prompt_target import OpenAICompletionTarget + + target = OpenAICompletionTarget( + model_name="test-model", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + ) + assert target.capabilities.input_modalities == ["text"] + assert target.capabilities.output_modalities == ["text"] + + def test_azure_blob_storage_target_modalities(self): + from pyrit.prompt_target import AzureBlobStorageTarget + + target = AzureBlobStorageTarget( + container_url="https://mock.blob.core.windows.net/container", + sas_token="mock-sas-token", + ) + assert "text" in target.capabilities.input_modalities + assert "url" in target.capabilities.input_modalities + assert target.capabilities.output_modalities == ["url"] + + def test_text_target_modalities(self): + from pyrit.prompt_target import TextTarget + + target = TextTarget() + assert target.capabilities.input_modalities == ["text"] + assert target.capabilities.output_modalities == ["text"] + + def test_playwright_target_modalities(self): + from unittest.mock import MagicMock + + from pyrit.prompt_target import PlaywrightTarget + + target = PlaywrightTarget( + interaction_func=MagicMock(), + page=MagicMock(), + ) + assert "text" in target.capabilities.input_modalities + assert "image_path" in target.capabilities.input_modalities + assert target.capabilities.output_modalities == ["text"] + + def test_playwright_copilot_target_modalities(self): + from unittest.mock import MagicMock + + from pyrit.prompt_target import PlaywrightCopilotTarget + + target = PlaywrightCopilotTarget(page=MagicMock()) + assert "text" in target.capabilities.input_modalities + assert "image_path" in target.capabilities.input_modalities + assert "text" in target.capabilities.output_modalities + assert "image_path" in target.capabilities.output_modalities + + def test_websocket_copilot_target_modalities(self): + from unittest.mock import MagicMock + + from pyrit.prompt_target import WebSocketCopilotTarget + + target = WebSocketCopilotTarget(authenticator=MagicMock()) + assert "text" in target.capabilities.input_modalities + assert "image_path" in target.capabilities.input_modalities + assert target.capabilities.output_modalities == ["text"] + + def test_custom_capabilities_override_modalities(self): + from pyrit.prompt_target import OpenAIChatTarget, TargetCapabilities + + custom = TargetCapabilities( + supports_multi_turn=True, + input_modalities=["text"], + output_modalities=["text"], + ) + target = OpenAIChatTarget( + model_name="test-model", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_capabilities=custom, + ) + assert target.capabilities.input_modalities == ["text"] + assert target.capabilities.output_modalities == ["text"] + + +class TestTargetCapabilitiesAssertSatisfies: + """Test the assert_satifies method including list-based modality checks.""" + + def test_assert_satisfies_passes_when_all_met(self): + caps = TargetCapabilities( + supports_multi_turn=True, + input_modalities=["text", "image_path"], + output_modalities=["text"], + ) + required = TargetCapabilities( + supports_multi_turn=True, + input_modalities=["text"], + output_modalities=["text"], + ) + caps.assert_satifies(required) # should not raise + + def test_assert_satisfies_fails_on_unmet_bool(self): + caps = TargetCapabilities(supports_multi_turn=False) + required = TargetCapabilities(supports_multi_turn=True) + with pytest.raises(ValueError, match="supports_multi_turn"): + caps.assert_satifies(required) + + def test_assert_satisfies_fails_on_missing_input_modality(self): + caps = TargetCapabilities(input_modalities=["text"]) + required = TargetCapabilities(input_modalities=["text", "image_path"]) + with pytest.raises(ValueError, match="input_modalities"): + caps.assert_satifies(required) + + def test_assert_satisfies_fails_on_missing_output_modality(self): + caps = TargetCapabilities(output_modalities=["text"]) + required = TargetCapabilities(output_modalities=["text", "audio_path"]) + with pytest.raises(ValueError, match="output_modalities"): + caps.assert_satifies(required) + + def test_assert_satisfies_passes_when_superset_of_required_modalities(self): + caps = TargetCapabilities( + input_modalities=["text", "image_path", "audio_path"], + output_modalities=["text", "audio_path"], + ) + required = TargetCapabilities( + input_modalities=["text", "image_path"], + output_modalities=["text"], + ) + caps.assert_satifies(required) # should not raise + + def test_assert_satisfies_passes_with_default_text_modalities(self): + caps = TargetCapabilities() + required = TargetCapabilities() + caps.assert_satifies(required) # should not raise + + def test_assert_satisfies_fails_multiple_unmet(self): + caps = TargetCapabilities( + supports_multi_turn=False, + input_modalities=["text"], + ) + required = TargetCapabilities( + supports_multi_turn=True, + input_modalities=["text", "image_path"], + ) + with pytest.raises(ValueError, match="supports_multi_turn") as exc_info: + caps.assert_satifies(required) + assert "input_modalities" in str(exc_info.value) + + def test_assert_satisfies_ignores_false_required_bools(self): + """When a required capability bool is False, it should not be flagged as unmet.""" + caps = TargetCapabilities(supports_multi_turn=False) + required = TargetCapabilities(supports_multi_turn=False) + caps.assert_satifies(required) # should not raise diff --git a/tests/unit/target/test_tts_target.py b/tests/unit/target/test_tts_target.py index 42ccc8a1b0..dded8f9325 100644 --- a/tests/unit/target/test_tts_target.py +++ b/tests/unit/target/test_tts_target.py @@ -72,7 +72,7 @@ async def test_tts_validate_request_length(tts_target: OpenAITTSTarget): @pytest.mark.asyncio async def test_tts_validate_prompt_type(tts_target: OpenAITTSTarget): request = Message(message_pieces=[get_image_message_piece()]) - with pytest.raises(ValueError, match="This target only supports text prompt input."): + with pytest.raises(ValueError, match="This target supports only the following data types"): await tts_target.send_prompt_async(message=request) @@ -83,7 +83,7 @@ async def test_tts_validate_previous_conversations( message_piece = sample_conversations[0] mock_memory = MagicMock() - mock_memory.get_conversation.return_value = sample_conversations + mock_memory.get_message_pieces.return_value = sample_conversations mock_memory.add_message_to_memory = AsyncMock() tts_target._memory = mock_memory @@ -184,10 +184,6 @@ async def test_tts_send_prompt_async_rate_limit_exception_retries( await tts_target.send_prompt_async(message=request) -def test_is_json_response_supported(tts_target: OpenAITTSTarget): - assert tts_target.is_json_response_supported() is False - - @pytest.mark.asyncio async def test_tts_send_prompt_with_speed_parameter( patch_central_database, diff --git a/tests/unit/target/test_video_target.py b/tests/unit/target/test_video_target.py index 53925863ec..18a754f7ce 100644 --- a/tests/unit/target/test_video_target.py +++ b/tests/unit/target/test_video_target.py @@ -76,11 +76,6 @@ def test_video_validate_prompt_type_image_only(video_target: OpenAIVideoTarget): video_target._validate_request(message=Message([msg])) -def test_is_json_response_supported(patch_central_database): - target = OpenAIVideoTarget(endpoint="test", api_key="test", model_name="test-model") - assert target.is_json_response_supported() is False - - @pytest.mark.asyncio async def test_video_send_prompt_async_success( video_target: OpenAIVideoTarget, sample_conversations: MutableSequence[MessagePiece] @@ -425,7 +420,7 @@ def test_validate_rejects_unsupported_types(self, video_target: OpenAIVideoTarge converted_value_data_type="audio_path", conversation_id=conversation_id, ) - with pytest.raises(ValueError, match="Unsupported piece types"): + with pytest.raises(ValueError, match="This target supports only the following data types"): video_target._validate_request(message=Message([msg_text, msg_audio])) def test_validate_rejects_remix_with_image(self, video_target: OpenAIVideoTarget): @@ -977,7 +972,7 @@ def test_video_validate_previous_conversations( message_piece = sample_conversations[0] mock_memory = MagicMock() - mock_memory.get_conversation.return_value = sample_conversations + mock_memory.get_message_pieces.return_value = sample_conversations mock_memory.add_message_to_memory = AsyncMock() video_target._memory = mock_memory