Skip to content

Commit 0e1708e

Browse files
author
Orbax Authors
committed
Re-enable multiprocess unit tests using a new multihost setup script.
PiperOrigin-RevId: 872217278
1 parent d8183bd commit 0e1708e

8 files changed

Lines changed: 273 additions & 75 deletions

File tree

.github/workflows/build.yml

Lines changed: 53 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ jobs:
7575
--ignore=orbax/checkpoint/experimental/emergency/multihost_test.py \
7676
--ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py \
7777
--ignore=orbax/checkpoint/_src/testing/multiprocess_test.py \
78-
$(python3 -c "import yaml; d=yaml.safe_load(open('orbax/checkpoint/_src/testing/multiprocess_unittests/tagged_tests.yaml')); print(' '.join(['--ignore=' + t.replace(':', '/') + '.py' for k,v in d.items() if k.startswith('processes') and v for t in v]))")
78+
$(python3 -c "import yaml; d=yaml.safe_load(open('orbax/checkpoint/_src/testing/oss/tagged_tests.yaml')); print(' '.join(['--ignore=' + t.replace(':', '/') + '.py' for k,v in d.items() if k.startswith('processes') and v for t in v]))")
7979
# The below step just reports the success or failure of tests as a "commit status".
8080
# This is needed for copybara integration.
8181
- name: Report success or failure as github status
@@ -284,13 +284,13 @@ jobs:
284284
TF_FORCE_GPU_ALLOW_GROWTH: true
285285
XLA_PYTHON_CLIENT_PREALLOCATE: false
286286
run: |
287-
cd orbax/checkpoint/_src/testing
287+
cd orbax/checkpoint/_src/testing/oss
288288
ls
289289
failed_benchmarks=""
290290
benchmark_configs_file="multiprocess_benchmark_configs.txt"
291291
echo "Running benchmarks specified in $benchmark_configs_file"
292292
benchmark_configs_file_path="$PWD/$benchmark_configs_file"
293-
cd benchmarks
293+
cd ../benchmarks
294294
while IFS= read -r entry || [ -n "$entry" ]; do
295295
if [ -n "$entry" ]; then
296296
echo "Running benchmark for $entry"
@@ -324,55 +324,53 @@ jobs:
324324
"context": "github-actions/build"
325325
}'
326326
327-
# multiprocess-unit-tests:
328-
# name: "multiprocess-unit-tests (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
329-
# runs-on: linux-x86-ct5lp-4tpu-x2
330-
# container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e
331-
# defaults:
332-
# run:
333-
# working-directory: checkpoint
334-
# strategy:
335-
# matrix:
336-
# python-version: ["3.11"]
337-
# jax-version: ["newest"]
338-
# steps:
339-
# - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
340-
# - name: Set up Python ${{ matrix.python-version }}
341-
# uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
342-
# with:
343-
# python-version: ${{ matrix.python-version }}
344-
# - name: Install dependencies
345-
# run: |
346-
# pip install -e .
347-
# pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
348-
# pip uninstall -y orbax
349-
# pip install gcsfs
350-
# pip install portpicker pytest chex pyyaml
351-
# if [ "${{ matrix.jax-version }}" = "newest" ]; then
352-
# pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
353-
# elif [ "${{ matrix.jax-version }}" = "nightly" ]; then
354-
# pip install -U --pre "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
355-
# else
356-
# pip install "jax[tpu]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
357-
# fi
358-
# - name: Run multiprocess tests
359-
# env:
360-
# TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }}
361-
# run: |
362-
# python orbax/checkpoint/_src/testing/multiprocess_unittests/run_tests.py --filename=orbax/checkpoint/_src/testing/multiprocess_unittests/tagged_tests.yaml --processes=4
363-
# - name: Report success or failure as github status
364-
# if: always()
365-
# shell: bash
366-
# run: |
367-
# status="${{ job.status }}"
368-
# lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
369-
# curl -sS --request POST \
370-
# --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
371-
# --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
372-
# --header 'content-type: application/json' \
373-
# --data '{
374-
# "state": "'$lowercase_status'",
375-
# "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
376-
# "description": "'$status'",
377-
# "context": "github-actions/build"
378-
# }'
327+
multiprocess-unit-tests:
328+
name: "multiprocess-unit-tests (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
329+
runs-on: linux-x86-ct5lp-224-8tpu
330+
container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e
331+
defaults:
332+
run:
333+
working-directory: checkpoint
334+
strategy:
335+
matrix:
336+
python-version: ["3.11"]
337+
jax-version: ["newest"]
338+
steps:
339+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
340+
- name: Set up Python ${{ matrix.python-version }}
341+
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
342+
with:
343+
python-version: ${{ matrix.python-version }}
344+
- name: Install dependencies
345+
run: |
346+
pip install -e .
347+
pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
348+
pip uninstall -y orbax
349+
pip install gcsfs
350+
pip install portpicker pytest chex pyyaml
351+
if [ "${{ matrix.jax-version }}" = "newest" ]; then
352+
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
353+
elif [ "${{ matrix.jax-version }}" = "nightly" ]; then
354+
pip install -U --pre "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
355+
else
356+
pip install "jax[tpu]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
357+
fi
358+
- name: Run multiprocess tests
359+
run: |
360+
python orbax/checkpoint/_src/testing/oss/run_multihost.py orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=4
361+
- name: Report success or failure as github status
362+
if: always()
363+
shell: bash
364+
run: |
365+
status="${{ job.status }}"
366+
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
367+
curl -sS --request POST \
368+
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
369+
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
370+
--header 'content-type: application/json' \
371+
--data '{
372+
"state": "'$lowercase_status'",
373+
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
374+
"description": "'$status'",
375+
"context": "github-actions/build"
376+
}'

