Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions custom_components/pyscript/decorators/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,13 @@ def _diff(self, dt: float, now: float) -> str:
return "None"
return f"{(now - dt):g} ago"

async def _check_new_state(self, trig_ok: bool) -> None:
async def _check_new_state(
self, trig_ok: bool, new_vars: dict[str, Any], func_args: dict[str, Any]
) -> None:
now = asyncio.get_running_loop().time()
if _LOGGER.isEnabledFor(logging.DEBUG):
msg = f"check_new_state: {self}"
msg += f"\ntrig_ok: {trig_ok} now {now} func_args: {self.last_func_args} new_vars: {self.last_new_vars}"
msg += f"\ntrig_ok: {trig_ok} now {now} func_args: {func_args} new_vars: {new_vars}"
if self.true_entered_at:
msg += f"\ntrue_entered_at: {self.true_entered_at}({(now - self.true_entered_at):g} ago)\n"
if self.false_entered_at:
Expand All @@ -179,6 +181,8 @@ async def _check_new_state(self, trig_ok: bool) -> None:

if state_hold_false_passed:
if self.state_hold is None:
self.last_new_vars = new_vars
self.last_func_args = func_args
state_hold_true_passed = True
else:
if self.true_entered_at:
Expand All @@ -192,9 +196,10 @@ async def _check_new_state(self, trig_ok: bool) -> None:
else:
_LOGGER.debug("state_hold started, %s", self)
self.true_entered_at = now
self.last_new_vars = new_vars
self.last_func_args = func_args

if state_hold_true_passed:
self.true_entered_at = None
await self.dispatch(
DispatchData(self.last_func_args, trigger_context={"new_vars": self.last_new_vars})
)
Expand Down Expand Up @@ -231,14 +236,14 @@ async def _cycle(self) -> None:
check_state_expr_on_start = self.state_check_now or self.state_hold_false is not None

if check_state_expr_on_start:
self.last_new_vars = State.notify_var_get(self.state_trig_ident, {})
trig_ok = await self._is_trig_ok()
new_vars = State.notify_var_get(self.state_trig_ident, {})
trig_ok = await self._is_trig_ok(new_vars)

if self.in_wait_until_function and trig_ok and self.state_check_now is True:
self.state_hold_false = None

if self.state_check_now and self.has_expression():
await self._check_new_state(trig_ok)
await self._check_new_state(trig_ok, new_vars, self.last_func_args)
else:
if not trig_ok and self.state_hold_false is not None:
self.false_entered_at = loop.time()
Expand Down Expand Up @@ -273,22 +278,22 @@ async def _cycle(self) -> None:
notify_type, notify_info = await asyncio.wait_for(self.notify_q.get(), effective_timeout)
if notify_type != "state":
raise RuntimeError(f"Invalid notify_type {notify_type}, {self}")
self.last_new_vars = notify_info[0]
self.last_func_args = notify_info[1]
new_vars = notify_info[0]
func_args = notify_info[1]

if ident_any_values_changed(self.last_func_args, self.state_trig_ident_any):
if ident_any_values_changed(func_args, self.state_trig_ident_any):
trig_ok = True
elif ident_values_changed(self.last_func_args, self.state_trig_ident):
trig_ok = await self._is_trig_ok()
elif ident_values_changed(func_args, self.state_trig_ident):
trig_ok = await self._is_trig_ok(new_vars)
else:
trig_ok = False
await self._check_new_state(trig_ok)
continue
await self._check_new_state(trig_ok, new_vars, func_args)
except TimeoutError:
await self._check_state_hold()

async def _is_trig_ok(self) -> bool:
async def _is_trig_ok(self, new_vars: dict[str, Any]) -> bool:
if self.has_expression():
return await self.check_expression_vars(self.last_new_vars)
return await self.check_expression_vars(new_vars)
return True

def _on_task_done(self, task: asyncio.Task) -> None:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,13 @@ def func10d(var_name=None, value=None, trigger_type=None, context=None, old_valu
log.info(f"func10d var = {var_name}, value = {value}, kwargs = {kwargs}")
pyscript.done = [seq_num, var_name, kwargs]

@state_trigger("pyscript.f11var1 == 'playing'", state_hold=1e-6)
def func11(var_name=None, value=None, old_value=None):
global seq_num

seq_num += 1
pyscript.done = [seq_num, var_name, old_value, value, value.position, pyscript.f11var1.position]

""",
)
# initialize the trigger and active variables
Expand Down Expand Up @@ -672,6 +679,22 @@ def func10d(var_name=None, value=None, trigger_type=None, context=None, old_valu
hass.states.async_set("pyscript.f8bvar1", 30)
hass.states.async_set("pyscript.f8bvar1", 31)

#
# check that state_hold isn't cancelled by unrelated attribute-only updates
#
seq_num += 1
hass.states.async_set("pyscript.f11var1", "stop")
hass.states.async_set("pyscript.f11var1", "playing", {"position": 1})
hass.states.async_set("pyscript.f11var1", "playing", {"position": 2})
assert literal_eval(await wait_until_done(notify_q)) == [
seq_num,
"pyscript.f11var1",
"stop",
"playing",
1,
2,
]

#
# check that state_var.old is None first time
#
Expand Down
Loading