diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py index 78d8108ed96c..64120308a872 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py @@ -13,7 +13,7 @@ from azure.ai.agentserver.core import AgentServerContext from azure.ai.agentserver.core.logger import get_logger -from azure.ai.agentserver.core.tools import FoundryToolLike, ResolvedFoundryTool +from azure.ai.agentserver.core.tools import FoundryToolLike, ResolvedFoundryTool, ensure_foundry_tool logger = get_logger() @@ -45,7 +45,7 @@ def __init__( self, tools: Sequence[FoundryToolLike], ) -> None: - self._allowed_tools: List[FoundryToolLike] = list(tools) + self._allowed_tools: List[FoundryToolLike] = [ensure_foundry_tool(tool) for tool in tools] async def list_tools(self) -> List[AIFunction]: server_context = AgentServerContext.get() @@ -71,7 +71,7 @@ def _to_aifunction(self, foundry_tool: "ResolvedFoundryTool") -> AIFunction: # Build field definitions for the Pydantic model field_definitions: Dict[str, Any] = {} for field_name, field_info in properties.items(): - field_type = self._json_schema_type_to_python(field_info.type or "string") + field_type = field_info.type.py_type field_description = field_info.description or "" is_required = field_name in required_fields @@ -107,24 +107,6 @@ async def tool_func(**kwargs: Any) -> Any: input_model=input_model ) - def _json_schema_type_to_python(self, json_type: str) -> type: - """Convert JSON schema type to Python type. - - :param json_type: The JSON schema type string. - :type json_type: str - :return: The corresponding Python type. - :rtype: type - """ - type_map = { - "string": str, - "number": float, - "integer": int, - "boolean": bool, - "array": list, - "object": dict, - } - return type_map.get(json_type, str) - class FoundryToolsChatMiddleware(ChatMiddleware): """Chat middleware to inject Foundry tools into ChatOptions on each call.""" diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/__init__.py deleted file mode 100644 index 4a5d26360bce..000000000000 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Unit tests package diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/conftest.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/conftest.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-agentframework/tests/conftest.py rename to sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/conftest.py diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_agent_framework_input_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/test_agent_framework_input_converter.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_agent_framework_input_converter.py rename to sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/agent_framework/test_agent_framework_input_converter.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py index 5b356f38c825..34c58d65cfd6 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py @@ -15,6 +15,7 @@ FoundryConnectedTool, FoundryHostedMcpTool, FoundryTool, + FoundryToolDetails, FoundryToolProtocol, FoundryToolSource, ResolvedFoundryTool, @@ -47,6 +48,7 @@ "FoundryConnectedTool", "FoundryHostedMcpTool", "FoundryTool", + "FoundryToolDetails", "FoundryToolProtocol", "FoundryToolSource", "ResolvedFoundryTool", diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py index f12d3f0db7b5..bfc4a08d9a63 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py @@ -1,6 +1,7 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +import re from typing import Any, Dict, Union from .. import FoundryConnectedTool, FoundryHostedMcpTool @@ -45,6 +46,48 @@ def ensure_foundry_tool(tool: FoundryToolLike) -> FoundryTool: if not isinstance(project_connection_id, str) or not project_connection_id: raise InvalidToolFacadeError(f"project_connection_id is required for tool protocol {protocol}.") - return FoundryConnectedTool(protocol=protocol, project_connection_id=project_connection_id) + # Parse the connection identifier to extract the connection name + connection_name = _parse_connection_id(project_connection_id) + return FoundryConnectedTool(protocol=protocol, project_connection_id=connection_name) except ValueError: return FoundryHostedMcpTool(name=tool_type, configuration=tool) + + +# Pattern for Azure resource ID format: +# /subscriptions//resourceGroups//providers/Microsoft.CognitiveServices/accounts//projects//connections/ +_RESOURCE_ID_PATTERN = re.compile( + r"^/subscriptions/[^/]+/resourceGroups/[^/]+/providers/Microsoft\.CognitiveServices/" + r"accounts/[^/]+/projects/[^/]+/connections/(?P[^/]+)$", + re.IGNORECASE, +) + + +def _parse_connection_id(connection_id: str) -> str: + """Parse the connection identifier and extract the connection name. + + Supports two formats: + 1. Simple name: "my-connection-name" + 2. Resource ID: "/subscriptions//resourceGroups//providers/Microsoft.CognitiveServices/accounts//projects//connections/" + + :param connection_id: The connection identifier, either a simple name or a full resource ID. + :type connection_id: str + :return: The connection name extracted from the identifier. + :rtype: str + :raises InvalidToolFacadeError: If the connection_id format is invalid. + """ + if not connection_id: + raise InvalidToolFacadeError("Connection identifier cannot be empty.") + + # Check if it's a resource ID format (starts with /) + if connection_id.startswith("/"): + match = _RESOURCE_ID_PATTERN.match(connection_id) + if not match: + raise InvalidToolFacadeError( + f"Invalid resource ID format for connection: '{connection_id}'. " + "Expected format: /subscriptions//resourceGroups//providers/" + "Microsoft.CognitiveServices/accounts//projects//connections/" + ) + return match.group("name") + + # Otherwise, treat it as a simple connection name + return connection_id diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/__init__.py new file mode 100644 index 000000000000..d02a9af6c5f6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/__init__.py new file mode 100644 index 000000000000..d02a9af6c5f6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/test_foundry_connected_tools.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/test_foundry_connected_tools.py new file mode 100644 index 000000000000..e7273f37a7e7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/test_foundry_connected_tools.py @@ -0,0 +1,479 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryConnectedToolsOperations - testing only public methods.""" +import pytest +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryToolDetails, +) +from azure.ai.agentserver.core.tools.client.operations._foundry_connected_tools import ( + FoundryConnectedToolsOperations, +) +from azure.ai.agentserver.core.tools._exceptions import OAuthConsentRequiredError, ToolInvocationError + +from ...conftest import create_mock_http_response + + +class TestFoundryConnectedToolsOperationsListTools: + """Tests for FoundryConnectedToolsOperations.list_tools public method.""" + + @pytest.mark.asyncio + async def test_list_tools_with_empty_list_returns_empty(self): + """Test list_tools returns empty when tools list is empty.""" + mock_client = AsyncMock() + ops = FoundryConnectedToolsOperations(mock_client) + + result = await ops.list_tools([], None, "test-agent") + + assert result == [] + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_list_tools_returns_tools_from_server( + self, + sample_connected_tool, + sample_user_info + ): + """Test list_tools returns tools from server response.""" + mock_client = AsyncMock() + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "remote_tool", + "description": "A remote connected tool", + "parameters": { + "type": "object", + "properties": { + "input": {"type": "string"} + } + } + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([sample_connected_tool], sample_user_info, "test-agent")) + + assert len(result) == 1 + definition, details = result[0] + assert definition == sample_connected_tool + assert isinstance(details, FoundryToolDetails) + assert details.name == "remote_tool" + assert details.description == "A remote connected tool" + + @pytest.mark.asyncio + async def test_list_tools_without_user_info(self, sample_connected_tool): + """Test list_tools works without user info (local execution).""" + mock_client = AsyncMock() + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "tool_no_user", + "description": "Tool without user", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([sample_connected_tool], None, "test-agent")) + + assert len(result) == 1 + assert result[0][1].name == "tool_no_user" + + @pytest.mark.asyncio + async def test_list_tools_with_multiple_connections(self, sample_user_info): + """Test list_tools with multiple connected tool definitions.""" + mock_client = AsyncMock() + + tool1 = FoundryConnectedTool(protocol="mcp", project_connection_id="conn-1") + tool2 = FoundryConnectedTool(protocol="a2a", project_connection_id="conn-2") + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": "mcp", + "projectConnectionId": "conn-1" + }, + "manifest": [ + { + "name": "tool_from_conn1", + "description": "From connection 1", + "parameters": {"type": "object", "properties": {}} + } + ] + }, + { + "remoteServer": { + "protocol": "a2a", + "projectConnectionId": "conn-2" + }, + "manifest": [ + { + "name": "tool_from_conn2", + "description": "From connection 2", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([tool1, tool2], sample_user_info, "test-agent")) + + assert len(result) == 2 + names = {r[1].name for r in result} + assert names == {"tool_from_conn1", "tool_from_conn2"} + + @pytest.mark.asyncio + async def test_list_tools_filters_by_connection_id(self, sample_user_info): + """Test list_tools only returns tools from requested connections.""" + mock_client = AsyncMock() + + requested_tool = FoundryConnectedTool(protocol="mcp", project_connection_id="requested-conn") + + # Server returns tools from multiple connections, but we only requested one + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": "mcp", + "projectConnectionId": "requested-conn" + }, + "manifest": [ + { + "name": "requested_tool", + "description": "Requested", + "parameters": {"type": "object", "properties": {}} + } + ] + }, + { + "remoteServer": { + "protocol": "mcp", + "projectConnectionId": "unrequested-conn" + }, + "manifest": [ + { + "name": "unrequested_tool", + "description": "Not requested", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([requested_tool], sample_user_info, "test-agent")) + + # Should only return tools from requested connection + assert len(result) == 1 + assert result[0][1].name == "requested_tool" + + @pytest.mark.asyncio + async def test_list_tools_multiple_tools_per_connection( + self, + sample_connected_tool, + sample_user_info + ): + """Test list_tools returns multiple tools from same connection.""" + mock_client = AsyncMock() + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "tool_one", + "description": "First tool", + "parameters": {"type": "object", "properties": {}} + }, + { + "name": "tool_two", + "description": "Second tool", + "parameters": {"type": "object", "properties": {}} + }, + { + "name": "tool_three", + "description": "Third tool", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([sample_connected_tool], sample_user_info, "test-agent")) + + assert len(result) == 3 + names = {r[1].name for r in result} + assert names == {"tool_one", "tool_two", "tool_three"} + + @pytest.mark.asyncio + async def test_list_tools_raises_oauth_consent_error( + self, + sample_connected_tool, + sample_user_info + ): + """Test list_tools raises OAuthConsentRequiredError when consent needed.""" + mock_client = AsyncMock() + + response_data = { + "type": "OAuthConsentRequired", + "toolResult": { + "consentUrl": "https://login.microsoftonline.com/consent", + "message": "User consent is required to access this resource", + "projectConnectionId": sample_connected_tool.project_connection_id + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + + with pytest.raises(OAuthConsentRequiredError) as exc_info: + list(await ops.list_tools([sample_connected_tool], sample_user_info, "test-agent")) + + assert exc_info.value.consent_url == "https://login.microsoftonline.com/consent" + assert "consent" in exc_info.value.message.lower() + + +class TestFoundryConnectedToolsOperationsInvokeTool: + """Tests for FoundryConnectedToolsOperations.invoke_tool public method.""" + + @pytest.mark.asyncio + async def test_invoke_tool_returns_result_value( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool returns the result value from server.""" + mock_client = AsyncMock() + + expected_result = {"data": "some output", "status": "success"} + response_data = {"toolResult": expected_result} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = await ops.invoke_tool( + sample_resolved_connected_tool, + {"input": "test"}, + sample_user_info, + "test-agent" + ) + + assert result == expected_result + + @pytest.mark.asyncio + async def test_invoke_tool_without_user_info(self, sample_resolved_connected_tool): + """Test invoke_tool works without user info (local execution).""" + mock_client = AsyncMock() + + response_data = {"toolResult": "local result"} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = await ops.invoke_tool( + sample_resolved_connected_tool, + {}, + None, # No user info + "test-agent" + ) + + assert result == "local result" + + @pytest.mark.asyncio + async def test_invoke_tool_with_complex_arguments( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool handles complex nested arguments.""" + mock_client = AsyncMock() + + response_data = {"toolResult": "processed"} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + complex_args = { + "query": "search term", + "filters": { + "date_range": {"start": "2025-01-01", "end": "2025-12-31"}, + "categories": ["A", "B", "C"] + }, + "limit": 50 + } + + result = await ops.invoke_tool( + sample_resolved_connected_tool, + complex_args, + sample_user_info, + "test-agent" + ) + + assert result == "processed" + mock_client.send_request.assert_called_once() + + @pytest.mark.asyncio + async def test_invoke_tool_returns_none_for_empty_result( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool returns None when server returns no result.""" + mock_client = AsyncMock() + + # Server returns empty response (no toolResult) + response_data = { + "toolResult": None + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = await ops.invoke_tool( + sample_resolved_connected_tool, + {}, + sample_user_info, + "test-agent" + ) + + assert result is None + + @pytest.mark.asyncio + async def test_invoke_tool_with_mcp_tool_raises_error( + self, + sample_resolved_mcp_tool, + sample_user_info + ): + """Test invoke_tool raises ToolInvocationError for non-connected tool.""" + mock_client = AsyncMock() + ops = FoundryConnectedToolsOperations(mock_client) + + with pytest.raises(ToolInvocationError) as exc_info: + await ops.invoke_tool( + sample_resolved_mcp_tool, + {}, + sample_user_info, + "test-agent" + ) + + assert "not a Foundry connected tool" in str(exc_info.value) + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_invoke_tool_raises_oauth_consent_error( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool raises OAuthConsentRequiredError when consent needed.""" + mock_client = AsyncMock() + + response_data = { + "type": "OAuthConsentRequired", + "toolResult": { + "consentUrl": "https://login.microsoftonline.com/oauth/consent", + "message": "Please provide consent to continue", + "projectConnectionId": "test-connection-id" + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + + with pytest.raises(OAuthConsentRequiredError) as exc_info: + await ops.invoke_tool( + sample_resolved_connected_tool, + {"input": "test"}, + sample_user_info, + "test-agent" + ) + + assert "https://login.microsoftonline.com/oauth/consent" in exc_info.value.consent_url + + @pytest.mark.asyncio + async def test_invoke_tool_with_different_agent_names( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool uses correct agent name in request.""" + mock_client = AsyncMock() + + response_data = {"toolResult": "result"} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + + # Invoke with different agent names + for agent_name in ["agent-1", "my-custom-agent", "production-agent"]: + await ops.invoke_tool( + sample_resolved_connected_tool, + {}, + sample_user_info, + agent_name + ) + + # Verify the correct path was used + call_args = mock_client.post.call_args + assert agent_name in call_args[0][0] + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/test_foundry_hosted_mcp_tools.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/test_foundry_hosted_mcp_tools.py new file mode 100644 index 000000000000..473b27cc8768 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/operations/test_foundry_hosted_mcp_tools.py @@ -0,0 +1,309 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryMcpToolsOperations - testing only public methods.""" +import pytest +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) +from azure.ai.agentserver.core.tools.client.operations._foundry_hosted_mcp_tools import ( + FoundryMcpToolsOperations, +) +from azure.ai.agentserver.core.tools._exceptions import ToolInvocationError + +from ...conftest import create_mock_http_response + + +class TestFoundryMcpToolsOperationsListTools: + """Tests for FoundryMcpToolsOperations.list_tools public method.""" + + @pytest.mark.asyncio + async def test_list_tools_with_empty_list_returns_empty(self): + """Test list_tools returns empty when allowed_tools is empty.""" + mock_client = AsyncMock() + ops = FoundryMcpToolsOperations(mock_client) + + result = await ops.list_tools([]) + + assert result == [] + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_list_tools_returns_matching_tools(self, sample_hosted_mcp_tool): + """Test list_tools returns tools that match the allowed list.""" + mock_client = AsyncMock() + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Test MCP tool", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + } + } + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([sample_hosted_mcp_tool])) + + assert len(result) == 1 + definition, details = result[0] + assert definition == sample_hosted_mcp_tool + assert isinstance(details, FoundryToolDetails) + assert details.name == sample_hosted_mcp_tool.name + assert details.description == "Test MCP tool" + + @pytest.mark.asyncio + async def test_list_tools_filters_out_non_allowed_tools(self, sample_hosted_mcp_tool): + """Test list_tools only returns tools in the allowed list.""" + mock_client = AsyncMock() + + # Server returns multiple tools but only one is allowed + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Allowed tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "other_tool_not_in_list", + "description": "Not allowed tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "another_unlisted_tool", + "description": "Also not allowed", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([sample_hosted_mcp_tool])) + + assert len(result) == 1 + assert result[0][1].name == sample_hosted_mcp_tool.name + + @pytest.mark.asyncio + async def test_list_tools_with_multiple_allowed_tools(self): + """Test list_tools with multiple tools in allowed list.""" + mock_client = AsyncMock() + + tool1 = FoundryHostedMcpTool(name="tool_one") + tool2 = FoundryHostedMcpTool(name="tool_two") + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": "tool_one", + "description": "First tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "tool_two", + "description": "Second tool", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([tool1, tool2])) + + assert len(result) == 2 + names = {r[1].name for r in result} + assert names == {"tool_one", "tool_two"} + + @pytest.mark.asyncio + async def test_list_tools_preserves_tool_metadata(self): + """Test list_tools preserves metadata from server response.""" + mock_client = AsyncMock() + + tool = FoundryHostedMcpTool(name="tool_with_meta") + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": "tool_with_meta", + "description": "Tool with metadata", + "inputSchema": { + "type": "object", + "properties": { + "param1": {"type": "string"} + }, + "required": ["param1"] + }, + "_meta": { + "type": "object", + "properties": { + "model": {"type": "string"} + } + } + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([tool])) + + assert len(result) == 1 + details = result[0][1] + assert details.metadata is not None + + +class TestFoundryMcpToolsOperationsInvokeTool: + """Tests for FoundryMcpToolsOperations.invoke_tool public method.""" + + @pytest.mark.asyncio + async def test_invoke_tool_returns_server_response(self, sample_resolved_mcp_tool): + """Test invoke_tool returns the response from server.""" + mock_client = AsyncMock() + + expected_response = { + "jsonrpc": "2.0", + "id": 2, + "result": { + "content": [{"type": "text", "text": "Hello World"}] + } + } + mock_response = create_mock_http_response(200, expected_response) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = await ops.invoke_tool(sample_resolved_mcp_tool, {"query": "test"}) + + assert result == expected_response + + @pytest.mark.asyncio + async def test_invoke_tool_with_empty_arguments(self, sample_resolved_mcp_tool): + """Test invoke_tool works with empty arguments.""" + mock_client = AsyncMock() + + expected_response = {"result": "success"} + mock_response = create_mock_http_response(200, expected_response) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = await ops.invoke_tool(sample_resolved_mcp_tool, {}) + + assert result == expected_response + + @pytest.mark.asyncio + async def test_invoke_tool_with_complex_arguments(self, sample_resolved_mcp_tool): + """Test invoke_tool handles complex nested arguments.""" + mock_client = AsyncMock() + + mock_response = create_mock_http_response(200, {"result": "ok"}) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + complex_args = { + "text": "sample text", + "options": { + "temperature": 0.7, + "max_tokens": 100 + }, + "tags": ["tag1", "tag2"] + } + + result = await ops.invoke_tool(sample_resolved_mcp_tool, complex_args) + + assert result == {"result": "ok"} + mock_client.send_request.assert_called_once() + + @pytest.mark.asyncio + async def test_invoke_tool_with_connected_tool_raises_error( + self, + sample_resolved_connected_tool + ): + """Test invoke_tool raises ToolInvocationError for non-MCP tool.""" + mock_client = AsyncMock() + ops = FoundryMcpToolsOperations(mock_client) + + with pytest.raises(ToolInvocationError) as exc_info: + await ops.invoke_tool(sample_resolved_connected_tool, {}) + + assert "not a Foundry-hosted MCP tool" in str(exc_info.value) + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_invoke_tool_with_configuration_and_metadata(self): + """Test invoke_tool handles tool with configuration and metadata.""" + mock_client = AsyncMock() + + # Create tool with configuration + tool_def = FoundryHostedMcpTool( + name="image_generation", + configuration={"model_deployment_name": "dall-e-3"} + ) + + # Create tool details with metadata schema + meta_schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "model": SchemaProperty(type=SchemaType.STRING) + } + ) + details = FoundryToolDetails( + name="image_generation", + description="Generate images", + input_schema=SchemaDefinition(type=SchemaType.OBJECT, properties={}), + metadata=meta_schema + ) + resolved_tool = ResolvedFoundryTool(definition=tool_def, details=details) + + mock_response = create_mock_http_response(200, {"result": "image_url"}) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = await ops.invoke_tool(resolved_tool, {"prompt": "a cat"}) + + assert result == {"result": "image_url"} + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/test_client.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/test_client.py new file mode 100644 index 000000000000..de60f545e089 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/test_client.py @@ -0,0 +1,485 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolClient - testing only public methods.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from azure.ai.agentserver.core.tools.client._client import FoundryToolClient +from azure.ai.agentserver.core.tools.client._models import ( + FoundryToolDetails, + FoundryToolSource, + ResolvedFoundryTool, +) +from azure.ai.agentserver.core.tools._exceptions import ToolInvocationError + +from ..conftest import create_mock_http_response + + +class TestFoundryToolClientInit: + """Tests for FoundryToolClient.__init__ public method.""" + + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + def test_init_with_valid_endpoint_and_credential(self, mock_pipeline_client_class, mock_credential): + """Test client can be initialized with valid endpoint and credential.""" + endpoint = "https://fake-project-endpoint.site" + + client = FoundryToolClient(endpoint, mock_credential) + + # Verify client was created with correct base_url + call_kwargs = mock_pipeline_client_class.call_args + assert call_kwargs[1]["base_url"] == endpoint + assert client is not None + + +class TestFoundryToolClientListTools: + """Tests for FoundryToolClient.list_tools public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_empty_collection_returns_empty_list( + self, + mock_pipeline_client_class, + mock_credential + ): + """Test list_tools returns empty list when given empty collection.""" + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + + result = await client.list_tools([], agent_name="test-agent") + + assert result == [] + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_with_single_mcp_tool_returns_resolved_tools( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools with a single MCP tool returns resolved tools.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Mock HTTP response for MCP tools listing + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Test MCP tool description", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools([sample_hosted_mcp_tool], agent_name="test-agent") + + assert len(result) == 1 + assert isinstance(result[0], ResolvedFoundryTool) + assert result[0].name == sample_hosted_mcp_tool.name + assert result[0].source == FoundryToolSource.HOSTED_MCP + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_with_single_connected_tool_returns_resolved_tools( + self, + mock_pipeline_client_class, + mock_credential, + sample_connected_tool, + sample_user_info + ): + """Test list_tools with a single connected tool returns resolved tools.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Mock HTTP response for connected tools listing + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "connected_test_tool", + "description": "Test connected tool", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools( + [sample_connected_tool], + agent_name="test-agent", + user=sample_user_info + ) + + assert len(result) == 1 + assert isinstance(result[0], ResolvedFoundryTool) + assert result[0].name == "connected_test_tool" + assert result[0].source == FoundryToolSource.CONNECTED + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_with_mixed_tool_types_returns_all_resolved( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool, + sample_connected_tool, + sample_user_info + ): + """Test list_tools with both MCP and connected tools returns all resolved tools.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # We need to return different responses based on the request + mcp_response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "MCP tool", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + connected_response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "connected_tool", + "description": "Connected tool", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + + # Mock to return different responses for different requests + mock_client_instance.send_request.side_effect = [ + create_mock_http_response(200, mcp_response_data), + create_mock_http_response(200, connected_response_data) + ] + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools( + [sample_hosted_mcp_tool, sample_connected_tool], + agent_name="test-agent", + user=sample_user_info + ) + + assert len(result) == 2 + sources = {tool.source for tool in result} + assert FoundryToolSource.HOSTED_MCP in sources + assert FoundryToolSource.CONNECTED in sources + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_filters_unlisted_mcp_tools( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools only returns tools that are in the allowed list.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Server returns more tools than requested + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Requested tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "unrequested_tool", + "description": "This tool was not requested", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools([sample_hosted_mcp_tool], agent_name="test-agent") + + # Should only return the requested tool + assert len(result) == 1 + assert result[0].name == sample_hosted_mcp_tool.name + + +class TestFoundryToolClientListToolsDetails: + """Tests for FoundryToolClient.list_tools_details public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_details_returns_mapping_structure( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools_details returns correct mapping structure.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Test tool", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools_details([sample_hosted_mcp_tool], agent_name="test-agent") + + assert isinstance(result, dict) + assert sample_hosted_mcp_tool.id in result + assert len(result[sample_hosted_mcp_tool.id]) == 1 + assert isinstance(result[sample_hosted_mcp_tool.id][0], FoundryToolDetails) + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_details_groups_multiple_tools_by_definition( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools_details groups multiple tools from same source by definition ID.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Server returns multiple tools for the same MCP source + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Tool variant 1", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools_details([sample_hosted_mcp_tool], agent_name="test-agent") + + # All tools should be grouped under the same definition ID + assert sample_hosted_mcp_tool.id in result + + +class TestFoundryToolClientInvokeTool: + """Tests for FoundryToolClient.invoke_tool public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_mcp_tool_returns_result( + self, + mock_pipeline_client_class, + mock_credential, + sample_resolved_mcp_tool + ): + """Test invoke_tool with MCP tool returns the invocation result.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + expected_result = {"result": {"content": [{"text": "Hello World"}]}} + mock_response = create_mock_http_response(200, expected_result) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.invoke_tool( + sample_resolved_mcp_tool, + arguments={"input": "test"}, + agent_name="test-agent" + ) + + assert result == expected_result + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_connected_tool_returns_result( + self, + mock_pipeline_client_class, + mock_credential, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool with connected tool returns the invocation result.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + expected_value = {"output": "Connected tool result"} + response_data = {"toolResult": expected_value} + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.invoke_tool( + sample_resolved_connected_tool, + arguments={"input": "test"}, + agent_name="test-agent", + user=sample_user_info + ) + + assert result == expected_value + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_tool_with_complex_arguments( + self, + mock_pipeline_client_class, + mock_credential, + sample_resolved_mcp_tool + ): + """Test invoke_tool correctly passes complex arguments.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + mock_response = create_mock_http_response(200, {"result": "success"}) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + complex_args = { + "string_param": "value", + "number_param": 42, + "bool_param": True, + "list_param": [1, 2, 3], + "nested_param": {"key": "value"} + } + + result = await client.invoke_tool( + sample_resolved_mcp_tool, + arguments=complex_args, + agent_name="test-agent" + ) + + # Verify request was made + mock_client_instance.send_request.assert_called_once() + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_tool_with_unsupported_source_raises_error( + self, + mock_pipeline_client_class, + mock_credential, + sample_tool_details + ): + """Test invoke_tool raises ToolInvocationError for unsupported tool source.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Create a mock tool with unsupported source + mock_definition = MagicMock() + mock_definition.source = "unsupported_source" + mock_tool = MagicMock(spec=ResolvedFoundryTool) + mock_tool.definition = mock_definition + mock_tool.source = "unsupported_source" + mock_tool.details = sample_tool_details + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + + with pytest.raises(ToolInvocationError) as exc_info: + await client.invoke_tool( + mock_tool, + arguments={"input": "test"}, + agent_name="test-agent" + ) + + assert "Unsupported tool source" in str(exc_info.value) + + +class TestFoundryToolClientClose: + """Tests for FoundryToolClient.close public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_close_closes_underlying_client( + self, + mock_pipeline_client_class, + mock_credential + ): + """Test close() properly closes the underlying HTTP client.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + await client.close() + + mock_client_instance.close.assert_called_once() + + +class TestFoundryToolClientContextManager: + """Tests for FoundryToolClient async context manager protocol.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_async_context_manager_enters_and_exits( + self, + mock_pipeline_client_class, + mock_credential + ): + """Test client can be used as async context manager.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + async with FoundryToolClient("https://fake-project-endpoint.site", mock_credential) as client: + assert client is not None + mock_client_instance.__aenter__.assert_called_once() + + mock_client_instance.__aexit__.assert_called_once() + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/test_configuration.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/test_configuration.py new file mode 100644 index 000000000000..2f3c2710a3fc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/client/test_configuration.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolClientConfiguration.""" + +from azure.core.pipeline import policies + +from azure.ai.agentserver.core.tools.client._configuration import FoundryToolClientConfiguration + + +class TestFoundryToolClientConfiguration: + """Tests for FoundryToolClientConfiguration class.""" + + def test_init_creates_all_required_policies(self, mock_credential): + """Test that initialization creates all required pipeline policies.""" + config = FoundryToolClientConfiguration(mock_credential) + + assert isinstance(config.retry_policy, policies.AsyncRetryPolicy) + assert isinstance(config.logging_policy, policies.NetworkTraceLoggingPolicy) + assert isinstance(config.request_id_policy, policies.RequestIdPolicy) + assert isinstance(config.http_logging_policy, policies.HttpLoggingPolicy) + assert isinstance(config.user_agent_policy, policies.UserAgentPolicy) + assert isinstance(config.authentication_policy, policies.AsyncBearerTokenCredentialPolicy) + assert isinstance(config.redirect_policy, policies.AsyncRedirectPolicy) + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/conftest.py new file mode 100644 index 000000000000..8849ce8aafbf --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/conftest.py @@ -0,0 +1,127 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for tools unit tests.""" +import json +from typing import Any, Dict, Optional +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, + UserInfo, +) + + +@pytest.fixture +def mock_credential(): + """Create a mock async token credential.""" + credential = AsyncMock() + credential.get_token = AsyncMock(return_value=MagicMock(token="test-token")) + return credential + + +@pytest.fixture +def sample_user_info(): + """Create a sample UserInfo instance.""" + return UserInfo(object_id="test-object-id", tenant_id="test-tenant-id") + + +@pytest.fixture +def sample_hosted_mcp_tool(): + """Create a sample FoundryHostedMcpTool.""" + return FoundryHostedMcpTool( + name="test_mcp_tool", + configuration={"model_deployment_name": "gpt-4"} + ) + + +@pytest.fixture +def sample_connected_tool(): + """Create a sample FoundryConnectedTool.""" + return FoundryConnectedTool( + protocol="mcp", + project_connection_id="test-connection-id" + ) + + +@pytest.fixture +def sample_schema_definition(): + """Create a sample SchemaDefinition.""" + return SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input parameter") + }, + required={"input"} + ) + + +@pytest.fixture +def sample_tool_details(sample_schema_definition): + """Create a sample FoundryToolDetails.""" + return FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=sample_schema_definition + ) + + +@pytest.fixture +def sample_resolved_mcp_tool(sample_hosted_mcp_tool, sample_tool_details): + """Create a sample ResolvedFoundryTool for MCP.""" + return ResolvedFoundryTool( + definition=sample_hosted_mcp_tool, + details=sample_tool_details + ) + + +@pytest.fixture +def sample_resolved_connected_tool(sample_connected_tool, sample_tool_details): + """Create a sample ResolvedFoundryTool for connected tools.""" + return ResolvedFoundryTool( + definition=sample_connected_tool, + details=sample_tool_details + ) + + +def create_mock_http_response( + status_code: int = 200, + json_data: Optional[Dict[str, Any]] = None +) -> AsyncMock: + """Create a mock HTTP response that simulates real Azure SDK response behavior. + + This mock matches the behavior expected by BaseOperations._extract_response_json, + where response.text() and response.body() are synchronous methods that return + the actual string/bytes values directly. + + :param status_code: HTTP status code. + :param json_data: JSON data to return. + :return: Mock response object. + """ + response = AsyncMock() + response.status_code = status_code + + if json_data is not None: + json_str = json.dumps(json_data) + json_bytes = json_str.encode("utf-8") + # text() and body() are synchronous methods in AsyncHttpResponse + # They must be MagicMock (not AsyncMock) to return values directly when called + response.text = MagicMock(return_value=json_str) + response.body = MagicMock(return_value=json_bytes) + else: + response.text = MagicMock(return_value="") + response.body = MagicMock(return_value=b"") + + # Support async context manager + response.__aenter__ = AsyncMock(return_value=response) + response.__aexit__ = AsyncMock(return_value=None) + + return response diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/__init__.py new file mode 100644 index 000000000000..964fac9d8a55 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Runtime unit tests package.""" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/conftest.py new file mode 100644 index 000000000000..52a371bdc958 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/conftest.py @@ -0,0 +1,39 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for runtime unit tests. + +Common fixtures are inherited from the parent conftest.py automatically by pytest. +""" +from unittest.mock import AsyncMock + +import pytest + + +@pytest.fixture +def mock_foundry_tool_client(): + """Create a mock FoundryToolClient.""" + client = AsyncMock() + client.list_tools = AsyncMock(return_value=[]) + client.list_tools_details = AsyncMock(return_value={}) + client.invoke_tool = AsyncMock(return_value={"result": "success"}) + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=None) + return client + + +@pytest.fixture +def mock_user_provider(sample_user_info): + """Create a mock UserProvider.""" + provider = AsyncMock() + provider.get_user = AsyncMock(return_value=sample_user_info) + return provider + + +@pytest.fixture +def mock_user_provider_none(): + """Create a mock UserProvider that returns None.""" + provider = AsyncMock() + provider.get_user = AsyncMock(return_value=None) + return provider + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_catalog.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_catalog.py new file mode 100644 index 000000000000..45b03f0530a2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_catalog.py @@ -0,0 +1,349 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _catalog.py - testing public methods of DefaultFoundryToolCatalog.""" +import asyncio +import pytest +from unittest.mock import AsyncMock + +from azure.ai.agentserver.core.tools.runtime._catalog import ( + DefaultFoundryToolCatalog, +) +from azure.ai.agentserver.core.tools.client._models import ( + FoundryToolDetails, + ResolvedFoundryTool, + UserInfo, +) + + +class TestFoundryToolCatalogGet: + """Tests for FoundryToolCatalog.get method.""" + + @pytest.mark.asyncio + async def test_get_returns_resolved_tool_when_found( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details, + sample_user_info + ): + """Test get returns a resolved tool when the tool is found.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.get(sample_hosted_mcp_tool) + + assert result is not None + assert isinstance(result, ResolvedFoundryTool) + assert result.details == sample_tool_details + + @pytest.mark.asyncio + async def test_get_returns_none_when_not_found( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool + ): + """Test get returns None when the tool is not found.""" + mock_foundry_tool_client.list_tools_details = AsyncMock(return_value={}) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.get(sample_hosted_mcp_tool) + + assert result is None + + +class TestDefaultFoundryToolCatalogList: + """Tests for DefaultFoundryToolCatalog.list method.""" + + @pytest.mark.asyncio + async def test_list_returns_empty_list_when_no_tools( + self, + mock_foundry_tool_client, + mock_user_provider + ): + """Test list returns empty list when no tools are provided.""" + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([]) + + assert result == [] + + @pytest.mark.asyncio + async def test_list_returns_resolved_tools( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details + ): + """Test list returns resolved tools.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([sample_hosted_mcp_tool]) + + assert len(result) == 1 + assert isinstance(result[0], ResolvedFoundryTool) + assert result[0].definition == sample_hosted_mcp_tool + assert result[0].details == sample_tool_details + + @pytest.mark.asyncio + async def test_list_multiple_tools_with_multiple_details( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_connected_tool, + sample_schema_definition + ): + """Test list returns all resolved tools when tools have multiple details.""" + details1 = FoundryToolDetails( + name="tool1", + description="First tool", + input_schema=sample_schema_definition + ) + details2 = FoundryToolDetails( + name="tool2", + description="Second tool", + input_schema=sample_schema_definition + ) + + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={ + sample_hosted_mcp_tool.id: [details1], + sample_connected_tool.id: [details2] + } + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([sample_hosted_mcp_tool, sample_connected_tool]) + + assert len(result) == 2 + names = {r.details.name for r in result} + assert names == {"tool1", "tool2"} + + @pytest.mark.asyncio + async def test_list_caches_results_for_hosted_mcp_tools( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details + ): + """Test that list caches results for hosted MCP tools.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + # First call + result1 = await catalog.list([sample_hosted_mcp_tool]) + # Second call should use cache + result2 = await catalog.list([sample_hosted_mcp_tool]) + + # Client should only be called once + assert mock_foundry_tool_client.list_tools_details.call_count == 1 + assert len(result1) == len(result2) == 1 + + @pytest.mark.asyncio + async def test_list_with_facade_dict( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_tool_details + ): + """Test list works with facade dictionaries.""" + facade = {"type": "custom_tool", "config": "value"} + expected_id = "hosted_mcp:custom_tool" + + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={expected_id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([facade]) + + assert len(result) == 1 + assert result[0].details == sample_tool_details + + @pytest.mark.asyncio + async def test_list_returns_multiple_details_per_tool( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_schema_definition + ): + """Test list returns multiple resolved tools when a tool has multiple details.""" + details1 = FoundryToolDetails( + name="function1", + description="First function", + input_schema=sample_schema_definition + ) + details2 = FoundryToolDetails( + name="function2", + description="Second function", + input_schema=sample_schema_definition + ) + + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [details1, details2]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([sample_hosted_mcp_tool]) + + assert len(result) == 2 + names = {r.details.name for r in result} + assert names == {"function1", "function2"} + + @pytest.mark.asyncio + async def test_list_handles_exception_from_client( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool + ): + """Test list propagates exception from client and clears cache.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + side_effect=RuntimeError("Network error") + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + with pytest.raises(RuntimeError, match="Network error"): + await catalog.list([sample_hosted_mcp_tool]) + + @pytest.mark.asyncio + async def test_list_connected_tool_cache_key_includes_user( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_connected_tool, + sample_tool_details, + sample_user_info + ): + """Test that connected tool cache key includes user info.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_connected_tool.id: [sample_tool_details]} + ) + + # Create a new user provider returning a different user + other_user = UserInfo(object_id="other-oid", tenant_id="other-tid") + mock_user_provider2 = AsyncMock() + mock_user_provider2.get_user = AsyncMock(return_value=other_user) + + catalog1 = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + catalog2 = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider2, + agent_name="test-agent" + ) + + # Both catalogs should be able to list tools + result1 = await catalog1.list([sample_connected_tool]) + result2 = await catalog2.list([sample_connected_tool]) + + assert len(result1) == 1 + assert len(result2) == 1 + + +class TestCachedFoundryToolCatalogConcurrency: + """Tests for CachedFoundryToolCatalog concurrency handling.""" + + @pytest.mark.asyncio + async def test_concurrent_requests_share_single_fetch( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details + ): + """Test that concurrent requests for the same tool share a single fetch.""" + call_count = 0 + fetch_event = asyncio.Event() + + async def slow_fetch(*args, **kwargs): + nonlocal call_count + call_count += 1 + await fetch_event.wait() + return {sample_hosted_mcp_tool.id: [sample_tool_details]} + + mock_foundry_tool_client.list_tools_details = slow_fetch + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + # Start two concurrent requests + task1 = asyncio.create_task(catalog.list([sample_hosted_mcp_tool])) + task2 = asyncio.create_task(catalog.list([sample_hosted_mcp_tool])) + + # Allow tasks to start + await asyncio.sleep(0.01) + + # Release the fetch + fetch_event.set() + + results = await asyncio.gather(task1, task2) + + # Both should get results, but fetch should only be called once + assert len(results[0]) == 1 + assert len(results[1]) == 1 + assert call_count == 1 diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_facade.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_facade.py new file mode 100644 index 000000000000..c5377dc339a4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_facade.py @@ -0,0 +1,180 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _facade.py - testing public function ensure_foundry_tool.""" +import pytest + +from azure.ai.agentserver.core.tools.runtime._facade import ensure_foundry_tool +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolProtocol, + FoundryToolSource, +) +from azure.ai.agentserver.core.tools._exceptions import InvalidToolFacadeError + + +class TestEnsureFoundryTool: + """Tests for ensure_foundry_tool public function.""" + + def test_returns_same_instance_when_given_foundry_tool(self, sample_hosted_mcp_tool): + """Test that passing a FoundryTool returns the same instance.""" + result = ensure_foundry_tool(sample_hosted_mcp_tool) + + assert result is sample_hosted_mcp_tool + + def test_returns_same_instance_for_connected_tool(self, sample_connected_tool): + """Test that passing a FoundryConnectedTool returns the same instance.""" + result = ensure_foundry_tool(sample_connected_tool) + + assert result is sample_connected_tool + + def test_converts_facade_with_mcp_protocol_to_connected_tool(self): + """Test that a facade with 'mcp' protocol is converted to FoundryConnectedTool.""" + facade = { + "type": "mcp", + "project_connection_id": "my-connection" + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.protocol == FoundryToolProtocol.MCP + assert result.project_connection_id == "my-connection" + assert result.source == FoundryToolSource.CONNECTED + + def test_converts_facade_with_a2a_protocol_to_connected_tool(self): + """Test that a facade with 'a2a' protocol is converted to FoundryConnectedTool.""" + facade = { + "type": "a2a", + "project_connection_id": "my-a2a-connection" + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.protocol == FoundryToolProtocol.A2A + assert result.project_connection_id == "my-a2a-connection" + + def test_converts_facade_with_unknown_type_to_hosted_mcp_tool(self): + """Test that a facade with unknown type is converted to FoundryHostedMcpTool.""" + facade = { + "type": "my_custom_tool", + "some_config": "value123", + "another_config": True + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryHostedMcpTool) + assert result.name == "my_custom_tool" + assert result.configuration == {"some_config": "value123", "another_config": True} + assert result.source == FoundryToolSource.HOSTED_MCP + + def test_raises_error_when_type_is_missing(self): + """Test that InvalidToolFacadeError is raised when 'type' is missing.""" + facade = {"project_connection_id": "my-connection"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "type" in str(exc_info.value).lower() + + def test_raises_error_when_type_is_empty_string(self): + """Test that InvalidToolFacadeError is raised when 'type' is empty string.""" + facade = {"type": "", "project_connection_id": "my-connection"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "type" in str(exc_info.value).lower() + + def test_raises_error_when_type_is_not_string(self): + """Test that InvalidToolFacadeError is raised when 'type' is not a string.""" + facade = {"type": 123, "project_connection_id": "my-connection"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "type" in str(exc_info.value).lower() + + def test_raises_error_when_mcp_protocol_missing_connection_id(self): + """Test that InvalidToolFacadeError is raised when mcp protocol is missing project_connection_id.""" + facade = {"type": "mcp"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "project_connection_id" in str(exc_info.value) + + def test_raises_error_when_a2a_protocol_has_empty_connection_id(self): + """Test that InvalidToolFacadeError is raised when a2a protocol has empty project_connection_id.""" + facade = {"type": "a2a", "project_connection_id": ""} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "project_connection_id" in str(exc_info.value) + + def test_parses_resource_id_format_connection_id(self): + """Test that resource ID format project_connection_id is parsed correctly.""" + resource_id = ( + "/subscriptions/sub-123/resourceGroups/rg-test/providers/" + "Microsoft.CognitiveServices/accounts/acc-test/projects/proj-test/connections/my-conn-name" + ) + facade = { + "type": "mcp", + "project_connection_id": resource_id + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.project_connection_id == "my-conn-name" + + def test_raises_error_for_invalid_resource_id_format(self): + """Test that InvalidToolFacadeError is raised for invalid resource ID format.""" + invalid_resource_id = "/subscriptions/sub-123/invalid/path" + facade = { + "type": "mcp", + "project_connection_id": invalid_resource_id + } + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "Invalid resource ID format" in str(exc_info.value) + + def test_uses_simple_connection_name_as_is(self): + """Test that simple connection name is used as-is without parsing.""" + facade = { + "type": "mcp", + "project_connection_id": "simple-connection-name" + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.project_connection_id == "simple-connection-name" + + def test_original_facade_not_modified(self): + """Test that the original facade dictionary is not modified.""" + facade = { + "type": "my_tool", + "config_key": "config_value" + } + original_facade = facade.copy() + + ensure_foundry_tool(facade) + + assert facade == original_facade + + def test_hosted_mcp_tool_with_no_extra_configuration(self): + """Test that hosted MCP tool works with no extra configuration.""" + facade = {"type": "simple_tool"} + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryHostedMcpTool) + assert result.name == "simple_tool" + assert result.configuration == {} diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_invoker.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_invoker.py new file mode 100644 index 000000000000..b2a222c09d6e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_invoker.py @@ -0,0 +1,198 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _invoker.py - testing public methods of DefaultFoundryToolInvoker.""" +import pytest +from unittest.mock import AsyncMock + +from azure.ai.agentserver.core.tools.runtime._invoker import DefaultFoundryToolInvoker + + +class TestDefaultFoundryToolInvokerResolvedTool: + """Tests for DefaultFoundryToolInvoker.resolved_tool property.""" + + def test_resolved_tool_returns_tool_passed_at_init( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test resolved_tool property returns the tool passed during initialization.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + assert invoker.resolved_tool is sample_resolved_mcp_tool + + def test_resolved_tool_returns_connected_tool( + self, + sample_resolved_connected_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test resolved_tool property returns connected tool.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_connected_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + assert invoker.resolved_tool is sample_resolved_connected_tool + + +class TestDefaultFoundryToolInvokerInvoke: + """Tests for DefaultFoundryToolInvoker.invoke method.""" + + @pytest.mark.asyncio + async def test_invoke_calls_client_with_correct_arguments( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider, + sample_user_info + ): + """Test invoke calls client.invoke_tool with correct arguments.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + arguments = {"input": "test value", "count": 5} + + await invoker.invoke(arguments) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + arguments, + "test-agent", + sample_user_info + ) + + @pytest.mark.asyncio + async def test_invoke_returns_result_from_client( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test invoke returns the result from client.invoke_tool.""" + expected_result = {"output": "test result", "status": "completed"} + mock_foundry_tool_client.invoke_tool = AsyncMock(return_value=expected_result) + + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await invoker.invoke({"input": "test"}) + + assert result == expected_result + + @pytest.mark.asyncio + async def test_invoke_with_empty_arguments( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider, + sample_user_info + ): + """Test invoke works with empty arguments dictionary.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + await invoker.invoke({}) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + {}, + "test-agent", + sample_user_info + ) + + @pytest.mark.asyncio + async def test_invoke_with_none_user( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider_none + ): + """Test invoke works when user provider returns None.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider_none, + agent_name="test-agent" + ) + + await invoker.invoke({"input": "test"}) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + {"input": "test"}, + "test-agent", + None + ) + + @pytest.mark.asyncio + async def test_invoke_propagates_client_exception( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test invoke propagates exceptions from client.invoke_tool.""" + mock_foundry_tool_client.invoke_tool = AsyncMock( + side_effect=RuntimeError("Client error") + ) + + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + with pytest.raises(RuntimeError, match="Client error"): + await invoker.invoke({"input": "test"}) + + @pytest.mark.asyncio + async def test_invoke_with_complex_nested_arguments( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider, + sample_user_info + ): + """Test invoke with complex nested argument structure.""" + complex_args = { + "nested": {"key1": "value1", "key2": 123}, + "list": [1, 2, 3], + "mixed": [{"a": 1}, {"b": 2}] + } + + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + await invoker.invoke(complex_args) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + complex_args, + "test-agent", + sample_user_info + ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_resolver.py new file mode 100644 index 000000000000..7bdaa8f957a9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_resolver.py @@ -0,0 +1,202 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _resolver.py - testing public methods of DefaultFoundryToolInvocationResolver.""" +import pytest +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.runtime._resolver import DefaultFoundryToolInvocationResolver +from azure.ai.agentserver.core.tools.runtime._invoker import DefaultFoundryToolInvoker +from azure.ai.agentserver.core.tools._exceptions import UnableToResolveToolInvocationError +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, +) + + +class TestDefaultFoundryToolInvocationResolverResolve: + """Tests for DefaultFoundryToolInvocationResolver.resolve method.""" + + @pytest.fixture + def mock_catalog(self, sample_resolved_mcp_tool): + """Create a mock FoundryToolCatalog.""" + catalog = AsyncMock() + catalog.get = AsyncMock(return_value=sample_resolved_mcp_tool) + catalog.list = AsyncMock(return_value=[sample_resolved_mcp_tool]) + return catalog + + @pytest.mark.asyncio + async def test_resolve_with_resolved_tool_returns_invoker_directly( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_mcp_tool + ): + """Test resolve returns invoker directly when given ResolvedFoundryTool.""" + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(sample_resolved_mcp_tool) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + assert invoker.resolved_tool is sample_resolved_mcp_tool + # Catalog should not be called when ResolvedFoundryTool is passed + mock_catalog.get.assert_not_called() + + @pytest.mark.asyncio + async def test_resolve_with_foundry_tool_uses_catalog( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_resolved_mcp_tool + ): + """Test resolve uses catalog to resolve FoundryTool.""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_mcp_tool) + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(sample_hosted_mcp_tool) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + mock_catalog.get.assert_called_once_with(sample_hosted_mcp_tool) + + @pytest.mark.asyncio + async def test_resolve_with_facade_dict_uses_catalog( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_connected_tool + ): + """Test resolve converts facade dict and uses catalog.""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_connected_tool) + facade = { + "type": "mcp", + "project_connection_id": "test-connection" + } + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(facade) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + mock_catalog.get.assert_called_once() + # Verify the facade was converted to FoundryConnectedTool + call_arg = mock_catalog.get.call_args[0][0] + assert isinstance(call_arg, FoundryConnectedTool) + + @pytest.mark.asyncio + async def test_resolve_raises_error_when_tool_not_found_in_catalog( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool + ): + """Test resolve raises UnableToResolveToolInvocationError when catalog returns None.""" + mock_catalog.get = AsyncMock(return_value=None) + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + with pytest.raises(UnableToResolveToolInvocationError) as exc_info: + await resolver.resolve(sample_hosted_mcp_tool) + + assert exc_info.value.tool is sample_hosted_mcp_tool + assert "Unable to resolve tool" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_resolve_with_hosted_mcp_facade( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_mcp_tool + ): + """Test resolve with hosted MCP facade (unknown type becomes FoundryHostedMcpTool).""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_mcp_tool) + facade = { + "type": "custom_mcp_tool", + "config_key": "config_value" + } + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(facade) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + # Verify the facade was converted to FoundryHostedMcpTool + call_arg = mock_catalog.get.call_args[0][0] + assert isinstance(call_arg, FoundryHostedMcpTool) + assert call_arg.name == "custom_mcp_tool" + + @pytest.mark.asyncio + async def test_resolve_returns_invoker_with_correct_agent_name( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_mcp_tool + ): + """Test resolve creates invoker with the correct agent name.""" + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="custom-agent-name" + ) + + invoker = await resolver.resolve(sample_resolved_mcp_tool) + + # Verify invoker was created with correct agent name by checking internal state + assert invoker._agent_name == "custom-agent-name" + + @pytest.mark.asyncio + async def test_resolve_with_connected_tool_directly( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_connected_tool, + sample_resolved_connected_tool + ): + """Test resolve with FoundryConnectedTool directly.""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_connected_tool) + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(sample_connected_tool) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + mock_catalog.get.assert_called_once_with(sample_connected_tool) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_runtime.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_runtime.py new file mode 100644 index 000000000000..e42fc29a76cd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_runtime.py @@ -0,0 +1,283 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _runtime.py - testing public methods of DefaultFoundryToolRuntime.""" +import os +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from azure.ai.agentserver.core.tools.runtime._runtime import DefaultFoundryToolRuntime +from azure.ai.agentserver.core.tools.runtime._catalog import DefaultFoundryToolCatalog +from azure.ai.agentserver.core.tools.runtime._resolver import DefaultFoundryToolInvocationResolver +from azure.ai.agentserver.core.tools.runtime._user import ContextVarUserProvider + + +class TestDefaultFoundryToolRuntimeInit: + """Tests for DefaultFoundryToolRuntime initialization.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_creates_client_with_endpoint_and_credential( + self, + mock_client_class, + mock_credential + ): + """Test initialization creates client with correct endpoint and credential.""" + endpoint = "https://test-project.azure.com" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint=endpoint, + credential=mock_credential + ) + + mock_client_class.assert_called_once_with( + endpoint=endpoint, + credential=mock_credential + ) + assert runtime is not None + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_uses_default_user_provider_when_none_provided( + self, + mock_client_class, + mock_credential + ): + """Test initialization uses ContextVarUserProvider when user_provider is None.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert isinstance(runtime._user_provider, ContextVarUserProvider) + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_uses_custom_user_provider( + self, + mock_client_class, + mock_credential, + mock_user_provider + ): + """Test initialization uses custom user provider when provided.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential, + user_provider=mock_user_provider + ) + + assert runtime._user_provider is mock_user_provider + + @patch.dict(os.environ, {"AGENT_NAME": "custom-agent"}) + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_reads_agent_name_from_environment( + self, + mock_client_class, + mock_credential + ): + """Test initialization reads agent name from environment variable.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert runtime._agent_name == "custom-agent" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_uses_default_agent_name_when_env_not_set( + self, + mock_client_class, + mock_credential + ): + """Test initialization uses default agent name when env var is not set.""" + mock_client_class.return_value = MagicMock() + + # Ensure AGENT_NAME is not set + env_copy = os.environ.copy() + if "AGENT_NAME" in env_copy: + del env_copy["AGENT_NAME"] + + with patch.dict(os.environ, env_copy, clear=True): + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert runtime._agent_name == "$default" + + +class TestDefaultFoundryToolRuntimeCatalog: + """Tests for DefaultFoundryToolRuntime.catalog property.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_catalog_returns_default_catalog( + self, + mock_client_class, + mock_credential + ): + """Test catalog property returns DefaultFoundryToolCatalog.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert isinstance(runtime.catalog, DefaultFoundryToolCatalog) + + +class TestDefaultFoundryToolRuntimeInvocation: + """Tests for DefaultFoundryToolRuntime.invocation property.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_invocation_returns_default_resolver( + self, + mock_client_class, + mock_credential + ): + """Test invocation property returns DefaultFoundryToolInvocationResolver.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert isinstance(runtime.invocation, DefaultFoundryToolInvocationResolver) + + +class TestDefaultFoundryToolRuntimeInvoke: + """Tests for DefaultFoundryToolRuntime.invoke method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_invoke_resolves_and_invokes_tool( + self, + mock_client_class, + mock_credential, + sample_resolved_mcp_tool + ): + """Test invoke resolves the tool and calls the invoker.""" + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + # Mock the invocation resolver + mock_invoker = AsyncMock() + mock_invoker.invoke = AsyncMock(return_value={"result": "success"}) + runtime._invocation.resolve = AsyncMock(return_value=mock_invoker) + + result = await runtime.invoke(sample_resolved_mcp_tool, {"input": "test"}) + + assert result == {"result": "success"} + runtime._invocation.resolve.assert_called_once_with(sample_resolved_mcp_tool) + mock_invoker.invoke.assert_called_once_with({"input": "test"}) + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_invoke_with_facade_dict( + self, + mock_client_class, + mock_credential + ): + """Test invoke works with facade dictionary.""" + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + facade = {"type": "custom_tool", "config": "value"} + + # Mock the invocation resolver + mock_invoker = AsyncMock() + mock_invoker.invoke = AsyncMock(return_value={"output": "done"}) + runtime._invocation.resolve = AsyncMock(return_value=mock_invoker) + + result = await runtime.invoke(facade, {"param": "value"}) + + assert result == {"output": "done"} + runtime._invocation.resolve.assert_called_once_with(facade) + + +class TestDefaultFoundryToolRuntimeContextManager: + """Tests for DefaultFoundryToolRuntime async context manager.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_aenter_returns_runtime_and_enters_client( + self, + mock_client_class, + mock_credential + ): + """Test __aenter__ enters client and returns runtime.""" + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + async with runtime as r: + assert r is runtime + mock_client_instance.__aenter__.assert_called_once() + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_aexit_exits_client( + self, + mock_client_class, + mock_credential + ): + """Test __aexit__ exits client properly.""" + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + async with runtime: + pass + + mock_client_instance.__aexit__.assert_called_once() + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_aexit_called_on_exception( + self, + mock_client_class, + mock_credential + ): + """Test __aexit__ is called even when exception occurs.""" + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + with pytest.raises(ValueError): + async with runtime: + raise ValueError("Test error") + + mock_client_instance.__aexit__.assert_called_once() diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_starlette.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_starlette.py new file mode 100644 index 000000000000..d1d72004d011 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_starlette.py @@ -0,0 +1,261 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _starlette.py - testing public methods of UserInfoContextMiddleware.""" +import pytest +from contextvars import ContextVar +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.client._models import UserInfo + + +class TestUserInfoContextMiddlewareInstall: + """Tests for UserInfoContextMiddleware.install class method.""" + + def test_install_adds_middleware_to_starlette_app(self): + """Test install adds middleware to Starlette application.""" + # Import here to avoid requiring starlette when not needed + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_app = MagicMock() + + UserInfoContextMiddleware.install(mock_app) + + mock_app.add_middleware.assert_called_once() + call_args = mock_app.add_middleware.call_args + assert call_args[0][0] == UserInfoContextMiddleware + + def test_install_uses_default_context_when_none_provided(self): + """Test install uses default user context when none is provided.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + from azure.ai.agentserver.core.tools.runtime._user import ContextVarUserProvider + + mock_app = MagicMock() + + UserInfoContextMiddleware.install(mock_app) + + call_kwargs = mock_app.add_middleware.call_args[1] + assert call_kwargs["user_info_var"] is ContextVarUserProvider.default_user_info_context + + def test_install_uses_custom_context(self): + """Test install uses custom user context when provided.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_app = MagicMock() + custom_context = ContextVar("custom_context") + + UserInfoContextMiddleware.install(mock_app, user_context=custom_context) + + call_kwargs = mock_app.add_middleware.call_args[1] + assert call_kwargs["user_info_var"] is custom_context + + def test_install_uses_custom_resolver(self): + """Test install uses custom user resolver when provided.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_app = MagicMock() + + async def custom_resolver(request): + return UserInfo(object_id="custom-oid", tenant_id="custom-tid") + + UserInfoContextMiddleware.install(mock_app, user_resolver=custom_resolver) + + call_kwargs = mock_app.add_middleware.call_args[1] + assert call_kwargs["user_resolver"] is custom_resolver + + +class TestUserInfoContextMiddlewareDispatch: + """Tests for UserInfoContextMiddleware.dispatch method.""" + + @pytest.mark.asyncio + async def test_dispatch_sets_user_in_context(self): + """Test dispatch sets user info in context variable.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + user_info = UserInfo(object_id="test-oid", tenant_id="test-tid") + + async def mock_resolver(request): + return user_info + + # Create a simple mock app + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + captured_user = None + + async def call_next(request): + nonlocal captured_user + captured_user = user_context.get(None) + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + assert captured_user is user_info + + @pytest.mark.asyncio + async def test_dispatch_resets_context_after_request(self): + """Test dispatch resets context variable after request completes.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + original_user = UserInfo(object_id="original-oid", tenant_id="original-tid") + user_context.set(original_user) + + new_user = UserInfo(object_id="new-oid", tenant_id="new-tid") + + async def mock_resolver(request): + return new_user + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + + async def call_next(request): + # During request, should have new_user + assert user_context.get(None) is new_user + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + # After request, context should be reset to original value + assert user_context.get(None) is original_user + + @pytest.mark.asyncio + async def test_dispatch_resets_context_on_exception(self): + """Test dispatch resets context even when call_next raises exception.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + original_user = UserInfo(object_id="original-oid", tenant_id="original-tid") + user_context.set(original_user) + + new_user = UserInfo(object_id="new-oid", tenant_id="new-tid") + + async def mock_resolver(request): + return new_user + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + + async def call_next(request): + raise RuntimeError("Request failed") + + with pytest.raises(RuntimeError, match="Request failed"): + await middleware.dispatch(mock_request, call_next) + + # Context should still be reset to original + assert user_context.get(None) is original_user + + @pytest.mark.asyncio + async def test_dispatch_handles_none_user(self): + """Test dispatch handles None user from resolver.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + + async def mock_resolver(request): + return None + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + captured_user = "not_set" + + async def call_next(request): + nonlocal captured_user + captured_user = user_context.get("default") + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + assert captured_user is None + + @pytest.mark.asyncio + async def test_dispatch_calls_resolver_with_request(self): + """Test dispatch calls user resolver with the request object.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + captured_request = None + + async def mock_resolver(request): + nonlocal captured_request + captured_request = request + return UserInfo(object_id="oid", tenant_id="tid") + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + mock_request.url = "https://test.com/api" + + async def call_next(request): + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + assert captured_request is mock_request + + +class TestUserInfoContextMiddlewareDefaultResolver: + """Tests for UserInfoContextMiddleware default resolver.""" + + @pytest.mark.asyncio + async def test_default_resolver_extracts_user_from_headers(self): + """Test default resolver extracts user info from request headers.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_request = MagicMock() + mock_request.headers = { + "x-aml-oid": "header-object-id", + "x-aml-tid": "header-tenant-id" + } + + result = await UserInfoContextMiddleware._default_user_resolver(mock_request) + + assert result is not None + assert result.object_id == "header-object-id" + assert result.tenant_id == "header-tenant-id" + + @pytest.mark.asyncio + async def test_default_resolver_returns_none_when_headers_missing(self): + """Test default resolver returns None when required headers are missing.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_request = MagicMock() + mock_request.headers = {} + + result = await UserInfoContextMiddleware._default_user_resolver(mock_request) + + assert result is None diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_user.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_user.py new file mode 100644 index 000000000000..a909d9e5948a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/runtime/test_user.py @@ -0,0 +1,210 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _user.py - testing public methods of ContextVarUserProvider and resolve_user_from_headers.""" +import pytest +from contextvars import ContextVar + +from azure.ai.agentserver.core.tools.runtime._user import ( + ContextVarUserProvider, + resolve_user_from_headers, +) +from azure.ai.agentserver.core.tools.client._models import UserInfo + + +class TestContextVarUserProvider: + """Tests for ContextVarUserProvider public methods.""" + + @pytest.mark.asyncio + async def test_get_user_returns_none_when_context_not_set(self): + """Test get_user returns None when context variable is not set.""" + custom_context = ContextVar("test_user_context") + provider = ContextVarUserProvider(context=custom_context) + + result = await provider.get_user() + + assert result is None + + @pytest.mark.asyncio + async def test_get_user_returns_user_when_context_is_set(self, sample_user_info): + """Test get_user returns UserInfo when context variable is set.""" + custom_context = ContextVar("test_user_context") + custom_context.set(sample_user_info) + provider = ContextVarUserProvider(context=custom_context) + + result = await provider.get_user() + + assert result is sample_user_info + assert result.object_id == "test-object-id" + assert result.tenant_id == "test-tenant-id" + + @pytest.mark.asyncio + async def test_uses_default_context_when_none_provided(self, sample_user_info): + """Test that default context is used when no context is provided.""" + # Set value in default context + ContextVarUserProvider.default_user_info_context.set(sample_user_info) + provider = ContextVarUserProvider() + + result = await provider.get_user() + + assert result is sample_user_info + + @pytest.mark.asyncio + async def test_different_providers_share_same_default_context(self, sample_user_info): + """Test that different providers using default context share the same value.""" + ContextVarUserProvider.default_user_info_context.set(sample_user_info) + provider1 = ContextVarUserProvider() + provider2 = ContextVarUserProvider() + + result1 = await provider1.get_user() + result2 = await provider2.get_user() + + assert result1 is result2 is sample_user_info + + @pytest.mark.asyncio + async def test_custom_context_isolation(self, sample_user_info): + """Test that custom contexts are isolated from each other.""" + context1 = ContextVar("context1") + context2 = ContextVar("context2") + user2 = UserInfo(object_id="other-oid", tenant_id="other-tid") + + context1.set(sample_user_info) + context2.set(user2) + + provider1 = ContextVarUserProvider(context=context1) + provider2 = ContextVarUserProvider(context=context2) + + result1 = await provider1.get_user() + result2 = await provider2.get_user() + + assert result1 is sample_user_info + assert result2 is user2 + assert result1 is not result2 + + +class TestResolveUserFromHeaders: + """Tests for resolve_user_from_headers public function.""" + + def test_returns_user_info_when_both_headers_present(self): + """Test returns UserInfo when both object_id and tenant_id headers are present.""" + headers = { + "x-aml-oid": "user-object-id", + "x-aml-tid": "user-tenant-id" + } + + result = resolve_user_from_headers(headers) + + assert result is not None + assert isinstance(result, UserInfo) + assert result.object_id == "user-object-id" + assert result.tenant_id == "user-tenant-id" + + def test_returns_none_when_object_id_missing(self): + """Test returns None when object_id header is missing.""" + headers = {"x-aml-tid": "user-tenant-id"} + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_tenant_id_missing(self): + """Test returns None when tenant_id header is missing.""" + headers = {"x-aml-oid": "user-object-id"} + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_both_headers_missing(self): + """Test returns None when both headers are missing.""" + headers = {} + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_object_id_is_empty(self): + """Test returns None when object_id is empty string.""" + headers = { + "x-aml-oid": "", + "x-aml-tid": "user-tenant-id" + } + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_tenant_id_is_empty(self): + """Test returns None when tenant_id is empty string.""" + headers = { + "x-aml-oid": "user-object-id", + "x-aml-tid": "" + } + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_custom_header_names(self): + """Test using custom header names for object_id and tenant_id.""" + headers = { + "custom-oid-header": "custom-object-id", + "custom-tid-header": "custom-tenant-id" + } + + result = resolve_user_from_headers( + headers, + object_id_header="custom-oid-header", + tenant_id_header="custom-tid-header" + ) + + assert result is not None + assert result.object_id == "custom-object-id" + assert result.tenant_id == "custom-tenant-id" + + def test_default_headers_not_matched_with_custom_headers(self): + """Test that default headers are not matched when custom headers are specified.""" + headers = { + "x-aml-oid": "default-object-id", + "x-aml-tid": "default-tenant-id" + } + + result = resolve_user_from_headers( + headers, + object_id_header="custom-oid", + tenant_id_header="custom-tid" + ) + + assert result is None + + def test_case_sensitive_header_matching(self): + """Test that header matching is case-sensitive.""" + headers = { + "X-AML-OID": "user-object-id", + "X-AML-TID": "user-tenant-id" + } + + # Default headers are lowercase, so these should not match + result = resolve_user_from_headers(headers) + + assert result is None + + def test_with_mapping_like_object(self): + """Test with a mapping-like object that supports .get().""" + class HeadersMapping: + def __init__(self, data): + self._data = data + + def get(self, key, default=""): + return self._data.get(key, default) + + headers = HeadersMapping({ + "x-aml-oid": "mapping-object-id", + "x-aml-tid": "mapping-tenant-id" + }) + + result = resolve_user_from_headers(headers) + + assert result is not None + assert result.object_id == "mapping-object-id" + assert result.tenant_id == "mapping-tenant-id" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/__init__.py new file mode 100644 index 000000000000..2d7503de198d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Utils unit tests package.""" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/conftest.py new file mode 100644 index 000000000000..abd2f5145c29 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/conftest.py @@ -0,0 +1,56 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for utils unit tests. + +Common fixtures are inherited from the parent conftest.py automatically by pytest. +""" +from typing import Optional + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaType, +) + + +def create_resolved_tool_with_name( + name: str, + tool_type: str = "mcp", + connection_id: Optional[str] = None +) -> ResolvedFoundryTool: + """Helper to create a ResolvedFoundryTool with a specific name. + + :param name: The name for the tool details. + :param tool_type: Either "mcp" or "connected". + :param connection_id: Connection ID for connected tools. If provided with tool_type="mcp", + will automatically use "connected" type to ensure unique tool IDs. + :return: A ResolvedFoundryTool instance. + """ + schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={}, + required=set() + ) + details = FoundryToolDetails( + name=name, + description=f"Tool named {name}", + input_schema=schema + ) + + # If connection_id is provided, use connected tool to ensure unique IDs + if connection_id is not None or tool_type == "connected": + definition = FoundryConnectedTool( + protocol="mcp", + project_connection_id=connection_id or f"conn-{name}" + ) + else: + definition = FoundryHostedMcpTool( + name=f"mcp-{name}", + configuration={} + ) + + return ResolvedFoundryTool(definition=definition, details=details) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/test_name_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/test_name_resolver.py new file mode 100644 index 000000000000..14340799253b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/core/tools/utils/test_name_resolver.py @@ -0,0 +1,260 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _name_resolver.py - testing public methods of ToolNameResolver.""" +from azure.ai.agentserver.core.tools.utils import ToolNameResolver +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, +) + +from .conftest import create_resolved_tool_with_name + + +class TestToolNameResolverResolve: + """Tests for ToolNameResolver.resolve method.""" + + def test_resolve_returns_tool_name_for_first_occurrence( + self, + sample_resolved_mcp_tool + ): + """Test resolve returns the original tool name for first occurrence.""" + resolver = ToolNameResolver() + + result = resolver.resolve(sample_resolved_mcp_tool) + + assert result == sample_resolved_mcp_tool.details.name + + def test_resolve_returns_same_name_for_same_tool( + self, + sample_resolved_mcp_tool + ): + """Test resolve returns the same name when called multiple times for same tool.""" + resolver = ToolNameResolver() + + result1 = resolver.resolve(sample_resolved_mcp_tool) + result2 = resolver.resolve(sample_resolved_mcp_tool) + result3 = resolver.resolve(sample_resolved_mcp_tool) + + assert result1 == result2 == result3 + assert result1 == sample_resolved_mcp_tool.details.name + + def test_resolve_appends_count_for_duplicate_names(self): + """Test resolve appends count for tools with duplicate names.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("my_tool", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("my_tool", connection_id="conn-2") + tool3 = create_resolved_tool_with_name("my_tool", connection_id="conn-3") + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + result3 = resolver.resolve(tool3) + + assert result1 == "my_tool" + assert result2 == "my_tool_1" + assert result3 == "my_tool_2" + + def test_resolve_handles_multiple_unique_names(self): + """Test resolve handles multiple tools with unique names.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("tool_alpha") + tool2 = create_resolved_tool_with_name("tool_beta") + tool3 = create_resolved_tool_with_name("tool_gamma") + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + result3 = resolver.resolve(tool3) + + assert result1 == "tool_alpha" + assert result2 == "tool_beta" + assert result3 == "tool_gamma" + + def test_resolve_mixed_unique_and_duplicate_names(self): + """Test resolve handles a mix of unique and duplicate names.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("shared_name", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("unique_name") + tool3 = create_resolved_tool_with_name("shared_name", connection_id="conn-2") + tool4 = create_resolved_tool_with_name("another_unique") + tool5 = create_resolved_tool_with_name("shared_name", connection_id="conn-3") + + assert resolver.resolve(tool1) == "shared_name" + assert resolver.resolve(tool2) == "unique_name" + assert resolver.resolve(tool3) == "shared_name_1" + assert resolver.resolve(tool4) == "another_unique" + assert resolver.resolve(tool5) == "shared_name_2" + + def test_resolve_returns_cached_name_after_duplicate_added(self): + """Test that resolving a tool again returns cached name even after duplicates are added.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("my_tool", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("my_tool", connection_id="conn-2") + + # First resolution + first_result = resolver.resolve(tool1) + assert first_result == "my_tool" + + # Add duplicate + dup_result = resolver.resolve(tool2) + assert dup_result == "my_tool_1" + + # Resolve original again - should return cached value + second_result = resolver.resolve(tool1) + assert second_result == "my_tool" + + def test_resolve_with_connected_tool( + self, + sample_resolved_connected_tool + ): + """Test resolve works with connected tools.""" + resolver = ToolNameResolver() + + result = resolver.resolve(sample_resolved_connected_tool) + + assert result == sample_resolved_connected_tool.details.name + + def test_resolve_different_tools_same_details_name(self, sample_schema_definition): + """Test resolve handles different tool definitions with same details name.""" + resolver = ToolNameResolver() + + details = FoundryToolDetails( + name="shared_function", + description="A shared function", + input_schema=sample_schema_definition + ) + + mcp_def = FoundryHostedMcpTool(name="mcp_server", configuration={}) + connected_def = FoundryConnectedTool(protocol="mcp", project_connection_id="my-conn") + + tool1 = ResolvedFoundryTool(definition=mcp_def, details=details) + tool2 = ResolvedFoundryTool(definition=connected_def, details=details) + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + + assert result1 == "shared_function" + assert result2 == "shared_function_1" + + def test_resolve_empty_name(self): + """Test resolve handles tools with empty name.""" + resolver = ToolNameResolver() + + tool = create_resolved_tool_with_name("") + + result = resolver.resolve(tool) + + assert result == "" + + def test_resolve_special_characters_in_name(self): + """Test resolve handles tools with special characters in name.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("my-tool_v1.0", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("my-tool_v1.0", connection_id="conn-2") + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + + assert result1 == "my-tool_v1.0" + assert result2 == "my-tool_v1.0_1" + + def test_independent_resolver_instances(self): + """Test that different resolver instances maintain independent state.""" + resolver1 = ToolNameResolver() + resolver2 = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("tool_name", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("tool_name", connection_id="conn-2") + + # Both resolvers resolve tool1 first + assert resolver1.resolve(tool1) == "tool_name" + assert resolver2.resolve(tool1) == "tool_name" + + # resolver1 resolves tool2 as duplicate + assert resolver1.resolve(tool2) == "tool_name_1" + + # resolver2 has not seen tool2 yet in its context + # but tool2 has same name, so it should be duplicate + assert resolver2.resolve(tool2) == "tool_name_1" + + def test_resolve_many_duplicates(self): + """Test resolve handles many tools with the same name.""" + resolver = ToolNameResolver() + + tools = [ + create_resolved_tool_with_name("common_name", connection_id=f"conn-{i}") + for i in range(10) + ] + + results = [resolver.resolve(tool) for tool in tools] + + expected = ["common_name"] + [f"common_name_{i}" for i in range(1, 10)] + assert results == expected + + def test_resolve_uses_tool_id_for_caching(self, sample_schema_definition): + """Test that resolve uses tool.id for caching, not just name.""" + resolver = ToolNameResolver() + + # Create two tools with same definition but different details names + definition = FoundryHostedMcpTool(name="same_definition", configuration={}) + + details1 = FoundryToolDetails( + name="function_a", + description="Function A", + input_schema=sample_schema_definition + ) + details2 = FoundryToolDetails( + name="function_b", + description="Function B", + input_schema=sample_schema_definition + ) + + tool1 = ResolvedFoundryTool(definition=definition, details=details1) + tool2 = ResolvedFoundryTool(definition=definition, details=details2) + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + + # Both should get their respective names since they have different tool.id + assert result1 == "function_a" + assert result2 == "function_b" + + def test_resolve_idempotent_for_same_tool_id(self, sample_schema_definition): + """Test that resolve is idempotent for the same tool id.""" + resolver = ToolNameResolver() + + definition = FoundryHostedMcpTool(name="my_mcp", configuration={}) + details = FoundryToolDetails( + name="my_function", + description="My function", + input_schema=sample_schema_definition + ) + tool = ResolvedFoundryTool(definition=definition, details=details) + + # Call resolve many times + results = [resolver.resolve(tool) for _ in range(5)] + + # All should return the same name + assert all(r == "my_function" for r in results) + + def test_resolve_interleaved_tool_resolutions(self): + """Test resolve with interleaved resolutions of different tools.""" + resolver = ToolNameResolver() + + toolA_1 = create_resolved_tool_with_name("A", connection_id="A-1") + toolA_2 = create_resolved_tool_with_name("A", connection_id="A-2") + toolB_1 = create_resolved_tool_with_name("B", connection_id="B-1") + toolA_3 = create_resolved_tool_with_name("A", connection_id="A-3") + toolB_2 = create_resolved_tool_with_name("B", connection_id="B-2") + + assert resolver.resolve(toolA_1) == "A" + assert resolver.resolve(toolB_1) == "B" + assert resolver.resolve(toolA_2) == "A_1" + assert resolver.resolve(toolA_3) == "A_2" + assert resolver.resolve(toolB_2) == "B_1" diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_context.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_context.py index 81f0e0f0b545..89be24921f54 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_context.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_context.py @@ -1,11 +1,13 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +import sys from dataclasses import dataclass from typing import Optional, Union +from langchain_core.runnables import RunnableConfig from langgraph.prebuilt import ToolRuntime -from langgraph.runtime import Runtime +from langgraph.runtime import Runtime, get_runtime from azure.ai.agentserver.core import AgentRunContext from .tools._context import FoundryToolContext @@ -17,6 +19,42 @@ class LanggraphRunContext: tools: FoundryToolContext + def attach_to_config(self, config: RunnableConfig): + config["configurable"]["__foundry_hosted_agent_langgraph_run_context__"] = self + + @classmethod + def resolve(cls, + config: Optional[RunnableConfig] = None, + runtime: Optional[Union[Runtime, ToolRuntime]] = None) -> Optional["LanggraphRunContext"]: + """Resolve the LanggraphRunContext from either a RunnableConfig or a Runtime. + + :param config: Optional RunnableConfig to extract the context from. + :param runtime: Optional Runtime or ToolRuntime to extract the context from. + :return: An instance of LanggraphRunContext if found, otherwise None. + """ + context: Optional["LanggraphRunContext"] = None + if config: + context = cls.from_config(config) + if not context and (r := cls._resolve_runtime(runtime)): + context = cls.from_runtime(r) + return context + + @staticmethod + def _resolve_runtime( + runtime: Optional[Union[Runtime, ToolRuntime]] = None) -> Optional[Union[Runtime, ToolRuntime]]: + if runtime: + return runtime + if sys.version_info >= (3, 11): + return get_runtime(LanggraphRunContext) + return None + + @staticmethod + def from_config(config: RunnableConfig) -> Optional["LanggraphRunContext"]: + context = config["configurable"].get("__foundry_hosted_agent_langgraph_run_context__") + if isinstance(context, LanggraphRunContext): + return context + return None + @staticmethod def from_runtime(runtime: Union[Runtime, ToolRuntime]) -> Optional["LanggraphRunContext"]: context = runtime.context diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py index e8e524764db2..aae3bc32ee35 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py @@ -68,7 +68,7 @@ async def agent_run(self, context: AgentRunContext): try: lg_run_context = await self.setup_lg_run_context(context) input_arguments = await self.converter.convert_request(lg_run_context) - self.ensure_runnable_config(input_arguments) + self.ensure_runnable_config(input_arguments, lg_run_context) if not context.stream: response = await self.agent_run_non_stream(input_arguments) @@ -156,17 +156,20 @@ async def agent_run_astream(self, logger.error(f"Error during streaming agent run: {e}", exc_info=True) raise e - def ensure_runnable_config(self, input_arguments: GraphInputArguments): + def ensure_runnable_config(self, input_arguments: GraphInputArguments, context: LanggraphRunContext): """ Ensure the RunnableConfig is set in the input arguments. :param input_arguments: The input arguments for the agent run. :type input_arguments: GraphInputArguments + :param context: The Langgraph run context. + :type context: LanggraphRunContext """ config = input_arguments.get("config", {}) configurable = config.get("configurable", {}) configurable["thread_id"] = input_arguments["context"].agent_run.conversation_id config["configurable"] = configurable + context.attach_to_config(config) callbacks = config.get("callbacks", []) if self.azure_ai_tracer and self.azure_ai_tracer not in callbacks: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py index 828a8b42ae45..0ea9a2da80f2 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py @@ -5,7 +5,7 @@ from langchain_core.language_models import BaseChatModel -from azure.ai.agentserver.core.tools import FoundryToolLike +from azure.ai.agentserver.core.tools import FoundryToolLike, ensure_foundry_tool from ._chat_model import FoundryToolLateBindingChatModel from ._middleware import FoundryToolBindingMiddleware from ._resolver import get_registry @@ -54,7 +54,10 @@ def use_foundry_tools( # pylint: disable=C4743 if isinstance(model_or_tools, BaseChatModel): if tools is None: raise ValueError("Tools must be provided when a model is given.") - get_registry().extend(tools) - return FoundryToolLateBindingChatModel(model_or_tools, foundry_tools=tools) - get_registry().extend(model_or_tools) - return FoundryToolBindingMiddleware(model_or_tools) + foundry_tools = [ensure_foundry_tool(tool) for tool in tools] + get_registry().extend(foundry_tools) + return FoundryToolLateBindingChatModel(model_or_tools, runtime=None, foundry_tools=foundry_tools) + + foundry_tools = [ensure_foundry_tool(tool) for tool in model_or_tools] + get_registry().extend(foundry_tools) + return FoundryToolBindingMiddleware(foundry_tools) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py index c221910218f4..4ca422b88c41 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py @@ -30,7 +30,7 @@ class FoundryToolLateBindingChatModel(BaseChatModel): :type foundry_tools: List[FoundryToolLike] """ - def __init__(self, delegate: BaseChatModel, runtime: Runtime, foundry_tools: List[FoundryToolLike]): + def __init__(self, delegate: BaseChatModel, runtime: Optional[Runtime], foundry_tools: List[FoundryToolLike]): super().__init__() self._delegate = delegate self._runtime = runtime @@ -88,12 +88,17 @@ def bind_tools(self, # pylint: disable=C4758 return self - def _bound_delegate_for_call(self) -> Runnable[LanguageModelInput, AIMessage]: + def _bound_delegate_for_call(self, config: Optional[RunnableConfig]) -> Runnable[LanguageModelInput, AIMessage]: from .._context import LanggraphRunContext foundry_tools: Iterable[BaseTool] = [] - if (context := LanggraphRunContext.from_runtime(self._runtime)) is not None: + if context := LanggraphRunContext.resolve(config, self._runtime): foundry_tools = context.tools.resolved_tools.get(self._foundry_tools_to_bind) + elif self._foundry_tools_to_bind: + raise RuntimeError("Unable to resolve foundry tools from context, " + "if you are running in python < 3.11, " + "make sure you are passing RunnableConfig when calling model.") + all_tools = self._bound_tools.copy() all_tools.extend(foundry_tools) @@ -104,16 +109,16 @@ def _bound_delegate_for_call(self) -> Runnable[LanguageModelInput, AIMessage]: return self._delegate.bind_tools(all_tools, **bound_kwargs) def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any: - return self._bound_delegate_for_call().invoke(input, config=config, **kwargs) + return self._bound_delegate_for_call(config).invoke(input, config=config, **kwargs) async def ainvoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any: - return await self._bound_delegate_for_call().ainvoke(input, config=config, **kwargs) + return await self._bound_delegate_for_call(config).ainvoke(input, config=config, **kwargs) def stream(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any): - yield from self._bound_delegate_for_call().stream(input, config=config, **kwargs) + yield from self._bound_delegate_for_call(config).stream(input, config=config, **kwargs) async def astream(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any): - async for x in self._bound_delegate_for_call().astream(input, config=config, **kwargs): + async for x in self._bound_delegate_for_call(config).astream(input, config=config, **kwargs): yield x @property diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py index 5f3c6326836b..1bfef8c39f81 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py @@ -78,7 +78,7 @@ def _maybe_calling_foundry_tool(self, request: ToolCallRequest) -> ToolCallReque if (request.tool or not self._allowed_foundry_tools - or (context := LanggraphRunContext.from_runtime(request.runtime)) is None): + or not (context := LanggraphRunContext.resolve(runtime=request.runtime))): # tool is already resolved return request diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/tool_client_example/graph_agent_tool.py b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/tool_client_example/graph_agent_tool.py new file mode 100644 index 000000000000..c4992ba71f46 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/tool_client_example/graph_agent_tool.py @@ -0,0 +1,104 @@ +import os + +from dotenv import load_dotenv +from langchain.chat_models import init_chat_model +from langchain_core.messages import SystemMessage, ToolMessage +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import tool +from langgraph.graph import ( + END, + START, + MessagesState, + StateGraph, +) +from typing_extensions import Literal +from azure.identity import DefaultAzureCredential, get_bearer_token_provider + +from azure.ai.agentserver.langgraph import from_langgraph +from azure.ai.agentserver.langgraph.tools import use_foundry_tools + +load_dotenv() + +deployment_name = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", "gpt-4o") +credential = DefaultAzureCredential() +token_provider = get_bearer_token_provider( + credential, "https://cognitiveservices.azure.com/.default" +) +llm = init_chat_model( + f"azure_openai:{deployment_name}", + azure_ad_token_provider=token_provider, +) +llm_with_foundry_tools = use_foundry_tools(llm, [ + { + # use the python tool to calculate what is 4 * 3.82. and then find its square root and then find the square root of that result + "type": "code_interpreter" + }, + { + # Give me the Azure CLI commands to create an Azure Container App with a managed identity. search Microsoft Learn + "type": "mcp", + "project_connection_id": "MicrosoftLearn" + }, + # { + # "type": "mcp", + # "project_connection_id": "FoundryMCPServerpreview" + # } +]) + + +# Nodes +async def llm_call(state: MessagesState, config: RunnableConfig): + """LLM decides whether to call a tool or not""" + + return { + "messages": [ + await llm_with_foundry_tools.ainvoke( + [ + SystemMessage( + content="You are a helpful assistant tasked with performing arithmetic on a set of inputs." + ) + ] + + state["messages"], + config=config, + ) + ] + } + + +# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call +def should_continue(state: MessagesState) -> Literal["environment", END]: + """Decide if we should continue the loop or stop based upon whether the LLM made a tool call""" + + messages = state["messages"] + last_message = messages[-1] + # If the LLM makes a tool call, then perform an action + if last_message.tool_calls: + return "Action" + # Otherwise, we stop (reply to the user) + return END + + +# Build workflow +agent_builder = StateGraph(MessagesState) + +# Add nodes +agent_builder.add_node("llm_call", llm_call) +agent_builder.add_node("environment", llm_with_foundry_tools.tool_node) + +# Add edges to connect nodes +agent_builder.add_edge(START, "llm_call") +agent_builder.add_conditional_edges( + "llm_call", + should_continue, + { + "Action": "environment", + END: END, + }, +) +agent_builder.add_edge("environment", "llm_call") + +# Compile the agent +agent = agent_builder.compile() + +if __name__ == "__main__": + adapter = from_langgraph(agent) + adapter.run() diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/__init__.py deleted file mode 100644 index 4a5d26360bce..000000000000 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Unit tests package diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/conftest.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/conftest.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-langgraph/tests/conftest.py rename to sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/conftest.py diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_langgraph_request_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/test_langgraph_request_converter.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_langgraph_request_converter.py rename to sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/test_langgraph_request_converter.py diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/conftest.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/conftest.py new file mode 100644 index 000000000000..7efc298559c1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/conftest.py @@ -0,0 +1,271 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for langgraph tools unit tests.""" +from typing import Any, Dict, List, Optional + +import pytest +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import BaseTool, tool + +from azure.ai.agentserver.core.tools import ( + FoundryHostedMcpTool, + FoundryConnectedTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) +from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext +from azure.ai.agentserver.langgraph._context import LanggraphRunContext +from azure.ai.agentserver.langgraph.tools._context import FoundryToolContext +from azure.ai.agentserver.langgraph.tools._resolver import ResolvedTools + + +class FakeChatModel(BaseChatModel): + """A fake chat model for testing purposes that returns pre-configured responses.""" + + responses: List[AIMessage] = [] + tool_calls_list: List[List[Dict[str, Any]]] = [] + _call_count: int = 0 + _bound_tools: List[Any] = [] + _bound_kwargs: Dict[str, Any] = {} + + def __init__( + self, + responses: Optional[List[AIMessage]] = None, + tool_calls: Optional[List[List[Dict[str, Any]]]] = None, + **kwargs: Any, + ): + """Initialize the fake chat model. + + :param responses: List of AIMessage responses to return in sequence. + :param tool_calls: List of tool_calls lists corresponding to each response. + """ + super().__init__(**kwargs) + self.responses = responses or [] + self.tool_calls_list = tool_calls or [] + self._call_count = 0 + self._bound_tools = [] + self._bound_kwargs = {} + + @property + def _llm_type(self) -> str: + return "fake_chat_model" + + def _generate( + self, + messages: List[Any], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Generate a response.""" + response = self._get_next_response() + return ChatResult(generations=[ChatGeneration(message=response)]) + + def bind_tools( + self, + tools: List[Any], + **kwargs: Any, + ) -> "FakeChatModel": + """Bind tools to this model.""" + self._bound_tools = list(tools) + self._bound_kwargs.update(kwargs) + return self + + def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> AIMessage: + """Synchronously invoke the model.""" + return self._get_next_response() + + async def ainvoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> AIMessage: + """Asynchronously invoke the model.""" + return self._get_next_response() + + def stream(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any): + """Stream the response.""" + yield self._get_next_response() + + async def astream(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any): + """Async stream the response.""" + yield self._get_next_response() + + def _get_next_response(self) -> AIMessage: + """Get the next response in sequence.""" + if self._call_count < len(self.responses): + response = self.responses[self._call_count] + else: + # Default response if no more configured + response = AIMessage(content="Default response") + + # Apply tool calls if configured + if self._call_count < len(self.tool_calls_list): + response = AIMessage( + content=response.content, + tool_calls=self.tool_calls_list[self._call_count], + ) + + self._call_count += 1 + return response + + +@pytest.fixture +def sample_schema_definition() -> SchemaDefinition: + """Create a sample SchemaDefinition.""" + return SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Search query"), + }, + required={"query"}, + ) + + +@pytest.fixture +def sample_code_interpreter_tool() -> FoundryHostedMcpTool: + """Create a sample code interpreter tool definition.""" + return FoundryHostedMcpTool( + name="code_interpreter", + configuration={}, + ) + + +@pytest.fixture +def sample_mcp_connected_tool() -> FoundryConnectedTool: + """Create a sample MCP connected tool definition.""" + return FoundryConnectedTool( + protocol="mcp", + project_connection_id="MicrosoftLearn", + ) + + +@pytest.fixture +def sample_tool_details(sample_schema_definition: SchemaDefinition) -> FoundryToolDetails: + """Create a sample FoundryToolDetails.""" + return FoundryToolDetails( + name="search", + description="Search for documents", + input_schema=sample_schema_definition, + ) + + +@pytest.fixture +def sample_resolved_tool( + sample_code_interpreter_tool: FoundryHostedMcpTool, + sample_tool_details: FoundryToolDetails, +) -> ResolvedFoundryTool: + """Create a sample resolved foundry tool.""" + return ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=sample_tool_details, + ) + + +@pytest.fixture +def mock_langchain_tool() -> BaseTool: + """Create a mock LangChain BaseTool.""" + @tool + def mock_tool(query: str) -> str: + """Mock tool for testing. + + :param query: The search query. + :return: Mock result. + """ + return f"Mock result for: {query}" + + return mock_tool + + +@pytest.fixture +def mock_async_langchain_tool() -> BaseTool: + """Create a mock async LangChain BaseTool.""" + @tool + async def mock_async_tool(query: str) -> str: + """Mock async tool for testing. + + :param query: The search query. + :return: Mock result. + """ + return f"Async mock result for: {query}" + + return mock_async_tool + + +@pytest.fixture +def sample_resolved_tools( + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, +) -> ResolvedTools: + """Create a sample ResolvedTools instance.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="mock_tool", + description="Mock tool for testing", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Query"), + }, + required={"query"}, + ), + ), + ) + return ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + +@pytest.fixture +def mock_agent_run_context() -> AgentRunContext: + """Create a mock AgentRunContext.""" + payload = { + "input": [{"role": "user", "content": "Hello"}], + "stream": False, + } + return AgentRunContext(payload=payload) + + +@pytest.fixture +def mock_foundry_tool_context(sample_resolved_tools: ResolvedTools) -> FoundryToolContext: + """Create a mock FoundryToolContext.""" + return FoundryToolContext(resolved_tools=sample_resolved_tools) + + +@pytest.fixture +def mock_langgraph_run_context( + mock_agent_run_context: AgentRunContext, + mock_foundry_tool_context: FoundryToolContext, +) -> LanggraphRunContext: + """Create a mock LanggraphRunContext.""" + return LanggraphRunContext( + agent_run=mock_agent_run_context, + tools=mock_foundry_tool_context, + ) + + +@pytest.fixture +def fake_chat_model_simple() -> FakeChatModel: + """Create a simple fake chat model.""" + return FakeChatModel( + responses=[AIMessage(content="Hello! How can I help you?")], + ) + + +@pytest.fixture +def fake_chat_model_with_tool_call() -> FakeChatModel: + """Create a fake chat model that makes a tool call.""" + return FakeChatModel( + responses=[ + AIMessage(content=""), # First response: tool call + AIMessage(content="The answer is 42."), # Second response: final answer + ], + tool_calls=[ + [{"id": "call_1", "name": "mock_tool", "args": {"query": "test query"}}], + [], # No tool calls in final response + ], + ) + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_agent_integration.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_agent_integration.py new file mode 100644 index 000000000000..fab1955ef415 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_agent_integration.py @@ -0,0 +1,404 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Integration-style unit tests for langgraph agents with foundry tools. + +These tests demonstrate the usage patterns similar to the tool_client_example samples, +but use mocked models and tools to avoid calling real services. +""" +import pytest + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import BaseTool, tool +from langgraph.graph import END, START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode +from typing_extensions import Literal + +from azure.ai.agentserver.core.tools import ( + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) +from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext +from azure.ai.agentserver.langgraph._context import LanggraphRunContext +from azure.ai.agentserver.langgraph.tools import use_foundry_tools +from azure.ai.agentserver.langgraph.tools._context import FoundryToolContext +from azure.ai.agentserver.langgraph.tools._chat_model import FoundryToolLateBindingChatModel +from azure.ai.agentserver.langgraph.tools._middleware import FoundryToolBindingMiddleware +from azure.ai.agentserver.langgraph.tools._resolver import ResolvedTools, get_registry + +from .conftest import FakeChatModel + + +@pytest.fixture(autouse=True) +def clear_registry(): + """Clear the global registry before and after each test.""" + registry = get_registry() + registry.clear() + yield + registry.clear() + + +@pytest.mark.unit +class TestGraphAgentWithFoundryTools: + """Tests demonstrating graph agent usage patterns similar to graph_agent_tool.py sample.""" + + def _create_mock_langgraph_context( + self, + foundry_tool: FoundryHostedMcpTool, + langchain_tool: BaseTool, + ) -> LanggraphRunContext: + """Create a mock LanggraphRunContext with resolved tools.""" + # Create resolved foundry tool + resolved_foundry_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name=langchain_tool.name, + description=langchain_tool.description or "Mock tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Query"), + }, + required={"query"}, + ), + ), + ) + + # Create resolved tools + resolved_tools = ResolvedTools(tools=[(resolved_foundry_tool, langchain_tool)]) + + # Create context + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext(resolved_tools=resolved_tools) + + return LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + @pytest.mark.asyncio + async def test_graph_agent_with_foundry_tools_no_tool_call(self): + """Test a graph agent that uses foundry tools but doesn't make a tool call.""" + # Create a mock tool + @tool + def calculate(expression: str) -> str: + """Calculate a mathematical expression. + + :param expression: The expression to calculate. + :return: The result. + """ + return "42" + + # Create foundry tool definition + foundry_tool = FoundryHostedMcpTool(name="code_interpreter", configuration={}) + foundry_tools = [{"type": "code_interpreter"}] + + # Create mock model that returns simple response (no tool call) + mock_model = FakeChatModel( + responses=[AIMessage(content="The answer is 42.")], + ) + + # Create the foundry tool binding chat model + llm_with_foundry_tools = FoundryToolLateBindingChatModel( + delegate=mock_model, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Create context and attach + context = self._create_mock_langgraph_context(foundry_tool, calculate) + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + # Define the LLM call node + async def llm_call(state: MessagesState, config: RunnableConfig): + return { + "messages": [ + await llm_with_foundry_tools.ainvoke( + [SystemMessage(content="You are a helpful assistant.")] + + state["messages"], + config=config, + ) + ] + } + + # Define routing function + def should_continue(state: MessagesState) -> Literal["tools", "__end__"]: + messages = state["messages"] + last_message = messages[-1] + if hasattr(last_message, 'tool_calls') and last_message.tool_calls: + return "tools" + return END + + # Build the graph + builder = StateGraph(MessagesState) + builder.add_node("llm_call", llm_call) + builder.add_node("tools", llm_with_foundry_tools.tool_node) + builder.add_edge(START, "llm_call") + builder.add_conditional_edges("llm_call", should_continue, {"tools": "tools", END: END}) + builder.add_edge("tools", "llm_call") + + graph = builder.compile() + + # Run the graph + result = await graph.ainvoke( + {"messages": [HumanMessage(content="What is 6 * 7?")]}, + config=config, + ) + + # Verify result + assert len(result["messages"]) == 2 # HumanMessage + AIMessage + assert result["messages"][-1].content == "The answer is 42." + + @pytest.mark.asyncio + async def test_graph_agent_with_tool_call(self): + """Test a graph agent that makes a tool call.""" + # Create a mock tool + @tool + def calculate(expression: str) -> str: + """Calculate a mathematical expression. + + :param expression: The expression to calculate. + :return: The result. + """ + return "42" + + # Create foundry tool definition + foundry_tool = FoundryHostedMcpTool(name="code_interpreter", configuration={}) + foundry_tools = [{"type": "code_interpreter"}] + + # Create mock model that makes a tool call, then returns final answer + mock_model = FakeChatModel( + responses=[ + AIMessage( + content="", + tool_calls=[{"id": "call_1", "name": "calculate", "args": {"expression": "6 * 7"}}], + ), + AIMessage(content="The answer is 42."), + ], + ) + + # Create the foundry tool binding chat model + llm_with_foundry_tools = FoundryToolLateBindingChatModel( + delegate=mock_model, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Create context with the calculate tool + context = self._create_mock_langgraph_context(foundry_tool, calculate) + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + # Define the LLM call node + async def llm_call(state: MessagesState, config: RunnableConfig): + return { + "messages": [ + await llm_with_foundry_tools.ainvoke( + [SystemMessage(content="You are a helpful assistant.")] + + state["messages"], + config=config, + ) + ] + } + + # Define routing function + def should_continue(state: MessagesState) -> Literal["tools", "__end__"]: + messages = state["messages"] + last_message = messages[-1] + if hasattr(last_message, 'tool_calls') and last_message.tool_calls: + return "tools" + return END + + # Build the graph with a regular ToolNode (using the local tool directly for testing) + builder = StateGraph(MessagesState) + builder.add_node("llm_call", llm_call) + builder.add_node("tools", ToolNode([calculate])) + builder.add_edge(START, "llm_call") + builder.add_conditional_edges("llm_call", should_continue, {"tools": "tools", END: END}) + builder.add_edge("tools", "llm_call") + + graph = builder.compile() + + # Run the graph + result = await graph.ainvoke( + {"messages": [HumanMessage(content="What is 6 * 7?")]}, + config=config, + ) + + # Verify result - should have: HumanMessage, AIMessage (with tool call), ToolMessage, AIMessage (final) + assert len(result["messages"]) == 4 + assert result["messages"][-1].content == "The answer is 42." + + # Verify tool was called + tool_message = result["messages"][2] + assert isinstance(tool_message, ToolMessage) + assert tool_message.content == "42" + + +@pytest.mark.unit +class TestReactAgentWithFoundryTools: + """Tests demonstrating react agent usage patterns similar to react_agent_tool.py sample.""" + + @pytest.mark.asyncio + async def test_middleware_integration_with_foundry_tools(self): + """Test that FoundryToolBindingMiddleware correctly integrates with agents.""" + # Define foundry tools configuration + foundry_tools_config = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "MicrosoftLearn"}, + ] + + # Create middleware using use_foundry_tools + middleware = use_foundry_tools(foundry_tools_config) + + # Verify middleware is created correctly + assert isinstance(middleware, FoundryToolBindingMiddleware) + + # Verify dummy tool is created for the agent + assert len(middleware.tools) == 1 + assert middleware.tools[0].name == "__dummy_tool_by_foundry_middleware__" + + # Verify foundry tools are recorded + assert len(middleware._foundry_tools_to_bind) == 2 + + def test_use_foundry_tools_with_model(self): + """Test use_foundry_tools when used with a model directly.""" + foundry_tools = [{"type": "code_interpreter"}] + mock_model = FakeChatModel() + + result = use_foundry_tools(mock_model, foundry_tools) # type: ignore + + assert isinstance(result, FoundryToolLateBindingChatModel) + assert result._foundry_tools_to_bind == foundry_tools + + +@pytest.mark.unit +class TestLanggraphRunContextIntegration: + """Tests for LanggraphRunContext integration with langgraph.""" + + def test_context_attachment_to_config(self): + """Test that context is correctly attached to RunnableConfig.""" + # Create a mock context + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext() + + context = LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + # Create config and attach context + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + # Verify context is attached + assert "__foundry_hosted_agent_langgraph_run_context__" in config["configurable"] + assert config["configurable"]["__foundry_hosted_agent_langgraph_run_context__"] is context + + def test_context_resolution_from_config(self): + """Test that context can be resolved from RunnableConfig.""" + # Create and attach context + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext() + + context = LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + # Resolve context + resolved = LanggraphRunContext.resolve(config=config) + + assert resolved is context + + def test_context_resolution_returns_none_when_not_attached(self): + """Test that context resolution returns None when not attached.""" + config: RunnableConfig = {"configurable": {}} + + resolved = LanggraphRunContext.resolve(config=config) + + assert resolved is None + + def test_from_config_returns_context(self): + """Test LanggraphRunContext.from_config method.""" + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext() + + context = LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + result = LanggraphRunContext.from_config(config) + + assert result is context + + def test_from_config_returns_none_for_non_context_value(self): + """Test that from_config returns None when value is not LanggraphRunContext.""" + config: RunnableConfig = { + "configurable": { + "__foundry_hosted_agent_langgraph_run_context__": "not a context" + } + } + + result = LanggraphRunContext.from_config(config) + + assert result is None + + +@pytest.mark.unit +class TestToolsResolutionInGraph: + """Tests for tool resolution within langgraph execution.""" + + @pytest.mark.asyncio + async def test_foundry_tools_resolved_from_context_in_graph_node(self): + """Test that foundry tools are correctly resolved from context during graph execution.""" + # Create mock tool + @tool + def search(query: str) -> str: + """Search for information. + + :param query: The search query. + :return: Search results. + """ + return f"Results for: {query}" + + # Create foundry tool and context + foundry_tool = FoundryHostedMcpTool(name="code_interpreter", configuration={}) + resolved_foundry_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name="search", + description="Search tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + + resolved_tools = ResolvedTools(tools=[(resolved_foundry_tool, search)]) + + # Create context + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext(resolved_tools=resolved_tools) + lg_context = LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + # Create config and attach context + config: RunnableConfig = {"configurable": {}} + lg_context.attach_to_config(config) + + # Verify tools can be resolved + resolved = LanggraphRunContext.resolve(config=config) + assert resolved is not None + + tools = list(resolved.tools.resolved_tools.get(foundry_tool)) + assert len(tools) == 1 + assert tools[0].name == "search" + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_builder.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_builder.py new file mode 100644 index 000000000000..1a2a5af167be --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_builder.py @@ -0,0 +1,109 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for use_foundry_tools builder function.""" +import pytest +from typing import List + + +from azure.ai.agentserver.langgraph.tools._builder import use_foundry_tools +from azure.ai.agentserver.langgraph.tools._chat_model import FoundryToolLateBindingChatModel +from azure.ai.agentserver.langgraph.tools._middleware import FoundryToolBindingMiddleware +from azure.ai.agentserver.langgraph.tools._resolver import get_registry + +from .conftest import FakeChatModel + + +@pytest.fixture(autouse=True) +def clear_registry(): + """Clear the global registry before and after each test.""" + registry = get_registry() + registry.clear() + yield + registry.clear() + + +@pytest.mark.unit +class TestUseFoundryTools: + """Tests for use_foundry_tools function.""" + + def test_use_foundry_tools_with_tools_only_returns_middleware(self): + """Test that passing only tools returns FoundryToolBindingMiddleware.""" + tools = [{"type": "code_interpreter"}] + + result = use_foundry_tools(tools) + + assert isinstance(result, FoundryToolBindingMiddleware) + + def test_use_foundry_tools_with_model_and_tools_returns_chat_model(self): + """Test that passing model and tools returns FoundryToolLateBindingChatModel.""" + model = FakeChatModel() + tools = [{"type": "code_interpreter"}] + + result = use_foundry_tools(model, tools) # type: ignore + + assert isinstance(result, FoundryToolLateBindingChatModel) + + def test_use_foundry_tools_with_model_but_no_tools_raises_error(self): + """Test that passing model without tools raises ValueError.""" + model = FakeChatModel() + + with pytest.raises(ValueError, match="Tools must be provided"): + use_foundry_tools(model, None) # type: ignore + + def test_use_foundry_tools_registers_tools_in_global_registry(self): + """Test that tools are registered in the global registry.""" + tools = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "test"}, + ] + + use_foundry_tools(tools) + + registry = get_registry() + assert len(registry) == 2 + + def test_use_foundry_tools_with_model_registers_tools(self): + """Test that tools are registered when using with model.""" + model = FakeChatModel() + tools = [{"type": "code_interpreter"}] + + use_foundry_tools(model, tools) # type: ignore + + registry = get_registry() + assert len(registry) == 1 + + def test_use_foundry_tools_with_empty_tools_list(self): + """Test using with empty tools list.""" + tools: List = [] + + result = use_foundry_tools(tools) + + assert isinstance(result, FoundryToolBindingMiddleware) + assert len(get_registry()) == 0 + + def test_use_foundry_tools_with_mcp_tools(self): + """Test using with MCP connected tools.""" + tools = [ + { + "type": "mcp", + "project_connection_id": "MicrosoftLearn", + }, + ] + + result = use_foundry_tools(tools) + + assert isinstance(result, FoundryToolBindingMiddleware) + + def test_use_foundry_tools_with_mixed_tool_types(self): + """Test using with a mix of different tool types.""" + tools = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "MicrosoftLearn"}, + ] + + result = use_foundry_tools(tools) + + assert isinstance(result, FoundryToolBindingMiddleware) + assert len(get_registry()) == 2 + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_chat_model.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_chat_model.py new file mode 100644 index 000000000000..085495a4b91e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_chat_model.py @@ -0,0 +1,277 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolLateBindingChatModel.""" +import pytest +from typing import Any, List, Optional +from unittest.mock import MagicMock, patch + +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import BaseTool, tool + +from azure.ai.agentserver.core.tools import ( + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) +from azure.ai.agentserver.langgraph._context import LanggraphRunContext +from azure.ai.agentserver.langgraph.tools._chat_model import FoundryToolLateBindingChatModel +from azure.ai.agentserver.langgraph.tools._context import FoundryToolContext +from azure.ai.agentserver.langgraph.tools._resolver import ResolvedTools + +from .conftest import FakeChatModel + + +@pytest.mark.unit +class TestFoundryToolLateBindingChatModel: + """Tests for FoundryToolLateBindingChatModel class.""" + + def test_llm_type_property(self): + """Test the _llm_type property returns correct value.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + assert "foundry_tool_binding_model" in model._llm_type + assert "fake_chat_model" in model._llm_type + + def test_bind_tools_records_tools(self): + """Test that bind_tools records tools for later use.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + @tool + def my_tool(x: str) -> str: + """My tool.""" + return x + + result = model.bind_tools([my_tool], tool_choice="auto") + + # Should return self for chaining + assert result is model + # Tools should be recorded + assert len(model._bound_tools) == 1 + assert model._bound_kwargs.get("tool_choice") == "auto" + + def test_bind_tools_multiple_times(self): + """Test binding tools multiple times accumulates them.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + @tool + def tool1(x: str) -> str: + """Tool 1.""" + return x + + @tool + def tool2(x: str) -> str: + """Tool 2.""" + return x + + model.bind_tools([tool1]) + model.bind_tools([tool2]) + + assert len(model._bound_tools) == 2 + + def test_tool_node_property(self): + """Test that tool_node property returns a ToolNode.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + tool_node = model.tool_node + + # Should return a ToolNode + assert tool_node is not None + + def test_tool_node_wrapper_property(self): + """Test that tool_node_wrapper returns correct wrappers.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + wrappers = model.tool_node_wrapper + + assert "wrap_tool_call" in wrappers + assert "awrap_tool_call" in wrappers + assert callable(wrappers["wrap_tool_call"]) + assert callable(wrappers["awrap_tool_call"]) + + def test_invoke_with_context( + self, + mock_langgraph_run_context: LanggraphRunContext, + sample_code_interpreter_tool: FoundryHostedMcpTool, + ): + """Test invoking model with context attached.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Hello from model!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Attach context to config + config: RunnableConfig = {"configurable": {}} + mock_langgraph_run_context.attach_to_config(config) + + input_messages = [HumanMessage(content="Hello")] + result = model.invoke(input_messages, config=config) + + assert result.content == "Hello from model!" + + @pytest.mark.asyncio + async def test_ainvoke_with_context( + self, + mock_langgraph_run_context: LanggraphRunContext, + ): + """Test async invoking model with context attached.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Async hello from model!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Attach context to config + config: RunnableConfig = {"configurable": {}} + mock_langgraph_run_context.attach_to_config(config) + + input_messages = [HumanMessage(content="Hello")] + result = await model.ainvoke(input_messages, config=config) + + assert result.content == "Async hello from model!" + + def test_invoke_without_context_and_no_foundry_tools(self): + """Test invoking model without context and no foundry tools.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Hello!")], + ) + # No foundry tools + foundry_tools: List[Any] = [] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + config: RunnableConfig = {"configurable": {}} + input_messages = [HumanMessage(content="Hello")] + result = model.invoke(input_messages, config=config) + + # Should work since no foundry tools need resolution + assert result.content == "Hello!" + + def test_invoke_without_context_raises_error_when_foundry_tools_present(self): + """Test that invoking without context raises error when foundry tools are set.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Hello!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + config: RunnableConfig = {"configurable": {}} + input_messages = [HumanMessage(content="Hello")] + + with pytest.raises(RuntimeError, match="Unable to resolve foundry tools from context"): + model.invoke(input_messages, config=config) + + def test_stream_with_context( + self, + mock_langgraph_run_context: LanggraphRunContext, + ): + """Test streaming model with context attached.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Streamed response!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Attach context to config + config: RunnableConfig = {"configurable": {}} + mock_langgraph_run_context.attach_to_config(config) + + input_messages = [HumanMessage(content="Hello")] + results = list(model.stream(input_messages, config=config)) + + assert len(results) == 1 + assert results[0].content == "Streamed response!" + + @pytest.mark.asyncio + async def test_astream_with_context( + self, + mock_langgraph_run_context: LanggraphRunContext, + ): + """Test async streaming model with context attached.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Async streamed response!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Attach context to config + config: RunnableConfig = {"configurable": {}} + mock_langgraph_run_context.attach_to_config(config) + + input_messages = [HumanMessage(content="Hello")] + results = [] + async for chunk in model.astream(input_messages, config=config): + results.append(chunk) + + assert len(results) == 1 + assert results[0].content == "Async streamed response!" + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_context.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_context.py new file mode 100644 index 000000000000..577d4e6e4e6f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_context.py @@ -0,0 +1,36 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolContext.""" +import pytest + +from azure.ai.agentserver.langgraph.tools._context import FoundryToolContext +from azure.ai.agentserver.langgraph.tools._resolver import ResolvedTools + + +@pytest.mark.unit +class TestFoundryToolContext: + """Tests for FoundryToolContext class.""" + + def test_create_with_resolved_tools(self, sample_resolved_tools: ResolvedTools): + """Test creating FoundryToolContext with resolved tools.""" + context = FoundryToolContext(resolved_tools=sample_resolved_tools) + + assert context.resolved_tools is sample_resolved_tools + + def test_create_with_default_resolved_tools(self): + """Test creating FoundryToolContext with default empty resolved tools.""" + context = FoundryToolContext() + + # Default should be empty ResolvedTools + assert context.resolved_tools is not None + tools_list = list(context.resolved_tools) + assert len(tools_list) == 0 + + def test_resolved_tools_is_iterable(self, sample_resolved_tools: ResolvedTools): + """Test that resolved_tools can be iterated.""" + context = FoundryToolContext(resolved_tools=sample_resolved_tools) + + tools_list = list(context.resolved_tools) + assert len(tools_list) == 1 + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_middleware.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_middleware.py new file mode 100644 index 000000000000..89290a58f97c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_middleware.py @@ -0,0 +1,197 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolBindingMiddleware.""" +import pytest +from typing import Any, List +from unittest.mock import AsyncMock, MagicMock + +from langchain.agents.middleware.types import ModelRequest +from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.tools import tool +from langgraph.prebuilt.tool_node import ToolCallRequest + +from azure.ai.agentserver.langgraph.tools._middleware import FoundryToolBindingMiddleware + +from .conftest import FakeChatModel + + +@pytest.mark.unit +class TestFoundryToolBindingMiddleware: + """Tests for FoundryToolBindingMiddleware class.""" + + def test_init_with_foundry_tools_creates_dummy_tool(self): + """Test that initialization with foundry tools creates a dummy tool.""" + foundry_tools = [{"type": "code_interpreter"}] + + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Should have one dummy tool + assert len(middleware.tools) == 1 + assert middleware.tools[0].name == "__dummy_tool_by_foundry_middleware__" + + def test_init_without_foundry_tools_no_dummy_tool(self): + """Test that initialization without foundry tools creates no dummy tool.""" + foundry_tools: List[Any] = [] + + middleware = FoundryToolBindingMiddleware(foundry_tools) + + assert len(middleware.tools) == 0 + + def test_wrap_model_call_wraps_model_with_foundry_binding(self): + """Test that wrap_model_call wraps the model correctly.""" + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create mock model and request + mock_model = FakeChatModel() + mock_runtime = MagicMock() + mock_request = MagicMock(spec=ModelRequest) + mock_request.model = mock_model + mock_request.runtime = mock_runtime + mock_request.tools = [] + + # Create a modified request to return + modified_request = MagicMock(spec=ModelRequest) + mock_request.override = MagicMock(return_value=modified_request) + + # Mock handler + expected_result = AIMessage(content="Result") + mock_handler = MagicMock(return_value=expected_result) + + result = middleware.wrap_model_call(mock_request, mock_handler) + + # Handler should be called with modified request + mock_handler.assert_called_once() + assert result == expected_result + + @pytest.mark.asyncio + async def test_awrap_model_call_wraps_model_async(self): + """Test that awrap_model_call wraps the model correctly in async.""" + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create mock model and request + mock_model = FakeChatModel() + mock_runtime = MagicMock() + mock_request = MagicMock(spec=ModelRequest) + mock_request.model = mock_model + mock_request.runtime = mock_runtime + mock_request.tools = [] + + # Create a modified request to return + modified_request = MagicMock(spec=ModelRequest) + mock_request.override = MagicMock(return_value=modified_request) + + # Mock async handler + expected_result = AIMessage(content="Async Result") + mock_handler = AsyncMock(return_value=expected_result) + + result = await middleware.awrap_model_call(mock_request, mock_handler) + + # Handler should be called + mock_handler.assert_awaited_once() + assert result == expected_result + + def test_wrap_model_without_foundry_tools_returns_unchanged(self): + """Test that wrap_model returns unchanged request when no foundry tools.""" + foundry_tools: List[Any] = [] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + mock_model = FakeChatModel() + mock_request = MagicMock(spec=ModelRequest) + mock_request.model = mock_model + mock_request.tools = [] + + # Should not call override + mock_request.override = MagicMock() + + mock_handler = MagicMock(return_value=AIMessage(content="Result")) + + middleware.wrap_model_call(mock_request, mock_handler) + + # Handler should be called with original request + mock_handler.assert_called_once_with(mock_request) + + def test_remove_dummy_tool_from_request(self): + """Test that dummy tool is removed from the request tools.""" + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create request with dummy tool + dummy = middleware._dummy_tool() + + @tool + def real_tool(x: str) -> str: + """Real tool.""" + return x + + mock_request = MagicMock(spec=ModelRequest) + mock_request.tools = [dummy, real_tool] + + # Call internal method + result = middleware._remove_dummy_tool(mock_request) + + # Should only have real_tool + assert len(result) == 1 + assert result[0] is real_tool + + def test_wrap_tool_call_delegates_to_wrapper(self): + """Test that wrap_tool_call delegates to FoundryToolCallWrapper.""" + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create mock tool call request + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "test_tool", "id": "call_1"} + mock_request.state = {} + mock_request.runtime = None + + # Mock handler + expected_result = ToolMessage(content="Result", tool_call_id="call_1") + mock_handler = MagicMock(return_value=expected_result) + + result = middleware.wrap_tool_call(mock_request, mock_handler) + + # Handler should be called + mock_handler.assert_called_once() + assert result == expected_result + + @pytest.mark.asyncio + async def test_awrap_tool_call_delegates_to_wrapper_async(self): + """Test that awrap_tool_call delegates to FoundryToolCallWrapper async.""" + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create mock tool call request + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "test_tool", "id": "call_1"} + mock_request.state = {} + mock_request.runtime = None + + # Mock async handler + expected_result = ToolMessage(content="Async Result", tool_call_id="call_1") + mock_handler = AsyncMock(return_value=expected_result) + + result = await middleware.awrap_tool_call(mock_request, mock_handler) + + # Handler should be awaited + mock_handler.assert_awaited_once() + assert result == expected_result + + def test_middleware_with_multiple_foundry_tools(self): + """Test middleware initialization with multiple foundry tools.""" + foundry_tools = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "test"}, + ] + + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Should still only have one dummy tool + assert len(middleware.tools) == 1 + # But should have all foundry tools registered + assert len(middleware._foundry_tools_to_bind) == 2 + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_resolver.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_resolver.py new file mode 100644 index 000000000000..985ed4caec49 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_resolver.py @@ -0,0 +1,502 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for ResolvedTools and FoundryLangChainToolResolver.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from langchain_core.tools import BaseTool, StructuredTool, tool +from pydantic import BaseModel + +from azure.ai.agentserver.core.tools import (FoundryConnectedTool, FoundryHostedMcpTool, FoundryToolDetails, + ResolvedFoundryTool, SchemaDefinition, SchemaProperty, SchemaType) +from azure.ai.agentserver.langgraph.tools._resolver import ( + ResolvedTools, + FoundryLangChainToolResolver, + get_registry, +) + + +@pytest.mark.unit +class TestResolvedTools: + """Tests for ResolvedTools class.""" + + def test_create_empty_resolved_tools(self): + """Test creating an empty ResolvedTools.""" + resolved = ResolvedTools(tools=[]) + + tools_list = list(resolved) + assert len(tools_list) == 0 + + def test_create_with_single_tool( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, + ): + """Test creating ResolvedTools with a single tool.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + tools_list = list(resolved) + assert len(tools_list) == 1 + assert tools_list[0] is mock_langchain_tool + + def test_create_with_multiple_tools( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + sample_mcp_connected_tool: FoundryConnectedTool, + ): + """Test creating ResolvedTools with multiple tools.""" + @tool + def tool1(query: str) -> str: + """Tool 1.""" + return "result1" + + @tool + def tool2(query: str) -> str: + """Tool 2.""" + return "result2" + + resolved_tool1 = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="tool1", + description="Tool 1", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + resolved_tool2 = ResolvedFoundryTool( + definition=sample_mcp_connected_tool, + details=FoundryToolDetails( + name="tool2", + description="Tool 2", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + + resolved = ResolvedTools(tools=[ + (resolved_tool1, tool1), + (resolved_tool2, tool2), + ]) + + tools_list = list(resolved) + assert len(tools_list) == 2 + + def test_get_tool_by_foundry_tool_like( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, + ): + """Test getting tools by FoundryToolLike.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + # Get by the original foundry tool definition + tools = list(resolved.get(sample_code_interpreter_tool)) + assert len(tools) == 1 + assert tools[0] is mock_langchain_tool + + def test_get_tools_by_list_of_foundry_tools( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + sample_mcp_connected_tool: FoundryConnectedTool, + ): + """Test getting tools by a list of FoundryToolLike.""" + @tool + def tool1(query: str) -> str: + """Tool 1.""" + return "result1" + + @tool + def tool2(query: str) -> str: + """Tool 2.""" + return "result2" + + resolved_tool1 = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="tool1", + description="Tool 1", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + resolved_tool2 = ResolvedFoundryTool( + definition=sample_mcp_connected_tool, + details=FoundryToolDetails( + name="tool2", + description="Tool 2", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + + resolved = ResolvedTools(tools=[ + (resolved_tool1, tool1), + (resolved_tool2, tool2), + ]) + + # Get by list of foundry tools + tools = list(resolved.get([sample_code_interpreter_tool, sample_mcp_connected_tool])) + assert len(tools) == 2 + + def test_get_all_tools_when_no_filter( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, + ): + """Test getting all tools when no filter is provided.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + # Get all tools (no filter) + tools = list(resolved.get()) + assert len(tools) == 1 + + def test_get_returns_empty_for_unknown_tool( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + sample_mcp_connected_tool: FoundryConnectedTool, + mock_langchain_tool: BaseTool, + ): + """Test that get returns empty when requesting unknown tool.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + # Get by a different foundry tool (not in resolved) + tools = list(resolved.get(sample_mcp_connected_tool)) + assert len(tools) == 0 + + def test_iteration_over_resolved_tools( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, + ): + """Test iterating over ResolvedTools.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + # Iterate using for loop + count = 0 + for t in resolved: + assert t is mock_langchain_tool + count += 1 + assert count == 1 + + +@pytest.mark.unit +class TestFoundryLangChainToolResolver: + """Tests for FoundryLangChainToolResolver class.""" + + def test_init_with_default_name_resolver(self): + """Test initialization with default name resolver.""" + resolver = FoundryLangChainToolResolver() + + assert resolver._name_resolver is not None + + def test_init_with_custom_name_resolver(self): + """Test initialization with custom name resolver.""" + from azure.ai.agentserver.core.tools.utils import ToolNameResolver + + custom_resolver = ToolNameResolver() + resolver = FoundryLangChainToolResolver(name_resolver=custom_resolver) + + assert resolver._name_resolver is custom_resolver + + def test_create_pydantic_model_with_required_fields(self): + """Test creating a Pydantic model with required fields.""" + input_schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Search query"), + "limit": SchemaProperty(type=SchemaType.INTEGER, description="Max results"), + }, + required={"query"}, + ) + + model = FoundryLangChainToolResolver._create_pydantic_model("test_tool", input_schema) + + assert issubclass(model, BaseModel) + # Check that the model has the expected fields + assert "query" in model.model_fields + assert "limit" in model.model_fields + + def test_create_pydantic_model_with_no_required_fields(self): + """Test creating a Pydantic model with no required fields.""" + input_schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Search query"), + }, + required=set(), + ) + + model = FoundryLangChainToolResolver._create_pydantic_model("optional_tool", input_schema) + + assert issubclass(model, BaseModel) + assert "query" in model.model_fields + # Optional field should have None as default + assert model.model_fields["query"].default is None + + def test_create_pydantic_model_with_special_characters_in_name(self): + """Test creating a Pydantic model with special characters in tool name.""" + input_schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ) + + model = FoundryLangChainToolResolver._create_pydantic_model("my-tool name", input_schema) + + assert issubclass(model, BaseModel) + # Name should be sanitized + assert "-Input" in model.__name__ or "Input" in model.__name__ + + def test_create_structured_tool(self): + """Test creating a StructuredTool from a resolved foundry tool.""" + resolver = FoundryLangChainToolResolver() + + foundry_tool = FoundryHostedMcpTool(name="test_tool", configuration={}) + resolved_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name="search", + description="Search for documents", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Search query"), + }, + required={"query"}, + ), + ), + ) + + structured_tool = resolver._create_structured_tool(resolved_tool) + + assert isinstance(structured_tool, StructuredTool) + assert structured_tool.description == "Search for documents" + assert structured_tool.coroutine is not None # Should have async function + + @pytest.mark.asyncio + async def test_resolve_from_registry(self): + """Test resolving tools from the global registry.""" + resolver = FoundryLangChainToolResolver() + + # Mock the AgentServerContext + mock_context = MagicMock() + mock_catalog = AsyncMock() + mock_context.tools.catalog.list = mock_catalog + + foundry_tool = FoundryHostedMcpTool(name="test_tool", configuration={}) + resolved_foundry_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name="search", + description="Search tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Query"), + }, + required={"query"}, + ), + ), + ) + mock_catalog.return_value = [resolved_foundry_tool] + + # Add tool to registry + registry = get_registry() + registry.clear() + registry.append({"type": "code_interpreter"}) + + with patch("azure.ai.agentserver.langgraph.tools._resolver.AgentServerContext.get", return_value=mock_context): + result = await resolver.resolve_from_registry() + + assert isinstance(result, ResolvedTools) + mock_catalog.assert_called_once() + + # Clean up registry + registry.clear() + + @pytest.mark.asyncio + async def test_resolve_with_foundry_tools_list(self): + """Test resolving a list of foundry tools.""" + resolver = FoundryLangChainToolResolver() + + # Mock the AgentServerContext + mock_context = MagicMock() + mock_catalog = AsyncMock() + mock_context.tools.catalog.list = mock_catalog + + foundry_tool = FoundryHostedMcpTool(name="code_interpreter", configuration={}) + resolved_foundry_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name="execute_code", + description="Execute code", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "code": SchemaProperty(type=SchemaType.STRING, description="Code to execute"), + }, + required={"code"}, + ), + ), + ) + mock_catalog.return_value = [resolved_foundry_tool] + + foundry_tools = [{"type": "code_interpreter"}] + + with patch("azure.ai.agentserver.langgraph.tools._resolver.AgentServerContext.get", return_value=mock_context): + result = await resolver.resolve(foundry_tools) + + assert isinstance(result, ResolvedTools) + tools_list = list(result) + assert len(tools_list) == 1 + assert isinstance(tools_list[0], StructuredTool) + + @pytest.mark.asyncio + async def test_resolve_empty_list(self): + """Test resolving an empty list of foundry tools.""" + resolver = FoundryLangChainToolResolver() + + # Mock the AgentServerContext + mock_context = MagicMock() + mock_catalog = AsyncMock() + mock_context.tools.catalog.list = mock_catalog + mock_catalog.return_value = [] + + with patch("azure.ai.agentserver.langgraph.tools._resolver.AgentServerContext.get", return_value=mock_context): + result = await resolver.resolve([]) + + assert isinstance(result, ResolvedTools) + tools_list = list(result) + assert len(tools_list) == 0 + + +@pytest.mark.unit +class TestGetRegistry: + """Tests for the get_registry function.""" + + def test_get_registry_returns_list(self): + """Test that get_registry returns a list.""" + registry = get_registry() + + assert isinstance(registry, list) + + def test_registry_is_singleton(self): + """Test that get_registry returns the same list instance.""" + registry1 = get_registry() + registry2 = get_registry() + + assert registry1 is registry2 + + def test_registry_can_be_modified(self): + """Test that the registry can be modified.""" + registry = get_registry() + original_length = len(registry) + + registry.append({"type": "test_tool"}) + + assert len(registry) == original_length + 1 + + # Clean up + registry.pop() + + def test_registry_extend(self): + """Test extending the registry with multiple tools.""" + registry = get_registry() + registry.clear() + + tools = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "test"}, + ] + registry.extend(tools) + + assert len(registry) == 2 + + # Clean up + registry.clear() diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_tool_node.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_tool_node.py new file mode 100644 index 000000000000..1c46e58785bc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/langgraph/tools/test_tool_node.py @@ -0,0 +1,179 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolCallWrapper and FoundryToolNodeWrappers.""" +import pytest +from typing import Any, List +from unittest.mock import AsyncMock, MagicMock + +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool +from langgraph.prebuilt.tool_node import ToolCallRequest +from langgraph.types import Command + +from azure.ai.agentserver.langgraph.tools._tool_node import ( + FoundryToolCallWrapper, + FoundryToolNodeWrappers, +) + + +@pytest.mark.unit +class TestFoundryToolCallWrapper: + """Tests for FoundryToolCallWrapper class.""" + + def test_as_wrappers_returns_typed_dict(self): + """Test that as_wrappers returns a FoundryToolNodeWrappers TypedDict.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + result = wrapper.as_wrappers() + + assert isinstance(result, dict) + assert "wrap_tool_call" in result + assert "awrap_tool_call" in result + assert callable(result["wrap_tool_call"]) + assert callable(result["awrap_tool_call"]) + + def test_call_tool_with_already_resolved_tool(self): + """Test that call_tool passes through when tool is already resolved.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + # Create request with tool already set + @tool + def existing_tool(x: str) -> str: + """Existing tool.""" + return f"Result: {x}" + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = existing_tool + mock_request.tool_call = {"name": "existing_tool", "id": "call_1"} + + expected_result = ToolMessage(content="Result: test", tool_call_id="call_1") + mock_invocation = MagicMock(return_value=expected_result) + + result = wrapper.call_tool(mock_request, mock_invocation) + + # Should pass through original request + mock_invocation.assert_called_once_with(mock_request) + assert result == expected_result + + def test_call_tool_with_no_foundry_tools(self): + """Test that call_tool passes through when no foundry tools configured.""" + foundry_tools: List[Any] = [] + wrapper = FoundryToolCallWrapper(foundry_tools) + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "some_tool", "id": "call_1"} + + expected_result = ToolMessage(content="Result", tool_call_id="call_1") + mock_invocation = MagicMock(return_value=expected_result) + + result = wrapper.call_tool(mock_request, mock_invocation) + + mock_invocation.assert_called_once_with(mock_request) + assert result == expected_result + + @pytest.mark.asyncio + async def test_call_tool_async_with_already_resolved_tool(self): + """Test that call_tool_async passes through when tool is already resolved.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + @tool + def existing_tool(x: str) -> str: + """Existing tool.""" + return f"Result: {x}" + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = existing_tool + mock_request.tool_call = {"name": "existing_tool", "id": "call_1"} + + expected_result = ToolMessage(content="Async Result", tool_call_id="call_1") + mock_invocation = AsyncMock(return_value=expected_result) + + result = await wrapper.call_tool_async(mock_request, mock_invocation) + + mock_invocation.assert_awaited_once_with(mock_request) + assert result == expected_result + + @pytest.mark.asyncio + async def test_call_tool_async_with_no_foundry_tools(self): + """Test that call_tool_async passes through when no foundry tools configured.""" + foundry_tools: List[Any] = [] + wrapper = FoundryToolCallWrapper(foundry_tools) + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "some_tool", "id": "call_1"} + + expected_result = ToolMessage(content="Result", tool_call_id="call_1") + mock_invocation = AsyncMock(return_value=expected_result) + + result = await wrapper.call_tool_async(mock_request, mock_invocation) + + mock_invocation.assert_awaited_once_with(mock_request) + assert result == expected_result + + def test_call_tool_returns_command_result(self): + """Test that call_tool can return Command objects.""" + foundry_tools: List[Any] = [] + wrapper = FoundryToolCallWrapper(foundry_tools) + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "some_tool", "id": "call_1"} + + # Return a Command instead of ToolMessage + expected_result = Command(goto="next_node") + mock_invocation = MagicMock(return_value=expected_result) + + result = wrapper.call_tool(mock_request, mock_invocation) + + assert result == expected_result + assert isinstance(result, Command) + + +@pytest.mark.unit +class TestFoundryToolNodeWrappers: + """Tests for FoundryToolNodeWrappers TypedDict.""" + + def test_foundry_tool_node_wrappers_structure(self): + """Test that FoundryToolNodeWrappers has the expected structure.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + wrappers: FoundryToolNodeWrappers = wrapper.as_wrappers() + + # Should have both sync and async wrappers + assert "wrap_tool_call" in wrappers + assert "awrap_tool_call" in wrappers + + # Should be the wrapper methods + assert wrappers["wrap_tool_call"] == wrapper.call_tool + assert wrappers["awrap_tool_call"] == wrapper.call_tool_async + + def test_wrappers_can_be_unpacked_to_tool_node(self): + """Test that wrappers can be unpacked as kwargs to ToolNode.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + wrappers = wrapper.as_wrappers() + + # Should be usable as kwargs + assert len(wrappers) == 2 + + # This pattern is used: ToolNode([], **wrappers) + def mock_tool_node_init(tools, wrap_tool_call=None, awrap_tool_call=None): + return { + "tools": tools, + "wrap_tool_call": wrap_tool_call, + "awrap_tool_call": awrap_tool_call, + } + + result = mock_tool_node_init([], **wrappers) + + assert result["wrap_tool_call"] is not None + assert result["awrap_tool_call"] is not None +