From 0e51cb5debce0b810bc6ebc7b5dabd119bdd4e5c Mon Sep 17 00:00:00 2001 From: Quang Le Date: Tue, 31 Mar 2026 15:14:41 +0700 Subject: [PATCH 01/14] fix: correct sync proto --- crates/dkg/src/dkgpb/v1.rs | 1 + crates/dkg/src/dkgpb/v1/{sync.rs => dkg.dkgpb.v1.rs} | 12 ++++++------ crates/dkg/src/dkgpb/v1/sync.proto | 4 +++- 3 files changed, 10 insertions(+), 7 deletions(-) rename crates/dkg/src/dkgpb/v1/{sync.rs => dkg.dkgpb.v1.rs} (80%) diff --git a/crates/dkg/src/dkgpb/v1.rs b/crates/dkg/src/dkgpb/v1.rs index 2c469269..ccb5fcb2 100644 --- a/crates/dkg/src/dkgpb/v1.rs +++ b/crates/dkg/src/dkgpb/v1.rs @@ -5,4 +5,5 @@ pub mod frost; /// Nodesigs protobuf definitions. pub mod nodesigs; /// Sync protobuf definitions. +#[path = "v1/dkg.dkgpb.v1.rs"] pub mod sync; diff --git a/crates/dkg/src/dkgpb/v1/sync.rs b/crates/dkg/src/dkgpb/v1/dkg.dkgpb.v1.rs similarity index 80% rename from crates/dkg/src/dkgpb/v1/sync.rs rename to crates/dkg/src/dkgpb/v1/dkg.dkgpb.v1.rs index 428d864a..ed928107 100644 --- a/crates/dkg/src/dkgpb/v1/sync.rs +++ b/crates/dkg/src/dkgpb/v1/dkg.dkgpb.v1.rs @@ -19,12 +19,12 @@ pub struct MsgSync { } impl ::prost::Name for MsgSync { const NAME: &'static str = "MsgSync"; - const PACKAGE: &'static str = "sync"; + const PACKAGE: &'static str = "dkg.dkgpb.v1"; fn full_name() -> ::prost::alloc::string::String { - "sync.MsgSync".into() + "dkg.dkgpb.v1.MsgSync".into() } fn type_url() -> ::prost::alloc::string::String { - "type.googleapis.com/sync.MsgSync".into() + "type.googleapis.com/dkg.dkgpb.v1.MsgSync".into() } } #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] @@ -36,11 +36,11 @@ pub struct MsgSyncResponse { } impl ::prost::Name for MsgSyncResponse { const NAME: &'static str = "MsgSyncResponse"; - const PACKAGE: &'static str = "sync"; + const PACKAGE: &'static str = "dkg.dkgpb.v1"; fn full_name() -> ::prost::alloc::string::String { - "sync.MsgSyncResponse".into() + "dkg.dkgpb.v1.MsgSyncResponse".into() } fn type_url() -> ::prost::alloc::string::String { - "type.googleapis.com/sync.MsgSyncResponse".into() + "type.googleapis.com/dkg.dkgpb.v1.MsgSyncResponse".into() } } \ No newline at end of file diff --git a/crates/dkg/src/dkgpb/v1/sync.proto b/crates/dkg/src/dkgpb/v1/sync.proto index 3849c4db..8cf517c8 100644 --- a/crates/dkg/src/dkgpb/v1/sync.proto +++ b/crates/dkg/src/dkgpb/v1/sync.proto @@ -1,9 +1,11 @@ syntax = "proto3"; -package sync; +package dkg.dkgpb.v1; import "google/protobuf/timestamp.proto"; +option go_package = "github.com/obolnetwork/charon/dkg/dkgpb/v1"; + message MsgSync { google.protobuf.Timestamp timestamp = 1; bytes hash_signature = 2; From 0526b92916c6b549720993802c2e6328f1a7457a Mon Sep 17 00:00:00 2001 From: Quang Le Date: Tue, 31 Mar 2026 15:15:14 +0700 Subject: [PATCH 02/14] feat: add read/write fixed size proto --- crates/p2p/src/proto.rs | 81 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/crates/p2p/src/proto.rs b/crates/p2p/src/proto.rs index 664d6eb3..d49a3fdb 100644 --- a/crates/p2p/src/proto.rs +++ b/crates/p2p/src/proto.rs @@ -7,6 +7,11 @@ use unsigned_varint::aio::read_usize; /// Default maximum protobuf message size pub const MAX_MESSAGE_SIZE: usize = 128 << 20; +/// Error returned when a fixed-size frame uses a negative length prefix. +#[derive(Debug, thiserror::Error)] +#[error("invalid fixed-size frame length: {0}")] +pub struct InvalidFixedSizeLength(pub i64); + /// Writes a length-delimited payload to the stream. /// /// Format: `[unsigned varint length][payload bytes]` @@ -48,6 +53,43 @@ pub async fn read_length_delimited( Ok(buf) } +/// Writes a fixed-size length-delimited payload to the stream. +/// +/// Format: `[i64 little-endian length][payload bytes]` +pub async fn write_fixed_size_delimited( + stream: &mut S, + payload: &[u8], +) -> io::Result<()> { + let len = i64::try_from(payload.len()) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "payload length overflow"))?; + + stream.write_all(&len.to_le_bytes()).await?; + stream.write_all(payload).await +} + +/// Reads a fixed-size length-delimited payload from the stream. +pub async fn read_fixed_size_delimited( + stream: &mut S, +) -> io::Result> { + let mut len_buf = [0_u8; 8]; + stream.read_exact(&mut len_buf).await?; + + let len = i64::from_le_bytes(len_buf); + if len < 0 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + InvalidFixedSizeLength(len), + )); + } + + let len = usize::try_from(len) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "payload length overflow"))?; + let mut payload = vec![0_u8; len]; + stream.read_exact(&mut payload).await?; + + Ok(payload) +} + /// Encodes a protobuf message and writes it with length-delimited framing. pub async fn write_protobuf( stream: &mut S, @@ -75,3 +117,42 @@ pub async fn read_protobuf_with_max_size()) + .map(|error| error.0); + + assert_eq!(size, Some(-1)); + } +} From 9693e9ae850114199a270e33c2ef31cef916253e Mon Sep 17 00:00:00 2001 From: Quang Le Date: Wed, 1 Apr 2026 14:51:43 +0700 Subject: [PATCH 03/14] feat: initial implementation of sync --- crates/dkg/src/lib.rs | 3 + crates/dkg/src/sync/behaviour.rs | 269 +++++++++++++++++++++ crates/dkg/src/sync/client.rs | 229 ++++++++++++++++++ crates/dkg/src/sync/error.rs | 138 +++++++++++ crates/dkg/src/sync/handler.rs | 388 +++++++++++++++++++++++++++++++ crates/dkg/src/sync/mod.rs | 316 +++++++++++++++++++++++++ crates/dkg/src/sync/protocol.rs | 155 ++++++++++++ crates/dkg/src/sync/server.rs | 257 ++++++++++++++++++++ 8 files changed, 1755 insertions(+) create mode 100644 crates/dkg/src/sync/behaviour.rs create mode 100644 crates/dkg/src/sync/client.rs create mode 100644 crates/dkg/src/sync/error.rs create mode 100644 crates/dkg/src/sync/handler.rs create mode 100644 crates/dkg/src/sync/mod.rs create mode 100644 crates/dkg/src/sync/protocol.rs create mode 100644 crates/dkg/src/sync/server.rs diff --git a/crates/dkg/src/lib.rs b/crates/dkg/src/lib.rs index ce0e6cde..0226167b 100644 --- a/crates/dkg/src/lib.rs +++ b/crates/dkg/src/lib.rs @@ -19,3 +19,6 @@ pub mod dkg; /// Shares distributed to each node in the cluster. pub mod share; + +/// Step synchronization protocol for DKG peers. +pub mod sync; diff --git a/crates/dkg/src/sync/behaviour.rs b/crates/dkg/src/sync/behaviour.rs new file mode 100644 index 00000000..ac0e6e2d --- /dev/null +++ b/crates/dkg/src/sync/behaviour.rs @@ -0,0 +1,269 @@ +use std::{ + collections::{HashMap, HashSet, VecDeque}, + task::{Context, Poll}, +}; + +use libp2p::{ + Multiaddr, PeerId, + swarm::{ + ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, + THandlerOutEvent, ToSwarm, + dial_opts::{DialOpts, PeerCondition}, + }, +}; +use pluto_p2p::p2p_context::P2PContext; +use tokio::sync::mpsc; + +use super::{Command, client::Client, handler::Handler, server::Server}; + +/// Event emitted by the sync behaviour. +#[derive(Debug, Clone)] +pub enum Event {} + +/// Swarm behaviour backing the DKG sync protocol. +pub struct Behaviour { + server: Server, + clients: HashMap, + p2p_context: P2PContext, + command_rx: mpsc::UnboundedReceiver, + pending_dials: HashSet, + pending_events: VecDeque>>, +} + +impl Behaviour { + /// Creates a new sync behaviour from a server and client handles. + pub(crate) fn new( + server: Server, + clients: impl IntoIterator, + p2p_context: P2PContext, + command_rx: mpsc::UnboundedReceiver, + ) -> Self { + Self { + server, + clients: clients + .into_iter() + .map(|client| (client.peer_id(), client)) + .collect(), + p2p_context, + command_rx, + pending_dials: HashSet::new(), + pending_events: VecDeque::new(), + } + } + + fn new_handler(&self, peer: PeerId) -> Handler { + Handler::new(peer, self.server.clone(), self.clients.get(&peer).cloned()) + } + + fn is_connected(&self, peer_id: &PeerId) -> bool { + !self + .p2p_context + .peer_store_lock() + .connections_to_peer(peer_id) + .is_empty() + } + + fn queue_dial(&mut self, peer_id: PeerId) { + if self.is_connected(&peer_id) || !self.pending_dials.insert(peer_id) { + return; + } + + self.pending_events.push_back(ToSwarm::Dial { + opts: DialOpts::peer_id(peer_id) + .condition(PeerCondition::DisconnectedAndNotDialing) + .build(), + }); + } + + fn handle_command(&mut self, command: Command) { + match command { + Command::Activate(peer_id) => { + let Some(client) = self.clients.get(&peer_id) else { + return; + }; + + if client.should_run() && !client.is_connected() { + self.queue_dial(peer_id); + } + } + } + } +} + +impl NetworkBehaviour for Behaviour { + type ConnectionHandler = Handler; + type ToSwarm = Event; + + fn handle_established_inbound_connection( + &mut self, + _connection_id: ConnectionId, + peer: PeerId, + _local_addr: &Multiaddr, + _remote_addr: &Multiaddr, + ) -> std::result::Result, ConnectionDenied> { + Ok(self.new_handler(peer)) + } + + fn handle_established_outbound_connection( + &mut self, + _connection_id: ConnectionId, + peer: PeerId, + _addr: &Multiaddr, + _role_override: libp2p::core::Endpoint, + _port_use: libp2p::core::transport::PortUse, + ) -> std::result::Result, ConnectionDenied> { + Ok(self.new_handler(peer)) + } + + fn on_swarm_event(&mut self, event: FromSwarm) { + match event { + FromSwarm::ConnectionEstablished(event) => { + self.pending_dials.remove(&event.peer_id); + } + FromSwarm::ConnectionClosed(event) => { + if event.remaining_established > 0 { + return; + } + + // TODO: Go retries sync client connections until reconnect is disabled. + // Re-queue active clients here (and on DialFailure below) so peers that + // restart before initial cluster sync can be dialed again. + self.pending_dials.remove(&event.peer_id); + + if let Some(client) = self.clients.get(&event.peer_id) { + client.set_connected(false); + client.release_outbound(); + } + } + FromSwarm::DialFailure(event) => { + if let Some(peer_id) = event.peer_id { + self.pending_dials.remove(&peer_id); + } + } + _ => {} + } + } + + fn on_connection_handler_event( + &mut self, + _peer_id: PeerId, + _connection_id: ConnectionId, + event: THandlerOutEvent, + ) { + match event {} + } + + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + while let Poll::Ready(Some(command)) = self.command_rx.poll_recv(cx) { + self.handle_command(command); + } + + if let Some(event) = self.pending_events.pop_front() { + return Poll::Ready(event); + } + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use std::task::Context; + + use futures::task::noop_waker_ref; + use libp2p::{ + core::{ConnectedPoint, Endpoint, transport::PortUse}, + swarm::{ + ConnectionClosed, ConnectionId, FromSwarm, NetworkBehaviour, ToSwarm, + dial_opts::DialOpts, + }, + }; + use pluto_core::version::SemVer; + use tokio::sync::mpsc; + + use super::*; + + fn test_behaviour(client: Client) -> Behaviour { + let (_unused_tx, command_rx) = mpsc::unbounded_channel(); + let version = SemVer::parse("v1.7").expect("valid version"); + let p2p_context = P2PContext::new([client.peer_id()]); + Behaviour::new( + Server::new(1, vec![1, 2, 3], version), + [client], + p2p_context, + command_rx, + ) + } + + #[test] + fn active_client_requests_dial() { + let (command_tx, command_rx) = mpsc::unbounded_channel(); + let version = SemVer::parse("v1.7").expect("valid version"); + let peer_id = PeerId::random(); + let client = Client::new_wired( + peer_id, + vec![1, 2, 3], + version.clone(), + Default::default(), + command_tx, + ); + let server = Server::new(1, vec![1, 2, 3], version); + let p2p_context = P2PContext::new([peer_id]); + let mut behaviour = Behaviour::new(server, [client.clone()], p2p_context, command_rx); + + client.activate(); + + let waker = noop_waker_ref(); + let mut cx = Context::from_waker(waker); + let poll = NetworkBehaviour::poll(&mut behaviour, &mut cx); + + let Poll::Ready(ToSwarm::Dial { opts }) = poll else { + panic!("expected dial event"); + }; + + assert_eq!(DialOpts::get_peer_id(&opts), Some(peer_id)); + } + + #[test] + fn connection_closed_keeps_client_state_until_last_connection() { + let version = SemVer::parse("v1.7").expect("valid version"); + let peer_id = PeerId::random(); + let client = Client::new(peer_id, vec![1, 2, 3], version); + client.set_connected(true); + assert!(client.try_claim_outbound()); + + let mut behaviour = test_behaviour(client.clone()); + + let address = "/ip4/127.0.0.1/tcp/9000".parse().expect("valid multiaddr"); + let endpoint = ConnectedPoint::Dialer { + address, + role_override: Endpoint::Dialer, + port_use: PortUse::New, + }; + + behaviour.on_swarm_event(FromSwarm::ConnectionClosed(ConnectionClosed { + peer_id, + connection_id: ConnectionId::new_unchecked(1), + endpoint: &endpoint, + cause: None, + remaining_established: 1, + })); + + assert!(client.is_connected()); + assert!(!client.try_claim_outbound()); + + behaviour.on_swarm_event(FromSwarm::ConnectionClosed(ConnectionClosed { + peer_id, + connection_id: ConnectionId::new_unchecked(2), + endpoint: &endpoint, + cause: None, + remaining_established: 0, + })); + + assert!(!client.is_connected()); + assert!(client.try_claim_outbound()); + } +} diff --git a/crates/dkg/src/sync/client.rs b/crates/dkg/src/sync/client.rs new file mode 100644 index 00000000..6c3264e8 --- /dev/null +++ b/crates/dkg/src/sync/client.rs @@ -0,0 +1,229 @@ +use std::{ + sync::{ + Arc, + atomic::{AtomicBool, AtomicI64, Ordering}, + }, + time::Duration, +}; + +use libp2p::PeerId; +use pluto_core::version::SemVer; +use tokio::sync::{mpsc, watch}; +use tokio_util::sync::CancellationToken; + +use super::Command; +use super::error::{Error, Result}; + +/// Default period between sync messages. +pub const DEFAULT_PERIOD: Duration = Duration::from_millis(100); + +/// Configuration for a sync client. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ClientConfig { + /// Period between sync messages. + pub period: Duration, +} + +impl Default for ClientConfig { + fn default() -> Self { + Self { + period: DEFAULT_PERIOD, + } + } +} + +#[derive(Debug)] +struct ClientInner { + peer_id: PeerId, + hash_sig: Vec, + version: SemVer, + period: Duration, + active: AtomicBool, + connected: AtomicBool, + reconnect: AtomicBool, + shutdown_requested: AtomicBool, + finished: AtomicBool, + outbound_claimed: AtomicBool, + step: AtomicI64, + done_tx: watch::Sender>>, + command_tx: Option>, +} + +/// User-facing handle for one outbound sync client. +#[derive(Debug, Clone)] +pub struct Client { + inner: Arc, +} + +impl Client { + /// Creates a new client with the default sync period. + pub fn new(peer_id: PeerId, hash_sig: Vec, version: SemVer) -> Self { + Self::new_with_config(peer_id, hash_sig, version, ClientConfig::default()) + } + + /// Creates a new client with an explicit config. + pub fn new_with_config( + peer_id: PeerId, + hash_sig: Vec, + version: SemVer, + config: ClientConfig, + ) -> Self { + Self::new_with_command(peer_id, hash_sig, version, config, None) + } + + pub(crate) fn new_wired( + peer_id: PeerId, + hash_sig: Vec, + version: SemVer, + config: ClientConfig, + command_tx: mpsc::UnboundedSender, + ) -> Self { + Self::new_with_command(peer_id, hash_sig, version, config, Some(command_tx)) + } + + fn new_with_command( + peer_id: PeerId, + hash_sig: Vec, + version: SemVer, + config: ClientConfig, + command_tx: Option>, + ) -> Self { + let (done_tx, _done_rx) = watch::channel(None); + Self { + inner: Arc::new(ClientInner { + peer_id, + hash_sig, + version, + period: config.period, + active: AtomicBool::new(false), + connected: AtomicBool::new(false), + reconnect: AtomicBool::new(true), + shutdown_requested: AtomicBool::new(false), + finished: AtomicBool::new(false), + outbound_claimed: AtomicBool::new(false), + step: AtomicI64::new(0), + done_tx, + command_tx, + }), + } + } + + /// Returns the target peer for this client. + pub fn peer_id(&self) -> PeerId { + self.inner.peer_id + } + + /// Runs the client until shutdown, fatal error, or cancellation. + pub async fn run(&self, cancellation: CancellationToken) -> Result<()> { + self.activate(); + self.wait_finished(cancellation, true).await + } + + /// Sets the current client step. + pub fn set_step(&self, step: i64) { + self.inner.step.store(step, Ordering::SeqCst); + } + + /// Returns whether the client currently has an active sync stream. + pub fn is_connected(&self) -> bool { + self.inner.connected.load(Ordering::SeqCst) + } + + /// Requests a graceful shutdown and waits for the client to finish. + pub async fn shutdown(&self, cancellation: CancellationToken) -> Result<()> { + self.inner.shutdown_requested.store(true, Ordering::SeqCst); + self.wait_finished(cancellation, false).await + } + + /// Disables reconnecting for non-relay disconnects. + pub fn disable_reconnect(&self) { + self.inner.reconnect.store(false, Ordering::SeqCst); + } + + pub(crate) fn version(&self) -> &SemVer { + &self.inner.version + } + + pub(crate) fn hash_sig(&self) -> &[u8] { + &self.inner.hash_sig + } + + pub(crate) fn period(&self) -> Duration { + self.inner.period + } + + pub(crate) fn should_run(&self) -> bool { + self.inner.active.load(Ordering::SeqCst) + } + + pub(crate) fn should_reconnect(&self) -> bool { + self.inner.reconnect.load(Ordering::SeqCst) + } + + pub(crate) fn shutdown_requested(&self) -> bool { + self.inner.shutdown_requested.load(Ordering::SeqCst) + } + + pub(crate) fn step(&self) -> i64 { + self.inner.step.load(Ordering::SeqCst) + } + + pub(crate) fn set_connected(&self, connected: bool) { + self.inner.connected.store(connected, Ordering::SeqCst); + } + + pub(crate) fn try_claim_outbound(&self) -> bool { + !self.inner.outbound_claimed.swap(true, Ordering::SeqCst) + } + + pub(crate) fn release_outbound(&self) { + self.inner.outbound_claimed.store(false, Ordering::SeqCst); + } + + pub(crate) fn finish(&self, result: Result<()>) { + self.inner.active.store(false, Ordering::SeqCst); + self.inner.connected.store(false, Ordering::SeqCst); + self.release_outbound(); + + if !self.inner.finished.swap(true, Ordering::SeqCst) { + let _ = self.inner.done_tx.send(Some(result)); + } + } + + pub(crate) fn activate(&self) { + self.inner.active.store(true, Ordering::SeqCst); + + if let Some(command_tx) = &self.inner.command_tx { + let _ = command_tx.send(Command::Activate(self.inner.peer_id)); + } + } + + async fn wait_finished( + &self, + cancellation: CancellationToken, + clear_on_cancel: bool, + ) -> Result<()> { + let mut done_rx = self.inner.done_tx.subscribe(); + + loop { + if let Some(result) = done_rx.borrow().clone() { + return result; + } + + tokio::select! { + _ = cancellation.cancelled() => { + if clear_on_cancel { + self.inner.active.store(false, Ordering::SeqCst); + self.inner.connected.store(false, Ordering::SeqCst); + } + return Err(Error::Canceled); + } + changed = done_rx.changed() => { + if changed.is_err() { + return Err(Error::message("sync client completion channel closed")); + } + } + } + } + } +} diff --git a/crates/dkg/src/sync/error.rs b/crates/dkg/src/sync/error.rs new file mode 100644 index 00000000..bb487cee --- /dev/null +++ b/crates/dkg/src/sync/error.rs @@ -0,0 +1,138 @@ +use pluto_core::version::SemVer; + +/// Sync result type. +pub type Result = std::result::Result; + +/// Error type for the DKG sync protocol. +#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)] +pub enum Error { + /// Generic message. + #[error("{0}")] + Message(String), + + /// The sync client was canceled. + #[error("sync client canceled")] + Canceled, + + /// The peer returned an application-level error. + #[error("peer responded with error: {0}")] + PeerRespondedWithError(String), + + /// The remote peer version did not match. + #[error("mismatching charon version; expect={expected}, got={got}")] + VersionMismatch { + /// The expected version string. + expected: String, + /// The received version string. + got: String, + }, + + /// The definition hash signature was invalid. + #[error("invalid definition hash signature")] + InvalidDefinitionHashSignature, + + /// The peer reported a step lower than the previous known step. + #[error("peer reported step is behind the last known step")] + PeerStepBehind, + + /// The peer reported a step too far ahead of the previous known step. + #[error("peer reported step is ahead the last known step")] + PeerStepAhead, + + /// The peer reported an invalid first step. + #[error("peer reported abnormal initial step, expected 0 or 1")] + AbnormalInitialStep, + + /// A peer was too far ahead for the awaited step. + #[error("peer step is too far ahead")] + PeerStepTooFarAhead, + + /// The stream protocol could not be negotiated. + #[error("protocol negotiation failed")] + Unsupported, + + /// A message length prefix was invalid. + #[error("invalid sized protobuf length: {0}")] + InvalidMessageLength(i64), + + /// Failed to parse the peer version. + #[error("parse peer version: {0}")] + ParsePeerVersion(String), + + /// Failed to sign the definition hash. + #[error("sign definition hash: {0}")] + SignDefinitionHash(String), + + /// Failed to convert the local key to a libp2p keypair. + #[error("convert secret key to libp2p keypair: {0}")] + KeyConversion(String), + + /// Failed to decode a protobuf message. + #[error("protobuf decode failed: {0}")] + Decode(String), + + /// Failed to encode a protobuf message. + #[error("protobuf encode failed: {0}")] + Encode(String), + + /// An I/O error occurred while reading or writing the stream. + #[error("i/o error: {0}")] + Io(String), + + /// A peer ID could not be converted to a public key. + #[error("peer error: {0}")] + Peer(String), + + /// A sync server operation was attempted before the server was started. + #[error("sync server not started")] + ServerNotStarted, + + /// The local peer ID was missing from the shared P2P context. + #[error("local peer id missing from p2p context")] + LocalPeerMissing, +} + +impl Error { + /// Creates a new generic message error. + pub fn message(message: impl Into) -> Self { + Self::Message(message.into()) + } + + /// Creates an I/O error from the given source. + pub fn io(error: impl std::fmt::Display) -> Self { + Self::Io(error.to_string()) + } + + /// Creates a protobuf decode error from the given source. + pub fn decode(error: impl std::fmt::Display) -> Self { + Self::Decode(error.to_string()) + } + + /// Creates a protobuf encode error from the given source. + pub fn encode(error: impl std::fmt::Display) -> Self { + Self::Encode(error.to_string()) + } + + /// Creates a peer conversion error from the given source. + pub fn peer(error: impl std::fmt::Display) -> Self { + Self::Peer(error.to_string()) + } + + /// Creates a version mismatch error matching Go's wire string. + pub fn version_mismatch(expected: &SemVer, got: &str) -> Self { + Self::VersionMismatch { + expected: expected.to_string(), + got: got.to_string(), + } + } + + /// Returns true if the error should be treated like Go's relay reset path. + pub fn is_relay_error(&self) -> bool { + matches!(self, Self::Io(message) if { + let lowercase = message.to_ascii_lowercase(); + lowercase.contains("connection reset") + || lowercase.contains("resource scope closed") + || lowercase.contains("broken pipe") + }) + } +} diff --git a/crates/dkg/src/sync/handler.rs b/crates/dkg/src/sync/handler.rs new file mode 100644 index 00000000..67e117d5 --- /dev/null +++ b/crates/dkg/src/sync/handler.rs @@ -0,0 +1,388 @@ +//! Connection handler for the DKG sync protocol. + +use std::{ + convert::Infallible, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use futures::{FutureExt, future::BoxFuture}; +use libp2p::{ + PeerId, Stream, + core::upgrade::ReadyUpgrade, + swarm::{ + ConnectionHandler, ConnectionHandlerEvent, StreamProtocol, StreamUpgradeError, + SubstreamProtocol, + handler::{ + ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, + }, + }, +}; +use prost_types::Timestamp; +use tokio::time::Sleep; +use tracing::{debug, info, warn}; + +use crate::dkgpb::v1::sync::{MsgSync, MsgSyncResponse}; + +use super::{ + client::Client, + error::{Error, Result}, + protocol, + server::Server, +}; + +const INITIAL_BACKOFF: Duration = Duration::from_millis(100); +const MAX_BACKOFF: Duration = Duration::from_secs(1); + +type InboundFuture = BoxFuture<'static, Result<()>>; + +enum OutboundState { + Idle, + OpenStream, + Running(BoxFuture<'static, OutboundExit>), + WaitingRetry(Pin>), + Disabled, +} + +enum OutboundExit { + GracefulShutdown, + Reconnectable { error: Error, relay: bool }, + Fatal(Error), +} + +/// Sync connection handler. +pub struct Handler { + peer_id: PeerId, + server: Server, + client: Option, + inbound: Option, + outbound: OutboundState, + backoff: Duration, +} + +impl Handler { + /// Creates a new handler for a single connection. + pub fn new(peer_id: PeerId, server: Server, client: Option) -> Self { + Self { + peer_id, + server, + client, + inbound: None, + outbound: OutboundState::Idle, + backoff: INITIAL_BACKOFF, + } + } + + fn substream_protocol(&self) -> SubstreamProtocol> { + SubstreamProtocol::new(ReadyUpgrade::new(protocol::PROTOCOL_NAME), ()) + } + + fn schedule_retry(&mut self) { + let sleep = Box::pin(tokio::time::sleep(self.backoff)); + self.outbound = OutboundState::WaitingRetry(sleep); + self.backoff = self.backoff.saturating_mul(2).min(MAX_BACKOFF); + } + + fn reset_backoff(&mut self) { + self.backoff = INITIAL_BACKOFF; + } + + fn wants_outbound(&self) -> bool { + self.client + .as_ref() + .is_some_and(|client| client.should_run()) + } + + fn try_request_outbound( + &mut self, + ) -> Option, (), Infallible>> { + let client = self.client.as_ref()?; + if !self.wants_outbound() || !client.try_claim_outbound() { + return None; + } + + self.outbound = OutboundState::OpenStream; + Some(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: self.substream_protocol(), + }) + } + + fn on_dial_upgrade_error( + &mut self, + DialUpgradeError { error, .. }: DialUpgradeError< + (), + ::OutboundProtocol, + >, + ) { + let Some(client) = self.client.as_ref() else { + self.outbound = OutboundState::Disabled; + return; + }; + + client.release_outbound(); + + let error = match error { + StreamUpgradeError::NegotiationFailed => Error::Unsupported, + StreamUpgradeError::Timeout => Error::io(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "sync protocol negotiation timed out", + )), + StreamUpgradeError::Apply(never) => match never {}, + StreamUpgradeError::Io(error) => Error::io(error), + }; + + if client.should_reconnect() || error.is_relay_error() { + self.schedule_retry(); + } else { + client.finish(Err(error)); + self.outbound = OutboundState::Disabled; + } + } +} + +impl ConnectionHandler for Handler { + type FromBehaviour = Infallible; + type InboundOpenInfo = (); + type InboundProtocol = ReadyUpgrade; + type OutboundOpenInfo = (); + type OutboundProtocol = ReadyUpgrade; + type ToBehaviour = Infallible; + + fn listen_protocol(&self) -> SubstreamProtocol { + self.substream_protocol() + } + + fn on_behaviour_event(&mut self, never: Self::FromBehaviour) { + match never {} + } + + fn poll( + &mut self, + cx: &mut Context<'_>, + ) -> Poll< + ConnectionHandlerEvent, + > { + if let Some(inbound) = self.inbound.as_mut() { + match inbound.poll_unpin(cx) { + Poll::Pending => {} + Poll::Ready(Ok(())) => { + self.inbound = None; + } + Poll::Ready(Err(error)) => { + warn!(peer = %self.peer_id, err = %error, "Error serving inbound sync stream"); + self.inbound = None; + } + } + } + + match &mut self.outbound { + OutboundState::Idle => { + if let Some(event) = self.try_request_outbound() { + return Poll::Ready(event); + } + } + OutboundState::OpenStream => {} + OutboundState::WaitingRetry(delay) => { + if delay.as_mut().poll(cx).is_ready() { + if let Some(event) = self.try_request_outbound() { + return Poll::Ready(event); + } + + self.outbound = OutboundState::Idle; + } + } + OutboundState::Running(fut) => match fut.poll_unpin(cx) { + Poll::Pending => {} + Poll::Ready(OutboundExit::GracefulShutdown) => { + if let Some(client) = self.client.as_ref() { + client.finish(Ok(())); + } + self.outbound = OutboundState::Disabled; + } + Poll::Ready(OutboundExit::Reconnectable { error, relay }) => { + let Some(client) = self.client.as_ref() else { + self.outbound = OutboundState::Disabled; + return Poll::Pending; + }; + + client.set_connected(false); + client.release_outbound(); + + if relay || client.should_reconnect() { + if relay { + debug!(peer = %self.peer_id, err = %error, "Relay connection dropped, reconnecting sync client"); + } else { + info!(peer = %self.peer_id, err = %error, "Disconnected from peer"); + } + self.outbound = OutboundState::Idle; + } else { + client.finish(Err(error)); + self.outbound = OutboundState::Disabled; + } + } + Poll::Ready(OutboundExit::Fatal(error)) => { + if let Some(client) = self.client.as_ref() { + client.finish(Err(error)); + } + self.outbound = OutboundState::Disabled; + } + }, + OutboundState::Disabled => {} + } + + Poll::Pending + } + + fn on_connection_event( + &mut self, + event: ConnectionEvent, + ) { + match event { + ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound { + protocol: mut stream, + .. + }) => { + stream.ignore_for_keep_alive(); + self.inbound = + Some(handle_inbound_stream(self.peer_id, self.server.clone(), stream).boxed()); + } + ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound { + protocol: mut stream, + .. + }) => { + let Some(client) = self.client.clone() else { + self.outbound = OutboundState::Disabled; + return; + }; + + stream.ignore_for_keep_alive(); + self.reset_backoff(); + self.outbound = OutboundState::Running(run_outbound_stream(client, stream).boxed()); + } + ConnectionEvent::DialUpgradeError(error) => self.on_dial_upgrade_error(error), + ConnectionEvent::AddressChange(_) + | ConnectionEvent::LocalProtocolsChange(_) + | ConnectionEvent::RemoteProtocolsChange(_) => {} + ConnectionEvent::ListenUpgradeError(_) => {} + _ => {} + } + } +} + +async fn run_outbound_stream(client: Client, mut stream: Stream) -> OutboundExit { + let mut first = true; + let mut interval = tokio::time::interval(client.period()); + let hash_signature = prost::bytes::Bytes::from(client.hash_sig().to_vec()); + let version = client.version().to_string(); + + client.set_connected(true); + + loop { + if first { + first = false; + } else { + interval.tick().await; + } + + let shutdown = client.shutdown_requested(); + let timestamp = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { + Ok(timestamp) => timestamp, + Err(error) => return OutboundExit::Fatal(Error::io(error)), + }; + let nanos = timestamp.subsec_nanos(); + let timestamp = Timestamp { + seconds: i64::try_from(timestamp.as_secs()).unwrap_or(i64::MAX), + nanos: i32::try_from(nanos).unwrap_or(i32::MAX), + }; + let request = MsgSync { + timestamp: Some(timestamp), + hash_signature: hash_signature.clone(), + shutdown, + version: version.clone(), + step: client.step(), + }; + + let response = async { + protocol::write_sync_request(&mut stream, &request).await?; + protocol::read_sync_response(&mut stream).await + } + .await; + + let response = match response { + Ok(response) => response, + Err(error) => { + return OutboundExit::Reconnectable { + relay: error.is_relay_error(), + error, + }; + } + }; + + if !response.error.is_empty() { + return OutboundExit::Fatal(Error::PeerRespondedWithError(response.error)); + } + + if let Some(sync_timestamp) = response.sync_timestamp { + debug!( + peer = %client.peer_id(), + sync_timestamp = ?sync_timestamp, + "Received sync response" + ); + } + + if shutdown { + return OutboundExit::GracefulShutdown; + } + } +} + +async fn handle_inbound_stream(peer_id: PeerId, server: Server, mut stream: Stream) -> Result<()> { + if !server.is_started() { + return Err(Error::ServerNotStarted); + } + + let public_key = pluto_p2p::peer::peer_id_to_libp2p_pk(&peer_id).map_err(Error::peer)?; + + loop { + let message = protocol::read_sync_request(&mut stream).await?; + let mut response = MsgSyncResponse { + sync_timestamp: message.timestamp, + error: String::new(), + }; + + if let Err(error) = protocol::validate_request_with_public_key( + server.def_hash(), + server.version(), + &public_key, + &message, + ) { + server + .set_err(Error::message(format!( + "invalid sync message: peer={peer_id} err={error}" + ))) + .await; + response.error = error.to_string(); + } else { + let (inserted, count) = server.mark_connected(peer_id).await; + if inserted { + info!( + peer = %peer_id, + connected = count, + expected = server.expected_peer_count(), + "Connected to peer" + ); + } + } + + server.update_step(peer_id, message.step).await?; + + protocol::write_sync_response(&mut stream, &response).await?; + + if message.shutdown { + server.set_shutdown(peer_id).await; + server.clear_connected(peer_id).await; + return Ok(()); + } + } +} diff --git a/crates/dkg/src/sync/mod.rs b/crates/dkg/src/sync/mod.rs new file mode 100644 index 00000000..6540d49c --- /dev/null +++ b/crates/dkg/src/sync/mod.rs @@ -0,0 +1,316 @@ +//! DKG peer step synchronization protocol. +//! +//! This module ports Go `charon/dkg/sync` into Rust while keeping the public +//! API split into a per-peer [`Client`] handle and a shared [`Server`] handle. +//! Internally, Pluto drives the protocol through a libp2p behaviour and +//! connection handler because protocol streams are owned by the swarm. + +mod behaviour; +mod client; +mod error; +mod handler; +mod protocol; +mod server; + +use libp2p::PeerId; +use pluto_core::version::SemVer; +use pluto_p2p::p2p_context::P2PContext; +use tokio::sync::mpsc; + +pub use behaviour::{Behaviour, Event}; +pub use client::{Client, ClientConfig, DEFAULT_PERIOD}; +pub use error::{Error, Result}; +pub use server::Server; + +#[derive(Debug, Clone, Copy)] +pub(crate) enum Command { + Activate(PeerId), +} + +/// Creates a sync behaviour plus server/client handles for the given peer set. +pub fn new( + peers: Vec, + p2p_context: P2PContext, + secret: &k256::SecretKey, + def_hash: Vec, + version: SemVer, +) -> Result<(Behaviour, Server, Vec)> { + let local_peer_id = p2p_context.local_peer_id().ok_or(Error::LocalPeerMissing)?; + let hash_sig = protocol::sign_definition_hash(secret, &def_hash)?; + let server = Server::new(peers.len().saturating_sub(1), def_hash, version.clone()); + let (command_tx, command_rx) = mpsc::unbounded_channel(); + let clients = peers + .into_iter() + .filter(|peer_id| *peer_id != local_peer_id) + .map(|peer_id| { + Client::new_wired( + peer_id, + hash_sig.clone(), + version.clone(), + ClientConfig::default(), + command_tx.clone(), + ) + }) + .collect::>(); + let behaviour = Behaviour::new(server.clone(), clients.clone(), p2p_context, command_rx); + Ok((behaviour, server, clients)) +} + +#[cfg(test)] +mod tests { + use std::{net::TcpListener, time::Duration}; + + use futures::StreamExt; + use libp2p::PeerId; + use pluto_core::version::SemVer; + use pluto_p2p::{ + config::P2PConfig, + p2p::{Node, NodeType}, + p2p_context::P2PContext, + peer::peer_id_from_key, + }; + use pluto_testutil::random::generate_insecure_k1_key; + use tokio::{sync::oneshot, time::timeout}; + use tokio_util::sync::CancellationToken; + + use super::*; + + struct LocalNode { + server: Server, + clients: Vec, + node: Node, + addr: libp2p::Multiaddr, + } + + struct RunningNode { + server: Server, + clients: Vec, + stop_tx: oneshot::Sender<()>, + join: tokio::task::JoinHandle>, + client_joins: Vec>>, + } + + fn available_tcp_port() -> anyhow::Result { + let listener = TcpListener::bind("127.0.0.1:0")?; + Ok(listener.local_addr()?.port()) + } + + async fn spawn_nodes(nodes: Vec) -> anyhow::Result> { + let mut nodes = nodes; + for node in &mut nodes { + node.node.listen_on(node.addr.clone())?; + } + + let dial_targets = (0..nodes.len()) + .map(|index| { + nodes + .iter() + .enumerate() + .filter(|(other, _)| *other > index) + .map(|(_, node)| node.addr.clone()) + .collect::>() + }) + .collect::>(); + + let mut running = Vec::with_capacity(nodes.len()); + for (local, targets) in nodes.into_iter().zip(dial_targets) { + let mut node = local.node; + let server = local.server.clone(); + server.start(); + let cancellation = CancellationToken::new(); + let client_joins = local + .clients + .iter() + .map(|client| { + let cancellation = cancellation.child_token(); + let client = client.clone(); + tokio::spawn(async move { client.run(cancellation).await }) + }) + .collect::>(); + let (stop_tx, mut stop_rx) = oneshot::channel(); + + let join = tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(200)).await; + for target in targets { + node.dial(target)?; + } + + loop { + tokio::select! { + _ = &mut stop_rx => break, + _event = node.select_next_some() => {} + } + } + + Ok(()) + }); + + running.push(RunningNode { + server: local.server, + clients: local.clients, + stop_tx, + join, + client_joins, + }); + } + + Ok(running) + } + + async fn shutdown_nodes(nodes: Vec) -> anyhow::Result<()> { + for node in nodes { + let _ = node.stop_tx.send(()); + node.join.await??; + for join in node.client_joins { + let _ = join.await; + } + } + + Ok(()) + } + + #[tokio::test] + async fn update_step_rules_match_go() { + let version = SemVer::parse("v0.1").expect("valid version"); + let server = Server::new(1, vec![0; 32], version); + let peer = PeerId::random(); + + let error = server + .update_step(peer, 100) + .await + .expect_err("wrong initial step should fail"); + assert!(matches!(error, Error::AbnormalInitialStep)); + + let peer = PeerId::random(); + server + .update_step(peer, 1) + .await + .expect("first valid step should pass"); + server + .update_step(peer, 1) + .await + .expect("same step should pass"); + server + .update_step(peer, 2) + .await + .expect("next step should pass"); + + let peer = PeerId::random(); + server + .update_step(peer, 1) + .await + .expect("first step should pass"); + let error = server + .update_step(peer, 0) + .await + .expect_err("behind should fail"); + assert!(matches!(error, Error::PeerStepBehind)); + + let peer = PeerId::random(); + server + .update_step(peer, 1) + .await + .expect("first step should pass"); + let error = server + .update_step(peer, 4) + .await + .expect_err("ahead should fail"); + assert!(matches!(error, Error::PeerStepAhead)); + } + + #[tokio::test] + async fn sync_round_trip_matches_go_shape() -> anyhow::Result<()> { + let ports = (0..3) + .map(|_| available_tcp_port()) + .collect::>>()?; + let keys = (0_u8..3).map(generate_insecure_k1_key).collect::>(); + let peer_ids = keys + .iter() + .map(|key| peer_id_from_key(key.public_key())) + .collect::, _>>()?; + let version = SemVer::parse("v1.7")?; + let mut nodes = Vec::new(); + for (index, key) in keys.into_iter().enumerate() { + let peer_id = peer_ids[index]; + let p2p_context = P2PContext::new(peer_ids.clone()); + p2p_context.set_local_peer_id(peer_id); + let (behaviour, server, clients) = new( + peer_ids.clone(), + p2p_context.clone(), + &key, + vec![1, 2, 3], + version.clone(), + )?; + let node = Node::new_server( + P2PConfig::default(), + key, + NodeType::TCP, + false, + peer_ids.clone(), + |builder, _keypair| builder.with_p2p_context(p2p_context).with_inner(behaviour), + )?; + let addr = format!("/ip4/127.0.0.1/tcp/{}", ports[index]).parse()?; + nodes.push(LocalNode { + server, + clients, + node, + addr, + }); + } + + let running = spawn_nodes(nodes).await?; + let cancellation = CancellationToken::new(); + + for node in &running { + timeout( + Duration::from_secs(10), + node.server.await_all_connected(cancellation.child_token()), + ) + .await??; + } + + for step in 0_i64..5 { + for node in &running { + timeout( + Duration::from_secs(10), + node.server + .await_all_at_step(step, cancellation.child_token()), + ) + .await??; + + let future = node + .server + .await_all_at_step(step + 1, cancellation.child_token()); + let error = timeout(Duration::from_millis(10), future).await; + assert!(error.is_err(), "next step should not complete immediately"); + } + + for node in &running { + for client in &node.clients { + client.set_step(step + 1); + } + } + } + + for node in &running { + assert!(node.clients.iter().all(Client::is_connected)); + } + + for node in &running { + for client in &node.clients { + client.shutdown(cancellation.child_token()).await?; + } + } + + for node in &running { + timeout( + Duration::from_secs(10), + node.server.await_all_shutdown(cancellation.child_token()), + ) + .await??; + } + + shutdown_nodes(running).await?; + Ok(()) + } +} diff --git a/crates/dkg/src/sync/protocol.rs b/crates/dkg/src/sync/protocol.rs new file mode 100644 index 00000000..18f05de0 --- /dev/null +++ b/crates/dkg/src/sync/protocol.rs @@ -0,0 +1,155 @@ +//! Wire protocol helpers for the DKG sync protocol. + +use std::io; + +use futures::{AsyncRead, AsyncWrite}; +use libp2p::{ + Stream, + identity::{Keypair, PublicKey}, +}; +use pluto_core::version::SemVer; +use pluto_p2p::proto::{self, InvalidFixedSizeLength}; +use prost::Message; + +use crate::dkgpb::v1::sync::{MsgSync, MsgSyncResponse}; + +use super::error::{Error, Result}; + +/// The protocol identifier for DKG sync. +pub const PROTOCOL_NAME: libp2p::StreamProtocol = StreamProtocol::new("/charon/dkg/sync/1.0.0/"); + +use libp2p::StreamProtocol; + +/// Signs the definition hash using the same libp2p signing path as Go. +pub fn sign_definition_hash(secret: &k256::SecretKey, def_hash: &[u8]) -> Result> { + let mut der = secret + .to_sec1_der() + .map_err(|error| Error::KeyConversion(error.to_string()))?; + let keypair = Keypair::secp256k1_from_der(&mut der) + .map_err(|error| Error::KeyConversion(error.to_string()))?; + keypair + .sign(def_hash) + .map_err(|error| Error::SignDefinitionHash(error.to_string())) +} + +/// Writes a size-prefixed protobuf to the stream. +pub async fn write_sized_protobuf(writer: &mut W, msg: &M) -> Result<()> +where + M: Message, + W: AsyncWrite + Unpin, +{ + let mut buf = Vec::new(); + msg.encode(&mut buf).map_err(Error::encode)?; + proto::write_fixed_size_delimited(writer, &buf) + .await + .map_err(Error::io) +} + +/// Reads a size-prefixed protobuf from the stream. +pub async fn read_sized_protobuf(reader: &mut R) -> Result +where + M: Message + Default, + R: AsyncRead + Unpin, +{ + let buf = proto::read_fixed_size_delimited(reader) + .await + .map_err(map_fixed_size_read_error)?; + M::decode(buf.as_slice()).map_err(Error::decode) +} + +/// Reads a sync request from the stream. +pub async fn read_sync_request(stream: &mut Stream) -> Result { + read_sized_protobuf(stream).await +} + +/// Writes a sync request to the stream. +pub async fn write_sync_request(stream: &mut Stream, message: &MsgSync) -> Result<()> { + write_sized_protobuf(stream, message).await +} + +/// Reads a sync response from the stream. +pub async fn read_sync_response(stream: &mut Stream) -> Result { + read_sized_protobuf(stream).await +} + +/// Writes a sync response to the stream. +pub async fn write_sync_response(stream: &mut Stream, message: &MsgSyncResponse) -> Result<()> { + write_sized_protobuf(stream, message).await +} + +/// Validates a sync request for a known peer public key. +pub fn validate_request_with_public_key( + def_hash: &[u8], + expected_version: &SemVer, + public_key: &PublicKey, + msg: &MsgSync, +) -> Result<()> { + let msg_version = + SemVer::parse(&msg.version).map_err(|error| Error::ParsePeerVersion(error.to_string()))?; + + if msg_version != *expected_version { + return Err(Error::version_mismatch(expected_version, &msg.version)); + } + + if !public_key.verify(def_hash, &msg.hash_signature) { + return Err(Error::InvalidDefinitionHashSignature); + } + + Ok(()) +} + +fn map_fixed_size_read_error(error: io::Error) -> Error { + if let Some(source) = error.get_ref() + && let Some(length) = source.downcast_ref::() + { + return Error::InvalidMessageLength(length.0); + } + + Error::io(error) +} + +#[cfg(test)] +mod tests { + use futures::{AsyncWriteExt, io::Cursor}; + + use super::*; + + #[tokio::test] + async fn sized_proto_round_trip() { + let message = MsgSync { + timestamp: Some(prost_types::Timestamp { + seconds: 1, + nanos: 2, + }), + hash_signature: vec![1, 2, 3].into(), + shutdown: true, + version: "v1.7".to_string(), + step: 3, + }; + let mut cursor = Cursor::new(Vec::new()); + write_sized_protobuf(&mut cursor, &message) + .await + .expect("writer should succeed"); + cursor.set_position(0); + let decoded = read_sized_protobuf::(&mut cursor) + .await + .expect("decode should succeed"); + + assert_eq!(decoded, message); + } + + #[tokio::test] + async fn negative_message_length_fails() { + let mut cursor = Cursor::new(Vec::new()); + cursor + .write_all(&(-1_i64).to_le_bytes()) + .await + .expect("writer should succeed"); + cursor.set_position(0); + + let error = read_sized_protobuf::(&mut cursor) + .await + .expect_err("negative sizes must fail"); + assert!(matches!(error, Error::InvalidMessageLength(-1))); + } +} diff --git a/crates/dkg/src/sync/server.rs b/crates/dkg/src/sync/server.rs new file mode 100644 index 00000000..1e2ff3ed --- /dev/null +++ b/crates/dkg/src/sync/server.rs @@ -0,0 +1,257 @@ +//! Server handle for the DKG sync protocol. + +use std::{ + collections::{HashMap, HashSet}, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, +}; + +use libp2p::PeerId; +use pluto_core::version::SemVer; +use tokio::sync::{Notify, RwLock}; +use tokio_util::sync::CancellationToken; + +use super::error::{Error, Result}; + +#[derive(Debug, Default)] +struct ServerState { + shutdown: HashSet, + connected: HashSet, + steps: HashMap, + err: Option, +} + +#[derive(Debug)] +struct ServerInner { + all_count: usize, + def_hash: Vec, + version: SemVer, + started: AtomicBool, + notify: Notify, + state: RwLock, +} + +/// User-facing handle for the sync server state. +#[derive(Debug, Clone)] +pub struct Server { + inner: Arc, +} + +impl Server { + /// Creates a new server handle. + pub fn new(all_count: usize, def_hash: Vec, version: SemVer) -> Self { + Self { + inner: Arc::new(ServerInner { + all_count, + def_hash, + version, + started: AtomicBool::new(false), + notify: Notify::new(), + state: RwLock::new(ServerState::default()), + }), + } + } + + /// Starts the server side of the protocol. + pub fn start(&self) { + self.inner.started.store(true, Ordering::SeqCst); + self.inner.notify.notify_waiters(); + } + + /// Returns the current shared server error, if any. + pub async fn err(&self) -> Option { + self.inner.state.read().await.err.clone() + } + + /// Waits until all peers have connected or an error occurs. + pub async fn await_all_connected(&self, cancellation: CancellationToken) -> Result<()> { + loop { + let notified = self.inner.notify.notified(); + tokio::pin!(notified); + + if !self.inner.started.load(Ordering::SeqCst) { + tokio::select! { + _ = cancellation.cancelled() => return Err(Error::Canceled), + _ = &mut notified => {} + } + continue; + } + + { + let state = self.inner.state.read().await; + if let Some(error) = &state.err { + return Err(error.clone()); + } + if state.connected.len() == self.inner.all_count { + return Ok(()); + } + } + + tokio::select! { + _ = cancellation.cancelled() => return Err(Error::Canceled), + _ = &mut notified => {} + } + } + } + + /// Waits until all peers have reported shutdown. + pub async fn await_all_shutdown(&self, cancellation: CancellationToken) -> Result<()> { + loop { + let notified = self.inner.notify.notified(); + tokio::pin!(notified); + + { + let state = self.inner.state.read().await; + if state.shutdown.len() == self.inner.all_count { + return Ok(()); + } + } + + tokio::select! { + _ = cancellation.cancelled() => return Err(Error::Canceled), + _ = &mut notified => {} + } + } + } + + /// Waits until all peers have reached the given step or the next one. + pub async fn await_all_at_step( + &self, + step: i64, + cancellation: CancellationToken, + ) -> Result<()> { + let step_plus_one = step + .checked_add(1) + .ok_or_else(|| Error::message("step overflow"))?; + let step_plus_two = step + .checked_add(2) + .ok_or_else(|| Error::message("step overflow"))?; + + loop { + let notified = self.inner.notify.notified(); + tokio::pin!(notified); + + { + let state = self.inner.state.read().await; + if let Some(error) = &state.err { + return Err(error.clone()); + } + + if state.steps.len() == self.inner.all_count { + let mut all_ok = true; + for actual in state.steps.values() { + if *actual >= step_plus_two { + return Err(Error::PeerStepTooFarAhead); + } + + if *actual != step && *actual != step_plus_one { + all_ok = false; + } + } + + if all_ok { + return Ok(()); + } + } + } + + tokio::select! { + _ = cancellation.cancelled() => return Err(Error::Canceled), + _ = &mut notified => {} + } + } + } + + pub(crate) fn def_hash(&self) -> &[u8] { + &self.inner.def_hash + } + + pub(crate) fn version(&self) -> &SemVer { + &self.inner.version + } + + pub(crate) fn expected_peer_count(&self) -> usize { + self.inner.all_count + } + + pub(crate) fn is_started(&self) -> bool { + self.inner.started.load(Ordering::SeqCst) + } + + pub(crate) async fn mark_connected(&self, peer_id: PeerId) -> (bool, usize) { + self.mutate_state(|state| { + let inserted = state.connected.insert(peer_id); + (inserted, state.connected.len()) + }) + .await + } + + pub(crate) async fn clear_connected(&self, peer_id: PeerId) { + self.mutate_state(|state| { + state.connected.remove(&peer_id); + }) + .await; + } + + pub(crate) async fn set_shutdown(&self, peer_id: PeerId) { + self.mutate_state(|state| { + state.shutdown.insert(peer_id); + }) + .await; + } + + pub(crate) async fn set_err(&self, error: Error) { + self.mutate_state(|state| { + state.err = Some(error); + }) + .await; + } + + pub(crate) async fn update_step(&self, peer_id: PeerId, step: i64) -> Result<()> { + let mut state = self.inner.state.write().await; + let current = state.steps.get(&peer_id).copied(); + + if let Some(current) = current + && step < current + { + return Err(Error::PeerStepBehind); + } + + let current_plus_two = current + .map(|current| { + current + .checked_add(2) + .ok_or_else(|| Error::message("step overflow")) + }) + .transpose()?; + + if let Some(current_plus_two) = current_plus_two + && step > current_plus_two + { + return Err(Error::PeerStepAhead); + } + + if current.is_none() && !(0..=1).contains(&step) { + return Err(Error::AbnormalInitialStep); + } + + if current == Some(step) { + return Ok(()); + } + + state.steps.insert(peer_id, step); + drop(state); + self.inner.notify.notify_waiters(); + Ok(()) + } + + async fn mutate_state(&self, mutate: impl FnOnce(&mut ServerState) -> T) -> T { + let mut state = self.inner.state.write().await; + let result = mutate(&mut state); + drop(state); + self.inner.notify.notify_waiters(); + result + } +} From dfcb872e525fbaaf559e1a3821fb552b19add6b1 Mon Sep 17 00:00:00 2001 From: Quang Le Date: Wed, 1 Apr 2026 14:57:46 +0700 Subject: [PATCH 04/14] fix: skip relay connection handler --- crates/dkg/Cargo.toml | 4 ++-- crates/dkg/src/sync/behaviour.rs | 22 +++++++++++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/crates/dkg/Cargo.toml b/crates/dkg/Cargo.toml index 1d1a19b9..6c2045aa 100644 --- a/crates/dkg/Cargo.toml +++ b/crates/dkg/Cargo.toml @@ -13,6 +13,7 @@ thiserror.workspace = true libp2p.workspace = true futures.workspace = true tokio.workspace = true +tokio-util.workspace = true sha2.workspace = true tracing.workspace = true either.workspace = true @@ -20,6 +21,7 @@ k256.workspace = true pluto-k1util.workspace = true pluto-p2p.workspace = true pluto-cluster.workspace = true +pluto-core.workspace = true pluto-crypto.workspace = true pluto-eth1wrap.workspace = true pluto-eth2util.workspace = true @@ -38,11 +40,9 @@ anyhow.workspace = true clap.workspace = true hex.workspace = true pluto-cluster = { workspace = true, features = ["test-cluster"] } -pluto-core.workspace = true pluto-testutil.workspace = true pluto-tracing.workspace = true serde_json.workspace = true -tokio-util.workspace = true tempfile.workspace = true [lints] diff --git a/crates/dkg/src/sync/behaviour.rs b/crates/dkg/src/sync/behaviour.rs index ac0e6e2d..d0cd97bc 100644 --- a/crates/dkg/src/sync/behaviour.rs +++ b/crates/dkg/src/sync/behaviour.rs @@ -3,11 +3,12 @@ use std::{ task::{Context, Poll}, }; +use either::Either; use libp2p::{ Multiaddr, PeerId, swarm::{ ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, - THandlerOutEvent, ToSwarm, + THandlerOutEvent, ToSwarm, dummy, dial_opts::{DialOpts, PeerCondition}, }, }; @@ -55,6 +56,14 @@ impl Behaviour { Handler::new(peer, self.server.clone(), self.clients.get(&peer).cloned()) } + fn connection_handler_for_peer(&self, peer: PeerId) -> THandler { + if self.clients.contains_key(&peer) { + Either::Left(self.new_handler(peer)) + } else { + Either::Right(dummy::ConnectionHandler) + } + } + fn is_connected(&self, peer_id: &PeerId) -> bool { !self .p2p_context @@ -91,7 +100,7 @@ impl Behaviour { } impl NetworkBehaviour for Behaviour { - type ConnectionHandler = Handler; + type ConnectionHandler = Either; type ToSwarm = Event; fn handle_established_inbound_connection( @@ -101,7 +110,7 @@ impl NetworkBehaviour for Behaviour { _local_addr: &Multiaddr, _remote_addr: &Multiaddr, ) -> std::result::Result, ConnectionDenied> { - Ok(self.new_handler(peer)) + Ok(self.connection_handler_for_peer(peer)) } fn handle_established_outbound_connection( @@ -112,7 +121,7 @@ impl NetworkBehaviour for Behaviour { _role_override: libp2p::core::Endpoint, _port_use: libp2p::core::transport::PortUse, ) -> std::result::Result, ConnectionDenied> { - Ok(self.new_handler(peer)) + Ok(self.connection_handler_for_peer(peer)) } fn on_swarm_event(&mut self, event: FromSwarm) { @@ -150,7 +159,10 @@ impl NetworkBehaviour for Behaviour { _connection_id: ConnectionId, event: THandlerOutEvent, ) { - match event {} + match event { + Either::Left(event) => match event {}, + Either::Right(unreachable) => match unreachable {}, + } } fn poll( From 638d13452656e31721e9dd6011ddd6a4ed4ddf91 Mon Sep 17 00:00:00 2001 From: Quang Le Date: Wed, 1 Apr 2026 15:44:36 +0700 Subject: [PATCH 05/14] refactor: simplify sync client --- Cargo.lock | 1 + crates/dkg/Cargo.toml | 1 + crates/dkg/src/sync/behaviour.rs | 13 ++++++-- crates/dkg/src/sync/client.rs | 54 +++++++++----------------------- crates/dkg/src/sync/mod.rs | 4 +-- 5 files changed, 29 insertions(+), 44 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b7466ffc..a04b609f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5561,6 +5561,7 @@ name = "pluto-dkg" version = "1.7.1" dependencies = [ "anyhow", + "bon", "clap", "either", "futures", diff --git a/crates/dkg/Cargo.toml b/crates/dkg/Cargo.toml index 6c2045aa..34b7f791 100644 --- a/crates/dkg/Cargo.toml +++ b/crates/dkg/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true publish.workspace = true [dependencies] +bon.workspace = true prost.workspace = true prost-types.workspace = true thiserror.workspace = true diff --git a/crates/dkg/src/sync/behaviour.rs b/crates/dkg/src/sync/behaviour.rs index d0cd97bc..93cb0865 100644 --- a/crates/dkg/src/sync/behaviour.rs +++ b/crates/dkg/src/sync/behaviour.rs @@ -197,6 +197,7 @@ mod tests { use tokio::sync::mpsc; use super::*; + use crate::sync::ClientConfig; fn test_behaviour(client: Client) -> Behaviour { let (_unused_tx, command_rx) = mpsc::unbounded_channel(); @@ -215,12 +216,12 @@ mod tests { let (command_tx, command_rx) = mpsc::unbounded_channel(); let version = SemVer::parse("v1.7").expect("valid version"); let peer_id = PeerId::random(); - let client = Client::new_wired( + let client = Client::new( peer_id, vec![1, 2, 3], version.clone(), Default::default(), - command_tx, + Some(command_tx), ); let server = Server::new(1, vec![1, 2, 3], version); let p2p_context = P2PContext::new([peer_id]); @@ -243,7 +244,13 @@ mod tests { fn connection_closed_keeps_client_state_until_last_connection() { let version = SemVer::parse("v1.7").expect("valid version"); let peer_id = PeerId::random(); - let client = Client::new(peer_id, vec![1, 2, 3], version); + let client = Client::new( + peer_id, + vec![1, 2, 3], + version, + ClientConfig::default(), + None, + ); client.set_connected(true); assert!(client.try_claim_outbound()); diff --git a/crates/dkg/src/sync/client.rs b/crates/dkg/src/sync/client.rs index 6c3264e8..096f8624 100644 --- a/crates/dkg/src/sync/client.rs +++ b/crates/dkg/src/sync/client.rs @@ -6,6 +6,7 @@ use std::{ time::Duration, }; +use bon::Builder; use libp2p::PeerId; use pluto_core::version::SemVer; use tokio::sync::{mpsc, watch}; @@ -18,34 +19,33 @@ use super::error::{Error, Result}; pub const DEFAULT_PERIOD: Duration = Duration::from_millis(100); /// Configuration for a sync client. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Builder)] pub struct ClientConfig { /// Period between sync messages. + #[builder(default = DEFAULT_PERIOD)] pub period: Duration, } impl Default for ClientConfig { fn default() -> Self { - Self { - period: DEFAULT_PERIOD, - } + Self::builder().build() } } #[derive(Debug)] struct ClientInner { - peer_id: PeerId, - hash_sig: Vec, - version: SemVer, - period: Duration, active: AtomicBool, connected: AtomicBool, reconnect: AtomicBool, + step: AtomicI64, shutdown_requested: AtomicBool, finished: AtomicBool, outbound_claimed: AtomicBool, - step: AtomicI64, done_tx: watch::Sender>>, + peer_id: PeerId, + hash_sig: Vec, + version: SemVer, + period: Duration, command_tx: Option>, } @@ -56,32 +56,8 @@ pub struct Client { } impl Client { - /// Creates a new client with the default sync period. - pub fn new(peer_id: PeerId, hash_sig: Vec, version: SemVer) -> Self { - Self::new_with_config(peer_id, hash_sig, version, ClientConfig::default()) - } - /// Creates a new client with an explicit config. - pub fn new_with_config( - peer_id: PeerId, - hash_sig: Vec, - version: SemVer, - config: ClientConfig, - ) -> Self { - Self::new_with_command(peer_id, hash_sig, version, config, None) - } - - pub(crate) fn new_wired( - peer_id: PeerId, - hash_sig: Vec, - version: SemVer, - config: ClientConfig, - command_tx: mpsc::UnboundedSender, - ) -> Self { - Self::new_with_command(peer_id, hash_sig, version, config, Some(command_tx)) - } - - fn new_with_command( + pub(crate) fn new( peer_id: PeerId, hash_sig: Vec, version: SemVer, @@ -91,18 +67,18 @@ impl Client { let (done_tx, _done_rx) = watch::channel(None); Self { inner: Arc::new(ClientInner { - peer_id, - hash_sig, - version, - period: config.period, active: AtomicBool::new(false), connected: AtomicBool::new(false), reconnect: AtomicBool::new(true), + step: AtomicI64::new(0), shutdown_requested: AtomicBool::new(false), finished: AtomicBool::new(false), outbound_claimed: AtomicBool::new(false), - step: AtomicI64::new(0), done_tx, + peer_id, + hash_sig, + version, + period: config.period, command_tx, }), } diff --git a/crates/dkg/src/sync/mod.rs b/crates/dkg/src/sync/mod.rs index 6540d49c..a8c4e63f 100644 --- a/crates/dkg/src/sync/mod.rs +++ b/crates/dkg/src/sync/mod.rs @@ -43,12 +43,12 @@ pub fn new( .into_iter() .filter(|peer_id| *peer_id != local_peer_id) .map(|peer_id| { - Client::new_wired( + Client::new( peer_id, hash_sig.clone(), version.clone(), ClientConfig::default(), - command_tx.clone(), + Some(command_tx.clone()), ) }) .collect::>(); From 522ab554720819435d51f521265f2f4885fd888c Mon Sep 17 00:00:00 2001 From: Quang Le Date: Wed, 1 Apr 2026 16:38:08 +0700 Subject: [PATCH 06/14] refactor: inline functions --- crates/dkg/src/sync/behaviour.rs | 46 +++++++++++------------- crates/dkg/src/sync/client.rs | 29 ++++++++------- crates/dkg/src/sync/handler.rs | 14 ++------ crates/dkg/src/sync/protocol.rs | 23 +++++------- crates/dkg/src/sync/server.rs | 60 +++++++++++++------------------- 5 files changed, 70 insertions(+), 102 deletions(-) diff --git a/crates/dkg/src/sync/behaviour.rs b/crates/dkg/src/sync/behaviour.rs index 93cb0865..2d00c532 100644 --- a/crates/dkg/src/sync/behaviour.rs +++ b/crates/dkg/src/sync/behaviour.rs @@ -52,38 +52,18 @@ impl Behaviour { } } - fn new_handler(&self, peer: PeerId) -> Handler { - Handler::new(peer, self.server.clone(), self.clients.get(&peer).cloned()) - } - fn connection_handler_for_peer(&self, peer: PeerId) -> THandler { if self.clients.contains_key(&peer) { - Either::Left(self.new_handler(peer)) + Either::Left(Handler::new( + peer, + self.server.clone(), + self.clients.get(&peer).cloned(), + )) } else { Either::Right(dummy::ConnectionHandler) } } - fn is_connected(&self, peer_id: &PeerId) -> bool { - !self - .p2p_context - .peer_store_lock() - .connections_to_peer(peer_id) - .is_empty() - } - - fn queue_dial(&mut self, peer_id: PeerId) { - if self.is_connected(&peer_id) || !self.pending_dials.insert(peer_id) { - return; - } - - self.pending_events.push_back(ToSwarm::Dial { - opts: DialOpts::peer_id(peer_id) - .condition(PeerCondition::DisconnectedAndNotDialing) - .build(), - }); - } - fn handle_command(&mut self, command: Command) { match command { Command::Activate(peer_id) => { @@ -92,7 +72,21 @@ impl Behaviour { }; if client.should_run() && !client.is_connected() { - self.queue_dial(peer_id); + if !self + .p2p_context + .peer_store_lock() + .connections_to_peer(&peer_id) + .is_empty() + || !self.pending_dials.insert(peer_id) + { + return; + } + + self.pending_events.push_back(ToSwarm::Dial { + opts: DialOpts::peer_id(peer_id) + .condition(PeerCondition::DisconnectedAndNotDialing) + .build(), + }); } } } diff --git a/crates/dkg/src/sync/client.rs b/crates/dkg/src/sync/client.rs index 096f8624..1e47475e 100644 --- a/crates/dkg/src/sync/client.rs +++ b/crates/dkg/src/sync/client.rs @@ -34,6 +34,10 @@ impl Default for ClientConfig { #[derive(Debug)] struct ClientInner { + peer_id: PeerId, + hash_sig: Vec, + version: SemVer, + period: Duration, active: AtomicBool, connected: AtomicBool, reconnect: AtomicBool, @@ -42,10 +46,6 @@ struct ClientInner { finished: AtomicBool, outbound_claimed: AtomicBool, done_tx: watch::Sender>>, - peer_id: PeerId, - hash_sig: Vec, - version: SemVer, - period: Duration, command_tx: Option>, } @@ -67,6 +67,10 @@ impl Client { let (done_tx, _done_rx) = watch::channel(None); Self { inner: Arc::new(ClientInner { + peer_id, + hash_sig, + version, + period: config.period, active: AtomicBool::new(false), connected: AtomicBool::new(false), reconnect: AtomicBool::new(true), @@ -75,20 +79,11 @@ impl Client { finished: AtomicBool::new(false), outbound_claimed: AtomicBool::new(false), done_tx, - peer_id, - hash_sig, - version, - period: config.period, command_tx, }), } } - /// Returns the target peer for this client. - pub fn peer_id(&self) -> PeerId { - self.inner.peer_id - } - /// Runs the client until shutdown, fatal error, or cancellation. pub async fn run(&self, cancellation: CancellationToken) -> Result<()> { self.activate(); @@ -116,14 +111,18 @@ impl Client { self.inner.reconnect.store(false, Ordering::SeqCst); } - pub(crate) fn version(&self) -> &SemVer { - &self.inner.version + pub(crate) fn peer_id(&self) -> PeerId { + self.inner.peer_id } pub(crate) fn hash_sig(&self) -> &[u8] { &self.inner.hash_sig } + pub(crate) fn version(&self) -> &SemVer { + &self.inner.version + } + pub(crate) fn period(&self) -> Duration { self.inner.period } diff --git a/crates/dkg/src/sync/handler.rs b/crates/dkg/src/sync/handler.rs index 67e117d5..7b74f8eb 100644 --- a/crates/dkg/src/sync/handler.rs +++ b/crates/dkg/src/sync/handler.rs @@ -84,21 +84,11 @@ impl Handler { self.backoff = self.backoff.saturating_mul(2).min(MAX_BACKOFF); } - fn reset_backoff(&mut self) { - self.backoff = INITIAL_BACKOFF; - } - - fn wants_outbound(&self) -> bool { - self.client - .as_ref() - .is_some_and(|client| client.should_run()) - } - fn try_request_outbound( &mut self, ) -> Option, (), Infallible>> { let client = self.client.as_ref()?; - if !self.wants_outbound() || !client.try_claim_outbound() { + if !client.should_run() || !client.try_claim_outbound() { return None; } @@ -257,7 +247,7 @@ impl ConnectionHandler for Handler { }; stream.ignore_for_keep_alive(); - self.reset_backoff(); + self.backoff = INITIAL_BACKOFF; self.outbound = OutboundState::Running(run_outbound_stream(client, stream).boxed()); } ConnectionEvent::DialUpgradeError(error) => self.on_dial_upgrade_error(error), diff --git a/crates/dkg/src/sync/protocol.rs b/crates/dkg/src/sync/protocol.rs index 18f05de0..713116d4 100644 --- a/crates/dkg/src/sync/protocol.rs +++ b/crates/dkg/src/sync/protocol.rs @@ -1,7 +1,5 @@ //! Wire protocol helpers for the DKG sync protocol. -use std::io; - use futures::{AsyncRead, AsyncWrite}; use libp2p::{ Stream, @@ -53,7 +51,15 @@ where { let buf = proto::read_fixed_size_delimited(reader) .await - .map_err(map_fixed_size_read_error)?; + .map_err(|error| { + if let Some(source) = error.get_ref() + && let Some(length) = source.downcast_ref::() + { + return Error::InvalidMessageLength(length.0); + } + + Error::io(error) + })?; M::decode(buf.as_slice()).map_err(Error::decode) } @@ -97,17 +103,6 @@ pub fn validate_request_with_public_key( Ok(()) } - -fn map_fixed_size_read_error(error: io::Error) -> Error { - if let Some(source) = error.get_ref() - && let Some(length) = source.downcast_ref::() - { - return Error::InvalidMessageLength(length.0); - } - - Error::io(error) -} - #[cfg(test)] mod tests { use futures::{AsyncWriteExt, io::Cursor}; diff --git a/crates/dkg/src/sync/server.rs b/crates/dkg/src/sync/server.rs index 1e2ff3ed..02bea40b 100644 --- a/crates/dkg/src/sync/server.rs +++ b/crates/dkg/src/sync/server.rs @@ -181,11 +181,11 @@ impl Server { } pub(crate) async fn mark_connected(&self, peer_id: PeerId) -> (bool, usize) { - self.mutate_state(|state| { - let inserted = state.connected.insert(peer_id); - (inserted, state.connected.len()) - }) - .await + let mut state = self.inner.state.write().await; + let inserted = state.connected.insert(peer_id); + let count = state.connected.len(); + self.inner.notify.notify_waiters(); + (inserted, count) } pub(crate) async fn clear_connected(&self, peer_id: PeerId) { @@ -211,47 +211,37 @@ impl Server { pub(crate) async fn update_step(&self, peer_id: PeerId, step: i64) -> Result<()> { let mut state = self.inner.state.write().await; - let current = state.steps.get(&peer_id).copied(); - - if let Some(current) = current - && step < current - { - return Err(Error::PeerStepBehind); - } + match state.steps.get(&peer_id).copied() { + Some(current) => { + if step < current { + return Err(Error::PeerStepBehind); + } - let current_plus_two = current - .map(|current| { - current + let current_plus_two = current .checked_add(2) - .ok_or_else(|| Error::message("step overflow")) - }) - .transpose()?; - - if let Some(current_plus_two) = current_plus_two - && step > current_plus_two - { - return Err(Error::PeerStepAhead); - } - - if current.is_none() && !(0..=1).contains(&step) { - return Err(Error::AbnormalInitialStep); - } + .ok_or_else(|| Error::message("step overflow"))?; + if step > current_plus_two { + return Err(Error::PeerStepAhead); + } - if current == Some(step) { - return Ok(()); + if step == current { + return Ok(()); + } + } + None if !(0..=1).contains(&step) => { + return Err(Error::AbnormalInitialStep); + } + None => {} } state.steps.insert(peer_id, step); - drop(state); self.inner.notify.notify_waiters(); Ok(()) } - async fn mutate_state(&self, mutate: impl FnOnce(&mut ServerState) -> T) -> T { + async fn mutate_state(&self, mutate: impl FnOnce(&mut ServerState)) { let mut state = self.inner.state.write().await; - let result = mutate(&mut state); - drop(state); + mutate(&mut state); self.inner.notify.notify_waiters(); - result } } From 55356758d7b0a771fcfda234c80cee2b17ac5db9 Mon Sep 17 00:00:00 2001 From: Quang Le Date: Wed, 1 Apr 2026 17:13:20 +0700 Subject: [PATCH 07/14] fix: don't need to self-manage pending dial --- crates/dkg/src/sync/behaviour.rs | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/crates/dkg/src/sync/behaviour.rs b/crates/dkg/src/sync/behaviour.rs index 2d00c532..f70e1847 100644 --- a/crates/dkg/src/sync/behaviour.rs +++ b/crates/dkg/src/sync/behaviour.rs @@ -1,5 +1,5 @@ use std::{ - collections::{HashMap, HashSet, VecDeque}, + collections::{HashMap, VecDeque}, task::{Context, Poll}, }; @@ -27,7 +27,6 @@ pub struct Behaviour { clients: HashMap, p2p_context: P2PContext, command_rx: mpsc::UnboundedReceiver, - pending_dials: HashSet, pending_events: VecDeque>>, } @@ -47,7 +46,6 @@ impl Behaviour { .collect(), p2p_context, command_rx, - pending_dials: HashSet::new(), pending_events: VecDeque::new(), } } @@ -77,7 +75,6 @@ impl Behaviour { .peer_store_lock() .connections_to_peer(&peer_id) .is_empty() - || !self.pending_dials.insert(peer_id) { return; } @@ -120,9 +117,6 @@ impl NetworkBehaviour for Behaviour { fn on_swarm_event(&mut self, event: FromSwarm) { match event { - FromSwarm::ConnectionEstablished(event) => { - self.pending_dials.remove(&event.peer_id); - } FromSwarm::ConnectionClosed(event) => { if event.remaining_established > 0 { return; @@ -131,18 +125,11 @@ impl NetworkBehaviour for Behaviour { // TODO: Go retries sync client connections until reconnect is disabled. // Re-queue active clients here (and on DialFailure below) so peers that // restart before initial cluster sync can be dialed again. - self.pending_dials.remove(&event.peer_id); - if let Some(client) = self.clients.get(&event.peer_id) { client.set_connected(false); client.release_outbound(); } } - FromSwarm::DialFailure(event) => { - if let Some(peer_id) = event.peer_id { - self.pending_dials.remove(&peer_id); - } - } _ => {} } } From 481a77613aa2dbf2652ff9019816b6a345b2784d Mon Sep 17 00:00:00 2001 From: Quang Le Date: Wed, 1 Apr 2026 17:38:30 +0700 Subject: [PATCH 08/14] fix: simplify code write/read proto --- crates/dkg/src/sync/error.rs | 4 -- crates/dkg/src/sync/handler.rs | 30 +++++---- crates/dkg/src/sync/protocol.rs | 110 +------------------------------- crates/p2p/src/proto.rs | 34 ++++++---- 4 files changed, 42 insertions(+), 136 deletions(-) diff --git a/crates/dkg/src/sync/error.rs b/crates/dkg/src/sync/error.rs index bb487cee..75e3d718 100644 --- a/crates/dkg/src/sync/error.rs +++ b/crates/dkg/src/sync/error.rs @@ -51,10 +51,6 @@ pub enum Error { #[error("protocol negotiation failed")] Unsupported, - /// A message length prefix was invalid. - #[error("invalid sized protobuf length: {0}")] - InvalidMessageLength(i64), - /// Failed to parse the peer version. #[error("parse peer version: {0}")] ParsePeerVersion(String), diff --git a/crates/dkg/src/sync/handler.rs b/crates/dkg/src/sync/handler.rs index 7b74f8eb..64c90c9f 100644 --- a/crates/dkg/src/sync/handler.rs +++ b/crates/dkg/src/sync/handler.rs @@ -261,7 +261,6 @@ impl ConnectionHandler for Handler { } async fn run_outbound_stream(client: Client, mut stream: Stream) -> OutboundExit { - let mut first = true; let mut interval = tokio::time::interval(client.period()); let hash_signature = prost::bytes::Bytes::from(client.hash_sig().to_vec()); let version = client.version().to_string(); @@ -269,11 +268,7 @@ async fn run_outbound_stream(client: Client, mut stream: Stream) -> OutboundExit client.set_connected(true); loop { - if first { - first = false; - } else { - interval.tick().await; - } + interval.tick().await; let shutdown = client.shutdown_requested(); let timestamp = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { @@ -293,11 +288,16 @@ async fn run_outbound_stream(client: Client, mut stream: Stream) -> OutboundExit step: client.step(), }; - let response = async { - protocol::write_sync_request(&mut stream, &request).await?; - protocol::read_sync_response(&mut stream).await - } - .await; + let response: Result = + match pluto_p2p::proto::write_fixed_size_protobuf(&mut stream, &request) + .await + .map_err(Error::io) + { + Ok(()) => pluto_p2p::proto::read_fixed_size_protobuf(&mut stream) + .await + .map_err(Error::io), + Err(error) => Err(error), + }; let response = match response { Ok(response) => response, @@ -335,7 +335,9 @@ async fn handle_inbound_stream(peer_id: PeerId, server: Server, mut stream: Stre let public_key = pluto_p2p::peer::peer_id_to_libp2p_pk(&peer_id).map_err(Error::peer)?; loop { - let message = protocol::read_sync_request(&mut stream).await?; + let message: MsgSync = pluto_p2p::proto::read_fixed_size_protobuf(&mut stream) + .await + .map_err(Error::io)?; let mut response = MsgSyncResponse { sync_timestamp: message.timestamp, error: String::new(), @@ -367,7 +369,9 @@ async fn handle_inbound_stream(peer_id: PeerId, server: Server, mut stream: Stre server.update_step(peer_id, message.step).await?; - protocol::write_sync_response(&mut stream, &response).await?; + pluto_p2p::proto::write_fixed_size_protobuf(&mut stream, &response) + .await + .map_err(Error::io)?; if message.shutdown { server.set_shutdown(peer_id).await; diff --git a/crates/dkg/src/sync/protocol.rs b/crates/dkg/src/sync/protocol.rs index 713116d4..7aa0bf9b 100644 --- a/crates/dkg/src/sync/protocol.rs +++ b/crates/dkg/src/sync/protocol.rs @@ -1,15 +1,9 @@ -//! Wire protocol helpers for the DKG sync protocol. +//! Protocol helpers for the DKG sync protocol. -use futures::{AsyncRead, AsyncWrite}; -use libp2p::{ - Stream, - identity::{Keypair, PublicKey}, -}; +use libp2p::identity::{Keypair, PublicKey}; use pluto_core::version::SemVer; -use pluto_p2p::proto::{self, InvalidFixedSizeLength}; -use prost::Message; -use crate::dkgpb::v1::sync::{MsgSync, MsgSyncResponse}; +use crate::dkgpb::v1::sync::MsgSync; use super::error::{Error, Result}; @@ -30,59 +24,6 @@ pub fn sign_definition_hash(secret: &k256::SecretKey, def_hash: &[u8]) -> Result .map_err(|error| Error::SignDefinitionHash(error.to_string())) } -/// Writes a size-prefixed protobuf to the stream. -pub async fn write_sized_protobuf(writer: &mut W, msg: &M) -> Result<()> -where - M: Message, - W: AsyncWrite + Unpin, -{ - let mut buf = Vec::new(); - msg.encode(&mut buf).map_err(Error::encode)?; - proto::write_fixed_size_delimited(writer, &buf) - .await - .map_err(Error::io) -} - -/// Reads a size-prefixed protobuf from the stream. -pub async fn read_sized_protobuf(reader: &mut R) -> Result -where - M: Message + Default, - R: AsyncRead + Unpin, -{ - let buf = proto::read_fixed_size_delimited(reader) - .await - .map_err(|error| { - if let Some(source) = error.get_ref() - && let Some(length) = source.downcast_ref::() - { - return Error::InvalidMessageLength(length.0); - } - - Error::io(error) - })?; - M::decode(buf.as_slice()).map_err(Error::decode) -} - -/// Reads a sync request from the stream. -pub async fn read_sync_request(stream: &mut Stream) -> Result { - read_sized_protobuf(stream).await -} - -/// Writes a sync request to the stream. -pub async fn write_sync_request(stream: &mut Stream, message: &MsgSync) -> Result<()> { - write_sized_protobuf(stream, message).await -} - -/// Reads a sync response from the stream. -pub async fn read_sync_response(stream: &mut Stream) -> Result { - read_sized_protobuf(stream).await -} - -/// Writes a sync response to the stream. -pub async fn write_sync_response(stream: &mut Stream, message: &MsgSyncResponse) -> Result<()> { - write_sized_protobuf(stream, message).await -} - /// Validates a sync request for a known peer public key. pub fn validate_request_with_public_key( def_hash: &[u8], @@ -103,48 +44,3 @@ pub fn validate_request_with_public_key( Ok(()) } -#[cfg(test)] -mod tests { - use futures::{AsyncWriteExt, io::Cursor}; - - use super::*; - - #[tokio::test] - async fn sized_proto_round_trip() { - let message = MsgSync { - timestamp: Some(prost_types::Timestamp { - seconds: 1, - nanos: 2, - }), - hash_signature: vec![1, 2, 3].into(), - shutdown: true, - version: "v1.7".to_string(), - step: 3, - }; - let mut cursor = Cursor::new(Vec::new()); - write_sized_protobuf(&mut cursor, &message) - .await - .expect("writer should succeed"); - cursor.set_position(0); - let decoded = read_sized_protobuf::(&mut cursor) - .await - .expect("decode should succeed"); - - assert_eq!(decoded, message); - } - - #[tokio::test] - async fn negative_message_length_fails() { - let mut cursor = Cursor::new(Vec::new()); - cursor - .write_all(&(-1_i64).to_le_bytes()) - .await - .expect("writer should succeed"); - cursor.set_position(0); - - let error = read_sized_protobuf::(&mut cursor) - .await - .expect_err("negative sizes must fail"); - assert!(matches!(error, Error::InvalidMessageLength(-1))); - } -} diff --git a/crates/p2p/src/proto.rs b/crates/p2p/src/proto.rs index d49a3fdb..9d16cd0c 100644 --- a/crates/p2p/src/proto.rs +++ b/crates/p2p/src/proto.rs @@ -7,11 +7,6 @@ use unsigned_varint::aio::read_usize; /// Default maximum protobuf message size pub const MAX_MESSAGE_SIZE: usize = 128 << 20; -/// Error returned when a fixed-size frame uses a negative length prefix. -#[derive(Debug, thiserror::Error)] -#[error("invalid fixed-size frame length: {0}")] -pub struct InvalidFixedSizeLength(pub i64); - /// Writes a length-delimited payload to the stream. /// /// Format: `[unsigned varint length][payload bytes]` @@ -78,7 +73,7 @@ pub async fn read_fixed_size_delimited( if len < 0 { return Err(io::Error::new( io::ErrorKind::InvalidData, - InvalidFixedSizeLength(len), + format!("invalid data length"), )); } @@ -102,6 +97,18 @@ pub async fn write_protobuf( write_length_delimited(stream, &buf).await } +/// Encodes a protobuf message and writes it with fixed-size framing. +pub async fn write_fixed_size_protobuf( + stream: &mut S, + msg: &M, +) -> io::Result<()> { + let mut buf = Vec::with_capacity(msg.encoded_len()); + msg.encode(&mut buf) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + + write_fixed_size_delimited(stream, &buf).await +} + /// Reads a protobuf message using the default maximum message size. pub async fn read_protobuf( stream: &mut S, @@ -118,6 +125,14 @@ pub async fn read_protobuf_with_max_size( + stream: &mut S, +) -> io::Result { + let buf = read_fixed_size_delimited(stream).await?; + M::decode(&buf[..]).map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error)) +} + #[cfg(test)] mod tests { use futures::io::Cursor; @@ -148,11 +163,6 @@ mod tests { let error = read_fixed_size_delimited(&mut cursor) .await .expect_err("negative sizes must fail"); - let size = error - .get_ref() - .and_then(|source| source.downcast_ref::()) - .map(|error| error.0); - - assert_eq!(size, Some(-1)); + assert_eq!(error.kind(), io::ErrorKind::InvalidData); } } From 25e2294cacf702aba65cabdbe713a867a50d184d Mon Sep 17 00:00:00 2001 From: Quang Le Date: Fri, 3 Apr 2026 14:21:41 +0700 Subject: [PATCH 09/14] feat: add events --- crates/dkg/src/sync/behaviour.rs | 24 ++++++++++++-- crates/dkg/src/sync/handler.rs | 54 ++++++++++++++++++++++++++++---- crates/dkg/src/sync/server.rs | 6 ++-- 3 files changed, 73 insertions(+), 11 deletions(-) diff --git a/crates/dkg/src/sync/behaviour.rs b/crates/dkg/src/sync/behaviour.rs index f70e1847..6a8588b6 100644 --- a/crates/dkg/src/sync/behaviour.rs +++ b/crates/dkg/src/sync/behaviour.rs @@ -19,7 +19,27 @@ use super::{Command, client::Client, handler::Handler, server::Server}; /// Event emitted by the sync behaviour. #[derive(Debug, Clone)] -pub enum Event {} +pub enum Event { + /// A peer advanced to a new sync step. + PeerStepUpdated { + /// The peer whose step was updated. + peer_id: PeerId, + /// The updated step. + step: i64, + }, + /// A peer requested graceful shutdown through sync. + PeerShutdownObserved { + /// The peer that requested shutdown. + peer_id: PeerId, + }, + /// A peer sent a sync message that failed. + SyncRejected { + /// The peer whose sync message was rejected. + peer_id: PeerId, + /// The validation error. + error: super::error::Error, + }, +} /// Swarm behaviour backing the DKG sync protocol. pub struct Behaviour { @@ -141,7 +161,7 @@ impl NetworkBehaviour for Behaviour { event: THandlerOutEvent, ) { match event { - Either::Left(event) => match event {}, + Either::Left(event) => self.pending_events.push_back(ToSwarm::GenerateEvent(event)), Either::Right(unreachable) => match unreachable {}, } } diff --git a/crates/dkg/src/sync/handler.rs b/crates/dkg/src/sync/handler.rs index 64c90c9f..fc2b89ae 100644 --- a/crates/dkg/src/sync/handler.rs +++ b/crates/dkg/src/sync/handler.rs @@ -20,6 +20,7 @@ use libp2p::{ }, }; use prost_types::Timestamp; +use tokio::sync::mpsc; use tokio::time::Sleep; use tracing::{debug, info, warn}; @@ -30,6 +31,7 @@ use super::{ error::{Error, Result}, protocol, server::Server, + Event, }; const INITIAL_BACKOFF: Duration = Duration::from_millis(100); @@ -37,6 +39,9 @@ const MAX_BACKOFF: Duration = Duration::from_secs(1); type InboundFuture = BoxFuture<'static, Result<()>>; +/// Protocol-level events emitted by the sync handler. +pub type OutEvent = Event; + enum OutboundState { Idle, OpenStream, @@ -57,6 +62,8 @@ pub struct Handler { server: Server, client: Option, inbound: Option, + inbound_events_tx: mpsc::UnboundedSender, + inbound_events_rx: mpsc::UnboundedReceiver, outbound: OutboundState, backoff: Duration, } @@ -64,11 +71,14 @@ pub struct Handler { impl Handler { /// Creates a new handler for a single connection. pub fn new(peer_id: PeerId, server: Server, client: Option) -> Self { + let (inbound_events_tx, inbound_events_rx) = mpsc::unbounded_channel(); Self { peer_id, server, client, inbound: None, + inbound_events_tx, + inbound_events_rx, outbound: OutboundState::Idle, backoff: INITIAL_BACKOFF, } @@ -86,7 +96,7 @@ impl Handler { fn try_request_outbound( &mut self, - ) -> Option, (), Infallible>> { + ) -> Option, (), OutEvent>> { let client = self.client.as_ref()?; if !client.should_run() || !client.try_claim_outbound() { return None; @@ -137,7 +147,7 @@ impl ConnectionHandler for Handler { type InboundProtocol = ReadyUpgrade; type OutboundOpenInfo = (); type OutboundProtocol = ReadyUpgrade; - type ToBehaviour = Infallible; + type ToBehaviour = OutEvent; fn listen_protocol(&self) -> SubstreamProtocol { self.substream_protocol() @@ -153,6 +163,10 @@ impl ConnectionHandler for Handler { ) -> Poll< ConnectionHandlerEvent, > { + if let Poll::Ready(Some(event)) = self.inbound_events_rx.poll_recv(cx) { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + } + if let Some(inbound) = self.inbound.as_mut() { match inbound.poll_unpin(cx) { Poll::Pending => {} @@ -234,8 +248,15 @@ impl ConnectionHandler for Handler { .. }) => { stream.ignore_for_keep_alive(); - self.inbound = - Some(handle_inbound_stream(self.peer_id, self.server.clone(), stream).boxed()); + self.inbound = Some( + handle_inbound_stream( + self.peer_id, + self.server.clone(), + self.inbound_events_tx.clone(), + stream, + ) + .boxed(), + ); } ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound { protocol: mut stream, @@ -327,7 +348,12 @@ async fn run_outbound_stream(client: Client, mut stream: Stream) -> OutboundExit } } -async fn handle_inbound_stream(peer_id: PeerId, server: Server, mut stream: Stream) -> Result<()> { +async fn handle_inbound_stream( + peer_id: PeerId, + server: Server, + inbound_events_tx: mpsc::UnboundedSender, + mut stream: Stream, +) -> Result<()> { if !server.is_started() { return Err(Error::ServerNotStarted); } @@ -349,6 +375,10 @@ async fn handle_inbound_stream(peer_id: PeerId, server: Server, mut stream: Stre &public_key, &message, ) { + send_inbound_event(&inbound_events_tx, OutEvent::SyncRejected { + peer_id, + error: error.clone(), + }); server .set_err(Error::message(format!( "invalid sync message: peer={peer_id} err={error}" @@ -367,16 +397,28 @@ async fn handle_inbound_stream(peer_id: PeerId, server: Server, mut stream: Stre } } - server.update_step(peer_id, message.step).await?; + if server.update_step(peer_id, message.step).await? { + send_inbound_event(&inbound_events_tx, OutEvent::PeerStepUpdated { + peer_id, + step: message.step, + }); + } pluto_p2p::proto::write_fixed_size_protobuf(&mut stream, &response) .await .map_err(Error::io)?; if message.shutdown { + send_inbound_event(&inbound_events_tx, OutEvent::PeerShutdownObserved { peer_id }); server.set_shutdown(peer_id).await; server.clear_connected(peer_id).await; return Ok(()); } } } + +fn send_inbound_event(inbound_events_tx: &mpsc::UnboundedSender, event: OutEvent) { + if let Err(error) = inbound_events_tx.send(event) { + tracing::error!(err = %error, "Failed to deliver inbound sync event"); + } +} diff --git a/crates/dkg/src/sync/server.rs b/crates/dkg/src/sync/server.rs index 02bea40b..0c8e7372 100644 --- a/crates/dkg/src/sync/server.rs +++ b/crates/dkg/src/sync/server.rs @@ -209,7 +209,7 @@ impl Server { .await; } - pub(crate) async fn update_step(&self, peer_id: PeerId, step: i64) -> Result<()> { + pub(crate) async fn update_step(&self, peer_id: PeerId, step: i64) -> Result { let mut state = self.inner.state.write().await; match state.steps.get(&peer_id).copied() { Some(current) => { @@ -225,7 +225,7 @@ impl Server { } if step == current { - return Ok(()); + return Ok(false); } } None if !(0..=1).contains(&step) => { @@ -236,7 +236,7 @@ impl Server { state.steps.insert(peer_id, step); self.inner.notify.notify_waiters(); - Ok(()) + Ok(true) } async fn mutate_state(&self, mutate: impl FnOnce(&mut ServerState)) { From d4d5c8d91c05824d3c88745eb4cd9de22db82fea Mon Sep 17 00:00:00 2001 From: Quang Le Date: Fri, 3 Apr 2026 15:45:58 +0700 Subject: [PATCH 10/14] feat: implement example sync --- crates/dkg/examples/sync.rs | 756 ++++++++++++++++++++++++++++++++ crates/dkg/src/sync/client.rs | 2 +- crates/dkg/src/sync/error.rs | 68 +-- crates/dkg/src/sync/handler.rs | 57 ++- crates/dkg/src/sync/protocol.rs | 5 +- crates/dkg/src/sync/server.rs | 6 +- 6 files changed, 810 insertions(+), 84 deletions(-) create mode 100644 crates/dkg/examples/sync.rs diff --git a/crates/dkg/examples/sync.rs b/crates/dkg/examples/sync.rs new file mode 100644 index 00000000..5b408906 --- /dev/null +++ b/crates/dkg/examples/sync.rs @@ -0,0 +1,756 @@ +//! Relay-based example for the DKG sync protocol. +//! +//! This example follows the same high-level shape as the `bcast` example: +//! - load a local private key from a data directory +//! - load cluster peers from `cluster-lock.json` +//! - resolve relay URLs with `bootnode::new_relays` +//! - create relay reservations and relay routing +//! - run `sync` over relay-mediated connectivity +//! +//! To try it locally: +//! +//! ```text +//! # Terminal 1: start a relay server +//! cargo run -p pluto-relay-server --example relay_server +//! +//! # Terminals 2-4: run three node directories from the same cluster +//! cargo run -p pluto-dkg --example sync -- \ +//! --relays http://127.0.0.1:8888 \ +//! --data-dir /path/to/node0 +//! +//! cargo run -p pluto-dkg --example sync -- \ +//! --relays http://127.0.0.1:8888 \ +//! --data-dir /path/to/node1 +//! +//! cargo run -p pluto-dkg --example sync -- \ +//! --relays http://127.0.0.1:8888 \ +//! --data-dir /path/to/node2 +//! ``` +//! +//! Assumption: +//! - the three data directories already exist +//! - each one belongs to one node in the same cluster +//! +//! Required files in each data directory: +//! - `charon-enr-private-key` +//! - `cluster-lock.json` +//! +//! Expected flow: +//! 1. Each node loads the same cluster peer order from the lock file. +//! 2. Nodes resolve the configured relays and establish relay reservations. +//! 3. The relay router dials known cluster peers through relay circuits. +//! 4. Each node starts one sync client per remote peer. +//! 5. Once all clients are connected, the demo advances through steps 1 and 2. +//! 6. The demo keeps the sync clients running in steady state. +//! 7. Press `Ctrl+C` to trigger graceful shutdown and wait for all remote +//! shutdowns. +//! +//! Success signals: +//! - `Relay reservation accepted` +//! - `Connection established` with `peer_type="CLUSTER"` +//! - `All sync clients connected` +//! - `Sync step reached` +//! - `Sync demo is now idling until Ctrl+C` +//! - `Sync steady-state heartbeat` +//! - `All peers reported shutdown` +//! +//! Transient relay warnings can occur during startup and reconnects. The demo +//! is healthy once all cluster peers are connected and the sync steps complete. +#![allow(missing_docs)] + +use std::{ + collections::{HashMap, HashSet}, + path::PathBuf, + str::FromStr, + time::Duration, +}; + +use anyhow::{Context as _, Result}; +use clap::Parser; +use futures::StreamExt; +use libp2p::{ + PeerId, identify, ping, + relay::{self}, + swarm::{NetworkBehaviour, SwarmEvent}, +}; +use pluto_cluster::lock::Lock; +use pluto_core::version::VERSION; +use pluto_dkg::sync::{self, Client, Server}; +use pluto_p2p::{ + behaviours::pluto::PlutoBehaviourEvent, + bootnode, + config::P2PConfig, + gater, k1, + p2p::{Node, NodeType}, + p2p_context::P2PContext, + relay::{MutableRelayReservation, RelayRouter}, +}; +use pluto_tracing::TracingConfig; +use tokio::{fs, signal, task::JoinHandle}; +use tokio_util::sync::CancellationToken; +use tracing::{debug, error, info}; + +#[derive(NetworkBehaviour)] +struct ExampleBehaviour { + relay: relay::client::Behaviour, + relay_reservation: MutableRelayReservation, + relay_router: RelayRouter, + sync: sync::Behaviour, +} + +#[derive(Debug, Parser)] +#[command(name = "sync-example")] +#[command(about = "Run a relay-based DKG sync demo node")] +struct Args { + /// Relay URLs or relay multiaddrs to use. + #[arg(long, value_delimiter = ',')] + relays: Vec, + + /// Data directory containing `charon-enr-private-key` and + /// `cluster-lock.json`. + #[arg(long)] + data_dir: PathBuf, + + /// Additional known peers to allow and route via relays. + #[arg(long, value_delimiter = ',')] + known_peers: Vec, + + /// Whether to filter private addresses from advertisements. + #[arg(short, long, default_value_t = false)] + filter_private_addrs: bool, + + /// The external IP address of the node. + #[arg(long)] + external_ip: Option, + + /// The external host of the node. + #[arg(long)] + external_host: Option, + + /// TCP addresses to listen on. + #[arg(long)] + tcp_addrs: Vec, + + /// UDP addresses to listen on. + #[arg(long)] + udp_addrs: Vec, + + /// Whether to disable reuse port. + #[arg(long, default_value_t = false)] + disable_reuse_port: bool, +} + +#[derive(Debug, Clone)] +struct ClusterInfo { + peers: Vec, + indices: HashMap, + local_peer_id: PeerId, + local_node_number: u32, +} + +impl ClusterInfo { + fn expected_connections(&self) -> usize { + self.peers.len().saturating_sub(1) + } + + fn peer_label(&self, peer_id: &PeerId) -> String { + match self.indices.get(peer_id) { + Some(index) => format!( + "node={} peer_id={peer_id}", + index.checked_add(1).unwrap_or(*index) + ), + None => format!("peer_id={peer_id}"), + } + } + + fn peer_labels_where(&self, predicate: impl Fn(&PeerId) -> bool) -> Vec { + self.peers + .iter() + .filter(|peer_id| predicate(peer_id)) + .map(|peer_id| self.peer_label(peer_id)) + .collect() + } + + fn connected_peers(&self, connected: &HashSet) -> Vec { + self.peer_labels_where(|peer_id| connected.contains(peer_id)) + } + + fn missing_peers(&self, connected: &HashSet) -> Vec { + self.peer_labels_where(|peer_id| { + *peer_id != self.local_peer_id && !connected.contains(peer_id) + }) + } +} + +fn peer_type( + peer_id: &PeerId, + relay_peer_ids: &HashSet, + cluster_info: &ClusterInfo, +) -> &'static str { + if relay_peer_ids.contains(peer_id) { + "RELAY" + } else if cluster_info.indices.contains_key(peer_id) { + "CLUSTER" + } else { + "UNKNOWN" + } +} + +fn merge_known_peers( + cluster_peers: &[PeerId], + configured_known_peers: &[String], +) -> Result> { + let capacity = cluster_peers + .len() + .checked_add(configured_known_peers.len()) + .context("known peer capacity overflow")?; + let mut known_peers = Vec::with_capacity(capacity); + known_peers.extend(cluster_peers.iter().copied()); + let mut known_peer_ids = HashSet::with_capacity(capacity); + known_peer_ids.extend(known_peers.iter().copied()); + + for peer in configured_known_peers { + let peer_id = PeerId::from_str(peer) + .with_context(|| format!("failed to parse known peer id: {peer}"))?; + if known_peer_ids.insert(peer_id) { + known_peers.push(peer_id); + } + } + + Ok(known_peers) +} + +fn local_node_number(cluster_peers: &[PeerId], local_peer_id: PeerId) -> Result { + let index = cluster_peers + .iter() + .position(|peer_id| peer_id == &local_peer_id) + .context("local peer id is not present in the cluster lock")?; + let node_number = index + .checked_add(1) + .context("cluster peer index overflow")?; + u32::try_from(node_number).context("cluster peer index does not fit in u32") +} + +fn endpoint_address(endpoint: &libp2p::core::ConnectedPoint) -> &libp2p::Multiaddr { + match endpoint { + libp2p::core::ConnectedPoint::Dialer { address, .. } => address, + libp2p::core::ConnectedPoint::Listener { send_back_addr, .. } => send_back_addr, + } +} + +fn connection_log_fields<'a>( + peer_id: PeerId, + endpoint: &'a libp2p::core::ConnectedPoint, + relay_peer_ids: &HashSet, + cluster_info: &ClusterInfo, +) -> (String, &'static str, &'a libp2p::Multiaddr) { + ( + cluster_info.peer_label(&peer_id), + peer_type(&peer_id, relay_peer_ids, cluster_info), + endpoint_address(endpoint), + ) +} + +fn log_relay_event(relay_event: relay::client::Event, cluster_info: &ClusterInfo) { + match relay_event { + relay::client::Event::ReservationReqAccepted { + relay_peer_id, + renewal, + limit, + } => { + debug!( + relay_peer_id = %relay_peer_id, + renewal, + limit = ?limit, + "Relay reservation accepted" + ); + } + relay::client::Event::OutboundCircuitEstablished { + relay_peer_id, + limit, + } => { + debug!( + relay_peer_id = %relay_peer_id, + limit = ?limit, + "Outbound relay circuit established" + ); + } + relay::client::Event::InboundCircuitEstablished { src_peer_id, limit } => { + debug!( + src_peer_id = %src_peer_id, + peer_label = %cluster_info.peer_label(&src_peer_id), + limit = ?limit, + "Inbound relay circuit established" + ); + } + } +} + +fn log_connection_established( + peer_id: PeerId, + endpoint: &libp2p::core::ConnectedPoint, + num_established: std::num::NonZero, + relay_peer_ids: &HashSet, + cluster_info: &ClusterInfo, +) { + let (peer_label, peer_type, address) = + connection_log_fields(peer_id, endpoint, relay_peer_ids, cluster_info); + debug!( + peer_id = %peer_id, + peer_label = %peer_label, + peer_type, + address = %address, + num_established = num_established.get(), + "Connection established" + ); +} + +fn log_connection_closed( + peer_id: PeerId, + endpoint: &libp2p::core::ConnectedPoint, + num_established: u32, + cause: Option<&libp2p::swarm::ConnectionError>, + relay_peer_ids: &HashSet, + cluster_info: &ClusterInfo, +) { + let (peer_label, peer_type, address) = + connection_log_fields(peer_id, endpoint, relay_peer_ids, cluster_info); + debug!( + peer_id = %peer_id, + peer_label = %peer_label, + peer_type, + address = %address, + num_established, + cause = ?cause, + "Connection closed" + ); +} + +fn log_identify_event( + peer_id: PeerId, + info: identify::Info, + relay_peer_ids: &HashSet, + cluster_info: &ClusterInfo, +) { + debug!( + peer_id = %peer_id, + peer_type = peer_type(&peer_id, relay_peer_ids, cluster_info), + agent_version = %info.agent_version, + protocol_version = %info.protocol_version, + num_addresses = info.listen_addrs.len(), + "Received identify from peer" + ); +} + +fn log_ping_event( + peer: PeerId, + result: Result, + relay_peer_ids: &HashSet, + cluster_info: &ClusterInfo, +) { + match result { + Ok(rtt) => debug!( + peer_id = %peer, + peer_type = peer_type(&peer, relay_peer_ids, cluster_info), + rtt = ?rtt, + "Received ping" + ), + Err(error) => debug!( + peer_id = %peer, + peer_type = peer_type(&peer, relay_peer_ids, cluster_info), + err = %error, + "Ping failed" + ), + } +} + +fn print_cluster_overview(cluster_info: &ClusterInfo) { + info!("Cluster peer order:"); + for (index, peer_id) in cluster_info.peers.iter().enumerate() { + let local_marker = if *peer_id == cluster_info.local_peer_id { + " (local)" + } else { + "" + }; + info!( + peer_index = index.checked_add(1).unwrap_or(index), + peer_id = %peer_id, + local = %local_marker, + "Cluster peer" + ); + } +} + +async fn run_sync( + server: Server, + clients: Vec, + cluster_info: ClusterInfo, + cancellation: CancellationToken, +) -> Result<()> { + server.start(); + info!( + local_node = cluster_info.local_node_number, + expected_clients = clients.len(), + "Started sync server" + ); + + let mut client_joins = Vec::with_capacity(clients.len()); + for client in &clients { + let client = client.clone(); + let cancellation = cancellation.child_token(); + client_joins.push(tokio::spawn(async move { client.run(cancellation).await })); + } + + // First wait until all local per-peer sync clients report connected. + // The shared server barrier below then confirms the whole cluster has + // observed all peer connections. + let mut previous_connected = None; + loop { + let connected = clients + .iter() + .filter(|client| client.is_connected()) + .count(); + if previous_connected != Some(connected) { + info!( + local_node = cluster_info.local_node_number, + connected, + expected = clients.len(), + "Sync client connectivity update" + ); + previous_connected = Some(connected); + } + + if connected == clients.len() { + break; + } + + tokio::select! { + _ = cancellation.cancelled() => break, + _ = tokio::time::sleep(Duration::from_millis(100)) => {} + } + } + + // Once all local sync clients are connected, wait for the shared server + // to observe the full cluster barrier and then drive the demo through a + // couple of synchronized steps. + if !cancellation.is_cancelled() { + info!( + local_node = cluster_info.local_node_number, + connected = clients.len(), + "All sync clients connected" + ); + + for client in &clients { + client.disable_reconnect(); + } + + match server.await_all_connected(cancellation.child_token()).await { + Ok(()) | Err(sync::Error::Canceled) => {} + Err(error) => return Err(anyhow::anyhow!(error.to_string())), + } + + for step in 1_i64..=2 { + for client in &clients { + client.set_step(step); + } + info!( + local_node = cluster_info.local_node_number, + step, "Waiting for sync step" + ); + match server + .await_all_at_step(step, cancellation.child_token()) + .await + { + Ok(()) => { + info!( + local_node = cluster_info.local_node_number, + step, "Sync step reached" + ); + } + Err(sync::Error::Canceled) => break, + Err(error) => return Err(anyhow::anyhow!(error.to_string())), + } + } + } + + if !cancellation.is_cancelled() { + info!( + local_node = cluster_info.local_node_number, + "Sync demo is now idling until Ctrl+C" + ); + + let mut heartbeat = tokio::time::interval(Duration::from_secs(5)); + loop { + tokio::select! { + _ = cancellation.cancelled() => break, + _ = heartbeat.tick() => { + let connected = clients.iter().filter(|client| client.is_connected()).count(); + info!( + local_node = cluster_info.local_node_number, + connected, + expected = clients.len(), + "Sync steady-state heartbeat" + ); + } + } + } + } + + let shutdown_cancellation = CancellationToken::new(); + info!( + local_node = cluster_info.local_node_number, + "Starting graceful shutdown" + ); + for client in &clients { + client.shutdown(shutdown_cancellation.child_token()).await?; + } + + server + .await_all_shutdown(shutdown_cancellation.child_token()) + .await?; + info!( + local_node = cluster_info.local_node_number, + "All peers reported shutdown" + ); + + for join in client_joins { + match join.await { + Ok(Ok(())) => {} + Ok(Err(error)) => return Err(anyhow::anyhow!(error.to_string())), + Err(error) => return Err(anyhow::anyhow!(error.to_string())), + } + } + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<()> { + pluto_tracing::init(&TracingConfig::default()).expect("failed to initialize tracing"); + + let args = Args::parse(); + let key = k1::load_priv_key(&args.data_dir).expect("Failed to load private key"); + let local_peer_id = pluto_p2p::peer::peer_id_from_key(key.public_key()) + .expect("Failed to derive local peer ID"); + + let lock_path = args.data_dir.join("cluster-lock.json"); + let lock_str = fs::read_to_string(&lock_path) + .await + .expect("Failed to load lock"); + let lock: Lock = serde_json::from_str(&lock_str).expect("Failed to parse lock"); + + let cluster_peers = lock.peer_ids().expect("Failed to get lock peer IDs"); + let local_node_number = local_node_number(&cluster_peers, local_peer_id) + .expect("Failed to derive local node number"); + let mut indices = HashMap::with_capacity(cluster_peers.len()); + for (index, peer_id) in cluster_peers.iter().copied().enumerate() { + indices.insert(peer_id, index); + } + let cluster_info = ClusterInfo { + peers: cluster_peers.clone(), + indices, + local_peer_id, + local_node_number, + }; + + let cancellation = CancellationToken::new(); + let lock_hash_hex = hex::encode(&lock.lock_hash); + let relays = bootnode::new_relays(cancellation.child_token(), &args.relays, &lock_hash_hex) + .await + .context("failed to resolve relays")?; + let relay_peer_ids = relays + .iter() + .filter_map(|relay| relay.peer().ok().flatten().map(|peer| peer.id)) + .collect::>(); + + let known_peers = merge_known_peers(&cluster_peers, &args.known_peers)?; + + let conn_gater = gater::ConnGater::new( + gater::Config::closed() + .with_relays(relays.clone()) + .with_peer_ids(known_peers.clone()), + ); + + let p2p_config = P2PConfig { + relays: vec![], + external_ip: args.external_ip, + external_host: args.external_host, + tcp_addrs: args.tcp_addrs, + udp_addrs: args.udp_addrs, + disable_reuse_port: args.disable_reuse_port, + }; + + let version = VERSION.to_minor(); + let p2p_context = P2PContext::new(known_peers.clone()); + p2p_context.set_local_peer_id(local_peer_id); + let (sync_behaviour, server, clients) = sync::new( + cluster_peers.clone(), + p2p_context.clone(), + &key, + lock.lock_hash.clone(), + version, + )?; + + let mut node: Node = Node::new( + p2p_config, + key, + NodeType::QUIC, + args.filter_private_addrs, + known_peers, + { + let p2p_context = p2p_context.clone(); + move |builder, keypair, relay_client| { + let p2p_context = p2p_context.clone(); + let local_peer_id = keypair.public().to_peer_id(); + + builder + .with_p2p_context(p2p_context.clone()) + .with_gater(conn_gater) + .with_inner(ExampleBehaviour { + relay: relay_client, + relay_reservation: MutableRelayReservation::new(relays.clone()), + relay_router: RelayRouter::new(relays.clone(), p2p_context, local_peer_id), + sync: sync_behaviour, + }) + } + }, + )?; + + info!( + local_peer_id = %local_peer_id, + local_node = local_node_number, + data_dir = %args.data_dir.display(), + "Started sync example" + ); + print_cluster_overview(&cluster_info); + + let mut connected_cluster_peers = + HashSet::::with_capacity(cluster_info.expected_connections()); + let mut demo_task: JoinHandle> = tokio::spawn(run_sync( + server, + clients, + cluster_info.clone(), + cancellation.child_token(), + )); + + loop { + tokio::select! { + event = node.select_next_some() => { + match event { + SwarmEvent::Behaviour(PlutoBehaviourEvent::Inner( + ExampleBehaviourEvent::Relay(relay_event), + )) => { + log_relay_event(relay_event, &cluster_info); + } + SwarmEvent::Behaviour(PlutoBehaviourEvent::Ping(ping::Event { + peer, + result, + .. + })) => { + log_ping_event(peer, result, &relay_peer_ids, &cluster_info); + } + SwarmEvent::Behaviour(PlutoBehaviourEvent::Identify( + identify::Event::Received { peer_id, info, .. }, + )) => { + log_identify_event(peer_id, info, &relay_peer_ids, &cluster_info); + } + SwarmEvent::ConnectionEstablished { + peer_id, + endpoint, + num_established, + .. + } => { + log_connection_established( + peer_id, + &endpoint, + num_established, + &relay_peer_ids, + &cluster_info, + ); + if cluster_info.indices.contains_key(&peer_id) { + connected_cluster_peers.insert(peer_id); + debug!( + connected = connected_cluster_peers.len(), + expected = cluster_info.expected_connections(), + connected_peers = ?cluster_info.connected_peers(&connected_cluster_peers), + missing_peers = ?cluster_info.missing_peers(&connected_cluster_peers), + "Cluster connectivity update" + ); + } + } + SwarmEvent::ConnectionClosed { + peer_id, + endpoint, + num_established, + cause, + .. + } => { + log_connection_closed( + peer_id, + &endpoint, + num_established, + cause.as_ref(), + &relay_peer_ids, + &cluster_info, + ); + connected_cluster_peers.remove(&peer_id); + } + SwarmEvent::OutgoingConnectionError { + peer_id, + connection_id, + error: err, + } => { + debug!( + ?peer_id, + ?connection_id, + %err, + "Outgoing connection error" + ); + } + SwarmEvent::IncomingConnectionError { + connection_id, + local_addr, + send_back_addr, + error: err, + .. + } => { + debug!( + ?connection_id, + %local_addr, + %send_back_addr, + %err, + "Incoming connection error" + ); + } + SwarmEvent::NewListenAddr { address, .. } => { + debug!(%address, "Listening on address"); + } + _ => {} + } + } + result = &mut demo_task => { + match result { + Ok(Ok(())) => { + info!("Sync demo completed successfully"); + break; + } + Ok(Err(error)) => { + error!(err = %error, "Sync demo failed"); + break; + } + Err(error) => { + error!(err = %error, "Sync demo task failed"); + break; + } + } + } + _ = signal::ctrl_c() => { + info!("Ctrl+C received, shutting down"); + cancellation.cancel(); + } + } + } + + cancellation.cancel(); + Ok(()) +} diff --git a/crates/dkg/src/sync/client.rs b/crates/dkg/src/sync/client.rs index 1e47475e..5b767a01 100644 --- a/crates/dkg/src/sync/client.rs +++ b/crates/dkg/src/sync/client.rs @@ -195,7 +195,7 @@ impl Client { } changed = done_rx.changed() => { if changed.is_err() { - return Err(Error::message("sync client completion channel closed")); + return Err(Error::CompletionChannelClosed); } } } diff --git a/crates/dkg/src/sync/error.rs b/crates/dkg/src/sync/error.rs index 75e3d718..46f4487f 100644 --- a/crates/dkg/src/sync/error.rs +++ b/crates/dkg/src/sync/error.rs @@ -1,4 +1,4 @@ -use pluto_core::version::SemVer; +use libp2p::PeerId; /// Sync result type. pub type Result = std::result::Result; @@ -6,10 +6,6 @@ pub type Result = std::result::Result; /// Error type for the DKG sync protocol. #[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)] pub enum Error { - /// Generic message. - #[error("{0}")] - Message(String), - /// The sync client was canceled. #[error("sync client canceled")] Canceled, @@ -47,6 +43,10 @@ pub enum Error { #[error("peer step is too far ahead")] PeerStepTooFarAhead, + /// A checked step arithmetic operation overflowed. + #[error("step overflow")] + StepOverflow, + /// The stream protocol could not be negotiated. #[error("protocol negotiation failed")] Unsupported, @@ -83,52 +83,20 @@ pub enum Error { #[error("sync server not started")] ServerNotStarted, + /// The sync client completion channel was closed unexpectedly. + #[error("sync client completion channel closed")] + CompletionChannelClosed, + + /// An inbound sync message failed validation. + #[error("invalid sync message: peer={peer} err={error}")] + InvalidSyncMessage { + /// The peer whose message was invalid. + peer: PeerId, + /// The validation error. + error: String, + }, + /// The local peer ID was missing from the shared P2P context. #[error("local peer id missing from p2p context")] LocalPeerMissing, } - -impl Error { - /// Creates a new generic message error. - pub fn message(message: impl Into) -> Self { - Self::Message(message.into()) - } - - /// Creates an I/O error from the given source. - pub fn io(error: impl std::fmt::Display) -> Self { - Self::Io(error.to_string()) - } - - /// Creates a protobuf decode error from the given source. - pub fn decode(error: impl std::fmt::Display) -> Self { - Self::Decode(error.to_string()) - } - - /// Creates a protobuf encode error from the given source. - pub fn encode(error: impl std::fmt::Display) -> Self { - Self::Encode(error.to_string()) - } - - /// Creates a peer conversion error from the given source. - pub fn peer(error: impl std::fmt::Display) -> Self { - Self::Peer(error.to_string()) - } - - /// Creates a version mismatch error matching Go's wire string. - pub fn version_mismatch(expected: &SemVer, got: &str) -> Self { - Self::VersionMismatch { - expected: expected.to_string(), - got: got.to_string(), - } - } - - /// Returns true if the error should be treated like Go's relay reset path. - pub fn is_relay_error(&self) -> bool { - matches!(self, Self::Io(message) if { - let lowercase = message.to_ascii_lowercase(); - lowercase.contains("connection reset") - || lowercase.contains("resource scope closed") - || lowercase.contains("broken pipe") - }) - } -} diff --git a/crates/dkg/src/sync/handler.rs b/crates/dkg/src/sync/handler.rs index fc2b89ae..5c20dbff 100644 --- a/crates/dkg/src/sync/handler.rs +++ b/crates/dkg/src/sync/handler.rs @@ -52,7 +52,7 @@ enum OutboundState { enum OutboundExit { GracefulShutdown, - Reconnectable { error: Error, relay: bool }, + Reconnectable { error: Error, relay_reset: bool }, Fatal(Error), } @@ -122,17 +122,21 @@ impl Handler { client.release_outbound(); - let error = match error { - StreamUpgradeError::NegotiationFailed => Error::Unsupported, - StreamUpgradeError::Timeout => Error::io(std::io::Error::new( + let (error, relay_reset) = match error { + StreamUpgradeError::NegotiationFailed => (Error::Unsupported, false), + StreamUpgradeError::Timeout => (Error::Io(std::io::Error::new( std::io::ErrorKind::TimedOut, "sync protocol negotiation timed out", - )), + ) + .to_string()), false), StreamUpgradeError::Apply(never) => match never {}, - StreamUpgradeError::Io(error) => Error::io(error), + StreamUpgradeError::Io(error) => ( + Error::Io(error.to_string()), + error.kind() == std::io::ErrorKind::ConnectionReset, + ), }; - if client.should_reconnect() || error.is_relay_error() { + if relay_reset || client.should_reconnect() { self.schedule_retry(); } else { client.finish(Err(error)); @@ -204,7 +208,7 @@ impl ConnectionHandler for Handler { } self.outbound = OutboundState::Disabled; } - Poll::Ready(OutboundExit::Reconnectable { error, relay }) => { + Poll::Ready(OutboundExit::Reconnectable { error, relay_reset }) => { let Some(client) = self.client.as_ref() else { self.outbound = OutboundState::Disabled; return Poll::Pending; @@ -213,12 +217,8 @@ impl ConnectionHandler for Handler { client.set_connected(false); client.release_outbound(); - if relay || client.should_reconnect() { - if relay { - debug!(peer = %self.peer_id, err = %error, "Relay connection dropped, reconnecting sync client"); - } else { - info!(peer = %self.peer_id, err = %error, "Disconnected from peer"); - } + if relay_reset || client.should_reconnect() { + info!(peer = %self.peer_id, err = %error, "Disconnected from peer"); self.outbound = OutboundState::Idle; } else { client.finish(Err(error)); @@ -294,7 +294,7 @@ async fn run_outbound_stream(client: Client, mut stream: Stream) -> OutboundExit let shutdown = client.shutdown_requested(); let timestamp = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { Ok(timestamp) => timestamp, - Err(error) => return OutboundExit::Fatal(Error::io(error)), + Err(error) => return OutboundExit::Fatal(Error::Io(error.to_string())), }; let nanos = timestamp.subsec_nanos(); let timestamp = Timestamp { @@ -309,14 +309,11 @@ async fn run_outbound_stream(client: Client, mut stream: Stream) -> OutboundExit step: client.step(), }; - let response: Result = - match pluto_p2p::proto::write_fixed_size_protobuf(&mut stream, &request) - .await - .map_err(Error::io) + let response: std::io::Result = + match pluto_p2p::proto::write_fixed_size_protobuf(&mut stream, &request).await { Ok(()) => pluto_p2p::proto::read_fixed_size_protobuf(&mut stream) - .await - .map_err(Error::io), + .await, Err(error) => Err(error), }; @@ -324,8 +321,8 @@ async fn run_outbound_stream(client: Client, mut stream: Stream) -> OutboundExit Ok(response) => response, Err(error) => { return OutboundExit::Reconnectable { - relay: error.is_relay_error(), - error, + relay_reset: error.kind() == std::io::ErrorKind::ConnectionReset, + error: Error::Io(error.to_string()), }; } }; @@ -358,12 +355,13 @@ async fn handle_inbound_stream( return Err(Error::ServerNotStarted); } - let public_key = pluto_p2p::peer::peer_id_to_libp2p_pk(&peer_id).map_err(Error::peer)?; + let public_key = + pluto_p2p::peer::peer_id_to_libp2p_pk(&peer_id).map_err(|error| Error::Peer(error.to_string()))?; loop { let message: MsgSync = pluto_p2p::proto::read_fixed_size_protobuf(&mut stream) .await - .map_err(Error::io)?; + .map_err(|error| Error::Io(error.to_string()))?; let mut response = MsgSyncResponse { sync_timestamp: message.timestamp, error: String::new(), @@ -380,9 +378,10 @@ async fn handle_inbound_stream( error: error.clone(), }); server - .set_err(Error::message(format!( - "invalid sync message: peer={peer_id} err={error}" - ))) + .set_err(Error::InvalidSyncMessage { + peer: peer_id, + error: error.to_string(), + }) .await; response.error = error.to_string(); } else { @@ -406,7 +405,7 @@ async fn handle_inbound_stream( pluto_p2p::proto::write_fixed_size_protobuf(&mut stream, &response) .await - .map_err(Error::io)?; + .map_err(|error| Error::Io(error.to_string()))?; if message.shutdown { send_inbound_event(&inbound_events_tx, OutEvent::PeerShutdownObserved { peer_id }); diff --git a/crates/dkg/src/sync/protocol.rs b/crates/dkg/src/sync/protocol.rs index 7aa0bf9b..477e25a5 100644 --- a/crates/dkg/src/sync/protocol.rs +++ b/crates/dkg/src/sync/protocol.rs @@ -35,7 +35,10 @@ pub fn validate_request_with_public_key( SemVer::parse(&msg.version).map_err(|error| Error::ParsePeerVersion(error.to_string()))?; if msg_version != *expected_version { - return Err(Error::version_mismatch(expected_version, &msg.version)); + return Err(Error::VersionMismatch { + expected: expected_version.to_string(), + got: msg.version.clone(), + }); } if !public_key.verify(def_hash, &msg.hash_signature) { diff --git a/crates/dkg/src/sync/server.rs b/crates/dkg/src/sync/server.rs index 0c8e7372..cb94c309 100644 --- a/crates/dkg/src/sync/server.rs +++ b/crates/dkg/src/sync/server.rs @@ -124,10 +124,10 @@ impl Server { ) -> Result<()> { let step_plus_one = step .checked_add(1) - .ok_or_else(|| Error::message("step overflow"))?; + .ok_or(Error::StepOverflow)?; let step_plus_two = step .checked_add(2) - .ok_or_else(|| Error::message("step overflow"))?; + .ok_or(Error::StepOverflow)?; loop { let notified = self.inner.notify.notified(); @@ -219,7 +219,7 @@ impl Server { let current_plus_two = current .checked_add(2) - .ok_or_else(|| Error::message("step overflow"))?; + .ok_or(Error::StepOverflow)?; if step > current_plus_two { return Err(Error::PeerStepAhead); } From 29013bf04b7e1890bab4fb3ce70d07b78e70bec8 Mon Sep 17 00:00:00 2001 From: Quang Le Date: Fri, 3 Apr 2026 16:04:35 +0700 Subject: [PATCH 11/14] fix: clear server connected state --- crates/dkg/src/sync/behaviour.rs | 5 +- crates/dkg/src/sync/client.rs | 45 +++++++++- crates/dkg/src/sync/error.rs | 8 ++ crates/dkg/src/sync/handler.rs | 143 +++++++++++++++++-------------- crates/dkg/src/sync/mod.rs | 35 +++++++- crates/dkg/src/sync/server.rs | 16 ++-- 6 files changed, 170 insertions(+), 82 deletions(-) diff --git a/crates/dkg/src/sync/behaviour.rs b/crates/dkg/src/sync/behaviour.rs index 6a8588b6..b33ca25e 100644 --- a/crates/dkg/src/sync/behaviour.rs +++ b/crates/dkg/src/sync/behaviour.rs @@ -8,8 +8,9 @@ use libp2p::{ Multiaddr, PeerId, swarm::{ ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour, THandler, THandlerInEvent, - THandlerOutEvent, ToSwarm, dummy, + THandlerOutEvent, ToSwarm, dial_opts::{DialOpts, PeerCondition}, + dummy, }, }; use pluto_p2p::p2p_context::P2PContext; @@ -228,7 +229,7 @@ mod tests { let p2p_context = P2PContext::new([peer_id]); let mut behaviour = Behaviour::new(server, [client.clone()], p2p_context, command_rx); - client.activate(); + client.activate().expect("activate should succeed"); let waker = noop_waker_ref(); let mut cx = Context::from_waker(waker); diff --git a/crates/dkg/src/sync/client.rs b/crates/dkg/src/sync/client.rs index 5b767a01..52947230 100644 --- a/crates/dkg/src/sync/client.rs +++ b/crates/dkg/src/sync/client.rs @@ -86,7 +86,7 @@ impl Client { /// Runs the client until shutdown, fatal error, or cancellation. pub async fn run(&self, cancellation: CancellationToken) -> Result<()> { - self.activate(); + self.activate()?; self.wait_finished(cancellation, true).await } @@ -165,12 +165,19 @@ impl Client { } } - pub(crate) fn activate(&self) { + pub(crate) fn activate(&self) -> Result<()> { self.inner.active.store(true, Ordering::SeqCst); - if let Some(command_tx) = &self.inner.command_tx { - let _ = command_tx.send(Command::Activate(self.inner.peer_id)); + if let Some(command_tx) = &self.inner.command_tx + && command_tx + .send(Command::Activate(self.inner.peer_id)) + .is_ok() + { + return Ok(()); } + + self.inner.active.store(false, Ordering::SeqCst); + Err(Error::ActivationChannelUnavailable) } async fn wait_finished( @@ -202,3 +209,33 @@ impl Client { } } } + +#[cfg(test)] +mod tests { + use libp2p::PeerId; + use pluto_core::version::SemVer; + + use super::*; + + #[tokio::test] + async fn run_fails_immediately_if_activation_channel_is_closed() { + let (command_tx, command_rx) = mpsc::unbounded_channel(); + drop(command_rx); + + let client = Client::new( + PeerId::random(), + vec![1, 2, 3], + SemVer::parse("v1.7").expect("version"), + ClientConfig::default(), + Some(command_tx), + ); + + let error = client + .run(CancellationToken::new()) + .await + .expect_err("closed activation channel should fail immediately"); + + assert!(matches!(error, Error::ActivationChannelUnavailable)); + assert!(!client.should_run()); + } +} diff --git a/crates/dkg/src/sync/error.rs b/crates/dkg/src/sync/error.rs index 46f4487f..3763b1e0 100644 --- a/crates/dkg/src/sync/error.rs +++ b/crates/dkg/src/sync/error.rs @@ -87,6 +87,10 @@ pub enum Error { #[error("sync client completion channel closed")] CompletionChannelClosed, + /// The sync client activation channel was unavailable. + #[error("sync client activation channel unavailable")] + ActivationChannelUnavailable, + /// An inbound sync message failed validation. #[error("invalid sync message: peer={peer} err={error}")] InvalidSyncMessage { @@ -99,4 +103,8 @@ pub enum Error { /// The local peer ID was missing from the shared P2P context. #[error("local peer id missing from p2p context")] LocalPeerMissing, + + /// The configured peer set did not include the local peer ID. + #[error("local peer id missing from sync peer set")] + LocalPeerNotInPeerSet, } diff --git a/crates/dkg/src/sync/handler.rs b/crates/dkg/src/sync/handler.rs index 5c20dbff..126cdfd1 100644 --- a/crates/dkg/src/sync/handler.rs +++ b/crates/dkg/src/sync/handler.rs @@ -27,11 +27,11 @@ use tracing::{debug, info, warn}; use crate::dkgpb::v1::sync::{MsgSync, MsgSyncResponse}; use super::{ + Event, client::Client, error::{Error, Result}, protocol, server::Server, - Event, }; const INITIAL_BACKOFF: Duration = Duration::from_millis(100); @@ -124,11 +124,16 @@ impl Handler { let (error, relay_reset) = match error { StreamUpgradeError::NegotiationFailed => (Error::Unsupported, false), - StreamUpgradeError::Timeout => (Error::Io(std::io::Error::new( - std::io::ErrorKind::TimedOut, - "sync protocol negotiation timed out", - ) - .to_string()), false), + StreamUpgradeError::Timeout => ( + Error::Io( + std::io::Error::new( + std::io::ErrorKind::TimedOut, + "sync protocol negotiation timed out", + ) + .to_string(), + ), + false, + ), StreamUpgradeError::Apply(never) => match never {}, StreamUpgradeError::Io(error) => ( Error::Io(error.to_string()), @@ -310,10 +315,8 @@ async fn run_outbound_stream(client: Client, mut stream: Stream) -> OutboundExit }; let response: std::io::Result = - match pluto_p2p::proto::write_fixed_size_protobuf(&mut stream, &request).await - { - Ok(()) => pluto_p2p::proto::read_fixed_size_protobuf(&mut stream) - .await, + match pluto_p2p::proto::write_fixed_size_protobuf(&mut stream, &request).await { + Ok(()) => pluto_p2p::proto::read_fixed_size_protobuf(&mut stream).await, Err(error) => Err(error), }; @@ -351,69 +354,83 @@ async fn handle_inbound_stream( inbound_events_tx: mpsc::UnboundedSender, mut stream: Stream, ) -> Result<()> { - if !server.is_started() { - return Err(Error::ServerNotStarted); - } + let result = async { + if !server.is_started() { + return Err(Error::ServerNotStarted); + } - let public_key = - pluto_p2p::peer::peer_id_to_libp2p_pk(&peer_id).map_err(|error| Error::Peer(error.to_string()))?; + let public_key = pluto_p2p::peer::peer_id_to_libp2p_pk(&peer_id) + .map_err(|error| Error::Peer(error.to_string()))?; - loop { - let message: MsgSync = pluto_p2p::proto::read_fixed_size_protobuf(&mut stream) - .await - .map_err(|error| Error::Io(error.to_string()))?; - let mut response = MsgSyncResponse { - sync_timestamp: message.timestamp, - error: String::new(), - }; + loop { + let message: MsgSync = pluto_p2p::proto::read_fixed_size_protobuf(&mut stream) + .await + .map_err(|error| Error::Io(error.to_string()))?; + let mut response = MsgSyncResponse { + sync_timestamp: message.timestamp, + error: String::new(), + }; - if let Err(error) = protocol::validate_request_with_public_key( - server.def_hash(), - server.version(), - &public_key, - &message, - ) { - send_inbound_event(&inbound_events_tx, OutEvent::SyncRejected { - peer_id, - error: error.clone(), - }); - server - .set_err(Error::InvalidSyncMessage { - peer: peer_id, - error: error.to_string(), - }) - .await; - response.error = error.to_string(); - } else { - let (inserted, count) = server.mark_connected(peer_id).await; - if inserted { - info!( - peer = %peer_id, - connected = count, - expected = server.expected_peer_count(), - "Connected to peer" + if let Err(error) = protocol::validate_request_with_public_key( + server.def_hash(), + server.version(), + &public_key, + &message, + ) { + send_inbound_event( + &inbound_events_tx, + OutEvent::SyncRejected { + peer_id, + error: error.clone(), + }, ); + server + .set_err(Error::InvalidSyncMessage { + peer: peer_id, + error: error.to_string(), + }) + .await; + response.error = error.to_string(); + } else { + let (inserted, count) = server.mark_connected(peer_id).await; + if inserted { + info!( + peer = %peer_id, + connected = count, + expected = server.expected_peer_count(), + "Connected to peer" + ); + } } - } - if server.update_step(peer_id, message.step).await? { - send_inbound_event(&inbound_events_tx, OutEvent::PeerStepUpdated { - peer_id, - step: message.step, - }); - } + if server.update_step(peer_id, message.step).await? { + send_inbound_event( + &inbound_events_tx, + OutEvent::PeerStepUpdated { + peer_id, + step: message.step, + }, + ); + } - pluto_p2p::proto::write_fixed_size_protobuf(&mut stream, &response) - .await - .map_err(|error| Error::Io(error.to_string()))?; + pluto_p2p::proto::write_fixed_size_protobuf(&mut stream, &response) + .await + .map_err(|error| Error::Io(error.to_string()))?; - if message.shutdown { - send_inbound_event(&inbound_events_tx, OutEvent::PeerShutdownObserved { peer_id }); - server.set_shutdown(peer_id).await; - server.clear_connected(peer_id).await; - return Ok(()); + if message.shutdown { + send_inbound_event( + &inbound_events_tx, + OutEvent::PeerShutdownObserved { peer_id }, + ); + server.set_shutdown(peer_id).await; + return Ok(()); + } } } + .await; + + server.clear_connected(peer_id).await; + result } fn send_inbound_event(inbound_events_tx: &mpsc::UnboundedSender, event: OutEvent) { diff --git a/crates/dkg/src/sync/mod.rs b/crates/dkg/src/sync/mod.rs index a8c4e63f..aaad833a 100644 --- a/crates/dkg/src/sync/mod.rs +++ b/crates/dkg/src/sync/mod.rs @@ -36,12 +36,19 @@ pub fn new( version: SemVer, ) -> Result<(Behaviour, Server, Vec)> { let local_peer_id = p2p_context.local_peer_id().ok_or(Error::LocalPeerMissing)?; + if !peers.contains(&local_peer_id) { + return Err(Error::LocalPeerNotInPeerSet); + } + let hash_sig = protocol::sign_definition_hash(secret, &def_hash)?; - let server = Server::new(peers.len().saturating_sub(1), def_hash, version.clone()); - let (command_tx, command_rx) = mpsc::unbounded_channel(); - let clients = peers + let remote_peers = peers .into_iter() .filter(|peer_id| *peer_id != local_peer_id) + .collect::>(); + let server = Server::new(remote_peers.len(), def_hash, version.clone()); + let (command_tx, command_rx) = mpsc::unbounded_channel(); + let clients = remote_peers + .into_iter() .map(|peer_id| { Client::new( peer_id, @@ -218,6 +225,28 @@ mod tests { assert!(matches!(error, Error::PeerStepAhead)); } + #[test] + fn new_requires_local_peer_in_peer_set() { + let key = generate_insecure_k1_key(0); + let local_peer_id = peer_id_from_key(key.public_key()).expect("peer id"); + let remote_peer = PeerId::random(); + let p2p_context = P2PContext::new([local_peer_id, remote_peer]); + p2p_context.set_local_peer_id(local_peer_id); + + let result = new( + vec![remote_peer], + p2p_context, + &key, + vec![1, 2, 3], + SemVer::parse("v1.7").expect("version"), + ); + + assert!( + matches!(result, Err(Error::LocalPeerNotInPeerSet)), + "local peer must be part of the sync peer set" + ); + } + #[tokio::test] async fn sync_round_trip_matches_go_shape() -> anyhow::Result<()> { let ports = (0..3) diff --git a/crates/dkg/src/sync/server.rs b/crates/dkg/src/sync/server.rs index cb94c309..7ec667c4 100644 --- a/crates/dkg/src/sync/server.rs +++ b/crates/dkg/src/sync/server.rs @@ -122,12 +122,8 @@ impl Server { step: i64, cancellation: CancellationToken, ) -> Result<()> { - let step_plus_one = step - .checked_add(1) - .ok_or(Error::StepOverflow)?; - let step_plus_two = step - .checked_add(2) - .ok_or(Error::StepOverflow)?; + let step_plus_one = step.checked_add(1).ok_or_else(|| Error::StepOverflow)?; + let step_plus_two = step.checked_add(2).ok_or_else(|| Error::StepOverflow)?; loop { let notified = self.inner.notify.notified(); @@ -184,7 +180,9 @@ impl Server { let mut state = self.inner.state.write().await; let inserted = state.connected.insert(peer_id); let count = state.connected.len(); - self.inner.notify.notify_waiters(); + if inserted { + self.inner.notify.notify_waiters(); + } (inserted, count) } @@ -217,9 +215,7 @@ impl Server { return Err(Error::PeerStepBehind); } - let current_plus_two = current - .checked_add(2) - .ok_or(Error::StepOverflow)?; + let current_plus_two = current.checked_add(2).ok_or_else(|| Error::StepOverflow)?; if step > current_plus_two { return Err(Error::PeerStepAhead); } From 6de0bda3554777d8f3faf96ad93b173209829c54 Mon Sep 17 00:00:00 2001 From: Quang Le Date: Fri, 3 Apr 2026 17:12:40 +0700 Subject: [PATCH 12/14] fix: simplify code --- crates/dkg/src/sync/behaviour.rs | 49 ++++++++++++++++---------------- crates/dkg/src/sync/client.rs | 6 ++-- crates/dkg/src/sync/error.rs | 2 +- crates/dkg/src/sync/handler.rs | 25 +++++----------- crates/dkg/src/sync/mod.rs | 7 ----- crates/dkg/src/sync/server.rs | 4 +-- crates/p2p/src/proto.rs | 35 +++++++++++++++++++++-- 7 files changed, 72 insertions(+), 56 deletions(-) diff --git a/crates/dkg/src/sync/behaviour.rs b/crates/dkg/src/sync/behaviour.rs index b33ca25e..0df6b2c3 100644 --- a/crates/dkg/src/sync/behaviour.rs +++ b/crates/dkg/src/sync/behaviour.rs @@ -72,14 +72,13 @@ impl Behaviour { } fn connection_handler_for_peer(&self, peer: PeerId) -> THandler { - if self.clients.contains_key(&peer) { - Either::Left(Handler::new( + match self.clients.get(&peer) { + Some(client) => Either::Left(Handler::new( peer, self.server.clone(), - self.clients.get(&peer).cloned(), - )) - } else { - Either::Right(dummy::ConnectionHandler) + Some(client.clone()), + )), + None => Either::Right(dummy::ConnectionHandler), } } @@ -90,22 +89,24 @@ impl Behaviour { return; }; - if client.should_run() && !client.is_connected() { - if !self - .p2p_context - .peer_store_lock() - .connections_to_peer(&peer_id) - .is_empty() - { - return; - } - - self.pending_events.push_back(ToSwarm::Dial { - opts: DialOpts::peer_id(peer_id) - .condition(PeerCondition::DisconnectedAndNotDialing) - .build(), - }); + if !client.should_run() || client.is_connected() || client.shutdown_requested() { + return; + } + + if !self + .p2p_context + .peer_store_lock() + .connections_to_peer(&peer_id) + .is_empty() + { + return; } + + self.pending_events.push_back(ToSwarm::Dial { + opts: DialOpts::peer_id(peer_id) + .condition(PeerCondition::DisconnectedAndNotDialing) + .build(), + }); } } } @@ -143,14 +144,14 @@ impl NetworkBehaviour for Behaviour { return; } - // TODO: Go retries sync client connections until reconnect is disabled. - // Re-queue active clients here (and on DialFailure below) so peers that - // restart before initial cluster sync can be dialed again. if let Some(client) = self.clients.get(&event.peer_id) { client.set_connected(false); client.release_outbound(); } } + FromSwarm::DialFailure(event) => { + let _ = event; + } _ => {} } } diff --git a/crates/dkg/src/sync/client.rs b/crates/dkg/src/sync/client.rs index 52947230..48d50be6 100644 --- a/crates/dkg/src/sync/client.rs +++ b/crates/dkg/src/sync/client.rs @@ -12,8 +12,10 @@ use pluto_core::version::SemVer; use tokio::sync::{mpsc, watch}; use tokio_util::sync::CancellationToken; -use super::Command; -use super::error::{Error, Result}; +use super::{ + Command, + error::{Error, Result}, +}; /// Default period between sync messages. pub const DEFAULT_PERIOD: Duration = Duration::from_millis(100); diff --git a/crates/dkg/src/sync/error.rs b/crates/dkg/src/sync/error.rs index 3763b1e0..7fffd255 100644 --- a/crates/dkg/src/sync/error.rs +++ b/crates/dkg/src/sync/error.rs @@ -15,7 +15,7 @@ pub enum Error { PeerRespondedWithError(String), /// The remote peer version did not match. - #[error("mismatching charon version; expect={expected}, got={got}")] + #[error("mismatching version; expect={expected}, got={got}")] VersionMismatch { /// The expected version string. expected: String, diff --git a/crates/dkg/src/sync/handler.rs b/crates/dkg/src/sync/handler.rs index 126cdfd1..a50f6c70 100644 --- a/crates/dkg/src/sync/handler.rs +++ b/crates/dkg/src/sync/handler.rs @@ -20,8 +20,7 @@ use libp2p::{ }, }; use prost_types::Timestamp; -use tokio::sync::mpsc; -use tokio::time::Sleep; +use tokio::{sync::mpsc, time::Sleep}; use tracing::{debug, info, warn}; use crate::dkgpb::v1::sync::{MsgSync, MsgSyncResponse}; @@ -288,7 +287,7 @@ impl ConnectionHandler for Handler { async fn run_outbound_stream(client: Client, mut stream: Stream) -> OutboundExit { let mut interval = tokio::time::interval(client.period()); - let hash_signature = prost::bytes::Bytes::from(client.hash_sig().to_vec()); + let hash_signature = prost::bytes::Bytes::copy_from_slice(client.hash_sig()); let version = client.version().to_string(); client.set_connected(true); @@ -297,15 +296,7 @@ async fn run_outbound_stream(client: Client, mut stream: Stream) -> OutboundExit interval.tick().await; let shutdown = client.shutdown_requested(); - let timestamp = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { - Ok(timestamp) => timestamp, - Err(error) => return OutboundExit::Fatal(Error::Io(error.to_string())), - }; - let nanos = timestamp.subsec_nanos(); - let timestamp = Timestamp { - seconds: i64::try_from(timestamp.as_secs()).unwrap_or(i64::MAX), - nanos: i32::try_from(nanos).unwrap_or(i32::MAX), - }; + let timestamp = Timestamp::from(std::time::SystemTime::now()); let request = MsgSync { timestamp: Some(timestamp), hash_signature: hash_signature.clone(), @@ -377,20 +368,18 @@ async fn handle_inbound_stream( &public_key, &message, ) { + let error_string = error.to_string(); send_inbound_event( &inbound_events_tx, - OutEvent::SyncRejected { - peer_id, - error: error.clone(), - }, + OutEvent::SyncRejected { peer_id, error }, ); server .set_err(Error::InvalidSyncMessage { peer: peer_id, - error: error.to_string(), + error: error_string.clone(), }) .await; - response.error = error.to_string(); + response.error = error_string; } else { let (inserted, count) = server.mark_connected(peer_id).await; if inserted { diff --git a/crates/dkg/src/sync/mod.rs b/crates/dkg/src/sync/mod.rs index aaad833a..2b98285e 100644 --- a/crates/dkg/src/sync/mod.rs +++ b/crates/dkg/src/sync/mod.rs @@ -1,10 +1,3 @@ -//! DKG peer step synchronization protocol. -//! -//! This module ports Go `charon/dkg/sync` into Rust while keeping the public -//! API split into a per-peer [`Client`] handle and a shared [`Server`] handle. -//! Internally, Pluto drives the protocol through a libp2p behaviour and -//! connection handler because protocol streams are owned by the swarm. - mod behaviour; mod client; mod error; diff --git a/crates/dkg/src/sync/server.rs b/crates/dkg/src/sync/server.rs index 7ec667c4..02534551 100644 --- a/crates/dkg/src/sync/server.rs +++ b/crates/dkg/src/sync/server.rs @@ -122,8 +122,8 @@ impl Server { step: i64, cancellation: CancellationToken, ) -> Result<()> { - let step_plus_one = step.checked_add(1).ok_or_else(|| Error::StepOverflow)?; - let step_plus_two = step.checked_add(2).ok_or_else(|| Error::StepOverflow)?; + let step_plus_one = step.checked_add(1).ok_or(Error::StepOverflow)?; + let step_plus_two = step.checked_add(2).ok_or(Error::StepOverflow)?; loop { let notified = self.inner.notify.notified(); diff --git a/crates/p2p/src/proto.rs b/crates/p2p/src/proto.rs index 9d16cd0c..9171598c 100644 --- a/crates/p2p/src/proto.rs +++ b/crates/p2p/src/proto.rs @@ -59,12 +59,22 @@ pub async fn write_fixed_size_delimited( .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "payload length overflow"))?; stream.write_all(&len.to_le_bytes()).await?; - stream.write_all(payload).await + stream.write_all(payload).await?; + stream.flush().await } /// Reads a fixed-size length-delimited payload from the stream. pub async fn read_fixed_size_delimited( stream: &mut S, +) -> io::Result> { + read_fixed_size_delimited_with_max(stream, MAX_MESSAGE_SIZE).await +} + +/// Reads a fixed-size length-delimited payload from the stream, rejecting +/// oversized messages. +pub async fn read_fixed_size_delimited_with_max( + stream: &mut S, + max_message_size: usize, ) -> io::Result> { let mut len_buf = [0_u8; 8]; stream.read_exact(&mut len_buf).await?; @@ -73,12 +83,19 @@ pub async fn read_fixed_size_delimited( if len < 0 { return Err(io::Error::new( io::ErrorKind::InvalidData, - format!("invalid data length"), + "invalid data length", )); } let len = usize::try_from(len) .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "payload length overflow"))?; + if len > max_message_size { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("message too large: {len} bytes (max: {max_message_size})"), + )); + } + let mut payload = vec![0_u8; len]; stream.read_exact(&mut payload).await?; @@ -165,4 +182,18 @@ mod tests { .expect_err("negative sizes must fail"); assert_eq!(error.kind(), io::ErrorKind::InvalidData); } + + #[tokio::test] + async fn oversized_fixed_size_length_fails() { + let too_large = i64::try_from(MAX_MESSAGE_SIZE) + .expect("max size should fit in i64") + .checked_add(1) + .expect("test length should not overflow"); + let mut cursor = Cursor::new(too_large.to_le_bytes().to_vec()); + + let error = read_fixed_size_delimited(&mut cursor) + .await + .expect_err("oversized payloads must fail"); + assert_eq!(error.kind(), io::ErrorKind::InvalidData); + } } From 4249e14d64961b2ca07ec9a5da601e953da4016e Mon Sep 17 00:00:00 2001 From: Quang Le Date: Fri, 3 Apr 2026 18:01:02 +0700 Subject: [PATCH 13/14] fix: sync example allow a peer is disconnected --- crates/dkg/examples/sync.rs | 81 ++++++++++++++----------------------- 1 file changed, 30 insertions(+), 51 deletions(-) diff --git a/crates/dkg/examples/sync.rs b/crates/dkg/examples/sync.rs index 5b408906..906fb11b 100644 --- a/crates/dkg/examples/sync.rs +++ b/crates/dkg/examples/sync.rs @@ -42,8 +42,8 @@ //! 4. Each node starts one sync client per remote peer. //! 5. Once all clients are connected, the demo advances through steps 1 and 2. //! 6. The demo keeps the sync clients running in steady state. -//! 7. Press `Ctrl+C` to trigger graceful shutdown and wait for all remote -//! shutdowns. +//! 7. Press `Ctrl+C` on any node to stop that node immediately and let the +//! other nodes observe the fault. //! //! Success signals: //! - `Relay reservation accepted` @@ -52,7 +52,7 @@ //! - `Sync step reached` //! - `Sync demo is now idling until Ctrl+C` //! - `Sync steady-state heartbeat` -//! - `All peers reported shutdown` +//! - `Ctrl+C received, exiting without graceful shutdown` //! //! Transient relay warnings can occur during startup and reconnects. The demo //! is healthy once all cluster peers are connected and the sync steps complete. @@ -305,27 +305,6 @@ fn log_connection_established( ); } -fn log_connection_closed( - peer_id: PeerId, - endpoint: &libp2p::core::ConnectedPoint, - num_established: u32, - cause: Option<&libp2p::swarm::ConnectionError>, - relay_peer_ids: &HashSet, - cluster_info: &ClusterInfo, -) { - let (peer_label, peer_type, address) = - connection_log_fields(peer_id, endpoint, relay_peer_ids, cluster_info); - debug!( - peer_id = %peer_id, - peer_label = %peer_label, - peer_type, - address = %address, - num_established, - cause = ?cause, - "Connection closed" - ); -} - fn log_identify_event( peer_id: PeerId, info: identify::Info, @@ -470,6 +449,13 @@ async fn run_sync( Err(sync::Error::Canceled) => break, Err(error) => return Err(anyhow::anyhow!(error.to_string())), } + + if step < 2 { + tokio::select! { + _ = cancellation.cancelled() => break, + _ = tokio::time::sleep(Duration::from_secs(3)) => {} + } + } } } @@ -496,26 +482,16 @@ async fn run_sync( } } - let shutdown_cancellation = CancellationToken::new(); - info!( - local_node = cluster_info.local_node_number, - "Starting graceful shutdown" - ); - for client in &clients { - client.shutdown(shutdown_cancellation.child_token()).await?; + if cancellation.is_cancelled() { + info!( + local_node = cluster_info.local_node_number, + "Cancellation received, exiting without graceful shutdown" + ); } - server - .await_all_shutdown(shutdown_cancellation.child_token()) - .await?; - info!( - local_node = cluster_info.local_node_number, - "All peers reported shutdown" - ); - for join in client_joins { match join.await { - Ok(Ok(())) => {} + Ok(Ok(())) | Ok(Err(sync::Error::Canceled)) => {} Ok(Err(error)) => return Err(anyhow::anyhow!(error.to_string())), Err(error) => return Err(anyhow::anyhow!(error.to_string())), } @@ -680,20 +656,23 @@ async fn main() -> Result<()> { } SwarmEvent::ConnectionClosed { peer_id, - endpoint, - num_established, cause, .. } => { - log_connection_closed( - peer_id, - &endpoint, - num_established, - cause.as_ref(), - &relay_peer_ids, - &cluster_info, - ); - connected_cluster_peers.remove(&peer_id); + if cluster_info.indices.contains_key(&peer_id) + && connected_cluster_peers.remove(&peer_id) + { + error!( + local_node = cluster_info.local_node_number, + peer_id = %peer_id, + peer_label = %cluster_info.peer_label(&peer_id), + connected = connected_cluster_peers.len(), + expected = cluster_info.expected_connections(), + missing_peers = ?cluster_info.missing_peers(&connected_cluster_peers), + cause = ?cause, + "Cluster peer disconnected" + ); + } } SwarmEvent::OutgoingConnectionError { peer_id, From bafbd0a5704d73ac55be8457360b6265646ff7fd Mon Sep 17 00:00:00 2001 From: Quang Le Date: Fri, 3 Apr 2026 18:05:26 +0700 Subject: [PATCH 14/14] fix: clippy --- crates/dkg/src/sync/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/dkg/src/sync/server.rs b/crates/dkg/src/sync/server.rs index 02534551..141d5e9a 100644 --- a/crates/dkg/src/sync/server.rs +++ b/crates/dkg/src/sync/server.rs @@ -215,7 +215,7 @@ impl Server { return Err(Error::PeerStepBehind); } - let current_plus_two = current.checked_add(2).ok_or_else(|| Error::StepOverflow)?; + let current_plus_two = current.checked_add(2).ok_or(Error::StepOverflow)?; if step > current_plus_two { return Err(Error::PeerStepAhead); }