forked from modelcontextprotocol/python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_exception_handling.py
More file actions
121 lines (99 loc) · 4.2 KB
/
test_exception_handling.py
File metadata and controls
121 lines (99 loc) · 4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""Test exception handling in lowlevel server message processing."""
import logging
from unittest.mock import Mock
import anyio
import pytest
from mcp.server import Server
from mcp.server.lowlevel import NotificationOptions
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared.message import SessionMessage
from mcp.types import ServerCapabilities
@pytest.mark.anyio
async def test_handle_message_with_exception_logging(caplog):
"""Test that Exception instances passed to _handle_message are properly logged."""
server = Server("test")
# Create in-memory streams for testing
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10)
# Create a server session
session = ServerSession(
read_stream=client_to_server_receive,
write_stream=server_to_client_send,
init_options=InitializationOptions(
server_name="test",
server_version="1.0.0",
capabilities=ServerCapabilities(),
),
)
# Create a test exception
test_exception = ValueError("Test exception for logging")
# Test the _handle_message method directly with an Exception
with caplog.at_level(logging.ERROR):
await server._handle_message(
message=test_exception,
session=session,
lifespan_context=None,
raise_exceptions=False,
)
# Verify that the exception was logged
assert len(caplog.records) == 1
record = caplog.records[0]
assert record.levelno == logging.ERROR
assert "Error in message processing" in record.getMessage()
assert "Test exception for logging" in record.getMessage()
@pytest.mark.anyio
async def test_handle_message_with_exception_raising():
"""Test that Exception instances are re-raised when raise_exceptions=True."""
server = Server("test")
# Create in-memory streams for testing
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10)
# Create a server session
session = ServerSession(
read_stream=client_to_server_receive,
write_stream=server_to_client_send,
init_options=InitializationOptions(
server_name="test",
server_version="1.0.0",
capabilities=ServerCapabilities(),
),
)
# Create a test exception
test_exception = ValueError("Test exception for raising")
# Test that the exception is re-raised when raise_exceptions=True
with pytest.raises(ValueError, match="Test exception for raising"):
await server._handle_message(
message=test_exception,
session=session,
lifespan_context=None,
raise_exceptions=True,
)
@pytest.mark.anyio
async def test_handle_message_with_exception_no_raise():
"""Test that Exception instances are not re-raised when raise_exceptions=False."""
server = Server("test")
# Create in-memory streams for testing
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10)
# Create a server session
session = ServerSession(
read_stream=client_to_server_receive,
write_stream=server_to_client_send,
init_options=InitializationOptions(
server_name="test",
server_version="1.0.0",
capabilities=ServerCapabilities(),
),
)
# Create a test exception
test_exception = RuntimeError("Test exception for no raising")
# Test that the exception is not re-raised when raise_exceptions=False
# This should not raise an exception
await server._handle_message(
message=test_exception,
session=session,
lifespan_context=None,
raise_exceptions=False,
)
# If we reach this point, the test passed