From 4f00d5795f7cd8d292b14462ef3cce021896714c Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Tue, 10 Mar 2026 12:00:17 -0700 Subject: [PATCH 1/3] Added AgentMiddleware to Copilot and Claude agents --- .../claude/agent_framework_claude/_agent.py | 35 +++- .../claude/tests/test_claude_agent.py | 173 +++++++++++++++++- .../agent_framework_github_copilot/_agent.py | 44 +++-- .../tests/test_github_copilot_agent.py | 93 ++++++++++ 4 files changed, 323 insertions(+), 22 deletions(-) diff --git a/python/packages/claude/agent_framework_claude/_agent.py b/python/packages/claude/agent_framework_claude/_agent.py index 127e3647ee..f9c11b87bb 100644 --- a/python/packages/claude/agent_framework_claude/_agent.py +++ b/python/packages/claude/agent_framework_claude/_agent.py @@ -26,6 +26,7 @@ normalize_messages, normalize_tools, ) +from agent_framework._middleware import AgentMiddlewareLayer from agent_framework.exceptions import AgentException from agent_framework.observability import AgentTelemetryLayer from claude_agent_sdk import ( @@ -242,7 +243,7 @@ def __init__( name: Name of the agent. description: Description of the agent. context_providers: Context providers for the agent. - middleware: List of middleware. + middleware: List of AgentMiddleware. tools: Tools for the agent. Can be: - Strings for built-in tools (e.g., "Read", "Write", "Bash", "Glob") - Functions for custom tools @@ -581,6 +582,7 @@ def run( *, stream: Literal[False] = ..., session: AgentSession | None = None, + middleware: Sequence[AgentMiddlewareTypes] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @@ -591,6 +593,7 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, + middleware: Sequence[AgentMiddlewareTypes] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... @@ -600,6 +603,7 @@ def run( *, stream: bool = False, session: AgentSession | None = None, + middleware: Sequence[AgentMiddlewareTypes] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Run the agent with the given messages. @@ -612,6 +616,7 @@ def run( returns an awaitable AgentResponse. session: The conversation session. If session has service_session_id set, the agent will resume that session. + middleware: Optional per-run AgentMiddleware applied on top of constructor middleware. kwargs: Additional keyword arguments including 'options' for runtime options (model, permission_mode can be changed per-request). @@ -620,15 +625,24 @@ def run( When stream=False: An Awaitable[AgentResponse] with the complete response. """ options = kwargs.pop("options", None) - response = ResponseStream( - self._get_stream(messages, session=session, options=options, **kwargs), - finalizer=self._finalize_response, - ) - + response = self._run_stream_impl(messages=messages, session=session, options=options, **kwargs) if stream: return response return response.get_final_response() + def _run_stream_impl( + self, + messages: AgentRunInputs | None = None, + *, + session: AgentSession | None = None, + options: OptionsT | MutableMapping[str, Any] | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: + return ResponseStream( + self._get_stream(messages, session=session, options=options, **kwargs), + finalizer=self._finalize_response, + ) + async def _get_stream( self, messages: AgentRunInputs | None = None, @@ -716,7 +730,12 @@ async def _get_stream( self._structured_output = structured_output -class ClaudeAgent(AgentTelemetryLayer, RawClaudeAgent[OptionsT], Generic[OptionsT]): +class ClaudeAgent( + AgentTelemetryLayer, + AgentMiddlewareLayer, + RawClaudeAgent[OptionsT], + Generic[OptionsT], +): """Claude Agent with OpenTelemetry instrumentation. This is the recommended agent class for most use cases. It includes @@ -736,3 +755,5 @@ class ClaudeAgent(AgentTelemetryLayer, RawClaudeAgent[OptionsT], Generic[Options response = await agent.run("Hello!") print(response.text) """ + + pass diff --git a/python/packages/claude/tests/test_claude_agent.py b/python/packages/claude/tests/test_claude_agent.py index e48a3b05d9..3175c68d5b 100644 --- a/python/packages/claude/tests/test_claude_agent.py +++ b/python/packages/claude/tests/test_claude_agent.py @@ -4,7 +4,16 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from agent_framework import AgentResponseUpdate, AgentSession, Content, Message, tool +from agent_framework import ( + AgentContext, + AgentMiddleware, + AgentResponse, + AgentResponseUpdate, + AgentSession, + Content, + Message, + tool, +) from agent_framework._settings import load_settings from agent_framework_claude import ClaudeAgent, ClaudeAgentOptions, ClaudeAgentSettings @@ -434,6 +443,168 @@ async def test_run_stream_raises_on_result_message_error(self) -> None: assert "Model 'claude-sonnet-4.5' not found" in str(exc_info.value) +class TestClaudeAgentMiddleware: + """Tests for ClaudeAgent AgentMiddleware support.""" + + @staticmethod + async def _create_async_generator(items: list[Any]) -> Any: + """Helper to create async generator from list.""" + for item in items: + yield item + + def _create_mock_client(self, messages: list[Any]) -> MagicMock: + """Create a mock ClaudeSDKClient that yields given messages.""" + mock_client = MagicMock() + mock_client.connect = AsyncMock() + mock_client.disconnect = AsyncMock() + mock_client.query = AsyncMock() + mock_client.set_model = AsyncMock() + mock_client.set_permission_mode = AsyncMock() + mock_client.receive_response = MagicMock(return_value=self._create_async_generator(messages)) + return mock_client + + async def test_run_executes_agent_middleware(self) -> None: + """Test that non-streaming run executes AgentMiddleware.""" + from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock + from claude_agent_sdk.types import StreamEvent + + messages = [ + StreamEvent( + event={ + "type": "content_block_delta", + "delta": {"type": "text_delta", "text": "Hello!"}, + }, + uuid="event-1", + session_id="session-123", + ), + AssistantMessage( + content=[TextBlock(text="Hello!")], + model="claude-sonnet", + ), + ResultMessage( + subtype="success", + duration_ms=100, + duration_api_ms=50, + is_error=False, + num_turns=1, + session_id="session-123", + ), + ] + mock_client = self._create_mock_client(messages) + calls: list[str] = [] + + class TrackingMiddleware(AgentMiddleware): + async def process(self, context: AgentContext, call_next: Any) -> None: + calls.append("before") + await call_next() + calls.append("after") + + with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client): + agent = ClaudeAgent(middleware=[TrackingMiddleware()]) + response = await agent.run("Hello") + + assert isinstance(response, AgentResponse) + assert response.text == "Hello!" + assert calls == ["before", "after"] + + async def test_run_stream_applies_agent_middleware_transform(self) -> None: + """Test that middleware stream transform hooks are applied.""" + from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock + from claude_agent_sdk.types import StreamEvent + + messages = [ + StreamEvent( + event={ + "type": "content_block_delta", + "delta": {"type": "text_delta", "text": "Streaming"}, + }, + uuid="event-1", + session_id="stream-session", + ), + AssistantMessage( + content=[TextBlock(text="Streaming")], + model="claude-sonnet", + ), + ResultMessage( + subtype="success", + duration_ms=100, + duration_api_ms=50, + is_error=False, + num_turns=1, + session_id="stream-session", + ), + ] + mock_client = self._create_mock_client(messages) + + class PrefixDeltaMiddleware(AgentMiddleware): + async def process(self, context: AgentContext, call_next: Any) -> None: + if context.stream: + + def add_prefix(update: AgentResponseUpdate) -> AgentResponseUpdate: + return AgentResponseUpdate( + role=update.role, + contents=[Content.from_text(f"mw:{update.text}")], + response_id=update.response_id, + message_id=update.message_id, + raw_representation=update.raw_representation, + ) + + context.stream_transform_hooks.append(add_prefix) + await call_next() + + with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client): + agent = ClaudeAgent(middleware=[PrefixDeltaMiddleware()]) + updates: list[AgentResponseUpdate] = [] + async for update in agent.run("Hello", stream=True): + updates.append(update) + + assert len(updates) == 1 + assert updates[0].text == "mw:Streaming" + + async def test_run_supports_agent_middleware_callable(self) -> None: + """Test that function-based AgentMiddleware callables are supported.""" + from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock + from claude_agent_sdk.types import StreamEvent + + messages = [ + StreamEvent( + event={ + "type": "content_block_delta", + "delta": {"type": "text_delta", "text": "Hello!"}, + }, + uuid="event-1", + session_id="session-123", + ), + AssistantMessage( + content=[TextBlock(text="Hello!")], + model="claude-sonnet", + ), + ResultMessage( + subtype="success", + duration_ms=100, + duration_api_ms=50, + is_error=False, + num_turns=1, + session_id="session-123", + ), + ] + mock_client = self._create_mock_client(messages) + calls: list[str] = [] + + async def tracking_middleware(context: AgentContext, call_next: Any) -> None: + calls.append("before") + await call_next() + calls.append("after") + + with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client): + agent = ClaudeAgent(middleware=[tracking_middleware]) + response = await agent.run("Hello") + + assert isinstance(response, AgentResponse) + assert response.text == "Hello!" + assert calls == ["before", "after"] + + # region Test ClaudeAgent Session Management diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 7fa7d0dce4..9ba26a90f8 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -21,6 +21,7 @@ ResponseStream, normalize_messages, ) +from agent_framework._middleware import AgentMiddlewareLayer from agent_framework._settings import load_settings from agent_framework._tools import FunctionTool, ToolTypes from agent_framework._types import AgentRunInputs, normalize_tools @@ -121,7 +122,7 @@ class GitHubCopilotOptions(TypedDict, total=False): ) -class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]): +class RawGitHubCopilotAgent(BaseAgent, Generic[OptionsT]): """A GitHub Copilot Agent. This agent wraps the GitHub Copilot SDK to provide Copilot agentic capabilities @@ -242,7 +243,7 @@ def __init__( self._default_options = opts self._started = False - async def __aenter__(self) -> GitHubCopilotAgent[OptionsT]: + async def __aenter__(self) -> RawGitHubCopilotAgent[OptionsT]: """Start the agent when entering async context.""" await self.start() return self @@ -302,7 +303,6 @@ def run( *, stream: Literal[False] = False, session: AgentSession | None = None, - options: OptionsT | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse]: ... @@ -313,7 +313,6 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, - options: OptionsT | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... @@ -323,7 +322,6 @@ def run( *, stream: bool = False, session: AgentSession | None = None, - options: OptionsT | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: """Get a response from the agent. @@ -338,8 +336,8 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). - options: Runtime options (model, timeout, etc.). - kwargs: Additional keyword arguments. + kwargs: Additional keyword arguments, including ``options`` for runtime options + (model, timeout, etc.). Returns: When stream=False: An Awaitable[AgentResponse]. @@ -348,16 +346,26 @@ def run( Raises: AgentException: If the request fails. """ + options = cast(OptionsT | None, kwargs.pop("options", None)) if stream: + return self._run_stream_impl(messages=messages, session=session, options=options, **kwargs) + return self._run_impl(messages=messages, session=session, options=options, **kwargs) - def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: - return AgentResponse.from_updates(updates) + def _run_stream_impl( + self, + messages: AgentRunInputs | None = None, + *, + session: AgentSession | None = None, + options: OptionsT | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: + return AgentResponse.from_updates(updates) - return ResponseStream( - self._stream_updates(messages=messages, session=session, options=options, **kwargs), - finalizer=_finalize, - ) - return self._run_impl(messages=messages, session=session, options=options, **kwargs) + return ResponseStream( + self._stream_updates(messages=messages, session=session, options=options, **kwargs), + finalizer=_finalize, + ) async def _run_impl( self, @@ -640,3 +648,11 @@ async def _resume_session(self, session_id: str, streaming: bool) -> CopilotSess config["mcp_servers"] = self._mcp_servers return await self._client.resume_session(session_id, config) + + +class GitHubCopilotAgent( + AgentMiddlewareLayer, + RawGitHubCopilotAgent[OptionsT], + Generic[OptionsT], +): + """GitHub Copilot agent with AgentMiddleware support.""" diff --git a/python/packages/github_copilot/tests/test_github_copilot_agent.py b/python/packages/github_copilot/tests/test_github_copilot_agent.py index ed8c089fa3..bda43676d6 100644 --- a/python/packages/github_copilot/tests/test_github_copilot_agent.py +++ b/python/packages/github_copilot/tests/test_github_copilot_agent.py @@ -8,6 +8,8 @@ import pytest from agent_framework import ( + AgentContext, + AgentMiddleware, AgentResponse, AgentResponseUpdate, AgentSession, @@ -459,6 +461,97 @@ def mock_on(handler: Any) -> Any: mock_client.start.assert_called_once() +class TestGitHubCopilotAgentMiddleware: + """Test cases for AgentMiddleware behavior.""" + + async def test_run_executes_agent_middleware( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that non-streaming run executes AgentMiddleware.""" + mock_session.send_and_wait.return_value = assistant_message_event + calls: list[str] = [] + + class TrackingMiddleware(AgentMiddleware): + async def process(self, context: AgentContext, call_next: Any) -> None: + calls.append("before") + await call_next() + calls.append("after") + + agent = GitHubCopilotAgent(client=mock_client, middleware=[TrackingMiddleware()]) + response = await agent.run("Hello") + + assert isinstance(response, AgentResponse) + assert response.text == "Test response" + assert calls == ["before", "after"] + + async def test_run_stream_applies_agent_middleware_transform( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_delta_event: SessionEvent, + session_idle_event: SessionEvent, + ) -> None: + """Test that streaming middleware transform hooks are applied.""" + events = [assistant_delta_event, session_idle_event] + + def mock_on(handler: Any) -> Any: + for event in events: + handler(event) + return lambda: None + + mock_session.on = mock_on + + class PrefixDeltaMiddleware(AgentMiddleware): + async def process(self, context: AgentContext, call_next: Any) -> None: + if context.stream: + + def add_prefix(update: AgentResponseUpdate) -> AgentResponseUpdate: + return AgentResponseUpdate( + role=update.role, + contents=[Content.from_text(f"mw:{update.text}")], + response_id=update.response_id, + message_id=update.message_id, + raw_representation=update.raw_representation, + ) + + context.stream_transform_hooks.append(add_prefix) + await call_next() + + agent = GitHubCopilotAgent(client=mock_client, middleware=[PrefixDeltaMiddleware()]) + + responses: list[AgentResponseUpdate] = [] + async for update in agent.run("Hello", stream=True): + responses.append(update) + + assert len(responses) == 1 + assert responses[0].text == "mw:Hello" + + async def test_run_supports_agent_middleware_callable( + self, + mock_client: MagicMock, + mock_session: MagicMock, + assistant_message_event: SessionEvent, + ) -> None: + """Test that function-based AgentMiddleware callables are supported.""" + mock_session.send_and_wait.return_value = assistant_message_event + calls: list[str] = [] + + async def tracking_middleware(context: AgentContext, call_next: Any) -> None: + calls.append("before") + await call_next() + calls.append("after") + + agent = GitHubCopilotAgent(client=mock_client, middleware=[tracking_middleware]) + response = await agent.run("Hello") + + assert isinstance(response, AgentResponse) + assert response.text == "Test response" + assert calls == ["before", "after"] + + class TestGitHubCopilotAgentSessionManagement: """Test cases for session management.""" From 5a6c021b9dd9726fc6d48021a2bad68ae960fa99 Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Tue, 10 Mar 2026 12:15:19 -0700 Subject: [PATCH 2/3] Fixed CI --- .../github_copilot/agent_framework_github_copilot/_agent.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 9ba26a90f8..f1a2cef369 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -303,6 +303,7 @@ def run( *, stream: Literal[False] = False, session: AgentSession | None = None, + middleware: Sequence[AgentMiddlewareTypes] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse]: ... @@ -313,6 +314,7 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, + middleware: Sequence[AgentMiddlewareTypes] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... @@ -322,6 +324,7 @@ def run( *, stream: bool = False, session: AgentSession | None = None, + middleware: Sequence[AgentMiddlewareTypes] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: """Get a response from the agent. @@ -336,6 +339,7 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). + middleware: Runtime middleware parameter accepted for compatibility with middleware layer routing. kwargs: Additional keyword arguments, including ``options`` for runtime options (model, timeout, etc.). From 8d41f39328a8fa1a08aab8582a46ad9e6b3795ef Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Tue, 10 Mar 2026 12:24:40 -0700 Subject: [PATCH 3/3] Resolved comments --- .../agent_framework_github_copilot/_agent.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index f1a2cef369..f0ed43cd54 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -26,6 +26,7 @@ from agent_framework._tools import FunctionTool, ToolTypes from agent_framework._types import AgentRunInputs, normalize_tools from agent_framework.exceptions import AgentException +from agent_framework.observability import AgentTelemetryLayer from copilot import CopilotClient, CopilotSession from copilot.generated.session_events import PermissionRequest, SessionEvent, SessionEventType from copilot.types import ( @@ -42,9 +43,9 @@ from copilot.types import Tool as CopilotTool if sys.version_info >= (3, 13): - from typing import TypeVar + from typing import Self, TypeVar else: - from typing_extensions import TypeVar + from typing_extensions import Self, TypeVar DEFAULT_TIMEOUT_SECONDS: float = 60.0 @@ -243,7 +244,7 @@ def __init__( self._default_options = opts self._started = False - async def __aenter__(self) -> RawGitHubCopilotAgent[OptionsT]: + async def __aenter__(self) -> Self: """Start the agent when entering async context.""" await self.start() return self @@ -655,8 +656,9 @@ async def _resume_session(self, session_id: str, streaming: bool) -> CopilotSess class GitHubCopilotAgent( + AgentTelemetryLayer, AgentMiddlewareLayer, RawGitHubCopilotAgent[OptionsT], Generic[OptionsT], ): - """GitHub Copilot agent with AgentMiddleware support.""" + """GitHub Copilot agent with middleware and telemetry support."""