Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ version = "0.18.3"

[features]
default = ["url"]
custom = []

[dependencies]
arrayref = "0.3"
Expand Down
363 changes: 363 additions & 0 deletions src/custom.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,363 @@
use std::collections::HashMap;
use std::sync::Arc;

use crate::{Error, Multiaddr, Protocol, Result};

/// A transcoder defines how to encode and decode a custom protocol's data
/// between its binary representation and its human-readable string representation.
pub trait Transcoder: Send + Sync {
/// Attempts to parse the human-readable string component of a protocol into bytes.
fn string_to_bytes(
&self,
s: &str,
) -> std::result::Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>;

/// Attempts to format the binary representation of a protocol's data into a human-readable string.
fn bytes_to_string(
&self,
bytes: &[u8],
) -> std::result::Result<String, Box<dyn std::error::Error + Send + Sync>>;
}

/// A custom protocol definition.
pub struct CustomProtocolDef {
pub name: &'static str,
pub code: u32,
/// The length of the binary payload.
/// `0` means no data. `> 0` means a fixed data length. `-1` denotes a length-prefixed protocol.
pub size: i32,
pub path: bool,
pub transcoder: Option<Box<dyn Transcoder>>,
}

impl std::cmp::PartialEq for CustomProtocolDef {
fn eq(&self, other: &Self) -> bool {
self.code == other.code && self.name == other.name
}
}

impl std::cmp::Eq for CustomProtocolDef {}

impl std::fmt::Debug for CustomProtocolDef {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CustomProtocolDef")
.field("name", &self.name)
.field("code", &self.code)
.field("size", &self.size)
.field("path", &self.path)
.finish()
}
}

/// A registry mapping protocol codes and names to their custom definitions.
#[derive(Clone)]
pub struct Registry {
by_code: HashMap<u32, Arc<CustomProtocolDef>>,
by_name: HashMap<String, Arc<CustomProtocolDef>>,
}

impl Default for Registry {
fn default() -> Self {
let mut r = Self {
by_code: HashMap::new(),
by_name: HashMap::new(),
};
r.register_builtins();
r
}
}

impl Registry {
/// Create a new, empty protocol registry. Wait, new() actually uses default and adds built-ins.
pub fn new() -> Self {
Self::default()
}

/// Add all built-in standard protocols to the registry
fn register_builtins(&mut self) {
for &(name, code, size, path) in crate::protocol::BUILT_IN_PROTOCOLS.iter() {
self.register(CustomProtocolDef {
name,
code,
size,
path,
transcoder: None,
});
}
}

/// Add a custom protocol definition to this registry.
pub fn register(&mut self, mut def: CustomProtocolDef) {
if def.path && def.name.starts_with('/') {
def.name = def.name.trim_start_matches('/');
}
let name = def.name.to_string();
let code = def.code;
let arc = Arc::new(def);
self.by_code.insert(code, arc.clone());
self.by_name.insert(name, arc);
}

/// Returns a registered custom protocol by its integer code.
pub fn get_by_code(&self, code: u32) -> Option<Arc<CustomProtocolDef>> {
self.by_code.get(&code).cloned()
}

/// Returns a registered custom protocol by its string name.
pub fn get_by_name(&self, name: &str) -> Option<Arc<CustomProtocolDef>> {
self.by_name.get(name).cloned()
}

/// Unregisters a protocol by its string name.
pub fn unregister_by_name(&mut self, name: &str) {
if let Some(def) = self.by_name.remove(name) {
self.by_code.remove(&def.code);
}
}

/// Unregisters a protocol by its integer code.
pub fn unregister_by_code(&mut self, code: u32) {
if let Some(def) = self.by_code.remove(&code) {
self.by_name.remove(def.name);
}
}

/// Iterate over the protocols in a `Multiaddr` using this registry.
pub fn iter<'a>(&'a self, ma: &'a Multiaddr) -> RegistryIter<'a> {
RegistryIter {
registry: self,
data: ma.as_ref(),
}
}
}

