|
| 1 | +# Copyright 2026 The Orbax Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Launches and bootstraps tests across multiple simulated JAX processes.""" |
| 16 | + |
| 17 | +import argparse |
| 18 | +import contextlib |
| 19 | +import os |
| 20 | +import runpy |
| 21 | +import socket |
| 22 | +import subprocess |
| 23 | +import sys |
| 24 | + |
| 25 | +from absl import logging |
| 26 | +import jax |
| 27 | +import pytest |
| 28 | + |
| 29 | + |
| 30 | +def find_free_port(): |
| 31 | + with contextlib.closing( |
| 32 | + socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| 33 | + ) as s: |
| 34 | + s.bind(("", 0)) |
| 35 | + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| 36 | + return s.getsockname()[1] |
| 37 | + |
| 38 | + |
| 39 | +def run_worker_and_command(command): |
| 40 | + """Worker Mode: Initializes JAX explicitly, then executes the target command.""" |
| 41 | + |
| 42 | + coordinator_address = os.environ.get("JAX_COORDINATOR_ADDRESS") |
| 43 | + num_processes = os.environ.get("JAX_NUM_PROCESSES") |
| 44 | + process_id = os.environ.get("JAX_PROCESS_ID") |
| 45 | + |
| 46 | + if coordinator_address is None: |
| 47 | + raise ValueError( |
| 48 | + "Environment variables for JAX distributed not found. " |
| 49 | + "Did you use launch_multihost.py?" |
| 50 | + ) |
| 51 | + |
| 52 | + # Explicit Initialization |
| 53 | + jax.distributed.initialize( |
| 54 | + coordinator_address=coordinator_address, |
| 55 | + num_processes=int(num_processes), |
| 56 | + process_id=int(process_id), |
| 57 | + ) |
| 58 | + |
| 59 | + print(f"[Rank {process_id}] JAX Initialized. Executing: {' '.join(command)}") |
| 60 | + print(f"[Rank {process_id}] JAX devices: {jax.devices()}") |
| 61 | + |
| 62 | + # Clean up 'python' from the command if the user accidentally included it |
| 63 | + if command[0] == "python" or command[0] == "python3": |
| 64 | + command = command[1:] |
| 65 | + |
| 66 | + cmd_name = command[0] |
| 67 | + |
| 68 | + # Execute the requested script/tool inside this initialized process |
| 69 | + if cmd_name == "pytest": |
| 70 | + sys.exit(pytest.main(command[1:])) |
| 71 | + |
| 72 | + elif cmd_name.endswith(".py"): |
| 73 | + # Overwrite sys.argv so the target script sees its expected arguments |
| 74 | + sys.argv = command |
| 75 | + runpy.run_path(cmd_name, run_name="__main__") |
| 76 | + |
| 77 | + else: |
| 78 | + # Fallback for arbitrary shell commands |
| 79 | + sys.exit(subprocess.call(command)) |
| 80 | + |
| 81 | + |
| 82 | +def main(): |
| 83 | + # 1. Parse arguments meant for launch.py |
| 84 | + parser = argparse.ArgumentParser(description="JAX Multihost Launcher") |
| 85 | + parser.add_argument( |
| 86 | + "--worker_mode", action="store_true", help=argparse.SUPPRESS |
| 87 | + ) |
| 88 | + parser.add_argument( |
| 89 | + "--num_processes", type=int, default=2, help="Number of simulated hosts" |
| 90 | + ) |
| 91 | + parser.add_argument( |
| 92 | + "--tpu_chips_per_process", type=int, default=4, help="TPU chips per host" |
| 93 | + ) |
| 94 | + |
| 95 | + # `args` gets the launcher configs, `command` gets everything else |
| 96 | + args, command = parser.parse_known_args() |
| 97 | + |
| 98 | + # 2. WORKER MODE |
| 99 | + if args.worker_mode: |
| 100 | + if not command: |
| 101 | + raise ValueError("No command provided for the worker to execute.") |
| 102 | + run_worker_and_command(command) |
| 103 | + return |
| 104 | + |
| 105 | + # 3. LAUNCHER MODE |
| 106 | + if not command: |
| 107 | + logging.error( |
| 108 | + "Usage: python %s [LAUNCH_ARGS] <script.py> [SCRIPT_ARGS]", |
| 109 | + os.path.basename(__file__), |
| 110 | + ) |
| 111 | + sys.exit(1) |
| 112 | + |
| 113 | + coordinator_port = find_free_port() |
| 114 | + coordinator_address = f"localhost:{coordinator_port}" |
| 115 | + |
| 116 | + slicebuilder_ports = [find_free_port() for _ in range(args.num_processes)] |
| 117 | + slicebuilder_addresses = ",".join( |
| 118 | + f"localhost:{port}" for port in slicebuilder_ports |
| 119 | + ) |
| 120 | + |
| 121 | + logging.info( |
| 122 | + "🚀 Starting %s JAX processes (%s chips/process)...", |
| 123 | + args.num_processes, |
| 124 | + args.tpu_chips_per_process, |
| 125 | + ) |
| 126 | + logging.info("📍 Coordinator: %s", coordinator_address) |
| 127 | + |
| 128 | + tpu_chips_per_process = args.tpu_chips_per_process |
| 129 | + num_tpu_chips = args.num_processes * args.tpu_chips_per_process |
| 130 | + if num_tpu_chips == 0: |
| 131 | + tpu_host_bounds = "" |
| 132 | + tpu_chips_per_host_bounds = "" |
| 133 | + elif num_tpu_chips == 1: |
| 134 | + assert tpu_chips_per_process == 1 |
| 135 | + tpu_host_bounds = "1,1,1" |
| 136 | + tpu_chips_per_host_bounds = "1,1,1" |
| 137 | + elif num_tpu_chips == 4: |
| 138 | + if tpu_chips_per_process == 1: |
| 139 | + tpu_host_bounds = "2,2,1" |
| 140 | + tpu_chips_per_host_bounds = "1,1,1" |
| 141 | + elif tpu_chips_per_process == 2: |
| 142 | + tpu_host_bounds = "2,1,1" |
| 143 | + tpu_chips_per_host_bounds = "1,2,1" |
| 144 | + elif tpu_chips_per_process == 4: |
| 145 | + tpu_host_bounds = "1,1,1" |
| 146 | + tpu_chips_per_host_bounds = "2,2,1" |
| 147 | + else: |
| 148 | + raise ValueError( |
| 149 | + "Invalid number of TPU chips per worker {}".format( |
| 150 | + tpu_chips_per_process |
| 151 | + ) |
| 152 | + ) |
| 153 | + elif num_tpu_chips == 8: |
| 154 | + if tpu_chips_per_process == 1: |
| 155 | + tpu_host_bounds = "4,2,1" |
| 156 | + tpu_chips_per_host_bounds = "1,1,1" |
| 157 | + elif tpu_chips_per_process == 4: |
| 158 | + # Note: this branch assumes we are using 2x4 v6e LitePod, and will not |
| 159 | + # work with 4x2 v5e LitePod. |
| 160 | + tpu_host_bounds = "1,2,1" |
| 161 | + tpu_chips_per_host_bounds = "2,2,1" |
| 162 | + elif tpu_chips_per_process == 8: |
| 163 | + tpu_host_bounds = "1,1,1" |
| 164 | + tpu_chips_per_host_bounds = "2,4,1" |
| 165 | + else: |
| 166 | + # TODO(phawkins): implement other cases. |
| 167 | + raise ValueError( |
| 168 | + "Invalid number of TPU chips per worker {}".format( |
| 169 | + tpu_chips_per_process |
| 170 | + ) |
| 171 | + ) |
| 172 | + else: |
| 173 | + raise ValueError(f"Invalid number of TPU chips {num_tpu_chips}") |
| 174 | + |
| 175 | + processes = [] |
| 176 | + for rank in range(args.num_processes): |
| 177 | + env = os.environ.copy() |
| 178 | + |
| 179 | + # JAX Distributed Setup |
| 180 | + env["JAX_COORDINATOR_ADDRESS"] = coordinator_address |
| 181 | + env["JAX_NUM_PROCESSES"] = str(args.num_processes) |
| 182 | + env["JAX_PROCESS_ID"] = str(rank) |
| 183 | + |
| 184 | + device_ids = range( |
| 185 | + rank * args.tpu_chips_per_process, |
| 186 | + (rank + 1) * args.tpu_chips_per_process, |
| 187 | + ) |
| 188 | + |
| 189 | + # Simulated TPU Setup |
| 190 | + env["CLOUD_TPU_TASK_ID"] = str(rank) |
| 191 | + env["TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_host_bounds |
| 192 | + env["TPU_PROCESS_BOUNDS"] = tpu_host_bounds |
| 193 | + env["TPU_PROCESS_ADDRESSES"] = slicebuilder_addresses |
| 194 | + env["TPU_PROCESS_PORT"] = str(slicebuilder_ports[rank]) |
| 195 | + env["TPU_VISIBLE_CHIPS"] = ",".join(map(str, device_ids)) |
| 196 | + env["ALLOW_MULTIPLE_LIBTPU_LOAD"] = "1" |
| 197 | + |
| 198 | + # Format the user's command to inject the current process rank where {rank} |
| 199 | + # is used |
| 200 | + worker_cmd = [c.format(rank=rank) for c in command] |
| 201 | + |
| 202 | + # Spawn THIS script again, triggering worker_mode |
| 203 | + cmd = [sys.executable, __file__, "--worker_mode"] + worker_cmd |
| 204 | + |
| 205 | + p = subprocess.Popen(cmd, env=env) |
| 206 | + processes.append(p) |
| 207 | + |
| 208 | + exit_codes = [p.wait() for p in processes] |
| 209 | + |
| 210 | + if any(c != 0 for c in exit_codes): |
| 211 | + logging.error("\n❌ Some processes failed.") |
| 212 | + sys.exit(1) |
| 213 | + else: |
| 214 | + logging.info("\n✅ All processes finished successfully.") |
| 215 | + |
| 216 | + |
| 217 | +if __name__ == "__main__": |
| 218 | + main() |
0 commit comments