Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions pyrit/prompt_target/azure_blob_storage_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pyrit.identifiers import ComponentIdentifier
from pyrit.models import Message, construct_response_from_request
from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
from pyrit.prompt_target.common.utils import limit_requests_per_minute

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -49,6 +50,12 @@ class AzureBlobStorageTarget(PromptTarget):
AZURE_STORAGE_CONTAINER_ENVIRONMENT_VARIABLE: str = "AZURE_STORAGE_ACCOUNT_CONTAINER_URL"
SAS_TOKEN_ENVIRONMENT_VARIABLE: str = "AZURE_STORAGE_ACCOUNT_SAS_TOKEN"

_DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(
input_modalities=["text", "url"],
output_modalities=["url"],
supports_multi_message_pieces=False,
)

def __init__(
self,
*,
Expand Down Expand Up @@ -196,18 +203,3 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]:
)

return [response]

def _validate_request(self, *, message: Message) -> None:
n_pieces = len(message.message_pieces)
if n_pieces != 1:
raise ValueError(f"This target only supports a single message piece. Received {n_pieces} pieces")

piece_type = message.message_pieces[0].converted_value_data_type
if piece_type not in ["text", "url"]:
raise ValueError(f"This target only supports text and url prompt input. Received: {piece_type}.")

request = message.message_pieces[0]
messages = self._memory.get_message_pieces(conversation_id=request.conversation_id)

if len(messages) > 0:
raise ValueError("This target only supports a single turn conversation.")
9 changes: 0 additions & 9 deletions pyrit/prompt_target/azure_ml_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,3 @@ def _get_headers(self) -> dict[str, str]:

def _validate_request(self, *, message: Message) -> None:
pass

def is_json_response_supported(self) -> bool:
"""
Check if the target supports JSON as a response format.

Returns:
bool: True if JSON response is supported, False otherwise.
"""
return False
18 changes: 4 additions & 14 deletions pyrit/prompt_target/common/prompt_chat_target.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import abc
from typing import Optional

from pyrit.identifiers import ComponentIdentifier
Copy link
Contributor

@rlundeen2 rlundeen2 Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably deprecate PromptChatTarget and remove it from the inheritence chain, since we should check via prompt capabilities now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note this may be part of a future PR

Expand Down Expand Up @@ -31,7 +30,7 @@ def __init__(
endpoint: str = "",
model_name: str = "",
underlying_model: Optional[str] = None,
capabilities: Optional[TargetCapabilities] = None,
custom_capabilities: Optional[TargetCapabilities] = None,
) -> None:
"""
Initialize the PromptChatTarget.
Expand All @@ -43,15 +42,15 @@ def __init__(
underlying_model (str, Optional): The underlying model name (e.g., "gpt-4o") for
identification purposes. This is useful when the deployment name in Azure differs
from the actual model. Defaults to None.
capabilities (TargetCapabilities, Optional): Override the default capabilities for
custom_capabilities (TargetCapabilities, Optional): Override the default capabilities for
this target instance. If None, uses the class-level defaults. Defaults to None.
"""
super().__init__(
max_requests_per_minute=max_requests_per_minute,
endpoint=endpoint,
model_name=model_name,
underlying_model=underlying_model,
capabilities=capabilities,
custom_capabilities=custom_capabilities,
)

def set_system_prompt(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably get rid of this also. set_system_prompt is error prone and you can always do it with prepended_conversation

Copy link
Contributor

@rlundeen2 rlundeen2 Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note this may be part of a future PR. We also need to think through how we set the system prompt for things like RealTimeTarget since they are multi-turn, you can set the system prompt, but you can't edit the history.

Expand Down Expand Up @@ -85,15 +84,6 @@ def set_system_prompt(
).to_message()
)

@abc.abstractmethod
def is_json_response_supported(self) -> bool:
"""
Abstract method to determine if JSON response format is supported by the target.

Returns:
bool: True if JSON response is supported, False otherwise.
"""

def is_response_format_json(self, message_piece: MessagePiece) -> bool:
"""
Check if the response format is JSON and ensure the target supports it.
Expand Down Expand Up @@ -127,7 +117,7 @@ def _get_json_response_config(self, *, message_piece: MessagePiece) -> _JsonResp
"""
config = _JsonResponseConfig.from_metadata(metadata=message_piece.prompt_metadata)

if config.enabled and not self.is_json_response_supported():
if config.enabled and not self.capabilities.supports_json_response:
target_name = self.get_identifier().class_name
raise ValueError(f"This target {target_name} does not support JSON response format.")

