Skip to content

Commit cdf6f3b

Browse files
committed
chore: move logic to be owned by v1 instead of device
1 parent 4a23a4b commit cdf6f3b

File tree

3 files changed

+108
-77
lines changed

3 files changed

+108
-77
lines changed

roborock/devices/device.py

Lines changed: 15 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import asyncio
88
import datetime
9-
import json
109
import logging
1110
from abc import ABC
1211
from collections.abc import Callable
@@ -16,12 +15,6 @@
1615
from roborock.data import HomeDataDevice, HomeDataProduct
1716
from roborock.diagnostics import redact_device_data
1817
from roborock.exceptions import RoborockException
19-
from roborock.roborock_message import (
20-
ROBOROCK_DATA_STATUS_PROTOCOL,
21-
RoborockDataProtocol,
22-
RoborockMessage,
23-
RoborockMessageProtocol,
24-
)
2518
from roborock.util import RoborockLoggerAdapter
2619

2720
from .traits import Trait
@@ -81,6 +74,7 @@ def __init__(
8174
self._channel = channel
8275
self._connect_task: asyncio.Task[None] | None = None
8376
self._unsub: Callable[[], None] | None = None
77+
self._v1_unsub: Callable[[], None] | None = None
8478
self._ready_callbacks = CallbackList["RoborockDevice"]()
8579
self._has_connected = False
8680

@@ -202,15 +196,23 @@ async def connect(self) -> None:
202196
"""Connect to the device using the appropriate protocol channel."""
203197
if self._unsub:
204198
raise ValueError("Already connected to the device")
205-
unsub = await self._channel.subscribe(self._on_message)
199+
206200
if self.v1_properties is not None:
207201
try:
202+
# V1 layer subscribes to the channel and handles protocol updates.
203+
# Note: V1Channel only allows one subscription, so the V1 layer
204+
# is the sole subscriber for V1 devices.
205+
self._v1_unsub = await self.v1_properties.subscribe_async(self._channel)
208206
await self.v1_properties.discover_features()
209207
except RoborockException:
210-
unsub()
208+
if self._v1_unsub:
209+
self._v1_unsub()
211210
raise
211+
else:
212+
# Non-V1 devices subscribe directly (no protocol update handling needed)
213+
self._unsub = await self._channel.subscribe(lambda msg: None)
214+
212215
self._logger.info("Connected to device")
213-
self._unsub = unsub
214216

215217
async def close(self) -> None:
216218
"""Close all connections to the device."""
@@ -220,70 +222,13 @@ async def close(self) -> None:
220222
await self._connect_task
221223
except asyncio.CancelledError:
222224
pass
225+
if self._v1_unsub:
226+
self._v1_unsub()
227+
self._v1_unsub = None
223228
if self._unsub:
224229
self._unsub()
225230
self._unsub = None
226231

227-
def _on_message(self, message: RoborockMessage) -> None:
228-
"""Handle incoming messages from the device.
229-
230-
Note: Protocol updates (data points) are only sent via cloud/MQTT, not local connection.
231-
"""
232-
self._logger.debug("Received message from device: %s", message)
233-
if self.v1_properties is None:
234-
# Ensure we are only doing below logic for set-up V1 devices.
235-
return
236-
237-
# Only process messages that can contain protocol updates
238-
# RPC_RESPONSE (102), and GENERAL_RESPONSE (5)
239-
if message.protocol not in {
240-
RoborockMessageProtocol.RPC_RESPONSE,
241-
RoborockMessageProtocol.GENERAL_RESPONSE,
242-
}:
243-
return
244-
245-
if not message.payload:
246-
return
247-
248-
try:
249-
payload = json.loads(message.payload.decode("utf-8"))
250-
dps = payload.get("dps", {})
251-
252-
if not dps:
253-
return
254-
255-
# Process each data point in the message
256-
for data_point_number, data_point in dps.items():
257-
# Skip RPC responses (102) as they're handled by the RPC channel
258-
if data_point_number == "102":
259-
continue
260-
261-
try:
262-
data_protocol = RoborockDataProtocol(int(data_point_number))
263-
self._logger.debug("Got device update for %s: %s", data_protocol.name, data_point)
264-
self._handle_protocol_update(data_protocol, data_point)
265-
except ValueError:
266-
# Unknown protocol number
267-
self._logger.debug(
268-
f"Got unknown data protocol {data_point_number}, data: {data_point}. "
269-
f"This may allow for faster updates in the future."
270-
)
271-
except (json.JSONDecodeError, UnicodeDecodeError, KeyError) as ex:
272-
self._logger.debug("Failed to parse protocol message: %s", ex)
273-
274-
def _handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: Any) -> None:
275-
"""Handle a protocol update for a specific data protocol.
276-
277-
Args:
278-
protocol: The data protocol number.
279-
data_point: The data value for this protocol.
280-
"""
281-
# Handle status protocol updates
282-
if protocol in ROBOROCK_DATA_STATUS_PROTOCOL and self.v1_properties and self.v1_properties.status:
283-
if self.v1_properties.status.handle_protocol_update(protocol, data_point):
284-
self._logger.debug("Updated status.%s to %s", protocol.name.lower(), data_point)
285-
self.v1_properties.status.notify_update()
286-
287232
def diagnostic_data(self) -> dict[str, Any]:
288233
"""Return diagnostics information about the device."""
289234
extra: dict[str, Any] = {}

