11use event_listener:: { Event , IntoNotification } ;
2- use parking_lot:: Mutex ;
32use std:: future:: Future ;
3+ use std:: num:: NonZero ;
44use std:: pin:: pin;
55use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
66use std:: sync:: Arc ;
77use std:: task:: Poll ;
88use std:: time:: Duration ;
99use std:: { array, iter} ;
1010
11+ use spin:: lock_api:: Mutex ;
12+
1113type ShardId = usize ;
1214type ConnectionIndex = usize ;
1315
1416/// Delay before a task waiting in a call to `acquire()` enters the global wait queue.
1517///
1618/// We want tasks to acquire from their local shards where possible, so they don't enter
1719/// the global queue immediately.
18- const GLOBAL_QUEUE_DELAY : Duration = Duration :: from_millis ( 5 ) ;
20+ const GLOBAL_QUEUE_DELAY : Duration = Duration :: from_millis ( 10 ) ;
21+
22+ /// Delay before attempting to acquire from a non-local shard,
23+ /// as well as the backoff when iterating through shards.
24+ const NON_LOCAL_ACQUIRE_DELAY : Duration = Duration :: from_micros ( 100 ) ;
1925
2026pub struct Sharded < T > {
2127 shards : Box < [ ArcShard < T > ] > ,
@@ -29,11 +35,10 @@ struct Global<T> {
2935 disconnect_event : Event < LockGuard < T > > ,
3036}
3137
32- type ArcMutexGuard < T > = parking_lot :: ArcMutexGuard < parking_lot :: RawMutex , Option < T > > ;
38+ type ArcMutexGuard < T > = lock_api :: ArcMutexGuard < spin :: Mutex < ( ) > , Option < T > > ;
3339
3440pub struct LockGuard < T > {
35- // `Option` allows us to drop the guard before sending the notification.
36- // Otherwise, if the receiver wakes too quickly, it might fail to lock the mutex.
41+ // `Option` allows us to take the guard in the drop handler.
3742 locked : Option < ArcMutexGuard < T > > ,
3843 shard : ArcShard < T > ,
3944 index : ConnectionIndex ,
@@ -73,13 +78,13 @@ const MAX_SHARD_SIZE: usize = if usize::BITS > 64 {
7378} ;
7479
7580impl < T > Sharded < T > {
76- pub fn new ( connections : usize , shards : usize ) -> Sharded < T > {
81+ pub fn new ( connections : usize , shards : NonZero < usize > ) -> Sharded < T > {
7782 let global = Arc :: new ( Global {
7883 unlock_event : Event :: with_tag ( ) ,
7984 disconnect_event : Event :: with_tag ( ) ,
8085 } ) ;
8186
82- let shards = Params :: calc ( connections, shards)
87+ let shards = Params :: calc ( connections, shards. get ( ) )
8388 . shard_sizes ( )
8489 . enumerate ( )
8590 . map ( |( shard_id, size) | Shard :: new ( shard_id, size, global. clone ( ) ) )
@@ -89,8 +94,28 @@ impl<T> Sharded<T> {
8994 }
9095
9196 pub async fn acquire ( & self , connected : bool ) -> LockGuard < T > {
92- let mut acquire_local =
93- pin ! ( self . shards[ thread_id( ) % self . shards. len( ) ] . acquire( connected) ) ;
97+ if self . shards . len ( ) == 1 {
98+ return self . shards [ 0 ] . acquire ( connected) . await ;
99+ }
100+
101+ let thread_id = current_thread_id ( ) ;
102+
103+ let mut acquire_local = pin ! ( self . shards[ thread_id % self . shards. len( ) ] . acquire( connected) ) ;
104+
105+ let mut acquire_nonlocal = pin ! ( async {
106+ let mut next_shard = thread_id;
107+
108+ loop {
109+ crate :: rt:: sleep( NON_LOCAL_ACQUIRE_DELAY ) . await ;
110+
111+ // Choose shards pseudorandomly by multiplying with a (relatively) large prime.
112+ next_shard = ( next_shard. wrapping_mul( 547 ) ) % self . shards. len( ) ;
113+
114+ if let Some ( locked) = self . shards[ next_shard] . try_acquire( connected) {
115+ return locked;
116+ }
117+ }
118+ } ) ;
94119
95120 let mut acquire_global = pin ! ( async {
96121 crate :: rt:: sleep( GLOBAL_QUEUE_DELAY ) . await ;
@@ -113,6 +138,10 @@ impl<T> Sharded<T> {
113138 return Poll :: Ready ( locked) ;
114139 }
115140
141+ if let Poll :: Ready ( locked) = acquire_nonlocal. as_mut ( ) . poll ( cx) {
142+ return Poll :: Ready ( locked) ;
143+ }
144+
116145 if let Poll :: Ready ( locked) = acquire_global. as_mut ( ) . poll ( cx) {
117146 return Poll :: Ready ( locked) ;
118147 }
@@ -125,6 +154,9 @@ impl<T> Sharded<T> {
125154
126155impl < T > Shard < T , [ Arc < Mutex < Option < T > > > ] > {
127156 fn new ( shard_id : ShardId , len : usize , global : Arc < Global < T > > ) -> Arc < Self > {
157+ // There's no way to create DSTs like this, in `std::sync::Arc`, on stable.
158+ //
159+ // Instead, we coerce from an array.
128160 macro_rules! make_array {
129161 ( $( $n: literal) ,+) => {
130162 match len {
@@ -206,6 +238,8 @@ impl<T> Shard<T, [Arc<Mutex<Option<T>>>]> {
206238
207239impl Params {
208240 fn calc ( connections : usize , mut shards : usize ) -> Params {
241+ assert_ne ! ( shards, 0 ) ;
242+
209243 let mut shard_size = connections / shards;
210244 let mut remainder = connections % shards;
211245
@@ -217,7 +251,11 @@ impl Params {
217251 } else if shard_size >= MAX_SHARD_SIZE {
218252 let new_shards = connections. div_ceil ( MAX_SHARD_SIZE ) ;
219253
220- tracing:: debug!( connections, shards, "clamping shard count to {new_shards}" ) ;
254+ tracing:: debug!(
255+ connections,
256+ shards,
257+ "shard size exceeds {MAX_SHARD_SIZE}, clamping shard count to {new_shards}"
258+ ) ;
221259
222260 shards = new_shards;
223261 shard_size = connections / shards;
@@ -239,7 +277,7 @@ impl Params {
239277 }
240278}
241279
242- fn thread_id ( ) -> usize {
280+ fn current_thread_id ( ) -> usize {
243281 // FIXME: this can be replaced when this is stabilized:
244282 // https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64
245283 static THREAD_ID : AtomicUsize = AtomicUsize :: new ( 0 ) ;
0 commit comments