Skip to content

Commit 7937e9e

Browse files
yinghsienwucopybara-github
authored andcommitted
fix: Use async with for API client in Vertex session management to ensure proper resource management
PiperOrigin-RevId: 830485639
1 parent c485889 commit 7937e9e

2 files changed

Lines changed: 87 additions & 81 deletions

File tree

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 70 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -107,19 +107,18 @@ async def create_session(
107107
)
108108

109109
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
110-
api_client = self._get_api_client()
111110

112111
config = {'session_state': state} if state else {}
113112
config.update(kwargs)
114-
115-
api_response = await api_client.aio.agent_engines.sessions.create(
116-
name=f'reasoningEngines/{reasoning_engine_id}',
117-
user_id=user_id,
118-
config=config,
119-
)
120-
logger.debug('Create session response: %s', api_response)
121-
get_session_response = api_response.response
122-
session_id = get_session_response.name.split('/')[-1]
113+
async with self._get_api_client() as api_client:
114+
api_response = await api_client.agent_engines.sessions.create(
115+
name=f'reasoningEngines/{reasoning_engine_id}',
116+
user_id=user_id,
117+
config=config,
118+
)
119+
logger.debug('Create session response: %s', api_response)
120+
get_session_response = api_response.response
121+
session_id = get_session_response.name.split('/')[-1]
123122

124123
session = Session(
125124
app_name=app_name,
@@ -140,30 +139,29 @@ async def get_session(
140139
config: Optional[GetSessionConfig] = None,
141140
) -> Optional[Session]:
142141
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
143-
api_client = self._get_api_client()
144142
session_resource_name = (
145143
f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}'
146144
)
147-
148-
# Get session resource and events in parallel.
149-
list_events_kwargs = {}
150-
if config and not config.num_recent_events and config.after_timestamp:
151-
# Filter events based on timestamp.
152-
list_events_kwargs['config'] = {
153-
'filter': 'timestamp>="{}"'.format(
154-
datetime.datetime.fromtimestamp(
155-
config.after_timestamp, tz=datetime.timezone.utc
156-
).isoformat()
157-
)
158-
}
159-
160-
get_session_response, events_iterator = await asyncio.gather(
161-
api_client.aio.agent_engines.sessions.get(name=session_resource_name),
162-
api_client.aio.agent_engines.sessions.events.list(
163-
name=session_resource_name,
164-
**list_events_kwargs,
165-
),
166-
)
145+
async with self._get_api_client() as api_client:
146+
# Get session resource and events in parallel.
147+
list_events_kwargs = {}
148+
if config and not config.num_recent_events and config.after_timestamp:
149+
# Filter events based on timestamp.
150+
list_events_kwargs['config'] = {
151+
'filter': 'timestamp>="{}"'.format(
152+
datetime.datetime.fromtimestamp(
153+
config.after_timestamp, tz=datetime.timezone.utc
154+
).isoformat()
155+
)
156+
}
157+
158+
get_session_response, events_iterator = await asyncio.gather(
159+
api_client.agent_engines.sessions.get(name=session_resource_name),
160+
api_client.agent_engines.sessions.events.list(
161+
name=session_resource_name,
162+
**list_events_kwargs,
163+
),
164+
)
167165

168166
if get_session_response.user_id != user_id:
169167
raise ValueError(
@@ -196,51 +194,52 @@ async def list_sessions(
196194
self, *, app_name: str, user_id: Optional[str] = None
197195
) -> ListSessionsResponse:
198196
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
199-
api_client = self._get_api_client()
200197

201-
sessions = []
202-
config = {}
203-
if user_id is not None:
204-
config['filter'] = f'user_id="{user_id}"'
205-
sessions_iterator = await api_client.aio.agent_engines.sessions.list(
206-
name=f'reasoningEngines/{reasoning_engine_id}',
207-
config=config,
208-
)
209-
210-
for api_session in sessions_iterator:
211-
sessions.append(
212-
Session(
213-
app_name=app_name,
214-
user_id=api_session.user_id,
215-
id=api_session.name.split('/')[-1],
216-
state=getattr(api_session, 'session_state', None) or {},
217-
last_update_time=api_session.update_time.timestamp(),
218-
)
198+
async with self._get_api_client() as api_client:
199+
sessions = []
200+
config = {}
201+
if user_id is not None:
202+
config['filter'] = f'user_id="{user_id}"'
203+
sessions_iterator = await api_client.agent_engines.sessions.list(
204+
name=f'reasoningEngines/{reasoning_engine_id}',
205+
config=config,
219206
)
220207

208+
for api_session in sessions_iterator:
209+
sessions.append(
210+
Session(
211+
app_name=app_name,
212+
user_id=api_session.user_id,
213+
id=api_session.name.split('/')[-1],
214+
state=getattr(api_session, 'session_state', None) or {},
215+
last_update_time=api_session.update_time.timestamp(),
216+
)
217+
)
218+
221219
return ListSessionsResponse(sessions=sessions)
222220

223221
async def delete_session(
224222
self, *, app_name: str, user_id: str, session_id: str
225223
) -> None:
226224
reasoning_engine_id = self._get_reasoning_engine_id(app_name)
227-
api_client = self._get_api_client()
228225

229-
try:
230-
await api_client.aio.agent_engines.sessions.delete(
231-
name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}',
232-
)
233-
except Exception as e:
234-
logger.error('Error deleting session %s: %s', session_id, e)
235-
raise e
226+
async with self._get_api_client() as api_client:
227+
try:
228+
await api_client.agent_engines.sessions.delete(
229+
name=(
230+
f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}'
231+
),
232+
)
233+
except Exception as e:
234+
logger.error('Error deleting session %s: %s', session_id, e)
235+
raise
236236

