Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ serde = "1.0.228"
reqwest = { version = "0.12.24", default-features = false, features = [
"rustls-tls-webpki-roots-no-provider",
] }
webpki-roots = "1.0.4"
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
axum = "0.8.6"
Expand Down
16 changes: 12 additions & 4 deletions src/http_version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,25 @@ impl HttpVersion {
pub fn from_negotiated_protocol_server<IO>(tls: &tokio_rustls::server::TlsStream<IO>) -> Self {
let (_io, conn) = tls.get_ref();

let chosen_protocol = Self::from_alpn_bytes(conn.alpn_protocol());
tracing::debug!("[server] Chosen protocol {chosen_protocol:?}",);
let negotiated_alpn = conn.alpn_protocol();
let chosen_protocol = Self::from_alpn_bytes(negotiated_alpn);
tracing::debug!(
"[server] Negotiated ALPN {:?}, chosen protocol {chosen_protocol:?}",
negotiated_alpn.map(String::from_utf8_lossy)
);
chosen_protocol
}

/// Given a client TLS stream, choose an HTTP version to use
pub fn from_negotiated_protocol_client<IO>(tls: &tokio_rustls::client::TlsStream<IO>) -> Self {
let (_io, conn) = tls.get_ref();

let chosen_protocol = Self::from_alpn_bytes(conn.alpn_protocol());
tracing::debug!("[client] Chosen protocol {chosen_protocol:?}",);
let negotiated_alpn = conn.alpn_protocol();
let chosen_protocol = Self::from_alpn_bytes(negotiated_alpn);
tracing::debug!(
"[client] Negotiated ALPN {:?}, chosen protocol {chosen_protocol:?}",
negotiated_alpn.map(String::from_utf8_lossy)
);
chosen_protocol
}

Expand Down
178 changes: 150 additions & 28 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ use thiserror::Error;
use tokio::io;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio::sync::{mpsc, oneshot};
use tokio_rustls::rustls::server::VerifierBuilderError;
use tokio_rustls::rustls::{ClientConfig, ServerConfig, pki_types::CertificateDer};
use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier};
use tokio_rustls::rustls::{
self, ClientConfig, RootCertStore, ServerConfig, pki_types::CertificateDer,
};
use tracing::{debug, error, warn};

use crate::http_version::{ALPN_H2, ALPN_HTTP11, HttpConnection, HttpSender, HttpVersion};
Expand Down Expand Up @@ -59,6 +61,17 @@ type RequestWithResponseSender = (
oneshot::Sender<Result<Response<BoxBody<bytes::Bytes, hyper::Error>>, hyper::Error>>,
);

/// Adds HTTP 1 and 2 to the list of allowed protocols
fn ensure_proxy_alpn_protocols(alpn_protocols: &mut Vec<Vec<u8>>) {
for protocol in [ALPN_H2, ALPN_HTTP11] {
let already_present = alpn_protocols.iter().any(|p| p.as_slice() == protocol);

if !already_present {
alpn_protocols.push(protocol.to_vec());
}
}
}

