Skip to content

Commit 538136a

Browse files
committed
test: add streamable HTTP hosting, resumability, and client transport conformance tests
1 parent 584e098 commit 538136a

15 files changed

Lines changed: 1356 additions & 163 deletions

src/mcp/client/streamable_http.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:
210210
# Stream ended normally (server closed) - reset attempt counter
211211
attempt = 0
212212

213-
except Exception: # pragma: lax no cover
213+
except Exception:
214214
logger.debug("GET stream error", exc_info=True)
215215
attempt += 1
216216

@@ -492,17 +492,17 @@ async def handle_request_async():
492492

493493
async def terminate_session(self, client: httpx.AsyncClient) -> None:
494494
"""Terminate the session by sending a DELETE request."""
495-
if not self.session_id: # pragma: lax no cover
496-
return
495+
if not self.session_id:
496+
return # pragma: no cover
497497

498498
try:
499499
headers = self._prepare_headers()
500500
response = await client.delete(self.url, headers=headers)
501501

502-
if response.status_code == 405: # pragma: lax no cover
502+
if response.status_code == 405:
503503
logger.debug("Server does not allow session termination")
504-
elif response.status_code not in (200, 204): # pragma: lax no cover
505-
logger.warning(f"Session termination failed: {response.status_code}")
504+
elif response.status_code not in (200, 204):
505+
logger.warning(f"Session termination failed: {response.status_code}") # pragma: no cover
506506
except Exception as exc: # pragma: no cover
507507
logger.warning(f"Session termination failed: {exc}")
508508

