@@ -24,8 +24,10 @@ use thiserror::Error;
2424use tokio:: io;
2525use tokio:: net:: { TcpListener , TcpStream , ToSocketAddrs } ;
2626use tokio:: sync:: { mpsc, oneshot} ;
27- use tokio_rustls:: rustls:: server:: VerifierBuilderError ;
28- use tokio_rustls:: rustls:: { ClientConfig , ServerConfig , pki_types:: CertificateDer } ;
27+ use tokio_rustls:: rustls:: server:: { VerifierBuilderError , WebPkiClientVerifier } ;
28+ use tokio_rustls:: rustls:: {
29+ self , ClientConfig , RootCertStore , ServerConfig , pki_types:: CertificateDer ,
30+ } ;
2931use tracing:: { debug, error, warn} ;
3032
3133use crate :: http_version:: { ALPN_H2 , ALPN_HTTP11 , HttpConnection , HttpSender , HttpVersion } ;
@@ -59,6 +61,17 @@ type RequestWithResponseSender = (
5961 oneshot:: Sender < Result < Response < BoxBody < bytes:: Bytes , hyper:: Error > > , hyper:: Error > > ,
6062) ;
6163
64+ /// Adds HTTP 1 and 2 to the list of allowed protocols
65+ fn ensure_proxy_alpn_protocols ( alpn_protocols : & mut Vec < Vec < u8 > > ) {
66+ for protocol in [ ALPN_H2 , ALPN_HTTP11 ] {
67+ let already_present = alpn_protocols. iter ( ) . any ( |p| p. as_slice ( ) == protocol) ;
68+
69+ if !already_present {
70+ alpn_protocols. push ( protocol. to_vec ( ) ) ;
71+ }
72+ }
73+ }
74+
6275/// Retrieve the attested remote TLS certificate.
6376pub async fn get_tls_cert (
6477 server_name : String ,
@@ -101,11 +114,32 @@ impl ProxyServer {
101114 attestation_verifier : AttestationVerifier ,
102115 client_auth : bool ,
103116 ) -> Result < Self , ProxyError > {
104- let attested_tls_server = AttestedTlsServer :: new (
105- cert_and_key,
117+ let mut server_config = if client_auth {
118+ let root_store =
119+ RootCertStore :: from_iter ( webpki_roots:: TLS_SERVER_ROOTS . iter ( ) . cloned ( ) ) ;
120+ let verifier = WebPkiClientVerifier :: builder ( Arc :: new ( root_store) ) . build ( ) ?;
121+
122+ ServerConfig :: builder_with_protocol_versions ( & [ & rustls:: version:: TLS13 ] )
123+ . with_client_cert_verifier ( verifier)
124+ . with_single_cert (
125+ cert_and_key. cert_chain . clone ( ) ,
126+ cert_and_key. key . clone_key ( ) ,
127+ ) ?
128+ } else {
129+ ServerConfig :: builder_with_protocol_versions ( & [ & rustls:: version:: TLS13 ] )
130+ . with_no_client_auth ( )
131+ . with_single_cert (
132+ cert_and_key. cert_chain . clone ( ) ,
133+ cert_and_key. key . clone_key ( ) ,
134+ ) ?
135+ } ;
136+ ensure_proxy_alpn_protocols ( & mut server_config. alpn_protocols ) ;
137+
138+ let attested_tls_server = AttestedTlsServer :: new_with_tls_config (
139+ cert_and_key. cert_chain ,
140+ server_config,
106141 attestation_generator,
107142 attestation_verifier,
108- client_auth,
109143 ) ?;
110144
111145 let listener = TcpListener :: bind ( local) . await ?;
@@ -126,16 +160,7 @@ impl ProxyServer {
126160 attestation_generator : AttestationGenerator ,
127161 attestation_verifier : AttestationVerifier ,
128162 ) -> Result < Self , ProxyError > {
129- for protocol in [ ALPN_H2 , ALPN_HTTP11 ] {
130- let already_present = server_config
131- . alpn_protocols
132- . iter ( )
133- . any ( |p| p. as_slice ( ) == protocol) ;
134-
135- if !already_present {
136- server_config. alpn_protocols . push ( protocol. to_vec ( ) ) ;
137- }
138- }
163+ ensure_proxy_alpn_protocols ( & mut server_config. alpn_protocols ) ;
139164
140165 let attested_tls_server = AttestedTlsServer :: new_with_tls_config (
141166 cert_chain,
@@ -347,11 +372,34 @@ impl ProxyClient {
347372 attestation_verifier : AttestationVerifier ,
348373 remote_certificate : Option < CertificateDer < ' static > > ,
349374 ) -> Result < Self , ProxyError > {
350- let attested_tls_client = AttestedTlsClient :: new (
351- cert_and_key,
375+ let root_store = match remote_certificate {
376+ Some ( remote_certificate) => {
377+ let mut root_store = RootCertStore :: empty ( ) ;
378+ root_store. add ( remote_certificate) ?;
379+ root_store
380+ }
381+ None => RootCertStore :: from_iter ( webpki_roots:: TLS_SERVER_ROOTS . iter ( ) . cloned ( ) ) ,
382+ } ;
383+
384+ let mut client_config = if let Some ( ref cert_and_key) = cert_and_key {
385+ ClientConfig :: builder_with_protocol_versions ( & [ & rustls:: version:: TLS13 ] )
386+ . with_root_certificates ( root_store)
387+ . with_client_auth_cert (
388+ cert_and_key. cert_chain . clone ( ) ,
389+ cert_and_key. key . clone_key ( ) ,
390+ ) ?
391+ } else {
392+ ClientConfig :: builder_with_protocol_versions ( & [ & rustls:: version:: TLS13 ] )
393+ . with_root_certificates ( root_store)
394+ . with_no_client_auth ( )
395+ } ;
396+ ensure_proxy_alpn_protocols ( & mut client_config. alpn_protocols ) ;
397+
398+ let attested_tls_client = AttestedTlsClient :: new_with_tls_config (
399+ client_config,
352400 attestation_generator,
353401 attestation_verifier,
354- remote_certificate ,
402+ cert_and_key . map ( |c| c . cert_chain ) ,
355403 ) ?;
356404
357405 Self :: new_with_inner ( address, attested_tls_client, & server_name) . await
@@ -366,16 +414,7 @@ impl ProxyClient {
366414 attestation_verifier : AttestationVerifier ,
367415 cert_chain : Option < Vec < CertificateDer < ' static > > > ,
368416 ) -> Result < Self , ProxyError > {
369- for protocol in [ ALPN_H2 , ALPN_HTTP11 ] {
370- let already_present = client_config
371- . alpn_protocols
372- . iter ( )
373- . any ( |p| p. as_slice ( ) == protocol) ;
374-
375- if !already_present {
376- client_config. alpn_protocols . push ( protocol. to_vec ( ) ) ;
377- }
378- }
417+ ensure_proxy_alpn_protocols ( & mut client_config. alpn_protocols ) ;
379418
380419 let attested_tls_client = AttestedTlsClient :: new_with_tls_config (
381420 client_config,
@@ -763,6 +802,89 @@ mod tests {
763802 generate_tls_config_with_client_auth, init_tracing, mock_dcap_measurements,
764803 } ;
765804
805+ #[ test]
806+ fn proxy_alpn_protocols_prefer_http2 ( ) {
807+ let mut protocols = Vec :: new ( ) ;
808+ ensure_proxy_alpn_protocols ( & mut protocols) ;
809+
810+ assert_eq ! ( protocols, vec![ ALPN_H2 . to_vec( ) , ALPN_HTTP11 . to_vec( ) ] ) ;
811+ }
812+
813+ #[ test]
814+ fn proxy_alpn_protocols_preserve_existing_order_without_duplicates ( ) {
815+ let mut protocols = vec ! [ ALPN_HTTP11 . to_vec( ) , ALPN_H2 . to_vec( ) ] ;
816+ ensure_proxy_alpn_protocols ( & mut protocols) ;
817+
818+ assert_eq ! ( protocols, vec![ ALPN_HTTP11 . to_vec( ) , ALPN_H2 . to_vec( ) ] ) ;
819+ }
820+
821+ #[ tokio:: test]
822+ async fn http_proxy_default_constructors_work ( ) {
823+ let target_addr = example_http_service ( ) . await ;
824+
825+ let ( cert_chain, private_key) = generate_certificate_chain ( "127.0.0.1" . parse ( ) . unwrap ( ) ) ;
826+ let server_cert = cert_chain[ 0 ] . clone ( ) ;
827+
828+ let proxy_server = ProxyServer :: new (
829+ TlsCertAndKey {
830+ cert_chain,
831+ key : private_key,
832+ } ,
833+ "127.0.0.1:0" ,
834+ target_addr. to_string ( ) ,
835+ AttestationGenerator :: new ( AttestationType :: DcapTdx , None ) . unwrap ( ) ,
836+ AttestationVerifier :: expect_none ( ) ,
837+ false ,
838+ )
839+ . await
840+ . unwrap ( ) ;
841+
842+ let proxy_addr = proxy_server. local_addr ( ) . unwrap ( ) ;
843+
844+ tokio:: spawn ( async move {
845+ proxy_server. accept ( ) . await . unwrap ( ) ;
846+ } ) ;
847+
848+ let proxy_client = ProxyClient :: new (
849+ None ,
850+ "127.0.0.1:0" . to_string ( ) ,
851+ proxy_addr. to_string ( ) ,
852+ AttestationGenerator :: with_no_attestation ( ) ,
853+ AttestationVerifier :: mock ( ) ,
854+ Some ( server_cert) ,
855+ )
856+ . await
857+ . unwrap ( ) ;
858+
859+ let proxy_client_addr = proxy_client. local_addr ( ) . unwrap ( ) ;
860+
861+ tokio:: spawn ( async move {
862+ proxy_client. accept ( ) . await . unwrap ( ) ;
863+ } ) ;
864+
865+ let res = reqwest:: get ( format ! ( "http://{}" , proxy_client_addr) )
866+ . await
867+ . unwrap ( ) ;
868+
869+ let headers = res. headers ( ) ;
870+
871+ let attestation_type = headers
872+ . get ( ATTESTATION_TYPE_HEADER )
873+ . unwrap ( )
874+ . to_str ( )
875+ . unwrap ( ) ;
876+ assert_eq ! ( attestation_type, AttestationType :: DcapTdx . as_str( ) ) ;
877+
878+ let measurements_json = headers. get ( MEASUREMENT_HEADER ) . unwrap ( ) . to_str ( ) . unwrap ( ) ;
879+ let measurements =
880+ MultiMeasurements :: from_header_format ( measurements_json, AttestationType :: DcapTdx )
881+ . unwrap ( ) ;
882+ assert_eq ! ( measurements, mock_dcap_measurements( ) ) ;
883+
884+ let res_body = res. text ( ) . await . unwrap ( ) ;
885+ assert_eq ! ( res_body, "No measurements" ) ;
886+ }
887+
766888 // Server has mock DCAP, client has no attestation and no client auth
767889 #[ tokio:: test]
768890 async fn http_proxy_with_server_attestation ( ) {
0 commit comments