diff --git a/Cargo.lock b/Cargo.lock index 69dbd04..917d680 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1731,9 +1731,11 @@ dependencies = [ "log", "prost", "rand 0.8.5", + "rcgen", "rusqlite", "serde", "tokio", + "tokio-rustls 0.26.4", "toml", ] @@ -2526,6 +2528,18 @@ dependencies = [ "cipher", ] +[[package]] +name = "rcgen" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "ring", + "rustls-pki-types", + "time", + "yasna", +] + [[package]] name = "reactor-trait" version = "1.1.0" @@ -4040,6 +4054,15 @@ dependencies = [ "time", ] +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "yoke" version = "0.8.1" diff --git a/ldk-server-cli/src/main.rs b/ldk-server-cli/src/main.rs index 58be095..fb9559c 100644 --- a/ldk-server-cli/src/main.rs +++ b/ldk-server-cli/src/main.rs @@ -49,6 +49,14 @@ struct Cli { #[arg(short, long, required(true))] api_key: String, + #[arg( + short, + long, + required(true), + help = "Path to the server's TLS certificate file (PEM format). Found at /tls_cert.pem" + )] + tls_cert: String, + #[command(subcommand)] command: Commands, } @@ -217,7 +225,18 @@ enum Commands { #[tokio::main] async fn main() { let cli = Cli::parse(); - let client = LdkServerClient::new(cli.base_url, cli.api_key); + + // Load server certificate for TLS verification + let server_cert_pem = std::fs::read(&cli.tls_cert).unwrap_or_else(|e| { + eprintln!("Failed to read server certificate file '{}': {}", cli.tls_cert, e); + std::process::exit(1); + }); + + let client = + LdkServerClient::new(cli.base_url, cli.api_key, &server_cert_pem).unwrap_or_else(|e| { + eprintln!("Failed to create client: {e}"); + std::process::exit(1); + }); match cli.command { Commands::GetNodeInfo => { diff --git a/ldk-server-client/src/client.rs b/ldk-server-client/src/client.rs index 3c76060..060f9bd 100644 --- a/ldk-server-client/src/client.rs +++ b/ldk-server-client/src/client.rs @@ -33,12 +33,16 @@ use ldk_server_protos::endpoints::{ }; use ldk_server_protos::error::{ErrorCode, ErrorResponse}; use reqwest::header::CONTENT_TYPE; -use reqwest::Client; +use reqwest::{Certificate, Client}; use std::time::{SystemTime, UNIX_EPOCH}; const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; /// Client to access a hosted instance of LDK Server. +/// +/// The client requires the server's TLS certificate to be provided for verification. +/// This certificate can be found at `/tls_cert.pem` after the +/// server generates it on first startup. #[derive(Clone)] pub struct LdkServerClient { base_url: String, @@ -48,9 +52,21 @@ pub struct LdkServerClient { impl LdkServerClient { /// Constructs a [`LdkServerClient`] using `base_url` as the ldk-server endpoint. + /// + /// `base_url` should not include the scheme, e.g., `localhost:3000`. /// `api_key` is used for HMAC-based authentication. - pub fn new(base_url: String, api_key: String) -> Self { - Self { base_url, client: Client::new(), api_key } + /// `server_cert_pem` is the server's TLS certificate in PEM format. This can be + /// found at `/tls_cert.pem` after the server starts. + pub fn new(base_url: String, api_key: String, server_cert_pem: &[u8]) -> Result { + let cert = Certificate::from_pem(server_cert_pem) + .map_err(|e| format!("Failed to parse server certificate: {e}"))?; + + let client = Client::builder() + .add_root_certificate(cert) + .build() + .map_err(|e| format!("Failed to build HTTP client: {e}"))?; + + Ok(Self { base_url, client, api_key }) } /// Computes the HMAC-SHA256 authentication header value. @@ -75,7 +91,7 @@ impl LdkServerClient { pub async fn get_node_info( &self, request: GetNodeInfoRequest, ) -> Result { - let url = format!("http://{}/{GET_NODE_INFO_PATH}", self.base_url); + let url = format!("https://{}/{GET_NODE_INFO_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -84,7 +100,7 @@ impl LdkServerClient { pub async fn get_balances( &self, request: GetBalancesRequest, ) -> Result { - let url = format!("http://{}/{GET_BALANCES_PATH}", self.base_url); + let url = format!("https://{}/{GET_BALANCES_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -93,7 +109,7 @@ impl LdkServerClient { pub async fn onchain_receive( &self, request: OnchainReceiveRequest, ) -> Result { - let url = format!("http://{}/{ONCHAIN_RECEIVE_PATH}", self.base_url); + let url = format!("https://{}/{ONCHAIN_RECEIVE_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -102,7 +118,7 @@ impl LdkServerClient { pub async fn onchain_send( &self, request: OnchainSendRequest, ) -> Result { - let url = format!("http://{}/{ONCHAIN_SEND_PATH}", self.base_url); + let url = format!("https://{}/{ONCHAIN_SEND_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -111,7 +127,7 @@ impl LdkServerClient { pub async fn bolt11_receive( &self, request: Bolt11ReceiveRequest, ) -> Result { - let url = format!("http://{}/{BOLT11_RECEIVE_PATH}", self.base_url); + let url = format!("https://{}/{BOLT11_RECEIVE_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -120,7 +136,7 @@ impl LdkServerClient { pub async fn bolt11_send( &self, request: Bolt11SendRequest, ) -> Result { - let url = format!("http://{}/{BOLT11_SEND_PATH}", self.base_url); + let url = format!("https://{}/{BOLT11_SEND_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -129,7 +145,7 @@ impl LdkServerClient { pub async fn bolt12_receive( &self, request: Bolt12ReceiveRequest, ) -> Result { - let url = format!("http://{}/{BOLT12_RECEIVE_PATH}", self.base_url); + let url = format!("https://{}/{BOLT12_RECEIVE_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -138,7 +154,7 @@ impl LdkServerClient { pub async fn bolt12_send( &self, request: Bolt12SendRequest, ) -> Result { - let url = format!("http://{}/{BOLT12_SEND_PATH}", self.base_url); + let url = format!("https://{}/{BOLT12_SEND_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -147,7 +163,7 @@ impl LdkServerClient { pub async fn open_channel( &self, request: OpenChannelRequest, ) -> Result { - let url = format!("http://{}/{OPEN_CHANNEL_PATH}", self.base_url); + let url = format!("https://{}/{OPEN_CHANNEL_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -156,7 +172,7 @@ impl LdkServerClient { pub async fn splice_in( &self, request: SpliceInRequest, ) -> Result { - let url = format!("http://{}/{SPLICE_IN_PATH}", self.base_url); + let url = format!("https://{}/{SPLICE_IN_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -165,7 +181,7 @@ impl LdkServerClient { pub async fn splice_out( &self, request: SpliceOutRequest, ) -> Result { - let url = format!("http://{}/{SPLICE_OUT_PATH}", self.base_url); + let url = format!("https://{}/{SPLICE_OUT_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -174,7 +190,7 @@ impl LdkServerClient { pub async fn close_channel( &self, request: CloseChannelRequest, ) -> Result { - let url = format!("http://{}/{CLOSE_CHANNEL_PATH}", self.base_url); + let url = format!("https://{}/{CLOSE_CHANNEL_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -183,7 +199,7 @@ impl LdkServerClient { pub async fn force_close_channel( &self, request: ForceCloseChannelRequest, ) -> Result { - let url = format!("http://{}/{FORCE_CLOSE_CHANNEL_PATH}", self.base_url); + let url = format!("https://{}/{FORCE_CLOSE_CHANNEL_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -192,7 +208,7 @@ impl LdkServerClient { pub async fn list_channels( &self, request: ListChannelsRequest, ) -> Result { - let url = format!("http://{}/{LIST_CHANNELS_PATH}", self.base_url); + let url = format!("https://{}/{LIST_CHANNELS_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -201,7 +217,7 @@ impl LdkServerClient { pub async fn list_payments( &self, request: ListPaymentsRequest, ) -> Result { - let url = format!("http://{}/{LIST_PAYMENTS_PATH}", self.base_url); + let url = format!("https://{}/{LIST_PAYMENTS_PATH}", self.base_url); self.post_request(&request, &url).await } @@ -210,7 +226,7 @@ impl LdkServerClient { pub async fn update_channel_config( &self, request: UpdateChannelConfigRequest, ) -> Result { - let url = format!("http://{}/{UPDATE_CHANNEL_CONFIG_PATH}", self.base_url); + let url = format!("https://{}/{UPDATE_CHANNEL_CONFIG_PATH}", self.base_url); self.post_request(&request, &url).await } diff --git a/ldk-server/Cargo.toml b/ldk-server/Cargo.toml index 62f82d3..3ec19fe 100644 --- a/ldk-server/Cargo.toml +++ b/ldk-server/Cargo.toml @@ -10,6 +10,8 @@ hyper = { version = "1", default-features = false, features = ["server", "http1" http-body-util = { version = "0.1", default-features = false } hyper-util = { version = "0.1", default-features = false, features = ["server-graceful"] } tokio = { version = "1.38.0", default-features = false, features = ["time", "signal", "rt-multi-thread"] } +tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +rcgen = { version = "0.13", default-features = false, features = ["ring"] } prost = { version = "0.11.6", default-features = false, features = ["std"] } ldk-server-protos = { path = "../ldk-server-protos" } bytes = { version = "1.4.0", default-features = false } diff --git a/ldk-server/ldk-server-config.toml b/ldk-server/ldk-server-config.toml index f419aed..f4485e9 100644 --- a/ldk-server/ldk-server-config.toml +++ b/ldk-server/ldk-server-config.toml @@ -13,6 +13,11 @@ dir_path = "/tmp/ldk-server/" # Path for LDK and BDK data persis level = "Debug" # Log level (Error, Warn, Info, Debug, Trace) file_path = "/tmp/ldk-server/ldk-server.log" # Log file path +[tls] +#cert_path = "/path/to/tls.crt" # Path to TLS certificate, by default uses dir_path/tls.crt +#key_path = "/path/to/tls.key" # Path to TLS private key, by default uses dir_path/tls.key +hosts = ["example.com"] # Allowed hosts for TLS, will always include "localhost" and "127.0.0.1" + # Must set either bitcoind or esplora settings, but not both # Bitcoin Core settings diff --git a/ldk-server/src/main.rs b/ldk-server/src/main.rs index a18f0bb..4c0d74a 100644 --- a/ldk-server/src/main.rs +++ b/ldk-server/src/main.rs @@ -36,6 +36,7 @@ use crate::io::persist::{ use crate::util::config::{load_config, ChainSource}; use crate::util::logger::ServerLogger; use crate::util::proto_adapter::{forwarded_payment_to_proto, payment_to_proto}; +use crate::util::tls::get_or_generate_tls_config; use hex::DisplayHex; use ldk_node::config::Config; use ldk_node::lightning::ln::channelmanager::PaymentId; @@ -155,14 +156,15 @@ fn main() { }, }; - let paginated_store: Arc = - Arc::new(match SqliteStore::new(PathBuf::from(config_file.storage_dir_path), None, None) { + let paginated_store: Arc = Arc::new( + match SqliteStore::new(PathBuf::from(&config_file.storage_dir_path), None, None) { Ok(store) => store, Err(e) => { error!("Failed to create SqliteStore: {e:?}"); std::process::exit(-1); }, - }); + }, + ); #[cfg(not(feature = "events-rabbitmq"))] let event_publisher: Arc = @@ -213,6 +215,20 @@ fn main() { let rest_svc_listener = TcpListener::bind(config_file.rest_service_addr) .await .expect("Failed to bind listening port"); + + let server_config = match get_or_generate_tls_config( + config_file.tls_config, + &config_file.storage_dir_path, + ) { + Ok(config) => config, + Err(e) => { + error!("Failed to set up TLS: {e}"); + std::process::exit(-1); + } + }; + let tls_acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(server_config)); + info!("TLS enabled for REST service on {}", config_file.rest_service_addr); + loop { select! { event = event_node.next_event_async() => { @@ -355,11 +371,17 @@ fn main() { res = rest_svc_listener.accept() => { match res { Ok((stream, _)) => { - let io_stream = TokioIo::new(stream); let node_service = NodeService::new(Arc::clone(&node), Arc::clone(&paginated_store), config_file.api_key.clone()); + let acceptor = tls_acceptor.clone(); runtime.spawn(async move { - if let Err(err) = http1::Builder::new().serve_connection(io_stream, node_service).await { - error!("Failed to serve connection: {}", err); + match acceptor.accept(stream).await { + Ok(tls_stream) => { + let io_stream = TokioIo::new(tls_stream); + if let Err(err) = http1::Builder::new().serve_connection(io_stream, node_service).await { + error!("Failed to serve TLS connection: {err}"); + } + }, + Err(e) => error!("TLS handshake failed: {e}"), } }); }, diff --git a/ldk-server/src/util/config.rs b/ldk-server/src/util/config.rs index 6f7a006..2128f78 100644 --- a/ldk-server/src/util/config.rs +++ b/ldk-server/src/util/config.rs @@ -25,6 +25,7 @@ pub struct Config { pub alias: Option, pub network: Network, pub api_key: String, + pub tls_config: Option, pub rest_service_addr: SocketAddr, pub storage_dir_path: String, pub chain_source: ChainSource, @@ -35,6 +36,13 @@ pub struct Config { pub log_file_path: Option, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TlsConfig { + pub cert_path: Option, + pub key_path: Option, + pub hosts: Vec, +} + #[derive(Debug)] pub enum ChainSource { Rpc { rpc_address: SocketAddr, rpc_user: String, rpc_password: String }, @@ -145,6 +153,12 @@ impl TryFrom for Config { ))? .into()); + let tls_config = toml_config.tls.map(|tls| TlsConfig { + cert_path: tls.cert_path, + key_path: tls.key_path, + hosts: tls.hosts.unwrap_or_default(), + }); + Ok(Config { listening_addr, network: toml_config.node.network, @@ -158,6 +172,7 @@ impl TryFrom for Config { lsps2_service_config, log_level, log_file_path: toml_config.log.and_then(|l| l.file), + tls_config, }) } } @@ -173,6 +188,7 @@ pub struct TomlConfig { rabbitmq: Option, liquidity: Option, log: Option, + tls: Option, } #[derive(Deserialize, Serialize)] @@ -223,6 +239,13 @@ struct RabbitmqConfig { exchange_name: String, } +#[derive(Deserialize, Serialize)] +struct TomlTlsConfig { + cert_path: Option, + key_path: Option, + hosts: Option>, +} + #[derive(Deserialize, Serialize)] struct LiquidityConfig { lsps2_service: Option, @@ -309,6 +332,11 @@ mod tests { alias = "LDK Server" api_key = "test_api_key" + [tls] + cert_path = "/path/to/tls.crt" + key_path = "/path/to/tls.key" + hosts = ["example.com", "ldk-server.local"] + [storage.disk] dir_path = "/tmp" @@ -349,6 +377,11 @@ mod tests { rest_service_addr: SocketAddr::from_str("127.0.0.1:3002").unwrap(), api_key: "test_api_key".to_string(), storage_dir_path: "/tmp".to_string(), + tls_config: Some(TlsConfig { + cert_path: Some("/path/to/tls.crt".to_string()), + key_path: Some("/path/to/tls.key".to_string()), + hosts: vec!["example.com".to_string(), "ldk-server.local".to_string()], + }), chain_source: ChainSource::Esplora { server_url: String::from("https://mempool.space/api"), }, @@ -375,6 +408,7 @@ mod tests { assert_eq!(config.rest_service_addr, expected.rest_service_addr); assert_eq!(config.api_key, expected.api_key); assert_eq!(config.storage_dir_path, expected.storage_dir_path); + assert_eq!(config.tls_config, expected.tls_config); let ChainSource::Esplora { server_url } = config.chain_source else { panic!("unexpected config chain source"); }; diff --git a/ldk-server/src/util/mod.rs b/ldk-server/src/util/mod.rs index 8bcf1c1..3662b12 100644 --- a/ldk-server/src/util/mod.rs +++ b/ldk-server/src/util/mod.rs @@ -10,3 +10,4 @@ pub(crate) mod config; pub(crate) mod logger; pub(crate) mod proto_adapter; +pub(crate) mod tls; diff --git a/ldk-server/src/util/tls.rs b/ldk-server/src/util/tls.rs new file mode 100644 index 0000000..41a9dd7 --- /dev/null +++ b/ldk-server/src/util/tls.rs @@ -0,0 +1,226 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +use crate::util::config::TlsConfig; +use base64::Engine; +use rcgen::{generate_simple_self_signed, CertifiedKey}; +use std::fs; +use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use tokio_rustls::rustls::ServerConfig; + +// PEM markers +const PEM_CERT_BEGIN: &str = "-----BEGIN CERTIFICATE-----"; +const PEM_CERT_END: &str = "-----END CERTIFICATE-----"; +const PEM_KEY_BEGIN: &str = "-----BEGIN PRIVATE KEY-----"; +const PEM_KEY_END: &str = "-----END PRIVATE KEY-----"; + +/// Gets or generates TLS configuration. If custom paths are provided, uses those. +/// Otherwise, generates a self-signed certificate in the storage directory. +pub fn get_or_generate_tls_config( + tls_config: Option, storage_dir: &str, +) -> Result { + if let Some(config) = tls_config { + let cert_path = config.cert_path.unwrap_or(format!("{storage_dir}/tls.crt")); + let key_path = config.key_path.unwrap_or(format!("{storage_dir}/tls.key")); + if !fs::exists(&cert_path).unwrap_or(false) || !fs::exists(&key_path).unwrap_or(false) { + generate_self_signed_cert(&cert_path, &key_path, &config.hosts)?; + } + load_tls_config(&cert_path, &key_path) + } else { + // Check if we already have generated certs, if we don't, generate new ones + let cert_path = format!("{storage_dir}/tls.crt"); + let key_path = format!("{storage_dir}/tls.key"); + if !fs::exists(&cert_path).unwrap_or(false) || !fs::exists(&key_path).unwrap_or(false) { + generate_self_signed_cert(&cert_path, &key_path, &[])?; + } + + load_tls_config(&cert_path, &key_path) + } +} + +/// Parses a PEM-encoded certificate file and returns the DER-encoded certificates. +fn parse_pem_certs(pem_data: &str) -> Result>, String> { + let mut certs = Vec::new(); + + for block in pem_data.split(PEM_CERT_END) { + if let Some(start) = block.find(PEM_CERT_BEGIN) { + let base64_content: String = block[start + PEM_CERT_BEGIN.len()..] + .lines() + .filter(|line| !line.starts_with("-----") && !line.is_empty()) + .collect(); + + let der = base64::engine::general_purpose::STANDARD + .decode(&base64_content) + .map_err(|e| format!("Failed to decode certificate base64: {e}"))?; + + certs.push(CertificateDer::from(der)); + } + } + + Ok(certs) +} + +/// Parses a PEM-encoded PKCS#8 private key file and returns the DER-encoded key. +fn parse_pem_private_key(pem_data: &str) -> Result, String> { + let start = pem_data.find(PEM_KEY_BEGIN).ok_or("Missing BEGIN PRIVATE KEY marker")?; + let end = pem_data.find(PEM_KEY_END).ok_or("Missing END PRIVATE KEY marker")?; + + let base64_content: String = pem_data[start + PEM_KEY_BEGIN.len()..end] + .lines() + .filter(|line| !line.starts_with("-----") && !line.is_empty()) + .collect(); + + let der = base64::engine::general_purpose::STANDARD + .decode(&base64_content) + .map_err(|e| format!("Failed to decode private key base64: {e}"))?; + + Ok(PrivateKeyDer::Pkcs8(der.into())) +} + +/// Generates a self-signed TLS certificate and saves it to the storage directory. +/// Returns the paths to the generated cert and key files. +fn generate_self_signed_cert( + cert_path: &str, key_path: &str, configure_hosts: &[String], +) -> Result<(), String> { + let mut hosts = vec!["localhost".to_string(), "127.0.0.1".to_string()]; + hosts.extend_from_slice(configure_hosts); + + let CertifiedKey { cert, key_pair } = generate_simple_self_signed(hosts) + .map_err(|e| format!("Failed to generate self-signed certificate: {e}"))?; + + // Convert DER to PEM format + let cert_der = cert.der(); + let key_der = key_pair.serialize_der(); + + let cert_pem = format!( + "{PEM_CERT_BEGIN}\n{}\n{PEM_CERT_END}\n", + base64::engine::general_purpose::STANDARD + .encode(cert_der) + .as_bytes() + .chunks(64) + .map(|chunk| std::str::from_utf8(chunk).unwrap()) + .collect::>() + .join("\n") + ); + + let key_pem = format!( + "{PEM_KEY_BEGIN}\n{}\n{PEM_KEY_END}\n", + base64::engine::general_purpose::STANDARD + .encode(&key_der) + .as_bytes() + .chunks(64) + .map(|chunk| std::str::from_utf8(chunk).unwrap()) + .collect::>() + .join("\n") + ); + + fs::write(cert_path, &cert_pem) + .map_err(|e| format!("Failed to write TLS certificate to '{cert_path}': {e}"))?; + fs::write(key_path, &key_pem) + .map_err(|e| format!("Failed to write TLS key to '{key_path}': {e}"))?; + + Ok(()) +} + +/// Loads TLS configuration from provided paths. +fn load_tls_config(cert_path: &str, key_path: &str) -> Result { + let cert_pem = fs::read_to_string(cert_path) + .map_err(|e| format!("Failed to read TLS certificate file '{cert_path}': {e}"))?; + let key_pem = fs::read_to_string(key_path) + .map_err(|e| format!("Failed to read TLS key file '{key_path}': {e}"))?; + + let certs = parse_pem_certs(&cert_pem)?; + + if certs.is_empty() { + return Err("No certificates found in certificate file".to_string()); + } + + let key = parse_pem_private_key(&key_pem)?; + + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .map_err(|e| format!("Failed to build TLS server config: {e}")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_pem_certs() { + let pem = "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHBfpegPjMCMA0GCSqGSIb3DQEBCwUAMBExDzANBgNVBAMMBnVu\ndXNlZDAeFw0yMzAxMDEwMDAwMDBaFw0yNDAxMDEwMDAwMDBaMBExDzANBgNVBAMM\nBnVudXNlZDBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQC7o96FCEcJsggt0c0dSfEB\nmm6vv1LdCoxXnhOSCutoJgJgmCPBjU1doFFKwAtXjfOv0eSLZ3NHLu0LRKmVvOsP\nAgMBAAGjUzBRMB0GA1UdDgQWBBQK3fc0myO0psd71FJd8v7VCmDJOzAfBgNVHSME\nGDAWgBQK3fc0myO0psd71FJd8v7VCmDJOzAPBgNVHRMBAf8EBTADAQH/MA0GCSqG\nSIb3DQEBCwUAA0EAhJg0cx2pFfVfGBfbJQNFa+A4ynJBMqKYlbUnJBfWPwg13RhC\nivLjYyhKzEbnOug0TuFfVaUBGfBYbPgaJQ4BAg==\n-----END CERTIFICATE-----\n"; + + let certs = parse_pem_certs(pem).unwrap(); + assert_eq!(certs.len(), 1); + assert!(!certs[0].is_empty()); + } + + #[test] + fn test_parse_pem_certs_multiple() { + let pem = "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHBfpegPjMCMA0GCSqGSIb3DQEBCwUAMBExDzANBgNVBAMMBnVu\ndXNlZDAeFw0yMzAxMDEwMDAwMDBaFw0yNDAxMDEwMDAwMDBaMBExDzANBgNVBAMM\nBnVudXNlZDBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQC7o96FCEcJsggt0c0dSfEB\nmm6vv1LdCoxXnhOSCutoJgJgmCPBjU1doFFKwAtXjfOv0eSLZ3NHLu0LRKmVvOsP\nAgMBAAGjUzBRMB0GA1UdDgQWBBQK3fc0myO0psd71FJd8v7VCmDJOzAfBgNVHSME\nGDAWgBQK3fc0myO0psd71FJd8v7VCmDJOzAPBgNVHRMBAf8EBTADAQH/MA0GCSqG\nSIb3DQEBCwUAA0EAhJg0cx2pFfVfGBfbJQNFa+A4ynJBMqKYlbUnJBfWPwg13RhC\nivLjYyhKzEbnOug0TuFfVaUBGfBYbPgaJQ4BAg==\n-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAKHBfpegPjMCMA0GCSqGSIb3DQEBCwUAMBExDzANBgNVBAMMBnVu\ndXNlZDAeFw0yMzAxMDEwMDAwMDBaFw0yNDAxMDEwMDAwMDBaMBExDzANBgNVBAMM\nBnVudXNlZDBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQC7o96FCEcJsggt0c0dSfEB\nmm6vv1LdCoxXnhOSCutoJgJgmCPBjU1doFFKwAtXjfOv0eSLZ3NHLu0LRKmVvOsP\nAgMBAAGjUzBRMB0GA1UdDgQWBBQK3fc0myO0psd71FJd8v7VCmDJOzAfBgNVHSME\nGDAWgBQK3fc0myO0psd71FJd8v7VCmDJOzAPBgNVHRMBAf8EBTADAQH/MA0GCSqG\nSIb3DQEBCwUAA0EAhJg0cx2pFfVfGBfbJQNFa+A4ynJBMqKYlbUnJBfWPwg13RhC\nivLjYyhKzEbnOug0TuFfVaUBGfBYbPgaJQ4BAg==\n-----END CERTIFICATE-----\n"; + + let certs = parse_pem_certs(pem).unwrap(); + assert_eq!(certs.len(), 2); + } + + #[test] + fn test_parse_pem_certs_empty() { + let certs = parse_pem_certs("").unwrap(); + assert!(certs.is_empty()); + + let certs = parse_pem_certs("not a cert").unwrap(); + assert!(certs.is_empty()); + } + + #[test] + fn test_parse_pem_private_key_pkcs8() { + let pem = "-----BEGIN PRIVATE KEY-----\nMIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQg2a2rwplBQLzHPDvn\nsaw8HKDP6WYBSF684gcz+D7zeVShRANCAAQq8R/E45tTNWMEpK8abYM7VzuJxpPS\nhJCi6bzjOPGHawEO8safLOWFaV7GqLJM0OdM3eu/qcz8HwgI3T8EVHQK\n-----END PRIVATE KEY-----\n"; + + let key = parse_pem_private_key(pem).unwrap(); + assert!(matches!(key, PrivateKeyDer::Pkcs8(_))); + } + + #[test] + fn test_parse_pem_private_key_invalid() { + let result = parse_pem_private_key(""); + assert!(result.is_err()); + + let result = parse_pem_private_key("not a key"); + assert!(result.is_err()); + } + + #[test] + fn test_generate_and_load_roundtrip() { + let temp_dir = std::env::temp_dir(); + let suffix: u64 = rand::random(); + let cert_path = temp_dir.join(format!("test_tls_cert_{suffix}.pem")); + let key_path = temp_dir.join(format!("test_tls_key_{suffix}.pem")); + + // Clean up any existing files to be safe + let _ = fs::remove_file(&cert_path); + let _ = fs::remove_file(&key_path); + + // Generate cert + generate_self_signed_cert(cert_path.to_str().unwrap(), key_path.to_str().unwrap(), &[]) + .unwrap(); + + // Verify files exist + assert!(cert_path.exists()); + assert!(key_path.exists()); + + // Load config + let res = load_tls_config(cert_path.to_str().unwrap(), key_path.to_str().unwrap()); + assert!(res.is_ok()); + + // Clean up + let _ = fs::remove_file(&cert_path); + let _ = fs::remove_file(&key_path); + } +}