diff --git a/Cargo.lock b/Cargo.lock index b7466ffc..a04b609f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5561,6 +5561,7 @@ name = "pluto-dkg" version = "1.7.1" dependencies = [ "anyhow", + "bon", "clap", "either", "futures", diff --git a/crates/dkg/Cargo.toml b/crates/dkg/Cargo.toml index 1d1a19b9..34b7f791 100644 --- a/crates/dkg/Cargo.toml +++ b/crates/dkg/Cargo.toml @@ -7,12 +7,14 @@ license.workspace = true publish.workspace = true [dependencies] +bon.workspace = true prost.workspace = true prost-types.workspace = true thiserror.workspace = true libp2p.workspace = true futures.workspace = true tokio.workspace = true +tokio-util.workspace = true sha2.workspace = true tracing.workspace = true either.workspace = true @@ -20,6 +22,7 @@ k256.workspace = true pluto-k1util.workspace = true pluto-p2p.workspace = true pluto-cluster.workspace = true +pluto-core.workspace = true pluto-crypto.workspace = true pluto-eth1wrap.workspace = true pluto-eth2util.workspace = true @@ -38,11 +41,9 @@ anyhow.workspace = true clap.workspace = true hex.workspace = true pluto-cluster = { workspace = true, features = ["test-cluster"] } -pluto-core.workspace = true pluto-testutil.workspace = true pluto-tracing.workspace = true serde_json.workspace = true -tokio-util.workspace = true tempfile.workspace = true [lints] diff --git a/crates/dkg/examples/sync.rs b/crates/dkg/examples/sync.rs new file mode 100644 index 00000000..906fb11b --- /dev/null +++ b/crates/dkg/examples/sync.rs @@ -0,0 +1,735 @@ +//! Relay-based example for the DKG sync protocol. +//! +//! This example follows the same high-level shape as the `bcast` example: +//! - load a local private key from a data directory +//! - load cluster peers from `cluster-lock.json` +//! - resolve relay URLs with `bootnode::new_relays` +//! - create relay reservations and relay routing +//! - run `sync` over relay-mediated connectivity +//! +//! To try it locally: +//! +//! ```text +//! # Terminal 1: start a relay server +//! cargo run -p pluto-relay-server --example relay_server +//! +//! # Terminals 2-4: run three node directories from the same cluster +//! cargo run -p pluto-dkg --example sync -- \ +//! --relays http://127.0.0.1:8888 \ +//! --data-dir /path/to/node0 +//! +//! cargo run -p pluto-dkg --example sync -- \ +//! --relays http://127.0.0.1:8888 \ +//! --data-dir /path/to/node1 +//! +//! cargo run -p pluto-dkg --example sync -- \ +//! --relays http://127.0.0.1:8888 \ +//! --data-dir /path/to/node2 +//! ``` +//! +//! Assumption: +//! - the three data directories already exist +//! - each one belongs to one node in the same cluster +//! +//! Required files in each data directory: +//! - `charon-enr-private-key` +//! - `cluster-lock.json` +//! +//! Expected flow: +//! 1. Each node loads the same cluster peer order from the lock file. +//! 2. Nodes resolve the configured relays and establish relay reservations. +//! 3. The relay router dials known cluster peers through relay circuits. +//! 4. Each node starts one sync client per remote peer. +//! 5. Once all clients are connected, the demo advances through steps 1 and 2. +//! 6. The demo keeps the sync clients running in steady state. +//! 7. Press `Ctrl+C` on any node to stop that node immediately and let the +//! other nodes observe the fault. +//! +//! Success signals: +//! - `Relay reservation accepted` +//! - `Connection established` with `peer_type="CLUSTER"` +//! - `All sync clients connected` +//! - `Sync step reached` +//! - `Sync demo is now idling until Ctrl+C` +//! - `Sync steady-state heartbeat` +//! - `Ctrl+C received, exiting without graceful shutdown` +//! +//! Transient relay warnings can occur during startup and reconnects. The demo +//! is healthy once all cluster peers are connected and the sync steps complete. +#![allow(missing_docs)] + +use std::{ + collections::{HashMap, HashSet}, + path::PathBuf, + str::FromStr, + time::Duration, +}; + +use anyhow::{Context as _, Result}; +use clap::Parser; +use futures::StreamExt; +use libp2p::{ + PeerId, identify, ping, + relay::{self}, + swarm::{NetworkBehaviour, SwarmEvent}, +}; +use pluto_cluster::lock::Lock; +use pluto_core::version::VERSION; +use pluto_dkg::sync::{self, Client, Server}; +use pluto_p2p::{ + behaviours::pluto::PlutoBehaviourEvent, + bootnode, + config::P2PConfig, + gater, k1, + p2p::{Node, NodeType}, + p2p_context::P2PContext, + relay::{MutableRelayReservation, RelayRouter}, +}; +use pluto_tracing::TracingConfig; +use tokio::{fs, signal, task::JoinHandle}; +use tokio_util::sync::CancellationToken; +use tracing::{debug, error, info}; + +#[derive(NetworkBehaviour)] +struct ExampleBehaviour { + relay: relay::client::Behaviour, + relay_reservation: MutableRelayReservation, + relay_router: RelayRouter, + sync: sync::Behaviour, +} + +#[derive(Debug, Parser)] +#[command(name = "sync-example")] +#[command(about = "Run a relay-based DKG sync demo node")] +struct Args { + /// Relay URLs or relay multiaddrs to use. + #[arg(long, value_delimiter = ',')] + relays: Vec, + + /// Data directory containing `charon-enr-private-key` and + /// `cluster-lock.json`. + #[arg(long)] + data_dir: PathBuf, + + /// Additional known peers to allow and route via relays. + #[arg(long, value_delimiter = ',')] + known_peers: Vec, + + /// Whether to filter private addresses from advertisements. + #[arg(short, long, default_value_t = false)] + filter_private_addrs: bool, + + /// The external IP address of the node. + #[arg(long)] + external_ip: Option, + + /// The external host of the node. + #[arg(long)] + external_host: Option, + + /// TCP addresses to listen on. + #[arg(long)] + tcp_addrs: Vec, + + /// UDP addresses to listen on. + #[arg(long)] + udp_addrs: Vec, + + /// Whether to disable reuse port. + #[arg(long, default_value_t = false)] + disable_reuse_port: bool, +} + +#[derive(Debug, Clone)] +struct ClusterInfo { + peers: Vec, + indices: HashMap, + local_peer_id: PeerId, + local_node_number: u32, +} + +impl ClusterInfo { + fn expected_connections(&self) -> usize { + self.peers.len().saturating_sub(1) + } + + fn peer_label(&self, peer_id: &PeerId) -> String { + match self.indices.get(peer_id) { + Some(index) => format!( + "node={} peer_id={peer_id}", + index.checked_add(1).unwrap_or(*index) + ), + None => format!("peer_id={peer_id}"), + } + } + + fn peer_labels_where(&self, predicate: impl Fn(&PeerId) -> bool) -> Vec { + self.peers + .iter() + .filter(|peer_id| predicate(peer_id)) + .map(|peer_id| self.peer_label(peer_id)) + .collect() + } + + fn connected_peers(&self, connected: &HashSet) -> Vec { + self.peer_labels_where(|peer_id| connected.contains(peer_id)) + } + + fn missing_peers(&self, connected: &HashSet) -> Vec { + self.peer_labels_where(|peer_id| { + *peer_id != self.local_peer_id && !connected.contains(peer_id) + }) + } +} + +fn peer_type( + peer_id: &PeerId, + relay_peer_ids: &HashSet, + cluster_info: &ClusterInfo, +) -> &'static str { + if relay_peer_ids.contains(peer_id) { + "RELAY" + } else if cluster_info.indices.contains_key(peer_id) { + "CLUSTER" + } else { + "UNKNOWN" + } +} + +fn merge_known_peers( + cluster_peers: &[PeerId], + configured_known_peers: &[String], +) -> Result> { + let capacity = cluster_peers + .len() + .checked_add(configured_known_peers.len()) + .context("known peer capacity overflow")?; + let mut known_peers = Vec::with_capacity(capacity); + known_peers.extend(cluster_peers.iter().copied()); + let mut known_peer_ids = HashSet::with_capacity(capacity); + known_peer_ids.extend(known_peers.iter().copied()); + + for peer in configured_known_peers { + let peer_id = PeerId::from_str(peer) + .with_context(|| format!("failed to parse known peer id: {peer}"))?; + if known_peer_ids.insert(peer_id) { + known_peers.push(peer_id); + } + } + + Ok(known_peers) +} + +fn local_node_number(cluster_peers: &[PeerId], local_peer_id: PeerId) -> Result { + let index = cluster_peers + .iter() + .position(|peer_id| peer_id == &local_peer_id) + .context("local peer id is not present in the cluster lock")?; + let node_number = index + .checked_add(1) + .context("cluster peer index overflow")?; + u32::try_from(node_number).context("cluster peer index does not fit in u32") +} + +fn endpoint_address(endpoint: &libp2p::core::ConnectedPoint) -> &libp2p::Multiaddr { + match endpoint { + libp2p::core::ConnectedPoint::Dialer { address, .. } => address, + libp2p::core::ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr, + } +} + +fn connection_log_fields<'a>( + peer_id: PeerId, + endpoint: &'a libp2p::core::ConnectedPoint, + relay_peer_ids: &HashSet, + cluster_info: &ClusterInfo, +) -> (String, &'static str, &'a libp2p::Multiaddr) { + ( + cluster_info.peer_label(&peer_id), + peer_type(&peer_id, relay_peer_ids, cluster_info), + endpoint_address(endpoint), + ) +} + +fn log_relay_event(relay_event: relay::client::Event, cluster_info: &ClusterInfo) { + match relay_event { + relay::client::Event::ReservationReqAccepted { + relay_peer_id, + renewal, + limit, + } => { + debug!( + relay_peer_id = %relay_peer_id, + renewal, + limit = ?limit, + "Relay reservation accepted" + ); + } + relay::client::Event::OutboundCircuitEstablished { + relay_peer_id, + limit, + } => { + debug!( + relay_peer_id = %relay_peer_id, + limit = ?limit, + "Outbound relay circuit established" + ); + } + relay::client::Event::InboundCircuitEstablished { src_peer_id, limit } => { + debug!( + src_peer_id = %src_peer_id, + peer_label = %cluster_info.peer_label(&src_peer_id), + limit = ?limit, + "Inbound relay circuit established" + ); + } + } +} + +fn log_connection_established( + peer_id: PeerId, + endpoint: &libp2p::core::ConnectedPoint, + num_established: std::num::NonZero, + relay_peer_ids: &HashSet, + cluster_info: &ClusterInfo, +) { + let (peer_label, peer_type, address) = + connection_log_fields(peer_id, endpoint, relay_peer_ids, cluster_info); + debug!( + peer_id = %peer_id, + peer_label = %peer_label, + peer_type, + address = %address, + num_established = num_established.get(), + "Connection established" + ); +} + +fn log_identify_event( + peer_id: PeerId, + info: identify::Info, + relay_peer_ids: &HashSet, + cluster_info: &ClusterInfo, +) { + debug!( + peer_id = %peer_id, + peer_type = peer_type(&peer_id, relay_peer_ids, cluster_info), + agent_version = %info.agent_version, + protocol_version = %info.protocol_version, + num_addresses = info.listen_addrs.len(), + "Received identify from peer" + ); +} + +fn log_ping_event( + peer: PeerId, + result: Result, + relay_peer_ids: &HashSet, + cluster_info: &ClusterInfo, +) { + match result { + Ok(rtt) => debug!( + peer_id = %peer, + peer_type = peer_type(&peer, relay_peer_ids, cluster_info), + rtt = ?rtt, + "Received ping" + ), + Err(error) => debug!( + peer_id = %peer, + peer_type = peer_type(&peer, relay_peer_ids, cluster_info), + err = %error, + "Ping failed" + ), + } +} + +fn print_cluster_overview(cluster_info: &ClusterInfo) { + info!("Cluster peer order:"); + for (index, peer_id) in cluster_info.peers.iter().enumerate() { + let local_marker = if *peer_id == cluster_info.local_peer_id { + " (local)" + } else { + "" + }; + info!( + peer_index = index.checked_add(1).unwrap_or(index), + peer_id = %peer_id, + local = %local_marker, + "Cluster peer" + ); + } +} + +async fn run_sync( + server: Server, + clients: Vec, + cluster_info: ClusterInfo, + cancellation: CancellationToken, +) -> Result<()> { + server.start(); + info!( + local_node = cluster_info.local_node_number, + expected_clients = clients.len(), + "Started sync server" + ); + + let mut client_joins = Vec::with_capacity(clients.len()); + for client in &clients { + let client = client.clone(); + let cancellation = cancellation.child_token(); + client_joins.push(tokio::spawn(async move { client.run(cancellation).await })); + } + + // First wait until all local per-peer sync clients report connected. + // The shared server barrier below then confirms the whole cluster has + // observed all peer connections. + let mut previous_connected = None; + loop { + let connected = clients + .iter() + .filter(|client| client.is_connected()) + .count(); + if previous_connected != Some(connected) { + info!( + local_node = cluster_info.local_node_number, + connected, + expected = clients.len(), + "Sync client connectivity update" + ); + previous_connected = Some(connected); + } + + if connected == clients.len() { + break; + } + + tokio::select! { + _ = cancellation.cancelled() => break, + _ = tokio::time::sleep(Duration::from_millis(100)) => {} + } + } + + // Once all local sync clients are connected, wait for the shared server + // to observe the full cluster barrier and then drive the demo through a + // couple of synchronized steps. + if !cancellation.is_cancelled() { + info!( + local_node = cluster_info.local_node_number, + connected = clients.len(), + "All sync clients connected" + ); + + for client in &clients { + client.disable_reconnect(); + } + + match server.await_all_connected(cancellation.child_token()).await { + Ok(()) | Err(sync::Error::Canceled) => {} + Err(error) => return Err(anyhow::anyhow!(error.to_string())), + } + + for step in 1_i64..=2 { + for client in &clients { + client.set_step(step); + } + info!( + local_node = cluster_info.local_node_number, + step, "Waiting for sync step" + ); + match server + .await_all_at_step(step, cancellation.child_token()) + .await + { + Ok(()) => { + info!( + local_node = cluster_info.local_node_number, + step, "Sync step reached" + ); + } + Err(sync::Error::Canceled) => break, + Err(error) => return Err(anyhow::anyhow!(error.to_string())), + } + + if step < 2 { + tokio::select! { + _ = cancellation.cancelled() => break, + _ = tokio::time::sleep(Duration::from_secs(3)) => {} + } + } + } + } + + if !cancellation.is_cancelled() { + info!( + local_node = cluster_info.local_node_number, + "Sync demo is now idling until Ctrl+C" + ); + + let mut heartbeat = tokio::time::interval(Duration::from_secs(5)); + loop { + tokio::select! { + _ = cancellation.cancelled() => break, + _ = heartbeat.tick() => { + let connected = clients.iter().filter(|client| client.is_connected()).count(); + info!( + local_node = cluster_info.local_node_number, + connected, + expected = clients.len(), + "Sync steady-state heartbeat" + ); + } + } + } + } + + if cancellation.is_cancelled() { + info!( + local_node = cluster_info.local_node_number, + "Cancellation received, exiting without graceful shutdown" + ); + } + + for join in client_joins { + match join.await { + Ok(Ok(())) | Ok(Err(sync::Error::Canceled)) => {} + Ok(Err(error)) => return Err(anyhow::anyhow!(error.to_string())), + Err(error) => return Err(anyhow::anyhow!(error.to_string())), + } + } + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<()> { + pluto_tracing::init(&TracingConfig::default()).expect("failed to initialize tracing"); + + let args = Args::parse(); + let key = k1::load_priv_key(&args.data_dir).expect("Failed to load private key"); + let local_peer_id = pluto_p2p::peer::peer_id_from_key(key.public_key()) + .expect("Failed to derive local peer ID"); + + let lock_path = args.data_dir.join("cluster-lock.json"); + let lock_str = fs::read_to_string(&lock_path) + .await + .expect("Failed to load lock"); + let lock: Lock = serde_json::from_str(&lock_str).expect("Failed to parse lock"); + + let cluster_peers = lock.peer_ids().expect("Failed to get lock peer IDs"); + let local_node_number = local_node_number(&cluster_peers, local_peer_id) + .expect("Failed to derive local node number"); + let mut indices = HashMap::with_capacity(cluster_peers.len()); + for (index, peer_id) in cluster_peers.iter().copied().enumerate() { + indices.insert(peer_id, index); + } + let cluster_info = ClusterInfo { + peers: cluster_peers.clone(), + indices, + local_peer_id, + local_node_number, + }; + + let cancellation = CancellationToken::new(); + let lock_hash_hex = hex::encode(&lock.lock_hash); + let relays = bootnode::new_relays(cancellation.child_token(), &args.relays, &lock_hash_hex) + .await + .context("failed to resolve relays")?; + let relay_peer_ids = relays + .iter() + .filter_map(|relay| relay.peer().ok().flatten().map(|peer| peer.id)) + .collect::>(); + + let known_peers = merge_known_peers(&cluster_peers, &args.known_peers)?; + + let conn_gater = gater::ConnGater::new( + gater::Config::closed() + .with_relays(relays.clone()) + .with_peer_ids(known_peers.clone()), + ); + + let p2p_config = P2PConfig { + relays: vec![], + external_ip: args.external_ip, + external_host: args.external_host, + tcp_addrs: args.tcp_addrs, + udp_addrs: args.udp_addrs, + disable_reuse_port: args.disable_reuse_port, + }; + + let version = VERSION.to_minor(); + let p2p_context = P2PContext::new(known_peers.clone()); + p2p_context.set_local_peer_id(local_peer_id); + let (sync_behaviour, server, clients) = sync::new( + cluster_peers.clone(), + p2p_context.clone(), + &key, + lock.lock_hash.clone(), + version, + )?; + + let mut node: Node = Node::new( + p2p_config, + key, + NodeType::QUIC, + args.filter_private_addrs, + known_peers, + { + let p2p_context = p2p_context.clone(); + move |builder, keypair, relay_client| { + let p2p_context = p2p_context.clone(); + let local_peer_id = keypair.public().to_peer_id(); + + builder + .with_p2p_context(p2p_context.clone()) + .with_gater(conn_gater) + .with_inner(ExampleBehaviour { + relay: relay_client, + relay_reservation: MutableRelayReservation::new(relays.clone()), + relay_router: RelayRouter::new(relays.clone(), p2p_context, local_peer_id), + sync: sync_behaviour, + }) + } + }, + )?; + + info!( + local_peer_id = %local_peer_id, + local_node = local_node_number, + data_dir = %args.data_dir.display(), + "Started sync example" + ); + print_cluster_overview(&cluster_info); + + let mut connected_cluster_peers = + HashSet::::with_capacity(cluster_info.expected_connections()); + let mut demo_task: JoinHandle> = tokio::spawn(run_sync( + server, + clients, + cluster_info.clone(), + cancellation.child_token(), + )); + + loop { + tokio::select! { + event = node.select_next_some() => { + match event { + SwarmEvent::Behaviour(PlutoBehaviourEvent::Inner( + ExampleBehaviourEvent::Relay(relay_event), + )) => { + log_relay_event(relay_event, &cluster_info); + } + SwarmEvent::Behaviour(PlutoBehaviourEvent::Ping(ping::Event { + peer, + result, + .. + })) => { + log_ping_event(peer, result, &relay_peer_ids, &cluster_info); + } + SwarmEvent::Behaviour(PlutoBehaviourEvent::Identify( + identify::Event::Received { peer_id, info, .. }, + )) => { + log_identify_event(peer_id, info, &relay_peer_ids, &cluster_info); + } + SwarmEvent::ConnectionEstablished { + peer_id, + endpoint, + num_established, + .. + } => { + log_connection_established( + peer_id, + &endpoint, + num_established, + &relay_peer_ids, + &cluster_info, + ); + if cluster_info.indices.contains_key(&peer_id) { + connected_cluster_peers.insert(peer_id); + debug!( + connected = connected_cluster_peers.len(), + expected = cluster_info.expected_connections(), + connected_peers = ?cluster_info.connected_peers(&connected_cluster_peers), + missing_peers = ?cluster_info.missing_peers(&connected_cluster_peers), + "Cluster connectivity update" + ); + } + } + SwarmEvent::ConnectionClosed { + peer_id, + cause, + .. + } => { + if cluster_info.indices.contains_key(&peer_id) + && connected_cluster_peers.remove(&peer_id) + { + error!( + local_node = cluster_info.local_node_number, + peer_id = %peer_id, + peer_label = %cluster_info.peer_label(&peer_id), + connected = connected_cluster_peers.len(), + expected = cluster_info.expected_connections(), + missing_peers = ?cluster_info.missing_peers(&connected_cluster_peers), + cause = ?cause, + "Cluster peer disconnected" + ); + } + } + SwarmEvent::OutgoingConnectionError { + peer_id, + connection_id, + error: err, + } => { + debug!( + ?peer_id, + ?connection_id, + %err, + "Outgoing connection error" + ); + } + SwarmEvent::IncomingConnectionError { + connection_id, + local_addr, + send_back_addr, + error: err, + .. + } => { + debug!( + ?connection_id, + %local_addr, + %send_back_addr, + %err, + "Incoming connection error" + ); + } + SwarmEvent::NewListenAddr { address, .. } => { + debug!(%address, "Listening on address"); + } + _ => {} + } + } + result = &mut demo_task => { + match result { + Ok(Ok(())) => { + info!("Sync demo completed successfully"); + break; + } + Ok(Err(error)) => { + error!(err = %error, "Sync demo failed"); + break; + } + Err(error) => { + error!(err = %error, "Sync demo task failed"); + break; + } + } + } + _ = signal::ctrl_c() => { + info!("Ctrl+C received, shutting down"); + cancellation.cancel(); + } + } + } + + cancellation.cancel(); + Ok(()) +} diff --git a/crates/dkg/src/dkgpb/v1.rs b/crates/dkg/src/dkgpb/v1.rs index 2c469269..ccb5fcb2 100644 --- a/crates/dkg/src/dkgpb/v1.rs +++ b/crates/dkg/src/dkgpb/v1.rs @@ -5,4 +5,5 @@ pub mod frost; /// Nodesigs protobuf definitions. pub mod nodesigs; /// Sync protobuf definitions. +#[path = "v1/dkg.dkgpb.v1.rs"] pub mod sync; diff --git a/crates/dkg/src/dkgpb/v1/sync.rs b/crates/dkg/src/dkgpb/v1/dkg.dkgpb.v1.rs similarity index 80% rename from crates/dkg/src/dkgpb/v1/sync.rs rename to crates/dkg/src/dkgpb/v1/dkg.dkgpb.v1.rs index 428d864a..ed928107 100644 --- a/crates/dkg/src/dkgpb/v1/sync.rs +++ b/crates/dkg/src/dkgpb/v1/dkg.dkgpb.v1.rs @@ -19,12 +19,12 @@ pub struct MsgSync { } impl ::prost::Name for MsgSync { const NAME: &'static str = "MsgSync"; - const PACKAGE: &'static str = "sync"; + const PACKAGE: &'static str = "dkg.dkgpb.v1"; fn full_name() -> ::prost::alloc::string::String { - "sync.MsgSync".into() + "dkg.dkgpb.v1.MsgSync".into() } fn type_url() -> ::prost::alloc::string::String { - "type.googleapis.com/sync.MsgSync".into() + "type.googleapis.com/dkg.dkgpb.v1.MsgSync".into() } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -36,11 +36,11 @@ pub struct MsgSyncResponse { } impl ::prost::Name for MsgSyncResponse { const NAME: &'static str = "MsgSyncResponse"; - const PACKAGE: &'static str = "sync"; + const PACKAGE: &'static str = "dkg.dkgpb.v1"; fn full_name() -> ::prost::alloc::string::String { - "sync.MsgSyncResponse".into() + "dkg.dkgpb.v1.MsgSyncResponse".into() } fn type_url() -> ::prost::alloc::string::String { - "type.googleapis.com/sync.MsgSyncResponse".into() + "type.googleapis.com/dkg.dkgpb.v1.MsgSyncResponse".into() } } \ No newline at end of file diff --git a/crates/dkg/src/dkgpb/v1/sync.proto b/crates/dkg/src/dkgpb/v1/sync.proto index 3849c4db..8cf517c8 100644 --- a/crates/dkg/src/dkgpb/v1/sync.proto +++ b/crates/dkg/src/dkgpb/v1/sync.proto @@ -1,9 +1,11 @@ syntax = "proto3"; -package sync; +package dkg.dkgpb.v1; import "google/protobuf/timestamp.proto"; +option go_package = "github.com/obolnetwork/charon/dkg/dkgpb/v1"; + message MsgSync { google.protobuf.Timestamp timestamp = 1; bytes hash_signature = 2; diff --git a/crates/dkg/src/lib.rs b/crates/dkg/src/lib.rs index ce0e6cde..0226167b 100644 --- a/crates/dkg/src/lib.rs +++ b/crates/dkg/src/lib.rs @@ -19,3 +19,6 @@ pub mod dkg; /// Shares distributed to each node in the cluster. pub mod share; + +/// Step synchronization protocol for DKG peers. +pub mod sync; diff --git a/crates/dkg/src/sync/behaviour.rs b/crates/dkg/src/sync/behaviour.rs new file mode 100644 index 00000000..0df6b2c3 --- /dev/null +++ b/crates/dkg/src/sync/behaviour.rs @@ -0,0 +1,291 @@ +use std::{ + collections::{HashMap, VecDeque}, + task::{Context, Poll}, +}; + +use either::Either; +use libp2p::{ + Multiaddr, PeerId, + swarm::{ + ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, + THandlerOutEvent, ToSwarm, + dial_opts::{DialOpts, PeerCondition}, + dummy, + }, +}; +use pluto_p2p::p2p_context::P2PContext; +use tokio::sync::mpsc; + +use super::{Command, client::Client, handler::Handler, server::Server}; + +/// Event emitted by the sync behaviour. +#[derive(Debug, Clone)] +pub enum Event { + /// A peer advanced to a new sync step. + PeerStepUpdated { + /// The peer whose step was updated. + peer_id: PeerId, + /// The updated step. + step: i64, + }, + /// A peer requested graceful shutdown through sync. + PeerShutdownObserved { + /// The peer that requested shutdown. + peer_id: PeerId, + }, + /// A peer sent a sync message that failed. + SyncRejected { + /// The peer whose sync message was rejected. + peer_id: PeerId, + /// The validation error. + error: super::error::Error, + }, +} + +/// Swarm behaviour backing the DKG sync protocol. +pub struct Behaviour { + server: Server, + clients: HashMap, + p2p_context: P2PContext, + command_rx: mpsc::UnboundedReceiver, + pending_events: VecDeque>>, +} + +impl Behaviour { + /// Creates a new sync behaviour from a server and client handles. + pub(crate) fn new( + server: Server, + clients: impl IntoIterator, + p2p_context: P2PContext, + command_rx: mpsc::UnboundedReceiver, + ) -> Self { + Self { + server, + clients: clients + .into_iter() + .map(|client| (client.peer_id(), client)) + .collect(), + p2p_context, + command_rx, + pending_events: VecDeque::new(), + } + } + + fn connection_handler_for_peer(&self, peer: PeerId) -> THandler { + match self.clients.get(&peer) { + Some(client) => Either::Left(Handler::new( + peer, + self.server.clone(), + Some(client.clone()), + )), + None => Either::Right(dummy::ConnectionHandler), + } + } + + fn handle_command(&mut self, command: Command) { + match command { + Command::Activate(peer_id) => { + let Some(client) = self.clients.get(&peer_id) else { + return; + }; + + if !client.should_run() || client.is_connected() || client.shutdown_requested() { + return; + } + + if !self + .p2p_context + .peer_store_lock() + .connections_to_peer(&peer_id) + .is_empty() + { + return; + } + + self.pending_events.push_back(ToSwarm::Dial { + opts: DialOpts::peer_id(peer_id) + .condition(PeerCondition::DisconnectedAndNotDialing) + .build(), + }); + } + } + } +} + +impl NetworkBehaviour for Behaviour { + type ConnectionHandler = Either; + type ToSwarm = Event; + + fn handle_established_inbound_connection( + &mut self, + _connection_id: ConnectionId, + peer: PeerId, + _local_addr: &Multiaddr, + _remote_addr: &Multiaddr, + ) -> std::result::Result, ConnectionDenied> { + Ok(self.connection_handler_for_peer(peer)) + } + + fn handle_established_outbound_connection( + &mut self, + _connection_id: ConnectionId, + peer: PeerId, + _addr: &Multiaddr, + _role_override: libp2p::core::Endpoint, + _port_use: libp2p::core::transport::PortUse, + ) -> std::result::Result, ConnectionDenied> { + Ok(self.connection_handler_for_peer(peer)) + } + + fn on_swarm_event(&mut self, event: FromSwarm) { + match event { + FromSwarm::ConnectionClosed(event) => { + if event.remaining_established > 0 { + return; + } + + if let Some(client) = self.clients.get(&event.peer_id) { + client.set_connected(false); + client.release_outbound(); + } + } + FromSwarm::DialFailure(event) => { + let _ = event; + } + _ => {} + } + } + + fn on_connection_handler_event( + &mut self, + _peer_id: PeerId, + _connection_id: ConnectionId, + event: THandlerOutEvent, + ) { + match event { + Either::Left(event) => self.pending_events.push_back(ToSwarm::GenerateEvent(event)), + Either::Right(unreachable) => match unreachable {}, + } + } + + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + while let Poll::Ready(Some(command)) = self.command_rx.poll_recv(cx) { + self.handle_command(command); + } + + if let Some(event) = self.pending_events.pop_front() { + return Poll::Ready(event); + } + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use std::task::Context; + + use futures::task::noop_waker_ref; + use libp2p::{ + core::{ConnectedPoint, Endpoint, transport::PortUse}, + swarm::{ + ConnectionClosed, ConnectionId, FromSwarm, NetworkBehaviour, ToSwarm, + dial_opts::DialOpts, + }, + }; + use pluto_core::version::SemVer; + use tokio::sync::mpsc; + + use super::*; + use crate::sync::ClientConfig; + + fn test_behaviour(client: Client) -> Behaviour { + let (_unused_tx, command_rx) = mpsc::unbounded_channel(); + let version = SemVer::parse("v1.7").expect("valid version"); + let p2p_context = P2PContext::new([client.peer_id()]); + Behaviour::new( + Server::new(1, vec![1, 2, 3], version), + [client], + p2p_context, + command_rx, + ) + } + + #[test] + fn active_client_requests_dial() { + let (command_tx, command_rx) = mpsc::unbounded_channel(); + let version = SemVer::parse("v1.7").expect("valid version"); + let peer_id = PeerId::random(); + let client = Client::new( + peer_id, + vec![1, 2, 3], + version.clone(), + Default::default(), + Some(command_tx), + ); + let server = Server::new(1, vec![1, 2, 3], version); + let p2p_context = P2PContext::new([peer_id]); + let mut behaviour = Behaviour::new(server, [client.clone()], p2p_context, command_rx); + + client.activate().expect("activate should succeed"); + + let waker = noop_waker_ref(); + let mut cx = Context::from_waker(waker); + let poll = NetworkBehaviour::poll(&mut behaviour, &mut cx); + + let Poll::Ready(ToSwarm::Dial { opts }) = poll else { + panic!("expected dial event"); + }; + + assert_eq!(DialOpts::get_peer_id(&opts), Some(peer_id)); + } + + #[test] + fn connection_closed_keeps_client_state_until_last_connection() { + let version = SemVer::parse("v1.7").expect("valid version"); + let peer_id = PeerId::random(); + let client = Client::new( + peer_id, + vec![1, 2, 3], + version, + ClientConfig::default(), + None, + ); + client.set_connected(true); + assert!(client.try_claim_outbound()); + + let mut behaviour = test_behaviour(client.clone()); + + let address = "/ip4/127.0.0.1/tcp/9000".parse().expect("valid multiaddr"); + let endpoint = ConnectedPoint::Dialer { + address, + role_override: Endpoint::Dialer, + port_use: PortUse::New, + }; + + behaviour.on_swarm_event(FromSwarm::ConnectionClosed(ConnectionClosed { + peer_id, + connection_id: ConnectionId::new_unchecked(1), + endpoint: &endpoint, + cause: None, + remaining_established: 1, + })); + + assert!(client.is_connected()); + assert!(!client.try_claim_outbound()); + + behaviour.on_swarm_event(FromSwarm::ConnectionClosed(ConnectionClosed { + peer_id, + connection_id: ConnectionId::new_unchecked(2), + endpoint: &endpoint, + cause: None, + remaining_established: 0, + })); + + assert!(!client.is_connected()); + assert!(client.try_claim_outbound()); + } +} diff --git a/crates/dkg/src/sync/client.rs b/crates/dkg/src/sync/client.rs new file mode 100644 index 00000000..48d50be6 --- /dev/null +++ b/crates/dkg/src/sync/client.rs @@ -0,0 +1,243 @@ +use std::{ + sync::{ + Arc, + atomic::{AtomicBool, AtomicI64, Ordering}, + }, + time::Duration, +}; + +use bon::Builder; +use libp2p::PeerId; +use pluto_core::version::SemVer; +use tokio::sync::{mpsc, watch}; +use tokio_util::sync::CancellationToken; + +use super::{ + Command, + error::{Error, Result}, +}; + +/// Default period between sync messages. +pub const DEFAULT_PERIOD: Duration = Duration::from_millis(100); + +/// Configuration for a sync client. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Builder)] +pub struct ClientConfig { + /// Period between sync messages. + #[builder(default = DEFAULT_PERIOD)] + pub period: Duration, +} + +impl Default for ClientConfig { + fn default() -> Self { + Self::builder().build() + } +} + +#[derive(Debug)] +struct ClientInner { + peer_id: PeerId, + hash_sig: Vec, + version: SemVer, + period: Duration, + active: AtomicBool, + connected: AtomicBool, + reconnect: AtomicBool, + step: AtomicI64, + shutdown_requested: AtomicBool, + finished: AtomicBool, + outbound_claimed: AtomicBool, + done_tx: watch::Sender>>, + command_tx: Option>, +} + +/// User-facing handle for one outbound sync client. +#[derive(Debug, Clone)] +pub struct Client { + inner: Arc, +} + +impl Client { + /// Creates a new client with an explicit config. + pub(crate) fn new( + peer_id: PeerId, + hash_sig: Vec, + version: SemVer, + config: ClientConfig, + command_tx: Option>, + ) -> Self { + let (done_tx, _done_rx) = watch::channel(None); + Self { + inner: Arc::new(ClientInner { + peer_id, + hash_sig, + version, + period: config.period, + active: AtomicBool::new(false), + connected: AtomicBool::new(false), + reconnect: AtomicBool::new(true), + step: AtomicI64::new(0), + shutdown_requested: AtomicBool::new(false), + finished: AtomicBool::new(false), + outbound_claimed: AtomicBool::new(false), + done_tx, + command_tx, + }), + } + } + + /// Runs the client until shutdown, fatal error, or cancellation. + pub async fn run(&self, cancellation: CancellationToken) -> Result<()> { + self.activate()?; + self.wait_finished(cancellation, true).await + } + + /// Sets the current client step. + pub fn set_step(&self, step: i64) { + self.inner.step.store(step, Ordering::SeqCst); + } + + /// Returns whether the client currently has an active sync stream. + pub fn is_connected(&self) -> bool { + self.inner.connected.load(Ordering::SeqCst) + } + + /// Requests a graceful shutdown and waits for the client to finish. + pub async fn shutdown(&self, cancellation: CancellationToken) -> Result<()> { + self.inner.shutdown_requested.store(true, Ordering::SeqCst); + self.wait_finished(cancellation, false).await + } + + /// Disables reconnecting for non-relay disconnects. + pub fn disable_reconnect(&self) { + self.inner.reconnect.store(false, Ordering::SeqCst); + } + + pub(crate) fn peer_id(&self) -> PeerId { + self.inner.peer_id + } + + pub(crate) fn hash_sig(&self) -> &[u8] { + &self.inner.hash_sig + } + + pub(crate) fn version(&self) -> &SemVer { + &self.inner.version + } + + pub(crate) fn period(&self) -> Duration { + self.inner.period + } + + pub(crate) fn should_run(&self) -> bool { + self.inner.active.load(Ordering::SeqCst) + } + + pub(crate) fn should_reconnect(&self) -> bool { + self.inner.reconnect.load(Ordering::SeqCst) + } + + pub(crate) fn shutdown_requested(&self) -> bool { + self.inner.shutdown_requested.load(Ordering::SeqCst) + } + + pub(crate) fn step(&self) -> i64 { + self.inner.step.load(Ordering::SeqCst) + } + + pub(crate) fn set_connected(&self, connected: bool) { + self.inner.connected.store(connected, Ordering::SeqCst); + } + + pub(crate) fn try_claim_outbound(&self) -> bool { + !self.inner.outbound_claimed.swap(true, Ordering::SeqCst) + } + + pub(crate) fn release_outbound(&self) { + self.inner.outbound_claimed.store(false, Ordering::SeqCst); + } + + pub(crate) fn finish(&self, result: Result<()>) { + self.inner.active.store(false, Ordering::SeqCst); + self.inner.connected.store(false, Ordering::SeqCst); + self.release_outbound(); + + if !self.inner.finished.swap(true, Ordering::SeqCst) { + let _ = self.inner.done_tx.send(Some(result)); + } + } + + pub(crate) fn activate(&self) -> Result<()> { + self.inner.active.store(true, Ordering::SeqCst); + + if let Some(command_tx) = &self.inner.command_tx + && command_tx + .send(Command::Activate(self.inner.peer_id)) + .is_ok() + { + return Ok(()); + } + + self.inner.active.store(false, Ordering::SeqCst); + Err(Error::ActivationChannelUnavailable) + } + + async fn wait_finished( + &self, + cancellation: CancellationToken, + clear_on_cancel: bool, + ) -> Result<()> { + let mut done_rx = self.inner.done_tx.subscribe(); + + loop { + if let Some(result) = done_rx.borrow().clone() { + return result; + } + + tokio::select! { + _ = cancellation.cancelled() => { + if clear_on_cancel { + self.inner.active.store(false, Ordering::SeqCst); + self.inner.connected.store(false, Ordering::SeqCst); + } + return Err(Error::Canceled); + } + changed = done_rx.changed() => { + if changed.is_err() { + return Err(Error::CompletionChannelClosed); + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use libp2p::PeerId; + use pluto_core::version::SemVer; + + use super::*; + + #[tokio::test] + async fn run_fails_immediately_if_activation_channel_is_closed() { + let (command_tx, command_rx) = mpsc::unbounded_channel(); + drop(command_rx); + + let client = Client::new( + PeerId::random(), + vec![1, 2, 3], + SemVer::parse("v1.7").expect("version"), + ClientConfig::default(), + Some(command_tx), + ); + + let error = client + .run(CancellationToken::new()) + .await + .expect_err("closed activation channel should fail immediately"); + + assert!(matches!(error, Error::ActivationChannelUnavailable)); + assert!(!client.should_run()); + } +} diff --git a/crates/dkg/src/sync/error.rs b/crates/dkg/src/sync/error.rs new file mode 100644 index 00000000..7fffd255 --- /dev/null +++ b/crates/dkg/src/sync/error.rs @@ -0,0 +1,110 @@ +use libp2p::PeerId; + +/// Sync result type. +pub type Result = std::result::Result; + +/// Error type for the DKG sync protocol. +#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)] +pub enum Error { + /// The sync client was canceled. + #[error("sync client canceled")] + Canceled, + + /// The peer returned an application-level error. + #[error("peer responded with error: {0}")] + PeerRespondedWithError(String), + + /// The remote peer version did not match. + #[error("mismatching version; expect={expected}, got={got}")] + VersionMismatch { + /// The expected version string. + expected: String, + /// The received version string. + got: String, + }, + + /// The definition hash signature was invalid. + #[error("invalid definition hash signature")] + InvalidDefinitionHashSignature, + + /// The peer reported a step lower than the previous known step. + #[error("peer reported step is behind the last known step")] + PeerStepBehind, + + /// The peer reported a step too far ahead of the previous known step. + #[error("peer reported step is ahead the last known step")] + PeerStepAhead, + + /// The peer reported an invalid first step. + #[error("peer reported abnormal initial step, expected 0 or 1")] + AbnormalInitialStep, + + /// A peer was too far ahead for the awaited step. + #[error("peer step is too far ahead")] + PeerStepTooFarAhead, + + /// A checked step arithmetic operation overflowed. + #[error("step overflow")] + StepOverflow, + + /// The stream protocol could not be negotiated. + #[error("protocol negotiation failed")] + Unsupported, + + /// Failed to parse the peer version. + #[error("parse peer version: {0}")] + ParsePeerVersion(String), + + /// Failed to sign the definition hash. + #[error("sign definition hash: {0}")] + SignDefinitionHash(String), + + /// Failed to convert the local key to a libp2p keypair. + #[error("convert secret key to libp2p keypair: {0}")] + KeyConversion(String), + + /// Failed to decode a protobuf message. + #[error("protobuf decode failed: {0}")] + Decode(String), + + /// Failed to encode a protobuf message. + #[error("protobuf encode failed: {0}")] + Encode(String), + + /// An I/O error occurred while reading or writing the stream. + #[error("i/o error: {0}")] + Io(String), + + /// A peer ID could not be converted to a public key. + #[error("peer error: {0}")] + Peer(String), + + /// A sync server operation was attempted before the server was started. + #[error("sync server not started")] + ServerNotStarted, + + /// The sync client completion channel was closed unexpectedly. + #[error("sync client completion channel closed")] + CompletionChannelClosed, + + /// The sync client activation channel was unavailable. + #[error("sync client activation channel unavailable")] + ActivationChannelUnavailable, + + /// An inbound sync message failed validation. + #[error("invalid sync message: peer={peer} err={error}")] + InvalidSyncMessage { + /// The peer whose message was invalid. + peer: PeerId, + /// The validation error. + error: String, + }, + + /// The local peer ID was missing from the shared P2P context. + #[error("local peer id missing from p2p context")] + LocalPeerMissing, + + /// The configured peer set did not include the local peer ID. + #[error("local peer id missing from sync peer set")] + LocalPeerNotInPeerSet, +} diff --git a/crates/dkg/src/sync/handler.rs b/crates/dkg/src/sync/handler.rs new file mode 100644 index 00000000..a50f6c70 --- /dev/null +++ b/crates/dkg/src/sync/handler.rs @@ -0,0 +1,429 @@ +//! Connection handler for the DKG sync protocol. + +use std::{ + convert::Infallible, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use futures::{FutureExt, future::BoxFuture}; +use libp2p::{ + PeerId, Stream, + core::upgrade::ReadyUpgrade, + swarm::{ + ConnectionHandler, ConnectionHandlerEvent, StreamProtocol, StreamUpgradeError, + SubstreamProtocol, + handler::{ + ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, + }, + }, +}; +use prost_types::Timestamp; +use tokio::{sync::mpsc, time::Sleep}; +use tracing::{debug, info, warn}; + +use crate::dkgpb::v1::sync::{MsgSync, MsgSyncResponse}; + +use super::{ + Event, + client::Client, + error::{Error, Result}, + protocol, + server::Server, +}; + +const INITIAL_BACKOFF: Duration = Duration::from_millis(100); +const MAX_BACKOFF: Duration = Duration::from_secs(1); + +type InboundFuture = BoxFuture<'static, Result<()>>; + +/// Protocol-level events emitted by the sync handler. +pub type OutEvent = Event; + +enum OutboundState { + Idle, + OpenStream, + Running(BoxFuture<'static, OutboundExit>), + WaitingRetry(Pin>), + Disabled, +} + +enum OutboundExit { + GracefulShutdown, + Reconnectable { error: Error, relay_reset: bool }, + Fatal(Error), +} + +/// Sync connection handler. +pub struct Handler { + peer_id: PeerId, + server: Server, + client: Option, + inbound: Option, + inbound_events_tx: mpsc::UnboundedSender, + inbound_events_rx: mpsc::UnboundedReceiver, + outbound: OutboundState, + backoff: Duration, +} + +impl Handler { + /// Creates a new handler for a single connection. + pub fn new(peer_id: PeerId, server: Server, client: Option) -> Self { + let (inbound_events_tx, inbound_events_rx) = mpsc::unbounded_channel(); + Self { + peer_id, + server, + client, + inbound: None, + inbound_events_tx, + inbound_events_rx, + outbound: OutboundState::Idle, + backoff: INITIAL_BACKOFF, + } + } + + fn substream_protocol(&self) -> SubstreamProtocol> { + SubstreamProtocol::new(ReadyUpgrade::new(protocol::PROTOCOL_NAME), ()) + } + + fn schedule_retry(&mut self) { + let sleep = Box::pin(tokio::time::sleep(self.backoff)); + self.outbound = OutboundState::WaitingRetry(sleep); + self.backoff = self.backoff.saturating_mul(2).min(MAX_BACKOFF); + } + + fn try_request_outbound( + &mut self, + ) -> Option, (), OutEvent>> { + let client = self.client.as_ref()?; + if !client.should_run() || !client.try_claim_outbound() { + return None; + } + + self.outbound = OutboundState::OpenStream; + Some(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: self.substream_protocol(), + }) + } + + fn on_dial_upgrade_error( + &mut self, + DialUpgradeError { error, .. }: DialUpgradeError< + (), + ::OutboundProtocol, + >, + ) { + let Some(client) = self.client.as_ref() else { + self.outbound = OutboundState::Disabled; + return; + }; + + client.release_outbound(); + + let (error, relay_reset) = match error { + StreamUpgradeError::NegotiationFailed => (Error::Unsupported, false), + StreamUpgradeError::Timeout => ( + Error::Io( + std::io::Error::new( + std::io::ErrorKind::TimedOut, + "sync protocol negotiation timed out", + ) + .to_string(), + ), + false, + ), + StreamUpgradeError::Apply(never) => match never {}, + StreamUpgradeError::Io(error) => ( + Error::Io(error.to_string()), + error.kind() == std::io::ErrorKind::ConnectionReset, + ), + }; + + if relay_reset || client.should_reconnect() { + self.schedule_retry(); + } else { + client.finish(Err(error)); + self.outbound = OutboundState::Disabled; + } + } +} + +impl ConnectionHandler for Handler { + type FromBehaviour = Infallible; + type InboundOpenInfo = (); + type InboundProtocol = ReadyUpgrade; + type OutboundOpenInfo = (); + type OutboundProtocol = ReadyUpgrade; + type ToBehaviour = OutEvent; + + fn listen_protocol(&self) -> SubstreamProtocol { + self.substream_protocol() + } + + fn on_behaviour_event(&mut self, never: Self::FromBehaviour) { + match never {} + } + + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll< + ConnectionHandlerEvent, + > { + if let Poll::Ready(Some(event)) = self.inbound_events_rx.poll_recv(cx) { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + } + + if let Some(inbound) = self.inbound.as_mut() { + match inbound.poll_unpin(cx) { + Poll::Pending => {} + Poll::Ready(Ok(())) => { + self.inbound = None; + } + Poll::Ready(Err(error)) => { + warn!(peer = %self.peer_id, err = %error, "Error serving inbound sync stream"); + self.inbound = None; + } + } + } + + match &mut self.outbound { + OutboundState::Idle => { + if let Some(event) = self.try_request_outbound() { + return Poll::Ready(event); + } + } + OutboundState::OpenStream => {} + OutboundState::WaitingRetry(delay) => { + if delay.as_mut().poll(cx).is_ready() { + if let Some(event) = self.try_request_outbound() { + return Poll::Ready(event); + } + + self.outbound = OutboundState::Idle; + } + } + OutboundState::Running(fut) => match fut.poll_unpin(cx) { + Poll::Pending => {} + Poll::Ready(OutboundExit::GracefulShutdown) => { + if let Some(client) = self.client.as_ref() { + client.finish(Ok(())); + } + self.outbound = OutboundState::Disabled; + } + Poll::Ready(OutboundExit::Reconnectable { error, relay_reset }) => { + let Some(client) = self.client.as_ref() else { + self.outbound = OutboundState::Disabled; + return Poll::Pending; + }; + + client.set_connected(false); + client.release_outbound(); + + if relay_reset || client.should_reconnect() { + info!(peer = %self.peer_id, err = %error, "Disconnected from peer"); + self.outbound = OutboundState::Idle; + } else { + client.finish(Err(error)); + self.outbound = OutboundState::Disabled; + } + } + Poll::Ready(OutboundExit::Fatal(error)) => { + if let Some(client) = self.client.as_ref() { + client.finish(Err(error)); + } + self.outbound = OutboundState::Disabled; + } + }, + OutboundState::Disabled => {} + } + + Poll::Pending + } + + fn on_connection_event( + &mut self, + event: ConnectionEvent, + ) { + match event { + ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound { + protocol: mut stream, + .. + }) => { + stream.ignore_for_keep_alive(); + self.inbound = Some( + handle_inbound_stream( + self.peer_id, + self.server.clone(), + self.inbound_events_tx.clone(), + stream, + ) + .boxed(), + ); + } + ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound { + protocol: mut stream, + .. + }) => { + let Some(client) = self.client.clone() else { + self.outbound = OutboundState::Disabled; + return; + }; + + stream.ignore_for_keep_alive(); + self.backoff = INITIAL_BACKOFF; + self.outbound = OutboundState::Running(run_outbound_stream(client, stream).boxed()); + } + ConnectionEvent::DialUpgradeError(error) => self.on_dial_upgrade_error(error), + ConnectionEvent::AddressChange(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} + ConnectionEvent::ListenUpgradeError(_) => {} + _ => {} + } + } +} + +async fn run_outbound_stream(client: Client, mut stream: Stream) -> OutboundExit { + let mut interval = tokio::time::interval(client.period()); + let hash_signature = prost::bytes::Bytes::copy_from_slice(client.hash_sig()); + let version = client.version().to_string(); + + client.set_connected(true); + + loop { + interval.tick().await; + + let shutdown = client.shutdown_requested(); + let timestamp = Timestamp::from(std::time::SystemTime::now()); + let request = MsgSync { + timestamp: Some(timestamp), + hash_signature: hash_signature.clone(), + shutdown, + version: version.clone(), + step: client.step(), + }; + + let response: std::io::Result = + match pluto_p2p::proto::write_fixed_size_protobuf(&mut stream, &request).await { + Ok(()) => pluto_p2p::proto::read_fixed_size_protobuf(&mut stream).await, + Err(error) => Err(error), + }; + + let response = match response { + Ok(response) => response, + Err(error) => { + return OutboundExit::Reconnectable { + relay_reset: error.kind() == std::io::ErrorKind::ConnectionReset, + error: Error::Io(error.to_string()), + }; + } + }; + + if !response.error.is_empty() { + return OutboundExit::Fatal(Error::PeerRespondedWithError(response.error)); + } + + if let Some(sync_timestamp) = response.sync_timestamp { + debug!( + peer = %client.peer_id(), + sync_timestamp = ?sync_timestamp, + "Received sync response" + ); + } + + if shutdown { + return OutboundExit::GracefulShutdown; + } + } +} + +async fn handle_inbound_stream( + peer_id: PeerId, + server: Server, + inbound_events_tx: mpsc::UnboundedSender, + mut stream: Stream, +) -> Result<()> { + let result = async { + if !server.is_started() { + return Err(Error::ServerNotStarted); + } + + let public_key = pluto_p2p::peer::peer_id_to_libp2p_pk(&peer_id) + .map_err(|error| Error::Peer(error.to_string()))?; + + loop { + let message: MsgSync = pluto_p2p::proto::read_fixed_size_protobuf(&mut stream) + .await + .map_err(|error| Error::Io(error.to_string()))?; + let mut response = MsgSyncResponse { + sync_timestamp: message.timestamp, + error: String::new(), + }; + + if let Err(error) = protocol::validate_request_with_public_key( + server.def_hash(), + server.version(), + &public_key, + &message, + ) { + let error_string = error.to_string(); + send_inbound_event( + &inbound_events_tx, + OutEvent::SyncRejected { peer_id, error }, + ); + server + .set_err(Error::InvalidSyncMessage { + peer: peer_id, + error: error_string.clone(), + }) + .await; + response.error = error_string; + } else { + let (inserted, count) = server.mark_connected(peer_id).await; + if inserted { + info!( + peer = %peer_id, + connected = count, + expected = server.expected_peer_count(), + "Connected to peer" + ); + } + } + + if server.update_step(peer_id, message.step).await? { + send_inbound_event( + &inbound_events_tx, + OutEvent::PeerStepUpdated { + peer_id, + step: message.step, + }, + ); + } + + pluto_p2p::proto::write_fixed_size_protobuf(&mut stream, &response) + .await + .map_err(|error| Error::Io(error.to_string()))?; + + if message.shutdown { + send_inbound_event( + &inbound_events_tx, + OutEvent::PeerShutdownObserved { peer_id }, + ); + server.set_shutdown(peer_id).await; + return Ok(()); + } + } + } + .await; + + server.clear_connected(peer_id).await; + result +} + +fn send_inbound_event(inbound_events_tx: &mpsc::UnboundedSender, event: OutEvent) { + if let Err(error) = inbound_events_tx.send(event) { + tracing::error!(err = %error, "Failed to deliver inbound sync event"); + } +} diff --git a/crates/dkg/src/sync/mod.rs b/crates/dkg/src/sync/mod.rs new file mode 100644 index 00000000..2b98285e --- /dev/null +++ b/crates/dkg/src/sync/mod.rs @@ -0,0 +1,338 @@ +mod behaviour; +mod client; +mod error; +mod handler; +mod protocol; +mod server; + +use libp2p::PeerId; +use pluto_core::version::SemVer; +use pluto_p2p::p2p_context::P2PContext; +use tokio::sync::mpsc; + +pub use behaviour::{Behaviour, Event}; +pub use client::{Client, ClientConfig, DEFAULT_PERIOD}; +pub use error::{Error, Result}; +pub use server::Server; + +#[derive(Debug, Clone, Copy)] +pub(crate) enum Command { + Activate(PeerId), +} + +/// Creates a sync behaviour plus server/client handles for the given peer set. +pub fn new( + peers: Vec, + p2p_context: P2PContext, + secret: &k256::SecretKey, + def_hash: Vec, + version: SemVer, +) -> Result<(Behaviour, Server, Vec)> { + let local_peer_id = p2p_context.local_peer_id().ok_or(Error::LocalPeerMissing)?; + if !peers.contains(&local_peer_id) { + return Err(Error::LocalPeerNotInPeerSet); + } + + let hash_sig = protocol::sign_definition_hash(secret, &def_hash)?; + let remote_peers = peers + .into_iter() + .filter(|peer_id| *peer_id != local_peer_id) + .collect::>(); + let server = Server::new(remote_peers.len(), def_hash, version.clone()); + let (command_tx, command_rx) = mpsc::unbounded_channel(); + let clients = remote_peers + .into_iter() + .map(|peer_id| { + Client::new( + peer_id, + hash_sig.clone(), + version.clone(), + ClientConfig::default(), + Some(command_tx.clone()), + ) + }) + .collect::>(); + let behaviour = Behaviour::new(server.clone(), clients.clone(), p2p_context, command_rx); + Ok((behaviour, server, clients)) +} + +#[cfg(test)] +mod tests { + use std::{net::TcpListener, time::Duration}; + + use futures::StreamExt; + use libp2p::PeerId; + use pluto_core::version::SemVer; + use pluto_p2p::{ + config::P2PConfig, + p2p::{Node, NodeType}, + p2p_context::P2PContext, + peer::peer_id_from_key, + }; + use pluto_testutil::random::generate_insecure_k1_key; + use tokio::{sync::oneshot, time::timeout}; + use tokio_util::sync::CancellationToken; + + use super::*; + + struct LocalNode { + server: Server, + clients: Vec, + node: Node, + addr: libp2p::Multiaddr, + } + + struct RunningNode { + server: Server, + clients: Vec, + stop_tx: oneshot::Sender<()>, + join: tokio::task::JoinHandle>, + client_joins: Vec>>, + } + + fn available_tcp_port() -> anyhow::Result { + let listener = TcpListener::bind("127.0.0.1:0")?; + Ok(listener.local_addr()?.port()) + } + + async fn spawn_nodes(nodes: Vec) -> anyhow::Result> { + let mut nodes = nodes; + for node in &mut nodes { + node.node.listen_on(node.addr.clone())?; + } + + let dial_targets = (0..nodes.len()) + .map(|index| { + nodes + .iter() + .enumerate() + .filter(|(other, _)| *other > index) + .map(|(_, node)| node.addr.clone()) + .collect::>() + }) + .collect::>(); + + let mut running = Vec::with_capacity(nodes.len()); + for (local, targets) in nodes.into_iter().zip(dial_targets) { + let mut node = local.node; + let server = local.server.clone(); + server.start(); + let cancellation = CancellationToken::new(); + let client_joins = local + .clients + .iter() + .map(|client| { + let cancellation = cancellation.child_token(); + let client = client.clone(); + tokio::spawn(async move { client.run(cancellation).await }) + }) + .collect::>(); + let (stop_tx, mut stop_rx) = oneshot::channel(); + + let join = tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(200)).await; + for target in targets { + node.dial(target)?; + } + + loop { + tokio::select! { + _ = &mut stop_rx => break, + _event = node.select_next_some() => {} + } + } + + Ok(()) + }); + + running.push(RunningNode { + server: local.server, + clients: local.clients, + stop_tx, + join, + client_joins, + }); + } + + Ok(running) + } + + async fn shutdown_nodes(nodes: Vec) -> anyhow::Result<()> { + for node in nodes { + let _ = node.stop_tx.send(()); + node.join.await??; + for join in node.client_joins { + let _ = join.await; + } + } + + Ok(()) + } + + #[tokio::test] + async fn update_step_rules_match_go() { + let version = SemVer::parse("v0.1").expect("valid version"); + let server = Server::new(1, vec![0; 32], version); + let peer = PeerId::random(); + + let error = server + .update_step(peer, 100) + .await + .expect_err("wrong initial step should fail"); + assert!(matches!(error, Error::AbnormalInitialStep)); + + let peer = PeerId::random(); + server + .update_step(peer, 1) + .await + .expect("first valid step should pass"); + server + .update_step(peer, 1) + .await + .expect("same step should pass"); + server + .update_step(peer, 2) + .await + .expect("next step should pass"); + + let peer = PeerId::random(); + server + .update_step(peer, 1) + .await + .expect("first step should pass"); + let error = server + .update_step(peer, 0) + .await + .expect_err("behind should fail"); + assert!(matches!(error, Error::PeerStepBehind)); + + let peer = PeerId::random(); + server + .update_step(peer, 1) + .await + .expect("first step should pass"); + let error = server + .update_step(peer, 4) + .await + .expect_err("ahead should fail"); + assert!(matches!(error, Error::PeerStepAhead)); + } + + #[test] + fn new_requires_local_peer_in_peer_set() { + let key = generate_insecure_k1_key(0); + let local_peer_id = peer_id_from_key(key.public_key()).expect("peer id"); + let remote_peer = PeerId::random(); + let p2p_context = P2PContext::new([local_peer_id, remote_peer]); + p2p_context.set_local_peer_id(local_peer_id); + + let result = new( + vec![remote_peer], + p2p_context, + &key, + vec![1, 2, 3], + SemVer::parse("v1.7").expect("version"), + ); + + assert!( + matches!(result, Err(Error::LocalPeerNotInPeerSet)), + "local peer must be part of the sync peer set" + ); + } + + #[tokio::test] + async fn sync_round_trip_matches_go_shape() -> anyhow::Result<()> { + let ports = (0..3) + .map(|_| available_tcp_port()) + .collect::>>()?; + let keys = (0_u8..3).map(generate_insecure_k1_key).collect::>(); + let peer_ids = keys + .iter() + .map(|key| peer_id_from_key(key.public_key())) + .collect::, _>>()?; + let version = SemVer::parse("v1.7")?; + let mut nodes = Vec::new(); + for (index, key) in keys.into_iter().enumerate() { + let peer_id = peer_ids[index]; + let p2p_context = P2PContext::new(peer_ids.clone()); + p2p_context.set_local_peer_id(peer_id); + let (behaviour, server, clients) = new( + peer_ids.clone(), + p2p_context.clone(), + &key, + vec![1, 2, 3], + version.clone(), + )?; + let node = Node::new_server( + P2PConfig::default(), + key, + NodeType::TCP, + false, + peer_ids.clone(), + |builder, _keypair| builder.with_p2p_context(p2p_context).with_inner(behaviour), + )?; + let addr = format!("/ip4/127.0.0.1/tcp/{}", ports[index]).parse()?; + nodes.push(LocalNode { + server, + clients, + node, + addr, + }); + } + + let running = spawn_nodes(nodes).await?; + let cancellation = CancellationToken::new(); + + for node in &running { + timeout( + Duration::from_secs(10), + node.server.await_all_connected(cancellation.child_token()), + ) + .await??; + } + + for step in 0_i64..5 { + for node in &running { + timeout( + Duration::from_secs(10), + node.server + .await_all_at_step(step, cancellation.child_token()), + ) + .await??; + + let future = node + .server + .await_all_at_step(step + 1, cancellation.child_token()); + let error = timeout(Duration::from_millis(10), future).await; + assert!(error.is_err(), "next step should not complete immediately"); + } + + for node in &running { + for client in &node.clients { + client.set_step(step + 1); + } + } + } + + for node in &running { + assert!(node.clients.iter().all(Client::is_connected)); + } + + for node in &running { + for client in &node.clients { + client.shutdown(cancellation.child_token()).await?; + } + } + + for node in &running { + timeout( + Duration::from_secs(10), + node.server.await_all_shutdown(cancellation.child_token()), + ) + .await??; + } + + shutdown_nodes(running).await?; + Ok(()) + } +} diff --git a/crates/dkg/src/sync/protocol.rs b/crates/dkg/src/sync/protocol.rs new file mode 100644 index 00000000..477e25a5 --- /dev/null +++ b/crates/dkg/src/sync/protocol.rs @@ -0,0 +1,49 @@ +//! Protocol helpers for the DKG sync protocol. + +use libp2p::identity::{Keypair, PublicKey}; +use pluto_core::version::SemVer; + +use crate::dkgpb::v1::sync::MsgSync; + +use super::error::{Error, Result}; + +/// The protocol identifier for DKG sync. +pub const PROTOCOL_NAME: libp2p::StreamProtocol = StreamProtocol::new("/charon/dkg/sync/1.0.0/"); + +use libp2p::StreamProtocol; + +/// Signs the definition hash using the same libp2p signing path as Go. +pub fn sign_definition_hash(secret: &k256::SecretKey, def_hash: &[u8]) -> Result> { + let mut der = secret + .to_sec1_der() + .map_err(|error| Error::KeyConversion(error.to_string()))?; + let keypair = Keypair::secp256k1_from_der(&mut der) + .map_err(|error| Error::KeyConversion(error.to_string()))?; + keypair + .sign(def_hash) + .map_err(|error| Error::SignDefinitionHash(error.to_string())) +} + +/// Validates a sync request for a known peer public key. +pub fn validate_request_with_public_key( + def_hash: &[u8], + expected_version: &SemVer, + public_key: &PublicKey, + msg: &MsgSync, +) -> Result<()> { + let msg_version = + SemVer::parse(&msg.version).map_err(|error| Error::ParsePeerVersion(error.to_string()))?; + + if msg_version != *expected_version { + return Err(Error::VersionMismatch { + expected: expected_version.to_string(), + got: msg.version.clone(), + }); + } + + if !public_key.verify(def_hash, &msg.hash_signature) { + return Err(Error::InvalidDefinitionHashSignature); + } + + Ok(()) +} diff --git a/crates/dkg/src/sync/server.rs b/crates/dkg/src/sync/server.rs new file mode 100644 index 00000000..141d5e9a --- /dev/null +++ b/crates/dkg/src/sync/server.rs @@ -0,0 +1,243 @@ +//! Server handle for the DKG sync protocol. + +use std::{ + collections::{HashMap, HashSet}, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, +}; + +use libp2p::PeerId; +use pluto_core::version::SemVer; +use tokio::sync::{Notify, RwLock}; +use tokio_util::sync::CancellationToken; + +use super::error::{Error, Result}; + +#[derive(Debug, Default)] +struct ServerState { + shutdown: HashSet, + connected: HashSet, + steps: HashMap, + err: Option, +} + +#[derive(Debug)] +struct ServerInner { + all_count: usize, + def_hash: Vec, + version: SemVer, + started: AtomicBool, + notify: Notify, + state: RwLock, +} + +/// User-facing handle for the sync server state. +#[derive(Debug, Clone)] +pub struct Server { + inner: Arc, +} + +impl Server { + /// Creates a new server handle. + pub fn new(all_count: usize, def_hash: Vec, version: SemVer) -> Self { + Self { + inner: Arc::new(ServerInner { + all_count, + def_hash, + version, + started: AtomicBool::new(false), + notify: Notify::new(), + state: RwLock::new(ServerState::default()), + }), + } + } + + /// Starts the server side of the protocol. + pub fn start(&self) { + self.inner.started.store(true, Ordering::SeqCst); + self.inner.notify.notify_waiters(); + } + + /// Returns the current shared server error, if any. + pub async fn err(&self) -> Option { + self.inner.state.read().await.err.clone() + } + + /// Waits until all peers have connected or an error occurs. + pub async fn await_all_connected(&self, cancellation: CancellationToken) -> Result<()> { + loop { + let notified = self.inner.notify.notified(); + tokio::pin!(notified); + + if !self.inner.started.load(Ordering::SeqCst) { + tokio::select! { + _ = cancellation.cancelled() => return Err(Error::Canceled), + _ = &mut notified => {} + } + continue; + } + + { + let state = self.inner.state.read().await; + if let Some(error) = &state.err { + return Err(error.clone()); + } + if state.connected.len() == self.inner.all_count { + return Ok(()); + } + } + + tokio::select! { + _ = cancellation.cancelled() => return Err(Error::Canceled), + _ = &mut notified => {} + } + } + } + + /// Waits until all peers have reported shutdown. + pub async fn await_all_shutdown(&self, cancellation: CancellationToken) -> Result<()> { + loop { + let notified = self.inner.notify.notified(); + tokio::pin!(notified); + + { + let state = self.inner.state.read().await; + if state.shutdown.len() == self.inner.all_count { + return Ok(()); + } + } + + tokio::select! { + _ = cancellation.cancelled() => return Err(Error::Canceled), + _ = &mut notified => {} + } + } + } + + /// Waits until all peers have reached the given step or the next one. + pub async fn await_all_at_step( + &self, + step: i64, + cancellation: CancellationToken, + ) -> Result<()> { + let step_plus_one = step.checked_add(1).ok_or(Error::StepOverflow)?; + let step_plus_two = step.checked_add(2).ok_or(Error::StepOverflow)?; + + loop { + let notified = self.inner.notify.notified(); + tokio::pin!(notified); + + { + let state = self.inner.state.read().await; + if let Some(error) = &state.err { + return Err(error.clone()); + } + + if state.steps.len() == self.inner.all_count { + let mut all_ok = true; + for actual in state.steps.values() { + if *actual >= step_plus_two { + return Err(Error::PeerStepTooFarAhead); + } + + if *actual != step && *actual != step_plus_one { + all_ok = false; + } + } + + if all_ok { + return Ok(()); + } + } + } + + tokio::select! { + _ = cancellation.cancelled() => return Err(Error::Canceled), + _ = &mut notified => {} + } + } + } + + pub(crate) fn def_hash(&self) -> &[u8] { + &self.inner.def_hash + } + + pub(crate) fn version(&self) -> &SemVer { + &self.inner.version + } + + pub(crate) fn expected_peer_count(&self) -> usize { + self.inner.all_count + } + + pub(crate) fn is_started(&self) -> bool { + self.inner.started.load(Ordering::SeqCst) + } + + pub(crate) async fn mark_connected(&self, peer_id: PeerId) -> (bool, usize) { + let mut state = self.inner.state.write().await; + let inserted = state.connected.insert(peer_id); + let count = state.connected.len(); + if inserted { + self.inner.notify.notify_waiters(); + } + (inserted, count) + } + + pub(crate) async fn clear_connected(&self, peer_id: PeerId) { + self.mutate_state(|state| { + state.connected.remove(&peer_id); + }) + .await; + } + + pub(crate) async fn set_shutdown(&self, peer_id: PeerId) { + self.mutate_state(|state| { + state.shutdown.insert(peer_id); + }) + .await; + } + + pub(crate) async fn set_err(&self, error: Error) { + self.mutate_state(|state| { + state.err = Some(error); + }) + .await; + } + + pub(crate) async fn update_step(&self, peer_id: PeerId, step: i64) -> Result { + let mut state = self.inner.state.write().await; + match state.steps.get(&peer_id).copied() { + Some(current) => { + if step < current { + return Err(Error::PeerStepBehind); + } + + let current_plus_two = current.checked_add(2).ok_or(Error::StepOverflow)?; + if step > current_plus_two { + return Err(Error::PeerStepAhead); + } + + if step == current { + return Ok(false); + } + } + None if !(0..=1).contains(&step) => { + return Err(Error::AbnormalInitialStep); + } + None => {} + } + + state.steps.insert(peer_id, step); + self.inner.notify.notify_waiters(); + Ok(true) + } + + async fn mutate_state(&self, mutate: impl FnOnce(&mut ServerState)) { + let mut state = self.inner.state.write().await; + mutate(&mut state); + self.inner.notify.notify_waiters(); + } +} diff --git a/crates/p2p/src/proto.rs b/crates/p2p/src/proto.rs index 664d6eb3..9171598c 100644 --- a/crates/p2p/src/proto.rs +++ b/crates/p2p/src/proto.rs @@ -48,6 +48,60 @@ pub async fn read_length_delimited( Ok(buf) } +/// Writes a fixed-size length-delimited payload to the stream. +/// +/// Format: `[i64 little-endian length][payload bytes]` +pub async fn write_fixed_size_delimited( + stream: &mut S, + payload: &[u8], +) -> io::Result<()> { + let len = i64::try_from(payload.len()) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "payload length overflow"))?; + + stream.write_all(&len.to_le_bytes()).await?; + stream.write_all(payload).await?; + stream.flush().await +} + +/// Reads a fixed-size length-delimited payload from the stream. +pub async fn read_fixed_size_delimited( + stream: &mut S, +) -> io::Result> { + read_fixed_size_delimited_with_max(stream, MAX_MESSAGE_SIZE).await +} + +/// Reads a fixed-size length-delimited payload from the stream, rejecting +/// oversized messages. +pub async fn read_fixed_size_delimited_with_max( + stream: &mut S, + max_message_size: usize, +) -> io::Result> { + let mut len_buf = [0_u8; 8]; + stream.read_exact(&mut len_buf).await?; + + let len = i64::from_le_bytes(len_buf); + if len < 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid data length", + )); + } + + let len = usize::try_from(len) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "payload length overflow"))?; + if len > max_message_size { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("message too large: {len} bytes (max: {max_message_size})"), + )); + } + + let mut payload = vec![0_u8; len]; + stream.read_exact(&mut payload).await?; + + Ok(payload) +} + /// Encodes a protobuf message and writes it with length-delimited framing. pub async fn write_protobuf( stream: &mut S, @@ -60,6 +114,18 @@ pub async fn write_protobuf( write_length_delimited(stream, &buf).await } +/// Encodes a protobuf message and writes it with fixed-size framing. +pub async fn write_fixed_size_protobuf( + stream: &mut S, + msg: &M, +) -> io::Result<()> { + let mut buf = Vec::with_capacity(msg.encoded_len()); + msg.encode(&mut buf) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + + write_fixed_size_delimited(stream, &buf).await +} + /// Reads a protobuf message using the default maximum message size. pub async fn read_protobuf( stream: &mut S, @@ -75,3 +141,59 @@ pub async fn read_protobuf_with_max_size( + stream: &mut S, +) -> io::Result { + let buf = read_fixed_size_delimited(stream).await?; + M::decode(&buf[..]).map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error)) +} + +#[cfg(test)] +mod tests { + use futures::io::Cursor; + + use super::*; + + #[tokio::test] + async fn fixed_size_round_trip() { + let payload = vec![1, 2, 3, 4]; + let mut cursor = Cursor::new(Vec::new()); + + write_fixed_size_delimited(&mut cursor, &payload) + .await + .expect("write should succeed"); + cursor.set_position(0); + + let decoded = read_fixed_size_delimited(&mut cursor) + .await + .expect("read should succeed"); + + assert_eq!(decoded, payload); + } + + #[tokio::test] + async fn negative_fixed_size_length_fails() { + let mut cursor = Cursor::new((-1_i64).to_le_bytes().to_vec()); + + let error = read_fixed_size_delimited(&mut cursor) + .await + .expect_err("negative sizes must fail"); + assert_eq!(error.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn oversized_fixed_size_length_fails() { + let too_large = i64::try_from(MAX_MESSAGE_SIZE) + .expect("max size should fit in i64") + .checked_add(1) + .expect("test length should not overflow"); + let mut cursor = Cursor::new(too_large.to_le_bytes().to_vec()); + + let error = read_fixed_size_delimited(&mut cursor) + .await + .expect_err("oversized payloads must fail"); + assert_eq!(error.kind(), io::ErrorKind::InvalidData); + } +}