Skip to content
48 changes: 34 additions & 14 deletions src/google/adk/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_session_service import BaseSessionService
from .in_memory_session_service import InMemorySessionService
from .session import Session
from .state import State
from .vertex_ai_session_service import VertexAiSessionService

try:
from .database_session_service import DatabaseSessionService
except ImportError:
# This handles the case where optional dependencies (like sqlalchemy)
# are not installed. A placeholder class ensures the symbol is always
# available for documentation tools and static analysis.
class DatabaseSessionService(BaseSessionService):
"""Placeholder for DatabaseSessionService when dependencies are not installed."""

_ERROR_MESSAGE = (
'DatabaseSessionService requires sqlalchemy>=2.0, please ensure it is'
' installed correctly.'
)

def __init__(self, *args, **kwargs):
raise ImportError(self._ERROR_MESSAGE)
Comment thread
Akshat8510 marked this conversation as resolved.

async def create_session(self, *args, **kwargs):
raise ImportError(self._ERROR_MESSAGE)

async def get_session(self, *args, **kwargs):
raise ImportError(self._ERROR_MESSAGE)

async def list_sessions(self, *args, **kwargs):
raise ImportError(self._ERROR_MESSAGE)

async def delete_session(self, *args, **kwargs):
raise ImportError(self._ERROR_MESSAGE)

async def append_event(self, *args, **kwargs):
raise ImportError(self._ERROR_MESSAGE)
Comment thread
Akshat8510 marked this conversation as resolved.
Comment thread
Akshat8510 marked this conversation as resolved.
Comment thread
Akshat8510 marked this conversation as resolved.
Comment thread
Akshat8510 marked this conversation as resolved.
Comment thread
Akshat8510 marked this conversation as resolved.


__all__ = [
'BaseSessionService',
'DatabaseSessionService',
Expand All @@ -25,17 +59,3 @@
'State',
'VertexAiSessionService',
]


def __getattr__(name: str):
if name == 'DatabaseSessionService':
try:
from .database_session_service import DatabaseSessionService

return DatabaseSessionService
except ImportError as e:
raise ImportError(
'DatabaseSessionService requires sqlalchemy>=2.0, please ensure it is'
' installed correctly.'
) from e
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')