Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csub.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def build_parser() -> argparse.ArgumentParser:
parser.add_argument("--secret-name", type=str, help="Override RUNAI_SECRET_NAME from the env file")
parser.add_argument("--pvc", type=str, help="Override SCRATCH_PVC from the env file")
parser.add_argument("--backofflimit", type=int, default=0, help="Retries before marking a training job as failed")
parser.add_argument("--node-type", type=str, choices=["", "v100", "h100", "h200", "default", "a100-40g"], default="", help="GPU node pool to target")
parser.add_argument("--node-type", nargs="*", type=str, choices=["v100", "h100", "h200", "default", "a100-40g"], default=[], help="GPU node pool(s) to target. Multiple values are accepted; the job will be scheduled on the first pool where it fits")
parser.add_argument("--host-ipc", action="store_true", help="Share the host IPC namespace")
parser.add_argument("--large-shm", action="store_true", help="Request a larger /dev/shm")
return parser
Expand Down Expand Up @@ -172,8 +172,8 @@ def build_runai_command(
cmd.append("--large-shm")

if args.node_type:
cmd.extend(["--node-pools", args.node_type])
if args.node_type in {"h200", "h100"} and not args.train:
cmd.extend(["--node-pools", ','.join(args.node_type)])
if any(i in {"h200", "h100"} for i in args.node_type) and not args.train:
cmd.append("--preemptible")

if distributed:
Expand Down
Loading