Skip to content

Commit 811e50a

Browse files
GWealecopybara-github
authored andcommitted
feat: add generate/create modes for Vertex AI Memory Bank writes
add_events_to_memory now supports memory_write_mode to select generate (event-based extraction/consolidation) or create (direct raw fact writes via memory_facts). This now lets custom memory pipelines while keeping generate as the default path Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 869897256
1 parent d5332f4 commit 811e50a

6 files changed

Lines changed: 468 additions & 4 deletions

File tree

src/google/adk/agents/context.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,8 @@ async def add_events_to_memory(
345345
346346
Args:
347347
events: Explicit events to add to memory.
348-
custom_metadata: Optional standard metadata for memory generation.
348+
custom_metadata: Optional metadata forwarded to the configured memory
349+
service. Supported keys are implementation-specific.
349350
350351
Raises:
351352
ValueError: If memory service is not available.
@@ -362,6 +363,33 @@ async def add_events_to_memory(
362363
custom_metadata=custom_metadata,
363364
)
364365

366+
async def add_memory(
367+
self,
368+
*,
369+
memories: Sequence[str],
370+
custom_metadata: Mapping[str, object] | None = None,
371+
) -> None:
372+
"""Adds explicit memory items directly to the memory service.
373+
374+
Uses this callback's current session identifiers as memory scope.
375+
376+
Args:
377+
memories: Explicit memory items to add.
378+
custom_metadata: Optional metadata forwarded to the configured memory
379+
service. Supported keys are implementation-specific.
380+
381+
Raises:
382+
ValueError: If memory service is not available.
383+
"""
384+
if self._invocation_context.memory_service is None:
385+
raise ValueError("Cannot add memory: memory service is not available.")
386+
await self._invocation_context.memory_service.add_memory(
387+
app_name=self._invocation_context.session.app_name,
388+
user_id=self._invocation_context.session.user_id,
389+
memories=memories,
390+
custom_metadata=custom_metadata,
391+
)
392+
365393
async def search_memory(self, query: str) -> SearchMemoryResponse:
366394
"""Searches the memory of the current user.
367395

src/google/adk/memory/base_memory_service.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,40 @@ async def add_events_to_memory(
8686
session_id: Optional session ID for memory scope/partitioning.
8787
custom_metadata: Optional, portable metadata for memory generation. Prefer
8888
this for service-specific fields (e.g., TTL) that may later become
89-
first-class API parameters.
89+
first-class API parameters. Supported keys are
90+
implementation-defined by each memory service.
9091
"""
9192
raise NotImplementedError(
9293
"This memory service does not support adding event deltas. "
9394
"Call add_session_to_memory(session) to ingest the full session."
9495
)
9596

97+
async def add_memory(
98+
self,
99+
*,
100+
app_name: str,
101+
user_id: str,
102+
memories: Sequence[str],
103+
custom_metadata: Mapping[str, object] | None = None,
104+
) -> None:
105+
"""Adds explicit memory items directly to the memory service.
106+
107+
This is intended for services that support direct memory writes in addition
108+
to event-based memory generation.
109+
110+
Args:
111+
app_name: The application name for memory scope.
112+
user_id: The user ID for memory scope.
113+
memories: Explicit memory items to add.
114+
custom_metadata: Optional, portable metadata for memory writes. Supported
115+
keys are implementation-defined by each memory service.
116+
"""
117+
raise NotImplementedError(
118+
"This memory service does not support direct memory writes. "
119+
"Call add_events_to_memory(...) or add_session_to_memory(session) "
120+
"instead."
121+
)
122+
96123
@abstractmethod
97124
async def search_memory(
98125
self,

src/google/adk/memory/vertex_ai_memory_bank_service.py

Lines changed: 217 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from collections.abc import Mapping
1818
from collections.abc import Sequence
1919
from datetime import datetime
20+
from functools import lru_cache
2021
import logging
2122
from typing import Optional
2223
from typing import TYPE_CHECKING
@@ -37,7 +38,7 @@
3738

3839
logger = logging.getLogger('google_adk.' + __name__)
3940

40-
_GENERATE_MEMORIES_CONFIG_KEYS = frozenset({
41+
_GENERATE_MEMORIES_CONFIG_FALLBACK_KEYS = frozenset({
4142
'disable_consolidation',
4243
'disable_memory_revisions',
4344
'http_options',
@@ -49,6 +50,20 @@
4950
'wait_for_completion',
5051
})
5152

53+
_CREATE_MEMORY_CONFIG_FALLBACK_KEYS = frozenset({
54+
'description',
55+
'disable_memory_revisions',
56+
'display_name',
57+
'expire_time',
58+
'http_options',
59+
'metadata',
60+
'revision_expire_time',
61+
'revision_ttl',
62+
'topics',
63+
'ttl',
64+
'wait_for_completion',
65+
})
66+
5267

5368
def _supports_generate_memories_metadata() -> bool:
5469
"""Returns whether installed Vertex SDK supports config.metadata."""
@@ -62,6 +77,61 @@ def _supports_generate_memories_metadata() -> bool:
6277
)
6378

6479

80+
def _supports_create_memory_metadata() -> bool:
81+
"""Returns whether installed Vertex SDK supports create config.metadata."""
82+
try:
83+
from vertexai._genai.types import common as vertex_common_types
84+
except ImportError:
85+
return False
86+
return 'metadata' in vertex_common_types.AgentEngineMemoryConfig.model_fields
87+
88+
89+
@lru_cache(maxsize=1)
90+
def _get_generate_memories_config_keys() -> frozenset[str]:
91+
"""Returns supported config keys for memories.generate.
92+
93+
Uses SDK runtime model fields when available and falls back to a static
94+
allowlist to preserve compatibility when introspection is unavailable.
95+
"""
96+
try:
97+
from vertexai._genai.types import common as vertex_common_types
98+
except ImportError:
99+
return _GENERATE_MEMORIES_CONFIG_FALLBACK_KEYS
100+
101+
try:
102+
model_fields = (
103+
vertex_common_types.GenerateAgentEngineMemoriesConfig.model_fields
104+
)
105+
except AttributeError:
106+
return _GENERATE_MEMORIES_CONFIG_FALLBACK_KEYS
107+
108+
if not isinstance(model_fields, Mapping):
109+
return _GENERATE_MEMORIES_CONFIG_FALLBACK_KEYS
110+
return frozenset(model_fields.keys())
111+
112+
113+
@lru_cache(maxsize=1)
114+
def _get_create_memory_config_keys() -> frozenset[str]:
115+
"""Returns supported config keys for memories.create.
116+
117+
Uses SDK runtime model fields when available and falls back to a static
118+
allowlist to preserve compatibility when introspection is unavailable.
119+
"""
120+
try:
121+
from vertexai._genai.types import common as vertex_common_types
122+
except ImportError:
123+
return _CREATE_MEMORY_CONFIG_FALLBACK_KEYS
124+
125+
try:
126+
model_fields = vertex_common_types.AgentEngineMemoryConfig.model_fields
127+
except AttributeError:
128+
return _CREATE_MEMORY_CONFIG_FALLBACK_KEYS
129+
130+
if not isinstance(model_fields, Mapping):
131+
return _CREATE_MEMORY_CONFIG_FALLBACK_KEYS
132+
return frozenset(model_fields.keys())
133+
134+
65135
class VertexAiMemoryBankService(BaseMemoryService):
66136
"""Implementation of the BaseMemoryService using Vertex AI Memory Bank."""
67137

@@ -122,6 +192,15 @@ async def add_events_to_memory(
122192
session_id: str | None = None,
123193
custom_metadata: Mapping[str, object] | None = None,
124194
) -> None:
195+
"""Adds events to Vertex AI Memory Bank via memories.generate.
196+
197+
Args:
198+
app_name: The application name for memory scope.
199+
user_id: The user ID for memory scope.
200+
events: The events to process for memory generation.
201+
session_id: Optional session ID. Currently unused.
202+
custom_metadata: Optional service-specific metadata for generate config.
203+
"""
125204
_ = session_id
126205
await self._add_events_to_memory_from_events(
127206
app_name=app_name,
@@ -130,6 +209,23 @@ async def add_events_to_memory(
130209
custom_metadata=custom_metadata,
131210
)
132211

212+
@override
213+
async def add_memory(
214+
self,
215+
*,
216+
app_name: str,
217+
user_id: str,
218+
memories: Sequence[str],
219+
custom_metadata: Mapping[str, object] | None = None,
220+
) -> None:
221+
"""Adds explicit memory items via Vertex memories.create."""
222+
await self._add_memories_via_create(
223+
app_name=app_name,
224+
user_id=user_id,
225+
memories=memories,
226+
custom_metadata=custom_metadata,
227+
)
228+
133229
async def _add_events_to_memory_from_events(
134230
self,
135231
*,
@@ -166,6 +262,34 @@ async def _add_events_to_memory_from_events(
166262
else:
167263
logger.info('No events to add to memory.')
168264

265+
async def _add_memories_via_create(
266+
self,
267+
*,
268+
app_name: str,
269+
user_id: str,
270+
memories: Sequence[str],
271+
custom_metadata: Mapping[str, object] | None = None,
272+
) -> None:
273+
"""Adds direct memory items without server-side extraction."""
274+
if not self._agent_engine_id:
275+
raise ValueError('Agent Engine ID is required for Memory Bank.')
276+
277+
memory_texts = _validate_memory_texts(memories)
278+
api_client = self._get_api_client()
279+
config = _build_create_memory_config(custom_metadata)
280+
for memory_text in memory_texts:
281+
operation = await api_client.agent_engines.memories.create(
282+
name='reasoningEngines/' + self._agent_engine_id,
283+
fact=memory_text,
284+
scope={
285+
'app_name': app_name,
286+
'user_id': user_id,
287+
},
288+
config=config,
289+
)
290+
logger.info('Create memory response received.')
291+
logger.debug('Create memory response: %s', operation)
292+
169293
@override
170294
async def search_memory(self, *, app_name: str, user_id: str, query: str):
171295
if not self._agent_engine_id:
@@ -237,6 +361,7 @@ def _build_generate_memories_config(
237361
"""Builds a valid memories.generate config from caller metadata."""
238362
config: dict[str, object] = {'wait_for_completion': False}
239363
supports_metadata = _supports_generate_memories_metadata()
364+
config_keys = _get_generate_memories_config_keys()
240365
if not custom_metadata:
241366
return config
242367

@@ -267,7 +392,7 @@ def _build_generate_memories_config(
267392
' mapping.'
268393
)
269394
continue
270-
if key in _GENERATE_MEMORIES_CONFIG_KEYS:
395+
if key in config_keys:
271396
if value is None:
272397
continue
273398
config[key] = value
@@ -304,6 +429,96 @@ def _build_generate_memories_config(
304429
return config
305430

306431

432+
def _build_create_memory_config(
433+
custom_metadata: Mapping[str, object] | None,
434+
) -> dict[str, object]:
435+
"""Builds a valid memories.create config from caller metadata."""
436+
config: dict[str, object] = {'wait_for_completion': False}
437+
supports_metadata = _supports_create_memory_metadata()
438+
config_keys = _get_create_memory_config_keys()
439+
if not custom_metadata:
440+
return config
441+
442+
logger.debug('Memory creation metadata: %s', custom_metadata)
443+
444+
metadata_by_key: dict[str, object] = {}
445+
for key, value in custom_metadata.items():
446+
if key == 'metadata':
447+
if value is None:
448+
continue
449+
if not supports_metadata:
450+
logger.warning(
451+
'Ignoring metadata because installed Vertex SDK does not support'
452+
' create config.metadata.'
453+
)
454+
continue
455+
if isinstance(value, Mapping):
456+
config['metadata'] = _build_vertex_metadata(value)
457+
else:
458+
logger.warning(
459+
'Ignoring metadata because custom_metadata["metadata"] is not a'
460+
' mapping.'
461+
)
462+
continue
463+
if key in config_keys:
464+
if value is None:
465+
continue
466+
config[key] = value
467+
else:
468+
metadata_by_key[key] = value
469+
470+
if not metadata_by_key:
471+
return config
472+
473+
if not supports_metadata:
474+
logger.warning(
475+
'Ignoring custom metadata keys %s because installed Vertex SDK does '
476+
'not support create config.metadata.',
477+
sorted(metadata_by_key.keys()),
478+
)
479+
return config
480+
481+
existing_metadata = config.get('metadata')
482+
if existing_metadata is None:
483+
config['metadata'] = _build_vertex_metadata(metadata_by_key)
484+
return config
485+
486+
if isinstance(existing_metadata, Mapping):
487+
merged_metadata = dict(existing_metadata)
488+
merged_metadata.update(_build_vertex_metadata(metadata_by_key))
489+
config['metadata'] = merged_metadata
490+
return config
491+
492+
logger.warning(
493+
'Ignoring custom metadata keys %s because config.metadata is not a'
494+
' mapping.',
495+
sorted(metadata_by_key.keys()),
496+
)
497+
return config
498+
499+
500+
def _validate_memory_texts(
501+
memories: Sequence[str],
502+
) -> list[str]:
503+
"""Validates direct textual memory items passed to add_memory."""
504+
if isinstance(memories, str):
505+
raise TypeError('memories must be a sequence of strings.')
506+
if not isinstance(memories, Sequence):
507+
raise TypeError('memories must be a sequence of strings.')
508+
memory_texts: list[str] = []
509+
for index, raw_memory in enumerate(memories):
510+
if not isinstance(raw_memory, str):
511+
raise TypeError(f'memories[{index}] must be a string.')
512+
memory_text = raw_memory.strip()
513+
if not memory_text:
514+
raise ValueError(f'memories[{index}] must not be empty.')
515+
memory_texts.append(memory_text)
516+
517+
if not memory_texts:
518+
raise ValueError('memories must contain at least one entry.')
519+
return memory_texts
520+
521+
307522
def _build_vertex_metadata(
308523
metadata_by_key: Mapping[str, object],
309524
) -> dict[str, object]:

0 commit comments

Comments
 (0)