diff --git a/src/ahttpx/__init__.py b/src/ahttpx/__init__.py index 9e589ab..aafb928 100644 --- a/src/ahttpx/__init__.py +++ b/src/ahttpx/__init__.py @@ -3,12 +3,12 @@ from ._content import * # Content, File, Files, Form, HTML, JSON, MultiPart, Text from ._headers import * # Headers from ._network import * # NetworkBackend, NetworkStream, timeout -from ._parsers import * # HTTPParser, ProtocolError +from ._parsers import * # HTTPParser, HTTPStream, ProtocolError from ._pool import * # Connection, ConnectionPool, Transport from ._quickstart import * # get, post, put, patch, delete from ._response import * # Response from ._request import * # Request -from ._streams import * # ByteStream, DuplexStream, FileStream, HTTPStream, Stream +from ._streams import * # ByteStream, DuplexStream, FileStream, Stream from ._server import * # serve_http, run from ._urlencode import * # quote, unquote, urldecode, urlencode from ._urls import * # QueryParams, URL diff --git a/src/ahttpx/_parsers.py b/src/ahttpx/_parsers.py index 440f810..6ac2c33 100644 --- a/src/ahttpx/_parsers.py +++ b/src/ahttpx/_parsers.py @@ -1,8 +1,10 @@ import enum +import io +import typing from ._streams import Stream -__all__ = ['HTTPParser', 'Mode', 'ProtocolError'] +__all__ = ['HTTPParser', 'HTTPStream', 'Mode', 'ProtocolError'] # TODO... @@ -436,6 +438,50 @@ def __repr__(self) -> str: return f'' +class HTTPStream(Stream): + def __init__(self, parser: HTTPParser, callback: typing.Callable | None = None): + self._parser = parser + self._buffer = io.BytesIO() + self._callback = callback + + async def read(self, size=-1) -> bytes: + sections = [] + length = 0 + + # If we have any data in the buffer read that and clear the buffer. + buffered = self._buffer.read() + if buffered: + sections.append(buffered) + length += len(buffered) + self._buffer.seek(0) + self._buffer.truncate(0) + + # Read each chunk in turn. + while (size < 0) or (length < size): + section = await self._parser.recv_body() + sections.append(section) + length += len(section) + if section == b'': + break + + # If we've more data than requested, then push some back into the buffer. + output = b''.join(sections) + if size > -1 and len(output) > size: + output, remainder = output[:size], output[size:] + self._buffer.write(remainder) + self._buffer.seek(0) + + return output + + async def close(self) -> None: + try: + self._buffer.close() + await self._parser.complete() + finally: + if self._callback is not None: + await self._callback() + + class ReadAheadParser: """ A buffered I/O stream, with methods for read-ahead parsing. diff --git a/src/ahttpx/_pool.py b/src/ahttpx/_pool.py index f712cfa..ff12246 100644 --- a/src/ahttpx/_pool.py +++ b/src/ahttpx/_pool.py @@ -5,10 +5,10 @@ from ._content import Content from ._headers import Headers from ._network import Lock, NetworkBackend, Semaphore -from ._parsers import HTTPParser +from ._parsers import HTTPParser, HTTPStream from ._response import Response from ._request import Request -from ._streams import HTTPStream, Stream +from ._streams import Stream from ._urls import URL @@ -170,7 +170,7 @@ async def send(self, request: Request) -> Response: await self._send_head(request) await self._send_body(request) code, headers = await self._recv_head() - stream = HTTPStream(self._recv_body, self._complete) + stream = HTTPStream(self._parser, callback=self._complete) # TODO... return Response(code, headers=headers, content=stream) # finally: @@ -237,7 +237,6 @@ async def _recv_body(self) -> bytes: # Request/response cycle complete... async def _complete(self) -> None: - await self._parser.complete() self._idle_expiry = time.monotonic() + self._keepalive_duration async def _close(self) -> None: diff --git a/src/ahttpx/_server.py b/src/ahttpx/_server.py index de9179a..973da54 100644 --- a/src/ahttpx/_server.py +++ b/src/ahttpx/_server.py @@ -3,11 +3,10 @@ import time from ._content import Text -from ._parsers import HTTPParser +from ._parsers import HTTPParser, HTTPStream from ._request import Request from ._response import Response from ._network import NetworkBackend, sleep -from ._streams import HTTPStream __all__ = [ "serve_http", "run" @@ -33,7 +32,7 @@ async def handle_requests(self): try: while not await self._parser.recv_close(): method, url, headers = await self._recv_head() - stream = HTTPStream(self._recv_body, self._complete) + stream = HTTPStream(self._parser, callback=self._complete) # TODO: Handle endpoint exceptions async with Request(method, url, headers=headers, content=stream) as request: try: @@ -89,7 +88,6 @@ async def _send_body(self, response: Response): # Start it all over again... async def _complete(self): - await self._parser.complete() self._idle_expiry = time.monotonic() + self._keepalive_duration diff --git a/src/ahttpx/_streams.py b/src/ahttpx/_streams.py index d5e5ad0..ae1739d 100644 --- a/src/ahttpx/_streams.py +++ b/src/ahttpx/_streams.py @@ -3,6 +3,9 @@ import os +__all__ = ['Stream', 'ByteStream', 'DuplexStream', 'FileStream', 'MultiPartStream'] + + class Stream: async def read(self, size: int=-1) -> bytes: raise NotImplementedError() @@ -103,47 +106,6 @@ async def __aenter__(self): return self -class HTTPStream(Stream): - def __init__(self, next_chunk, complete): - self._next_chunk = next_chunk - self._complete = complete - self._buffer = io.BytesIO() - - async def read(self, size=-1) -> bytes: - sections = [] - length = 0 - - # If we have any data in the buffer read that and clear the buffer. - buffered = self._buffer.read() - if buffered: - sections.append(buffered) - length += len(buffered) - self._buffer.seek(0) - self._buffer.truncate(0) - - # Read each chunk in turn. - while (size < 0) or (length < size): - section = await self._next_chunk() - sections.append(section) - length += len(section) - if section == b'': - break - - # If we've more data than requested, then push some back into the buffer. - output = b''.join(sections) - if size > -1 and len(output) > size: - output, remainder = output[:size], output[size:] - self._buffer.write(remainder) - self._buffer.seek(0) - - return output - - async def close(self) -> None: - self._buffer.close() - if self._complete is not None: - await self._complete() - - class MultiPartStream(Stream): def __init__(self, form: list[tuple[str, str]], files: list[tuple[str, str]], boundary=''): self._form = list(form) diff --git a/src/httpx/__init__.py b/src/httpx/__init__.py index 9e589ab..aafb928 100644 --- a/src/httpx/__init__.py +++ b/src/httpx/__init__.py @@ -3,12 +3,12 @@ from ._content import * # Content, File, Files, Form, HTML, JSON, MultiPart, Text from ._headers import * # Headers from ._network import * # NetworkBackend, NetworkStream, timeout -from ._parsers import * # HTTPParser, ProtocolError +from ._parsers import * # HTTPParser, HTTPStream, ProtocolError from ._pool import * # Connection, ConnectionPool, Transport from ._quickstart import * # get, post, put, patch, delete from ._response import * # Response from ._request import * # Request -from ._streams import * # ByteStream, DuplexStream, FileStream, HTTPStream, Stream +from ._streams import * # ByteStream, DuplexStream, FileStream, Stream from ._server import * # serve_http, run from ._urlencode import * # quote, unquote, urldecode, urlencode from ._urls import * # QueryParams, URL diff --git a/src/httpx/_parsers.py b/src/httpx/_parsers.py index b8be024..415bfef 100644 --- a/src/httpx/_parsers.py +++ b/src/httpx/_parsers.py @@ -1,8 +1,10 @@ import enum +import io +import typing from ._streams import Stream -__all__ = ['HTTPParser', 'Mode', 'ProtocolError'] +__all__ = ['HTTPParser', 'HTTPStream', 'Mode', 'ProtocolError'] # TODO... @@ -436,6 +438,50 @@ def __repr__(self) -> str: return f'' +class HTTPStream(Stream): + def __init__(self, parser: HTTPParser, callback: typing.Callable | None = None): + self._parser = parser + self._buffer = io.BytesIO() + self._callback = callback + + def read(self, size=-1) -> bytes: + sections = [] + length = 0 + + # If we have any data in the buffer read that and clear the buffer. + buffered = self._buffer.read() + if buffered: + sections.append(buffered) + length += len(buffered) + self._buffer.seek(0) + self._buffer.truncate(0) + + # Read each chunk in turn. + while (size < 0) or (length < size): + section = self._parser.recv_body() + sections.append(section) + length += len(section) + if section == b'': + break + + # If we've more data than requested, then push some back into the buffer. + output = b''.join(sections) + if size > -1 and len(output) > size: + output, remainder = output[:size], output[size:] + self._buffer.write(remainder) + self._buffer.seek(0) + + return output + + def close(self) -> None: + try: + self._buffer.close() + self._parser.complete() + finally: + if self._callback is not None: + self._callback() + + class ReadAheadParser: """ A buffered I/O stream, with methods for read-ahead parsing. diff --git a/src/httpx/_pool.py b/src/httpx/_pool.py index 7193f8d..71cb942 100644 --- a/src/httpx/_pool.py +++ b/src/httpx/_pool.py @@ -5,10 +5,10 @@ from ._content import Content from ._headers import Headers from ._network import Lock, NetworkBackend, Semaphore -from ._parsers import HTTPParser +from ._parsers import HTTPParser, HTTPStream from ._response import Response from ._request import Request -from ._streams import HTTPStream, Stream +from ._streams import Stream from ._urls import URL @@ -170,7 +170,7 @@ def send(self, request: Request) -> Response: self._send_head(request) self._send_body(request) code, headers = self._recv_head() - stream = HTTPStream(self._recv_body, self._complete) + stream = HTTPStream(self._parser, callback=self._complete) # TODO... return Response(code, headers=headers, content=stream) # finally: @@ -237,7 +237,6 @@ def _recv_body(self) -> bytes: # Request/response cycle complete... def _complete(self) -> None: - self._parser.complete() self._idle_expiry = time.monotonic() + self._keepalive_duration def _close(self) -> None: diff --git a/src/httpx/_server.py b/src/httpx/_server.py index 14d23fa..31bec5d 100644 --- a/src/httpx/_server.py +++ b/src/httpx/_server.py @@ -3,11 +3,10 @@ import time from ._content import Text -from ._parsers import HTTPParser +from ._parsers import HTTPParser, HTTPStream from ._request import Request from ._response import Response from ._network import NetworkBackend, sleep -from ._streams import HTTPStream __all__ = [ "serve_http", "run" @@ -33,7 +32,7 @@ def handle_requests(self): try: while not self._parser.recv_close(): method, url, headers = self._recv_head() - stream = HTTPStream(self._recv_body, self._complete) + stream = HTTPStream(self._parser, callback=self._complete) # TODO: Handle endpoint exceptions with Request(method, url, headers=headers, content=stream) as request: try: @@ -89,7 +88,6 @@ def _send_body(self, response: Response): # Start it all over again... def _complete(self): - self._parser.complete() self._idle_expiry = time.monotonic() + self._keepalive_duration diff --git a/src/httpx/_streams.py b/src/httpx/_streams.py index 1fc6cde..66312b3 100644 --- a/src/httpx/_streams.py +++ b/src/httpx/_streams.py @@ -3,6 +3,9 @@ import os +__all__ = ['Stream', 'ByteStream', 'DuplexStream', 'FileStream', 'MultiPartStream'] + + class Stream: def read(self, size: int=-1) -> bytes: raise NotImplementedError() @@ -103,47 +106,6 @@ def __enter__(self): return self -class HTTPStream(Stream): - def __init__(self, next_chunk, complete): - self._next_chunk = next_chunk - self._complete = complete - self._buffer = io.BytesIO() - - def read(self, size=-1) -> bytes: - sections = [] - length = 0 - - # If we have any data in the buffer read that and clear the buffer. - buffered = self._buffer.read() - if buffered: - sections.append(buffered) - length += len(buffered) - self._buffer.seek(0) - self._buffer.truncate(0) - - # Read each chunk in turn. - while (size < 0) or (length < size): - section = self._next_chunk() - sections.append(section) - length += len(section) - if section == b'': - break - - # If we've more data than requested, then push some back into the buffer. - output = b''.join(sections) - if size > -1 and len(output) > size: - output, remainder = output[:size], output[size:] - self._buffer.write(remainder) - self._buffer.seek(0) - - return output - - def close(self) -> None: - self._buffer.close() - if self._complete is not None: - self._complete() - - class MultiPartStream(Stream): def __init__(self, form: list[tuple[str, str]], files: list[tuple[str, str]], boundary=''): self._form = list(form) diff --git a/tests/test_ahttpx/test_request.py b/tests/test_ahttpx/test_request.py index 6bc9f3d..441ae08 100644 --- a/tests/test_ahttpx/test_request.py +++ b/tests/test_ahttpx/test_request.py @@ -2,16 +2,6 @@ import pytest -class ByteIterator: - def __init__(self, buffer=b""): - self._buffer = buffer - - async def next(self) -> bytes: - buffer = self._buffer - self._buffer = b'' - return buffer - - @pytest.mark.trio async def test_request(): r = ahttpx.Request("GET", "https://example.com") @@ -41,8 +31,7 @@ async def test_request_bytes(): @pytest.mark.trio async def test_request_stream(): - i = ByteIterator(b"Hello, world") - stream = ahttpx.HTTPStream(i.next, None) + stream = ahttpx.ByteStream(b"Hello, world") r = ahttpx.Request("POST", "https://example.com", content=stream) assert repr(r) == "" @@ -50,7 +39,7 @@ async def test_request_stream(): assert r.url == "https://example.com" assert r.headers == { "Host": "example.com", - "Transfer-Encoding": "chunked", + "Content-Length": "12", } assert await r.read() == b"Hello, world" diff --git a/tests/test_ahttpx/test_response.py b/tests/test_ahttpx/test_response.py index 3b2f4bf..18d282d 100644 --- a/tests/test_ahttpx/test_response.py +++ b/tests/test_ahttpx/test_response.py @@ -2,16 +2,6 @@ import pytest -class ByteIterator: - def __init__(self, buffer=b""): - self._buffer = buffer - - async def next(self) -> bytes: - buffer = self._buffer - self._buffer = b'' - return buffer - - @pytest.mark.trio async def test_response(): r = ahttpx.Response(200) @@ -46,13 +36,12 @@ async def test_response_bytes(): @pytest.mark.trio async def test_response_stream(): - i = ByteIterator(b"Hello, world") - stream = ahttpx.HTTPStream(i.next, None) + stream = ahttpx.ByteStream(b"Hello, world") r = ahttpx.Response(200, content=stream) assert repr(r) == "" assert r.headers == { - "Transfer-Encoding": "chunked", + "Content-Length": "12", } assert await r.read() == b"Hello, world" diff --git a/tests/test_httpx/test_request.py b/tests/test_httpx/test_request.py index 47e5c4d..e5a8d3a 100644 --- a/tests/test_httpx/test_request.py +++ b/tests/test_httpx/test_request.py @@ -2,16 +2,6 @@ import pytest -class ByteIterator: - def __init__(self, buffer=b""): - self._buffer = buffer - - def next(self) -> bytes: - buffer = self._buffer - self._buffer = b'' - return buffer - - def test_request(): r = httpx.Request("GET", "https://example.com") @@ -38,8 +28,7 @@ def test_request_bytes(): def test_request_stream(): - i = ByteIterator(b"Hello, world") - stream = httpx.HTTPStream(i.next, None) + stream = httpx.ByteStream(b"Hello, world") r = httpx.Request("POST", "https://example.com", content=stream) assert repr(r) == "" @@ -47,7 +36,7 @@ def test_request_stream(): assert r.url == "https://example.com" assert r.headers == { "Host": "example.com", - "Transfer-Encoding": "chunked", + "Content-Length": "12", } assert r.read() == b"Hello, world" diff --git a/tests/test_httpx/test_response.py b/tests/test_httpx/test_response.py index 94efdce..5d326ed 100644 --- a/tests/test_httpx/test_response.py +++ b/tests/test_httpx/test_response.py @@ -2,16 +2,6 @@ import pytest -class ByteIterator: - def __init__(self, buffer=b""): - self._buffer = buffer - - def next(self) -> bytes: - buffer = self._buffer - self._buffer = b'' - return buffer - - def test_response(): r = httpx.Response(200) @@ -42,13 +32,12 @@ def test_response_bytes(): def test_response_stream(): - i = ByteIterator(b"Hello, world") - stream = httpx.HTTPStream(i.next, None) + stream = httpx.ByteStream(b"Hello, world") r = httpx.Response(200, content=stream) assert repr(r) == "" assert r.headers == { - "Transfer-Encoding": "chunked", + "Content-Length": "12", } assert r.read() == b"Hello, world"