From 7a97864ce07dbfedc3df4b652290694fb06e1188 Mon Sep 17 00:00:00 2001 From: Surbhi Jain Date: Tue, 24 Mar 2026 13:31:52 -0700 Subject: [PATCH] Update checkpoint conversion documentation PiperOrigin-RevId: 888829289 --- .coveragerc | 2 +- .pre-commit-config.yaml | 2 +- benchmarks/convergence/c4_exp.py | 1 - benchmarks/disruption_management/monitor.py | 1 - benchmarks/llama2_v6e-256_benchmarks.py | 10 +- benchmarks/maxtext_xpk_runner.py | 2 +- .../recipes/mcjax_long_running_recipe.py | 2 +- .../recipes/pw_elastic_training_recipe.py | 6 +- benchmarks/recipes/pw_headless_mode.py | 4 +- benchmarks/recipes/pw_long_running_recipe.py | 11 +- .../recipes/pw_mcjax_benchmark_recipe.py | 12 +- .../pw_mcjax_checkpoint_benchmark_recipe.py | 4 +- benchmarks/recipes/pw_remote_python_recipe.py | 2 +- benchmarks/recipes/pw_suspend_resume.py | 7 +- benchmarks/recipes/pw_utils.py | 2 +- benchmarks/recipes/runner_utils.py | 2 +- benchmarks/recipes/user_configs.py | 8 +- codecov.yml | 9 +- .../convert_checkpoint.md | 29 +- docs/guides/data_input_pipeline.md | 2 +- docs/guides/model_bringup.md | 8 +- .../optimization/benchmark_and_performance.md | 2 +- .../pallas_kernels_performance.md | 6 +- .../architecture/architecture_overview.md | 2 +- .../core_concepts/moe_configuration.md | 4 +- .../supported_models_and_architectures.md | 12 +- docs/tutorials/first_run.md | 2 +- .../tutorials/posttraining/full_finetuning.md | 38 +- docs/tutorials/posttraining/multimodal.md | 2 +- docs/tutorials/posttraining/rl.md | 2 +- .../posttraining/rl_on_multi_host.md | 2 +- docs/tutorials/posttraining/sft.md | 2 +- .../posttraining/sft_on_multi_host.md | 2 +- src/MaxText/README.md | 4 +- .../llama_mistral_mixtral_orbax_to_hf.py | 2 +- .../checkpoint_conversion/to_huggingface.py | 2 +- .../checkpoint_conversion/to_maxtext.py | 4 +- src/maxtext/configs/base.yml | 2 +- .../configs/tpu/v5p/gpt3_175b/v5p_12288.sh | 2 +- .../examples/sft_llama3_demo_gpu.ipynb | 1200 ++++++++--------- src/maxtext/examples/sft_qwen3_demo.ipynb | 1172 ++++++++-------- .../agent/ckpt_conversion_agent/README.md | 20 +- .../baselines/one-shot-agent.ipynb | 898 ++++++------ .../distillation/distillation_utils.py | 2 +- .../distillation/save_top_k_teacher_logits.py | 2 +- .../trainers/post_train/rl/train_rl.py | 4 +- tests/end_to_end/gpu/a3/test_llama2_7b.sh | 4 +- tests/end_to_end/tpu/deepseek/Run_DeepSeek.md | 4 +- .../tpu/deepseek/v2-16b/test_deepseek.sh | 4 +- .../tpu/deepseek/v3-671b/2_test_deepseek.sh | 4 +- tests/end_to_end/tpu/gemma/Run_Gemma.md | 2 +- tests/end_to_end/tpu/gemma3/Run_Gemma3.md | 2 +- tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md | 6 +- tests/end_to_end/tpu/mixtral/Run_Mixtral.md | 2 +- .../2_test_qwen3_next_80b_a3b.sh | 4 +- tests/end_to_end/tpu/test_grpo.sh | 4 +- tests/inference/test_llama2_7b_bf16.sh | 4 +- tests/inference/test_llama2_7b_int8.sh | 4 +- tests/utils/forward_pass_logit_checker.py | 2 +- tools/dev/code_style.sh | 2 +- 60 files changed, 1773 insertions(+), 1791 deletions(-) diff --git a/.coveragerc b/.coveragerc index aceea9571f..7e4acd5b39 100644 --- a/.coveragerc +++ b/.coveragerc @@ -10,7 +10,7 @@ omit = [paths] source = src/MaxText - src/MaxText + src/maxtext */site-packages/MaxText */site-packages/maxtext diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 39313ef66c..1b601dbe72 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,7 +42,7 @@ repos: # args: # - '--jobs=auto' # - '--keep-going' - # - 'src/MaxText/' + # - 'src/maxtext/' - repo: https://github.com/google/pyink rev: 24.10.1 diff --git a/benchmarks/convergence/c4_exp.py b/benchmarks/convergence/c4_exp.py index c349a935e3..daa2241893 100644 --- a/benchmarks/convergence/c4_exp.py +++ b/benchmarks/convergence/c4_exp.py @@ -23,7 +23,6 @@ from benchmarks.benchmark_utils import MaxTextModel, _add_to_model_dictionary from benchmarks.convergence.convergence_utils import DatasetHParams, ConvHParams, _setup_model_convergence_ - from benchmarks.maxtext_v5p_model_configs import deepseek_v3_ep_256_v5p_512 c4_pretrain_model_dict = {} diff --git a/benchmarks/disruption_management/monitor.py b/benchmarks/disruption_management/monitor.py index 31960a7d19..5952f5f935 100644 --- a/benchmarks/disruption_management/monitor.py +++ b/benchmarks/disruption_management/monitor.py @@ -29,7 +29,6 @@ import time from benchmarks.disruption_management.disruption_utils import wait_for_pod_to_start - from benchmarks.disruption_management.disruption_handler import DisruptionConfig from benchmarks.disruption_management.disruption_handler import TriggerType diff --git a/benchmarks/llama2_v6e-256_benchmarks.py b/benchmarks/llama2_v6e-256_benchmarks.py index 2b56e0e3e6..e7bceea2b1 100644 --- a/benchmarks/llama2_v6e-256_benchmarks.py +++ b/benchmarks/llama2_v6e-256_benchmarks.py @@ -19,11 +19,11 @@ import maxtext_trillium_model_configs as model_configs -from maxtext_xpk_runner import BenchmarkRunner -from maxtext_xpk_runner import HWConfig -from maxtext_xpk_runner import SWconfig -from maxtext_xpk_runner import xpk_benchmark_runner -from maxtext_xpk_runner import XpkConfig +from benchmarks.maxtext_xpk_runner import BenchmarkRunner +from benchmarks.maxtext_xpk_runner import HWConfig +from benchmarks.maxtext_xpk_runner import SWconfig +from benchmarks.maxtext_xpk_runner import xpk_benchmark_runner +from benchmarks.maxtext_xpk_runner import XpkConfig DATE = "20241009" diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index 2a7ebc8b66..f0510533c7 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -35,9 +35,9 @@ import omegaconf import benchmarks.maxtext_trillium_model_configs as model_configs +import benchmarks.xla_flags_library as xla_flags from benchmarks.globals import MAXTEXT_PKG_DIR from benchmarks.command_utils import run_command_with_updates -import benchmarks.xla_flags_library as xla_flags from benchmarks.disruption_management.disruption_handler import DisruptionConfig from benchmarks.disruption_management.disruption_manager import DisruptionManager from benchmarks.xpk_configs import XpkClusterConfig diff --git a/benchmarks/recipes/mcjax_long_running_recipe.py b/benchmarks/recipes/mcjax_long_running_recipe.py index 17222723d0..9fcd24b5bc 100644 --- a/benchmarks/recipes/mcjax_long_running_recipe.py +++ b/benchmarks/recipes/mcjax_long_running_recipe.py @@ -27,7 +27,7 @@ import benchmarks.maxtext_trillium_model_configs as model_configs import benchmarks.maxtext_xpk_runner as mxr from benchmarks.xpk_configs import XpkClusterConfig -from . import user_configs +from benchmarks.recipes import user_configs # Cluster Params CLUSTER = "v6e-256-cluster" diff --git a/benchmarks/recipes/pw_elastic_training_recipe.py b/benchmarks/recipes/pw_elastic_training_recipe.py index 7ab1f3b7e6..3a3f68aba0 100644 --- a/benchmarks/recipes/pw_elastic_training_recipe.py +++ b/benchmarks/recipes/pw_elastic_training_recipe.py @@ -25,11 +25,11 @@ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(parent_dir) -from . import args_helper as helper -from . import user_configs from benchmarks.disruption_management.disruption_handler import DisruptionMethod -from .runner_utils import generate_and_run_workloads +from benchmarks.recipes import args_helper as helper +from benchmarks.recipes import user_configs +from benchmarks.recipes.runner_utils import generate_and_run_workloads user_configs.USER_CONFIG.max_restarts = 10 COMPARE_WITH_MCJAX = True diff --git a/benchmarks/recipes/pw_headless_mode.py b/benchmarks/recipes/pw_headless_mode.py index eaac22782b..94064bf1fe 100644 --- a/benchmarks/recipes/pw_headless_mode.py +++ b/benchmarks/recipes/pw_headless_mode.py @@ -22,8 +22,8 @@ """ import benchmarks.recipes.args_helper as helper -from .. import maxtext_xpk_runner as mxr -from ..recipes.user_configs import USER_CONFIG +from benchmarks import maxtext_xpk_runner as mxr +from benchmarks.recipes.user_configs import USER_CONFIG def main() -> int: diff --git a/benchmarks/recipes/pw_long_running_recipe.py b/benchmarks/recipes/pw_long_running_recipe.py index f1f2b52c81..3e1c14d32c 100644 --- a/benchmarks/recipes/pw_long_running_recipe.py +++ b/benchmarks/recipes/pw_long_running_recipe.py @@ -27,13 +27,10 @@ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(parent_dir) -import recipes.args_helper as helper - -import maxtext_trillium_model_configs as model_configs - -import maxtext_xpk_runner as mxr - -from xpk_configs import XpkClusterConfig +import benchmarks.maxtext_trillium_model_configs as model_configs +import benchmarks.maxtext_xpk_runner as mxr +import benchmarks.recipes.args_helper as helper +from benchmarks.xpk_configs import XpkClusterConfig PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server" SERVER_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server" diff --git a/benchmarks/recipes/pw_mcjax_benchmark_recipe.py b/benchmarks/recipes/pw_mcjax_benchmark_recipe.py index 575150bbaa..7deeb3cabf 100644 --- a/benchmarks/recipes/pw_mcjax_benchmark_recipe.py +++ b/benchmarks/recipes/pw_mcjax_benchmark_recipe.py @@ -18,14 +18,14 @@ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(parent_dir) -from . import args_helper as helper -from .user_configs import UserConfig -from .user_configs import USER_CONFIG -from .runner_utils import generate_and_run_workloads -from . import parser_utils +from benchmarks.recipes import args_helper as helper +from benchmarks.recipes import parser_utils +from benchmarks.recipes.pw_utils import check_and_create_bucket +from benchmarks.recipes.runner_utils import generate_and_run_workloads +from benchmarks.recipes.user_configs import UserConfig +from benchmarks.recipes.user_configs import USER_CONFIG import argparse from google.cloud import storage -from .pw_utils import check_and_create_bucket def main(user_config) -> int: diff --git a/benchmarks/recipes/pw_mcjax_checkpoint_benchmark_recipe.py b/benchmarks/recipes/pw_mcjax_checkpoint_benchmark_recipe.py index 01ef72df63..e40deffcc7 100644 --- a/benchmarks/recipes/pw_mcjax_checkpoint_benchmark_recipe.py +++ b/benchmarks/recipes/pw_mcjax_checkpoint_benchmark_recipe.py @@ -22,10 +22,10 @@ import datetime import dataclasses import os -import args_helper as helper -from benchmarks import maxtext_trillium_model_configs as model_configs import benchmarks.maxtext_xpk_runner as mxr +from benchmarks import maxtext_trillium_model_configs as model_configs +from benchmarks.recipes import args_helper as helper from benchmarks.xpk_configs import XpkClusterConfig PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server" diff --git a/benchmarks/recipes/pw_remote_python_recipe.py b/benchmarks/recipes/pw_remote_python_recipe.py index 62653bb3ab..374fa2863e 100644 --- a/benchmarks/recipes/pw_remote_python_recipe.py +++ b/benchmarks/recipes/pw_remote_python_recipe.py @@ -21,7 +21,7 @@ import os -import args_helper as helper +import benchmarks.recipes.args_helper as helper from benchmarks import maxtext_trillium_model_configs as model_configs from benchmarks import maxtext_xpk_runner as mxr diff --git a/benchmarks/recipes/pw_suspend_resume.py b/benchmarks/recipes/pw_suspend_resume.py index addc28c1e7..490d25c0d8 100644 --- a/benchmarks/recipes/pw_suspend_resume.py +++ b/benchmarks/recipes/pw_suspend_resume.py @@ -25,11 +25,10 @@ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(parent_dir) -from . import args_helper as helper -from . import user_configs - from benchmarks.disruption_management.disruption_handler import DisruptionMethod -from .runner_utils import generate_and_run_workloads +from benchmarks.recipes import args_helper as helper +from benchmarks.recipes import user_configs +from benchmarks.recipes.runner_utils import generate_and_run_workloads user_configs.USER_CONFIG.max_restarts = 3 DISRUPTION_METHOD = DisruptionMethod.SIGTERM diff --git a/benchmarks/recipes/pw_utils.py b/benchmarks/recipes/pw_utils.py index 4ec4bca9c3..03a070003e 100644 --- a/benchmarks/recipes/pw_utils.py +++ b/benchmarks/recipes/pw_utils.py @@ -20,7 +20,7 @@ import typing -import maxtext_xpk_runner as mxr +import benchmarks.recipes.maxtext_xpk_runner as mxr from google.api_core.exceptions import ( NotFound, Conflict, diff --git a/benchmarks/recipes/runner_utils.py b/benchmarks/recipes/runner_utils.py index 43626b59f8..f1f45feda2 100644 --- a/benchmarks/recipes/runner_utils.py +++ b/benchmarks/recipes/runner_utils.py @@ -16,7 +16,7 @@ import logging -from .. import maxtext_xpk_runner as mxr +from benchmarks import maxtext_xpk_runner as mxr from benchmarks.benchmark_utils import Framework from benchmarks.disruption_management.disruption_manager import construct_disruption_configs diff --git a/benchmarks/recipes/user_configs.py b/benchmarks/recipes/user_configs.py index 5d283c8b47..0e2ad05e1e 100644 --- a/benchmarks/recipes/user_configs.py +++ b/benchmarks/recipes/user_configs.py @@ -27,10 +27,10 @@ parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(parent_dir) -from .. import maxtext_trillium_model_configs as v6e_model_configs -from .. import maxtext_v5e_model_configs as v5e_model_configs -from .. import maxtext_v5p_model_configs as v5p_model_configs -from .pw_utils import build_user_models, get_cluster_config, get_pathways_config +from benchmarks import maxtext_trillium_model_configs as v6e_model_configs +from benchmarks import maxtext_v5e_model_configs as v5e_model_configs +from benchmarks import maxtext_v5p_model_configs as v5p_model_configs +from benchmarks.recipes.pw_utils import build_user_models, get_cluster_config, get_pathways_config AVAILABLE_MODELS_FRAMEWORKS = ["mcjax", "pathways"] diff --git a/codecov.yml b/codecov.yml index f5971c2a21..302d8bc243 100644 --- a/codecov.yml +++ b/codecov.yml @@ -27,7 +27,7 @@ codecov: token: 35742a22-fb1f-4839-97ff-b54da5588689 # By default file names in the coverage report will have their path in the file system, which in our -# runners would be /__w/maxtext/maxtext/src/MaxText/* but Codecov expects src/MaxText/* so we need to fix the path +# runners would be /__w/maxtext/maxtext/src/maxtext/* but Codecov expects src/maxtext/* so we need to fix the path fixes: # - ".*/maxtext/src/::src/" - "/github/workspace/::" @@ -35,13 +35,10 @@ ignore: - "src/maxtext/assets" - "src/maxtext/configs" - "src/maxtext/examples" - - "src/MaxText/experimental" + - "src/maxtext/experimental" - "src/maxtext/inference" - "src/maxtext/scratch_code" - - "src/MaxText/distillation" # code moved to src/maxtext/trainers/post_train/distillation - - "src/MaxText/sft" # code moved to src/maxtext/trainers/post_train/sft - - "src/MaxText/rl" # code moved to src/maxtext/trainers/post_train/rl - + - "src/MaxText" flags: # Updated ONLY by PRs (contains subset of tests, excluding scheduled_only). diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index 30e9750bdf..f1f6fbb4c6 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -1,6 +1,6 @@ # Checkpoint conversion utilities -This guide provides instructions for using the [scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/checkpoint_conversion) that convert model checkpoints bidirectionally between Hugging Face and MaxText formats. +This guide provides instructions for using the [scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/checkpoint_conversion) that convert model checkpoints bidirectionally between Hugging Face and MaxText formats. ## Supported models @@ -34,13 +34,8 @@ Use the `to_maxtext.py` script to convert a Hugging Face model into a MaxText ch ### Usage First, make sure python3 virtual environment for MaxText is set up and enabled. - -```bash -export VENV_NAME= # e.g., maxtext_venv -pip install uv -uv venv --python 3.12 --seed ${VENV_NAME?} -source ${VENV_NAME?}/bin/activate -``` +For instructions on installing MaxText on your VM, please refer to the [official documentation] and use the +maxtext[tpu-post-train] installation path to include all necessary post-training dependencies. Second, ensure you have the necessary dependencies installed (e.g., install PyTorch for checkpoint conversion and logit check). @@ -52,7 +47,7 @@ Third, setup following environment variables for conversion script ```bash # -- Model configuration -- -export MODEL_NAME= # e.g. 'llama3.1-8b-Instruct' +export MODEL= # e.g. 'llama3.1-8b-Instruct' export HF_TOKEN= # your token to access gated HF repos # -- MaxText configuration -- @@ -70,7 +65,7 @@ Finally, run below command to complete the conversion # customize your "HF_HOME" to redirect the cache to a larger or mounted disk (e.g., on a TPU VM). # export HF_HOME="/dev/shm/huggingface_tmp" python3 -m maxtext.checkpoint_conversion.to_maxtext \ - model_name=${MODEL_NAME?} \ + model_name=${MODEL?} \ hf_access_token=${HF_TOKEN?} \ base_output_directory=${MODEL_CHECKPOINT_DIRECTORY?} \ scan_layers=True \ @@ -90,7 +85,7 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \ - `hardware=cpu`: run the conversion script on a CPU machine. - `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: Set to True for McJAX (default, `USE_PATHWAYS=0`); set to False for Pathways (`USE_PATHWAYS=1`). Both are controlled by the `$((1 - USE_PATHWAYS))` calculation in the example above. - `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. When memory is constrained, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes. -- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. +- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. Above command will download the Hugging Face model to local machine if `hf_model_path` is unspecified, or reuse the checkpoint in `hf_model_path`. It will convert the checkpoint to the MaxText format and save it to `${MODEL_CHECKPOINT_DIRECTORY}/0/items`. @@ -105,7 +100,7 @@ The following command converts a MaxText checkpoint and saves it locally, to GCS ```bash python3 -m maxtext.checkpoint_conversion.to_huggingface \ - model_name= \ + model_name= \ load_parameters_path= \ base_output_directory= \ scan_layers=false \ @@ -134,7 +129,7 @@ To ensure the conversion was successful, you can use the [`tests/utils/forward_p python3 -m tests.utils.forward_pass_logit_checker src/maxtext/configs/base.yml \ tokenizer_path= \ load_parameters_path= \ - model_name= \ + model_name= \ scan_layers=false \ max_prefill_predict_length=4 \ max_target_length=8 \ @@ -213,12 +208,12 @@ To extend conversion support to a new model architecture, you must define its sp 1. **Add parameter mappings**: -- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the parameter name mappings(`def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING`). This is the 1-to-1 mappings of parameters names per layer. -- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer. +- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/param_mapping.py), add the parameter name mappings(`def {MODEL}_MAXTEXT_TO_HF_PARAM_MAPPING`). This is the 1-to-1 mappings of parameters names per layer. +- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer. -2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion. +2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion. 3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`. -4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in [`src/maxtext/configs/models`](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture. +4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in [`src/maxtext/configs/models`](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture. Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983) diff --git a/docs/guides/data_input_pipeline.md b/docs/guides/data_input_pipeline.md index 8772db513e..0b65bdfa6c 100644 --- a/docs/guides/data_input_pipeline.md +++ b/docs/guides/data_input_pipeline.md @@ -42,7 +42,7 @@ In MaxText, this is best supported by the ArrayRecord format using the Grain inp - **Concurrent access and uniqueness**: Grain assigns a unique set of indices to each host. ArrayRecord allows different hosts to read from different indices in the same file. -- **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/MaxText/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met. +- **Uneven completion**: Data indices are distributed evenly among hosts. Without packing, the data imbalance between hosts will be at most one batch. To handle the final steps where some hosts run out of data, you can enable the `generate_padding_batch_train`/`generate_padding_batch_eval` flag in `src/maxtext/config/base.yml` or through command line arguments. This directs hosts to generate empty "padding" batches until the training or evaluation steps are met. ```{note} When sequence packing is enabled, the difference in the number of packed examples per host can be larger. The `generate_padding_batch_train`/`generate_padding_batch_eval` flag still solves this. diff --git a/docs/guides/model_bringup.md b/docs/guides/model_bringup.md index 1ba7e7181a..cc590202d3 100644 --- a/docs/guides/model_bringup.md +++ b/docs/guides/model_bringup.md @@ -26,13 +26,13 @@ The first phase involves determining how the new model's architecture aligns wit **Tokenizer**: Supported [tokenizer options](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/input_pipeline/tokenizer.py) include `TikTokenTokenizer`, `SentencePieceTokenizer`, and `HFTokenizer`. -**Self-Attention & RoPE**: Available mechanisms include optimized [Flash Attention](https://github.com/AI-Hypercomputer/maxtext/blob/62ee818144eb037ad3fe85ab8e789cd074776f46/src/MaxText/layers/attention_op.py#L1184) (supporting MHA, GQA, and MQA), Multi-head Latent Attention ([MLA](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/attention_mla.py)), and [Gated Delta Network](https://github.com/AI-Hypercomputer/maxtext/blob/62ee818144eb037ad3fe85ab8e789cd074776f46/src/MaxText/models/qwen3.py#L358). MaxText also supports [Regular](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L108), [Llama](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L178), and [YaRN](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L282) variations of Rotary Positional Embeddings (RoPE). +**Self-Attention & RoPE**: Available mechanisms include optimized [Flash Attention](https://github.com/AI-Hypercomputer/maxtext/blob/62ee818144eb037ad3fe85ab8e789cd074776f46/src/MaxText/layers/attention_op.py#L1184) (supporting MHA, GQA, and MQA), Multi-head Latent Attention ([MLA](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/attention_mla.py)), and [Gated Delta Network](https://github.com/AI-Hypercomputer/maxtext/blob/62ee818144eb037ad3fe85ab8e789cd074776f46/src/MaxText/models/qwen3.py#L358). MaxText also supports [Regular](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L108), [Llama](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L178), and [YaRN](https://github.com/AI-Hypercomputer/maxtext/blob/88d2ffd34c0ace76f836c7ea9c2fe4cd2d271088/MaxText/layers/embeddings.py#L282) variations of Rotary Positional Embeddings (RoPE). **Multi-Layer Perceptron (MLP)**: The framework supports both traditional dense models and Mixture of Experts (MoE) architectures, including [configurations](https://maxtext.readthedocs.io/en/latest/reference/core_concepts/moe_configuration.html) for routed and shared experts. -**Normalization**: We support different [normalization strategies](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/normalizations.py), including RMSNorm and Gated RMSNorm. These can be configured before or after attention/MLP layers. +**Normalization**: We support different [normalization strategies](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/normalizations.py), including RMSNorm and Gated RMSNorm. These can be configured before or after attention/MLP layers. -**Decoder Layers**: Models can have multiple [decoder layers](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/models) with varying structures. The trend has evolved from entirely dense layers to purely MoE layers, and now towards a mix of both. +**Decoder Layers**: Models can have multiple [decoder layers](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/models) with varying structures. The trend has evolved from entirely dense layers to purely MoE layers, and now towards a mix of both. ## 2. (Optional) Feature Implementation @@ -58,7 +58,7 @@ Success starts with a clear map. You must align the parameter names from your so ### 3.2 Write Script -Use existing model scripts within the repository as templates to tailor the conversion logic for your specific architecture. We strongly recommended to use the [checkpoint conversion utility](https://maxtext.readthedocs.io/en/latest/guides/checkpointing_solutions/convert_checkpoint.html) rather than [standalone scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/checkpoint_conversion/standalone_scripts). +Use existing model scripts within the repository as templates to tailor the conversion logic for your specific architecture. We strongly recommended to use the [checkpoint conversion utility](https://maxtext.readthedocs.io/en/latest/guides/checkpointing_solutions/convert_checkpoint.html) rather than [standalone scripts](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/checkpoint_conversion/standalone_scripts). ### 3.3 Verify Compatibility diff --git a/docs/guides/optimization/benchmark_and_performance.md b/docs/guides/optimization/benchmark_and_performance.md index f0d1b15433..03d26a8595 100644 --- a/docs/guides/optimization/benchmark_and_performance.md +++ b/docs/guides/optimization/benchmark_and_performance.md @@ -51,7 +51,7 @@ Remat policies can be chosen from: `minimal_with_context`, `minimal`, `save_dot_ These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest) -`minimal_with_context` consumes the most HBM memory, while `full` signifies minimal checkpointing, with everything being rematerialized. [More explanation and latest support](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/decoders.py#L287) +`minimal_with_context` consumes the most HBM memory, while `full` signifies minimal checkpointing, with everything being rematerialized. [More explanation and latest support](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/decoders.py#L287) **Custom policy** diff --git a/docs/guides/optimization/pallas_kernels_performance.md b/docs/guides/optimization/pallas_kernels_performance.md index b4884f6b17..29ad2f2659 100644 --- a/docs/guides/optimization/pallas_kernels_performance.md +++ b/docs/guides/optimization/pallas_kernels_performance.md @@ -58,7 +58,7 @@ To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth - **Training Attention (Flash/Splash-style):** This kernel is the default for training Transformer models in MaxText, such as DeepSeek, Gemma and Llama. It avoids creating the large [L,L] attention matrix to save memory, processing data in smaller, tiled chunks with online softmax accumulation. - - [`src/MaxText/kernels/attention/splash_attention_kernel.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/attention/splash_attention_kernel.py) + - [`src/maxtext/kernels/attention/splash_attention_kernel.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/kernels/attention/splash_attention_kernel.py) - **Serving Attention (Paged & Ragged):** For high-throughput inference, this kernel efficiently fetches non-contiguous "pages" of the KV cache from memory. It is a key optimization for our serving stack and is used for models running on MaxText's inference engine. @@ -69,9 +69,9 @@ To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth > This is an efficient computation method for Mixture-of-Experts (MoE) models like DeepSeek, Llama 4, Mixtral and Qwen-MoE. In MoE, each token is processed by only a few "experts," which is inefficient for standard matrix multiplication. Megablox solves this by having the CPU (**host**) first create a routing plan (**metadata**) that assigns tokens to experts. The accelerator (**device**) then uses this plan to perform many small, dense matrix multiplications in parallel (**Grouped Matrix Multiplication**), avoiding wasted work on unused experts. - - [`src/MaxText/kernels/megablox/gmm.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py) + - [`src/maxtext/kernels/megablox/gmm.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/kernels/megablox/gmm.py) - **Note:** Megablox accelerates the grouped **matmul**; **routing/gating** is separate code ([`src/MaxText/layers/moe.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/moe.py)). + **Note:** Megablox accelerates the grouped **matmul**; **routing/gating** is separate code ([`src/maxtext/layers/moe.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/moe.py)). ## 🔧 The Pallas optimization workflow: code → profile → tune → repeat diff --git a/docs/reference/architecture/architecture_overview.md b/docs/reference/architecture/architecture_overview.md index 4d3f16f5a9..7035ffda9c 100644 --- a/docs/reference/architecture/architecture_overview.md +++ b/docs/reference/architecture/architecture_overview.md @@ -161,7 +161,7 @@ Performance can be further tuned by setting specific XLA flags in the configurat One of the most significant performance levers available in MaxText is the integration of Google's Accurate Quantized Training (AQT) and Qwix libraries. These enable training with reduced numerical precision, reducing memory requirements and often increasing FLOPS, while maintaining model quality and convergence characteristics that are very close to the full-precision baseline. Integration into MaxText is seamless for the user. Quantization can be enabled by simply setting, for example, `quantization: 'int8'` in the configuration file. This flag activates quantization-aware layers (defined in -[`src/MaxText/layers/quantizations.py`](https://github.com/AI-Hypercomputer/maxtext/blob/db7b85be153e6b7ca387a8d02c991f9d35bae6bd/src/MaxText/layers/quantizations.py)) that are applied to the relevant dense layers within the model's Flax definition. The quantization library handles the complexities of simulating quantization during the forward and backward passes, allowing the model to learn weights that are robust to the reduced precision. +[`src/maxtext/layers/quantizations.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/layers/quantizations.py)) that are applied to the relevant dense layers within the model's Flax definition. The quantization library handles the complexities of simulating quantization during the forward and backward passes, allowing the model to learn weights that are robust to the reduced precision. ## The ecosystem: interoperability and advanced features diff --git a/docs/reference/core_concepts/moe_configuration.md b/docs/reference/core_concepts/moe_configuration.md index c5e42b7153..7ce7d63110 100644 --- a/docs/reference/core_concepts/moe_configuration.md +++ b/docs/reference/core_concepts/moe_configuration.md @@ -16,7 +16,7 @@ # Mixture of Experts (MoE) Configuration -This document provides a detailed explanation of the configuration parameters related to Mixture of Experts (MoE) models in MaxText. These settings control the model architecture, routing mechanisms, and performance optimizations. Default values and parameter definitions are located in `src/maxtext/configs/base.yml` and are primarily used in `src/MaxText/layers/moe.py`. +This document provides a detailed explanation of the configuration parameters related to Mixture of Experts (MoE) models in MaxText. These settings control the model architecture, routing mechanisms, and performance optimizations. Default values and parameter definitions are located in `src/maxtext/configs/base.yml` and are primarily used in `src/maxtext/layers/moe.py`. ## 1. Architecture @@ -30,7 +30,7 @@ MaxText supports both Dropless and Dropping strategies. Please refer to the deci Dropless: - [Tokamax Ragged Dot](https://github.com/openxla/tokamax/tree/main/tokamax/_src/ops/ragged_dot): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=True`. -- [Megablox](https://github.com/google/maxtext/tree/main/src/MaxText/kernels/megablox): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=True`. +- [Megablox](https://github.com/google/maxtext/tree/main/src/maxtext/kernels/megablox): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=True`. - [JAX Ragged Dot](https://docs.jax.dev/en/latest/_autosummary/jax.lax.ragged_dot.html): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=False`. - Dense Matmul: Enabled by setting `sparse_matmul=False, capacity_factor=-1`. diff --git a/docs/reference/models/supported_models_and_architectures.md b/docs/reference/models/supported_models_and_architectures.md index 5001b26f56..f329c9ba2b 100644 --- a/docs/reference/models/supported_models_and_architectures.md +++ b/docs/reference/models/supported_models_and_architectures.md @@ -80,12 +80,12 @@ The following summarizes observed runtime efficiency and scaling behaviors of Ma - **Model Implementation Guides & Source Code:** - - **Llama**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/llama2/run_llama2.md) | [Llama2 and Llama3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/llama2.py) | [Llama4 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/llama4.py) - - **Gemma**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma/Run_Gemma.md) | [Gemma Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/gemma.py) | [Gemma2 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/gemma2.py) | [Gemma3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/gemma3.py) - - **Mixtral**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/mixtral/Run_Mixtral.md) | [Mixtral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/mixtral.py) | [Mistral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/mistral.py) - - **DeepSeek**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md) | [DeepSeek Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/deepseek.py) - - **Qwen3**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/moe/run_qwen_moe.md) | [Qwen3-Next Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md) | [Qwen3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/qwen3.py) | [Qwen3-Next Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/qwen3.py) - - **GPT-OSS**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md) | [GPT-OSS Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/models/gpt_oss.py) + - **Llama**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/llama2/run_llama2.md) | [Llama2 and Llama3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/llama2.py) | [Llama4 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/llama4.py) + - **Gemma**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma/Run_Gemma.md) | [Gemma Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma.py) | [Gemma2 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma2.py) | [Gemma3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma3.py) + - **Mixtral**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/mixtral/Run_Mixtral.md) | [Mixtral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/mixtral.py) | [Mistral Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/mistral.py) + - **DeepSeek**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md) | [DeepSeek Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/deepseek.py) + - **Qwen3**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/moe/run_qwen_moe.md) | [Qwen3-Next Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md) | [Qwen3 Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/qwen3.py) | [Qwen3-Next Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/qwen3.py) + - **GPT-OSS**: [Guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md) | [GPT-OSS Source](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gpt_oss.py) - **Technical Explanations:** diff --git a/docs/tutorials/first_run.md b/docs/tutorials/first_run.md index 3b7468129b..58b71f80ac 100644 --- a/docs/tutorials/first_run.md +++ b/docs/tutorials/first_run.md @@ -66,7 +66,7 @@ In the same TPU VM where you just installed all the dependencies of MaxText, You #### Decoding in MaxText via notebook -You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint. +You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint. ### Run MaxText on NVIDIA GPUs diff --git a/docs/tutorials/posttraining/full_finetuning.md b/docs/tutorials/posttraining/full_finetuning.md index 45f6e9eb5b..8ebf2f525b 100644 --- a/docs/tutorials/posttraining/full_finetuning.md +++ b/docs/tutorials/posttraining/full_finetuning.md @@ -24,29 +24,19 @@ In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get st ## Install dependencies -```sh -# 1. Clone the repository -git clone https://github.com/AI-Hypercomputer/maxtext.git -cd maxtext - -# 2. Create virtual environment -export VENV_NAME= # e.g., maxtext_venv -pip install uv -uv venv --python 3.12 --seed ${VENV_NAME?} -source ${VENV_NAME?}/bin/activate - -# 3. Install dependencies in editable mode -uv pip install -e .[tpu] --resolution=lowest -install_maxtext_github_deps -``` +For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/install_maxtext.html) and use the `maxtext[tpu]` installation path to include all necessary dependencies. ## Setup environment variables +Follow the instructions [here](https://huggingface.co/docs/huggingface_hub/v0.21.2/guides/cli) to login to Hugging Face using your access token using + +```bash +huggingface-cli login +``` + ```sh # -- Model configuration -- -export MODEL_NAME= # e.g., 'llama3.1-8b' -export MODEL_TOKENIZER= # e.g., 'meta-llama/Llama-3.1-8B-Instruct' -export HF_TOKEN= +export MODEL= # e.g., 'llama3.1-8b-Instruct' # -- MaxText configuration -- export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-output-directory @@ -62,15 +52,15 @@ This section explains how to prepare your model checkpoint for use with MaxText. If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section. ```sh -export MODEL_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items +export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items ``` ### Option 2: Converting a Hugging Face checkpoint -Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. +Refer the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. ```bash -export MODEL_CKPT_PATH= # gs://my-bucket/my-checkpoint-directory/0/items +export MAXTEXT_CKPT_PATH= # gs://my-bucket/my-checkpoint-directory/0/items ``` ## Dataset @@ -103,12 +93,10 @@ Below is a sample training script. python3 -m maxtext.trainers.pre_train.train \ run_name=${RUN_NAME?} \ base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ - load_parameters_path=${MODEL_CKPT_PATH?} \ - model_name=${MODEL_NAME?} \ + load_parameters_path=${MAXTEXT_CKPT_PATH?} \ + model_name=${MODEL?} \ dataset_path=${DATASET_GCS_BUCKET?} \ async_checkpointing=False \ - tokenizer_path=${MODEL_TOKENIZER?} \ - hf_access_token=${HF_TOKEN?} \ steps=10 per_device_batch_size=1 ``` diff --git a/docs/tutorials/posttraining/multimodal.md b/docs/tutorials/posttraining/multimodal.md index 65bbc1a78d..f19f43b6fd 100644 --- a/docs/tutorials/posttraining/multimodal.md +++ b/docs/tutorials/posttraining/multimodal.md @@ -25,7 +25,7 @@ Multimodal Large Language Models (LLMs) extend traditional text-only models by i ## Checkpoint Conversion -Recently we have onboarded a new centralized tool for bidirectional checkpoint conversion between MaxText and HuggingFace ([README](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/README.md)). +Recently we have onboarded a new centralized tool for bidirectional checkpoint conversion between MaxText and HuggingFace ([README](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/README.md)). Install pytorch: diff --git a/docs/tutorials/posttraining/rl.md b/docs/tutorials/posttraining/rl.md index d8cfc17a44..ec486684e2 100644 --- a/docs/tutorials/posttraining/rl.md +++ b/docs/tutorials/posttraining/rl.md @@ -87,7 +87,7 @@ export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucke ### Option 2: Converting from a Hugging Face checkpoint -Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. +Refer the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. ```bash export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items diff --git a/docs/tutorials/posttraining/rl_on_multi_host.md b/docs/tutorials/posttraining/rl_on_multi_host.md index d1c20a68b2..033400a172 100644 --- a/docs/tutorials/posttraining/rl_on_multi_host.md +++ b/docs/tutorials/posttraining/rl_on_multi_host.md @@ -101,7 +101,7 @@ export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucke ### Option 2: Converting from a Hugging Face checkpoint -Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. +Refer the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. ```bash export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items diff --git a/docs/tutorials/posttraining/sft.md b/docs/tutorials/posttraining/sft.md index c7ed9f45c5..ad7613df95 100644 --- a/docs/tutorials/posttraining/sft.md +++ b/docs/tutorials/posttraining/sft.md @@ -68,7 +68,7 @@ export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucke ### Option 2: Converting a Hugging Face checkpoint -Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. +Refer the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. ```sh export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items diff --git a/docs/tutorials/posttraining/sft_on_multi_host.md b/docs/tutorials/posttraining/sft_on_multi_host.md index b54819cee0..766dd30e63 100644 --- a/docs/tutorials/posttraining/sft_on_multi_host.md +++ b/docs/tutorials/posttraining/sft_on_multi_host.md @@ -92,7 +92,7 @@ checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) ### Option 2: Converting a Hugging Face checkpoint -Refer the steps in [Hugging Face to MaxText](../../guides/checkpointing_solutions/convert_checkpoint.md#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. +Refer the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. ```bash export MAXTEXT_CKPT_PATH= # gs://my-bucket/my-checkpoint-directory/0/items diff --git a/src/MaxText/README.md b/src/MaxText/README.md index c720fed8b8..1aecb22261 100644 --- a/src/MaxText/README.md +++ b/src/MaxText/README.md @@ -14,9 +14,9 @@ # limitations under the License. --> -# src/MaxText +# src/maxtext -The contents of `src/MaxText` have moved to `src/MaxText` as part of a larger +The contents of `src/maxtext` have moved to `src/maxtext` as part of a larger [restructuring effort in MaxText](https://github.com/AI-Hypercomputer/maxtext/blob/2790ed289c0c4cb704645d5d2ab91da26711b891/RESTRUCTURE.md). This directory only contains shim files to temporarily support legacy commands like `python3 -m MaxText.train ...`. These legacy commands are now deprecated and will be removed soon. Please migrate your existing commands and avoid using diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/llama_mistral_mixtral_orbax_to_hf.py b/src/maxtext/checkpoint_conversion/standalone_scripts/llama_mistral_mixtral_orbax_to_hf.py index 29ffb3ba28..3e0f1f4574 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/llama_mistral_mixtral_orbax_to_hf.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/llama_mistral_mixtral_orbax_to_hf.py @@ -25,7 +25,7 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.llama_mistral_mixtral_orbax_to_hf src/maxtext/configs/base.yml base_output_directory=path/to/saving/intermediate_MaxText_files - load_parameters_path=/path/to/src/MaxText/checkpoint run_name= model_name= + load_parameters_path=/path/to/src/maxtext/checkpoint run_name= model_name= hardware=gpu hf_model_path=/local/path/to/save/HF/model/to diff --git a/src/maxtext/checkpoint_conversion/to_huggingface.py b/src/maxtext/checkpoint_conversion/to_huggingface.py index 4aa5429deb..94c70860c7 100644 --- a/src/maxtext/checkpoint_conversion/to_huggingface.py +++ b/src/maxtext/checkpoint_conversion/to_huggingface.py @@ -44,7 +44,7 @@ To convert a gemma2-2b MaxText checkpoint and save it to a local directory: export HF_AUTH_TOKEN="hf_YOUR_TOKEN" - python src/MaxText/checkpoint_conversion/to_huggingface.py \ + python src/maxtext/checkpoint_conversion/to_huggingface.py \ src/maxtext/configs/base.yml \ model_name="gemma2-2b" \ load_parameters_path="/path/to/your/maxtext/checkpoint/" \ diff --git a/src/maxtext/checkpoint_conversion/to_maxtext.py b/src/maxtext/checkpoint_conversion/to_maxtext.py index 26dbeb214b..a77893df2f 100644 --- a/src/maxtext/checkpoint_conversion/to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/to_maxtext.py @@ -40,7 +40,7 @@ Example Usage: To convert a gemma2-2b model and save it to a specific directory: - /usr/bin/time -v python src/MaxText/checkpoint_conversion/to_maxtext.py \ + /usr/bin/time -v python src/maxtext/checkpoint_conversion/to_maxtext.py \ maxtext/configs/base.yml model_name="gemma2-2b" \ base_output_directory="/path/to/your/output/directory" \ hf_access_token=${HF_TOKEN?} hardware=cpu skip_jax_distributed_system=True \ @@ -51,7 +51,7 @@ To convert a 70B model with minimal RAM usage: - /usr/bin/time -v python src/MaxText/checkpoint_conversion/to_maxtext.py \ + /usr/bin/time -v python src/maxtext/checkpoint_conversion/to_maxtext.py \ maxtext/configs/base.yml model_name="llama3.1-70b" \ base_output_directory="gs://my-bucket/maxtext-checkpoints" \ hf_access_token=${HF_TOKEN?} hardware=cpu skip_jax_distributed_system=True \ diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index eea7b9c2d0..f3be5b6ace 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1134,7 +1134,7 @@ use_tokamax_splash: false use_jax_splash: false # vLLM Adapter Configurations -# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter) +# Path to the HuggingFace-style config directory for the adapter (e.g. src/maxtext/integration/vllm/maxtext_vllm_adapter) vllm_hf_config_path: "" # A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter. # This can be used to override specific settings without modifying the original config file. diff --git a/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_12288.sh b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_12288.sh index 926dff0abc..6f9cb366a8 100644 --- a/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_12288.sh +++ b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_12288.sh @@ -11,5 +11,5 @@ set -euox pipefail RUNNAME=${1:-${RUNNAME:-some-run}} BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} -chmod +x "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/v5p/gpt3_175b/gpt3_175b_base.sh +chmod +x "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}"/v5p/gpt3_175b/gpt3_175b_base.sh ./maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 16 48 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" \ No newline at end of file diff --git a/src/maxtext/examples/sft_llama3_demo_gpu.ipynb b/src/maxtext/examples/sft_llama3_demo_gpu.ipynb index aaa0fe5fcf..467f8b8917 100644 --- a/src/maxtext/examples/sft_llama3_demo_gpu.ipynb +++ b/src/maxtext/examples/sft_llama3_demo_gpu.ipynb @@ -1,603 +1,603 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "7687de92-1dfb-4237-b663-30cda55dc8e1", - "metadata": {}, - "source": [ - "# Supervised Fine-Tuning of Llama 3.1-8B on NVIDIA GPUs with JAX and MaxText\n", - "\n", - "## Overview\n", - "\n", - "This tutorial walks you through supervised fine-tuning (SFT) of Llama 3.1-8B on NVIDIA GPUs using JAX and MaxText. You'll learn how to take a pretrained Llama checkpoint, convert it into MaxText's native format, configure an SFT training run, and verify the result with a quick inference test.\n", - "\n", - "**What you'll do:**\n", - "1. Set up the environment and authenticate with Hugging Face\n", - "2. Download and convert the Llama 3.1-8B checkpoint to MaxText format\n", - "3. Configure and launch supervised fine-tuning on the UltraChat 200k dataset\n", - "4. Visualize training metrics with TensorBoard\n", - "5. Run a quick inference sanity check\n", - "\n", - "## Prerequisites\n", - "\n", - "### Make sure you have supported hardware\n", - "\n", - "**Hardware requirements.** Full-parameter SFT of Llama 3.1-8B is memory-intensive due to optimizer state, activations, and sharded model parameters. We recommend a system with **8 NVIDIA GPUs with at least 80 GB of memory each** (e.g., A100-80GB, H100-80GB, or H200). This allows the model, optimizer state, and activations to be cleanly sharded across devices without aggressive memory tuning.\n", - "\n", - "When running `nvidia-smi`, you should see eight or more visible GPUs, each reporting at least 80 GB of total memory, with recent drivers, CUDA 12.x+ support, and minimal memory usage before training starts." - ] + "cells": [ + { + "cell_type": "markdown", + "id": "7687de92-1dfb-4237-b663-30cda55dc8e1", + "metadata": {}, + "source": [ + "# Supervised Fine-Tuning of Llama 3.1-8B on NVIDIA GPUs with JAX and MaxText\n", + "\n", + "## Overview\n", + "\n", + "This tutorial walks you through supervised fine-tuning (SFT) of Llama 3.1-8B on NVIDIA GPUs using JAX and MaxText. You'll learn how to take a pretrained Llama checkpoint, convert it into MaxText's native format, configure an SFT training run, and verify the result with a quick inference test.\n", + "\n", + "**What you'll do:**\n", + "1. Set up the environment and authenticate with Hugging Face\n", + "2. Download and convert the Llama 3.1-8B checkpoint to MaxText format\n", + "3. Configure and launch supervised fine-tuning on the UltraChat 200k dataset\n", + "4. Visualize training metrics with TensorBoard\n", + "5. Run a quick inference sanity check\n", + "\n", + "## Prerequisites\n", + "\n", + "### Make sure you have supported hardware\n", + "\n", + "**Hardware requirements.** Full-parameter SFT of Llama 3.1-8B is memory-intensive due to optimizer state, activations, and sharded model parameters. We recommend a system with **8 NVIDIA GPUs with at least 80 GB of memory each** (e.g., A100-80GB, H100-80GB, or H200). This allows the model, optimizer state, and activations to be cleanly sharded across devices without aggressive memory tuning.\n", + "\n", + "When running `nvidia-smi`, you should see eight or more visible GPUs, each reporting at least 80 GB of total memory, with recent drivers, CUDA 12.x+ support, and minimal memory usage before training starts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dhc3l20703b", + "metadata": {}, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "id": "c7671875-fea3-4abe-9f01-57c854f50f92", + "metadata": {}, + "source": [ + "### Get your Hugging Face token\n", + "\n", + "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", + "\n", + "**Follow these steps to get your token:**\n", + "\n", + "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", + "\n", + "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", + "\n", + "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", + "\n", + "4. **Copy the generated token**. You will need this in the later steps.\n", + "\n", + "**Follow these steps to store your token (only if running on Google Colab):**\n", + "\n", + "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", + "\n", + "2. Click **\"+ Add new secret\"**.\n", + "\n", + "3. Set the Name as **HF_TOKEN**.\n", + "\n", + "4. Paste your token into the Value field.\n", + "\n", + "5. Ensure the Notebook access toggle is turned On." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7999f002-65a1-4764-ba41-922f2fec43df", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from google.colab import userdata\n", + " print(\"Running the notebook on Google Colab\")\n", + " IN_COLAB = True\n", + "except ImportError:\n", + " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", + " IN_COLAB = False" + ] + }, + { + "cell_type": "markdown", + "id": "6b691306-88ee-47b3-b841-cd6d072f51eb", + "metadata": {}, + "source": [ + "### Authenticate with Hugging Face\n", + "\n", + "Verify that your Hugging Face token is set and valid by calling the Hub's `whoami` endpoint. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93a0030b-ab25-4d76-a173-1a464c19087a", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from huggingface_hub import HfApi\n", + "\n", + "if IN_COLAB:\n", + " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", + "else:\n", + " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", + "\n", + "if not HF_TOKEN:\n", + " from getpass import getpass\n", + " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", + "\n", + "if not HF_TOKEN:\n", + " raise RuntimeError(\"Authentication failed: Hugging Face token is not set.\")\n", + "\n", + "# Ensure token is set in this process\n", + "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n", + "\n", + "# Verify identity\n", + "api = HfApi()\n", + "user_info = api.whoami(token=HF_TOKEN)\n", + "username = user_info.get(\"name\") or \"Unknown user\"\n", + "\n", + "print(f\"Authenticated with Hugging Face successfully as: {username}\")" + ] + }, + { + "cell_type": "markdown", + "id": "887cd139-a776-43ad-aeea-8547fcd8d744", + "metadata": {}, + "source": [ + "### Acquire permission to use the gated model\n", + "\n", + "Llama 3.1-8B is a gated model, so you must explicitly request access before it can be downloaded. Visit the [model page](https://huggingface.co/meta-llama/Llama-3.1-8B) on Hugging Face, log in with the same account linked to your access token, and click **Request access**. You'll need to agree to Meta's license terms; approval is usually granted quickly but is not automatic. Once approved, your Hugging Face token will authorize downloads transparently. If you skip this step, model downloads will fail even with a valid token." + ] + }, + { + "cell_type": "markdown", + "id": "66b25ee3-1072-4c0f-965d-72e911873d1c", + "metadata": {}, + "source": [ + "## Get the model and convert it into MaxText format\n", + "\n", + "### Import dependencies\n", + "\n", + "#### Core libraries and installation\n", + "\n", + "Import the core libraries needed for this tutorial:\n", + "\n", + "- **JAX**: High-performance ML framework with automatic differentiation and XLA compilation\n", + "- **MaxText**: Google's production-grade training stack for JAX, providing model architectures, checkpoint management, and the SFT training loop\n", + "\n", + "The easiest way to get a working environment is the [NVIDIA NGC JAX container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax), which ships with JAX, CUDA, and MaxText preinstalled. To install the dependencies manually:\n", + "\n", + "```bash\n", + "pip install 'jax[cuda13]' maxtext\n", + "```\n", + "\n", + "On top of it, for the model conversion step you will also need **Torch**, the CPU version would be enough:\n", + "\n", + "```bash\n", + "pip install torch --index-url https://download.pytorch.org/whl/cpu\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "531951be-2f24-455a-b9ff-b07aeeb2de1d", + "metadata": {}, + "outputs": [], + "source": [ + "# Imports\n", + "from datetime import datetime\n", + "from pathlib import Path\n", + "import sys\n", + "import subprocess\n", + "import logging\n", + "\n", + "import transformers\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import MaxText\n", + "from maxtext.configs import pyconfig\n", + "from maxtext.trainers.post_train.sft import train_sft\n", + "\n", + "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n", + "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")\n", + "\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX devices: {jax.devices()}\")\n", + "print(f\"Number of available devices: {jax.local_device_count()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "483873c9-b3eb-4a95-be34-2b52aeb222e7", + "metadata": {}, + "source": [ + "#### Setting up the right parallel setup on GPU\n", + "\n", + "JAX supports two different parallel setups:\n", + "\n", + "1. *Single-host* (one machine)\n", + "\n", + "* Can be 1 GPU or multiple GPUs\n", + "* JAX will discover and use all local GPUs automatically\n", + "\n", + "Does not require `jax.distributed.initialize()`\n", + "\n", + "2. *Multi-host* (multiple machines / nodes)\n", + "\n", + "* Requires coordination across processes/hosts\n", + "* Requires `jax.distributed.initialize()` (or a launcher that does it)\n", + "\n", + "Needs coordinator and process metadata (address, process count, process index):\n", + "\n", + "`JAX_COORDINATOR_ADDRESS` (reachable host:port on process 0)\n", + "\n", + "`JAX_PROCESS_COUNT` (total number of processes/hosts)\n", + "\n", + "`JAX_PROCESS_INDEX` (0..count-1)\n", + "\n", + "**Example (2 hosts)**\n", + "\n", + "On host 0:\n", + "```bash\n", + "export JAX_COORDINATOR_ADDRESS=\"10.0.0.1:1234\"\n", + "export JAX_PROCESS_COUNT=2\n", + "export JAX_PROCESS_INDEX=0\n", + "```\n", + "\n", + "On host 1:\n", + "```bash\n", + "export JAX_COORDINATOR_ADDRESS=\"10.0.0.1:1234\"\n", + "export JAX_PROCESS_COUNT=2\n", + "export JAX_PROCESS_INDEX=1\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a86465ed-69d1-4d77-bb47-b29c203e5a60", + "metadata": {}, + "outputs": [], + "source": [ + "if not jax.distributed.is_initialized() and \"JAX_COORDINATOR_ADDRESS\" in os.environ:\n", + " jax.distributed.initialize()" + ] + }, + { + "cell_type": "markdown", + "id": "1687aa03-1549-429a-8156-571c7493ca3d", + "metadata": {}, + "source": [ + "### Define model paths and run configuration\n", + "\n", + "This block defines the core paths and identifiers used throughout the tutorial: the model name, tokenizer source, checkpoint location, and output directory. You can override `MODEL_CHECKPOINT_PATH` via an environment variable to point to an existing converted checkpoint and skip the conversion step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1aa133cf-1168-4e87-8c4a-6fc34f1cf5cc", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_NAME = \"llama3.1-8b\"\n", + "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n", + "\n", + "WORKSPACE_DIR = Path(\n", + " os.environ.get(\"WORKSPACE_DIR\", os.getcwd())\n", + ")\n", + "\n", + "# If set, use it; otherwise default to llama_checkpoint\n", + "MODEL_CHECKPOINT_PATH = os.environ.get(\"MODEL_CHECKPOINT_PATH\")\n", + "MODEL_CHECKPOINT_PATH = Path(MODEL_CHECKPOINT_PATH) if MODEL_CHECKPOINT_PATH else (WORKSPACE_DIR / \"llama_checkpoint\")\n", + "\n", + "print(f\"Model checkpoint directory: {MODEL_CHECKPOINT_PATH}\")\n", + "print(\"Tip: set MODEL_CHECKPOINT_PATH to a local directory to reuse an existing converted checkpoint.\")\n", + "\n", + "BASE_OUTPUT_DIRECTORY = Path(os.environ.get(\"BASE_OUTPUT_DIRECTORY\", str(WORKSPACE_DIR / \"sft_llama3_output\")))" + ] + }, + { + "cell_type": "markdown", + "id": "6b762437-1edb-4123-8257-90cb98028e97", + "metadata": {}, + "source": [ + "### Download and convert the Llama 3.1-8B checkpoint from Hugging Face\n", + "\n", + "This block downloads the pretrained Llama 3.1-8B weights from Hugging Face and converts them into MaxText's native checkpoint format. If a converted checkpoint already exists at the target path, this step is skipped entirely.\n", + "\n", + "The conversion runs in a CPU-only JAX context (`JAX_PLATFORMS=cpu`) to avoid unnecessary GPU memory allocation. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f028b9c-89ed-4301-9a73-2967694891d3", + "metadata": {}, + "outputs": [], + "source": [ + "ckpt_dir = Path(MODEL_CHECKPOINT_PATH)\n", + "\n", + "def run_ckpt_conversion(\n", + " *,\n", + " maxtext_repo_root: str,\n", + " model_name: str,\n", + " output_dir: Path,\n", + " hf_token: str,\n", + " quiet: bool = True,\n", + ") -> None:\n", + " env = os.environ.copy()\n", + "\n", + " # Conversion should not touch GPUs\n", + " env[\"JAX_PLATFORMS\"] = \"cpu\"\n", + "\n", + " # Reduce verbosity (JAX/XLA/TensorFlow C++ logging)\n", + " env.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\") # 0=all, 1=INFO off, 2=INFO+WARNING off, 3=only FATAL\n", + "\n", + " cmd = [\n", + " sys.executable, \"-m\", \"MaxText.utils.ckpt_conversion.to_maxtext\",\n", + " f\"{maxtext_repo_root}/configs/base.yml\",\n", + " f\"model_name={model_name}\",\n", + " f\"base_output_directory={str(output_dir)}\",\n", + " f\"hf_access_token={hf_token}\",\n", + " \"use_multimodal=false\",\n", + " \"scan_layers=true\",\n", + " \"skip_jax_distributed_system=true\",\n", + " ]\n", + "\n", + " output_dir.parent.mkdir(parents=True, exist_ok=True)\n", + "\n", + " if quiet:\n", + " # Capture logs; show only if something goes wrong\n", + " result = subprocess.run(\n", + " cmd,\n", + " env=env,\n", + " stdout=subprocess.PIPE,\n", + " stderr=subprocess.PIPE,\n", + " text=True,\n", + " )\n", + " if result.returncode != 0:\n", + " print(\"Checkpoint conversion failed. Logs:\\n\")\n", + " if result.stdout:\n", + " print(\"----- stdout -----\")\n", + " print(result.stdout)\n", + " if result.stderr:\n", + " print(\"----- stderr -----\")\n", + " print(result.stderr)\n", + " raise RuntimeError(\"Checkpoint conversion failed. See logs above.\")\n", + " else:\n", + " # Verbose mode (streams logs)\n", + " subprocess.run(cmd, env=env, check=True)\n", + "\n", + " print(f\"Checkpoint successfully converted to MaxText format at: {output_dir}\")\n", + "\n", + "if ckpt_dir.exists():\n", + " print(f\"Converted checkpoint already exists at: {ckpt_dir}\")\n", + "else:\n", + " print(f\"Converting checkpoint to MaxText format → {ckpt_dir}\")\n", + " run_ckpt_conversion(\n", + " maxtext_repo_root=MAXTEXT_REPO_ROOT,\n", + " model_name=MODEL_NAME,\n", + " output_dir=ckpt_dir,\n", + " hf_token=HF_TOKEN,\n", + " quiet=True, \n", + " )\n", + "\n", + "if not ckpt_dir.exists():\n", + " raise RuntimeError(\"Model checkpoint conversion failed. See logs above.\")" + ] + }, + { + "cell_type": "markdown", + "id": "265e9555-f026-4bbc-a6fb-7b0cac6bd9da", + "metadata": {}, + "source": [ + "## Provide the training configuration\n", + "\n", + "This block builds the full MaxText SFT training configuration by loading the base `sft.yml` config and applying runtime overrides for the model, dataset, hyperparameters, and output paths. Each run is tagged with a timestamp-based name to keep outputs isolated across experiments. Key settings:\n", + "\n", + "- **Dataset:** [UltraChat 200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k), a large instruction-style conversational dataset commonly used for SFT of chat models.\n", + "- **Training:** 100 steps, learning rate 2e-5, sequence length 1024, bfloat16 precision.\n", + "- **Checkpoint source:** The converted MaxText checkpoint from the previous step.\n", + "\n", + "To use your own dataset, ensure it follows a compatible schema and is accessible via the Hugging Face Hub or a local path. MaxText handles dataset loading, sharding, and batching automatically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d22f54a-efe1-425d-abd4-ceb34561a9a1", + "metadata": {}, + "outputs": [], + "source": [ + "ckpt_items_path = Path(MODEL_CHECKPOINT_PATH) / \"0\" / \"items\"\n", + "\n", + "if not os.environ.get(\"HF_TOKEN\"):\n", + " raise RuntimeError(\"HF_TOKEN is not set. Export it before loading the SFT config.\")\n", + "\n", + "RUN_NAME = datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n", + "\n", + "# Load configuration for SFT training\n", + "config_argv = [\n", + " \"\",\n", + " f\"{MAXTEXT_REPO_ROOT}/configs/sft.yml\",\n", + " f\"load_parameters_path={ckpt_items_path}\",\n", + " f\"model_name={MODEL_NAME}\",\n", + " \"steps=100\",\n", + " \"per_device_batch_size=1\",\n", + " \"max_target_length=1024\",\n", + " \"learning_rate=2.0e-5\",\n", + " \"weight_dtype=bfloat16\",\n", + " \"dtype=bfloat16\",\n", + " \"hf_path=HuggingFaceH4/ultrachat_200k\",\n", + " f\"hf_access_token={HF_TOKEN}\",\n", + " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", + " f\"run_name={RUN_NAME}\",\n", + " f\"tokenizer_path={TOKENIZER_PATH}\",\n", + " \"hardware=gpu\",\n", + "]\n", + "\n", + "# Suppress the verbose per-parameter config dump (hundreds of INFO lines)\n", + "_pyconfig_logger = logging.getLogger(\"MaxText.pyconfig\")\n", + "_prev_level = _pyconfig_logger.level\n", + "_pyconfig_logger.setLevel(logging.WARNING)\n", + "\n", + "config = pyconfig.initialize(config_argv)\n", + "\n", + "_pyconfig_logger.setLevel(_prev_level)\n", + "\n", + "print(\"SFT configuration loaded:\")\n", + "print(f\" Model: {config.model_name}\")\n", + "print(f\" Training Steps: {config.steps}\")\n", + "print(f\" Max sequence length: {config.max_target_length}\")\n", + "print(f\" Output Directory: {config.base_output_directory}\")" + ] + }, + { + "cell_type": "markdown", + "id": "408c6100-20e9-4dcb-b9b6-42d68bb03ae7", + "metadata": {}, + "source": [ + "## Run the SFT training\n", + "\n", + "This section launches the SFT training loop. It runs MaxText's `sft_train` with the configuration defined above, reports progress, and saves checkpoints to the output directory. On completion, it prints the checkpoint and TensorBoard log paths." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb13ddd1-57a8-469e-940c-11fe6ae2a90d", + "metadata": {}, + "outputs": [], + "source": [ + "os.environ.setdefault(\"LIBTPU_INIT_ARGS\", \"\")\n", + "\n", + "print(\"=\" * 60)\n", + "print(f\"Starting SFT training (run_name={RUN_NAME})\")\n", + "print(\"=\" * 60)\n", + "\n", + "try:\n", + " result = train_sft.train(config)\n", + "\n", + " print(\"\\n\" + \"=\" * 60)\n", + " print(\"Training completed successfully\")\n", + " print(\"=\" * 60)\n", + " print(f\"Checkpoints written to: {config.checkpoint_dir}\")\n", + " if hasattr(config, \"tensorboard_dir\"):\n", + " print(f\"TensorBoard logs: {config.tensorboard_dir}\")\n", + "\n", + " if isinstance(result, tuple) and len(result) == 2:\n", + " trainer, mesh = result\n", + "except Exception as e:\n", + " print(\"\\n\" + \"=\" * 60)\n", + " print(\"Training failed\")\n", + " print(\"=\" * 60)\n", + " print(f\"Error details: {e}\")\n", + " raise" + ] + }, + { + "cell_type": "markdown", + "id": "24e3a3e2-027a-4bb7-bb7f-163605010d03", + "metadata": {}, + "source": [ + "## Visualize training metrics with TensorBoard\n", + "\n", + "To monitor training loss and other metrics, launch TensorBoard in a separate terminal replacing `` with the TensorBoard logs path from the training log:\n", + "\n", + "```bash\n", + "export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python\n", + "tensorboard --logdir= --host 0.0.0.0 --port 6006 --load_fast=false\n", + "```\n", + "\n", + "Then open [http://127.0.0.1:6006/](http://127.0.0.1:6006/) in your browser. " + ] + }, + { + "cell_type": "markdown", + "id": "5acaf26b-72fe-404b-827a-297045547f5b", + "metadata": {}, + "source": [ + "## Test inference\n", + "\n", + "A quick sanity check to verify the fine-tuned model produces coherent output. The code below tokenizes a prompt using the Llama 3.1 chat template, then runs greedy autoregressive generation for up to 10 tokens, stopping early if the model produces an EOS token. This confirms the model loaded correctly and produces reasonable predictions.\n", + "\n", + "**Note:** this is naive autoregressive generation without KV-caching, so each step recomputes attention over the full sequence. For production use, consider a dedicated serving framework with KV-cache support." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f80390d8-a267-4d63-84c4-9c1c7a5a5944", + "metadata": {}, + "outputs": [], + "source": [ + "# Load tokenizer\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH, token=HF_TOKEN)\n", + "\n", + "# Get model from trainer\n", + "model = trainer.model\n", + "\n", + "# Format prompt using Llama chat template\n", + "prompt = \"What is the capital of France?\"\n", + "messages = [{\"role\": \"user\", \"content\": prompt}]\n", + "text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", + "\n", + "# Tokenize\n", + "tokens = jnp.array(tokenizer(text)[\"input_ids\"])[None, :]\n", + "\n", + "# Greedy autoregressive generation\n", + "max_new_tokens = 10\n", + "generated_ids = []\n", + "eos_token_id = tokenizer.eos_token_id\n", + "\n", + "for _ in range(max_new_tokens):\n", + " seq_len = tokens.shape[1]\n", + " positions = jnp.arange(seq_len)[None, :]\n", + " attention_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))[None, :]\n", + "\n", + " with mesh:\n", + " output = model(tokens, positions, None, attention_mask)\n", + " logits = output[0] if isinstance(output, tuple) else output\n", + "\n", + " next_token_id = int(jnp.argmax(logits[0, -1]))\n", + " generated_ids.append(next_token_id)\n", + "\n", + " if next_token_id == eos_token_id:\n", + " break\n", + "\n", + " tokens = jnp.concatenate([tokens, jnp.array([[next_token_id]])], axis=1)\n", + "\n", + "# Decode all generated tokens\n", + "generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"Generated ({len(generated_ids)} tokens): '{generated_text}'\")" + ] + }, + { + "cell_type": "markdown", + "id": "8a261a8c-55af-47ff-b68f-179abff5b623", + "metadata": {}, + "source": [ + "## Learn more\n", + "\n", + "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html\n", + "- **Configuration**: See `src/maxtext/configs/post_train/sft.yml` for all available options\n", + "- **Documentation**: Check `src/maxtext/trainers/post_train/sft/train_sft.py` for the `train` function implementation" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } }, - { - "cell_type": "code", - "execution_count": null, - "id": "dhc3l20703b", - "metadata": {}, - "outputs": [], - "source": [ - "!nvidia-smi" - ] - }, - { - "cell_type": "markdown", - "id": "c7671875-fea3-4abe-9f01-57c854f50f92", - "metadata": {}, - "source": [ - "### Get your Hugging Face token\n", - "\n", - "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", - "\n", - "**Follow these steps to get your token:**\n", - "\n", - "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", - "\n", - "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", - "\n", - "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", - "\n", - "4. **Copy the generated token**. You will need this in the later steps.\n", - "\n", - "**Follow these steps to store your token (only if running on Google Colab):**\n", - "\n", - "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", - "\n", - "2. Click **\"+ Add new secret\"**.\n", - "\n", - "3. Set the Name as **HF_TOKEN**.\n", - "\n", - "4. Paste your token into the Value field.\n", - "\n", - "5. Ensure the Notebook access toggle is turned On." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7999f002-65a1-4764-ba41-922f2fec43df", - "metadata": {}, - "outputs": [], - "source": [ - "try:\n", - " from google.colab import userdata\n", - " print(\"Running the notebook on Google Colab\")\n", - " IN_COLAB = True\n", - "except ImportError:\n", - " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", - " IN_COLAB = False" - ] - }, - { - "cell_type": "markdown", - "id": "6b691306-88ee-47b3-b841-cd6d072f51eb", - "metadata": {}, - "source": [ - "### Authenticate with Hugging Face\n", - "\n", - "Verify that your Hugging Face token is set and valid by calling the Hub's `whoami` endpoint. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "93a0030b-ab25-4d76-a173-1a464c19087a", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from huggingface_hub import HfApi\n", - "\n", - "if IN_COLAB:\n", - " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", - "else:\n", - " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", - "\n", - "if not HF_TOKEN:\n", - " from getpass import getpass\n", - " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", - "\n", - "if not HF_TOKEN:\n", - " raise RuntimeError(\"Authentication failed: Hugging Face token is not set.\")\n", - "\n", - "# Ensure token is set in this process\n", - "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n", - "\n", - "# Verify identity\n", - "api = HfApi()\n", - "user_info = api.whoami(token=HF_TOKEN)\n", - "username = user_info.get(\"name\") or \"Unknown user\"\n", - "\n", - "print(f\"Authenticated with Hugging Face successfully as: {username}\")" - ] - }, - { - "cell_type": "markdown", - "id": "887cd139-a776-43ad-aeea-8547fcd8d744", - "metadata": {}, - "source": [ - "### Acquire permission to use the gated model\n", - "\n", - "Llama 3.1-8B is a gated model, so you must explicitly request access before it can be downloaded. Visit the [model page](https://huggingface.co/meta-llama/Llama-3.1-8B) on Hugging Face, log in with the same account linked to your access token, and click **Request access**. You'll need to agree to Meta's license terms; approval is usually granted quickly but is not automatic. Once approved, your Hugging Face token will authorize downloads transparently. If you skip this step, model downloads will fail even with a valid token." - ] - }, - { - "cell_type": "markdown", - "id": "66b25ee3-1072-4c0f-965d-72e911873d1c", - "metadata": {}, - "source": [ - "## Get the model and convert it into MaxText format\n", - "\n", - "### Import dependencies\n", - "\n", - "#### Core libraries and installation\n", - "\n", - "Import the core libraries needed for this tutorial:\n", - "\n", - "- **JAX**: High-performance ML framework with automatic differentiation and XLA compilation\n", - "- **MaxText**: Google's production-grade training stack for JAX, providing model architectures, checkpoint management, and the SFT training loop\n", - "\n", - "The easiest way to get a working environment is the [NVIDIA NGC JAX container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax), which ships with JAX, CUDA, and MaxText preinstalled. To install the dependencies manually:\n", - "\n", - "```bash\n", - "pip install 'jax[cuda13]' maxtext\n", - "```\n", - "\n", - "On top of it, for the model conversion step you will also need **Torch**, the CPU version would be enough:\n", - "\n", - "```bash\n", - "pip install torch --index-url https://download.pytorch.org/whl/cpu\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "531951be-2f24-455a-b9ff-b07aeeb2de1d", - "metadata": {}, - "outputs": [], - "source": [ - "# Imports\n", - "from datetime import datetime\n", - "from pathlib import Path\n", - "import sys\n", - "import subprocess\n", - "import logging\n", - "\n", - "import transformers\n", - "\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import MaxText\n", - "from MaxText import pyconfig\n", - "from MaxText.sft.sft_trainer import train as sft_train, setup_trainer_state\n", - "\n", - "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n", - "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")\n", - "\n", - "print(f\"JAX version: {jax.__version__}\")\n", - "print(f\"JAX devices: {jax.devices()}\")\n", - "print(f\"Number of available devices: {jax.local_device_count()}\")" - ] - }, - { - "cell_type": "markdown", - "id": "483873c9-b3eb-4a95-be34-2b52aeb222e7", - "metadata": {}, - "source": [ - "#### Setting up the right parallel setup on GPU\n", - "\n", - "JAX supports two different parallel setups:\n", - "\n", - "1. *Single-host* (one machine)\n", - "\n", - "* Can be 1 GPU or multiple GPUs\n", - "* JAX will discover and use all local GPUs automatically\n", - "\n", - "Does not require `jax.distributed.initialize()`\n", - "\n", - "2. *Multi-host* (multiple machines / nodes)\n", - "\n", - "* Requires coordination across processes/hosts\n", - "* Requires `jax.distributed.initialize()` (or a launcher that does it)\n", - "\n", - "Needs coordinator and process metadata (address, process count, process index):\n", - "\n", - "`JAX_COORDINATOR_ADDRESS` (reachable host:port on process 0)\n", - "\n", - "`JAX_PROCESS_COUNT` (total number of processes/hosts)\n", - "\n", - "`JAX_PROCESS_INDEX` (0..count-1)\n", - "\n", - "**Example (2 hosts)**\n", - "\n", - "On host 0:\n", - "```bash\n", - "export JAX_COORDINATOR_ADDRESS=\"10.0.0.1:1234\"\n", - "export JAX_PROCESS_COUNT=2\n", - "export JAX_PROCESS_INDEX=0\n", - "```\n", - "\n", - "On host 1:\n", - "```bash\n", - "export JAX_COORDINATOR_ADDRESS=\"10.0.0.1:1234\"\n", - "export JAX_PROCESS_COUNT=2\n", - "export JAX_PROCESS_INDEX=1\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a86465ed-69d1-4d77-bb47-b29c203e5a60", - "metadata": {}, - "outputs": [], - "source": [ - "if not jax.distributed.is_initialized() and \"JAX_COORDINATOR_ADDRESS\" in os.environ:\n", - " jax.distributed.initialize()" - ] - }, - { - "cell_type": "markdown", - "id": "1687aa03-1549-429a-8156-571c7493ca3d", - "metadata": {}, - "source": [ - "### Define model paths and run configuration\n", - "\n", - "This block defines the core paths and identifiers used throughout the tutorial: the model name, tokenizer source, checkpoint location, and output directory. You can override `MODEL_CHECKPOINT_PATH` via an environment variable to point to an existing converted checkpoint and skip the conversion step." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1aa133cf-1168-4e87-8c4a-6fc34f1cf5cc", - "metadata": {}, - "outputs": [], - "source": [ - "MODEL_NAME = \"llama3.1-8b\"\n", - "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n", - "\n", - "WORKSPACE_DIR = Path(\n", - " os.environ.get(\"WORKSPACE_DIR\", os.getcwd())\n", - ")\n", - "\n", - "# If set, use it; otherwise default to llama_checkpoint\n", - "MODEL_CHECKPOINT_PATH = os.environ.get(\"MODEL_CHECKPOINT_PATH\")\n", - "MODEL_CHECKPOINT_PATH = Path(MODEL_CHECKPOINT_PATH) if MODEL_CHECKPOINT_PATH else (WORKSPACE_DIR / \"llama_checkpoint\")\n", - "\n", - "print(f\"Model checkpoint directory: {MODEL_CHECKPOINT_PATH}\")\n", - "print(\"Tip: set MODEL_CHECKPOINT_PATH to a local directory to reuse an existing converted checkpoint.\")\n", - "\n", - "BASE_OUTPUT_DIRECTORY = Path(os.environ.get(\"BASE_OUTPUT_DIRECTORY\", str(WORKSPACE_DIR / \"sft_llama3_output\")))" - ] - }, - { - "cell_type": "markdown", - "id": "6b762437-1edb-4123-8257-90cb98028e97", - "metadata": {}, - "source": [ - "### Download and convert the Llama 3.1-8B checkpoint from Hugging Face\n", - "\n", - "This block downloads the pretrained Llama 3.1-8B weights from Hugging Face and converts them into MaxText's native checkpoint format. If a converted checkpoint already exists at the target path, this step is skipped entirely.\n", - "\n", - "The conversion runs in a CPU-only JAX context (`JAX_PLATFORMS=cpu`) to avoid unnecessary GPU memory allocation. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3f028b9c-89ed-4301-9a73-2967694891d3", - "metadata": {}, - "outputs": [], - "source": [ - "ckpt_dir = Path(MODEL_CHECKPOINT_PATH)\n", - "\n", - "def run_ckpt_conversion(\n", - " *,\n", - " maxtext_repo_root: str,\n", - " model_name: str,\n", - " output_dir: Path,\n", - " hf_token: str,\n", - " quiet: bool = True,\n", - ") -> None:\n", - " env = os.environ.copy()\n", - "\n", - " # Conversion should not touch GPUs\n", - " env[\"JAX_PLATFORMS\"] = \"cpu\"\n", - "\n", - " # Reduce verbosity (JAX/XLA/TensorFlow C++ logging)\n", - " env.setdefault(\"TF_CPP_MIN_LOG_LEVEL\", \"2\") # 0=all, 1=INFO off, 2=INFO+WARNING off, 3=only FATAL\n", - "\n", - " cmd = [\n", - " sys.executable, \"-m\", \"MaxText.utils.ckpt_conversion.to_maxtext\",\n", - " f\"{maxtext_repo_root}/configs/base.yml\",\n", - " f\"model_name={model_name}\",\n", - " f\"base_output_directory={str(output_dir)}\",\n", - " f\"hf_access_token={hf_token}\",\n", - " \"use_multimodal=false\",\n", - " \"scan_layers=true\",\n", - " \"skip_jax_distributed_system=true\",\n", - " ]\n", - "\n", - " output_dir.parent.mkdir(parents=True, exist_ok=True)\n", - "\n", - " if quiet:\n", - " # Capture logs; show only if something goes wrong\n", - " result = subprocess.run(\n", - " cmd,\n", - " env=env,\n", - " stdout=subprocess.PIPE,\n", - " stderr=subprocess.PIPE,\n", - " text=True,\n", - " )\n", - " if result.returncode != 0:\n", - " print(\"Checkpoint conversion failed. Logs:\\n\")\n", - " if result.stdout:\n", - " print(\"----- stdout -----\")\n", - " print(result.stdout)\n", - " if result.stderr:\n", - " print(\"----- stderr -----\")\n", - " print(result.stderr)\n", - " raise RuntimeError(\"Checkpoint conversion failed. See logs above.\")\n", - " else:\n", - " # Verbose mode (streams logs)\n", - " subprocess.run(cmd, env=env, check=True)\n", - "\n", - " print(f\"Checkpoint successfully converted to MaxText format at: {output_dir}\")\n", - "\n", - "if ckpt_dir.exists():\n", - " print(f\"Converted checkpoint already exists at: {ckpt_dir}\")\n", - "else:\n", - " print(f\"Converting checkpoint to MaxText format → {ckpt_dir}\")\n", - " run_ckpt_conversion(\n", - " maxtext_repo_root=MAXTEXT_REPO_ROOT,\n", - " model_name=MODEL_NAME,\n", - " output_dir=ckpt_dir,\n", - " hf_token=HF_TOKEN,\n", - " quiet=True, \n", - " )\n", - "\n", - "if not ckpt_dir.exists():\n", - " raise RuntimeError(\"Model checkpoint conversion failed. See logs above.\")" - ] - }, - { - "cell_type": "markdown", - "id": "265e9555-f026-4bbc-a6fb-7b0cac6bd9da", - "metadata": {}, - "source": [ - "## Provide the training configuration\n", - "\n", - "This block builds the full MaxText SFT training configuration by loading the base `sft.yml` config and applying runtime overrides for the model, dataset, hyperparameters, and output paths. Each run is tagged with a timestamp-based name to keep outputs isolated across experiments. Key settings:\n", - "\n", - "- **Dataset:** [UltraChat 200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k), a large instruction-style conversational dataset commonly used for SFT of chat models.\n", - "- **Training:** 100 steps, learning rate 2e-5, sequence length 1024, bfloat16 precision.\n", - "- **Checkpoint source:** The converted MaxText checkpoint from the previous step.\n", - "\n", - "To use your own dataset, ensure it follows a compatible schema and is accessible via the Hugging Face Hub or a local path. MaxText handles dataset loading, sharding, and batching automatically." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4d22f54a-efe1-425d-abd4-ceb34561a9a1", - "metadata": {}, - "outputs": [], - "source": [ - "ckpt_items_path = Path(MODEL_CHECKPOINT_PATH) / \"0\" / \"items\"\n", - "\n", - "if not os.environ.get(\"HF_TOKEN\"):\n", - " raise RuntimeError(\"HF_TOKEN is not set. Export it before loading the SFT config.\")\n", - "\n", - "RUN_NAME = datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n", - "\n", - "# Load configuration for SFT training\n", - "config_argv = [\n", - " \"\",\n", - " f\"{MAXTEXT_REPO_ROOT}/configs/sft.yml\",\n", - " f\"load_parameters_path={ckpt_items_path}\",\n", - " f\"model_name={MODEL_NAME}\",\n", - " \"steps=100\",\n", - " \"per_device_batch_size=1\",\n", - " \"max_target_length=1024\",\n", - " \"learning_rate=2.0e-5\",\n", - " \"weight_dtype=bfloat16\",\n", - " \"dtype=bfloat16\",\n", - " \"hf_path=HuggingFaceH4/ultrachat_200k\",\n", - " f\"hf_access_token={HF_TOKEN}\",\n", - " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", - " f\"run_name={RUN_NAME}\",\n", - " f\"tokenizer_path={TOKENIZER_PATH}\",\n", - " \"hardware=gpu\",\n", - "]\n", - "\n", - "# Suppress the verbose per-parameter config dump (hundreds of INFO lines)\n", - "_pyconfig_logger = logging.getLogger(\"MaxText.pyconfig\")\n", - "_prev_level = _pyconfig_logger.level\n", - "_pyconfig_logger.setLevel(logging.WARNING)\n", - "\n", - "config = pyconfig.initialize(config_argv)\n", - "\n", - "_pyconfig_logger.setLevel(_prev_level)\n", - "\n", - "print(\"SFT configuration loaded:\")\n", - "print(f\" Model: {config.model_name}\")\n", - "print(f\" Training Steps: {config.steps}\")\n", - "print(f\" Max sequence length: {config.max_target_length}\")\n", - "print(f\" Output Directory: {config.base_output_directory}\")" - ] - }, - { - "cell_type": "markdown", - "id": "408c6100-20e9-4dcb-b9b6-42d68bb03ae7", - "metadata": {}, - "source": [ - "## Run the SFT training\n", - "\n", - "This section launches the SFT training loop. It runs MaxText's `sft_train` with the configuration defined above, reports progress, and saves checkpoints to the output directory. On completion, it prints the checkpoint and TensorBoard log paths." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bb13ddd1-57a8-469e-940c-11fe6ae2a90d", - "metadata": {}, - "outputs": [], - "source": [ - "os.environ.setdefault(\"LIBTPU_INIT_ARGS\", \"\")\n", - "\n", - "print(\"=\" * 60)\n", - "print(f\"Starting SFT training (run_name={RUN_NAME})\")\n", - "print(\"=\" * 60)\n", - "\n", - "try:\n", - " result = sft_train(config)\n", - "\n", - " print(\"\\n\" + \"=\" * 60)\n", - " print(\"Training completed successfully\")\n", - " print(\"=\" * 60)\n", - " print(f\"Checkpoints written to: {config.checkpoint_dir}\")\n", - " if hasattr(config, \"tensorboard_dir\"):\n", - " print(f\"TensorBoard logs: {config.tensorboard_dir}\")\n", - "\n", - " if isinstance(result, tuple) and len(result) == 2:\n", - " trainer, mesh = result\n", - "except Exception as e:\n", - " print(\"\\n\" + \"=\" * 60)\n", - " print(\"Training failed\")\n", - " print(\"=\" * 60)\n", - " print(f\"Error details: {e}\")\n", - " raise" - ] - }, - { - "cell_type": "markdown", - "id": "24e3a3e2-027a-4bb7-bb7f-163605010d03", - "metadata": {}, - "source": [ - "## Visualize training metrics with TensorBoard\n", - "\n", - "To monitor training loss and other metrics, launch TensorBoard in a separate terminal replacing `` with the TensorBoard logs path from the training log:\n", - "\n", - "```bash\n", - "export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python\n", - "tensorboard --logdir= --host 0.0.0.0 --port 6006 --load_fast=false\n", - "```\n", - "\n", - "Then open [http://127.0.0.1:6006/](http://127.0.0.1:6006/) in your browser. " - ] - }, - { - "cell_type": "markdown", - "id": "5acaf26b-72fe-404b-827a-297045547f5b", - "metadata": {}, - "source": [ - "## Test inference\n", - "\n", - "A quick sanity check to verify the fine-tuned model produces coherent output. The code below tokenizes a prompt using the Llama 3.1 chat template, then runs greedy autoregressive generation for up to 10 tokens, stopping early if the model produces an EOS token. This confirms the model loaded correctly and produces reasonable predictions.\n", - "\n", - "**Note:** this is naive autoregressive generation without KV-caching, so each step recomputes attention over the full sequence. For production use, consider a dedicated serving framework with KV-cache support." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f80390d8-a267-4d63-84c4-9c1c7a5a5944", - "metadata": {}, - "outputs": [], - "source": [ - "# Load tokenizer\n", - "tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH, token=HF_TOKEN)\n", - "\n", - "# Get model from trainer\n", - "model = trainer.model\n", - "\n", - "# Format prompt using Llama chat template\n", - "prompt = \"What is the capital of France?\"\n", - "messages = [{\"role\": \"user\", \"content\": prompt}]\n", - "text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", - "\n", - "# Tokenize\n", - "tokens = jnp.array(tokenizer(text)[\"input_ids\"])[None, :]\n", - "\n", - "# Greedy autoregressive generation\n", - "max_new_tokens = 10\n", - "generated_ids = []\n", - "eos_token_id = tokenizer.eos_token_id\n", - "\n", - "for _ in range(max_new_tokens):\n", - " seq_len = tokens.shape[1]\n", - " positions = jnp.arange(seq_len)[None, :]\n", - " attention_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))[None, :]\n", - "\n", - " with mesh:\n", - " output = model(tokens, positions, None, attention_mask)\n", - " logits = output[0] if isinstance(output, tuple) else output\n", - "\n", - " next_token_id = int(jnp.argmax(logits[0, -1]))\n", - " generated_ids.append(next_token_id)\n", - "\n", - " if next_token_id == eos_token_id:\n", - " break\n", - "\n", - " tokens = jnp.concatenate([tokens, jnp.array([[next_token_id]])], axis=1)\n", - "\n", - "# Decode all generated tokens\n", - "generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)\n", - "\n", - "print(f\"Prompt: {prompt}\")\n", - "print(f\"Generated ({len(generated_ids)} tokens): '{generated_text}'\")" - ] - }, - { - "cell_type": "markdown", - "id": "8a261a8c-55af-47ff-b68f-179abff5b623", - "metadata": {}, - "source": [ - "## Learn more\n", - "\n", - "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html\n", - "- **Configuration**: See `src/maxtext/configs/post_train/sft.yml` for all available options\n", - "- **Documentation**: Check `src/MaxText/sft/sft_trainer.py` for the `sft_train` function implementation" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/src/maxtext/examples/sft_qwen3_demo.ipynb b/src/maxtext/examples/sft_qwen3_demo.ipynb index 0f6f2e8f59..c4fbf53529 100644 --- a/src/maxtext/examples/sft_qwen3_demo.ipynb +++ b/src/maxtext/examples/sft_qwen3_demo.ipynb @@ -1,584 +1,592 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "1nb_Ppf2ZUQL" - }, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_qwen3_demo.ipynb)\n", - "\n", - "# Qwen3-0.6B Supervised Fine-Tuning (SFT) Demo\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FGbe4_YQZUQL" - }, - "source": [ - "## Overview\n", - "\n", - "This notebook performs SFT training and evaluation workflow on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k).\n", - "The primary goal is to demonstrate the end-to-end process of:\n", - "1. Pre-SFT Evaluation: Calcuating baseline accuracy for the model before training.\n", - "2. SFT Training: Fine-tune the model using MaxText & Tunix SFT trainer.\n", - "3. Post-SFT Evaluation: Re-running the evaluation loop after training to measure the performance gain achieved by SFT.\n", - "\n", - "This notebook can run on the **public TPU v5e-1**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zolxPWhQZUQL" - }, - "source": [ - "## Prerequisites\n", - "\n", - "### Change Runtime Type (only if running on Google Colab)\n", - "\n", - "**Instructions:**\n", - "1. Navigate to the menu at the top of the screen.\n", - "2. Click on **Runtime**.\n", - "3. Select **Change runtime type** from the dropdown menu.\n", - "4. Select **v5e-1 TPU** as the **Hardware accelerator**.\n", - "5. Click on **Save**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Rk_QpVVuZUQL" - }, - "source": [ - "### Get Your Hugging Face Token\n", - "\n", - "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", - "\n", - "**Follow these steps to get your token:**\n", - "\n", - "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", - " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", - "\n", - "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", - "\n", - "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", - "\n", - "4. **Copy the generated token**. You will need this in the later steps.\n", - "\n", - "**Follow these steps to store your token (only if running on Google Colab):**\n", - "\n", - "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", - "\n", - "2. Click **\"+ Add new secret\"**.\n", - "\n", - "3. Set the Name as **HF_TOKEN**.\n", - "\n", - "4. Paste your token into the Value field.\n", - "\n", - "5. Ensure the Notebook access toggle is turned On." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "try:\n", - " from google.colab import userdata\n", - " print(\"Running the notebook on Google Colab\")\n", - " IN_COLAB = True\n", - "except ImportError:\n", - " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", - " IN_COLAB = False" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "D9ms-jTSZUQL" - }, - "source": [ - "## Installation: MaxText and Post training Dependencies\n", - "\n", - "**Running the notebook on Visual Studio or JupyterLab:** Before proceeding, create a virtual environment and install the required post-training dependencies by following `Option 3: Installing [tpu-post-train]` in the [MaxText installation guide](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-source). Once the environment is set up, ensure the notebook is running within it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if IN_COLAB:\n", - " # Clone the MaxText repository\n", - " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n", - " %cd maxtext\n", - "\n", - " # Install uv, a fast Python package installer\n", - " !pip install uv\n", - " \n", - " # Install MaxText and post-training dependencies\n", - " !uv pip install -e .[tpu-post-train] --resolution=lowest\n", - " !install_maxtext_tpu_post_train_extra_deps" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Session restart Instructions for Colab:**\n", - "1. Navigate to the menu at the top of the screen.\n", - "2. Click on **Runtime**.\n", - "3. Select **Restart session** from the dropdown menu.\n", - "\n", - "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Clexf-j7ZUQM" - }, - "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "PkBI9A3JZUQM" - }, - "outputs": [], - "source": [ - "import jax\n", - "import os\n", - "import sys\n", - "import transformers\n", - "\n", - "from maxtext.configs import pyconfig\n", - "from maxtext.examples.sft_train_and_evaluate import evaluate_model, get_test_dataset\n", - "from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter\n", - "from maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n", - "from maxtext.trainers.post_train.sft import train_sft\n", - "\n", - "# Suppress vLLM logging with a severity level below ERROR\n", - "os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n", - "from tunix.rl.rollout import base_rollout\n", - "from tunix.rl.rollout.vllm_rollout import VllmRollout\n", - "\n", - "from datetime import datetime\n", - "from flax import nnx\n", - "from huggingface_hub import login\n", - "\n", - "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if not jax.distributed.is_initialized():\n", - " jax.distributed.initialize()\n", - "print(f\"JAX version: {jax.__version__}\")\n", - "print(f\"JAX devices: {jax.devices()}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JBbPN-uVZUQM" - }, - "outputs": [], - "source": [ - "try:\n", - " from google.colab import userdata\n", - " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", - "except ImportError:\n", - " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", - "\n", - "# If not found in the environment, prompt the user for input securely\n", - "# getpass function ensures the token is hidden while you type\n", - "if not HF_TOKEN:\n", - " from getpass import getpass\n", - " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", - "\n", - "if HF_TOKEN:\n", - " login(token=HF_TOKEN)\n", - " print(\"Authenticated with Hugging Face successfully!\")\n", - "else:\n", - " print(\"Authentication failed: Hugging Face token is not set.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aENuzm9iZUQM" - }, - "source": [ - "## Model Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "RjPYYl3zZUQM" - }, - "outputs": [], - "source": [ - "MODEL_NAME = \"qwen3-0.6b\"\n", - "TOKENIZER_PATH = \"Qwen/Qwen3-0.6B\"\n", - "tokenizer = transformers.AutoTokenizer.from_pretrained(\n", - " TOKENIZER_PATH,\n", - " token=HF_TOKEN,\n", - ")\n", - "\n", - "# set the path to the model checkpoint (excluding `/0/items`) or leave empty to download from HuggingFace\n", - "MODEL_CHECKPOINT_PATH = \"\"\n", - "if not MODEL_CHECKPOINT_PATH:\n", - " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_PKG_DIR}/qwen_checkpoint\"\n", - " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n", - " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n", - "\n", - "\n", - "RUN_NAME = datetime.now().strftime(\"%Y-%m-%d-%H-%m-%S\")\n", - "\n", - "# This is the directory where the fine-tuned model checkpoint will be saved\n", - "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_PKG_DIR}/maxtext_qwen06_output\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4L37Ij4NZUQM" - }, - "source": [ - "## Download Qwen3-0.6B Model Checkpoint from Hugging Face" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kJanDAc0ZUQM" - }, - "outputs": [], - "source": [ - "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", - " # install torch for the conversion script\n", - " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n", - "\n", - " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_PKG_DIR} {sys.executable} -m maxtext.checkpoint_conversion.to_maxtext \\\n", - " {MAXTEXT_PKG_DIR}/configs/base.yml \\\n", - " model_name={MODEL_NAME} \\\n", - " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n", - " hf_access_token={HF_TOKEN} \\\n", - " use_multimodal=false \\\n", - " scan_layers=true \\\n", - " skip_jax_distributed_system=True\n", - "\n", - "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", - " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PC-hILG0ZUQM" - }, - "source": [ - "## Dataset Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "O3MLdr9kZUQM" - }, - "outputs": [], - "source": [ - "DATASET_NAME = \"openai/gsm8k\"\n", - "TRAIN_DATA_SPLIT = \"train\"\n", - "TEST_DATA_SPLIT = \"test\"\n", - "HF_DATA_DIR = \"main\"\n", - "TRAIN_DATA_COLUMNS = [\"question\", \"answer\"]\n", - "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/src/maxtext/examples/chat_templates/math_qa.json\"\n", - "if not os.path.exists(CHAT_TEMPLATE_PATH):\n", - " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", - "NUM_TEST_SAMPLES = 20 # Total number of samples to test\n", - "BATCH_SIZE = 1 # Number of test samples to process in a batch" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yeAHmxSYZUQM" - }, - "source": [ - "## MaxText Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "In-jdp1AAwrL" - }, - "outputs": [], - "source": [ - "%%capture\n", - "config = pyconfig.initialize(\n", - " [\n", - " \"\",\n", - " f\"{MAXTEXT_PKG_DIR}/configs/post_train/sft.yml\",\n", - " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n", - " f\"model_name={MODEL_NAME}\",\n", - " f\"hf_access_token={HF_TOKEN}\",\n", - " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", - " f\"run_name={RUN_NAME}\",\n", - " f\"tokenizer_path={TOKENIZER_PATH}\",\n", - " f\"hf_path={DATASET_NAME}\",\n", - " f\"train_split={TRAIN_DATA_SPLIT}\",\n", - " f\"hf_data_dir={HF_DATA_DIR}\",\n", - " f\"train_data_columns={TRAIN_DATA_COLUMNS}\",\n", - " \"steps=500\",\n", - " \"per_device_batch_size=1\",\n", - " \"max_target_length=1024\",\n", - " \"learning_rate=3e-6\",\n", - " \"weight_dtype=bfloat16\",\n", - " \"dtype=bfloat16\",\n", - " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", - " ]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O9b0GWo-ZUQM" - }, - "source": [ - "## Initial Setup & Data Preparation" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TDqFmvUCZUQM" - }, - "source": [ - "### Create Test Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wscWYxrtZUQM" - }, - "outputs": [], - "source": [ - "test_dataset = get_test_dataset(config, tokenizer)\n", - "test_dataset = test_dataset[:NUM_TEST_SAMPLES]\n", - "test_dataset = test_dataset.to_iter_dataset().batch(BATCH_SIZE, drop_remainder=True)\n", - "TOTAL_BATCHES = NUM_TEST_SAMPLES // BATCH_SIZE\n", - "print(\n", - " f\"Processing {NUM_TEST_SAMPLES} examples with a batch size of {BATCH_SIZE}. This will result in {TOTAL_BATCHES} total batches for the test run.\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bLSvOOEUZUQM" - }, - "source": [ - "### Create SFT Trainer State" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2IHsC0m6ZUQM" - }, - "outputs": [], - "source": [ - "trainer, mesh = train_sft.setup_trainer_state(config)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PpKtEqzFZUQM" - }, - "source": [ - "### Create vLLM Rollout" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3-pf_rbqZUQM" - }, - "outputs": [], - "source": [ - "tunix_model = TunixMaxTextAdapter(trainer.model)\n", - "vllm_rollout = VllmRollout(\n", - " model=tunix_model,\n", - " tokenizer=tokenizer,\n", - " cache_config_or_size=1280,\n", - " mesh=mesh,\n", - " rollout_config=base_rollout.RolloutConfig(\n", - " rollout_vllm_model_version=TOKENIZER_PATH,\n", - " rollout_vllm_hbm_utilization=0.8,\n", - " rollout_vllm_init_with_random_weights=True,\n", - " rollout_vllm_tpu_backend_type=\"jax\",\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "567gTxsEZUQM" - }, - "source": [ - "## Evaluation before SFT Training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "OnACa3zCZUQM" - }, - "outputs": [], - "source": [ - "print(\"Running Pre-SFT Evaluation...\")\n", - "score = evaluate_model(test_dataset, vllm_rollout, debug=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "u5-M4iYkZUQN" - }, - "outputs": [], - "source": [ - "print(\"========================= Score for PRE-SFT Evaluation =========================\")\n", - "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n", - "print(\n", - " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n", - ")\n", - "print(\n", - " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EJE1ookSAzz-" - }, - "source": [ - "## SFT Training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "mgwpNgQYCJEd", - "tags": [] - }, - "outputs": [], - "source": [ - "print(\"Starting SFT Training...\")\n", - "trainer = train_sft.train_model(config, trainer, mesh)\n", - "print(\"SFT Training Complete!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WEdNYRhwZUQN" - }, - "source": [ - "## Evaluation after SFT Training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XcsZacZdZUQN" - }, - "outputs": [], - "source": [ - "print(\"Running Post-SFT Evaluation...\")\n", - "model = TunixMaxTextAdapter(trainer.model)\n", - "state = nnx.state(model)\n", - "vllm_rollout.update_params(state)\n", - "score = evaluate_model(test_dataset, vllm_rollout, debug=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "-JtYTPvJZUQN", - "tags": [] - }, - "outputs": [], - "source": [ - "print(\"========================= Score for POST-SFT Evaluation =========================\")\n", - "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n", - "print(\n", - " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n", - ")\n", - "print(\n", - " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n", - ")" - ] - } - ], - "metadata": { - "accelerator": "TPU", - "colab": { - "gpuType": "V5E1", - "provenance": [] - }, - "kernelspec": { - "display_name": "maxtext_venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.12" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "1nb_Ppf2ZUQL" + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_qwen3_demo.ipynb)\n", + "\n", + "# Qwen3-0.6B Supervised Fine-Tuning (SFT) Demo\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FGbe4_YQZUQL" + }, + "source": [ + "## Overview\n", + "\n", + "This notebook performs SFT training and evaluation workflow on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k).\n", + "The primary goal is to demonstrate the end-to-end process of:\n", + "1. Pre-SFT Evaluation: Calcuating baseline accuracy for the model before training.\n", + "2. SFT Training: Fine-tune the model using MaxText & Tunix SFT trainer.\n", + "3. Post-SFT Evaluation: Re-running the evaluation loop after training to measure the performance gain achieved by SFT.\n", + "\n", + "This notebook can run on the **public TPU v5e-1**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zolxPWhQZUQL" + }, + "source": [ + "## Prerequisites\n", + "\n", + "### Change Runtime Type (only if running on Google Colab)\n", + "\n", + "**Instructions:**\n", + "1. Navigate to the menu at the top of the screen.\n", + "2. Click on **Runtime**.\n", + "3. Select **Change runtime type** from the dropdown menu.\n", + "4. Select **v5e-1 TPU** as the **Hardware accelerator**.\n", + "5. Click on **Save**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Rk_QpVVuZUQL" + }, + "source": [ + "### Get Your Hugging Face Token\n", + "\n", + "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", + "\n", + "**Follow these steps to get your token:**\n", + "\n", + "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", + " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", + "\n", + "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", + "\n", + "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", + "\n", + "4. **Copy the generated token**. You will need this in the later steps.\n", + "\n", + "**Follow these steps to store your token (only if running on Google Colab):**\n", + "\n", + "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", + "\n", + "2. Click **\"+ Add new secret\"**.\n", + "\n", + "3. Set the Name as **HF_TOKEN**.\n", + "\n", + "4. Paste your token into the Value field.\n", + "\n", + "5. Ensure the Notebook access toggle is turned On." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "o0gz1E8VtpsI" + }, + "outputs": [], + "source": [ + "try:\n", + " from google.colab import userdata\n", + " print(\"Running the notebook on Google Colab\")\n", + " IN_COLAB = True\n", + "except ImportError:\n", + " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", + " IN_COLAB = False" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D9ms-jTSZUQL" + }, + "source": [ + "## Installation: MaxText and Post training Dependencies\n", + "\n", + "**Running the notebook on Visual Studio or JupyterLab:** Before proceeding, create a virtual environment and install the required post-training dependencies by following `Option 3: Installing [tpu-post-train]` in the [MaxText installation guide](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-source). Once the environment is set up, ensure the notebook is running within it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bjnwIv1YtpsI" + }, + "outputs": [], + "source": [ + "if IN_COLAB:\n", + " # Clone the MaxText repository\n", + " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n", + " %cd maxtext\n", + "\n", + " # Install uv, a fast Python package installer\n", + " !pip install uv\n", + " \n", + " # Install MaxText and post-training dependencies\n", + " !uv pip install -e .[tpu-post-train] --resolution=lowest\n", + " !install_maxtext_tpu_post_train_extra_deps" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OKWBCMrstpsI" + }, + "source": [ + "**Session restart Instructions for Colab:**\n", + "1. Navigate to the menu at the top of the screen.\n", + "2. Click on **Runtime**.\n", + "3. Select **Restart session** from the dropdown menu.\n", + "\n", + "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Clexf-j7ZUQM" + }, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PkBI9A3JZUQM" + }, + "outputs": [], + "source": [ + "import jax\n", + "import os\n", + "import sys\n", + "import transformers\n", + "\n", + "from maxtext.configs import pyconfig\n", + "from maxtext.examples.sft_train_and_evaluate import evaluate_model, get_test_dataset\n", + "from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter\n", + "from maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n", + "from maxtext.trainers.post_train.sft import train_sft\n", + "\n", + "# Suppress vLLM logging with a severity level below ERROR\n", + "os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n", + "from tunix.rl.rollout import base_rollout\n", + "from tunix.rl.rollout.vllm_rollout import VllmRollout\n", + "\n", + "from datetime import datetime\n", + "from flax import nnx\n", + "from huggingface_hub import login\n", + "\n", + "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NIiA2OletpsI" + }, + "outputs": [], + "source": [ + "if not jax.distributed.is_initialized():\n", + " jax.distributed.initialize()\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX devices: {jax.devices()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JBbPN-uVZUQM" + }, + "outputs": [], + "source": [ + "try:\n", + " from google.colab import userdata\n", + " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", + "except ImportError:\n", + " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", + "\n", + "# If not found in the environment, prompt the user for input securely\n", + "# getpass function ensures the token is hidden while you type\n", + "if not HF_TOKEN:\n", + " from getpass import getpass\n", + " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", + "\n", + "if HF_TOKEN:\n", + " login(token=HF_TOKEN)\n", + " print(\"Authenticated with Hugging Face successfully!\")\n", + "else:\n", + " print(\"Authentication failed: Hugging Face token is not set.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aENuzm9iZUQM" + }, + "source": [ + "## Model Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RjPYYl3zZUQM" + }, + "outputs": [], + "source": [ + "MODEL_NAME = \"qwen3-0.6b\"\n", + "TOKENIZER_PATH = \"Qwen/Qwen3-0.6B\"\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained(\n", + " TOKENIZER_PATH,\n", + " token=HF_TOKEN,\n", + ")\n", + "\n", + "# set the path to the model checkpoint (excluding `/0/items`) or leave empty to download from HuggingFace\n", + "MODEL_CHECKPOINT_PATH = \"\"\n", + "if not MODEL_CHECKPOINT_PATH:\n", + " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_PKG_DIR}/qwen_checkpoint\"\n", + " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n", + " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n", + "\n", + "\n", + "RUN_NAME = datetime.now().strftime(\"%Y-%m-%d-%H-%m-%S\")\n", + "\n", + "# This is the directory where the fine-tuned model checkpoint will be saved\n", + "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_PKG_DIR}/maxtext_qwen06_output\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4L37Ij4NZUQM" + }, + "source": [ + "## Download Qwen3-0.6B Model Checkpoint from Hugging Face" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kJanDAc0ZUQM" + }, + "outputs": [], + "source": [ + "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", + " # install torch for the conversion script\n", + " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n", + "\n", + " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_PKG_DIR} {sys.executable} -m maxtext.checkpoint_conversion.to_maxtext \\\n", + " {MAXTEXT_PKG_DIR}/configs/base.yml \\\n", + " model_name={MODEL_NAME} \\\n", + " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n", + " hf_access_token={HF_TOKEN} \\\n", + " use_multimodal=false \\\n", + " scan_layers=true \\\n", + " skip_jax_distributed_system=True\n", + "\n", + "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", + " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PC-hILG0ZUQM" + }, + "source": [ + "## Dataset Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "O3MLdr9kZUQM" + }, + "outputs": [], + "source": [ + "DATASET_NAME = \"openai/gsm8k\"\n", + "TRAIN_DATA_SPLIT = \"train\"\n", + "TEST_DATA_SPLIT = \"test\"\n", + "HF_DATA_DIR = \"main\"\n", + "TRAIN_DATA_COLUMNS = [\"question\", \"answer\"]\n", + "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/src/maxtext/examples/chat_templates/math_qa.json\"\n", + "if not os.path.exists(CHAT_TEMPLATE_PATH):\n", + " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", + "NUM_TEST_SAMPLES = 20 # Total number of samples to test\n", + "BATCH_SIZE = 1 # Number of test samples to process in a batch" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yeAHmxSYZUQM" + }, + "source": [ + "## MaxText Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "In-jdp1AAwrL" + }, + "outputs": [], + "source": [ + "%%capture\n", + "config = pyconfig.initialize(\n", + " [\n", + " \"\",\n", + " f\"{MAXTEXT_PKG_DIR}/configs/post_train/sft.yml\",\n", + " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n", + " f\"model_name={MODEL_NAME}\",\n", + " f\"hf_access_token={HF_TOKEN}\",\n", + " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", + " f\"run_name={RUN_NAME}\",\n", + " f\"tokenizer_path={TOKENIZER_PATH}\",\n", + " f\"hf_path={DATASET_NAME}\",\n", + " f\"train_split={TRAIN_DATA_SPLIT}\",\n", + " f\"hf_data_dir={HF_DATA_DIR}\",\n", + " f\"train_data_columns={TRAIN_DATA_COLUMNS}\",\n", + " \"steps=500\",\n", + " \"per_device_batch_size=1\",\n", + " \"max_target_length=1024\",\n", + " \"learning_rate=3e-6\",\n", + " \"weight_dtype=bfloat16\",\n", + " \"dtype=bfloat16\",\n", + " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O9b0GWo-ZUQM" + }, + "source": [ + "## Initial Setup & Data Preparation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TDqFmvUCZUQM" + }, + "source": [ + "### Create Test Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wscWYxrtZUQM" + }, + "outputs": [], + "source": [ + "test_dataset = get_test_dataset(config, tokenizer)\n", + "test_dataset = test_dataset[:NUM_TEST_SAMPLES]\n", + "test_dataset = test_dataset.to_iter_dataset().batch(BATCH_SIZE, drop_remainder=True)\n", + "TOTAL_BATCHES = NUM_TEST_SAMPLES // BATCH_SIZE\n", + "print(\n", + " f\"Processing {NUM_TEST_SAMPLES} examples with a batch size of {BATCH_SIZE}. This will result in {TOTAL_BATCHES} total batches for the test run.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bLSvOOEUZUQM" + }, + "source": [ + "### Create SFT Trainer State" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2IHsC0m6ZUQM" + }, + "outputs": [], + "source": [ + "trainer, mesh = train_sft.setup_trainer_state(config)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PpKtEqzFZUQM" + }, + "source": [ + "### Create vLLM Rollout" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3-pf_rbqZUQM" + }, + "outputs": [], + "source": [ + "tunix_model = TunixMaxTextAdapter(trainer.model)\n", + "vllm_rollout = VllmRollout(\n", + " model=tunix_model,\n", + " tokenizer=tokenizer,\n", + " cache_config_or_size=1280,\n", + " mesh=mesh,\n", + " rollout_config=base_rollout.RolloutConfig(\n", + " rollout_vllm_model_version=TOKENIZER_PATH,\n", + " rollout_vllm_hbm_utilization=0.8,\n", + " rollout_vllm_init_with_random_weights=True,\n", + " rollout_vllm_tpu_backend_type=\"jax\",\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "567gTxsEZUQM" + }, + "source": [ + "## Evaluation before SFT Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OnACa3zCZUQM" + }, + "outputs": [], + "source": [ + "print(\"Running Pre-SFT Evaluation...\")\n", + "score = evaluate_model(test_dataset, vllm_rollout, debug=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "u5-M4iYkZUQN" + }, + "outputs": [], + "source": [ + "print(\"========================= Score for PRE-SFT Evaluation =========================\")\n", + "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n", + "print(\n", + " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n", + ")\n", + "print(\n", + " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EJE1ookSAzz-" + }, + "source": [ + "## SFT Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "id": "mgwpNgQYCJEd", + "tags": [] + }, + "outputs": [], + "source": [ + "print(\"Starting SFT Training...\")\n", + "trainer = train_sft.train_model(config, trainer, mesh)\n", + "print(\"SFT Training Complete!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WEdNYRhwZUQN" + }, + "source": [ + "## Evaluation after SFT Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XcsZacZdZUQN" + }, + "outputs": [], + "source": [ + "print(\"Running Post-SFT Evaluation...\")\n", + "model = TunixMaxTextAdapter(trainer.model)\n", + "state = nnx.state(model)\n", + "vllm_rollout.update_params(state)\n", + "score = evaluate_model(test_dataset, vllm_rollout, debug=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "id": "-JtYTPvJZUQN", + "tags": [] + }, + "outputs": [], + "source": [ + "print(\"========================= Score for POST-SFT Evaluation =========================\")\n", + "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n", + "print(\n", + " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n", + ")\n", + "print(\n", + " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n", + ")" + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "gpuType": "V5E1", + "provenance": [] + }, + "kernelspec": { + "display_name": "maxtext_venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/src/maxtext/experimental/agent/ckpt_conversion_agent/README.md b/src/maxtext/experimental/agent/ckpt_conversion_agent/README.md index 44cd63656a..4730c22a9c 100644 --- a/src/maxtext/experimental/agent/ckpt_conversion_agent/README.md +++ b/src/maxtext/experimental/agent/ckpt_conversion_agent/README.md @@ -1,5 +1,5 @@ # Checkpoint conversion agent -The agent is used to automate the model-specific mappings of checkpoint conversion. It is designed to cooperate with the new checkpoint conversion [framework](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/checkpoint_conversion). +The agent is used to automate the model-specific mappings of checkpoint conversion. It is designed to cooperate with the new checkpoint conversion [framework](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/checkpoint_conversion). ## Quick starts To begin, you'll need: @@ -16,7 +16,7 @@ pip install -q -U "google-genai>=1.0.0" ## 1. Prepare the context file -The agent requires context files about the target and source model's parameter names and tensor shapes. You can generate them using the [`save_param.py`](../ckpt_conversion_agent/utils/save_param.py) script. The output directory defined by `config.base_output_directory`. The default is `src/MaxText/experimental/agent/ckpt_conversion_agent/context/` folder. +The agent requires context files about the target and source model's parameter names and tensor shapes. You can generate them using the [`save_param.py`](../ckpt_conversion_agent/utils/save_param.py) script. The output directory defined by `config.base_output_directory`. The default is `src/maxtext/experimental/agent/ckpt_conversion_agent/context/` folder. ```bash python3 -m maxtext.experimental.agent.ckpt_conversion_agent.utils.save_param src/maxtext/configs/base.yml \ per_device_batch_size=1 run_name=param_ model_name= scan_layers=false \ @@ -30,16 +30,16 @@ After it, you can get two `*.json` files in `config.base_output_directory` folde ```bash python3 -m maxtext.experimental.agent.ckpt_conversion_agent.step1 --target_model= \ - --dir_path=src/MaxText/experimental/agent/ckpt_conversion_agent --api_key= + --dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent --api_key= ``` -Our engineer should check the `src/MaxText/experimental/agent/ckpt_conversion_agent/outputs/proposed_dsl.txt` for potential new DSL and assess if it's needed. Then we need to add this ops into `src/MaxText/experimental/agent/ckpt_conversion_agent/context/dsl.txt`. +Our engineer should check the `src/maxtext/experimental/agent/ckpt_conversion_agent/outputs/proposed_dsl.txt` for potential new DSL and assess if it's needed. Then we need to add this ops into `src/maxtext/experimental/agent/ckpt_conversion_agent/context/dsl.txt`. ### 2.2 Step 2: Generate mappings ```bash python3 -m maxtext.experimental.agent.ckpt_conversion_agent.step2 --target_model= \ - --dir_path=src/MaxText/experimental/agent/ckpt_conversion_agent --api_key= + --dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent --api_key= ``` ## Evaluation and Debugging @@ -53,14 +53,14 @@ You can automatically verify the output by comparing the generated code against ```bash python3 -m maxtext.experimental.agent.ckpt_conversion_agent.evaluation --files ground_truth/.py \ - outputs/hook_fn.py --api_key= --dir_path=src/MaxText/experimental/agent/ckpt_conversion_agent + outputs/hook_fn.py --api_key= --dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent ``` ### Manual Debugging (No Ground-Truth Code) If a ground-truth version isn't available, you'll need to debug the conversion manually. The recommended process is to: -1. Add the model mappings into [checkpoint conversion framework](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/README.md#adding-support-for-new-models). +1. Add the model mappings into [checkpoint conversion framework](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/README.md#adding-support-for-new-models). -2. Execute the conversion process layer-by-layer, using [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/README.md#hugging-face-to-maxtext) or [to_huggingface.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/README.md#maxtext-to-hugging-face). +2. Execute the conversion process layer-by-layer, using [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/README.md#hugging-face-to-maxtext) or [to_huggingface.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/README.md#maxtext-to-hugging-face). - If the tensor shape are not matched after conversion, error message will print out the parameter name that caused error. 3. After the conversion is done, run a decode to check the correctness of the generated code. @@ -73,7 +73,7 @@ python3 -m maxtext.inference.decode model_name=gemma3-4b tokenizer_path=src/maxt ``` If outputs are wrong, you can use jax.debug.print() to print the layer-wise mean/max/min values for debugging. -4. To further validate the converted checkpoint, we recommend to use the [forward_pass_logit_checker.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/README.md#verifying-conversion-correctness) to compare the original ckpt with the converted ckpt: +4. To further validate the converted checkpoint, we recommend to use the [forward_pass_logit_checker.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/README.md#verifying-conversion-correctness) to compare the original ckpt with the converted ckpt: ```bash python3 -m tests.utils.forward_pass_logit_checker src/maxtext/configs/base.yml \ tokenizer_path=assets/tokenizers/ \ @@ -121,5 +121,5 @@ Run the [One-shot agent Jyputer notebook](./baselines/one-shot-agent.ipynb) ### Prompt-chain Agent: ```bash python3 -m maxtext.experimental.agent.ckpt_conversion_agent.prompt_chain --target_model= \ - --dir_path=src/MaxText/experimental/agent/ckpt_conversion_agent --api_key= + --dir_path=src/maxtext/experimental/agent/ckpt_conversion_agent --api_key= ``` \ No newline at end of file diff --git a/src/maxtext/experimental/agent/ckpt_conversion_agent/baselines/one-shot-agent.ipynb b/src/maxtext/experimental/agent/ckpt_conversion_agent/baselines/one-shot-agent.ipynb index 0f3a149474..6c82fd4489 100644 --- a/src/maxtext/experimental/agent/ckpt_conversion_agent/baselines/one-shot-agent.ipynb +++ b/src/maxtext/experimental/agent/ckpt_conversion_agent/baselines/one-shot-agent.ipynb @@ -1,458 +1,458 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "bc539d4f", - "metadata": {}, - "outputs": [], - "source": [ - "# Copyright 2025 Google LLC\n", - "#\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# http://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "1f25b113", - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Note: you may need to restart the kernel to use updated packages.\n" - ] - } - ], - "source": [ - "%pip install -U -q 'google-genai>=1.0.0'" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "09a81a73", - "metadata": {}, - "outputs": [], - "source": [ - "from google import genai\n", - "from IPython.display import Markdown" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bf2eab8b", - "metadata": {}, - "outputs": [], - "source": [ - "GOOGLE_API_KEY = \"\"\n", - "\n", - "client = genai.Client(api_key=GOOGLE_API_KEY)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f51eb3cd", - "metadata": {}, - "outputs": [], - "source": [ - "MODEL_ID = \"gemini-2.0-pro\"\n", - "target_model = \"Gemma3\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c7908d62", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": null, + "id": "bc539d4f", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2025 Google LLC\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Uploaded file 'files/96zbvn0jl8v9' as: https://generativelanguage.googleapis.com/v1beta/files/96zbvn0jl8v9\n", - "Uploaded file 'files/i81ci5tcyjwa' as: https://generativelanguage.googleapis.com/v1beta/files/i81ci5tcyjwa\n", - "Uploaded file 'files/poo15cv7l54e' as: https://generativelanguage.googleapis.com/v1beta/files/poo15cv7l54e\n", - "Uploaded file 'files/84qrwp42q92e' as: https://generativelanguage.googleapis.com/v1beta/files/84qrwp42q92e\n" - ] - } - ], - "source": [ - "param_file = client.files.upload(file=\"context/param_mapping.py\")\n", - "shape_file = client.files.upload(file=\"context/hf_shape.py\")\n", - "\n", - "print(f\"Uploaded file '{param_file.name}' as: {param_file.uri}\")\n", - "print(f\"Uploaded file '{shape_file.name}' as: {shape_file.uri}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "a8b3dcf0", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 2, + "id": "1f25b113", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install -U -q 'google-genai>=1.0.0'" + ] + }, { - "data": { - "text/markdown": [ - "```python\n", - "\"\"\"\n", - " Copyright 2025 Google LLC\n", - "\n", - " Licensed under the Apache License, Version 2.0 (the \"License\");\n", - " you may not use this file except in compliance with the License.\n", - " You may obtain a copy of the License at\n", - "\n", - " https://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - " Unless required by applicable law or agreed to in writing, software\n", - " distributed under the License is distributed on an \"AS IS\" BASIS,\n", - " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - " See the License for the specific language governing permissions and\n", - " limitations under the License.\n", - " \"\"\"\n", - "\n", - "import numpy as np\n", - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "\n", - "def Gemma3_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False):\n", - " \"\"\"Returns mapping between MaxText and HuggingFace Gemma3 weight paths.\n", - "\n", - " Args:\n", - " config (dict): Model configuration dictionary containing at least 'num_hidden_layers'.\n", - " scan_layers (bool, optional): Whether the MaxText model uses layer scanning optimization.\n", - " When True, decoder layers are stacked into a single tensor [dim1, #layers, dim2].\n", - " Defaults to False.\n", - "\n", - " Returns:\n", - " dict: A mapping where:\n", - " - Keys are MaxText parameter paths\n", - " - Values are either:\n", - " - Single strings (HF parameter path) for unscanned parameters\n", - " - Lists of strings (HF parameter paths) for stacked layers when scan_layers=True\n", - " \"\"\"\n", - "\n", - " nlayers = config[\"num_hidden_layers\"]\n", - " mapping = {\n", - " \"params-token_embedder-embedding\": \"model.embed_tokens.weight\",\n", - " \"params-decoder-decoder_norm-scale\": \"model.norm.weight\",\n", - " }\n", - " if scan_layers:\n", - " mapping = {\n", - " **mapping,\n", - " \"params-decoder-layers-attention-key-kernel\": [\n", - " f\"model.layers.{i}.self_attn.k_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-attention-value-kernel\": [\n", - " f\"model.layers.{i}.self_attn.v_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-attention-query-kernel\": [\n", - " f\"model.layers.{i}.self_attn.q_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-attention-out-kernel\": [\n", - " f\"model.layers.{i}.self_attn.o_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-mlp-wi_0-kernel\": [\n", - " f\"model.layers.{i}.mlp.gate_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-mlp-wi_1-kernel\": [\n", - " f\"model.layers.{i}.mlp.up_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-mlp-wo-kernel\": [\n", - " f\"model.layers.{i}.mlp.down_proj.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-rms_norm-scale\": [\n", - " f\"model.layers.{i}.input_layernorm.weight\" for i in range(nlayers)\n", - " ],\n", - " \"params-decoder-layers-ffn_rms_norm-scale\": [\n", - " f\"model.layers.{i}.post_attention_layernorm.weight\" for i in range(nlayers)\n", - " ],\n", - " }\n", - " else:\n", - " for layer_idx in range(nlayers):\n", - " layer_mapping = {\n", - " f\"params-decoder-layers_{layer_idx}-attention-key-kernel\": f\"model.layers.{layer_idx}.self_attn.k_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-attention-value-kernel\": f\"model.layers.{layer_idx}.self_attn.v_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-attention-query-kernel\": f\"model.layers.{layer_idx}.self_attn.q_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-attention-out-kernel\": f\"model.layers.{layer_idx}.self_attn.o_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel\": f\"model.layers.{layer_idx}.mlp.gate_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel\": f\"model.layers.{layer_idx}.mlp.up_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-mlp-wo-kernel\": f\"model.layers.{layer_idx}.mlp.down_proj.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-rms_norm-scale\": f\"model.layers.{layer_idx}.input_layernorm.weight\",\n", - " f\"params-decoder-layers_{layer_idx}-ffn_rms_norm-scale\": f\"model.layers.{layer_idx}.post_attention_layernorm.weight\",\n", - " }\n", - " mapping = {**mapping, **layer_mapping}\n", - " return mapping\n", - "\n", - "\n", - "def Gemma3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, saving_to_hf=False):\n", - " \"\"\"Creates parameter transformation functions for converting between MaxText and\n", - " HuggingFace formats.\n", - "\n", - " This function generates a mapping of transformation functions that handle the necessary\n", - " conversions between MaxText and HuggingFace parameter formats, including operations like\n", - " padding, reshaping, and scaling.\n", - "\n", - " Args:\n", - " config (dict): Model configuration dictionary that must contain:\n", - " - num_hidden_layers (int): Number of layers in the model\n", - " - head_dim (int): Dimension of attention heads\n", - " - hidden_size (int): Model's hidden dimension size\n", - "\n", - " scan_layers (bool, optional): Controls the output format for layer parameters:\n", - " - True: Returns transformation functions for batched layer parameters\n", - " - False: Returns transformation functions for individual layer parameters\n", - " Defaults to False.\n", - "\n", - " saving_to_hf (bool, optional): Determines the direction of transformation:\n", - " - True: MaxText → HuggingFace conversion\n", - " - False: HuggingFace → MaxText conversion\n", - " Defaults to False.\n", - "\n", - " Returns:\n", - " dict: Parameter transformation mapping where:\n", - " - Keys: MaxText parameter names (str)\n", - " - Values: Either:\n", - " - callable: Single transformation function\n", - " - list[callable]: List of transformation functions to be applied in sequence\n", - "\n", - " Transformation Details:\n", - " The function handles several types of parameter transformations:\n", - " 1. Embedding layer padding:\n", - " - HF shape: [vocab_size, d_model]\n", - " - MaxText shape: [padded_vocab_size, d_model] (padded for performance)\n", - " 2. Layer normalization scaling:\n", - " - Adds/subtracts 1.0 depending on direction\n", - " 3. Attention query scaling:\n", - " - Scales by sqrt(head_dim) or its inverse\n", - "\n", - " 4. Kernel reshaping:\n", - " - Handles dimension transposition and reshaping between formats\n", - " \"\"\"\n", - " nlayers = config[\"num_hidden_layers\"]\n", - "\n", - " def pad_hf_embedding_layer(input_tensor, target_shape):\n", - " \"\"\"Pads the HF embedding layer to match the MaxText embedding layer's shape.\n", - "\n", - " Note:\n", - " HF embedding weights shape = [vocab_size,d_model]\n", - " MaxText embedding weights shape = [padded_vocab_size,d_model]\n", - " MaxText pad Gemma3 embedding to padded_vocab_size for better performance.\n", - " \"\"\"\n", - " # TODO(wenxindongwork), Perhaps, this dtype should be the activation dtype\n", - " normalizer = np.dtype(\"float32\").type(config[\"hidden_size\"] ** 0.5)\n", - "\n", - " def to_hf():\n", - " target_tensor = input_tensor[: target_shape[0], : target_shape[1]]\n", - " # target_tensor = target_tensor / normalizer # no scale factor for embedding\n", - " target_tensor = target_tensor.astype(input_tensor.dtype)\n", - " return target_tensor\n", - "\n", - " def from_hf():\n", - " target_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)\n", - " target_tensor[: input_tensor.shape[0], : input_tensor.shape[1]] = input_tensor\n", - " # target_tensor = target_tensor * normalizer # no scale factor for embedding\n", - " target_tensor = target_tensor.astype(input_tensor.dtype)\n", - " return target_tensor\n", - "\n", - " if saving_to_hf:\n", - " return to_hf()\n", - " else:\n", - " return from_hf()\n", - "\n", - " def reshape_kernel(input_tensor, target_shape):\n", - " def to_hf():\n", - " flipped_target_shape = np.flip(np.array(target_shape))\n", - " return input_tensor.reshape(flipped_target_shape).T\n", - "\n", - " def from_hf():\n", - " return input_tensor.T.reshape(target_shape)\n", - "\n", - " if saving_to_hf:\n", - " return to_hf()\n", - " else:\n", - " return from_hf()\n", - "\n", - " def scale_rmsnorm_layer(input_tensor, target_shape):\n", - " def to_hf():\n", - " return (input_tensor - 1.0).reshape(target_shape)\n", - "\n", - " def from_hf():\n", - " return (input_tensor + 1.0).reshape(target_shape)\n", - "\n", - " if saving_to_hf:\n", - " return to_hf()\n", - " else:\n", - " return from_hf()\n", - "\n", - " def scale_query_layer(input_tensor, target_shape):\n", - " def to_hf():\n", - " depth_scale = np.dtype(\"float32\").type(np.sqrt(config[\"head_dim\"]))\n", - " return (input_tensor * depth_scale).astype(input_tensor.dtype)\n", - "\n", - " def from_hf():\n", - " depth_scale = np.dtype(\"float32\").type(1 / np.sqrt(config[\"head_dim\"]))\n", - " return (input_tensor * depth_scale).astype(input_tensor.dtype)\n", - "\n", - " if saving_to_hf:\n", - " return to_hf()\n", - " else:\n", - " return from_hf()\n", - "\n", - " mapping = {\n", - " \"params-token_embedder-embedding\": pad_hf_embedding_layer,\n", - " \"params-decoder-decoder_norm-scale\": scale_rmsnorm_layer,\n", - " }\n", - " if scan_layers:\n", - " mapping = {\n", - " **mapping,\n", - " \"params-decoder-layers-attention-query-kernel\": [\n", - " reshape_kernel,\n", - " scale_query_layer,\n", - " ],\n", - " \"params-decoder-layers-attention-key-kernel\": reshape_kernel,\n", - " \"params-decoder-layers-attention-value-kernel\": reshape_kernel,\n", - " \"params-decoder-layers-mlp-wo-kernel\": reshape_kernel,\n", - " \"params-decoder-layers-mlp-wi_1-kernel\": reshape_kernel,\n", - " \"params-decoder-layers-mlp-wi_0-kernel\": reshape_kernel,\n", - " \"params-decoder-layers-attention-out-kernel\": reshape_kernel,\n", - " \"params-decoder-layers-rms_norm-scale\": scale_rmsnorm_layer,\n", - " \"params-decoder-layers-ffn_rms_norm-scale\": scale_rmsnorm_layer,\n", - " }\n", - " else:\n", - " for layer_idx in range(nlayers):\n", - " mapping = {\n", - " **mapping,\n", - " f\"params-decoder-layers_{layer_idx}-attention-query-kernel\": [\n", - " reshape_kernel,\n", - " scale_query_layer,\n", - " ],\n", - " f\"params-decoder-layers_{layer_idx}-attention-key-kernel\": reshape_kernel,\n", - " f\"params-decoder-layers_{layer_idx}-attention-value-kernel\": reshape_kernel,\n", - " f\"params-decoder-layers_{layer_idx}-mlp-wo-kernel\": reshape_kernel,\n", - " f\"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel\": reshape_kernel,\n", - " f\"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel\": reshape_kernel,\n", - " f\"params-decoder-layers_{layer_idx}-attention-out-kernel\": reshape_kernel,\n", - " f\"params-decoder-layers_{layer_idx}-rms_norm-scale\": scale_rmsnorm_layer,\n", - " f\"params-decoder-layers_{layer_idx}-ffn_rms_norm-scale\": scale_rmsnorm_layer,\n", - " }\n", - " return mapping\n", - "\n", - "\n", - "def Gemma3_HF_WEIGHTS_TO_SHAPE_MAPPING(config):\n", - " \"\"\"Returns mapping between HuggingFace weights path and weights shape.\n", - "\n", - " Args:\n", - " config (dict): Model configuration dictionary, defined in `model_configs.py`\n", - "\n", - " Returns:\n", - " dict: A mapping where:\n", - " - Keys are HuggingFace model parameter paths\n", - " - Values are parameter shape as a List\n", - " \"\"\"\n", - "\n", - " mapping = {\n", - " \"model.embed_tokens.weight\": [config[\"vocab_size\"], config[\"hidden_size\"]],\n", - " \"model.norm.weight\": [config[\"hidden_size\"]],\n", - " }\n", - " for layer_idx in range(config[\"num_hidden_layers\"]):\n", - " layer_mapping = {\n", - " f\"model.layers.{layer_idx}.input_layernorm.weight\": [config[\"hidden_size\"]],\n", - " f\"model.layers.{layer_idx}.post_attention_layernorm.weight\": [config[\"hidden_size\"]],\n", - " f\"model.layers.{layer_idx}.self_attn.q_proj.weight\": [\n", - " config[\"num_attention_heads\"] * config[\"head_dim\"],\n", - " config[\"hidden_size\"],\n", - " ],\n", - " f\"model.layers.{layer_idx}.self_attn.k_proj.weight\": [\n", - " config[\"num_key_value_heads\"] * config[\"head_dim\"],\n", - " config[\"hidden_size\"],\n", - " ],\n", - " f\"model.layers.{layer_idx}.self_attn.v_proj.weight\": [\n", - " config[\"num_key_value_heads\"] * config[\"head_dim\"],\n", - " config[\"hidden_size\"],\n", - " ],\n", - " f\"model.layers.{layer_idx}.self_attn.o_proj.weight\": [\n", - " config[\"hidden_size\"],\n", - " config[\"num_attention_heads\"] * config[\"head_dim\"],\n", - " ],\n", - " f\"model.layers.{layer_idx}.mlp.gate_proj.weight\": [\n", - " config[\"intermediate_size\"],\n", - " config[\"hidden_size\"],\n", - " ],\n", - " f\"model.layers.{layer_idx}.mlp.up_proj.weight\": [\n", - " config[\"intermediate_size\"],\n", - " config[\"hidden_size\"],\n", - " ],\n", - " f\"model.layers.{layer_idx}.mlp.down_proj.weight\": [\n", - " config[\"hidden_size\"],\n", - " config[\"intermediate_size\"],\n", - " ],\n", - " }\n", - " mapping = {**mapping, **layer_mapping}\n", - " return mapping\n", - "\n", - "```" + "cell_type": "code", + "execution_count": 3, + "id": "09a81a73", + "metadata": {}, + "outputs": [], + "source": [ + "from google import genai\n", + "from IPython.display import Markdown" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf2eab8b", + "metadata": {}, + "outputs": [], + "source": [ + "GOOGLE_API_KEY = \"\"\n", + "\n", + "client = genai.Client(api_key=GOOGLE_API_KEY)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f51eb3cd", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_ID = \"gemini-2.0-pro\"\n", + "target_model = \"Gemma3\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7908d62", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Uploaded file 'files/96zbvn0jl8v9' as: https://generativelanguage.googleapis.com/v1beta/files/96zbvn0jl8v9\n", + "Uploaded file 'files/i81ci5tcyjwa' as: https://generativelanguage.googleapis.com/v1beta/files/i81ci5tcyjwa\n", + "Uploaded file 'files/poo15cv7l54e' as: https://generativelanguage.googleapis.com/v1beta/files/poo15cv7l54e\n", + "Uploaded file 'files/84qrwp42q92e' as: https://generativelanguage.googleapis.com/v1beta/files/84qrwp42q92e\n" + ] + } ], - "text/plain": [ - "" + "source": [ + "param_file = client.files.upload(file=\"context/param_mapping.py\")\n", + "shape_file = client.files.upload(file=\"context/hf_shape.py\")\n", + "\n", + "print(f\"Uploaded file '{param_file.name}' as: {param_file.uri}\")\n", + "print(f\"Uploaded file '{shape_file.name}' as: {shape_file.uri}\")" ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a8b3dcf0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "```python\n", + "\"\"\"\n", + " Copyright 2025 Google LLC\n", + "\n", + " Licensed under the Apache License, Version 2.0 (the \"License\");\n", + " you may not use this file except in compliance with the License.\n", + " You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + " Unless required by applicable law or agreed to in writing, software\n", + " distributed under the License is distributed on an \"AS IS\" BASIS,\n", + " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + " See the License for the specific language governing permissions and\n", + " limitations under the License.\n", + " \"\"\"\n", + "\n", + "import numpy as np\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", + "\n", + "def Gemma3_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False):\n", + " \"\"\"Returns mapping between MaxText and HuggingFace Gemma3 weight paths.\n", + "\n", + " Args:\n", + " config (dict): Model configuration dictionary containing at least 'num_hidden_layers'.\n", + " scan_layers (bool, optional): Whether the MaxText model uses layer scanning optimization.\n", + " When True, decoder layers are stacked into a single tensor [dim1, #layers, dim2].\n", + " Defaults to False.\n", + "\n", + " Returns:\n", + " dict: A mapping where:\n", + " - Keys are MaxText parameter paths\n", + " - Values are either:\n", + " - Single strings (HF parameter path) for unscanned parameters\n", + " - Lists of strings (HF parameter paths) for stacked layers when scan_layers=True\n", + " \"\"\"\n", + "\n", + " nlayers = config[\"num_hidden_layers\"]\n", + " mapping = {\n", + " \"params-token_embedder-embedding\": \"model.embed_tokens.weight\",\n", + " \"params-decoder-decoder_norm-scale\": \"model.norm.weight\",\n", + " }\n", + " if scan_layers:\n", + " mapping = {\n", + " **mapping,\n", + " \"params-decoder-layers-attention-key-kernel\": [\n", + " f\"model.layers.{i}.self_attn.k_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-attention-value-kernel\": [\n", + " f\"model.layers.{i}.self_attn.v_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-attention-query-kernel\": [\n", + " f\"model.layers.{i}.self_attn.q_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-attention-out-kernel\": [\n", + " f\"model.layers.{i}.self_attn.o_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-mlp-wi_0-kernel\": [\n", + " f\"model.layers.{i}.mlp.gate_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-mlp-wi_1-kernel\": [\n", + " f\"model.layers.{i}.mlp.up_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-mlp-wo-kernel\": [\n", + " f\"model.layers.{i}.mlp.down_proj.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-rms_norm-scale\": [\n", + " f\"model.layers.{i}.input_layernorm.weight\" for i in range(nlayers)\n", + " ],\n", + " \"params-decoder-layers-ffn_rms_norm-scale\": [\n", + " f\"model.layers.{i}.post_attention_layernorm.weight\" for i in range(nlayers)\n", + " ],\n", + " }\n", + " else:\n", + " for layer_idx in range(nlayers):\n", + " layer_mapping = {\n", + " f\"params-decoder-layers_{layer_idx}-attention-key-kernel\": f\"model.layers.{layer_idx}.self_attn.k_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-attention-value-kernel\": f\"model.layers.{layer_idx}.self_attn.v_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-attention-query-kernel\": f\"model.layers.{layer_idx}.self_attn.q_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-attention-out-kernel\": f\"model.layers.{layer_idx}.self_attn.o_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel\": f\"model.layers.{layer_idx}.mlp.gate_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel\": f\"model.layers.{layer_idx}.mlp.up_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-mlp-wo-kernel\": f\"model.layers.{layer_idx}.mlp.down_proj.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-rms_norm-scale\": f\"model.layers.{layer_idx}.input_layernorm.weight\",\n", + " f\"params-decoder-layers_{layer_idx}-ffn_rms_norm-scale\": f\"model.layers.{layer_idx}.post_attention_layernorm.weight\",\n", + " }\n", + " mapping = {**mapping, **layer_mapping}\n", + " return mapping\n", + "\n", + "\n", + "def Gemma3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, saving_to_hf=False):\n", + " \"\"\"Creates parameter transformation functions for converting between MaxText and\n", + " HuggingFace formats.\n", + "\n", + " This function generates a mapping of transformation functions that handle the necessary\n", + " conversions between MaxText and HuggingFace parameter formats, including operations like\n", + " padding, reshaping, and scaling.\n", + "\n", + " Args:\n", + " config (dict): Model configuration dictionary that must contain:\n", + " - num_hidden_layers (int): Number of layers in the model\n", + " - head_dim (int): Dimension of attention heads\n", + " - hidden_size (int): Model's hidden dimension size\n", + "\n", + " scan_layers (bool, optional): Controls the output format for layer parameters:\n", + " - True: Returns transformation functions for batched layer parameters\n", + " - False: Returns transformation functions for individual layer parameters\n", + " Defaults to False.\n", + "\n", + " saving_to_hf (bool, optional): Determines the direction of transformation:\n", + " - True: MaxText → HuggingFace conversion\n", + " - False: HuggingFace → MaxText conversion\n", + " Defaults to False.\n", + "\n", + " Returns:\n", + " dict: Parameter transformation mapping where:\n", + " - Keys: MaxText parameter names (str)\n", + " - Values: Either:\n", + " - callable: Single transformation function\n", + " - list[callable]: List of transformation functions to be applied in sequence\n", + "\n", + " Transformation Details:\n", + " The function handles several types of parameter transformations:\n", + " 1. Embedding layer padding:\n", + " - HF shape: [vocab_size, d_model]\n", + " - MaxText shape: [padded_vocab_size, d_model] (padded for performance)\n", + " 2. Layer normalization scaling:\n", + " - Adds/subtracts 1.0 depending on direction\n", + " 3. Attention query scaling:\n", + " - Scales by sqrt(head_dim) or its inverse\n", + "\n", + " 4. Kernel reshaping:\n", + " - Handles dimension transposition and reshaping between formats\n", + " \"\"\"\n", + " nlayers = config[\"num_hidden_layers\"]\n", + "\n", + " def pad_hf_embedding_layer(input_tensor, target_shape):\n", + " \"\"\"Pads the HF embedding layer to match the MaxText embedding layer's shape.\n", + "\n", + " Note:\n", + " HF embedding weights shape = [vocab_size,d_model]\n", + " MaxText embedding weights shape = [padded_vocab_size,d_model]\n", + " MaxText pad Gemma3 embedding to padded_vocab_size for better performance.\n", + " \"\"\"\n", + " # TODO(wenxindongwork), Perhaps, this dtype should be the activation dtype\n", + " normalizer = np.dtype(\"float32\").type(config[\"hidden_size\"] ** 0.5)\n", + "\n", + " def to_hf():\n", + " target_tensor = input_tensor[: target_shape[0], : target_shape[1]]\n", + " # target_tensor = target_tensor / normalizer # no scale factor for embedding\n", + " target_tensor = target_tensor.astype(input_tensor.dtype)\n", + " return target_tensor\n", + "\n", + " def from_hf():\n", + " target_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)\n", + " target_tensor[: input_tensor.shape[0], : input_tensor.shape[1]] = input_tensor\n", + " # target_tensor = target_tensor * normalizer # no scale factor for embedding\n", + " target_tensor = target_tensor.astype(input_tensor.dtype)\n", + " return target_tensor\n", + "\n", + " if saving_to_hf:\n", + " return to_hf()\n", + " else:\n", + " return from_hf()\n", + "\n", + " def reshape_kernel(input_tensor, target_shape):\n", + " def to_hf():\n", + " flipped_target_shape = np.flip(np.array(target_shape))\n", + " return input_tensor.reshape(flipped_target_shape).T\n", + "\n", + " def from_hf():\n", + " return input_tensor.T.reshape(target_shape)\n", + "\n", + " if saving_to_hf:\n", + " return to_hf()\n", + " else:\n", + " return from_hf()\n", + "\n", + " def scale_rmsnorm_layer(input_tensor, target_shape):\n", + " def to_hf():\n", + " return (input_tensor - 1.0).reshape(target_shape)\n", + "\n", + " def from_hf():\n", + " return (input_tensor + 1.0).reshape(target_shape)\n", + "\n", + " if saving_to_hf:\n", + " return to_hf()\n", + " else:\n", + " return from_hf()\n", + "\n", + " def scale_query_layer(input_tensor, target_shape):\n", + " def to_hf():\n", + " depth_scale = np.dtype(\"float32\").type(np.sqrt(config[\"head_dim\"]))\n", + " return (input_tensor * depth_scale).astype(input_tensor.dtype)\n", + "\n", + " def from_hf():\n", + " depth_scale = np.dtype(\"float32\").type(1 / np.sqrt(config[\"head_dim\"]))\n", + " return (input_tensor * depth_scale).astype(input_tensor.dtype)\n", + "\n", + " if saving_to_hf:\n", + " return to_hf()\n", + " else:\n", + " return from_hf()\n", + "\n", + " mapping = {\n", + " \"params-token_embedder-embedding\": pad_hf_embedding_layer,\n", + " \"params-decoder-decoder_norm-scale\": scale_rmsnorm_layer,\n", + " }\n", + " if scan_layers:\n", + " mapping = {\n", + " **mapping,\n", + " \"params-decoder-layers-attention-query-kernel\": [\n", + " reshape_kernel,\n", + " scale_query_layer,\n", + " ],\n", + " \"params-decoder-layers-attention-key-kernel\": reshape_kernel,\n", + " \"params-decoder-layers-attention-value-kernel\": reshape_kernel,\n", + " \"params-decoder-layers-mlp-wo-kernel\": reshape_kernel,\n", + " \"params-decoder-layers-mlp-wi_1-kernel\": reshape_kernel,\n", + " \"params-decoder-layers-mlp-wi_0-kernel\": reshape_kernel,\n", + " \"params-decoder-layers-attention-out-kernel\": reshape_kernel,\n", + " \"params-decoder-layers-rms_norm-scale\": scale_rmsnorm_layer,\n", + " \"params-decoder-layers-ffn_rms_norm-scale\": scale_rmsnorm_layer,\n", + " }\n", + " else:\n", + " for layer_idx in range(nlayers):\n", + " mapping = {\n", + " **mapping,\n", + " f\"params-decoder-layers_{layer_idx}-attention-query-kernel\": [\n", + " reshape_kernel,\n", + " scale_query_layer,\n", + " ],\n", + " f\"params-decoder-layers_{layer_idx}-attention-key-kernel\": reshape_kernel,\n", + " f\"params-decoder-layers_{layer_idx}-attention-value-kernel\": reshape_kernel,\n", + " f\"params-decoder-layers_{layer_idx}-mlp-wo-kernel\": reshape_kernel,\n", + " f\"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel\": reshape_kernel,\n", + " f\"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel\": reshape_kernel,\n", + " f\"params-decoder-layers_{layer_idx}-attention-out-kernel\": reshape_kernel,\n", + " f\"params-decoder-layers_{layer_idx}-rms_norm-scale\": scale_rmsnorm_layer,\n", + " f\"params-decoder-layers_{layer_idx}-ffn_rms_norm-scale\": scale_rmsnorm_layer,\n", + " }\n", + " return mapping\n", + "\n", + "\n", + "def Gemma3_HF_WEIGHTS_TO_SHAPE_MAPPING(config):\n", + " \"\"\"Returns mapping between HuggingFace weights path and weights shape.\n", + "\n", + " Args:\n", + " config (dict): Model configuration dictionary, defined in `model_configs.py`\n", + "\n", + " Returns:\n", + " dict: A mapping where:\n", + " - Keys are HuggingFace model parameter paths\n", + " - Values are parameter shape as a List\n", + " \"\"\"\n", + "\n", + " mapping = {\n", + " \"model.embed_tokens.weight\": [config[\"vocab_size\"], config[\"hidden_size\"]],\n", + " \"model.norm.weight\": [config[\"hidden_size\"]],\n", + " }\n", + " for layer_idx in range(config[\"num_hidden_layers\"]):\n", + " layer_mapping = {\n", + " f\"model.layers.{layer_idx}.input_layernorm.weight\": [config[\"hidden_size\"]],\n", + " f\"model.layers.{layer_idx}.post_attention_layernorm.weight\": [config[\"hidden_size\"]],\n", + " f\"model.layers.{layer_idx}.self_attn.q_proj.weight\": [\n", + " config[\"num_attention_heads\"] * config[\"head_dim\"],\n", + " config[\"hidden_size\"],\n", + " ],\n", + " f\"model.layers.{layer_idx}.self_attn.k_proj.weight\": [\n", + " config[\"num_key_value_heads\"] * config[\"head_dim\"],\n", + " config[\"hidden_size\"],\n", + " ],\n", + " f\"model.layers.{layer_idx}.self_attn.v_proj.weight\": [\n", + " config[\"num_key_value_heads\"] * config[\"head_dim\"],\n", + " config[\"hidden_size\"],\n", + " ],\n", + " f\"model.layers.{layer_idx}.self_attn.o_proj.weight\": [\n", + " config[\"hidden_size\"],\n", + " config[\"num_attention_heads\"] * config[\"head_dim\"],\n", + " ],\n", + " f\"model.layers.{layer_idx}.mlp.gate_proj.weight\": [\n", + " config[\"intermediate_size\"],\n", + " config[\"hidden_size\"],\n", + " ],\n", + " f\"model.layers.{layer_idx}.mlp.up_proj.weight\": [\n", + " config[\"intermediate_size\"],\n", + " config[\"hidden_size\"],\n", + " ],\n", + " f\"model.layers.{layer_idx}.mlp.down_proj.weight\": [\n", + " config[\"hidden_size\"],\n", + " config[\"intermediate_size\"],\n", + " ],\n", + " }\n", + " mapping = {**mapping, **layer_mapping}\n", + " return mapping\n", + "\n", + "```" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prompt = f\"\"\"\n", + " You are a code assist to help me find the checkpoint conversion from MaxText to HuggingFace. \n", + " The checkpoint does not fuse QKV vectors. \n", + " The transformer configs should be completely aligned with given model config for {target_model}\n", + " You need to generate the following code functions of {target_model} Model:\n", + " {target_model}_MAXTEXT_TO_HF_PARAM_MAPPING(); \n", + " {target_model}_MAXTEXT_TO_HF_PARAM_HOOK_FN();\n", + " {target_model}_HF_WEIGHTS_TO_SHAPE_MAPPING();\n", + "\"\"\"\n", + "\n", + "response = client.models.generate_content(model=MODEL_ID, contents=[prompt, param_file, shape_file])\n", + "\n", + "Markdown(response.text)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "agent_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" } - ], - "source": [ - "prompt = f\"\"\"\n", - " You are a code assist to help me find the checkpoint conversion from maxtext to huggingface. \n", - " The checkpoint does not fuse QKV vectors. \n", - " The transformer configs should be completely aligned with given model config for {target_model}\n", - " You need to generate the following code functions of {target_model} Model:\n", - " {target_model}_MAXTEXT_TO_HF_PARAM_MAPPING(); \n", - " {target_model}_MAXTEXT_TO_HF_PARAM_HOOK_FN();\n", - " {target_model}_HF_WEIGHTS_TO_SHAPE_MAPPING();\n", - "\"\"\"\n", - "\n", - "response = client.models.generate_content(model=MODEL_ID, contents=[prompt, param_file, shape_file])\n", - "\n", - "Markdown(response.text)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "agent_env", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index 418fd970db..ff8cdde8a8 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -478,7 +478,7 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F if self._iterator is not None: # Follow MaxText's logic to handle multi-process saving - # Logic extracted from src/MaxText/common/checkpointing.py:save_checkpoint + # Logic extracted from src/maxtext/common/checkpointing.py:save_checkpoint data_iterator = self._iterator if not isinstance(data_iterator, list): data_iterator = [data_iterator] diff --git a/src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py b/src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py index 05c567e924..2d30fd2706 100644 --- a/src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py +++ b/src/maxtext/trainers/post_train/distillation/save_top_k_teacher_logits.py @@ -36,7 +36,7 @@ from itertools import islice from absl import app -from MaxText import pyconfig +from maxtext import pyconfig from maxtext.utils import model_creation_utils from maxtext.input_pipeline import input_pipeline_interface from maxtext.utils import maxtext_utils diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index d77270f62e..5083a2194f 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -85,10 +85,10 @@ def get_maxtext_model(config, devices=None): """ Load MaxText model with Tunix adapter. # Note: pass the path to your scanned checkpoint for 'load_parameters_path'. - # To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if + # To create a scanned checkpoint, you can use /maxtext/src/maxtext/checkpoint_conversion/to_maxtext.py and if # using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags: # export USE_PATHWAYS=1 - # python src/MaxText/checkpoint_conversion/to_maxtext.py \ + # python src/maxtext/checkpoint_conversion/to_maxtext.py \ # --model_name="gemma2-2b" \ # --base_output_directory="/path/to/your/output/directory" \ # --scan_layers=True \ diff --git a/tests/end_to_end/gpu/a3/test_llama2_7b.sh b/tests/end_to_end/gpu/a3/test_llama2_7b.sh index a832c66500..49d6915f54 100644 --- a/tests/end_to_end/gpu/a3/test_llama2_7b.sh +++ b/tests/end_to_end/gpu/a3/test_llama2_7b.sh @@ -16,7 +16,7 @@ idx=$(date +%Y-%m-%d-%H-%M) export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs export ASYNC_CHECKPOINTING=false -# We install torch CPU because the checkpoint conversion script "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/llama_or_mistral_ckpt.py does not need a TPU/GPU +# We install torch CPU because the checkpoint conversion script "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}"/llama_or_mistral_ckpt.py does not need a TPU/GPU python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu # We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint @@ -29,7 +29,7 @@ gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/ # `CONVERTED_CHECKPOINT_PATH` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `CONVERTED_CHECKPOINT_PATH` to a GCS bucket that you own export CONVERTED_CHECKPOINT_PATH=gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext-gpu -#Next, run the conversion script `src/MaxText/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path` +#Next, run the conversion script `src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path` python3 -m maxtext.checkpoint_conversion.standalone_scripts.llama_or_mistral_ckpt --base-model-path /tmp/meta-ckpt --model-size llama2-7b --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH} # We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory exactly inside `CONVERTED_CHECKPOINT_PATH`. This way it is easier to use this path in the `train.py` and `decode.py` commands diff --git a/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md b/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md index 0efdf5729c..10651f7186 100644 --- a/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md +++ b/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md @@ -57,8 +57,8 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ ## Checkpoint conversion To get started, follow the instructions at HuggingFace ([V3](https://huggingface.co/deepseek-ai/DeepSeek-V3), [V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)) to download the model. Currently for V3, V3.1, and R1, it uses mixed precision fp8 & bf16 weights. To convert all FP8 weights to BF16, use the script [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/ckpt_scripts/deepseek_fp8_to_bf16.py). Once downloaded and converted to BF16: -* run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly. -* run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding. +* run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly. +* run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding. ## Fine-tuning diff --git a/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh b/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh index 2e22f344b8..9f5d4c320b 100644 --- a/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh +++ b/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh @@ -20,8 +20,8 @@ export TOKENIZER_PATH='deepseek-ai/DeepSeek-V2-Lite' # Installing torch for checkpoint conversion and forward_pass_logit_checker.py python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -# e.g., $HOME/maxtext/src/MaxText -export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}" +# e.g., $HOME/maxtext/src/maxtext +export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}" if [ -z "${BASE_OUTPUT_PATH}" ]; then # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. diff --git a/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh b/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh index 016d435133..f2b0d62dad 100644 --- a/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh +++ b/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh @@ -18,8 +18,8 @@ export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3' # Installing torch for checkpoint conversion and forward_pass_logit_checker.py python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -# e.g., $HOME/maxtext/src/MaxText -export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}" +# e.g., $HOME/maxtext/src/maxtext +export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}" if [ -z "${BASE_OUTPUT_PATH}" ]; then # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. diff --git a/tests/end_to_end/tpu/gemma/Run_Gemma.md b/tests/end_to_end/tpu/gemma/Run_Gemma.md index 33149c7d21..2fe8141156 100644 --- a/tests/end_to_end/tpu/gemma/Run_Gemma.md +++ b/tests/end_to_end/tpu/gemma/Run_Gemma.md @@ -19,7 +19,7 @@ Following the instructions at [kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText) will let you download Gemma model weights. You will have to consent to license for Gemma using your kaggle account's [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials). -After downloading the weights run [convert_gemma_chkpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_gemma_chkpt.py), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [tests/end_to_end/tpu/gemma](https://github.com/AI-Hypercomputer/maxtext/tree/main/tests/end_to_end/tpu/gemma). +After downloading the weights run [convert_gemma_chkpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gemma_chkpt.py), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [tests/end_to_end/tpu/gemma](https://github.com/AI-Hypercomputer/maxtext/tree/main/tests/end_to_end/tpu/gemma). ## MaxText supports pretraining and finetuning with high performance diff --git a/tests/end_to_end/tpu/gemma3/Run_Gemma3.md b/tests/end_to_end/tpu/gemma3/Run_Gemma3.md index 962ae4a803..f95f39b54e 100644 --- a/tests/end_to_end/tpu/gemma3/Run_Gemma3.md +++ b/tests/end_to_end/tpu/gemma3/Run_Gemma3.md @@ -29,7 +29,7 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml model_n ``` ## Checkpoint Conversion -To obtain the Gemma3 model weights, follow the instructions provided on [Kaggle](https://www.kaggle.com/models/google/gemma-3/flax/). You will need to accept the Gemma3 license through your Kaggle account and utilize your Kaggle [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials) for authentication. Once the weights are downloaded to your GCS bucket, use the [checkpoint conversion utils](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/checkpoint_conversion#usage) to transform the checkpoint into a format compatible with MaxText. This script will also upload the converted checkpoints to a Google Cloud Storage (GCS) bucket. +To obtain the Gemma3 model weights, follow the instructions provided on [Kaggle](https://www.kaggle.com/models/google/gemma-3/flax/). You will need to accept the Gemma3 license through your Kaggle account and utilize your Kaggle [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials) for authentication. Once the weights are downloaded to your GCS bucket, use the [checkpoint conversion utils](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/checkpoint_conversion#usage) to transform the checkpoint into a format compatible with MaxText. This script will also upload the converted checkpoints to a Google Cloud Storage (GCS) bucket. ## Fine-tuning After the conversion, you will have a MaxText compatible checkpoint which allows you to fine-tune it with different datasets. One example command to fine-tune a Gemma3-4B model is as follows: diff --git a/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md b/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md index 0c70557b4f..de505cf439 100644 --- a/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md +++ b/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md @@ -31,7 +31,7 @@ hf download [openai/gpt-oss-20b|openai/gpt-oss-120b] --local-dir --output-path= --dtype-str=bf16 @@ -39,14 +39,14 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.dequantize_mxfp4 --i 3. Once downloaded and converted to BF16: -* run [convert_gpt_oss_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_gpt_oss_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) scanned format for training and fine-tuning. +* run [convert_gpt_oss_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt_oss_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) scanned format for training and fine-tuning. ``` python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gpt_oss_ckpt --base-model-path \ --maxtext-model-path --model-size [gpt-oss-20b|gpt-oss-120b] ``` -* run [convert_gpt_oss_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/standalone_scripts/convert_gpt_oss_unscanned_ckpt.py) to convert the checkpoint to unscanned format in Orbax for decoding. +* run [convert_gpt_oss_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt_oss_unscanned_ckpt.py) to convert the checkpoint to unscanned format in Orbax for decoding. ``` python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_gpt_oss_unscanned_ckpt --base-model-path \ diff --git a/tests/end_to_end/tpu/mixtral/Run_Mixtral.md b/tests/end_to_end/tpu/mixtral/Run_Mixtral.md index c0f93de5e5..243041ff27 100644 --- a/tests/end_to_end/tpu/mixtral/Run_Mixtral.md +++ b/tests/end_to_end/tpu/mixtral/Run_Mixtral.md @@ -19,7 +19,7 @@ [Mixtral](https://mistral.ai/news/mixtral-of-experts/) is a state-of-the-art AI model developed by Mistral AI, utilizing a sparse mixture-of-experts (MoE) architecture. -To get started, follow the instructions at [mistral-inference](https://github.com/mistralai/mistral-inference) to download the model. Once downloaded, run [llama_or_mistral_ckpt.py](../../../src/MaxText/llama_or_mistral_ckpt.py) to convert the checkpoint for MaxText compatibility. You can then proceed with decoding, pretraining, and finetuning. You could find Mixtral 8x7B example in the [tests/end_to_end/tpu/mixtral/8x7b](../mixtral/8x7b) test scripts. +To get started, follow the instructions at [mistral-inference](https://github.com/mistralai/mistral-inference) to download the model. Once downloaded, run [llama_or_mistral_ckpt.py](../../../src/maxtext/llama_or_mistral_ckpt.py) to convert the checkpoint for MaxText compatibility. You can then proceed with decoding, pretraining, and finetuning. You could find Mixtral 8x7B example in the [tests/end_to_end/tpu/mixtral/8x7b](../mixtral/8x7b) test scripts. Additionally, Mixtral integrates with [MegaBlocks](https://arxiv.org/abs/2211.15841), an efficient dropless MoE strategy, which can be activated by setting both sparse_matmul and megablox flags to True (default). diff --git a/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh index 04a7596688..fd11eed7fa 100644 --- a/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh +++ b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh @@ -20,8 +20,8 @@ export TOKENIZER_PATH='Qwen/Qwen3-Next-80B-A3B-Instruct' # Installing torch for checkpoint conversion and forward_pass_logit_checker.py python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -# e.g., $HOME/maxtext/src/MaxText -export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}" +# e.g., $HOME/maxtext/src/maxtext +export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}" if [ -z "${BASE_OUTPUT_PATH}" ]; then # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. diff --git a/tests/end_to_end/tpu/test_grpo.sh b/tests/end_to_end/tpu/test_grpo.sh index 21bf5a6174..e69dbca005 100644 --- a/tests/end_to_end/tpu/test_grpo.sh +++ b/tests/end_to_end/tpu/test_grpo.sh @@ -54,6 +54,6 @@ ici_data_parallelism=${NUM_SAMPLERS} ici_tensor_parallelism=${DEVICES_PER_SAMPLE profiler=xplane skip_first_n_steps_for_profiler=10 profiler_steps=2" JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \ - python3 src/MaxText/experimental/rl/grpo_trainer.py src/MaxText/experimental/rl/grpo.yml \ - ${COMMON_ARGS} ${TRAINING_ARGS} src/MaxText/experimental/rl/grpo_inference.yml \ + python3 src/maxtext/experimental/rl/grpo_trainer.py src/maxtext/experimental/rl/grpo.yml \ + ${COMMON_ARGS} ${TRAINING_ARGS} src/maxtext/experimental/rl/grpo_inference.yml \ ${COMMON_ARGS} ${INFERENCE_ARGS} diff --git a/tests/inference/test_llama2_7b_bf16.sh b/tests/inference/test_llama2_7b_bf16.sh index 82f360b7ec..b02fa32d9a 100755 --- a/tests/inference/test_llama2_7b_bf16.sh +++ b/tests/inference/test_llama2_7b_bf16.sh @@ -1,8 +1,8 @@ #!/bin/bash -CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml" +CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}/configs/base.yml" if [ "${DECOUPLE_GCLOUD^^}" = "TRUE" ]; then - CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/decoupled_base_test.yml" + CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}/configs/decoupled_base_test.yml" fi # Define the arguments in an array diff --git a/tests/inference/test_llama2_7b_int8.sh b/tests/inference/test_llama2_7b_int8.sh index 4056467bc6..8e11e6ab48 100755 --- a/tests/inference/test_llama2_7b_int8.sh +++ b/tests/inference/test_llama2_7b_int8.sh @@ -1,8 +1,8 @@ #!/bin/bash -CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml" +CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}/configs/base.yml" if [ "${DECOUPLE_GCLOUD^^}" = "TRUE" ]; then - CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/decoupled_base_test.yml" + CONFIG_PATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}/configs/decoupled_base_test.yml" fi # Define the arguments in an array diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index c176e53883..b7a31acae6 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -34,7 +34,7 @@ # For example: # tests/assets/logits_generation/golden_llama2-7b_export.ipynb -"""Check if the logits generated by a model's src/MaxText/HF implementation matches golden logits for the same inputs""" +"""Check if the logits generated by a model's src/maxtext/HF implementation matches golden logits for the same inputs""" import argparse import os diff --git a/tools/dev/code_style.sh b/tools/dev/code_style.sh index c540715158..cf25bdfa2b 100755 --- a/tools/dev/code_style.sh +++ b/tools/dev/code_style.sh @@ -18,7 +18,7 @@ set -e # Exit immediately if any command fails REPO_ROOT="${MAXTEXT_REPO_ROOT:-$PWD}" -FOLDERS_TO_FORMAT=("${MAXTEXT_PKG_DIR:-${REPO_ROOT}/src/MaxText}" "${REPO_ROOT}/pedagogical_examples") +FOLDERS_TO_FORMAT=("${MAXTEXT_PKG_DIR:-${REPO_ROOT}/src/maxtext}" "${REPO_ROOT}/pedagogical_examples") LINE_LENGTH=$(grep -E "^max-line-length=" pylintrc | cut -d '=' -f 2) # Check for --check flag