diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index a502bfb6..7e800c95 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -2,6 +2,7 @@ from a2a.utils import proto_utils from a2a.utils.artifact import ( + ArtifactStreamer, get_artifact_text, new_artifact, new_data_artifact, @@ -38,6 +39,7 @@ __all__ = [ 'AGENT_CARD_WELL_KNOWN_PATH', 'DEFAULT_RPC_URL', + 'ArtifactStreamer', 'TransportProtocol', 'append_artifact_to_task', 'are_modalities_compatible', diff --git a/src/a2a/utils/artifact.py b/src/a2a/utils/artifact.py index ac14087d..107f94c5 100644 --- a/src/a2a/utils/artifact.py +++ b/src/a2a/utils/artifact.py @@ -6,7 +6,7 @@ from google.protobuf.struct_pb2 import Struct, Value -from a2a.types.a2a_pb2 import Artifact, Part +from a2a.types.a2a_pb2 import Artifact, Part, TaskArtifactUpdateEvent from a2a.utils.parts import get_text_parts @@ -90,3 +90,82 @@ def get_artifact_text(artifact: Artifact, delimiter: str = '\n') -> str: A single string containing all text content, or an empty string if no text parts are found. """ return delimiter.join(get_text_parts(artifact.parts)) + + +class ArtifactStreamer: + """Helper for streaming text into a single artifact across multiple events. + + Creates a stable artifact ID on construction so all chunks reference + the same artifact, enabling proper append semantics per the A2A spec. + + Example:: + + streamer = ArtifactStreamer(context_id, task_id, name='response') + + async for chunk in model.stream(prompt): + await event_queue.enqueue_event(streamer.append(chunk)) + + await event_queue.enqueue_event(streamer.finalize()) + + Args: + context_id: The context ID associated with the task. + task_id: The task ID associated with the streaming session. + name: A human-readable name for the artifact. + artifact_id: An explicit artifact ID. If omitted a UUID is generated. + """ + + def __init__( + self, + context_id: str, + task_id: str, + name: str = 'response', + artifact_id: str | None = None, + ) -> None: + self._context_id = context_id + self._task_id = task_id + self._name = name + self._artifact_id = artifact_id or str(uuid.uuid4()) + + def append(self, text: str) -> TaskArtifactUpdateEvent: + """Emit a chunk to be appended to the streaming artifact. + + Args: + text: The incremental text content for this chunk. + + Returns: + A ``TaskArtifactUpdateEvent`` with ``append=True`` and + ``last_chunk=False``. + """ + return TaskArtifactUpdateEvent( + context_id=self._context_id, + task_id=self._task_id, + append=True, + last_chunk=False, + artifact=Artifact( + artifact_id=self._artifact_id, + name=self._name, + parts=[Part(text=text)], + ), + ) + + def finalize(self) -> TaskArtifactUpdateEvent: + """Signal that the artifact stream is complete. + + Returns: + A ``TaskArtifactUpdateEvent`` with ``append=True`` and + ``last_chunk=True``. + """ + return TaskArtifactUpdateEvent( + context_id=self._context_id, + task_id=self._task_id, + append=True, + last_chunk=True, + artifact=Artifact( + artifact_id=self._artifact_id, + name=self._name, + # Spec requires >= 1 part; use empty-text sentinel since + # finalize carries no new content. + # https://github.com/a2aproject/A2A/issues/1231 + parts=[Part(text='')], + ), + ) diff --git a/tests/utils/test_artifact.py b/tests/utils/test_artifact.py index cbe8e9c9..742ccdf9 100644 --- a/tests/utils/test_artifact.py +++ b/tests/utils/test_artifact.py @@ -8,8 +8,10 @@ from a2a.types.a2a_pb2 import ( Artifact, Part, + TaskArtifactUpdateEvent, ) from a2a.utils.artifact import ( + ArtifactStreamer, get_artifact_text, new_artifact, new_data_artifact, @@ -157,5 +159,106 @@ def test_get_artifact_text_empty_parts(self): assert result == '' +class TestArtifactStreamer(unittest.TestCase): + def setUp(self): + self.context_id = 'ctx-123' + self.task_id = 'task-456' + + def test_generates_stable_artifact_id(self): + streamer = ArtifactStreamer(self.context_id, self.task_id) + e1 = streamer.append('hello ') + e2 = streamer.append('world') + self.assertEqual(e1.artifact.artifact_id, e2.artifact.artifact_id) + + def test_uses_explicit_artifact_id(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, artifact_id='my-fixed-id' + ) + event = streamer.append('chunk') + self.assertEqual(event.artifact.artifact_id, 'my-fixed-id') + + @patch('a2a.utils.artifact.uuid.uuid4') + def test_generated_id_comes_from_uuid4(self, mock_uuid4): + mock_uuid = uuid.UUID('abcdef12-1234-5678-1234-567812345678') + mock_uuid4.return_value = mock_uuid + streamer = ArtifactStreamer(self.context_id, self.task_id) + self.assertEqual(streamer._artifact_id, str(mock_uuid)) + + def test_default_name_is_response(self): + streamer = ArtifactStreamer(self.context_id, self.task_id) + event = streamer.append('text') + self.assertEqual(event.artifact.name, 'response') + + def test_custom_name(self): + streamer = ArtifactStreamer( + self.context_id, self.task_id, name='summary' + ) + event = streamer.append('text') + self.assertEqual(event.artifact.name, 'summary') + + def test_append_returns_task_artifact_update_event(self): + streamer = ArtifactStreamer(self.context_id, self.task_id) + event = streamer.append('chunk') + self.assertIsInstance(event, TaskArtifactUpdateEvent) + + def test_append_sets_correct_context_and_task(self): + streamer = ArtifactStreamer(self.context_id, self.task_id) + event = streamer.append('chunk') + self.assertEqual(event.context_id, self.context_id) + self.assertEqual(event.task_id, self.task_id) + + def test_append_sets_append_true_last_chunk_false(self): + streamer = ArtifactStreamer(self.context_id, self.task_id) + event = streamer.append('chunk') + self.assertTrue(event.append) + self.assertFalse(event.last_chunk) + + def test_append_creates_single_text_part(self): + streamer = ArtifactStreamer(self.context_id, self.task_id) + event = streamer.append('hello') + self.assertEqual(len(event.artifact.parts), 1) + self.assertTrue(event.artifact.parts[0].HasField('text')) + self.assertEqual(event.artifact.parts[0].text, 'hello') + + def test_finalize_returns_task_artifact_update_event(self): + streamer = ArtifactStreamer(self.context_id, self.task_id) + event = streamer.finalize() + self.assertIsInstance(event, TaskArtifactUpdateEvent) + + def test_finalize_sets_append_true_last_chunk_true(self): + streamer = ArtifactStreamer(self.context_id, self.task_id) + event = streamer.finalize() + self.assertTrue(event.append) + self.assertTrue(event.last_chunk) + + def test_finalize_has_sentinel_empty_text_part(self): + streamer = ArtifactStreamer(self.context_id, self.task_id) + event = streamer.finalize() + self.assertEqual(len(event.artifact.parts), 1) + self.assertEqual(event.artifact.parts[0].text, '') + + def test_finalize_uses_same_artifact_id_as_append(self): + streamer = ArtifactStreamer(self.context_id, self.task_id) + append_event = streamer.append('text') + finalize_event = streamer.finalize() + self.assertEqual( + append_event.artifact.artifact_id, + finalize_event.artifact.artifact_id, + ) + + def test_multiple_appends_all_share_artifact_id(self): + streamer = ArtifactStreamer(self.context_id, self.task_id) + events = [streamer.append(f'chunk-{i}') for i in range(5)] + ids = {e.artifact.artifact_id for e in events} + self.assertEqual(len(ids), 1) + + def test_multiple_appends_carry_distinct_text(self): + streamer = ArtifactStreamer(self.context_id, self.task_id) + texts = ['Hello, ', 'world', '!'] + events = [streamer.append(t) for t in texts] + result_texts = [e.artifact.parts[0].text for e in events] + self.assertEqual(result_texts, texts) + + if __name__ == '__main__': unittest.main()