diff --git a/streamdeck/manager.py b/streamdeck/manager.py index ac31335..0690078 100644 --- a/streamdeck/manager.py +++ b/streamdeck/manager.py @@ -100,7 +100,7 @@ def run(self) -> None: command_sender.send_action_registration(register_event=self._register_event, plugin_registration_uuid=self._registration_uuid) - for message in client.listen_forever(): + for message in client.listen(): data: EventBase = event_adapter.validate_json(message) if not is_valid_event_name(data.event): diff --git a/streamdeck/websocket.py b/streamdeck/websocket.py index 1ba9a4d..46927dd 100644 --- a/streamdeck/websocket.py +++ b/streamdeck/websocket.py @@ -4,7 +4,7 @@ from logging import getLogger from typing import TYPE_CHECKING -from websockets.exceptions import ConnectionClosed, ConnectionClosedError, ConnectionClosedOK +from websockets.exceptions import ConnectionClosed, ConnectionClosedOK from websockets.sync.client import ClientConnection, connect @@ -12,7 +12,7 @@ from collections.abc import Generator from typing import Any - from typing_extensions import Self + from typing_extensions import Self # noqa: UP035 logger = getLogger("streamdeck.websocket") @@ -38,10 +38,14 @@ def send_event(self, data: dict[str, Any]) -> None: Args: data (dict[str, Any]): The event data to send. """ + if self._client is None: + msg = "WebSocket connection not established yet." + raise ValueError(msg) + data_str = json.dumps(data) self._client.send(message=data_str) - def listen_forever(self) -> Generator[str | bytes, Any, None]: + def listen(self) -> Generator[str | bytes, Any, None]: """Listen for messages from the WebSocket server indefinitely. TODO: implement more concise error-handling. @@ -55,8 +59,23 @@ def listen_forever(self) -> Generator[str | bytes, Any, None]: message: str | bytes = self._client.recv() yield message + except ConnectionClosedOK: + logger.debug("Connection was closed normally, stopping the client.") + + except ConnectionClosed: + logger.exception("Connection was closed with an error.") + except Exception: - logger.exception("Failed to receive messages from websocket server.") + logger.exception("Failed to receive messages from websocket server due to unexpected error.") + + def start(self) -> None: + """Start the connection to the websocket server.""" + self._client = connect(uri=f"ws://localhost:{self._port}") + + def stop(self) -> None: + """Close the WebSocket connection, if open.""" + if self._client is not None: + self._client.close() def __enter__(self) -> Self: """Start the connection to the websocket server. @@ -64,11 +83,10 @@ def __enter__(self) -> Self: Returns: Self: The WebSocketClient instance after connecting to the WebSocket server. """ - self._client = connect(uri=f"ws://localhost:{self._port}") + self.start() return self def __exit__(self, *args, **kwargs) -> None: """Close the WebSocket connection, if open.""" - if self._client is not None: - self._client.close() + self.stop() diff --git a/tests/plugin_manager/conftest.py b/tests/plugin_manager/conftest.py index b85d149..ff140a8 100644 --- a/tests/plugin_manager/conftest.py +++ b/tests/plugin_manager/conftest.py @@ -32,8 +32,8 @@ def plugin_manager(port_number: int, plugin_registration_uuid: str) -> PluginMan def patch_websocket_client(monkeypatch: pytest.MonkeyPatch) -> Mock: """Fixture that uses pytest's MonkeyPatch to mock WebSocketClient for the PluginManager run method. - The mocked WebSocketClient can be given fake event messages to yield when listen_forever() is called: - ```patch_websocket_client.listen_forever.return_value = [fake_event_json1, fake_event_json2, ...]``` + The mocked WebSocketClient can be given fake event messages to yield when listen() is called: + ```patch_websocket_client.listen.return_value = [fake_event_json1, fake_event_json2, ...]``` Args: monkeypatch: pytest's monkeypatch fixture. diff --git a/tests/plugin_manager/test_command_sender_binding.py b/tests/plugin_manager/test_command_sender_binding.py index a388664..f251a08 100644 --- a/tests/plugin_manager/test_command_sender_binding.py +++ b/tests/plugin_manager/test_command_sender_binding.py @@ -64,12 +64,12 @@ def mock_websocket_client_with_fake_events(patch_websocket_client: Mock) -> tupl Returns: tuple: Mocked instance of WebSocketClient, and a list of fake event messages. """ - # Create a list of fake event messages, and convert them to json strings to be passed back by the client.listen_forever() method. + # Create a list of fake event messages, and convert them to json strings to be passed back by the client.listen() method. fake_event_messages: list[events.EventBase] = [ KeyDownEventFactory.build(action="my-fake-action-uuid"), ] - patch_websocket_client.listen_forever.return_value = [event.model_dump_json() for event in fake_event_messages] + patch_websocket_client.listen.return_value = [event.model_dump_json() for event in fake_event_messages] return patch_websocket_client, fake_event_messages diff --git a/tests/plugin_manager/test_plugin_manager.py b/tests/plugin_manager/test_plugin_manager.py index 0062396..edd1444 100644 --- a/tests/plugin_manager/test_plugin_manager.py +++ b/tests/plugin_manager/test_plugin_manager.py @@ -20,9 +20,9 @@ def mock_websocket_client_with_event(patch_websocket_client: Mock) -> tuple[Mock Returns: tuple: Mocked instance of WebSocketClient, and a fake DialRotateEvent. """ - # Create a fake event message, and convert it to a json string to be passed back by the client.listen_forever() method. + # Create a fake event message, and convert it to a json string to be passed back by the client.listen() method. fake_event_message: DialRotate = DialRotateEventFactory.build() - patch_websocket_client.listen_forever.return_value = [fake_event_message.model_dump_json()] + patch_websocket_client.listen.return_value = [fake_event_message.model_dump_json()] return patch_websocket_client, fake_event_message @@ -91,11 +91,11 @@ def test_plugin_manager_process_event( plugin_manager.run() - # First check that the WebSocketClient's listen_forever() method was called. + # First check that the WebSocketClient's listen() method was called. # This has been stubbed to return the fake_event_message's json string. - mock_websocket_client.listen_forever.assert_called_once() + mock_websocket_client.listen.assert_called_once() - # Check that the event_adapter.validate_json method was called with the stub json string returned by listen_forever(). + # Check that the event_adapter.validate_json method was called with the stub json string returned by listen(). spied_event_adapter_validate_json = cast(Mock, event_adapter.validate_json) spied_event_adapter_validate_json.assert_called_once_with(fake_event_message.model_dump_json()) # Check that the validate_json method returns the same event type model as the fake_event_message. diff --git a/tests/test_websocket.py b/tests/test_websocket.py index b74cf87..42befd3 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -47,12 +47,12 @@ def test_send_event_serializes_and_sends(mock_connection: Mock, port_number: int @pytest.mark.usefixtures("patched_connect") -def test_listen_forever_yields_messages(mock_connection: Mock, port_number: int): - """Test that listen_forever yields messages from the WebSocket connection.""" +def test_listen_yields_messages(mock_connection: Mock, port_number: int): + """Test that listen yields messages from the WebSocket connection.""" # Set up the mocked connection to return messages until closing mock_connection.recv.side_effect = ["message1", b"message2", WebSocketException()] with WebSocketClient(port=port_number) as client: - messages = list(client.listen_forever()) + messages = list(client.listen()) - assert messages == ["message1", b"message2"] \ No newline at end of file + assert messages == ["message1", b"message2"]