Skip to content

Commit c753276

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add metadata to memories
feat: Support metadata filtering for memory retrieval feat: Support metadata merge strategies for memory generation PiperOrigin-RevId: 854287666
1 parent b814aab commit c753276

6 files changed

Lines changed: 432 additions & 2 deletions

File tree

tests/unit/vertexai/genai/replays/test_create_agent_engine_memory.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,28 @@ def test_create_memory_with_ttl(client):
2525
assert isinstance(agent_engine, types.AgentEngine)
2626
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)
2727

28+
metadata = {
29+
"my_string_key": types.MemoryMetadataValue(
30+
string_value="my_string_value"
31+
),
32+
"my_double_key": types.MemoryMetadataValue(double_value=123.456),
33+
"my_boolean_key": types.MemoryMetadataValue(bool_value=True),
34+
"my_timestamp_key": types.MemoryMetadataValue(
35+
timestamp_value=datetime.datetime(
36+
2027, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
37+
)
38+
),
39+
}
40+
2841
operation = client.agent_engines.memories.create(
2942
name=agent_engine.api_resource.name,
3043
fact="memory_fact",
3144
scope={"user_id": "123"},
32-
config=types.AgentEngineMemoryConfig(display_name="my_memory_fact", ttl="120s"),
45+
config=types.AgentEngineMemoryConfig(
46+
display_name="my_memory_fact",
47+
ttl="120s",
48+
metadata=metadata,
49+
),
3350
)
3451
assert isinstance(operation, types.AgentEngineMemoryOperation)
3552
assert operation.response.fact == "memory_fact"
@@ -42,6 +59,7 @@ def test_create_memory_with_ttl(client):
4259
<= operation.response.expire_time
4360
<= operation.response.create_time + datetime.timedelta(seconds=120.5)
4461
)
62+
assert operation.response.metadata == metadata
4563
# Clean up resources.
4664
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
4765

@@ -51,7 +69,7 @@ def test_create_memory_with_expire_time(client):
5169
assert isinstance(agent_engine, types.AgentEngine)
5270
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)
5371
expire_time = datetime.datetime(
54-
2026, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
72+
2027, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
5573
)
5674

