Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 75 additions & 17 deletions anyscan_rate_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@

from __future__ import annotations

import fcntl
import json
import os
import re
import subprocess
import sys
import threading
import time
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import Callable, Iterable, Mapping, Optional
from typing import Callable, Iterable, Iterator, Mapping, Optional


DEFAULT_FLOOR = 100_000
Expand Down Expand Up @@ -251,8 +253,13 @@ class RateCalibrationStore:
"""Thin JSON-on-disk persistence for per-interface learned rates.

Keeps state across scans so a freshly dispatched worker doesn't have to
re-discover its ceiling on every job. Writes go through a tempfile +
atomic rename so a half-written file can never be observed.
re-discover its ceiling on every job. Writes go through a per-pid
tempfile + atomic rename so a half-written file can never be observed,
and the read-modify-write cycle is serialized across processes via
fcntl.flock on a sibling lockfile so the multi-NIC parent's concurrent
shards don't clobber each other (each shard converges on its own
interface; pre-lock the last writer's view of {interfaces: {iface_X:
...}} silently wiped every other shard's entry).
"""

SCHEMA_VERSION = 1
Expand All @@ -264,6 +271,10 @@ def __init__(self, path: os.PathLike[str] | str = DEFAULT_CALIBRATION_PATH) -> N
def path(self) -> Path:
return self._path

@property
def _lock_path(self) -> Path:
return self._path.with_suffix(self._path.suffix + ".lock")

def load(self) -> dict[str, CalibrationEntry]:
try:
raw = self._path.read_text()
Expand Down Expand Up @@ -296,27 +307,74 @@ def lookup(self, interface: str) -> Optional[CalibrationEntry]:
def store(self, interface: str, learned_rate: int, *, now_iso: Optional[str] = None) -> None:
if learned_rate <= 0:
return
entries = self.load()
timestamp = now_iso if now_iso is not None else _utc_now_iso()
entries[interface] = CalibrationEntry(learned_rate, timestamp)
payload = {
"version": self.SCHEMA_VERSION,
"interfaces": {
key: {"learned_rate": entry.learned_rate, "updated_at": entry.updated_at}
for key, entry in entries.items()
},
}
try:
self._path.parent.mkdir(parents=True, exist_ok=True)
except OSError:
return
tmp_path = self._path.with_suffix(self._path.suffix + ".tmp")
timestamp = now_iso if now_iso is not None else _utc_now_iso()
with self._locked_for_write():
# Re-read inside the lock so concurrent shards observe each
# other's entries instead of clobbering with a stale view.
entries = self.load()
entries[interface] = CalibrationEntry(learned_rate, timestamp)
payload = {
"version": self.SCHEMA_VERSION,
"interfaces": {
key: {"learned_rate": entry.learned_rate, "updated_at": entry.updated_at}
for key, entry in entries.items()
},
}
# Per-pid tmp filename so concurrent shards don't overwrite
# each other's pre-rename payloads. The flock above already
# serializes them, but per-pid tmp keeps cleanup well-defined
# if a shard dies between write and rename.
tmp_path = self._path.with_suffix(
self._path.suffix + f".tmp.{os.getpid()}"
)
try:
tmp_path.write_text(json.dumps(payload, sort_keys=True))
os.replace(tmp_path, self._path)
except OSError:
try:
tmp_path.unlink()
except OSError:
pass

@contextmanager
def _locked_for_write(self) -> Iterator[None]:
"""Hold an exclusive flock for the lifetime of a read-modify-write.

On hosts where flock is unsupported (rare; most filesystems with a
Linux kernel grant it) the lock acquisition is best-effort — we
fall through to the unprotected write rather than blocking
calibration entirely. Contention is bounded: the held interval is
a few millis (json.dumps + tempfile write + rename).
"""

lock_handle = None
try:
tmp_path.write_text(json.dumps(payload, sort_keys=True))
os.replace(tmp_path, self._path)
lock_handle = open(self._lock_path, "w")
except OSError:
yield
return
try:
try:
fcntl.flock(lock_handle.fileno(), fcntl.LOCK_EX)
except OSError:
# Filesystem refused the lock; proceed unprotected so a
# missing lock primitive doesn't drop calibration entirely.
yield
return
try:
yield
finally:
try:
fcntl.flock(lock_handle.fileno(), fcntl.LOCK_UN)
except OSError:
pass
finally:
try:
tmp_path.unlink()
lock_handle.close()
except OSError:
pass

Expand Down
142 changes: 142 additions & 0 deletions test_anyscan_rate_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,96 @@ def test_zero_or_negative_rates_skipped(self) -> None:
store.store("eth0", -100)
self.assertFalse(path.exists())

def test_concurrent_writes_from_multiple_processes_all_persist(self) -> None:
"""Multi-NIC parent spawns N shard children; each child's controller
writes to the SAME calibration file when it converges. Pre-fix the
children clobbered each other (shared tmp filename + read-modify-write
race) so only one shard's calibration survived. Post-fix all shards'
entries land.
"""

import multiprocessing as mp

def _store_in_subprocess(path_str: str, interface: str, rate: int) -> None:
# Re-import inside the worker so the subprocess has a clean
# module state (no shared file handles across fork).
from pathlib import Path as _Path
import anyscan_rate_controller as _rc

