diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 7b2bc0a13..af7cbcf08 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -203,19 +203,15 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non @pytest.mark.anyio async def test_sse_client_on_session_created(server: None, server_url: str) -> None: - captured_session_id: str | None = None + captured: list[str] = [] - def on_session_created(session_id: str) -> None: - nonlocal captured_session_id - captured_session_id = session_id - - async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams: + async with sse_client(server_url + "/sse", on_session_created=captured.append) as streams: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - - assert captured_session_id is not None # pragma: lax no cover - assert len(captured_session_id) > 0 # pragma: lax no cover + # Callback fires when the endpoint event arrives, before sse_client yields. + assert len(captured) == 1 + assert len(captured[0]) > 0 @pytest.mark.parametrize( @@ -248,8 +244,9 @@ def mock_extract(url: str) -> None: async with ClientSession(*streams) as session: result = await session.initialize() assert isinstance(result, InitializeResult) - - callback_mock.assert_not_called() # pragma: lax no cover + # Callback would have fired by now (endpoint event arrives before + # sse_client yields); if it hasn't, it won't. + callback_mock.assert_not_called() @pytest.fixture diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 42b1a3698..61ba4a2e5 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1132,22 +1132,19 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba read_stream, write_stream, ): - async with ClientSession(read_stream, write_stream) as session: + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) assert len(captured_ids) > 0 captured_session_id = captured_ids[0] assert captured_session_id is not None + headers = {MCP_SESSION_ID_HEADER: captured_session_id} # Make a request to confirm session is working tools = await session.list_tools() assert len(tools.tools) == 10 - headers: dict[str, str] = {} # pragma: lax no cover - if captured_session_id: # pragma: lax no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - async with create_mcp_http_client(headers=headers) as httpx_client2: async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( read_stream, @@ -1196,22 +1193,19 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt read_stream, write_stream, ): - async with ClientSession(read_stream, write_stream) as session: + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) assert len(captured_ids) > 0 captured_session_id = captured_ids[0] assert captured_session_id is not None + headers = {MCP_SESSION_ID_HEADER: captured_session_id} # Make a request to confirm session is working tools = await session.list_tools() assert len(tools.tools) == 10 - headers: dict[str, str] = {} # pragma: lax no cover - if captured_session_id: # pragma: lax no cover - headers[MCP_SESSION_ID_HEADER] = captured_session_id - async with create_mcp_http_client(headers=headers) as httpx_client2: async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client2) as ( read_stream, @@ -1231,7 +1225,6 @@ async def test_streamable_http_client_resumption(event_server: tuple[SimpleEvent # Variables to track the state captured_resumption_token: str | None = None captured_notifications: list[types.ServerNotification] = [] - captured_protocol_version: str | int | None = None first_notification_received = False async def message_handler( # pragma: no branch @@ -1258,15 +1251,20 @@ async def on_resumption_token_update(token: str) -> None: read_stream, write_stream, ): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: + async with ClientSession( # pragma: no branch + read_stream, write_stream, message_handler=message_handler + ) as session: # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) assert len(captured_ids) > 0 captured_session_id = captured_ids[0] assert captured_session_id is not None - # Capture the negotiated protocol version - captured_protocol_version = result.protocol_version + # Build phase-2 headers now while both values are in scope + headers: dict[str, Any] = { + MCP_SESSION_ID_HEADER: captured_session_id, + MCP_PROTOCOL_VERSION_HEADER: result.protocol_version, + } # Start the tool that will wait on lock in a task async with anyio.create_task_group() as tg: # pragma: no branch @@ -1291,25 +1289,19 @@ async def run_tool(): while not first_notification_received or not captured_resumption_token: await anyio.sleep(0.1) + # The while loop only exits after first_notification_received=True, + # which is set by message_handler immediately after appending to + # captured_notifications. The server tool is blocked on its lock, + # so nothing else can arrive before we cancel. + assert len(captured_notifications) == 1 + assert isinstance(captured_notifications[0], types.LoggingMessageNotification) + assert captured_notifications[0].params.data == "First notification before lock" + # Reset for phase 2 before cancelling + captured_notifications.clear() + # Kill the client session while tool is waiting on lock tg.cancel_scope.cancel() - # Verify we received exactly one notification (inside ClientSession - # so coverage tracks these on Python 3.11, see PR #1897 for details) - assert len(captured_notifications) == 1 # pragma: lax no cover - assert isinstance(captured_notifications[0], types.LoggingMessageNotification) # pragma: lax no cover - assert captured_notifications[0].params.data == "First notification before lock" # pragma: lax no cover - - # Clear notifications and set up headers for phase 2 (between connections, - # not tracked by coverage on Python 3.11 due to cancel scope + sys.settrace bug) - captured_notifications = [] # pragma: lax no cover - assert captured_session_id is not None # pragma: lax no cover - assert captured_protocol_version is not None # pragma: lax no cover - headers: dict[str, Any] = { # pragma: lax no cover - MCP_SESSION_ID_HEADER: captured_session_id, - MCP_PROTOCOL_VERSION_HEADER: captured_protocol_version, - } - async with create_mcp_http_client(headers=headers) as httpx_client2: async with streamable_http_client(f"{server_url}/mcp", http_client=httpx_client2) as ( read_stream, @@ -2092,11 +2084,12 @@ async def on_resumption_token(token: str) -> None: assert isinstance(result.content[0], TextContent) assert "Completed 3 checkpoints" in result.content[0].text - # 4 priming + 3 notifications + 1 response = 8 tokens - assert len(resumption_tokens) == 8, ( # pragma: lax no cover - f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " - f"got {len(resumption_tokens)}: {resumption_tokens}" - ) + # 4 priming + 3 notifications + 1 response = 8 tokens. All tokens are + # captured before send_request returns, so this is safe to check here. + assert len(resumption_tokens) == 8, ( + f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), " + f"got {len(resumption_tokens)}: {resumption_tokens}" + ) @pytest.mark.anyio