From f4760f131b4779a4328fb0c4e5e73ebfc7a593cc Mon Sep 17 00:00:00 2001 From: Akshay Kumar BM <90820098+akshaykumarbedre@users.noreply.github.com> Date: Sun, 18 Jan 2026 02:16:07 +0530 Subject: [PATCH] fix: Make context logging functions spec-compliant by accepting Any type for data --- src/mcp/server/fastmcp/server.py | 1347 ------------------------ tests/server/fastmcp/test_server.py | 1502 --------------------------- 2 files changed, 2849 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 75f2d2237f..e69de29bb2 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -1,1347 +0,0 @@ -"""FastMCP - A more ergonomic interface for MCP servers.""" - -from __future__ import annotations - -import inspect -import re -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence -from contextlib import AbstractAsyncContextManager, asynccontextmanager -from typing import Any, Generic, Literal, overload - -import anyio -import pydantic_core -from pydantic import BaseModel -from pydantic.networks import AnyUrl -from pydantic_settings import BaseSettings, SettingsConfigDict -from starlette.applications import Starlette -from starlette.middleware import Middleware -from starlette.middleware.authentication import AuthenticationMiddleware -from starlette.requests import Request -from starlette.responses import Response -from starlette.routing import Mount, Route -from starlette.types import Receive, Scope, Send - -from mcp.server.auth.middleware.auth_context import AuthContextMiddleware -from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware -from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier -from mcp.server.auth.settings import AuthSettings -from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, UrlElicitationResult, elicit_with_validation -from mcp.server.elicitation import elicit_url as _elicit_url -from mcp.server.fastmcp.exceptions import ResourceError -from mcp.server.fastmcp.prompts import Prompt, PromptManager -from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager -from mcp.server.fastmcp.tools import Tool, ToolManager -from mcp.server.fastmcp.utilities.context_injection import find_context_parameter -from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger -from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.server import LifespanResultT -from mcp.server.lowlevel.server import Server as MCPServer -from mcp.server.lowlevel.server import lifespan as default_lifespan -from mcp.server.session import ServerSession, ServerSessionT -from mcp.server.sse import SseServerTransport -from mcp.server.stdio import stdio_server -from mcp.server.streamable_http import EventStore -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.context import LifespanContextT, RequestContext, RequestT -from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations -from mcp.types import Prompt as MCPPrompt -from mcp.types import PromptArgument as MCPPromptArgument -from mcp.types import Resource as MCPResource -from mcp.types import ResourceTemplate as MCPResourceTemplate -from mcp.types import Tool as MCPTool - -logger = get_logger(__name__) - - -class Settings(BaseSettings, Generic[LifespanResultT]): - """FastMCP server settings. - - All settings can be configured via environment variables with the prefix FASTMCP_. - For example, FASTMCP_DEBUG=true will set debug=True. - """ - - model_config = SettingsConfigDict( - env_prefix="FASTMCP_", - env_file=".env", - env_nested_delimiter="__", - nested_model_default_partial_update=True, - extra="ignore", - ) - - # Server settings - debug: bool - log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] - - # resource settings - warn_on_duplicate_resources: bool - - # tool settings - warn_on_duplicate_tools: bool - - # prompt settings - warn_on_duplicate_prompts: bool - - lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None - """A async context manager that will be called when the server is started.""" - - auth: AuthSettings | None - - -def lifespan_wrapper( - app: FastMCP[LifespanResultT], - lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[LifespanResultT]]: - @asynccontextmanager - async def wrap( - _: MCPServer[LifespanResultT, Request], - ) -> AsyncIterator[LifespanResultT]: - async with lifespan(app) as context: - yield context - - return wrap - - -class FastMCP(Generic[LifespanResultT]): - def __init__( - self, - name: str | None = None, - title: str | None = None, - description: str | None = None, - instructions: str | None = None, - website_url: str | None = None, - icons: list[Icon] | None = None, - version: str | None = None, - auth_server_provider: (OAuthAuthorizationServerProvider[Any, Any, Any] | None) = None, - token_verifier: TokenVerifier | None = None, - *, - tools: list[Tool] | None = None, - debug: bool = False, - log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", - warn_on_duplicate_resources: bool = True, - warn_on_duplicate_tools: bool = True, - warn_on_duplicate_prompts: bool = True, - lifespan: (Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None) = None, - auth: AuthSettings | None = None, - ): - self.settings = Settings( - debug=debug, - log_level=log_level, - warn_on_duplicate_resources=warn_on_duplicate_resources, - warn_on_duplicate_tools=warn_on_duplicate_tools, - warn_on_duplicate_prompts=warn_on_duplicate_prompts, - lifespan=lifespan, - auth=auth, - ) - - self._mcp_server = MCPServer( - name=name or "FastMCP", - title=title, - description=description, - instructions=instructions, - website_url=website_url, - icons=icons, - version=version, - # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an FastMCP and Server. - # We need to create a Lifespan type that is a generic on the server type, like Starlette does. - lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore - ) - self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) - self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) - self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) - # Validate auth configuration - if self.settings.auth is not None: - if auth_server_provider and token_verifier: # pragma: no cover - raise ValueError("Cannot specify both auth_server_provider and token_verifier") - if not auth_server_provider and not token_verifier: # pragma: no cover - raise ValueError("Must specify either auth_server_provider or token_verifier when auth is enabled") - elif auth_server_provider or token_verifier: # pragma: no cover - raise ValueError("Cannot specify auth_server_provider or token_verifier without auth settings") - - self._auth_server_provider = auth_server_provider - self._token_verifier = token_verifier - - # Create token verifier from provider if needed (backwards compatibility) - if auth_server_provider and not token_verifier: # pragma: no cover - self._token_verifier = ProviderTokenVerifier(auth_server_provider) - self._custom_starlette_routes: list[Route] = [] - self._session_manager: StreamableHTTPSessionManager | None = None - - # Set up MCP protocol handlers - self._setup_handlers() - - # Configure logging - configure_logging(self.settings.log_level) - - @property - def name(self) -> str: - return self._mcp_server.name - - @property - def title(self) -> str | None: - return self._mcp_server.title - - @property - def description(self) -> str | None: - return self._mcp_server.description - - @property - def instructions(self) -> str | None: - return self._mcp_server.instructions - - @property - def website_url(self) -> str | None: - return self._mcp_server.website_url - - @property - def icons(self) -> list[Icon] | None: - return self._mcp_server.icons - - @property - def version(self) -> str | None: - return self._mcp_server.version - - @property - def session_manager(self) -> StreamableHTTPSessionManager: - """Get the StreamableHTTP session manager. - - This is exposed to enable advanced use cases like mounting multiple - FastMCP servers in a single FastAPI application. - - Raises: - RuntimeError: If called before streamable_http_app() has been called. - """ - if self._session_manager is None: # pragma: no cover - raise RuntimeError( - "Session manager can only be accessed after" - "calling streamable_http_app()." - "The session manager is created lazily" - "to avoid unnecessary initialization." - ) - return self._session_manager # pragma: no cover - - @overload - def run(self, transport: Literal["stdio"] = ...) -> None: ... - - @overload - def run( - self, - transport: Literal["sse"], - *, - host: str = ..., - port: int = ..., - sse_path: str = ..., - message_path: str = ..., - transport_security: TransportSecuritySettings | None = ..., - ) -> None: ... - - @overload - def run( - self, - transport: Literal["streamable-http"], - *, - host: str = ..., - port: int = ..., - streamable_http_path: str = ..., - json_response: bool = ..., - stateless_http: bool = ..., - event_store: EventStore | None = ..., - retry_interval: int | None = ..., - transport_security: TransportSecuritySettings | None = ..., - ) -> None: ... - - def run( - self, - transport: Literal["stdio", "sse", "streamable-http"] = "stdio", - **kwargs: Any, - ) -> None: - """Run the FastMCP server. Note this is a synchronous function. - - Args: - transport: Transport protocol to use ("stdio", "sse", or "streamable-http") - **kwargs: Transport-specific options (see overloads for details) - """ - TRANSPORTS = Literal["stdio", "sse", "streamable-http"] - if transport not in TRANSPORTS.__args__: # type: ignore # pragma: no cover - raise ValueError(f"Unknown transport: {transport}") - - match transport: - case "stdio": - anyio.run(self.run_stdio_async) - case "sse": # pragma: no cover - anyio.run(lambda: self.run_sse_async(**kwargs)) - case "streamable-http": # pragma: no cover - anyio.run(lambda: self.run_streamable_http_async(**kwargs)) - - def _setup_handlers(self) -> None: - """Set up core MCP protocol handlers.""" - self._mcp_server.list_tools()(self.list_tools) - # Note: we disable the lowlevel server's input validation. - # FastMCP does ad hoc conversion of incoming data before validating - - # for now we preserve this for backwards compatibility. - self._mcp_server.call_tool(validate_input=False)(self.call_tool) - self._mcp_server.list_resources()(self.list_resources) - self._mcp_server.read_resource()(self.read_resource) - self._mcp_server.list_prompts()(self.list_prompts) - self._mcp_server.get_prompt()(self.get_prompt) - self._mcp_server.list_resource_templates()(self.list_resource_templates) - - async def list_tools(self) -> list[MCPTool]: - """List all available tools.""" - tools = self._tool_manager.list_tools() - return [ - MCPTool( - name=info.name, - title=info.title, - description=info.description, - input_schema=info.parameters, - output_schema=info.output_schema, - annotations=info.annotations, - icons=info.icons, - _meta=info.meta, - ) - for info in tools - ] - - def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: - """Returns a Context object. Note that the context will only be valid - during a request; outside a request, most methods will error. - """ - try: - request_context = self._mcp_server.request_context - except LookupError: - request_context = None - return Context(request_context=request_context, fastmcp=self) - - async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock] | dict[str, Any]: - """Call a tool by name with arguments.""" - context = self.get_context() - return await self._tool_manager.call_tool(name, arguments, context=context, convert_result=True) - - async def list_resources(self) -> list[MCPResource]: - """List all available resources.""" - - resources = self._resource_manager.list_resources() - return [ - MCPResource( - uri=resource.uri, - name=resource.name or "", - title=resource.title, - description=resource.description, - mime_type=resource.mime_type, - icons=resource.icons, - annotations=resource.annotations, - _meta=resource.meta, - ) - for resource in resources - ] - - async def list_resource_templates(self) -> list[MCPResourceTemplate]: - templates = self._resource_manager.list_templates() - return [ - MCPResourceTemplate( - uri_template=template.uri_template, - name=template.name, - title=template.title, - description=template.description, - mime_type=template.mime_type, - icons=template.icons, - annotations=template.annotations, - _meta=template.meta, - ) - for template in templates - ] - - async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]: - """Read a resource by URI.""" - - context = self.get_context() - resource = await self._resource_manager.get_resource(uri, context=context) - if not resource: # pragma: no cover - raise ResourceError(f"Unknown resource: {uri}") - - try: - content = await resource.read() - return [ReadResourceContents(content=content, mime_type=resource.mime_type, meta=resource.meta)] - except Exception as e: # pragma: no cover - logger.exception(f"Error reading resource {uri}") - raise ResourceError(str(e)) - - def add_tool( - self, - fn: AnyFunction, - name: str | None = None, - title: str | None = None, - description: str | None = None, - annotations: ToolAnnotations | None = None, - icons: list[Icon] | None = None, - meta: dict[str, Any] | None = None, - structured_output: bool | None = None, - ) -> None: - """Add a tool to the server. - - The tool function can optionally request a Context object by adding a parameter - with the Context type annotation. See the @tool decorator for examples. - - Args: - fn: The function to register as a tool - name: Optional name for the tool (defaults to function name) - title: Optional human-readable title for the tool - description: Optional description of what the tool does - annotations: Optional ToolAnnotations providing additional tool information - structured_output: Controls whether the tool's output is structured or unstructured - - If None, auto-detects based on the function's return type annotation - - If True, creates a structured tool (return type annotation permitting) - - If False, unconditionally creates an unstructured tool - """ - self._tool_manager.add_tool( - fn, - name=name, - title=title, - description=description, - annotations=annotations, - icons=icons, - meta=meta, - structured_output=structured_output, - ) - - def remove_tool(self, name: str) -> None: - """Remove a tool from the server by name. - - Args: - name: The name of the tool to remove - - Raises: - ToolError: If the tool does not exist - """ - self._tool_manager.remove_tool(name) - - def tool( - self, - name: str | None = None, - title: str | None = None, - description: str | None = None, - annotations: ToolAnnotations | None = None, - icons: list[Icon] | None = None, - meta: dict[str, Any] | None = None, - structured_output: bool | None = None, - ) -> Callable[[AnyFunction], AnyFunction]: - """Decorator to register a tool. - - Tools can optionally request a Context object by adding a parameter with the - Context type annotation. The context provides access to MCP capabilities like - logging, progress reporting, and resource access. - - Args: - name: Optional name for the tool (defaults to function name) - title: Optional human-readable title for the tool - description: Optional description of what the tool does - annotations: Optional ToolAnnotations providing additional tool information - structured_output: Controls whether the tool's output is structured or unstructured - - If None, auto-detects based on the function's return type annotation - - If True, creates a structured tool (return type annotation permitting) - - If False, unconditionally creates an unstructured tool - - Example: - @server.tool() - def my_tool(x: int) -> str: - return str(x) - - @server.tool() - def tool_with_context(x: int, ctx: Context) -> str: - ctx.info(f"Processing {x}") - return str(x) - - @server.tool() - async def async_tool(x: int, context: Context) -> str: - await context.report_progress(50, 100) - return str(x) - """ - # Check if user passed function directly instead of calling decorator - if callable(name): - raise TypeError( - "The @tool decorator was used incorrectly. Did you forget to call it? Use @tool() instead of @tool" - ) - - def decorator(fn: AnyFunction) -> AnyFunction: - self.add_tool( - fn, - name=name, - title=title, - description=description, - annotations=annotations, - icons=icons, - meta=meta, - structured_output=structured_output, - ) - return fn - - return decorator - - def completion(self): - """Decorator to register a completion handler. - - The completion handler receives: - - ref: PromptReference or ResourceTemplateReference - - argument: CompletionArgument with name and partial value - - context: Optional CompletionContext with previously resolved arguments - - Example: - @mcp.completion() - async def handle_completion(ref, argument, context): - if isinstance(ref, ResourceTemplateReference): - # Return completions based on ref, argument, and context - return Completion(values=["option1", "option2"]) - return None - """ - return self._mcp_server.completion() - - def add_resource(self, resource: Resource) -> None: - """Add a resource to the server. - - Args: - resource: A Resource instance to add - """ - self._resource_manager.add_resource(resource) - - def resource( - self, - uri: str, - *, - name: str | None = None, - title: str | None = None, - description: str | None = None, - mime_type: str | None = None, - icons: list[Icon] | None = None, - annotations: Annotations | None = None, - meta: dict[str, Any] | None = None, - ) -> Callable[[AnyFunction], AnyFunction]: - """Decorator to register a function as a resource. - - The function will be called when the resource is read to generate its content. - The function can return: - - str for text content - - bytes for binary content - - other types will be converted to JSON - - If the URI contains parameters (e.g. "resource://{param}") or the function - has parameters, it will be registered as a template resource. - - Args: - uri: URI for the resource (e.g. "resource://my-resource" or "resource://{param}") - name: Optional name for the resource - title: Optional human-readable title for the resource - description: Optional description of the resource - mime_type: Optional MIME type for the resource - meta: Optional metadata dictionary for the resource - - Example: - @server.resource("resource://my-resource") - def get_data() -> str: - return "Hello, world!" - - @server.resource("resource://my-resource") - async get_data() -> str: - data = await fetch_data() - return f"Hello, world! {data}" - - @server.resource("resource://{city}/weather") - def get_weather(city: str) -> str: - return f"Weather for {city}" - - @server.resource("resource://{city}/weather") - async def get_weather(city: str) -> str: - data = await fetch_weather(city) - return f"Weather for {city}: {data}" - """ - # Check if user passed function directly instead of calling decorator - if callable(uri): - raise TypeError( - "The @resource decorator was used incorrectly. " - "Did you forget to call it? Use @resource('uri') instead of @resource" - ) - - def decorator(fn: AnyFunction) -> AnyFunction: - # Check if this should be a template - sig = inspect.signature(fn) - has_uri_params = "{" in uri and "}" in uri - has_func_params = bool(sig.parameters) - - if has_uri_params or has_func_params: - # Check for Context parameter to exclude from validation - context_param = find_context_parameter(fn) - - # Validate that URI params match function params (excluding context) - uri_params = set(re.findall(r"{(\w+)}", uri)) - # We need to remove the context_param from the resource function if - # there is any. - func_params = {p for p in sig.parameters.keys() if p != context_param} - - if uri_params != func_params: - raise ValueError( - f"Mismatch between URI parameters {uri_params} and function parameters {func_params}" - ) - - # Register as template - self._resource_manager.add_template( - fn=fn, - uri_template=uri, - name=name, - title=title, - description=description, - mime_type=mime_type, - icons=icons, - annotations=annotations, - meta=meta, - ) - else: - # Register as regular resource - resource = FunctionResource.from_function( - fn=fn, - uri=uri, - name=name, - title=title, - description=description, - mime_type=mime_type, - icons=icons, - annotations=annotations, - meta=meta, - ) - self.add_resource(resource) - return fn - - return decorator - - def add_prompt(self, prompt: Prompt) -> None: - """Add a prompt to the server. - - Args: - prompt: A Prompt instance to add - """ - self._prompt_manager.add_prompt(prompt) - - def prompt( - self, - name: str | None = None, - title: str | None = None, - description: str | None = None, - icons: list[Icon] | None = None, - ) -> Callable[[AnyFunction], AnyFunction]: - """Decorator to register a prompt. - - Args: - name: Optional name for the prompt (defaults to function name) - title: Optional human-readable title for the prompt - description: Optional description of what the prompt does - - Example: - @server.prompt() - def analyze_table(table_name: str) -> list[Message]: - schema = read_table_schema(table_name) - return [ - { - "role": "user", - "content": f"Analyze this schema:\n{schema}" - } - ] - - @server.prompt() - async def analyze_file(path: str) -> list[Message]: - content = await read_file(path) - return [ - { - "role": "user", - "content": { - "type": "resource", - "resource": { - "uri": f"file://{path}", - "text": content - } - } - } - ] - """ - # Check if user passed function directly instead of calling decorator - if callable(name): - raise TypeError( - "The @prompt decorator was used incorrectly. " - "Did you forget to call it? Use @prompt() instead of @prompt" - ) - - def decorator(func: AnyFunction) -> AnyFunction: - prompt = Prompt.from_function(func, name=name, title=title, description=description, icons=icons) - self.add_prompt(prompt) - return func - - return decorator - - def custom_route( - self, - path: str, - methods: list[str], - name: str | None = None, - include_in_schema: bool = True, - ): - """Decorator to register a custom HTTP route on the FastMCP server. - - Allows adding arbitrary HTTP endpoints outside the standard MCP protocol, - which can be useful for OAuth callbacks, health checks, or admin APIs. - The handler function must be an async function that accepts a Starlette - Request and returns a Response. - - Routes using this decorator will not require authorization. It is intended - for uses that are either a part of authorization flows or intended to be - public such as health check endpoints. - - Args: - path: URL path for the route (e.g., "/oauth/callback") - methods: List of HTTP methods to support (e.g., ["GET", "POST"]) - name: Optional name for the route (to reference this route with - Starlette's reverse URL lookup feature) - include_in_schema: Whether to include in OpenAPI schema, defaults to True - - Example: - @server.custom_route("/health", methods=["GET"]) - async def health_check(request: Request) -> Response: - return JSONResponse({"status": "ok"}) - """ - - def decorator( # pragma: no cover - func: Callable[[Request], Awaitable[Response]], - ) -> Callable[[Request], Awaitable[Response]]: - self._custom_starlette_routes.append( - Route( - path, - endpoint=func, - methods=methods, - name=name, - include_in_schema=include_in_schema, - ) - ) - return func - - return decorator # pragma: no cover - - async def run_stdio_async(self) -> None: - """Run the server using stdio transport.""" - async with stdio_server() as (read_stream, write_stream): - await self._mcp_server.run( - read_stream, - write_stream, - self._mcp_server.create_initialization_options(), - ) - - async def run_sse_async( # pragma: no cover - self, - *, - host: str = "127.0.0.1", - port: int = 8000, - sse_path: str = "/sse", - message_path: str = "/messages/", - transport_security: TransportSecuritySettings | None = None, - ) -> None: - """Run the server using SSE transport.""" - import uvicorn - - starlette_app = self.sse_app( - sse_path=sse_path, - message_path=message_path, - transport_security=transport_security, - host=host, - ) - - config = uvicorn.Config( - starlette_app, - host=host, - port=port, - log_level=self.settings.log_level.lower(), - ) - server = uvicorn.Server(config) - await server.serve() - - async def run_streamable_http_async( # pragma: no cover - self, - *, - host: str = "127.0.0.1", - port: int = 8000, - streamable_http_path: str = "/mcp", - json_response: bool = False, - stateless_http: bool = False, - event_store: EventStore | None = None, - retry_interval: int | None = None, - transport_security: TransportSecuritySettings | None = None, - ) -> None: - """Run the server using StreamableHTTP transport.""" - import uvicorn - - starlette_app = self.streamable_http_app( - streamable_http_path=streamable_http_path, - json_response=json_response, - stateless_http=stateless_http, - event_store=event_store, - retry_interval=retry_interval, - transport_security=transport_security, - host=host, - ) - - config = uvicorn.Config( - starlette_app, - host=host, - port=port, - log_level=self.settings.log_level.lower(), - ) - server = uvicorn.Server(config) - await server.serve() - - def sse_app( - self, - *, - sse_path: str = "/sse", - message_path: str = "/messages/", - transport_security: TransportSecuritySettings | None = None, - host: str = "127.0.0.1", - ) -> Starlette: - """Return an instance of the SSE server app.""" - # Auto-enable DNS rebinding protection for localhost (IPv4 and IPv6) - if transport_security is None and host in ("127.0.0.1", "localhost", "::1"): - transport_security = TransportSecuritySettings( - enable_dns_rebinding_protection=True, - allowed_hosts=["127.0.0.1:*", "localhost:*", "[::1]:*"], - allowed_origins=["http://127.0.0.1:*", "http://localhost:*", "http://[::1]:*"], - ) - - sse = SseServerTransport(message_path, security_settings=transport_security) - - async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no cover - # Add client ID from auth context into request context if available - - async with sse.connect_sse(scope, receive, send) as streams: - await self._mcp_server.run(streams[0], streams[1], self._mcp_server.create_initialization_options()) - return Response() - - # Create routes - routes: list[Route | Mount] = [] - middleware: list[Middleware] = [] - required_scopes: list[str] = [] - - # Set up auth if configured - if self.settings.auth: # pragma: no cover - required_scopes = self.settings.auth.required_scopes or [] - - # Add auth middleware if token verifier is available - if self._token_verifier: - middleware = [ - # extract auth info from request (but do not require it) - Middleware( - AuthenticationMiddleware, - backend=BearerAuthBackend(self._token_verifier), - ), - # Add the auth context middleware to store - # authenticated user in a contextvar - Middleware(AuthContextMiddleware), - ] - - # Add auth endpoints if auth server provider is configured - if self._auth_server_provider: - from mcp.server.auth.routes import create_auth_routes - - routes.extend( - create_auth_routes( - provider=self._auth_server_provider, - issuer_url=self.settings.auth.issuer_url, - service_documentation_url=self.settings.auth.service_documentation_url, - client_registration_options=self.settings.auth.client_registration_options, - revocation_options=self.settings.auth.revocation_options, - ) - ) - - # When auth is configured, require authentication - if self._token_verifier: # pragma: no cover - # Determine resource metadata URL - resource_metadata_url = None - if self.settings.auth and self.settings.auth.resource_server_url: - from mcp.server.auth.routes import build_resource_metadata_url - - # Build compliant metadata URL for WWW-Authenticate header - resource_metadata_url = build_resource_metadata_url(self.settings.auth.resource_server_url) - - # Auth is enabled, wrap the endpoints with RequireAuthMiddleware - routes.append( - Route( - sse_path, - endpoint=RequireAuthMiddleware(handle_sse, required_scopes, resource_metadata_url), - methods=["GET"], - ) - ) - routes.append( - Mount( - message_path, - app=RequireAuthMiddleware(sse.handle_post_message, required_scopes, resource_metadata_url), - ) - ) - else: # pragma: no cover - # Auth is disabled, no need for RequireAuthMiddleware - # Since handle_sse is an ASGI app, we need to create a compatible endpoint - async def sse_endpoint(request: Request) -> Response: - # Convert the Starlette request to ASGI parameters - return await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage] - - routes.append( - Route( - sse_path, - endpoint=sse_endpoint, - methods=["GET"], - ) - ) - routes.append( - Mount( - message_path, - app=sse.handle_post_message, - ) - ) - # Add protected resource metadata endpoint if configured as RS - if self.settings.auth and self.settings.auth.resource_server_url: # pragma: no cover - from mcp.server.auth.routes import create_protected_resource_routes - - routes.extend( - create_protected_resource_routes( - resource_url=self.settings.auth.resource_server_url, - authorization_servers=[self.settings.auth.issuer_url], - scopes_supported=self.settings.auth.required_scopes, - ) - ) - - # mount these routes last, so they have the lowest route matching precedence - routes.extend(self._custom_starlette_routes) - - # Create Starlette app with routes and middleware - return Starlette(debug=self.settings.debug, routes=routes, middleware=middleware) - - def streamable_http_app( - self, - *, - streamable_http_path: str = "/mcp", - json_response: bool = False, - stateless_http: bool = False, - event_store: EventStore | None = None, - retry_interval: int | None = None, - transport_security: TransportSecuritySettings | None = None, - host: str = "127.0.0.1", - ) -> Starlette: - """Return an instance of the StreamableHTTP server app.""" - from starlette.middleware import Middleware - - # Auto-enable DNS rebinding protection for localhost (IPv4 and IPv6) - if transport_security is None and host in ("127.0.0.1", "localhost", "::1"): - transport_security = TransportSecuritySettings( - enable_dns_rebinding_protection=True, - allowed_hosts=["127.0.0.1:*", "localhost:*", "[::1]:*"], - allowed_origins=["http://127.0.0.1:*", "http://localhost:*", "http://[::1]:*"], - ) - - # Create session manager on first call (lazy initialization) - if self._session_manager is None: # pragma: no branch - self._session_manager = StreamableHTTPSessionManager( - app=self._mcp_server, - event_store=event_store, - retry_interval=retry_interval, - json_response=json_response, - stateless=stateless_http, - security_settings=transport_security, - ) - - # Create the ASGI handler - streamable_http_app = StreamableHTTPASGIApp(self._session_manager) - - # Create routes - routes: list[Route | Mount] = [] - middleware: list[Middleware] = [] - required_scopes: list[str] = [] - - # Set up auth if configured - if self.settings.auth: # pragma: no cover - required_scopes = self.settings.auth.required_scopes or [] - - # Add auth middleware if token verifier is available - if self._token_verifier: - middleware = [ - Middleware( - AuthenticationMiddleware, - backend=BearerAuthBackend(self._token_verifier), - ), - Middleware(AuthContextMiddleware), - ] - - # Add auth endpoints if auth server provider is configured - if self._auth_server_provider: - from mcp.server.auth.routes import create_auth_routes - - routes.extend( - create_auth_routes( - provider=self._auth_server_provider, - issuer_url=self.settings.auth.issuer_url, - service_documentation_url=self.settings.auth.service_documentation_url, - client_registration_options=self.settings.auth.client_registration_options, - revocation_options=self.settings.auth.revocation_options, - ) - ) - - # Set up routes with or without auth - if self._token_verifier: # pragma: no cover - # Determine resource metadata URL - resource_metadata_url = None - if self.settings.auth and self.settings.auth.resource_server_url: - from mcp.server.auth.routes import build_resource_metadata_url - - # Build compliant metadata URL for WWW-Authenticate header - resource_metadata_url = build_resource_metadata_url(self.settings.auth.resource_server_url) - - routes.append( - Route( - streamable_http_path, - endpoint=RequireAuthMiddleware(streamable_http_app, required_scopes, resource_metadata_url), - ) - ) - else: - # Auth is disabled, no wrapper needed - routes.append( - Route( - streamable_http_path, - endpoint=streamable_http_app, - ) - ) - - # Add protected resource metadata endpoint if configured as RS - if self.settings.auth and self.settings.auth.resource_server_url: # pragma: no cover - from mcp.server.auth.routes import create_protected_resource_routes - - routes.extend( - create_protected_resource_routes( - resource_url=self.settings.auth.resource_server_url, - authorization_servers=[self.settings.auth.issuer_url], - scopes_supported=self.settings.auth.required_scopes, - ) - ) - - routes.extend(self._custom_starlette_routes) - - return Starlette( - debug=self.settings.debug, - routes=routes, - middleware=middleware, - lifespan=lambda app: self.session_manager.run(), - ) - - async def list_prompts(self) -> list[MCPPrompt]: - """List all available prompts.""" - prompts = self._prompt_manager.list_prompts() - return [ - MCPPrompt( - name=prompt.name, - title=prompt.title, - description=prompt.description, - arguments=[ - MCPPromptArgument( - name=arg.name, - description=arg.description, - required=arg.required, - ) - for arg in (prompt.arguments or []) - ], - icons=prompt.icons, - ) - for prompt in prompts - ] - - async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult: - """Get a prompt by name with arguments.""" - try: - prompt = self._prompt_manager.get_prompt(name) - if not prompt: - raise ValueError(f"Unknown prompt: {name}") - - messages = await prompt.render(arguments, context=self.get_context()) - - return GetPromptResult( - description=prompt.description, - messages=pydantic_core.to_jsonable_python(messages), - ) - except Exception as e: - logger.exception(f"Error getting prompt {name}") - raise ValueError(str(e)) - - -class StreamableHTTPASGIApp: - """ASGI application for Streamable HTTP server transport.""" - - def __init__(self, session_manager: StreamableHTTPSessionManager): - self.session_manager = session_manager - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover - await self.session_manager.handle_request(scope, receive, send) - - -class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]): - """Context object providing access to MCP capabilities. - - This provides a cleaner interface to MCP's RequestContext functionality. - It gets injected into tool and resource functions that request it via type hints. - - To use context in a tool function, add a parameter with the Context type annotation: - - ```python - @server.tool() - def my_tool(x: int, ctx: Context) -> str: - # Log messages to the client - ctx.info(f"Processing {x}") - ctx.debug("Debug info") - ctx.warning("Warning message") - ctx.error("Error message") - - # Report progress - ctx.report_progress(50, 100) - - # Access resources - data = ctx.read_resource("resource://data") - - # Get request info - request_id = ctx.request_id - client_id = ctx.client_id - - return str(x) - ``` - - The context parameter name can be anything as long as it's annotated with Context. - The context is optional - tools that don't need it can omit the parameter. - """ - - _request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None - _fastmcp: FastMCP | None - - def __init__( - self, - *, - request_context: (RequestContext[ServerSessionT, LifespanContextT, RequestT] | None) = None, - fastmcp: FastMCP | None = None, - **kwargs: Any, - ): - super().__init__(**kwargs) - self._request_context = request_context - self._fastmcp = fastmcp - - @property - def fastmcp(self) -> FastMCP: - """Access to the FastMCP server.""" - if self._fastmcp is None: # pragma: no cover - raise ValueError("Context is not available outside of a request") - return self._fastmcp # pragma: no cover - - @property - def request_context( - self, - ) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]: - """Access to the underlying request context.""" - if self._request_context is None: # pragma: no cover - raise ValueError("Context is not available outside of a request") - return self._request_context - - async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: - """Report progress for the current operation. - - Args: - progress: Current progress value e.g. 24 - total: Optional total value e.g. 100 - message: Optional message e.g. Starting render... - """ - progress_token = self.request_context.meta.progress_token if self.request_context.meta else None - - if progress_token is None: # pragma: no cover - return - - await self.request_context.session.send_progress_notification( - progress_token=progress_token, - progress=progress, - total=total, - message=message, - ) - - async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: - """Read a resource by URI. - - Args: - uri: Resource URI to read - - Returns: - The resource content as either text or bytes - """ - assert self._fastmcp is not None, "Context is not available outside of a request" - return await self._fastmcp.read_resource(uri) - - async def elicit( - self, - message: str, - schema: type[ElicitSchemaModelT], - ) -> ElicitationResult[ElicitSchemaModelT]: - """Elicit information from the client/user. - - This method can be used to interactively ask for additional information from the - client within a tool's execution. The client might display the message to the - user and collect a response according to the provided schema. Or in case a - client is an agent, it might decide how to handle the elicitation -- either by asking - the user or automatically generating a response. - - Args: - schema: A Pydantic model class defining the expected response structure, according to the specification, - only primive types are allowed. - message: Optional message to present to the user. If not provided, will use - a default message based on the schema - - Returns: - An ElicitationResult containing the action taken and the data if accepted - - Note: - Check the result.action to determine if the user accepted, declined, or cancelled. - The result.data will only be populated if action is "accept" and validation succeeded. - """ - - return await elicit_with_validation( - session=self.request_context.session, - message=message, - schema=schema, - related_request_id=self.request_id, - ) - - async def elicit_url( - self, - message: str, - url: str, - elicitation_id: str, - ) -> UrlElicitationResult: - """Request URL mode elicitation from the client. - - This directs the user to an external URL for out-of-band interactions - that must not pass through the MCP client. Use this for: - - Collecting sensitive credentials (API keys, passwords) - - OAuth authorization flows with third-party services - - Payment and subscription flows - - Any interaction where data should not pass through the LLM context - - The response indicates whether the user consented to navigate to the URL. - The actual interaction happens out-of-band. When the elicitation completes, - call `self.session.send_elicit_complete(elicitation_id)` to notify the client. - - Args: - message: Human-readable explanation of why the interaction is needed - url: The URL the user should navigate to - elicitation_id: Unique identifier for tracking this elicitation - - Returns: - UrlElicitationResult indicating accept, decline, or cancel - """ - return await _elicit_url( - session=self.request_context.session, - message=message, - url=url, - elicitation_id=elicitation_id, - related_request_id=self.request_id, - ) - - async def log( - self, - level: Literal["debug", "info", "warning", "error"], - message: str, - *, - logger_name: str | None = None, - extra: dict[str, Any] | None = None, - ) -> None: - """Send a log message to the client. - - Args: - level: Log level (debug, info, warning, error) - message: Log message - logger_name: Optional logger name - extra: Optional dictionary with additional structured data to include - """ - - if extra: - log_data = { - "message": message, - **extra, - } - else: - log_data = message - - await self.request_context.session.send_log_message( - level=level, - data=log_data, - logger=logger_name, - related_request_id=self.request_id, - ) - - @property - def client_id(self) -> str | None: - """Get the client ID if available.""" - return ( - getattr(self.request_context.meta, "client_id", None) if self.request_context.meta else None - ) # pragma: no cover - - @property - def request_id(self) -> str: - """Get the unique ID for this request.""" - return str(self.request_context.request_id) - - @property - def session(self): - """Access to the underlying session for advanced usage.""" - return self.request_context.session - - async def close_sse_stream(self) -> None: - """Close the SSE stream to trigger client reconnection. - - This method closes the HTTP connection for the current request, triggering - client reconnection. Events continue to be stored in the event store and will - be replayed when the client reconnects with Last-Event-ID. - - Use this to implement polling behavior during long-running operations - - client will reconnect after the retry interval specified in the priming event. - - Note: - This is a no-op if not using StreamableHTTP transport with event_store. - The callback is only available when event_store is configured. - """ - if self._request_context and self._request_context.close_sse_stream: # pragma: no cover - await self._request_context.close_sse_stream() - - async def close_standalone_sse_stream(self) -> None: - """Close the standalone GET SSE stream to trigger client reconnection. - - This method closes the HTTP connection for the standalone GET stream used - for unsolicited server-to-client notifications. The client SHOULD reconnect - with Last-Event-ID to resume receiving notifications. - - Note: - This is a no-op if not using StreamableHTTP transport with event_store. - Currently, client reconnection for standalone GET streams is NOT - implemented - this is a known gap. - """ - if self._request_context and self._request_context.close_standalone_sse_stream: # pragma: no cover - await self._request_context.close_standalone_sse_stream() - - # Convenience methods for common log levels - async def debug(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: - """Send a debug log message.""" - await self.log("debug", message, logger_name=logger_name, extra=extra) - - async def info(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: - """Send an info log message.""" - await self.log("info", message, logger_name=logger_name, extra=extra) - - async def warning( - self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None - ) -> None: - """Send a warning log message.""" - await self.log("warning", message, logger_name=logger_name, extra=extra) - - async def error(self, message: str, *, logger_name: str | None = None, extra: dict[str, Any] | None = None) -> None: - """Send an error log message.""" - await self.log("error", message, logger_name=logger_name, extra=extra) diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index 6d1cee58ef..e69de29bb2 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -1,1502 +0,0 @@ -import base64 -from pathlib import Path -from typing import Any -from unittest.mock import patch - -import pytest -from pydantic import BaseModel -from starlette.applications import Starlette -from starlette.routing import Mount, Route - -from mcp.client import Client -from mcp.server.fastmcp import Context, FastMCP -from mcp.server.fastmcp.exceptions import ToolError -from mcp.server.fastmcp.prompts.base import Message, UserMessage -from mcp.server.fastmcp.resources import FileResource, FunctionResource -from mcp.server.fastmcp.utilities.types import Audio, Image -from mcp.server.session import ServerSession -from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.exceptions import McpError -from mcp.types import ( - AudioContent, - BlobResourceContents, - ContentBlock, - EmbeddedResource, - Icon, - ImageContent, - TextContent, - TextResourceContents, -) - - -class TestServer: - @pytest.mark.anyio - async def test_create_server(self): - mcp = FastMCP( - title="FastMCP Server", - description="Server description", - instructions="Server instructions", - website_url="https://example.com/mcp_server", - version="1.0", - icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48", "96x96"])], - ) - assert mcp.name == "FastMCP" - assert mcp.title == "FastMCP Server" - assert mcp.description == "Server description" - assert mcp.instructions == "Server instructions" - assert mcp.website_url == "https://example.com/mcp_server" - assert mcp.version == "1.0" - assert isinstance(mcp.icons, list) - assert len(mcp.icons) == 1 - assert mcp.icons[0].src == "https://example.com/icon.png" - - @pytest.mark.anyio - async def test_sse_app_returns_starlette_app(self): - """Test that sse_app returns a Starlette application with correct routes.""" - mcp = FastMCP("test") - # Use host="0.0.0.0" to avoid auto DNS protection - app = mcp.sse_app(host="0.0.0.0") - - assert isinstance(app, Starlette) - - # Verify routes exist - sse_routes = [r for r in app.routes if isinstance(r, Route)] - mount_routes = [r for r in app.routes if isinstance(r, Mount)] - - assert len(sse_routes) == 1, "Should have one SSE route" - assert len(mount_routes) == 1, "Should have one mount route" - assert sse_routes[0].path == "/sse" - assert mount_routes[0].path == "/messages" - - @pytest.mark.anyio - async def test_non_ascii_description(self): - """Test that FastMCP handles non-ASCII characters in descriptions correctly""" - mcp = FastMCP() - - @mcp.tool(description=("🌟 This tool uses emojis and UTF-8 characters: á é í ó ú ñ 漢字 🎉")) - def hello_world(name: str = "世界") -> str: - return f"¡Hola, {name}! 👋" - - async with Client(mcp) as client: - tools = await client.list_tools() - assert len(tools.tools) == 1 - tool = tools.tools[0] - assert tool.description is not None - assert "🌟" in tool.description - assert "漢字" in tool.description - assert "🎉" in tool.description - - result = await client.call_tool("hello_world", {}) - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, TextContent) - assert "¡Hola, 世界! 👋" == content.text - - @pytest.mark.anyio - async def test_add_tool_decorator(self): - mcp = FastMCP() - - @mcp.tool() - def sum(x: int, y: int) -> int: # pragma: no cover - return x + y - - assert len(mcp._tool_manager.list_tools()) == 1 - - @pytest.mark.anyio - async def test_add_tool_decorator_incorrect_usage(self): - mcp = FastMCP() - - with pytest.raises(TypeError, match="The @tool decorator was used incorrectly"): - - @mcp.tool # Missing parentheses #type: ignore - def sum(x: int, y: int) -> int: # pragma: no cover - return x + y - - @pytest.mark.anyio - async def test_add_resource_decorator(self): - mcp = FastMCP() - - @mcp.resource("r://{x}") - def get_data(x: str) -> str: # pragma: no cover - return f"Data: {x}" - - assert len(mcp._resource_manager._templates) == 1 - - @pytest.mark.anyio - async def test_add_resource_decorator_incorrect_usage(self): - mcp = FastMCP() - - with pytest.raises(TypeError, match="The @resource decorator was used incorrectly"): - - @mcp.resource # Missing parentheses #type: ignore - def get_data(x: str) -> str: # pragma: no cover - return f"Data: {x}" - - -class TestDnsRebindingProtection: - """Tests for automatic DNS rebinding protection on localhost. - - DNS rebinding protection is now configured in sse_app() and streamable_http_app() - based on the host parameter passed to those methods. - """ - - def test_auto_enabled_for_127_0_0_1_sse(self): - """DNS rebinding protection should auto-enable for host=127.0.0.1 in SSE app.""" - mcp = FastMCP() - # Call sse_app with host=127.0.0.1 to trigger auto-config - # We can't directly inspect the transport_security, but we can verify - # the app is created without error - app = mcp.sse_app(host="127.0.0.1") - assert app is not None - - def test_auto_enabled_for_127_0_0_1_streamable_http(self): - """DNS rebinding protection should auto-enable for host=127.0.0.1 in StreamableHTTP app.""" - mcp = FastMCP() - app = mcp.streamable_http_app(host="127.0.0.1") - assert app is not None - - def test_auto_enabled_for_localhost_sse(self): - """DNS rebinding protection should auto-enable for host=localhost in SSE app.""" - mcp = FastMCP() - app = mcp.sse_app(host="localhost") - assert app is not None - - def test_auto_enabled_for_ipv6_localhost_sse(self): - """DNS rebinding protection should auto-enable for host=::1 (IPv6 localhost) in SSE app.""" - mcp = FastMCP() - app = mcp.sse_app(host="::1") - assert app is not None - - def test_not_auto_enabled_for_other_hosts_sse(self): - """DNS rebinding protection should NOT auto-enable for other hosts in SSE app.""" - mcp = FastMCP() - app = mcp.sse_app(host="0.0.0.0") - assert app is not None - - def test_explicit_settings_not_overridden_sse(self): - """Explicit transport_security settings should not be overridden in SSE app.""" - custom_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=False, - ) - mcp = FastMCP() - # Explicit transport_security passed to sse_app should be used as-is - app = mcp.sse_app(host="127.0.0.1", transport_security=custom_settings) - assert app is not None - - def test_explicit_settings_not_overridden_streamable_http(self): - """Explicit transport_security settings should not be overridden in StreamableHTTP app.""" - custom_settings = TransportSecuritySettings( - enable_dns_rebinding_protection=False, - ) - mcp = FastMCP() - # Explicit transport_security passed to streamable_http_app should be used as-is - app = mcp.streamable_http_app(host="127.0.0.1", transport_security=custom_settings) - assert app is not None - - -def tool_fn(x: int, y: int) -> int: - return x + y - - -def error_tool_fn() -> None: - raise ValueError("Test error") - - -def image_tool_fn(path: str) -> Image: - return Image(path) - - -def audio_tool_fn(path: str) -> Audio: - return Audio(path) - - -def mixed_content_tool_fn() -> list[ContentBlock]: - return [ - TextContent(type="text", text="Hello"), - ImageContent(type="image", data="abc", mime_type="image/png"), - AudioContent(type="audio", data="def", mime_type="audio/wav"), - ] - - -class TestServerTools: - @pytest.mark.anyio - async def test_add_tool(self): - mcp = FastMCP() - mcp.add_tool(tool_fn) - mcp.add_tool(tool_fn) - assert len(mcp._tool_manager.list_tools()) == 1 - - @pytest.mark.anyio - async def test_list_tools(self): - mcp = FastMCP() - mcp.add_tool(tool_fn) - async with Client(mcp) as client: - tools = await client.list_tools() - assert len(tools.tools) == 1 - - @pytest.mark.anyio - async def test_call_tool(self): - mcp = FastMCP() - mcp.add_tool(tool_fn) - async with Client(mcp) as client: - result = await client.call_tool("my_tool", {"arg1": "value"}) - assert not hasattr(result, "error") - assert len(result.content) > 0 - - @pytest.mark.anyio - async def test_tool_exception_handling(self): - mcp = FastMCP() - mcp.add_tool(error_tool_fn) - async with Client(mcp) as client: - result = await client.call_tool("error_tool_fn", {}) - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, TextContent) - assert "Test error" in content.text - assert result.is_error is True - - @pytest.mark.anyio - async def test_tool_error_handling(self): - mcp = FastMCP() - mcp.add_tool(error_tool_fn) - async with Client(mcp) as client: - result = await client.call_tool("error_tool_fn", {}) - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, TextContent) - assert "Test error" in content.text - assert result.is_error is True - - @pytest.mark.anyio - async def test_tool_error_details(self): - """Test that exception details are properly formatted in the response""" - mcp = FastMCP() - mcp.add_tool(error_tool_fn) - async with Client(mcp) as client: - result = await client.call_tool("error_tool_fn", {}) - content = result.content[0] - assert isinstance(content, TextContent) - assert isinstance(content.text, str) - assert "Test error" in content.text - assert result.is_error is True - - @pytest.mark.anyio - async def test_tool_return_value_conversion(self): - mcp = FastMCP() - mcp.add_tool(tool_fn) - async with Client(mcp) as client: - result = await client.call_tool("tool_fn", {"x": 1, "y": 2}) - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, TextContent) - assert content.text == "3" - # Check structured content - int return type should have structured output - assert result.structured_content is not None - assert result.structured_content == {"result": 3} - - @pytest.mark.anyio - async def test_tool_image_helper(self, tmp_path: Path): - # Create a test image - image_path = tmp_path / "test.png" - image_path.write_bytes(b"fake png data") - - mcp = FastMCP() - mcp.add_tool(image_tool_fn) - async with Client(mcp) as client: - result = await client.call_tool("image_tool_fn", {"path": str(image_path)}) - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, ImageContent) - assert content.type == "image" - assert content.mime_type == "image/png" - # Verify base64 encoding - decoded = base64.b64decode(content.data) - assert decoded == b"fake png data" - # Check structured content - Image return type should NOT have structured output - assert result.structured_content is None - - @pytest.mark.anyio - async def test_tool_audio_helper(self, tmp_path: Path): - # Create a test audio - audio_path = tmp_path / "test.wav" - audio_path.write_bytes(b"fake wav data") - - mcp = FastMCP() - mcp.add_tool(audio_tool_fn) - async with Client(mcp) as client: - result = await client.call_tool("audio_tool_fn", {"path": str(audio_path)}) - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, AudioContent) - assert content.type == "audio" - assert content.mime_type == "audio/wav" - # Verify base64 encoding - decoded = base64.b64decode(content.data) - assert decoded == b"fake wav data" - # Check structured content - Image return type should NOT have structured output - assert result.structured_content is None - - @pytest.mark.parametrize( - "filename,expected_mime_type", - [ - ("test.wav", "audio/wav"), - ("test.mp3", "audio/mpeg"), - ("test.ogg", "audio/ogg"), - ("test.flac", "audio/flac"), - ("test.aac", "audio/aac"), - ("test.m4a", "audio/mp4"), - ("test.unknown", "application/octet-stream"), # Unknown extension fallback - ], - ) - @pytest.mark.anyio - async def test_tool_audio_suffix_detection(self, tmp_path: Path, filename: str, expected_mime_type: str): - """Test that Audio helper correctly detects MIME types from file suffixes""" - mcp = FastMCP() - mcp.add_tool(audio_tool_fn) - - # Create a test audio file with the specific extension - audio_path = tmp_path / filename - audio_path.write_bytes(b"fake audio data") - - async with Client(mcp) as client: - result = await client.call_tool("audio_tool_fn", {"path": str(audio_path)}) - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, AudioContent) - assert content.type == "audio" - assert content.mime_type == expected_mime_type - # Verify base64 encoding - decoded = base64.b64decode(content.data) - assert decoded == b"fake audio data" - - @pytest.mark.anyio - async def test_tool_mixed_content(self): - mcp = FastMCP() - mcp.add_tool(mixed_content_tool_fn) - async with Client(mcp) as client: - result = await client.call_tool("mixed_content_tool_fn", {}) - assert len(result.content) == 3 - content1, content2, content3 = result.content - assert isinstance(content1, TextContent) - assert content1.text == "Hello" - assert isinstance(content2, ImageContent) - assert content2.mime_type == "image/png" - assert content2.data == "abc" - assert isinstance(content3, AudioContent) - assert content3.mime_type == "audio/wav" - assert content3.data == "def" - assert result.structured_content is not None - assert "result" in result.structured_content - structured_result = result.structured_content["result"] - assert len(structured_result) == 3 - - expected_content = [ - {"type": "text", "text": "Hello"}, - {"type": "image", "data": "abc", "mimeType": "image/png"}, - {"type": "audio", "data": "def", "mimeType": "audio/wav"}, - ] - - for i, expected in enumerate(expected_content): - for key, value in expected.items(): - assert structured_result[i][key] == value - - @pytest.mark.anyio - async def test_tool_mixed_list_with_audio_and_image(self, tmp_path: Path): - """Test that lists containing Image objects and other types are handled - correctly""" - # Create a test image - image_path = tmp_path / "test.png" - image_path.write_bytes(b"test image data") - - # Create a test audio - audio_path = tmp_path / "test.wav" - audio_path.write_bytes(b"test audio data") - - # TODO(Marcelo): It seems if we add the proper type hint, it generates an invalid JSON schema. - # We need to fix this. - def mixed_list_fn() -> list: # type: ignore - return [ # type: ignore - "text message", - Image(image_path), - Audio(audio_path), - {"key": "value"}, - TextContent(type="text", text="direct content"), - ] - - mcp = FastMCP() - mcp.add_tool(mixed_list_fn) # type: ignore - async with Client(mcp) as client: - result = await client.call_tool("mixed_list_fn", {}) - assert len(result.content) == 5 - # Check text conversion - content1 = result.content[0] - assert isinstance(content1, TextContent) - assert content1.text == "text message" - # Check image conversion - content2 = result.content[1] - assert isinstance(content2, ImageContent) - assert content2.mime_type == "image/png" - assert base64.b64decode(content2.data) == b"test image data" - # Check audio conversion - content3 = result.content[2] - assert isinstance(content3, AudioContent) - assert content3.mime_type == "audio/wav" - assert base64.b64decode(content3.data) == b"test audio data" - # Check dict conversion - content4 = result.content[3] - assert isinstance(content4, TextContent) - assert '"key": "value"' in content4.text - # Check direct TextContent - content5 = result.content[4] - assert isinstance(content5, TextContent) - assert content5.text == "direct content" - # Check structured content - untyped list with Image objects should NOT have structured output - assert result.structured_content is None - - @pytest.mark.anyio - async def test_tool_structured_output_basemodel(self): - """Test tool with structured output returning BaseModel""" - - class UserOutput(BaseModel): - name: str - age: int - active: bool = True - - def get_user(user_id: int) -> UserOutput: - """Get user by ID""" - return UserOutput(name="John Doe", age=30) - - mcp = FastMCP() - mcp.add_tool(get_user) - - async with Client(mcp) as client: - # Check that the tool has outputSchema - tools = await client.list_tools() - tool = next(t for t in tools.tools if t.name == "get_user") - assert tool.output_schema is not None - assert tool.output_schema["type"] == "object" - assert "name" in tool.output_schema["properties"] - assert "age" in tool.output_schema["properties"] - - # Call the tool and check structured output - result = await client.call_tool("get_user", {"user_id": 123}) - assert result.is_error is False - assert result.structured_content is not None - assert result.structured_content == {"name": "John Doe", "age": 30, "active": True} - # Content should be JSON serialized version - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert '"name": "John Doe"' in result.content[0].text - - @pytest.mark.anyio - async def test_tool_structured_output_primitive(self): - """Test tool with structured output returning primitive type""" - - def calculate_sum(a: int, b: int) -> int: - """Add two numbers""" - return a + b - - mcp = FastMCP() - mcp.add_tool(calculate_sum) - - async with Client(mcp) as client: - # Check that the tool has outputSchema - tools = await client.list_tools() - tool = next(t for t in tools.tools if t.name == "calculate_sum") - assert tool.output_schema is not None - # Primitive types are wrapped - assert tool.output_schema["type"] == "object" - assert "result" in tool.output_schema["properties"] - assert tool.output_schema["properties"]["result"]["type"] == "integer" - - # Call the tool - result = await client.call_tool("calculate_sum", {"a": 5, "b": 7}) - assert result.is_error is False - assert result.structured_content is not None - assert result.structured_content == {"result": 12} - - @pytest.mark.anyio - async def test_tool_structured_output_list(self): - """Test tool with structured output returning list""" - - def get_numbers() -> list[int]: - """Get a list of numbers""" - return [1, 2, 3, 4, 5] - - mcp = FastMCP() - mcp.add_tool(get_numbers) - - async with Client(mcp) as client: - result = await client.call_tool("get_numbers", {}) - assert result.is_error is False - assert result.structured_content is not None - assert result.structured_content == {"result": [1, 2, 3, 4, 5]} - - @pytest.mark.anyio - async def test_tool_structured_output_server_side_validation_error(self): - """Test that server-side validation errors are handled properly""" - - def get_numbers() -> list[int]: - return [1, 2, 3, 4, [5]] # type: ignore - - mcp = FastMCP() - mcp.add_tool(get_numbers) - - async with Client(mcp) as client: - result = await client.call_tool("get_numbers", {}) - assert result.is_error is True - assert result.structured_content is None - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - - @pytest.mark.anyio - async def test_tool_structured_output_dict_str_any(self): - """Test tool with dict[str, Any] structured output""" - - def get_metadata() -> dict[str, Any]: - """Get metadata dictionary""" - return { - "version": "1.0.0", - "enabled": True, - "count": 42, - "tags": ["production", "stable"], - "config": {"nested": {"value": 123}}, - } - - mcp = FastMCP() - mcp.add_tool(get_metadata) - - async with Client(mcp) as client: - # Check schema - tools = await client.list_tools() - tool = next(t for t in tools.tools if t.name == "get_metadata") - assert tool.output_schema is not None - assert tool.output_schema["type"] == "object" - # dict[str, Any] should have minimal schema - assert ( - "additionalProperties" not in tool.output_schema - or tool.output_schema.get("additionalProperties") is True - ) - - # Call tool - result = await client.call_tool("get_metadata", {}) - assert result.is_error is False - assert result.structured_content is not None - expected = { - "version": "1.0.0", - "enabled": True, - "count": 42, - "tags": ["production", "stable"], - "config": {"nested": {"value": 123}}, - } - assert result.structured_content == expected - - @pytest.mark.anyio - async def test_tool_structured_output_dict_str_typed(self): - """Test tool with dict[str, T] structured output for specific T""" - - def get_settings() -> dict[str, str]: - """Get settings as string dictionary""" - return {"theme": "dark", "language": "en", "timezone": "UTC"} - - mcp = FastMCP() - mcp.add_tool(get_settings) - - async with Client(mcp) as client: - # Check schema - tools = await client.list_tools() - tool = next(t for t in tools.tools if t.name == "get_settings") - assert tool.output_schema is not None - assert tool.output_schema["type"] == "object" - assert tool.output_schema["additionalProperties"]["type"] == "string" - - # Call tool - result = await client.call_tool("get_settings", {}) - assert result.is_error is False - assert result.structured_content == {"theme": "dark", "language": "en", "timezone": "UTC"} - - @pytest.mark.anyio - async def test_remove_tool(self): - """Test removing a tool from the server.""" - mcp = FastMCP() - mcp.add_tool(tool_fn) - - # Verify tool exists - assert len(mcp._tool_manager.list_tools()) == 1 - - # Remove the tool - mcp.remove_tool("tool_fn") - - # Verify tool is removed - assert len(mcp._tool_manager.list_tools()) == 0 - - @pytest.mark.anyio - async def test_remove_nonexistent_tool(self): - """Test that removing a non-existent tool raises ToolError.""" - mcp = FastMCP() - - with pytest.raises(ToolError, match="Unknown tool: nonexistent"): - mcp.remove_tool("nonexistent") - - @pytest.mark.anyio - async def test_remove_tool_and_list(self): - """Test that a removed tool doesn't appear in list_tools.""" - mcp = FastMCP() - mcp.add_tool(tool_fn) - mcp.add_tool(error_tool_fn) - - # Verify both tools exist - async with Client(mcp) as client: - tools = await client.list_tools() - assert len(tools.tools) == 2 - tool_names = [t.name for t in tools.tools] - assert "tool_fn" in tool_names - assert "error_tool_fn" in tool_names - - # Remove one tool - mcp.remove_tool("tool_fn") - - # Verify only one tool remains - async with Client(mcp) as client: - tools = await client.list_tools() - assert len(tools.tools) == 1 - assert tools.tools[0].name == "error_tool_fn" - - @pytest.mark.anyio - async def test_remove_tool_and_call(self): - """Test that calling a removed tool fails appropriately.""" - mcp = FastMCP() - mcp.add_tool(tool_fn) - - # Verify tool works before removal - async with Client(mcp) as client: - result = await client.call_tool("tool_fn", {"x": 1, "y": 2}) - assert not result.is_error - content = result.content[0] - assert isinstance(content, TextContent) - assert content.text == "3" - - # Remove the tool - mcp.remove_tool("tool_fn") - - # Verify calling removed tool returns an error - async with Client(mcp) as client: - result = await client.call_tool("tool_fn", {"x": 1, "y": 2}) - assert result.is_error - content = result.content[0] - assert isinstance(content, TextContent) - assert "Unknown tool" in content.text - - -class TestServerResources: - @pytest.mark.anyio - async def test_text_resource(self): - mcp = FastMCP() - - def get_text(): - return "Hello, world!" - - resource = FunctionResource(uri="resource://test", name="test", fn=get_text) - mcp.add_resource(resource) - - async with Client(mcp) as client: - result = await client.read_resource("resource://test") - - async with Client(mcp) as client: - result = await client.read_resource("resource://test") - - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Hello, world!" - - @pytest.mark.anyio - async def test_binary_resource(self): - mcp = FastMCP() - - def get_binary(): - return b"Binary data" - - resource = FunctionResource( - uri="resource://binary", - name="binary", - fn=get_binary, - mime_type="application/octet-stream", - ) - mcp.add_resource(resource) - - async with Client(mcp) as client: - result = await client.read_resource("resource://binary") - - async with Client(mcp) as client: - result = await client.read_resource("resource://binary") - - assert isinstance(result.contents[0], BlobResourceContents) - assert result.contents[0].blob == base64.b64encode(b"Binary data").decode() - - @pytest.mark.anyio - async def test_file_resource_text(self, tmp_path: Path): - mcp = FastMCP() - - # Create a text file - text_file = tmp_path / "test.txt" - text_file.write_text("Hello from file!") - - resource = FileResource(uri="file://test.txt", name="test.txt", path=text_file) - mcp.add_resource(resource) - - async with Client(mcp) as client: - result = await client.read_resource("file://test.txt") - - async with Client(mcp) as client: - result = await client.read_resource("file://test.txt") - - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Hello from file!" - - @pytest.mark.anyio - async def test_file_resource_binary(self, tmp_path: Path): - mcp = FastMCP() - - # Create a binary file - binary_file = tmp_path / "test.bin" - binary_file.write_bytes(b"Binary file data") - - resource = FileResource( - uri="file://test.bin", - name="test.bin", - path=binary_file, - mime_type="application/octet-stream", - ) - mcp.add_resource(resource) - - async with Client(mcp) as client: - result = await client.read_resource("file://test.bin") - - async with Client(mcp) as client: - result = await client.read_resource("file://test.bin") - - assert isinstance(result.contents[0], BlobResourceContents) - assert result.contents[0].blob == base64.b64encode(b"Binary file data").decode() - - @pytest.mark.anyio - async def test_function_resource(self): - mcp = FastMCP() - - @mcp.resource("function://test", name="test_get_data") - def get_data() -> str: # pragma: no cover - """get_data returns a string""" - return "Hello, world!" - - async with Client(mcp) as client: - resources = await client.list_resources() - assert len(resources.resources) == 1 - resource = resources.resources[0] - assert resource.description == "get_data returns a string" - assert resource.uri == "function://test" - assert resource.name == "test_get_data" - assert resource.mime_type == "text/plain" - - -class TestServerResourceTemplates: - @pytest.mark.anyio - async def test_resource_with_params(self): - """Test that a resource with function parameters raises an error if the URI - parameters don't match""" - mcp = FastMCP() - - with pytest.raises(ValueError, match="Mismatch between URI parameters"): - - @mcp.resource("resource://data") - def get_data_fn(param: str) -> str: # pragma: no cover - return f"Data: {param}" - - @pytest.mark.anyio - async def test_resource_with_uri_params(self): - """Test that a resource with URI parameters is automatically a template""" - mcp = FastMCP() - - with pytest.raises(ValueError, match="Mismatch between URI parameters"): - - @mcp.resource("resource://{param}") - def get_data() -> str: # pragma: no cover - return "Data" - - @pytest.mark.anyio - async def test_resource_with_untyped_params(self): - """Test that a resource with untyped parameters raises an error""" - mcp = FastMCP() - - @mcp.resource("resource://{param}") - def get_data(param) -> str: # type: ignore # pragma: no cover - return "Data" - - @pytest.mark.anyio - async def test_resource_matching_params(self): - """Test that a resource with matching URI and function parameters works""" - mcp = FastMCP() - - @mcp.resource("resource://{name}/data") - def get_data(name: str) -> str: - return f"Data for {name}" - - async with Client(mcp) as client: - result = await client.read_resource("resource://test/data") - - async with Client(mcp) as client: - result = await client.read_resource("resource://test/data") - - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Data for test" - - @pytest.mark.anyio - async def test_resource_mismatched_params(self): - """Test that mismatched parameters raise an error""" - mcp = FastMCP() - - with pytest.raises(ValueError, match="Mismatch between URI parameters"): - - @mcp.resource("resource://{name}/data") - def get_data(user: str) -> str: # pragma: no cover - return f"Data for {user}" - - @pytest.mark.anyio - async def test_resource_multiple_params(self): - """Test that multiple parameters work correctly""" - mcp = FastMCP() - - @mcp.resource("resource://{org}/{repo}/data") - def get_data(org: str, repo: str) -> str: - return f"Data for {org}/{repo}" - - async with Client(mcp) as client: - result = await client.read_resource("resource://cursor/fastmcp/data") - - async with Client(mcp) as client: - result = await client.read_resource("resource://cursor/fastmcp/data") - - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Data for cursor/fastmcp" - - @pytest.mark.anyio - async def test_resource_multiple_mismatched_params(self): - """Test that mismatched parameters raise an error""" - mcp = FastMCP() - - with pytest.raises(ValueError, match="Mismatch between URI parameters"): - - @mcp.resource("resource://{org}/{repo}/data") - def get_data_mismatched(org: str, repo_2: str) -> str: # pragma: no cover - return f"Data for {org}" - - """Test that a resource with no parameters works as a regular resource""" # pragma: no cover - mcp = FastMCP() - - @mcp.resource("resource://static") - def get_static_data() -> str: - return "Static data" - - async with Client(mcp) as client: - result = await client.read_resource("resource://static") - - async with Client(mcp) as client: - result = await client.read_resource("resource://static") - - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Static data" - - @pytest.mark.anyio - async def test_template_to_resource_conversion(self): - """Test that templates are properly converted to resources when accessed""" - mcp = FastMCP() - - @mcp.resource("resource://{name}/data") - def get_data(name: str) -> str: - return f"Data for {name}" - - # Should be registered as a template - assert len(mcp._resource_manager._templates) == 1 - assert len(await mcp.list_resources()) == 0 - - # When accessed, should create a concrete resource - resource = await mcp._resource_manager.get_resource("resource://test/data") - assert isinstance(resource, FunctionResource) - result = await resource.read() - assert result == "Data for test" - - @pytest.mark.anyio - async def test_resource_template_includes_mime_type(self): - """Test that list resource templates includes the correct mimeType.""" - mcp = FastMCP() - - @mcp.resource("resource://{user}/csv", mime_type="text/csv") - def get_csv(user: str) -> str: - return f"csv for {user}" - - templates = await mcp.list_resource_templates() - assert len(templates) == 1 - template = templates[0] - - assert hasattr(template, "mime_type") - assert template.mime_type == "text/csv" - - async with Client(mcp) as client: - result = await client.read_resource("resource://bob/csv") - - async with Client(mcp) as client: - result = await client.read_resource("resource://bob/csv") - - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "csv for bob" - - -class TestServerResourceMetadata: - """Test FastMCP @resource decorator meta parameter for list operations. - - Meta flows: @resource decorator -> resource/template storage -> list_resources/list_resource_templates. - Note: read_resource does NOT pass meta to protocol response (lowlevel/server.py only extracts content/mime_type). - """ - - @pytest.mark.anyio - async def test_resource_decorator_with_metadata(self): - """Test that @resource decorator accepts and passes meta parameter.""" - # Tests static resource flow: decorator -> FunctionResource -> list_resources (server.py:544,635,361) - mcp = FastMCP() - - metadata = {"ui": {"component": "file-viewer"}, "priority": "high"} - - @mcp.resource("resource://config", meta=metadata) - def get_config() -> str: # pragma: no cover - return '{"debug": false}' - - resources = await mcp.list_resources() - assert len(resources) == 1 - assert resources[0].meta is not None - assert resources[0].meta == metadata - assert resources[0].meta["ui"]["component"] == "file-viewer" - assert resources[0].meta["priority"] == "high" - - @pytest.mark.anyio - async def test_resource_template_decorator_with_metadata(self): - """Test that @resource decorator passes meta to templates.""" - # Tests template resource flow: decorator -> add_template() -> list_resource_templates (server.py:544,622,377) - mcp = FastMCP() - - metadata = {"api_version": "v2", "deprecated": False} - - @mcp.resource("resource://{city}/weather", meta=metadata) - def get_weather(city: str) -> str: # pragma: no cover - return f"Weather for {city}" - - templates = await mcp.list_resource_templates() - assert len(templates) == 1 - assert templates[0].meta is not None - assert templates[0].meta == metadata - assert templates[0].meta["api_version"] == "v2" - - @pytest.mark.anyio - async def test_read_resource_returns_meta(self): - """Test that read_resource includes meta in response.""" - # Tests end-to-end: Resource.meta -> ReadResourceContents.meta -> protocol _meta (lowlevel/server.py:341,371) - mcp = FastMCP() - - metadata = {"version": "1.0", "category": "config"} - - @mcp.resource("resource://data", meta=metadata) - def get_data() -> str: - return "test data" - - async with Client(mcp) as client: - result = await client.read_resource("resource://data") - - async with Client(mcp) as client: - result = await client.read_resource("resource://data") - - # Verify content and metadata in protocol response - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "test data" - assert result.contents[0].meta is not None - assert result.contents[0].meta == metadata - assert result.contents[0].meta["version"] == "1.0" - assert result.contents[0].meta["category"] == "config" - - -class TestContextInjection: - """Test context injection in tools, resources, and prompts.""" - - @pytest.mark.anyio - async def test_context_detection(self): - """Test that context parameters are properly detected.""" - mcp = FastMCP() - - def tool_with_context(x: int, ctx: Context[ServerSession, None]) -> str: # pragma: no cover - return f"Request {ctx.request_id}: {x}" - - tool = mcp._tool_manager.add_tool(tool_with_context) - assert tool.context_kwarg == "ctx" - - @pytest.mark.anyio - async def test_context_injection(self): - """Test that context is properly injected into tool calls.""" - mcp = FastMCP() - - def tool_with_context(x: int, ctx: Context[ServerSession, None]) -> str: - assert ctx.request_id is not None - return f"Request {ctx.request_id}: {x}" - - mcp.add_tool(tool_with_context) - async with Client(mcp) as client: - result = await client.call_tool("tool_with_context", {"x": 42}) - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, TextContent) - assert "Request" in content.text - assert "42" in content.text - - @pytest.mark.anyio - async def test_async_context(self): - """Test that context works in async functions.""" - mcp = FastMCP() - - async def async_tool(x: int, ctx: Context[ServerSession, None]) -> str: - assert ctx.request_id is not None - return f"Async request {ctx.request_id}: {x}" - - mcp.add_tool(async_tool) - async with Client(mcp) as client: - result = await client.call_tool("async_tool", {"x": 42}) - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, TextContent) - assert "Async request" in content.text - assert "42" in content.text - - @pytest.mark.anyio - async def test_context_logging(self): - """Test that context logging methods work.""" - mcp = FastMCP() - - async def logging_tool(msg: str, ctx: Context[ServerSession, None]) -> str: - await ctx.debug("Debug message") - await ctx.info("Info message") - await ctx.warning("Warning message") - await ctx.error("Error message") - return f"Logged messages for {msg}" - - mcp.add_tool(logging_tool) - - with patch("mcp.server.session.ServerSession.send_log_message") as mock_log: - async with Client(mcp) as client: - result = await client.call_tool("logging_tool", {"msg": "test"}) - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, TextContent) - assert "Logged messages for test" in content.text - - assert mock_log.call_count == 4 - mock_log.assert_any_call( - level="debug", - data="Debug message", - logger=None, - related_request_id="1", - ) - mock_log.assert_any_call( - level="info", - data="Info message", - logger=None, - related_request_id="1", - ) - mock_log.assert_any_call( - level="warning", - data="Warning message", - logger=None, - related_request_id="1", - ) - mock_log.assert_any_call( - level="error", - data="Error message", - logger=None, - related_request_id="1", - ) - - @pytest.mark.anyio - async def test_optional_context(self): - """Test that context is optional.""" - mcp = FastMCP() - - def no_context(x: int) -> int: - return x * 2 - - mcp.add_tool(no_context) - async with Client(mcp) as client: - result = await client.call_tool("no_context", {"x": 21}) - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, TextContent) - assert content.text == "42" - - @pytest.mark.anyio - async def test_context_resource_access(self): - """Test that context can access resources.""" - mcp = FastMCP() - - @mcp.resource("test://data") - def test_resource() -> str: - return "resource data" - - @mcp.tool() - async def tool_with_resource(ctx: Context[ServerSession, None]) -> str: - r_iter = await ctx.read_resource("test://data") - r_list = list(r_iter) - assert len(r_list) == 1 - r = r_list[0] - return f"Read resource: {r.content} with mime type {r.mime_type}" - - async with Client(mcp) as client: - result = await client.call_tool("tool_with_resource", {}) - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, TextContent) - assert "Read resource: resource data" in content.text - - @pytest.mark.anyio - async def test_resource_with_context(self): - """Test that resources can receive context parameter.""" - mcp = FastMCP() - - @mcp.resource("resource://context/{name}") - def resource_with_context(name: str, ctx: Context[ServerSession, None]) -> str: - """Resource that receives context.""" - assert ctx is not None - return f"Resource {name} - context injected" - - # Verify template has context_kwarg set - templates = mcp._resource_manager.list_templates() - assert len(templates) == 1 - template = templates[0] - assert hasattr(template, "context_kwarg") - assert template.context_kwarg == "ctx" - - # Test via client - - async with Client(mcp) as client: - result = await client.read_resource("resource://context/test") - - async with Client(mcp) as client: - result = await client.read_resource("resource://context/test") - - assert len(result.contents) == 1 - content = result.contents[0] - assert isinstance(content, TextResourceContents) - # Should have either request_id or indication that context was injected - assert "Resource test - context injected" == content.text - - @pytest.mark.anyio - async def test_resource_without_context(self): - """Test that resources without context work normally.""" - mcp = FastMCP() - - @mcp.resource("resource://nocontext/{name}") - def resource_no_context(name: str) -> str: - """Resource without context.""" - return f"Resource {name} works" - - # Verify template has no context_kwarg - templates = mcp._resource_manager.list_templates() - assert len(templates) == 1 - template = templates[0] - assert template.context_kwarg is None - - # Test via client - - async with Client(mcp) as client: - result = await client.read_resource("resource://nocontext/test") - - async with Client(mcp) as client: - result = await client.read_resource("resource://nocontext/test") - - assert len(result.contents) == 1 - content = result.contents[0] - assert isinstance(content, TextResourceContents) - assert content.text == "Resource test works" - - @pytest.mark.anyio - async def test_resource_context_custom_name(self): - """Test resource context with custom parameter name.""" - mcp = FastMCP() - - @mcp.resource("resource://custom/{id}") - def resource_custom_ctx(id: str, my_ctx: Context[ServerSession, None]) -> str: - """Resource with custom context parameter name.""" - assert my_ctx is not None - return f"Resource {id} with context" - - # Verify template detects custom context parameter - templates = mcp._resource_manager.list_templates() - assert len(templates) == 1 - template = templates[0] - assert template.context_kwarg == "my_ctx" - - # Test via client - - async with Client(mcp) as client: - result = await client.read_resource("resource://custom/123") - - async with Client(mcp) as client: - result = await client.read_resource("resource://custom/123") - - assert len(result.contents) == 1 - content = result.contents[0] - assert isinstance(content, TextResourceContents) - assert "Resource 123 with context" in content.text - - @pytest.mark.anyio - async def test_prompt_with_context(self): - """Test that prompts can receive context parameter.""" - mcp = FastMCP() - - @mcp.prompt("prompt_with_ctx") - def prompt_with_context(text: str, ctx: Context[ServerSession, None]) -> str: - """Prompt that expects context.""" - assert ctx is not None - return f"Prompt '{text}' - context injected" - - # Check if prompt has context parameter detection - prompts = mcp._prompt_manager.list_prompts() - assert len(prompts) == 1 - - # Test via client - async with Client(mcp) as client: - # Try calling without passing ctx explicitly - result = await client.get_prompt("prompt_with_ctx", {"text": "test"}) - # If this succeeds, check if context was injected - assert len(result.messages) == 1 - content = result.messages[0].content - assert isinstance(content, TextContent) - assert "Prompt 'test' - context injected" in content.text - - @pytest.mark.anyio - async def test_prompt_without_context(self): - """Test that prompts without context work normally.""" - mcp = FastMCP() - - @mcp.prompt("prompt_no_ctx") - def prompt_no_context(text: str) -> str: - """Prompt without context.""" - return f"Prompt '{text}' works" - - # Test via client - async with Client(mcp) as client: - result = await client.get_prompt("prompt_no_ctx", {"text": "test"}) - assert len(result.messages) == 1 - message = result.messages[0] - content = message.content - assert isinstance(content, TextContent) - assert content.text == "Prompt 'test' works" - - -class TestServerPrompts: - """Test prompt functionality in FastMCP server.""" - - @pytest.mark.anyio - async def test_prompt_decorator(self): - """Test that the prompt decorator registers prompts correctly.""" - mcp = FastMCP() - - @mcp.prompt() - def fn() -> str: - return "Hello, world!" - - prompts = mcp._prompt_manager.list_prompts() - assert len(prompts) == 1 - assert prompts[0].name == "fn" - # Don't compare functions directly since validate_call wraps them - content = await prompts[0].render() - assert isinstance(content[0].content, TextContent) - assert content[0].content.text == "Hello, world!" - - @pytest.mark.anyio - async def test_prompt_decorator_with_name(self): - """Test prompt decorator with custom name.""" - mcp = FastMCP() - - @mcp.prompt(name="custom_name") - def fn() -> str: - return "Hello, world!" - - prompts = mcp._prompt_manager.list_prompts() - assert len(prompts) == 1 - assert prompts[0].name == "custom_name" - content = await prompts[0].render() - assert isinstance(content[0].content, TextContent) - assert content[0].content.text == "Hello, world!" - - @pytest.mark.anyio - async def test_prompt_decorator_with_description(self): - """Test prompt decorator with custom description.""" - mcp = FastMCP() - - @mcp.prompt(description="A custom description") - def fn() -> str: - return "Hello, world!" - - prompts = mcp._prompt_manager.list_prompts() - assert len(prompts) == 1 - assert prompts[0].description == "A custom description" - content = await prompts[0].render() - assert isinstance(content[0].content, TextContent) - assert content[0].content.text == "Hello, world!" - - def test_prompt_decorator_error(self): - """Test error when decorator is used incorrectly.""" - mcp = FastMCP() - with pytest.raises(TypeError, match="decorator was used incorrectly"): - - @mcp.prompt # type: ignore - def fn() -> str: # pragma: no cover - return "Hello, world!" - - @pytest.mark.anyio - async def test_list_prompts(self): - """Test listing prompts through MCP protocol.""" - mcp = FastMCP() - - @mcp.prompt() - def fn(name: str, optional: str = "default") -> str: # pragma: no cover - return f"Hello, {name}!" - - async with Client(mcp) as client: - result = await client.list_prompts() - assert result.prompts is not None - assert len(result.prompts) == 1 - prompt = result.prompts[0] - assert prompt.name == "fn" - assert prompt.arguments is not None - assert len(prompt.arguments) == 2 - assert prompt.arguments[0].name == "name" - assert prompt.arguments[0].required is True - assert prompt.arguments[1].name == "optional" - assert prompt.arguments[1].required is False - - @pytest.mark.anyio - async def test_get_prompt(self): - """Test getting a prompt through MCP protocol.""" - mcp = FastMCP() - - @mcp.prompt() - def fn(name: str) -> str: - return f"Hello, {name}!" - - async with Client(mcp) as client: - result = await client.get_prompt("fn", {"name": "World"}) - assert len(result.messages) == 1 - message = result.messages[0] - assert message.role == "user" - content = message.content - assert isinstance(content, TextContent) - assert content.text == "Hello, World!" - - @pytest.mark.anyio - async def test_get_prompt_with_description(self): - """Test getting a prompt through MCP protocol.""" - mcp = FastMCP() - - @mcp.prompt(description="Test prompt description") - def fn(name: str) -> str: - return f"Hello, {name}!" - - async with Client(mcp) as client: - result = await client.get_prompt("fn", {"name": "World"}) - assert result.description == "Test prompt description" - - @pytest.mark.anyio - async def test_get_prompt_without_description(self): - """Test getting a prompt without description returns empty string.""" - mcp = FastMCP() - - @mcp.prompt() - def fn(name: str) -> str: - return f"Hello, {name}!" - - async with Client(mcp) as client: - result = await client.get_prompt("fn", {"name": "World"}) - assert result.description == "" - - @pytest.mark.anyio - async def test_get_prompt_with_docstring_description(self): - """Test prompt uses docstring as description when not explicitly provided.""" - mcp = FastMCP() - - @mcp.prompt() - def fn(name: str) -> str: - """This is the function docstring.""" - return f"Hello, {name}!" - - async with Client(mcp) as client: - result = await client.get_prompt("fn", {"name": "World"}) - assert result.description == "This is the function docstring." - - @pytest.mark.anyio - async def test_get_prompt_with_resource(self): - """Test getting a prompt that returns resource content.""" - mcp = FastMCP() - - @mcp.prompt() - def fn() -> Message: - return UserMessage( - content=EmbeddedResource( - type="resource", - resource=TextResourceContents( - uri="file://file.txt", - text="File contents", - mime_type="text/plain", - ), - ) - ) - - async with Client(mcp) as client: - result = await client.get_prompt("fn") - assert len(result.messages) == 1 - message = result.messages[0] - assert message.role == "user" - content = message.content - assert isinstance(content, EmbeddedResource) - resource = content.resource - assert isinstance(resource, TextResourceContents) - assert resource.text == "File contents" - assert resource.mime_type == "text/plain" - - @pytest.mark.anyio - async def test_get_unknown_prompt(self): - """Test error when getting unknown prompt.""" - mcp = FastMCP() - async with Client(mcp) as client: - with pytest.raises(McpError, match="Unknown prompt"): - await client.get_prompt("unknown") - - @pytest.mark.anyio - async def test_get_prompt_missing_args(self): - """Test error when required arguments are missing.""" - mcp = FastMCP() - - @mcp.prompt() - def prompt_fn(name: str) -> str: # pragma: no cover - return f"Hello, {name}!" - - async with Client(mcp) as client: - with pytest.raises(McpError, match="Missing required arguments"): - await client.get_prompt("prompt_fn") - - -def test_streamable_http_no_redirect() -> None: - """Test that streamable HTTP routes are correctly configured.""" - mcp = FastMCP() - # streamable_http_path defaults to "/mcp" - app = mcp.streamable_http_app() - - # Find routes by type - streamable_http_app creates Route objects, not Mount objects - streamable_routes = [r for r in app.routes if isinstance(r, Route) and hasattr(r, "path") and r.path == "/mcp"] - - # Verify routes exist - assert len(streamable_routes) == 1, "Should have one streamable route" - - # Verify path values - assert streamable_routes[0].path == "/mcp", "Streamable route path should be /mcp"