diff --git a/src/galaxy/api/jsonrpc.py b/src/galaxy/api/jsonrpc.py index 69c781f..cfef38c 100644 --- a/src/galaxy/api/jsonrpc.py +++ b/src/galaxy/api/jsonrpc.py @@ -96,6 +96,27 @@ def anonymise_sensitive_params(params, sensitive_params): return params + +def _iter_error_types(error_type): + yield error_type + for subclass in error_type.__subclasses__(): + yield from _iter_error_types(subclass) + + +def build_error(error): + data = error.setdefault("data", None) + code = error.setdefault("code", 0) + message = error.setdefault("message", "") + + if isinstance(data, Mapping): + internal_type = data.get("internal_type") + if internal_type: + for error_type in _iter_error_types(JsonRpcError): + if error_type.__name__ == internal_type: + return error_type(message, data) + + return JsonRpcError(code, message, data) + class Connection(): def __init__(self, reader, writer, encoder=json.JSONEncoder()): self._active = True @@ -192,6 +213,10 @@ def close(self): if self._active: logger.info("Closing JSON-RPC server - not more messages will be read") self._active = False + for request_id, (future, _) in self._requests_futures.items(): + if not future.done(): + future.set_exception(Aborted(data={"request_id": request_id})) + self._requests_futures.clear() async def wait_closed(self): await self._task_manager.wait() @@ -216,7 +241,13 @@ def _handle_input(self, data): self._handle_response(message) def _handle_response(self, response): - request_future = self._requests_futures.get(int(response.id)) + try: + request_id = int(response.id) + except (TypeError, ValueError): + logger.warning("Received response with invalid request id: %s", response.id) + return + + request_future = self._requests_futures.pop(request_id, None) if request_future is None: response_type = "response" if response.result is not None else "error" logger.warning("Received %s for unknown request: %s", response_type, response.id) @@ -225,11 +256,7 @@ def _handle_response(self, response): future, sensitive_params = request_future if response.error: - error = JsonRpcError( - response.error.setdefault("code", 0), - response.error.setdefault("message", ""), - response.error.setdefault("data", None) - ) + error = build_error(response.error) self._log_error(response, error, sensitive_params) future.set_exception(error) return @@ -249,7 +276,8 @@ def _handle_notification(self, request): try: bound_args = signature.bind(**request.params) except TypeError: - self._send_error(request.id, InvalidParams()) + logger.error("Received notification with invalid params: %s", request.method) + return if immediate: callback(*bound_args.args, **bound_args.kwargs) @@ -273,6 +301,7 @@ def _handle_request(self, request): bound_args = signature.bind(**request.params) except TypeError: self._send_error(request.id, InvalidParams()) + return if immediate: response = callback(*bound_args.args, **bound_args.kwargs) diff --git a/tests/test_internal.py b/tests/test_internal.py index 7547723..caa0a69 100644 --- a/tests/test_internal.py +++ b/tests/test_internal.py @@ -2,6 +2,7 @@ from galaxy.api.plugin import Plugin from galaxy.api.consts import Platform +from galaxy.api.jsonrpc import Connection, InvalidParams, Request from galaxy.unittest.mock import delayed_return_value_iterable from tests import create_message, get_messages @@ -97,3 +98,38 @@ async def test_tick_after_handshake(plugin, read): await plugin.run() await plugin.wait_closed() plugin.tick.assert_called_with() + + +@pytest.mark.asyncio +async def test_notification_invalid_params_does_not_crash(plugin, read, write, reader, writer): + called = False + + def callback(game_id): + nonlocal called + called = True + + connection = Connection(reader, writer) + connection.register_notification("install_game", callback, immediate=True) + connection._handle_notification(Request(method="install_game", params={}, id=None)) + + assert called is False + assert get_messages(write) == [] + + +@pytest.mark.asyncio +async def test_request_invalid_params_returns_error(plugin, read, write, reader, writer): + async def callback(game_ids): + return game_ids + + connection = Connection(reader, writer) + connection.register_method("start_achievements_import", callback, immediate=False) + connection._handle_request(Request(method="start_achievements_import", params={}, id="8")) + await connection.wait_closed() + + assert get_messages(write) == [ + { + "jsonrpc": "2.0", + "id": "8", + "error": InvalidParams().json() + } + ] diff --git a/tests/test_refresh_credentials.py b/tests/test_refresh_credentials.py index b3fd27f..cada1df 100644 --- a/tests/test_refresh_credentials.py +++ b/tests/test_refresh_credentials.py @@ -5,7 +5,7 @@ from galaxy.api.errors import ( BackendNotAvailable, BackendTimeout, BackendError, InvalidCredentials, NetworkError, AccessDenied, UnknownError ) -from galaxy.api.jsonrpc import JsonRpcError +from galaxy.api.jsonrpc import Aborted @pytest.mark.asyncio @@ -37,6 +37,7 @@ async def test_refresh_credentials_success(plugin, read, write): ] assert result == refreshed_credentials + assert plugin._connection._requests_futures == {} await run_task @pytest.mark.asyncio @@ -56,11 +57,10 @@ async def test_refresh_credentials_failure(exception, plugin, read, write): # 2 loop iterations delay is to force sending response after request has been sent read.side_effect = [create_message(response), b""] - with pytest.raises(JsonRpcError) as e: + with pytest.raises(exception) as e: await plugin.refresh_credentials({}, False) - # Go back to comparing error == e.value, after fixing current always raising JsonRpcError when handling a response with an error - assert error.code == e.value.code + assert error == e.value assert get_messages(write) == [ { "jsonrpc": "2.0", @@ -70,5 +70,29 @@ async def test_refresh_credentials_failure(exception, plugin, read, write): "id": "1" } ] + assert plugin._connection._requests_futures == {} await run_task + + +@pytest.mark.asyncio +async def test_refresh_credentials_aborted_on_close(plugin, write): + refresh_task = asyncio.create_task(plugin.refresh_credentials({}, False)) + + await asyncio.sleep(0) + plugin.close() + + with pytest.raises(Aborted): + await refresh_task + + assert get_messages(write) == [ + { + "jsonrpc": "2.0", + "method": "refresh_credentials", + "params": { + }, + "id": "1" + } + ] + assert plugin._connection._requests_futures == {} + await plugin.wait_closed()