-
Notifications
You must be signed in to change notification settings - Fork 691
FEAT expand TargetCapabilities #1464
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,6 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
|
|
||
| import abc | ||
| from typing import Optional | ||
|
|
||
| from pyrit.identifiers import ComponentIdentifier | ||
|
|
@@ -31,7 +30,7 @@ def __init__( | |
| endpoint: str = "", | ||
| model_name: str = "", | ||
| underlying_model: Optional[str] = None, | ||
| capabilities: Optional[TargetCapabilities] = None, | ||
| custom_capabilities: Optional[TargetCapabilities] = None, | ||
| ) -> None: | ||
| """ | ||
| Initialize the PromptChatTarget. | ||
|
|
@@ -43,15 +42,15 @@ def __init__( | |
| underlying_model (str, Optional): The underlying model name (e.g., "gpt-4o") for | ||
| identification purposes. This is useful when the deployment name in Azure differs | ||
| from the actual model. Defaults to None. | ||
| capabilities (TargetCapabilities, Optional): Override the default capabilities for | ||
| custom_capabilities (TargetCapabilities, Optional): Override the default capabilities for | ||
| this target instance. If None, uses the class-level defaults. Defaults to None. | ||
| """ | ||
| super().__init__( | ||
| max_requests_per_minute=max_requests_per_minute, | ||
| endpoint=endpoint, | ||
| model_name=model_name, | ||
| underlying_model=underlying_model, | ||
| capabilities=capabilities, | ||
| custom_capabilities=custom_capabilities, | ||
| ) | ||
|
|
||
| def set_system_prompt( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably get rid of this also.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note this may be part of a future PR. We also need to think through how we set the system prompt for things like |
||
|
|
@@ -85,15 +84,6 @@ def set_system_prompt( | |
| ).to_message() | ||
| ) | ||
|
|
||
| @abc.abstractmethod | ||
| def is_json_response_supported(self) -> bool: | ||
| """ | ||
| Abstract method to determine if JSON response format is supported by the target. | ||
|
|
||
| Returns: | ||
| bool: True if JSON response is supported, False otherwise. | ||
| """ | ||
|
|
||
| def is_response_format_json(self, message_piece: MessagePiece) -> bool: | ||
| """ | ||
| Check if the response format is JSON and ensure the target supports it. | ||
|
|
@@ -127,7 +117,7 @@ def _get_json_response_config(self, *, message_piece: MessagePiece) -> _JsonResp | |
| """ | ||
| config = _JsonResponseConfig.from_metadata(metadata=message_piece.prompt_metadata) | ||
|
|
||
| if config.enabled and not self.is_json_response_supported(): | ||
| if config.enabled and not self.capabilities.supports_json_response: | ||
| target_name = self.get_identifier().class_name | ||
| raise ValueError(f"This target {target_name} does not support JSON response format.") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,7 +38,7 @@ def __init__( | |
| endpoint: str = "", | ||
| model_name: str = "", | ||
| underlying_model: Optional[str] = None, | ||
| capabilities: Optional[TargetCapabilities] = None, | ||
| custom_capabilities: Optional[TargetCapabilities] = None, | ||
| ) -> None: | ||
| """ | ||
| Initialize the PromptTarget. | ||
|
|
@@ -52,7 +52,7 @@ def __init__( | |
| identification purposes. This is useful when the deployment name in Azure differs | ||
| from the actual model. If not provided, `model_name` will be used for the identifier. | ||
| Defaults to None. | ||
| capabilities (TargetCapabilities, Optional): Override the default capabilities for | ||
| custom_capabilities (TargetCapabilities, Optional): Override the default capabilities for | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we could have a mapping of known default capabilities here, or potentially retrieved from a list in target_capabilites class.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also talked about a method to discover these which I think will be useful. But the defaults we know could go a long way early on :) |
||
| this target instance. Useful for targets whose capabilities depend on deployment | ||
| configuration (e.g., Playwright, HTTP). If None, uses the class-level | ||
| ``_DEFAULT_CAPABILITIES``. Defaults to None. | ||
|
|
@@ -63,7 +63,9 @@ def __init__( | |
| self._endpoint = endpoint | ||
| self._model_name = model_name | ||
| self._underlying_model = underlying_model | ||
| self._capabilities = capabilities if capabilities is not None else type(self)._DEFAULT_CAPABILITIES | ||
| self._capabilities = ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. self._capabilities is defined here but the later changes reference self.capabilities? |
||
| custom_capabilities if custom_capabilities is not None else type(self)._DEFAULT_CAPABILITIES | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good, I like this |
||
| ) | ||
|
|
||
| if self._verbose: | ||
| logging.basicConfig(level=logging.INFO) | ||
|
|
@@ -78,14 +80,36 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: | |
| but some (like response target with tool calls) may return multiple messages. | ||
| """ | ||
|
|
||
| @abc.abstractmethod | ||
| def _validate_request(self, *, message: Message) -> None: | ||
| """ | ||
| Validate the provided message. | ||
|
|
||
| Args: | ||
| message: The message to validate. | ||
|
|
||
| Raises: | ||
| ValueError: if the target does not support the provided message pieces or if the message | ||
| violates any constraints based on the target's capabilities. | ||
|
|
||
| """ | ||
| n_pieces = len(message.message_pieces) | ||
| if not self.capabilities.supports_multi_message_pieces and n_pieces != 1: | ||
| raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") | ||
|
|
||
| for piece in message.message_pieces: | ||
| piece_type = piece.converted_value_data_type | ||
| if piece_type not in self.capabilities.input_modalities: | ||
| supported_types = ", ".join(sorted(self.capabilities.input_modalities)) | ||
| raise ValueError( | ||
| f"This target supports only the following data types: {supported_types}. Received: {piece_type}." | ||
| ) | ||
|
|
||
| if not self.supports_multi_turn: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| request = message.message_pieces[0] | ||
| messages = self._memory.get_message_pieces(conversation_id=request.conversation_id) | ||
|
|
||
| if len(messages) > 0: | ||
| raise ValueError("This target only supports a single turn conversation.") | ||
|
|
||
| def set_model_name(self, *, model_name: str) -> None: | ||
| """ | ||
|
|
@@ -140,6 +164,15 @@ def _create_identifier( | |
|
|
||
| return ComponentIdentifier.of(self, params=all_params, children=children) | ||
|
|
||
| def is_json_response_supported(self) -> bool: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should remove the methods like this (and |
||
| """ | ||
| Method to determine if JSON response format is supported by the target. | ||
|
|
||
| Returns: | ||
| bool: True if JSON response is supported, False otherwise. | ||
| """ | ||
| return self.capabilities.supports_json_response | ||
|
|
||
| @property | ||
| def capabilities(self) -> TargetCapabilities: | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,9 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
|
|
||
| from dataclasses import dataclass | ||
| from dataclasses import dataclass, field, fields | ||
|
|
||
| from pyrit.models import PromptDataType | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
|
|
@@ -20,3 +22,38 @@ class attribute. Users can override individual capabilities per instance | |
| # (i.e., it accepts and uses conversation history or maintains state | ||
| # across turns via external mechanisms like WebSocket connections). | ||
| supports_multi_turn: bool = False | ||
|
|
||
| # Whether the target natively supports multiple message pieces in a single request. | ||
| supports_multi_message_pieces: bool = True | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might default to the least capable; |
||
|
|
||
| # Whether the target natively supports JSON output (e.g., via a "json" response format). | ||
| supports_json_response: bool = False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You likely want to split this to |
||
|
|
||
| # The input modalities supported by the target (e.g., "text", "image"). | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also likely want |
||
| input_modalities: list[PromptDataType] = field(default_factory=lambda: ["text"]) | ||
|
|
||
| # The output modalities supported by the target (e.g., "text", "image"). | ||
| output_modalities: list[PromptDataType] = field(default_factory=lambda: ["text"]) | ||
|
Comment on lines
+33
to
+36
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: What do you think about making these into sets? It would be functionally identical but I'm curious if we want to preserve the ordering
Comment on lines
+32
to
+36
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I worry about this a little. You have multi-piece and input modalities, but just because an endpoint supports text and image and video and multiple pieces doesn't mean it supports all combinations of those, right? It might just accept text+image and text+video, but not all three together, or not image+video. A solution to this would be to drop the multi-piece support field, and instead make input and output modalities of type Wdyt? Did you consider this and decide it's not good for some reason? I might be missing some cases because I haven't spent as much time on it. |
||
|
|
||
| def assert_satifies(self, required_capabilities: "TargetCapabilities") -> None: | ||
| """ | ||
| Assert that the current capabilities satisfy the required capabilities. | ||
|
|
||
| Args: | ||
| required_capabilities (TargetCapabilities): The required capabilities to check against. | ||
|
|
||
| Raises: | ||
| ValueError: If any of the required capabilities are not satisfied. | ||
| """ | ||
| unmet = [] | ||
| for f in fields(required_capabilities): | ||
| required_value = getattr(required_capabilities, f.name) | ||
| self_value = getattr(self, f.name) | ||
| if isinstance(required_value, list): | ||
| missing = set(required_value) - set(self_value) | ||
| if missing: | ||
| unmet.append(f"{f.name}: missing {missing}") | ||
| elif required_value and not self_value: | ||
| unmet.append(f.name) | ||
| if unmet: | ||
| raise ValueError(f"Target does not satisfy the following capabilities: {', '.join(unmet)}") | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably deprecate
PromptChatTargetand remove it from the inheritence chain, since we should check via prompt capabilities nowThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note this may be part of a future PR