Skip to content

Commit 647d3b1

Browse files
xuanyang15copybara-github
authored andcommitted
chore: Unify CallbackContext and ToolContext into the Context class
Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 869400758
1 parent 4f71b45 commit 647d3b1

4 files changed

Lines changed: 23 additions & 320 deletions

File tree

src/google/adk/agents/callback_context.py

Lines changed: 4 additions & 235 deletions
Original file line numberDiff line numberDiff line change
@@ -14,240 +14,9 @@
1414

1515
from __future__ import annotations
1616

17-
from collections.abc import Mapping
18-
from collections.abc import Sequence
19-
from typing import Any
20-
from typing import Optional
21-
from typing import TYPE_CHECKING
22-
23-
from typing_extensions import override
24-
17+
from .context import Context
18+
# Keep ReadonlyContext for backward compatibility
2519
from .readonly_context import ReadonlyContext
2620

27-
if TYPE_CHECKING:
28-
from google.genai import types
29-
30-
from ..artifacts.base_artifact_service import ArtifactVersion
31-
from ..auth.auth_credential import AuthCredential
32-
from ..auth.auth_tool import AuthConfig
33-
from ..events.event import Event
34-
from ..events.event_actions import EventActions
35-
from ..sessions.state import State
36-
from .invocation_context import InvocationContext
37-
38-
39-
class CallbackContext(ReadonlyContext):
40-
"""The context of various callbacks within an agent run."""
41-
42-
def __init__(
43-
self,
44-
invocation_context: InvocationContext,
45-
*,
46-
event_actions: Optional[EventActions] = None,
47-
) -> None:
48-
super().__init__(invocation_context)
49-
50-
from ..events.event_actions import EventActions
51-
from ..sessions.state import State
52-
53-
self._event_actions = event_actions or EventActions()
54-
self._state = State(
55-
value=invocation_context.session.state,
56-
delta=self._event_actions.state_delta,
57-
)
58-
59-
@property
60-
@override
61-
def state(self) -> State:
62-
"""The delta-aware state of the current session.
63-
64-
For any state change, you can mutate this object directly,
65-
e.g. `ctx.state['foo'] = 'bar'`
66-
"""
67-
return self._state
68-
69-
async def load_artifact(
70-
self, filename: str, version: Optional[int] = None
71-
) -> Optional[types.Part]:
72-
"""Loads an artifact attached to the current session.
73-
74-
Args:
75-
filename: The filename of the artifact.
76-
version: The version of the artifact. If None, the latest version will be
77-
returned.
78-
79-
Returns:
80-
The artifact.
81-
"""
82-
if self._invocation_context.artifact_service is None:
83-
raise ValueError("Artifact service is not initialized.")
84-
return await self._invocation_context.artifact_service.load_artifact(
85-
app_name=self._invocation_context.app_name,
86-
user_id=self._invocation_context.user_id,
87-
session_id=self._invocation_context.session.id,
88-
filename=filename,
89-
version=version,
90-
)
91-
92-
async def save_artifact(
93-
self,
94-
filename: str,
95-
artifact: types.Part,
96-
custom_metadata: Optional[dict[str, Any]] = None,
97-
) -> int:
98-
"""Saves an artifact and records it as delta for the current session.
99-
100-
Args:
101-
filename: The filename of the artifact.
102-
artifact: The artifact to save.
103-
custom_metadata: Custom metadata to associate with the artifact.
104-
105-
Returns:
106-
The version of the artifact.
107-
"""
108-
if self._invocation_context.artifact_service is None:
109-
raise ValueError("Artifact service is not initialized.")
110-
version = await self._invocation_context.artifact_service.save_artifact(
111-
app_name=self._invocation_context.app_name,
112-
user_id=self._invocation_context.user_id,
113-
session_id=self._invocation_context.session.id,
114-
filename=filename,
115-
artifact=artifact,
116-
custom_metadata=custom_metadata,
117-
)
118-
self._event_actions.artifact_delta[filename] = version
119-
return version
120-
121-
async def get_artifact_version(
122-
self, filename: str, version: Optional[int] = None
123-
) -> Optional[ArtifactVersion]:
124-
"""Gets artifact version info.
125-
126-
Args:
127-
filename: The filename of the artifact.
128-
version: The version of the artifact. If None, the latest version will be
129-
returned.
130-
131-
Returns:
132-
The artifact version info.
133-
"""
134-
if self._invocation_context.artifact_service is None:
135-
raise ValueError("Artifact service is not initialized.")
136-
return await self._invocation_context.artifact_service.get_artifact_version(
137-
app_name=self._invocation_context.app_name,
138-
user_id=self._invocation_context.user_id,
139-
session_id=self._invocation_context.session.id,
140-
filename=filename,
141-
version=version,
142-
)
143-
144-
async def list_artifacts(self) -> list[str]:
145-
"""Lists the filenames of the artifacts attached to the current session."""
146-
if self._invocation_context.artifact_service is None:
147-
raise ValueError("Artifact service is not initialized.")
148-
return await self._invocation_context.artifact_service.list_artifact_keys(
149-
app_name=self._invocation_context.app_name,
150-
user_id=self._invocation_context.user_id,
151-
session_id=self._invocation_context.session.id,
152-
)
153-
154-
async def save_credential(self, auth_config: AuthConfig) -> None:
155-
"""Saves a credential to the credential service.
156-
157-
Args:
158-
auth_config: The authentication configuration containing the credential.
159-
"""
160-
if self._invocation_context.credential_service is None:
161-
raise ValueError("Credential service is not initialized.")
162-
await self._invocation_context.credential_service.save_credential(
163-
auth_config, self
164-
)
165-
166-
async def load_credential(
167-
self, auth_config: AuthConfig
168-
) -> Optional[AuthCredential]:
169-
"""Loads a credential from the credential service.
170-
171-
Args:
172-
auth_config: The authentication configuration for the credential.
173-
174-
Returns:
175-
The loaded credential, or None if not found.
176-
"""
177-
if self._invocation_context.credential_service is None:
178-
raise ValueError("Credential service is not initialized.")
179-
return await self._invocation_context.credential_service.load_credential(
180-
auth_config, self
181-
)
182-
183-
def get_auth_response(
184-
self, auth_config: AuthConfig
185-
) -> Optional[AuthCredential]:
186-
"""Gets the auth response credential from session state.
187-
188-
This method retrieves an authentication credential that was previously
189-
stored in session state after a user completed an OAuth flow or other
190-
authentication process.
191-
192-
Args:
193-
auth_config: The authentication configuration for the credential.
194-
195-
Returns:
196-
The auth credential from the auth response, or None if not found.
197-
"""
198-
from ..auth.auth_handler import AuthHandler
199-
200-
return AuthHandler(auth_config).get_auth_response(self.state)
201-
202-
async def add_session_to_memory(self) -> None:
203-
"""Triggers memory generation for the current session.
204-
205-
This method saves the current session's events to the memory service,
206-
enabling the agent to recall information from past interactions.
207-
208-
Raises:
209-
ValueError: If memory service is not available.
210-
211-
Example:
212-
```python
213-
async def my_after_agent_callback(callback_context: CallbackContext):
214-
# Save conversation to memory at the end of each interaction
215-
await callback_context.add_session_to_memory()
216-
```
217-
"""
218-
if self._invocation_context.memory_service is None:
219-
raise ValueError(
220-
"Cannot add session to memory: memory service is not available."
221-
)
222-
await self._invocation_context.memory_service.add_session_to_memory(
223-
self._invocation_context.session
224-
)
225-
226-
async def add_events_to_memory(
227-
self,
228-
*,
229-
events: Sequence[Event],
230-
custom_metadata: Mapping[str, object] | None = None,
231-
) -> None:
232-
"""Adds an explicit list of events to the memory service.
233-
234-
Uses this callback's current session identifiers as memory scope.
235-
236-
Args:
237-
events: Explicit events to add to memory.
238-
custom_metadata: Optional standard metadata for memory generation.
239-
240-
Raises:
241-
ValueError: If memory service is not available.
242-
"""
243-
if self._invocation_context.memory_service is None:
244-
raise ValueError(
245-
"Cannot add events to memory: memory service is not available."
246-
)
247-
await self._invocation_context.memory_service.add_events_to_memory(
248-
app_name=self._invocation_context.session.app_name,
249-
user_id=self._invocation_context.session.user_id,
250-
session_id=self._invocation_context.session.id,
251-
events=events,
252-
custom_metadata=custom_metadata,
253-
)
21+
# CallbackContext is unified into Context
22+
CallbackContext = Context