Expand Down
41 changes: 37 additions & 4 deletions pyrit/prompt_target/common/prompt_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
endpoint: str = "",
model_name: str = "",
underlying_model: Optional[str] = None,
capabilities: Optional[TargetCapabilities] = None,
custom_capabilities: Optional[TargetCapabilities] = None,
) -> None:
"""
Initialize the PromptTarget.
Expand All @@ -52,7 +52,7 @@ def __init__(
identification purposes. This is useful when the deployment name in Azure differs
from the actual model. If not provided, `model_name` will be used for the identifier.
Defaults to None.
capabilities (TargetCapabilities, Optional): Override the default capabilities for
custom_capabilities (TargetCapabilities, Optional): Override the default capabilities for
Copy link
Contributor

@rlundeen2 rlundeen2 Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could have a mapping of known default capabilities here, or potentially retrieved from a list in target_capabilites class.

if underlying_model == "gpt-5.1":
  _default = X
elif ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also talked about a method to discover these which I think will be useful. But the defaults we know could go a long way early on :)

this target instance. Useful for targets whose capabilities depend on deployment
configuration (e.g., Playwright, HTTP). If None, uses the class-level
``_DEFAULT_CAPABILITIES``. Defaults to None.
Expand All @@ -63,7 +63,9 @@ def __init__(
self._endpoint = endpoint
self._model_name = model_name
self._underlying_model = underlying_model
self._capabilities = capabilities if capabilities is not None else type(self)._DEFAULT_CAPABILITIES
self._capabilities = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._capabilities is defined here but the later changes reference self.capabilities?

custom_capabilities if custom_capabilities is not None else type(self)._DEFAULT_CAPABILITIES
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good, I like this

)

if self._verbose:
logging.basicConfig(level=logging.INFO)
Expand All @@ -78,14 +80,36 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]:
but some (like response target with tool calls) may return multiple messages.
"""

@abc.abstractmethod
def _validate_request(self, *, message: Message) -> None:
"""
Validate the provided message.

Args:
message: The message to validate.

Raises:
ValueError: if the target does not support the provided message pieces or if the message
violates any constraints based on the target's capabilities.

"""
n_pieces = len(message.message_pieces)
if not self.capabilities.supports_multi_message_pieces and n_pieces != 1:
raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.")

for piece in message.message_pieces:
piece_type = piece.converted_value_data_type
if piece_type not in self.capabilities.input_modalities:
supported_types = ", ".join(sorted(self.capabilities.input_modalities))
raise ValueError(
f"This target supports only the following data types: {supported_types}. Received: {piece_type}."
)

if not self.supports_multi_turn:
Copy link
Contributor

@jsong468 jsong468 Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.capabilities.supports_multi_turn? Right now, we have a convenience property for this but not some of the others. Since we have a bunch now, we could just get rid of it?

request = message.message_pieces[0]
messages = self._memory.get_message_pieces(conversation_id=request.conversation_id)

if len(messages) > 0:
raise ValueError("This target only supports a single turn conversation.")

def set_model_name(self, *, model_name: str) -> None:
"""
Expand Down Expand Up @@ -140,6 +164,15 @@ def _create_identifier(

return ComponentIdentifier.of(self, params=all_params, children=children)

def is_json_response_supported(self) -> bool:
Copy link
Contributor

@rlundeen2 rlundeen2 Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should remove the methods like this (and supports_multi_turn, etc, and just have people validate using the capabilities property.

"""
Method to determine if JSON response format is supported by the target.

Returns:
bool: True if JSON response is supported, False otherwise.
"""
return self.capabilities.supports_json_response

@property
def capabilities(self) -> TargetCapabilities:
"""
Expand Down
39 changes: 38 additions & 1 deletion pyrit/prompt_target/common/target_capabilities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from dataclasses import dataclass
from dataclasses import dataclass, field, fields

from pyrit.models import PromptDataType


@dataclass(frozen=True)
Expand All @@ -20,3 +22,38 @@ class attribute. Users can override individual capabilities per instance
# (i.e., it accepts and uses conversation history or maintains state
# across turns via external mechanisms like WebSocket connections).
supports_multi_turn: bool = False

# Whether the target natively supports multiple message pieces in a single request.
supports_multi_message_pieces: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might default to the least capable; supports_multi_message_pieces = False


# Whether the target natively supports JSON output (e.g., via a "json" response format).
supports_json_response: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You likely want to split this to json_schema_support and json_output_support because they're different params. In one you include the schema, in the other you put json=True


# The input modalities supported by the target (e.g., "text", "image").
Copy link
Contributor

@rlundeen2 rlundeen2 Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also likely want editable_history; This let's us set the system history that are not actually responses from the target. We use PromptChatTarget for this currently.

input_modalities: list[PromptDataType] = field(default_factory=lambda: ["text"])

# The output modalities supported by the target (e.g., "text", "image").
output_modalities: list[PromptDataType] = field(default_factory=lambda: ["text"])
Comment on lines +33 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: What do you think about making these into sets? It would be functionally identical but I'm curious if we want to preserve the ordering

Comment on lines +32 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry about this a little. You have multi-piece and input modalities, but just because an endpoint supports text and image and video and multiple pieces doesn't mean it supports all combinations of those, right?

It might just accept text+image and text+video, but not all three together, or not image+video.

