Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 38 additions & 14 deletions trpc_agent_sdk/sessions/_sql_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,25 @@ class SessionStorageBase(DeclarativeBase):
pass


def _storage_dialect_name(storage: SessionStorageBase) -> Optional[str]:
orm_session = inspect(storage).session
if orm_session is None or orm_session.bind is None:
return None
return orm_session.bind.dialect.name


def _timestamp_tz(value: datetime, dialect_name: Optional[str]) -> float:
if dialect_name == "sqlite":
return value.replace(tzinfo=timezone.utc).timestamp()
return value.timestamp()


def _expire_before(sql_session: SqlSession, ttl_seconds: int) -> datetime:
if sql_session.bind is not None and sql_session.bind.dialect.name == "sqlite":
return datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(seconds=ttl_seconds)
return datetime.now() - timedelta(seconds=ttl_seconds)


class StorageSession(SessionStorageBase):
"""Represents a session stored in the database with TTL support.

Expand Down Expand Up @@ -135,14 +154,11 @@ def __repr__(self):

@property
def _dialect_name(self) -> Optional[str]:
session = inspect(self).session
return session.bind.dialect.name if session else None # type: ignore
return _storage_dialect_name(self)

@property
def update_timestamp_tz(self) -> float:
if self._dialect_name == "sqlite":
return self.update_time.replace(tzinfo=timezone.utc).timestamp()
return self.update_time.timestamp()
return _timestamp_tz(self.update_time, self._dialect_name)

def to_session(
self,
Expand Down Expand Up @@ -306,6 +322,10 @@ class StorageAppState(SessionStorageBase):
state: Mapped[MutableDict[str, Any]] = mapped_column(MutableDict.as_mutable(DynamicJSON), default={})
update_time: Mapped[datetime] = mapped_column(PreciseTimestamp, default=func.now(), onupdate=func.now())

@property
def update_timestamp_tz(self) -> float:
return _timestamp_tz(self.update_time, _storage_dialect_name(self))


class StorageUserState(SessionStorageBase):
"""Represents a user state stored in the database with TTL support.
Expand All @@ -319,6 +339,10 @@ class StorageUserState(SessionStorageBase):
state: Mapped[MutableDict[str, Any]] = mapped_column(MutableDict.as_mutable(DynamicJSON), default={})
update_time: Mapped[datetime] = mapped_column(PreciseTimestamp, default=func.now(), onupdate=func.now())

@property
def update_timestamp_tz(self) -> float:
return _timestamp_tz(self.update_time, _storage_dialect_name(self))


class SqlSessionService(BaseSessionService):
"""A SQL database implementation of the session service.
Expand Down Expand Up @@ -452,7 +476,7 @@ async def list_sessions(self, *, app_name: str, user_id: str) -> ListSessionsRes

sessions = []
for storage_session in results:
if self._session_config.is_expired_by_timestamp(storage_session.update_time.timestamp()):
if self._session_config.is_expired_by_timestamp(storage_session.update_timestamp_tz):
logger.debug("Cleaned up expired session: %s/%s/%s", storage_session.app_name,
storage_session.user_id, storage_session.id)
continue
Expand Down Expand Up @@ -593,7 +617,7 @@ async def _update_app_state(self, sql_session: SqlSession, app_name: str, state_
await self._sql_storage.add(sql_session, storage_app_state)
else:
storage_app_state.state = app_state # type: ignore
storage_app_state.update_time = datetime.now()
storage_app_state.update_time = func.now()

return app_state

Expand Down Expand Up @@ -621,9 +645,9 @@ async def _get_app_state(self, sql_session: SqlSession, app_name: str) -> dict[s

app_state = {}
if storage_app_state:
if not self._session_config.is_expired_by_timestamp(storage_app_state.update_time.timestamp()):
if not self._session_config.is_expired_by_timestamp(storage_app_state.update_timestamp_tz):
app_state = storage_app_state.state
storage_app_state.update_time = datetime.now()
storage_app_state.update_time = func.now()
await self._sql_storage.commit(sql_session)

return app_state
Expand All @@ -634,9 +658,9 @@ async def _get_user_state(self, sql_session: SqlSession, app_name: str, user_id:

user_state = {}
if storage_user_state:
if not self._session_config.is_expired_by_timestamp(storage_user_state.update_time.timestamp()):
if not self._session_config.is_expired_by_timestamp(storage_user_state.update_timestamp_tz):
user_state = storage_user_state.state
storage_user_state.update_time = datetime.now()
storage_user_state.update_time = func.now()
await self._sql_storage.commit(sql_session)

return user_state
Expand All @@ -648,11 +672,11 @@ async def _get_session(self, sql_session: SqlSession, app_name: str, user_id: st
if storage_session is None:
return None

if self._session_config.is_expired_by_timestamp(storage_session.update_time.timestamp()):
if self._session_config.is_expired_by_timestamp(storage_session.update_timestamp_tz):
logger.debug("Session %s is expired", session_id)
return None

storage_session.update_time = datetime.now()
storage_session.update_time = func.now()
await self._sql_storage.commit(sql_session)

return storage_session
Expand All @@ -665,7 +689,7 @@ async def _cleanup_expired_async(self) -> None:
"""
async with self._sql_storage.create_db_session() as sql_session:
# Calculate expiration threshold once in application time for cross-database compatibility.
expire_before = datetime.now() - timedelta(seconds=self._session_config.ttl.ttl_seconds)
expire_before = _expire_before(sql_session, self._session_config.ttl.ttl_seconds)
total_deleted = 0

# Batch delete expired sessions
Expand Down
Loading