237237
@override
238238
async def append_event(self, session: Session, event: Event) -> Event:
239239
# Update the in-memory session.
240240
await super().append_event(session=session, event=event)
241241

242242
reasoning_engine_id = self._get_reasoning_engine_id(session.app_name)
243-
api_client = self._get_api_client()
244243

245244
config = {}
246245
if event.content:
@@ -284,15 +283,16 @@ async def append_event(self, session: Session, event: Event) -> Event:
284283
)
285284
config['event_metadata'] = metadata_dict
286285

287-
await api_client.aio.agent_engines.sessions.events.append(
288-
name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}',
289-
author=event.author,
290-
invocation_id=event.invocation_id,
291-
timestamp=datetime.datetime.fromtimestamp(
292-
event.timestamp, tz=datetime.timezone.utc
293-
),
294-
config=config,
295-
)
286+
async with self._get_api_client() as api_client:
287+
await api_client.agent_engines.sessions.events.append(
288+
name=f'reasoningEngines/{reasoning_engine_id}/sessions/{session.id}',
289+
author=event.author,
290+
invocation_id=event.invocation_id,
291+
timestamp=datetime.datetime.fromtimestamp(
292+
event.timestamp, tz=datetime.timezone.utc
293+
),
294+
config=config,
295+
)
296296
return event
297297

298298
def _get_reasoning_engine_id(self, app_name: str):
@@ -318,7 +318,7 @@ def _api_client_http_options_override(
318318
) -> Optional[Union[types.HttpOptions, types.HttpOptionsDict]]:
319319
return None
320320

321-
def _get_api_client(self) -> vertexai.Client:
321+
def _get_api_client(self) -> vertexai.AsyncClient:
322322
"""Instantiates an API client for the given project and location.
323323
324324
Returns:
@@ -331,7 +331,7 @@ def _get_api_client(self) -> vertexai.Client:
331331
location=self._location,
332332
http_options=self._api_client_http_options_override(),
333333
api_key=self._express_mode_api_key,
334-
)
334+
).aio
335335

336336

337337
def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -228,24 +228,30 @@ def _convert_to_object(data):
228228
return data
229229

230230

231-
class MockApiClient:
231+
class MockAsyncClient:
232232
"""Mocks the API Client."""
233233

234234
def __init__(self) -> None:
235235
"""Initializes MockClient."""
236236
self.session_dict: dict[str, Any] = {}
237237
self.event_dict: dict[str, Tuple[List[Any], Optional[str]]] = {}
238-
self.aio = mock.Mock()
239-
self.aio.agent_engines.sessions.get.side_effect = self._get_session
240-
self.aio.agent_engines.sessions.list.side_effect = self._list_sessions
241-
self.aio.agent_engines.sessions.delete.side_effect = self._delete_session
242-
self.aio.agent_engines.sessions.create.side_effect = self._create_session
243-
self.aio.agent_engines.sessions.events.list.side_effect = self._list_events
244-
self.aio.agent_engines.sessions.events.append.side_effect = (
245-
self._append_event
246-
)
238+
self.agent_engines = mock.AsyncMock()
239+
self.agent_engines.sessions.get.side_effect = self._get_session
240+
self.agent_engines.sessions.list.side_effect = self._list_sessions
241+
self.agent_engines.sessions.delete.side_effect = self._delete_session
242+
self.agent_engines.sessions.create.side_effect = self._create_session
243+
self.agent_engines.sessions.events.list.side_effect = self._list_events
244+
self.agent_engines.sessions.events.append.side_effect = self._append_event
247245
self.last_create_session_config: dict[str, Any] = {}
248246

247+
async def __aenter__(self):
248+
"""Enters the asynchronous context."""
249+
return self
250+
251+
async def __aexit__(self, exc_type, exc_val, exc_tb):
252+
"""Exits the asynchronous context."""
253+
pass
254+
249255
async def _get_session(self, name: str):
250256
session_id = name.split('/')[-1]
251257
if session_id in self.session_dict:
@@ -374,7 +380,7 @@ def mock_vertex_ai_session_service(
374380
@pytest.fixture
375381
def mock_api_client_instance():
376382
"""Creates a mock API client instance for testing."""
377-
api_client = MockApiClient()
383+
api_client = MockAsyncClient()
378384
api_client.session_dict = {
379385
'1': MOCK_SESSION_JSON_1,
380386
'2': MOCK_SESSION_JSON_2,

0 commit comments

Comments
 (0)