2323
2424import yaml
2525
26+ def _worker_fn (args_tuple ):
27+ """Worker function for parallel decode (takes all args as tuple for spawn compatibility)."""
28+ worker_idx , workflow_root , idle_timeout , max_tasks = args_tuple
29+ os .environ ["CCACHE_DISABLE" ] = "1"
30+ os .environ ["OMP_NUM_THREADS" ] = "1"
31+ os .environ ["OPENBLAS_NUM_THREADS" ] = "1"
32+ os .environ ["MKL_NUM_THREADS" ] = "1"
33+ from waterz import LargeDecodeRunner as _LDR
34+ w = _LDR .load (workflow_root )
35+ return w .run_worker (
36+ worker_id = f"local-{ worker_idx } " ,
37+ idle_timeout = idle_timeout ,
38+ max_tasks = max_tasks ,
39+ )
40+
2641
2742def main ():
2843 parser = argparse .ArgumentParser (description = "Large-volume waterz decoding" )
@@ -37,6 +52,10 @@ def main():
3752 parser .add_argument ("--idle-timeout" , type = float , default = 60.0 , help = "Worker idle timeout (seconds)" )
3853 parser .add_argument ("--worker-id" , type = str , default = None , help = "Worker identifier" )
3954 parser .add_argument ("--job-id" , type = str , default = None , help = "SLURM job ID" )
55+ parser .add_argument ("--stale-timeout" , type = float , default = 600 ,
56+ help = "Reset RUNNING tasks older than this many seconds (default: 600)" )
57+ parser .add_argument ("--no-reset-stale" , action = "store_true" ,
58+ help = "Skip resetting stale RUNNING tasks" )
4059 parser .add_argument ("overrides" , nargs = "*" , help = "Config overrides (key=value)" )
4160 args = parser .parse_args ()
4261
@@ -76,6 +95,13 @@ def main():
7695 runner = LargeDecodeRunner (config )
7796 runner .initialize ()
7897
98+ # Reset stale/failed tasks so re-runs recover from crashed workers
99+ if not args .no_reset_stale :
100+ n_stale = runner .orchestrator .reset_stale_tasks (max_age_seconds = args .stale_timeout )
101+ runner .orchestrator .reset_failed_tasks ()
102+ if n_stale :
103+ print (f"Reset { n_stale } stale RUNNING tasks (older than { args .stale_timeout } s)." )
104+
79105 chunks = runner .chunks
80106 borders = runner .borders
81107 print (f"Volume shape: { config .volume_shape } " )
@@ -115,21 +141,20 @@ def main():
115141 if args .parallel and args .parallel > 1 :
116142 import multiprocessing as mp
117143
144+ workflow_root = large_cfg ["workflow_root" ]
145+ idle_timeout = args .idle_timeout or 120
146+ max_tasks = args .max_tasks
147+
118148 n_workers = args .parallel
119149 print (f"Running parallel decode with { n_workers } workers..." )
120150
121- def _worker_fn (worker_idx ):
122- os .environ ["CCACHE_DISABLE" ] = "1"
123- from waterz import LargeDecodeRunner as _LDR
124- w = _LDR .load (large_cfg ["workflow_root" ])
125- return w .run_worker (
126- worker_id = f"local-{ worker_idx } " ,
127- idle_timeout = args .idle_timeout or 120 ,
128- max_tasks = args .max_tasks ,
129- )
130-
131- with mp .Pool (n_workers ) as pool :
132- counts = pool .map (_worker_fn , range (n_workers ))
151+ worker_args = [
152+ (i , workflow_root , idle_timeout , max_tasks )
153+ for i in range (n_workers )
154+ ]
155+ ctx = mp .get_context ("spawn" )
156+ with ctx .Pool (n_workers ) as pool :
157+ counts = pool .map (_worker_fn , worker_args )
133158 n = sum (counts )
134159 print (f"Completed { n } tasks across { n_workers } workers." )
135160 else :
0 commit comments