-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Expand file tree
/
Copy pathstreaming_asgi_transport.py
More file actions
218 lines (176 loc) · 7.95 KB
/
streaming_asgi_transport.py
File metadata and controls
218 lines (176 loc) · 7.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""
A modified version of httpx.ASGITransport that supports streaming responses.
This transport runs the ASGI app as a separate anyio task, allowing it to
handle streaming responses like SSE where the app doesn't terminate until
the connection is closed.
This is only intended for writing tests for the SSE transport.
"""
import typing
from collections.abc import Awaitable, Callable
from typing import Any, cast
import anyio
import anyio.abc
import anyio.streams.memory
from httpx._models import Request, Response
from httpx._transports.base import AsyncBaseTransport
from httpx._types import AsyncByteStream
from starlette.types import ASGIApp, Receive, Scope, Send
class StreamingASGITransport(AsyncBaseTransport):
"""
A custom AsyncTransport that handles sending requests directly to an ASGI app
and supports streaming responses like SSE.
Unlike the standard ASGITransport, this transport runs the ASGI app in a
separate anyio task, allowing it to handle responses from apps that don't
terminate immediately (like SSE endpoints).
Arguments:
* `app` - The ASGI application.
* `raise_app_exceptions` - Boolean indicating if exceptions in the application
should be raised. Default to `True`. Can be set to `False` for use cases
such as testing the content of a client 500 response.
* `root_path` - The root path on which the ASGI application should be mounted.
* `client` - A two-tuple indicating the client IP and port of incoming requests.
* `response_timeout` - Timeout in seconds to wait for the initial response.
Default is 10 seconds.
TODO: https://github.com/encode/httpx/pull/3059 is adding something similar to
upstream httpx. When that merges, we should delete this & switch back to the
upstream implementation.
"""
def __init__(
self,
app: ASGIApp,
task_group: anyio.abc.TaskGroup,
raise_app_exceptions: bool = True,
root_path: str = "",
client: tuple[str, int] = ("127.0.0.1", 123),
) -> None:
self.app = app
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
self.client = client
self.task_group = task_group
async def handle_async_request(
self,
request: Request,
) -> Response:
assert isinstance(request.stream, AsyncByteStream)
disconnect_event = anyio.Event()
# ASGI scope.
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": request.method,
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
"scheme": request.url.scheme,
"path": request.url.path,
"raw_path": request.url.raw_path.split(b"?")[0],
"query_string": request.url.query,
"server": (request.url.host, request.url.port),
"client": self.client,
"root_path": self.root_path,
}
# Request body
request_body_chunks = request.stream.__aiter__()
request_complete = False
# Response state
status_code = 499
response_headers = None
response_started = False
response_complete = anyio.Event()
initial_response_ready = anyio.Event()
# Synchronization for streaming response
asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[dict[str, Any]](100)
content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100)
# ASGI callables.
async def send_disconnect() -> None:
disconnect_event.set()
async def receive() -> dict[str, Any]:
nonlocal request_complete
if disconnect_event.is_set():
return {"type": "http.disconnect"}
if request_complete:
await disconnect_event.wait()
return {"type": "http.disconnect"}
try:
body = await request_body_chunks.__anext__()
except StopAsyncIteration:
request_complete = True
return {"type": "http.request", "body": b"", "more_body": False}
return {"type": "http.request", "body": body, "more_body": True}
async def send(message: dict[str, Any]) -> None:
nonlocal status_code, response_headers, response_started
await asgi_send_channel.send(message)
# Start the ASGI application in a separate task
async def run_app() -> None:
try:
# Cast the receive and send functions to the ASGI types
await self.app(cast(Scope, scope), cast(Receive, receive), cast(Send, send))
except Exception:
if self.raise_app_exceptions:
raise
if not response_started:
await asgi_send_channel.send({"type": "http.response.start", "status": 500, "headers": []})
await asgi_send_channel.send({"type": "http.response.body", "body": b"", "more_body": False})
finally:
await asgi_send_channel.aclose()
# Process messages from the ASGI app
async def process_messages() -> None:
nonlocal status_code, response_headers, response_started
try:
async with asgi_receive_channel:
async for message in asgi_receive_channel:
if message["type"] == "http.response.start":
assert not response_started
status_code = message["status"]
response_headers = message.get("headers", [])
response_started = True
# As soon as we have headers, we can return a response
initial_response_ready.set()
elif message["type"] == "http.response.body":
body = message.get("body", b"")
more_body = message.get("more_body", False)
if body and request.method != "HEAD":
await content_send_channel.send(body)
if not more_body:
response_complete.set()
await content_send_channel.aclose()
break
finally:
# Ensure events are set even if there's an error
initial_response_ready.set()
response_complete.set()
await content_send_channel.aclose()
# Create tasks for running the app and processing messages
self.task_group.start_soon(run_app)
self.task_group.start_soon(process_messages)
# Wait for the initial response or timeout
await initial_response_ready.wait()
# Create a streaming response
return Response(
status_code,
headers=response_headers,
stream=StreamingASGIResponseStream(content_receive_channel, send_disconnect),
)
class StreamingASGIResponseStream(AsyncByteStream):
"""
A modified ASGIResponseStream that supports streaming responses.
This class extends the standard ASGIResponseStream to handle cases where
the response body continues to be generated after the initial response
is returned.
"""
def __init__(
self,
receive_channel: anyio.streams.memory.MemoryObjectReceiveStream[bytes],
send_disconnect: Callable[[], Awaitable[None]],
) -> None:
self.receive_channel = receive_channel
self.send_disconnect = send_disconnect
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
try:
async for chunk in self.receive_channel:
yield chunk
finally:
await self.aclose()
async def aclose(self) -> None:
await self.receive_channel.aclose()
await self.send_disconnect()