Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,15 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non

@pytest.mark.anyio
async def test_sse_client_on_session_created(server: None, server_url: str) -> None:
captured_session_id: str | None = None
captured: list[str] = []

def on_session_created(session_id: str) -> None:
nonlocal captured_session_id
captured_session_id = session_id

async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams:
async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams:
async with ClientSession(*streams) as session:
result = await session.initialize()
assert isinstance(result, InitializeResult)

assert captured_session_id is not None # pragma: lax no cover
assert len(captured_session_id) > 0 # pragma: lax no cover
# Callback fires when the endpoint event arrives, before sse_client yields.
assert len(captured) == 1
assert len(captured[0]) > 0


@pytest.mark.parametrize(
Expand Down Expand Up @@ -248,8 +244,9 @@ def mock_extract(url: str) -> None:
async with ClientSession(*streams) as session:
result = await session.initialize()
assert isinstance(result, InitializeResult)

callback_mock.assert_not_called() # pragma: lax no cover
# Callback would have fired by now (endpoint event arrives before
# sse_client yields); if it hasn't, it won't.
callback_mock.assert_not_called()


@pytest.fixture
Expand Down
63 changes: 28 additions & 35 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,22 +1132,19 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session:
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert len(captured_ids) > 0
captured_session_id = captured_ids[0]
assert captured_session_id is not None
headers = {MCP_SESSION_ID_HEADER: captured_session_id}

# Make a request to confirm session is working
tools = await session.list_tools()
assert len(tools.tools) == 10

headers: dict[str, str] = {} # pragma: lax no cover
if captured_session_id: # pragma: lax no cover
headers[MCP_SESSION_ID_HEADER] = captured_session_id

async with create_mcp_http_client(headers=headers) as httpx_client2:
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as (
read_stream,
Expand Down Expand Up @@ -1196,22 +1193,19 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session:
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert len(captured_ids) > 0
captured_session_id = captured_ids[0]
assert captured_session_id is not None
headers = {MCP_SESSION_ID_HEADER: captured_session_id}

# Make a request to confirm session is working
tools = await session.list_tools()
assert len(tools.tools) == 10

headers: dict[str, str] = {} # pragma: lax no cover
if captured_session_id: # pragma: lax no cover
headers[MCP_SESSION_ID_HEADER] = captured_session_id

async with create_mcp_http_client(headers=headers) as httpx_client2:
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as (
read_stream,
Expand All @@ -1231,7 +1225,6 @@ async def test_streamable_http_client_resumption(event_server: tuple[SimpleEvent
# Variables to track the state
captured_resumption_token: str | None = None
captured_notifications: list[types.ServerNotification] = []
captured_protocol_version: str | int | None = None
first_notification_received = False

async def message_handler( # pragma: no branch
Expand All @@ -1258,15 +1251,20 @@ async def on_resumption_token_update(token: str) -> None:
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
async with ClientSession( # pragma: no branch
read_stream, write_stream, message_handler=message_handler
) as session:
# Initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert len(captured_ids) > 0
captured_session_id = captured_ids[0]
assert captured_session_id is not None
# Capture the negotiated protocol version
captured_protocol_version = result.protocol_version
# Build phase-2 headers now while both values are in scope
headers: dict[str, Any] = {
MCP_SESSION_ID_HEADER: captured_session_id,
MCP_PROTOCOL_VERSION_HEADER: result.protocol_version,
}

# Start the tool that will wait on lock in a task
async with anyio.create_task_group() as tg: # pragma: no branch
Expand All @@ -1291,25 +1289,19 @@ async def run_tool():
while not first_notification_received or not captured_resumption_token:
await anyio.sleep(0.1)

# The while loop only exits after first_notification_received=True,
# which is set by message_handler immediately after appending to
# captured_notifications. The server tool is blocked on its lock,
# so nothing else can arrive before we cancel.
assert len(captured_notifications) == 1
assert isinstance(captured_notifications[0], types.LoggingMessageNotification)
assert captured_notifications[0].params.data == "First notification before lock"
# Reset for phase 2 before cancelling
captured_notifications.clear()

# Kill the client session while tool is waiting on lock
tg.cancel_scope.cancel()

# Verify we received exactly one notification (inside ClientSession
# so coverage tracks these on Python 3.11, see PR #1897 for details)
assert len(captured_notifications) == 1 # pragma: lax no cover
assert isinstance(captured_notifications[0], types.LoggingMessageNotification) # pragma: lax no cover
assert captured_notifications[0].params.data == "First notification before lock" # pragma: lax no cover

# Clear notifications and set up headers for phase 2 (between connections,
# not tracked by coverage on Python 3.11 due to cancel scope + sys.settrace bug)
captured_notifications = [] # pragma: lax no cover
assert captured_session_id is not None # pragma: lax no cover
assert captured_protocol_version is not None # pragma: lax no cover
headers: dict[str, Any] = { # pragma: lax no cover
MCP_SESSION_ID_HEADER: captured_session_id,
MCP_PROTOCOL_VERSION_HEADER: captured_protocol_version,
}

async with create_mcp_http_client(headers=headers) as httpx_client2:
async with streamable_http_client(f"{server_url}/mcp", http_client=httpx_client2) as (
read_stream,
Expand Down Expand Up @@ -2092,11 +2084,12 @@ async def on_resumption_token(token: str) -> None:
assert isinstance(result.content[0], TextContent)
assert "Completed 3 checkpoints" in result.content[0].text

# 4 priming + 3 notifications + 1 response = 8 tokens
assert len(resumption_tokens) == 8, ( # pragma: lax no cover
f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), "
f"got {len(resumption_tokens)}: {resumption_tokens}"
)
# 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are
# captured before send_request returns, so this is safe to check here.
assert len(resumption_tokens) == 8, (
f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), "
f"got {len(resumption_tokens)}: {resumption_tokens}"
)


@pytest.mark.anyio
Expand Down