diff --git a/csub.py b/csub.py index db3625d..ac1f668 100644 --- a/csub.py +++ b/csub.py @@ -56,7 +56,7 @@ def build_parser() -> argparse.ArgumentParser: parser.add_argument("--secret-name", type=str, help="Override RUNAI_SECRET_NAME from the env file") parser.add_argument("--pvc", type=str, help="Override SCRATCH_PVC from the env file") parser.add_argument("--backofflimit", type=int, default=0, help="Retries before marking a training job as failed") - parser.add_argument("--node-type", type=str, choices=["", "v100", "h100", "h200", "default", "a100-40g"], default="", help="GPU node pool to target") + parser.add_argument("--node-type", nargs="*", type=str, choices=["v100", "h100", "h200", "default", "a100-40g"], default=[], help="GPU node pool(s) to target. Multiple values are accepted; the job will be scheduled on the first pool where it fits") parser.add_argument("--host-ipc", action="store_true", help="Share the host IPC namespace") parser.add_argument("--large-shm", action="store_true", help="Request a larger /dev/shm") return parser @@ -172,8 +172,8 @@ def build_runai_command( cmd.append("--large-shm") if args.node_type: - cmd.extend(["--node-pools", args.node_type]) - if args.node_type in {"h200", "h100"} and not args.train: + cmd.extend(["--node-pools", ','.join(args.node_type)]) + if any(i in {"h200", "h100"} for i in args.node_type) and not args.train: cmd.append("--preemptible") if distributed: