Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 155 additions & 5 deletions lightning-custom-message/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,25 @@ macro_rules! composite_custom_message_handler {
}

fn peer_connected(&self, their_node_id: $crate::bitcoin::secp256k1::PublicKey, msg: &$crate::lightning::ln::msgs::Init, inbound: bool) -> Result<(), ()> {
let mut result = Ok(());
// Per the `CustomMessageHandler::peer_connected` contract, `peer_disconnected`
// will not be called by `PeerManager` if we return `Err`. To avoid leaking
// per-peer state in sub-handlers that already returned `Ok` when a later one
// errors, record each sub-handler's result and roll back the successful ones
// ourselves before propagating the failure.
$(
if let Err(e) = self.$field.peer_connected(their_node_id, msg, inbound) {
result = Err(e);
}
let $field = self.$field.peer_connected(their_node_id, msg, inbound);
)*
result
let any_err = false $( || $field.is_err() )*;
if any_err {
$(
if $field.is_ok() {
self.$field.peer_disconnected(their_node_id);
}
)*
Err(())
} else {
Ok(())
}
}

fn provided_node_features(&self) -> $crate::lightning::types::features::NodeFeatures {
Expand Down Expand Up @@ -376,3 +388,141 @@ macro_rules! composite_custom_message_handler {
}
}
}

#[cfg(test)]
mod tests {
use bitcoin::secp256k1::PublicKey;
use core::sync::atomic::{AtomicUsize, Ordering};
use lightning::io;
use lightning::ln::msgs::{DecodeError, Init, LightningError};
use lightning::ln::peer_handler::CustomMessageHandler;
use lightning::ln::wire::{CustomMessageReader, Type};
use lightning::types::features::{InitFeatures, NodeFeatures};
use lightning::util::ser::{LengthLimitedRead, Writeable, Writer};

#[derive(Debug)]
pub struct Foo;
impl Type for Foo {
fn type_id(&self) -> u16 {
32768
}
}
impl Writeable for Foo {
fn write<W: Writer>(&self, _: &mut W) -> Result<(), io::Error> {
Ok(())
}
}

pub struct CountingHandler {
pub connect_count: AtomicUsize,
}
impl CustomMessageReader for CountingHandler {
type CustomMessage = Foo;
fn read<R: LengthLimitedRead>(
&self, _t: u16, _b: &mut R,
) -> Result<Option<Foo>, DecodeError> {
Ok(None)
}
}
impl CustomMessageHandler for CountingHandler {
fn handle_custom_message(&self, _msg: Foo, _: PublicKey) -> Result<(), LightningError> {
Ok(())
}
fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Foo)> {
vec![]
}
fn peer_disconnected(&self, _: PublicKey) {
self.connect_count.fetch_sub(1, Ordering::SeqCst);
}
fn peer_connected(&self, _: PublicKey, _: &Init, _: bool) -> Result<(), ()> {
self.connect_count.fetch_add(1, Ordering::SeqCst);
Ok(())
}
fn provided_node_features(&self) -> NodeFeatures {
NodeFeatures::empty()
}
fn provided_init_features(&self, _: PublicKey) -> InitFeatures {
InitFeatures::empty()
}
}

#[derive(Debug)]
pub struct Bar;
impl Type for Bar {
fn type_id(&self) -> u16 {
32769
}
}
impl Writeable for Bar {
fn write<W: Writer>(&self, _: &mut W) -> Result<(), io::Error> {
Ok(())
}
}

pub struct ErroringHandler;
impl CustomMessageReader for ErroringHandler {
type CustomMessage = Bar;
fn read<R: LengthLimitedRead>(
&self, _t: u16, _b: &mut R,
) -> Result<Option<Bar>, DecodeError> {
Ok(None)
}
}
impl CustomMessageHandler for ErroringHandler {
fn handle_custom_message(&self, _msg: Bar, _: PublicKey) -> Result<(), LightningError> {
Ok(())
}
fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Bar)> {
vec![]
}
fn peer_disconnected(&self, _: PublicKey) {}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert!(false) here.

fn peer_connected(&self, _: PublicKey, _: &Init, _: bool) -> Result<(), ()> {
Err(())
}
fn provided_node_features(&self) -> NodeFeatures {
NodeFeatures::empty()
}
fn provided_init_features(&self, _: PublicKey) -> InitFeatures {
InitFeatures::empty()
}
}

composite_custom_message_handler!(
pub struct CompositeHandler {
counting: CountingHandler,
erroring: ErroringHandler,
}

pub enum CompositeMessage {
Foo(32768),
Bar(32769),
}
);

#[test]
fn peer_connected_failure_does_not_leak_subhandler_state() {
let composite = CompositeHandler {
counting: CountingHandler { connect_count: AtomicUsize::new(0) },
erroring: ErroringHandler,
};
let pk_bytes = [
0x02, 0x79, 0xBE, 0x66, 0x7E, 0xF9, 0xDC, 0xBB, 0xAC, 0x55, 0xA0, 0x62, 0x95, 0xCE,
0x87, 0x0B, 0x07, 0x02, 0x9B, 0xFC, 0xDB, 0x2D, 0xCE, 0x28, 0xD9, 0x59, 0xF2, 0x81,
0x5B, 0x16, 0xF8, 0x17, 0x98,
];
let pk = PublicKey::from_slice(&pk_bytes).unwrap();
let init =
Init { features: InitFeatures::empty(), networks: None, remote_network_address: None };

let result = composite.peer_connected(pk, &init, true);
assert!(result.is_err(), "Composite must propagate the inner Err");

let leaked = composite.counting.connect_count.load(Ordering::SeqCst);
assert_eq!(
leaked, 0,
"CountingHandler tracked {leaked} connected peer(s) after the composite \
returned Err; this state will never be cleaned up because per the trait \
contract peer_disconnected won't be called when peer_connected returns Err.",
);
}
}
Loading