From 2904bdff9d1def91a5b5a44e27b9881161177a07 Mon Sep 17 00:00:00 2001 From: Will Scott Date: Wed, 1 Apr 2026 08:57:06 +0200 Subject: [PATCH] Add support for extensions of the multiaddr interface. This introduces a feature flag `Custom` that can be used to allow for parsing and management of protocols that are not part of the hard-coded set of known multiaddrs. This parallels the more permissive support that is found in js and golang implementations. by default, unknown protocols will not be parsed, since the semantics of how to parse their arguments cannot be known, but they can be manually constructed and then serialized to string. Specific extension protocols can be registered in a protocol registry to extend the default set and allow for handling of multiaddrs using those additional protocols more naturally. --- Cargo.toml | 1 + src/custom.rs | 363 ++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 7 +- src/protocol.rs | 106 +++++++++++++- 4 files changed, 475 insertions(+), 2 deletions(-) create mode 100644 src/custom.rs diff --git a/Cargo.toml b/Cargo.toml index 72e1d04..71fafe6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ version = "0.18.3" [features] default = ["url"] +custom = [] [dependencies] arrayref = "0.3" diff --git a/src/custom.rs b/src/custom.rs new file mode 100644 index 0000000..3311e00 --- /dev/null +++ b/src/custom.rs @@ -0,0 +1,363 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use crate::{Error, Multiaddr, Protocol, Result}; + +/// A transcoder defines how to encode and decode a custom protocol's data +/// between its binary representation and its human-readable string representation. +pub trait Transcoder: Send + Sync { + /// Attempts to parse the human-readable string component of a protocol into bytes. + fn string_to_bytes( + &self, + s: &str, + ) -> std::result::Result, Box>; + + /// Attempts to format the binary representation of a protocol's data into a human-readable string. + fn bytes_to_string( + &self, + bytes: &[u8], + ) -> std::result::Result>; +} + +/// A custom protocol definition. +pub struct CustomProtocolDef { + pub name: &'static str, + pub code: u32, + /// The length of the binary payload. + /// `0` means no data. `> 0` means a fixed data length. `-1` denotes a length-prefixed protocol. + pub size: i32, + pub path: bool, + pub transcoder: Option>, +} + +impl std::cmp::PartialEq for CustomProtocolDef { + fn eq(&self, other: &Self) -> bool { + self.code == other.code && self.name == other.name + } +} + +impl std::cmp::Eq for CustomProtocolDef {} + +impl std::fmt::Debug for CustomProtocolDef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CustomProtocolDef") + .field("name", &self.name) + .field("code", &self.code) + .field("size", &self.size) + .field("path", &self.path) + .finish() + } +} + +/// A registry mapping protocol codes and names to their custom definitions. +#[derive(Clone)] +pub struct Registry { + by_code: HashMap>, + by_name: HashMap>, +} + +impl Default for Registry { + fn default() -> Self { + let mut r = Self { + by_code: HashMap::new(), + by_name: HashMap::new(), + }; + r.register_builtins(); + r + } +} + +impl Registry { + /// Create a new, empty protocol registry. Wait, new() actually uses default and adds built-ins. + pub fn new() -> Self { + Self::default() + } + + /// Add all built-in standard protocols to the registry + fn register_builtins(&mut self) { + for &(name, code, size, path) in crate::protocol::BUILT_IN_PROTOCOLS.iter() { + self.register(CustomProtocolDef { + name, + code, + size, + path, + transcoder: None, + }); + } + } + + /// Add a custom protocol definition to this registry. + pub fn register(&mut self, mut def: CustomProtocolDef) { + if def.path && def.name.starts_with('/') { + def.name = def.name.trim_start_matches('/'); + } + let name = def.name.to_string(); + let code = def.code; + let arc = Arc::new(def); + self.by_code.insert(code, arc.clone()); + self.by_name.insert(name, arc); + } + + /// Returns a registered custom protocol by its integer code. + pub fn get_by_code(&self, code: u32) -> Option> { + self.by_code.get(&code).cloned() + } + + /// Returns a registered custom protocol by its string name. + pub fn get_by_name(&self, name: &str) -> Option> { + self.by_name.get(name).cloned() + } + + /// Unregisters a protocol by its string name. + pub fn unregister_by_name(&mut self, name: &str) { + if let Some(def) = self.by_name.remove(name) { + self.by_code.remove(&def.code); + } + } + + /// Unregisters a protocol by its integer code. + pub fn unregister_by_code(&mut self, code: u32) { + if let Some(def) = self.by_code.remove(&code) { + self.by_name.remove(def.name); + } + } + + /// Iterate over the protocols in a `Multiaddr` using this registry. + pub fn iter<'a>(&'a self, ma: &'a Multiaddr) -> RegistryIter<'a> { + RegistryIter { + registry: self, + data: ma.as_ref(), + } + } +} + +/// Iterator over protocols using a registry. +pub struct RegistryIter<'a> { + registry: &'a Registry, + data: &'a [u8], +} + +impl<'a> Iterator for RegistryIter<'a> { + type Item = Protocol<'a>; + + fn next(&mut self) -> Option { + if self.data.is_empty() { + return None; + } + + let (p, next_data) = self.registry.parse_protocol_from_bytes(self.data).ok()?; + self.data = next_data; + Some(p) + } +} + +impl Registry { + /// Try parsing a single Protocol from bytes using the registry. + pub fn parse_protocol_from_bytes<'a>( + &self, + input: &'a [u8], + ) -> Result<(Protocol<'a>, &'a [u8])> { + let n_input = input; + let id_res = unsigned_varint::decode::u32(n_input); + if let Ok((id, _rest)) = id_res { + if !self.by_code.contains_key(&id) { + return Err(Error::UnknownProtocolId(id)); + } + } + + if let Ok(res) = Protocol::from_bytes(input) { + return Ok(res); + } + + let n_input = input; + let id_res = unsigned_varint::decode::u32(n_input); + if let Ok((id, rest)) = id_res { + if let Some(def) = self.get_by_code(id) { + let (data, out_rest) = if def.size == 0 { + (std::borrow::Cow::Borrowed(&rest[..0]), rest) + } else if def.size > 0 { + let fixed = def.size as usize; + if rest.len() < fixed { + return Err(Error::DataLessThanLen); + } + let (d, r) = rest.split_at(fixed); + (std::borrow::Cow::Borrowed(d), r) + } else { + let (len, r) = + unsigned_varint::decode::usize(rest).map_err(|_| Error::DataLessThanLen)?; + if r.len() < len { + return Err(Error::DataLessThanLen); + } + let (d, r2) = r.split_at(len); + (std::borrow::Cow::Borrowed(d), r2) + }; + return Ok((Protocol::Custom { def, data }, out_rest)); + } + } + + Err(Error::UnknownProtocolId( + id_res.map(|(i, _)| i).unwrap_or(0), + )) + } + + /// Try parsing a single Protocol from string parts using the registry. + pub fn parse_protocol_from_str_parts<'a, I>(&self, iter: &mut I) -> Result> + where + I: Iterator + Clone, + { + let mut peek_iter = iter.clone(); + if let Some(tag) = peek_iter.next() { + if !self.by_name.contains_key(tag) { + return Err(Error::UnknownProtocolString(tag.to_string())); + } + } + + let mut native_iter = iter.clone(); + if let Ok(p) = Protocol::from_str_parts(&mut native_iter) { + *iter = native_iter; + return Ok(p); + } + + let mut peek_iter = iter.clone(); + if let Some(tag) = peek_iter.next() { + if let Some(def) = self.get_by_name(tag) { + iter.next(); // consume the tag + let data = if def.size == 0 { + vec![] + } else if let Some(t) = &def.transcoder { + let part = iter.next().ok_or(Error::InvalidProtocolString)?; + t.string_to_bytes(part) + .map_err(|_| Error::InvalidProtocolString)? + } else if def.path { + let part = iter.next().ok_or(Error::InvalidProtocolString)?; + percent_encoding::percent_decode(part.as_bytes()).collect::>() + } else { + let part = iter.next().ok_or(Error::InvalidProtocolString)?; + multibase::Base::Base58Btc + .decode(part) + .map_err(|_| Error::InvalidProtocolString)? + }; + return Ok(Protocol::Custom { + def, + data: std::borrow::Cow::Owned(data), + }); + } + } + + let mut final_try = iter.clone(); + if let Some(tag) = final_try.next() { + Err(Error::UnknownProtocolString(tag.to_string())) + } else { + Err(Error::InvalidProtocolString) + } + } + + /// Parse a Multiaddr string using this registry + pub fn try_from_str(&self, input: &str) -> Result { + let mut addr = Multiaddr::empty(); + let mut parts = input.split('/'); + + if Some("") != parts.next() { + return Err(Error::InvalidMultiaddr); + } + + while parts.clone().peekable().peek().is_some() { + let p = self.parse_protocol_from_str_parts(&mut parts)?; + addr = addr.with(p); + } + + Ok(addr) + } + + /// Parse a Multiaddr from bytes using this registry + pub fn try_from_bytes(&self, mut input: &[u8]) -> Result { + let mut addr = Multiaddr::empty(); + while !input.is_empty() { + let (p, rest) = self.parse_protocol_from_bytes(input)?; + addr = addr.with(p); + input = rest; + } + Ok(addr) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct SimpleTranscoder; + impl Transcoder for SimpleTranscoder { + fn string_to_bytes( + &self, + s: &str, + ) -> std::result::Result, Box> { + Ok(s.as_bytes().to_vec()) + } + fn bytes_to_string( + &self, + bytes: &[u8], + ) -> std::result::Result> { + Ok(String::from_utf8(bytes.to_vec())?) + } + } + + #[test] + fn test_custom_protocol_registry() { + let mut registry = Registry::new(); + registry.register(CustomProtocolDef { + name: "my-custom", + code: 999, + size: -1, + path: false, + transcoder: Some(Box::new(SimpleTranscoder)), + }); + + let addr = registry + .try_from_str("/ip4/127.0.0.1/my-custom/helloworld") + .unwrap(); + + // Output via normal iter should panic because the global parser doesn't know 999, + // wait, we modified the normal fmt::Display to iterate, BUT Display iterates the multiaddr. + // If we try `addr.to_string()`, it will panic if it's not a generic iterator. + + let vec = addr.to_vec(); + // Parse back from vec + let parsed = registry.try_from_bytes(&vec).unwrap(); + + let mut iter = registry.iter(&parsed); + if let Some(Protocol::Ip4(ip)) = iter.next() { + assert_eq!(ip, std::net::Ipv4Addr::new(127, 0, 0, 1)); + } else { + panic!("expected ip4"); + } + + if let Some(Protocol::Custom { def, data }) = iter.next() { + assert_eq!(def.code, 999); + assert_eq!(def.name, "my-custom"); + assert_eq!(data.as_ref(), b"helloworld"); + } else { + panic!("expected custom protocol"); + } + } + + #[test] + fn test_unregister_builtin() { + let mut registry = Registry::default(); + + // Assert tcp works + let addr = registry.try_from_str("/ip4/127.0.0.1/tcp/80").unwrap(); + let mut iter = registry.iter(&addr); + assert!(matches!(iter.next(), Some(Protocol::Ip4(_)))); + assert!(matches!(iter.next(), Some(Protocol::Tcp(80)))); + + // Unregister tcp + registry.unregister_by_name("tcp"); + + // Assert tcp fails now + assert!(registry.try_from_str("/ip4/127.0.0.1/tcp/80").is_err()); + + // And similarly from bytes + let vec = addr.to_vec(); + assert!(registry.try_from_bytes(&vec).is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs index b6b0ad4..264fb11 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,9 +7,14 @@ mod errors; mod onion_addr; mod protocol; +#[cfg(feature = "custom")] +mod custom; + #[cfg(feature = "url")] mod from_url; +#[cfg(feature = "custom")] +pub use self::custom::{CustomProtocolDef, Registry, Transcoder}; pub use self::errors::{Error, Result}; pub use self::onion_addr::Onion3Addr; pub use self::protocol::Protocol; @@ -223,7 +228,7 @@ impl Multiaddr { /// Returns &str identifiers for the protocol names themselves. /// This omits specific info like addresses, ports, peer IDs, and the like. /// Example: `"/ip4/127.0.0.1/tcp/5001"` would return `["ip4", "tcp"]` - pub fn protocol_stack(&self) -> ProtoStackIter { + pub fn protocol_stack(&self) -> ProtoStackIter<'_> { ProtoStackIter { parts: self.iter() } } } diff --git a/src/protocol.rs b/src/protocol.rs index f0fa1d4..5137919 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -60,6 +60,49 @@ const P2P_STARDUST: u32 = 277; // Deprecated const WEBRTC: u32 = 281; const HTTP_PATH: u32 = 481; +#[cfg(feature = "custom")] +pub(crate) const BUILT_IN_PROTOCOLS: &[(&str, u32, i32, bool)] = &[ + ("ip4", IP4, 4, false), + ("tcp", TCP, 2, false), + ("udp", UDP, 2, false), + ("dccp", DCCP, 2, false), + ("ip6", IP6, 16, false), + ("p2p", P2P, -1, false), + ("ipfs", P2P, -1, false), + ("http", HTTP, 0, false), + ("https", HTTPS, 0, false), + ("onion", ONION, 12, false), + ("onion3", ONION3, 37, false), + ("quic", QUIC, 0, false), + ("quic-v1", QUIC_V1, 0, false), + ("ws", WS, 0, false), + ("wss", WSS, 0, false), + ("p2p-websocket-star", P2P_WEBSOCKET_STAR, 0, false), + ("webrtc-direct", WEBRTC_DIRECT, 0, false), + ("p2p-webrtc-direct", P2P_WEBRTC_DIRECT, 0, false), + ("certhash", CERTHASH, -1, false), + ("p2p-circuit", P2P_CIRCUIT, 0, false), + ("sctp", SCTP, 2, false), + ("udt", UDT, 0, false), + ("utp", UTP, 0, false), + ("unix", UNIX, -1, true), + ("dns", DNS, -1, false), + ("dns4", DNS4, -1, false), + ("dns6", DNS6, -1, false), + ("dnsaddr", DNSADDR, -1, false), + ("tls", TLS, 0, false), + ("noise", NOISE, 0, false), + ("webtransport", WEBTRANSPORT, 0, false), + ("ip6zone", IP6ZONE, -1, true), + ("ipcidr", IPCIDR, 1, false), + ("garlic64", GARLIC64, -1, false), + ("garlic32", GARLIC32, -1, false), + ("sni", SNI, -1, false), + ("webrtc", WEBRTC, 0, false), + ("http-path", HTTP_PATH, -1, true), + ("memory", MEMORY, 8, false), +]; + /// Type-alias for how multi-addresses use `Multihash`. /// /// The `64` defines the allocation size for the digest within the `Multihash`. @@ -130,6 +173,13 @@ pub enum Protocol<'a> { P2pStardust, WebRTC, HttpPath(Cow<'a, str>), + #[cfg(feature = "custom")] + Custom { + def: std::sync::Arc, + data: Cow<'a, [u8]>, + }, + #[cfg(feature = "custom")] + Unknown(u32, Cow<'a, [u8]>), } impl<'a> Protocol<'a> { @@ -625,6 +675,19 @@ impl<'a> Protocol<'a> { w.write_all(encode::usize(bytes.len(), &mut encode::usize_buffer()))?; w.write_all(bytes)? } + #[cfg(feature = "custom")] + Protocol::Custom { def, data } => { + w.write_all(encode::u32(def.code, &mut buf))?; + if def.size == -1 { + w.write_all(encode::usize(data.len(), &mut encode::usize_buffer()))?; + } + w.write_all(data.as_ref())? + } + #[cfg(feature = "custom")] + Protocol::Unknown(code, data) => { + w.write_all(encode::u32(*code, &mut buf))?; + w.write_all(data.as_ref())? + } } Ok(()) } @@ -673,6 +736,13 @@ impl<'a> Protocol<'a> { P2pStardust => P2pStardust, WebRTC => WebRTC, HttpPath(cow) => HttpPath(Cow::Owned(cow.into_owned())), + #[cfg(feature = "custom")] + Custom { def, data } => Custom { + def: def.clone(), + data: Cow::Owned(data.into_owned()), + }, + #[cfg(feature = "custom")] + Unknown(code, data) => Unknown(code, Cow::Owned(data.into_owned())), } } @@ -721,6 +791,10 @@ impl<'a> Protocol<'a> { P2pStardust => "p2p-stardust", WebRTC => "webrtc", HttpPath(_) => "http-path", + #[cfg(feature = "custom")] + Custom { def, .. } => def.name, + #[cfg(feature = "custom")] + Unknown(_, _) => "unknown", } } } @@ -728,7 +802,11 @@ impl<'a> Protocol<'a> { impl fmt::Display for Protocol<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use self::Protocol::*; - write!(f, "/{}", self.tag())?; + match self { + #[cfg(feature = "custom")] + Unknown(code, _) => write!(f, "/unknown-{code}"), + _ => write!(f, "/{}", self.tag()), + }?; match self { Dccp(port) => write!(f, "/{port}"), Dns(s) => write!(f, "/{s}"), @@ -783,6 +861,32 @@ impl fmt::Display for Protocol<'_> { percent_encoding::percent_encode(s.as_bytes(), PATH_SEGMENT_ENCODE_SET); write!(f, "/{encoded}") } + #[cfg(feature = "custom")] + Custom { def, data } => { + if let Some(t) = &def.transcoder { + let s = t.bytes_to_string(data.as_ref()).map_err(|_| fmt::Error)?; + if !s.is_empty() { + write!(f, "/{s}")?; + } + Ok(()) + } else if data.is_empty() { + Ok(()) + } else if def.path { + let s = std::str::from_utf8(data.as_ref()).map_err(|_| fmt::Error)?; + let encoded = + percent_encoding::percent_encode(s.as_bytes(), PATH_SEGMENT_ENCODE_SET); + write!(f, "/{encoded}") + } else { + write!(f, "/{}", multibase::Base::Base58Btc.encode(data.as_ref())) + } + } + #[cfg(feature = "custom")] + Unknown(_, data) => { + if !data.is_empty() { + write!(f, "/{}", multibase::Base::Base58Btc.encode(data.as_ref()))?; + } + Ok(()) + } _ => Ok(()), } }