Skip to content

Commit b270cc0

Browse files
author
Orbax Authors
committed
Update multiprocess unit tests to use an 8-TPU runner.
PiperOrigin-RevId: 868179942
1 parent 3ccf74c commit b270cc0

12 files changed

Lines changed: 3400 additions & 797 deletions

File tree

.github/workflows/build.yml

Lines changed: 82 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ jobs:
6464
pip install "jax>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
6565
fi
6666
- name: Test with pytest
67-
# TODO(yaning): Move these to an exclude target within pytest.ini.
67+
# TODO(nikhilbansall): Move these to an exclude target within pytest.ini.
6868
run: |
69-
python -m pytest --import-mode=importlib --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py --ignore=orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py --ignore=orbax/checkpoint/checkpoint_manager_test.py
69+
python -m pytest --import-mode=importlib --ignore=orbax/checkpoint/experimental/emergency/broadcast_multislice_test.py --ignore=orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/single_slice_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py --ignore=orbax/checkpoint/experimental/emergency/local_checkpoint_manager_test.py --ignore=orbax/checkpoint/experimental/emergency/multihost_test.py --ignore=orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/testing/multiprocess_test.py --ignore=orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py --ignore=orbax/checkpoint/checkpoint_manager_test.py --ignore=orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py
7070
# The below step just reports the success or failure of tests as a "commit status".
7171
# This is needed for copybara integration.
7272
- name: Report success or failure as github status
@@ -260,9 +260,9 @@ jobs:
260260
pip install -e .
261261
pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
262262
pip uninstall -y orbax
263-
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
263+
if [ "${{ matrix.jax-version }}" = "newest" ]; then
264264
pip install -U jax[k8s,cuda12] jaxlib
265-
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
265+
elif [ "${{ matrix.jax-version }}" = "nightly" ]; then
266266
pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
267267
else
268268
pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
@@ -296,20 +296,88 @@ jobs:
296296
echo "The following benchmarks failed:$failed_benchmarks"
297297
exit 1
298298
fi
299-
# cd orbax/checkpoint/_src/testing/benchmarks && python -c "import sys; import jax; jax.distributed.initialize(); print(jax.devices()); from absl import app; import run_benchmarks; sys.argv = ['run_benchmarks.py', '--config_file=configs/pytree_checkpoint_benchmark.yaml', '--output_directory=$GCS_BUCKET_PATH']; app.run(run_benchmarks.main)"
300-
# cd ../../../../..
301-
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
302-
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
303299
# The below step just reports the success or failure of tests as a "commit status".
304300
# This is needed for copybara integration.
301+
- name: Report success or failure as github status
302+
if: always()
303+
shell: bash
304+
run: |
305+
status="${{ job.status }}"
306+
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
307+
curl -sS --request POST \
308+
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
309+
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
310+
--header 'content-type: application/json' \
311+
--data '{
312+
"state": "'$lowercase_status'",
313+
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
314+
"description": "'$status'",
315+
"context": "github-actions/build"
316+
}'
317+
318+
multiprocess-unit-tests:
319+
name: "multiprocess-unit-tests (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
320+
# runs-on: linux-x86-ct5lp-4tpu-x2
321+
runs-on: linux-x86-ct5lp-224-8tpu
322+
container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e
323+
defaults:
324+
run:
325+
working-directory: checkpoint
326+
strategy:
327+
matrix:
328+
python-version: ["3.11"]
329+
jax-version: ["newest"]
330+
steps:
331+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
332+
- name: Set up Python ${{ matrix.python-version }}
333+
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
334+
with:
335+
python-version: ${{ matrix.python-version }}
336+
- name: Install dependencies
337+
run: |
338+
pip install -e .
339+
pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
340+
pip uninstall -y orbax
341+
pip install gcsfs
342+
pip install portpicker pytest chex pyyaml
343+
if [ "${{ matrix.jax-version }}" = "newest" ]; then
344+
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
345+
elif [ "${{ matrix.jax-version }}" = "nightly" ]; then
346+
pip install -U --pre "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
347+
else
348+
pip install "jax[tpu]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
349+
fi
305350
- name: Run multiprocess tests
306-
env:
307-
TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }}
351+
# env:
352+
# TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }}
353+
# export JAX_TPU_CHIPS_PER_HOST_BOUNDS=1,2,1
354+
# export LIBTPU_INIT_ARGS="--deepsea_chips_per_host_bounds=2,2,1 --deepsea_host_bounds=2,1,1"
308355
run: |
309-
python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)"
310-
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;"
311-
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
312-
# python -m pytest orbax/checkpoint/checkpoint_manager_test.py
356+
export JAX_TPU_CHIPS_PER_HOST_BOUNDS=2,2,1
357+
export JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT=10
358+
export LIBTPU_INIT_ARGS="--deepsea_chips_per_host_bounds=2,2,1 --deepsea_host_bounds=2,1,1" \
359+
360+
JAX_PLATFORMS=tpu \
361+
TPU_VISIBLE_DEVICES=0,1,2,3 \
362+
DEEPSEA_HAL_EXCLUDED_DEVS=0,1,2,3 \
363+
JAX_PROCESS_ID=0 JAX_NUM_PROCESSES=2 JAX_DISTRIBUTED_SERVICE_ADDR=localhost:1234 \
364+
JAX_NUM_TASKS=2 JAX_TASK_ID=0 NUM_PROCESSES=2 MULTIPROCESS_TEST_WORKER_ID=0 \
365+
JAX_ALLOW_UNUSED_TPUS=true \
366+
JAX_PORT=1234 \
367+
python orbax/checkpoint/_src/testing/multiprocess_unittests/run_tests.py \
368+
--filename=../tagged_tests.yaml --processes=4 --process_id=0 &
369+
sleep 5
370+
371+
JAX_PLATFORMS=tpu \
372+
TPU_VISIBLE_DEVICES=4,5,6,7 \
373+
DEEPSEA_HAL_EXCLUDED_DEVS=4,5,6,7 \
374+
JAX_ALLOW_UNUSED_TPUS=true \
375+
JAX_PROCESS_ID=1 JAX_NUM_PROCESSES=2 JAX_DISTRIBUTED_SERVICE_ADDR=localhost:1234 \
376+
JAX_NUM_TASKS=2 JAX_TASK_ID=1 NUM_PROCESSES=2 MULTIPROCESS_TEST_WORKER_ID=1 \
377+
python orbax/checkpoint/_src/testing/multiprocess_unittests/run_tests.py \
378+
--filename=../tagged_tests.yaml --processes=4 --process_id=1
379+
380+
wait
313381
- name: Report success or failure as github status
314382
if: always()
315383
shell: bash

0 commit comments

Comments
 (0)