forked from modelcontextprotocol/python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_session.py
More file actions
191 lines (153 loc) · 6.34 KB
/
test_session.py
File metadata and controls
191 lines (153 loc) · 6.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from collections.abc import AsyncGenerator
from datetime import timedelta
import anyio
import pytest
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.server.lowlevel.server import Server
from mcp.shared.exceptions import McpError
from mcp.shared.memory import (
create_connected_server_and_client_session,
)
from mcp.types import (
EmptyResult,
)
@pytest.fixture
def mcp_server() -> Server:
return Server(name="test server")
@pytest.fixture
async def client_connected_to_server(
mcp_server: Server,
) -> AsyncGenerator[ClientSession, None]:
async with create_connected_server_and_client_session(mcp_server) as client_session:
yield client_session
@pytest.mark.anyio
async def test_in_flight_requests_cleared_after_completion(
client_connected_to_server: ClientSession,
):
"""Verify that _in_flight is empty after all requests complete."""
# Send a request and wait for response
response = await client_connected_to_server.send_ping()
assert isinstance(response, EmptyResult)
# Verify _in_flight is empty
assert len(client_connected_to_server._in_flight) == 0
@pytest.mark.anyio
async def test_request_cancellation():
"""Test that requests can be cancelled while in-flight."""
ev_tool_called = anyio.Event()
ev_tool_cancelled = anyio.Event()
ev_cancelled = anyio.Event()
ev_cancel_notified = anyio.Event()
# Start the request in a separate task so we can cancel it
def make_server() -> Server:
server = Server(name="TestSessionServer")
# Register the tool handler
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict | None) -> list:
nonlocal ev_tool_called, ev_tool_cancelled
if name == "slow_tool":
ev_tool_called.set()
with anyio.CancelScope():
try:
await anyio.sleep(10) # Long enough to ensure we can cancel
return []
except anyio.get_cancelled_exc_class() as err:
ev_tool_cancelled.set()
raise err
raise ValueError(f"Unknown tool: {name}")
@server.cancel_notification()
async def handle_cancel(requestId: str | int, reason: str | None):
nonlocal ev_cancel_notified
ev_cancel_notified.set()
# Register the tool so it shows up in list_tools
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
return [
types.Tool(
name="slow_tool",
description="A slow tool that takes 10 seconds to complete",
inputSchema={},
)
]
return server
async def make_request(client_session: ClientSession):
nonlocal ev_cancelled
try:
await client_session.call_tool("slow_tool")
pytest.fail("Request should have been cancelled")
except McpError as e:
# Expected - request was cancelled
assert "Request cancelled" in str(e)
ev_cancelled.set()
async with create_connected_server_and_client_session(make_server()) as client_session:
async with anyio.create_task_group() as tg:
tg.start_soon(make_request, client_session)
# Wait for the request to be in-flight
with anyio.fail_after(1): # Timeout after 1 second
await ev_tool_called.wait()
# Cancel the task via task group
tg.cancel_scope.cancel()
# Give cancellation time to process
with anyio.fail_after(1):
await ev_cancelled.wait()
# Check server cancel notification received
with anyio.fail_after(1):
await ev_cancel_notified.wait()
# Give cancellation time to process on server
with anyio.fail_after(1):
await ev_tool_cancelled.wait()
@pytest.mark.anyio
async def test_request_cancellation_uncancellable():
"""Test that asserts a call with cancellable=False is not cancelled on
server when cancel scope on client is set."""
ev_tool_called = anyio.Event()
ev_tool_commplete = anyio.Event()
ev_cancelled = anyio.Event()
# Start the request in a separate task so we can cancel it
def make_server() -> Server:
server = Server(name="TestSessionServer")
# Register the tool handler
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict | None) -> list:
nonlocal ev_tool_called, ev_tool_commplete
if name == "slow_tool":
ev_tool_called.set()
with anyio.CancelScope():
with anyio.fail_after(10): # Long enough to ensure we can cancel
await ev_cancelled.wait()
ev_tool_commplete.set()
return []
raise ValueError(f"Unknown tool: {name}")
# Register the tool so it shows up in list_tools
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
return [
types.Tool(
name="slow_tool",
description="A slow tool that takes 10 seconds to complete",
inputSchema={},
)
]
return server
async def make_request(client_session: ClientSession):
nonlocal ev_cancelled
try:
await client_session.call_tool(
"slow_tool",
cancellable=False,
read_timeout_seconds=timedelta(seconds=10),
)
except McpError:
pytest.fail("Request should not have been cancelled")
async with create_connected_server_and_client_session(make_server()) as client_session:
async with anyio.create_task_group() as tg:
tg.start_soon(make_request, client_session)
# Wait for the request to be in-flight
with anyio.fail_after(1): # Timeout after 1 second
await ev_tool_called.wait()
# Cancel the task via task group
tg.cancel_scope.cancel()
ev_cancelled.set()
# Check server completed regardless
with anyio.fail_after(1):
await ev_tool_commplete.wait()