diff --git a/aws_lambda_powertools/event_handler/http_resolver.py b/aws_lambda_powertools/event_handler/http_resolver.py index 93e2fdc932e..168a6f44b8e 100644 --- a/aws_lambda_powertools/event_handler/http_resolver.py +++ b/aws_lambda_powertools/event_handler/http_resolver.py @@ -1,9 +1,7 @@ from __future__ import annotations -import asyncio import base64 import inspect -import threading import warnings from typing import TYPE_CHECKING, Any, Callable from urllib.parse import parse_qs @@ -15,6 +13,7 @@ Response, Route, ) +from aws_lambda_powertools.event_handler.middlewares.async_utils import wrap_middleware_async from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent @@ -320,73 +319,10 @@ async def final_handler(app): next_handler = final_handler for middleware in reversed(all_middlewares): - next_handler = self._wrap_middleware_async(middleware, next_handler) + next_handler = wrap_middleware_async(middleware, next_handler) return await next_handler(self) - def _wrap_middleware_async(self, middleware: Callable, next_handler: Callable) -> Callable: - """Wrap a middleware to work in async context. - - For sync middlewares, we split execution into pre/post phases around the - call to next(). The sync middleware runs its pre-processing (e.g. request - validation), then we intercept the next() call, await the async handler, - and resume the middleware with the real response so post-processing - (e.g. response validation) sees the actual data. - """ - - async def wrapped(app): - if inspect.iscoroutinefunction(middleware): - return await middleware(app, next_handler) - - # We use an Event to coordinate: the sync middleware runs in a thread, - # calls sync_next which signals us to resolve the async handler, - # then waits for the real response. - middleware_called_next = asyncio.Event() - next_app_holder: list = [] - real_response_holder: list = [] - middleware_result_holder: list = [] - middleware_error_holder: list = [] - - def sync_next(app): - next_app_holder.append(app) - middleware_called_next.set() - # Block this thread until the real response is available - event = threading.Event() - next_app_holder.append(event) - event.wait() - return real_response_holder[0] - - def run_middleware(): - try: - result = middleware(app, sync_next) - middleware_result_holder.append(result) - except Exception as e: - middleware_error_holder.append(e) - - thread = threading.Thread(target=run_middleware, daemon=True) - thread.start() - - # Wait for the middleware to call next() - await middleware_called_next.wait() - - # Now resolve the async next_handler - real_response = await next_handler(next_app_holder[0]) - real_response_holder.append(real_response) - - # Signal the thread that the response is ready - threading_event = next_app_holder[1] - threading_event.set() - - # Wait for the middleware thread to finish - thread.join() - - if middleware_error_holder: - raise middleware_error_holder[0] - - return middleware_result_holder[0] - - return wrapped - async def _handle_not_found_async(self) -> dict: """Handle 404 responses, using custom not_found handler if registered.""" from http import HTTPStatus diff --git a/aws_lambda_powertools/event_handler/middlewares/async_utils.py b/aws_lambda_powertools/event_handler/middlewares/async_utils.py new file mode 100644 index 00000000000..b04db33f1e8 --- /dev/null +++ b/aws_lambda_powertools/event_handler/middlewares/async_utils.py @@ -0,0 +1,107 @@ +"""Async middleware utilities for bridging sync and async middleware execution.""" + +from __future__ import annotations + +import asyncio +import inspect +import threading +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable + + from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, Response + + +def wrap_middleware_async(middleware: Callable, next_handler: Callable) -> Callable: + """Wrap a middleware to work in an async context. + + For async middlewares, delegates directly with ``await``. + + For sync middlewares, runs the middleware in a background thread and uses + ``asyncio.Event`` / ``threading.Event`` to coordinate the ``next()`` call + so the async handler can be awaited on the main event-loop while the sync + middleware blocks its own thread waiting for the result. + + Parameters + ---------- + middleware : Callable + A sync or async middleware ``(app, next_middleware) -> Response``. + next_handler : Callable + The next (async) handler in the chain. + + Returns + ------- + Callable + An async callable ``(app) -> Response`` that executes *middleware* + followed by *next_handler*. + """ + + async def wrapped(app: ApiGatewayResolver) -> Response: + if inspect.iscoroutinefunction(middleware): + return await middleware(app, next_handler) + + return await _run_sync_middleware_in_thread(middleware, next_handler, app) + + return wrapped + + +async def _run_sync_middleware_in_thread( + middleware: Callable, + next_handler: Callable, + app: Any, +) -> Any: + """Execute a **sync** middleware inside a daemon thread. + + The sync middleware calls ``sync_next(app)`` which: + + 1. Signals the async side that the middleware is ready for the next handler. + 2. Blocks the thread until the async handler has produced a response. + 3. Returns the response so the middleware can do post-processing. + + Meanwhile the async side awaits *next_handler*, feeds the response back, + and waits for the thread to finish. + """ + middleware_called_next = asyncio.Event() + next_app_holder: list = [] + real_response_holder: list = [] + middleware_result_holder: list = [] + middleware_error_holder: list = [] + + def sync_next(app: Any) -> Any: + next_app_holder.append(app) + middleware_called_next.set() + # Block this thread until the async handler resolves + event = threading.Event() + next_app_holder.append(event) + event.wait() + return real_response_holder[0] + + def run_middleware() -> None: + try: + result = middleware(app, sync_next) + middleware_result_holder.append(result) + except Exception as e: + middleware_error_holder.append(e) + + thread = threading.Thread(target=run_middleware, daemon=True) + thread.start() + + # Wait for the middleware to call next() + await middleware_called_next.wait() + + # Resolve the async next_handler on the event-loop + real_response = await next_handler(next_app_holder[0]) + real_response_holder.append(real_response) + + # Unblock the middleware thread + threading_event = next_app_holder[1] + threading_event.set() + + # Wait for the middleware thread to complete post-processing + thread.join() + + if middleware_error_holder: + raise middleware_error_holder[0] + + return middleware_result_holder[0]