Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions src/galaxy/api/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,27 @@ def anonymise_sensitive_params(params, sensitive_params):

return params


def _iter_error_types(error_type):
yield error_type
for subclass in error_type.__subclasses__():
yield from _iter_error_types(subclass)


def build_error(error):
data = error.setdefault("data", None)
code = error.setdefault("code", 0)
message = error.setdefault("message", "")

if isinstance(data, Mapping):
internal_type = data.get("internal_type")
if internal_type:
for error_type in _iter_error_types(JsonRpcError):
if error_type.__name__ == internal_type:
return error_type(message, data)

return JsonRpcError(code, message, data)

class Connection():
def __init__(self, reader, writer, encoder=json.JSONEncoder()):
self._active = True
Expand Down Expand Up @@ -192,6 +213,10 @@ def close(self):
if self._active:
logger.info("Closing JSON-RPC server - not more messages will be read")
self._active = False
for request_id, (future, _) in self._requests_futures.items():
if not future.done():
future.set_exception(Aborted(data={"request_id": request_id}))
self._requests_futures.clear()

async def wait_closed(self):
await self._task_manager.wait()
Expand All @@ -216,7 +241,13 @@ def _handle_input(self, data):
self._handle_response(message)

def _handle_response(self, response):
request_future = self._requests_futures.get(int(response.id))
try:
request_id = int(response.id)
except (TypeError, ValueError):
logger.warning("Received response with invalid request id: %s", response.id)
return

request_future = self._requests_futures.pop(request_id, None)
if request_future is None:
response_type = "response" if response.result is not None else "error"
logger.warning("Received %s for unknown request: %s", response_type, response.id)
Expand All @@ -225,11 +256,7 @@ def _handle_response(self, response):
future, sensitive_params = request_future

if response.error:
error = JsonRpcError(
response.error.setdefault("code", 0),
response.error.setdefault("message", ""),
response.error.setdefault("data", None)
)
error = build_error(response.error)
self._log_error(response, error, sensitive_params)
future.set_exception(error)
return
Expand All @@ -249,7 +276,8 @@ def _handle_notification(self, request):
try:
bound_args = signature.bind(**request.params)
except TypeError:
self._send_error(request.id, InvalidParams())
logger.error("Received notification with invalid params: %s", request.method)
return

if immediate:
callback(*bound_args.args, **bound_args.kwargs)
Expand All @@ -273,6 +301,7 @@ def _handle_request(self, request):
bound_args = signature.bind(**request.params)
except TypeError:
self._send_error(request.id, InvalidParams())
return

if immediate:
response = callback(*bound_args.args, **bound_args.kwargs)
Expand Down
36 changes: 36 additions & 0 deletions tests/test_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from galaxy.api.plugin import Plugin
from galaxy.api.consts import Platform
from galaxy.api.jsonrpc import Connection, InvalidParams, Request
from galaxy.unittest.mock import delayed_return_value_iterable

from tests import create_message, get_messages
Expand Down Expand Up @@ -97,3 +98,38 @@ async def test_tick_after_handshake(plugin, read):
await plugin.run()
await plugin.wait_closed()
plugin.tick.assert_called_with()


@pytest.mark.asyncio
async def test_notification_invalid_params_does_not_crash(plugin, read, write, reader, writer):
called = False

def callback(game_id):
nonlocal called
called = True

connection = Connection(reader, writer)
connection.register_notification("install_game", callback, immediate=True)
connection._handle_notification(Request(method="install_game", params={}, id=None))

assert called is False
assert get_messages(write) == []


@pytest.mark.asyncio
async def test_request_invalid_params_returns_error(plugin, read, write, reader, writer):
async def callback(game_ids):
return game_ids

connection = Connection(reader, writer)
connection.register_method("start_achievements_import", callback, immediate=False)
connection._handle_request(Request(method="start_achievements_import", params={}, id="8"))
await connection.wait_closed()

assert get_messages(write) == [
{
"jsonrpc": "2.0",
"id": "8",
"error": InvalidParams().json()
}
]
32 changes: 28 additions & 4 deletions tests/test_refresh_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from galaxy.api.errors import (
BackendNotAvailable, BackendTimeout, BackendError, InvalidCredentials, NetworkError, AccessDenied, UnknownError
)
from galaxy.api.jsonrpc import JsonRpcError
from galaxy.api.jsonrpc import Aborted


@pytest.mark.asyncio
Expand Down Expand Up @@ -37,6 +37,7 @@ async def test_refresh_credentials_success(plugin, read, write):
]

assert result == refreshed_credentials
assert plugin._connection._requests_futures == {}
await run_task

@pytest.mark.asyncio
Expand All @@ -56,11 +57,10 @@ async def test_refresh_credentials_failure(exception, plugin, read, write):
# 2 loop iterations delay is to force sending response after request has been sent
read.side_effect = [create_message(response), b""]

with pytest.raises(JsonRpcError) as e:
with pytest.raises(exception) as e:
await plugin.refresh_credentials({}, False)

# Go back to comparing error == e.value, after fixing current always raising JsonRpcError when handling a response with an error
assert error.code == e.value.code
assert error == e.value
assert get_messages(write) == [
{
"jsonrpc": "2.0",
Expand All @@ -70,5 +70,29 @@ async def test_refresh_credentials_failure(exception, plugin, read, write):
"id": "1"
}
]
assert plugin._connection._requests_futures == {}

await run_task


@pytest.mark.asyncio
async def test_refresh_credentials_aborted_on_close(plugin, write):
refresh_task = asyncio.create_task(plugin.refresh_credentials({}, False))

await asyncio.sleep(0)
plugin.close()

with pytest.raises(Aborted):
await refresh_task

assert get_messages(write) == [
{
"jsonrpc": "2.0",
"method": "refresh_credentials",
"params": {
},
"id": "1"
}
]
assert plugin._connection._requests_futures == {}
await plugin.wait_closed()