Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ omit =
[paths]
source =
src/MaxText
src/MaxText
src/maxtext
*/site-packages/MaxText
*/site-packages/maxtext

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ repos:
# args:
# - '--jobs=auto'
# - '--keep-going'
# - 'src/MaxText/'
# - 'src/maxtext/'

- repo: https://github.com/google/pyink
rev: 24.10.1
Expand Down
1 change: 0 additions & 1 deletion benchmarks/convergence/c4_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from benchmarks.benchmark_utils import MaxTextModel, _add_to_model_dictionary
from benchmarks.convergence.convergence_utils import DatasetHParams, ConvHParams, _setup_model_convergence_

from benchmarks.maxtext_v5p_model_configs import deepseek_v3_ep_256_v5p_512

c4_pretrain_model_dict = {}
Expand Down
1 change: 0 additions & 1 deletion benchmarks/disruption_management/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import time

from benchmarks.disruption_management.disruption_utils import wait_for_pod_to_start

from benchmarks.disruption_management.disruption_handler import DisruptionConfig
from benchmarks.disruption_management.disruption_handler import TriggerType

Expand Down
10 changes: 5 additions & 5 deletions benchmarks/llama2_v6e-256_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

import maxtext_trillium_model_configs as model_configs

from maxtext_xpk_runner import BenchmarkRunner
from maxtext_xpk_runner import HWConfig
from maxtext_xpk_runner import SWconfig
from maxtext_xpk_runner import xpk_benchmark_runner
from maxtext_xpk_runner import XpkConfig
from benchmarks.maxtext_xpk_runner import BenchmarkRunner
from benchmarks.maxtext_xpk_runner import HWConfig
from benchmarks.maxtext_xpk_runner import SWconfig
from benchmarks.maxtext_xpk_runner import xpk_benchmark_runner
from benchmarks.maxtext_xpk_runner import XpkConfig


DATE = "20241009"
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/maxtext_xpk_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
import omegaconf

import benchmarks.maxtext_trillium_model_configs as model_configs
import benchmarks.xla_flags_library as xla_flags
from benchmarks.globals import MAXTEXT_PKG_DIR
from benchmarks.command_utils import run_command_with_updates
import benchmarks.xla_flags_library as xla_flags
from benchmarks.disruption_management.disruption_handler import DisruptionConfig
from benchmarks.disruption_management.disruption_manager import DisruptionManager
from benchmarks.xpk_configs import XpkClusterConfig
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/recipes/mcjax_long_running_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import benchmarks.maxtext_trillium_model_configs as model_configs
import benchmarks.maxtext_xpk_runner as mxr
from benchmarks.xpk_configs import XpkClusterConfig
from . import user_configs
from benchmarks.recipes import user_configs

# Cluster Params
CLUSTER = "v6e-256-cluster"
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/recipes/pw_elastic_training_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
from . import args_helper as helper
from . import user_configs

from benchmarks.disruption_management.disruption_handler import DisruptionMethod
from .runner_utils import generate_and_run_workloads
from benchmarks.recipes import args_helper as helper
from benchmarks.recipes import user_configs
from benchmarks.recipes.runner_utils import generate_and_run_workloads

