Skip to content

Commit 967a2b7

Browse files
authored
Merge pull request #151 from flashbots/peg/default-to-http2
Default to http2 for proxy-client to proxy-server connections
2 parents 0655858 + 9540908 commit 967a2b7

4 files changed

Lines changed: 164 additions & 32 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ serde = "1.0.228"
3232
reqwest = { version = "0.12.24", default-features = false, features = [
3333
"rustls-tls-webpki-roots-no-provider",
3434
] }
35+
webpki-roots = "1.0.4"
3536
tracing = "0.1.41"
3637
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
3738
axum = "0.8.6"

src/http_version.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,25 @@ impl HttpVersion {
2121
pub fn from_negotiated_protocol_server<IO>(tls: &tokio_rustls::server::TlsStream<IO>) -> Self {
2222
let (_io, conn) = tls.get_ref();
2323

24-
let chosen_protocol = Self::from_alpn_bytes(conn.alpn_protocol());
25-
tracing::debug!("[server] Chosen protocol {chosen_protocol:?}",);
24+
let negotiated_alpn = conn.alpn_protocol();
25+
let chosen_protocol = Self::from_alpn_bytes(negotiated_alpn);
26+
tracing::debug!(
27+
"[server] Negotiated ALPN {:?}, chosen protocol {chosen_protocol:?}",
28+
negotiated_alpn.map(String::from_utf8_lossy)
29+
);
2630
chosen_protocol
2731
}
2832

2933
/// Given a client TLS stream, choose an HTTP version to use
3034
pub fn from_negotiated_protocol_client<IO>(tls: &tokio_rustls::client::TlsStream<IO>) -> Self {
3135
let (_io, conn) = tls.get_ref();
3236

33-
let chosen_protocol = Self::from_alpn_bytes(conn.alpn_protocol());
34-
tracing::debug!("[client] Chosen protocol {chosen_protocol:?}",);
37+
let negotiated_alpn = conn.alpn_protocol();
38+
let chosen_protocol = Self::from_alpn_bytes(negotiated_alpn);
39+
tracing::debug!(
40+
"[client] Negotiated ALPN {:?}, chosen protocol {chosen_protocol:?}",
41+
negotiated_alpn.map(String::from_utf8_lossy)
42+
);
3543
chosen_protocol
3644
}
3745

src/lib.rs

Lines changed: 150 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ use thiserror::Error;
2424
use tokio::io;
2525
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
2626
use tokio::sync::{mpsc, oneshot};
27-
use tokio_rustls::rustls::server::VerifierBuilderError;
28-
use tokio_rustls::rustls::{ClientConfig, ServerConfig, pki_types::CertificateDer};
27+
use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier};
28+
use tokio_rustls::rustls::{
29+
self, ClientConfig, RootCertStore, ServerConfig, pki_types::CertificateDer,
30+
};
2931
use tracing::{debug, error, warn};
3032

3133
use crate::http_version::{ALPN_H2, ALPN_HTTP11, HttpConnection, HttpSender, HttpVersion};
@@ -59,6 +61,17 @@ type RequestWithResponseSender = (
5961
oneshot::Sender<Result<Response<BoxBody<bytes::Bytes, hyper::Error>>, hyper::Error>>,
6062
);
6163