.github/workflows/multiprocess_tests.yml

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ jobs:
5353
TF_FORCE_GPU_ALLOW_GROWTH: true
5454
XLA_PYTHON_CLIENT_PREALLOCATE: false
5555
run: |
56-
cd orbax/checkpoint/_src/testing
56+
cd orbax/checkpoint/_src/testing/oss
5757
failed_benchmarks=""
5858
benchmark_configs_file="multiprocess_benchmark_configs.txt"
5959
echo "Running benchmarks specified in $benchmark_configs_file"
6060
benchmark_configs_file_path="$PWD/$benchmark_configs_file"
61-
cd benchmarks
61+
cd ../benchmarks
6262
while IFS= read -r entry || [ -n "$entry" ]; do
6363
if [ -n "$entry" ]; then
6464
echo "Running benchmark for $entry"
@@ -77,14 +77,6 @@ jobs:
7777
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
7878
# The below step just reports the success or failure of tests as a "commit status".
7979
# This is needed for copybara integration.
80-
- name: Run multiprocess tests
81-
env:
82-
TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }}
83-
run: |
84-
python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)"
85-
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;"
86-
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
87-
# python -m pytest orbax/checkpoint/checkpoint_manager_test.py
8880
- name: Report success or failure as github status
8981
if: always()
9082
shell: bash

checkpoint/orbax/checkpoint/_src/testing/multiprocess_tests.txt

Lines changed: 0 additions & 1 deletion
This file was deleted.

checkpoint/orbax/checkpoint/_src/testing/multiprocess_unittests/generate_multiprocess_test.py renamed to checkpoint/orbax/checkpoint/_src/testing/oss/generate_multiprocess_test.py

File renamed without changes.

checkpoint/orbax/checkpoint/_src/testing/multiprocess_benchmark_configs.txt renamed to checkpoint/orbax/checkpoint/_src/testing/oss/multiprocess_benchmark_configs.txt

