forked from google/adk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinstructions_utils.py
More file actions
155 lines (132 loc) · 4.63 KB
/
instructions_utils.py
File metadata and controls
155 lines (132 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import re
from ..agents.readonly_context import ReadonlyContext
from ..sessions.state import State
__all__ = [
'inject_session_state',
]
logger = logging.getLogger('google_adk.' + __name__)
async def inject_session_state(
template: str,
readonly_context: ReadonlyContext,
) -> str:
"""Populates values in the instruction template, e.g. state, artifact, etc.
This method is intended to be used in InstructionProvider based instruction
and global_instruction which are called with readonly_context.
e.g.
```
...
from google.adk.utils.instructions_utils import inject_session_state
async def build_instruction(
readonly_context: ReadonlyContext,
) -> str:
return await inject_session_state(
'You can inject a state variable like {var_name} or an artifact '
'{artifact.file_name} into the instruction template.',
readonly_context,
)
agent = Agent(
model="gemini-2.0-flash",
name="agent",
instruction=build_instruction,
)
```
Args:
template: The instruction template.
readonly_context: The read-only context
Returns:
The instruction template with values populated.
"""
invocation_context = readonly_context._invocation_context
async def _async_sub(pattern, repl_async_fn, string) -> str:
result = []
last_end = 0
for match in re.finditer(pattern, string):
result.append(string[last_end : match.start()])
replacement = await repl_async_fn(match)
result.append(replacement)
last_end = match.end()
result.append(string[last_end:])
return ''.join(result)
async def _replace_match(match) -> str:
raw = match.group()
# Double (or more) braces are escape sequences: {{x}} → {x}
if raw.startswith('{{') and raw.endswith('}}'):
return raw[1:-1]
# Use slicing instead of lstrip/rstrip to avoid stripping extra braces
# from malformed inputs like '{var}}}'.
var_name = raw[1:-1].strip()
optional = False
if var_name.endswith('?'):
optional = True
var_name = var_name.removesuffix('?')
if var_name.startswith('artifact.'):
var_name = var_name.removeprefix('artifact.')
if invocation_context.artifact_service is None:
raise ValueError('Artifact service is not initialized.')
artifact = await invocation_context.artifact_service.load_artifact(
app_name=invocation_context.session.app_name,
user_id=invocation_context.session.user_id,
session_id=invocation_context.session.id,
filename=var_name,
)
if artifact is None:
if optional:
logger.debug(
'Artifact %s not found, replacing with empty string', var_name
)
return ''
else:
raise KeyError(f'Artifact {var_name} not found.')
return str(artifact)
else:
if not _is_valid_state_name(var_name):
return match.group()
if var_name in invocation_context.session.state:
value = invocation_context.session.state[var_name]
if value is None:
return ''
return str(value)
else:
if optional:
logger.debug(
'Context variable %s not found, replacing with empty string',
var_name,
)
return ''
else:
raise KeyError(f'Context variable not found: `{var_name}`.')
return await _async_sub(r'{+[^{}]*}+', _replace_match, template)
def _is_valid_state_name(var_name):
"""Checks if the variable name is a valid state name.
Valid state is either:
- Valid identifier
- <Valid prefix>:<Valid identifier>
All the others will just return as it is.
Args:
var_name: The variable name to check.
Returns:
True if the variable name is a valid state name, False otherwise.
"""
parts = var_name.split(':')
if len(parts) == 1:
return var_name.isidentifier()
if len(parts) == 2:
prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
if (parts[0] + ':') in prefixes:
return parts[1].isidentifier()
return False