From cc6b946b22185ce7ce6af05c34a62cf5ad4c3a0f Mon Sep 17 00:00:00 2001 From: weimch Date: Mon, 1 Jun 2026 19:09:59 +0800 Subject: [PATCH] =?UTF-8?q?Bugfix:=20=E4=BF=AE=E5=A4=8DSqlSessionService?= =?UTF-8?q?=E5=9C=A8sqlite=E4=B8=8B=E5=9B=A0=E6=97=B6=E5=8C=BA=E8=AE=BE?= =?UTF-8?q?=E7=BD=AE=E4=B8=8D=E5=AF=B9=E5=AF=BC=E8=87=B4=E9=A2=91=E7=B9=81?= =?UTF-8?q?warn?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 问题:当前在SqlSessionService的实现中,默认创建DB的表中,update_time使用了sqlalchemy的func.now,但在更新时间时,使用了datetime.now,在sqlite实现里,func.now默认使用了utc的时间,而datetime.now不是utc的时间,导致append_event时,因时区不同,导致diff失败,warn告警 - 解决方案:总是使用func.now来更新时间 --- .../sessions/_sql_session_service.py | 52 ++++++++++++++----- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/trpc_agent_sdk/sessions/_sql_session_service.py b/trpc_agent_sdk/sessions/_sql_session_service.py index dc07e2b..82dd41a 100644 --- a/trpc_agent_sdk/sessions/_sql_session_service.py +++ b/trpc_agent_sdk/sessions/_sql_session_service.py @@ -103,6 +103,25 @@ class SessionStorageBase(DeclarativeBase): pass +def _storage_dialect_name(storage: SessionStorageBase) -> Optional[str]: + orm_session = inspect(storage).session + if orm_session is None or orm_session.bind is None: + return None + return orm_session.bind.dialect.name + + +def _timestamp_tz(value: datetime, dialect_name: Optional[str]) -> float: + if dialect_name == "sqlite": + return value.replace(tzinfo=timezone.utc).timestamp() + return value.timestamp() + + +def _expire_before(sql_session: SqlSession, ttl_seconds: int) -> datetime: + if sql_session.bind is not None and sql_session.bind.dialect.name == "sqlite": + return datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(seconds=ttl_seconds) + return datetime.now() - timedelta(seconds=ttl_seconds) + + class StorageSession(SessionStorageBase): """Represents a session stored in the database with TTL support. @@ -135,14 +154,11 @@ def __repr__(self): @property def _dialect_name(self) -> Optional[str]: - session = inspect(self).session - return session.bind.dialect.name if session else None # type: ignore + return _storage_dialect_name(self) @property def update_timestamp_tz(self) -> float: - if self._dialect_name == "sqlite": - return self.update_time.replace(tzinfo=timezone.utc).timestamp() - return self.update_time.timestamp() + return _timestamp_tz(self.update_time, self._dialect_name) def to_session( self, @@ -306,6 +322,10 @@ class StorageAppState(SessionStorageBase): state: Mapped[MutableDict[str, Any]] = mapped_column(MutableDict.as_mutable(DynamicJSON), default={}) update_time: Mapped[datetime] = mapped_column(PreciseTimestamp, default=func.now(), onupdate=func.now()) + @property + def update_timestamp_tz(self) -> float: + return _timestamp_tz(self.update_time, _storage_dialect_name(self)) + class StorageUserState(SessionStorageBase): """Represents a user state stored in the database with TTL support. @@ -319,6 +339,10 @@ class StorageUserState(SessionStorageBase): state: Mapped[MutableDict[str, Any]] = mapped_column(MutableDict.as_mutable(DynamicJSON), default={}) update_time: Mapped[datetime] = mapped_column(PreciseTimestamp, default=func.now(), onupdate=func.now()) + @property + def update_timestamp_tz(self) -> float: + return _timestamp_tz(self.update_time, _storage_dialect_name(self)) + class SqlSessionService(BaseSessionService): """A SQL database implementation of the session service. @@ -452,7 +476,7 @@ async def list_sessions(self, *, app_name: str, user_id: str) -> ListSessionsRes sessions = [] for storage_session in results: - if self._session_config.is_expired_by_timestamp(storage_session.update_time.timestamp()): + if self._session_config.is_expired_by_timestamp(storage_session.update_timestamp_tz): logger.debug("Cleaned up expired session: %s/%s/%s", storage_session.app_name, storage_session.user_id, storage_session.id) continue @@ -593,7 +617,7 @@ async def _update_app_state(self, sql_session: SqlSession, app_name: str, state_ await self._sql_storage.add(sql_session, storage_app_state) else: storage_app_state.state = app_state # type: ignore - storage_app_state.update_time = datetime.now() + storage_app_state.update_time = func.now() return app_state @@ -621,9 +645,9 @@ async def _get_app_state(self, sql_session: SqlSession, app_name: str) -> dict[s app_state = {} if storage_app_state: - if not self._session_config.is_expired_by_timestamp(storage_app_state.update_time.timestamp()): + if not self._session_config.is_expired_by_timestamp(storage_app_state.update_timestamp_tz): app_state = storage_app_state.state - storage_app_state.update_time = datetime.now() + storage_app_state.update_time = func.now() await self._sql_storage.commit(sql_session) return app_state @@ -634,9 +658,9 @@ async def _get_user_state(self, sql_session: SqlSession, app_name: str, user_id: user_state = {} if storage_user_state: - if not self._session_config.is_expired_by_timestamp(storage_user_state.update_time.timestamp()): + if not self._session_config.is_expired_by_timestamp(storage_user_state.update_timestamp_tz): user_state = storage_user_state.state - storage_user_state.update_time = datetime.now() + storage_user_state.update_time = func.now() await self._sql_storage.commit(sql_session) return user_state @@ -648,11 +672,11 @@ async def _get_session(self, sql_session: SqlSession, app_name: str, user_id: st if storage_session is None: return None - if self._session_config.is_expired_by_timestamp(storage_session.update_time.timestamp()): + if self._session_config.is_expired_by_timestamp(storage_session.update_timestamp_tz): logger.debug("Session %s is expired", session_id) return None - storage_session.update_time = datetime.now() + storage_session.update_time = func.now() await self._sql_storage.commit(sql_session) return storage_session @@ -665,7 +689,7 @@ async def _cleanup_expired_async(self) -> None: """ async with self._sql_storage.create_db_session() as sql_session: # Calculate expiration threshold once in application time for cross-database compatibility. - expire_before = datetime.now() - timedelta(seconds=self._session_config.ttl.ttl_seconds) + expire_before = _expire_before(sql_session, self._session_config.ttl.ttl_seconds) total_deleted = 0 # Batch delete expired sessions