Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ jobs:
name: ${{ matrix.artifact }}

- name: Install dependencies
run: npm install -g npm@latest && npm ci
run: npm ci

- name: Run tests
run: npm test
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ libc = "0.2"
socket2 = { version = "0.5", features = ["all"] }
ring = "0.17"
md-5 = "0.10"
zeroize = "1"

[build-dependencies]
napi-build = "2.1"
Expand Down
2 changes: 1 addition & 1 deletion __test__/suspended.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ describe('Suspended routes – hold then resolve', () => {
// resolved route builds its own TLS config (the original route's config
// is not reused for resolved connections).
proxy.resolveConnection(capturedConn!.id, {
upstreams: [{ kind: 'tcp', host: '127.0.0.1', port: echo.port }],
upstream: { kind: 'tcp', host: '127.0.0.1', port: echo.port },
terminateTls: true,
cert: { certChain: cert.cert, privateKey: cert.key },
});
Expand Down
60 changes: 56 additions & 4 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 19 additions & 9 deletions src/protection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ impl BlockReason {

#[derive(Debug)]
pub enum Decision {
/// Connection allowed and active counter incremented — release() must be called on close.
Allow,
/// Connection allowed via allowlist — active counter was NOT incremented; release() is a no-op.
AllowBypassed,
Block(BlockReason),
}

Expand Down Expand Up @@ -135,10 +138,10 @@ impl ProtectionState {
pub fn check(&self, peer_ip: IpAddr, peek_info: &PeekInfo) -> Decision {
let cfg = self.config.load();

// 1. Allowlist — skip all other checks
// 1. Allowlist — skip all other checks; active counter is NOT incremented
for network in &self.allowlist {
if network.contains(peer_ip) {
return Decision::Allow;
return Decision::AllowBypassed;
}
}

Expand Down Expand Up @@ -187,7 +190,11 @@ impl ProtectionState {
.compare_exchange(old_tokens, new_tokens, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
state.last_refill_ns.store(now, Ordering::Relaxed);
// CAS on the timestamp so only the first winner advances the refill window.
// A losing compare_exchange here is harmless — another thread already wrote it.
let _ = state.last_refill_ns.compare_exchange(
last, now, Ordering::Relaxed, Ordering::Relaxed,
);
break;
}
// CAS failed — another thread beat us; retry
Expand All @@ -210,16 +217,19 @@ impl ProtectionState {
}
}

// 6. Concurrency limit
// 6. Concurrency limit — atomic test-and-increment to avoid TOCTOU
if cfg.max_concurrent_per_ip > 0 {
let active = state.active.load(Ordering::Relaxed);
if active >= cfg.max_concurrent_per_ip {
let max = cfg.max_concurrent_per_ip;
let result = state.active.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
if v < max { Some(v + 1) } else { None }
});
if result.is_err() {
return Decision::Block(BlockReason::TooManyConnections);
}
// fetch_update already incremented the counter
} else {
state.active.fetch_add(1, Ordering::Relaxed);
}

// Allow — increment active counter
state.active.fetch_add(1, Ordering::Relaxed);
Decision::Allow
}

Expand Down
74 changes: 46 additions & 28 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ pub struct JsBlockedIpsInfo {

#[napi(object)]
pub struct JsResolveRoute {
pub upstreams: Vec<JsUpstream>,
pub upstream: JsUpstream,
pub terminate_tls: bool,
pub cert: Option<JsCertConfig>,
pub mtls: Option<JsMtlsConfig>,
Expand Down Expand Up @@ -197,13 +197,17 @@ impl SymphonyProxyWrap {
.unwrap_or_else(ListenerTlsSpec::empty);

let worker_threads = config.worker_threads.unwrap_or(num_cpus()) as usize;
let idle_timeout = Duration::from_millis(
config
.listeners
.first()
.and_then(|l| l.idle_timeout_ms)
.unwrap_or(60_000.0) as u64,
);
// 0 means "no idle timeout" — stored as Duration::ZERO and checked in proxy_conn.rs.
let idle_timeout_ms = config
.listeners
.first()
.and_then(|l| l.idle_timeout_ms)
.unwrap_or(60_000.0);
let idle_timeout = if idle_timeout_ms > 0.0 {
Duration::from_millis(idle_timeout_ms as u64)
} else {
Duration::ZERO
};
let read_buffer_size = config.read_buffer_size.unwrap_or(65_536) as usize;

let mut internal_listeners = Vec::new();
Expand Down Expand Up @@ -308,6 +312,14 @@ impl SymphonyProxyWrap {

for (i, listener) in self.listeners.iter().enumerate() {
let state = &self.listener_states[i];
// Use the TLS handshake timeout as the upstream connect timeout when
// protection is configured; otherwise fall back to 30 s.
let upstream_connect_timeout = state
.protection
.as_ref()
.map(|p| p.config.load().tls_handshake_timeout())
.unwrap_or(std::time::Duration::from_secs(30));

let ctx = Arc::new(ConnContext {
route_table: self.route_table.clone(),
protection: state.protection.clone(),
Expand All @@ -316,6 +328,7 @@ impl SymphonyProxyWrap {
listener_metrics: state.metrics.clone(),
listener_addr: state.addr.clone(),
idle_timeout: self.idle_timeout,
upstream_connect_timeout,
read_buffer_size: self.read_buffer_size,
js_emit: self.js_emit.clone(),
});
Expand Down Expand Up @@ -463,7 +476,7 @@ fn parse_route_spec(r: &JsRouteConfig) -> Result<RouteSpec> {
cert_pem: r.cert.as_ref().map(|c| pem_bytes(&c.cert_chain)),
key_pem: r.cert.as_ref().map(|c| pem_bytes(&c.private_key)),
mtls_ca_pem: r.mtls.as_ref().map(|m| pem_bytes(&m.client_ca_cert)),
require_client_cert: r.mtls.as_ref().and_then(|m| m.require_client_cert).unwrap_or(true),
require_client_cert: r.mtls.as_ref().and_then(|m| m.require_client_cert).unwrap_or(false),
suspended: r.suspended.unwrap_or(false),
suspend_timeout_ms: r.suspend_timeout_ms.unwrap_or(30_000.0) as u64,
max_cps: r.max_connections_per_second,
Expand Down Expand Up @@ -512,21 +525,17 @@ fn parse_upstream_spec(u: &JsUpstream, sni: &str) -> Result<UpstreamSpec> {
}

fn parse_resolve_spec(r: &JsResolveRoute) -> Result<ResolveSpec> {
let upstream = r
.upstreams
.first()
.ok_or_else(|| napi::Error::from_reason("resolveConnection: upstreams must not be empty".to_string()))
.and_then(|u| match parse_upstream_spec(u, "<resolved>")? {
UpstreamSpec::Tcp { host, port } => {
let addr = format!("{host}:{port}")
.parse()
.map_err(|e| napi::Error::from_reason(format!("invalid address: {e}")))?;
Ok(ResolveUpstream::Tcp(addr))
}
UpstreamSpec::Uds { paths, ip_affinity, affinity_ttl_ms, .. } => {
Ok(ResolveUpstream::Uds { paths, ip_affinity, affinity_ttl_ms })
}
})?;
let upstream = match parse_upstream_spec(&r.upstream, "<resolved>")? {
UpstreamSpec::Tcp { host, port } => {
let addr = format!("{host}:{port}")
.parse()
.map_err(|e| napi::Error::from_reason(format!("invalid address: {e}")))?;
ResolveUpstream::Tcp(addr)
}
UpstreamSpec::Uds { paths, ip_affinity, affinity_ttl_ms, .. } => {
ResolveUpstream::Uds { paths, ip_affinity, affinity_ttl_ms }
}
};

let has_uds = matches!(&upstream, ResolveUpstream::Uds { .. });
let source_address_mode = parse_source_address_mode(r.source_address_header.as_deref(), has_uds)?;
Expand All @@ -537,7 +546,7 @@ fn parse_resolve_spec(r: &JsResolveRoute) -> Result<ResolveSpec> {
cert_pem: r.cert.as_ref().map(|c| pem_bytes(&c.cert_chain)),
key_pem: r.cert.as_ref().map(|c| pem_bytes(&c.private_key)),
mtls_ca_pem: r.mtls.as_ref().map(|m| pem_bytes(&m.client_ca_cert)),
require_client_cert: r.mtls.as_ref().and_then(|m| m.require_client_cert).unwrap_or(true),
require_client_cert: r.mtls.as_ref().and_then(|m| m.require_client_cert).unwrap_or(false),
source_address_mode,
http2: r.http2.unwrap_or(false),
})
Expand Down Expand Up @@ -575,8 +584,11 @@ fn parse_protection_config(
.as_deref()
.unwrap_or(&[])
.iter()
.filter_map(|s| s.parse().ok())
.collect();
.map(|s| {
s.parse::<IpNetwork>()
.map_err(|e| napi::Error::from_reason(format!("invalid allowlist CIDR '{s}': {e}")))
})
.collect::<Result<Vec<_>>>()?;

let blocklist_strings: Vec<String> = prot
.blocklist
Expand All @@ -586,7 +598,13 @@ fn parse_protection_config(
.cloned()
.collect();

let blocklist: Vec<IpNetwork> = blocklist_strings.iter().filter_map(|s| s.parse().ok()).collect();
let blocklist: Vec<IpNetwork> = blocklist_strings
.iter()
.map(|s| {
s.parse::<IpNetwork>()
.map_err(|e| napi::Error::from_reason(format!("invalid blocklist CIDR '{s}': {e}")))
})
.collect::<Result<Vec<_>>>()?;

Ok((cfg, allowlist, blocklist, blocklist_strings))
}
Expand Down
Loading
Loading