store = _rc.RateCalibrationStore(_Path(path_str))
store.store(interface, rate, now_iso="2026-04-27T12:00:00Z")

with tempfile.TemporaryDirectory() as tmpdir:
path = Path(tmpdir) / "rate-calibration.json"
# Six shards, fired simultaneously. Each writes a distinct
# interface; the per-interface rate value is the rate the
# synthetic AIMD loop "converged" to.
interfaces = [(f"eth{i}", 1_000_000 + i * 100_000) for i in range(6)]
ctx = mp.get_context("fork")
workers = [
ctx.Process(
target=_store_in_subprocess,
args=(str(path), iface, rate),
)
for iface, rate in interfaces
]
for w in workers:
w.start()
for w in workers:
w.join(timeout=10)
self.assertEqual(w.exitcode, 0, msg=f"worker {w.pid} failed")

store = rc.RateCalibrationStore(path)
entries = store.load()
# Pre-fix: only one entry survived (last-writer-wins on the
# shared tmp + clobbered .json). Post-fix: all six landed.
self.assertEqual(
set(entries.keys()),
{iface for iface, _ in interfaces},
msg=f"expected all 6 interfaces persisted, got {list(entries.keys())}",
)
for iface, rate in interfaces:
self.assertEqual(entries[iface].learned_rate, rate)

def test_concurrent_writes_leave_no_dangling_tmp_files(self) -> None:
"""The pre-fix race left orphan .tmp files on disk when the second
os.replace failed (source already moved). Post-fix every writer
cleans up its own per-pid tmp regardless of outcome.
"""

import multiprocessing as mp

def _store_in_subprocess(path_str: str, interface: str, rate: int) -> None:
from pathlib import Path as _Path
import anyscan_rate_controller as _rc

store = _rc.RateCalibrationStore(_Path(path_str))
store.store(interface, rate)

with tempfile.TemporaryDirectory() as tmpdir:
path = Path(tmpdir) / "rate-calibration.json"
ctx = mp.get_context("fork")
workers = [
ctx.Process(
target=_store_in_subprocess,
args=(str(path), f"eth{i}", 500_000 + i * 100_000),
)
for i in range(8)
]
for w in workers:
w.start()
for w in workers:
w.join(timeout=10)
self.assertEqual(w.exitcode, 0)

# No stray .tmp.* files: success and failure paths both clean
# up after themselves.
stray = list(Path(tmpdir).glob("rate-calibration.json.tmp*"))
self.assertEqual(stray, [], msg=f"orphan tmp files left: {stray}")
# The final file is valid JSON we can round-trip.
store = rc.RateCalibrationStore(path)
self.assertEqual(len(store.load()), 8)


@dataclass
class StubWindow:
Expand Down Expand Up @@ -853,6 +943,58 @@ def test_persists_when_crash_strikes_before_any_clean_window(self) -> None:
self.assertIsNotNone(entry)
self.assertEqual(entry.learned_rate, 1_500_000) # untouched

def test_persists_when_interrupted_by_systemexit_mid_loop(self) -> None:
"""Signal-handler exit path: SIGTERM/SIGINT in the adapter raises
SystemExit (handle_termination in vulnscanner-zmap-adapter.py).
That exception unwinds through SubprocessWindowRunner.run() and
out of RateController.run()'s while loop. The controller's
try/finally must still persist max_clean_rate so the calibration
learned in the windows that DID complete is not lost on a SIGTERM
from the multi-NIC parent or agentd.
"""

class InterruptingRunner(rc.WindowRunner):
def __init__(self) -> None:
self._idx = 0

def run(self, *, rate, window_seconds, is_first_window):
self._idx += 1
if self._idx == 1:
return make_measurement(set_rate=rate, achieved_pps=rate * 0.97)
if self._idx == 2:
return make_measurement(
set_rate=rate, achieved_pps=rate * 0.97
)
# Window 3: simulate the adapter's SIGTERM handler which
# raises SystemExit(128 + signum). Real signal handlers
# cannot raise during a child.wait() syscall reliably,
# but the exception path the controller has to survive
# is identical to the one a Python-level handler would
# produce.
raise SystemExit(143)

with tempfile.TemporaryDirectory() as tmpdir:
calib = rc.RateCalibrationStore(Path(tmpdir) / "rate-calibration.json")
policy = rc.AimdPolicy(window_seconds=30)
controller = rc.RateController(
options=rc.ControllerOptions(
policy=policy,
window_seconds=float(policy.window_seconds),
interface="eth2",
starting_rate=500_000,
calibration=calib,
),
runner=InterruptingRunner(),
log_sink=io.StringIO(),
)
with self.assertRaises(SystemExit):
controller.run()
entry = calib.lookup("eth2")
self.assertIsNotNone(entry, "SystemExit must not bypass terminal persist")
# Highest CLEAN rate observed before the interrupt was 700k
# (window 2 ran at 700k after window 1's clean@500k bumped it).
self.assertEqual(entry.learned_rate, 700_000)

def test_persists_after_natural_finish(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
calib = rc.RateCalibrationStore(Path(tmpdir) / "rate-calibration.json")
Expand Down
Loading