-
Notifications
You must be signed in to change notification settings - Fork 745
Expand file tree
/
Copy pathmemory_interface.py
More file actions
1736 lines (1456 loc) · 74.2 KB
/
memory_interface.py
File metadata and controls
1736 lines (1456 loc) · 74.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
import atexit
import logging
import uuid
import warnings
import weakref
from collections.abc import MutableSequence, Sequence
from contextlib import closing
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from sqlalchemy import MetaData, and_, or_
from sqlalchemy.engine.base import Engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm.attributes import InstrumentedAttribute
from pyrit.common.path import DB_DATA_PATH
from pyrit.memory.memory_embedding import (
MemoryEmbedding,
default_memory_embedding_factory,
)
from pyrit.memory.memory_exporter import MemoryExporter
from pyrit.memory.memory_models import (
AttackResultEntry,
Base,
EmbeddingDataEntry,
PromptMemoryEntry,
ScenarioResultEntry,
ScoreEntry,
SeedEntry,
)
from pyrit.models import (
AttackResult,
ConversationStats,
DataTypeSerializer,
Message,
MessagePiece,
ScenarioResult,
Score,
Seed,
SeedDataset,
SeedGroup,
SeedType,
StorageIO,
data_serializer_factory,
group_conversation_message_pieces_by_sequence,
sort_message_pieces,
)
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
logger = logging.getLogger(__name__)
Model = TypeVar("Model")
class MemoryInterface(abc.ABC):
"""
Abstract interface for conversation memory storage systems.
This interface defines the contract for storing and retrieving chat messages
and conversation history. Implementations can use different storage backends
such as files, databases, or cloud storage services.
"""
memory_embedding: MemoryEmbedding = None
results_storage_io: StorageIO = None
results_path: str = None
engine: Engine = None
def __init__(self, embedding_model: Optional[Any] = None) -> None:
"""
Initialize the MemoryInterface.
Args:
embedding_model: If set, this includes embeddings in the memory entries
which are extremely useful for comparing chat messages and similarities,
but also includes overhead.
"""
self.memory_embedding = embedding_model
# Initialize the MemoryExporter instance
self.exporter = MemoryExporter()
self._init_storage_io()
# Ensure cleanup at process exit
self.cleanup()
def enable_embedding(self, embedding_model: Optional[Any] = None) -> None:
"""
Enable embedding functionality for the memory interface.
Args:
embedding_model: Optional embedding model to use. If not provided,
attempts to create a default embedding model from environment variables.
Raises:
ValueError: If no embedding model is provided and required environment
variables are not set.
"""
self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model)
def disable_embedding(self) -> None:
"""
Disable embedding functionality for the memory interface.
Sets the memory_embedding attribute to None, disabling any embedding operations.
"""
self.memory_embedding = None
@abc.abstractmethod
def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]:
"""
Load all EmbeddingData from the memory storage handler.
"""
@abc.abstractmethod
def _init_storage_io(self) -> None:
"""
Initialize the storage IO handler results_storage_io.
"""
@abc.abstractmethod
def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str, str]) -> list[Any]:
"""
Return a list of conditions for filtering memory entries based on memory labels.
Args:
memory_labels (dict[str, str]): A free-form dictionary for tagging prompts with custom labels.
These labels can be used to track all prompts sent as part of an operation, score prompts based on
the operation ID (op_id), and tag each prompt with the relevant Responsible AI (RAI) harm category.
Users can define any key-value pairs according to their needs.
Returns:
list: A list of conditions for filtering memory entries based on memory labels.
"""
@abc.abstractmethod
def _get_message_pieces_prompt_metadata_conditions(
self, *, prompt_metadata: dict[str, Union[str, int]]
) -> list[Any]:
"""
Return a list of conditions for filtering memory entries based on prompt metadata.
Args:
prompt_metadata (dict[str, str | int]): A free-form dictionary for tagging prompts with custom metadata.
This includes information that is useful for the specific target you're probing, such as encoding data.
Returns:
list: A list of conditions for filtering memory entries based on prompt metadata.
"""
@abc.abstractmethod
def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any:
"""
Return a condition to retrieve based on attack ID.
"""
@abc.abstractmethod
def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any:
"""
Return a condition for filtering seed prompt entries based on prompt metadata.
Args:
metadata (dict[str, str | int]): A free-form dictionary for tagging prompts with custom metadata.
This includes information that is useful for the specific target you're probing, such as encoding data.
Returns:
Any: A SQLAlchemy condition for filtering memory entries based on prompt metadata.
"""
@abc.abstractmethod
def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None:
"""
Insert a list of message pieces into the memory storage.
"""
@abc.abstractmethod
def _add_embeddings_to_memory(self, *, embedding_data: Sequence[EmbeddingDataEntry]) -> None:
"""
Insert embedding data into memory storage.
"""
@abc.abstractmethod
def _query_entries(
self,
model_class: type[Model],
*,
conditions: Optional[Any] = None,
distinct: bool = False,
join_scores: bool = False,
) -> MutableSequence[Model]:
"""
Fetch data from the specified table model with optional conditions.
Args:
model_class: The SQLAlchemy model class corresponding to the table you want to query.
conditions: SQLAlchemy filter conditions (Optional).
distinct: Whether to return distinct rows only. Defaults to False.
join_scores: Whether to join the scores table. Defaults to False.
Returns:
List of model instances representing the rows fetched from the table.
"""
@abc.abstractmethod
def _insert_entry(self, entry: Base) -> None:
"""
Insert an entry into the Table.
Args:
entry: An instance of a SQLAlchemy model to be added to the Table.
"""
@abc.abstractmethod
def _insert_entries(self, *, entries: Sequence[Base]) -> None:
"""Insert multiple entries into the database."""
@abc.abstractmethod
def get_session(self) -> Any:
"""
Provide a SQLAlchemy session for transactional operations.
Returns:
Session: A SQLAlchemy session bound to the engine.
"""
def _update_entry(self, entry: Base) -> None:
"""
Update an existing entry in the Table using merge.
This method uses SQLAlchemy's merge operation which will:
- Update the existing record if the primary key matches
- Insert a new record if the primary key doesn't exist
Args:
entry: An instance of a SQLAlchemy model to be updated in the Table.
Raises:
SQLAlchemyError: If there's an error during the database operation.
"""
with closing(self.get_session()) as session:
try:
session.merge(entry)
session.commit()
except SQLAlchemyError as e:
session.rollback()
logger.exception(f"Error updating entry in the table: {e}")
raise
@abc.abstractmethod
def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict[str, Any]) -> bool:
"""
Update the given entries with the specified field values.
Args:
entries (Sequence[Base]): A list of SQLAlchemy model instances to be updated.
update_fields (dict): A dictionary of field names and their new values.
"""
@abc.abstractmethod
def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any:
"""
Return a database-specific condition for filtering AttackResults by targeted harm categories
in the associated PromptMemoryEntry records.
Args:
targeted_harm_categories: List of harm categories that must ALL be present.
Returns:
Database-specific SQLAlchemy condition.
"""
@abc.abstractmethod
def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
"""
Return a database-specific condition for filtering AttackResults by labels
in the associated PromptMemoryEntry records.
Args:
labels: Dictionary of labels that must ALL be present.
Returns:
Database-specific SQLAlchemy condition.
"""
@abc.abstractmethod
def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any:
"""
Return a database-specific condition for filtering AttackResults by attack class
(class_name in the attack_identifier JSON column).
Args:
attack_class: Exact attack class name to match.
Returns:
Database-specific SQLAlchemy condition.
"""
@abc.abstractmethod
def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any:
"""
Return a database-specific condition for filtering AttackResults by converter classes
in the request_converter_identifiers array within attack_identifier JSON column.
This method is only called when converter filtering is requested (converter_classes
is not None). The caller handles the None-vs-list distinction:
- ``len(converter_classes) == 0``: return a condition matching attacks with NO converters.
- ``len(converter_classes) > 0``: return a condition requiring ALL specified converter
class names to be present (AND logic, case-insensitive).
Args:
converter_classes: Converter class names to require. An empty sequence means
"match only attacks that have no converters".
Returns:
Database-specific SQLAlchemy condition.
"""
@abc.abstractmethod
def get_unique_attack_class_names(self) -> list[str]:
"""
Return sorted unique attack class names from all stored attack results.
Extracts class_name from the attack_identifier JSON column via a
database-level DISTINCT query.
Returns:
Sorted list of unique attack class name strings.
"""
@abc.abstractmethod
def get_unique_converter_class_names(self) -> list[str]:
"""
Return sorted unique converter class names used across all attack results.
Extracts class_name values from the request_converter_identifiers array
within the attack_identifier JSON column via a database-level query.
Returns:
Sorted list of unique converter class name strings.
"""
@abc.abstractmethod
def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, "ConversationStats"]:
"""
Return lightweight aggregate statistics for one or more conversations.
Computes per-conversation message count (distinct sequence numbers),
a truncated last-message preview, the first non-empty labels dict,
and the earliest message timestamp using efficient SQL aggregation
instead of loading full pieces.
Args:
conversation_ids: The conversation IDs to query.
Returns:
Mapping from conversation_id to ConversationStats.
Conversations with no pieces are omitted from the result.
"""
@abc.abstractmethod
def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any:
"""
Return a database-specific condition for filtering ScenarioResults by labels.
Args:
labels: Dictionary of labels that must ALL be present.
Returns:
Database-specific SQLAlchemy condition.
"""
@abc.abstractmethod
def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> Any:
"""
Return a database-specific condition for filtering ScenarioResults by target endpoint.
Args:
endpoint: Endpoint substring to search for (case-insensitive).
Returns:
Database-specific SQLAlchemy condition.
"""
@abc.abstractmethod
def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any:
"""
Return a database-specific condition for filtering ScenarioResults by target model name.
Args:
model_name: Model name substring to search for (case-insensitive).
Returns:
Database-specific SQLAlchemy condition.
"""
def add_scores_to_memory(self, *, scores: Sequence[Score]) -> None:
"""
Insert a list of scores into the memory storage.
"""
for score in scores:
if score.message_piece_id:
message_piece_id = score.message_piece_id
pieces = self.get_message_pieces(prompt_ids=[str(message_piece_id)])
if not pieces:
logger.error(f"MessagePiece with ID {message_piece_id} not found in memory.")
continue
# auto-link score to the original prompt id if the prompt is a duplicate
if pieces[0].original_prompt_id != pieces[0].id:
score.message_piece_id = pieces[0].original_prompt_id
self._insert_entries(entries=[ScoreEntry(entry=score) for score in scores])
def get_scores(
self,
*,
score_ids: Optional[Sequence[str]] = None,
score_type: Optional[str] = None,
score_category: Optional[str] = None,
sent_after: Optional[datetime] = None,
sent_before: Optional[datetime] = None,
) -> Sequence[Score]:
"""
Retrieve a list of Score objects based on the specified filters.
Args:
score_ids (Optional[Sequence[str]]): A list of score IDs to filter by.
score_type (Optional[str]): The type of the score to filter by.
score_category (Optional[str]): The category of the score to filter by.
sent_after (Optional[datetime]): Filter for scores sent after this datetime.
sent_before (Optional[datetime]): Filter for scores sent before this datetime.
Returns:
Sequence[Score]: A list of Score objects that match the specified filters.
"""
conditions: list[Any] = []
if score_ids:
conditions.append(ScoreEntry.id.in_(score_ids))
if score_type:
conditions.append(ScoreEntry.score_type == score_type)
if score_category:
conditions.append(ScoreEntry.score_category == score_category)
if sent_after:
conditions.append(ScoreEntry.timestamp >= sent_after)
if sent_before:
conditions.append(ScoreEntry.timestamp <= sent_before)
if not conditions:
return []
entries: Sequence[ScoreEntry] = self._query_entries(ScoreEntry, conditions=and_(*conditions))
return [entry.get_score() for entry in entries]
def get_prompt_scores(
self,
*,
attack_id: Optional[str | uuid.UUID] = None,
role: Optional[str] = None,
conversation_id: Optional[str | uuid.UUID] = None,
prompt_ids: Optional[Sequence[str | uuid.UUID]] = None,
labels: Optional[dict[str, str]] = None,
prompt_metadata: Optional[dict[str, Union[str, int]]] = None,
sent_after: Optional[datetime] = None,
sent_before: Optional[datetime] = None,
original_values: Optional[Sequence[str]] = None,
converted_values: Optional[Sequence[str]] = None,
data_type: Optional[str] = None,
not_data_type: Optional[str] = None,
converted_value_sha256: Optional[Sequence[str]] = None,
) -> Sequence[Score]:
"""
Retrieve scores attached to message pieces based on the specified filters.
Args:
attack_id (Optional[str | uuid.UUID], optional): The ID of the attack. Defaults to None.
role (Optional[str], optional): The role of the prompt. Defaults to None.
conversation_id (Optional[str | uuid.UUID], optional): The ID of the conversation. Defaults to None.
prompt_ids (Optional[Sequence[str] | Sequence[uuid.UUID]], optional): A list of prompt IDs.
Defaults to None.
labels (Optional[dict[str, str]], optional): A dictionary of labels. Defaults to None.
prompt_metadata (Optional[dict[str, Union[str, int]]], optional): The metadata associated with the prompt.
Defaults to None.
sent_after (Optional[datetime], optional): Filter for prompts sent after this datetime. Defaults to None.
sent_before (Optional[datetime], optional): Filter for prompts sent before this datetime. Defaults to None.
original_values (Optional[Sequence[str]], optional): A list of original values. Defaults to None.
converted_values (Optional[Sequence[str]], optional): A list of converted values. Defaults to None.
data_type (Optional[str], optional): The data type to filter by. Defaults to None.
not_data_type (Optional[str], optional): The data type to exclude. Defaults to None.
converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values.
Defaults to None.
Returns:
Sequence[Score]: A list of scores extracted from the message pieces.
"""
message_pieces = self.get_message_pieces(
attack_id=attack_id,
role=role,
conversation_id=conversation_id,
prompt_ids=prompt_ids,
labels=labels,
prompt_metadata=prompt_metadata,
sent_after=sent_after,
sent_before=sent_before,
original_values=original_values,
converted_values=converted_values,
data_type=data_type,
not_data_type=not_data_type,
converted_value_sha256=converted_value_sha256,
)
# Deduplicate message pieces by original_prompt_id to avoid duplicate scores
# since duplicated pieces share scores with their originals
seen_original_ids = set()
unique_pieces = []
for piece in message_pieces:
if piece.original_prompt_id not in seen_original_ids:
seen_original_ids.add(piece.original_prompt_id)
unique_pieces.append(piece)
scores = []
for piece in unique_pieces:
if piece.scores:
scores.extend(piece.scores)
return list(scores)
def get_conversation(self, *, conversation_id: str) -> MutableSequence[Message]:
"""
Retrieve a list of Message objects that have the specified conversation ID.
Args:
conversation_id (str): The conversation ID to match.
Returns:
MutableSequence[Message]: A list of chat memory entries with the specified conversation ID.
"""
message_pieces = self.get_message_pieces(conversation_id=conversation_id)
return group_conversation_message_pieces_by_sequence(message_pieces=message_pieces)
def get_request_from_response(self, *, response: Message) -> Message:
"""
Retrieve the request that produced the given response.
Args:
response (Message): The response message object to match.
Returns:
Message: The corresponding message object.
Raises:
ValueError: If the response is not from an assistant role or has no preceding request.
"""
if response.api_role != "assistant":
raise ValueError("The provided request is not a response (role must be 'assistant').")
if response.sequence < 1:
raise ValueError("The provided request does not have a preceding request (sequence < 1).")
conversation = self.get_conversation(conversation_id=response.conversation_id)
return conversation[response.sequence - 1]
def get_message_pieces(
self,
*,
attack_id: Optional[str | uuid.UUID] = None,
role: Optional[str] = None,
conversation_id: Optional[str | uuid.UUID] = None,
prompt_ids: Optional[Sequence[str | uuid.UUID]] = None,
labels: Optional[dict[str, str]] = None,
prompt_metadata: Optional[dict[str, Union[str, int]]] = None,
sent_after: Optional[datetime] = None,
sent_before: Optional[datetime] = None,
original_values: Optional[Sequence[str]] = None,
converted_values: Optional[Sequence[str]] = None,
data_type: Optional[str] = None,
not_data_type: Optional[str] = None,
converted_value_sha256: Optional[Sequence[str]] = None,
) -> Sequence[MessagePiece]:
"""
Retrieve a list of MessagePiece objects based on the specified filters.
Args:
attack_id (Optional[str | uuid.UUID], optional): The ID of the attack. Defaults to None.
role (Optional[str], optional): The role of the prompt. Defaults to None.
conversation_id (Optional[str | uuid.UUID], optional): The ID of the conversation. Defaults to None.
prompt_ids (Optional[Sequence[str] | Sequence[uuid.UUID]], optional): A list of prompt IDs.
Defaults to None.
labels (Optional[dict[str, str]], optional): A dictionary of labels. Defaults to None.
prompt_metadata (Optional[dict[str, Union[str, int]]], optional): The metadata associated with the prompt.
Defaults to None.
sent_after (Optional[datetime], optional): Filter for prompts sent after this datetime. Defaults to None.
sent_before (Optional[datetime], optional): Filter for prompts sent before this datetime. Defaults to None.
original_values (Optional[Sequence[str]], optional): A list of original values. Defaults to None.
converted_values (Optional[Sequence[str]], optional): A list of converted values. Defaults to None.
data_type (Optional[str], optional): The data type to filter by. Defaults to None.
not_data_type (Optional[str], optional): The data type to exclude. Defaults to None.
converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values.
Defaults to None.
Returns:
Sequence[MessagePiece]: A list of MessagePiece objects that match the specified filters.
Raises:
Exception: If there is an error retrieving the prompts,
an exception is logged and an empty list is returned.
"""
if prompt_ids is not None and len(prompt_ids) == 0:
return []
conditions = []
if attack_id:
conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id)))
if role:
conditions.append(PromptMemoryEntry.role == role)
if conversation_id:
conditions.append(PromptMemoryEntry.conversation_id == str(conversation_id))
if prompt_ids:
prompt_ids = [str(pi) for pi in prompt_ids]
conditions.append(PromptMemoryEntry.id.in_(prompt_ids))
if labels:
conditions.extend(self._get_message_pieces_memory_label_conditions(memory_labels=labels))
if prompt_metadata:
conditions.extend(self._get_message_pieces_prompt_metadata_conditions(prompt_metadata=prompt_metadata))
if sent_after:
conditions.append(PromptMemoryEntry.timestamp >= sent_after)
if sent_before:
conditions.append(PromptMemoryEntry.timestamp <= sent_before)
if original_values:
conditions.append(PromptMemoryEntry.original_value.in_(original_values))
if converted_values:
conditions.append(PromptMemoryEntry.converted_value.in_(converted_values))
if data_type:
conditions.append(PromptMemoryEntry.converted_value_data_type == data_type)
if not_data_type:
conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type)
if converted_value_sha256:
conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256))
try:
memory_entries: Sequence[PromptMemoryEntry] = self._query_entries(
PromptMemoryEntry, conditions=and_(*conditions) if conditions else None, join_scores=True
)
message_pieces = [memory_entry.get_message_piece() for memory_entry in memory_entries]
return sort_message_pieces(message_pieces=message_pieces)
except Exception as e:
logger.exception(f"Failed to retrieve prompts with error {e}")
raise
def duplicate_messages(self, *, messages: Sequence[Message]) -> tuple[str, Sequence[MessagePiece]]:
"""
Duplicate messages with a new conversation ID.
Each duplicated piece gets a fresh ``id`` and ``timestamp`` while
preserving ``original_prompt_id`` for tracking lineage.
Args:
messages: The messages to duplicate.
Returns:
Tuple of (new_conversation_id, duplicated_message_pieces).
"""
new_conversation_id = str(uuid.uuid4())
all_pieces: list[MessagePiece] = []
for message in messages:
duplicated_message = message.duplicate_message()
for piece in duplicated_message.message_pieces:
piece.conversation_id = new_conversation_id
all_pieces.extend(duplicated_message.message_pieces)
return new_conversation_id, all_pieces
def duplicate_conversation(self, *, conversation_id: str) -> str:
"""
Duplicate a conversation for reuse.
This can be useful when an attack strategy requires branching out from a particular point in the conversation.
One cannot continue both branches with the same conversation ID since that would corrupt
the memory. Instead, one needs to duplicate the conversation and continue with the new conversation ID.
Args:
conversation_id (str): The conversation ID with existing conversations.
Returns:
The uuid for the new conversation.
"""
messages = self.get_conversation(conversation_id=conversation_id)
new_conversation_id, all_pieces = self.duplicate_messages(messages=messages)
self.add_message_pieces_to_memory(message_pieces=all_pieces)
return new_conversation_id
def duplicate_conversation_excluding_last_turn(self, *, conversation_id: str) -> str:
"""
Duplicate a conversation, excluding the last turn. In this case, last turn is defined as before the last
user request (e.g. if there is half a turn, it just removes that half).
This can be useful when an attack strategy requires back tracking the last prompt/response pair.
Args:
conversation_id (str): The conversation ID with existing conversations.
Returns:
The uuid for the new conversation.
"""
messages = self.get_conversation(conversation_id=conversation_id)
# remove the final turn from the conversation
if len(messages) == 0:
return str(uuid.uuid4())
last_message = messages[-1]
length_of_sequence_to_remove = 0
length_of_sequence_to_remove = 1 if last_message.api_role == "system" or last_message.api_role == "user" else 2
messages_to_duplicate = [
message for message in messages if message.sequence <= last_message.sequence - length_of_sequence_to_remove
]
new_conversation_id, all_pieces = self.duplicate_messages(messages=messages_to_duplicate)
self.add_message_pieces_to_memory(message_pieces=all_pieces)
return new_conversation_id
def add_message_to_memory(self, *, request: Message) -> None:
"""
Insert a list of message pieces into the memory storage.
Automatically updates the sequence to be the next number in the conversation.
If necessary, generates embedding data for applicable entries
Args:
request (MessagePiece): The message piece to add to the memory.
"""
request.validate()
embedding_entries = []
message_pieces = request.message_pieces
self._update_sequence(message_pieces=message_pieces)
self.add_message_pieces_to_memory(message_pieces=message_pieces)
if self.memory_embedding:
for piece in message_pieces:
embedding_entry = self.memory_embedding.generate_embedding_memory_data(message_piece=piece)
embedding_entries.append(embedding_entry)
self._add_embeddings_to_memory(embedding_data=embedding_entries)
def _update_sequence(self, *, message_pieces: Sequence[MessagePiece]) -> None:
"""
Update the sequence number of the message pieces in the conversation.
Args:
message_pieces (Sequence[MessagePiece]): The list of message pieces to update.
"""
prev_conversations = self.get_message_pieces(conversation_id=message_pieces[0].conversation_id)
sequence = 0
if len(prev_conversations) > 0:
sequence = max(prev_conversations, key=lambda item: item.sequence).sequence + 1
for piece in message_pieces:
piece.sequence = sequence
def update_prompt_entries_by_conversation_id(self, *, conversation_id: str, update_fields: dict[str, Any]) -> bool:
"""
Update prompt entries for a given conversation ID with the specified field values.
Args:
conversation_id (str): The conversation ID of the entries to be updated.
update_fields (dict): A dictionary of field names and their new values (ex. {"labels": {"test": "value"}})
Returns:
bool: True if the update was successful, False otherwise.
Raises:
ValueError: If update_fields is empty or not provided.
"""
if not update_fields:
raise ValueError("update_fields must be provided to update prompt entries.")
# Fetch the relevant entries using query_entries
entries_to_update: MutableSequence[Base] = self._query_entries(
PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id
)
# Check if there are entries to update
if not entries_to_update:
logger.info(f"No entries found with conversation_id {conversation_id} to update.")
return False
# Use the utility function to update the entries
success = self._update_entries(entries=entries_to_update, update_fields=update_fields)
if success:
logger.info(f"Updated {len(entries_to_update)} entries with conversation_id {conversation_id}.")
else:
logger.error(f"Failed to update entries with conversation_id {conversation_id}.")
return success
def update_labels_by_conversation_id(self, *, conversation_id: str, labels: dict[str, Any]) -> bool:
"""
Update the labels of prompt entries in memory for a given conversation ID.
Args:
conversation_id (str): The conversation ID of the entries to be updated.
labels (dict): New dictionary of labels.
Returns:
bool: True if the update was successful, False otherwise.
"""
return self.update_prompt_entries_by_conversation_id(
conversation_id=conversation_id, update_fields={"labels": labels}
)
def update_prompt_metadata_by_conversation_id(
self, *, conversation_id: str, prompt_metadata: dict[str, Union[str, int]]
) -> bool:
"""
Update the metadata of prompt entries in memory for a given conversation ID.
Args:
conversation_id (str): The conversation ID of the entries to be updated.
prompt_metadata (dict[str, str | int]): New metadata.
Returns:
bool: True if the update was successful, False otherwise.
"""
return self.update_prompt_entries_by_conversation_id(
conversation_id=conversation_id, update_fields={"prompt_metadata": prompt_metadata}
)
@abc.abstractmethod
def dispose_engine(self) -> None:
"""
Dispose the engine and clean up resources.
"""
def cleanup(self) -> None:
"""
Ensure cleanup on process exit.
"""
# Ensure cleanup at process exit
atexit.register(self.dispose_engine)
# Ensure cleanup happens even if the object is garbage collected before process exits
weakref.finalize(self, self.dispose_engine)
def get_seeds(
self,
*,
value: Optional[str] = None,
value_sha256: Optional[Sequence[str]] = None,
dataset_name: Optional[str] = None,
dataset_name_pattern: Optional[str] = None,
data_types: Optional[Sequence[str]] = None,
harm_categories: Optional[Sequence[str]] = None,
added_by: Optional[str] = None,
authors: Optional[Sequence[str]] = None,
groups: Optional[Sequence[str]] = None,
source: Optional[str] = None,
seed_type: Optional[SeedType] = None,
is_objective: Optional[bool] = None, # Deprecated in 0.13.0: Use seed_type instead
parameters: Optional[Sequence[str]] = None,
metadata: Optional[dict[str, Union[str, int]]] = None,
prompt_group_ids: Optional[Sequence[uuid.UUID]] = None,
) -> Sequence[Seed]:
"""
Retrieve a list of seed prompts based on the specified filters.
Args:
value (str): The value to match by substring. If None, all values are returned.
value_sha256 (str): The SHA256 hash of the value to match. If None, all values are returned.
dataset_name (str): The dataset name to match exactly. If None, all dataset names are considered.
dataset_name_pattern (str): A pattern to match dataset names using SQL LIKE syntax.
Supports wildcards: % (any characters) and _ (single character).
Examples: "harm%" matches names starting with "harm", "%test%" matches names containing "test".
If both dataset_name and dataset_name_pattern are provided, dataset_name takes precedence.
data_types (Optional[Sequence[str], Optional): List of data types to filter seed prompts by
(e.g., text, image_path).
harm_categories (Sequence[str]): A list of harm categories to filter by. If None,
all harm categories are considered.
Specifying multiple harm categories returns only prompts that are marked with all harm categories.
added_by (str): The user who added the prompts.
authors (Sequence[str]): A list of authors to filter by.
Note that this filters by substring, so a query for "Adam Jones" may not return results if the record
is "A. Jones", "Jones, Adam", etc. If None, all authors are considered.
groups (Sequence[str]): A list of groups to filter by. If None, all groups are considered.
source (str): The source to filter by. If None, all sources are considered.
seed_type (SeedType): The type of seed to filter by ("prompt", "objective", or
"simulated_conversation").
is_objective (bool): Deprecated in 0.13.0. Use seed_type="objective" instead.
parameters (Sequence[str]): A list of parameters to filter by. Specifying parameters effectively returns
prompt templates instead of prompts.
metadata (dict[str, str | int]): A free-form dictionary for tagging prompts with custom metadata.
prompt_group_ids (Sequence[uuid.UUID]): A list of prompt group IDs to filter by.
Returns:
Sequence[SeedPrompt]: A list of prompts matching the criteria.
Raises:
ValueError: If both 'seed_type' and deprecated 'is_objective' parameters are specified.
"""
# Handle deprecated is_objective parameter
if is_objective is not None:
if seed_type is not None:
raise ValueError(
"Cannot specify both 'seed_type' and 'is_objective'. "
"is_objective is deprecated since 0.13.0. Use seed_type='objective' instead."
)
warnings.warn(
"is_objective parameter is deprecated since 0.13.0. Use seed_type='objective' instead.",
DeprecationWarning,
stacklevel=2,
)
# Convert is_objective to seed_type
seed_type = "objective" if is_objective else "prompt"
conditions = []
# Apply filters for non-list fields
if value:
conditions.append(SeedEntry.value.contains(value))
if value_sha256:
conditions.append(SeedEntry.value_sha256.in_(value_sha256))
if dataset_name:
conditions.append(SeedEntry.dataset_name == dataset_name)
elif dataset_name_pattern:
conditions.append(SeedEntry.dataset_name.like(dataset_name_pattern))
if prompt_group_ids:
conditions.append(SeedEntry.prompt_group_id.in_(prompt_group_ids))
if data_types:
data_type_conditions = SeedEntry.data_type.in_(data_types)
conditions.append(data_type_conditions)
if added_by:
conditions.append(SeedEntry.added_by == added_by)
if source:
conditions.append(SeedEntry.source == source)
# Handle seed_type filtering with backward compatibility for is_objective
if seed_type == "objective":
# Match either seed_type="objective" OR legacy is_objective=True
conditions.append(or_(SeedEntry.seed_type == "objective", SeedEntry.is_objective == True)) # noqa: E712
elif seed_type is not None:
conditions.append(SeedEntry.seed_type == seed_type)
self._add_list_conditions(field=SeedEntry.harm_categories, values=harm_categories, conditions=conditions)
self._add_list_conditions(field=SeedEntry.authors, values=authors, conditions=conditions)
self._add_list_conditions(field=SeedEntry.groups, values=groups, conditions=conditions)
if parameters:
self._add_list_conditions(field=SeedEntry.parameters, values=parameters, conditions=conditions)
if metadata:
conditions.append(self._get_seed_metadata_conditions(metadata=metadata))
try:
memory_entries: Sequence[SeedEntry] = self._query_entries(
SeedEntry,
conditions=and_(*conditions) if conditions else None,
)
return [memory_entry.get_seed() for memory_entry in memory_entries]
except Exception as e:
logger.exception(f"Failed to retrieve prompts with dataset name {dataset_name} with error {e}")
raise
def _add_list_conditions(
self, field: InstrumentedAttribute[Any], conditions: list[Any], values: Optional[Sequence[str]] = None
) -> None:
if values:
conditions.extend(field.contains(value) for value in values)
async def _serialize_seed_value(self, prompt: Seed) -> str:
"""
Serialize the value of a seed prompt based on its data type.
Args:
prompt (Seed): The seed prompt to serialize. Must have a valid `data_type`.
Returns:
str: The serialized value for the prompt.
Raises:
ValueError: If the `data_type` of the prompt is unsupported.
"""
extension = DataTypeSerializer.get_extension(prompt.value)
if extension:
extension = extension.lstrip(".")
serializer = data_serializer_factory(
category="seed-prompt-entries", data_type=prompt.data_type, value=prompt.value, extension=extension