Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/a2a/utils/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,12 @@ def apply_history_length(task: Task, history_length: int | None) -> Task:
A new task object with limited history
"""
# Apply historyLength parameter if specified
if history_length is not None and history_length > 0 and task.history:
if history_length is not None and history_length >= 0:
# Limit history to the most recent N messages
limited_history = task.history[-history_length:]
if task.history and history_length > 0:
limited_history = task.history[-history_length:]
else:
limited_history = []
# Create a new task instance with limited history
return task.model_copy(update={'history': limited_history})

Expand Down
41 changes: 40 additions & 1 deletion tests/utils/test_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import unittest
import uuid

Expand All @@ -6,7 +6,7 @@
import pytest

from a2a.types import Artifact, Message, Part, Role, TextPart
from a2a.utils.task import completed_task, new_task
from a2a.utils.task import apply_history_length, completed_task, new_task


class TestTask(unittest.TestCase):
Expand Down Expand Up @@ -188,6 +188,45 @@
history=[],
)

def test_apply_history_length_cases(self):
# Setup task with 3 messages
history = [
Message(role=Role.user, parts=[Part(root=TextPart(text='1'))], message_id='1'),
Message(role=Role.agent, parts=[Part(root=TextPart(text='2'))], message_id='2'),
Message(role=Role.user, parts=[Part(root=TextPart(text='3'))], message_id='3'),
]
task_id = str(uuid.uuid4())
context_id = str(uuid.uuid4())
task = completed_task(
task_id=task_id,
context_id=context_id,
artifacts=[Artifact(artifact_id='a', parts=[Part(root=TextPart(text='a'))])],
history=history
)

# historyLength = 0 -> empty
t0 = apply_history_length(task, 0)
self.assertEqual(len(t0.history), 0)

# historyLength = 1 -> last one
t1 = apply_history_length(task, 1)
self.assertEqual(len(t1.history), 1)
self.assertEqual(t1.history[0].message_id, '3')

# historyLength = 2 -> last two
t2 = apply_history_length(task, 2)
self.assertEqual(len(t2.history), 2)
self.assertEqual(t2.history[0].message_id, '2')
self.assertEqual(t2.history[1].message_id, '3')

# historyLength = None -> all
tn = apply_history_length(task, None)
self.assertEqual(len(tn.history), 3)

# historyLength = 10 -> all
t10 = apply_history_length(task, 10)
self.assertEqual(len(t10.history), 3)


if __name__ == '__main__':
unittest.main()
Loading