Skip to content

Commit faca367

Browse files
committed
restrict ServiceFlags api
1 parent d0d40cd commit faca367

5 files changed

Lines changed: 62 additions & 99 deletions

File tree

dash-spv/src/network/handshake.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::net::SocketAddr;
44
use std::time::{Duration, SystemTime, UNIX_EPOCH};
55

66
use dashcore::network::constants;
7-
use dashcore::network::constants::{ServiceFlags, NODE_HEADERS_COMPRESSED};
7+
use dashcore::network::constants::ServiceFlags;
88
use dashcore::network::message::NetworkMessage;
99
use dashcore::network::message_network::VersionMessage;
1010
use dashcore::Network;
@@ -36,7 +36,7 @@ pub struct HandshakeManager {
3636
state: HandshakeState,
3737
our_version: u32,
3838
peer_version: Option<u32>,
39-
peer_services: Option<ServiceFlags>,
39+
peer_services: ServiceFlags,
4040
version_received: bool,
4141
verack_received: bool,
4242
version_sent: bool,
@@ -56,7 +56,7 @@ impl HandshakeManager {
5656
state: HandshakeState::Init,
5757
our_version: constants::PROTOCOL_VERSION,
5858
peer_version: None,
59-
peer_services: None,
59+
peer_services: ServiceFlags::NONE,
6060
version_received: false,
6161
verack_received: false,
6262
version_sent: false,
@@ -157,7 +157,7 @@ impl HandshakeManager {
157157
version_msg
158158
);
159159
self.peer_version = Some(version_msg.version);
160-
self.peer_services = Some(version_msg.services);
160+
self.peer_services = version_msg.services;
161161
self.version_received = true;
162162

163163
// Update connection's peer information
@@ -261,7 +261,7 @@ impl HandshakeManager {
261261
.as_secs() as i64;
262262

263263
// Advertise headers2 support (NODE_HEADERS_COMPRESSED)
264-
let services = ServiceFlags::NONE | NODE_HEADERS_COMPRESSED;
264+
let services = ServiceFlags::NODE_HEADERS_COMPRESSED;
265265

266266
// Parse the local address safely
267267
let local_addr = "127.0.0.1:0"
@@ -313,7 +313,7 @@ impl HandshakeManager {
313313

314314
/// Check if peer supports headers2 compression.
315315
pub fn peer_supports_headers2(&self) -> bool {
316-
self.peer_services.map(|services| services.has(NODE_HEADERS_COMPRESSED)).unwrap_or(false)
316+
self.peer_services.has(ServiceFlags::NODE_HEADERS_COMPRESSED)
317317
}
318318

319319
/// Negotiate headers2 support with the peer after handshake completion.

dash-spv/src/network/peer.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ pub struct Peer {
4040
pending_pings: HashMap<u64, SystemTime>, // nonce -> sent_time
4141
// Peer information from Version message
4242
version: Option<u32>,
43-
services: Option<u64>,
43+
services: ServiceFlags,
4444
user_agent: Option<String>,
4545
best_height: Option<u32>,
4646
relay: Option<bool>,
@@ -68,7 +68,7 @@ impl Peer {
6868
last_pong_received: None,
6969
pending_pings: HashMap::new(),
7070
version: None,
71-
services: None,
71+
services: ServiceFlags::NONE,
7272
user_agent: None,
7373
best_height: None,
7474
relay: None,
@@ -115,7 +115,7 @@ impl Peer {
115115
last_pong_received: None,
116116
pending_pings: HashMap::new(),
117117
version: None,
118-
services: None,
118+
services: ServiceFlags::NONE,
119119
user_agent: None,
120120
best_height: None,
121121
relay: None,
@@ -144,7 +144,7 @@ impl Peer {
144144
}
145145

146146
pub fn has_service(&self, flags: ServiceFlags) -> bool {
147-
self.services.map(|s| ServiceFlags::from(s).has(flags)).unwrap_or(false)
147+
self.services.has(flags)
148148
}
149149

150150
/// Connect to the peer (instance method for compatibility).
@@ -273,7 +273,7 @@ impl Peer {
273273

274274
// All validations passed, update peer info
275275
self.version = Some(version_msg.version);
276-
self.services = Some(version_msg.services.as_u64());
276+
self.services = version_msg.services;
277277
self.user_agent = Some(version_msg.user_agent.clone());
278278
self.best_height = Some(version_msg.start_height as u32);
279279
self.relay = Some(version_msg.relay);
@@ -824,12 +824,7 @@ impl Peer {
824824
// We can request headers2 if peer has the service flag for headers2 support
825825
// Note: We don't wait for SendHeaders2 from peer as that creates a race condition
826826
// during initial sync. The service flag is sufficient to know they support headers2.
827-
if let Some(services) = self.services {
828-
dashcore::network::constants::ServiceFlags::from(services)
829-
.has(dashcore::network::constants::NODE_HEADERS_COMPRESSED)
830-
} else {
831-
false
832-
}
827+
self.services.has(ServiceFlags::NODE_HEADERS_COMPRESSED)
833828
}
834829
}
835830

dash/src/network/address.rs

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ impl Encodable for AddrV2Message {
308308
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
309309
let mut len = 0;
310310
len += self.time.consensus_encode(w)?;
311-
len += VarInt(self.services.as_u64()).consensus_encode(w)?;
311+
len += self.services.consensus_encode(w)?;
312312
len += self.addr.consensus_encode(w)?;
313313

314314
w.write_all(&self.port.to_be_bytes())?;
@@ -322,7 +322,7 @@ impl Decodable for AddrV2Message {
322322
fn consensus_decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, encode::Error> {
323323
Ok(AddrV2Message {
324324
time: Decodable::consensus_decode(r)?,
325-
services: ServiceFlags::from(VarInt::consensus_decode(r)?.0),
325+
services: ServiceFlags::consensus_decode(r)?,
326326
addr: Decodable::consensus_decode(r)?,
327327
port: u16::swap_bytes(Decodable::consensus_decode(r)?),
328328
})
@@ -412,16 +412,20 @@ mod test {
412412

413413
#[test]
414414
fn test_socket_addr() {
415+
let mut services = ServiceFlags::NETWORK;
416+
services.add(ServiceFlags::WITNESS);
417+
415418
let s4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(111, 222, 123, 4)), 5555);
416-
let a4 = Address::new(&s4, ServiceFlags::NETWORK | ServiceFlags::WITNESS);
419+
let a4 = Address::new(&s4, services);
417420
assert_eq!(a4.socket_addr().unwrap(), s4);
421+
418422
let s6 = SocketAddr::new(
419423
IpAddr::V6(Ipv6Addr::new(
420424
0x1111, 0x2222, 0x3333, 0x4444, 0x5555, 0x6666, 0x7777, 0x8888,
421425
)),
422426
9999,
423427
);
424-
let a6 = Address::new(&s6, ServiceFlags::NETWORK | ServiceFlags::WITNESS);
428+
let a6 = Address::new(&s6, services);
425429
assert_eq!(a6.socket_addr().unwrap(), s6);
426430
}
427431

@@ -577,19 +581,26 @@ mod test {
577581
let raw = hex!("0261bc6649019902abab208d79627683fd4804010409090909208d");
578582
let addresses: Vec<AddrV2Message> = deserialize(&raw).unwrap();
579583

584+
let mut services1 = ServiceFlags::NONE;
585+
services1.add(ServiceFlags::NETWORK);
586+
587+
let mut services2 = ServiceFlags::NONE;
588+
services2
589+
.add(ServiceFlags::NETWORK_LIMITED)
590+
.add(ServiceFlags::WITNESS)
591+
.add(ServiceFlags::COMPACT_FILTERS);
592+
580593
assert_eq!(
581594
addresses,
582595
vec![
583596
AddrV2Message {
584-
services: ServiceFlags::NETWORK,
597+
services: services1,
585598
time: 0x4966bc61,
586599
port: 8333,
587600
addr: AddrV2::Unknown(153, hex!("abab"))
588601
},
589602
AddrV2Message {
590-
services: ServiceFlags::NETWORK_LIMITED
591-
| ServiceFlags::WITNESS
592-
| ServiceFlags::COMPACT_FILTERS,
603+
services: services2,
593604
time: 0x83766279,
594605
port: 8333,
595606
addr: AddrV2::Ipv4(Ipv4Addr::new(9, 9, 9, 9))

dash/src/network/constants.rs

Lines changed: 17 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,12 @@
3333
//! assert_eq!(&bytes[..], &[0xBF, 0x0C, 0x6B, 0xBD]);
3434
//! ```
3535
36-
use core::convert::From;
37-
use core::{fmt, ops};
36+
use core::fmt;
3837

3938
use hashes::Hash;
4039

4140
use crate::consensus::encode::{self, Decodable, Encodable};
42-
use crate::{BlockHash, io};
43-
44-
// Re-export NODE_HEADERS_COMPRESSED for convenience
45-
pub const NODE_HEADERS_COMPRESSED: ServiceFlags = ServiceFlags::NODE_HEADERS_COMPRESSED;
41+
use crate::{BlockHash, VarInt, io};
4642

4743
/// Version of the protocol as appearing in network message headers
4844
/// This constant is used to signal to other peers which features you support.
@@ -307,57 +303,18 @@ impl fmt::Display for ServiceFlags {
307303
}
308304
}
309305

310-
impl From<u64> for ServiceFlags {
311-
fn from(f: u64) -> Self {
312-
ServiceFlags(f)
313-
}
314-
}
315-
316-
impl From<ServiceFlags> for u64 {
317-
fn from(val: ServiceFlags) -> Self {
318-
val.0
319-
}
320-
}
321-
322-
impl ops::BitOr for ServiceFlags {
323-
type Output = Self;
324-
325-
fn bitor(mut self, rhs: Self) -> Self {
326-
self.add(rhs)
327-
}
328-
}
329-
330-
impl ops::BitOrAssign for ServiceFlags {
331-
fn bitor_assign(&mut self, rhs: Self) {
332-
self.add(rhs);
333-
}
334-
}
335-
336-
impl ops::BitXor for ServiceFlags {
337-
type Output = Self;
338-
339-
fn bitxor(mut self, rhs: Self) -> Self {
340-
self.remove(rhs)
341-
}
342-
}
343-
344-
impl ops::BitXorAssign for ServiceFlags {
345-
fn bitxor_assign(&mut self, rhs: Self) {
346-
self.remove(rhs);
347-
}
348-
}
349-
350306
impl Encodable for ServiceFlags {
351307
#[inline]
352308
fn consensus_encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<usize, io::Error> {
353-
self.0.consensus_encode(w)
309+
VarInt(self.0).consensus_encode(w)
354310
}
355311
}
356312

357313
impl Decodable for ServiceFlags {
358314
#[inline]
359315
fn consensus_decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, encode::Error> {
360-
Ok(ServiceFlags(Decodable::consensus_decode(r)?))
316+
let services = VarInt::consensus_decode(r)?;
317+
Ok(ServiceFlags(services.0))
361318
}
362319
}
363320

@@ -434,27 +391,27 @@ mod tests {
434391
assert!(!flags.has(*f));
435392
}
436393

437-
flags |= ServiceFlags::WITNESS;
394+
flags.add(ServiceFlags::WITNESS);
438395
assert_eq!(flags, ServiceFlags::WITNESS);
439396

440-
let mut flags2 = flags | ServiceFlags::GETUTXO;
397+
flags.add(ServiceFlags::GETUTXO);
441398
for f in all.iter() {
442-
assert_eq!(flags2.has(*f), *f == ServiceFlags::WITNESS || *f == ServiceFlags::GETUTXO);
399+
assert_eq!(flags.has(*f), *f == ServiceFlags::WITNESS || *f == ServiceFlags::GETUTXO);
443400
}
444401

445-
flags2 ^= ServiceFlags::WITNESS;
446-
assert_eq!(flags2, ServiceFlags::GETUTXO);
402+
flags.remove(ServiceFlags::WITNESS);
403+
assert_eq!(flags, ServiceFlags::GETUTXO);
447404

448-
flags2 |= ServiceFlags::COMPACT_FILTERS;
449-
flags2 ^= ServiceFlags::GETUTXO;
450-
assert_eq!(flags2, ServiceFlags::COMPACT_FILTERS);
405+
flags.add(ServiceFlags::COMPACT_FILTERS);
406+
flags.remove(ServiceFlags::GETUTXO);
407+
assert_eq!(flags, ServiceFlags::COMPACT_FILTERS);
451408

452409
// Test formatting.
453410
assert_eq!("ServiceFlags(NONE)", ServiceFlags::NONE.to_string());
454411
assert_eq!("ServiceFlags(WITNESS)", ServiceFlags::WITNESS.to_string());
455-
let flag = ServiceFlags::WITNESS | ServiceFlags::BLOOM | ServiceFlags::NETWORK;
456-
assert_eq!("ServiceFlags(NETWORK|BLOOM|WITNESS)", flag.to_string());
457-
let flag = ServiceFlags::WITNESS | 0xf0.into();
458-
assert_eq!("ServiceFlags(WITNESS|COMPACT_FILTERS|0xb0)", flag.to_string());
412+
413+
let mut flags = ServiceFlags::NONE;
414+
flags.add(ServiceFlags::WITNESS).add(ServiceFlags::BLOOM).add(ServiceFlags::NETWORK);
415+
assert_eq!("ServiceFlags(NETWORK|BLOOM|WITNESS)", flags.to_string());
459416
}
460417
}

dash/src/network/message.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -934,13 +934,13 @@ mod test {
934934
assert_eq!(msg.magic, 0xd9b4bef9);
935935
if let NetworkMessage::Version(version_msg) = msg.payload {
936936
assert_eq!(version_msg.version, 70015);
937-
assert_eq!(
938-
version_msg.services,
939-
ServiceFlags::NETWORK
940-
| ServiceFlags::BLOOM
941-
| ServiceFlags::WITNESS
942-
| ServiceFlags::NETWORK_LIMITED
943-
);
937+
let mut expected_services = ServiceFlags::NONE;
938+
expected_services
939+
.add(ServiceFlags::NETWORK)
940+
.add(ServiceFlags::BLOOM)
941+
.add(ServiceFlags::WITNESS)
942+
.add(ServiceFlags::NETWORK_LIMITED);
943+
assert_eq!(version_msg.services, expected_services);
944944
assert_eq!(version_msg.timestamp, 1548554224);
945945
assert_eq!(version_msg.nonce, 13952548347456104954);
946946
assert_eq!(version_msg.user_agent, "/Satoshi:0.17.1/");
@@ -979,13 +979,13 @@ mod test {
979979
assert_eq!(msg.magic, 0xd9b4bef9);
980980
if let NetworkMessage::Version(version_msg) = msg.payload {
981981
assert_eq!(version_msg.version, 70015);
982-
assert_eq!(
983-
version_msg.services,
984-
ServiceFlags::NETWORK
985-
| ServiceFlags::BLOOM
986-
| ServiceFlags::WITNESS
987-
| ServiceFlags::NETWORK_LIMITED
988-
);
982+
let mut expected_services = ServiceFlags::NONE;
983+
expected_services
984+
.add(ServiceFlags::NETWORK)
985+
.add(ServiceFlags::BLOOM)
986+
.add(ServiceFlags::WITNESS)
987+
.add(ServiceFlags::NETWORK_LIMITED);
988+
assert_eq!(version_msg.services, expected_services);
989989
assert_eq!(version_msg.timestamp, 1548554224);
990990
assert_eq!(version_msg.nonce, 13952548347456104954);
991991
assert_eq!(version_msg.user_agent, "/Satoshi:0.17.1/");

0 commit comments

Comments
 (0)