From e8399c207a7bf30a59298a1c2de05b3cf90af008 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 10 Mar 2026 10:37:08 +0100 Subject: [PATCH 01/13] Python: clean up kwargs across agents, chat clients, tools, and sessions (#3642) Audit and refactor public **kwargs usage across core agents, chat clients, tools, sessions, and provider packages per the migration strategy codified in CODING_STANDARD.md. Key changes: - Add explicit runtime buckets: function_invocation_kwargs and client_kwargs on RawAgent.run() and chat client get_response() layers. - Refactor FunctionTool to prefer explicit ctx: FunctionInvocationContext injection; legacy **kwargs tools still work via _forward_runtime_kwargs. - Refactor Agent.as_tool() to use direct JSON schema, always-streaming wrapper, approval_mode parameter, and UserInputRequiredException propagation (integrates PR #4568 behavior). - Remove implicit session bleeding into FunctionInvocationContext; tools that need a session must receive it via function_invocation_kwargs. - Lower chat-client layers after FunctionInvocationLayer accept only compatibility **kwargs (client_kwargs flattened, function_invocation_kwargs ignored). - Add layered docstring composition from Raw... implementations via _docstrings.py helper. - Clean up provider constructors to use explicit additional_properties. - Deprecation warnings on legacy direct kwargs paths. - Update samples, tests, and typing across all 23 packages. Resolves #3642 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- docs/decisions/0001-agent-run-response.md | 10 +- python/CODING_STANDARD.md | 5 + .../a2a/agent_framework_a2a/_agent.py | 18 +- .../ag-ui/agent_framework_ag_ui/_client.py | 3 - python/packages/ag-ui/tests/ag_ui/conftest.py | 6 +- .../ag_ui/test_agent_wrapper_comprehensive.py | 20 +- .../agent_framework_anthropic/_chat_client.py | 6 +- .../tests/test_aisearch_context_provider.py | 12 + .../agent_framework_azure_ai/_chat_client.py | 6 +- .../agent_framework_azure_ai/_client.py | 12 +- .../_embedding_client.py | 8 +- .../_history_provider.py | 17 +- .../agent_framework_bedrock/_chat_client.py | 6 +- .../_embedding_client.py | 8 +- .../claude/agent_framework_claude/_agent.py | 28 +- .../agent_framework_copilotstudio/_agent.py | 10 +- .../packages/core/agent_framework/__init__.py | 2 + .../packages/core/agent_framework/_agents.py | 336 ++++++++++++------ .../packages/core/agent_framework/_clients.py | 115 ++++-- .../core/agent_framework/_docstrings.py | 85 +++++ .../core/agent_framework/_middleware.py | 86 ++++- .../core/agent_framework/_sessions.py | 21 +- .../packages/core/agent_framework/_tools.py | 222 ++++++++++-- .../packages/core/agent_framework/_types.py | 6 +- .../agent_framework/azure/_chat_client.py | 6 +- .../core/agent_framework/exceptions.py | 28 ++ .../core/agent_framework/observability.py | 62 +++- .../agent_framework/openai/_chat_client.py | 150 +++++++- .../packages/core/tests/core/test_agents.py | 198 +++++++++-- .../core/test_as_tool_kwargs_propagation.py | 212 ++++++----- .../packages/core/tests/core/test_clients.py | 57 +++ .../core/tests/core/test_embedding_client.py | 7 + .../core/test_function_invocation_logic.py | 41 +++ .../test_kwargs_propagation_to_ai_function.py | 62 ++++ .../packages/core/tests/core/test_sessions.py | 4 +- python/packages/core/tests/core/test_tools.py | 123 +++++++ .../agent_framework_durabletask/_executors.py | 16 +- .../agent_framework_durabletask/_models.py | 53 +-- .../agent_framework_durabletask/_shim.py | 18 +- .../tests/test_agent_session_id.py | 14 +- .../packages/durabletask/tests/test_client.py | 9 - .../tests/test_orchestration_context.py | 11 - .../packages/durabletask/tests/test_shim.py | 11 - .../_foundry_local_client.py | 7 +- .../agent_framework_github_copilot/_agent.py | 11 +- .../agent_framework_ollama/_chat_client.py | 6 +- .../_embedding_client.py | 8 +- .../_history_provider.py | 19 +- .../agent_as_tool_with_session_propagation.py | 61 +++- .../tools/function_tool_with_kwargs.py | 32 +- .../function_tool_with_session_injection.py | 41 ++- 51 files changed, 1796 insertions(+), 519 deletions(-) create mode 100644 python/packages/core/agent_framework/_docstrings.py diff --git a/docs/decisions/0001-agent-run-response.md b/docs/decisions/0001-agent-run-response.md index 12724aca3a..6ffebe7e4f 100644 --- a/docs/decisions/0001-agent-run-response.md +++ b/docs/decisions/0001-agent-run-response.md @@ -4,8 +4,8 @@ status: accepted contact: westey-m date: 2025-07-10 {YYYY-MM-DD when the decision was last updated} deciders: sergeymenshykh, markwallace, rbarreto, dmytrostruk, westey-m, eavanvalkenburg, stephentoub -consulted: -informed: +consulted: +informed: --- # Agent Run Responses Design @@ -64,7 +64,7 @@ Approaches observed from the compared SDKs: | AutoGen | **Approach 1** Separates messages into Agent-Agent (maps to Primary) and Internal (maps to Secondary) and these are returned as separate properties on the agent response object. See [types of messages](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/messages.html#types-of-messages) and [Response](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.base.html#autogen_agentchat.base.Response) | **Approach 2** Returns a stream of internal events and the last item is a Response object. See [ChatAgent.on_messages_stream](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.base.html#autogen_agentchat.base.ChatAgent.on_messages_stream) | | OpenAI Agent SDK | **Approach 1** Separates new_items (Primary+Secondary) from final output (Primary) as separate properties on the [RunResult](https://github.com/openai/openai-agents-python/blob/main/src/agents/result.py#L39) | **Approach 1** Similar to non-streaming, has a way of streaming updates via a method on the response object which includes all data, and then a separate final output property on the response object which is populated only when the run is complete. See [RunResultStreaming](https://github.com/openai/openai-agents-python/blob/main/src/agents/result.py#L136) | | Google ADK | **Approach 2** [Emits events](https://google.github.io/adk-docs/runtime/#step-by-step-breakdown) with [FinalResponse](https://github.com/google/adk-java/blob/main/core/src/main/java/com/google/adk/events/Event.java#L232) true (Primary) / false (Secondary) and callers have to filter out those with false to get just the final response message | **Approach 2** Similar to non-streaming except [events](https://google.github.io/adk-docs/runtime/#streaming-vs-non-streaming-output-partialtrue) are emitted with [Partial](https://github.com/google/adk-java/blob/main/core/src/main/java/com/google/adk/events/Event.java#L133) true to indicate that they are streaming messages. A final non partial event is also emitted. | -| AWS (Strands) | **Approach 3** Returns an [AgentResult](https://strandsagents.com/docs/api/python/strands.agent.agent_result/#agentresult) (Primary) with messages and a reason for the run's completion. | **Approach 2** [Streams events](https://strandsagents.com/docs/user-guide/concepts/streaming/) (Primary+Secondary) including, response text, current_tool_use, even data from "callbacks" (strands plugins) | +| AWS (Strands) | **Approach 3** Returns an [AgentResult](https://strandsagents.com/docs/api/python/strands.agent.agent_result/) (Primary) with messages and a reason for the run's completion. | **Approach 2** [Streams events](https://strandsagents.com/docs/api/python/strands.agent.agent/) (Primary+Secondary) including, response text, current_tool_use, even data from "callbacks" (strands plugins) | | LangGraph | **Approach 2** A mixed list of all [messages](https://langchain-ai.github.io/langgraph/agents/run_agents/#output-format) | **Approach 2** A mixed list of all [messages](https://langchain-ai.github.io/langgraph/agents/run_agents/#output-format) | | Agno | **Combination of various approaches** Returns a [RunResponse](https://docs.agno.com/reference/agents/run-response) object with text content, messages (essentially chat history including inputs and instructions), reasoning and thinking text properties. Secondary events could potentially be extracted from messages. | **Approach 2** Returns [RunResponseEvent](https://docs.agno.com/reference/agents/run-response#runresponseevent-types-and-attributes) objects including tool call, memory update, etc, information, where the [RunResponseCompletedEvent](https://docs.agno.com/reference/agents/run-response#runresponsecompletedevent) has similar properties to RunResponse| | A2A | **Approach 3** Returns a [Task or Message](https://a2aproject.github.io/A2A/latest/specification/#71-messagesend) where the message is the final result (Primary) and task is a reference to a long running process. | **Approach 2** Returns a [stream](https://a2aproject.github.io/A2A/latest/specification/#72-messagestream) that contains task updates (Secondary) and a final message (Primary) | @@ -496,7 +496,7 @@ We need to decide what AIContent types, each agent response type will be mapped |-|-| | AutoGen | **Approach 1** Supports [configuring an agent](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/agents.html#structured-output) at agent creation. | | Google ADK | **Approach 1** Both [input and output schemas can be specified for LLM Agents](https://google.github.io/adk-docs/agents/llm-agents/#structuring-data-input_schema-output_schema-output_key) at construction time. This option is specific to this agent type and other agent types do not necessarily support | -| AWS (Strands) | **Approach 2** Supports a special invocation method called [structured_output](https://strandsagents.com/docs/user-guide/concepts/agents/structured-output/) | +| AWS (Strands) | **Approach 2** Supports a special invocation method called [structured_output](https://strandsagents.com/docs/api/python/strands.agent.agent/) | | LangGraph | **Approach 1** Supports [configuring an agent](https://langchain-ai.github.io/langgraph/agents/agents/?h=structured#6-configure-structured-output) at agent construction time, and a [structured response](https://langchain-ai.github.io/langgraph/agents/run_agents/#output-format) can be retrieved as a special property on the agent response | | Agno | **Approach 1** Supports [configuring an agent](https://docs.agno.com/input-output/structured-output/agent) at agent construction time | | A2A | **Informal Approach 2** Doesn't formally support schema negotiation, but [hints can be provided via metadata](https://a2a-protocol.org/latest/specification/#97-structured-data-exchange-requesting-and-providing-json) at invocation time | @@ -508,7 +508,7 @@ We need to decide what AIContent types, each agent response type will be mapped |-|-| | AutoGen | Supports a [stop reason](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.base.html#autogen_agentchat.base.TaskResult.stop_reason) which is a freeform text string | | Google ADK | [No equivalent present](https://github.com/google/adk-python/blob/main/src/google/adk/events/event.py) | -| AWS (Strands) | Exposes a `stop_reason` property on the [AgentResult](https://strandsagents.com/docs/api/python/strands.agent.agent_result/#agentresult) class with options that are tied closely to LLM operations. | +| AWS (Strands) | Exposes a [stop_reason](https://strandsagents.com/docs/api/python/strands.types.event_loop/) property on the [AgentResult](https://strandsagents.com/docs/api/python/strands.agent.agent_result/) class with options that are tied closely to LLM operations. | | LangGraph | No equivalent present, output contains only [messages](https://langchain-ai.github.io/langgraph/agents/run_agents/#output-format) | | Agno | [No equivalent present](https://docs.agno.com/reference/agents/run-response) | | A2A | No equivalent present, response only contains a [message](https://a2a-protocol.org/latest/specification/#64-message-object) or [task](https://a2a-protocol.org/latest/specification/#61-task-object). | diff --git a/python/CODING_STANDARD.md b/python/CODING_STANDARD.md index ccb8e058e3..8611592692 100644 --- a/python/CODING_STANDARD.md +++ b/python/CODING_STANDARD.md @@ -127,7 +127,12 @@ def create_agent(name: str, tool_mode: Literal['auto', 'required', 'none'] | Cha Avoid `**kwargs` unless absolutely necessary. It should only be used as an escape route, not for well-known flows of data: - **Prefer named parameters**: If there are known extra arguments being passed, use explicit named parameters instead of kwargs +- **Prefer purpose-specific buckets over generic kwargs**: If a flexible payload is still needed, use an explicit named parameter such as `additional_properties`, `function_invocation_kwargs`, or `client_kwargs` rather than a blanket `**kwargs` - **Subclassing support**: kwargs is acceptable in methods that are part of classes designed for subclassing, allowing subclass-defined kwargs to pass through without issues. In this case, clearly document that kwargs exists for subclass extensibility and not for passing arbitrary data +- **Make known flows explicit first**: For abstract hooks, move known data flows into explicit parameters before leaving `**kwargs` behind for subclass extensibility (for example, prefer `state=` explicitly instead of passing it through kwargs) +- **Prefer explicit metadata containers**: For constructors that expose metadata, prefer an explicit `additional_properties` parameter. +- **Keep SDK passthroughs narrow and documented**: A kwargs escape hatch may be acceptable for provider helper APIs that pass through to a large or unstable external SDK surface, but it should be documented as SDK passthrough and revisited regularly +- **Do not keep passthrough kwargs on wrappers that do not use them**: Convenience wrappers and session helpers should not accept generic kwargs merely to forward or ignore them - **Remove when possible**: In other cases, removing kwargs is likely better than keeping it - **Separate kwargs by purpose**: When combining kwargs for multiple purposes, use specific parameters like `client_kwargs: dict[str, Any]` instead of mixing everything in `**kwargs` - **Always document**: If kwargs must be used, always document how it's used, either by referencing external documentation or explaining its purpose diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 31fac386b3..f9f2afaeff 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -6,7 +6,7 @@ import json import re import uuid -from collections.abc import AsyncIterable, Awaitable, Sequence +from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence from typing import Any, Final, Literal, TypeAlias, overload import httpx @@ -218,6 +218,8 @@ def run( *, stream: Literal[False] = ..., session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, continuation_token: A2AContinuationToken | None = None, background: bool = False, **kwargs: Any, @@ -230,17 +232,21 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, continuation_token: A2AContinuationToken | None = None, background: bool = False, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - def run( + def run( # pyright: ignore[reportIncompatibleMethodOverride] self, messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, continuation_token: A2AContinuationToken | None = None, background: bool = False, **kwargs: Any, @@ -253,17 +259,23 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). + function_invocation_kwargs: Present for compatibility with the shared agent interface. + A2AAgent does not use these values directly. + client_kwargs: Present for compatibility with the shared agent interface. + A2AAgent does not use these values directly. + kwargs: Additional compatibility keyword arguments. + A2AAgent does not use these values directly. continuation_token: Optional token to resume a long-running task instead of starting a new one. background: When True, in-progress task updates surface continuation tokens so the caller can poll or resubscribe later. When False (default), the agent internally waits for the task to complete. - kwargs: Additional keyword arguments. Returns: When stream=False: An Awaitable[AgentResponse]. When stream=True: A ResponseStream of AgentResponseUpdate items. """ + del function_invocation_kwargs, client_kwargs, kwargs if continuation_token is not None: a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe( TaskIdParams(id=continuation_token["task_id"]) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index 7188eb739c..d2fb59bbb6 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -220,7 +220,6 @@ def __init__( additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, - **kwargs: Any, ) -> None: """Initialize the AG-UI chat client. @@ -231,13 +230,11 @@ def __init__( additional_properties: Additional properties to store middleware: Optional middleware to apply to the client. function_invocation_configuration: Optional function invocation configuration override. - **kwargs: Additional arguments passed to BaseChatClient """ super().__init__( additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) self._http_service = AGUIHttpService( endpoint=endpoint, diff --git a/python/packages/ag-ui/tests/ag_ui/conftest.py b/python/packages/ag-ui/tests/ag_ui/conftest.py index b73eddb8ad..3e839358da 100644 --- a/python/packages/ag-ui/tests/ag_ui/conftest.py +++ b/python/packages/ag-ui/tests/ag_ui/conftest.py @@ -98,7 +98,11 @@ def get_response( options: OptionsCoT | ChatOptions[Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: - self.last_session = kwargs.get("session") + compatibility_client_kwargs = kwargs.get("client_kwargs") + if isinstance(compatibility_client_kwargs, Mapping): + self.last_session = cast(AgentSession | None, compatibility_client_kwargs.get("session")) + else: + self.last_session = cast(AgentSession | None, kwargs.get("session")) self.last_service_session_id = self.last_session.service_session_id if self.last_session else None return cast( Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]], diff --git a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py index 75cb659633..d5e7b0ae83 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py @@ -702,14 +702,9 @@ async def test_agent_with_use_service_session_is_true(streaming_chat_client_stub """Test that when use_service_session is True, the AgentSession used to run the agent is set to the service session ID.""" from agent_framework.ag_ui import AgentFrameworkAgent - request_service_session_id: str | None = None - async def stream_fn( messages: MutableSequence[Message], chat_options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - nonlocal request_service_session_id - session = kwargs.get("session") - request_service_session_id = session.service_session_id if session else None yield ChatResponseUpdate( contents=[Content.from_text(text="Response")], response_id="resp_67890", conversation_id="conv_12345" ) @@ -719,11 +714,22 @@ async def stream_fn( input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"} + # Spy on agent.run to capture the session kwarg at call time (before streaming mutates it) + captured_service_session_id: str | None = None + original_run = agent.run + + def capturing_run(*args: Any, **kwargs: Any) -> Any: + nonlocal captured_service_session_id + session = kwargs.get("session") + captured_service_session_id = session.service_session_id if session else None + return original_run(*args, **kwargs) + + agent.run = capturing_run # type: ignore[assignment, method-assign] + events: list[Any] = [] async for event in wrapper.run(input_data): events.append(event) - request_service_session_id = agent.client.last_service_session_id - assert request_service_session_id == "conv_123456" # type: ignore[attr-defined] (service_session_id should be set) + assert captured_service_session_id == "conv_123456" async def test_function_approval_mode_executes_tool(streaming_chat_client_stub): diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index 5cda4991c8..32805b7dfe 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -228,11 +228,11 @@ def __init__( model_id: str | None = None, anthropic_client: AsyncAnthropic | None = None, additional_beta_flags: list[str] | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Anthropic Agent client. @@ -244,11 +244,11 @@ def __init__( For instance if you need to set a different base_url for testing or private deployments. additional_beta_flags: Additional beta flags to enable on the client. Default flags are: "mcp-client-2025-04-04", "code-execution-2025-08-25". + additional_properties: Additional properties stored on the client instance. middleware: Optional middleware to apply to the client. function_invocation_configuration: Optional function invocation configuration override. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. - kwargs: Additional keyword arguments passed to the parent class. Examples: .. code-block:: python @@ -319,9 +319,9 @@ class MyOptions(AnthropicChatOptions, total=False): # Initialize parent super().__init__( + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) # Initialize instance variables diff --git a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py index 3c4fb68fe8..9972f1301d 100644 --- a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py +++ b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py @@ -16,6 +16,18 @@ # -- Helpers ------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def clear_azure_search_env(monkeypatch: pytest.MonkeyPatch) -> None: + """Keep tests isolated from ambient Azure Search environment variables.""" + for key in ( + "AZURE_SEARCH_ENDPOINT", + "AZURE_SEARCH_INDEX_NAME", + "AZURE_SEARCH_KNOWLEDGE_BASE_NAME", + "AZURE_SEARCH_API_KEY", + ): + monkeypatch.delenv(key, raising=False) + + class MockSearchResults: """Async-iterable mock for Azure SearchClient.search() results.""" diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index 4c0e3a56e7..407f524d52 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -444,11 +444,11 @@ def __init__( model_deployment_name: str | None = None, credential: AzureCredentialTypes | None = None, should_cleanup_agent: bool = True, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Azure AI Agent client. @@ -471,11 +471,11 @@ def __init__( should_cleanup_agent: Whether to cleanup (delete) agents created by this client when the client is closed or context is exited. Defaults to True. Only affects agents created by this client instance; existing agents passed via agent_id are never deleted. + additional_properties: Additional properties stored on the client instance. middleware: Optional sequence of middlewares to include. function_invocation_configuration: Optional function invocation configuration. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. - kwargs: Additional keyword arguments passed to the parent class. Examples: .. code-block:: python @@ -548,9 +548,9 @@ class MyOptions(AzureAIAgentOptions, total=False): # Initialize parent super().__init__( + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) # Initialize instance variables diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index ba5dd8aad7..1fc6c7c1c9 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -119,9 +119,9 @@ def __init__( credential: AzureCredentialTypes | None = None, use_latest_version: bool | None = None, allow_preview: bool | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a bare Azure AI client. @@ -145,9 +145,9 @@ def __init__( use_latest_version: Boolean flag that indicates whether to use latest agent version if it exists in the service. allow_preview: Enables preview opt-in on internally-created ``AIProjectClient``. + additional_properties: Additional properties stored on the client instance. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. - kwargs: Additional keyword arguments passed to the parent class. Examples: .. code-block:: python @@ -217,7 +217,7 @@ class MyOptions(ChatOptions, total=False): # Initialize parent super().__init__( - **kwargs, + additional_properties=additional_properties, ) # Initialize instance variables @@ -1243,11 +1243,11 @@ def __init__( credential: AzureCredentialTypes | None = None, use_latest_version: bool | None = None, allow_preview: bool | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Azure AI client with full layer support. @@ -1268,11 +1268,11 @@ def __init__( use_latest_version: Boolean flag that indicates whether to use latest agent version if it exists in the service. allow_preview: Enables preview opt-in on internally-created ``AIProjectClient`` + additional_properties: Additional properties stored on the client instance. middleware: Optional sequence of chat middlewares to include. function_invocation_configuration: Optional function invocation configuration. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. - kwargs: Additional keyword arguments passed to the parent class. Examples: .. code-block:: python @@ -1319,9 +1319,9 @@ class MyOptions(ChatOptions, total=False): credential=credential, use_latest_version=use_latest_version, allow_preview=allow_preview, + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, env_file_path=env_file_path, env_file_encoding=env_file_encoding, - **kwargs, ) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py index a243f77a38..3daa678333 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py @@ -124,9 +124,9 @@ def __init__( text_client: EmbeddingsClient | None = None, image_client: ImageEmbeddingsClient | None = None, credential: AzureKeyCredential | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a raw Azure AI Inference embedding client.""" settings = load_settings( @@ -160,7 +160,7 @@ def __init__( credential=credential, # type: ignore[arg-type] ) self._endpoint = resolved_endpoint - super().__init__(**kwargs) + super().__init__(additional_properties=additional_properties) async def close(self) -> None: """Close the underlying SDK clients and release resources.""" @@ -376,9 +376,9 @@ def __init__( image_client: ImageEmbeddingsClient | None = None, credential: AzureKeyCredential | None = None, otel_provider_name: str | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Azure AI Inference embedding client.""" super().__init__( @@ -389,8 +389,8 @@ def __init__( text_client=text_client, image_client=image_client, credential=credential, + additional_properties=additional_properties, otel_provider_name=otel_provider_name, env_file_path=env_file_path, env_file_encoding=env_file_encoding, - **kwargs, ) diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py index 35c4243c37..6d205fa378 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py @@ -124,7 +124,13 @@ def __init__( self._database_client = self._cosmos_client.get_database_client(self.database_name) - async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]: + async def get_messages( + self, + session_id: str | None, + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> list[Message]: """Retrieve stored messages for this session from Azure Cosmos DB.""" await self._ensure_container_proxy() session_key = self._session_partition_key(session_id) @@ -157,7 +163,14 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess return messages - async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None: + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: """Persist messages for this session to Azure Cosmos DB.""" if not messages: return diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index 5bc9735846..cbb7e2879d 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -236,11 +236,11 @@ def __init__( session_token: str | None = None, client: BaseClient | None = None, boto3_session: Boto3Session | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Create a Bedrock chat client and load AWS credentials. @@ -252,11 +252,11 @@ def __init__( session_token: Optional AWS session token for temporary credentials. client: Preconfigured Bedrock runtime client; when omitted a boto3 session is created. boto3_session: Custom boto3 session used to build the runtime client if provided. + additional_properties: Additional properties stored on the client instance. middleware: Optional sequence of middlewares to include. function_invocation_configuration: Optional function invocation configuration env_file_path: Optional .env file path used by ``BedrockSettings`` to load defaults. env_file_encoding: Encoding for the optional .env file. - kwargs: Additional arguments forwarded to ``BaseChatClient``. Examples: .. code-block:: python @@ -303,9 +303,9 @@ class MyOptions(BedrockChatOptions, total=False): ) super().__init__( + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) self.model_id = chat_model_id self.region = region diff --git a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py index d07bdee45c..3161ed4c88 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py @@ -104,9 +104,9 @@ def __init__( session_token: str | None = None, client: BaseClient | None = None, boto3_session: Boto3Session | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a raw Bedrock embedding client.""" settings = load_settings( @@ -145,7 +145,7 @@ def __init__( self.model_id: str = settings["embedding_model_id"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess] self.region = resolved_region - super().__init__(**kwargs) + super().__init__(additional_properties=additional_properties) def service_url(self) -> str: """Get the URL of the service.""" @@ -274,9 +274,9 @@ def __init__( client: BaseClient | None = None, boto3_session: Boto3Session | None = None, otel_provider_name: str | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a Bedrock embedding client.""" super().__init__( @@ -287,8 +287,8 @@ def __init__( session_token=session_token, client=client, boto3_session=boto3_session, + additional_properties=additional_properties, otel_provider_name=otel_provider_name, env_file_path=env_file_path, env_file_encoding=env_file_encoding, - **kwargs, ) diff --git a/python/packages/claude/agent_framework_claude/_agent.py b/python/packages/claude/agent_framework_claude/_agent.py index 127e3647ee..549c7a5046 100644 --- a/python/packages/claude/agent_framework_claude/_agent.py +++ b/python/packages/claude/agent_framework_claude/_agent.py @@ -5,7 +5,7 @@ import contextlib import logging import sys -from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence +from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, Sequence from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload @@ -581,7 +581,9 @@ def run( *, stream: Literal[False] = ..., session: AgentSession | None = None, - **kwargs: Any, + options: OptionsT | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, ) -> Awaitable[AgentResponse[Any]]: ... @overload @@ -591,7 +593,9 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, - **kwargs: Any, + options: OptionsT | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( @@ -600,7 +604,9 @@ def run( *, stream: bool = False, session: AgentSession | None = None, - **kwargs: Any, + options: OptionsT | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Run the agent with the given messages. @@ -612,16 +618,19 @@ def run( returns an awaitable AgentResponse. session: The conversation session. If session has service_session_id set, the agent will resume that session. - kwargs: Additional keyword arguments including 'options' for runtime options - (model, permission_mode can be changed per-request). + options: Runtime options. Model and permission_mode can be changed per request. + function_invocation_kwargs: Present for compatibility with the shared agent interface. + ClaudeAgent does not use these values directly. + client_kwargs: Present for compatibility with the shared agent interface. + ClaudeAgent does not use these values directly. Returns: When stream=True: An ResponseStream for streaming updates. When stream=False: An Awaitable[AgentResponse] with the complete response. """ - options = kwargs.pop("options", None) + del function_invocation_kwargs, client_kwargs response = ResponseStream( - self._get_stream(messages, session=session, options=options, **kwargs), + self._get_stream(messages, session=session, options=options), finalizer=self._finalize_response, ) @@ -634,8 +643,7 @@ async def _get_stream( messages: AgentRunInputs | None = None, *, session: AgentSession | None = None, - options: OptionsT | MutableMapping[str, Any] | None = None, - **kwargs: Any, + options: OptionsT | None = None, ) -> AsyncIterable[AgentResponseUpdate]: """Internal streaming implementation.""" session = session or self.create_session() diff --git a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py index edacb614a5..fc2a35c72b 100644 --- a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py +++ b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py @@ -196,7 +196,6 @@ def run( *, stream: Literal[False] = False, session: AgentSession | None = None, - **kwargs: Any, ) -> Awaitable[AgentResponse]: ... @overload @@ -206,7 +205,6 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, - **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... def run( @@ -215,7 +213,6 @@ def run( *, stream: bool = False, session: AgentSession | None = None, - **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: """Get a response from the agent. @@ -229,22 +226,20 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). - kwargs: Additional keyword arguments. Returns: When stream=False: An Awaitable[AgentResponse]. When stream=True: A ResponseStream of AgentResponseUpdate items. """ if stream: - return self._run_stream_impl(messages=messages, session=session, **kwargs) - return self._run_impl(messages=messages, session=session, **kwargs) + return self._run_stream_impl(messages=messages, session=session) + return self._run_impl(messages=messages, session=session) async def _run_impl( self, messages: AgentRunInputs | None = None, *, session: AgentSession | None = None, - **kwargs: Any, ) -> AgentResponse: """Non-streaming implementation of run.""" if not session: @@ -269,7 +264,6 @@ def _run_stream_impl( messages: AgentRunInputs | None = None, *, session: AgentSession | None = None, - **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: """Streaming implementation of run.""" diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index ef03652898..57a438da2a 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -181,6 +181,7 @@ ) from .exceptions import ( MiddlewareException, + UserInputRequiredException, WorkflowCheckpointException, WorkflowConvergenceException, WorkflowException, @@ -291,6 +292,7 @@ "TypeCompatibilityError", "UpdateT", "UsageDetails", + "UserInputRequiredException", "ValidationTypeEnum", "Workflow", "WorkflowAgent", diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 5cf7ff78a2..efc6fd8e51 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -2,10 +2,10 @@ from __future__ import annotations -import inspect import logging import re import sys +import warnings from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack from copy import deepcopy @@ -27,11 +27,12 @@ from mcp import types from mcp.server.lowlevel import Server from mcp.shared.exceptions import McpError -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel from ._clients import BaseChatClient, SupportsChatGetResponse +from ._docstrings import apply_layered_docstring from ._mcp import LOG_LEVEL_MAPPING, MCPTool -from ._middleware import AgentMiddlewareLayer, MiddlewareTypes +from ._middleware import AgentMiddlewareLayer, FunctionInvocationContext, MiddlewareTypes from ._serialization import SerializationMixin from ._sessions import ( AgentSession, @@ -57,7 +58,7 @@ map_chat_to_agent_update, normalize_messages, ) -from .exceptions import AgentInvalidResponseException +from .exceptions import AgentInvalidResponseException, UserInputRequiredException from .observability import AgentTelemetryLayer if sys.version_info >= (3, 13): @@ -177,8 +178,8 @@ class _RunContext(TypedDict): session_messages: Sequence[Message] agent_name: str chat_options: MutableMapping[str, Any] - filtered_kwargs: Mapping[str, Any] - finalize_kwargs: Mapping[str, Any] + client_kwargs: Mapping[str, Any] + function_invocation_kwargs: Mapping[str, Any] # region Agent Protocol @@ -226,15 +227,15 @@ async def _stream(): return AgentResponse(messages=[], response_id="custom-response") - def create_session(self, **kwargs): + def create_session(self, *, session_id: str | None = None): from agent_framework import AgentSession - return AgentSession(**kwargs) + return AgentSession(session_id=session_id) - def get_session(self, *, service_session_id, **kwargs): + def get_session(self, service_session_id: str, *, session_id: str | None = None): from agent_framework import AgentSession - return AgentSession(service_session_id=service_session_id, **kwargs) + return AgentSession(service_session_id=service_session_id, session_id=session_id) # Verify the instance satisfies the protocol @@ -253,6 +254,8 @@ def run( *, stream: Literal[False] = ..., session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: """Get a response from the agent (non-streaming).""" @@ -265,6 +268,8 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Get a streaming response from the agent.""" @@ -276,6 +281,8 @@ def run( *, stream: bool = False, session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Get a response from the agent. @@ -290,6 +297,8 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). + function_invocation_kwargs: Keyword arguments forwarded to tool invocation. + client_kwargs: Additional client-specific keyword arguments. kwargs: Additional keyword arguments. Returns: @@ -299,11 +308,11 @@ def run( """ ... - def create_session(self, **kwargs: Any) -> AgentSession: + def create_session(self, *, session_id: str | None = None) -> AgentSession: """Creates a new conversation session.""" ... - def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: """Gets or creates a session for a service-managed session ID.""" ... @@ -386,6 +395,13 @@ def __init__( additional_properties: Additional properties set on the agent. kwargs: Additional keyword arguments (merged into additional_properties). """ + if kwargs: + warnings.warn( + "Passing additional properties as direct keyword arguments to BaseAgent is deprecated; " + "pass them via additional_properties instead.", + DeprecationWarning, + stacklevel=3, + ) if id is None: id = str(uuid4()) self.id = id @@ -400,27 +416,40 @@ def __init__( self.additional_properties: dict[str, Any] = cast(dict[str, Any], additional_properties or {}) self.additional_properties.update(kwargs) - def create_session(self, *, session_id: str | None = None, **kwargs: Any) -> AgentSession: + def create_session(self, *, session_id: str | None = None) -> AgentSession: """Create a new lightweight session. + This will be used by a agent to hold the persisted session. + This depends on the service used, in some cases, or with store=True + this will add the `service_session_id` based on the response, + which is then fed back to the API on the next call. + + In other cases, if there is a HistoryProvider setup in the agent, + that is used and it can store state in the session. + + If there is no HistoryProvider and store=False or the default of a service is False. + Then a ``InMemoryHistoryProvider`` is added to the agent and used with the session automatically. + The ``InMemoryHistoryProvider`` stores the messages as `state` in the session by default. + Keyword Args: session_id: Optional session ID (generated if not provided). - kwargs: Additional keyword arguments. Returns: A new AgentSession instance. """ return AgentSession(session_id=session_id) - def get_session(self, *, service_session_id: str, session_id: str | None = None, **kwargs: Any) -> AgentSession: - """Get or create a session for a service-managed session ID. + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + """Get a session for a service-managed session ID. + + Only use this to create a session continuing that session id from a service. + Otherwise use ``create_session``. Args: service_session_id: The service-managed session ID. Keyword Args: session_id: Optional local session ID (generated if not provided). - kwargs: Additional keyword arguments. Returns: A new AgentSession instance with service_session_id set. @@ -460,9 +489,8 @@ def as_tool( description: str | None = None, arg_name: str = "task", arg_description: str | None = None, - stream_callback: Callable[[AgentResponseUpdate], None] - | Callable[[AgentResponseUpdate], Awaitable[None]] - | None = None, + approval_mode: Literal["always_require", "never_require"] = "never_require", + stream_callback: Callable[[AgentResponseUpdate], Awaitable[None] | None] | None = None, propagate_session: bool = False, ) -> FunctionTool: """Create a FunctionTool that wraps this agent. @@ -473,21 +501,18 @@ def as_tool( arg_name: The name of the function argument (default: "task"). arg_description: The description for the function argument. If None, defaults to "Task for {tool_name}". + approval_mode: Whether this delegated tool requires approval before execution. stream_callback: Optional callback for streaming responses. If provided, uses run(..., stream=True). - propagate_session: If True, the parent agent's ``AgentSession`` is - forwarded to this sub-agent's ``run()`` call, so both agents - operate within the same logical session (sharing the same - ``session_id`` and provider-managed state, such as any stored - conversation history or metadata). Defaults to False, meaning - the sub-agent runs with a new, independent session. + propagate_session: If True, the sub-agent's ``run()`` call receives + the ``session`` value from ``FunctionInvocationContext.kwargs`` + when one is supplied explicitly (for example via + ``function_invocation_kwargs={"session": session}``). Defaults + to False, meaning the sub-agent runs with a new, independent + session. Returns: A FunctionTool that can be used as a tool by other agents. - Raises: - TypeError: If the agent does not implement SupportsAgentRun. - ValueError: If the agent tool name cannot be determined. - Examples: .. code-block:: python @@ -499,7 +524,8 @@ def as_tool( # Convert the agent to a tool (independent session) research_tool = agent.as_tool() - # Convert the agent to a tool (shared session with parent) + # Convert the agent to a tool (shared session when the caller + # passes ``function_invocation_kwargs={"session": session}``) research_tool = agent.as_tool(propagate_session=True) # Use the tool with another agent @@ -515,59 +541,54 @@ def as_tool( tool_description = description or self.description or "" argument_description = arg_description or f"Task for {tool_name}" - # Create dynamic input model with the specified argument name - field_info = Field(..., description=argument_description) - model_name = f"{name or _sanitize_agent_name(self.name) or 'agent'}_task" - input_model = create_model(model_name, **{arg_name: (str, field_info)}) # type: ignore[call-overload] - - # Check if callback is async once, outside the wrapper - is_async_callback = stream_callback is not None and inspect.iscoroutinefunction(stream_callback) - - async def agent_wrapper(**kwargs: Any) -> str: - """Wrapper function that calls the agent.""" - # Extract the input from kwargs using the specified arg_name - input_text = kwargs.get(arg_name, "") - - # Extract parent session when propagate_session is enabled - parent_session = kwargs.get("session") if propagate_session else None + input_schema = { + "type": "object", + "properties": { + arg_name: { + "type": "string", + "description": argument_description, + } + }, + "required": [arg_name], + "additionalProperties": False, + } - # Forward runtime context kwargs, excluding framework-internal keys. - forwarded_kwargs = { - k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options", "session") - } + async def _agent_wrapper(ctx: FunctionInvocationContext, **kwargs: Any) -> str: + """Wrapper function that calls the agent. - if stream_callback is None: - # Use non-streaming mode - return ( - await self.run( - input_text, - stream=False, - session=parent_session, - **forwarded_kwargs, + Args: + ctx: the function invocation context used + **kwargs: only used to dynamically load the argument that is defined for this tool. + """ + session = None + if propagate_session: + session = ctx.kwargs.get("session") + if session and not isinstance(session, AgentSession): + raise TypeError( + "The provided session is not a AgentSession object, please make sure to " + "pass it through the function_invocation_kwargs." ) - ).text - - # Use streaming mode - accumulate updates and create final response - response_updates: list[AgentResponseUpdate] = [] - async for update in self.run(input_text, stream=True, session=parent_session, **forwarded_kwargs): - response_updates.append(update) - if is_async_callback: - await stream_callback(update) # type: ignore[misc] - else: - stream_callback(update) - - # Create final text from accumulated updates - return AgentResponse.from_updates(response_updates).text - - agent_tool: FunctionTool = FunctionTool( + stream = self.run( + str(kwargs.get(arg_name, "")), + stream=True, + session=session, + function_invocation_kwargs=dict(ctx.kwargs), + ) + if stream_callback is not None: + stream.with_transform_hook(stream_callback) + final_response = await stream.get_final_response() + if final_response.user_input_requests: + raise UserInputRequiredException(contents=final_response.user_input_requests) + # TODO(Copilot): update once #4331 merges + return final_response.text + + return FunctionTool( name=tool_name, description=tool_description, - func=agent_wrapper, - input_model=input_model, # type: ignore - approval_mode="never_require", + func=_agent_wrapper, + input_model=input_schema, + approval_mode=approval_mode, ) - agent_tool._forward_runtime_kwargs = True # type: ignore - return agent_tool # region Agent @@ -799,6 +820,8 @@ def run( session: AgentSession | None = None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, options: ChatOptions[ResponseModelBoundT], + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ... @@ -811,6 +834,8 @@ def run( session: AgentSession | None = None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, options: OptionsCoT | ChatOptions[None] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @@ -823,6 +848,8 @@ def run( session: AgentSession | None = None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, options: OptionsCoT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... @@ -834,6 +861,8 @@ def run( session: AgentSession | None = None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, options: OptionsCoT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Run the agent with the given messages and options. @@ -857,14 +886,23 @@ def run( ``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for provider-specific options including temperature, max_tokens, model_id, tool_choice, and provider-specific options like reasoning_effort. - kwargs: Additional keyword arguments for the agent. - Will only be passed to functions that are called. + function_invocation_kwargs: Keyword arguments forwarded to tool invocation. + client_kwargs: Additional client-specific keyword arguments for the chat client. + kwargs: Deprecated additional keyword arguments for the agent. + They are forwarded to both tool invocation and the chat client for compatibility. Returns: When stream=False: An Awaitable[AgentResponse] containing the agent's response. When stream=True: A ResponseStream of AgentResponseUpdate items with ``get_final_response()`` for the final AgentResponse. """ + if kwargs: + warnings.warn( + "Passing runtime keyword arguments directly to run() is deprecated; pass tool values via " + "function_invocation_kwargs and client-specific values via client_kwargs instead.", + DeprecationWarning, + stacklevel=2, + ) if not stream: async def _run_non_streaming() -> AgentResponse[Any]: @@ -873,7 +911,9 @@ async def _run_non_streaming() -> AgentResponse[Any]: session=session, tools=tools, options=options, - kwargs=kwargs, + legacy_kwargs=kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ) response = cast( ChatResponse[Any], @@ -881,7 +921,8 @@ async def _run_non_streaming() -> AgentResponse[Any]: messages=ctx["session_messages"], stream=False, options=ctx["chat_options"], # type: ignore[reportArgumentType] - **ctx["filtered_kwargs"], + function_invocation_kwargs=ctx["function_invocation_kwargs"], + client_kwargs=ctx["client_kwargs"], ), ) @@ -954,14 +995,17 @@ async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]] session=session, tools=tools, options=options, - kwargs=kwargs, + legacy_kwargs=kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ) ctx: _RunContext = ctx_holder["ctx"] # type: ignore[assignment] # Safe: we just assigned it return self.client.get_response( # type: ignore[call-overload, no-any-return] messages=ctx["session_messages"], stream=True, options=ctx["chat_options"], # type: ignore[reportArgumentType] - **ctx["filtered_kwargs"], + function_invocation_kwargs=ctx["function_invocation_kwargs"], + client_kwargs=ctx["client_kwargs"], ) def _propagate_conversation_id( @@ -1047,9 +1091,12 @@ async def _prepare_run_context( session: AgentSession | None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None, options: Mapping[str, Any] | None, - kwargs: dict[str, Any], + legacy_kwargs: Mapping[str, Any], + function_invocation_kwargs: Mapping[str, Any] | None, + client_kwargs: Mapping[str, Any] | None, ) -> _RunContext: opts = dict(options) if options else {} + existing_additional_args: dict[str, Any] = opts.pop("additional_function_arguments", None) or {} # Get tools from options or named parameter (named param takes precedence) tools_ = tools if tools is not None else opts.pop("tools", None) @@ -1080,6 +1127,12 @@ async def _prepare_run_context( input_messages=input_messages, options=opts, ) + default_additional_args = chat_options.pop("additional_function_arguments", None) + if isinstance(default_additional_args, Mapping): + existing_additional_args = { + **dict(cast(Mapping[str, Any], default_additional_args)), + **existing_additional_args, + } # Normalize tools normalized_tools = normalize_tools(tools_) @@ -1101,13 +1154,13 @@ async def _prepare_run_context( await self._async_exit_stack.enter_async_context(mcp_server) final_tools.extend(f for f in mcp_server.functions if f.name not in existing_names) - # Merge runtime kwargs into additional_function_arguments so they're available - # in function middleware context and tool invocation. - existing_additional_args: dict[str, Any] = opts.pop("additional_function_arguments", None) or {} - additional_function_arguments = {**kwargs, **existing_additional_args} - # Include session so as_tool() wrappers with propagate_session=True can access it. - if active_session is not None: - additional_function_arguments["session"] = active_session + # TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed. + # Legacy compatibility still fans out direct run kwargs into tool runtime kwargs. + effective_function_invocation_kwargs = { + **dict(legacy_kwargs), + **(dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}), + } + additional_function_arguments = {**effective_function_invocation_kwargs, **existing_additional_args} # Build options dict from run() options merged with provided options run_opts: dict[str, Any] = { @@ -1116,7 +1169,6 @@ async def _prepare_run_context( if active_session else opts.pop("conversation_id", None), "allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None), - "additional_function_arguments": additional_function_arguments or None, "frequency_penalty": opts.pop("frequency_penalty", None), "logit_bias": opts.pop("logit_bias", None), "max_tokens": opts.pop("max_tokens", None), @@ -1140,11 +1192,12 @@ async def _prepare_run_context( # Build session_messages from session context: context messages + input messages session_messages: list[Message] = session_context.get_messages(include_input=True) - # Ensure session is forwarded in kwargs for tool invocation - finalize_kwargs = dict(kwargs) - finalize_kwargs["session"] = active_session - # Filter chat_options from kwargs to prevent duplicate keyword argument - filtered_kwargs = {k: v for k, v in finalize_kwargs.items() if k != "chat_options"} + # TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed. + # Legacy compatibility still fans out direct run kwargs into client kwargs. + effective_client_kwargs = { + **dict(legacy_kwargs), + **(dict(client_kwargs) if client_kwargs is not None else {}), + } return { "session": active_session, @@ -1153,8 +1206,8 @@ async def _prepare_run_context( "session_messages": session_messages, "agent_name": agent_name, "chat_options": co, - "filtered_kwargs": filtered_kwargs, - "finalize_kwargs": finalize_kwargs, + "client_kwargs": effective_client_kwargs, + "function_invocation_kwargs": additional_function_arguments, } async def _finalize_response( @@ -1396,6 +1449,58 @@ class Agent( For a minimal implementation without these features, use :class:`RawAgent`. """ + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[False] = ..., + session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[True], + session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + middleware: Sequence[MiddlewareTypes] | None = None, + options: OptionsCoT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: + """Run the agent.""" + super_run = cast( + "Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]", + super().run, # type: ignore[misc] + ) + return super_run( # type: ignore[no-any-return] + messages=messages, + stream=stream, + session=session, + middleware=middleware, + options=options, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + **kwargs, + ) + def __init__( self, client: SupportsChatGetResponse[OptionsCoT], @@ -1423,3 +1528,34 @@ def __init__( middleware=middleware, **kwargs, ) + + +def _apply_agent_docstrings() -> None: + """Align public agent docstrings with the raw implementation.""" + apply_layered_docstring( + AgentMiddlewareLayer.run, + RawAgent.run, + extra_keyword_args={ + "middleware": """ + Optional per-run agent, chat, and function middleware. + Agent middleware wraps the run itself, while chat and function middleware are forwarded to the + underlying chat-client stack for this call. + """, + }, + ) + apply_layered_docstring(AgentTelemetryLayer.run, AgentMiddlewareLayer.run) + apply_layered_docstring( + Agent.run, + RawAgent.run, + extra_keyword_args={ + "middleware": """ + Optional per-run agent, chat, and function middleware. + Agent middleware wraps the run itself, while chat and function middleware are forwarded to the + underlying chat-client stack for this call. + """, + }, + ) + apply_layered_docstring(Agent.__init__, RawAgent.__init__) + + +_apply_agent_docstrings() diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 5dd049ecd3..b667187da2 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -4,6 +4,7 @@ import logging import sys +import warnings from abc import ABC, abstractmethod from collections.abc import ( AsyncIterable, @@ -27,6 +28,7 @@ from pydantic import BaseModel +from ._docstrings import apply_layered_docstring from ._serialization import SerializationMixin from ._tools import ( FunctionInvocationConfiguration, @@ -104,7 +106,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]): class CustomChatClient: additional_properties: dict = {} - def get_response(self, messages, *, stream=False, **kwargs): + def get_response(self, messages, *, stream=False, client_kwargs=None, **kwargs): if stream: from agent_framework import ChatResponseUpdate, ResponseStream @@ -144,6 +146,8 @@ def get_response( *, stream: Literal[False] = ..., options: OptionsContraT | ChatOptions[None] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]]: ... @@ -154,6 +158,8 @@ def get_response( *, stream: Literal[True], options: OptionsContraT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... @@ -163,6 +169,8 @@ def get_response( *, stream: bool = False, options: OptionsContraT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: """Send input and return the response. @@ -171,7 +179,9 @@ def get_response( messages: The sequence of input messages to send. stream: Whether to stream the response. Defaults to False. options: Chat options as a TypedDict. - **kwargs: Additional chat options. + function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers. + client_kwargs: Additional client-specific keyword arguments. + **kwargs: Deprecated additional client-specific keyword arguments. Returns: When stream=False: An awaitable ChatResponse from the client. @@ -276,7 +286,15 @@ def __init__( kwargs: Additional keyword arguments (merged into additional_properties). """ self.additional_properties = additional_properties or {} - super().__init__(**kwargs) + if kwargs: + warnings.warn( + "Passing additional properties as direct keyword arguments to BaseChatClient is deprecated; " + "pass them via additional_properties instead.", + DeprecationWarning, + stacklevel=3, + ) + self.additional_properties.update(kwargs) + super().__init__() def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: """Convert the instance to a dictionary. @@ -374,6 +392,8 @@ def get_response( *, stream: Literal[False] = ..., options: ChatOptions[ResponseModelBoundT], + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... @@ -411,16 +431,31 @@ def get_response( messages: The message or messages to send to the model. stream: Whether to stream the response. Defaults to False. options: Chat options as a TypedDict. - **kwargs: Other keyword arguments, can be used to pass function specific parameters. + **kwargs: Additional compatibility keyword arguments. Lower chat-client layers do not + consume ``function_invocation_kwargs`` directly; if present, it is ignored here + because function invocation has already been handled by upper layers. If a + ``client_kwargs`` mapping is present, it is flattened into standard keyword + arguments before forwarding to ``_inner_get_response()`` so client implementations + that care about those values can still use them, while implementations that ignore + extra kwargs remain compatible. Returns: When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse. """ + compatibility_client_kwargs = kwargs.pop("client_kwargs", None) + kwargs.pop("function_invocation_kwargs", None) + merged_client_kwargs = ( + dict(cast(Mapping[str, Any], compatibility_client_kwargs)) + if isinstance(compatibility_client_kwargs, Mapping) + else {} + ) + merged_client_kwargs.update(kwargs) + return self._inner_get_response( messages=messages, stream=stream, options=options or {}, # type: ignore[arg-type] - **kwargs, + **merged_client_kwargs, ) def service_url(self) -> str: @@ -446,7 +481,7 @@ def as_agent( context_providers: Sequence[Any] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, - **kwargs: Any, + additional_properties: Mapping[str, Any] | None = None, ) -> Agent[OptionsCoT]: """Create a Agent with this client. @@ -468,7 +503,7 @@ def as_agent( context_providers: Context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. function_invocation_configuration: Optional function invocation configuration override. - kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``. + additional_properties: Additional properties stored on the created agent. Returns: A Agent instance configured with this chat client. @@ -493,19 +528,22 @@ def as_agent( """ from ._agents import Agent - return Agent( - client=self, - id=id, - name=name, - description=description, - instructions=instructions, - tools=tools, - default_options=cast(Any, default_options), - context_providers=context_providers, - middleware=middleware, - function_invocation_configuration=function_invocation_configuration, - **kwargs, - ) + agent_kwargs: dict[str, Any] = { + "client": self, + "id": id, + "name": name, + "description": description, + "instructions": instructions, + "tools": tools, + "default_options": cast(Any, default_options), + "context_providers": context_providers, + "middleware": middleware, + "additional_properties": dict(additional_properties) if additional_properties is not None else None, + } + if function_invocation_configuration is not None: + agent_kwargs["function_invocation_configuration"] = function_invocation_configuration + + return Agent(**agent_kwargs) # endregion @@ -768,16 +806,14 @@ def __init__( self, *, additional_properties: dict[str, Any] | None = None, - **kwargs: Any, ) -> None: """Initialize a BaseEmbeddingClient instance. Args: additional_properties: Additional properties to pass to the client. - **kwargs: Additional keyword arguments passed to parent classes (for MRO). """ self.additional_properties = additional_properties or {} - super().__init__(**kwargs) + super().__init__() @abstractmethod async def get_embeddings( @@ -799,3 +835,36 @@ async def get_embeddings( # endregion + + +def _apply_get_response_docstrings() -> None: + """Align layered chat-client docstrings with the lowest public implementation.""" + from ._middleware import ChatMiddlewareLayer + from ._tools import FunctionInvocationLayer + from .observability import ChatTelemetryLayer + + apply_layered_docstring(ChatTelemetryLayer.get_response, BaseChatClient.get_response) + apply_layered_docstring( + FunctionInvocationLayer.get_response, + ChatTelemetryLayer.get_response, + extra_keyword_args={ + "function_middleware": """ + Optional per-call function middleware. + When omitted, middleware configured on the client or forwarded from higher layers is used. + """, + }, + ) + apply_layered_docstring( + ChatMiddlewareLayer.get_response, + FunctionInvocationLayer.get_response, + extra_keyword_args={ + "middleware": """ + Optional per-call chat and function middleware. + This compatibility keyword argument is merged with any ``client_kwargs["middleware"]`` value + before the request is executed. + """, + }, + ) + + +_apply_get_response_docstrings() diff --git a/python/packages/core/agent_framework/_docstrings.py b/python/packages/core/agent_framework/_docstrings.py new file mode 100644 index 0000000000..44dd7c50a3 --- /dev/null +++ b/python/packages/core/agent_framework/_docstrings.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import inspect +from collections.abc import Callable, Mapping +from typing import Any + +_GOOGLE_SECTION_HEADERS = ( + "Args:", + "Keyword Args:", + "Returns:", + "Raises:", + "Examples:", + "Note:", + "Notes:", + "Warning:", + "Warnings:", +) + + +def _find_section_index(lines: list[str], header: str) -> int | None: + for index, line in enumerate(lines): + if line == header: + return index + return None + + +def _find_next_section_index(lines: list[str], start: int) -> int: + for index in range(start, len(lines)): + if lines[index] in _GOOGLE_SECTION_HEADERS: + return index + return len(lines) + + +def _format_keyword_arg_lines(extra_keyword_args: Mapping[str, str]) -> list[str]: + formatted_lines: list[str] = [] + for name, description in extra_keyword_args.items(): + description_lines = inspect.cleandoc(description).splitlines() + if not description_lines: + formatted_lines.append(f" {name}:") + continue + formatted_lines.append(f" {name}: {description_lines[0]}") + formatted_lines.extend(f" {line}" for line in description_lines[1:]) + return formatted_lines + + +def build_layered_docstring( + source: Callable[..., Any], + *, + extra_keyword_args: Mapping[str, str] | None = None, +) -> str | None: + """Build a Google-style docstring from a lower-layer implementation.""" + docstring = inspect.getdoc(source) + if not docstring: + return None + if not extra_keyword_args: + return docstring + + lines = docstring.splitlines() + formatted_keyword_arg_lines = _format_keyword_arg_lines(extra_keyword_args) + keyword_args_index = _find_section_index(lines, "Keyword Args:") + + if keyword_args_index is None: + args_index = _find_section_index(lines, "Args:") + if args_index is not None: + insert_index = _find_next_section_index(lines, args_index + 1) + else: + insert_index = _find_next_section_index(lines, 0) + lines[insert_index:insert_index] = ["", "Keyword Args:", *formatted_keyword_arg_lines] + return "\n".join(lines).rstrip() + + insert_index = _find_next_section_index(lines, keyword_args_index + 1) + lines[insert_index:insert_index] = formatted_keyword_arg_lines + return "\n".join(lines).rstrip() + + +def apply_layered_docstring( + target: Callable[..., Any], + source: Callable[..., Any], + *, + extra_keyword_args: Mapping[str, str] | None = None, +) -> None: + """Copy a lower-layer docstring onto a wrapper and extend it when needed.""" + target.__doc__ = build_layered_docstring(source, extra_keyword_args=extra_keyword_args) diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 7f3f3da13d..f1f3b234d0 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -106,7 +106,9 @@ class AgentContext: to see the actual execution result or can be set to override the execution result. For non-streaming: should be AgentResponse. For streaming: should be ResponseStream[AgentResponseUpdate, AgentResponse]. - kwargs: Additional keyword arguments passed to the agent run method. + kwargs: Legacy runtime keyword arguments visible to agent middleware. + client_kwargs: Client-specific keyword arguments for downstream chat clients. + function_invocation_kwargs: Keyword arguments forwarded to tool invocation. Examples: .. code-block:: python @@ -142,6 +144,8 @@ def __init__( metadata: Mapping[str, Any] | None = None, result: AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None = None, kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, stream_transform_hooks: Sequence[ Callable[[AgentResponseUpdate], AgentResponseUpdate | Awaitable[AgentResponseUpdate]] ] @@ -160,7 +164,9 @@ def __init__( stream: Whether this is a streaming invocation. metadata: Metadata dictionary for sharing data between agent middleware. result: Agent execution result. - kwargs: Additional keyword arguments passed to the agent run method. + kwargs: Legacy runtime keyword arguments visible to agent middleware. + client_kwargs: Client-specific keyword arguments for downstream chat clients. + function_invocation_kwargs: Keyword arguments forwarded to tool invocation. stream_transform_hooks: Hooks to transform streamed updates. stream_result_hooks: Hooks to process the final result after streaming. stream_cleanup_hooks: Hooks to run after streaming completes. @@ -173,6 +179,10 @@ def __init__( self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {} self.result = result self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {} + self.client_kwargs: dict[str, Any] = dict(client_kwargs) if client_kwargs is not None else {} + self.function_invocation_kwargs: dict[str, Any] = ( + dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {} + ) self.stream_transform_hooks = list(stream_transform_hooks or []) self.stream_result_hooks = list(stream_result_hooks or []) self.stream_cleanup_hooks = list(stream_cleanup_hooks or []) @@ -190,8 +200,7 @@ class FunctionInvocationContext: metadata: Metadata dictionary for sharing data between function middleware. result: Function execution result. Can be observed after calling ``call_next()`` to see the actual execution result or can be set to override the execution result. - - kwargs: Additional keyword arguments passed to the chat method that invoked this function. + kwargs: Additional runtime keyword arguments forwarded to the function invocation. Examples: .. code-block:: python @@ -227,7 +236,7 @@ def __init__( arguments: The validated arguments for the function. metadata: Metadata dictionary for sharing data between function middleware. result: Function execution result. - kwargs: Additional keyword arguments passed to the chat method that invoked this function. + kwargs: Additional runtime keyword arguments forwarded to the function invocation. """ self.function = function self.arguments = arguments @@ -253,6 +262,7 @@ class ChatContext: For non-streaming: should be ChatResponse. For streaming: should be ResponseStream[ChatResponseUpdate, ChatResponse]. kwargs: Additional keyword arguments passed to the chat client. + function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers. stream_transform_hooks: Hooks applied to transform each streamed update. stream_result_hooks: Hooks applied to the finalized response (after finalizer). stream_cleanup_hooks: Hooks executed after stream consumption (before finalizer). @@ -289,6 +299,7 @@ def __init__( metadata: Mapping[str, Any] | None = None, result: ChatResponse | ResponseStream[ChatResponseUpdate, ChatResponse] | None = None, kwargs: Mapping[str, Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, stream_transform_hooks: Sequence[ Callable[[ChatResponseUpdate], ChatResponseUpdate | Awaitable[ChatResponseUpdate]] ] @@ -306,6 +317,7 @@ def __init__( metadata: Metadata dictionary for sharing data between chat middleware. result: Chat execution result. kwargs: Additional keyword arguments passed to the chat client. + function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers. stream_transform_hooks: Transform hooks to apply to each streamed update. stream_result_hooks: Result hooks to apply to the finalized streaming response. stream_cleanup_hooks: Cleanup hooks to run after streaming completes. @@ -317,6 +329,9 @@ def __init__( self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {} self.result = result self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {} + self.function_invocation_kwargs: dict[str, Any] = ( + dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {} + ) self.stream_transform_hooks = list(stream_transform_hooks or []) self.stream_result_hooks = list(stream_result_hooks or []) self.stream_cleanup_hooks = list(stream_cleanup_hooks or []) @@ -969,6 +984,7 @@ def get_response( *, stream: Literal[False] = ..., options: ChatOptions[ResponseModelBoundT], + function_invocation_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... @@ -979,6 +995,8 @@ def get_response( *, stream: Literal[False] = ..., options: OptionsCoT | ChatOptions[None] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]]: ... @@ -989,6 +1007,8 @@ def get_response( *, stream: Literal[True], options: OptionsCoT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... @@ -998,14 +1018,17 @@ def get_response( *, stream: bool = False, options: OptionsCoT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: """Execute the chat pipeline if middleware is configured.""" super_get_response = super().get_response # type: ignore[misc] - call_middleware = kwargs.pop("middleware", []) + effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + call_middleware = kwargs.pop("middleware", effective_client_kwargs.pop("middleware", [])) middleware = categorize_middleware(call_middleware) - kwargs["function_middleware"] = middleware["function"] + effective_client_kwargs["function_middleware"] = middleware["function"] pipeline = ChatMiddlewarePipeline( *self.chat_middleware, @@ -1016,6 +1039,8 @@ def get_response( messages=messages, stream=stream, options=options, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=effective_client_kwargs, **kwargs, ) @@ -1024,7 +1049,8 @@ def get_response( messages=list(messages), options=options, stream=stream, - kwargs=kwargs, + kwargs={**effective_client_kwargs, **kwargs}, + function_invocation_kwargs=function_invocation_kwargs, ) async def _execute() -> ChatResponse | ResponseStream[ChatResponseUpdate, ChatResponse] | None: @@ -1061,7 +1087,8 @@ def _middleware_handler( messages=context.messages, stream=context.stream, options=context.options or {}, - **context.kwargs, + function_invocation_kwargs=context.function_invocation_kwargs, + client_kwargs=context.kwargs, ) @@ -1091,6 +1118,8 @@ def run( session: AgentSession | None = None, middleware: Sequence[MiddlewareTypes] | None = None, options: ChatOptions[ResponseModelBoundT], + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ... @@ -1103,6 +1132,8 @@ def run( session: AgentSession | None = None, middleware: Sequence[MiddlewareTypes] | None = None, options: ChatOptions[None] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @@ -1115,6 +1146,8 @@ def run( session: AgentSession | None = None, middleware: Sequence[MiddlewareTypes] | None = None, options: ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... @@ -1126,6 +1159,8 @@ def run( session: AgentSession | None = None, middleware: Sequence[MiddlewareTypes] | None = None, options: ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """MiddlewareTypes-enabled unified run method.""" @@ -1145,12 +1180,23 @@ def run( + run_middleware_list["function"] + run_middleware_list["chat"] ) - combined_kwargs = dict(kwargs) - combined_kwargs["middleware"] = combined_function_chat_middleware if combined_function_chat_middleware else None - + effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + if combined_function_chat_middleware: + effective_client_kwargs["middleware"] = combined_function_chat_middleware + effective_function_invocation_kwargs = ( + dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {} + ) # Execute with middleware if available if not pipeline.has_middlewares: - return super().run(messages, stream=stream, session=session, options=options, **combined_kwargs) # type: ignore[misc, no-any-return] + return super().run( # type: ignore[misc, no-any-return] + messages, + stream=stream, + session=session, + options=options, + function_invocation_kwargs=effective_function_invocation_kwargs, + client_kwargs=effective_client_kwargs, + **kwargs, + ) context = AgentContext( agent=self, # type: ignore[arg-type] @@ -1158,7 +1204,9 @@ def run( session=session, options=options, stream=stream, - kwargs=combined_kwargs, + kwargs=kwargs, + client_kwargs=effective_client_kwargs, + function_invocation_kwargs=effective_function_invocation_kwargs, ) async def _execute() -> AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None: @@ -1190,12 +1238,20 @@ async def _execute_stream() -> ResponseStream[AgentResponseUpdate, AgentResponse def _middleware_handler( self, context: AgentContext ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + # TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed. + client_kwargs = {**context.client_kwargs, **context.kwargs} + # TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed. + function_invocation_kwargs = { + **context.function_invocation_kwargs, + **{k: v for k, v in context.kwargs.items() if k != "middleware"}, + } return super().run( # type: ignore[misc, no-any-return] context.messages, stream=context.stream, session=context.session, options=context.options, - **context.kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ) diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 8c3457da26..9eca419df0 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -392,12 +392,16 @@ def __init__( self.store_outputs = store_outputs @abstractmethod - async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]: + async def get_messages( + self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any + ) -> list[Message]: """Retrieve stored messages for this session. Args: session_id: The session ID to retrieve messages for. - **kwargs: Additional arguments (e.g., ``state`` for in-memory providers). + state: Optional session state for providers that persist in session state. + Not used by all providers. + **kwargs: Additional subclass-specific extensibility arguments. Returns: List of stored messages. @@ -405,13 +409,22 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess ... @abstractmethod - async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None: + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: """Persist messages for this session. Args: session_id: The session ID to store messages for. messages: The messages to persist. - **kwargs: Additional arguments (e.g., ``state`` for in-memory providers). + state: Optional session state for providers that persist in session state. + Not used by all providers. + **kwargs: Additional subclass-specific extensibility arguments. """ ... diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 105738e717..67015cf061 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -7,6 +7,8 @@ import json import logging import sys +import typing +import warnings from collections.abc import ( AsyncIterable, Awaitable, @@ -37,7 +39,7 @@ from pydantic import BaseModel, Field, ValidationError, create_model from ._serialization import SerializationMixin -from .exceptions import ToolException +from .exceptions import ToolException, UserInputRequiredException from .observability import ( OPERATION_DURATION_BUCKET_BOUNDARIES, OtelAttr, @@ -60,7 +62,7 @@ if TYPE_CHECKING: from ._clients import SupportsChatGetResponse from ._mcp import MCPTool - from ._middleware import FunctionMiddlewarePipeline, FunctionMiddlewareTypes + from ._middleware import FunctionInvocationContext, FunctionMiddlewarePipeline, FunctionMiddlewareTypes from ._types import ( ChatOptions, ChatResponse, @@ -173,6 +175,16 @@ def _default_histogram() -> Histogram: ) +def _annotation_includes_function_invocation_context(annotation: Any) -> bool: + """Check whether an annotation resolves to FunctionInvocationContext.""" + from ._middleware import FunctionInvocationContext + + candidates = get_args(annotation) or (annotation,) + return any( + candidate is FunctionInvocationContext or candidate == "FunctionInvocationContext" for candidate in candidates + ) + + ClassT = TypeVar("ClassT", bound="SerializationMixin") @@ -309,6 +321,12 @@ def __init__( # FunctionTool-specific attributes self.func = func self._instance = None # Store the instance for bound methods + self._context_parameter_name: str | None = None + self._input_model_explicitly_provided = input_model is not None + # TODO(Copilot): Delete once legacy ``**kwargs`` runtime injection is removed. + self._forward_runtime_kwargs: bool = False + if self.func: + self._discover_injected_parameters() # Initialize schema cache (will be lazily populated) self._input_schema_cached: dict[str, Any] | None = None @@ -335,13 +353,37 @@ def __init__( self._invocation_duration_histogram = _default_histogram() self.type: Literal["function_tool"] = "function_tool" self.result_parser = result_parser - self._forward_runtime_kwargs: bool = False - if self.func: - sig = inspect.signature(self.func) - for param in sig.parameters.values(): - if param.kind == inspect.Parameter.VAR_KEYWORD: - self._forward_runtime_kwargs = True - break + + def _discover_injected_parameters(self) -> None: + """Inspect the wrapped function for runtime injection parameters.""" + func = self.func.func if isinstance(self.func, FunctionTool) else self.func + if func is None: + return + + signature = inspect.signature(func) + try: + type_hints = typing.get_type_hints(func) + except Exception: + type_hints = {name: param.annotation for name, param in signature.parameters.items()} + + for name, param in signature.parameters.items(): + if name in {"self", "cls"}: + continue + if param.kind == inspect.Parameter.VAR_KEYWORD: + self._forward_runtime_kwargs = True + continue + + annotation = type_hints.get(name, param.annotation) + if self._is_context_parameter(name, annotation): + if self._context_parameter_name is not None: + raise ValueError(f"Function '{self.name}' defines multiple FunctionInvocationContext parameters.") + self._context_parameter_name = name + + def _is_context_parameter(self, name: str, annotation: Any) -> bool: + """Check whether a callable parameter should receive FunctionInvocationContext injection.""" + if _annotation_includes_function_invocation_context(annotation): + return True + return self._input_model_explicitly_provided and name == "ctx" and annotation is inspect.Parameter.empty def __str__(self) -> str: """Return a string representation of the tool.""" @@ -410,6 +452,7 @@ def _resolve_input_model(self, input_model: type[BaseModel] | None) -> type[Base ) for pname, param in sig.parameters.items() if pname not in {"self", "cls"} + and pname != self._context_parameter_name and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} } return create_model(f"{self.name}_input", **fields) @@ -447,6 +490,7 @@ async def invoke( self, *, arguments: BaseModel | Mapping[str, Any] | None = None, + context: FunctionInvocationContext | None = None, **kwargs: Any, ) -> str: """Run the AI function with the provided arguments as a Pydantic model. @@ -457,7 +501,8 @@ async def invoke( Keyword Args: arguments: A mapping or model instance containing the arguments for the function. - kwargs: Keyword arguments to pass to the function, will not be used if ``arguments`` is provided. + context: Explicit function invocation context carrying runtime kwargs. + kwargs: Deprecated keyword arguments to pass to the function. Use ``context`` instead. Returns: The parsed result as a string — either plain text or serialized JSON. @@ -468,13 +513,36 @@ async def invoke( if self.declaration_only: raise ToolException(f"Function '{self.name}' is declaration only and cannot be invoked.") global OBSERVABILITY_SETTINGS + from ._middleware import FunctionInvocationContext from .observability import OBSERVABILITY_SETTINGS parser = self.result_parser or FunctionTool.parse_result - original_kwargs = dict(kwargs) - tool_call_id = original_kwargs.pop("tool_call_id", None) - if arguments is not None: + parameter_names = set(self.parameters().get("properties", {}).keys()) + direct_argument_kwargs = ( + {key: value for key, value in kwargs.items() if key in parameter_names} if arguments is None else {} + ) + runtime_kwargs = dict(context.kwargs) if context is not None else {} + deprecated_runtime_kwargs = { + key: value for key, value in kwargs.items() if key not in direct_argument_kwargs and key != "tool_call_id" + } + if deprecated_runtime_kwargs: + warnings.warn( + "Passing runtime keyword arguments directly to FunctionTool.invoke() is deprecated; " + "pass them via FunctionInvocationContext instead.", + DeprecationWarning, + stacklevel=2, + ) + runtime_kwargs.update(deprecated_runtime_kwargs) + tool_call_id = kwargs.get("tool_call_id", runtime_kwargs.pop("tool_call_id", None)) + if arguments is None and direct_argument_kwargs: + arguments = direct_argument_kwargs + if arguments is None and context is not None: + arguments = context.arguments + + if arguments is None: + validated_arguments: dict[str, Any] = {} + else: try: if isinstance(arguments, Mapping): parsed_arguments = dict(arguments) @@ -496,19 +564,42 @@ async def invoke( ) except ValidationError as exc: raise TypeError(f"Invalid arguments for '{self.name}': {exc}") from exc - kwargs = _validate_arguments_against_schema( + + validated_arguments = _validate_arguments_against_schema( arguments=parsed_arguments, schema=self.parameters(), tool_name=self.name, ) - if getattr(self, "_forward_runtime_kwargs", False) and original_kwargs: - kwargs.update(original_kwargs) - else: - kwargs = original_kwargs + + effective_context = context + if effective_context is None and self._context_parameter_name is not None: + effective_context = FunctionInvocationContext( + function=self, + arguments=validated_arguments, + kwargs=runtime_kwargs, + ) + if effective_context is not None: + effective_context.function = self + effective_context.arguments = validated_arguments + effective_context.kwargs = dict(runtime_kwargs) + + call_kwargs = dict(validated_arguments) + observable_kwargs = dict(validated_arguments) + + # Legacy runtime kwargs injection path retained for backwards compatibility with tools + # that still declare ``**kwargs``. New tools should consume runtime data via ``ctx``. + legacy_runtime_kwargs = dict(runtime_kwargs) + if self._forward_runtime_kwargs and legacy_runtime_kwargs: + call_kwargs.update(legacy_runtime_kwargs) + observable_kwargs.update(legacy_runtime_kwargs) + + if self._context_parameter_name is not None and effective_context is not None: + call_kwargs[self._context_parameter_name] = effective_context + if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined] logger.info(f"Function name: {self.name}") - logger.debug(f"Function arguments: {kwargs}") - res = self.__call__(**kwargs) + logger.debug(f"Function arguments: {observable_kwargs}") + res = self.__call__(**call_kwargs) result = await res if inspect.isawaitable(res) else res try: parsed = parser(result) @@ -523,7 +614,7 @@ async def invoke( # Filter out framework kwargs that are not JSON serializable. serializable_kwargs = { k: v - for k, v in kwargs.items() + for k, v in observable_kwargs.items() if k not in { "chat_options", @@ -549,7 +640,7 @@ async def invoke( start_time_stamp = perf_counter() end_time_stamp: float | None = None try: - res = self.__call__(**kwargs) + res = self.__call__(**call_kwargs) result = await res if inspect.isawaitable(res) else res end_time_stamp = perf_counter() except Exception as exception: @@ -1215,19 +1306,30 @@ async def _auto_invoke_function( additional_properties=function_call_content.additional_properties, ) + from ._middleware import FunctionInvocationContext + if middleware_pipeline is None or not middleware_pipeline.has_middlewares: # No middleware - execute directly try: + direct_context = None + if getattr(tool, "_forward_runtime_kwargs", False) or getattr(tool, "_context_parameter_name", None): + direct_context = FunctionInvocationContext( + function=tool, + arguments=args, + kwargs=runtime_kwargs.copy(), + ) function_result = await tool.invoke( arguments=args, + context=direct_context, tool_call_id=function_call_content.call_id, - **runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {}, ) return Content.from_function_result( call_id=function_call_content.call_id, # type: ignore[arg-type] result=function_result, additional_properties=function_call_content.additional_properties, ) + except UserInputRequiredException: + raise except Exception as exc: message = "Error: Function failed." if config.get("include_detailed_errors", False): @@ -1239,8 +1341,6 @@ async def _auto_invoke_function( additional_properties=function_call_content.additional_properties, ) # Execute through middleware pipeline if available - from ._middleware import FunctionInvocationContext - middleware_context = FunctionInvocationContext( function=tool, arguments=args, @@ -1250,8 +1350,8 @@ async def _auto_invoke_function( async def final_function_handler(context_obj: Any) -> Any: return await tool.invoke( arguments=context_obj.arguments, + context=context_obj, tool_call_id=function_call_content.call_id, - **context_obj.kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {}, ) from ._middleware import MiddlewareTermination @@ -1274,6 +1374,8 @@ async def final_function_handler(context_obj: Any) -> Any: additional_properties=function_call_content.additional_properties, ) raise + except UserInputRequiredException: + raise except Exception as exc: message = "Error: Function failed." if config.get("include_detailed_errors", False): @@ -1381,6 +1483,8 @@ async def _try_execute_function_calls( # Run all function calls concurrently, handling MiddlewareTermination from ._middleware import MiddlewareTermination + extra_user_input_contents: list[Content] = [] + async def invoke_with_termination_handling( function_call: Content, seq_idx: int, @@ -1407,6 +1511,26 @@ async def invoke_with_termination_handling( result=exc.result, ) return (result_content, True) + except UserInputRequiredException as exc: + if exc.contents: + propagated: list[Content] = [] + for item in exc.contents: + if isinstance(item, Content): + item.call_id = function_call.call_id # type: ignore[attr-defined] + if not item.id: # type: ignore[attr-defined] + item.id = function_call.call_id # type: ignore[attr-defined] + propagated.append(item) + if propagated: + extra_user_input_contents.extend(propagated[1:]) + return (propagated[0], False) + return ( + Content.from_function_result( + call_id=function_call.call_id, # type: ignore[arg-type] + result="Tool requires user input but no request details were provided.", + exception="UserInputRequiredException", + ), + False, + ) execution_results = await asyncio.gather(*[ invoke_with_termination_handling(function_call, seq_idx) for seq_idx, function_call in enumerate(function_calls) @@ -1414,6 +1538,7 @@ async def invoke_with_termination_handling( # Unpack results - each is (Content, terminate_flag) contents: list[Content] = [result[0] for result in execution_results] + contents.extend(extra_user_input_contents) # If any function requested termination, terminate the loop should_terminate = any(result[1] for result in execution_results) return (contents, should_terminate) @@ -1645,7 +1770,10 @@ def _handle_function_call_results( ) -> FunctionRequestResult: from ._types import Message - if any(fccr.type in {"function_approval_request", "function_call"} for fccr in function_call_results): + if any( + fccr.type in {"function_approval_request", "function_call"} or fccr.user_input_request + for fccr in function_call_results + ): # Only add items that aren't already in the message (e.g. function_approval_request wrappers). # Declaration-only function_call items are already present from the LLM response. new_items = [fccr for fccr in function_call_results if fccr.type != "function_call"] @@ -1811,6 +1939,8 @@ def get_response( *, stream: Literal[False] = ..., options: ChatOptions[ResponseModelBoundT], + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... @@ -1821,6 +1951,8 @@ def get_response( *, stream: Literal[False] = ..., options: OptionsCoT | ChatOptions[None] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]]: ... @@ -1831,6 +1963,8 @@ def get_response( *, stream: Literal[True], options: OptionsCoT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... @@ -1841,6 +1975,8 @@ def get_response( stream: bool = False, options: OptionsCoT | ChatOptions[Any] | None = None, function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: from ._middleware import FunctionMiddlewarePipeline @@ -1851,24 +1987,40 @@ def get_response( ) super_get_response = super().get_response # type: ignore[misc] + if kwargs: + warnings.warn( + "Passing client-specific keyword arguments directly to get_response() is deprecated; " + "pass them via client_kwargs instead.", + DeprecationWarning, + stacklevel=2, + ) + + effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + effective_function_middleware = function_middleware + if effective_function_middleware is None: + middleware_from_client_kwargs = effective_client_kwargs.pop("function_middleware", None) + if middleware_from_client_kwargs is not None: + effective_function_middleware = cast(Sequence[Any], middleware_from_client_kwargs) # ChatMiddleware adds this kwarg function_middleware_pipeline = FunctionMiddlewarePipeline( - *(self.function_middleware), *(function_middleware or []) + *(self.function_middleware), *(effective_function_middleware or []) ) max_errors = self.function_invocation_configuration.get( "max_consecutive_errors_per_request", DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST ) - additional_function_arguments: dict[str, Any] = {} + additional_function_arguments = ( + dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {} + ) if options and (additional_opts := options.get("additional_function_arguments")): # type: ignore[attr-defined] - additional_function_arguments = additional_opts # type: ignore + additional_function_arguments.update(cast(Mapping[str, Any], additional_opts)) execute_function_calls = partial( _execute_function_calls, custom_args=additional_function_arguments, config=self.function_invocation_configuration, middleware_pipeline=function_middleware_pipeline, ) - filtered_kwargs = {k: v for k, v in kwargs.items() if k != "session"} + filtered_kwargs = {k: v for k, v in {**effective_client_kwargs, **kwargs}.items() if k != "session"} # Make options mutable so we can update conversation_id during function invocation loop mutable_options: dict[str, Any] = dict(options) if options else {} @@ -1918,7 +2070,7 @@ async def _get_response() -> ChatResponse[Any]: messages=prepped_messages, stream=False, options=mutable_options, - **filtered_kwargs, + client_kwargs=filtered_kwargs, ), ) @@ -1987,7 +2139,7 @@ async def _get_response() -> ChatResponse[Any]: messages=prepped_messages, stream=False, options=mutable_options, - **filtered_kwargs, + client_kwargs=filtered_kwargs, ), ) if fcc_messages: @@ -2037,7 +2189,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: messages=prepped_messages, stream=True, options=mutable_options, - **filtered_kwargs, + client_kwargs=filtered_kwargs, ), ) await inner_stream @@ -2129,7 +2281,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: messages=prepped_messages, stream=True, options=mutable_options, - **filtered_kwargs, + client_kwargs=filtered_kwargs, ), ) await final_inner_stream diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index b8d5f5c29a..6606a3b2c5 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -2631,7 +2631,7 @@ def __init__( stream: AsyncIterable[UpdateT] | Awaitable[AsyncIterable[UpdateT]], *, finalizer: Callable[[Sequence[UpdateT]], FinalT | Awaitable[FinalT]] | None = None, - transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None]] | None = None, + transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None]] | None = None, cleanup_hooks: list[Callable[[], Awaitable[None] | None]] | None = None, result_hooks: list[Callable[[FinalT], FinalT | Awaitable[FinalT | None] | None]] | None = None, ) -> None: @@ -2655,7 +2655,7 @@ def __init__( self._consumed: bool = False self._finalized: bool = False self._final_result: FinalT | None = None - self._transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None]] = ( + self._transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None]] = ( transform_hooks if transform_hooks is not None else [] ) self._result_hooks: list[Callable[[FinalT], FinalT | Awaitable[FinalT | None] | None]] = ( @@ -2928,7 +2928,7 @@ async def get_final_response(self) -> FinalT: def with_transform_hook( self, - hook: Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None], + hook: Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None], ) -> ResponseStream[UpdateT, FinalT]: """Register a transform hook executed for each update during iteration.""" self._transform_hooks.append(hook) diff --git a/python/packages/core/agent_framework/azure/_chat_client.py b/python/packages/core/agent_framework/azure/_chat_client.py index b57abd6faf..21c38f6b57 100644 --- a/python/packages/core/agent_framework/azure/_chat_client.py +++ b/python/packages/core/agent_framework/azure/_chat_client.py @@ -172,12 +172,12 @@ def __init__( credential: AzureCredentialTypes | AzureTokenProvider | None = None, default_headers: Mapping[str, str] | None = None, async_client: AsyncAzureOpenAI | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, instruction_role: str | None = None, middleware: Sequence[MiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, - **kwargs: Any, ) -> None: """Initialize an Azure OpenAI Chat completion client. @@ -205,13 +205,13 @@ def __init__( default_headers: The default headers mapping of string keys to string values for HTTP requests. async_client: An existing client to use. + additional_properties: Additional properties stored on the client instance. env_file_path: Use the environment settings file as a fallback to using env vars. env_file_encoding: The encoding of the environment settings file, defaults to 'utf-8'. instruction_role: The role to use for 'instruction' messages, for example, summarization prompts could use `developer` or `system`. middleware: Optional sequence of middleware to apply to requests. function_invocation_configuration: Optional configuration for function invocation behavior. - kwargs: Other keyword parameters. Examples: .. code-block:: python @@ -283,10 +283,10 @@ class MyOptions(AzureOpenAIChatOptions, total=False): credential=credential, default_headers=default_headers, client=async_client, + additional_properties=additional_properties, instruction_role=instruction_role, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) @override diff --git a/python/packages/core/agent_framework/exceptions.py b/python/packages/core/agent_framework/exceptions.py index f38aa38590..4f56c34b5c 100644 --- a/python/packages/core/agent_framework/exceptions.py +++ b/python/packages/core/agent_framework/exceptions.py @@ -180,6 +180,34 @@ class ToolExecutionException(ToolException): pass +class UserInputRequiredException(ToolException): + """Raised when a tool wrapping a sub-agent requires user input to proceed. + + This exception carries the ``user_input_request`` Content items emitted by + the sub-agent (e.g., ``oauth_consent_request``, ``function_approval_request``) + so the tool invocation layer can propagate them to the parent agent's response + instead of swallowing them as a generic tool error. + + Args: + contents: The user-input-request Content items from the sub-agent response. + message: Human-readable description of why user input is needed. + """ + + def __init__( + self, + contents: list[Any], + message: str = "Tool requires user input to proceed.", + ) -> None: + """Create a UserInputRequiredException. + + Args: + contents: The user-input-request Content items from the sub-agent response. + message: Human-readable description of why user input is needed. + """ + super().__init__(message, log_level=None) + self.contents = contents + + # endregion # region Middleware Exceptions diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index a595582b33..71b8702c14 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1153,18 +1153,47 @@ def get_response( options: OptionsCoT | ChatOptions[Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: - """Trace chat responses with OpenTelemetry spans and metrics.""" + """Trace chat responses with OpenTelemetry spans and metrics. + + Args: + messages: The message or messages to send to the model. + stream: Whether to stream the response. Defaults to False. + options: Chat options as a TypedDict. + + Keyword Args: + kwargs: Compatibility keyword arguments from higher client layers. This layer does + not consume ``function_invocation_kwargs`` directly; if present, it is ignored + because function invocation has already been processed above. If a ``client_kwargs`` + mapping is present, it is flattened into ordinary keyword arguments for tracing and + forwarding so clients that use those values continue to work while clients that + ignore extra kwargs remain compatible. + """ from ._types import ChatResponse, ChatResponseUpdate, ResponseStream # type: ignore[reportUnusedImport] global OBSERVABILITY_SETTINGS super_get_response = super().get_response # type: ignore[misc] + compatibility_client_kwargs = kwargs.pop("client_kwargs", None) + kwargs.pop("function_invocation_kwargs", None) + merged_client_kwargs = ( + dict(cast(Mapping[str, Any], compatibility_client_kwargs)) + if isinstance(compatibility_client_kwargs, Mapping) + else {} + ) + merged_client_kwargs.update(kwargs) if not OBSERVABILITY_SETTINGS.ENABLED: - return super_get_response(messages=messages, stream=stream, options=options, **kwargs) # type: ignore[no-any-return] + return super_get_response( # type: ignore[no-any-return] + messages=messages, + stream=stream, + options=options, + **merged_client_kwargs, + ) opts: dict[str, Any] = options or {} # type: ignore[assignment] provider_name = str(getattr(self, "otel_provider_name", "unknown")) - model_id = kwargs.get("model_id") or opts.get("model_id") or getattr(self, "model_id", None) or "unknown" + model_id = ( + merged_client_kwargs.get("model_id") or opts.get("model_id") or getattr(self, "model_id", None) or "unknown" + ) service_url_func = getattr(self, "service_url", None) service_url = str(service_url_func() if callable(service_url_func) else "unknown") attributes = _get_span_attributes( @@ -1172,13 +1201,18 @@ def get_response( provider_name=provider_name, model=model_id, service_url=service_url, - **kwargs, + **merged_client_kwargs, ) if stream: result_stream = cast( ResponseStream[ChatResponseUpdate, ChatResponse[Any]], - super_get_response(messages=messages, stream=True, options=opts, **kwargs), + super_get_response( + messages=messages, + stream=True, + options=opts, + **merged_client_kwargs, + ), ) # Create span directly without trace.use_span() context attachment. @@ -1266,7 +1300,7 @@ async def _get_response() -> ChatResponse: messages=messages, stream=False, options=opts, - **kwargs, + **merged_client_kwargs, ), ) except Exception as exception: @@ -1393,6 +1427,8 @@ def run( *, stream: Literal[False] = ..., session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @@ -1403,6 +1439,8 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... @@ -1412,6 +1450,8 @@ def run( *, stream: bool = False, session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Trace agent runs with OpenTelemetry spans and metrics.""" @@ -1430,11 +1470,15 @@ def run( messages=messages, stream=stream, session=session, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, **kwargs, ) default_options = getattr(self, "default_options", {}) options = kwargs.get("options") + merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + merged_client_kwargs.update(kwargs) merged_options: dict[str, Any] = merge_chat_options(default_options, options or {}) attributes = _get_span_attributes( operation_name=OtelAttr.AGENT_INVOKE_OPERATION, @@ -1444,7 +1488,7 @@ def run( agent_description=getattr(self, "description", None), thread_id=session.service_session_id if session else None, all_options=merged_options, - **kwargs, + **merged_client_kwargs, ) if stream: @@ -1452,6 +1496,8 @@ def run( messages=messages, stream=True, session=session, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, **kwargs, ) if isinstance(run_result, ResponseStream): @@ -1541,6 +1587,8 @@ async def _run() -> AgentResponse: messages=messages, stream=False, session=session, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, **kwargs, ) except Exception as exception: diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index 0562e68f3e..82d47ec054 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -15,7 +15,7 @@ ) from datetime import datetime, timezone from itertools import chain -from typing import Any, Generic, Literal, cast +from typing import Any, Generic, Literal, cast, overload from openai import AsyncOpenAI, BadRequestError from openai.lib._parsing._completions import type_to_response_format_param @@ -30,7 +30,8 @@ from pydantic import BaseModel from .._clients import BaseChatClient -from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer +from .._docstrings import apply_layered_docstring +from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer, FunctionMiddlewareTypes from .._settings import load_settings from .._tools import ( FunctionInvocationConfiguration, @@ -72,6 +73,7 @@ logger = logging.getLogger("agent_framework.openai") +ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel) ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None) @@ -213,6 +215,57 @@ def get_web_search_tool( # endregion + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[False] = ..., + options: ChatOptions[ResponseModelBoundT], + **kwargs: Any, + ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... + + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[False] = ..., + options: OpenAIChatOptionsT | ChatOptions[None] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[Any]]: ... + + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[True], + options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, + **kwargs: Any, + ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... + + @override + def get_response( + self, + messages: Sequence[Message], + *, + stream: bool = False, + options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: + """Get a response from the raw OpenAI chat client.""" + super_get_response = cast( + "Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]", + super().get_response, # type: ignore[misc] + ) + return super_get_response( # type: ignore[no-any-return] + messages=messages, + stream=stream, + options=options, + **kwargs, + ) + @override def _inner_get_response( self, @@ -716,6 +769,77 @@ class OpenAIChatClient( # type: ignore[misc] ): """OpenAI Chat completion class with middleware, telemetry, and function invocation support.""" + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[False] = ..., + options: ChatOptions[ResponseModelBoundT], + function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... + + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[False] = ..., + options: OpenAIChatOptionsT | ChatOptions[None] | None = None, + function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[Any]]: ... + + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[True], + options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, + function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + **kwargs: Any, + ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... + + @override + def get_response( + self, + messages: Sequence[Message], + *, + stream: bool = False, + options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, + function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: + """Get a response from the OpenAI chat client with all standard layers enabled.""" + super_get_response = cast( + "Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]", + super().get_response, # type: ignore[misc] + ) + return super_get_response( # type: ignore[no-any-return] + messages=messages, + stream=stream, + options=options, + function_middleware=function_middleware, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + middleware=middleware, + **kwargs, + ) + def __init__( self, *, @@ -819,3 +943,25 @@ class MyOptions(OpenAIChatOptions, total=False): middleware=middleware, function_invocation_configuration=function_invocation_configuration, ) + + +def _apply_openai_chat_client_docstrings() -> None: + """Align OpenAI chat-client docstrings with the raw implementation.""" + apply_layered_docstring(RawOpenAIChatClient.get_response, BaseChatClient.get_response) + apply_layered_docstring( + OpenAIChatClient.get_response, + RawOpenAIChatClient.get_response, + extra_keyword_args={ + "function_middleware": """ + Optional per-call function middleware. + When omitted, middleware configured on the client or forwarded from higher layers is used. + """, + "middleware": """ + Optional per-call chat and function middleware. + This is merged with any middleware configured on the client for the current request. + """, + }, + ) + + +_apply_openai_chat_client_docstrings() diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index a60e924387..ec93d6c242 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import contextlib +import inspect from collections.abc import AsyncIterable, MutableSequence from typing import Any from unittest.mock import AsyncMock, MagicMock @@ -27,6 +28,7 @@ ) from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name from agent_framework._mcp import MCPTool +from agent_framework._middleware import FunctionInvocationContext def test_agent_session_type(agent_session: AgentSession) -> None: @@ -65,6 +67,30 @@ def test_chat_client_agent_type(client: SupportsChatGetResponse) -> None: assert isinstance(chat_client_agent, SupportsAgentRun) +def test_agent_init_docstring_surfaces_raw_agent_constructor_docs() -> None: + docstring = inspect.getdoc(Agent.__init__) + + assert docstring is not None + assert "client: The chat client to use for the agent." in docstring + assert "middleware: List of middleware to intercept agent and function invocations." in docstring + + +def test_agent_run_docstring_surfaces_raw_agent_runtime_docs() -> None: + docstring = inspect.getdoc(Agent.run) + + assert docstring is not None + assert "Run the agent with the given messages and options." in docstring + assert "function_invocation_kwargs: Keyword arguments forwarded to tool invocation." in docstring + assert "middleware: Optional per-run agent, chat, and function middleware." in docstring + + +def test_agent_run_is_defined_on_agent_class() -> None: + signature = inspect.signature(Agent.run) + + assert Agent.run.__qualname__ == "Agent.run" + assert "middleware" in signature.parameters + + async def test_chat_client_agent_init(client: SupportsChatGetResponse) -> None: agent_id = str(uuid4()) agent = Agent(client=client, id=agent_id, description="Test") @@ -85,6 +111,13 @@ async def test_chat_client_agent_init_with_name( assert agent.description == "Test" +def test_agent_init_warns_for_direct_additional_properties(client: SupportsChatGetResponse) -> None: + with pytest.warns(DeprecationWarning, match="additional_properties"): + agent = Agent(client=client, legacy_key="legacy-value") + + assert agent.additional_properties["legacy_key"] == "legacy-value" + + async def test_chat_client_agent_run(client: SupportsChatGetResponse) -> None: agent = Agent(client=client) @@ -217,9 +250,36 @@ async def test_prepare_session_does_not_mutate_agent_chat_options( assert len(agent.default_options["tools"]) == 1 -async def test_chat_client_agent_run_with_session( +async def test_prepare_run_context_separates_function_invocation_kwargs_from_chat_options( chat_client_base: SupportsChatGetResponse, ) -> None: + agent = Agent(client=chat_client_base) + session = agent.create_session() + + ctx = await agent._prepare_run_context( # type: ignore[reportPrivateUsage] + messages="Hello", + session=session, + tools=None, + options={ + "temperature": 0.4, + "additional_function_arguments": {"from_options": "options-value"}, + }, + legacy_kwargs={"legacy_key": "legacy-value"}, + function_invocation_kwargs={"runtime_key": "runtime-value"}, + client_kwargs={"client_key": "client-value"}, + ) + + assert ctx["chat_options"]["temperature"] == 0.4 + assert "additional_function_arguments" not in ctx["chat_options"] + assert ctx["function_invocation_kwargs"]["from_options"] == "options-value" + assert ctx["function_invocation_kwargs"]["legacy_key"] == "legacy-value" + assert ctx["function_invocation_kwargs"]["runtime_key"] == "runtime-value" + assert "session" not in ctx["function_invocation_kwargs"] + assert ctx["client_kwargs"]["client_key"] == "client-value" + assert "session" not in ctx["client_kwargs"] + + +async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChatGetResponse) -> None: mock_response = ChatResponse( messages=[Message(role="assistant", contents=[Content.from_text("test response")])], conversation_id="123", @@ -660,8 +720,9 @@ async def test_chat_agent_as_tool_basic(client: SupportsChatGetResponse) -> None assert tool.name == "TestAgent" assert tool.description == "Test agent for as_tool" + assert tool.approval_mode == "never_require" assert hasattr(tool, "func") - assert hasattr(tool, "input_model") + assert tool.input_model is None async def test_chat_agent_as_tool_custom_parameters( @@ -675,13 +736,15 @@ async def test_chat_agent_as_tool_custom_parameters( description="Custom description", arg_name="query", arg_description="Custom input description", + approval_mode="always_require", ) assert tool.name == "CustomTool" assert tool.description == "Custom description" + assert tool.approval_mode == "always_require" # Check that the input model has the custom field name - schema = tool.input_model.model_json_schema() + schema = tool.parameters() assert "query" in schema["properties"] assert schema["properties"]["query"]["description"] == "Custom input description" @@ -700,7 +763,7 @@ async def test_chat_agent_as_tool_defaults(client: SupportsChatGetResponse) -> N assert tool.description == "" # Should default to empty string # Check default input field - schema = tool.input_model.model_json_schema() + schema = tool.parameters() assert "task" in schema["properties"] assert "Task for TestAgent" in schema["properties"]["task"]["description"] @@ -723,11 +786,11 @@ async def test_chat_agent_as_tool_function_execution( tool = agent.as_tool() # Test function execution - result = await tool.invoke(arguments=tool.input_model(task="Hello")) + result = await tool.invoke(arguments={"task": "Hello"}) - # Should return the agent's response text + # as_tool always uses streaming and finalizes the accumulated updates. assert isinstance(result, str) - assert result == "test response" # From mock chat client + assert result == "test streaming response another update" async def test_chat_agent_as_tool_with_stream_callback( @@ -745,7 +808,7 @@ def stream_callback(update: AgentResponseUpdate) -> None: tool = agent.as_tool(stream_callback=stream_callback) # Execute the tool - result = await tool.invoke(arguments=tool.input_model(task="Hello")) + result = await tool.invoke(arguments={"task": "Hello"}) # Should have collected streaming updates assert len(collected_updates) > 0 @@ -764,8 +827,8 @@ async def test_chat_agent_as_tool_with_custom_arg_name( tool = agent.as_tool(arg_name="prompt", arg_description="Custom prompt input") # Test that the custom argument name works - result = await tool.invoke(arguments=tool.input_model(prompt="Test prompt")) - assert result == "test response" + result = await tool.invoke(arguments={"prompt": "Test prompt"}) + assert result == "test streaming response another update" async def test_chat_agent_as_tool_with_async_stream_callback( @@ -783,7 +846,7 @@ async def async_stream_callback(update: AgentResponseUpdate) -> None: tool = agent.as_tool(stream_callback=async_stream_callback) # Execute the tool - result = await tool.invoke(arguments=tool.input_model(task="Hello")) + result = await tool.invoke(arguments={"task": "Hello"}) # Should have collected streaming updates assert len(collected_updates) > 0 @@ -813,10 +876,8 @@ async def test_chat_agent_as_tool_name_sanitization( assert tool.name == expected_tool_name, f"Expected {expected_tool_name}, got {tool.name} for input {agent_name}" -async def test_chat_agent_as_tool_propagate_session_true( - client: SupportsChatGetResponse, -) -> None: - """Test that propagate_session=True forwards the parent's session to the sub-agent.""" +async def test_chat_agent_as_tool_propagate_session_true(client: SupportsChatGetResponse) -> None: + """Test that propagate_session=True forwards an explicitly provided session to the sub-agent.""" agent = Agent(client=client, name="SubAgent", description="Sub agent") tool = agent.as_tool(propagate_session=True) @@ -834,17 +895,21 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: agent.run = capturing_run # type: ignore[assignment, method-assign] - await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session) + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + kwargs={"session": parent_session}, + ) + ) assert captured_session is parent_session assert captured_session.session_id == "parent-session-123" assert captured_session.state["shared_key"] == "shared_value" -async def test_chat_agent_as_tool_propagate_session_false_by_default( - client: SupportsChatGetResponse, -) -> None: - """Test that propagate_session defaults to False and does not forward the session.""" +async def test_chat_agent_as_tool_propagate_session_false_by_default(client: SupportsChatGetResponse) -> None: + """Test that propagate_session defaults to False and does not forward runtime sessions.""" agent = Agent(client=client, name="SubAgent", description="Sub agent") tool = agent.as_tool() # default: propagate_session=False @@ -860,15 +925,19 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: agent.run = capturing_run # type: ignore[assignment, method-assign] - await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session) + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + kwargs={"session": parent_session}, + ) + ) assert captured_session is None -async def test_chat_agent_as_tool_propagate_session_shares_state( - client: SupportsChatGetResponse, -) -> None: - """Test that shared session allows the sub-agent to read and write parent's state.""" +async def test_chat_agent_as_tool_propagate_session_shares_state(client: SupportsChatGetResponse) -> None: + """Test that an explicitly propagated session allows the sub-agent to read and write parent state.""" agent = Agent(client=client, name="SubAgent", description="Sub agent") tool = agent.as_tool(propagate_session=True) @@ -888,7 +957,13 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: agent.run = capturing_run # type: ignore[assignment, method-assign] - await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session) + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + kwargs={"session": parent_session}, + ) + ) # The parent's state should reflect the sub-agent's mutation assert parent_session.state["counter"] == 1 @@ -992,8 +1067,8 @@ async def capturing_inner( assert len(tool_names) == 3 -async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> None: - """Verify tool execution receives 'session' inside **kwargs when function is called by client.""" +async def test_agent_tool_does_not_receive_agent_session_in_kwargs(chat_client_base: Any) -> None: + """Verify agent sessions are not injected into tool kwargs implicitly.""" captured: dict[str, Any] = {} @@ -1029,10 +1104,49 @@ def echo_session_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUn session=session, options={"additional_function_arguments": {"session": session}}, ) + assert result.text == "done" + assert captured.get("has_session") is False + assert captured.get("has_state") is False + + +async def test_agent_tool_receives_explicit_session_via_function_invocation_context_kwargs( + chat_client_base: Any, +) -> None: + """Verify ctx-based tools read explicit sessions from FunctionInvocationContext.kwargs.""" + + captured: dict[str, Any] = {} + + @tool(name="capture_session_context", approval_mode="never_require") + def capture_session_context(text: str, ctx: FunctionInvocationContext) -> str: + session = ctx.kwargs.get("session") + captured["session"] = session + captured["has_state"] = session.state is not None if isinstance(session, AgentSession) else False + return f"echo: {text}" + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="1", + name="capture_session_context", + arguments='{"text": "hello"}', + ) + ], + ) + ), + ChatResponse(messages=Message(role="assistant", text="done")), + ] + + agent = Agent(client=chat_client_base, tools=[capture_session_context]) + session = agent.create_session() + + result = await agent.run("hello", session=session, function_invocation_kwargs={"session": session}) assert result.text == "done" - assert captured.get("has_session") is True - assert captured.get("has_state") is True + assert captured["session"] is session + assert captured["has_state"] is True async def test_chat_agent_tool_choice_run_level_overrides_agent_level(chat_client_base: Any, tool_tool: Any) -> None: @@ -1622,4 +1736,26 @@ async def test_stores_by_default_with_store_false_in_default_options_injects_inm assert any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers) -# endregion +# region as_tool user_input_request propagation + + +async def test_as_tool_raises_on_user_input_request(client: SupportsChatGetResponse) -> None: + """Test that as_tool raises when the wrapped sub-agent requests user input.""" + from agent_framework.exceptions import UserInputRequiredException + + consent_content = Content.from_oauth_consent_request( + consent_link="https://login.microsoftonline.com/consent", + ) + client.streaming_responses = [ # type: ignore[attr-defined] + [ChatResponseUpdate(contents=[consent_content], role="assistant")], + ] + + agent = Agent(client=client, name="OAuthAgent", description="Agent requiring consent") + agent_tool = agent.as_tool() + + with raises(UserInputRequiredException) as exc_info: + await agent_tool.invoke(arguments={"task": "Do something"}) + + assert len(exc_info.value.contents) == 1 + assert exc_info.value.contents[0].type == "oauth_consent_request" + assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent" diff --git a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py index da8e907c40..8aa71a4582 100644 --- a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py +++ b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py @@ -6,7 +6,7 @@ from typing import Any from agent_framework import Agent, ChatResponse, Content, Message, agent_middleware -from agent_framework._middleware import AgentContext +from agent_framework._middleware import AgentContext, FunctionInvocationContext from .conftest import MockChatClient @@ -14,14 +14,28 @@ class TestAsToolKwargsPropagation: """Test cases for kwargs propagation through as_tool() delegation.""" + @staticmethod + def _build_context( + tool: Any, + *, + task: str, + runtime_kwargs: dict[str, Any] | None = None, + ) -> FunctionInvocationContext: + return FunctionInvocationContext( + function=tool, + arguments={"task": task}, + kwargs=runtime_kwargs, + ) + async def test_as_tool_forwards_runtime_kwargs(self, client: MockChatClient) -> None: - """Test that runtime kwargs are forwarded through as_tool() to sub-agent.""" + """Test that runtime kwargs are forwarded through as_tool() to sub-agent tools.""" captured_kwargs: dict[str, Any] = {} + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - # Capture kwargs passed to the sub-agent captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock response @@ -39,29 +53,31 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Create tool from sub-agent tool = sub_agent.as_tool(name="delegate", arg_name="task") - # Directly invoke the tool with kwargs (simulating what happens during agent execution) + # Directly invoke the tool with explicit runtime context (simulating agent execution). _ = await tool.invoke( - arguments=tool.input_model(task="Test delegation"), - api_token="secret-xyz-123", - user_id="user-456", - session_id="session-789", + context=self._build_context( + tool, + task="Test delegation", + runtime_kwargs={ + "api_token": "secret-xyz-123", + "user_id": "user-456", + "session_id": "session-789", + }, + ), ) - # Verify kwargs were forwarded to sub-agent - assert "api_token" in captured_kwargs, f"Expected 'api_token' in {captured_kwargs}" - assert captured_kwargs["api_token"] == "secret-xyz-123" - assert "user_id" in captured_kwargs - assert captured_kwargs["user_id"] == "user-456" - assert "session_id" in captured_kwargs - assert captured_kwargs["session_id"] == "session-789" + assert captured_kwargs == {} + assert captured_function_invocation_kwargs["api_token"] == "secret-xyz-123" + assert captured_function_invocation_kwargs["user_id"] == "user-456" + assert captured_function_invocation_kwargs["session_id"] == "session-789" - async def test_as_tool_excludes_arg_name_from_forwarded_kwargs(self, client: MockChatClient) -> None: - """Test that the arg_name parameter is not forwarded as a kwarg.""" - captured_kwargs: dict[str, Any] = {} + async def test_as_tool_forwards_context_kwargs_verbatim(self, client: MockChatClient) -> None: + """Test that runtime kwargs are forwarded exactly from FunctionInvocationContext.kwargs.""" + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock response @@ -79,25 +95,26 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Invoke tool with both the arg_name field and additional kwargs await tool.invoke( - arguments=tool.input_model(custom_task="Test task"), - api_token="token-123", - custom_task="should_be_excluded", # This should be filtered out + context=FunctionInvocationContext( + function=tool, + arguments={"custom_task": "Test task"}, + kwargs={ + "api_token": "token-123", + "custom_task": "should_be_excluded", + }, + ) ) - # The arg_name ("custom_task") should NOT be in the forwarded kwargs - assert "custom_task" not in captured_kwargs - # But other kwargs should be present - assert "api_token" in captured_kwargs - assert captured_kwargs["api_token"] == "token-123" + assert captured_function_invocation_kwargs["custom_task"] == "should_be_excluded" + assert captured_function_invocation_kwargs["api_token"] == "token-123" async def test_as_tool_nested_delegation_propagates_kwargs(self, client: MockChatClient) -> None: - """Test that kwargs propagate through multiple levels of delegation (A → B → C).""" - captured_kwargs_list: list[dict[str, Any]] = [] + """Test that runtime kwargs propagate through multiple levels of delegation (A -> B -> C).""" + captured_function_invocation_kwargs_list: list[dict[str, Any]] = [] @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - # Capture kwargs at each level - captured_kwargs_list.append(dict(context.kwargs)) + captured_function_invocation_kwargs_list.append(dict(context.function_invocation_kwargs)) await call_next() # Setup mock responses to trigger nested tool invocation: B calls tool C, then completes. @@ -140,24 +157,29 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Invoke tool B with kwargs - should propagate to both B and C await tool_b.invoke( - arguments=tool_b.input_model(task="Test cascade"), - trace_id="trace-abc-123", - tenant_id="tenant-xyz", - options={"additional_function_arguments": {"trace_id": "trace-abc-123", "tenant_id": "tenant-xyz"}}, + context=self._build_context( + tool_b, + task="Test cascade", + runtime_kwargs={ + "trace_id": "trace-abc-123", + "tenant_id": "tenant-xyz", + }, + ), ) - # Verify kwargs were forwarded to the first agent invocation. - assert len(captured_kwargs_list) >= 1 - assert captured_kwargs_list[0].get("trace_id") == "trace-abc-123" - assert captured_kwargs_list[0].get("tenant_id") == "tenant-xyz" + assert len(captured_function_invocation_kwargs_list) >= 1 + assert captured_function_invocation_kwargs_list[0].get("trace_id") == "trace-abc-123" + assert captured_function_invocation_kwargs_list[0].get("tenant_id") == "tenant-xyz" async def test_as_tool_streaming_mode_forwards_kwargs(self, client: MockChatClient) -> None: - """Test that kwargs are forwarded in streaming mode.""" + """Test that runtime kwargs are forwarded in streaming mode.""" captured_kwargs: dict[str, Any] = {} + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock streaming responses @@ -182,13 +204,15 @@ async def stream_callback(update: Any) -> None: # Invoke tool with kwargs while streaming callback is active await tool.invoke( - arguments=tool.input_model(task="Test streaming"), - api_key="streaming-key-999", + context=self._build_context( + tool, + task="Test streaming", + runtime_kwargs={"api_key": "streaming-key-999"}, + ), ) - # Verify kwargs were forwarded even in streaming mode - assert "api_key" in captured_kwargs - assert captured_kwargs["api_key"] == "streaming-key-999" + assert captured_kwargs == {} + assert captured_function_invocation_kwargs["api_key"] == "streaming-key-999" assert len(captured_updates) == 1 async def test_as_tool_empty_kwargs_still_works(self, client: MockChatClient) -> None: @@ -206,18 +230,20 @@ async def test_as_tool_empty_kwargs_still_works(self, client: MockChatClient) -> tool = sub_agent.as_tool() # Invoke without any extra kwargs - should work without errors - result = await tool.invoke(arguments=tool.input_model(task="Simple task")) + result = await tool.invoke(arguments={"task": "Simple task"}) # Verify tool executed successfully assert result is not None async def test_as_tool_kwargs_with_chat_options(self, client: MockChatClient) -> None: - """Test that kwargs including chat_options are properly forwarded.""" + """Test that runtime kwargs are forwarded only via function_invocation_kwargs.""" captured_kwargs: dict[str, Any] = {} + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock response @@ -235,24 +261,26 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Invoke with various kwargs await tool.invoke( - arguments=tool.input_model(task="Test with options"), - temperature=0.8, - max_tokens=500, - custom_param="custom_value", + context=self._build_context( + tool, + task="Test with options", + runtime_kwargs={ + "temperature": 0.8, + "max_tokens": 500, + "custom_param": "custom_value", + }, + ), ) - # Verify all kwargs were forwarded - assert "temperature" in captured_kwargs - assert captured_kwargs["temperature"] == 0.8 - assert "max_tokens" in captured_kwargs - assert captured_kwargs["max_tokens"] == 500 - assert "custom_param" in captured_kwargs - assert captured_kwargs["custom_param"] == "custom_value" + assert captured_kwargs == {} + assert captured_function_invocation_kwargs["temperature"] == 0.8 + assert captured_function_invocation_kwargs["max_tokens"] == 500 + assert captured_function_invocation_kwargs["custom_param"] == "custom_value" async def test_as_tool_kwargs_isolated_per_invocation(self, client: MockChatClient) -> None: - """Test that kwargs are isolated per invocation and don't leak between calls.""" - first_call_kwargs: dict[str, Any] = {} - second_call_kwargs: dict[str, Any] = {} + """Test that runtime kwargs are isolated per invocation and don't leak between calls.""" + first_call_function_invocation_kwargs: dict[str, Any] = {} + second_call_function_invocation_kwargs: dict[str, Any] = {} call_count = 0 @agent_middleware @@ -260,9 +288,9 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai nonlocal call_count call_count += 1 if call_count == 1: - first_call_kwargs.update(context.kwargs) + first_call_function_invocation_kwargs.update(context.function_invocation_kwargs) elif call_count == 2: - second_call_kwargs.update(context.kwargs) + second_call_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock responses for both calls @@ -281,33 +309,35 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # First call with specific kwargs await tool.invoke( - arguments=tool.input_model(task="First task"), - session_id="session-1", - api_token="token-1", + context=self._build_context( + tool, + task="First task", + runtime_kwargs={"session_id": "session-1", "api_token": "token-1"}, + ), ) # Second call with different kwargs await tool.invoke( - arguments=tool.input_model(task="Second task"), - session_id="session-2", - api_token="token-2", + context=self._build_context( + tool, + task="Second task", + runtime_kwargs={"session_id": "session-2", "api_token": "token-2"}, + ), ) - # Verify first call had its own kwargs - assert first_call_kwargs.get("session_id") == "session-1" - assert first_call_kwargs.get("api_token") == "token-1" + assert first_call_function_invocation_kwargs.get("session_id") == "session-1" + assert first_call_function_invocation_kwargs.get("api_token") == "token-1" - # Verify second call had its own kwargs (not leaked from first) - assert second_call_kwargs.get("session_id") == "session-2" - assert second_call_kwargs.get("api_token") == "token-2" + assert second_call_function_invocation_kwargs.get("session_id") == "session-2" + assert second_call_function_invocation_kwargs.get("api_token") == "token-2" - async def test_as_tool_excludes_conversation_id_from_forwarded_kwargs(self, client: MockChatClient) -> None: - """Test that conversation_id is not forwarded to sub-agent.""" - captured_kwargs: dict[str, Any] = {} + async def test_as_tool_forwards_conversation_id_from_context_kwargs(self, client: MockChatClient) -> None: + """Test that conversation_id is forwarded when explicitly present in runtime context kwargs.""" + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock response @@ -325,17 +355,17 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Invoke tool with conversation_id in kwargs (simulating parent's conversation state) await tool.invoke( - arguments=tool.input_model(task="Test delegation"), - conversation_id="conv-parent-456", - api_token="secret-xyz-123", - user_id="user-456", - ) - - # Verify conversation_id was NOT forwarded to sub-agent - assert "conversation_id" not in captured_kwargs, ( - f"conversation_id should not be forwarded, but got: {captured_kwargs}" + context=self._build_context( + tool, + task="Test delegation", + runtime_kwargs={ + "conversation_id": "conv-parent-456", + "api_token": "secret-xyz-123", + "user_id": "user-456", + }, + ), ) - # Verify other kwargs were still forwarded - assert captured_kwargs.get("api_token") == "secret-xyz-123" - assert captured_kwargs.get("user_id") == "user-456" + assert captured_function_invocation_kwargs.get("conversation_id") == "conv-parent-456" + assert captured_function_invocation_kwargs.get("api_token") == "secret-xyz-123" + assert captured_function_invocation_kwargs.get("user_id") == "user-456" diff --git a/python/packages/core/tests/core/test_clients.py b/python/packages/core/tests/core/test_clients.py index a23b1d2a5f..670ff5f455 100644 --- a/python/packages/core/tests/core/test_clients.py +++ b/python/packages/core/tests/core/test_clients.py @@ -1,8 +1,11 @@ # Copyright (c) Microsoft. All rights reserved. +import inspect from unittest.mock import patch +import pytest + from agent_framework import ( BaseChatClient, ChatResponse, @@ -37,6 +40,60 @@ def test_base_client(chat_client_base: SupportsChatGetResponse): assert isinstance(chat_client_base, SupportsChatGetResponse) +def test_base_client_warns_for_direct_additional_properties(chat_client_base: SupportsChatGetResponse) -> None: + with pytest.warns(DeprecationWarning, match="additional_properties"): + client = type(chat_client_base)(legacy_key="legacy-value") + + assert client.additional_properties["legacy_key"] == "legacy-value" + + +def test_base_client_as_agent_uses_explicit_additional_properties(chat_client_base: SupportsChatGetResponse) -> None: + agent = chat_client_base.as_agent(additional_properties={"team": "core"}) + + assert agent.additional_properties == {"team": "core"} + + +def test_openai_chat_client_get_response_docstring_surfaces_layered_runtime_docs() -> None: + from agent_framework.openai import OpenAIChatClient + + docstring = inspect.getdoc(OpenAIChatClient.get_response) + + assert docstring is not None + assert "Get a response from a chat client." in docstring + assert "function_invocation_kwargs" in docstring + assert "function_middleware: Optional per-call function middleware." in docstring + assert "middleware: Optional per-call chat and function middleware." in docstring + + +def test_openai_chat_client_get_response_is_defined_on_openai_class() -> None: + from agent_framework.openai import OpenAIChatClient + + signature = inspect.signature(OpenAIChatClient.get_response) + + assert OpenAIChatClient.get_response.__qualname__ == "OpenAIChatClient.get_response" + assert "function_middleware" in signature.parameters + assert "middleware" in signature.parameters + + +async def test_base_client_get_response_uses_explicit_client_kwargs(chat_client_base: SupportsChatGetResponse) -> None: + async def fake_inner_get_response(**kwargs): + assert kwargs["trace_id"] == "trace-123" + assert "function_invocation_kwargs" not in kwargs + return ChatResponse(messages=[Message(role="assistant", text="ok")]) + + with patch.object( + chat_client_base, + "_inner_get_response", + side_effect=fake_inner_get_response, + ) as mock_inner_get_response: + await chat_client_base.get_response( + [Message(role="user", text="hello")], + function_invocation_kwargs={"tool_request_id": "tool-123"}, + client_kwargs={"trace_id": "trace-123"}, + ) + mock_inner_get_response.assert_called_once() + + async def test_base_client_get_response(chat_client_base: SupportsChatGetResponse): response = await chat_client_base.get_response([Message(role="user", text="Hello")]) assert response.messages[0].role == "assistant" diff --git a/python/packages/core/tests/core/test_embedding_client.py b/python/packages/core/tests/core/test_embedding_client.py index 71d2bcfd70..1c49c1d012 100644 --- a/python/packages/core/tests/core/test_embedding_client.py +++ b/python/packages/core/tests/core/test_embedding_client.py @@ -4,6 +4,8 @@ from collections.abc import Sequence +import pytest + from agent_framework import ( BaseEmbeddingClient, Embedding, @@ -63,6 +65,11 @@ def test_base_additional_properties_custom() -> None: assert client.additional_properties == {"key": "value"} +def test_base_embedding_client_rejects_unknown_kwargs() -> None: + with pytest.raises(TypeError): + MockEmbeddingClient(legacy_key="value") # type: ignore[call-arg] + + # --- SupportsGetEmbeddings protocol tests --- diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 7f0eda62fc..50df7ccea1 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -3512,3 +3512,44 @@ def test_dict_overwrites_existing_conversation_id(self): # endregion +async def test_user_input_request_propagates_through_as_tool(chat_client_base: SupportsChatGetResponse): + """Test that user_input_request content from a sub-agent wrapped as a tool propagates to the parent response.""" + from agent_framework.exceptions import UserInputRequiredException + + @tool(name="delegate_agent", approval_mode="never_require") + def delegate_tool(task: str) -> str: + del task + raise UserInputRequiredException( + contents=[ + Content.from_oauth_consent_request( + consent_link="https://login.microsoftonline.com/consent", + ) + ] + ) + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="1", name="delegate_agent", arguments='{"task": "do it"}'), + ], + ) + ) + ] + + response = await chat_client_base.get_response( + [Message(role="user", text="delegate this")], + options={"tool_choice": "auto", "tools": [delegate_tool]}, + ) + + user_requests = [ + content + for msg in response.messages + for content in msg.contents + if isinstance(content, Content) and content.user_input_request + ] + assert len(user_requests) == 1 + assert user_requests[0].type == "oauth_consent_request" + assert user_requests[0].consent_link == "https://login.microsoftonline.com/consent" + assert user_requests[0].user_input_request is True diff --git a/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py index cecd466d86..160ea0fcc4 100644 --- a/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py +++ b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py @@ -6,11 +6,13 @@ from typing import Any from agent_framework import ( + Agent, BaseChatClient, ChatMiddlewareLayer, ChatResponse, ChatResponseUpdate, Content, + FunctionInvocationContext, FunctionInvocationLayer, Message, ResponseStream, @@ -97,6 +99,7 @@ class TestKwargsPropagationToFunctionTool: async def test_kwargs_propagate_to_tool_with_kwargs(self) -> None: """Test that kwargs passed to get_response() are available in @tool **kwargs.""" + # TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed. captured_kwargs: dict[str, Any] = {} @tool(approval_mode="never_require") @@ -149,6 +152,7 @@ def capture_kwargs_tool(x: int, **kwargs: Any) -> str: async def test_kwargs_not_forwarded_to_tool_without_kwargs(self) -> None: """Test that kwargs are NOT forwarded to @tool that doesn't accept **kwargs.""" + # TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed. @tool(approval_mode="never_require") def simple_tool(x: int) -> str: @@ -185,6 +189,7 @@ def simple_tool(x: int) -> str: async def test_kwargs_isolated_between_function_calls(self) -> None: """Test that kwargs are consistent across multiple function call invocations.""" + # TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed. invocation_kwargs: list[dict[str, Any]] = [] @tool(approval_mode="never_require") @@ -235,6 +240,7 @@ def tracking_tool(name: str, **kwargs: Any) -> str: async def test_streaming_response_kwargs_propagation(self) -> None: """Test that kwargs propagate to @tool in streaming mode.""" + # TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed. captured_kwargs: dict[str, Any] = {} @tool(approval_mode="never_require") @@ -287,3 +293,59 @@ def streaming_capture_tool(value: str, **kwargs: Any) -> str: assert "streaming_session" in captured_kwargs, f"Expected 'streaming_session' in {captured_kwargs}" assert captured_kwargs["streaming_session"] == "session-xyz" assert captured_kwargs["correlation_id"] == "corr-123" + + async def test_agent_run_injects_function_invocation_context(self) -> None: + """Test that Agent.run injects FunctionInvocationContext for ctx-based tools.""" + captured_context_kwargs: dict[str, Any] = {} + captured_client_kwargs: dict[str, Any] = {} + captured_options: dict[str, Any] = {} + + @tool(approval_mode="never_require") + def capture_context_tool(x: int, ctx: FunctionInvocationContext) -> str: + captured_context_kwargs.update(ctx.kwargs) + return f"result: x={x}" + + class CapturingFunctionInvokingMockClient(FunctionInvokingMockClient): + async def _get_non_streaming_response( + self, + *, + messages: MutableSequence[Message], + options: dict[str, Any], + **kwargs: Any, + ) -> ChatResponse: + captured_options.update(options) + captured_client_kwargs.update(kwargs) + return await super()._get_non_streaming_response(messages=messages, options=options, **kwargs) + + client = CapturingFunctionInvokingMockClient() + client.run_responses = [ + ChatResponse( + messages=[ + Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_1", + name="capture_context_tool", + arguments='{"x": 42}', + ) + ], + ) + ] + ), + ChatResponse(messages=[Message(role="assistant", text="Done!")]), + ] + + agent = Agent(client=client, tools=[capture_context_tool]) + result = await agent.run( + [Message(role="user", text="Test")], + function_invocation_kwargs={"tool_request_id": "tool-123"}, + client_kwargs={"client_request_id": "client-456"}, + ) + + assert captured_context_kwargs["tool_request_id"] == "tool-123" + assert "client_request_id" not in captured_context_kwargs + assert captured_client_kwargs["client_request_id"] == "client-456" + assert "tool_request_id" not in captured_client_kwargs + assert "additional_function_arguments" not in captured_options + assert result.messages[-1].text == "Done!" diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index 4d2e603274..bd2cb8155e 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -192,10 +192,10 @@ def __init__(self, source_id: str, stored_messages: list[Message] | None = None, self.stored: list[Message] = [] self._stored_messages = stored_messages or [] - async def get_messages(self, session_id: str | None, **kwargs) -> list[Message]: + async def get_messages(self, session_id: str | None, *, state=None, **kwargs) -> list[Message]: return list(self._stored_messages) - async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs) -> None: + async def save_messages(self, session_id: str | None, messages: Sequence[Message], *, state=None, **kwargs) -> None: self.stored.extend(messages) diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index f7674edc9b..fcab14ee1b 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -12,6 +12,7 @@ FunctionTool, tool, ) +from agent_framework._middleware import FunctionInvocationContext from agent_framework._tools import ( _parse_annotation, _parse_inputs, @@ -941,6 +942,128 @@ def tool_with_kwargs(x: int, **kwargs: Any) -> str: assert result_default == "x=10, user=unknown" +async def test_ai_function_with_explicit_invocation_context(): + """Test that invoke() can receive runtime kwargs via FunctionInvocationContext.""" + + @tool + def tool_with_context(x: int, ctx: FunctionInvocationContext) -> str: + """A tool that accepts runtime context injection.""" + user_id = ctx.kwargs.get("user_id", "unknown") + return f"x={x}, user={user_id}" + + assert tool_with_context.parameters() == { + "properties": {"x": {"title": "X", "type": "integer"}}, + "required": ["x"], + "title": "tool_with_context_input", + "type": "object", + } + + context = FunctionInvocationContext( + function=tool_with_context, + arguments=tool_with_context.input_model(x=7), + kwargs={"user_id": "ctx-user"}, + ) + + result = await tool_with_context.invoke(context=context) + + assert result == "x=7, user=ctx-user" + + +async def test_ai_function_with_typed_context_parameter_using_custom_name(): + """Test that typed context injection works for names other than ctx.""" + + @tool + def tool_with_runtime_context(x: int, runtime: FunctionInvocationContext) -> str: + """A tool that uses a custom context parameter name.""" + user_id = runtime.kwargs.get("user_id", "unknown") + return f"x={x}, user={user_id}" + + assert tool_with_runtime_context.parameters() == { + "properties": {"x": {"title": "X", "type": "integer"}}, + "required": ["x"], + "title": "tool_with_runtime_context_input", + "type": "object", + } + + context = FunctionInvocationContext( + function=tool_with_runtime_context, + arguments=tool_with_runtime_context.input_model(x=8), + kwargs={"user_id": "runtime-user"}, + ) + + result = await tool_with_runtime_context.invoke(context=context) + + assert result == "x=8, user=runtime-user" + + +async def test_ai_function_with_explicit_schema_and_untyped_ctx(): + """Test that explicit schemas allow an untyped ctx parameter.""" + + class ToolInput(BaseModel): + x: int + + @tool(schema=ToolInput) + def tool_with_schema(x, ctx) -> str: + """A tool with explicit schema and implicit ctx injection.""" + return f"x={x}, user={ctx.kwargs.get('user_id', 'unknown')}" + + context = FunctionInvocationContext( + function=tool_with_schema, + arguments=ToolInput(x=9), + kwargs={"user_id": "schema-user"}, + ) + + result = await tool_with_schema.invoke(context=context) + + assert result == "x=9, user=schema-user" + + +async def test_ai_function_with_explicit_schema_and_typed_ctx(): + """Test that explicit schemas also work with typed context injection.""" + + class ToolInput(BaseModel): + x: int + + @tool(schema=ToolInput) + def tool_with_schema(x: int, runtime: FunctionInvocationContext) -> str: + """A tool with explicit schema and typed context injection.""" + return f"x={x}, user={runtime.kwargs.get('user_id', 'unknown')}" + + context = FunctionInvocationContext( + function=tool_with_schema, + arguments=ToolInput(x=11), + kwargs={"user_id": "typed-schema-user"}, + ) + + result = await tool_with_schema.invoke(context=context) + + assert tool_with_schema.parameters() == ToolInput.model_json_schema() + assert result == "x=11, user=typed-schema-user" + + +def test_ai_function_with_multiple_typed_context_parameters_fails(): + """Test that tools reject multiple typed FunctionInvocationContext parameters.""" + + with pytest.raises(ValueError, match="multiple FunctionInvocationContext parameters"): + + @tool + def invalid_tool(ctx_one: FunctionInvocationContext, ctx_two: FunctionInvocationContext) -> str: + return f"{ctx_one.kwargs}-{ctx_two.kwargs}" + + +def test_ai_function_with_ctx_and_typed_context_parameter_fails(): + """Test that explicit-schema tools reject both implicit ctx and typed context parameters.""" + + class ToolInput(BaseModel): + x: int + + with pytest.raises(ValueError, match="multiple FunctionInvocationContext parameters"): + + @tool(schema=ToolInput) + def invalid_tool(x, ctx, runtime: FunctionInvocationContext) -> str: + return f"{x}-{ctx.kwargs}-{runtime.kwargs}" + + # region _parse_annotation tests diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py index 0a7cf50b0b..713c1b4e69 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_executors.py +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -124,10 +124,20 @@ def run_durable_agent( """ raise NotImplementedError - def get_new_session(self, agent_name: str, **kwargs: Any) -> DurableAgentSession: + def get_new_session( + self, + agent_name: str, + *, + session_id: str | None = None, + service_session_id: str | None = None, + ) -> DurableAgentSession: """Create a new DurableAgentSession with random session ID.""" - session_id = self._create_session_id(agent_name) - return DurableAgentSession.from_session_id(session_id, **kwargs) + durable_session_id = self._create_session_id(agent_name) + return DurableAgentSession( + durable_session_id=durable_session_id, + session_id=session_id, + service_session_id=service_session_id, + ) def _create_session_id( self, diff --git a/python/packages/durabletask/agent_framework_durabletask/_models.py b/python/packages/durabletask/agent_framework_durabletask/_models.py index 1c5484afbf..c1654c99c0 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_models.py +++ b/python/packages/durabletask/agent_framework_durabletask/_models.py @@ -284,46 +284,47 @@ def __init__( durable_session_id: AgentSessionId | None = None, session_id: str | None = None, service_session_id: str | None = None, - **kwargs: Any, ) -> None: - super().__init__(session_id=session_id, service_session_id=service_session_id, **kwargs) - self._session_id_value: AgentSessionId | None = durable_session_id + super().__init__(session_id=session_id, service_session_id=service_session_id) + self.durable_session_id: AgentSessionId | None = durable_session_id - @property - def durable_session_id(self) -> AgentSessionId | None: - return self._session_id_value - - @durable_session_id.setter - def durable_session_id(self, value: AgentSessionId | None) -> None: - self._session_id_value = value + def to_dict(self) -> dict[str, Any]: + state = super().to_dict() + if self.durable_session_id is not None: + state[self._SERIALIZED_SESSION_ID_KEY] = str(self.durable_session_id) + return state @classmethod def from_session_id( cls, - session_id: AgentSessionId, - **kwargs: Any, + durable_session_id: AgentSessionId, + *, + session_id: str | None = None, + service_session_id: str | None = None, ) -> DurableAgentSession: - return cls(durable_session_id=session_id, **kwargs) - - def to_dict(self) -> dict[str, Any]: - state = super().to_dict() - if self._session_id_value is not None: - state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id_value) - return state + """Create a DurableAgentSession from an AgentSessionId.""" + return cls( + durable_session_id=durable_session_id, + session_id=session_id, + service_session_id=service_session_id, + ) @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentSession: - state_payload = dict(data) - session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None) - session = super().from_dict(state_payload) + """Create a DurableAgentSession from a state dict.""" + session_id_value = data.pop(cls._SERIALIZED_SESSION_ID_KEY, None) + session = super().from_dict(data) + durable_session_id: AgentSessionId | None = None # We need to create a DurableAgentSession from the base AgentSession + if session_id_value is not None: + if not isinstance(session_id_value, str): + raise ValueError("durable_session_id must be a string when present in serialized state") + durable_session_id = AgentSessionId.parse(session_id_value) + durable_session = cls( + durable_session_id=durable_session_id, session_id=session.session_id, service_session_id=session.service_session_id, ) durable_session.state.update(session.state) - if session_id_value is not None: - if not isinstance(session_id_value, str): - raise ValueError("durable_session_id must be a string when present in serialized state") - durable_session._session_id_value = AgentSessionId.parse(session_id_value) return durable_session diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index 5693876ad7..09cd6e6875 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -133,16 +133,26 @@ def run( # type: ignore[override] session=session, ) - def create_session(self, **kwargs: Any) -> DurableAgentSession: + def create_session(self, *, session_id: str | None = None, **kwargs: Any) -> DurableAgentSession: """Create a new agent session via the provider.""" - return self._executor.get_new_session(self.name, **kwargs) + return self._executor.get_new_session(self.name) - def get_session(self, **kwargs: Any) -> AgentSession: + def get_session(self, service_session_id: str, *, session_id: str | None = None, **kwargs: Any) -> AgentSession: """Retrieve an existing session via the provider. For durable agents, sessions do not use `service_session_id` so this is not used. """ - return self._executor.get_new_session(self.name, **kwargs) + session = self._executor.get_new_session(self.name) + if service_session_id == session.service_session_id and session_id is None: + return session + + cloned_session = DurableAgentSession( + durable_session_id=session.durable_session_id, + session_id=session_id or session.session_id, + service_session_id=service_session_id, + ) + cloned_session.state.update(session.state) + return cloned_session def _normalize_messages(self, messages: AgentRunInputs | None) -> str: """Convert supported message inputs to a single string. diff --git a/python/packages/durabletask/tests/test_agent_session_id.py b/python/packages/durabletask/tests/test_agent_session_id.py index 571212f145..3902acd22f 100644 --- a/python/packages/durabletask/tests/test_agent_session_id.py +++ b/python/packages/durabletask/tests/test_agent_session_id.py @@ -2,6 +2,8 @@ """Unit tests for AgentSessionId and DurableAgentSession.""" +from typing import Any + import pytest from agent_framework import AgentSession @@ -153,7 +155,7 @@ def test_durable_session_id_setter(self) -> None: def test_from_session_id(self) -> None: """Test creating DurableAgentSession from session ID.""" session_id = AgentSessionId(name="TestAgent", key="test-key") - session = DurableAgentSession.from_session_id(session_id) + session = DurableAgentSession(durable_session_id=session_id) assert isinstance(session, DurableAgentSession) assert session.durable_session_id is not None @@ -161,10 +163,10 @@ def test_from_session_id(self) -> None: assert session.durable_session_id.name == "TestAgent" assert session.durable_session_id.key == "test-key" - def test_from_session_id_with_service_session_id(self) -> None: - """Test creating DurableAgentSession with service session ID.""" + def test_init_with_service_session_id(self) -> None: + """Test creating DurableAgentSession with explicit service session ID.""" session_id = AgentSessionId(name="TestAgent", key="test-key") - session = DurableAgentSession.from_session_id(session_id, service_session_id="service-123") + session = DurableAgentSession(durable_session_id=session_id, service_session_id="service-123") assert session.durable_session_id is not None assert session.durable_session_id == session_id @@ -192,7 +194,7 @@ def test_to_dict_without_durable_session_id(self) -> None: def test_from_dict_with_durable_session_id(self) -> None: """Test deserialization restores durable session ID.""" - serialized = { + serialized: dict[str, Any] = { "type": "session", "session_id": "session-123", "service_session_id": "service-123", @@ -210,7 +212,7 @@ def test_from_dict_with_durable_session_id(self) -> None: def test_from_dict_without_durable_session_id(self) -> None: """Test deserialization without durable session ID.""" - serialized = { + serialized: dict[str, Any] = { "type": "session", "session_id": "session-456", "service_session_id": "service-456", diff --git a/python/packages/durabletask/tests/test_client.py b/python/packages/durabletask/tests/test_client.py index 0acdfb2f9c..a056d4e254 100644 --- a/python/packages/durabletask/tests/test_client.py +++ b/python/packages/durabletask/tests/test_client.py @@ -88,15 +88,6 @@ def test_client_agent_can_create_sessions(self, agent_client: DurableAIAgentClie assert isinstance(session, DurableAgentSession) - def test_client_agent_session_with_parameters(self, agent_client: DurableAIAgentClient) -> None: - """Verify agent can create sessions with custom parameters.""" - agent = agent_client.get_agent("assistant") - - session = agent.create_session(service_session_id="client-session-123") - - assert isinstance(session, DurableAgentSession) - assert session.service_session_id == "client-session-123" - class TestDurableAIAgentClientPollingConfiguration: """Test polling configuration parameters for DurableAIAgentClient.""" diff --git a/python/packages/durabletask/tests/test_orchestration_context.py b/python/packages/durabletask/tests/test_orchestration_context.py index 033c274c88..9f7cde156c 100644 --- a/python/packages/durabletask/tests/test_orchestration_context.py +++ b/python/packages/durabletask/tests/test_orchestration_context.py @@ -82,17 +82,6 @@ def test_orchestration_agent_can_create_sessions(self, agent_context: DurableAIA assert isinstance(session, DurableAgentSession) - def test_orchestration_agent_session_with_parameters( - self, agent_context: DurableAIAgentOrchestrationContext - ) -> None: - """Verify agent can create sessions with custom parameters.""" - agent = agent_context.get_agent("assistant") - - session = agent.create_session(service_session_id="orch-session-456") - - assert isinstance(session, DurableAgentSession) - assert session.service_session_id == "orch-session-456" - if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_shim.py b/python/packages/durabletask/tests/test_shim.py index 423f587871..f713fdded2 100644 --- a/python/packages/durabletask/tests/test_shim.py +++ b/python/packages/durabletask/tests/test_shim.py @@ -184,17 +184,6 @@ def test_create_session_delegates_to_executor(self, test_agent: DurableAIAgent[A mock_executor.get_new_session.assert_called_once_with("test_agent") assert session == mock_session - def test_create_session_forwards_kwargs(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: - """Verify create_session forwards kwargs to executor.""" - mock_session = DurableAgentSession(service_session_id="session-123") - mock_executor.get_new_session.return_value = mock_session - - test_agent.create_session(service_session_id="session-123") - - mock_executor.get_new_session.assert_called_once() - _, kwargs = mock_executor.get_new_session.call_args - assert kwargs["service_session_id"] == "session-123" - class TestDurableAgentProviderInterface: """Test that DurableAgentProvider defines the correct interface.""" diff --git a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py index 16451ae85a..4c1e64cd7c 100644 --- a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py +++ b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py @@ -146,11 +146,11 @@ def __init__( timeout: float | None = None, prepare_model: bool = True, device: DeviceType | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str = "utf-8", - **kwargs: Any, ) -> None: """Initialize a FoundryLocalClient. @@ -169,12 +169,11 @@ def __init__( The device is used to select the appropriate model variant. If not provided, the default device for your system will be used. The values are in the foundry_local.models.DeviceType enum. + additional_properties: Additional properties stored on the client instance. middleware: Optional sequence of ChatAndFunctionMiddlewareTypes to apply to requests. function_invocation_configuration: Optional configuration for function invocation support. env_file_path: If provided, the .env settings are read from this file path location. env_file_encoding: The encoding of the .env file, defaults to 'utf-8'. - kwargs: Additional keyword arguments, are passed to the RawOpenAIChatClient. - This can include middleware and additional properties. Examples: @@ -271,8 +270,8 @@ class MyOptions(FoundryLocalChatOptions, total=False): super().__init__( model_id=model_info.id, client=AsyncOpenAI(base_url=manager.endpoint, api_key=manager.api_key), + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) self.manager = manager diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 7fa7d0dce4..2b7c266a37 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -303,7 +303,6 @@ def run( stream: Literal[False] = False, session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> Awaitable[AgentResponse]: ... @overload @@ -314,7 +313,6 @@ def run( stream: Literal[True], session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... def run( @@ -324,7 +322,6 @@ def run( stream: bool = False, session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: """Get a response from the agent. @@ -339,7 +336,6 @@ def run( stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). options: Runtime options (model, timeout, etc.). - kwargs: Additional keyword arguments. Returns: When stream=False: An Awaitable[AgentResponse]. @@ -354,10 +350,10 @@ def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: return AgentResponse.from_updates(updates) return ResponseStream( - self._stream_updates(messages=messages, session=session, options=options, **kwargs), + self._stream_updates(messages=messages, session=session, options=options), finalizer=_finalize, ) - return self._run_impl(messages=messages, session=session, options=options, **kwargs) + return self._run_impl(messages=messages, session=session, options=options) async def _run_impl( self, @@ -365,7 +361,6 @@ async def _run_impl( *, session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> AgentResponse: """Non-streaming implementation of run.""" if not self._started: @@ -414,7 +409,6 @@ async def _stream_updates( *, session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: """Internal method to stream updates from GitHub Copilot. @@ -424,7 +418,6 @@ async def _stream_updates( Keyword Args: session: The conversation session associated with the message(s). options: Runtime options (model, timeout, etc.). - kwargs: Additional keyword arguments. Yields: AgentResponseUpdate items. diff --git a/python/packages/ollama/agent_framework_ollama/_chat_client.py b/python/packages/ollama/agent_framework_ollama/_chat_client.py index e31c1971da..db0666b2d4 100644 --- a/python/packages/ollama/agent_framework_ollama/_chat_client.py +++ b/python/packages/ollama/agent_framework_ollama/_chat_client.py @@ -300,11 +300,11 @@ def __init__( host: str | None = None, client: AsyncClient | None = None, model_id: str | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Ollama Chat client. @@ -313,11 +313,11 @@ def __init__( Can be set via the OLLAMA_HOST env variable. client: An optional Ollama Client instance. If not provided, a new instance will be created. model_id: The Ollama chat model ID to use. Can be set via the OLLAMA_MODEL_ID env variable. + additional_properties: Additional properties stored on the client instance. middleware: Optional middleware to apply to the client. function_invocation_configuration: Optional function invocation configuration override. env_file_path: An optional path to a dotenv (.env) file to load environment variables from. env_file_encoding: The encoding to use when reading the dotenv (.env) file. Defaults to 'utf-8'. - **kwargs: Additional keyword arguments passed to BaseChatClient. """ ollama_settings = load_settings( OllamaSettings, @@ -336,9 +336,9 @@ def __init__( self.host = str(self.client._client.base_url) # type: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType] super().__init__( + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) self.middleware = list(self.chat_middleware) diff --git a/python/packages/ollama/agent_framework_ollama/_embedding_client.py b/python/packages/ollama/agent_framework_ollama/_embedding_client.py index 5cd35fc9f3..8e0508c708 100644 --- a/python/packages/ollama/agent_framework_ollama/_embedding_client.py +++ b/python/packages/ollama/agent_framework_ollama/_embedding_client.py @@ -92,9 +92,9 @@ def __init__( model_id: str | None = None, host: str | None = None, client: AsyncClient | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a raw Ollama embedding client.""" ollama_settings = load_settings( @@ -110,7 +110,7 @@ def __init__( self.model_id = ollama_settings["embedding_model_id"] # type: ignore[assignment,reportTypedDictNotRequiredAccess] self.client = client or AsyncClient(host=ollama_settings.get("host")) self.host = str(self.client._client.base_url) # type: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType] - super().__init__(**kwargs) + super().__init__(additional_properties=additional_properties) def service_url(self) -> str: """Get the URL of the service.""" @@ -214,17 +214,17 @@ def __init__( host: str | None = None, client: AsyncClient | None = None, otel_provider_name: str | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Ollama embedding client.""" super().__init__( model_id=model_id, host=host, client=client, + additional_properties=additional_properties, otel_provider_name=otel_provider_name, env_file_path=env_file_path, env_file_encoding=env_file_encoding, - **kwargs, ) diff --git a/python/packages/redis/agent_framework_redis/_history_provider.py b/python/packages/redis/agent_framework_redis/_history_provider.py index e1a20b6218..be2db098b8 100644 --- a/python/packages/redis/agent_framework_redis/_history_provider.py +++ b/python/packages/redis/agent_framework_redis/_history_provider.py @@ -107,11 +107,18 @@ def _redis_key(self, session_id: str | None) -> str: """Get the Redis key for a given session's messages.""" return f"{self.key_prefix}:{session_id or 'default'}" - async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]: + async def get_messages( + self, + session_id: str | None, + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> list[Message]: """Retrieve stored messages for this session from Redis. Args: session_id: The session ID to retrieve messages for. + state: Optional session state. Unused for Redis-backed history. **kwargs: Additional arguments (unused). Returns: @@ -125,12 +132,20 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess messages.append(Message.from_dict(self._deserialize_json(serialized))) # type: ignore[union-attr] return messages - async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None: + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: """Persist messages for this session to Redis. Args: session_id: The session ID to store messages for. messages: The messages to persist. + state: Optional session state. Unused for Redis-backed history. **kwargs: Additional arguments (unused). """ if not messages: diff --git a/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py b/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py index 33748437e0..8199557792 100644 --- a/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py +++ b/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Awaitable, Callable -from agent_framework import AgentContext, AgentSession +from agent_framework import AgentContext, AgentSession, FunctionInvocationContext, tool from agent_framework.openai import OpenAIResponsesClient from dotenv import load_dotenv @@ -17,7 +17,9 @@ When session propagation is enabled, both agents share the same session object, including session_id and the mutable state dict. This allows correlated -conversation tracking and shared state across the agent hierarchy. +conversation tracking and shared state across the agent hierarchy. The session +must be passed explicitly through ``function_invocation_kwargs`` for the +delegated tool call. The middleware functions below are purely for observability — they are NOT required for session propagation to work. @@ -34,13 +36,37 @@ async def log_session( If propagation is working, both agents will show the same session_id. """ session: AgentSession | None = context.session + if not session: + print("No session found.") + await call_next() + return agent_name = context.agent.name or "unknown" - session_id = session.session_id if session else None - state = dict(session.state) if session else {} - print(f" [{agent_name}] session_id={session_id}, state={state}") + print( + f" [{agent_name}] session_id={session.session_id}, " + f"service_session_id={session.service_session_id} state={session.state}" + ) await call_next() +@tool(description="Use this tool to store the findings so that other agents can reason over them.") +def store_findings(findings: str, ctx: FunctionInvocationContext) -> None: + session = ctx.kwargs.get("session") + current_findings = session.state["findings"] + if current_findings is None: + session.state["findings"] = findings + else: + session.state["finding"] = f"{current_findings}\n{findings}" + + +@tool(description="Use this tool to gather the current findings from other agents.") +def recall_findings(ctx: FunctionInvocationContext) -> str: + session = ctx.kwargs.get("session") + current_findings = session.state["findings"] + if current_findings is None: + return "Nothing yet" + return current_findings + + async def main() -> None: print("=== Agent-as-Tool: Session Propagation ===\n") @@ -50,14 +76,15 @@ async def main() -> None: # The sub-agent has the same log_session middleware to prove it receives the session. research_agent = client.as_agent( name="ResearchAgent", - instructions="You are a research assistant. Provide concise answers.", + instructions="You are a research assistant. Provide concise answers and store your findings.", middleware=[log_session], + tools=[store_findings, recall_findings], ) - # propagate_session=True: the coordinator's session will be forwarded + # propagate_session=True forwards an explicitly supplied runtime session. research_tool = research_agent.as_tool( name="research", - description="Research a topic and return findings", + description="Research a topic and store your findings.", arg_name="query", arg_description="The research query", propagate_session=True, @@ -66,27 +93,29 @@ async def main() -> None: # --- Coordinator agent --- coordinator = client.as_agent( name="CoordinatorAgent", - instructions="You coordinate research. Use the 'research' tool to look up information.", - tools=[research_tool], + instructions="You coordinate research. Use the 'research' tool to start research and then use the recall findings tool to gather up everything. You can also start by storing some of the background directly.", + tools=[research_tool, store_findings, recall_findings], middleware=[log_session], ) # Create a shared session and put some state in it session = coordinator.create_session() - session.state["request_source"] = "demo" + session.state["findings"] = None print(f"Session ID: {session.session_id}") print(f"Session state before run: {session.state}\n") - query = "What are the latest developments in quantum computing?" + query = "What are the latest developments in quantum computing and in AI?" print(f"User: {query}\n") - result = await coordinator.run(query, session=session) + result = await coordinator.run( + query, + session=session, + function_invocation_kwargs={"session": session}, + ) print(f"\nCoordinator: {result}\n") print(f"Session state after run: {session.state}") - print( - "\nIf both agents show the same session_id above, session propagation is working." - ) + print("\nIf both agents show the same session_id above, session propagation is working.") if __name__ == "__main__": diff --git a/python/samples/02-agents/tools/function_tool_with_kwargs.py b/python/samples/02-agents/tools/function_tool_with_kwargs.py index 249ebc4a33..61db84eb17 100644 --- a/python/samples/02-agents/tools/function_tool_with_kwargs.py +++ b/python/samples/02-agents/tools/function_tool_with_kwargs.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from typing import Annotated, Any +from typing import Annotated -from agent_framework import tool +from agent_framework import FunctionInvocationContext, tool from agent_framework.openai import OpenAIResponsesClient from dotenv import load_dotenv from pydantic import Field @@ -14,27 +14,27 @@ """ AI Function with kwargs Example -This example demonstrates how to inject custom keyword arguments (kwargs) into an AI function -from the agent's run method, without exposing them to the AI model. +This example demonstrates how to inject runtime context into an AI function +from the agent's run method, without exposing it to the AI model. This is useful for passing runtime information like access tokens, user IDs, or request-specific context that the tool needs but the model shouldn't know about -or provide. +or provide. The injected context parameter can be typed as +``FunctionInvocationContext`` as shown here, or left untyped as ``ctx`` when you +prefer a lighter-weight sample setup. """ -# Define the function tool with **kwargs to accept injected arguments -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; -# see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. +# Define the function tool with explicit invocation context. +# The context parameter can also be declared as an untyped ``ctx`` parameter. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], - **kwargs: Any, + ctx: FunctionInvocationContext, ) -> str: """Get the weather for a given location.""" - # Extract the injected argument from kwargs - user_id = kwargs.get("user_id", "unknown") + # Extract the injected argument from the explicit context + user_id = ctx.kwargs.get("user_id", "unknown") # Simulate using the user_id for logging or personalization print(f"Getting weather for user: {user_id}") @@ -49,9 +49,11 @@ async def main() -> None: tools=[get_weather], ) - # Pass the injected argument when running the agent - # The 'user_id' kwarg will be passed down to the tool execution via **kwargs - response = await agent.run("What is the weather like in Amsterdam?", user_id="user_123") + # Pass the runtime context explicitly when running the agent. + response = await agent.run( + "What is the weather like in Amsterdam?", + function_invocation_kwargs={"user_id": "user_123"}, + ) print(f"Agent: {response.text}") diff --git a/python/samples/02-agents/tools/function_tool_with_session_injection.py b/python/samples/02-agents/tools/function_tool_with_session_injection.py index 2689ff5f9c..a21a5b82b4 100644 --- a/python/samples/02-agents/tools/function_tool_with_session_injection.py +++ b/python/samples/02-agents/tools/function_tool_with_session_injection.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from typing import Annotated, Any +from typing import Annotated -from agent_framework import AgentSession, tool +from agent_framework import AgentSession, FunctionInvocationContext, tool from agent_framework.openai import OpenAIResponsesClient from dotenv import load_dotenv from pydantic import Field @@ -14,23 +14,24 @@ """ AI Function with Session Injection Example -This example demonstrates the behavior when passing 'session' to agent.run() -and accessing that session in AI function. +This example demonstrates explicitly passing an ``AgentSession`` through +``function_invocation_kwargs`` and reading it from ``FunctionInvocationContext.kwargs``. +The injected context parameter can be typed as ``FunctionInvocationContext`` as +shown here, or left untyped as ``ctx`` when you want the conventional untyped form. """ -# Define the function tool with **kwargs -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; -# see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. +# Define the function tool with explicit invocation context. +# The context parameter can also be declared as an untyped parameter with the name: ``ctx``. @tool(approval_mode="never_require") async def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], - **kwargs: Any, + ctx: FunctionInvocationContext, ) -> str: """Get the weather for a given location.""" - # Get session object from kwargs - session = kwargs.get("session") + # FunctionInvocationContext does not surface agent sessions directly. + # If a tool needs session data, pass it explicitly through function_invocation_kwargs. + session = ctx.kwargs.get("session") if session and isinstance(session, AgentSession) and session.service_session_id: print(f"Session ID: {session.service_session_id}.") @@ -42,18 +43,22 @@ async def main() -> None: name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=[get_weather], - options={"store": True}, + default_options={"store": True}, ) # Create a session session = agent.create_session() - # Run the agent with the session - # Pass session via additional_function_arguments so tools can access it via **kwargs - opts = {"additional_function_arguments": {"session": session}} - print(f"Agent: {await agent.run('What is the weather in London?', session=session, options=opts)}") - print(f"Agent: {await agent.run('What is the weather in Amsterdam?', session=session, options=opts)}") - print(f"Agent: {await agent.run('What cities did I ask about?', session=session)}") + # Pass the session explicitly through function_invocation_kwargs when the tool needs it. + print( + f"Agent: {await agent.run('What is the weather in London?', session=session, function_invocation_kwargs={'session': session})}" + ) + print( + f"Agent: {await agent.run('What is the weather in Amsterdam?', session=session, function_invocation_kwargs={'session': session})}" + ) + print( + f"Agent: {await agent.run('What cities did I ask about?', session=session, function_invocation_kwargs={'session': session})}" + ) if __name__ == "__main__": From a446ed49297738d9e652c52ce37c5cab932db35a Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Mar 2026 09:48:26 +0100 Subject: [PATCH 02/13] clarified docstring --- python/packages/core/agent_framework/_agents.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index efc6fd8e51..f6276ec8f0 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -503,12 +503,10 @@ def as_tool( If None, defaults to "Task for {tool_name}". approval_mode: Whether this delegated tool requires approval before execution. stream_callback: Optional callback for streaming responses. If provided, uses run(..., stream=True). - propagate_session: If True, the sub-agent's ``run()`` call receives - the ``session`` value from ``FunctionInvocationContext.kwargs`` - when one is supplied explicitly (for example via + propagate_session: If True, this agent's get's a ``session`` from the + calling agents, when one is supplied explicitly (for example via ``function_invocation_kwargs={"session": session}``). Defaults - to False, meaning the sub-agent runs with a new, independent - session. + to False, meaning this agent runs without a session. Returns: A FunctionTool that can be used as a tool by other agents. From eebec5b8ee2fd50c04fd89d7dde6fa2dbdca82f3 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Mar 2026 09:49:55 +0100 Subject: [PATCH 03/13] fix test --- .../packages/core/tests/core/test_agents.py | 42 ------------------- 1 file changed, 42 deletions(-) diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index ec93d6c242..021545e082 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -1067,48 +1067,6 @@ async def capturing_inner( assert len(tool_names) == 3 -async def test_agent_tool_does_not_receive_agent_session_in_kwargs(chat_client_base: Any) -> None: - """Verify agent sessions are not injected into tool kwargs implicitly.""" - - captured: dict[str, Any] = {} - - @tool(name="echo_session_info", approval_mode="never_require") - def echo_session_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUnknownParameterType] - session = kwargs.get("session") - captured["has_session"] = session is not None - captured["has_state"] = session.state is not None if isinstance(session, AgentSession) else False - return f"echo: {text}" - - # Make the base client emit a function call for our tool - chat_client_base.run_responses = [ - ChatResponse( - messages=Message( - role="assistant", - contents=[ - Content.from_function_call( - call_id="1", - name="echo_session_info", - arguments='{"text": "hello"}', - ) - ], - ) - ), - ChatResponse(messages=Message(role="assistant", text="done")), - ] - - agent = Agent(client=chat_client_base, tools=[echo_session_info]) - session = agent.create_session() - - result = await agent.run( - "hello", - session=session, - options={"additional_function_arguments": {"session": session}}, - ) - assert result.text == "done" - assert captured.get("has_session") is False - assert captured.get("has_state") is False - - async def test_agent_tool_receives_explicit_session_via_function_invocation_context_kwargs( chat_client_base: Any, ) -> None: From 4133d9d08e76da3ae26c34bd81f920c06e3c22c7 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Mar 2026 10:13:24 +0100 Subject: [PATCH 04/13] feedback fixes --- python/packages/core/agent_framework/_agents.py | 6 +++--- python/packages/core/agent_framework/_tools.py | 7 +++++-- .../agent_framework_durabletask/_shim.py | 17 ++--------------- 3 files changed, 10 insertions(+), 20 deletions(-) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index f6276ec8f0..953f205ecf 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -419,16 +419,16 @@ def __init__( def create_session(self, *, session_id: str | None = None) -> AgentSession: """Create a new lightweight session. - This will be used by a agent to hold the persisted session. + This will be used by an agent to hold the persisted session. This depends on the service used, in some cases, or with store=True - this will add the `service_session_id` based on the response, + this will add the ``service_session_id`` based on the response, which is then fed back to the API on the next call. In other cases, if there is a HistoryProvider setup in the agent, that is used and it can store state in the session. If there is no HistoryProvider and store=False or the default of a service is False. - Then a ``InMemoryHistoryProvider`` is added to the agent and used with the session automatically. + Then a ``InMemoryHistoryProvider`` instance is added to the agent and used with the session automatically. The ``InMemoryHistoryProvider`` stores the messages as `state` in the session by default. Keyword Args: diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 67015cf061..31def22f1c 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -590,8 +590,11 @@ async def invoke( # that still declare ``**kwargs``. New tools should consume runtime data via ``ctx``. legacy_runtime_kwargs = dict(runtime_kwargs) if self._forward_runtime_kwargs and legacy_runtime_kwargs: - call_kwargs.update(legacy_runtime_kwargs) - observable_kwargs.update(legacy_runtime_kwargs) + for key, value in legacy_runtime_kwargs.items(): + if key not in call_kwargs: + call_kwargs[key] = value + if key not in observable_kwargs: + observable_kwargs[key] = value if self._context_parameter_name is not None and effective_context is not None: call_kwargs[self._context_parameter_name] = effective_context diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index 09cd6e6875..077aae8b28 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -138,21 +138,8 @@ def create_session(self, *, session_id: str | None = None, **kwargs: Any) -> Dur return self._executor.get_new_session(self.name) def get_session(self, service_session_id: str, *, session_id: str | None = None, **kwargs: Any) -> AgentSession: - """Retrieve an existing session via the provider. - - For durable agents, sessions do not use `service_session_id` so this is not used. - """ - session = self._executor.get_new_session(self.name) - if service_session_id == session.service_session_id and session_id is None: - return session - - cloned_session = DurableAgentSession( - durable_session_id=session.durable_session_id, - session_id=session_id or session.session_id, - service_session_id=service_session_id, - ) - cloned_session.state.update(session.state) - return cloned_session + """Retrieve an existing session via the provider.""" + return self._executor.get_new_session(self.name, service_session_id=service_session_id, session_id=session_id) def _normalize_messages(self, messages: AgentRunInputs | None) -> str: """Convert supported message inputs to a single string. From 280278d83d6572118f439a2418a7e60f7d116e00 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Mar 2026 11:03:48 +0100 Subject: [PATCH 05/13] Add unit tests for _docstrings.py build/apply helpers Tests cover: no docstring source, no extra kwargs, appending to existing Keyword Args section, inserting after Args, inserting in plain docstrings, multiline descriptions, ordering, and apply_layered_docstring. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../core/tests/core/test_docstrings.py | 175 ++++++++++++++++++ .../agent_framework_durabletask/_models.py | 1 + .../agent_as_tool_with_session_propagation.py | 2 +- 3 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 python/packages/core/tests/core/test_docstrings.py diff --git a/python/packages/core/tests/core/test_docstrings.py b/python/packages/core/tests/core/test_docstrings.py new file mode 100644 index 0000000000..ab4b116422 --- /dev/null +++ b/python/packages/core/tests/core/test_docstrings.py @@ -0,0 +1,175 @@ +# Copyright (c) Microsoft. All rights reserved. + +from agent_framework._docstrings import apply_layered_docstring, build_layered_docstring + +# -- Helpers: stub functions with various docstring shapes -- + + +def _source_with_full_docstring(x: int) -> int: + """Do something useful. + + Args: + x: The input value. + + Keyword Args: + timeout: Max seconds to wait. + + Returns: + The computed result. + """ + return x + + +def _source_with_args_only(x: int) -> int: + """Do something useful. + + Args: + x: The input value. + + Returns: + The computed result. + """ + return x + + +def _source_no_sections() -> None: + """A plain summary with no Google-style sections.""" + + +def _source_no_docstring() -> None: + pass + + +def _target_stub() -> None: + pass + + +# -- build_layered_docstring tests -- + + +def test_build_returns_none_when_source_has_no_docstring() -> None: + result = build_layered_docstring(_source_no_docstring) + assert result is None + + +def test_build_returns_original_when_no_extra_kwargs() -> None: + result = build_layered_docstring(_source_with_full_docstring) + assert result is not None + assert "Do something useful." in result + assert "Keyword Args:" in result + + +def test_build_returns_original_when_extra_kwargs_empty() -> None: + result = build_layered_docstring(_source_with_full_docstring, extra_keyword_args={}) + assert result is not None + assert result == build_layered_docstring(_source_with_full_docstring) + + +def test_build_appends_to_existing_keyword_args_section() -> None: + result = build_layered_docstring( + _source_with_full_docstring, + extra_keyword_args={"retries": "Number of retries."}, + ) + assert result is not None + assert "timeout: Max seconds to wait." in result + assert "retries: Number of retries." in result + # Both should be under Keyword Args + lines = result.splitlines() + kw_index = next(i for i, line in enumerate(lines) if line == "Keyword Args:") + ret_index = next(i for i, line in enumerate(lines) if line == "Returns:") + retries_index = next(i for i, line in enumerate(lines) if "retries:" in line) + assert kw_index < retries_index < ret_index + + +def test_build_inserts_keyword_args_after_args_section() -> None: + result = build_layered_docstring( + _source_with_args_only, + extra_keyword_args={"verbose": "Enable verbose output."}, + ) + assert result is not None + assert "Keyword Args:" in result + assert "verbose: Enable verbose output." in result + lines = result.splitlines() + args_index = next(i for i, line in enumerate(lines) if line == "Args:") + kw_index = next(i for i, line in enumerate(lines) if line == "Keyword Args:") + ret_index = next(i for i, line in enumerate(lines) if line == "Returns:") + assert args_index < kw_index < ret_index + + +def test_build_inserts_keyword_args_in_docstring_with_no_sections() -> None: + result = build_layered_docstring( + _source_no_sections, + extra_keyword_args={"debug": "Enable debug mode."}, + ) + assert result is not None + assert "A plain summary" in result + assert "Keyword Args:" in result + assert "debug: Enable debug mode." in result + + +def test_build_handles_multiline_descriptions() -> None: + result = build_layered_docstring( + _source_with_args_only, + extra_keyword_args={ + "config": "The configuration object.\nMust be a valid mapping.\nDefaults to empty.", + }, + ) + assert result is not None + lines = result.splitlines() + config_line = next(line for line in lines if "config:" in line) + assert "The configuration object." in config_line + # Continuation lines should be indented + config_idx = lines.index(config_line) + assert "Must be a valid mapping." in lines[config_idx + 1] + assert "Defaults to empty." in lines[config_idx + 2] + + +def test_build_preserves_multiple_extra_kwargs_order() -> None: + result = build_layered_docstring( + _source_with_args_only, + extra_keyword_args={ + "alpha": "First.", + "beta": "Second.", + "gamma": "Third.", + }, + ) + assert result is not None + lines = result.splitlines() + alpha_idx = next(i for i, line in enumerate(lines) if "alpha:" in line) + beta_idx = next(i for i, line in enumerate(lines) if "beta:" in line) + gamma_idx = next(i for i, line in enumerate(lines) if "gamma:" in line) + assert alpha_idx < beta_idx < gamma_idx + + +# -- apply_layered_docstring tests -- + + +def test_apply_sets_docstring_on_target() -> None: + def target() -> None: + pass + + apply_layered_docstring(target, _source_with_full_docstring) + assert target.__doc__ is not None + assert "Do something useful." in target.__doc__ + + +def test_apply_with_extra_kwargs() -> None: + def target() -> None: + pass + + apply_layered_docstring( + target, + _source_with_args_only, + extra_keyword_args={"flag": "A boolean flag."}, + ) + assert target.__doc__ is not None + assert "flag: A boolean flag." in target.__doc__ + assert "Keyword Args:" in target.__doc__ + + +def test_apply_sets_none_when_source_has_no_docstring() -> None: + def target() -> None: + """Original.""" + + apply_layered_docstring(target, _source_no_docstring) + assert target.__doc__ is None diff --git a/python/packages/durabletask/agent_framework_durabletask/_models.py b/python/packages/durabletask/agent_framework_durabletask/_models.py index c1654c99c0..19d5804bc2 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_models.py +++ b/python/packages/durabletask/agent_framework_durabletask/_models.py @@ -312,6 +312,7 @@ def from_session_id( @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentSession: """Create a DurableAgentSession from a state dict.""" + data = dict(data) # defensive copy — avoid mutating caller's dict session_id_value = data.pop(cls._SERIALIZED_SESSION_ID_KEY, None) session = super().from_dict(data) durable_session_id: AgentSessionId | None = None diff --git a/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py b/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py index 8199557792..30763a30f4 100644 --- a/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py +++ b/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py @@ -55,7 +55,7 @@ def store_findings(findings: str, ctx: FunctionInvocationContext) -> None: if current_findings is None: session.state["findings"] = findings else: - session.state["finding"] = f"{current_findings}\n{findings}" + session.state["findings"] = f"{current_findings}\n{findings}" @tool(description="Use this tool to gather the current findings from other agents.") From fb76cc7059a02b5079295956d1b9cb6bc7d99e6b Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Mar 2026 11:05:57 +0100 Subject: [PATCH 06/13] Add test for propagate_session TypeError on non-AgentSession values Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../packages/core/tests/core/test_agents.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 021545e082..7491de4b72 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -969,6 +969,34 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: assert parent_session.state["counter"] == 1 +async def test_chat_agent_as_tool_propagate_session_rejects_non_agent_session( + client: SupportsChatGetResponse, +) -> None: + """Test that propagate_session=True raises TypeError for non-AgentSession values.""" + agent = Agent(client=client, name="SubAgent", description="Sub agent") + tool = agent.as_tool(propagate_session=True) + + # A plain dict is truthy but not an AgentSession — should raise TypeError. + with raises(TypeError, match="not a AgentSession object"): + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + kwargs={"session": {"fake": "session"}}, + ) + ) + + # A string is also truthy and not an AgentSession. + with raises(TypeError, match="not a AgentSession object"): + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + kwargs={"session": "not-a-session"}, + ) + ) + + async def test_chat_agent_as_mcp_server_basic(client: SupportsChatGetResponse) -> None: """Test basic as_mcp_server functionality.""" agent = Agent(client=client, name="TestAgent", description="Test agent for MCP") From c8a86bc48b1737542cb3fa131483bb3ffe9b3f4e Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Mar 2026 11:16:57 +0100 Subject: [PATCH 07/13] Add tests for multi-content and empty UserInputRequiredException propagation Cover the branching logic in _try_execute_function_calls for: - Multiple user_input_request items in a single exception (extra_user_input_contents path) - Empty contents list (fallback function_result path) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../core/test_function_invocation_logic.py | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 50df7ccea1..6a8ba2eade 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -3553,3 +3553,93 @@ def delegate_tool(task: str) -> str: assert user_requests[0].type == "oauth_consent_request" assert user_requests[0].consent_link == "https://login.microsoftonline.com/consent" assert user_requests[0].user_input_request is True + + +async def test_user_input_request_multiple_contents_propagate(chat_client_base: SupportsChatGetResponse): + """Test that multiple user_input_request items in a single exception all propagate to the parent response.""" + from agent_framework.exceptions import UserInputRequiredException + + @tool(name="multi_request_tool", approval_mode="never_require") + def multi_request(task: str) -> str: + del task + raise UserInputRequiredException( + contents=[ + Content.from_oauth_consent_request( + consent_link="https://example.com/consent1", + ), + Content.from_oauth_consent_request( + consent_link="https://example.com/consent2", + ), + Content.from_oauth_consent_request( + consent_link="https://example.com/consent3", + ), + ] + ) + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="1", name="multi_request_tool", arguments='{"task": "do it"}'), + ], + ) + ) + ] + + response = await chat_client_base.get_response( + [Message(role="user", text="do something")], + options={"tool_choice": "auto", "tools": [multi_request]}, + ) + + user_requests = [ + content + for msg in response.messages + for content in msg.contents + if isinstance(content, Content) and content.user_input_request + ] + assert len(user_requests) == 3 + consent_links = {r.consent_link for r in user_requests} + assert consent_links == { + "https://example.com/consent1", + "https://example.com/consent2", + "https://example.com/consent3", + } + + +async def test_user_input_request_empty_contents_returns_fallback(chat_client_base: SupportsChatGetResponse): + """Test that UserInputRequiredException with empty contents produces a fallback function_result.""" + from agent_framework.exceptions import UserInputRequiredException + + @tool(name="empty_request_tool", approval_mode="never_require") + def empty_request(task: str) -> str: + del task + raise UserInputRequiredException(contents=[]) + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="1", name="empty_request_tool", arguments='{"task": "do it"}'), + ], + ) + ), + ChatResponse(messages=Message(role="assistant", text="handled")), + ] + + response = await chat_client_base.get_response( + [Message(role="user", text="do something")], + options={"tool_choice": "auto", "tools": [empty_request]}, + ) + + # With empty contents, the handler returns a function_result with an error message + # and the loop continues to the next chat response. + function_results = [ + content + for msg in response.messages + for content in msg.contents + if content.type == "function_result" + ] + assert len(function_results) >= 1 + assert any("user input" in (fr.result or "").lower() for fr in function_results) From 5de7d656bfd8a1080f029d1797b193f90688ca31 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Mar 2026 11:17:51 +0100 Subject: [PATCH 08/13] Add tests for DurableAIAgent.get_session forwarding service_session_id Verifies get_session correctly forwards service_session_id and session_id to the executor's get_new_session, replacing the removed kwargs test. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../packages/durabletask/tests/test_shim.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/packages/durabletask/tests/test_shim.py b/python/packages/durabletask/tests/test_shim.py index f713fdded2..cfbb447e87 100644 --- a/python/packages/durabletask/tests/test_shim.py +++ b/python/packages/durabletask/tests/test_shim.py @@ -184,6 +184,34 @@ def test_create_session_delegates_to_executor(self, test_agent: DurableAIAgent[A mock_executor.get_new_session.assert_called_once_with("test_agent") assert session == mock_session + def test_get_session_forwards_service_session_id( + self, test_agent: DurableAIAgent[Any], mock_executor: Mock + ) -> None: + """Verify get_session forwards service_session_id and session_id to executor.""" + mock_session = DurableAgentSession(service_session_id="svc-123") + mock_executor.get_new_session.return_value = mock_session + + session = test_agent.get_session("svc-123", session_id="local-456") + + mock_executor.get_new_session.assert_called_once_with( + "test_agent", service_session_id="svc-123", session_id="local-456" + ) + assert session.service_session_id == "svc-123" + + def test_get_session_without_session_id( + self, test_agent: DurableAIAgent[Any], mock_executor: Mock + ) -> None: + """Verify get_session works with only service_session_id (session_id defaults to None).""" + mock_session = DurableAgentSession(service_session_id="svc-789") + mock_executor.get_new_session.return_value = mock_session + + session = test_agent.get_session("svc-789") + + mock_executor.get_new_session.assert_called_once_with( + "test_agent", service_session_id="svc-789", session_id=None + ) + assert session.service_session_id == "svc-789" + class TestDurableAgentProviderInterface: """Test that DurableAgentProvider defines the correct interface.""" From 83bd4bd2a6e6803c8f0a38f594764c78772d1e8c Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Mar 2026 11:19:10 +0100 Subject: [PATCH 09/13] Simplify ag-ui test stub to read session from client_kwargs only Remove dual-mode detection (client_kwargs vs raw kwargs fallback) from the test mock. Session is now read exclusively from client_kwargs, matching the settled public calling convention. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- python/packages/ag-ui/tests/ag_ui/conftest.py | 8 ++++---- .../durabletask/agent_framework_durabletask/_shim.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/packages/ag-ui/tests/ag_ui/conftest.py b/python/packages/ag-ui/tests/ag_ui/conftest.py index 3e839358da..42a6967371 100644 --- a/python/packages/ag-ui/tests/ag_ui/conftest.py +++ b/python/packages/ag-ui/tests/ag_ui/conftest.py @@ -98,11 +98,11 @@ def get_response( options: OptionsCoT | ChatOptions[Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: - compatibility_client_kwargs = kwargs.get("client_kwargs") - if isinstance(compatibility_client_kwargs, Mapping): - self.last_session = cast(AgentSession | None, compatibility_client_kwargs.get("session")) + client_kwargs = kwargs.get("client_kwargs") + if isinstance(client_kwargs, Mapping): + self.last_session = cast(AgentSession | None, client_kwargs.get("session")) else: - self.last_session = cast(AgentSession | None, kwargs.get("session")) + self.last_session = None self.last_service_session_id = self.last_session.service_session_id if self.last_session else None return cast( Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]], diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index 077aae8b28..78e4224207 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -135,7 +135,7 @@ def run( # type: ignore[override] def create_session(self, *, session_id: str | None = None, **kwargs: Any) -> DurableAgentSession: """Create a new agent session via the provider.""" - return self._executor.get_new_session(self.name) + return self._executor.get_new_session(self.name, **kwargs) def get_session(self, service_session_id: str, *, session_id: str | None = None, **kwargs: Any) -> AgentSession: """Retrieve an existing session via the provider.""" From 0fe9cffe3747554367492aac520ca38c245ea552 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Mar 2026 11:23:43 +0100 Subject: [PATCH 10/13] updated create and get sessions in durable --- .../durabletask/agent_framework_durabletask/_shim.py | 6 +++--- python/packages/durabletask/tests/test_shim.py | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index 78e4224207..b21cac6831 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -133,11 +133,11 @@ def run( # type: ignore[override] session=session, ) - def create_session(self, *, session_id: str | None = None, **kwargs: Any) -> DurableAgentSession: + def create_session(self, *, session_id: str | None = None) -> DurableAgentSession: """Create a new agent session via the provider.""" - return self._executor.get_new_session(self.name, **kwargs) + return self._executor.get_new_session(self.name) - def get_session(self, service_session_id: str, *, session_id: str | None = None, **kwargs: Any) -> AgentSession: + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: """Retrieve an existing session via the provider.""" return self._executor.get_new_session(self.name, service_session_id=service_session_id, session_id=session_id) diff --git a/python/packages/durabletask/tests/test_shim.py b/python/packages/durabletask/tests/test_shim.py index cfbb447e87..687a0746a7 100644 --- a/python/packages/durabletask/tests/test_shim.py +++ b/python/packages/durabletask/tests/test_shim.py @@ -198,9 +198,7 @@ def test_get_session_forwards_service_session_id( ) assert session.service_session_id == "svc-123" - def test_get_session_without_session_id( - self, test_agent: DurableAIAgent[Any], mock_executor: Mock - ) -> None: + def test_get_session_without_session_id(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: """Verify get_session works with only service_session_id (session_id defaults to None).""" mock_session = DurableAgentSession(service_session_id="svc-789") mock_executor.get_new_session.return_value = mock_session From 6c498bc278b065b707b9fc70b83f40a1513ca042 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Mar 2026 11:28:35 +0100 Subject: [PATCH 11/13] fixed docstrings --- python/packages/core/agent_framework/_agents.py | 5 +++-- .../core/tests/core/test_function_invocation_logic.py | 5 +---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 953f205ecf..619e097803 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -503,7 +503,7 @@ def as_tool( If None, defaults to "Task for {tool_name}". approval_mode: Whether this delegated tool requires approval before execution. stream_callback: Optional callback for streaming responses. If provided, uses run(..., stream=True). - propagate_session: If True, this agent's get's a ``session`` from the + propagate_session: If True, this agent gets a ``session`` object from the calling agents, when one is supplied explicitly (for example via ``function_invocation_kwargs={"session": session}``). Defaults to False, meaning this agent runs without a session. @@ -563,7 +563,8 @@ async def _agent_wrapper(ctx: FunctionInvocationContext, **kwargs: Any) -> str: session = ctx.kwargs.get("session") if session and not isinstance(session, AgentSession): raise TypeError( - "The provided session is not a AgentSession object, please make sure to " + "The provided session is not an ``AgentSession`` object, " + f"got {type(session).__name__!r}, please make sure to " "pass it through the function_invocation_kwargs." ) stream = self.run( diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 6a8ba2eade..9eed4b9299 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -3636,10 +3636,7 @@ def empty_request(task: str) -> str: # With empty contents, the handler returns a function_result with an error message # and the loop continues to the next chat response. function_results = [ - content - for msg in response.messages - for content in msg.contents - if content.type == "function_result" + content for msg in response.messages for content in msg.contents if content.type == "function_result" ] assert len(function_results) >= 1 assert any("user input" in (fr.result or "").lower() for fr in function_results) From 9d72df5aa7cf3b0eef23d695001e0e40026844c8 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Mar 2026 11:31:55 +0100 Subject: [PATCH 12/13] fix test --- python/packages/core/tests/core/test_agents.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 7491de4b72..63674e88ea 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -977,7 +977,7 @@ async def test_chat_agent_as_tool_propagate_session_rejects_non_agent_session( tool = agent.as_tool(propagate_session=True) # A plain dict is truthy but not an AgentSession — should raise TypeError. - with raises(TypeError, match="not a AgentSession object"): + with raises(TypeError, match="The provided session is not an ``AgentSession`` object"): await tool.invoke( context=FunctionInvocationContext( function=tool, @@ -987,7 +987,7 @@ async def test_chat_agent_as_tool_propagate_session_rejects_non_agent_session( ) # A string is also truthy and not an AgentSession. - with raises(TypeError, match="not a AgentSession object"): + with raises(TypeError, match="The provided session is not an ``AgentSession`` object"): await tool.invoke( context=FunctionInvocationContext( function=tool, From 69d40d7ca7fe2a1cba523c7478d65b0895ad57ad Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 11 Mar 2026 13:25:20 +0100 Subject: [PATCH 13/13] updated session handling --- .../packages/core/agent_framework/_agents.py | 23 +++----- .../core/agent_framework/_middleware.py | 4 ++ .../packages/core/agent_framework/_tools.py | 21 ++++++- .../packages/core/tests/core/test_agents.py | 56 ++++--------------- .../agent_as_tool_with_session_propagation.py | 44 +++++---------- .../function_tool_with_session_injection.py | 21 +++---- 6 files changed, 65 insertions(+), 104 deletions(-) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 619e097803..e7c7b71b80 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -503,10 +503,9 @@ def as_tool( If None, defaults to "Task for {tool_name}". approval_mode: Whether this delegated tool requires approval before execution. stream_callback: Optional callback for streaming responses. If provided, uses run(..., stream=True). - propagate_session: If True, this agent gets a ``session`` object from the - calling agents, when one is supplied explicitly (for example via - ``function_invocation_kwargs={"session": session}``). Defaults - to False, meaning this agent runs without a session. + propagate_session: If True, the parent agent's session is forwarded + to this sub-agent's ``run()`` call so both agents share the + same session. Defaults to False. Returns: A FunctionTool that can be used as a tool by other agents. @@ -522,8 +521,7 @@ def as_tool( # Convert the agent to a tool (independent session) research_tool = agent.as_tool() - # Convert the agent to a tool (shared session when the caller - # passes ``function_invocation_kwargs={"session": session}``) + # Convert the agent to a tool (shared session with parent) research_tool = agent.as_tool(propagate_session=True) # Use the tool with another agent @@ -558,19 +556,10 @@ async def _agent_wrapper(ctx: FunctionInvocationContext, **kwargs: Any) -> str: ctx: the function invocation context used **kwargs: only used to dynamically load the argument that is defined for this tool. """ - session = None - if propagate_session: - session = ctx.kwargs.get("session") - if session and not isinstance(session, AgentSession): - raise TypeError( - "The provided session is not an ``AgentSession`` object, " - f"got {type(session).__name__!r}, please make sure to " - "pass it through the function_invocation_kwargs." - ) stream = self.run( str(kwargs.get(arg_name, "")), stream=True, - session=session, + session=ctx.session if propagate_session else None, function_invocation_kwargs=dict(ctx.kwargs), ) if stream_callback is not None: @@ -1197,6 +1186,8 @@ async def _prepare_run_context( **dict(legacy_kwargs), **(dict(client_kwargs) if client_kwargs is not None else {}), } + if active_session is not None: + effective_client_kwargs["session"] = active_session return { "session": active_session, diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index f1f3b234d0..f475fac7ca 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -197,6 +197,7 @@ class FunctionInvocationContext: Attributes: function: The function being invoked. arguments: The validated arguments for the function. + session: The agent session for this invocation, if any. metadata: Metadata dictionary for sharing data between function middleware. result: Function execution result. Can be observed after calling ``call_next()`` to see the actual execution result or can be set to override the execution result. @@ -225,6 +226,7 @@ def __init__( self, function: FunctionTool, arguments: BaseModel | Mapping[str, Any], + session: AgentSession | None = None, metadata: Mapping[str, Any] | None = None, result: Any = None, kwargs: Mapping[str, Any] | None = None, @@ -234,12 +236,14 @@ def __init__( Args: function: The function being invoked. arguments: The validated arguments for the function. + session: The agent session for this invocation, if any. metadata: Metadata dictionary for sharing data between function middleware. result: Function execution result. kwargs: Additional runtime keyword arguments forwarded to the function invocation. """ self.function = function self.arguments = arguments + self.session = session self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {} self.result = result self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {} diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 31def22f1c..505580c207 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -63,6 +63,7 @@ from ._clients import SupportsChatGetResponse from ._mcp import MCPTool from ._middleware import FunctionInvocationContext, FunctionMiddlewarePipeline, FunctionMiddlewareTypes + from ._sessions import AgentSession from ._types import ( ChatOptions, ChatResponse, @@ -1224,9 +1225,10 @@ async def _auto_invoke_function( *, config: FunctionInvocationConfiguration, tool_map: dict[str, FunctionTool], + invocation_session: AgentSession | None = None, sequence_index: int | None = None, request_index: int | None = None, - middleware_pipeline: FunctionMiddlewarePipeline | None = None, # Optional MiddlewarePipeline + middleware_pipeline: FunctionMiddlewarePipeline | None = None, ) -> Content: """Invoke a function call requested by the agent, applying middleware that is defined. @@ -1237,6 +1239,7 @@ async def _auto_invoke_function( Keyword Args: config: The function invocation configuration. tool_map: A mapping of tool names to FunctionTool instances. + invocation_session: The agent session for this invocation, if any. sequence_index: The index of the function call in the sequence. request_index: The index of the request iteration. middleware_pipeline: Optional middleware pipeline to apply during execution. @@ -1288,6 +1291,8 @@ async def _auto_invoke_function( for key, value in (custom_args or {}).items() if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"} } + if invocation_session is not None: + runtime_kwargs["session"] = invocation_session try: if not cast(bool, getattr(tool, "_schema_supplied", False)) and tool.input_model is not None: args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True) @@ -1319,6 +1324,7 @@ async def _auto_invoke_function( direct_context = FunctionInvocationContext( function=tool, arguments=args, + session=invocation_session, kwargs=runtime_kwargs.copy(), ) function_result = await tool.invoke( @@ -1347,6 +1353,7 @@ async def _auto_invoke_function( middleware_context = FunctionInvocationContext( function=tool, arguments=args, + session=invocation_session, kwargs=runtime_kwargs.copy(), ) @@ -1407,7 +1414,8 @@ async def _try_execute_function_calls( function_calls: Sequence[Content], tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]], config: FunctionInvocationConfiguration, - middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports + invocation_session: AgentSession | None = None, + middleware_pipeline: Any = None, ) -> tuple[Sequence[Content], bool]: """Execute multiple function calls concurrently. @@ -1417,6 +1425,7 @@ async def _try_execute_function_calls( function_calls: A sequence of FunctionCallContent to execute. tools: The tools available for execution. config: Configuration for function invocation. + invocation_session: The agent session for this invocation, if any. middleware_pipeline: Optional middleware pipeline to apply during execution. Returns: @@ -1498,6 +1507,7 @@ async def invoke_with_termination_handling( function_call_content=function_call, # type: ignore[arg-type] custom_args=custom_args, tool_map=tool_map, + invocation_session=invocation_session, sequence_index=seq_idx, request_index=attempt_idx, middleware_pipeline=middleware_pipeline, @@ -1554,6 +1564,7 @@ async def _execute_function_calls( function_calls: list[Content], tool_options: dict[str, Any] | None, config: FunctionInvocationConfiguration, + invocation_session: AgentSession | None = None, middleware_pipeline: Any = None, ) -> tuple[list[Content], bool, bool]: tools = _extract_tools(tool_options) @@ -1564,6 +1575,7 @@ async def _execute_function_calls( attempt_idx=attempt_idx, function_calls=function_calls, tools=tools, # type: ignore + invocation_session=invocation_session, middleware_pipeline=middleware_pipeline, config=config, ) @@ -2017,10 +2029,15 @@ def get_response( ) if options and (additional_opts := options.get("additional_function_arguments")): # type: ignore[attr-defined] additional_function_arguments.update(cast(Mapping[str, Any], additional_opts)) + from ._sessions import AgentSession as _AgentSession + + raw_session = effective_client_kwargs.get("session") + invocation_session = raw_session if isinstance(raw_session, _AgentSession) else None execute_function_calls = partial( _execute_function_calls, custom_args=additional_function_arguments, config=self.function_invocation_configuration, + invocation_session=invocation_session, middleware_pipeline=function_middleware_pipeline, ) filtered_kwargs = {k: v for k, v in {**effective_client_kwargs, **kwargs}.items() if k != "session"} diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 63674e88ea..352f562417 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -276,7 +276,7 @@ async def test_prepare_run_context_separates_function_invocation_kwargs_from_cha assert ctx["function_invocation_kwargs"]["runtime_key"] == "runtime-value" assert "session" not in ctx["function_invocation_kwargs"] assert ctx["client_kwargs"]["client_key"] == "client-value" - assert "session" not in ctx["client_kwargs"] + assert ctx["client_kwargs"]["session"] is session async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChatGetResponse) -> None: @@ -877,14 +877,13 @@ async def test_chat_agent_as_tool_name_sanitization( async def test_chat_agent_as_tool_propagate_session_true(client: SupportsChatGetResponse) -> None: - """Test that propagate_session=True forwards an explicitly provided session to the sub-agent.""" + """Test that propagate_session=True forwards the session to the sub-agent.""" agent = Agent(client=client, name="SubAgent", description="Sub agent") tool = agent.as_tool(propagate_session=True) parent_session = AgentSession(session_id="parent-session-123") parent_session.state["shared_key"] = "shared_value" - # Spy on the agent's run method to capture the session argument original_run = agent.run captured_session = None @@ -899,7 +898,7 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: context=FunctionInvocationContext( function=tool, arguments={"task": "Hello"}, - kwargs={"session": parent_session}, + session=parent_session, ) ) @@ -909,7 +908,7 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: async def test_chat_agent_as_tool_propagate_session_false_by_default(client: SupportsChatGetResponse) -> None: - """Test that propagate_session defaults to False and does not forward runtime sessions.""" + """Test that propagate_session defaults to False and does not forward the session.""" agent = Agent(client=client, name="SubAgent", description="Sub agent") tool = agent.as_tool() # default: propagate_session=False @@ -929,7 +928,7 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: context=FunctionInvocationContext( function=tool, arguments={"task": "Hello"}, - kwargs={"session": parent_session}, + session=parent_session, ) ) @@ -937,14 +936,13 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: async def test_chat_agent_as_tool_propagate_session_shares_state(client: SupportsChatGetResponse) -> None: - """Test that an explicitly propagated session allows the sub-agent to read and write parent state.""" + """Test that a propagated session allows the sub-agent to read and write parent state.""" agent = Agent(client=client, name="SubAgent", description="Sub agent") tool = agent.as_tool(propagate_session=True) parent_session = AgentSession(session_id="shared-session") parent_session.state["counter"] = 0 - # The sub-agent receives the same session object, so mutations are shared original_run = agent.run captured_session = None @@ -961,42 +959,13 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: context=FunctionInvocationContext( function=tool, arguments={"task": "Hello"}, - kwargs={"session": parent_session}, + session=parent_session, ) ) - # The parent's state should reflect the sub-agent's mutation assert parent_session.state["counter"] == 1 -async def test_chat_agent_as_tool_propagate_session_rejects_non_agent_session( - client: SupportsChatGetResponse, -) -> None: - """Test that propagate_session=True raises TypeError for non-AgentSession values.""" - agent = Agent(client=client, name="SubAgent", description="Sub agent") - tool = agent.as_tool(propagate_session=True) - - # A plain dict is truthy but not an AgentSession — should raise TypeError. - with raises(TypeError, match="The provided session is not an ``AgentSession`` object"): - await tool.invoke( - context=FunctionInvocationContext( - function=tool, - arguments={"task": "Hello"}, - kwargs={"session": {"fake": "session"}}, - ) - ) - - # A string is also truthy and not an AgentSession. - with raises(TypeError, match="The provided session is not an ``AgentSession`` object"): - await tool.invoke( - context=FunctionInvocationContext( - function=tool, - arguments={"task": "Hello"}, - kwargs={"session": "not-a-session"}, - ) - ) - - async def test_chat_agent_as_mcp_server_basic(client: SupportsChatGetResponse) -> None: """Test basic as_mcp_server functionality.""" agent = Agent(client=client, name="TestAgent", description="Test agent for MCP") @@ -1095,18 +1064,17 @@ async def capturing_inner( assert len(tool_names) == 3 -async def test_agent_tool_receives_explicit_session_via_function_invocation_context_kwargs( +async def test_agent_tool_receives_session_via_function_invocation_context( chat_client_base: Any, ) -> None: - """Verify ctx-based tools read explicit sessions from FunctionInvocationContext.kwargs.""" + """Verify ctx-based tools receive the session via FunctionInvocationContext.session.""" captured: dict[str, Any] = {} @tool(name="capture_session_context", approval_mode="never_require") def capture_session_context(text: str, ctx: FunctionInvocationContext) -> str: - session = ctx.kwargs.get("session") - captured["session"] = session - captured["has_state"] = session.state is not None if isinstance(session, AgentSession) else False + captured["session"] = ctx.session + captured["has_state"] = ctx.session.state is not None if isinstance(ctx.session, AgentSession) else False return f"echo: {text}" chat_client_base.run_responses = [ @@ -1128,7 +1096,7 @@ def capture_session_context(text: str, ctx: FunctionInvocationContext) -> str: agent = Agent(client=chat_client_base, tools=[capture_session_context]) session = agent.create_session() - result = await agent.run("hello", session=session, function_invocation_kwargs={"session": session}) + result = await agent.run("hello", session=session) assert result.text == "done" assert captured["session"] is session diff --git a/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py b/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py index 30763a30f4..fa78a9ede5 100644 --- a/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py +++ b/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py @@ -17,12 +17,7 @@ When session propagation is enabled, both agents share the same session object, including session_id and the mutable state dict. This allows correlated -conversation tracking and shared state across the agent hierarchy. The session -must be passed explicitly through ``function_invocation_kwargs`` for the -delegated tool call. - -The middleware functions below are purely for observability — they are NOT -required for session propagation to work. +conversation tracking and shared state across the agent hierarchy. """ @@ -30,11 +25,7 @@ async def log_session( context: AgentContext, call_next: Callable[[], Awaitable[None]], ) -> None: - """Agent middleware that logs the session received by each agent. - - NOT required for session propagation — only used to observe the flow. - If propagation is working, both agents will show the same session_id. - """ + """Agent middleware that logs the session received by each agent.""" session: AgentSession | None = context.session if not session: print("No session found.") @@ -50,18 +41,20 @@ async def log_session( @tool(description="Use this tool to store the findings so that other agents can reason over them.") def store_findings(findings: str, ctx: FunctionInvocationContext) -> None: - session = ctx.kwargs.get("session") - current_findings = session.state["findings"] + if ctx.session is None: + return + current_findings = ctx.session.state.get("findings") if current_findings is None: - session.state["findings"] = findings + ctx.session.state["findings"] = findings else: - session.state["findings"] = f"{current_findings}\n{findings}" + ctx.session.state["findings"] = f"{current_findings}\n{findings}" @tool(description="Use this tool to gather the current findings from other agents.") def recall_findings(ctx: FunctionInvocationContext) -> str: - session = ctx.kwargs.get("session") - current_findings = session.state["findings"] + if ctx.session is None: + return "No session available" + current_findings = ctx.session.state.get("findings") if current_findings is None: return "Nothing yet" return current_findings @@ -72,8 +65,6 @@ async def main() -> None: client = OpenAIResponsesClient() - # --- Sub-agent: a research specialist --- - # The sub-agent has the same log_session middleware to prove it receives the session. research_agent = client.as_agent( name="ResearchAgent", instructions="You are a research assistant. Provide concise answers and store your findings.", @@ -81,7 +72,6 @@ async def main() -> None: tools=[store_findings, recall_findings], ) - # propagate_session=True forwards an explicitly supplied runtime session. research_tool = research_agent.as_tool( name="research", description="Research a topic and store your findings.", @@ -90,15 +80,16 @@ async def main() -> None: propagate_session=True, ) - # --- Coordinator agent --- coordinator = client.as_agent( name="CoordinatorAgent", - instructions="You coordinate research. Use the 'research' tool to start research and then use the recall findings tool to gather up everything. You can also start by storing some of the background directly.", + instructions=( + "You coordinate research. Use the 'research' tool to start research " + "and then use the recall findings tool to gather up everything." + ), tools=[research_tool, store_findings, recall_findings], middleware=[log_session], ) - # Create a shared session and put some state in it session = coordinator.create_session() session.state["findings"] = None print(f"Session ID: {session.session_id}") @@ -107,15 +98,10 @@ async def main() -> None: query = "What are the latest developments in quantum computing and in AI?" print(f"User: {query}\n") - result = await coordinator.run( - query, - session=session, - function_invocation_kwargs={"session": session}, - ) + result = await coordinator.run(query, session=session) print(f"\nCoordinator: {result}\n") print(f"Session state after run: {session.state}") - print("\nIf both agents show the same session_id above, session propagation is working.") if __name__ == "__main__": diff --git a/python/samples/02-agents/tools/function_tool_with_session_injection.py b/python/samples/02-agents/tools/function_tool_with_session_injection.py index a21a5b82b4..53cc63c2c0 100644 --- a/python/samples/02-agents/tools/function_tool_with_session_injection.py +++ b/python/samples/02-agents/tools/function_tool_with_session_injection.py @@ -14,10 +14,9 @@ """ AI Function with Session Injection Example -This example demonstrates explicitly passing an ``AgentSession`` through -``function_invocation_kwargs`` and reading it from ``FunctionInvocationContext.kwargs``. -The injected context parameter can be typed as ``FunctionInvocationContext`` as -shown here, or left untyped as ``ctx`` when you want the conventional untyped form. +This example demonstrates accessing the agent session inside a tool function +via ``FunctionInvocationContext.session``. The session is automatically +available when the agent is invoked with a session. """ @@ -29,9 +28,7 @@ async def get_weather( ctx: FunctionInvocationContext, ) -> str: """Get the weather for a given location.""" - # FunctionInvocationContext does not surface agent sessions directly. - # If a tool needs session data, pass it explicitly through function_invocation_kwargs. - session = ctx.kwargs.get("session") + session = ctx.session if session and isinstance(session, AgentSession) and session.service_session_id: print(f"Session ID: {session.service_session_id}.") @@ -49,16 +46,14 @@ async def main() -> None: # Create a session session = agent.create_session() - # Pass the session explicitly through function_invocation_kwargs when the tool needs it. + # Run the agent with the session; tools receive it via ctx.session. print( - f"Agent: {await agent.run('What is the weather in London?', session=session, function_invocation_kwargs={'session': session})}" + f"Agent: {await agent.run('What is the weather in London?', session=session)}" ) print( - f"Agent: {await agent.run('What is the weather in Amsterdam?', session=session, function_invocation_kwargs={'session': session})}" - ) - print( - f"Agent: {await agent.run('What cities did I ask about?', session=session, function_invocation_kwargs={'session': session})}" + f"Agent: {await agent.run('What is the weather in Amsterdam?', session=session)}" ) + print(f"Agent: {await agent.run('What cities did I ask about?', session=session)}") if __name__ == "__main__":