From 37303b047bd527c99eed9a216cd21882b314319b Mon Sep 17 00:00:00 2001 From: junanchen Date: Mon, 19 Jan 2026 13:38:48 -0800 Subject: [PATCH 01/29] remove unused code in af --- .../agentframework/_agent_framework.py | 51 ++----------------- 1 file changed, 5 insertions(+), 46 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py index 8d22cd9ff263..40e1e72b70a4 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py @@ -5,31 +5,28 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Optional, Protocol, Union, List +from typing import Any, AsyncGenerator, Optional, TYPE_CHECKING, Union -from agent_framework import AgentProtocol, AIFunction, CheckpointStorage, InMemoryCheckpointStorage, WorkflowCheckpoint -from agent_framework.azure import AzureAIClient # pylint: disable=no-name-in-module +from agent_framework import AgentProtocol, CheckpointStorage, WorkflowCheckpoint from agent_framework._workflows import get_checkpoint_summary +from agent_framework.azure import AzureAIClient # pylint: disable=no-name-in-module from opentelemetry import trace -from azure.ai.agentserver.core.tools import OAuthConsentRequiredError from azure.ai.agentserver.core import AgentRunContext, FoundryCBAgent from azure.ai.agentserver.core.constants import Constants as AdapterConstants from azure.ai.agentserver.core.logger import APPINSIGHT_CONNSTR_ENV_NAME, get_logger from azure.ai.agentserver.core.models import ( - CreateResponse, Response as OpenAIResponse, ResponseStreamEvent, ) from azure.ai.agentserver.core.models.projects import ResponseErrorEvent, ResponseFailedEvent - +from azure.ai.agentserver.core.tools import OAuthConsentRequiredError from .models.agent_framework_input_converters import AgentFrameworkInputConverter from .models.agent_framework_output_non_streaming_converter import ( AgentFrameworkOutputNonStreamingConverter, ) from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter from .models.human_in_the_loop_helper import HumanInTheLoopHelper -from .models.constants import Constants from .persistence import AgentThreadRepository, CheckpointRepository if TYPE_CHECKING: @@ -38,24 +35,6 @@ logger = get_logger() -class AgentFactory(Protocol): - """Protocol for agent factory functions. - - An agent factory is a callable that takes a list of tools and returns - an AgentProtocol, either synchronously or asynchronously. - """ - - def __call__(self, tools: List[AIFunction]) -> Union[AgentProtocol, Awaitable[AgentProtocol]]: - """Create an AgentProtocol using the provided tools. - - :param tools: The list of AIFunction tools available to the agent. - :type tools: List[AIFunction] - :return: An Agent Framework agent, or an awaitable that resolves to one. - :rtype: Union[AgentProtocol, Awaitable[AgentProtocol]] - """ - ... - - class AgentFrameworkCBAgent(FoundryCBAgent): """ Adapter class for integrating Agent Framework agents with the FoundryCB agent interface. @@ -84,7 +63,7 @@ def __init__(self, agent: AgentProtocol, :param agent: The Agent Framework agent to adapt, or a callable that takes ToolClient and returns AgentProtocol (sync or async). - :type agent: Union[AgentProtocol, AgentFactory] + :type agent: AgentProtocol :param credentials: Azure credentials for authentication. :type credentials: Optional[AsyncTokenCredential] :param thread_repository: An optional AgentThreadRepository instance for managing thread messages. @@ -105,26 +84,6 @@ def agent(self) -> "AgentProtocol": """ return self._agent - def _resolve_stream_timeout(self, request_body: CreateResponse) -> float: - """Resolve idle timeout for streaming updates. - - Order of precedence: - 1) request_body.stream_timeout_s (if provided) - 2) env var Constants.AGENTS_ADAPTER_STREAM_TIMEOUT_S - 3) Constants.DEFAULT_STREAM_TIMEOUT_S - - :param request_body: The CreateResponse request body. - :type request_body: CreateResponse - - :return: The resolved stream timeout in seconds. - :rtype: float - """ - override = request_body.get("stream_timeout_s", None) - if override is not None: - return float(override) - env_val = os.getenv(Constants.AGENTS_ADAPTER_STREAM_TIMEOUT_S) - return float(env_val) if env_val is not None else float(Constants.DEFAULT_STREAM_TIMEOUT_S) - def init_tracing(self): try: otel_exporter_endpoint = os.environ.get(AdapterConstants.OTEL_EXPORTER_ENDPOINT) From f0e463064d825ecb8035de65cc69ec0e094e4c19 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Mon, 19 Jan 2026 17:40:16 -0800 Subject: [PATCH 02/29] created subclasses for agent-framework AIAgent and WorkflowAgent --- .../ai/agentserver/agentframework/__init__.py | 14 +- .../agentframework/_agent_framework.py | 159 +------------- .../agentframework/_ai_agent_adapter.py | 151 +++++++++++++ .../agentframework/_workflow_agent_adapter.py | 202 ++++++++++++++++++ ...ramework_output_non_streaming_converter.py | 2 +- ...nt_framework_output_streaming_converter.py | 10 +- .../persistence/agent_thread_repository.py | 21 +- .../README.md | 2 +- .../human_in_the_loop_workflow_agent/main.py | 15 +- .../workflow_agent_simple.py | 8 +- 10 files changed, 400 insertions(+), 184 deletions(-) create mode 100644 sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py create mode 100644 sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py index 20a41df7ef73..b257ff132717 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py @@ -3,10 +3,14 @@ # --------------------------------------------------------- __path__ = __import__("pkgutil").extend_path(__path__, __name__) -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union + +from agent_framework import AgentProtocol, WorkflowBuilder from azure.ai.agentserver.agentframework._version import VERSION from azure.ai.agentserver.agentframework._agent_framework import AgentFrameworkCBAgent +from azure.ai.agentserver.agentframework._ai_agent_adapter import AgentFrameworkAIAgentAdapter +from azure.ai.agentserver.agentframework._workflow_agent_adapter import AgentFrameworkWorkflowAdapter from azure.ai.agentserver.agentframework._foundry_tools import FoundryToolsChatMiddleware from azure.ai.agentserver.core.application import PackageMetadata, set_current_app @@ -15,12 +19,16 @@ def from_agent_framework( - agent, + agent: Union[AgentProtocol, WorkflowBuilder], credentials: Optional["AsyncTokenCredential"] = None, **kwargs: Any, ) -> "AgentFrameworkCBAgent": - return AgentFrameworkCBAgent(agent, credentials=credentials, **kwargs) + if isinstance(agent, WorkflowBuilder): + return AgentFrameworkWorkflowAdapter(workflow_builder=agent, credentials=credentials, **kwargs) + if isinstance(agent, AgentProtocol): + return AgentFrameworkAIAgentAdapter(agent, credentials=credentials, **kwargs) + raise TypeError("agent must be an instance of AgentProtocol or WorkflowBuilder") __all__ = [ diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py index 8d22cd9ff263..7407fc96131c 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py @@ -7,9 +7,8 @@ import os from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Optional, Protocol, Union, List -from agent_framework import AgentProtocol, AIFunction, CheckpointStorage, InMemoryCheckpointStorage, WorkflowCheckpoint +from agent_framework import AgentProtocol, AIFunction from agent_framework.azure import AzureAIClient # pylint: disable=no-name-in-module -from agent_framework._workflows import get_checkpoint_summary from opentelemetry import trace from azure.ai.agentserver.core.tools import OAuthConsentRequiredError @@ -75,9 +74,6 @@ class AgentFrameworkCBAgent(FoundryCBAgent): def __init__(self, agent: AgentProtocol, credentials: "Optional[AsyncTokenCredential]" = None, - *, - thread_repository: AgentThreadRepository = None, - checkpoint_repository: CheckpointRepository = None, **kwargs: Any, ): """Initialize the AgentFrameworkCBAgent with an AgentProtocol or a factory function. @@ -93,8 +89,6 @@ def __init__(self, agent: AgentProtocol, super().__init__(credentials=credentials, **kwargs) # pylint: disable=unexpected-keyword-arg self._agent: AgentProtocol = agent self._hitl_helper = HumanInTheLoopHelper() - self._checkpoint_repository = checkpoint_repository - self._thread_repository = thread_repository @property def agent(self) -> "AgentProtocol": @@ -229,153 +223,4 @@ async def agent_run( # pylint: disable=too-many-statements OpenAIResponse, AsyncGenerator[ResponseStreamEvent, Any], ]: - try: - logger.info(f"Starting agent_run with stream={context.stream}") - request_input = context.request.get("input") - - agent_thread = None - checkpoint_storage = None - last_checkpoint = None - if self._thread_repository: - agent_thread = await self._thread_repository.get(context.conversation_id) - if agent_thread: - logger.info(f"Loaded agent thread for conversation: {context.conversation_id}") - else: - agent_thread = self.agent.get_new_thread() - - if self._checkpoint_repository: - checkpoint_storage = await self._checkpoint_repository.get_or_create(context.conversation_id) - last_checkpoint = await self._get_latest_checkpoint(checkpoint_storage) - if last_checkpoint: - summary = get_checkpoint_summary(last_checkpoint) - if summary.status == "completed": - logger.warning("Last checkpoint is completed. Will not resume from it.") - last_checkpoint = None # Do not resume from completed checkpoints - if last_checkpoint: - await self._load_checkpoint(self.agent, last_checkpoint, checkpoint_storage) - logger.info(f"Loaded checkpoint with ID: {last_checkpoint.checkpoint_id}") - - input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper) - message = await input_converter.transform_input( - request_input, - agent_thread=agent_thread, - checkpoint=last_checkpoint) - logger.debug(f"Transformed input message type: {type(message)}") - - # Use split converters - if context.stream: - logger.info("Running agent in streaming mode") - streaming_converter = AgentFrameworkOutputStreamingConverter(context, hitl_helper=self._hitl_helper) - - async def stream_updates(): - try: - update_count = 0 - try: - updates = self.agent.run_stream( - message, - thread=agent_thread, - checkpoint_storage=checkpoint_storage, - ) - async for event in streaming_converter.convert(updates): - update_count += 1 - yield event - - if agent_thread and self._thread_repository: - await self._thread_repository.set(context.conversation_id, agent_thread, checkpoint_storage) - logger.info(f"Saved agent thread for conversation: {context.conversation_id}") - - logger.info("Streaming completed with %d updates", update_count) - except OAuthConsentRequiredError as e: - logger.info("OAuth consent required during streaming updates") - if update_count == 0: - async for event in self.respond_with_oauth_consent_astream(context, e): - yield event - else: - # If we've already emitted events, we cannot safely restart a new - # OAuth-consent stream (it would reset sequence numbers). - yield ResponseErrorEvent( - sequence_number=streaming_converter.next_sequence(), - code="server_error", - message=f"OAuth consent required: {e.consent_url}", - param="agent_run", - ) - yield ResponseFailedEvent( - sequence_number=streaming_converter.next_sequence(), - response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access - ) - except Exception as e: # pylint: disable=broad-exception-caught - logger.error("Unhandled exception during streaming updates: %s", e, exc_info=True) - - # Emit well-formed error events instead of terminating the stream. - yield ResponseErrorEvent( - sequence_number=streaming_converter.next_sequence(), - code="server_error", - message=str(e), - param="agent_run", - ) - yield ResponseFailedEvent( - sequence_number=streaming_converter.next_sequence(), - response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access - ) - finally: - # No request-scoped resources to clean up here today. - # Keep this block as a hook for future request-scoped cleanup. - pass - - return stream_updates() - - # Non-streaming path - logger.info("Running agent in non-streaming mode") - non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context) - result = await self.agent.run( - message, - thread=agent_thread, - checkpoint_storage=checkpoint_storage) - logger.debug(f"Agent run completed, result type: {type(result)}") - - if agent_thread and self._thread_repository: - await self._thread_repository.set(context.conversation_id, agent_thread) - logger.info(f"Saved agent thread for conversation: {context.conversation_id}") - - transformed_result = non_streaming_converter.transform_output_for_response(result) - logger.info("Agent run and transformation completed successfully") - return transformed_result - except OAuthConsentRequiredError as e: - logger.info("OAuth consent required during agent run") - if context.stream: - # Yield OAuth consent response events - # Capture e in the closure by passing it as a default argument - async def oauth_consent_stream(error=e): - async for event in self.respond_with_oauth_consent_astream(context, error): - yield event - return oauth_consent_stream() - return await self.respond_with_oauth_consent(context, e) - finally: - pass - - async def _get_latest_checkpoint(self, - checkpoint_storage: CheckpointStorage) -> Optional[Any]: - """Load the latest checkpoint from the given storage. - - :param checkpoint_storage: The checkpoint storage to load from. - :type checkpoint_storage: CheckpointStorage - - :return: The latest checkpoint if available, None otherwise. - :rtype: Optional[Any] - """ - checkpoints = await checkpoint_storage.list_checkpoints() - if checkpoints: - latest_checkpoint = max(checkpoints, key=lambda cp: cp.timestamp) - return latest_checkpoint - return None - - async def _load_checkpoint(self, agent: AgentProtocol, - checkpoint: WorkflowCheckpoint, - checkpoint_storage: CheckpointStorage) -> None: - """Load the checkpoint data from the given WorkflowCheckpoint. - - :param checkpoint: The WorkflowCheckpoint to load data from. - :type checkpoint: WorkflowCheckpoint - """ - await agent.run(checkpoint_id=checkpoint.checkpoint_id, - checkpoint_storage=checkpoint_storage) \ No newline at end of file + raise NotImplementedError("This method is implemented in the base class.") \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py new file mode 100644 index 000000000000..5f20367f5a0f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py @@ -0,0 +1,151 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=logging-fstring-interpolation,no-name-in-module,no-member,do-not-import-asyncio +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Union + +from agent_framework import AgentProtocol + +from azure.ai.agentserver.core import AgentRunContext +from azure.ai.agentserver.core.tools import OAuthConsentRequiredError +from azure.ai.agentserver.core.logger import get_logger +from azure.ai.agentserver.core.models import ( + Response as OpenAIResponse, + ResponseStreamEvent, +) +from azure.ai.agentserver.core.models.projects import ResponseErrorEvent, ResponseFailedEvent + +from .models.agent_framework_input_converters import AgentFrameworkInputConverter +from .models.agent_framework_output_non_streaming_converter import ( + AgentFrameworkOutputNonStreamingConverter, +) +from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter +from ._agent_framework import AgentFrameworkCBAgent +from .persistence import AgentThreadRepository + +logger = get_logger() + +class AgentFrameworkAIAgentAdapter(AgentFrameworkCBAgent): + def __init__(self, agent: AgentProtocol, + *, + thread_repository: Optional[AgentThreadRepository]=None, + **kwargs) -> None: + super().__init__(agent=agent, **kwargs) + self._agent = agent + self._thread_repository = thread_repository + + async def agent_run( # pylint: disable=too-many-statements + self, context: AgentRunContext + ) -> Union[ + OpenAIResponse, + AsyncGenerator[ResponseStreamEvent, Any], + ]: + try: + logger.info(f"Starting agent_run with stream={context.stream}") + request_input = context.request.get("input") + + agent_thread = None + if self._thread_repository: + agent_thread = await self._thread_repository.get(context.conversation_id) + if agent_thread: + logger.info(f"Loaded agent thread for conversation: {context.conversation_id}") + else: + agent_thread = self.agent.get_new_thread() + + input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper) + message = await input_converter.transform_input( + request_input, + agent_thread=agent_thread) + logger.debug(f"Transformed input message type: {type(message)}") + + # Use split converters + if context.stream: + logger.info("Running agent in streaming mode") + streaming_converter = AgentFrameworkOutputStreamingConverter(context, hitl_helper=self._hitl_helper) + + async def stream_updates(): + try: + update_count = 0 + try: + updates = self.agent.run_stream( + message, + thread=agent_thread, + ) + async for event in streaming_converter.convert(updates): + update_count += 1 + yield event + + if agent_thread and self._thread_repository: + await self._thread_repository.set(context.conversation_id, agent_thread) + logger.info(f"Saved agent thread for conversation: {context.conversation_id}") + + logger.info("Streaming completed with %d updates", update_count) + except OAuthConsentRequiredError as e: + logger.info("OAuth consent required during streaming updates") + if update_count == 0: + async for event in self.respond_with_oauth_consent_astream(context, e): + yield event + else: + # If we've already emitted events, we cannot safely restart a new + # OAuth-consent stream (it would reset sequence numbers). + yield ResponseErrorEvent( + sequence_number=streaming_converter.next_sequence(), + code="server_error", + message=f"OAuth consent required: {e.consent_url}", + param="agent_run", + ) + yield ResponseFailedEvent( + sequence_number=streaming_converter.next_sequence(), + response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Unhandled exception during streaming updates: %s", e, exc_info=True) + + # Emit well-formed error events instead of terminating the stream. + yield ResponseErrorEvent( + sequence_number=streaming_converter.next_sequence(), + code="server_error", + message=str(e), + param="agent_run", + ) + yield ResponseFailedEvent( + sequence_number=streaming_converter.next_sequence(), + response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access + ) + finally: + # No request-scoped resources to clean up here today. + # Keep this block as a hook for future request-scoped cleanup. + pass + + return stream_updates() + + # Non-streaming path + logger.info("Running agent in non-streaming mode") + non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper) + result = await self.agent.run( + message, + thread=agent_thread) + logger.debug(f"Agent run completed, result type: {type(result)}") + + if agent_thread and self._thread_repository: + await self._thread_repository.set(context.conversation_id, agent_thread) + logger.info(f"Saved agent thread for conversation: {context.conversation_id}") + + transformed_result = non_streaming_converter.transform_output_for_response(result) + logger.info("Agent run and transformation completed successfully") + return transformed_result + except OAuthConsentRequiredError as e: + logger.info("OAuth consent required during agent run") + if context.stream: + # Yield OAuth consent response events + # Capture e in the closure by passing it as a default argument + async def oauth_consent_stream(error=e): + async for event in self.respond_with_oauth_consent_astream(context, error): + yield event + return oauth_consent_stream() + return await self.respond_with_oauth_consent(context, e) + finally: + pass diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py new file mode 100644 index 000000000000..45089f88c012 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py @@ -0,0 +1,202 @@ + +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Optional, Protocol, Union, List + +from agent_framework import WorkflowBuilder, CheckpointStorage, WorkflowAgent, WorkflowCheckpoint +from agent_framework._workflows import get_checkpoint_summary + +from azure.ai.agentserver.core.tools import OAuthConsentRequiredError +from azure.ai.agentserver.core import AgentRunContext +from azure.ai.agentserver.core.logger import get_logger +from azure.ai.agentserver.core.models import ( + CreateResponse, + Response as OpenAIResponse, + ResponseStreamEvent, +) +from azure.ai.agentserver.core.models.projects import ResponseErrorEvent, ResponseFailedEvent + +from ._agent_framework import AgentFrameworkCBAgent +from .models.agent_framework_input_converters import AgentFrameworkInputConverter +from .models.agent_framework_output_non_streaming_converter import ( + AgentFrameworkOutputNonStreamingConverter, +) +from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter +from .persistence.agent_thread_repository import AgentThreadRepository +from .persistence.checkpoint_repository import CheckpointRepository + +logger = get_logger() + +class AgentFrameworkWorkflowAdapter(AgentFrameworkCBAgent): + """Adapter to run WorkflowBuilder agents within the Agent Framework CBAgent structure.""" + def __init__(self, + workflow_builder: WorkflowBuilder, + *, + thread_repository: Optional[AgentThreadRepository] = None, + checkpoint_repository: Optional[CheckpointRepository] = None, + **kwargs: Any) -> None: + super().__init__(agent=workflow_builder, **kwargs) + self._workflow_builder = workflow_builder + self._thread_repository = thread_repository + self._checkpoint_repository = checkpoint_repository + + async def agent_run( # pylint: disable=too-many-statements + self, context: AgentRunContext + ) -> Union[ + OpenAIResponse, + AsyncGenerator[ResponseStreamEvent, Any], + ]: + try: + agent = self._build_agent() + + logger.info(f"Starting agent_run with stream={context.stream}") + request_input = context.request.get("input") + + agent_thread = None + checkpoint_storage = None + last_checkpoint = None + if self._thread_repository: + agent_thread = await self._thread_repository.get(context.conversation_id, agent=agent) + if agent_thread: + logger.info(f"Loaded agent thread for conversation: {context.conversation_id}") + else: + agent_thread = agent.get_new_thread() + + if self._checkpoint_repository: + checkpoint_storage = await self._checkpoint_repository.get_or_create(context.conversation_id) + last_checkpoint = await self._get_latest_checkpoint(checkpoint_storage) + if last_checkpoint: + summary = get_checkpoint_summary(last_checkpoint) + if summary.status == "completed": + logger.warning("Last checkpoint is completed. Will not resume from it.") + last_checkpoint = None # Do not resume from completed checkpoints + if last_checkpoint: + await self._load_checkpoint(agent, last_checkpoint, checkpoint_storage) + logger.info(f"Loaded checkpoint with ID: {last_checkpoint.checkpoint_id}") + + input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper) + message = await input_converter.transform_input( + request_input, + agent_thread=agent_thread, + checkpoint=last_checkpoint) + logger.debug(f"Transformed input message type: {type(message)}") + + # Use split converters + if context.stream: + logger.info("Running agent in streaming mode") + streaming_converter = AgentFrameworkOutputStreamingConverter(context, hitl_helper=self._hitl_helper) + + async def stream_updates(): + try: + update_count = 0 + try: + updates = agent.run_stream( + message, + thread=agent_thread, + checkpoint_storage=checkpoint_storage, + ) + async for event in streaming_converter.convert(updates): + update_count += 1 + yield event + + if agent_thread and self._thread_repository: + await self._thread_repository.set(context.conversation_id, agent_thread) + logger.info(f"Saved agent thread for conversation: {context.conversation_id}") + + logger.info("Streaming completed with %d updates", update_count) + except OAuthConsentRequiredError as e: + logger.info("OAuth consent required during streaming updates") + if update_count == 0: + async for event in self.respond_with_oauth_consent_astream(context, e): + yield event + else: + # If we've already emitted events, we cannot safely restart a new + # OAuth-consent stream (it would reset sequence numbers). + yield ResponseErrorEvent( + sequence_number=streaming_converter.next_sequence(), + code="server_error", + message=f"OAuth consent required: {e.consent_url}", + param="agent_run", + ) + yield ResponseFailedEvent( + sequence_number=streaming_converter.next_sequence(), + response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Unhandled exception during streaming updates: %s", e, exc_info=True) + + # Emit well-formed error events instead of terminating the stream. + yield ResponseErrorEvent( + sequence_number=streaming_converter.next_sequence(), + code="server_error", + message=str(e), + param="agent_run", + ) + yield ResponseFailedEvent( + sequence_number=streaming_converter.next_sequence(), + response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access + ) + finally: + # No request-scoped resources to clean up here today. + # Keep this block as a hook for future request-scoped cleanup. + pass + + return stream_updates() + + # Non-streaming path + logger.info("Running agent in non-streaming mode") + non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper) + result = await agent.run( + message, + thread=agent_thread, + checkpoint_storage=checkpoint_storage) + logger.debug(f"Agent run completed, result type: {type(result)}") + + if agent_thread and self._thread_repository: + await self._thread_repository.set(context.conversation_id, agent_thread) + logger.info(f"Saved agent thread for conversation: {context.conversation_id}") + + transformed_result = non_streaming_converter.transform_output_for_response(result) + logger.info("Agent run and transformation completed successfully") + return transformed_result + except OAuthConsentRequiredError as e: + logger.info("OAuth consent required during agent run") + if context.stream: + # Yield OAuth consent response events + # Capture e in the closure by passing it as a default argument + async def oauth_consent_stream(error=e): + async for event in self.respond_with_oauth_consent_astream(context, error): + yield event + return oauth_consent_stream() + return await self.respond_with_oauth_consent(context, e) + finally: + pass + + def _build_agent(self) -> WorkflowAgent: + return self._workflow_builder.build().as_agent() + + async def _get_latest_checkpoint(self, + checkpoint_storage: CheckpointStorage) -> Optional[Any]: + """Load the latest checkpoint from the given storage. + + :param checkpoint_storage: The checkpoint storage to load from. + :type checkpoint_storage: CheckpointStorage + + :return: The latest checkpoint if available, None otherwise. + :rtype: Optional[Any] + """ + checkpoints = await checkpoint_storage.list_checkpoints() + if checkpoints: + latest_checkpoint = max(checkpoints, key=lambda cp: cp.timestamp) + return latest_checkpoint + return None + + async def _load_checkpoint(self, agent: WorkflowAgent, + checkpoint: WorkflowCheckpoint, + checkpoint_storage: CheckpointStorage) -> None: + """Load the checkpoint data from the given WorkflowCheckpoint. + + :param checkpoint: The WorkflowCheckpoint to load data from. + :type checkpoint: WorkflowCheckpoint + """ + logger.info(f"Loading checkpoint ID: {checkpoint.to_dict()} into agent.") + await agent.run(checkpoint_id=checkpoint.checkpoint_id, + checkpoint_storage=checkpoint_storage) \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py index 08db24adfae0..95c7bb7acc6b 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py @@ -36,7 +36,7 @@ class AgentFrameworkOutputNonStreamingConverter: # pylint: disable=name-too-long """Non-streaming converter: AgentRunResponse -> OpenAIResponse.""" - def __init__(self, context: AgentRunContext, *, hitl_helper: HumanInTheLoopHelper): + def __init__(self, context: AgentRunContext, *, hitl_helper: HumanInTheLoopHelper=None): self._context = context self._response_id = None self._response_created_at = None diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py index 23d8702e38ec..be3280b91502 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py @@ -23,6 +23,7 @@ Response as OpenAIResponse, ResponseStreamEvent, ) +from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.models.projects import ( AgentId, CreatedBy, @@ -48,6 +49,7 @@ from .human_in_the_loop_helper import HumanInTheLoopHelper from .utils.async_iter import chunk_on_change, peek +logger = get_logger() class _BaseStreamingState: """Base interface for streaming state handlers.""" @@ -356,7 +358,10 @@ async def convert(self, updates: AsyncIterable[AgentRunResponseUpdate]) -> Async lambda a, b: a is not None \ and b is not None \ and (a.message_id != b.message_id \ - or type(a.content[0]) != type(b.content[0])) # pylint: disable=unnecessary-lambda-assignment + or ( + a.contents and b.contents \ + and type(a.contents[0]) != type(b.contents[0])) + ) # pylint: disable=unnecessary-lambda-assignment ) async for group in chunk_on_change(updates, is_changed): @@ -405,7 +410,8 @@ def _build_created_by(self, author_name: str) -> dict: async def _read_updates(self, updates: AsyncIterable[AgentRunResponseUpdate]) -> AsyncIterable[tuple[BaseContent, str]]: async for update in updates: - if not update.contents: + logger.info(f"Processing update: {update.to_dict()}") + if not hasattr(update, "contents") or not update.contents: continue # Extract author_name from each update diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py index ea3de29385e1..294d7d0948fc 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py @@ -1,20 +1,22 @@ from abc import ABC, abstractmethod import json import os -from typing import Any, Optional +from typing import Any, Optional, Union -from agent_framework import AgentThread, AgentProtocol +from agent_framework import AgentThread, AgentProtocol, WorkflowAgent class AgentThreadRepository(ABC): """AgentThread repository to manage saved thread messages of agent threads and workflows.""" @abstractmethod - async def get(self, conversation_id: str) -> Optional[AgentThread]: + async def get(self, conversation_id: str, agent: Optional[Union[AgentProtocol, WorkflowAgent]]=None) -> Optional[AgentThread]: """Retrieve the savedt thread for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: str + :param agent: The agent instance. If provided, it can be used to deserialize the thread. + :type agent: Optional[Union[AgentProtocol, WorkflowAgent]] :return: The saved AgentThread if available, None otherwise. :rtype: Optional[AgentThread] @@ -36,12 +38,13 @@ class InMemoryAgentThreadRepository(AgentThreadRepository): def __init__(self) -> None: self._inventory: dict[str, AgentThread] = {} - async def get(self, conversation_id: str) -> Optional[AgentThread]: + async def get(self, conversation_id: str, agent: Optional[Union[AgentProtocol, WorkflowAgent]]=None) -> Optional[AgentThread]: """Retrieve the saved thread for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: str - + :param agent: The agent instance. It will be used for in-memory repository for interface consistency. + :type agent: Optional[Union[AgentProtocol, WorkflowAgent]] :return: The saved AgentThread if available, None otherwise. :rtype: Optional[AgentThread] """ @@ -72,18 +75,22 @@ def __init__(self, agent: AgentProtocol) -> None: """ self._agent = agent - async def get(self, conversation_id: str) -> Optional[AgentThread]: + async def get(self, conversation_id: str, agent: Optional[Union[AgentProtocol, WorkflowAgent]]=None) -> Optional[AgentThread]: """Retrieve the saved thread for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: str + :param agent: The agent instance. If provided, it can be used to deserialize the thread. + Otherwise, the repository's agent will be used. + :type agent: Optional[Union[AgentProtocol, WorkflowAgent]] :return: The saved AgentThread if available, None otherwise. :rtype: Optional[AgentThread] """ serialized_thread = await self.read_from_storage(conversation_id) if serialized_thread: - thread = await self._agent.deserialize_thread(serialized_thread) + agent_to_use = agent or self._agent + thread = await agent_to_use.deserialize_thread(serialized_thread) return thread return None diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/README.md b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/README.md index aed6deee122a..172422f87c7f 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/README.md +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/README.md @@ -132,7 +132,7 @@ Respond by sending a `CreateResponse` request with `function_call_output` messag "input": [ { "call_id": "", - "output": "{\"request_id\":\"\",\"approved\":true}", + "output": "{\"request_id\":\"\",\"approved\":true,\"feedback\":\"approve\"}", "type": "function_call_output", } ] diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py index b5deef145920..cc89c941e65e 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py @@ -23,7 +23,7 @@ ) from azure.ai.agentserver.agentframework import from_agent_framework -from azure.ai.agentserver.agentframework.persistence import InMemoryCheckpointRepository +from azure.ai.agentserver.agentframework.persistence import FileCheckpointRepository load_dotenv() @@ -84,10 +84,10 @@ async def accept_human_review( print("Reviewer: Forwarding human review back to worker...") await ctx.send_message(response, target_id=self._worker_id) -def build_agent(tools): +def create_builder(): # Build a workflow with bidirectional communication between Worker and Reviewer, # and escalation paths for human review. - agent = ( + builder = ( WorkflowBuilder() .register_executor( lambda: Worker( @@ -103,17 +103,16 @@ def build_agent(tools): .add_edge("worker", "reviewer") # Worker sends requests to Reviewer .add_edge("reviewer", "worker") # Reviewer sends feedback to Worker .set_start_executor("worker") - .build() - .as_agent() # Convert workflow into an agent interface ) - return agent + return builder async def run_agent() -> None: """Run the workflow inside the agent server adapter.""" + builder = create_builder() await from_agent_framework( - build_agent, # pass WorkflowAgent factory to adapter, build a new instance per request - checkpoint_repository=InMemoryCheckpointRepository(), # for checkpoint storage + builder, # pass workflow builder to adapter + checkpoint_repository=FileCheckpointRepository(storage_path="./checkpoints"), # for checkpoint storage ).run_async() if __name__ == "__main__": diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py index 4d2569c38932..afbc0b48667b 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py @@ -264,7 +264,7 @@ async def handle_review_response( self._pending_requests[new_request.request_id] = (new_request, messages) -def build_agent(chat_client: BaseChatClient): +def create_builder(chat_client: BaseChatClient): reviewer = Reviewer(chat_client=chat_client) worker = Worker(chat_client=chat_client) return ( @@ -276,16 +276,14 @@ def build_agent(chat_client: BaseChatClient): reviewer, worker ) # <--- This edge allows the reviewer to send feedback back to the worker .set_start_executor(worker) - .build() - .as_agent() # Convert the workflow to an agent. ) async def main() -> None: async with DefaultAzureCredential() as credential: async with AzureAIAgentClient(async_credential=credential) as chat_client: - agent = build_agent(chat_client) - await from_agent_framework(agent).run_async() + builder = create_builder(chat_client) + await from_agent_framework(builder).run_async() if __name__ == "__main__": From 2196a8136276d7a7d4d52f2604efb6c6be12d96e Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Mon, 19 Jan 2026 18:15:23 -0800 Subject: [PATCH 03/29] remove unused code --- .../models/agent_framework_output_streaming_converter.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py index be3280b91502..503e1c29bfbd 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py @@ -23,7 +23,6 @@ Response as OpenAIResponse, ResponseStreamEvent, ) -from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.models.projects import ( AgentId, CreatedBy, @@ -49,7 +48,6 @@ from .human_in_the_loop_helper import HumanInTheLoopHelper from .utils.async_iter import chunk_on_change, peek -logger = get_logger() class _BaseStreamingState: """Base interface for streaming state handlers.""" @@ -410,8 +408,7 @@ def _build_created_by(self, author_name: str) -> dict: async def _read_updates(self, updates: AsyncIterable[AgentRunResponseUpdate]) -> AsyncIterable[tuple[BaseContent, str]]: async for update in updates: - logger.info(f"Processing update: {update.to_dict()}") - if not hasattr(update, "contents") or not update.contents: + if not update.contents: continue # Extract author_name from each update From 4ae465b1ec1c57e8d73980564b714726ee2cf943 Mon Sep 17 00:00:00 2001 From: junanchen Date: Mon, 19 Jan 2026 19:53:31 -0800 Subject: [PATCH 04/29] validate core with tox --- .../agentserver/core/application/__init__.py | 2 +- .../core/application/_configuration.py | 4 +- .../agentserver/core/application/_options.py | 5 +- .../core/application/_package_metadata.py | 4 +- .../azure/ai/agentserver/core/server/base.py | 52 +- .../core/server/common/constants.py | 2 +- .../ai/agentserver/core/tools/__init__.py | 82 +- .../ai/agentserver/core/tools/_exceptions.py | 2 - .../agentserver/core/tools/client/_client.py | 95 +- .../core/tools/client/_configuration.py | 40 +- .../agentserver/core/tools/client/_models.py | 894 +++++++++--------- .../core/tools/client/operations/_base.py | 4 +- .../operations/_foundry_hosted_mcp_tools.py | 6 +- .../core/tools/runtime/_catalog.py | 11 +- .../agentserver/core/tools/runtime/_facade.py | 5 +- .../core/tools/runtime/_resolver.py | 17 +- .../core/tools/runtime/_runtime.py | 10 +- .../core/tools/runtime/_starlette.py | 2 +- .../agentserver/core/tools/runtime/_user.py | 12 +- .../agentserver/core/tools/utils/__init__.py | 6 +- .../core/tools/utils/_name_resolver.py | 4 +- .../ai/agentserver/core/utils/_credential.py | 33 +- 22 files changed, 726 insertions(+), 566 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/__init__.py index ccf4062cce31..6e70a718531c 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/__init__.py @@ -9,4 +9,4 @@ "set_current_app" ] -from ._package_metadata import PackageMetadata, set_current_app \ No newline at end of file +from ._package_metadata import PackageMetadata, set_current_app diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_configuration.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_configuration.py index fe05dae18a67..1f8a01d57639 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_configuration.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_configuration.py @@ -31,12 +31,12 @@ class ToolsConfiguration: catalog_cache_max_size: int = 1024 -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class AgentServerConfiguration: """Resolved configuration for the Agent Server application.""" - agent_name: str = "$default" project_endpoint: str credential: AsyncTokenCredential + agent_name: str = "$default" http: HttpServerConfiguration = field(default_factory=HttpServerConfiguration) tools: ToolsConfiguration = field(default_factory=ToolsConfiguration) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_options.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_options.py index cb4e8bde0bfd..dc80c1538327 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_options.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_options.py @@ -1,7 +1,9 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import Literal, NotRequired, TypedDict, Union +from typing import Literal, TypedDict, Union + +from typing_extensions import NotRequired from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential @@ -41,4 +43,3 @@ class ToolsOptions(TypedDict): """ catalog_cache_ttl: NotRequired[int] catalog_cache_max_size: NotRequired[int] - diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_package_metadata.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_package_metadata.py index 36ff9313a6a2..5701110e5c7f 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_package_metadata.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_package_metadata.py @@ -41,10 +41,10 @@ def as_user_agent(self, component: str | None = None) -> str: def set_current_app(app: PackageMetadata) -> None: - global _app + global _app # pylint: disable=W0603 _app = app def get_current_app() -> PackageMetadata: - global _app + global _app # pylint: disable=W0602 return _app diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/base.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/base.py index a5f69664cf66..2afbed6e99a8 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/base.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/base.py @@ -3,7 +3,7 @@ # --------------------------------------------------------- # pylint: disable=broad-exception-caught,unused-argument,logging-fstring-interpolation,too-many-statements,too-many-return-statements # mypy: ignore-errors -import asyncio +import asyncio # pylint: disable=C4763 import inspect import json import os @@ -12,8 +12,6 @@ from typing import Any, AsyncGenerator, Generator, Optional, Union import uvicorn -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential from opentelemetry import context as otel_context, trace from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from starlette.applications import Starlette @@ -25,10 +23,12 @@ from starlette.routing import Route from starlette.types import ASGIApp +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential from azure.identity.aio import DefaultAzureCredential as AsyncDefaultTokenCredential from ._context import AgentServerContext -from ..models import projects as project_models +from ..models import projects as project_models from ..constants import Constants from ..logger import APPINSIGHT_CONNSTR_ENV_NAME, get_logger, get_project_endpoint, request_context from ..models import ( @@ -37,8 +37,7 @@ ) from .common.agent_run_context import AgentRunContext -from ..tools import DefaultFoundryToolRuntime, FoundryTool, FoundryToolClient, FoundryToolRuntime, UserInfo, \ - UserInfoContextMiddleware +from ..tools import DefaultFoundryToolRuntime, UserInfoContextMiddleware from ..utils._credential import AsyncTokenCredentialAdapter logger = get_logger() @@ -405,24 +404,6 @@ def setup_otlp_exporter(self, endpoint, provider): provider.add_span_processor(processor) logger.info(f"Tracing setup with OTLP exporter: {endpoint}") - def get_tool_client( - self, tools: Optional[list[FoundryTool]], user_info: Optional[UserInfo] - ) -> FoundryToolClient: - # TODO: remove this method - logger.debug("Creating AzureAIToolClient with tools: %s", tools) - if not self.credentials: - raise ValueError("Credentials are required to create Tool Client.") - - tools_endpoint, agent_name = self._configure_endpoint() - - return FoundryToolClient( - endpoint=tools_endpoint, - credential=self.credentials, - tools=tools, - user=user_info, - agent_name=agent_name, - ) - def _event_to_sse_chunk(event: ResponseStreamEvent) -> str: event_data = json.dumps(event.as_dict()) @@ -432,7 +413,11 @@ def _event_to_sse_chunk(event: ResponseStreamEvent) -> str: def _keep_alive_comment() -> str: - """Generate a keep-alive SSE comment to maintain connection.""" + """Generate a keep-alive SSE comment to maintain connection. + + :return: The keep-alive comment string. + :rtype: str + """ return ": keep-alive\n\n" @@ -440,15 +425,20 @@ async def _iter_with_keep_alive( it: AsyncGenerator[ResponseStreamEvent, None] ) -> AsyncGenerator[Optional[ResponseStreamEvent], None]: """Wrap an async iterator with keep-alive mechanism. - + If no event is received within KEEP_ALIVE_INTERVAL seconds, yields None as a signal to send a keep-alive comment. The original iterator is protected with asyncio.shield to ensure it continues running even when timeout occurs. + + :param it: The async generator to wrap. + :type it: AsyncGenerator[ResponseStreamEvent, None] + :return: An async generator that yields events or None for keep-alive. + :rtype: AsyncGenerator[Optional[ResponseStreamEvent], None] """ it_anext = it.__anext__ pending_task: Optional[asyncio.Task] = None - + while True: try: # If there's a pending task from previous timeout, wait for it first @@ -457,14 +447,14 @@ async def _iter_with_keep_alive( pending_task = None yield event continue - + # Create a task for the next event next_event_task = asyncio.create_task(it_anext()) - + try: # Shield the task and wait with timeout event = await asyncio.wait_for( - asyncio.shield(next_event_task), + asyncio.shield(next_event_task), timeout=KEEP_ALIVE_INTERVAL ) yield event @@ -473,7 +463,7 @@ async def _iter_with_keep_alive( # Save task to check in next iteration pending_task = next_event_task yield None - + except StopAsyncIteration: # Iterator exhausted break diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/constants.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/constants.py index 7d21ee7a31ff..6d4fb628a7f2 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/constants.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/constants.py @@ -3,4 +3,4 @@ # --------------------------------------------------------- # Reserved function name for HITL. -HUMAN_IN_THE_LOOP_FUNCTION_NAME = "__hosted_agent_adapter_hitl__" \ No newline at end of file +HUMAN_IN_THE_LOOP_FUNCTION_NAME = "__hosted_agent_adapter_hitl__" diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py index f158cd370990..5b356f38c825 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py @@ -5,13 +5,75 @@ __path__ = __import__('pkgutil').extend_path(__path__, __name__) from .client._client import FoundryToolClient -from ._exceptions import * -from .client._models import FoundryConnectedTool, FoundryHostedMcpTool, FoundryTool, FoundryToolProtocol, \ - FoundryToolSource, ResolvedFoundryTool, SchemaDefinition, SchemaProperty, SchemaType, UserInfo -from .runtime._catalog import * -from .runtime._facade import * -from .runtime._invoker import * -from .runtime._resolver import * -from .runtime._runtime import * -from .runtime._starlette import * -from .runtime._user import * \ No newline at end of file +from ._exceptions import ( + ToolInvocationError, + OAuthConsentRequiredError, + UnableToResolveToolInvocationError, + InvalidToolFacadeError, +) +from .client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryTool, + FoundryToolProtocol, + FoundryToolSource, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, + UserInfo, +) +from .runtime._catalog import ( + FoundryToolCatalog, + CachedFoundryToolCatalog, + DefaultFoundryToolCatalog, +) +from .runtime._facade import FoundryToolFacade, FoundryToolLike, ensure_foundry_tool +from .runtime._invoker import FoundryToolInvoker, DefaultFoundryToolInvoker +from .runtime._resolver import FoundryToolInvocationResolver, DefaultFoundryToolInvocationResolver +from .runtime._runtime import FoundryToolRuntime, DefaultFoundryToolRuntime +from .runtime._starlette import UserInfoContextMiddleware +from .runtime._user import UserProvider, ContextVarUserProvider + +__all__ = [ + # Client + "FoundryToolClient", + # Exceptions + "ToolInvocationError", + "OAuthConsentRequiredError", + "UnableToResolveToolInvocationError", + "InvalidToolFacadeError", + # Models + "FoundryConnectedTool", + "FoundryHostedMcpTool", + "FoundryTool", + "FoundryToolProtocol", + "FoundryToolSource", + "ResolvedFoundryTool", + "SchemaDefinition", + "SchemaProperty", + "SchemaType", + "UserInfo", + # Catalog + "FoundryToolCatalog", + "CachedFoundryToolCatalog", + "DefaultFoundryToolCatalog", + # Facade + "FoundryToolFacade", + "FoundryToolLike", + "ensure_foundry_tool", + # Invoker + "FoundryToolInvoker", + "DefaultFoundryToolInvoker", + # Resolver + "FoundryToolInvocationResolver", + "DefaultFoundryToolInvocationResolver", + # Runtime + "FoundryToolRuntime", + "DefaultFoundryToolRuntime", + # Starlette + "UserInfoContextMiddleware", + # User + "UserProvider", + "ContextVarUserProvider", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/_exceptions.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/_exceptions.py index b91c1f71c7a3..a5fe7726e9f1 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/_exceptions.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/_exceptions.py @@ -72,5 +72,3 @@ class InvalidToolFacadeError(RuntimeError): This exception is raised when a tool facade does not conform to the expected structure or contains invalid data. """ - pass - diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py index cbd0dbba6aa6..d51f73d8737b 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py @@ -1,26 +1,43 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -import asyncio +import asyncio # pylint: disable=C4763 import itertools from collections import defaultdict -from typing import Any, AsyncContextManager, AsyncIterable, Awaitable, Callable, Collection, Coroutine, DefaultDict, Dict, \ - Iterable, List, \ - Mapping, Optional, \ - Tuple +from typing import ( + Any, + AsyncContextManager, + Awaitable, Collection, + DefaultDict, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + cast, +) from azure.core import AsyncPipelineClient from azure.core.credentials_async import AsyncTokenCredential from azure.core.tracing.decorator_async import distributed_trace_async from ._configuration import FoundryToolClientConfiguration -from ._models import FoundryTool, FoundryToolDetails, FoundryToolSource, ResolvedFoundryTool, UserInfo +from ._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryTool, + FoundryToolDetails, + FoundryToolSource, + ResolvedFoundryTool, + UserInfo, +) from .operations._foundry_connected_tools import FoundryConnectedToolsOperations from .operations._foundry_hosted_mcp_tools import FoundryMcpToolsOperations from .._exceptions import ToolInvocationError -class FoundryToolClient(AsyncContextManager["FoundryToolClient"]): +class FoundryToolClient(AsyncContextManager["FoundryToolClient"]): # pylint: disable=C4748 """Asynchronous client for aggregating tools from Azure AI MCP and Tools APIs. This client provides access to tools from both MCP (Model Context Protocol) servers @@ -33,15 +50,23 @@ class FoundryToolClient(AsyncContextManager["FoundryToolClient"]): Credential for authenticating requests to the service. Use credentials from azure-identity like DefaultAzureCredential. :type credential: ~azure.core.credentials.TokenCredential + :param api_version: The API version to use for this operation. + :type api_version: str or None """ - def __init__(self, endpoint: str, credential: "AsyncTokenCredential"): + def __init__( # pylint: disable=C4718 + self, + endpoint: str, + credential: "AsyncTokenCredential", + ) -> None: """Initialize the asynchronous Azure AI Tool Client. :param endpoint: The service endpoint URL. :type endpoint: str :param credential: Credentials for authenticating requests. :type credential: ~azure.core.credentials.TokenCredential + :param api_version: The API version to use for this operation. + :type api_version: str or None """ # noinspection PyTypeChecker config = FoundryToolClientConfiguration(credential) @@ -51,10 +76,13 @@ def __init__(self, endpoint: str, credential: "AsyncTokenCredential"): self._connected_tools = FoundryConnectedToolsOperations(self._client) @distributed_trace_async - async def list_tools(self, - tools: Collection[FoundryTool], - agent_name, - user: Optional[UserInfo] = None) -> List[ResolvedFoundryTool]: + async def list_tools( + self, + tools: Collection[FoundryTool], + agent_name: str, + user: Optional[UserInfo] = None, + **kwargs: Any + ) -> List[ResolvedFoundryTool]: """List all available tools from configured sources. Retrieves tools from both MCP servers and Azure AI Tools API endpoints, @@ -72,6 +100,7 @@ async def list_tools(self, :raises ~azure.core.exceptions.HttpResponseError: Raised for HTTP communication failures. """ + _ = kwargs # Reserved for future use resolved_tools: List[ResolvedFoundryTool] = [] results = await self._list_tools_details_internal(tools, agent_name, user) for definition, details in results: @@ -79,10 +108,13 @@ async def list_tools(self, return resolved_tools @distributed_trace_async - async def list_tools_details(self, - tools: Collection[FoundryTool], - agent_name, - user: Optional[UserInfo] = None) -> Mapping[str, List[FoundryToolDetails]]: + async def list_tools_details( + self, + tools: Collection[FoundryTool], + agent_name: str, + user: Optional[UserInfo] = None, + **kwargs: Any + ) -> Mapping[str, List[FoundryToolDetails]]: """List all available tools from configured sources. Retrieves tools from both MCP servers and Azure AI Tools API endpoints, @@ -100,6 +132,7 @@ async def list_tools_details(self, :raises ~azure.core.exceptions.HttpResponseError: Raised for HTTP communication failures. """ + _ = kwargs # Reserved for future use resolved_tools: Dict[str, List[FoundryToolDetails]] = defaultdict(list) results = await self._list_tools_details_internal(tools, agent_name, user) for definition, details in results: @@ -109,33 +142,32 @@ async def list_tools_details(self, async def _list_tools_details_internal( self, tools: Collection[FoundryTool], - agent_name, + agent_name: str, user: Optional[UserInfo] = None, ) -> Iterable[Tuple[FoundryTool, FoundryToolDetails]]: tools_by_source: DefaultDict[FoundryToolSource, List[FoundryTool]] = defaultdict(list) for t in tools: tools_by_source[t.source].append(t) - listing_tools = [] + listing_tools: List[Awaitable[Iterable[Tuple[FoundryTool, FoundryToolDetails]]]] = [] if FoundryToolSource.HOSTED_MCP in tools_by_source: - # noinspection PyTypeChecker - listing_tools.append(asyncio.create_task( - self._hosted_mcp_tools.list_tools(tools_by_source[FoundryToolSource.HOSTED_MCP]) - )) + hosted_mcp_tools = cast(List[FoundryHostedMcpTool], tools_by_source[FoundryToolSource.HOSTED_MCP]) + listing_tools.append(self._hosted_mcp_tools.list_tools(hosted_mcp_tools)) if FoundryToolSource.CONNECTED in tools_by_source: - # noinspection PyTypeChecker - listing_tools.append(asyncio.create_task( - self._connected_tools.list_tools(tools_by_source[FoundryToolSource.CONNECTED], user, agent_name) - )) + connected_tools = cast(List[FoundryConnectedTool], tools_by_source[FoundryToolSource.CONNECTED]) + listing_tools.append(self._connected_tools.list_tools(connected_tools, user, agent_name)) iters = await asyncio.gather(*listing_tools) return itertools.chain.from_iterable(iters) @distributed_trace_async - async def invoke_tool(self, - tool: ResolvedFoundryTool, - arguments: Dict[str, Any], - agent_name: str, - user: Optional[UserInfo] = None) -> Any: + async def invoke_tool( + self, + tool: ResolvedFoundryTool, + arguments: Dict[str, Any], + agent_name: str, + user: Optional[UserInfo] = None, + **kwargs: Any + ) -> Any: """Invoke a tool by instance, name, or descriptor. :param tool: Tool to invoke, specified as an AzureAITool instance, @@ -156,6 +188,7 @@ async def invoke_tool(self, :raises ~ToolInvocationError: Raised when the tool invocation fails or source is not supported. """ + _ = kwargs # Reserved for future use if tool.source is FoundryToolSource.HOSTED_MCP: return await self._hosted_mcp_tools.invoke_tool(tool, arguments) if tool.source is FoundryToolSource.CONNECTED: diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_configuration.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_configuration.py index 5c3f19a61d55..c496ef563216 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_configuration.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_configuration.py @@ -9,27 +9,27 @@ class FoundryToolClientConfiguration(Configuration): # pylint: disable=too-many-instance-attributes - """Configuration for Azure AI Tool Client. + """Configuration for Azure AI Tool Client. - Manages authentication, endpoint configuration, and policy settings for the - Azure AI Tool Client. This class is used internally by the client and should - not typically be instantiated directly. + Manages authentication, endpoint configuration, and policy settings for the + Azure AI Tool Client. This class is used internally by the client and should + not typically be instantiated directly. - :param credential: - Azure TokenCredential for authentication. - :type credential: ~azure.core.credentials.TokenCredential - """ + :param credential: + Azure TokenCredential for authentication. + :type credential: ~azure.core.credentials.TokenCredential + """ - def __init__(self, credential: "AsyncTokenCredential"): - super().__init__() + def __init__(self, credential: "AsyncTokenCredential"): + super().__init__() - self.retry_policy = policies.AsyncRetryPolicy() - self.logging_policy = policies.NetworkTraceLoggingPolicy() - self.request_id_policy = policies.RequestIdPolicy() - self.http_logging_policy = policies.HttpLoggingPolicy() - self.user_agent_policy = policies.UserAgentPolicy( - base_user_agent=get_current_app().as_user_agent("FoundryToolClient")) - self.authentication_policy = policies.AsyncBearerTokenCredentialPolicy( - credential, "https://ai.azure.com/.default" - ) - self.redirect_policy = policies.AsyncRedirectPolicy() + self.retry_policy = policies.AsyncRetryPolicy() + self.logging_policy = policies.NetworkTraceLoggingPolicy() + self.request_id_policy = policies.RequestIdPolicy() + self.http_logging_policy = policies.HttpLoggingPolicy() + self.user_agent_policy = policies.UserAgentPolicy( + base_user_agent=get_current_app().as_user_agent("FoundryToolClient")) + self.authentication_policy = policies.AsyncBearerTokenCredentialPolicy( + credential, "https://ai.azure.com/.default" + ) + self.redirect_policy = policies.AsyncRedirectPolicy() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py index 8664e23c7c8b..c4d5c4d96a28 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py @@ -1,552 +1,598 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -import asyncio -import inspect from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum -from typing import Annotated, Any, Awaitable, Callable, ClassVar, Dict, Iterable, List, Literal, Mapping, Optional, Set, Type, Union +from typing import ( + Annotated, + Any, + Awaitable, + Callable, + ClassVar, + Dict, + Iterable, + List, + Literal, + Mapping, + Optional, + Set, + Type, + Union, +) + +from pydantic import ( + AliasChoices, + AliasPath, + BaseModel, + Discriminator, + Field, + ModelWrapValidatorHandler, + Tag, + TypeAdapter, + model_validator, +) from azure.core import CaseInsensitiveEnumMeta -from pydantic import AliasChoices, AliasPath, BaseModel, Discriminator, Field, ModelWrapValidatorHandler, Tag, \ - TypeAdapter, model_validator from .._exceptions import OAuthConsentRequiredError class FoundryToolSource(str, Enum, metaclass=CaseInsensitiveEnumMeta): - """Identifies the origin of a tool. + """Identifies the origin of a tool. - Specifies whether a tool comes from an MCP (Model Context Protocol) server - or from the Azure AI Tools API (remote tools). - """ + Specifies whether a tool comes from an MCP (Model Context Protocol) server + or from the Azure AI Tools API (remote tools). + """ - HOSTED_MCP = "hosted_mcp" - CONNECTED = "connected" + HOSTED_MCP = "hosted_mcp" + CONNECTED = "connected" class FoundryToolProtocol(str, Enum, metaclass=CaseInsensitiveEnumMeta): - """Identifies the protocol used by a connected tool.""" + """Identifies the protocol used by a connected tool.""" - MCP = "mcp" - A2A = "a2a" + MCP = "mcp" + A2A = "a2a" -@dataclass(frozen=True, kw_only=True, eq=False) +@dataclass(frozen=True, eq=False) class FoundryTool(ABC): - """Definition of a foundry tool including its parameters.""" - source: FoundryToolSource = field(init=False) + """Definition of a foundry tool including its parameters.""" + source: FoundryToolSource = field(init=False) - @property - @abstractmethod - def id(self) -> str: - """Unique identifier for the tool.""" - raise NotImplementedError + @property + @abstractmethod + def id(self) -> str: + """Unique identifier for the tool. - def __str__(self): - return self.id + :rtype: str + """ + raise NotImplementedError + def __str__(self): + return self.id -@dataclass(frozen=True, kw_only=True, eq=False) + +@dataclass(frozen=True, eq=False) class FoundryHostedMcpTool(FoundryTool): - """Foundry MCP tool definition. + """Foundry MCP tool definition. + + :ivar str name: Name of MCP tool. + :ivar Mapping[str, Any] configuration: Tools configuration. + """ + source: Literal[FoundryToolSource.HOSTED_MCP] = field(init=False, default=FoundryToolSource.HOSTED_MCP) + name: str + configuration: Optional[Mapping[str, Any]] = None - :ivar str name: Name of MCP tool. - :ivar Mapping[str, Any] configuration: Tools configuration. - """ - source: Literal[FoundryToolSource.HOSTED_MCP] = field(init=False, default=FoundryToolSource.HOSTED_MCP) - name: str - configuration: Optional[Mapping[str, Any]] = None + @property + def id(self) -> str: + """Unique identifier for the tool. - @property - def id(self) -> str: - """Unique identifier for the tool.""" - return f"{self.source}:{self.name}" + :rtype: str + """ + return f"{self.source}:{self.name}" -@dataclass(frozen=True, kw_only=True, eq=False) +@dataclass(frozen=True, eq=False) class FoundryConnectedTool(FoundryTool): - """Foundry connected tool definition. + """Foundry connected tool definition. - :ivar str project_connection_id: connection name of foundry tool. - """ - source: Literal[FoundryToolSource.CONNECTED] = field(init=False, default=FoundryToolSource.CONNECTED) - protocol: str - project_connection_id: str + :ivar str project_connection_id: connection name of foundry tool. + """ + source: Literal[FoundryToolSource.CONNECTED] = field(init=False, default=FoundryToolSource.CONNECTED) + protocol: str + project_connection_id: str - @property - def id(self) -> str: - return f"{self.source}:{self.protocol}:{self.project_connection_id}" + @property + def id(self) -> str: + return f"{self.source}:{self.protocol}:{self.project_connection_id}" @dataclass(frozen=True) class FoundryToolDetails: - """Details about a Foundry tool. + """Details about a Foundry tool. - :ivar str name: Name of the tool. - :ivar str description: Description of the tool. - :ivar SchemaDefinition input_schema: Input schema for the tool parameters. - :ivar Optional[SchemaDefinition] metadata: Optional metadata schema for the tool. - """ - name: str - description: str - input_schema: "SchemaDefinition" - metadata: Optional["SchemaDefinition"] = None + :ivar str name: Name of the tool. + :ivar str description: Description of the tool. + :ivar SchemaDefinition input_schema: Input schema for the tool parameters. + :ivar Optional[SchemaDefinition] metadata: Optional metadata schema for the tool. + """ + name: str + description: str + input_schema: "SchemaDefinition" + metadata: Optional["SchemaDefinition"] = None @dataclass(frozen=True) class ResolvedFoundryTool: - """Resolved Foundry tool with definition and details. + """Resolved Foundry tool with definition and details. + + :ivar ToolDefinition definition: + Optional tool definition object, or None. + :ivar FoundryToolDetails details: + Details about the tool, including name, description, and input schema. + """ + + definition: FoundryTool + details: FoundryToolDetails + + @property + def id(self) -> str: + return f"{self.definition.id}:{self.details.name}" + + @property + def source(self) -> FoundryToolSource: + """Origin of the tool. - :ivar ToolDefinition definition: - Optional tool definition object, or None. - :ivar FoundryToolDetails details: - Details about the tool, including name, description, and input schema. - """ + :rtype: FoundryToolSource + """ + return self.definition.source - definition: FoundryTool - details: FoundryToolDetails - invoker: Optional[Callable[..., Awaitable[Any]]] = None # TODO: deprecated + @property + def name(self) -> str: + """Name of the tool. - @property - def id(self) -> str: - return f"{self.definition.id}:{self.details.name}" + :rtype: str + """ + return self.details.name - @property - def source(self) -> FoundryToolSource: - """Origin of the tool.""" - return self.definition.source + @property + def description(self) -> str: + """Description of the tool. - @property - def name(self) -> str: - """Name of the tool.""" - return self.details.name + :rtype: str + """ + return self.details.description - @property - def description(self) -> str: - """Description of the tool.""" - return self.details.description + @property + def input_schema(self) -> "SchemaDefinition": + """Input schema of the tool. - @property - def input_schema(self) -> "SchemaDefinition": - """Input schema of the tool.""" - return self.details.input_schema + :rtype: SchemaDefinition + """ + return self.details.input_schema - @property - def metadata(self) -> Optional["SchemaDefinition"]: - """Metadata schema of the tool, if any.""" - return self.details.metadata + @property + def metadata(self) -> Optional["SchemaDefinition"]: + """Metadata schema of the tool, if any.""" + return self.details.metadata @dataclass(frozen=True) class UserInfo: - """Represents user information. + """Represents user information. - :ivar str object_id: User's object identifier. - :ivar str tenant_id: Tenant identifier. - """ + :ivar str object_id: User's object identifier. + :ivar str tenant_id: Tenant identifier. + """ - object_id: str - tenant_id: str + object_id: str + tenant_id: str -class SchemaType(str, Enum): - """ - Enumeration of possible schema types. +class SchemaType(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """ + Enumeration of possible schema types. - :ivar py_type: The corresponding Python runtime type for this schema type - (e.g., ``SchemaType.STRING.py_type is str``). - """ + :ivar py_type: The corresponding Python runtime type for this schema type + (e.g., ``SchemaType.STRING.py_type is str``). + """ - py_type: Type[Any] - """The corresponding Python runtime type for this schema type.""" + py_type: Type[Any] + """The corresponding Python runtime type for this schema type.""" - STRING = ("string", str) - """Schema type for string values (maps to ``str``).""" + STRING = ("string", str) + """Schema type for string values (maps to ``str``).""" - NUMBER = ("number", float) - """Schema type for numeric values with decimals (maps to ``float``).""" + NUMBER = ("number", float) + """Schema type for numeric values with decimals (maps to ``float``).""" - INTEGER = ("integer", int) - """Schema type for integer values (maps to ``int``).""" + INTEGER = ("integer", int) + """Schema type for integer values (maps to ``int``).""" - BOOLEAN = ("boolean", bool) - """Schema type for boolean values (maps to ``bool``).""" + BOOLEAN = ("boolean", bool) + """Schema type for boolean values (maps to ``bool``).""" - ARRAY = ("array", list) - """Schema type for array values (maps to ``list``).""" + ARRAY = ("array", list) + """Schema type for array values (maps to ``list``).""" - OBJECT = ("object", dict) - """Schema type for object/dictionary values (maps to ``dict``).""" + OBJECT = ("object", dict) + """Schema type for object/dictionary values (maps to ``dict``).""" - def __new__(cls, value: str, py_type: Type[Any]): - """ - Create an enum member whose value is the schema type string, while also - attaching the mapped Python type. + def __new__(cls, value: str, py_type: Type[Any]): + """ + Create an enum member whose value is the schema type string, while also + attaching the mapped Python type. - :param value: The serialized schema type string (e.g. ``"string"``). - :param py_type: The mapped Python runtime type (e.g. ``str``). - """ - obj = str.__new__(cls, value) - obj._value_ = value - obj.py_type = py_type - return obj + :param value: The serialized schema type string (e.g. ``"string"``). + :type value: str + :param py_type: The mapped Python runtime type (e.g. ``str``). + :type py_type: Type[Any] + :return: The created enum member. + :rtype: SchemaType + """ + obj = str.__new__(cls, value) + obj._value_ = value + obj.py_type = py_type + return obj - @classmethod - def from_python_type(cls, t: Type[Any]) -> "SchemaType": - """ - Get the matching :class:`SchemaType` for a given Python runtime type. + @classmethod + def from_python_type(cls, t: Type[Any]) -> "SchemaType": + """ + Get the matching :class:`SchemaType` for a given Python runtime type. - :param t: A Python runtime type (e.g. ``str``, ``int``, ``float``). - :returns: The corresponding :class:`SchemaType`. - :raises ValueError: If ``t`` is not supported by this enumeration. - """ - for member in cls: - if member.py_type is t: - return member - raise ValueError(f"Unsupported python type: {t!r}") + :param t: A Python runtime type (e.g. ``str``, ``int``, ``float``). + :type t: Type[Any] + :returns: The corresponding :class:`SchemaType`. + :rtype: SchemaType + :raises ValueError: If ``t`` is not supported by this enumeration. + """ + for member in cls: + if member.py_type is t: + return member + raise ValueError(f"Unsupported python type: {t!r}") class SchemaProperty(BaseModel): - """ - A JSON Schema-like description of a single property (field) or nested schema node. - - This model is intended to be recursively nestable via :attr:`items` (for arrays) - and :attr:`properties` (for objects). - - :ivar type: The schema node type (e.g., ``string``, ``object``, ``array``). - :ivar description: Optional human-readable description of the property. - :ivar items: The item schema for an ``array`` type. Typically set when - :attr:`type` is :data:`~SchemaType.ARRAY`. - :ivar properties: Nested properties for an ``object`` type. Typically set when - :attr:`type` is :data:`~SchemaType.OBJECT`. Keys are property names, values - are their respective schemas. - :ivar default: Optional default value for the property. - :ivar required: For an ``object`` schema node, the set of required property - names within :attr:`properties`. (This mirrors JSON Schema’s ``required`` - keyword; it is *not* “this property is required in a parent object”.) - """ - - type: SchemaType - description: Optional[str] = None - items: Optional["SchemaProperty"] = None - properties: Optional[Mapping[str, "SchemaProperty"]] = None - default: Any = None - required: Optional[Set[str]] = None - - def has_default(self) -> bool: - """ - Check if the property has a default value defined. - - :return: True if a default value is set, False otherwise. - :rtype: bool - """ - return "default" in self.model_fields_set + """ + A JSON Schema-like description of a single property (field) or nested schema node. + + This model is intended to be recursively nestable via :attr:`items` (for arrays) + and :attr:`properties` (for objects). + + :ivar type: The schema node type (e.g., ``string``, ``object``, ``array``). + :ivar description: Optional human-readable description of the property. + :ivar items: The item schema for an ``array`` type. Typically set when + :attr:`type` is :data:`~SchemaType.ARRAY`. + :ivar properties: Nested properties for an ``object`` type. Typically set when + :attr:`type` is :data:`~SchemaType.OBJECT`. Keys are property names, values + are their respective schemas. + :ivar default: Optional default value for the property. + :ivar required: For an ``object`` schema node, the set of required property + names within :attr:`properties`. (This mirrors JSON Schema’s ``required`` + keyword; it is *not* “this property is required in a parent object”.) + """ + + type: SchemaType + description: Optional[str] = None + items: Optional["SchemaProperty"] = None + properties: Optional[Mapping[str, "SchemaProperty"]] = None + default: Any = None + required: Optional[Set[str]] = None + + def has_default(self) -> bool: + """ + Check if the property has a default value defined. + + :return: True if a default value is set, False otherwise. + :rtype: bool + """ + return "default" in self.model_fields_set class SchemaDefinition(BaseModel): - """ - A top-level JSON Schema-like definition for an object. - - :ivar type: The schema type of the root. Typically :data:`~SchemaType.OBJECT`. - :ivar properties: Mapping of top-level property names to their schemas. - :ivar required: Set of required top-level property names within - :attr:`properties`. - """ - - type: SchemaType = SchemaType.OBJECT - properties: Mapping[str, SchemaProperty] - required: Optional[Set[str]] = None - - def extract_from(self, - datasource: Mapping[str, Any], - property_alias: Optional[Dict[str, List[str]]] = None) -> Dict[str, Any]: - return self._extract(datasource, self.properties, self.required, property_alias) - - @classmethod - def _extract(cls, - datasource: Mapping[str, Any], - properties: Mapping[str, SchemaProperty], - required: Optional[Set[str]] = None, - property_alias: Optional[Dict[str, List[str]]] = None) -> Dict[str, Any]: - result: Dict[str, Any] = {} - - for property_name, schema in properties.items(): - # Determine the keys to look for in the datasource - keys_to_check = [property_name] - if property_alias and property_name in property_alias: - keys_to_check.extend(property_alias[property_name]) - - # Find the first matching key in the datasource - value_found = False - for key in keys_to_check: - if key in datasource: - value = datasource[key] - value_found = True - break - - if not value_found and schema.has_default(): - value = schema.default - value_found = True - - if not value_found: - # If the property is required but not found, raise an error - if required and property_name in required: - raise KeyError(f"Required property '{property_name}' not found in datasource.") - # If not found and not required, skip to next property - continue - - # Process the value based on its schema type - if schema.type == SchemaType.OBJECT and schema.properties: - if isinstance(value, Mapping): - nested_value = cls._extract( - value, - schema.properties, - schema.required, - property_alias - ) - result[property_name] = nested_value - elif schema.type == SchemaType.ARRAY and schema.items: - if isinstance(value, Iterable): - nested_list = [] - for item in value: - if schema.items.type == SchemaType.OBJECT and schema.items.properties: - if isinstance(item, dict): - nested_item = SchemaDefinition._extract( - item, - schema.items.properties, - schema.items.required, - property_alias - ) - nested_list.append(nested_item) - else: - nested_list.append(item) - result[property_name] = nested_list - else: - result[property_name] = value - - return result + """ + A top-level JSON Schema-like definition for an object. + + :ivar type: The schema type of the root. Typically :data:`~SchemaType.OBJECT`. + :ivar properties: Mapping of top-level property names to their schemas. + :ivar required: Set of required top-level property names within + :attr:`properties`. + """ + + type: SchemaType = SchemaType.OBJECT + properties: Mapping[str, SchemaProperty] = field(default_factory=dict) + required: Optional[Set[str]] = None + + def extract_from(self, + datasource: Mapping[str, Any], + property_alias: Optional[Dict[str, List[str]]] = None) -> Dict[str, Any]: + return self._extract(datasource, self.properties, self.required, property_alias) + + @classmethod + def _extract(cls, + datasource: Mapping[str, Any], + properties: Mapping[str, SchemaProperty], + required: Optional[Set[str]] = None, + property_alias: Optional[Dict[str, List[str]]] = None) -> Dict[str, Any]: + result: Dict[str, Any] = {} + + for property_name, schema in properties.items(): + # Determine the keys to look for in the datasource + keys_to_check = [property_name] + if property_alias and property_name in property_alias: + keys_to_check.extend(property_alias[property_name]) + + # Find the first matching key in the datasource + value_found = False + for key in keys_to_check: + if key in datasource: + value = datasource[key] + value_found = True + break + + if not value_found and schema.has_default(): + value = schema.default + value_found = True + + if not value_found: + # If the property is required but not found, raise an error + if required and property_name in required: + raise KeyError(f"Required property '{property_name}' not found in datasource.") + # If not found and not required, skip to next property + continue + + # Process the value based on its schema type + if schema.type == SchemaType.OBJECT and schema.properties: + if isinstance(value, Mapping): + nested_value = cls._extract( + value, + schema.properties, + schema.required, + property_alias + ) + result[property_name] = nested_value + elif schema.type == SchemaType.ARRAY and schema.items: + if isinstance(value, Iterable): + nested_list = [] + for item in value: + if schema.items.type == SchemaType.OBJECT and schema.items.properties: + nested_item = SchemaDefinition._extract( + item, + schema.items.properties, + schema.items.required, + property_alias + ) + nested_list.append(nested_item) + else: + nested_list.append(item) + result[property_name] = nested_list + else: + result[property_name] = value + + return result class RawFoundryHostedMcpTool(BaseModel): - """Pydantic model for a single MCP tool. + """Pydantic model for a single MCP tool. - :ivar str name: Unique name identifier of the tool. - :ivar Optional[str] title: Display title of the tool, defaults to name if not provided. - :ivar str description: Human-readable description of the tool. - :ivar SchemaDefinition input_schema: JSON schema for tool input parameters. - :ivar Optional[SchemaDefinition] meta: Optional metadata for the tool. - """ + :ivar str name: Unique name identifier of the tool. + :ivar Optional[str] title: Display title of the tool, defaults to name if not provided. + :ivar str description: Human-readable description of the tool. + :ivar SchemaDefinition input_schema: JSON schema for tool input parameters. + :ivar Optional[SchemaDefinition] meta: Optional metadata for the tool. + """ - name: str - title: Optional[str] = None - description: str = "" - input_schema: SchemaDefinition = Field( - default_factory=SchemaDefinition, - validation_alias="inputSchema" - ) - meta: Optional[SchemaDefinition] = Field(default=None, validation_alias="_meta") + name: str + title: Optional[str] = None + description: str = "" + input_schema: SchemaDefinition = Field( + default_factory=SchemaDefinition, + validation_alias="inputSchema" + ) + meta: Optional[SchemaDefinition] = Field(default=None, validation_alias="_meta") - def model_post_init(self, __context: Any) -> None: - if self.title is None: - self.title = self.name + def model_post_init(self, __context: Any) -> None: + if self.title is None: + self.title = self.name class RawFoundryHostedMcpTools(BaseModel): - """Pydantic model for the result containing list of tools. + """Pydantic model for the result containing list of tools. - :ivar List[RawFoundryHostedMcpTool] tools: List of MCP tool definitions. - """ + :ivar List[RawFoundryHostedMcpTool] tools: List of MCP tool definitions. + """ - tools: List[RawFoundryHostedMcpTool] = Field(default_factory=list) + tools: List[RawFoundryHostedMcpTool] = Field(default_factory=list) class ListFoundryHostedMcpToolsResponse(BaseModel): - """Pydantic model for the complete MCP tools/list JSON-RPC response. + """Pydantic model for the complete MCP tools/list JSON-RPC response. - :ivar str jsonrpc: JSON-RPC version, defaults to "2.0". - :ivar int id: Request identifier, defaults to 0. - :ivar RawFoundryHostedMcpTools result: Result containing the list of tools. - """ + :ivar str jsonrpc: JSON-RPC version, defaults to "2.0". + :ivar int id: Request identifier, defaults to 0. + :ivar RawFoundryHostedMcpTools result: Result containing the list of tools. + """ - jsonrpc: str = "2.0" - id: int = 0 - result: RawFoundryHostedMcpTools = Field( - default_factory=RawFoundryHostedMcpTools - ) + jsonrpc: str = "2.0" + id: int = 0 + result: RawFoundryHostedMcpTools = Field( + default_factory=RawFoundryHostedMcpTools + ) class BaseConnectedToolsErrorResult(BaseModel, ABC): - """Base model for connected tools error responses.""" + """Base model for connected tools error responses.""" - @abstractmethod - def as_exception(self) -> Exception: - """Convert the error result to an appropriate exception. + @abstractmethod + def as_exception(self) -> Exception: + """Convert the error result to an appropriate exception. - :return: An exception representing the error. - :rtype: Exception - """ - raise NotImplementedError + :return: An exception representing the error. + :rtype: Exception + """ + raise NotImplementedError class OAuthConsentRequiredErrorResult(BaseConnectedToolsErrorResult): - """Model for OAuth consent required error responses. + """Model for OAuth consent required error responses. - :ivar Literal["OAuthConsentRequired"] type: Error type identifier. - :ivar Optional[str] consent_url: URL for user consent, if available. - :ivar Optional[str] message: Human-readable error message. - :ivar Optional[str] project_connection_id: Project connection ID related to the error. - """ + :ivar Literal["OAuthConsentRequired"] type: Error type identifier. + :ivar Optional[str] consent_url: URL for user consent, if available. + :ivar Optional[str] message: Human-readable error message. + :ivar Optional[str] project_connection_id: Project connection ID related to the error. + """ - type: Literal["OAuthConsentRequired"] - consent_url: str = Field( - validation_alias=AliasChoices( + type: Literal["OAuthConsentRequired"] + consent_url: str = Field( + validation_alias=AliasChoices( AliasPath("toolResult", "consentUrl"), AliasPath("toolResult", "message"), ), - ) - message: str = Field( - validation_alias=AliasPath("toolResult", "message"), - ) - project_connection_id: str = Field( - validation_alias=AliasPath("toolResult", "projectConnectionId"), - ) + ) + message: str = Field( + validation_alias=AliasPath("toolResult", "message"), + ) + project_connection_id: str = Field( + validation_alias=AliasPath("toolResult", "projectConnectionId"), + ) - def as_exception(self) -> Exception: - return OAuthConsentRequiredError(self.message, self.consent_url, self.project_connection_id) + def as_exception(self) -> Exception: + return OAuthConsentRequiredError(self.message, self.consent_url, self.project_connection_id) class RawFoundryConnectedTool(BaseModel): - """Pydantic model for a single connected tool. + """Pydantic model for a single connected tool. - :ivar str name: Name of the tool. - :ivar str description: Description of the tool. - :ivar Optional[SchemaDefinition] input_schema: Input schema for the tool parameters. - """ - name: str - description: str - input_schema: SchemaDefinition = Field( - default=SchemaDefinition, - validation_alias="parameters", - ) + :ivar str name: Name of the tool. + :ivar str description: Description of the tool. + :ivar Optional[SchemaDefinition] input_schema: Input schema for the tool parameters. + """ + name: str + description: str + input_schema: SchemaDefinition = Field( + default_factory=SchemaDefinition, + validation_alias="parameters", + ) class RawFoundryConnectedRemoteServer(BaseModel): - """Pydantic model for a connected remote server. - - :ivar str protocol: Protocol used by the remote server. - :ivar str project_connection_id: Project connection ID of the remote server. - :ivar List[RawFoundryConnectedTool] tools: List of connected tools from this server. - """ - protocol: str = Field( - validation_alias=AliasPath("remoteServer", "protocol"), - ) - project_connection_id: str = Field( - validation_alias=AliasPath("remoteServer", "projectConnectionId"), - ) - tools: List[RawFoundryConnectedTool] = Field( - default_factory=list, - validation_alias="manifest", - ) + """Pydantic model for a connected remote server. + + :ivar str protocol: Protocol used by the remote server. + :ivar str project_connection_id: Project connection ID of the remote server. + :ivar List[RawFoundryConnectedTool] tools: List of connected tools from this server. + """ + protocol: str = Field( + validation_alias=AliasPath("remoteServer", "protocol"), + ) + project_connection_id: str = Field( + validation_alias=AliasPath("remoteServer", "projectConnectionId"), + ) + tools: List[RawFoundryConnectedTool] = Field( + default_factory=list, + validation_alias="manifest", + ) class ListConnectedToolsResult(BaseModel): - """Pydantic model for the result of listing connected tools. + """Pydantic model for the result of listing connected tools. - :ivar List[ConnectedRemoteServer] servers: List of connected remote servers. - """ - servers: List[RawFoundryConnectedRemoteServer] = Field( - default_factory=list, - validation_alias="tools", - ) + :ivar List[ConnectedRemoteServer] servers: List of connected remote servers. + """ + servers: List[RawFoundryConnectedRemoteServer] = Field( + default_factory=list, + validation_alias="tools", + ) class ListFoundryConnectedToolsResponse(BaseModel): - """Pydantic model for the response of listing the connected tools. - - :ivar Optional[ConnectedToolsResult] result: Result containing connected tool servers. - :ivar Optional[BaseConnectedToolsErrorResult] error: Error result, if any. - """ - - result: Optional[ListConnectedToolsResult] = None - error: Optional[BaseConnectedToolsErrorResult] = None - - # noinspection DuplicatedCode - _TYPE_ADAPTER: ClassVar[TypeAdapter] = TypeAdapter( - Annotated[ - Union[ - Annotated[ - Annotated[ - Union[OAuthConsentRequiredErrorResult], - Field(discriminator="type") - ], - Tag("ErrorType") - ], - Annotated[ListConnectedToolsResult, Tag("ResultType")], - ], - Discriminator( - lambda payload: "ErrorType" if isinstance(payload, dict) and "type" in payload else "ResultType" - ), - ]) - - @model_validator(mode="wrap") - @classmethod - def _validator(cls, data: Any, handler: ModelWrapValidatorHandler) -> "ListFoundryConnectedToolsResponse": - parsed = cls._TYPE_ADAPTER.validate_python(data) - normalized = {} - if isinstance(parsed, ListConnectedToolsResult): - normalized["result"] = parsed - elif isinstance(parsed, BaseConnectedToolsErrorResult): - normalized["error"] = parsed - return handler(normalized) + """Pydantic model for the response of listing the connected tools. + + :ivar Optional[ConnectedToolsResult] result: Result containing connected tool servers. + :ivar Optional[BaseConnectedToolsErrorResult] error: Error result, if any. + """ + + result: Optional[ListConnectedToolsResult] = None + error: Optional[BaseConnectedToolsErrorResult] = None + + # noinspection DuplicatedCode + _TYPE_ADAPTER: ClassVar[TypeAdapter] = TypeAdapter( + Annotated[ + Union[ + Annotated[ + Annotated[ + Union[OAuthConsentRequiredErrorResult], + Field(discriminator="type") + ], + Tag("ErrorType") + ], + Annotated[ListConnectedToolsResult, Tag("ResultType")], + ], + Discriminator( + lambda payload: "ErrorType" if isinstance(payload, dict) and "type" in payload else "ResultType" + ), + ]) + + @model_validator(mode="wrap") + @classmethod + def _validator(cls, data: Any, handler: ModelWrapValidatorHandler) -> "ListFoundryConnectedToolsResponse": + parsed = cls._TYPE_ADAPTER.validate_python(data) + normalized = {} + if isinstance(parsed, ListConnectedToolsResult): + normalized["result"] = parsed + elif isinstance(parsed, BaseConnectedToolsErrorResult): + normalized["error"] = parsed # type: ignore[assignment] + return handler(normalized) class InvokeConnectedToolsResult(BaseModel): - """Pydantic model for the result of invoking a connected tool. + """Pydantic model for the result of invoking a connected tool. - :ivar Any value: The result value from the tool invocation. - """ - value: Any = Field(validation_alias="toolResult") + :ivar Any value: The result value from the tool invocation. + """ + value: Any = Field(validation_alias="toolResult") class InvokeFoundryConnectedToolsResponse(BaseModel): - """Pydantic model for the response of invoking a connected tool. - - :ivar Optional[InvokeConnectedToolsResult] result: Result of the tool invocation. - :ivar Optional[BaseConnectedToolsErrorResult] error: Error result, if any. - """ - result: Optional[InvokeConnectedToolsResult] = None - error: Optional[BaseConnectedToolsErrorResult] = None - - # noinspection DuplicatedCode - _TYPE_ADAPTER: ClassVar[TypeAdapter] = TypeAdapter( - Annotated[ - Union[ - Annotated[ - Annotated[ - Union[OAuthConsentRequiredErrorResult], - Field(discriminator="type") - ], - Tag("ErrorType") - ], - Annotated[InvokeConnectedToolsResult, Tag("ResultType")], - ], - Discriminator( - lambda payload: "ErrorType" if isinstance(payload, dict) and - # handle other error types in the future - payload.get("type") == "OAuthConsentRequired" - else "ResultType" - ), - ]) - - @model_validator(mode="wrap") - @classmethod - def _validator(cls, data: Any, handler: ModelWrapValidatorHandler) -> "InvokeFoundryConnectedToolsResponse": - parsed = cls._TYPE_ADAPTER.validate_python(data) - normalized = {} - if isinstance(parsed, InvokeConnectedToolsResult): - normalized["result"] = parsed - elif isinstance(parsed, BaseConnectedToolsErrorResult): - normalized["error"] = parsed - return handler(normalized) + """Pydantic model for the response of invoking a connected tool. + + :ivar Optional[InvokeConnectedToolsResult] result: Result of the tool invocation. + :ivar Optional[BaseConnectedToolsErrorResult] error: Error result, if any. + """ + result: Optional[InvokeConnectedToolsResult] = None + error: Optional[BaseConnectedToolsErrorResult] = None + + # noinspection DuplicatedCode + _TYPE_ADAPTER: ClassVar[TypeAdapter] = TypeAdapter( + Annotated[ + Union[ + Annotated[ + Annotated[ + Union[OAuthConsentRequiredErrorResult], + Field(discriminator="type") + ], + Tag("ErrorType") + ], + Annotated[InvokeConnectedToolsResult, Tag("ResultType")], + ], + Discriminator( + lambda payload: "ErrorType" if isinstance(payload, dict) and + # handle other error types in the future + payload.get("type") == "OAuthConsentRequired" + else "ResultType" + ), + ]) + + @model_validator(mode="wrap") + @classmethod + def _validator(cls, data: Any, handler: ModelWrapValidatorHandler) -> "InvokeFoundryConnectedToolsResponse": + parsed: Union[InvokeConnectedToolsResult, BaseConnectedToolsErrorResult] = (cls._TYPE_ADAPTER + .validate_python(data)) + normalized: Dict[str, Any] = {} + if isinstance(parsed, InvokeConnectedToolsResult): + normalized["result"] = parsed + elif isinstance(parsed, BaseConnectedToolsErrorResult): + normalized["error"] = parsed + return handler(normalized) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_base.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_base.py index 5248ab7aa7fa..a3c552fe2575 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_base.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_base.py @@ -10,7 +10,7 @@ from azure.core import AsyncPipelineClient from azure.core.exceptions import ClientAuthenticationError, HttpResponseError, ResourceExistsError, \ ResourceNotFoundError, ResourceNotModifiedError, map_error -from azure.core.rest import AsyncHttpResponse, HttpRequest +from azure.core.pipeline.transport import AsyncHttpResponse, HttpRequest ErrorMapping = MutableMapping[int, Type[HttpResponseError]] @@ -67,7 +67,7 @@ def _extract_response_json(self, response: AsyncHttpResponse) -> Any: try: payload_text = response.text() payload_json = json.loads(payload_text) if payload_text else {} - except AttributeError as e: + except AttributeError: payload_bytes = response.body() payload_json = json.loads(payload_bytes.decode("utf-8")) if payload_bytes else {} return payload_json \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_hosted_mcp_tools.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_hosted_mcp_tools.py index 0c01164a6809..08587e274096 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_hosted_mcp_tools.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_hosted_mcp_tools.py @@ -2,9 +2,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- from abc import ABC -from typing import Any, AsyncIterable, ClassVar, Dict, Iterable, List, Mapping, TYPE_CHECKING, Tuple, cast +from typing import Any, AsyncIterable, ClassVar, Dict, Iterable, List, Tuple, cast -from azure.core.rest import HttpRequest +from azure.core.pipeline.transport import HttpRequest from azure.core.tracing.decorator_async import distributed_trace_async from ._base import BaseOperations @@ -88,7 +88,7 @@ def _convert_listed_tools( def _build_invoke_tool_request(self, tool: ResolvedFoundryTool, arguments: Dict[str, Any]) -> HttpRequest: if tool.definition.source != FoundryToolSource.HOSTED_MCP: raise ToolInvocationError(f"Tool {tool.name} is not a Foundry-hosted MCP tool.", tool=tool) - definition = cast(FoundryHostedMcpTool, tool.definition) if TYPE_CHECKING else tool.definition + definition = cast(FoundryHostedMcpTool, tool.definition) payload = dict(self._INVOKE_TOOL_REQUEST_BODY_TEMPLATE) payload["params"] = { diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_catalog.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_catalog.py index 17eb8c2eec48..2d50089fef8f 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_catalog.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_catalog.py @@ -1,14 +1,11 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -import asyncio -import threading +import asyncio # pylint: disable=C4763 from abc import ABC, abstractmethod -from collections import defaultdict -from concurrent.futures import Future -from typing import Any, Awaitable, Collection, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import Any, Awaitable, Collection, List, Mapping, MutableMapping, Optional, Union -from cachetools import TTLCache +from cachetools import TTLCache # type: ignore[import-untyped] from ._facade import FoundryToolLike, ensure_foundry_tool from ._user import UserProvider @@ -93,7 +90,7 @@ async def list(self, tools: List[FoundryToolLike]) -> List[ResolvedFoundryTool]: await asyncio.gather(*fetching_tasks) except: # exception can only be caused by fetching tasks, remove them from cache - for k in tools_to_fetch.keys(): + for k, _ in tools_to_fetch.items(): if k in self._cache: del self._cache[k] raise diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py index ebaca87cf1a7..f12d3f0db7b5 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py @@ -13,7 +13,8 @@ # Required: # - "type": str Discriminator, e.g. "mcp" | "a2a" | "code_interpreter" | ... # Optional: -# - "project_connection_id": str Project connection id of Foundry connected tools, required with "type" is "mcp" or a2a. +# - "project_connection_id": str Project connection id of Foundry connected tools, +# required when "type" is "mcp" or "a2a". # # Custom keys: # - Allowed, but MUST NOT shadow reserved keys. @@ -45,5 +46,5 @@ def ensure_foundry_tool(tool: FoundryToolLike) -> FoundryTool: raise InvalidToolFacadeError(f"project_connection_id is required for tool protocol {protocol}.") return FoundryConnectedTool(protocol=protocol, project_connection_id=project_connection_id) - except: + except ValueError: return FoundryHostedMcpTool(name=tool_type, configuration=tool) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_resolver.py index 2764558b06bb..24eb0fabbb21 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_resolver.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_resolver.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- from abc import ABC, abstractmethod -from typing import Awaitable, Union, overload +from typing import Optional, Union from ._catalog import FoundryToolCatalog from ._facade import FoundryToolLike, ensure_foundry_tool @@ -49,9 +49,12 @@ async def resolve(self, tool: Union[FoundryToolLike, ResolvedFoundryTool]) -> Fo :return: The resolved Foundry tool invoker. :rtype: FoundryToolInvoker """ - resolved_tool = (tool - if isinstance(tool, ResolvedFoundryTool) - else await self._catalog.get(ensure_foundry_tool(tool))) - if not resolved_tool: - raise UnableToResolveToolInvocationError(f"Unable to resolve tool {tool} from catalog", tool) - return DefaultFoundryToolInvoker(resolved_tool, self._client, self._user_provider, self._agent_name) \ No newline at end of file + if isinstance(tool, ResolvedFoundryTool): + resolved_tool = tool + else: + foundry_tool = ensure_foundry_tool(tool) + resolved_tool = await self._catalog.get(foundry_tool) # type: ignore[assignment] + if not resolved_tool: + raise UnableToResolveToolInvocationError(f"Unable to resolve tool {foundry_tool} from catalog", + foundry_tool) + return DefaultFoundryToolInvoker(resolved_tool, self._client, self._user_provider, self._agent_name) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_runtime.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_runtime.py index 8ff723a6f7dc..8bc77759ecd6 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_runtime.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_runtime.py @@ -71,12 +71,18 @@ def __init__(self, @property def catalog(self) -> FoundryToolCatalog: - """The tool catalog.""" + """The tool catalog. + + :rtype: FoundryToolCatalog + """ return self._catalog @property def invocation(self) -> FoundryToolInvocationResolver: - """The tool invocation resolver.""" + """The tool invocation resolver. + + :rtype: FoundryToolInvocationResolver + """ return self._invocation async def __aenter__(self) -> "DefaultFoundryToolRuntime": diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py index 17b25095a953..f60fb63f2cdc 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py @@ -39,7 +39,7 @@ def install(cls, If not provided, a default resolver will be used. :type user_resolver: Optional[Callable[[Request], Awaitable[Optional[UserInfo]]]] """ - app.add_middleware(UserInfoContextMiddleware, + app.add_middleware(UserInfoContextMiddleware, # type: ignore[arg-type] user_info_var=user_context or ContextVarUserProvider.default_user_info_context, user_resolver=user_resolver or cls._default_user_resolver) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_user.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_user.py index 14d8aad2690a..f72b30c0d3d3 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_user.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_user.py @@ -13,7 +13,11 @@ class UserProvider(ABC): @abstractmethod async def get_user(self) -> Optional[UserInfo]: - """Get the user information.""" + """Get the user information. + + :return: The user information or None if not found. + :rtype: Optional[UserInfo] + """ raise NotImplementedError @@ -25,7 +29,11 @@ def __init__(self, context: Optional[ContextVar[UserInfo]] = None): self.context = context or self.default_user_info_context async def get_user(self) -> Optional[UserInfo]: - """Get the user information from the context variable.""" + """Get the user information from the context variable. + + :return: The user information or None if not found. + :rtype: Optional[UserInfo] + """ return self.context.get(None) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/__init__.py index 41fc7e00dd6d..037fb1dc04de 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/__init__.py @@ -4,4 +4,8 @@ __path__ = __import__('pkgutil').extend_path(__path__, __name__) -from ._name_resolver import * +from ._name_resolver import ToolNameResolver + +__all__ = [ + "ToolNameResolver", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/_name_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/_name_resolver.py index ab9c87fd113c..9f1b7874f52c 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/_name_resolver.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/_name_resolver.py @@ -8,8 +8,8 @@ class ToolNameResolver: """Utility class for resolving tool names to be registered to model.""" def __init__(self): - self._count_by_name = dict() - self._stable_names = dict() + self._count_by_name = {} + self._stable_names = {} def resolve(self, tool: ResolvedFoundryTool) -> str: """Resolve a stable name for the given tool. diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/_credential.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/_credential.py index 24de2e1345a4..398a8c46fd5d 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/_credential.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/_credential.py @@ -3,17 +3,27 @@ # --------------------------------------------------------- from __future__ import annotations -import asyncio +import asyncio # pylint: disable=C4763 import inspect from types import TracebackType -from typing import Any, Optional, Sequence, Type, Union +from typing import Any, Type, cast from azure.core.credentials import AccessToken, TokenCredential from azure.core.credentials_async import AsyncTokenCredential -async def _to_thread(func, *args, **kwargs): - """Compatibility wrapper for asyncio.to_thread (Python 3.8+).""" +async def _to_thread(func, *args, **kwargs): # pylint: disable=C4743 + """Compatibility wrapper for asyncio.to_thread (Python 3.8+). + + :param func: The function to run in a thread. + :type func: Callable + :param args: Positional arguments to pass to the function. + :type args: Any + :param kwargs: Keyword arguments to pass to the function. + :type kwargs: Any + :return: The result of the function call. + :rtype: Any + """ if hasattr(asyncio, "to_thread"): return await asyncio.to_thread(func, *args, **kwargs) # py>=3.9 loop = asyncio.get_running_loop() @@ -27,7 +37,7 @@ class AsyncTokenCredentialAdapter(AsyncTokenCredential): - azure.core.credentials_async.AsyncTokenCredential (async) """ - def __init__(self, credential: TokenCredential |AsyncTokenCredential) -> None: + def __init__(self, credential: TokenCredential | AsyncTokenCredential) -> None: if not hasattr(credential, "get_token"): raise TypeError("credential must have a get_token method") self._credential = credential @@ -44,11 +54,12 @@ async def get_token( **kwargs: Any, ) -> AccessToken: if self._is_async: - return await self._credential.get_token(*scopes, - claims=claims, - tenant_id=tenant_id, - enable_cae=enable_cae, - **kwargs) + cred = cast(AsyncTokenCredential, self._credential) + return await cred.get_token(*scopes, + claims=claims, + tenant_id=tenant_id, + enable_cae=enable_cae, + **kwargs) return await _to_thread(self._credential.get_token, *scopes, claims=claims, @@ -86,4 +97,4 @@ async def __aexit__( aexit = getattr(self._credential, "__aexit__", None) if aexit is not None and inspect.iscoroutinefunction(aexit): return await aexit(exc_type, exc_value, traceback) - await self.close() \ No newline at end of file + await self.close() From 06f568ea0ba051e6510e836a1162fc7a71a4211f Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Mon, 19 Jan 2026 19:59:59 -0800 Subject: [PATCH 05/29] refining agent framework adapters --- .../agentframework/_agent_framework.py | 35 +++++++++++++++++-- .../agentframework/_ai_agent_adapter.py | 21 +++-------- .../agentframework/_workflow_agent_adapter.py | 26 +++++--------- 3 files changed, 46 insertions(+), 36 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py index 7407fc96131c..a8f1a8f7eb83 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py @@ -7,7 +7,7 @@ import os from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Optional, Protocol, Union, List -from agent_framework import AgentProtocol, AIFunction +from agent_framework import AgentProtocol, AIFunction, AgentThread, WorkflowAgent from agent_framework.azure import AzureAIClient # pylint: disable=no-name-in-module from opentelemetry import trace @@ -223,4 +223,35 @@ async def agent_run( # pylint: disable=too-many-statements OpenAIResponse, AsyncGenerator[ResponseStreamEvent, Any], ]: - raise NotImplementedError("This method is implemented in the base class.") \ No newline at end of file + raise NotImplementedError("This method is implemented in the base class.") + + async def _load_agent_thread(self, context: AgentRunContext, agent: Union[AgentProtocol, WorkflowAgent]) -> Optional[AgentThread]: + """Load the agent thread for a given conversation ID. + + :param context: The agent run context. + :type context: AgentRunContext + :param agent: The agent instance. + :type agent: AgentProtocol | WorkflowAgent + + :return: The loaded AgentThread if available, None otherwise. + :rtype: Optional[AgentThread] + """ + if self._thread_repository: + agent_thread = await self._thread_repository.get(context.conversation_id) + if agent_thread: + logger.info(f"Loaded agent thread for conversation: {context.conversation_id}") + return agent_thread + return agent.get_new_thread() + return None + + async def _save_agent_thread(self, context: AgentRunContext, agent_thread: AgentThread) -> None: + """Save the agent thread for a given conversation ID. + + :param context: The agent run context. + :type context: AgentRunContext + :param agent_thread: The agent thread to save. + :type agent_thread: AgentThread + """ + if agent_thread and self._thread_repository: + await self._thread_repository.set(context.conversation_id, agent_thread) + logger.info(f"Saved agent thread for conversation: {context.conversation_id}") \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py index 5f20367f5a0f..eb15ff4f6b64 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py @@ -47,14 +47,8 @@ async def agent_run( # pylint: disable=too-many-statements logger.info(f"Starting agent_run with stream={context.stream}") request_input = context.request.get("input") - agent_thread = None - if self._thread_repository: - agent_thread = await self._thread_repository.get(context.conversation_id) - if agent_thread: - logger.info(f"Loaded agent thread for conversation: {context.conversation_id}") - else: - agent_thread = self.agent.get_new_thread() - + agent_thread = self._load_agent_thread(context, self._agent) + input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper) message = await input_converter.transform_input( request_input, @@ -78,9 +72,7 @@ async def stream_updates(): update_count += 1 yield event - if agent_thread and self._thread_repository: - await self._thread_repository.set(context.conversation_id, agent_thread) - logger.info(f"Saved agent thread for conversation: {context.conversation_id}") + await self._save_agent_thread(context, agent_thread) logger.info("Streaming completed with %d updates", update_count) except OAuthConsentRequiredError as e: @@ -124,16 +116,13 @@ async def stream_updates(): # Non-streaming path logger.info("Running agent in non-streaming mode") - non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper) result = await self.agent.run( message, thread=agent_thread) logger.debug(f"Agent run completed, result type: {type(result)}") + await self._save_agent_thread(context, agent_thread) - if agent_thread and self._thread_repository: - await self._thread_repository.set(context.conversation_id, agent_thread) - logger.info(f"Saved agent thread for conversation: {context.conversation_id}") - + non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper) transformed_result = non_streaming_converter.transform_output_for_response(result) logger.info("Agent run and transformation completed successfully") return transformed_result diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py index 45089f88c012..c3ec658bda7c 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py @@ -50,23 +50,17 @@ async def agent_run( # pylint: disable=too-many-statements logger.info(f"Starting agent_run with stream={context.stream}") request_input = context.request.get("input") - agent_thread = None + agent_thread = self._load_agent_thread(context, agent) + checkpoint_storage = None last_checkpoint = None - if self._thread_repository: - agent_thread = await self._thread_repository.get(context.conversation_id, agent=agent) - if agent_thread: - logger.info(f"Loaded agent thread for conversation: {context.conversation_id}") - else: - agent_thread = agent.get_new_thread() - if self._checkpoint_repository: checkpoint_storage = await self._checkpoint_repository.get_or_create(context.conversation_id) last_checkpoint = await self._get_latest_checkpoint(checkpoint_storage) if last_checkpoint: summary = get_checkpoint_summary(last_checkpoint) if summary.status == "completed": - logger.warning("Last checkpoint is completed. Will not resume from it.") + logger.warning(f"Lastest checkpoint {last_checkpoint.checkpoint_id} is completed. Will not resume from it.") last_checkpoint = None # Do not resume from completed checkpoints if last_checkpoint: await self._load_checkpoint(agent, last_checkpoint, checkpoint_storage) @@ -97,9 +91,7 @@ async def stream_updates(): update_count += 1 yield event - if agent_thread and self._thread_repository: - await self._thread_repository.set(context.conversation_id, agent_thread) - logger.info(f"Saved agent thread for conversation: {context.conversation_id}") + await self._save_agent_thread(context, agent_thread) logger.info("Streaming completed with %d updates", update_count) except OAuthConsentRequiredError as e: @@ -142,18 +134,16 @@ async def stream_updates(): return stream_updates() # Non-streaming path - logger.info("Running agent in non-streaming mode") - non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper) + logger.info("Running WorkflowAgent in non-streaming mode") result = await agent.run( message, thread=agent_thread, checkpoint_storage=checkpoint_storage) - logger.debug(f"Agent run completed, result type: {type(result)}") + logger.debug(f"WorkflowAgent run completed, result type: {type(result)}") - if agent_thread and self._thread_repository: - await self._thread_repository.set(context.conversation_id, agent_thread) - logger.info(f"Saved agent thread for conversation: {context.conversation_id}") + await self._save_agent_thread(context, agent_thread) + non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper) transformed_result = non_streaming_converter.transform_output_for_response(result) logger.info("Agent run and transformation completed successfully") return transformed_result From 2a94c4082736663f6eb0fefdba755d4784e90e7c Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Tue, 20 Jan 2026 09:46:44 -0800 Subject: [PATCH 06/29] refining adapters --- .../agentframework/_agent_framework.py | 60 ++++++++++++++- .../agentframework/_ai_agent_adapter.py | 73 +++---------------- .../agentframework/_workflow_agent_adapter.py | 70 +++--------------- ...nt_framework_output_streaming_converter.py | 23 ++++-- 4 files changed, 96 insertions(+), 130 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py index a8f1a8f7eb83..2bae98b31af9 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py @@ -5,7 +5,7 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Optional, Protocol, Union, List +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Optional, Protocol, Union, List, Callable from agent_framework import AgentProtocol, AIFunction, AgentThread, WorkflowAgent from agent_framework.azure import AzureAIClient # pylint: disable=no-name-in-module @@ -254,4 +254,60 @@ async def _save_agent_thread(self, context: AgentRunContext, agent_thread: Agent """ if agent_thread and self._thread_repository: await self._thread_repository.set(context.conversation_id, agent_thread) - logger.info(f"Saved agent thread for conversation: {context.conversation_id}") \ No newline at end of file + logger.info(f"Saved agent thread for conversation: {context.conversation_id}") + + def _run_streaming_updates( + self, + *, + context: AgentRunContext, + run_stream: Callable[[], AsyncGenerator[Any, None]], + agent_thread: Optional[AgentThread] = None, + ) -> AsyncGenerator[ResponseStreamEvent, Any]: + """Execute a streaming run with shared OAuth/error handling.""" + logger.info("Running agent in streaming mode") + streaming_converter = AgentFrameworkOutputStreamingConverter(context, hitl_helper=self._hitl_helper) + + async def stream_updates(): + try: + update_count = 0 + try: + updates = run_stream() + async for event in streaming_converter.convert(updates): + update_count += 1 + yield event + + await self._save_agent_thread(context, agent_thread) + logger.info("Streaming completed with %d updates", update_count) + except OAuthConsentRequiredError as e: + logger.info("OAuth consent required during streaming updates") + if update_count == 0: + async for event in self.respond_with_oauth_consent_astream(context, e): + yield event + else: + yield ResponseErrorEvent( + sequence_number=streaming_converter.next_sequence(), + code="server_error", + message=f"OAuth consent required: {e.consent_url}", + param="agent_run", + ) + yield ResponseFailedEvent( + sequence_number=streaming_converter.next_sequence(), + response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Unhandled exception during streaming updates: %s", e, exc_info=True) + yield ResponseErrorEvent( + sequence_number=streaming_converter.next_sequence(), + code="server_error", + message=str(e), + param="agent_run", + ) + yield ResponseFailedEvent( + sequence_number=streaming_converter.next_sequence(), + response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access + ) + finally: + # No request-scoped resources to clean up today, but keep hook for future use. + pass + + return stream_updates() \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py index eb15ff4f6b64..3b2da66a939e 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py @@ -4,8 +4,7 @@ # pylint: disable=logging-fstring-interpolation,no-name-in-module,no-member,do-not-import-asyncio from __future__ import annotations -import os -from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Union +from typing import Any, AsyncGenerator, Optional, Union from agent_framework import AgentProtocol @@ -16,13 +15,11 @@ Response as OpenAIResponse, ResponseStreamEvent, ) -from azure.ai.agentserver.core.models.projects import ResponseErrorEvent, ResponseFailedEvent from .models.agent_framework_input_converters import AgentFrameworkInputConverter from .models.agent_framework_output_non_streaming_converter import ( AgentFrameworkOutputNonStreamingConverter, ) -from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter from ._agent_framework import AgentFrameworkCBAgent from .persistence import AgentThreadRepository @@ -47,8 +44,8 @@ async def agent_run( # pylint: disable=too-many-statements logger.info(f"Starting agent_run with stream={context.stream}") request_input = context.request.get("input") - agent_thread = self._load_agent_thread(context, self._agent) - + agent_thread = await self._load_agent_thread(context, self._agent) + input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper) message = await input_converter.transform_input( request_input, @@ -57,62 +54,14 @@ async def agent_run( # pylint: disable=too-many-statements # Use split converters if context.stream: - logger.info("Running agent in streaming mode") - streaming_converter = AgentFrameworkOutputStreamingConverter(context, hitl_helper=self._hitl_helper) - - async def stream_updates(): - try: - update_count = 0 - try: - updates = self.agent.run_stream( - message, - thread=agent_thread, - ) - async for event in streaming_converter.convert(updates): - update_count += 1 - yield event - - await self._save_agent_thread(context, agent_thread) - - logger.info("Streaming completed with %d updates", update_count) - except OAuthConsentRequiredError as e: - logger.info("OAuth consent required during streaming updates") - if update_count == 0: - async for event in self.respond_with_oauth_consent_astream(context, e): - yield event - else: - # If we've already emitted events, we cannot safely restart a new - # OAuth-consent stream (it would reset sequence numbers). - yield ResponseErrorEvent( - sequence_number=streaming_converter.next_sequence(), - code="server_error", - message=f"OAuth consent required: {e.consent_url}", - param="agent_run", - ) - yield ResponseFailedEvent( - sequence_number=streaming_converter.next_sequence(), - response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access - ) - except Exception as e: # pylint: disable=broad-exception-caught - logger.error("Unhandled exception during streaming updates: %s", e, exc_info=True) - - # Emit well-formed error events instead of terminating the stream. - yield ResponseErrorEvent( - sequence_number=streaming_converter.next_sequence(), - code="server_error", - message=str(e), - param="agent_run", - ) - yield ResponseFailedEvent( - sequence_number=streaming_converter.next_sequence(), - response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access - ) - finally: - # No request-scoped resources to clean up here today. - # Keep this block as a hook for future request-scoped cleanup. - pass - - return stream_updates() + return self._run_streaming_updates( + context=context, + run_stream=lambda: self.agent.run_stream( + message, + thread=agent_thread, + ), + agent_thread=agent_thread, + ) # Non-streaming path logger.info("Running agent in non-streaming mode") diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py index c3ec658bda7c..064acfc8fd22 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py @@ -12,14 +12,12 @@ Response as OpenAIResponse, ResponseStreamEvent, ) -from azure.ai.agentserver.core.models.projects import ResponseErrorEvent, ResponseFailedEvent from ._agent_framework import AgentFrameworkCBAgent from .models.agent_framework_input_converters import AgentFrameworkInputConverter from .models.agent_framework_output_non_streaming_converter import ( AgentFrameworkOutputNonStreamingConverter, ) -from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter from .persistence.agent_thread_repository import AgentThreadRepository from .persistence.checkpoint_repository import CheckpointRepository @@ -50,7 +48,7 @@ async def agent_run( # pylint: disable=too-many-statements logger.info(f"Starting agent_run with stream={context.stream}") request_input = context.request.get("input") - agent_thread = self._load_agent_thread(context, agent) + agent_thread = await self._load_agent_thread(context, agent) checkpoint_storage = None last_checkpoint = None @@ -75,63 +73,15 @@ async def agent_run( # pylint: disable=too-many-statements # Use split converters if context.stream: - logger.info("Running agent in streaming mode") - streaming_converter = AgentFrameworkOutputStreamingConverter(context, hitl_helper=self._hitl_helper) - - async def stream_updates(): - try: - update_count = 0 - try: - updates = agent.run_stream( - message, - thread=agent_thread, - checkpoint_storage=checkpoint_storage, - ) - async for event in streaming_converter.convert(updates): - update_count += 1 - yield event - - await self._save_agent_thread(context, agent_thread) - - logger.info("Streaming completed with %d updates", update_count) - except OAuthConsentRequiredError as e: - logger.info("OAuth consent required during streaming updates") - if update_count == 0: - async for event in self.respond_with_oauth_consent_astream(context, e): - yield event - else: - # If we've already emitted events, we cannot safely restart a new - # OAuth-consent stream (it would reset sequence numbers). - yield ResponseErrorEvent( - sequence_number=streaming_converter.next_sequence(), - code="server_error", - message=f"OAuth consent required: {e.consent_url}", - param="agent_run", - ) - yield ResponseFailedEvent( - sequence_number=streaming_converter.next_sequence(), - response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access - ) - except Exception as e: # pylint: disable=broad-exception-caught - logger.error("Unhandled exception during streaming updates: %s", e, exc_info=True) - - # Emit well-formed error events instead of terminating the stream. - yield ResponseErrorEvent( - sequence_number=streaming_converter.next_sequence(), - code="server_error", - message=str(e), - param="agent_run", - ) - yield ResponseFailedEvent( - sequence_number=streaming_converter.next_sequence(), - response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access - ) - finally: - # No request-scoped resources to clean up here today. - # Keep this block as a hook for future request-scoped cleanup. - pass - - return stream_updates() + return self._run_streaming_updates( + context=context, + run_stream=lambda: agent.run_stream( + message, + thread=agent_thread, + checkpoint_storage=checkpoint_storage, + ), + agent_thread=agent_thread, + ) # Non-streaming path logger.info("Running WorkflowAgent in non-streaming mode") diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py index 503e1c29bfbd..f089fd672cd8 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py @@ -48,6 +48,9 @@ from .human_in_the_loop_helper import HumanInTheLoopHelper from .utils.async_iter import chunk_on_change, peek +from azure.ai.agentserver.core.logger import get_logger +logger = get_logger() + class _BaseStreamingState: """Base interface for streaming state handlers.""" @@ -147,6 +150,7 @@ async def convert_contents( hitl_contents = [] async for content in contents: + logger.info("Processing content %s: %s", type(content).__name__, content.to_dict()) if isinstance(content, FunctionCallContent): if content.call_id not in content_by_call_id: item_id = self._parent.context.id_generator.generate_function_call_id() @@ -186,7 +190,7 @@ async def convert_contents( for call_id, content in content_by_call_id.items(): item_id, output_index = ids_by_call_id[call_id] - args = content.arguments if isinstance(content.arguments, str) else json.dumps(content.arguments) + args = self._serialize_arguments(content.arguments) yield ResponseFunctionCallArgumentsDoneEvent( sequence_number=self._parent.next_sequence(), item_id=item_id, @@ -255,6 +259,17 @@ async def convert_contents( ) self._parent.add_completed_output_item(item) + def _serialize_arguments(self, arguments: Any) -> str: + if isinstance(arguments, str): + return arguments + if hasattr(arguments, "to_dict"): + arguments = arguments.to_dict() + try: + return json.dumps(arguments) + except Exception as e: + logger.error(f"Failed to serialize function call arguments: {e}") + return str(arguments) + class _FunctionCallOutputStreamingState(_BaseStreamingState): """Handles function_call_output items streaming (non-chunked simple output).""" @@ -355,11 +370,7 @@ async def convert(self, updates: AsyncIterable[AgentRunResponseUpdate]) -> Async is_changed = ( lambda a, b: a is not None \ and b is not None \ - and (a.message_id != b.message_id \ - or ( - a.contents and b.contents \ - and type(a.contents[0]) != type(b.contents[0])) - ) # pylint: disable=unnecessary-lambda-assignment + and a.message_id != b.message_id # pylint: disable=unnecessary-lambda-assignment ) async for group in chunk_on_change(updates, is_changed): From 18acd9f922dd0af6010be266bc19874c81add73f Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Tue, 20 Jan 2026 10:08:55 -0800 Subject: [PATCH 07/29] updated minors --- .../agentframework/_agent_framework.py | 2 +- .../agentframework/_ai_agent_adapter.py | 7 ++-- .../agentframework/_workflow_agent_adapter.py | 34 +++++++++---------- ...nt_framework_output_streaming_converter.py | 7 ---- 4 files changed, 21 insertions(+), 29 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py index 2bae98b31af9..6393a7e9b42f 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py @@ -55,7 +55,7 @@ def __call__(self, tools: List[AIFunction]) -> Union[AgentProtocol, Awaitable[Ag ... -class AgentFrameworkCBAgent(FoundryCBAgent): +class AgentFrameworkAgent(FoundryCBAgent): """ Adapter class for integrating Agent Framework agents with the FoundryCB agent interface. diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py index 3b2da66a939e..6105470dbdc9 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py @@ -1,7 +1,6 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -# pylint: disable=logging-fstring-interpolation,no-name-in-module,no-member,do-not-import-asyncio from __future__ import annotations from typing import Any, AsyncGenerator, Optional, Union @@ -20,12 +19,12 @@ from .models.agent_framework_output_non_streaming_converter import ( AgentFrameworkOutputNonStreamingConverter, ) -from ._agent_framework import AgentFrameworkCBAgent +from ._agent_framework import AgentFrameworkAgent from .persistence import AgentThreadRepository logger = get_logger() -class AgentFrameworkAIAgentAdapter(AgentFrameworkCBAgent): +class AgentFrameworkAIAgentAdapter(AgentFrameworkAgent): def __init__(self, agent: AgentProtocol, *, thread_repository: Optional[AgentThreadRepository]=None, @@ -41,7 +40,7 @@ async def agent_run( # pylint: disable=too-many-statements AsyncGenerator[ResponseStreamEvent, Any], ]: try: - logger.info(f"Starting agent_run with stream={context.stream}") + logger.info(f"Starting AIAgent agent_run with stream={context.stream}") request_input = context.request.get("input") agent_thread = await self._load_agent_thread(context, self._agent) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py index 064acfc8fd22..fc917a3ea9bf 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py @@ -1,4 +1,6 @@ - +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Optional, Protocol, Union, List from agent_framework import WorkflowBuilder, CheckpointStorage, WorkflowAgent, WorkflowCheckpoint @@ -8,12 +10,11 @@ from azure.ai.agentserver.core import AgentRunContext from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.models import ( - CreateResponse, Response as OpenAIResponse, ResponseStreamEvent, ) -from ._agent_framework import AgentFrameworkCBAgent +from ._agent_framework import AgentFrameworkAgent from .models.agent_framework_input_converters import AgentFrameworkInputConverter from .models.agent_framework_output_non_streaming_converter import ( AgentFrameworkOutputNonStreamingConverter, @@ -23,7 +24,7 @@ logger = get_logger() -class AgentFrameworkWorkflowAdapter(AgentFrameworkCBAgent): +class AgentFrameworkWorkflowAdapter(AgentFrameworkAgent): """Adapter to run WorkflowBuilder agents within the Agent Framework CBAgent structure.""" def __init__(self, workflow_builder: WorkflowBuilder, @@ -45,30 +46,30 @@ async def agent_run( # pylint: disable=too-many-statements try: agent = self._build_agent() - logger.info(f"Starting agent_run with stream={context.stream}") + logger.info(f"Starting WorkflowAgent agent_run with stream={context.stream}") request_input = context.request.get("input") agent_thread = await self._load_agent_thread(context, agent) checkpoint_storage = None - last_checkpoint = None + selected_checkpoint = None if self._checkpoint_repository: checkpoint_storage = await self._checkpoint_repository.get_or_create(context.conversation_id) - last_checkpoint = await self._get_latest_checkpoint(checkpoint_storage) - if last_checkpoint: - summary = get_checkpoint_summary(last_checkpoint) - if summary.status == "completed": - logger.warning(f"Lastest checkpoint {last_checkpoint.checkpoint_id} is completed. Will not resume from it.") - last_checkpoint = None # Do not resume from completed checkpoints - if last_checkpoint: - await self._load_checkpoint(agent, last_checkpoint, checkpoint_storage) - logger.info(f"Loaded checkpoint with ID: {last_checkpoint.checkpoint_id}") + selected_checkpoint = await self._get_latest_checkpoint(checkpoint_storage) + if selected_checkpoint: + summary = get_checkpoint_summary(selected_checkpoint) + if summary.status == "completed": + logger.warning(f"Selected checkpoint {selected_checkpoint.checkpoint_id} is completed. Will not resume from it.") + selected_checkpoint = None # Do not resume from completed checkpoints + else: + await self._load_checkpoint(agent, selected_checkpoint, checkpoint_storage) + logger.info(f"Loaded checkpoint with ID: {selected_checkpoint.checkpoint_id}") input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper) message = await input_converter.transform_input( request_input, agent_thread=agent_thread, - checkpoint=last_checkpoint) + checkpoint=selected_checkpoint) logger.debug(f"Transformed input message type: {type(message)}") # Use split converters @@ -137,6 +138,5 @@ async def _load_checkpoint(self, agent: WorkflowAgent, :param checkpoint: The WorkflowCheckpoint to load data from. :type checkpoint: WorkflowCheckpoint """ - logger.info(f"Loading checkpoint ID: {checkpoint.to_dict()} into agent.") await agent.run(checkpoint_id=checkpoint.checkpoint_id, checkpoint_storage=checkpoint_storage) \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py index f089fd672cd8..253b0fc7aa9e 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py @@ -48,9 +48,6 @@ from .human_in_the_loop_helper import HumanInTheLoopHelper from .utils.async_iter import chunk_on_change, peek -from azure.ai.agentserver.core.logger import get_logger -logger = get_logger() - class _BaseStreamingState: """Base interface for streaming state handlers.""" @@ -150,7 +147,6 @@ async def convert_contents( hitl_contents = [] async for content in contents: - logger.info("Processing content %s: %s", type(content).__name__, content.to_dict()) if isinstance(content, FunctionCallContent): if content.call_id not in content_by_call_id: item_id = self._parent.context.id_generator.generate_function_call_id() @@ -262,12 +258,9 @@ async def convert_contents( def _serialize_arguments(self, arguments: Any) -> str: if isinstance(arguments, str): return arguments - if hasattr(arguments, "to_dict"): - arguments = arguments.to_dict() try: return json.dumps(arguments) except Exception as e: - logger.error(f"Failed to serialize function call arguments: {e}") return str(arguments) From 8fde0d9a5e4ed088851ffc587f51f22c706dd460 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Tue, 20 Jan 2026 16:43:48 -0800 Subject: [PATCH 08/29] update from_agent_framework --- .../ai/agentserver/agentframework/__init__.py | 95 +++++- .../agentframework/_agent_framework.py | 15 +- .../agentframework/_workflow_agent_adapter.py | 12 +- .../samples/workflow_agent_simple/README.md | 286 ++-------------- .../workflow_agent_simple.py | 308 ++---------------- 5 files changed, 167 insertions(+), 549 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py index ed655ea7595d..1a276a14ff9e 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py @@ -3,9 +3,9 @@ # --------------------------------------------------------- __path__ = __import__("pkgutil").extend_path(__path__, __name__) -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, overload -from agent_framework import AgentProtocol, WorkflowBuilder +from agent_framework import AgentProtocol, BaseAgent, Workflow, WorkflowBuilder from azure.ai.agentserver.agentframework._version import VERSION from azure.ai.agentserver.agentframework._agent_framework import AgentFrameworkAgent @@ -18,17 +18,98 @@ from azure.core.credentials_async import AsyncTokenCredential +@overload def from_agent_framework( - agent: Union[AgentProtocol, WorkflowBuilder], + *, + agent: Union[BaseAgent, AgentProtocol], + credentials: Optional["AsyncTokenCredential"] = None, + **kwargs: Any, + ) -> "AgentFrameworkAIAgentAdapter": + """ + Create an Agent Framework AI Agent Adapter from an AgentProtocol or BaseAgent. + + :param agent: The agent to adapt. + :type agent: Union[BaseAgent, AgentProtocol] + :param credentials: Optional asynchronous token credential for authentication. + :type credentials: Optional[AsyncTokenCredential] + :param kwargs: Additional keyword arguments to pass to the adapter. + :type kwargs: Any + + :return: An instance of AgentFrameworkAIAgentAdapter. + :rtype: AgentFrameworkAIAgentAdapter + """ + ... + +@overload +def from_agent_framework( + *, + workflow: Union[WorkflowBuilder, Callable[[], Workflow]], + credentials: Optional["AsyncTokenCredential"] = None, + **kwargs: Any, + ) -> "AgentFrameworkWorkflowAdapter": + """ + Create an Agent Framework Workflow Adapter. + The arugument `workflow` can be either a WorkflowBuilder or a factory function + that returns a Workflow. + It will be called to create a new Workflow instance and `.as_agent()` will be + called as well for each incoming CreateResponse request. Please ensure that the + workflow definition can be converted to a WorkflowAgent. For more information, + see the agent-framework samples and documentation. + + :param workflow: The workflow builder or factory function to adapt. + :type workflow: Union[WorkflowBuilder, Callable[[], Workflow]] + :param credentials: Optional asynchronous token credential for authentication. + :type credentials: Optional[AsyncTokenCredential] + :param kwargs: Additional keyword arguments to pass to the adapter. + :type kwargs: Any + :return: An instance of AgentFrameworkWorkflowAdapter. + :rtype: AgentFrameworkWorkflowAdapter + """ + ... + +def from_agent_framework( + *, + agent: Optional[Union[BaseAgent, AgentProtocol]] = None, + workflow: Optional[Union[WorkflowBuilder, Callable[[], Workflow]]] = None, credentials: Optional["AsyncTokenCredential"] = None, **kwargs: Any, ) -> "AgentFrameworkAgent": + """ + Create an Agent Framework Adapter from either an AgentProtocol/BaseAgent or a + WorkflowAgent. + One of agent or workflow must be provided. + + :param agent: The agent to adapt. + :type agent: Optional[Union[BaseAgent, AgentProtocol]] + :param workflow: The workflow builder or factory function to adapt. + :type workflow: Optional[Union[WorkflowBuilder, Callable[[], Workflow]]] + :param credentials: Optional asynchronous token credential for authentication. + :type credentials: Optional[AsyncTokenCredential] + :param kwargs: Additional keyword arguments to pass to the adapter. + :type kwargs: Any + :return: An instance of AgentFrameworkAgent. + :rtype: AgentFrameworkAgent + :raises TypeError: If neither or both of agent and workflow are provided, or if + the provided types are incorrect. + """ + + provided = sum(value is not None for value in (agent, workflow)) + if provided != 1: + raise TypeError("from_agent_framework requires exactly one of 'agent' or 'workflow' keyword arguments") + + if workflow is not None: + if isinstance(workflow, WorkflowBuilder): + def workflow_factory() -> Workflow: + return workflow.build() + + return AgentFrameworkWorkflowAdapter(workflow_factory=workflow_factory, credentials=credentials, **kwargs) + if isinstance(workflow, Callable): + return AgentFrameworkWorkflowAdapter(workflow_factory=workflow, credentials=credentials, **kwargs) + raise TypeError("workflow must be a WorkflowBuilder or callable returning a Workflow") - if isinstance(agent, WorkflowBuilder): - return AgentFrameworkWorkflowAdapter(workflow_builder=agent, credentials=credentials, **kwargs) - if isinstance(agent, AgentProtocol): + if isinstance(agent, AgentProtocol) or isinstance(agent, BaseAgent): return AgentFrameworkAIAgentAdapter(agent, credentials=credentials, **kwargs) - raise TypeError("agent must be an instance of AgentProtocol or WorkflowBuilder") + raise TypeError("agent must be an instance of AgentProtocol or BaseAgent") __all__ = [ diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py index 6f38501ddd7b..ece0198285bf 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py @@ -20,13 +20,10 @@ ) from azure.ai.agentserver.core.models.projects import ResponseErrorEvent, ResponseFailedEvent from azure.ai.agentserver.core.tools import OAuthConsentRequiredError -from .models.agent_framework_input_converters import AgentFrameworkInputConverter -from .models.agent_framework_output_non_streaming_converter import ( - AgentFrameworkOutputNonStreamingConverter, -) + from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter from .models.human_in_the_loop_helper import HumanInTheLoopHelper -from .persistence import AgentThreadRepository, CheckpointRepository +from .persistence import AgentThreadRepository if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -53,12 +50,13 @@ class AgentFrameworkAgent(FoundryCBAgent): def __init__(self, agent: AgentProtocol, credentials: "Optional[AsyncTokenCredential]" = None, + *, + thread_repository: Optional[AgentThreadRepository] = None, **kwargs: Any, ): - """Initialize the AgentFrameworkCBAgent with an AgentProtocol or a factory function. + """Initialize the AgentFrameworkAgent with an AgentProtocol. - :param agent: The Agent Framework agent to adapt, or a callable that takes ToolClient - and returns AgentProtocol (sync or async). + :param agent: The Agent Framework agent to adapt. :type agent: AgentProtocol :param credentials: Azure credentials for authentication. :type credentials: Optional[AsyncTokenCredential] @@ -67,6 +65,7 @@ def __init__(self, agent: AgentProtocol, """ super().__init__(credentials=credentials, **kwargs) # pylint: disable=unexpected-keyword-arg self._agent: AgentProtocol = agent + self._thread_repository = thread_repository self._hitl_helper = HumanInTheLoopHelper() @property diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py index fc917a3ea9bf..92667eb7cbcc 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py @@ -1,9 +1,9 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Optional, Protocol, Union, List +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Optional, Protocol, Union, List -from agent_framework import WorkflowBuilder, CheckpointStorage, WorkflowAgent, WorkflowCheckpoint +from agent_framework import Workflow, CheckpointStorage, WorkflowAgent, WorkflowCheckpoint from agent_framework._workflows import get_checkpoint_summary from azure.ai.agentserver.core.tools import OAuthConsentRequiredError @@ -27,13 +27,13 @@ class AgentFrameworkWorkflowAdapter(AgentFrameworkAgent): """Adapter to run WorkflowBuilder agents within the Agent Framework CBAgent structure.""" def __init__(self, - workflow_builder: WorkflowBuilder, + workflow_factory: Callable[[], Workflow], *, thread_repository: Optional[AgentThreadRepository] = None, checkpoint_repository: Optional[CheckpointRepository] = None, **kwargs: Any) -> None: - super().__init__(agent=workflow_builder, **kwargs) - self._workflow_builder = workflow_builder + super().__init__(agent=None, **kwargs) + self._workflow_factory = workflow_factory self._thread_repository = thread_repository self._checkpoint_repository = checkpoint_repository @@ -112,7 +112,7 @@ async def oauth_consent_stream(error=e): pass def _build_agent(self) -> WorkflowAgent: - return self._workflow_builder.build().as_agent() + return self._workflow_factory().as_agent() async def _get_latest_checkpoint(self, checkpoint_storage: CheckpointStorage) -> Optional[Any]: diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/README.md b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/README.md index 59bb6b9f19ec..cc82d1f19171 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/README.md +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/README.md @@ -1,287 +1,63 @@ -## Workflow Agent Reflection Sample (Python) +## Workflow Agent Simple Sample (Python) -This sample demonstrates how to wrap an Agent Framework workflow (with iterative review + improvement) as an agent using the Container Agents Adapter. It implements a "reflection" pattern consisting of two executors: +This sample hosts a two-step Agent Framework workflow—`Writer` followed by `Reviewer`—through the Azure AI Agent Server Adapter. The writer creates content, the reviewer provides the final response, and the adapter exposes the workflow through the same HTTP surface as any hosted agent. -- Worker: Produces an initial answer (and revised answers after feedback) -- Reviewer: Evaluates the answer against quality criteria and either approves or returns constructive feedback +### What `workflow_agent_simple.py` does +- Builds a workflow with `WorkflowBuilder` +- Passes the builder to `from_agent_framework(...).run_async()` so the adapter spins up an HTTP server (defaults to `0.0.0.0:8088`). +- A example of passing with a factory of `Workflow` is shown in comments. -The workflow cycles until the Reviewer approves the response. Only approved content is emitted externally (streamed the same way as a normal agent response). This pattern is useful for quality‑controlled assistance, gated tool use, evaluative chains, or iterative refinement. - -### Key Concepts Shown -- `WorkflowBuilder` + `.as_agent()` to expose a workflow as a standard agent -- Bidirectional edges enabling cyclical review (Worker ↔ Reviewer) -- Structured output parsing (Pydantic model) for review feedback -- Emitting `AgentRunUpdateEvent` to stream only approved messages -- Managing pending requests and re‑submission with incorporated feedback - -File: `workflow_agent_simple.py` +Please note that the `WorkflowBuilder` or `Workflow` factory will be called for each incoming request. The `Workflow` will be converted to `WorkflowAgent` by `.as_agent()`. The workflow definition need ot be valid for `WorkflowAgent`. --- ## Prerequisites - -> **Azure sign-in:** Run `az login` before starting the sample so `DefaultAzureCredential` can acquire a CLI token. - -Dependencies used by `workflow_agent_simple.py`: -- agent-framework-azure-ai (published package with workflow abstractions) -- agents_adapter -- azure-identity (for `DefaultAzureCredential`) -- python-dotenv (loads `.env` for local credentials) -- pydantic (pulled transitively; listed for clarity) - -Install from PyPI (from the repo root: `container_agents/`): -```bash -pip install agent-framework-azure-ai azure-identity python-dotenv - -pip install -e src/adapter/python -``` +- Python 3.10+ +- Azure CLI authenticated with `az login` (required for `AzureCliCredential`). +- An Azure AI project that already hosts a chat model deployment supported by the Agent Framework Azure client. --- -## Additional Requirements - -1. Azure AI project with a model deployment (supports Microsoft hosted, Azure OpenAI, or custom models exposed via Azure AI Foundry). - ---- - -## Configuration - -Copy `.envtemplate` to `.env` and fill in real values: -``` -AZURE_AI_PROJECT_ENDPOINT= -AZURE_AI_MODEL_DEPLOYMENT_NAME= -AGENT_PROJECT_NAME= -``` -`AGENT_PROJECT_NAME` lets you override the default Azure AI agent project for this workflow; omit it to fall back to the SDK default. +## Setup +1. Copy `.envtemplate` to `.env` and fill in your project details: + ``` + AZURE_AI_PROJECT_ENDPOINT= + AZURE_AI_MODEL_DEPLOYMENT_NAME= + ``` +2. Install the sample dependencies: + ```bash + pip install -r requirements.txt + ``` --- ## Run the Workflow Agent - From this folder: ```bash python workflow_agent_simple.py ``` -The server (via the adapter) will start on `0.0.0.0:8088` by default. - ---- - -## Send a Non‑Streaming Request - -```bash -curl -sS \ - -H "Content-Type: application/json" \ - -X POST http://localhost:8088/runs \ - -d '{"input":"Explain the concept of reflection in this workflow sample.","stream":false}' -``` - -Sample output (non‑streaming): - -``` -Processing 1 million files in parallel and writing their contents into a sorted output file can be a computationally and resource-intensive task. To handle it effectively, you can use Python with libraries like `concurrent.futures` for parallelism and `heapq` for the sorting and merging. - -Below is an example implementation: - -import os -from concurrent.futures import ThreadPoolExecutor -import heapq - -def read_file(file_path): - """Read the content of a single file and return it as a list of lines.""" - with open(file_path, 'r') as file: - return file.readlines() - -def parallel_read_files(file_paths, max_workers=8): - """ - Read files in parallel and return all the lines in memory. - :param file_paths: List of file paths to read. - :param max_workers: Number of worker threads to use for parallelism. - """ - all_lines = [] - with ThreadPoolExecutor(max_workers=max_workers) as executor: - # Submit tasks to read each file in parallel - results = executor.map(read_file, file_paths) - # Collect the results - for lines in results: - all_lines.extend(lines) - return all_lines - -def write_sorted_output(lines, output_file_path): - """ - Write sorted lines to the output file. - :param lines: List of strings to be sorted and written. - :param output_file_path: File path to write the sorted result. - """ - sorted_lines = sorted(lines) - with open(output_file_path, 'w') as output_file: - output_file.writelines(sorted_lines) - -def main(directory_path, output_file_path): - """ - Main function to read files in parallel and write sorted output. - :param directory_path: Path to the directory containing input files. - :param output_file_path: File path to write the sorted output. - """ - # Get a list of all the file paths in the given directory - file_paths = [os.path.join(directory_path, f) for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))] - - print(f"Found {len(file_paths)} files. Reading files in parallel...") - - # Read all lines from the files in parallel - all_lines = parallel_read_files(file_paths) - - print(f"Total lines read: {len(all_lines)}. Sorting and writing to output file...") - - # Write the sorted lines to the output file - write_sorted_output(all_lines, output_file_path) - - print(f"Sorted output written to: {output_file_path}") - -if __name__ == "__main__": - # Replace these paths with the appropriate input directory and output file path - input_directory = "path/to/input/directory" # Directory containing 1 million files - output_file = "path/to/output/sorted_output.txt" # Output file path - - main(input_directory, output_file) - -### Key Features and Steps: - -1. **Parallel Reading with `ThreadPoolExecutor`**: - - Files are read in parallel using threads to improve I/O performance since reading many files is mostly I/O-bound. - -2. **Sorting and Writing**: - - Once all lines are aggregated into memory, they are sorted using Python's `sorted()` function and written to the output file in one go. - -3. **Handles Large Number of Files**: - - The program uses threads to manage the potentially massive number of files in parallel, saving time instead of processing them serially. - -### Considerations: -- **Memory Usage**: This script reads all file contents into memory. If the total size of the files is too large, you may encounter memory issues. In such cases, consider processing the files in smaller chunks. -- **Sorting**: For extremely large data, consider using an external/merge sort technique to handle sorting in smaller chunks. -- **I/O Performance**: Ensure that your I/O subsystem and disk can handle the load. - -Let me know if you'd like an optimized version to handle larger datasets with limited memory! - -Usage (if provided): None -``` - ---- - -## Send a Streaming Request (Server-Sent Events) - -```bash -curl -N \ - -H "Content-Type: application/json" \ - -X POST http://localhost:8088/runs \ - -d '{"input":"How does the reviewer decide to approve?","stream":true}' -``` - -Sample output (streaming): - -``` -Here is a Python script that demonstrates parallel reading of 1 million files using `concurrent.futures` for parallelism and `heapq` to write the outputs to a sorted file. This approach ensures efficiency when dealing with such a large number of files. - - -import os -import heapq -from concurrent.futures import ThreadPoolExecutor - -def read_file(file_path): - """ - Read the content of a single file and return it as a list of lines. - """ - with open(file_path, 'r') as file: - return file.readlines() - -def parallel_read_files(file_paths, max_workers=4): - """ - Read multiple files in parallel. - """ - all_lines = [] - with ThreadPoolExecutor(max_workers=max_workers) as executor: - # Submit reading tasks to the thread pool - futures = [executor.submit(read_file, file_path) for file_path in file_paths] - - # Gather results as they are completed - for future in futures: - all_lines.extend(future.result()) - - return all_lines - -def write_sorted_output(lines, output_file): - """ - Write sorted lines to an output file. - """ - sorted_lines = sorted(lines) - with open(output_file, 'w') as file: - file.writelines(sorted_lines) - -if __name__ == "__main__": - # Set the directory containing your input files - input_directory = 'path_to_your_folder_with_files' - - # Get the list of all input files - file_paths = [os.path.join(input_directory, f) for f in os.listdir(input_directory) if os.path.isfile(os.path.join(input_directory, f))] - - # Specify the number of threads for parallel processing - max_threads = 8 # Adjust according to your system's capabilities - - # Step 1: Read all files in parallel - print("Reading files in parallel...") - all_lines = parallel_read_files(file_paths, max_workers=max_threads) - - # Step 2: Write the sorted data to the output file - output_file = 'sorted_output.txt' - print(f"Writing sorted output to {output_file}...") - write_sorted_output(all_lines, output_file) - - print("Operation complete.") - -[comment]: # ( cspell:ignore pysort ) - -### Key Points: -1. **Parallel Read**: The reading of files is handled using `concurrent.futures.ThreadPoolExecutor`, allowing multiple files to be processed simultaneously. - -2. **Sorted Output**: After collecting all lines from the files, the `sorted()` function is used to sort the content in memory. This ensures that the final output file will have all data in sorted order. - -3. **Adjustable Parallelism**: The `max_threads` parameter can be modified to control the number of threads used for file reading. The value should match your system's capabilities for optimal performance. - -4. **Large Data Handling**: If the data from 1 million files is too large to fit into memory, consider using an external merge sort algorithm or a library like `pysort` for efficient external sorting. - -Let me know if you'd like improvements or adjustments for more specific scenarios! -Final usage (if provided): None -``` - -> Only the final approved assistant content is emitted as normal output deltas; intermediate review feedback stays internal. - ---- -## How the Reflection Loop Works -1. User query enters the workflow (Worker start executor) -2. Worker produces an answer with model call -3. Reviewer evaluates using a structured schema (`feedback`, `approved`) -4. If not approved: Worker augments context with feedback + regeneration instruction, then re‑answers -5. Loop continues until `approved=True` -6. Approved content is emitted as `AgentRunResponseUpdate` (streamed externally) +The adapter starts the server on `http://0.0.0.0:8088` by default. --- -## Troubleshooting -| Issue | Resolution | -|-------|------------| -| `DefaultAzureCredential` errors | Run `az login` or configure a service principal. | -| Empty / no streaming | Confirm `stream` flag in request JSON and that the event loop is healthy. | -| Model 404 / deployment error | Verify `AZURE_AI_MODEL_DEPLOYMENT_NAME` exists in the Azure AI project configured by `AZURE_AI_PROJECT_ENDPOINT`. | -| `.env` not loading | Ensure `.env` sits beside the script (or set `dotenv_path`) and that `python-dotenv` is installed. | +## Send Requests +- **Non-streaming:** + ```bash + curl -sS \ + -H "Content-Type: application/json" \ + -X POST http://localhost:8088/runs \ + -d '{"input":"Create a slogan for a new electric SUV that is affordable and fun to drive","stream":false}' + ``` --- ## Related Resources - Agent Framework repo: https://github.com/microsoft/agent-framework -- Basic simple sample README (same folder structure) for installation reference +- Adapter package docs: `azure.ai.agentserver.agentframework` in this SDK --- ## License & Support -This sample follows the repository's LICENSE. For questions about unreleased Agent Framework features, contact the Agent Framework team via its GitHub repository. +This sample follows the repository LICENSE. For questions about the Agent Framework itself, open an issue in the Agent Framework GitHub repository. diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py index afbc0b48667b..eac282110d65 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py @@ -1,289 +1,51 @@ -# Copyright (c) Microsoft. All rights reserved. - import asyncio -from dataclasses import dataclass -from uuid import uuid4 - -from agent_framework import ( - AgentRunResponseUpdate, - AgentRunUpdateEvent, - BaseChatClient, - ChatMessage, - Contents, - Executor, - Role as ChatRole, - WorkflowBuilder, - WorkflowContext, - handler, -) -from agent_framework_azure_ai import AzureAIAgentClient -from azure.identity.aio import DefaultAzureCredential from dotenv import load_dotenv -from pydantic import BaseModel - -from azure.ai.agentserver.agentframework import from_agent_framework - -""" -The following sample demonstrates how to wrap a workflow as an agent using WorkflowAgent. - -This sample shows how to: -1. Create a workflow with a reflection pattern (Worker + Reviewer executors) -2. Wrap the workflow as an agent using the .as_agent() method -3. Stream responses from the workflow agent like a regular agent -4. Implement a review-retry mechanism where responses are iteratively improved - -The example implements a quality-controlled AI assistant where: -- Worker executor generates responses to user queries -- Reviewer executor evaluates the responses and provides feedback -- If not approved, the Worker incorporates feedback and regenerates the response -- The cycle continues until the response is approved -- Only approved responses are emitted to the external consumer - -Key concepts demonstrated: -- WorkflowAgent: Wraps a workflow to make it behave as an agent -- Bidirectional workflow with cycles (Worker ↔ Reviewer) -- AgentRunUpdateEvent: How workflows communicate with external consumers -- Structured output parsing for review feedback -- State management with pending requests tracking -""" - - -@dataclass -class ReviewRequest: - request_id: str - user_messages: list[ChatMessage] - agent_messages: list[ChatMessage] - - -@dataclass -class ReviewResponse: - request_id: str - feedback: str - approved: bool - load_dotenv() +from agent_framework import ChatAgent, WorkflowBuilder +from agent_framework.azure import AzureAIAgentClient +from azure.identity.aio import AzureCliCredential -class Reviewer(Executor): - """An executor that reviews messages and provides feedback.""" - - def __init__(self, chat_client: BaseChatClient) -> None: - super().__init__(id="reviewer") - self._chat_client = chat_client - - @handler - async def review( - self, request: ReviewRequest, ctx: WorkflowContext[ReviewResponse] - ) -> None: - print( - f"🔍 Reviewer: Evaluating response for request {request.request_id[:8]}..." - ) - - # Use the chat client to review the message and use structured output. - # NOTE: this can be modified to use an evaluation framework. - - class _Response(BaseModel): - feedback: str - approved: bool - - # Define the system prompt. - messages = [ - ChatMessage( - role=ChatRole.SYSTEM, - text="You are a reviewer for an AI agent, please provide feedback on the " - "following exchange between a user and the AI agent, " - "and indicate if the agent's responses are approved or not.\n" - "Use the following criteria for your evaluation:\n" - "- Relevance: Does the response address the user's query?\n" - "- Accuracy: Is the information provided correct?\n" - "- Clarity: Is the response easy to understand?\n" - "- Completeness: Does the response cover all aspects of the query?\n" - "Be critical in your evaluation and provide constructive feedback.\n" - "Do not approve until all criteria are met.", - ) - ] - - # Add user and agent messages to the chat history. - messages.extend(request.user_messages) - - # Add agent messages to the chat history. - messages.extend(request.agent_messages) - - # Add add one more instruction for the assistant to follow. - messages.append( - ChatMessage( - role=ChatRole.USER, - text="Please provide a review of the agent's responses to the user.", - ) - ) - - print("🔍 Reviewer: Sending review request to LLM...") - # Get the response from the chat client. - response = await self._chat_client.get_response( - messages=messages, response_format=_Response - ) - - # Parse the response. - parsed = _Response.model_validate_json(response.messages[-1].text) - - print(f"🔍 Reviewer: Review complete - Approved: {parsed.approved}") - print(f"🔍 Reviewer: Feedback: {parsed.feedback}") - - # Send the review response. - await ctx.send_message( - ReviewResponse( - request_id=request.request_id, - feedback=parsed.feedback, - approved=parsed.approved, - ) - ) - - -class Worker(Executor): - """An executor that performs tasks for the user.""" - - def __init__(self, chat_client: BaseChatClient) -> None: - super().__init__(id="worker") - self._chat_client = chat_client - self._pending_requests: dict[str, tuple[ReviewRequest, list[ChatMessage]]] = {} - - @handler - async def handle_user_messages( - self, user_messages: list[ChatMessage], ctx: WorkflowContext[ReviewRequest] - ) -> None: - print("🔧 Worker: Received user messages, generating response...") - - # Handle user messages and prepare a review request for the reviewer. - # Define the system prompt. - messages = [ - ChatMessage(role=ChatRole.SYSTEM, text="You are a helpful assistant.") - ] - - # Add user messages. - messages.extend(user_messages) - - print("🔧 Worker: Calling LLM to generate response...") - # Get the response from the chat client. - response = await self._chat_client.get_response(messages=messages) - print(f"🔧 Worker: Response generated: {response.messages[-1].text}") - - # Add agent messages. - messages.extend(response.messages) - - # Create the review request. - request = ReviewRequest( - request_id=str(uuid4()), - user_messages=user_messages, - agent_messages=response.messages, - ) - - print( - f"🔧 Worker: Generated response, sending to reviewer (ID: {request.request_id[:8]})" - ) - # Send the review request. - await ctx.send_message(request) - - # Add to pending requests. - self._pending_requests[request.request_id] = (request, messages) - - @handler - async def handle_review_response( - self, review: ReviewResponse, ctx: WorkflowContext[ReviewRequest] - ) -> None: - print( - f"🔧 Worker: Received review for request {review.request_id[:8]} - Approved: {review.approved}" - ) - - # Handle the review response. Depending on the approval status, - # either emit the approved response as AgentRunUpdateEvent, or - # retry given the feedback. - if review.request_id not in self._pending_requests: - raise ValueError( - f"Received review response for unknown request ID: {review.request_id}" - ) - # Remove the request from pending requests. - request, messages = self._pending_requests.pop(review.request_id) - - if review.approved: - print("✅ Worker: Response approved! Emitting to external consumer...") - # If approved, emit the agent run response update to the workflow's - # external consumer. - contents: list[Contents] = [] - for message in request.agent_messages: - contents.extend(message.contents) - # Emitting an AgentRunUpdateEvent in a workflow wrapped by a WorkflowAgent - # will send the AgentRunResponseUpdate to the WorkflowAgent's - # event stream. - await ctx.add_event( - AgentRunUpdateEvent( - self.id, - data=AgentRunResponseUpdate( - contents=contents, role=ChatRole.ASSISTANT, author_name=self.id - ), - ) - ) - return - - print(f"❌ Worker: Response not approved. Feedback: {review.feedback}") - print("🔧 Worker: Incorporating feedback and regenerating response...") - - # Construct new messages with feedback. - messages.append(ChatMessage(role=ChatRole.SYSTEM, text=review.feedback)) - - # Add additional instruction to address the feedback. - messages.append( - ChatMessage( - role=ChatRole.SYSTEM, - text="Please incorporate the feedback above, and provide a response to user's next message.", - ) - ) - messages.extend(request.user_messages) - - # Get the new response from the chat client. - response = await self._chat_client.get_response(messages=messages) - print( - f"🔧 Worker: New response generated after feedback: {response.messages[-1].text}" - ) - - # Process the response. - messages.extend(response.messages) - - print( - f"🔧 Worker: Generated improved response, sending for re-review (ID: {review.request_id[:8]})" - ) - # Send an updated review request. - new_request = ReviewRequest( - request_id=review.request_id, - user_messages=request.user_messages, - agent_messages=response.messages, - ) - await ctx.send_message(new_request) +from azure.ai.agentserver.agentframework import from_agent_framework - # Add to pending requests. - self._pending_requests[new_request.request_id] = (new_request, messages) +def create_writer_agent(client: AzureAIAgentClient) -> ChatAgent: + return client.create_agent( + name="Writer", + instructions=( + "You are an excellent content writer. You create new content and edit contents based on the feedback." + ), + ) -def create_builder(chat_client: BaseChatClient): - reviewer = Reviewer(chat_client=chat_client) - worker = Worker(chat_client=chat_client) - return ( - WorkflowBuilder() - .add_edge( - worker, reviewer - ) # <--- This edge allows the worker to send requests to the reviewer - .add_edge( - reviewer, worker - ) # <--- This edge allows the reviewer to send feedback back to the worker - .set_start_executor(worker) +def create_reviewer_agent(client: AzureAIAgentClient) -> ChatAgent: + return client.create_agent( + name="Reviewer", + instructions=( + "You are an excellent content reviewer. " + "Provide actionable feedback to the writer about the provided content. " + "Provide the feedback in the most concise manner possible." + ), ) async def main() -> None: - async with DefaultAzureCredential() as credential: - async with AzureAIAgentClient(async_credential=credential) as chat_client: - builder = create_builder(chat_client) - await from_agent_framework(builder).run_async() + async with AzureCliCredential() as cred, AzureAIAgentClient(credential=cred) as client: + builder = ( + WorkflowBuilder() + .register_agent(lambda: create_writer_agent(client), name="writer") + .register_agent(lambda: create_reviewer_agent(client), name="reviewer", output_response=True) + .set_start_executor("writer") + .add_edge("writer", "reviewer") + ) + + # Pass the WorkflowBuilder to the adapter and run it + await from_agent_framework(workflow=builder).run_async() + + # Or create a factory function for the workflow pass the workflow factory to the adapter + # def workflow_factory() -> Workflow: + # return builder.build() + # await from_agent_framework(workflow=workflow_factory).run_async() if __name__ == "__main__": From 0dca736a229d9d19dfce4c043bc1109d30d802ad Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Wed, 21 Jan 2026 09:26:51 -0800 Subject: [PATCH 09/29] resolve trailing whitespace --- .../ai/agentserver/agentframework/__init__.py | 12 ++++++------ .../agentserver/agentframework/_agent_framework.py | 8 ++++---- .../agentserver/agentframework/_foundry_tools.py | 2 +- .../models/agent_framework_input_converters.py | 8 ++++---- ...ent_framework_output_non_streaming_converter.py | 6 +++--- .../agent_framework_output_streaming_converter.py | 10 +++++----- .../models/human_in_the_loop_helper.py | 14 +++++++------- .../persistence/agent_thread_repository.py | 10 +++++----- .../persistence/checkpoint_repository.py | 4 ++-- 9 files changed, 37 insertions(+), 37 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py index 1a276a14ff9e..a213d06fb00b 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py @@ -27,7 +27,7 @@ def from_agent_framework( ) -> "AgentFrameworkAIAgentAdapter": """ Create an Agent Framework AI Agent Adapter from an AgentProtocol or BaseAgent. - + :param agent: The agent to adapt. :type agent: Union[BaseAgent, AgentProtocol] :param credentials: Optional asynchronous token credential for authentication. @@ -49,9 +49,9 @@ def from_agent_framework( ) -> "AgentFrameworkWorkflowAdapter": """ Create an Agent Framework Workflow Adapter. - The arugument `workflow` can be either a WorkflowBuilder or a factory function + The arugument `workflow` can be either a WorkflowBuilder or a factory function that returns a Workflow. - It will be called to create a new Workflow instance and `.as_agent()` will be + It will be called to create a new Workflow instance and `.as_agent()` will be called as well for each incoming CreateResponse request. Please ensure that the workflow definition can be converted to a WorkflowAgent. For more information, see the agent-framework samples and documentation. @@ -61,7 +61,7 @@ def from_agent_framework( :param credentials: Optional asynchronous token credential for authentication. :type credentials: Optional[AsyncTokenCredential] :param kwargs: Additional keyword arguments to pass to the adapter. - :type kwargs: Any + :type kwargs: Any :return: An instance of AgentFrameworkWorkflowAdapter. :rtype: AgentFrameworkWorkflowAdapter """ @@ -75,7 +75,7 @@ def from_agent_framework( **kwargs: Any, ) -> "AgentFrameworkAgent": """ - Create an Agent Framework Adapter from either an AgentProtocol/BaseAgent or a + Create an Agent Framework Adapter from either an AgentProtocol/BaseAgent or a WorkflowAgent. One of agent or workflow must be provided. @@ -118,4 +118,4 @@ def workflow_factory() -> Workflow: ] __version__ = VERSION -set_current_app(PackageMetadata.from_dist("azure-ai-agentserver-agentframework")) \ No newline at end of file +set_current_app(PackageMetadata.from_dist("azure-ai-agentserver-agentframework")) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py index ece0198285bf..b9460bf7dfdf 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py @@ -114,9 +114,9 @@ def _create_application_insights_exporter(self, connection_string): def _create_otlp_exporter(self, endpoint, protocol=None): try: if protocol and protocol.lower() in ("http", "http/protobuf", "http/json"): - from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter - return OTLPSpanExporter(endpoint=endpoint) + return OTLPSpanExporter(endpoint=endpoint) else: from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter @@ -182,7 +182,7 @@ async def agent_run( # pylint: disable=too-many-statements AsyncGenerator[ResponseStreamEvent, Any], ]: raise NotImplementedError("This method is implemented in the base class.") - + async def _load_agent_thread(self, context: AgentRunContext, agent: Union[AgentProtocol, WorkflowAgent]) -> Optional[AgentThread]: """Load the agent thread for a given conversation ID. @@ -268,4 +268,4 @@ async def stream_updates(): # No request-scoped resources to clean up today, but keep hook for future use. pass - return stream_updates() \ No newline at end of file + return stream_updates() diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py index 875c1de24e8c..10e22ac128d5 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py @@ -50,7 +50,7 @@ async def list_tools(self) -> List[AIFunction]: foundry_tool_catalog = server_context.tools.catalog resolved_tools = await foundry_tool_catalog.list(self._allowed_tools) return [self._to_aifunction(tool) for tool in resolved_tools] - + def _to_aifunction(self, foundry_tool: "ResolvedFoundryTool") -> AIFunction: """Convert an FoundryTool to an Agent Framework AI Function diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py index a6eefb2b1740..2ad8ab3cee11 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py @@ -43,7 +43,7 @@ async def transform_input( if isinstance(input, str): return input - + if self._hitl_helper: # load pending requests from checkpoint and thread messages if available thread_messages = [] @@ -57,7 +57,7 @@ async def transform_input( logger.info(f"HitL response validation result: {[m.to_dict() for m in hitl_response]}") if hitl_response: return hitl_response - + return self._transform_input_internal(input) def _transform_input_internal( @@ -163,7 +163,7 @@ def _validate_and_convert_hitl_response( if not isinstance(input, list) or len(input) != 1: logger.warning("Expected single-item list input for HitL response validation.") return None - + item = input[0] if item.get("type") != "function_call_output": logger.warning("Expected function_call_output type for HitL response validation.") @@ -178,5 +178,5 @@ def _validate_and_convert_hitl_response( if not isinstance(request_info, RequestInfoEvent): logger.warning("No valid pending request info found for call_id: %s", call_id) return None - + return self._hitl_helper.convert_response(request_info, item) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py index 95c7bb7acc6b..c272373b2291 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py @@ -53,13 +53,13 @@ def _build_item_content_output_text(self, text: str) -> ItemContentOutputText: def _build_created_by(self, author_name: str) -> dict: self._ensure_response_started() - + agent_dict = { "type": "agent_id", "name": author_name or "", "version": "", # Default to empty string } - + return { "agent": agent_dict, "response_id": self._response_id, @@ -216,7 +216,7 @@ def _append_function_result_content(self, content: FunctionResultContent, sink: call_id, len(result), ) - + def _append_user_input_request_contents(self, content: UserInputRequestContents, sink: List[dict], author_name: str) -> None: item_id = self._context.id_generator.generate_function_call_id() content = self._hitl_helper.convert_user_input_request_content(content) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py index 253b0fc7aa9e..e216dbb767b6 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py @@ -116,8 +116,8 @@ async def convert_contents(self, contents: AsyncIterable[TextContent], author_na ) item = ResponsesAssistantMessageItemResource( - id=item_id, - status="completed", + id=item_id, + status="completed", content=[content_part], created_by=self._parent._build_created_by(author_name), ) @@ -209,7 +209,7 @@ async def convert_contents( ) self._parent.add_completed_output_item(item) # pylint: disable=protected-access - + # process HITL contents after function calls for content in hitl_contents: item_id = self._parent.context.id_generator.generate_function_call_id() @@ -389,7 +389,7 @@ async def convert(self, updates: AsyncIterable[AgentRunResponseUpdate]) -> Async async def extract_contents(): async for content, _ in contents_with_author: yield content - + async for content in state.convert_contents(extract_contents(), author_name): yield content @@ -404,7 +404,7 @@ def _build_created_by(self, author_name: str) -> dict: "name": author_name or "", "version": "", } - + return { "agent": agent_dict, "response_id": self._response_id, diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py index 30bb3aa8d9c5..a40eac7f7a7d 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py @@ -30,10 +30,10 @@ def get_pending_hitl_request(self, request_obj = RequestInfoEvent.from_dict(request) res[call_id] = request_obj return res - + if not thread_messages: return res - + # if no checkpoint (Agent), find user input request and pair the feedbacks for message in thread_messages: for content in message.contents: @@ -59,7 +59,7 @@ def get_pending_hitl_request(self, if call_id and call_id in res: res.pop(call_id) return res - + def convert_user_input_request_content(self, content: UserInputRequestContents) -> dict: function_call = content.function_call call_id = getattr(function_call, "call_id", "") @@ -69,7 +69,7 @@ def convert_user_input_request_content(self, content: UserInputRequestContents) "name": HUMAN_IN_THE_LOOP_FUNCTION_NAME, "arguments": arguments or "", } - + def convert_request_arguments(self, arguments: Any) -> str: # convert data to payload if possible if isinstance(arguments, dict): @@ -85,7 +85,7 @@ def convert_request_arguments(self, arguments: Any) -> str: except Exception: # pragma: no cover - fallback # pylint: disable=broad-exception-caught arguments = str(arguments) return arguments - + def validate_and_convert_hitl_response(self, input: str | List[Dict] | None, pending_requests: Dict[str, RequestInfoEvent], @@ -94,7 +94,7 @@ def validate_and_convert_hitl_response(self, if input is None or isinstance(input, str): logger.warning("Expected list input for HitL response validation, got str.") return None - + res = [] for item in input: if item.get("type") != "function_call_output": @@ -116,4 +116,4 @@ def convert_response(self, hitl_request: RequestInfoEvent, input: Dict) -> ChatM call_id=hitl_request.request_id, result=response_result, ) - return ChatMessage(role="tool", contents=[response_content]) \ No newline at end of file + return ChatMessage(role="tool", contents=[response_content]) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py index 294d7d0948fc..293c8ec88805 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py @@ -22,7 +22,7 @@ async def get(self, conversation_id: str, agent: Optional[Union[AgentProtocol, W :rtype: Optional[AgentThread] """ - @abstractmethod + @abstractmethod async def set(self, conversation_id: str, thread: AgentThread) -> None: """Save the thread for a given conversation ID. @@ -69,7 +69,7 @@ class SerializedAgentThreadRepository(AgentThreadRepository): def __init__(self, agent: AgentProtocol) -> None: """ Initialize the repository with the given agent. - + :param agent: The agent instance. :type agent: AgentProtocol """ @@ -115,7 +115,7 @@ async def read_from_storage(self, conversation_id: str) -> Optional[Any]: :rtype: Optional[Any] """ raise NotImplementedError("read_from_storage is not implemented.") - + async def write_to_storage(self, conversation_id: str, serialized_thread: Any) -> None: """Write the serialized thread to storage. @@ -125,7 +125,7 @@ async def write_to_storage(self, conversation_id: str, serialized_thread: Any) - :type serialized_thread: Any """ raise NotImplementedError("write_to_storage is not implemented.") - + class JsonLocalFileAgentThreadRepository(SerializedAgentThreadRepository): """Json based implementation of AgentThreadRepository using local file storage.""" @@ -150,4 +150,4 @@ async def write_to_storage(self, conversation_id: str, serialized_thread: Any) - f.write(serialized_str) def _get_file_path(self, conversation_id: str) -> str: - return os.path.join(self._storage_path, f"{conversation_id}.json") \ No newline at end of file + return os.path.join(self._storage_path, f"{conversation_id}.json") diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py index 0bc89a4b5377..c4a1473c5f6b 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py @@ -57,6 +57,6 @@ async def get_or_create(self, conversation_id: str) -> Optional[CheckpointStorag if conversation_id not in self._inventory: self._inventory[conversation_id] = FileCheckpointStorage(self._get_dir_path(conversation_id)) return self._inventory[conversation_id] - + def _get_dir_path(self, conversation_id: str) -> str: - return os.path.join(self._storage_path, conversation_id) \ No newline at end of file + return os.path.join(self._storage_path, conversation_id) From bd1b95e3463aecb8f1d8dc6465b02ac8af3f4655 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Wed, 21 Jan 2026 09:28:01 -0800 Subject: [PATCH 10/29] fix samples --- .../samples/basic_simple/minimal_example.py | 4 ++-- .../chat_client_with_foundry_tool.py | 2 +- .../samples/human_in_the_loop_ai_function/main.py | 2 +- .../samples/human_in_the_loop_workflow_agent/main.py | 2 +- .../samples/mcp_apikey/mcp_apikey.py | 2 +- .../samples/mcp_simple/mcp_simple.py | 2 +- .../samples/simple_async/minimal_async_example.py | 2 +- .../workflow_agent_simple/workflow_agent_simple.py | 10 +++++----- 8 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/basic_simple/minimal_example.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/basic_simple/minimal_example.py index 15afa52f42b8..1d5aab07ae8a 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/basic_simple/minimal_example.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/basic_simple/minimal_example.py @@ -6,10 +6,10 @@ from agent_framework.azure import AzureOpenAIChatClient from azure.identity import DefaultAzureCredential from dotenv import load_dotenv +load_dotenv() from azure.ai.agentserver.agentframework import from_agent_framework -load_dotenv() def get_weather( @@ -26,7 +26,7 @@ def main() -> None: tools=get_weather, ) - from_agent_framework(agent).run() + from_agent_framework(agent=agent).run() if __name__ == "__main__": diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/chat_client_with_foundry_tool.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/chat_client_with_foundry_tool.py index cb9c3cd2c9c6..08e7e8bdffc7 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/chat_client_with_foundry_tool.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/chat_client_with_foundry_tool.py @@ -28,7 +28,7 @@ def main(): instructions="You are a helpful assistant with access to various tools.", ) - from_agent_framework(agent).run() + from_agent_framework(agent=agent).run() if __name__ == "__main__": main() diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py index 56dc5fca8860..db81c1091336 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py @@ -78,7 +78,7 @@ def build_agent(): async def main() -> None: agent = build_agent() thread_repository = JsonLocalFileAgentThreadRepository(agent=agent, storage_path="./thread_storage") - await from_agent_framework(agent, thread_repository=thread_repository).run_async() + await from_agent_framework(agent=agent, thread_repository=thread_repository).run_async() if __name__ == "__main__": asyncio.run(main()) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py index cc89c941e65e..e749a4a62fc6 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py @@ -111,7 +111,7 @@ async def run_agent() -> None: """Run the workflow inside the agent server adapter.""" builder = create_builder() await from_agent_framework( - builder, # pass workflow builder to adapter + workflow=builder, # pass workflow builder to adapter checkpoint_repository=FileCheckpointRepository(storage_path="./checkpoints"), # for checkpoint storage ).run_async() diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_apikey/mcp_apikey.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_apikey/mcp_apikey.py index 985d7fd01e0c..2a1058e7f468 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_apikey/mcp_apikey.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_apikey/mcp_apikey.py @@ -35,7 +35,7 @@ async def main() -> None: ) async with agent: - await from_agent_framework(agent).run_async() + await from_agent_framework(agent=agent).run_async() if __name__ == "__main__": diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_simple/mcp_simple.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_simple/mcp_simple.py index 6b59771fe0da..ce5bb37eea4f 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_simple/mcp_simple.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_simple/mcp_simple.py @@ -22,7 +22,7 @@ async def main() -> None: ) async with agent: - await from_agent_framework(agent).run_async() + await from_agent_framework(agent=agent).run_async() if __name__ == "__main__": diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/simple_async/minimal_async_example.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/simple_async/minimal_async_example.py index 4c69c8afa84d..ac781d4d39ab 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/simple_async/minimal_async_example.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/simple_async/minimal_async_example.py @@ -28,7 +28,7 @@ async def main() -> None: ) async with agent: - await from_agent_framework(agent).run_async() + await from_agent_framework(agent=agent).run_async() if __name__ == "__main__": diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py index eac282110d65..a79e24f1a3fb 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py @@ -3,7 +3,7 @@ load_dotenv() -from agent_framework import ChatAgent, WorkflowBuilder +from agent_framework import ChatAgent, Workflow, WorkflowBuilder from agent_framework.azure import AzureAIAgentClient from azure.identity.aio import AzureCliCredential @@ -40,12 +40,12 @@ async def main() -> None: ) # Pass the WorkflowBuilder to the adapter and run it - await from_agent_framework(workflow=builder).run_async() + # await from_agent_framework(workflow=builder).run_async() # Or create a factory function for the workflow pass the workflow factory to the adapter - # def workflow_factory() -> Workflow: - # return builder.build() - # await from_agent_framework(workflow=workflow_factory).run_async() + def workflow_factory() -> Workflow: + return builder.build() + await from_agent_framework(workflow=workflow_factory).run_async() if __name__ == "__main__": From 47bfb319e1e10129de1deb38d407da003dcdc447 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Wed, 21 Jan 2026 09:45:53 -0800 Subject: [PATCH 11/29] resolve unused import and long lines --- .../agentframework/_agent_framework.py | 19 ++++++++--- .../agentframework/_workflow_agent_adapter.py | 16 +++++++-- ...ramework_output_non_streaming_converter.py | 30 +++++++++++------ ...nt_framework_output_streaming_converter.py | 33 ++++++++++++++----- .../persistence/agent_thread_repository.py | 12 +++++-- .../persistence/checkpoint_repository.py | 2 +- 6 files changed, 85 insertions(+), 27 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py index b9460bf7dfdf..749c084cf5ff 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py @@ -183,7 +183,11 @@ async def agent_run( # pylint: disable=too-many-statements ]: raise NotImplementedError("This method is implemented in the base class.") - async def _load_agent_thread(self, context: AgentRunContext, agent: Union[AgentProtocol, WorkflowAgent]) -> Optional[AgentThread]: + async def _load_agent_thread( + self, + context: AgentRunContext, + agent: Union[AgentProtocol, WorkflowAgent], + ) -> Optional[AgentThread]: """Load the agent thread for a given conversation ID. :param context: The agent run context. @@ -223,7 +227,10 @@ def _run_streaming_updates( ) -> AsyncGenerator[ResponseStreamEvent, Any]: """Execute a streaming run with shared OAuth/error handling.""" logger.info("Running agent in streaming mode") - streaming_converter = AgentFrameworkOutputStreamingConverter(context, hitl_helper=self._hitl_helper) + streaming_converter = AgentFrameworkOutputStreamingConverter( + context, + hitl_helper=self._hitl_helper, + ) async def stream_updates(): try: @@ -250,7 +257,9 @@ async def stream_updates(): ) yield ResponseFailedEvent( sequence_number=streaming_converter.next_sequence(), - response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access + response=streaming_converter._build_response( + status="failed" + ), # pylint: disable=protected-access ) except Exception as e: # pylint: disable=broad-exception-caught logger.error("Unhandled exception during streaming updates: %s", e, exc_info=True) @@ -262,7 +271,9 @@ async def stream_updates(): ) yield ResponseFailedEvent( sequence_number=streaming_converter.next_sequence(), - response=streaming_converter._build_response(status="failed"), # pylint: disable=protected-access + response=streaming_converter._build_response( + status="failed" + ), # pylint: disable=protected-access ) finally: # No request-scoped resources to clean up today, but keep hook for future use. diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py index 92667eb7cbcc..29280428dbba 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py @@ -1,7 +1,16 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Optional, Protocol, Union, List +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Callable, + List, + Optional, + Protocol, + Union, +) from agent_framework import Workflow, CheckpointStorage, WorkflowAgent, WorkflowCheckpoint from agent_framework._workflows import get_checkpoint_summary @@ -59,7 +68,10 @@ async def agent_run( # pylint: disable=too-many-statements if selected_checkpoint: summary = get_checkpoint_summary(selected_checkpoint) if summary.status == "completed": - logger.warning(f"Selected checkpoint {selected_checkpoint.checkpoint_id} is completed. Will not resume from it.") + logger.warning( + "Selected checkpoint %s is completed. Will not resume from it.", + selected_checkpoint.checkpoint_id, + ) selected_checkpoint = None # Do not resume from completed checkpoints else: await self._load_checkpoint(agent, selected_checkpoint, checkpoint_storage) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py index c272373b2291..4984b2fc0423 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py @@ -19,12 +19,7 @@ from azure.ai.agentserver.core import AgentRunContext from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.models import Response as OpenAIResponse -from azure.ai.agentserver.core.models.projects import ( - AgentId, - CreatedBy, - ItemContentOutputText, - ResponsesAssistantMessageItemResource, -) +from azure.ai.agentserver.core.models.projects import ItemContentOutputText from .agent_id_generator import AgentIdGenerator from .constants import Constants @@ -189,7 +184,12 @@ def _append_function_call_content(self, content: FunctionCallContent, sink: List len(arguments or ""), ) - def _append_function_result_content(self, content: FunctionResultContent, sink: List[dict], author_name: str) -> None: + def _append_function_result_content( + self, + content: FunctionResultContent, + sink: List[dict], + author_name: str, + ) -> None: # Coerce the function result into a simple display string. result = [] raw = getattr(content, "result", None) @@ -211,13 +211,19 @@ def _append_function_result_content(self, content: FunctionResultContent, sink: } ) logger.debug( - "added function_call_output item id=%s call_id=%s output_len=%d", + "added function_call_output item id=%s call_id=%s " + "output_len=%d", func_out_id, call_id, len(result), ) - def _append_user_input_request_contents(self, content: UserInputRequestContents, sink: List[dict], author_name: str) -> None: + def _append_user_input_request_contents( + self, + content: UserInputRequestContents, + sink: List[dict], + author_name: str, + ) -> None: item_id = self._context.id_generator.generate_function_call_id() content = self._hitl_helper.convert_user_input_request_content(content) sink.append( @@ -231,7 +237,11 @@ def _append_user_input_request_contents(self, content: UserInputRequestContents, "created_by": self._build_created_by(author_name), } ) - logger.debug(" added user_input_request item id=%s call_id=%s", item_id, content["call_id"]) + logger.debug( + " added user_input_request item id=%s call_id=%s", + item_id, + content["call_id"], + ) # ------------- simple normalization helper ------------------------- def _coerce_result_text(self, value: Any) -> str | dict: diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py index e216dbb767b6..4b281530f74e 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py @@ -5,12 +5,15 @@ # mypy: disable-error-code="call-overload,assignment,arg-type,override" from __future__ import annotations -from ast import arguments import datetime import json from typing import Any, AsyncIterable, List, Union -from agent_framework import AgentRunResponseUpdate, BaseContent, FunctionApprovalRequestContent, FunctionResultContent +from agent_framework import ( + AgentRunResponseUpdate, + BaseContent, + FunctionResultContent, +) from agent_framework._types import ( ErrorContent, FunctionCallContent, @@ -24,8 +27,6 @@ ResponseStreamEvent, ) from azure.ai.agentserver.core.models.projects import ( - AgentId, - CreatedBy, FunctionToolCallItemResource, FunctionToolCallOutputItemResource, ItemContentOutputText, @@ -52,7 +53,12 @@ class _BaseStreamingState: """Base interface for streaming state handlers.""" - async def convert_contents(self, contents: AsyncIterable[BaseContent], author_name: str) -> AsyncIterable[ResponseStreamEvent]: # pylint: disable=unused-argument + async def convert_contents( + self, + contents: AsyncIterable[BaseContent], + author_name: str, + ) -> AsyncIterable[ResponseStreamEvent]: + # pylint: disable=unused-argument raise NotImplementedError @@ -62,7 +68,11 @@ class _TextContentStreamingState(_BaseStreamingState): def __init__(self, parent: AgentFrameworkOutputStreamingConverter): self._parent = parent - async def convert_contents(self, contents: AsyncIterable[TextContent], author_name: str) -> AsyncIterable[ResponseStreamEvent]: + async def convert_contents( + self, + contents: AsyncIterable[TextContent], + author_name: str, + ) -> AsyncIterable[ResponseStreamEvent]: item_id = self._parent.context.id_generator.generate_message_id() output_index = self._parent.next_output_index() @@ -381,7 +391,11 @@ async def convert(self, updates: AsyncIterable[AgentRunResponseUpdate]) -> Async elif isinstance(first, FunctionResultContent): state = _FunctionCallOutputStreamingState(self) elif isinstance(first, ErrorContent): - raise ValueError(f"ErrorContent received: code={first.error_code}, message={first.message}") + error_msg = ( + f"ErrorContent received: code={first.error_code}, " + f"message={first.message}" + ) + raise ValueError(error_msg) if not state: continue @@ -410,7 +424,10 @@ def _build_created_by(self, author_name: str) -> dict: "response_id": self._response_id, } - async def _read_updates(self, updates: AsyncIterable[AgentRunResponseUpdate]) -> AsyncIterable[tuple[BaseContent, str]]: + async def _read_updates( + self, + updates: AsyncIterable[AgentRunResponseUpdate], + ) -> AsyncIterable[tuple[BaseContent, str]]: async for update in updates: if not update.contents: continue diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py index 293c8ec88805..41f5b6a1f5cd 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py @@ -38,7 +38,11 @@ class InMemoryAgentThreadRepository(AgentThreadRepository): def __init__(self) -> None: self._inventory: dict[str, AgentThread] = {} - async def get(self, conversation_id: str, agent: Optional[Union[AgentProtocol, WorkflowAgent]]=None) -> Optional[AgentThread]: + async def get( + self, + conversation_id: str, + agent: Optional[Union[AgentProtocol, WorkflowAgent]] = None, + ) -> Optional[AgentThread]: """Retrieve the saved thread for a given conversation ID. :param conversation_id: The conversation ID. @@ -75,7 +79,11 @@ def __init__(self, agent: AgentProtocol) -> None: """ self._agent = agent - async def get(self, conversation_id: str, agent: Optional[Union[AgentProtocol, WorkflowAgent]]=None) -> Optional[AgentThread]: + async def get( + self, + conversation_id: str, + agent: Optional[Union[AgentProtocol, WorkflowAgent]] = None, + ) -> Optional[AgentThread]: """Retrieve the saved thread for a given conversation ID. :param conversation_id: The conversation ID. diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py index c4a1473c5f6b..e32784850838 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod import os -from typing import Any, Optional +from typing import Optional from agent_framework import ( CheckpointStorage, From fcd73b9d16c1fffa294c21d42e73692a5edb9a5f Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Wed, 21 Jan 2026 11:09:18 -0800 Subject: [PATCH 12/29] fixed pylint --- .../ai/agentserver/agentframework/__init__.py | 24 +++++------ .../agentframework/_agent_framework.py | 40 +++++++++++++------ .../agentframework/_ai_agent_adapter.py | 8 ++-- .../agentframework/_foundry_tools.py | 4 +- .../agentframework/_workflow_agent_adapter.py | 26 +++++++----- .../agent_framework_input_converters.py | 4 +- ...nt_framework_output_streaming_converter.py | 8 ++-- .../models/human_in_the_loop_helper.py | 7 +++- .../agentframework/persistence/__init__.py | 2 +- .../persistence/agent_thread_repository.py | 11 ++++- .../persistence/checkpoint_repository.py | 3 ++ 11 files changed, 84 insertions(+), 53 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py index a213d06fb00b..04a8459c53d9 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py @@ -12,7 +12,7 @@ from azure.ai.agentserver.agentframework._ai_agent_adapter import AgentFrameworkAIAgentAdapter from azure.ai.agentserver.agentframework._workflow_agent_adapter import AgentFrameworkWorkflowAdapter from azure.ai.agentserver.agentframework._foundry_tools import FoundryToolsChatMiddleware -from azure.ai.agentserver.core.application import PackageMetadata, set_current_app +from azure.ai.agentserver.core.application import PackageMetadata, set_current_app # pylint: disable=import-error,no-name-in-module if TYPE_CHECKING: # pragma: no cover from azure.core.credentials_async import AsyncTokenCredential @@ -28,12 +28,10 @@ def from_agent_framework( """ Create an Agent Framework AI Agent Adapter from an AgentProtocol or BaseAgent. - :param agent: The agent to adapt. + :keyword agent: The agent to adapt. :type agent: Union[BaseAgent, AgentProtocol] - :param credentials: Optional asynchronous token credential for authentication. + :keyword credentials: Optional asynchronous token credential for authentication. :type credentials: Optional[AsyncTokenCredential] - :param kwargs: Additional keyword arguments to pass to the adapter. - :type kwargs: Any :return: An instance of AgentFrameworkAIAgentAdapter. :rtype: AgentFrameworkAIAgentAdapter @@ -56,12 +54,10 @@ def from_agent_framework( workflow definition can be converted to a WorkflowAgent. For more information, see the agent-framework samples and documentation. - :param workflow: The workflow builder or factory function to adapt. + :keyword workflow: The workflow builder or factory function to adapt. :type workflow: Union[WorkflowBuilder, Callable[[], Workflow]] - :param credentials: Optional asynchronous token credential for authentication. + :keyword credentials: Optional asynchronous token credential for authentication. :type credentials: Optional[AsyncTokenCredential] - :param kwargs: Additional keyword arguments to pass to the adapter. - :type kwargs: Any :return: An instance of AgentFrameworkWorkflowAdapter. :rtype: AgentFrameworkWorkflowAdapter """ @@ -79,14 +75,12 @@ def from_agent_framework( WorkflowAgent. One of agent or workflow must be provided. - :param agent: The agent to adapt. + :keyword agent: The agent to adapt. :type agent: Optional[Union[BaseAgent, AgentProtocol]] - :param workflow: The workflow builder or factory function to adapt. + :keyword workflow: The workflow builder or factory function to adapt. :type workflow: Optional[Union[WorkflowBuilder, Callable[[], Workflow]]] - :param credentials: Optional asynchronous token credential for authentication. + :keyword credentials: Optional asynchronous token credential for authentication. :type credentials: Optional[AsyncTokenCredential] - :param kwargs: Additional keyword arguments to pass to the adapter. - :type kwargs: Any :return: An instance of AgentFrameworkAgent. :rtype: AgentFrameworkAgent :raises TypeError: If neither or both of agent and workflow are provided, or if @@ -107,7 +101,7 @@ def workflow_factory() -> Workflow: return AgentFrameworkWorkflowAdapter(workflow_factory=workflow, credentials=credentials, **kwargs) raise TypeError("workflow must be a WorkflowBuilder or callable returning a Workflow") - if isinstance(agent, AgentProtocol) or isinstance(agent, BaseAgent): + if isinstance(agent, (AgentProtocol, BaseAgent)): return AgentFrameworkAIAgentAdapter(agent, credentials=credentials, **kwargs) raise TypeError("agent must be an instance of AgentProtocol or BaseAgent") diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py index 749c084cf5ff..732b70095028 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py @@ -19,7 +19,7 @@ ResponseStreamEvent, ) from azure.ai.agentserver.core.models.projects import ResponseErrorEvent, ResponseFailedEvent -from azure.ai.agentserver.core.tools import OAuthConsentRequiredError +from azure.ai.agentserver.core.tools import OAuthConsentRequiredError # pylint: disable=import-error from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter from .models.human_in_the_loop_helper import HumanInTheLoopHelper @@ -98,7 +98,7 @@ def init_tracing(self): logger.info("Observability setup completed with provided exporters.") elif project_endpoint: self._setup_tracing_with_azure_ai_client(project_endpoint) - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught logger.warning(f"Failed to initialize tracing: {e}", exc_info=True) self.tracer = trace.get_tracer(__name__) @@ -107,7 +107,7 @@ def _create_application_insights_exporter(self, connection_string): from azure.monitor.opentelemetry.exporter import AzureMonitorTraceExporter return AzureMonitorTraceExporter.from_connection_string(connection_string) - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught logger.error(f"Failed to create Application Insights exporter: {e}", exc_info=True) return None @@ -117,11 +117,11 @@ def _create_otlp_exporter(self, endpoint, protocol=None): from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter return OTLPSpanExporter(endpoint=endpoint) - else: - from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter - return OTLPSpanExporter(endpoint=endpoint) - except Exception as e: + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + + return OTLPSpanExporter(endpoint=endpoint) + except Exception as e: # pylint: disable=broad-exception-caught logger.error(f"Failed to create OTLP exporter: {e}", exc_info=True) return None @@ -213,6 +213,9 @@ async def _save_agent_thread(self, context: AgentRunContext, agent_thread: Agent :type context: AgentRunContext :param agent_thread: The agent thread to save. :type agent_thread: AgentThread + + :return: None + :rtype: None """ if agent_thread and self._thread_repository: await self._thread_repository.set(context.conversation_id, agent_thread) @@ -220,12 +223,23 @@ async def _save_agent_thread(self, context: AgentRunContext, agent_thread: Agent def _run_streaming_updates( self, - *, context: AgentRunContext, run_stream: Callable[[], AsyncGenerator[Any, None]], agent_thread: Optional[AgentThread] = None, ) -> AsyncGenerator[ResponseStreamEvent, Any]: - """Execute a streaming run with shared OAuth/error handling.""" + """ + Execute a streaming run with shared OAuth/error handling. + + :param context: The agent run context. + :type context: AgentRunContext + :param run_stream: A callable that invokes the agent in stream mode + :type run_stream: Callable[[], AsyncGenerator[Any, None]] + :param agent_thread: The agent thread to use during streaming updates. + :type agent_thread: Optional[AgentThread] + + :return: An async generator yielding streaming events. + :rtype: AsyncGenerator[ResponseStreamEvent, Any] + """ logger.info("Running agent in streaming mode") streaming_converter = AgentFrameworkOutputStreamingConverter( context, @@ -257,9 +271,9 @@ async def stream_updates(): ) yield ResponseFailedEvent( sequence_number=streaming_converter.next_sequence(), - response=streaming_converter._build_response( + response=streaming_converter._build_response( # pylint: disable=protected-access status="failed" - ), # pylint: disable=protected-access + ), ) except Exception as e: # pylint: disable=broad-exception-caught logger.error("Unhandled exception during streaming updates: %s", e, exc_info=True) @@ -271,9 +285,9 @@ async def stream_updates(): ) yield ResponseFailedEvent( sequence_number=streaming_converter.next_sequence(), - response=streaming_converter._build_response( + response=streaming_converter._build_response( # pylint: disable=protected-access status="failed" - ), # pylint: disable=protected-access + ), ) finally: # No request-scoped resources to clean up today, but keep hook for future use. diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py index 6105470dbdc9..622fb2762e7b 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py @@ -1,6 +1,7 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +# pylint: disable=no-name-in-module,import-error from __future__ import annotations from typing import Any, AsyncGenerator, Optional, Union @@ -40,7 +41,7 @@ async def agent_run( # pylint: disable=too-many-statements AsyncGenerator[ResponseStreamEvent, Any], ]: try: - logger.info(f"Starting AIAgent agent_run with stream={context.stream}") + logger.info("Starting AIAgent agent_run with stream=%s", context.stream) request_input = context.request.get("input") agent_thread = await self._load_agent_thread(context, self._agent) @@ -49,8 +50,7 @@ async def agent_run( # pylint: disable=too-many-statements message = await input_converter.transform_input( request_input, agent_thread=agent_thread) - logger.debug(f"Transformed input message type: {type(message)}") - + logger.debug("Transformed input message type: %s", type(message)) # Use split converters if context.stream: return self._run_streaming_updates( @@ -67,7 +67,7 @@ async def agent_run( # pylint: disable=too-many-statements result = await self.agent.run( message, thread=agent_thread) - logger.debug(f"Agent run completed, result type: {type(result)}") + logger.debug("Agent run completed, result type: %s", type(result)) await self._save_agent_thread(context, agent_thread) non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py index 10e22ac128d5..78d8108ed96c 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py @@ -1,6 +1,8 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +# pylint: disable=client-accepts-api-version-keyword,missing-client-constructor-parameter-credential,missing-client-constructor-parameter-kwargs +# pylint: disable=no-name-in-module,import-error from __future__ import annotations import inspect @@ -24,7 +26,7 @@ def _attach_signature_from_pydantic_model(func, input_model) -> None: ann = field.annotation or Any annotations[name] = ann - default = inspect._empty if field.is_required() else field.default + default = inspect._empty if field.is_required() else field.default # pylint: disable=protected-access params.append( inspect.Parameter( name=name, diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py index 29280428dbba..fb40cb453124 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py @@ -1,27 +1,25 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +# pylint: disable=no-name-in-module,import-error from typing import ( - TYPE_CHECKING, Any, AsyncGenerator, Callable, - List, Optional, - Protocol, Union, ) from agent_framework import Workflow, CheckpointStorage, WorkflowAgent, WorkflowCheckpoint from agent_framework._workflows import get_checkpoint_summary -from azure.ai.agentserver.core.tools import OAuthConsentRequiredError from azure.ai.agentserver.core import AgentRunContext from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.models import ( Response as OpenAIResponse, ResponseStreamEvent, ) +from azure.ai.agentserver.core.tools import OAuthConsentRequiredError from ._agent_framework import AgentFrameworkAgent from .models.agent_framework_input_converters import AgentFrameworkInputConverter @@ -55,7 +53,7 @@ async def agent_run( # pylint: disable=too-many-statements try: agent = self._build_agent() - logger.info(f"Starting WorkflowAgent agent_run with stream={context.stream}") + logger.info("Starting WorkflowAgent agent_run with stream=%s", context.stream) request_input = context.request.get("input") agent_thread = await self._load_agent_thread(context, agent) @@ -75,14 +73,14 @@ async def agent_run( # pylint: disable=too-many-statements selected_checkpoint = None # Do not resume from completed checkpoints else: await self._load_checkpoint(agent, selected_checkpoint, checkpoint_storage) - logger.info(f"Loaded checkpoint with ID: {selected_checkpoint.checkpoint_id}") + logger.info("Loaded checkpoint with ID: %s", selected_checkpoint.checkpoint_id) input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper) message = await input_converter.transform_input( request_input, agent_thread=agent_thread, checkpoint=selected_checkpoint) - logger.debug(f"Transformed input message type: {type(message)}") + logger.debug("Transformed input message type: %s", type(message)) # Use split converters if context.stream: @@ -102,7 +100,7 @@ async def agent_run( # pylint: disable=too-many-statements message, thread=agent_thread, checkpoint_storage=checkpoint_storage) - logger.debug(f"WorkflowAgent run completed, result type: {type(result)}") + logger.debug("WorkflowAgent run completed, result type: %s", type(result)) await self._save_agent_thread(context, agent_thread) @@ -142,13 +140,21 @@ async def _get_latest_checkpoint(self, return latest_checkpoint return None - async def _load_checkpoint(self, agent: WorkflowAgent, + async def _load_checkpoint(self, + agent: WorkflowAgent, checkpoint: WorkflowCheckpoint, checkpoint_storage: CheckpointStorage) -> None: """Load the checkpoint data from the given WorkflowCheckpoint. + :param agent: The WorkflowAgent to load the checkpoint into. + :type agent: WorkflowAgent :param checkpoint: The WorkflowCheckpoint to load data from. :type checkpoint: WorkflowCheckpoint + :param checkpoint_storage: The storage to load the checkpoint from. + :type checkpoint_storage: CheckpointStorage + + :return: None + :rtype: None """ await agent.run(checkpoint_id=checkpoint.checkpoint_id, - checkpoint_storage=checkpoint_storage) \ No newline at end of file + checkpoint_storage=checkpoint_storage) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py index 2ad8ab3cee11..4e3feb52a29f 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py @@ -50,11 +50,11 @@ async def transform_input( if agent_thread and agent_thread.message_store: thread_messages = await agent_thread.message_store.list_messages() pending_hitl_requests = self._hitl_helper.get_pending_hitl_request(thread_messages, checkpoint) - logger.info(f"Pending HitL requests: {list(pending_hitl_requests.keys())}") + logger.info("Pending HitL requests: %s", list(pending_hitl_requests.keys())) hitl_response = self._hitl_helper.validate_and_convert_hitl_response( input, pending_requests=pending_hitl_requests) - logger.info(f"HitL response validation result: {[m.to_dict() for m in hitl_response]}") + logger.info("HitL response validation result: %s", [m.to_dict() for m in hitl_response]) if hitl_response: return hitl_response diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py index 4b281530f74e..805f3fc79ead 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py @@ -1,7 +1,7 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -# pylint: disable=attribute-defined-outside-init,protected-access +# pylint: disable=attribute-defined-outside-init,protected-access,unnecessary-lambda-assignment # mypy: disable-error-code="call-overload,assignment,arg-type,override" from __future__ import annotations @@ -270,7 +270,7 @@ def _serialize_arguments(self, arguments: Any) -> str: return arguments try: return json.dumps(arguments) - except Exception as e: + except Exception: # pylint: disable=broad-exception-caught return str(arguments) @@ -373,7 +373,7 @@ async def convert(self, updates: AsyncIterable[AgentRunResponseUpdate]) -> Async is_changed = ( lambda a, b: a is not None \ and b is not None \ - and a.message_id != b.message_id # pylint: disable=unnecessary-lambda-assignment + and a.message_id != b.message_id ) async for group in chunk_on_change(updates, is_changed): @@ -401,7 +401,7 @@ async def convert(self, updates: AsyncIterable[AgentRunResponseUpdate]) -> Async # Extract just the content from (content, author_name) tuples using async generator async def extract_contents(): - async for content, _ in contents_with_author: + async for content, _ in contents_with_author: # pylint: disable=cell-var-from-loop yield content async for content in state.convert_contents(extract_contents(), author_name): diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py index a40eac7f7a7d..4b3dce2c1bdb 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py @@ -1,3 +1,6 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- from typing import Any, List, Dict, Optional, Union import json @@ -87,9 +90,9 @@ def convert_request_arguments(self, arguments: Any) -> str: return arguments def validate_and_convert_hitl_response(self, - input: str | List[Dict] | None, + input: Union[str, List[Dict], None], pending_requests: Dict[str, RequestInfoEvent], - ) -> List[ChatMessage] | None: + ) -> Optional[List[ChatMessage]]: if input is None or isinstance(input, str): logger.warning("Expected list input for HitL response validation, got str.") diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py index 40ce839556bd..cf07cb449d00 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py @@ -18,4 +18,4 @@ "CheckpointRepository", "InMemoryCheckpointRepository", "FileCheckpointRepository", -] \ No newline at end of file +] diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py index 41f5b6a1f5cd..66528ff96213 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py @@ -1,3 +1,6 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- from abc import ABC, abstractmethod import json import os @@ -10,7 +13,11 @@ class AgentThreadRepository(ABC): """AgentThread repository to manage saved thread messages of agent threads and workflows.""" @abstractmethod - async def get(self, conversation_id: str, agent: Optional[Union[AgentProtocol, WorkflowAgent]]=None) -> Optional[AgentThread]: + async def get( + self, + conversation_id: str, + agent: Optional[Union[AgentProtocol, WorkflowAgent]] = None, + ) -> Optional[AgentThread]: """Retrieve the savedt thread for a given conversation ID. :param conversation_id: The conversation ID. @@ -131,6 +138,8 @@ async def write_to_storage(self, conversation_id: str, serialized_thread: Any) - :type conversation_id: str :param serialized_thread: The serialized thread to save. :type serialized_thread: Any + :return: None + :rtype: None """ raise NotImplementedError("write_to_storage is not implemented.") diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py index e32784850838..471d3a2f7f84 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py @@ -1,3 +1,6 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- from abc import ABC, abstractmethod import os from typing import Optional From 009f1457496aca3bd8a63f5691a7274681ab6436 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Wed, 21 Jan 2026 13:26:32 -0800 Subject: [PATCH 13/29] fix unittest in langgraph --- .../unit_tests/test_langgraph_request_converter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_langgraph_request_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_langgraph_request_converter.py index 84a8c8784d8b..b1894f7350d5 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_langgraph_request_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_langgraph_request_converter.py @@ -3,7 +3,7 @@ from azure.ai.agentserver.core import models from azure.ai.agentserver.core.models import projects as project_models -from azure.ai.agentserver.langgraph import models as langgraph_models +from azure.ai.agentserver.langgraph.models.response_api_request_converter import ResponseAPIMessageRequestConverter @pytest.mark.unit @@ -16,7 +16,7 @@ def test_convert_implicit_user_message(): input=[implicit_user_message], ) - converter = langgraph_models.LangGraphRequestConverter(create_response) + converter = ResponseAPIMessageRequestConverter(create_response) res = converter.convert() assert "messages" in res @@ -34,7 +34,7 @@ def test_convert_implicit_user_message_with_contents(): ] create_response = models.CreateResponse(input=[{"content": input_data}]) - converter = langgraph_models.LangGraphRequestConverter(create_response) + converter = ResponseAPIMessageRequestConverter(create_response) res = converter.convert() assert "messages" in res @@ -61,7 +61,7 @@ def test_convert_item_param_message(): create_response = models.CreateResponse( input=input_data, ) - converter = langgraph_models.LangGraphRequestConverter(create_response) + converter = ResponseAPIMessageRequestConverter(create_response) res = converter.convert() assert "messages" in res @@ -103,7 +103,7 @@ def test_convert_item_param_function_call_and_function_call_output(): create_response = models.CreateResponse( input=input_data, ) - converter = langgraph_models.LangGraphRequestConverter(create_response) + converter = ResponseAPIMessageRequestConverter(create_response) res = converter.convert() assert "messages" in res assert len(res["messages"]) == len(input_data) From 8db2c1ab1153a685befdd02749ba22b3f32dcace Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Wed, 21 Jan 2026 16:48:03 -0800 Subject: [PATCH 14/29] fix sphinx for core --- .../core/server/common/agent_run_context.py | 3 +++ .../agentserver/core/tools/client/_client.py | 26 ++++++++++++------- .../core/tools/runtime/_starlette.py | 7 ++--- .../azure.ai.agentserver.core.application.rst | 7 +++++ ...zure.ai.agentserver.core.models.openai.rst | 8 ++++++ ...re.ai.agentserver.core.models.projects.rst | 8 ++++++ .../doc/azure.ai.agentserver.core.models.rst | 17 ++++++++++++ .../doc/azure.ai.agentserver.core.rst | 4 +++ ...zure.ai.agentserver.core.server.common.rst | 9 ++++++- ...azure.ai.agentserver.core.tools.client.rst | 7 +++++ .../doc/azure.ai.agentserver.core.tools.rst | 18 +++++++++++++ ...zure.ai.agentserver.core.tools.runtime.rst | 7 +++++ .../azure.ai.agentserver.core.tools.utils.rst | 7 +++++ .../doc/azure.ai.agentserver.core.utils.rst | 7 +++++ 14 files changed, 122 insertions(+), 13 deletions(-) create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.application.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.openai.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.projects.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.client.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.runtime.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.utils.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.utils.rst diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/agent_run_context.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/agent_run_context.py index 53eb15af3550..2464179a119f 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/agent_run_context.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/agent_run_context.py @@ -13,6 +13,9 @@ class AgentRunContext: + """ + :meta private: + """ def __init__(self, payload: dict, **kwargs: Any) -> None: self._raw_payload = payload self._request = _deserialize_create_response(payload) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py index 030fbe26b5e7..a998de7f9597 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py @@ -43,12 +43,13 @@ class FoundryToolClient(AsyncContextManager["FoundryToolClient"]): # pylint: di This client provides access to tools from both MCP (Model Context Protocol) servers and Azure AI Tools API endpoints, enabling unified tool discovery and invocation. - :param str endpoint: - The fully qualified endpoint for the Azure AI Agents service. - Example: "https://.api.azureml.ms" + :param endpoint: + The fully qualified endpoint for the Azure AI Agents service. + Example: "https://.api.azureml.ms" + :type endpoint: str :param credential: - Credential for authenticating requests to the service. - Use credentials from azure-identity like DefaultAzureCredential. + Credential for authenticating requests to the service. + Use credentials from azure-identity like DefaultAzureCredential. :type credential: ~azure.core.credentials.TokenCredential :param api_version: The API version to use for this operation. :type api_version: str or None @@ -87,18 +88,21 @@ async def list_tools( Retrieves tools from both MCP servers and Azure AI Tools API endpoints, returning them as ResolvedFoundryTool instances ready for invocation. + :param tools: Collection of FoundryTool instances to resolve. :type tools: Collection[~FoundryTool] :param user: Information about the user requesting the tools. :type user: Optional[UserInfo] :param agent_name: Name of the agent requesting the tools. :type agent_name: str + :return: List of resolved Foundry tools. :rtype: List[ResolvedFoundryTool] :raises ~azure.ai.agentserver.core.tools._exceptions.OAuthConsentRequiredError: - Raised when the service requires user OAuth consent. + Raised when the service requires user OAuth consent. :raises ~azure.core.exceptions.HttpResponseError: - Raised for HTTP communication failures. + Raised for HTTP communication failures. + """ _ = kwargs # Reserved for future use resolved_tools: List[ResolvedFoundryTool] = [] @@ -119,18 +123,21 @@ async def list_tools_details( Retrieves tools from both MCP servers and Azure AI Tools API endpoints, returning them as ResolvedFoundryTool instances ready for invocation. + :param tools: Collection of FoundryTool instances to resolve. :type tools: Collection[~FoundryTool] :param user: Information about the user requesting the tools. :type user: Optional[UserInfo] :param agent_name: Name of the agent requesting the tools. :type agent_name: str + :return: Mapping of tool IDs to lists of FoundryToolDetails. :rtype: Mapping[str, List[FoundryToolDetails]] :raises ~azure.ai.agentserver.core.tools._exceptions.OAuthConsentRequiredError: - Raised when the service requires user OAuth consent. + Raised when the service requires user OAuth consent. :raises ~azure.core.exceptions.HttpResponseError: - Raised for HTTP communication failures. + Raised for HTTP communication failures. + """ _ = kwargs # Reserved for future use resolved_tools: Dict[str, List[FoundryToolDetails]] = defaultdict(list) @@ -187,6 +194,7 @@ async def invoke_tool( Raised for HTTP communication failures. :raises ~ToolInvocationError: Raised when the tool invocation fails or source is not supported. + """ _ = kwargs # Reserved for future use if tool.source is FoundryToolSource.HOSTED_MCP: diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py index f60fb63f2cdc..80b25d78b20e 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py @@ -33,11 +33,12 @@ def install(cls, :param app: The Starlette application to install the middleware into. :type app: Starlette :param user_context: Optional context variable to use for storing user info. - If not provided, a default context variable will be used. - :type user_context: Optional[ContextVar[Optional[UserInfo]]] + If not provided, a default context variable will be used. + :type user_context: Optional[ContextVar[Optional[UserInfo]]] :param user_resolver: Optional function to resolve user info from the request. - If not provided, a default resolver will be used. + If not provided, a default resolver will be used. :type user_resolver: Optional[Callable[[Request], Awaitable[Optional[UserInfo]]]] + """ app.add_middleware(UserInfoContextMiddleware, # type: ignore[arg-type] user_info_var=user_context or ContextVarUserProvider.default_user_info_context, diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.application.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.application.rst new file mode 100644 index 000000000000..415b7d3b2538 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.application.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.application package +============================================= + +.. automodule:: azure.ai.agentserver.core.application + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.openai.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.openai.rst new file mode 100644 index 000000000000..dd1cce6eecca --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.openai.rst @@ -0,0 +1,8 @@ +azure.ai.agentserver.core.models.openai package +=============================================== + +.. automodule:: azure.ai.agentserver.core.models.openai + :inherited-members: + :members: + :undoc-members: + :ignore-module-all: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.projects.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.projects.rst new file mode 100644 index 000000000000..38e0be4f331b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.projects.rst @@ -0,0 +1,8 @@ +azure.ai.agentserver.core.models.projects package +================================================= + +.. automodule:: azure.ai.agentserver.core.models.projects + :inherited-members: + :members: + :undoc-members: + :ignore-module-all: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.rst new file mode 100644 index 000000000000..008b280c64de --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.rst @@ -0,0 +1,17 @@ +azure.ai.agentserver.core.models package +======================================== + +.. automodule:: azure.ai.agentserver.core.models + :inherited-members: + :members: + :undoc-members: + :ignore-module-all: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + azure.ai.agentserver.core.models.openai + azure.ai.agentserver.core.models.projects diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.rst index da01b083b0b3..b8f1dadf3a73 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.rst +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.rst @@ -12,7 +12,11 @@ Subpackages .. toctree:: :maxdepth: 4 + azure.ai.agentserver.core.application + azure.ai.agentserver.core.models azure.ai.agentserver.core.server + azure.ai.agentserver.core.tools + azure.ai.agentserver.core.utils Submodules ---------- diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst index 01e54afab103..8fb5b52e4465 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst @@ -24,4 +24,11 @@ azure.ai.agentserver.core.server.common.agent\_run\_context module :inherited-members: :members: :undoc-members: - :no-index: + +azure.ai.agentserver.core.server.common.constants module +-------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.core.server.common.constants + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.client.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.client.rst new file mode 100644 index 000000000000..8182914f69f9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.client.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.tools.client package +============================================== + +.. automodule:: azure.ai.agentserver.core.tools.client + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.rst new file mode 100644 index 000000000000..c112ec2beabd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.rst @@ -0,0 +1,18 @@ +azure.ai.agentserver.core.tools package +======================================= + +.. automodule:: azure.ai.agentserver.core.tools + :inherited-members: + :members: + :undoc-members: + :exclude-members: BaseModel,model_json_schema + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + azure.ai.agentserver.core.tools.client + azure.ai.agentserver.core.tools.runtime + azure.ai.agentserver.core.tools.utils diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.runtime.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.runtime.rst new file mode 100644 index 000000000000..c502d56b42f6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.runtime.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.tools.runtime package +=============================================== + +.. automodule:: azure.ai.agentserver.core.tools.runtime + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.utils.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.utils.rst new file mode 100644 index 000000000000..94d3f310e112 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.utils.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.tools.utils package +============================================= + +.. automodule:: azure.ai.agentserver.core.tools.utils + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.utils.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.utils.rst new file mode 100644 index 000000000000..5250167cf7e6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.utils.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.utils package +======================================= + +.. automodule:: azure.ai.agentserver.core.utils + :inherited-members: + :members: + :undoc-members: From 4a0276560fd32d49e2e6be68cae7c60b87479714 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Wed, 21 Jan 2026 18:31:33 -0800 Subject: [PATCH 15/29] fix minors --- .../ai/agentserver/agentframework/__init__.py | 44 ------------------- .../custom_mock_agent_test.py | 2 +- 2 files changed, 1 insertion(+), 45 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py index 19373d6d5f4f..04a8459c53d9 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py @@ -27,21 +27,11 @@ def from_agent_framework( ) -> "AgentFrameworkAIAgentAdapter": """ Create an Agent Framework AI Agent Adapter from an AgentProtocol or BaseAgent. -<<<<<<< HEAD :keyword agent: The agent to adapt. :type agent: Union[BaseAgent, AgentProtocol] :keyword credentials: Optional asynchronous token credential for authentication. :type credentials: Optional[AsyncTokenCredential] -======= - - :param agent: The agent to adapt. - :type agent: Union[BaseAgent, AgentProtocol] - :param credentials: Optional asynchronous token credential for authentication. - :type credentials: Optional[AsyncTokenCredential] - :param kwargs: Additional keyword arguments to pass to the adapter. - :type kwargs: Any ->>>>>>> lusu/agentserver-1110 :return: An instance of AgentFrameworkAIAgentAdapter. :rtype: AgentFrameworkAIAgentAdapter @@ -57,32 +47,17 @@ def from_agent_framework( ) -> "AgentFrameworkWorkflowAdapter": """ Create an Agent Framework Workflow Adapter. -<<<<<<< HEAD The arugument `workflow` can be either a WorkflowBuilder or a factory function that returns a Workflow. It will be called to create a new Workflow instance and `.as_agent()` will be -======= - The arugument `workflow` can be either a WorkflowBuilder or a factory function - that returns a Workflow. - It will be called to create a new Workflow instance and `.as_agent()` will be ->>>>>>> lusu/agentserver-1110 called as well for each incoming CreateResponse request. Please ensure that the workflow definition can be converted to a WorkflowAgent. For more information, see the agent-framework samples and documentation. -<<<<<<< HEAD :keyword workflow: The workflow builder or factory function to adapt. :type workflow: Union[WorkflowBuilder, Callable[[], Workflow]] :keyword credentials: Optional asynchronous token credential for authentication. :type credentials: Optional[AsyncTokenCredential] -======= - :param workflow: The workflow builder or factory function to adapt. - :type workflow: Union[WorkflowBuilder, Callable[[], Workflow]] - :param credentials: Optional asynchronous token credential for authentication. - :type credentials: Optional[AsyncTokenCredential] - :param kwargs: Additional keyword arguments to pass to the adapter. - :type kwargs: Any ->>>>>>> lusu/agentserver-1110 :return: An instance of AgentFrameworkWorkflowAdapter. :rtype: AgentFrameworkWorkflowAdapter """ @@ -96,7 +71,6 @@ def from_agent_framework( **kwargs: Any, ) -> "AgentFrameworkAgent": """ -<<<<<<< HEAD Create an Agent Framework Adapter from either an AgentProtocol/BaseAgent or a WorkflowAgent. One of agent or workflow must be provided. @@ -107,20 +81,6 @@ def from_agent_framework( :type workflow: Optional[Union[WorkflowBuilder, Callable[[], Workflow]]] :keyword credentials: Optional asynchronous token credential for authentication. :type credentials: Optional[AsyncTokenCredential] -======= - Create an Agent Framework Adapter from either an AgentProtocol/BaseAgent or a - WorkflowAgent. - One of agent or workflow must be provided. - - :param agent: The agent to adapt. - :type agent: Optional[Union[BaseAgent, AgentProtocol]] - :param workflow: The workflow builder or factory function to adapt. - :type workflow: Optional[Union[WorkflowBuilder, Callable[[], Workflow]]] - :param credentials: Optional asynchronous token credential for authentication. - :type credentials: Optional[AsyncTokenCredential] - :param kwargs: Additional keyword arguments to pass to the adapter. - :type kwargs: Any ->>>>>>> lusu/agentserver-1110 :return: An instance of AgentFrameworkAgent. :rtype: AgentFrameworkAgent :raises TypeError: If neither or both of agent and workflow are provided, or if @@ -141,11 +101,7 @@ def workflow_factory() -> Workflow: return AgentFrameworkWorkflowAdapter(workflow_factory=workflow, credentials=credentials, **kwargs) raise TypeError("workflow must be a WorkflowBuilder or callable returning a Workflow") -<<<<<<< HEAD if isinstance(agent, (AgentProtocol, BaseAgent)): -======= - if isinstance(agent, AgentProtocol) or isinstance(agent, BaseAgent): ->>>>>>> lusu/agentserver-1110 return AgentFrameworkAIAgentAdapter(agent, credentials=credentials, **kwargs) raise TypeError("agent must be an instance of AgentProtocol or BaseAgent") diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/simple_mock_agent/custom_mock_agent_test.py b/sdk/agentserver/azure-ai-agentserver-core/samples/simple_mock_agent/custom_mock_agent_test.py index 3d4187a188f2..f6d2c08bb0b9 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/samples/simple_mock_agent/custom_mock_agent_test.py +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/simple_mock_agent/custom_mock_agent_test.py @@ -97,7 +97,7 @@ async def agent_run(context: AgentRunContext): return response -my_agent = FoundryCBAgent() +my_agent = FoundryCBAgent(project_endpoint="mock-endpoint") my_agent.agent_run = agent_run if __name__ == "__main__": From a58b6a5478e0d456b3debc068d8077034ec20dcf Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Wed, 21 Jan 2026 21:24:17 -0800 Subject: [PATCH 16/29] fix pylint --- .../azure/ai/agentserver/agentframework/__init__.py | 11 ++++++----- .../azure/ai/agentserver/langgraph/__init__.py | 1 + .../ai/agentserver/langgraph/tools/_chat_model.py | 7 ++++--- .../ai/agentserver/langgraph/tools/_tool_node.py | 6 +++--- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py index 04a8459c53d9..32cf57200a49 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py @@ -7,13 +7,14 @@ from agent_framework import AgentProtocol, BaseAgent, Workflow, WorkflowBuilder -from azure.ai.agentserver.agentframework._version import VERSION -from azure.ai.agentserver.agentframework._agent_framework import AgentFrameworkAgent -from azure.ai.agentserver.agentframework._ai_agent_adapter import AgentFrameworkAIAgentAdapter -from azure.ai.agentserver.agentframework._workflow_agent_adapter import AgentFrameworkWorkflowAdapter -from azure.ai.agentserver.agentframework._foundry_tools import FoundryToolsChatMiddleware from azure.ai.agentserver.core.application import PackageMetadata, set_current_app # pylint: disable=import-error,no-name-in-module +from ._version import VERSION +from ._agent_framework import AgentFrameworkAgent +from ._ai_agent_adapter import AgentFrameworkAIAgentAdapter +from ._workflow_agent_adapter import AgentFrameworkWorkflowAdapter +from ._foundry_tools import FoundryToolsChatMiddleware + if TYPE_CHECKING: # pragma: no cover from azure.core.credentials_async import AsyncTokenCredential diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/__init__.py index 7fe934ae81c6..7fefa1b486d5 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/__init__.py @@ -6,6 +6,7 @@ from typing import Optional, TYPE_CHECKING from azure.ai.agentserver.core.application import PackageMetadata, set_current_app + from ._context import LanggraphRunContext from ._version import VERSION from .langgraph import LangGraphAdapter diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py index e511f5bbf915..c221910218f4 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py @@ -52,12 +52,13 @@ def tool_node_wrapper(self) -> FoundryToolNodeWrappers: """Get the Foundry tool call wrappers for this chat model. Example:: - >>> from langgraph.prebuilt import ToolNode - >>> foundry_tool_bound_chat_model = FoundryToolLateBindingChatModel(...) - >>> ToolNode([...], **foundry_tool_bound_chat_model.as_wrappers()) + >>> from langgraph.prebuilt import ToolNode + >>> foundry_tool_bound_chat_model = FoundryToolLateBindingChatModel(...) + >>> ToolNode([...], **foundry_tool_bound_chat_model.as_wrappers()) :return: The Foundry tool call wrappers. :rtype: FoundryToolNodeWrappers + """ return FoundryToolCallWrapper(self._foundry_tools_to_bind).as_wrappers() diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py index e66e1c554ba1..5f3c6326836b 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py @@ -18,9 +18,9 @@ class FoundryToolNodeWrappers(TypedDict): """A TypedDict for Foundry tool node wrappers. Example:: - >>> from langgraph.prebuilt import ToolNode - >>> call_wrapper = FoundryToolCallWrapper(...) - >>> ToolNode([...], **call_wrapper.as_wrappers()) + >>> from langgraph.prebuilt import ToolNode + >>> call_wrapper = FoundryToolCallWrapper(...) + >>> ToolNode([...], **call_wrapper.as_wrappers()) :param wrap_tool_call: The synchronous tool call wrapper. :type wrap_tool_call: ToolCallWrapper From 2e40950baff2a8d822d4980dfd358e3c55018a03 Mon Sep 17 00:00:00 2001 From: junanchen Date: Wed, 21 Jan 2026 22:07:03 -0800 Subject: [PATCH 17/29] tools ut in core package --- .../tests/__init__.py | 1 - .../unit_tests/agent_framework/__init__.py | 5 + .../agent_framework}/conftest.py | 0 .../test_agent_framework_input_converter.py | 0 .../tests/unit_tests/__init__.py | 4 + .../tests/unit_tests/core/__init__.py | 5 + .../tests/unit_tests/core/tools/__init__.py | 4 + .../tests/unit_tests/core/tools/conftest.py | 128 +++++ .../core/tools/operations/__init__.py | 4 + .../test_foundry_connected_tools.py | 479 +++++++++++++++++ .../test_foundry_hosted_mcp_tools.py | 309 +++++++++++ .../unit_tests/core/tools/test_client.py | 485 ++++++++++++++++++ .../core/tools/test_configuration.py | 25 + .../tests/__init__.py | 1 - .../tests/unit_tests/langgraph/__init__.py | 5 + .../{ => unit_tests/langgraph}/conftest.py | 0 .../test_langgraph_request_converter.py | 0 17 files changed, 1453 insertions(+), 2 deletions(-) delete mode 100644 sdk/agentserver/azure-ai-agentserver-agentframework/tests/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/__init__.py rename sdk/agentserver/azure-ai-agentserver-agentframework/tests/{ => unit_tests/agent_framework}/conftest.py (100%) rename sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/{ => agent_framework}/test_agent_framework_input_converter.py (100%) create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/conftest.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_connected_tools.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_hosted_mcp_tools.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/test_client.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/test_configuration.py delete mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/tests/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/__init__.py rename sdk/agentserver/azure-ai-agentserver-langgraph/tests/{ => unit_tests/langgraph}/conftest.py (100%) rename sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/{ => langgraph}/test_langgraph_request_converter.py (100%) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/__init__.py deleted file mode 100644 index 4a5d26360bce..000000000000 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Unit tests package diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/conftest.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/conftest.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-agentframework/tests/conftest.py rename to sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/conftest.py diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_agent_framework_input_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/test_agent_framework_input_converter.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_agent_framework_input_converter.py rename to sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/test_agent_framework_input_converter.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/__init__.py new file mode 100644 index 000000000000..d02a9af6c5f6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/__init__.py new file mode 100644 index 000000000000..d02a9af6c5f6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/conftest.py new file mode 100644 index 000000000000..bf94b48b9699 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/conftest.py @@ -0,0 +1,128 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for unit tests.""" +import json +from typing import Any, Dict, Optional +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, + UserInfo, +) + + +@pytest.fixture +def mock_credential(): + """Create a mock async token credential.""" + credential = AsyncMock() + credential.get_token = AsyncMock(return_value=MagicMock(token="test-token")) + return credential + + +@pytest.fixture +def sample_user_info(): + """Create a sample UserInfo instance.""" + return UserInfo(object_id="test-object-id", tenant_id="test-tenant-id") + + +@pytest.fixture +def sample_hosted_mcp_tool(): + """Create a sample FoundryHostedMcpTool.""" + return FoundryHostedMcpTool( + name="test_mcp_tool", + configuration={"model_deployment_name": "gpt-4"} + ) + + +@pytest.fixture +def sample_connected_tool(): + """Create a sample FoundryConnectedTool.""" + return FoundryConnectedTool( + protocol="mcp", + project_connection_id="test-connection-id" + ) + + +@pytest.fixture +def sample_schema_definition(): + """Create a sample SchemaDefinition.""" + return SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input parameter") + }, + required={"input"} + ) + + +@pytest.fixture +def sample_tool_details(sample_schema_definition): + """Create a sample FoundryToolDetails.""" + return FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=sample_schema_definition + ) + + +@pytest.fixture +def sample_resolved_mcp_tool(sample_hosted_mcp_tool, sample_tool_details): + """Create a sample ResolvedFoundryTool for MCP.""" + return ResolvedFoundryTool( + definition=sample_hosted_mcp_tool, + details=sample_tool_details + ) + + +@pytest.fixture +def sample_resolved_connected_tool(sample_connected_tool, sample_tool_details): + """Create a sample ResolvedFoundryTool for connected tools.""" + return ResolvedFoundryTool( + definition=sample_connected_tool, + details=sample_tool_details + ) + + +def create_mock_http_response( + status_code: int = 200, + json_data: Optional[Dict[str, Any]] = None +) -> AsyncMock: + """Create a mock HTTP response that simulates real Azure SDK response behavior. + + This mock matches the behavior expected by BaseOperations._extract_response_json, + where response.text() and response.body() are synchronous methods that return + the actual string/bytes values directly. + + :param status_code: HTTP status code. + :param json_data: JSON data to return. + :return: Mock response object. + """ + response = AsyncMock() + response.status_code = status_code + + if json_data is not None: + json_str = json.dumps(json_data) + json_bytes = json_str.encode("utf-8") + # text() and body() are synchronous methods in AsyncHttpResponse + # They must be MagicMock (not AsyncMock) to return values directly when called + response.text = MagicMock(return_value=json_str) + response.body = MagicMock(return_value=json_bytes) + else: + response.text = MagicMock(return_value="") + response.body = MagicMock(return_value=b"") + + # Support async context manager + response.__aenter__ = AsyncMock(return_value=response) + response.__aexit__ = AsyncMock(return_value=None) + + return response + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/__init__.py new file mode 100644 index 000000000000..d02a9af6c5f6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_connected_tools.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_connected_tools.py new file mode 100644 index 000000000000..7c453ba2fa2c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_connected_tools.py @@ -0,0 +1,479 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryConnectedToolsOperations - testing only public methods.""" +import pytest +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryToolDetails, +) +from azure.ai.agentserver.core.tools.client.operations._foundry_connected_tools import ( + FoundryConnectedToolsOperations, +) +from azure.ai.agentserver.core.tools._exceptions import OAuthConsentRequiredError, ToolInvocationError + +from ..conftest import create_mock_http_response + + +class TestFoundryConnectedToolsOperationsListTools: + """Tests for FoundryConnectedToolsOperations.list_tools public method.""" + + @pytest.mark.asyncio + async def test_list_tools_with_empty_list_returns_empty(self): + """Test list_tools returns empty when tools list is empty.""" + mock_client = AsyncMock() + ops = FoundryConnectedToolsOperations(mock_client) + + result = await ops.list_tools([], None, "test-agent") + + assert result == [] + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_list_tools_returns_tools_from_server( + self, + sample_connected_tool, + sample_user_info + ): + """Test list_tools returns tools from server response.""" + mock_client = AsyncMock() + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "remote_tool", + "description": "A remote connected tool", + "parameters": { + "type": "object", + "properties": { + "input": {"type": "string"} + } + } + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([sample_connected_tool], sample_user_info, "test-agent")) + + assert len(result) == 1 + definition, details = result[0] + assert definition == sample_connected_tool + assert isinstance(details, FoundryToolDetails) + assert details.name == "remote_tool" + assert details.description == "A remote connected tool" + + @pytest.mark.asyncio + async def test_list_tools_without_user_info(self, sample_connected_tool): + """Test list_tools works without user info (local execution).""" + mock_client = AsyncMock() + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "tool_no_user", + "description": "Tool without user", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([sample_connected_tool], None, "test-agent")) + + assert len(result) == 1 + assert result[0][1].name == "tool_no_user" + + @pytest.mark.asyncio + async def test_list_tools_with_multiple_connections(self, sample_user_info): + """Test list_tools with multiple connected tool definitions.""" + mock_client = AsyncMock() + + tool1 = FoundryConnectedTool(protocol="mcp", project_connection_id="conn-1") + tool2 = FoundryConnectedTool(protocol="a2a", project_connection_id="conn-2") + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": "mcp", + "projectConnectionId": "conn-1" + }, + "manifest": [ + { + "name": "tool_from_conn1", + "description": "From connection 1", + "parameters": {"type": "object", "properties": {}} + } + ] + }, + { + "remoteServer": { + "protocol": "a2a", + "projectConnectionId": "conn-2" + }, + "manifest": [ + { + "name": "tool_from_conn2", + "description": "From connection 2", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([tool1, tool2], sample_user_info, "test-agent")) + + assert len(result) == 2 + names = {r[1].name for r in result} + assert names == {"tool_from_conn1", "tool_from_conn2"} + + @pytest.mark.asyncio + async def test_list_tools_filters_by_connection_id(self, sample_user_info): + """Test list_tools only returns tools from requested connections.""" + mock_client = AsyncMock() + + requested_tool = FoundryConnectedTool(protocol="mcp", project_connection_id="requested-conn") + + # Server returns tools from multiple connections, but we only requested one + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": "mcp", + "projectConnectionId": "requested-conn" + }, + "manifest": [ + { + "name": "requested_tool", + "description": "Requested", + "parameters": {"type": "object", "properties": {}} + } + ] + }, + { + "remoteServer": { + "protocol": "mcp", + "projectConnectionId": "unrequested-conn" + }, + "manifest": [ + { + "name": "unrequested_tool", + "description": "Not requested", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([requested_tool], sample_user_info, "test-agent")) + + # Should only return tools from requested connection + assert len(result) == 1 + assert result[0][1].name == "requested_tool" + + @pytest.mark.asyncio + async def test_list_tools_multiple_tools_per_connection( + self, + sample_connected_tool, + sample_user_info + ): + """Test list_tools returns multiple tools from same connection.""" + mock_client = AsyncMock() + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "tool_one", + "description": "First tool", + "parameters": {"type": "object", "properties": {}} + }, + { + "name": "tool_two", + "description": "Second tool", + "parameters": {"type": "object", "properties": {}} + }, + { + "name": "tool_three", + "description": "Third tool", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([sample_connected_tool], sample_user_info, "test-agent")) + + assert len(result) == 3 + names = {r[1].name for r in result} + assert names == {"tool_one", "tool_two", "tool_three"} + + @pytest.mark.asyncio + async def test_list_tools_raises_oauth_consent_error( + self, + sample_connected_tool, + sample_user_info + ): + """Test list_tools raises OAuthConsentRequiredError when consent needed.""" + mock_client = AsyncMock() + + response_data = { + "type": "OAuthConsentRequired", + "toolResult": { + "consentUrl": "https://login.microsoftonline.com/consent", + "message": "User consent is required to access this resource", + "projectConnectionId": sample_connected_tool.project_connection_id + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + + with pytest.raises(OAuthConsentRequiredError) as exc_info: + list(await ops.list_tools([sample_connected_tool], sample_user_info, "test-agent")) + + assert exc_info.value.consent_url == "https://login.microsoftonline.com/consent" + assert "consent" in exc_info.value.message.lower() + + +class TestFoundryConnectedToolsOperationsInvokeTool: + """Tests for FoundryConnectedToolsOperations.invoke_tool public method.""" + + @pytest.mark.asyncio + async def test_invoke_tool_returns_result_value( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool returns the result value from server.""" + mock_client = AsyncMock() + + expected_result = {"data": "some output", "status": "success"} + response_data = {"toolResult": expected_result} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = await ops.invoke_tool( + sample_resolved_connected_tool, + {"input": "test"}, + sample_user_info, + "test-agent" + ) + + assert result == expected_result + + @pytest.mark.asyncio + async def test_invoke_tool_without_user_info(self, sample_resolved_connected_tool): + """Test invoke_tool works without user info (local execution).""" + mock_client = AsyncMock() + + response_data = {"toolResult": "local result"} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = await ops.invoke_tool( + sample_resolved_connected_tool, + {}, + None, # No user info + "test-agent" + ) + + assert result == "local result" + + @pytest.mark.asyncio + async def test_invoke_tool_with_complex_arguments( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool handles complex nested arguments.""" + mock_client = AsyncMock() + + response_data = {"toolResult": "processed"} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + complex_args = { + "query": "search term", + "filters": { + "date_range": {"start": "2025-01-01", "end": "2025-12-31"}, + "categories": ["A", "B", "C"] + }, + "limit": 50 + } + + result = await ops.invoke_tool( + sample_resolved_connected_tool, + complex_args, + sample_user_info, + "test-agent" + ) + + assert result == "processed" + mock_client.send_request.assert_called_once() + + @pytest.mark.asyncio + async def test_invoke_tool_returns_none_for_empty_result( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool returns None when server returns no result.""" + mock_client = AsyncMock() + + # Server returns empty response (no toolResult) + response_data = { + "toolResult": None + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = await ops.invoke_tool( + sample_resolved_connected_tool, + {}, + sample_user_info, + "test-agent" + ) + + assert result is None + + @pytest.mark.asyncio + async def test_invoke_tool_with_mcp_tool_raises_error( + self, + sample_resolved_mcp_tool, + sample_user_info + ): + """Test invoke_tool raises ToolInvocationError for non-connected tool.""" + mock_client = AsyncMock() + ops = FoundryConnectedToolsOperations(mock_client) + + with pytest.raises(ToolInvocationError) as exc_info: + await ops.invoke_tool( + sample_resolved_mcp_tool, + {}, + sample_user_info, + "test-agent" + ) + + assert "not a Foundry connected tool" in str(exc_info.value) + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_invoke_tool_raises_oauth_consent_error( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool raises OAuthConsentRequiredError when consent needed.""" + mock_client = AsyncMock() + + response_data = { + "type": "OAuthConsentRequired", + "toolResult": { + "consentUrl": "https://login.microsoftonline.com/oauth/consent", + "message": "Please provide consent to continue", + "projectConnectionId": "test-connection-id" + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + + with pytest.raises(OAuthConsentRequiredError) as exc_info: + await ops.invoke_tool( + sample_resolved_connected_tool, + {"input": "test"}, + sample_user_info, + "test-agent" + ) + + assert "https://login.microsoftonline.com/oauth/consent" in exc_info.value.consent_url + + @pytest.mark.asyncio + async def test_invoke_tool_with_different_agent_names( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool uses correct agent name in request.""" + mock_client = AsyncMock() + + response_data = {"toolResult": "result"} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + + # Invoke with different agent names + for agent_name in ["agent-1", "my-custom-agent", "production-agent"]: + await ops.invoke_tool( + sample_resolved_connected_tool, + {}, + sample_user_info, + agent_name + ) + + # Verify the correct path was used + call_args = mock_client.post.call_args + assert agent_name in call_args[0][0] + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_hosted_mcp_tools.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_hosted_mcp_tools.py new file mode 100644 index 000000000000..9897cbf168ed --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_hosted_mcp_tools.py @@ -0,0 +1,309 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryMcpToolsOperations - testing only public methods.""" +import pytest +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) +from azure.ai.agentserver.core.tools.client.operations._foundry_hosted_mcp_tools import ( + FoundryMcpToolsOperations, +) +from azure.ai.agentserver.core.tools._exceptions import ToolInvocationError + +from ..conftest import create_mock_http_response + + +class TestFoundryMcpToolsOperationsListTools: + """Tests for FoundryMcpToolsOperations.list_tools public method.""" + + @pytest.mark.asyncio + async def test_list_tools_with_empty_list_returns_empty(self): + """Test list_tools returns empty when allowed_tools is empty.""" + mock_client = AsyncMock() + ops = FoundryMcpToolsOperations(mock_client) + + result = await ops.list_tools([]) + + assert result == [] + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_list_tools_returns_matching_tools(self, sample_hosted_mcp_tool): + """Test list_tools returns tools that match the allowed list.""" + mock_client = AsyncMock() + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Test MCP tool", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + } + } + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([sample_hosted_mcp_tool])) + + assert len(result) == 1 + definition, details = result[0] + assert definition == sample_hosted_mcp_tool + assert isinstance(details, FoundryToolDetails) + assert details.name == sample_hosted_mcp_tool.name + assert details.description == "Test MCP tool" + + @pytest.mark.asyncio + async def test_list_tools_filters_out_non_allowed_tools(self, sample_hosted_mcp_tool): + """Test list_tools only returns tools in the allowed list.""" + mock_client = AsyncMock() + + # Server returns multiple tools but only one is allowed + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Allowed tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "other_tool_not_in_list", + "description": "Not allowed tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "another_unlisted_tool", + "description": "Also not allowed", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([sample_hosted_mcp_tool])) + + assert len(result) == 1 + assert result[0][1].name == sample_hosted_mcp_tool.name + + @pytest.mark.asyncio + async def test_list_tools_with_multiple_allowed_tools(self): + """Test list_tools with multiple tools in allowed list.""" + mock_client = AsyncMock() + + tool1 = FoundryHostedMcpTool(name="tool_one") + tool2 = FoundryHostedMcpTool(name="tool_two") + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": "tool_one", + "description": "First tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "tool_two", + "description": "Second tool", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([tool1, tool2])) + + assert len(result) == 2 + names = {r[1].name for r in result} + assert names == {"tool_one", "tool_two"} + + @pytest.mark.asyncio + async def test_list_tools_preserves_tool_metadata(self): + """Test list_tools preserves metadata from server response.""" + mock_client = AsyncMock() + + tool = FoundryHostedMcpTool(name="tool_with_meta") + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": "tool_with_meta", + "description": "Tool with metadata", + "inputSchema": { + "type": "object", + "properties": { + "param1": {"type": "string"} + }, + "required": ["param1"] + }, + "_meta": { + "type": "object", + "properties": { + "model": {"type": "string"} + } + } + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([tool])) + + assert len(result) == 1 + details = result[0][1] + assert details.metadata is not None + + +class TestFoundryMcpToolsOperationsInvokeTool: + """Tests for FoundryMcpToolsOperations.invoke_tool public method.""" + + @pytest.mark.asyncio + async def test_invoke_tool_returns_server_response(self, sample_resolved_mcp_tool): + """Test invoke_tool returns the response from server.""" + mock_client = AsyncMock() + + expected_response = { + "jsonrpc": "2.0", + "id": 2, + "result": { + "content": [{"type": "text", "text": "Hello World"}] + } + } + mock_response = create_mock_http_response(200, expected_response) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = await ops.invoke_tool(sample_resolved_mcp_tool, {"query": "test"}) + + assert result == expected_response + + @pytest.mark.asyncio + async def test_invoke_tool_with_empty_arguments(self, sample_resolved_mcp_tool): + """Test invoke_tool works with empty arguments.""" + mock_client = AsyncMock() + + expected_response = {"result": "success"} + mock_response = create_mock_http_response(200, expected_response) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = await ops.invoke_tool(sample_resolved_mcp_tool, {}) + + assert result == expected_response + + @pytest.mark.asyncio + async def test_invoke_tool_with_complex_arguments(self, sample_resolved_mcp_tool): + """Test invoke_tool handles complex nested arguments.""" + mock_client = AsyncMock() + + mock_response = create_mock_http_response(200, {"result": "ok"}) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + complex_args = { + "text": "sample text", + "options": { + "temperature": 0.7, + "max_tokens": 100 + }, + "tags": ["tag1", "tag2"] + } + + result = await ops.invoke_tool(sample_resolved_mcp_tool, complex_args) + + assert result == {"result": "ok"} + mock_client.send_request.assert_called_once() + + @pytest.mark.asyncio + async def test_invoke_tool_with_connected_tool_raises_error( + self, + sample_resolved_connected_tool + ): + """Test invoke_tool raises ToolInvocationError for non-MCP tool.""" + mock_client = AsyncMock() + ops = FoundryMcpToolsOperations(mock_client) + + with pytest.raises(ToolInvocationError) as exc_info: + await ops.invoke_tool(sample_resolved_connected_tool, {}) + + assert "not a Foundry-hosted MCP tool" in str(exc_info.value) + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_invoke_tool_with_configuration_and_metadata(self): + """Test invoke_tool handles tool with configuration and metadata.""" + mock_client = AsyncMock() + + # Create tool with configuration + tool_def = FoundryHostedMcpTool( + name="image_generation", + configuration={"model_deployment_name": "dall-e-3"} + ) + + # Create tool details with metadata schema + meta_schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "model": SchemaProperty(type=SchemaType.STRING) + } + ) + details = FoundryToolDetails( + name="image_generation", + description="Generate images", + input_schema=SchemaDefinition(type=SchemaType.OBJECT, properties={}), + metadata=meta_schema + ) + resolved_tool = ResolvedFoundryTool(definition=tool_def, details=details) + + mock_response = create_mock_http_response(200, {"result": "image_url"}) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = await ops.invoke_tool(resolved_tool, {"prompt": "a cat"}) + + assert result == {"result": "image_url"} + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/test_client.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/test_client.py new file mode 100644 index 000000000000..c99de80a87f9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/test_client.py @@ -0,0 +1,485 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolClient - testing only public methods.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from azure.ai.agentserver.core.tools.client._client import FoundryToolClient +from azure.ai.agentserver.core.tools.client._models import ( + FoundryToolDetails, + FoundryToolSource, + ResolvedFoundryTool, +) +from azure.ai.agentserver.core.tools._exceptions import ToolInvocationError + +from .conftest import create_mock_http_response + + +class TestFoundryToolClientInit: + """Tests for FoundryToolClient.__init__ public method.""" + + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + def test_init_with_valid_endpoint_and_credential(self, mock_pipeline_client_class, mock_credential): + """Test client can be initialized with valid endpoint and credential.""" + endpoint = "https://test.api.azureml.ms" + + client = FoundryToolClient(endpoint, mock_credential) + + # Verify client was created with correct base_url + call_kwargs = mock_pipeline_client_class.call_args + assert call_kwargs[1]["base_url"] == endpoint + assert client is not None + + +class TestFoundryToolClientListTools: + """Tests for FoundryToolClient.list_tools public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_empty_collection_returns_empty_list( + self, + mock_pipeline_client_class, + mock_credential + ): + """Test list_tools returns empty list when given empty collection.""" + client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + + result = await client.list_tools([], agent_name="test-agent") + + assert result == [] + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_with_single_mcp_tool_returns_resolved_tools( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools with a single MCP tool returns resolved tools.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Mock HTTP response for MCP tools listing + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Test MCP tool description", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + result = await client.list_tools([sample_hosted_mcp_tool], agent_name="test-agent") + + assert len(result) == 1 + assert isinstance(result[0], ResolvedFoundryTool) + assert result[0].name == sample_hosted_mcp_tool.name + assert result[0].source == FoundryToolSource.HOSTED_MCP + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_with_single_connected_tool_returns_resolved_tools( + self, + mock_pipeline_client_class, + mock_credential, + sample_connected_tool, + sample_user_info + ): + """Test list_tools with a single connected tool returns resolved tools.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Mock HTTP response for connected tools listing + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "connected_test_tool", + "description": "Test connected tool", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + result = await client.list_tools( + [sample_connected_tool], + agent_name="test-agent", + user=sample_user_info + ) + + assert len(result) == 1 + assert isinstance(result[0], ResolvedFoundryTool) + assert result[0].name == "connected_test_tool" + assert result[0].source == FoundryToolSource.CONNECTED + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_with_mixed_tool_types_returns_all_resolved( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool, + sample_connected_tool, + sample_user_info + ): + """Test list_tools with both MCP and connected tools returns all resolved tools.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # We need to return different responses based on the request + mcp_response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "MCP tool", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + connected_response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "connected_tool", + "description": "Connected tool", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + + # Mock to return different responses for different requests + mock_client_instance.send_request.side_effect = [ + create_mock_http_response(200, mcp_response_data), + create_mock_http_response(200, connected_response_data) + ] + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + result = await client.list_tools( + [sample_hosted_mcp_tool, sample_connected_tool], + agent_name="test-agent", + user=sample_user_info + ) + + assert len(result) == 2 + sources = {tool.source for tool in result} + assert FoundryToolSource.HOSTED_MCP in sources + assert FoundryToolSource.CONNECTED in sources + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_filters_unlisted_mcp_tools( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools only returns tools that are in the allowed list.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Server returns more tools than requested + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Requested tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "unrequested_tool", + "description": "This tool was not requested", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + result = await client.list_tools([sample_hosted_mcp_tool], agent_name="test-agent") + + # Should only return the requested tool + assert len(result) == 1 + assert result[0].name == sample_hosted_mcp_tool.name + + +class TestFoundryToolClientListToolsDetails: + """Tests for FoundryToolClient.list_tools_details public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_details_returns_mapping_structure( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools_details returns correct mapping structure.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Test tool", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + result = await client.list_tools_details([sample_hosted_mcp_tool], agent_name="test-agent") + + assert isinstance(result, dict) + assert sample_hosted_mcp_tool.id in result + assert len(result[sample_hosted_mcp_tool.id]) == 1 + assert isinstance(result[sample_hosted_mcp_tool.id][0], FoundryToolDetails) + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_details_groups_multiple_tools_by_definition( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools_details groups multiple tools from same source by definition ID.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Server returns multiple tools for the same MCP source + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Tool variant 1", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + result = await client.list_tools_details([sample_hosted_mcp_tool], agent_name="test-agent") + + # All tools should be grouped under the same definition ID + assert sample_hosted_mcp_tool.id in result + + +class TestFoundryToolClientInvokeTool: + """Tests for FoundryToolClient.invoke_tool public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_mcp_tool_returns_result( + self, + mock_pipeline_client_class, + mock_credential, + sample_resolved_mcp_tool + ): + """Test invoke_tool with MCP tool returns the invocation result.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + expected_result = {"result": {"content": [{"text": "Hello World"}]}} + mock_response = create_mock_http_response(200, expected_result) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + result = await client.invoke_tool( + sample_resolved_mcp_tool, + arguments={"input": "test"}, + agent_name="test-agent" + ) + + assert result == expected_result + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_connected_tool_returns_result( + self, + mock_pipeline_client_class, + mock_credential, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool with connected tool returns the invocation result.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + expected_value = {"output": "Connected tool result"} + response_data = {"toolResult": expected_value} + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + result = await client.invoke_tool( + sample_resolved_connected_tool, + arguments={"input": "test"}, + agent_name="test-agent", + user=sample_user_info + ) + + assert result == expected_value + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_tool_with_complex_arguments( + self, + mock_pipeline_client_class, + mock_credential, + sample_resolved_mcp_tool + ): + """Test invoke_tool correctly passes complex arguments.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + mock_response = create_mock_http_response(200, {"result": "success"}) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + complex_args = { + "string_param": "value", + "number_param": 42, + "bool_param": True, + "list_param": [1, 2, 3], + "nested_param": {"key": "value"} + } + + result = await client.invoke_tool( + sample_resolved_mcp_tool, + arguments=complex_args, + agent_name="test-agent" + ) + + # Verify request was made + mock_client_instance.send_request.assert_called_once() + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_tool_with_unsupported_source_raises_error( + self, + mock_pipeline_client_class, + mock_credential, + sample_tool_details + ): + """Test invoke_tool raises ToolInvocationError for unsupported tool source.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Create a mock tool with unsupported source + mock_definition = MagicMock() + mock_definition.source = "unsupported_source" + mock_tool = MagicMock(spec=ResolvedFoundryTool) + mock_tool.definition = mock_definition + mock_tool.source = "unsupported_source" + mock_tool.details = sample_tool_details + + client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + + with pytest.raises(ToolInvocationError) as exc_info: + await client.invoke_tool( + mock_tool, + arguments={"input": "test"}, + agent_name="test-agent" + ) + + assert "Unsupported tool source" in str(exc_info.value) + + +class TestFoundryToolClientClose: + """Tests for FoundryToolClient.close public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_close_closes_underlying_client( + self, + mock_pipeline_client_class, + mock_credential + ): + """Test close() properly closes the underlying HTTP client.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + await client.close() + + mock_client_instance.close.assert_called_once() + + +class TestFoundryToolClientContextManager: + """Tests for FoundryToolClient async context manager protocol.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_async_context_manager_enters_and_exits( + self, + mock_pipeline_client_class, + mock_credential + ): + """Test client can be used as async context manager.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + async with FoundryToolClient("https://test.api.azureml.ms", mock_credential) as client: + assert client is not None + mock_client_instance.__aenter__.assert_called_once() + + mock_client_instance.__aexit__.assert_called_once() + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/test_configuration.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/test_configuration.py new file mode 100644 index 000000000000..2f3c2710a3fc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/test_configuration.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolClientConfiguration.""" + +from azure.core.pipeline import policies + +from azure.ai.agentserver.core.tools.client._configuration import FoundryToolClientConfiguration + + +class TestFoundryToolClientConfiguration: + """Tests for FoundryToolClientConfiguration class.""" + + def test_init_creates_all_required_policies(self, mock_credential): + """Test that initialization creates all required pipeline policies.""" + config = FoundryToolClientConfiguration(mock_credential) + + assert isinstance(config.retry_policy, policies.AsyncRetryPolicy) + assert isinstance(config.logging_policy, policies.NetworkTraceLoggingPolicy) + assert isinstance(config.request_id_policy, policies.RequestIdPolicy) + assert isinstance(config.http_logging_policy, policies.HttpLoggingPolicy) + assert isinstance(config.user_agent_policy, policies.UserAgentPolicy) + assert isinstance(config.authentication_policy, policies.AsyncBearerTokenCredentialPolicy) + assert isinstance(config.redirect_policy, policies.AsyncRedirectPolicy) + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/__init__.py deleted file mode 100644 index 4a5d26360bce..000000000000 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Unit tests package diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/conftest.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/conftest.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-langgraph/tests/conftest.py rename to sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/conftest.py diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_langgraph_request_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/test_langgraph_request_converter.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_langgraph_request_converter.py rename to sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/test_langgraph_request_converter.py From ed539e5d8dc563234c0c00d9bbc8b90f4b7957ac Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Wed, 21 Jan 2026 23:47:38 -0800 Subject: [PATCH 18/29] add cachetools min version --- sdk/agentserver/azure-ai-agentserver-core/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml index e53b8f5474b7..afb5e6797396 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "starlette>=0.45.0", "uvicorn>=0.31.0", "aiohttp>=3.13.0", # used by azure-identity aio - "cachetools" + "cachetools>=6.0.0" ] [build-system] From 1a031f30d060f6f7bbf41d92308b7c9ac6dc119d Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Thu, 22 Jan 2026 14:05:01 -0800 Subject: [PATCH 19/29] fix bugs for non stream converter --- .../ai/agentserver/langgraph/langgraph.py | 12 +++--- .../models/response_api_converter.py | 4 +- .../models/response_api_default_converter.py | 2 +- ...ponse_api_non_stream_response_converter.py | 39 ++++++++++--------- 4 files changed, 29 insertions(+), 28 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py index 6827cced1902..e8e524764db2 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py @@ -12,8 +12,8 @@ from azure.ai.agentserver.core.constants import Constants from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.server.base import FoundryCBAgent -from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext -from azure.ai.agentserver.core.tools import OAuthConsentRequiredError +from azure.ai.agentserver.core import AgentRunContext +from azure.ai.agentserver.core.tools import OAuthConsentRequiredError # pylint:disable=import-error,no-name-in-module from ._context import LanggraphRunContext from .models.response_api_converter import GraphInputArguments, ResponseAPIConverter from .models.response_api_default_converter import ResponseAPIDefaultConverter @@ -68,7 +68,7 @@ async def agent_run(self, context: AgentRunContext): try: lg_run_context = await self.setup_lg_run_context(context) input_arguments = await self.converter.convert_request(lg_run_context) - self.ensure_runnable_config(input_arguments, lg_run_context) + self.ensure_runnable_config(input_arguments) if not context.stream: response = await self.agent_run_non_stream(input_arguments) @@ -156,19 +156,17 @@ async def agent_run_astream(self, logger.error(f"Error during streaming agent run: {e}", exc_info=True) raise e - def ensure_runnable_config(self, input_arguments: GraphInputArguments, context: LanggraphRunContext): + def ensure_runnable_config(self, input_arguments: GraphInputArguments): """ Ensure the RunnableConfig is set in the input arguments. :param input_arguments: The input arguments for the agent run. :type input_arguments: GraphInputArguments - :param context: The Langgraph run context. - :type context: LanggraphRunContext """ config = input_arguments.get("config", {}) configurable = config.get("configurable", {}) - config["configurable"] = configurable configurable["thread_id"] = input_arguments["context"].agent_run.conversation_id + config["configurable"] = configurable callbacks = config.get("callbacks", []) if self.azure_ai_tracer and self.azure_ai_tracer not in callbacks: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_converter.py index d1c5531993a1..3cb6314ecc12 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_converter.py @@ -61,14 +61,14 @@ async def convert_request(self, context: LanggraphRunContext) -> GraphInputArgum """ @abstractmethod - async def convert_response_non_stream(self, output: Any, context: LanggraphRunContext) -> Response: + async def convert_response_non_stream(self, output: Union[dict[str, Any], Any], context: LanggraphRunContext) -> Response: """Convert the completed LangGraph state into a final non-streaming Response object. This is a convenience wrapper around state_to_response that retrieves the current state snapshot asynchronously. :param output: The LangGraph output to convert. - :type output: Any + :type output: Union[dict[str, Any], Any], :param context: The context for the agent run. :type context: LanggraphRunContext diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_default_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_default_converter.py index 9bc237c87cf1..5a22134372c7 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_default_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_default_converter.py @@ -56,7 +56,7 @@ async def convert_request(self, context: LanggraphRunContext) -> GraphInputArgum context=context, ) - async def convert_response_non_stream(self, output: Any, context: LanggraphRunContext) -> Response: + async def convert_response_non_stream(self, output: Union[dict[str, Any], Any], context: LanggraphRunContext) -> Response: agent_run_context = context.agent_run converter = self._create_non_stream_response_converter(context) converted_output = converter.convert(output) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py index c776fad3dcad..7fa3c51e3223 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py @@ -5,7 +5,7 @@ # mypy: disable-error-code="valid-type,call-overload,attr-defined" import copy from abc import ABC, abstractmethod -from typing import Any, Collection, Iterable, List +from typing import Any, Collection, Iterable, List, Union from langchain_core import messages from langchain_core.messages import AnyMessage @@ -51,31 +51,34 @@ def __init__(self, self.context = context self.hitl_helper = hitl_helper - def convert(self, output: dict[str, Any]) -> list[project_models.ItemResource]: + def convert(self, output: Union[dict[str, Any], Any]) -> list[project_models.ItemResource]: res: list[project_models.ItemResource] = [] - for node_name, node_output in output.items(): - node_results = self._convert_node_output(node_name, node_output) - res.extend(node_results) + for step in output: + for node_name, node_output in step.items(): + node_results = self._convert_node_output(node_name, node_output) + res.extend(node_results) return res def _convert_node_output( self, node_name: str, node_output: Any ) -> Iterable[project_models.ItemResource]: + logger.info(f"Converting output for node: {node_name} with output: {node_output}") if node_name == INTERRUPT_NODE_NAME: yield from self.hitl_helper.convert_interrupts(node_output) - - message_arr = node_output.get("messages") - if not message_arr or not isinstance(message_arr, Collection): - logger.warning(f"No messages found in node {node_name} output: {node_output}") - return - - for message in message_arr: - try: - converted = self.convert_output_message(message) - if converted: - yield converted - except Exception as e: - logger.error(f"Error converting message {message}: {e}") + + else: + message_arr = node_output.get("messages") + if not message_arr or not isinstance(message_arr, Collection): + logger.warning(f"No messages found in node {node_name} output: {node_output}") + return + + for message in message_arr: + try: + converted = self.convert_output_message(message) + if converted: + yield converted + except Exception as e: + logger.error(f"Error converting message {message}: {e}") def convert_output_message(self, output_message: AnyMessage): # pylint: disable=inconsistent-return-statements # Implement the conversion logic for inner inputs From c68ee16e57dedf9a77a89c365b3f1e7b84d3313f Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Thu, 22 Jan 2026 14:26:05 -0800 Subject: [PATCH 20/29] fix pylint --- .../ai/agentserver/langgraph/models/response_api_converter.py | 4 +++- .../langgraph/models/response_api_default_converter.py | 3 ++- .../models/response_api_non_stream_response_converter.py | 1 - 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_converter.py index 3cb6314ecc12..32cbf93a4bfb 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_converter.py @@ -61,7 +61,9 @@ async def convert_request(self, context: LanggraphRunContext) -> GraphInputArgum """ @abstractmethod - async def convert_response_non_stream(self, output: Union[dict[str, Any], Any], context: LanggraphRunContext) -> Response: + async def convert_response_non_stream( + self, output: Union[dict[str, Any], Any], context: LanggraphRunContext + ) -> Response: """Convert the completed LangGraph state into a final non-streaming Response object. This is a convenience wrapper around state_to_response that retrieves diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_default_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_default_converter.py index 5a22134372c7..cfe5229e3634 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_default_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_default_converter.py @@ -56,7 +56,8 @@ async def convert_request(self, context: LanggraphRunContext) -> GraphInputArgum context=context, ) - async def convert_response_non_stream(self, output: Union[dict[str, Any], Any], context: LanggraphRunContext) -> Response: + async def convert_response_non_stream( + self, output: Union[dict[str, Any], Any], context: LanggraphRunContext) -> Response: agent_run_context = context.agent_run converter = self._create_non_stream_response_converter(context) converted_output = converter.convert(output) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py index 7fa3c51e3223..b5f770ebff34 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py @@ -65,7 +65,6 @@ def _convert_node_output( logger.info(f"Converting output for node: {node_name} with output: {node_output}") if node_name == INTERRUPT_NODE_NAME: yield from self.hitl_helper.convert_interrupts(node_output) - else: message_arr = node_output.get("messages") if not message_arr or not isinstance(message_arr, Collection): From 7a1324e9136d69287c5c61a6928733f4f834c3f0 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Thu, 22 Jan 2026 15:58:04 -0800 Subject: [PATCH 21/29] add output type check --- .../models/response_api_non_stream_response_converter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py index b5f770ebff34..7ec8bdf14f1a 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py @@ -53,6 +53,9 @@ def __init__(self, def convert(self, output: Union[dict[str, Any], Any]) -> list[project_models.ItemResource]: res: list[project_models.ItemResource] = [] + if not isinstance(output, list): + logger.error(f"Expected output to be a list, got {type(output)}: {output}") + raise ValueError(f"Invalid output format. Expected a list, got {type(output)}.") for step in output: for node_name, node_output in step.items(): node_results = self._convert_node_output(node_name, node_output) @@ -62,7 +65,6 @@ def convert(self, output: Union[dict[str, Any], Any]) -> list[project_models.Ite def _convert_node_output( self, node_name: str, node_output: Any ) -> Iterable[project_models.ItemResource]: - logger.info(f"Converting output for node: {node_name} with output: {node_output}") if node_name == INTERRUPT_NODE_NAME: yield from self.hitl_helper.convert_interrupts(node_output) else: From 7adbb6639b68d08719dbe9f65ad90dbcf0e3fb80 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Thu, 22 Jan 2026 16:08:43 -0800 Subject: [PATCH 22/29] fix minors --- .../models/agent_framework_input_converters.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py index 4e3feb52a29f..9ba678e7bebf 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py @@ -50,13 +50,13 @@ async def transform_input( if agent_thread and agent_thread.message_store: thread_messages = await agent_thread.message_store.list_messages() pending_hitl_requests = self._hitl_helper.get_pending_hitl_request(thread_messages, checkpoint) - logger.info("Pending HitL requests: %s", list(pending_hitl_requests.keys())) - hitl_response = self._hitl_helper.validate_and_convert_hitl_response( - input, - pending_requests=pending_hitl_requests) - logger.info("HitL response validation result: %s", [m.to_dict() for m in hitl_response]) - if hitl_response: - return hitl_response + if pending_hitl_requests: + logger.info("Pending HitL requests: %s", list(pending_hitl_requests.keys())) + hitl_response = self._hitl_helper.validate_and_convert_hitl_response( + input, + pending_requests=pending_hitl_requests) + if hitl_response: + return hitl_response return self._transform_input_internal(input) From 61e8c4c99604a8edc4ab4559fd73b97903faed4c Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Thu, 22 Jan 2026 16:36:17 -0800 Subject: [PATCH 23/29] updated version and changelog --- .../CHANGELOG.md | 19 +++++++++++++++++++ .../ai/agentserver/agentframework/_version.py | 2 +- .../azure-ai-agentserver-core/CHANGELOG.md | 12 ++++++++++++ .../azure/ai/agentserver/core/_version.py | 2 +- .../CHANGELOG.md | 16 ++++++++++++++++ .../ai/agentserver/langgraph/_version.py | 2 +- 6 files changed, 50 insertions(+), 3 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md index 84c4a76a27e5..29bae6795995 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md @@ -1,6 +1,25 @@ # Release History +## 1.0.0b9 (2026-01-23) + +- Integrated with Foundry Tools +- Add persistence for agent thread and checkpoint +- Fixed WorkflowAgent concurrency issue +- Support Human-in-the-Loop + + +## 1.0.0b8 (2026-01-21) + +### Features Added + +- Support keep alive for long-running streaming responses. + +### Bugs Fixed + +- Fixed AgentFramework breaking change and pin version to >=1.0.0b251112,<=1.0.0b260107 + + ## 1.0.0b7 (2025-12-05) ### Features Added diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_version.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_version.py index 84058978c521..b1c2836b6921 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_version.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_version.py @@ -6,4 +6,4 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -VERSION = "1.0.0b7" +VERSION = "1.0.0b9" diff --git a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md index 84c4a76a27e5..b05d70708716 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md @@ -1,6 +1,18 @@ # Release History +## 1.0.0b9 (2026-01-23) + +- Integrated with Foundry Tools + + +## 1.0.0b8 (2026-01-21) + +### Features Added + +- Support keep alive for long-running streaming responses. + + ## 1.0.0b7 (2025-12-05) ### Features Added diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py index 84058978c521..b1c2836b6921 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py @@ -6,4 +6,4 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -VERSION = "1.0.0b7" +VERSION = "1.0.0b9" diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md index abea93ee106a..43641a3de515 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md @@ -1,6 +1,22 @@ # Release History +## 1.0.0b9 (2026-01-23) + +- Integrated with Foundry Tools + +- Support Human-in-the-Loop + +- Added Response API converters for request conversion and orchestration. + + +## 1.0.0b8 (2026-01-21) + +### Features Added + +- Support keep alive for long-running streaming responses. + + ## 1.0.0b7 (2025-12-05) ### Features Added diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_version.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_version.py index 84058978c521..b1c2836b6921 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_version.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_version.py @@ -6,4 +6,4 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -VERSION = "1.0.0b7" +VERSION = "1.0.0b9" From b8a543a487c6d49fb5a8112d1948dae479a9a701 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Thu, 22 Jan 2026 16:40:18 -0800 Subject: [PATCH 24/29] update required core package --- .../azure-ai-agentserver-agentframework/pyproject.toml | 2 +- sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml index 47ffdce2c23e..7cbbfd0edf6a 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ keywords = ["azure", "azure sdk"] dependencies = [ - "azure-ai-agentserver-core>=1.0.0b7", + "azure-ai-agentserver-core==1.0.0b9", "agent-framework-azure-ai>=1.0.0b251112,<=1.0.0b260107", "agent-framework-core>=1.0.0b251112,<=1.0.0b260107", "opentelemetry-exporter-otlp-proto-grpc>=1.36.0", diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml index b970062738ee..0e8a32a3afa2 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ keywords = ["azure", "azure sdk"] dependencies = [ - "azure-ai-agentserver-core>=1.0.0b7", + "azure-ai-agentserver-core==1.0.0b9", "langchain>0.3.20", "langchain-openai>0.3.10", "langchain-azure-ai[opentelemetry]>=0.1.8", From 41b9b5362df01d2db056dd23f842a85ca36453c7 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Thu, 22 Jan 2026 16:52:11 -0800 Subject: [PATCH 25/29] fix langgraph sphinx --- ...graph.models.response_event_generators.rst | 74 +++++++++++++++++ .../azure.ai.agentserver.langgraph.models.rst | 82 +++++++++++++++++++ .../doc/azure.ai.agentserver.langgraph.rst | 27 ++++++ .../azure.ai.agentserver.langgraph.tools.rst | 6 ++ 4 files changed, 189 insertions(+) create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.response_event_generators.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.tools.rst diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.response_event_generators.rst b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.response_event_generators.rst new file mode 100644 index 000000000000..af7cc69bd859 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.response_event_generators.rst @@ -0,0 +1,74 @@ +azure.ai.agentserver.langgraph.models.response\_event\_generators package +========================================================================= + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators + :inherited-members: + :members: + :undoc-members: + +Submodules +---------- + +azure.ai.agentserver.langgraph.models.response\_event\_generators.item\_content\_helpers module +----------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.item_content_helpers + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.item\_resource\_helpers module +------------------------------------------------------------------------------------------------ + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.item_resource_helpers + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.response\_content\_part\_event\_generator module +------------------------------------------------------------------------------------------------------------------ + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.response_content_part_event_generator + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.response\_event\_generator module +--------------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.response_event_generator + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.response\_function\_call\_argument\_event\_generator module +----------------------------------------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.response_function_call_argument_event_generator + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.response\_output\_item\_event\_generator module +----------------------------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.response_output_item_event_generator + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.response\_output\_text\_event\_generator module +----------------------------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.response_output_text_event_generator + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.response\_stream\_event\_generator module +----------------------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.response_stream_event_generator + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.rst b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.rst new file mode 100644 index 000000000000..aba857c3b64a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.rst @@ -0,0 +1,82 @@ +azure.ai.agentserver.langgraph.models package +============================================= + +.. automodule:: azure.ai.agentserver.langgraph.models + :inherited-members: + :members: + :undoc-members: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + azure.ai.agentserver.langgraph.models.response_event_generators + +Submodules +---------- + +azure.ai.agentserver.langgraph.models.human\_in\_the\_loop\_helper module +------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.human_in_the_loop_helper + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.human\_in\_the\_loop\_json\_helper module +------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.human_in_the_loop_json_helper + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_api\_converter module +--------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_api_converter + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_api\_default\_converter module +------------------------------------------------------------------------------ + +.. automodule:: azure.ai.agentserver.langgraph.models.response_api_default_converter + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_api\_non\_stream\_response\_converter module +-------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_api_non_stream_response_converter + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_api\_request\_converter module +------------------------------------------------------------------------------ + +.. automodule:: azure.ai.agentserver.langgraph.models.response_api_request_converter + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_api\_stream\_response\_converter module +--------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_api_stream_response_converter + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.utils module +-------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.utils + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.rst b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.rst new file mode 100644 index 000000000000..deefeb67fa96 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.rst @@ -0,0 +1,27 @@ +azure.ai.agentserver.langgraph package +====================================== + +.. automodule:: azure.ai.agentserver.langgraph + :inherited-members: + :members: + :undoc-members: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + azure.ai.agentserver.langgraph.models + azure.ai.agentserver.langgraph.tools + +Submodules +---------- + +azure.ai.agentserver.langgraph.langgraph module +----------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.langgraph + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.tools.rst b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.tools.rst new file mode 100644 index 000000000000..17f7ef6d2ab7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.tools.rst @@ -0,0 +1,6 @@ +azure.ai.agentserver.langgraph.tools package +============================================ + +.. automodule:: azure.ai.agentserver.langgraph.tools + :members: + :undoc-members: From f192f0bcb455a23e07a77fbca3b35d1cd72bddf7 Mon Sep 17 00:00:00 2001 From: junanchen Date: Thu, 22 Jan 2026 20:28:31 -0800 Subject: [PATCH 26/29] Put LG run context to runtime & config --- .../ai/agentserver/langgraph/_context.py | 40 ++++++- .../ai/agentserver/langgraph/langgraph.py | 7 +- .../agentserver/langgraph/tools/_builder.py | 2 +- .../langgraph/tools/_chat_model.py | 19 ++-- .../agentserver/langgraph/tools/_tool_node.py | 2 +- .../tool_client_example/graph_agent_tool.py | 104 ++++++++++++++++++ 6 files changed, 162 insertions(+), 12 deletions(-) create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/samples/tool_client_example/graph_agent_tool.py diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_context.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_context.py index 81f0e0f0b545..89be24921f54 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_context.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_context.py @@ -1,11 +1,13 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +import sys from dataclasses import dataclass from typing import Optional, Union +from langchain_core.runnables import RunnableConfig from langgraph.prebuilt import ToolRuntime -from langgraph.runtime import Runtime +from langgraph.runtime import Runtime, get_runtime from azure.ai.agentserver.core import AgentRunContext from .tools._context import FoundryToolContext @@ -17,6 +19,42 @@ class LanggraphRunContext: tools: FoundryToolContext + def attach_to_config(self, config: RunnableConfig): + config["configurable"]["__foundry_hosted_agent_langgraph_run_context__"] = self + + @classmethod + def resolve(cls, + config: Optional[RunnableConfig] = None, + runtime: Optional[Union[Runtime, ToolRuntime]] = None) -> Optional["LanggraphRunContext"]: + """Resolve the LanggraphRunContext from either a RunnableConfig or a Runtime. + + :param config: Optional RunnableConfig to extract the context from. + :param runtime: Optional Runtime or ToolRuntime to extract the context from. + :return: An instance of LanggraphRunContext if found, otherwise None. + """ + context: Optional["LanggraphRunContext"] = None + if config: + context = cls.from_config(config) + if not context and (r := cls._resolve_runtime(runtime)): + context = cls.from_runtime(r) + return context + + @staticmethod + def _resolve_runtime( + runtime: Optional[Union[Runtime, ToolRuntime]] = None) -> Optional[Union[Runtime, ToolRuntime]]: + if runtime: + return runtime + if sys.version_info >= (3, 11): + return get_runtime(LanggraphRunContext) + return None + + @staticmethod + def from_config(config: RunnableConfig) -> Optional["LanggraphRunContext"]: + context = config["configurable"].get("__foundry_hosted_agent_langgraph_run_context__") + if isinstance(context, LanggraphRunContext): + return context + return None + @staticmethod def from_runtime(runtime: Union[Runtime, ToolRuntime]) -> Optional["LanggraphRunContext"]: context = runtime.context diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py index e8e524764db2..aae3bc32ee35 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py @@ -68,7 +68,7 @@ async def agent_run(self, context: AgentRunContext): try: lg_run_context = await self.setup_lg_run_context(context) input_arguments = await self.converter.convert_request(lg_run_context) - self.ensure_runnable_config(input_arguments) + self.ensure_runnable_config(input_arguments, lg_run_context) if not context.stream: response = await self.agent_run_non_stream(input_arguments) @@ -156,17 +156,20 @@ async def agent_run_astream(self, logger.error(f"Error during streaming agent run: {e}", exc_info=True) raise e - def ensure_runnable_config(self, input_arguments: GraphInputArguments): + def ensure_runnable_config(self, input_arguments: GraphInputArguments, context: LanggraphRunContext): """ Ensure the RunnableConfig is set in the input arguments. :param input_arguments: The input arguments for the agent run. :type input_arguments: GraphInputArguments + :param context: The Langgraph run context. + :type context: LanggraphRunContext """ config = input_arguments.get("config", {}) configurable = config.get("configurable", {}) configurable["thread_id"] = input_arguments["context"].agent_run.conversation_id config["configurable"] = configurable + context.attach_to_config(config) callbacks = config.get("callbacks", []) if self.azure_ai_tracer and self.azure_ai_tracer not in callbacks: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py index 828a8b42ae45..ccd79dff8fc6 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py @@ -55,6 +55,6 @@ def use_foundry_tools( # pylint: disable=C4743 if tools is None: raise ValueError("Tools must be provided when a model is given.") get_registry().extend(tools) - return FoundryToolLateBindingChatModel(model_or_tools, foundry_tools=tools) + return FoundryToolLateBindingChatModel(model_or_tools, runtime=None, foundry_tools=tools) get_registry().extend(model_or_tools) return FoundryToolBindingMiddleware(model_or_tools) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py index c221910218f4..4ca422b88c41 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py @@ -30,7 +30,7 @@ class FoundryToolLateBindingChatModel(BaseChatModel): :type foundry_tools: List[FoundryToolLike] """ - def __init__(self, delegate: BaseChatModel, runtime: Runtime, foundry_tools: List[FoundryToolLike]): + def __init__(self, delegate: BaseChatModel, runtime: Optional[Runtime], foundry_tools: List[FoundryToolLike]): super().__init__() self._delegate = delegate self._runtime = runtime @@ -88,12 +88,17 @@ def bind_tools(self, # pylint: disable=C4758 return self - def _bound_delegate_for_call(self) -> Runnable[LanguageModelInput, AIMessage]: + def _bound_delegate_for_call(self, config: Optional[RunnableConfig]) -> Runnable[LanguageModelInput, AIMessage]: from .._context import LanggraphRunContext foundry_tools: Iterable[BaseTool] = [] - if (context := LanggraphRunContext.from_runtime(self._runtime)) is not None: + if context := LanggraphRunContext.resolve(config, self._runtime): foundry_tools = context.tools.resolved_tools.get(self._foundry_tools_to_bind) + elif self._foundry_tools_to_bind: + raise RuntimeError("Unable to resolve foundry tools from context, " + "if you are running in python < 3.11, " + "make sure you are passing RunnableConfig when calling model.") + all_tools = self._bound_tools.copy() all_tools.extend(foundry_tools) @@ -104,16 +109,16 @@ def _bound_delegate_for_call(self) -> Runnable[LanguageModelInput, AIMessage]: return self._delegate.bind_tools(all_tools, **bound_kwargs) def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any: - return self._bound_delegate_for_call().invoke(input, config=config, **kwargs) + return self._bound_delegate_for_call(config).invoke(input, config=config, **kwargs) async def ainvoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any: - return await self._bound_delegate_for_call().ainvoke(input, config=config, **kwargs) + return await self._bound_delegate_for_call(config).ainvoke(input, config=config, **kwargs) def stream(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any): - yield from self._bound_delegate_for_call().stream(input, config=config, **kwargs) + yield from self._bound_delegate_for_call(config).stream(input, config=config, **kwargs) async def astream(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any): - async for x in self._bound_delegate_for_call().astream(input, config=config, **kwargs): + async for x in self._bound_delegate_for_call(config).astream(input, config=config, **kwargs): yield x @property diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py index 5f3c6326836b..1bfef8c39f81 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py @@ -78,7 +78,7 @@ def _maybe_calling_foundry_tool(self, request: ToolCallRequest) -> ToolCallReque if (request.tool or not self._allowed_foundry_tools - or (context := LanggraphRunContext.from_runtime(request.runtime)) is None): + or not (context := LanggraphRunContext.resolve(runtime=request.runtime))): # tool is already resolved return request diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/tool_client_example/graph_agent_tool.py b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/tool_client_example/graph_agent_tool.py new file mode 100644 index 000000000000..c4992ba71f46 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/tool_client_example/graph_agent_tool.py @@ -0,0 +1,104 @@ +import os + +from dotenv import load_dotenv +from langchain.chat_models import init_chat_model +from langchain_core.messages import SystemMessage, ToolMessage +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import tool +from langgraph.graph import ( + END, + START, + MessagesState, + StateGraph, +) +from typing_extensions import Literal +from azure.identity import DefaultAzureCredential, get_bearer_token_provider + +from azure.ai.agentserver.langgraph import from_langgraph +from azure.ai.agentserver.langgraph.tools import use_foundry_tools + +load_dotenv() + +deployment_name = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", "gpt-4o") +credential = DefaultAzureCredential() +token_provider = get_bearer_token_provider( + credential, "https://cognitiveservices.azure.com/.default" +) +llm = init_chat_model( + f"azure_openai:{deployment_name}", + azure_ad_token_provider=token_provider, +) +llm_with_foundry_tools = use_foundry_tools(llm, [ + { + # use the python tool to calculate what is 4 * 3.82. and then find its square root and then find the square root of that result + "type": "code_interpreter" + }, + { + # Give me the Azure CLI commands to create an Azure Container App with a managed identity. search Microsoft Learn + "type": "mcp", + "project_connection_id": "MicrosoftLearn" + }, + # { + # "type": "mcp", + # "project_connection_id": "FoundryMCPServerpreview" + # } +]) + + +# Nodes +async def llm_call(state: MessagesState, config: RunnableConfig): + """LLM decides whether to call a tool or not""" + + return { + "messages": [ + await llm_with_foundry_tools.ainvoke( + [ + SystemMessage( + content="You are a helpful assistant tasked with performing arithmetic on a set of inputs." + ) + ] + + state["messages"], + config=config, + ) + ] + } + + +# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call +def should_continue(state: MessagesState) -> Literal["environment", END]: + """Decide if we should continue the loop or stop based upon whether the LLM made a tool call""" + + messages = state["messages"] + last_message = messages[-1] + # If the LLM makes a tool call, then perform an action + if last_message.tool_calls: + return "Action" + # Otherwise, we stop (reply to the user) + return END + + +# Build workflow +agent_builder = StateGraph(MessagesState) + +# Add nodes +agent_builder.add_node("llm_call", llm_call) +agent_builder.add_node("environment", llm_with_foundry_tools.tool_node) + +# Add edges to connect nodes +agent_builder.add_edge(START, "llm_call") +agent_builder.add_conditional_edges( + "llm_call", + should_continue, + { + "Action": "environment", + END: END, + }, +) +agent_builder.add_edge("environment", "llm_call") + +# Compile the agent +agent = agent_builder.compile() + +if __name__ == "__main__": + adapter = from_langgraph(agent) + adapter.run() From 29864961a1083a989466740461041c4cc63b3ff7 Mon Sep 17 00:00:00 2001 From: junanchen Date: Thu, 22 Jan 2026 21:12:53 -0800 Subject: [PATCH 27/29] langgraph ut --- .../ai/agentserver/core/tools/__init__.py | 2 + .../unit_tests/langgraph/tools/__init__.py | 5 + .../unit_tests/langgraph/tools/conftest.py | 271 ++++++++++ .../langgraph/tools/test_agent_integration.py | 404 ++++++++++++++ .../langgraph/tools/test_builder.py | 109 ++++ .../langgraph/tools/test_chat_model.py | 277 ++++++++++ .../langgraph/tools/test_context.py | 36 ++ .../langgraph/tools/test_middleware.py | 197 +++++++ .../langgraph/tools/test_resolver.py | 502 ++++++++++++++++++ .../langgraph/tools/test_tool_node.py | 179 +++++++ 10 files changed, 1982 insertions(+) create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/conftest.py create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_agent_integration.py create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_builder.py create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_chat_model.py create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_context.py create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_middleware.py create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_resolver.py create mode 100644 sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_tool_node.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py index 5b356f38c825..34c58d65cfd6 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py @@ -15,6 +15,7 @@ FoundryConnectedTool, FoundryHostedMcpTool, FoundryTool, + FoundryToolDetails, FoundryToolProtocol, FoundryToolSource, ResolvedFoundryTool, @@ -47,6 +48,7 @@ "FoundryConnectedTool", "FoundryHostedMcpTool", "FoundryTool", + "FoundryToolDetails", "FoundryToolProtocol", "FoundryToolSource", "ResolvedFoundryTool", diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/conftest.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/conftest.py new file mode 100644 index 000000000000..7efc298559c1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/conftest.py @@ -0,0 +1,271 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for langgraph tools unit tests.""" +from typing import Any, Dict, List, Optional + +import pytest +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import BaseTool, tool + +from azure.ai.agentserver.core.tools import ( + FoundryHostedMcpTool, + FoundryConnectedTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) +from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext +from azure.ai.agentserver.langgraph._context import LanggraphRunContext +from azure.ai.agentserver.langgraph.tools._context import FoundryToolContext +from azure.ai.agentserver.langgraph.tools._resolver import ResolvedTools + + +class FakeChatModel(BaseChatModel): + """A fake chat model for testing purposes that returns pre-configured responses.""" + + responses: List[AIMessage] = [] + tool_calls_list: List[List[Dict[str, Any]]] = [] + _call_count: int = 0 + _bound_tools: List[Any] = [] + _bound_kwargs: Dict[str, Any] = {} + + def __init__( + self, + responses: Optional[List[AIMessage]] = None, + tool_calls: Optional[List[List[Dict[str, Any]]]] = None, + **kwargs: Any, + ): + """Initialize the fake chat model. + + :param responses: List of AIMessage responses to return in sequence. + :param tool_calls: List of tool_calls lists corresponding to each response. + """ + super().__init__(**kwargs) + self.responses = responses or [] + self.tool_calls_list = tool_calls or [] + self._call_count = 0 + self._bound_tools = [] + self._bound_kwargs = {} + + @property + def _llm_type(self) -> str: + return "fake_chat_model" + + def _generate( + self, + messages: List[Any], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Generate a response.""" + response = self._get_next_response() + return ChatResult(generations=[ChatGeneration(message=response)]) + + def bind_tools( + self, + tools: List[Any], + **kwargs: Any, + ) -> "FakeChatModel": + """Bind tools to this model.""" + self._bound_tools = list(tools) + self._bound_kwargs.update(kwargs) + return self + + def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> AIMessage: + """Synchronously invoke the model.""" + return self._get_next_response() + + async def ainvoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> AIMessage: + """Asynchronously invoke the model.""" + return self._get_next_response() + + def stream(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any): + """Stream the response.""" + yield self._get_next_response() + + async def astream(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any): + """Async stream the response.""" + yield self._get_next_response() + + def _get_next_response(self) -> AIMessage: + """Get the next response in sequence.""" + if self._call_count < len(self.responses): + response = self.responses[self._call_count] + else: + # Default response if no more configured + response = AIMessage(content="Default response") + + # Apply tool calls if configured + if self._call_count < len(self.tool_calls_list): + response = AIMessage( + content=response.content, + tool_calls=self.tool_calls_list[self._call_count], + ) + + self._call_count += 1 + return response + + +@pytest.fixture +def sample_schema_definition() -> SchemaDefinition: + """Create a sample SchemaDefinition.""" + return SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Search query"), + }, + required={"query"}, + ) + + +@pytest.fixture +def sample_code_interpreter_tool() -> FoundryHostedMcpTool: + """Create a sample code interpreter tool definition.""" + return FoundryHostedMcpTool( + name="code_interpreter", + configuration={}, + ) + + +@pytest.fixture +def sample_mcp_connected_tool() -> FoundryConnectedTool: + """Create a sample MCP connected tool definition.""" + return FoundryConnectedTool( + protocol="mcp", + project_connection_id="MicrosoftLearn", + ) + + +@pytest.fixture +def sample_tool_details(sample_schema_definition: SchemaDefinition) -> FoundryToolDetails: + """Create a sample FoundryToolDetails.""" + return FoundryToolDetails( + name="search", + description="Search for documents", + input_schema=sample_schema_definition, + ) + + +@pytest.fixture +def sample_resolved_tool( + sample_code_interpreter_tool: FoundryHostedMcpTool, + sample_tool_details: FoundryToolDetails, +) -> ResolvedFoundryTool: + """Create a sample resolved foundry tool.""" + return ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=sample_tool_details, + ) + + +@pytest.fixture +def mock_langchain_tool() -> BaseTool: + """Create a mock LangChain BaseTool.""" + @tool + def mock_tool(query: str) -> str: + """Mock tool for testing. + + :param query: The search query. + :return: Mock result. + """ + return f"Mock result for: {query}" + + return mock_tool + + +@pytest.fixture +def mock_async_langchain_tool() -> BaseTool: + """Create a mock async LangChain BaseTool.""" + @tool + async def mock_async_tool(query: str) -> str: + """Mock async tool for testing. + + :param query: The search query. + :return: Mock result. + """ + return f"Async mock result for: {query}" + + return mock_async_tool + + +@pytest.fixture +def sample_resolved_tools( + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, +) -> ResolvedTools: + """Create a sample ResolvedTools instance.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="mock_tool", + description="Mock tool for testing", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Query"), + }, + required={"query"}, + ), + ), + ) + return ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + +@pytest.fixture +def mock_agent_run_context() -> AgentRunContext: + """Create a mock AgentRunContext.""" + payload = { + "input": [{"role": "user", "content": "Hello"}], + "stream": False, + } + return AgentRunContext(payload=payload) + + +@pytest.fixture +def mock_foundry_tool_context(sample_resolved_tools: ResolvedTools) -> FoundryToolContext: + """Create a mock FoundryToolContext.""" + return FoundryToolContext(resolved_tools=sample_resolved_tools) + + +@pytest.fixture +def mock_langgraph_run_context( + mock_agent_run_context: AgentRunContext, + mock_foundry_tool_context: FoundryToolContext, +) -> LanggraphRunContext: + """Create a mock LanggraphRunContext.""" + return LanggraphRunContext( + agent_run=mock_agent_run_context, + tools=mock_foundry_tool_context, + ) + + +@pytest.fixture +def fake_chat_model_simple() -> FakeChatModel: + """Create a simple fake chat model.""" + return FakeChatModel( + responses=[AIMessage(content="Hello! How can I help you?")], + ) + + +@pytest.fixture +def fake_chat_model_with_tool_call() -> FakeChatModel: + """Create a fake chat model that makes a tool call.""" + return FakeChatModel( + responses=[ + AIMessage(content=""), # First response: tool call + AIMessage(content="The answer is 42."), # Second response: final answer + ], + tool_calls=[ + [{"id": "call_1", "name": "mock_tool", "args": {"query": "test query"}}], + [], # No tool calls in final response + ], + ) + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_agent_integration.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_agent_integration.py new file mode 100644 index 000000000000..fab1955ef415 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_agent_integration.py @@ -0,0 +1,404 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Integration-style unit tests for langgraph agents with foundry tools. + +These tests demonstrate the usage patterns similar to the tool_client_example samples, +but use mocked models and tools to avoid calling real services. +""" +import pytest + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import BaseTool, tool +from langgraph.graph import END, START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode +from typing_extensions import Literal + +from azure.ai.agentserver.core.tools import ( + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) +from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext +from azure.ai.agentserver.langgraph._context import LanggraphRunContext +from azure.ai.agentserver.langgraph.tools import use_foundry_tools +from azure.ai.agentserver.langgraph.tools._context import FoundryToolContext +from azure.ai.agentserver.langgraph.tools._chat_model import FoundryToolLateBindingChatModel +from azure.ai.agentserver.langgraph.tools._middleware import FoundryToolBindingMiddleware +from azure.ai.agentserver.langgraph.tools._resolver import ResolvedTools, get_registry + +from .conftest import FakeChatModel + + +@pytest.fixture(autouse=True) +def clear_registry(): + """Clear the global registry before and after each test.""" + registry = get_registry() + registry.clear() + yield + registry.clear() + + +@pytest.mark.unit +class TestGraphAgentWithFoundryTools: + """Tests demonstrating graph agent usage patterns similar to graph_agent_tool.py sample.""" + + def _create_mock_langgraph_context( + self, + foundry_tool: FoundryHostedMcpTool, + langchain_tool: BaseTool, + ) -> LanggraphRunContext: + """Create a mock LanggraphRunContext with resolved tools.""" + # Create resolved foundry tool + resolved_foundry_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name=langchain_tool.name, + description=langchain_tool.description or "Mock tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Query"), + }, + required={"query"}, + ), + ), + ) + + # Create resolved tools + resolved_tools = ResolvedTools(tools=[(resolved_foundry_tool, langchain_tool)]) + + # Create context + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext(resolved_tools=resolved_tools) + + return LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + @pytest.mark.asyncio + async def test_graph_agent_with_foundry_tools_no_tool_call(self): + """Test a graph agent that uses foundry tools but doesn't make a tool call.""" + # Create a mock tool + @tool + def calculate(expression: str) -> str: + """Calculate a mathematical expression. + + :param expression: The expression to calculate. + :return: The result. + """ + return "42" + + # Create foundry tool definition + foundry_tool = FoundryHostedMcpTool(name="code_interpreter", configuration={}) + foundry_tools = [{"type": "code_interpreter"}] + + # Create mock model that returns simple response (no tool call) + mock_model = FakeChatModel( + responses=[AIMessage(content="The answer is 42.")], + ) + + # Create the foundry tool binding chat model + llm_with_foundry_tools = FoundryToolLateBindingChatModel( + delegate=mock_model, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Create context and attach + context = self._create_mock_langgraph_context(foundry_tool, calculate) + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + # Define the LLM call node + async def llm_call(state: MessagesState, config: RunnableConfig): + return { + "messages": [ + await llm_with_foundry_tools.ainvoke( + [SystemMessage(content="You are a helpful assistant.")] + + state["messages"], + config=config, + ) + ] + } + + # Define routing function + def should_continue(state: MessagesState) -> Literal["tools", "__end__"]: + messages = state["messages"] + last_message = messages[-1] + if hasattr(last_message, 'tool_calls') and last_message.tool_calls: + return "tools" + return END + + # Build the graph + builder = StateGraph(MessagesState) + builder.add_node("llm_call", llm_call) + builder.add_node("tools", llm_with_foundry_tools.tool_node) + builder.add_edge(START, "llm_call") + builder.add_conditional_edges("llm_call", should_continue, {"tools": "tools", END: END}) + builder.add_edge("tools", "llm_call") + + graph = builder.compile() + + # Run the graph + result = await graph.ainvoke( + {"messages": [HumanMessage(content="What is 6 * 7?")]}, + config=config, + ) + + # Verify result + assert len(result["messages"]) == 2 # HumanMessage + AIMessage + assert result["messages"][-1].content == "The answer is 42." + + @pytest.mark.asyncio + async def test_graph_agent_with_tool_call(self): + """Test a graph agent that makes a tool call.""" + # Create a mock tool + @tool + def calculate(expression: str) -> str: + """Calculate a mathematical expression. + + :param expression: The expression to calculate. + :return: The result. + """ + return "42" + + # Create foundry tool definition + foundry_tool = FoundryHostedMcpTool(name="code_interpreter", configuration={}) + foundry_tools = [{"type": "code_interpreter"}] + + # Create mock model that makes a tool call, then returns final answer + mock_model = FakeChatModel( + responses=[ + AIMessage( + content="", + tool_calls=[{"id": "call_1", "name": "calculate", "args": {"expression": "6 * 7"}}], + ), + AIMessage(content="The answer is 42."), + ], + ) + + # Create the foundry tool binding chat model + llm_with_foundry_tools = FoundryToolLateBindingChatModel( + delegate=mock_model, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Create context with the calculate tool + context = self._create_mock_langgraph_context(foundry_tool, calculate) + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + # Define the LLM call node + async def llm_call(state: MessagesState, config: RunnableConfig): + return { + "messages": [ + await llm_with_foundry_tools.ainvoke( + [SystemMessage(content="You are a helpful assistant.")] + + state["messages"], + config=config, + ) + ] + } + + # Define routing function + def should_continue(state: MessagesState) -> Literal["tools", "__end__"]: + messages = state["messages"] + last_message = messages[-1] + if hasattr(last_message, 'tool_calls') and last_message.tool_calls: + return "tools" + return END + + # Build the graph with a regular ToolNode (using the local tool directly for testing) + builder = StateGraph(MessagesState) + builder.add_node("llm_call", llm_call) + builder.add_node("tools", ToolNode([calculate])) + builder.add_edge(START, "llm_call") + builder.add_conditional_edges("llm_call", should_continue, {"tools": "tools", END: END}) + builder.add_edge("tools", "llm_call") + + graph = builder.compile() + + # Run the graph + result = await graph.ainvoke( + {"messages": [HumanMessage(content="What is 6 * 7?")]}, + config=config, + ) + + # Verify result - should have: HumanMessage, AIMessage (with tool call), ToolMessage, AIMessage (final) + assert len(result["messages"]) == 4 + assert result["messages"][-1].content == "The answer is 42." + + # Verify tool was called + tool_message = result["messages"][2] + assert isinstance(tool_message, ToolMessage) + assert tool_message.content == "42" + + +@pytest.mark.unit +class TestReactAgentWithFoundryTools: + """Tests demonstrating react agent usage patterns similar to react_agent_tool.py sample.""" + + @pytest.mark.asyncio + async def test_middleware_integration_with_foundry_tools(self): + """Test that FoundryToolBindingMiddleware correctly integrates with agents.""" + # Define foundry tools configuration + foundry_tools_config = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "MicrosoftLearn"}, + ] + + # Create middleware using use_foundry_tools + middleware = use_foundry_tools(foundry_tools_config) + + # Verify middleware is created correctly + assert isinstance(middleware, FoundryToolBindingMiddleware) + + # Verify dummy tool is created for the agent + assert len(middleware.tools) == 1 + assert middleware.tools[0].name == "__dummy_tool_by_foundry_middleware__" + + # Verify foundry tools are recorded + assert len(middleware._foundry_tools_to_bind) == 2 + + def test_use_foundry_tools_with_model(self): + """Test use_foundry_tools when used with a model directly.""" + foundry_tools = [{"type": "code_interpreter"}] + mock_model = FakeChatModel() + + result = use_foundry_tools(mock_model, foundry_tools) # type: ignore + + assert isinstance(result, FoundryToolLateBindingChatModel) + assert result._foundry_tools_to_bind == foundry_tools + + +@pytest.mark.unit +class TestLanggraphRunContextIntegration: + """Tests for LanggraphRunContext integration with langgraph.""" + + def test_context_attachment_to_config(self): + """Test that context is correctly attached to RunnableConfig.""" + # Create a mock context + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext() + + context = LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + # Create config and attach context + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + # Verify context is attached + assert "__foundry_hosted_agent_langgraph_run_context__" in config["configurable"] + assert config["configurable"]["__foundry_hosted_agent_langgraph_run_context__"] is context + + def test_context_resolution_from_config(self): + """Test that context can be resolved from RunnableConfig.""" + # Create and attach context + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext() + + context = LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + # Resolve context + resolved = LanggraphRunContext.resolve(config=config) + + assert resolved is context + + def test_context_resolution_returns_none_when_not_attached(self): + """Test that context resolution returns None when not attached.""" + config: RunnableConfig = {"configurable": {}} + + resolved = LanggraphRunContext.resolve(config=config) + + assert resolved is None + + def test_from_config_returns_context(self): + """Test LanggraphRunContext.from_config method.""" + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext() + + context = LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + result = LanggraphRunContext.from_config(config) + + assert result is context + + def test_from_config_returns_none_for_non_context_value(self): + """Test that from_config returns None when value is not LanggraphRunContext.""" + config: RunnableConfig = { + "configurable": { + "__foundry_hosted_agent_langgraph_run_context__": "not a context" + } + } + + result = LanggraphRunContext.from_config(config) + + assert result is None + + +@pytest.mark.unit +class TestToolsResolutionInGraph: + """Tests for tool resolution within langgraph execution.""" + + @pytest.mark.asyncio + async def test_foundry_tools_resolved_from_context_in_graph_node(self): + """Test that foundry tools are correctly resolved from context during graph execution.""" + # Create mock tool + @tool + def search(query: str) -> str: + """Search for information. + + :param query: The search query. + :return: Search results. + """ + return f"Results for: {query}" + + # Create foundry tool and context + foundry_tool = FoundryHostedMcpTool(name="code_interpreter", configuration={}) + resolved_foundry_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name="search", + description="Search tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + + resolved_tools = ResolvedTools(tools=[(resolved_foundry_tool, search)]) + + # Create context + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext(resolved_tools=resolved_tools) + lg_context = LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + # Create config and attach context + config: RunnableConfig = {"configurable": {}} + lg_context.attach_to_config(config) + + # Verify tools can be resolved + resolved = LanggraphRunContext.resolve(config=config) + assert resolved is not None + + tools = list(resolved.tools.resolved_tools.get(foundry_tool)) + assert len(tools) == 1 + assert tools[0].name == "search" + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_builder.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_builder.py new file mode 100644 index 000000000000..1a2a5af167be --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_builder.py @@ -0,0 +1,109 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for use_foundry_tools builder function.""" +import pytest +from typing import List + + +from azure.ai.agentserver.langgraph.tools._builder import use_foundry_tools +from azure.ai.agentserver.langgraph.tools._chat_model import FoundryToolLateBindingChatModel +from azure.ai.agentserver.langgraph.tools._middleware import FoundryToolBindingMiddleware +from azure.ai.agentserver.langgraph.tools._resolver import get_registry + +from .conftest import FakeChatModel + + +@pytest.fixture(autouse=True) +def clear_registry(): + """Clear the global registry before and after each test.""" + registry = get_registry() + registry.clear() + yield + registry.clear() + + +@pytest.mark.unit +class TestUseFoundryTools: + """Tests for use_foundry_tools function.""" + + def test_use_foundry_tools_with_tools_only_returns_middleware(self): + """Test that passing only tools returns FoundryToolBindingMiddleware.""" + tools = [{"type": "code_interpreter"}] + + result = use_foundry_tools(tools) + + assert isinstance(result, FoundryToolBindingMiddleware) + + def test_use_foundry_tools_with_model_and_tools_returns_chat_model(self): + """Test that passing model and tools returns FoundryToolLateBindingChatModel.""" + model = FakeChatModel() + tools = [{"type": "code_interpreter"}] + + result = use_foundry_tools(model, tools) # type: ignore + + assert isinstance(result, FoundryToolLateBindingChatModel) + + def test_use_foundry_tools_with_model_but_no_tools_raises_error(self): + """Test that passing model without tools raises ValueError.""" + model = FakeChatModel() + + with pytest.raises(ValueError, match="Tools must be provided"): + use_foundry_tools(model, None) # type: ignore + + def test_use_foundry_tools_registers_tools_in_global_registry(self): + """Test that tools are registered in the global registry.""" + tools = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "test"}, + ] + + use_foundry_tools(tools) + + registry = get_registry() + assert len(registry) == 2 + + def test_use_foundry_tools_with_model_registers_tools(self): + """Test that tools are registered when using with model.""" + model = FakeChatModel() + tools = [{"type": "code_interpreter"}] + + use_foundry_tools(model, tools) # type: ignore + + registry = get_registry() + assert len(registry) == 1 + + def test_use_foundry_tools_with_empty_tools_list(self): + """Test using with empty tools list.""" + tools: List = [] + + result = use_foundry_tools(tools) + + assert isinstance(result, FoundryToolBindingMiddleware) + assert len(get_registry()) == 0 + + def test_use_foundry_tools_with_mcp_tools(self): + """Test using with MCP connected tools.""" + tools = [ + { + "type": "mcp", + "project_connection_id": "MicrosoftLearn", + }, + ] + + result = use_foundry_tools(tools) + + assert isinstance(result, FoundryToolBindingMiddleware) + + def test_use_foundry_tools_with_mixed_tool_types(self): + """Test using with a mix of different tool types.""" + tools = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "MicrosoftLearn"}, + ] + + result = use_foundry_tools(tools) + + assert isinstance(result, FoundryToolBindingMiddleware) + assert len(get_registry()) == 2 + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_chat_model.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_chat_model.py new file mode 100644 index 000000000000..085495a4b91e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_chat_model.py @@ -0,0 +1,277 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolLateBindingChatModel.""" +import pytest +from typing import Any, List, Optional +from unittest.mock import MagicMock, patch + +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import BaseTool, tool + +from azure.ai.agentserver.core.tools import ( + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) +from azure.ai.agentserver.langgraph._context import LanggraphRunContext +from azure.ai.agentserver.langgraph.tools._chat_model import FoundryToolLateBindingChatModel +from azure.ai.agentserver.langgraph.tools._context import FoundryToolContext +from azure.ai.agentserver.langgraph.tools._resolver import ResolvedTools + +from .conftest import FakeChatModel + + +@pytest.mark.unit +class TestFoundryToolLateBindingChatModel: + """Tests for FoundryToolLateBindingChatModel class.""" + + def test_llm_type_property(self): + """Test the _llm_type property returns correct value.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + assert "foundry_tool_binding_model" in model._llm_type + assert "fake_chat_model" in model._llm_type + + def test_bind_tools_records_tools(self): + """Test that bind_tools records tools for later use.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + @tool + def my_tool(x: str) -> str: + """My tool.""" + return x + + result = model.bind_tools([my_tool], tool_choice="auto") + + # Should return self for chaining + assert result is model + # Tools should be recorded + assert len(model._bound_tools) == 1 + assert model._bound_kwargs.get("tool_choice") == "auto" + + def test_bind_tools_multiple_times(self): + """Test binding tools multiple times accumulates them.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + @tool + def tool1(x: str) -> str: + """Tool 1.""" + return x + + @tool + def tool2(x: str) -> str: + """Tool 2.""" + return x + + model.bind_tools([tool1]) + model.bind_tools([tool2]) + + assert len(model._bound_tools) == 2 + + def test_tool_node_property(self): + """Test that tool_node property returns a ToolNode.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + tool_node = model.tool_node + + # Should return a ToolNode + assert tool_node is not None + + def test_tool_node_wrapper_property(self): + """Test that tool_node_wrapper returns correct wrappers.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + wrappers = model.tool_node_wrapper + + assert "wrap_tool_call" in wrappers + assert "awrap_tool_call" in wrappers + assert callable(wrappers["wrap_tool_call"]) + assert callable(wrappers["awrap_tool_call"]) + + def test_invoke_with_context( + self, + mock_langgraph_run_context: LanggraphRunContext, + sample_code_interpreter_tool: FoundryHostedMcpTool, + ): + """Test invoking model with context attached.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Hello from model!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Attach context to config + config: RunnableConfig = {"configurable": {}} + mock_langgraph_run_context.attach_to_config(config) + + input_messages = [HumanMessage(content="Hello")] + result = model.invoke(input_messages, config=config) + + assert result.content == "Hello from model!" + + @pytest.mark.asyncio + async def test_ainvoke_with_context( + self, + mock_langgraph_run_context: LanggraphRunContext, + ): + """Test async invoking model with context attached.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Async hello from model!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Attach context to config + config: RunnableConfig = {"configurable": {}} + mock_langgraph_run_context.attach_to_config(config) + + input_messages = [HumanMessage(content="Hello")] + result = await model.ainvoke(input_messages, config=config) + + assert result.content == "Async hello from model!" + + def test_invoke_without_context_and_no_foundry_tools(self): + """Test invoking model without context and no foundry tools.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Hello!")], + ) + # No foundry tools + foundry_tools: List[Any] = [] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + config: RunnableConfig = {"configurable": {}} + input_messages = [HumanMessage(content="Hello")] + result = model.invoke(input_messages, config=config) + + # Should work since no foundry tools need resolution + assert result.content == "Hello!" + + def test_invoke_without_context_raises_error_when_foundry_tools_present(self): + """Test that invoking without context raises error when foundry tools are set.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Hello!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + config: RunnableConfig = {"configurable": {}} + input_messages = [HumanMessage(content="Hello")] + + with pytest.raises(RuntimeError, match="Unable to resolve foundry tools from context"): + model.invoke(input_messages, config=config) + + def test_stream_with_context( + self, + mock_langgraph_run_context: LanggraphRunContext, + ): + """Test streaming model with context attached.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Streamed response!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Attach context to config + config: RunnableConfig = {"configurable": {}} + mock_langgraph_run_context.attach_to_config(config) + + input_messages = [HumanMessage(content="Hello")] + results = list(model.stream(input_messages, config=config)) + + assert len(results) == 1 + assert results[0].content == "Streamed response!" + + @pytest.mark.asyncio + async def test_astream_with_context( + self, + mock_langgraph_run_context: LanggraphRunContext, + ): + """Test async streaming model with context attached.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Async streamed response!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Attach context to config + config: RunnableConfig = {"configurable": {}} + mock_langgraph_run_context.attach_to_config(config) + + input_messages = [HumanMessage(content="Hello")] + results = [] + async for chunk in model.astream(input_messages, config=config): + results.append(chunk) + + assert len(results) == 1 + assert results[0].content == "Async streamed response!" + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_context.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_context.py new file mode 100644 index 000000000000..577d4e6e4e6f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_context.py @@ -0,0 +1,36 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolContext.""" +import pytest + +from azure.ai.agentserver.langgraph.tools._context import FoundryToolContext +from azure.ai.agentserver.langgraph.tools._resolver import ResolvedTools + + +@pytest.mark.unit +class TestFoundryToolContext: + """Tests for FoundryToolContext class.""" + + def test_create_with_resolved_tools(self, sample_resolved_tools: ResolvedTools): + """Test creating FoundryToolContext with resolved tools.""" + context = FoundryToolContext(resolved_tools=sample_resolved_tools) + + assert context.resolved_tools is sample_resolved_tools + + def test_create_with_default_resolved_tools(self): + """Test creating FoundryToolContext with default empty resolved tools.""" + context = FoundryToolContext() + + # Default should be empty ResolvedTools + assert context.resolved_tools is not None + tools_list = list(context.resolved_tools) + assert len(tools_list) == 0 + + def test_resolved_tools_is_iterable(self, sample_resolved_tools: ResolvedTools): + """Test that resolved_tools can be iterated.""" + context = FoundryToolContext(resolved_tools=sample_resolved_tools) + + tools_list = list(context.resolved_tools) + assert len(tools_list) == 1 + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_middleware.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_middleware.py new file mode 100644 index 000000000000..89290a58f97c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_middleware.py @@ -0,0 +1,197 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolBindingMiddleware.""" +import pytest +from typing import Any, List +from unittest.mock import AsyncMock, MagicMock + +from langchain.agents.middleware.types import ModelRequest +from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.tools import tool +from langgraph.prebuilt.tool_node import ToolCallRequest + +from azure.ai.agentserver.langgraph.tools._middleware import FoundryToolBindingMiddleware + +from .conftest import FakeChatModel + + +@pytest.mark.unit +class TestFoundryToolBindingMiddleware: + """Tests for FoundryToolBindingMiddleware class.""" + + def test_init_with_foundry_tools_creates_dummy_tool(self): + """Test that initialization with foundry tools creates a dummy tool.""" + foundry_tools = [{"type": "code_interpreter"}] + + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Should have one dummy tool + assert len(middleware.tools) == 1 + assert middleware.tools[0].name == "__dummy_tool_by_foundry_middleware__" + + def test_init_without_foundry_tools_no_dummy_tool(self): + """Test that initialization without foundry tools creates no dummy tool.""" + foundry_tools: List[Any] = [] + + middleware = FoundryToolBindingMiddleware(foundry_tools) + + assert len(middleware.tools) == 0 + + def test_wrap_model_call_wraps_model_with_foundry_binding(self): + """Test that wrap_model_call wraps the model correctly.""" + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create mock model and request + mock_model = FakeChatModel() + mock_runtime = MagicMock() + mock_request = MagicMock(spec=ModelRequest) + mock_request.model = mock_model + mock_request.runtime = mock_runtime + mock_request.tools = [] + + # Create a modified request to return + modified_request = MagicMock(spec=ModelRequest) + mock_request.override = MagicMock(return_value=modified_request) + + # Mock handler + expected_result = AIMessage(content="Result") + mock_handler = MagicMock(return_value=expected_result) + + result = middleware.wrap_model_call(mock_request, mock_handler) + + # Handler should be called with modified request + mock_handler.assert_called_once() + assert result == expected_result + + @pytest.mark.asyncio + async def test_awrap_model_call_wraps_model_async(self): + """Test that awrap_model_call wraps the model correctly in async.""" + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create mock model and request + mock_model = FakeChatModel() + mock_runtime = MagicMock() + mock_request = MagicMock(spec=ModelRequest) + mock_request.model = mock_model + mock_request.runtime = mock_runtime + mock_request.tools = [] + + # Create a modified request to return + modified_request = MagicMock(spec=ModelRequest) + mock_request.override = MagicMock(return_value=modified_request) + + # Mock async handler + expected_result = AIMessage(content="Async Result") + mock_handler = AsyncMock(return_value=expected_result) + + result = await middleware.awrap_model_call(mock_request, mock_handler) + + # Handler should be called + mock_handler.assert_awaited_once() + assert result == expected_result + + def test_wrap_model_without_foundry_tools_returns_unchanged(self): + """Test that wrap_model returns unchanged request when no foundry tools.""" + foundry_tools: List[Any] = [] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + mock_model = FakeChatModel() + mock_request = MagicMock(spec=ModelRequest) + mock_request.model = mock_model + mock_request.tools = [] + + # Should not call override + mock_request.override = MagicMock() + + mock_handler = MagicMock(return_value=AIMessage(content="Result")) + + middleware.wrap_model_call(mock_request, mock_handler) + + # Handler should be called with original request + mock_handler.assert_called_once_with(mock_request) + + def test_remove_dummy_tool_from_request(self): + """Test that dummy tool is removed from the request tools.""" + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create request with dummy tool + dummy = middleware._dummy_tool() + + @tool + def real_tool(x: str) -> str: + """Real tool.""" + return x + + mock_request = MagicMock(spec=ModelRequest) + mock_request.tools = [dummy, real_tool] + + # Call internal method + result = middleware._remove_dummy_tool(mock_request) + + # Should only have real_tool + assert len(result) == 1 + assert result[0] is real_tool + + def test_wrap_tool_call_delegates_to_wrapper(self): + """Test that wrap_tool_call delegates to FoundryToolCallWrapper.""" + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create mock tool call request + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "test_tool", "id": "call_1"} + mock_request.state = {} + mock_request.runtime = None + + # Mock handler + expected_result = ToolMessage(content="Result", tool_call_id="call_1") + mock_handler = MagicMock(return_value=expected_result) + + result = middleware.wrap_tool_call(mock_request, mock_handler) + + # Handler should be called + mock_handler.assert_called_once() + assert result == expected_result + + @pytest.mark.asyncio + async def test_awrap_tool_call_delegates_to_wrapper_async(self): + """Test that awrap_tool_call delegates to FoundryToolCallWrapper async.""" + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create mock tool call request + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "test_tool", "id": "call_1"} + mock_request.state = {} + mock_request.runtime = None + + # Mock async handler + expected_result = ToolMessage(content="Async Result", tool_call_id="call_1") + mock_handler = AsyncMock(return_value=expected_result) + + result = await middleware.awrap_tool_call(mock_request, mock_handler) + + # Handler should be awaited + mock_handler.assert_awaited_once() + assert result == expected_result + + def test_middleware_with_multiple_foundry_tools(self): + """Test middleware initialization with multiple foundry tools.""" + foundry_tools = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "test"}, + ] + + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Should still only have one dummy tool + assert len(middleware.tools) == 1 + # But should have all foundry tools registered + assert len(middleware._foundry_tools_to_bind) == 2 + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_resolver.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_resolver.py new file mode 100644 index 000000000000..985ed4caec49 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_resolver.py @@ -0,0 +1,502 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for ResolvedTools and FoundryLangChainToolResolver.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from langchain_core.tools import BaseTool, StructuredTool, tool +from pydantic import BaseModel + +from azure.ai.agentserver.core.tools import (FoundryConnectedTool, FoundryHostedMcpTool, FoundryToolDetails, + ResolvedFoundryTool, SchemaDefinition, SchemaProperty, SchemaType) +from azure.ai.agentserver.langgraph.tools._resolver import ( + ResolvedTools, + FoundryLangChainToolResolver, + get_registry, +) + + +@pytest.mark.unit +class TestResolvedTools: + """Tests for ResolvedTools class.""" + + def test_create_empty_resolved_tools(self): + """Test creating an empty ResolvedTools.""" + resolved = ResolvedTools(tools=[]) + + tools_list = list(resolved) + assert len(tools_list) == 0 + + def test_create_with_single_tool( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, + ): + """Test creating ResolvedTools with a single tool.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + tools_list = list(resolved) + assert len(tools_list) == 1 + assert tools_list[0] is mock_langchain_tool + + def test_create_with_multiple_tools( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + sample_mcp_connected_tool: FoundryConnectedTool, + ): + """Test creating ResolvedTools with multiple tools.""" + @tool + def tool1(query: str) -> str: + """Tool 1.""" + return "result1" + + @tool + def tool2(query: str) -> str: + """Tool 2.""" + return "result2" + + resolved_tool1 = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="tool1", + description="Tool 1", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + resolved_tool2 = ResolvedFoundryTool( + definition=sample_mcp_connected_tool, + details=FoundryToolDetails( + name="tool2", + description="Tool 2", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + + resolved = ResolvedTools(tools=[ + (resolved_tool1, tool1), + (resolved_tool2, tool2), + ]) + + tools_list = list(resolved) + assert len(tools_list) == 2 + + def test_get_tool_by_foundry_tool_like( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, + ): + """Test getting tools by FoundryToolLike.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + # Get by the original foundry tool definition + tools = list(resolved.get(sample_code_interpreter_tool)) + assert len(tools) == 1 + assert tools[0] is mock_langchain_tool + + def test_get_tools_by_list_of_foundry_tools( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + sample_mcp_connected_tool: FoundryConnectedTool, + ): + """Test getting tools by a list of FoundryToolLike.""" + @tool + def tool1(query: str) -> str: + """Tool 1.""" + return "result1" + + @tool + def tool2(query: str) -> str: + """Tool 2.""" + return "result2" + + resolved_tool1 = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="tool1", + description="Tool 1", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + resolved_tool2 = ResolvedFoundryTool( + definition=sample_mcp_connected_tool, + details=FoundryToolDetails( + name="tool2", + description="Tool 2", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + + resolved = ResolvedTools(tools=[ + (resolved_tool1, tool1), + (resolved_tool2, tool2), + ]) + + # Get by list of foundry tools + tools = list(resolved.get([sample_code_interpreter_tool, sample_mcp_connected_tool])) + assert len(tools) == 2 + + def test_get_all_tools_when_no_filter( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, + ): + """Test getting all tools when no filter is provided.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + # Get all tools (no filter) + tools = list(resolved.get()) + assert len(tools) == 1 + + def test_get_returns_empty_for_unknown_tool( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + sample_mcp_connected_tool: FoundryConnectedTool, + mock_langchain_tool: BaseTool, + ): + """Test that get returns empty when requesting unknown tool.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + # Get by a different foundry tool (not in resolved) + tools = list(resolved.get(sample_mcp_connected_tool)) + assert len(tools) == 0 + + def test_iteration_over_resolved_tools( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, + ): + """Test iterating over ResolvedTools.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + # Iterate using for loop + count = 0 + for t in resolved: + assert t is mock_langchain_tool + count += 1 + assert count == 1 + + +@pytest.mark.unit +class TestFoundryLangChainToolResolver: + """Tests for FoundryLangChainToolResolver class.""" + + def test_init_with_default_name_resolver(self): + """Test initialization with default name resolver.""" + resolver = FoundryLangChainToolResolver() + + assert resolver._name_resolver is not None + + def test_init_with_custom_name_resolver(self): + """Test initialization with custom name resolver.""" + from azure.ai.agentserver.core.tools.utils import ToolNameResolver + + custom_resolver = ToolNameResolver() + resolver = FoundryLangChainToolResolver(name_resolver=custom_resolver) + + assert resolver._name_resolver is custom_resolver + + def test_create_pydantic_model_with_required_fields(self): + """Test creating a Pydantic model with required fields.""" + input_schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Search query"), + "limit": SchemaProperty(type=SchemaType.INTEGER, description="Max results"), + }, + required={"query"}, + ) + + model = FoundryLangChainToolResolver._create_pydantic_model("test_tool", input_schema) + + assert issubclass(model, BaseModel) + # Check that the model has the expected fields + assert "query" in model.model_fields + assert "limit" in model.model_fields + + def test_create_pydantic_model_with_no_required_fields(self): + """Test creating a Pydantic model with no required fields.""" + input_schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Search query"), + }, + required=set(), + ) + + model = FoundryLangChainToolResolver._create_pydantic_model("optional_tool", input_schema) + + assert issubclass(model, BaseModel) + assert "query" in model.model_fields + # Optional field should have None as default + assert model.model_fields["query"].default is None + + def test_create_pydantic_model_with_special_characters_in_name(self): + """Test creating a Pydantic model with special characters in tool name.""" + input_schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ) + + model = FoundryLangChainToolResolver._create_pydantic_model("my-tool name", input_schema) + + assert issubclass(model, BaseModel) + # Name should be sanitized + assert "-Input" in model.__name__ or "Input" in model.__name__ + + def test_create_structured_tool(self): + """Test creating a StructuredTool from a resolved foundry tool.""" + resolver = FoundryLangChainToolResolver() + + foundry_tool = FoundryHostedMcpTool(name="test_tool", configuration={}) + resolved_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name="search", + description="Search for documents", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Search query"), + }, + required={"query"}, + ), + ), + ) + + structured_tool = resolver._create_structured_tool(resolved_tool) + + assert isinstance(structured_tool, StructuredTool) + assert structured_tool.description == "Search for documents" + assert structured_tool.coroutine is not None # Should have async function + + @pytest.mark.asyncio + async def test_resolve_from_registry(self): + """Test resolving tools from the global registry.""" + resolver = FoundryLangChainToolResolver() + + # Mock the AgentServerContext + mock_context = MagicMock() + mock_catalog = AsyncMock() + mock_context.tools.catalog.list = mock_catalog + + foundry_tool = FoundryHostedMcpTool(name="test_tool", configuration={}) + resolved_foundry_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name="search", + description="Search tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Query"), + }, + required={"query"}, + ), + ), + ) + mock_catalog.return_value = [resolved_foundry_tool] + + # Add tool to registry + registry = get_registry() + registry.clear() + registry.append({"type": "code_interpreter"}) + + with patch("azure.ai.agentserver.langgraph.tools._resolver.AgentServerContext.get", return_value=mock_context): + result = await resolver.resolve_from_registry() + + assert isinstance(result, ResolvedTools) + mock_catalog.assert_called_once() + + # Clean up registry + registry.clear() + + @pytest.mark.asyncio + async def test_resolve_with_foundry_tools_list(self): + """Test resolving a list of foundry tools.""" + resolver = FoundryLangChainToolResolver() + + # Mock the AgentServerContext + mock_context = MagicMock() + mock_catalog = AsyncMock() + mock_context.tools.catalog.list = mock_catalog + + foundry_tool = FoundryHostedMcpTool(name="code_interpreter", configuration={}) + resolved_foundry_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name="execute_code", + description="Execute code", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "code": SchemaProperty(type=SchemaType.STRING, description="Code to execute"), + }, + required={"code"}, + ), + ), + ) + mock_catalog.return_value = [resolved_foundry_tool] + + foundry_tools = [{"type": "code_interpreter"}] + + with patch("azure.ai.agentserver.langgraph.tools._resolver.AgentServerContext.get", return_value=mock_context): + result = await resolver.resolve(foundry_tools) + + assert isinstance(result, ResolvedTools) + tools_list = list(result) + assert len(tools_list) == 1 + assert isinstance(tools_list[0], StructuredTool) + + @pytest.mark.asyncio + async def test_resolve_empty_list(self): + """Test resolving an empty list of foundry tools.""" + resolver = FoundryLangChainToolResolver() + + # Mock the AgentServerContext + mock_context = MagicMock() + mock_catalog = AsyncMock() + mock_context.tools.catalog.list = mock_catalog + mock_catalog.return_value = [] + + with patch("azure.ai.agentserver.langgraph.tools._resolver.AgentServerContext.get", return_value=mock_context): + result = await resolver.resolve([]) + + assert isinstance(result, ResolvedTools) + tools_list = list(result) + assert len(tools_list) == 0 + + +@pytest.mark.unit +class TestGetRegistry: + """Tests for the get_registry function.""" + + def test_get_registry_returns_list(self): + """Test that get_registry returns a list.""" + registry = get_registry() + + assert isinstance(registry, list) + + def test_registry_is_singleton(self): + """Test that get_registry returns the same list instance.""" + registry1 = get_registry() + registry2 = get_registry() + + assert registry1 is registry2 + + def test_registry_can_be_modified(self): + """Test that the registry can be modified.""" + registry = get_registry() + original_length = len(registry) + + registry.append({"type": "test_tool"}) + + assert len(registry) == original_length + 1 + + # Clean up + registry.pop() + + def test_registry_extend(self): + """Test extending the registry with multiple tools.""" + registry = get_registry() + registry.clear() + + tools = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "test"}, + ] + registry.extend(tools) + + assert len(registry) == 2 + + # Clean up + registry.clear() diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_tool_node.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_tool_node.py new file mode 100644 index 000000000000..1c46e58785bc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_tool_node.py @@ -0,0 +1,179 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolCallWrapper and FoundryToolNodeWrappers.""" +import pytest +from typing import Any, List +from unittest.mock import AsyncMock, MagicMock + +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool +from langgraph.prebuilt.tool_node import ToolCallRequest +from langgraph.types import Command + +from azure.ai.agentserver.langgraph.tools._tool_node import ( + FoundryToolCallWrapper, + FoundryToolNodeWrappers, +) + + +@pytest.mark.unit +class TestFoundryToolCallWrapper: + """Tests for FoundryToolCallWrapper class.""" + + def test_as_wrappers_returns_typed_dict(self): + """Test that as_wrappers returns a FoundryToolNodeWrappers TypedDict.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + result = wrapper.as_wrappers() + + assert isinstance(result, dict) + assert "wrap_tool_call" in result + assert "awrap_tool_call" in result + assert callable(result["wrap_tool_call"]) + assert callable(result["awrap_tool_call"]) + + def test_call_tool_with_already_resolved_tool(self): + """Test that call_tool passes through when tool is already resolved.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + # Create request with tool already set + @tool + def existing_tool(x: str) -> str: + """Existing tool.""" + return f"Result: {x}" + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = existing_tool + mock_request.tool_call = {"name": "existing_tool", "id": "call_1"} + + expected_result = ToolMessage(content="Result: test", tool_call_id="call_1") + mock_invocation = MagicMock(return_value=expected_result) + + result = wrapper.call_tool(mock_request, mock_invocation) + + # Should pass through original request + mock_invocation.assert_called_once_with(mock_request) + assert result == expected_result + + def test_call_tool_with_no_foundry_tools(self): + """Test that call_tool passes through when no foundry tools configured.""" + foundry_tools: List[Any] = [] + wrapper = FoundryToolCallWrapper(foundry_tools) + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "some_tool", "id": "call_1"} + + expected_result = ToolMessage(content="Result", tool_call_id="call_1") + mock_invocation = MagicMock(return_value=expected_result) + + result = wrapper.call_tool(mock_request, mock_invocation) + + mock_invocation.assert_called_once_with(mock_request) + assert result == expected_result + + @pytest.mark.asyncio + async def test_call_tool_async_with_already_resolved_tool(self): + """Test that call_tool_async passes through when tool is already resolved.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + @tool + def existing_tool(x: str) -> str: + """Existing tool.""" + return f"Result: {x}" + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = existing_tool + mock_request.tool_call = {"name": "existing_tool", "id": "call_1"} + + expected_result = ToolMessage(content="Async Result", tool_call_id="call_1") + mock_invocation = AsyncMock(return_value=expected_result) + + result = await wrapper.call_tool_async(mock_request, mock_invocation) + + mock_invocation.assert_awaited_once_with(mock_request) + assert result == expected_result + + @pytest.mark.asyncio + async def test_call_tool_async_with_no_foundry_tools(self): + """Test that call_tool_async passes through when no foundry tools configured.""" + foundry_tools: List[Any] = [] + wrapper = FoundryToolCallWrapper(foundry_tools) + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "some_tool", "id": "call_1"} + + expected_result = ToolMessage(content="Result", tool_call_id="call_1") + mock_invocation = AsyncMock(return_value=expected_result) + + result = await wrapper.call_tool_async(mock_request, mock_invocation) + + mock_invocation.assert_awaited_once_with(mock_request) + assert result == expected_result + + def test_call_tool_returns_command_result(self): + """Test that call_tool can return Command objects.""" + foundry_tools: List[Any] = [] + wrapper = FoundryToolCallWrapper(foundry_tools) + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "some_tool", "id": "call_1"} + + # Return a Command instead of ToolMessage + expected_result = Command(goto="next_node") + mock_invocation = MagicMock(return_value=expected_result) + + result = wrapper.call_tool(mock_request, mock_invocation) + + assert result == expected_result + assert isinstance(result, Command) + + +@pytest.mark.unit +class TestFoundryToolNodeWrappers: + """Tests for FoundryToolNodeWrappers TypedDict.""" + + def test_foundry_tool_node_wrappers_structure(self): + """Test that FoundryToolNodeWrappers has the expected structure.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + wrappers: FoundryToolNodeWrappers = wrapper.as_wrappers() + + # Should have both sync and async wrappers + assert "wrap_tool_call" in wrappers + assert "awrap_tool_call" in wrappers + + # Should be the wrapper methods + assert wrappers["wrap_tool_call"] == wrapper.call_tool + assert wrappers["awrap_tool_call"] == wrapper.call_tool_async + + def test_wrappers_can_be_unpacked_to_tool_node(self): + """Test that wrappers can be unpacked as kwargs to ToolNode.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + wrappers = wrapper.as_wrappers() + + # Should be usable as kwargs + assert len(wrappers) == 2 + + # This pattern is used: ToolNode([], **wrappers) + def mock_tool_node_init(tools, wrap_tool_call=None, awrap_tool_call=None): + return { + "tools": tools, + "wrap_tool_call": wrap_tool_call, + "awrap_tool_call": awrap_tool_call, + } + + result = mock_tool_node_init([], **wrappers) + + assert result["wrap_tool_call"] is not None + assert result["awrap_tool_call"] is not None + From bd8da8e7d77a9ce8957ea96be6d079f4b57189ea Mon Sep 17 00:00:00 2001 From: junanchen Date: Thu, 22 Jan 2026 21:14:45 -0800 Subject: [PATCH 28/29] optimize af code --- .../agentframework/_foundry_tools.py | 20 +------------------ 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py index 78d8108ed96c..068ad960857d 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py @@ -71,7 +71,7 @@ def _to_aifunction(self, foundry_tool: "ResolvedFoundryTool") -> AIFunction: # Build field definitions for the Pydantic model field_definitions: Dict[str, Any] = {} for field_name, field_info in properties.items(): - field_type = self._json_schema_type_to_python(field_info.type or "string") + field_type = field_info.type.py_type field_description = field_info.description or "" is_required = field_name in required_fields @@ -107,24 +107,6 @@ async def tool_func(**kwargs: Any) -> Any: input_model=input_model ) - def _json_schema_type_to_python(self, json_type: str) -> type: - """Convert JSON schema type to Python type. - - :param json_type: The JSON schema type string. - :type json_type: str - :return: The corresponding Python type. - :rtype: type - """ - type_map = { - "string": str, - "number": float, - "integer": int, - "boolean": bool, - "array": list, - "object": dict, - } - return type_map.get(json_type, str) - class FoundryToolsChatMiddleware(ChatMiddleware): """Chat middleware to inject Foundry tools into ChatOptions on each call.""" From 2f2ca63bec273ded546ddc1797dae5e26e51bf7c Mon Sep 17 00:00:00 2001 From: junanchen Date: Thu, 22 Jan 2026 21:55:02 -0800 Subject: [PATCH 29/29] add more uts. support resource id format of project connection id --- .../agentframework/_foundry_tools.py | 4 +- .../agentserver/core/tools/runtime/_facade.py | 45 ++- .../unit_tests/core/tools/client/__init__.py | 5 + .../tools/client/operations}/__init__.py | 0 .../test_foundry_connected_tools.py | 2 +- .../test_foundry_hosted_mcp_tools.py | 2 +- .../core/tools/{ => client}/test_client.py | 30 +- .../tools/{ => client}/test_configuration.py | 0 .../tests/unit_tests/core/tools/conftest.py | 3 +- .../unit_tests/core/tools/runtime/__init__.py | 4 + .../unit_tests/core/tools/runtime/conftest.py | 39 ++ .../core/tools/runtime/test_catalog.py | 349 ++++++++++++++++++ .../core/tools/runtime/test_facade.py | 180 +++++++++ .../core/tools/runtime/test_invoker.py | 198 ++++++++++ .../core/tools/runtime/test_resolver.py | 202 ++++++++++ .../core/tools/runtime/test_runtime.py | 283 ++++++++++++++ .../core/tools/runtime/test_starlette.py | 261 +++++++++++++ .../core/tools/runtime/test_user.py | 210 +++++++++++ .../tools/{operations => utils}/__init__.py | 2 +- .../unit_tests/core/tools/utils/conftest.py | 56 +++ .../core/tools/utils/test_name_resolver.py | 260 +++++++++++++ .../agentserver/langgraph/tools/_builder.py | 13 +- 22 files changed, 2120 insertions(+), 28 deletions(-) create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/__init__.py rename sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/{ => core/tools/client/operations}/__init__.py (100%) rename sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/{ => client}/operations/test_foundry_connected_tools.py (99%) rename sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/{ => client}/operations/test_foundry_hosted_mcp_tools.py (99%) rename sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/{ => client}/test_client.py (93%) rename sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/{ => client}/test_configuration.py (100%) create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/conftest.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_catalog.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_facade.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_invoker.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_resolver.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_runtime.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_starlette.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_user.py rename sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/{operations => utils}/__init__.py (84%) create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/conftest.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/test_name_resolver.py diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py index 068ad960857d..64120308a872 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py @@ -13,7 +13,7 @@ from azure.ai.agentserver.core import AgentServerContext from azure.ai.agentserver.core.logger import get_logger -from azure.ai.agentserver.core.tools import FoundryToolLike, ResolvedFoundryTool +from azure.ai.agentserver.core.tools import FoundryToolLike, ResolvedFoundryTool, ensure_foundry_tool logger = get_logger() @@ -45,7 +45,7 @@ def __init__( self, tools: Sequence[FoundryToolLike], ) -> None: - self._allowed_tools: List[FoundryToolLike] = list(tools) + self._allowed_tools: List[FoundryToolLike] = [ensure_foundry_tool(tool) for tool in tools] async def list_tools(self) -> List[AIFunction]: server_context = AgentServerContext.get() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py index f12d3f0db7b5..bfc4a08d9a63 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py @@ -1,6 +1,7 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +import re from typing import Any, Dict, Union from .. import FoundryConnectedTool, FoundryHostedMcpTool @@ -45,6 +46,48 @@ def ensure_foundry_tool(tool: FoundryToolLike) -> FoundryTool: if not isinstance(project_connection_id, str) or not project_connection_id: raise InvalidToolFacadeError(f"project_connection_id is required for tool protocol {protocol}.") - return FoundryConnectedTool(protocol=protocol, project_connection_id=project_connection_id) + # Parse the connection identifier to extract the connection name + connection_name = _parse_connection_id(project_connection_id) + return FoundryConnectedTool(protocol=protocol, project_connection_id=connection_name) except ValueError: return FoundryHostedMcpTool(name=tool_type, configuration=tool) + + +# Pattern for Azure resource ID format: +# /subscriptions//resourceGroups//providers/Microsoft.CognitiveServices/accounts//projects//connections/ +_RESOURCE_ID_PATTERN = re.compile( + r"^/subscriptions/[^/]+/resourceGroups/[^/]+/providers/Microsoft\.CognitiveServices/" + r"accounts/[^/]+/projects/[^/]+/connections/(?P[^/]+)$", + re.IGNORECASE, +) + + +def _parse_connection_id(connection_id: str) -> str: + """Parse the connection identifier and extract the connection name. + + Supports two formats: + 1. Simple name: "my-connection-name" + 2. Resource ID: "/subscriptions//resourceGroups//providers/Microsoft.CognitiveServices/accounts//projects//connections/" + + :param connection_id: The connection identifier, either a simple name or a full resource ID. + :type connection_id: str + :return: The connection name extracted from the identifier. + :rtype: str + :raises InvalidToolFacadeError: If the connection_id format is invalid. + """ + if not connection_id: + raise InvalidToolFacadeError("Connection identifier cannot be empty.") + + # Check if it's a resource ID format (starts with /) + if connection_id.startswith("/"): + match = _RESOURCE_ID_PATTERN.match(connection_id) + if not match: + raise InvalidToolFacadeError( + f"Invalid resource ID format for connection: '{connection_id}'. " + "Expected format: /subscriptions//resourceGroups//providers/" + "Microsoft.CognitiveServices/accounts//projects//connections/" + ) + return match.group("name") + + # Otherwise, treat it as a simple connection name + return connection_id diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/__init__.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/__init__.py rename to sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/__init__.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_connected_tools.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/test_foundry_connected_tools.py similarity index 99% rename from sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_connected_tools.py rename to sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/test_foundry_connected_tools.py index 7c453ba2fa2c..e7273f37a7e7 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_connected_tools.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/test_foundry_connected_tools.py @@ -14,7 +14,7 @@ ) from azure.ai.agentserver.core.tools._exceptions import OAuthConsentRequiredError, ToolInvocationError -from ..conftest import create_mock_http_response +from ...conftest import create_mock_http_response class TestFoundryConnectedToolsOperationsListTools: diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_hosted_mcp_tools.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/test_foundry_hosted_mcp_tools.py similarity index 99% rename from sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_hosted_mcp_tools.py rename to sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/test_foundry_hosted_mcp_tools.py index 9897cbf168ed..473b27cc8768 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/test_foundry_hosted_mcp_tools.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/test_foundry_hosted_mcp_tools.py @@ -18,7 +18,7 @@ ) from azure.ai.agentserver.core.tools._exceptions import ToolInvocationError -from ..conftest import create_mock_http_response +from ...conftest import create_mock_http_response class TestFoundryMcpToolsOperationsListTools: diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/test_client.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/test_client.py similarity index 93% rename from sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/test_client.py rename to sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/test_client.py index c99de80a87f9..de60f545e089 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/test_client.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/test_client.py @@ -13,7 +13,7 @@ ) from azure.ai.agentserver.core.tools._exceptions import ToolInvocationError -from .conftest import create_mock_http_response +from ..conftest import create_mock_http_response class TestFoundryToolClientInit: @@ -22,7 +22,7 @@ class TestFoundryToolClientInit: @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") def test_init_with_valid_endpoint_and_credential(self, mock_pipeline_client_class, mock_credential): """Test client can be initialized with valid endpoint and credential.""" - endpoint = "https://test.api.azureml.ms" + endpoint = "https://fake-project-endpoint.site" client = FoundryToolClient(endpoint, mock_credential) @@ -43,7 +43,7 @@ async def test_list_tools_empty_collection_returns_empty_list( mock_credential ): """Test list_tools returns empty list when given empty collection.""" - client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) result = await client.list_tools([], agent_name="test-agent") @@ -79,7 +79,7 @@ async def test_list_tools_with_single_mcp_tool_returns_resolved_tools( mock_client_instance.send_request.return_value = mock_response mock_client_instance.post.return_value = MagicMock() - client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) result = await client.list_tools([sample_hosted_mcp_tool], agent_name="test-agent") assert len(result) == 1 @@ -122,7 +122,7 @@ async def test_list_tools_with_single_connected_tool_returns_resolved_tools( mock_client_instance.send_request.return_value = mock_response mock_client_instance.post.return_value = MagicMock() - client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) result = await client.list_tools( [sample_connected_tool], agent_name="test-agent", @@ -187,7 +187,7 @@ async def test_list_tools_with_mixed_tool_types_returns_all_resolved( ] mock_client_instance.post.return_value = MagicMock() - client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) result = await client.list_tools( [sample_hosted_mcp_tool, sample_connected_tool], agent_name="test-agent", @@ -234,7 +234,7 @@ async def test_list_tools_filters_unlisted_mcp_tools( mock_client_instance.send_request.return_value = mock_response mock_client_instance.post.return_value = MagicMock() - client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) result = await client.list_tools([sample_hosted_mcp_tool], agent_name="test-agent") # Should only return the requested tool @@ -274,7 +274,7 @@ async def test_list_tools_details_returns_mapping_structure( mock_client_instance.send_request.return_value = mock_response mock_client_instance.post.return_value = MagicMock() - client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) result = await client.list_tools_details([sample_hosted_mcp_tool], agent_name="test-agent") assert isinstance(result, dict) @@ -312,7 +312,7 @@ async def test_list_tools_details_groups_multiple_tools_by_definition( mock_client_instance.send_request.return_value = mock_response mock_client_instance.post.return_value = MagicMock() - client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) result = await client.list_tools_details([sample_hosted_mcp_tool], agent_name="test-agent") # All tools should be grouped under the same definition ID @@ -339,7 +339,7 @@ async def test_invoke_mcp_tool_returns_result( mock_client_instance.send_request.return_value = mock_response mock_client_instance.post.return_value = MagicMock() - client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) result = await client.invoke_tool( sample_resolved_mcp_tool, arguments={"input": "test"}, @@ -367,7 +367,7 @@ async def test_invoke_connected_tool_returns_result( mock_client_instance.send_request.return_value = mock_response mock_client_instance.post.return_value = MagicMock() - client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) result = await client.invoke_tool( sample_resolved_connected_tool, arguments={"input": "test"}, @@ -393,7 +393,7 @@ async def test_invoke_tool_with_complex_arguments( mock_client_instance.send_request.return_value = mock_response mock_client_instance.post.return_value = MagicMock() - client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) complex_args = { "string_param": "value", "number_param": 42, @@ -431,7 +431,7 @@ async def test_invoke_tool_with_unsupported_source_raises_error( mock_tool.source = "unsupported_source" mock_tool.details = sample_tool_details - client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) with pytest.raises(ToolInvocationError) as exc_info: await client.invoke_tool( @@ -457,7 +457,7 @@ async def test_close_closes_underlying_client( mock_client_instance = AsyncMock() mock_pipeline_client_class.return_value = mock_client_instance - client = FoundryToolClient("https://test.api.azureml.ms", mock_credential) + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) await client.close() mock_client_instance.close.assert_called_once() @@ -477,7 +477,7 @@ async def test_async_context_manager_enters_and_exits( mock_client_instance = AsyncMock() mock_pipeline_client_class.return_value = mock_client_instance - async with FoundryToolClient("https://test.api.azureml.ms", mock_credential) as client: + async with FoundryToolClient("https://fake-project-endpoint.site", mock_credential) as client: assert client is not None mock_client_instance.__aenter__.assert_called_once() diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/test_configuration.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/test_configuration.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/test_configuration.py rename to sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/test_configuration.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/conftest.py index bf94b48b9699..8849ce8aafbf 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/conftest.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/conftest.py @@ -1,7 +1,7 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -"""Shared fixtures for unit tests.""" +"""Shared fixtures for tools unit tests.""" import json from typing import Any, Dict, Optional from unittest.mock import AsyncMock, MagicMock @@ -125,4 +125,3 @@ def create_mock_http_response( response.__aexit__ = AsyncMock(return_value=None) return response - diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/__init__.py new file mode 100644 index 000000000000..964fac9d8a55 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Runtime unit tests package.""" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/conftest.py new file mode 100644 index 000000000000..52a371bdc958 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/conftest.py @@ -0,0 +1,39 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for runtime unit tests. + +Common fixtures are inherited from the parent conftest.py automatically by pytest. +""" +from unittest.mock import AsyncMock + +import pytest + + +@pytest.fixture +def mock_foundry_tool_client(): + """Create a mock FoundryToolClient.""" + client = AsyncMock() + client.list_tools = AsyncMock(return_value=[]) + client.list_tools_details = AsyncMock(return_value={}) + client.invoke_tool = AsyncMock(return_value={"result": "success"}) + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=None) + return client + + +@pytest.fixture +def mock_user_provider(sample_user_info): + """Create a mock UserProvider.""" + provider = AsyncMock() + provider.get_user = AsyncMock(return_value=sample_user_info) + return provider + + +@pytest.fixture +def mock_user_provider_none(): + """Create a mock UserProvider that returns None.""" + provider = AsyncMock() + provider.get_user = AsyncMock(return_value=None) + return provider + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_catalog.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_catalog.py new file mode 100644 index 000000000000..45b03f0530a2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_catalog.py @@ -0,0 +1,349 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _catalog.py - testing public methods of DefaultFoundryToolCatalog.""" +import asyncio +import pytest +from unittest.mock import AsyncMock + +from azure.ai.agentserver.core.tools.runtime._catalog import ( + DefaultFoundryToolCatalog, +) +from azure.ai.agentserver.core.tools.client._models import ( + FoundryToolDetails, + ResolvedFoundryTool, + UserInfo, +) + + +class TestFoundryToolCatalogGet: + """Tests for FoundryToolCatalog.get method.""" + + @pytest.mark.asyncio + async def test_get_returns_resolved_tool_when_found( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details, + sample_user_info + ): + """Test get returns a resolved tool when the tool is found.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.get(sample_hosted_mcp_tool) + + assert result is not None + assert isinstance(result, ResolvedFoundryTool) + assert result.details == sample_tool_details + + @pytest.mark.asyncio + async def test_get_returns_none_when_not_found( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool + ): + """Test get returns None when the tool is not found.""" + mock_foundry_tool_client.list_tools_details = AsyncMock(return_value={}) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.get(sample_hosted_mcp_tool) + + assert result is None + + +class TestDefaultFoundryToolCatalogList: + """Tests for DefaultFoundryToolCatalog.list method.""" + + @pytest.mark.asyncio + async def test_list_returns_empty_list_when_no_tools( + self, + mock_foundry_tool_client, + mock_user_provider + ): + """Test list returns empty list when no tools are provided.""" + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([]) + + assert result == [] + + @pytest.mark.asyncio + async def test_list_returns_resolved_tools( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details + ): + """Test list returns resolved tools.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([sample_hosted_mcp_tool]) + + assert len(result) == 1 + assert isinstance(result[0], ResolvedFoundryTool) + assert result[0].definition == sample_hosted_mcp_tool + assert result[0].details == sample_tool_details + + @pytest.mark.asyncio + async def test_list_multiple_tools_with_multiple_details( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_connected_tool, + sample_schema_definition + ): + """Test list returns all resolved tools when tools have multiple details.""" + details1 = FoundryToolDetails( + name="tool1", + description="First tool", + input_schema=sample_schema_definition + ) + details2 = FoundryToolDetails( + name="tool2", + description="Second tool", + input_schema=sample_schema_definition + ) + + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={ + sample_hosted_mcp_tool.id: [details1], + sample_connected_tool.id: [details2] + } + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([sample_hosted_mcp_tool, sample_connected_tool]) + + assert len(result) == 2 + names = {r.details.name for r in result} + assert names == {"tool1", "tool2"} + + @pytest.mark.asyncio + async def test_list_caches_results_for_hosted_mcp_tools( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details + ): + """Test that list caches results for hosted MCP tools.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + # First call + result1 = await catalog.list([sample_hosted_mcp_tool]) + # Second call should use cache + result2 = await catalog.list([sample_hosted_mcp_tool]) + + # Client should only be called once + assert mock_foundry_tool_client.list_tools_details.call_count == 1 + assert len(result1) == len(result2) == 1 + + @pytest.mark.asyncio + async def test_list_with_facade_dict( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_tool_details + ): + """Test list works with facade dictionaries.""" + facade = {"type": "custom_tool", "config": "value"} + expected_id = "hosted_mcp:custom_tool" + + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={expected_id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([facade]) + + assert len(result) == 1 + assert result[0].details == sample_tool_details + + @pytest.mark.asyncio + async def test_list_returns_multiple_details_per_tool( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_schema_definition + ): + """Test list returns multiple resolved tools when a tool has multiple details.""" + details1 = FoundryToolDetails( + name="function1", + description="First function", + input_schema=sample_schema_definition + ) + details2 = FoundryToolDetails( + name="function2", + description="Second function", + input_schema=sample_schema_definition + ) + + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [details1, details2]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([sample_hosted_mcp_tool]) + + assert len(result) == 2 + names = {r.details.name for r in result} + assert names == {"function1", "function2"} + + @pytest.mark.asyncio + async def test_list_handles_exception_from_client( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool + ): + """Test list propagates exception from client and clears cache.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + side_effect=RuntimeError("Network error") + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + with pytest.raises(RuntimeError, match="Network error"): + await catalog.list([sample_hosted_mcp_tool]) + + @pytest.mark.asyncio + async def test_list_connected_tool_cache_key_includes_user( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_connected_tool, + sample_tool_details, + sample_user_info + ): + """Test that connected tool cache key includes user info.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_connected_tool.id: [sample_tool_details]} + ) + + # Create a new user provider returning a different user + other_user = UserInfo(object_id="other-oid", tenant_id="other-tid") + mock_user_provider2 = AsyncMock() + mock_user_provider2.get_user = AsyncMock(return_value=other_user) + + catalog1 = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + catalog2 = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider2, + agent_name="test-agent" + ) + + # Both catalogs should be able to list tools + result1 = await catalog1.list([sample_connected_tool]) + result2 = await catalog2.list([sample_connected_tool]) + + assert len(result1) == 1 + assert len(result2) == 1 + + +class TestCachedFoundryToolCatalogConcurrency: + """Tests for CachedFoundryToolCatalog concurrency handling.""" + + @pytest.mark.asyncio + async def test_concurrent_requests_share_single_fetch( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details + ): + """Test that concurrent requests for the same tool share a single fetch.""" + call_count = 0 + fetch_event = asyncio.Event() + + async def slow_fetch(*args, **kwargs): + nonlocal call_count + call_count += 1 + await fetch_event.wait() + return {sample_hosted_mcp_tool.id: [sample_tool_details]} + + mock_foundry_tool_client.list_tools_details = slow_fetch + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + # Start two concurrent requests + task1 = asyncio.create_task(catalog.list([sample_hosted_mcp_tool])) + task2 = asyncio.create_task(catalog.list([sample_hosted_mcp_tool])) + + # Allow tasks to start + await asyncio.sleep(0.01) + + # Release the fetch + fetch_event.set() + + results = await asyncio.gather(task1, task2) + + # Both should get results, but fetch should only be called once + assert len(results[0]) == 1 + assert len(results[1]) == 1 + assert call_count == 1 diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_facade.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_facade.py new file mode 100644 index 000000000000..c5377dc339a4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_facade.py @@ -0,0 +1,180 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _facade.py - testing public function ensure_foundry_tool.""" +import pytest + +from azure.ai.agentserver.core.tools.runtime._facade import ensure_foundry_tool +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolProtocol, + FoundryToolSource, +) +from azure.ai.agentserver.core.tools._exceptions import InvalidToolFacadeError + + +class TestEnsureFoundryTool: + """Tests for ensure_foundry_tool public function.""" + + def test_returns_same_instance_when_given_foundry_tool(self, sample_hosted_mcp_tool): + """Test that passing a FoundryTool returns the same instance.""" + result = ensure_foundry_tool(sample_hosted_mcp_tool) + + assert result is sample_hosted_mcp_tool + + def test_returns_same_instance_for_connected_tool(self, sample_connected_tool): + """Test that passing a FoundryConnectedTool returns the same instance.""" + result = ensure_foundry_tool(sample_connected_tool) + + assert result is sample_connected_tool + + def test_converts_facade_with_mcp_protocol_to_connected_tool(self): + """Test that a facade with 'mcp' protocol is converted to FoundryConnectedTool.""" + facade = { + "type": "mcp", + "project_connection_id": "my-connection" + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.protocol == FoundryToolProtocol.MCP + assert result.project_connection_id == "my-connection" + assert result.source == FoundryToolSource.CONNECTED + + def test_converts_facade_with_a2a_protocol_to_connected_tool(self): + """Test that a facade with 'a2a' protocol is converted to FoundryConnectedTool.""" + facade = { + "type": "a2a", + "project_connection_id": "my-a2a-connection" + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.protocol == FoundryToolProtocol.A2A + assert result.project_connection_id == "my-a2a-connection" + + def test_converts_facade_with_unknown_type_to_hosted_mcp_tool(self): + """Test that a facade with unknown type is converted to FoundryHostedMcpTool.""" + facade = { + "type": "my_custom_tool", + "some_config": "value123", + "another_config": True + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryHostedMcpTool) + assert result.name == "my_custom_tool" + assert result.configuration == {"some_config": "value123", "another_config": True} + assert result.source == FoundryToolSource.HOSTED_MCP + + def test_raises_error_when_type_is_missing(self): + """Test that InvalidToolFacadeError is raised when 'type' is missing.""" + facade = {"project_connection_id": "my-connection"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "type" in str(exc_info.value).lower() + + def test_raises_error_when_type_is_empty_string(self): + """Test that InvalidToolFacadeError is raised when 'type' is empty string.""" + facade = {"type": "", "project_connection_id": "my-connection"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "type" in str(exc_info.value).lower() + + def test_raises_error_when_type_is_not_string(self): + """Test that InvalidToolFacadeError is raised when 'type' is not a string.""" + facade = {"type": 123, "project_connection_id": "my-connection"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "type" in str(exc_info.value).lower() + + def test_raises_error_when_mcp_protocol_missing_connection_id(self): + """Test that InvalidToolFacadeError is raised when mcp protocol is missing project_connection_id.""" + facade = {"type": "mcp"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "project_connection_id" in str(exc_info.value) + + def test_raises_error_when_a2a_protocol_has_empty_connection_id(self): + """Test that InvalidToolFacadeError is raised when a2a protocol has empty project_connection_id.""" + facade = {"type": "a2a", "project_connection_id": ""} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "project_connection_id" in str(exc_info.value) + + def test_parses_resource_id_format_connection_id(self): + """Test that resource ID format project_connection_id is parsed correctly.""" + resource_id = ( + "/subscriptions/sub-123/resourceGroups/rg-test/providers/" + "Microsoft.CognitiveServices/accounts/acc-test/projects/proj-test/connections/my-conn-name" + ) + facade = { + "type": "mcp", + "project_connection_id": resource_id + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.project_connection_id == "my-conn-name" + + def test_raises_error_for_invalid_resource_id_format(self): + """Test that InvalidToolFacadeError is raised for invalid resource ID format.""" + invalid_resource_id = "/subscriptions/sub-123/invalid/path" + facade = { + "type": "mcp", + "project_connection_id": invalid_resource_id + } + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "Invalid resource ID format" in str(exc_info.value) + + def test_uses_simple_connection_name_as_is(self): + """Test that simple connection name is used as-is without parsing.""" + facade = { + "type": "mcp", + "project_connection_id": "simple-connection-name" + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.project_connection_id == "simple-connection-name" + + def test_original_facade_not_modified(self): + """Test that the original facade dictionary is not modified.""" + facade = { + "type": "my_tool", + "config_key": "config_value" + } + original_facade = facade.copy() + + ensure_foundry_tool(facade) + + assert facade == original_facade + + def test_hosted_mcp_tool_with_no_extra_configuration(self): + """Test that hosted MCP tool works with no extra configuration.""" + facade = {"type": "simple_tool"} + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryHostedMcpTool) + assert result.name == "simple_tool" + assert result.configuration == {} diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_invoker.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_invoker.py new file mode 100644 index 000000000000..b2a222c09d6e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_invoker.py @@ -0,0 +1,198 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _invoker.py - testing public methods of DefaultFoundryToolInvoker.""" +import pytest +from unittest.mock import AsyncMock + +from azure.ai.agentserver.core.tools.runtime._invoker import DefaultFoundryToolInvoker + + +class TestDefaultFoundryToolInvokerResolvedTool: + """Tests for DefaultFoundryToolInvoker.resolved_tool property.""" + + def test_resolved_tool_returns_tool_passed_at_init( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test resolved_tool property returns the tool passed during initialization.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + assert invoker.resolved_tool is sample_resolved_mcp_tool + + def test_resolved_tool_returns_connected_tool( + self, + sample_resolved_connected_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test resolved_tool property returns connected tool.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_connected_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + assert invoker.resolved_tool is sample_resolved_connected_tool + + +class TestDefaultFoundryToolInvokerInvoke: + """Tests for DefaultFoundryToolInvoker.invoke method.""" + + @pytest.mark.asyncio + async def test_invoke_calls_client_with_correct_arguments( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider, + sample_user_info + ): + """Test invoke calls client.invoke_tool with correct arguments.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + arguments = {"input": "test value", "count": 5} + + await invoker.invoke(arguments) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + arguments, + "test-agent", + sample_user_info + ) + + @pytest.mark.asyncio + async def test_invoke_returns_result_from_client( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test invoke returns the result from client.invoke_tool.""" + expected_result = {"output": "test result", "status": "completed"} + mock_foundry_tool_client.invoke_tool = AsyncMock(return_value=expected_result) + + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await invoker.invoke({"input": "test"}) + + assert result == expected_result + + @pytest.mark.asyncio + async def test_invoke_with_empty_arguments( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider, + sample_user_info + ): + """Test invoke works with empty arguments dictionary.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + await invoker.invoke({}) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + {}, + "test-agent", + sample_user_info + ) + + @pytest.mark.asyncio + async def test_invoke_with_none_user( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider_none + ): + """Test invoke works when user provider returns None.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider_none, + agent_name="test-agent" + ) + + await invoker.invoke({"input": "test"}) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + {"input": "test"}, + "test-agent", + None + ) + + @pytest.mark.asyncio + async def test_invoke_propagates_client_exception( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test invoke propagates exceptions from client.invoke_tool.""" + mock_foundry_tool_client.invoke_tool = AsyncMock( + side_effect=RuntimeError("Client error") + ) + + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + with pytest.raises(RuntimeError, match="Client error"): + await invoker.invoke({"input": "test"}) + + @pytest.mark.asyncio + async def test_invoke_with_complex_nested_arguments( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider, + sample_user_info + ): + """Test invoke with complex nested argument structure.""" + complex_args = { + "nested": {"key1": "value1", "key2": 123}, + "list": [1, 2, 3], + "mixed": [{"a": 1}, {"b": 2}] + } + + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + await invoker.invoke(complex_args) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + complex_args, + "test-agent", + sample_user_info + ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_resolver.py new file mode 100644 index 000000000000..7bdaa8f957a9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_resolver.py @@ -0,0 +1,202 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _resolver.py - testing public methods of DefaultFoundryToolInvocationResolver.""" +import pytest +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.runtime._resolver import DefaultFoundryToolInvocationResolver +from azure.ai.agentserver.core.tools.runtime._invoker import DefaultFoundryToolInvoker +from azure.ai.agentserver.core.tools._exceptions import UnableToResolveToolInvocationError +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, +) + + +class TestDefaultFoundryToolInvocationResolverResolve: + """Tests for DefaultFoundryToolInvocationResolver.resolve method.""" + + @pytest.fixture + def mock_catalog(self, sample_resolved_mcp_tool): + """Create a mock FoundryToolCatalog.""" + catalog = AsyncMock() + catalog.get = AsyncMock(return_value=sample_resolved_mcp_tool) + catalog.list = AsyncMock(return_value=[sample_resolved_mcp_tool]) + return catalog + + @pytest.mark.asyncio + async def test_resolve_with_resolved_tool_returns_invoker_directly( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_mcp_tool + ): + """Test resolve returns invoker directly when given ResolvedFoundryTool.""" + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(sample_resolved_mcp_tool) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + assert invoker.resolved_tool is sample_resolved_mcp_tool + # Catalog should not be called when ResolvedFoundryTool is passed + mock_catalog.get.assert_not_called() + + @pytest.mark.asyncio + async def test_resolve_with_foundry_tool_uses_catalog( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_resolved_mcp_tool + ): + """Test resolve uses catalog to resolve FoundryTool.""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_mcp_tool) + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(sample_hosted_mcp_tool) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + mock_catalog.get.assert_called_once_with(sample_hosted_mcp_tool) + + @pytest.mark.asyncio + async def test_resolve_with_facade_dict_uses_catalog( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_connected_tool + ): + """Test resolve converts facade dict and uses catalog.""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_connected_tool) + facade = { + "type": "mcp", + "project_connection_id": "test-connection" + } + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(facade) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + mock_catalog.get.assert_called_once() + # Verify the facade was converted to FoundryConnectedTool + call_arg = mock_catalog.get.call_args[0][0] + assert isinstance(call_arg, FoundryConnectedTool) + + @pytest.mark.asyncio + async def test_resolve_raises_error_when_tool_not_found_in_catalog( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool + ): + """Test resolve raises UnableToResolveToolInvocationError when catalog returns None.""" + mock_catalog.get = AsyncMock(return_value=None) + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + with pytest.raises(UnableToResolveToolInvocationError) as exc_info: + await resolver.resolve(sample_hosted_mcp_tool) + + assert exc_info.value.tool is sample_hosted_mcp_tool + assert "Unable to resolve tool" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_resolve_with_hosted_mcp_facade( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_mcp_tool + ): + """Test resolve with hosted MCP facade (unknown type becomes FoundryHostedMcpTool).""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_mcp_tool) + facade = { + "type": "custom_mcp_tool", + "config_key": "config_value" + } + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(facade) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + # Verify the facade was converted to FoundryHostedMcpTool + call_arg = mock_catalog.get.call_args[0][0] + assert isinstance(call_arg, FoundryHostedMcpTool) + assert call_arg.name == "custom_mcp_tool" + + @pytest.mark.asyncio + async def test_resolve_returns_invoker_with_correct_agent_name( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_mcp_tool + ): + """Test resolve creates invoker with the correct agent name.""" + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="custom-agent-name" + ) + + invoker = await resolver.resolve(sample_resolved_mcp_tool) + + # Verify invoker was created with correct agent name by checking internal state + assert invoker._agent_name == "custom-agent-name" + + @pytest.mark.asyncio + async def test_resolve_with_connected_tool_directly( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_connected_tool, + sample_resolved_connected_tool + ): + """Test resolve with FoundryConnectedTool directly.""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_connected_tool) + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(sample_connected_tool) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + mock_catalog.get.assert_called_once_with(sample_connected_tool) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_runtime.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_runtime.py new file mode 100644 index 000000000000..e42fc29a76cd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_runtime.py @@ -0,0 +1,283 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _runtime.py - testing public methods of DefaultFoundryToolRuntime.""" +import os +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from azure.ai.agentserver.core.tools.runtime._runtime import DefaultFoundryToolRuntime +from azure.ai.agentserver.core.tools.runtime._catalog import DefaultFoundryToolCatalog +from azure.ai.agentserver.core.tools.runtime._resolver import DefaultFoundryToolInvocationResolver +from azure.ai.agentserver.core.tools.runtime._user import ContextVarUserProvider + + +class TestDefaultFoundryToolRuntimeInit: + """Tests for DefaultFoundryToolRuntime initialization.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_creates_client_with_endpoint_and_credential( + self, + mock_client_class, + mock_credential + ): + """Test initialization creates client with correct endpoint and credential.""" + endpoint = "https://test-project.azure.com" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint=endpoint, + credential=mock_credential + ) + + mock_client_class.assert_called_once_with( + endpoint=endpoint, + credential=mock_credential + ) + assert runtime is not None + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_uses_default_user_provider_when_none_provided( + self, + mock_client_class, + mock_credential + ): + """Test initialization uses ContextVarUserProvider when user_provider is None.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert isinstance(runtime._user_provider, ContextVarUserProvider) + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_uses_custom_user_provider( + self, + mock_client_class, + mock_credential, + mock_user_provider + ): + """Test initialization uses custom user provider when provided.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential, + user_provider=mock_user_provider + ) + + assert runtime._user_provider is mock_user_provider + + @patch.dict(os.environ, {"AGENT_NAME": "custom-agent"}) + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_reads_agent_name_from_environment( + self, + mock_client_class, + mock_credential + ): + """Test initialization reads agent name from environment variable.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert runtime._agent_name == "custom-agent" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_uses_default_agent_name_when_env_not_set( + self, + mock_client_class, + mock_credential + ): + """Test initialization uses default agent name when env var is not set.""" + mock_client_class.return_value = MagicMock() + + # Ensure AGENT_NAME is not set + env_copy = os.environ.copy() + if "AGENT_NAME" in env_copy: + del env_copy["AGENT_NAME"] + + with patch.dict(os.environ, env_copy, clear=True): + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert runtime._agent_name == "$default" + + +class TestDefaultFoundryToolRuntimeCatalog: + """Tests for DefaultFoundryToolRuntime.catalog property.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_catalog_returns_default_catalog( + self, + mock_client_class, + mock_credential + ): + """Test catalog property returns DefaultFoundryToolCatalog.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert isinstance(runtime.catalog, DefaultFoundryToolCatalog) + + +class TestDefaultFoundryToolRuntimeInvocation: + """Tests for DefaultFoundryToolRuntime.invocation property.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_invocation_returns_default_resolver( + self, + mock_client_class, + mock_credential + ): + """Test invocation property returns DefaultFoundryToolInvocationResolver.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert isinstance(runtime.invocation, DefaultFoundryToolInvocationResolver) + + +class TestDefaultFoundryToolRuntimeInvoke: + """Tests for DefaultFoundryToolRuntime.invoke method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_invoke_resolves_and_invokes_tool( + self, + mock_client_class, + mock_credential, + sample_resolved_mcp_tool + ): + """Test invoke resolves the tool and calls the invoker.""" + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + # Mock the invocation resolver + mock_invoker = AsyncMock() + mock_invoker.invoke = AsyncMock(return_value={"result": "success"}) + runtime._invocation.resolve = AsyncMock(return_value=mock_invoker) + + result = await runtime.invoke(sample_resolved_mcp_tool, {"input": "test"}) + + assert result == {"result": "success"} + runtime._invocation.resolve.assert_called_once_with(sample_resolved_mcp_tool) + mock_invoker.invoke.assert_called_once_with({"input": "test"}) + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_invoke_with_facade_dict( + self, + mock_client_class, + mock_credential + ): + """Test invoke works with facade dictionary.""" + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + facade = {"type": "custom_tool", "config": "value"} + + # Mock the invocation resolver + mock_invoker = AsyncMock() + mock_invoker.invoke = AsyncMock(return_value={"output": "done"}) + runtime._invocation.resolve = AsyncMock(return_value=mock_invoker) + + result = await runtime.invoke(facade, {"param": "value"}) + + assert result == {"output": "done"} + runtime._invocation.resolve.assert_called_once_with(facade) + + +class TestDefaultFoundryToolRuntimeContextManager: + """Tests for DefaultFoundryToolRuntime async context manager.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_aenter_returns_runtime_and_enters_client( + self, + mock_client_class, + mock_credential + ): + """Test __aenter__ enters client and returns runtime.""" + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + async with runtime as r: + assert r is runtime + mock_client_instance.__aenter__.assert_called_once() + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_aexit_exits_client( + self, + mock_client_class, + mock_credential + ): + """Test __aexit__ exits client properly.""" + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + async with runtime: + pass + + mock_client_instance.__aexit__.assert_called_once() + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_aexit_called_on_exception( + self, + mock_client_class, + mock_credential + ): + """Test __aexit__ is called even when exception occurs.""" + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + with pytest.raises(ValueError): + async with runtime: + raise ValueError("Test error") + + mock_client_instance.__aexit__.assert_called_once() diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_starlette.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_starlette.py new file mode 100644 index 000000000000..d1d72004d011 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_starlette.py @@ -0,0 +1,261 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _starlette.py - testing public methods of UserInfoContextMiddleware.""" +import pytest +from contextvars import ContextVar +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.client._models import UserInfo + + +class TestUserInfoContextMiddlewareInstall: + """Tests for UserInfoContextMiddleware.install class method.""" + + def test_install_adds_middleware_to_starlette_app(self): + """Test install adds middleware to Starlette application.""" + # Import here to avoid requiring starlette when not needed + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_app = MagicMock() + + UserInfoContextMiddleware.install(mock_app) + + mock_app.add_middleware.assert_called_once() + call_args = mock_app.add_middleware.call_args + assert call_args[0][0] == UserInfoContextMiddleware + + def test_install_uses_default_context_when_none_provided(self): + """Test install uses default user context when none is provided.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + from azure.ai.agentserver.core.tools.runtime._user import ContextVarUserProvider + + mock_app = MagicMock() + + UserInfoContextMiddleware.install(mock_app) + + call_kwargs = mock_app.add_middleware.call_args[1] + assert call_kwargs["user_info_var"] is ContextVarUserProvider.default_user_info_context + + def test_install_uses_custom_context(self): + """Test install uses custom user context when provided.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_app = MagicMock() + custom_context = ContextVar("custom_context") + + UserInfoContextMiddleware.install(mock_app, user_context=custom_context) + + call_kwargs = mock_app.add_middleware.call_args[1] + assert call_kwargs["user_info_var"] is custom_context + + def test_install_uses_custom_resolver(self): + """Test install uses custom user resolver when provided.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_app = MagicMock() + + async def custom_resolver(request): + return UserInfo(object_id="custom-oid", tenant_id="custom-tid") + + UserInfoContextMiddleware.install(mock_app, user_resolver=custom_resolver) + + call_kwargs = mock_app.add_middleware.call_args[1] + assert call_kwargs["user_resolver"] is custom_resolver + + +class TestUserInfoContextMiddlewareDispatch: + """Tests for UserInfoContextMiddleware.dispatch method.""" + + @pytest.mark.asyncio + async def test_dispatch_sets_user_in_context(self): + """Test dispatch sets user info in context variable.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + user_info = UserInfo(object_id="test-oid", tenant_id="test-tid") + + async def mock_resolver(request): + return user_info + + # Create a simple mock app + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + captured_user = None + + async def call_next(request): + nonlocal captured_user + captured_user = user_context.get(None) + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + assert captured_user is user_info + + @pytest.mark.asyncio + async def test_dispatch_resets_context_after_request(self): + """Test dispatch resets context variable after request completes.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + original_user = UserInfo(object_id="original-oid", tenant_id="original-tid") + user_context.set(original_user) + + new_user = UserInfo(object_id="new-oid", tenant_id="new-tid") + + async def mock_resolver(request): + return new_user + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + + async def call_next(request): + # During request, should have new_user + assert user_context.get(None) is new_user + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + # After request, context should be reset to original value + assert user_context.get(None) is original_user + + @pytest.mark.asyncio + async def test_dispatch_resets_context_on_exception(self): + """Test dispatch resets context even when call_next raises exception.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + original_user = UserInfo(object_id="original-oid", tenant_id="original-tid") + user_context.set(original_user) + + new_user = UserInfo(object_id="new-oid", tenant_id="new-tid") + + async def mock_resolver(request): + return new_user + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + + async def call_next(request): + raise RuntimeError("Request failed") + + with pytest.raises(RuntimeError, match="Request failed"): + await middleware.dispatch(mock_request, call_next) + + # Context should still be reset to original + assert user_context.get(None) is original_user + + @pytest.mark.asyncio + async def test_dispatch_handles_none_user(self): + """Test dispatch handles None user from resolver.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + + async def mock_resolver(request): + return None + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + captured_user = "not_set" + + async def call_next(request): + nonlocal captured_user + captured_user = user_context.get("default") + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + assert captured_user is None + + @pytest.mark.asyncio + async def test_dispatch_calls_resolver_with_request(self): + """Test dispatch calls user resolver with the request object.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + captured_request = None + + async def mock_resolver(request): + nonlocal captured_request + captured_request = request + return UserInfo(object_id="oid", tenant_id="tid") + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + mock_request.url = "https://test.com/api" + + async def call_next(request): + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + assert captured_request is mock_request + + +class TestUserInfoContextMiddlewareDefaultResolver: + """Tests for UserInfoContextMiddleware default resolver.""" + + @pytest.mark.asyncio + async def test_default_resolver_extracts_user_from_headers(self): + """Test default resolver extracts user info from request headers.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_request = MagicMock() + mock_request.headers = { + "x-aml-oid": "header-object-id", + "x-aml-tid": "header-tenant-id" + } + + result = await UserInfoContextMiddleware._default_user_resolver(mock_request) + + assert result is not None + assert result.object_id == "header-object-id" + assert result.tenant_id == "header-tenant-id" + + @pytest.mark.asyncio + async def test_default_resolver_returns_none_when_headers_missing(self): + """Test default resolver returns None when required headers are missing.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_request = MagicMock() + mock_request.headers = {} + + result = await UserInfoContextMiddleware._default_user_resolver(mock_request) + + assert result is None diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_user.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_user.py new file mode 100644 index 000000000000..a909d9e5948a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_user.py @@ -0,0 +1,210 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _user.py - testing public methods of ContextVarUserProvider and resolve_user_from_headers.""" +import pytest +from contextvars import ContextVar + +from azure.ai.agentserver.core.tools.runtime._user import ( + ContextVarUserProvider, + resolve_user_from_headers, +) +from azure.ai.agentserver.core.tools.client._models import UserInfo + + +class TestContextVarUserProvider: + """Tests for ContextVarUserProvider public methods.""" + + @pytest.mark.asyncio + async def test_get_user_returns_none_when_context_not_set(self): + """Test get_user returns None when context variable is not set.""" + custom_context = ContextVar("test_user_context") + provider = ContextVarUserProvider(context=custom_context) + + result = await provider.get_user() + + assert result is None + + @pytest.mark.asyncio + async def test_get_user_returns_user_when_context_is_set(self, sample_user_info): + """Test get_user returns UserInfo when context variable is set.""" + custom_context = ContextVar("test_user_context") + custom_context.set(sample_user_info) + provider = ContextVarUserProvider(context=custom_context) + + result = await provider.get_user() + + assert result is sample_user_info + assert result.object_id == "test-object-id" + assert result.tenant_id == "test-tenant-id" + + @pytest.mark.asyncio + async def test_uses_default_context_when_none_provided(self, sample_user_info): + """Test that default context is used when no context is provided.""" + # Set value in default context + ContextVarUserProvider.default_user_info_context.set(sample_user_info) + provider = ContextVarUserProvider() + + result = await provider.get_user() + + assert result is sample_user_info + + @pytest.mark.asyncio + async def test_different_providers_share_same_default_context(self, sample_user_info): + """Test that different providers using default context share the same value.""" + ContextVarUserProvider.default_user_info_context.set(sample_user_info) + provider1 = ContextVarUserProvider() + provider2 = ContextVarUserProvider() + + result1 = await provider1.get_user() + result2 = await provider2.get_user() + + assert result1 is result2 is sample_user_info + + @pytest.mark.asyncio + async def test_custom_context_isolation(self, sample_user_info): + """Test that custom contexts are isolated from each other.""" + context1 = ContextVar("context1") + context2 = ContextVar("context2") + user2 = UserInfo(object_id="other-oid", tenant_id="other-tid") + + context1.set(sample_user_info) + context2.set(user2) + + provider1 = ContextVarUserProvider(context=context1) + provider2 = ContextVarUserProvider(context=context2) + + result1 = await provider1.get_user() + result2 = await provider2.get_user() + + assert result1 is sample_user_info + assert result2 is user2 + assert result1 is not result2 + + +class TestResolveUserFromHeaders: + """Tests for resolve_user_from_headers public function.""" + + def test_returns_user_info_when_both_headers_present(self): + """Test returns UserInfo when both object_id and tenant_id headers are present.""" + headers = { + "x-aml-oid": "user-object-id", + "x-aml-tid": "user-tenant-id" + } + + result = resolve_user_from_headers(headers) + + assert result is not None + assert isinstance(result, UserInfo) + assert result.object_id == "user-object-id" + assert result.tenant_id == "user-tenant-id" + + def test_returns_none_when_object_id_missing(self): + """Test returns None when object_id header is missing.""" + headers = {"x-aml-tid": "user-tenant-id"} + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_tenant_id_missing(self): + """Test returns None when tenant_id header is missing.""" + headers = {"x-aml-oid": "user-object-id"} + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_both_headers_missing(self): + """Test returns None when both headers are missing.""" + headers = {} + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_object_id_is_empty(self): + """Test returns None when object_id is empty string.""" + headers = { + "x-aml-oid": "", + "x-aml-tid": "user-tenant-id" + } + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_tenant_id_is_empty(self): + """Test returns None when tenant_id is empty string.""" + headers = { + "x-aml-oid": "user-object-id", + "x-aml-tid": "" + } + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_custom_header_names(self): + """Test using custom header names for object_id and tenant_id.""" + headers = { + "custom-oid-header": "custom-object-id", + "custom-tid-header": "custom-tenant-id" + } + + result = resolve_user_from_headers( + headers, + object_id_header="custom-oid-header", + tenant_id_header="custom-tid-header" + ) + + assert result is not None + assert result.object_id == "custom-object-id" + assert result.tenant_id == "custom-tenant-id" + + def test_default_headers_not_matched_with_custom_headers(self): + """Test that default headers are not matched when custom headers are specified.""" + headers = { + "x-aml-oid": "default-object-id", + "x-aml-tid": "default-tenant-id" + } + + result = resolve_user_from_headers( + headers, + object_id_header="custom-oid", + tenant_id_header="custom-tid" + ) + + assert result is None + + def test_case_sensitive_header_matching(self): + """Test that header matching is case-sensitive.""" + headers = { + "X-AML-OID": "user-object-id", + "X-AML-TID": "user-tenant-id" + } + + # Default headers are lowercase, so these should not match + result = resolve_user_from_headers(headers) + + assert result is None + + def test_with_mapping_like_object(self): + """Test with a mapping-like object that supports .get().""" + class HeadersMapping: + def __init__(self, data): + self._data = data + + def get(self, key, default=""): + return self._data.get(key, default) + + headers = HeadersMapping({ + "x-aml-oid": "mapping-object-id", + "x-aml-tid": "mapping-tenant-id" + }) + + result = resolve_user_from_headers(headers) + + assert result is not None + assert result.object_id == "mapping-object-id" + assert result.tenant_id == "mapping-tenant-id" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/__init__.py similarity index 84% rename from sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/__init__.py rename to sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/__init__.py index d02a9af6c5f6..2d7503de198d 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/operations/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/__init__.py @@ -1,4 +1,4 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- - +"""Utils unit tests package.""" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/conftest.py new file mode 100644 index 000000000000..abd2f5145c29 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/conftest.py @@ -0,0 +1,56 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for utils unit tests. + +Common fixtures are inherited from the parent conftest.py automatically by pytest. +""" +from typing import Optional + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaType, +) + + +def create_resolved_tool_with_name( + name: str, + tool_type: str = "mcp", + connection_id: Optional[str] = None +) -> ResolvedFoundryTool: + """Helper to create a ResolvedFoundryTool with a specific name. + + :param name: The name for the tool details. + :param tool_type: Either "mcp" or "connected". + :param connection_id: Connection ID for connected tools. If provided with tool_type="mcp", + will automatically use "connected" type to ensure unique tool IDs. + :return: A ResolvedFoundryTool instance. + """ + schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={}, + required=set() + ) + details = FoundryToolDetails( + name=name, + description=f"Tool named {name}", + input_schema=schema + ) + + # If connection_id is provided, use connected tool to ensure unique IDs + if connection_id is not None or tool_type == "connected": + definition = FoundryConnectedTool( + protocol="mcp", + project_connection_id=connection_id or f"conn-{name}" + ) + else: + definition = FoundryHostedMcpTool( + name=f"mcp-{name}", + configuration={} + ) + + return ResolvedFoundryTool(definition=definition, details=details) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/test_name_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/test_name_resolver.py new file mode 100644 index 000000000000..14340799253b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/test_name_resolver.py @@ -0,0 +1,260 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _name_resolver.py - testing public methods of ToolNameResolver.""" +from azure.ai.agentserver.core.tools.utils import ToolNameResolver +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, +) + +from .conftest import create_resolved_tool_with_name + + +class TestToolNameResolverResolve: + """Tests for ToolNameResolver.resolve method.""" + + def test_resolve_returns_tool_name_for_first_occurrence( + self, + sample_resolved_mcp_tool + ): + """Test resolve returns the original tool name for first occurrence.""" + resolver = ToolNameResolver() + + result = resolver.resolve(sample_resolved_mcp_tool) + + assert result == sample_resolved_mcp_tool.details.name + + def test_resolve_returns_same_name_for_same_tool( + self, + sample_resolved_mcp_tool + ): + """Test resolve returns the same name when called multiple times for same tool.""" + resolver = ToolNameResolver() + + result1 = resolver.resolve(sample_resolved_mcp_tool) + result2 = resolver.resolve(sample_resolved_mcp_tool) + result3 = resolver.resolve(sample_resolved_mcp_tool) + + assert result1 == result2 == result3 + assert result1 == sample_resolved_mcp_tool.details.name + + def test_resolve_appends_count_for_duplicate_names(self): + """Test resolve appends count for tools with duplicate names.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("my_tool", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("my_tool", connection_id="conn-2") + tool3 = create_resolved_tool_with_name("my_tool", connection_id="conn-3") + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + result3 = resolver.resolve(tool3) + + assert result1 == "my_tool" + assert result2 == "my_tool_1" + assert result3 == "my_tool_2" + + def test_resolve_handles_multiple_unique_names(self): + """Test resolve handles multiple tools with unique names.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("tool_alpha") + tool2 = create_resolved_tool_with_name("tool_beta") + tool3 = create_resolved_tool_with_name("tool_gamma") + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + result3 = resolver.resolve(tool3) + + assert result1 == "tool_alpha" + assert result2 == "tool_beta" + assert result3 == "tool_gamma" + + def test_resolve_mixed_unique_and_duplicate_names(self): + """Test resolve handles a mix of unique and duplicate names.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("shared_name", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("unique_name") + tool3 = create_resolved_tool_with_name("shared_name", connection_id="conn-2") + tool4 = create_resolved_tool_with_name("another_unique") + tool5 = create_resolved_tool_with_name("shared_name", connection_id="conn-3") + + assert resolver.resolve(tool1) == "shared_name" + assert resolver.resolve(tool2) == "unique_name" + assert resolver.resolve(tool3) == "shared_name_1" + assert resolver.resolve(tool4) == "another_unique" + assert resolver.resolve(tool5) == "shared_name_2" + + def test_resolve_returns_cached_name_after_duplicate_added(self): + """Test that resolving a tool again returns cached name even after duplicates are added.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("my_tool", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("my_tool", connection_id="conn-2") + + # First resolution + first_result = resolver.resolve(tool1) + assert first_result == "my_tool" + + # Add duplicate + dup_result = resolver.resolve(tool2) + assert dup_result == "my_tool_1" + + # Resolve original again - should return cached value + second_result = resolver.resolve(tool1) + assert second_result == "my_tool" + + def test_resolve_with_connected_tool( + self, + sample_resolved_connected_tool + ): + """Test resolve works with connected tools.""" + resolver = ToolNameResolver() + + result = resolver.resolve(sample_resolved_connected_tool) + + assert result == sample_resolved_connected_tool.details.name + + def test_resolve_different_tools_same_details_name(self, sample_schema_definition): + """Test resolve handles different tool definitions with same details name.""" + resolver = ToolNameResolver() + + details = FoundryToolDetails( + name="shared_function", + description="A shared function", + input_schema=sample_schema_definition + ) + + mcp_def = FoundryHostedMcpTool(name="mcp_server", configuration={}) + connected_def = FoundryConnectedTool(protocol="mcp", project_connection_id="my-conn") + + tool1 = ResolvedFoundryTool(definition=mcp_def, details=details) + tool2 = ResolvedFoundryTool(definition=connected_def, details=details) + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + + assert result1 == "shared_function" + assert result2 == "shared_function_1" + + def test_resolve_empty_name(self): + """Test resolve handles tools with empty name.""" + resolver = ToolNameResolver() + + tool = create_resolved_tool_with_name("") + + result = resolver.resolve(tool) + + assert result == "" + + def test_resolve_special_characters_in_name(self): + """Test resolve handles tools with special characters in name.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("my-tool_v1.0", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("my-tool_v1.0", connection_id="conn-2") + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + + assert result1 == "my-tool_v1.0" + assert result2 == "my-tool_v1.0_1" + + def test_independent_resolver_instances(self): + """Test that different resolver instances maintain independent state.""" + resolver1 = ToolNameResolver() + resolver2 = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("tool_name", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("tool_name", connection_id="conn-2") + + # Both resolvers resolve tool1 first + assert resolver1.resolve(tool1) == "tool_name" + assert resolver2.resolve(tool1) == "tool_name" + + # resolver1 resolves tool2 as duplicate + assert resolver1.resolve(tool2) == "tool_name_1" + + # resolver2 has not seen tool2 yet in its context + # but tool2 has same name, so it should be duplicate + assert resolver2.resolve(tool2) == "tool_name_1" + + def test_resolve_many_duplicates(self): + """Test resolve handles many tools with the same name.""" + resolver = ToolNameResolver() + + tools = [ + create_resolved_tool_with_name("common_name", connection_id=f"conn-{i}") + for i in range(10) + ] + + results = [resolver.resolve(tool) for tool in tools] + + expected = ["common_name"] + [f"common_name_{i}" for i in range(1, 10)] + assert results == expected + + def test_resolve_uses_tool_id_for_caching(self, sample_schema_definition): + """Test that resolve uses tool.id for caching, not just name.""" + resolver = ToolNameResolver() + + # Create two tools with same definition but different details names + definition = FoundryHostedMcpTool(name="same_definition", configuration={}) + + details1 = FoundryToolDetails( + name="function_a", + description="Function A", + input_schema=sample_schema_definition + ) + details2 = FoundryToolDetails( + name="function_b", + description="Function B", + input_schema=sample_schema_definition + ) + + tool1 = ResolvedFoundryTool(definition=definition, details=details1) + tool2 = ResolvedFoundryTool(definition=definition, details=details2) + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + + # Both should get their respective names since they have different tool.id + assert result1 == "function_a" + assert result2 == "function_b" + + def test_resolve_idempotent_for_same_tool_id(self, sample_schema_definition): + """Test that resolve is idempotent for the same tool id.""" + resolver = ToolNameResolver() + + definition = FoundryHostedMcpTool(name="my_mcp", configuration={}) + details = FoundryToolDetails( + name="my_function", + description="My function", + input_schema=sample_schema_definition + ) + tool = ResolvedFoundryTool(definition=definition, details=details) + + # Call resolve many times + results = [resolver.resolve(tool) for _ in range(5)] + + # All should return the same name + assert all(r == "my_function" for r in results) + + def test_resolve_interleaved_tool_resolutions(self): + """Test resolve with interleaved resolutions of different tools.""" + resolver = ToolNameResolver() + + toolA_1 = create_resolved_tool_with_name("A", connection_id="A-1") + toolA_2 = create_resolved_tool_with_name("A", connection_id="A-2") + toolB_1 = create_resolved_tool_with_name("B", connection_id="B-1") + toolA_3 = create_resolved_tool_with_name("A", connection_id="A-3") + toolB_2 = create_resolved_tool_with_name("B", connection_id="B-2") + + assert resolver.resolve(toolA_1) == "A" + assert resolver.resolve(toolB_1) == "B" + assert resolver.resolve(toolA_2) == "A_1" + assert resolver.resolve(toolA_3) == "A_2" + assert resolver.resolve(toolB_2) == "B_1" diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py index ccd79dff8fc6..0ea9a2da80f2 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py @@ -5,7 +5,7 @@ from langchain_core.language_models import BaseChatModel -from azure.ai.agentserver.core.tools import FoundryToolLike +from azure.ai.agentserver.core.tools import FoundryToolLike, ensure_foundry_tool from ._chat_model import FoundryToolLateBindingChatModel from ._middleware import FoundryToolBindingMiddleware from ._resolver import get_registry @@ -54,7 +54,10 @@ def use_foundry_tools( # pylint: disable=C4743 if isinstance(model_or_tools, BaseChatModel): if tools is None: raise ValueError("Tools must be provided when a model is given.") - get_registry().extend(tools) - return FoundryToolLateBindingChatModel(model_or_tools, runtime=None, foundry_tools=tools) - get_registry().extend(model_or_tools) - return FoundryToolBindingMiddleware(model_or_tools) + foundry_tools = [ensure_foundry_tool(tool) for tool in tools] + get_registry().extend(foundry_tools) + return FoundryToolLateBindingChatModel(model_or_tools, runtime=None, foundry_tools=foundry_tools) + + foundry_tools = [ensure_foundry_tool(tool) for tool in model_or_tools] + get_registry().extend(foundry_tools) + return FoundryToolBindingMiddleware(foundry_tools)