Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 133 additions & 4 deletions pingora-core/src/listeners/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,24 @@ pub trait TlsAccept {

pub type TlsAcceptCallbacks = Box<dyn TlsAccept + Send + Sync>;

/// Some protocols, such as the proxy protocol, must be inspected before the TLS
/// handshake. The below trait provides access to the raw TCP stream right
/// before TLS for these situations.
#[async_trait]
pub trait InspectPreTls: Send + Sync {
/// The implementation can read bytes from the stream (e.g., PROXY protocol header)
/// before the TLS handshake takes place.
///
/// If this method returns an error, the connection will be dropped.
async fn inspect(&self, stream: &mut L4Stream) -> Result<()>;
}

struct TransportStackBuilder {
l4: ServerAddress,
tls: Option<TlsSettings>,
#[cfg(feature = "connection_filter")]
connection_filter: Option<Arc<dyn ConnectionFilter>>,
pre_tls_inspector: Option<Arc<dyn InspectPreTls>>,
}

impl TransportStackBuilder {
Expand All @@ -148,6 +161,7 @@ impl TransportStackBuilder {
Ok(TransportStack {
l4,
tls: self.tls.take().map(|tls| Arc::new(tls.build())),
pre_tls_inspector: self.pre_tls_inspector.clone(),
})
}
}
Expand All @@ -156,6 +170,7 @@ impl TransportStackBuilder {
pub(crate) struct TransportStack {
l4: ListenerEndpoint,
tls: Option<Arc<Acceptor>>,
pre_tls_inspector: Option<Arc<dyn InspectPreTls>>,
}

impl TransportStack {
Expand All @@ -168,6 +183,7 @@ impl TransportStack {
Ok(UninitializedStream {
l4: stream,
tls: self.tls.clone(),
pre_tls_inspector: self.pre_tls_inspector.clone(),
})
}

Expand All @@ -179,17 +195,27 @@ impl TransportStack {
pub(crate) struct UninitializedStream {
l4: L4Stream,
tls: Option<Arc<Acceptor>>,
pre_tls_inspector: Option<Arc<dyn InspectPreTls>>,
}

impl UninitializedStream {
pub async fn handshake(mut self) -> Result<Stream> {
self.l4.set_buffer();
if let Some(tls) = self.tls {

// Expose raw l4 stream to any registered pre-TLS inspectors before
// handshaking.
if let Some(inspector) = self.pre_tls_inspector.as_ref() {
inspector.inspect(&mut self.l4).await?;
}

let res_with_stream: Result<Stream> = if let Some(tls) = self.tls {
let tls_stream = tls.tls_handshake(self.l4).await?;
Ok(Box::new(tls_stream))
} else {
Ok(Box::new(self.l4))
}
};

res_with_stream
}

/// Get the peer address of the connection if available
Expand All @@ -205,6 +231,7 @@ pub struct Listeners {
stacks: Vec<TransportStackBuilder>,
#[cfg(feature = "connection_filter")]
connection_filter: Option<Arc<dyn ConnectionFilter>>,
pre_tls_inspector: Option<Arc<dyn InspectPreTls>>,
}

impl Listeners {
Expand All @@ -214,6 +241,7 @@ impl Listeners {
stacks: vec![],
#[cfg(feature = "connection_filter")]
connection_filter: None,
pre_tls_inspector: None,
}
}
/// Create a new [`Listeners`] with a TCP server endpoint from the given string.
Expand Down Expand Up @@ -294,13 +322,31 @@ impl Listeners {
}
}

/// Set a pre-TLS inspector for all endpoints in this listener collection.
///
/// The inspector will be invoked after TCP accept but before the TLS handshake,
/// allowing the application to read and process data such as PROXY protocol
/// headers that arrive before TLS.
pub fn set_pre_tls_inspector(&mut self, inspector: Arc<dyn InspectPreTls>) {
log::debug!("Setting pre-TLS inspector on Listeners");

// Store the inspector for future endpoints
self.pre_tls_inspector = Some(inspector.clone());

// Apply to existing stacks
for stack in &mut self.stacks {
stack.pre_tls_inspector = Some(inspector.clone());
}
}

/// Add the given [`ServerAddress`] to `self` with the given [`TlsSettings`] if provided
pub fn add_endpoint(&mut self, l4: ServerAddress, tls: Option<TlsSettings>) {
self.stacks.push(TransportStackBuilder {
l4,
tls,
#[cfg(feature = "connection_filter")]
connection_filter: self.connection_filter.clone(),
pre_tls_inspector: self.pre_tls_inspector.clone(),
})
}

Expand Down Expand Up @@ -341,8 +387,8 @@ mod test {

#[tokio::test]
async fn test_listen_tcp() {
let addr1 = "127.0.0.1:7101";
let addr2 = "127.0.0.1:7102";
let addr1 = "127.0.0.1:7107";
let addr2 = "127.0.0.1:7108";
let mut listeners = Listeners::tcp(addr1);
listeners.add_tcp(addr2);

Expand Down Expand Up @@ -460,4 +506,87 @@ mod test {
);
}
}

#[tokio::test]
#[cfg(any(feature = "openssl", feature = "boringssl"))]
async fn test_inspect_pre_tls() {
use pingora_error::{Error, Result};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use tokio::io::{AsyncReadExt, AsyncWriteExt};

use crate::protocols::tls::SslStream;
use crate::tls::ssl;
struct HelloInspector {
stored_bytes: Arc<Mutex<Vec<u8>>>,
}

#[async_trait]
impl InspectPreTls for HelloInspector {
async fn inspect(&self, stream: &mut L4Stream) -> Result<()> {
let mut buf = [0u8; 5];
stream.read_exact(&mut buf).await.map_err(|e| {
Error::new_str("failed to read pre-TLS bytes").more_context(format!("{e}"))
})?;
self.stored_bytes.lock().unwrap().extend_from_slice(&buf);
if &buf != b"hello" {
return Err(Error::new_str("pre-TLS bytes did not match 'hello'"));
}
Ok(())
}
}

let stored = Arc::new(Mutex::new(Vec::new()));
let inspector = Arc::new(HelloInspector {
stored_bytes: stored.clone(),
});

let addr = "127.0.0.1:7109";
let cert_path = format!("{}/tests/keys/server.crt", env!("CARGO_MANIFEST_DIR"));
let key_path = format!("{}/tests/keys/key.pem", env!("CARGO_MANIFEST_DIR"));
let mut listeners = Listeners::tls(addr, &cert_path, &key_path).unwrap();

// Register HelloInspector on the listener so it fires before TLS handshaking.
listeners.set_pre_tls_inspector(inspector.clone());
let listener = listeners
.build(
#[cfg(unix)]
None,
)
.await
.unwrap()
.pop()
.unwrap();

let server_handle = tokio::spawn(async move {
// Acceptor thread should handshake, which will perform pre-TLS inspection
// and then the TLS handshake.
let stream = listener.accept().await.unwrap();
stream.handshake().await.unwrap();
});

// make sure the above starts before the lines below
sleep(Duration::from_millis(10)).await;

let client_handle = tokio::spawn(async move {
// Prepend the TLS handshake with the bytes "hello".
let mut tcp_stream = tokio::net::TcpStream::connect(addr).await.unwrap();
tcp_stream.write_all(b"hello").await.unwrap();

// Perform the TLS handshake with verification disabled because the
// certificates aren't actually valid.
let ssl_context = ssl::SslContext::builder(ssl::SslMethod::tls())
.unwrap()
.build();
let mut ssl_obj = ssl::Ssl::new(&ssl_context).unwrap();
ssl_obj.set_verify(ssl::SslVerifyMode::NONE);
let mut tls_stream = SslStream::new(ssl_obj, tcp_stream).unwrap();
Pin::new(&mut tls_stream).connect().await.unwrap();
});

server_handle.await.unwrap();
client_handle.await.unwrap();

assert_eq!(&*stored.lock().unwrap(), b"hello");
}
}
5 changes: 4 additions & 1 deletion pingora-core/src/protocols/l4/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,10 @@ impl Stream {
}

/// Put Some data back to the head of the stream to be read again
pub(crate) fn rewind(&mut self, data: &[u8]) {
/// This can be used in cases where we "peek" at data only to find
/// it doesn't match what's expected, and so it needs to be put back
/// for a different protocol to potentially use it.
pub fn rewind(&mut self, data: &[u8]) {
if !data.is_empty() {
self.rewind_read_buf.push(data.to_vec());
}
Expand Down
1 change: 1 addition & 0 deletions pingora-core/src/protocols/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ impl ALPN {
ALPN::H1 => vec![b"http/1.1".to_vec()],
ALPN::H2 => vec![b"h2".to_vec()],
ALPN::H2H1 => vec![b"h2".to_vec(), b"http/1.1".to_vec()],
ALPN::Custom(custom) => vec![custom.protocol().to_vec()],
}
}
}
Expand Down
Loading