From 895661357d28a22e8b7139de929f9d3f323e9969 Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Mon, 23 Feb 2026 10:13:36 -0800 Subject: [PATCH] Re-enable multiprocess unit tests using a new multihost setup script. PiperOrigin-RevId: 874136334 --- .github/workflows/build.yml | 108 +++++---- .github/workflows/multiprocess_tests.yml | 12 +- .../_src/testing/multiprocess_tests.txt | 1 - .../generate_multiprocess_test.py | 0 .../multiprocess_benchmark_configs.txt | 0 .../_src/testing/oss/run_multihost.py | 218 ++++++++++++++++++ .../run_tests.py | 9 - .../tagged_tests.yaml | 0 8 files changed, 273 insertions(+), 75 deletions(-) delete mode 100644 checkpoint/orbax/checkpoint/_src/testing/multiprocess_tests.txt rename checkpoint/orbax/checkpoint/_src/testing/{multiprocess_unittests => oss}/generate_multiprocess_test.py (100%) rename checkpoint/orbax/checkpoint/_src/testing/{ => oss}/multiprocess_benchmark_configs.txt (100%) create mode 100644 checkpoint/orbax/checkpoint/_src/testing/oss/run_multihost.py rename checkpoint/orbax/checkpoint/_src/testing/{multiprocess_unittests => oss}/run_tests.py (95%) rename checkpoint/orbax/checkpoint/_src/testing/{multiprocess_unittests => oss}/tagged_tests.yaml (100%) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 75703d02e..d3d450274 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -75,7 +75,7 @@ jobs: --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py \ --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py \ --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py \ - $(python3 -c "import yaml; d=yaml.safe_load(open('orbax/checkpoint/_src/testing/multiprocess_unittests/tagged_tests.yaml')); print(' '.join(['--ignore=' + t.replace(':', '/') + '.py' for k,v in d.items() if k.startswith('processes') and v for t in v]))") + $(python3 -c "import yaml; d=yaml.safe_load(open('orbax/checkpoint/_src/testing/oss/tagged_tests.yaml')); print(' '.join(['--ignore=' + t.replace(':', '/') + '.py' for k,v in d.items() if k.startswith('processes') and v for t in v]))") # The below step just reports the success or failure of tests as a "commit status". # This is needed for copybara integration. - name: Report success or failure as github status @@ -284,13 +284,13 @@ jobs: TF_FORCE_GPU_ALLOW_GROWTH: true XLA_PYTHON_CLIENT_PREALLOCATE: false run: | - cd orbax/checkpoint/_src/testing + cd orbax/checkpoint/_src/testing/oss ls failed_benchmarks="" benchmark_configs_file="multiprocess_benchmark_configs.txt" echo "Running benchmarks specified in $benchmark_configs_file" benchmark_configs_file_path="$PWD/$benchmark_configs_file" - cd benchmarks + cd ../benchmarks while IFS= read -r entry || [ -n "$entry" ]; do if [ -n "$entry" ]; then echo "Running benchmark for $entry" @@ -324,55 +324,53 @@ jobs: "context": "github-actions/build" }' - # multiprocess-unit-tests: - # name: "multiprocess-unit-tests (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" - # runs-on: linux-x86-ct5lp-4tpu-x2 - # container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e - # defaults: - # run: - # working-directory: checkpoint - # strategy: - # matrix: - # python-version: ["3.11"] - # jax-version: ["newest"] - # steps: - # - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - # - name: Set up Python ${{ matrix.python-version }} - # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - # with: - # python-version: ${{ matrix.python-version }} - # - name: Install dependencies - # run: | - # pip install -e . - # pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - # pip uninstall -y orbax - # pip install gcsfs - # pip install portpicker pytest chex pyyaml - # if [ "${{ matrix.jax-version }}" = "newest" ]; then - # pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - # elif [ "${{ matrix.jax-version }}" = "nightly" ]; then - # pip install -U --pre "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ - # else - # pip install "jax[tpu]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - # fi - # - name: Run multiprocess tests - # env: - # TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }} - # run: | - # python orbax/checkpoint/_src/testing/multiprocess_unittests/run_tests.py --filename=orbax/checkpoint/_src/testing/multiprocess_unittests/tagged_tests.yaml --processes=4 - # - name: Report success or failure as github status - # if: always() - # shell: bash - # run: | - # status="${{ job.status }}" - # lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') - # curl -sS --request POST \ - # --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ - # --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ - # --header 'content-type: application/json' \ - # --data '{ - # "state": "'$lowercase_status'", - # "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", - # "description": "'$status'", - # "context": "github-actions/build" - # }' + multiprocess-unit-tests: + name: "multiprocess-unit-tests (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})" + runs-on: linux-x86-ct5lp-224-8tpu + container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e + defaults: + run: + working-directory: checkpoint + strategy: + matrix: + python-version: ["3.11"] + jax-version: ["newest"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install -e . + pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + pip uninstall -y orbax + pip install gcsfs + pip install portpicker pytest chex pyyaml + if [ "${{ matrix.jax-version }}" = "newest" ]; then + pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + elif [ "${{ matrix.jax-version }}" = "nightly" ]; then + pip install -U --pre "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/ + else + pip install "jax[tpu]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + fi + - name: Run multiprocess tests + run: | + python orbax/checkpoint/_src/testing/oss/run_multihost.py orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=4 + - name: Report success or failure as github status + if: always() + shell: bash + run: | + status="${{ job.status }}" + lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]') + curl -sS --request POST \ + --url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \ + --header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \ + --header 'content-type: application/json' \ + --data '{ + "state": "'$lowercase_status'", + "target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", + "description": "'$status'", + "context": "github-actions/build" + }' diff --git a/.github/workflows/multiprocess_tests.yml b/.github/workflows/multiprocess_tests.yml index dd8d9d462..cd76ed235 100644 --- a/.github/workflows/multiprocess_tests.yml +++ b/.github/workflows/multiprocess_tests.yml @@ -53,12 +53,12 @@ jobs: TF_FORCE_GPU_ALLOW_GROWTH: true XLA_PYTHON_CLIENT_PREALLOCATE: false run: | - cd orbax/checkpoint/_src/testing + cd orbax/checkpoint/_src/testing/oss failed_benchmarks="" benchmark_configs_file="multiprocess_benchmark_configs.txt" echo "Running benchmarks specified in $benchmark_configs_file" benchmark_configs_file_path="$PWD/$benchmark_configs_file" - cd benchmarks + cd ../benchmarks while IFS= read -r entry || [ -n "$entry" ]; do if [ -n "$entry" ]; then echo "Running benchmark for $entry" @@ -77,14 +77,6 @@ jobs: # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH # The below step just reports the success or failure of tests as a "commit status". # This is needed for copybara integration. - - name: Run multiprocess tests - env: - TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }} - run: | - python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)" - # python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;" - # cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH - # python -m pytest orbax/checkpoint/checkpoint_manager_test.py - name: Report success or failure as github status if: always() shell: bash diff --git a/checkpoint/orbax/checkpoint/_src/testing/multiprocess_tests.txt b/checkpoint/orbax/checkpoint/_src/testing/multiprocess_tests.txt deleted file mode 100644 index e5459132c..000000000 --- a/checkpoint/orbax/checkpoint/_src/testing/multiprocess_tests.txt +++ /dev/null @@ -1 +0,0 @@ -orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py diff --git a/checkpoint/orbax/checkpoint/_src/testing/multiprocess_unittests/generate_multiprocess_test.py b/checkpoint/orbax/checkpoint/_src/testing/oss/generate_multiprocess_test.py similarity index 100% rename from checkpoint/orbax/checkpoint/_src/testing/multiprocess_unittests/generate_multiprocess_test.py rename to checkpoint/orbax/checkpoint/_src/testing/oss/generate_multiprocess_test.py diff --git a/checkpoint/orbax/checkpoint/_src/testing/multiprocess_benchmark_configs.txt b/checkpoint/orbax/checkpoint/_src/testing/oss/multiprocess_benchmark_configs.txt similarity index 100% rename from checkpoint/orbax/checkpoint/_src/testing/multiprocess_benchmark_configs.txt rename to checkpoint/orbax/checkpoint/_src/testing/oss/multiprocess_benchmark_configs.txt diff --git a/checkpoint/orbax/checkpoint/_src/testing/oss/run_multihost.py b/checkpoint/orbax/checkpoint/_src/testing/oss/run_multihost.py new file mode 100644 index 000000000..408ceed56 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/run_multihost.py @@ -0,0 +1,218 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Launches and bootstraps tests across multiple simulated JAX processes.""" + +import argparse +import contextlib +import os +import runpy +import socket +import subprocess +import sys + +from absl import logging +import jax +import pytest + + +def find_free_port(): + with contextlib.closing( + socket.socket(socket.AF_INET, socket.SOCK_STREAM) + ) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def run_worker_and_command(command): + """Worker Mode: Initializes JAX explicitly, then executes the target command.""" + + coordinator_address = os.environ.get("JAX_COORDINATOR_ADDRESS") + num_processes = os.environ.get("JAX_NUM_PROCESSES") + process_id = os.environ.get("JAX_PROCESS_ID") + + if coordinator_address is None: + raise ValueError( + "Environment variables for JAX distributed not found. " + "Did you use launch_multihost.py?" + ) + + # Explicit Initialization + jax.distributed.initialize( + coordinator_address=coordinator_address, + num_processes=int(num_processes), + process_id=int(process_id), + ) + + print(f"[Rank {process_id}] JAX Initialized. Executing: {' '.join(command)}") + print(f"[Rank {process_id}] JAX devices: {jax.devices()}") + + # Clean up 'python' from the command if the user accidentally included it + if command[0] == "python" or command[0] == "python3": + command = command[1:] + + cmd_name = command[0] + + # Execute the requested script/tool inside this initialized process + if cmd_name == "pytest": + sys.exit(pytest.main(command[1:])) + + elif cmd_name.endswith(".py"): + # Overwrite sys.argv so the target script sees its expected arguments + sys.argv = command + runpy.run_path(cmd_name, run_name="__main__") + + else: + # Fallback for arbitrary shell commands + sys.exit(subprocess.call(command)) + + +def main(): + # 1. Parse arguments meant for launch.py + parser = argparse.ArgumentParser(description="JAX Multihost Launcher") + parser.add_argument( + "--worker_mode", action="store_true", help=argparse.SUPPRESS + ) + parser.add_argument( + "--num_processes", type=int, default=2, help="Number of simulated hosts" + ) + parser.add_argument( + "--tpu_chips_per_process", type=int, default=4, help="TPU chips per host" + ) + + # `args` gets the launcher configs, `command` gets everything else + args, command = parser.parse_known_args() + + # 2. WORKER MODE + if args.worker_mode: + if not command: + raise ValueError("No command provided for the worker to execute.") + run_worker_and_command(command) + return + + # 3. LAUNCHER MODE + if not command: + logging.error( + "Usage: python %s [LAUNCH_ARGS] [SCRIPT_ARGS]", + os.path.basename(__file__), + ) + sys.exit(1) + + coordinator_port = find_free_port() + coordinator_address = f"localhost:{coordinator_port}" + + slicebuilder_ports = [find_free_port() for _ in range(args.num_processes)] + slicebuilder_addresses = ",".join( + f"localhost:{port}" for port in slicebuilder_ports + ) + + logging.info( + "šŸš€ Starting %s JAX processes (%s chips/process)...", + args.num_processes, + args.tpu_chips_per_process, + ) + logging.info("šŸ“ Coordinator: %s", coordinator_address) + + tpu_chips_per_process = args.tpu_chips_per_process + num_tpu_chips = args.num_processes * args.tpu_chips_per_process + if num_tpu_chips == 0: + tpu_host_bounds = "" + tpu_chips_per_host_bounds = "" + elif num_tpu_chips == 1: + assert tpu_chips_per_process == 1 + tpu_host_bounds = "1,1,1" + tpu_chips_per_host_bounds = "1,1,1" + elif num_tpu_chips == 4: + if tpu_chips_per_process == 1: + tpu_host_bounds = "2,2,1" + tpu_chips_per_host_bounds = "1,1,1" + elif tpu_chips_per_process == 2: + tpu_host_bounds = "2,1,1" + tpu_chips_per_host_bounds = "1,2,1" + elif tpu_chips_per_process == 4: + tpu_host_bounds = "1,1,1" + tpu_chips_per_host_bounds = "2,2,1" + else: + raise ValueError( + "Invalid number of TPU chips per worker {}".format( + tpu_chips_per_process + ) + ) + elif num_tpu_chips == 8: + if tpu_chips_per_process == 1: + tpu_host_bounds = "4,2,1" + tpu_chips_per_host_bounds = "1,1,1" + elif tpu_chips_per_process == 4: + # Note: this branch assumes we are using 2x4 v6e LitePod, and will not + # work with 4x2 v5e LitePod. + tpu_host_bounds = "1,2,1" + tpu_chips_per_host_bounds = "2,2,1" + elif tpu_chips_per_process == 8: + tpu_host_bounds = "1,1,1" + tpu_chips_per_host_bounds = "2,4,1" + else: + # TODO(phawkins): implement other cases. + raise ValueError( + "Invalid number of TPU chips per worker {}".format( + tpu_chips_per_process + ) + ) + else: + raise ValueError(f"Invalid number of TPU chips {num_tpu_chips}") + + processes = [] + for rank in range(args.num_processes): + env = os.environ.copy() + + # JAX Distributed Setup + env["JAX_COORDINATOR_ADDRESS"] = coordinator_address + env["JAX_NUM_PROCESSES"] = str(args.num_processes) + env["JAX_PROCESS_ID"] = str(rank) + + device_ids = range( + rank * args.tpu_chips_per_process, + (rank + 1) * args.tpu_chips_per_process, + ) + + # Simulated TPU Setup + env["CLOUD_TPU_TASK_ID"] = str(rank) + env["TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_host_bounds + env["TPU_PROCESS_BOUNDS"] = tpu_host_bounds + env["TPU_PROCESS_ADDRESSES"] = slicebuilder_addresses + env["TPU_PROCESS_PORT"] = str(slicebuilder_ports[rank]) + env["TPU_VISIBLE_CHIPS"] = ",".join(map(str, device_ids)) + env["ALLOW_MULTIPLE_LIBTPU_LOAD"] = "1" + + # Format the user's command to inject the current process rank where {rank} + # is used + worker_cmd = [c.format(rank=rank) for c in command] + + # Spawn THIS script again, triggering worker_mode + cmd = [sys.executable, __file__, "--worker_mode"] + worker_cmd + + p = subprocess.Popen(cmd, env=env) + processes.append(p) + + exit_codes = [p.wait() for p in processes] + + if any(c != 0 for c in exit_codes): + logging.error("\nāŒ Some processes failed.") + sys.exit(1) + else: + logging.info("\nāœ… All processes finished successfully.") + + +if __name__ == "__main__": + main() diff --git a/checkpoint/orbax/checkpoint/_src/testing/multiprocess_unittests/run_tests.py b/checkpoint/orbax/checkpoint/_src/testing/oss/run_tests.py similarity index 95% rename from checkpoint/orbax/checkpoint/_src/testing/multiprocess_unittests/run_tests.py rename to checkpoint/orbax/checkpoint/_src/testing/oss/run_tests.py index 74e10ba47..8e66f33f0 100755 --- a/checkpoint/orbax/checkpoint/_src/testing/multiprocess_unittests/run_tests.py +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/run_tests.py @@ -22,7 +22,6 @@ from absl import app from absl import flags from absl import logging -import jax import pytest import yaml @@ -87,14 +86,6 @@ def main(argv: Sequence[str]) -> None: install_deps() - try: - jax.distributed.initialize() - logging.info('JAX devices: %s', jax.devices()) - except RuntimeError as e: - logging.warning( - 'Could not initialize jax.distributed: %s. Proceeding without it.', e - ) - try: with open(FLAGS.filename, 'r') as f: try: diff --git a/checkpoint/orbax/checkpoint/_src/testing/multiprocess_unittests/tagged_tests.yaml b/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests.yaml similarity index 100% rename from checkpoint/orbax/checkpoint/_src/testing/multiprocess_unittests/tagged_tests.yaml rename to checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests.yaml