File renamed without changes.
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Launches and bootstraps tests across multiple simulated JAX processes."""
16+
17+
import argparse
18+
import contextlib
19+
import os
20+
import runpy
21+
import socket
22+
import subprocess
23+
import sys
24+
25+
from absl import logging
26+
import jax
27+
import pytest
28+
29+
30+
def find_free_port():
31+
with contextlib.closing(
32+
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
33+
) as s:
34+
s.bind(("", 0))
35+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
36+
return s.getsockname()[1]
37+
38+
39+
def run_worker_and_command(command):
40+
"""Worker Mode: Initializes JAX explicitly, then executes the target command."""
41+
42+
coordinator_address = os.environ.get("JAX_COORDINATOR_ADDRESS")
43+
num_processes = os.environ.get("JAX_NUM_PROCESSES")
44+
process_id = os.environ.get("JAX_PROCESS_ID")
45+
46+
if coordinator_address is None:
47+
raise ValueError(
48+
"Environment variables for JAX distributed not found. "
49+
"Did you use launch_multihost.py?"
50+
)
51+
52+
# Explicit Initialization
53+
jax.distributed.initialize(
54+
coordinator_address=coordinator_address,
55+
num_processes=int(num_processes),
56+
process_id=int(process_id),
57+
)
58+
59+
print(f"[Rank {process_id}] JAX Initialized. Executing: {' '.join(command)}")
60+
print(f"[Rank {process_id}] JAX devices: {jax.devices()}")
61+
62+
# Clean up 'python' from the command if the user accidentally included it
63+
if command[0] == "python" or command[0] == "python3":
64+
command = command[1:]
65+
66+
cmd_name = command[0]
67+
68+
# Execute the requested script/tool inside this initialized process
69+
if cmd_name == "pytest":
70+
sys.exit(pytest.main(command[1:]))
71+
72+
elif cmd_name.endswith(".py"):
73+
# Overwrite sys.argv so the target script sees its expected arguments
74+
sys.argv = command
75+
runpy.run_path(cmd_name, run_name="__main__")
76+
77+
else:
78+
# Fallback for arbitrary shell commands
79+
sys.exit(subprocess.call(command))
80+
81+
82+
def main():
83+
# 1. Parse arguments meant for launch.py
84+
parser = argparse.ArgumentParser(description="JAX Multihost Launcher")
85+
parser.add_argument(
86+
"--worker_mode", action="store_true", help=argparse.SUPPRESS
87+
)
88+
parser.add_argument(
89+
"--num_processes", type=int, default=2, help="Number of simulated hosts"
90+
)
91+
parser.add_argument(
92+
"--tpu_chips_per_process", type=int, default=4, help="TPU chips per host"
93+
)
94+
95+
# `args` gets the launcher configs, `command` gets everything else
96+
args, command = parser.parse_known_args()
97+
98+
# 2. WORKER MODE
99+
if args.worker_mode:
100+
if not command:
101+
raise ValueError("No command provided for the worker to execute.")
102+
run_worker_and_command(command)
103+
return
104+
105+
# 3. LAUNCHER MODE
106+
if not command:
107+
logging.error(
108+
"Usage: python %s [LAUNCH_ARGS] <script.py> [SCRIPT_ARGS]",
109+
os.path.basename(__file__),
110+
)
111+
sys.exit(1)
112+
113+
coordinator_port = find_free_port()
114+
coordinator_address = f"localhost:{coordinator_port}"
115+
116+
slicebuilder_ports = [find_free_port() for _ in range(args.num_processes)]
117+
slicebuilder_addresses = ",".join(
118+
f"localhost:{port}" for port in slicebuilder_ports
119+
)
120+
121+
logging.info(
122+
"🚀 Starting %s JAX processes (%s chips/process)...",
123+
args.num_processes,
124+
args.tpu_chips_per_process,
125+
)
126+
logging.info("📍 Coordinator: %s", coordinator_address)
127+
128+
tpu_chips_per_process = args.tpu_chips_per_process
129+
num_tpu_chips = args.num_processes * args.tpu_chips_per_process
130+
if num_tpu_chips == 0:
131+
tpu_host_bounds = ""
132+
tpu_chips_per_host_bounds = ""
133+
elif num_tpu_chips == 1:
134+
assert tpu_chips_per_process == 1
135+
tpu_host_bounds = "1,1,1"
136+
tpu_chips_per_host_bounds = "1,1,1"
137+
elif num_tpu_chips == 4:
138+
if tpu_chips_per_process == 1:
139+
tpu_host_bounds = "2,2,1"
140+
tpu_chips_per_host_bounds = "1,1,1"
141+
elif tpu_chips_per_process == 2:
142+
tpu_host_bounds = "2,1,1"
143+
tpu_chips_per_host_bounds = "1,2,1"
144+
elif tpu_chips_per_process == 4:
145+
tpu_host_bounds = "1,1,1"
146+
tpu_chips_per_host_bounds = "2,2,1"
147+
else:
148+
raise ValueError(
149+
"Invalid number of TPU chips per worker {}".format(
150+
tpu_chips_per_process
151+
)
152+
)
153+
elif num_tpu_chips == 8:
154+
if tpu_chips_per_process == 1:
155+
tpu_host_bounds = "4,2,1"
156+
tpu_chips_per_host_bounds = "1,1,1"
157+
elif tpu_chips_per_process == 4:
158+
# Note: this branch assumes we are using 2x4 v6e LitePod, and will not
159+
# work with 4x2 v5e LitePod.
160+
tpu_host_bounds = "1,2,1"
161+
tpu_chips_per_host_bounds = "2,2,1"
162+
elif tpu_chips_per_process == 8:
163+
tpu_host_bounds = "1,1,1"
164+
tpu_chips_per_host_bounds = "2,4,1"
165+
else:
166+
# TODO(phawkins): implement other cases.
167+
raise ValueError(
168+
"Invalid number of TPU chips per worker {}".format(
169+
tpu_chips_per_process
170+
)
171+
)
172+
else:
173+
raise ValueError(f"Invalid number of TPU chips {num_tpu_chips}")
174+
175+
processes = []
176+
for rank in range(args.num_processes):
177+
env = os.environ.copy()
178+
179+
# JAX Distributed Setup
180+
env["JAX_COORDINATOR_ADDRESS"] = coordinator_address
181+
env["JAX_NUM_PROCESSES"] = str(args.num_processes)
182+
env["JAX_PROCESS_ID"] = str(rank)
183+
184+
device_ids = range(
185+
rank * args.tpu_chips_per_process,
186+
(rank + 1) * args.tpu_chips_per_process,
187+
)
188+
189+
# Simulated TPU Setup
190+
env["CLOUD_TPU_TASK_ID"] = str(rank)
191+
env["TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_host_bounds
192+
env["TPU_PROCESS_BOUNDS"] = tpu_host_bounds
193+
env["TPU_PROCESS_ADDRESSES"] = slicebuilder_addresses
194+
env["TPU_PROCESS_PORT"] = str(slicebuilder_ports[rank])
195+
env["TPU_VISIBLE_CHIPS"] = ",".join(map(str, device_ids))
196+
env["ALLOW_MULTIPLE_LIBTPU_LOAD"] = "1"
197+
198+
# Format the user's command to inject the current process rank where {rank}
199+
# is used
200+
worker_cmd = [c.format(rank=rank) for c in command]
201+
202+
# Spawn THIS script again, triggering worker_mode
203+
cmd = [sys.executable, __file__, "--worker_mode"] + worker_cmd
204+
205+
p = subprocess.Popen(cmd, env=env)
206+
processes.append(p)
207+
208+
exit_codes = [p.wait() for p in processes]
209+
210+
if any(c != 0 for c in exit_codes):
211+
logging.error("\n❌ Some processes failed.")
212+
sys.exit(1)
213+
else:
214+
logging.info("\n✅ All processes finished successfully.")
215+
216+
217+
if __name__ == "__main__":
218+
main()

0 commit comments

Comments
 (0)