-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathguards.rs
More file actions
192 lines (158 loc) · 5.6 KB
/
guards.rs
File metadata and controls
192 lines (158 loc) · 5.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
//! Thread spawn guards for limiting multi-agent capabilities.
use super::ThreadId;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::Mutex;
/// Maximum depth for thread spawning (0 = initial, 1 = first child).
pub const MAX_THREAD_SPAWN_DEPTH: i32 = 1;
/// Default maximum number of concurrent agent threads.
pub const DEFAULT_MAX_THREADS: usize = 10;
/// Guards for multi-agent spawn limits per user session.
pub struct Guards {
/// Set of active thread IDs.
threads_set: Mutex<HashSet<ThreadId>>,
/// Total count of spawned threads (includes completed).
total_count: AtomicUsize,
/// Count of pending reservations (not yet committed).
pending_count: AtomicUsize,
/// Maximum allowed concurrent threads.
max_threads: usize,
}
impl Guards {
/// Create new guards with default limits.
pub fn new() -> Self {
Self {
threads_set: Mutex::new(HashSet::new()),
total_count: AtomicUsize::new(0),
pending_count: AtomicUsize::new(0),
max_threads: DEFAULT_MAX_THREADS,
}
}
/// Create guards with custom max threads limit.
pub fn with_max_threads(max_threads: usize) -> Self {
Self {
threads_set: Mutex::new(HashSet::new()),
total_count: AtomicUsize::new(0),
pending_count: AtomicUsize::new(0),
max_threads,
}
}
/// Reserve a spawn slot, returning a RAII guard.
/// Returns None if limit is exceeded.
pub async fn reserve_spawn_slot(self: &Arc<Self>) -> Option<SpawnReservation> {
let threads = self.threads_set.lock().await;
let current_count = threads.len();
let pending = self.pending_count.load(Ordering::Acquire);
if current_count + pending >= self.max_threads {
return None;
}
// Increment pending and total count
self.pending_count.fetch_add(1, Ordering::AcqRel);
self.total_count.fetch_add(1, Ordering::AcqRel);
Some(SpawnReservation {
state: Arc::clone(self),
active: true,
})
}
/// Register a spawned thread after successful spawn.
pub(crate) async fn register_spawned_thread(&self, thread_id: ThreadId) {
let mut threads = self.threads_set.lock().await;
threads.insert(thread_id);
}
/// Release a spawned thread when it completes.
pub async fn release_spawned_thread(&self, thread_id: ThreadId) {
let mut threads = self.threads_set.lock().await;
threads.remove(&thread_id);
}
/// Get current number of active threads.
pub async fn active_count(&self) -> usize {
let threads = self.threads_set.lock().await;
threads.len()
}
/// Get total number of threads spawned (includes completed).
pub fn total_spawned(&self) -> usize {
self.total_count.load(Ordering::Acquire)
}
/// Get the maximum threads limit.
pub fn max_threads(&self) -> usize {
self.max_threads
}
}
impl Default for Guards {
fn default() -> Self {
Self::new()
}
}
/// RAII guard for spawn slot reservation.
/// If not committed, the slot is released on drop.
pub struct SpawnReservation {
state: Arc<Guards>,
active: bool,
}
impl SpawnReservation {
/// Commit the reservation with the actual thread ID.
/// This transfers ownership of the slot to the spawned thread.
pub async fn commit(mut self, thread_id: ThreadId) {
self.state.register_spawned_thread(thread_id).await;
// Decrement pending count since it's now committed
self.state.pending_count.fetch_sub(1, Ordering::AcqRel);
self.active = false;
}
/// Cancel the reservation without spawning.
pub fn cancel(mut self) {
if self.active {
self.state.pending_count.fetch_sub(1, Ordering::AcqRel);
self.state.total_count.fetch_sub(1, Ordering::AcqRel);
self.active = false;
}
}
}
impl Drop for SpawnReservation {
fn drop(&mut self) {
if self.active {
// If not committed, decrement the counts
self.state.pending_count.fetch_sub(1, Ordering::AcqRel);
self.state.total_count.fetch_sub(1, Ordering::AcqRel);
}
}
}
/// Check if a depth exceeds the thread spawn limit.
pub fn exceeds_thread_spawn_depth_limit(depth: i32) -> bool {
depth > MAX_THREAD_SPAWN_DEPTH
}
/// Get the next spawn depth from current depth.
pub fn next_spawn_depth(current_depth: i32) -> i32 {
current_depth + 1
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_guards_reservation() {
let guards = Arc::new(Guards::with_max_threads(2));
// Reserve first slot
let res1 = guards.reserve_spawn_slot().await;
assert!(res1.is_some());
// Reserve second slot
let res2 = guards.reserve_spawn_slot().await;
assert!(res2.is_some());
// Third should fail
let res3 = guards.reserve_spawn_slot().await;
assert!(res3.is_none());
// Commit first reservation
let thread_id = ThreadId::new();
res1.unwrap().commit(thread_id).await;
assert_eq!(guards.active_count().await, 1);
// Release the thread
guards.release_spawned_thread(thread_id).await;
assert_eq!(guards.active_count().await, 0);
}
#[test]
fn test_depth_limit() {
assert!(!exceeds_thread_spawn_depth_limit(0));
assert!(!exceeds_thread_spawn_depth_limit(1));
assert!(exceeds_thread_spawn_depth_limit(2));
assert!(exceeds_thread_spawn_depth_limit(5));
}
}