From 0fbd75d61246ee57ff797047c92488acf532a990 Mon Sep 17 00:00:00 2001 From: Jay Newstrom Date: Fri, 5 Jun 2026 07:48:19 -0600 Subject: [PATCH] Add new webhook handler annotation. --- custom_components/pyscript/decorator.py | 57 ++- custom_components/pyscript/decorator_abc.py | 10 + .../pyscript/decorators/__init__.py | 2 + custom_components/pyscript/decorators/base.py | 16 +- .../pyscript/decorators/webhook.py | 58 +-- .../pyscript/decorators/webhook_base.py | 80 ++++ .../pyscript/decorators/webhook_handler.py | 132 ++++++ custom_components/pyscript/eval.py | 10 +- .../pyscript/stubs/pyscript_builtins.py | 28 ++ docs/reference.rst | 36 ++ tests/test_decorator_errors.py | 2 +- tests/test_decorator_manager.py | 57 ++- tests/test_decorators.py | 383 +++++++++++++++++- 13 files changed, 794 insertions(+), 77 deletions(-) create mode 100644 custom_components/pyscript/decorators/webhook_base.py create mode 100644 custom_components/pyscript/decorators/webhook_handler.py diff --git a/custom_components/pyscript/decorator.py b/custom_components/pyscript/decorator.py index 33af0d3..1185798 100644 --- a/custom_components/pyscript/decorator.py +++ b/custom_components/pyscript/decorator.py @@ -258,29 +258,50 @@ async def _call(self, data: DispatchData) -> None: handlers = self.get_decorators(CallHandlerDecorator) result_handlers = self.get_decorators(CallResultHandlerDecorator) - for handler_dec in handlers: - if await handler_dec.handle_call(data) is False: - self.logger.debug("Calling canceled by %s", handler_dec) - # notify handlers with "None" - for result_handler_dec in result_handlers: - await result_handler_dec.handle_call_result(data, None) - return - # Fire an event indicating that pyscript is running - # Note: the event must have an entity_id for logbook to work correctly. - ev_name = self.name.replace(".", "_") - ev_entity_id = f"pyscript.{ev_name}" + async def notify_result(result: Any) -> None: + for result_handler_dec in result_handlers: + await result_handler_dec.handle_call_result(data, result) - event_data = {"name": ev_name, "entity_id": ev_entity_id, "func_args": data.func_args} - self.hass.bus.async_fire("pyscript_running", event_data, context=data.hass_context) - # Store HASS Context for this Task - Function.store_hass_context(data.hass_context) + async def notify_exception(exc: BaseException) -> None: + for result_handler_dec in result_handlers: + await result_handler_dec.handle_call_exception(data, exc) + # Track whether result handlers have already been told the outcome so the + # cancellation guard in ``finally`` never notifies them twice. + notified = False try: + for handler_dec in handlers: + if await handler_dec.handle_call(data) is False: + self.logger.debug("Calling canceled by %s", handler_dec) + # notify handlers with "None" + notified = True + await notify_result(None) + return + # Fire an event indicating that pyscript is running + # Note: the event must have an entity_id for logbook to work correctly. + ev_name = self.name.replace(".", "_") + ev_entity_id = f"pyscript.{ev_name}" + + event_data = {"name": ev_name, "entity_id": ev_entity_id, "func_args": data.func_args} + self.hass.bus.async_fire("pyscript_running", event_data, context=data.hass_context) + # Store HASS Context for this Task + Function.store_hass_context(data.hass_context) + result = await data.call_ast_ctx.call_func(self.eval_func, None, **data.func_args) - for result_handler_dec in result_handlers: - await result_handler_dec.handle_call_result(data, result) + notified = True + await notify_result(result) except Exception as e: + notified = True + await notify_exception(e) await self.handle_exception(e) + finally: + if not notified: + # The action task was cancelled (CancelledError is a BaseException, so + # it is not caught above) or torn down before producing a result. Release + # anything awaiting one -- e.g. a @webhook_handler's HTTP response future -- + # so it fails fast instead of hanging; the original exception keeps + # propagating once this block returns. + await notify_exception(asyncio.CancelledError()) async def dispatch(self, data: DispatchData) -> None: """Handle a trigger dispatch: run guards, create a context, and invoke the function.""" @@ -290,6 +311,8 @@ async def dispatch(self, data: DispatchData) -> None: for dec in decorators: if await dec.handle_dispatch(data) is False: self.logger.debug("Trigger not active due to %s", dec) + for result_handler_dec in self.get_decorators(CallResultHandlerDecorator): + await result_handler_dec.handle_call_result(data, None) return action_ast_ctx = AstEval( diff --git a/custom_components/pyscript/decorator_abc.py b/custom_components/pyscript/decorator_abc.py index 4775317..cf4643e 100644 --- a/custom_components/pyscript/decorator_abc.py +++ b/custom_components/pyscript/decorator_abc.py @@ -256,6 +256,7 @@ async def validate(self) -> None: "state_trigger", "time_trigger", "webhook_trigger", + "webhook_handler", } raise ValueError( f"{self.dm.func_name} defined in {self.dm.ast_ctx.get_global_ctx_name()}: " @@ -281,3 +282,12 @@ class CallResultHandlerDecorator(Decorator, ABC): @abstractmethod async def handle_call_result(self, data: DispatchData, result: Any) -> None: """Handle an action call result.""" + + async def handle_call_exception(self, data: DispatchData, exc: Exception) -> None: + """ + Handle an unhandled exception raised by the action call. + + Defaults to treating the exception as a ``None`` result. Subclasses may + override to react to the failure (e.g. produce an error response). + """ + await self.handle_call_result(data, None) diff --git a/custom_components/pyscript/decorators/__init__.py b/custom_components/pyscript/decorators/__init__.py index f21b9a9..6094ceb 100644 --- a/custom_components/pyscript/decorators/__init__.py +++ b/custom_components/pyscript/decorators/__init__.py @@ -7,6 +7,7 @@ from .task import TaskUniqueDecorator from .timing import TimeActiveDecorator, TimeTriggerDecorator from .webhook import WebhookTriggerDecorator +from .webhook_handler import WebhookHandlerDecorator DECORATORS = [ StateTriggerDecorator, @@ -17,5 +18,6 @@ EventTriggerDecorator, MQTTTriggerDecorator, WebhookTriggerDecorator, + WebhookHandlerDecorator, ServiceDecorator, ] diff --git a/custom_components/pyscript/decorators/base.py b/custom_components/pyscript/decorators/base.py index 213a67e..2ccb134 100644 --- a/custom_components/pyscript/decorators/base.py +++ b/custom_components/pyscript/decorators/base.py @@ -1,6 +1,7 @@ """Base mixins for pyscript decorators.""" from abc import ABC +import inspect import logging from typing import Any @@ -19,10 +20,21 @@ class AutoKwargsDecorator(Decorator, ABC): async def validate(self) -> None: """Run base validation and materialize annotated kwargs as attributes.""" await super().validate() - for k in self.__class__.kwargs_schema.schema: + # Collect annotations declared anywhere in the class hierarchy so kwargs + # handling keeps working when attributes are declared on a shared base + # class (a class's ``__annotations__`` only exposes its own annotations). + # ``Decorator`` is skipped because its ``args``/``kwargs`` annotations would + # otherwise clobber the validated values. + annotations = { + name + for klass in type(self).__mro__ + if klass is not Decorator + for name in inspect.get_annotations(klass) + } + for k in type(self).kwargs_schema.schema: if isinstance(k, vol.Marker): k = k.schema - if k in self.__class__.__annotations__: + if k in annotations: setattr(self, k, self.kwargs.get(k, None)) diff --git a/custom_components/pyscript/decorators/webhook.py b/custom_components/pyscript/decorators/webhook.py index 3db0a09..82814f3 100644 --- a/custom_components/pyscript/decorators/webhook.py +++ b/custom_components/pyscript/decorators/webhook.py @@ -5,64 +5,24 @@ import logging from typing import ClassVar -from aiohttp import hdrs -import voluptuous as vol - from homeassistant.components import webhook -from homeassistant.components.webhook import SUPPORTED_METHODS -from homeassistant.helpers import config_validation as cv -from ..decorator_abc import DispatchData, TriggerDecorator -from .base import AutoKwargsDecorator, ExpressionDecorator +from ..decorator_abc import DispatchData +from .webhook_base import WebhookBaseDecorator _LOGGER = logging.getLogger(__name__) -class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator): +class WebhookTriggerDecorator(WebhookBaseDecorator): """Implementation for @webhook_trigger.""" name = "webhook_trigger" - args_schema = vol.Schema( - vol.All( - [vol.Coerce(str)], - vol.Length(min=1, max=2, msg="needs at least one argument"), - ) - ) - kwargs_schema = vol.Schema( - { - vol.Optional("local_only", default=True): cv.boolean, - vol.Optional("methods"): vol.All(list[str], [vol.In(SUPPORTED_METHODS)]), - } - ) - - webhook_id: str - local_only: bool - methods: set[str] webhook_id2triggers: ClassVar[dict[str, set[WebhookTriggerDecorator]]] = {} - async def validate(self): - """Validate the webhook trigger configuration.""" - await super().validate() - self.webhook_id = self.args[0] - - if len(self.args) == 2: - self.create_expression(self.args[1]) - @staticmethod async def _handler(_hass, webhook_id, request): - func_args = { - "trigger_type": "webhook", - "webhook_id": webhook_id, - "request": request, - } - - if "json" in request.headers.get(hdrs.CONTENT_TYPE, ""): - func_args["payload"] = await request.json() - else: - # Could potentially return multiples of a key - only take the first - payload_multidict = await request.post() - func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()} + func_args = await WebhookTriggerDecorator.build_func_args(webhook_id, request) for trigger in WebhookTriggerDecorator.webhook_id2triggers.get(webhook_id, set()).copy(): trigger_args = func_args.copy() @@ -75,15 +35,7 @@ async def _handler(_hass, webhook_id, request): def _add_trigger(trigger: WebhookTriggerDecorator) -> None: webhook_id = trigger.webhook_id if webhook_id not in WebhookTriggerDecorator.webhook_id2triggers: - webhook.async_register( - trigger.dm.hass, - "pyscript", # DOMAIN - "pyscript", # NAME - webhook_id, - WebhookTriggerDecorator._handler, - local_only=trigger.local_only, - allowed_methods=trigger.methods, - ) + WebhookTriggerDecorator.register_webhook(trigger, WebhookTriggerDecorator._handler) WebhookTriggerDecorator.webhook_id2triggers[webhook_id] = set() WebhookTriggerDecorator.webhook_id2triggers[webhook_id].add(trigger) diff --git a/custom_components/pyscript/decorators/webhook_base.py b/custom_components/pyscript/decorators/webhook_base.py new file mode 100644 index 0000000..38d2da9 --- /dev/null +++ b/custom_components/pyscript/decorators/webhook_base.py @@ -0,0 +1,80 @@ +"""Shared base for webhook decorators.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any + +from aiohttp import hdrs, web +import voluptuous as vol + +from homeassistant.components import webhook +from homeassistant.components.webhook import SUPPORTED_METHODS +from homeassistant.core import HomeAssistant +from homeassistant.helpers import config_validation as cv + +from ..decorator_abc import TriggerDecorator +from .base import AutoKwargsDecorator, ExpressionDecorator + + +class WebhookBaseDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator): + """Shared argument handling and request parsing for webhook decorators.""" + + args_schema = vol.Schema( + vol.All( + [vol.Coerce(str)], + vol.Length(min=1, max=2, msg="needs at least one argument"), + ) + ) + kwargs_schema = vol.Schema( + { + vol.Optional("local_only", default=True): cv.boolean, + vol.Optional("methods"): vol.All(list[str], [vol.In(SUPPORTED_METHODS)]), + } + ) + + webhook_id: str + local_only: bool + methods: list[str] | None + + async def validate(self): + """Validate the webhook configuration.""" + await super().validate() + self.webhook_id = self.args[0] + + if len(self.args) == 2: + self.create_expression(self.args[1]) + + @staticmethod + async def build_func_args(webhook_id: str, request: web.Request) -> dict[str, Any]: + """Build the function arguments from an incoming webhook request.""" + func_args = { + "trigger_type": "webhook", + "webhook_id": webhook_id, + "request": request, + } + + if "json" in request.headers.get(hdrs.CONTENT_TYPE, ""): + func_args["payload"] = await request.json() + else: + # Could potentially return multiples of a key - only take the first + payload_multidict = await request.post() + func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()} + + return func_args + + @staticmethod + def register_webhook( + decorator: WebhookBaseDecorator, + handler: Callable[[HomeAssistant, str, web.Request], Awaitable[web.Response | None]], + ) -> None: + """Register a webhook decorator's id with Home Assistant.""" + webhook.async_register( + decorator.dm.hass, + "pyscript", # DOMAIN + "pyscript", # NAME + decorator.webhook_id, + handler, + local_only=decorator.local_only, + allowed_methods=decorator.methods, + ) diff --git a/custom_components/pyscript/decorators/webhook_handler.py b/custom_components/pyscript/decorators/webhook_handler.py new file mode 100644 index 0000000..5db4e74 --- /dev/null +++ b/custom_components/pyscript/decorators/webhook_handler.py @@ -0,0 +1,132 @@ +"""Webhook handler decorator.""" + +from __future__ import annotations + +import asyncio +from http import HTTPStatus +import logging +from typing import Any, ClassVar + +from aiohttp import web + +from homeassistant.components import webhook + +from ..decorator_abc import CallResultHandlerDecorator, DispatchData +from .webhook_base import WebhookBaseDecorator + +_LOGGER = logging.getLogger(__name__) + +# Key under which the per-request response future is stored on DispatchData.trigger_context. +# Per-request state must live on DispatchData (one per request), not on the shared decorator +# instance, so concurrent requests to the same webhook_id don't race over a single future. +_RESPONSE_FUTURE_KEY = "webhook_response_future" + + +class WebhookHandlerDecorator(WebhookBaseDecorator, CallResultHandlerDecorator): + """Implementation for @webhook_handler (one handler per id; return value drives the response).""" + + name = "webhook_handler" + + # Exactly one handler per webhook_id (unlike webhook_trigger which allows many). + webhook_id2handler: ClassVar[dict[str, WebhookHandlerDecorator]] = {} + + @staticmethod + async def _handler(hass, webhook_id, request): + handler = WebhookHandlerDecorator.webhook_id2handler.get(webhook_id) + if handler is None: + return None + + try: + func_args = await WebhookHandlerDecorator.build_func_args(webhook_id, request) + except ValueError: + # The body could not be parsed (e.g. malformed JSON). Unlike @webhook_trigger, + # which silently drops the event, a handler should tell the caller their request + # was bad rather than returning a 200. + _LOGGER.debug("webhook %s received an unparseable request body", webhook_id) + return web.Response(status=HTTPStatus.BAD_REQUEST) + + if handler.has_expression(): + if not await handler.check_expression_vars(func_args): + return None + + response_future: asyncio.Future[Any] = hass.loop.create_future() + data = DispatchData(func_args, trigger_context={_RESPONSE_FUTURE_KEY: response_future}) + await handler.dispatch(data) + + result = await response_future + return WebhookHandlerDecorator.coerce_response(result) + + async def handle_call_result(self, data: DispatchData, result: Any) -> None: + """Resolve the per-request response future with the function's return value.""" + if data.trigger is not self: + return + response_future = data.trigger_context.get(_RESPONSE_FUTURE_KEY) + if response_future is not None and not response_future.done(): + response_future.set_result(result) + + async def handle_call_exception(self, data: DispatchData, exc: Exception) -> None: + """ + Resolve the per-request response future with a 500 on an unhandled exception. + + The exception is also logged via the manager's handle_exception; here we only + ensure the awaiting request gets a 500 instead of falling back to a 200. We + resolve with a Response (not set_exception) so the error does not propagate out + of the aiohttp handler, where Home Assistant would turn it back into a 200. + """ + if data.trigger is not self: + return + response_future = data.trigger_context.get(_RESPONSE_FUTURE_KEY) + if response_future is not None and not response_future.done(): + response_future.set_result(web.Response(status=HTTPStatus.INTERNAL_SERVER_ERROR)) + + @staticmethod + def coerce_response(value: Any) -> web.Response | None: + """Convert a webhook handler return value to an aiohttp Response.""" + if value is None: + return None + if isinstance(value, web.Response): + return value + # bool is a subclass of int; reject it so True/False don't become 1/0 status codes. + if isinstance(value, int) and not isinstance(value, bool): + if 100 <= value <= 599: + return web.Response(status=value) + _LOGGER.warning( + "@webhook_handler function returned %s, which is not a valid HTTP status code (100-599)", + value, + ) + return None + _LOGGER.warning( + "@webhook_handler function returned unsupported type %s; " + "expected int status code or aiohttp.web.Response", + type(value).__name__, + ) + return None + + @staticmethod + def _add_handler(handler: WebhookHandlerDecorator) -> None: + # Home Assistant's webhook.async_register raises if the webhook_id is already + # registered (by another @webhook_handler, a @webhook_trigger, or any other + # integration), so duplicates are rejected here without an extra pyscript check. + webhook_id = handler.webhook_id + WebhookHandlerDecorator.register_webhook(handler, WebhookHandlerDecorator._handler) + WebhookHandlerDecorator.webhook_id2handler[webhook_id] = handler + + @staticmethod + def _remove_handler(handler: WebhookHandlerDecorator) -> None: + webhook_id = handler.webhook_id + if WebhookHandlerDecorator.webhook_id2handler.get(webhook_id) is not handler: + return + webhook.async_unregister(handler.dm.hass, webhook_id) + del WebhookHandlerDecorator.webhook_id2handler[webhook_id] + + async def start(self): + """Start the webhook handler.""" + await super().start() + self._add_handler(self) + + _LOGGER.debug("webhook handler %s listening on id %s", self.dm.name, self.webhook_id) + + async def stop(self): + """Stop the webhook handler.""" + await super().stop() + self._remove_handler(self) diff --git a/custom_components/pyscript/eval.py b/custom_components/pyscript/eval.py index 8a9df38..da50a81 100644 --- a/custom_components/pyscript/eval.py +++ b/custom_components/pyscript/eval.py @@ -59,6 +59,7 @@ "event_trigger", "mqtt_trigger", "webhook_trigger", + "webhook_handler", "state_active", "time_active", "task_unique", @@ -377,6 +378,7 @@ async def trigger_init(self, trig_ctx, func_name): "state_trigger", "time_trigger", "webhook_trigger", + "webhook_handler", } arg_check = { "event_trigger": {"arg_cnt": {1, 2, 3}, "rep_ok": True}, @@ -388,6 +390,7 @@ async def trigger_init(self, trig_ctx, func_name): "time_active": {"arg_cnt": {"*"}}, "time_trigger": {"arg_cnt": {0, "*"}, "rep_ok": True}, "webhook_trigger": {"arg_cnt": {1, 2}, "rep_ok": True}, + "webhook_handler": {"arg_cnt": {1, 2}, "rep_ok": True}, } kwarg_check = { "event_trigger": {"kwargs": {dict}}, @@ -411,6 +414,11 @@ async def trigger_init(self, trig_ctx, func_name): "local_only": {bool}, "methods": {list, set}, }, + "webhook_handler": { + "kwargs": {dict}, + "local_only": {bool}, + "methods": {list, set}, + }, } for dec in self.decorators: @@ -541,7 +549,7 @@ async def do_service_call(func, ast_ctx, data): self.trigger_service.add(srv_name) continue - if dec_name == "webhook_trigger" and "methods" in dec_kwargs: + if dec_name in ("webhook_trigger", "webhook_handler") and "methods" in dec_kwargs: if len(bad := set(dec_kwargs["methods"]).difference(WEBHOOK_METHODS)) > 0: raise TypeError(f"{exc_mesg}: {bad} aren't valid {dec_name} methods") diff --git a/custom_components/pyscript/stubs/pyscript_builtins.py b/custom_components/pyscript/stubs/pyscript_builtins.py index ea75580..501a8f4 100644 --- a/custom_components/pyscript/stubs/pyscript_builtins.py +++ b/custom_components/pyscript/stubs/pyscript_builtins.py @@ -143,6 +143,34 @@ def webhook_trigger( ... +def webhook_handler( + webhook_id: str, + str_expr: str | None = None, + local_only: bool = True, + methods: set[SUPPORTED_METHODS] | list[SUPPORTED_METHODS] = {"POST", "PUT"}, + kwargs: dict | None = None, +) -> Callable[..., Any]: + """Handle a request to a webhook endpoint and control its HTTP response. + + Like ``webhook_trigger``, but only one ``webhook_handler`` is allowed per ``webhook_id`` and + the decorated function's return value drives the HTTP response: ``None`` produces a ``200 OK``, + an ``int`` sends that status code, and an ``aiohttp.web.Response`` gives full control over the + body and headers. The handler waits for the function to return before responding, so use + ``task.create()`` for any long-running work. + + Args: + webhook_id: Webhook id to listen to. Must be unique; it cannot be shared with another + ``webhook_handler`` or ``webhook_trigger``. + str_expr: Optional expression evaluated against ``trigger_type``, ``webhook_id``, ``request``, and ``payload``. + local_only: If False, allow requests from anywhere on the internet. + methods: HTTP methods to allow. + kwargs: Extra keyword arguments merged into each invocation. + + Handler kwargs include ``trigger_type="webhook"``, ``webhook_id``, the parsed payload fields, and ``request`` (the underlying ``aiohttp.web.Request``). + """ + ... + + def pyscript_compile() -> Callable[..., Any]: """Compile the wrapped function into native (synchronous) Python. diff --git a/docs/reference.rst b/docs/reference.rst index 3b7c587..a508cc9 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -917,6 +917,42 @@ To validate an HMAC signature on incoming requests, declare ``request`` in the f NOTE: A webhook_id can only be used by either a built-in Home Assistant automation or pyscript, but not both. Trying to use the same webhook_id in both will result in an error. +@webhook_handler +^^^^^^^^^^^^^^^^ + +.. code:: python + + @webhook_handler(webhook_id, str_expr=None, local_only=True, methods={"POST", "PUT"}, kwargs=None) + +``@webhook_handler`` is like ``@webhook_trigger``, with two differences: only **one** handler is allowed per ``webhook_id``, and the decorated function's **return value controls the HTTP response** sent back to the caller. Use it when the caller needs a meaningful status code or response body; use ``@webhook_trigger`` for fire-and-forget notifications where the response is always ``200 OK``. + +The ``local_only``, ``methods``, ``str_expr`` and ``kwargs`` options, and the ``trigger_type``, ``webhook_id``, ``payload`` and ``request`` variables, behave exactly as for ``@webhook_trigger`` (see above). + +The return value is mapped to the response as follows: + +- ``None`` (or no ``return``) produces a ``200 OK``. +- an ``int`` in the range ``100``-``599`` sends back a response with that status code, e.g. ``return 404``. An out-of-range int is ignored (a warning is logged) and a ``200 OK`` is sent. +- an ``aiohttp.web.Response`` is returned as-is, giving full control over the status, body and headers. Constructing one requires ``from aiohttp import web``, which needs ``allow_all_imports: true`` in your pyscript configuration. +- any other type is ignored (a warning is logged) and a ``200 OK`` is sent. + +If the function raises an unhandled exception, the exception is logged and a ``500 Internal Server Error`` is sent. Catch the exception in your function if you want to return a different status. + +If the request body cannot be parsed (for example a body declared as ``application/json`` that is not valid JSON), the function is not called and a ``400 Bad Request`` is sent. + +For example: + +.. code:: python + + @webhook_handler("myid") + def webhook_check(payload): + if "token" not in payload: + return 401 + return 204 + +The handler waits for the decorated function to finish before sending the response, so use ``task.create()`` to fire-and-forget any long-running work and return promptly. + +NOTE: The ``webhook_id`` used by a ``@webhook_handler`` must be unique - it cannot be shared with another ``@webhook_handler`` or with a ``@webhook_trigger``. Attempting to reuse a ``webhook_id`` will result in an error when the script is loaded. + @state_active ^^^^^^^^^^^^^ diff --git a/tests/test_decorator_errors.py b/tests/test_decorator_errors.py index b03d8cc..6e2a0f1 100644 --- a/tests/test_decorator_errors.py +++ b/tests/test_decorator_errors.py @@ -208,7 +208,7 @@ def func4(): """, ) assert ( - "func4 defined in file.hello: needs at least one trigger decorator (ie: event_trigger, mqtt_trigger, state_trigger, time_trigger, webhook_trigger)" + "func4 defined in file.hello: needs at least one trigger decorator (ie: event_trigger, mqtt_trigger, state_trigger, time_trigger, webhook_handler, webhook_trigger)" in caplog.text ) diff --git a/tests/test_decorator_manager.py b/tests/test_decorator_manager.py index 45c6f89..1408938 100644 --- a/tests/test_decorator_manager.py +++ b/tests/test_decorator_manager.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import logging from typing import ClassVar from unittest.mock import patch @@ -271,7 +272,7 @@ def get_name(self) -> str: class DummyCallAstCtx: """Minimal action AstEval stub for manager call tests.""" - def __init__(self, result: object = None, exc: Exception | None = None) -> None: + def __init__(self, result: object = None, exc: BaseException | None = None) -> None: """Initialize the dummy action context.""" self.result = result self.exc = exc @@ -599,6 +600,60 @@ async def test_function_decorator_manager_logs_call_exception(hass): assert str(ast_ctx.logged_exceptions[0]) == "decorated call failed" +@pytest.mark.asyncio +async def test_function_decorator_manager_exception_calls_result_handlers(hass): + """When the decorated function raises, result handlers should be notified with None.""" + DecoratorManager.hass = hass + ast_ctx = DummyAstCtx() + manager = FunctionDecoratorManager(ast_ctx, DummyEvalFuncVar()) + result_handler = make_recording_result_handler() + manager.add(result_handler) + call_ast_ctx = DummyCallAstCtx(exc=RuntimeError("boom")) + + with patch.object(Function, "store_hass_context"): + await call_function_manager( + manager, + make_dispatch_data( + {"arg1": 1}, + call_ast_ctx=call_ast_ctx, + hass_context=Context(id="call-parent"), + ), + ) + + assert result_handler.results == [None] + assert len(ast_ctx.logged_exceptions) == 1 + + +@pytest.mark.asyncio +async def test_function_decorator_manager_cancellation_notifies_result_handlers(hass): + """A cancelled action must still release result handlers, then re-raise the cancellation.""" + DecoratorManager.hass = hass + ast_ctx = DummyAstCtx() + manager = FunctionDecoratorManager(ast_ctx, DummyEvalFuncVar()) + result_handler = make_recording_result_handler() + manager.add(result_handler) + call_ast_ctx = DummyCallAstCtx(exc=asyncio.CancelledError()) + + with ( + patch.object(Function, "store_hass_context"), + pytest.raises(asyncio.CancelledError), + ): + await call_function_manager( + manager, + make_dispatch_data( + {"arg1": 1}, + call_ast_ctx=call_ast_ctx, + hass_context=Context(id="call-parent"), + ), + ) + + # The cancelled call still notifies result handlers (so e.g. a @webhook_handler + # response future is resolved rather than orphaned)... + assert result_handler.results == [None] + # ...but a cancellation is not logged as an error the way a normal exception is. + assert not ast_ctx.logged_exceptions + + def test_decorator_registry_register_requires_name(): """Registry should reject decorators without a declared name.""" diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 12224d4..ff95da2 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -3,12 +3,20 @@ from ast import literal_eval import asyncio from datetime import datetime as dt -from unittest.mock import mock_open, patch +from http import HTTPStatus +from unittest.mock import Mock, mock_open, patch +from aiohttp import web import pytest from custom_components.pyscript import trigger from custom_components.pyscript.const import DOMAIN +from custom_components.pyscript.decorator_abc import DispatchData +from custom_components.pyscript.decorators.webhook import WebhookTriggerDecorator +from custom_components.pyscript.decorators.webhook_handler import ( + _RESPONSE_FUTURE_KEY, + WebhookHandlerDecorator, +) from custom_components.pyscript.function import Function from homeassistant.components import webhook from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_STATE_CHANGED @@ -72,7 +80,7 @@ async def wait_until_done(notify_q): @pytest.mark.asyncio -async def test_decorator_errors(hass, caplog): +async def test_decorator_errors(hass): """Test decorator syntax and run-time errors.""" notify_q = asyncio.Queue(0) await setup_script( @@ -256,3 +264,374 @@ def webhook_test(payload, request): await webhook.async_handle_webhook(hass, "test_req_hook", request) assert literal_eval(await wait_until_done(notify_q)) == ["abc123", "POST", {"hello": "world"}] + + +def _post_webhook_request(content: bytes = b"") -> MockRequest: + """Build a MockRequest representing a webhook POST with form data.""" + return MockRequest( + content=content, + headers={}, + method="POST", + query_string="", + mock_source="test", + remote="127.0.0.1", + ) + + +def test_webhook_handler_coerce_response_none(): + """A None return should fall through to the HA default response.""" + assert WebhookHandlerDecorator.coerce_response(None) is None + + +def test_webhook_handler_coerce_response_int(): + """Int returns should produce an aiohttp Response with that status.""" + response = WebhookHandlerDecorator.coerce_response(HTTPStatus.CREATED.value) + assert isinstance(response, web.Response) + assert response.status == HTTPStatus.CREATED + + +def test_webhook_handler_coerce_response_passthrough(): + """An aiohttp Response should be returned unchanged.""" + custom = web.Response(status=HTTPStatus.ACCEPTED, body=b"queued") + assert WebhookHandlerDecorator.coerce_response(custom) is custom + + +def test_webhook_handler_coerce_response_bool_warns(caplog): + """Bool returns should be rejected so True/False don't masquerade as 1/0.""" + assert WebhookHandlerDecorator.coerce_response(True) is None + assert "unsupported type bool" in caplog.text + + +def test_webhook_handler_coerce_response_unsupported_warns(caplog): + """Other return types should warn and fall through.""" + assert WebhookHandlerDecorator.coerce_response("ok") is None + assert "unsupported type str" in caplog.text + + +@pytest.mark.asyncio +async def test_webhook_handler_returns_status_code(hass): + """A webhook handler returning an int should set the HTTP status.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_handler("status_hook") +def func_status(payload): + return 201 +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response = await webhook.async_handle_webhook(hass, "status_hook", _post_webhook_request()) + await hass.async_block_till_done() + assert response.status == HTTPStatus.CREATED + + +@pytest.mark.asyncio +async def test_webhook_handler_default_response(hass): + """A webhook handler returning None should produce a 200 OK.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_handler("default_hook") +def func_default(payload): + pass +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response = await webhook.async_handle_webhook(hass, "default_hook", _post_webhook_request()) + await hass.async_block_till_done() + assert response.status == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_webhook_handler_exception_returns_500(hass, caplog): + """A handler that raises should not hang and should return a 500.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_handler("boom_hook") +def func_boom(payload): + raise ValueError("boom") +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response = await asyncio.wait_for( + webhook.async_handle_webhook(hass, "boom_hook", _post_webhook_request()), timeout=4 + ) + await hass.async_block_till_done() + assert response.status == HTTPStatus.INTERNAL_SERVER_ERROR + # The underlying exception is still surfaced to the user's log. + assert "boom" in caplog.text + + +@pytest.mark.asyncio +async def test_webhook_handler_exception_for_other_trigger_ignored(): + """An exception whose trigger is a different decorator must not resolve this future.""" + handler = Mock(spec=WebhookHandlerDecorator) + future = asyncio.get_running_loop().create_future() + data = DispatchData({}, trigger_context={_RESPONSE_FUTURE_KEY: future}) + data.trigger = object() # a different trigger instance, not `handler` + await WebhookHandlerDecorator.handle_call_exception(handler, data, ValueError("boom")) + assert not future.done() + + +@pytest.mark.asyncio +async def test_webhook_handler_str_expr_no_match(hass): + """When str_expr does not match, the function is not called and a 200 OK is sent.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_handler("expr_hook", "payload['ok'] == 'yes'") +def func_expr(payload): + return 418 +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response = await webhook.async_handle_webhook(hass, "expr_hook", _post_webhook_request(b"ok=no")) + await hass.async_block_till_done() + assert response.status == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_webhook_handler_str_expr_match(hass): + """When str_expr matches, the function is called and its return value drives the response.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_handler("expr_match_hook", "payload['ok'] == 'yes'") +def func_expr(payload): + return 418 +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + response = await webhook.async_handle_webhook(hass, "expr_match_hook", _post_webhook_request(b"ok=yes")) + await hass.async_block_till_done() + assert response.status == 418 # the function returned 418 (HTTP "I'm a teapot") + + +@pytest.mark.asyncio +async def test_webhook_handler_json_payload(hass): + """A JSON request body is parsed into the payload passed to the handler.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_handler("json_hook") +def func_json(payload): + return payload["status"] +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + request = MockRequest( + content=b'{"status": 503}', + mock_source="test", + method="POST", + headers={"Content-Type": "application/json"}, + remote="127.0.0.1", + ) + response = await webhook.async_handle_webhook(hass, "json_hook", request) + await hass.async_block_till_done() + assert response.status == HTTPStatus.SERVICE_UNAVAILABLE + + +@pytest.mark.asyncio +async def test_webhook_handler_concurrent_requests(hass): + """Concurrent requests to the same id must each get their own response (no shared future).""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_handler("concurrent_hook") +def func_concurrent(payload): + task.sleep(float(payload["delay"])) + return int(payload["code"]) +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + # Request A starts first but sleeps longer, so it finishes after request B. Under the old + # shared-instance `self.future` design A would hang or receive B's response. + response_a, response_b = await asyncio.gather( + webhook.async_handle_webhook(hass, "concurrent_hook", _post_webhook_request(b"delay=0.3&code=201")), + webhook.async_handle_webhook(hass, "concurrent_hook", _post_webhook_request(b"delay=0.05&code=202")), + ) + await hass.async_block_till_done() + assert response_a.status == HTTPStatus.CREATED + assert response_b.status == HTTPStatus.ACCEPTED + + +@pytest.mark.asyncio +async def test_webhook_handler_cancelled_action_still_responds(hass): + """If @task_unique cancels an in-flight action, that request must get a 500, not hang.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_handler("cancel_hook") +@task_unique("cancel_unique") +def func_cancel(payload): + task.sleep(float(payload["delay"])) + return int(payload["code"]) +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + + async def fire(delay, code): + return await asyncio.wait_for( + webhook.async_handle_webhook( + hass, "cancel_hook", _post_webhook_request(f"delay={delay}&code={code}".encode()) + ), + timeout=4, + ) + + # Both requests register the same @task_unique name while sleeping, so whichever runs + # second cancels the other mid-flight. Before the fix the cancelled request's response + # future was orphaned and the request hung; now it resolves to a 500. + results = await asyncio.gather(fire(0.3, 201), fire(0.05, 202), return_exceptions=True) + await hass.async_block_till_done() + + assert not any(isinstance(r, Exception) for r in results), f"a request hung/failed: {results}" + statuses = sorted(r.status for r in results) + assert statuses[0] in (HTTPStatus.CREATED, HTTPStatus.ACCEPTED), statuses # survivor's own code + assert statuses[1] == HTTPStatus.INTERNAL_SERVER_ERROR, statuses # cancelled request + + +@pytest.mark.asyncio +async def test_webhook_handler_malformed_json_returns_400(hass): + """A body that claims to be JSON but isn't should yield a 400, not a masked 200.""" + await setup_script( + hass, + None, + dt(2020, 7, 1, 11, 59, 59, 999999), + """ +@webhook_handler("bad_json_hook") +def func_bad_json(payload): + return 200 +""", + ) + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + request = MockRequest( + content=b"{not valid json", + mock_source="test", + method="POST", + headers={"Content-Type": "application/json"}, + remote="127.0.0.1", + ) + response = await webhook.async_handle_webhook(hass, "bad_json_hook", request) + await hass.async_block_till_done() + assert response.status == HTTPStatus.BAD_REQUEST + + +def test_webhook_handler_coerce_response_out_of_range_warns(caplog): + """An int outside the valid HTTP status range should warn and fall through to a 200.""" + assert WebhookHandlerDecorator.coerce_response(1000) is None + assert WebhookHandlerDecorator.coerce_response(0) is None + assert "not a valid HTTP status code" in caplog.text + + +def test_webhook_handler_duplicate_id_fails(): + """A second @webhook_handler for the same webhook_id is rejected by HA's webhook registry.""" + hass = Mock(data={}) + first = Mock(webhook_id="dup_hook", local_only=True, methods=None, dm=Mock(hass=hass)) + second = Mock(webhook_id="dup_hook", local_only=True, methods=None, dm=Mock(hass=hass)) + + WebhookHandlerDecorator.webhook_id2handler.pop("dup_hook", None) + try: + WebhookHandlerDecorator._add_handler(first) # pylint: disable=protected-access + with pytest.raises(ValueError, match="Handler is already defined"): + WebhookHandlerDecorator._add_handler(second) # pylint: disable=protected-access + finally: + WebhookHandlerDecorator.webhook_id2handler.pop("dup_hook", None) + + +def test_webhook_handler_collides_with_webhook_trigger(): + """A @webhook_handler whose id is already used by a @webhook_trigger is rejected by HA.""" + hass = Mock(data={}) + trigger_mock = Mock(webhook_id="shared_hook", local_only=True, methods=None, dm=Mock(hass=hass)) + handler = Mock(webhook_id="shared_hook", local_only=True, methods=None, dm=Mock(hass=hass)) + + WebhookTriggerDecorator.webhook_id2triggers.pop("shared_hook", None) + WebhookHandlerDecorator.webhook_id2handler.pop("shared_hook", None) + try: + WebhookTriggerDecorator._add_trigger(trigger_mock) # pylint: disable=protected-access + with pytest.raises(ValueError, match="Handler is already defined"): + WebhookHandlerDecorator._add_handler(handler) # pylint: disable=protected-access + finally: + WebhookTriggerDecorator.webhook_id2triggers.pop("shared_hook", None) + WebhookHandlerDecorator.webhook_id2handler.pop("shared_hook", None) + + +def test_webhook_handler_remove_unregisters_and_frees_id(): + """Removing a handler unregisters the webhook and frees the id for reuse.""" + handler = Mock(webhook_id="remove_hook", local_only=True, methods=None) + + WebhookHandlerDecorator.webhook_id2handler.pop("remove_hook", None) + with ( + patch("custom_components.pyscript.decorators.webhook_handler.webhook.async_register") as register, + patch( + "custom_components.pyscript.decorators.webhook_handler.webhook.async_unregister" + ) as unregister, + ): + try: + WebhookHandlerDecorator._add_handler(handler) # pylint: disable=protected-access + assert register.called + assert WebhookHandlerDecorator.webhook_id2handler["remove_hook"] is handler + + WebhookHandlerDecorator._remove_handler(handler) # pylint: disable=protected-access + assert unregister.called + assert "remove_hook" not in WebhookHandlerDecorator.webhook_id2handler + finally: + WebhookHandlerDecorator.webhook_id2handler.pop("remove_hook", None) + + +@pytest.mark.asyncio +async def test_webhook_handler_unknown_id_returns_none(): + """The shared handler returns None for a webhook_id with no registered handler.""" + WebhookHandlerDecorator.webhook_id2handler.pop("ghost_hook", None) + result = await WebhookHandlerDecorator._handler( # pylint: disable=protected-access + Mock(), "ghost_hook", Mock() + ) + assert result is None + + +def test_webhook_handler_remove_is_noop_for_unregistered(): + """Removing a handler that isn't the registered one must not unregister anything.""" + handler = Mock(webhook_id="absent_hook") + WebhookHandlerDecorator.webhook_id2handler.pop("absent_hook", None) + with patch( + "custom_components.pyscript.decorators.webhook_handler.webhook.async_unregister" + ) as unregister: + WebhookHandlerDecorator._remove_handler(handler) # pylint: disable=protected-access + assert not unregister.called + + +@pytest.mark.asyncio +async def test_webhook_handler_result_for_other_trigger_ignored(): + """A result whose trigger is a different decorator must not resolve this handler's future.""" + handler = Mock(spec=WebhookHandlerDecorator) + future = asyncio.get_running_loop().create_future() + data = DispatchData({}, trigger_context={_RESPONSE_FUTURE_KEY: future}) + data.trigger = object() # a different trigger instance, not `handler` + await WebhookHandlerDecorator.handle_call_result(handler, data, 201) + assert not future.done()