Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions python/packages/claude/agent_framework_claude/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
normalize_messages,
normalize_tools,
)
from agent_framework._middleware import AgentMiddlewareLayer
from agent_framework.exceptions import AgentException
from agent_framework.observability import AgentTelemetryLayer
from claude_agent_sdk import (
Expand Down Expand Up @@ -242,7 +243,7 @@ def __init__(
name: Name of the agent.
description: Description of the agent.
context_providers: Context providers for the agent.
middleware: List of middleware.
middleware: List of AgentMiddleware.
tools: Tools for the agent. Can be:
- Strings for built-in tools (e.g., "Read", "Write", "Bash", "Glob")
- Functions for custom tools
Expand Down Expand Up @@ -581,6 +582,7 @@ def run(
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
middleware: Sequence[AgentMiddlewareTypes] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...

Expand All @@ -591,6 +593,7 @@ def run(
*,
stream: Literal[True],
session: AgentSession | None = None,
middleware: Sequence[AgentMiddlewareTypes] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...

Expand All @@ -600,6 +603,7 @@ def run(
*,
stream: bool = False,
session: AgentSession | None = None,
middleware: Sequence[AgentMiddlewareTypes] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Run the agent with the given messages.
Expand All @@ -612,6 +616,7 @@ def run(
returns an awaitable AgentResponse.
session: The conversation session. If session has service_session_id set,
the agent will resume that session.
middleware: Optional per-run AgentMiddleware applied on top of constructor middleware.
kwargs: Additional keyword arguments including 'options' for runtime options
(model, permission_mode can be changed per-request).

Expand All @@ -620,15 +625,24 @@ def run(
When stream=False: An Awaitable[AgentResponse] with the complete response.
"""
options = kwargs.pop("options", None)
response = ResponseStream(
self._get_stream(messages, session=session, options=options, **kwargs),
finalizer=self._finalize_response,
)

response = self._run_stream_impl(messages=messages, session=session, options=options, **kwargs)
if stream:
return response
return response.get_final_response()

def _run_stream_impl(
self,
messages: AgentRunInputs | None = None,
*,
session: AgentSession | None = None,
options: OptionsT | MutableMapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
return ResponseStream(
self._get_stream(messages, session=session, options=options, **kwargs),
finalizer=self._finalize_response,
)

async def _get_stream(
self,
messages: AgentRunInputs | None = None,
Expand Down Expand Up @@ -716,7 +730,12 @@ async def _get_stream(
self._structured_output = structured_output


class ClaudeAgent(AgentTelemetryLayer, RawClaudeAgent[OptionsT], Generic[OptionsT]):
class ClaudeAgent(
AgentTelemetryLayer,
AgentMiddlewareLayer,
RawClaudeAgent[OptionsT],
Generic[OptionsT],
):
"""Claude Agent with OpenTelemetry instrumentation.

This is the recommended agent class for most use cases. It includes
Expand All @@ -736,3 +755,5 @@ class ClaudeAgent(AgentTelemetryLayer, RawClaudeAgent[OptionsT], Generic[Options
response = await agent.run("Hello!")
print(response.text)
"""

pass
173 changes: 172 additions & 1 deletion python/packages/claude/tests/test_claude_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from agent_framework import AgentResponseUpdate, AgentSession, Content, Message, tool
from agent_framework import (
AgentContext,
AgentMiddleware,
AgentResponse,
AgentResponseUpdate,
AgentSession,
Content,
Message,
tool,
)
from agent_framework._settings import load_settings

from agent_framework_claude import ClaudeAgent, ClaudeAgentOptions, ClaudeAgentSettings
Expand Down Expand Up @@ -434,6 +443,168 @@ async def test_run_stream_raises_on_result_message_error(self) -> None:
assert "Model 'claude-sonnet-4.5' not found" in str(exc_info.value)


class TestClaudeAgentMiddleware:
"""Tests for ClaudeAgent AgentMiddleware support."""

@staticmethod
async def _create_async_generator(items: list[Any]) -> Any:
"""Helper to create async generator from list."""
for item in items:
yield item

def _create_mock_client(self, messages: list[Any]) -> MagicMock:
"""Create a mock ClaudeSDKClient that yields given messages."""
mock_client = MagicMock()
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client.query = AsyncMock()
mock_client.set_model = AsyncMock()
mock_client.set_permission_mode = AsyncMock()
mock_client.receive_response = MagicMock(return_value=self._create_async_generator(messages))
return mock_client

async def test_run_executes_agent_middleware(self) -> None:
"""Test that non-streaming run executes AgentMiddleware."""
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
from claude_agent_sdk.types import StreamEvent

messages = [
StreamEvent(
event={
"type": "content_block_delta",
"delta": {"type": "text_delta", "text": "Hello!"},
},
uuid="event-1",
session_id="session-123",
),
AssistantMessage(
content=[TextBlock(text="Hello!")],
model="claude-sonnet",
),
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="session-123",
),
]
mock_client = self._create_mock_client(messages)
calls: list[str] = []

class TrackingMiddleware(AgentMiddleware):
async def process(self, context: AgentContext, call_next: Any) -> None:
calls.append("before")
await call_next()
calls.append("after")

with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
agent = ClaudeAgent(middleware=[TrackingMiddleware()])
response = await agent.run("Hello")

assert isinstance(response, AgentResponse)
assert response.text == "Hello!"
assert calls == ["before", "after"]

async def test_run_stream_applies_agent_middleware_transform(self) -> None:
"""Test that middleware stream transform hooks are applied."""
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
from claude_agent_sdk.types import StreamEvent

messages = [
StreamEvent(
event={
"type": "content_block_delta",
"delta": {"type": "text_delta", "text": "Streaming"},
},
uuid="event-1",
session_id="stream-session",
),
AssistantMessage(
content=[TextBlock(text="Streaming")],
model="claude-sonnet",
),
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="stream-session",
),
]
mock_client = self._create_mock_client(messages)

class PrefixDeltaMiddleware(AgentMiddleware):
async def process(self, context: AgentContext, call_next: Any) -> None:
if context.stream:

def add_prefix(update: AgentResponseUpdate) -> AgentResponseUpdate:
return AgentResponseUpdate(
role=update.role,
contents=[Content.from_text(f"mw:{update.text}")],
response_id=update.response_id,
message_id=update.message_id,
raw_representation=update.raw_representation,
)

context.stream_transform_hooks.append(add_prefix)
await call_next()

with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
agent = ClaudeAgent(middleware=[PrefixDeltaMiddleware()])
updates: list[AgentResponseUpdate] = []
async for update in agent.run("Hello", stream=True):
updates.append(update)

assert len(updates) == 1
assert updates[0].text == "mw:Streaming"

async def test_run_supports_agent_middleware_callable(self) -> None:
"""Test that function-based AgentMiddleware callables are supported."""
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
from claude_agent_sdk.types import StreamEvent

messages = [
StreamEvent(
event={
"type": "content_block_delta",
"delta": {"type": "text_delta", "text": "Hello!"},
},
uuid="event-1",
session_id="session-123",
),
AssistantMessage(
content=[TextBlock(text="Hello!")],
model="claude-sonnet",
),
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="session-123",
),
]
mock_client = self._create_mock_client(messages)
calls: list[str] = []

async def tracking_middleware(context: AgentContext, call_next: Any) -> None:
calls.append("before")
await call_next()
calls.append("after")

with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
agent = ClaudeAgent(middleware=[tracking_middleware])
response = await agent.run("Hello")

assert isinstance(response, AgentResponse)
assert response.text == "Hello!"
assert calls == ["before", "after"]


# region Test ClaudeAgent Session Management


Expand Down
Loading