Skip to content

Commit e6bd00c

Browse files
author
Henry Lee
committed
feat: allow overriding SSE messages endpoint
1 parent 3d7b311 commit e6bd00c

2 files changed

Lines changed: 113 additions & 3 deletions

File tree

src/mcp/client/sse.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
2727
return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]
2828

2929

30+
def _resolve_endpoint_url(sse_url: str, endpoint_data: str, messages_url: str | None = None) -> str:
31+
if messages_url is None:
32+
return urljoin(sse_url, endpoint_data)
33+
34+
endpoint_url = urljoin(sse_url, messages_url)
35+
endpoint_query = urlparse(endpoint_data).query
36+
if endpoint_query:
37+
endpoint_parsed = urlparse(endpoint_url)
38+
query = "&".join(filter(None, [endpoint_parsed.query, endpoint_query]))
39+
endpoint_url = endpoint_parsed._replace(query=query).geturl()
40+
41+
return endpoint_url
42+
43+
3044
@asynccontextmanager
3145
async def sse_client(
3246
url: str,
@@ -36,6 +50,7 @@ async def sse_client(
3650
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
3751
auth: httpx.Auth | None = None,
3852
on_session_created: Callable[[str], None] | None = None,
53+
messages_url: str | None = None,
3954
):
4055
"""Client transport for SSE.
4156
@@ -50,6 +65,9 @@ async def sse_client(
5065
httpx_client_factory: Factory function for creating the HTTPX client.
5166
auth: Optional HTTPX authentication handler.
5267
on_session_created: Optional callback invoked with the session ID when received.
68+
messages_url: Optional message endpoint URL to use instead of deriving it
69+
from the SSE endpoint event. Relative URLs are resolved against `url`,
70+
and any session query parameters from the endpoint event are preserved.
5371
"""
5472
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
5573
async with httpx_client_factory(
@@ -68,7 +86,7 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
6886
logger.debug(f"Received SSE event: {sse.event}")
6987
match sse.event:
7088
case "endpoint":
71-
endpoint_url = urljoin(url, sse.data)
89+
endpoint_url = _resolve_endpoint_url(url, sse.data, messages_url)
7290
logger.debug(f"Received endpoint URL: {endpoint_url}")
7391

7492
url_parsed = urlparse(url)

tests/shared/test_sse.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import multiprocessing
33
import socket
44
from collections.abc import AsyncGenerator, Generator
5-
from typing import Any
5+
from typing import Any, cast
66
from unittest.mock import AsyncMock, MagicMock, Mock, patch
77
from urllib.parse import urlparse
88

@@ -20,7 +20,7 @@
2020
import mcp.client.sse
2121
from mcp import types
2222
from mcp.client.session import ClientSession
23-
from mcp.client.sse import _extract_session_id_from_endpoint, sse_client
23+
from mcp.client.sse import _extract_session_id_from_endpoint, _resolve_endpoint_url, sse_client
2424
from mcp.server import Server, ServerRequestContext
2525
from mcp.server.sse import SseServerTransport
2626
from mcp.server.transport_security import TransportSecuritySettings
@@ -229,6 +229,44 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non
229229
assert _extract_session_id_from_endpoint(endpoint_url) == expected
230230

231231

232+
@pytest.mark.parametrize(
233+
("sse_url", "endpoint_data", "messages_url", "expected"),
234+
[
235+
(
236+
"https://example.com/api/v1/sse",
237+
"/v1/messages/?session_id=abc123",
238+
None,
239+
"https://example.com/v1/messages/?session_id=abc123",
240+
),
241+
(
242+
"https://example.com/api/v1/sse",
243+
"/v1/messages/?session_id=abc123",
244+
"https://example.com/api/v1/messages/",
245+
"https://example.com/api/v1/messages/?session_id=abc123",
246+
),
247+
(
248+
"https://example.com/api/v1/sse",
249+
"/v1/messages/?session_id=abc123",
250+
"/api/v1/messages/",
251+
"https://example.com/api/v1/messages/?session_id=abc123",
252+
),
253+
(
254+
"https://example.com/api/v1/sse",
255+
"/v1/messages/?session_id=abc123",
256+
"https://example.com/api/v1/messages/?tenant=blue",
257+
"https://example.com/api/v1/messages/?tenant=blue&session_id=abc123",
258+
),
259+
],
260+
)
261+
def test_resolve_endpoint_url_with_messages_url_override(
262+
sse_url: str,
263+
endpoint_data: str,
264+
messages_url: str | None,
265+
expected: str,
266+
) -> None:
267+
assert _resolve_endpoint_url(sse_url, endpoint_data, messages_url) == expected
268+
269+
232270
@pytest.mark.anyio
233271
async def test_sse_client_on_session_created_not_called_when_no_session_id(
234272
server: None, server_url: str, monkeypatch: pytest.MonkeyPatch
@@ -249,6 +287,60 @@ def mock_extract(url: str) -> None:
249287
callback_mock.assert_not_called()
250288

251289

290+
@pytest.mark.anyio
291+
async def test_sse_client_uses_messages_url_override() -> None:
292+
init_result = InitializeResult(
293+
protocol_version="2024-11-05",
294+
capabilities=ServerCapabilities(),
295+
server_info=Implementation(name="test", version="1.0"),
296+
)
297+
response = JSONRPCResponse(
298+
jsonrpc="2.0",
299+
id=0,
300+
result=init_result.model_dump(by_alias=True, exclude_none=True),
301+
)
302+
response_json = response.model_dump_json(by_alias=True, exclude_none=True)
303+
304+
async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]:
305+
yield ServerSentEvent(event="endpoint", data="/v1/messages/?session_id=abc123")
306+
yield ServerSentEvent(event="message", data=response_json)
307+
await anyio.sleep_forever()
308+
309+
mock_event_source = MagicMock()
310+
mock_event_source.aiter_sse.return_value = mock_aiter_sse()
311+
mock_event_source.response = MagicMock()
312+
mock_event_source.response.raise_for_status = MagicMock()
313+
314+
mock_aconnect_sse = MagicMock()
315+
mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source)
316+
mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None)
317+
318+
mock_client = MagicMock()
319+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
320+
mock_client.__aexit__ = AsyncMock(return_value=None)
321+
mock_client.post = AsyncMock(return_value=MagicMock(status_code=200, raise_for_status=MagicMock()))
322+
323+
def mock_httpx_client_factory(
324+
headers: dict[str, str] | None = None,
325+
timeout: httpx.Timeout | None = None,
326+
auth: httpx.Auth | None = None,
327+
) -> httpx.AsyncClient:
328+
_ = (headers, timeout, auth)
329+
return cast(httpx.AsyncClient, mock_client)
330+
331+
with patch("mcp.client.sse.aconnect_sse", return_value=mock_aconnect_sse):
332+
async with sse_client(
333+
"https://example.com/api/v1/sse",
334+
httpx_client_factory=mock_httpx_client_factory,
335+
messages_url="https://example.com/api/v1/messages/",
336+
) as streams:
337+
async with ClientSession(*streams) as session:
338+
await session.initialize()
339+
340+
mock_client.post.assert_awaited()
341+
assert mock_client.post.await_args.args[0] == "https://example.com/api/v1/messages/?session_id=abc123"
342+
343+
252344
@pytest.fixture
253345
async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]:
254346
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:

0 commit comments

Comments
 (0)