diff --git a/Cargo.toml b/Cargo.toml index 72e1d04..71fafe6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ version = "0.18.3" [features] default = ["url"] +custom = [] [dependencies] arrayref = "0.3" diff --git a/src/custom.rs b/src/custom.rs new file mode 100644 index 0000000..3311e00 --- /dev/null +++ b/src/custom.rs @@ -0,0 +1,363 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use crate::{Error, Multiaddr, Protocol, Result}; + +/// A transcoder defines how to encode and decode a custom protocol's data +/// between its binary representation and its human-readable string representation. +pub trait Transcoder: Send + Sync { + /// Attempts to parse the human-readable string component of a protocol into bytes. + fn string_to_bytes( + &self, + s: &str, + ) -> std::result::Result, Box>; + + /// Attempts to format the binary representation of a protocol's data into a human-readable string. + fn bytes_to_string( + &self, + bytes: &[u8], + ) -> std::result::Result>; +} + +/// A custom protocol definition. +pub struct CustomProtocolDef { + pub name: &'static str, + pub code: u32, + /// The length of the binary payload. + /// `0` means no data. `> 0` means a fixed data length. `-1` denotes a length-prefixed protocol. + pub size: i32, + pub path: bool, + pub transcoder: Option>, +} + +impl std::cmp::PartialEq for CustomProtocolDef { + fn eq(&self, other: &Self) -> bool { + self.code == other.code && self.name == other.name + } +} + +impl std::cmp::Eq for CustomProtocolDef {} + +impl std::fmt::Debug for CustomProtocolDef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CustomProtocolDef") + .field("name", &self.name) + .field("code", &self.code) + .field("size", &self.size) + .field("path", &self.path) + .finish() + } +} + +/// A registry mapping protocol codes and names to their custom definitions. +#[derive(Clone)] +pub struct Registry { + by_code: HashMap>, + by_name: HashMap>, +} + +impl Default for Registry { + fn default() -> Self { + let mut r = Self { + by_code: HashMap::new(), + by_name: HashMap::new(), + }; + r.register_builtins(); + r + } +} + +impl Registry { + /// Create a new, empty protocol registry. Wait, new() actually uses default and adds built-ins. + pub fn new() -> Self { + Self::default() + } + + /// Add all built-in standard protocols to the registry + fn register_builtins(&mut self) { + for &(name, code, size, path) in crate::protocol::BUILT_IN_PROTOCOLS.iter() { + self.register(CustomProtocolDef { + name, + code, + size, + path, + transcoder: None, + }); + } + } + + /// Add a custom protocol definition to this registry. + pub fn register(&mut self, mut def: CustomProtocolDef) { + if def.path && def.name.starts_with('/') { + def.name = def.name.trim_start_matches('/'); + } + let name = def.name.to_string(); + let code = def.code; + let arc = Arc::new(def); + self.by_code.insert(code, arc.clone()); + self.by_name.insert(name, arc); + } + + /// Returns a registered custom protocol by its integer code. + pub fn get_by_code(&self, code: u32) -> Option> { + self.by_code.get(&code).cloned() + } + + /// Returns a registered custom protocol by its string name. + pub fn get_by_name(&self, name: &str) -> Option> { + self.by_name.get(name).cloned() + } + + /// Unregisters a protocol by its string name. + pub fn unregister_by_name(&mut self, name: &str) { + if let Some(def) = self.by_name.remove(name) { + self.by_code.remove(&def.code); + } + } + + /// Unregisters a protocol by its integer code. + pub fn unregister_by_code(&mut self, code: u32) { + if let Some(def) = self.by_code.remove(&code) { + self.by_name.remove(def.name); + } + } + + /// Iterate over the protocols in a `Multiaddr` using this registry. + pub fn iter<'a>(&'a self, ma: &'a Multiaddr) -> RegistryIter<'a> { + RegistryIter { + registry: self, + data: ma.as_ref(), + } + } +} + +/// Iterator over protocols using a registry. +pub struct RegistryIter<'a> { + registry: &'a Registry, + data: &'a [u8], +} + +impl<'a> Iterator for RegistryIter<'a> { + type Item = Protocol<'a>; + + fn next(&mut self) -> Option { + if self.data.is_empty() { + return None; + } + + let (p, next_data) = self.registry.parse_protocol_from_bytes(self.data).ok()?; + self.data = next_data; + Some(p) + } +} + +impl Registry { + /// Try parsing a single Protocol from bytes using the registry. + pub fn parse_protocol_from_bytes<'a>( + &self, + input: &'a [u8], + ) -> Result<(Protocol<'a>, &'a [u8])> { + let n_input = input; + let id_res = unsigned_varint::decode::u32(n_input); + if let Ok((id, _rest)) = id_res { + if !self.by_code.contains_key(&id) { + return Err(Error::UnknownProtocolId(id)); + } + } + + if let Ok(res) = Protocol::from_bytes(input) { + return Ok(res); + } + + let n_input = input; + let id_res = unsigned_varint::decode::u32(n_input); + if let Ok((id, rest)) = id_res { + if let Some(def) = self.get_by_code(id) { + let (data, out_rest) = if def.size == 0 { + (std::borrow::Cow::Borrowed(&rest[..0]), rest) + } else if def.size > 0 { + let fixed = def.size as usize; + if rest.len() < fixed { + return Err(Error::DataLessThanLen); + } + let (d, r) = rest.split_at(fixed); + (std::borrow::Cow::Borrowed(d), r) + } else { + let (len, r) = + unsigned_varint::decode::usize(rest).map_err(|_| Error::DataLessThanLen)?; + if r.len() < len { + return Err(Error::DataLessThanLen); + } + let (d, r2) = r.split_at(len); + (std::borrow::Cow::Borrowed(d), r2) + }; + return Ok((Protocol::Custom { def, data }, out_rest)); + } + } + + Err(Error::UnknownProtocolId( + id_res.map(|(i, _)| i).unwrap_or(0), + )) + } + + /// Try parsing a single Protocol from string parts using the registry. + pub fn parse_protocol_from_str_parts<'a, I>(&self, iter: &mut I) -> Result> + where + I: Iterator + Clone, + { + let mut peek_iter = iter.clone(); + if let Some(tag) = peek_iter.next() { + if !self.by_name.contains_key(tag) { + return Err(Error::UnknownProtocolString(tag.to_string())); + } + } + + let mut native_iter = iter.clone(); + if let Ok(p) = Protocol::from_str_parts(&mut native_iter) { + *iter = native_iter; + return Ok(p); + } + + let mut peek_iter = iter.clone(); + if let Some(tag) = peek_iter.next() { + if let Some(def) = self.get_by_name(tag) { + iter.next(); // consume the tag + let data = if def.size == 0 { + vec![] + } else if let Some(t) = &def.transcoder { + let part = iter.next().ok_or(Error::InvalidProtocolString)?; + t.string_to_bytes(part) + .map_err(|_| Error::InvalidProtocolString)? + } else if def.path { + let part = iter.next().ok_or(Error::InvalidProtocolString)?; + percent_encoding::percent_decode(part.as_bytes()).collect::>() + } else { + let part = iter.next().ok_or(Error::InvalidProtocolString)?; + multibase::Base::Base58Btc + .decode(part) + .map_err(|_| Error::InvalidProtocolString)? + }; + return Ok(Protocol::Custom { + def, + data: std::borrow::Cow::Owned(data), + }); + } + } + + let mut final_try = iter.clone(); + if let Some(tag) = final_try.next() { + Err(Error::UnknownProtocolString(tag.to_string())) + } else { + Err(Error::InvalidProtocolString) + } + } + + /// Parse a Multiaddr string using this registry + pub fn try_from_str(&self, input: &str) -> Result { + let mut addr = Multiaddr::empty(); + let mut parts = input.split('/'); + + if Some("") != parts.next() { + return Err(Error::InvalidMultiaddr); + } + + while parts.clone().peekable().peek().is_some() { + let p = self.parse_protocol_from_str_parts(&mut parts)?; + addr = addr.with(p); + } + + Ok(addr) + } + + /// Parse a Multiaddr from bytes using this registry + pub fn try_from_bytes(&self, mut input: &[u8]) -> Result { + let mut addr = Multiaddr::empty(); + while !input.is_empty() { + let (p, rest) = self.parse_protocol_from_bytes(input)?; + addr = addr.with(p); + input = rest; + } + Ok(addr) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct SimpleTranscoder; + impl Transcoder for SimpleTranscoder { + fn string_to_bytes( + &self, + s: &str, + ) -> std::result::Result, Box> { + Ok(s.as_bytes().to_vec()) + } + fn bytes_to_string( + &self, + bytes: &[u8], + ) -> std::result::Result> { + Ok(String::from_utf8(bytes.to_vec())?) + } + } + + #[test] + fn test_custom_protocol_registry() { + let mut registry = Registry::new(); + registry.register(CustomProtocolDef { + name: "my-custom", + code: 999, + size: -1, + path: false, + transcoder: Some(Box::new(SimpleTranscoder)), + }); + + let addr = registry + .try_from_str("/ip4/127.0.0.1/my-custom/helloworld") + .unwrap(); + + // Output via normal iter should panic because the global parser doesn't know 999, + // wait, we modified the normal fmt::Display to iterate, BUT Display iterates the multiaddr. + // If we try `addr.to_string()`, it will panic if it's not a generic iterator. + + let vec = addr.to_vec(); + // Parse back from vec + let parsed = registry.try_from_bytes(&vec).unwrap(); + + let mut iter = registry.iter(&parsed); + if let Some(Protocol::Ip4(ip)) = iter.next() { + assert_eq!(ip, std::net::Ipv4Addr::new(127, 0, 0, 1)); + } else { + panic!("expected ip4"); + } + + if let Some(Protocol::Custom { def, data }) = iter.next() { + assert_eq!(def.code, 999); + assert_eq!(def.name, "my-custom"); + assert_eq!(data.as_ref(), b"helloworld"); + } else { + panic!("expected custom protocol"); + } + } + + #[test] + fn test_unregister_builtin() { + let mut registry = Registry::default(); + + // Assert tcp works + let addr = registry.try_from_str("/ip4/127.0.0.1/tcp/80").unwrap(); + let mut iter = registry.iter(&addr); + assert!(matches!(iter.next(), Some(Protocol::Ip4(_)))); + assert!(matches!(iter.next(), Some(Protocol::Tcp(80)))); + + // Unregister tcp + registry.unregister_by_name("tcp"); + + // Assert tcp fails now + assert!(registry.try_from_str("/ip4/127.0.0.1/tcp/80").is_err()); + + // And similarly from bytes + let vec = addr.to_vec(); + assert!(registry.try_from_bytes(&vec).is_err()); + } +} diff --git a/src/lib.rs b/src/lib.rs index b6b0ad4..264fb11 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,9 +7,14 @@ mod errors; mod onion_addr; mod protocol; +#[cfg(feature = "custom")] +mod custom; + #[cfg(feature = "url")] mod from_url; +#[cfg(feature = "custom")] +pub use self::custom::{CustomProtocolDef, Registry, Transcoder}; pub use self::errors::{Error, Result}; pub use self::onion_addr::Onion3Addr; pub use self::protocol::Protocol; @@ -223,7 +228,7 @@ impl Multiaddr { /// Returns &str identifiers for the protocol names themselves. /// This omits specific info like addresses, ports, peer IDs, and the like. /// Example: `"/ip4/127.0.0.1/tcp/5001"` would return `["ip4", "tcp"]` - pub fn protocol_stack(&self) -> ProtoStackIter { + pub fn protocol_stack(&self) -> ProtoStackIter<'_> { ProtoStackIter { parts: self.iter() } } } diff --git a/src/protocol.rs b/src/protocol.rs index f0fa1d4..5137919 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -60,6 +60,49 @@ const P2P_STARDUST: u32 = 277; // Deprecated const WEBRTC: u32 = 281; const HTTP_PATH: u32 = 481; +#[cfg(feature = "custom")] +pub(crate) const BUILT_IN_PROTOCOLS: &[(&str, u32, i32, bool)] = &[ + ("ip4", IP4, 4, false), + ("tcp", TCP, 2, false), + ("udp", UDP, 2, false), + ("dccp", DCCP, 2, false), + ("ip6", IP6, 16, false), + ("p2p", P2P, -1, false), + ("ipfs", P2P, -1, false), + ("http", HTTP, 0, false), + ("https", HTTPS, 0, false), + ("onion", ONION, 12, false), + ("onion3", ONION3, 37, false), + ("quic", QUIC, 0, false), + ("quic-v1", QUIC_V1, 0, false), + ("ws", WS, 0, false), + ("wss", WSS, 0, false), + ("p2p-websocket-star", P2P_WEBSOCKET_STAR, 0, false), + ("webrtc-direct", WEBRTC_DIRECT, 0, false), + ("p2p-webrtc-direct", P2P_WEBRTC_DIRECT, 0, false), + ("certhash", CERTHASH, -1, false), + ("p2p-circuit", P2P_CIRCUIT, 0, false), + ("sctp", SCTP, 2, false), + ("udt", UDT, 0, false), + ("utp", UTP, 0, false), + ("unix", UNIX, -1, true), + ("dns", DNS, -1, false), + ("dns4", DNS4, -1, false), + ("dns6", DNS6, -1, false), + ("dnsaddr", DNSADDR, -1, false), + ("tls", TLS, 0, false), + ("noise", NOISE, 0, false), + ("webtransport", WEBTRANSPORT, 0, false), + ("ip6zone", IP6ZONE, -1, true), + ("ipcidr", IPCIDR, 1, false), + ("garlic64", GARLIC64, -1, false), + ("garlic32", GARLIC32, -1, false), + ("sni", SNI, -1, false), + ("webrtc", WEBRTC, 0, false), + ("http-path", HTTP_PATH, -1, true), + ("memory", MEMORY, 8, false), +]; + /// Type-alias for how multi-addresses use `Multihash`. /// /// The `64` defines the allocation size for the digest within the `Multihash`. @@ -130,6 +173,13 @@ pub enum Protocol<'a> { P2pStardust, WebRTC, HttpPath(Cow<'a, str>), + #[cfg(feature = "custom")] + Custom { + def: std::sync::Arc, + data: Cow<'a, [u8]>, + }, + #[cfg(feature = "custom")] + Unknown(u32, Cow<'a, [u8]>), } impl<'a> Protocol<'a> { @@ -625,6 +675,19 @@ impl<'a> Protocol<'a> { w.write_all(encode::usize(bytes.len(), &mut encode::usize_buffer()))?; w.write_all(bytes)? } + #[cfg(feature = "custom")] + Protocol::Custom { def, data } => { + w.write_all(encode::u32(def.code, &mut buf))?; + if def.size == -1 { + w.write_all(encode::usize(data.len(), &mut encode::usize_buffer()))?; + } + w.write_all(data.as_ref())? + } + #[cfg(feature = "custom")] + Protocol::Unknown(code, data) => { + w.write_all(encode::u32(*code, &mut buf))?; + w.write_all(data.as_ref())? + } } Ok(()) } @@ -673,6 +736,13 @@ impl<'a> Protocol<'a> { P2pStardust => P2pStardust, WebRTC => WebRTC, HttpPath(cow) => HttpPath(Cow::Owned(cow.into_owned())), + #[cfg(feature = "custom")] + Custom { def, data } => Custom { + def: def.clone(), + data: Cow::Owned(data.into_owned()), + }, + #[cfg(feature = "custom")] + Unknown(code, data) => Unknown(code, Cow::Owned(data.into_owned())), } } @@ -721,6 +791,10 @@ impl<'a> Protocol<'a> { P2pStardust => "p2p-stardust", WebRTC => "webrtc", HttpPath(_) => "http-path", + #[cfg(feature = "custom")] + Custom { def, .. } => def.name, + #[cfg(feature = "custom")] + Unknown(_, _) => "unknown", } } } @@ -728,7 +802,11 @@ impl<'a> Protocol<'a> { impl fmt::Display for Protocol<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use self::Protocol::*; - write!(f, "/{}", self.tag())?; + match self { + #[cfg(feature = "custom")] + Unknown(code, _) => write!(f, "/unknown-{code}"), + _ => write!(f, "/{}", self.tag()), + }?; match self { Dccp(port) => write!(f, "/{port}"), Dns(s) => write!(f, "/{s}"), @@ -783,6 +861,32 @@ impl fmt::Display for Protocol<'_> { percent_encoding::percent_encode(s.as_bytes(), PATH_SEGMENT_ENCODE_SET); write!(f, "/{encoded}") } + #[cfg(feature = "custom")] + Custom { def, data } => { + if let Some(t) = &def.transcoder { + let s = t.bytes_to_string(data.as_ref()).map_err(|_| fmt::Error)?; + if !s.is_empty() { + write!(f, "/{s}")?; + } + Ok(()) + } else if data.is_empty() { + Ok(()) + } else if def.path { + let s = std::str::from_utf8(data.as_ref()).map_err(|_| fmt::Error)?; + let encoded = + percent_encoding::percent_encode(s.as_bytes(), PATH_SEGMENT_ENCODE_SET); + write!(f, "/{encoded}") + } else { + write!(f, "/{}", multibase::Base::Base58Btc.encode(data.as_ref())) + } + } + #[cfg(feature = "custom")] + Unknown(_, data) => { + if !data.is_empty() { + write!(f, "/{}", multibase::Base::Base58Btc.encode(data.as_ref()))?; + } + Ok(()) + } _ => Ok(()), } }