From 71dbeb927f262ad6b1b8fed7f9b6e514a901c810 Mon Sep 17 00:00:00 2001 From: Luke Date: Sat, 2 May 2026 20:54:44 -0400 Subject: [PATCH 1/2] pull out bug fixes from ha_app branch --- .../https_server/routes/user/scene/service.py | 7 +- .../mqtt_tls_proxy_server/server.py | 151 ++++++++++- .../bundled_backend/shared/protocol_auth.py | 33 ++- .../bundled_backend/shared/routine_runner.py | 142 +++++++++-- .../shared/runtime_credentials.py | 37 +++ .../bundled_backend/shared/runtime_state.py | 30 +++ src/roborock_local_server/server.py | 1 + tests/contracts/test_ios_app_init_contract.py | 6 +- tests/test_admin_api.py | 134 +++++++++- tests/test_mqtt_tls_proxy.py | 241 ++++++++++++++++++ tests/test_routine_runner.py | 166 ++++++++++++ tests/test_runtime_state.py | 69 +++++ 12 files changed, 970 insertions(+), 47 deletions(-) diff --git a/src/roborock_local_server/bundled_backend/https_server/routes/user/scene/service.py b/src/roborock_local_server/bundled_backend/https_server/routes/user/scene/service.py index fd4e0ef..0928a20 100644 --- a/src/roborock_local_server/bundled_backend/https_server/routes/user/scene/service.py +++ b/src/roborock_local_server/bundled_backend/https_server/routes/user/scene/service.py @@ -563,7 +563,9 @@ def _routine_runner_for_context(ctx: ServerContext) -> RoutineRunner: def list_scenes_for_device(ctx: ServerContext, device_id: str) -> list[dict[str, Any]]: - scenes = _scene_state(ctx)["scenes"] + state = _scene_state(ctx) + scenes = state["scenes"] + home_id = state["home_id"] filtered: list[dict[str, Any]] = [] for scene in scenes: if not isinstance(scene, dict): @@ -571,7 +573,7 @@ def list_scenes_for_device(ctx: ServerContext, device_id: str) -> list[dict[str, scene_device = get_value(scene, "device_id", "deviceId", "duid") if scene_device and str(scene_device) != str(device_id): continue - filtered.append(build_scene_payload(scene, home_id=None, include_device_context=False)) + filtered.append(build_scene_payload(scene, home_id=home_id, include_device_context=True)) return filtered @@ -693,4 +695,3 @@ def apply_update(updated_scene: dict[str, Any], inventory: dict[str, Any]) -> No updated_scene, home_id = _replace_inventory_scene(ctx, scene_id=scene_id, scene_updater=apply_update) return build_scene_payload(updated_scene, home_id=home_id, include_device_context=True) - diff --git a/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py b/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py index 5d4de2b..531f8c5 100644 --- a/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py +++ b/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py @@ -64,6 +64,8 @@ def __init__( self._counter = 0 self._lock = threading.Lock() self._conn_protocol_levels: dict[str, int] = {} + self._conn_endpoints: dict[str, tuple[socket.socket, socket.socket]] = {} + self._pending_onboarding_auth: dict[str, dict[str, str]] = {} self._trace_queue: queue.Queue[tuple[str, str, bytes] | None] = queue.Queue() self._trace_thread: threading.Thread | None = None self._protocol_auth = ( @@ -83,6 +85,38 @@ def _next_conn(self) -> str: self._counter += 1 return str(self._counter) + def _register_conn_endpoints(self, conn_id: str, client_conn: socket.socket, backend_conn: socket.socket) -> None: + with self._lock: + self._conn_endpoints[conn_id] = (client_conn, backend_conn) + + def _pop_conn_endpoints(self, conn_id: str) -> tuple[socket.socket, socket.socket] | None: + with self._lock: + return self._conn_endpoints.pop(conn_id, None) + + def _close_conn_endpoints(self, conn_id: str) -> None: + endpoints = self._pop_conn_endpoints(conn_id) + if endpoints is None: + return + for endpoint in endpoints: + try: + endpoint.close() + except OSError: + pass + + def _set_pending_onboarding_auth(self, conn_id: str, candidate: dict[str, str]) -> None: + with self._lock: + self._pending_onboarding_auth[conn_id] = dict(candidate) + + def _get_pending_onboarding_auth(self, conn_id: str) -> dict[str, str] | None: + with self._lock: + candidate = self._pending_onboarding_auth.get(conn_id) + return dict(candidate) if candidate is not None else None + + def _pop_pending_onboarding_auth(self, conn_id: str) -> dict[str, str] | None: + with self._lock: + candidate = self._pending_onboarding_auth.pop(conn_id, None) + return dict(candidate) if candidate is not None else None + @staticmethod def _decode_remaining_length(data: bytes, start: int) -> tuple[int | None, int]: multiplier = 1 @@ -187,28 +221,37 @@ def _expected_bootstrap_credentials(self) -> tuple[str, str, str] | None: return username, password, client_id def _authorize_connect_packet(self, packet: bytes) -> tuple[bool, str, dict[str, Any] | None]: + authorized, reason, info, _candidate = self._authorize_connect_packet_for_client(packet, client_ip="") + return authorized, reason, info + + def _authorize_connect_packet_for_client( + self, + packet: bytes, + *, + client_ip: str, + ) -> tuple[bool, str, dict[str, Any] | None, dict[str, str] | None]: info = parse_mqtt_connect_packet(packet) if info is None: - return False, "invalid_connect_packet", None + return False, "invalid_connect_packet", None, None username = str(info.get("username") or "").strip() password = str(info.get("password") or "").strip() client_id = str(info.get("client_id") or "").strip() if not username or not password: - return False, "missing_mqtt_credentials", info + return False, "missing_mqtt_credentials", info, None if self._protocol_auth is not None and self._protocol_auth_enabled(): authorized, auth_reason, _matched_user = self._protocol_auth.verify_user_mqtt_credentials(username, password) if authorized: - return True, auth_reason, info + return True, auth_reason, info, None bootstrap_credentials = self._expected_bootstrap_credentials() if bootstrap_credentials is not None: expected_username, expected_password, expected_client_id = bootstrap_credentials if username == expected_username and password == expected_password: if expected_client_id and client_id and client_id != expected_client_id: - return False, "invalid_bootstrap_client_id", info - return True, "bootstrap", info + return False, "invalid_bootstrap_client_id", info, None + return True, "bootstrap", info, None if self.runtime_credentials is not None: authorized, auth_reason, _matched_device = self.runtime_credentials.verify_device_mqtt_credentials( @@ -216,16 +259,96 @@ def _authorize_connect_packet(self, packet: bytes) -> tuple[bool, str, dict[str, password=password, ) if authorized: - return True, auth_reason, info + return True, auth_reason, info, None if auth_reason == "device_mqtt_password_missing": recovered_device = self.runtime_credentials.recover_device_mqtt_password( username=username, password=password, ) if recovered_device is not None: - return True, "device_mqtt_recovered", info + return True, "device_mqtt_recovered", info, None + if auth_reason == "unknown_device_mqtt_username": + candidate = self._resolve_onboarding_device_mqtt_candidate( + client_ip=client_ip, + username=username, + password=password, + ) + if candidate is not None: + return True, "device_mqtt_onboarding_pending", info, candidate + + return False, "invalid_mqtt_credentials", info, None - return False, "invalid_mqtt_credentials", info + def _resolve_onboarding_device_mqtt_candidate( + self, + *, + client_ip: str, + username: str, + password: str, + ) -> dict[str, str] | None: + if self.runtime_state is None or self.runtime_credentials is None: + return None + candidate = self.runtime_state.onboarding_device_mqtt_candidate(client_ip=client_ip) + if candidate is None: + return None + device = self.runtime_credentials.resolve_device( + did=str(candidate.get("did") or ""), + duid=str(candidate.get("duid") or ""), + ) + if device is None: + return None + existing_username = str(device.get("device_mqtt_usr") or "").strip() + if existing_username: + return None + return { + "did": str(device.get("did") or "").strip(), + "duid": str(device.get("duid") or "").strip(), + "name": str(device.get("name") or candidate.get("name") or "").strip(), + "username": username.strip(), + "password": password.strip(), + "client_ip": client_ip.strip(), + } + + def _confirm_pending_onboarding_auth(self, conn_id: str, *, direction: str, topic: str) -> bool: + if direction != "c2b" or self.runtime_credentials is None: + return True + candidate = self._get_pending_onboarding_auth(conn_id) + if candidate is None: + return True + expected_topic = f"rr/d/i/{candidate['did']}/{candidate['username']}" + if topic != expected_topic: + self.logger.warning( + "[conn %s] rejected provisional onboarding MQTT session expected_topic=%s got=%s", + conn_id, + expected_topic, + topic, + ) + self._pop_pending_onboarding_auth(conn_id) + self._close_conn_endpoints(conn_id) + return False + learned = self.runtime_credentials.confirm_device_mqtt_credentials( + did=candidate.get("did", ""), + duid=candidate.get("duid", ""), + username=candidate["username"], + password=candidate["password"], + ) + self._pop_pending_onboarding_auth(conn_id) + if learned is None: + self.logger.warning( + "[conn %s] failed to persist confirmed onboarding MQTT credentials did=%s duid=%s", + conn_id, + candidate.get("did", ""), + candidate.get("duid", ""), + ) + self._close_conn_endpoints(conn_id) + return False + self.logger.info( + "[conn %s] learned onboarding MQTT credentials did=%s duid=%s username=%s", + conn_id, + learned.get("did", ""), + learned.get("duid", ""), + candidate["username"], + ) + return True @classmethod def _extract_publish(cls, packet: bytes, protocol_level: int | None = None) -> tuple[str | None, bytes | None]: @@ -408,6 +531,8 @@ def _trace_packet(self, conn_id: str, direction: str, packet: bytes) -> None: topic, payload = self._extract_publish(packet, self._get_conn_protocol_level(conn_id)) if topic is None or payload is None: return + if not self._confirm_pending_onboarding_auth(conn_id, direction=direction, topic=topic): + return if self.runtime_state is not None: self.runtime_state.record_mqtt_message( conn_id=conn_id, @@ -609,7 +734,10 @@ def _handle_client(self, tls_conn: ssl.SSLSocket, addr: tuple[str, int]) -> None self.logger.warning("[conn %s] client closed before MQTT CONNECT", conn_id) return connect_packet, initial_remainder = first_packet - authorized, auth_reason, connect_info = self._authorize_connect_packet(connect_packet) + authorized, auth_reason, connect_info, onboarding_candidate = self._authorize_connect_packet_for_client( + connect_packet, + client_ip=addr[0], + ) if connect_info is not None: protocol_level = connect_info.get("protocol_level") if isinstance(protocol_level, int): @@ -637,6 +765,9 @@ def _handle_client(self, tls_conn: ssl.SSLSocket, addr: tuple[str, int]) -> None backend = socket.socket(socket.AF_INET, socket.SOCK_STREAM) backend.connect((self.backend_host, self.backend_port)) + self._register_conn_endpoints(conn_id, tls_conn, backend) + if onboarding_candidate is not None: + self._set_pending_onboarding_auth(conn_id, onboarding_candidate) c2b_frame_buf = bytearray(initial_remainder) for packet in self._extract_packets(c2b_frame_buf): self._queue_trace_packet(conn_id, "c2b", packet) @@ -655,6 +786,8 @@ def _handle_client(self, tls_conn: ssl.SSLSocket, addr: tuple[str, int]) -> None except Exception as exc: self.logger.error("[conn %s] connection error: %s", conn_id, exc) finally: + self._pop_pending_onboarding_auth(conn_id) + self._pop_conn_endpoints(conn_id) if not relay_started: for endpoint in (tls_conn, backend): if endpoint is None: diff --git a/src/roborock_local_server/bundled_backend/shared/protocol_auth.py b/src/roborock_local_server/bundled_backend/shared/protocol_auth.py index 23902e4..c876583 100644 --- a/src/roborock_local_server/bundled_backend/shared/protocol_auth.py +++ b/src/roborock_local_server/bundled_backend/shared/protocol_auth.py @@ -23,6 +23,10 @@ def _md5hex(value: str) -> str: return hashlib.md5(value.encode("utf-8")).hexdigest() +def _md5hex_bytes(value: bytes) -> str: + return hashlib.md5(value).hexdigest() + + def _parse_json_body_param_map(body_params: dict[str, list[str]]) -> dict[str, Any]: for raw in body_params.get("__json") or []: try: @@ -63,9 +67,17 @@ def _build_hawk_mac( path: str, query_values: dict[str, Any] | None, form_values: dict[str, Any] | None, + json_body: bytes | str | None = None, timestamp: int, nonce: str, ) -> str: + query_hash = _process_extra_hawk_values(query_values) + if json_body is None: + body_hash = _process_extra_hawk_values(form_values) + else: + if isinstance(json_body, str): + json_body = json_body.encode("utf-8") + body_hash = _md5hex_bytes(json_body) prestr = ":".join( [ hawk_id, @@ -73,8 +85,8 @@ def _build_hawk_mac( nonce, str(timestamp), _md5hex(path), - _process_extra_hawk_values(query_values), - _process_extra_hawk_values(form_values), + query_hash, + body_hash, ] ) return base64.b64encode(hmac.new(hawk_key.encode(), prestr.encode(), hashlib.sha256).digest()).decode() @@ -162,11 +174,16 @@ def build_hawk_authorization( path: str, query_values: dict[str, Any] | None = None, form_values: dict[str, Any] | None = None, + json_body: bytes | str | None = None, timestamp: int | None = None, nonce: str | None = None, ) -> str: ts = int(time.time() if timestamp is None else timestamp) normalized_nonce = _clean_str(nonce) or secrets.token_urlsafe(6) + if json_body is None and isinstance(form_values, Mapping): + raw_json = form_values.get("__json") + if isinstance(raw_json, (str, bytes)): + json_body = raw_json mac = _build_hawk_mac( hawk_id=user.hawk_id, hawk_session=user.hawk_session, @@ -174,6 +191,7 @@ def build_hawk_authorization( path=path, query_values=query_values, form_values=form_values, + json_body=json_body, timestamp=ts, nonce=normalized_nonce, ) @@ -533,6 +551,7 @@ def verify_hawk( query_params: dict[str, list[str]], body_params: dict[str, list[str]], headers: Mapping[str, str], + raw_body: bytes | None = None, now_ts: float | None = None, ) -> tuple[bool, str]: availability = self.availability() @@ -565,13 +584,21 @@ def verify_hawk( if not nonce: return False, "missing_hawk_nonce" + json_body: bytes | None = None + if body_params.get("__json"): + if raw_body is not None: + json_body = raw_body + else: + json_raw = next((value for value in body_params.get("__json", []) if isinstance(value, str)), "") + json_body = json_raw.encode("utf-8") expected_mac = _build_hawk_mac( hawk_id=user.hawk_id, hawk_session=user.hawk_session, hawk_key=user.hawk_key, path=path, query_values=_normalize_param_values(query_params), - form_values=_normalize_param_values(body_params, include_json=True), + form_values=None if json_body is not None else _normalize_param_values(body_params, include_json=True), + json_body=json_body, timestamp=timestamp, nonce=nonce, ) diff --git a/src/roborock_local_server/bundled_backend/shared/routine_runner.py b/src/roborock_local_server/bundled_backend/shared/routine_runner.py index 8f0b934..569b5f5 100644 --- a/src/roborock_local_server/bundled_backend/shared/routine_runner.py +++ b/src/roborock_local_server/bundled_backend/shared/routine_runner.py @@ -50,16 +50,25 @@ def _ensure_local_python_roborock_on_path() -> None: _STEP_COMPLETE_TIMEOUT_SECONDS = 4 * 60 * 60 _STEP_START_POLL_INTERVAL_SECONDS = 0.5 _STATUS_POLL_INTERVAL_SECONDS = 5.0 +_STEP_COMPLETE_CONFIRM_SECONDS = 30.0 _ROUTINE_READY_STATES = {3, 8, 100} _RESUME_BATTERY_THRESHOLD = 80 _POST_STEP_SETTLE_SECONDS = 15.0 _POST_STEP_SETTLE_TIMEOUT_SECONDS = 10 * 60 +_ACTION_LOCK_RETRY_DELAY_SECONDS = 5.0 +_ACTION_LOCK_RETRY_ATTEMPTS = 6 from .inventory_io import WEB_API_INVENTORY_FILE _SUPPORTED_METHODS = { "do_scenes_app_start", "do_scenes_segments", "do_scenes_zones", } +_STEP_START_COMMANDS = { + RoborockCommand.APP_START, + RoborockCommand.APP_SEGMENT_CLEAN, + RoborockCommand.APP_ZONED_CLEAN, +} +_ACTION_LOCK_ERROR_CODES = {-10003, -10007} _RESUME_COMMAND_BY_IN_CLEANING: dict[int, RoborockCommand] = { RoborockInCleaning.global_clean_not_complete.value: RoborockCommand.APP_START, RoborockInCleaning.zone_clean_not_complete.value: RoborockCommand.RESUME_ZONED_CLEAN, @@ -368,6 +377,37 @@ def _is_optional_unsupported_command(command: RoborockCommand, exc: Exception) - return command == RoborockCommand.SET_MOP_TEMPLATE_ID and isinstance(exc, RoborockUnsupportedFeature) +def _exception_error_payload(exc: Exception) -> dict[str, Any] | None: + args = getattr(exc, "args", ()) + if not args: + return None + payload = args[0] + return payload if isinstance(payload, dict) else None + + +def _exception_error_code(exc: Exception) -> int | None: + payload = _exception_error_payload(exc) + code = payload.get("code") if payload is not None else getattr(exc, "code", None) + return code if isinstance(code, int) else None + + +def _exception_error_message(exc: Exception) -> str: + payload = _exception_error_payload(exc) + if payload is not None: + message = payload.get("message") + if isinstance(message, str): + return message + return str(exc) + + +def _is_retryable_action_locked_error(exc: Exception) -> bool: + code = _exception_error_code(exc) + if code in _ACTION_LOCK_ERROR_CODES: + return True + message = _exception_error_message(exc).strip().lower() + return "action locked" in message or "invalid status" in message + + def _response_dps(message: RoborockMessage) -> dict[str, Any] | None: if message.payload is None: return None @@ -532,6 +572,7 @@ async def wait_for_step_complete(self) -> None: saw_activity = False saw_cleaning = False sent_resume = False + completion_candidate_since: float | None = None try: while True: @@ -555,8 +596,9 @@ async def wait_for_step_complete(self) -> None: if in_cleaning != RoborockInCleaning.complete.value: saw_cleaning = True + is_non_cleaning = in_cleaning == RoborockInCleaning.complete.value is_ready = ( - in_cleaning == RoborockInCleaning.complete.value + is_non_cleaning and state in _ROUTINE_READY_STATES ) @@ -588,8 +630,17 @@ async def wait_for_step_complete(self) -> None: saw_activity = True if sent_resume and state not in (8, 23, 26): sent_resume = False - elif saw_cleaning: - return + if saw_cleaning: + if is_non_cleaning: + if completion_candidate_since is None: + completion_candidate_since = loop.time() + elif ( + loop.time() - completion_candidate_since + >= _STEP_COMPLETE_CONFIRM_SECONDS + ): + return + else: + completion_candidate_since = None elif saw_activity: self._logger.info( "Routine wait: dock activity cycle ended (no cleaning observed), resetting" @@ -611,7 +662,7 @@ async def wait_for_step_complete(self) -> None: ) from exc - async def wait_for_dock_settle(self) -> None: + async def wait_for_dock_settle(self, *, initial_delay_seconds: float = _POST_STEP_SETTLE_SECONDS) -> None: """Wait for automatic dock activities (e.g. bin emptying) to finish before sending the next step. @@ -626,21 +677,21 @@ async def wait_for_dock_settle(self) -> None: loop = asyncio.get_running_loop() deadline = loop.time() + _POST_STEP_SETTLE_TIMEOUT_SECONDS - self._logger.info( - "Post-step settle: waiting %.0fs before checking dock activity", - _POST_STEP_SETTLE_SECONDS, - ) - await asyncio.sleep(_POST_STEP_SETTLE_SECONDS) + if initial_delay_seconds > 0: + self._logger.info( + "Post-step settle: waiting %.0fs before checking dock activity", + initial_delay_seconds, + ) + await asyncio.sleep(initial_delay_seconds) last_observed = None while True: remaining = deadline - loop.time() if remaining <= 0: - self._logger.warning( - "Post-step settle: timed out after %.0fs waiting for dock activity to finish", - _POST_STEP_SETTLE_TIMEOUT_SECONDS, + raise RoutineExecutionError( + "Timed out waiting for dock activity to finish " + f"after {_POST_STEP_SETTLE_TIMEOUT_SECONDS}s" ) - break status = await asyncio.wait_for(self.get_status(), timeout=remaining) state = _enum_or_int_value(status.state) @@ -862,6 +913,53 @@ async def _stop_scene(self, *, device_id: str, scene_id: int, scene_name: str) - finally: await client.close() + async def _send_step_command( + self, + *, + client: _RoutineMqttClient, + logger: logging.LoggerAdapter, + step: RoutineStep, + routine_command: RoutineCommand, + ) -> None: + attempts = 0 + while True: + attempts += 1 + try: + await client.send_command(routine_command.command, routine_command.params) + return + except RoborockUnsupportedFeature as exc: + if not _is_optional_unsupported_command(routine_command.command, exc): + raise + logger.warning( + "Skipping unsupported routine command step=%s method=%s command=%s: %s", + step.step_id, + step.method, + routine_command.command.value, + exc, + ) + return + except Exception as exc: # noqa: BLE001 + if ( + routine_command.command in _STEP_START_COMMANDS + and attempts < _ACTION_LOCK_RETRY_ATTEMPTS + and _is_retryable_action_locked_error(exc) + ): + logger.warning( + "Routine step=%s method=%s command=%s rejected as action locked; " + "waiting %.0fs and retrying (%s/%s)", + step.step_id, + step.method, + routine_command.command.value, + _ACTION_LOCK_RETRY_DELAY_SECONDS, + attempts, + _ACTION_LOCK_RETRY_ATTEMPTS, + ) + await client.wait_for_dock_settle( + initial_delay_seconds=_ACTION_LOCK_RETRY_DELAY_SECONDS, + ) + continue + raise + async def _run_scene(self, *, scene: dict[str, Any], steps: list[RoutineStep]) -> None: device_id = scene_device_id(scene) device = self._device_record(device_id) @@ -896,18 +994,12 @@ async def _run_scene(self, *, scene: dict[str, Any], steps: list[RoutineStep]) - routine_command.command.value, routine_command.params, ) - try: - await client.send_command(routine_command.command, routine_command.params) - except RoborockUnsupportedFeature as exc: - if not _is_optional_unsupported_command(routine_command.command, exc): - raise - logger.warning( - "Skipping unsupported routine command step=%s method=%s command=%s: %s", - step.step_id, - step.method, - routine_command.command.value, - exc, - ) + await self._send_step_command( + client=client, + logger=logger, + step=step, + routine_command=routine_command, + ) if waits_for_step_complete: logger.info("Waiting for ready state step=%s scene=%s", step.step_id, _scene_name(scene)) await client.wait_for_step_complete() diff --git a/src/roborock_local_server/bundled_backend/shared/runtime_credentials.py b/src/roborock_local_server/bundled_backend/shared/runtime_credentials.py index 6850f92..0977e49 100644 --- a/src/roborock_local_server/bundled_backend/shared/runtime_credentials.py +++ b/src/roborock_local_server/bundled_backend/shared/runtime_credentials.py @@ -653,6 +653,43 @@ def recover_device_mqtt_password(self, *, username: str, password: str) -> dict[ return dict(device) return None + def confirm_device_mqtt_credentials( + self, + *, + did: str = "", + duid: str = "", + username: str, + password: str, + ) -> dict[str, str] | None: + normalized_did = _clean_str(did) + normalized_duid = _clean_str(duid) + normalized_username = _clean_str(username) + normalized_password = _clean_str(password) + if not normalized_username or not normalized_password: + return None + with self._lock: + index = self._find_index_locked(did=normalized_did, duid=normalized_duid, model="") + if index is None: + return None + device = self._devices[index] + existing_username = _clean_str(device.get("device_mqtt_usr")) + existing_password = _clean_str(device.get("device_mqtt_pass")) + if existing_username and existing_username != normalized_username: + return None + if existing_password and existing_password != normalized_password: + return None + changed = False + if existing_username != normalized_username: + device["device_mqtt_usr"] = normalized_username + changed = True + if existing_password != normalized_password: + device["device_mqtt_pass"] = normalized_password + changed = True + if changed: + device["updated_at"] = utcnow_iso() + self._save_locked() + return dict(device) + def recovery_pending_devices(self) -> list[dict[str, str]]: with self._lock: pending = [ diff --git a/src/roborock_local_server/bundled_backend/shared/runtime_state.py b/src/roborock_local_server/bundled_backend/shared/runtime_state.py index 60c5bd1..f50f83a 100644 --- a/src/roborock_local_server/bundled_backend/shared/runtime_state.py +++ b/src/roborock_local_server/bundled_backend/shared/runtime_state.py @@ -399,6 +399,36 @@ def onboarding_session_snapshot(self) -> dict[str, Any]: def pairing_snapshot(self) -> dict[str, Any]: return self.onboarding_session_snapshot() + def onboarding_device_mqtt_candidate(self, *, client_ip: str) -> dict[str, str] | None: + normalized_ip = client_ip.strip() + if not normalized_ip: + return None + with self._lock: + session = self._pairing_session + if session is None or not session.get("active"): + return None + self._refresh_pairing_session_locked(session) + if str(session.get("identity_conflict") or "").strip(): + return None + if not str(session.get("region_at") or "").strip() or not str(session.get("nc_at") or "").strip(): + return None + target_ip = str(session.get("target_ip") or "").strip() + if not target_ip or target_ip != normalized_ip: + return None + target_did = str(session.get("target_did") or "").strip() + target_duid = str(session.get("target_duid") or "").strip() + if not target_did: + return None + key_state = self._session_key_state_locked(target_did, target_duid) + if not bool(key_state.get("has_modulus")): + return None + return { + "did": target_did, + "duid": target_duid, + "name": str(session.get("selected_name") or "").strip(), + "target_ip": target_ip, + } + def recent_events(self, *, limit: int = 200) -> list[dict[str, Any]]: with self._lock: if limit <= 0: diff --git a/src/roborock_local_server/server.py b/src/roborock_local_server/server.py index ef5f0dd..466e537 100644 --- a/src/roborock_local_server/server.py +++ b/src/roborock_local_server/server.py @@ -1031,6 +1031,7 @@ async def _handle_roborock_request(self, request: Request) -> Response: query_params=query_params, body_params=body_params, headers=request.headers, + raw_body=raw_body, ) if not authenticated: route_name = f"{required_auth}_auth_failed" diff --git a/tests/contracts/test_ios_app_init_contract.py b/tests/contracts/test_ios_app_init_contract.py index c8fc9af..d4037f5 100644 --- a/tests/contracts/test_ios_app_init_contract.py +++ b/tests/contracts/test_ios_app_init_contract.py @@ -103,12 +103,14 @@ def test_ios_app_init_contract_from_anonymized_capture(tmp_path: Path, monkeypat for index, request in enumerate(fixture["requests"]): headers = dict(default_headers) headers.update(request.get("headers", {})) + json_body = json.dumps(request["json"], separators=(",", ":")) if request.get("json") is not None else None if request["path"].startswith(("/user/", "/v2/user/", "/v3/user/")): headers["authorization"] = build_hawk_authorization( user=user, path=request["path"], query_values=request.get("query"), - form_values=request.get("form") or request.get("json"), + form_values=request.get("form"), + json_body=json_body, timestamp=fixture["frozen_time"], nonce=f"contract-{index}", ) @@ -117,7 +119,7 @@ def test_ios_app_init_contract_from_anonymized_capture(tmp_path: Path, monkeypat url=request["path"], headers=headers, params=request.get("query"), - json=request.get("json"), + content=json_body if json_body is not None else None, data=request.get("form"), ) assert response.status_code == 200, request["name"] diff --git a/tests/test_admin_api.py b/tests/test_admin_api.py index 6c722b9..215b59f 100644 --- a/tests/test_admin_api.py +++ b/tests/test_admin_api.py @@ -176,14 +176,24 @@ def _seed_protocol_snapshot(path: Path) -> None: ) -def _hawk_headers(snapshot_path: Path, path: str, *, form_values: dict[str, object] | None = None, json_values: dict[str, object] | None = None) -> dict[str, str]: +def _hawk_headers( + snapshot_path: Path, + path: str, + *, + form_values: dict[str, object] | None = None, + json_values: dict[str, object] | None = None, + json_body: str | None = None, +) -> dict[str, str]: user = ProtocolAuthStore(snapshot_path).availability().user assert user is not None + if json_body is None and json_values is not None: + json_body = json.dumps(json_values, separators=(",", ":")) return { "Authorization": build_hawk_authorization( user=user, path=path, - form_values=form_values or json_values, + form_values=form_values, + json_body=json_body, nonce=f"nonce-{path.replace('/', '-')}", ) } @@ -737,14 +747,18 @@ def test_scene_update_routes_persist_name_and_zone_ranges(tmp_path: Path) -> Non assert rename_response.json()["data"]["name"] == "After dinner" update_payload = _after_dinner_param_payload(device_id, include_ranges=False) + update_body = json.dumps(update_payload, separators=(",", ":")) update_response = client.put( "/user/scene/4491073/param", - json=update_payload, - headers=_hawk_headers( + content=update_body, + headers={ + "content-type": "application/json", + **_hawk_headers( paths.cloud_snapshot_path, "/user/scene/4491073/param", - json_values=update_payload, + json_body=update_body, ), + }, ) assert update_response.status_code == 200 @@ -761,6 +775,116 @@ def test_scene_update_routes_persist_name_and_zone_ranges(tmp_path: Path) -> Non assert second_step["params"]["data"][0]["zones"][0]["range"] == [32550, 22650, 34550, 25200] +def test_get_scenes_for_device_includes_edit_context(tmp_path: Path) -> None: + config_file = write_release_config(tmp_path) + config = load_config(config_file) + paths = resolve_paths(config_file, config) + device_id = "6HL2zfniaoYYV01CkVuhkO" + + paths.inventory_path.parent.mkdir(parents=True, exist_ok=True) + paths.inventory_path.write_text( + json.dumps( + { + "home": {"id": 1233716, "name": "My Home"}, + "devices": [ + { + "duid": device_id, + "name": "Qrevo MaxV", + "model": "roborock.vacuum.a87", + } + ], + "scenes": [ + { + "id": 4491073, + "name": "After dinner", + "device_id": device_id, + "device_name": "Qrevo MaxV", + "enabled": True, + "type": "WORKFLOW", + "param": json.dumps(_after_dinner_param_payload(device_id, include_ranges=True), separators=(",", ":")), + } + ], + } + ) + + "\n", + encoding="utf-8", + ) + _seed_protocol_snapshot(paths.cloud_snapshot_path) + + supervisor = ReleaseSupervisor(config=config, paths=paths) + client = TestClient(supervisor.app) + + response = client.get( + f"/user/scene/device/{device_id}", + headers=_hawk_headers(paths.cloud_snapshot_path, f"/user/scene/device/{device_id}"), + ) + assert response.status_code == 200 + + scenes = response.json()["data"] + assert len(scenes) == 1 + assert scenes[0]["homeId"] == 1233716 + assert scenes[0]["deviceId"] == device_id + assert scenes[0]["deviceName"] == "Qrevo MaxV" + + +def test_post_scene_create_accepts_hawk_json_body_signature(tmp_path: Path) -> None: + config_file = write_release_config(tmp_path) + config = load_config(config_file) + paths = resolve_paths(config_file, config) + device_id = "6HL2zfniaoYYV01CkVuhkO" + + paths.inventory_path.parent.mkdir(parents=True, exist_ok=True) + paths.inventory_path.write_text( + json.dumps( + { + "home": {"id": 1233716, "name": "My Home"}, + "devices": [ + { + "duid": device_id, + "name": "Qrevo MaxV", + "model": "roborock.vacuum.a87", + } + ], + "scenes": [], + } + ) + + "\n", + encoding="utf-8", + ) + _seed_protocol_snapshot(paths.cloud_snapshot_path) + _write_scene_zone_trace(paths.mqtt_jsonl_path) + + supervisor = ReleaseSupervisor(config=config, paths=paths) + client = TestClient(supervisor.app) + + create_payload = { + "name": "Party prep", + "homeId": "1233716", + "param": { + **_after_dinner_param_payload(device_id, include_ranges=False), + "tagId": 1002, + }, + } + create_body = json.dumps(create_payload, separators=(",", ":")) + response = client.post( + "/v2/user/scene", + content=create_body, + headers={ + "content-type": "application/json", + **_hawk_headers( + paths.cloud_snapshot_path, + "/v2/user/scene", + json_body=create_body, + ), + }, + ) + assert response.status_code == 200 + assert response.json()["data"]["name"] == "Party prep" + + stored_inventory = json.loads(paths.inventory_path.read_text(encoding="utf-8")) + assert any(scene["name"] == "Party prep" for scene in stored_inventory["scenes"]) + + def test_execute_scene_hydrates_missing_zone_ranges_from_mqtt(tmp_path: Path) -> None: config_file = write_release_config(tmp_path) config = load_config(config_file) diff --git a/tests/test_mqtt_tls_proxy.py b/tests/test_mqtt_tls_proxy.py index 0e54cda..2faaa20 100644 --- a/tests/test_mqtt_tls_proxy.py +++ b/tests/test_mqtt_tls_proxy.py @@ -3,11 +3,14 @@ import socket import threading import time +from datetime import datetime, timezone from pathlib import Path import pytest from roborock_local_server.backend import MqttTlsProxy +from roborock_local_server.bundled_backend.shared.runtime_credentials import RuntimeCredentialsStore +from roborock_local_server.bundled_backend.shared.runtime_state import RuntimeState class _FakeSourceSocket: @@ -110,6 +113,19 @@ def _seed_protocol_sessions(path: Path) -> None: ) +def _seed_key_state(path: Path, *, did: str) -> None: + _write_json( + path, + { + "devices": { + did: { + "modulus_hex": "ab", + } + } + }, + ) + + def _build_connect_packet(*, client_id: str, username: str, password: str, protocol_level: int = 4) -> bytes: protocol_name = b"MQTT" variable_header = ( @@ -132,6 +148,12 @@ def _build_connect_packet(*, client_id: str, username: str, password: str, proto return bytes([0x10, len(remaining)]) + remaining +def _build_publish_packet(*, topic: str, payload: bytes = b"{}") -> bytes: + topic_bytes = topic.encode() + remaining = len(topic_bytes).to_bytes(2, "big") + topic_bytes + payload + return bytes([0x30, len(remaining)]) + remaining + + def test_relay_forwards_chunk_before_slow_packet_tracing_finishes(tmp_path, monkeypatch) -> None: cloud_snapshot_path = tmp_path / "cloud_snapshot.json" _seed_cloud_snapshot(cloud_snapshot_path) @@ -412,6 +434,225 @@ def test_authorize_connect_recovers_missing_known_device_mqtt_password(tmp_path) assert reject_reason == "invalid_mqtt_credentials" +def test_authorize_connect_accepts_unknown_device_credentials_only_for_matching_onboarding_session(tmp_path) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + key_state_path = tmp_path / "device_key_state.json" + _seed_key_state(key_state_path, did="1103821560705") + runtime_credentials_path = tmp_path / "runtime_credentials.json" + _write_json( + runtime_credentials_path, + { + "schema_version": 2, + "mqtt_usr": "bootstrap-user", + "mqtt_passwd": "bootstrap-pass", + "mqtt_clientid": "bootstrap-client", + "devices": [ + { + "did": "1103821560705", + "duid": "6HL2zfniaoYYV01CkVuhkO", + "name": "Roborock Qrevo MaxV 2", + "model": "roborock.vacuum.a87", + "product_id": "5gUei3OIJIXVD3eD85Balg", + "localkey": "xPd5Dr8CGGqtdDlH", + "local_key_source": "inventory", + "device_mqtt_usr": "", + "device_mqtt_pass": "", + "updated_at": "2026-04-17T17:00:00+00:00", + "last_nc_at": "", + "last_mqtt_seen_at": "", + } + ], + }, + ) + runtime_credentials = RuntimeCredentialsStore(runtime_credentials_path) + runtime_state = RuntimeState(log_dir=tmp_path, key_state_file=key_state_path, runtime_credentials=runtime_credentials) + runtime_state.upsert_vacuum("6HL2zfniaoYYV01CkVuhkO", name="Roborock Qrevo MaxV 2", id_kind="duid") + runtime_state.start_onboarding_session(target_duid="6HL2zfniaoYYV01CkVuhkO", target_name="Roborock Qrevo MaxV 2") + event_time = datetime.now(timezone.utc).isoformat() + for route_name, path_name in (("region", "/region"), ("nc_prepare", "/nc")): + runtime_state.record_http_event( + event_time=event_time, + route_name=route_name, + clean_path=path_name, + raw_path=path_name, + method="GET", + host="api-roborock.example.com", + remote="192.168.8.10:54321", + did="1103821560705", + ) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + runtime_state=runtime_state, + runtime_credentials=runtime_credentials, + ) + + packet = _build_connect_packet( + client_id="a012391cb5f8bc97", + username="c25b14ceac358d2a", + password="ff8922d24a9a9af81f18f35dcee9a5a5", + ) + authorized, reason, info, candidate = proxy._authorize_connect_packet_for_client( + packet, + client_ip="192.168.8.10", + ) + + assert authorized is True + assert reason == "device_mqtt_onboarding_pending" + assert info is not None + assert candidate is not None + assert candidate["did"] == "1103821560705" + persisted = runtime_credentials.resolve_device(did="1103821560705") + assert persisted is not None + assert persisted["device_mqtt_usr"] == "" + assert persisted["device_mqtt_pass"] == "" + + rejected, reject_reason, _info, rejected_candidate = proxy._authorize_connect_packet_for_client( + packet, + client_ip="192.168.8.11", + ) + assert rejected is False + assert reject_reason == "invalid_mqtt_credentials" + assert rejected_candidate is None + + +def test_trace_packet_persists_confirmed_onboarding_device_mqtt_credentials(tmp_path) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + runtime_credentials_path = tmp_path / "runtime_credentials.json" + _write_json( + runtime_credentials_path, + { + "schema_version": 2, + "devices": [ + { + "did": "1103821560705", + "duid": "6HL2zfniaoYYV01CkVuhkO", + "name": "Roborock Qrevo MaxV 2", + "model": "roborock.vacuum.a87", + "product_id": "5gUei3OIJIXVD3eD85Balg", + "localkey": "xPd5Dr8CGGqtdDlH", + "local_key_source": "inventory", + "device_mqtt_usr": "", + "device_mqtt_pass": "", + "updated_at": "2026-04-17T17:00:00+00:00", + "last_nc_at": "", + "last_mqtt_seen_at": "", + } + ], + }, + ) + runtime_credentials = RuntimeCredentialsStore(runtime_credentials_path) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + runtime_credentials=runtime_credentials, + ) + proxy._set_pending_onboarding_auth( + "1", + { + "did": "1103821560705", + "duid": "6HL2zfniaoYYV01CkVuhkO", + "name": "Roborock Qrevo MaxV 2", + "username": "c25b14ceac358d2a", + "password": "ff8922d24a9a9af81f18f35dcee9a5a5", + "client_ip": "192.168.8.10", + }, + ) + proxy._register_conn_endpoints("1", _FakeSourceSocket(), _FakeBackendSocket()) + + proxy._trace_packet("1", "c2b", _build_publish_packet(topic="rr/d/i/1103821560705/c25b14ceac358d2a")) + + persisted = runtime_credentials.resolve_device(did="1103821560705") + assert persisted is not None + assert persisted["device_mqtt_usr"] == "c25b14ceac358d2a" + assert persisted["device_mqtt_pass"] == "ff8922d24a9a9af81f18f35dcee9a5a5" + assert proxy._get_pending_onboarding_auth("1") is None + + +def test_trace_packet_closes_provisional_onboarding_session_when_first_publish_topic_mismatches(tmp_path) -> None: + cloud_snapshot_path = tmp_path / "cloud_snapshot.json" + _seed_cloud_snapshot(cloud_snapshot_path) + runtime_credentials_path = tmp_path / "runtime_credentials.json" + _write_json( + runtime_credentials_path, + { + "schema_version": 2, + "devices": [ + { + "did": "1103821560705", + "duid": "6HL2zfniaoYYV01CkVuhkO", + "name": "Roborock Qrevo MaxV 2", + "model": "roborock.vacuum.a87", + "product_id": "5gUei3OIJIXVD3eD85Balg", + "localkey": "xPd5Dr8CGGqtdDlH", + "local_key_source": "inventory", + "device_mqtt_usr": "", + "device_mqtt_pass": "", + "updated_at": "2026-04-17T17:00:00+00:00", + "last_nc_at": "", + "last_mqtt_seen_at": "", + } + ], + }, + ) + runtime_credentials = RuntimeCredentialsStore(runtime_credentials_path) + proxy = MqttTlsProxy( + cert_file=tmp_path / "fullchain.pem", + key_file=tmp_path / "privkey.pem", + listen_host="127.0.0.1", + listen_port=8883, + backend_host="127.0.0.1", + backend_port=1883, + localkey="test-local-key", + logger=logging.getLogger("test.mqtt_tls_proxy"), + decoded_jsonl=tmp_path / "decoded.jsonl", + cloud_snapshot_path=cloud_snapshot_path, + runtime_credentials=runtime_credentials, + ) + client_sock = _FakeSourceSocket() + backend_sock = _FakeBackendSocket() + proxy._set_pending_onboarding_auth( + "1", + { + "did": "1103821560705", + "duid": "6HL2zfniaoYYV01CkVuhkO", + "name": "Roborock Qrevo MaxV 2", + "username": "c25b14ceac358d2a", + "password": "ff8922d24a9a9af81f18f35dcee9a5a5", + "client_ip": "192.168.8.10", + }, + ) + proxy._register_conn_endpoints("1", client_sock, backend_sock) + + proxy._trace_packet("1", "c2b", _build_publish_packet(topic="rr/d/i/9999999999999/c25b14ceac358d2a")) + + persisted = runtime_credentials.resolve_device(did="1103821560705") + assert persisted is not None + assert persisted["device_mqtt_usr"] == "" + assert persisted["device_mqtt_pass"] == "" + assert client_sock.closed is True + assert backend_sock.closed is True + assert proxy._get_pending_onboarding_auth("1") is None + + def test_authorize_connect_accepts_persisted_synced_user_hash_credentials(tmp_path) -> None: cloud_snapshot_path = tmp_path / "cloud_snapshot.json" _seed_cloud_snapshot(cloud_snapshot_path) diff --git a/tests/test_routine_runner.py b/tests/test_routine_runner.py index cbe2910..0add843 100644 --- a/tests/test_routine_runner.py +++ b/tests/test_routine_runner.py @@ -6,6 +6,7 @@ import pytest from roborock.data import StatusV2 +from roborock.exceptions import RoborockException from roborock_local_server.bundled_backend.shared.context import ServerContext import roborock_local_server.bundled_backend.shared.routine_runner as routine_runner_module from roborock_local_server.bundled_backend.shared.routine_runner import RoutineRunner, parse_scene_steps @@ -262,6 +263,118 @@ async def wait_for_step_complete(self) -> None: asyncio.run(exercise()) +def test_run_scene_retries_step_start_when_device_is_action_locked(tmp_path: Path, monkeypatch) -> None: + async def exercise() -> None: + device_id = "6HL2zfniaoYYV01CkVuhkO" + scene = { + "id": 4491073, + "name": "Night clean", + "device_id": device_id, + "param": json.dumps( + { + "action": { + "items": [ + { + "id": 1, + "type": "CMD", + "name": "Step 1", + "finishDpIds": [130], + "param": json.dumps( + { + "method": "do_scenes_segments", + "params": { + "data": [ + { + "tid": "1755507280460", + "segs": [{"sid": 18}], + "fan_power": 108, + "repeat": 1, + } + ] + }, + }, + separators=(",", ":"), + ), + }, + { + "id": 2, + "type": "CMD", + "name": "Step 2", + "finishDpIds": [130], + "param": json.dumps( + { + "method": "do_scenes_segments", + "params": { + "data": [ + { + "tid": "1755507296636", + "segs": [{"sid": 19}], + "fan_power": 103, + "repeat": 1, + } + ] + }, + }, + separators=(",", ":"), + ), + }, + ] + } + }, + separators=(",", ":"), + ), + } + + sent_commands: list[tuple[RoborockCommand, object]] = [] + settle_calls: list[float] = [] + segment_attempts = 0 + + class FakeRoutineClient: + def __init__(self, context, device, logger) -> None: + _ = context, device, logger + + async def connect(self) -> None: + return None + + async def close(self) -> None: + return None + + async def send_command(self, command, params=None): + nonlocal segment_attempts + sent_commands.append((command, params)) + if command == RoborockCommand.APP_SEGMENT_CLEAN: + segment_attempts += 1 + if segment_attempts == 2: + raise RoborockException({"code": -10003, "message": "action locked"}) + return ["ok"] + + async def wait_for_step_complete(self) -> None: + return None + + async def wait_for_dock_settle(self, *, initial_delay_seconds=15.0) -> None: + settle_calls.append(float(initial_delay_seconds)) + + monkeypatch.setattr(routine_runner_module, "_RoutineMqttClient", FakeRoutineClient) + + runner = RoutineRunner(_test_context(tmp_path)) + await runner._run_scene(scene=scene, steps=parse_scene_steps(scene)) + + assert sent_commands == [ + ( + RoborockCommand.REUNION_SCENES, + {"data": [{"tid": "1755507280460"}, {"tid": "1755507296636"}]}, + ), + (RoborockCommand.SET_CUSTOM_MODE, [108]), + (RoborockCommand.APP_SEGMENT_CLEAN, [{"segments": [18], "repeat": 1}]), + (RoborockCommand.SET_CUSTOM_MODE, [103]), + (RoborockCommand.APP_SEGMENT_CLEAN, [{"segments": [19], "repeat": 1}]), + (RoborockCommand.APP_SEGMENT_CLEAN, [{"segments": [19], "repeat": 1}]), + ] + assert settle_calls == [15.0, 5.0] + + asyncio.run(exercise()) + + # --------------------------------------------------------------------------- # wait_for_step_complete tests # --------------------------------------------------------------------------- @@ -290,10 +403,18 @@ async def send_command(self, command: RoborockCommand, params=None) -> None: _ScriptedStatusClient.wait_for_step_complete = ( routine_runner_module._RoutineMqttClient.wait_for_step_complete ) +_ScriptedStatusClient.wait_for_dock_settle = ( + routine_runner_module._RoutineMqttClient.wait_for_dock_settle +) + + +def _set_fast_wait_constants(monkeypatch, *, confirm_seconds: float = 0.0) -> None: + monkeypatch.setattr(routine_runner_module, "_STEP_COMPLETE_CONFIRM_SECONDS", confirm_seconds) def test_wait_for_step_complete_dock_activity_does_not_end_step(monkeypatch) -> None: """Dock activity (emptying bin) followed by ready must not declare step complete.""" + _set_fast_wait_constants(monkeypatch) monkeypatch.setattr(routine_runner_module, "_STEP_START_TIMEOUT_SECONDS", 0.1) monkeypatch.setattr(routine_runner_module, "_STEP_START_POLL_INTERVAL_SECONDS", 0.0) monkeypatch.setattr(routine_runner_module, "_STATUS_POLL_INTERVAL_SECONDS", 0.0) @@ -317,6 +438,7 @@ async def exercise() -> None: def test_wait_for_step_complete_actual_cleaning_completes(monkeypatch) -> None: """Step completes when in_cleaning becomes non-zero then robot returns to ready.""" + _set_fast_wait_constants(monkeypatch) monkeypatch.setattr(routine_runner_module, "_STEP_START_POLL_INTERVAL_SECONDS", 0.0) monkeypatch.setattr(routine_runner_module, "_STATUS_POLL_INTERVAL_SECONDS", 0.0) @@ -334,6 +456,7 @@ async def exercise() -> None: def test_wait_for_step_complete_dock_then_cleaning_completes(monkeypatch) -> None: """Dock activity followed by actual cleaning should complete after cleaning finishes.""" + _set_fast_wait_constants(monkeypatch) monkeypatch.setattr(routine_runner_module, "_STEP_START_POLL_INTERVAL_SECONDS", 0.0) monkeypatch.setattr(routine_runner_module, "_STATUS_POLL_INTERVAL_SECONDS", 0.0) @@ -354,6 +477,7 @@ async def exercise() -> None: def test_wait_for_step_complete_start_timeout(monkeypatch) -> None: """Raises RoutineExecutionError when robot stays in ready state past start deadline.""" + _set_fast_wait_constants(monkeypatch) monkeypatch.setattr(routine_runner_module, "_STEP_START_TIMEOUT_SECONDS", 0.1) monkeypatch.setattr(routine_runner_module, "_STEP_START_POLL_INTERVAL_SECONDS", 0.0) @@ -374,6 +498,7 @@ async def exercise() -> None: def test_wait_for_step_complete_resume_after_mid_clean_charge(monkeypatch) -> None: """Robot returns to dock mid-clean with low battery, charges, gets resumed, completes.""" + _set_fast_wait_constants(monkeypatch) monkeypatch.setattr(routine_runner_module, "_STEP_START_POLL_INTERVAL_SECONDS", 0.0) monkeypatch.setattr(routine_runner_module, "_STATUS_POLL_INTERVAL_SECONDS", 0.0) monkeypatch.setattr(routine_runner_module, "_RESUME_BATTERY_THRESHOLD", 80) @@ -400,6 +525,7 @@ async def exercise() -> None: def test_wait_for_step_complete_resume_zoned_clean(monkeypatch) -> None: """Resume uses correct command for zone cleaning.""" + _set_fast_wait_constants(monkeypatch) monkeypatch.setattr(routine_runner_module, "_STEP_START_POLL_INTERVAL_SECONDS", 0.0) monkeypatch.setattr(routine_runner_module, "_STATUS_POLL_INTERVAL_SECONDS", 0.0) monkeypatch.setattr(routine_runner_module, "_RESUME_BATTERY_THRESHOLD", 80) @@ -420,6 +546,7 @@ async def exercise() -> None: def test_wait_for_step_complete_no_resume_when_battery_low(monkeypatch) -> None: """No resume sent while battery is below threshold.""" + _set_fast_wait_constants(monkeypatch) monkeypatch.setattr(routine_runner_module, "_STEP_START_POLL_INTERVAL_SECONDS", 0.0) monkeypatch.setattr(routine_runner_module, "_STATUS_POLL_INTERVAL_SECONDS", 0.0) monkeypatch.setattr(routine_runner_module, "_STEP_COMPLETE_TIMEOUT_SECONDS", 0.1) @@ -441,6 +568,7 @@ async def exercise() -> None: def test_wait_for_step_complete_resume_only_sent_once(monkeypatch) -> None: """Resume command is only sent once even if robot returns to dock again.""" + _set_fast_wait_constants(monkeypatch) monkeypatch.setattr(routine_runner_module, "_STEP_START_POLL_INTERVAL_SECONDS", 0.0) monkeypatch.setattr(routine_runner_module, "_STATUS_POLL_INTERVAL_SECONDS", 0.0) monkeypatch.setattr(routine_runner_module, "_RESUME_BATTERY_THRESHOLD", 80) @@ -458,3 +586,41 @@ async def exercise() -> None: assert len(client.sent_commands) == 2 asyncio.run(exercise()) + + +def test_wait_for_step_complete_requires_stable_non_cleaning_state(monkeypatch) -> None: + """A brief non-cleaning status must not end the step if cleaning resumes.""" + monkeypatch.setattr(routine_runner_module, "_STEP_COMPLETE_CONFIRM_SECONDS", 0.02) + monkeypatch.setattr(routine_runner_module, "_STEP_START_POLL_INTERVAL_SECONDS", 0.0) + monkeypatch.setattr(routine_runner_module, "_STATUS_POLL_INTERVAL_SECONDS", 0.01) + + async def exercise() -> None: + client = _ScriptedStatusClient([ + {"state": 18, "in_cleaning": 3}, + {"state": 8, "in_cleaning": 0}, + {"state": 18, "in_cleaning": 3}, + {"state": 18, "in_cleaning": 3}, + {"state": 8, "in_cleaning": 0}, + {"state": 8, "in_cleaning": 0}, + {"state": 8, "in_cleaning": 0}, + ]) + await client.wait_for_step_complete() + + asyncio.run(exercise()) + + +def test_wait_for_dock_settle_timeout_raises(monkeypatch) -> None: + _set_fast_wait_constants(monkeypatch) + monkeypatch.setattr(routine_runner_module, "_POST_STEP_SETTLE_TIMEOUT_SECONDS", 0.05) + monkeypatch.setattr(routine_runner_module, "_STATUS_POLL_INTERVAL_SECONDS", 0.0) + + async def exercise() -> None: + client = _ScriptedStatusClient([ + {"state": 18, "in_cleaning": 3}, + {"state": 18, "in_cleaning": 3}, + {"state": 18, "in_cleaning": 3}, + ]) + with pytest.raises(routine_runner_module.RoutineExecutionError, match="dock activity"): + await client.wait_for_dock_settle(initial_delay_seconds=0.0) + + asyncio.run(exercise()) diff --git a/tests/test_runtime_state.py b/tests/test_runtime_state.py index 9640990..20afe86 100644 --- a/tests/test_runtime_state.py +++ b/tests/test_runtime_state.py @@ -132,3 +132,72 @@ def test_runtime_state_onboarding_session_reports_identity_conflict(tmp_path: Pa assert linked["duid"] == "cloud-q7-b" assert "already linked to DUID cloud-q7-b" in snapshot["identity_conflict"] assert snapshot["status"] == "conflict" + + +def test_runtime_state_onboarding_device_mqtt_candidate_requires_matching_ip_and_public_key(tmp_path: Path) -> None: + credentials_path = tmp_path / "runtime_credentials.json" + credentials_path.write_text( + json.dumps( + { + "schema_version": 2, + "devices": [ + { + "did": "", + "duid": "cloud-q7-a", + "name": "Q7 Upstairs", + "model": "roborock.vacuum.sc05", + "product_id": "product-q7-a", + "localkey": "local-key-a", + } + ], + } + ) + + "\n", + encoding="utf-8", + ) + key_state_path = tmp_path / "device_key_state.json" + key_state_path.write_text( + json.dumps( + { + "devices": { + "1103821560705": { + "modulus_hex": "ab", + } + } + } + ) + + "\n", + encoding="utf-8", + ) + credentials = RuntimeCredentialsStore(credentials_path) + state = RuntimeState(log_dir=tmp_path, key_state_file=key_state_path, runtime_credentials=credentials) + state.upsert_vacuum("cloud-q7-a", name="Q7 Upstairs", id_kind="duid") + state.start_onboarding_session(target_duid="cloud-q7-a", target_name="Q7 Upstairs") + + event_time = datetime.now(timezone.utc).isoformat() + state.record_http_event( + event_time=event_time, + route_name="region", + clean_path="/region", + raw_path="/region", + method="GET", + host="api-roborock.example.com", + remote="192.168.8.10:54321", + did="1103821560705", + ) + state.record_http_event( + event_time=event_time, + route_name="nc_prepare", + clean_path="/nc", + raw_path="/nc", + method="GET", + host="api-roborock.example.com", + remote="192.168.8.10:54321", + did="1103821560705", + ) + + candidate = state.onboarding_device_mqtt_candidate(client_ip="192.168.8.10") + assert candidate is not None + assert candidate["did"] == "1103821560705" + assert candidate["duid"] == "cloud-q7-a" + assert state.onboarding_device_mqtt_candidate(client_ip="192.168.8.11") is None From cc2fd621c7105fc59be3175bf6874ef1ae8a44c2 Mon Sep 17 00:00:00 2001 From: Luke Date: Sat, 2 May 2026 21:08:55 -0400 Subject: [PATCH 2/2] fix copilot comment --- .gitignore | 3 ++- .../bundled_backend/mqtt_tls_proxy_server/server.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 66be241..323723a 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ data/ site/ secrets/ config.toml -mitm_logs \ No newline at end of file +mitm_logs +dist/ \ No newline at end of file diff --git a/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py b/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py index 531f8c5..c18540a 100644 --- a/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py +++ b/src/roborock_local_server/bundled_backend/mqtt_tls_proxy_server/server.py @@ -297,7 +297,8 @@ def _resolve_onboarding_device_mqtt_candidate( if device is None: return None existing_username = str(device.get("device_mqtt_usr") or "").strip() - if existing_username: + existing_password = str(device.get("device_mqtt_pass") or "").strip() + if existing_username or existing_password: return None return { "did": str(device.get("did") or "").strip(),