64+
/// Adds HTTP 1 and 2 to the list of allowed protocols
65+
fn ensure_proxy_alpn_protocols(alpn_protocols: &mut Vec<Vec<u8>>) {
66+
for protocol in [ALPN_H2, ALPN_HTTP11] {
67+
let already_present = alpn_protocols.iter().any(|p| p.as_slice() == protocol);
68+
69+
if !already_present {
70+
alpn_protocols.push(protocol.to_vec());
71+
}
72+
}
73+
}
74+
6275
/// Retrieve the attested remote TLS certificate.
6376
pub async fn get_tls_cert(
6477
server_name: String,
@@ -101,11 +114,32 @@ impl ProxyServer {
101114
attestation_verifier: AttestationVerifier,
102115
client_auth: bool,
103116
) -> Result<Self, ProxyError> {
104-
let attested_tls_server = AttestedTlsServer::new(
105-
cert_and_key,
117+
let mut server_config = if client_auth {
118+
let root_store =
119+
RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
120+
let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?;
121+
122+
ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
123+
.with_client_cert_verifier(verifier)
124+
.with_single_cert(
125+
cert_and_key.cert_chain.clone(),
126+
cert_and_key.key.clone_key(),
127+
)?
128+
} else {
129+
ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
130+
.with_no_client_auth()
131+
.with_single_cert(
132+
cert_and_key.cert_chain.clone(),
133+
cert_and_key.key.clone_key(),
134+
)?
135+
};
136+
ensure_proxy_alpn_protocols(&mut server_config.alpn_protocols);
137+
138+
let attested_tls_server = AttestedTlsServer::new_with_tls_config(
139+
cert_and_key.cert_chain,
140+
server_config,
106141
attestation_generator,
107142
attestation_verifier,
108-
client_auth,
109143
)?;
110144

111145
let listener = TcpListener::bind(local).await?;
@@ -126,16 +160,7 @@ impl ProxyServer {
126160
attestation_generator: AttestationGenerator,
127161
attestation_verifier: AttestationVerifier,
128162
) -> Result<Self, ProxyError> {
129-
for protocol in [ALPN_H2, ALPN_HTTP11] {
130-
let already_present = server_config
131-
.alpn_protocols
132-
.iter()
133-
.any(|p| p.as_slice() == protocol);
134-
135-
if !already_present {
136-
server_config.alpn_protocols.push(protocol.to_vec());
137-
}
138-
}
163+
ensure_proxy_alpn_protocols(&mut server_config.alpn_protocols);
139164

140165
let attested_tls_server = AttestedTlsServer::new_with_tls_config(
141166
cert_chain,
@@ -347,11 +372,34 @@ impl ProxyClient {
347372
attestation_verifier: AttestationVerifier,
348373
remote_certificate: Option<CertificateDer<'static>>,
349374
) -> Result<Self, ProxyError> {
350-
let attested_tls_client = AttestedTlsClient::new(
351-
cert_and_key,
375+
let root_store = match remote_certificate {
376+
Some(remote_certificate) => {
377+
let mut root_store = RootCertStore::empty();
378+
root_store.add(remote_certificate)?;
379+
root_store
380+
}
381+
None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()),
382+
};
383+
384+
let mut client_config = if let Some(ref cert_and_key) = cert_and_key {
385+
ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
386+
.with_root_certificates(root_store)
387+
.with_client_auth_cert(
388+
cert_and_key.cert_chain.clone(),
389+
cert_and_key.key.clone_key(),
390+
)?
391+
} else {
392+
ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
393+
.with_root_certificates(root_store)
394+
.with_no_client_auth()
395+
};
396+
ensure_proxy_alpn_protocols(&mut client_config.alpn_protocols);
397+
398+
let attested_tls_client = AttestedTlsClient::new_with_tls_config(
399+
client_config,
352400
attestation_generator,
353401
attestation_verifier,
354-
remote_certificate,
402+
cert_and_key.map(|c| c.cert_chain),
355403
)?;
356404

357405
Self::new_with_inner(address, attested_tls_client, &server_name).await
@@ -366,16 +414,7 @@ impl ProxyClient {
366414
attestation_verifier: AttestationVerifier,
367415
cert_chain: Option<Vec<CertificateDer<'static>>>,
368416
) -> Result<Self, ProxyError> {
369-
for protocol in [ALPN_H2, ALPN_HTTP11] {
370-
let already_present = client_config
371-
.alpn_protocols
372-
.iter()
373-
.any(|p| p.as_slice() == protocol);
374-
375-
if !already_present {
376-
client_config.alpn_protocols.push(protocol.to_vec());
377-
}
378-
}
417+
ensure_proxy_alpn_protocols(&mut client_config.alpn_protocols);
379418

380419
let attested_tls_client = AttestedTlsClient::new_with_tls_config(
381420
client_config,
@@ -763,6 +802,89 @@ mod tests {
763802
generate_tls_config_with_client_auth, init_tracing, mock_dcap_measurements,
764803
};
765804

805+
#[test]
806+
fn proxy_alpn_protocols_prefer_http2() {
807+
let mut protocols = Vec::new();
808+
ensure_proxy_alpn_protocols(&mut protocols);
809+
810+
assert_eq!(protocols, vec![ALPN_H2.to_vec(), ALPN_HTTP11.to_vec()]);
811+
}
812+
813+
#[test]
814+
fn proxy_alpn_protocols_preserve_existing_order_without_duplicates() {
815+
let mut protocols = vec![ALPN_HTTP11.to_vec(), ALPN_H2.to_vec()];
816+
ensure_proxy_alpn_protocols(&mut protocols);
817+
818+
assert_eq!(protocols, vec![ALPN_HTTP11.to_vec(), ALPN_H2.to_vec()]);
819+
}
820+
821+
#[tokio::test]
822+
async fn http_proxy_default_constructors_work() {
823+
let target_addr = example_http_service().await;
824+
825+
let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
826+
let server_cert = cert_chain[0].clone();
827+
828+
let proxy_server = ProxyServer::new(
829+
TlsCertAndKey {
830+
cert_chain,
831+
key: private_key,
832+
},
833+
"127.0.0.1:0",
834+
target_addr.to_string(),
835+
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
836+
AttestationVerifier::expect_none(),
837+
false,
838+
)
839+
.await
840+
.unwrap();
841+
842+
let proxy_addr = proxy_server.local_addr().unwrap();
843+
844+
tokio::spawn(async move {
845+
proxy_server.accept().await.unwrap();
846+
});
847+
848+
let proxy_client = ProxyClient::new(
849+
None,
850+
"127.0.0.1:0".to_string(),
851+
proxy_addr.to_string(),
852+
AttestationGenerator::with_no_attestation(),
853+
AttestationVerifier::mock(),
854+
Some(server_cert),
855+
)
856+
.await
857+
.unwrap();
858+
859+
let proxy_client_addr = proxy_client.local_addr().unwrap();
860+
861+
tokio::spawn(async move {
862+
proxy_client.accept().await.unwrap();
863+
});
864+
865+
let res = reqwest::get(format!("http://{}", proxy_client_addr))
866+
.await
867+
.unwrap();
868+
869+
let headers = res.headers();
870+
871+
let attestation_type = headers
872+
.get(ATTESTATION_TYPE_HEADER)
873+
.unwrap()
874+
.to_str()
875+
.unwrap();
876+
assert_eq!(attestation_type, AttestationType::DcapTdx.as_str());
877+
878+
let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap();
879+
let measurements =
880+
MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx)
881+
.unwrap();
882+
assert_eq!(measurements, mock_dcap_measurements());
883+
884+
let res_body = res.text().await.unwrap();
885+
assert_eq!(res_body, "No measurements");
886+
}
887+
766888
// Server has mock DCAP, client has no attestation and no client auth
767889
#[tokio::test]
768890
async fn http_proxy_with_server_attestation() {

0 commit comments

Comments
 (0)