Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions custom_components/pyscript/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,29 +258,50 @@ async def _call(self, data: DispatchData) -> None:
handlers = self.get_decorators(CallHandlerDecorator)
result_handlers = self.get_decorators(CallResultHandlerDecorator)

for handler_dec in handlers:
if await handler_dec.handle_call(data) is False:
self.logger.debug("Calling canceled by %s", handler_dec)
# notify handlers with "None"
for result_handler_dec in result_handlers:
await result_handler_dec.handle_call_result(data, None)
return
# Fire an event indicating that pyscript is running
# Note: the event must have an entity_id for logbook to work correctly.
ev_name = self.name.replace(".", "_")
ev_entity_id = f"pyscript.{ev_name}"
async def notify_result(result: Any) -> None:
for result_handler_dec in result_handlers:
await result_handler_dec.handle_call_result(data, result)

event_data = {"name": ev_name, "entity_id": ev_entity_id, "func_args": data.func_args}
self.hass.bus.async_fire("pyscript_running", event_data, context=data.hass_context)
# Store HASS Context for this Task
Function.store_hass_context(data.hass_context)
async def notify_exception(exc: BaseException) -> None:
for result_handler_dec in result_handlers:
await result_handler_dec.handle_call_exception(data, exc)

# Track whether result handlers have already been told the outcome so the
# cancellation guard in ``finally`` never notifies them twice.
notified = False
try:
for handler_dec in handlers:
if await handler_dec.handle_call(data) is False:
self.logger.debug("Calling canceled by %s", handler_dec)
# notify handlers with "None"
notified = True
await notify_result(None)
return
# Fire an event indicating that pyscript is running
# Note: the event must have an entity_id for logbook to work correctly.
ev_name = self.name.replace(".", "_")
ev_entity_id = f"pyscript.{ev_name}"

event_data = {"name": ev_name, "entity_id": ev_entity_id, "func_args": data.func_args}
self.hass.bus.async_fire("pyscript_running", event_data, context=data.hass_context)
# Store HASS Context for this Task
Function.store_hass_context(data.hass_context)

result = await data.call_ast_ctx.call_func(self.eval_func, None, **data.func_args)
for result_handler_dec in result_handlers:
await result_handler_dec.handle_call_result(data, result)
notified = True
await notify_result(result)
except Exception as e:
notified = True
await notify_exception(e)
await self.handle_exception(e)
finally:
if not notified:
# The action task was cancelled (CancelledError is a BaseException, so
# it is not caught above) or torn down before producing a result. Release
# anything awaiting one -- e.g. a @webhook_handler's HTTP response future --
# so it fails fast instead of hanging; the original exception keeps
# propagating once this block returns.
await notify_exception(asyncio.CancelledError())

async def dispatch(self, data: DispatchData) -> None:
"""Handle a trigger dispatch: run guards, create a context, and invoke the function."""
Expand All @@ -290,6 +311,8 @@ async def dispatch(self, data: DispatchData) -> None:
for dec in decorators:
if await dec.handle_dispatch(data) is False:
self.logger.debug("Trigger not active due to %s", dec)
for result_handler_dec in self.get_decorators(CallResultHandlerDecorator):
await result_handler_dec.handle_call_result(data, None)
return

