Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions examples/mcp/elicitations/blocking_elicitation_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
MCP Server for Testing Elicitation Blocking Scenarios

This server provides tools to test the concurrent POST fix for elicitation.
It demonstrates both:
1. Associated elicitation (via POST response SSE stream)
2. Dissociated elicitation (via GET stream)

The blocking issue occurs when:
- Client sends tools/call POST
- Server sends elicitation request to client
- Client tries to POST elicitation response
- Client's response is blocked by HTTP client connection pool

Without the concurrent POST fix, the elicitation response cannot be sent
until the original tools/call POST completes, causing timeouts.

Run with: python examples/mcp/elicitations/blocking_elicitation_server.py
Connect to: http://127.0.0.1:8000/mcp
"""

import logging
import sys
from typing import TYPE_CHECKING, Any

from mcp.server.elicitation import (
AcceptedElicitation,
CancelledElicitation,
DeclinedElicitation,
)
from mcp.server.fastmcp import FastMCP
from pydantic import BaseModel, Field

if TYPE_CHECKING:
from mcp.types import ElicitResult

# Configure detailed logging
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
stream=sys.stderr,
)
logger = logging.getLogger("blocking_elicitation_server")

# Create MCP server (host/port are configured here, not in run())
mcp = FastMCP(
"Blocking Elicitation Test Server",
log_level="DEBUG",
host="127.0.0.1",
port=8000,
)


class DeploymentConfig(BaseModel):
"""Schema for deployment confirmation elicitation."""

environment: str = Field(
description="Target environment for deployment",
json_schema_extra={"enum": ["development", "staging", "production"]},
)
confirm: bool = Field(description="Confirm the deployment?")


@mcp.tool()
async def deploy_associated() -> str:
"""
Test elicitation via POST response SSE stream (associated with request).

This uses FastMCP's ctx.elicit() which automatically sets related_request_id,
routing the elicitation through the POST response stream.

Expected behavior:
- Without fix: Client blocks, elicitation may timeout
- With fix: Elicitation completes immediately
"""
ctx = mcp.get_context()
logger.info("deploy_associated: Sending elicitation via POST response SSE")

result = await ctx.elicit(
"Confirm deployment configuration (associated - via POST SSE)",
schema=DeploymentConfig,
)

match result:
case AcceptedElicitation(data=data):
logger.info(f"Elicitation accepted: {data}")
return f"Deployed to {data.environment} (confirm={data.confirm})"
case DeclinedElicitation():
logger.info("Elicitation declined")
return "Deployment declined by user"
case CancelledElicitation():
logger.info("Elicitation cancelled")
return "Deployment cancelled"
case _:
return f"Unexpected result: {result}"


@mcp.tool()
async def deploy_dissociated() -> str:
"""
Test elicitation via GET stream (dissociated from request).

This bypasses FastMCP and calls session.elicit_form() directly WITHOUT
setting related_request_id, routing the elicitation through the GET stream.

This matches the scenario from the user's logs where elicitation
goes via GET and the response is blocked.

