forked from google/adk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathadk_web_server_client.py
More file actions
303 lines (259 loc) · 9.39 KB
/
adk_web_server_client.py
File metadata and controls
303 lines (259 loc) · 9.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
"""HTTP client for interacting with the ADK web server."""
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from contextlib import asynccontextmanager
import json
import logging
from typing import Any
from typing import AsyncGenerator
from typing import Dict
from typing import Literal
from typing import Optional
import httpx
from ...artifacts.base_artifact_service import ArtifactVersion
from ...events.event import Event
from ...sessions.session import Session
from ..adk_web_server import RunAgentRequest
logger = logging.getLogger("google_adk." + __name__)
class AdkWebServerClient:
"""HTTP client for interacting with the ADK web server for conformance tests.
Usage patterns:
# Pattern 1: Manual lifecycle management
client = AdkWebServerClient()
session = await client.create_session(app_name="app", user_id="user")
async for event in client.run_agent(request):
# Process events...
await client.close() # Optional explicit cleanup
# Pattern 2: Automatic cleanup with context manager (recommended)
async with AdkWebServerClient() as client:
session = await client.create_session(app_name="app", user_id="user")
async for event in client.run_agent(request):
# Process events...
# Client automatically closed here
"""
def __init__(
self, base_url: str = "http://127.0.0.1:8000", timeout: float = 30.0
):
"""Initialize the ADK web server client for conformance testing.
Args:
base_url: Base URL of the ADK web server (default: http://127.0.0.1:8000)
timeout: Request timeout in seconds (default: 30.0)
"""
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self._client: Optional[httpx.AsyncClient] = None
@asynccontextmanager
async def _get_client(self) -> AsyncGenerator[httpx.AsyncClient, None]:
"""Get or create an HTTP client with proper lifecycle management.
Returns:
AsyncGenerator yielding the HTTP client instance.
"""
if self._client is None:
self._client = httpx.AsyncClient(
base_url=self.base_url,
timeout=httpx.Timeout(self.timeout),
)
try:
yield self._client
finally:
pass # Keep client alive for reuse
async def close(self) -> None:
"""Close the HTTP client and clean up resources."""
if self._client:
await self._client.aclose()
self._client = None
async def __aenter__(self) -> "AdkWebServerClient":
"""Async context manager entry.
Returns:
The client instance for use in the async context.
"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: # pylint: disable=unused-argument
"""Async context manager exit that closes the HTTP client."""
await self.close()
async def get_session(
self, *, app_name: str, user_id: str, session_id: str
) -> Session:
"""Retrieve a specific session from the ADK web server.
Args:
app_name: Name of the application
user_id: User identifier
session_id: Session identifier
Returns:
The requested Session object
Raises:
httpx.HTTPStatusError: If the request fails or session not found
"""
async with self._get_client() as client:
response = await client.get(
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}"
)
response.raise_for_status()
return Session.model_validate(response.json())
async def create_session(
self,
*,
app_name: str,
user_id: str,
state: Optional[Dict[str, Any]] = None,
) -> Session:
"""Create a new session in the ADK web server.
Args:
app_name: Name of the application
user_id: User identifier
state: Optional initial state for the session
Returns:
The newly created Session object
Raises:
httpx.HTTPStatusError: If the request fails
"""
async with self._get_client() as client:
payload = {}
if state is not None:
payload["state"] = state
response = await client.post(
f"/apps/{app_name}/users/{user_id}/sessions",
json=payload,
)
response.raise_for_status()
return Session.model_validate(response.json())
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
"""Delete a session from the ADK web server.
Args:
app_name: Name of the application
user_id: User identifier
session_id: Session identifier to delete
Raises:
httpx.HTTPStatusError: If the request fails or session not found
"""
async with self._get_client() as client:
response = await client.delete(
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}"
)
response.raise_for_status()
async def update_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
state_delta: Dict[str, Any],
) -> Session:
"""Update session state without running the agent.
Args:
app_name: Name of the application
user_id: User identifier
session_id: Session identifier to update
state_delta: The state changes to apply to the session
Returns:
The updated Session object
Raises:
httpx.HTTPStatusError: If the request fails or session not found
"""
async with self._get_client() as client:
response = await client.patch(
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
json={"state_delta": state_delta},
)
response.raise_for_status()
return Session.model_validate(response.json())
async def run_agent(
self,
request: RunAgentRequest,
mode: Optional[Literal["record", "replay"]] = None,
test_case_dir: Optional[str] = None,
user_message_index: Optional[int] = None,
) -> AsyncGenerator[Event, None]:
"""Run an agent with streaming Server-Sent Events response.
Args:
request: The RunAgentRequest containing agent execution parameters
mode: Optional conformance mode ("record" or "replay") to trigger recording
test_case_dir: Optional test case directory path for conformance recording
user_message_index: Optional user message index for conformance recording
Yields:
Event objects streamed from the agent execution
Raises:
ValueError: If mode is provided but test_case_dir or user_message_index is None
httpx.HTTPStatusError: If the request fails
json.JSONDecodeError: If event data cannot be parsed
"""
# Add recording parameters to state_delta for conformance tests
if mode:
if test_case_dir is None or user_message_index is None:
raise ValueError(
"test_case_dir and user_message_index must be provided when mode is"
" specified"
)
# Modify request state_delta in place
if request.state_delta is None:
request.state_delta = {}
if mode == "replay":
request.state_delta["_adk_replay_config"] = {
"dir": str(test_case_dir),
"user_message_index": user_message_index,
}
else: # record mode
request.state_delta["_adk_recordings_config"] = {
"dir": str(test_case_dir),
"user_message_index": user_message_index,
}
async with self._get_client() as client:
async with client.stream(
"POST",
"/run_sse",
json=request.model_dump(by_alias=True, exclude_none=True),
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data:") and (data := line[5:].strip()):
event_data = json.loads(data)
yield Event.model_validate(event_data)
else:
logger.debug("Non data line received: %s", line)
async def get_artifact_version_metadata(
self,
*,
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
version: int,
) -> ArtifactVersion:
"""Retrieve metadata for a specific artifact version."""
async with self._get_client() as client:
response = await client.get((
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}"
f"/artifacts/{artifact_name}/versions/{version}/metadata"
))
response.raise_for_status()
return ArtifactVersion.model_validate(response.json())
async def list_artifact_versions_metadata(
self,
*,
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
) -> list[ArtifactVersion]:
"""List metadata for all versions of an artifact."""
async with self._get_client() as client:
response = await client.get((
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}"
f"/artifacts/{artifact_name}/versions/metadata"
))
response.raise_for_status()
return [ArtifactVersion.model_validate(item) for item in response.json()]