diff --git a/requirements.txt b/requirements.txt index 8fd226d..f441484 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ -e . +trio==0.33.0 + # Build... build==1.2.2 @@ -7,6 +9,7 @@ build==1.2.2 mypy==1.15.0 pytest==8.3.5 pytest-cov==6.1.1 +pytest-trio==0.8.0 # Sync & Async mirroring... unasync==0.6.0 diff --git a/scripts/test b/scripts/test index 1e0812c..491c9c1 100755 --- a/scripts/test +++ b/scripts/test @@ -5,6 +5,8 @@ if [ -d 'venv' ] ; then export PREFIX="venv/bin/" fi -${PREFIX}mypy src/httpx ${PREFIX}mypy src/ahttpx -${PREFIX}pytest --cov src/httpx tests +${PREFIX}pytest tests/test_ahttpx + +# ${PREFIX}mypy src/httpx +# ${PREFIX}pytest tests/test_httpx diff --git a/scripts/unasync b/scripts/unasync index 67d66b5..b0aca05 100755 --- a/scripts/unasync +++ b/scripts/unasync @@ -27,3 +27,51 @@ unasync.unasync_files( ), ] ) + + +unasync.unasync_files( + fpath_list = [ + "tests/test_ahttpx/test_client.py", + "tests/test_ahttpx/test_content.py", + "tests/test_ahttpx/test_headers.py", + "tests/test_ahttpx/test_network.py", + "tests/test_ahttpx/test_parsers.py", + "tests/test_ahttpx/test_pool.py", + "tests/test_ahttpx/test_quickstart.py", + "tests/test_ahttpx/test_request.py", + "tests/test_ahttpx/test_response.py", + "tests/test_ahttpx/test_streams.py", + "tests/test_ahttpx/test_urlencode.py", + "tests/test_ahttpx/test_urls.py", + ], + rules = [ + unasync.Rule( + "tests/test_ahttpx/", + "tests/test_httpx/", + additional_replacements={"ahttpx": "httpx"} + ), + ] +) + + +for path in [ + "tests/test_httpx/test_client.py", + "tests/test_httpx/test_content.py", + "tests/test_httpx/test_headers.py", + "tests/test_httpx/test_network.py", + "tests/test_httpx/test_parsers.py", + "tests/test_httpx/test_pool.py", + "tests/test_httpx/test_quickstart.py", + "tests/test_httpx/test_request.py", + "tests/test_httpx/test_response.py", + "tests/test_httpx/test_streams.py", + "tests/test_httpx/test_urlencode.py", + "tests/test_httpx/test_urls.py", +]: + with open(path, "r") as fin: + lines = fin.readlines() + + lines = [line for line in lines if line != "@pytest.mark.trio\n"] + + with open(path, "w") as fout: + fout.writelines(lines) \ No newline at end of file diff --git a/src/ahttpx/_network.py b/src/ahttpx/_network.py index 957e036..63d7c8f 100644 --- a/src/ahttpx/_network.py +++ b/src/ahttpx/_network.py @@ -1,8 +1,8 @@ -import asyncio import ssl import types import typing +import trio import certifi from ._streams import Stream @@ -13,39 +13,37 @@ class NetworkStream(Stream): def __init__( - self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, address: str = '' + self, trio_stream: trio.abc.Stream, address: str = '' ) -> None: - self._reader = reader - self._writer = writer + self._trio_stream = trio_stream self._address = address - self._tls = False self._closed = False async def read(self, size: int = -1) -> bytes: if size < 0: size = 64 * 1024 - return await self._reader.read(size) + return await self._trio_stream.receive_some(size) async def write(self, buffer: bytes) -> None: - self._writer.write(buffer) - await self._writer.drain() + await self._trio_stream.send_all(buffer) async def close(self) -> None: - if not self._closed: - self._writer.close() - await self._writer.wait_closed() + # Close the NetworkStream. + # If the stream is already closed this is a checkpointed no-op. + try: + await self._trio_stream.aclose() + finally: self._closed = True def __repr__(self): description = "" - description += " TLS" if self._tls else "" description += " CLOSED" if self._closed else "" - return f"" + return f"" def __del__(self): if not self._closed: import warnings - warnings.warn("NetworkStream was garbage collected without being closed.") + warnings.warn(f"{self!r} was garbage collected without being closed.") # Context managed usage... async def __aenter__(self) -> "NetworkStream": @@ -61,13 +59,17 @@ async def __aexit__( class NetworkServer: - def __init__(self, host: str, port: int, server: asyncio.Server): + def __init__(self, host: str, port: int, handler, listeners: list[trio.SocketListener]): self.host = host self.port = port - self._server = server + self._handler = handler + self._listeners = listeners # Context managed usage... async def __aenter__(self) -> "NetworkServer": + self._nursery_manager = trio.open_nursery() + self._nursery = await self._nursery_manager.__aenter__() + self._nursery.start_soon(trio.serve_listeners, self._handler, self._listeners) return self async def __aexit__( @@ -76,8 +78,8 @@ async def __aexit__( exc_value: BaseException | None = None, traceback: types.TracebackType | None = None, ): - self._server.close() - await self._server.wait_closed() + self._nursery.cancel_scope.cancel() + await self._nursery_manager.__aexit__(exc_type, exc_value, traceback) class NetworkBackend: @@ -92,29 +94,42 @@ async def connect(self, host: str, port: int) -> NetworkStream: """ Connect to the given address, returning a Stream instance. """ + # Create the TCP stream address = f"{host}:{port}" - reader, writer = await asyncio.open_connection(host, port) - return NetworkStream(reader, writer, address=address) + trio_stream = await trio.open_tcp_stream(host, port) + return NetworkStream(trio_stream, address=address) async def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream: """ Connect to the given address, returning a Stream instance. """ + # Create the TCP stream address = f"{host}:{port}" - reader, writer = await asyncio.open_connection(host, port) - await writer.start_tls(self._ssl_ctx, server_hostname=hostname) - return NetworkStream(reader, writer, address=address) + trio_stream = await trio.open_tcp_stream(host, port) + + # Establish SSL over TCP + hostname = hostname or host + ssl_stream = trio.SSLStream(trio_stream, ssl_context=self._ssl_ctx, server_hostname=hostname) + await ssl_stream.do_handshake() + + return NetworkStream(ssl_stream, address=address) async def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer: - async def callback(reader, writer): - stream = NetworkStream(reader, writer) - await handler(stream) + async def callback(trio_stream): + stream = NetworkStream(trio_stream, address=f"{host}:{port}") + try: + await handler(stream) + finally: + await stream.close() - server = await asyncio.start_server(callback, host, port) - return NetworkServer(host, port, server) + listeners = await trio.open_tcp_listeners(port=port, host=host) + return NetworkServer(host, port, callback, listeners) + + def __repr__(self): + return f"" -Semaphore = asyncio.Semaphore -Lock = asyncio.Lock -timeout = asyncio.timeout -sleep = asyncio.sleep +Semaphore = trio.Semaphore +Lock = trio.Lock +timeout = trio.move_on_after +sleep = trio.sleep diff --git a/src/ahttpx/_parsers.py b/src/ahttpx/_parsers.py index 8a52a56..440f810 100644 --- a/src/ahttpx/_parsers.py +++ b/src/ahttpx/_parsers.py @@ -224,6 +224,16 @@ async def send_body(self, body: bytes) -> None: # Handle body close self.send_state = State.DONE + async def recv_close(self) -> bool: + # ... + if self.is_closed(): + return True + + if await self.parser.read_eof(): + await self.close() + return True + return False + async def recv_method_line(self) -> tuple[bytes, bytes, bytes]: """ Receive the initial request method line: @@ -463,6 +473,18 @@ async def read(self, size: int) -> bytes: self._push_back(bytes(push_back)) return bytes(buffer) + async def read_eof(self) -> bool: + """ + Attempt to read the closing EOF. + Return True if the stream is EOF, or False otherwise. + """ + if not self._buffer: + chunk = await self._read_some() + if not chunk: + return True + self._push_back(chunk) + return False + async def read_until(self, marker: bytes, max_size: int, exc_text: str) -> bytes: """ Read and return bytes from the stream, delimited by marker. diff --git a/src/ahttpx/_server.py b/src/ahttpx/_server.py index a9103cc..de9179a 100644 --- a/src/ahttpx/_server.py +++ b/src/ahttpx/_server.py @@ -31,7 +31,7 @@ def __init__(self, stream, endpoint): # API entry points... async def handle_requests(self): try: - while not self._parser.is_closed(): + while not await self._parser.recv_close(): method, url, headers = await self._recv_head() stream = HTTPStream(self._recv_body, self._complete) # TODO: Handle endpoint exceptions @@ -43,13 +43,13 @@ async def handle_requests(self): except Exception: logger.error("Internal Server Error", exc_info=True) content = Text("Internal Server Error") - err = Response(code=500, content=content) + err = Response(500, content=content) await self._send_head(err) await self._send_body(err) else: await self._send_head(response) await self._send_body(response) - except Exception: + except BaseException: logger.error("Internal Server Error", exc_info=True) async def close(self): @@ -89,7 +89,7 @@ async def _send_body(self, response: Response): # Start it all over again... async def _complete(self): - await self._parser.complete + await self._parser.complete() self._idle_expiry = time.monotonic() + self._keepalive_duration diff --git a/src/httpx/_network.py b/src/httpx/_network.py index 5ea9bb5..8410f61 100644 --- a/src/httpx/_network.py +++ b/src/httpx/_network.py @@ -160,7 +160,7 @@ def __init__(self, listener: NetworkListener, handler: typing.Callable[[NetworkS self._max_workers = 5 self._executor = None self._thread = None - self._streams = list[NetworkStream] + self._streams: list[NetworkStream] = [] @property def host(self): @@ -176,6 +176,8 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): + for stream in self._streams: + stream.close() self.listener.close() self._executor.shutdown(wait=True) @@ -185,9 +187,11 @@ def _serve(self): def _handler(self, stream): try: + self._streams.append(stream) self.handler(stream) finally: stream.close() + self._streams.remove(stream) class NetworkBackend: diff --git a/src/httpx/_parsers.py b/src/httpx/_parsers.py index 830fccd..b8be024 100644 --- a/src/httpx/_parsers.py +++ b/src/httpx/_parsers.py @@ -224,6 +224,16 @@ def send_body(self, body: bytes) -> None: # Handle body close self.send_state = State.DONE + def recv_close(self) -> bool: + # ... + if self.is_closed(): + return True + + if self.parser.read_eof(): + self.close() + return True + return False + def recv_method_line(self) -> tuple[bytes, bytes, bytes]: """ Receive the initial request method line: @@ -463,6 +473,18 @@ def read(self, size: int) -> bytes: self._push_back(bytes(push_back)) return bytes(buffer) + def read_eof(self) -> bool: + """ + Attempt to read the closing EOF. + Return True if the stream is EOF, or False otherwise. + """ + if not self._buffer: + chunk = self._read_some() + if not chunk: + return True + self._push_back(chunk) + return False + def read_until(self, marker: bytes, max_size: int, exc_text: str) -> bytes: """ Read and return bytes from the stream, delimited by marker. diff --git a/src/httpx/_server.py b/src/httpx/_server.py index 95226d9..14d23fa 100644 --- a/src/httpx/_server.py +++ b/src/httpx/_server.py @@ -31,7 +31,7 @@ def __init__(self, stream, endpoint): # API entry points... def handle_requests(self): try: - while not self._parser.is_closed(): + while not self._parser.recv_close(): method, url, headers = self._recv_head() stream = HTTPStream(self._recv_body, self._complete) # TODO: Handle endpoint exceptions @@ -43,13 +43,13 @@ def handle_requests(self): except Exception: logger.error("Internal Server Error", exc_info=True) content = Text("Internal Server Error") - err = Response(code=500, content=content) + err = Response(500, content=content) self._send_head(err) self._send_body(err) else: self._send_head(response) self._send_body(response) - except Exception: + except BaseException: logger.error("Internal Server Error", exc_info=True) def close(self): @@ -89,7 +89,7 @@ def _send_body(self, response: Response): # Start it all over again... def _complete(self): - self._parser.complete + self._parser.complete() self._idle_expiry = time.monotonic() + self._keepalive_duration diff --git a/tests/__init__.py b/tests/test_ahttpx/__init__.py similarity index 100% rename from tests/__init__.py rename to tests/test_ahttpx/__init__.py diff --git a/tests/test_ahttpx/test_client.py b/tests/test_ahttpx/test_client.py new file mode 100644 index 0000000..f7be2c2 --- /dev/null +++ b/tests/test_ahttpx/test_client.py @@ -0,0 +1,123 @@ +import json +import ahttpx +import pytest + + +async def echo(request): + await request.read() + response = ahttpx.Response(200, content=ahttpx.JSON({ + 'method': request.method, + 'query-params': dict(request.url.params.items()), + 'content-type': request.headers.get('Content-Type'), + 'json': json.loads(request.body) if request.body else None, + })) + return response + + +@pytest.fixture +async def client(): + async with ahttpx.Client() as client: + yield client + + +@pytest.fixture +async def server(): + async with ahttpx.serve_http(echo) as server: + yield server + + +@pytest.mark.trio +async def test_client(client): + assert repr(client) == "" + + +@pytest.mark.trio +async def test_get(): + async with ahttpx.serve_http(echo) as server: + async with ahttpx.Client() as client: + r = await client.get(server.url) + assert r.status_code == 200 + assert r.body == b'{"method":"GET","query-params":{},"content-type":null,"json":null}' + assert r.text == '{"method":"GET","query-params":{},"content-type":null,"json":null}' + + +@pytest.mark.trio +async def test_post(client, server): + data = ahttpx.JSON({"data": 123}) + r = await client.post(server.url, content=data) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'POST', + 'query-params': {}, + 'content-type': 'application/json', + 'json': {"data": 123}, + } + + +@pytest.mark.trio +async def test_put(client, server): + data = ahttpx.JSON({"data": 123}) + r = await client.put(server.url, content=data) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'PUT', + 'query-params': {}, + 'content-type': 'application/json', + 'json': {"data": 123}, + } + + +@pytest.mark.trio +async def test_patch(client, server): + data = ahttpx.JSON({"data": 123}) + r = await client.patch(server.url, content=data) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'PATCH', + 'query-params': {}, + 'content-type': 'application/json', + 'json': {"data": 123}, + } + + +@pytest.mark.trio +async def test_delete(client, server): + r = await client.delete(server.url) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'DELETE', + 'query-params': {}, + 'content-type': None, + 'json': None, + } + + +@pytest.mark.trio +async def test_request(client, server): + r = await client.request("GET", server.url) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'GET', + 'query-params': {}, + 'content-type': None, + 'json': None, + } + + +@pytest.mark.trio +async def test_stream(client, server): + async with await client.stream("GET", server.url) as r: + assert r.status_code == 200 + await r.read() + assert json.loads(r.body) == { + 'method': 'GET', + 'query-params': {}, + 'content-type': None, + 'json': None, + } + + +@pytest.mark.trio +async def test_get_with_invalid_scheme(client): + with pytest.raises(ValueError): + await client.get("nope://www.example.com") diff --git a/tests/test_ahttpx/test_content.py b/tests/test_ahttpx/test_content.py new file mode 100644 index 0000000..f8a1500 --- /dev/null +++ b/tests/test_ahttpx/test_content.py @@ -0,0 +1,290 @@ +import ahttpx +import os +import tempfile +import pytest + + +# HTML + +@pytest.mark.trio +async def test_html(): + html = ahttpx.HTML("Hello, world") + + stream = html.encode() + content_type = html.content_type() + + assert await stream.read() == b'Hello, world' + assert content_type == "text/html; charset='utf-8'" + + +# Text + +@pytest.mark.trio +async def test_text(): + text = ahttpx.Text("Hello, world") + + stream = text.encode() + content_type = text.content_type() + + assert await stream.read() == b'Hello, world' + assert content_type == "text/plain; charset='utf-8'" + + +# JSON + +@pytest.mark.trio +async def test_json(): + data = ahttpx.JSON({'data': 123}) + + stream = data.encode() + content_type = data.content_type() + + assert await stream.read() == b'{"data":123}' + assert content_type == "application/json" + + +# Form + +def test_form(): + f = ahttpx.Form("a=123&a=456&b=789") + assert str(f) == "a=123&a=456&b=789" + assert repr(f) == "
" + assert f.multi_dict() == { + "a": ["123", "456"], + "b": ["789"] + } + + +def test_form_from_dict(): + f = ahttpx.Form({ + "a": ["123", "456"], + "b": "789" + }) + assert str(f) == "a=123&a=456&b=789" + assert repr(f) == "" + assert f.multi_dict() == { + "a": ["123", "456"], + "b": ["789"] + } + + +def test_form_from_list(): + f = ahttpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + assert str(f) == "a=123&a=456&b=789" + assert repr(f) == "" + assert f.multi_dict() == { + "a": ["123", "456"], + "b": ["789"] + } + + +def test_empty_form(): + f = ahttpx.Form() + assert str(f) == '' + assert repr(f) == "" + assert f.multi_dict() == {} + + +def test_form_accessors(): + f = ahttpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + assert "a" in f + assert "A" not in f + assert "c" not in f + assert f["a"] == "123" + assert f.get("a") == "123" + assert f.get("nope", default=None) is None + + +def test_form_dict(): + f = ahttpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + assert list(f.keys()) == ["a", "b"] + assert list(f.values()) == ["123", "789"] + assert list(f.items()) == [("a", "123"), ("b", "789")] + assert list(f) == ["a", "b"] + assert dict(f) == {"a": "123", "b": "789"} + + +def test_form_multidict(): + f = ahttpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + assert f.get_list("a") == ["123", "456"] + assert f.multi_items() == [("a", "123"), ("a", "456"), ("b", "789")] + assert f.multi_dict() == {"a": ["123", "456"], "b": ["789"]} + + +def test_form_builtins(): + f = ahttpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + assert len(f) == 2 + assert bool(f) + assert hash(f) + assert f == ahttpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + + +def test_form_copy_operations(): + f = ahttpx.Form([("a", "123"), ("a", "456"), ("b", "789")]) + assert f.copy_set("a", "abc") == ahttpx.Form([("a", "abc"), ("b", "789")]) + assert f.copy_append("a", "abc") == ahttpx.Form([("a", "123"), ("a", "456"), ("a", "abc"), ("b", "789")]) + assert f.copy_remove("a") == ahttpx.Form([("b", "789")]) + + +@pytest.mark.trio +async def test_form_encode(): + form = ahttpx.Form({'email': 'address@example.com'}) + assert form['email'] == "address@example.com" + + stream = form.encode() + content_type = form.content_type() + + assert await stream.read() == b"email=address%40example.com" + assert content_type == "application/x-www-form-urlencoded" + + +# Files + +def test_files(): + f = ahttpx.Files() + assert f.multi_dict() == {} + assert repr(f) == "" + + +def test_files_from_dict(): + f = ahttpx.Files({ + "a": [ + ahttpx.File("123.json"), + ahttpx.File("456.json"), + ], + "b": ahttpx.File("789.json") + }) + assert f.multi_dict() == { + "a": [ + ahttpx.File("123.json"), + ahttpx.File("456.json"), + ], + "b": [ + ahttpx.File("789.json"), + ] + } + assert repr(f) == ( + "), ('a', ), ('b', )]>" + ) + + + +def test_files_from_list(): + f = ahttpx.Files([ + ("a", ahttpx.File("123.json")), + ("a", ahttpx.File("456.json")), + ("b", ahttpx.File("789.json")) + ]) + assert f.multi_dict() == { + "a": [ + ahttpx.File("123.json"), + ahttpx.File("456.json"), + ], + "b": [ + ahttpx.File("789.json"), + ] + } + assert repr(f) == ( + "), ('a', ), ('b', )]>" + ) + + +def test_files_accessors(): + f = ahttpx.Files([ + ("a", ahttpx.File("123.json")), + ("a", ahttpx.File("456.json")), + ("b", ahttpx.File("789.json")) + ]) + assert "a" in f + assert "A" not in f + assert "c" not in f + assert f["a"] == ahttpx.File("123.json") + assert f.get("a") == ahttpx.File("123.json") + assert f.get("nope", default=None) is None + + +def test_files_dict(): + f = ahttpx.Files([ + ("a", ahttpx.File("123.json")), + ("a", ahttpx.File("456.json")), + ("b", ahttpx.File("789.json")) + ]) + assert list(f.keys()) == ["a", "b"] + assert list(f.values()) == [ahttpx.File("123.json"), ahttpx.File("789.json")] + assert list(f.items()) == [("a", ahttpx.File("123.json")), ("b", ahttpx.File("789.json"))] + assert list(f) == ["a", "b"] + assert dict(f) == {"a": ahttpx.File("123.json"), "b": ahttpx.File("789.json")} + + +def test_files_multidict(): + f = ahttpx.Files([ + ("a", ahttpx.File("123.json")), + ("a", ahttpx.File("456.json")), + ("b", ahttpx.File("789.json")) + ]) + assert f.get_list("a") == [ + ahttpx.File("123.json"), + ahttpx.File("456.json"), + ] + assert f.multi_items() == [ + ("a", ahttpx.File("123.json")), + ("a", ahttpx.File("456.json")), + ("b", ahttpx.File("789.json")), + ] + assert f.multi_dict() == { + "a": [ + ahttpx.File("123.json"), + ahttpx.File("456.json"), + ], + "b": [ + ahttpx.File("789.json"), + ] + } + + +def test_files_builtins(): + f = ahttpx.Files([ + ("a", ahttpx.File("123.json")), + ("a", ahttpx.File("456.json")), + ("b", ahttpx.File("789.json")) + ]) + assert len(f) == 2 + assert bool(f) + assert f == ahttpx.Files([ + ("a", ahttpx.File("123.json")), + ("a", ahttpx.File("456.json")), + ("b", ahttpx.File("789.json")), + ]) + + +@pytest.mark.trio +async def test_multipart(): + with tempfile.NamedTemporaryFile() as f: + f.write(b"Hello, world") + f.seek(0) + + multipart = ahttpx.MultiPart( + form={'email': 'me@example.com'}, + files={'upload': ahttpx.File(f.name)}, + boundary='BOUNDARY', + ) + assert multipart.form['email'] == "me@example.com" + assert multipart.files['upload'] == ahttpx.File(f.name) + + fname = os.path.basename(f.name).encode('utf-8') + stream = multipart.encode() + content_type = multipart.content_type() + + content_type == "multipart/form-data; boundary=BOUNDARY" + content = await stream.read() + assert content == ( + b'--BOUNDARY\r\n' + b'Content-Disposition: form-data; name="email"\r\n' + b'\r\n' + b'me@example.com\r\n' + b'--BOUNDARY\r\n' + b'Content-Disposition: form-data; name="upload"; filename="' + fname + b'"\r\n' + b'\r\n' + b'Hello, world\r\n' + b'--BOUNDARY--\r\n' + ) diff --git a/tests/test_ahttpx/test_headers.py b/tests/test_ahttpx/test_headers.py new file mode 100644 index 0000000..ce991e0 --- /dev/null +++ b/tests/test_ahttpx/test_headers.py @@ -0,0 +1,109 @@ +import ahttpx +import pytest + + +def test_headers_from_dict(): + headers = ahttpx.Headers({ + 'Content-Length': '1024', + 'Content-Type': 'text/plain; charset=utf-8', + }) + assert headers['Content-Length'] == '1024' + assert headers['Content-Type'] == 'text/plain; charset=utf-8' + + +def test_headers_from_list(): + headers = ahttpx.Headers([ + ('Location', 'https://www.example.com'), + ('Set-Cookie', 'session_id=3498jj489jhb98jn'), + ]) + assert headers['Location'] == 'https://www.example.com' + assert headers['Set-Cookie'] == 'session_id=3498jj489jhb98jn' + + +def test_header_keys(): + h = ahttpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert list(h.keys()) == ["Accept", "User-Agent"] + + +def test_header_values(): + h = ahttpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert list(h.values()) == ["*/*", "python/httpx"] + + +def test_header_items(): + h = ahttpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert list(h.items()) == [("Accept", "*/*"), ("User-Agent", "python/httpx")] + + +def test_header_get(): + h = ahttpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert h.get("User-Agent") == "python/httpx" + assert h.get("user-agent") == "python/httpx" + assert h.get("missing") is None + + +def test_header_copy_set(): + h = ahttpx.Headers({"Expires": "0"}) + h = h.copy_set("Expires", "Wed, 21 Oct 2015 07:28:00 GMT") + assert h == ahttpx.Headers({"Expires": "Wed, 21 Oct 2015 07:28:00 GMT"}) + + h = ahttpx.Headers({"Expires": "0"}) + h = h.copy_set("expires", "Wed, 21 Oct 2015 07:28:00 GMT") + assert h == ahttpx.Headers({"Expires": "Wed, 21 Oct 2015 07:28:00 GMT"}) + + +def test_header_copy_remove(): + h = ahttpx.Headers({"Accept": "*/*"}) + h = h.copy_remove("Accept") + assert h == ahttpx.Headers({}) + + h = ahttpx.Headers({"Accept": "*/*"}) + h = h.copy_remove("accept") + assert h == ahttpx.Headers({}) + + +def test_header_getitem(): + h = ahttpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert h["User-Agent"] == "python/httpx" + assert h["user-agent"] == "python/httpx" + with pytest.raises(KeyError): + h["missing"] + + +def test_header_contains(): + h = ahttpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert "User-Agent" in h + assert "user-agent" in h + assert "missing" not in h + + +def test_header_bool(): + h = ahttpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert bool(h) + h = ahttpx.Headers() + assert not bool(h) + + +def test_header_iter(): + h = ahttpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert [k for k in h] == ["Accept", "User-Agent"] + + +def test_header_len(): + h = ahttpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert len(h) == 2 + + +def test_header_repr(): + h = ahttpx.Headers({"Accept": "*/*", "User-Agent": "python/httpx"}) + assert repr(h) == "" + + +def test_header_invalid_name(): + with pytest.raises(ValueError): + ahttpx.Headers({"Accept\n": "*/*"}) + + +def test_header_invalid_value(): + with pytest.raises(ValueError): + ahttpx.Headers({"Accept": "*/*\n"}) diff --git a/tests/test_ahttpx/test_network.py b/tests/test_ahttpx/test_network.py new file mode 100644 index 0000000..9c34ac8 --- /dev/null +++ b/tests/test_ahttpx/test_network.py @@ -0,0 +1,104 @@ +import ahttpx +import pytest + + +async def echo(stream): + while buffer := await stream.read(): + await stream.write(buffer) + + +@pytest.fixture +async def server(): + net = ahttpx.NetworkBackend() + async with await net.serve("127.0.0.1", 8080, echo) as server: + yield server + + +def test_network_backend(): + net = ahttpx.NetworkBackend() + assert repr(net) in ["", ""] + + +@pytest.mark.trio +async def test_network_backend_connect(server): + net = ahttpx.NetworkBackend() + stream = await net.connect(server.host, server.port) + try: + assert repr(stream) == f"" + await stream.write(b"Hello, world.") + content = await stream.read() + assert content == b"Hello, world." + finally: + await stream.close() + + +@pytest.mark.trio +async def test_network_backend_context_managed(server): + net = ahttpx.NetworkBackend() + async with await net.connect(server.host, server.port) as stream: + await stream.write(b"Hello, world.") + content = await stream.read() + assert content == b"Hello, world." + assert repr(stream) == f"" + + +@pytest.mark.trio +async def test_network_backend_timeout(server): + net = ahttpx.NetworkBackend() + with ahttpx.timeout(0.0): + with pytest.raises(TimeoutError): + async with await net.connect(server.host, server.port) as stream: + pass + + with ahttpx.timeout(10.0): + async with await net.connect(server.host, server.port) as stream: + pass + + +# >>> net = httpx.NetworkBackend() +# >>> stream = net.connect("dev.encode.io", 80) +# >>> try: +# >>> ... +# >>> finally: +# >>> stream.close() +# >>> stream +# + +# import httpx +# import ssl +# import truststore + +# net = httpx.NetworkBackend() +# ctx = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +# req = b'\r\n'.join([ +# b'GET / HTTP/1.1', +# b'Host: www.example.com', +# b'User-Agent: python/dev', +# b'Connection: close', +# b'', +# ]) + +# # Use a 10 second overall timeout for the entire request/response. +# with timeout(10.0): +# # Use a 3 second timeout for the initial connection. +# with timeout(3.0) as t: +# # Open the connection & establish SSL. +# with net.open_stream("www.example.com", 443) as stream: +# stream.start_tls(ctx, hostname="www.example.com") +# t.cancel() +# # Send the request & read the response. +# stream.write(req) +# buffer = [] +# while part := stream.read(): +# buffer.append(part) +# resp = b''.join(buffer) + + +# def test_fixture(tcp_echo_server): +# host, port = (tcp_echo_server.host, tcp_echo_server.port) + +# net = httpx.NetworkBackend() +# with net.connect(host, port) as stream: +# stream.write(b"123") +# buffer = stream.read() +# assert buffer == b"123" diff --git a/tests/test_ahttpx/test_parsers.py b/tests/test_ahttpx/test_parsers.py new file mode 100644 index 0000000..b2796ee --- /dev/null +++ b/tests/test_ahttpx/test_parsers.py @@ -0,0 +1,771 @@ +import ahttpx +import pytest + + +class TrickleIO(ahttpx.Stream): + def __init__(self, stream: ahttpx.Stream): + self._stream = stream + + async def read(self, size) -> bytes: + return await self._stream.read(1) + + async def write(self, data: bytes) -> None: + await self._stream.write(data) + + async def close(self) -> None: + await self._stream.close() + + +@pytest.mark.trio +async def test_parser(): + stream = ahttpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"POST", b"/", b"HTTP/1.1") + await p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Type", b"application/json"), + (b"Content-Length", b"23"), + ]) + await p.send_body(b'{"msg": "hello, world"}') + await p.send_body(b'') + + assert stream.input_bytes() == ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + assert stream.output_bytes() == ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 23\r\n" + b"\r\n" + b'{"msg": "hello, world"}' + ) + + protocol, code, reason_phase = await p.recv_status_line() + headers = await p.recv_headers() + body = await p.recv_body() + terminator = await p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b'OK' + assert headers == [ + (b'Content-Length', b'12'), + (b'Content-Type', b'text/plain'), + ] + assert body == b'hello, world' + assert terminator == b'' + + assert not p.is_idle() + await p.complete() + assert p.is_idle() + + +@pytest.mark.trio +async def test_parser_server(): + stream = ahttpx.DuplexStream( + b"GET / HTTP/1.1\r\n" + b"Host: www.example.com\r\n" + b"\r\n" + ) + + p = ahttpx.HTTPParser(stream, mode='SERVER') + method, target, protocol = await p.recv_method_line() + headers = await p.recv_headers() + body = await p.recv_body() + + assert method == b'GET' + assert target == b'/' + assert protocol == b'HTTP/1.1' + assert headers == [ + (b'Host', b'www.example.com'), + ] + assert body == b'' + + await p.send_status_line(b"HTTP/1.1", 200, b"OK") + await p.send_headers([ + (b"Content-Type", b"application/json"), + (b"Content-Length", b"23"), + ]) + await p.send_body(b'{"msg": "hello, world"}') + await p.send_body(b'') + + assert stream.input_bytes() == ( + b"GET / HTTP/1.1\r\n" + b"Host: www.example.com\r\n" + b"\r\n" + ) + assert stream.output_bytes() == ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 23\r\n" + b"\r\n" + b'{"msg": "hello, world"}' + ) + + assert not p.is_idle() + await p.complete() + assert p.is_idle() + + +@pytest.mark.trio +async def test_parser_trickle(): + stream = ahttpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + + p = ahttpx.HTTPParser(TrickleIO(stream), mode='CLIENT') + await p.send_method_line(b"POST", b"/", b"HTTP/1.1") + await p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Type", b"application/json"), + (b"Content-Length", b"23"), + ]) + await p.send_body(b'{"msg": "hello, world"}') + await p.send_body(b'') + + assert stream.input_bytes() == ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + assert stream.output_bytes() == ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 23\r\n" + b"\r\n" + b'{"msg": "hello, world"}' + ) + + protocol, code, reason_phase = await p.recv_status_line() + headers = await p.recv_headers() + body = await p.recv_body() + terminator = await p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b'OK' + assert headers == [ + (b'Content-Length', b'12'), + (b'Content-Type', b'text/plain'), + ] + assert body == b'hello, world' + assert terminator == b'' + + +@pytest.mark.trio +async def test_parser_transfer_encoding_chunked(): + stream = ahttpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"c\r\n" + b"hello, world\r\n" + b"0\r\n\r\n" + ) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"POST", b"/", b"HTTP/1.1") + await p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Type", b"application/json"), + (b"Transfer-Encoding", b"chunked"), + ]) + await p.send_body(b'{"msg": "hello, world"}') + await p.send_body(b'') + + assert stream.input_bytes() == ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"c\r\n" + b"hello, world\r\n" + b"0\r\n\r\n" + ) + assert stream.output_bytes() == ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Type: application/json\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b'17\r\n' + b'{"msg": "hello, world"}\r\n' + b'0\r\n\r\n' + ) + + protocol, code, reason_phase = await p.recv_status_line() + headers = await p.recv_headers() + body = await p.recv_body() + terminator = await p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b'OK' + assert headers == [ + (b'Content-Type', b'text/plain'), + (b'Transfer-Encoding', b'chunked'), + ] + assert body == b'hello, world' + assert terminator == b'' + + +@pytest.mark.trio +async def test_parser_transfer_encoding_chunked_trickle(): + stream = ahttpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"c\r\n" + b"hello, world\r\n" + b"0\r\n\r\n" + ) + + p = ahttpx.HTTPParser(TrickleIO(stream), mode='CLIENT') + await p.send_method_line(b"POST", b"/", b"HTTP/1.1") + await p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Type", b"application/json"), + (b"Transfer-Encoding", b"chunked"), + ]) + await p.send_body(b'{"msg": "hello, world"}') + await p.send_body(b'') + + assert stream.input_bytes() == ( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: text/plain\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"c\r\n" + b"hello, world\r\n" + b"0\r\n\r\n" + ) + assert stream.output_bytes() == ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Type: application/json\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b'17\r\n' + b'{"msg": "hello, world"}\r\n' + b'0\r\n\r\n' + ) + + protocol, code, reason_phase = await p.recv_status_line() + headers = await p.recv_headers() + body = await p.recv_body() + terminator = await p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b'OK' + assert headers == [ + (b'Content-Type', b'text/plain'), + (b'Transfer-Encoding', b'chunked'), + ] + assert body == b'hello, world' + assert terminator == b'' + + +@pytest.mark.trio +async def test_parser_repr(): + stream = ahttpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Type: application/json\r\n" + b"Content-Length: 23\r\n" + b"\r\n" + b'{"msg": "hello, world"}' + ) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + assert repr(p) == "" + + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + assert repr(p) == "" + + await p.send_headers([(b"Host", b"example.com")]) + assert repr(p) == "" + + await p.send_body(b'') + assert repr(p) == "" + + await p.recv_status_line() + assert repr(p) == "" + + await p.recv_headers() + assert repr(p) == "" + + await p.recv_body() + assert repr(p) == "" + + await p.recv_body() + assert repr(p) == "" + + await p.complete() + assert repr(p) == "" + + +@pytest.mark.trio +async def test_parser_invalid_transitions(): + stream = ahttpx.DuplexStream() + + with pytest.raises(ahttpx.ProtocolError): + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b'GET', b'/', b'HTTP/1.1') + await p.send_method_line(b'GET', b'/', b'HTTP/1.1') + + with pytest.raises(ahttpx.ProtocolError): + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_headers([]) + + with pytest.raises(ahttpx.ProtocolError): + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_body(b'') + + with pytest.raises(ahttpx.ProtocolError): + reader = ahttpx.ByteStream(b'HTTP/1.1 200 OK\r\n') + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.recv_status_line() + + with pytest.raises(ahttpx.ProtocolError): + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.recv_headers() + + with pytest.raises(ahttpx.ProtocolError): + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.recv_body() + + +@pytest.mark.trio +async def test_parser_invalid_status_line(): + # ... + stream = ahttpx.DuplexStream(b'...') + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + await p.send_headers([(b"Host", b"example.com")]) + await p.send_body(b'') + + msg = 'Stream closed early reading response status line' + with pytest.raises(ahttpx.ProtocolError, match=msg): + await p.recv_status_line() + + # ... + stream = ahttpx.DuplexStream(b'HTTP/1.1' + b'x' * 5000) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + await p.send_headers([(b"Host", b"example.com")]) + await p.send_body(b'') + + msg = 'Exceeded maximum size reading response status line' + with pytest.raises(ahttpx.ProtocolError, match=msg): + await p.recv_status_line() + + # ... + stream = ahttpx.DuplexStream(b'HTTP/1.1' + b'x' * 5000 + b'\r\n') + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + await p.send_headers([(b"Host", b"example.com")]) + await p.send_body(b'') + + msg = 'Exceeded maximum size reading response status line' + with pytest.raises(ahttpx.ProtocolError, match=msg): + await p.recv_status_line() + + +@pytest.mark.trio +async def test_parser_sent_unsupported_protocol(): + # Currently only HTTP/1.1 is supported. + stream = ahttpx.DuplexStream() + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + msg = 'Sent unsupported protocol version' + with pytest.raises(ahttpx.ProtocolError, match=msg): + await p.send_method_line(b"GET", b"/", b"HTTP/1.0") + + +@pytest.mark.trio +async def test_parser_recv_unsupported_protocol(): + # Currently only HTTP/1.1 is supported. + stream = ahttpx.DuplexStream(b"HTTP/1.0 200 OK\r\n") + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + msg = 'Received unsupported protocol version' + with pytest.raises(ahttpx.ProtocolError, match=msg): + await p.recv_status_line() + + +@pytest.mark.trio +async def test_parser_large_body(): + body = b"x" * 6988 + + stream = ahttpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 6988\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + body + ) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + await p.send_headers([(b"Host", b"example.com")]) + await p.send_body(b'') + + # Checkout our buffer sizes. + await p.recv_status_line() + await p.recv_headers() + assert len(await p.recv_body()) == 4096 + assert len(await p.recv_body()) == 2892 + assert len(await p.recv_body()) == 0 + +@pytest.mark.trio +async def test_parser_stream_large_body(): + body = b"x" * 6956 + + stream = ahttpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Transfer-Encoding: chunked\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"1b2c\r\n" + body + b'\r\n0\r\n\r\n' + ) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + await p.send_headers([(b"Host", b"example.com")]) + await p.send_body(b'') + + # Checkout our buffer sizes. + await p.recv_status_line() + await p.recv_headers() + # assert len(p.recv_body()) == 4096 + # assert len(p.recv_body()) == 2860 + assert len(await p.recv_body()) == 6956 + assert len(await p.recv_body()) == 0 + + +@pytest.mark.trio +async def test_parser_not_enough_data_received(): + stream = ahttpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 188\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"truncated" + ) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + await p.send_headers([(b"Host", b"example.com")]) + await p.send_body(b'') + + # Checkout our buffer sizes. + await p.recv_status_line() + await p.recv_headers() + await p.recv_body() + msg = 'Not enough data received for declared Content-Length' + with pytest.raises(ahttpx.ProtocolError, match=msg): + await p.recv_body() + + +@pytest.mark.trio +async def test_parser_not_enough_data_sent(): + stream = ahttpx.DuplexStream() + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"POST", b"/", b"HTTP/1.1") + await p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Type", b"application/json"), + (b"Content-Length", b"23"), + ]) + await p.send_body(b'{"msg": "too smol"}') + msg = 'Not enough data sent for declared Content-Length' + with pytest.raises(ahttpx.ProtocolError, match=msg): + await p.send_body(b'') + + +@pytest.mark.trio +async def test_parser_too_much_data_sent(): + stream = ahttpx.DuplexStream() + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"POST", b"/", b"HTTP/1.1") + await p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Type", b"application/json"), + (b"Content-Length", b"19"), + ]) + msg = 'Too much data sent for declared Content-Length' + with pytest.raises(ahttpx.ProtocolError, match=msg): + await p.send_body(b'{"msg": "too chonky"}') + + +@pytest.mark.trio +async def test_parser_missing_host_header(): + stream = ahttpx.DuplexStream() + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + msg = "Request missing 'Host' header" + with pytest.raises(ahttpx.ProtocolError, match=msg): + await p.send_headers([]) + + +@pytest.mark.trio +async def test_client_connection_close(): + stream = ahttpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + await p.send_headers([ + (b"Host", b"example.com"), + (b"Connection", b"close"), + ]) + await p.send_body(b'') + + protocol, code, reason_phase = await p.recv_status_line() + headers = await p.recv_headers() + body = await p.recv_body() + terminator = await p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b"OK" + assert headers == [ + (b'Content-Length', b'12'), + (b'Content-Type', b'text/plain'), + ] + assert body == b"hello, world" + assert terminator == b"" + + assert repr(p) == "" + + await p.complete() + assert repr(p) == "" + assert p.is_closed() + + +@pytest.mark.trio +async def test_server_connection_close(): + stream = ahttpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"Connection: close\r\n" + b"\r\n" + b"hello, world" + ) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + await p.send_headers([(b"Host", b"example.com")]) + await p.send_body(b'') + + protocol, code, reason_phase = await p.recv_status_line() + headers = await p.recv_headers() + body = await p.recv_body() + terminator = await p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b"OK" + assert headers == [ + (b'Content-Length', b'12'), + (b'Content-Type', b'text/plain'), + (b'Connection', b'close'), + ] + assert body == b"hello, world" + assert terminator == b"" + + assert repr(p) == "" + await p.complete() + assert repr(p) == "" + + +@pytest.mark.trio +async def test_invalid_status_code(): + stream = ahttpx.DuplexStream( + b"HTTP/1.1 99 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + await p.send_headers([ + (b"Host", b"example.com"), + (b"Connection", b"close"), + ]) + await p.send_body(b'') + + msg = "Received invalid status code" + with pytest.raises(ahttpx.ProtocolError, match=msg): + await p.recv_status_line() + + +@pytest.mark.trio +async def test_1xx_status_code(): + stream = ahttpx.DuplexStream( + b"HTTP/1.1 103 Early Hints\r\n" + b"Link: ; rel=preload; as=style\r\n" + b"Link: ; rel=preload; as=script\r\n" + b"\r\n" + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: 12\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + await p.send_headers([(b"Host", b"example.com")]) + await p.send_body(b'') + + protocol, code, reason_phase = await p.recv_status_line() + headers = await p.recv_headers() + + assert protocol == b'HTTP/1.1' + assert code == 103 + assert reason_phase == b'Early Hints' + assert headers == [ + (b'Link', b'; rel=preload; as=style'), + (b'Link', b'; rel=preload; as=script'), + ] + + protocol, code, reason_phase = await p.recv_status_line() + headers = await p.recv_headers() + body = await p.recv_body() + terminator = await p.recv_body() + + assert protocol == b'HTTP/1.1' + assert code == 200 + assert reason_phase == b"OK" + assert headers == [ + (b'Content-Length', b'12'), + (b'Content-Type', b'text/plain'), + ] + assert body == b"hello, world" + assert terminator == b"" + + +@pytest.mark.trio +async def test_received_invalid_content_length(): + stream = ahttpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Content-Length: -999\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"hello, world" + ) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + await p.send_headers([ + (b"Host", b"example.com"), + (b"Connection", b"close"), + ]) + await p.send_body(b'') + + await p.recv_status_line() + msg = "Received invalid Content-Length" + with pytest.raises(ahttpx.ProtocolError, match=msg): + await p.recv_headers() + + +@pytest.mark.trio +async def test_sent_invalid_content_length(): + stream = ahttpx.DuplexStream() + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + msg = "Sent invalid Content-Length" + with pytest.raises(ahttpx.ProtocolError, match=msg): + # Limited to 20 digits. + # 100 million terabytes should be enough for anyone. + await p.send_headers([ + (b"Host", b"example.com"), + (b"Content-Length", b"100000000000000000000"), + ]) + + +@pytest.mark.trio +async def test_received_invalid_characters_in_chunk_size(): + stream = ahttpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Transfer-Encoding: chunked\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"0xFF\r\n..." + ) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + await p.send_headers([ + (b"Host", b"example.com"), + (b"Connection", b"close"), + ]) + await p.send_body(b'') + + await p.recv_status_line() + await p.recv_headers() + msg = "Received invalid chunk size" + with pytest.raises(ahttpx.ProtocolError, match=msg): + await p.recv_body() + + +@pytest.mark.trio +async def test_received_oversized_chunk(): + stream = ahttpx.DuplexStream( + b"HTTP/1.1 200 OK\r\n" + b"Transfer-Encoding: chunked\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + b"FFFFFFFFFF\r\n..." + ) + + p = ahttpx.HTTPParser(stream, mode='CLIENT') + await p.send_method_line(b"GET", b"/", b"HTTP/1.1") + await p.send_headers([ + (b"Host", b"example.com"), + (b"Connection", b"close"), + ]) + await p.send_body(b'') + + await p.recv_status_line() + await p.recv_headers() + msg = "Received invalid chunk size" + with pytest.raises(ahttpx.ProtocolError, match=msg): + await p.recv_body() diff --git a/tests/test_ahttpx/test_pool.py b/tests/test_ahttpx/test_pool.py new file mode 100644 index 0000000..072da67 --- /dev/null +++ b/tests/test_ahttpx/test_pool.py @@ -0,0 +1,133 @@ +import ahttpx +import pytest + + +async def hello_world(request): + content = ahttpx.Text('Hello, world.') + return ahttpx.Response(200, content=content) + + +@pytest.fixture +async def server(): + async with ahttpx.serve_http(hello_world) as server: + yield server + + +@pytest.mark.trio +async def test_connection_pool_request(server): + async with ahttpx.ConnectionPool() as pool: + assert repr(pool) == "" + assert len(pool.connections) == 0 + + r = await pool.request("GET", server.url) + + assert r.status_code == 200 + assert repr(pool) == "" + assert len(pool.connections) == 1 + + +@pytest.mark.trio +async def test_connection_pool_connection_close(server): + async with ahttpx.ConnectionPool() as pool: + assert repr(pool) == "" + assert len(pool.connections) == 0 + + r = await pool.request("GET", server.url, headers={"Connection": "close"}) + + # TODO: Really we want closed connections proactively removed from the pool, + assert r.status_code == 200 + assert repr(pool) == "" + assert len(pool.connections) == 1 + + +@pytest.mark.trio +async def test_connection_pool_stream(server): + async with ahttpx.ConnectionPool() as pool: + assert repr(pool) == "" + assert len(pool.connections) == 0 + + async with await pool.stream("GET", server.url) as r: + assert r.status_code == 200 + assert repr(pool) == "" + assert len(pool.connections) == 1 + await r.read() + + assert repr(pool) == "" + assert len(pool.connections) == 1 + + +@pytest.mark.trio +async def test_connection_pool_cannot_request_after_closed(server): + async with ahttpx.ConnectionPool() as pool: + pool + + with pytest.raises(RuntimeError): + await pool.request("GET", server.url) + + +@pytest.mark.trio +async def test_connection_pool_should_have_managed_lifespan(server): + pool = ahttpx.ConnectionPool() + with pytest.warns(UserWarning): + del pool + + +@pytest.mark.trio +async def test_connection_request(server): + async with await ahttpx.open_connection(server.url) as conn: + assert repr(conn) == f"" + + r = await conn.request("GET", "/") + + assert r.status_code == 200 + assert repr(conn) == f"" + + +@pytest.mark.trio +async def test_connection_stream(server): + async with await ahttpx.open_connection(server.url) as conn: + assert repr(conn) == f"" + async with await conn.stream("GET", "/") as r: + assert r.status_code == 200 + assert repr(conn) == f"" + await r.read() + assert repr(conn) == f"" + + +# # with httpx.open_connection("https://www.example.com/") as conn: +# # r = conn.request("GET", "/") + +# # >>> pool = httpx.ConnectionPool() +# # >>> pool +# # + +# # >>> with httpx.open_connection_pool() as pool: +# # >>> res = pool.request("GET", "https://www.example.com") +# # >>> res, pool +# # , + +# # >>> with httpx.open_connection_pool() as pool: +# # >>> with pool.stream("GET", "https://www.example.com") as res: +# # >>> res, pool +# # , + +# # >>> with httpx.open_connection_pool() as pool: +# # >>> req = httpx.Request("GET", "https://www.example.com") +# # >>> with pool.send(req) as res: +# # >>> res.body() +# # >>> res, pool +# # , + +# # >>> with httpx.open_connection_pool() as pool: +# # >>> pool.close() +# # + +# # with httpx.open_connection("https://www.example.com/") as conn: +# # with conn.upgrade("GET", "/feed", {"Upgrade": "WebSocket") as stream: +# # ... + +# # with httpx.open_connection("http://127.0.0.1:8080") as conn: +# # with conn.upgrade("CONNECT", "www.encode.io:443") as stream: +# # stream.start_tls(ctx, hostname="www.encode.io") +# # ... + diff --git a/tests/test_ahttpx/test_quickstart.py b/tests/test_ahttpx/test_quickstart.py new file mode 100644 index 0000000..ef3963c --- /dev/null +++ b/tests/test_ahttpx/test_quickstart.py @@ -0,0 +1,83 @@ +import json +import ahttpx +import pytest + + +async def echo(request): + await request.read() + response = ahttpx.Response(200, content=ahttpx.JSON({ + 'method': request.method, + 'query-params': dict(request.url.params.items()), + 'content-type': request.headers.get('Content-Type'), + 'json': json.loads(request.body) if request.body else None, + })) + return response + + +@pytest.fixture +async def server(): + async with ahttpx.serve_http(echo) as server: + yield server + + +@pytest.mark.trio +async def test_get(server): + r = await ahttpx.get(server.url) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'GET', + 'query-params': {}, + 'content-type': None, + 'json': None, + } + + +@pytest.mark.trio +async def test_post(server): + data = ahttpx.JSON({"data": 123}) + r = await ahttpx.post(server.url, content=data) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'POST', + 'query-params': {}, + 'content-type': 'application/json', + 'json': {"data": 123}, + } + + +@pytest.mark.trio +async def test_put(server): + data = ahttpx.JSON({"data": 123}) + r = await ahttpx.put(server.url, content=data) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'PUT', + 'query-params': {}, + 'content-type': 'application/json', + 'json': {"data": 123}, + } + + +@pytest.mark.trio +async def test_patch(server): + data = ahttpx.JSON({"data": 123}) + r = await ahttpx.patch(server.url, content=data) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'PATCH', + 'query-params': {}, + 'content-type': 'application/json', + 'json': {"data": 123}, + } + + +@pytest.mark.trio +async def test_delete(server): + r = await ahttpx.delete(server.url) + assert r.status_code == 200 + assert json.loads(r.body) == { + 'method': 'DELETE', + 'query-params': {}, + 'content-type': None, + 'json': None, + } diff --git a/tests/test_ahttpx/test_request.py b/tests/test_ahttpx/test_request.py new file mode 100644 index 0000000..6bc9f3d --- /dev/null +++ b/tests/test_ahttpx/test_request.py @@ -0,0 +1,85 @@ +import ahttpx +import pytest + + +class ByteIterator: + def __init__(self, buffer=b""): + self._buffer = buffer + + async def next(self) -> bytes: + buffer = self._buffer + self._buffer = b'' + return buffer + + +@pytest.mark.trio +async def test_request(): + r = ahttpx.Request("GET", "https://example.com") + + assert repr(r) == "" + assert r.method == "GET" + assert r.url == "https://example.com" + assert r.headers == { + "Host": "example.com" + } + assert await r.read() == b"" + +@pytest.mark.trio +async def test_request_bytes(): + content = b"Hello, world" + r = ahttpx.Request("POST", "https://example.com", content=content) + + assert repr(r) == "" + assert r.method == "POST" + assert r.url == "https://example.com" + assert r.headers == { + "Host": "example.com", + "Content-Length": "12", + } + assert await r.read() == b"Hello, world" + + +@pytest.mark.trio +async def test_request_stream(): + i = ByteIterator(b"Hello, world") + stream = ahttpx.HTTPStream(i.next, None) + r = ahttpx.Request("POST", "https://example.com", content=stream) + + assert repr(r) == "" + assert r.method == "POST" + assert r.url == "https://example.com" + assert r.headers == { + "Host": "example.com", + "Transfer-Encoding": "chunked", + } + assert await r.read() == b"Hello, world" + + +@pytest.mark.trio +async def test_request_json(): + data = ahttpx.JSON({"msg": "Hello, world"}) + r = ahttpx.Request("POST", "https://example.com", content=data) + + assert repr(r) == "" + assert r.method == "POST" + assert r.url == "https://example.com" + assert r.headers == { + "Host": "example.com", + "Content-Length": "22", + "Content-Type": "application/json", + } + assert await r.read() == b'{"msg":"Hello, world"}' + + +@pytest.mark.trio +async def test_request_empty_post(): + r = ahttpx.Request("POST", "https://example.com") + + assert repr(r) == "" + assert r.method == "POST" + assert r.url == "https://example.com" + assert r.headers == { + "Host": "example.com", + "Content-Length": "0", + } + assert await r.read() == b'' diff --git a/tests/test_ahttpx/test_response.py b/tests/test_ahttpx/test_response.py new file mode 100644 index 0000000..3b2f4bf --- /dev/null +++ b/tests/test_ahttpx/test_response.py @@ -0,0 +1,70 @@ +import ahttpx +import pytest + + +class ByteIterator: + def __init__(self, buffer=b""): + self._buffer = buffer + + async def next(self) -> bytes: + buffer = self._buffer + self._buffer = b'' + return buffer + + +@pytest.mark.trio +async def test_response(): + r = ahttpx.Response(200) + + assert repr(r) == "" + assert r.status_code == 200 + assert r.headers == {'Content-Length': '0'} + assert await r.read() == b"" + + +@pytest.mark.trio +async def test_response_204(): + r = ahttpx.Response(204) + + assert repr(r) == "" + assert r.status_code == 204 + assert r.headers == {} + assert await r.read() == b"" + + +@pytest.mark.trio +async def test_response_bytes(): + content = b"Hello, world" + r = ahttpx.Response(200, content=content) + + assert repr(r) == "" + assert r.headers == { + "Content-Length": "12", + } + assert await r.read() == b"Hello, world" + + +@pytest.mark.trio +async def test_response_stream(): + i = ByteIterator(b"Hello, world") + stream = ahttpx.HTTPStream(i.next, None) + r = ahttpx.Response(200, content=stream) + + assert repr(r) == "" + assert r.headers == { + "Transfer-Encoding": "chunked", + } + assert await r.read() == b"Hello, world" + + +@pytest.mark.trio +async def test_response_json(): + data = ahttpx.JSON({"msg": "Hello, world"}) + r = ahttpx.Response(200, content=data) + + assert repr(r) == "" + assert r.headers == { + "Content-Length": "22", + "Content-Type": "application/json", + } + assert await r.read() == b'{"msg":"Hello, world"}' diff --git a/tests/test_ahttpx/test_streams.py b/tests/test_ahttpx/test_streams.py new file mode 100644 index 0000000..b898ac5 --- /dev/null +++ b/tests/test_ahttpx/test_streams.py @@ -0,0 +1,85 @@ +import pytest +import ahttpx + + +@pytest.mark.trio +async def test_stream(): + i = ahttpx.Stream() + with pytest.raises(NotImplementedError): + await i.read() + + with pytest.raises(NotImplementedError): + await i.close() + + i.size == None + + +@pytest.mark.trio +async def test_bytestream(): + data = b'abc' + s = ahttpx.ByteStream(data) + assert s.size == 3 + assert await s.read() == b'abc' + + s = ahttpx.ByteStream(data) + assert await s.read(1) == b'a' + assert await s.read(1) == b'b' + assert await s.read(1) == b'c' + assert await s.read(1) == b'' + + +@pytest.mark.trio +async def test_filestream(tmp_path): + path = tmp_path / "example.txt" + path.write_bytes(b"hello world") + + async with ahttpx.FileStream(path) as s: + assert s.size == 11 + assert await s.read() == b'hello world' + + async with ahttpx.FileStream(path) as s: + assert await s.read(5) == b'hello' + assert await s.read(5) == b' worl' + assert await s.read(5) == b'd' + assert await s.read(5) == b'' + + async with ahttpx.FileStream(path) as s: + assert await s.read(5) == b'hello' + + +@pytest.mark.trio +async def test_multipartstream(tmp_path): + path = tmp_path / 'example.txt' + path.write_bytes(b'hello world' + b'x' * 50) + + expected = b''.join([ + b'--boundary\r\n', + b'Content-Disposition: form-data; name="email"\r\n', + b'\r\n', + b'heya@example.com\r\n', + b'--boundary\r\n', + b'Content-Disposition: form-data; name="upload"; filename="example.txt"\r\n', + b'\r\n', + b'hello world' + ( b'x' * 50) + b'\r\n', + b'--boundary--\r\n', + ]) + + form = [('email', 'heya@example.com')] + files = [('upload', str(path))] + async with ahttpx.MultiPartStream(form, files, boundary='boundary') as s: + assert s.size is None + assert await s.read() == expected + + async with ahttpx.MultiPartStream(form, files, boundary='boundary') as s: + assert await s.read(50) == expected[:50] + assert await s.read(50) == expected[50:100] + assert await s.read(50) == expected[100:150] + assert await s.read(50) == expected[150:200] + assert await s.read(50) == expected[200:250] + + async with ahttpx.MultiPartStream(form, files, boundary='boundary') as s: + assert await s.read(50) == expected[:50] + assert await s.read(50) == expected[50:100] + assert await s.read(50) == expected[100:150] + assert await s.read(50) == expected[150:200] + await s.close() # test close during open file diff --git a/tests/test_ahttpx/test_urlencode.py b/tests/test_ahttpx/test_urlencode.py new file mode 100644 index 0000000..1c6afbd --- /dev/null +++ b/tests/test_ahttpx/test_urlencode.py @@ -0,0 +1,33 @@ +import ahttpx + + +def test_urlencode(): + qs = "a=name%40example.com&a=456&b=7+8+9&c" + d = ahttpx.urldecode(qs) + assert d == { + "a": ["name@example.com", "456"], + "b": ["7 8 9"], + "c": [""] + } + + +def test_urldecode(): + d = { + "a": ["name@example.com", "456"], + "b": ["7 8 9"], + "c": [""] + } + qs = ahttpx.urlencode(d) + assert qs == "a=name%40example.com&a=456&b=7+8+9&c=" + + +def test_urlencode_empty(): + qs = "" + d = ahttpx.urldecode(qs) + assert d == {} + + +def test_urldecode_empty(): + d = {} + qs = ahttpx.urlencode(d) + assert qs == "" diff --git a/tests/test_ahttpx/test_urls.py b/tests/test_ahttpx/test_urls.py new file mode 100644 index 0000000..354ec3c --- /dev/null +++ b/tests/test_ahttpx/test_urls.py @@ -0,0 +1,164 @@ +import ahttpx +import pytest + + +def test_url(): + url = ahttpx.URL('https://www.example.com/') + assert str(url) == "https://www.example.com/" + + +def test_url_repr(): + url = ahttpx.URL('https://www.example.com/') + assert repr(url) == "" + + +def test_url_params(): + url = ahttpx.URL('https://www.example.com/', params={"a": "b", "c": "d"}) + assert str(url) == "https://www.example.com/?a=b&c=d" + + +def test_url_normalisation(): + url = ahttpx.URL('https://www.EXAMPLE.com:443/path/../main') + assert str(url) == 'https://www.example.com/main' + + +def test_url_relative(): + url = ahttpx.URL('/README.md') + assert str(url) == '/README.md' + + +def test_url_escaping(): + url = ahttpx.URL('https://example.com/path to here?search=🦋') + assert str(url) == 'https://example.com/path%20to%20here?search=%F0%9F%A6%8B' + + +def test_url_components(): + url = ahttpx.URL(scheme="https", host="example.com", path="/") + assert str(url) == 'https://example.com/' + + +# QueryParams + +def test_queryparams(): + params = ahttpx.QueryParams({"color": "black", "size": "medium"}) + assert str(params) == 'color=black&size=medium' + + +def test_queryparams_repr(): + params = ahttpx.QueryParams({"color": "black", "size": "medium"}) + assert repr(params) == "" + + +def test_queryparams_list_of_values(): + params = ahttpx.QueryParams({"filter": ["60GHz", "75GHz", "100GHz"]}) + assert str(params) == 'filter=60GHz&filter=75GHz&filter=100GHz' + + +def test_queryparams_from_str(): + params = ahttpx.QueryParams("color=black&size=medium") + assert str(params) == 'color=black&size=medium' + + +def test_queryparams_access(): + params = ahttpx.QueryParams("sort_by=published&author=natalie") + assert params["sort_by"] == 'published' + + +def test_queryparams_escaping(): + params = ahttpx.QueryParams({"email": "user@example.com", "search": "How HTTP works!"}) + assert str(params) == 'email=user%40example.com&search=How+HTTP+works%21' + + +def test_queryparams_empty(): + q = ahttpx.QueryParams({"a": ""}) + assert str(q) == "a=" + + q = ahttpx.QueryParams("a=") + assert str(q) == "a=" + + q = ahttpx.QueryParams("a") + assert str(q) == "a=" + + +def test_queryparams_set(): + q = ahttpx.QueryParams("a=123") + q = q.copy_set("a", "456") + assert q == ahttpx.QueryParams("a=456") + + +def test_queryparams_append(): + q = ahttpx.QueryParams("a=123") + q = q.copy_append("a", "456") + assert q == ahttpx.QueryParams("a=123&a=456") + + +def test_queryparams_remove(): + q = ahttpx.QueryParams("a=123") + q = q.copy_remove("a") + assert q == ahttpx.QueryParams("") + + +def test_queryparams_merge(): + q = ahttpx.QueryParams("a=123") + q = q.copy_update({"b": "456"}) + assert q == ahttpx.QueryParams("a=123&b=456") + q = q.copy_update({"a": "000", "c": "789"}) + assert q == ahttpx.QueryParams("a=000&b=456&c=789") + + +def test_queryparams_are_hashable(): + params = ( + ahttpx.QueryParams("a=123"), + ahttpx.QueryParams({"a": "123"}), + ahttpx.QueryParams("b=456"), + ahttpx.QueryParams({"b": "456"}), + ) + + assert len(set(params)) == 2 + + +@pytest.mark.parametrize( + "source", + [ + "a=123&a=456&b=789", + {"a": ["123", "456"], "b": "789"}, + {"a": ("123", "456"), "b": "789"}, + [("a", "123"), ("a", "456"), ("b", "789")], + (("a", "123"), ("a", "456"), ("b", "789")), + ], +) +def test_queryparams_misc(source): + q = ahttpx.QueryParams(source) + assert "a" in q + assert "A" not in q + assert "c" not in q + assert q["a"] == "123" + assert q.get("a") == "123" + assert q.get("nope", default=None) is None + assert q.get_list("a") == ["123", "456"] + assert bool(q) + + assert list(q.keys()) == ["a", "b"] + assert list(q.values()) == ["123", "789"] + assert list(q.items()) == [("a", "123"), ("b", "789")] + assert len(q) == 2 + assert list(q) == ["a", "b"] + assert dict(q) == {"a": "123", "b": "789"} + assert str(q) == "a=123&a=456&b=789" + assert ahttpx.QueryParams({"a": "123", "b": "456"}) == ahttpx.QueryParams( + [("a", "123"), ("b", "456")] + ) + assert ahttpx.QueryParams({"a": "123", "b": "456"}) == ahttpx.QueryParams( + "a=123&b=456" + ) + assert ahttpx.QueryParams({"a": "123", "b": "456"}) == ahttpx.QueryParams( + {"b": "456", "a": "123"} + ) + assert ahttpx.QueryParams() == ahttpx.QueryParams({}) + assert ahttpx.QueryParams([("a", "123"), ("a", "456")]) == ahttpx.QueryParams( + "a=123&a=456" + ) + assert ahttpx.QueryParams({"a": "123", "b": "456"}) != "invalid" + + q = ahttpx.QueryParams([("a", "123"), ("a", "456")]) + assert ahttpx.QueryParams(q) == q diff --git a/tests/test_httpx/__init__.py b/tests/test_httpx/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_client.py b/tests/test_httpx/test_client.py similarity index 87% rename from tests/test_client.py rename to tests/test_httpx/test_client.py index c26f6ba..6aa76f5 100644 --- a/tests/test_client.py +++ b/tests/test_httpx/test_client.py @@ -30,11 +30,13 @@ def test_client(client): assert repr(client) == "" -def test_get(client, server): - r = client.get(server.url) - assert r.status_code == 200 - assert r.body == b'{"method":"GET","query-params":{},"content-type":null,"json":null}' - assert r.text == '{"method":"GET","query-params":{},"content-type":null,"json":null}' +def test_get(): + with httpx.serve_http(echo) as server: + with httpx.Client() as client: + r = client.get(server.url) + assert r.status_code == 200 + assert r.body == b'{"method":"GET","query-params":{},"content-type":null,"json":null}' + assert r.text == '{"method":"GET","query-params":{},"content-type":null,"json":null}' def test_post(client, server): diff --git a/tests/test_content.py b/tests/test_httpx/test_content.py similarity index 99% rename from tests/test_content.py rename to tests/test_httpx/test_content.py index ae3158e..d5fc3de 100644 --- a/tests/test_content.py +++ b/tests/test_httpx/test_content.py @@ -1,7 +1,7 @@ import httpx import os import tempfile - +import pytest # HTML diff --git a/tests/test_headers.py b/tests/test_httpx/test_headers.py similarity index 100% rename from tests/test_headers.py rename to tests/test_httpx/test_headers.py diff --git a/tests/test_network.py b/tests/test_httpx/test_network.py similarity index 97% rename from tests/test_network.py rename to tests/test_httpx/test_network.py index e6ce925..59b66f5 100644 --- a/tests/test_network.py +++ b/tests/test_httpx/test_network.py @@ -16,7 +16,7 @@ def server(): def test_network_backend(): net = httpx.NetworkBackend() - assert repr(net) == "" + assert repr(net) in ["", ""] def test_network_backend_connect(server): diff --git a/tests/test_parsers.py b/tests/test_httpx/test_parsers.py similarity index 99% rename from tests/test_parsers.py rename to tests/test_httpx/test_parsers.py index e2a321e..62260d9 100644 --- a/tests/test_parsers.py +++ b/tests/test_httpx/test_parsers.py @@ -430,7 +430,6 @@ def test_parser_large_body(): assert len(p.recv_body()) == 2892 assert len(p.recv_body()) == 0 - def test_parser_stream_large_body(): body = b"x" * 6956 diff --git a/tests/test_pool.py b/tests/test_httpx/test_pool.py similarity index 100% rename from tests/test_pool.py rename to tests/test_httpx/test_pool.py diff --git a/tests/test_quickstart.py b/tests/test_httpx/test_quickstart.py similarity index 100% rename from tests/test_quickstart.py rename to tests/test_httpx/test_quickstart.py diff --git a/tests/test_request.py b/tests/test_httpx/test_request.py similarity index 99% rename from tests/test_request.py rename to tests/test_httpx/test_request.py index a69e1d1..47e5c4d 100644 --- a/tests/test_request.py +++ b/tests/test_httpx/test_request.py @@ -1,10 +1,11 @@ import httpx +import pytest class ByteIterator: def __init__(self, buffer=b""): self._buffer = buffer - + def next(self) -> bytes: buffer = self._buffer self._buffer = b'' diff --git a/tests/test_response.py b/tests/test_httpx/test_response.py similarity index 99% rename from tests/test_response.py rename to tests/test_httpx/test_response.py index d25ebeb..94efdce 100644 --- a/tests/test_response.py +++ b/tests/test_httpx/test_response.py @@ -1,4 +1,5 @@ import httpx +import pytest class ByteIterator: diff --git a/tests/test_streams.py b/tests/test_httpx/test_streams.py similarity index 99% rename from tests/test_streams.py rename to tests/test_httpx/test_streams.py index 8053761..41ae812 100644 --- a/tests/test_streams.py +++ b/tests/test_httpx/test_streams.py @@ -44,7 +44,6 @@ def test_filestream(tmp_path): assert s.read(5) == b'hello' - def test_multipartstream(tmp_path): path = tmp_path / 'example.txt' path.write_bytes(b'hello world' + b'x' * 50) diff --git a/tests/test_urlencode.py b/tests/test_httpx/test_urlencode.py similarity index 100% rename from tests/test_urlencode.py rename to tests/test_httpx/test_urlencode.py diff --git a/tests/test_urls.py b/tests/test_httpx/test_urls.py similarity index 100% rename from tests/test_urls.py rename to tests/test_httpx/test_urls.py