diff --git a/src/saev/utils/monitoring.py b/src/saev/utils/monitoring.py index 2cf41c53..2eaab493 100644 --- a/src/saev/utils/monitoring.py +++ b/src/saev/utils/monitoring.py @@ -1,5 +1,6 @@ -import dataclasses import logging +import time +from collections.abc import Callable import beartype import psutil @@ -8,87 +9,212 @@ @beartype.beartype -@dataclasses.dataclass(slots=True) -class LoaderMonitor: - last_rb: int | None = None - last_t: float | None = None - current_pid: int | None = None - can_read_io: bool = True - can_read_cpu: bool = True - warned_io: bool = False - warned_cpu: bool = False - - def collect( +class DataloaderMonitor: + """ + Tracks IO and CPU activity for the dataloader manager process and its children. + + The monitor owns the dataloader handle and psutil processes internally, so callers + simply construct it with the dataloader and then call `compute()` whenever metrics + are needed. + """ + + def __init__( self, - p_dataloader: psutil.Process | None, - p_children: list[psutil.Process], - reservoir_fill: float, - now: float, - ) -> dict[str, float]: - if p_dataloader is None: - self.current_pid = None - return {} - - if self.current_pid != p_dataloader.pid: - self.current_pid = p_dataloader.pid - self.last_rb = None - self.last_t = None - self.can_read_io = True - self.can_read_cpu = True - self.warned_io = False - self.warned_cpu = False + dataloader: object, + process_factory: Callable[[int], psutil.Process] | None = None, + ) -> None: + self.dataloader = dataloader + self.process_factory = process_factory or psutil.Process + self._reset_state() + + def attach(self, dataloader: object) -> None: + if dataloader is self.dataloader: + return + self.dataloader = dataloader + self._reset_state() + + def compute(self, now: float | None = None) -> dict[str, float]: + if now is None: + now = time.time() + + metrics: dict[str, float] = { + "loader/buffer_fill": self._get_reservoir_fill(self.dataloader) + } + + manager_pid = self._get_manager_pid(self.dataloader) + if manager_pid <= 0: + self._reset_state(preserve_warnings=True) + return metrics - metrics = {"loader/buffer_fill": reservoir_fill} + if self.current_pid != manager_pid: + self._reset_state() + self.current_pid = manager_pid + + process = self._ensure_process(manager_pid) + if process is None: + return metrics + + self._update_children(process) if self.can_read_io: - try: - io_counters = p_dataloader.io_counters() - except ( - psutil.AccessDenied, - psutil.NoSuchProcess, - psutil.ZombieProcess, - ) as err: - self.can_read_io = False - self.last_rb = None - self.last_t = None - if not self.warned_io: - logger.warning("Disabling dataloader IO metrics: %s", err) - self.warned_io = True - else: - rb = io_counters.read_bytes - if self.last_rb is None or self.last_t is None: - read_mb = 0.0 - read_mb_s = 0.0 - else: - read_mb = max(rb - self.last_rb, 0) / (1024 * 1024) - interval = max(now - self.last_t, 1e-6) - read_mb_s = read_mb / interval - self.last_rb, self.last_t = rb, now + read = self._read_bytes(process, now) + if read is not None: + read_mb, read_mb_s = read metrics["loader/read_mb"] = read_mb metrics["loader/read_mb_s"] = read_mb_s if self.can_read_cpu: - cpu_util = 0.0 - for child in p_children: - try: - cpu_util += child.cpu_percent(None) - except (psutil.NoSuchProcess, psutil.ZombieProcess): - continue - except psutil.AccessDenied: - continue + cpu_total = 0.0 + for child in self.children: + cpu = self._read_cpu_percent(child, is_parent=False) + if cpu is not None: + cpu_total += cpu + parent_cpu = self._read_cpu_percent(process, is_parent=True) + if parent_cpu is not None: + cpu_total += parent_cpu + metrics["loader/cpu_util"] = cpu_total + else: + self.warned_cpu = True + + return metrics + + # Internal helpers ----------------------------------------------------------------- + + def _reset_state(self, *, preserve_warnings: bool = False) -> None: + self.last_rb: int | None = None + self.last_t: float | None = None + self.current_pid: int | None = None + self.process: object | None = None + self.children: list[object] = [] + self.can_read_io = True + self.can_read_cpu = True + if not preserve_warnings: + self.warned_io = False + self.warned_cpu = False + + def _ensure_process(self, pid: int) -> object | None: + process = self.process + if ( + process is None + or getattr(process, "pid", None) != pid + or not self._is_running(process) + ): try: - cpu_util += p_dataloader.cpu_percent(None) - except (psutil.NoSuchProcess, psutil.ZombieProcess): + process = self.process_factory(pid) + except Exception: # noqa: BLE001 + return None + self.process = process + return process + + @staticmethod + def _is_running(process: object) -> bool: + if not hasattr(process, "is_running"): + return True + try: + return bool(process.is_running()) + except (psutil.NoSuchProcess, psutil.ZombieProcess): + return False + except Exception: # noqa: BLE001 + return False + + def _update_children(self, process: object) -> None: + if not hasattr(process, "children"): + self.children = [] + return + try: + children = process.children(recursive=True) + except psutil.Error: + self.children = [] + except Exception: # noqa: BLE001 + self.children = [] + else: + self.children = list(children) if children is not None else [] + + def _read_bytes(self, process: object, now: float) -> tuple[float, float] | None: + if not hasattr(process, "io_counters"): + return None + try: + counters = process.io_counters() + except ( + psutil.AccessDenied, + psutil.NoSuchProcess, + psutil.ZombieProcess, + ) as err: + self._disable_io(err) + return None + except Exception as err: # noqa: BLE001 + self._disable_io(err) + return None + + rb = getattr(counters, "read_bytes", None) + if rb is None: + return None + + if self.last_rb is None or self.last_t is None: + read_mb = 0.0 + read_mb_s = 0.0 + else: + read_mb = max(rb - self.last_rb, 0) / (1024 * 1024) + interval = max(now - self.last_t, 1e-6) + read_mb_s = read_mb / interval + self.last_rb, self.last_t = rb, now + return read_mb, read_mb_s + + def _disable_io(self, err: Exception) -> None: + self.can_read_io = False + self.last_rb = None + self.last_t = None + if not self.warned_io: + logger.warning("Disabling dataloader IO metrics: %s", err) + self.warned_io = True + + def _read_cpu_percent(self, process: object, *, is_parent: bool) -> float | None: + if not hasattr(process, "cpu_percent"): + return 0.0 + try: + value = process.cpu_percent(None) + except (psutil.NoSuchProcess, psutil.ZombieProcess): + if is_parent: self.can_read_cpu = False - except psutil.AccessDenied as err: + return None + except psutil.AccessDenied as err: + if is_parent: self.can_read_cpu = False if not self.warned_cpu: logger.warning("Disabling dataloader CPU metrics: %s", err) self.warned_cpu = True - else: - metrics["loader/cpu_util"] = cpu_util + return None + except Exception: + if is_parent: + self.can_read_cpu = False + return None - if not self.can_read_cpu: - self.warned_cpu = True + try: + return float(value) + except (TypeError, ValueError): + return 0.0 - return metrics + @staticmethod + def _get_manager_pid(dataloader: object) -> int: + pid = getattr(dataloader, "manager_pid", None) + if callable(pid): + try: + pid = pid() + except Exception: # noqa: BLE001 + return -1 + if pid is None: + return -1 + try: + return int(pid) + except (TypeError, ValueError): + return -1 + + @staticmethod + def _get_reservoir_fill(dataloader: object) -> float: + reservoir = getattr(dataloader, "reservoir", None) + if reservoir is None or not hasattr(reservoir, "fill"): + return 0.0 + try: + return float(reservoir.fill()) + except Exception: # noqa: BLE001 + return 0.0 diff --git a/tests/test_utils_monitoring.py b/tests/test_utils_monitoring.py index b492c2d9..8aee7ac8 100644 --- a/tests/test_utils_monitoring.py +++ b/tests/test_utils_monitoring.py @@ -3,89 +3,445 @@ import psutil -from saev.utils.monitoring import LoaderMonitor +from saev.utils.monitoring import DataloaderMonitor + + +class _StubReservoir: + def __init__(self, fill_value: float) -> None: + self._fill_value = fill_value + + def fill(self) -> float: + return self._fill_value + + def set_fill(self, value: float) -> None: + self._fill_value = value + + +class _StubLoader: + def __init__(self, reservoir: _StubReservoir, manager_pid: int = -1) -> None: + self.reservoir = reservoir + self._manager_pid = manager_pid + + @property + def manager_pid(self) -> int: + return self._manager_pid + + def set_manager_pid(self, pid: int) -> None: + self._manager_pid = pid class _StubProcess: def __init__( self, pid: int, - io_exc: Exception | None, - read_bytes: int, - cpu_percent_value: float, + *, + read_bytes: int = 0, + cpu_percent: float = 0.0, + io_exc: Exception | None = None, + cpu_exc: Exception | None = None, + running_exc: Exception | None = None, + running: bool = True, + children: list["_StubProcess"] | None = None, + children_exc: Exception | None = None, ) -> None: self.pid = pid - self._io_exc = io_exc self._read_bytes = read_bytes - self._cpu_percent_value = cpu_percent_value + self._cpu_percent = cpu_percent + self._io_exc = io_exc + self._cpu_exc = cpu_exc + self._running_exc = running_exc + self._running = running + self._children = children or [] + self._children_exc = children_exc def io_counters(self) -> types.SimpleNamespace: if self._io_exc is not None: raise self._io_exc return types.SimpleNamespace(read_bytes=self._read_bytes) - def cpu_percent(self, interval: float | None) -> float: - return self._cpu_percent_value - def set_read_bytes(self, value: int) -> None: self._read_bytes = value + def cpu_percent(self, interval: float | None) -> float: + if self._cpu_exc is not None: + raise self._cpu_exc + return self._cpu_percent -def test_loader_monitor_resets_when_pid_changes(): - now = time.time() - monitor = LoaderMonitor() + def set_cpu_percent(self, value: float) -> None: + self._cpu_percent = value - failing_proc = _StubProcess( - pid=123, - io_exc=psutil.AccessDenied(pid=123, name="stub"), - read_bytes=1024, - cpu_percent_value=5.0, - ) + def children(self, recursive: bool) -> list["_StubProcess"]: + if self._children_exc is not None: + raise self._children_exc + return self._children + + def set_children(self, children: list["_StubProcess"]) -> None: + self._children = children + + def is_running(self) -> bool: + if self._running_exc is not None: + raise self._running_exc + return self._running + + def set_running(self, value: bool) -> None: + self._running = value + + +def test_monitor_returns_buffer_when_manager_missing(): + loader = _StubLoader(_StubReservoir(0.4), manager_pid=-1) + monitor = DataloaderMonitor(loader) + metrics = monitor.compute() + assert metrics == {"loader/buffer_fill": 0.4} + + +def test_monitor_preserves_warnings_when_manager_missing(): + loader = _StubLoader(_StubReservoir(0.1), manager_pid=-1) + monitor = DataloaderMonitor(loader) + monitor.warned_cpu = True + monitor.warned_io = True + monitor.can_read_cpu = False + metrics = monitor.compute() + assert metrics == {"loader/buffer_fill": 0.1} + assert monitor.warned_cpu is True + assert monitor.warned_io is True + + +def test_monitor_tracks_io_and_cpu_across_steps(): + reservoir = _StubReservoir(0.5) + loader = _StubLoader(reservoir, manager_pid=123) + + child = _StubProcess(pid=124, cpu_percent=5.0) + parent = _StubProcess(pid=123, read_bytes=1024, cpu_percent=7.5, children=[child]) + processes = {123: parent} + + def _factory(pid: int) -> _StubProcess: + return processes[pid] + + monitor = DataloaderMonitor(loader, process_factory=_factory) + + metrics_first = monitor.compute(now=time.time()) + assert metrics_first["loader/buffer_fill"] == 0.5 + assert metrics_first["loader/read_mb"] == 0.0 + assert metrics_first["loader/read_mb_s"] == 0.0 + assert metrics_first["loader/cpu_util"] == 12.5 + + parent.set_read_bytes(3072) + parent.set_cpu_percent(10.0) + child.set_cpu_percent(6.0) + reservoir.set_fill(0.6) + + metrics_second = monitor.compute(now=time.time() + 2.0) + assert metrics_second["loader/buffer_fill"] == 0.6 + assert metrics_second["loader/read_mb"] > 0.0 + assert metrics_second["loader/read_mb_s"] > 0.0 + assert metrics_second["loader/cpu_util"] == 16.0 - metrics = LoaderMonitor.collect.__wrapped__( - monitor, - p_dataloader=failing_proc, - p_children=[], - reservoir_fill=0.5, - now=now, + +def test_monitor_disables_io_on_access_denied(): + loader = _StubLoader(_StubReservoir(0.0), manager_pid=5) + process = _StubProcess( + pid=5, io_exc=psutil.AccessDenied(pid=5, name="io"), cpu_percent=1.0 ) - assert metrics["loader/buffer_fill"] == 0.5 + def _factory(pid: int) -> _StubProcess: + return process + + monitor = DataloaderMonitor(loader, process_factory=_factory) + metrics = monitor.compute(now=time.time()) assert "loader/read_mb" not in metrics - assert "loader/read_mb_s" not in metrics assert monitor.can_read_io is False assert monitor.warned_io is True - healthy_proc = _StubProcess( - pid=456, - io_exc=None, - read_bytes=2048, - cpu_percent_value=7.5, + +def test_monitor_disables_cpu_on_access_denied(): + loader = _StubLoader(_StubReservoir(0.2), manager_pid=9) + process = _StubProcess( + pid=9, cpu_exc=psutil.AccessDenied(pid=9, name="cpu"), read_bytes=0 ) - metrics_after_restart = LoaderMonitor.collect.__wrapped__( - monitor, - p_dataloader=healthy_proc, - p_children=[], - reservoir_fill=0.25, - now=now + 1.0, + def _factory(pid: int) -> _StubProcess: + return process + + monitor = DataloaderMonitor(loader, process_factory=_factory) + metrics = monitor.compute(now=time.time()) + assert "loader/cpu_util" not in metrics + assert monitor.can_read_cpu is False + assert monitor.warned_cpu is True + + +def test_monitor_handles_children_errors(): + loader = _StubLoader(_StubReservoir(0.1), manager_pid=11) + process = _StubProcess( + pid=11, + cpu_percent=0.0, + children_exc=psutil.AccessDenied(pid=11, name="children"), ) - assert metrics_after_restart["loader/buffer_fill"] == 0.25 - assert metrics_after_restart["loader/read_mb"] == 0.0 - assert metrics_after_restart["loader/read_mb_s"] == 0.0 - assert monitor.can_read_io is True + def _factory(pid: int) -> _StubProcess: + return process + + monitor = DataloaderMonitor(loader, process_factory=_factory) + _ = monitor.compute(now=time.time()) + assert monitor.children == [] + + +def test_monitor_attach_resets_state(): + reservoir = _StubReservoir(0.3) + loader = _StubLoader(reservoir, manager_pid=13) + process = _StubProcess(pid=13, read_bytes=512, cpu_percent=2.0) + + def _factory(pid: int) -> _StubProcess: + return process + + monitor = DataloaderMonitor(loader, process_factory=_factory) + _ = monitor.compute(now=time.time()) + assert monitor.current_pid == 13 + assert monitor.process is process + + new_loader = _StubLoader(reservoir, manager_pid=14) + monitor.attach(new_loader) + assert monitor.current_pid is None + assert monitor.process is None + assert monitor.warned_cpu is False assert monitor.warned_io is False - healthy_proc.set_read_bytes(4096) - metrics_next = LoaderMonitor.collect.__wrapped__( - monitor, - p_dataloader=healthy_proc, - p_children=[], - reservoir_fill=0.3, - now=now + 2.0, - ) +def test_monitor_attach_noop_for_same_loader(): + loader = _StubLoader(_StubReservoir(0.2), manager_pid=1) + monitor = DataloaderMonitor(loader) + monitor.warned_cpu = True + monitor.attach(loader) + assert monitor.warned_cpu is True + + +def test_monitor_process_factory_failure(): + loader = _StubLoader(_StubReservoir(0.4), manager_pid=21) + + def _factory(pid: int) -> _StubProcess: + raise psutil.NoSuchProcess(pid=pid, name="missing") + + monitor = DataloaderMonitor(loader, process_factory=_factory) + metrics = monitor.compute(now=time.time()) + assert metrics == {"loader/buffer_fill": 0.4} + + +def test_monitor_manager_pid_callable_failure(): + class _Loader: + def __init__(self) -> None: + self.reservoir = _StubReservoir(0.1) + + def manager_pid(self) -> int: + raise RuntimeError("boom") + + loader = _Loader() + monitor = DataloaderMonitor(loader, process_factory=lambda _: None) # type: ignore[arg-type] + metrics = monitor.compute(now=time.time()) + assert metrics == {"loader/buffer_fill": 0.1} + + +def test_monitor_manager_pid_cast_failure(): + class _Loader: + def __init__(self) -> None: + self.reservoir = _StubReservoir(0.2) + self.manager_pid = "not-an-int" + + loader = _Loader() + monitor = DataloaderMonitor(loader, process_factory=lambda _: None) # type: ignore[arg-type] + metrics = monitor.compute(now=time.time()) + assert metrics == {"loader/buffer_fill": 0.2} + + +def test_monitor_reservoir_fill_exception(): + class _Reservoir: + def fill(self) -> float: + raise RuntimeError("fill failed") + + class _Loader: + def __init__(self) -> None: + self.reservoir = _Reservoir() + self.manager_pid = -1 + + loader = _Loader() + monitor = DataloaderMonitor(loader, process_factory=lambda _: None) # type: ignore[arg-type] + metrics = monitor.compute(now=time.time()) + assert metrics == {"loader/buffer_fill": 0.0} + + +def test_monitor_is_running_exception() -> None: + process = _StubProcess(pid=1, running_exc=RuntimeError("boom")) + assert DataloaderMonitor._is_running.__wrapped__(process) is False + + +def test_monitor_is_running_missing_method() -> None: + class _Bare: + pass + + assert DataloaderMonitor._is_running.__wrapped__(_Bare()) is True + + +def test_monitor_is_running_nosuchprocess() -> None: + process = _StubProcess(pid=2, running_exc=psutil.NoSuchProcess(pid=2, name="p")) + assert DataloaderMonitor._is_running.__wrapped__(process) is False + + +def test_monitor_read_bytes_missing_method() -> None: + loader = _StubLoader(_StubReservoir(0.1), manager_pid=5) + + class _Process: + pid = 5 + + monitor = DataloaderMonitor(loader, process_factory=lambda _: _Process()) # type: ignore[arg-type] + metrics = monitor.compute(now=time.time()) + assert "loader/read_mb" not in metrics + + +def test_monitor_read_cpu_percent_missing_method() -> None: + loader = _StubLoader(_StubReservoir(0.1), manager_pid=7) + + class _Process: + pid = 7 + + monitor = DataloaderMonitor(loader, process_factory=lambda _: _Process()) # type: ignore[arg-type] + metrics = monitor.compute(now=time.time()) + assert metrics["loader/cpu_util"] == 0.0 + + +def test_monitor_read_bytes_generic_exception() -> None: + loader = _StubLoader(_StubReservoir(0.1), manager_pid=17) + + class _Process: + pid = 17 + + def io_counters(self): + raise RuntimeError("io boom") + + monitor = DataloaderMonitor(loader, process_factory=lambda _: _Process()) # type: ignore[arg-type] + metrics = monitor.compute(now=time.time()) + assert "loader/read_mb" not in metrics + assert monitor.can_read_io is False + + +def test_monitor_read_bytes_missing_read_bytes_field() -> None: + monitor = DataloaderMonitor(_StubLoader(_StubReservoir(0.1), manager_pid=23)) + + class _Process: + pid = 23 + + def io_counters(self): + return types.SimpleNamespace() + + result = DataloaderMonitor._read_bytes.__wrapped__(monitor, _Process(), time.time()) + assert result is None + + +def test_monitor_update_children_generic_exception() -> None: + loader = _StubLoader(_StubReservoir(0.2), manager_pid=41) + + class _BoomProcess: + pid = 41 + + def children(self, recursive: bool) -> list[object]: + raise RuntimeError("boom") + + def io_counters(self): + return types.SimpleNamespace(read_bytes=0) + + def cpu_percent(self, interval: float | None) -> float: + return 0.0 + + monitor = DataloaderMonitor(loader, process_factory=lambda _: _BoomProcess()) # type: ignore[arg-type] + metrics = monitor.compute(now=time.time()) + assert "loader/cpu_util" in metrics + assert monitor.children == [] + + +def test_monitor_read_cpu_percent_generic_exception() -> None: + loader = _StubLoader(_StubReservoir(0.2), manager_pid=19) + + class _Process: + pid = 19 + + def cpu_percent(self, interval: float | None) -> float: + raise RuntimeError("cpu boom") + + def io_counters(self): + return types.SimpleNamespace(read_bytes=0) + + monitor = DataloaderMonitor(loader, process_factory=lambda _: _Process()) # type: ignore[arg-type] + metrics = monitor.compute(now=time.time()) + assert "loader/cpu_util" not in metrics + assert monitor.can_read_cpu is False + + +def test_monitor_read_cpu_percent_non_numeric() -> None: + loader = _StubLoader(_StubReservoir(0.2), manager_pid=43) + + class _Process: + pid = 43 + + def cpu_percent(self, interval: float | None): + return "oops" + + def io_counters(self): + return types.SimpleNamespace(read_bytes=0) + + monitor = DataloaderMonitor(loader, process_factory=lambda _: _Process()) # type: ignore[arg-type] + metrics = monitor.compute(now=time.time()) + assert metrics["loader/cpu_util"] == 0.0 + + +def test_monitor_read_cpu_percent_nosuchprocess() -> None: + loader = _StubLoader(_StubReservoir(0.2), manager_pid=29) + + class _Process: + pid = 29 + + def cpu_percent(self, interval: float | None) -> float: + raise psutil.NoSuchProcess(pid=29, name="parent") + + def io_counters(self): + return types.SimpleNamespace(read_bytes=0) + + monitor = DataloaderMonitor(loader, process_factory=lambda _: _Process()) # type: ignore[arg-type] + metrics = monitor.compute(now=time.time()) + assert "loader/cpu_util" not in metrics + assert monitor.can_read_cpu is False + + +def test_monitor_get_reservoir_fill_missing_reservoir(): + class _Loader: + manager_pid = -1 + reservoir = None + + loader = _Loader() + monitor = DataloaderMonitor(loader, process_factory=lambda _: None) # type: ignore[arg-type] + metrics = monitor.compute() + assert metrics == {"loader/buffer_fill": 0.0} + + +def test_monitor_get_manager_pid_none(): + class _Loader: + def __init__(self) -> None: + self.reservoir = _StubReservoir(0.1) + self.manager_pid = None + + loader = _Loader() + monitor = DataloaderMonitor(loader, process_factory=lambda _: None) # type: ignore[arg-type] + metrics = monitor.compute() + assert metrics == {"loader/buffer_fill": 0.1} + + +def test_monitor_can_read_cpu_false_branch(): + reservoir = _StubReservoir(0.3) + loader = _StubLoader(reservoir, manager_pid=31) + process = _StubProcess(pid=31, read_bytes=0, cpu_percent=0.0) + + def _factory(pid: int) -> _StubProcess: + return process - assert metrics_next["loader/read_mb"] > 0.0 - assert metrics_next["loader/read_mb_s"] > 0.0 + monitor = DataloaderMonitor(loader, process_factory=_factory) + _ = monitor.compute(now=time.time()) + monitor.can_read_cpu = False + monitor.warned_cpu = False + _ = monitor.compute(now=time.time()) + assert monitor.warned_cpu is True diff --git a/train.py b/train.py index f1d0a273..4d1c9673 100644 --- a/train.py +++ b/train.py @@ -26,7 +26,6 @@ import beartype import einops import orjson -import psutil import torch import tyro import wandb @@ -38,6 +37,7 @@ import saev.utils.wandb from saev import configs, disk, helpers, nn from saev.utils import statistics +from saev.utils.monitoring import DataloaderMonitor logger = logging.getLogger("train.py") @@ -151,16 +151,10 @@ def worker_fn(cfgs: list[Config]) -> list[str]: metric.n_almost_dead / sae.cfg.d_sae * 100, ) - # Load metadata to get dataset paths - train_md = saev.data.Metadata.load(cfg.train_data.shards) - val_md = saev.data.Metadata.load(cfg.val_data.shards) - run = disk.Run.new( id, train_shards_dir=cfg.train_data.shards, - train_dataset=train_md.dataset, val_shards_dir=cfg.val_data.shards, - val_dataset=val_md.dataset, runs_root=cfg.runs_root, ) nn.dump(run.ckpt, sae) @@ -232,12 +226,9 @@ def train( objectives = objectives.to(cfg.device) global_step, n_patches_seen = 0, 0 - - p_dataloader, p_children, last_rb, last_t = None, [], 0, time.time() + dl_monitor = DataloaderMonitor(dataloader) for batch in helpers.progress(dataloader, every=cfg.log_every): - p_dataloader, p_children = get_p_dl(p_dataloader, dataloader.manager_pid) - acts_BD = batch["act"].to(cfg.device, non_blocking=True) for sae in saes: sae.normalize_w_dec() @@ -271,22 +262,7 @@ def train( if (global_step + 1) % cfg.log_every == 0: with torch.no_grad(): now = time.time() - # Dataloader stuff - loader_metrics = {} - if p_dataloader is not None: - rb = p_dataloader.io_counters().read_bytes - read_mb = (rb - last_rb) / (1024 * 1024) - read_mb_s = read_mb / (now - last_t) - cpu_util = sum( - t.cpu_percent(None) for t in p_children - ) + p_dataloader.cpu_percent(None) - last_rb, last_t = rb, now - loader_metrics = { - "loader/read_mb": read_mb, - "loader/read_mb_s": read_mb_s, - "loader/cpu_util": cpu_util, - "loader/buffer_fill": dataloader.reservoir.fill(), - } + dl_metrics = dl_monitor.compute(now=now) metadata = dataloader.metadata entropy_metrics = statistics.calc_batch_entropy( @@ -295,7 +271,7 @@ def train( metadata.n_examples, metadata.content_tokens_per_example, ) - loader_metrics.update(entropy_metrics) + dl_metrics.update(entropy_metrics) metrics = [] for i, (loss, sae, objective, group) in enumerate( @@ -332,7 +308,7 @@ def train( "metrics/dictionary_coherence": coherence.item(), "metrics/avg_decoder_row_norm": avg_w_row_norm.item(), "metrics/grad_norm": grad_norms[i].item(), - **loader_metrics, + **dl_metrics, } metrics.append(metric) @@ -362,23 +338,6 @@ def train( return saes, objectives, run, global_step -@beartype.beartype -def get_p_dl( - p_dataloader: psutil.Process | None, manager_pid: int -) -> tuple[psutil.Process | None, list[psutil.Process]]: - needs_updating = ( - p_dataloader is None - or not p_dataloader.is_running() - or p_dataloader.pid != manager_pid - ) - if psutil.pid_exists(manager_pid) and needs_updating: - p_dataloader = psutil.Process(manager_pid) - p_children = p_dataloader.children(recursive=True) - return p_dataloader, p_children - else: - return None, [] - - # TODO: I think this needs to be jaxtyped, but jaxtyped in a submitit context can cause real issues. @beartype.beartype @dataclasses.dataclass(frozen=True)