Skip to content
49 changes: 15 additions & 34 deletions roborock/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from pyshark.packet.packet import Packet # type: ignore

from roborock import RoborockException
from roborock.containers import DeviceData, HomeDataProduct, LoginData
from roborock.mqtt.roborock_session import create_mqtt_session
from roborock.protocol import MessageParser, create_mqtt_params
from roborock.containers import DeviceData, HomeData, HomeDataProduct, LoginData
from roborock.devices.device_manager import create_device_manager, create_home_data_api
from roborock.protocol import MessageParser
from roborock.util import run_sync
from roborock.version_1_apis.roborock_local_client_v1 import RoborockLocalClientV1
from roborock.version_1_apis.roborock_mqtt_client_v1 import RoborockMqttClientV1
Expand Down Expand Up @@ -101,44 +101,25 @@ async def session(ctx, duration: int):
context: RoborockContext = ctx.obj
login_data = context.login_data()

# Discovery devices if not already available
if not login_data.home_data:
await _discover(ctx)
login_data = context.login_data()
if not login_data.home_data or not login_data.home_data.devices:
raise RoborockException("Unable to discover devices")

all_devices = login_data.home_data.devices + login_data.home_data.received_devices
click.echo(f"Discovered devices: {', '.join([device.name for device in all_devices])}")

rriot = login_data.user_data.rriot
params = create_mqtt_params(rriot)

mqtt_session = await create_mqtt_session(params)
click.echo("Starting MQTT session...")
if not mqtt_session.connected:
raise RoborockException("Failed to connect to MQTT broker")
home_data_api = create_home_data_api(login_data.email, login_data.user_data)

def on_message(bytes: bytes):
"""Callback function to handle incoming MQTT messages."""
# Decode the first 20 bytes of the message for display
bytes = bytes[:20]
async def home_data_cache() -> HomeData:
if login_data.home_data is None:
login_data.home_data = await home_data_api()
context.update(login_data)
return login_data.home_data

click.echo(f"Received message: {bytes}...")
# Create device manager
device_manager = await create_device_manager(login_data.user_data, home_data_cache)

unsubs = []
for device in all_devices:
device_topic = f"rr/m/o/{rriot.u}/{params.username}/{device.duid}"
unsub = await mqtt_session.subscribe(device_topic, on_message)
unsubs.append(unsub)
devices = await device_manager.get_devices()
click.echo(f"Discovered devices: {', '.join([device.name for device in devices])}")

click.echo("MQTT session started. Listening for messages...")
await asyncio.sleep(duration)

click.echo("Stopping MQTT session...")
for unsub in unsubs:
unsub()
await mqtt_session.close()
# Close the device manager (this will close all devices and MQTT session)
await device_manager.close()


async def _discover(ctx):
Expand Down
46 changes: 44 additions & 2 deletions roborock/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

import enum
import logging
from collections.abc import Callable
from functools import cached_property

from roborock.containers import HomeDataDevice, HomeDataProduct, UserData

from .mqtt_channel import MqttChannel

_LOGGER = logging.getLogger(__name__)

__all__ = [
Expand All @@ -29,11 +32,25 @@ class DeviceVersion(enum.StrEnum):
class RoborockDevice:
"""Unified Roborock device class with automatic connection setup."""

def __init__(self, user_data: UserData, device_info: HomeDataDevice, product_info: HomeDataProduct) -> None:
"""Initialize the RoborockDevice with device info, user data, and capabilities."""
def __init__(
self,
user_data: UserData,
device_info: HomeDataDevice,
product_info: HomeDataProduct,
mqtt_channel: MqttChannel,
) -> None:
"""Initialize the RoborockDevice.

The device takes ownership of the MQTT channel for communication with the device.
Use `connect()` to establish the connection, which will set up the MQTT channel
for receiving messages from the device. Use `close()` to unsubscribe from the MQTT
channel.
"""
self._user_data = user_data
self._device_info = device_info
self._product_info = product_info
self._mqtt_channel = mqtt_channel
self._unsub: Callable[[], None] | None = None

@property
def duid(self) -> str:
Expand Down Expand Up @@ -63,3 +80,28 @@ def device_version(self) -> str:
self._device_info.name,
)
return DeviceVersion.UNKNOWN

