|
1 | | -from typing import Any |
| 1 | +from typing import Any, cast |
2 | 2 |
|
3 | 3 | import anyio |
4 | 4 | import pytest |
@@ -221,64 +221,65 @@ async def mock_client(): |
221 | 221 |
|
222 | 222 |
|
223 | 223 | class _ClosedWriteStream: |
224 | | - async def send(self, item): |
| 224 | + async def send(self, item: SessionMessage) -> None: |
225 | 225 | raise anyio.ClosedResourceError |
226 | 226 |
|
227 | 227 |
|
228 | 228 | class _OpenWriteStream: |
229 | 229 | def __init__(self): |
230 | 230 | self.items: list[SessionMessage] = [] |
231 | 231 |
|
232 | | - async def send(self, item): |
| 232 | + async def send(self, item: SessionMessage) -> None: |
233 | 233 | self.items.append(item) |
234 | 234 |
|
235 | 235 |
|
236 | 236 | class _FakeResult: |
237 | 237 | def __init__(self, payload: dict[str, Any]): |
238 | 238 | self._payload = payload |
239 | 239 |
|
240 | | - def model_dump(self, **kwargs): |
| 240 | + def model_dump(self, **kwargs: Any) -> dict[str, Any]: |
241 | 241 | return dict(self._payload) |
242 | 242 |
|
243 | 243 |
|
244 | 244 | class _FakeNotification: |
245 | 245 | def __init__(self, payload: dict[str, Any]): |
246 | 246 | self._payload = payload |
247 | 247 |
|
248 | | - def model_dump(self, **kwargs): |
| 248 | + def model_dump(self, **kwargs: Any) -> dict[str, Any]: |
249 | 249 | return dict(self._payload) |
250 | 250 |
|
251 | 251 |
|
252 | 252 | @pytest.mark.anyio |
253 | 253 | async def test_base_session_send_response_ignores_closed_write_stream(): |
254 | | - session = object.__new__(BaseSession) |
| 254 | + session = cast(Any, object.__new__(BaseSession)) |
255 | 255 | session._write_stream = _ClosedWriteStream() |
256 | 256 |
|
257 | | - await BaseSession._send_response(session, 1, _FakeResult({"ok": True})) |
| 257 | + await cast(Any, BaseSession)._send_response(session, 1, _FakeResult({"ok": True})) |
258 | 258 |
|
259 | 259 |
|
260 | 260 | @pytest.mark.anyio |
261 | 261 | async def test_base_session_send_notification_ignores_closed_write_stream(): |
262 | | - session = object.__new__(BaseSession) |
| 262 | + session = cast(Any, object.__new__(BaseSession)) |
263 | 263 | session._write_stream = _ClosedWriteStream() |
264 | 264 |
|
265 | | - await BaseSession.send_notification( |
| 265 | + await cast(Any, BaseSession).send_notification( |
266 | 266 | session, |
267 | 267 | _FakeNotification({"method": "notifications/progress", "params": {"progress": 1}}), |
268 | 268 | ) |
269 | 269 |
|
270 | 270 |
|
271 | 271 | @pytest.mark.anyio |
272 | 272 | async def test_base_session_send_notification_still_writes_when_open(): |
273 | | - session = object.__new__(BaseSession) |
274 | | - session._write_stream = _OpenWriteStream() |
| 273 | + open_stream = _OpenWriteStream() |
| 274 | + session = cast(Any, object.__new__(BaseSession)) |
| 275 | + session._write_stream = open_stream |
275 | 276 |
|
276 | | - await BaseSession.send_notification( |
| 277 | + await cast(Any, BaseSession).send_notification( |
277 | 278 | session, |
278 | 279 | _FakeNotification({"method": "notifications/progress", "params": {"progress": 1}}), |
279 | 280 | ) |
280 | 281 |
|
281 | | - assert len(session._write_stream.items) == 1 |
| 282 | + assert len(open_stream.items) == 1 |
282 | 283 |
|
283 | 284 |
|
284 | 285 | @pytest.mark.anyio |
|
0 commit comments