Skip to content
Open
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
-e .

trio==0.33.0

# Build...
build==1.2.2

# Test...
mypy==1.15.0
pytest==8.3.5
pytest-cov==6.1.1
pytest-trio==0.8.0

# Sync & Async mirroring...
unasync==0.6.0
Expand Down
6 changes: 4 additions & 2 deletions scripts/test
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ if [ -d 'venv' ] ; then
export PREFIX="venv/bin/"
fi

${PREFIX}mypy src/httpx
${PREFIX}mypy src/ahttpx
${PREFIX}pytest --cov src/httpx tests
${PREFIX}pytest tests/test_ahttpx

# ${PREFIX}mypy src/httpx
# ${PREFIX}pytest tests/test_httpx
48 changes: 48 additions & 0 deletions scripts/unasync
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,51 @@ unasync.unasync_files(
),
]
)


unasync.unasync_files(
fpath_list = [
"tests/test_ahttpx/test_client.py",
"tests/test_ahttpx/test_content.py",
"tests/test_ahttpx/test_headers.py",
"tests/test_ahttpx/test_network.py",
"tests/test_ahttpx/test_parsers.py",
"tests/test_ahttpx/test_pool.py",
"tests/test_ahttpx/test_quickstart.py",
"tests/test_ahttpx/test_request.py",
"tests/test_ahttpx/test_response.py",
"tests/test_ahttpx/test_streams.py",
"tests/test_ahttpx/test_urlencode.py",
"tests/test_ahttpx/test_urls.py",
],
rules = [
unasync.Rule(
"tests/test_ahttpx/",
"tests/test_httpx/",
additional_replacements={"ahttpx": "httpx"}
),
]
)


for path in [
"tests/test_httpx/test_client.py",
"tests/test_httpx/test_content.py",
"tests/test_httpx/test_headers.py",
"tests/test_httpx/test_network.py",
"tests/test_httpx/test_parsers.py",
"tests/test_httpx/test_pool.py",
"tests/test_httpx/test_quickstart.py",
"tests/test_httpx/test_request.py",
"tests/test_httpx/test_response.py",
"tests/test_httpx/test_streams.py",
"tests/test_httpx/test_urlencode.py",
"tests/test_httpx/test_urls.py",
]:
with open(path, "r") as fin:
lines = fin.readlines()

lines = [line for line in lines if line != "@pytest.mark.trio\n"]

with open(path, "w") as fout:
fout.writelines(lines)
79 changes: 47 additions & 32 deletions src/ahttpx/_network.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import ssl
import types
import typing

import trio
import certifi

from ._streams import Stream
Expand All @@ -13,39 +13,37 @@

class NetworkStream(Stream):
def __init__(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, address: str = ''
self, trio_stream: trio.abc.Stream, address: str = ''
) -> None:
self._reader = reader
self._writer = writer
self._trio_stream = trio_stream
self._address = address
self._tls = False
self._closed = False

async def read(self, size: int = -1) -> bytes:
if size < 0:
size = 64 * 1024
return await self._reader.read(size)
return await self._trio_stream.receive_some(size)

async def write(self, buffer: bytes) -> None:
self._writer.write(buffer)
await self._writer.drain()
await self._trio_stream.send_all(buffer)

async def close(self) -> None:
if not self._closed:
self._writer.close()
await self._writer.wait_closed()
# Close the NetworkStream.
# If the stream is already closed this is a checkpointed no-op.
try:
await self._trio_stream.aclose()
finally:
self._closed = True

def __repr__(self):
description = ""
description += " TLS" if self._tls else ""
description += " CLOSED" if self._closed else ""
return f"<NetworkStream [{self._address!r}{description}]>"
return f"<NetworkStream [{self._address}{description}]>"

def __del__(self):
if not self._closed:
import warnings
warnings.warn("NetworkStream was garbage collected without being closed.")
warnings.warn(f"{self!r} was garbage collected without being closed.")

