Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/mcp/server/mcpserver/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ class Prompt(BaseModel):
title: str | None = Field(None, description="Human-readable title of the prompt")
description: str | None = Field(None, description="Description of what the prompt does")
arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt")
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
fn: Callable[..., PromptResult] = Field(exclude=True)
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this prompt")
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context", exclude=True)

@classmethod
def from_function(
cls,
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
fn: Callable[..., PromptResult],
name: str | None = None,
title: str | None = None,
description: str | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/mcpserver/resources/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def from_function(
context_kwarg=context_kwarg,
)

def matches(self, uri: str) -> dict[str, Any] | None:
def matches(self, uri: str) -> dict[str, str] | None:
"""Check if URI matches template and extract parameters.

Extracted parameters are URL-decoded to handle percent-encoded characters.
Expand Down
21 changes: 8 additions & 13 deletions src/mcp/server/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

import base64
import inspect
import json
import re
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import Any, Generic, Literal, TypeVar, overload
from typing import Any, Generic, Literal, TypeVar, cast, overload

import anyio
import pydantic_core
Expand Down Expand Up @@ -36,6 +35,7 @@
from mcp.server.mcpserver.resources import FunctionResource, Resource, ResourceManager
from mcp.server.mcpserver.tools import Tool, ToolManager
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter
from mcp.server.mcpserver.utilities.func_metadata import ConvertedToolResult
from mcp.server.mcpserver.utilities.logging import configure_logging, get_logger
from mcp.server.sse import SseServerTransport
from mcp.server.stdio import stdio_server
Expand Down Expand Up @@ -308,18 +308,13 @@ async def _handle_call_tool(
if isinstance(result, CallToolResult):
return result
if isinstance(result, tuple) and len(result) == 2:
unstructured_content, structured_content = result
return CallToolResult(
content=list(unstructured_content), # type: ignore[arg-type]
structured_content=structured_content, # type: ignore[arg-type]
unstructured_content, structured_content = cast(
tuple[Sequence[ContentBlock], dict[str, Any]],
result,
)
if isinstance(result, dict): # pragma: no cover
# TODO: this code path is unreachable — convert_result never returns a raw dict.
# The call_tool return type (Sequence[ContentBlock] | dict[str, Any]) is wrong
# and needs to be cleaned up.
return CallToolResult(
content=[TextContent(type="text", text=json.dumps(result, indent=2))],
structured_content=result,
content=list(unstructured_content),
structured_content=structured_content,
)
return CallToolResult(content=list(result))

Expand Down Expand Up @@ -390,7 +385,7 @@ async def list_tools(self) -> list[MCPTool]:

async def call_tool(
self, name: str, arguments: dict[str, Any], context: Context[LifespanResultT, Any] | None = None
) -> Sequence[ContentBlock] | dict[str, Any]:
) -> ConvertedToolResult:
"""Call a tool by name with arguments."""
if context is None:
context = Context(mcp_server=self)
Expand Down
20 changes: 18 additions & 2 deletions src/mcp/server/mcpserver/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import inspect
from collections.abc import Callable
from functools import cached_property
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal, overload

from pydantic import BaseModel, Field

from mcp.server.mcpserver.exceptions import ToolError
from mcp.server.mcpserver.utilities.context_injection import find_context_parameter
from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata
from mcp.server.mcpserver.utilities.func_metadata import ConvertedToolResult, FuncMetadata, func_metadata
from mcp.shared.exceptions import UrlElicitationRequiredError
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
from mcp.types import Icon, ToolAnnotations
Expand Down Expand Up @@ -89,6 +89,22 @@ def from_function(
meta=meta,
)

@overload
async def run(
self,
arguments: dict[str, Any],
context: Context[LifespanContextT, RequestT],
convert_result: Literal[True],
) -> ConvertedToolResult: ...

@overload
async def run(
self,
arguments: dict[str, Any],
context: Context[LifespanContextT, RequestT],
convert_result: Literal[False] = False,
) -> Any: ...

async def run(
self,
arguments: dict[str, Any],
Expand Down
21 changes: 20 additions & 1 deletion src/mcp/server/mcpserver/tools/tool_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal, overload

from mcp.server.mcpserver.exceptions import ToolError
from mcp.server.mcpserver.tools.base import Tool
from mcp.server.mcpserver.utilities.func_metadata import ConvertedToolResult
from mcp.server.mcpserver.utilities.logging import get_logger
from mcp.types import Icon, ToolAnnotations

Expand Down Expand Up @@ -77,6 +78,24 @@ def remove_tool(self, name: str) -> None:
raise ToolError(f"Unknown tool: {name}")
del self._tools[name]

@overload
async def call_tool(
self,
name: str,
arguments: dict[str, Any],
context: Context[LifespanContextT, RequestT],
convert_result: Literal[True],
) -> ConvertedToolResult: ...

@overload
async def call_tool(
self,
name: str,
arguments: dict[str, Any],
context: Context[LifespanContextT, RequestT],
convert_result: Literal[False] = False,
) -> Any: ...

async def call_tool(
self,
name: str,
Expand Down
4 changes: 3 additions & 1 deletion src/mcp/server/mcpserver/utilities/func_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Awaitable, Callable, Sequence
from itertools import chain
from types import GenericAlias
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
from typing import Annotated, Any, TypeAlias, cast, get_args, get_origin, get_type_hints

import anyio
import anyio.to_thread
Expand All @@ -28,6 +28,8 @@

logger = get_logger(__name__)

ConvertedToolResult: TypeAlias = CallToolResult | Sequence[ContentBlock] | tuple[Sequence[ContentBlock], dict[str, Any]]


class StrictJsonSchema(GenerateJsonSchema):
"""A JSON schema generator that raises exceptions instead of emitting warnings.
Expand Down
38 changes: 32 additions & 6 deletions tests/server/mcpserver/test_tool_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import logging
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, TypedDict
from typing import Any, TypedDict, cast

import pytest
from pydantic import BaseModel
Expand All @@ -12,7 +13,12 @@
from mcp.server.mcpserver.tools import Tool, ToolManager
from mcp.server.mcpserver.utilities.func_metadata import ArgModelBase, FuncMetadata
from mcp.server.session import ServerSessionT
from mcp.types import TextContent, ToolAnnotations
from mcp.types import ContentBlock, TextContent, ToolAnnotations


def _text_contents(unstructured_content: Sequence[ContentBlock]) -> list[TextContent]:
assert all(isinstance(item, TextContent) for item in unstructured_content)
return [cast(TextContent, item) for item in unstructured_content]


class TestAddTools:
Expand Down Expand Up @@ -456,7 +462,12 @@ def get_user(user_id: int) -> UserOutput:
manager.add_tool(get_user)
result = await manager.call_tool("get_user", {"user_id": 1}, Context(), convert_result=True)
# don't test unstructured output here, just the structured conversion
assert len(result) == 2 and result[1] == {"name": "John", "age": 30}
assert isinstance(result, tuple)
assert len(result) == 2
unstructured_content, structured_content = cast(tuple[Sequence[ContentBlock], dict[str, Any]], result)
text_items = _text_contents(unstructured_content)
assert structured_content == {"name": "John", "age": 30}
assert json.loads(text_items[0].text) == structured_content

@pytest.mark.anyio
async def test_tool_with_primitive_output(self):
Expand All @@ -471,7 +482,12 @@ def double_number(n: int) -> int:
result = await manager.call_tool("double_number", {"n": 5}, Context())
assert result == 10
result = await manager.call_tool("double_number", {"n": 5}, Context(), convert_result=True)
assert isinstance(result[0][0], TextContent) and result[1] == {"result": 10}
assert isinstance(result, tuple)
assert len(result) == 2
unstructured_content, structured_content = cast(tuple[Sequence[ContentBlock], dict[str, Any]], result)
text_items = _text_contents(unstructured_content)
assert text_items[0].text == "10"
assert structured_content == {"result": 10}

@pytest.mark.anyio
async def test_tool_with_typeddict_output(self):
Expand Down Expand Up @@ -511,7 +527,12 @@ def get_person() -> Person:
manager.add_tool(get_person)
result = await manager.call_tool("get_person", {}, Context(), convert_result=True)
# don't test unstructured output here, just the structured conversion
assert len(result) == 2 and result[1] == expected_output
assert isinstance(result, tuple)
assert len(result) == 2
unstructured_content, structured_content = cast(tuple[Sequence[ContentBlock], dict[str, Any]], result)
text_items = _text_contents(unstructured_content)
assert structured_content == expected_output
assert json.loads(text_items[0].text) == structured_content

@pytest.mark.anyio
async def test_tool_with_list_output(self):
Expand All @@ -529,7 +550,12 @@ def get_numbers() -> list[int]:
result = await manager.call_tool("get_numbers", {}, Context())
assert result == expected_list
result = await manager.call_tool("get_numbers", {}, Context(), convert_result=True)
assert isinstance(result[0][0], TextContent) and result[1] == expected_output
assert isinstance(result, tuple)
assert len(result) == 2
unstructured_content, structured_content = cast(tuple[Sequence[ContentBlock], dict[str, Any]], result)
text_items = _text_contents(unstructured_content)
assert [item.text for item in text_items] == ["1", "2", "3", "4", "5"]
assert structured_content == expected_output

@pytest.mark.anyio
async def test_tool_without_structured_output(self):
Expand Down