From e4889b75d44d9c521a5137b048cbc3221ef1ef82 Mon Sep 17 00:00:00 2001 From: Lovish Arora <46993225+lavish0000@users.noreply.github.com> Date: Fri, 6 Mar 2026 05:59:31 +0100 Subject: [PATCH 1/2] fix(types): type converted MCPServer handler results --- src/mcp/server/mcpserver/prompts/base.py | 4 ++-- .../server/mcpserver/resources/templates.py | 2 +- src/mcp/server/mcpserver/server.py | 21 +++++++------------ src/mcp/server/mcpserver/tools/base.py | 20 ++++++++++++++++-- .../server/mcpserver/tools/tool_manager.py | 21 ++++++++++++++++++- .../mcpserver/utilities/func_metadata.py | 4 +++- 6 files changed, 52 insertions(+), 20 deletions(-) diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index 0c319d53c..eb5ee7796 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -69,14 +69,14 @@ class Prompt(BaseModel): title: str | None = Field(None, description="Human-readable title of the prompt") description: str | None = Field(None, description="Description of what the prompt does") arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt") - fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True) + fn: Callable[..., PromptResult] = Field(exclude=True) icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this prompt") context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context", exclude=True) @classmethod def from_function( cls, - fn: Callable[..., PromptResult | Awaitable[PromptResult]], + fn: Callable[..., PromptResult], name: str | None = None, title: str | None = None, description: str | None = None, diff --git a/src/mcp/server/mcpserver/resources/templates.py b/src/mcp/server/mcpserver/resources/templates.py index 2d612657c..21f1172f1 100644 --- a/src/mcp/server/mcpserver/resources/templates.py +++ b/src/mcp/server/mcpserver/resources/templates.py @@ -82,7 +82,7 @@ def from_function( context_kwarg=context_kwarg, ) - def matches(self, uri: str) -> dict[str, Any] | None: + def matches(self, uri: str) -> dict[str, str] | None: """Check if URI matches template and extract parameters. Extracted parameters are URL-decoded to handle percent-encoded characters. diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 2a7a58117..cdee88c76 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -4,11 +4,10 @@ import base64 import inspect -import json import re from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence from contextlib import AbstractAsyncContextManager, asynccontextmanager -from typing import Any, Generic, Literal, TypeVar, overload +from typing import Any, Generic, Literal, TypeVar, cast, overload import anyio import pydantic_core @@ -36,6 +35,7 @@ from mcp.server.mcpserver.resources import FunctionResource, Resource, ResourceManager from mcp.server.mcpserver.tools import Tool, ToolManager from mcp.server.mcpserver.utilities.context_injection import find_context_parameter +from mcp.server.mcpserver.utilities.func_metadata import ConvertedToolResult from mcp.server.mcpserver.utilities.logging import configure_logging, get_logger from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server @@ -308,18 +308,13 @@ async def _handle_call_tool( if isinstance(result, CallToolResult): return result if isinstance(result, tuple) and len(result) == 2: - unstructured_content, structured_content = result - return CallToolResult( - content=list(unstructured_content), # type: ignore[arg-type] - structured_content=structured_content, # type: ignore[arg-type] + unstructured_content, structured_content = cast( + tuple[Sequence[ContentBlock], dict[str, Any]], + result, ) - if isinstance(result, dict): # pragma: no cover - # TODO: this code path is unreachable — convert_result never returns a raw dict. - # The call_tool return type (Sequence[ContentBlock] | dict[str, Any]) is wrong - # and needs to be cleaned up. return CallToolResult( - content=[TextContent(type="text", text=json.dumps(result, indent=2))], - structured_content=result, + content=list(unstructured_content), + structured_content=structured_content, ) return CallToolResult(content=list(result)) @@ -390,7 +385,7 @@ async def list_tools(self) -> list[MCPTool]: async def call_tool( self, name: str, arguments: dict[str, Any], context: Context[LifespanResultT, Any] | None = None - ) -> Sequence[ContentBlock] | dict[str, Any]: + ) -> ConvertedToolResult: """Call a tool by name with arguments.""" if context is None: context = Context(mcp_server=self) diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index dc65be988..0d38fe5c0 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -4,13 +4,13 @@ import inspect from collections.abc import Callable from functools import cached_property -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, overload from pydantic import BaseModel, Field from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.utilities.context_injection import find_context_parameter -from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata +from mcp.server.mcpserver.utilities.func_metadata import ConvertedToolResult, FuncMetadata, func_metadata from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.shared.tool_name_validation import validate_and_warn_tool_name from mcp.types import Icon, ToolAnnotations @@ -89,6 +89,22 @@ def from_function( meta=meta, ) + @overload + async def run( + self, + arguments: dict[str, Any], + context: Context[LifespanContextT, RequestT], + convert_result: Literal[True], + ) -> ConvertedToolResult: ... + + @overload + async def run( + self, + arguments: dict[str, Any], + context: Context[LifespanContextT, RequestT], + convert_result: Literal[False] = False, + ) -> Any: ... + async def run( self, arguments: dict[str, Any], diff --git a/src/mcp/server/mcpserver/tools/tool_manager.py b/src/mcp/server/mcpserver/tools/tool_manager.py index 32ed54797..64ed5e580 100644 --- a/src/mcp/server/mcpserver/tools/tool_manager.py +++ b/src/mcp/server/mcpserver/tools/tool_manager.py @@ -1,10 +1,11 @@ from __future__ import annotations from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, overload from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.tools.base import Tool +from mcp.server.mcpserver.utilities.func_metadata import ConvertedToolResult from mcp.server.mcpserver.utilities.logging import get_logger from mcp.types import Icon, ToolAnnotations @@ -77,6 +78,24 @@ def remove_tool(self, name: str) -> None: raise ToolError(f"Unknown tool: {name}") del self._tools[name] + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any], + context: Context[LifespanContextT, RequestT], + convert_result: Literal[True], + ) -> ConvertedToolResult: ... + + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any], + context: Context[LifespanContextT, RequestT], + convert_result: Literal[False] = False, + ) -> Any: ... + async def call_tool( self, name: str, diff --git a/src/mcp/server/mcpserver/utilities/func_metadata.py b/src/mcp/server/mcpserver/utilities/func_metadata.py index 062b47d0f..f0123ef88 100644 --- a/src/mcp/server/mcpserver/utilities/func_metadata.py +++ b/src/mcp/server/mcpserver/utilities/func_metadata.py @@ -4,7 +4,7 @@ from collections.abc import Awaitable, Callable, Sequence from itertools import chain from types import GenericAlias -from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints +from typing import Annotated, Any, TypeAlias, cast, get_args, get_origin, get_type_hints import anyio import anyio.to_thread @@ -28,6 +28,8 @@ logger = get_logger(__name__) +ConvertedToolResult: TypeAlias = CallToolResult | Sequence[ContentBlock] | tuple[Sequence[ContentBlock], dict[str, Any]] + class StrictJsonSchema(GenerateJsonSchema): """A JSON schema generator that raises exceptions instead of emitting warnings. From 23b75b8980dbdde375e210f604df960d5c18f541 Mon Sep 17 00:00:00 2001 From: Lovish Arora <46993225+lavish0000@users.noreply.github.com> Date: Fri, 6 Mar 2026 06:38:31 +0100 Subject: [PATCH 2/2] test: align structured tool result assertions with types --- tests/server/mcpserver/test_tool_manager.py | 38 +++++++++++++++++---- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/tests/server/mcpserver/test_tool_manager.py b/tests/server/mcpserver/test_tool_manager.py index f990ec47b..645e1d79d 100644 --- a/tests/server/mcpserver/test_tool_manager.py +++ b/tests/server/mcpserver/test_tool_manager.py @@ -1,7 +1,8 @@ import json import logging +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, TypedDict +from typing import Any, TypedDict, cast import pytest from pydantic import BaseModel @@ -12,7 +13,12 @@ from mcp.server.mcpserver.tools import Tool, ToolManager from mcp.server.mcpserver.utilities.func_metadata import ArgModelBase, FuncMetadata from mcp.server.session import ServerSessionT -from mcp.types import TextContent, ToolAnnotations +from mcp.types import ContentBlock, TextContent, ToolAnnotations + + +def _text_contents(unstructured_content: Sequence[ContentBlock]) -> list[TextContent]: + assert all(isinstance(item, TextContent) for item in unstructured_content) + return [cast(TextContent, item) for item in unstructured_content] class TestAddTools: @@ -456,7 +462,12 @@ def get_user(user_id: int) -> UserOutput: manager.add_tool(get_user) result = await manager.call_tool("get_user", {"user_id": 1}, Context(), convert_result=True) # don't test unstructured output here, just the structured conversion - assert len(result) == 2 and result[1] == {"name": "John", "age": 30} + assert isinstance(result, tuple) + assert len(result) == 2 + unstructured_content, structured_content = cast(tuple[Sequence[ContentBlock], dict[str, Any]], result) + text_items = _text_contents(unstructured_content) + assert structured_content == {"name": "John", "age": 30} + assert json.loads(text_items[0].text) == structured_content @pytest.mark.anyio async def test_tool_with_primitive_output(self): @@ -471,7 +482,12 @@ def double_number(n: int) -> int: result = await manager.call_tool("double_number", {"n": 5}, Context()) assert result == 10 result = await manager.call_tool("double_number", {"n": 5}, Context(), convert_result=True) - assert isinstance(result[0][0], TextContent) and result[1] == {"result": 10} + assert isinstance(result, tuple) + assert len(result) == 2 + unstructured_content, structured_content = cast(tuple[Sequence[ContentBlock], dict[str, Any]], result) + text_items = _text_contents(unstructured_content) + assert text_items[0].text == "10" + assert structured_content == {"result": 10} @pytest.mark.anyio async def test_tool_with_typeddict_output(self): @@ -511,7 +527,12 @@ def get_person() -> Person: manager.add_tool(get_person) result = await manager.call_tool("get_person", {}, Context(), convert_result=True) # don't test unstructured output here, just the structured conversion - assert len(result) == 2 and result[1] == expected_output + assert isinstance(result, tuple) + assert len(result) == 2 + unstructured_content, structured_content = cast(tuple[Sequence[ContentBlock], dict[str, Any]], result) + text_items = _text_contents(unstructured_content) + assert structured_content == expected_output + assert json.loads(text_items[0].text) == structured_content @pytest.mark.anyio async def test_tool_with_list_output(self): @@ -529,7 +550,12 @@ def get_numbers() -> list[int]: result = await manager.call_tool("get_numbers", {}, Context()) assert result == expected_list result = await manager.call_tool("get_numbers", {}, Context(), convert_result=True) - assert isinstance(result[0][0], TextContent) and result[1] == expected_output + assert isinstance(result, tuple) + assert len(result) == 2 + unstructured_content, structured_content = cast(tuple[Sequence[ContentBlock], dict[str, Any]], result) + text_items = _text_contents(unstructured_content) + assert [item.text for item in text_items] == ["1", "2", "3", "4", "5"] + assert structured_content == expected_output @pytest.mark.anyio async def test_tool_without_structured_output(self):