diff --git a/docs/decisions/0001-agent-run-response.md b/docs/decisions/0001-agent-run-response.md index 12724aca3a..6ffebe7e4f 100644 --- a/docs/decisions/0001-agent-run-response.md +++ b/docs/decisions/0001-agent-run-response.md @@ -4,8 +4,8 @@ status: accepted contact: westey-m date: 2025-07-10 {YYYY-MM-DD when the decision was last updated} deciders: sergeymenshykh, markwallace, rbarreto, dmytrostruk, westey-m, eavanvalkenburg, stephentoub -consulted: -informed: +consulted: +informed: --- # Agent Run Responses Design @@ -64,7 +64,7 @@ Approaches observed from the compared SDKs: | AutoGen | **Approach 1** Separates messages into Agent-Agent (maps to Primary) and Internal (maps to Secondary) and these are returned as separate properties on the agent response object. See [types of messages](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/messages.html#types-of-messages) and [Response](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.base.html#autogen_agentchat.base.Response) | **Approach 2** Returns a stream of internal events and the last item is a Response object. See [ChatAgent.on_messages_stream](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.base.html#autogen_agentchat.base.ChatAgent.on_messages_stream) | | OpenAI Agent SDK | **Approach 1** Separates new_items (Primary+Secondary) from final output (Primary) as separate properties on the [RunResult](https://github.com/openai/openai-agents-python/blob/main/src/agents/result.py#L39) | **Approach 1** Similar to non-streaming, has a way of streaming updates via a method on the response object which includes all data, and then a separate final output property on the response object which is populated only when the run is complete. See [RunResultStreaming](https://github.com/openai/openai-agents-python/blob/main/src/agents/result.py#L136) | | Google ADK | **Approach 2** [Emits events](https://google.github.io/adk-docs/runtime/#step-by-step-breakdown) with [FinalResponse](https://github.com/google/adk-java/blob/main/core/src/main/java/com/google/adk/events/Event.java#L232) true (Primary) / false (Secondary) and callers have to filter out those with false to get just the final response message | **Approach 2** Similar to non-streaming except [events](https://google.github.io/adk-docs/runtime/#streaming-vs-non-streaming-output-partialtrue) are emitted with [Partial](https://github.com/google/adk-java/blob/main/core/src/main/java/com/google/adk/events/Event.java#L133) true to indicate that they are streaming messages. A final non partial event is also emitted. | -| AWS (Strands) | **Approach 3** Returns an [AgentResult](https://strandsagents.com/docs/api/python/strands.agent.agent_result/#agentresult) (Primary) with messages and a reason for the run's completion. | **Approach 2** [Streams events](https://strandsagents.com/docs/user-guide/concepts/streaming/) (Primary+Secondary) including, response text, current_tool_use, even data from "callbacks" (strands plugins) | +| AWS (Strands) | **Approach 3** Returns an [AgentResult](https://strandsagents.com/docs/api/python/strands.agent.agent_result/) (Primary) with messages and a reason for the run's completion. | **Approach 2** [Streams events](https://strandsagents.com/docs/api/python/strands.agent.agent/) (Primary+Secondary) including, response text, current_tool_use, even data from "callbacks" (strands plugins) | | LangGraph | **Approach 2** A mixed list of all [messages](https://langchain-ai.github.io/langgraph/agents/run_agents/#output-format) | **Approach 2** A mixed list of all [messages](https://langchain-ai.github.io/langgraph/agents/run_agents/#output-format) | | Agno | **Combination of various approaches** Returns a [RunResponse](https://docs.agno.com/reference/agents/run-response) object with text content, messages (essentially chat history including inputs and instructions), reasoning and thinking text properties. Secondary events could potentially be extracted from messages. | **Approach 2** Returns [RunResponseEvent](https://docs.agno.com/reference/agents/run-response#runresponseevent-types-and-attributes) objects including tool call, memory update, etc, information, where the [RunResponseCompletedEvent](https://docs.agno.com/reference/agents/run-response#runresponsecompletedevent) has similar properties to RunResponse| | A2A | **Approach 3** Returns a [Task or Message](https://a2aproject.github.io/A2A/latest/specification/#71-messagesend) where the message is the final result (Primary) and task is a reference to a long running process. | **Approach 2** Returns a [stream](https://a2aproject.github.io/A2A/latest/specification/#72-messagestream) that contains task updates (Secondary) and a final message (Primary) | @@ -496,7 +496,7 @@ We need to decide what AIContent types, each agent response type will be mapped |-|-| | AutoGen | **Approach 1** Supports [configuring an agent](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/agents.html#structured-output) at agent creation. | | Google ADK | **Approach 1** Both [input and output schemas can be specified for LLM Agents](https://google.github.io/adk-docs/agents/llm-agents/#structuring-data-input_schema-output_schema-output_key) at construction time. This option is specific to this agent type and other agent types do not necessarily support | -| AWS (Strands) | **Approach 2** Supports a special invocation method called [structured_output](https://strandsagents.com/docs/user-guide/concepts/agents/structured-output/) | +| AWS (Strands) | **Approach 2** Supports a special invocation method called [structured_output](https://strandsagents.com/docs/api/python/strands.agent.agent/) | | LangGraph | **Approach 1** Supports [configuring an agent](https://langchain-ai.github.io/langgraph/agents/agents/?h=structured#6-configure-structured-output) at agent construction time, and a [structured response](https://langchain-ai.github.io/langgraph/agents/run_agents/#output-format) can be retrieved as a special property on the agent response | | Agno | **Approach 1** Supports [configuring an agent](https://docs.agno.com/input-output/structured-output/agent) at agent construction time | | A2A | **Informal Approach 2** Doesn't formally support schema negotiation, but [hints can be provided via metadata](https://a2a-protocol.org/latest/specification/#97-structured-data-exchange-requesting-and-providing-json) at invocation time | @@ -508,7 +508,7 @@ We need to decide what AIContent types, each agent response type will be mapped |-|-| | AutoGen | Supports a [stop reason](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.base.html#autogen_agentchat.base.TaskResult.stop_reason) which is a freeform text string | | Google ADK | [No equivalent present](https://github.com/google/adk-python/blob/main/src/google/adk/events/event.py) | -| AWS (Strands) | Exposes a `stop_reason` property on the [AgentResult](https://strandsagents.com/docs/api/python/strands.agent.agent_result/#agentresult) class with options that are tied closely to LLM operations. | +| AWS (Strands) | Exposes a [stop_reason](https://strandsagents.com/docs/api/python/strands.types.event_loop/) property on the [AgentResult](https://strandsagents.com/docs/api/python/strands.agent.agent_result/) class with options that are tied closely to LLM operations. | | LangGraph | No equivalent present, output contains only [messages](https://langchain-ai.github.io/langgraph/agents/run_agents/#output-format) | | Agno | [No equivalent present](https://docs.agno.com/reference/agents/run-response) | | A2A | No equivalent present, response only contains a [message](https://a2a-protocol.org/latest/specification/#64-message-object) or [task](https://a2a-protocol.org/latest/specification/#61-task-object). | diff --git a/python/CODING_STANDARD.md b/python/CODING_STANDARD.md index ccb8e058e3..8611592692 100644 --- a/python/CODING_STANDARD.md +++ b/python/CODING_STANDARD.md @@ -127,7 +127,12 @@ def create_agent(name: str, tool_mode: Literal['auto', 'required', 'none'] | Cha Avoid `**kwargs` unless absolutely necessary. It should only be used as an escape route, not for well-known flows of data: - **Prefer named parameters**: If there are known extra arguments being passed, use explicit named parameters instead of kwargs +- **Prefer purpose-specific buckets over generic kwargs**: If a flexible payload is still needed, use an explicit named parameter such as `additional_properties`, `function_invocation_kwargs`, or `client_kwargs` rather than a blanket `**kwargs` - **Subclassing support**: kwargs is acceptable in methods that are part of classes designed for subclassing, allowing subclass-defined kwargs to pass through without issues. In this case, clearly document that kwargs exists for subclass extensibility and not for passing arbitrary data +- **Make known flows explicit first**: For abstract hooks, move known data flows into explicit parameters before leaving `**kwargs` behind for subclass extensibility (for example, prefer `state=` explicitly instead of passing it through kwargs) +- **Prefer explicit metadata containers**: For constructors that expose metadata, prefer an explicit `additional_properties` parameter. +- **Keep SDK passthroughs narrow and documented**: A kwargs escape hatch may be acceptable for provider helper APIs that pass through to a large or unstable external SDK surface, but it should be documented as SDK passthrough and revisited regularly +- **Do not keep passthrough kwargs on wrappers that do not use them**: Convenience wrappers and session helpers should not accept generic kwargs merely to forward or ignore them - **Remove when possible**: In other cases, removing kwargs is likely better than keeping it - **Separate kwargs by purpose**: When combining kwargs for multiple purposes, use specific parameters like `client_kwargs: dict[str, Any]` instead of mixing everything in `**kwargs` - **Always document**: If kwargs must be used, always document how it's used, either by referencing external documentation or explaining its purpose diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 31fac386b3..f9f2afaeff 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -6,7 +6,7 @@ import json import re import uuid -from collections.abc import AsyncIterable, Awaitable, Sequence +from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence from typing import Any, Final, Literal, TypeAlias, overload import httpx @@ -218,6 +218,8 @@ def run( *, stream: Literal[False] = ..., session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, continuation_token: A2AContinuationToken | None = None, background: bool = False, **kwargs: Any, @@ -230,17 +232,21 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, continuation_token: A2AContinuationToken | None = None, background: bool = False, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - def run( + def run( # pyright: ignore[reportIncompatibleMethodOverride] self, messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, continuation_token: A2AContinuationToken | None = None, background: bool = False, **kwargs: Any, @@ -253,17 +259,23 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). + function_invocation_kwargs: Present for compatibility with the shared agent interface. + A2AAgent does not use these values directly. + client_kwargs: Present for compatibility with the shared agent interface. + A2AAgent does not use these values directly. + kwargs: Additional compatibility keyword arguments. + A2AAgent does not use these values directly. continuation_token: Optional token to resume a long-running task instead of starting a new one. background: When True, in-progress task updates surface continuation tokens so the caller can poll or resubscribe later. When False (default), the agent internally waits for the task to complete. - kwargs: Additional keyword arguments. Returns: When stream=False: An Awaitable[AgentResponse]. When stream=True: A ResponseStream of AgentResponseUpdate items. """ + del function_invocation_kwargs, client_kwargs, kwargs if continuation_token is not None: a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe( TaskIdParams(id=continuation_token["task_id"]) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index 7188eb739c..d2fb59bbb6 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -220,7 +220,6 @@ def __init__( additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, - **kwargs: Any, ) -> None: """Initialize the AG-UI chat client. @@ -231,13 +230,11 @@ def __init__( additional_properties: Additional properties to store middleware: Optional middleware to apply to the client. function_invocation_configuration: Optional function invocation configuration override. - **kwargs: Additional arguments passed to BaseChatClient """ super().__init__( additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) self._http_service = AGUIHttpService( endpoint=endpoint, diff --git a/python/packages/ag-ui/tests/ag_ui/conftest.py b/python/packages/ag-ui/tests/ag_ui/conftest.py index b73eddb8ad..42a6967371 100644 --- a/python/packages/ag-ui/tests/ag_ui/conftest.py +++ b/python/packages/ag-ui/tests/ag_ui/conftest.py @@ -98,7 +98,11 @@ def get_response( options: OptionsCoT | ChatOptions[Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: - self.last_session = kwargs.get("session") + client_kwargs = kwargs.get("client_kwargs") + if isinstance(client_kwargs, Mapping): + self.last_session = cast(AgentSession | None, client_kwargs.get("session")) + else: + self.last_session = None self.last_service_session_id = self.last_session.service_session_id if self.last_session else None return cast( Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]], diff --git a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py index 75cb659633..d5e7b0ae83 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py @@ -702,14 +702,9 @@ async def test_agent_with_use_service_session_is_true(streaming_chat_client_stub """Test that when use_service_session is True, the AgentSession used to run the agent is set to the service session ID.""" from agent_framework.ag_ui import AgentFrameworkAgent - request_service_session_id: str | None = None - async def stream_fn( messages: MutableSequence[Message], chat_options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - nonlocal request_service_session_id - session = kwargs.get("session") - request_service_session_id = session.service_session_id if session else None yield ChatResponseUpdate( contents=[Content.from_text(text="Response")], response_id="resp_67890", conversation_id="conv_12345" ) @@ -719,11 +714,22 @@ async def stream_fn( input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"} + # Spy on agent.run to capture the session kwarg at call time (before streaming mutates it) + captured_service_session_id: str | None = None + original_run = agent.run + + def capturing_run(*args: Any, **kwargs: Any) -> Any: + nonlocal captured_service_session_id + session = kwargs.get("session") + captured_service_session_id = session.service_session_id if session else None + return original_run(*args, **kwargs) + + agent.run = capturing_run # type: ignore[assignment, method-assign] + events: list[Any] = [] async for event in wrapper.run(input_data): events.append(event) - request_service_session_id = agent.client.last_service_session_id - assert request_service_session_id == "conv_123456" # type: ignore[attr-defined] (service_session_id should be set) + assert captured_service_session_id == "conv_123456" async def test_function_approval_mode_executes_tool(streaming_chat_client_stub): diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index 5cda4991c8..32805b7dfe 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -228,11 +228,11 @@ def __init__( model_id: str | None = None, anthropic_client: AsyncAnthropic | None = None, additional_beta_flags: list[str] | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Anthropic Agent client. @@ -244,11 +244,11 @@ def __init__( For instance if you need to set a different base_url for testing or private deployments. additional_beta_flags: Additional beta flags to enable on the client. Default flags are: "mcp-client-2025-04-04", "code-execution-2025-08-25". + additional_properties: Additional properties stored on the client instance. middleware: Optional middleware to apply to the client. function_invocation_configuration: Optional function invocation configuration override. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. - kwargs: Additional keyword arguments passed to the parent class. Examples: .. code-block:: python @@ -319,9 +319,9 @@ class MyOptions(AnthropicChatOptions, total=False): # Initialize parent super().__init__( + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) # Initialize instance variables diff --git a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py index 3c4fb68fe8..9972f1301d 100644 --- a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py +++ b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py @@ -16,6 +16,18 @@ # -- Helpers ------------------------------------------------------------------- +@pytest.fixture(autouse=True) +def clear_azure_search_env(monkeypatch: pytest.MonkeyPatch) -> None: + """Keep tests isolated from ambient Azure Search environment variables.""" + for key in ( + "AZURE_SEARCH_ENDPOINT", + "AZURE_SEARCH_INDEX_NAME", + "AZURE_SEARCH_KNOWLEDGE_BASE_NAME", + "AZURE_SEARCH_API_KEY", + ): + monkeypatch.delenv(key, raising=False) + + class MockSearchResults: """Async-iterable mock for Azure SearchClient.search() results.""" diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index 4c0e3a56e7..407f524d52 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -444,11 +444,11 @@ def __init__( model_deployment_name: str | None = None, credential: AzureCredentialTypes | None = None, should_cleanup_agent: bool = True, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Azure AI Agent client. @@ -471,11 +471,11 @@ def __init__( should_cleanup_agent: Whether to cleanup (delete) agents created by this client when the client is closed or context is exited. Defaults to True. Only affects agents created by this client instance; existing agents passed via agent_id are never deleted. + additional_properties: Additional properties stored on the client instance. middleware: Optional sequence of middlewares to include. function_invocation_configuration: Optional function invocation configuration. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. - kwargs: Additional keyword arguments passed to the parent class. Examples: .. code-block:: python @@ -548,9 +548,9 @@ class MyOptions(AzureAIAgentOptions, total=False): # Initialize parent super().__init__( + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) # Initialize instance variables diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index ba5dd8aad7..1fc6c7c1c9 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -119,9 +119,9 @@ def __init__( credential: AzureCredentialTypes | None = None, use_latest_version: bool | None = None, allow_preview: bool | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a bare Azure AI client. @@ -145,9 +145,9 @@ def __init__( use_latest_version: Boolean flag that indicates whether to use latest agent version if it exists in the service. allow_preview: Enables preview opt-in on internally-created ``AIProjectClient``. + additional_properties: Additional properties stored on the client instance. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. - kwargs: Additional keyword arguments passed to the parent class. Examples: .. code-block:: python @@ -217,7 +217,7 @@ class MyOptions(ChatOptions, total=False): # Initialize parent super().__init__( - **kwargs, + additional_properties=additional_properties, ) # Initialize instance variables @@ -1243,11 +1243,11 @@ def __init__( credential: AzureCredentialTypes | None = None, use_latest_version: bool | None = None, allow_preview: bool | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Azure AI client with full layer support. @@ -1268,11 +1268,11 @@ def __init__( use_latest_version: Boolean flag that indicates whether to use latest agent version if it exists in the service. allow_preview: Enables preview opt-in on internally-created ``AIProjectClient`` + additional_properties: Additional properties stored on the client instance. middleware: Optional sequence of chat middlewares to include. function_invocation_configuration: Optional function invocation configuration. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. - kwargs: Additional keyword arguments passed to the parent class. Examples: .. code-block:: python @@ -1319,9 +1319,9 @@ class MyOptions(ChatOptions, total=False): credential=credential, use_latest_version=use_latest_version, allow_preview=allow_preview, + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, env_file_path=env_file_path, env_file_encoding=env_file_encoding, - **kwargs, ) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py index a243f77a38..3daa678333 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py @@ -124,9 +124,9 @@ def __init__( text_client: EmbeddingsClient | None = None, image_client: ImageEmbeddingsClient | None = None, credential: AzureKeyCredential | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a raw Azure AI Inference embedding client.""" settings = load_settings( @@ -160,7 +160,7 @@ def __init__( credential=credential, # type: ignore[arg-type] ) self._endpoint = resolved_endpoint - super().__init__(**kwargs) + super().__init__(additional_properties=additional_properties) async def close(self) -> None: """Close the underlying SDK clients and release resources.""" @@ -376,9 +376,9 @@ def __init__( image_client: ImageEmbeddingsClient | None = None, credential: AzureKeyCredential | None = None, otel_provider_name: str | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Azure AI Inference embedding client.""" super().__init__( @@ -389,8 +389,8 @@ def __init__( text_client=text_client, image_client=image_client, credential=credential, + additional_properties=additional_properties, otel_provider_name=otel_provider_name, env_file_path=env_file_path, env_file_encoding=env_file_encoding, - **kwargs, ) diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py index 35c4243c37..6d205fa378 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py @@ -124,7 +124,13 @@ def __init__( self._database_client = self._cosmos_client.get_database_client(self.database_name) - async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]: + async def get_messages( + self, + session_id: str | None, + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> list[Message]: """Retrieve stored messages for this session from Azure Cosmos DB.""" await self._ensure_container_proxy() session_key = self._session_partition_key(session_id) @@ -157,7 +163,14 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess return messages - async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None: + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: """Persist messages for this session to Azure Cosmos DB.""" if not messages: return diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index 5bc9735846..cbb7e2879d 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -236,11 +236,11 @@ def __init__( session_token: str | None = None, client: BaseClient | None = None, boto3_session: Boto3Session | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Create a Bedrock chat client and load AWS credentials. @@ -252,11 +252,11 @@ def __init__( session_token: Optional AWS session token for temporary credentials. client: Preconfigured Bedrock runtime client; when omitted a boto3 session is created. boto3_session: Custom boto3 session used to build the runtime client if provided. + additional_properties: Additional properties stored on the client instance. middleware: Optional sequence of middlewares to include. function_invocation_configuration: Optional function invocation configuration env_file_path: Optional .env file path used by ``BedrockSettings`` to load defaults. env_file_encoding: Encoding for the optional .env file. - kwargs: Additional arguments forwarded to ``BaseChatClient``. Examples: .. code-block:: python @@ -303,9 +303,9 @@ class MyOptions(BedrockChatOptions, total=False): ) super().__init__( + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) self.model_id = chat_model_id self.region = region diff --git a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py index d07bdee45c..3161ed4c88 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py @@ -104,9 +104,9 @@ def __init__( session_token: str | None = None, client: BaseClient | None = None, boto3_session: Boto3Session | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a raw Bedrock embedding client.""" settings = load_settings( @@ -145,7 +145,7 @@ def __init__( self.model_id: str = settings["embedding_model_id"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess] self.region = resolved_region - super().__init__(**kwargs) + super().__init__(additional_properties=additional_properties) def service_url(self) -> str: """Get the URL of the service.""" @@ -274,9 +274,9 @@ def __init__( client: BaseClient | None = None, boto3_session: Boto3Session | None = None, otel_provider_name: str | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a Bedrock embedding client.""" super().__init__( @@ -287,8 +287,8 @@ def __init__( session_token=session_token, client=client, boto3_session=boto3_session, + additional_properties=additional_properties, otel_provider_name=otel_provider_name, env_file_path=env_file_path, env_file_encoding=env_file_encoding, - **kwargs, ) diff --git a/python/packages/claude/agent_framework_claude/_agent.py b/python/packages/claude/agent_framework_claude/_agent.py index 127e3647ee..549c7a5046 100644 --- a/python/packages/claude/agent_framework_claude/_agent.py +++ b/python/packages/claude/agent_framework_claude/_agent.py @@ -5,7 +5,7 @@ import contextlib import logging import sys -from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence +from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, Sequence from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload @@ -581,7 +581,9 @@ def run( *, stream: Literal[False] = ..., session: AgentSession | None = None, - **kwargs: Any, + options: OptionsT | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, ) -> Awaitable[AgentResponse[Any]]: ... @overload @@ -591,7 +593,9 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, - **kwargs: Any, + options: OptionsT | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def run( @@ -600,7 +604,9 @@ def run( *, stream: bool = False, session: AgentSession | None = None, - **kwargs: Any, + options: OptionsT | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Run the agent with the given messages. @@ -612,16 +618,19 @@ def run( returns an awaitable AgentResponse. session: The conversation session. If session has service_session_id set, the agent will resume that session. - kwargs: Additional keyword arguments including 'options' for runtime options - (model, permission_mode can be changed per-request). + options: Runtime options. Model and permission_mode can be changed per request. + function_invocation_kwargs: Present for compatibility with the shared agent interface. + ClaudeAgent does not use these values directly. + client_kwargs: Present for compatibility with the shared agent interface. + ClaudeAgent does not use these values directly. Returns: When stream=True: An ResponseStream for streaming updates. When stream=False: An Awaitable[AgentResponse] with the complete response. """ - options = kwargs.pop("options", None) + del function_invocation_kwargs, client_kwargs response = ResponseStream( - self._get_stream(messages, session=session, options=options, **kwargs), + self._get_stream(messages, session=session, options=options), finalizer=self._finalize_response, ) @@ -634,8 +643,7 @@ async def _get_stream( messages: AgentRunInputs | None = None, *, session: AgentSession | None = None, - options: OptionsT | MutableMapping[str, Any] | None = None, - **kwargs: Any, + options: OptionsT | None = None, ) -> AsyncIterable[AgentResponseUpdate]: """Internal streaming implementation.""" session = session or self.create_session() diff --git a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py index edacb614a5..fc2a35c72b 100644 --- a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py +++ b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py @@ -196,7 +196,6 @@ def run( *, stream: Literal[False] = False, session: AgentSession | None = None, - **kwargs: Any, ) -> Awaitable[AgentResponse]: ... @overload @@ -206,7 +205,6 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, - **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... def run( @@ -215,7 +213,6 @@ def run( *, stream: bool = False, session: AgentSession | None = None, - **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: """Get a response from the agent. @@ -229,22 +226,20 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). - kwargs: Additional keyword arguments. Returns: When stream=False: An Awaitable[AgentResponse]. When stream=True: A ResponseStream of AgentResponseUpdate items. """ if stream: - return self._run_stream_impl(messages=messages, session=session, **kwargs) - return self._run_impl(messages=messages, session=session, **kwargs) + return self._run_stream_impl(messages=messages, session=session) + return self._run_impl(messages=messages, session=session) async def _run_impl( self, messages: AgentRunInputs | None = None, *, session: AgentSession | None = None, - **kwargs: Any, ) -> AgentResponse: """Non-streaming implementation of run.""" if not session: @@ -269,7 +264,6 @@ def _run_stream_impl( messages: AgentRunInputs | None = None, *, session: AgentSession | None = None, - **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: """Streaming implementation of run.""" diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index ef03652898..57a438da2a 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -181,6 +181,7 @@ ) from .exceptions import ( MiddlewareException, + UserInputRequiredException, WorkflowCheckpointException, WorkflowConvergenceException, WorkflowException, @@ -291,6 +292,7 @@ "TypeCompatibilityError", "UpdateT", "UsageDetails", + "UserInputRequiredException", "ValidationTypeEnum", "Workflow", "WorkflowAgent", diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 5cf7ff78a2..e7c7b71b80 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -2,10 +2,10 @@ from __future__ import annotations -import inspect import logging import re import sys +import warnings from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack from copy import deepcopy @@ -27,11 +27,12 @@ from mcp import types from mcp.server.lowlevel import Server from mcp.shared.exceptions import McpError -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel from ._clients import BaseChatClient, SupportsChatGetResponse +from ._docstrings import apply_layered_docstring from ._mcp import LOG_LEVEL_MAPPING, MCPTool -from ._middleware import AgentMiddlewareLayer, MiddlewareTypes +from ._middleware import AgentMiddlewareLayer, FunctionInvocationContext, MiddlewareTypes from ._serialization import SerializationMixin from ._sessions import ( AgentSession, @@ -57,7 +58,7 @@ map_chat_to_agent_update, normalize_messages, ) -from .exceptions import AgentInvalidResponseException +from .exceptions import AgentInvalidResponseException, UserInputRequiredException from .observability import AgentTelemetryLayer if sys.version_info >= (3, 13): @@ -177,8 +178,8 @@ class _RunContext(TypedDict): session_messages: Sequence[Message] agent_name: str chat_options: MutableMapping[str, Any] - filtered_kwargs: Mapping[str, Any] - finalize_kwargs: Mapping[str, Any] + client_kwargs: Mapping[str, Any] + function_invocation_kwargs: Mapping[str, Any] # region Agent Protocol @@ -226,15 +227,15 @@ async def _stream(): return AgentResponse(messages=[], response_id="custom-response") - def create_session(self, **kwargs): + def create_session(self, *, session_id: str | None = None): from agent_framework import AgentSession - return AgentSession(**kwargs) + return AgentSession(session_id=session_id) - def get_session(self, *, service_session_id, **kwargs): + def get_session(self, service_session_id: str, *, session_id: str | None = None): from agent_framework import AgentSession - return AgentSession(service_session_id=service_session_id, **kwargs) + return AgentSession(service_session_id=service_session_id, session_id=session_id) # Verify the instance satisfies the protocol @@ -253,6 +254,8 @@ def run( *, stream: Literal[False] = ..., session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: """Get a response from the agent (non-streaming).""" @@ -265,6 +268,8 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Get a streaming response from the agent.""" @@ -276,6 +281,8 @@ def run( *, stream: bool = False, session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Get a response from the agent. @@ -290,6 +297,8 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). + function_invocation_kwargs: Keyword arguments forwarded to tool invocation. + client_kwargs: Additional client-specific keyword arguments. kwargs: Additional keyword arguments. Returns: @@ -299,11 +308,11 @@ def run( """ ... - def create_session(self, **kwargs: Any) -> AgentSession: + def create_session(self, *, session_id: str | None = None) -> AgentSession: """Creates a new conversation session.""" ... - def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: """Gets or creates a session for a service-managed session ID.""" ... @@ -386,6 +395,13 @@ def __init__( additional_properties: Additional properties set on the agent. kwargs: Additional keyword arguments (merged into additional_properties). """ + if kwargs: + warnings.warn( + "Passing additional properties as direct keyword arguments to BaseAgent is deprecated; " + "pass them via additional_properties instead.", + DeprecationWarning, + stacklevel=3, + ) if id is None: id = str(uuid4()) self.id = id @@ -400,27 +416,40 @@ def __init__( self.additional_properties: dict[str, Any] = cast(dict[str, Any], additional_properties or {}) self.additional_properties.update(kwargs) - def create_session(self, *, session_id: str | None = None, **kwargs: Any) -> AgentSession: + def create_session(self, *, session_id: str | None = None) -> AgentSession: """Create a new lightweight session. + This will be used by an agent to hold the persisted session. + This depends on the service used, in some cases, or with store=True + this will add the ``service_session_id`` based on the response, + which is then fed back to the API on the next call. + + In other cases, if there is a HistoryProvider setup in the agent, + that is used and it can store state in the session. + + If there is no HistoryProvider and store=False or the default of a service is False. + Then a ``InMemoryHistoryProvider`` instance is added to the agent and used with the session automatically. + The ``InMemoryHistoryProvider`` stores the messages as `state` in the session by default. + Keyword Args: session_id: Optional session ID (generated if not provided). - kwargs: Additional keyword arguments. Returns: A new AgentSession instance. """ return AgentSession(session_id=session_id) - def get_session(self, *, service_session_id: str, session_id: str | None = None, **kwargs: Any) -> AgentSession: - """Get or create a session for a service-managed session ID. + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + """Get a session for a service-managed session ID. + + Only use this to create a session continuing that session id from a service. + Otherwise use ``create_session``. Args: service_session_id: The service-managed session ID. Keyword Args: session_id: Optional local session ID (generated if not provided). - kwargs: Additional keyword arguments. Returns: A new AgentSession instance with service_session_id set. @@ -460,9 +489,8 @@ def as_tool( description: str | None = None, arg_name: str = "task", arg_description: str | None = None, - stream_callback: Callable[[AgentResponseUpdate], None] - | Callable[[AgentResponseUpdate], Awaitable[None]] - | None = None, + approval_mode: Literal["always_require", "never_require"] = "never_require", + stream_callback: Callable[[AgentResponseUpdate], Awaitable[None] | None] | None = None, propagate_session: bool = False, ) -> FunctionTool: """Create a FunctionTool that wraps this agent. @@ -473,21 +501,15 @@ def as_tool( arg_name: The name of the function argument (default: "task"). arg_description: The description for the function argument. If None, defaults to "Task for {tool_name}". + approval_mode: Whether this delegated tool requires approval before execution. stream_callback: Optional callback for streaming responses. If provided, uses run(..., stream=True). - propagate_session: If True, the parent agent's ``AgentSession`` is - forwarded to this sub-agent's ``run()`` call, so both agents - operate within the same logical session (sharing the same - ``session_id`` and provider-managed state, such as any stored - conversation history or metadata). Defaults to False, meaning - the sub-agent runs with a new, independent session. + propagate_session: If True, the parent agent's session is forwarded + to this sub-agent's ``run()`` call so both agents share the + same session. Defaults to False. Returns: A FunctionTool that can be used as a tool by other agents. - Raises: - TypeError: If the agent does not implement SupportsAgentRun. - ValueError: If the agent tool name cannot be determined. - Examples: .. code-block:: python @@ -515,59 +537,46 @@ def as_tool( tool_description = description or self.description or "" argument_description = arg_description or f"Task for {tool_name}" - # Create dynamic input model with the specified argument name - field_info = Field(..., description=argument_description) - model_name = f"{name or _sanitize_agent_name(self.name) or 'agent'}_task" - input_model = create_model(model_name, **{arg_name: (str, field_info)}) # type: ignore[call-overload] - - # Check if callback is async once, outside the wrapper - is_async_callback = stream_callback is not None and inspect.iscoroutinefunction(stream_callback) - - async def agent_wrapper(**kwargs: Any) -> str: - """Wrapper function that calls the agent.""" - # Extract the input from kwargs using the specified arg_name - input_text = kwargs.get(arg_name, "") - - # Extract parent session when propagate_session is enabled - parent_session = kwargs.get("session") if propagate_session else None - - # Forward runtime context kwargs, excluding framework-internal keys. - forwarded_kwargs = { - k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options", "session") - } - - if stream_callback is None: - # Use non-streaming mode - return ( - await self.run( - input_text, - stream=False, - session=parent_session, - **forwarded_kwargs, - ) - ).text - - # Use streaming mode - accumulate updates and create final response - response_updates: list[AgentResponseUpdate] = [] - async for update in self.run(input_text, stream=True, session=parent_session, **forwarded_kwargs): - response_updates.append(update) - if is_async_callback: - await stream_callback(update) # type: ignore[misc] - else: - stream_callback(update) + input_schema = { + "type": "object", + "properties": { + arg_name: { + "type": "string", + "description": argument_description, + } + }, + "required": [arg_name], + "additionalProperties": False, + } - # Create final text from accumulated updates - return AgentResponse.from_updates(response_updates).text + async def _agent_wrapper(ctx: FunctionInvocationContext, **kwargs: Any) -> str: + """Wrapper function that calls the agent. - agent_tool: FunctionTool = FunctionTool( + Args: + ctx: the function invocation context used + **kwargs: only used to dynamically load the argument that is defined for this tool. + """ + stream = self.run( + str(kwargs.get(arg_name, "")), + stream=True, + session=ctx.session if propagate_session else None, + function_invocation_kwargs=dict(ctx.kwargs), + ) + if stream_callback is not None: + stream.with_transform_hook(stream_callback) + final_response = await stream.get_final_response() + if final_response.user_input_requests: + raise UserInputRequiredException(contents=final_response.user_input_requests) + # TODO(Copilot): update once #4331 merges + return final_response.text + + return FunctionTool( name=tool_name, description=tool_description, - func=agent_wrapper, - input_model=input_model, # type: ignore - approval_mode="never_require", + func=_agent_wrapper, + input_model=input_schema, + approval_mode=approval_mode, ) - agent_tool._forward_runtime_kwargs = True # type: ignore - return agent_tool # region Agent @@ -799,6 +808,8 @@ def run( session: AgentSession | None = None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, options: ChatOptions[ResponseModelBoundT], + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ... @@ -811,6 +822,8 @@ def run( session: AgentSession | None = None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, options: OptionsCoT | ChatOptions[None] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @@ -823,6 +836,8 @@ def run( session: AgentSession | None = None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, options: OptionsCoT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... @@ -834,6 +849,8 @@ def run( session: AgentSession | None = None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None, options: OptionsCoT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Run the agent with the given messages and options. @@ -857,14 +874,23 @@ def run( ``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for provider-specific options including temperature, max_tokens, model_id, tool_choice, and provider-specific options like reasoning_effort. - kwargs: Additional keyword arguments for the agent. - Will only be passed to functions that are called. + function_invocation_kwargs: Keyword arguments forwarded to tool invocation. + client_kwargs: Additional client-specific keyword arguments for the chat client. + kwargs: Deprecated additional keyword arguments for the agent. + They are forwarded to both tool invocation and the chat client for compatibility. Returns: When stream=False: An Awaitable[AgentResponse] containing the agent's response. When stream=True: A ResponseStream of AgentResponseUpdate items with ``get_final_response()`` for the final AgentResponse. """ + if kwargs: + warnings.warn( + "Passing runtime keyword arguments directly to run() is deprecated; pass tool values via " + "function_invocation_kwargs and client-specific values via client_kwargs instead.", + DeprecationWarning, + stacklevel=2, + ) if not stream: async def _run_non_streaming() -> AgentResponse[Any]: @@ -873,7 +899,9 @@ async def _run_non_streaming() -> AgentResponse[Any]: session=session, tools=tools, options=options, - kwargs=kwargs, + legacy_kwargs=kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ) response = cast( ChatResponse[Any], @@ -881,7 +909,8 @@ async def _run_non_streaming() -> AgentResponse[Any]: messages=ctx["session_messages"], stream=False, options=ctx["chat_options"], # type: ignore[reportArgumentType] - **ctx["filtered_kwargs"], + function_invocation_kwargs=ctx["function_invocation_kwargs"], + client_kwargs=ctx["client_kwargs"], ), ) @@ -954,14 +983,17 @@ async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]] session=session, tools=tools, options=options, - kwargs=kwargs, + legacy_kwargs=kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ) ctx: _RunContext = ctx_holder["ctx"] # type: ignore[assignment] # Safe: we just assigned it return self.client.get_response( # type: ignore[call-overload, no-any-return] messages=ctx["session_messages"], stream=True, options=ctx["chat_options"], # type: ignore[reportArgumentType] - **ctx["filtered_kwargs"], + function_invocation_kwargs=ctx["function_invocation_kwargs"], + client_kwargs=ctx["client_kwargs"], ) def _propagate_conversation_id( @@ -1047,9 +1079,12 @@ async def _prepare_run_context( session: AgentSession | None, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None, options: Mapping[str, Any] | None, - kwargs: dict[str, Any], + legacy_kwargs: Mapping[str, Any], + function_invocation_kwargs: Mapping[str, Any] | None, + client_kwargs: Mapping[str, Any] | None, ) -> _RunContext: opts = dict(options) if options else {} + existing_additional_args: dict[str, Any] = opts.pop("additional_function_arguments", None) or {} # Get tools from options or named parameter (named param takes precedence) tools_ = tools if tools is not None else opts.pop("tools", None) @@ -1080,6 +1115,12 @@ async def _prepare_run_context( input_messages=input_messages, options=opts, ) + default_additional_args = chat_options.pop("additional_function_arguments", None) + if isinstance(default_additional_args, Mapping): + existing_additional_args = { + **dict(cast(Mapping[str, Any], default_additional_args)), + **existing_additional_args, + } # Normalize tools normalized_tools = normalize_tools(tools_) @@ -1101,13 +1142,13 @@ async def _prepare_run_context( await self._async_exit_stack.enter_async_context(mcp_server) final_tools.extend(f for f in mcp_server.functions if f.name not in existing_names) - # Merge runtime kwargs into additional_function_arguments so they're available - # in function middleware context and tool invocation. - existing_additional_args: dict[str, Any] = opts.pop("additional_function_arguments", None) or {} - additional_function_arguments = {**kwargs, **existing_additional_args} - # Include session so as_tool() wrappers with propagate_session=True can access it. - if active_session is not None: - additional_function_arguments["session"] = active_session + # TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed. + # Legacy compatibility still fans out direct run kwargs into tool runtime kwargs. + effective_function_invocation_kwargs = { + **dict(legacy_kwargs), + **(dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}), + } + additional_function_arguments = {**effective_function_invocation_kwargs, **existing_additional_args} # Build options dict from run() options merged with provided options run_opts: dict[str, Any] = { @@ -1116,7 +1157,6 @@ async def _prepare_run_context( if active_session else opts.pop("conversation_id", None), "allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None), - "additional_function_arguments": additional_function_arguments or None, "frequency_penalty": opts.pop("frequency_penalty", None), "logit_bias": opts.pop("logit_bias", None), "max_tokens": opts.pop("max_tokens", None), @@ -1140,11 +1180,14 @@ async def _prepare_run_context( # Build session_messages from session context: context messages + input messages session_messages: list[Message] = session_context.get_messages(include_input=True) - # Ensure session is forwarded in kwargs for tool invocation - finalize_kwargs = dict(kwargs) - finalize_kwargs["session"] = active_session - # Filter chat_options from kwargs to prevent duplicate keyword argument - filtered_kwargs = {k: v for k, v in finalize_kwargs.items() if k != "chat_options"} + # TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed. + # Legacy compatibility still fans out direct run kwargs into client kwargs. + effective_client_kwargs = { + **dict(legacy_kwargs), + **(dict(client_kwargs) if client_kwargs is not None else {}), + } + if active_session is not None: + effective_client_kwargs["session"] = active_session return { "session": active_session, @@ -1153,8 +1196,8 @@ async def _prepare_run_context( "session_messages": session_messages, "agent_name": agent_name, "chat_options": co, - "filtered_kwargs": filtered_kwargs, - "finalize_kwargs": finalize_kwargs, + "client_kwargs": effective_client_kwargs, + "function_invocation_kwargs": additional_function_arguments, } async def _finalize_response( @@ -1396,6 +1439,58 @@ class Agent( For a minimal implementation without these features, use :class:`RawAgent`. """ + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[False] = ..., + session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[True], + session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + middleware: Sequence[MiddlewareTypes] | None = None, + options: OptionsCoT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: + """Run the agent.""" + super_run = cast( + "Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]", + super().run, # type: ignore[misc] + ) + return super_run( # type: ignore[no-any-return] + messages=messages, + stream=stream, + session=session, + middleware=middleware, + options=options, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + **kwargs, + ) + def __init__( self, client: SupportsChatGetResponse[OptionsCoT], @@ -1423,3 +1518,34 @@ def __init__( middleware=middleware, **kwargs, ) + + +def _apply_agent_docstrings() -> None: + """Align public agent docstrings with the raw implementation.""" + apply_layered_docstring( + AgentMiddlewareLayer.run, + RawAgent.run, + extra_keyword_args={ + "middleware": """ + Optional per-run agent, chat, and function middleware. + Agent middleware wraps the run itself, while chat and function middleware are forwarded to the + underlying chat-client stack for this call. + """, + }, + ) + apply_layered_docstring(AgentTelemetryLayer.run, AgentMiddlewareLayer.run) + apply_layered_docstring( + Agent.run, + RawAgent.run, + extra_keyword_args={ + "middleware": """ + Optional per-run agent, chat, and function middleware. + Agent middleware wraps the run itself, while chat and function middleware are forwarded to the + underlying chat-client stack for this call. + """, + }, + ) + apply_layered_docstring(Agent.__init__, RawAgent.__init__) + + +_apply_agent_docstrings() diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 5dd049ecd3..b667187da2 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -4,6 +4,7 @@ import logging import sys +import warnings from abc import ABC, abstractmethod from collections.abc import ( AsyncIterable, @@ -27,6 +28,7 @@ from pydantic import BaseModel +from ._docstrings import apply_layered_docstring from ._serialization import SerializationMixin from ._tools import ( FunctionInvocationConfiguration, @@ -104,7 +106,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]): class CustomChatClient: additional_properties: dict = {} - def get_response(self, messages, *, stream=False, **kwargs): + def get_response(self, messages, *, stream=False, client_kwargs=None, **kwargs): if stream: from agent_framework import ChatResponseUpdate, ResponseStream @@ -144,6 +146,8 @@ def get_response( *, stream: Literal[False] = ..., options: OptionsContraT | ChatOptions[None] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]]: ... @@ -154,6 +158,8 @@ def get_response( *, stream: Literal[True], options: OptionsContraT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... @@ -163,6 +169,8 @@ def get_response( *, stream: bool = False, options: OptionsContraT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: """Send input and return the response. @@ -171,7 +179,9 @@ def get_response( messages: The sequence of input messages to send. stream: Whether to stream the response. Defaults to False. options: Chat options as a TypedDict. - **kwargs: Additional chat options. + function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers. + client_kwargs: Additional client-specific keyword arguments. + **kwargs: Deprecated additional client-specific keyword arguments. Returns: When stream=False: An awaitable ChatResponse from the client. @@ -276,7 +286,15 @@ def __init__( kwargs: Additional keyword arguments (merged into additional_properties). """ self.additional_properties = additional_properties or {} - super().__init__(**kwargs) + if kwargs: + warnings.warn( + "Passing additional properties as direct keyword arguments to BaseChatClient is deprecated; " + "pass them via additional_properties instead.", + DeprecationWarning, + stacklevel=3, + ) + self.additional_properties.update(kwargs) + super().__init__() def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: """Convert the instance to a dictionary. @@ -374,6 +392,8 @@ def get_response( *, stream: Literal[False] = ..., options: ChatOptions[ResponseModelBoundT], + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... @@ -411,16 +431,31 @@ def get_response( messages: The message or messages to send to the model. stream: Whether to stream the response. Defaults to False. options: Chat options as a TypedDict. - **kwargs: Other keyword arguments, can be used to pass function specific parameters. + **kwargs: Additional compatibility keyword arguments. Lower chat-client layers do not + consume ``function_invocation_kwargs`` directly; if present, it is ignored here + because function invocation has already been handled by upper layers. If a + ``client_kwargs`` mapping is present, it is flattened into standard keyword + arguments before forwarding to ``_inner_get_response()`` so client implementations + that care about those values can still use them, while implementations that ignore + extra kwargs remain compatible. Returns: When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse. """ + compatibility_client_kwargs = kwargs.pop("client_kwargs", None) + kwargs.pop("function_invocation_kwargs", None) + merged_client_kwargs = ( + dict(cast(Mapping[str, Any], compatibility_client_kwargs)) + if isinstance(compatibility_client_kwargs, Mapping) + else {} + ) + merged_client_kwargs.update(kwargs) + return self._inner_get_response( messages=messages, stream=stream, options=options or {}, # type: ignore[arg-type] - **kwargs, + **merged_client_kwargs, ) def service_url(self) -> str: @@ -446,7 +481,7 @@ def as_agent( context_providers: Sequence[Any] | None = None, middleware: Sequence[MiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, - **kwargs: Any, + additional_properties: Mapping[str, Any] | None = None, ) -> Agent[OptionsCoT]: """Create a Agent with this client. @@ -468,7 +503,7 @@ def as_agent( context_providers: Context providers to include during agent invocation. middleware: List of middleware to intercept agent and function invocations. function_invocation_configuration: Optional function invocation configuration override. - kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``. + additional_properties: Additional properties stored on the created agent. Returns: A Agent instance configured with this chat client. @@ -493,19 +528,22 @@ def as_agent( """ from ._agents import Agent - return Agent( - client=self, - id=id, - name=name, - description=description, - instructions=instructions, - tools=tools, - default_options=cast(Any, default_options), - context_providers=context_providers, - middleware=middleware, - function_invocation_configuration=function_invocation_configuration, - **kwargs, - ) + agent_kwargs: dict[str, Any] = { + "client": self, + "id": id, + "name": name, + "description": description, + "instructions": instructions, + "tools": tools, + "default_options": cast(Any, default_options), + "context_providers": context_providers, + "middleware": middleware, + "additional_properties": dict(additional_properties) if additional_properties is not None else None, + } + if function_invocation_configuration is not None: + agent_kwargs["function_invocation_configuration"] = function_invocation_configuration + + return Agent(**agent_kwargs) # endregion @@ -768,16 +806,14 @@ def __init__( self, *, additional_properties: dict[str, Any] | None = None, - **kwargs: Any, ) -> None: """Initialize a BaseEmbeddingClient instance. Args: additional_properties: Additional properties to pass to the client. - **kwargs: Additional keyword arguments passed to parent classes (for MRO). """ self.additional_properties = additional_properties or {} - super().__init__(**kwargs) + super().__init__() @abstractmethod async def get_embeddings( @@ -799,3 +835,36 @@ async def get_embeddings( # endregion + + +def _apply_get_response_docstrings() -> None: + """Align layered chat-client docstrings with the lowest public implementation.""" + from ._middleware import ChatMiddlewareLayer + from ._tools import FunctionInvocationLayer + from .observability import ChatTelemetryLayer + + apply_layered_docstring(ChatTelemetryLayer.get_response, BaseChatClient.get_response) + apply_layered_docstring( + FunctionInvocationLayer.get_response, + ChatTelemetryLayer.get_response, + extra_keyword_args={ + "function_middleware": """ + Optional per-call function middleware. + When omitted, middleware configured on the client or forwarded from higher layers is used. + """, + }, + ) + apply_layered_docstring( + ChatMiddlewareLayer.get_response, + FunctionInvocationLayer.get_response, + extra_keyword_args={ + "middleware": """ + Optional per-call chat and function middleware. + This compatibility keyword argument is merged with any ``client_kwargs["middleware"]`` value + before the request is executed. + """, + }, + ) + + +_apply_get_response_docstrings() diff --git a/python/packages/core/agent_framework/_docstrings.py b/python/packages/core/agent_framework/_docstrings.py new file mode 100644 index 0000000000..44dd7c50a3 --- /dev/null +++ b/python/packages/core/agent_framework/_docstrings.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import inspect +from collections.abc import Callable, Mapping +from typing import Any + +_GOOGLE_SECTION_HEADERS = ( + "Args:", + "Keyword Args:", + "Returns:", + "Raises:", + "Examples:", + "Note:", + "Notes:", + "Warning:", + "Warnings:", +) + + +def _find_section_index(lines: list[str], header: str) -> int | None: + for index, line in enumerate(lines): + if line == header: + return index + return None + + +def _find_next_section_index(lines: list[str], start: int) -> int: + for index in range(start, len(lines)): + if lines[index] in _GOOGLE_SECTION_HEADERS: + return index + return len(lines) + + +def _format_keyword_arg_lines(extra_keyword_args: Mapping[str, str]) -> list[str]: + formatted_lines: list[str] = [] + for name, description in extra_keyword_args.items(): + description_lines = inspect.cleandoc(description).splitlines() + if not description_lines: + formatted_lines.append(f" {name}:") + continue + formatted_lines.append(f" {name}: {description_lines[0]}") + formatted_lines.extend(f" {line}" for line in description_lines[1:]) + return formatted_lines + + +def build_layered_docstring( + source: Callable[..., Any], + *, + extra_keyword_args: Mapping[str, str] | None = None, +) -> str | None: + """Build a Google-style docstring from a lower-layer implementation.""" + docstring = inspect.getdoc(source) + if not docstring: + return None + if not extra_keyword_args: + return docstring + + lines = docstring.splitlines() + formatted_keyword_arg_lines = _format_keyword_arg_lines(extra_keyword_args) + keyword_args_index = _find_section_index(lines, "Keyword Args:") + + if keyword_args_index is None: + args_index = _find_section_index(lines, "Args:") + if args_index is not None: + insert_index = _find_next_section_index(lines, args_index + 1) + else: + insert_index = _find_next_section_index(lines, 0) + lines[insert_index:insert_index] = ["", "Keyword Args:", *formatted_keyword_arg_lines] + return "\n".join(lines).rstrip() + + insert_index = _find_next_section_index(lines, keyword_args_index + 1) + lines[insert_index:insert_index] = formatted_keyword_arg_lines + return "\n".join(lines).rstrip() + + +def apply_layered_docstring( + target: Callable[..., Any], + source: Callable[..., Any], + *, + extra_keyword_args: Mapping[str, str] | None = None, +) -> None: + """Copy a lower-layer docstring onto a wrapper and extend it when needed.""" + target.__doc__ = build_layered_docstring(source, extra_keyword_args=extra_keyword_args) diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 7f3f3da13d..f475fac7ca 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -106,7 +106,9 @@ class AgentContext: to see the actual execution result or can be set to override the execution result. For non-streaming: should be AgentResponse. For streaming: should be ResponseStream[AgentResponseUpdate, AgentResponse]. - kwargs: Additional keyword arguments passed to the agent run method. + kwargs: Legacy runtime keyword arguments visible to agent middleware. + client_kwargs: Client-specific keyword arguments for downstream chat clients. + function_invocation_kwargs: Keyword arguments forwarded to tool invocation. Examples: .. code-block:: python @@ -142,6 +144,8 @@ def __init__( metadata: Mapping[str, Any] | None = None, result: AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None = None, kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, stream_transform_hooks: Sequence[ Callable[[AgentResponseUpdate], AgentResponseUpdate | Awaitable[AgentResponseUpdate]] ] @@ -160,7 +164,9 @@ def __init__( stream: Whether this is a streaming invocation. metadata: Metadata dictionary for sharing data between agent middleware. result: Agent execution result. - kwargs: Additional keyword arguments passed to the agent run method. + kwargs: Legacy runtime keyword arguments visible to agent middleware. + client_kwargs: Client-specific keyword arguments for downstream chat clients. + function_invocation_kwargs: Keyword arguments forwarded to tool invocation. stream_transform_hooks: Hooks to transform streamed updates. stream_result_hooks: Hooks to process the final result after streaming. stream_cleanup_hooks: Hooks to run after streaming completes. @@ -173,6 +179,10 @@ def __init__( self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {} self.result = result self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {} + self.client_kwargs: dict[str, Any] = dict(client_kwargs) if client_kwargs is not None else {} + self.function_invocation_kwargs: dict[str, Any] = ( + dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {} + ) self.stream_transform_hooks = list(stream_transform_hooks or []) self.stream_result_hooks = list(stream_result_hooks or []) self.stream_cleanup_hooks = list(stream_cleanup_hooks or []) @@ -187,11 +197,11 @@ class FunctionInvocationContext: Attributes: function: The function being invoked. arguments: The validated arguments for the function. + session: The agent session for this invocation, if any. metadata: Metadata dictionary for sharing data between function middleware. result: Function execution result. Can be observed after calling ``call_next()`` to see the actual execution result or can be set to override the execution result. - - kwargs: Additional keyword arguments passed to the chat method that invoked this function. + kwargs: Additional runtime keyword arguments forwarded to the function invocation. Examples: .. code-block:: python @@ -216,6 +226,7 @@ def __init__( self, function: FunctionTool, arguments: BaseModel | Mapping[str, Any], + session: AgentSession | None = None, metadata: Mapping[str, Any] | None = None, result: Any = None, kwargs: Mapping[str, Any] | None = None, @@ -225,12 +236,14 @@ def __init__( Args: function: The function being invoked. arguments: The validated arguments for the function. + session: The agent session for this invocation, if any. metadata: Metadata dictionary for sharing data between function middleware. result: Function execution result. - kwargs: Additional keyword arguments passed to the chat method that invoked this function. + kwargs: Additional runtime keyword arguments forwarded to the function invocation. """ self.function = function self.arguments = arguments + self.session = session self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {} self.result = result self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {} @@ -253,6 +266,7 @@ class ChatContext: For non-streaming: should be ChatResponse. For streaming: should be ResponseStream[ChatResponseUpdate, ChatResponse]. kwargs: Additional keyword arguments passed to the chat client. + function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers. stream_transform_hooks: Hooks applied to transform each streamed update. stream_result_hooks: Hooks applied to the finalized response (after finalizer). stream_cleanup_hooks: Hooks executed after stream consumption (before finalizer). @@ -289,6 +303,7 @@ def __init__( metadata: Mapping[str, Any] | None = None, result: ChatResponse | ResponseStream[ChatResponseUpdate, ChatResponse] | None = None, kwargs: Mapping[str, Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, stream_transform_hooks: Sequence[ Callable[[ChatResponseUpdate], ChatResponseUpdate | Awaitable[ChatResponseUpdate]] ] @@ -306,6 +321,7 @@ def __init__( metadata: Metadata dictionary for sharing data between chat middleware. result: Chat execution result. kwargs: Additional keyword arguments passed to the chat client. + function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers. stream_transform_hooks: Transform hooks to apply to each streamed update. stream_result_hooks: Result hooks to apply to the finalized streaming response. stream_cleanup_hooks: Cleanup hooks to run after streaming completes. @@ -317,6 +333,9 @@ def __init__( self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {} self.result = result self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {} + self.function_invocation_kwargs: dict[str, Any] = ( + dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {} + ) self.stream_transform_hooks = list(stream_transform_hooks or []) self.stream_result_hooks = list(stream_result_hooks or []) self.stream_cleanup_hooks = list(stream_cleanup_hooks or []) @@ -969,6 +988,7 @@ def get_response( *, stream: Literal[False] = ..., options: ChatOptions[ResponseModelBoundT], + function_invocation_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... @@ -979,6 +999,8 @@ def get_response( *, stream: Literal[False] = ..., options: OptionsCoT | ChatOptions[None] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]]: ... @@ -989,6 +1011,8 @@ def get_response( *, stream: Literal[True], options: OptionsCoT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... @@ -998,14 +1022,17 @@ def get_response( *, stream: bool = False, options: OptionsCoT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: """Execute the chat pipeline if middleware is configured.""" super_get_response = super().get_response # type: ignore[misc] - call_middleware = kwargs.pop("middleware", []) + effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + call_middleware = kwargs.pop("middleware", effective_client_kwargs.pop("middleware", [])) middleware = categorize_middleware(call_middleware) - kwargs["function_middleware"] = middleware["function"] + effective_client_kwargs["function_middleware"] = middleware["function"] pipeline = ChatMiddlewarePipeline( *self.chat_middleware, @@ -1016,6 +1043,8 @@ def get_response( messages=messages, stream=stream, options=options, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=effective_client_kwargs, **kwargs, ) @@ -1024,7 +1053,8 @@ def get_response( messages=list(messages), options=options, stream=stream, - kwargs=kwargs, + kwargs={**effective_client_kwargs, **kwargs}, + function_invocation_kwargs=function_invocation_kwargs, ) async def _execute() -> ChatResponse | ResponseStream[ChatResponseUpdate, ChatResponse] | None: @@ -1061,7 +1091,8 @@ def _middleware_handler( messages=context.messages, stream=context.stream, options=context.options or {}, - **context.kwargs, + function_invocation_kwargs=context.function_invocation_kwargs, + client_kwargs=context.kwargs, ) @@ -1091,6 +1122,8 @@ def run( session: AgentSession | None = None, middleware: Sequence[MiddlewareTypes] | None = None, options: ChatOptions[ResponseModelBoundT], + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ... @@ -1103,6 +1136,8 @@ def run( session: AgentSession | None = None, middleware: Sequence[MiddlewareTypes] | None = None, options: ChatOptions[None] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @@ -1115,6 +1150,8 @@ def run( session: AgentSession | None = None, middleware: Sequence[MiddlewareTypes] | None = None, options: ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... @@ -1126,6 +1163,8 @@ def run( session: AgentSession | None = None, middleware: Sequence[MiddlewareTypes] | None = None, options: ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """MiddlewareTypes-enabled unified run method.""" @@ -1145,12 +1184,23 @@ def run( + run_middleware_list["function"] + run_middleware_list["chat"] ) - combined_kwargs = dict(kwargs) - combined_kwargs["middleware"] = combined_function_chat_middleware if combined_function_chat_middleware else None - + effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + if combined_function_chat_middleware: + effective_client_kwargs["middleware"] = combined_function_chat_middleware + effective_function_invocation_kwargs = ( + dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {} + ) # Execute with middleware if available if not pipeline.has_middlewares: - return super().run(messages, stream=stream, session=session, options=options, **combined_kwargs) # type: ignore[misc, no-any-return] + return super().run( # type: ignore[misc, no-any-return] + messages, + stream=stream, + session=session, + options=options, + function_invocation_kwargs=effective_function_invocation_kwargs, + client_kwargs=effective_client_kwargs, + **kwargs, + ) context = AgentContext( agent=self, # type: ignore[arg-type] @@ -1158,7 +1208,9 @@ def run( session=session, options=options, stream=stream, - kwargs=combined_kwargs, + kwargs=kwargs, + client_kwargs=effective_client_kwargs, + function_invocation_kwargs=effective_function_invocation_kwargs, ) async def _execute() -> AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None: @@ -1190,12 +1242,20 @@ async def _execute_stream() -> ResponseStream[AgentResponseUpdate, AgentResponse def _middleware_handler( self, context: AgentContext ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + # TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed. + client_kwargs = {**context.client_kwargs, **context.kwargs} + # TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed. + function_invocation_kwargs = { + **context.function_invocation_kwargs, + **{k: v for k, v in context.kwargs.items() if k != "middleware"}, + } return super().run( # type: ignore[misc, no-any-return] context.messages, stream=context.stream, session=context.session, options=context.options, - **context.kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ) diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 8c3457da26..9eca419df0 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -392,12 +392,16 @@ def __init__( self.store_outputs = store_outputs @abstractmethod - async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]: + async def get_messages( + self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any + ) -> list[Message]: """Retrieve stored messages for this session. Args: session_id: The session ID to retrieve messages for. - **kwargs: Additional arguments (e.g., ``state`` for in-memory providers). + state: Optional session state for providers that persist in session state. + Not used by all providers. + **kwargs: Additional subclass-specific extensibility arguments. Returns: List of stored messages. @@ -405,13 +409,22 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess ... @abstractmethod - async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None: + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: """Persist messages for this session. Args: session_id: The session ID to store messages for. messages: The messages to persist. - **kwargs: Additional arguments (e.g., ``state`` for in-memory providers). + state: Optional session state for providers that persist in session state. + Not used by all providers. + **kwargs: Additional subclass-specific extensibility arguments. """ ... diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 105738e717..505580c207 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -7,6 +7,8 @@ import json import logging import sys +import typing +import warnings from collections.abc import ( AsyncIterable, Awaitable, @@ -37,7 +39,7 @@ from pydantic import BaseModel, Field, ValidationError, create_model from ._serialization import SerializationMixin -from .exceptions import ToolException +from .exceptions import ToolException, UserInputRequiredException from .observability import ( OPERATION_DURATION_BUCKET_BOUNDARIES, OtelAttr, @@ -60,7 +62,8 @@ if TYPE_CHECKING: from ._clients import SupportsChatGetResponse from ._mcp import MCPTool - from ._middleware import FunctionMiddlewarePipeline, FunctionMiddlewareTypes + from ._middleware import FunctionInvocationContext, FunctionMiddlewarePipeline, FunctionMiddlewareTypes + from ._sessions import AgentSession from ._types import ( ChatOptions, ChatResponse, @@ -173,6 +176,16 @@ def _default_histogram() -> Histogram: ) +def _annotation_includes_function_invocation_context(annotation: Any) -> bool: + """Check whether an annotation resolves to FunctionInvocationContext.""" + from ._middleware import FunctionInvocationContext + + candidates = get_args(annotation) or (annotation,) + return any( + candidate is FunctionInvocationContext or candidate == "FunctionInvocationContext" for candidate in candidates + ) + + ClassT = TypeVar("ClassT", bound="SerializationMixin") @@ -309,6 +322,12 @@ def __init__( # FunctionTool-specific attributes self.func = func self._instance = None # Store the instance for bound methods + self._context_parameter_name: str | None = None + self._input_model_explicitly_provided = input_model is not None + # TODO(Copilot): Delete once legacy ``**kwargs`` runtime injection is removed. + self._forward_runtime_kwargs: bool = False + if self.func: + self._discover_injected_parameters() # Initialize schema cache (will be lazily populated) self._input_schema_cached: dict[str, Any] | None = None @@ -335,13 +354,37 @@ def __init__( self._invocation_duration_histogram = _default_histogram() self.type: Literal["function_tool"] = "function_tool" self.result_parser = result_parser - self._forward_runtime_kwargs: bool = False - if self.func: - sig = inspect.signature(self.func) - for param in sig.parameters.values(): - if param.kind == inspect.Parameter.VAR_KEYWORD: - self._forward_runtime_kwargs = True - break + + def _discover_injected_parameters(self) -> None: + """Inspect the wrapped function for runtime injection parameters.""" + func = self.func.func if isinstance(self.func, FunctionTool) else self.func + if func is None: + return + + signature = inspect.signature(func) + try: + type_hints = typing.get_type_hints(func) + except Exception: + type_hints = {name: param.annotation for name, param in signature.parameters.items()} + + for name, param in signature.parameters.items(): + if name in {"self", "cls"}: + continue + if param.kind == inspect.Parameter.VAR_KEYWORD: + self._forward_runtime_kwargs = True + continue + + annotation = type_hints.get(name, param.annotation) + if self._is_context_parameter(name, annotation): + if self._context_parameter_name is not None: + raise ValueError(f"Function '{self.name}' defines multiple FunctionInvocationContext parameters.") + self._context_parameter_name = name + + def _is_context_parameter(self, name: str, annotation: Any) -> bool: + """Check whether a callable parameter should receive FunctionInvocationContext injection.""" + if _annotation_includes_function_invocation_context(annotation): + return True + return self._input_model_explicitly_provided and name == "ctx" and annotation is inspect.Parameter.empty def __str__(self) -> str: """Return a string representation of the tool.""" @@ -410,6 +453,7 @@ def _resolve_input_model(self, input_model: type[BaseModel] | None) -> type[Base ) for pname, param in sig.parameters.items() if pname not in {"self", "cls"} + and pname != self._context_parameter_name and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} } return create_model(f"{self.name}_input", **fields) @@ -447,6 +491,7 @@ async def invoke( self, *, arguments: BaseModel | Mapping[str, Any] | None = None, + context: FunctionInvocationContext | None = None, **kwargs: Any, ) -> str: """Run the AI function with the provided arguments as a Pydantic model. @@ -457,7 +502,8 @@ async def invoke( Keyword Args: arguments: A mapping or model instance containing the arguments for the function. - kwargs: Keyword arguments to pass to the function, will not be used if ``arguments`` is provided. + context: Explicit function invocation context carrying runtime kwargs. + kwargs: Deprecated keyword arguments to pass to the function. Use ``context`` instead. Returns: The parsed result as a string — either plain text or serialized JSON. @@ -468,13 +514,36 @@ async def invoke( if self.declaration_only: raise ToolException(f"Function '{self.name}' is declaration only and cannot be invoked.") global OBSERVABILITY_SETTINGS + from ._middleware import FunctionInvocationContext from .observability import OBSERVABILITY_SETTINGS parser = self.result_parser or FunctionTool.parse_result - original_kwargs = dict(kwargs) - tool_call_id = original_kwargs.pop("tool_call_id", None) - if arguments is not None: + parameter_names = set(self.parameters().get("properties", {}).keys()) + direct_argument_kwargs = ( + {key: value for key, value in kwargs.items() if key in parameter_names} if arguments is None else {} + ) + runtime_kwargs = dict(context.kwargs) if context is not None else {} + deprecated_runtime_kwargs = { + key: value for key, value in kwargs.items() if key not in direct_argument_kwargs and key != "tool_call_id" + } + if deprecated_runtime_kwargs: + warnings.warn( + "Passing runtime keyword arguments directly to FunctionTool.invoke() is deprecated; " + "pass them via FunctionInvocationContext instead.", + DeprecationWarning, + stacklevel=2, + ) + runtime_kwargs.update(deprecated_runtime_kwargs) + tool_call_id = kwargs.get("tool_call_id", runtime_kwargs.pop("tool_call_id", None)) + if arguments is None and direct_argument_kwargs: + arguments = direct_argument_kwargs + if arguments is None and context is not None: + arguments = context.arguments + + if arguments is None: + validated_arguments: dict[str, Any] = {} + else: try: if isinstance(arguments, Mapping): parsed_arguments = dict(arguments) @@ -496,19 +565,45 @@ async def invoke( ) except ValidationError as exc: raise TypeError(f"Invalid arguments for '{self.name}': {exc}") from exc - kwargs = _validate_arguments_against_schema( + + validated_arguments = _validate_arguments_against_schema( arguments=parsed_arguments, schema=self.parameters(), tool_name=self.name, ) - if getattr(self, "_forward_runtime_kwargs", False) and original_kwargs: - kwargs.update(original_kwargs) - else: - kwargs = original_kwargs + + effective_context = context + if effective_context is None and self._context_parameter_name is not None: + effective_context = FunctionInvocationContext( + function=self, + arguments=validated_arguments, + kwargs=runtime_kwargs, + ) + if effective_context is not None: + effective_context.function = self + effective_context.arguments = validated_arguments + effective_context.kwargs = dict(runtime_kwargs) + + call_kwargs = dict(validated_arguments) + observable_kwargs = dict(validated_arguments) + + # Legacy runtime kwargs injection path retained for backwards compatibility with tools + # that still declare ``**kwargs``. New tools should consume runtime data via ``ctx``. + legacy_runtime_kwargs = dict(runtime_kwargs) + if self._forward_runtime_kwargs and legacy_runtime_kwargs: + for key, value in legacy_runtime_kwargs.items(): + if key not in call_kwargs: + call_kwargs[key] = value + if key not in observable_kwargs: + observable_kwargs[key] = value + + if self._context_parameter_name is not None and effective_context is not None: + call_kwargs[self._context_parameter_name] = effective_context + if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined] logger.info(f"Function name: {self.name}") - logger.debug(f"Function arguments: {kwargs}") - res = self.__call__(**kwargs) + logger.debug(f"Function arguments: {observable_kwargs}") + res = self.__call__(**call_kwargs) result = await res if inspect.isawaitable(res) else res try: parsed = parser(result) @@ -523,7 +618,7 @@ async def invoke( # Filter out framework kwargs that are not JSON serializable. serializable_kwargs = { k: v - for k, v in kwargs.items() + for k, v in observable_kwargs.items() if k not in { "chat_options", @@ -549,7 +644,7 @@ async def invoke( start_time_stamp = perf_counter() end_time_stamp: float | None = None try: - res = self.__call__(**kwargs) + res = self.__call__(**call_kwargs) result = await res if inspect.isawaitable(res) else res end_time_stamp = perf_counter() except Exception as exception: @@ -1130,9 +1225,10 @@ async def _auto_invoke_function( *, config: FunctionInvocationConfiguration, tool_map: dict[str, FunctionTool], + invocation_session: AgentSession | None = None, sequence_index: int | None = None, request_index: int | None = None, - middleware_pipeline: FunctionMiddlewarePipeline | None = None, # Optional MiddlewarePipeline + middleware_pipeline: FunctionMiddlewarePipeline | None = None, ) -> Content: """Invoke a function call requested by the agent, applying middleware that is defined. @@ -1143,6 +1239,7 @@ async def _auto_invoke_function( Keyword Args: config: The function invocation configuration. tool_map: A mapping of tool names to FunctionTool instances. + invocation_session: The agent session for this invocation, if any. sequence_index: The index of the function call in the sequence. request_index: The index of the request iteration. middleware_pipeline: Optional middleware pipeline to apply during execution. @@ -1194,6 +1291,8 @@ async def _auto_invoke_function( for key, value in (custom_args or {}).items() if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"} } + if invocation_session is not None: + runtime_kwargs["session"] = invocation_session try: if not cast(bool, getattr(tool, "_schema_supplied", False)) and tool.input_model is not None: args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True) @@ -1215,19 +1314,31 @@ async def _auto_invoke_function( additional_properties=function_call_content.additional_properties, ) + from ._middleware import FunctionInvocationContext + if middleware_pipeline is None or not middleware_pipeline.has_middlewares: # No middleware - execute directly try: + direct_context = None + if getattr(tool, "_forward_runtime_kwargs", False) or getattr(tool, "_context_parameter_name", None): + direct_context = FunctionInvocationContext( + function=tool, + arguments=args, + session=invocation_session, + kwargs=runtime_kwargs.copy(), + ) function_result = await tool.invoke( arguments=args, + context=direct_context, tool_call_id=function_call_content.call_id, - **runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {}, ) return Content.from_function_result( call_id=function_call_content.call_id, # type: ignore[arg-type] result=function_result, additional_properties=function_call_content.additional_properties, ) + except UserInputRequiredException: + raise except Exception as exc: message = "Error: Function failed." if config.get("include_detailed_errors", False): @@ -1239,19 +1350,18 @@ async def _auto_invoke_function( additional_properties=function_call_content.additional_properties, ) # Execute through middleware pipeline if available - from ._middleware import FunctionInvocationContext - middleware_context = FunctionInvocationContext( function=tool, arguments=args, + session=invocation_session, kwargs=runtime_kwargs.copy(), ) async def final_function_handler(context_obj: Any) -> Any: return await tool.invoke( arguments=context_obj.arguments, + context=context_obj, tool_call_id=function_call_content.call_id, - **context_obj.kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {}, ) from ._middleware import MiddlewareTermination @@ -1274,6 +1384,8 @@ async def final_function_handler(context_obj: Any) -> Any: additional_properties=function_call_content.additional_properties, ) raise + except UserInputRequiredException: + raise except Exception as exc: message = "Error: Function failed." if config.get("include_detailed_errors", False): @@ -1302,7 +1414,8 @@ async def _try_execute_function_calls( function_calls: Sequence[Content], tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]], config: FunctionInvocationConfiguration, - middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports + invocation_session: AgentSession | None = None, + middleware_pipeline: Any = None, ) -> tuple[Sequence[Content], bool]: """Execute multiple function calls concurrently. @@ -1312,6 +1425,7 @@ async def _try_execute_function_calls( function_calls: A sequence of FunctionCallContent to execute. tools: The tools available for execution. config: Configuration for function invocation. + invocation_session: The agent session for this invocation, if any. middleware_pipeline: Optional middleware pipeline to apply during execution. Returns: @@ -1381,6 +1495,8 @@ async def _try_execute_function_calls( # Run all function calls concurrently, handling MiddlewareTermination from ._middleware import MiddlewareTermination + extra_user_input_contents: list[Content] = [] + async def invoke_with_termination_handling( function_call: Content, seq_idx: int, @@ -1391,6 +1507,7 @@ async def invoke_with_termination_handling( function_call_content=function_call, # type: ignore[arg-type] custom_args=custom_args, tool_map=tool_map, + invocation_session=invocation_session, sequence_index=seq_idx, request_index=attempt_idx, middleware_pipeline=middleware_pipeline, @@ -1407,6 +1524,26 @@ async def invoke_with_termination_handling( result=exc.result, ) return (result_content, True) + except UserInputRequiredException as exc: + if exc.contents: + propagated: list[Content] = [] + for item in exc.contents: + if isinstance(item, Content): + item.call_id = function_call.call_id # type: ignore[attr-defined] + if not item.id: # type: ignore[attr-defined] + item.id = function_call.call_id # type: ignore[attr-defined] + propagated.append(item) + if propagated: + extra_user_input_contents.extend(propagated[1:]) + return (propagated[0], False) + return ( + Content.from_function_result( + call_id=function_call.call_id, # type: ignore[arg-type] + result="Tool requires user input but no request details were provided.", + exception="UserInputRequiredException", + ), + False, + ) execution_results = await asyncio.gather(*[ invoke_with_termination_handling(function_call, seq_idx) for seq_idx, function_call in enumerate(function_calls) @@ -1414,6 +1551,7 @@ async def invoke_with_termination_handling( # Unpack results - each is (Content, terminate_flag) contents: list[Content] = [result[0] for result in execution_results] + contents.extend(extra_user_input_contents) # If any function requested termination, terminate the loop should_terminate = any(result[1] for result in execution_results) return (contents, should_terminate) @@ -1426,6 +1564,7 @@ async def _execute_function_calls( function_calls: list[Content], tool_options: dict[str, Any] | None, config: FunctionInvocationConfiguration, + invocation_session: AgentSession | None = None, middleware_pipeline: Any = None, ) -> tuple[list[Content], bool, bool]: tools = _extract_tools(tool_options) @@ -1436,6 +1575,7 @@ async def _execute_function_calls( attempt_idx=attempt_idx, function_calls=function_calls, tools=tools, # type: ignore + invocation_session=invocation_session, middleware_pipeline=middleware_pipeline, config=config, ) @@ -1645,7 +1785,10 @@ def _handle_function_call_results( ) -> FunctionRequestResult: from ._types import Message - if any(fccr.type in {"function_approval_request", "function_call"} for fccr in function_call_results): + if any( + fccr.type in {"function_approval_request", "function_call"} or fccr.user_input_request + for fccr in function_call_results + ): # Only add items that aren't already in the message (e.g. function_approval_request wrappers). # Declaration-only function_call items are already present from the LLM response. new_items = [fccr for fccr in function_call_results if fccr.type != "function_call"] @@ -1811,6 +1954,8 @@ def get_response( *, stream: Literal[False] = ..., options: ChatOptions[ResponseModelBoundT], + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... @@ -1821,6 +1966,8 @@ def get_response( *, stream: Literal[False] = ..., options: OptionsCoT | ChatOptions[None] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]]: ... @@ -1831,6 +1978,8 @@ def get_response( *, stream: Literal[True], options: OptionsCoT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... @@ -1841,6 +1990,8 @@ def get_response( stream: bool = False, options: OptionsCoT | ChatOptions[Any] | None = None, function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: from ._middleware import FunctionMiddlewarePipeline @@ -1851,24 +2002,45 @@ def get_response( ) super_get_response = super().get_response # type: ignore[misc] + if kwargs: + warnings.warn( + "Passing client-specific keyword arguments directly to get_response() is deprecated; " + "pass them via client_kwargs instead.", + DeprecationWarning, + stacklevel=2, + ) + + effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + effective_function_middleware = function_middleware + if effective_function_middleware is None: + middleware_from_client_kwargs = effective_client_kwargs.pop("function_middleware", None) + if middleware_from_client_kwargs is not None: + effective_function_middleware = cast(Sequence[Any], middleware_from_client_kwargs) # ChatMiddleware adds this kwarg function_middleware_pipeline = FunctionMiddlewarePipeline( - *(self.function_middleware), *(function_middleware or []) + *(self.function_middleware), *(effective_function_middleware or []) ) max_errors = self.function_invocation_configuration.get( "max_consecutive_errors_per_request", DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST ) - additional_function_arguments: dict[str, Any] = {} + additional_function_arguments = ( + dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {} + ) if options and (additional_opts := options.get("additional_function_arguments")): # type: ignore[attr-defined] - additional_function_arguments = additional_opts # type: ignore + additional_function_arguments.update(cast(Mapping[str, Any], additional_opts)) + from ._sessions import AgentSession as _AgentSession + + raw_session = effective_client_kwargs.get("session") + invocation_session = raw_session if isinstance(raw_session, _AgentSession) else None execute_function_calls = partial( _execute_function_calls, custom_args=additional_function_arguments, config=self.function_invocation_configuration, + invocation_session=invocation_session, middleware_pipeline=function_middleware_pipeline, ) - filtered_kwargs = {k: v for k, v in kwargs.items() if k != "session"} + filtered_kwargs = {k: v for k, v in {**effective_client_kwargs, **kwargs}.items() if k != "session"} # Make options mutable so we can update conversation_id during function invocation loop mutable_options: dict[str, Any] = dict(options) if options else {} @@ -1918,7 +2090,7 @@ async def _get_response() -> ChatResponse[Any]: messages=prepped_messages, stream=False, options=mutable_options, - **filtered_kwargs, + client_kwargs=filtered_kwargs, ), ) @@ -1987,7 +2159,7 @@ async def _get_response() -> ChatResponse[Any]: messages=prepped_messages, stream=False, options=mutable_options, - **filtered_kwargs, + client_kwargs=filtered_kwargs, ), ) if fcc_messages: @@ -2037,7 +2209,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: messages=prepped_messages, stream=True, options=mutable_options, - **filtered_kwargs, + client_kwargs=filtered_kwargs, ), ) await inner_stream @@ -2129,7 +2301,7 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: messages=prepped_messages, stream=True, options=mutable_options, - **filtered_kwargs, + client_kwargs=filtered_kwargs, ), ) await final_inner_stream diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index b8d5f5c29a..6606a3b2c5 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -2631,7 +2631,7 @@ def __init__( stream: AsyncIterable[UpdateT] | Awaitable[AsyncIterable[UpdateT]], *, finalizer: Callable[[Sequence[UpdateT]], FinalT | Awaitable[FinalT]] | None = None, - transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None]] | None = None, + transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None]] | None = None, cleanup_hooks: list[Callable[[], Awaitable[None] | None]] | None = None, result_hooks: list[Callable[[FinalT], FinalT | Awaitable[FinalT | None] | None]] | None = None, ) -> None: @@ -2655,7 +2655,7 @@ def __init__( self._consumed: bool = False self._finalized: bool = False self._final_result: FinalT | None = None - self._transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None]] = ( + self._transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None]] = ( transform_hooks if transform_hooks is not None else [] ) self._result_hooks: list[Callable[[FinalT], FinalT | Awaitable[FinalT | None] | None]] = ( @@ -2928,7 +2928,7 @@ async def get_final_response(self) -> FinalT: def with_transform_hook( self, - hook: Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None], + hook: Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None], ) -> ResponseStream[UpdateT, FinalT]: """Register a transform hook executed for each update during iteration.""" self._transform_hooks.append(hook) diff --git a/python/packages/core/agent_framework/azure/_chat_client.py b/python/packages/core/agent_framework/azure/_chat_client.py index b57abd6faf..21c38f6b57 100644 --- a/python/packages/core/agent_framework/azure/_chat_client.py +++ b/python/packages/core/agent_framework/azure/_chat_client.py @@ -172,12 +172,12 @@ def __init__( credential: AzureCredentialTypes | AzureTokenProvider | None = None, default_headers: Mapping[str, str] | None = None, async_client: AsyncAzureOpenAI | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, instruction_role: str | None = None, middleware: Sequence[MiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, - **kwargs: Any, ) -> None: """Initialize an Azure OpenAI Chat completion client. @@ -205,13 +205,13 @@ def __init__( default_headers: The default headers mapping of string keys to string values for HTTP requests. async_client: An existing client to use. + additional_properties: Additional properties stored on the client instance. env_file_path: Use the environment settings file as a fallback to using env vars. env_file_encoding: The encoding of the environment settings file, defaults to 'utf-8'. instruction_role: The role to use for 'instruction' messages, for example, summarization prompts could use `developer` or `system`. middleware: Optional sequence of middleware to apply to requests. function_invocation_configuration: Optional configuration for function invocation behavior. - kwargs: Other keyword parameters. Examples: .. code-block:: python @@ -283,10 +283,10 @@ class MyOptions(AzureOpenAIChatOptions, total=False): credential=credential, default_headers=default_headers, client=async_client, + additional_properties=additional_properties, instruction_role=instruction_role, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) @override diff --git a/python/packages/core/agent_framework/exceptions.py b/python/packages/core/agent_framework/exceptions.py index f38aa38590..4f56c34b5c 100644 --- a/python/packages/core/agent_framework/exceptions.py +++ b/python/packages/core/agent_framework/exceptions.py @@ -180,6 +180,34 @@ class ToolExecutionException(ToolException): pass +class UserInputRequiredException(ToolException): + """Raised when a tool wrapping a sub-agent requires user input to proceed. + + This exception carries the ``user_input_request`` Content items emitted by + the sub-agent (e.g., ``oauth_consent_request``, ``function_approval_request``) + so the tool invocation layer can propagate them to the parent agent's response + instead of swallowing them as a generic tool error. + + Args: + contents: The user-input-request Content items from the sub-agent response. + message: Human-readable description of why user input is needed. + """ + + def __init__( + self, + contents: list[Any], + message: str = "Tool requires user input to proceed.", + ) -> None: + """Create a UserInputRequiredException. + + Args: + contents: The user-input-request Content items from the sub-agent response. + message: Human-readable description of why user input is needed. + """ + super().__init__(message, log_level=None) + self.contents = contents + + # endregion # region Middleware Exceptions diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index a595582b33..71b8702c14 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1153,18 +1153,47 @@ def get_response( options: OptionsCoT | ChatOptions[Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: - """Trace chat responses with OpenTelemetry spans and metrics.""" + """Trace chat responses with OpenTelemetry spans and metrics. + + Args: + messages: The message or messages to send to the model. + stream: Whether to stream the response. Defaults to False. + options: Chat options as a TypedDict. + + Keyword Args: + kwargs: Compatibility keyword arguments from higher client layers. This layer does + not consume ``function_invocation_kwargs`` directly; if present, it is ignored + because function invocation has already been processed above. If a ``client_kwargs`` + mapping is present, it is flattened into ordinary keyword arguments for tracing and + forwarding so clients that use those values continue to work while clients that + ignore extra kwargs remain compatible. + """ from ._types import ChatResponse, ChatResponseUpdate, ResponseStream # type: ignore[reportUnusedImport] global OBSERVABILITY_SETTINGS super_get_response = super().get_response # type: ignore[misc] + compatibility_client_kwargs = kwargs.pop("client_kwargs", None) + kwargs.pop("function_invocation_kwargs", None) + merged_client_kwargs = ( + dict(cast(Mapping[str, Any], compatibility_client_kwargs)) + if isinstance(compatibility_client_kwargs, Mapping) + else {} + ) + merged_client_kwargs.update(kwargs) if not OBSERVABILITY_SETTINGS.ENABLED: - return super_get_response(messages=messages, stream=stream, options=options, **kwargs) # type: ignore[no-any-return] + return super_get_response( # type: ignore[no-any-return] + messages=messages, + stream=stream, + options=options, + **merged_client_kwargs, + ) opts: dict[str, Any] = options or {} # type: ignore[assignment] provider_name = str(getattr(self, "otel_provider_name", "unknown")) - model_id = kwargs.get("model_id") or opts.get("model_id") or getattr(self, "model_id", None) or "unknown" + model_id = ( + merged_client_kwargs.get("model_id") or opts.get("model_id") or getattr(self, "model_id", None) or "unknown" + ) service_url_func = getattr(self, "service_url", None) service_url = str(service_url_func() if callable(service_url_func) else "unknown") attributes = _get_span_attributes( @@ -1172,13 +1201,18 @@ def get_response( provider_name=provider_name, model=model_id, service_url=service_url, - **kwargs, + **merged_client_kwargs, ) if stream: result_stream = cast( ResponseStream[ChatResponseUpdate, ChatResponse[Any]], - super_get_response(messages=messages, stream=True, options=opts, **kwargs), + super_get_response( + messages=messages, + stream=True, + options=opts, + **merged_client_kwargs, + ), ) # Create span directly without trace.use_span() context attachment. @@ -1266,7 +1300,7 @@ async def _get_response() -> ChatResponse: messages=messages, stream=False, options=opts, - **kwargs, + **merged_client_kwargs, ), ) except Exception as exception: @@ -1393,6 +1427,8 @@ def run( *, stream: Literal[False] = ..., session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @@ -1403,6 +1439,8 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... @@ -1412,6 +1450,8 @@ def run( *, stream: bool = False, session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Trace agent runs with OpenTelemetry spans and metrics.""" @@ -1430,11 +1470,15 @@ def run( messages=messages, stream=stream, session=session, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, **kwargs, ) default_options = getattr(self, "default_options", {}) options = kwargs.get("options") + merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + merged_client_kwargs.update(kwargs) merged_options: dict[str, Any] = merge_chat_options(default_options, options or {}) attributes = _get_span_attributes( operation_name=OtelAttr.AGENT_INVOKE_OPERATION, @@ -1444,7 +1488,7 @@ def run( agent_description=getattr(self, "description", None), thread_id=session.service_session_id if session else None, all_options=merged_options, - **kwargs, + **merged_client_kwargs, ) if stream: @@ -1452,6 +1496,8 @@ def run( messages=messages, stream=True, session=session, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, **kwargs, ) if isinstance(run_result, ResponseStream): @@ -1541,6 +1587,8 @@ async def _run() -> AgentResponse: messages=messages, stream=False, session=session, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, **kwargs, ) except Exception as exception: diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index 0562e68f3e..82d47ec054 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -15,7 +15,7 @@ ) from datetime import datetime, timezone from itertools import chain -from typing import Any, Generic, Literal, cast +from typing import Any, Generic, Literal, cast, overload from openai import AsyncOpenAI, BadRequestError from openai.lib._parsing._completions import type_to_response_format_param @@ -30,7 +30,8 @@ from pydantic import BaseModel from .._clients import BaseChatClient -from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer +from .._docstrings import apply_layered_docstring +from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer, FunctionMiddlewareTypes from .._settings import load_settings from .._tools import ( FunctionInvocationConfiguration, @@ -72,6 +73,7 @@ logger = logging.getLogger("agent_framework.openai") +ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel) ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None) @@ -213,6 +215,57 @@ def get_web_search_tool( # endregion + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[False] = ..., + options: ChatOptions[ResponseModelBoundT], + **kwargs: Any, + ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... + + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[False] = ..., + options: OpenAIChatOptionsT | ChatOptions[None] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[Any]]: ... + + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[True], + options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, + **kwargs: Any, + ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... + + @override + def get_response( + self, + messages: Sequence[Message], + *, + stream: bool = False, + options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: + """Get a response from the raw OpenAI chat client.""" + super_get_response = cast( + "Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]", + super().get_response, # type: ignore[misc] + ) + return super_get_response( # type: ignore[no-any-return] + messages=messages, + stream=stream, + options=options, + **kwargs, + ) + @override def _inner_get_response( self, @@ -716,6 +769,77 @@ class OpenAIChatClient( # type: ignore[misc] ): """OpenAI Chat completion class with middleware, telemetry, and function invocation support.""" + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[False] = ..., + options: ChatOptions[ResponseModelBoundT], + function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... + + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[False] = ..., + options: OpenAIChatOptionsT | ChatOptions[None] | None = None, + function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[Any]]: ... + + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[True], + options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, + function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + **kwargs: Any, + ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... + + @override + def get_response( + self, + messages: Sequence[Message], + *, + stream: bool = False, + options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, + function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: + """Get a response from the OpenAI chat client with all standard layers enabled.""" + super_get_response = cast( + "Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]", + super().get_response, # type: ignore[misc] + ) + return super_get_response( # type: ignore[no-any-return] + messages=messages, + stream=stream, + options=options, + function_middleware=function_middleware, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + middleware=middleware, + **kwargs, + ) + def __init__( self, *, @@ -819,3 +943,25 @@ class MyOptions(OpenAIChatOptions, total=False): middleware=middleware, function_invocation_configuration=function_invocation_configuration, ) + + +def _apply_openai_chat_client_docstrings() -> None: + """Align OpenAI chat-client docstrings with the raw implementation.""" + apply_layered_docstring(RawOpenAIChatClient.get_response, BaseChatClient.get_response) + apply_layered_docstring( + OpenAIChatClient.get_response, + RawOpenAIChatClient.get_response, + extra_keyword_args={ + "function_middleware": """ + Optional per-call function middleware. + When omitted, middleware configured on the client or forwarded from higher layers is used. + """, + "middleware": """ + Optional per-call chat and function middleware. + This is merged with any middleware configured on the client for the current request. + """, + }, + ) + + +_apply_openai_chat_client_docstrings() diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index a60e924387..352f562417 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import contextlib +import inspect from collections.abc import AsyncIterable, MutableSequence from typing import Any from unittest.mock import AsyncMock, MagicMock @@ -27,6 +28,7 @@ ) from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name from agent_framework._mcp import MCPTool +from agent_framework._middleware import FunctionInvocationContext def test_agent_session_type(agent_session: AgentSession) -> None: @@ -65,6 +67,30 @@ def test_chat_client_agent_type(client: SupportsChatGetResponse) -> None: assert isinstance(chat_client_agent, SupportsAgentRun) +def test_agent_init_docstring_surfaces_raw_agent_constructor_docs() -> None: + docstring = inspect.getdoc(Agent.__init__) + + assert docstring is not None + assert "client: The chat client to use for the agent." in docstring + assert "middleware: List of middleware to intercept agent and function invocations." in docstring + + +def test_agent_run_docstring_surfaces_raw_agent_runtime_docs() -> None: + docstring = inspect.getdoc(Agent.run) + + assert docstring is not None + assert "Run the agent with the given messages and options." in docstring + assert "function_invocation_kwargs: Keyword arguments forwarded to tool invocation." in docstring + assert "middleware: Optional per-run agent, chat, and function middleware." in docstring + + +def test_agent_run_is_defined_on_agent_class() -> None: + signature = inspect.signature(Agent.run) + + assert Agent.run.__qualname__ == "Agent.run" + assert "middleware" in signature.parameters + + async def test_chat_client_agent_init(client: SupportsChatGetResponse) -> None: agent_id = str(uuid4()) agent = Agent(client=client, id=agent_id, description="Test") @@ -85,6 +111,13 @@ async def test_chat_client_agent_init_with_name( assert agent.description == "Test" +def test_agent_init_warns_for_direct_additional_properties(client: SupportsChatGetResponse) -> None: + with pytest.warns(DeprecationWarning, match="additional_properties"): + agent = Agent(client=client, legacy_key="legacy-value") + + assert agent.additional_properties["legacy_key"] == "legacy-value" + + async def test_chat_client_agent_run(client: SupportsChatGetResponse) -> None: agent = Agent(client=client) @@ -217,9 +250,36 @@ async def test_prepare_session_does_not_mutate_agent_chat_options( assert len(agent.default_options["tools"]) == 1 -async def test_chat_client_agent_run_with_session( +async def test_prepare_run_context_separates_function_invocation_kwargs_from_chat_options( chat_client_base: SupportsChatGetResponse, ) -> None: + agent = Agent(client=chat_client_base) + session = agent.create_session() + + ctx = await agent._prepare_run_context( # type: ignore[reportPrivateUsage] + messages="Hello", + session=session, + tools=None, + options={ + "temperature": 0.4, + "additional_function_arguments": {"from_options": "options-value"}, + }, + legacy_kwargs={"legacy_key": "legacy-value"}, + function_invocation_kwargs={"runtime_key": "runtime-value"}, + client_kwargs={"client_key": "client-value"}, + ) + + assert ctx["chat_options"]["temperature"] == 0.4 + assert "additional_function_arguments" not in ctx["chat_options"] + assert ctx["function_invocation_kwargs"]["from_options"] == "options-value" + assert ctx["function_invocation_kwargs"]["legacy_key"] == "legacy-value" + assert ctx["function_invocation_kwargs"]["runtime_key"] == "runtime-value" + assert "session" not in ctx["function_invocation_kwargs"] + assert ctx["client_kwargs"]["client_key"] == "client-value" + assert ctx["client_kwargs"]["session"] is session + + +async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChatGetResponse) -> None: mock_response = ChatResponse( messages=[Message(role="assistant", contents=[Content.from_text("test response")])], conversation_id="123", @@ -660,8 +720,9 @@ async def test_chat_agent_as_tool_basic(client: SupportsChatGetResponse) -> None assert tool.name == "TestAgent" assert tool.description == "Test agent for as_tool" + assert tool.approval_mode == "never_require" assert hasattr(tool, "func") - assert hasattr(tool, "input_model") + assert tool.input_model is None async def test_chat_agent_as_tool_custom_parameters( @@ -675,13 +736,15 @@ async def test_chat_agent_as_tool_custom_parameters( description="Custom description", arg_name="query", arg_description="Custom input description", + approval_mode="always_require", ) assert tool.name == "CustomTool" assert tool.description == "Custom description" + assert tool.approval_mode == "always_require" # Check that the input model has the custom field name - schema = tool.input_model.model_json_schema() + schema = tool.parameters() assert "query" in schema["properties"] assert schema["properties"]["query"]["description"] == "Custom input description" @@ -700,7 +763,7 @@ async def test_chat_agent_as_tool_defaults(client: SupportsChatGetResponse) -> N assert tool.description == "" # Should default to empty string # Check default input field - schema = tool.input_model.model_json_schema() + schema = tool.parameters() assert "task" in schema["properties"] assert "Task for TestAgent" in schema["properties"]["task"]["description"] @@ -723,11 +786,11 @@ async def test_chat_agent_as_tool_function_execution( tool = agent.as_tool() # Test function execution - result = await tool.invoke(arguments=tool.input_model(task="Hello")) + result = await tool.invoke(arguments={"task": "Hello"}) - # Should return the agent's response text + # as_tool always uses streaming and finalizes the accumulated updates. assert isinstance(result, str) - assert result == "test response" # From mock chat client + assert result == "test streaming response another update" async def test_chat_agent_as_tool_with_stream_callback( @@ -745,7 +808,7 @@ def stream_callback(update: AgentResponseUpdate) -> None: tool = agent.as_tool(stream_callback=stream_callback) # Execute the tool - result = await tool.invoke(arguments=tool.input_model(task="Hello")) + result = await tool.invoke(arguments={"task": "Hello"}) # Should have collected streaming updates assert len(collected_updates) > 0 @@ -764,8 +827,8 @@ async def test_chat_agent_as_tool_with_custom_arg_name( tool = agent.as_tool(arg_name="prompt", arg_description="Custom prompt input") # Test that the custom argument name works - result = await tool.invoke(arguments=tool.input_model(prompt="Test prompt")) - assert result == "test response" + result = await tool.invoke(arguments={"prompt": "Test prompt"}) + assert result == "test streaming response another update" async def test_chat_agent_as_tool_with_async_stream_callback( @@ -783,7 +846,7 @@ async def async_stream_callback(update: AgentResponseUpdate) -> None: tool = agent.as_tool(stream_callback=async_stream_callback) # Execute the tool - result = await tool.invoke(arguments=tool.input_model(task="Hello")) + result = await tool.invoke(arguments={"task": "Hello"}) # Should have collected streaming updates assert len(collected_updates) > 0 @@ -813,17 +876,14 @@ async def test_chat_agent_as_tool_name_sanitization( assert tool.name == expected_tool_name, f"Expected {expected_tool_name}, got {tool.name} for input {agent_name}" -async def test_chat_agent_as_tool_propagate_session_true( - client: SupportsChatGetResponse, -) -> None: - """Test that propagate_session=True forwards the parent's session to the sub-agent.""" +async def test_chat_agent_as_tool_propagate_session_true(client: SupportsChatGetResponse) -> None: + """Test that propagate_session=True forwards the session to the sub-agent.""" agent = Agent(client=client, name="SubAgent", description="Sub agent") tool = agent.as_tool(propagate_session=True) parent_session = AgentSession(session_id="parent-session-123") parent_session.state["shared_key"] = "shared_value" - # Spy on the agent's run method to capture the session argument original_run = agent.run captured_session = None @@ -834,16 +894,20 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: agent.run = capturing_run # type: ignore[assignment, method-assign] - await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session) + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + session=parent_session, + ) + ) assert captured_session is parent_session assert captured_session.session_id == "parent-session-123" assert captured_session.state["shared_key"] == "shared_value" -async def test_chat_agent_as_tool_propagate_session_false_by_default( - client: SupportsChatGetResponse, -) -> None: +async def test_chat_agent_as_tool_propagate_session_false_by_default(client: SupportsChatGetResponse) -> None: """Test that propagate_session defaults to False and does not forward the session.""" agent = Agent(client=client, name="SubAgent", description="Sub agent") tool = agent.as_tool() # default: propagate_session=False @@ -860,22 +924,25 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: agent.run = capturing_run # type: ignore[assignment, method-assign] - await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session) + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + session=parent_session, + ) + ) assert captured_session is None -async def test_chat_agent_as_tool_propagate_session_shares_state( - client: SupportsChatGetResponse, -) -> None: - """Test that shared session allows the sub-agent to read and write parent's state.""" +async def test_chat_agent_as_tool_propagate_session_shares_state(client: SupportsChatGetResponse) -> None: + """Test that a propagated session allows the sub-agent to read and write parent state.""" agent = Agent(client=client, name="SubAgent", description="Sub agent") tool = agent.as_tool(propagate_session=True) parent_session = AgentSession(session_id="shared-session") parent_session.state["counter"] = 0 - # The sub-agent receives the same session object, so mutations are shared original_run = agent.run captured_session = None @@ -888,9 +955,14 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: agent.run = capturing_run # type: ignore[assignment, method-assign] - await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session) + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + session=parent_session, + ) + ) - # The parent's state should reflect the sub-agent's mutation assert parent_session.state["counter"] == 1 @@ -992,19 +1064,19 @@ async def capturing_inner( assert len(tool_names) == 3 -async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> None: - """Verify tool execution receives 'session' inside **kwargs when function is called by client.""" +async def test_agent_tool_receives_session_via_function_invocation_context( + chat_client_base: Any, +) -> None: + """Verify ctx-based tools receive the session via FunctionInvocationContext.session.""" captured: dict[str, Any] = {} - @tool(name="echo_session_info", approval_mode="never_require") - def echo_session_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUnknownParameterType] - session = kwargs.get("session") - captured["has_session"] = session is not None - captured["has_state"] = session.state is not None if isinstance(session, AgentSession) else False + @tool(name="capture_session_context", approval_mode="never_require") + def capture_session_context(text: str, ctx: FunctionInvocationContext) -> str: + captured["session"] = ctx.session + captured["has_state"] = ctx.session.state is not None if isinstance(ctx.session, AgentSession) else False return f"echo: {text}" - # Make the base client emit a function call for our tool chat_client_base.run_responses = [ ChatResponse( messages=Message( @@ -1012,7 +1084,7 @@ def echo_session_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUn contents=[ Content.from_function_call( call_id="1", - name="echo_session_info", + name="capture_session_context", arguments='{"text": "hello"}', ) ], @@ -1021,18 +1093,14 @@ def echo_session_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUn ChatResponse(messages=Message(role="assistant", text="done")), ] - agent = Agent(client=chat_client_base, tools=[echo_session_info]) + agent = Agent(client=chat_client_base, tools=[capture_session_context]) session = agent.create_session() - result = await agent.run( - "hello", - session=session, - options={"additional_function_arguments": {"session": session}}, - ) + result = await agent.run("hello", session=session) assert result.text == "done" - assert captured.get("has_session") is True - assert captured.get("has_state") is True + assert captured["session"] is session + assert captured["has_state"] is True async def test_chat_agent_tool_choice_run_level_overrides_agent_level(chat_client_base: Any, tool_tool: Any) -> None: @@ -1622,4 +1690,26 @@ async def test_stores_by_default_with_store_false_in_default_options_injects_inm assert any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers) -# endregion +# region as_tool user_input_request propagation + + +async def test_as_tool_raises_on_user_input_request(client: SupportsChatGetResponse) -> None: + """Test that as_tool raises when the wrapped sub-agent requests user input.""" + from agent_framework.exceptions import UserInputRequiredException + + consent_content = Content.from_oauth_consent_request( + consent_link="https://login.microsoftonline.com/consent", + ) + client.streaming_responses = [ # type: ignore[attr-defined] + [ChatResponseUpdate(contents=[consent_content], role="assistant")], + ] + + agent = Agent(client=client, name="OAuthAgent", description="Agent requiring consent") + agent_tool = agent.as_tool() + + with raises(UserInputRequiredException) as exc_info: + await agent_tool.invoke(arguments={"task": "Do something"}) + + assert len(exc_info.value.contents) == 1 + assert exc_info.value.contents[0].type == "oauth_consent_request" + assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent" diff --git a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py index da8e907c40..8aa71a4582 100644 --- a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py +++ b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py @@ -6,7 +6,7 @@ from typing import Any from agent_framework import Agent, ChatResponse, Content, Message, agent_middleware -from agent_framework._middleware import AgentContext +from agent_framework._middleware import AgentContext, FunctionInvocationContext from .conftest import MockChatClient @@ -14,14 +14,28 @@ class TestAsToolKwargsPropagation: """Test cases for kwargs propagation through as_tool() delegation.""" + @staticmethod + def _build_context( + tool: Any, + *, + task: str, + runtime_kwargs: dict[str, Any] | None = None, + ) -> FunctionInvocationContext: + return FunctionInvocationContext( + function=tool, + arguments={"task": task}, + kwargs=runtime_kwargs, + ) + async def test_as_tool_forwards_runtime_kwargs(self, client: MockChatClient) -> None: - """Test that runtime kwargs are forwarded through as_tool() to sub-agent.""" + """Test that runtime kwargs are forwarded through as_tool() to sub-agent tools.""" captured_kwargs: dict[str, Any] = {} + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - # Capture kwargs passed to the sub-agent captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock response @@ -39,29 +53,31 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Create tool from sub-agent tool = sub_agent.as_tool(name="delegate", arg_name="task") - # Directly invoke the tool with kwargs (simulating what happens during agent execution) + # Directly invoke the tool with explicit runtime context (simulating agent execution). _ = await tool.invoke( - arguments=tool.input_model(task="Test delegation"), - api_token="secret-xyz-123", - user_id="user-456", - session_id="session-789", + context=self._build_context( + tool, + task="Test delegation", + runtime_kwargs={ + "api_token": "secret-xyz-123", + "user_id": "user-456", + "session_id": "session-789", + }, + ), ) - # Verify kwargs were forwarded to sub-agent - assert "api_token" in captured_kwargs, f"Expected 'api_token' in {captured_kwargs}" - assert captured_kwargs["api_token"] == "secret-xyz-123" - assert "user_id" in captured_kwargs - assert captured_kwargs["user_id"] == "user-456" - assert "session_id" in captured_kwargs - assert captured_kwargs["session_id"] == "session-789" + assert captured_kwargs == {} + assert captured_function_invocation_kwargs["api_token"] == "secret-xyz-123" + assert captured_function_invocation_kwargs["user_id"] == "user-456" + assert captured_function_invocation_kwargs["session_id"] == "session-789" - async def test_as_tool_excludes_arg_name_from_forwarded_kwargs(self, client: MockChatClient) -> None: - """Test that the arg_name parameter is not forwarded as a kwarg.""" - captured_kwargs: dict[str, Any] = {} + async def test_as_tool_forwards_context_kwargs_verbatim(self, client: MockChatClient) -> None: + """Test that runtime kwargs are forwarded exactly from FunctionInvocationContext.kwargs.""" + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock response @@ -79,25 +95,26 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Invoke tool with both the arg_name field and additional kwargs await tool.invoke( - arguments=tool.input_model(custom_task="Test task"), - api_token="token-123", - custom_task="should_be_excluded", # This should be filtered out + context=FunctionInvocationContext( + function=tool, + arguments={"custom_task": "Test task"}, + kwargs={ + "api_token": "token-123", + "custom_task": "should_be_excluded", + }, + ) ) - # The arg_name ("custom_task") should NOT be in the forwarded kwargs - assert "custom_task" not in captured_kwargs - # But other kwargs should be present - assert "api_token" in captured_kwargs - assert captured_kwargs["api_token"] == "token-123" + assert captured_function_invocation_kwargs["custom_task"] == "should_be_excluded" + assert captured_function_invocation_kwargs["api_token"] == "token-123" async def test_as_tool_nested_delegation_propagates_kwargs(self, client: MockChatClient) -> None: - """Test that kwargs propagate through multiple levels of delegation (A → B → C).""" - captured_kwargs_list: list[dict[str, Any]] = [] + """Test that runtime kwargs propagate through multiple levels of delegation (A -> B -> C).""" + captured_function_invocation_kwargs_list: list[dict[str, Any]] = [] @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - # Capture kwargs at each level - captured_kwargs_list.append(dict(context.kwargs)) + captured_function_invocation_kwargs_list.append(dict(context.function_invocation_kwargs)) await call_next() # Setup mock responses to trigger nested tool invocation: B calls tool C, then completes. @@ -140,24 +157,29 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Invoke tool B with kwargs - should propagate to both B and C await tool_b.invoke( - arguments=tool_b.input_model(task="Test cascade"), - trace_id="trace-abc-123", - tenant_id="tenant-xyz", - options={"additional_function_arguments": {"trace_id": "trace-abc-123", "tenant_id": "tenant-xyz"}}, + context=self._build_context( + tool_b, + task="Test cascade", + runtime_kwargs={ + "trace_id": "trace-abc-123", + "tenant_id": "tenant-xyz", + }, + ), ) - # Verify kwargs were forwarded to the first agent invocation. - assert len(captured_kwargs_list) >= 1 - assert captured_kwargs_list[0].get("trace_id") == "trace-abc-123" - assert captured_kwargs_list[0].get("tenant_id") == "tenant-xyz" + assert len(captured_function_invocation_kwargs_list) >= 1 + assert captured_function_invocation_kwargs_list[0].get("trace_id") == "trace-abc-123" + assert captured_function_invocation_kwargs_list[0].get("tenant_id") == "tenant-xyz" async def test_as_tool_streaming_mode_forwards_kwargs(self, client: MockChatClient) -> None: - """Test that kwargs are forwarded in streaming mode.""" + """Test that runtime kwargs are forwarded in streaming mode.""" captured_kwargs: dict[str, Any] = {} + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock streaming responses @@ -182,13 +204,15 @@ async def stream_callback(update: Any) -> None: # Invoke tool with kwargs while streaming callback is active await tool.invoke( - arguments=tool.input_model(task="Test streaming"), - api_key="streaming-key-999", + context=self._build_context( + tool, + task="Test streaming", + runtime_kwargs={"api_key": "streaming-key-999"}, + ), ) - # Verify kwargs were forwarded even in streaming mode - assert "api_key" in captured_kwargs - assert captured_kwargs["api_key"] == "streaming-key-999" + assert captured_kwargs == {} + assert captured_function_invocation_kwargs["api_key"] == "streaming-key-999" assert len(captured_updates) == 1 async def test_as_tool_empty_kwargs_still_works(self, client: MockChatClient) -> None: @@ -206,18 +230,20 @@ async def test_as_tool_empty_kwargs_still_works(self, client: MockChatClient) -> tool = sub_agent.as_tool() # Invoke without any extra kwargs - should work without errors - result = await tool.invoke(arguments=tool.input_model(task="Simple task")) + result = await tool.invoke(arguments={"task": "Simple task"}) # Verify tool executed successfully assert result is not None async def test_as_tool_kwargs_with_chat_options(self, client: MockChatClient) -> None: - """Test that kwargs including chat_options are properly forwarded.""" + """Test that runtime kwargs are forwarded only via function_invocation_kwargs.""" captured_kwargs: dict[str, Any] = {} + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock response @@ -235,24 +261,26 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Invoke with various kwargs await tool.invoke( - arguments=tool.input_model(task="Test with options"), - temperature=0.8, - max_tokens=500, - custom_param="custom_value", + context=self._build_context( + tool, + task="Test with options", + runtime_kwargs={ + "temperature": 0.8, + "max_tokens": 500, + "custom_param": "custom_value", + }, + ), ) - # Verify all kwargs were forwarded - assert "temperature" in captured_kwargs - assert captured_kwargs["temperature"] == 0.8 - assert "max_tokens" in captured_kwargs - assert captured_kwargs["max_tokens"] == 500 - assert "custom_param" in captured_kwargs - assert captured_kwargs["custom_param"] == "custom_value" + assert captured_kwargs == {} + assert captured_function_invocation_kwargs["temperature"] == 0.8 + assert captured_function_invocation_kwargs["max_tokens"] == 500 + assert captured_function_invocation_kwargs["custom_param"] == "custom_value" async def test_as_tool_kwargs_isolated_per_invocation(self, client: MockChatClient) -> None: - """Test that kwargs are isolated per invocation and don't leak between calls.""" - first_call_kwargs: dict[str, Any] = {} - second_call_kwargs: dict[str, Any] = {} + """Test that runtime kwargs are isolated per invocation and don't leak between calls.""" + first_call_function_invocation_kwargs: dict[str, Any] = {} + second_call_function_invocation_kwargs: dict[str, Any] = {} call_count = 0 @agent_middleware @@ -260,9 +288,9 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai nonlocal call_count call_count += 1 if call_count == 1: - first_call_kwargs.update(context.kwargs) + first_call_function_invocation_kwargs.update(context.function_invocation_kwargs) elif call_count == 2: - second_call_kwargs.update(context.kwargs) + second_call_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock responses for both calls @@ -281,33 +309,35 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # First call with specific kwargs await tool.invoke( - arguments=tool.input_model(task="First task"), - session_id="session-1", - api_token="token-1", + context=self._build_context( + tool, + task="First task", + runtime_kwargs={"session_id": "session-1", "api_token": "token-1"}, + ), ) # Second call with different kwargs await tool.invoke( - arguments=tool.input_model(task="Second task"), - session_id="session-2", - api_token="token-2", + context=self._build_context( + tool, + task="Second task", + runtime_kwargs={"session_id": "session-2", "api_token": "token-2"}, + ), ) - # Verify first call had its own kwargs - assert first_call_kwargs.get("session_id") == "session-1" - assert first_call_kwargs.get("api_token") == "token-1" + assert first_call_function_invocation_kwargs.get("session_id") == "session-1" + assert first_call_function_invocation_kwargs.get("api_token") == "token-1" - # Verify second call had its own kwargs (not leaked from first) - assert second_call_kwargs.get("session_id") == "session-2" - assert second_call_kwargs.get("api_token") == "token-2" + assert second_call_function_invocation_kwargs.get("session_id") == "session-2" + assert second_call_function_invocation_kwargs.get("api_token") == "token-2" - async def test_as_tool_excludes_conversation_id_from_forwarded_kwargs(self, client: MockChatClient) -> None: - """Test that conversation_id is not forwarded to sub-agent.""" - captured_kwargs: dict[str, Any] = {} + async def test_as_tool_forwards_conversation_id_from_context_kwargs(self, client: MockChatClient) -> None: + """Test that conversation_id is forwarded when explicitly present in runtime context kwargs.""" + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock response @@ -325,17 +355,17 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Invoke tool with conversation_id in kwargs (simulating parent's conversation state) await tool.invoke( - arguments=tool.input_model(task="Test delegation"), - conversation_id="conv-parent-456", - api_token="secret-xyz-123", - user_id="user-456", - ) - - # Verify conversation_id was NOT forwarded to sub-agent - assert "conversation_id" not in captured_kwargs, ( - f"conversation_id should not be forwarded, but got: {captured_kwargs}" + context=self._build_context( + tool, + task="Test delegation", + runtime_kwargs={ + "conversation_id": "conv-parent-456", + "api_token": "secret-xyz-123", + "user_id": "user-456", + }, + ), ) - # Verify other kwargs were still forwarded - assert captured_kwargs.get("api_token") == "secret-xyz-123" - assert captured_kwargs.get("user_id") == "user-456" + assert captured_function_invocation_kwargs.get("conversation_id") == "conv-parent-456" + assert captured_function_invocation_kwargs.get("api_token") == "secret-xyz-123" + assert captured_function_invocation_kwargs.get("user_id") == "user-456" diff --git a/python/packages/core/tests/core/test_clients.py b/python/packages/core/tests/core/test_clients.py index a23b1d2a5f..670ff5f455 100644 --- a/python/packages/core/tests/core/test_clients.py +++ b/python/packages/core/tests/core/test_clients.py @@ -1,8 +1,11 @@ # Copyright (c) Microsoft. All rights reserved. +import inspect from unittest.mock import patch +import pytest + from agent_framework import ( BaseChatClient, ChatResponse, @@ -37,6 +40,60 @@ def test_base_client(chat_client_base: SupportsChatGetResponse): assert isinstance(chat_client_base, SupportsChatGetResponse) +def test_base_client_warns_for_direct_additional_properties(chat_client_base: SupportsChatGetResponse) -> None: + with pytest.warns(DeprecationWarning, match="additional_properties"): + client = type(chat_client_base)(legacy_key="legacy-value") + + assert client.additional_properties["legacy_key"] == "legacy-value" + + +def test_base_client_as_agent_uses_explicit_additional_properties(chat_client_base: SupportsChatGetResponse) -> None: + agent = chat_client_base.as_agent(additional_properties={"team": "core"}) + + assert agent.additional_properties == {"team": "core"} + + +def test_openai_chat_client_get_response_docstring_surfaces_layered_runtime_docs() -> None: + from agent_framework.openai import OpenAIChatClient + + docstring = inspect.getdoc(OpenAIChatClient.get_response) + + assert docstring is not None + assert "Get a response from a chat client." in docstring + assert "function_invocation_kwargs" in docstring + assert "function_middleware: Optional per-call function middleware." in docstring + assert "middleware: Optional per-call chat and function middleware." in docstring + + +def test_openai_chat_client_get_response_is_defined_on_openai_class() -> None: + from agent_framework.openai import OpenAIChatClient + + signature = inspect.signature(OpenAIChatClient.get_response) + + assert OpenAIChatClient.get_response.__qualname__ == "OpenAIChatClient.get_response" + assert "function_middleware" in signature.parameters + assert "middleware" in signature.parameters + + +async def test_base_client_get_response_uses_explicit_client_kwargs(chat_client_base: SupportsChatGetResponse) -> None: + async def fake_inner_get_response(**kwargs): + assert kwargs["trace_id"] == "trace-123" + assert "function_invocation_kwargs" not in kwargs + return ChatResponse(messages=[Message(role="assistant", text="ok")]) + + with patch.object( + chat_client_base, + "_inner_get_response", + side_effect=fake_inner_get_response, + ) as mock_inner_get_response: + await chat_client_base.get_response( + [Message(role="user", text="hello")], + function_invocation_kwargs={"tool_request_id": "tool-123"}, + client_kwargs={"trace_id": "trace-123"}, + ) + mock_inner_get_response.assert_called_once() + + async def test_base_client_get_response(chat_client_base: SupportsChatGetResponse): response = await chat_client_base.get_response([Message(role="user", text="Hello")]) assert response.messages[0].role == "assistant" diff --git a/python/packages/core/tests/core/test_docstrings.py b/python/packages/core/tests/core/test_docstrings.py new file mode 100644 index 0000000000..ab4b116422 --- /dev/null +++ b/python/packages/core/tests/core/test_docstrings.py @@ -0,0 +1,175 @@ +# Copyright (c) Microsoft. All rights reserved. + +from agent_framework._docstrings import apply_layered_docstring, build_layered_docstring + +# -- Helpers: stub functions with various docstring shapes -- + + +def _source_with_full_docstring(x: int) -> int: + """Do something useful. + + Args: + x: The input value. + + Keyword Args: + timeout: Max seconds to wait. + + Returns: + The computed result. + """ + return x + + +def _source_with_args_only(x: int) -> int: + """Do something useful. + + Args: + x: The input value. + + Returns: + The computed result. + """ + return x + + +def _source_no_sections() -> None: + """A plain summary with no Google-style sections.""" + + +def _source_no_docstring() -> None: + pass + + +def _target_stub() -> None: + pass + + +# -- build_layered_docstring tests -- + + +def test_build_returns_none_when_source_has_no_docstring() -> None: + result = build_layered_docstring(_source_no_docstring) + assert result is None + + +def test_build_returns_original_when_no_extra_kwargs() -> None: + result = build_layered_docstring(_source_with_full_docstring) + assert result is not None + assert "Do something useful." in result + assert "Keyword Args:" in result + + +def test_build_returns_original_when_extra_kwargs_empty() -> None: + result = build_layered_docstring(_source_with_full_docstring, extra_keyword_args={}) + assert result is not None + assert result == build_layered_docstring(_source_with_full_docstring) + + +def test_build_appends_to_existing_keyword_args_section() -> None: + result = build_layered_docstring( + _source_with_full_docstring, + extra_keyword_args={"retries": "Number of retries."}, + ) + assert result is not None + assert "timeout: Max seconds to wait." in result + assert "retries: Number of retries." in result + # Both should be under Keyword Args + lines = result.splitlines() + kw_index = next(i for i, line in enumerate(lines) if line == "Keyword Args:") + ret_index = next(i for i, line in enumerate(lines) if line == "Returns:") + retries_index = next(i for i, line in enumerate(lines) if "retries:" in line) + assert kw_index < retries_index < ret_index + + +def test_build_inserts_keyword_args_after_args_section() -> None: + result = build_layered_docstring( + _source_with_args_only, + extra_keyword_args={"verbose": "Enable verbose output."}, + ) + assert result is not None + assert "Keyword Args:" in result + assert "verbose: Enable verbose output." in result + lines = result.splitlines() + args_index = next(i for i, line in enumerate(lines) if line == "Args:") + kw_index = next(i for i, line in enumerate(lines) if line == "Keyword Args:") + ret_index = next(i for i, line in enumerate(lines) if line == "Returns:") + assert args_index < kw_index < ret_index + + +def test_build_inserts_keyword_args_in_docstring_with_no_sections() -> None: + result = build_layered_docstring( + _source_no_sections, + extra_keyword_args={"debug": "Enable debug mode."}, + ) + assert result is not None + assert "A plain summary" in result + assert "Keyword Args:" in result + assert "debug: Enable debug mode." in result + + +def test_build_handles_multiline_descriptions() -> None: + result = build_layered_docstring( + _source_with_args_only, + extra_keyword_args={ + "config": "The configuration object.\nMust be a valid mapping.\nDefaults to empty.", + }, + ) + assert result is not None + lines = result.splitlines() + config_line = next(line for line in lines if "config:" in line) + assert "The configuration object." in config_line + # Continuation lines should be indented + config_idx = lines.index(config_line) + assert "Must be a valid mapping." in lines[config_idx + 1] + assert "Defaults to empty." in lines[config_idx + 2] + + +def test_build_preserves_multiple_extra_kwargs_order() -> None: + result = build_layered_docstring( + _source_with_args_only, + extra_keyword_args={ + "alpha": "First.", + "beta": "Second.", + "gamma": "Third.", + }, + ) + assert result is not None + lines = result.splitlines() + alpha_idx = next(i for i, line in enumerate(lines) if "alpha:" in line) + beta_idx = next(i for i, line in enumerate(lines) if "beta:" in line) + gamma_idx = next(i for i, line in enumerate(lines) if "gamma:" in line) + assert alpha_idx < beta_idx < gamma_idx + + +# -- apply_layered_docstring tests -- + + +def test_apply_sets_docstring_on_target() -> None: + def target() -> None: + pass + + apply_layered_docstring(target, _source_with_full_docstring) + assert target.__doc__ is not None + assert "Do something useful." in target.__doc__ + + +def test_apply_with_extra_kwargs() -> None: + def target() -> None: + pass + + apply_layered_docstring( + target, + _source_with_args_only, + extra_keyword_args={"flag": "A boolean flag."}, + ) + assert target.__doc__ is not None + assert "flag: A boolean flag." in target.__doc__ + assert "Keyword Args:" in target.__doc__ + + +def test_apply_sets_none_when_source_has_no_docstring() -> None: + def target() -> None: + """Original.""" + + apply_layered_docstring(target, _source_no_docstring) + assert target.__doc__ is None diff --git a/python/packages/core/tests/core/test_embedding_client.py b/python/packages/core/tests/core/test_embedding_client.py index 71d2bcfd70..1c49c1d012 100644 --- a/python/packages/core/tests/core/test_embedding_client.py +++ b/python/packages/core/tests/core/test_embedding_client.py @@ -4,6 +4,8 @@ from collections.abc import Sequence +import pytest + from agent_framework import ( BaseEmbeddingClient, Embedding, @@ -63,6 +65,11 @@ def test_base_additional_properties_custom() -> None: assert client.additional_properties == {"key": "value"} +def test_base_embedding_client_rejects_unknown_kwargs() -> None: + with pytest.raises(TypeError): + MockEmbeddingClient(legacy_key="value") # type: ignore[call-arg] + + # --- SupportsGetEmbeddings protocol tests --- diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 7f0eda62fc..9eed4b9299 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -3512,3 +3512,131 @@ def test_dict_overwrites_existing_conversation_id(self): # endregion +async def test_user_input_request_propagates_through_as_tool(chat_client_base: SupportsChatGetResponse): + """Test that user_input_request content from a sub-agent wrapped as a tool propagates to the parent response.""" + from agent_framework.exceptions import UserInputRequiredException + + @tool(name="delegate_agent", approval_mode="never_require") + def delegate_tool(task: str) -> str: + del task + raise UserInputRequiredException( + contents=[ + Content.from_oauth_consent_request( + consent_link="https://login.microsoftonline.com/consent", + ) + ] + ) + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="1", name="delegate_agent", arguments='{"task": "do it"}'), + ], + ) + ) + ] + + response = await chat_client_base.get_response( + [Message(role="user", text="delegate this")], + options={"tool_choice": "auto", "tools": [delegate_tool]}, + ) + + user_requests = [ + content + for msg in response.messages + for content in msg.contents + if isinstance(content, Content) and content.user_input_request + ] + assert len(user_requests) == 1 + assert user_requests[0].type == "oauth_consent_request" + assert user_requests[0].consent_link == "https://login.microsoftonline.com/consent" + assert user_requests[0].user_input_request is True + + +async def test_user_input_request_multiple_contents_propagate(chat_client_base: SupportsChatGetResponse): + """Test that multiple user_input_request items in a single exception all propagate to the parent response.""" + from agent_framework.exceptions import UserInputRequiredException + + @tool(name="multi_request_tool", approval_mode="never_require") + def multi_request(task: str) -> str: + del task + raise UserInputRequiredException( + contents=[ + Content.from_oauth_consent_request( + consent_link="https://example.com/consent1", + ), + Content.from_oauth_consent_request( + consent_link="https://example.com/consent2", + ), + Content.from_oauth_consent_request( + consent_link="https://example.com/consent3", + ), + ] + ) + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="1", name="multi_request_tool", arguments='{"task": "do it"}'), + ], + ) + ) + ] + + response = await chat_client_base.get_response( + [Message(role="user", text="do something")], + options={"tool_choice": "auto", "tools": [multi_request]}, + ) + + user_requests = [ + content + for msg in response.messages + for content in msg.contents + if isinstance(content, Content) and content.user_input_request + ] + assert len(user_requests) == 3 + consent_links = {r.consent_link for r in user_requests} + assert consent_links == { + "https://example.com/consent1", + "https://example.com/consent2", + "https://example.com/consent3", + } + + +async def test_user_input_request_empty_contents_returns_fallback(chat_client_base: SupportsChatGetResponse): + """Test that UserInputRequiredException with empty contents produces a fallback function_result.""" + from agent_framework.exceptions import UserInputRequiredException + + @tool(name="empty_request_tool", approval_mode="never_require") + def empty_request(task: str) -> str: + del task + raise UserInputRequiredException(contents=[]) + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="1", name="empty_request_tool", arguments='{"task": "do it"}'), + ], + ) + ), + ChatResponse(messages=Message(role="assistant", text="handled")), + ] + + response = await chat_client_base.get_response( + [Message(role="user", text="do something")], + options={"tool_choice": "auto", "tools": [empty_request]}, + ) + + # With empty contents, the handler returns a function_result with an error message + # and the loop continues to the next chat response. + function_results = [ + content for msg in response.messages for content in msg.contents if content.type == "function_result" + ] + assert len(function_results) >= 1 + assert any("user input" in (fr.result or "").lower() for fr in function_results) diff --git a/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py index cecd466d86..160ea0fcc4 100644 --- a/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py +++ b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py @@ -6,11 +6,13 @@ from typing import Any from agent_framework import ( + Agent, BaseChatClient, ChatMiddlewareLayer, ChatResponse, ChatResponseUpdate, Content, + FunctionInvocationContext, FunctionInvocationLayer, Message, ResponseStream, @@ -97,6 +99,7 @@ class TestKwargsPropagationToFunctionTool: async def test_kwargs_propagate_to_tool_with_kwargs(self) -> None: """Test that kwargs passed to get_response() are available in @tool **kwargs.""" + # TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed. captured_kwargs: dict[str, Any] = {} @tool(approval_mode="never_require") @@ -149,6 +152,7 @@ def capture_kwargs_tool(x: int, **kwargs: Any) -> str: async def test_kwargs_not_forwarded_to_tool_without_kwargs(self) -> None: """Test that kwargs are NOT forwarded to @tool that doesn't accept **kwargs.""" + # TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed. @tool(approval_mode="never_require") def simple_tool(x: int) -> str: @@ -185,6 +189,7 @@ def simple_tool(x: int) -> str: async def test_kwargs_isolated_between_function_calls(self) -> None: """Test that kwargs are consistent across multiple function call invocations.""" + # TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed. invocation_kwargs: list[dict[str, Any]] = [] @tool(approval_mode="never_require") @@ -235,6 +240,7 @@ def tracking_tool(name: str, **kwargs: Any) -> str: async def test_streaming_response_kwargs_propagation(self) -> None: """Test that kwargs propagate to @tool in streaming mode.""" + # TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed. captured_kwargs: dict[str, Any] = {} @tool(approval_mode="never_require") @@ -287,3 +293,59 @@ def streaming_capture_tool(value: str, **kwargs: Any) -> str: assert "streaming_session" in captured_kwargs, f"Expected 'streaming_session' in {captured_kwargs}" assert captured_kwargs["streaming_session"] == "session-xyz" assert captured_kwargs["correlation_id"] == "corr-123" + + async def test_agent_run_injects_function_invocation_context(self) -> None: + """Test that Agent.run injects FunctionInvocationContext for ctx-based tools.""" + captured_context_kwargs: dict[str, Any] = {} + captured_client_kwargs: dict[str, Any] = {} + captured_options: dict[str, Any] = {} + + @tool(approval_mode="never_require") + def capture_context_tool(x: int, ctx: FunctionInvocationContext) -> str: + captured_context_kwargs.update(ctx.kwargs) + return f"result: x={x}" + + class CapturingFunctionInvokingMockClient(FunctionInvokingMockClient): + async def _get_non_streaming_response( + self, + *, + messages: MutableSequence[Message], + options: dict[str, Any], + **kwargs: Any, + ) -> ChatResponse: + captured_options.update(options) + captured_client_kwargs.update(kwargs) + return await super()._get_non_streaming_response(messages=messages, options=options, **kwargs) + + client = CapturingFunctionInvokingMockClient() + client.run_responses = [ + ChatResponse( + messages=[ + Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_1", + name="capture_context_tool", + arguments='{"x": 42}', + ) + ], + ) + ] + ), + ChatResponse(messages=[Message(role="assistant", text="Done!")]), + ] + + agent = Agent(client=client, tools=[capture_context_tool]) + result = await agent.run( + [Message(role="user", text="Test")], + function_invocation_kwargs={"tool_request_id": "tool-123"}, + client_kwargs={"client_request_id": "client-456"}, + ) + + assert captured_context_kwargs["tool_request_id"] == "tool-123" + assert "client_request_id" not in captured_context_kwargs + assert captured_client_kwargs["client_request_id"] == "client-456" + assert "tool_request_id" not in captured_client_kwargs + assert "additional_function_arguments" not in captured_options + assert result.messages[-1].text == "Done!" diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index 4d2e603274..bd2cb8155e 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -192,10 +192,10 @@ def __init__(self, source_id: str, stored_messages: list[Message] | None = None, self.stored: list[Message] = [] self._stored_messages = stored_messages or [] - async def get_messages(self, session_id: str | None, **kwargs) -> list[Message]: + async def get_messages(self, session_id: str | None, *, state=None, **kwargs) -> list[Message]: return list(self._stored_messages) - async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs) -> None: + async def save_messages(self, session_id: str | None, messages: Sequence[Message], *, state=None, **kwargs) -> None: self.stored.extend(messages) diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index f7674edc9b..fcab14ee1b 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -12,6 +12,7 @@ FunctionTool, tool, ) +from agent_framework._middleware import FunctionInvocationContext from agent_framework._tools import ( _parse_annotation, _parse_inputs, @@ -941,6 +942,128 @@ def tool_with_kwargs(x: int, **kwargs: Any) -> str: assert result_default == "x=10, user=unknown" +async def test_ai_function_with_explicit_invocation_context(): + """Test that invoke() can receive runtime kwargs via FunctionInvocationContext.""" + + @tool + def tool_with_context(x: int, ctx: FunctionInvocationContext) -> str: + """A tool that accepts runtime context injection.""" + user_id = ctx.kwargs.get("user_id", "unknown") + return f"x={x}, user={user_id}" + + assert tool_with_context.parameters() == { + "properties": {"x": {"title": "X", "type": "integer"}}, + "required": ["x"], + "title": "tool_with_context_input", + "type": "object", + } + + context = FunctionInvocationContext( + function=tool_with_context, + arguments=tool_with_context.input_model(x=7), + kwargs={"user_id": "ctx-user"}, + ) + + result = await tool_with_context.invoke(context=context) + + assert result == "x=7, user=ctx-user" + + +async def test_ai_function_with_typed_context_parameter_using_custom_name(): + """Test that typed context injection works for names other than ctx.""" + + @tool + def tool_with_runtime_context(x: int, runtime: FunctionInvocationContext) -> str: + """A tool that uses a custom context parameter name.""" + user_id = runtime.kwargs.get("user_id", "unknown") + return f"x={x}, user={user_id}" + + assert tool_with_runtime_context.parameters() == { + "properties": {"x": {"title": "X", "type": "integer"}}, + "required": ["x"], + "title": "tool_with_runtime_context_input", + "type": "object", + } + + context = FunctionInvocationContext( + function=tool_with_runtime_context, + arguments=tool_with_runtime_context.input_model(x=8), + kwargs={"user_id": "runtime-user"}, + ) + + result = await tool_with_runtime_context.invoke(context=context) + + assert result == "x=8, user=runtime-user" + + +async def test_ai_function_with_explicit_schema_and_untyped_ctx(): + """Test that explicit schemas allow an untyped ctx parameter.""" + + class ToolInput(BaseModel): + x: int + + @tool(schema=ToolInput) + def tool_with_schema(x, ctx) -> str: + """A tool with explicit schema and implicit ctx injection.""" + return f"x={x}, user={ctx.kwargs.get('user_id', 'unknown')}" + + context = FunctionInvocationContext( + function=tool_with_schema, + arguments=ToolInput(x=9), + kwargs={"user_id": "schema-user"}, + ) + + result = await tool_with_schema.invoke(context=context) + + assert result == "x=9, user=schema-user" + + +async def test_ai_function_with_explicit_schema_and_typed_ctx(): + """Test that explicit schemas also work with typed context injection.""" + + class ToolInput(BaseModel): + x: int + + @tool(schema=ToolInput) + def tool_with_schema(x: int, runtime: FunctionInvocationContext) -> str: + """A tool with explicit schema and typed context injection.""" + return f"x={x}, user={runtime.kwargs.get('user_id', 'unknown')}" + + context = FunctionInvocationContext( + function=tool_with_schema, + arguments=ToolInput(x=11), + kwargs={"user_id": "typed-schema-user"}, + ) + + result = await tool_with_schema.invoke(context=context) + + assert tool_with_schema.parameters() == ToolInput.model_json_schema() + assert result == "x=11, user=typed-schema-user" + + +def test_ai_function_with_multiple_typed_context_parameters_fails(): + """Test that tools reject multiple typed FunctionInvocationContext parameters.""" + + with pytest.raises(ValueError, match="multiple FunctionInvocationContext parameters"): + + @tool + def invalid_tool(ctx_one: FunctionInvocationContext, ctx_two: FunctionInvocationContext) -> str: + return f"{ctx_one.kwargs}-{ctx_two.kwargs}" + + +def test_ai_function_with_ctx_and_typed_context_parameter_fails(): + """Test that explicit-schema tools reject both implicit ctx and typed context parameters.""" + + class ToolInput(BaseModel): + x: int + + with pytest.raises(ValueError, match="multiple FunctionInvocationContext parameters"): + + @tool(schema=ToolInput) + def invalid_tool(x, ctx, runtime: FunctionInvocationContext) -> str: + return f"{x}-{ctx.kwargs}-{runtime.kwargs}" + + # region _parse_annotation tests diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py index 0a7cf50b0b..713c1b4e69 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_executors.py +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -124,10 +124,20 @@ def run_durable_agent( """ raise NotImplementedError - def get_new_session(self, agent_name: str, **kwargs: Any) -> DurableAgentSession: + def get_new_session( + self, + agent_name: str, + *, + session_id: str | None = None, + service_session_id: str | None = None, + ) -> DurableAgentSession: """Create a new DurableAgentSession with random session ID.""" - session_id = self._create_session_id(agent_name) - return DurableAgentSession.from_session_id(session_id, **kwargs) + durable_session_id = self._create_session_id(agent_name) + return DurableAgentSession( + durable_session_id=durable_session_id, + session_id=session_id, + service_session_id=service_session_id, + ) def _create_session_id( self, diff --git a/python/packages/durabletask/agent_framework_durabletask/_models.py b/python/packages/durabletask/agent_framework_durabletask/_models.py index 1c5484afbf..19d5804bc2 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_models.py +++ b/python/packages/durabletask/agent_framework_durabletask/_models.py @@ -284,46 +284,48 @@ def __init__( durable_session_id: AgentSessionId | None = None, session_id: str | None = None, service_session_id: str | None = None, - **kwargs: Any, ) -> None: - super().__init__(session_id=session_id, service_session_id=service_session_id, **kwargs) - self._session_id_value: AgentSessionId | None = durable_session_id + super().__init__(session_id=session_id, service_session_id=service_session_id) + self.durable_session_id: AgentSessionId | None = durable_session_id - @property - def durable_session_id(self) -> AgentSessionId | None: - return self._session_id_value - - @durable_session_id.setter - def durable_session_id(self, value: AgentSessionId | None) -> None: - self._session_id_value = value + def to_dict(self) -> dict[str, Any]: + state = super().to_dict() + if self.durable_session_id is not None: + state[self._SERIALIZED_SESSION_ID_KEY] = str(self.durable_session_id) + return state @classmethod def from_session_id( cls, - session_id: AgentSessionId, - **kwargs: Any, + durable_session_id: AgentSessionId, + *, + session_id: str | None = None, + service_session_id: str | None = None, ) -> DurableAgentSession: - return cls(durable_session_id=session_id, **kwargs) - - def to_dict(self) -> dict[str, Any]: - state = super().to_dict() - if self._session_id_value is not None: - state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id_value) - return state + """Create a DurableAgentSession from an AgentSessionId.""" + return cls( + durable_session_id=durable_session_id, + session_id=session_id, + service_session_id=service_session_id, + ) @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentSession: - state_payload = dict(data) - session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None) - session = super().from_dict(state_payload) + """Create a DurableAgentSession from a state dict.""" + data = dict(data) # defensive copy — avoid mutating caller's dict + session_id_value = data.pop(cls._SERIALIZED_SESSION_ID_KEY, None) + session = super().from_dict(data) + durable_session_id: AgentSessionId | None = None # We need to create a DurableAgentSession from the base AgentSession + if session_id_value is not None: + if not isinstance(session_id_value, str): + raise ValueError("durable_session_id must be a string when present in serialized state") + durable_session_id = AgentSessionId.parse(session_id_value) + durable_session = cls( + durable_session_id=durable_session_id, session_id=session.session_id, service_session_id=session.service_session_id, ) durable_session.state.update(session.state) - if session_id_value is not None: - if not isinstance(session_id_value, str): - raise ValueError("durable_session_id must be a string when present in serialized state") - durable_session._session_id_value = AgentSessionId.parse(session_id_value) return durable_session diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index 5693876ad7..b21cac6831 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -133,16 +133,13 @@ def run( # type: ignore[override] session=session, ) - def create_session(self, **kwargs: Any) -> DurableAgentSession: + def create_session(self, *, session_id: str | None = None) -> DurableAgentSession: """Create a new agent session via the provider.""" - return self._executor.get_new_session(self.name, **kwargs) + return self._executor.get_new_session(self.name) - def get_session(self, **kwargs: Any) -> AgentSession: - """Retrieve an existing session via the provider. - - For durable agents, sessions do not use `service_session_id` so this is not used. - """ - return self._executor.get_new_session(self.name, **kwargs) + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + """Retrieve an existing session via the provider.""" + return self._executor.get_new_session(self.name, service_session_id=service_session_id, session_id=session_id) def _normalize_messages(self, messages: AgentRunInputs | None) -> str: """Convert supported message inputs to a single string. diff --git a/python/packages/durabletask/tests/test_agent_session_id.py b/python/packages/durabletask/tests/test_agent_session_id.py index 571212f145..3902acd22f 100644 --- a/python/packages/durabletask/tests/test_agent_session_id.py +++ b/python/packages/durabletask/tests/test_agent_session_id.py @@ -2,6 +2,8 @@ """Unit tests for AgentSessionId and DurableAgentSession.""" +from typing import Any + import pytest from agent_framework import AgentSession @@ -153,7 +155,7 @@ def test_durable_session_id_setter(self) -> None: def test_from_session_id(self) -> None: """Test creating DurableAgentSession from session ID.""" session_id = AgentSessionId(name="TestAgent", key="test-key") - session = DurableAgentSession.from_session_id(session_id) + session = DurableAgentSession(durable_session_id=session_id) assert isinstance(session, DurableAgentSession) assert session.durable_session_id is not None @@ -161,10 +163,10 @@ def test_from_session_id(self) -> None: assert session.durable_session_id.name == "TestAgent" assert session.durable_session_id.key == "test-key" - def test_from_session_id_with_service_session_id(self) -> None: - """Test creating DurableAgentSession with service session ID.""" + def test_init_with_service_session_id(self) -> None: + """Test creating DurableAgentSession with explicit service session ID.""" session_id = AgentSessionId(name="TestAgent", key="test-key") - session = DurableAgentSession.from_session_id(session_id, service_session_id="service-123") + session = DurableAgentSession(durable_session_id=session_id, service_session_id="service-123") assert session.durable_session_id is not None assert session.durable_session_id == session_id @@ -192,7 +194,7 @@ def test_to_dict_without_durable_session_id(self) -> None: def test_from_dict_with_durable_session_id(self) -> None: """Test deserialization restores durable session ID.""" - serialized = { + serialized: dict[str, Any] = { "type": "session", "session_id": "session-123", "service_session_id": "service-123", @@ -210,7 +212,7 @@ def test_from_dict_with_durable_session_id(self) -> None: def test_from_dict_without_durable_session_id(self) -> None: """Test deserialization without durable session ID.""" - serialized = { + serialized: dict[str, Any] = { "type": "session", "session_id": "session-456", "service_session_id": "service-456", diff --git a/python/packages/durabletask/tests/test_client.py b/python/packages/durabletask/tests/test_client.py index 0acdfb2f9c..a056d4e254 100644 --- a/python/packages/durabletask/tests/test_client.py +++ b/python/packages/durabletask/tests/test_client.py @@ -88,15 +88,6 @@ def test_client_agent_can_create_sessions(self, agent_client: DurableAIAgentClie assert isinstance(session, DurableAgentSession) - def test_client_agent_session_with_parameters(self, agent_client: DurableAIAgentClient) -> None: - """Verify agent can create sessions with custom parameters.""" - agent = agent_client.get_agent("assistant") - - session = agent.create_session(service_session_id="client-session-123") - - assert isinstance(session, DurableAgentSession) - assert session.service_session_id == "client-session-123" - class TestDurableAIAgentClientPollingConfiguration: """Test polling configuration parameters for DurableAIAgentClient.""" diff --git a/python/packages/durabletask/tests/test_orchestration_context.py b/python/packages/durabletask/tests/test_orchestration_context.py index 033c274c88..9f7cde156c 100644 --- a/python/packages/durabletask/tests/test_orchestration_context.py +++ b/python/packages/durabletask/tests/test_orchestration_context.py @@ -82,17 +82,6 @@ def test_orchestration_agent_can_create_sessions(self, agent_context: DurableAIA assert isinstance(session, DurableAgentSession) - def test_orchestration_agent_session_with_parameters( - self, agent_context: DurableAIAgentOrchestrationContext - ) -> None: - """Verify agent can create sessions with custom parameters.""" - agent = agent_context.get_agent("assistant") - - session = agent.create_session(service_session_id="orch-session-456") - - assert isinstance(session, DurableAgentSession) - assert session.service_session_id == "orch-session-456" - if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_shim.py b/python/packages/durabletask/tests/test_shim.py index 423f587871..687a0746a7 100644 --- a/python/packages/durabletask/tests/test_shim.py +++ b/python/packages/durabletask/tests/test_shim.py @@ -184,16 +184,31 @@ def test_create_session_delegates_to_executor(self, test_agent: DurableAIAgent[A mock_executor.get_new_session.assert_called_once_with("test_agent") assert session == mock_session - def test_create_session_forwards_kwargs(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: - """Verify create_session forwards kwargs to executor.""" - mock_session = DurableAgentSession(service_session_id="session-123") + def test_get_session_forwards_service_session_id( + self, test_agent: DurableAIAgent[Any], mock_executor: Mock + ) -> None: + """Verify get_session forwards service_session_id and session_id to executor.""" + mock_session = DurableAgentSession(service_session_id="svc-123") mock_executor.get_new_session.return_value = mock_session - test_agent.create_session(service_session_id="session-123") + session = test_agent.get_session("svc-123", session_id="local-456") - mock_executor.get_new_session.assert_called_once() - _, kwargs = mock_executor.get_new_session.call_args - assert kwargs["service_session_id"] == "session-123" + mock_executor.get_new_session.assert_called_once_with( + "test_agent", service_session_id="svc-123", session_id="local-456" + ) + assert session.service_session_id == "svc-123" + + def test_get_session_without_session_id(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify get_session works with only service_session_id (session_id defaults to None).""" + mock_session = DurableAgentSession(service_session_id="svc-789") + mock_executor.get_new_session.return_value = mock_session + + session = test_agent.get_session("svc-789") + + mock_executor.get_new_session.assert_called_once_with( + "test_agent", service_session_id="svc-789", session_id=None + ) + assert session.service_session_id == "svc-789" class TestDurableAgentProviderInterface: diff --git a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py index 16451ae85a..4c1e64cd7c 100644 --- a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py +++ b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py @@ -146,11 +146,11 @@ def __init__( timeout: float | None = None, prepare_model: bool = True, device: DeviceType | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str = "utf-8", - **kwargs: Any, ) -> None: """Initialize a FoundryLocalClient. @@ -169,12 +169,11 @@ def __init__( The device is used to select the appropriate model variant. If not provided, the default device for your system will be used. The values are in the foundry_local.models.DeviceType enum. + additional_properties: Additional properties stored on the client instance. middleware: Optional sequence of ChatAndFunctionMiddlewareTypes to apply to requests. function_invocation_configuration: Optional configuration for function invocation support. env_file_path: If provided, the .env settings are read from this file path location. env_file_encoding: The encoding of the .env file, defaults to 'utf-8'. - kwargs: Additional keyword arguments, are passed to the RawOpenAIChatClient. - This can include middleware and additional properties. Examples: @@ -271,8 +270,8 @@ class MyOptions(FoundryLocalChatOptions, total=False): super().__init__( model_id=model_info.id, client=AsyncOpenAI(base_url=manager.endpoint, api_key=manager.api_key), + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) self.manager = manager diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 7fa7d0dce4..2b7c266a37 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -303,7 +303,6 @@ def run( stream: Literal[False] = False, session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> Awaitable[AgentResponse]: ... @overload @@ -314,7 +313,6 @@ def run( stream: Literal[True], session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... def run( @@ -324,7 +322,6 @@ def run( stream: bool = False, session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: """Get a response from the agent. @@ -339,7 +336,6 @@ def run( stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). options: Runtime options (model, timeout, etc.). - kwargs: Additional keyword arguments. Returns: When stream=False: An Awaitable[AgentResponse]. @@ -354,10 +350,10 @@ def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: return AgentResponse.from_updates(updates) return ResponseStream( - self._stream_updates(messages=messages, session=session, options=options, **kwargs), + self._stream_updates(messages=messages, session=session, options=options), finalizer=_finalize, ) - return self._run_impl(messages=messages, session=session, options=options, **kwargs) + return self._run_impl(messages=messages, session=session, options=options) async def _run_impl( self, @@ -365,7 +361,6 @@ async def _run_impl( *, session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> AgentResponse: """Non-streaming implementation of run.""" if not self._started: @@ -414,7 +409,6 @@ async def _stream_updates( *, session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: """Internal method to stream updates from GitHub Copilot. @@ -424,7 +418,6 @@ async def _stream_updates( Keyword Args: session: The conversation session associated with the message(s). options: Runtime options (model, timeout, etc.). - kwargs: Additional keyword arguments. Yields: AgentResponseUpdate items. diff --git a/python/packages/ollama/agent_framework_ollama/_chat_client.py b/python/packages/ollama/agent_framework_ollama/_chat_client.py index e31c1971da..db0666b2d4 100644 --- a/python/packages/ollama/agent_framework_ollama/_chat_client.py +++ b/python/packages/ollama/agent_framework_ollama/_chat_client.py @@ -300,11 +300,11 @@ def __init__( host: str | None = None, client: AsyncClient | None = None, model_id: str | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Ollama Chat client. @@ -313,11 +313,11 @@ def __init__( Can be set via the OLLAMA_HOST env variable. client: An optional Ollama Client instance. If not provided, a new instance will be created. model_id: The Ollama chat model ID to use. Can be set via the OLLAMA_MODEL_ID env variable. + additional_properties: Additional properties stored on the client instance. middleware: Optional middleware to apply to the client. function_invocation_configuration: Optional function invocation configuration override. env_file_path: An optional path to a dotenv (.env) file to load environment variables from. env_file_encoding: The encoding to use when reading the dotenv (.env) file. Defaults to 'utf-8'. - **kwargs: Additional keyword arguments passed to BaseChatClient. """ ollama_settings = load_settings( OllamaSettings, @@ -336,9 +336,9 @@ def __init__( self.host = str(self.client._client.base_url) # type: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType] super().__init__( + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) self.middleware = list(self.chat_middleware) diff --git a/python/packages/ollama/agent_framework_ollama/_embedding_client.py b/python/packages/ollama/agent_framework_ollama/_embedding_client.py index 5cd35fc9f3..8e0508c708 100644 --- a/python/packages/ollama/agent_framework_ollama/_embedding_client.py +++ b/python/packages/ollama/agent_framework_ollama/_embedding_client.py @@ -92,9 +92,9 @@ def __init__( model_id: str | None = None, host: str | None = None, client: AsyncClient | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a raw Ollama embedding client.""" ollama_settings = load_settings( @@ -110,7 +110,7 @@ def __init__( self.model_id = ollama_settings["embedding_model_id"] # type: ignore[assignment,reportTypedDictNotRequiredAccess] self.client = client or AsyncClient(host=ollama_settings.get("host")) self.host = str(self.client._client.base_url) # type: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType] - super().__init__(**kwargs) + super().__init__(additional_properties=additional_properties) def service_url(self) -> str: """Get the URL of the service.""" @@ -214,17 +214,17 @@ def __init__( host: str | None = None, client: AsyncClient | None = None, otel_provider_name: str | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Ollama embedding client.""" super().__init__( model_id=model_id, host=host, client=client, + additional_properties=additional_properties, otel_provider_name=otel_provider_name, env_file_path=env_file_path, env_file_encoding=env_file_encoding, - **kwargs, ) diff --git a/python/packages/redis/agent_framework_redis/_history_provider.py b/python/packages/redis/agent_framework_redis/_history_provider.py index e1a20b6218..be2db098b8 100644 --- a/python/packages/redis/agent_framework_redis/_history_provider.py +++ b/python/packages/redis/agent_framework_redis/_history_provider.py @@ -107,11 +107,18 @@ def _redis_key(self, session_id: str | None) -> str: """Get the Redis key for a given session's messages.""" return f"{self.key_prefix}:{session_id or 'default'}" - async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]: + async def get_messages( + self, + session_id: str | None, + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> list[Message]: """Retrieve stored messages for this session from Redis. Args: session_id: The session ID to retrieve messages for. + state: Optional session state. Unused for Redis-backed history. **kwargs: Additional arguments (unused). Returns: @@ -125,12 +132,20 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess messages.append(Message.from_dict(self._deserialize_json(serialized))) # type: ignore[union-attr] return messages - async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None: + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: """Persist messages for this session to Redis. Args: session_id: The session ID to store messages for. messages: The messages to persist. + state: Optional session state. Unused for Redis-backed history. **kwargs: Additional arguments (unused). """ if not messages: diff --git a/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py b/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py index 33748437e0..fa78a9ede5 100644 --- a/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py +++ b/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Awaitable, Callable -from agent_framework import AgentContext, AgentSession +from agent_framework import AgentContext, AgentSession, FunctionInvocationContext, tool from agent_framework.openai import OpenAIResponsesClient from dotenv import load_dotenv @@ -18,9 +18,6 @@ When session propagation is enabled, both agents share the same session object, including session_id and the mutable state dict. This allows correlated conversation tracking and shared state across the agent hierarchy. - -The middleware functions below are purely for observability — they are NOT -required for session propagation to work. """ @@ -28,65 +25,83 @@ async def log_session( context: AgentContext, call_next: Callable[[], Awaitable[None]], ) -> None: - """Agent middleware that logs the session received by each agent. - - NOT required for session propagation — only used to observe the flow. - If propagation is working, both agents will show the same session_id. - """ + """Agent middleware that logs the session received by each agent.""" session: AgentSession | None = context.session + if not session: + print("No session found.") + await call_next() + return agent_name = context.agent.name or "unknown" - session_id = session.session_id if session else None - state = dict(session.state) if session else {} - print(f" [{agent_name}] session_id={session_id}, state={state}") + print( + f" [{agent_name}] session_id={session.session_id}, " + f"service_session_id={session.service_session_id} state={session.state}" + ) await call_next() +@tool(description="Use this tool to store the findings so that other agents can reason over them.") +def store_findings(findings: str, ctx: FunctionInvocationContext) -> None: + if ctx.session is None: + return + current_findings = ctx.session.state.get("findings") + if current_findings is None: + ctx.session.state["findings"] = findings + else: + ctx.session.state["findings"] = f"{current_findings}\n{findings}" + + +@tool(description="Use this tool to gather the current findings from other agents.") +def recall_findings(ctx: FunctionInvocationContext) -> str: + if ctx.session is None: + return "No session available" + current_findings = ctx.session.state.get("findings") + if current_findings is None: + return "Nothing yet" + return current_findings + + async def main() -> None: print("=== Agent-as-Tool: Session Propagation ===\n") client = OpenAIResponsesClient() - # --- Sub-agent: a research specialist --- - # The sub-agent has the same log_session middleware to prove it receives the session. research_agent = client.as_agent( name="ResearchAgent", - instructions="You are a research assistant. Provide concise answers.", + instructions="You are a research assistant. Provide concise answers and store your findings.", middleware=[log_session], + tools=[store_findings, recall_findings], ) - # propagate_session=True: the coordinator's session will be forwarded research_tool = research_agent.as_tool( name="research", - description="Research a topic and return findings", + description="Research a topic and store your findings.", arg_name="query", arg_description="The research query", propagate_session=True, ) - # --- Coordinator agent --- coordinator = client.as_agent( name="CoordinatorAgent", - instructions="You coordinate research. Use the 'research' tool to look up information.", - tools=[research_tool], + instructions=( + "You coordinate research. Use the 'research' tool to start research " + "and then use the recall findings tool to gather up everything." + ), + tools=[research_tool, store_findings, recall_findings], middleware=[log_session], ) - # Create a shared session and put some state in it session = coordinator.create_session() - session.state["request_source"] = "demo" + session.state["findings"] = None print(f"Session ID: {session.session_id}") print(f"Session state before run: {session.state}\n") - query = "What are the latest developments in quantum computing?" + query = "What are the latest developments in quantum computing and in AI?" print(f"User: {query}\n") result = await coordinator.run(query, session=session) print(f"\nCoordinator: {result}\n") print(f"Session state after run: {session.state}") - print( - "\nIf both agents show the same session_id above, session propagation is working." - ) if __name__ == "__main__": diff --git a/python/samples/02-agents/tools/function_tool_with_kwargs.py b/python/samples/02-agents/tools/function_tool_with_kwargs.py index 249ebc4a33..61db84eb17 100644 --- a/python/samples/02-agents/tools/function_tool_with_kwargs.py +++ b/python/samples/02-agents/tools/function_tool_with_kwargs.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from typing import Annotated, Any +from typing import Annotated -from agent_framework import tool +from agent_framework import FunctionInvocationContext, tool from agent_framework.openai import OpenAIResponsesClient from dotenv import load_dotenv from pydantic import Field @@ -14,27 +14,27 @@ """ AI Function with kwargs Example -This example demonstrates how to inject custom keyword arguments (kwargs) into an AI function -from the agent's run method, without exposing them to the AI model. +This example demonstrates how to inject runtime context into an AI function +from the agent's run method, without exposing it to the AI model. This is useful for passing runtime information like access tokens, user IDs, or request-specific context that the tool needs but the model shouldn't know about -or provide. +or provide. The injected context parameter can be typed as +``FunctionInvocationContext`` as shown here, or left untyped as ``ctx`` when you +prefer a lighter-weight sample setup. """ -# Define the function tool with **kwargs to accept injected arguments -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; -# see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. +# Define the function tool with explicit invocation context. +# The context parameter can also be declared as an untyped ``ctx`` parameter. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], - **kwargs: Any, + ctx: FunctionInvocationContext, ) -> str: """Get the weather for a given location.""" - # Extract the injected argument from kwargs - user_id = kwargs.get("user_id", "unknown") + # Extract the injected argument from the explicit context + user_id = ctx.kwargs.get("user_id", "unknown") # Simulate using the user_id for logging or personalization print(f"Getting weather for user: {user_id}") @@ -49,9 +49,11 @@ async def main() -> None: tools=[get_weather], ) - # Pass the injected argument when running the agent - # The 'user_id' kwarg will be passed down to the tool execution via **kwargs - response = await agent.run("What is the weather like in Amsterdam?", user_id="user_123") + # Pass the runtime context explicitly when running the agent. + response = await agent.run( + "What is the weather like in Amsterdam?", + function_invocation_kwargs={"user_id": "user_123"}, + ) print(f"Agent: {response.text}") diff --git a/python/samples/02-agents/tools/function_tool_with_session_injection.py b/python/samples/02-agents/tools/function_tool_with_session_injection.py index 2689ff5f9c..53cc63c2c0 100644 --- a/python/samples/02-agents/tools/function_tool_with_session_injection.py +++ b/python/samples/02-agents/tools/function_tool_with_session_injection.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from typing import Annotated, Any +from typing import Annotated -from agent_framework import AgentSession, tool +from agent_framework import AgentSession, FunctionInvocationContext, tool from agent_framework.openai import OpenAIResponsesClient from dotenv import load_dotenv from pydantic import Field @@ -14,23 +14,21 @@ """ AI Function with Session Injection Example -This example demonstrates the behavior when passing 'session' to agent.run() -and accessing that session in AI function. +This example demonstrates accessing the agent session inside a tool function +via ``FunctionInvocationContext.session``. The session is automatically +available when the agent is invoked with a session. """ -# Define the function tool with **kwargs -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; -# see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. +# Define the function tool with explicit invocation context. +# The context parameter can also be declared as an untyped parameter with the name: ``ctx``. @tool(approval_mode="never_require") async def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], - **kwargs: Any, + ctx: FunctionInvocationContext, ) -> str: """Get the weather for a given location.""" - # Get session object from kwargs - session = kwargs.get("session") + session = ctx.session if session and isinstance(session, AgentSession) and session.service_session_id: print(f"Session ID: {session.service_session_id}.") @@ -42,17 +40,19 @@ async def main() -> None: name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=[get_weather], - options={"store": True}, + default_options={"store": True}, ) # Create a session session = agent.create_session() - # Run the agent with the session - # Pass session via additional_function_arguments so tools can access it via **kwargs - opts = {"additional_function_arguments": {"session": session}} - print(f"Agent: {await agent.run('What is the weather in London?', session=session, options=opts)}") - print(f"Agent: {await agent.run('What is the weather in Amsterdam?', session=session, options=opts)}") + # Run the agent with the session; tools receive it via ctx.session. + print( + f"Agent: {await agent.run('What is the weather in London?', session=session)}" + ) + print( + f"Agent: {await agent.run('What is the weather in Amsterdam?', session=session)}" + ) print(f"Agent: {await agent.run('What cities did I ask about?', session=session)}")