Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions python/packages/claude/agent_framework_claude/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
normalize_messages,
normalize_tools,
)
from agent_framework._middleware import AgentMiddlewareLayer
from agent_framework.exceptions import AgentException
from agent_framework.observability import AgentTelemetryLayer
from claude_agent_sdk import (
Expand Down Expand Up @@ -242,7 +243,7 @@ def __init__(
name: Name of the agent.
description: Description of the agent.
context_providers: Context providers for the agent.
middleware: List of middleware.
middleware: List of AgentMiddleware.
tools: Tools for the agent. Can be:
- Strings for built-in tools (e.g., "Read", "Write", "Bash", "Glob")
- Functions for custom tools
Expand Down Expand Up @@ -581,6 +582,7 @@ def run(
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Declaring middleware explicitly in RawClaudeAgent's overloads leaks AgentMiddlewareLayer's routing concern into the raw class. The parameter is accepted and immediately discarded (never forwarded to _run_stream_impl). Since it arrives through **kwargs anyway when AgentMiddlewareLayer calls super().run(), it can simply be absorbed there without an explicit declaration.

Suggested change
session: AgentSession | None = None,
**kwargs: Any,

middleware: Sequence[AgentMiddlewareTypes] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...

Expand All @@ -591,6 +593,7 @@ def run(
*,
stream: Literal[True],
session: AgentSession | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same leaky abstraction as the non-streaming overload: middleware is declared but silently dropped. Remove it from this overload as well.

middleware: Sequence[AgentMiddlewareTypes] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...

Expand All @@ -600,6 +603,7 @@ def run(
*,
stream: bool = False,
session: AgentSession | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue on the base overload. Dropping the explicit middleware declaration from all three overloads keeps the raw class unaware of the middleware routing mechanism.

middleware: Sequence[AgentMiddlewareTypes] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Run the agent with the given messages.
Expand All @@ -612,6 +616,7 @@ def run(
returns an awaitable AgentResponse.
session: The conversation session. If session has service_session_id set,
the agent will resume that session.
middleware: Optional per-run AgentMiddleware applied on top of constructor middleware.
kwargs: Additional keyword arguments including 'options' for runtime options
(model, permission_mode can be changed per-request).
Comment thread
dmytrostruk marked this conversation as resolved.

Expand All @@ -620,15 +625,24 @@ def run(
When stream=False: An Awaitable[AgentResponse] with the complete response.
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

middleware is accepted as a named parameter in run() but never forwarded here. AgentMiddlewareLayer passes combined function/chat middleware via kwargs["middleware"], which is now intercepted by the named param and silently dropped instead of reaching _get_stream.

Suggested change
"""
response = self._run_stream_impl(messages=messages, session=session, options=options, middleware=middleware, **kwargs)

options = kwargs.pop("options", None)
response = ResponseStream(
self._get_stream(messages, session=session, options=options, **kwargs),
finalizer=self._finalize_response,
)

response = self._run_stream_impl(messages=messages, session=session, options=options, **kwargs)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change needed? we are already wrapping a call to another function, no real need to do that twice (_get_stream and _run_stream_impl)

if stream:
return response
return response.get_final_response()
Comment thread
dmytrostruk marked this conversation as resolved.

def _run_stream_impl(
self,
messages: AgentRunInputs | None = None,
*,
session: AgentSession | None = None,
options: OptionsT | MutableMapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
return ResponseStream(
self._get_stream(messages, session=session, options=options, **kwargs),
finalizer=self._finalize_response,
)

async def _get_stream(
self,
messages: AgentRunInputs | None = None,
Expand Down Expand Up @@ -716,7 +730,12 @@ async def _get_stream(
self._structured_output = structured_output


class ClaudeAgent(AgentTelemetryLayer, RawClaudeAgent[OptionsT], Generic[OptionsT]):
class ClaudeAgent(
AgentTelemetryLayer,
AgentMiddlewareLayer,
RawClaudeAgent[OptionsT],
Generic[OptionsT],
):
"""Claude Agent with OpenTelemetry instrumentation.

This is the recommended agent class for most use cases. It includes
Expand All @@ -736,3 +755,5 @@ class ClaudeAgent(AgentTelemetryLayer, RawClaudeAgent[OptionsT], Generic[Options
response = await agent.run("Hello!")
print(response.text)
"""

pass
173 changes: 172 additions & 1 deletion python/packages/claude/tests/test_claude_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from agent_framework import AgentResponseUpdate, AgentSession, Content, Message, tool
from agent_framework import (
AgentContext,
AgentMiddleware,
AgentResponse,
AgentResponseUpdate,
AgentSession,
Content,
Message,
tool,
)
from agent_framework._settings import load_settings

from agent_framework_claude import ClaudeAgent, ClaudeAgentOptions, ClaudeAgentSettings
Expand Down Expand Up @@ -434,6 +443,168 @@ async def test_run_stream_raises_on_result_message_error(self) -> None:
assert "Model 'claude-sonnet-4.5' not found" in str(exc_info.value)


class TestClaudeAgentMiddleware:
"""Tests for ClaudeAgent AgentMiddleware support."""

@staticmethod
async def _create_async_generator(items: list[Any]) -> Any:
"""Helper to create async generator from list."""
for item in items:
yield item

def _create_mock_client(self, messages: list[Any]) -> MagicMock:
"""Create a mock ClaudeSDKClient that yields given messages."""
mock_client = MagicMock()
mock_client.connect = AsyncMock()
mock_client.disconnect = AsyncMock()
mock_client.query = AsyncMock()
mock_client.set_model = AsyncMock()
mock_client.set_permission_mode = AsyncMock()
mock_client.receive_response = MagicMock(return_value=self._create_async_generator(messages))
return mock_client

async def test_run_executes_agent_middleware(self) -> None:
"""Test that non-streaming run executes AgentMiddleware."""
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
from claude_agent_sdk.types import StreamEvent

messages = [
StreamEvent(
event={
"type": "content_block_delta",
"delta": {"type": "text_delta", "text": "Hello!"},
},
uuid="event-1",
session_id="session-123",
),
AssistantMessage(
content=[TextBlock(text="Hello!")],
model="claude-sonnet",
),
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="session-123",
),
]
mock_client = self._create_mock_client(messages)
calls: list[str] = []

class TrackingMiddleware(AgentMiddleware):
async def process(self, context: AgentContext, call_next: Any) -> None:
calls.append("before")
await call_next()
calls.append("after")

with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
agent = ClaudeAgent(middleware=[TrackingMiddleware()])
response = await agent.run("Hello")

assert isinstance(response, AgentResponse)
assert response.text == "Hello!"
assert calls == ["before", "after"]

async def test_run_stream_applies_agent_middleware_transform(self) -> None:
"""Test that middleware stream transform hooks are applied."""
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
from claude_agent_sdk.types import StreamEvent

messages = [
StreamEvent(
event={
"type": "content_block_delta",
"delta": {"type": "text_delta", "text": "Streaming"},
},
uuid="event-1",
session_id="stream-session",
),
AssistantMessage(
content=[TextBlock(text="Streaming")],
model="claude-sonnet",
),
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="stream-session",
),
]
mock_client = self._create_mock_client(messages)

class PrefixDeltaMiddleware(AgentMiddleware):
async def process(self, context: AgentContext, call_next: Any) -> None:
if context.stream:

def add_prefix(update: AgentResponseUpdate) -> AgentResponseUpdate:
return AgentResponseUpdate(
role=update.role,
contents=[Content.from_text(f"mw:{update.text}")],
response_id=update.response_id,
message_id=update.message_id,
raw_representation=update.raw_representation,
)

context.stream_transform_hooks.append(add_prefix)
await call_next()

with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
agent = ClaudeAgent(middleware=[PrefixDeltaMiddleware()])
updates: list[AgentResponseUpdate] = []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No test covers passing middleware to run() at call time. The constructor-level middleware path is tested, but the per-run path added to the run() signature is dead code from a testing perspective.

Suggested change
updates: list[AgentResponseUpdate] = []
async def test_run_supports_per_run_middleware(self) -> None:
"""Test that middleware passed to run() at call time is executed."""
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
from claude_agent_sdk.types import StreamEvent
messages = [
StreamEvent(
event={"type": "content_block_delta", "delta": {"type": "text_delta", "text": "Hi"}},
uuid="e1", session_id="s1",
),
AssistantMessage(content=[TextBlock(text="Hi")], model="claude-sonnet"),
ResultMessage(subtype="success", duration_ms=100, duration_api_ms=50,
is_error=False, num_turns=1, session_id="s1"),
]
mock_client = self._create_mock_client(messages)
calls: list[str] = []
async def run_middleware(context: AgentContext, call_next: Any) -> None:
calls.append("run_before")
await call_next()
calls.append("run_after")
with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
agent = ClaudeAgent()
response = await agent.run("Hello", middleware=[run_middleware])
assert isinstance(response, AgentResponse)
assert calls == ["run_before", "run_after"]

async for update in agent.run("Hello", stream=True):
updates.append(update)

assert len(updates) == 1
assert updates[0].text == "mw:Streaming"

async def test_run_supports_agent_middleware_callable(self) -> None:
"""Test that function-based AgentMiddleware callables are supported."""
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
from claude_agent_sdk.types import StreamEvent

messages = [
StreamEvent(
event={
"type": "content_block_delta",
"delta": {"type": "text_delta", "text": "Hello!"},
},
uuid="event-1",
session_id="session-123",
),
AssistantMessage(
content=[TextBlock(text="Hello!")],
model="claude-sonnet",
),
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="session-123",
),
]
mock_client = self._create_mock_client(messages)
calls: list[str] = []

async def tracking_middleware(context: AgentContext, call_next: Any) -> None:
calls.append("before")
await call_next()
calls.append("after")

with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
agent = ClaudeAgent(middleware=[tracking_middleware])
response = await agent.run("Hello")

assert isinstance(response, AgentResponse)
assert response.text == "Hello!"
assert calls == ["before", "after"]


# region Test ClaudeAgent Session Management


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ResponseStream,
normalize_messages,
)
from agent_framework._middleware import AgentMiddlewareLayer
from agent_framework._settings import load_settings
from agent_framework._tools import FunctionTool, ToolTypes
from agent_framework._types import AgentRunInputs, normalize_tools
Expand Down Expand Up @@ -121,7 +122,7 @@ class GitHubCopilotOptions(TypedDict, total=False):
)


class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
class RawGitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
"""A GitHub Copilot Agent.

This agent wraps the GitHub Copilot SDK to provide Copilot agentic capabilities
Expand Down Expand Up @@ -242,7 +243,7 @@ def __init__(
self._default_options = opts
self._started = False

async def __aenter__(self) -> GitHubCopilotAgent[OptionsT]:
async def __aenter__(self) -> RawGitHubCopilotAgent[OptionsT]:
"""Start the agent when entering async context."""
await self.start()
return self
Comment thread
dmytrostruk marked this conversation as resolved.
Outdated
Expand Down Expand Up @@ -302,7 +303,6 @@ def run(
*,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

options was previously an explicitly typed parameter in this overload and is now completely absent from the signature. AgentMiddlewareLayer._middleware_handler still passes options=context.options explicitly, which works at runtime (falls into **kwargs), but this is a type-breaking change for any typed callers of GitHubCopilotAgent.run().

stream: Literal[False] = False,
session: AgentSession | None = None,
options: OptionsT | None = None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shouldn't be removed, should it?

**kwargs: Any,
) -> Awaitable[AgentResponse]: ...

Expand All @@ -313,7 +313,6 @@ def run(
*,
stream: Literal[True],
session: AgentSession | None = None,
options: OptionsT | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ...

Expand All @@ -323,7 +322,6 @@ def run(
*,
stream: bool = False,
session: AgentSession | None = None,
options: OptionsT | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
"""Get a response from the agent.
Expand All @@ -338,8 +336,8 @@ def run(
Keyword Args:
stream: Whether to stream the response. Defaults to False.
session: The conversation session associated with the message(s).
options: Runtime options (model, timeout, etc.).
kwargs: Additional keyword arguments.
kwargs: Additional keyword arguments, including ``options`` for runtime options
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not reintroduce kwargs, we are trying to get rid of them, use options

(model, timeout, etc.).

Returns:
When stream=False: An Awaitable[AgentResponse].
Expand All @@ -348,16 +346,26 @@ def run(
Raises:
AgentException: If the request fails.
"""
options = cast(OptionsT | None, kwargs.pop("options", None))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast() silently coerces whatever is in kwargs to OptionsT | None with no runtime check. If a caller passes options with a wrong type, the error will surface much later (or not at all). Consider validating the extracted value before use.

Suggested change
options = cast(OptionsT | None, kwargs.pop("options", None))
options = kwargs.pop("options", None)
if options is not None and not isinstance(options, (dict, type(None))):
raise TypeError(f"Expected options to be a mapping or None, got {type(options).__name__}")
options = cast(OptionsT | None, options)

if stream:
return self._run_stream_impl(messages=messages, session=session, options=options, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue: middleware (containing combined function/chat middleware from AgentMiddlewareLayer) is captured by the named param in run() and never forwarded to _run_stream_impl, silently discarding any function/chat middleware.

Suggested change
return self._run_stream_impl(messages=messages, session=session, options=options, **kwargs)
return self._run_stream_impl(messages=messages, session=session, options=options, middleware=middleware, **kwargs)

return self._run_impl(messages=messages, session=session, options=options, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

middleware not forwarded to _run_impl either — same silent drop for the non-streaming path.

Suggested change
return self._run_impl(messages=messages, session=session, options=options, **kwargs)
return self._run_impl(messages=messages, session=session, options=options, middleware=middleware, **kwargs)


def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse:
return AgentResponse.from_updates(updates)
def _run_stream_impl(
self,
messages: AgentRunInputs | None = None,
*,
session: AgentSession | None = None,
options: OptionsT | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse]:
def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse:
return AgentResponse.from_updates(updates)

return ResponseStream(
self._stream_updates(messages=messages, session=session, options=options, **kwargs),
finalizer=_finalize,
)
return self._run_impl(messages=messages, session=session, options=options, **kwargs)
return ResponseStream(
self._stream_updates(messages=messages, session=session, options=options, **kwargs),
finalizer=_finalize,
)

async def _run_impl(
self,
Expand Down Expand Up @@ -640,3 +648,11 @@ async def _resume_session(self, session_id: str, streaming: bool) -> CopilotSess
config["mcp_servers"] = self._mcp_servers

return await self._client.resume_session(session_id, config)


class GitHubCopilotAgent(
AgentMiddlewareLayer,
RawGitHubCopilotAgent[OptionsT],
Comment thread
dmytrostruk marked this conversation as resolved.
Generic[OptionsT],
):
"""GitHub Copilot agent with AgentMiddleware support."""
Loading
Loading