action_ast_ctx = AstEval(
Expand Down
10 changes: 10 additions & 0 deletions custom_components/pyscript/decorator_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ async def validate(self) -> None:
"state_trigger",
"time_trigger",
"webhook_trigger",
"webhook_handler",
}
raise ValueError(
f"{self.dm.func_name} defined in {self.dm.ast_ctx.get_global_ctx_name()}: "
Expand All @@ -281,3 +282,12 @@ class CallResultHandlerDecorator(Decorator, ABC):
@abstractmethod
async def handle_call_result(self, data: DispatchData, result: Any) -> None:
"""Handle an action call result."""

async def handle_call_exception(self, data: DispatchData, exc: Exception) -> None:
"""
Handle an unhandled exception raised by the action call.

Defaults to treating the exception as a ``None`` result. Subclasses may
override to react to the failure (e.g. produce an error response).
"""
await self.handle_call_result(data, None)
2 changes: 2 additions & 0 deletions custom_components/pyscript/decorators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .task import TaskUniqueDecorator
from .timing import TimeActiveDecorator, TimeTriggerDecorator
from .webhook import WebhookTriggerDecorator
from .webhook_handler import WebhookHandlerDecorator

DECORATORS = [
StateTriggerDecorator,
Expand All @@ -17,5 +18,6 @@
EventTriggerDecorator,
MQTTTriggerDecorator,
WebhookTriggerDecorator,
WebhookHandlerDecorator,
ServiceDecorator,
]
16 changes: 14 additions & 2 deletions custom_components/pyscript/decorators/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base mixins for pyscript decorators."""

from abc import ABC
import inspect
import logging
from typing import Any

Expand All @@ -19,10 +20,21 @@ class AutoKwargsDecorator(Decorator, ABC):
async def validate(self) -> None:
"""Run base validation and materialize annotated kwargs as attributes."""
await super().validate()
for k in self.__class__.kwargs_schema.schema:
# Collect annotations declared anywhere in the class hierarchy so kwargs
# handling keeps working when attributes are declared on a shared base
# class (a class's ``__annotations__`` only exposes its own annotations).
# ``Decorator`` is skipped because its ``args``/``kwargs`` annotations would
# otherwise clobber the validated values.
annotations = {
name
for klass in type(self).__mro__
if klass is not Decorator
for name in inspect.get_annotations(klass)
}
for k in type(self).kwargs_schema.schema:
if isinstance(k, vol.Marker):
k = k.schema
if k in self.__class__.__annotations__:
if k in annotations:
setattr(self, k, self.kwargs.get(k, None))


Expand Down
58 changes: 5 additions & 53 deletions custom_components/pyscript/decorators/webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,64 +5,24 @@
import logging
from typing import ClassVar

from aiohttp import hdrs
import voluptuous as vol

from homeassistant.components import webhook
from homeassistant.components.webhook import SUPPORTED_METHODS
from homeassistant.helpers import config_validation as cv

from ..decorator_abc import DispatchData, TriggerDecorator
from .base import AutoKwargsDecorator, ExpressionDecorator
from ..decorator_abc import DispatchData
from .webhook_base import WebhookBaseDecorator

_LOGGER = logging.getLogger(__name__)


class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator):
class WebhookTriggerDecorator(WebhookBaseDecorator):
"""Implementation for @webhook_trigger."""

name = "webhook_trigger"
args_schema = vol.Schema(
vol.All(
[vol.Coerce(str)],
vol.Length(min=1, max=2, msg="needs at least one argument"),
)
)
kwargs_schema = vol.Schema(
{
vol.Optional("local_only", default=True): cv.boolean,
vol.Optional("methods"): vol.All(list[str], [vol.In(SUPPORTED_METHODS)]),
}
)

webhook_id: str
local_only: bool
methods: set[str]

webhook_id2triggers: ClassVar[dict[str, set[WebhookTriggerDecorator]]] = {}

async def validate(self):
"""Validate the webhook trigger configuration."""
await super().validate()
self.webhook_id = self.args[0]

if len(self.args) == 2:
self.create_expression(self.args[1])

@staticmethod
async def _handler(_hass, webhook_id, request):
func_args = {
"trigger_type": "webhook",
"webhook_id": webhook_id,
"request": request,
}

if "json" in request.headers.get(hdrs.CONTENT_TYPE, ""):
func_args["payload"] = await request.json()
else:
# Could potentially return multiples of a key - only take the first
payload_multidict = await request.post()
func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()}
func_args = await WebhookTriggerDecorator.build_func_args(webhook_id, request)

for trigger in WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id, set()).copy():
trigger_args = func_args.copy()
Expand All @@ -75,15 +35,7 @@ async def _handler(_hass, webhook_id, request):
def _add_trigger(trigger: WebhookTriggerDecorator) -> None:
webhook_id = trigger.webhook_id
if webhook_id not in WebhookTriggerDecorator.webhook_id2triggers:
webhook.async_register(
trigger.dm.hass,
"pyscript", # DOMAIN
"pyscript", # NAME
webhook_id,
WebhookTriggerDecorator._handler,
local_only=trigger.local_only,
allowed_methods=trigger.methods,
)
WebhookTriggerDecorator.register_webhook(trigger, WebhookTriggerDecorator._handler)
WebhookTriggerDecorator.webhook_id2triggers[webhook_id] = set()

WebhookTriggerDecorator.webhook_id2triggers[webhook_id].add(trigger)
Expand Down
80 changes: 80 additions & 0 deletions custom_components/pyscript/decorators/webhook_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Shared base for webhook decorators."""

from __future__ import annotations

from collections.abc import Awaitable, Callable
from typing import Any

from aiohttp import hdrs, web
import voluptuous as vol

from homeassistant.components import webhook
from homeassistant.components.webhook import SUPPORTED_METHODS
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv

from ..decorator_abc import TriggerDecorator
from .base import AutoKwargsDecorator, ExpressionDecorator


class WebhookBaseDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator):
"""Shared argument handling and request parsing for webhook decorators."""

args_schema = vol.Schema(
vol.All(
[vol.Coerce(str)],
vol.Length(min=1, max=2, msg="needs at least one argument"),
)
)
kwargs_schema = vol.Schema(
{
vol.Optional("local_only", default=True): cv.boolean,
vol.Optional("methods"): vol.All(list[str], [vol.In(SUPPORTED_METHODS)]),
}
)

webhook_id: str
local_only: bool
methods: list[str] | None

async def validate(self):
"""Validate the webhook configuration."""
await super().validate()
self.webhook_id = self.args[0]

if len(self.args) == 2:
self.create_expression(self.args[1])

@staticmethod
async def build_func_args(webhook_id: str, request: web.Request) -> dict[str, Any]:
"""Build the function arguments from an incoming webhook request."""
func_args = {
"trigger_type": "webhook",
"webhook_id": webhook_id,
"request": request,
}

if "json" in request.headers.get(hdrs.CONTENT_TYPE, ""):
func_args["payload"] = await request.json()
else:
# Could potentially return multiples of a key - only take the first
payload_multidict = await request.post()
func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()}

return func_args

@staticmethod
def register_webhook(
decorator: WebhookBaseDecorator,
handler: Callable[[HomeAssistant, str, web.Request], Awaitable[web.Response | None]],
) -> None:
"""Register a webhook decorator's id with Home Assistant."""
webhook.async_register(
decorator.dm.hass,
"pyscript", # DOMAIN
"pyscript", # NAME
decorator.webhook_id,
handler,
local_only=decorator.local_only,
allowed_methods=decorator.methods,
)
Loading
Loading