/// Iterator over protocols using a registry.
pub struct RegistryIter<'a> {
registry: &'a Registry,
data: &'a [u8],
}

impl<'a> Iterator for RegistryIter<'a> {
type Item = Protocol<'a>;

fn next(&mut self) -> Option<Self::Item> {
if self.data.is_empty() {
return None;
}

let (p, next_data) = self.registry.parse_protocol_from_bytes(self.data).ok()?;
self.data = next_data;
Some(p)
}
}

impl Registry {
/// Try parsing a single Protocol from bytes using the registry.
pub fn parse_protocol_from_bytes<'a>(
&self,
input: &'a [u8],
) -> Result<(Protocol<'a>, &'a [u8])> {
let n_input = input;
let id_res = unsigned_varint::decode::u32(n_input);
if let Ok((id, _rest)) = id_res {
if !self.by_code.contains_key(&id) {
return Err(Error::UnknownProtocolId(id));
}
}

if let Ok(res) = Protocol::from_bytes(input) {
return Ok(res);
}

let n_input = input;
let id_res = unsigned_varint::decode::u32(n_input);
if let Ok((id, rest)) = id_res {
if let Some(def) = self.get_by_code(id) {
let (data, out_rest) = if def.size == 0 {
(std::borrow::Cow::Borrowed(&rest[..0]), rest)
} else if def.size > 0 {
let fixed = def.size as usize;
if rest.len() < fixed {
return Err(Error::DataLessThanLen);
}
let (d, r) = rest.split_at(fixed);
(std::borrow::Cow::Borrowed(d), r)
} else {
let (len, r) =
unsigned_varint::decode::usize(rest).map_err(|_| Error::DataLessThanLen)?;
if r.len() < len {
return Err(Error::DataLessThanLen);
}
let (d, r2) = r.split_at(len);
(std::borrow::Cow::Borrowed(d), r2)
};
return Ok((Protocol::Custom { def, data }, out_rest));
}
}

Err(Error::UnknownProtocolId(
id_res.map(|(i, _)| i).unwrap_or(0),
))
}

/// Try parsing a single Protocol from string parts using the registry.
pub fn parse_protocol_from_str_parts<'a, I>(&self, iter: &mut I) -> Result<Protocol<'a>>
where
I: Iterator<Item = &'a str> + Clone,
{
let mut peek_iter = iter.clone();
if let Some(tag) = peek_iter.next() {
if !self.by_name.contains_key(tag) {
return Err(Error::UnknownProtocolString(tag.to_string()));
}
}

let mut native_iter = iter.clone();
if let Ok(p) = Protocol::from_str_parts(&mut native_iter) {
*iter = native_iter;
return Ok(p);
}

let mut peek_iter = iter.clone();
if let Some(tag) = peek_iter.next() {
if let Some(def) = self.get_by_name(tag) {
iter.next(); // consume the tag
let data = if def.size == 0 {
vec![]
} else if let Some(t) = &def.transcoder {
let part = iter.next().ok_or(Error::InvalidProtocolString)?;
t.string_to_bytes(part)
.map_err(|_| Error::InvalidProtocolString)?
} else if def.path {
let part = iter.next().ok_or(Error::InvalidProtocolString)?;
percent_encoding::percent_decode(part.as_bytes()).collect::<Vec<u8>>()
} else {
let part = iter.next().ok_or(Error::InvalidProtocolString)?;
multibase::Base::Base58Btc
.decode(part)
.map_err(|_| Error::InvalidProtocolString)?
};
return Ok(Protocol::Custom {
def,
data: std::borrow::Cow::Owned(data),
});
}
}

let mut final_try = iter.clone();
if let Some(tag) = final_try.next() {
Err(Error::UnknownProtocolString(tag.to_string()))
} else {
Err(Error::InvalidProtocolString)
}
}

/// Parse a Multiaddr string using this registry
pub fn try_from_str(&self, input: &str) -> Result<Multiaddr> {
let mut addr = Multiaddr::empty();
let mut parts = input.split('/');

if Some("") != parts.next() {
return Err(Error::InvalidMultiaddr);
}

while parts.clone().peekable().peek().is_some() {
let p = self.parse_protocol_from_str_parts(&mut parts)?;
addr = addr.with(p);
}

Ok(addr)
}