user_configs.USER_CONFIG.max_restarts = 10
COMPARE_WITH_MCJAX = True
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/recipes/pw_headless_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"""

import benchmarks.recipes.args_helper as helper
from .. import maxtext_xpk_runner as mxr
from ..recipes.user_configs import USER_CONFIG
from benchmarks import maxtext_xpk_runner as mxr
from benchmarks.recipes.user_configs import USER_CONFIG


def main() -> int:
Expand Down
11 changes: 4 additions & 7 deletions benchmarks/recipes/pw_long_running_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,10 @@
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)

import recipes.args_helper as helper

import maxtext_trillium_model_configs as model_configs

import maxtext_xpk_runner as mxr

from xpk_configs import XpkClusterConfig
import benchmarks.maxtext_trillium_model_configs as model_configs
import benchmarks.maxtext_xpk_runner as mxr
import benchmarks.recipes.args_helper as helper
from benchmarks.xpk_configs import XpkClusterConfig

PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server"
SERVER_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server"
Expand Down
12 changes: 6 additions & 6 deletions benchmarks/recipes/pw_mcjax_benchmark_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
from . import args_helper as helper
from .user_configs import UserConfig
from .user_configs import USER_CONFIG
from .runner_utils import generate_and_run_workloads
from . import parser_utils
from benchmarks.recipes import args_helper as helper
from benchmarks.recipes import parser_utils
from benchmarks.recipes.pw_utils import check_and_create_bucket
from benchmarks.recipes.runner_utils import generate_and_run_workloads
from benchmarks.recipes.user_configs import UserConfig
from benchmarks.recipes.user_configs import USER_CONFIG
import argparse
from google.cloud import storage
from .pw_utils import check_and_create_bucket


def main(user_config) -> int:
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/recipes/pw_mcjax_checkpoint_benchmark_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import datetime
import dataclasses
import os
import args_helper as helper

from benchmarks import maxtext_trillium_model_configs as model_configs
import benchmarks.maxtext_xpk_runner as mxr
from benchmarks import maxtext_trillium_model_configs as model_configs
from benchmarks.recipes import args_helper as helper
from benchmarks.xpk_configs import XpkClusterConfig

PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server"
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/recipes/pw_remote_python_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import os

import args_helper as helper
import benchmarks.recipes.args_helper as helper

from benchmarks import maxtext_trillium_model_configs as model_configs
from benchmarks import maxtext_xpk_runner as mxr
Expand Down
7 changes: 3 additions & 4 deletions benchmarks/recipes/pw_suspend_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
from . import args_helper as helper
from . import user_configs

from benchmarks.disruption_management.disruption_handler import DisruptionMethod
from .runner_utils import generate_and_run_workloads
from benchmarks.recipes import args_helper as helper
from benchmarks.recipes import user_configs
from benchmarks.recipes.runner_utils import generate_and_run_workloads

user_configs.USER_CONFIG.max_restarts = 3
DISRUPTION_METHOD = DisruptionMethod.SIGTERM
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/recipes/pw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import typing

import maxtext_xpk_runner as mxr
import benchmarks.recipes.maxtext_xpk_runner as mxr
from google.api_core.exceptions import (
NotFound,
Conflict,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/recipes/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging

from .. import maxtext_xpk_runner as mxr
from benchmarks import maxtext_xpk_runner as mxr
from benchmarks.benchmark_utils import Framework
from benchmarks.disruption_management.disruption_manager import construct_disruption_configs

Expand Down
8 changes: 4 additions & 4 deletions benchmarks/recipes/user_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@

parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
from .. import maxtext_trillium_model_configs as v6e_model_configs
from .. import maxtext_v5e_model_configs as v5e_model_configs
from .. import maxtext_v5p_model_configs as v5p_model_configs
from .pw_utils import build_user_models, get_cluster_config, get_pathways_config
from benchmarks import maxtext_trillium_model_configs as v6e_model_configs
from benchmarks import maxtext_v5e_model_configs as v5e_model_configs
from benchmarks import maxtext_v5p_model_configs as v5p_model_configs
from benchmarks.recipes.pw_utils import build_user_models, get_cluster_config, get_pathways_config


AVAILABLE_MODELS_FRAMEWORKS = ["mcjax", "pathways"]
Expand Down
9 changes: 3 additions & 6 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,18 @@
codecov:
token: 35742a22-fb1f-4839-97ff-b54da5588689
# By default file names in the coverage report will have their path in the file system, which in our
# runners would be /__w/maxtext/maxtext/src/MaxText/* but Codecov expects src/MaxText/* so we need to fix the path
# runners would be /__w/maxtext/maxtext/src/maxtext/* but Codecov expects src/maxtext/* so we need to fix the path
fixes:
# - ".*/maxtext/src/::src/"
- "/github/workspace/::"
ignore:
- "src/maxtext/assets"
- "src/maxtext/configs"
- "src/maxtext/examples"
- "src/MaxText/experimental"
- "src/maxtext/experimental"
- "src/maxtext/inference"
- "src/maxtext/scratch_code"
- "src/MaxText/distillation" # code moved to src/maxtext/trainers/post_train/distillation
- "src/MaxText/sft" # code moved to src/maxtext/trainers/post_train/sft
- "src/MaxText/rl" # code moved to src/maxtext/trainers/post_train/rl

- "src/MaxText"

flags:
# Updated ONLY by PRs (contains subset of tests, excluding scheduled_only).
Expand Down
29 changes: 12 additions & 17 deletions docs/guides/checkpointing_solutions/convert_checkpoint.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Checkpoint conversion utilities

This guide provides instructions for using the [scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/checkpoint_conversion) that convert model checkpoints bidirectionally between Hugging Face and MaxText formats.
This guide provides instructions for using the [scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/checkpoint_conversion) that convert model checkpoints bidirectionally between Hugging Face and MaxText formats.

## Supported models

Expand Down Expand Up @@ -34,13 +34,8 @@ Use the `to_maxtext.py` script to convert a Hugging Face model into a MaxText ch
### Usage

First, make sure python3 virtual environment for MaxText is set up and enabled.

```bash
export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
pip install uv
uv venv --python 3.12 --seed ${VENV_NAME?}
source ${VENV_NAME?}/bin/activate
```
For instructions on installing MaxText on your VM, please refer to the [official documentation] and use the
maxtext[tpu-post-train] installation path to include all necessary post-training dependencies.

Second, ensure you have the necessary dependencies installed (e.g., install PyTorch for checkpoint conversion and logit check).

Expand All @@ -52,7 +47,7 @@ Third, setup following environment variables for conversion script

```bash
# -- Model configuration --
export MODEL_NAME=<Hugging Face Model to be converted to MaxText> # e.g. 'llama3.1-8b-Instruct'
export MODEL=<Hugging Face Model to be converted to MaxText> # e.g. 'llama3.1-8b-Instruct'
export HF_TOKEN=<Hugging Face access token> # your token to access gated HF repos

# -- MaxText configuration --
Expand All @@ -70,7 +65,7 @@ Finally, run below command to complete the conversion
# customize your "HF_HOME" to redirect the cache to a larger or mounted disk (e.g., on a TPU VM).
# export HF_HOME="/dev/shm/huggingface_tmp"
python3 -m maxtext.checkpoint_conversion.to_maxtext \
model_name=${MODEL_NAME?} \
model_name=${MODEL?} \
hf_access_token=${HF_TOKEN?} \
base_output_directory=${MODEL_CHECKPOINT_DIRECTORY?} \
scan_layers=True \
Expand All @@ -90,7 +85,7 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \
- `hardware=cpu`: run the conversion script on a CPU machine.
- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: Set to True for McJAX (default, `USE_PATHWAYS=0`); set to False for Pathways (`USE_PATHWAYS=1`). Both are controlled by the `$((1 - USE_PATHWAYS))` calculation in the example above.
- `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. When memory is constrained, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.

Above command will download the Hugging Face model to local machine if `hf_model_path` is unspecified, or reuse the checkpoint in `hf_model_path`. It will convert the checkpoint to the MaxText format and save it to `${MODEL_CHECKPOINT_DIRECTORY}/0/items`.

Expand All @@ -105,7 +100,7 @@ The following command converts a MaxText checkpoint and saves it locally, to GCS

```bash
python3 -m maxtext.checkpoint_conversion.to_huggingface \
model_name=<MODEL_NAME> \
model_name=<MODEL> \
load_parameters_path=<path-to-maxtext-checkpoint> \
base_output_directory=<path-to-save-converted-checkpoint> \
scan_layers=false \
Expand Down Expand Up @@ -134,7 +129,7 @@ To ensure the conversion was successful, you can use the [`tests/utils/forward_p
python3 -m tests.utils.forward_pass_logit_checker src/maxtext/configs/base.yml \
tokenizer_path=<tokenizer> \
load_parameters_path=<path-to-maxtext-checkpoint> \
model_name=<MODEL_NAME> \
model_name=<MODEL> \
scan_layers=false \
max_prefill_predict_length=4 \
max_target_length=8 \
Expand Down Expand Up @@ -213,12 +208,12 @@ To extend conversion support to a new model architecture, you must define its sp

1. **Add parameter mappings**:

- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the parameter name mappings(`def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING`). This is the 1-to-1 mappings of parameters names per layer.
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/param_mapping.py), add the parameter name mappings(`def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING`). This is the 1-to-1 mappings of parameters names per layer.
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.

2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`.
4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in [`src/maxtext/configs/models`](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture.
4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in [`src/maxtext/configs/models`](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture.

Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983)

Expand Down
2 changes: 1 addition & 1 deletion docs/guides/data_input_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ In MaxText, this is best supported by the ArrayRecord format using the Grain inp

- **Concurrent access and uniqueness**: Grain assigns a unique set of indices to each host. ArrayRecord allows different hosts to read from different indices in the same file.

- **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/MaxText/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met.
- **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/maxtext/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met.

```{note}
When sequence packing is enabled, the difference in the number of packed examples per host can be larger. The `generate_padding_batch_train`/`generate_padding_batch_eval` flag still solves this.
Expand Down
Loading
Loading