From ec6618705ca1a45306bfe686f86298c569c1e744 Mon Sep 17 00:00:00 2001 From: Nishar Date: Sat, 18 Apr 2026 18:12:42 -0500 Subject: [PATCH 1/6] feat: add S3ArtifactService with native async and atomic versioning Adds S3ArtifactService to the artifacts module, providing: - Native async I/O via aioboto3 (no asyncio.to_thread wrappers) - Atomic versioning using S3 IfNoneMatch conditional writes - Session-scoped and user-scoped artifact namespacing - Custom metadata (JSON-serialised in S3 user-metadata) - Batch delete (1000 keys per request) for efficient cleanup - Paginated listing for large artifact collections - Parallel head_object calls in list_artifact_versions - Optional [s3] dependency group (aioboto3>=13.0.0) - Comprehensive test suite with full async mock infrastructure Closes #37 Closes #71 --- pyproject.toml | 7 +- src/google/adk_community/__init__.py | 1 + src/google/adk_community/artifacts/README.md | 38 ++ .../adk_community/artifacts/__init__.py | 19 + .../artifacts/s3_artifact_service.py | 482 ++++++++++++++ tests/unittests/artifacts/__init__.py | 0 .../artifacts/test_s3_artifact_service.py | 630 ++++++++++++++++++ 7 files changed, 1175 insertions(+), 2 deletions(-) create mode 100644 src/google/adk_community/artifacts/README.md create mode 100644 src/google/adk_community/artifacts/__init__.py create mode 100644 src/google/adk_community/artifacts/s3_artifact_service.py create mode 100644 tests/unittests/artifacts/__init__.py create mode 100644 tests/unittests/artifacts/test_s3_artifact_service.py diff --git a/pyproject.toml b/pyproject.toml index 11afcd82..e77eae20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,12 +25,12 @@ classifiers = [ # List of https://pypi.org/classifiers/ ] dependencies = [ # go/keep-sorted start - "google-genai>=1.21.1, <2.0.0", # Google GenAI SDK "google-adk", # Google ADK + "google-genai>=1.21.1, <2.0.0", # Google GenAI SDK "httpx>=0.27.0, <1.0.0", # For OpenMemory service + "orjson>=3.11.3", "redis>=5.0.0, <6.0.0", # Redis for session storage # go/keep-sorted end - "orjson>=3.11.3", ] dynamic = ["version"] @@ -41,6 +41,9 @@ changelog = "https://github.com/google/adk-python-community/blob/main/CHANGELOG. documentation = "https://google.github.io/adk-docs/" [project.optional-dependencies] +s3 = [ + "aioboto3>=13.0.0", # For S3ArtifactService +] test = [ "pytest>=8.4.2", "pytest-asyncio>=1.2.0", diff --git a/src/google/adk_community/__init__.py b/src/google/adk_community/__init__.py index 9a1dc35f..803823d1 100644 --- a/src/google/adk_community/__init__.py +++ b/src/google/adk_community/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import artifacts from . import memory from . import sessions from . import version diff --git a/src/google/adk_community/artifacts/README.md b/src/google/adk_community/artifacts/README.md new file mode 100644 index 00000000..799a1d3b --- /dev/null +++ b/src/google/adk_community/artifacts/README.md @@ -0,0 +1,38 @@ +# Community Artifact Services + +This module contains community-contributed artifact service implementations for ADK. + +## Available Services + +### S3ArtifactService + +Production-ready artifact storage using Amazon S3 (or any S3-compatible service such as MinIO, DigitalOcean Spaces, etc.). + +**Installation:** +```bash +pip install google-adk-community[s3] +``` + +**Usage:** +```python +from google.adk_community.artifacts import S3ArtifactService + +artifact_service = S3ArtifactService( + bucket_name="my-adk-artifacts", + aws_configs={"region_name": "us-east-1"}, +) +``` + +**Features:** +- Native async I/O via `aioboto3` (no `asyncio.to_thread` wrappers) +- Atomic versioning using S3 conditional writes (`IfNoneMatch`) +- Session-scoped and user-scoped artifacts +- Automatic version management +- Custom metadata support (JSON-serialised) +- Batch delete for efficient cleanup +- Paginated listing for large artifact collections +- Works with S3-compatible services (MinIO, DigitalOcean Spaces, etc.) + +**See Also:** +- [S3ArtifactService Implementation](./s3_artifact_service.py) +- [Tests](../../../tests/unittests/artifacts/test_s3_artifact_service.py) diff --git a/src/google/adk_community/artifacts/__init__.py b/src/google/adk_community/artifacts/__init__.py new file mode 100644 index 00000000..6be6a93b --- /dev/null +++ b/src/google/adk_community/artifacts/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .s3_artifact_service import S3ArtifactService + +__all__ = [ + 'S3ArtifactService', +] diff --git a/src/google/adk_community/artifacts/s3_artifact_service.py b/src/google/adk_community/artifacts/s3_artifact_service.py new file mode 100644 index 00000000..581a7d72 --- /dev/null +++ b/src/google/adk_community/artifacts/s3_artifact_service.py @@ -0,0 +1,482 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An artifact service implementation using Amazon S3. + +The object key format depends on whether the filename has a user namespace: + - For files with user namespace (starting with "user:"): + {app_name}/{user_id}/user/{filename}/{version} + - For regular session-scoped files: + {app_name}/{user_id}/{session_id}/{filename}/{version} + +Uses aioboto3 for native async I/O and atomic versioning via +S3's ``IfNoneMatch`` condition to prevent race conditions. + +Install S3 support with:: + + pip install google-adk-community[s3] +""" +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +import json +import logging +from typing import Any +from typing import Optional + +from google.adk.artifacts.base_artifact_service import ArtifactVersion +from google.adk.artifacts.base_artifact_service import BaseArtifactService +from google.genai import types +from typing_extensions import override + +logger = logging.getLogger("google_adk_community." + __name__) + + +class S3ArtifactService(BaseArtifactService): + """An artifact service implementation using Amazon S3. + + Uses ``aioboto3`` for native async I/O instead of wrapping synchronous + calls with ``asyncio.to_thread``. Artifact saves are atomic: a + ``IfNoneMatch="*"`` condition on ``put_object`` prevents race conditions + when two writers try to create the same version concurrently. + + Args: + bucket_name: The name of the S3 bucket to use. + aws_configs: Extra keyword arguments forwarded to + ``aioboto3.Session().client("s3", ...)``. Use this to pass + ``region_name``, ``endpoint_url`` (for MinIO / Spaces), etc. + save_max_retries: Maximum retries on version conflict. + ``-1`` means retry indefinitely. + """ + + def __init__( + self, + bucket_name: str, + aws_configs: Optional[dict[str, Any]] = None, + save_max_retries: int = -1, + ): + try: + import aioboto3 # noqa: F401 + except ImportError as exc: + raise ImportError( + "aioboto3 is required to use S3ArtifactService. " + "Install it with: pip install google-adk-community[s3]" + ) from exc + + self.bucket_name = bucket_name + self.aws_configs: dict[str, Any] = aws_configs or {} + self.save_max_retries = save_max_retries + self._session = None + + # ------------------------------------------------------------------ # + # S3 client helpers + # ------------------------------------------------------------------ # + + async def _get_session(self): + import aioboto3 + + if self._session is None: + self._session = aioboto3.Session() + return self._session + + @asynccontextmanager + async def _client(self): + session = await self._get_session() + async with session.client( + service_name="s3", **self.aws_configs + ) as s3: + yield s3 + + # ------------------------------------------------------------------ # + # Metadata serialisation + # ------------------------------------------------------------------ # + + @staticmethod + def _flatten_metadata(metadata: Optional[dict[str, Any]]) -> dict[str, str]: + """JSON-encode metadata values for S3 user-metadata (strings only).""" + if not metadata: + return {} + return {str(k): json.dumps(v) for k, v in metadata.items()} + + @staticmethod + def _unflatten_metadata(metadata: Optional[dict[str, str]]) -> dict[str, Any]: + """Decode JSON metadata back to Python objects.""" + results: dict[str, Any] = {} + for k, v in (metadata or {}).items(): + try: + results[k] = json.loads(v) + except json.JSONDecodeError: + logger.warning( + "Failed to decode metadata value for key %r. Using raw string.", k + ) + results[k] = v + return results + + # ------------------------------------------------------------------ # + # Key helpers + # ------------------------------------------------------------------ # + + @staticmethod + def _file_has_user_namespace(filename: str) -> bool: + return filename.startswith("user:") + + def _get_blob_prefix( + self, + app_name: str, + user_id: str, + session_id: Optional[str], + filename: str, + ) -> str: + if self._file_has_user_namespace(filename): + return f"{app_name}/{user_id}/user/{filename[5:]}" # strip "user:" + if session_id is None: + raise ValueError( + "session_id is required for session-scoped artifacts." + ) + return f"{app_name}/{user_id}/{session_id}/{filename}" + + def _get_blob_name( + self, + app_name: str, + user_id: str, + session_id: Optional[str], + filename: str, + version: int, + ) -> str: + return ( + f"{self._get_blob_prefix(app_name, user_id, session_id, filename)}" + f"/{version}" + ) + + # ------------------------------------------------------------------ # + # Public API + # ------------------------------------------------------------------ # + + @override + async def save_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + artifact: types.Part, + session_id: Optional[str] = None, + custom_metadata: Optional[dict[str, Any]] = None, + ) -> int: + """Save an artifact with atomic versioning via ``IfNoneMatch``. + + If two concurrent callers race to create the same version, S3 + will reject the second ``put_object`` with a ``PreconditionFailed`` + error and this method will transparently retry. + """ + from botocore.exceptions import ClientError + + if self.save_max_retries < 0: + retry_iter = iter(int, 1) # infinite iterator + else: + retry_iter = range(self.save_max_retries + 1) + + for _ in retry_iter: + versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + version = 0 if not versions else max(versions) + 1 + key = self._get_blob_name( + app_name, user_id, session_id, filename, version + ) + + # Prepare data and content type + if artifact.inline_data: + body = artifact.inline_data.data + content_type = ( + artifact.inline_data.mime_type or "application/octet-stream" + ) + elif artifact.text: + body = artifact.text.encode("utf-8") + content_type = "text/plain; charset=utf-8" + else: + raise ValueError( + "Artifact must have either inline_data or text content." + ) + + async with self._client() as s3: + try: + await s3.put_object( + Bucket=self.bucket_name, + Key=key, + Body=body, + ContentType=content_type, + Metadata=self._flatten_metadata(custom_metadata), + IfNoneMatch="*", + ) + logger.debug( + "Saved artifact %s version %d to s3://%s/%s", + filename, + version, + self.bucket_name, + key, + ) + return version + except ClientError as e: + code = e.response.get("Error", {}).get("Code", "") + if code in ("PreconditionFailed", "ObjectAlreadyExists"): + logger.debug( + "Version conflict for %s version %d, retryingโ€ฆ", + filename, + version, + ) + continue + raise + + raise RuntimeError( + "Failed to save artifact due to version conflicts after retries." + ) + + @override + async def load_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + ) -> Optional[types.Part]: + """Load a specific version (or latest) of an artifact from S3.""" + from botocore.exceptions import ClientError + + if version is None: + versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + if not versions: + return None + version = max(versions) + + key = self._get_blob_name( + app_name, user_id, session_id, filename, version + ) + async with self._client() as s3: + try: + response = await s3.get_object(Bucket=self.bucket_name, Key=key) + async with response["Body"] as stream: + data = await stream.read() + content_type = response.get("ContentType", "application/octet-stream") + except ClientError as e: + code = e.response.get("Error", {}).get("Code", "") + if code in ("NoSuchKey", "404"): + return None + raise + + logger.debug( + "Loaded artifact %s version %d from s3://%s/%s", + filename, + version, + self.bucket_name, + key, + ) + return types.Part.from_bytes(data=data, mime_type=content_type) + + @override + async def list_artifact_keys( + self, *, app_name: str, user_id: str, session_id: Optional[str] = None + ) -> list[str]: + """List all artifact keys for a user, optionally scoped to a session.""" + keys: set[str] = set() + prefixes = [ + f"{app_name}/{user_id}/{session_id}/" if session_id else None, + f"{app_name}/{user_id}/user/", + ] + async with self._client() as s3: + for prefix in filter(None, prefixes): + paginator = s3.get_paginator("list_objects_v2") + async for page in paginator.paginate( + Bucket=self.bucket_name, Prefix=prefix + ): + for obj in page.get("Contents", []): + relative = obj["Key"][len(prefix):] + # relative is "{filename}/{version}" โ€” strip version part + parts = relative.rsplit("/", 1) + if len(parts) >= 2: + raw_filename = parts[0] + # Re-add "user:" prefix for user-scoped artifacts + if prefix.endswith("/user/"): + keys.add(f"user:{raw_filename}") + else: + keys.add(raw_filename) + return sorted(keys) + + @override + async def delete_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: Optional[str] = None, + ) -> None: + """Delete all versions of an artifact using S3 batch delete.""" + versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + if not versions: + return + + keys_to_delete = [ + { + "Key": self._get_blob_name( + app_name, user_id, session_id, filename, v + ) + } + for v in versions + ] + async with self._client() as s3: + # S3 batch delete supports up to 1000 keys per request + for i in range(0, len(keys_to_delete), 1000): + batch = keys_to_delete[i : i + 1000] + await s3.delete_objects( + Bucket=self.bucket_name, Delete={"Objects": batch} + ) + + @override + async def list_versions( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: Optional[str] = None, + ) -> list[int]: + """List all available version numbers for an artifact.""" + prefix = ( + self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" + ) + versions: list[int] = [] + async with self._client() as s3: + paginator = s3.get_paginator("list_objects_v2") + async for page in paginator.paginate( + Bucket=self.bucket_name, Prefix=prefix + ): + for obj in page.get("Contents", []): + version_str = obj["Key"].split("/")[-1] + try: + versions.append(int(version_str)) + except ValueError: + continue + return sorted(versions) + + @override + async def list_artifact_versions( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: Optional[str] = None, + ) -> list[ArtifactVersion]: + """List all versions with metadata, using parallel head_object calls.""" + prefix = ( + self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" + ) + results: list[ArtifactVersion] = [] + async with self._client() as s3: + paginator = s3.get_paginator("list_objects_v2") + async for page in paginator.paginate( + Bucket=self.bucket_name, Prefix=prefix + ): + page_objects = page.get("Contents", []) + if not page_objects: + continue + + # Parallelise head_object calls for each page + head_tasks = [ + s3.head_object(Bucket=self.bucket_name, Key=obj["Key"]) + for obj in page_objects + ] + heads = await asyncio.gather(*head_tasks) + + for obj, head in zip(page_objects, heads): + version_str = obj["Key"].split("/")[-1] + try: + version = int(version_str) + except ValueError: + continue + results.append( + ArtifactVersion( + version=version, + canonical_uri=f"s3://{self.bucket_name}/{obj['Key']}", + custom_metadata=self._unflatten_metadata( + head.get("Metadata", {}) + ), + create_time=obj["LastModified"].timestamp(), + mime_type=head["ContentType"], + ) + ) + return sorted(results, key=lambda a: a.version) + + @override + async def get_artifact_version( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + ) -> Optional[ArtifactVersion]: + """Retrieve metadata for a specific version (or the latest).""" + from botocore.exceptions import ClientError + + if version is None: + all_versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + if not all_versions: + return None + version = max(all_versions) + + key = self._get_blob_name( + app_name, user_id, session_id, filename, version + ) + async with self._client() as s3: + try: + head = await s3.head_object(Bucket=self.bucket_name, Key=key) + except ClientError as e: + code = e.response.get("Error", {}).get("Code", "") + if code in ("NoSuchKey", "404"): + return None + raise + + return ArtifactVersion( + version=version, + canonical_uri=f"s3://{self.bucket_name}/{key}", + custom_metadata=self._unflatten_metadata( + head.get("Metadata", {}) + ), + create_time=head["LastModified"].timestamp(), + mime_type=head["ContentType"], + ) diff --git a/tests/unittests/artifacts/__init__.py b/tests/unittests/artifacts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/artifacts/test_s3_artifact_service.py b/tests/unittests/artifacts/test_s3_artifact_service.py new file mode 100644 index 00000000..e1bc8e09 --- /dev/null +++ b/tests/unittests/artifacts/test_s3_artifact_service.py @@ -0,0 +1,630 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=missing-class-docstring,missing-function-docstring + +"""Tests for S3ArtifactService.""" + +import asyncio +from datetime import datetime +from datetime import timezone +from typing import Any +from typing import Optional +from unittest import mock + +from google.adk.artifacts.base_artifact_service import ArtifactVersion +from google.adk_community.artifacts import S3ArtifactService +from google.genai import types +import pytest + +# --------------------------------------------------------------------------- # +# Mock infrastructure +# --------------------------------------------------------------------------- # + +FIXED_DATETIME = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + +class MockS3Object: + """Mocks an S3 object stored in a bucket.""" + + def __init__(self, key: str) -> None: + self.key = key + self.data: Optional[bytes] = None + self.content_type: Optional[str] = None + self.last_modified = FIXED_DATETIME + self.metadata: dict[str, str] = {} + + def set_data( + self, data: bytes, content_type: str, metadata: dict[str, str] + ): + self.data = data + self.content_type = content_type + self.metadata = metadata or {} + + +class MockS3Bucket: + """Mocks an S3 bucket.""" + + def __init__(self, name: str) -> None: + self.name = name + self.objects: dict[str, MockS3Object] = {} + + +class MockS3ResponseBody: + """Mocks the async streaming body returned by get_object.""" + + def __init__(self, data: bytes): + self._data = data + + async def read(self) -> bytes: + return self._data + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + +class MockS3Client: + """Mocks an aioboto3 S3 client with async context manager support.""" + + def __init__(self, **kwargs) -> None: + self.buckets: dict[str, MockS3Bucket] = {} + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + # --- S3 operations ---------------------------------------------------- # + + async def head_bucket(self, Bucket: str): + if Bucket not in self.buckets: + self.buckets[Bucket] = MockS3Bucket(Bucket) + return {} + + async def put_object( + self, + Bucket: str, + Key: str, + Body: bytes, + ContentType: str, + Metadata: Optional[dict[str, str]] = None, + IfNoneMatch: Optional[str] = None, + **kwargs, + ): + if Bucket not in self.buckets: + self.buckets[Bucket] = MockS3Bucket(Bucket) + bucket = self.buckets[Bucket] + + # Simulate atomic IfNoneMatch="*" โ€” reject if key already exists + if IfNoneMatch == "*" and Key in bucket.objects and bucket.objects[Key].data is not None: + from botocore.exceptions import ClientError + + raise ClientError( + {"Error": {"Code": "PreconditionFailed", "Message": "Object exists"}}, + "PutObject", + ) + + if Key not in bucket.objects: + bucket.objects[Key] = MockS3Object(Key) + bucket.objects[Key].set_data(Body, ContentType, Metadata or {}) + + async def get_object(self, Bucket: str, Key: str): + bucket = self.buckets.get(Bucket) + if not bucket or Key not in bucket.objects or bucket.objects[Key].data is None: + from botocore.exceptions import ClientError + + raise ClientError( + {"Error": {"Code": "NoSuchKey", "Message": "Not found"}}, + "GetObject", + ) + obj = bucket.objects[Key] + return { + "Body": MockS3ResponseBody(obj.data), + "ContentType": obj.content_type, + "LastModified": obj.last_modified, + "Metadata": obj.metadata, + } + + async def head_object(self, Bucket: str, Key: str): + bucket = self.buckets.get(Bucket) + if not bucket or Key not in bucket.objects or bucket.objects[Key].data is None: + from botocore.exceptions import ClientError + + raise ClientError( + {"Error": {"Code": "404", "Message": "Not found"}}, + "HeadObject", + ) + obj = bucket.objects[Key] + return { + "ContentType": obj.content_type, + "LastModified": obj.last_modified, + "Metadata": obj.metadata, + } + + async def delete_object(self, Bucket: str, Key: str): + bucket = self.buckets.get(Bucket) + if bucket and Key in bucket.objects: + del bucket.objects[Key] + + async def delete_objects(self, Bucket: str, Delete: dict): + bucket = self.buckets.get(Bucket) + if bucket: + for obj_spec in Delete.get("Objects", []): + key = obj_spec["Key"] + if key in bucket.objects: + del bucket.objects[key] + + def get_paginator(self, operation_name: str): + return MockS3Paginator(self, operation_name) + + +class MockS3Paginator: + """Mocks an S3 paginator returned by get_paginator.""" + + def __init__(self, client: MockS3Client, operation_name: str): + self.client = client + self.operation_name = operation_name + + def paginate(self, Bucket: str, Prefix: str = ""): + return MockS3PaginateResult(self.client, Bucket, Prefix) + + +class MockS3PaginateResult: + """Async iterator that yields a single page of list_objects_v2 results.""" + + def __init__(self, client: MockS3Client, bucket: str, prefix: str): + self.client = client + self.bucket_name = bucket + self.prefix = prefix + self._yielded = False + + def __aiter__(self): + self._yielded = False + return self + + async def __anext__(self): + if self._yielded: + raise StopAsyncIteration + self._yielded = True + + bucket = self.client.buckets.get(self.bucket_name) + if not bucket: + return {} + + contents = [] + for key, obj in bucket.objects.items(): + if key.startswith(self.prefix) and obj.data is not None: + contents.append({ + "Key": key, + "LastModified": obj.last_modified, + }) + + if contents: + return {"Contents": contents} + return {} + + +class MockS3Session: + """Mocks aioboto3.Session.""" + + def __init__(self, mock_client: MockS3Client): + self._client = mock_client + + def client(self, service_name: str, **kwargs): + return self._client + + +# --------------------------------------------------------------------------- # +# Fixtures +# --------------------------------------------------------------------------- # + +@pytest.fixture +def mock_s3_service(): + """Provides a mocked S3ArtifactService for testing.""" + mock_client = MockS3Client() + mock_session = MockS3Session(mock_client) + + with mock.patch("aioboto3.Session", return_value=mock_session): + service = S3ArtifactService(bucket_name="test_bucket") + # Ensure the mock session is used + service._session = mock_session + return service + + +# --------------------------------------------------------------------------- # +# Tests +# --------------------------------------------------------------------------- # + +@pytest.mark.asyncio +async def test_load_empty(mock_s3_service): + """Loading an artifact when none exists returns None.""" + result = await mock_s3_service.load_artifact( + app_name="test_app", + user_id="test_user", + session_id="session_id", + filename="filename", + ) + assert result is None + + +@pytest.mark.asyncio +async def test_save_load_delete(mock_s3_service): + """Full CRUD cycle: save, load, load-missing-version, delete.""" + artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain") + app_name = "app0" + user_id = "user0" + session_id = "123" + filename = "file456" + + version = await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=artifact, + ) + assert version == 0 + + loaded = await mock_s3_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + assert loaded == artifact + + # Non-existent version + assert not await mock_s3_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + version=3, + ) + + await mock_s3_service.delete_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + assert not await mock_s3_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + + +@pytest.mark.asyncio +async def test_list_keys(mock_s3_service): + """Listing artifact keys returns all saved filenames.""" + artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain") + app_name = "app0" + user_id = "user0" + session_id = "123" + filenames = [f"filename{i}" for i in range(5)] + + for f in filenames: + await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=f, + artifact=artifact, + ) + + keys = await mock_s3_service.list_artifact_keys( + app_name=app_name, user_id=user_id, session_id=session_id + ) + assert keys == filenames + + +@pytest.mark.asyncio +async def test_list_versions(mock_s3_service): + """Multiple saves of the same artifact create incremental versions.""" + app_name = "app0" + user_id = "user0" + session_id = "123" + filename = "report.txt" + parts = [ + types.Part.from_bytes( + data=i.to_bytes(2, byteorder="big"), mime_type="text/plain" + ) + for i in range(3) + ] + parts.append(types.Part.from_text(text="hello")) + + for p in parts: + await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=p, + ) + + versions = await mock_s3_service.list_versions( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + assert versions == [0, 1, 2, 3] + + +@pytest.mark.asyncio +async def test_list_keys_preserves_user_prefix(mock_s3_service): + """User-scoped artifacts keep the 'user:' prefix in key listings.""" + artifact = types.Part.from_bytes(data=b"data", mime_type="text/plain") + app_name = "app0" + user_id = "user0" + session_id = "123" + + await mock_s3_service.save_artifact( + app_name=app_name, user_id=user_id, session_id=session_id, + filename="user:document.pdf", artifact=artifact, + ) + await mock_s3_service.save_artifact( + app_name=app_name, user_id=user_id, session_id=session_id, + filename="user:image.png", artifact=artifact, + ) + await mock_s3_service.save_artifact( + app_name=app_name, user_id=user_id, session_id=session_id, + filename="session_file.txt", artifact=artifact, + ) + + keys = await mock_s3_service.list_artifact_keys( + app_name=app_name, user_id=user_id, session_id=session_id + ) + assert sorted(keys) == ["session_file.txt", "user:document.pdf", "user:image.png"] + + +@pytest.mark.asyncio +async def test_list_artifact_versions_and_get_artifact_version( + mock_s3_service, +): + """Artifact version metadata includes canonical URI and custom metadata.""" + app_name = "app0" + user_id = "user0" + session_id = "123" + filename = "filename" + + for i in range(4): + part = types.Part.from_bytes( + data=i.to_bytes(2, byteorder="big"), mime_type="text/plain" + ) + await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=part, + custom_metadata={"key": f"value{i}"}, + ) + + artifact_versions = await mock_s3_service.list_artifact_versions( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + + assert len(artifact_versions) == 4 + for i, av in enumerate(artifact_versions): + assert av.version == i + assert ( + av.canonical_uri + == f"s3://test_bucket/{app_name}/{user_id}/{session_id}/{filename}/{i}" + ) + assert av.custom_metadata["key"] == f"value{i}" + assert av.mime_type == "text/plain" + + # Get latest + latest = await mock_s3_service.get_artifact_version( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + assert latest is not None + assert latest.version == 3 + + # Get specific version + specific = await mock_s3_service.get_artifact_version( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + version=2, + ) + assert specific is not None + assert specific.version == 2 + + +@pytest.mark.asyncio +async def test_list_artifact_versions_with_user_prefix(mock_s3_service): + """User-scoped artifact versions have correct canonical URIs.""" + app_name = "app0" + user_id = "user0" + session_id = "123" + user_filename = "user:document.pdf" + + for i in range(4): + part = types.Part.from_bytes( + data=i.to_bytes(2, byteorder="big"), mime_type="text/plain" + ) + await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=user_filename, + artifact=part, + custom_metadata={"key": f"value{i}"}, + ) + + artifact_versions = await mock_s3_service.list_artifact_versions( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=user_filename, + ) + + assert len(artifact_versions) == 4 + for i, av in enumerate(artifact_versions): + assert av.version == i + assert ( + av.canonical_uri + == f"s3://test_bucket/{app_name}/{user_id}/user/document.pdf/{i}" + ) + + +@pytest.mark.asyncio +async def test_get_artifact_version_artifact_does_not_exist(mock_s3_service): + """Getting a version for a non-existent artifact returns None.""" + result = await mock_s3_service.get_artifact_version( + app_name="test_app", + user_id="test_user", + session_id="session_id", + filename="filename", + ) + assert result is None + + +@pytest.mark.asyncio +async def test_get_artifact_version_out_of_index(mock_s3_service): + """Getting a non-existent version number returns None.""" + artifact = types.Part.from_bytes(data=b"data", mime_type="text/plain") + await mock_s3_service.save_artifact( + app_name="app0", + user_id="user0", + session_id="123", + filename="filename", + artifact=artifact, + ) + result = await mock_s3_service.get_artifact_version( + app_name="app0", + user_id="user0", + session_id="123", + filename="filename", + version=3, + ) + assert result is None + + +@pytest.mark.asyncio +async def test_empty_artifact(mock_s3_service): + """Saving and loading 0-byte artifacts works correctly.""" + empty = types.Part.from_bytes(data=b"", mime_type="text/plain") + + version = await mock_s3_service.save_artifact( + app_name="app0", + user_id="user0", + session_id="123", + filename="empty.txt", + artifact=empty, + ) + assert version == 0 + + loaded = await mock_s3_service.load_artifact( + app_name="app0", + user_id="user0", + session_id="123", + filename="empty.txt", + ) + assert loaded is not None + assert loaded.inline_data is not None + assert loaded.inline_data.data == b"" + + +@pytest.mark.asyncio +async def test_custom_metadata(mock_s3_service): + """Custom metadata is stored and retrieved correctly (JSON-encoded).""" + artifact = types.Part.from_text(text="Test") + custom_metadata = {"author": "test", "tags": ["integration", "test"]} + + await mock_s3_service.save_artifact( + app_name="app0", + user_id="user0", + session_id="123", + filename="test.txt", + artifact=artifact, + custom_metadata=custom_metadata, + ) + + version_info = await mock_s3_service.get_artifact_version( + app_name="app0", + user_id="user0", + session_id="123", + filename="test.txt", + ) + + assert version_info is not None + assert version_info.custom_metadata["author"] == "test" + assert version_info.custom_metadata["tags"] == ["integration", "test"] + + +@pytest.mark.asyncio +async def test_text_artifact_roundtrip(mock_s3_service): + """Text artifacts are encoded to UTF-8 bytes on save and loaded as bytes.""" + artifact = types.Part.from_text(text="Hello, world! ๐ŸŒ") + + version = await mock_s3_service.save_artifact( + app_name="app0", + user_id="user0", + session_id="123", + filename="greeting.txt", + artifact=artifact, + ) + assert version == 0 + + loaded = await mock_s3_service.load_artifact( + app_name="app0", + user_id="user0", + session_id="123", + filename="greeting.txt", + ) + assert loaded is not None + assert loaded.inline_data is not None + assert loaded.inline_data.data == "Hello, world! ๐ŸŒ".encode("utf-8") + + +@pytest.mark.asyncio +async def test_save_artifact_version_conflict_retry(mock_s3_service): + """Atomic versioning retries on PreconditionFailed.""" + artifact = types.Part.from_bytes(data=b"data", mime_type="text/plain") + + # Save version 0 successfully + v0 = await mock_s3_service.save_artifact( + app_name="app0", + user_id="user0", + session_id="123", + filename="conflict.txt", + artifact=artifact, + ) + assert v0 == 0 + + # Save version 1 โ€” should succeed because version 1 doesn't exist yet + v1 = await mock_s3_service.save_artifact( + app_name="app0", + user_id="user0", + session_id="123", + filename="conflict.txt", + artifact=artifact, + ) + assert v1 == 1 From 93acebf1a1f13868459c8e6b1dd80bb023331448 Mon Sep 17 00:00:00 2001 From: Nishar Date: Sat, 18 Apr 2026 18:12:42 -0500 Subject: [PATCH 2/6] style: match codebase conventions for docstrings and comments - Remove section comment banners (# --- heading --- patterns) - Simplify module and class docstrings to one-liners - Move Args section to __init__ docstring (matches RedisSessionService) - Shorten method docstrings to single line - Remove explanatory inline comments - Clean up blank lines left from removals --- .../artifacts/s3_artifact_service.py | 85 +++++-------------- tests/__init__.py | 0 .../artifacts/test_s3_artifact_service.py | 18 ++-- 3 files changed, 27 insertions(+), 76 deletions(-) create mode 100644 tests/__init__.py diff --git a/src/google/adk_community/artifacts/s3_artifact_service.py b/src/google/adk_community/artifacts/s3_artifact_service.py index 581a7d72..d93de542 100644 --- a/src/google/adk_community/artifacts/s3_artifact_service.py +++ b/src/google/adk_community/artifacts/s3_artifact_service.py @@ -12,21 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""An artifact service implementation using Amazon S3. - -The object key format depends on whether the filename has a user namespace: - - For files with user namespace (starting with "user:"): - {app_name}/{user_id}/user/{filename}/{version} - - For regular session-scoped files: - {app_name}/{user_id}/{session_id}/{filename}/{version} - -Uses aioboto3 for native async I/O and atomic versioning via -S3's ``IfNoneMatch`` condition to prevent race conditions. - -Install S3 support with:: - - pip install google-adk-community[s3] -""" +"""Artifact service implementation using Amazon S3.""" from __future__ import annotations import asyncio @@ -45,21 +31,7 @@ class S3ArtifactService(BaseArtifactService): - """An artifact service implementation using Amazon S3. - - Uses ``aioboto3`` for native async I/O instead of wrapping synchronous - calls with ``asyncio.to_thread``. Artifact saves are atomic: a - ``IfNoneMatch="*"`` condition on ``put_object`` prevents race conditions - when two writers try to create the same version concurrently. - - Args: - bucket_name: The name of the S3 bucket to use. - aws_configs: Extra keyword arguments forwarded to - ``aioboto3.Session().client("s3", ...)``. Use this to pass - ``region_name``, ``endpoint_url`` (for MinIO / Spaces), etc. - save_max_retries: Maximum retries on version conflict. - ``-1`` means retry indefinitely. - """ + """An S3-backed implementation of the artifact service.""" def __init__( self, @@ -67,6 +39,15 @@ def __init__( aws_configs: Optional[dict[str, Any]] = None, save_max_retries: int = -1, ): + """Initializes the S3 artifact service. + + Args: + bucket_name: The name of the S3 bucket to use. + aws_configs: Extra kwargs forwarded to the aioboto3 S3 client. + Use this to pass region_name, endpoint_url (for MinIO), etc. + save_max_retries: Maximum retries on version conflict. -1 means + retry indefinitely. + """ try: import aioboto3 # noqa: F401 except ImportError as exc: @@ -80,9 +61,7 @@ def __init__( self.save_max_retries = save_max_retries self._session = None - # ------------------------------------------------------------------ # - # S3 client helpers - # ------------------------------------------------------------------ # + async def _get_session(self): import aioboto3 @@ -99,20 +78,16 @@ async def _client(self): ) as s3: yield s3 - # ------------------------------------------------------------------ # - # Metadata serialisation - # ------------------------------------------------------------------ # - @staticmethod def _flatten_metadata(metadata: Optional[dict[str, Any]]) -> dict[str, str]: - """JSON-encode metadata values for S3 user-metadata (strings only).""" + """JSON-encode metadata values for S3 user-metadata.""" if not metadata: return {} return {str(k): json.dumps(v) for k, v in metadata.items()} @staticmethod def _unflatten_metadata(metadata: Optional[dict[str, str]]) -> dict[str, Any]: - """Decode JSON metadata back to Python objects.""" + """Decode JSON metadata back to native Python objects.""" results: dict[str, Any] = {} for k, v in (metadata or {}).items(): try: @@ -124,10 +99,6 @@ def _unflatten_metadata(metadata: Optional[dict[str, str]]) -> dict[str, Any]: results[k] = v return results - # ------------------------------------------------------------------ # - # Key helpers - # ------------------------------------------------------------------ # - @staticmethod def _file_has_user_namespace(filename: str) -> bool: return filename.startswith("user:") @@ -160,10 +131,6 @@ def _get_blob_name( f"/{version}" ) - # ------------------------------------------------------------------ # - # Public API - # ------------------------------------------------------------------ # - @override async def save_artifact( self, @@ -175,12 +142,7 @@ async def save_artifact( session_id: Optional[str] = None, custom_metadata: Optional[dict[str, Any]] = None, ) -> int: - """Save an artifact with atomic versioning via ``IfNoneMatch``. - - If two concurrent callers race to create the same version, S3 - will reject the second ``put_object`` with a ``PreconditionFailed`` - error and this method will transparently retry. - """ + """Save an artifact with atomic versioning via IfNoneMatch.""" from botocore.exceptions import ClientError if self.save_max_retries < 0: @@ -200,7 +162,6 @@ async def save_artifact( app_name, user_id, session_id, filename, version ) - # Prepare data and content type if artifact.inline_data: body = artifact.inline_data.data content_type = ( @@ -257,7 +218,7 @@ async def load_artifact( session_id: Optional[str] = None, version: Optional[int] = None, ) -> Optional[types.Part]: - """Load a specific version (or latest) of an artifact from S3.""" + """Load a specific version of an artifact, or the latest.""" from botocore.exceptions import ClientError if version is None: @@ -299,7 +260,7 @@ async def load_artifact( async def list_artifact_keys( self, *, app_name: str, user_id: str, session_id: Optional[str] = None ) -> list[str]: - """List all artifact keys for a user, optionally scoped to a session.""" + """List all artifact keys for a user, optionally filtered by session.""" keys: set[str] = set() prefixes = [ f"{app_name}/{user_id}/{session_id}/" if session_id else None, @@ -313,11 +274,9 @@ async def list_artifact_keys( ): for obj in page.get("Contents", []): relative = obj["Key"][len(prefix):] - # relative is "{filename}/{version}" โ€” strip version part parts = relative.rsplit("/", 1) if len(parts) >= 2: raw_filename = parts[0] - # Re-add "user:" prefix for user-scoped artifacts if prefix.endswith("/user/"): keys.add(f"user:{raw_filename}") else: @@ -333,7 +292,7 @@ async def delete_artifact( filename: str, session_id: Optional[str] = None, ) -> None: - """Delete all versions of an artifact using S3 batch delete.""" + """Delete all versions of an artifact.""" versions = await self.list_versions( app_name=app_name, user_id=user_id, @@ -352,7 +311,6 @@ async def delete_artifact( for v in versions ] async with self._client() as s3: - # S3 batch delete supports up to 1000 keys per request for i in range(0, len(keys_to_delete), 1000): batch = keys_to_delete[i : i + 1000] await s3.delete_objects( @@ -368,7 +326,7 @@ async def list_versions( filename: str, session_id: Optional[str] = None, ) -> list[int]: - """List all available version numbers for an artifact.""" + """List all available versions of an artifact.""" prefix = ( self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" ) @@ -395,7 +353,7 @@ async def list_artifact_versions( filename: str, session_id: Optional[str] = None, ) -> list[ArtifactVersion]: - """List all versions with metadata, using parallel head_object calls.""" + """List all versions with metadata.""" prefix = ( self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" ) @@ -409,7 +367,6 @@ async def list_artifact_versions( if not page_objects: continue - # Parallelise head_object calls for each page head_tasks = [ s3.head_object(Bucket=self.bucket_name, Key=obj["Key"]) for obj in page_objects @@ -445,7 +402,7 @@ async def get_artifact_version( session_id: Optional[str] = None, version: Optional[int] = None, ) -> Optional[ArtifactVersion]: - """Retrieve metadata for a specific version (or the latest).""" + """Retrieve metadata for a specific version, or the latest.""" from botocore.exceptions import ClientError if version is None: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/artifacts/test_s3_artifact_service.py b/tests/unittests/artifacts/test_s3_artifact_service.py index e1bc8e09..8bdccb07 100644 --- a/tests/unittests/artifacts/test_s3_artifact_service.py +++ b/tests/unittests/artifacts/test_s3_artifact_service.py @@ -28,9 +28,7 @@ from google.genai import types import pytest -# --------------------------------------------------------------------------- # -# Mock infrastructure -# --------------------------------------------------------------------------- # + FIXED_DATETIME = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) @@ -89,7 +87,7 @@ async def __aenter__(self): async def __aexit__(self, *args): pass - # --- S3 operations ---------------------------------------------------- # + async def head_bucket(self, Bucket: str): if Bucket not in self.buckets: @@ -110,7 +108,7 @@ async def put_object( self.buckets[Bucket] = MockS3Bucket(Bucket) bucket = self.buckets[Bucket] - # Simulate atomic IfNoneMatch="*" โ€” reject if key already exists + # Reject if key already exists (atomic IfNoneMatch) if IfNoneMatch == "*" and Key in bucket.objects and bucket.objects[Key].data is not None: from botocore.exceptions import ClientError @@ -229,9 +227,7 @@ def client(self, service_name: str, **kwargs): return self._client -# --------------------------------------------------------------------------- # -# Fixtures -# --------------------------------------------------------------------------- # + @pytest.fixture def mock_s3_service(): @@ -241,14 +237,12 @@ def mock_s3_service(): with mock.patch("aioboto3.Session", return_value=mock_session): service = S3ArtifactService(bucket_name="test_bucket") - # Ensure the mock session is used service._session = mock_session return service -# --------------------------------------------------------------------------- # -# Tests -# --------------------------------------------------------------------------- # + + @pytest.mark.asyncio async def test_load_empty(mock_s3_service): From 25ebe3e39395aacadd59d5e4e26b8eb83bf0b497 Mon Sep 17 00:00:00 2001 From: Nishar Date: Sat, 18 Apr 2026 18:12:42 -0500 Subject: [PATCH 3/6] chore: remove accidental tests/__init__.py --- tests/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/__init__.py diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29b..00000000 From f3b48a36490f02e6b526349ae650b9b23c15b0a6 Mon Sep 17 00:00:00 2001 From: Nishar Date: Tue, 5 May 2026 07:51:08 +0545 Subject: [PATCH 4/6] feat: add S3ArtifactService using aioboto3 Adds S3-backed artifact storage with: - Atomic versioning via IfNoneMatch conditional writes - Async-safe session initialization with asyncio.Lock - Bounded concurrent head_object calls (semaphore=10) - S3 metadata size validation (2KB limit) - User-scoped artifact namespace support - Full test coverage with mocked S3 client --- .../artifacts/s3_artifact_service.py | 40 ++++++++++++++----- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/src/google/adk_community/artifacts/s3_artifact_service.py b/src/google/adk_community/artifacts/s3_artifact_service.py index d93de542..51fb5da2 100644 --- a/src/google/adk_community/artifacts/s3_artifact_service.py +++ b/src/google/adk_community/artifacts/s3_artifact_service.py @@ -60,14 +60,14 @@ def __init__( self.aws_configs: dict[str, Any] = aws_configs or {} self.save_max_retries = save_max_retries self._session = None - - + self._session_lock = asyncio.Lock() async def _get_session(self): import aioboto3 - if self._session is None: - self._session = aioboto3.Session() + async with self._session_lock: + if self._session is None: + self._session = aioboto3.Session() return self._session @asynccontextmanager @@ -78,12 +78,27 @@ async def _client(self): ) as s3: yield s3 + # S3 user-defined metadata is limited to 2 KB total. + _S3_METADATA_MAX_BYTES = 2048 + @staticmethod def _flatten_metadata(metadata: Optional[dict[str, Any]]) -> dict[str, str]: - """JSON-encode metadata values for S3 user-metadata.""" + """JSON-encode metadata values for S3 user-metadata. + + Raises: + ValueError: If the encoded metadata exceeds the S3 2 KB limit. + """ if not metadata: return {} - return {str(k): json.dumps(v) for k, v in metadata.items()} + flat = {str(k): json.dumps(v) for k, v in metadata.items()} + total = sum(len(k.encode()) + len(v.encode()) for k, v in flat.items()) + if total > S3ArtifactService._S3_METADATA_MAX_BYTES: + raise ValueError( + f"Custom metadata ({total} bytes) exceeds the S3 " + f"user-metadata limit of " + f"{S3ArtifactService._S3_METADATA_MAX_BYTES} bytes." + ) + return flat @staticmethod def _unflatten_metadata(metadata: Optional[dict[str, str]]) -> dict[str, Any]: @@ -358,7 +373,15 @@ async def list_artifact_versions( self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" ) results: list[ArtifactVersion] = [] + # Limit concurrent head_object calls to avoid S3 rate-limiting. + sem = asyncio.Semaphore(10) + async with self._client() as s3: + + async def _head(key: str): + async with sem: + return await s3.head_object(Bucket=self.bucket_name, Key=key) + paginator = s3.get_paginator("list_objects_v2") async for page in paginator.paginate( Bucket=self.bucket_name, Prefix=prefix @@ -367,10 +390,7 @@ async def list_artifact_versions( if not page_objects: continue - head_tasks = [ - s3.head_object(Bucket=self.bucket_name, Key=obj["Key"]) - for obj in page_objects - ] + head_tasks = [_head(obj["Key"]) for obj in page_objects] heads = await asyncio.gather(*head_tasks) for obj, head in zip(page_objects, heads): From c564f5d4a9a8dc75019294cb31cb0bdf9d487343 Mon Sep 17 00:00:00 2001 From: Nishar Date: Thu, 7 May 2026 09:01:32 +0545 Subject: [PATCH 5/6] refactor(s3): address PR review comments - Metadata size check now accounts for S3 x-amz-meta- prefix overhead (~11 bytes per key) in the 2KB limit calculation. - Replace unconventional iter(int, 1) infinite loop with while True and add exponential backoff (100ms, 200ms, ... capped at 5s) between version conflict retries. - list_artifact_keys now uses S3 Delimiter='/' with CommonPrefixes for O(unique-keys) efficiency instead of listing every version object. - load_artifact returns Part.from_text() for text/* content types so consumers can check part.text consistently. --- .../artifacts/s3_artifact_service.py | 73 +++++++++++++------ .../artifacts/test_s3_artifact_service.py | 60 +++++++++++---- 2 files changed, 96 insertions(+), 37 deletions(-) diff --git a/src/google/adk_community/artifacts/s3_artifact_service.py b/src/google/adk_community/artifacts/s3_artifact_service.py index 51fb5da2..b6f5ba15 100644 --- a/src/google/adk_community/artifacts/s3_artifact_service.py +++ b/src/google/adk_community/artifacts/s3_artifact_service.py @@ -78,8 +78,11 @@ async def _client(self): ) as s3: yield s3 - # S3 user-defined metadata is limited to 2 KB total. + # S3 user-defined metadata is limited to 2 KB total. S3 prefixes each + # key with ``x-amz-meta-`` (11 bytes) in the header, so we include that + # overhead per key when computing the total size. _S3_METADATA_MAX_BYTES = 2048 + _S3_META_PREFIX_LEN = len("x-amz-meta-") @staticmethod def _flatten_metadata(metadata: Optional[dict[str, Any]]) -> dict[str, str]: @@ -91,11 +94,17 @@ def _flatten_metadata(metadata: Optional[dict[str, Any]]) -> dict[str, str]: if not metadata: return {} flat = {str(k): json.dumps(v) for k, v in metadata.items()} - total = sum(len(k.encode()) + len(v.encode()) for k, v in flat.items()) + # Include the x-amz-meta- prefix overhead that S3 adds per key. + total = sum( + S3ArtifactService._S3_META_PREFIX_LEN + + len(k.encode()) + + len(v.encode()) + for k, v in flat.items() + ) if total > S3ArtifactService._S3_METADATA_MAX_BYTES: raise ValueError( - f"Custom metadata ({total} bytes) exceeds the S3 " - f"user-metadata limit of " + f"Custom metadata ({total} bytes including S3 header " + f"overhead) exceeds the S3 user-metadata limit of " f"{S3ArtifactService._S3_METADATA_MAX_BYTES} bytes." ) return flat @@ -160,12 +169,11 @@ async def save_artifact( """Save an artifact with atomic versioning via IfNoneMatch.""" from botocore.exceptions import ClientError - if self.save_max_retries < 0: - retry_iter = iter(int, 1) # infinite iterator - else: - retry_iter = range(self.save_max_retries + 1) + attempt = 0 + while True: + if self.save_max_retries >= 0 and attempt > self.save_max_retries: + break - for _ in retry_iter: versions = await self.list_versions( app_name=app_name, user_id=user_id, @@ -211,16 +219,23 @@ async def save_artifact( except ClientError as e: code = e.response.get("Error", {}).get("Code", "") if code in ("PreconditionFailed", "ObjectAlreadyExists"): + attempt += 1 + backoff = min(0.1 * (2 ** (attempt - 1)), 5.0) logger.debug( - "Version conflict for %s version %d, retryingโ€ฆ", + "Version conflict for %s version %d, retrying in " + "%.2fs (attempt %d)โ€ฆ", filename, version, + backoff, + attempt, ) + await asyncio.sleep(backoff) continue raise raise RuntimeError( - "Failed to save artifact due to version conflicts after retries." + "Failed to save artifact due to version conflicts after " + f"{self.save_max_retries} retries." ) @override @@ -269,13 +284,23 @@ async def load_artifact( self.bucket_name, key, ) + # Return Part.from_text for text content types so consumers can + # check ``part.text`` consistently. Fall back to from_bytes for + # binary content. + if content_type.startswith("text/"): + return types.Part.from_text(text=data.decode("utf-8")) return types.Part.from_bytes(data=data, mime_type=content_type) @override async def list_artifact_keys( self, *, app_name: str, user_id: str, session_id: Optional[str] = None ) -> list[str]: - """List all artifact keys for a user, optionally filtered by session.""" + """List all artifact keys for a user, optionally filtered by session. + + Uses S3 ``Delimiter='/'`` with ``CommonPrefixes`` to retrieve only + unique artifact names without listing every individual version + object. + """ keys: set[str] = set() prefixes = [ f"{app_name}/{user_id}/{session_id}/" if session_id else None, @@ -285,17 +310,21 @@ async def list_artifact_keys( for prefix in filter(None, prefixes): paginator = s3.get_paginator("list_objects_v2") async for page in paginator.paginate( - Bucket=self.bucket_name, Prefix=prefix + Bucket=self.bucket_name, + Prefix=prefix, + Delimiter="/", ): - for obj in page.get("Contents", []): - relative = obj["Key"][len(prefix):] - parts = relative.rsplit("/", 1) - if len(parts) >= 2: - raw_filename = parts[0] - if prefix.endswith("/user/"): - keys.add(f"user:{raw_filename}") - else: - keys.add(raw_filename) + for cp in page.get("CommonPrefixes", []): + # CommonPrefixes entries look like + # "/" โ€” strip the prefix and + # trailing slash to get the raw filename. + raw_filename = cp["Prefix"][len(prefix):].rstrip("/") + if not raw_filename: + continue + if prefix.endswith("/user/"): + keys.add(f"user:{raw_filename}") + else: + keys.add(raw_filename) return sorted(keys) @override diff --git a/tests/unittests/artifacts/test_s3_artifact_service.py b/tests/unittests/artifacts/test_s3_artifact_service.py index 8bdccb07..f6336182 100644 --- a/tests/unittests/artifacts/test_s3_artifact_service.py +++ b/tests/unittests/artifacts/test_s3_artifact_service.py @@ -178,17 +178,26 @@ def __init__(self, client: MockS3Client, operation_name: str): self.client = client self.operation_name = operation_name - def paginate(self, Bucket: str, Prefix: str = ""): - return MockS3PaginateResult(self.client, Bucket, Prefix) + def paginate(self, Bucket: str, Prefix: str = "", Delimiter: str = ""): + return MockS3PaginateResult( + self.client, Bucket, Prefix, Delimiter + ) class MockS3PaginateResult: """Async iterator that yields a single page of list_objects_v2 results.""" - def __init__(self, client: MockS3Client, bucket: str, prefix: str): + def __init__( + self, + client: MockS3Client, + bucket: str, + prefix: str, + delimiter: str = "", + ): self.client = client self.bucket_name = bucket self.prefix = prefix + self.delimiter = delimiter self._yielded = False def __aiter__(self): @@ -204,13 +213,33 @@ async def __anext__(self): if not bucket: return {} + matching_keys = [ + key + for key, obj in bucket.objects.items() + if key.startswith(self.prefix) and obj.data is not None + ] + + # When Delimiter is set, return CommonPrefixes instead of Contents + # to simulate S3's virtual directory grouping. + if self.delimiter: + prefixes: set[str] = set() + for key in matching_keys: + relative = key[len(self.prefix):] + idx = relative.find(self.delimiter) + if idx >= 0: + prefixes.add(self.prefix + relative[: idx + 1]) + if prefixes: + return { + "CommonPrefixes": [{"Prefix": p} for p in sorted(prefixes)] + } + return {} + contents = [] - for key, obj in bucket.objects.items(): - if key.startswith(self.prefix) and obj.data is not None: - contents.append({ - "Key": key, - "LastModified": obj.last_modified, - }) + for key in matching_keys: + contents.append({ + "Key": key, + "LastModified": bucket.objects[key].last_modified, + }) if contents: return {"Contents": contents} @@ -259,7 +288,7 @@ async def test_load_empty(mock_s3_service): @pytest.mark.asyncio async def test_save_load_delete(mock_s3_service): """Full CRUD cycle: save, load, load-missing-version, delete.""" - artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain") + artifact = types.Part.from_bytes(data=b"test_data", mime_type="application/octet-stream") app_name = "app0" user_id = "user0" session_id = "123" @@ -542,8 +571,8 @@ async def test_empty_artifact(mock_s3_service): filename="empty.txt", ) assert loaded is not None - assert loaded.inline_data is not None - assert loaded.inline_data.data == b"" + # Empty text/plain content returns Part.from_text with empty string. + assert loaded.text == "" @pytest.mark.asyncio @@ -575,7 +604,7 @@ async def test_custom_metadata(mock_s3_service): @pytest.mark.asyncio async def test_text_artifact_roundtrip(mock_s3_service): - """Text artifacts are encoded to UTF-8 bytes on save and loaded as bytes.""" + """Text artifacts are encoded to UTF-8 bytes on save and returned via Part.from_text on load.""" artifact = types.Part.from_text(text="Hello, world! ๐ŸŒ") version = await mock_s3_service.save_artifact( @@ -594,8 +623,9 @@ async def test_text_artifact_roundtrip(mock_s3_service): filename="greeting.txt", ) assert loaded is not None - assert loaded.inline_data is not None - assert loaded.inline_data.data == "Hello, world! ๐ŸŒ".encode("utf-8") + # Text artifacts are now returned via Part.from_text, so consumers + # can check ``part.text`` directly. + assert loaded.text == "Hello, world! ๐ŸŒ" @pytest.mark.asyncio From 0f8188ff1611fa5b210a3219de50a09d18167eff Mon Sep 17 00:00:00 2001 From: Nishar Date: Thu, 7 May 2026 10:16:26 +0545 Subject: [PATCH 6/6] fix(test): add importorskip guard for S3 optional dependencies The S3 artifact tests require aioboto3 and botocore which are only installed with the [s3] extra. CI runs with --extra test only, so these tests fail with ModuleNotFoundError on collection. Adding pytest.importorskip() at module level gracefully skips the entire test file when the S3 dependencies aren't available. --- tests/unittests/artifacts/test_s3_artifact_service.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unittests/artifacts/test_s3_artifact_service.py b/tests/unittests/artifacts/test_s3_artifact_service.py index f6336182..ec919a85 100644 --- a/tests/unittests/artifacts/test_s3_artifact_service.py +++ b/tests/unittests/artifacts/test_s3_artifact_service.py @@ -28,6 +28,8 @@ from google.genai import types import pytest +pytest.importorskip("aioboto3") +pytest.importorskip("botocore") FIXED_DATETIME = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc)