From eeb04ec0cbcdf6b0d53f10aa7add4c8f2f7a5eae Mon Sep 17 00:00:00 2001 From: "Audrey M. Roy Greenfeld" Date: Fri, 13 Mar 2026 09:05:16 +0800 Subject: [PATCH 1/2] Propagate inner app return values through the middleware chain StaticRewriteMiddleware.__call__ now uses `return await self.app()` on both code paths. Frameworks like Starlette, FastAPI, and Air rely on return values propagating through the middleware stack when using app.add_middleware(). A bare `await` without `return` silently discards the result. Co-authored-by: Daniel Roy Greenfeld --- src/staticware/middleware.py | 5 ++-- tests/test_staticware.py | 55 ++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/src/staticware/middleware.py b/src/staticware/middleware.py index 1921e87..5eb7179 100644 --- a/src/staticware/middleware.py +++ b/src/staticware/middleware.py @@ -203,8 +203,7 @@ def _replace(self, match: re.Match[str]) -> str: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": - await self.app(scope, receive, send) - return + return await self.app(scope, receive, send) response_start: dict[str, Any] | None = None body_parts: list[bytes] = [] @@ -260,7 +259,7 @@ async def send_wrapper(message: dict[str, Any]) -> None: await send({"type": "http.response.body", "body": full_body}) return - await self.app(scope, receive, send_wrapper) + return await self.app(scope, receive, send_wrapper) # ── Raw ASGI helpers ──────────────────────────────────────────────────── diff --git a/tests/test_staticware.py b/tests/test_staticware.py index ed28dfc..c10a9de 100644 --- a/tests/test_staticware.py +++ b/tests/test_staticware.py @@ -528,3 +528,58 @@ async def test_hashed_url_no_etag(static: HashedStatic) -> None: await static(make_scope(f"/static/{hashed_name}"), receive, resp) assert resp.status == 200 assert b"etag" not in resp.headers, "Hashed URL should not include an etag header" + + +# ── StaticRewriteMiddleware: return value propagation ────────────── + + +async def test_rewrite_middleware_returns_inner_app_result( + static: HashedStatic, +) -> None: + """Middleware should propagate the inner app's return value on the HTTP path. + + ASGI apps normally return None, but the spec does not forbid return values. + Frameworks like Starlette rely on ``return await self.app(...)`` so that + return values propagate through the middleware chain. A bare ``await`` + without ``return`` silently discards the result. + """ + sentinel = "app_result" + + async def inner_app(scope: dict, receive: Any, send: Any) -> str: + body = b"hello" + await send({ + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/html; charset=utf-8"), + (b"content-length", str(len(body)).encode("latin-1")), + ], + }) + await send({"type": "http.response.body", "body": body}) + return sentinel + + app = StaticRewriteMiddleware(inner_app, static=static) + resp = ResponseCollector() + result = await app(make_scope("/"), receive, resp) + assert result == sentinel + + +async def test_rewrite_middleware_returns_inner_app_result_non_http( + static: HashedStatic, +) -> None: + """Middleware should propagate the inner app's return value for non-HTTP scopes. + + When the scope type is not "http", the middleware forwards directly to the + inner app. It must ``return await self.app(...)`` so the return value is + not silently discarded. + """ + sentinel = "ws_result" + + async def inner_app(scope: dict, receive: Any, send: Any) -> str: + return sentinel + + app = StaticRewriteMiddleware(inner_app, static=static) + result = await app( + {"type": "websocket", "path": "/"}, receive, ResponseCollector() + ) + assert result == sentinel From 5bd404aff4d965bea656935a6ac89c2445a5dbef Mon Sep 17 00:00:00 2001 From: "Audrey M. Roy Greenfeld" Date: Fri, 13 Mar 2026 09:36:00 +0800 Subject: [PATCH 2/2] Match Starlette's Awaitable[Any] convention for ASGIApp The ASGI spec is silent on return type, but Starlette and other frameworks define ASGIApp as Awaitable[Any] because middleware chains propagate return values via `return await self.app(...)`. Widening from Awaitable[None] lets ty accept inner apps that return non-None values, which is what the return-value propagation tests exercise. Also applies ruff formatting (import ordering, dict literal style) to both files. --- src/staticware/middleware.py | 74 ++++++++++---------- tests/test_staticware.py | 126 +++++++++++++++++------------------ 2 files changed, 99 insertions(+), 101 deletions(-) diff --git a/src/staticware/middleware.py b/src/staticware/middleware.py index 5eb7179..68cdf48 100644 --- a/src/staticware/middleware.py +++ b/src/staticware/middleware.py @@ -19,15 +19,15 @@ import hashlib import mimetypes import re -from pathlib import Path from collections.abc import Awaitable, Callable +from pathlib import Path from typing import Any # ASGI protocol types — inlined so we depend on nothing. type Scope = dict[str, Any] type Receive = Callable[[], Awaitable[dict[str, Any]]] type Send = Callable[[dict[str, Any]], Awaitable[None]] -type ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]] +type ASGIApp = Callable[[Scope, Receive, Send], Awaitable[Any]] class HashedStatic: @@ -102,7 +102,7 @@ def _hash_files(self) -> None: self.file_map[relative] = hashed self._reverse[hashed] = relative - self._etags[relative] = f'"{hash_val}"'.encode('latin-1') + self._etags[relative] = f'"{hash_val}"'.encode("latin-1") def url(self, path: str) -> str: """Return the cache-busted URL for a static file path. @@ -165,9 +165,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if hdr_name == b"if-none-match" and hdr_value == etag: await _send_text(send, 304, b"") return - await _send_file( - send, file_path, extra_headers=[(b"etag", etag)] - ) + await _send_file(send, file_path, extra_headers=[(b"etag", etag)]) else: await _send_file(send, file_path) return @@ -201,7 +199,7 @@ def _replace(self, match: re.Match[str]) -> str: return f"{self.static.prefix}/{self.static.file_map[path]}" return match.group(0) - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Any: if scope["type"] != "http": return await self.app(scope, receive, send) @@ -224,9 +222,7 @@ async def send_wrapper(message: dict[str, Any]) -> None: if message["type"] == "http.response.body": if response_start is None: - raise RuntimeError( - "http.response.body received before http.response.start" - ) + raise RuntimeError("http.response.body received before http.response.start") if not is_html: await send(message) return @@ -245,13 +241,9 @@ async def send_wrapper(message: dict[str, Any]) -> None: pass if response_start is None: - raise RuntimeError( - "http.response.body received before http.response.start" - ) + raise RuntimeError("http.response.body received before http.response.start") new_headers = [ - (k, str(len(full_body)).encode("latin-1")) - if k == b"content-length" - else (k, v) + (k, str(len(full_body)).encode("latin-1")) if k == b"content-length" else (k, v) for k, v in response_start.get("headers", []) ] response_start["headers"] = new_headers @@ -282,28 +274,36 @@ async def _send_file( if extra_headers: headers.extend(extra_headers) - await send({ - "type": "http.response.start", - "status": 200, - "headers": headers, - }) - await send({ - "type": "http.response.body", - "body": content, - }) + await send( + { + "type": "http.response.start", + "status": 200, + "headers": headers, + } + ) + await send( + { + "type": "http.response.body", + "body": content, + } + ) async def _send_text(send: Send, status: int, body: bytes) -> None: """Send a plain-text ASGI response.""" - await send({ - "type": "http.response.start", - "status": status, - "headers": [ - (b"content-type", b"text/plain"), - (b"content-length", str(len(body)).encode("latin-1")), - ], - }) - await send({ - "type": "http.response.body", - "body": body, - }) + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + (b"content-type", b"text/plain"), + (b"content-length", str(len(body)).encode("latin-1")), + ], + } + ) + await send( + { + "type": "http.response.body", + "body": body, + } + ) diff --git a/tests/test_staticware.py b/tests/test_staticware.py index c10a9de..09bdc2f 100644 --- a/tests/test_staticware.py +++ b/tests/test_staticware.py @@ -22,7 +22,6 @@ from staticware import HashedStatic, StaticRewriteMiddleware - # ── Helpers ────────────────────────────────────────────────────────────── @@ -61,7 +60,6 @@ def expected_hash(content: bytes, length: int = 8) -> str: # ── HashedStatic: hashing and url() ────────────────────────────────────── - def test_file_map_contains_all_files(static: HashedStatic, static_dir: Path) -> None: assert "styles.css" in static.file_map assert "images/logo.png" in static.file_map @@ -216,14 +214,16 @@ def make_html_app(html: str): body = html.encode("utf-8") async def app(scope: dict, receive: Any, send: Any) -> None: - await send({ - "type": "http.response.start", - "status": 200, - "headers": [ - (b"content-type", b"text/html; charset=utf-8"), - (b"content-length", str(len(body)).encode("latin-1")), - ], - }) + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/html; charset=utf-8"), + (b"content-length", str(len(body)).encode("latin-1")), + ], + } + ) await send({"type": "http.response.body", "body": body}) return app @@ -233,14 +233,16 @@ def make_json_app(data: bytes): """Create a dummy ASGI app that returns JSON.""" async def app(scope: dict, receive: Any, send: Any) -> None: - await send({ - "type": "http.response.start", - "status": 200, - "headers": [ - (b"content-type", b"application/json"), - (b"content-length", str(len(data)).encode("latin-1")), - ], - }) + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(data)).encode("latin-1")), + ], + } + ) await send({"type": "http.response.body", "body": data}) return app @@ -318,10 +320,12 @@ async def test_rewrite_raises_runtime_error_on_body_before_start( async def broken_app(scope: dict, receive: Any, send: Any) -> None: # Skip http.response.start entirely — straight to body. - await send({ - "type": "http.response.body", - "body": b"oops", - }) + await send( + { + "type": "http.response.body", + "body": b"oops", + } + ) app = StaticRewriteMiddleware(broken_app, static=static) with pytest.raises(RuntimeError): @@ -335,14 +339,16 @@ async def test_rewrite_streaming_html_response(static: HashedStatic) -> None: async def streaming_app(scope: dict, receive: Any, send: Any) -> None: total = len(chunk1) + len(chunk2) - await send({ - "type": "http.response.start", - "status": 200, - "headers": [ - (b"content-type", b"text/html; charset=utf-8"), - (b"content-length", str(total).encode("latin-1")), - ], - }) + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/html; charset=utf-8"), + (b"content-length", str(total).encode("latin-1")), + ], + } + ) await send({"type": "http.response.body", "body": chunk1, "more_body": True}) await send({"type": "http.response.body", "body": chunk2, "more_body": False}) @@ -373,14 +379,16 @@ async def test_rewrite_non_utf8_html_passes_through(static: HashedStatic) -> Non raw_body = b"\x80\x81\x82 not valid utf-8" async def bad_encoding_app(scope: dict, receive: Any, send: Any) -> None: - await send({ - "type": "http.response.start", - "status": 200, - "headers": [ - (b"content-type", b"text/html; charset=utf-8"), - (b"content-length", str(len(raw_body)).encode("latin-1")), - ], - }) + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/html; charset=utf-8"), + (b"content-length", str(len(raw_body)).encode("latin-1")), + ], + } + ) await send({"type": "http.response.body", "body": raw_body}) app = StaticRewriteMiddleware(bad_encoding_app, static=static) @@ -397,9 +405,7 @@ def make_mount_scope(path: str, *, root_path: str = "") -> dict[str, Any]: return {"type": "http", "path": path, "root_path": root_path, "method": "GET"} -async def test_serve_with_root_path_scope( - static: HashedStatic, static_dir: Path -) -> None: +async def test_serve_with_root_path_scope(static: HashedStatic, static_dir: Path) -> None: """Starlette-style mount: root_path set, path still includes the prefix. Starlette sets scope["root_path"] = "/static" and leaves @@ -414,9 +420,7 @@ async def test_serve_with_root_path_scope( assert resp.text == "body { color: red; }" -async def test_serve_with_stripped_path( - static: HashedStatic, static_dir: Path -) -> None: +async def test_serve_with_stripped_path(static: HashedStatic, static_dir: Path) -> None: """Litestar-style mount: framework strips the prefix from scope["path"]. The sub-app sees scope["root_path"] = "/static" and @@ -466,18 +470,14 @@ async def test_serve_with_mismatched_mount_and_prefix(static_dir: Path) -> None: # ── HashedStatic: ETag and conditional requests ────────────────────── -def make_scope_with_headers( - path: str, headers: list[tuple[bytes, bytes]] | None = None -) -> dict[str, Any]: +def make_scope_with_headers(path: str, headers: list[tuple[bytes, bytes]] | None = None) -> dict[str, Any]: scope: dict[str, Any] = {"type": "http", "path": path, "method": "GET"} if headers: scope["headers"] = headers return scope -async def test_etag_on_unhashed_response( - static: HashedStatic, static_dir: Path -) -> None: +async def test_etag_on_unhashed_response(static: HashedStatic, static_dir: Path) -> None: """Original filename response includes an ETag header with the content hash.""" resp = ResponseCollector() await static(make_scope("/static/styles.css"), receive, resp) @@ -489,9 +489,7 @@ async def test_etag_on_unhashed_response( assert resp.headers[b"etag"] == f'"{h}"'.encode("latin-1") -async def test_conditional_request_returns_304( - static: HashedStatic, static_dir: Path -) -> None: +async def test_conditional_request_returns_304(static: HashedStatic, static_dir: Path) -> None: """If-None-Match with matching ETag returns 304 and empty body.""" css_content = (static_dir / "styles.css").read_bytes() h = expected_hash(css_content) @@ -547,14 +545,16 @@ async def test_rewrite_middleware_returns_inner_app_result( async def inner_app(scope: dict, receive: Any, send: Any) -> str: body = b"hello" - await send({ - "type": "http.response.start", - "status": 200, - "headers": [ - (b"content-type", b"text/html; charset=utf-8"), - (b"content-length", str(len(body)).encode("latin-1")), - ], - }) + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/html; charset=utf-8"), + (b"content-length", str(len(body)).encode("latin-1")), + ], + } + ) await send({"type": "http.response.body", "body": body}) return sentinel @@ -579,7 +579,5 @@ async def inner_app(scope: dict, receive: Any, send: Any) -> str: return sentinel app = StaticRewriteMiddleware(inner_app, static=static) - result = await app( - {"type": "websocket", "path": "/"}, receive, ResponseCollector() - ) + result = await app({"type": "websocket", "path": "/"}, receive, ResponseCollector()) assert result == sentinel