async def connect(self) -> None:
"""Connect to the device using MQTT.

This method will set up the MQTT channel for communication with the device.
"""
if self._unsub:
raise ValueError("Already connected to the device")
self._unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)

async def close(self) -> None:
"""Close the MQTT connection to the device.

This method will unsubscribe from the MQTT channel and clean up resources.
"""
if self._unsub:
self._unsub()
self._unsub = None

def _on_mqtt_message(self, message: bytes) -> None:
"""Handle incoming MQTT messages from the device.

This method should be overridden in subclasses to handle specific device messages.
"""
_LOGGER.debug("Received message from device %s: %s", self.duid, message[:50]) # Log first 50 bytes for brevity
44 changes: 37 additions & 7 deletions roborock/devices/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@
UserData,
)
from roborock.devices.device import RoborockDevice
from roborock.mqtt.roborock_session import create_mqtt_session
from roborock.mqtt.session import MqttSession
from roborock.protocol import create_mqtt_params
from roborock.web_api import RoborockApiClient

from .mqtt_channel import MqttChannel

_LOGGER = logging.getLogger(__name__)

__all__ = [
Expand All @@ -34,21 +39,32 @@ def __init__(
self,
home_data_api: HomeDataApi,
device_creator: DeviceCreator,
mqtt_session: MqttSession,
) -> None:
"""Initialize the DeviceManager with user data and optional cache storage."""
"""Initialize the DeviceManager with user data and optional cache storage.

This takes ownership of the MQTT session and will close it when the manager is closed.
"""
self._home_data_api = home_data_api
self._device_creator = device_creator
self._devices: dict[str, RoborockDevice] = {}
self._mqtt_session = mqtt_session

async def discover_devices(self) -> list[RoborockDevice]:
"""Discover all devices for the logged-in user."""
home_data = await self._home_data_api()
device_products = home_data.device_products
_LOGGER.debug("Discovered %d devices %s", len(device_products), home_data)

self._devices = {
duid: self._device_creator(device, product) for duid, (device, product) in device_products.items()
}
new_devices = {}
for duid, (device, product) in device_products.items():
if duid in self._devices:
continue
new_device = self._device_creator(device, product)
await new_device.connect()
Comment thread
allenporter marked this conversation as resolved.
new_devices[duid] = new_device

self._devices.update(new_devices)
return list(self._devices.values())

async def get_device(self, duid: str) -> RoborockDevice | None:
Expand All @@ -59,6 +75,14 @@ async def get_devices(self) -> list[RoborockDevice]:
"""Get all discovered devices."""
return list(self._devices.values())

async def close(self) -> None:
"""Close all MQTT connections and clean up resources."""
for device in self._devices.values():
await device.close()
Comment thread
allenporter marked this conversation as resolved.
Outdated
self._devices.clear()
if self._mqtt_session:
await self._mqtt_session.close()


def create_home_data_api(email: str, user_data: UserData) -> HomeDataApi:
"""Create a home data API wrapper.
Expand All @@ -67,7 +91,9 @@ def create_home_data_api(email: str, user_data: UserData) -> HomeDataApi:
home data for the user.
"""

client = RoborockApiClient(email, user_data)
# Note: This will auto discover the API base URL. This can be improved
# by caching this next to `UserData` if needed to avoid unnecessary API calls.
client = RoborockApiClient(email)

async def home_data_api() -> HomeData:
return await client.get_home_data(user_data)
Expand All @@ -83,9 +109,13 @@ async def create_device_manager(user_data: UserData, home_data_api: HomeDataApi)
include caching or other optimizations.
"""

mqtt_params = create_mqtt_params(user_data.rriot)
mqtt_session = await create_mqtt_session(mqtt_params)

def device_creator(device: HomeDataDevice, product: HomeDataProduct) -> RoborockDevice:
return RoborockDevice(user_data, device, product)
mqtt_channel = MqttChannel(mqtt_session, device.duid, user_data.rriot, mqtt_params)
return RoborockDevice(user_data, device, product, mqtt_channel)

manager = DeviceManager(home_data_api, device_creator)
manager = DeviceManager(home_data_api, device_creator, mqtt_session=mqtt_session)
await manager.discover_devices()
return manager
44 changes: 44 additions & 0 deletions roborock/devices/mqtt_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging
from collections.abc import Callable

from roborock.containers import RRiot
from roborock.mqtt.session import MqttParams, MqttSession

_LOGGER = logging.getLogger(__name__)


class MqttChannel:
"""RPC-style channel for communicating with a specific device over MQTT.

This currently only supports listening to messages and does not yet
support RPC functionality.
"""

def __init__(self, mqtt_session: MqttSession, duid: str, rriot: RRiot, mqtt_params: MqttParams):
self._mqtt_session = mqtt_session
self._duid = duid
self._rriot = rriot
self._mqtt_params = mqtt_params
self._unsub: Callable[[], None] | None = None
Comment thread
allenporter marked this conversation as resolved.
Outdated

@property
def _publish_topic(self) -> str:
"""Topic to send commands to the device."""
return f"rr/m/i/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}"

@property
def _subscribe_topic(self) -> str:
"""Topic to receive responses from the device."""
return f"rr/m/o/{self._rriot.u}/{self._mqtt_params.username}/{self._duid}"

async def subscribe(self, callback: Callable[[bytes], None]) -> Callable[[], None]:
"""Subscribe to the device's response topic.

The callback will be called with the message payload when a message is received.
If already subscribed, raises ValueError.
Comment thread
allenporter marked this conversation as resolved.
Outdated

Returns a callable that can be used to unsubscribe from the topic.
"""
if self._unsub:
raise ValueError("Already subscribed to the response topic")
Comment thread
allenporter marked this conversation as resolved.
Outdated
return await self._mqtt_session.subscribe(self._subscribe_topic, callback)
Comment thread
allenporter marked this conversation as resolved.
Outdated
40 changes: 40 additions & 0 deletions tests/devices/test_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Tests for the Device class."""

from unittest.mock import AsyncMock, Mock

from roborock.containers import HomeData, UserData
from roborock.devices.device import DeviceVersion, RoborockDevice

from .. import mock_data

USER_DATA = UserData.from_dict(mock_data.USER_DATA)
HOME_DATA = HomeData.from_dict(mock_data.HOME_DATA_RAW)


async def test_device_connection() -> None:
"""Test the Device connection setup."""

unsub = Mock()
subscribe = AsyncMock()
subscribe.return_value = unsub
mqtt_channel = AsyncMock()
mqtt_channel.subscribe = subscribe

device = RoborockDevice(
USER_DATA,
device_info=HOME_DATA.devices[0],
product_info=HOME_DATA.products[0],
mqtt_channel=mqtt_channel,
)
assert device.duid == "abc123"
assert device.name == "Roborock S7 MaxV"
assert device.device_version == DeviceVersion.V1

assert not subscribe.called

await device.connect()
assert subscribe.called
assert not unsub.called

await device.close()
assert unsub.called
11 changes: 11 additions & 0 deletions tests/devices/test_device_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for the DeviceManager class."""

from collections.abc import Generator
from unittest.mock import patch

import pytest
Expand All @@ -14,6 +15,13 @@
USER_DATA = UserData.from_dict(mock_data.USER_DATA)


@pytest.fixture(autouse=True)
def setup_mqtt_session() -> Generator[None, None, None]:
"""Fixture to set up the MQTT session for the tests."""
with patch("roborock.devices.device_manager.create_mqtt_session"):
yield


async def home_home_data_no_devices() -> HomeData:
"""Mock home data API that returns no devices."""
return HomeData(
Expand Down Expand Up @@ -52,12 +60,15 @@ async def test_with_device() -> None:
assert device.name == "Roborock S7 MaxV"
assert device.device_version == DeviceVersion.V1

await device_manager.close()


async def test_get_non_existent_device() -> None:
"""Test getting a non-existent device."""
device_manager = await create_device_manager(USER_DATA, mock_home_data)
device = await device_manager.get_device("non_existent_duid")
assert device is None
await device_manager.close()


async def test_home_data_api_exception() -> None:
Expand Down
Loading