roborock/devices/traits/v1/__init__.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@
5252
code in HomeDataProduct Schema that is required for the field to be supported.
5353
"""
5454

55+
import json
5556
import logging
57+
from collections.abc import Callable
5658
from dataclasses import dataclass, field, fields
5759
from functools import cache
5860
from typing import Any, get_args
@@ -61,8 +63,15 @@
6163
from roborock.data.v1.v1_code_mappings import RoborockDockTypeCode
6264
from roborock.devices.cache import DeviceCache
6365
from roborock.devices.traits import Trait
66+
from roborock.devices.transport.channel import Channel
6467
from roborock.map.map_parser import MapParserConfig
6568
from roborock.protocols.v1_protocol import V1RpcChannel
69+
from roborock.roborock_message import (
70+
ROBOROCK_DATA_STATUS_PROTOCOL,
71+
RoborockDataProtocol,
72+
RoborockMessage,
73+
RoborockMessageProtocol,
74+
)
6675
from roborock.web_api import UserWebApiClient
6776

6877
from . import (
@@ -313,6 +322,79 @@ def as_dict(self) -> dict[str, Any]:
313322
result[item.name] = data
314323
return result
315324

325+
async def subscribe_async(self, channel: Channel) -> Callable[[], None]:
326+
"""Subscribe to protocol updates from the channel.
327+
328+
This handles MQTT protocol updates for V1 devices, routing data point
329+
updates to the appropriate traits.
330+
331+
Args:
332+
channel: The channel to subscribe to for updates.
333+
334+
Returns:
335+
A callable that can be used to unsubscribe from updates.
336+
"""
337+
338+
def on_message(message: RoborockMessage) -> None:
339+
self._handle_message(message)
340+
341+
return await channel.subscribe(on_message)
342+
343+
def _handle_message(self, message: RoborockMessage) -> None:
344+
"""Handle incoming messages from the device.
345+
346+
Parses protocol updates and routes them to the appropriate traits.
347+
"""
348+
# Only process messages that can contain protocol updates
349+
# RPC_RESPONSE (102), and GENERAL_RESPONSE (5)
350+
if message.protocol not in {
351+
RoborockMessageProtocol.RPC_RESPONSE,
352+
RoborockMessageProtocol.GENERAL_RESPONSE,
353+
}:
354+
return
355+
356+
if not message.payload:
357+
return
358+
359+
try:
360+
payload = json.loads(message.payload.decode("utf-8"))
361+
dps = payload.get("dps", {})
362+
363+
if not dps:
364+
return
365+
366+
# Process each data point in the message
367+
for data_point_number, data_point in dps.items():
368+
# Skip RPC responses (102) as they're handled by the RPC channel
369+
if data_point_number == "102":
370+
continue
371+
372+
try:
373+
data_protocol = RoborockDataProtocol(int(data_point_number))
374+
_LOGGER.debug("Got device update for %s: %s", data_protocol.name, data_point)
375+
self._handle_protocol_update(data_protocol, data_point)
376+
except ValueError:
377+
# Unknown protocol number
378+
_LOGGER.debug(
379+
f"Got unknown data protocol {data_point_number}, data: {data_point}. "
380+
f"This may allow for faster updates in the future."
381+
)
382+
except (json.JSONDecodeError, UnicodeDecodeError, KeyError) as ex:
383+
_LOGGER.debug("Failed to parse protocol message: %s", ex)
384+
385+
def _handle_protocol_update(self, protocol: RoborockDataProtocol, data_point: Any) -> None:
386+
"""Handle a protocol update for a specific data protocol.
387+
388+
Args:
389+
protocol: The data protocol number.
390+
data_point: The data value for this protocol.
391+
"""
392+
# Handle status protocol updates
393+
if protocol in ROBOROCK_DATA_STATUS_PROTOCOL and self.status:
394+
if self.status.handle_protocol_update(protocol, data_point):
395+
_LOGGER.debug("Updated status.%s to %s", protocol.name.lower(), data_point)
396+
self.status.notify_update()
397+
316398

317399
def create(
318400
device_uid: str,

roborock/devices/traits/v1/common.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,14 @@
1111
from dataclasses import dataclass, fields
1212
from typing import ClassVar, Self
1313

14-
from roborock.callbacks import CallbackList
1514
from roborock.data import RoborockBase
1615
from roborock.protocols.v1_protocol import V1RpcChannel
1716
from roborock.roborock_typing import RoborockCommand
1817

1918
_LOGGER = logging.getLogger(__name__)
2019

2120
V1ResponseData = dict | list | int | str
22-
V1TraitUpdateCallback = Callable[["V1TraitMixin"], None]
21+
V1TraitUpdateCallback = Callable[[], None]
2322

2423

2524
@dataclass
@@ -79,7 +78,7 @@ def __post_init__(self) -> None:
7978
device setup code.
8079
"""
8180
self._rpc_channel = None
82-
self._update_callbacks: CallbackList[V1TraitMixin] = CallbackList()
81+
self._update_callbacks: list[V1TraitUpdateCallback] = []
8382

8483
@property
8584
def rpc_channel(self) -> V1RpcChannel:
@@ -106,17 +105,22 @@ def _update_trait_values(self, new_data: RoborockBase) -> None:
106105
def add_update_callback(self, callback: V1TraitUpdateCallback) -> Callable[[], None]:
107106
"""Add a callback to be notified when the trait is updated.
108107
109-
The callback will be called with the updated trait instance whenever
110-
a protocol message updates the trait.
108+
The callback will be called whenever a protocol message updates the trait.
109+
Callers should track which trait they subscribed to if needed.
111110
112111
Returns:
113112
A callable that can be used to remove the callback.
114113
"""
115-
return self._update_callbacks.add_callback(callback)
114+
self._update_callbacks.append(callback)
115+
return lambda: self._update_callbacks.remove(callback)
116116

117117
def notify_update(self) -> None:
118118
"""Notify all registered callbacks that the trait has been updated."""
119-
self._update_callbacks(self)
119+
for callback in self._update_callbacks:
120+
try:
121+
callback()
122+
except Exception: # noqa: BLE001
123+
_LOGGER.exception("Error in trait update callback")
120124

121125

122126
def _get_value_field(clazz: type[V1TraitMixin]) -> str:

0 commit comments

Comments
 (0)