|
| 1 | +"""Large-volume parallel waterz decoding using file-backed orchestrator. |
| 2 | +
|
| 3 | +Usage: |
| 4 | + # Serial (single process, all stages) |
| 5 | + python scripts/decode_large.py --config tutorials/waterz_decoding_large.yaml |
| 6 | +
|
| 7 | + # Initialize workflow only (for parallel launch) |
| 8 | + python scripts/decode_large.py --config tutorials/waterz_decoding_large.yaml --init-only |
| 9 | +
|
| 10 | + # Run as a worker (claims tasks from shared workflow dir) |
| 11 | + python scripts/decode_large.py --config tutorials/waterz_decoding_large.yaml --worker |
| 12 | +
|
| 13 | + # Wait for all workers to finish, then assemble output |
| 14 | + python scripts/decode_large.py --config tutorials/waterz_decoding_large.yaml --wait --assemble |
| 15 | +""" |
| 16 | + |
| 17 | +import argparse |
| 18 | +import os |
| 19 | +import sys |
| 20 | +from pathlib import Path |
| 21 | + |
| 22 | +import yaml |
| 23 | + |
| 24 | + |
| 25 | +def main(): |
| 26 | + parser = argparse.ArgumentParser(description="Large-volume waterz decoding") |
| 27 | + parser.add_argument("--config", required=True, help="YAML config file") |
| 28 | + parser.add_argument("--init-only", action="store_true", help="Initialize workflow and exit") |
| 29 | + parser.add_argument("--worker", action="store_true", help="Run as a worker (claim tasks)") |
| 30 | + parser.add_argument("--wait", action="store_true", help="Wait for all tasks to complete") |
| 31 | + parser.add_argument("--assemble", action="store_true", help="Assemble final output volume") |
| 32 | + parser.add_argument("--max-tasks", type=int, default=None, help="Max tasks per worker") |
| 33 | + parser.add_argument("--idle-timeout", type=float, default=60.0, help="Worker idle timeout (seconds)") |
| 34 | + parser.add_argument("--worker-id", type=str, default=None, help="Worker identifier") |
| 35 | + parser.add_argument("--job-id", type=str, default=None, help="SLURM job ID") |
| 36 | + # Allow CLI overrides in key=value format |
| 37 | + parser.add_argument("overrides", nargs="*", help="Config overrides (key=value)") |
| 38 | + args = parser.parse_args() |
| 39 | + |
| 40 | + # Load config |
| 41 | + with open(args.config) as f: |
| 42 | + cfg = yaml.safe_load(f) |
| 43 | + |
| 44 | + large_cfg = cfg.get("large_decode", {}) |
| 45 | + |
| 46 | + # Apply CLI overrides |
| 47 | + for override in args.overrides: |
| 48 | + if "=" not in override: |
| 49 | + print(f"Warning: skipping invalid override '{override}' (expected key=value)") |
| 50 | + continue |
| 51 | + key, value = override.split("=", 1) |
| 52 | + # Try numeric conversion |
| 53 | + try: |
| 54 | + value = int(value) |
| 55 | + except ValueError: |
| 56 | + try: |
| 57 | + value = float(value) |
| 58 | + except ValueError: |
| 59 | + pass |
| 60 | + large_cfg[key] = value |
| 61 | + |
| 62 | + if not large_cfg.get("affinity_path"): |
| 63 | + print("Error: large_decode.affinity_path is required") |
| 64 | + sys.exit(1) |
| 65 | + if not large_cfg.get("workflow_root"): |
| 66 | + print("Error: large_decode.workflow_root is required") |
| 67 | + sys.exit(1) |
| 68 | + |
| 69 | + os.environ.setdefault("CCACHE_DISABLE", "1") |
| 70 | + |
| 71 | + from waterz import LargeDecodeRunner |
| 72 | + |
| 73 | + # Parse config |
| 74 | + chunk_shape = large_cfg.get("chunk_shape", [256, 512, 512]) |
| 75 | + if isinstance(chunk_shape, list): |
| 76 | + chunk_shape = tuple(chunk_shape) |
| 77 | + |
| 78 | + thresholds = large_cfg.get("thresholds", [0.5]) |
| 79 | + if isinstance(thresholds, (int, float)): |
| 80 | + thresholds = [thresholds] |
| 81 | + |
| 82 | + runner = LargeDecodeRunner.create( |
| 83 | + affinity_path=large_cfg["affinity_path"], |
| 84 | + workflow_root=large_cfg["workflow_root"], |
| 85 | + chunk_shape=chunk_shape, |
| 86 | + thresholds=thresholds, |
| 87 | + merge_function=large_cfg.get("merge_function", "aff85_his256"), |
| 88 | + aff_threshold_low=float(large_cfg.get("aff_threshold_low", 0.1)), |
| 89 | + aff_threshold_high=float(large_cfg.get("aff_threshold_high", 0.999)), |
| 90 | + channel_order=large_cfg.get("channel_order", "xyz"), |
| 91 | + write_output=bool(large_cfg.get("write_output", True)), |
| 92 | + output_path=large_cfg.get("output_path") or None, |
| 93 | + border_min_overlap=int(large_cfg.get("border_min_overlap", 1)), |
| 94 | + border_one_sided_threshold=float(large_cfg.get("border_one_sided_threshold", 0.9)), |
| 95 | + border_iou_threshold=float(large_cfg.get("border_iou_threshold", 0.0)), |
| 96 | + border_affinity_threshold=float(large_cfg.get("border_affinity_threshold", 0.0)), |
| 97 | + compression=large_cfg.get("compression", "gzip"), |
| 98 | + compression_level=int(large_cfg.get("compression_level", 4)), |
| 99 | + ) |
| 100 | + |
| 101 | + chunks = runner.chunks |
| 102 | + borders = runner.borders |
| 103 | + print(f"Volume shape: {runner.config.volume_shape}") |
| 104 | + print(f"Chunk shape: {runner.config.chunk_shape}") |
| 105 | + print(f"Chunks: {len(chunks)}") |
| 106 | + print(f"Borders: {len(borders)}") |
| 107 | + print(f"Workflow: {runner.config.workflow_root}") |
| 108 | + |
| 109 | + if args.init_only: |
| 110 | + print("Workflow initialized. Launch workers to execute tasks.") |
| 111 | + return |
| 112 | + |
| 113 | + if args.worker: |
| 114 | + worker_id = args.worker_id or os.environ.get("SLURM_JOB_ID", None) |
| 115 | + job_id = args.job_id or os.environ.get("SLURM_ARRAY_TASK_ID", None) |
| 116 | + print(f"Starting worker: {worker_id or 'auto'} (job={job_id or 'none'})") |
| 117 | + n = runner.run_worker( |
| 118 | + worker_id=worker_id, |
| 119 | + max_tasks=args.max_tasks, |
| 120 | + idle_timeout=args.idle_timeout, |
| 121 | + job_id=job_id, |
| 122 | + ) |
| 123 | + print(f"Worker completed {n} tasks.") |
| 124 | + return |
| 125 | + |
| 126 | + if args.wait: |
| 127 | + print("Waiting for all tasks to complete...") |
| 128 | + runner.wait(timeout=None) |
| 129 | + print("All tasks completed.") |
| 130 | + if args.assemble and runner.config.write_output: |
| 131 | + print("Assembling output...") |
| 132 | + runner.handle_assemble_output(None) |
| 133 | + print(f"Output: {runner.config.resolved_output_path}") |
| 134 | + return |
| 135 | + |
| 136 | + # Default: run serial (all stages in one process) |
| 137 | + print("Running serial decode...") |
| 138 | + n = runner.run_serial() |
| 139 | + print(f"Completed {n} tasks.") |
| 140 | + |
| 141 | + status = runner.orchestrator.stage_counts() |
| 142 | + for stage, counts in sorted(status.items()): |
| 143 | + print(f" {stage}: {counts}") |
| 144 | + |
| 145 | + |
| 146 | +if __name__ == "__main__": |
| 147 | + main() |
0 commit comments