-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Python: Added AgentMiddleware to Copilot and Claude agents #4601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
4f00d57
1da740b
5a6c021
60a8550
8d41f39
5b10bd4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -26,6 +26,7 @@ | |||||
| normalize_messages, | ||||||
| normalize_tools, | ||||||
| ) | ||||||
| from agent_framework._middleware import AgentMiddlewareLayer | ||||||
| from agent_framework.exceptions import AgentException | ||||||
| from agent_framework.observability import AgentTelemetryLayer | ||||||
| from claude_agent_sdk import ( | ||||||
|
|
@@ -242,7 +243,7 @@ def __init__( | |||||
| name: Name of the agent. | ||||||
| description: Description of the agent. | ||||||
| context_providers: Context providers for the agent. | ||||||
| middleware: List of middleware. | ||||||
| middleware: List of AgentMiddleware. | ||||||
| tools: Tools for the agent. Can be: | ||||||
| - Strings for built-in tools (e.g., "Read", "Write", "Bash", "Glob") | ||||||
| - Functions for custom tools | ||||||
|
|
@@ -581,6 +582,7 @@ def run( | |||||
| *, | ||||||
| stream: Literal[False] = ..., | ||||||
| session: AgentSession | None = None, | ||||||
| middleware: Sequence[AgentMiddlewareTypes] | None = None, | ||||||
| **kwargs: Any, | ||||||
| ) -> Awaitable[AgentResponse[Any]]: ... | ||||||
|
|
||||||
|
|
@@ -591,6 +593,7 @@ def run( | |||||
| *, | ||||||
| stream: Literal[True], | ||||||
| session: AgentSession | None = None, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same leaky abstraction as the non-streaming overload: |
||||||
| middleware: Sequence[AgentMiddlewareTypes] | None = None, | ||||||
| **kwargs: Any, | ||||||
| ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... | ||||||
|
|
||||||
|
|
@@ -600,6 +603,7 @@ def run( | |||||
| *, | ||||||
| stream: bool = False, | ||||||
| session: AgentSession | None = None, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same issue on the base overload. Dropping the explicit |
||||||
| middleware: Sequence[AgentMiddlewareTypes] | None = None, | ||||||
| **kwargs: Any, | ||||||
| ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: | ||||||
| """Run the agent with the given messages. | ||||||
|
|
@@ -612,6 +616,7 @@ def run( | |||||
| returns an awaitable AgentResponse. | ||||||
| session: The conversation session. If session has service_session_id set, | ||||||
| the agent will resume that session. | ||||||
| middleware: Optional per-run AgentMiddleware applied on top of constructor middleware. | ||||||
| kwargs: Additional keyword arguments including 'options' for runtime options | ||||||
| (model, permission_mode can be changed per-request). | ||||||
|
dmytrostruk marked this conversation as resolved.
|
||||||
|
|
||||||
|
|
@@ -620,15 +625,24 @@ def run( | |||||
| When stream=False: An Awaitable[AgentResponse] with the complete response. | ||||||
| """ | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| options = kwargs.pop("options", None) | ||||||
| response = ResponseStream( | ||||||
| self._get_stream(messages, session=session, options=options, **kwargs), | ||||||
| finalizer=self._finalize_response, | ||||||
| ) | ||||||
|
|
||||||
| response = self._run_stream_impl(messages=messages, session=session, options=options, **kwargs) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this change needed? we are already wrapping a call to another function, no real need to do that twice (_get_stream and _run_stream_impl) |
||||||
| if stream: | ||||||
| return response | ||||||
| return response.get_final_response() | ||||||
|
dmytrostruk marked this conversation as resolved.
|
||||||
|
|
||||||
| def _run_stream_impl( | ||||||
| self, | ||||||
| messages: AgentRunInputs | None = None, | ||||||
| *, | ||||||
| session: AgentSession | None = None, | ||||||
| options: OptionsT | MutableMapping[str, Any] | None = None, | ||||||
| **kwargs: Any, | ||||||
| ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: | ||||||
| return ResponseStream( | ||||||
| self._get_stream(messages, session=session, options=options, **kwargs), | ||||||
| finalizer=self._finalize_response, | ||||||
| ) | ||||||
|
|
||||||
| async def _get_stream( | ||||||
| self, | ||||||
| messages: AgentRunInputs | None = None, | ||||||
|
|
@@ -716,7 +730,12 @@ async def _get_stream( | |||||
| self._structured_output = structured_output | ||||||
|
|
||||||
|
|
||||||
| class ClaudeAgent(AgentTelemetryLayer, RawClaudeAgent[OptionsT], Generic[OptionsT]): | ||||||
| class ClaudeAgent( | ||||||
| AgentTelemetryLayer, | ||||||
| AgentMiddlewareLayer, | ||||||
| RawClaudeAgent[OptionsT], | ||||||
| Generic[OptionsT], | ||||||
| ): | ||||||
| """Claude Agent with OpenTelemetry instrumentation. | ||||||
|
|
||||||
| This is the recommended agent class for most use cases. It includes | ||||||
|
|
@@ -736,3 +755,5 @@ class ClaudeAgent(AgentTelemetryLayer, RawClaudeAgent[OptionsT], Generic[Options | |||||
| response = await agent.run("Hello!") | ||||||
| print(response.text) | ||||||
| """ | ||||||
|
|
||||||
| pass | ||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,7 +4,16 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from unittest.mock import AsyncMock, MagicMock, patch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from agent_framework import AgentResponseUpdate, AgentSession, Content, Message, tool | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from agent_framework import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AgentContext, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AgentMiddleware, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AgentResponse, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AgentResponseUpdate, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AgentSession, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Content, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Message, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tool, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from agent_framework._settings import load_settings | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from agent_framework_claude import ClaudeAgent, ClaudeAgentOptions, ClaudeAgentSettings | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -434,6 +443,168 @@ async def test_run_stream_raises_on_result_message_error(self) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert "Model 'claude-sonnet-4.5' not found" in str(exc_info.value) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class TestClaudeAgentMiddleware: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Tests for ClaudeAgent AgentMiddleware support.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def _create_async_generator(items: list[Any]) -> Any: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Helper to create async generator from list.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for item in items: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| yield item | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _create_mock_client(self, messages: list[Any]) -> MagicMock: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Create a mock ClaudeSDKClient that yields given messages.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_client = MagicMock() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_client.connect = AsyncMock() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_client.disconnect = AsyncMock() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_client.query = AsyncMock() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_client.set_model = AsyncMock() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_client.set_permission_mode = AsyncMock() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_client.receive_response = MagicMock(return_value=self._create_async_generator(messages)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return mock_client | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def test_run_executes_agent_middleware(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Test that non-streaming run executes AgentMiddleware.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from claude_agent_sdk.types import StreamEvent | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| messages = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| StreamEvent( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| event={ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "type": "content_block_delta", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "delta": {"type": "text_delta", "text": "Hello!"}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uuid="event-1", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| session_id="session-123", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AssistantMessage( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| content=[TextBlock(text="Hello!")], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model="claude-sonnet", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ResultMessage( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| subtype="success", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| duration_ms=100, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| duration_api_ms=50, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_error=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_turns=1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| session_id="session-123", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_client = self._create_mock_client(messages) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| calls: list[str] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class TrackingMiddleware(AgentMiddleware): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def process(self, context: AgentContext, call_next: Any) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| calls.append("before") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| await call_next() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| calls.append("after") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| agent = ClaudeAgent(middleware=[TrackingMiddleware()]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| response = await agent.run("Hello") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert isinstance(response, AgentResponse) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert response.text == "Hello!" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert calls == ["before", "after"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def test_run_stream_applies_agent_middleware_transform(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Test that middleware stream transform hooks are applied.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from claude_agent_sdk.types import StreamEvent | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| messages = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| StreamEvent( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| event={ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "type": "content_block_delta", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "delta": {"type": "text_delta", "text": "Streaming"}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uuid="event-1", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| session_id="stream-session", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AssistantMessage( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| content=[TextBlock(text="Streaming")], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model="claude-sonnet", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ResultMessage( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| subtype="success", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| duration_ms=100, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| duration_api_ms=50, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_error=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_turns=1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| session_id="stream-session", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_client = self._create_mock_client(messages) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class PrefixDeltaMiddleware(AgentMiddleware): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def process(self, context: AgentContext, call_next: Any) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if context.stream: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def add_prefix(update: AgentResponseUpdate) -> AgentResponseUpdate: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return AgentResponseUpdate( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| role=update.role, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| contents=[Content.from_text(f"mw:{update.text}")], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| response_id=update.response_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| message_id=update.message_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raw_representation=update.raw_representation, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| context.stream_transform_hooks.append(add_prefix) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| await call_next() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| agent = ClaudeAgent(middleware=[PrefixDeltaMiddleware()]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| updates: list[AgentResponseUpdate] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No test covers passing
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async for update in agent.run("Hello", stream=True): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| updates.append(update) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert len(updates) == 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert updates[0].text == "mw:Streaming" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def test_run_supports_agent_middleware_callable(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Test that function-based AgentMiddleware callables are supported.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from claude_agent_sdk.types import StreamEvent | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| messages = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| StreamEvent( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| event={ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "type": "content_block_delta", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "delta": {"type": "text_delta", "text": "Hello!"}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uuid="event-1", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| session_id="session-123", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| AssistantMessage( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| content=[TextBlock(text="Hello!")], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model="claude-sonnet", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ResultMessage( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| subtype="success", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| duration_ms=100, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| duration_api_ms=50, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_error=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_turns=1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| session_id="session-123", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_client = self._create_mock_client(messages) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| calls: list[str] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def tracking_middleware(context: AgentContext, call_next: Any) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| calls.append("before") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| await call_next() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| calls.append("after") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| agent = ClaudeAgent(middleware=[tracking_middleware]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| response = await agent.run("Hello") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert isinstance(response, AgentResponse) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert response.text == "Hello!" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert calls == ["before", "after"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # region Test ClaudeAgent Session Management | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |||||||||||
| ResponseStream, | ||||||||||||
| normalize_messages, | ||||||||||||
| ) | ||||||||||||
| from agent_framework._middleware import AgentMiddlewareLayer | ||||||||||||
| from agent_framework._settings import load_settings | ||||||||||||
| from agent_framework._tools import FunctionTool, ToolTypes | ||||||||||||
| from agent_framework._types import AgentRunInputs, normalize_tools | ||||||||||||
|
|
@@ -121,7 +122,7 @@ class GitHubCopilotOptions(TypedDict, total=False): | |||||||||||
| ) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]): | ||||||||||||
| class RawGitHubCopilotAgent(BaseAgent, Generic[OptionsT]): | ||||||||||||
| """A GitHub Copilot Agent. | ||||||||||||
|
|
||||||||||||
| This agent wraps the GitHub Copilot SDK to provide Copilot agentic capabilities | ||||||||||||
|
|
@@ -242,7 +243,7 @@ def __init__( | |||||||||||
| self._default_options = opts | ||||||||||||
| self._started = False | ||||||||||||
|
|
||||||||||||
| async def __aenter__(self) -> GitHubCopilotAgent[OptionsT]: | ||||||||||||
| async def __aenter__(self) -> RawGitHubCopilotAgent[OptionsT]: | ||||||||||||
| """Start the agent when entering async context.""" | ||||||||||||
| await self.start() | ||||||||||||
| return self | ||||||||||||
|
dmytrostruk marked this conversation as resolved.
Outdated
|
||||||||||||
|
|
@@ -302,7 +303,6 @@ def run( | |||||||||||
| *, | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||
| stream: Literal[False] = False, | ||||||||||||
| session: AgentSession | None = None, | ||||||||||||
| options: OptionsT | None = None, | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this shouldn't be removed, should it? |
||||||||||||
| **kwargs: Any, | ||||||||||||
| ) -> Awaitable[AgentResponse]: ... | ||||||||||||
|
|
||||||||||||
|
|
@@ -313,7 +313,6 @@ def run( | |||||||||||
| *, | ||||||||||||
| stream: Literal[True], | ||||||||||||
| session: AgentSession | None = None, | ||||||||||||
| options: OptionsT | None = None, | ||||||||||||
| **kwargs: Any, | ||||||||||||
| ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... | ||||||||||||
|
|
||||||||||||
|
|
@@ -323,7 +322,6 @@ def run( | |||||||||||
| *, | ||||||||||||
| stream: bool = False, | ||||||||||||
| session: AgentSession | None = None, | ||||||||||||
| options: OptionsT | None = None, | ||||||||||||
| **kwargs: Any, | ||||||||||||
| ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: | ||||||||||||
| """Get a response from the agent. | ||||||||||||
|
|
@@ -338,8 +336,8 @@ def run( | |||||||||||
| Keyword Args: | ||||||||||||
| stream: Whether to stream the response. Defaults to False. | ||||||||||||
| session: The conversation session associated with the message(s). | ||||||||||||
| options: Runtime options (model, timeout, etc.). | ||||||||||||
| kwargs: Additional keyword arguments. | ||||||||||||
| kwargs: Additional keyword arguments, including ``options`` for runtime options | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should not reintroduce kwargs, we are trying to get rid of them, use |
||||||||||||
| (model, timeout, etc.). | ||||||||||||
|
|
||||||||||||
| Returns: | ||||||||||||
| When stream=False: An Awaitable[AgentResponse]. | ||||||||||||
|
|
@@ -348,16 +346,26 @@ def run( | |||||||||||
| Raises: | ||||||||||||
| AgentException: If the request fails. | ||||||||||||
| """ | ||||||||||||
| options = cast(OptionsT | None, kwargs.pop("options", None)) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
| if stream: | ||||||||||||
| return self._run_stream_impl(messages=messages, session=session, options=options, **kwargs) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same issue:
Suggested change
|
||||||||||||
| return self._run_impl(messages=messages, session=session, options=options, **kwargs) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
|
||||||||||||
| def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: | ||||||||||||
| return AgentResponse.from_updates(updates) | ||||||||||||
| def _run_stream_impl( | ||||||||||||
| self, | ||||||||||||
| messages: AgentRunInputs | None = None, | ||||||||||||
| *, | ||||||||||||
| session: AgentSession | None = None, | ||||||||||||
| options: OptionsT | None = None, | ||||||||||||
| **kwargs: Any, | ||||||||||||
| ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: | ||||||||||||
| def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: | ||||||||||||
| return AgentResponse.from_updates(updates) | ||||||||||||
|
|
||||||||||||
| return ResponseStream( | ||||||||||||
| self._stream_updates(messages=messages, session=session, options=options, **kwargs), | ||||||||||||
| finalizer=_finalize, | ||||||||||||
| ) | ||||||||||||
| return self._run_impl(messages=messages, session=session, options=options, **kwargs) | ||||||||||||
| return ResponseStream( | ||||||||||||
| self._stream_updates(messages=messages, session=session, options=options, **kwargs), | ||||||||||||
| finalizer=_finalize, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| async def _run_impl( | ||||||||||||
| self, | ||||||||||||
|
|
@@ -640,3 +648,11 @@ async def _resume_session(self, session_id: str, streaming: bool) -> CopilotSess | |||||||||||
| config["mcp_servers"] = self._mcp_servers | ||||||||||||
|
|
||||||||||||
| return await self._client.resume_session(session_id, config) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class GitHubCopilotAgent( | ||||||||||||
| AgentMiddlewareLayer, | ||||||||||||
| RawGitHubCopilotAgent[OptionsT], | ||||||||||||
|
dmytrostruk marked this conversation as resolved.
|
||||||||||||
| Generic[OptionsT], | ||||||||||||
| ): | ||||||||||||
| """GitHub Copilot agent with AgentMiddleware support.""" | ||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Declaring
middlewareexplicitly in RawClaudeAgent's overloads leaks AgentMiddlewareLayer's routing concern into the raw class. The parameter is accepted and immediately discarded (never forwarded to _run_stream_impl). Since it arrives through **kwargs anyway when AgentMiddlewareLayer calls super().run(), it can simply be absorbed there without an explicit declaration.