src/mcp/server/mcpserver/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ async def close_sse_stream(self) -> None:
237237
This is a no-op if not using StreamableHTTP transport with event_store.
238238
The callback is only available when event_store is configured.
239239
"""
240-
if self._request_context and self._request_context.close_sse_stream: # pragma: no cover
240+
if self._request_context and self._request_context.close_sse_stream: # pragma: no branch
241241
await self._request_context.close_sse_stream()
242242

243243
async def close_standalone_sse_stream(self) -> None:

src/mcp/server/streamable_http.py

Lines changed: 47 additions & 47 deletions
Large diffs are not rendered by default.

src/mcp/server/streamable_http_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
173173
self.app.create_initialization_options(),
174174
stateless=True,
175175
)
176-
except Exception: # pragma: no cover
176+
except Exception: # pragma: lax no cover
177177
logger.exception("Stateless session crashed")
178178

179179
# Assert task group is not None for type checking

src/mcp/server/transport_security.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,19 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
4040
# If not specified, disable DNS rebinding protection by default for backwards compatibility
4141
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)
4242

43-
def _validate_host(self, host: str | None) -> bool: # pragma: no cover
43+
def _validate_host(self, host: str | None) -> bool:
4444
"""Validate the Host header against allowed values."""
45-
if not host:
45+
if not host: # pragma: no cover
4646
logger.warning("Missing Host header in request")
4747
return False
4848

4949
# Check exact match first
50-
if host in self.settings.allowed_hosts:
50+
if host in self.settings.allowed_hosts: # pragma: no cover
5151
return True
5252

5353
# Check wildcard port patterns
5454
for allowed in self.settings.allowed_hosts:
55-
if allowed.endswith(":*"):
55+
if allowed.endswith(":*"): # pragma: no branch
5656
# Extract base host from pattern
5757
base_host = allowed[:-2]
5858
# Check if the actual host starts with base host and has a port
@@ -62,19 +62,19 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover
6262
logger.warning(f"Invalid Host header: {host}")
6363
return False
6464

65-
def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover
65+
def _validate_origin(self, origin: str | None) -> bool:
6666
"""Validate the Origin header against allowed values."""
6767
# Origin can be absent for same-origin requests
68-
if not origin:
68+
if not origin: # pragma: no cover
6969
return True
7070

7171
# Check exact match first
72-
if origin in self.settings.allowed_origins:
72+
if origin in self.settings.allowed_origins: # pragma: no cover
7373
return True
7474

7575
# Check wildcard port patterns
7676
for allowed in self.settings.allowed_origins:
77-
if allowed.endswith(":*"):
77+
if allowed.endswith(":*"): # pragma: no branch
7878
# Extract base origin from pattern
7979
base_origin = allowed[:-2]
8080
# Check if the actual origin starts with base origin and has a port
@@ -103,14 +103,14 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
103103
if not self.settings.enable_dns_rebinding_protection:
104104
return None
105105

106-
# Validate Host header # pragma: no cover
107-
host = request.headers.get("host") # pragma: no cover
108-
if not self._validate_host(host): # pragma: no cover
109-
return Response("Invalid Host header", status_code=421) # pragma: no cover
106+
# Validate Host header
107+
host = request.headers.get("host")
108+
if not self._validate_host(host):
109+
return Response("Invalid Host header", status_code=421)
110110

111-
# Validate Origin header # pragma: no cover
112-
origin = request.headers.get("origin") # pragma: no cover
113-
if not self._validate_origin(origin): # pragma: no cover
114-
return Response("Invalid Origin header", status_code=403) # pragma: no cover
111+
# Validate Origin header
112+
origin = request.headers.get("origin")
113+
if not self._validate_origin(origin):
114+
return Response("Invalid Origin header", status_code=403)
115115

116-
return None # pragma: no cover
116+
return None

tests/interaction/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ stream pair), the bare-`ClientSession` lifecycle tests, the real-clock timeout t
6262
machinery is transport-independent and must not race transport latency), and everything under
6363
`transports/`, which pins behaviour only observable on that transport.
6464

65+
A transport conformance test in `transports/` speaks raw `httpx` against the mounted ASGI app
66+
**only** when its assertion is about HTTP semantics that `Client` cannot observe — status codes,
67+
response headers, SSE event fields, which stream a message travels on. Any other behaviour is
68+
asserted through a `Client`, connected to the mounted app via `client_via_http(http)` so several
69+
clients can share one session manager.
70+
6571
## The requirements manifest
6672

6773
`_requirements.py` maps every behaviour the suite covers to the reason it must hold:

tests/interaction/_connect.py

Lines changed: 165 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99

1010
import gc
1111
import warnings
12-
from collections.abc import AsyncIterator
12+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
1313
from contextlib import AbstractAsyncContextManager, asynccontextmanager
1414
from typing import Protocol
1515

1616
import httpx
17+
from httpx_sse import ServerSentEvent, aconnect_sse
1718
from starlette.applications import Starlette
1819
from starlette.requests import Request
1920
from starlette.responses import Response
@@ -26,12 +27,29 @@
2627
from mcp.server import Server
2728
from mcp.server.mcpserver import MCPServer
2829
from mcp.server.sse import SseServerTransport
30+
from mcp.server.streamable_http import EventStore
31+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
2932
from mcp.server.transport_security import TransportSecuritySettings
30-
from mcp.types import Implementation
33+
from mcp.types import (
34+
LATEST_PROTOCOL_VERSION,
35+
ClientCapabilities,
36+
Implementation,
37+
InitializeRequestParams,
38+
JSONRPCMessage,
39+
JSONRPCRequest,
40+
JSONRPCResponse,
41+
jsonrpc_message_adapter,
42+
)
3143
from tests.interaction.transports._bridge import StreamingASGITransport
3244

3345
# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here.
34-
_BASE_URL = "http://127.0.0.1:8000"
46+
BASE_URL = "http://127.0.0.1:8000"
47+
48+
# DNS-rebinding protection validates Host/Origin headers against a real network attack that cannot
49+
# exist for an in-process ASGI app, so the in-process factories disable it; tests that exercise the
50+
# protection itself pass explicit settings (or transport_security=None to get the localhost
51+
# auto-enable behaviour).
52+
NO_DNS_REBINDING_PROTECTION = TransportSecuritySettings(enable_dns_rebinding_protection=False)
3553

3654

3755
class Connect(Protocol):
@@ -86,6 +104,8 @@ async def connect_over_streamable_http(
86104
*,
87105
stateless_http: bool = False,
88106
json_response: bool = False,
107+
event_store: EventStore | None = None,
108+
retry_interval: int | None = None,
89109
read_timeout_seconds: float | None = None,
90110
sampling_callback: SamplingFnT | None = None,
91111
list_roots_callback: ListRootsFnT | None = None,
@@ -98,19 +118,19 @@ async def connect_over_streamable_http(
98118
99119
With the defaults this is the matrix leg (stateful sessions, SSE responses); the
100120
transport-specific tests pass `stateless_http` or `json_response` to select the other
101-
server modes.
121+
server modes, and the resumability tests pass an `event_store` (with `retry_interval=0` so
122+
the client's reconnection wait is a no-op).
102123
"""
103-
# DNS-rebinding protection validates Host/Origin headers against a real network attack that
104-
# cannot exist for an in-process ASGI app; leaving it on would also pull the origin-validation
105-
# branch (deliberately uncovered in src) into coverage.
106124
app = server.streamable_http_app(
107125
stateless_http=stateless_http,
108126
json_response=json_response,
109-
transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False),
127+
event_store=event_store,
128+
retry_interval=retry_interval,
129+
transport_security=NO_DNS_REBINDING_PROTECTION,
110130
)
111131
async with server.session_manager.run():
112-
async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=_BASE_URL) as http_client:
113-
transport = streamable_http_client(f"{_BASE_URL}/mcp", http_client=http_client)
132+
async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client:
133+
transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client)
114134
async with Client(
115135
transport,
116136
read_timeout_seconds=read_timeout_seconds,
@@ -124,6 +144,139 @@ async def connect_over_streamable_http(
124144
yield client
125145

126146

147+
@asynccontextmanager
148+
async def mounted_app(
149+
server: Server | MCPServer,
150+
*,
151+
stateless_http: bool = False,
152+
event_store: EventStore | None = None,
153+
retry_interval: int | None = None,
154+
transport_security: TransportSecuritySettings | None = NO_DNS_REBINDING_PROTECTION,
155+
on_request: Callable[[httpx.Request], Awaitable[None]] | None = None,
156+
headers: dict[str, str] | None = None,
157+
) -> AsyncIterator[tuple[httpx.AsyncClient, StreamableHTTPSessionManager]]:
158+
"""Mount the server's streamable HTTP app on the in-process bridge and yield an httpx client.
159+
160+
Yields the httpx client (rooted at the in-process origin) and the live session manager. Tests
161+
use this in two ways: for raw-httpx assertions (status codes, headers, SSE bytes) the test
162+
speaks HTTP through the yielded client directly; for client-driven assertions the test wraps
163+
that client in `client_via_http(http)`, which lets several `Client`s share the one mounted
164+
session manager. `on_request` records every outgoing HTTP request before it leaves the
165+
yielded client.
166+
167+
DNS-rebinding protection is disabled by default; pass explicit settings (or `None` for the
168+
localhost auto-enable behaviour) to test the protection itself.
169+
"""
170+
app = server.streamable_http_app(
171+
stateless_http=stateless_http,
172+
event_store=event_store,
173+
retry_interval=retry_interval,
174+
transport_security=transport_security,
175+
)
176+
event_hooks = {"request": [on_request]} if on_request is not None else None
177+
async with server.session_manager.run():
178+
async with httpx.AsyncClient(
179+
transport=StreamingASGITransport(app), base_url=BASE_URL, event_hooks=event_hooks, headers=headers
180+
) as http_client:
181+
yield http_client, server.session_manager
182+
183+
184+
@asynccontextmanager
185+
async def client_via_http(
186+
http_client: httpx.AsyncClient,
187+
*,
188+
logging_callback: LoggingFnT | None = None,
189+
message_handler: MessageHandlerFnT | None = None,
190+
elicitation_callback: ElicitationFnT | None = None,
191+
) -> AsyncIterator[Client]:
192+
"""Connect a `Client` over an already-mounted streamable HTTP app.
193+
194+
Use with `mounted_app(...)` so several `Client`s share the one session manager, or so a
195+
client-driven assertion can sit alongside raw-httpx assertions in the same test. The
196+
underlying `httpx.AsyncClient` is left open when the `Client` exits.
197+
"""
198+
transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client)
199+
async with Client(
200+
transport,
201+
logging_callback=logging_callback,
202+
message_handler=message_handler,
203+
elicitation_callback=elicitation_callback,
204+
) as client:
205+
yield client
206+
207+
208+
def parse_sse_messages(events: Iterable[ServerSentEvent]) -> list[JSONRPCMessage]:
209+
"""Decode SSE events into JSON-RPC messages, skipping priming events that carry no data."""
210+
return [jsonrpc_message_adapter.validate_json(event.data) for event in events if event.data]
211+
212+
213+
async def post_jsonrpc(
214+
http: httpx.AsyncClient, body: dict[str, object], *, session_id: str | None = None
215+
) -> tuple[httpx.Response, list[JSONRPCMessage]]:
216+
"""POST a JSON-RPC body and read its SSE response stream to completion.
217+
218+
Returns the HTTP response (for header/status assertions) and the parsed JSON-RPC messages
219+
that arrived on the response's SSE stream. Only meaningful for requests the server answers
220+
with `text/event-stream`; for error responses or 202 notification acknowledgements, use
221+
`httpx.AsyncClient.post` directly and assert on the response.
222+
"""
223+
async with aconnect_sse(http, "POST", "/mcp", json=body, headers=base_headers(session_id=session_id)) as source:
224+
events = [event async for event in source.aiter_sse()]
225+
return source.response, parse_sse_messages(events)
226+
227+
228+
def base_headers(*, session_id: str | None = None) -> dict[str, str]:
229+
"""Standard request headers for raw-httpx streamable-HTTP tests.
230+
231+
Every well-formed request carries these (Accept covering both response representations,
232+
Content-Type for POST bodies, MCP-Protocol-Version at the latest revision, and the session
233+
ID once one exists), so a test that wants to assert a specific rejection only varies the one
234+
header under test.
235+
"""
236+
headers = {
237+
"accept": "application/json, text/event-stream",
238+
"content-type": "application/json",
239+
"mcp-protocol-version": LATEST_PROTOCOL_VERSION,
240+
}
241+
if session_id is not None:
242+
headers["mcp-session-id"] = session_id
243+
return headers
244+
245+
246+
def initialize_body(request_id: int = 1) -> dict[str, object]:
247+
"""A wire-level initialize JSON-RPC request body, exactly as an SDK client would send it."""
248+
params = InitializeRequestParams(
249+
protocol_version=LATEST_PROTOCOL_VERSION,
250+
capabilities=ClientCapabilities(),
251+
client_info=Implementation(name="raw", version="0.0.0"),
252+
)
253+
return JSONRPCRequest(
254+
jsonrpc="2.0", id=request_id, method="initialize", params=params.model_dump(by_alias=True, exclude_none=True)
255+
).model_dump(by_alias=True, exclude_none=True)
256+
257+
258+
async def initialize_via_http(http: httpx.AsyncClient) -> str:
259+
"""Perform the initialize handshake over a raw `httpx.AsyncClient` and return the session ID.
260+
261+
Validates the SSE response and sends the `notifications/initialized` follow-up, so the server
262+
is fully ready for subsequent feature requests when this returns.
263+
"""
264+
async with aconnect_sse(http, "POST", "/mcp", json=initialize_body(), headers=base_headers()) as source:
265+
assert source.response.status_code == 200
266+
# An event-store-backed server opens the stream with a priming event (empty data); skip it.
267+
events = [event async for event in source.aiter_sse() if event.data]
268+
assert len(events) == 1
269+
assert JSONRPCResponse.model_validate_json(events[0].data).id == 1
270+
session_id = source.response.headers["mcp-session-id"]
271+
initialized = await http.post(
272+
"/mcp",
273+
json={"jsonrpc": "2.0", "method": "notifications/initialized"},
274+
headers=base_headers(session_id=session_id),
275+
)
276+
assert initialized.status_code == 202
277+
return session_id
278+
279+
127280
def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTransport]:
128281
"""Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/.
129282
@@ -175,13 +328,13 @@ def httpx_client_factory(
175328
# bridge must let the application drain rather than cancelling at close.
176329
return httpx.AsyncClient(
177330
transport=StreamingASGITransport(app, cancel_on_close=False),
178-
base_url=_BASE_URL,
331+
base_url=BASE_URL,
179332
headers=headers,
180333
timeout=timeout,
181334
auth=auth,
182335
)
183336

184-
transport = sse_client(f"{_BASE_URL}/sse", httpx_client_factory=httpx_client_factory)
337+
transport = sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory)
185338
try:
186339
async with Client(
187340
transport,

0 commit comments

Comments
 (0)