Expected behavior:
- Without fix: Client blocks, elicitation times out after ~20s
- With fix: Elicitation completes immediately
"""
ctx = mcp.get_context()
session = ctx.request_context.session

logger.info("deploy_dissociated: Sending elicitation via GET stream (no related_request_id)")

# Call session.elicit_form WITHOUT related_request_id
# This routes the elicitation to the GET stream instead of POST response
requested_schema: dict[str, Any] = {
"type": "object",
"properties": {
"environment": {
"type": "string",
"title": "Environment",
"enum": ["development", "staging", "production"],
},
"confirm": {
"type": "boolean",
"title": "Confirm deployment?",
},
},
"required": ["environment", "confirm"],
}

result: ElicitResult = await session.elicit_form(
message="Confirm deployment configuration (dissociated - via GET stream)",
requestedSchema=requested_schema,
related_request_id=None, # <-- KEY: No related_request_id = routes to GET stream
)

logger.info(f"Elicitation result: {result}")

match result.action:
case "accept":
content = result.content or {}
env = content.get("environment", "unknown")
confirm = content.get("confirm", False)
return f"Deployed to {env} (confirm={confirm})"
case "decline":
return "Deployment declined by user"
case "cancel":
return "Deployment cancelled"
case _:
return f"Unexpected action: {result.action}"


@mcp.tool()
async def ping() -> str:
"""Simple ping tool for testing basic connectivity."""
return "pong"


if __name__ == "__main__":
logger.info("Starting blocking elicitation test server...")
logger.info("Connect to: http://127.0.0.1:8000/mcp")
logger.info("")
logger.info("Available tools:")
logger.info(" - deploy_associated: Elicitation via POST response SSE")
logger.info(" - deploy_dissociated: Elicitation via GET stream")
logger.info(" - ping: Basic connectivity test")
mcp.run(transport="streamable-http")
16 changes: 10 additions & 6 deletions src/fast_agent/mcp/mcp_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@
get_default_environment,
)
from mcp.client.streamable_http import GetSessionIdCallback
from mcp.shared._httpx_utils import (
MCP_DEFAULT_SSE_READ_TIMEOUT,
MCP_DEFAULT_TIMEOUT,
create_mcp_http_client,
)
from mcp.shared._httpx_utils import MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT
from mcp.types import Implementation, JSONRPCMessage, ServerCapabilities

from fast_agent.config import MCPServerSettings
Expand Down Expand Up @@ -696,10 +692,18 @@ def channel_hook(event):
read=config.http_read_timeout_seconds or MCP_DEFAULT_SSE_READ_TIMEOUT,
)

http_client = create_mcp_http_client(
# Use an HTTP client that allows concurrent POSTs so elicitation
# responses are not blocked behind long-running requests.
limits = httpx.Limits(
max_connections=100,
max_keepalive_connections=0, # force new connection per request to avoid blocking
)
http_client = httpx.AsyncClient(
headers=headers,
auth=oauth_auth,
timeout=timeout,
limits=limits,
http2=False, # avoid h2 dependency; still allow concurrent HTTP/1.1 connections
)
return tracking_streamablehttp_client(
config.url,
Expand Down
77 changes: 74 additions & 3 deletions src/fast_agent/mcp/streamable_http_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
StreamWriter,
)
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared.message import SessionMessage
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import JSONRPCError, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse

from fast_agent.mcp.transport_tracking import ChannelEvent, ChannelName

if TYPE_CHECKING:

from anyio.abc import ObjectReceiveStream, ObjectSendStream
from anyio.abc import ObjectReceiveStream, ObjectSendStream, TaskGroup

logger = logging.getLogger(__name__)

Expand All @@ -44,6 +44,78 @@ def __init__(
super().__init__(url)
self._channel_hook = channel_hook

async def post_writer( # type: ignore[override]
self,
client: httpx.AsyncClient,
write_stream_reader: "ObjectReceiveStream[SessionMessage]",
read_stream_writer: StreamWriter,
write_stream: "ObjectSendStream[SessionMessage]",
start_get_stream: Callable[[], None],
tg: "TaskGroup",
) -> None:
"""
Override to dispatch all outbound messages asynchronously.

The base transport awaits non-request messages, which can block elicitation
responses behind a long-lived tools/call POST. Running everything in the
task group ensures elicitation/create replies are sent immediately.
"""
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
message = session_message.message
metadata = (
session_message.metadata
if isinstance(session_message.metadata, ClientMessageMetadata)
else None
)
root = message.root if isinstance(message, JSONRPCMessage) else None

# For responses/errors, use a short-lived client to avoid blocking behind long POSTs.
response_client: httpx.AsyncClient | None = None
if isinstance(root, (JSONRPCResponse, JSONRPCError)):
response_client = httpx.AsyncClient(
headers=client.headers,
timeout=client.timeout,
limits=httpx.Limits(
max_connections=10,
max_keepalive_connections=0,
),
http2=False,
)

is_resumption = bool(metadata and metadata.resumption_token)

# Handle initialized notification
if self._is_initialized_notification(message):
start_get_stream()

ctx = RequestContext(
client=response_client or client,
session_id=self.session_id,
session_message=session_message,
metadata=metadata,
read_stream_writer=read_stream_writer,
)

async def handle_request_async() -> None:
if is_resumption:
await self._handle_resumption_request(ctx)
else:
await self._handle_post_request(ctx)
if response_client:
await response_client.aclose()

# Always dispatch asynchronously so responses are not gated by
# any in-flight POST (e.g., tools/call).
tg.start_soon(handle_request_async)

except Exception:
logger.exception("Error in post_writer") # pragma: no cover
finally:
await read_stream_writer.aclose()
await write_stream.aclose()

def _emit_channel_event(
self,
channel: ChannelName,
Expand Down Expand Up @@ -375,7 +447,6 @@ async def tracking_streamablehttp_client(

client_provided = http_client is not None
client = http_client or create_mcp_http_client()

async with anyio.create_task_group() as tg:
try:
async with AsyncExitStack() as stack:
Expand Down
Loading