diff --git a/pyproject.toml b/pyproject.toml index 11afcd8..7878921 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ # go/keep-sorted end "orjson>=3.11.3", ] + dynamic = ["version"] [project.urls] @@ -45,6 +46,11 @@ test = [ "pytest>=8.4.2", "pytest-asyncio>=1.2.0", ] +sqlalchemy = [ + # Required for DatabaseMemoryService; aiosqlite enables SQLite async + "sqlalchemy[asyncio]>=2.0.0", + "aiosqlite>=0.19.0", +] [tool.pyink] @@ -74,6 +80,8 @@ build-backend = "flit_core.buildapi" dev = [ "pytest>=8.4.2", "pytest-asyncio>=1.2.0", + "sqlalchemy[asyncio]>=2.0.0", + "aiosqlite>=0.19.0", ] diff --git a/src/google/adk_community/__init__.py b/src/google/adk_community/__init__.py index 9a1dc35..98c0e33 100644 --- a/src/google/adk_community/__init__.py +++ b/src/google/adk_community/__init__.py @@ -15,4 +15,5 @@ from . import memory from . import sessions from . import version + __version__ = version.__version__ diff --git a/src/google/adk_community/memory/__init__.py b/src/google/adk_community/memory/__init__.py index 1f3442c..f29e1c0 100644 --- a/src/google/adk_community/memory/__init__.py +++ b/src/google/adk_community/memory/__init__.py @@ -14,11 +14,20 @@ """Community memory services for ADK.""" +try: + from .database_memory_service import DatabaseMemoryService + from .memory_search_backend import KeywordSearchBackend + from .memory_search_backend import MemorySearchBackend +except ImportError: + pass + from .open_memory_service import OpenMemoryService from .open_memory_service import OpenMemoryServiceConfig __all__ = [ - "OpenMemoryService", - "OpenMemoryServiceConfig", + 'DatabaseMemoryService', + 'KeywordSearchBackend', + 'MemorySearchBackend', + 'OpenMemoryService', + 'OpenMemoryServiceConfig', ] - diff --git a/src/google/adk_community/memory/database_memory_service.py b/src/google/adk_community/memory/database_memory_service.py new file mode 100644 index 0000000..1847a28 --- /dev/null +++ b/src/google/adk_community/memory/database_memory_service.py @@ -0,0 +1,564 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQL-backed memory service with scratchpad support for ADK agents.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Mapping +from collections.abc import Sequence +from contextlib import asynccontextmanager +from datetime import datetime +import logging +from typing import Any +from typing import AsyncIterator +from typing import Optional +from typing import TYPE_CHECKING +import uuid + +from google.adk.memory.base_memory_service import BaseMemoryService +from google.adk.memory.base_memory_service import SearchMemoryResponse +from google.adk.memory.memory_entry import MemoryEntry +from google.genai import types +from sqlalchemy import delete +from sqlalchemy import select +from sqlalchemy.engine import make_url +from sqlalchemy.exc import ArgumentError +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.pool import StaticPool +from typing_extensions import override + +from .memory_search_backend import KeywordSearchBackend +from .memory_search_backend import MemorySearchBackend +from .schemas.memory_schema import Base +from .schemas.memory_schema import StorageMemoryEntry +from .schemas.memory_schema import StorageScratchpadKV +from .schemas.memory_schema import StorageScratchpadLog + +if TYPE_CHECKING: + from google.adk.events.event import Event + from google.adk.sessions.session import Session + +logger = logging.getLogger('google_adk.' + __name__) + +_SQLITE_DIALECT = 'sqlite' + + +def _format_timestamp(timestamp: float) -> str: + return datetime.fromtimestamp(timestamp).isoformat() + + +class DatabaseMemoryService(BaseMemoryService): + """A durable, SQL-backed memory service for any SQLAlchemy-supported DB. + + Works with SQLite, PostgreSQL, MySQL, and MariaDB. Also exposes a + scratchpad (KV store + append-log) for agents to use as intermediate + working memory during task execution. + + Usage:: + + from google.adk_community.memory import DatabaseMemoryService + + # SQLite (no external DB needed): + svc = DatabaseMemoryService("sqlite+aiosqlite:///:memory:") + + # PostgreSQL: + svc = DatabaseMemoryService( + "postgresql+asyncpg://user:pass@host/dbname" + ) + """ + + def __init__( + self, + db_url: str, + search_backend: Optional[MemorySearchBackend] = None, + **kwargs: Any, + ): + """Initialises the service and creates a DB engine. + + Args: + db_url: SQLAlchemy async connection URL. + search_backend: Optional custom search backend. Defaults to + KeywordSearchBackend. + **kwargs: Extra keyword arguments forwarded to + sqlalchemy.ext.asyncio.create_async_engine. + + Raises: + ValueError: If the db_url is invalid or the required DB driver is + not installed. + """ + try: + engine_kwargs = dict(kwargs) + url = make_url(db_url) + backend = url.get_backend_name() + if backend == _SQLITE_DIALECT and url.database == ':memory:': + engine_kwargs.setdefault('poolclass', StaticPool) + connect_args = dict(engine_kwargs.get('connect_args', {})) + connect_args.setdefault('check_same_thread', False) + engine_kwargs['connect_args'] = connect_args + elif backend != _SQLITE_DIALECT: + engine_kwargs.setdefault('pool_pre_ping', True) + + self.db_engine: AsyncEngine = create_async_engine(db_url, **engine_kwargs) + except ArgumentError as exc: + raise ValueError( + f"Invalid database URL format or argument '{db_url}'." + ) from exc + except ImportError as exc: + raise ValueError( + f"Database-related module not found for URL '{db_url}'." + ) from exc + + self._session_factory: async_sessionmaker[AsyncSession] = ( + async_sessionmaker(bind=self.db_engine, expire_on_commit=False) + ) + self._search_backend: MemorySearchBackend = ( + search_backend or KeywordSearchBackend() + ) + self._tables_created = False + self._table_creation_lock = asyncio.Lock() + + # --------------------------------------------------------------------------- + # Internal helpers + # --------------------------------------------------------------------------- + + @asynccontextmanager + async def _session(self) -> AsyncIterator[AsyncSession]: + """Yield an AsyncSession; roll back on exception.""" + async with self._session_factory() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + async def _prepare_tables(self) -> None: + """Lazy, double-checked table initialisation.""" + if self._tables_created: + return + async with self._table_creation_lock: + if self._tables_created: + return + async with self.db_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + self._tables_created = True + + @staticmethod + def _extract_search_text(content: types.Content) -> str: + """Join all text parts of a Content into a single searchable string.""" + if not content or not content.parts: + return '' + return ' '.join(part.text for part in content.parts if part.text) + + @staticmethod + def _should_skip_event(event: Event) -> bool: + """Return True if the event has no usable text content.""" + if not event.content or not event.content.parts: + return True + return not any(part.text for part in event.content.parts if part.text) + + # --------------------------------------------------------------------------- + # BaseMemoryService implementation + # --------------------------------------------------------------------------- + + @override + async def add_session_to_memory(self, session: Session) -> None: + """Idempotently ingest all events from a session. + + Deletes any existing rows for this session, then re-inserts from scratch. + + Args: + session: The session whose events should be stored in memory. + """ + await self._prepare_tables() + async with self._session() as sql: + await sql.execute( + delete(StorageMemoryEntry).where( + StorageMemoryEntry.app_name == session.app_name, + StorageMemoryEntry.user_id == session.user_id, + StorageMemoryEntry.session_id == session.id, + ) + ) + for event in session.events: + if self._should_skip_event(event): + continue + content_dict = event.content.model_dump(mode='json', exclude_none=True) + sql.add( + StorageMemoryEntry( + id=str(uuid.uuid4()), + app_name=session.app_name, + user_id=session.user_id, + session_id=session.id, + event_id=event.id, + author=event.author, + timestamp=_format_timestamp(event.timestamp), + content_json=content_dict, + search_text=self._extract_search_text(event.content), + custom_metadata={}, + ) + ) + + @override + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: Sequence[Event], + session_id: Optional[str] = None, + custom_metadata: Optional[Mapping[str, object]] = None, + ) -> None: + """Delta-insert events; skips duplicate event_id within the same session. + + Args: + app_name: The application name for memory scope. + user_id: The user ID for memory scope. + events: The events to add to memory. + session_id: Optional session ID for memory scope/partitioning. + custom_metadata: Optional metadata attached to each stored entry. + """ + await self._prepare_tables() + async with self._session() as sql: + stmt = select(StorageMemoryEntry.event_id).where( + StorageMemoryEntry.app_name == app_name, + StorageMemoryEntry.user_id == user_id, + StorageMemoryEntry.session_id == session_id, + StorageMemoryEntry.event_id.isnot(None), + ) + result = await sql.execute(stmt) + existing_event_ids = {row[0] for row in result.fetchall()} + + meta = dict(custom_metadata) if custom_metadata else {} + for event in events: + if self._should_skip_event(event): + continue + if event.id and event.id in existing_event_ids: + continue + content_dict = event.content.model_dump(mode='json', exclude_none=True) + sql.add( + StorageMemoryEntry( + id=str(uuid.uuid4()), + app_name=app_name, + user_id=user_id, + session_id=session_id, + event_id=event.id, + author=event.author, + timestamp=_format_timestamp(event.timestamp), + content_json=content_dict, + search_text=self._extract_search_text(event.content), + custom_metadata=meta, + ) + ) + if event.id: + existing_event_ids.add(event.id) + + @override + async def add_memory( + self, + *, + app_name: str, + user_id: str, + memories: Sequence[MemoryEntry], + custom_metadata: Optional[Mapping[str, object]] = None, + ) -> None: + """Directly insert MemoryEntry objects (not tied to session events). + + Args: + app_name: The application name for memory scope. + user_id: The user ID for memory scope. + memories: Explicit memory items to add. + custom_metadata: Optional metadata attached to each stored entry. + """ + await self._prepare_tables() + meta = dict(custom_metadata) if custom_metadata else {} + async with self._session() as sql: + for entry in memories: + entry_id = entry.id or str(uuid.uuid4()) + content_dict = entry.content.model_dump(mode='json', exclude_none=True) + sql.add( + StorageMemoryEntry( + id=entry_id, + app_name=app_name, + user_id=user_id, + session_id=None, + event_id=None, + author=entry.author, + timestamp=entry.timestamp, + content_json=content_dict, + search_text=self._extract_search_text(entry.content), + custom_metadata={**entry.custom_metadata, **meta}, + ) + ) + + @override + async def search_memory( + self, + *, + app_name: str, + user_id: str, + query: str, + ) -> SearchMemoryResponse: + """Search stored memories using the configured search backend. + + Args: + app_name: The name of the application. + user_id: The id of the user. + query: The query to search for. + + Returns: + A SearchMemoryResponse containing the matching memories. + """ + await self._prepare_tables() + async with self._session() as sql: + rows = await self._search_backend.search( + sql_session=sql, + app_name=app_name, + user_id=user_id, + query=query, + ) + memories = [] + for row in rows: + try: + content = types.Content.model_validate(row.content_json) + except Exception: # pylint: disable=broad-except + logger.warning( + 'Skipping memory entry %s: invalid content JSON', row.id + ) + continue + memories.append( + MemoryEntry( + id=row.id, + content=content, + author=row.author, + timestamp=row.timestamp, + custom_metadata=row.custom_metadata or {}, + ) + ) + return SearchMemoryResponse(memories=memories) + + # --------------------------------------------------------------------------- + # Scratchpad KV methods + # --------------------------------------------------------------------------- + + async def set_scratchpad( + self, + *, + app_name: str, + user_id: str, + session_id: str = '', + key: str, + value: Any, + ) -> None: + """Write a key-value pair to the scratchpad. + + Overwrites any existing value for the same composite key. + + Args: + app_name: Application name scope. + user_id: User ID scope. + session_id: Session ID scope. Use '' for user-level (non-session) KV. + key: The key to write. + value: The JSON-serialisable value to store. + """ + await self._prepare_tables() + async with self._session() as sql: + existing = await sql.get( + StorageScratchpadKV, (app_name, user_id, session_id, key) + ) + if existing is not None: + existing.value_json = value + else: + sql.add( + StorageScratchpadKV( + app_name=app_name, + user_id=user_id, + session_id=session_id, + key=key, + value_json=value, + ) + ) + + async def get_scratchpad( + self, + *, + app_name: str, + user_id: str, + session_id: str = '', + key: str, + ) -> Any | None: + """Read a value from the scratchpad. + + Args: + app_name: Application name scope. + user_id: User ID scope. + session_id: Session ID scope. Use '' for user-level (non-session) KV. + key: The key to read. + + Returns: + The stored value, or None if the key does not exist. + """ + await self._prepare_tables() + async with self._session() as sql: + row = await sql.get( + StorageScratchpadKV, (app_name, user_id, session_id, key) + ) + return row.value_json if row is not None else None + + async def delete_scratchpad( + self, + *, + app_name: str, + user_id: str, + session_id: str = '', + key: str, + ) -> None: + """Delete a key-value pair from the scratchpad. No-op if not found. + + Args: + app_name: Application name scope. + user_id: User ID scope. + session_id: Session ID scope. Use '' for user-level (non-session) KV. + key: The key to delete. + """ + await self._prepare_tables() + async with self._session() as sql: + await sql.execute( + delete(StorageScratchpadKV).where( + StorageScratchpadKV.app_name == app_name, + StorageScratchpadKV.user_id == user_id, + StorageScratchpadKV.session_id == session_id, + StorageScratchpadKV.key == key, + ) + ) + + async def list_scratchpad_keys( + self, + *, + app_name: str, + user_id: str, + session_id: str = '', + ) -> list[str]: + """Return all keys present in the scratchpad for the given scope. + + Args: + app_name: Application name scope. + user_id: User ID scope. + session_id: Session ID scope. Use '' for user-level (non-session) KV. + + Returns: + A list of key strings. + """ + await self._prepare_tables() + async with self._session() as sql: + result = await sql.execute( + select(StorageScratchpadKV.key).where( + StorageScratchpadKV.app_name == app_name, + StorageScratchpadKV.user_id == user_id, + StorageScratchpadKV.session_id == session_id, + ) + ) + return [row[0] for row in result.fetchall()] + + # --------------------------------------------------------------------------- + # Scratchpad log methods + # --------------------------------------------------------------------------- + + async def append_log( + self, + *, + app_name: str, + user_id: str, + session_id: str = '', + content: str, + tag: Optional[str] = None, + agent_name: Optional[str] = None, + extra: Optional[Any] = None, + ) -> None: + """Append an entry to the append-only scratchpad log. + + Args: + app_name: Application name scope. + user_id: User ID scope. + session_id: Session ID scope. Use '' for user-level log. + content: The text content to log. + tag: Optional category label for filtering. + agent_name: Optional name of the agent appending this entry. + extra: Optional JSON-serialisable extra data. + """ + await self._prepare_tables() + async with self._session() as sql: + sql.add( + StorageScratchpadLog( + app_name=app_name, + user_id=user_id, + session_id=session_id, + tag=tag, + agent_name=agent_name, + content=content, + extra_json=extra, + ) + ) + + async def get_log( + self, + *, + app_name: str, + user_id: str, + session_id: str = '', + tag: Optional[str] = None, + limit: int = 50, + ) -> list[dict]: + """Read the most recent log entries, optionally filtered by tag. + + Args: + app_name: Application name scope. + user_id: User ID scope. + session_id: Session ID scope. Use '' for user-level log. + tag: Optional tag to filter results by. + limit: Maximum number of entries to return. + + Returns: + A list of dicts with keys: id, tag, agent_name, content, extra. + """ + await self._prepare_tables() + async with self._session() as sql: + stmt = ( + select(StorageScratchpadLog) + .where( + StorageScratchpadLog.app_name == app_name, + StorageScratchpadLog.user_id == user_id, + StorageScratchpadLog.session_id == session_id, + ) + .order_by(StorageScratchpadLog.id.desc()) + .limit(limit) + ) + if tag is not None: + stmt = stmt.where(StorageScratchpadLog.tag == tag) + result = await sql.execute(stmt) + rows = result.scalars().all() + return [ + { + 'id': r.id, + 'tag': r.tag, + 'agent_name': r.agent_name, + 'content': r.content, + 'extra': r.extra_json, + } + for r in reversed(rows) + ] diff --git a/src/google/adk_community/memory/memory_search_backend.py b/src/google/adk_community/memory/memory_search_backend.py new file mode 100644 index 0000000..2a98dec --- /dev/null +++ b/src/google/adk_community/memory/memory_search_backend.py @@ -0,0 +1,126 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Memory search backends for DatabaseMemoryService.""" + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod +from collections.abc import Sequence +import re +from typing import TYPE_CHECKING + +from sqlalchemy import or_ +from sqlalchemy import select + +from .schemas.memory_schema import StorageMemoryEntry + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + +_ILIKE_DIALECTS = frozenset({'postgresql', 'mysql', 'mariadb'}) + + +class MemorySearchBackend(ABC): + """Abstract base class for memory search strategies.""" + + @abstractmethod + async def search( + self, + *, + sql_session: AsyncSession, + app_name: str, + user_id: str, + query: str, + limit: int = 10, + ) -> Sequence[StorageMemoryEntry]: + """Search for memory entries matching the query. + + Args: + sql_session: The active async SQLAlchemy session. + app_name: Application name scope. + user_id: User ID scope. + query: Natural-language or keyword query string. + limit: Maximum number of results to return. + + Returns: + A sequence of matching StorageMemoryEntry rows. + """ + + +class KeywordSearchBackend(MemorySearchBackend): + """LIKE/ILIKE keyword search on the search_text column. + + Strategy: + 1. Tokenise the query into individual words. + 2. Try an AND predicate (all tokens must appear) — return if found. + 3. Fall back to OR (any token matches) if AND yields nothing. + + Uses ILIKE on PostgreSQL/MySQL/MariaDB and LIKE on SQLite + (case-insensitive by default collation). + """ + + async def search( + self, + *, + sql_session: AsyncSession, + app_name: str, + user_id: str, + query: str, + limit: int = 10, + ) -> Sequence[StorageMemoryEntry]: + """Search for memory entries using LIKE/ILIKE keyword matching.""" + if not query or not query.strip(): + return [] + + tokens = [ + cleaned + for raw in query.split() + if raw.strip() + for cleaned in [re.sub(r'[^\w]', '', raw).lower()] + if cleaned + ] + if not tokens: + return [] + + dialect_name = sql_session.get_bind().dialect.name + use_ilike = dialect_name in _ILIKE_DIALECTS + + def _like_expr(token: str): + pattern = f'%{token}%' + col = StorageMemoryEntry.search_text + return col.ilike(pattern) if use_ilike else col.like(pattern) + + base_stmt = ( + select(StorageMemoryEntry) + .where( + StorageMemoryEntry.app_name == app_name, + StorageMemoryEntry.user_id == user_id, + StorageMemoryEntry.search_text.isnot(None), + ) + .limit(limit) + ) + + # AND predicate: all tokens must match. + and_stmt = base_stmt.where(*[_like_expr(t) for t in tokens]) + result = await sql_session.execute(and_stmt) + rows = result.scalars().all() + if rows: + return rows + + # OR fallback: any token matches. + or_stmt = base_stmt.where(or_(*[_like_expr(t) for t in tokens])) + result = await sql_session.execute(or_stmt) + return result.scalars().all() diff --git a/src/google/adk_community/memory/open_memory_service.py b/src/google/adk_community/memory/open_memory_service.py index 92c1ae6..4e949cf 100644 --- a/src/google/adk_community/memory/open_memory_service.py +++ b/src/google/adk_community/memory/open_memory_service.py @@ -19,27 +19,27 @@ from typing import Optional from typing import TYPE_CHECKING -import httpx -from google.genai import types -from pydantic import BaseModel -from pydantic import Field -from typing_extensions import override - from google.adk.memory import _utils from google.adk.memory.base_memory_service import BaseMemoryService from google.adk.memory.base_memory_service import SearchMemoryResponse from google.adk.memory.memory_entry import MemoryEntry +from google.genai import types +import httpx +from pydantic import BaseModel +from pydantic import Field +from typing_extensions import override from .utils import extract_text_from_event if TYPE_CHECKING: from google.adk.sessions.session import Session -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) + class OpenMemoryService(BaseMemoryService): """Memory service implementation using OpenMemory. - + See https://openmemory.cavira.app/ for more information. """ @@ -55,7 +55,7 @@ def __init__( base_url: Base URL of the OpenMemory instance (default: http://localhost:3000). api_key: API key for authentication. **Required** - must be provided. config: OpenMemoryServiceConfig instance. If None, uses defaults. - + Raises: ValueError: If api_key is not provided or is empty. """ @@ -64,7 +64,7 @@ def __init__( "api_key is required for OpenMemory. " "Provide an API key when initializing OpenMemoryService." ) - self._base_url = base_url.rstrip('/') + self._base_url = base_url.rstrip("/") self._api_key = api_key self._config = config or OpenMemoryServiceConfig() @@ -81,14 +81,12 @@ def _determine_salience(self, author: Optional[str]) -> float: else: return self._config.default_salience - def _prepare_memory_data( - self, event, content_text: str, session - ) -> dict: + def _prepare_memory_data(self, event, content_text: str, session) -> dict: """Prepare memory data structure for OpenMemory API.""" timestamp_str = None if event.timestamp: timestamp_str = _utils.format_timestamp(event.timestamp) - + # Embed author and timestamp in content for search retrieval # Format: [Author: user, Time: 2025-11-04T10:32:01] Content text enriched_content = content_text @@ -97,11 +95,11 @@ def _prepare_memory_data( metadata_parts.append(f"Author: {event.author}") if timestamp_str: metadata_parts.append(f"Time: {timestamp_str}") - + if metadata_parts: metadata_prefix = "[" + ", ".join(metadata_parts) + "] " enriched_content = metadata_prefix + content_text - + metadata = { "app_name": session.app_name, "user_id": session.user_id, @@ -110,13 +108,13 @@ def _prepare_memory_data( "invocation_id": event.invocation_id, "author": event.author, "timestamp": event.timestamp, - "source": "adk_session" + "source": "adk_session", } - + memory_data = { "content": enriched_content, "metadata": metadata, - "salience": self._determine_salience(event.author) + "salience": self._determine_salience(event.author), } if self._config.enable_metadata_tags: @@ -138,7 +136,7 @@ async def add_session_to_memory(self, session: Session): async with httpx.AsyncClient(timeout=self._config.timeout) as http_client: headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self._api_key}" + "Authorization": f"Bearer {self._api_key}", } for event in session.events: @@ -155,16 +153,14 @@ async def add_session_to_memory(self, session: Session): "tags": memory_data.get("tags", []), "metadata": memory_data.get("metadata", {}), "salience": memory_data.get("salience", 0.5), - "user_id": session.user_id + "user_id": session.user_id, } - + response = await http_client.post( - f"{self._base_url}/memory/add", - json=payload, - headers=headers + f"{self._base_url}/memory/add", json=payload, headers=headers ) response.raise_for_status() - + memories_added += 1 logger.debug("Added memory for event %s", event.id) except httpx.HTTPStatusError as e: @@ -176,24 +172,24 @@ async def add_session_to_memory(self, session: Session): ) except httpx.RequestError as e: logger.error( - "Failed to add memory for event %s due to request error: %s", event.id, e + "Failed to add memory for event %s due to request error: %s", + event.id, + e, ) except Exception as e: - logger.error("Failed to add memory for event %s due to unexpected error: %s", event.id, e) + logger.error( + "Failed to add memory for event %s due to unexpected error: %s", + event.id, + e, + ) - logger.info( - "Added %d memories from session %s", memories_added, session.id - ) + logger.info("Added %d memories from session %s", memories_added, session.id) def _build_search_payload( self, app_name: str, user_id: str, query: str ) -> dict: """Build search payload for OpenMemory query API.""" - payload = { - "query": query, - "k": self._config.search_top_k, - "filter": {} - } + payload = {"query": query, "k": self._config.search_top_k, "filter": {}} payload["filter"]["user_id"] = user_id @@ -204,7 +200,7 @@ def _build_search_payload( def _convert_to_memory_entry(self, result: dict) -> Optional[MemoryEntry]: """Convert OpenMemory result to MemoryEntry. - + Extracts author and timestamp from enriched content format: [Author: user, Time: 2025-11-04T10:32:01] Content text """ @@ -213,28 +209,24 @@ def _convert_to_memory_entry(self, result: dict) -> Optional[MemoryEntry]: author = None timestamp = None clean_content = raw_content - + # Parse enriched content format to extract metadata - match = re.match(r'^\[([^\]]+)\]\s+(.*)', raw_content, re.DOTALL) + match = re.match(r"^\[([^\]]+)\]\s+(.*)", raw_content, re.DOTALL) if match: metadata_str = match.group(1) clean_content = match.group(2) - - author_match = re.search(r'Author:\s*([^,\]]+)', metadata_str) + + author_match = re.search(r"Author:\s*([^,\]]+)", metadata_str) if author_match: author = author_match.group(1).strip() - - time_match = re.search(r'Time:\s*([^,\]]+)', metadata_str) + + time_match = re.search(r"Time:\s*([^,\]]+)", metadata_str) if time_match: timestamp = time_match.group(1).strip() - + content = types.Content(parts=[types.Part(text=clean_content)]) - return MemoryEntry( - content=content, - author=author, - timestamp=timestamp - ) + return MemoryEntry(content=content, author=author, timestamp=timestamp) except (KeyError, ValueError) as e: logger.debug("Failed to convert result to MemoryEntry: %s", e) return None @@ -247,25 +239,27 @@ async def search_memory( try: search_payload = self._build_search_payload(app_name, user_id, query) memories = [] - + async with httpx.AsyncClient(timeout=self._config.timeout) as http_client: headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self._api_key}" + "Authorization": f"Bearer {self._api_key}", } - + logger.debug("Query payload: %s", search_payload) - + response = await http_client.post( f"{self._base_url}/memory/query", json=search_payload, - headers=headers + headers=headers, ) response.raise_for_status() result = response.json() - - logger.debug("Query returned %d matches", len(result.get("matches", []))) - + + logger.debug( + "Query returned %d matches", len(result.get("matches", [])) + ) + for match in result.get("matches", []): memory_entry = self._convert_to_memory_entry(match) if memory_entry: diff --git a/src/google/adk_community/memory/schemas/__init__.py b/src/google/adk_community/memory/schemas/__init__.py new file mode 100644 index 0000000..0a2669d --- /dev/null +++ b/src/google/adk_community/memory/schemas/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk_community/memory/schemas/memory_schema.py b/src/google/adk_community/memory/schemas/memory_schema.py new file mode 100644 index 0000000..ebc3f9c --- /dev/null +++ b/src/google/adk_community/memory/schemas/memory_schema.py @@ -0,0 +1,196 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQLAlchemy ORM schema for DatabaseMemoryService tables.""" + +from __future__ import annotations + +import json +from typing import Any +from typing import Optional + +from sqlalchemy import DateTime +from sqlalchemy import Dialect +from sqlalchemy import func +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import Text +from sqlalchemy.dialects import mysql +from sqlalchemy.dialects import postgresql +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.types import String +from sqlalchemy.types import TypeDecorator + +DEFAULT_MAX_KEY_LENGTH = 128 +DEFAULT_MAX_VARCHAR_LENGTH = 256 + + +class DynamicJSON(TypeDecorator): + """JSON type using JSONB on PostgreSQL and TEXT elsewhere.""" + + impl = Text + cache_ok = True + + def load_dialect_impl(self, dialect: Dialect): + if dialect.name == 'postgresql': + return dialect.type_descriptor(postgresql.JSONB) + if dialect.name == 'mysql': + return dialect.type_descriptor(mysql.LONGTEXT) + return dialect.type_descriptor(Text) + + def process_bind_param(self, value, dialect: Dialect): + if value is not None: + if dialect.name == 'postgresql': + return value + return json.dumps(value) + return value + + def process_result_value(self, value, dialect: Dialect): + if value is not None: + if dialect.name == 'postgresql': + return value + return json.loads(value) + return value + + +class PreciseTimestamp(TypeDecorator): + """Timestamp with microsecond precision.""" + + impl = DateTime + cache_ok = True + + def load_dialect_impl(self, dialect: Dialect): + if dialect.name == 'mysql': + return dialect.type_descriptor(mysql.DATETIME(fsp=6)) + return self.impl + + +class Base(DeclarativeBase): + """Declarative base for memory schema tables.""" + + pass + + +class StorageMemoryEntry(Base): + """ORM model for the adk_memory_entries table.""" + + __tablename__ = 'adk_memory_entries' + + id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=False, index=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=False, index=True + ) + session_id: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=True + ) + event_id: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=True + ) + author: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=True + ) + timestamp: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + content_json: Mapped[Any] = mapped_column(DynamicJSON, nullable=True) + search_text: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + custom_metadata: Mapped[Any] = mapped_column( + MutableDict.as_mutable(DynamicJSON), nullable=True + ) + created_at: Mapped[Any] = mapped_column( + PreciseTimestamp, server_default=func.now() + ) + + __table_args__ = ( + Index('ix_memory_entries_app_user', 'app_name', 'user_id'), + Index('ix_memory_entries_session', 'app_name', 'user_id', 'session_id'), + ) + + +class StorageScratchpadKV(Base): + """ORM model for the adk_scratchpad_kv table. + + Composite PK: (app_name, user_id, session_id, key). + Use session_id='' as a sentinel for user-level (non-session) KV. + """ + + __tablename__ = 'adk_scratchpad_kv' + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + session_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + key: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + value_json: Mapped[Any] = mapped_column(DynamicJSON, nullable=False) + updated_at: Mapped[Any] = mapped_column( + PreciseTimestamp, + server_default=func.now(), + onupdate=func.now(), + ) + + +class StorageScratchpadLog(Base): + """ORM model for the adk_scratchpad_log table. + + Append-only. id is autoincrement int to preserve insertion order. + Use session_id='' as a sentinel for user-level (non-session) log. + """ + + __tablename__ = 'adk_scratchpad_log' + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=False + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=False + ) + session_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=False + ) + tag: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=True, index=True + ) + agent_name: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=True + ) + content: Mapped[str] = mapped_column(Text, nullable=False) + extra_json: Mapped[Optional[Any]] = mapped_column(DynamicJSON, nullable=True) + created_at: Mapped[Any] = mapped_column( + PreciseTimestamp, server_default=func.now() + ) + + __table_args__ = ( + Index( + 'ix_scratchpad_log_scope', + 'app_name', + 'user_id', + 'session_id', + ), + ) diff --git a/src/google/adk_community/memory/utils.py b/src/google/adk_community/memory/utils.py index 0b78206..48058ce 100644 --- a/src/google/adk_community/memory/utils.py +++ b/src/google/adk_community/memory/utils.py @@ -33,9 +33,8 @@ def extract_text_from_event(event) -> str: # Filter out thought parts and only extract text # This prevents metadata like thoughtSignature from being stored text_parts = [ - part.text - for part in event.content.parts + part.text + for part in event.content.parts if part.text and not part.thought ] return ' '.join(text_parts) - diff --git a/src/google/adk_community/sessions/redis_session_service.py b/src/google/adk_community/sessions/redis_session_service.py index bd8e289..075b342 100644 --- a/src/google/adk_community/sessions/redis_session_service.py +++ b/src/google/adk_community/sessions/redis_session_service.py @@ -17,23 +17,21 @@ import bisect import logging import time +from typing import Any +from typing import Optional import uuid -from typing import Any, Optional +from google.adk.events.event import Event +from google.adk.sessions.base_session_service import BaseSessionService +from google.adk.sessions.base_session_service import GetSessionConfig +from google.adk.sessions.base_session_service import ListSessionsResponse +from google.adk.sessions.session import Session +from google.adk.sessions.state import State import orjson import redis.asyncio as redis from redis.crc import key_slot from typing_extensions import override -from google.adk.events.event import Event -from google.adk.sessions.base_session_service import ( - BaseSessionService, - GetSessionConfig, - ListSessionsResponse, -) -from google.adk.sessions.session import Session -from google.adk.sessions.state import State - from .utils import _json_serializer logger = logging.getLogger("google_adk." + __name__) @@ -42,256 +40,261 @@ def _session_serializer(obj: Session) -> bytes: - """Serialize ADK Session to JSON bytes.""" - return orjson.dumps(obj.model_dump(), default=_json_serializer) + """Serialize ADK Session to JSON bytes.""" + return orjson.dumps(obj.model_dump(), default=_json_serializer) class RedisKeys: - """Helper to generate Redis keys consistently.""" + """Helper to generate Redis keys consistently.""" - @staticmethod - def session(session_id: str) -> str: - return f"session:{session_id}" + @staticmethod + def session(session_id: str) -> str: + return f"session:{session_id}" - @staticmethod - def user_sessions(app_name: str, user_id: str) -> str: - return f"{State.APP_PREFIX}:{app_name}:{user_id}" + @staticmethod + def user_sessions(app_name: str, user_id: str) -> str: + return f"{State.APP_PREFIX}:{app_name}:{user_id}" - @staticmethod - def app_state(app_name: str) -> str: - return f"{State.APP_PREFIX}{app_name}" + @staticmethod + def app_state(app_name: str) -> str: + return f"{State.APP_PREFIX}{app_name}" - @staticmethod - def user_state(app_name: str, user_id: str) -> str: - return f"{State.USER_PREFIX}{app_name}:{user_id}" + @staticmethod + def user_state(app_name: str, user_id: str) -> str: + return f"{State.USER_PREFIX}{app_name}:{user_id}" class RedisSessionService(BaseSessionService): - """A Redis-backed implementation of the session service.""" - - def __init__( - self, - host="localhost", - port=6379, - db=0, - uri=None, - cluster_uri=None, - expire=DEFAULT_EXPIRATION, - **kwargs, - ): - self.expire = expire - - if cluster_uri: - self.cache = redis.RedisCluster.from_url(cluster_uri, **kwargs) - elif uri: - self.cache = redis.Redis.from_url(uri, **kwargs) - else: - self.cache = redis.Redis(host=host, port=port, db=db, **kwargs) - - async def health_check(self) -> bool: - try: - await self.cache.ping() - return True - except redis.RedisError: - return False - - @override - async def create_session( - self, - *, - app_name: str, - user_id: str, - state: Optional[dict[str, Any]] = None, - session_id: Optional[str] = None, - ) -> Session: - session_id = ( - session_id.strip() - if session_id and session_id.strip() - else str(uuid.uuid4()) - ) - session = Session( - app_name=app_name, - user_id=user_id, - id=session_id, - state=state or {}, - last_update_time=time.time(), - ) - - user_sessions_key = RedisKeys.user_sessions(app_name, user_id) - session_key = RedisKeys.session(session_id) - - async with self.cache.pipeline(transaction=False) as pipe: - pipe.sadd(user_sessions_key, session_id) - pipe.expire(user_sessions_key, self.expire) - pipe.set( - session_key, - _session_serializer(session), - ex=self.expire, + """A Redis-backed implementation of the session service.""" + + def __init__( + self, + host="localhost", + port=6379, + db=0, + uri=None, + cluster_uri=None, + expire=DEFAULT_EXPIRATION, + **kwargs, + ): + self.expire = expire + + if cluster_uri: + self.cache = redis.RedisCluster.from_url(cluster_uri, **kwargs) + elif uri: + self.cache = redis.Redis.from_url(uri, **kwargs) + else: + self.cache = redis.Redis(host=host, port=port, db=db, **kwargs) + + async def health_check(self) -> bool: + try: + await self.cache.ping() + return True + except redis.RedisError: + return False + + @override + async def create_session( + self, + *, + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> Session: + session_id = ( + session_id.strip() + if session_id and session_id.strip() + else str(uuid.uuid4()) + ) + session = Session( + app_name=app_name, + user_id=user_id, + id=session_id, + state=state or {}, + last_update_time=time.time(), + ) + + user_sessions_key = RedisKeys.user_sessions(app_name, user_id) + session_key = RedisKeys.session(session_id) + + async with self.cache.pipeline(transaction=False) as pipe: + pipe.sadd(user_sessions_key, session_id) + pipe.expire(user_sessions_key, self.expire) + pipe.set( + session_key, + _session_serializer(session), + ex=self.expire, + ) + await pipe.execute() + + return await self._merge_state(app_name, user_id, session) + + @override + async def get_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + config: Optional[GetSessionConfig] = None, + ) -> Optional[Session]: + session_key = RedisKeys.session(session_id) + raw_session = await self.cache.get(session_key) + if not raw_session: + user_sessions_key = RedisKeys.user_sessions(app_name, user_id) + await self.cache.srem(user_sessions_key, session_id) + return None + + try: + session_dict = orjson.loads(raw_session) + session = Session.model_validate(session_dict) + except (orjson.JSONDecodeError, Exception) as e: + logger.error(f"Error decoding session {session_id}: {e}") + return None + + if config: + if config.num_recent_events: + session.events = session.events[-config.num_recent_events :] + if config.after_timestamp: + timestamps = [e.timestamp for e in session.events] + start_index = bisect.bisect_left(timestamps, config.after_timestamp) + session.events = session.events[start_index:] + + return await self._merge_state(app_name, user_id, session) + + @override + async def list_sessions( + self, *, app_name: str, user_id: str + ) -> ListSessionsResponse: + sessions = await self._load_sessions(app_name, user_id) + sessions_without_events = [] + + for session_data in sessions.values(): + session = Session.model_validate(session_data) + session.events = [] + session.state = {} + sessions_without_events.append(session) + + return ListSessionsResponse(sessions=sessions_without_events) + + @override + async def delete_session( + self, *, app_name: str, user_id: str, session_id: str + ) -> None: + user_sessions_key = RedisKeys.user_sessions(app_name, user_id) + session_key = RedisKeys.session(session_id) + + async with self.cache.pipeline(transaction=False) as pipe: + pipe.srem(user_sessions_key, session_id) + pipe.delete(session_key) + await pipe.execute() + + @override + async def append_event(self, session: Session, event: Event) -> Event: + await super().append_event(session=session, event=event) + session.last_update_time = event.timestamp + + async with self.cache.pipeline(transaction=False) as pipe: + user_sessions_key = RedisKeys.user_sessions( + session.app_name, session.user_id + ) + pipe.expire(user_sessions_key, self.expire) + + if event.actions and event.actions.state_delta: + for key, value in event.actions.state_delta.items(): + if key.startswith(State.APP_PREFIX): + pipe.hset( + RedisKeys.app_state(session.app_name), + key.removeprefix(State.APP_PREFIX), + orjson.dumps(value), ) - await pipe.execute() - - return await self._merge_state(app_name, user_id, session) - - @override - async def get_session( - self, - *, - app_name: str, - user_id: str, - session_id: str, - config: Optional[GetSessionConfig] = None, - ) -> Optional[Session]: - session_key = RedisKeys.session(session_id) - raw_session = await self.cache.get(session_key) - if not raw_session: - user_sessions_key = RedisKeys.user_sessions(app_name, user_id) - await self.cache.srem(user_sessions_key, session_id) - return None - - try: - session_dict = orjson.loads(raw_session) - session = Session.model_validate(session_dict) - except (orjson.JSONDecodeError, Exception) as e: - logger.error(f"Error decoding session {session_id}: {e}") - return None - - if config: - if config.num_recent_events: - session.events = session.events[-config.num_recent_events :] - if config.after_timestamp: - timestamps = [e.timestamp for e in session.events] - start_index = bisect.bisect_left(timestamps, config.after_timestamp) - session.events = session.events[start_index:] - - return await self._merge_state(app_name, user_id, session) - - @override - async def list_sessions( - self, *, app_name: str, user_id: str - ) -> ListSessionsResponse: - sessions = await self._load_sessions(app_name, user_id) - sessions_without_events = [] - - for session_data in sessions.values(): - session = Session.model_validate(session_data) - session.events = [] - session.state = {} - sessions_without_events.append(session) - - return ListSessionsResponse(sessions=sessions_without_events) - - @override - async def delete_session( - self, *, app_name: str, user_id: str, session_id: str - ) -> None: - user_sessions_key = RedisKeys.user_sessions(app_name, user_id) - session_key = RedisKeys.session(session_id) - - async with self.cache.pipeline(transaction=False) as pipe: - pipe.srem(user_sessions_key, session_id) - pipe.delete(session_key) - await pipe.execute() - - @override - async def append_event(self, session: Session, event: Event) -> Event: - await super().append_event(session=session, event=event) - session.last_update_time = event.timestamp - - async with self.cache.pipeline(transaction=False) as pipe: - user_sessions_key = RedisKeys.user_sessions( - session.app_name, session.user_id - ) - pipe.expire(user_sessions_key, self.expire) - - if event.actions and event.actions.state_delta: - for key, value in event.actions.state_delta.items(): - if key.startswith(State.APP_PREFIX): - pipe.hset( - RedisKeys.app_state(session.app_name), - key.removeprefix(State.APP_PREFIX), - orjson.dumps(value), - ) - if key.startswith(State.USER_PREFIX): - pipe.hset( - RedisKeys.user_state(session.app_name, session.user_id), - key.removeprefix(State.USER_PREFIX), - orjson.dumps(value), - ) - - pipe.set( - RedisKeys.session(session.id), - _session_serializer(session), - ex=self.expire, - ) - await pipe.execute() - - return event - - async def _merge_state( - self, app_name: str, user_id: str, session: Session - ) -> Session: - app_state = await self.cache.hgetall(RedisKeys.app_state(app_name)) - for k, v in app_state.items(): - session.state[State.APP_PREFIX + k.decode()] = orjson.loads(v) - - user_state = await self.cache.hgetall(RedisKeys.user_state(app_name, user_id)) - for k, v in user_state.items(): - session.state[State.USER_PREFIX + k.decode()] = orjson.loads(v) - - return session - - async def _load_sessions(self, app_name: str, user_id: str) -> dict[str, dict]: - key = RedisKeys.user_sessions(app_name, user_id) - try: - session_ids_bytes = await self.cache.smembers(key) - if not session_ids_bytes: - return {} - - session_ids = [s.decode() for s in session_ids_bytes] - session_keys = [RedisKeys.session(sid) for sid in session_ids] - - # Group by slot for Redis Cluster - slot_groups: dict[int, list[str]] = {} - for k in session_keys: - slot = key_slot(k.encode()) - slot_groups.setdefault(slot, []).append(k) - - async def fetch_group(keys: list[str]): - async with self.cache.pipeline(transaction=False) as pipe: - for k in keys: - pipe.get(k) - return await pipe.execute() - - results_per_group = await asyncio.gather( - *(fetch_group(keys) for keys in slot_groups.values()) + if key.startswith(State.USER_PREFIX): + pipe.hset( + RedisKeys.user_state(session.app_name, session.user_id), + key.removeprefix(State.USER_PREFIX), + orjson.dumps(value), ) - raw_sessions = [] - for group_keys, group_results in zip( - slot_groups.values(), results_per_group - ): - raw_sessions.extend(zip(group_keys, group_results)) - - sessions = {} - sessions_to_cleanup = [] - for key_name, raw_session in raw_sessions: - session_id = key_name.split(":", 1)[1] - if raw_session: - try: - sessions[session_id] = orjson.loads(raw_session) - except orjson.JSONDecodeError as e: - logger.error(f"Error decoding session {session_id}: {e}") - else: - logger.warning( - "Session ID %s found in user set but session data is missing. Cleaning up.", - session_id, - ) - sessions_to_cleanup.append(session_id) - - if sessions_to_cleanup: - await self.cache.srem(key, *sessions_to_cleanup) - - return sessions - except redis.RedisError as e: - logger.error(f"Error loading sessions for {user_id}: {e}") - return {} + pipe.set( + RedisKeys.session(session.id), + _session_serializer(session), + ex=self.expire, + ) + await pipe.execute() + + return event + + async def _merge_state( + self, app_name: str, user_id: str, session: Session + ) -> Session: + app_state = await self.cache.hgetall(RedisKeys.app_state(app_name)) + for k, v in app_state.items(): + session.state[State.APP_PREFIX + k.decode()] = orjson.loads(v) + + user_state = await self.cache.hgetall( + RedisKeys.user_state(app_name, user_id) + ) + for k, v in user_state.items(): + session.state[State.USER_PREFIX + k.decode()] = orjson.loads(v) + + return session + + async def _load_sessions( + self, app_name: str, user_id: str + ) -> dict[str, dict]: + key = RedisKeys.user_sessions(app_name, user_id) + try: + session_ids_bytes = await self.cache.smembers(key) + if not session_ids_bytes: + return {} + + session_ids = [s.decode() for s in session_ids_bytes] + session_keys = [RedisKeys.session(sid) for sid in session_ids] + + # Group by slot for Redis Cluster + slot_groups: dict[int, list[str]] = {} + for k in session_keys: + slot = key_slot(k.encode()) + slot_groups.setdefault(slot, []).append(k) + + async def fetch_group(keys: list[str]): + async with self.cache.pipeline(transaction=False) as pipe: + for k in keys: + pipe.get(k) + return await pipe.execute() + + results_per_group = await asyncio.gather( + *(fetch_group(keys) for keys in slot_groups.values()) + ) + + raw_sessions = [] + for group_keys, group_results in zip( + slot_groups.values(), results_per_group + ): + raw_sessions.extend(zip(group_keys, group_results)) + + sessions = {} + sessions_to_cleanup = [] + for key_name, raw_session in raw_sessions: + session_id = key_name.split(":", 1)[1] + if raw_session: + try: + sessions[session_id] = orjson.loads(raw_session) + except orjson.JSONDecodeError as e: + logger.error(f"Error decoding session {session_id}: {e}") + else: + logger.warning( + "Session ID %s found in user set but session data is missing." + " Cleaning up.", + session_id, + ) + sessions_to_cleanup.append(session_id) + + if sessions_to_cleanup: + await self.cache.srem(key, *sessions_to_cleanup) + + return sessions + except redis.RedisError as e: + logger.error(f"Error loading sessions for {user_id}: {e}") + return {} diff --git a/src/google/adk_community/sessions/utils.py b/src/google/adk_community/sessions/utils.py index bc53d2b..132c773 100644 --- a/src/google/adk_community/sessions/utils.py +++ b/src/google/adk_community/sessions/utils.py @@ -20,18 +20,18 @@ def _json_serializer(obj): - """Fallback serializer to handle non-JSON-compatible types.""" - if isinstance(obj, set): - return list(obj) - if isinstance(obj, bytes): - try: - return base64.b64encode(obj).decode("ascii") - except Exception: - return repr(obj) - if isinstance(obj, (datetime.datetime, datetime.date)): - return obj.isoformat() - if isinstance(obj, uuid.UUID): - return str(obj) - if isinstance(obj, Decimal): - return float(obj) + """Fallback serializer to handle non-JSON-compatible types.""" + if isinstance(obj, set): + return list(obj) + if isinstance(obj, bytes): + try: + return base64.b64encode(obj).decode("ascii") + except Exception: + return repr(obj) + if isinstance(obj, (datetime.datetime, datetime.date)): + return obj.isoformat() + if isinstance(obj, uuid.UUID): return str(obj) + if isinstance(obj, Decimal): + return float(obj) + return str(obj) diff --git a/src/google/adk_community/tools/__init__.py b/src/google/adk_community/tools/__init__.py new file mode 100644 index 0000000..89f387d --- /dev/null +++ b/src/google/adk_community/tools/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Community tools for ADK.""" + +from .scratchpad_tool import scratchpad_append_log_tool +from .scratchpad_tool import scratchpad_get_log_tool +from .scratchpad_tool import scratchpad_get_tool +from .scratchpad_tool import scratchpad_set_tool +from .scratchpad_tool import ScratchpadAppendLogTool +from .scratchpad_tool import ScratchpadGetLogTool +from .scratchpad_tool import ScratchpadGetTool +from .scratchpad_tool import ScratchpadSetTool + +__all__ = [ + 'ScratchpadGetTool', + 'ScratchpadSetTool', + 'ScratchpadAppendLogTool', + 'ScratchpadGetLogTool', + 'scratchpad_get_tool', + 'scratchpad_set_tool', + 'scratchpad_append_log_tool', + 'scratchpad_get_log_tool', +] diff --git a/src/google/adk_community/tools/scratchpad_tool.py b/src/google/adk_community/tools/scratchpad_tool.py new file mode 100644 index 0000000..b7afebc --- /dev/null +++ b/src/google/adk_community/tools/scratchpad_tool.py @@ -0,0 +1,244 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent-callable tools for reading/writing the DatabaseMemoryService scratchpad.""" + +from __future__ import annotations + +from typing import Any +from typing import Optional + +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +from typing_extensions import override + + +def _get_db_memory_service(tool_context: ToolContext): + """Return the DatabaseMemoryService from the invocation context, or raise.""" + # pylint: disable=g-import-not-at-top + from google.adk_community.memory.database_memory_service import DatabaseMemoryService + + svc = tool_context._invocation_context.memory_service + if not isinstance(svc, DatabaseMemoryService): + raise ValueError( + "Scratchpad tools require the agent's memory_service to be a " + f'DatabaseMemoryService, got: {type(svc).__name__}' + ) + return svc + + +def _session_scope(tool_context: ToolContext) -> tuple[str, str, str]: + """Return (app_name, user_id, session_id) from the invocation context.""" + ic = tool_context._invocation_context + return ic.app_name, ic.session.user_id, ic.session.id + + +class ScratchpadGetTool(BaseTool): + """Read a value from the agent scratchpad KV store.""" + + def __init__(self): + super().__init__( + name='scratchpad_get', + description=( + 'Read a value stored in the scratchpad KV store by key.' + ' Returns null if the key does not exist.' + ), + ) + + @override + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + 'key': types.Schema( + type=types.Type.STRING, + description='The key to read.', + ), + }, + required=['key'], + ), + ) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + svc = _get_db_memory_service(tool_context) + app_name, user_id, session_id = _session_scope(tool_context) + return await svc.get_scratchpad( + app_name=app_name, + user_id=user_id, + session_id=session_id, + key=args['key'], + ) + + +class ScratchpadSetTool(BaseTool): + """Write a value to the agent scratchpad KV store.""" + + def __init__(self): + super().__init__( + name='scratchpad_set', + description=( + 'Write a value to the scratchpad KV store. ' + 'Overwrites any existing value for the same key.' + ), + ) + + @override + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + 'key': types.Schema( + type=types.Type.STRING, + description='The key to write.', + ), + 'value': types.Schema( + description=( + 'The value to store (any JSON-serialisable type).' + ), + ), + }, + required=['key', 'value'], + ), + ) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> str: + svc = _get_db_memory_service(tool_context) + app_name, user_id, session_id = _session_scope(tool_context) + await svc.set_scratchpad( + app_name=app_name, + user_id=user_id, + session_id=session_id, + key=args['key'], + value=args['value'], + ) + return 'ok' + + +class ScratchpadAppendLogTool(BaseTool): + """Append an observation or note to the agent scratchpad log.""" + + def __init__(self): + super().__init__( + name='scratchpad_append_log', + description=( + 'Append a text observation or note to the scratchpad log. ' + 'Entries are stored in insertion order and can be filtered by tag.' + ), + ) + + @override + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + 'content': types.Schema( + type=types.Type.STRING, + description='The text content to log.', + ), + 'tag': types.Schema( + type=types.Type.STRING, + description='Optional category label for filtering.', + ), + }, + required=['content'], + ), + ) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> str: + svc = _get_db_memory_service(tool_context) + app_name, user_id, session_id = _session_scope(tool_context) + await svc.append_log( + app_name=app_name, + user_id=user_id, + session_id=session_id, + content=args['content'], + tag=args.get('tag'), + agent_name=tool_context.agent_name, + ) + return 'ok' + + +class ScratchpadGetLogTool(BaseTool): + """Read recent entries from the agent scratchpad log.""" + + def __init__(self): + super().__init__( + name='scratchpad_get_log', + description=( + 'Read recent entries from the scratchpad log, ' + 'optionally filtered by tag.' + ), + ) + + @override + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + 'tag': types.Schema( + type=types.Type.STRING, + description='Optional category label to filter by.', + ), + 'limit': types.Schema( + type=types.Type.INTEGER, + description=( + 'Maximum number of entries to return (default 50).' + ), + ), + }, + ), + ) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> list[dict]: + svc = _get_db_memory_service(tool_context) + app_name, user_id, session_id = _session_scope(tool_context) + return await svc.get_log( + app_name=app_name, + user_id=user_id, + session_id=session_id, + tag=args.get('tag'), + limit=int(args.get('limit', 50)), + ) + + +# Ready-to-use singleton instances +scratchpad_get_tool = ScratchpadGetTool() +scratchpad_set_tool = ScratchpadSetTool() +scratchpad_append_log_tool = ScratchpadAppendLogTool() +scratchpad_get_log_tool = ScratchpadGetLogTool() diff --git a/tests/unittests/memory/__init__.py b/tests/unittests/memory/__init__.py index 36a1e8d..0a2669d 100644 --- a/tests/unittests/memory/__init__.py +++ b/tests/unittests/memory/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/tests/unittests/memory/test_database_memory_service.py b/tests/unittests/memory/test_database_memory_service.py new file mode 100644 index 0000000..027908e --- /dev/null +++ b/tests/unittests/memory/test_database_memory_service.py @@ -0,0 +1,711 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for DatabaseMemoryService.""" + +from __future__ import annotations + +from collections.abc import Sequence +import time +from typing import Any +from unittest.mock import MagicMock + +from google.adk.events.event import Event +from google.adk.memory.base_memory_service import SearchMemoryResponse +from google.adk.memory.memory_entry import MemoryEntry +from google.adk.sessions.session import Session +from google.genai import types +import pytest +import pytest_asyncio + +from google.adk_community.memory.database_memory_service import DatabaseMemoryService +from google.adk_community.memory.memory_search_backend import MemorySearchBackend +from google.adk_community.memory.schemas.memory_schema import StorageMemoryEntry + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_DB_URL = 'sqlite+aiosqlite:///:memory:' +_APP = 'test_app' +_USER = 'user_1' +_SESSION = 'session_1' + + +def _make_content(text: str) -> types.Content: + return types.Content(role='user', parts=[types.Part(text=text)]) + + +def _make_event( + text: str, event_id: str = 'ev1', author: str = 'user' +) -> Event: + return Event( + id=event_id, + author=author, + content=_make_content(text), + timestamp=time.time(), + invocation_id='inv1', + ) + + +def _make_session(events: list[Event], session_id: str = _SESSION) -> Session: + return Session( + id=session_id, + app_name=_APP, + user_id=_USER, + events=events, + ) + + +@pytest.fixture +def svc() -> DatabaseMemoryService: + return DatabaseMemoryService(_DB_URL) + + +# --------------------------------------------------------------------------- +# 1. add_session_to_memory — filters empty events, persists content/author/ts +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_session_to_memory_persists_text_events(svc): + session = _make_session([_make_event('hello world')]) + await svc.add_session_to_memory(session) + + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query='hello') + assert len(resp.memories) == 1 + assert resp.memories[0].author == 'user' + assert resp.memories[0].timestamp is not None + + +@pytest.mark.asyncio +async def test_add_session_to_memory_skips_empty_events(svc): + empty_event = Event( + id='empty', + author='user', + content=types.Content(role='user', parts=[]), + timestamp=time.time(), + invocation_id='inv1', + ) + session = _make_session([empty_event]) + await svc.add_session_to_memory(session) + + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query='anything') + assert resp.memories == [] + + +# --------------------------------------------------------------------------- +# 2. Re-ingest same session → idempotent (no duplicates) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_session_to_memory_idempotent(svc): + session = _make_session([_make_event('idempotent test')]) + await svc.add_session_to_memory(session) + await svc.add_session_to_memory(session) + + resp = await svc.search_memory( + app_name=_APP, user_id=_USER, query='idempotent' + ) + assert len(resp.memories) == 1 + + +# --------------------------------------------------------------------------- +# 3. add_events_to_memory — delta, skips duplicate event_id +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_events_to_memory_delta(svc): + ev = _make_event('delta event', event_id='ev_delta') + await svc.add_events_to_memory( + app_name=_APP, + user_id=_USER, + events=[ev], + session_id=_SESSION, + ) + await svc.add_events_to_memory( + app_name=_APP, + user_id=_USER, + events=[ev], + session_id=_SESSION, + ) + + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query='delta') + assert len(resp.memories) == 1 + + +@pytest.mark.asyncio +async def test_add_events_to_memory_skips_empty(svc): + empty = Event( + id='empty2', + author='agent', + content=types.Content(role='model', parts=[]), + timestamp=time.time(), + invocation_id='inv1', + ) + await svc.add_events_to_memory( + app_name=_APP, user_id=_USER, events=[empty], session_id=_SESSION + ) + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query='anything') + assert resp.memories == [] + + +# --------------------------------------------------------------------------- +# 4. add_memory — direct MemoryEntry persist, auto-UUID +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_memory_direct(svc): + entry = MemoryEntry( + content=_make_content('direct memory fact'), + author='system', + ) + await svc.add_memory(app_name=_APP, user_id=_USER, memories=[entry]) + + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query='direct') + assert len(resp.memories) == 1 + assert resp.memories[0].author == 'system' + assert resp.memories[0].id is not None + + +@pytest.mark.asyncio +async def test_add_memory_preserves_explicit_id(svc): + entry = MemoryEntry( + id='explicit-id-123', + content=_make_content('explicit id memory'), + ) + await svc.add_memory(app_name=_APP, user_id=_USER, memories=[entry]) + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query='explicit') + assert resp.memories[0].id == 'explicit-id-123' + + +# --------------------------------------------------------------------------- +# 5. search_memory — AND match, OR fallback, no results for empty query +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_search_and_match(svc): + await svc.add_memory( + app_name=_APP, + user_id=_USER, + memories=[MemoryEntry(content=_make_content('cats and dogs'))], + ) + resp = await svc.search_memory( + app_name=_APP, user_id=_USER, query='cats dogs' + ) + assert len(resp.memories) == 1 + + +@pytest.mark.asyncio +async def test_search_or_fallback(svc): + await svc.add_memory( + app_name=_APP, + user_id=_USER, + memories=[MemoryEntry(content=_make_content('cats are great'))], + ) + resp = await svc.search_memory( + app_name=_APP, user_id=_USER, query='cats fish' + ) + assert len(resp.memories) == 1 + + +@pytest.mark.asyncio +async def test_search_empty_query_returns_empty(svc): + await svc.add_memory( + app_name=_APP, + user_id=_USER, + memories=[MemoryEntry(content=_make_content('something'))], + ) + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query='') + assert resp.memories == [] + + +@pytest.mark.asyncio +async def test_search_no_match(svc): + await svc.add_memory( + app_name=_APP, + user_id=_USER, + memories=[MemoryEntry(content=_make_content('hello world'))], + ) + resp = await svc.search_memory( + app_name=_APP, user_id=_USER, query='zzznomatch' + ) + assert resp.memories == [] + + +# --------------------------------------------------------------------------- +# 6. Scratchpad KV: set/get/overwrite/delete/list +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scratchpad_kv_set_get(svc): + await svc.set_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='k1', value='v1' + ) + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='k1' + ) + assert val == 'v1' + + +@pytest.mark.asyncio +async def test_scratchpad_kv_overwrite(svc): + await svc.set_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='k2', value='old' + ) + await svc.set_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='k2', value='new' + ) + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='k2' + ) + assert val == 'new' + + +@pytest.mark.asyncio +async def test_scratchpad_kv_missing_returns_none(svc): + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='nonexistent' + ) + assert val is None + + +@pytest.mark.asyncio +async def test_scratchpad_kv_delete(svc): + await svc.set_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='k3', value='v3' + ) + await svc.delete_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='k3' + ) + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='k3' + ) + assert val is None + + +@pytest.mark.asyncio +async def test_scratchpad_kv_list_keys(svc): + for k in ('a', 'b', 'c'): + await svc.set_scratchpad( + app_name=_APP, + user_id=_USER, + session_id=_SESSION, + key=k, + value=k, + ) + keys = await svc.list_scratchpad_keys( + app_name=_APP, user_id=_USER, session_id=_SESSION + ) + assert set(keys) == {'a', 'b', 'c'} + + +@pytest.mark.asyncio +async def test_scratchpad_kv_json_types(svc): + payload = {'nested': [1, 2, 3], 'flag': True} + await svc.set_scratchpad( + app_name=_APP, + user_id=_USER, + session_id=_SESSION, + key='json_key', + value=payload, + ) + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='json_key' + ) + assert val == payload + + +# --------------------------------------------------------------------------- +# 7. Scratchpad log: append/get, filter by tag, limit +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scratchpad_log_append_get(svc): + await svc.append_log( + app_name=_APP, user_id=_USER, session_id=_SESSION, content='entry 1' + ) + await svc.append_log( + app_name=_APP, user_id=_USER, session_id=_SESSION, content='entry 2' + ) + entries = await svc.get_log(app_name=_APP, user_id=_USER, session_id=_SESSION) + assert len(entries) == 2 + assert entries[0]['content'] == 'entry 1' + assert entries[1]['content'] == 'entry 2' + + +@pytest.mark.asyncio +async def test_scratchpad_log_filter_by_tag(svc): + await svc.append_log( + app_name=_APP, + user_id=_USER, + session_id=_SESSION, + content='tagged', + tag='mytag', + ) + await svc.append_log( + app_name=_APP, user_id=_USER, session_id=_SESSION, content='untagged' + ) + tagged = await svc.get_log( + app_name=_APP, user_id=_USER, session_id=_SESSION, tag='mytag' + ) + assert len(tagged) == 1 + assert tagged[0]['content'] == 'tagged' + + +@pytest.mark.asyncio +async def test_scratchpad_log_limit(svc): + for i in range(10): + await svc.append_log( + app_name=_APP, + user_id=_USER, + session_id=_SESSION, + content=f'msg {i}', + ) + entries = await svc.get_log( + app_name=_APP, user_id=_USER, session_id=_SESSION, limit=3 + ) + assert len(entries) == 3 + + +# --------------------------------------------------------------------------- +# 8. Custom search backend +# --------------------------------------------------------------------------- + + +class _AlwaysReturnOneBackend(MemorySearchBackend): + """Stub backend that always returns a single hard-coded row.""" + + async def search( + self, + *, + sql_session, + app_name, + user_id, + query, + limit=10, + ) -> Sequence[StorageMemoryEntry]: + row = StorageMemoryEntry( + id='stub-id', + app_name=app_name, + user_id=user_id, + content_json={'role': 'user', 'parts': [{'text': 'stub result'}]}, + author='stub', + timestamp=None, + custom_metadata={}, + ) + return [row] + + +@pytest.mark.asyncio +async def test_custom_search_backend(): + svc = DatabaseMemoryService(_DB_URL, search_backend=_AlwaysReturnOneBackend()) + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query='anything') + assert len(resp.memories) == 1 + assert resp.memories[0].id == 'stub-id' + assert resp.memories[0].author == 'stub' + + +# --------------------------------------------------------------------------- +# 9. Engine construction errors raise ValueError +# --------------------------------------------------------------------------- + + +def test_bad_url_raises_value_error(): + with pytest.raises(ValueError, match='Invalid database URL'): + DatabaseMemoryService('not_a_valid_url://') + + +def test_missing_driver_raises_value_error(): + with pytest.raises(ValueError): + DatabaseMemoryService('sqlite+nonexistentdriver:///:memory:') + + +# --------------------------------------------------------------------------- +# 10. Multi-user isolation — user A results must not leak to user B +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_search_user_isolation(svc): + await svc.add_memory( + app_name=_APP, + user_id='user_a', + memories=[MemoryEntry(content=_make_content('secret data alpha'))], + ) + resp = await svc.search_memory( + app_name=_APP, user_id='user_b', query='secret' + ) + assert resp.memories == [], "User B should not see user A's memories" + + +@pytest.mark.asyncio +async def test_add_session_user_isolation(svc): + session_a = Session( + id='sess_a', + app_name=_APP, + user_id='user_a', + events=[_make_event('shared keyword')], + ) + await svc.add_session_to_memory(session_a) + + resp = await svc.search_memory( + app_name=_APP, user_id='user_b', query='shared' + ) + assert resp.memories == [] + + +# --------------------------------------------------------------------------- +# 11. Scratchpad KV scoping +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scratchpad_kv_session_scoping(svc): + await svc.set_scratchpad( + app_name=_APP, user_id=_USER, session_id='s1', key='scoped', value='yes' + ) + + val_s2 = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id='s2', key='scoped' + ) + assert val_s2 is None, 'Key from s1 must not appear in s2' + + val_user = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id='', key='scoped' + ) + assert val_user is None, 'Key from s1 must not appear in user-level scope' + + +@pytest.mark.asyncio +async def test_scratchpad_log_session_scoping(svc): + await svc.append_log( + app_name=_APP, + user_id=_USER, + session_id='s1', + content='session-one log', + ) + entries = await svc.get_log(app_name=_APP, user_id=_USER, session_id='s2') + assert entries == [], 'Log from s1 must not appear in s2' + + +# --------------------------------------------------------------------------- +# 12. add_memory with custom_metadata — verify merge +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_memory_custom_metadata_merge(svc): + entry = MemoryEntry( + content=_make_content('metadata test'), + author='agent', + custom_metadata={'entry_key': 'entry_val'}, + ) + await svc.add_memory( + app_name=_APP, + user_id=_USER, + memories=[entry], + custom_metadata={'call_key': 'call_val'}, + ) + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query='metadata') + assert len(resp.memories) == 1 + meta = resp.memories[0].custom_metadata + assert meta.get('entry_key') == 'entry_val' + assert meta.get('call_key') == 'call_val' + + +# --------------------------------------------------------------------------- +# 13. delete_scratchpad no-op +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scratchpad_delete_noop(svc): + await svc.delete_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='ghost' + ) + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='ghost' + ) + assert val is None + + +# --------------------------------------------------------------------------- +# 14. list_scratchpad_keys on empty scope +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scratchpad_list_keys_empty_scope(svc): + keys = await svc.list_scratchpad_keys( + app_name=_APP, user_id=_USER, session_id='brand_new_session' + ) + assert keys == [] + + +# --------------------------------------------------------------------------- +# 15. Scratchpad tool tests — all 4 BaseTool subclasses +# --------------------------------------------------------------------------- + + +def _make_tool_context(svc: DatabaseMemoryService, session_id: str = _SESSION): + session_mock = MagicMock() + session_mock.user_id = _USER + session_mock.id = session_id + + ic_mock = MagicMock() + ic_mock.app_name = _APP + ic_mock.session = session_mock + ic_mock.memory_service = svc + + ctx = MagicMock() + ctx._invocation_context = ic_mock + ctx.agent_name = 'test_agent' + return ctx + + +@pytest.mark.asyncio +async def test_scratchpad_set_tool_happy_path(svc): + from google.adk_community.tools.scratchpad_tool import ScratchpadSetTool + + tool = ScratchpadSetTool() + ctx = _make_tool_context(svc) + result = await tool.run_async( + args={'key': 'tool_key', 'value': 'tool_value'}, tool_context=ctx + ) + assert result == 'ok' + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='tool_key' + ) + assert val == 'tool_value' + + +@pytest.mark.asyncio +async def test_scratchpad_get_tool_happy_path(svc): + from google.adk_community.tools.scratchpad_tool import ScratchpadGetTool + + await svc.set_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key='gt_key', value=42 + ) + tool = ScratchpadGetTool() + ctx = _make_tool_context(svc) + val = await tool.run_async(args={'key': 'gt_key'}, tool_context=ctx) + assert val == 42 + + +@pytest.mark.asyncio +async def test_scratchpad_append_log_tool_happy_path(svc): + from google.adk_community.tools.scratchpad_tool import ScratchpadAppendLogTool + + tool = ScratchpadAppendLogTool() + ctx = _make_tool_context(svc) + result = await tool.run_async( + args={'content': 'observation logged', 'tag': 'obs'}, tool_context=ctx + ) + assert result == 'ok' + entries = await svc.get_log( + app_name=_APP, user_id=_USER, session_id=_SESSION, tag='obs' + ) + assert len(entries) == 1 + assert entries[0]['content'] == 'observation logged' + assert entries[0]['agent_name'] == 'test_agent' + + +@pytest.mark.asyncio +async def test_scratchpad_get_log_tool_happy_path(svc): + from google.adk_community.tools.scratchpad_tool import ScratchpadGetLogTool + + for i in range(5): + await svc.append_log( + app_name=_APP, + user_id=_USER, + session_id=_SESSION, + content=f'log {i}', + ) + tool = ScratchpadGetLogTool() + ctx = _make_tool_context(svc) + entries = await tool.run_async(args={'limit': 3}, tool_context=ctx) + assert len(entries) == 3 + + +# --------------------------------------------------------------------------- +# 15b. Wrong-service-type error for all 4 tools +# --------------------------------------------------------------------------- + + +def _make_wrong_service_context(): + from google.adk.memory.in_memory_memory_service import InMemoryMemoryService + + session_mock = MagicMock() + session_mock.user_id = _USER + session_mock.id = _SESSION + + ic_mock = MagicMock() + ic_mock.app_name = _APP + ic_mock.session = session_mock + ic_mock.memory_service = InMemoryMemoryService() + + ctx = MagicMock() + ctx._invocation_context = ic_mock + ctx.agent_name = 'test_agent' + return ctx + + +@pytest.mark.asyncio +async def test_scratchpad_get_tool_wrong_service(): + from google.adk_community.tools.scratchpad_tool import ScratchpadGetTool + + tool = ScratchpadGetTool() + with pytest.raises(ValueError, match='DatabaseMemoryService'): + await tool.run_async( + args={'key': 'x'}, tool_context=_make_wrong_service_context() + ) + + +@pytest.mark.asyncio +async def test_scratchpad_set_tool_wrong_service(): + from google.adk_community.tools.scratchpad_tool import ScratchpadSetTool + + tool = ScratchpadSetTool() + with pytest.raises(ValueError, match='DatabaseMemoryService'): + await tool.run_async( + args={'key': 'x', 'value': 1}, + tool_context=_make_wrong_service_context(), + ) + + +@pytest.mark.asyncio +async def test_scratchpad_append_log_tool_wrong_service(): + from google.adk_community.tools.scratchpad_tool import ScratchpadAppendLogTool + + tool = ScratchpadAppendLogTool() + with pytest.raises(ValueError, match='DatabaseMemoryService'): + await tool.run_async( + args={'content': 'x'}, tool_context=_make_wrong_service_context() + ) + + +@pytest.mark.asyncio +async def test_scratchpad_get_log_tool_wrong_service(): + from google.adk_community.tools.scratchpad_tool import ScratchpadGetLogTool + + tool = ScratchpadGetLogTool() + with pytest.raises(ValueError, match='DatabaseMemoryService'): + await tool.run_async(args={}, tool_context=_make_wrong_service_context()) diff --git a/tests/unittests/memory/test_open_memory_service.py b/tests/unittests/memory/test_open_memory_service.py index 74cb05e..718f6c6 100644 --- a/tests/unittests/memory/test_open_memory_service.py +++ b/tests/unittests/memory/test_open_memory_service.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock +from unittest.mock import MagicMock from unittest.mock import patch from google.adk.events.event import Event -from google.adk_community.memory.open_memory_service import ( - OpenMemoryService, - OpenMemoryServiceConfig, -) from google.adk.sessions.session import Session from google.genai import types import pytest +from google.adk_community.memory.open_memory_service import OpenMemoryService +from google.adk_community.memory.open_memory_service import OpenMemoryServiceConfig + MOCK_APP_NAME = 'test-app' MOCK_USER_ID = 'test-user' MOCK_SESSION_ID = 'session-1' @@ -39,7 +39,9 @@ invocation_id='inv-1', author='user', timestamp=12345, - content=types.Content(parts=[types.Part(text='Hello, I like Python.')]), + content=types.Content( + parts=[types.Part(text='Hello, I like Python.')] + ), ), Event( id='event-2', @@ -47,7 +49,9 @@ author='model', timestamp=12346, content=types.Content( - parts=[types.Part(text='Python is a great programming language.')] + parts=[ + types.Part(text='Python is a great programming language.') + ] ), ), # Empty event, should be ignored @@ -85,10 +89,12 @@ @pytest.fixture def mock_httpx_client(): """Mock httpx.AsyncClient for testing.""" - with patch('google.adk_community.memory.open_memory_service.httpx.AsyncClient') as mock_client_class: + with patch( + 'google.adk_community.memory.open_memory_service.httpx.AsyncClient' + ) as mock_client_class: mock_client = MagicMock() mock_response = MagicMock() - mock_response.json.return_value = {"matches": []} + mock_response.json.return_value = {'matches': []} mock_response.raise_for_status = MagicMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client.__aenter__ = AsyncMock(return_value=mock_client) @@ -107,14 +113,10 @@ def memory_service(mock_httpx_client): def memory_service_with_config(mock_httpx_client): """Create OpenMemoryService with custom config.""" config = OpenMemoryServiceConfig( - search_top_k=5, - user_content_salience=0.9, - model_content_salience=0.6 + search_top_k=5, user_content_salience=0.9, model_content_salience=0.6 ) return OpenMemoryService( - base_url='http://localhost:3000', - api_key='test-key', - config=config + base_url='http://localhost:3000', api_key='test-key', config=config ) @@ -139,7 +141,7 @@ def test_custom_config(self): user_content_salience=0.9, model_content_salience=0.75, default_salience=0.5, - enable_metadata_tags=False + enable_metadata_tags=False, ) assert config.search_top_k == 20 assert config.timeout == 10.0 @@ -158,18 +160,20 @@ def test_config_validation_search_top_k(self): def test_api_key_required(self): """Test that API key is required.""" - with pytest.raises(ValueError, match="api_key is required"): - OpenMemoryService(base_url="http://localhost:3000", api_key="") - - with pytest.raises(ValueError, match="api_key is required"): - OpenMemoryService(base_url="http://localhost:3000") + with pytest.raises(ValueError, match='api_key is required'): + OpenMemoryService(base_url='http://localhost:3000', api_key='') + + with pytest.raises(ValueError, match='api_key is required'): + OpenMemoryService(base_url='http://localhost:3000') class TestOpenMemoryService: """Tests for OpenMemoryService.""" @pytest.mark.asyncio - async def test_add_session_to_memory_success(self, memory_service, mock_httpx_client): + async def test_add_session_to_memory_success( + self, memory_service, mock_httpx_client + ): """Test successful addition of session memories.""" await memory_service.add_session_to_memory(MOCK_SESSION) @@ -220,9 +224,7 @@ async def test_add_session_uses_config_salience( assert request_data['salience'] == 0.6 # Custom model salience @pytest.mark.asyncio - async def test_add_session_without_metadata_tags( - self, mock_httpx_client - ): + async def test_add_session_without_metadata_tags(self, mock_httpx_client): """Test adding memories without metadata tags.""" config = OpenMemoryServiceConfig(enable_metadata_tags=False) memory_service = OpenMemoryService( @@ -236,7 +238,9 @@ async def test_add_session_without_metadata_tags( assert request_data.get('tags', []) == [] @pytest.mark.asyncio - async def test_add_session_error_handling(self, memory_service, mock_httpx_client): + async def test_add_session_error_handling( + self, memory_service, mock_httpx_client + ): """Test error handling during memory addition.""" mock_httpx_client.post.side_effect = Exception('API Error') @@ -254,20 +258,23 @@ async def test_search_memory_success(self, memory_service, mock_httpx_client): mock_response.json.return_value = { 'matches': [ { - 'content': '[Author: user, Time: 2025-01-01T00:00:00] Python is great', + 'content': ( + '[Author: user, Time: 2025-01-01T00:00:00] Python is great' + ), }, { - 'content': '[Author: model, Time: 2025-01-01T00:01:00] I like programming', - } + 'content': ( + '[Author: model, Time: 2025-01-01T00:01:00] I like' + ' programming' + ), + }, ] } mock_response.raise_for_status = MagicMock() mock_httpx_client.post = AsyncMock(return_value=mock_response) result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query='Python programming' + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='Python programming' ) # Verify API call @@ -276,7 +283,7 @@ async def test_search_memory_success(self, memory_service, mock_httpx_client): assert request_data['query'] == 'Python programming' assert request_data['k'] == 10 assert request_data['filter']['user_id'] == MOCK_USER_ID - assert f"app:{MOCK_APP_NAME}" in request_data['filter']['tags'] + assert f'app:{MOCK_APP_NAME}' in request_data['filter']['tags'] # Verify results (content should be cleaned of metadata prefix) assert len(result.memories) == 2 @@ -293,26 +300,24 @@ async def test_search_memory_applies_filters( # Mock response - server-side filtering ensures only matching results mock_response = MagicMock() mock_response.json.return_value = { - 'matches': [ - { - 'content': '[Author: model, Time: 2025-01-01T00:01:00] I like programming', - } - ] + 'matches': [{ + 'content': ( + '[Author: model, Time: 2025-01-01T00:01:00] I like programming' + ), + }] } mock_response.raise_for_status = MagicMock() mock_httpx_client.post = AsyncMock(return_value=mock_response) result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query='test query' + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='test query' ) # Verify filters were passed correctly call_args = mock_httpx_client.post.call_args request_data = call_args.kwargs['json'] assert request_data['filter']['user_id'] == MOCK_USER_ID - assert f"app:{MOCK_APP_NAME}" in request_data['filter']['tags'] + assert f'app:{MOCK_APP_NAME}' in request_data['filter']['tags'] # Should return filtered results assert len(result.memories) == 1 @@ -329,9 +334,7 @@ async def test_search_memory_respects_top_k( mock_httpx_client.post = AsyncMock(return_value=mock_response) await memory_service_with_config.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query='test query' + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='test query' ) call_args = mock_httpx_client.post.call_args @@ -346,11 +349,8 @@ async def test_search_memory_error_handling( mock_httpx_client.post.side_effect = Exception('API Error') result = await memory_service.search_memory( - app_name=MOCK_APP_NAME, - user_id=MOCK_USER_ID, - query='test query' + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='test query' ) # Should return empty results on error assert len(result.memories) == 0 - diff --git a/tests/unittests/sessions/test_redis_session_service.py b/tests/unittests/sessions/test_redis_session_service.py index dad0867..2354561 100644 --- a/tests/unittests/sessions/test_redis_session_service.py +++ b/tests/unittests/sessions/test_redis_session_service.py @@ -12,549 +12,568 @@ # See the License for the specific language governing permissions and # limitations under the License. -import orjson -from datetime import datetime, timezone -import pytest -import pytest_asyncio -from unittest.mock import AsyncMock, MagicMock, patch +from datetime import datetime +from datetime import timezone +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch from google.adk.events.event import Event from google.adk.events.event_actions import EventActions from google.adk.sessions.base_session_service import GetSessionConfig -from google.adk_community.sessions.redis_session_service import RedisSessionService from google.genai import types +import orjson +import pytest +import pytest_asyncio +from google.adk_community.sessions.redis_session_service import RedisSessionService -class TestRedisSessionService: - """Test cases for RedisSessionService.""" - - @pytest_asyncio.fixture - async def redis_service(self): - """Create a Redis session service for testing.""" - with patch("redis.asyncio.Redis") as mock_redis: - mock_client = AsyncMock() - mock_redis.return_value = mock_client - service = RedisSessionService() - service.cache = mock_client - yield service - - @pytest_asyncio.fixture - async def redis_cluster_service(self): - """Create a Redis cluster session service for testing.""" - with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: - mock_client = AsyncMock() - mock_redis_cluster.return_value = mock_client - cluster_uri = "redis://redis-node1:6379" - service = RedisSessionService(cluster_uri=cluster_uri) - service.cache = mock_client - yield service - - @pytest_asyncio.fixture - async def redis_cluster_uri_service(self): - """Create a Redis cluster session service using URI for testing.""" - with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: - mock_client = AsyncMock() - mock_redis_cluster.return_value = mock_client - cluster_uri = "redis://node1:6379,node2:6379" - service = RedisSessionService(cluster_uri=cluster_uri) - service.cache = mock_client - yield service - - def _setup_redis_mocks(self, redis_service, sessions_data=None): - """Helper to set up Redis mocks for the new storage strategy.""" - if sessions_data is None: - sessions_data = {} - - session_ids = list(sessions_data.keys()) - redis_service.cache.smembers = AsyncMock( - return_value={sid.encode() for sid in session_ids} - ) - - # Mock the new cluster-aware pipeline approach - session_values = [ - orjson.dumps(sessions_data[sid]) if sid in sessions_data else None - for sid in session_ids - ] - # For backward compatibility with mget approach (still used in some tests) - redis_service.cache.mget = AsyncMock(return_value=session_values) - - # Mock pipeline for the new cluster approach - if session_ids: - # Group sessions as the actual implementation does - results_per_group = [] - for i in range(len(session_ids)): - results_per_group.append([session_values[i]]) - - mock_context_manager = MagicMock() - mock_pipe = MagicMock() - mock_pipe.get = MagicMock(return_value=mock_pipe) - mock_pipe.execute = AsyncMock(side_effect=results_per_group) - mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) - mock_context_manager.__aexit__ = AsyncMock(return_value=None) - redis_service.cache.pipeline = MagicMock(return_value=mock_context_manager) - else: - mock_context_manager = MagicMock() - mock_pipe = MagicMock() - mock_pipe.get = MagicMock(return_value=mock_pipe) - mock_pipe.execute = AsyncMock(return_value=[]) - mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) - mock_context_manager.__aexit__ = AsyncMock(return_value=None) - redis_service.cache.pipeline = MagicMock(return_value=mock_context_manager) - - redis_service.cache.srem = AsyncMock() - redis_service.cache.get = AsyncMock(return_value=None) # Default to no session - - # Additional pipeline operations for create/update operations - if not session_ids: - mock_context_manager = MagicMock() - mock_pipe = MagicMock() - mock_pipe.set = MagicMock(return_value=mock_pipe) # Allow chaining - mock_pipe.sadd = MagicMock(return_value=mock_pipe) - mock_pipe.expire = MagicMock(return_value=mock_pipe) - mock_pipe.delete = MagicMock(return_value=mock_pipe) - mock_pipe.srem = MagicMock(return_value=mock_pipe) - mock_pipe.hset = MagicMock(return_value=mock_pipe) - mock_pipe.get = MagicMock(return_value=mock_pipe) - mock_pipe.execute = AsyncMock(return_value=[]) - mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) - mock_context_manager.__aexit__ = AsyncMock(return_value=None) - redis_service.cache.pipeline = MagicMock(return_value=mock_context_manager) - - redis_service.cache.hgetall = AsyncMock(return_value={}) - redis_service.cache.hset = AsyncMock() - - @pytest.mark.asyncio - async def test_get_empty_session(self, redis_service): - """Test getting a non-existent session.""" - self._setup_redis_mocks(redis_service) - - session = await redis_service.get_session( - app_name="test_app", user_id="test_user", session_id="nonexistent" - ) - - assert session is None - - @pytest.mark.asyncio - async def test_create_get_session(self, redis_service): - """Test session creation and retrieval.""" - app_name = "test_app" - user_id = "test_user" - state = {"key": "value"} - - self._setup_redis_mocks(redis_service) - - session = await redis_service.create_session( - app_name=app_name, user_id=user_id, state=state - ) - - assert session.app_name == app_name - assert session.user_id == user_id - assert session.id is not None - assert session.state == state - - # Allow tiny float/clock rounding differences (~1ms) - assert ( - session.last_update_time - <= datetime.now().astimezone(timezone.utc).timestamp() + 0.001 - ) - - # Mock individual session retrieval - redis_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - got_session = await redis_service.get_session( - app_name=app_name, user_id=user_id, session_id=session.id - ) - - assert got_session.app_name == session.app_name - assert got_session.user_id == session.user_id - assert got_session.id == session.id - assert got_session.state == session.state - - @pytest.mark.asyncio - async def test_create_and_list_sessions(self, redis_service): - """Test creating multiple sessions and listing them. - - list_sessions() is expected to return lightweight session summaries, - i.e., with events and state stripped for performance. - """ - app_name = "test_app" - user_id = "test_user" - - self._setup_redis_mocks(redis_service) - - session_ids = ["session" + str(i) for i in range(3)] - sessions_data = {} - - for i, session_id in enumerate(session_ids): - session = await redis_service.create_session( - app_name=app_name, - user_id=user_id, - session_id=session_id, - state={"key": "value" + session_id}, - ) - # Add at least one event to ensure list_sessions actually strips them. - session.events.append(Event(author="user", timestamp=float(i + 1))) - sessions_data[session_id] = session.model_dump() - - # Now mock Redis to return those sessions (with events present in storage) - self._setup_redis_mocks(redis_service, sessions_data) - - list_sessions_response = await redis_service.list_sessions( - app_name=app_name, user_id=user_id - ) - sessions = list_sessions_response.sessions - - assert len(sessions) == len(session_ids) - returned_session_ids = {s.id for s in sessions} - assert returned_session_ids == set(session_ids) - - for s in sessions: - # list_sessions returns summaries: events and state removed for perf. - assert len(s.events) == 0 - assert s.state == {} - - @pytest.mark.asyncio - async def test_session_state_management(self, redis_service): - """Test session state management with app, user, and temp state.""" - app_name = "test_app" - user_id = "test_user" - session_id = "test_session" - - self._setup_redis_mocks(redis_service) - - session = await redis_service.create_session( - app_name=app_name, - user_id=user_id, - session_id=session_id, - state={"initial_key": "initial_value"}, - ) - - event = Event( - invocation_id="invocation", - author="user", - content=types.Content(role="user", parts=[types.Part(text="text")]), - actions=EventActions( - state_delta={ - "app:key": "app_value", - "user:key1": "user_value", - "temp:key": "temp_value", - "initial_key": "updated_value", - } +class TestRedisSessionService: + """Test cases for RedisSessionService.""" + + @pytest_asyncio.fixture + async def redis_service(self): + """Create a Redis session service for testing.""" + with patch("redis.asyncio.Redis") as mock_redis: + mock_client = AsyncMock() + mock_redis.return_value = mock_client + service = RedisSessionService() + service.cache = mock_client + yield service + + @pytest_asyncio.fixture + async def redis_cluster_service(self): + """Create a Redis cluster session service for testing.""" + with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: + mock_client = AsyncMock() + mock_redis_cluster.return_value = mock_client + cluster_uri = "redis://redis-node1:6379" + service = RedisSessionService(cluster_uri=cluster_uri) + service.cache = mock_client + yield service + + @pytest_asyncio.fixture + async def redis_cluster_uri_service(self): + """Create a Redis cluster session service using URI for testing.""" + with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: + mock_client = AsyncMock() + mock_redis_cluster.return_value = mock_client + cluster_uri = "redis://node1:6379,node2:6379" + service = RedisSessionService(cluster_uri=cluster_uri) + service.cache = mock_client + yield service + + def _setup_redis_mocks(self, redis_service, sessions_data=None): + """Helper to set up Redis mocks for the new storage strategy.""" + if sessions_data is None: + sessions_data = {} + + session_ids = list(sessions_data.keys()) + redis_service.cache.smembers = AsyncMock( + return_value={sid.encode() for sid in session_ids} + ) + + # Mock the new cluster-aware pipeline approach + session_values = [ + orjson.dumps(sessions_data[sid]) if sid in sessions_data else None + for sid in session_ids + ] + + # For backward compatibility with mget approach (still used in some tests) + redis_service.cache.mget = AsyncMock(return_value=session_values) + + # Mock pipeline for the new cluster approach + if session_ids: + # Group sessions as the actual implementation does + results_per_group = [] + for i in range(len(session_ids)): + results_per_group.append([session_values[i]]) + + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock(side_effect=results_per_group) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_service.cache.pipeline = MagicMock( + return_value=mock_context_manager + ) + else: + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock(return_value=[]) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_service.cache.pipeline = MagicMock( + return_value=mock_context_manager + ) + + redis_service.cache.srem = AsyncMock() + redis_service.cache.get = AsyncMock( + return_value=None + ) # Default to no session + + # Additional pipeline operations for create/update operations + if not session_ids: + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.set = MagicMock(return_value=mock_pipe) # Allow chaining + mock_pipe.sadd = MagicMock(return_value=mock_pipe) + mock_pipe.expire = MagicMock(return_value=mock_pipe) + mock_pipe.delete = MagicMock(return_value=mock_pipe) + mock_pipe.srem = MagicMock(return_value=mock_pipe) + mock_pipe.hset = MagicMock(return_value=mock_pipe) + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock(return_value=[]) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_service.cache.pipeline = MagicMock( + return_value=mock_context_manager + ) + + redis_service.cache.hgetall = AsyncMock(return_value={}) + redis_service.cache.hset = AsyncMock() + + @pytest.mark.asyncio + async def test_get_empty_session(self, redis_service): + """Test getting a non-existent session.""" + self._setup_redis_mocks(redis_service) + + session = await redis_service.get_session( + app_name="test_app", user_id="test_user", session_id="nonexistent" + ) + + assert session is None + + @pytest.mark.asyncio + async def test_create_get_session(self, redis_service): + """Test session creation and retrieval.""" + app_name = "test_app" + user_id = "test_user" + state = {"key": "value"} + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, user_id=user_id, state=state + ) + + assert session.app_name == app_name + assert session.user_id == user_id + assert session.id is not None + assert session.state == state + + # Allow tiny float/clock rounding differences (~1ms) + assert ( + session.last_update_time + <= datetime.now().astimezone(timezone.utc).timestamp() + 0.001 + ) + + # Mock individual session retrieval + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + got_session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + + assert got_session.app_name == session.app_name + assert got_session.user_id == session.user_id + assert got_session.id == session.id + assert got_session.state == session.state + + @pytest.mark.asyncio + async def test_create_and_list_sessions(self, redis_service): + """Test creating multiple sessions and listing them. + + list_sessions() is expected to return lightweight session summaries, + i.e., with events and state stripped for performance. + """ + app_name = "test_app" + user_id = "test_user" + + self._setup_redis_mocks(redis_service) + + session_ids = ["session" + str(i) for i in range(3)] + sessions_data = {} + + for i, session_id in enumerate(session_ids): + session = await redis_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state={"key": "value" + session_id}, + ) + # Add at least one event to ensure list_sessions actually strips them. + session.events.append(Event(author="user", timestamp=float(i + 1))) + sessions_data[session_id] = session.model_dump() + + # Now mock Redis to return those sessions (with events present in storage) + self._setup_redis_mocks(redis_service, sessions_data) + + list_sessions_response = await redis_service.list_sessions( + app_name=app_name, user_id=user_id + ) + sessions = list_sessions_response.sessions + + assert len(sessions) == len(session_ids) + returned_session_ids = {s.id for s in sessions} + assert returned_session_ids == set(session_ids) + + for s in sessions: + # list_sessions returns summaries: events and state removed for perf. + assert len(s.events) == 0 + assert s.state == {} + + @pytest.mark.asyncio + async def test_session_state_management(self, redis_service): + """Test session state management with app, user, and temp state.""" + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state={"initial_key": "initial_value"}, + ) + + event = Event( + invocation_id="invocation", + author="user", + content=types.Content(role="user", parts=[types.Part(text="text")]), + actions=EventActions( + state_delta={ + "app:key": "app_value", + "user:key1": "user_value", + "temp:key": "temp_value", + "initial_key": "updated_value", + } + ), + ) + + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + await redis_service.append_event(session=session, event=event) + + assert session.state.get("app:key") == "app_value" + assert session.state.get("user:key1") == "user_value" + assert session.state.get("initial_key") == "updated_value" + assert session.state.get("temp:key") is None # Temp state filtered + + pipeline_mock = redis_service.cache.pipeline.return_value + pipe_mock = await pipeline_mock.__aenter__() + pipe_mock.hset.assert_any_call( + "app:test_app", "key", orjson.dumps("app_value") + ) + pipe_mock.hset.assert_any_call( + "user:test_app:test_user", "key1", orjson.dumps("user_value") + ) + + @pytest.mark.asyncio + async def test_append_event_with_bytes(self, redis_service): + """Test appending events with binary content and serialization roundtrip.""" + app_name = "test_app" + user_id = "test_user" + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, user_id=user_id + ) + + test_content = types.Content( + role="user", + parts=[ + types.Part.from_bytes( + data=b"test_image_data", mime_type="image/png" ), - ) - - redis_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - await redis_service.append_event(session=session, event=event) - - assert session.state.get("app:key") == "app_value" - assert session.state.get("user:key1") == "user_value" - assert session.state.get("initial_key") == "updated_value" - assert session.state.get("temp:key") is None # Temp state filtered - - pipeline_mock = redis_service.cache.pipeline.return_value - pipe_mock = await pipeline_mock.__aenter__() - pipe_mock.hset.assert_any_call("app:test_app", "key", orjson.dumps("app_value")) - pipe_mock.hset.assert_any_call( - "user:test_app:test_user", "key1", orjson.dumps("user_value") - ) - - @pytest.mark.asyncio - async def test_append_event_with_bytes(self, redis_service): - """Test appending events with binary content and serialization roundtrip.""" - app_name = "test_app" - user_id = "test_user" - - self._setup_redis_mocks(redis_service) - - session = await redis_service.create_session(app_name=app_name, user_id=user_id) - - test_content = types.Content( - role="user", - parts=[ - types.Part.from_bytes( - data=b"test_image_data", mime_type="image/png" - ), - ], - ) - test_grounding_metadata = types.GroundingMetadata( - search_entry_point=types.SearchEntryPoint(sdk_blob=b"test_sdk_blob") - ) - event = Event( - invocation_id="invocation", - author="user", - content=test_content, - grounding_metadata=test_grounding_metadata, - ) - - redis_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - await redis_service.append_event(session=session, event=event) - - # Verify the event was appended to in-memory session - assert len(session.events) == 1 - assert session.events[0].content == test_content - assert session.events[0].grounding_metadata == test_grounding_metadata - - # Test serialization/deserialization roundtrip to ensure binary data is preserved - # Simulate what happens when session is stored and retrieved from Redis - serialized_session = session.model_dump_json() - - redis_service.cache.get = AsyncMock(return_value=serialized_session.encode()) - - retrieved_session = await redis_service.get_session( - app_name=app_name, user_id=user_id, session_id=session.id - ) - - assert retrieved_session is not None - assert len(retrieved_session.events) == 1 - - # Verify the binary content was preserved through serialization - retrieved_event = retrieved_session.events[0] - assert retrieved_event.content.parts[0].inline_data.data == b"test_image_data" - assert ( - retrieved_event.content.parts[0].inline_data.mime_type - == "image/png" - ) - assert ( - retrieved_event.grounding_metadata.search_entry_point.sdk_blob - == b"test_sdk_blob" - ) - - @pytest.mark.asyncio - async def test_get_session_with_config(self, redis_service): - """Test getting session with configuration filters.""" - app_name = "test_app" - user_id = "test_user" - - self._setup_redis_mocks(redis_service) - - session = await redis_service.create_session(app_name=app_name, user_id=user_id) - - # Add multiple events with different timestamps - num_test_events = 5 - for i in range(1, num_test_events + 1): - event = Event(author="user", timestamp=float(i)) - session.events.append(event) - - redis_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - # Test num_recent_events filter - config = GetSessionConfig(num_recent_events=3) - filtered_session = await redis_service.get_session( - app_name=app_name, - user_id=user_id, - session_id=session.id, - config=config, - ) - - assert len(filtered_session.events) == 3 - assert filtered_session.events[0].timestamp == 3.0 # Last 3 events - - # Test after_timestamp filter - config = GetSessionConfig(after_timestamp=3.0) - filtered_session = await redis_service.get_session( - app_name=app_name, - user_id=user_id, - session_id=session.id, - config=config, - ) - - assert len(filtered_session.events) == 3 # Events 3, 4, 5 - assert filtered_session.events[0].timestamp == 3.0 - - @pytest.mark.asyncio - async def test_delete_session(self, redis_service): - """Test session deletion.""" - app_name = "test_app" - user_id = "test_user" - session_id = "test_session" - - self._setup_redis_mocks(redis_service) # Empty sessions - await redis_service.delete_session( - app_name=app_name, - user_id=user_id, - session_id=session_id, - ) - pipeline_mock = redis_service.cache.pipeline.return_value - pipe_mock = await pipeline_mock.__aenter__() - pipe_mock.execute.assert_called() - - redis_service.cache.pipeline.reset_mock() - self._setup_redis_mocks(redis_service) - - await redis_service.delete_session( - app_name=app_name, - user_id=user_id, - session_id=session_id, - ) - - pipeline_mock = redis_service.cache.pipeline.return_value - pipe_mock = await pipeline_mock.__aenter__() - pipe_mock.execute.assert_called() - - @pytest.mark.asyncio - async def test_cluster_health_check(self, redis_cluster_service): - """Test health check for Redis cluster.""" - redis_cluster_service.cache.ping = AsyncMock(return_value=True) - - result = await redis_cluster_service.health_check() - assert result is True - redis_cluster_service.cache.ping.assert_called_once() - - @pytest.mark.asyncio - async def test_cluster_health_check_failure(self, redis_cluster_service): - """Test health check failure for Redis cluster.""" - from redis import RedisError - - redis_cluster_service.cache.ping = AsyncMock( - side_effect=RedisError("Connection failed") - ) - - result = await redis_cluster_service.health_check() - assert result is False - - @pytest.mark.asyncio - async def test_cluster_create_and_get_session(self, redis_cluster_service): - """Test session creation and retrieval in cluster mode.""" - app_name = "cluster_test_app" - user_id = "cluster_test_user" - state = {"cluster_key": "cluster_value"} - - self._setup_redis_mocks(redis_cluster_service) - - session = await redis_cluster_service.create_session( - app_name=app_name, user_id=user_id, state=state - ) - - assert session.app_name == app_name - assert session.user_id == user_id - assert session.id is not None - assert session.state == state - - # Mock individual session retrieval - redis_cluster_service.cache.get = AsyncMock( - return_value=session.model_dump_json().encode() - ) - - got_session = await redis_cluster_service.get_session( - app_name=app_name, user_id=user_id, session_id=session.id - ) - - assert got_session.app_name == session.app_name - assert got_session.user_id == session.user_id - assert got_session.id == session.id - assert got_session.state == session.state - - @pytest.mark.asyncio - async def test_cluster_uri_initialization(self, redis_cluster_uri_service): - """Test Redis cluster initialization with URI.""" - assert redis_cluster_uri_service.cache is not None - - @pytest.mark.asyncio - async def test_cluster_error_handling(self, redis_cluster_service): - """Test error handling in cluster operations.""" - from redis import RedisError - - app_name = "test_app" - user_id = "test_user" - - # Mock Redis error during session loading - redis_cluster_service.cache.smembers = AsyncMock( - side_effect=RedisError("Cluster error") - ) - - sessions_response = await redis_cluster_service.list_sessions( - app_name=app_name, user_id=user_id - ) - - assert len(sessions_response.sessions) == 0 - - @pytest.mark.asyncio - async def test_cluster_connection_validation(self): - """Test cluster connection validation during initialization.""" - cluster_uri = "redis://redis-node1:6379" - - with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: - mock_client = AsyncMock() - mock_redis_cluster.return_value = mock_client - - service = RedisSessionService(cluster_uri=cluster_uri) - assert service.cache is not None - mock_redis_cluster.assert_called_once() - - @pytest.mark.asyncio - async def test_cluster_session_cleanup_on_error(self, redis_cluster_service): - """Test session cleanup when corrupted data is found in cluster.""" - app_name = "test_app" - user_id = "test_user" - - # Setup mock with corrupted session data - valid_session_data = { - "app_name": "test_app", - "user_id": "test_user", - "id": "session1", - "state": {}, - "events": [], - "last_update_time": 1234567890, - } - redis_cluster_service.cache.smembers = AsyncMock( - return_value={b"session1", b"session2"} - ) - - # Mock the pipeline for cluster approach - mock_context_manager = MagicMock() - mock_pipe = MagicMock() - mock_pipe.get = MagicMock(return_value=mock_pipe) - mock_pipe.execute = AsyncMock( - side_effect=[ - [orjson.dumps(valid_session_data)], # session1 result - [None], # session2 result (missing) - ] - ) - mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) - mock_context_manager.__aexit__ = AsyncMock(return_value=None) - redis_cluster_service.cache.pipeline = MagicMock( - return_value=mock_context_manager - ) - redis_cluster_service.cache.srem = AsyncMock() - redis_cluster_service.cache.hgetall = AsyncMock(return_value={}) - - sessions_response = await redis_cluster_service.list_sessions( - app_name=app_name, user_id=user_id - ) - - redis_cluster_service.cache.srem.assert_called() - assert len(sessions_response.sessions) == 1 - - @pytest.mark.asyncio - async def test_decode_responses_handling(self, redis_service): - """Test proper handling of decode_responses setting.""" - app_name = "test_app" - user_id = "test_user" - session_id = "test_session" - - # Test with bytes response (decode_responses=False) - session_data = ( - '{"app_name": "test_app", "user_id": "test_user", "id": "test_session", ' - '"state": {}, "events": [], "last_update_time": 1234567890}' - ) - redis_service.cache.get = AsyncMock(return_value=session_data.encode()) - redis_service.cache.hgetall = AsyncMock(return_value={}) - - session = await redis_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - - assert session is not None - assert session.app_name == app_name - assert session.user_id == user_id + ], + ) + test_grounding_metadata = types.GroundingMetadata( + search_entry_point=types.SearchEntryPoint(sdk_blob=b"test_sdk_blob") + ) + event = Event( + invocation_id="invocation", + author="user", + content=test_content, + grounding_metadata=test_grounding_metadata, + ) + + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + await redis_service.append_event(session=session, event=event) + + # Verify the event was appended to in-memory session + assert len(session.events) == 1 + assert session.events[0].content == test_content + assert session.events[0].grounding_metadata == test_grounding_metadata + + # Test serialization/deserialization roundtrip to ensure binary data is preserved + # Simulate what happens when session is stored and retrieved from Redis + serialized_session = session.model_dump_json() + + redis_service.cache.get = AsyncMock( + return_value=serialized_session.encode() + ) + + retrieved_session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + + assert retrieved_session is not None + assert len(retrieved_session.events) == 1 + + # Verify the binary content was preserved through serialization + retrieved_event = retrieved_session.events[0] + assert ( + retrieved_event.content.parts[0].inline_data.data == b"test_image_data" + ) + assert retrieved_event.content.parts[0].inline_data.mime_type == "image/png" + assert ( + retrieved_event.grounding_metadata.search_entry_point.sdk_blob + == b"test_sdk_blob" + ) + + @pytest.mark.asyncio + async def test_get_session_with_config(self, redis_service): + """Test getting session with configuration filters.""" + app_name = "test_app" + user_id = "test_user" + + self._setup_redis_mocks(redis_service) + + session = await redis_service.create_session( + app_name=app_name, user_id=user_id + ) + + # Add multiple events with different timestamps + num_test_events = 5 + for i in range(1, num_test_events + 1): + event = Event(author="user", timestamp=float(i)) + session.events.append(event) + + redis_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + # Test num_recent_events filter + config = GetSessionConfig(num_recent_events=3) + filtered_session = await redis_service.get_session( + app_name=app_name, + user_id=user_id, + session_id=session.id, + config=config, + ) + + assert len(filtered_session.events) == 3 + assert filtered_session.events[0].timestamp == 3.0 # Last 3 events + + # Test after_timestamp filter + config = GetSessionConfig(after_timestamp=3.0) + filtered_session = await redis_service.get_session( + app_name=app_name, + user_id=user_id, + session_id=session.id, + config=config, + ) + + assert len(filtered_session.events) == 3 # Events 3, 4, 5 + assert filtered_session.events[0].timestamp == 3.0 + + @pytest.mark.asyncio + async def test_delete_session(self, redis_service): + """Test session deletion.""" + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + self._setup_redis_mocks(redis_service) # Empty sessions + await redis_service.delete_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + ) + pipeline_mock = redis_service.cache.pipeline.return_value + pipe_mock = await pipeline_mock.__aenter__() + pipe_mock.execute.assert_called() + + redis_service.cache.pipeline.reset_mock() + self._setup_redis_mocks(redis_service) + + await redis_service.delete_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + ) + + pipeline_mock = redis_service.cache.pipeline.return_value + pipe_mock = await pipeline_mock.__aenter__() + pipe_mock.execute.assert_called() + + @pytest.mark.asyncio + async def test_cluster_health_check(self, redis_cluster_service): + """Test health check for Redis cluster.""" + redis_cluster_service.cache.ping = AsyncMock(return_value=True) + + result = await redis_cluster_service.health_check() + assert result is True + redis_cluster_service.cache.ping.assert_called_once() + + @pytest.mark.asyncio + async def test_cluster_health_check_failure(self, redis_cluster_service): + """Test health check failure for Redis cluster.""" + from redis import RedisError + + redis_cluster_service.cache.ping = AsyncMock( + side_effect=RedisError("Connection failed") + ) + + result = await redis_cluster_service.health_check() + assert result is False + + @pytest.mark.asyncio + async def test_cluster_create_and_get_session(self, redis_cluster_service): + """Test session creation and retrieval in cluster mode.""" + app_name = "cluster_test_app" + user_id = "cluster_test_user" + state = {"cluster_key": "cluster_value"} + + self._setup_redis_mocks(redis_cluster_service) + + session = await redis_cluster_service.create_session( + app_name=app_name, user_id=user_id, state=state + ) + + assert session.app_name == app_name + assert session.user_id == user_id + assert session.id is not None + assert session.state == state + + # Mock individual session retrieval + redis_cluster_service.cache.get = AsyncMock( + return_value=session.model_dump_json().encode() + ) + + got_session = await redis_cluster_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + + assert got_session.app_name == session.app_name + assert got_session.user_id == session.user_id + assert got_session.id == session.id + assert got_session.state == session.state + + @pytest.mark.asyncio + async def test_cluster_uri_initialization(self, redis_cluster_uri_service): + """Test Redis cluster initialization with URI.""" + assert redis_cluster_uri_service.cache is not None + + @pytest.mark.asyncio + async def test_cluster_error_handling(self, redis_cluster_service): + """Test error handling in cluster operations.""" + from redis import RedisError + + app_name = "test_app" + user_id = "test_user" + + # Mock Redis error during session loading + redis_cluster_service.cache.smembers = AsyncMock( + side_effect=RedisError("Cluster error") + ) + + sessions_response = await redis_cluster_service.list_sessions( + app_name=app_name, user_id=user_id + ) + + assert len(sessions_response.sessions) == 0 + + @pytest.mark.asyncio + async def test_cluster_connection_validation(self): + """Test cluster connection validation during initialization.""" + cluster_uri = "redis://redis-node1:6379" + + with patch("redis.asyncio.RedisCluster.from_url") as mock_redis_cluster: + mock_client = AsyncMock() + mock_redis_cluster.return_value = mock_client + + service = RedisSessionService(cluster_uri=cluster_uri) + assert service.cache is not None + mock_redis_cluster.assert_called_once() + + @pytest.mark.asyncio + async def test_cluster_session_cleanup_on_error(self, redis_cluster_service): + """Test session cleanup when corrupted data is found in cluster.""" + app_name = "test_app" + user_id = "test_user" + + # Setup mock with corrupted session data + valid_session_data = { + "app_name": "test_app", + "user_id": "test_user", + "id": "session1", + "state": {}, + "events": [], + "last_update_time": 1234567890, + } + redis_cluster_service.cache.smembers = AsyncMock( + return_value={b"session1", b"session2"} + ) + + # Mock the pipeline for cluster approach + mock_context_manager = MagicMock() + mock_pipe = MagicMock() + mock_pipe.get = MagicMock(return_value=mock_pipe) + mock_pipe.execute = AsyncMock( + side_effect=[ + [orjson.dumps(valid_session_data)], # session1 result + [None], # session2 result (missing) + ] + ) + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + redis_cluster_service.cache.pipeline = MagicMock( + return_value=mock_context_manager + ) + redis_cluster_service.cache.srem = AsyncMock() + redis_cluster_service.cache.hgetall = AsyncMock(return_value={}) + + sessions_response = await redis_cluster_service.list_sessions( + app_name=app_name, user_id=user_id + ) + + redis_cluster_service.cache.srem.assert_called() + assert len(sessions_response.sessions) == 1 + + @pytest.mark.asyncio + async def test_decode_responses_handling(self, redis_service): + """Test proper handling of decode_responses setting.""" + app_name = "test_app" + user_id = "test_user" + session_id = "test_session" + + # Test with bytes response (decode_responses=False) + session_data = ( + '{"app_name": "test_app", "user_id": "test_user", "id": "test_session",' + ' "state": {}, "events": [], "last_update_time": 1234567890}' + ) + redis_service.cache.get = AsyncMock(return_value=session_data.encode()) + redis_service.cache.hgetall = AsyncMock(return_value={}) + + session = await redis_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + assert session is not None + assert session.app_name == app_name + assert session.user_id == user_id