forked from modelcontextprotocol/python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsse.py
More file actions
158 lines (138 loc) · 7.57 KB
/
sse.py
File metadata and controls
158 lines (138 loc) · 7.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import logging
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urljoin, urlparse
import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
from httpx_sse._exceptions import SSEError
import mcp.types as types
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
@asynccontextmanager
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
):
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
auth: Optional HTTPX authentication handler.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx_client_factory(
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
) as client:
async with aconnect_sse(
client,
"GET",
url,
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
async def sse_reader(
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
try:
async for sse in event_source.aiter_sse():
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.debug(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = (
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)
task_status.started(endpoint_url)
case "message":
try:
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
)
logger.debug(f"Received server message: {message}")
except Exception as exc:
logger.exception("Error parsing server message")
await read_stream_writer.send(exc)
continue
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
except SSEError as sse_exc:
logger.exception("Encountered SSE exception")
raise sse_exc
except Exception as exc:
logger.exception("Error in sse_reader")
try:
await read_stream_writer.send(exc)
except anyio.ClosedResourceError:
# Stream already closed during shutdown, which is expected
logger.debug("Stream closed during error handling - ignoring")
finally:
await read_stream_writer.aclose()
async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
logger.debug(f"Sending client message: {session_message}")
response = await client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
except Exception:
logger.exception("Error in post_writer")
finally:
await write_stream.aclose()
endpoint_url = await tg.start(sse_reader)
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
tg.start_soon(post_writer, endpoint_url)
try:
yield read_stream, write_stream
finally:
# FIX: Removed manual tg.cancel_scope.cancel() that violated anyio lifecycle
# The original code called tg.cancel_scope.cancel() here, but this caused:
# "RuntimeError: Attempted to exit cancel scope in a different task than it was entered in"
#
# The task group's __aexit__ will properly cancel all child tasks when
# this context manager exits, so no manual cancellation is needed.
pass
finally:
await read_stream_writer.aclose()
await write_stream.aclose()