A solution to this would be to drop the multi-piece support field, and instead make input and output modalities of type set[set[PromptDataType]] (set because the order is irrelevant and we can check membership in constant time (although complexity is hardly an issue here). I think @fitzpr's PR #1383 has this as sets of sets if you want an idea.

Wdyt? Did you consider this and decide it's not good for some reason? I might be missing some cases because I haven't spent as much time on it.


def assert_satifies(self, required_capabilities: "TargetCapabilities") -> None:
"""
Assert that the current capabilities satisfy the required capabilities.

Args:
required_capabilities (TargetCapabilities): The required capabilities to check against.

Raises:
ValueError: If any of the required capabilities are not satisfied.
"""
unmet = []
for f in fields(required_capabilities):
required_value = getattr(required_capabilities, f.name)
self_value = getattr(self, f.name)
if isinstance(required_value, list):
missing = set(required_value) - set(self_value)
if missing:
unmet.append(f"{f.name}: missing {missing}")
elif required_value and not self_value:
unmet.append(f.name)
if unmet:
raise ValueError(f"Target does not satisfy the following capabilities: {', '.join(unmet)}")
12 changes: 3 additions & 9 deletions pyrit/prompt_target/crucible_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from pyrit.models import Message, construct_response_from_request
from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
from pyrit.prompt_target.common.utils import limit_requests_per_minute

logger = logging.getLogger(__name__)
Expand All @@ -24,6 +25,8 @@ class CrucibleTarget(PromptTarget):

API_KEY_ENVIRONMENT_VARIABLE: str = "CRUCIBLE_API_KEY"

_DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_message_pieces=False)

def __init__(
self,
*,
Expand Down Expand Up @@ -80,15 +83,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]:

return [response_entry]

def _validate_request(self, *, message: Message) -> None:
n_pieces = len(message.message_pieces)
if n_pieces != 1:
raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.")

piece_type = message.message_pieces[0].converted_value_data_type
if piece_type != "text":
raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.")

@pyrit_target_retry
async def _complete_text_async(self, text: str) -> str:
payload: dict[str, object] = {
Expand Down
12 changes: 3 additions & 9 deletions pyrit/prompt_target/gandalf_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pyrit.identifiers import ComponentIdentifier
from pyrit.models import Message, construct_response_from_request
from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
from pyrit.prompt_target.common.utils import limit_requests_per_minute

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -38,6 +39,8 @@ class GandalfLevel(enum.Enum):
class GandalfTarget(PromptTarget):
"""A prompt target for the Gandalf security challenge."""

_DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_message_pieces=False)

def __init__(
self,
*,
Expand Down Expand Up @@ -93,15 +96,6 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]:

return [response_entry]

def _validate_request(self, *, message: Message) -> None:
n_pieces = len(message.message_pieces)
if n_pieces != 1:
raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.")

piece_type = message.message_pieces[0].converted_value_data_type
if piece_type != "text":
raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.")

async def check_password(self, password: str) -> bool:
"""
Check if the password is correct.
Expand Down
12 changes: 4 additions & 8 deletions pyrit/prompt_target/http_target/http_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import logging
import re
from collections.abc import Callable, Sequence
from collections.abc import Callable
from typing import Any, Optional

import httpx
Expand All @@ -17,6 +17,7 @@
construct_response_from_request,
)
from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
from pyrit.prompt_target.common.utils import limit_requests_per_minute

logger = logging.getLogger(__name__)
Expand All @@ -40,6 +41,8 @@ class HTTPTarget(PromptTarget):
httpx_client_kwargs: (dict): additional keyword arguments to pass to the HTTP client
"""

_DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(supports_multi_message_pieces=False)

def __init__(
self,
http_request: str,
Expand Down Expand Up @@ -290,10 +293,3 @@ def _infer_full_url_from_host(

host = headers_dict["host"]
return f"{http_protocol}{host}{path}"

def _validate_request(self, *, message: Message) -> None:
message_pieces: Sequence[MessagePiece] = message.message_pieces

n_pieces = len(message_pieces)
if n_pieces != 1:
raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.")
25 changes: 6 additions & 19 deletions pyrit/prompt_target/hugging_face/hugging_face_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pyrit.identifiers import ComponentIdentifier
from pyrit.models import Message, construct_response_from_request
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
from pyrit.prompt_target.common.utils import limit_requests_per_minute

logger = logging.getLogger(__name__)
Expand All @@ -34,6 +35,11 @@ class HuggingFaceChatTarget(PromptChatTarget):
Inherits from PromptTarget to comply with the current design standards.
"""

_DEFAULT_CAPABILITIES: TargetCapabilities = TargetCapabilities(
supports_multi_turn=True,
supports_multi_message_pieces=False,
)

# Class-level cache for model and tokenizer
_cached_model = None
_cached_tokenizer = None
Expand Down Expand Up @@ -388,25 +394,6 @@ def _apply_chat_template(self, messages: list[dict[str, str]]) -> Any:
logger.error(error_message)
raise ValueError(error_message)

def _validate_request(self, *, message: Message) -> None:
"""
Validate the provided message.

Args:
message: The message to validate.

Raises:
ValueError: If the message does not contain exactly one text piece.
ValueError: If the message piece is not of type text.
"""
n_pieces = len(message.message_pieces)
if n_pieces != 1:
raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.")

piece_type = message.message_pieces[0].converted_value_data_type
if piece_type != "text":
raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.")

def is_json_response_supported(self) -> bool:
"""
Check if the target supports JSON as a response format.
Expand Down
Loading
Loading