diff --git a/CHANGES/12744.feature.rst b/CHANGES/12744.feature.rst new file mode 100644 index 00000000000..9f2f116f71f --- /dev/null +++ b/CHANGES/12744.feature.rst @@ -0,0 +1,3 @@ +Added ``aiofastnet`` package to ``speedups`` extra. aiofastnet provides faster alternatives to the standard loop functions, which are used to run server or establish connections. If you experience any issues that you think might be related to this change, you can try to disable ``aiofastnet`` by uninstalling aiofastnet package. + +-- by :user:`tarasko`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 2aaa0a02403..837b8456f9d 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -367,6 +367,7 @@ Sunit Deshpande Sviatoslav Bulbakha Sviatoslav Sydorenko Taha Jahangir +Taras Kozlov Taras Voinarovskyi Terence Honles Thanos Lefteris diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 6e70b3a28a2..fdec1d50fe4 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -1,5 +1,6 @@ import asyncio import functools +import importlib import random import socket import sys @@ -18,6 +19,13 @@ from aiohappyeyeballs import AddrInfoType, SocketFactoryType from multidict import CIMultiDict +aiofastnet: Any +try: + aiofastnet = importlib.import_module("aiofastnet") +except ImportError: + aiofastnet = None + + from . import hdrs, helpers from .abc import AbstractResolver, ResolveResult from .client_exceptions import ( @@ -96,6 +104,24 @@ from .tracing import Trace +async def create_connection( + loop: asyncio.AbstractEventLoop, *args: Any, **kwargs: Any, +) -> tuple[asyncio.Transport, ResponseHandler]: + if aiofastnet is not None: + return await aiofastnet.create_connection(loop, *args, **kwargs) # type: ignore[no-any-return] + else: + return await loop.create_connection(*args, **kwargs) + + +async def start_tls( + loop: asyncio.AbstractEventLoop, *args: Any, **kwargs: Any +) -> asyncio.BaseTransport | None: + if aiofastnet is not None: + return await aiofastnet.start_tls(loop, *args, **kwargs) # type: ignore[no-any-return] + else: + return await loop.start_tls(*args, **kwargs) + + class Connection: """Represents a single connection.""" @@ -1259,7 +1285,7 @@ async def _wrap_create_connection( and sys.version_info >= (3, 11) ): kwargs["ssl_shutdown_timeout"] = self._ssl_shutdown_timeout - return await self._loop.create_connection(*args, **kwargs, sock=sock) + return await create_connection(self._loop, *args, **kwargs, sock=sock) except cert_errors as exc: raise ClientConnectorCertificateError(req.connection_key, exc) from exc except ssl_errors as exc: @@ -1340,7 +1366,8 @@ async def _start_tls_connection( try: # ssl_shutdown_timeout is only available in Python 3.11+ if sys.version_info >= (3, 11) and self._ssl_shutdown_timeout: - tls_transport = await self._loop.start_tls( + tls_transport = await start_tls( + self._loop, underlying_transport, tls_proto, sslcontext, @@ -1349,7 +1376,8 @@ async def _start_tls_connection( ssl_shutdown_timeout=self._ssl_shutdown_timeout, ) else: - tls_transport = await self._loop.start_tls( + tls_transport = await start_tls( + self._loop, underlying_transport, tls_proto, sslcontext, diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index d45afc0dd7d..addd420aa47 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -1,4 +1,5 @@ import asyncio +import importlib import io import os import pathlib @@ -11,6 +12,13 @@ from types import MappingProxyType from typing import IO, TYPE_CHECKING, Any, Final, Optional +aiofastnet: Any +try: + aiofastnet = importlib.import_module("aiofastnet") +except ImportError: + aiofastnet = None + + from . import hdrs from .abc import AbstractStreamWriter from .helpers import DEFAULT_CHUNK_SIZE, ETAG_ANY, ETag, must_be_empty_body @@ -34,6 +42,15 @@ _T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]] +async def sendfile( + loop: asyncio.AbstractEventLoop, *args: Any, **kwargs: Any +) -> int: + if aiofastnet is not None: + return await aiofastnet.sendfile(loop, *args, **kwargs) # type: ignore[no-any-return] + else: + return await loop.sendfile(*args, **kwargs) + + NOSENDFILE: Final[bool] = bool(os.environ.get("AIOHTTP_NOSENDFILE")) CONTENT_TYPES: Final[MimeTypes] = MimeTypes() @@ -132,7 +149,7 @@ async def _sendfile( raise ConnectionResetError("Connection lost") try: - await loop.sendfile(transport, fobj, offset, count) + await sendfile(loop, transport, fobj, offset, count) except NotImplementedError: return await self._sendfile_fallback(writer, fobj, offset, count) diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index 82c3bd277f8..e192fe565ff 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -1,4 +1,5 @@ import asyncio +import importlib import signal import socket from abc import ABC, abstractmethod @@ -6,6 +7,12 @@ from yarl import URL +aiofastnet: Any +try: + aiofastnet = importlib.import_module("aiofastnet") +except ImportError: + aiofastnet = None + from .abc import AbstractAccessLogger, AbstractStreamWriter from .http_parser import RawRequestMessage from .streams import StreamReader @@ -21,6 +28,16 @@ except ImportError: # pragma: no cover SSLContext = object # type: ignore[misc,assignment] + +async def create_server( + loop: asyncio.AbstractEventLoop, *args: Any, **kwargs: Any +) -> asyncio.Server: + if aiofastnet is not None: + return await aiofastnet.create_server(loop, *args, **kwargs) # type: ignore[no-any-return] + else: + return await loop.create_server(*args, **kwargs) + + __all__ = ( "BaseSite", "TCPSite", @@ -130,7 +147,8 @@ async def start(self) -> None: loop = asyncio.get_running_loop() server = self._runner.server assert server is not None - self._server = await loop.create_server( + self._server = await create_server( + loop, server, self._host, self._port, @@ -244,8 +262,8 @@ async def start(self) -> None: loop = asyncio.get_running_loop() server = self._runner.server assert server is not None - self._server = await loop.create_server( - server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog + self._server = await create_server( + loop, server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog ) diff --git a/docs/faq.rst b/docs/faq.rst index 3f50b855588..b4276cf1843 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -263,6 +263,52 @@ enable compression in NGINX (you are deploying aiohttp behind reverse proxy, right?). +How do I enable Kernel TLS, and should I do it? +----------------------------------------------- + +Kernel TLS (KTLS) allows aiohttp to move encryption and decryption of +TLS traffic from user space to the kernel. It was added to the Linux kernel in +4.13, but full support for TLS 1.3 and modern ciphers is available only +since 5.19. + +KTLS will be beneficial if you run an HTTPS server that often returns +:class:`~aiohttp.web.FileResponse` objects or you have a high-end NIC that can +offload TLS encryption. For ordinary +dynamic responses, small files, or deployments behind a TLS-terminating reverse +proxy, it is unlikely to help and may actually slightly degrade performance. + +KTLS is supported through the ``aiofastnet`` package, which is installed as +part of the ``speedups`` extra. + +To enable KTLS, you have to do and check the following: + +* Make sure the Linux ``tls`` kernel module is loaded:: + + sudo modprobe tls + +* Make sure the ``ssl.OP_ENABLE_KTLS`` option is enabled in ``SSLContext`` + (available since Python 3.12):: + + sslcontext.options |= ssl.OP_ENABLE_KTLS + +* Make sure Python is using OpenSSL 3.0 or newer. OpenSSL should have been + built on a machine whose Linux headers are new enough. OpenSSL needs Linux + headers at least 4.13.0 to build the transmit path; older headers make it + skip KTLS support. Typically, Python is using the system OpenSSL on Linux, + but some times distributions ship their own OpenSSL. The following commands + will help identify the OpenSSL version and which ``libssl`` and ``libcrypto`` + are being used by the ``ssl`` module:: + + python -c "import ssl; print(ssl.OPENSSL_VERSION)" + ldd "$(python -c 'import _ssl; print(_ssl.__file__)')" + + +If ``ssl.OP_ENABLE_KTLS`` was requested in ``sslcontext``, but ``aiofastnet`` +could not enable KTLS, it will log a warning suggesting the possible reason. + +After enabling it, run your own benchmarks and verify that KTLS actually +speeds things up in your case. + How do I manage a ClientSession within a web server? ---------------------------------------------------- diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 75d3d0c8323..14b6da20607 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -4,6 +4,7 @@ ABI addons aiodns aioes +aiofastnet aiohttp aiohttpdemo aiohttp’s @@ -183,6 +184,7 @@ keepalive keepalived keepalives keepaliving +KTLS kib KiB kwarg @@ -227,6 +229,7 @@ namedtuple nameservers namespace netrc +NIC nginx Nginx Nikolay @@ -236,6 +239,7 @@ nowait OAuth Online optimizations +OpenSSL orjson os outcoming diff --git a/pyproject.toml b/pyproject.toml index 0c27cc88bb5..1d96cf4206d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dynamic = [ [project.optional-dependencies] speedups = [ "aiodns >= 3.3.0; sys_platform != 'android' and sys_platform != 'ios'", + "aiofastnet >= 0.9.0; platform_python_implementation == 'CPython' and (platform_machine == 'x86_64' or platform_machine == 'AMD64' or platform_machine == 'aarch64')", "Brotli >= 1.2; platform_python_implementation == 'CPython' and sys_platform != 'android' and sys_platform != 'ios'", "brotlicffi >= 1.2; platform_python_implementation != 'CPython'", "backports.zstd; platform_python_implementation == 'CPython' and python_version < '3.14' and sys_platform != 'android' and sys_platform != 'ios'", diff --git a/requirements/lint.in b/requirements/lint.in index c0a86f2435f..e5c3f7f4533 100644 --- a/requirements/lint.in +++ b/requirements/lint.in @@ -1,4 +1,5 @@ aiodns +aiofastnet backports.zstd; implementation_name == "cpython" and python_version < "3.14" blockbuster freezegun diff --git a/requirements/runtime-deps.in b/requirements/runtime-deps.in index d70fc5a9dbc..03974b7b217 100644 --- a/requirements/runtime-deps.in +++ b/requirements/runtime-deps.in @@ -1,6 +1,7 @@ # Extracted from `pyproject.toml` via `make sync-direct-runtime-deps` aiodns >= 3.3.0; sys_platform != 'android' and sys_platform != 'ios' +aiofastnet >= 0.9.0; platform_python_implementation == 'CPython' and (platform_machine == 'x86_64' or platform_machine == 'AMD64' or platform_machine == 'aarch64') aiohappyeyeballs >= 2.5.0 aiosignal >= 1.4.0 async-timeout >= 4.0, < 6.0 ; python_version < '3.11' diff --git a/tests/conftest.py b/tests/conftest.py index 3869d93794e..3e1a6b2e9fc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -109,6 +109,10 @@ def blockbuster(request: pytest.FixtureRequest) -> Iterator[None]: # synchronization in async code. # Allow lock.acquire calls to prevent these false positives bb.functions["threading.Lock.acquire"].deactivate() + + # aiofastnet is using sendfile on a non-blocking socket. + # blockbuster triggers anyway. Seems like a false positive + bb.functions["os.sendfile"].deactivate() yield diff --git a/tests/test_connector.py b/tests/test_connector.py index 0b1cbcff03e..cbc1f561031 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -671,7 +671,7 @@ async def test_tcp_connector_certificate_error( conn = aiohttp.TCPConnector() with mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, @@ -694,7 +694,7 @@ async def test_tcp_connector_server_hostname_default( conn = aiohttp.TCPConnector() with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -712,7 +712,7 @@ async def test_tcp_connector_server_hostname_override( conn = aiohttp.TCPConnector() with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -869,7 +869,7 @@ def get_extra_info(param: str) -> object: side_effect=_resolve_host, ), mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, @@ -970,7 +970,7 @@ async def create_connection( side_effect=sock_connect, ): with mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, @@ -1063,7 +1063,7 @@ async def create_connection( side_effect=_resolve_host, ), mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, @@ -1145,7 +1145,7 @@ async def create_connection( side_effect=sock_connect, ): with mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, @@ -1258,7 +1258,7 @@ async def create_connection( side_effect=_resolve_host, ), mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, @@ -2219,7 +2219,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( conn = aiohttp.TCPConnector(ssl_shutdown_timeout=2.5) with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -2237,7 +2237,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( conn = aiohttp.TCPConnector(ssl_shutdown_timeout=None) with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -2256,7 +2256,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( conn = aiohttp.TCPConnector(ssl_shutdown_timeout=2.5) with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -2284,7 +2284,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_not_passed_pre_311( assert any(issubclass(warn.category, RuntimeWarning) for warn in w) with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -2442,7 +2442,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_zero_not_passed( conn = aiohttp.TCPConnector(ssl_shutdown_timeout=0) with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -2474,7 +2474,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_nonzero_passed( conn = aiohttp.TCPConnector(ssl_shutdown_timeout=5.0) with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True + connector_module, "create_connection", autospec=True, spec_set=True ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() @@ -2543,7 +2543,9 @@ async def test_start_tls_exception_with_ssl_shutdown_timeout_zero() -> None: mock.patch.object( conn, "_get_ssl_context", return_value=ssl.create_default_context() ), - mock.patch.object(conn._loop, "start_tls", side_effect=OSError("TLS failed")), + mock.patch.object( + connector_module, "start_tls", side_effect=OSError("TLS failed") + ), ): with pytest.raises(OSError): await conn._start_tls_connection(underlying_transport, req, ClientTimeout()) @@ -2575,7 +2577,9 @@ async def test_start_tls_exception_with_ssl_shutdown_timeout_nonzero() -> None: mock.patch.object( conn, "_get_ssl_context", return_value=ssl.create_default_context() ), - mock.patch.object(conn._loop, "start_tls", side_effect=OSError("TLS failed")), + mock.patch.object( + connector_module, "start_tls", side_effect=OSError("TLS failed") + ), ): with pytest.raises(OSError): await conn._start_tls_connection(underlying_transport, req, ClientTimeout()) @@ -2610,7 +2614,9 @@ async def test_start_tls_exception_with_ssl_shutdown_timeout_nonzero_pre_311() - mock.patch.object( conn, "_get_ssl_context", return_value=ssl.create_default_context() ), - mock.patch.object(conn._loop, "start_tls", side_effect=OSError("TLS failed")), + mock.patch.object( + connector_module, "start_tls", side_effect=OSError("TLS failed") + ), ): with pytest.raises(OSError): await conn._start_tls_connection(underlying_transport, req, ClientTimeout()) @@ -4072,7 +4078,7 @@ async def _resolve_host( first_conn = next(iter(conn._conns.values()))[0][0] assert first_conn.transport is not None - _sslcontext = first_conn.transport._ssl_protocol._sslcontext # type: ignore[attr-defined] + _sslcontext = first_conn.transport.get_extra_info("sslcontext") assert _sslcontext is client_ssl_ctx r.close() @@ -4510,19 +4516,15 @@ async def allow_connection_and_add_dummy_waiter() -> None: def test_connector_multiple_event_loop(make_client_request: _RequestMaker) -> None: """Test the connector with multiple event loops.""" + async def create_connection(*args: object, **kwargs: object) -> NoReturn: + raise ssl.CertificateError + async def async_connect() -> Literal[True]: conn = aiohttp.TCPConnector() loop = asyncio.get_running_loop() req = make_client_request("GET", URL("https://127.0.0.1"), loop=loop) with suppress(aiohttp.ClientConnectorError): - with mock.patch.object( - conn._loop, - "create_connection", - autospec=True, - spec_set=True, - side_effect=ssl.CertificateError, - ): - await conn.connect(req, [], ClientTimeout()) + await conn.connect(req, [], ClientTimeout()) return True def test_connect() -> Literal[True]: @@ -4532,9 +4534,10 @@ def test_connect() -> Literal[True]: finally: loop.close() - with futures.ThreadPoolExecutor() as executor: - res_list = [executor.submit(test_connect) for _ in range(2)] - raw_response_list = [res.result() for res in futures.as_completed(res_list)] + with mock.patch.object(connector_module, "create_connection", create_connection): + with futures.ThreadPoolExecutor() as executor: + res_list = [executor.submit(test_connect) for _ in range(2)] + raw_response_list = [res.result() for res in futures.as_completed(res_list)] assert raw_response_list == [True, True] @@ -4559,7 +4562,7 @@ async def test_tcp_connector_socket_factory( ) with mock.patch.object( - conn._loop, + connector_module, "create_connection", autospec=True, spec_set=True, diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 9cd8b3f1d6a..ad47a608115 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -10,7 +10,7 @@ from yarl import URL import aiohttp -from aiohttp import hdrs +from aiohttp import connector as connector_module, hdrs from aiohttp.abc import AbstractStreamWriter from aiohttp.client_reqrep import ( ClientRequest, @@ -70,7 +70,7 @@ async def test_connect( # type: ignore[misc] } ) with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(proto.transport, proto), @@ -131,7 +131,7 @@ async def test_proxy_headers( # type: ignore[misc] } ) with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(proto.transport, proto), @@ -204,7 +204,7 @@ async def test_proxy_connection_error( # type: ignore[misc] } with mock.patch.object(connector, "_resolve_host", autospec=True, return_value=[r]): with mock.patch.object( - connector._loop, + connector_module, "create_connection", autospec=True, side_effect=OSError("dont take it serious"), @@ -274,13 +274,13 @@ async def test_proxy_server_hostname_default( # type: ignore[misc] ): tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, return_value=mock.Mock(), @@ -360,13 +360,13 @@ async def test_proxy_server_hostname_override( # type: ignore[misc] ): tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, return_value=mock.Mock(), @@ -482,14 +482,14 @@ def close(self) -> None: return_value=fingerprint_mock, ), mock.patch.object( # Called on connection to http://proxy.example.com - event_loop, + connector_module, "create_connection", autospec=True, spec_set=True, return_value=(mock.Mock(), mock.Mock()), ), mock.patch.object( # Called on connection to https://www.python.org - event_loop, + connector_module, "start_tls", autospec=True, spec_set=True, @@ -561,13 +561,13 @@ async def test_https_connect( # type: ignore[misc] ): tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, return_value=mock.Mock(), @@ -647,14 +647,14 @@ async def test_https_connect_certificate_error( # type: ignore[misc] tr, proto = mock.Mock(), mock.Mock() # Called on connection to http://proxy.example.com with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): # Called on connection to https://www.python.org with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, side_effect=ssl.CertificateError, @@ -728,14 +728,14 @@ async def test_https_connect_ssl_error( # type: ignore[misc] tr, proto = mock.Mock(), mock.Mock() # Called on connection to http://proxy.example.com with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): # Called on connection to https://www.python.org with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, side_effect=ssl.SSLError, @@ -811,7 +811,7 @@ async def test_https_connect_http_proxy_error( # type: ignore[misc] tr.get_extra_info.return_value = None # Called on connection to http://proxy.example.com with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), @@ -891,7 +891,7 @@ async def test_https_connect_resp_start_error( # type: ignore[misc] tr.get_extra_info.return_value = None # Called on connection to http://proxy.example.com with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), @@ -940,7 +940,7 @@ async def test_request_port( # type: ignore[misc] tr.get_extra_info.return_value = None # Called on connection to http://proxy.example.com with mock.patch.object( - event_loop, "create_connection", autospec=True, return_value=(tr, proto) + connector_module, "create_connection", autospec=True, return_value=(tr, proto) ): req = make_client_request( "GET", @@ -1008,13 +1008,13 @@ async def test_https_connect_pass_ssl_context( # type: ignore[misc] ): tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, return_value=mock.Mock(), @@ -1031,6 +1031,7 @@ async def test_https_connect_pass_ssl_context( # type: ignore[misc] # ssl_shutdown_timeout=0 is not passed to start_tls tls_m.assert_called_with( + event_loop, mock.ANY, mock.ANY, _SSL_CONTEXT_VERIFIED, @@ -1103,13 +1104,13 @@ async def test_https_auth( # type: ignore[misc] ) as host_m: tr, proto = mock.Mock(), mock.Mock() with mock.patch.object( - event_loop, + connector_module, "create_connection", autospec=True, return_value=(tr, proto), ): with mock.patch.object( - event_loop, + connector_module, "start_tls", autospec=True, return_value=mock.Mock(), diff --git a/tests/test_run_app.py b/tests/test_run_app.py index a1cf5dd0f92..e8acc2ccae9 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -26,6 +26,7 @@ WSCloseCode, web, ) +from aiohttp import web_runner as web_runner_module from aiohttp.log import access_logger from aiohttp.web_protocol import RequestHandler from aiohttp.web_runner import BaseRunner @@ -69,6 +70,17 @@ def skip_if_on_windows() -> None: pytest.skip("the test is not valid for Windows") +@pytest.fixture +def create_server_mock() -> Iterator[mock.AsyncMock]: + server = mock.create_autospec(asyncio.Server, spec_set=True, instance=True) + server.wait_closed.return_value = None + server.sockets = [] + create_server_mock = mock.AsyncMock(return_value=server) + + with mock.patch.object(web_runner_module, "create_server", create_server_mock): + yield create_server_mock + + @pytest.fixture def patched_loop( event_loop: asyncio.AbstractEventLoop, @@ -103,7 +115,7 @@ def f(*args: object) -> None: return f -def test_run_app_http(patched_loop: asyncio.AbstractEventLoop) -> None: +def test_run_app_http(patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock) -> None: app = web.Application() startup_handler = mock.AsyncMock() app.on_startup.append(startup_handler) @@ -112,19 +124,21 @@ def test_run_app_http(patched_loop: asyncio.AbstractEventLoop) -> None: web.run_app(app, print=stopper(patched_loop), loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None + create_server_mock.assert_called_with( + patched_loop, mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None, ) startup_handler.assert_called_once_with(app) cleanup_handler.assert_called_once_with(app) -def test_run_app_close_loop(patched_loop: asyncio.AbstractEventLoop) -> None: +def test_run_app_close_loop( + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock +) -> None: app = web.Application() web.run_app(app, print=stopper(patched_loop), loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None + create_server_mock.assert_called_with( + patched_loop, mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None, ) assert patched_loop.is_closed() @@ -160,6 +174,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: ] mock_server_single = [ mock.call( + mock.ANY, mock.ANY, "127.0.0.1", 8080, @@ -171,6 +186,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: ] mock_server_multi = [ mock.call( + mock.ANY, mock.ANY, "127.0.0.1", 8080, @@ -180,6 +196,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: reuse_port=None, ), mock.call( + mock.ANY, mock.ANY, "192.168.1.1", 8080, @@ -191,7 +208,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: ] mock_server_default_8989 = [ mock.call( - mock.ANY, None, 8989, ssl=None, backlog=128, reuse_address=None, reuse_port=None + mock.ANY, mock.ANY, None, 8989, ssl=None, backlog=128, reuse_address=None, reuse_port=None ) ] mock_socket = mock.Mock(getsockname=lambda: ("mock-socket", 123)) @@ -203,6 +220,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {}, [ mock.call( + mock.ANY, mock.ANY, None, 8080, @@ -261,6 +279,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: }, [ mock.call( + mock.ANY, mock.ANY, "127.0.0.1", 8000, @@ -270,6 +289,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: reuse_port=None, ), mock.call( + mock.ANY, mock.ANY, "192.168.1.1", 8000, @@ -284,7 +304,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: ( "Only socket", {"sock": [mock_socket]}, - [mock.call(mock.ANY, ssl=None, sock=mock_socket, backlog=128)], + [mock.call(mock.ANY, mock.ANY, ssl=None, sock=mock_socket, backlog=128)], [], ), ( @@ -292,6 +312,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"sock": [mock_socket], "port": 8765}, [ mock.call( + mock.ANY, mock.ANY, None, 8765, @@ -300,7 +321,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: reuse_address=None, reuse_port=None, ), - mock.call(mock.ANY, sock=mock_socket, ssl=None, backlog=128), + mock.call(mock.ANY, mock.ANY, sock=mock_socket, ssl=None, backlog=128), ], [], ), @@ -309,6 +330,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"sock": [mock_socket], "host": "localhost"}, [ mock.call( + mock.ANY, mock.ANY, "localhost", 8080, @@ -317,7 +339,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: reuse_address=None, reuse_port=None, ), - mock.call(mock.ANY, sock=mock_socket, ssl=None, backlog=128), + mock.call(mock.ANY, mock.ANY, sock=mock_socket, ssl=None, backlog=128), ], [], ), @@ -326,6 +348,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"reuse_port": True}, [ mock.call( + mock.ANY, mock.ANY, None, 8080, @@ -342,6 +365,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"reuse_address": False}, [ mock.call( + mock.ANY, mock.ANY, None, 8080, @@ -358,6 +382,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"reuse_address": True, "reuse_port": True}, [ mock.call( + mock.ANY, mock.ANY, None, 8080, @@ -374,6 +399,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"port": 8989, "reuse_port": True}, [ mock.call( + mock.ANY, mock.ANY, None, 8989, @@ -390,6 +416,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: {"host": ("127.0.0.1", "192.168.1.1"), "reuse_port": True}, [ mock.call( + mock.ANY, mock.ANY, "127.0.0.1", 8080, @@ -399,6 +426,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: reuse_port=True, ), mock.call( + mock.ANY, mock.ANY, "192.168.1.1", 8080, @@ -419,6 +447,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: }, [ mock.call( + mock.ANY, mock.ANY, None, 8989, @@ -440,6 +469,7 @@ async def failing_ctx(_app: web.Application) -> AsyncIterator[None]: }, [ mock.call( + mock.ANY, mock.ANY, "127.0.0.1", 8080, @@ -466,15 +496,16 @@ def test_run_app_mixed_bindings( # type: ignore[misc] expected_server_calls: list[mock._Call], expected_unix_server_calls: list[mock._Call], patched_loop: asyncio.AbstractEventLoop, + create_server_mock: mock.AsyncMock, ) -> None: app = web.Application() web.run_app(app, print=stopper(patched_loop), **run_app_kwargs, loop=patched_loop) assert patched_loop.create_unix_server.mock_calls == expected_unix_server_calls # type: ignore[attr-defined] - assert patched_loop.create_server.mock_calls == expected_server_calls # type: ignore[attr-defined] + assert create_server_mock.mock_calls == expected_server_calls -def test_run_app_https(patched_loop: asyncio.AbstractEventLoop) -> None: +def test_run_app_https(patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock) -> None: app = web.Application() ssl_context = ssl.create_default_context() @@ -482,7 +513,8 @@ def test_run_app_https(patched_loop: asyncio.AbstractEventLoop) -> None: app, ssl_context=ssl_context, print=stopper(patched_loop), loop=patched_loop ) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] + create_server_mock.assert_called_with( + patched_loop, mock.ANY, None, 8443, @@ -494,23 +526,32 @@ def test_run_app_https(patched_loop: asyncio.AbstractEventLoop) -> None: def test_run_app_nondefault_host_port( - patched_loop: asyncio.AbstractEventLoop, unused_port_socket: socket.socket + patched_loop: asyncio.AbstractEventLoop, + unused_port_socket: socket.socket, + create_server_mock: mock.AsyncMock, ) -> None: port = unused_port_socket.getsockname()[1] host = "127.0.0.1" app = web.Application() - web.run_app( - app, host=host, port=port, print=stopper(patched_loop), loop=patched_loop - ) + web.run_app(app, host=host, port=port, print=stopper(patched_loop), loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, host, port, ssl=None, backlog=128, reuse_address=None, reuse_port=None + create_server_mock.assert_called_with( + patched_loop, + mock.ANY, + host, + port, + ssl=None, + backlog=128, + reuse_address=None, + reuse_port=None, ) def test_run_app_with_sock( - patched_loop: asyncio.AbstractEventLoop, unused_port_socket: socket.socket + patched_loop: asyncio.AbstractEventLoop, + unused_port_socket: socket.socket, + create_server_mock: mock.AsyncMock, ) -> None: sock = unused_port_socket app = web.Application() @@ -521,12 +562,14 @@ def test_run_app_with_sock( loop=patched_loop, ) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, sock=sock, ssl=None, backlog=128 + create_server_mock.assert_called_with( + patched_loop, mock.ANY, sock=sock, ssl=None, backlog=128 ) -def test_run_app_multiple_hosts(patched_loop: asyncio.AbstractEventLoop) -> None: +def test_run_app_multiple_hosts( + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock +) -> None: hosts = ("127.0.0.1", "127.0.0.2") app = web.Application() @@ -534,6 +577,7 @@ def test_run_app_multiple_hosts(patched_loop: asyncio.AbstractEventLoop) -> None calls = map( lambda h: mock.call( + patched_loop, mock.ANY, h, 8080, @@ -544,15 +588,17 @@ def test_run_app_multiple_hosts(patched_loop: asyncio.AbstractEventLoop) -> None ), hosts, ) - patched_loop.create_server.assert_has_calls(calls) # type: ignore[attr-defined] + create_server_mock.assert_has_calls(list(calls)) -def test_run_app_custom_backlog(patched_loop: asyncio.AbstractEventLoop) -> None: +def test_run_app_custom_backlog( + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock +) -> None: app = web.Application() web.run_app(app, backlog=10, print=stopper(patched_loop), loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, None, 8080, ssl=None, backlog=10, reuse_address=None, reuse_port=None + create_server_mock.assert_called_with( + patched_loop, mock.ANY, None, 8080, ssl=None, backlog=10, reuse_address=None, reuse_port=None, ) @@ -630,7 +676,9 @@ def test_run_app_abstract_linux_socket( def test_run_app_preexisting_inet_socket( - patched_loop: asyncio.AbstractEventLoop, mocker: MockerFixture + patched_loop: asyncio.AbstractEventLoop, + mocker: MockerFixture, + create_server_mock: mock.AsyncMock, ) -> None: app = web.Application() @@ -642,15 +690,15 @@ def test_run_app_preexisting_inet_socket( printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, sock=sock, print=printer, loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, sock=sock, backlog=128, ssl=None + create_server_mock.assert_called_with( + patched_loop, mock.ANY, sock=sock, backlog=128, ssl=None ) assert f"http://127.0.0.1:{port}" in printer.call_args[0][0] @pytest.mark.skipif(not HAS_IPV6, reason="IPv6 is not available") def test_run_app_preexisting_inet6_socket( - patched_loop: asyncio.AbstractEventLoop, + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock ) -> None: app = web.Application() @@ -662,15 +710,18 @@ def test_run_app_preexisting_inet6_socket( printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, sock=sock, print=printer, loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, sock=sock, backlog=128, ssl=None + create_server_mock.assert_called_with( + patched_loop, mock.ANY, sock=sock, backlog=128, ssl=None ) assert f"http://[::1]:{port}" in printer.call_args[0][0] @skip_if_no_unix_socks def test_run_app_preexisting_unix_socket( - patched_loop: asyncio.AbstractEventLoop, unix_sockname: str, mocker: MockerFixture + patched_loop: asyncio.AbstractEventLoop, + unix_sockname: str, + mocker: MockerFixture, + create_server_mock: mock.AsyncMock, ) -> None: app = web.Application() @@ -682,14 +733,14 @@ def test_run_app_preexisting_unix_socket( printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, sock=sock, print=printer, loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, sock=sock, backlog=128, ssl=None + create_server_mock.assert_called_with( + patched_loop, mock.ANY, sock=sock, backlog=128, ssl=None ) assert f"http://unix:{unix_sockname}:" in printer.call_args[0][0] def test_run_app_multiple_preexisting_sockets( - patched_loop: asyncio.AbstractEventLoop, + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock ) -> None: app = web.Application() @@ -704,10 +755,10 @@ def test_run_app_multiple_preexisting_sockets( printer = mock.Mock(wraps=stopper(patched_loop)) web.run_app(app, sock=(sock1, sock2), print=printer, loop=patched_loop) - patched_loop.create_server.assert_has_calls( # type: ignore[attr-defined] + create_server_mock.assert_has_calls( [ - mock.call(mock.ANY, sock=sock1, backlog=128, ssl=None), - mock.call(mock.ANY, sock=sock2, backlog=128, ssl=None), + mock.call(patched_loop, mock.ANY, sock=sock1, backlog=128, ssl=None), + mock.call(patched_loop, mock.ANY, sock=sock2, backlog=128, ssl=None), ] ) assert f"http://127.0.0.1:{port1}" in printer.call_args[0][0] @@ -753,9 +804,9 @@ def test_sigterm() -> None: def test_startup_cleanup_signals_even_on_failure( - patched_loop: asyncio.AbstractEventLoop, + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock ) -> None: - patched_loop.create_server.side_effect = RuntimeError() # type: ignore[attr-defined] + create_server_mock.side_effect = RuntimeError() app = web.Application() startup_handler = mock.AsyncMock() @@ -770,7 +821,9 @@ def test_startup_cleanup_signals_even_on_failure( cleanup_handler.assert_called_once_with(app) -def test_run_app_coro(patched_loop: asyncio.AbstractEventLoop) -> None: +def test_run_app_coro( + patched_loop: asyncio.AbstractEventLoop, create_server_mock: mock.AsyncMock +) -> None: startup_handler = cleanup_handler = None async def make_app() -> web.Application: @@ -784,8 +837,8 @@ async def make_app() -> web.Application: web.run_app(make_app(), print=stopper(patched_loop), loop=patched_loop) - patched_loop.create_server.assert_called_with( # type: ignore[attr-defined] - mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None + create_server_mock.assert_called_with( + patched_loop, mock.ANY, None, 8080, ssl=None, backlog=128, reuse_address=None, reuse_port=None, ) assert startup_handler is not None assert cleanup_handler is not None @@ -911,9 +964,7 @@ async def on_startup(app: web.Application) -> None: assert task.cancelled() -def test_run_app_cancels_done_tasks( - patched_loop: asyncio.AbstractEventLoop, -) -> None: +def test_run_app_cancels_done_tasks(patched_loop: asyncio.AbstractEventLoop) -> None: app = web.Application() task = None @@ -932,9 +983,7 @@ async def on_startup(app: web.Application) -> None: assert task.done() -def test_run_app_cancels_failed_tasks( - patched_loop: asyncio.AbstractEventLoop, -) -> None: +def test_run_app_cancels_failed_tasks(patched_loop: asyncio.AbstractEventLoop) -> None: app = web.Application() task = None @@ -1031,9 +1080,7 @@ async def init() -> web.Application: assert count == 3 -def test_run_app_raises_exception( - patched_loop: asyncio.AbstractEventLoop, -) -> None: +def test_run_app_raises_exception(patched_loop: asyncio.AbstractEventLoop) -> None: async def context(app: web.Application) -> AsyncIterator[None]: raise RuntimeError("foo") yield # type: ignore[unreachable] # pragma: no cover diff --git a/tests/test_web_runner.py b/tests/test_web_runner.py index c4b7b19e8b7..0798b785647 100644 --- a/tests/test_web_runner.py +++ b/tests/test_web_runner.py @@ -9,6 +9,7 @@ import pytest from aiohttp import web +from aiohttp import web_runner as web_runner_module from aiohttp.abc import AbstractAccessLogger from aiohttp.test_utils import REUSE_ADDRESS from aiohttp.web_log import AccessLogger @@ -265,16 +266,16 @@ async def test_tcpsite_default_host(make_runner: _RunnerMaker) -> None: site = web.TCPSite(runner) assert site.name == "http://0.0.0.0:8080" - m = mock.create_autospec(asyncio.AbstractEventLoop, spec_set=True, instance=True) - m.create_server.return_value = mock.create_autospec(asyncio.Server, spec_set=True) - with mock.patch( - "asyncio.get_running_loop", autospec=True, spec_set=True, return_value=m - ): + create_server = mock.AsyncMock( + return_value=mock.create_autospec(asyncio.Server, spec_set=True) + ) + + with mock.patch.object(web_runner_module, "create_server", create_server): await site.start() - m.create_server.assert_called_once() - args, kwargs = m.create_server.call_args - assert args == (runner.server, None, 8080) + create_server.assert_called_once() + args, kwargs = create_server.call_args + assert args == (asyncio.get_running_loop(), runner.server, None, 8080) async def test_tcpsite_empty_str_host(make_runner: _RunnerMaker) -> None: diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index e4daf828fcd..7ecb2899c7e 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -14,6 +14,7 @@ import aiohttp from aiohttp import web +from aiohttp import web_fileresponse as web_fileresponse_module from aiohttp.compression_utils import ZLibBackend from aiohttp.typedefs import PathLike from aiohttp.web_fileresponse import NOSENDFILE @@ -74,14 +75,13 @@ async def sender(request: SubRequest) -> AsyncIterator[_Sender]: def maker(path: PathLike, chunk_size: int = 256 * 1024) -> web.FileResponse: ret = web.FileResponse(path, chunk_size=chunk_size) - rloop = asyncio.get_running_loop() - is_patched = rloop.sendfile is sendfile_mock + is_patched = web_fileresponse_module.sendfile is sendfile_mock assert is_patched if request.param == "no_sendfile" else not is_patched return ret if request.param == "no_sendfile": with mock.patch.object( - asyncio.get_running_loop(), + web_fileresponse_module, "sendfile", autospec=True, spec_set=True,