/// Parse a Multiaddr from bytes using this registry
pub fn try_from_bytes(&self, mut input: &[u8]) -> Result<Multiaddr> {
let mut addr = Multiaddr::empty();
while !input.is_empty() {
let (p, rest) = self.parse_protocol_from_bytes(input)?;
addr = addr.with(p);
input = rest;
}
Ok(addr)
}
}

#[cfg(test)]
mod tests {
use super::*;

struct SimpleTranscoder;
impl Transcoder for SimpleTranscoder {
fn string_to_bytes(
&self,
s: &str,
) -> std::result::Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
Ok(s.as_bytes().to_vec())
}
fn bytes_to_string(
&self,
bytes: &[u8],
) -> std::result::Result<String, Box<dyn std::error::Error + Send + Sync>> {
Ok(String::from_utf8(bytes.to_vec())?)
}
}

#[test]
fn test_custom_protocol_registry() {
let mut registry = Registry::new();
registry.register(CustomProtocolDef {
name: "my-custom",
code: 999,
size: -1,
path: false,
transcoder: Some(Box::new(SimpleTranscoder)),
});

let addr = registry
.try_from_str("/ip4/127.0.0.1/my-custom/helloworld")
.unwrap();

// Output via normal iter should panic because the global parser doesn't know 999,
// wait, we modified the normal fmt::Display to iterate, BUT Display iterates the multiaddr.
// If we try `addr.to_string()`, it will panic if it's not a generic iterator.

let vec = addr.to_vec();
// Parse back from vec
let parsed = registry.try_from_bytes(&vec).unwrap();

let mut iter = registry.iter(&parsed);
if let Some(Protocol::Ip4(ip)) = iter.next() {
assert_eq!(ip, std::net::Ipv4Addr::new(127, 0, 0, 1));
} else {
panic!("expected ip4");
}

if let Some(Protocol::Custom { def, data }) = iter.next() {
assert_eq!(def.code, 999);
assert_eq!(def.name, "my-custom");
assert_eq!(data.as_ref(), b"helloworld");
} else {
panic!("expected custom protocol");
}
}

#[test]
fn test_unregister_builtin() {
let mut registry = Registry::default();

// Assert tcp works
let addr = registry.try_from_str("/ip4/127.0.0.1/tcp/80").unwrap();
let mut iter = registry.iter(&addr);
assert!(matches!(iter.next(), Some(Protocol::Ip4(_))));
assert!(matches!(iter.next(), Some(Protocol::Tcp(80))));

// Unregister tcp
registry.unregister_by_name("tcp");

// Assert tcp fails now
assert!(registry.try_from_str("/ip4/127.0.0.1/tcp/80").is_err());

// And similarly from bytes
let vec = addr.to_vec();
assert!(registry.try_from_bytes(&vec).is_err());
}
}
7 changes: 6 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@ mod errors;
mod onion_addr;
mod protocol;

#[cfg(feature = "custom")]
mod custom;

#[cfg(feature = "url")]
mod from_url;

#[cfg(feature = "custom")]
pub use self::custom::{CustomProtocolDef, Registry, Transcoder};
pub use self::errors::{Error, Result};
pub use self::onion_addr::Onion3Addr;
pub use self::protocol::Protocol;
Expand Down Expand Up @@ -223,7 +228,7 @@ impl Multiaddr {
/// Returns &str identifiers for the protocol names themselves.
/// This omits specific info like addresses, ports, peer IDs, and the like.
/// Example: `"/ip4/127.0.0.1/tcp/5001"` would return `["ip4", "tcp"]`
pub fn protocol_stack(&self) -> ProtoStackIter {
pub fn protocol_stack(&self) -> ProtoStackIter<'_> {
ProtoStackIter { parts: self.iter() }
}
}
Expand Down
Loading
Loading