5775
operation = client.agent_engines.memories.create(

tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,138 @@ def test_generate_memories_direct_memories_source(client):
145145
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
146146

147147

148+
def test_generate_memories_with_metadata(client):
149+
agent_engine = client.agent_engines.create()
150+
metadata = {
151+
"my_string_key": types.MemoryMetadataValue(
152+
string_value="my_string_value"
153+
),
154+
"my_double_key": types.MemoryMetadataValue(double_value=123.456),
155+
"my_boolean_key": types.MemoryMetadataValue(bool_value=True),
156+
"my_timestamp_key": types.MemoryMetadataValue(
157+
timestamp_value=datetime.datetime(
158+
2027, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
159+
)
160+
),
161+
}
162+
# Reuse the same content and scope for all generation requests to ensure
163+
# that the same memory is updated.
164+
direct_memories_source = types.GenerateMemoriesRequestDirectMemoriesSource(
165+
direct_memories=[
166+
types.GenerateMemoriesRequestDirectMemoriesSourceDirectMemory(
167+
fact="I am a software engineer."
168+
),
169+
]
170+
)
171+
scope = {"user_id": "test-user-id"}
172+
173+
operation = client.agent_engines.memories.generate(
174+
name=agent_engine.api_resource.name,
175+
scope=scope,
176+
direct_memories_source=direct_memories_source,
177+
config=types.GenerateAgentEngineMemoriesConfig(
178+
metadata=metadata
179+
),
180+
)
181+
assert len(operation.response.generated_memories) >= 1
182+
memory = client.agent_engines.memories.get(
183+
name=operation.response.generated_memories[0].memory.name
184+
)
185+
assert memory.metadata == metadata
186+
187+
# Overwrite the metadata.
188+
overwrite_metadata = {
189+
"my_string_key": types.MemoryMetadataValue(string_value="new_value"),
190+
}
191+
operation = client.agent_engines.memories.generate(
192+
name=agent_engine.api_resource.name,
193+
scope=scope,
194+
direct_memories_source=direct_memories_source,
195+
config=types.GenerateAgentEngineMemoriesConfig(
196+
metadata=overwrite_metadata,
197+
metadata_merge_strategy=types.MemoryMetadataMergeStrategy.OVERWRITE,
198+
),
199+
)
200+
assert len(operation.response.generated_memories) >= 1
201+
assert (
202+
operation.response.generated_memories[0].action
203+
== types.GenerateMemoriesResponseGeneratedMemoryAction.UPDATED
204+
)
205+
memory = client.agent_engines.memories.get(
206+
name=operation.response.generated_memories[0].memory.name
207+
)
208+
assert memory.metadata == overwrite_metadata
209+
210+
# Merge the metadata.
211+
new_metadata = {
212+
"my_double_key": types.MemoryMetadataValue(double_value=123.456),
213+
}
214+
operation = client.agent_engines.memories.generate(
215+
name=agent_engine.api_resource.name,
216+
scope=scope,
217+
direct_memories_source=direct_memories_source,
218+
config=types.GenerateAgentEngineMemoriesConfig(
219+
metadata=new_metadata,
220+
metadata_merge_strategy=types.MemoryMetadataMergeStrategy.MERGE,
221+
),
222+
)
223+
assert len(operation.response.generated_memories) >= 1
224+
assert (
225+
operation.response.generated_memories[0].action
226+
== types.GenerateMemoriesResponseGeneratedMemoryAction.UPDATED
227+
)
228+
memory = client.agent_engines.memories.get(
229+
name=operation.response.generated_memories[0].memory.name
230+
)
231+
assert memory.metadata == {**overwrite_metadata, **new_metadata}
232+
233+
# Restrict consolidation based on metadata values. For the first request,
234+
# there's no existing memories that match the metadata, so a new memory is
235+
# created.
236+
restricted_metadata = {
237+
"my_string_key": types.MemoryMetadataValue(string_value="new_value2"),
238+
}
239+
operation = client.agent_engines.memories.generate(
240+
name=agent_engine.api_resource.name,
241+
scope=scope,
242+
direct_memories_source=direct_memories_source,
243+
config=types.GenerateAgentEngineMemoriesConfig(
244+
metadata=restricted_metadata,
245+
metadata_merge_strategy="REQUIRE_EXACT_MATCH",
246+
),
247+
)
248+
assert len(operation.response.generated_memories) == 1
249+
# Metadata doesn't match existing memory, so a new memory is created.
250+
assert (
251+
operation.response.generated_memories[0].action
252+
== types.GenerateMemoriesResponseGeneratedMemoryAction.CREATED
253+
)
254+
memory = client.agent_engines.memories.get(
255+
name=operation.response.generated_memories[0].memory.name
256+
)
257+
assert memory.metadata == restricted_metadata
258+
259+
# Send a second request where the metadata matches only one of the existing
260+
# memories.
261+
operation = client.agent_engines.memories.generate(
262+
name=agent_engine.api_resource.name,
263+
scope=scope,
264+
direct_memories_source=direct_memories_source,
265+
config=types.GenerateAgentEngineMemoriesConfig(
266+
metadata=restricted_metadata,
267+
metadata_merge_strategy="REQUIRE_EXACT_MATCH",
268+
),
269+
)
270+
assert len(operation.response.generated_memories) == 1
271+
assert (
272+
operation.response.generated_memories[0].action
273+
== types.GenerateMemoriesResponseGeneratedMemoryAction.UPDATED
274+
)
275+
assert operation.response.generated_memories[0].memory.name == memory.name
276+
277+
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
278+
279+
148280
pytestmark = pytest_helper.setup(
149281
file=__file__,
150282
globals_for_file=globals(),

tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

17+
import datetime
1718
import pytest
1819

1920

@@ -115,6 +116,57 @@ def test_retrieve_memories_with_simple_retrieval_params(client):
115116
agent_engine.delete(force=True)
116117

117118

119+
def test_retrieve_memories_with_metadata(client):
120+
agent_engine = client.agent_engines.create()
121+
metadata = {
122+
"my_string_key": types.MemoryMetadataValue(
123+
string_value="my_string_value"
124+
),
125+
"my_double_key": types.MemoryMetadataValue(double_value=123.456),
126+
"my_boolean_key": types.MemoryMetadataValue(bool_value=True),
127+
"my_timestamp_key": types.MemoryMetadataValue(
128+
timestamp_value=datetime.datetime(
129+
2027, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
130+
)
131+
),
132+
}
133+
scope = {"user_id": "123"}
134+
client.agent_engines.memories.create(
135+
name=agent_engine.api_resource.name,
136+
fact="memory_fact_1",
137+
scope=scope,
138+
)
139+
operation = client.agent_engines.memories.create(
140+
name=agent_engine.api_resource.name,
141+
fact="memory_fact_2",
142+
scope=scope,
143+
config={"metadata": metadata},
144+
)
145+
memory_name2 = operation.response.name
146+
147+
results = client.agent_engines.memories.retrieve(
148+
name=agent_engine.api_resource.name,
149+
scope=scope,
150+
config={
151+
"filter_groups": [
152+
{
153+
"filters": [
154+
{
155+
"key": "my_string_key",
156+
"value": {"string_value": "my_string_value"}
157+
}
158+
]
159+
}
160+
],
161+
},
162+
)
163+
assert len(results) == 1
164+
assert results[0].memory.name == memory_name2
165+
166+
# Clean up resources.
167+
agent_engine.delete(force=True)
168+
169+
118170
pytestmark = pytest_helper.setup(
119171
file=__file__,
120172
globals_for_file=globals(),

vertexai/_genai/memories.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def _AgentEngineMemoryConfig_to_vertex(
7777
parent_object, ["topics"], [item for item in getv(from_object, ["topics"])]
7878
)
7979

80+
if getv(from_object, ["metadata"]) is not None:
81+
setv(parent_object, ["metadata"], getv(from_object, ["metadata"]))
82+
8083
return to_object
8184

8285

@@ -153,6 +156,16 @@ def _GenerateAgentEngineMemoriesConfig_to_vertex(
153156
getv(from_object, ["disable_memory_revisions"]),
154157
)
155158

159+
if getv(from_object, ["metadata"]) is not None:
160+
setv(parent_object, ["metadata"], getv(from_object, ["metadata"]))
161+
162+
if getv(from_object, ["metadata_merge_strategy"]) is not None:
163+
setv(
164+
parent_object,
165+
["metadataMergeStrategy"],
166+
getv(from_object, ["metadata_merge_strategy"]),
167+
)
168+
156169
return to_object
157170

158171

@@ -316,6 +329,13 @@ def _RetrieveAgentEngineMemoriesConfig_to_vertex(
316329
if getv(from_object, ["filter"]) is not None:
317330
setv(parent_object, ["filter"], getv(from_object, ["filter"]))
318331

332+
if getv(from_object, ["filter_groups"]) is not None:
333+
setv(
334+
parent_object,
335+
["filterGroups"],
336+
[item for item in getv(from_object, ["filter_groups"])],
337+
)
338+
319339
return to_object
320340

321341

@@ -413,6 +433,9 @@ def _UpdateAgentEngineMemoryConfig_to_vertex(
413433
parent_object, ["topics"], [item for item in getv(from_object, ["topics"])]
414434
)
415435

436+
if getv(from_object, ["metadata"]) is not None:
437+
setv(parent_object, ["metadata"], getv(from_object, ["metadata"]))
438+
416439
if getv(from_object, ["update_mask"]) is not None:
417440
setv(
418441
parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"])

vertexai/_genai/types/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,17 @@
578578
from .common import MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopicOrDict
579579
from .common import MemoryBankCustomizationConfigMemoryTopicOrDict
580580
from .common import MemoryBankCustomizationConfigOrDict
581+
from .common import MemoryConjunctionFilter
582+
from .common import MemoryConjunctionFilterDict
583+
from .common import MemoryConjunctionFilterOrDict
581584
from .common import MemoryDict
585+
from .common import MemoryFilter
586+
from .common import MemoryFilterDict
587+
from .common import MemoryFilterOrDict
588+
from .common import MemoryMetadataMergeStrategy
589+
from .common import MemoryMetadataValue
590+
from .common import MemoryMetadataValueDict
591+
from .common import MemoryMetadataValueOrDict
582592
from .common import MemoryOrDict
583593
from .common import MemoryRevision
584594
from .common import MemoryRevisionDict
@@ -613,6 +623,7 @@
613623
from .common import ObservabilityEvalCase
614624
from .common import ObservabilityEvalCaseDict
615625
from .common import ObservabilityEvalCaseOrDict
626+
from .common import Operator
616627
from .common import OptimizeConfig
617628
from .common import OptimizeConfigDict
618629
from .common import OptimizeConfigOrDict
@@ -1523,6 +1534,15 @@
15231534
"RetrieveMemoriesRequestSimpleRetrievalParams",
15241535
"RetrieveMemoriesRequestSimpleRetrievalParamsDict",
15251536
"RetrieveMemoriesRequestSimpleRetrievalParamsOrDict",
1537+
"MemoryMetadataValue",
1538+
"MemoryMetadataValueDict",
1539+
"MemoryMetadataValueOrDict",
1540+
"MemoryFilter",
1541+
"MemoryFilterDict",
1542+
"MemoryFilterOrDict",
1543+
"MemoryConjunctionFilter",
1544+
"MemoryConjunctionFilterDict",
1545+
"MemoryConjunctionFilterOrDict",
15261546
"RetrieveAgentEngineMemoriesConfig",
15271547
"RetrieveAgentEngineMemoriesConfigDict",
15281548
"RetrieveAgentEngineMemoriesConfigOrDict",
@@ -1909,6 +1929,7 @@
19091929
"IdentityType",
19101930
"AgentServerMode",
19111931
"ManagedTopicEnum",
1932+
"Operator",
19121933
"Language",
19131934
"MachineConfig",
19141935
"State",
@@ -1917,6 +1938,7 @@
19171938
"RubricContentType",
19181939
"EvaluationRunState",
19191940
"OptimizeTarget",
1941+
"MemoryMetadataMergeStrategy",
19201942
"GenerateMemoriesResponseGeneratedMemoryAction",
19211943
"PromptOptimizerMethod",
19221944
"PromptData",

0 commit comments

Comments
 (0)