forked from google/adk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase_session_service.py
More file actions
153 lines (124 loc) · 4.63 KB
/
base_session_service.py
File metadata and controls
153 lines (124 loc) · 4.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
from typing import Any
from typing import Optional
from pydantic import BaseModel
from pydantic import Field
from ..events.event import Event
from .session import Session
from .state import State
class GetSessionConfig(BaseModel):
"""The configuration of getting a session."""
num_recent_events: Optional[int] = None
after_timestamp: Optional[float] = None
class ListSessionsResponse(BaseModel):
"""The response of listing sessions.
The events and states are not set within each Session object.
"""
sessions: list[Session] = Field(default_factory=list)
class BaseSessionService(abc.ABC):
"""Base class for session services.
The service provides a set of methods for managing sessions and events.
"""
@abc.abstractmethod
async def create_session(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
"""Creates a new session.
Args:
app_name: the name of the app.
user_id: the id of the user.
state: the initial state of the session.
session_id: the client-provided id of the session. If not provided, a
generated ID will be used.
Returns:
session: The newly created session instance.
"""
@abc.abstractmethod
async def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
"""Gets a session."""
@abc.abstractmethod
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
) -> ListSessionsResponse:
"""Lists all the sessions for a user.
Args:
app_name: The name of the app.
user_id: The ID of the user. If not provided, lists all sessions for all
users.
Returns:
A ListSessionsResponse containing the sessions.
"""
@abc.abstractmethod
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
"""Deletes a session."""
async def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object."""
if event.partial:
return event
# Apply temp-scoped state to the in-memory session BEFORE trimming the
# event delta, so that subsequent agents within the same invocation can
# read temp values (e.g. output_key='temp:my_key' in SequentialAgent).
self._apply_temp_state(session, event)
event = self._trim_temp_delta_state(event)
self._update_session_state(session, event)
session.events.append(event)
return event
def _apply_temp_state(self, session: Session, event: Event) -> None:
"""Applies temp-scoped state delta to the in-memory session state.
Temp state is ephemeral: it lives in the session's in-memory state for
the duration of the current invocation but is NOT persisted to storage
(the event delta is trimmed separately by _trim_temp_delta_state).
"""
if not event.actions or not event.actions.state_delta:
return
for key, value in event.actions.state_delta.items():
if key.startswith(State.TEMP_PREFIX):
session.state[key] = value
def _trim_temp_delta_state(self, event: Event) -> Event:
"""Removes temporary state delta keys from the event.
This prevents temp-scoped state from being persisted, while the
in-memory session state (updated by _apply_temp_state) retains the
values for the duration of the current invocation.
"""
if not event.actions or not event.actions.state_delta:
return event
event.actions.state_delta = {
key: value
for key, value in event.actions.state_delta.items()
if not key.startswith(State.TEMP_PREFIX)
}
return event
def _update_session_state(self, session: Session, event: Event) -> None:
"""Updates the session state based on the event."""
if not event.actions or not event.actions.state_delta:
return
for key, value in event.actions.state_delta.items():
session.state.update({key: value})