Skip to content

Commit 0798438

Browse files
committed
Merge branch 'main' of github.com:AI-Hypercomputer/maxtext into shuningjin-ckpt-opt3
2 parents 10bef5d + 9f6b09a commit 0798438

83 files changed

Lines changed: 1918 additions & 1068 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ tests/inference/ @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @p
2222
src/maxtext/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
2323

2424
# Dockerfiles and dependencies
25-
src/dependencies/ @bvandermoon @parambole @richjames0 @shralex
25+
src/dependencies/ @bvandermoon @SurbhiJainUSC @parambole @richjames0 @shralex
2626

2727
# Docs
2828
docs/ @jacoguzo @bvandermoon @richjames0 @shralex @gobbleturk @RissyRan @gagika @A9isha @jiangjy1982 @vipannalla

.github/workflows/build_and_push_docker_image.yml

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ jobs:
5454
runs-on: linux-x86-n2-16-buildkit
5555
container: google/cloud-sdk:524.0.0
5656
if: >
57+
github.event_name == 'release' ||
5758
github.event_name == 'schedule' ||
5859
github.event_name == 'pull_request' ||
5960
github.event_name == 'workflow_dispatch' && (
@@ -86,15 +87,8 @@ jobs:
8687
# This ensures that every job clones the exact same commit as "setup" job
8788
ref: ${{ inputs.maxtext_sha }}
8889

89-
- name: Checkout post-training dependencies
90-
if: steps.check.outputs.should_run == 'true' && inputs.image_name == 'maxtext_post_training_nightly'
91-
run: |
92-
git clone https://github.com/google/tunix.git ./tunix
93-
git clone https://github.com/vllm-project/vllm.git ./vllm
94-
git clone https://github.com/vllm-project/tpu-inference.git ./tpu-inference
95-
9690
- name: Mark git repositories as safe
97-
run: git config --global --add safe.directory '*'
91+
run: git config --global --add safe.directory ${GITHUB_WORKSPACE}
9892
if: steps.check.outputs.should_run == 'true'
9993

10094
- name: Configure Docker
@@ -123,7 +117,6 @@ jobs:
123117
MODE=${{ inputs.build_mode }}
124118
WORKFLOW=${{ inputs.workflow }}
125119
PACKAGE_DIR=./src
126-
TESTS_DIR=./tests
127120
JAX_VERSION=NONE
128121
LIBTPU_VERSION=NONE
129122
INCLUDE_TEST_ASSETS=true
@@ -149,16 +142,6 @@ jobs:
149142
# Add MaxText tag
150143
maxtext_hash=$(git rev-parse --short HEAD)
151144
gcloud container images add-tag "$SOURCE_IMAGE:${{ github.run_id }}" "$SOURCE_IMAGE:maxtext_${maxtext_hash}_${clean_date}" --quiet
152-
153-
# Add post-training dependencies tags
154-
if [ "${{ inputs.workflow }}" == "post-training" ]; then
155-
for dir in tunix vllm tpu-inference; do
156-
if [ -d "./$dir" ]; then
157-
dir_hash=$(git -C "$dir" rev-parse --short HEAD)
158-
gcloud container images add-tag "$SOURCE_IMAGE:${{ github.run_id }}" "$SOURCE_IMAGE:${dir}_${dir_hash}_${clean_date}" --quiet
159-
fi
160-
done
161-
fi
162145
fi
163146
env:
164147
INPUTS_IMAGE_NAME: ${{ inputs.image_name }}

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ repos:
5252
args:
5353
- '--pyink-indentation=2'
5454
- '--line-length=122'
55-
- '--check'
5655

5756
- repo: https://github.com/executablebooks/mdformat
5857
rev: 0.7.22

LICENSE_HEADER

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Copyright 2023–2026 Google LLC
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.

benchmarks/maxtext_xpk_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def build_user_command(
428428
if wl_config.hlo_dump:
429429
hlo_dump = "XLA_FLAGS='--xla_dump_large_constants --xla_dump_to=/tmp/xla_dump'"
430430
upload_hlo_dump = (
431-
f" && gsutil -m cp -r /tmp/xla_dump {wl_config.base_output_directory}/{wl_config.run_name}/hlo_dump"
431+
f" && gcloud storage cp -r /tmp/xla_dump {wl_config.base_output_directory}/{wl_config.run_name}/hlo_dump"
432432
)
433433
# Construct the command string with proper formatting and line continuations
434434
command = " ".join(

benchmarks/upload_metrics_to_bq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def add_parser_arguments(parser: argparse.ArgumentParser):
187187

188188

189189
def download_metrics_file_locally(metrics_gcs_file: str, local_file: str) -> int:
190-
command = f"gsutil cp -r {metrics_gcs_file} {local_file}"
190+
command = f"gcloud storage cp --recursive {metrics_gcs_file} {local_file}"
191191
return run_command_with_updates(command, f"Download {metrics_gcs_file} in {local_file}")
192192

193193

docs/guides/data_input_pipeline.md

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,34 @@
1515
-->
1616

1717
(data-input-pipeline)=
18+
1819
# Data pipelines
1920

2021
Currently MaxText has three data input pipelines:
2122

22-
| Pipeline | Dataset formats | Features | Limitations |
23-
| -------- | --------------- | -------- | ----------- |
24-
| **[Grain](data_input_pipeline/data_input_grain.md)** (recommended)| [ArrayRecord](https://github.com/google/array_record) (random access, available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview), or [conversion](https://github.com/google/array_record/tree/main/beam))<br>[Parquet](https://arrow.apache.org/docs/python/parquet.html) (sequential access) | With arrayrecord: fully deterministic, resilient to preemption; global shuffle <br>With parquet: performant; fully deterministic, resilient to preemption; hierarchical shuffle | |
25-
| **[Hugging Face](data_input_pipeline/data_input_hf.md)** | datasets in [Hugging Face Hub](https://huggingface.co/datasets)<br>local/Cloud Storage datasets in json, parquet, arrow, csv, txt (sequential access) | no download needed, convenience; <br>multiple formats | limit scalability using the Hugging Face Hub (no limit using Cloud Storage); <br>non-deterministic with preemption<br>(deterministic without preemption)<br> |
26-
| **[TFDS](data_input_pipeline/data_input_tfds.md)** | TFRecord (sequential access), available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview) | performant | only supports TFRecords; <br>non-deterministic with preemption<br>(deterministic without preemption) |
23+
| Pipeline | Dataset formats | Features | Limitations |
24+
| ------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
25+
| **[Grain](data_input_pipeline/data_input_grain.md)** (recommended) | [ArrayRecord](https://github.com/google/array_record) (random access, available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview), or [conversion](https://github.com/google/array_record/tree/main/beam))<br>[TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access, available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview))<br>[Parquet](https://arrow.apache.org/docs/python/parquet.html) (sequential access) | With arrayrecord: fully deterministic, resilient to preemption; global shuffle <br>With parquet: performant; fully deterministic, resilient to preemption; hierarchical shuffle | |
26+
| **[Hugging Face](data_input_pipeline/data_input_hf.md)** | datasets in [Hugging Face Hub](https://huggingface.co/datasets)<br>local/Cloud Storage datasets in json, parquet, arrow, csv, txt (sequential access) | no download needed, convenience; <br>multiple formats | limit scalability using the Hugging Face Hub (no limit using Cloud Storage); <br>non-deterministic with preemption<br>(deterministic without preemption)<br> |
27+
| **[TFDS](data_input_pipeline/data_input_tfds.md)** | TFRecord (sequential access), available through [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview) | performant | only supports TFRecords; <br>non-deterministic with preemption<br>(deterministic without preemption) |
2728

2829
## Multihost dataloading best practice
30+
2931
Training in a multi-host environment presents unique challenges for data input pipelines. An effective data loading strategy must address three key issues:
32+
3033
1. **Concurrent access**: Multiple hosts need to read from the same dataset simultaneously without causing conflicts.
3134
2. **Data uniqueness**: Each host must be fed a unique, non-overlapping subset of the data to ensure the model sees each example correctly.
32-
3. **Uneven completion**: Handling the scenario where some hosts run out of data before others, which can lead to hanging.
33-
The approaches to solve these challenges depend on whether your dataset supports random access or is limited to sequential access.
35+
3. **Uneven completion**: Handling the scenario where some hosts run out of data before others, which can lead to hanging.
36+
The approaches to solve these challenges depend on whether your dataset supports random access or is limited to sequential access.
3437

3538
### Random access dataset (Recommended)
39+
3640
Random-access formats are highly recommended for multi-host training because they allow any part of the file to be read directly by its index.<br>
3741
In MaxText, this is best supported by the ArrayRecord format using the Grain input pipeline. This approach gracefully handles the key challenges:
38-
* **Concurrent access and uniqueness**: Grain assigns a unique set of indices to each host. ArrayRecord allows different hosts to read from different indices in the same file.
3942

40-
* **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/MaxText/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met.
43+
- **Concurrent access and uniqueness**: Grain assigns a unique set of indices to each host. ArrayRecord allows different hosts to read from different indices in the same file.
44+
45+
- **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/MaxText/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met.
4146

4247
```{note}
4348
When sequence packing is enabled, the difference in the number of packed examples per host can be larger. The `generate_padding_batch_train`/`generate_padding_batch_eval` flag still solves this.
@@ -48,12 +53,14 @@ If all hosts exhaust their data before the target step count is reached, both `t
4853
```
4954

5055
### Sequential access dataset
51-
* **Concurrent access and uniqueness**: Sequential-access datasets (e.g., Parquet, JSON, TFRecord) cannot be accessed by index, requiring a different strategy -- file-based sharding, where each host is given exclusive access to a specific subset of data files. **Key requirement**: `(Number of data files) % (Number of data-loading hosts) == 0`. If the file count isn't a multiple of the host count, the files will be distributed unevenly. For example, with 10 files and 8 hosts, some hosts will get two files while others get one, significantly worsening the "uneven completion" problem. If you have fewer files than hosts, performance will be severely degraded as all hosts are concurrently accessing all the files.
52-
* **Uneven completion**: Similar to random-access datasets, you can use the `generate_padding_batch_train`/`generate_padding_batch_eval` flag to handle hosts that finish their file shards early.
5356

54-
```{toctree}
55-
:hidden:
57+
- **Concurrent access and uniqueness**: Sequential-access datasets (e.g., Parquet, JSON, TFRecord) cannot be accessed by index, requiring a different strategy -- file-based sharding, where each host is given exclusive access to a specific subset of data files. **Key requirement**: `(Number of data files) % (Number of data-loading hosts) == 0`. If the file count isn't a multiple of the host count, the files will be distributed unevenly. For example, with 10 files and 8 hosts, some hosts will get two files while others get one, significantly worsening the "uneven completion" problem. If you have fewer files than hosts, performance will be severely degraded as all hosts are concurrently accessing all the files.
58+
- **Uneven completion**: Similar to random-access datasets, you can use the `generate_padding_batch_train`/`generate_padding_batch_eval` flag to handle hosts that finish their file shards early.
5659

60+
```{toctree}
61+
---
62+
hidden:
63+
---
5764
data_input_pipeline/data_input_grain
5865
data_input_pipeline/data_input_hf
5966
data_input_pipeline/data_input_tfds

docs/guides/data_input_pipeline/data_input_grain.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ Grain ensures determinism in data input pipelines by saving the pipeline's state
3232

3333
## Using Grain
3434

35-
1. Grain currently supports two data formats: [ArrayRecord](https://github.com/google/array_record) (random access) and [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources.html) class.
35+
1. Grain currently supports three data formats: [ArrayRecord](https://github.com/google/array_record) (random access), [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups) and [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord)(sequential access). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources.html) class.
3636
- **Community Resource**: The MaxText community has created a [ArrayRecord Documentation](https://array-record.readthedocs.io/). Note: we appreciate the contribution from the community, but as of now it has not been verified by the MaxText or ArrayRecord developers yet.
37-
2. When the dataset is hosted on a Cloud Storage bucket, Grain can read it through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh). The script configures some parameters for the mount.
37+
2. If the dataset is hosted on a Cloud Storage bucket, the path `gs://` can be provided directly. However, for the best performance, it's recommended to read the bucket through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). This will significantly improve the perf for the ArrayRecord format as it allows meta data caching to speeds up random access. The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/src/dependencies/scripts/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh). The script configures some parameters for the mount.
3838

3939
```sh
4040
bash tools/setup/setup_gcsfuse.sh \
@@ -45,7 +45,7 @@ MOUNT_PATH=${MOUNT_PATH?} \
4545

4646
Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pre-filling the metadata cache (see ["Performance tuning best practices" on the Google Cloud documentation](https://cloud.google.com/storage/docs/cloud-storage-fuse/performance#improve-first-time-reads)).
4747

48-
1. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet}`, `grain_train_files` in `src/maxtext/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path.
48+
1. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet|tfrecord}`, `grain_train_files` in `src/maxtext/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path.
4949

5050
2. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html), [grain_pool.py](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh) to avoid gcsfuse throttling.
5151

@@ -112,7 +112,7 @@ Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pr
112112
bash tools/setup/setup_gcsfuse.sh \
113113
DATASET_GCS_BUCKET=maxtext-dataset \
114114
MOUNT_PATH=/tmp/gcsfuse && \
115-
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
115+
python3 -m maxtext.trainers.pre_train.train \
116116
run_name=<RUN_NAME> base_output_directory=gs://<MY_BUCKET> \
117117
dataset_type=grain \
118118
grain_file_type=arrayrecord # or parquet \

0 commit comments

Comments
 (0)