src/google/adk/agents/context.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,21 @@ def function_call_id(self) -> str | None:
7676
"""The function call id of the current tool call."""
7777
return self._function_call_id
7878

79+
@function_call_id.setter
80+
def function_call_id(self, value: str | None) -> None:
81+
"""Sets the function call id of the current tool call."""
82+
self._function_call_id = value
83+
7984
@property
8085
def tool_confirmation(self) -> ToolConfirmation | None:
8186
"""The tool confirmation of the current tool call."""
8287
return self._tool_confirmation
8388

89+
@tool_confirmation.setter
90+
def tool_confirmation(self, value: ToolConfirmation | None) -> None:
91+
"""Sets the tool confirmation of the current tool call."""
92+
self._tool_confirmation = value
93+
8494
@property
8595
@override
8696
def state(self) -> State:

src/google/adk/tools/tool_context.py

Lines changed: 8 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -14,93 +14,17 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any
18-
from typing import Optional
19-
from typing import TYPE_CHECKING
20-
17+
# Keep CallbackContext for backward compatibility
2118
from ..agents.callback_context import CallbackContext
19+
from ..agents.context import Context
20+
# Keep AuthCredential for backward compatibility
2221
from ..auth.auth_credential import AuthCredential
22+
# Keep AuthHandler for backward compatibility
2323
from ..auth.auth_handler import AuthHandler
24+
# Keep AuthConfig for backward compatibility
2425
from ..auth.auth_tool import AuthConfig
26+
# Keep ToolConfirmation for backward compatibility
2527
from .tool_confirmation import ToolConfirmation
2628

