Skip to content

Commit 5a16435

Browse files
Donglai Weiclaude
andcommitted
Fix multiprocessing: pass args as tuple instead of module globals
Module-level globals are not inherited by spawn-context child processes. Pass workflow_root, idle_timeout, max_tasks as a tuple argument instead. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 66bedda commit 5a16435

1 file changed

Lines changed: 37 additions & 12 deletions

File tree

scripts/decode_large.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,21 @@
2323

2424
import yaml
2525

26+
def _worker_fn(args_tuple):
27+
"""Worker function for parallel decode (takes all args as tuple for spawn compatibility)."""
28+
worker_idx, workflow_root, idle_timeout, max_tasks = args_tuple
29+
os.environ["CCACHE_DISABLE"] = "1"
30+
os.environ["OMP_NUM_THREADS"] = "1"
31+
os.environ["OPENBLAS_NUM_THREADS"] = "1"
32+
os.environ["MKL_NUM_THREADS"] = "1"
33+
from waterz import LargeDecodeRunner as _LDR
34+
w = _LDR.load(workflow_root)
35+
return w.run_worker(
36+
worker_id=f"local-{worker_idx}",
37+
idle_timeout=idle_timeout,
38+
max_tasks=max_tasks,
39+
)
40+
2641

2742
def main():
2843
parser = argparse.ArgumentParser(description="Large-volume waterz decoding")
@@ -37,6 +52,10 @@ def main():
3752
parser.add_argument("--idle-timeout", type=float, default=60.0, help="Worker idle timeout (seconds)")
3853
parser.add_argument("--worker-id", type=str, default=None, help="Worker identifier")
3954
parser.add_argument("--job-id", type=str, default=None, help="SLURM job ID")
55+
parser.add_argument("--stale-timeout", type=float, default=600,
56+
help="Reset RUNNING tasks older than this many seconds (default: 600)")
57+
parser.add_argument("--no-reset-stale", action="store_true",
58+
help="Skip resetting stale RUNNING tasks")
4059
parser.add_argument("overrides", nargs="*", help="Config overrides (key=value)")
4160
args = parser.parse_args()
4261

@@ -76,6 +95,13 @@ def main():
7695
runner = LargeDecodeRunner(config)
7796
runner.initialize()
7897

98+
# Reset stale/failed tasks so re-runs recover from crashed workers
99+
if not args.no_reset_stale:
100+
n_stale = runner.orchestrator.reset_stale_tasks(max_age_seconds=args.stale_timeout)
101+
runner.orchestrator.reset_failed_tasks()
102+
if n_stale:
103+
print(f"Reset {n_stale} stale RUNNING tasks (older than {args.stale_timeout}s).")
104+
79105
chunks = runner.chunks
80106
borders = runner.borders
81107
print(f"Volume shape: {config.volume_shape}")
@@ -115,21 +141,20 @@ def main():
115141
if args.parallel and args.parallel > 1:
116142
import multiprocessing as mp
117143

144+
workflow_root = large_cfg["workflow_root"]
145+
idle_timeout = args.idle_timeout or 120
146+
max_tasks = args.max_tasks
147+
118148
n_workers = args.parallel
119149
print(f"Running parallel decode with {n_workers} workers...")
120150

121-
def _worker_fn(worker_idx):
122-
os.environ["CCACHE_DISABLE"] = "1"
123-
from waterz import LargeDecodeRunner as _LDR
124-
w = _LDR.load(large_cfg["workflow_root"])
125-
return w.run_worker(
126-
worker_id=f"local-{worker_idx}",
127-
idle_timeout=args.idle_timeout or 120,
128-
max_tasks=args.max_tasks,
129-
)
130-
131-
with mp.Pool(n_workers) as pool:
132-
counts = pool.map(_worker_fn, range(n_workers))
151+
worker_args = [
152+
(i, workflow_root, idle_timeout, max_tasks)
153+
for i in range(n_workers)
154+
]
155+
ctx = mp.get_context("spawn")
156+
with ctx.Pool(n_workers) as pool:
157+
counts = pool.map(_worker_fn, worker_args)
133158
n = sum(counts)
134159
print(f"Completed {n} tasks across {n_workers} workers.")
135160
else:

0 commit comments

Comments
 (0)