diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index 62771541..19aab72d 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -11,24 +11,14 @@ def override(func): # noqa: ANN001, ANN201 return func -from google.protobuf.json_format import MessageToDict, ParseDict -from google.protobuf.message import Message as ProtoMessage -from pydantic import BaseModel - -from a2a.types.a2a_pb2 import Artifact, Message, TaskStatus - - try: - from sqlalchemy import JSON, DateTime, Dialect, Index, LargeBinary, String + from sqlalchemy import JSON, DateTime, Index, LargeBinary, String from sqlalchemy.orm import ( DeclarativeBase, Mapped, declared_attr, mapped_column, ) - from sqlalchemy.types import ( - TypeDecorator, - ) except ImportError as e: raise ImportError( 'Database models require SQLAlchemy. ' @@ -40,101 +30,6 @@ def override(func): # noqa: ANN001, ANN201 ) from e -T = TypeVar('T') - - -class PydanticType(TypeDecorator[T], Generic[T]): - """SQLAlchemy type that handles Pydantic model and Protobuf message serialization.""" - - impl = JSON - cache_ok = True - - def __init__(self, pydantic_type: type[T], **kwargs: dict[str, Any]): - """Initialize the PydanticType. - - Args: - pydantic_type: The Pydantic model or Protobuf message type to handle. - **kwargs: Additional arguments for TypeDecorator. - """ - self.pydantic_type = pydantic_type - super().__init__(**kwargs) - - def process_bind_param( - self, value: T | None, dialect: Dialect - ) -> dict[str, Any] | None: - """Convert Pydantic model or Protobuf message to a JSON-serializable dictionary for the database.""" - if value is None: - return None - if isinstance(value, ProtoMessage): - return MessageToDict(value, preserving_proto_field_name=False) - if isinstance(value, BaseModel): - return value.model_dump(mode='json') - return value # type: ignore[return-value] - - def process_result_value( - self, value: dict[str, Any] | None, dialect: Dialect - ) -> T | None: - """Convert a JSON-like dictionary from the database back to a Pydantic model or Protobuf message.""" - if value is None: - return None - # Check if it's a protobuf message class - if isinstance(self.pydantic_type, type) and issubclass( - self.pydantic_type, ProtoMessage - ): - return ParseDict(value, self.pydantic_type()) # type: ignore[return-value] - # Assume it's a Pydantic model - return self.pydantic_type.model_validate(value) # type: ignore[attr-defined] - - -class PydanticListType(TypeDecorator, Generic[T]): - """SQLAlchemy type that handles lists of Pydantic models or Protobuf messages.""" - - impl = JSON - cache_ok = True - - def __init__(self, pydantic_type: type[T], **kwargs: dict[str, Any]): - """Initialize the PydanticListType. - - Args: - pydantic_type: The Pydantic model or Protobuf message type for items in the list. - **kwargs: Additional arguments for TypeDecorator. - """ - self.pydantic_type = pydantic_type - super().__init__(**kwargs) - - def process_bind_param( - self, value: list[T] | None, dialect: Dialect - ) -> list[dict[str, Any]] | None: - """Convert a list of Pydantic models or Protobuf messages to a JSON-serializable list for the DB.""" - if value is None: - return None - result: list[dict[str, Any]] = [] - for item in value: - if isinstance(item, ProtoMessage): - result.append( - MessageToDict(item, preserving_proto_field_name=False) - ) - elif isinstance(item, BaseModel): - result.append(item.model_dump(mode='json')) - else: - result.append(item) # type: ignore[arg-type] - return result - - def process_result_value( - self, value: list[dict[str, Any]] | None, dialect: Dialect - ) -> list[T] | None: - """Convert a JSON-like list from the DB back to a list of Pydantic models or Protobuf messages.""" - if value is None: - return None - # Check if it's a protobuf message class - if isinstance(self.pydantic_type, type) and issubclass( - self.pydantic_type, ProtoMessage - ): - return [ParseDict(item, self.pydantic_type()) for item in value] # type: ignore[misc] - # Assume it's a Pydantic model - return [self.pydantic_type.model_validate(item) for item in value] # type: ignore[attr-defined] - - # Base class for all database models class Base(DeclarativeBase): """Base class for declarative models in A2A SDK.""" @@ -153,14 +48,12 @@ class TaskMixin: last_updated: Mapped[datetime | None] = mapped_column( DateTime, nullable=True ) - - # Properly typed Pydantic fields with automatic serialization - status: Mapped[TaskStatus] = mapped_column(PydanticType(TaskStatus)) - artifacts: Mapped[list[Artifact] | None] = mapped_column( - PydanticListType(Artifact), nullable=True + status: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) + artifacts: Mapped[list[dict[str, Any]] | None] = mapped_column( + JSON, nullable=True ) - history: Mapped[list[Message] | None] = mapped_column( - PydanticListType(Message), nullable=True + history: Mapped[list[dict[str, Any]] | None] = mapped_column( + JSON, nullable=True ) protocol_version: Mapped[str | None] = mapped_column( String(16), nullable=True diff --git a/src/a2a/server/tasks/database_push_notification_config_store.py b/src/a2a/server/tasks/database_push_notification_config_store.py index 17eeba1d..ebfdf01e 100644 --- a/src/a2a/server/tasks/database_push_notification_config_store.py +++ b/src/a2a/server/tasks/database_push_notification_config_store.py @@ -1,5 +1,4 @@ # ruff: noqa: PLC0415 -import json import logging from typing import TYPE_CHECKING @@ -27,6 +26,8 @@ "or 'pip install a2a-sdk[sql]'" ) from e +from a2a.compat.v0_3 import conversions +from a2a.compat.v0_3 import types as types_v03 from a2a.server.context import ServerCallContext from a2a.server.models import ( Base, @@ -163,6 +164,7 @@ def _to_orm( config_id=config.id, owner=owner, config_data=data_to_store, + protocol_version='1.0', ) def _from_orm( @@ -181,11 +183,11 @@ def _from_orm( try: decrypted_payload = self._fernet.decrypt(payload) - return Parse( + return self._parse_config( decrypted_payload.decode('utf-8'), - TaskPushNotificationConfig(), + model_instance.protocol_version, ) - except (json.JSONDecodeError, Exception) as e: + except Exception as e: if isinstance(e, InvalidToken): # Decryption failed. This could be because the data is not encrypted. # We'll log a warning and try to parse it as plain JSON as a fallback. @@ -215,7 +217,10 @@ def _from_orm( if isinstance(payload, bytes) else payload ) - return Parse(payload_str, TaskPushNotificationConfig()) + return self._parse_config( + payload_str, model_instance.protocol_version + ) + except Exception as e: if self._fernet: logger.exception( @@ -334,3 +339,22 @@ async def delete_info( owner, config_id, ) + + def _parse_config( + self, json_payload: str, protocol_version: str | None = None + ) -> TaskPushNotificationConfig: + """Parses a JSON payload into a TaskPushNotificationConfig proto. + + Uses protocol_version to decide between modern parsing and legacy fallback. + """ + if protocol_version == '1.0': + return Parse(json_payload, TaskPushNotificationConfig()) + + legacy_instance = ( + types_v03.TaskPushNotificationConfig.model_validate_json( + json_payload + ) + ) + return conversions.to_core_task_push_notification_config( + legacy_instance + ) diff --git a/src/a2a/server/tasks/database_task_store.py b/src/a2a/server/tasks/database_task_store.py index 4f7b1ecd..255145f8 100644 --- a/src/a2a/server/tasks/database_task_store.py +++ b/src/a2a/server/tasks/database_task_store.py @@ -31,8 +31,10 @@ "or 'pip install a2a-sdk[sql]'" ) from e -from google.protobuf.json_format import MessageToDict +from google.protobuf.json_format import MessageToDict, ParseDict +from a2a.compat.v0_3 import conversions +from a2a.compat.v0_3 import types as types_v03 from a2a.server.context import ServerCallContext from a2a.server.models import Base, TaskModel, create_task_model from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope @@ -117,8 +119,6 @@ async def _ensure_initialized(self) -> None: def _to_orm(self, task: Task, owner: str) -> TaskModel: """Maps a Proto Task to a SQLAlchemy TaskModel instance.""" - # Pass proto objects directly - PydanticType/PydanticListType - # handle serialization via process_bind_param return self.task_model( id=task.id, context_id=task.context_id, @@ -126,36 +126,60 @@ def _to_orm(self, task: Task, owner: str) -> TaskModel: owner=owner, last_updated=( task.status.timestamp.ToDatetime() - if task.HasField('status') and task.status.HasField('timestamp') + if task.status.HasField('timestamp') else None ), - status=task.status if task.HasField('status') else None, - artifacts=list(task.artifacts) if task.artifacts else [], - history=list(task.history) if task.history else [], + status=MessageToDict(task.status), + artifacts=[MessageToDict(artifact) for artifact in task.artifacts], + history=[MessageToDict(history) for history in task.history], task_metadata=( MessageToDict(task.metadata) if task.metadata.fields else None ), + protocol_version='1.0', ) def _from_orm(self, task_model: TaskModel) -> Task: """Maps a SQLAlchemy TaskModel to a Proto Task instance.""" - # PydanticType/PydanticListType already deserialize to proto objects - # via process_result_value, so we can construct the Task directly - task = Task( - id=task_model.id, - context_id=task_model.context_id, - ) - if task_model.status: - task.status.CopyFrom(task_model.status) - if task_model.artifacts: - task.artifacts.extend(task_model.artifacts) - if task_model.history: - task.history.extend(task_model.history) - if task_model.task_metadata: - task.metadata.update( - cast('dict[str, Any]', task_model.task_metadata) + if task_model.protocol_version == '1.0': + task = Task( + id=task_model.id, + context_id=task_model.context_id, ) - return task + if task_model.status: + ParseDict( + cast('dict[str, Any]', task_model.status), task.status + ) + if task_model.artifacts: + for art_dict in cast( + 'list[dict[str, Any]]', task_model.artifacts + ): + art = task.artifacts.add() + ParseDict(art_dict, art) + if task_model.history: + for msg_dict in cast( + 'list[dict[str, Any]]', task_model.history + ): + msg = task.history.add() + ParseDict(msg_dict, msg) + if task_model.task_metadata: + task.metadata.update( + cast('dict[str, Any]', task_model.task_metadata) + ) + return task + + # Legacy conversion + # Reconstruct legacy task using model_validate to handle dicts and resolve Pyright issues + legacy_task = types_v03.Task.model_validate( + { + 'id': task_model.id, + 'context_id': task_model.context_id, + 'status': task_model.status, + 'artifacts': task_model.artifacts, + 'history': task_model.history, + 'metadata': task_model.task_metadata, + } + ) + return conversions.to_core_task(legacy_task) async def save( self, task: Task, context: ServerCallContext | None = None diff --git a/tests/server/tasks/test_database_push_notification_config_store.py b/tests/server/tasks/test_database_push_notification_config_store.py index 6974881b..d4d08da1 100644 --- a/tests/server/tasks/test_database_push_notification_config_store.py +++ b/tests/server/tasks/test_database_push_notification_config_store.py @@ -5,6 +5,8 @@ import pytest from a2a.server.context import ServerCallContext from a2a.auth.user import User +from a2a.compat.v0_3 import types as types_v03 +from sqlalchemy import insert # Skip entire test module if SQLAlchemy is not installed @@ -719,3 +721,64 @@ async def test_owner_resource_scoping( # Cleanup remaining await config_store.delete_info('task1', context=context_user1) await config_store.delete_info('task1', context=context_user2) + + +@pytest.mark.asyncio +async def test_get_0_3_push_notification_config_detailed( + db_store_parameterized: DatabasePushNotificationConfigStore, +) -> None: + """Test retrieving a legacy v0.3 push notification config from the database. + + This test simulates a database that already contains legacy v0.3 JSON data + and verifies that the store correctly converts it to the modern Protobuf model. + """ + task_id = 'legacy-push-1' + config_id = 'config-legacy-1' + owner = 'legacy_user' + context_user = ServerCallContext(user=SampleUser(user_name=owner)) + + # 1. Create a legacy PushNotificationConfig using v0.3 models + legacy_config = types_v03.TaskPushNotificationConfig( + task_id=task_id, + push_notification_config=types_v03.PushNotificationConfig( + id=config_id, + url='https://example.com/push', + token='legacy-token', + authentication=types_v03.PushNotificationAuthenticationInfo( + schemes=['bearer'], + credentials='legacy-creds', + ), + ), + ) + + # 2. Manually insert the legacy data into the database + # For PushNotificationConfigStore, the data is stored in the config_data column. + async with db_store_parameterized.async_session_maker.begin() as session: + # Pydantic model_dump_json() produces the JSON that we'll store. + # Note: DatabasePushNotificationConfigStore normally encrypts this, but here + # we'll store it as plain JSON bytes to simulate legacy data. + legacy_json = legacy_config.model_dump_json() + + stmt = insert(db_store_parameterized.config_model).values( + task_id=task_id, + config_id=config_id, + owner=owner, + config_data=legacy_json.encode('utf-8'), + ) + await session.execute(stmt) + + # 3. Retrieve the config using the standard store.get_info() + # This will trigger the DatabasePushNotificationConfigStore._from_orm legacy conversion + retrieved_configs = await db_store_parameterized.get_info( + task_id, context_user + ) + + # 4. Verify the conversion to modern Protobuf + assert len(retrieved_configs) == 1 + retrieved = retrieved_configs[0] + assert retrieved.task_id == task_id + assert retrieved.id == config_id + assert retrieved.url == 'https://example.com/push' + assert retrieved.token == 'legacy-token' + assert retrieved.authentication.scheme == 'bearer' + assert retrieved.authentication.credentials == 'legacy-creds' diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index b71fd709..781c46c7 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -8,6 +8,8 @@ from _pytest.mark.structures import ParameterSet from a2a.types.a2a_pb2 import ListTasksRequest +from a2a.compat.v0_3 import types as types_v03 +from sqlalchemy import insert # Skip entire test module if SQLAlchemy is not installed @@ -683,4 +685,143 @@ async def test_owner_resource_scoping( await task_store.delete('u2-task1', context_user2) +@pytest.mark.asyncio +async def test_get_0_3_task_detailed( + db_store_parameterized: DatabaseTaskStore, +) -> None: + """Test retrieving a detailed legacy v0.3 task from the database. + + This test simulates a database that already contains legacy v0.3 JSON data + (string-based enums, different field names) and verifies that the store + correctly converts it to the modern Protobuf-based Task model. + """ + + task_id = 'legacy-detailed-1' + owner = 'legacy_user' + context_user = ServerCallContext(user=SampleUser(user_name=owner)) + + # 1. Create a detailed legacy Task using v0.3 models + legacy_task = types_v03.Task( + id=task_id, + context_id='legacy-ctx-1', + status=types_v03.TaskStatus( + state=types_v03.TaskState.working, + message=types_v03.Message( + message_id='msg-status', + role=types_v03.Role.agent, + parts=[ + types_v03.Part( + root=types_v03.TextPart(text='Legacy status message') + ) + ], + ), + timestamp='2023-10-27T10:00:00Z', + ), + history=[ + types_v03.Message( + message_id='msg-1', + role=types_v03.Role.user, + parts=[ + types_v03.Part(root=types_v03.TextPart(text='Hello legacy')) + ], + ), + types_v03.Message( + message_id='msg-2', + role=types_v03.Role.agent, + parts=[ + types_v03.Part( + root=types_v03.DataPart(data={'legacy_key': 'value'}) + ) + ], + ), + ], + artifacts=[ + types_v03.Artifact( + artifact_id='art-1', + name='Legacy Artifact', + parts=[ + types_v03.Part( + root=types_v03.FilePart( + file=types_v03.FileWithUri( + uri='https://example.com/legacy.txt', + mime_type='text/plain', + ) + ) + ) + ], + ) + ], + metadata={'meta_key': 'meta_val'}, + ) + + # 2. Manually insert the legacy data into the database + # We must bypass the store's save() method because it expects Protobuf objects. + async with db_store_parameterized.async_session_maker.begin() as session: + # Pydantic model_dump(mode='json') produces exactly what would be in the legacy DB + legacy_data = legacy_task.model_dump(mode='json') + + stmt = insert(db_store_parameterized.task_model).values( + id=task_id, + context_id=legacy_task.context_id, + owner=owner, + status=legacy_data['status'], + history=legacy_data['history'], + artifacts=legacy_data['artifacts'], + task_metadata=legacy_data['metadata'], + kind='task', + last_updated=None, + ) + await session.execute(stmt) + + # 3. Retrieve the task using the standard store.get() + # This will trigger conversion from legacy to 1.0 format in the _from_orm method + retrieved_task = await db_store_parameterized.get(task_id, context_user) + + # 4. Verify the conversion to modern Protobuf + assert retrieved_task is not None + assert retrieved_task.id == task_id + assert retrieved_task.context_id == 'legacy-ctx-1' + + # Check Status & State (The most critical part: string 'working' -> enum TASK_STATE_WORKING) + assert retrieved_task.status.state == TaskState.TASK_STATE_WORKING + assert retrieved_task.status.message.message_id == 'msg-status' + assert retrieved_task.status.message.role == Role.ROLE_AGENT + assert ( + retrieved_task.status.message.parts[0].text == 'Legacy status message' + ) + + # Check History + assert len(retrieved_task.history) == 2 + assert retrieved_task.history[0].message_id == 'msg-1' + assert retrieved_task.history[0].role == Role.ROLE_USER + assert retrieved_task.history[0].parts[0].text == 'Hello legacy' + + assert retrieved_task.history[1].message_id == 'msg-2' + assert retrieved_task.history[1].role == Role.ROLE_AGENT + assert ( + MessageToDict(retrieved_task.history[1].parts[0].data)['legacy_key'] + == 'value' + ) + + # Check Artifacts + assert len(retrieved_task.artifacts) == 1 + assert retrieved_task.artifacts[0].artifact_id == 'art-1' + assert retrieved_task.artifacts[0].name == 'Legacy Artifact' + assert ( + retrieved_task.artifacts[0].parts[0].url + == 'https://example.com/legacy.txt' + ) + + # Check Metadata + assert dict(retrieved_task.metadata) == {'meta_key': 'meta_val'} + + retrieved_tasks = await db_store_parameterized.list( + ListTasksRequest(), context_user + ) + assert retrieved_tasks is not None + assert retrieved_tasks.tasks == [retrieved_task] + + await db_store_parameterized.delete(task_id, context_user) + + # Ensure aiosqlite, asyncpg, and aiomysql are installed in the test environment (added to pyproject.toml). diff --git a/tests/server/test_models.py b/tests/server/test_models.py index 08d700ce..bfaaed9d 100644 --- a/tests/server/test_models.py +++ b/tests/server/test_models.py @@ -5,76 +5,9 @@ from sqlalchemy.orm import DeclarativeBase from a2a.server.models import ( - PydanticListType, - PydanticType, create_push_notification_config_model, create_task_model, ) -from a2a.types.a2a_pb2 import Artifact, Part, TaskState, TaskStatus - - -class TestPydanticType: - """Tests for PydanticType SQLAlchemy type decorator.""" - - def test_process_bind_param_with_pydantic_model(self): - pydantic_type = PydanticType(TaskStatus) - status = TaskStatus(state=TaskState.TASK_STATE_WORKING) - dialect = MagicMock() - - result = pydantic_type.process_bind_param(status, dialect) - assert result is not None - assert result['state'] == 'TASK_STATE_WORKING' - # message field is optional and not set - - def test_process_bind_param_with_none(self): - pydantic_type = PydanticType(TaskStatus) - dialect = MagicMock() - - result = pydantic_type.process_bind_param(None, dialect) - assert result is None - - def test_process_result_value(self): - pydantic_type = PydanticType(TaskStatus) - dialect = MagicMock() - - result = pydantic_type.process_result_value( - {'state': 'TASK_STATE_COMPLETED'}, dialect - ) - assert isinstance(result, TaskStatus) - assert result.state == TaskState.TASK_STATE_COMPLETED - - -class TestPydanticListType: - """Tests for PydanticListType SQLAlchemy type decorator.""" - - def test_process_bind_param_with_list(self): - pydantic_list_type = PydanticListType(Artifact) - artifacts = [ - Artifact(artifact_id='1', parts=[Part(text='Hello')]), - Artifact(artifact_id='2', parts=[Part(text='World')]), - ] - dialect = MagicMock() - - result = pydantic_list_type.process_bind_param(artifacts, dialect) - assert result is not None - assert len(result) == 2 - assert result[0]['artifactId'] == '1' # JSON mode uses camelCase - assert result[1]['artifactId'] == '2' - - def test_process_result_value_with_list(self): - pydantic_list_type = PydanticListType(Artifact) - dialect = MagicMock() - data = [ - {'artifactId': '1', 'parts': [{'text': 'Hello'}]}, - {'artifactId': '2', 'parts': [{'text': 'World'}]}, - ] - - result = pydantic_list_type.process_result_value(data, dialect) - assert result is not None - assert len(result) == 2 - assert all(isinstance(art, Artifact) for art in result) - assert result[0].artifact_id == '1' - assert result[1].artifact_id == '2' def test_create_task_model():