27-
if TYPE_CHECKING:
28-
from ..agents.invocation_context import InvocationContext
29-
from ..events.event_actions import EventActions
30-
from ..memory.base_memory_service import SearchMemoryResponse
31-
32-
33-
class ToolContext(CallbackContext):
34-
"""The context of the tool.
35-
36-
This class provides the context for a tool invocation, including access to
37-
the invocation context, function call ID, event actions, and authentication
38-
response. It also provides methods for requesting credentials, retrieving
39-
authentication responses, listing artifacts, and searching memory.
40-
41-
Attributes:
42-
invocation_context: The invocation context of the tool.
43-
function_call_id: The function call id of the current tool call. This id was
44-
returned in the function call event from LLM to identify a function call.
45-
If LLM didn't return this id, ADK will assign one to it. This id is used
46-
to map function call response to the original function call.
47-
event_actions: The event actions of the current tool call.
48-
tool_confirmation: The tool confirmation of the current tool call.
49-
"""
50-
51-
def __init__(
52-
self,
53-
invocation_context: InvocationContext,
54-
*,
55-
function_call_id: Optional[str] = None,
56-
event_actions: Optional[EventActions] = None,
57-
tool_confirmation: Optional[ToolConfirmation] = None,
58-
):
59-
super().__init__(invocation_context, event_actions=event_actions)
60-
self.function_call_id = function_call_id
61-
self.tool_confirmation = tool_confirmation
62-
63-
@property
64-
def actions(self) -> EventActions:
65-
return self._event_actions
66-
67-
def request_credential(self, auth_config: AuthConfig) -> None:
68-
if not self.function_call_id:
69-
raise ValueError('function_call_id is not set.')
70-
self._event_actions.requested_auth_configs[self.function_call_id] = (
71-
AuthHandler(auth_config).generate_auth_request()
72-
)
73-
74-
def get_auth_response(self, auth_config: AuthConfig) -> AuthCredential:
75-
return AuthHandler(auth_config).get_auth_response(self.state)
76-
77-
def request_confirmation(
78-
self,
79-
*,
80-
hint: Optional[str] = None,
81-
payload: Optional[Any] = None,
82-
) -> None:
83-
"""Requests confirmation for the given function call.
84-
85-
Args:
86-
hint: A hint to the user on how to confirm the tool call.
87-
payload: The payload used to confirm the tool call.
88-
"""
89-
if not self.function_call_id:
90-
raise ValueError('function_call_id is not set.')
91-
self._event_actions.requested_tool_confirmations[self.function_call_id] = (
92-
ToolConfirmation(
93-
hint=hint,
94-
payload=payload,
95-
)
96-
)
97-
98-
async def search_memory(self, query: str) -> SearchMemoryResponse:
99-
"""Searches the memory of the current user."""
100-
if self._invocation_context.memory_service is None:
101-
raise ValueError('Memory service is not available.')
102-
return await self._invocation_context.memory_service.search_memory(
103-
app_name=self._invocation_context.app_name,
104-
user_id=self._invocation_context.user_id,
105-
query=query,
106-
)
29+
# ToolContext is unified into Context
30+
ToolContext = Context

tests/unittests/tools/openapi_tool/openapi_spec_parser/test_tool_auth_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ async def test_openid_connect_with_auth_response(
172172
oauth2=OAuth2Auth(auth_response_uri='test_auth_response_uri'),
173173
)
174174
mock_auth_handler.get_auth_response.return_value = returned_credential
175-
mock_auth_handler_path = 'google.adk.tools.tool_context.AuthHandler'
175+
mock_auth_handler_path = 'google.adk.auth.auth_handler.AuthHandler'
176176
monkeypatch.setattr(
177177
mock_auth_handler_path, lambda *args, **kwargs: mock_auth_handler
178178
)

0 commit comments

Comments
 (0)