diff --git a/roborock/devices/traits/v1/__init__.py b/roborock/devices/traits/v1/__init__.py index 1ccf3c42..cb17e352 100644 --- a/roborock/devices/traits/v1/__init__.py +++ b/roborock/devices/traits/v1/__init__.py @@ -199,7 +199,7 @@ def __init__( self.device_features = DeviceFeaturesTrait(product, self._device_cache) self.status = StatusTrait(self.device_features, region=self._region) self.consumables = ConsumableTrait() - self.rooms = RoomsTrait(home_data, web_api) + self.rooms = RoomsTrait(home_data, device_uid, web_api) self.maps = MapsTrait(self.status) self.map_content = MapContentTrait(map_parser_config) self.home = HomeTrait(self.status, self.maps, self.map_content, self.rooms, self._device_cache) diff --git a/roborock/devices/traits/v1/rooms.py b/roborock/devices/traits/v1/rooms.py index c3cf2562..0b838c74 100644 --- a/roborock/devices/traits/v1/rooms.py +++ b/roborock/devices/traits/v1/rooms.py @@ -84,12 +84,21 @@ class RoomsTrait(Rooms, common.V1TraitMixin): command = RoborockCommand.GET_ROOM_MAPPING converter = RoomsConverter() - def __init__(self, home_data: HomeData, web_api: UserWebApiClient) -> None: + def __init__(self, home_data: HomeData, device_uid: str, web_api: UserWebApiClient) -> None: """Initialize the RoomsTrait.""" super().__init__() self._home_data = home_data + self._device_uid = device_uid + self._shared_device_uid = next( + (device.duid for device in home_data.received_devices if device.duid == device_uid), None + ) self._web_api = web_api self._discovered_iot_ids: set[str] = set() + self._room_names: dict[str, str] = dict(home_data.rooms_name_map) + + @property + def _room_name_map(self) -> dict[str, str]: + return self._room_names async def refresh(self) -> None: """Refresh room mappings and backfill unknown room names from the web API.""" @@ -104,12 +113,14 @@ async def refresh(self) -> None: segment_map = RoomsConverter.extract_segment_map(response) # Track all iot ids seen before. Refresh the room list when new ids are found. - new_iot_ids = set(segment_map.values()) - set(self._home_data.rooms_map.keys()) + new_iot_ids = set(segment_map.values()) - set(self._room_name_map.keys()) if new_iot_ids - self._discovered_iot_ids: _LOGGER.debug("Refreshing room list to discover new room names") if updated_rooms := await self._refresh_rooms(): _LOGGER.debug("Updating rooms: %s", list(updated_rooms)) - self._home_data.rooms = updated_rooms + self._room_names = {room.iot_id: room.name for room in updated_rooms} + if self._shared_device_uid is None: + self._home_data.rooms = updated_rooms self._discovered_iot_ids.update(new_iot_ids) try: rooms = self.converter.convert(response) @@ -121,12 +132,17 @@ async def refresh(self) -> None: inner_error=err, ) from err - rooms = rooms.with_room_names(self._home_data.rooms_name_map) + rooms = rooms.with_room_names(self._room_name_map) common.merge_trait_values(self, rooms) async def _refresh_rooms(self) -> list[HomeDataRoom]: """Fetch the latest rooms from the web API.""" try: + if self._shared_device_uid is not None: + rooms_by_id = {room.iot_id: room for room in self._home_data.rooms} + shared_rooms = await self._web_api.get_shared_device_rooms(self._shared_device_uid) + rooms_by_id.update({room.iot_id: room for room in shared_rooms}) + return list(rooms_by_id.values()) return await self._web_api.get_rooms() except Exception: _LOGGER.debug("Failed to fetch rooms from web API", exc_info=True) diff --git a/roborock/web_api.py b/roborock/web_api.py index a76d14c5..4700c601 100644 --- a/roborock/web_api.py +++ b/roborock/web_api.py @@ -557,6 +557,33 @@ async def get_rooms(self, user_data: UserData, home_id: int | None = None) -> li else: raise RoborockException("home_response result was an unexpected type") + async def get_shared_device_rooms(self, user_data: UserData, device_id: str) -> list[HomeDataRoom]: + """Fetch room names for a shared (received) device.""" + rriot = user_data.rriot + if rriot is None: + raise RoborockException("rriot is none") + if rriot.r.a is None: + raise RoborockException("Missing field 'a' in rriot reference") + path = f"/user/deviceshare/query/{device_id}/rooms" + room_request = PreparedRequest( + rriot.r.a, + self.session, + {"Authorization": _get_hawk_authentication(rriot, path)}, + ) + room_response = await room_request.request("get", path) + if not room_response.get("success"): + raise RoborockException(room_response) + rooms = room_response.get("result") + if isinstance(rooms, list): + output_list = [] + for room in rooms: + normalized_room = room + if isinstance(room, dict) and "id" not in room and "roomId" in room: + normalized_room = {**room, "id": room["roomId"]} + output_list.append(HomeDataRoom.from_dict(normalized_room)) + return output_list + raise RoborockException("get_shared_device_rooms result was an unexpected type") + async def get_scenes(self, user_data: UserData, device_id: str) -> list[HomeDataScene]: rriot = user_data.rriot if rriot is None: @@ -754,6 +781,10 @@ async def get_rooms(self) -> list[HomeDataRoom]: """Fetch rooms using the API client.""" return await self._web_api.get_rooms(self._user_data) + async def get_shared_device_rooms(self, device_id: str) -> list[HomeDataRoom]: + """Fetch shared-device rooms using the API client.""" + return await self._web_api.get_shared_device_rooms(self._user_data, device_id) + async def execute_routine(self, scene_id: int) -> None: """Execute a specific routine (scene) by its ID.""" await self._web_api.execute_scene(self._user_data, scene_id) diff --git a/tests/devices/traits/v1/fixtures.py b/tests/devices/traits/v1/fixtures.py index bf42d151..fba0b86a 100644 --- a/tests/devices/traits/v1/fixtures.py +++ b/tests/devices/traits/v1/fixtures.py @@ -46,6 +46,12 @@ def web_api_client_fixture() -> AsyncMock: return AsyncMock() +@pytest.fixture(name="trait_home_data") +def trait_home_data_fixture(request: pytest.FixtureRequest) -> HomeData: + """Fixture to provide HomeData, optionally overridden via indirect parametrization.""" + return deepcopy(getattr(request, "param", HOME_DATA)) + + @pytest.fixture(autouse=True, name="roborock_cache") def roborock_cache_fixture() -> Cache: """Fixture to provide a NoCache instance for tests.""" @@ -53,15 +59,15 @@ def roborock_cache_fixture() -> Cache: @pytest.fixture(autouse=True, name="device_cache") -def device_cache_fixture(roborock_cache: Cache) -> DeviceCache: +def device_cache_fixture(roborock_cache: Cache, trait_home_data: HomeData) -> DeviceCache: """Fixture to provide a DeviceCache instance for tests.""" - return DeviceCache(HOME_DATA.devices[0].duid, roborock_cache) + return DeviceCache(trait_home_data.get_all_devices()[0].duid, roborock_cache) @pytest.fixture(name="device_info") -def device_info_fixture() -> HomeDataDevice: +def device_info_fixture(trait_home_data: HomeData) -> HomeDataDevice: """Fixture to provide a DeviceInfo instance for tests.""" - return HOME_DATA.devices[0] + return trait_home_data.get_all_devices()[0] @pytest.fixture(name="products") @@ -79,6 +85,7 @@ def device_fixture( web_api_client: AsyncMock, device_cache: DeviceCache, device_info: HomeDataDevice, + trait_home_data: HomeData, products: list[HomeDataProduct], ) -> RoborockDevice: """Fixture to set up the device for tests.""" @@ -90,7 +97,7 @@ def device_fixture( trait=v1.create( device_info.duid, product, - deepcopy(HOME_DATA), + trait_home_data, mock_rpc_channel, mock_mqtt_rpc_channel, mock_map_rpc_channel, diff --git a/tests/devices/traits/v1/test_rooms.py b/tests/devices/traits/v1/test_rooms.py index 70dfb080..9e6d720f 100644 --- a/tests/devices/traits/v1/test_rooms.py +++ b/tests/devices/traits/v1/test_rooms.py @@ -1,15 +1,17 @@ """Tests for the RoomMapping related functionality.""" +from copy import deepcopy from typing import Any from unittest.mock import AsyncMock import pytest -from roborock.data.containers import HomeDataRoom, NamedRoomMapping +from roborock.data.containers import HomeData, HomeDataRoom, NamedRoomMapping from roborock.devices.device import RoborockDevice from roborock.devices.traits.v1.rooms import RoomsTrait from roborock.devices.traits.v1.status import StatusTrait from roborock.roborock_typing import RoborockCommand +from tests.devices.traits.v1.fixtures import HOME_DATA @pytest.fixture @@ -79,29 +81,25 @@ async def test_refresh_unknown_room_names_overwrites_home_data( mock_rpc_channel: AsyncMock, ) -> None: """Test web rooms are used to refresh home data for missing iot ids.""" - original_rooms = list(rooms_trait._home_data.rooms or ()) - try: - web_api_client.get_rooms.return_value = [ - HomeDataRoom(id=2362048, name="Living Room"), - HomeDataRoom(id=2362044, name="Example room 2"), - HomeDataRoom(id=2362041, name="Example room 3"), - HomeDataRoom(id=9999999, name="Office"), - ] - - room_mapping_data = [[16, "2362048"], [17, "9999999"]] - mock_rpc_channel.send_command.side_effect = [room_mapping_data] + web_api_client.get_rooms.return_value = [ + HomeDataRoom(id=2362048, name="Living Room"), + HomeDataRoom(id=2362044, name="Example room 2"), + HomeDataRoom(id=2362041, name="Example room 3"), + HomeDataRoom(id=9999999, name="Office"), + ] + + room_mapping_data = [[16, "2362048"], [17, "9999999"]] + mock_rpc_channel.send_command.side_effect = [room_mapping_data] - await rooms_trait.refresh() + await rooms_trait.refresh() - assert rooms_trait.rooms - assert rooms_trait.rooms[0] == NamedRoomMapping(segment_id=16, iot_id="2362048", raw_name="Living Room") - assert rooms_trait.rooms[1] == NamedRoomMapping(segment_id=17, iot_id="9999999", raw_name="Office") + assert rooms_trait.rooms + assert rooms_trait.rooms[0] == NamedRoomMapping(segment_id=16, iot_id="2362048", raw_name="Living Room") + assert rooms_trait.rooms[1] == NamedRoomMapping(segment_id=17, iot_id="9999999", raw_name="Office") - home_data_rooms = {str(room.id): room.name for room in rooms_trait._home_data.rooms or ()} - assert home_data_rooms["2362048"] == "Living Room" - assert home_data_rooms["9999999"] == "Office" - finally: - rooms_trait._home_data.rooms = original_rooms + home_data_rooms = {str(room.id): room.name for room in rooms_trait._home_data.rooms or ()} + assert home_data_rooms["2362048"] == "Living Room" + assert home_data_rooms["9999999"] == "Office" async def test_refresh_unknown_room_names_web_api_called_once( @@ -110,25 +108,21 @@ async def test_refresh_unknown_room_names_web_api_called_once( mock_rpc_channel: AsyncMock, ) -> None: """Test unknown room IDs trigger one web lookup per iot_id.""" - original_rooms = list(rooms_trait._home_data.rooms or ()) - try: - web_api_client.get_rooms.return_value = [ - HomeDataRoom(id=9999911, name="Living Room"), - ] + web_api_client.get_rooms.return_value = [ + HomeDataRoom(id=9999911, name="Living Room"), + ] - room_mapping_data = [[16, "9999911"]] - mock_rpc_channel.send_command.side_effect = [room_mapping_data, room_mapping_data] + room_mapping_data = [[16, "9999911"]] + mock_rpc_channel.send_command.side_effect = [room_mapping_data, room_mapping_data] - await rooms_trait.refresh() - assert rooms_trait.rooms - assert rooms_trait.rooms[0].name == "Living Room" + await rooms_trait.refresh() + assert rooms_trait.rooms + assert rooms_trait.rooms[0].name == "Living Room" - await rooms_trait.refresh() - assert rooms_trait.rooms - assert rooms_trait.rooms[0].name == "Living Room" - web_api_client.get_rooms.assert_called_once() - finally: - rooms_trait._home_data.rooms = original_rooms + await rooms_trait.refresh() + assert rooms_trait.rooms + assert rooms_trait.rooms[0].name == "Living Room" + web_api_client.get_rooms.assert_called_once() async def test_refresh_unknown_room_names_unresolved_uses_room_fallback( @@ -207,3 +201,36 @@ async def test_refresh_unknown_room_names_failure_falls_back_to_room_segment_id( assert rooms_trait.rooms[0] == NamedRoomMapping(segment_id=16, iot_id="9999401") assert rooms_trait.rooms[0].name == "Room 16" web_api_client.get_rooms.assert_called_once() + + +def _build_shared_home_data() -> HomeData: + home_data = deepcopy(HOME_DATA) + home_data.received_devices = [home_data.devices.pop(0)] + return home_data + + +@pytest.mark.parametrize("trait_home_data", [_build_shared_home_data()], indirect=True) +async def test_refresh_shared_room_names_use_shared_device_rooms( + rooms_trait: RoomsTrait, + trait_home_data: HomeData, + web_api_client: AsyncMock, + mock_rpc_channel: AsyncMock, +) -> None: + """Test shared devices resolve room names via the shared-device room list.""" + assert trait_home_data.received_devices + assert not rooms_trait.rooms + + web_api_client.get_shared_device_rooms.return_value = [ + HomeDataRoom(id=9999999, name="Office"), + ] + room_mapping_data = [[16, "2362048"], [17, "9999999"]] + mock_rpc_channel.send_command.side_effect = [room_mapping_data] + + await rooms_trait.refresh() + + assert rooms_trait.rooms == [ + NamedRoomMapping(segment_id=16, iot_id="2362048", raw_name="Example room 1"), + NamedRoomMapping(segment_id=17, iot_id="9999999", raw_name="Office"), + ] + web_api_client.get_shared_device_rooms.assert_called_once_with(rooms_trait._device_uid) + web_api_client.get_rooms.assert_not_called() diff --git a/tests/test_web_api.py b/tests/test_web_api.py index 429783fc..1135c977 100644 --- a/tests/test_web_api.py +++ b/tests/test_web_api.py @@ -5,7 +5,7 @@ import pytest from aioresponses.compat import normalize_url -from roborock import HomeData, HomeDataScene, UserData +from roborock import HomeData, HomeDataRoom, HomeDataScene, UserData from roborock.exceptions import RoborockAccountDoesNotExist from roborock.web_api import IotLoginInfo, RoborockApiClient from tests.mock_data import HOME_DATA_RAW, USER_DATA @@ -374,3 +374,43 @@ async def test_get_schedules(mock_rest) -> None: assert schedule.cron == "03 13 15 12 ?" assert schedule.repeated is False assert schedule.enabled is True + + +@pytest.mark.parametrize( + "result_payload", + [ + # roomId field (deviceshare endpoint convention) + [ + {"roomId": "2362048", "name": "Living Room"}, + {"roomId": 2362044, "name": "Kitchen"}, + ], + # id field (matches /user/homes/{id}/rooms — defensive in case the API normalizes) + [ + {"id": 2362048, "name": "Living Room"}, + {"id": 2362044, "name": "Kitchen"}, + ], + ], +) +async def test_get_shared_device_rooms(mock_rest, result_payload) -> None: + """Test that shared-device rooms are fetched from the deviceshare query path.""" + api = RoborockApiClient(username="test_user@gmail.com") + ud = await api.pass_login("password") + + mock_rest.get( + "https://api-us.roborock.com/user/deviceshare/query/device-id-q7/rooms", + status=200, + payload={ + "api": None, + "code": 200, + "result": result_payload, + "status": "ok", + "success": True, + }, + ) + + rooms = await api.get_shared_device_rooms(ud, "device-id-q7") + + assert rooms == [ + HomeDataRoom(id=2362048, name="Living Room"), + HomeDataRoom(id=2362044, name="Kitchen"), + ]