/// Retrieve the attested remote TLS certificate.
pub async fn get_tls_cert(
server_name: String,
Expand Down Expand Up @@ -101,11 +114,32 @@ impl ProxyServer {
attestation_verifier: AttestationVerifier,
client_auth: bool,
) -> Result<Self, ProxyError> {
let attested_tls_server = AttestedTlsServer::new(
cert_and_key,
let mut server_config = if client_auth {
let root_store =
RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?;

ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
.with_client_cert_verifier(verifier)
.with_single_cert(
cert_and_key.cert_chain.clone(),
cert_and_key.key.clone_key(),
)?
} else {
ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
.with_no_client_auth()
.with_single_cert(
cert_and_key.cert_chain.clone(),
cert_and_key.key.clone_key(),
)?
};
ensure_proxy_alpn_protocols(&mut server_config.alpn_protocols);

let attested_tls_server = AttestedTlsServer::new_with_tls_config(
cert_and_key.cert_chain,
server_config,
attestation_generator,
attestation_verifier,
client_auth,
)?;

let listener = TcpListener::bind(local).await?;
Expand All @@ -126,16 +160,7 @@ impl ProxyServer {
attestation_generator: AttestationGenerator,
attestation_verifier: AttestationVerifier,
) -> Result<Self, ProxyError> {
for protocol in [ALPN_H2, ALPN_HTTP11] {
let already_present = server_config
.alpn_protocols
.iter()
.any(|p| p.as_slice() == protocol);

if !already_present {
server_config.alpn_protocols.push(protocol.to_vec());
}
}
ensure_proxy_alpn_protocols(&mut server_config.alpn_protocols);

let attested_tls_server = AttestedTlsServer::new_with_tls_config(
cert_chain,
Expand Down Expand Up @@ -347,11 +372,34 @@ impl ProxyClient {
attestation_verifier: AttestationVerifier,
remote_certificate: Option<CertificateDer<'static>>,
) -> Result<Self, ProxyError> {
let attested_tls_client = AttestedTlsClient::new(
cert_and_key,
let root_store = match remote_certificate {
Some(remote_certificate) => {
let mut root_store = RootCertStore::empty();
root_store.add(remote_certificate)?;
root_store
}
None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()),
};

let mut client_config = if let Some(ref cert_and_key) = cert_and_key {
ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
.with_root_certificates(root_store)
.with_client_auth_cert(
cert_and_key.cert_chain.clone(),
cert_and_key.key.clone_key(),
)?
} else {
ClientConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
.with_root_certificates(root_store)
.with_no_client_auth()
};
ensure_proxy_alpn_protocols(&mut client_config.alpn_protocols);

let attested_tls_client = AttestedTlsClient::new_with_tls_config(
client_config,
attestation_generator,
attestation_verifier,
remote_certificate,
cert_and_key.map(|c| c.cert_chain),
)?;

Self::new_with_inner(address, attested_tls_client, &server_name).await
Expand All @@ -366,16 +414,7 @@ impl ProxyClient {
attestation_verifier: AttestationVerifier,
cert_chain: Option<Vec<CertificateDer<'static>>>,
) -> Result<Self, ProxyError> {
for protocol in [ALPN_H2, ALPN_HTTP11] {
let already_present = client_config
.alpn_protocols
.iter()
.any(|p| p.as_slice() == protocol);

if !already_present {
client_config.alpn_protocols.push(protocol.to_vec());
}
}
ensure_proxy_alpn_protocols(&mut client_config.alpn_protocols);

let attested_tls_client = AttestedTlsClient::new_with_tls_config(
client_config,
Expand Down Expand Up @@ -763,6 +802,89 @@ mod tests {
generate_tls_config_with_client_auth, init_tracing, mock_dcap_measurements,
};

#[test]
fn proxy_alpn_protocols_prefer_http2() {
let mut protocols = Vec::new();
ensure_proxy_alpn_protocols(&mut protocols);

assert_eq!(protocols, vec![ALPN_H2.to_vec(), ALPN_HTTP11.to_vec()]);
}

#[test]
fn proxy_alpn_protocols_preserve_existing_order_without_duplicates() {
let mut protocols = vec![ALPN_HTTP11.to_vec(), ALPN_H2.to_vec()];
ensure_proxy_alpn_protocols(&mut protocols);

assert_eq!(protocols, vec![ALPN_HTTP11.to_vec(), ALPN_H2.to_vec()]);
}

#[tokio::test]
async fn http_proxy_default_constructors_work() {
let target_addr = example_http_service().await;

let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
let server_cert = cert_chain[0].clone();

let proxy_server = ProxyServer::new(
TlsCertAndKey {
cert_chain,
key: private_key,
},
"127.0.0.1:0",
target_addr.to_string(),
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
AttestationVerifier::expect_none(),
false,
)
.await
.unwrap();

let proxy_addr = proxy_server.local_addr().unwrap();

tokio::spawn(async move {
proxy_server.accept().await.unwrap();
});

let proxy_client = ProxyClient::new(
None,
"127.0.0.1:0".to_string(),
proxy_addr.to_string(),
AttestationGenerator::with_no_attestation(),
AttestationVerifier::mock(),
Some(server_cert),
)
.await
.unwrap();

let proxy_client_addr = proxy_client.local_addr().unwrap();

tokio::spawn(async move {
proxy_client.accept().await.unwrap();
});

let res = reqwest::get(format!("http://{}", proxy_client_addr))
.await
.unwrap();

let headers = res.headers();

let attestation_type = headers
.get(ATTESTATION_TYPE_HEADER)
.unwrap()
.to_str()
.unwrap();
assert_eq!(attestation_type, AttestationType::DcapTdx.as_str());

let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap();
let measurements =
MultiMeasurements::from_header_format(measurements_json, AttestationType::DcapTdx)
.unwrap();
assert_eq!(measurements, mock_dcap_measurements());

let res_body = res.text().await.unwrap();
assert_eq!(res_body, "No measurements");
}

// Server has mock DCAP, client has no attestation and no client auth
#[tokio::test]
async fn http_proxy_with_server_attestation() {
Expand Down