-
Notifications
You must be signed in to change notification settings - Fork 746
Expand file tree
/
Copy pathgeneric_system_squash.py
More file actions
78 lines (62 loc) · 3.02 KB
/
generic_system_squash.py
File metadata and controls
78 lines (62 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
from pyrit.message_normalizer.message_normalizer import MessageListNormalizer
from pyrit.models import Message, MessagePiece
class GenericSystemSquashNormalizer(MessageListNormalizer[Message]):
"""
Normalizer that combines the first system message with the first user message using generic instruction tags.
"""
async def normalize_async(self, messages: list[Message]) -> list[Message]:
"""
Return messages with the first system message combined into the first user message.
The format uses generic instruction tags:
### Instructions ###
{system_content}
######
{user_content}
Args:
messages: The list of messages to normalize.
Returns:
A Message with the system message squashed into the first user message.
Raises:
ValueError: If the messages list is empty.
"""
if not messages:
raise ValueError("Messages list cannot be empty")
# Check if first message is a system message
first_piece = messages[0].get_piece()
if first_piece.api_role != "system":
# No system message to squash, return messages unchanged
return list(messages)
if len(messages) == 1:
# Only system message, convert to user message
return [Message.from_prompt(prompt=first_piece.converted_value, role="user")]
user_message_index = next(
(i for i, message in enumerate(messages[1:], start=1) if message.api_role == "user"),
-1,
)
if user_message_index == -1:
# Preserve the instruction content without rewriting non-user messages.
return [Message.from_prompt(prompt=first_piece.converted_value, role="user")] + list(messages[1:])
# Combine system with the first user message
system_content = first_piece.converted_value
user_piece = messages[user_message_index].get_piece()
user_content = user_piece.converted_value
combined_content = f"### Instructions ###\n\n{system_content}\n\n######\n\n{user_content}"
squashed_message = copy.deepcopy(messages[user_message_index])
if squashed_message.message_pieces[0].converted_value_data_type == "text":
squashed_message.message_pieces[0].original_value = combined_content
squashed_message.message_pieces[0].converted_value = combined_content
else:
squashed_message.message_pieces.insert(
0,
MessagePiece(
role="user",
original_value=combined_content,
conversation_id=user_piece.conversation_id,
sequence=user_piece.sequence,
),
)
# Return the squashed message followed by remaining messages (skip first two)
return list(messages[1:user_message_index]) + [squashed_message] + list(messages[user_message_index + 1 :])