99
1010import gc
1111import warnings
12- from collections .abc import AsyncIterator
12+ from collections .abc import AsyncIterator , Awaitable , Callable , Iterable
1313from contextlib import AbstractAsyncContextManager , asynccontextmanager
1414from typing import Protocol
1515
1616import httpx
17+ from httpx_sse import ServerSentEvent , aconnect_sse
1718from starlette .applications import Starlette
1819from starlette .requests import Request
1920from starlette .responses import Response
2627from mcp .server import Server
2728from mcp .server .mcpserver import MCPServer
2829from mcp .server .sse import SseServerTransport
30+ from mcp .server .streamable_http import EventStore
31+ from mcp .server .streamable_http_manager import StreamableHTTPSessionManager
2932from mcp .server .transport_security import TransportSecuritySettings
30- from mcp .types import Implementation
33+ from mcp .types import (
34+ LATEST_PROTOCOL_VERSION ,
35+ ClientCapabilities ,
36+ Implementation ,
37+ InitializeRequestParams ,
38+ JSONRPCMessage ,
39+ JSONRPCRequest ,
40+ JSONRPCResponse ,
41+ jsonrpc_message_adapter ,
42+ )
3143from tests .interaction .transports ._bridge import StreamingASGITransport
3244
3345# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here.
34- _BASE_URL = "http://127.0.0.1:8000"
46+ BASE_URL = "http://127.0.0.1:8000"
47+
48+ # DNS-rebinding protection validates Host/Origin headers against a real network attack that cannot
49+ # exist for an in-process ASGI app, so the in-process factories disable it; tests that exercise the
50+ # protection itself pass explicit settings (or transport_security=None to get the localhost
51+ # auto-enable behaviour).
52+ NO_DNS_REBINDING_PROTECTION = TransportSecuritySettings (enable_dns_rebinding_protection = False )
3553
3654
3755class Connect (Protocol ):
@@ -86,6 +104,8 @@ async def connect_over_streamable_http(
86104 * ,
87105 stateless_http : bool = False ,
88106 json_response : bool = False ,
107+ event_store : EventStore | None = None ,
108+ retry_interval : int | None = None ,
89109 read_timeout_seconds : float | None = None ,
90110 sampling_callback : SamplingFnT | None = None ,
91111 list_roots_callback : ListRootsFnT | None = None ,
@@ -98,19 +118,19 @@ async def connect_over_streamable_http(
98118
99119 With the defaults this is the matrix leg (stateful sessions, SSE responses); the
100120 transport-specific tests pass `stateless_http` or `json_response` to select the other
101- server modes.
121+ server modes, and the resumability tests pass an `event_store` (with `retry_interval=0` so
122+ the client's reconnection wait is a no-op).
102123 """
103- # DNS-rebinding protection validates Host/Origin headers against a real network attack that
104- # cannot exist for an in-process ASGI app; leaving it on would also pull the origin-validation
105- # branch (deliberately uncovered in src) into coverage.
106124 app = server .streamable_http_app (
107125 stateless_http = stateless_http ,
108126 json_response = json_response ,
109- transport_security = TransportSecuritySettings (enable_dns_rebinding_protection = False ),
127+ event_store = event_store ,
128+ retry_interval = retry_interval ,
129+ transport_security = NO_DNS_REBINDING_PROTECTION ,
110130 )
111131 async with server .session_manager .run ():
112- async with httpx .AsyncClient (transport = StreamingASGITransport (app ), base_url = _BASE_URL ) as http_client :
113- transport = streamable_http_client (f"{ _BASE_URL } /mcp" , http_client = http_client )
132+ async with httpx .AsyncClient (transport = StreamingASGITransport (app ), base_url = BASE_URL ) as http_client :
133+ transport = streamable_http_client (f"{ BASE_URL } /mcp" , http_client = http_client )
114134 async with Client (
115135 transport ,
116136 read_timeout_seconds = read_timeout_seconds ,
@@ -124,6 +144,139 @@ async def connect_over_streamable_http(
124144 yield client
125145
126146
147+ @asynccontextmanager
148+ async def mounted_app (
149+ server : Server | MCPServer ,
150+ * ,
151+ stateless_http : bool = False ,
152+ event_store : EventStore | None = None ,
153+ retry_interval : int | None = None ,
154+ transport_security : TransportSecuritySettings | None = NO_DNS_REBINDING_PROTECTION ,
155+ on_request : Callable [[httpx .Request ], Awaitable [None ]] | None = None ,
156+ headers : dict [str , str ] | None = None ,
157+ ) -> AsyncIterator [tuple [httpx .AsyncClient , StreamableHTTPSessionManager ]]:
158+ """Mount the server's streamable HTTP app on the in-process bridge and yield an httpx client.
159+
160+ Yields the httpx client (rooted at the in-process origin) and the live session manager. Tests
161+ use this in two ways: for raw-httpx assertions (status codes, headers, SSE bytes) the test
162+ speaks HTTP through the yielded client directly; for client-driven assertions the test wraps
163+ that client in `client_via_http(http)`, which lets several `Client`s share the one mounted
164+ session manager. `on_request` records every outgoing HTTP request before it leaves the
165+ yielded client.
166+
167+ DNS-rebinding protection is disabled by default; pass explicit settings (or `None` for the
168+ localhost auto-enable behaviour) to test the protection itself.
169+ """
170+ app = server .streamable_http_app (
171+ stateless_http = stateless_http ,
172+ event_store = event_store ,
173+ retry_interval = retry_interval ,
174+ transport_security = transport_security ,
175+ )
176+ event_hooks = {"request" : [on_request ]} if on_request is not None else None
177+ async with server .session_manager .run ():
178+ async with httpx .AsyncClient (
179+ transport = StreamingASGITransport (app ), base_url = BASE_URL , event_hooks = event_hooks , headers = headers
180+ ) as http_client :
181+ yield http_client , server .session_manager
182+
183+
184+ @asynccontextmanager
185+ async def client_via_http (
186+ http_client : httpx .AsyncClient ,
187+ * ,
188+ logging_callback : LoggingFnT | None = None ,
189+ message_handler : MessageHandlerFnT | None = None ,
190+ elicitation_callback : ElicitationFnT | None = None ,
191+ ) -> AsyncIterator [Client ]:
192+ """Connect a `Client` over an already-mounted streamable HTTP app.
193+
194+ Use with `mounted_app(...)` so several `Client`s share the one session manager, or so a
195+ client-driven assertion can sit alongside raw-httpx assertions in the same test. The
196+ underlying `httpx.AsyncClient` is left open when the `Client` exits.
197+ """
198+ transport = streamable_http_client (f"{ BASE_URL } /mcp" , http_client = http_client )
199+ async with Client (
200+ transport ,
201+ logging_callback = logging_callback ,
202+ message_handler = message_handler ,
203+ elicitation_callback = elicitation_callback ,
204+ ) as client :
205+ yield client
206+
207+
208+ def parse_sse_messages (events : Iterable [ServerSentEvent ]) -> list [JSONRPCMessage ]:
209+ """Decode SSE events into JSON-RPC messages, skipping priming events that carry no data."""
210+ return [jsonrpc_message_adapter .validate_json (event .data ) for event in events if event .data ]
211+
212+
213+ async def post_jsonrpc (
214+ http : httpx .AsyncClient , body : dict [str , object ], * , session_id : str | None = None
215+ ) -> tuple [httpx .Response , list [JSONRPCMessage ]]:
216+ """POST a JSON-RPC body and read its SSE response stream to completion.
217+
218+ Returns the HTTP response (for header/status assertions) and the parsed JSON-RPC messages
219+ that arrived on the response's SSE stream. Only meaningful for requests the server answers
220+ with `text/event-stream`; for error responses or 202 notification acknowledgements, use
221+ `httpx.AsyncClient.post` directly and assert on the response.
222+ """
223+ async with aconnect_sse (http , "POST" , "/mcp" , json = body , headers = base_headers (session_id = session_id )) as source :
224+ events = [event async for event in source .aiter_sse ()]
225+ return source .response , parse_sse_messages (events )
226+
227+
228+ def base_headers (* , session_id : str | None = None ) -> dict [str , str ]:
229+ """Standard request headers for raw-httpx streamable-HTTP tests.
230+
231+ Every well-formed request carries these (Accept covering both response representations,
232+ Content-Type for POST bodies, MCP-Protocol-Version at the latest revision, and the session
233+ ID once one exists), so a test that wants to assert a specific rejection only varies the one
234+ header under test.
235+ """
236+ headers = {
237+ "accept" : "application/json, text/event-stream" ,
238+ "content-type" : "application/json" ,
239+ "mcp-protocol-version" : LATEST_PROTOCOL_VERSION ,
240+ }
241+ if session_id is not None :
242+ headers ["mcp-session-id" ] = session_id
243+ return headers
244+
245+
246+ def initialize_body (request_id : int = 1 ) -> dict [str , object ]:
247+ """A wire-level initialize JSON-RPC request body, exactly as an SDK client would send it."""
248+ params = InitializeRequestParams (
249+ protocol_version = LATEST_PROTOCOL_VERSION ,
250+ capabilities = ClientCapabilities (),
251+ client_info = Implementation (name = "raw" , version = "0.0.0" ),
252+ )
253+ return JSONRPCRequest (
254+ jsonrpc = "2.0" , id = request_id , method = "initialize" , params = params .model_dump (by_alias = True , exclude_none = True )
255+ ).model_dump (by_alias = True , exclude_none = True )
256+
257+
258+ async def initialize_via_http (http : httpx .AsyncClient ) -> str :
259+ """Perform the initialize handshake over a raw `httpx.AsyncClient` and return the session ID.
260+
261+ Validates the SSE response and sends the `notifications/initialized` follow-up, so the server
262+ is fully ready for subsequent feature requests when this returns.
263+ """
264+ async with aconnect_sse (http , "POST" , "/mcp" , json = initialize_body (), headers = base_headers ()) as source :
265+ assert source .response .status_code == 200
266+ # An event-store-backed server opens the stream with a priming event (empty data); skip it.
267+ events = [event async for event in source .aiter_sse () if event .data ]
268+ assert len (events ) == 1
269+ assert JSONRPCResponse .model_validate_json (events [0 ].data ).id == 1
270+ session_id = source .response .headers ["mcp-session-id" ]
271+ initialized = await http .post (
272+ "/mcp" ,
273+ json = {"jsonrpc" : "2.0" , "method" : "notifications/initialized" },
274+ headers = base_headers (session_id = session_id ),
275+ )
276+ assert initialized .status_code == 202
277+ return session_id
278+
279+
127280def build_sse_app (server : Server | MCPServer ) -> tuple [Starlette , SseServerTransport ]:
128281 """Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/.
129282
@@ -175,13 +328,13 @@ def httpx_client_factory(
175328 # bridge must let the application drain rather than cancelling at close.
176329 return httpx .AsyncClient (
177330 transport = StreamingASGITransport (app , cancel_on_close = False ),
178- base_url = _BASE_URL ,
331+ base_url = BASE_URL ,
179332 headers = headers ,
180333 timeout = timeout ,
181334 auth = auth ,
182335 )
183336
184- transport = sse_client (f"{ _BASE_URL } /sse" , httpx_client_factory = httpx_client_factory )
337+ transport = sse_client (f"{ BASE_URL } /sse" , httpx_client_factory = httpx_client_factory )
185338 try :
186339 async with Client (
187340 transport ,
0 commit comments