# Context managed usage...
async def __aenter__(self) -> "NetworkStream":
Expand All @@ -61,13 +59,17 @@ async def __aexit__(


class NetworkServer:
def __init__(self, host: str, port: int, server: asyncio.Server):
def __init__(self, host: str, port: int, handler, listeners: list[trio.SocketListener]):
self.host = host
self.port = port
self._server = server
self._handler = handler
self._listeners = listeners

# Context managed usage...
async def __aenter__(self) -> "NetworkServer":
self._nursery_manager = trio.open_nursery()
self._nursery = await self._nursery_manager.__aenter__()
self._nursery.start_soon(trio.serve_listeners, self._handler, self._listeners)
return self

async def __aexit__(
Expand All @@ -76,8 +78,8 @@ async def __aexit__(
exc_value: BaseException | None = None,
traceback: types.TracebackType | None = None,
):
self._server.close()
await self._server.wait_closed()
self._nursery.cancel_scope.cancel()
await self._nursery_manager.__aexit__(exc_type, exc_value, traceback)


class NetworkBackend:
Expand All @@ -92,29 +94,42 @@ async def connect(self, host: str, port: int) -> NetworkStream:
"""
Connect to the given address, returning a Stream instance.
"""
# Create the TCP stream
address = f"{host}:{port}"
reader, writer = await asyncio.open_connection(host, port)
return NetworkStream(reader, writer, address=address)
trio_stream = await trio.open_tcp_stream(host, port)
return NetworkStream(trio_stream, address=address)

async def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream:
"""
Connect to the given address, returning a Stream instance.
"""
# Create the TCP stream
address = f"{host}:{port}"
reader, writer = await asyncio.open_connection(host, port)
await writer.start_tls(self._ssl_ctx, server_hostname=hostname)
return NetworkStream(reader, writer, address=address)
trio_stream = await trio.open_tcp_stream(host, port)

# Establish SSL over TCP
hostname = hostname or host
ssl_stream = trio.SSLStream(trio_stream, ssl_context=self._ssl_ctx, server_hostname=hostname)
await ssl_stream.do_handshake()

return NetworkStream(ssl_stream, address=address)

async def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer:
async def callback(reader, writer):
stream = NetworkStream(reader, writer)
await handler(stream)
async def callback(trio_stream):
stream = NetworkStream(trio_stream, address=f"{host}:{port}")
try:
await handler(stream)
finally:
await stream.close()

server = await asyncio.start_server(callback, host, port)
return NetworkServer(host, port, server)
listeners = await trio.open_tcp_listeners(port=port, host=host)
return NetworkServer(host, port, callback, listeners)

def __repr__(self):
return f"<NetworkBackend [trio]>"


Semaphore = asyncio.Semaphore
Lock = asyncio.Lock
timeout = asyncio.timeout
sleep = asyncio.sleep
Semaphore = trio.Semaphore
Lock = trio.Lock
timeout = trio.move_on_after
sleep = trio.sleep
22 changes: 22 additions & 0 deletions src/ahttpx/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ async def send_body(self, body: bytes) -> None:
# Handle body close
self.send_state = State.DONE

async def recv_close(self) -> bool:
# ...
if self.is_closed():
return True

if await self.parser.read_eof():
await self.close()
return True
return False

async def recv_method_line(self) -> tuple[bytes, bytes, bytes]:
"""
Receive the initial request method line:
Expand Down Expand Up @@ -463,6 +473,18 @@ async def read(self, size: int) -> bytes:
self._push_back(bytes(push_back))
return bytes(buffer)

async def read_eof(self) -> bool:
"""
Attempt to read the closing EOF.
Return True if the stream is EOF, or False otherwise.
"""
if not self._buffer:
chunk = await self._read_some()
if not chunk:
return True
self._push_back(chunk)
return False

async def read_until(self, marker: bytes, max_size: int, exc_text: str) -> bytes:
"""
Read and return bytes from the stream, delimited by marker.
Expand Down
8 changes: 4 additions & 4 deletions src/ahttpx/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, stream, endpoint):
# API entry points...
async def handle_requests(self):
try:
while not self._parser.is_closed():
while not await self._parser.recv_close():
method, url, headers = await self._recv_head()
stream = HTTPStream(self._recv_body, self._complete)
# TODO: Handle endpoint exceptions
Expand All @@ -43,13 +43,13 @@ async def handle_requests(self):
except Exception:
logger.error("Internal Server Error", exc_info=True)
content = Text("Internal Server Error")
err = Response(code=500, content=content)
err = Response(500, content=content)
await self._send_head(err)
await self._send_body(err)
else:
await self._send_head(response)
await self._send_body(response)
except Exception:
except BaseException:
logger.error("Internal Server Error", exc_info=True)

async def close(self):
Expand Down Expand Up @@ -89,7 +89,7 @@ async def _send_body(self, response: Response):

# Start it all over again...
async def _complete(self):
await self._parser.complete
await self._parser.complete()
self._idle_expiry = time.monotonic() + self._keepalive_duration


Expand Down
6 changes: 5 additions & 1 deletion src/httpx/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(self, listener: NetworkListener, handler: typing.Callable[[NetworkS
self._max_workers = 5
self._executor = None
self._thread = None
self._streams = list[NetworkStream]
self._streams: list[NetworkStream] = []

@property
def host(self):
Expand All @@ -176,6 +176,8 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
for stream in self._streams:
stream.close()
self.listener.close()
self._executor.shutdown(wait=True)

Expand All @@ -185,9 +187,11 @@ def _serve(self):

def _handler(self, stream):
try:
self._streams.append(stream)
self.handler(stream)
finally:
stream.close()
self._streams.remove(stream)


class NetworkBackend:
Expand Down
22 changes: 22 additions & 0 deletions src/httpx/_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,16 @@ def send_body(self, body: bytes) -> None:
# Handle body close
self.send_state = State.DONE

def recv_close(self) -> bool:
# ...
if self.is_closed():
return True

if self.parser.read_eof():
self.close()
return True
return False

def recv_method_line(self) -> tuple[bytes, bytes, bytes]:
"""
Receive the initial request method line:
Expand Down Expand Up @@ -463,6 +473,18 @@ def read(self, size: int) -> bytes:
self._push_back(bytes(push_back))
return bytes(buffer)

def read_eof(self) -> bool:
"""
Attempt to read the closing EOF.
Return True if the stream is EOF, or False otherwise.
"""
if not self._buffer:
chunk = self._read_some()
if not chunk:
return True
self._push_back(chunk)
return False

def read_until(self, marker: bytes, max_size: int, exc_text: str) -> bytes:
"""
Read and return bytes from the stream, delimited by marker.
Expand Down
8 changes: 4 additions & 4 deletions src/httpx/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, stream, endpoint):
# API entry points...
def handle_requests(self):
try:
while not self._parser.is_closed():
while not self._parser.recv_close():
method, url, headers = self._recv_head()
stream = HTTPStream(self._recv_body, self._complete)
# TODO: Handle endpoint exceptions
Expand All @@ -43,13 +43,13 @@ def handle_requests(self):
except Exception:
logger.error("Internal Server Error", exc_info=True)
content = Text("Internal Server Error")
err = Response(code=500, content=content)
err = Response(500, content=content)
self._send_head(err)
self._send_body(err)
else:
self._send_head(response)
self._send_body(response)
except Exception:
except BaseException:
logger.error("Internal Server Error", exc_info=True)

def close(self):
Expand Down Expand Up @@ -89,7 +89,7 @@ def _send_body(self, response: Response):

# Start it all over again...
def _complete(self):
self._parser.complete
self._parser.complete()
self._idle_expiry = time.monotonic() + self._keepalive_duration


Expand Down
File renamed without changes.
Loading