55import contextlib
66import multiprocessing
77import queue
8- import ssl
98import sys
109import time
1110import traceback
12- from tempfile import NamedTemporaryFile
1311from typing import TYPE_CHECKING , Literal , TypeVar , get_args
1412
15- import httpx
1613from _util import create_standard_streams
1714from gen .connectrpc .conformance .v1 .client_compat_pb2 import (
1815 ClientCompatRequest ,
4037 UnimplementedRequest ,
4138)
4239from google .protobuf .message import Message
43- from pyqwest import HTTPTransport , SyncHTTPTransport
40+ from pyqwest import Client , HTTPTransport , SyncClient , SyncHTTPTransport
4441from pyqwest import HTTPVersion as PyQwestHTTPVersion
45- from pyqwest .httpx import AsyncPyqwestTransport , PyqwestTransport
4642
4743from connectrpc .client import ResponseMetadata
4844from connectrpc .code import Code
@@ -118,41 +114,8 @@ def _unpack_request(message: Any, request: T) -> T:
118114 return request
119115
120116
121- async def httpx_client_kwargs (test_request : ClientCompatRequest ) -> dict :
122- kwargs = {}
123- match test_request .http_version :
124- case HTTPVersion .HTTP_VERSION_1 :
125- kwargs ["http1" ] = True
126- kwargs ["http2" ] = False
127- case HTTPVersion .HTTP_VERSION_2 :
128- kwargs ["http1" ] = False
129- kwargs ["http2" ] = True
130- if test_request .server_tls_cert :
131- ctx = ssl .create_default_context (
132- purpose = ssl .Purpose .SERVER_AUTH ,
133- cadata = test_request .server_tls_cert .decode (),
134- )
135- if test_request .HasField ("client_tls_creds" ):
136-
137- def load_certs () -> None :
138- with (
139- NamedTemporaryFile () as cert_file ,
140- NamedTemporaryFile () as key_file ,
141- ):
142- cert_file .write (test_request .client_tls_creds .cert )
143- cert_file .flush ()
144- key_file .write (test_request .client_tls_creds .key )
145- key_file .flush ()
146- ctx .load_cert_chain (certfile = cert_file .name , keyfile = key_file .name )
147-
148- await asyncio .to_thread (load_certs )
149- kwargs ["verify" ] = ctx
150-
151- return kwargs
152-
153-
154117def pyqwest_client_kwargs (test_request : ClientCompatRequest ) -> dict :
155- kwargs : dict = {"enable_gzip" : True , "enable_brotli" : True , "enable_zstd" : True }
118+ kwargs : dict = {}
156119 match test_request .http_version :
157120 case HTTPVersion .HTTP_VERSION_1 :
158121 kwargs ["http_version" ] = PyQwestHTTPVersion .HTTP1
@@ -169,28 +132,26 @@ def pyqwest_client_kwargs(test_request: ClientCompatRequest) -> dict:
169132
170133@contextlib .asynccontextmanager
171134async def client_sync (
172- test_request : ClientCompatRequest , client_type : Client
135+ test_request : ClientCompatRequest ,
173136) -> AsyncIterator [ConformanceServiceClientSync ]:
174137 read_max_bytes = None
175138 if test_request .message_receive_limit :
176139 read_max_bytes = test_request .message_receive_limit
177140 scheme = "https" if test_request .server_tls_cert else "http"
141+ args = pyqwest_client_kwargs (test_request )
142+
178143 cleanup = contextlib .ExitStack ()
179- match client_type :
180- case "httpx" :
181- args = await httpx_client_kwargs (test_request )
182- session = cleanup .enter_context (httpx .Client (** args ))
183- case "pyqwest" :
184- args = pyqwest_client_kwargs (test_request )
185- http_transport = cleanup .enter_context (SyncHTTPTransport (** args ))
186- transport = cleanup .enter_context (PyqwestTransport (http_transport ))
187- session = cleanup .enter_context (httpx .Client (transport = transport ))
144+ if args :
145+ transport = cleanup .enter_context (SyncHTTPTransport (** args ))
146+ http_client = SyncClient (transport )
147+ else :
148+ http_client = None
188149
189150 with (
190151 cleanup ,
191152 ConformanceServiceClientSync (
192153 f"{ scheme } ://{ test_request .host } :{ test_request .port } " ,
193- session = session ,
154+ http_client = http_client ,
194155 send_compression = _convert_compression (test_request .compression ),
195156 proto_json = test_request .codec == Codec .CODEC_JSON ,
196157 grpc = test_request .protocol == Protocol .PROTOCOL_GRPC ,
@@ -202,32 +163,29 @@ async def client_sync(
202163
203164@contextlib .asynccontextmanager
204165async def client_async (
205- test_request : ClientCompatRequest , client_type : Client
166+ test_request : ClientCompatRequest ,
206167) -> AsyncIterator [ConformanceServiceClient ]:
207168 read_max_bytes = None
208169 if test_request .message_receive_limit :
209170 read_max_bytes = test_request .message_receive_limit
210171 scheme = "https" if test_request .server_tls_cert else "http"
172+ args = pyqwest_client_kwargs (test_request )
173+
211174 cleanup = contextlib .AsyncExitStack ()
212- match client_type :
213- case "httpx" :
214- args = await httpx_client_kwargs (test_request )
215- session = await cleanup .enter_async_context (httpx .AsyncClient (** args ))
216- case "pyqwest" :
217- args = pyqwest_client_kwargs (test_request )
218- http_transport = await cleanup .enter_async_context (HTTPTransport (** args ))
219- transport = await cleanup .enter_async_context (
220- AsyncPyqwestTransport (http_transport )
221- )
222- session = await cleanup .enter_async_context (
223- httpx .AsyncClient (transport = transport )
224- )
175+ if args :
176+ transport = HTTPTransport (** args )
177+ # Type parameter for enter_async_context requires coroutine even though
178+ # implementation doesn't. We can directly push aexit to work around it.
179+ cleanup .push_async_exit (transport .__aexit__ )
180+ http_client = Client (transport )
181+ else :
182+ http_client = None
225183
226184 async with (
227185 cleanup ,
228186 ConformanceServiceClient (
229187 f"{ scheme } ://{ test_request .host } :{ test_request .port } " ,
230- session = session ,
188+ http_client = http_client ,
231189 send_compression = _convert_compression (test_request .compression ),
232190 proto_json = test_request .codec == Codec .CODEC_JSON ,
233191 grpc = test_request .protocol == Protocol .PROTOCOL_GRPC ,
@@ -238,7 +196,7 @@ async def client_async(
238196
239197
240198async def _run_test (
241- mode : Mode , test_request : ClientCompatRequest , client_type : Client
199+ mode : Mode , test_request : ClientCompatRequest
242200) -> ClientCompatResponse :
243201 test_response = ClientCompatResponse ()
244202 test_response .test_name = test_request .test_name
@@ -260,7 +218,7 @@ async def _run_test(
260218 request_closed = asyncio .Event ()
261219 match mode :
262220 case "sync" :
263- async with client_sync (test_request , client_type ) as client :
221+ async with client_sync (test_request ) as client :
264222 match test_request .method :
265223 case "BidiStream" :
266224 request_queue = queue .Queue ()
@@ -468,7 +426,7 @@ def send_unary_request_sync(
468426 task .cancel ()
469427 await task
470428 case "async" :
471- async with client_async (test_request , client_type ) as client :
429+ async with client_async (test_request ) as client :
472430 match test_request .method :
473431 case "BidiStream" :
474432 request_queue = asyncio .Queue ()
@@ -691,20 +649,17 @@ async def send_unary_request(
691649
692650
693651Mode = Literal ["sync" , "async" ]
694- Client = Literal ["httpx" , "pyqwest" ]
695652
696653
697654class Args (argparse .Namespace ):
698655 mode : Mode
699- client : Client
700656 parallel : int
701657
702658
703659async def main () -> None :
704660 parser = argparse .ArgumentParser (description = "Conformance client" )
705661 parser .add_argument ("--mode" , choices = get_args (Mode ))
706662 parser .add_argument ("--parallel" , type = int , default = multiprocessing .cpu_count () * 4 )
707- parser .add_argument ("--client" , choices = get_args (Client ))
708663 args = parser .parse_args (namespace = Args ())
709664
710665 stdin , stdout = await create_standard_streams ()
@@ -724,7 +679,7 @@ async def main() -> None:
724679
725680 async def task (request : ClientCompatRequest ) -> None :
726681 async with sema :
727- response = await _run_test (args .mode , request , args . client )
682+ response = await _run_test (args .mode , request )
728683
729684 response_buf = response .SerializeToString ()
730685 size_buf = len (response_buf ).to_bytes (4 , byteorder = "big" )
0 commit comments