From c6ecf79532fc01fcef0f2e7bc0455efb1a88d354 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 6 Mar 2026 16:42:44 +0000 Subject: [PATCH 01/11] feat(server): add v0.3 legacy compatibility for database models --- src/a2a/server/models.py | 62 +++++++- .../server/tasks/test_database_task_store.py | 135 ++++++++++++++++++ 2 files changed, 193 insertions(+), 4 deletions(-) diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index bba12e90..6949b9c3 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from datetime import datetime from typing import TYPE_CHECKING, Any, Generic, TypeVar @@ -11,10 +12,12 @@ def override(func): # noqa: ANN001, ANN201 return func -from google.protobuf.json_format import MessageToDict, ParseDict +from google.protobuf.json_format import MessageToDict, ParseDict, ParseError from google.protobuf.message import Message as ProtoMessage -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError +from a2a.compat.v0_3 import conversions +from a2a.compat.v0_3 import types as types_v03 from a2a.types.a2a_pb2 import Artifact, Message, TaskStatus @@ -81,7 +84,19 @@ def process_result_value( if isinstance(self.pydantic_type, type) and issubclass( self.pydantic_type, ProtoMessage ): - return ParseDict(value, self.pydantic_type()) # type: ignore[return-value] + try: + return ParseDict(value, self.pydantic_type()) # type: ignore[return-value] + except (ParseError, ValueError): + # Try legacy conversion + legacy_map = _get_legacy_conversions() + if self.pydantic_type in legacy_map: + legacy_type, convert_func = legacy_map[self.pydantic_type] + try: + legacy_instance = legacy_type.model_validate(value) + return convert_func(legacy_instance) + except ValidationError: + pass + raise # Assume it's a Pydantic model return self.pydantic_type.model_validate(value) # type: ignore[attr-defined] @@ -130,7 +145,24 @@ def process_result_value( if isinstance(self.pydantic_type, type) and issubclass( self.pydantic_type, ProtoMessage ): - return [ParseDict(item, self.pydantic_type()) for item in value] # type: ignore[misc] + result = [] + legacy_map = _get_legacy_conversions() + legacy_info = legacy_map.get(self.pydantic_type) + + for item in value: + try: + result.append(ParseDict(item, self.pydantic_type())) + except (ParseError, ValueError): # noqa: PERF203 + if legacy_info: + legacy_type, convert_func = legacy_info + try: + legacy_instance = legacy_type.model_validate(item) + result.append(convert_func(legacy_instance)) + continue + except ValidationError: + pass + raise + return result # type: ignore[return-value] # Assume it's a Pydantic model return [self.pydantic_type.model_validate(item) for item in value] # type: ignore[attr-defined] @@ -292,3 +324,25 @@ class PushNotificationConfigModel(PushNotificationConfigMixin, Base): """Default push notification config model with standard table name.""" __tablename__ = 'push_notification_configs' + + +_LEGACY_CONVERSIONS: dict[type, tuple[type[BaseModel], Callable]] | None = None + + +def _get_legacy_conversions() -> dict[type, tuple[type[BaseModel], Callable]]: + """Lazily load and return legacy conversion mapping.""" + global _LEGACY_CONVERSIONS # noqa: PLW0603 + if _LEGACY_CONVERSIONS is None: + try: + # Lazy imports to avoid circular dependencies and unnecessary overhead + _LEGACY_CONVERSIONS = { + TaskStatus: ( + types_v03.TaskStatus, + conversions.to_core_task_status, + ), + Message: (types_v03.Message, conversions.to_core_message), + Artifact: (types_v03.Artifact, conversions.to_core_artifact), + } + except ImportError: + _LEGACY_CONVERSIONS = {} + return _LEGACY_CONVERSIONS diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index b71fd709..98942996 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -683,4 +683,139 @@ async def test_owner_resource_scoping( await task_store.delete('u2-task1', context_user2) +@pytest.mark.asyncio +async def test_get_0_3_task_detailed( + db_store_parameterized: DatabaseTaskStore, +) -> None: + """Test retrieving a detailed legacy v0.3 task from the database. + + This test simulates a database that already contains legacy v0.3 JSON data + (string-based enums, different field names) and verifies that the store + correctly converts it to the modern Protobuf-based Task model. + """ + from a2a.compat.v0_3 import types as types_v03 + from sqlalchemy import insert + + task_id = 'legacy-detailed-1' + owner = 'legacy_user' + context_user = ServerCallContext(user=SampleUser(user_name=owner)) + + # 1. Create a detailed legacy Task using v0.3 models + legacy_task = types_v03.Task( + id=task_id, + context_id='legacy-ctx-1', + status=types_v03.TaskStatus( + state=types_v03.TaskState.working, + message=types_v03.Message( + message_id='msg-status', + role=types_v03.Role.agent, + parts=[ + types_v03.Part( + root=types_v03.TextPart(text='Legacy status message') + ) + ], + ), + timestamp='2023-10-27T10:00:00Z', + ), + history=[ + types_v03.Message( + message_id='msg-1', + role=types_v03.Role.user, + parts=[ + types_v03.Part(root=types_v03.TextPart(text='Hello legacy')) + ], + ), + types_v03.Message( + message_id='msg-2', + role=types_v03.Role.agent, + parts=[ + types_v03.Part( + root=types_v03.DataPart(data={'legacy_key': 'value'}) + ) + ], + ), + ], + artifacts=[ + types_v03.Artifact( + artifact_id='art-1', + name='Legacy Artifact', + parts=[ + types_v03.Part( + root=types_v03.FilePart( + file=types_v03.FileWithUri( + uri='https://example.com/legacy.txt', + mime_type='text/plain', + ) + ) + ) + ], + ) + ], + metadata={'meta_key': 'meta_val'}, + ) + + # 2. Manually insert the legacy data into the database + # We must bypass the store's save() method because it expects Protobuf objects. + async with db_store_parameterized.async_session_maker.begin() as session: + # Pydantic model_dump(mode='json') produces exactly what would be in the legacy DB + legacy_data = legacy_task.model_dump(mode='json') + + stmt = insert(db_store_parameterized.task_model).values( + id=task_id, + context_id=legacy_task.context_id, + owner=owner, + status=legacy_data['status'], + history=legacy_data['history'], + artifacts=legacy_data['artifacts'], + task_metadata=legacy_data['metadata'], + kind='task', + last_updated=datetime.now(timezone.utc), + ) + await session.execute(stmt) + + # 3. Retrieve the task using the standard store.get() + # This will trigger the PydanticType/PydanticListType legacy fallback + retrieved_task = await db_store_parameterized.get(task_id, context_user) + + # 4. Verify the conversion to modern Protobuf + assert retrieved_task is not None + assert retrieved_task.id == task_id + assert retrieved_task.context_id == 'legacy-ctx-1' + + # Check Status & State (The most critical part: string 'working' -> enum TASK_STATE_WORKING) + assert retrieved_task.status.state == TaskState.TASK_STATE_WORKING + assert retrieved_task.status.message.message_id == 'msg-status' + assert retrieved_task.status.message.role == Role.ROLE_AGENT + assert ( + retrieved_task.status.message.parts[0].text == 'Legacy status message' + ) + + # Check History + assert len(retrieved_task.history) == 2 + assert retrieved_task.history[0].message_id == 'msg-1' + assert retrieved_task.history[0].role == Role.ROLE_USER + assert retrieved_task.history[0].parts[0].text == 'Hello legacy' + + assert retrieved_task.history[1].message_id == 'msg-2' + assert retrieved_task.history[1].role == Role.ROLE_AGENT + assert ( + MessageToDict(retrieved_task.history[1].parts[0].data)['legacy_key'] + == 'value' + ) + + # Check Artifacts + assert len(retrieved_task.artifacts) == 1 + assert retrieved_task.artifacts[0].artifact_id == 'art-1' + assert retrieved_task.artifacts[0].name == 'Legacy Artifact' + assert ( + retrieved_task.artifacts[0].parts[0].url + == 'https://example.com/legacy.txt' + ) + + # Check Metadata + assert dict(retrieved_task.metadata) == {'meta_key': 'meta_val'} + + await db_store_parameterized.delete(task_id, context_user) + + # Ensure aiosqlite, asyncpg, and aiomysql are installed in the test environment (added to pyproject.toml). From 1c7e552a39ed28b93161751d65c19e8b4f3880f5 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 6 Mar 2026 16:47:00 +0000 Subject: [PATCH 02/11] refactor: remove PLW --- src/a2a/server/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index 6949b9c3..f64fb642 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -331,7 +331,7 @@ class PushNotificationConfigModel(PushNotificationConfigMixin, Base): def _get_legacy_conversions() -> dict[type, tuple[type[BaseModel], Callable]]: """Lazily load and return legacy conversion mapping.""" - global _LEGACY_CONVERSIONS # noqa: PLW0603 + global _LEGACY_CONVERSIONS if _LEGACY_CONVERSIONS is None: try: # Lazy imports to avoid circular dependencies and unnecessary overhead From b52c8d41dbd749d998a0fa6724cf27ebd1b74cb4 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 6 Mar 2026 17:01:05 +0000 Subject: [PATCH 03/11] refactor: remove global variable --- src/a2a/server/models.py | 31 ++++++------------- .../server/tasks/test_database_task_store.py | 2 +- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index f64fb642..eb2604e7 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -1,4 +1,3 @@ -from collections.abc import Callable from datetime import datetime from typing import TYPE_CHECKING, Any, Generic, TypeVar @@ -326,23 +325,13 @@ class PushNotificationConfigModel(PushNotificationConfigMixin, Base): __tablename__ = 'push_notification_configs' -_LEGACY_CONVERSIONS: dict[type, tuple[type[BaseModel], Callable]] | None = None - - -def _get_legacy_conversions() -> dict[type, tuple[type[BaseModel], Callable]]: - """Lazily load and return legacy conversion mapping.""" - global _LEGACY_CONVERSIONS - if _LEGACY_CONVERSIONS is None: - try: - # Lazy imports to avoid circular dependencies and unnecessary overhead - _LEGACY_CONVERSIONS = { - TaskStatus: ( - types_v03.TaskStatus, - conversions.to_core_task_status, - ), - Message: (types_v03.Message, conversions.to_core_message), - Artifact: (types_v03.Artifact, conversions.to_core_artifact), - } - except ImportError: - _LEGACY_CONVERSIONS = {} - return _LEGACY_CONVERSIONS +def _get_legacy_conversions() -> dict[type, tuple[type, Any]]: + """Get the mapping of current types to their legacy counterparts and conversion functions.""" + return { + TaskStatus: ( + types_v03.TaskStatus, + conversions.to_core_task_status, + ), + Message: (types_v03.Message, conversions.to_core_message), + Artifact: (types_v03.Artifact, conversions.to_core_artifact), + } diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index 98942996..9514a07a 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -769,7 +769,7 @@ async def test_get_0_3_task_detailed( artifacts=legacy_data['artifacts'], task_metadata=legacy_data['metadata'], kind='task', - last_updated=datetime.now(timezone.utc), + last_updated=None, ) await session.execute(stmt) From 3d983034c92a59e71dbdddf54c38cb8973fa1b08 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Mon, 9 Mar 2026 11:48:37 +0000 Subject: [PATCH 04/11] WIP --- ...database_push_notification_config_store.py | 32 ++++++++-- ...database_push_notification_config_store.py | 60 +++++++++++++++++++ 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/src/a2a/server/tasks/database_push_notification_config_store.py b/src/a2a/server/tasks/database_push_notification_config_store.py index be8f1612..2952a48c 100644 --- a/src/a2a/server/tasks/database_push_notification_config_store.py +++ b/src/a2a/server/tasks/database_push_notification_config_store.py @@ -1,5 +1,4 @@ # ruff: noqa: PLC0415 -import json import logging from typing import TYPE_CHECKING @@ -27,6 +26,8 @@ "or 'pip install a2a-sdk[sql]'" ) from e +from a2a.compat.v0_3 import conversions +from a2a.compat.v0_3 import types as types_v03 from a2a.server.context import ServerCallContext from a2a.server.models import ( Base, @@ -163,8 +164,25 @@ def _to_orm( config_id=config.id, owner=owner, config_data=data_to_store, + protocol_version='1.0', ) + def _parse_config( + self, json_payload: str, protocol_version: str | None = None + ) -> PushNotificationConfig: + """Parses a JSON payload into a PushNotificationConfig proto. + + Uses protocol_version to decide between modern parsing and legacy fallback. + """ + if protocol_version == '1.0': + return Parse(json_payload, PushNotificationConfig()) + + # Legacy case: no version or older + legacy_instance = types_v03.PushNotificationConfig.model_validate_json( + json_payload + ) + return conversions.to_core_push_notification_config(legacy_instance) + def _from_orm( self, model_instance: PushNotificationConfigModel ) -> PushNotificationConfig: @@ -181,10 +199,11 @@ def _from_orm( try: decrypted_payload = self._fernet.decrypt(payload) - return Parse( - decrypted_payload.decode('utf-8'), PushNotificationConfig() + return self._parse_config( + decrypted_payload.decode('utf-8'), + model_instance.protocol_version, ) - except (json.JSONDecodeError, Exception) as e: + except Exception as e: if isinstance(e, InvalidToken): # Decryption failed. This could be because the data is not encrypted. # We'll log a warning and try to parse it as plain JSON as a fallback. @@ -214,7 +233,10 @@ def _from_orm( if isinstance(payload, bytes) else payload ) - return Parse(payload_str, PushNotificationConfig()) + return self._parse_config( + payload_str, model_instance.protocol_version + ) + except Exception as e: if self._fernet: logger.exception( diff --git a/tests/server/tasks/test_database_push_notification_config_store.py b/tests/server/tasks/test_database_push_notification_config_store.py index 042ff800..b71d11a3 100644 --- a/tests/server/tasks/test_database_push_notification_config_store.py +++ b/tests/server/tasks/test_database_push_notification_config_store.py @@ -707,3 +707,63 @@ async def test_owner_resource_scoping( # Cleanup remaining await config_store.delete_info('task1', context=context_user1) await config_store.delete_info('task1', context=context_user2) + + +@pytest.mark.asyncio +async def test_get_0_3_push_notification_config_detailed( + db_store_parameterized: DatabasePushNotificationConfigStore, +) -> None: + """Test retrieving a legacy v0.3 push notification config from the database. + + This test simulates a database that already contains legacy v0.3 JSON data + and verifies that the store correctly converts it to the modern Protobuf model. + """ + from a2a.compat.v0_3 import types as types_v03 + from sqlalchemy import insert + + task_id = 'legacy-push-1' + config_id = 'config-legacy-1' + owner = 'legacy_user' + context_user = ServerCallContext(user=SampleUser(user_name=owner)) + + # 1. Create a legacy PushNotificationConfig using v0.3 models + legacy_config = types_v03.PushNotificationConfig( + id=config_id, + url='https://example.com/push', + token='legacy-token', + authentication=types_v03.PushNotificationAuthenticationInfo( + schemes=['bearer'], + credentials='legacy-creds', + ), + ) + + # 2. Manually insert the legacy data into the database + # For PushNotificationConfigStore, the data is stored in the config_data column. + async with db_store_parameterized.async_session_maker.begin() as session: + # Pydantic model_dump_json() produces the JSON that we'll store. + # Note: DatabasePushNotificationConfigStore normally encrypts this, but here + # we'll store it as plain JSON bytes to simulate legacy data. + legacy_json = legacy_config.model_dump_json() + + stmt = insert(db_store_parameterized.config_model).values( + task_id=task_id, + config_id=config_id, + owner=owner, + config_data=legacy_json.encode('utf-8'), + ) + await session.execute(stmt) + + # 3. Retrieve the config using the standard store.get_info() + # This will trigger the DatabasePushNotificationConfigStore._from_orm legacy fallback + retrieved_configs = await db_store_parameterized.get_info( + task_id, context_user + ) + + # 4. Verify the conversion to modern Protobuf + assert len(retrieved_configs) == 1 + retrieved = retrieved_configs[0] + assert retrieved.id == config_id + assert retrieved.url == 'https://example.com/push' + assert retrieved.token == 'legacy-token' + assert retrieved.authentication.scheme == 'bearer' + assert retrieved.authentication.credentials == 'legacy-creds' From 4e03b9d3c75e0514dab9bfd1b754e33b27092dbb Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 10 Mar 2026 13:57:01 +0000 Subject: [PATCH 05/11] refactor: replace PydanticType/PydanticListType with explicit JSON serialization/deserialization for protobuf models and add protocol versioning to task storage. --- src/a2a/server/models.py | 174 +----------------- ...database_push_notification_config_store.py | 35 ++-- src/a2a/server/tasks/database_task_store.py | 64 ++++--- ...database_push_notification_config_store.py | 23 ++- .../server/tasks/test_database_task_store.py | 8 +- tests/server/test_models.py | 67 ------- 6 files changed, 90 insertions(+), 281 deletions(-) diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index 52d4044e..0cad562d 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -11,26 +11,14 @@ def override(func): # noqa: ANN001, ANN201 return func -from google.protobuf.json_format import MessageToDict, ParseDict, ParseError -from google.protobuf.message import Message as ProtoMessage -from pydantic import BaseModel, ValidationError - -from a2a.compat.v0_3 import conversions -from a2a.compat.v0_3 import types as types_v03 -from a2a.types.a2a_pb2 import Artifact, Message, TaskStatus - - try: - from sqlalchemy import JSON, DateTime, Dialect, Index, LargeBinary, String + from sqlalchemy import JSON, DateTime, Index, LargeBinary, String from sqlalchemy.orm import ( DeclarativeBase, Mapped, declared_attr, mapped_column, ) - from sqlalchemy.types import ( - TypeDecorator, - ) except ImportError as e: raise ImportError( 'Database models require SQLAlchemy. ' @@ -42,130 +30,6 @@ def override(func): # noqa: ANN001, ANN201 ) from e -T = TypeVar('T') - - -class PydanticType(TypeDecorator[T], Generic[T]): - """SQLAlchemy type that handles Pydantic model and Protobuf message serialization.""" - - impl = JSON - cache_ok = True - - def __init__(self, pydantic_type: type[T], **kwargs: dict[str, Any]): - """Initialize the PydanticType. - - Args: - pydantic_type: The Pydantic model or Protobuf message type to handle. - **kwargs: Additional arguments for TypeDecorator. - """ - self.pydantic_type = pydantic_type - super().__init__(**kwargs) - - def process_bind_param( - self, value: T | None, dialect: Dialect - ) -> dict[str, Any] | None: - """Convert Pydantic model or Protobuf message to a JSON-serializable dictionary for the database.""" - if value is None: - return None - if isinstance(value, ProtoMessage): - return MessageToDict(value, preserving_proto_field_name=False) - if isinstance(value, BaseModel): - return value.model_dump(mode='json') - return value # type: ignore[return-value] - - def process_result_value( - self, value: dict[str, Any] | None, dialect: Dialect - ) -> T | None: - """Convert a JSON-like dictionary from the database back to a Pydantic model or Protobuf message.""" - if value is None: - return None - # Check if it's a protobuf message class - if isinstance(self.pydantic_type, type) and issubclass( - self.pydantic_type, ProtoMessage - ): - try: - return ParseDict(value, self.pydantic_type()) # type: ignore[return-value] - except (ParseError, ValueError): - # Try legacy conversion - legacy_map = _get_legacy_conversions() - if self.pydantic_type in legacy_map: - legacy_type, convert_func = legacy_map[self.pydantic_type] - try: - legacy_instance = legacy_type.model_validate(value) - return convert_func(legacy_instance) - except ValidationError: - pass - raise - # Assume it's a Pydantic model - return self.pydantic_type.model_validate(value) # type: ignore[attr-defined] - - -class PydanticListType(TypeDecorator, Generic[T]): - """SQLAlchemy type that handles lists of Pydantic models or Protobuf messages.""" - - impl = JSON - cache_ok = True - - def __init__(self, pydantic_type: type[T], **kwargs: dict[str, Any]): - """Initialize the PydanticListType. - - Args: - pydantic_type: The Pydantic model or Protobuf message type for items in the list. - **kwargs: Additional arguments for TypeDecorator. - """ - self.pydantic_type = pydantic_type - super().__init__(**kwargs) - - def process_bind_param( - self, value: list[T] | None, dialect: Dialect - ) -> list[dict[str, Any]] | None: - """Convert a list of Pydantic models or Protobuf messages to a JSON-serializable list for the DB.""" - if value is None: - return None - result: list[dict[str, Any]] = [] - for item in value: - if isinstance(item, ProtoMessage): - result.append( - MessageToDict(item, preserving_proto_field_name=False) - ) - elif isinstance(item, BaseModel): - result.append(item.model_dump(mode='json')) - else: - result.append(item) # type: ignore[arg-type] - return result - - def process_result_value( - self, value: list[dict[str, Any]] | None, dialect: Dialect - ) -> list[T] | None: - """Convert a JSON-like list from the DB back to a list of Pydantic models or Protobuf messages.""" - if value is None: - return None - # Check if it's a protobuf message class - if isinstance(self.pydantic_type, type) and issubclass( - self.pydantic_type, ProtoMessage - ): - result = [] - legacy_map = _get_legacy_conversions() - legacy_info = legacy_map.get(self.pydantic_type) - - for item in value: - try: - result.append(ParseDict(item, self.pydantic_type())) - except (ParseError, ValueError): # noqa: PERF203 - if legacy_info: - legacy_type, convert_func = legacy_info - try: - legacy_instance = legacy_type.model_validate(item) - result.append(convert_func(legacy_instance)) - continue - except ValidationError: - pass - raise - return result # type: ignore[return-value] - # Assume it's a Pydantic model - return [self.pydantic_type.model_validate(item) for item in value] # type: ignore[attr-defined] - - # Base class for all database models class Base(DeclarativeBase): """Base class for declarative models in A2A SDK.""" @@ -184,25 +48,17 @@ class TaskMixin: last_updated: Mapped[datetime | None] = mapped_column( DateTime, nullable=True ) - - # Properly typed Pydantic fields with automatic serialization - status: Mapped[TaskStatus] = mapped_column(PydanticType(TaskStatus)) - artifacts: Mapped[list[Artifact] | None] = mapped_column( - PydanticListType(Artifact), nullable=True - ) - history: Mapped[list[Message] | None] = mapped_column( - PydanticListType(Message), nullable=True - ) + status: Mapped[Any] = mapped_column(JSON) + artifacts: Mapped[list[Any] | None] = mapped_column(JSON, nullable=True) + history: Mapped[list[Any] | None] = mapped_column(JSON, nullable=True) protocol_version: Mapped[str | None] = mapped_column( String(16), nullable=True ) - # Using declared_attr to avoid conflict with Pydantic's metadata - @declared_attr - @classmethod - def task_metadata(cls) -> Mapped[dict[str, Any] | None]: - """Define the 'metadata' column, avoiding name conflicts with Pydantic.""" - return mapped_column(JSON, nullable=True, name='metadata') + # Using 'task_metadata' to avoid conflict with SQLAlchemy's 'Base.metadata' + task_metadata: Mapped[dict[str, Any] | None] = mapped_column( + JSON, nullable=True, name='metadata' + ) @override def __repr__(self) -> str: @@ -329,15 +185,3 @@ class PushNotificationConfigModel(PushNotificationConfigMixin, Base): """Default push notification config model with standard table name.""" __tablename__ = 'push_notification_configs' - - -def _get_legacy_conversions() -> dict[type, tuple[type, Any]]: - """Get the mapping of current types to their legacy counterparts and conversion functions.""" - return { - TaskStatus: ( - types_v03.TaskStatus, - conversions.to_core_task_status, - ), - Message: (types_v03.Message, conversions.to_core_message), - Artifact: (types_v03.Artifact, conversions.to_core_artifact), - } diff --git a/src/a2a/server/tasks/database_push_notification_config_store.py b/src/a2a/server/tasks/database_push_notification_config_store.py index 8cb6f929..ebfdf01e 100644 --- a/src/a2a/server/tasks/database_push_notification_config_store.py +++ b/src/a2a/server/tasks/database_push_notification_config_store.py @@ -167,22 +167,6 @@ def _to_orm( protocol_version='1.0', ) - def _parse_config( - self, json_payload: str, protocol_version: str | None = None - ) -> PushNotificationConfig: - """Parses a JSON payload into a PushNotificationConfig proto. - - Uses protocol_version to decide between modern parsing and legacy fallback. - """ - if protocol_version == '1.0': - return Parse(json_payload, PushNotificationConfig()) - - # Legacy case: no version or older - legacy_instance = types_v03.PushNotificationConfig.model_validate_json( - json_payload - ) - return conversions.to_core_push_notification_config(legacy_instance) - def _from_orm( self, model_instance: PushNotificationConfigModel ) -> TaskPushNotificationConfig: @@ -355,3 +339,22 @@ async def delete_info( owner, config_id, ) + + def _parse_config( + self, json_payload: str, protocol_version: str | None = None + ) -> TaskPushNotificationConfig: + """Parses a JSON payload into a TaskPushNotificationConfig proto. + + Uses protocol_version to decide between modern parsing and legacy fallback. + """ + if protocol_version == '1.0': + return Parse(json_payload, TaskPushNotificationConfig()) + + legacy_instance = ( + types_v03.TaskPushNotificationConfig.model_validate_json( + json_payload + ) + ) + return conversions.to_core_task_push_notification_config( + legacy_instance + ) diff --git a/src/a2a/server/tasks/database_task_store.py b/src/a2a/server/tasks/database_task_store.py index 4f7b1ecd..1cb859c7 100644 --- a/src/a2a/server/tasks/database_task_store.py +++ b/src/a2a/server/tasks/database_task_store.py @@ -31,8 +31,10 @@ "or 'pip install a2a-sdk[sql]'" ) from e -from google.protobuf.json_format import MessageToDict +from google.protobuf.json_format import MessageToDict, ParseDict +from a2a.compat.v0_3 import conversions +from a2a.compat.v0_3 import types as types_v03 from a2a.server.context import ServerCallContext from a2a.server.models import Base, TaskModel, create_task_model from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope @@ -117,8 +119,7 @@ async def _ensure_initialized(self) -> None: def _to_orm(self, task: Task, owner: str) -> TaskModel: """Maps a Proto Task to a SQLAlchemy TaskModel instance.""" - # Pass proto objects directly - PydanticType/PydanticListType - # handle serialization via process_bind_param + task_dict = MessageToDict(task) return self.task_model( id=task.id, context_id=task.context_id, @@ -129,33 +130,52 @@ def _to_orm(self, task: Task, owner: str) -> TaskModel: if task.HasField('status') and task.status.HasField('timestamp') else None ), - status=task.status if task.HasField('status') else None, - artifacts=list(task.artifacts) if task.artifacts else [], - history=list(task.history) if task.history else [], - task_metadata=( - MessageToDict(task.metadata) if task.metadata.fields else None - ), + status=task_dict.get('status'), + artifacts=task_dict.get('artifacts', []), + history=task_dict.get('history', []), + task_metadata=task_dict.get('metadata'), + protocol_version='1.0', ) def _from_orm(self, task_model: TaskModel) -> Task: """Maps a SQLAlchemy TaskModel to a Proto Task instance.""" - # PydanticType/PydanticListType already deserialize to proto objects - # via process_result_value, so we can construct the Task directly + # Data is stored as raw JSON (dicts/lists), so we parse it manually task = Task( id=task_model.id, context_id=task_model.context_id, ) - if task_model.status: - task.status.CopyFrom(task_model.status) - if task_model.artifacts: - task.artifacts.extend(task_model.artifacts) - if task_model.history: - task.history.extend(task_model.history) - if task_model.task_metadata: - task.metadata.update( - cast('dict[str, Any]', task_model.task_metadata) - ) - return task + if task_model.protocol_version == '1.0': + if task_model.status: + ParseDict( + cast('dict[str, Any]', task_model.status), task.status + ) + if task_model.artifacts: + for art_dict in cast( + 'list[dict[str, Any]]', task_model.artifacts + ): + art = task.artifacts.add() + ParseDict(art_dict, art) + if task_model.history: + for msg_dict in cast( + 'list[dict[str, Any]]', task_model.history + ): + msg = task.history.add() + ParseDict(msg_dict, msg) + if task_model.task_metadata: + task.metadata.update( + cast('dict[str, Any]', task_model.task_metadata) + ) + return task + # Reconstruct legacy task from raw columns (which are dicts/lists here) + legacy_task = types_v03.Task( + id=task_model.id, + context_id=task_model.context_id, + status=cast('dict[str, Any]', task_model.status), + artifacts=cast('list[dict[str, Any]]', task_model.artifacts), + history=cast('list[dict[str, Any]]', task_model.history), + metadata=cast('dict[str, Any]', task_model.task_metadata), + ) + return conversions.to_core_task(legacy_task) async def save( self, task: Task, context: ServerCallContext | None = None diff --git a/tests/server/tasks/test_database_push_notification_config_store.py b/tests/server/tasks/test_database_push_notification_config_store.py index df1ac8b3..8236112e 100644 --- a/tests/server/tasks/test_database_push_notification_config_store.py +++ b/tests/server/tasks/test_database_push_notification_config_store.py @@ -5,6 +5,8 @@ import pytest from a2a.server.context import ServerCallContext from a2a.auth.user import User +from a2a.compat.v0_3 import types as types_v03 +from sqlalchemy import insert # Skip entire test module if SQLAlchemy is not installed @@ -730,22 +732,22 @@ async def test_get_0_3_push_notification_config_detailed( This test simulates a database that already contains legacy v0.3 JSON data and verifies that the store correctly converts it to the modern Protobuf model. """ - from a2a.compat.v0_3 import types as types_v03 - from sqlalchemy import insert - task_id = 'legacy-push-1' config_id = 'config-legacy-1' owner = 'legacy_user' context_user = ServerCallContext(user=SampleUser(user_name=owner)) # 1. Create a legacy PushNotificationConfig using v0.3 models - legacy_config = types_v03.PushNotificationConfig( - id=config_id, - url='https://example.com/push', - token='legacy-token', - authentication=types_v03.PushNotificationAuthenticationInfo( - schemes=['bearer'], - credentials='legacy-creds', + legacy_config = types_v03.TaskPushNotificationConfig( + task_id=task_id, + push_notification_config=types_v03.PushNotificationConfig( + id=config_id, + url='https://example.com/push', + token='legacy-token', + authentication=types_v03.PushNotificationAuthenticationInfo( + schemes=['bearer'], + credentials='legacy-creds', + ), ), ) @@ -774,6 +776,7 @@ async def test_get_0_3_push_notification_config_detailed( # 4. Verify the conversion to modern Protobuf assert len(retrieved_configs) == 1 retrieved = retrieved_configs[0] + assert retrieved.task_id == task_id assert retrieved.id == config_id assert retrieved.url == 'https://example.com/push' assert retrieved.token == 'legacy-token' diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index 9514a07a..dbc6bbd6 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -774,7 +774,7 @@ async def test_get_0_3_task_detailed( await session.execute(stmt) # 3. Retrieve the task using the standard store.get() - # This will trigger the PydanticType/PydanticListType legacy fallback + # This will trigger conversion from legacy to 1.0 format in _from_orm method retrieved_task = await db_store_parameterized.get(task_id, context_user) # 4. Verify the conversion to modern Protobuf @@ -815,6 +815,12 @@ async def test_get_0_3_task_detailed( # Check Metadata assert dict(retrieved_task.metadata) == {'meta_key': 'meta_val'} + retrieved_tasks = await db_store_parameterized.list( + ListTasksRequest(), context_user + ) + assert retrieved_tasks is not None + assert retrieved_tasks.tasks == [retrieved_task] + await db_store_parameterized.delete(task_id, context_user) diff --git a/tests/server/test_models.py b/tests/server/test_models.py index 08d700ce..bfaaed9d 100644 --- a/tests/server/test_models.py +++ b/tests/server/test_models.py @@ -5,76 +5,9 @@ from sqlalchemy.orm import DeclarativeBase from a2a.server.models import ( - PydanticListType, - PydanticType, create_push_notification_config_model, create_task_model, ) -from a2a.types.a2a_pb2 import Artifact, Part, TaskState, TaskStatus - - -class TestPydanticType: - """Tests for PydanticType SQLAlchemy type decorator.""" - - def test_process_bind_param_with_pydantic_model(self): - pydantic_type = PydanticType(TaskStatus) - status = TaskStatus(state=TaskState.TASK_STATE_WORKING) - dialect = MagicMock() - - result = pydantic_type.process_bind_param(status, dialect) - assert result is not None - assert result['state'] == 'TASK_STATE_WORKING' - # message field is optional and not set - - def test_process_bind_param_with_none(self): - pydantic_type = PydanticType(TaskStatus) - dialect = MagicMock() - - result = pydantic_type.process_bind_param(None, dialect) - assert result is None - - def test_process_result_value(self): - pydantic_type = PydanticType(TaskStatus) - dialect = MagicMock() - - result = pydantic_type.process_result_value( - {'state': 'TASK_STATE_COMPLETED'}, dialect - ) - assert isinstance(result, TaskStatus) - assert result.state == TaskState.TASK_STATE_COMPLETED - - -class TestPydanticListType: - """Tests for PydanticListType SQLAlchemy type decorator.""" - - def test_process_bind_param_with_list(self): - pydantic_list_type = PydanticListType(Artifact) - artifacts = [ - Artifact(artifact_id='1', parts=[Part(text='Hello')]), - Artifact(artifact_id='2', parts=[Part(text='World')]), - ] - dialect = MagicMock() - - result = pydantic_list_type.process_bind_param(artifacts, dialect) - assert result is not None - assert len(result) == 2 - assert result[0]['artifactId'] == '1' # JSON mode uses camelCase - assert result[1]['artifactId'] == '2' - - def test_process_result_value_with_list(self): - pydantic_list_type = PydanticListType(Artifact) - dialect = MagicMock() - data = [ - {'artifactId': '1', 'parts': [{'text': 'Hello'}]}, - {'artifactId': '2', 'parts': [{'text': 'World'}]}, - ] - - result = pydantic_list_type.process_result_value(data, dialect) - assert result is not None - assert len(result) == 2 - assert all(isinstance(art, Artifact) for art in result) - assert result[0].artifact_id == '1' - assert result[1].artifact_id == '2' def test_create_task_model(): From 11598fa2076f0c3e3eec81f3b5572c1c53b0cd41 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 10 Mar 2026 14:41:24 +0000 Subject: [PATCH 06/11] refactor: simplify task deserialization in `_from_orm` by using `ParseDict` and `model_validate` for mapping ORM models to task types. --- src/a2a/server/tasks/database_task_store.py | 45 +++++---------------- 1 file changed, 10 insertions(+), 35 deletions(-) diff --git a/src/a2a/server/tasks/database_task_store.py b/src/a2a/server/tasks/database_task_store.py index 1cb859c7..06fd03c1 100644 --- a/src/a2a/server/tasks/database_task_store.py +++ b/src/a2a/server/tasks/database_task_store.py @@ -139,42 +139,17 @@ def _to_orm(self, task: Task, owner: str) -> TaskModel: def _from_orm(self, task_model: TaskModel) -> Task: """Maps a SQLAlchemy TaskModel to a Proto Task instance.""" - # Data is stored as raw JSON (dicts/lists), so we parse it manually - task = Task( - id=task_model.id, - context_id=task_model.context_id, - ) + task_dict = { + 'id': task_model.id, + 'context_id': task_model.context_id, + 'status': task_model.status, + 'artifacts': task_model.artifacts, + 'history': task_model.history, + 'metadata': task_model.task_metadata, + } if task_model.protocol_version == '1.0': - if task_model.status: - ParseDict( - cast('dict[str, Any]', task_model.status), task.status - ) - if task_model.artifacts: - for art_dict in cast( - 'list[dict[str, Any]]', task_model.artifacts - ): - art = task.artifacts.add() - ParseDict(art_dict, art) - if task_model.history: - for msg_dict in cast( - 'list[dict[str, Any]]', task_model.history - ): - msg = task.history.add() - ParseDict(msg_dict, msg) - if task_model.task_metadata: - task.metadata.update( - cast('dict[str, Any]', task_model.task_metadata) - ) - return task - # Reconstruct legacy task from raw columns (which are dicts/lists here) - legacy_task = types_v03.Task( - id=task_model.id, - context_id=task_model.context_id, - status=cast('dict[str, Any]', task_model.status), - artifacts=cast('list[dict[str, Any]]', task_model.artifacts), - history=cast('list[dict[str, Any]]', task_model.history), - metadata=cast('dict[str, Any]', task_model.task_metadata), - ) + return ParseDict(task_dict, Task()) + legacy_task = types_v03.Task.model_validate(task_dict) return conversions.to_core_task(legacy_task) async def save( From 350539b5333e35b246264389930aceb948ad8eaa Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 10 Mar 2026 14:44:00 +0000 Subject: [PATCH 07/11] refactor: remove unused Any and cast imports from database_task_store.py --- src/a2a/server/tasks/database_task_store.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/a2a/server/tasks/database_task_store.py b/src/a2a/server/tasks/database_task_store.py index 06fd03c1..bbb7457b 100644 --- a/src/a2a/server/tasks/database_task_store.py +++ b/src/a2a/server/tasks/database_task_store.py @@ -1,7 +1,6 @@ import logging from datetime import datetime, timezone -from typing import Any, cast try: From b8504cfd8dba2162ec7d1f722497a97e0d7b93fb Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 10 Mar 2026 15:12:12 +0000 Subject: [PATCH 08/11] Refactor: Move `types_v03` and `sqlalchemy.insert` imports to module level and add new database files. --- tests/server/tasks/test_database_task_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index dbc6bbd6..3e2a91e2 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -8,6 +8,8 @@ from _pytest.mark.structures import ParameterSet from a2a.types.a2a_pb2 import ListTasksRequest +from a2a.compat.v0_3 import types as types_v03 +from sqlalchemy import insert # Skip entire test module if SQLAlchemy is not installed @@ -693,8 +695,6 @@ async def test_get_0_3_task_detailed( (string-based enums, different field names) and verifies that the store correctly converts it to the modern Protobuf-based Task model. """ - from a2a.compat.v0_3 import types as types_v03 - from sqlalchemy import insert task_id = 'legacy-detailed-1' owner = 'legacy_user' From 335ea4159bf4e5210bd5719765ec6c62e7e0fbe6 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 10 Mar 2026 15:20:14 +0000 Subject: [PATCH 09/11] test: clarify legacy data conversion comments in database store tests and add new test database files --- .../tasks/test_database_push_notification_config_store.py | 2 +- tests/server/tasks/test_database_task_store.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/server/tasks/test_database_push_notification_config_store.py b/tests/server/tasks/test_database_push_notification_config_store.py index 8236112e..d4d08da1 100644 --- a/tests/server/tasks/test_database_push_notification_config_store.py +++ b/tests/server/tasks/test_database_push_notification_config_store.py @@ -768,7 +768,7 @@ async def test_get_0_3_push_notification_config_detailed( await session.execute(stmt) # 3. Retrieve the config using the standard store.get_info() - # This will trigger the DatabasePushNotificationConfigStore._from_orm legacy fallback + # This will trigger the DatabasePushNotificationConfigStore._from_orm legacy conversion retrieved_configs = await db_store_parameterized.get_info( task_id, context_user ) diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index 3e2a91e2..781c46c7 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -774,7 +774,7 @@ async def test_get_0_3_task_detailed( await session.execute(stmt) # 3. Retrieve the task using the standard store.get() - # This will trigger conversion from legacy to 1.0 format in _from_orm method + # This will trigger conversion from legacy to 1.0 format in the _from_orm method retrieved_task = await db_store_parameterized.get(task_id, context_user) # 4. Verify the conversion to modern Protobuf From 42aada9e6c3af91527618041cfd0a32361417bc9 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 10 Mar 2026 17:05:39 +0000 Subject: [PATCH 10/11] refactor: improve SQLAlchemy ORM to Protobuf message mapping for JSON fields and refine model column definitions. --- src/a2a/server/models.py | 20 ++++--- src/a2a/server/tasks/database_task_store.py | 59 +++++++++++++++------ 2 files changed, 56 insertions(+), 23 deletions(-) diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index 0cad562d..19aab72d 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -48,17 +48,23 @@ class TaskMixin: last_updated: Mapped[datetime | None] = mapped_column( DateTime, nullable=True ) - status: Mapped[Any] = mapped_column(JSON) - artifacts: Mapped[list[Any] | None] = mapped_column(JSON, nullable=True) - history: Mapped[list[Any] | None] = mapped_column(JSON, nullable=True) + status: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) + artifacts: Mapped[list[dict[str, Any]] | None] = mapped_column( + JSON, nullable=True + ) + history: Mapped[list[dict[str, Any]] | None] = mapped_column( + JSON, nullable=True + ) protocol_version: Mapped[str | None] = mapped_column( String(16), nullable=True ) - # Using 'task_metadata' to avoid conflict with SQLAlchemy's 'Base.metadata' - task_metadata: Mapped[dict[str, Any] | None] = mapped_column( - JSON, nullable=True, name='metadata' - ) + # Using declared_attr to avoid conflict with Pydantic's metadata + @declared_attr + @classmethod + def task_metadata(cls) -> Mapped[dict[str, Any] | None]: + """Define the 'metadata' column, avoiding name conflicts with Pydantic.""" + return mapped_column(JSON, nullable=True, name='metadata') @override def __repr__(self) -> str: diff --git a/src/a2a/server/tasks/database_task_store.py b/src/a2a/server/tasks/database_task_store.py index bbb7457b..582a8c25 100644 --- a/src/a2a/server/tasks/database_task_store.py +++ b/src/a2a/server/tasks/database_task_store.py @@ -1,6 +1,7 @@ import logging from datetime import datetime, timezone +from typing import Any, cast try: @@ -118,7 +119,6 @@ async def _ensure_initialized(self) -> None: def _to_orm(self, task: Task, owner: str) -> TaskModel: """Maps a Proto Task to a SQLAlchemy TaskModel instance.""" - task_dict = MessageToDict(task) return self.task_model( id=task.id, context_id=task.context_id, @@ -126,29 +126,56 @@ def _to_orm(self, task: Task, owner: str) -> TaskModel: owner=owner, last_updated=( task.status.timestamp.ToDatetime() - if task.HasField('status') and task.status.HasField('timestamp') + if task.status.HasField('timestamp') else None ), - status=task_dict.get('status'), - artifacts=task_dict.get('artifacts', []), - history=task_dict.get('history', []), - task_metadata=task_dict.get('metadata'), + status=MessageToDict(task.status), + artifacts=[MessageToDict(artifact) for artifact in task.artifacts], + history=[MessageToDict(history) for history in task.history], + task_metadata=( + MessageToDict(task.metadata) if task.metadata.fields else None + ), protocol_version='1.0', ) def _from_orm(self, task_model: TaskModel) -> Task: """Maps a SQLAlchemy TaskModel to a Proto Task instance.""" - task_dict = { - 'id': task_model.id, - 'context_id': task_model.context_id, - 'status': task_model.status, - 'artifacts': task_model.artifacts, - 'history': task_model.history, - 'metadata': task_model.task_metadata, - } if task_model.protocol_version == '1.0': - return ParseDict(task_dict, Task()) - legacy_task = types_v03.Task.model_validate(task_dict) + task = Task( + id=task_model.id, + context_id=task_model.context_id, + ) + if task_model.status: + ParseDict( + cast('dict[str, Any]', task_model.status), task.status + ) + if task_model.artifacts: + for art_dict in cast( + 'list[dict[str, Any]]', task_model.artifacts + ): + art = task.artifacts.add() + ParseDict(art_dict, art) + if task_model.history: + for msg_dict in cast( + 'list[dict[str, Any]]', task_model.history + ): + msg = task.history.add() + ParseDict(msg_dict, msg) + if task_model.task_metadata: + task.metadata.update( + cast('dict[str, Any]', task_model.task_metadata) + ) + return task + + # Legacy conversion + legacy_task = types_v03.Task( + id=task_model.id, + context_id=task_model.context_id, + status=task_model.status, + artifacts=task_model.artifacts or [], + history=task_model.history or [], + metadata=task_model.task_metadata or {}, + ) return conversions.to_core_task(legacy_task) async def save( From 615f608385be190343cfabd445ce302f008fcca7 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Tue, 10 Mar 2026 17:26:34 +0000 Subject: [PATCH 11/11] refactor: use `model_validate` for legacy task conversion to improve type handling and Pyright compatibility. --- src/a2a/server/tasks/database_task_store.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/a2a/server/tasks/database_task_store.py b/src/a2a/server/tasks/database_task_store.py index 582a8c25..255145f8 100644 --- a/src/a2a/server/tasks/database_task_store.py +++ b/src/a2a/server/tasks/database_task_store.py @@ -168,13 +168,16 @@ def _from_orm(self, task_model: TaskModel) -> Task: return task # Legacy conversion - legacy_task = types_v03.Task( - id=task_model.id, - context_id=task_model.context_id, - status=task_model.status, - artifacts=task_model.artifacts or [], - history=task_model.history or [], - metadata=task_model.task_metadata or {}, + # Reconstruct legacy task using model_validate to handle dicts and resolve Pyright issues + legacy_task = types_v03.Task.model_validate( + { + 'id': task_model.id, + 'context_id': task_model.context_id, + 'status': task_model.status, + 'artifacts': task_model.artifacts, + 'history': task_model.history, + 'metadata': task_model.task_metadata, + } ) return conversions.to_core_task(legacy_task)