@@ -138,7 +138,7 @@ def main():
138138 chunk = chunks [idx ]
139139 print (f"Decoding chunk { idx } /{ len (chunks )} : { chunk .key } " )
140140 from waterz .orchestrator import TaskRecord , TaskSpec
141- record = TaskRecord (spec = TaskSpec (stage = "decode" , key = chunk .key ))
141+ record = TaskRecord (spec = TaskSpec (name = f"decode_ { chunk . key } " , stage = "decode" , key = chunk .key ))
142142 result = runner .handle_decode_chunk (record )
143143 print (f" Done: { result } " )
144144 return
@@ -157,9 +157,41 @@ def main():
157157 return
158158
159159 if args .wait :
160+ import time as _time
161+
160162 print ("Waiting for all tasks to complete..." )
161- runner .wait (timeout = None )
163+ last_print = 0
164+ while True :
165+ counts = runner .orchestrator .stage_counts ()
166+ now = _time .monotonic ()
167+ if now - last_print >= 10 :
168+ parts = []
169+ for stage , sc in sorted (counts .items ()):
170+ done = sc .get ("succeeded" , 0 )
171+ total = sum (sc .values ())
172+ running = sc .get ("running" , 0 )
173+ status = f"{ stage } : { done } /{ total } "
174+ if running :
175+ status += f" ({ running } running)"
176+ parts .append (status )
177+ print (f" Progress: { ' | ' .join (parts )} " , flush = True )
178+ last_print = now
179+
180+ all_terminal = all (
181+ all (k in ("succeeded" , "failed" ) for k in sc )
182+ for sc in counts .values ()
183+ )
184+ if counts and all_terminal :
185+ break
186+ _time .sleep (1 )
187+
162188 print ("All tasks completed." )
189+ # Check for failures
190+ for stage , sc in sorted (counts .items ()):
191+ failed = sc .get ("failed" , 0 )
192+ if failed :
193+ print (f" WARNING: { stage } has { failed } failed tasks" )
194+
163195 if args .assemble and config .write_output :
164196 print ("Assembling output..." )
165197 runner .handle_assemble_output (None )
@@ -186,8 +218,32 @@ def main():
186218 n = sum (counts )
187219 print (f"Completed { n } tasks across { n_workers } workers." )
188220 else :
189- print ("Running serial decode..." )
190- n = runner .run_serial ()
221+ import threading
222+
223+ total_tasks = len (runner .orchestrator .list_records ())
224+ stop_progress = threading .Event ()
225+
226+ def _progress_loop ():
227+ while not stop_progress .wait (10 ):
228+ counts = runner .orchestrator .stage_counts ()
229+ parts = []
230+ for stage , sc in sorted (counts .items ()):
231+ done = sc .get ("succeeded" , 0 )
232+ total = sum (sc .values ())
233+ running = sc .get ("running" , 0 )
234+ status = f"{ stage } : { done } /{ total } "
235+ if running :
236+ status += f" ({ running } running)"
237+ parts .append (status )
238+ print (f" Progress: { ' | ' .join (parts )} " , flush = True )
239+
240+ t = threading .Thread (target = _progress_loop , daemon = True )
241+ t .start ()
242+ try :
243+ n = runner .run_serial ()
244+ finally :
245+ stop_progress .set ()
246+ t .join ()
191247 print (f"Completed { n } tasks." )
192248
193249 status = runner .orchestrator .stage_counts ()
0 commit comments