diff --git a/Cargo.lock b/Cargo.lock index 2275275a..d752c4d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -104,9 +104,9 @@ dependencies = [ [[package]] name = "alloy-chains" -version = "0.2.31" +version = "0.2.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d9d22005bf31b018f31ef9ecadb5d2c39cf4f6acc8db0456f72c815f3d7f757" +checksum = "9247f0a399ef71aeb68f497b2b8fb348014f742b50d3b83b1e00dfe1b7d64b3d" dependencies = [ "alloy-primitives", "num_enum", @@ -202,7 +202,7 @@ dependencies = [ "itoa", "serde", "serde_json", - "winnow", + "winnow 0.7.15", ] [[package]] @@ -615,7 +615,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6df77fea9d6a2a75c0ef8d2acbdfd92286cc599983d3175ccdc170d3433d249" dependencies = [ "serde", - "winnow", + "winnow 0.7.15", ] [[package]] @@ -2375,6 +2375,12 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" +[[package]] +name = "dyn-eq" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c2d035d21af5cde1a6f5c7b444a5bf963520a9f142e5d06931178433d7d5388" + [[package]] name = "ecdsa" version = "0.16.9" @@ -5513,16 +5519,28 @@ name = "pluto-core" version = "1.7.1" dependencies = [ "alloy", + "anyhow", + "async-trait", "base64 0.22.1", "built", "cancellation", "chrono", + "clap", "crossbeam", + "dyn-clone", + "dyn-eq", + "futures", + "futures-timer", "hex", + "k256", "libp2p", "pluto-build-proto", + "pluto-cluster", "pluto-eth2api", "pluto-eth2util", + "pluto-p2p", + "pluto-testutil", + "pluto-tracing", "prost 0.14.3", "prost-types 0.14.3", "rand 0.8.5", @@ -5533,7 +5551,11 @@ dependencies = [ "test-case", "thiserror 2.0.18", "tokio", + "tokio-util", + "tracing", "tree_hash", + "unsigned-varint 0.8.0", + "vise", ] [[package]] @@ -5666,6 +5688,7 @@ dependencies = [ "pluto-k1util", "pluto-testutil", "pluto-tracing", + "prost 0.14.3", "rand 0.8.5", "reqwest 0.13.2", "serde", @@ -5675,11 +5698,36 @@ dependencies = [ "tokio", "tokio-util", "tracing", + "unsigned-varint 0.8.0", "url", "vise", "vise-exporter", ] +[[package]] +name = "pluto-parsigex" +version = "1.7.1" +dependencies = [ + "anyhow", + "clap", + "either", + "futures", + "futures-timer", + "hex", + "libp2p", + "pluto-cluster", + "pluto-core", + "pluto-p2p", + "pluto-tracing", + "prost 0.14.3", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tokio-util", + "tracing", + "unsigned-varint 0.8.0", +] + [[package]] name = "pluto-peerinfo" version = "1.7.1" @@ -5741,6 +5789,7 @@ dependencies = [ "hex", "k256", "pluto-crypto", + "pluto-eth2api", "rand 0.8.5", "rand_core 0.6.4", "thiserror 2.0.18", @@ -5862,7 +5911,7 @@ version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" dependencies = [ - "toml_edit 0.25.4+spec-1.1.0", + "toml_edit 0.25.5+spec-1.1.0", ] [[package]] @@ -7669,9 +7718,9 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "1.0.0+spec-1.1.0" +version = "1.0.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" +checksum = "9b320e741db58cac564e26c607d3cc1fdc4a88fd36c879568c07856ed83ff3e9" dependencies = [ "serde_core", ] @@ -7687,28 +7736,28 @@ dependencies = [ "serde_spanned", "toml_datetime 0.6.11", "toml_write", - "winnow", + "winnow 0.7.15", ] [[package]] name = "toml_edit" -version = "0.25.4+spec-1.1.0" +version = "0.25.5+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7193cbd0ce53dc966037f54351dbbcf0d5a642c7f0038c382ef9e677ce8c13f2" +checksum = "8ca1a40644a28bce036923f6a431df0b34236949d111cc07cb6dca830c9ef2e1" dependencies = [ "indexmap 2.13.0", - "toml_datetime 1.0.0+spec-1.1.0", + "toml_datetime 1.0.1+spec-1.1.0", "toml_parser", - "winnow", + "winnow 1.0.0", ] [[package]] name = "toml_parser" -version = "1.0.9+spec-1.1.0" +version = "1.0.10+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" +checksum = "7df25b4befd31c4816df190124375d5a20c6b6921e2cad937316de3fccd63420" dependencies = [ - "winnow", + "winnow 1.0.0", ] [[package]] @@ -8931,6 +8980,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "winnow" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a90e88e4667264a994d34e6d1ab2d26d398dcdca8b7f52bec8668957517fc7d8" +dependencies = [ + "memchr", +] + [[package]] name = "winreg" version = "0.50.0" diff --git a/Cargo.toml b/Cargo.toml index 93eddca0..c0c3f4f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "crates/app", + "crates/parsigex", "crates/build-proto", "crates/cli", "crates/cluster", @@ -27,6 +28,7 @@ license = "BUSL-1.1" publish = false [workspace.dependencies] +async-trait = "0.1.89" alloy = { version = "1.3", features = ["essentials"] } built = { version = "0.8.0", features = ["git2", "chrono", "cargo-lock"] } blst = "0.3" @@ -36,6 +38,8 @@ cancellation = "0.1.0" chrono = { version = "0.4", features = ["serde"] } clap = { version = "4.5.53", features = ["derive", "env", "cargo"] } crossbeam = "0.8.4" +dyn-clone = "1.0" +dyn-eq = "0.1.3" either = "1.13" futures = "0.3" futures-timer = "3.0" @@ -95,6 +99,7 @@ wiremock = "0.6" # Crates in the workspace pluto-app = { path = "crates/app" } +pluto-parsigex = { path = "crates/parasigex" } pluto-build-proto = { path = "crates/build-proto" } pluto-cli = { path = "crates/cli" } pluto-cluster = { path = "crates/cluster" } diff --git a/crates/app/src/deadline/mod.rs b/crates/app/src/deadline/mod.rs deleted file mode 100644 index a8c39809..00000000 --- a/crates/app/src/deadline/mod.rs +++ /dev/null @@ -1,53 +0,0 @@ -use pluto_core::types::{Duty, DutyType}; -use pluto_eth2api::{EthBeaconNodeApiClient, EthBeaconNodeApiClientError}; - -/// Defines the fraction of the slot duration to use as a margin. -/// This is to consider network delays and other factors that may affect the -/// timing. -pub const MARGIN_FACTOR: u32 = 12; - -/// A function that returns the deadline for a duty. -pub type DeadlineFunc = Box Option> + Send + Sync>; - -/// Error type for deadline-related operations. -#[derive(Debug, thiserror::Error)] -pub enum DeadlineError { - /// Beacon client API error. - #[error("Beacon client error: {0}")] - BeaconClientError(#[from] EthBeaconNodeApiClientError), -} - -type Result = std::result::Result; - -/// Create a function that provides duty deadline or [`None`] if the duty never -/// deadlines. -pub async fn new_duty_deadline_func(eth2_cl: &EthBeaconNodeApiClient) -> Result { - let genesis_time = eth2_cl.fetch_genesis_time().await?; - let (slot_duration, _) = eth2_cl.fetch_slots_config().await?; - - #[allow( - clippy::arithmetic_side_effects, - reason = "Matches original implementation" - )] - Ok(Box::new(move |duty: Duty| match duty.duty_type { - DutyType::Exit | DutyType::BuilderRegistration => None, - _ => { - #[allow( - clippy::cast_possible_truncation, - reason = "TODO: unsupported operation in u64" - )] - let start = genesis_time + (slot_duration * (u64::from(duty.slot)) as u32); - let margin = slot_duration / MARGIN_FACTOR; - - let duration = match duty.duty_type { - DutyType::Proposer | DutyType::Randao => slot_duration / 3, - DutyType::SyncMessage => 2 * slot_duration / 3, - DutyType::Attester | DutyType::Aggregator | DutyType::PrepareAggregator => { - 2 * slot_duration - } - _ => slot_duration, - }; - Some(start + duration + margin) - } - })) -} diff --git a/crates/app/src/lib.rs b/crates/app/src/lib.rs index e0634ce2..d6cf9287 100644 --- a/crates/app/src/lib.rs +++ b/crates/app/src/lib.rs @@ -13,9 +13,6 @@ pub mod log; /// until the deadline has elapsed. pub mod retry; -/// Deadline -pub mod deadline; - /// Featureset defines a set of global features and their rollout status. pub mod featureset; diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index ea060e56..3559cafd 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -7,33 +7,50 @@ license.workspace = true publish.workspace = true [dependencies] +async-trait.workspace = true cancellation.workspace = true chrono.workspace = true crossbeam.workspace = true +futures.workspace = true +futures-timer.workspace = true +dyn-clone.workspace = true +dyn-eq.workspace = true hex.workspace = true +libp2p.workspace = true +vise.workspace = true +pluto-eth2api.workspace = true +prost.workspace = true +prost-types.workspace = true +regex.workspace = true serde.workspace = true serde_json.workspace = true serde_with.workspace = true base64.workspace = true thiserror.workspace = true tokio.workspace = true -libp2p.workspace = true -regex.workspace = true -prost.workspace = true -prost-types.workspace = true -pluto-eth2api.workspace = true +tokio-util.workspace = true +tracing.workspace = true pluto-eth2util.workspace = true tree_hash.workspace = true +unsigned-varint.workspace = true [dev-dependencies] +anyhow.workspace = true alloy.workspace = true +clap.workspace = true rand.workspace = true libp2p.workspace = true +k256.workspace = true prost.workspace = true prost-types.workspace = true hex.workspace = true chrono.workspace = true test-case.workspace = true +pluto-eth2util.workspace = true +pluto-cluster.workspace = true +pluto-p2p.workspace = true +pluto-testutil.workspace = true +pluto-tracing.workspace = true [build-dependencies] pluto-build-proto.workspace = true diff --git a/crates/core/src/deadline.rs b/crates/core/src/deadline.rs new file mode 100644 index 00000000..47cd7fd4 --- /dev/null +++ b/crates/core/src/deadline.rs @@ -0,0 +1,925 @@ +//! Duty deadline tracking and notification functionality. +//! +//! This module provides the [`Deadliner`] trait for tracking duty deadlines +//! and notifying when duties expire. It implements a background task that +//! manages timers for multiple duties and sends expired duties to a channel. +//! +//! # Example +//! +//! ```no_run +//! use chrono::{DateTime, Utc}; +//! use pluto_core::{ +//! deadline::{DeadlineFunc, new_deadliner}, +//! types::{Duty, DutyType, SlotNumber}, +//! }; +//! use std::sync::Arc; +//! use tokio_util::sync::CancellationToken; +//! +//! # async fn example() { +//! let cancel_token = CancellationToken::new(); +//! +//! // Define a deadline function +//! let deadline_func: DeadlineFunc = Arc::new(|_duty| { +//! let deadline = DateTime::from_timestamp(1000, 0).unwrap(); +//! Ok(Some(deadline)) +//! }); +//! +//! let deadliner = new_deadliner(cancel_token, "example", deadline_func); +//! +//! // Add a duty +//! let duty = Duty::new_attester_duty(SlotNumber::new(1)); +//! let added = deadliner.add(duty).await; +//! +//! // Receive expired duties +//! if let Some(mut rx) = deadliner.c() { +//! while let Some(expired_duty) = rx.recv().await { +//! println!("Duty expired: {}", expired_duty); +//! } +//! } +//! # } +//! ``` +use crate::types::{Duty, DutyType, SlotNumber}; +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use futures::future::{BoxFuture, FutureExt}; +use pluto_eth2api::{EthBeaconNodeApiClient, EthBeaconNodeApiClientError}; +use std::{ + collections::HashSet, + sync::{Arc, Mutex}, +}; +use tokio_util::sync::CancellationToken; + +/// Fraction of slot duration to use as a margin for network delays. +const MARGIN_FACTOR: i32 = 12; + +/// Type alias for the deadline function. +/// +/// Takes a duty and returns an optional deadline. +/// Returns `Ok(Some(deadline))` if the duty expires at the given time. +/// Returns `Ok(None)` if the duty never expires. +pub type DeadlineFunc = Arc Result>> + Send + Sync>; + +/// Error types for deadline operations. +#[derive(Debug, thiserror::Error)] +pub enum DeadlineError { + /// Failed to fetch genesis time from beacon node. + #[error("Failed to fetch genesis time: {0}")] + FetchGenesisTime(#[from] EthBeaconNodeApiClientError), + + /// Deadliner has been shut down. + #[error("Deadliner has been shut down")] + Shutdown, + + /// Arithmetic overflow in deadline calculation. + #[error("Arithmetic overflow in deadline calculation")] + ArithmeticOverflow, + + /// Duration conversion failed. + #[error("Duration conversion failed")] + DurationConversion, + + /// DateTime calculation failed. + #[error("DateTime calculation failed")] + DateTimeCalculation, +} + +/// Result type for deadline operations. +pub type Result = std::result::Result; + +/// Converts a `std::time::Duration` to `chrono::Duration`. +fn to_chrono_duration(duration: std::time::Duration) -> Result { + chrono::Duration::from_std(duration).map_err(|_| DeadlineError::DurationConversion) +} + +/// Converts seconds (u64) to `chrono::Duration`. +fn secs_to_chrono(secs: u64) -> Result { + let secs_i64 = i64::try_from(secs).map_err(|_| DeadlineError::ArithmeticOverflow)?; + chrono::Duration::try_seconds(secs_i64).ok_or(DeadlineError::DurationConversion) +} + +/// Deadliner provides duty deadline functionality. +/// +/// The `c()` method returns a channel for receiving expired duties. +/// It may only be called once and the returned channel should be used +/// by a single task. Multiple instances are required for different +/// components and use cases. +pub trait Deadliner: Send + Sync { + /// Adds a duty for deadline scheduling. + /// + /// Returns `true` if the duty was added for future deadline scheduling. + /// This method is idempotent and returns `true` if the duty was previously + /// added and still awaits deadline scheduling. + /// + /// Returns `false` if: + /// - The duty has already expired and cannot be scheduled + /// - The duty never expires (e.g., Exit, BuilderRegistration) + fn add(&self, duty: Duty) -> BoxFuture<'_, bool>; + + /// Returns the channel for receiving deadlined duties. + /// + /// This method may only be called once and returns `None` on subsequent + /// calls. The returned channel should only be used by a single task. + fn c(&self) -> Option>; +} + +/// Trait for beacon clients that can provide genesis time and slot +/// configuration. +/// +/// This trait abstracts the necessary beacon node API calls for deadline +/// calculation. +#[async_trait] +pub trait BeaconClientForDeadline { + /// Fetches the genesis time from the beacon node. + async fn fetch_genesis_time(&self) -> Result>; + + /// Fetches the slot duration and slots per epoch from the beacon node. + async fn fetch_slots_config(&self) -> Result<(std::time::Duration, u64)>; +} + +#[async_trait] +impl BeaconClientForDeadline for EthBeaconNodeApiClient { + async fn fetch_genesis_time(&self) -> Result> { + self.fetch_genesis_time() + .await + .map_err(DeadlineError::FetchGenesisTime) + } + + async fn fetch_slots_config(&self) -> Result<(std::time::Duration, u64)> { + self.fetch_slots_config() + .await + .map_err(DeadlineError::FetchGenesisTime) + } +} + +/// Creates a deadline function from the Ethereum 2.0 beacon node configuration. +/// +/// Fetches genesis time and slot duration from the beacon node and returns +/// a function that calculates deadlines for each duty type. +/// +/// # Errors +/// +/// Returns an error if fetching genesis time or slots config fails. +pub async fn new_duty_deadline_func( + client: &C, +) -> Result { + let genesis_time = client.fetch_genesis_time().await?; + let (slot_duration, _slots_per_epoch) = client.fetch_slots_config().await?; + + // Convert std::time::Duration to chrono::Duration for slot_duration + let slot_duration = to_chrono_duration(slot_duration)?; + + Ok(Arc::new(move |duty: Duty| { + // Exit and BuilderRegistration duties never expire + match duty.duty_type { + DutyType::Exit | DutyType::BuilderRegistration => { + return Ok(None); + } + _ => {} + } + + // Calculate slot start time + // start = genesis_time + (slot * slot_duration) + let slot_secs = duty + .slot + .inner() + .checked_mul( + u64::try_from(slot_duration.num_seconds()) + .map_err(|_| DeadlineError::ArithmeticOverflow)?, + ) + .ok_or(DeadlineError::ArithmeticOverflow)?; + let slot_offset = secs_to_chrono(slot_secs)?; + + let start: DateTime = genesis_time + .checked_add_signed(slot_offset) + .ok_or(DeadlineError::DateTimeCalculation)?; + + // Calculate margin: slot_duration / MARGIN_FACTOR + let margin = slot_duration + .checked_div(MARGIN_FACTOR) + .ok_or(DeadlineError::ArithmeticOverflow)?; + + // Calculate duty-specific duration + let duration = match duty.duty_type { + DutyType::Proposer | DutyType::Randao => { + // duration = slot_duration / 3 + slot_duration + .checked_div(3) + .ok_or(DeadlineError::ArithmeticOverflow)? + } + DutyType::SyncMessage => { + // duration = 2 * slot_duration / 3 + slot_duration + .checked_mul(2) + .and_then(|s| s.checked_div(3)) + .ok_or(DeadlineError::ArithmeticOverflow)? + } + DutyType::Attester | DutyType::Aggregator | DutyType::PrepareAggregator => { + // duration = 2 * slot_duration + // Even though attestations and aggregations are acceptable after 2 slots, + // the rewards are heavily diminished. + slot_duration + .checked_mul(2) + .ok_or(DeadlineError::ArithmeticOverflow)? + } + _ => { + // Default: duration = slot_duration + slot_duration + } + }; + + // Calculate final deadline: start + duration + margin + let deadline = start + .checked_add_signed(duration) + .and_then(|t| t.checked_add_signed(margin)) + .ok_or(DeadlineError::DateTimeCalculation)?; + + Ok(Some(deadline)) + })) +} + +/// Gets the duty with the earliest deadline from the duties map. +/// +/// Returns a tuple of (duty, deadline). If no duties are available, +/// returns a sentinel far-future date (9999-01-01). +fn get_curr_duty(duties: &HashSet, deadline_func: &DeadlineFunc) -> (Duty, DateTime) { + let mut curr_duty = Duty::new(SlotNumber::new(0), DutyType::Unknown); + + // Use far-future sentinel date (9999-01-01) matching Go implementation + // This timestamp is a known constant and will never fail + let mut curr_deadline = + DateTime::from_timestamp(253402300799, 0).unwrap_or(DateTime::::MAX_UTC); + + for duty in duties.iter() { + let Ok(deadline_opt) = deadline_func(duty.clone()) else { + continue; + }; + + // Ignore duties that never expire + let Some(duty_deadline) = deadline_opt else { + continue; + }; + + // Update if this duty has an earlier deadline + if duty_deadline < curr_deadline { + curr_duty = duty.clone(); + curr_deadline = duty_deadline; + } + } + + (curr_duty, curr_deadline) +} + +/// Internal message type for adding duties to the deadliner. +struct DeadlineInput { + duty: Duty, + response_tx: tokio::sync::oneshot::Sender, +} + +/// Implementation of the Deadliner trait. +struct DeadlinerImpl { + cancel_token: CancellationToken, + input_tx: tokio::sync::mpsc::UnboundedSender, + output_rx: Arc>>>, +} + +impl Deadliner for DeadlinerImpl { + fn add(&self, duty: Duty) -> BoxFuture<'_, bool> { + Box::pin(async move { + // Check if shut down + if self.cancel_token.is_cancelled() { + return false; + } + + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let input = DeadlineInput { duty, response_tx }; + + // Send the duty to the background task + if self.input_tx.send(input).is_err() { + return false; + } + + // Wait for response + response_rx.await.unwrap_or(false) + }) + } + + fn c(&self) -> Option> { + self.output_rx + .lock() + .ok() + .and_then(|mut guard| guard.take()) + } +} + +/// Clock trait for abstracting time operations. +trait Clock: Send + Sync { + /// Returns the current time. + fn now(&self) -> DateTime; + + /// Creates a sleep future that completes after the given duration. + fn sleep(&self, duration: std::time::Duration) -> BoxFuture<'static, ()>; +} + +/// Real clock implementation using tokio::time. +struct RealClock; + +impl Clock for RealClock { + fn now(&self) -> DateTime { + Utc::now() + } + + fn sleep(&self, duration: std::time::Duration) -> BoxFuture<'static, ()> { + tokio::time::sleep(duration).boxed() + } +} + +impl DeadlinerImpl { + /// Background task that manages duty deadlines. + /// + /// This is an associated function (not a method) because the DeadlinerImpl + /// is immediately wrapped in Arc, preventing mutable access. + async fn run_task( + cancel_token: CancellationToken, + label: String, + deadline_func: DeadlineFunc, + clock: Arc, + mut input_rx: tokio::sync::mpsc::UnboundedReceiver, + output_tx: tokio::sync::mpsc::Sender, + ) { + let mut duties: HashSet = HashSet::new(); + let (mut curr_duty, mut curr_deadline) = get_curr_duty(&duties, &deadline_func); + + // Create initial timer + let now = clock.now(); + let initial_duration = curr_deadline + .signed_duration_since(now) + .to_std() + .unwrap_or(std::time::Duration::ZERO); + let mut timer = clock.sleep(initial_duration); + + loop { + tokio::select! { + biased; + + _ = cancel_token.cancelled() => { + return; + } + + Some(input) = input_rx.recv() => { + let duty = input.duty; + let Ok(deadline_opt) = deadline_func(duty.clone()) else { + let _ = input.response_tx.send(false); + continue; + }; + + // Drop duties that never expire + let Some(deadline) = deadline_opt else { + let _ = input.response_tx.send(false); + continue; + }; + + let now = clock.now(); + let expired = deadline < now; + + let _ = input.response_tx.send(!expired); + + // Ignore expired duties + if expired { + continue; + } + + // Add duty to the map (idempotent) + duties.insert(duty); + + // Update timer if this deadline is earlier + if deadline < curr_deadline { + let (new_duty, new_deadline) = get_curr_duty(&duties, &deadline_func); + curr_duty = new_duty; + curr_deadline = new_deadline; + + let duration = curr_deadline + .signed_duration_since(clock.now()) + .to_std() + .unwrap_or(std::time::Duration::ZERO); + timer = clock.sleep(duration); + } + } + + _ = &mut timer => { + // Deadline expired - send duty to output channel + match output_tx.try_send(curr_duty.clone()) { + Ok(()) => {} + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + tracing::warn!( + label = %label, + duty = %curr_duty, + "Deadliner output channel full" + ); + } + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { + return; + } + } + + // Remove duty from map + duties.remove(&curr_duty); + + // Update to next duty + let (new_duty, new_deadline) = get_curr_duty(&duties, &deadline_func); + curr_duty = new_duty; + curr_deadline = new_deadline; + + let duration = curr_deadline + .signed_duration_since(clock.now()) + .to_std() + .unwrap_or(std::time::Duration::ZERO); + timer = clock.sleep(duration); + } + } + } + } + + /// Internal constructor for creating a deadliner with a specific clock. + fn new_internal( + cancel_token: CancellationToken, + label: impl Into, + deadline_func: DeadlineFunc, + clock: Arc, + ) -> Arc { + const OUTPUT_BUFFER: usize = 10; + + let label = label.into(); + let (input_tx, input_rx) = tokio::sync::mpsc::unbounded_channel(); + let (output_tx, output_rx) = tokio::sync::mpsc::channel(OUTPUT_BUFFER); + + let impl_instance: Arc = Arc::new(DeadlinerImpl { + cancel_token: cancel_token.clone(), + input_tx, + output_rx: Arc::new(Mutex::new(Some(output_rx))), + }); + + // Spawn background task + tokio::spawn(Self::run_task( + cancel_token, + label, + deadline_func, + clock, + input_rx, + output_tx, + )); + + impl_instance + } +} + +/// Creates a new Deadliner instance. +/// +/// Starts a background task that manages duty deadlines and sends expired +/// duties to a channel. The background task runs until the cancellation token +/// is cancelled. +/// +/// # Arguments +/// +/// * `cancel_token` - Token to cancel the background task +/// * `label` - Label for logging purposes +/// * `deadline_func` - Function that calculates deadlines for duties +/// +/// # Returns +/// +/// An Arc-wrapped Deadliner trait object +pub fn new_deadliner( + cancel_token: CancellationToken, + label: impl Into, + deadline_func: DeadlineFunc, +) -> Arc { + DeadlinerImpl::new_internal(cancel_token, label, deadline_func, Arc::new(RealClock)) +} + +/// Creates a new Deadliner instance for testing with a fake clock. +/// +/// This constructor is intended for use in tests where you need to control +/// time progression. +/// +/// # Arguments +/// +/// * `cancel_token` - Token to cancel the background task +/// * `label` - Label for logging purposes +/// * `deadline_func` - Function that calculates deadlines for duties +/// * `clock` - Test clock for controlling time in tests +/// +/// # Returns +/// +/// An Arc-wrapped Deadliner trait object +#[cfg(test)] +fn new_deadliner_for_test( + cancel_token: CancellationToken, + label: impl Into, + deadline_func: DeadlineFunc, + clock: Arc, +) -> Arc { + DeadlinerImpl::new_internal(cancel_token, label, deadline_func, clock) +} + +/// Fake clock implementation for testing. +#[cfg(test)] +type WakerList = Vec<(DateTime, std::task::Waker)>; + +#[cfg(test)] +struct TestClock { + start: std::sync::Arc>>, + wakers: std::sync::Arc>, +} + +#[cfg(test)] +impl TestClock { + fn new(start: DateTime) -> Self { + Self { + start: std::sync::Arc::new(std::sync::Mutex::new(start)), + wakers: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())), + } + } + + fn advance(&self, duration: std::time::Duration) { + let new_time = { + let mut start = self.start.lock().unwrap(); + let chrono_duration = chrono::Duration::from_std(duration).unwrap(); + *start = start.checked_add_signed(chrono_duration).unwrap(); + *start + }; + + // Wake all timers that have expired + let mut wakers = self.wakers.lock().unwrap(); + let (expired, pending): (Vec<_>, Vec<_>) = wakers + .drain(..) + .partition(|(deadline, _)| *deadline <= new_time); + *wakers = pending; + + // Wake expired futures + for (_, waker) in expired { + waker.wake(); + } + } +} + +#[cfg(test)] +impl Clock for TestClock { + fn now(&self) -> DateTime { + *self.start.lock().unwrap() + } + + fn sleep(&self, duration: std::time::Duration) -> BoxFuture<'static, ()> { + let deadline = self + .now() + .checked_add_signed(chrono::Duration::from_std(duration).unwrap()) + .unwrap(); + let wakers = Arc::clone(&self.wakers); + let start = Arc::clone(&self.start); + + Box::pin(std::future::poll_fn(move |cx| { + let now = *start.lock().unwrap(); + if now >= deadline { + std::task::Poll::Ready(()) + } else { + // Register waker + let mut wakers = wakers.lock().unwrap(); + // Check if this waker is already registered for this deadline + if !wakers.iter().any(|(d, _)| *d == deadline) { + wakers.push((deadline, cx.waker().clone())); + } + std::task::Poll::Pending + } + })) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::SlotNumber; + use test_case::test_case; + + /// Helper function to create expired duties, non-expired duties, and + /// voluntary exits. + fn setup_data() -> (Vec, Vec, Vec) { + let expired_duties = vec![ + Duty::new_attester_duty(SlotNumber::new(1)), + Duty::new_proposer_duty(SlotNumber::new(2)), + Duty::new_randao_duty(SlotNumber::new(3)), + ]; + + let non_expired_duties = vec![ + Duty::new_proposer_duty(SlotNumber::new(1)), + Duty::new_attester_duty(SlotNumber::new(2)), + ]; + + let voluntary_exits = vec![ + Duty::new_voluntary_exit_duty(SlotNumber::new(2)), + Duty::new_voluntary_exit_duty(SlotNumber::new(4)), + ]; + + (expired_duties, non_expired_duties, voluntary_exits) + } + + /// Helper function to add duties to the deadliner and send results to a + /// channel. + async fn add_duties( + duties: Vec, + deadliner: Arc, + result_tx: tokio::sync::mpsc::Sender, + ) { + for duty in duties { + let added = deadliner.add(duty).await; + let _ = result_tx.send(added).await; + } + } + + #[tokio::test] + async fn test_deadliner() { + let (expired_duties, non_expired_duties, voluntary_exits) = setup_data(); + + let start_time = DateTime::from_timestamp(1000, 0).unwrap(); + let clock = Arc::new(TestClock::new(start_time)); + + // Create a deadline function provider + let expired_set: std::collections::HashSet<_> = expired_duties.iter().cloned().collect(); + let deadline_func: DeadlineFunc = { + Arc::new(move |duty: Duty| { + if duty.duty_type == DutyType::Exit { + // Voluntary exits expire after 1 hour + let deadline = start_time + .checked_add_signed(chrono::Duration::try_hours(1).unwrap()) + .ok_or(DeadlineError::DateTimeCalculation)?; + return Ok(Some(deadline)); + } + + if expired_set.contains(&duty) { + // Expired duties have deadline 1 hour in the past + let deadline = start_time + .checked_sub_signed(chrono::Duration::try_hours(1).unwrap()) + .ok_or(DeadlineError::DateTimeCalculation)?; + return Ok(Some(deadline)); + } + + // Non-expired duties expire after duty.slot seconds from start + let deadline = start_time + .checked_add_signed( + chrono::Duration::try_seconds(i64::try_from(duty.slot.inner()).unwrap()) + .unwrap(), + ) + .ok_or(DeadlineError::DateTimeCalculation)?; + Ok(Some(deadline)) + }) + }; + + let cancel_token = CancellationToken::new(); + let deadliner = new_deadliner_for_test( + cancel_token.clone(), + "test", + deadline_func, + Arc::clone(&clock), + ); + + // Get the output receiver + let mut output_rx = deadliner.c().expect("should get receiver"); + + // Separate channels for expired and non-expired results + let (expired_tx, mut expired_rx) = tokio::sync::mpsc::channel(100); + let (non_expired_tx, mut non_expired_rx) = tokio::sync::mpsc::channel(100); + + // Add all duties + let expired_len = expired_duties.len(); + let non_expired_len = non_expired_duties.len(); + let voluntary_exits_len = voluntary_exits.len(); + + let handler_expired = tokio::spawn(add_duties( + expired_duties, + Arc::clone(&deadliner), + expired_tx, + )); + let handler_non_expired = tokio::spawn(add_duties( + non_expired_duties.clone(), + Arc::clone(&deadliner), + non_expired_tx.clone(), + )); + let handler_voluntary_exits = tokio::spawn(add_duties( + voluntary_exits, + Arc::clone(&deadliner), + non_expired_tx, + )); + + // Wait for all handlers to complete + let (result_expired, result_non_expired, result_voluntary_exits) = tokio::join!( + handler_expired, + handler_non_expired, + handler_voluntary_exits + ); + result_expired.unwrap(); + result_non_expired.unwrap(); + result_voluntary_exits.unwrap(); + + for _ in 0..expired_len { + let result = expired_rx.recv().await.expect("should receive result"); + assert!(!result, "expired duties should return false"); + } + + for _ in 0..(non_expired_len.checked_add(voluntary_exits_len).unwrap()) { + let result = non_expired_rx.recv().await.expect("should receive result"); + assert!(result, "non-expired duties should return true"); + } + + // Find max slot from non-expired duties + let max_slot = non_expired_duties + .iter() + .map(|d| d.slot.inner()) + .max() + .unwrap(); + + // Advance clock to trigger deadline of all non-expired duties + clock.advance(std::time::Duration::from_secs(max_slot)); + + // Give the deadliner task time to wake up and process + // We need to yield multiple times to ensure the background task runs + for _ in 0..10 { + tokio::task::yield_now().await; + } + + // Collect expired duties from output channel + let mut actual_duties = Vec::new(); + for _ in 0..non_expired_len { + let duty = tokio::time::timeout(std::time::Duration::from_secs(1), output_rx.recv()) + .await + .expect("should receive within timeout") + .expect("should receive duty"); + actual_duties.push(duty); + } + + // Sort both for comparison + actual_duties.sort_by_key(|d| d.slot.inner()); + let mut expected_duties = non_expired_duties; + expected_duties.sort_by_key(|d| d.slot.inner()); + + assert_eq!(expected_duties, actual_duties); + + cancel_token.cancel(); + } + + #[test_case(DutyType::Exit ; "exit")] + #[test_case(DutyType::BuilderRegistration ; "builder_registration")] + #[tokio::test] + async fn test_never_expire_duties(duty_type: DutyType) { + let mock_client = create_mock_client(); + + let deadline_func = new_duty_deadline_func(&mock_client) + .await + .expect("should create deadline func"); + + let duty = Duty::new(SlotNumber::new(100), duty_type); + let result = deadline_func(duty).expect("should compute deadline"); + + assert_eq!(result, None, "duty should never expire"); + } + + // todo: uses hardcode beacon client for testing, should be refactored to use a + // real beacon client (testutils/beaconmock) + #[test_case(DutyType::Proposer ; "proposer")] + #[test_case(DutyType::Attester ; "attester")] + #[test_case(DutyType::Aggregator ; "aggregator")] + #[test_case(DutyType::PrepareAggregator ; "prepare_aggregator")] + #[test_case(DutyType::SyncMessage ; "sync_message")] + #[test_case(DutyType::SyncContribution ; "sync_contribution")] + #[test_case(DutyType::Randao ; "randao")] + #[test_case(DutyType::InfoSync ; "info_sync")] + #[test_case(DutyType::PrepareSyncContribution ; "prepare_sync_contribution")] + #[tokio::test] + async fn test_duty_deadline_durations(duty_type: DutyType) { + let mock_client = create_mock_client(); + + let genesis_time = mock_client.fetch_genesis_time().await.unwrap(); + let (slot_duration, _) = mock_client.fetch_slots_config().await.unwrap(); + + let margin = slot_duration + .checked_div(12) + .expect("margin calculation should not fail"); + + let time_since_genesis = Utc::now().signed_duration_since(genesis_time); + let slot_duration_chrono = to_chrono_duration(slot_duration).unwrap(); + let current_slot = u64::try_from( + time_since_genesis + .num_seconds() + .checked_div(slot_duration_chrono.num_seconds()) + .expect("slot duration should not be zero"), + ) + .expect("current slot should be positive"); + + let slot_start = { + let offset_secs = current_slot + .checked_mul(slot_duration.as_secs()) + .expect("slot offset should not overflow"); + let offset = chrono::Duration::try_seconds( + i64::try_from(offset_secs).expect("offset should fit in i64"), + ) + .expect("offset should be valid duration"); + genesis_time + .checked_add_signed(offset) + .expect("slot start should not overflow") + }; + + let deadline_func = new_duty_deadline_func(&mock_client) + .await + .expect("should create deadline func"); + + let expected_duration = match duty_type { + DutyType::Proposer | DutyType::Randao => { + // slotDuration/3 + margin + slot_duration + .checked_div(3) + .and_then(|d| d.checked_add(margin)) + .expect("duration calculation should not fail") + } + DutyType::Attester | DutyType::Aggregator | DutyType::PrepareAggregator => { + // 2*slotDuration + margin + slot_duration + .checked_mul(2) + .and_then(|d| d.checked_add(margin)) + .expect("duration calculation should not fail") + } + DutyType::SyncMessage => { + // 2*slotDuration/3 + margin + slot_duration + .checked_mul(2) + .and_then(|d| d.checked_div(3)) + .and_then(|d| d.checked_add(margin)) + .expect("duration calculation should not fail") + } + DutyType::SyncContribution | DutyType::InfoSync | DutyType::PrepareSyncContribution => { + // slotDuration + margin + slot_duration + .checked_add(margin) + .expect("duration calculation should not fail") + } + _ => panic!("unexpected duty type: {:?}", duty_type), + }; + + let duty = Duty::new(SlotNumber::new(current_slot), duty_type.clone()); + + let now_before_deadline = slot_start + .checked_add_signed(to_chrono_duration(expected_duration).unwrap()) + .and_then(|t| t.checked_sub_signed(chrono::Duration::try_milliseconds(1).unwrap())) + .expect("time calculation should not fail"); + + let deadline_opt = deadline_func(duty.clone()).expect("should compute deadline"); + + assert!( + deadline_opt.is_some(), + "duty {:?} should have a deadline", + duty_type + ); + + let deadline = deadline_opt.unwrap(); + + assert!( + now_before_deadline < deadline, + "duty {:?}: now ({}) should be before deadline ({})", + duty_type, + now_before_deadline, + deadline + ); + + let time_until_deadline = deadline.signed_duration_since(now_before_deadline); + assert_eq!( + time_until_deadline, + chrono::Duration::try_milliseconds(1).unwrap(), + "duty {:?}: deadline should be exactly 1ms after now (actual: {}ms)", + duty_type, + time_until_deadline.num_milliseconds() + ); + } + + /// Creates a mock EthBeaconNodeApiClient for testing. + fn create_mock_client() -> MockBeaconClient { + MockBeaconClient { + genesis_time: DateTime::from_timestamp(1646092800, 0).unwrap(), /* 2022-03-01 + * 00:00:00 UTC */ + slot_duration: std::time::Duration::from_secs(12), + slots_per_epoch: 16, + } + } + + /// Mock beacon client for testing. + struct MockBeaconClient { + genesis_time: DateTime, + slot_duration: std::time::Duration, + slots_per_epoch: u64, + } + + #[async_trait] + impl BeaconClientForDeadline for MockBeaconClient { + async fn fetch_genesis_time(&self) -> Result> { + Ok(self.genesis_time) + } + + async fn fetch_slots_config(&self) -> Result<(std::time::Duration, u64)> { + Ok((self.slot_duration, self.slots_per_epoch)) + } + } +} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index d696f161..2e356963 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -19,3 +19,17 @@ pub mod corepb; /// Semver version parsing utilities. pub mod version; + +/// Duty deadline tracking and notification. +pub mod deadline; + +/// parsigdb +pub mod parsigdb; + +mod parsigex_codec; + +pub use parsigex_codec::ParSigExCodecError; + +/// Test utilities. +#[cfg(test)] +pub mod testutils; diff --git a/crates/core/src/parasigdb/memory.rs b/crates/core/src/parsigdb/memory.rs similarity index 87% rename from crates/core/src/parasigdb/memory.rs rename to crates/core/src/parsigdb/memory.rs index 1a20e14e..e4d025e5 100644 --- a/crates/core/src/parasigdb/memory.rs +++ b/crates/core/src/parsigdb/memory.rs @@ -5,7 +5,8 @@ use tracing::{debug, warn}; use crate::{ deadline::Deadliner, - parasigdb::metrics::PARASIG_DB_METRICS, + parsigdb::metrics::PARSIG_DB_METRICS, + signeddata::SignedDataError, types::{Duty, DutyType, ParSignedData, ParSignedDataSet, PubKey}, }; use chrono::{DateTime, Utc}; @@ -55,9 +56,9 @@ pub type ThreshSub = Arc< /// Helper to create an internal subscriber from a closure. /// -/// The closure receives owned copies of the duty and data set. Since the closure -/// is `Fn` (can be called multiple times), you need to clone any captured Arc values -/// before the `async move` block. +/// The closure receives owned copies of the duty and data set. Since the +/// closure is `Fn` (can be called multiple times), you need to clone any +/// captured Arc values before the `async move` block. /// /// # Example /// ```ignore @@ -88,8 +89,8 @@ where /// Helper to create a threshold subscriber from a closure. /// /// The closure receives owned copies of the duty and data. Since the closure -/// is `Fn` (can be called multiple times), you need to clone any captured Arc values -/// before the `async move` block. +/// is `Fn` (can be called multiple times), you need to clone any captured Arc +/// values before the `async move` block. /// /// # Example /// ```ignore @@ -128,6 +129,10 @@ pub enum MemDBError { /// Share index of the mismatched signature share_idx: u64, }, + + /// Signed data error. + #[error("signed data error: {0}")] + SignedDataError(#[from] SignedDataError), } type Result = std::result::Result; @@ -186,8 +191,8 @@ impl MemDB { impl MemDB { /// Registers a subscriber for internally generated partial signed data. /// - /// The subscriber will be called when the node generates partial signed data - /// that needs to be exchanged with peers. + /// The subscriber will be called when the node generates partial signed + /// data that needs to be exchanged with peers. pub async fn subscribe_internal(&self, sub: InternalSub) -> Result<()> { let mut inner = self.inner.lock().await; inner.internal_subs.push(sub); @@ -204,11 +209,13 @@ impl MemDB { Ok(()) } - /// Stores internally generated partial signed data and notifies subscribers. + /// Stores internally generated partial signed data and notifies + /// subscribers. /// - /// This is called when the node generates partial signed data that needs to be - /// stored and exchanged with peers. It first stores the data (via `store_external`), - /// then calls all internal subscribers to trigger peer exchange. + /// This is called when the node generates partial signed data that needs to + /// be stored and exchanged with peers. It first stores the data (via + /// `store_external`), then calls all internal subscribers to trigger + /// peer exchange. pub async fn store_internal(&self, duty: &Duty, signed_set: &ParSignedDataSet) -> Result<()> { self.store_external(duty, signed_set).await?; @@ -226,9 +233,10 @@ impl MemDB { /// Stores externally received partial signed data and checks for threshold. /// - /// This is called when the node receives partial signed data from peers. It stores - /// the data, checks if enough matching signatures have been collected to meet the - /// threshold, and calls threshold subscribers when the threshold is reached. + /// This is called when the node receives partial signed data from peers. It + /// stores the data, checks if enough matching signatures have been + /// collected to meet the threshold, and calls threshold subscribers + /// when the threshold is reached. pub async fn store_external(&self, duty: &Duty, signed_data: &ParSignedDataSet) -> Result<()> { let _ = self.deadliner.add(duty.clone()).await; @@ -239,7 +247,7 @@ impl MemDB { .store( Key { duty: duty.clone(), - pub_key: pub_key.clone(), + pub_key: *pub_key, }, par_signed.clone(), ) @@ -257,7 +265,7 @@ impl MemDB { continue; }; - output.insert(pub_key.clone(), psigs); + output.insert(*pub_key, psigs); } if output.is_empty() { @@ -278,17 +286,15 @@ impl MemDB { /// Trims expired duties from the database. /// - /// This method runs in a loop, listening for expired duties from the deadliner - /// and removing their associated data from the database. It should be spawned - /// as a background task and will run until the cancellation token is triggered. + /// This method runs in a loop, listening for expired duties from the + /// deadliner and removing their associated data from the database. It + /// should be spawned as a background task and will run until the + /// cancellation token is triggered. pub async fn trim(&self) { - let deadliner_rx = self.deadliner.c(); - if deadliner_rx.is_none() { + let Some(mut deadliner_rx) = self.deadliner.c() else { warn!("Deadliner channel is not available"); return; - } - - let mut deadliner_rx = deadliner_rx.unwrap(); + }; loop { tokio::select! { @@ -345,14 +351,10 @@ impl MemDB { .push(k.clone()); if k.duty.duty_type == DutyType::Exit { - PARASIG_DB_METRICS.exit_total[&k.pub_key.to_string()].inc(); + PARSIG_DB_METRICS.exit_total[&k.pub_key.to_string()].inc(); } - let result = inner - .entries - .get(&k) - .map(|entries| entries.clone()) - .unwrap_or_default(); + let result = inner.entries.get(&k).cloned().unwrap_or_default(); Ok(Some(result)) } @@ -381,11 +383,11 @@ async fn get_threshold_matching( let mut sigs_by_msg_root: HashMap<[u8; 32], Vec> = HashMap::new(); for sig in sigs { - let root = sig.signed_data.message_root(); - sigs_by_msg_root - .entry(root) - .or_insert_with(Vec::new) - .push(sig.clone()); + let root = sig + .signed_data + .message_root() + .map_err(MemDBError::SignedDataError)?; + sigs_by_msg_root.entry(root).or_default().push(sig.clone()); } // Return the first set that has exactly threshold number of signatures diff --git a/crates/core/src/parsigdb/memory_internal_test.rs b/crates/core/src/parsigdb/memory_internal_test.rs new file mode 100644 index 00000000..58f4c474 --- /dev/null +++ b/crates/core/src/parsigdb/memory_internal_test.rs @@ -0,0 +1,215 @@ +use std::{ + sync::{Arc, Mutex as StdMutex}, + time::Duration, +}; + +use futures::future::{BoxFuture, FutureExt}; +use pluto_eth2api::{spec::altair, v1}; +use pluto_testutil as testutil; +use test_case::test_case; +use tokio::sync::{Mutex, mpsc}; +use tokio_util::sync::CancellationToken; + +use super::{MemDB, get_threshold_matching, threshold_subscriber}; +use crate::{ + deadline::Deadliner, + signeddata::{BeaconCommitteeSelection, SignedSyncMessage, VersionedAttestation}, + testutils::random_core_pub_key, + types::{Duty, DutyType, ParSignedData, ParSignedDataSet, SlotNumber}, +}; + +fn threshold(nodes: usize) -> u64 { + (2_u64 + .checked_mul(u64::try_from(nodes).expect("nodes overflow")) + .expect("nodes overflow")) + .div_ceil(3) +} + +#[test_case(Vec::new(), Vec::new() ; "empty")] +#[test_case(vec![0, 0, 0], vec![0, 1, 2] ; "all identical exact threshold")] +#[test_case(vec![0, 0, 0, 0], Vec::new() ; "all identical above threshold")] +#[test_case(vec![0, 0, 1, 0], vec![0, 1, 3] ; "one odd")] +#[test_case(vec![0, 0, 1, 1], Vec::new() ; "two odd")] +#[tokio::test] +async fn test_get_threshold_matching(input: Vec, output: Vec) { + const N: usize = 4; + + let slot = testutil::random_slot(); + let validator_index = testutil::random_v_idx(); + let roots = [testutil::random_root_bytes(), testutil::random_root_bytes()]; + let threshold = threshold(N); + + type Providers<'a> = [(&'a str, Box ParSignedData + 'a>); 2]; + + let providers: Providers<'_> = [ + ( + "sync_committee_message", + Box::new(|i| { + let message = altair::SyncCommitteeMessage { + slot, + beacon_block_root: roots[input[i]], + validator_index, + signature: testutil::random_eth2_signature_bytes(), + }; + + SignedSyncMessage::new_partial(message, u64::try_from(i.wrapping_add(1)).unwrap()) + }), + ), + ( + "selection", + Box::new(|i| { + let selection = v1::BeaconCommitteeSelection { + validator_index, + slot: u64::try_from(input[i]).unwrap(), + selection_proof: testutil::random_eth2_signature_bytes(), + }; + + BeaconCommitteeSelection::new_partial( + selection, + u64::try_from(i.wrapping_add(1)).unwrap(), + ) + }), + ), + ]; + + for (name, provider) in providers { + let mut data = Vec::new(); + for i in 0..input.len() { + data.push(provider(i)); + } + + let out = get_threshold_matching(&DutyType::SyncMessage, &data, threshold) + .await + .expect("threshold matching should succeed"); + let expect: Vec<_> = output.iter().map(|idx| data[*idx].clone()).collect(); + let expected_out = if expect.is_empty() { + None + } else { + Some(expect.clone()) + }; + + assert_eq!(expected_out, out, "{name}/output mismatch"); + assert_eq!( + out.as_ref() + .map(|matches| u64::try_from(matches.len()).unwrap() == threshold) + .unwrap_or(false), + expect.len() as u64 == threshold, + "{name}/ok mismatch" + ); + } +} + +#[tokio::test] +async fn test_memdb_threshold() { + const THRESHOLD: u64 = 7; + const N: usize = 10; + + let deadliner = Arc::new(TestDeadliner::new()); + let cancel = CancellationToken::new(); + let db = Arc::new(MemDB::new(cancel.clone(), THRESHOLD, deadliner.clone())); + + let trim_handle = tokio::spawn({ + let db = db.clone(); + async move { + db.trim().await; + } + }); + + let times_called = Arc::new(Mutex::new(0usize)); + db.subscribe_threshold(threshold_subscriber({ + let times_called = times_called.clone(); + move |_duty, _data| { + let times_called = times_called.clone(); + async move { + *times_called.lock().await += 1; + Ok(()) + } + } + })) + .await + .expect("subscription should succeed"); + + let pubkey = random_core_pub_key(); + let attestation = testutil::random_deneb_versioned_attestation(); + let duty = Duty::new_attester_duty(SlotNumber::new(123)); + + let enqueue_n = || async { + for i in 0..N { + let partial = VersionedAttestation::new_partial( + attestation.clone(), + u64::try_from(i + 1).unwrap(), + ) + .expect("versioned attestation should be valid"); + + let mut set = ParSignedDataSet::new(); + set.insert(pubkey, partial); + + db.store_external(&duty, &set) + .await + .expect("store_external should succeed"); + } + }; + + enqueue_n().await; + assert_eq!(1, *times_called.lock().await); + + deadliner.expire().await; + tokio::time::sleep(Duration::from_millis(20)).await; + + enqueue_n().await; + assert_eq!(2, *times_called.lock().await); + + cancel.cancel(); + trim_handle + .await + .expect("trim task should shut down cleanly"); +} + +struct TestDeadliner { + added: StdMutex>, + tx: mpsc::Sender, + rx: StdMutex>>, +} + +impl TestDeadliner { + fn new() -> Self { + let (tx, rx) = mpsc::channel(32); + Self { + added: StdMutex::new(Vec::new()), + tx, + rx: StdMutex::new(Some(rx)), + } + } + + async fn expire(&self) -> bool { + let duties = { + let mut added = self.added.lock().expect("test deadliner lock poisoned"); + std::mem::take(&mut *added) + }; + + for duty in duties { + if self.tx.send(duty).await.is_err() { + return false; + } + } + + true + } +} + +impl Deadliner for TestDeadliner { + fn add(&self, duty: Duty) -> BoxFuture<'_, bool> { + async move { + self.added + .lock() + .expect("test deadliner lock poisoned") + .push(duty); + true + } + .boxed() + } + + fn c(&self) -> Option> { + self.rx.lock().expect("test deadliner lock poisoned").take() + } +} diff --git a/crates/core/src/parsigdb/metrics.rs b/crates/core/src/parsigdb/metrics.rs new file mode 100644 index 00000000..24a05fd8 --- /dev/null +++ b/crates/core/src/parsigdb/metrics.rs @@ -0,0 +1,12 @@ +use vise::*; + +/// Metrics for the ParSigDB. +#[derive(Debug, Clone, Metrics)] +pub struct ParsigDBMetrics { + /// Total number of partially signed voluntary exits per public key + #[metrics(labels = ["pubkey"])] + pub exit_total: LabeledFamily, +} + +/// Global metrics for the ParSigDB. +pub static PARSIG_DB_METRICS: Global = Global::new(); diff --git a/crates/core/src/parsigdb/mod.rs b/crates/core/src/parsigdb/mod.rs new file mode 100644 index 00000000..fd01b279 --- /dev/null +++ b/crates/core/src/parsigdb/mod.rs @@ -0,0 +1,5 @@ +/// Memory implementation of the ParSigDB. +pub mod memory; + +/// Metrics for the ParSigDB. +pub mod metrics; diff --git a/crates/core/src/parsigex_codec.rs b/crates/core/src/parsigex_codec.rs new file mode 100644 index 00000000..1859be72 --- /dev/null +++ b/crates/core/src/parsigex_codec.rs @@ -0,0 +1,116 @@ +//! Partial signature exchange codec helpers used by core types. + +use std::any::Any; + +use crate::{ + signeddata::{ + Attestation, BeaconCommitteeSelection, SignedAggregateAndProof, SignedRandao, + SignedSyncContributionAndProof, SignedSyncMessage, SignedVoluntaryExit, + SyncCommitteeSelection, VersionedAttestation, VersionedSignedAggregateAndProof, + VersionedSignedProposal, VersionedSignedValidatorRegistration, + }, + types::{DutyType, Signature, SignedData}, +}; + +/// Error type for partial signature exchange codec operations. +#[derive(Debug, thiserror::Error)] +pub enum ParSigExCodecError { + /// Missing duty or data set fields. + #[error("invalid parsigex msg fields")] + InvalidMessageFields, + + /// Invalid partial signed data set proto. + #[error("invalid partial signed data set proto fields")] + InvalidParSignedDataSetFields, + + /// Invalid partial signed proto. + #[error("invalid partial signed proto")] + InvalidParSignedProto, + + /// Invalid duty type. + #[error("invalid duty")] + InvalidDuty, + + /// Unsupported duty type. + #[error("unsupported duty type")] + UnsupportedDutyType, + + /// Deprecated builder proposer duty. + #[error("deprecated duty builder proposer")] + DeprecatedBuilderProposer, + + /// Failed to parse a public key. + #[error("invalid public key: {0}")] + InvalidPubKey(String), + + /// Invalid share index. + #[error("invalid share index")] + InvalidShareIndex, + + /// Serialization failed. + #[error("marshal signed data: {0}")] + Serialize(#[from] serde_json::Error), +} + +pub(crate) fn serialize_signed_data(data: &dyn SignedData) -> Result, ParSigExCodecError> { + let any = data as &dyn Any; + + macro_rules! serialize_as { + ($ty:ty) => { + if let Some(value) = any.downcast_ref::<$ty>() { + return Ok(serde_json::to_vec(value)?); + } + }; + } + + serialize_as!(Attestation); + serialize_as!(VersionedAttestation); + serialize_as!(VersionedSignedProposal); + serialize_as!(VersionedSignedValidatorRegistration); + serialize_as!(SignedVoluntaryExit); + serialize_as!(SignedRandao); + serialize_as!(Signature); + serialize_as!(BeaconCommitteeSelection); + serialize_as!(SignedAggregateAndProof); + serialize_as!(VersionedSignedAggregateAndProof); + serialize_as!(SignedSyncMessage); + serialize_as!(SyncCommitteeSelection); + serialize_as!(SignedSyncContributionAndProof); + + Err(ParSigExCodecError::UnsupportedDutyType) +} + +pub(crate) fn deserialize_signed_data( + duty_type: &DutyType, + bytes: &[u8], +) -> Result, ParSigExCodecError> { + macro_rules! deserialize_json { + ($ty:ty) => { + serde_json::from_slice::<$ty>(bytes) + .map(|value| Box::new(value) as Box) + .map_err(ParSigExCodecError::from) + }; + } + + match duty_type { + DutyType::Attester => deserialize_json!(VersionedAttestation) + .or_else(|_| deserialize_json!(Attestation)) + .map_err(|_| ParSigExCodecError::UnsupportedDutyType), + DutyType::Proposer => deserialize_json!(VersionedSignedProposal), + DutyType::BuilderProposer => Err(ParSigExCodecError::DeprecatedBuilderProposer), + DutyType::BuilderRegistration => deserialize_json!(VersionedSignedValidatorRegistration), + DutyType::Exit => deserialize_json!(SignedVoluntaryExit), + DutyType::Randao => deserialize_json!(SignedRandao), + DutyType::Signature => deserialize_json!(Signature), + DutyType::PrepareAggregator => deserialize_json!(BeaconCommitteeSelection), + DutyType::Aggregator => deserialize_json!(VersionedSignedAggregateAndProof) + .or_else(|_| deserialize_json!(SignedAggregateAndProof)) + .map_err(|_| ParSigExCodecError::UnsupportedDutyType), + DutyType::SyncMessage => deserialize_json!(SignedSyncMessage), + DutyType::PrepareSyncContribution => deserialize_json!(SyncCommitteeSelection), + DutyType::SyncContribution => deserialize_json!(SignedSyncContributionAndProof), + DutyType::Unknown | DutyType::InfoSync | DutyType::DutySentinel(_) => { + Err(ParSigExCodecError::UnsupportedDutyType) + } + } +} diff --git a/crates/core/src/signeddata.rs b/crates/core/src/signeddata.rs index 09865daf..4244e8ab 100644 --- a/crates/core/src/signeddata.rs +++ b/crates/core/src/signeddata.rs @@ -48,6 +48,9 @@ pub enum SignedDataError { /// Invalid attestation wrapper JSON. #[error("unmarshal attestation")] AttestationJson, + /// Custom error. + #[error("{0}")] + Custom(Box), } fn hash_root(value: &T) -> [u8; 32] { @@ -127,23 +130,21 @@ impl Signature { } /// Creates a partially signed signature wrapper. - pub fn new_partial(sig: Self, share_idx: u64) -> ParSignedData { + pub fn new_partial(sig: Self, share_idx: u64) -> ParSignedData { ParSignedData::new(sig, share_idx) } } impl SignedData for Signature { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { Ok(self.clone()) } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { Ok(signature) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { Err(SignedDataError::UnsupportedSignatureMessageRoot) } } @@ -179,7 +180,7 @@ impl VersionedSignedProposal { pub fn new_partial( proposal: versioned::VersionedSignedProposal, share_idx: u64, - ) -> Result, SignedDataError> { + ) -> Result { Ok(ParSignedData::new(Self::new(proposal)?, share_idx)) } @@ -222,7 +223,7 @@ impl VersionedSignedProposal { pub fn new_partial_from_blinded_proposal( proposal: versioned::VersionedSignedBlindedProposal, share_idx: u64, - ) -> Result, SignedDataError> { + ) -> Result { Ok(ParSignedData::new( Self::from_blinded_proposal(proposal)?, share_idx, @@ -231,9 +232,7 @@ impl VersionedSignedProposal { } impl SignedData for VersionedSignedProposal { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { let proposal = &self.0; if proposal.version == versioned::DataVersion::Unknown { return Err(SignedDataError::UnknownVersion); @@ -241,7 +240,7 @@ impl SignedData for VersionedSignedProposal { Ok(sig_from_eth2(proposal.block.signature())) } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { let mut out = self.clone(); let proposal = &mut out.0; if proposal.version == versioned::DataVersion::Unknown { @@ -253,7 +252,7 @@ impl SignedData for VersionedSignedProposal { Ok(out) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { let proposal = &self.0; if proposal.version == versioned::DataVersion::Unknown { return Err(SignedDataError::UnknownVersion); @@ -378,25 +377,23 @@ impl Attestation { } /// Creates a partial signed attestation wrapper. - pub fn new_partial(attestation: phase0::Attestation, share_idx: u64) -> ParSignedData { + pub fn new_partial(attestation: phase0::Attestation, share_idx: u64) -> ParSignedData { ParSignedData::new(Self::new(attestation), share_idx) } } impl SignedData for Attestation { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { Ok(sig_from_eth2(self.0.signature)) } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { let mut out = self.clone(); out.0.signature = sig_to_eth2(&signature); Ok(out) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { Ok(hash_root(&self.0.data)) } } @@ -427,7 +424,7 @@ impl VersionedAttestation { pub fn new_partial( attestation: versioned::VersionedAttestation, share_idx: u64, - ) -> Result, SignedDataError> { + ) -> Result { Ok(ParSignedData::new(Self::new(attestation)?, share_idx)) } @@ -447,9 +444,7 @@ impl VersionedAttestation { } impl SignedData for VersionedAttestation { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { let version = self.0.version; if version == versioned::DataVersion::Unknown { return Err(SignedDataError::UnknownVersion); @@ -461,7 +456,7 @@ impl SignedData for VersionedAttestation { .ok_or(SignedDataError::MissingAttestation(version)) } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { let mut out = self.clone(); let version = out.0.version; if version == versioned::DataVersion::Unknown { @@ -476,7 +471,7 @@ impl SignedData for VersionedAttestation { Ok(out) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { let version = self.0.version; if version == versioned::DataVersion::Unknown { return Err(SignedDataError::UnknownVersion); @@ -584,19 +579,17 @@ pub struct SignedVoluntaryExit( ); impl SignedData for SignedVoluntaryExit { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { Ok(sig_from_eth2(self.0.signature)) } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { let mut out = self.clone(); out.0.signature = sig_to_eth2(&signature); Ok(out) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { Ok(hash_root(&self.0.message)) } } @@ -608,7 +601,7 @@ impl SignedVoluntaryExit { } /// Creates a partially signed voluntary exit wrapper. - pub fn new_partial(exit: phase0::SignedVoluntaryExit, share_idx: u64) -> ParSignedData { + pub fn new_partial(exit: phase0::SignedVoluntaryExit, share_idx: u64) -> ParSignedData { ParSignedData::new(Self::new(exit), share_idx) } } @@ -643,15 +636,13 @@ impl VersionedSignedValidatorRegistration { pub fn new_partial( registration: versioned::VersionedSignedValidatorRegistration, share_idx: u64, - ) -> Result, SignedDataError> { + ) -> Result { Ok(ParSignedData::new(Self::new(registration)?, share_idx)) } } impl SignedData for VersionedSignedValidatorRegistration { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { match self.0.version { versioned::BuilderVersion::V1 => self .0 @@ -663,7 +654,7 @@ impl SignedData for VersionedSignedValidatorRegistration { } } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { let mut out = self.clone(); match out.0.version { versioned::BuilderVersion::V1 => { @@ -680,7 +671,7 @@ impl SignedData for VersionedSignedValidatorRegistration { Ok(out) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { match self.0.version { versioned::BuilderVersion::V1 => { let Some(v1) = self.0.v1.as_ref() else { @@ -748,19 +739,17 @@ pub struct SignedRandao( ); impl SignedData for SignedRandao { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { Ok(sig_from_eth2(self.0.signature)) } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { let mut out = self.clone(); out.0.signature = sig_to_eth2(&signature); Ok(out) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { Ok(hash_root(&self.0)) } } @@ -779,7 +768,7 @@ impl SignedRandao { epoch: phase0::Epoch, randao: phase0::BLSSignature, share_idx: u64, - ) -> ParSignedData { + ) -> ParSignedData { ParSignedData::new(Self::new(epoch, randao), share_idx) } } @@ -793,19 +782,17 @@ pub struct BeaconCommitteeSelection( ); impl SignedData for BeaconCommitteeSelection { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { Ok(sig_from_eth2(self.0.selection_proof)) } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { let mut out = self.clone(); out.0.selection_proof = sig_to_eth2(&signature); Ok(out) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { Ok(hash_root(&self.0.slot)) } } @@ -817,10 +804,7 @@ impl BeaconCommitteeSelection { } /// Creates a partial beacon committee selection wrapper. - pub fn new_partial( - selection: v1::BeaconCommitteeSelection, - share_idx: u64, - ) -> ParSignedData { + pub fn new_partial(selection: v1::BeaconCommitteeSelection, share_idx: u64) -> ParSignedData { ParSignedData::new(Self::new(selection), share_idx) } } @@ -834,19 +818,17 @@ pub struct SyncCommitteeSelection( ); impl SignedData for SyncCommitteeSelection { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { Ok(sig_from_eth2(self.0.selection_proof)) } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { let mut out = self.clone(); out.0.selection_proof = sig_to_eth2(&signature); Ok(out) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { let data = altair::SyncAggregatorSelectionData { slot: self.0.slot, subcommittee_index: self.0.subcommittee_index, @@ -863,10 +845,7 @@ impl SyncCommitteeSelection { } /// Creates a partial sync committee selection wrapper. - pub fn new_partial( - selection: v1::SyncCommitteeSelection, - share_idx: u64, - ) -> ParSignedData { + pub fn new_partial(selection: v1::SyncCommitteeSelection, share_idx: u64) -> ParSignedData { ParSignedData::new(Self::new(selection), share_idx) } } @@ -880,19 +859,17 @@ pub struct SignedAggregateAndProof( ); impl SignedData for SignedAggregateAndProof { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { Ok(sig_from_eth2(self.0.signature)) } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { let mut out = self.clone(); out.0.signature = sig_to_eth2(&signature); Ok(out) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { Ok(hash_root(&self.0.message)) } } @@ -904,10 +881,7 @@ impl SignedAggregateAndProof { } /// Creates a partial signed aggregate-and-proof wrapper. - pub fn new_partial( - data: phase0::SignedAggregateAndProof, - share_idx: u64, - ) -> ParSignedData { + pub fn new_partial(data: phase0::SignedAggregateAndProof, share_idx: u64) -> ParSignedData { ParSignedData::new(Self::new(data), share_idx) } } @@ -947,15 +921,13 @@ impl VersionedSignedAggregateAndProof { pub fn new_partial( data: versioned::VersionedSignedAggregateAndProof, share_idx: u64, - ) -> ParSignedData { + ) -> ParSignedData { ParSignedData::new(Self::new(data), share_idx) } } impl SignedData for VersionedSignedAggregateAndProof { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { let version = self.0.version; if version == versioned::DataVersion::Unknown { return Err(SignedDataError::UnknownVersion); @@ -964,7 +936,7 @@ impl SignedData for VersionedSignedAggregateAndProof { Ok(sig_from_eth2(self.0.aggregate_and_proof.signature())) } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { let mut out = self.clone(); let version = out.0.version; if version == versioned::DataVersion::Unknown { @@ -977,7 +949,7 @@ impl SignedData for VersionedSignedAggregateAndProof { Ok(out) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { let version = self.0.version; if version == versioned::DataVersion::Unknown { return Err(SignedDataError::UnknownVersion); @@ -1068,19 +1040,17 @@ pub struct SignedSyncMessage( ); impl SignedData for SignedSyncMessage { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { Ok(sig_from_eth2(self.0.signature)) } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { let mut out = self.clone(); out.0.signature = sig_to_eth2(&signature); Ok(out) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { Ok(self.0.beacon_block_root) } } @@ -1092,7 +1062,7 @@ impl SignedSyncMessage { } /// Creates a partial signed sync committee message wrapper. - pub fn new_partial(data: altair::SyncCommitteeMessage, share_idx: u64) -> ParSignedData { + pub fn new_partial(data: altair::SyncCommitteeMessage, share_idx: u64) -> ParSignedData { ParSignedData::new(Self::new(data), share_idx) } } @@ -1106,19 +1076,17 @@ pub struct SyncContributionAndProof( ); impl SignedData for SyncContributionAndProof { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { Ok(sig_from_eth2(self.0.selection_proof)) } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { let mut out = self.clone(); out.0.selection_proof = sig_to_eth2(&signature); Ok(out) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { let data = altair::SyncAggregatorSelectionData { slot: self.0.contribution.slot, subcommittee_index: self.0.contribution.subcommittee_index, @@ -1135,7 +1103,7 @@ impl SyncContributionAndProof { } /// Creates a partial sync contribution-and-proof wrapper. - pub fn new_partial(proof: altair::ContributionAndProof, share_idx: u64) -> ParSignedData { + pub fn new_partial(proof: altair::ContributionAndProof, share_idx: u64) -> ParSignedData { ParSignedData::new(Self::new(proof), share_idx) } } @@ -1149,19 +1117,17 @@ pub struct SignedSyncContributionAndProof( ); impl SignedData for SignedSyncContributionAndProof { - type Error = SignedDataError; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { Ok(sig_from_eth2(self.0.signature)) } - fn set_signature(&self, signature: Signature) -> Result { + fn set_signature(&self, signature: Signature) -> Result { let mut out = self.clone(); out.0.signature = sig_to_eth2(&signature); Ok(out) } - fn message_root(&self) -> Result<[u8; 32], Self::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { Ok(hash_root(&self.0.message)) } } @@ -1173,10 +1139,7 @@ impl SignedSyncContributionAndProof { } /// Creates a partial signed sync contribution-and-proof wrapper. - pub fn new_partial( - proof: altair::SignedContributionAndProof, - share_idx: u64, - ) -> ParSignedData { + pub fn new_partial(proof: altair::SignedContributionAndProof, share_idx: u64) -> ParSignedData { ParSignedData::new(Self::new(proof), share_idx) } } @@ -2052,7 +2015,7 @@ mod tests { fn assert_set_signature(data: T) where - T: SignedData + std::fmt::Debug + PartialEq, + T: SignedData + std::fmt::Debug + PartialEq, { let clone = data.set_signature(sample_signature(0xAB)).unwrap(); let clone_sig = clone.signature().unwrap(); diff --git a/crates/core/src/testutils.rs b/crates/core/src/testutils.rs new file mode 100644 index 00000000..ae6b4f47 --- /dev/null +++ b/crates/core/src/testutils.rs @@ -0,0 +1,143 @@ +//! Test utilities for the Charon core. + +use rand::{Rng, SeedableRng}; + +use crate::types::PubKey; + +/// The size of a BLS public key in bytes. +const PK_LEN: usize = 48; + +/// Creates a new seeded random number generator. +/// +/// Returns a new random number generator seeded with a random value. +/// This matches the Go implementation: +/// `rand.New(rand.NewSource(rand.Int63()))`. +pub fn new_seed_rand() -> impl Rng { + let seed = rand::random::(); + rand::rngs::StdRng::seed_from_u64(seed) +} + +/// Returns a random core workflow pubkey. +/// +/// This is a convenience wrapper around `random_core_pub_key_seed` that creates +/// a new random seed for each call. +pub fn random_core_pub_key() -> PubKey { + random_core_pub_key_seed(new_seed_rand()) +} + +/// Returns a random core workflow pubkey using a provided random source. +/// +/// # Arguments +/// +/// * `rng` - A random number generator to use for generating the pubkey. +/// +/// # Panics +/// +/// Panics if the generated bytes cannot be converted to a valid PubKey. +/// This should never happen in practice as we generate exactly 48 bytes. +pub fn random_core_pub_key_seed(mut rng: R) -> PubKey { + let pubkey = deterministic_pub_key_seed(&mut rng); + PubKey::try_from(&pubkey[..]).expect("valid pubkey length") +} + +/// Generates a deterministic pubkey from a seeded RNG. +/// +/// This function creates a new RNG seeded from the input RNG, then fills +/// a 48-byte array with random data. This matches the Go implementation: +/// +/// ```go +/// random := rand.New(rand.NewSource(r.Int63())) +/// var key tbls.PublicKey +/// _, err := random.Read(key[:]) +/// ``` +/// +/// # Arguments +/// +/// * `rng` - A mutable reference to a random number generator. +/// +/// # Returns +/// +/// A 48-byte array containing random data suitable for use as a public key. +fn deterministic_pub_key_seed(rng: &mut R) -> [u8; PK_LEN] { + // Create a new RNG seeded from the input RNG (matching Go's + // rand.New(rand.NewSource(r.Int63()))) + let seed: u64 = rng.r#gen(); + let mut seeded_rng = rand::rngs::StdRng::seed_from_u64(seed); + + let mut key = [0u8; PK_LEN]; + // Fill the key with random bytes + for byte in &mut key { + *byte = seeded_rng.r#gen(); + } + + key +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new_seed_rand_produces_different_values() { + let mut rng1 = new_seed_rand(); + let mut rng2 = new_seed_rand(); + + let val1: u64 = rng1.r#gen(); + let val2: u64 = rng2.r#gen(); + + // These should be different with very high probability + assert_ne!(val1, val2); + } + + #[test] + fn test_random_core_pub_key_generates_valid_keys() { + let pk1 = random_core_pub_key(); + let pk2 = random_core_pub_key(); + + // Keys should be different + assert_ne!(pk1, pk2); + + // Keys should have the correct length when serialized + assert_eq!(pk1.to_string().len(), 98); // 0x + 96 hex chars + assert_eq!(pk2.to_string().len(), 98); + } + + #[test] + fn test_random_core_pub_key_seed_is_deterministic() { + let seed = 12345u64; + let mut rng1 = rand::rngs::StdRng::seed_from_u64(seed); + let mut rng2 = rand::rngs::StdRng::seed_from_u64(seed); + + let pk1 = random_core_pub_key_seed(&mut rng1); + let pk2 = random_core_pub_key_seed(&mut rng2); + + // Same seed should produce same key + assert_eq!(pk1, pk2); + } + + #[test] + fn test_deterministic_pub_key_seed() { + let seed = 42u64; + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + + let key = deterministic_pub_key_seed(&mut rng); + + // Check that we got 48 bytes + assert_eq!(key.len(), PK_LEN); + + // Check that the key is not all zeros (very unlikely with a proper RNG) + assert!(key.iter().any(|&b| b != 0)); + } + + #[test] + fn test_random_core_pub_key_seed_different_rngs() { + let mut rng1 = rand::rngs::StdRng::seed_from_u64(1); + let mut rng2 = rand::rngs::StdRng::seed_from_u64(2); + + let pk1 = random_core_pub_key_seed(&mut rng1); + let pk2 = random_core_pub_key_seed(&mut rng2); + + // Different seeds should produce different keys + assert_ne!(pk1, pk2); + } +} diff --git a/crates/core/src/types.rs b/crates/core/src/types.rs index 2d0f3b37..05a73f03 100644 --- a/crates/core/src/types.rs +++ b/crates/core/src/types.rs @@ -1,11 +1,20 @@ //! Types for the Charon core. -use std::{collections::HashMap, fmt::Display, iter}; +use std::{any::Any, collections::HashMap, fmt::Display, iter}; use chrono::{DateTime, Duration, Utc}; +use dyn_clone::DynClone; +use dyn_eq::DynEq; use serde::{Deserialize, Serialize}; use std::fmt::Debug as StdDebug; +use crate::{ + ParSigExCodecError, + corepb::v1::core as pbcore, + parsigex_codec::{deserialize_signed_data, serialize_signed_data}, + signeddata::SignedDataError, +}; + /// The type of duty. #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] @@ -62,6 +71,52 @@ impl DutyType { } } +impl From<&DutyType> for i32 { + fn from(duty_type: &DutyType) -> Self { + match duty_type { + DutyType::Unknown => 0, + DutyType::Proposer => 1, + DutyType::Attester => 2, + DutyType::Signature => 3, + DutyType::Exit => 4, + DutyType::BuilderProposer => 5, + DutyType::BuilderRegistration => 6, + DutyType::Randao => 7, + DutyType::PrepareAggregator => 8, + DutyType::Aggregator => 9, + DutyType::SyncMessage => 10, + DutyType::PrepareSyncContribution => 11, + DutyType::SyncContribution => 12, + DutyType::InfoSync => 13, + DutyType::DutySentinel(_) => 14, + } + } +} + +impl TryFrom for DutyType { + type Error = ParSigExCodecError; + + fn try_from(value: i32) -> Result { + match value { + 0 => Ok(DutyType::Unknown), + 1 => Ok(DutyType::Proposer), + 2 => Ok(DutyType::Attester), + 3 => Ok(DutyType::Signature), + 4 => Ok(DutyType::Exit), + 5 => Ok(DutyType::BuilderProposer), + 6 => Ok(DutyType::BuilderRegistration), + 7 => Ok(DutyType::Randao), + 8 => Ok(DutyType::PrepareAggregator), + 9 => Ok(DutyType::Aggregator), + 10 => Ok(DutyType::SyncMessage), + 11 => Ok(DutyType::PrepareSyncContribution), + 12 => Ok(DutyType::SyncContribution), + 13 => Ok(DutyType::InfoSync), + _ => Err(ParSigExCodecError::InvalidDuty), + } + } +} + /// SlotNumber struct #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct SlotNumber(u64); @@ -188,6 +243,28 @@ impl Duty { } } +impl From<&Duty> for pbcore::Duty { + fn from(duty: &Duty) -> Self { + Self { + slot: duty.slot.inner(), + r#type: i32::from(&duty.duty_type), + } + } +} + +impl TryFrom<&pbcore::Duty> for Duty { + type Error = ParSigExCodecError; + + fn try_from(duty: &pbcore::Duty) -> Result { + let duty_type = DutyType::try_from(duty.r#type)?; + if !duty_type.is_valid() { + return Err(ParSigExCodecError::InvalidDuty); + } + + Ok(Self::new(duty.slot.into(), duty_type)) + } +} + /// The type of proposal. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] @@ -448,42 +525,64 @@ impl AsRef<[u8; SIG_LEN]> for Signature { } /// Signed data type -pub trait SignedData: Clone + Serialize + StdDebug { - /// The error type - type Error: std::error::Error; - +pub trait SignedData: Any + DynClone + DynEq + StdDebug + Send + Sync { /// signature returns the signed duty data's signature. - fn signature(&self) -> Result; + fn signature(&self) -> Result; /// Returns a copy of signed duty data with the signature replaced. - fn set_signature(&self, signature: Signature) -> Result + fn set_signature(&self, signature: Signature) -> Result where Self: Sized; /// message_root returns the message root for the unsigned data. - fn message_root(&self) -> Result<[u8; 32], Self::Error>; + fn message_root(&self) -> Result<[u8; 32], SignedDataError>; } +dyn_eq::eq_trait_object!(SignedData); +dyn_clone::clone_trait_object!(SignedData); + // todo: add Eth2SignedData type // https://github.com/ObolNetwork/charon/blob/b3008103c5429b031b63518195f4c49db4e9a68d/core/types.go#L396 /// ParSignedData is a partially signed duty data only signed by a single /// threshold BLS share. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ParSignedData { +#[derive(Debug)] +pub struct ParSignedData { /// Partially signed duty data. - pub signed_data: T, + pub signed_data: Box, /// Threshold BLS share index. pub share_idx: u64, } -impl ParSignedData -where - T: SignedData, -{ +impl Clone for ParSignedData { + fn clone(&self) -> Self { + Self { + signed_data: self.signed_data.clone(), + share_idx: self.share_idx, + } + } +} + +impl PartialEq for ParSignedData { + fn eq(&self, other: &Self) -> bool { + self.share_idx == other.share_idx && self.signed_data == other.signed_data + } +} + +impl Eq for ParSignedData {} + +impl ParSignedData { /// Create a new partially signed data. - pub fn new(partially_signed_data: T, share_idx: u64) -> Self { + pub fn new(partially_signed_data: T, share_idx: u64) -> Self { + Self { + signed_data: Box::new(partially_signed_data), + share_idx, + } + } + + /// Create a new partially signed data from a boxed signed data. + pub fn new_boxed(partially_signed_data: Box, share_idx: u64) -> Self { Self { signed_data: partially_signed_data, share_idx, @@ -491,55 +590,109 @@ where } } -/// ParSignedDataSet is a set of partially signed duty data only signed by a -/// single threshold BLS share. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ParSignedDataSet(HashMap>); +impl TryFrom<&ParSignedData> for pbcore::ParSignedData { + type Error = ParSigExCodecError; + + fn try_from(data: &ParSignedData) -> Result { + let encoded = serialize_signed_data(data.signed_data.as_ref())?; + let share_idx = + i32::try_from(data.share_idx).map_err(|_| ParSigExCodecError::InvalidShareIndex)?; + let signature = data.signed_data.signature().map_err(|err| { + ParSigExCodecError::Serialize(serde_json::Error::io(std::io::Error::other( + err.to_string(), + ))) + })?; + + Ok(Self { + data: encoded.into(), + signature: signature.as_ref().to_vec().into(), + share_idx, + }) + } +} -impl Default for ParSignedDataSet -where - T: SignedData, -{ - fn default() -> Self { - Self(HashMap::default()) +impl TryFrom<(&DutyType, &pbcore::ParSignedData)> for ParSignedData { + type Error = ParSigExCodecError; + + fn try_from(value: (&DutyType, &pbcore::ParSignedData)) -> Result { + let (duty_type, data) = value; + let share_idx = + u64::try_from(data.share_idx).map_err(|_| ParSigExCodecError::InvalidShareIndex)?; + let signed_data = deserialize_signed_data(duty_type, &data.data)?; + Ok(Self::new_boxed(signed_data, share_idx)) } } -impl ParSignedDataSet -where - T: SignedData, -{ +/// ParSignedDataSet is a set of partially signed duty data only signed by a +/// single threshold BLS share. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct ParSignedDataSet(HashMap); + +impl ParSignedDataSet { /// Create a new partially signed data set. pub fn new() -> Self { Self::default() } /// Get a partially signed data by public key. - pub fn get(&self, pub_key: &PubKey) -> Option<&ParSignedData> { + pub fn get(&self, pub_key: &PubKey) -> Option<&ParSignedData> { self.inner().get(pub_key) } /// Insert a partially signed data. - pub fn insert(&mut self, pub_key: PubKey, partially_signed_data: ParSignedData) { + pub fn insert(&mut self, pub_key: PubKey, partially_signed_data: ParSignedData) { self.inner_mut().insert(pub_key, partially_signed_data); } /// Remove a partially signed data by public key. - pub fn remove(&mut self, pub_key: &PubKey) -> Option> { + pub fn remove(&mut self, pub_key: &PubKey) -> Option { self.inner_mut().remove(pub_key) } /// Inner partially signed data set. - pub fn inner(&self) -> &HashMap> { + pub fn inner(&self) -> &HashMap { &self.0 } /// Inner partially signed data set. - pub fn inner_mut(&mut self) -> &mut HashMap> { + pub fn inner_mut(&mut self) -> &mut HashMap { &mut self.0 } } +impl TryFrom<&ParSignedDataSet> for pbcore::ParSignedDataSet { + type Error = ParSigExCodecError; + + fn try_from(set: &ParSignedDataSet) -> Result { + let mut out = std::collections::BTreeMap::new(); + for (pub_key, value) in set.inner() { + out.insert(pub_key.to_string(), pbcore::ParSignedData::try_from(value)?); + } + + Ok(Self { set: out }) + } +} + +impl TryFrom<(&DutyType, &pbcore::ParSignedDataSet)> for ParSignedDataSet { + type Error = ParSigExCodecError; + + fn try_from(value: (&DutyType, &pbcore::ParSignedDataSet)) -> Result { + let (duty_type, set) = value; + if set.set.is_empty() { + return Err(ParSigExCodecError::InvalidParSignedDataSetFields); + } + + let mut out = Self::new(); + for (pub_key, value) in &set.set { + let pub_key = PubKey::try_from(pub_key.as_str()) + .map_err(|_| ParSigExCodecError::InvalidPubKey(pub_key.clone()))?; + out.insert(pub_key, ParSignedData::try_from((duty_type, value))?); + } + + Ok(out) + } +} + /// SignedDataSet is a set of signed duty data. #[derive(Debug, Clone, PartialEq, Eq)] pub struct SignedDataSet(HashMap); @@ -856,17 +1009,15 @@ mod tests { struct MockSignedData; impl SignedData for MockSignedData { - type Error = std::io::Error; - - fn signature(&self) -> Result { + fn signature(&self) -> Result { Ok(Signature::new([42u8; SIG_LEN])) } - fn set_signature(&self, _signature: Signature) -> Result { + fn set_signature(&self, _signature: Signature) -> Result { Ok(self.clone()) } - fn message_root(&self) -> Result<[u8; 32], std::io::Error> { + fn message_root(&self) -> Result<[u8; 32], SignedDataError> { Ok([42u8; 32]) } } @@ -874,13 +1025,15 @@ mod tests { #[test] fn test_partially_signed_data_set() { let mut partially_signed_data_set = ParSignedDataSet::new(); - partially_signed_data_set.insert( - PubKey::new([42u8; PK_LEN]), - ParSignedData::new(MockSignedData, 0), - ); + let par_signed = ParSignedData::new(MockSignedData, 0); + partially_signed_data_set.insert(PubKey::new([42u8; PK_LEN]), par_signed.clone()); + let retrieved = partially_signed_data_set.get(&PubKey::new([42u8; PK_LEN])); + assert!(retrieved.is_some()); + let retrieved = retrieved.unwrap(); + assert_eq!(retrieved.share_idx, 0); assert_eq!( - partially_signed_data_set.get(&PubKey::new([42u8; PK_LEN])), - Some(&ParSignedData::new(MockSignedData, 0)) + retrieved.signed_data.signature().unwrap(), + Signature::new([42u8; SIG_LEN]) ); } diff --git a/crates/p2p/Cargo.toml b/crates/p2p/Cargo.toml index 750609b1..935dbaf7 100644 --- a/crates/p2p/Cargo.toml +++ b/crates/p2p/Cargo.toml @@ -16,6 +16,7 @@ thiserror.workspace = true k256.workspace = true pluto-eth2util.workspace = true pluto-k1util.workspace = true +prost.workspace = true vise.workspace = true tokio.workspace = true tokio-util.workspace = true @@ -29,6 +30,7 @@ pluto-core.workspace = true backon.workspace = true reqwest.workspace = true url.workspace = true +unsigned-varint.workspace = true [dev-dependencies] pluto-testutil.workspace = true diff --git a/crates/p2p/src/lib.rs b/crates/p2p/src/lib.rs index 0c06608b..5e85afc4 100644 --- a/crates/p2p/src/lib.rs +++ b/crates/p2p/src/lib.rs @@ -51,3 +51,6 @@ pub mod relay; /// Force direct connection behaviour. pub mod force_direct; + +/// Protobuf utilities. +pub mod proto; diff --git a/crates/p2p/src/p2p.rs b/crates/p2p/src/p2p.rs index a35a2775..33ba5ab1 100644 --- a/crates/p2p/src/p2p.rs +++ b/crates/p2p/src/p2p.rs @@ -110,6 +110,14 @@ use crate::{ utils, }; +const YAMUX_MAX_NUM_STREAMS: usize = 2_048; + +fn yamux_config() -> yamux::Config { + let mut config = yamux::Config::default(); + config.set_max_num_streams(YAMUX_MAX_NUM_STREAMS); + config +} + /// P2P error. #[derive(Debug, thiserror::Error)] pub enum P2PError { @@ -323,20 +331,17 @@ impl Node { { let swarm = SwarmBuilder::with_existing_identity(keypair) .with_tokio() - .with_tcp( - tcp::Config::default(), - noise::Config::new, - yamux::Config::default, - ) + .with_tcp(tcp::Config::default(), noise::Config::new, yamux_config) .map_err(P2PError::failed_to_build_swarm)? .with_quic() .with_dns() .map_err(P2PError::failed_to_build_swarm)? - .with_relay_client(noise::Config::new, yamux::Config::default) + .with_relay_client(noise::Config::new, yamux_config) .map_err(P2PError::failed_to_build_swarm)? .with_behaviour(|key, relay_client| { - let builder = - PlutoBehaviourBuilder::default().with_p2p_context(p2p_context.clone()); + let builder = PlutoBehaviourBuilder::default() + .with_p2p_context(p2p_context.clone()) + .with_quic_enabled(true); behaviour_fn(builder, key, relay_client).build(key) }) .map_err(P2PError::failed_to_build_swarm)? @@ -364,15 +369,11 @@ impl Node { { let swarm = SwarmBuilder::with_existing_identity(keypair) .with_tokio() - .with_tcp( - tcp::Config::default(), - noise::Config::new, - yamux::Config::default, - ) + .with_tcp(tcp::Config::default(), noise::Config::new, yamux_config) .map_err(P2PError::failed_to_build_swarm)? .with_dns() .map_err(P2PError::failed_to_build_swarm)? - .with_relay_client(noise::Config::new, yamux::Config::default) + .with_relay_client(noise::Config::new, yamux_config) .map_err(P2PError::failed_to_build_swarm)? .with_behaviour(|key, relay_client| { let builder = @@ -400,11 +401,7 @@ impl Node { { let swarm = SwarmBuilder::with_existing_identity(keypair) .with_tokio() - .with_tcp( - tcp::Config::default(), - noise::Config::new, - yamux::Config::default, - ) + .with_tcp(tcp::Config::default(), noise::Config::new, yamux_config) .map_err(P2PError::failed_to_build_swarm)? .with_quic() .with_dns() @@ -435,11 +432,7 @@ impl Node { { let swarm = SwarmBuilder::with_existing_identity(keypair) .with_tokio() - .with_tcp( - tcp::Config::default(), - noise::Config::new, - yamux::Config::default, - ) + .with_tcp(tcp::Config::default(), noise::Config::new, yamux_config) .map_err(P2PError::failed_to_build_swarm)? .with_quic() .with_dns() diff --git a/crates/p2p/src/proto.rs b/crates/p2p/src/proto.rs new file mode 100644 index 00000000..57825b4a --- /dev/null +++ b/crates/p2p/src/proto.rs @@ -0,0 +1,86 @@ +use std::io; + +use futures::prelude::*; +use prost::Message; +use unsigned_varint::aio::read_usize; + +/// Default maximum message size (64KB should be plenty for peer info). +pub const MAX_MESSAGE_SIZE: usize = 64 * 1024; + +/// Writes a length-delimited payload with an unsigned varint length prefix. +/// +/// Wire format: `[uvarint length][payload bytes]` +pub async fn write_length_delimited( + stream: &mut S, + payload: &[u8], +) -> io::Result<()> { + let mut len_buf = unsigned_varint::encode::usize_buffer(); + let encoded_len = unsigned_varint::encode::usize(payload.len(), &mut len_buf); + stream.write_all(encoded_len).await?; + stream.write_all(payload).await?; + stream.flush().await +} + +/// Reads a length-delimited payload with an unsigned varint length prefix. +/// +/// Wire format: `[uvarint length][payload bytes]` +/// +/// Returns an error if the payload exceeds `max_message_size`. +pub async fn read_length_delimited( + stream: &mut S, + max_message_size: usize, +) -> io::Result> { + let msg_len = read_usize(&mut *stream).await.map_err(|e| match e { + unsigned_varint::io::ReadError::Io(io_err) => io_err, + other => io::Error::new(io::ErrorKind::InvalidData, other), + })?; + + if msg_len > max_message_size { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("message too large: {msg_len} bytes (max: {max_message_size})"), + )); + } + + let mut buf = vec![0u8; msg_len]; + stream.read_exact(&mut buf).await?; + Ok(buf) +} + +/// Writes a protobuf message with unsigned varint length prefix to the stream. +/// +/// Wire format: `[uvarint length][protobuf bytes]` +pub async fn write_protobuf( + stream: &mut S, + msg: &M, +) -> io::Result<()> { + let mut buf = Vec::with_capacity(msg.encoded_len()); + msg.encode(&mut buf) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + write_length_delimited(stream, &buf).await +} + +/// Reads a protobuf message with unsigned varint length prefix from the stream. +/// +/// Wire format: `[uvarint length][protobuf bytes]` +/// +/// Returns an error if the message exceeds `MAX_MESSAGE_SIZE`. +pub async fn read_protobuf( + stream: &mut S, +) -> io::Result { + read_protobuf_with_max_size(stream, MAX_MESSAGE_SIZE).await +} + +/// Reads a protobuf message with unsigned varint length prefix from the stream. +/// +/// Wire format: `[uvarint length][protobuf bytes]` +/// +/// Returns an error if the message exceeds `max_message_size`. +pub async fn read_protobuf_with_max_size( + stream: &mut S, + max_message_size: usize, +) -> io::Result { + let buf = read_length_delimited(stream, max_message_size).await?; + + M::decode(&buf[..]).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) +} diff --git a/crates/p2p/src/relay.rs b/crates/p2p/src/relay.rs index c334670f..946785e1 100644 --- a/crates/p2p/src/relay.rs +++ b/crates/p2p/src/relay.rs @@ -31,9 +31,10 @@ use libp2p::{ ToSwarm, dial_opts::DialOpts, dummy, }, }; -use tokio::time::Interval; +use tokio::time::{Instant, Interval}; const RELAY_ROUTER_INTERVAL: Duration = Duration::from_secs(60); +const RELAY_ROUTER_INITIAL_DELAY: Duration = Duration::from_secs(10); /// Mutable relay reservation behaviour. /// @@ -246,7 +247,12 @@ pub struct RelayRouter { impl RelayRouter { /// Creates a new relay router. pub fn new(relays: Vec, p2p_context: P2PContext, local_peer_id: PeerId) -> Self { - let mut interval = tokio::time::interval(RELAY_ROUTER_INTERVAL); + let mut interval = tokio::time::interval_at( + Instant::now() + .checked_add(RELAY_ROUTER_INITIAL_DELAY) + .expect("should not overflow"), + RELAY_ROUTER_INTERVAL, + ); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); Self { diff --git a/crates/parsigex/Cargo.toml b/crates/parsigex/Cargo.toml new file mode 100644 index 00000000..c89d7fbe --- /dev/null +++ b/crates/parsigex/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "pluto-parsigex" +version.workspace = true +edition.workspace = true +repository.workspace = true +license.workspace = true +publish.workspace = true + +[dependencies] +either.workspace = true +futures.workspace = true +futures-timer.workspace = true +libp2p.workspace = true +prost.workspace = true +thiserror.workspace = true +tokio.workspace = true +tracing.workspace = true +unsigned-varint.workspace = true +pluto-core.workspace = true +pluto-p2p.workspace = true + +[dev-dependencies] +anyhow.workspace = true +clap.workspace = true +hex.workspace = true +pluto-cluster.workspace = true +pluto-tracing.workspace = true +tokio-util.workspace = true +serde_json.workspace = true + +[lints] +workspace = true diff --git a/crates/parsigex/examples/parsigex.rs b/crates/parsigex/examples/parsigex.rs new file mode 100644 index 00000000..af7d4287 --- /dev/null +++ b/crates/parsigex/examples/parsigex.rs @@ -0,0 +1,512 @@ +#![allow(missing_docs)] + +use std::{ + collections::{HashMap, HashSet}, + path::PathBuf, + time::Duration, +}; + +use anyhow::{Context, Result, anyhow}; +use clap::Parser; +use futures::StreamExt; +use libp2p::{ + identify, ping, + relay::{self}, + swarm::{NetworkBehaviour, SwarmEvent}, +}; +use pluto_cluster::lock::Lock; +use pluto_core::{ + signeddata::SignedRandao, + types::{Duty, DutyType, ParSignedDataSet, PubKey, SlotNumber}, +}; +use pluto_p2p::{ + behaviours::pluto::PlutoBehaviourEvent, + bootnode, + config::P2PConfig, + gater, k1, + p2p::{Node, NodeType}, + peer::peer_id_from_key, + relay::{MutableRelayReservation, RelayRouter}, +}; +use pluto_parsigex::{self as parsigex, DutyGater, Event, Handle, Verifier}; +use pluto_tracing::TracingConfig; +use tokio::fs; +use tokio_util::sync::CancellationToken; +use tracing::{info, warn}; + +#[derive(NetworkBehaviour)] +#[behaviour(to_swarm = "CombinedBehaviourEvent")] +struct CombinedBehaviour { + relay: relay::client::Behaviour, + relay_reservation: MutableRelayReservation, + relay_router: RelayRouter, + parsigex: parsigex::Behaviour, +} + +#[derive(Debug)] +enum CombinedBehaviourEvent { + ParSigEx(Event), + Relay(relay::client::Event), +} + +impl From for CombinedBehaviourEvent { + fn from(event: Event) -> Self { + Self::ParSigEx(event) + } +} + +impl From for CombinedBehaviourEvent { + fn from(event: relay::client::Event) -> Self { + Self::Relay(event) + } +} + +impl From for CombinedBehaviourEvent { + fn from(value: std::convert::Infallible) -> Self { + match value {} + } +} + +#[derive(Debug, Parser)] +#[command(name = "parsigex-example")] +#[command(about = "Demonstrates partial signature exchange over the bootnode/relay P2P path")] +struct Args { + /// Relay URLs or multiaddrs. + #[arg(long, value_delimiter = ',')] + relays: Vec, + + /// Directory holding the p2p private key and cluster lock. + #[arg(long)] + data_dir: PathBuf, + + /// TCP listen addresses. + #[arg(long, value_delimiter = ',', default_value = "0.0.0.0:0")] + tcp_addrs: Vec, + + /// UDP listen addresses used for QUIC. + #[arg(long, value_delimiter = ',', default_value = "0.0.0.0:0")] + udp_addrs: Vec, + + /// Whether to filter private addresses from advertisements. + #[arg(long, default_value_t = false)] + filter_private_addrs: bool, + + /// External IP address to advertise. + #[arg(long)] + external_ip: Option, + + /// External hostname to advertise. + #[arg(long)] + external_host: Option, + + /// Whether to disable socket reuse-port. + #[arg(long, default_value_t = false)] + disable_reuse_port: bool, + + /// Emit a sample partial signature every N seconds. + #[arg(long, default_value_t = 10)] + broadcast_every: u64, + + /// Share index to use in the sample partial signature. + #[arg(long, default_value_t = 1)] + share_idx: u64, + + /// Log level. + #[arg(long, default_value = "info")] + log_level: String, +} + +fn make_sample_set(slot: u64, share_idx: u64) -> ParSignedDataSet { + let share_byte = u8::try_from(share_idx % 255).unwrap_or(1); + let pub_key = PubKey::new([share_byte; 48]); + + let mut set = ParSignedDataSet::new(); + set.insert( + pub_key, + SignedRandao::new_partial(slot / 32, [share_byte; 96], share_idx), + ); + set +} + +fn log_received(duty: &Duty, set: &ParSignedDataSet, peer: &libp2p::PeerId) { + let entries = set + .inner() + .iter() + .map(|(pub_key, data)| format!("{pub_key}:share_idx={}", data.share_idx)) + .collect::>() + .join(", "); + + info!(peer = %peer, duty = %duty, entries = %entries, "received partial signature set"); +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + + pluto_tracing::init( + &TracingConfig::builder() + .with_default_console() + .override_env_filter(&args.log_level) + .build(), + )?; + + let key = k1::load_priv_key(&args.data_dir).with_context(|| { + format!( + "failed to load private key from {}", + args.data_dir.display() + ) + })?; + let local_peer_id = peer_id_from_key(key.public_key()) + .context("failed to derive local peer ID from private key")?; + + let lock_path = args.data_dir.join("cluster-lock.json"); + let lock_str = fs::read_to_string(&lock_path) + .await + .with_context(|| format!("failed to read {}", lock_path.display()))?; + let lock: Lock = serde_json::from_str(&lock_str) + .with_context(|| format!("failed to parse {}", lock_path.display()))?; + + let cancel = CancellationToken::new(); + let lock_hash_hex = hex::encode(&lock.lock_hash); + let relays = bootnode::new_relays(cancel.child_token(), &args.relays, &lock_hash_hex) + .await + .context("failed to resolve relays")?; + + let known_peers = lock + .peer_ids() + .context("failed to derive peer IDs from lock")?; + if !known_peers.contains(&local_peer_id) { + return Err(anyhow!( + "local peer ID {local_peer_id} not found in cluster lock" + )); + } + let conn_gater = gater::ConnGater::new( + gater::Config::closed() + .with_relays(relays.clone()) + .with_peer_ids(known_peers.clone()), + ); + + let verifier: Verifier = + std::sync::Arc::new(|_duty, _pubkey, _data| Box::pin(async { Ok(()) })); + let duty_gater: DutyGater = std::sync::Arc::new(|duty| duty.duty_type != DutyType::Unknown); + let handle_slot = std::sync::Arc::new(tokio::sync::Mutex::new(1_u64)); + + let p2p_config = P2PConfig { + relays: vec![], + external_ip: args.external_ip.clone(), + external_host: args.external_host.clone(), + tcp_addrs: args.tcp_addrs.clone(), + udp_addrs: args.udp_addrs.clone(), + disable_reuse_port: args.disable_reuse_port, + }; + + let relay_peer_ids: HashSet<_> = relays + .iter() + .filter_map(|relay| relay.peer().ok().flatten().map(|peer| peer.id)) + .collect(); + + let mut parsigex_handle: Option = None; + let mut node: Node = Node::new( + p2p_config, + key, + NodeType::QUIC, + args.filter_private_addrs, + known_peers.clone(), + |builder, keypair, relay_client| { + let p2p_context = builder.p2p_context(); + let local_peer_id = keypair.public().to_peer_id(); + let config = parsigex::Config::new( + local_peer_id, + p2p_context.clone(), + verifier.clone(), + duty_gater.clone(), + ) + .with_timeout(Duration::from_secs(10)); + let (parsigex, handle) = parsigex::Behaviour::new(config, local_peer_id); + parsigex_handle = Some(handle); + + builder + .with_gater(conn_gater) + .with_inner(CombinedBehaviour { + parsigex, + relay: relay_client, + relay_reservation: MutableRelayReservation::new(relays.clone()), + relay_router: RelayRouter::new(relays.clone(), p2p_context, local_peer_id), + }) + }, + )?; + + let parsigex_handle = + parsigex_handle.ok_or_else(|| anyhow!("parsigex handle should be created"))?; + + info!( + peer_id = %node.local_peer_id(), + data_dir = %args.data_dir.display(), + known_peers = ?known_peers, + relays = ?args.relays, + "parsigex example started" + ); + + let mut ticker = tokio::time::interval(Duration::from_secs(args.broadcast_every)); + let mut pending_broadcasts: HashMap = HashMap::new(); + + loop { + tokio::select! { + _ = tokio::signal::ctrl_c() => { + info!("ctrl+c received, shutting down"); + break; + } + _ = ticker.tick() => { + info!("broadcasting sample partial signature set"); + let mut slot = handle_slot.lock().await; + let duty = Duty::new(SlotNumber::new(*slot), DutyType::Randao); + let data_set = make_sample_set(*slot, args.share_idx); + + match parsigex_handle.broadcast(duty.clone(), data_set.clone()).await { + Ok(request_id) => { + pending_broadcasts.insert(request_id, (duty.clone(), args.share_idx)); + info!( + request_id, + duty = %duty, + share_idx = args.share_idx, + "queued sample partial signature set for broadcast" + ); + *slot = slot.saturating_add(1); + } + Err(error) => { + warn!(%error, "broadcast failed"); + } + } + } + event = node.select_next_some() => { + let peer_type = |peer_id: &libp2p::PeerId| { + if relay_peer_ids.contains(peer_id) { + "RELAY" + } else if known_peers.contains(peer_id) { + "PEER" + } else { + "UNKNOWN" + } + }; + + match event { + SwarmEvent::Behaviour(PlutoBehaviourEvent::Inner( + CombinedBehaviourEvent::Relay(relay::client::Event::ReservationReqAccepted { + relay_peer_id, + renewal, + limit, + }), + )) => { + info!( + relay_peer_id = %relay_peer_id, + peer_type = peer_type(&relay_peer_id), + renewal, + limit = ?limit, + "relay reservation accepted" + ); + } + SwarmEvent::Behaviour(PlutoBehaviourEvent::Inner( + CombinedBehaviourEvent::Relay(relay::client::Event::OutboundCircuitEstablished { + relay_peer_id, + limit, + }), + )) => { + info!( + relay_peer_id = %relay_peer_id, + peer_type = peer_type(&relay_peer_id), + limit = ?limit, + "outbound relay circuit established" + ); + } + SwarmEvent::Behaviour(PlutoBehaviourEvent::Inner( + CombinedBehaviourEvent::Relay(relay::client::Event::InboundCircuitEstablished { + src_peer_id, + limit, + }), + )) => { + info!( + src_peer_id = %src_peer_id, + peer_type = peer_type(&src_peer_id), + limit = ?limit, + "inbound relay circuit established" + ); + } + SwarmEvent::ConnectionEstablished { + peer_id, + endpoint, + num_established, + .. + } => { + let address = match &endpoint { + libp2p::core::ConnectedPoint::Dialer { address, .. } => address, + libp2p::core::ConnectedPoint::Listener { send_back_addr, .. } => { + send_back_addr + } + }; + info!( + peer_id = %peer_id, + peer_type = peer_type(&peer_id), + address = %address, + num_established, + "connection established" + ); + } + SwarmEvent::ConnectionClosed { + peer_id, + endpoint, + num_established, + cause, + .. + } => { + let address = match &endpoint { + libp2p::core::ConnectedPoint::Dialer { address, .. } => address, + libp2p::core::ConnectedPoint::Listener { send_back_addr, .. } => { + send_back_addr + } + }; + info!( + peer_id = %peer_id, + peer_type = peer_type(&peer_id), + address = %address, + num_established, + cause = ?cause, + "connection closed" + ); + } + SwarmEvent::OutgoingConnectionError { + peer_id, + error, + connection_id, + } => { + warn!( + peer_id = ?peer_id, + connection_id = ?connection_id, + error = %error, + "outgoing connection failed" + ); + } + SwarmEvent::IncomingConnectionError { + connection_id, + local_addr, + send_back_addr, + error, + .. + } => { + warn!( + connection_id = ?connection_id, + local_addr = %local_addr, + send_back_addr = %send_back_addr, + error = %error, + "incoming connection failed" + ); + } + SwarmEvent::Behaviour(PlutoBehaviourEvent::Identify( + identify::Event::Received { peer_id, info, .. }, + )) => { + info!( + peer_id = %peer_id, + peer_type = peer_type(&peer_id), + agent_version = %info.agent_version, + protocol_version = %info.protocol_version, + listen_addrs = ?info.listen_addrs, + "identify received" + ); + } + SwarmEvent::Behaviour(PlutoBehaviourEvent::Ping(ping::Event { + peer, + result, + .. + })) => match result { + Ok(rtt) => { + info!(peer_id = %peer, peer_type = peer_type(&peer), rtt = ?rtt, "ping succeeded"); + } + Err(error) => { + warn!(peer_id = %peer, peer_type = peer_type(&peer), error = %error, "ping failed"); + } + }, + SwarmEvent::Behaviour(PlutoBehaviourEvent::Inner( + CombinedBehaviourEvent::ParSigEx(Event::Received { + peer, + duty, + data_set, + .. + }), + )) => { + log_received(&duty, &data_set, &peer); + } + SwarmEvent::Behaviour(PlutoBehaviourEvent::Inner( + CombinedBehaviourEvent::ParSigEx(Event::Error { peer, error, .. }), + )) => { + warn!(peer = %peer, error = %error, "parsigex protocol error"); + } + SwarmEvent::Behaviour(PlutoBehaviourEvent::Inner( + CombinedBehaviourEvent::ParSigEx(Event::BroadcastError { + request_id, + peer, + error, + }), + )) => { + match pending_broadcasts.get(&request_id) { + Some((duty, share_idx)) => { + warn!( + request_id, + duty = %duty, + share_idx, + peer = ?peer, + error = %error, + "sample partial signature broadcast failed" + ); + } + None => { + warn!( + request_id, + peer = ?peer, + error = %error, + "partial signature broadcast failed" + ); + } + } + } + SwarmEvent::Behaviour(PlutoBehaviourEvent::Inner( + CombinedBehaviourEvent::ParSigEx(Event::BroadcastComplete { + request_id, + }), + )) => { + if let Some((duty, share_idx)) = pending_broadcasts.remove(&request_id) { + info!( + request_id, + duty = %duty, + share_idx, + "broadcasted sample partial signature set" + ); + } else { + info!(request_id, "partial signature broadcast completed"); + } + } + SwarmEvent::Behaviour(PlutoBehaviourEvent::Inner( + CombinedBehaviourEvent::ParSigEx(Event::BroadcastFinished { + request_id, + }), + )) => { + if let Some((duty, share_idx)) = pending_broadcasts.remove(&request_id) { + warn!( + request_id, + duty = %duty, + share_idx, + "sample partial signature broadcast finished with failures" + ); + } else { + warn!(request_id, "partial signature broadcast finished with failures"); + } + } + SwarmEvent::NewListenAddr { address, .. } => { + info!(address = %address, "listening"); + } + _ => {} + } + } + } + } + + Ok(()) +} diff --git a/crates/parsigex/src/behaviour.rs b/crates/parsigex/src/behaviour.rs new file mode 100644 index 00000000..4f9d6c5e --- /dev/null +++ b/crates/parsigex/src/behaviour.rs @@ -0,0 +1,444 @@ +//! Network behaviour and control handle for partial signature exchange. + +use std::{ + collections::{HashMap, VecDeque}, + future::Future, + pin::Pin, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, + task::{Context, Poll}, + time::Duration, +}; + +use either::Either; +use libp2p::{ + Multiaddr, PeerId, + swarm::{ + ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, NotifyHandler, THandler, + THandlerInEvent, THandlerOutEvent, ToSwarm, dummy, + }, +}; +use tokio::sync::mpsc; + +use pluto_core::types::{Duty, ParSignedData, ParSignedDataSet, PubKey}; +use pluto_p2p::p2p_context::P2PContext; + +use super::{Error as CodecError, Handler, encode_message}; +use crate::handler::{Failure as HandlerFailure, FromHandler, ToHandler}; + +/// Future returned by verifier callbacks. +pub type VerifyFuture = + Pin> + Send + 'static>>; + +/// Verifier callback type. +pub type Verifier = + Arc VerifyFuture + Send + Sync + 'static>; + +/// Duty gate callback type. +pub type DutyGater = Arc bool + Send + Sync + 'static>; + +/// Error type for signature verification callbacks. +#[derive(Debug, thiserror::Error)] +pub enum VerifyError { + /// Unknown validator public key. + #[error("unknown pubkey, not part of cluster lock")] + UnknownPubKey, + + /// Invalid share index for the validator. + #[error("invalid shareIdx")] + InvalidShareIndex, + + /// Invalid signed-data family for the duty. + #[error("invalid eth2 signed data")] + InvalidSignedDataFamily, + + /// Generic verification error. + #[error("{0}")] + Other(String), +} + +/// Error type for behaviour operations. +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// Message conversion failed. + #[error(transparent)] + Codec(#[from] CodecError), + + /// Channel closed. + #[error("parsigex handle closed")] + Closed, + + /// Broadcast failed for a peer. + #[error("broadcast to peer {peer} failed: {source}")] + BroadcastPeer { + /// Peer for which the broadcast failed. + peer: PeerId, + /// Source error. + #[source] + source: HandlerFailure, + }, + + /// Peer is not currently connected. + #[error("peer {0} is not connected")] + PeerNotConnected(PeerId), +} + +/// Result type for partial signature exchange behaviour operations. +pub type Result = std::result::Result; + +/// Event emitted by the partial signature exchange behaviour. +#[derive(Debug)] +pub enum Event { + /// A verified partial signature set was received from a peer. + Received { + /// The remote peer. + peer: PeerId, + /// Connection on which it was received. + connection: ConnectionId, + /// Duty associated with the data set. + duty: Duty, + /// Partial signature set. + data_set: ParSignedDataSet, + }, + /// A peer sent invalid data or verification failed. + Error { + /// The remote peer. + peer: PeerId, + /// Connection on which the error occurred. + connection: ConnectionId, + /// Failure reason. + error: HandlerFailure, + }, + /// Broadcast failed. + BroadcastError { + /// Request identifier. + request_id: u64, + /// Peer for which the broadcast failed. + peer: Option, + /// Failure reason. + error: HandlerFailure, + }, + /// Broadcast completed successfully for all targeted peers. + BroadcastComplete { + /// Request identifier. + request_id: u64, + }, + /// Broadcast finished after one or more peer failures. + BroadcastFinished { + /// Request identifier. + request_id: u64, + }, +} + +#[derive(Debug)] +struct PendingBroadcast { + remaining: usize, + failed: bool, +} + +#[derive(Debug)] +enum Command { + Broadcast { + request_id: u64, + duty: Duty, + data_set: ParSignedDataSet, + }, +} + +/// Async handle for outbound partial signature broadcasts. +#[derive(Debug, Clone)] +pub struct Handle { + tx: mpsc::UnboundedSender, + next_request_id: Arc, +} + +impl Handle { + /// Broadcasts a partial signature set to all peers except self. + pub async fn broadcast(&self, duty: Duty, data_set: ParSignedDataSet) -> Result { + let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); + self.tx + .send(Command::Broadcast { + request_id, + duty, + data_set, + }) + .map_err(|_| Error::Closed)?; + + Ok(request_id) + } +} + +/// Configuration for the partial signature exchange behaviour. +#[derive(Clone)] +pub struct Config { + peer_id: PeerId, + p2p_context: P2PContext, + verifier: Verifier, + duty_gater: DutyGater, + timeout: Duration, +} + +impl Config { + /// Creates a new configuration. + pub fn new( + peer_id: PeerId, + p2p_context: P2PContext, + verifier: Verifier, + duty_gater: DutyGater, + ) -> Self { + Self { + peer_id, + p2p_context, + verifier, + duty_gater, + timeout: Duration::from_secs(20), + } + } + + /// Sets the send/receive timeout. + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } +} + +/// Behaviour for partial signature exchange. +pub struct Behaviour { + config: Config, + rx: mpsc::UnboundedReceiver, + pending_actions: VecDeque>>, + events: VecDeque, + pending_broadcasts: HashMap, +} + +impl Behaviour { + /// Creates a behaviour and a clonable broadcast handle. + pub fn new(config: Config, peer_id: PeerId) -> (Self, Handle) { + debug_assert_eq!(config.peer_id, peer_id); + let (tx, rx) = mpsc::unbounded_channel(); + let handle = Handle { + tx, + next_request_id: Arc::new(AtomicU64::new(0)), + }; + + ( + Self { + config, + rx, + pending_actions: VecDeque::new(), + events: VecDeque::new(), + pending_broadcasts: HashMap::new(), + }, + handle, + ) + } + + fn handle_command(&mut self, command: Command) { + match command { + Command::Broadcast { + request_id, + duty, + data_set, + } => { + let message = match encode_message(&duty, &data_set) { + Ok(message) => message, + Err(err) => { + self.broadcast_error(request_id, None, HandlerFailure::Codec(err)); + return; + } + }; + + let peers: Vec<_> = self + .config + .p2p_context + .known_peers() + .iter() + .copied() + .collect(); + let mut targeted = 0usize; + for peer in peers { + if peer == self.config.peer_id { + continue; + } + + if self + .config + .p2p_context + .peer_store_lock() + .connections_to_peer(&peer) + .is_empty() + { + self.broadcast_error( + request_id, + Some(peer), + HandlerFailure::Io(format!("peer {peer} is not connected")), + ); + continue; + } + + self.pending_actions.push_back(ToSwarm::NotifyHandler { + peer_id: peer, + handler: NotifyHandler::Any, + event: Either::Left(ToHandler::Send { + request_id, + payload: message.clone(), + }), + }); + targeted = targeted.saturating_add(1); + } + + if targeted == 0 { + return; + } + + self.pending_broadcasts.insert( + request_id, + PendingBroadcast { + remaining: targeted, + failed: false, + }, + ); + } + } + } + + fn finish_broadcast_result(&mut self, request_id: u64, failed: bool) { + let Some(entry) = self.pending_broadcasts.get_mut(&request_id) else { + return; + }; + + entry.failed |= failed; + entry.remaining = entry.remaining.saturating_sub(1); + if entry.remaining == 0 { + let failed = self + .pending_broadcasts + .remove(&request_id) + .map(|entry| entry.failed) + .unwrap_or(failed); + if failed { + self.events + .push_back(Event::BroadcastFinished { request_id }); + } else { + self.events + .push_back(Event::BroadcastComplete { request_id }); + } + } + } + + fn broadcast_error(&mut self, request_id: u64, peer: Option, error: HandlerFailure) { + self.events.push_back(Event::BroadcastError { + request_id, + peer, + error, + }); + } +} + +impl NetworkBehaviour for Behaviour { + type ConnectionHandler = Either; + type ToSwarm = Event; + + fn handle_established_inbound_connection( + &mut self, + _connection_id: ConnectionId, + peer: PeerId, + _local_addr: &Multiaddr, + _remote_addr: &Multiaddr, + ) -> std::result::Result, ConnectionDenied> { + if !self.config.p2p_context.is_known_peer(&peer) { + return Ok(Either::Right(dummy::ConnectionHandler)); + } + + tracing::trace!("establishing inbound connection to peer: {:?}", peer); + Ok(Either::Left(Handler::new( + self.config.timeout, + self.config.verifier.clone(), + self.config.duty_gater.clone(), + peer, + ))) + } + + fn handle_established_outbound_connection( + &mut self, + _connection_id: ConnectionId, + peer: PeerId, + _addr: &Multiaddr, + _role_override: libp2p::core::Endpoint, + _port_use: libp2p::core::transport::PortUse, + ) -> std::result::Result, ConnectionDenied> { + if !self.config.p2p_context.is_known_peer(&peer) { + return Ok(Either::Right(dummy::ConnectionHandler)); + } + + tracing::trace!("establishing outbound connection to peer: {:?}", peer); + Ok(Either::Left(Handler::new( + self.config.timeout, + self.config.verifier.clone(), + self.config.duty_gater.clone(), + peer, + ))) + } + + fn on_swarm_event(&mut self, _event: FromSwarm) {} + + fn on_connection_handler_event( + &mut self, + peer_id: PeerId, + connection_id: ConnectionId, + event: THandlerOutEvent, + ) { + let event = match event { + Either::Left(event) => event, + Either::Right(value) => match value {}, + }; + + tracing::trace!("received connection handler event: {:?}", event); + match event { + FromHandler::Received { duty, data_set } => { + self.events.push_back(Event::Received { + peer: peer_id, + connection: connection_id, + duty, + data_set, + }); + } + FromHandler::InboundError(error) => { + self.events.push_back(Event::Error { + peer: peer_id, + connection: connection_id, + error, + }); + } + FromHandler::OutboundSuccess { request_id } => { + self.finish_broadcast_result(request_id, false); + } + FromHandler::OutboundError { request_id, error } => { + self.finish_broadcast_result(request_id, true); + self.broadcast_error(request_id, Some(peer_id), error); + } + } + } + + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + tracing::trace!("polling parsigex behaviour"); + + if let Some(event) = self.events.pop_front() { + return Poll::Ready(ToSwarm::GenerateEvent(event)); + } + + if let Poll::Ready(Some(command)) = self.rx.poll_recv(cx) { + self.handle_command(command); + } + + if let Some(action) = self.pending_actions.pop_front() { + return Poll::Ready(action); + } + + Poll::Pending + } +} diff --git a/crates/parsigex/src/handler.rs b/crates/parsigex/src/handler.rs new file mode 100644 index 00000000..e1925aad --- /dev/null +++ b/crates/parsigex/src/handler.rs @@ -0,0 +1,403 @@ +//! Connection handler for the partial signature exchange protocol. + +use std::{ + collections::VecDeque, + task::{Context, Poll}, + time::Duration, +}; + +use futures::{future::BoxFuture, prelude::*}; +use futures_timer::Delay; +use libp2p::{ + PeerId, + core::upgrade::ReadyUpgrade, + swarm::{ + ConnectionHandler, ConnectionHandlerEvent, StreamProtocol, StreamUpgradeError, + SubstreamProtocol, + handler::{ + ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, + }, + }, +}; + +use pluto_core::types::{Duty, ParSignedDataSet}; + +use super::{DutyGater, PROTOCOL_NAME, Verifier, protocol}; +use crate::Error as CodecError; + +/// Failure type for the partial signature exchange handler. +#[derive(Debug, thiserror::Error)] +pub enum Failure { + /// Stream negotiation timed out. + #[error("parsigex protocol negotiation timed out")] + Timeout, + /// Invalid payload. + #[error("invalid parsigex payload")] + InvalidPayload, + /// Duty not accepted by the gater. + #[error("invalid duty")] + InvalidDuty, + /// Signature verification failed. + #[error("invalid partial signature")] + InvalidPartialSignature, + /// I/O error. + #[error("{0}")] + Io(String), + /// Codec error. + #[error("codec error: {0}")] + Codec(CodecError), +} + +impl Failure { + fn io(error: impl std::fmt::Display) -> Self { + Self::Io(error.to_string()) + } +} + +/// Command sent from the behaviour to a handler. +#[derive(Debug)] +pub enum ToHandler { + /// Send the encoded payload to the remote peer. + Send { + /// Request identifier used to correlate broadcast completions. + request_id: u64, + /// Encoded protobuf payload. + payload: Vec, + }, +} + +/// Event sent from the handler back to the behaviour. +#[derive(Debug)] +pub enum FromHandler { + /// A verified message was received. + Received { + /// Duty from the message. + duty: Duty, + /// Verified partial signature set. + data_set: ParSignedDataSet, + }, + /// An inbound message failed decoding, gating, or verification. + InboundError(Failure), + /// Outbound send completed successfully. + OutboundSuccess { + /// Request identifier. + request_id: u64, + }, + /// Outbound send failed. + OutboundError { + /// Request identifier. + request_id: u64, + /// Failure reason. + error: Failure, + }, +} + +type SendFuture = BoxFuture<'static, Result<(), Failure>>; +type RecvFuture = BoxFuture<'static, Result<(Duty, ParSignedDataSet), Failure>>; + +enum OutboundState { + IdleStream { stream: libp2p::swarm::Stream }, + RequestOpenStream { request_id: u64, payload: Vec }, + Sending { request_id: u64, future: SendFuture }, +} + +impl std::fmt::Debug for OutboundState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OutboundState::IdleStream { .. } => { + write!(f, "IdleStream {{ stream: }}") + } + OutboundState::RequestOpenStream { + request_id, + payload, + } => write!( + f, + "RequestOpenStream {{ request_id: {}, payload: {:?} }}", + request_id, payload + ), + OutboundState::Sending { request_id, .. } => write!( + f, + "Sending {{ request_id: {}, future: }}", + request_id + ), + } + } +} + +fn recv_message( + mut stream: libp2p::swarm::Stream, + verifier: Verifier, + duty_gater: DutyGater, + timeout: Duration, +) -> RecvFuture { + async move { + let recv = async { + let bytes = protocol::recv_message(&mut stream) + .await + .map_err(Failure::io)?; + let (duty, data_set) = + protocol::decode_message(&bytes).map_err(|_| Failure::InvalidPayload)?; + if !(duty_gater)(&duty) { + return Err(Failure::InvalidDuty); + } + + for (pub_key, par_sig) in data_set.inner() { + verifier(duty.clone(), *pub_key, par_sig.clone()) + .await + .map_err(|_| Failure::InvalidPartialSignature)?; + } + + Ok((duty, data_set)) + }; + + futures::pin_mut!(recv); + match futures::future::select(recv, Delay::new(timeout)).await { + futures::future::Either::Left((result, _)) => result, + futures::future::Either::Right(((), _)) => Err(Failure::Timeout), + } + } + .boxed() +} + +fn send_message( + mut stream: libp2p::swarm::Stream, + payload: Vec, + timeout: Duration, +) -> SendFuture { + async move { + let send = + protocol::send_message(&mut stream, &payload).map(|result| result.map_err(Failure::io)); + futures::pin_mut!(send); + match futures::future::select(send, Delay::new(timeout)).await { + futures::future::Either::Left((result, _)) => result, + futures::future::Either::Right(((), _)) => Err(Failure::Timeout), + } + } + .boxed() +} + +/// Connection handler for parsigex. +pub struct Handler { + timeout: Duration, + verifier: Verifier, + duty_gater: DutyGater, + outbound_queue: VecDeque<(u64, Vec)>, + outbound: Option, + inbound: Option, + pending_events: VecDeque, +} + +impl Handler { + /// Creates a new handler for one connection. + pub fn new( + timeout: Duration, + verifier: Verifier, + duty_gater: DutyGater, + _peer: PeerId, + ) -> Self { + Self { + timeout, + verifier, + duty_gater, + outbound_queue: VecDeque::new(), + outbound: None, + inbound: None, + pending_events: VecDeque::new(), + } + } + + fn on_dial_upgrade_error( + &mut self, + error: DialUpgradeError<(), ::OutboundProtocol>, + ) { + let Some(OutboundState::RequestOpenStream { request_id, .. }) = self.outbound.take() else { + return; + }; + + let failure = match error.error { + StreamUpgradeError::Timeout => Failure::Timeout, + StreamUpgradeError::NegotiationFailed => Failure::io("protocol negotiation failed"), + StreamUpgradeError::Apply(e) => libp2p::core::util::unreachable(e), + StreamUpgradeError::Io(e) => Failure::io(e), + }; + + self.pending_events.push_back(FromHandler::OutboundError { + request_id, + error: failure, + }); + } +} + +impl ConnectionHandler for Handler { + type FromBehaviour = ToHandler; + type InboundOpenInfo = (); + type InboundProtocol = ReadyUpgrade; + type OutboundOpenInfo = (); + type OutboundProtocol = ReadyUpgrade; + type ToBehaviour = FromHandler; + + fn listen_protocol(&self) -> SubstreamProtocol { + SubstreamProtocol::new(ReadyUpgrade::new(PROTOCOL_NAME), ()) + } + + fn on_behaviour_event(&mut self, event: Self::FromBehaviour) { + match event { + ToHandler::Send { + request_id, + payload, + } => self.outbound_queue.push_back((request_id, payload)), + } + } + + #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))] + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll< + ConnectionHandlerEvent, + > { + if let Some(event) = self.pending_events.pop_front() { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + } + + if let Some(fut) = self.inbound.as_mut() { + match fut.poll_unpin(cx) { + Poll::Pending => {} + Poll::Ready(Ok((duty, data_set))) => { + self.inbound = None; + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + FromHandler::Received { duty, data_set }, + )); + } + Poll::Ready(Err(error)) => { + self.inbound = None; + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + FromHandler::InboundError(error), + )); + } + } + } + + if let Some(outbound) = self.outbound.take() { + match outbound { + OutboundState::IdleStream { stream } => { + if let Some((request_id, payload)) = self.outbound_queue.pop_front() { + self.outbound = Some(OutboundState::Sending { + request_id, + future: send_message(stream, payload, self.timeout), + }); + } else { + self.outbound = Some(OutboundState::IdleStream { stream }); + } + } + OutboundState::RequestOpenStream { + request_id, + payload, + } => { + // Waiting for stream negotiation - put state back and return pending. + // The OutboundSubstreamRequest was already emitted when first entering this + // state. Returning it again would cause libp2p to panic + // with "cannot extract twice". + self.outbound = Some(OutboundState::RequestOpenStream { + request_id, + payload, + }); + } + OutboundState::Sending { + request_id, + mut future, + } => match future.poll_unpin(cx) { + Poll::Pending => { + self.outbound = Some(OutboundState::Sending { request_id, future }); + } + Poll::Ready(Ok(())) => { + self.outbound = None; + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + FromHandler::OutboundSuccess { request_id }, + )); + } + Poll::Ready(Err(error)) => { + self.outbound = None; + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + FromHandler::OutboundError { request_id, error }, + )); + } + }, + } + } + + // Only start a new outbound operation if none is in progress. + // This prevents overwriting RequestOpenStream or Sending states. + if self.outbound.is_none() + && let Some((request_id, payload)) = self.outbound_queue.pop_front() + { + self.outbound = Some(OutboundState::RequestOpenStream { + request_id, + payload, + }); + return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(ReadyUpgrade::new(PROTOCOL_NAME), ()), + }); + } + + Poll::Pending + } + + fn on_connection_event( + &mut self, + event: ConnectionEvent, + ) { + match event { + ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound { + protocol: mut stream, + .. + }) => { + stream.ignore_for_keep_alive(); + self.inbound = Some(recv_message( + stream, + self.verifier.clone(), + self.duty_gater.clone(), + self.timeout, + )); + } + ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound { + protocol: mut stream, + .. + }) => { + stream.ignore_for_keep_alive(); + match self.outbound.take() { + Some(OutboundState::RequestOpenStream { + request_id, + payload, + }) => { + self.outbound = Some(OutboundState::Sending { + request_id, + future: send_message(stream, payload, self.timeout), + }); + } + Some(OutboundState::Sending { request_id, future }) => { + self.outbound = Some(OutboundState::Sending { request_id, future }); + tracing::debug!( + "dropping unexpected outbound parsigex stream while a send is already in progress" + ); + } + Some(OutboundState::IdleStream { + stream: idle_stream, + }) => { + self.outbound = Some(OutboundState::IdleStream { + stream: idle_stream, + }); + tracing::debug!( + "dropping unexpected outbound parsigex stream while an idle stream is already cached" + ); + } + None => { + self.outbound = Some(OutboundState::IdleStream { stream }); + } + } + } + ConnectionEvent::DialUpgradeError(error) => self.on_dial_upgrade_error(error), + _ => {} + } + } +} diff --git a/crates/parsigex/src/lib.rs b/crates/parsigex/src/lib.rs new file mode 100644 index 00000000..ca967afc --- /dev/null +++ b/crates/parsigex/src/lib.rs @@ -0,0 +1,41 @@ +//! Partial signature exchange protocol. + +pub mod behaviour; +mod handler; +mod protocol; + +pub use behaviour::{ + Behaviour, Config, DutyGater, Error as BehaviourError, Event, Handle, Verifier, VerifyError, +}; +pub use handler::Handler; +pub use protocol::{decode_message, encode_message}; + +use libp2p::PeerId; +use pluto_core::ParSigExCodecError; + +/// The protocol name for partial signature exchange (version 2.0.0). +pub const PROTOCOL_NAME: libp2p::swarm::StreamProtocol = + libp2p::swarm::StreamProtocol::new("/charon/parsigex/2.0.0"); + +/// Returns the supported protocols in precedence order. +pub fn protocols() -> Vec { + vec![PROTOCOL_NAME] +} + +/// Error type for proto and conversion operations. +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// Core codec error. + #[error(transparent)] + Codec(#[from] ParSigExCodecError), + + /// Broadcast failed for a peer. + #[error("broadcast to peer {peer} failed")] + BroadcastPeer { + /// Peer for which the broadcast failed. + peer: PeerId, + }, +} + +/// Result type for partial signature exchange operations. +pub type Result = std::result::Result; diff --git a/crates/parsigex/src/protocol.rs b/crates/parsigex/src/protocol.rs new file mode 100644 index 00000000..bfaccaf9 --- /dev/null +++ b/crates/parsigex/src/protocol.rs @@ -0,0 +1,68 @@ +//! Wire protocol helpers for partial signature exchange. + +use std::io; + +use libp2p::swarm::Stream; +use prost::Message; + +use pluto_core::{ + corepb::v1::{core as pbcore, parsigex as pbparsigex}, + types::{Duty, ParSignedDataSet}, +}; +use pluto_p2p::proto; + +use super::{Error, Result as ParasigexResult}; + +/// Maximum accepted message size. +const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024; + +/// Encodes a protobuf message to bytes. +pub fn encode_protobuf(message: &M) -> Vec { + let mut buf = Vec::with_capacity(message.encoded_len()); + message + .encode(&mut buf) + .expect("vec-backed protobuf encoding cannot fail"); + buf +} + +/// Decodes a protobuf message from bytes. +pub fn decode_protobuf( + bytes: &[u8], +) -> std::result::Result { + M::decode(bytes) +} + +/// Encodes a partial signature exchange message. +pub fn encode_message(duty: &Duty, data_set: &ParSignedDataSet) -> ParasigexResult> { + let pb = pbparsigex::ParSigExMsg { + duty: Some(pbcore::Duty::from(duty)), + data_set: Some(pbcore::ParSignedDataSet::try_from(data_set)?), + }; + + Ok(encode_protobuf(&pb)) +} + +/// Decodes a partial signature exchange message. +pub fn decode_message(bytes: &[u8]) -> ParasigexResult<(Duty, ParSignedDataSet)> { + let pb: pbparsigex::ParSigExMsg = decode_protobuf(bytes) + .map_err(|_| Error::from(pluto_core::ParSigExCodecError::InvalidMessageFields))?; + let duty_pb = pb + .duty + .ok_or(pluto_core::ParSigExCodecError::InvalidMessageFields)?; + let data_set_pb = pb + .data_set + .ok_or(pluto_core::ParSigExCodecError::InvalidMessageFields)?; + let duty = Duty::try_from(&duty_pb)?; + let data_set = ParSignedDataSet::try_from((&duty.duty_type, &data_set_pb))?; + Ok((duty, data_set)) +} + +/// Sends one protobuf message on the stream. +pub async fn send_message(stream: &mut Stream, payload: &[u8]) -> io::Result<()> { + proto::write_length_delimited(stream, payload).await +} + +/// Receives one protobuf payload from the stream. +pub async fn recv_message(stream: &mut Stream) -> io::Result> { + proto::read_length_delimited(stream, MAX_MESSAGE_SIZE).await +} diff --git a/crates/peerinfo/src/protocol.rs b/crates/peerinfo/src/protocol.rs index 2fe6ad70..5b284521 100644 --- a/crates/peerinfo/src/protocol.rs +++ b/crates/peerinfo/src/protocol.rs @@ -17,14 +17,12 @@ use std::{ }; use chrono::{DateTime, Utc}; -use futures::prelude::*; use libp2p::{PeerId, swarm::Stream}; use pluto_core::version::{self, SemVer, SemVerError}; -use prost::Message; +use pluto_p2p::proto; use regex::Regex; use tokio::sync::Mutex; use tracing::{info, warn}; -use unsigned_varint::aio::read_usize; use crate::{ LocalPeerInfo, @@ -32,9 +30,6 @@ use crate::{ peerinfopb::v1::peerinfo::PeerInfo, }; -/// Maximum message size (64KB should be plenty for peer info). -const MAX_MESSAGE_SIZE: usize = 64 * 1024; - static GIT_HASH_RE: LazyLock = LazyLock::new(|| Regex::new(r"^[0-9a-f]{7}$").expect("invalid regex")); @@ -51,57 +46,6 @@ pub struct ProtocolState { local_info: LocalPeerInfo, } -/// Writes a protobuf message with unsigned varint length prefix to the stream. -/// -/// Wire format: `[uvarint length][protobuf bytes]` -async fn write_protobuf( - stream: &mut S, - msg: &M, -) -> io::Result<()> { - // Encode message to protobuf bytes - let mut buf = Vec::with_capacity(msg.encoded_len()); - msg.encode(&mut buf) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - - // Write unsigned varint length prefix - let mut len_buf = unsigned_varint::encode::usize_buffer(); - let encoded_len = unsigned_varint::encode::usize(buf.len(), &mut len_buf); - stream.write_all(encoded_len).await?; - - // Write protobuf bytes - stream.write_all(&buf).await?; - stream.flush().await -} - -/// Reads a protobuf message with unsigned varint length prefix from the stream. -/// -/// Wire format: `[uvarint length][protobuf bytes]` -/// -/// Returns an error if the message exceeds `MAX_MESSAGE_SIZE`. -async fn read_protobuf( - stream: &mut S, -) -> io::Result { - // Read unsigned varint length prefix - let msg_len = read_usize(&mut *stream).await.map_err(|e| match e { - unsigned_varint::io::ReadError::Io(io_err) => io_err, - other => io::Error::new(io::ErrorKind::InvalidData, other), - })?; - - if msg_len > MAX_MESSAGE_SIZE { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("message too large: {msg_len} bytes (max: {MAX_MESSAGE_SIZE})"), - )); - } - - // Read exactly `msg_len` protobuf bytes - let mut buf = vec![0u8; msg_len]; - stream.read_exact(&mut buf).await?; - - // Unmarshal protobuf - M::decode(&buf[..]).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) -} - /// Errors that can occur during the protocol. #[derive(Debug, thiserror::Error)] pub enum ProtocolError { @@ -317,8 +261,8 @@ impl ProtocolState { request: &PeerInfo, ) -> io::Result<(Stream, PeerInfo)> { let start = Instant::now(); - write_protobuf(&mut stream, request).await?; - let response = read_protobuf(&mut stream).await?; + proto::write_protobuf(&mut stream, request).await?; + let response = proto::read_protobuf(&mut stream).await?; let rtt = start.elapsed(); self.validate_peer_info(&response, rtt).await; @@ -334,8 +278,8 @@ impl ProtocolState { mut stream: Stream, local_info: &PeerInfo, ) -> io::Result<(Stream, PeerInfo)> { - let request = read_protobuf(&mut stream).await?; - write_protobuf(&mut stream, local_info).await?; + let request = proto::read_protobuf(&mut stream).await?; + proto::write_protobuf(&mut stream, local_info).await?; Ok((stream, request)) } } @@ -344,6 +288,7 @@ impl ProtocolState { mod tests { use super::*; use hex_literal::hex; + use prost::Message; // Test case: minimal // CharonVersion: "v1.0.0" @@ -571,7 +516,7 @@ mod tests { // Write to a cursor let mut buf = Vec::new(); - write_protobuf(&mut buf, &original).await.unwrap(); + proto::write_protobuf(&mut buf, &original).await.unwrap(); // The wire format should be: [varint length][protobuf bytes] // Minimal message is 14 bytes, so length prefix is just 1 byte (14 < 128) @@ -580,7 +525,7 @@ mod tests { // Read it back let mut cursor = futures::io::Cursor::new(&buf[..]); - let decoded: PeerInfo = read_protobuf(&mut cursor).await.unwrap(); + let decoded: PeerInfo = proto::read_protobuf(&mut cursor).await.unwrap(); assert_eq!(original, decoded); } @@ -589,11 +534,11 @@ mod tests { let original = make_full_peerinfo(); let mut buf = Vec::new(); - write_protobuf(&mut buf, &original).await.unwrap(); + proto::write_protobuf(&mut buf, &original).await.unwrap(); // Read it back let mut cursor = futures::io::Cursor::new(&buf[..]); - let decoded: PeerInfo = read_protobuf(&mut cursor).await.unwrap(); + let decoded: PeerInfo = proto::read_protobuf(&mut cursor).await.unwrap(); assert_eq!(original, decoded); } @@ -609,10 +554,10 @@ mod tests { for original in variants { let mut buf = Vec::new(); - write_protobuf(&mut buf, &original).await.unwrap(); + proto::write_protobuf(&mut buf, &original).await.unwrap(); let mut cursor = futures::io::Cursor::new(&buf[..]); - let decoded: PeerInfo = read_protobuf(&mut cursor).await.unwrap(); + let decoded: PeerInfo = proto::read_protobuf(&mut cursor).await.unwrap(); assert_eq!(original, decoded); } } @@ -621,13 +566,13 @@ mod tests { async fn test_read_protobuf_message_too_large() { // Create a buffer with a length prefix that exceeds MAX_MESSAGE_SIZE let mut buf = Vec::new(); - let large_len = MAX_MESSAGE_SIZE + 1; + let large_len = proto::MAX_MESSAGE_SIZE + 1; let mut len_buf = unsigned_varint::encode::usize_buffer(); let encoded_len = unsigned_varint::encode::usize(large_len, &mut len_buf); buf.extend_from_slice(encoded_len); let mut cursor = futures::io::Cursor::new(&buf[..]); - let result: io::Result = read_protobuf(&mut cursor).await; + let result: io::Result = proto::read_protobuf(&mut cursor).await; assert!(result.is_err()); let err = result.unwrap_err(); @@ -641,7 +586,7 @@ mod tests { let invalid_data = [0x05, 0xff, 0xff, 0xff, 0xff, 0xff]; // length 5, then garbage let mut cursor = futures::io::Cursor::new(&invalid_data[..]); - let result: io::Result = read_protobuf(&mut cursor).await; + let result: io::Result = proto::read_protobuf(&mut cursor).await; assert!(result.is_err()); assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData); @@ -653,7 +598,7 @@ mod tests { let truncated = [0x10]; // claims 16 bytes but has none let mut cursor = futures::io::Cursor::new(&truncated[..]); - let result: io::Result = read_protobuf(&mut cursor).await; + let result: io::Result = proto::read_protobuf(&mut cursor).await; assert!(result.is_err()); assert_eq!(result.unwrap_err().kind(), io::ErrorKind::UnexpectedEof); @@ -667,15 +612,15 @@ mod tests { // Write multiple messages to the same buffer let mut buf = Vec::new(); - write_protobuf(&mut buf, &msg1).await.unwrap(); - write_protobuf(&mut buf, &msg2).await.unwrap(); - write_protobuf(&mut buf, &msg3).await.unwrap(); + proto::write_protobuf(&mut buf, &msg1).await.unwrap(); + proto::write_protobuf(&mut buf, &msg2).await.unwrap(); + proto::write_protobuf(&mut buf, &msg3).await.unwrap(); // Read them back in order let mut cursor = futures::io::Cursor::new(&buf[..]); - let decoded1: PeerInfo = read_protobuf(&mut cursor).await.unwrap(); - let decoded2: PeerInfo = read_protobuf(&mut cursor).await.unwrap(); - let decoded3: PeerInfo = read_protobuf(&mut cursor).await.unwrap(); + let decoded1: PeerInfo = proto::read_protobuf(&mut cursor).await.unwrap(); + let decoded2: PeerInfo = proto::read_protobuf(&mut cursor).await.unwrap(); + let decoded3: PeerInfo = proto::read_protobuf(&mut cursor).await.unwrap(); assert_eq!(msg1, decoded1); assert_eq!(msg2, decoded2); diff --git a/crates/testutil/Cargo.toml b/crates/testutil/Cargo.toml index 0a973d1f..720bcc86 100644 --- a/crates/testutil/Cargo.toml +++ b/crates/testutil/Cargo.toml @@ -9,6 +9,7 @@ publish.workspace = true [dependencies] k256.workspace = true pluto-crypto.workspace = true +pluto-eth2api.workspace = true rand.workspace = true rand_core.workspace = true thiserror.workspace = true diff --git a/crates/testutil/src/lib.rs b/crates/testutil/src/lib.rs index abc00e7a..686c8c7a 100644 --- a/crates/testutil/src/lib.rs +++ b/crates/testutil/src/lib.rs @@ -6,3 +6,8 @@ /// Random utilities. pub mod random; + +pub use random::{ + random_deneb_versioned_attestation, random_eth2_signature, random_eth2_signature_bytes, + random_root, random_root_bytes, random_slot, random_v_idx, +}; diff --git a/crates/testutil/src/random.rs b/crates/testutil/src/random.rs index 65be74cc..8e4a8eeb 100644 --- a/crates/testutil/src/random.rs +++ b/crates/testutil/src/random.rs @@ -7,6 +7,14 @@ use k256::{ elliptic_curve::rand_core::{CryptoRng, Error, RngCore}, }; use pluto_crypto::{blst_impl::BlstImpl, tbls::Tbls, types::PrivateKey}; +use pluto_eth2api::{ + spec::phase0, + types::{ + AltairBeaconStateCurrentJustifiedCheckpoint, Data, + GetBlockAttestationsV2ResponseResponseDataArray2, + }, + versioned::{self, AttestationPayload}, +}; use rand::{Rng, SeedableRng, rngs::StdRng}; /// A deterministic RNG that always returns the same byte value. @@ -67,6 +75,26 @@ pub fn generate_test_bls_key(seed: u64) -> PrivateKey { .expect("deterministic key generation should not fail") } +/// Generates a random BLS signature as a hex string for testing. +/// +/// Returns a 96-byte (192 hex characters) BLS signature encoded as a hex string +/// with "0x" prefix. +pub fn random_eth2_signature() -> String { + let mut bytes = [0u8; 96]; + let mut rng = rand::thread_rng(); + for byte in &mut bytes { + *byte = rng.r#gen(); + } + format!("0x{}", hex::encode(bytes)) +} + +/// Generates a random Ethereum consensus signature for testing. +pub fn random_eth2_signature_bytes() -> phase0::BLSSignature { + let mut signature = [0u8; 96]; + rand::thread_rng().fill(&mut signature[..]); + signature +} + /// Generate random Ethereum address for testing. pub fn random_eth_address(rand: &mut impl Rng) -> [u8; 20] { let mut bytes = [0u8; 20]; @@ -74,6 +102,134 @@ pub fn random_eth_address(rand: &mut impl Rng) -> [u8; 20] { bytes } +/// Generates a random 32-byte root as a hex string for testing. +/// +/// Returns a 32-byte (64 hex characters) root encoded as a hex string with "0x" +/// prefix. +pub fn random_root() -> String { + let mut bytes = [0u8; 32]; + let mut rng = rand::thread_rng(); + for byte in &mut bytes { + *byte = rng.r#gen(); + } + format!("0x{}", hex::encode(bytes)) +} + +/// Generates a random Ethereum consensus root for testing. +pub fn random_root_bytes() -> phase0::Root { + let mut root = [0u8; 32]; + rand::thread_rng().fill(&mut root); + root +} + +/// Generates a random slot for testing. +pub fn random_slot() -> phase0::Slot { + rand::thread_rng().r#gen() +} + +/// Generates a random validator index for testing. +pub fn random_v_idx() -> phase0::ValidatorIndex { + rand::thread_rng().r#gen() +} + +/// Generates a random bitlist as a hex string for testing. +/// +/// # Arguments +/// +/// * `length` - The number of bits to set in the bitlist +/// +/// Returns a hex-encoded bitlist string with "0x" prefix. +pub fn random_bit_list(length: usize) -> String { + // Create a byte array large enough to hold the bits + // For simplicity, use 32 bytes (256 bits) + let mut bytes = [0u8; 32]; + let mut rng = rand::thread_rng(); + + // Set 'length' random bits + for _ in 0..length { + let bit_idx = rng.r#gen::() % 256; + let byte_idx = bit_idx / 8; + let bit_offset = bit_idx % 8; + bytes[byte_idx] |= 1 << bit_offset; + } + + format!("0x{}", hex::encode(bytes)) +} + +/// Generates a random checkpoint for testing. +fn random_checkpoint() -> AltairBeaconStateCurrentJustifiedCheckpoint { + let mut rng = rand::thread_rng(); + AltairBeaconStateCurrentJustifiedCheckpoint { + epoch: rng.r#gen::().to_string(), + root: random_root(), + } +} + +/// Generates random attestation data for Phase 0. +fn random_attestation_data_phase0() -> Data { + let mut rng = rand::thread_rng(); + Data { + slot: rng.r#gen::().to_string(), + index: rng.r#gen::().to_string(), + beacon_block_root: random_root(), + source: random_checkpoint(), + target: random_checkpoint(), + } +} + +/// Generates a random Phase 0 attestation. +/// +/// Returns an attestation with random aggregation bits, attestation data, and +/// signature. +pub fn random_phase0_attestation() -> GetBlockAttestationsV2ResponseResponseDataArray2 { + GetBlockAttestationsV2ResponseResponseDataArray2 { + aggregation_bits: random_bit_list(1), + data: random_attestation_data_phase0(), + signature: random_eth2_signature(), + } +} + +/// Generates a random Deneb versioned attestation. +/// +/// Returns a versioned attestation containing a Phase 0 attestation with the +/// Deneb version tag. This matches the Go implementation: +/// +/// ```go +/// func RandomDenebVersionedAttestation() *eth2spec.VersionedAttestation { +/// return ð2spec.VersionedAttestation{ +/// Version: eth2spec.DataVersionDeneb, +/// Deneb: RandomPhase0Attestation(), +/// } +/// } +/// ``` +pub fn random_deneb_versioned_attestation() -> versioned::VersionedAttestation { + let mut rng = rand::thread_rng(); + + let attestation = phase0::Attestation { + aggregation_bits: phase0::BitList::default(), + data: phase0::AttestationData { + slot: rng.r#gen(), + index: rng.r#gen(), + beacon_block_root: random_root_bytes(), + source: phase0::Checkpoint { + epoch: rng.r#gen(), + root: random_root_bytes(), + }, + target: phase0::Checkpoint { + epoch: rng.r#gen(), + root: random_root_bytes(), + }, + }, + signature: random_eth2_signature_bytes(), + }; + + versioned::VersionedAttestation { + version: versioned::DataVersion::Deneb, + validator_index: Some(rng.r#gen()), + attestation: Some(AttestationPayload::Deneb(attestation)), + } +} + #[cfg(test)] mod tests { use super::*; @@ -150,4 +306,92 @@ mod tests { "Different seeds should produce different BLS keys" ); } + + #[test] + fn test_random_eth2_signature() { + let sig1 = random_eth2_signature(); + let sig2 = random_eth2_signature(); + + // Check format + assert!(sig1.starts_with("0x")); + // 96 bytes = 192 hex chars + "0x" prefix = 194 total + assert_eq!(sig1.len(), 194); + + // Different calls should produce different signatures + assert_ne!(sig1, sig2); + } + + #[test] + fn test_random_root() { + let root1 = random_root(); + let root2 = random_root(); + + // Check format + assert!(root1.starts_with("0x")); + // 32 bytes = 64 hex chars + "0x" prefix = 66 total + assert_eq!(root1.len(), 66); + + // Different calls should produce different roots + assert_ne!(root1, root2); + } + + #[test] + fn test_random_bit_list() { + let bitlist = random_bit_list(5); + + // Check format + assert!(bitlist.starts_with("0x")); + // 32 bytes = 64 hex chars + "0x" prefix = 66 total + assert_eq!(bitlist.len(), 66); + } + + #[test] + fn test_random_phase0_attestation() { + let att = random_phase0_attestation(); + + // Check that all fields are populated + assert!(att.aggregation_bits.starts_with("0x")); + assert!(att.signature.starts_with("0x")); + assert!(att.data.beacon_block_root.starts_with("0x")); + assert!(!att.data.slot.is_empty()); + assert!(!att.data.index.is_empty()); + } + + #[test] + fn test_random_deneb_versioned_attestation() { + let versioned_att = random_deneb_versioned_attestation(); + + // Check version is Deneb + assert!(matches!( + versioned_att.version, + versioned::DataVersion::Deneb + )); + + // Check that data is populated + match versioned_att.attestation { + Some(AttestationPayload::Deneb(att)) => { + assert_eq!(att.signature.len(), 96); + } + _ => panic!("Expected Deneb attestation"), + } + } + + #[test] + fn test_random_deneb_versioned_attestation_different() { + let att1 = random_deneb_versioned_attestation(); + let att2 = random_deneb_versioned_attestation(); + + // Different calls should produce different attestations + // Check signatures are different + let sig1 = match &att1.attestation { + Some(AttestationPayload::Deneb(a)) => &a.signature, + _ => panic!("Expected Deneb attestation"), + }; + let sig2 = match &att2.attestation { + Some(AttestationPayload::Deneb(a)) => &a.signature, + _ => panic!("Expected Deneb attestation"), + }; + + assert_ne!(sig1, sig2); + } }