Skip to content

Commit 5a6cdc2

Browse files
committed
fix: handle Web API unauthorized errors
1 parent 86b839e commit 5a6cdc2

4 files changed

Lines changed: 81 additions & 11 deletions

File tree

roborock/devices/device_manager.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,15 @@ def create_web_api_wrapper(
173173
*,
174174
cache: Cache | None = None,
175175
session: aiohttp.ClientSession | None = None,
176+
unauthorized_hook: SessionUnauthorizedHook | None = None,
176177
) -> UserWebApiClient:
177178
"""Create a home data API wrapper from an existing API client."""
178179

179180
# Note: This will auto discover the API base URL. This can be improved
180181
# by caching this next to `UserData` if needed to avoid unnecessary API calls.
181182
client = RoborockApiClient(username=user_params.username, base_url=user_params.base_url, session=session)
182183

183-
return UserWebApiClient(client, user_params.user_data)
184+
return UserWebApiClient(client, user_params.user_data, unauthorized_hook=unauthorized_hook)
184185

185186

186187
async def create_device_manager(
@@ -212,7 +213,9 @@ async def create_device_manager(
212213
if cache is None:
213214
cache = NoCache()
214215

215-
web_api = create_web_api_wrapper(user_params, session=session, cache=cache)
216+
web_api = create_web_api_wrapper(
217+
user_params, session=session, cache=cache, unauthorized_hook=mqtt_session_unauthorized_hook
218+
)
216219
user_data = user_params.user_data
217220

218221
diagnostics = Diagnostics()
@@ -264,6 +267,12 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat
264267
dev.add_ready_callback(ready_callback)
265268
return dev
266269

267-
manager = DeviceManager(web_api, device_creator, mqtt_session=mqtt_session, cache=cache, diagnostics=diagnostics)
270+
manager = DeviceManager(
271+
web_api,
272+
device_creator,
273+
mqtt_session=mqtt_session,
274+
cache=cache,
275+
diagnostics=diagnostics,
276+
)
268277
await manager.discover_devices(prefer_cache)
269278
return manager

roborock/web_api.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import secrets
77
import string
88
import time
9+
from collections.abc import Callable
910
from dataclasses import dataclass
1011

1112
import aiohttp
@@ -737,23 +738,46 @@ class UserWebApiClient:
737738
to avoid needing to pass UserData around and mock out the web API.
738739
"""
739740

740-
def __init__(self, web_api: RoborockApiClient, user_data: UserData) -> None:
741+
def __init__(
742+
self, web_api: RoborockApiClient, user_data: UserData, unauthorized_hook: Callable[[], None] | None = None
743+
) -> None:
741744
"""Initialize the wrapper with the API client and user data."""
742745
self._web_api = web_api
743746
self._user_data = user_data
747+
self._unauthorized_hook = unauthorized_hook
744748

745749
async def get_home_data(self) -> HomeData:
746750
"""Fetch home data using the API client."""
747-
return await self._web_api.get_home_data_v3(self._user_data)
751+
try:
752+
return await self._web_api.get_home_data_v3(self._user_data)
753+
except RoborockInvalidCredentials:
754+
if self._unauthorized_hook:
755+
self._unauthorized_hook()
756+
raise
748757

749758
async def get_routines(self, device_id: str) -> list[HomeDataScene]:
750759
"""Fetch routines (scenes) for a specific device."""
751-
return await self._web_api.get_scenes(self._user_data, device_id)
760+
try:
761+
return await self._web_api.get_scenes(self._user_data, device_id)
762+
except RoborockInvalidCredentials:
763+
if self._unauthorized_hook:
764+
self._unauthorized_hook()
765+
raise
752766

753767
async def get_rooms(self) -> list[HomeDataRoom]:
754768
"""Fetch rooms using the API client."""
755-
return await self._web_api.get_rooms(self._user_data)
769+
try:
770+
return await self._web_api.get_rooms(self._user_data)
771+
except RoborockInvalidCredentials:
772+
if self._unauthorized_hook:
773+
self._unauthorized_hook()
774+
raise
756775

757776
async def execute_routine(self, scene_id: int) -> None:
758777
"""Execute a specific routine (scene) by its ID."""
759-
await self._web_api.execute_scene(self._user_data, scene_id)
778+
try:
779+
await self._web_api.execute_scene(self._user_data, scene_id)
780+
except RoborockInvalidCredentials:
781+
if self._unauthorized_hook:
782+
self._unauthorized_hook()
783+
raise

tests/devices/test_device_manager.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from roborock.devices.cache import InMemoryCache
1414
from roborock.devices.device import RoborockDevice
1515
from roborock.devices.device_manager import UserParams, create_device_manager, create_web_api_wrapper
16-
from roborock.exceptions import RoborockException
16+
from roborock.exceptions import RoborockException, RoborockInvalidCredentials
1717
from tests import mock_data
1818

1919
USER_DATA = UserData.from_dict(mock_data.USER_DATA)
@@ -150,6 +150,19 @@ async def test_create_home_data_api_exception() -> None:
150150
await api.get_home_data()
151151

152152

153+
async def test_device_manager_unauthorized_hook() -> None:
154+
"""Test that unauthorized hook is called when RoborockInvalidCredentials is raised."""
155+
mock_hook = Mock()
156+
with patch(
157+
"roborock.devices.device_manager.RoborockApiClient.get_home_data_v3",
158+
side_effect=RoborockInvalidCredentials("Unauthorized"),
159+
):
160+
with pytest.raises(RoborockInvalidCredentials, match="Unauthorized"):
161+
await create_device_manager(USER_PARAMS, mqtt_session_unauthorized_hook=mock_hook, prefer_cache=False)
162+
163+
mock_hook.assert_called_once()
164+
165+
153166
@pytest.mark.parametrize(("prefer_cache", "expected_call_count"), [(True, 1), (False, 2)])
154167
async def test_cache_logic(prefer_cache: bool, expected_call_count: int) -> None:
155168
"""Test that the cache logic works correctly."""

tests/test_web_api.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import re
22
from typing import Any
3+
from unittest.mock import AsyncMock, Mock
34

45
import aiohttp
56
import pytest
67
from aioresponses.compat import normalize_url
78

89
from roborock import HomeData, HomeDataScene, UserData
9-
from roborock.exceptions import RoborockAccountDoesNotExist
10-
from roborock.web_api import IotLoginInfo, RoborockApiClient
10+
from roborock.exceptions import RoborockAccountDoesNotExist, RoborockInvalidCredentials
11+
from roborock.web_api import IotLoginInfo, RoborockApiClient, UserWebApiClient
1112
from tests.mock_data import HOME_DATA_RAW, USER_DATA
1213

1314
pytest_plugins = [
@@ -374,3 +375,26 @@ async def test_get_schedules(mock_rest) -> None:
374375
assert schedule.cron == "03 13 15 12 ?"
375376
assert schedule.repeated is False
376377
assert schedule.enabled is True
378+
379+
380+
async def test_user_web_api_client_unauthorized_hook() -> None:
381+
"""Test that UserWebApiClient triggers unauthorized hook on RoborockInvalidCredentials."""
382+
mock_hook = Mock()
383+
mock_api = AsyncMock(spec=RoborockApiClient)
384+
385+
# Setup mock to raise RoborockInvalidCredentials
386+
mock_api.get_home_data_v3.side_effect = RoborockInvalidCredentials("Unauthorized")
387+
388+
client = UserWebApiClient(mock_api, UserData.from_dict(USER_DATA), unauthorized_hook=mock_hook)
389+
390+
with pytest.raises(RoborockInvalidCredentials):
391+
await client.get_home_data()
392+
393+
mock_hook.assert_called_once()
394+
395+
# Test another method
396+
mock_hook.reset_mock()
397+
mock_api.get_rooms.side_effect = RoborockInvalidCredentials("Unauthorized")
398+
with pytest.raises(RoborockInvalidCredentials):
399+
await client.get_rooms()
400+
mock_hook.assert_called_once()

0 commit comments

Comments
 (0)