22import multiprocessing
33import socket
44from collections .abc import AsyncGenerator , Generator
5- from typing import Any
5+ from typing import Any , cast
66from unittest .mock import AsyncMock , MagicMock , Mock , patch
77from urllib .parse import urlparse
88
2020import mcp .client .sse
2121from mcp import types
2222from mcp .client .session import ClientSession
23- from mcp .client .sse import _extract_session_id_from_endpoint , sse_client
23+ from mcp .client .sse import _extract_session_id_from_endpoint , _resolve_endpoint_url , sse_client
2424from mcp .server import Server , ServerRequestContext
2525from mcp .server .sse import SseServerTransport
2626from mcp .server .transport_security import TransportSecuritySettings
@@ -229,6 +229,44 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non
229229 assert _extract_session_id_from_endpoint (endpoint_url ) == expected
230230
231231
232+ @pytest .mark .parametrize (
233+ ("sse_url" , "endpoint_data" , "messages_url" , "expected" ),
234+ [
235+ (
236+ "https://example.com/api/v1/sse" ,
237+ "/v1/messages/?session_id=abc123" ,
238+ None ,
239+ "https://example.com/v1/messages/?session_id=abc123" ,
240+ ),
241+ (
242+ "https://example.com/api/v1/sse" ,
243+ "/v1/messages/?session_id=abc123" ,
244+ "https://example.com/api/v1/messages/" ,
245+ "https://example.com/api/v1/messages/?session_id=abc123" ,
246+ ),
247+ (
248+ "https://example.com/api/v1/sse" ,
249+ "/v1/messages/?session_id=abc123" ,
250+ "/api/v1/messages/" ,
251+ "https://example.com/api/v1/messages/?session_id=abc123" ,
252+ ),
253+ (
254+ "https://example.com/api/v1/sse" ,
255+ "/v1/messages/?session_id=abc123" ,
256+ "https://example.com/api/v1/messages/?tenant=blue" ,
257+ "https://example.com/api/v1/messages/?tenant=blue&session_id=abc123" ,
258+ ),
259+ ],
260+ )
261+ def test_resolve_endpoint_url_with_messages_url_override (
262+ sse_url : str ,
263+ endpoint_data : str ,
264+ messages_url : str | None ,
265+ expected : str ,
266+ ) -> None :
267+ assert _resolve_endpoint_url (sse_url , endpoint_data , messages_url ) == expected
268+
269+
232270@pytest .mark .anyio
233271async def test_sse_client_on_session_created_not_called_when_no_session_id (
234272 server : None , server_url : str , monkeypatch : pytest .MonkeyPatch
@@ -249,6 +287,60 @@ def mock_extract(url: str) -> None:
249287 callback_mock .assert_not_called ()
250288
251289
290+ @pytest .mark .anyio
291+ async def test_sse_client_uses_messages_url_override () -> None :
292+ init_result = InitializeResult (
293+ protocol_version = "2024-11-05" ,
294+ capabilities = ServerCapabilities (),
295+ server_info = Implementation (name = "test" , version = "1.0" ),
296+ )
297+ response = JSONRPCResponse (
298+ jsonrpc = "2.0" ,
299+ id = 0 ,
300+ result = init_result .model_dump (by_alias = True , exclude_none = True ),
301+ )
302+ response_json = response .model_dump_json (by_alias = True , exclude_none = True )
303+
304+ async def mock_aiter_sse () -> AsyncGenerator [ServerSentEvent , None ]:
305+ yield ServerSentEvent (event = "endpoint" , data = "/v1/messages/?session_id=abc123" )
306+ yield ServerSentEvent (event = "message" , data = response_json )
307+ await anyio .sleep_forever ()
308+
309+ mock_event_source = MagicMock ()
310+ mock_event_source .aiter_sse .return_value = mock_aiter_sse ()
311+ mock_event_source .response = MagicMock ()
312+ mock_event_source .response .raise_for_status = MagicMock ()
313+
314+ mock_aconnect_sse = MagicMock ()
315+ mock_aconnect_sse .__aenter__ = AsyncMock (return_value = mock_event_source )
316+ mock_aconnect_sse .__aexit__ = AsyncMock (return_value = None )
317+
318+ mock_client = MagicMock ()
319+ mock_client .__aenter__ = AsyncMock (return_value = mock_client )
320+ mock_client .__aexit__ = AsyncMock (return_value = None )
321+ mock_client .post = AsyncMock (return_value = MagicMock (status_code = 200 , raise_for_status = MagicMock ()))
322+
323+ def mock_httpx_client_factory (
324+ headers : dict [str , str ] | None = None ,
325+ timeout : httpx .Timeout | None = None ,
326+ auth : httpx .Auth | None = None ,
327+ ) -> httpx .AsyncClient :
328+ _ = (headers , timeout , auth )
329+ return cast (httpx .AsyncClient , mock_client )
330+
331+ with patch ("mcp.client.sse.aconnect_sse" , return_value = mock_aconnect_sse ):
332+ async with sse_client (
333+ "https://example.com/api/v1/sse" ,
334+ httpx_client_factory = mock_httpx_client_factory ,
335+ messages_url = "https://example.com/api/v1/messages/" ,
336+ ) as streams :
337+ async with ClientSession (* streams ) as session :
338+ await session .initialize ()
339+
340+ mock_client .post .assert_awaited ()
341+ assert mock_client .post .await_args .args [0 ] == "https://example.com/api/v1/messages/?session_id=abc123"
342+
343+
252344@pytest .fixture
253345async def initialized_sse_client_session (server : None , server_url : str ) -> AsyncGenerator [ClientSession , None ]:
254346 async with sse_client (server_url + "/sse" , sse_read_timeout = 0.5 ) as streams :
0 commit comments