Skip to content

Commit 74fbcc5

Browse files
fix(event_handler): sync middleware receives real response in async ASGI context (#8089)
fix: resolve ASGI
1 parent 2b95d7f commit 74fbcc5

File tree

2 files changed

+140
-26
lines changed

2 files changed

+140
-26
lines changed

aws_lambda_powertools/event_handler/http_resolver.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import base64
55
import inspect
6+
import threading
67
import warnings
78
from typing import TYPE_CHECKING, Any, Callable
89
from urllib.parse import parse_qs
@@ -324,36 +325,65 @@ async def final_handler(app):
324325
return await next_handler(self)
325326

326327
def _wrap_middleware_async(self, middleware: Callable, next_handler: Callable) -> Callable:
327-
"""Wrap a middleware to work in async context."""
328+
"""Wrap a middleware to work in async context.
329+
330+
For sync middlewares, we split execution into pre/post phases around the
331+
call to next(). The sync middleware runs its pre-processing (e.g. request
332+
validation), then we intercept the next() call, await the async handler,
333+
and resume the middleware with the real response so post-processing
334+
(e.g. response validation) sees the actual data.
335+
"""
328336

329337
async def wrapped(app):
330-
# Create a next_middleware that the sync middleware can call
331-
def sync_next(app):
332-
# This will be called by sync middleware
333-
# We need to run the async next_handler
334-
loop = asyncio.get_event_loop()
335-
if loop.is_running():
336-
# We're in an async context, create a task
337-
future = asyncio.ensure_future(next_handler(app))
338-
# Store for later await
339-
app.context["_async_next_result"] = future
340-
return Response(status_code=200, body="") # Placeholder
341-
else: # pragma: no cover
342-
return loop.run_until_complete(next_handler(app))
343-
344-
# Check if middleware is async
345338
if inspect.iscoroutinefunction(middleware):
346-
result = await middleware(app, next_handler)
347-
else:
348-
# Sync middleware - need special handling
349-
result = middleware(app, sync_next)
339+
return await middleware(app, next_handler)
350340

351-
# Check if we stored an async result
352-
if "_async_next_result" in app.context:
353-
future = app.context.pop("_async_next_result")
354-
result = await future
341+
# We use an Event to coordinate: the sync middleware runs in a thread,
342+
# calls sync_next which signals us to resolve the async handler,
343+
# then waits for the real response.
344+
middleware_called_next = asyncio.Event()
345+
next_app_holder: list = []
346+
real_response_holder: list = []
347+
middleware_result_holder: list = []
348+
middleware_error_holder: list = []
355349

356-
return result
350+
def sync_next(app):
351+
next_app_holder.append(app)
352+
middleware_called_next.set()
353+
# Block this thread until the real response is available
354+
event = threading.Event()
355+
next_app_holder.append(event)
356+
event.wait()
357+
return real_response_holder[0]
358+
359+
def run_middleware():
360+
try:
361+
result = middleware(app, sync_next)
362+
middleware_result_holder.append(result)
363+
except Exception as e:
364+
middleware_error_holder.append(e)
365+
366+
thread = threading.Thread(target=run_middleware, daemon=True)
367+
thread.start()
368+
369+
# Wait for the middleware to call next()
370+
await middleware_called_next.wait()
371+
372+
# Now resolve the async next_handler
373+
real_response = await next_handler(next_app_holder[0])
374+
real_response_holder.append(real_response)
375+
376+
# Signal the thread that the response is ready
377+
threading_event = next_app_holder[1]
378+
threading_event.set()
379+
380+
# Wait for the middleware thread to finish
381+
thread.join()
382+
383+
if middleware_error_holder:
384+
raise middleware_error_holder[0]
385+
386+
return middleware_result_holder[0]
357387

358388
return wrapped
359389

tests/functional/event_handler/_pydantic/test_http_resolver_pydantic.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ def search(
209209
# =============================================================================
210210

211211

212-
@pytest.mark.skip("Due to issue #7981.")
213212
@pytest.mark.asyncio
214213
async def test_async_handler_with_validation():
215214
# GIVEN an app with async handler and validation
@@ -241,6 +240,91 @@ async def create_user(user: UserModel) -> UserResponse:
241240
assert body["user"]["name"] == "AsyncUser"
242241

243242

243+
@pytest.mark.asyncio
244+
async def test_async_handler_invalid_response_returns_422():
245+
# GIVEN an app with async handler and validation
246+
app = HttpResolverLocal(enable_validation=True)
247+
248+
@app.get("/user")
249+
async def get_user() -> UserResponse:
250+
await asyncio.sleep(0.001)
251+
return {"name": "John"} # type: ignore # Missing required fields
252+
253+
scope = {
254+
"type": "http",
255+
"method": "GET",
256+
"path": "/user",
257+
"query_string": b"",
258+
"headers": [(b"content-type", b"application/json")],
259+
}
260+
261+
receive = make_asgi_receive()
262+
send, captured = make_asgi_send()
263+
264+
# WHEN called via ASGI interface
265+
await app(scope, receive, send)
266+
267+
# THEN it returns 422 for invalid response
268+
assert captured["status_code"] == 422
269+
270+
271+
@pytest.mark.asyncio
272+
async def test_sync_handler_with_validation_via_asgi():
273+
# GIVEN an app with a sync handler and validation, called via ASGI
274+
app = HttpResolverLocal(enable_validation=True)
275+
276+
@app.post("/users")
277+
def create_user(user: UserModel) -> UserResponse:
278+
return UserResponse(id="sync-123", user=user)
279+
280+
scope = {
281+
"type": "http",
282+
"method": "POST",
283+
"path": "/users",
284+
"query_string": b"",
285+
"headers": [(b"content-type", b"application/json")],
286+
}
287+
288+
receive = make_asgi_receive(b'{"name": "SyncUser", "age": 30}')
289+
send, captured = make_asgi_send()
290+
291+
# WHEN called via ASGI interface
292+
await app(scope, receive, send)
293+
294+
# THEN validation works with sync handler
295+
assert captured["status_code"] == 200
296+
body = json.loads(captured["body"])
297+
assert body["id"] == "sync-123"
298+
assert body["user"]["name"] == "SyncUser"
299+
300+
301+
@pytest.mark.asyncio
302+
async def test_sync_handler_invalid_response_returns_422_via_asgi():
303+
# GIVEN an app with a sync handler and validation, called via ASGI
304+
app = HttpResolverLocal(enable_validation=True)
305+
306+
@app.get("/user")
307+
def get_user() -> UserResponse:
308+
return {"name": "John"} # type: ignore # Missing required fields
309+
310+
scope = {
311+
"type": "http",
312+
"method": "GET",
313+
"path": "/user",
314+
"query_string": b"",
315+
"headers": [(b"content-type", b"application/json")],
316+
}
317+
318+
receive = make_asgi_receive()
319+
send, captured = make_asgi_send()
320+
321+
# WHEN called via ASGI interface
322+
await app(scope, receive, send)
323+
324+
# THEN it returns 422 for invalid response
325+
assert captured["status_code"] == 422
326+
327+
244328
# =============================================================================
245329
# OpenAPI Tests
246330
# =============================================================================

0 commit comments

Comments
 (0)