diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index ef03652898..57a438da2a 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -181,6 +181,7 @@ ) from .exceptions import ( MiddlewareException, + UserInputRequiredException, WorkflowCheckpointException, WorkflowConvergenceException, WorkflowException, @@ -291,6 +292,7 @@ "TypeCompatibilityError", "UpdateT", "UsageDetails", + "UserInputRequiredException", "ValidationTypeEnum", "Workflow", "WorkflowAgent", diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 5cf7ff78a2..381997d567 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -57,7 +57,7 @@ map_chat_to_agent_update, normalize_messages, ) -from .exceptions import AgentInvalidResponseException +from .exceptions import AgentInvalidResponseException, UserInputRequiredException from .observability import AgentTelemetryLayer if sys.version_info >= (3, 13): @@ -538,14 +538,16 @@ async def agent_wrapper(**kwargs: Any) -> str: if stream_callback is None: # Use non-streaming mode - return ( - await self.run( - input_text, - stream=False, - session=parent_session, - **forwarded_kwargs, - ) - ).text + response = await self.run( + input_text, + stream=False, + session=parent_session, + **forwarded_kwargs, + ) + + if response.user_input_requests: + raise UserInputRequiredException(contents=response.user_input_requests) + return response.text # Use streaming mode - accumulate updates and create final response response_updates: list[AgentResponseUpdate] = [] @@ -557,7 +559,10 @@ async def agent_wrapper(**kwargs: Any) -> str: stream_callback(update) # Create final text from accumulated updates - return AgentResponse.from_updates(response_updates).text + final_response = AgentResponse.from_updates(response_updates) + if final_response.user_input_requests: + raise UserInputRequiredException(contents=final_response.user_input_requests) + return final_response.text agent_tool: FunctionTool = FunctionTool( name=tool_name, diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 105738e717..f595cd20a0 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -37,7 +37,7 @@ from pydantic import BaseModel, Field, ValidationError, create_model from ._serialization import SerializationMixin -from .exceptions import ToolException +from .exceptions import ToolException, UserInputRequiredException from .observability import ( OPERATION_DURATION_BUCKET_BOUNDARIES, OtelAttr, @@ -1228,6 +1228,8 @@ async def _auto_invoke_function( result=function_result, additional_properties=function_call_content.additional_properties, ) + except UserInputRequiredException: + raise except Exception as exc: message = "Error: Function failed." if config.get("include_detailed_errors", False): @@ -1274,6 +1276,8 @@ async def final_function_handler(context_obj: Any) -> Any: additional_properties=function_call_content.additional_properties, ) raise + except UserInputRequiredException: + raise except Exception as exc: message = "Error: Function failed." if config.get("include_detailed_errors", False): @@ -1384,8 +1388,8 @@ async def _try_execute_function_calls( async def invoke_with_termination_handling( function_call: Content, seq_idx: int, - ) -> tuple[Content, bool]: - """Invoke function and catch MiddlewareTermination, returning (result, should_terminate).""" + ) -> tuple[list[Content], bool]: + """Invoke function and catch MiddlewareTermination, returning (results, should_terminate).""" try: result = await _auto_invoke_function( function_call_content=function_call, # type: ignore[arg-type] @@ -1396,24 +1400,48 @@ async def invoke_with_termination_handling( middleware_pipeline=middleware_pipeline, config=config, ) - return (result, False) + return ([result], False) except MiddlewareTermination as exc: # Middleware requested termination - return result as Content # exc.result may already be a Content (set by _auto_invoke_function) or raw value if isinstance(exc.result, Content): - return (exc.result, True) + return ([exc.result], True) result_content = Content.from_function_result( call_id=function_call.call_id, # type: ignore[arg-type] result=exc.result, ) - return (result_content, True) + return ([result_content], True) + except UserInputRequiredException as exc: + # Sub-agent requires user input — propagate the Content items so + # _handle_function_call_results can surface them to the parent response. + if exc.contents: + propagated: list[Content] = [] + for idx, item in enumerate(exc.contents): + item.call_id = function_call.call_id # type: ignore[attr-defined] + if not item.id: # type: ignore[attr-defined] + item.id = f"{function_call.call_id}:{idx}" # type: ignore[attr-defined] + propagated.append(item) + if propagated: + return (propagated, False) + return ( + [ + Content.from_function_result( + call_id=function_call.call_id, # type: ignore[arg-type] + result="Tool requires user input but no request details were provided.", + exception="UserInputRequiredException", + ) + ], + False, + ) execution_results = await asyncio.gather(*[ invoke_with_termination_handling(function_call, seq_idx) for seq_idx, function_call in enumerate(function_calls) ]) - # Unpack results - each is (Content, terminate_flag) - contents: list[Content] = [result[0] for result in execution_results] + # Flatten results in original function_calls order — each task returns (list[Content], terminate_flag) + contents: list[Content] = [] + for result_contents, _ in execution_results: + contents.extend(result_contents) # If any function requested termination, terminate the loop should_terminate = any(result[1] for result in execution_results) return (contents, should_terminate) @@ -1645,7 +1673,10 @@ def _handle_function_call_results( ) -> FunctionRequestResult: from ._types import Message - if any(fccr.type in {"function_approval_request", "function_call"} for fccr in function_call_results): + if any( + fccr.type in {"function_approval_request", "function_call"} or fccr.user_input_request + for fccr in function_call_results + ): # Only add items that aren't already in the message (e.g. function_approval_request wrappers). # Declaration-only function_call items are already present from the LLM response. new_items = [fccr for fccr in function_call_results if fccr.type != "function_call"] diff --git a/python/packages/core/agent_framework/exceptions.py b/python/packages/core/agent_framework/exceptions.py index f38aa38590..24d38fd1dc 100644 --- a/python/packages/core/agent_framework/exceptions.py +++ b/python/packages/core/agent_framework/exceptions.py @@ -6,8 +6,13 @@ and guidance on choosing the correct exception class. """ +from __future__ import annotations + import logging -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal + +if TYPE_CHECKING: + from ._types import Content logger = logging.getLogger("agent_framework") @@ -180,6 +185,34 @@ class ToolExecutionException(ToolException): pass +class UserInputRequiredException(ToolException): + """Raised when a tool wrapping a sub-agent requires user input to proceed. + + This exception carries the ``user_input_request`` Content items emitted by + the sub-agent (e.g., ``oauth_consent_request``, ``function_approval_request``) + so the tool invocation layer can propagate them to the parent agent's response + instead of swallowing them as a generic tool error. + + Args: + contents: The user-input-request Content items from the sub-agent response. + message: Human-readable description of why user input is needed. + """ + + def __init__( + self, + contents: list[Content], + message: str = "Tool requires user input to proceed.", + ) -> None: + """Create a UserInputRequiredException. + + Args: + contents: The user-input-request Content items from the sub-agent response. + message: Human-readable description of why user input is needed. + """ + super().__init__(message, log_level=None) + self.contents: list[Content] = contents + + # endregion # region Middleware Exceptions diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index a60e924387..ef1ac6815b 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -1622,4 +1622,72 @@ async def test_stores_by_default_with_store_false_in_default_options_injects_inm assert any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers) +# region as_tool user_input_request propagation + + +async def test_as_tool_raises_on_user_input_request_non_streaming(client: SupportsChatGetResponse) -> None: + """Test that as_tool raises UserInputRequiredException when the sub-agent response has user_input_requests.""" + from agent_framework.exceptions import UserInputRequiredException + + # Configure mock client to return a response with oauth_consent_request content + consent_content = Content.from_oauth_consent_request( + consent_link="https://login.microsoftonline.com/consent", + ) + client.responses = [ # type: ignore[attr-defined] + ChatResponse(messages=Message(role="assistant", contents=[consent_content])), + ] + + agent = Agent(client=client, name="OAuthAgent", description="Agent requiring consent") + agent_tool = agent.as_tool() + + with raises(UserInputRequiredException) as exc_info: + await agent_tool.invoke(arguments=agent_tool.input_model(task="Do something")) + + assert len(exc_info.value.contents) == 1 + assert exc_info.value.contents[0].type == "oauth_consent_request" + assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent" + + +async def test_as_tool_raises_on_user_input_request_streaming(client: SupportsChatGetResponse) -> None: + """Test that as_tool raises UserInputRequiredException in streaming mode.""" + from agent_framework.exceptions import UserInputRequiredException + + consent_content = Content.from_oauth_consent_request( + consent_link="https://login.microsoftonline.com/consent", + ) + client.streaming_responses = [ # type: ignore[attr-defined] + [ChatResponseUpdate(contents=[consent_content], role="assistant")], + ] + + collected_updates: list[AgentResponseUpdate] = [] + + def stream_callback(update: AgentResponseUpdate) -> None: + collected_updates.append(update) + + agent = Agent(client=client, name="OAuthAgent", description="Agent requiring consent") + agent_tool = agent.as_tool(stream_callback=stream_callback) + + with raises(UserInputRequiredException) as exc_info: + await agent_tool.invoke(arguments=agent_tool.input_model(task="Do something")) + + assert len(exc_info.value.contents) == 1 + assert exc_info.value.contents[0].type == "oauth_consent_request" + # Stream callback should still have received the update before the exception + assert len(collected_updates) > 0 + + +async def test_as_tool_returns_text_when_no_user_input_request(client: SupportsChatGetResponse) -> None: + """Test that as_tool returns text normally when there are no user_input_requests.""" + client.responses = [ # type: ignore[attr-defined] + ChatResponse(messages=Message(role="assistant", text="Here is the result")), + ] + + agent = Agent(client=client, name="NormalAgent", description="Normal agent") + agent_tool = agent.as_tool() + + result = await agent_tool.invoke(arguments=agent_tool.input_model(task="Do something")) + + assert result == "Here is the result" + + # endregion diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 7f0eda62fc..a460f2ea7f 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -3511,4 +3511,61 @@ def test_dict_overwrites_existing_conversation_id(self): assert kwargs["chat_options"]["conversation_id"] == "new_id" +# region UserInputRequiredException propagation through tool invocation + + +async def test_user_input_request_propagates_through_as_tool(chat_client_base: SupportsChatGetResponse): + """Test that user_input_request content from a sub-agent wrapped as a tool propagates to the parent response. + + This is an end-to-end test: sub-agent returns oauth_consent_request → + as_tool raises UserInputRequiredException → invoke_with_termination_handling catches it → + _handle_function_call_results returns "action": "return" → Content ends up in parent response. + """ + from agent_framework.exceptions import UserInputRequiredException + + # Create a mock tool that simulates what as_tool does when the sub-agent + # returns user_input_request content + @tool(name="delegate_agent", approval_mode="never_require") + def delegate_tool(task: str) -> str: + raise UserInputRequiredException( + contents=[ + Content.from_oauth_consent_request( + consent_link="https://login.microsoftonline.com/consent", + ) + ] + ) + + # Parent agent calls the tool, which raises UserInputRequiredException + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="1", name="delegate_agent", arguments='{"task": "do it"}'), + ], + ) + ), + ] + + response = await chat_client_base.get_response( + [Message(role="user", text="delegate this")], + options={"tool_choice": "auto", "tools": [delegate_tool]}, + ) + + # The oauth_consent_request Content should be in the parent response's assistant message + user_requests = [ + content + for msg in response.messages + for content in msg.contents + if isinstance(content, Content) and content.user_input_request + ] + assert len(user_requests) == 1 + assert user_requests[0].type == "oauth_consent_request" + assert user_requests[0].consent_link == "https://login.microsoftonline.com/consent" + assert user_requests[0].user_input_request is True + + +# endregion + + # endregion