22
33import logging
44from dataclasses import dataclass
5+ from functools import cached_property
56
67from roborock .data import HomeData , HomeDataRoom , NamedRoomMapping , RoborockBase
78from roborock .devices .traits .v1 import common
@@ -84,12 +85,22 @@ class RoomsTrait(Rooms, common.V1TraitMixin):
8485 command = RoborockCommand .GET_ROOM_MAPPING
8586 converter = RoomsConverter ()
8687
87- def __init__ (self , home_data : HomeData , web_api : UserWebApiClient ) -> None :
88+ def __init__ (self , home_data : HomeData , device_uid : str , web_api : UserWebApiClient ) -> None :
8889 """Initialize the RoomsTrait."""
8990 super ().__init__ ()
9091 self ._home_data = home_data
92+ self ._device_uid = device_uid
9193 self ._web_api = web_api
9294 self ._discovered_iot_ids : set [str ] = set ()
95+ self ._shared_room_names : dict [str , str ] = {}
96+
97+ @cached_property
98+ def _is_shared (self ) -> bool :
99+ return any (d .duid == self ._device_uid for d in self ._home_data .received_devices )
100+
101+ @property
102+ def _room_name_map (self ) -> dict [str , str ]:
103+ return {** self ._home_data .rooms_name_map , ** self ._shared_room_names }
93104
94105 async def refresh (self ) -> None :
95106 """Refresh room mappings and backfill unknown room names from the web API."""
@@ -104,12 +115,15 @@ async def refresh(self) -> None:
104115
105116 segment_map = RoomsConverter .extract_segment_map (response )
106117 # Track all iot ids seen before. Refresh the room list when new ids are found.
107- new_iot_ids = set (segment_map .values ()) - set (self ._home_data . rooms_map .keys ())
118+ new_iot_ids = set (segment_map .values ()) - set (self ._room_name_map .keys ())
108119 if new_iot_ids - self ._discovered_iot_ids :
109120 _LOGGER .debug ("Refreshing room list to discover new room names" )
110121 if updated_rooms := await self ._refresh_rooms ():
111122 _LOGGER .debug ("Updating rooms: %s" , list (updated_rooms ))
112- self ._home_data .rooms = updated_rooms
123+ if self ._is_shared :
124+ self ._shared_room_names = {room .iot_id : room .name for room in updated_rooms }
125+ else :
126+ self ._home_data .rooms = updated_rooms
113127 self ._discovered_iot_ids .update (new_iot_ids )
114128 try :
115129 rooms = self .converter .convert (response )
@@ -121,12 +135,14 @@ async def refresh(self) -> None:
121135 inner_error = err ,
122136 ) from err
123137
124- rooms = rooms .with_room_names (self ._home_data . rooms_name_map )
138+ rooms = rooms .with_room_names (self ._room_name_map )
125139 common .merge_trait_values (self , rooms )
126140
127141 async def _refresh_rooms (self ) -> list [HomeDataRoom ]:
128142 """Fetch the latest rooms from the web API."""
129143 try :
144+ if self ._is_shared :
145+ return await self ._web_api .get_shared_device_rooms (self ._device_uid )
130146 return await self ._web_api .get_rooms ()
131147 except Exception :
132148 _LOGGER .debug ("Failed to fetch rooms from web API" , exc_info = True )
0 commit comments