Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions crates/stackable-webhook/src/tls/cert_resolver.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{sync::Arc, time::SystemTime};

use arc_swap::ArcSwap;
use snafu::{OptionExt, ResultExt, Snafu};
Expand Down Expand Up @@ -57,6 +57,9 @@ pub struct CertificateResolver {
/// Using a [`ArcSwap`] (over e.g. [`tokio::sync::RwLock`]), so that we can easily
/// (and performant) bridge between async write and sync read.
current_certified_key: ArcSwap<CertifiedKey>,
/// The wall-clock expiry time (`not_after`) of the current certificate.
/// Used to detect clock drift between monotonic and wall-clock time.
current_not_after: ArcSwap<SystemTime>,
subject_alterative_dns_names: Arc<Vec<String>>,

certificate_tx: mpsc::Sender<Certificate>,
Expand All @@ -68,7 +71,7 @@ impl CertificateResolver {
certificate_tx: mpsc::Sender<Certificate>,
) -> Result<Self> {
let subject_alterative_dns_names = Arc::new(subject_alterative_dns_names);
let certified_key = Self::generate_new_certificate_inner(
let (certified_key, not_after) = Self::generate_new_certificate_inner(
subject_alterative_dns_names.clone(),
&certificate_tx,
)
Expand All @@ -77,20 +80,37 @@ impl CertificateResolver {
Ok(Self {
subject_alterative_dns_names,
current_certified_key: ArcSwap::new(certified_key),
current_not_after: ArcSwap::new(Arc::new(not_after)),
certificate_tx,
})
}

pub async fn rotate_certificate(&self) -> Result<()> {
let certified_key = self.generate_new_certificate().await?;
let (certified_key, not_after) = self.generate_new_certificate().await?;

// TODO: Sign the new cert somehow with the old cert. See https://github.com/stackabletech/decisions/issues/56
self.current_certified_key.store(certified_key);
self.current_not_after.store(Arc::new(not_after));

Ok(())
}

async fn generate_new_certificate(&self) -> Result<Arc<CertifiedKey>> {
/// Returns `true` if the current certificate is expired or will expire
/// within the given `buffer` duration according to wall-clock time.
///
/// This catches cases where the monotonic timer (used by `tokio::time`)
/// has drifted from wall-clock time, e.g. due to system hibernation.
pub fn needs_rotation(&self, buffer: std::time::Duration) -> bool {
let not_after = **self.current_not_after.load();
// If subtraction underflows (buffer > time since epoch), fall back to
// UNIX_EPOCH so that the comparison always triggers rotation.
let deadline = not_after
.checked_sub(buffer)
.unwrap_or(SystemTime::UNIX_EPOCH);
SystemTime::now() >= deadline
}

async fn generate_new_certificate(&self) -> Result<(Arc<CertifiedKey>, SystemTime)> {
let subject_alterative_dns_names = self.subject_alterative_dns_names.clone();
Self::generate_new_certificate_inner(subject_alterative_dns_names, &self.certificate_tx)
.await
Expand All @@ -106,7 +126,7 @@ impl CertificateResolver {
async fn generate_new_certificate_inner(
subject_alterative_dns_names: Arc<Vec<String>>,
certificate_tx: &mpsc::Sender<Certificate>,
) -> Result<Arc<CertifiedKey>> {
) -> Result<(Arc<CertifiedKey>, SystemTime)> {
// The certificate generations can take a while, so we use `spawn_blocking`
let (cert, certified_key) = tokio::task::spawn_blocking(move || {
let tls_provider =
Expand Down Expand Up @@ -144,12 +164,14 @@ impl CertificateResolver {
.await
.context(TokioSpawnBlockingSnafu)??;

let not_after = cert.tbs_certificate.validity.not_after.to_system_time();

certificate_tx
.send(cert)
.await
.map_err(|_err| CertificateResolverError::SendCertificateToChannel)?;

Ok(certified_key)
Ok((certified_key, not_after))
}
}

Expand Down
51 changes: 37 additions & 14 deletions crates/stackable-webhook/src/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,25 @@ use crate::{
mod cert_resolver;

pub const WEBHOOK_CA_LIFETIME: Duration = Duration::from_hours_unchecked(24);
pub const WEBHOOK_CERTIFICATE_LIFETIME: Duration = Duration::from_hours_unchecked(24);
pub const WEBHOOK_CERTIFICATE_ROTATION_INTERVAL: Duration = Duration::from_hours_unchecked(20);

/// The wall-clock lifetime of generated webhook certificates. If this is ever
/// reduced, ensure it stays well above [`CERTIFICATE_ROTATION_CHECK_INTERVAL`]
/// (currently 5 minutes), otherwise the certificate could expire between checks.
const WEBHOOK_CERTIFICATE_LIFETIME_HOURS: u64 = 24;
pub const WEBHOOK_CERTIFICATE_LIFETIME: Duration =
Duration::from_hours_unchecked(WEBHOOK_CERTIFICATE_LIFETIME_HOURS);

/// How often to check whether the certificate needs rotation. This is
/// intentionally independent of the certificate lifetime — it controls how
/// quickly we detect wall-clock drift (from hibernation, VM migration, etc.),
/// not how long the certificate lives.
const CERTIFICATE_ROTATION_CHECK_INTERVAL: Duration = Duration::from_minutes_unchecked(5);

/// Rotate the certificate when less than 1/6 of its lifetime remains
/// (4 hours for the current 24h lifetime). Derived from
/// [`WEBHOOK_CERTIFICATE_LIFETIME`] so it scales if the lifetime changes.
const CERTIFICATE_EXPIRY_BUFFER: Duration =
Duration::from_minutes_unchecked(WEBHOOK_CERTIFICATE_LIFETIME_HOURS * 60 / 6);

pub type Result<T, E = TlsServerError> = std::result::Result<T, E>;

Expand Down Expand Up @@ -153,8 +170,12 @@ impl TlsServer {
router,
} = self;

let start = tokio::time::Instant::now() + *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL;
let mut interval = tokio::time::interval_at(start, *WEBHOOK_CERTIFICATE_ROTATION_INTERVAL);
// Periodically check whether the certificate needs rotation based on
// wall-clock time. This avoids the monotonic vs wall-clock drift problem
// that can occur during hibernation, VM migration, or cgroup freezing.
let check_start = tokio::time::Instant::now() + *CERTIFICATE_ROTATION_CHECK_INTERVAL;
let mut rotation_check =
tokio::time::interval_at(check_start, *CERTIFICATE_ROTATION_CHECK_INTERVAL);

let tls_acceptor = TlsAcceptor::from(Arc::new(config));
let tcp_listener = TcpListener::bind(socket_addr)
Expand Down Expand Up @@ -183,11 +204,10 @@ impl TlsServer {
loop {
let tls_acceptor = tls_acceptor.clone();

// Wait for either a new TCP connection or the certificate rotation interval tick
tokio::select! {
// We opt for a biased execution of arms to make sure we always check if a
// shutdown signal was received or the certificate needs rotation based on the
// interval. This ensures, we always use a valid certificate for the TLS connection.
// shutdown signal was received or the certificate needs rotation before
// accepting new connections.
biased;

// Once a shutdown signal is received (this future becomes `Poll::Ready`), break out
Expand All @@ -198,13 +218,16 @@ impl TlsServer {
break;
}

// This is cancellation-safe. If this branch is cancelled, the tick is NOT consumed.
// As such, we will not miss rotating the certificate.
_ = interval.tick() => {
cert_resolver
.rotate_certificate()
.await
.context(RotateCertificateSnafu)?
// Check wall-clock time to decide if the certificate needs rotation.
// This is cancellation-safe: if cancelled, the tick is NOT consumed.
_ = rotation_check.tick() => {
if cert_resolver.needs_rotation(*CERTIFICATE_EXPIRY_BUFFER) {
tracing::info!("certificate approaching expiry, rotating");
cert_resolver
.rotate_certificate()
.await
.context(RotateCertificateSnafu)?;
}
}

// This is cancellation-safe. If cancelled, no new connections are accepted.
Expand Down
Loading