Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 39 additions & 40 deletions src/staticware/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
import hashlib
import mimetypes
import re
from pathlib import Path
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any

# ASGI protocol types — inlined so we depend on nothing.
type Scope = dict[str, Any]
type Receive = Callable[[], Awaitable[dict[str, Any]]]
type Send = Callable[[dict[str, Any]], Awaitable[None]]
type ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]]
type ASGIApp = Callable[[Scope, Receive, Send], Awaitable[Any]]


class HashedStatic:
Expand Down Expand Up @@ -102,7 +102,7 @@ def _hash_files(self) -> None:

self.file_map[relative] = hashed
self._reverse[hashed] = relative
self._etags[relative] = f'"{hash_val}"'.encode('latin-1')
self._etags[relative] = f'"{hash_val}"'.encode("latin-1")

def url(self, path: str) -> str:
"""Return the cache-busted URL for a static file path.
Expand Down Expand Up @@ -165,9 +165,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if hdr_name == b"if-none-match" and hdr_value == etag:
await _send_text(send, 304, b"")
return
await _send_file(
send, file_path, extra_headers=[(b"etag", etag)]
)
await _send_file(send, file_path, extra_headers=[(b"etag", etag)])
else:
await _send_file(send, file_path)
return
Expand Down Expand Up @@ -201,10 +199,9 @@ def _replace(self, match: re.Match[str]) -> str:
return f"{self.static.prefix}/{self.static.file_map[path]}"
return match.group(0)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Any:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
return await self.app(scope, receive, send)

response_start: dict[str, Any] | None = None
body_parts: list[bytes] = []
Expand All @@ -225,9 +222,7 @@ async def send_wrapper(message: dict[str, Any]) -> None:

if message["type"] == "http.response.body":
if response_start is None:
raise RuntimeError(
"http.response.body received before http.response.start"
)
raise RuntimeError("http.response.body received before http.response.start")
if not is_html:
await send(message)
return
Expand All @@ -246,21 +241,17 @@ async def send_wrapper(message: dict[str, Any]) -> None:
pass

if response_start is None:
raise RuntimeError(
"http.response.body received before http.response.start"
)
raise RuntimeError("http.response.body received before http.response.start")
new_headers = [
(k, str(len(full_body)).encode("latin-1"))
if k == b"content-length"
else (k, v)
(k, str(len(full_body)).encode("latin-1")) if k == b"content-length" else (k, v)
for k, v in response_start.get("headers", [])
]
response_start["headers"] = new_headers
await send(response_start)
await send({"type": "http.response.body", "body": full_body})
return

await self.app(scope, receive, send_wrapper)
return await self.app(scope, receive, send_wrapper)


# ── Raw ASGI helpers ────────────────────────────────────────────────────
Expand All @@ -283,28 +274,36 @@ async def _send_file(
if extra_headers:
headers.extend(extra_headers)

await send({
"type": "http.response.start",
"status": 200,
"headers": headers,
})
await send({
"type": "http.response.body",
"body": content,
})
await send(
{
"type": "http.response.start",
"status": 200,
"headers": headers,
}
)
await send(
{
"type": "http.response.body",
"body": content,
}
)


async def _send_text(send: Send, status: int, body: bytes) -> None:
"""Send a plain-text ASGI response."""
await send({
"type": "http.response.start",
"status": status,
"headers": [
(b"content-type", b"text/plain"),
(b"content-length", str(len(body)).encode("latin-1")),
],
})
await send({
"type": "http.response.body",
"body": body,
})
await send(
{
"type": "http.response.start",
"status": status,
"headers": [
(b"content-type", b"text/plain"),
(b"content-length", str(len(body)).encode("latin-1")),
],
}
)
await send(
{
"type": "http.response.body",
"body": body,
}
)
159 changes: 106 additions & 53 deletions tests/test_staticware.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from staticware import HashedStatic, StaticRewriteMiddleware


# ── Helpers ──────────────────────────────────────────────────────────────


Expand Down Expand Up @@ -61,7 +60,6 @@ def expected_hash(content: bytes, length: int = 8) -> str:
# ── HashedStatic: hashing and url() ──────────────────────────────────────



def test_file_map_contains_all_files(static: HashedStatic, static_dir: Path) -> None:
assert "styles.css" in static.file_map
assert "images/logo.png" in static.file_map
Expand Down Expand Up @@ -216,14 +214,16 @@ def make_html_app(html: str):
body = html.encode("utf-8")

async def app(scope: dict, receive: Any, send: Any) -> None:
await send({
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(len(body)).encode("latin-1")),
],
})
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(len(body)).encode("latin-1")),
],
}
)
await send({"type": "http.response.body", "body": body})

return app
Expand All @@ -233,14 +233,16 @@ def make_json_app(data: bytes):
"""Create a dummy ASGI app that returns JSON."""

async def app(scope: dict, receive: Any, send: Any) -> None:
await send({
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"application/json"),
(b"content-length", str(len(data)).encode("latin-1")),
],
})
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"application/json"),
(b"content-length", str(len(data)).encode("latin-1")),
],
}
)
await send({"type": "http.response.body", "body": data})

return app
Expand Down Expand Up @@ -318,10 +320,12 @@ async def test_rewrite_raises_runtime_error_on_body_before_start(

async def broken_app(scope: dict, receive: Any, send: Any) -> None:
# Skip http.response.start entirely — straight to body.
await send({
"type": "http.response.body",
"body": b"<html>oops</html>",
})
await send(
{
"type": "http.response.body",
"body": b"<html>oops</html>",
}
)

app = StaticRewriteMiddleware(broken_app, static=static)
with pytest.raises(RuntimeError):
Expand All @@ -335,14 +339,16 @@ async def test_rewrite_streaming_html_response(static: HashedStatic) -> None:

async def streaming_app(scope: dict, receive: Any, send: Any) -> None:
total = len(chunk1) + len(chunk2)
await send({
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(total).encode("latin-1")),
],
})
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(total).encode("latin-1")),
],
}
)
await send({"type": "http.response.body", "body": chunk1, "more_body": True})
await send({"type": "http.response.body", "body": chunk2, "more_body": False})

Expand Down Expand Up @@ -373,14 +379,16 @@ async def test_rewrite_non_utf8_html_passes_through(static: HashedStatic) -> Non
raw_body = b"<html>\x80\x81\x82 not valid utf-8</html>"

async def bad_encoding_app(scope: dict, receive: Any, send: Any) -> None:
await send({
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(len(raw_body)).encode("latin-1")),
],
})
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(len(raw_body)).encode("latin-1")),
],
}
)
await send({"type": "http.response.body", "body": raw_body})

app = StaticRewriteMiddleware(bad_encoding_app, static=static)
Expand All @@ -397,9 +405,7 @@ def make_mount_scope(path: str, *, root_path: str = "") -> dict[str, Any]:
return {"type": "http", "path": path, "root_path": root_path, "method": "GET"}


async def test_serve_with_root_path_scope(
static: HashedStatic, static_dir: Path
) -> None:
async def test_serve_with_root_path_scope(static: HashedStatic, static_dir: Path) -> None:
"""Starlette-style mount: root_path set, path still includes the prefix.

