From fd5dd25299191f5092deea808dd97af8af790dc9 Mon Sep 17 00:00:00 2001 From: Dmitrii Amelin Date: Wed, 20 May 2026 00:04:39 +0200 Subject: [PATCH] fix state trigger hold on unrelated state updates --- .../pyscript/decorators/state.py | 35 +++++++++++-------- tests/test_function.py | 23 ++++++++++++ 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/custom_components/pyscript/decorators/state.py b/custom_components/pyscript/decorators/state.py index 4c1b44a..889d296 100644 --- a/custom_components/pyscript/decorators/state.py +++ b/custom_components/pyscript/decorators/state.py @@ -149,11 +149,13 @@ def _diff(self, dt: float, now: float) -> str: return "None" return f"{(now - dt):g} ago" - async def _check_new_state(self, trig_ok: bool) -> None: + async def _check_new_state( + self, trig_ok: bool, new_vars: dict[str, Any], func_args: dict[str, Any] + ) -> None: now = asyncio.get_running_loop().time() if _LOGGER.isEnabledFor(logging.DEBUG): msg = f"check_new_state: {self}" - msg += f"\ntrig_ok: {trig_ok} now {now} func_args: {self.last_func_args} new_vars: {self.last_new_vars}" + msg += f"\ntrig_ok: {trig_ok} now {now} func_args: {func_args} new_vars: {new_vars}" if self.true_entered_at: msg += f"\ntrue_entered_at: {self.true_entered_at}({(now - self.true_entered_at):g} ago)\n" if self.false_entered_at: @@ -179,6 +181,8 @@ async def _check_new_state(self, trig_ok: bool) -> None: if state_hold_false_passed: if self.state_hold is None: + self.last_new_vars = new_vars + self.last_func_args = func_args state_hold_true_passed = True else: if self.true_entered_at: @@ -192,9 +196,10 @@ async def _check_new_state(self, trig_ok: bool) -> None: else: _LOGGER.debug("state_hold started, %s", self) self.true_entered_at = now + self.last_new_vars = new_vars + self.last_func_args = func_args if state_hold_true_passed: - self.true_entered_at = None await self.dispatch( DispatchData(self.last_func_args, trigger_context={"new_vars": self.last_new_vars}) ) @@ -231,14 +236,14 @@ async def _cycle(self) -> None: check_state_expr_on_start = self.state_check_now or self.state_hold_false is not None if check_state_expr_on_start: - self.last_new_vars = State.notify_var_get(self.state_trig_ident, {}) - trig_ok = await self._is_trig_ok() + new_vars = State.notify_var_get(self.state_trig_ident, {}) + trig_ok = await self._is_trig_ok(new_vars) if self.in_wait_until_function and trig_ok and self.state_check_now is True: self.state_hold_false = None if self.state_check_now and self.has_expression(): - await self._check_new_state(trig_ok) + await self._check_new_state(trig_ok, new_vars, self.last_func_args) else: if not trig_ok and self.state_hold_false is not None: self.false_entered_at = loop.time() @@ -273,22 +278,22 @@ async def _cycle(self) -> None: notify_type, notify_info = await asyncio.wait_for(self.notify_q.get(), effective_timeout) if notify_type != "state": raise RuntimeError(f"Invalid notify_type {notify_type}, {self}") - self.last_new_vars = notify_info[0] - self.last_func_args = notify_info[1] + new_vars = notify_info[0] + func_args = notify_info[1] - if ident_any_values_changed(self.last_func_args, self.state_trig_ident_any): + if ident_any_values_changed(func_args, self.state_trig_ident_any): trig_ok = True - elif ident_values_changed(self.last_func_args, self.state_trig_ident): - trig_ok = await self._is_trig_ok() + elif ident_values_changed(func_args, self.state_trig_ident): + trig_ok = await self._is_trig_ok(new_vars) else: - trig_ok = False - await self._check_new_state(trig_ok) + continue + await self._check_new_state(trig_ok, new_vars, func_args) except TimeoutError: await self._check_state_hold() - async def _is_trig_ok(self) -> bool: + async def _is_trig_ok(self, new_vars: dict[str, Any]) -> bool: if self.has_expression(): - return await self.check_expression_vars(self.last_new_vars) + return await self.check_expression_vars(new_vars) return True def _on_task_done(self, task: asyncio.Task) -> None: diff --git a/tests/test_function.py b/tests/test_function.py index f87fbe1..8c54428 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -440,6 +440,13 @@ def func10d(var_name=None, value=None, trigger_type=None, context=None, old_valu log.info(f"func10d var = {var_name}, value = {value}, kwargs = {kwargs}") pyscript.done = [seq_num, var_name, kwargs] +@state_trigger("pyscript.f11var1 == 'playing'", state_hold=1e-6) +def func11(var_name=None, value=None, old_value=None): + global seq_num + + seq_num += 1 + pyscript.done = [seq_num, var_name, old_value, value, value.position, pyscript.f11var1.position] + """, ) # initialize the trigger and active variables @@ -672,6 +679,22 @@ def func10d(var_name=None, value=None, trigger_type=None, context=None, old_valu hass.states.async_set("pyscript.f8bvar1", 30) hass.states.async_set("pyscript.f8bvar1", 31) + # + # check that state_hold isn't cancelled by unrelated attribute-only updates + # + seq_num += 1 + hass.states.async_set("pyscript.f11var1", "stop") + hass.states.async_set("pyscript.f11var1", "playing", {"position": 1}) + hass.states.async_set("pyscript.f11var1", "playing", {"position": 2}) + assert literal_eval(await wait_until_done(notify_q)) == [ + seq_num, + "pyscript.f11var1", + "stop", + "playing", + 1, + 2, + ] + # # check that state_var.old is None first time #