diff --git a/src/a2a/server/owner_resolver.py b/src/a2a/server/owner_resolver.py index 798eb8c9b..4fca42d24 100644 --- a/src/a2a/server/owner_resolver.py +++ b/src/a2a/server/owner_resolver.py @@ -4,13 +4,10 @@ # Definition -OwnerResolver = Callable[[ServerCallContext | None], str] +OwnerResolver = Callable[[ServerCallContext], str] # Example Default Implementation -def resolve_user_scope(context: ServerCallContext | None) -> str: +def resolve_user_scope(context: ServerCallContext) -> str: """Resolves the owner scope based on the user in the context.""" - if not context: - return 'unknown' - # Example: Basic user name. Adapt as needed for your user model. return context.user.user_name diff --git a/src/a2a/server/tasks/database_task_store.py b/src/a2a/server/tasks/database_task_store.py index ac1cf947b..cfa007e56 100644 --- a/src/a2a/server/tasks/database_task_store.py +++ b/src/a2a/server/tasks/database_task_store.py @@ -169,9 +169,7 @@ def _from_orm(self, task_model: TaskModel) -> Task: # Legacy conversion return compat_task_model_to_core(task_model) - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves or updates a task in the database for the resolved owner.""" await self._ensure_initialized() owner = self.owner_resolver(context) @@ -185,7 +183,7 @@ async def save( ) async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the database by ID, for the given owner.""" await self._ensure_initialized() @@ -216,7 +214,7 @@ async def get( async def list( self, params: a2a_pb2.ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> a2a_pb2.ListTasksResponse: """Retrieves tasks from the database based on provided parameters, for the given owner.""" await self._ensure_initialized() @@ -315,9 +313,7 @@ async def list( page_size=page_size, ) - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """Deletes a task from the database by ID, for the given owner.""" await self._ensure_initialized() owner = self.owner_resolver(context) diff --git a/src/a2a/server/tasks/inmemory_task_store.py b/src/a2a/server/tasks/inmemory_task_store.py index eb596ca4b..6634554df 100644 --- a/src/a2a/server/tasks/inmemory_task_store.py +++ b/src/a2a/server/tasks/inmemory_task_store.py @@ -34,9 +34,7 @@ def __init__( def _get_owner_tasks(self, owner: str) -> dict[str, Task]: return self.tasks.get(owner, {}) - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves or updates a task in the in-memory store for the resolved owner.""" owner = self.owner_resolver(context) if owner not in self.tasks: @@ -49,7 +47,7 @@ async def save( ) async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the in-memory store by ID, for the given owner.""" owner = self.owner_resolver(context) @@ -76,7 +74,7 @@ async def get( async def list( self, params: a2a_pb2.ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> a2a_pb2.ListTasksResponse: """Retrieves a list of tasks from the store, for the given owner.""" owner = self.owner_resolver(context) @@ -155,9 +153,7 @@ async def list( page_size=page_size, ) - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """Deletes a task from the in-memory store by ID, for the given owner.""" owner = self.owner_resolver(context) async with self.lock: diff --git a/src/a2a/server/tasks/task_store.py b/src/a2a/server/tasks/task_store.py index a4d3308c0..25e4838d1 100644 --- a/src/a2a/server/tasks/task_store.py +++ b/src/a2a/server/tasks/task_store.py @@ -11,14 +11,12 @@ class TaskStore(ABC): """ @abstractmethod - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves or updates a task in the store.""" @abstractmethod async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the store by ID.""" @@ -26,12 +24,10 @@ async def get( async def list( self, params: ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> ListTasksResponse: """Retrieves a list of tasks from the store.""" @abstractmethod - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """Deletes a task from the store by ID.""" diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index 445a45a37..8c9b7d07d 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -56,6 +56,9 @@ def user_name(self) -> str: return self._user_name +TEST_CONTEXT = ServerCallContext(user=SampleUser('test_user')) + + # DSNs for different databases SQLITE_TEST_DSN = ( 'sqlite+aiosqlite:///file:testdb?mode=memory&cache=shared&uri=true' @@ -170,13 +173,17 @@ async def test_save_task(db_store_parameterized: DatabaseTaskStore) -> None: task_to_save.id = ( f'save-task-{db_store_parameterized.engine.url.drivername}' ) - await db_store_parameterized.save(task_to_save) + await db_store_parameterized.save(task_to_save, TEST_CONTEXT) - retrieved_task = await db_store_parameterized.get(task_to_save.id) + retrieved_task = await db_store_parameterized.get( + task_to_save.id, TEST_CONTEXT + ) assert retrieved_task is not None assert retrieved_task.id == task_to_save.id assert MessageToDict(retrieved_task) == MessageToDict(task_to_save) - await db_store_parameterized.delete(task_to_save.id) # Cleanup + await db_store_parameterized.delete( + task_to_save.id, TEST_CONTEXT + ) # Cleanup @pytest.mark.asyncio @@ -186,14 +193,18 @@ async def test_get_task(db_store_parameterized: DatabaseTaskStore) -> None: task_to_save = Task() task_to_save.CopyFrom(MINIMAL_TASK_OBJ) task_to_save.id = task_id - await db_store_parameterized.save(task_to_save) + await db_store_parameterized.save(task_to_save, TEST_CONTEXT) - retrieved_task = await db_store_parameterized.get(task_to_save.id) + retrieved_task = await db_store_parameterized.get( + task_to_save.id, TEST_CONTEXT + ) assert retrieved_task is not None assert retrieved_task.id == task_to_save.id assert retrieved_task.context_id == task_to_save.context_id assert retrieved_task.status.state == TaskState.TASK_STATE_SUBMITTED - await db_store_parameterized.delete(task_to_save.id) # Cleanup + await db_store_parameterized.delete( + task_to_save.id, TEST_CONTEXT + ) # Cleanup @pytest.mark.asyncio @@ -321,9 +332,9 @@ async def test_list_tasks( ), ] for task in tasks_to_create: - await db_store_parameterized.save(task) + await db_store_parameterized.save(task, TEST_CONTEXT) - page = await db_store_parameterized.list(params) + page = await db_store_parameterized.list(params, TEST_CONTEXT) retrieved_ids = [task.id for task in page.tasks] assert retrieved_ids == expected_ids @@ -333,7 +344,7 @@ async def test_list_tasks( # Cleanup for task in tasks_to_create: - await db_store_parameterized.delete(task.id) + await db_store_parameterized.delete(task.id, TEST_CONTEXT) @pytest.mark.asyncio @@ -381,16 +392,16 @@ async def test_list_tasks_fails( ), ] for task in tasks_to_create: - await db_store_parameterized.save(task) + await db_store_parameterized.save(task, TEST_CONTEXT) with pytest.raises(InvalidParamsError) as excinfo: - await db_store_parameterized.list(params) + await db_store_parameterized.list(params, TEST_CONTEXT) assert expected_error_message in str(excinfo.value) # Cleanup for task in tasks_to_create: - await db_store_parameterized.delete(task.id) + await db_store_parameterized.delete(task.id, TEST_CONTEXT) @pytest.mark.asyncio @@ -398,7 +409,9 @@ async def test_get_nonexistent_task( db_store_parameterized: DatabaseTaskStore, ) -> None: """Test retrieving a nonexistent task.""" - retrieved_task = await db_store_parameterized.get('nonexistent-task-id') + retrieved_task = await db_store_parameterized.get( + 'nonexistent-task-id', TEST_CONTEXT + ) assert retrieved_task is None @@ -409,13 +422,23 @@ async def test_delete_task(db_store_parameterized: DatabaseTaskStore) -> None: task_to_save_and_delete = Task() task_to_save_and_delete.CopyFrom(MINIMAL_TASK_OBJ) task_to_save_and_delete.id = task_id - await db_store_parameterized.save(task_to_save_and_delete) + await db_store_parameterized.save(task_to_save_and_delete, TEST_CONTEXT) assert ( - await db_store_parameterized.get(task_to_save_and_delete.id) is not None + await db_store_parameterized.get( + task_to_save_and_delete.id, TEST_CONTEXT + ) + is not None + ) + await db_store_parameterized.delete( + task_to_save_and_delete.id, TEST_CONTEXT + ) + assert ( + await db_store_parameterized.get( + task_to_save_and_delete.id, TEST_CONTEXT + ) + is None ) - await db_store_parameterized.delete(task_to_save_and_delete.id) - assert await db_store_parameterized.get(task_to_save_and_delete.id) is None @pytest.mark.asyncio @@ -423,7 +446,9 @@ async def test_delete_nonexistent_task( db_store_parameterized: DatabaseTaskStore, ) -> None: """Test deleting a nonexistent task. Should not error.""" - await db_store_parameterized.delete('nonexistent-delete-task-id') + await db_store_parameterized.delete( + 'nonexistent-delete-task-id', TEST_CONTEXT + ) @pytest.mark.asyncio @@ -455,8 +480,10 @@ async def test_save_and_get_detailed_task( ], ) - await db_store_parameterized.save(test_task) - retrieved_task = await db_store_parameterized.get(test_task.id) + await db_store_parameterized.save(test_task, TEST_CONTEXT) + retrieved_task = await db_store_parameterized.get( + test_task.id, TEST_CONTEXT + ) assert retrieved_task is not None assert retrieved_task.id == test_task.id @@ -479,8 +506,8 @@ async def test_save_and_get_detailed_task( == MessageToDict(test_task)['history'] ) - await db_store_parameterized.delete(test_task.id) - assert await db_store_parameterized.get(test_task.id) is None + await db_store_parameterized.delete(test_task.id, TEST_CONTEXT) + assert await db_store_parameterized.get(test_task.id, TEST_CONTEXT) is None @pytest.mark.asyncio @@ -498,9 +525,11 @@ async def test_update_task(db_store_parameterized: DatabaseTaskStore) -> None: artifacts=[], history=[], ) - await db_store_parameterized.save(original_task) + await db_store_parameterized.save(original_task, TEST_CONTEXT) - retrieved_before_update = await db_store_parameterized.get(task_id) + retrieved_before_update = await db_store_parameterized.get( + task_id, TEST_CONTEXT + ) assert retrieved_before_update is not None assert ( retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED @@ -516,16 +545,18 @@ async def test_update_task(db_store_parameterized: DatabaseTaskStore) -> None: updated_task.status.timestamp.FromDatetime(updated_timestamp) updated_task.metadata['update_key'] = 'update_value' - await db_store_parameterized.save(updated_task) + await db_store_parameterized.save(updated_task, TEST_CONTEXT) - retrieved_after_update = await db_store_parameterized.get(task_id) + retrieved_after_update = await db_store_parameterized.get( + task_id, TEST_CONTEXT + ) assert retrieved_after_update is not None assert retrieved_after_update.status.state == TaskState.TASK_STATE_COMPLETED assert dict(retrieved_after_update.metadata) == { 'update_key': 'update_value' } - await db_store_parameterized.delete(task_id) + await db_store_parameterized.delete(task_id, TEST_CONTEXT) @pytest.mark.asyncio @@ -547,9 +578,9 @@ async def test_metadata_field_mapping( context_id='session-meta-1', status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), ) - await db_store_parameterized.save(task_no_metadata) + await db_store_parameterized.save(task_no_metadata, TEST_CONTEXT) retrieved_no_metadata = await db_store_parameterized.get( - 'task-metadata-test-1' + 'task-metadata-test-1', TEST_CONTEXT ) assert retrieved_no_metadata is not None # Proto Struct is empty, not None @@ -563,8 +594,10 @@ async def test_metadata_field_mapping( status=TaskStatus(state=TaskState.TASK_STATE_WORKING), metadata=simple_metadata, ) - await db_store_parameterized.save(task_simple_metadata) - retrieved_simple = await db_store_parameterized.get('task-metadata-test-2') + await db_store_parameterized.save(task_simple_metadata, TEST_CONTEXT) + retrieved_simple = await db_store_parameterized.get( + 'task-metadata-test-2', TEST_CONTEXT + ) assert retrieved_simple is not None assert dict(retrieved_simple.metadata) == simple_metadata @@ -586,8 +619,10 @@ async def test_metadata_field_mapping( status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), metadata=complex_metadata, ) - await db_store_parameterized.save(task_complex_metadata) - retrieved_complex = await db_store_parameterized.get('task-metadata-test-3') + await db_store_parameterized.save(task_complex_metadata, TEST_CONTEXT) + retrieved_complex = await db_store_parameterized.get( + 'task-metadata-test-3', TEST_CONTEXT + ) assert retrieved_complex is not None # Convert proto Struct to dict for comparison retrieved_meta = MessageToDict(retrieved_complex.metadata) @@ -599,14 +634,16 @@ async def test_metadata_field_mapping( context_id='session-meta-4', status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), ) - await db_store_parameterized.save(task_update_metadata) + await db_store_parameterized.save(task_update_metadata, TEST_CONTEXT) # Update metadata task_update_metadata.metadata['updated'] = True task_update_metadata.metadata['timestamp'] = '2024-01-01' - await db_store_parameterized.save(task_update_metadata) + await db_store_parameterized.save(task_update_metadata, TEST_CONTEXT) - retrieved_updated = await db_store_parameterized.get('task-metadata-test-4') + retrieved_updated = await db_store_parameterized.get( + 'task-metadata-test-4', TEST_CONTEXT + ) assert retrieved_updated is not None assert dict(retrieved_updated.metadata) == { 'updated': True, @@ -615,17 +652,19 @@ async def test_metadata_field_mapping( # Test 5: Clear metadata (set to empty) task_update_metadata.metadata.Clear() - await db_store_parameterized.save(task_update_metadata) + await db_store_parameterized.save(task_update_metadata, TEST_CONTEXT) - retrieved_none = await db_store_parameterized.get('task-metadata-test-4') + retrieved_none = await db_store_parameterized.get( + 'task-metadata-test-4', TEST_CONTEXT + ) assert retrieved_none is not None assert len(retrieved_none.metadata) == 0 # Cleanup - await db_store_parameterized.delete('task-metadata-test-1') - await db_store_parameterized.delete('task-metadata-test-2') - await db_store_parameterized.delete('task-metadata-test-3') - await db_store_parameterized.delete('task-metadata-test-4') + await db_store_parameterized.delete('task-metadata-test-1', TEST_CONTEXT) + await db_store_parameterized.delete('task-metadata-test-2', TEST_CONTEXT) + await db_store_parameterized.delete('task-metadata-test-3', TEST_CONTEXT) + await db_store_parameterized.delete('task-metadata-test-4', TEST_CONTEXT) @pytest.mark.asyncio @@ -874,7 +913,7 @@ async def test_core_to_0_3_model_conversion( ) # 1. Save the task (will use core_to_compat_task_model) - await store.save(original_task) + await store.save(original_task, TEST_CONTEXT) # 2. Verify it's stored in v0.3 format directly in DB async with store.async_session_maker() as session: @@ -882,17 +921,18 @@ async def test_core_to_0_3_model_conversion( assert db_task is not None assert db_task.protocol_version == '0.3' # v0.3 status JSON uses string for state + assert isinstance(db_task.status, dict) assert db_task.status['state'] == 'working' # 3. Retrieve the task (will use compat_task_model_to_core) - retrieved_task = await store.get(task_id) + retrieved_task = await store.get(task_id, context=TEST_CONTEXT) assert retrieved_task is not None assert retrieved_task.id == original_task.id assert retrieved_task.status.state == TaskState.TASK_STATE_WORKING assert dict(retrieved_task.metadata) == {'key': 'value'} # Reset conversion attributes store.core_to_model_conversion = None - await store.delete('v03-persistence-task') + await store.delete('v03-persistence-task', TEST_CONTEXT) # Ensure aiosqlite, asyncpg, and aiomysql are installed in the test environment (added to pyproject.toml). diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index 2184c2116..40ddc4175 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -25,6 +25,9 @@ def user_name(self) -> str: return self._user_name +TEST_CONTEXT = ServerCallContext(user=SampleUser('test_user')) + + def create_minimal_task( task_id: str = 'task-abc', context_id: str = 'session-xyz' ) -> Task: @@ -41,8 +44,8 @@ async def test_in_memory_task_store_save_and_get() -> None: """Test saving and retrieving a task from the in-memory store.""" store = InMemoryTaskStore() task = create_minimal_task() - await store.save(task) - retrieved_task = await store.get('task-abc') + await store.save(task, TEST_CONTEXT) + retrieved_task = await store.get('task-abc', TEST_CONTEXT) assert retrieved_task == task @@ -50,7 +53,7 @@ async def test_in_memory_task_store_save_and_get() -> None: async def test_in_memory_task_store_get_nonexistent() -> None: """Test retrieving a nonexistent task.""" store = InMemoryTaskStore() - retrieved_task = await store.get('nonexistent') + retrieved_task = await store.get('nonexistent', TEST_CONTEXT) assert retrieved_task is None @@ -179,9 +182,9 @@ async def test_list_tasks( ), ] for task in tasks_to_create: - await store.save(task) + await store.save(task, TEST_CONTEXT) - page = await store.list(params) + page = await store.list(params, TEST_CONTEXT) retrieved_ids = [task.id for task in page.tasks] assert retrieved_ids == expected_ids @@ -191,7 +194,7 @@ async def test_list_tasks( # Cleanup for task in tasks_to_create: - await store.delete(task.id) + await store.delete(task.id, TEST_CONTEXT) @pytest.mark.asyncio @@ -238,16 +241,16 @@ async def test_list_tasks_fails( ), ] for task in tasks_to_create: - await store.save(task) + await store.save(task, TEST_CONTEXT) with pytest.raises(InvalidParamsError) as excinfo: - await store.list(params) + await store.list(params, TEST_CONTEXT) assert expected_error_message in str(excinfo.value) # Cleanup for task in tasks_to_create: - await store.delete(task.id) + await store.delete(task.id, TEST_CONTEXT) @pytest.mark.asyncio @@ -255,9 +258,9 @@ async def test_in_memory_task_store_delete() -> None: """Test deleting a task from the store.""" store = InMemoryTaskStore() task = create_minimal_task() - await store.save(task) - await store.delete('task-abc') - retrieved_task = await store.get('task-abc') + await store.save(task, TEST_CONTEXT) + await store.delete('task-abc', TEST_CONTEXT) + retrieved_task = await store.get('task-abc', TEST_CONTEXT) assert retrieved_task is None @@ -265,7 +268,7 @@ async def test_in_memory_task_store_delete() -> None: async def test_in_memory_task_store_delete_nonexistent() -> None: """Test deleting a nonexistent task.""" store = InMemoryTaskStore() - await store.delete('nonexistent') + await store.delete('nonexistent', TEST_CONTEXT) @pytest.mark.asyncio diff --git a/tests/server/test_owner_resolver.py b/tests/server/test_owner_resolver.py index 5bac5c605..bb7b91012 100644 --- a/tests/server/test_owner_resolver.py +++ b/tests/server/test_owner_resolver.py @@ -19,13 +19,13 @@ def user_name(self) -> str: return self._user_name -def test_resolve_user_scope_valid_user(): - """Test resolve_user_scope with a valid user in the context.""" +def test_resolve_user(): + """Test resolve_user_scope.""" user = SampleUser(user_name='SampleUser') context = ServerCallContext(user=user) assert resolve_user_scope(context) == 'SampleUser' -def test_resolve_user_scope_no_context(): - """Test resolve_user_scope when the context is None.""" - assert resolve_user_scope(None) == 'unknown' +def test_resolve_user_default_context(): + """Test resolve_user_scope with default context.""" + assert resolve_user_scope(ServerCallContext()) == ''