Starlette sets scope["root_path"] = "/static" and leaves
Expand All @@ -414,9 +420,7 @@ async def test_serve_with_root_path_scope(
assert resp.text == "body { color: red; }"


async def test_serve_with_stripped_path(
static: HashedStatic, static_dir: Path
) -> None:
async def test_serve_with_stripped_path(static: HashedStatic, static_dir: Path) -> None:
"""Litestar-style mount: framework strips the prefix from scope["path"].

The sub-app sees scope["root_path"] = "/static" and
Expand Down Expand Up @@ -466,18 +470,14 @@ async def test_serve_with_mismatched_mount_and_prefix(static_dir: Path) -> None:
# ── HashedStatic: ETag and conditional requests ──────────────────────


def make_scope_with_headers(
path: str, headers: list[tuple[bytes, bytes]] | None = None
) -> dict[str, Any]:
def make_scope_with_headers(path: str, headers: list[tuple[bytes, bytes]] | None = None) -> dict[str, Any]:
scope: dict[str, Any] = {"type": "http", "path": path, "method": "GET"}
if headers:
scope["headers"] = headers
return scope


async def test_etag_on_unhashed_response(
static: HashedStatic, static_dir: Path
) -> None:
async def test_etag_on_unhashed_response(static: HashedStatic, static_dir: Path) -> None:
"""Original filename response includes an ETag header with the content hash."""
resp = ResponseCollector()
await static(make_scope("/static/styles.css"), receive, resp)
Expand All @@ -489,9 +489,7 @@ async def test_etag_on_unhashed_response(
assert resp.headers[b"etag"] == f'"{h}"'.encode("latin-1")


async def test_conditional_request_returns_304(
static: HashedStatic, static_dir: Path
) -> None:
async def test_conditional_request_returns_304(static: HashedStatic, static_dir: Path) -> None:
"""If-None-Match with matching ETag returns 304 and empty body."""
css_content = (static_dir / "styles.css").read_bytes()
h = expected_hash(css_content)
Expand Down Expand Up @@ -528,3 +526,58 @@ async def test_hashed_url_no_etag(static: HashedStatic) -> None:
await static(make_scope(f"/static/{hashed_name}"), receive, resp)
assert resp.status == 200
assert b"etag" not in resp.headers, "Hashed URL should not include an etag header"


# ── StaticRewriteMiddleware: return value propagation ──────────────


async def test_rewrite_middleware_returns_inner_app_result(
static: HashedStatic,
) -> None:
"""Middleware should propagate the inner app's return value on the HTTP path.

ASGI apps normally return None, but the spec does not forbid return values.
Frameworks like Starlette rely on ``return await self.app(...)`` so that
return values propagate through the middleware chain. A bare ``await``
without ``return`` silently discards the result.
"""
sentinel = "app_result"

async def inner_app(scope: dict, receive: Any, send: Any) -> str:
body = b"<html>hello</html>"
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(len(body)).encode("latin-1")),
],
}
)
await send({"type": "http.response.body", "body": body})
return sentinel

app = StaticRewriteMiddleware(inner_app, static=static)
resp = ResponseCollector()
result = await app(make_scope("/"), receive, resp)
assert result == sentinel


async def test_rewrite_middleware_returns_inner_app_result_non_http(
static: HashedStatic,
) -> None:
"""Middleware should propagate the inner app's return value for non-HTTP scopes.

When the scope type is not "http", the middleware forwards directly to the
inner app. It must ``return await self.app(...)`` so the return value is
not silently discarded.
"""
sentinel = "ws_result"

async def inner_app(scope: dict, receive: Any, send: Any) -> str:
return sentinel

app = StaticRewriteMiddleware(inner_app, static=static)
result = await app({"type": "websocket", "path": "/"}, receive, ResponseCollector())
assert result == sentinel