diff --git a/.github/workflows/run_jupyter_notebooks.yml b/.github/workflows/run_jupyter_notebooks.yml
index 528ff15e5f..78a1ec4375 100644
--- a/.github/workflows/run_jupyter_notebooks.yml
+++ b/.github/workflows/run_jupyter_notebooks.yml
@@ -87,10 +87,11 @@ jobs:
- name: Run Post-Training Notebooks
shell: bash
env:
+ PYTHONPATH: "${{ github.workspace }}/src"
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
MAXTEXT_REPO_ROOT=$(pwd)
- MAXTEXT_NOTEBOOKS_ROOT="$MAXTEXT_REPO_ROOT/src/MaxText/examples"
+ MAXTEXT_NOTEBOOKS_ROOT="$MAXTEXT_REPO_ROOT/src/maxtext/examples"
for notebook in "$MAXTEXT_NOTEBOOKS_ROOT"/{sft,rl}*.ipynb; do
filename=$(basename "$notebook")
diff --git a/README.md b/README.md
index 5373a11360..aae9502b0b 100644
--- a/README.md
+++ b/README.md
@@ -43,7 +43,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies
* \[December 22, 2025\] [Muon optimizer](https://kellerjordan.github.io/posts/muon) is now supported.
* \[December 10, 2025\] DeepSeek V3.1 is now supported. Use existing configs for [DeepSeek V3 671B](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/models/deepseek3-671b.yml) and load in V3.1 checkpoint to use model.
-* \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/examples) are available.
+* \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/examples) are available.
* \[December 4, 2025\] The [ReadTheDocs documentation site](https://maxtext.readthedocs.io/en/latest/index.html) has been reorganized.
* \[December 3, 2025\] Multi-host support for GSPO and GRPO is now available via [new RL tutorials](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/rl_on_multi_host.html).
* \[November 20, 2025\] A new guide, [What is Post Training in MaxText?](https://maxtext.readthedocs.io/en/latest/tutorials/post_training_index.html), is now available.
diff --git a/codecov.yml b/codecov.yml
index 9765fe5edd..0b616cf714 100644
--- a/codecov.yml
+++ b/codecov.yml
@@ -34,7 +34,7 @@ fixes:
ignore:
- "src/maxtext/assets"
- "src/MaxText/configs"
- - "src/MaxText/examples"
+ - "src/maxtext/examples"
- "src/MaxText/experimental"
- "src/MaxText/inference"
- "src/MaxText/inference_mlperf"
diff --git a/docs/guides/run_python_notebook.md b/docs/guides/run_python_notebook.md
index 7c5dc14ac6..7d24c271d6 100644
--- a/docs/guides/run_python_notebook.md
+++ b/docs/guides/run_python_notebook.md
@@ -19,6 +19,7 @@ Before starting, make sure you have:
- ✅ Basic familiarity with Jupyter, Python, and Git
**For Method 2 (Visual Studio Code) and Method 3 (Local Jupyter Lab) only:**
+
- ✅ A Google Cloud Platform (GCP) account with billing enabled
- ✅ TPU quota available in your region (check under IAM & Admin → Quotas)
- ✅ `tpu.nodes.create` permission to create a TPU VM
@@ -36,16 +37,18 @@ Currently, this method only supports the **`sft_qwen3_demo.ipynb`** notebook, wh
Before proceeding, please verify that the specific notebook you are running works reliably on the free-tier TPU resources. If you encounter frequent disconnections or resource limitations, you may need to:
-* Upgrade to a Colab Pro or Pro+ subscription for more stable and powerful TPU access.
+- Upgrade to a Colab Pro or Pro+ subscription for more stable and powerful TPU access.
-* Move to local Jupyter Lab setup method with access to a powerful TPU machine.
+- Move to local Jupyter Lab setup method with access to a powerful TPU machine.
### Step 1: Choose an Example
-1.a. Visit the [MaxText examples directory](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/examples) on Github.
+
+1.a. Visit the [MaxText examples directory](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/examples) on Github.
1.b. Find the notebook you want to run (e.g., `sft_qwen3_demo.ipynb`) and copy its URL.
### Step 2: Import into Colab
+
2.a. Go to [Google Colab](https://colab.research.google.com/) and sign in.
2.b. Select **File** -> **Open Notebook**.
@@ -63,9 +66,11 @@ Before proceeding, please verify that the specific notebook you are running work
3.c. Click **Save**
### Step 4: Run the Notebook
+
Follow the instructions within the notebook cells to install dependencies and run the training/inference.
## Method 2: Visual Studio Code with TPU (Recommended)
+
Running Jupyter notebooks in Visual Studio Code (VS Code) provides a powerful, interactive environment that combines the flexibility of notebooks with the robust features of a code editor. Follow these steps to get your environment up and running.
### Step 1: Set Up TPU VM
@@ -75,9 +80,10 @@ In Google Cloud Console, create a standalone TPU VM:
1.a. **Compute Engine** → **TPUs** → **Create TPU**
1.b. Example config:
- - **Name:** `maxtext-tpu-node`
- - **TPU type:** Choose your desired TPU type
- - **Runtime Version:** `tpu-ubuntu2204-base` (or other compatible runtime)
+
+- **Name:** `maxtext-tpu-node`
+- **TPU type:** Choose your desired TPU type
+- **Runtime Version:** `tpu-ubuntu2204-base` (or other compatible runtime)
### Step 2: SSH to TPU-VM via VS Code
@@ -86,11 +92,12 @@ In Google Cloud Console, create a standalone TPU VM:
2.b. Follow [Connect to a remote host](https://code.visualstudio.com/docs/remote/ssh#_connect-to-a-remote-host) guide to connect to your TPU-VM via VS Code.
### Step 3. Install Necessary Extensions on VS Code
+
To enable notebook support, you must install two official extensions from the VS Code Marketplace:
-* Python Extension: Provides support for the Python language.
+- Python Extension: Provides support for the Python language.
-* Jupyter Extension: Enables you to create, edit, and run `.ipynb` files directly inside VS Code.
+- Jupyter Extension: Enables you to create, edit, and run `.ipynb` files directly inside VS Code.
To install, click the `Extensions` icon on the left sidebar (or press `Ctrl+Shift+X` or `Cmd+Shift+X`), search for `Jupyter` and `Python`, and click `Install`.
@@ -99,6 +106,7 @@ To install, click the `Extensions` icon on the left sidebar (or press `Ctrl+Shif
To execute post-training notebooks on your TPU-VM, follow the official [MaxText installation guides](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/rl.html#create-virtual-environment-and-install-maxtext-dependencies) to install MaxText and its dependencies inside a dedicated virtual environment.
### Step 5: Install the necessary library for Jupyter
+
Jupyter requires a kernel to execute code. This kernel is tied to a specific Python environment. Open your terminal inside VS Code and run:
```bash
@@ -110,9 +118,9 @@ uv pip install ipykernel
Before you can run the notebook, you must tell VS Code which Python environment to use.
1. Look at the top-right corner of the notebook editor.
-2. Click `Select Kernel`.
-3. Choose Python Environments and select the virtual environment you created in Step 4.
-4. Open [available post-training notebooks in MaxText](#available-examples) inside VS Code and run the jupyter notebook cells.
+1. Click `Select Kernel`.
+1. Choose Python Environments and select the virtual environment you created in Step 4.
+1. Open [available post-training notebooks in MaxText](#available-examples) inside VS Code and run the jupyter notebook cells.
## Method 3: Local Jupyter Lab with TPU (Recommended)
@@ -125,12 +133,15 @@ In Google Cloud Console, create a standalone TPU VM:
1.a. **Compute Engine** → **TPUs** → **Create TPU**
1.b. Example config:
- - **Name:** `maxtext-tpu-node`
- - **TPU type:** Choose your desired TPU type
- - **Runtime Version:** `tpu-ubuntu2204-base` (or other compatible runtime)
+
+- **Name:** `maxtext-tpu-node`
+- **TPU type:** Choose your desired TPU type
+- **Runtime Version:** `tpu-ubuntu2204-base` (or other compatible runtime)
### Step 2: Connect with Port Forwarding
+
Run the following command on your local machine:
+
> **Note**: The `--` separator before the `-L` flag is required. This tunnels the remote port 8888 to your local machine securely.
```bash
@@ -170,13 +181,15 @@ jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root
```
### Step 7: Access the Notebook
+
7.a. Look at the terminal output for a URL that looks like: `http://127.0.0.1:8888/lab?token=...`.
7.b. Copy that URL.
7.c. Paste it into your **local computer's browser**.
- * **Important:** If you changed the port in Step 2 (e.g., to `9999`), you must manually replace `8888` in the URL with `9999`.
- * *Example:* `http://127.0.0.1:9999/lab?token=...`
+
+- **Important:** If you changed the port in Step 2 (e.g., to `9999`), you must manually replace `8888` in the URL with `9999`.
+- *Example:* `http://127.0.0.1:9999/lab?token=...`
7.d. Once the interface opens in your browser, Click on the current kernel name (e.g., `Python 3 (ipykernel)`).
@@ -197,13 +210,13 @@ jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root
## Common Pitfalls & Debugging
-| Issue | Solution |
-|-------|----------|
-| ❌ TPU runtime mismatch | Check TPU runtime version matches VM image |
-| ❌ Colab disconnects | Save checkpoints to GCS or Drive regularly |
-| ❌ "RESOURCE_EXHAUSTED" errors | Use smaller batch size or v5e-8 instead of v5e-1 |
-| ❌ Firewall blocked | Ensure port 8888 open, or always use SSH tunneling |
-| ❌ Path confusion | In Colab use `/content/maxtext`; in TPU VM use `~/maxtext` |
+| Issue | Solution |
+| ------------------------------ | ---------------------------------------------------------- |
+| ❌ TPU runtime mismatch | Check TPU runtime version matches VM image |
+| ❌ Colab disconnects | Save checkpoints to GCS or Drive regularly |
+| ❌ "RESOURCE_EXHAUSTED" errors | Use smaller batch size or v5e-8 instead of v5e-1 |
+| ❌ Firewall blocked | Ensure port 8888 open, or always use SSH tunneling |
+| ❌ Path confusion | In Colab use `/content/maxtext`; in TPU VM use `~/maxtext` |
## Support and Resources
@@ -217,9 +230,9 @@ jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root
If you encounter issues or have improvements for this guide, please:
1. Open an issue on the MaxText repository
-2. Submit a pull request with your improvements
-3. Share your experience in the discussions
+1. Submit a pull request with your improvements
+1. Share your experience in the discussions
----
+______________________________________________________________________
-**Happy Training! 🚀**
\ No newline at end of file
+**Happy Training! 🚀**
diff --git a/docs/tutorials/first_run.md b/docs/tutorials/first_run.md
index 7fefa80344..960ecfbb98 100644
--- a/docs/tutorials/first_run.md
+++ b/docs/tutorials/first_run.md
@@ -75,7 +75,7 @@ In the same TPU VM where you just installed all the dependencies of MaxText, You
#### Decoding in MaxText via notebook
-You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint.
+You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint.
### Run MaxText on NVIDIA GPUs
diff --git a/docs/tutorials/posttraining/multimodal.md b/docs/tutorials/posttraining/multimodal.md
index 980c2d8aca..11c6982c66 100644
--- a/docs/tutorials/posttraining/multimodal.md
+++ b/docs/tutorials/posttraining/multimodal.md
@@ -6,7 +6,7 @@ This document provides a guide to use the multimodal functionalities in MaxText
- **Multimodal Decode**: Inference with text+images as input.
- **Supervised Fine-Tuning (SFT)**: Apply SFT to the model using a visual-question-answering dataset.
-We also provide a [colab](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/multimodal_gemma3_demo.ipynb) for multimodal features demonstration. The following table provides a list of models and modalities we currently support:
+We also provide a [colab](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/multimodal_gemma3_demo.ipynb) for multimodal features demonstration. The following table provides a list of models and modalities we currently support:
| Models | Input Modalities | Output Modalities |
| :--------------------------------------------- | :--------------- | :---------------- |
diff --git a/pedagogical_examples/__init__.py b/pedagogical_examples/__init__.py
deleted file mode 100644
index 2237c9162e..0000000000
--- a/pedagogical_examples/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2023–2025 Google LLC
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml
index 283d388d92..5a8f57f664 100644
--- a/src/MaxText/configs/rl.yml
+++ b/src/MaxText/configs/rl.yml
@@ -171,7 +171,7 @@ reasoning_start_token: ''
reasoning_end_token: ''
solution_start_token: ''
solution_end_token: ''
-chat_template_path: 'src/MaxText/examples/chat_templates/gsm8k_rl.json'
+chat_template_path: 'src/maxtext/examples/chat_templates/gsm8k_rl.json'
skip_jax_distributed_system: True
# # TODO(@mazumdera): fix this
diff --git a/src/MaxText/examples/rl_llama3_demo.ipynb b/src/MaxText/examples/rl_llama3_demo.ipynb
deleted file mode 100644
index a3d7cd7506..0000000000
--- a/src/MaxText/examples/rl_llama3_demo.ipynb
+++ /dev/null
@@ -1,410 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "aPIBMUYfh6NF"
- },
- "source": [
- "# Llama3.1-8B-Instruct Reinforcement Learning Demo\n",
- "\n",
- "This notebook demonstrates training on Llama3.1-8B-Instruct model with either GRPO (Group Relative Policy Optimization) or GSPO (Group Sequence Policy Optimization).\n",
- "\n",
- "This notebook can run on **TPU v5e-8** or **v5p-8**.\n",
- "\n",
- "## What is GRPO/GSPO?\n",
- "\n",
- "GRPO/GSPO is an RL algorithm that enhances reasoning abilities of LLMs by:\n",
- "1. Generating multiple responses for each prompt\n",
- "2. Evaluating responses using reward models \n",
- "3. Calculating relative advantages to update the policy\n",
- "\n",
- "The difference is in the loss function - either it's optimizing each token (GRPO) or the whole sequence(GSPO)."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "3hIEGGBbh6NF"
- },
- "source": [
- "## Prerequisites\n",
- "\n",
- "### Change Runtime Type (only if running on Google Colab)\n",
- "\n",
- "**Instructions:**\n",
- "1. Navigate to the menu at the top of the screen.\n",
- "2. Click on **Runtime**.\n",
- "3. Select **Change runtime type** from the dropdown menu.\n",
- "4. Select **v5e-8** or **v5p-8 TPU** as the **Hardware accelerator**.\n",
- "5. Click on **Save**."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "t9J8jIJuh6NF"
- },
- "source": [
- "### Get Your Hugging Face Token\n",
- "\n",
- "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n",
- "\n",
- "**Follow these steps to get your token:**\n",
- "\n",
- "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n",
- " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n",
- "\n",
- "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n",
- "\n",
- "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n",
- "\n",
- "4. **Copy the generated token**. You will need this in the later steps.\n",
- "\n",
- "**Follow these steps to store your token (only if running on Google Colab):**\n",
- "\n",
- "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n",
- "\n",
- "2. Click **\"+ Add new secret\"**.\n",
- "\n",
- "3. Set the Name as **HF_TOKEN**.\n",
- "\n",
- "4. Paste your token into the Value field.\n",
- "\n",
- "5. Ensure the Notebook access toggle is turned On."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "NasqFEooh6NF"
- },
- "outputs": [],
- "source": [
- "try:\n",
- " from google.colab import userdata\n",
- " print(\"Running the notebook on Google Colab\")\n",
- " IN_COLAB = True\n",
- "except ImportError:\n",
- " print(\"Running the notebook on Visual Studio or JupyterLab\")\n",
- " IN_COLAB = False"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "_5dzCPYeh6NF"
- },
- "source": [
- "## Installation: MaxText and Dependencies\n",
- "\n",
- "**⚠️ Note:** The installation process in following cell may take a few minutes to complete. Please be patient."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "2G9dGRaoh6NF"
- },
- "outputs": [],
- "source": [
- "if IN_COLAB:\n",
- " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n",
- " %cd /content/maxtext\n",
- "\n",
- " # Install uv, a fast Python package installer\n",
- " !pip install uv\n",
- "\n",
- " # Install MaxText and its dependencies\n",
- " !uv pip install -e .[tpu] --resolution=lowest\n",
- " !python3 -m MaxText.install_maxtext_extra_deps\n",
- "\n",
- " # Install vLLM for Jax and TPUs\n",
- " !uv pip install vllm-tpu\n",
- " !uv pip install --no-deps qwix==0.1.4"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "nelhTNVGh6NF"
- },
- "source": [
- "### Restart Session (only if running on Google Colab)\n",
- "To apply certain changes, you need to restart the session.\n",
- "\n",
- "**Instructions:**\n",
- "1. Navigate to the menu at the top of the screen.\n",
- "2. Click on **Runtime**.\n",
- "3. Select **Restart session** from the dropdown menu.\n",
- "\n",
- "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "7Q4NS47Mh6NF"
- },
- "source": [
- "## Environment Setup"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Ml0orOyEh6NF"
- },
- "outputs": [],
- "source": [
- "import datetime\n",
- "import os\n",
- "import sys\n",
- "from pathlib import Path\n",
- "import MaxText\n",
- "from huggingface_hub import login\n",
- "import jax\n",
- "\n",
- "from MaxText.rl.train_rl import rl_train, setup_configs_and_devices\n",
- "\n",
- "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n",
- "os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n",
- "\n",
- "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n",
- "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "BHOtzHWBh6NF"
- },
- "outputs": [],
- "source": [
- "if not jax.distributed.is_initialized():\n",
- " jax.distributed.initialize()\n",
- "print(f\"JAX version: {jax.__version__}\")\n",
- "print(f\"JAX devices: {jax.devices()}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "jqewgodrh6NF"
- },
- "outputs": [],
- "source": [
- "if IN_COLAB:\n",
- " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n",
- "else:\n",
- " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n",
- "\n",
- "# If not found in the environment, prompt the user for input securely\n",
- "# getpass function ensures the token is hidden while you type\n",
- "if not HF_TOKEN:\n",
- " from getpass import getpass\n",
- " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n",
- "\n",
- "if HF_TOKEN:\n",
- " os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
- " login(token=HF_TOKEN)\n",
- " print(\"Authenticated with Hugging Face successfully!\")\n",
- "else:\n",
- " print(\"Authentication failed: Hugging Face token is not set.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "_uxwEwuah6NF"
- },
- "source": [
- "## Model Configurations"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "dSQzsP3th6NF"
- },
- "outputs": [],
- "source": [
- "MODEL_NAME = \"llama3.1-8b\"\n",
- "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
- "RUN_NAME = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n",
- "LOSS_ALGO=\"grpo\" # or \"gspo-token\" if you want to use GSPO\n",
- "\n",
- "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json\"\n",
- "if not os.path.exists(CHAT_TEMPLATE_PATH):\n",
- " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n",
- "\n",
- "# set the path to the model checkpoint or leave empty to download from HuggingFace\n",
- "MODEL_CHECKPOINT_PATH = \"\"\n",
- "if not MODEL_CHECKPOINT_PATH:\n",
- " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_REPO_ROOT}/llama_checkpoint\"\n",
- " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n",
- " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n",
- " \n",
- "OUTPUT_DIRECTORY = f\"{MAXTEXT_REPO_ROOT}/rl_llama3_output\""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "jHbSg0vih6NF"
- },
- "source": [
- "## Download Llama3.1-8B Model Checkpoint from Hugging Face"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "I3O4Pr-1h6NF"
- },
- "outputs": [],
- "source": [
- "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
- " # install torch for the conversion script\n",
- " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n",
- "\n",
- " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_REPO_ROOT} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n",
- " {MAXTEXT_REPO_ROOT}/configs/base.yml \\\n",
- " model_name={MODEL_NAME} \\\n",
- " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n",
- " hf_access_token={HF_TOKEN} \\\n",
- " use_multimodal=false \\\n",
- " scan_layers=true \\\n",
- " skip_jax_distributed_system=True\n",
- "\n",
- "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
- " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "HHuww-GWh6NF"
- },
- "source": [
- "## MaxText Configurations"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "UNzZPHYsh6NF"
- },
- "outputs": [],
- "source": [
- "# Load configuration for RL training\n",
- "config_argv = [\n",
- " \"\",\n",
- " f\"{MAXTEXT_REPO_ROOT}/configs/rl.yml\",\n",
- " f\"model_name={MODEL_NAME}\",\n",
- " f\"tokenizer_path={TOKENIZER_PATH}\",\n",
- " f\"run_name={RUN_NAME}\",\n",
- " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n",
- " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n",
- " f\"base_output_directory={OUTPUT_DIRECTORY}\",\n",
- " f\"hf_access_token={HF_TOKEN}\",\n",
- " \"debug.rl=False\",\n",
- " f\"rl.loss_algo={LOSS_ALGO}\",\n",
- " \"use_pathways=False\"\n",
- "]\n",
- "\n",
- "trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(config_argv)\n",
- "\n",
- "rl_train_steps = int(\n",
- " trainer_config.num_batches\n",
- " * trainer_config.rl.num_iterations\n",
- " * trainer_config.train_fraction\n",
- " * trainer_config.num_epoch\n",
- ")\n",
- "\n",
- "print(\"✓ Configuration initialized successfully\")\n",
- "print(f\"📁 Output directory: {trainer_config.base_output_directory}\")\n",
- "print(f\"🤖 Model: {trainer_config.model_name}\")\n",
- "print(f\"📊 RL Train Steps: {rl_train_steps}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "uKDvQ-x4h6NG"
- },
- "source": [
- "## RL Training"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "zGm_j3Vrh6NG"
- },
- "outputs": [],
- "source": [
- "print(\"\\n\" + \"=\" * 80)\n",
- "print(f\"🚀 Starting {LOSS_ALGO} Training...\")\n",
- "print(\"=\" * 80)\n",
- "try:\n",
- " rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)\n",
- " print(\"\\n\" + \"=\" * 80)\n",
- " print(\"✅ Training Completed Successfully!\")\n",
- " print(f\"✍️ Note the improved evaluation accuracy metrics with just {rl_train_steps} RL training steps!\")\n",
- " print(\"=\" * 80)\n",
- " print(f\"📁 Checkpoints saved to: {trainer_config.checkpoint_dir}\")\n",
- " print(f\"📊 TensorBoard logs: {trainer_config.tensorboard_dir}\")\n",
- " print(f\"🎯 Model ready for inference!\")\n",
- "except Exception as e:\n",
- " print(\"\\n\" + \"=\" * 80)\n",
- " print(\"❌Training Failed!\")\n",
- " print(\"=\" * 80)\n",
- " print(f\"Error: {str(e)}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "TdlhG-wRh6NG"
- },
- "source": [
- "## 📚 Learn More\n",
- "\n",
- "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/rl.html\n",
- "- **Configuration**: See `src/MaxText/configs/rl.yml` for all available options\n",
- "- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.11"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/src/MaxText/examples/sft_llama3_demo.ipynb b/src/MaxText/examples/sft_llama3_demo.ipynb
deleted file mode 100644
index 03e7b078fc..0000000000
--- a/src/MaxText/examples/sft_llama3_demo.ipynb
+++ /dev/null
@@ -1,407 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "hcsWywuVjQq8"
- },
- "source": [
- "[](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/sft_llama3_demo.ipynb)\n",
- "\n",
- "# Llama3.1-8B-Instruct Supervised Fine-Tuning (SFT) Demo\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "3faYjYXojQq8"
- },
- "source": [
- "## Overview\n",
- "\n",
- "This notebook demonstrates how to perform Supervised Fine-Tuning (SFT) on Llama3.1-8B-Instruct using the Hugging Face ultrachat_200k dataset with MaxText and Tunix integration for efficient training.\n",
- "\n",
- "This notebook can run on **TPU v5e-8** or **v5p-8**."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "6QotQjsMjQq8"
- },
- "source": [
- "## Prerequisites\n",
- "\n",
- "### Change Runtime Type (only if running on Google Colab)\n",
- "\n",
- "**Instructions:**\n",
- "1. Navigate to the menu at the top of the screen.\n",
- "2. Click on **Runtime**.\n",
- "3. Select **Change runtime type** from the dropdown menu.\n",
- "4. Select **v5e-8** or **v5p-8 TPU** as the **Hardware accelerator**.\n",
- "5. Click on **Save**."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "8XMWNuEwjQq8"
- },
- "source": [
- "### Get Your Hugging Face Token\n",
- "\n",
- "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n",
- "\n",
- "**Follow these steps to get your token:**\n",
- "\n",
- "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n",
- " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n",
- "\n",
- "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n",
- "\n",
- "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n",
- "\n",
- "4. **Copy the generated token**. You will need this in the later steps.\n",
- "\n",
- "**Follow these steps to store your token (only if running on Google Colab):**\n",
- "\n",
- "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n",
- "\n",
- "2. Click **\"+ Add new secret\"**.\n",
- "\n",
- "3. Set the Name as **HF_TOKEN**.\n",
- "\n",
- "4. Paste your token into the Value field.\n",
- "\n",
- "5. Ensure the Notebook access toggle is turned On."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Y53cfWyujQq9"
- },
- "outputs": [],
- "source": [
- "try:\n",
- " from google.colab import userdata\n",
- " print(\"Running the notebook on Google Colab\")\n",
- " IN_COLAB = True\n",
- "except ImportError:\n",
- " print(\"Running the notebook on Visual Studio or JupyterLab\")\n",
- " IN_COLAB = False"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "GeTXb9DnjQq9"
- },
- "source": [
- "### Installation: MaxText & Other Dependencies"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "KTRkrhp1jQq9"
- },
- "source": [
- "**⚠️ Note:** The installation process in following cell may take a few minutes to complete. Please be patient."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "aUoJbrTDjQq9"
- },
- "outputs": [],
- "source": [
- "if IN_COLAB:\n",
- " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n",
- " %cd /content/maxtext\n",
- "\n",
- " # Install uv, a fast Python package installer\n",
- " !pip install uv\n",
- "\n",
- " # Install MaxText and its dependencies\n",
- " !uv pip install -e .[tpu] --resolution=lowest\n",
- " !python3 -m MaxText.install_maxtext_extra_deps"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "8YkY4rSqjQq9"
- },
- "source": [
- "### Restart Session (only if running on Google Colab)\n",
- "To apply certain changes, you need to restart the session.\n",
- "\n",
- "**Instructions:**\n",
- "1. Navigate to the menu at the top of the screen.\n",
- "2. Click on **Runtime**.\n",
- "3. Select **Restart session** from the dropdown menu.\n",
- "\n",
- "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "6mrIF5acjQq9"
- },
- "source": [
- "## Environment Setup"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "mdeFSFymjQq9"
- },
- "outputs": [],
- "source": [
- "import datetime\n",
- "import os\n",
- "import subprocess\n",
- "import sys\n",
- "import MaxText\n",
- "from MaxText import pyconfig\n",
- "from maxtext.trainers.post_train.sft import train_sft\n",
- "import jax\n",
- "from huggingface_hub import login\n",
- "\n",
- "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n",
- "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "4RBhx-PBjQq9"
- },
- "outputs": [],
- "source": [
- "if not jax.distributed.is_initialized():\n",
- " jax.distributed.initialize()\n",
- "print(f\"JAX version: {jax.__version__}\")\n",
- "print(f\"JAX devices: {jax.devices()}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "yaw60bgPjQq9"
- },
- "outputs": [],
- "source": [
- "if IN_COLAB:\n",
- " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n",
- "else:\n",
- " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n",
- "\n",
- "# If not found in the environment, prompt the user for input securely\n",
- "# getpass function ensures the token is hidden while you type\n",
- "if not HF_TOKEN:\n",
- " from getpass import getpass\n",
- " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n",
- "\n",
- "if HF_TOKEN:\n",
- " login(token=HF_TOKEN)\n",
- " print(\"Authenticated with Hugging Face successfully!\")\n",
- "else:\n",
- " print(\"Authentication failed: Hugging Face token is not set.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "bxNSmUV8jQq9"
- },
- "source": [
- "## Model Configurations"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "kWjjVLgUjQq9"
- },
- "outputs": [],
- "source": [
- "MODEL_NAME = \"llama3.1-8b\"\n",
- "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
- "\n",
- "# set the path to the model checkpoint or leave empty to download from HuggingFace\n",
- "MODEL_CHECKPOINT_PATH = \"\"\n",
- "if not MODEL_CHECKPOINT_PATH:\n",
- " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_REPO_ROOT}/llama_checkpoint\"\n",
- " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n",
- " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n",
- "\n",
- "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_REPO_ROOT}/sft_llama3_output\"\n",
- "RUN_NAME = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "yCO9GaQMjQq9"
- },
- "source": [
- "## Download Llama3.1-8B Model Checkpoint from Hugging Face"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "bwzsJpHqjQq9"
- },
- "outputs": [],
- "source": [
- "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
- " # install torch for the conversion script\n",
- " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n",
- "\n",
- " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_REPO_ROOT} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n",
- " {MAXTEXT_REPO_ROOT}/configs/base.yml \\\n",
- " model_name={MODEL_NAME} \\\n",
- " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n",
- " hf_access_token={HF_TOKEN} \\\n",
- " use_multimodal=false \\\n",
- " scan_layers=true \\\n",
- " skip_jax_distributed_system=True\n",
- "\n",
- "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
- " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "imqH5KL4jQq9"
- },
- "source": [
- "## MaxText Configurations"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "In-jdp1AAwrL"
- },
- "outputs": [],
- "source": [
- "# Load configuration for SFT training\n",
- "config_argv = [\n",
- " \"\",\n",
- " f\"{MAXTEXT_REPO_ROOT}/configs/sft.yml\",\n",
- " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n",
- " f\"model_name={MODEL_NAME}\",\n",
- " \"steps=100\",\n",
- " \"per_device_batch_size=1\",\n",
- " \"max_target_length=1024\",\n",
- " \"learning_rate=2.0e-5\",\n",
- " \"weight_dtype=bfloat16\",\n",
- " \"dtype=bfloat16\",\n",
- " \"hf_path=HuggingFaceH4/ultrachat_200k\",\n",
- " f\"hf_access_token={HF_TOKEN}\",\n",
- " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n",
- " f\"run_name={RUN_NAME}\",\n",
- " f\"tokenizer_path={TOKENIZER_PATH}\",\n",
- " \"profiler=xplane\",\n",
- "]\n",
- "\n",
- "config = pyconfig.initialize(config_argv)\n",
- "\n",
- "print(\"✓ SFT configuration loaded:\")\n",
- "print(f\" Model: {config.model_name}\")\n",
- "print(f\" Training Steps: {config.steps}\")\n",
- "print(f\" Output Directory: {config.base_output_directory}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "_fdeK75ajQq9"
- },
- "source": [
- "## SFT Training"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "mgwpNgQYCJEd"
- },
- "outputs": [],
- "source": [
- "print(\"=\" * 60)\n",
- "print(\"🚀 Starting SFT Training...\")\n",
- "print(\"=\" * 60)\n",
- "\n",
- "try:\n",
- " trainer, mesh = train_sft.train(config)\n",
- "\n",
- " print(\"\\n\" + \"=\" * 60)\n",
- " print(\"✅ Training Completed Successfully!\")\n",
- " print(\"=\" * 60)\n",
- " print(f\"📁 Checkpoints saved to: {config.checkpoint_dir}\")\n",
- "except Exception as e:\n",
- " print(\"\\n\" + \"=\" * 60)\n",
- " print(\"❌Training Failed!\")\n",
- " print(\"=\" * 60)\n",
- " print(f\"Error: {str(e)}\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "85nvNASTjQq9"
- },
- "source": [
- "## 📚 Learn More\n",
- "\n",
- "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html\n",
- "- **Configuration**: See `src/MaxText/configs/sft.yml` for all available options\n",
- "- **Documentation**: Check `src/MaxText/sft/sft_trainer.py` for the `sft_train` function implementation"
- ]
- }
- ],
- "metadata": {
- "accelerator": "TPU",
- "colab": {
- "gpuType": "V5E1",
- "provenance": []
- },
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
diff --git a/src/MaxText/examples/sft_qwen3_demo.ipynb b/src/MaxText/examples/sft_qwen3_demo.ipynb
deleted file mode 100644
index a56cb2465f..0000000000
--- a/src/MaxText/examples/sft_qwen3_demo.ipynb
+++ /dev/null
@@ -1,631 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "1nb_Ppf2ZUQL"
- },
- "source": [
- "[](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/sft_qwen3_demo.ipynb)\n",
- "\n",
- "# Qwen3-0.6B Supervised Fine-Tuning (SFT) Demo\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "FGbe4_YQZUQL"
- },
- "source": [
- "## Overview\n",
- "\n",
- "This notebook performs SFT training and evaluation workflow on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k).\n",
- "The primary goal is to demonstrate the end-to-end process of:\n",
- "1. Pre-SFT Evaluation: Calcuating baseline accuracy for the model before training.\n",
- "2. SFT Training: Fine-tune the model using MaxText & Tunix SFT trainer.\n",
- "3. Post-SFT Evaluation: Re-running the evaluation loop after training to measure the performance gain achieved by SFT.\n",
- "\n",
- "This notebook can run on the **public TPU v5e-1**."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "zolxPWhQZUQL"
- },
- "source": [
- "## Prerequisites\n",
- "\n",
- "### Change Runtime Type (only if running on Google Colab)\n",
- "\n",
- "**Instructions:**\n",
- "1. Navigate to the menu at the top of the screen.\n",
- "2. Click on **Runtime**.\n",
- "3. Select **Change runtime type** from the dropdown menu.\n",
- "4. Select **v5e-1 TPU** as the **Hardware accelerator**.\n",
- "5. Click on **Save**."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Rk_QpVVuZUQL"
- },
- "source": [
- "### Get Your Hugging Face Token\n",
- "\n",
- "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n",
- "\n",
- "**Follow these steps to get your token:**\n",
- "\n",
- "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n",
- " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n",
- "\n",
- "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n",
- "\n",
- "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n",
- "\n",
- "4. **Copy the generated token**. You will need this in the later steps.\n",
- "\n",
- "**Follow these steps to store your token (only if running on Google Colab):**\n",
- "\n",
- "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n",
- "\n",
- "2. Click **\"+ Add new secret\"**.\n",
- "\n",
- "3. Set the Name as **HF_TOKEN**.\n",
- "\n",
- "4. Paste your token into the Value field.\n",
- "\n",
- "5. Ensure the Notebook access toggle is turned On."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "VQaxueyfjLwX"
- },
- "outputs": [],
- "source": [
- "try:\n",
- " from google.colab import userdata\n",
- " print(\"Running the notebook on Google Colab\")\n",
- " IN_COLAB = True\n",
- "except ImportError:\n",
- " print(\"Running the notebook on Visual Studio or JupyterLab\")\n",
- " IN_COLAB = False"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "D9ms-jTSZUQL"
- },
- "source": [
- "## Installation: MaxText & Other Dependencies"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "wrSXHHrGjLwX"
- },
- "source": [
- "**⚠️ Note:** The installation process in following cell may take a few minutes to complete. Please be patient."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "jYAhyzQJjLwX"
- },
- "outputs": [],
- "source": [
- "try:\n",
- " import google.colab\n",
- " print(\"Running the notebook in Google Colab\")\n",
- " IN_COLAB = True\n",
- "except ImportError:\n",
- " print(\"Running the notebook on JupyterLab\")\n",
- " IN_COLAB = False"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "OSPRVbi7n6tB"
- },
- "outputs": [],
- "source": [
- "if IN_COLAB:\n",
- " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n",
- " %cd /content/maxtext\n",
- "\n",
- " # Install uv, a fast Python package installer\n",
- " !pip install uv\n",
- "\n",
- " # Install MaxText and its dependencies\n",
- " !uv pip install -e .[tpu] --resolution=lowest\n",
- " !python3 -m MaxText.install_maxtext_extra_deps\n",
- "\n",
- " # Install Tunix for post-training notebooks\n",
- " !uv pip install git+https://github.com/google/tunix\n",
- " \n",
- " # Install vllm for post-training notebooks\n",
- " !git clone https://github.com/vllm-project/vllm.git\n",
- " !VLLM_TARGET_DEVICE=\"tpu\" uv pip install ./vllm\n",
- "\n",
- " # Install tpu-inference for post-training notebooks\n",
- " !git clone https://github.com/vllm-project/tpu-inference.git\n",
- " !uv pip install ./tpu-inference\n",
- " !uv pip install --no-deps qwix==0.1.4"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ywtealAxZUQM"
- },
- "source": [
- "### Restart Session (only if running on Google Colab)\n",
- "To apply certain changes, you need to restart the session.\n",
- "\n",
- "**Instructions:**\n",
- "1. Navigate to the menu at the top of the screen.\n",
- "2. Click on **Runtime**.\n",
- "3. Select **Restart session** from the dropdown menu.\n",
- "\n",
- "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Clexf-j7ZUQM"
- },
- "source": [
- "## Imports"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "PkBI9A3JZUQM"
- },
- "outputs": [],
- "source": [
- "import jax\n",
- "import os\n",
- "import sys\n",
- "import transformers\n",
- "\n",
- "import MaxText\n",
- "from MaxText import pyconfig\n",
- "from MaxText.examples.sft_train_and_evaluate import evaluate_model, get_test_dataset\n",
- "from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter\n",
- "from maxtext.trainers.post_train.sft import train_sft\n",
- "\n",
- "# Suppress vLLM logging with a severity level below ERROR\n",
- "os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n",
- "from tunix.rl.rollout import base_rollout\n",
- "from tunix.rl.rollout.vllm_rollout import VllmRollout\n",
- "\n",
- "from datetime import datetime\n",
- "from flax import nnx\n",
- "from huggingface_hub import login\n",
- "\n",
- "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n",
- "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "0nXfUgcWjLwX"
- },
- "outputs": [],
- "source": [
- "if not jax.distributed.is_initialized():\n",
- " jax.distributed.initialize()\n",
- "print(f\"JAX version: {jax.__version__}\")\n",
- "print(f\"JAX devices: {jax.devices()}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "JBbPN-uVZUQM"
- },
- "outputs": [],
- "source": [
- "try:\n",
- " from google.colab import userdata\n",
- " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n",
- "except ImportError:\n",
- " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n",
- "\n",
- "# If not found in the environment, prompt the user for input securely\n",
- "# getpass function ensures the token is hidden while you type\n",
- "if not HF_TOKEN:\n",
- " from getpass import getpass\n",
- " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n",
- "\n",
- "if HF_TOKEN:\n",
- " login(token=HF_TOKEN)\n",
- " print(\"Authenticated with Hugging Face successfully!\")\n",
- "else:\n",
- " print(\"Authentication failed: Hugging Face token is not set.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "aENuzm9iZUQM"
- },
- "source": [
- "## Model Configurations"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "RjPYYl3zZUQM"
- },
- "outputs": [],
- "source": [
- "MODEL_NAME = \"qwen3-0.6b\"\n",
- "TOKENIZER_PATH = \"Qwen/Qwen3-0.6B\"\n",
- "tokenizer = transformers.AutoTokenizer.from_pretrained(\n",
- " TOKENIZER_PATH,\n",
- " token=HF_TOKEN,\n",
- ")\n",
- "\n",
- "# set the path to the model checkpoint or leave empty to download from HuggingFace\n",
- "MODEL_CHECKPOINT_PATH = \"\"\n",
- "if not MODEL_CHECKPOINT_PATH:\n",
- " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_REPO_ROOT}/qwen_checkpoint\"\n",
- " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n",
- " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n",
- "\n",
- "\n",
- "RUN_NAME = datetime.now().strftime(\"%Y-%m-%d-%H-%m-%S\")\n",
- "\n",
- "# This is the directory where the fine-tuned model checkpoint will be saved\n",
- "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_REPO_ROOT}/maxtext_qwen06_output\""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "4L37Ij4NZUQM"
- },
- "source": [
- "## Download Qwen3-0.6B Model Checkpoint from Hugging Face"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "kJanDAc0ZUQM"
- },
- "outputs": [],
- "source": [
- "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
- " # install torch for the conversion script\n",
- " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n",
- "\n",
- " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_REPO_ROOT} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n",
- " {MAXTEXT_REPO_ROOT}/configs/base.yml \\\n",
- " model_name={MODEL_NAME} \\\n",
- " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n",
- " hf_access_token={HF_TOKEN} \\\n",
- " use_multimodal=false \\\n",
- " scan_layers=true \\\n",
- " skip_jax_distributed_system=True\n",
- "\n",
- "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
- " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "PC-hILG0ZUQM"
- },
- "source": [
- "## Dataset Configurations"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "O3MLdr9kZUQM"
- },
- "outputs": [],
- "source": [
- "DATASET_NAME = \"openai/gsm8k\"\n",
- "TRAIN_DATA_SPLIT = \"train\"\n",
- "TEST_DATA_SPLIT = \"test\"\n",
- "HF_DATA_DIR = \"main\"\n",
- "TRAIN_DATA_COLUMNS = [\"question\", \"answer\"]\n",
- "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/math_qa.json\"\n",
- "if not os.path.exists(CHAT_TEMPLATE_PATH):\n",
- " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n",
- "NUM_TEST_SAMPLES = 20 # Total number of samples to test\n",
- "BATCH_SIZE = 1 # Number of test samples to process in a batch"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "yeAHmxSYZUQM"
- },
- "source": [
- "## MaxText Configurations"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "In-jdp1AAwrL"
- },
- "outputs": [],
- "source": [
- "%%capture\n",
- "config = pyconfig.initialize(\n",
- " [\n",
- " \"\",\n",
- " f\"{MAXTEXT_REPO_ROOT}/configs/sft.yml\",\n",
- " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n",
- " f\"model_name={MODEL_NAME}\",\n",
- " f\"hf_access_token={HF_TOKEN}\",\n",
- " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n",
- " f\"run_name={RUN_NAME}\",\n",
- " f\"tokenizer_path={TOKENIZER_PATH}\",\n",
- " f\"hf_path={DATASET_NAME}\",\n",
- " f\"train_split={TRAIN_DATA_SPLIT}\",\n",
- " f\"hf_data_dir={HF_DATA_DIR}\",\n",
- " f\"train_data_columns={TRAIN_DATA_COLUMNS}\",\n",
- " \"steps=500\",\n",
- " \"per_device_batch_size=1\",\n",
- " \"max_target_length=1024\",\n",
- " \"learning_rate=3e-6\",\n",
- " \"weight_dtype=bfloat16\",\n",
- " \"dtype=bfloat16\",\n",
- " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n",
- " ]\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "O9b0GWo-ZUQM"
- },
- "source": [
- "## Initial Setup & Data Preparation"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "TDqFmvUCZUQM"
- },
- "source": [
- "### Create Test Dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "wscWYxrtZUQM"
- },
- "outputs": [],
- "source": [
- "test_dataset = get_test_dataset(config, tokenizer)\n",
- "test_dataset = test_dataset[:NUM_TEST_SAMPLES]\n",
- "test_dataset = test_dataset.to_iter_dataset().batch(BATCH_SIZE, drop_remainder=True)\n",
- "TOTAL_BATCHES = NUM_TEST_SAMPLES // BATCH_SIZE\n",
- "print(\n",
- " f\"Processing {NUM_TEST_SAMPLES} examples with a batch size of {BATCH_SIZE}. This will result in {TOTAL_BATCHES} total batches for the test run.\"\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "bLSvOOEUZUQM"
- },
- "source": [
- "### Create SFT Trainer State"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "2IHsC0m6ZUQM"
- },
- "outputs": [],
- "source": [
- "trainer, mesh = train_sft.setup_trainer_state(config)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "PpKtEqzFZUQM"
- },
- "source": [
- "### Create vLLM Rollout"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "3-pf_rbqZUQM"
- },
- "outputs": [],
- "source": [
- "tunix_model = TunixMaxTextAdapter(trainer.model)\n",
- "vllm_rollout = VllmRollout(\n",
- " model=tunix_model,\n",
- " tokenizer=tokenizer,\n",
- " cache_config_or_size=1280,\n",
- " mesh=mesh,\n",
- " rollout_config=base_rollout.RolloutConfig(\n",
- " rollout_vllm_model_version=TOKENIZER_PATH,\n",
- " rollout_vllm_hbm_utilization=0.8,\n",
- " rollout_vllm_init_with_random_weights=True,\n",
- " rollout_vllm_tpu_backend_type=\"jax\",\n",
- " ),\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "567gTxsEZUQM"
- },
- "source": [
- "## Evaluation before SFT Training"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "OnACa3zCZUQM"
- },
- "outputs": [],
- "source": [
- "print(\"Running Pre-SFT Evaluation...\")\n",
- "score = evaluate_model(test_dataset, vllm_rollout, debug=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "u5-M4iYkZUQN"
- },
- "outputs": [],
- "source": [
- "print(\"========================= Score for PRE-SFT Evaluation =========================\")\n",
- "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n",
- "print(\n",
- " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n",
- ")\n",
- "print(\n",
- " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "EJE1ookSAzz-"
- },
- "source": [
- "## SFT Training"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "editable": true,
- "id": "mgwpNgQYCJEd",
- "tags": []
- },
- "outputs": [],
- "source": [
- "print(\"Starting SFT Training...\")\n",
- "trainer = train_sft.train_model(config, trainer, mesh)\n",
- "print(\"SFT Training Complete!\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "WEdNYRhwZUQN"
- },
- "source": [
- "## Evaluation after SFT Training"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "XcsZacZdZUQN"
- },
- "outputs": [],
- "source": [
- "print(\"Running Post-SFT Evaluation...\")\n",
- "model = TunixMaxTextAdapter(trainer.model)\n",
- "state = nnx.state(model)\n",
- "vllm_rollout.update_params(state)\n",
- "score = evaluate_model(test_dataset, vllm_rollout, debug=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "editable": true,
- "id": "-JtYTPvJZUQN",
- "tags": []
- },
- "outputs": [],
- "source": [
- "print(\"========================= Score for POST-SFT Evaluation =========================\")\n",
- "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n",
- "print(\n",
- " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n",
- ")\n",
- "print(\n",
- " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n",
- ")"
- ]
- }
- ],
- "metadata": {
- "accelerator": "TPU",
- "colab": {
- "gpuType": "V5E1",
- "provenance": []
- },
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.11"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
diff --git a/src/MaxText/examples/chat_templates/gsm8k_rl.json b/src/maxtext/examples/chat_templates/gsm8k_rl.json
similarity index 100%
rename from src/MaxText/examples/chat_templates/gsm8k_rl.json
rename to src/maxtext/examples/chat_templates/gsm8k_rl.json
diff --git a/src/MaxText/examples/chat_templates/math_qa.json b/src/maxtext/examples/chat_templates/math_qa.json
similarity index 100%
rename from src/MaxText/examples/chat_templates/math_qa.json
rename to src/maxtext/examples/chat_templates/math_qa.json
diff --git a/src/MaxText/examples/demo_decoding.ipynb b/src/maxtext/examples/demo_decoding.ipynb
similarity index 99%
rename from src/MaxText/examples/demo_decoding.ipynb
rename to src/maxtext/examples/demo_decoding.ipynb
index 99757c569c..4698a260a0 100644
--- a/src/MaxText/examples/demo_decoding.ipynb
+++ b/src/maxtext/examples/demo_decoding.ipynb
@@ -5,7 +5,7 @@
"id": "e017d77b",
"metadata": {},
"source": [
- "[](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/demo_decoding.ipynb)\n",
+ "[](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/demo_decoding.ipynb)\n",
" \n",
"# Qwen3-0.6B Decoding Demo"
]
@@ -437,4 +437,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
-}
+}
\ No newline at end of file
diff --git a/src/MaxText/examples/multimodal_gemma3_demo.ipynb b/src/maxtext/examples/multimodal_gemma3_demo.ipynb
similarity index 99%
rename from src/MaxText/examples/multimodal_gemma3_demo.ipynb
rename to src/maxtext/examples/multimodal_gemma3_demo.ipynb
index 8410cd6424..4df0157314 100644
--- a/src/MaxText/examples/multimodal_gemma3_demo.ipynb
+++ b/src/maxtext/examples/multimodal_gemma3_demo.ipynb
@@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "[](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/multimodal_gemma3_demo.ipynb)\n",
+ "[](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/multimodal_gemma3_demo.ipynb)\n",
"\n",
"# Gemma3 Multimodal Inference/Training Demo"
]
diff --git a/pedagogical_examples/non_spmd.py b/src/maxtext/examples/non_spmd.py
similarity index 100%
rename from pedagogical_examples/non_spmd.py
rename to src/maxtext/examples/non_spmd.py
diff --git a/src/maxtext/examples/rl_llama3_demo.ipynb b/src/maxtext/examples/rl_llama3_demo.ipynb
new file mode 100644
index 0000000000..4eacf0b34a
--- /dev/null
+++ b/src/maxtext/examples/rl_llama3_demo.ipynb
@@ -0,0 +1,372 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Llama3.1-8B-Instruct Reinforcement Learning Demo\n",
+ "\n",
+ "This notebook demonstrates training on Llama3.1-8B-Instruct model with either GRPO (Group Relative Policy Optimization) or GSPO (Group Sequence Policy Optimization).\n",
+ "\n",
+ "This notebook can run on **TPU v5e-8** or **v5p-8**.\n",
+ "\n",
+ "## What is GRPO/GSPO?\n",
+ "\n",
+ "GRPO/GSPO is an RL algorithm that enhances reasoning abilities of LLMs by:\n",
+ "1. Generating multiple responses for each prompt\n",
+ "2. Evaluating responses using reward models \n",
+ "3. Calculating relative advantages to update the policy\n",
+ "\n",
+ "The difference is in the loss function - either it's optimizing each token (GRPO) or the whole sequence(GSPO)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Prerequisites\n",
+ "\n",
+ "### Change Runtime Type (only if running on Google Colab)\n",
+ "\n",
+ "**Instructions:**\n",
+ "1. Navigate to the menu at the top of the screen.\n",
+ "2. Click on **Runtime**.\n",
+ "3. Select **Change runtime type** from the dropdown menu.\n",
+ "4. Select **v5e-8** or **v5p-8 TPU** as the **Hardware accelerator**.\n",
+ "5. Click on **Save**."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Get Your Hugging Face Token\n",
+ "\n",
+ "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n",
+ "\n",
+ "**Follow these steps to get your token:**\n",
+ "\n",
+ "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n",
+ " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n",
+ "\n",
+ "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n",
+ "\n",
+ "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n",
+ "\n",
+ "4. **Copy the generated token**. You will need this in the later steps.\n",
+ "\n",
+ "**Follow these steps to store your token (only if running on Google Colab):**\n",
+ "\n",
+ "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n",
+ "\n",
+ "2. Click **\"+ Add new secret\"**.\n",
+ "\n",
+ "3. Set the Name as **HF_TOKEN**.\n",
+ "\n",
+ "4. Paste your token into the Value field.\n",
+ "\n",
+ "5. Ensure the Notebook access toggle is turned On."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "try:\n",
+ " from google.colab import userdata\n",
+ " print(\"Running the notebook on Google Colab\")\n",
+ " IN_COLAB = True\n",
+ "except ImportError:\n",
+ " print(\"Running the notebook on Visual Studio or JupyterLab\")\n",
+ " IN_COLAB = False"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Installation: MaxText and Dependencies\n",
+ "\n",
+ "**⚠️ Note:** The installation process in following cell may take a few minutes to complete. Please be patient."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if IN_COLAB:\n",
+ " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n",
+ " %cd /content/maxtext\n",
+ "\n",
+ " # Install uv, a fast Python package installer\n",
+ " !pip install uv\n",
+ "\n",
+ " # Install MaxText and its dependencies\n",
+ " !uv pip install -e .[tpu] --resolution=lowest\n",
+ " !python3 -m MaxText.install_maxtext_extra_deps\n",
+ "\n",
+ " # Install vLLM for Jax and TPUs\n",
+ " !uv pip install vllm-tpu\n",
+ " !uv pip install --no-deps qwix==0.1.4"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Restart Session (only if running on Google Colab)\n",
+ "To apply certain changes, you need to restart the session.\n",
+ "\n",
+ "**Instructions:**\n",
+ "1. Navigate to the menu at the top of the screen.\n",
+ "2. Click on **Runtime**.\n",
+ "3. Select **Restart session** from the dropdown menu.\n",
+ "\n",
+ "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Environment Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import datetime\n",
+ "import os\n",
+ "import sys\n",
+ "from pathlib import Path\n",
+ "import MaxText\n",
+ "from huggingface_hub import login\n",
+ "import jax\n",
+ "\n",
+ "from MaxText import max_utils\n",
+ "from MaxText.rl.train_rl import rl_train, setup_configs_and_devices\n",
+ "\n",
+ "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n",
+ "os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n",
+ "\n",
+ "MAXTEXT_PKG_DIR = os.path.dirname(MaxText.__file__)\n",
+ "MAXTEXT_REPO_ROOT = os.sep.join([\"maxtext\" if p == \"MaxText\" else p for p in MAXTEXT_PKG_DIR.split(os.sep)])\n",
+ "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if not jax.distributed.is_initialized():\n",
+ " jax.distributed.initialize()\n",
+ "print(f\"JAX version: {jax.__version__}\")\n",
+ "print(f\"JAX devices: {jax.devices()}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if IN_COLAB:\n",
+ " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n",
+ "else:\n",
+ " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n",
+ "\n",
+ "# If not found in the environment, prompt the user for input securely\n",
+ "# getpass function ensures the token is hidden while you type\n",
+ "if not HF_TOKEN:\n",
+ " from getpass import getpass\n",
+ " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n",
+ "\n",
+ "if HF_TOKEN:\n",
+ " os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
+ " login(token=HF_TOKEN)\n",
+ " print(\"Authenticated with Hugging Face successfully!\")\n",
+ "else:\n",
+ " print(\"Authentication failed: Hugging Face token is not set.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Model Configurations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MODEL_NAME = \"llama3.1-8b\"\n",
+ "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
+ "RUN_NAME = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n",
+ "LOSS_ALGO=\"grpo\" # or \"gspo-token\" if you want to use GSPO\n",
+ "\n",
+ "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json\"\n",
+ "if not os.path.exists(CHAT_TEMPLATE_PATH):\n",
+ " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n",
+ "\n",
+ "# set the path to the model checkpoint or leave empty to download from HuggingFace\n",
+ "MODEL_CHECKPOINT_PATH = \"\"\n",
+ "if not MODEL_CHECKPOINT_PATH:\n",
+ " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_PKG_DIR}/llama_checkpoint\"\n",
+ " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n",
+ " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n",
+ " \n",
+ "OUTPUT_DIRECTORY = f\"{MAXTEXT_PKG_DIR}/rl_llama3_output\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Download Llama3.1-8B Model Checkpoint from Hugging Face"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
+ " # install torch for the conversion script\n",
+ " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n",
+ "\n",
+ " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_PKG_DIR} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n",
+ " {MAXTEXT_PKG_DIR}/configs/base.yml \\\n",
+ " model_name={MODEL_NAME} \\\n",
+ " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n",
+ " hf_access_token={HF_TOKEN} \\\n",
+ " use_multimodal=false \\\n",
+ " scan_layers=true \\\n",
+ " skip_jax_distributed_system=True\n",
+ "\n",
+ "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
+ " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## MaxText Configurations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load configuration for RL training\n",
+ "config_argv = [\n",
+ " \"\",\n",
+ " f\"{MAXTEXT_PKG_DIR}/configs/rl.yml\",\n",
+ " f\"model_name={MODEL_NAME}\",\n",
+ " f\"tokenizer_path={TOKENIZER_PATH}\",\n",
+ " f\"run_name={RUN_NAME}\",\n",
+ " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n",
+ " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n",
+ " f\"base_output_directory={OUTPUT_DIRECTORY}\",\n",
+ " f\"hf_access_token={HF_TOKEN}\",\n",
+ " \"debug.rl=False\",\n",
+ " f\"rl.loss_algo={LOSS_ALGO}\",\n",
+ " \"use_pathways=False\"\n",
+ "]\n",
+ "\n",
+ "trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(config_argv)\n",
+ "\n",
+ "rl_train_steps = int(\n",
+ " trainer_config.num_batches\n",
+ " * trainer_config.rl.num_iterations\n",
+ " * trainer_config.train_fraction\n",
+ " * trainer_config.num_epoch\n",
+ ")\n",
+ "\n",
+ "print(\"✓ Configuration initialized successfully\")\n",
+ "print(f\"📁 Output directory: {trainer_config.base_output_directory}\")\n",
+ "print(f\"🤖 Model: {trainer_config.model_name}\")\n",
+ "print(f\"📊 RL Train Steps: {rl_train_steps}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## RL Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"\\n\" + \"=\" * 80)\n",
+ "print(f\"🚀 Starting {LOSS_ALGO} Training...\")\n",
+ "print(\"=\" * 80)\n",
+ "try:\n",
+ " rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)\n",
+ " print(\"\\n\" + \"=\" * 80)\n",
+ " print(\"✅ Training Completed Successfully!\")\n",
+ " print(f\"✍️ Note the improved evaluation accuracy metrics with just {rl_train_steps} RL training steps!\")\n",
+ " print(\"=\" * 80)\n",
+ " print(f\"📁 Checkpoints saved to: {trainer_config.checkpoint_dir}\")\n",
+ " print(f\"📊 TensorBoard logs: {trainer_config.tensorboard_dir}\")\n",
+ " print(f\"🎯 Model ready for inference!\")\n",
+ "except Exception as e:\n",
+ " print(\"\\n\" + \"=\" * 80)\n",
+ " print(\"❌Training Failed!\")\n",
+ " print(\"=\" * 80)\n",
+ " print(f\"Error: {str(e)}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 📚 Learn More\n",
+ "\n",
+ "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/rl.html\n",
+ "- **Configuration**: See `src/MaxText/configs/rl.yml` for all available options\n",
+ "- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/src/maxtext/examples/sft_llama3_demo.ipynb b/src/maxtext/examples/sft_llama3_demo.ipynb
new file mode 100644
index 0000000000..0b7dd227ce
--- /dev/null
+++ b/src/maxtext/examples/sft_llama3_demo.ipynb
@@ -0,0 +1,367 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "[](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo.ipynb)\n",
+ "\n",
+ "# Llama3.1-8B-Instruct Supervised Fine-Tuning (SFT) Demo\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Overview\n",
+ "\n",
+ "This notebook demonstrates how to perform Supervised Fine-Tuning (SFT) on Llama3.1-8B-Instruct using the Hugging Face ultrachat_200k dataset with MaxText and Tunix integration for efficient training.\n",
+ "\n",
+ "This notebook can run on **TPU v5e-8** or **v5p-8**."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Prerequisites\n",
+ "\n",
+ "### Change Runtime Type (only if running on Google Colab)\n",
+ "\n",
+ "**Instructions:**\n",
+ "1. Navigate to the menu at the top of the screen.\n",
+ "2. Click on **Runtime**.\n",
+ "3. Select **Change runtime type** from the dropdown menu.\n",
+ "4. Select **v5e-8** or **v5p-8 TPU** as the **Hardware accelerator**.\n",
+ "5. Click on **Save**."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Get Your Hugging Face Token\n",
+ "\n",
+ "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n",
+ "\n",
+ "**Follow these steps to get your token:**\n",
+ "\n",
+ "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n",
+ " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n",
+ "\n",
+ "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n",
+ "\n",
+ "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n",
+ "\n",
+ "4. **Copy the generated token**. You will need this in the later steps.\n",
+ "\n",
+ "**Follow these steps to store your token (only if running on Google Colab):**\n",
+ "\n",
+ "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n",
+ "\n",
+ "2. Click **\"+ Add new secret\"**.\n",
+ "\n",
+ "3. Set the Name as **HF_TOKEN**.\n",
+ "\n",
+ "4. Paste your token into the Value field.\n",
+ "\n",
+ "5. Ensure the Notebook access toggle is turned On."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "try:\n",
+ " from google.colab import userdata\n",
+ " print(\"Running the notebook on Google Colab\")\n",
+ " IN_COLAB = True\n",
+ "except ImportError:\n",
+ " print(\"Running the notebook on Visual Studio or JupyterLab\")\n",
+ " IN_COLAB = False"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Installation: MaxText & Other Dependencies"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**⚠️ Note:** The installation process in following cell may take a few minutes to complete. Please be patient."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if IN_COLAB:\n",
+ " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n",
+ " %cd /content/maxtext\n",
+ "\n",
+ " # Install uv, a fast Python package installer\n",
+ " !pip install uv\n",
+ "\n",
+ " # Install MaxText and its dependencies\n",
+ " !uv pip install -e .[tpu] --resolution=lowest\n",
+ " !python3 -m MaxText.install_maxtext_extra_deps"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Restart Session (only if running on Google Colab)\n",
+ "To apply certain changes, you need to restart the session.\n",
+ "\n",
+ "**Instructions:**\n",
+ "1. Navigate to the menu at the top of the screen.\n",
+ "2. Click on **Runtime**.\n",
+ "3. Select **Restart session** from the dropdown menu.\n",
+ "\n",
+ "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Environment Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import datetime\n",
+ "import os\n",
+ "import subprocess\n",
+ "import sys\n",
+ "import MaxText\n",
+ "from MaxText import pyconfig\n",
+ "from maxtext.trainers.post_train.sft import train_sft\n",
+ "import jax\n",
+ "from huggingface_hub import login\n",
+ "\n",
+ "MAXTEXT_PKG_DIR = os.path.dirname(MaxText.__file__)\n",
+ "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if not jax.distributed.is_initialized():\n",
+ " jax.distributed.initialize()\n",
+ "print(f\"JAX version: {jax.__version__}\")\n",
+ "print(f\"JAX devices: {jax.devices()}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if IN_COLAB:\n",
+ " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n",
+ "else:\n",
+ " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n",
+ "\n",
+ "# If not found in the environment, prompt the user for input securely\n",
+ "# getpass function ensures the token is hidden while you type\n",
+ "if not HF_TOKEN:\n",
+ " from getpass import getpass\n",
+ " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n",
+ "\n",
+ "if HF_TOKEN:\n",
+ " login(token=HF_TOKEN)\n",
+ " print(\"Authenticated with Hugging Face successfully!\")\n",
+ "else:\n",
+ " print(\"Authentication failed: Hugging Face token is not set.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Model Configurations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "MODEL_NAME = \"llama3.1-8b\"\n",
+ "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n",
+ "\n",
+ "# set the path to the model checkpoint or leave empty to download from HuggingFace\n",
+ "MODEL_CHECKPOINT_PATH = \"\"\n",
+ "if not MODEL_CHECKPOINT_PATH:\n",
+ " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_PKG_DIR}/llama_checkpoint\"\n",
+ " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n",
+ " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n",
+ "\n",
+ "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_PKG_DIR}/sft_llama3_output\"\n",
+ "RUN_NAME = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Download Llama3.1-8B Model Checkpoint from Hugging Face"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
+ " # install torch for the conversion script\n",
+ " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n",
+ "\n",
+ " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_PKG_DIR} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n",
+ " {MAXTEXT_PKG_DIR}/configs/base.yml \\\n",
+ " model_name={MODEL_NAME} \\\n",
+ " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n",
+ " hf_access_token={HF_TOKEN} \\\n",
+ " use_multimodal=false \\\n",
+ " scan_layers=true \\\n",
+ " skip_jax_distributed_system=True\n",
+ "\n",
+ "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
+ " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## MaxText Configurations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "In-jdp1AAwrL"
+ },
+ "outputs": [],
+ "source": [
+ "# Load configuration for SFT training\n",
+ "config_argv = [\n",
+ " \"\",\n",
+ " f\"{MAXTEXT_PKG_DIR}/configs/sft.yml\",\n",
+ " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n",
+ " f\"model_name={MODEL_NAME}\",\n",
+ " \"steps=100\",\n",
+ " \"per_device_batch_size=1\",\n",
+ " \"max_target_length=1024\",\n",
+ " \"learning_rate=2.0e-5\",\n",
+ " \"weight_dtype=bfloat16\",\n",
+ " \"dtype=bfloat16\",\n",
+ " \"hf_path=HuggingFaceH4/ultrachat_200k\",\n",
+ " f\"hf_access_token={HF_TOKEN}\",\n",
+ " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n",
+ " f\"run_name={RUN_NAME}\",\n",
+ " f\"tokenizer_path={TOKENIZER_PATH}\",\n",
+ " \"profiler=xplane\",\n",
+ "]\n",
+ "\n",
+ "config = pyconfig.initialize(config_argv)\n",
+ "\n",
+ "print(\"✓ SFT configuration loaded:\")\n",
+ "print(f\" Model: {config.model_name}\")\n",
+ "print(f\" Training Steps: {config.steps}\")\n",
+ "print(f\" Output Directory: {config.base_output_directory}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## SFT Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "mgwpNgQYCJEd"
+ },
+ "outputs": [],
+ "source": [
+ "print(\"=\" * 60)\n",
+ "print(\"🚀 Starting SFT Training...\")\n",
+ "print(\"=\" * 60)\n",
+ "\n",
+ "try:\n",
+ " trainer, mesh = train_sft.train(config)\n",
+ "\n",
+ " print(\"\\n\" + \"=\" * 60)\n",
+ " print(\"✅ Training Completed Successfully!\")\n",
+ " print(\"=\" * 60)\n",
+ " print(f\"📁 Checkpoints saved to: {config.checkpoint_dir}\")\n",
+ "except Exception as e:\n",
+ " print(\"\\n\" + \"=\" * 60)\n",
+ " print(\"❌Training Failed!\")\n",
+ " print(\"=\" * 60)\n",
+ " print(f\"Error: {str(e)}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 📚 Learn More\n",
+ "\n",
+ "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html\n",
+ "- **Configuration**: See `src/MaxText/configs/sft.yml` for all available options\n",
+ "- **Documentation**: Check `src/MaxText/sft/sft_trainer.py` for the `sft_train` function implementation"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "TPU",
+ "colab": {
+ "gpuType": "V5E1",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/src/maxtext/examples/sft_qwen3_demo.ipynb b/src/maxtext/examples/sft_qwen3_demo.ipynb
new file mode 100644
index 0000000000..bd44c741c0
--- /dev/null
+++ b/src/maxtext/examples/sft_qwen3_demo.ipynb
@@ -0,0 +1,624 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1nb_Ppf2ZUQL"
+ },
+ "source": [
+ "[](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/sft_qwen3_demo.ipynb)\n",
+ "\n",
+ "# Qwen3-0.6B Supervised Fine-Tuning (SFT) Demo\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FGbe4_YQZUQL"
+ },
+ "source": [
+ "## Overview\n",
+ "\n",
+ "This notebook performs SFT training and evaluation workflow on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k).\n",
+ "The primary goal is to demonstrate the end-to-end process of:\n",
+ "1. Pre-SFT Evaluation: Calcuating baseline accuracy for the model before training.\n",
+ "2. SFT Training: Fine-tune the model using MaxText & Tunix SFT trainer.\n",
+ "3. Post-SFT Evaluation: Re-running the evaluation loop after training to measure the performance gain achieved by SFT.\n",
+ "\n",
+ "This notebook can run on the **public TPU v5e-1**."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zolxPWhQZUQL"
+ },
+ "source": [
+ "## Prerequisites\n",
+ "\n",
+ "### Change Runtime Type (only if running on Google Colab)\n",
+ "\n",
+ "**Instructions:**\n",
+ "1. Navigate to the menu at the top of the screen.\n",
+ "2. Click on **Runtime**.\n",
+ "3. Select **Change runtime type** from the dropdown menu.\n",
+ "4. Select **v5e-1 TPU** as the **Hardware accelerator**.\n",
+ "5. Click on **Save**."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Rk_QpVVuZUQL"
+ },
+ "source": [
+ "### Get Your Hugging Face Token\n",
+ "\n",
+ "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n",
+ "\n",
+ "**Follow these steps to get your token:**\n",
+ "\n",
+ "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n",
+ " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n",
+ "\n",
+ "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n",
+ "\n",
+ "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n",
+ "\n",
+ "4. **Copy the generated token**. You will need this in the later steps.\n",
+ "\n",
+ "**Follow these steps to store your token (only if running on Google Colab):**\n",
+ "\n",
+ "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n",
+ "\n",
+ "2. Click **\"+ Add new secret\"**.\n",
+ "\n",
+ "3. Set the Name as **HF_TOKEN**.\n",
+ "\n",
+ "4. Paste your token into the Value field.\n",
+ "\n",
+ "5. Ensure the Notebook access toggle is turned On."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "try:\n",
+ " from google.colab import userdata\n",
+ " print(\"Running the notebook on Google Colab\")\n",
+ " IN_COLAB = True\n",
+ "except ImportError:\n",
+ " print(\"Running the notebook on Visual Studio or JupyterLab\")\n",
+ " IN_COLAB = False"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "D9ms-jTSZUQL"
+ },
+ "source": [
+ "## Installation: MaxText & Other Dependencies"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**⚠️ Note:** The installation process in following cell may take a few minutes to complete. Please be patient."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "try:\n",
+ " import google.colab\n",
+ " print(\"Running the notebook in Google Colab\")\n",
+ " IN_COLAB = True\n",
+ "except ImportError:\n",
+ " print(\"Running the notebook on JupyterLab\")\n",
+ " IN_COLAB = False"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "OSPRVbi7n6tB"
+ },
+ "outputs": [],
+ "source": [
+ "if IN_COLAB:\n",
+ " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n",
+ " %cd /content/maxtext\n",
+ "\n",
+ " # Install uv, a fast Python package installer\n",
+ " !pip install uv\n",
+ "\n",
+ " # Install MaxText and its dependencies\n",
+ " !uv pip install -e .[tpu] --resolution=lowest\n",
+ " !python3 -m MaxText.install_maxtext_extra_deps\n",
+ "\n",
+ " # Install Tunix for post-training notebooks\n",
+ " !uv pip install git+https://github.com/google/tunix\n",
+ " \n",
+ " # Install vllm for post-training notebooks\n",
+ " !git clone https://github.com/vllm-project/vllm.git\n",
+ " !VLLM_TARGET_DEVICE=\"tpu\" uv pip install ./vllm\n",
+ "\n",
+ " # Install tpu-inference for post-training notebooks\n",
+ " !git clone https://github.com/vllm-project/tpu-inference.git\n",
+ " !uv pip install ./tpu-inference\n",
+ " !uv pip install --no-deps qwix==0.1.4"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ywtealAxZUQM"
+ },
+ "source": [
+ "### Restart Session (only if running on Google Colab)\n",
+ "To apply certain changes, you need to restart the session.\n",
+ "\n",
+ "**Instructions:**\n",
+ "1. Navigate to the menu at the top of the screen.\n",
+ "2. Click on **Runtime**.\n",
+ "3. Select **Restart session** from the dropdown menu.\n",
+ "\n",
+ "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Clexf-j7ZUQM"
+ },
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "PkBI9A3JZUQM"
+ },
+ "outputs": [],
+ "source": [
+ "import jax\n",
+ "import os\n",
+ "import sys\n",
+ "import transformers\n",
+ "\n",
+ "import MaxText\n",
+ "from MaxText import pyconfig\n",
+ "from maxtext.examples.sft_train_and_evaluate import evaluate_model, get_test_dataset\n",
+ "from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter\n",
+ "from maxtext.trainers.post_train.sft import train_sft\n",
+ "\n",
+ "# Suppress vLLM logging with a severity level below ERROR\n",
+ "os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n",
+ "from tunix.rl.rollout import base_rollout\n",
+ "from tunix.rl.rollout.vllm_rollout import VllmRollout\n",
+ "\n",
+ "from datetime import datetime\n",
+ "from flax import nnx\n",
+ "from huggingface_hub import login\n",
+ "\n",
+ "MAXTEXT_PKG_DIR = os.path.dirname(MaxText.__file__)\n",
+ "MAXTEXT_REPO_ROOT = os.sep.join([\"maxtext\" if p == \"MaxText\" else p for p in MAXTEXT_PKG_DIR.split(os.sep)])\n",
+ "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if not jax.distributed.is_initialized():\n",
+ " jax.distributed.initialize()\n",
+ "print(f\"JAX version: {jax.__version__}\")\n",
+ "print(f\"JAX devices: {jax.devices()}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "JBbPN-uVZUQM"
+ },
+ "outputs": [],
+ "source": [
+ "try:\n",
+ " from google.colab import userdata\n",
+ " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n",
+ "except ImportError:\n",
+ " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n",
+ "\n",
+ "# If not found in the environment, prompt the user for input securely\n",
+ "# getpass function ensures the token is hidden while you type\n",
+ "if not HF_TOKEN:\n",
+ " from getpass import getpass\n",
+ " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n",
+ "\n",
+ "if HF_TOKEN:\n",
+ " login(token=HF_TOKEN)\n",
+ " print(\"Authenticated with Hugging Face successfully!\")\n",
+ "else:\n",
+ " print(\"Authentication failed: Hugging Face token is not set.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aENuzm9iZUQM"
+ },
+ "source": [
+ "## Model Configurations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "RjPYYl3zZUQM"
+ },
+ "outputs": [],
+ "source": [
+ "MODEL_NAME = \"qwen3-0.6b\"\n",
+ "TOKENIZER_PATH = \"Qwen/Qwen3-0.6B\"\n",
+ "tokenizer = transformers.AutoTokenizer.from_pretrained(\n",
+ " TOKENIZER_PATH,\n",
+ " token=HF_TOKEN,\n",
+ ")\n",
+ "\n",
+ "# set the path to the model checkpoint or leave empty to download from HuggingFace\n",
+ "MODEL_CHECKPOINT_PATH = \"\"\n",
+ "if not MODEL_CHECKPOINT_PATH:\n",
+ " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_PKG_DIR}/qwen_checkpoint\"\n",
+ " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n",
+ " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n",
+ "\n",
+ "\n",
+ "RUN_NAME = datetime.now().strftime(\"%Y-%m-%d-%H-%m-%S\")\n",
+ "\n",
+ "# This is the directory where the fine-tuned model checkpoint will be saved\n",
+ "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_PKG_DIR}/maxtext_qwen06_output\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4L37Ij4NZUQM"
+ },
+ "source": [
+ "## Download Qwen3-0.6B Model Checkpoint from Hugging Face"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "kJanDAc0ZUQM"
+ },
+ "outputs": [],
+ "source": [
+ "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
+ " # install torch for the conversion script\n",
+ " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n",
+ "\n",
+ " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_PKG_DIR} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n",
+ " {MAXTEXT_PKG_DIR}/configs/base.yml \\\n",
+ " model_name={MODEL_NAME} \\\n",
+ " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n",
+ " hf_access_token={HF_TOKEN} \\\n",
+ " use_multimodal=false \\\n",
+ " scan_layers=true \\\n",
+ " skip_jax_distributed_system=True\n",
+ "\n",
+ "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n",
+ " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PC-hILG0ZUQM"
+ },
+ "source": [
+ "## Dataset Configurations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "O3MLdr9kZUQM"
+ },
+ "outputs": [],
+ "source": [
+ "DATASET_NAME = \"openai/gsm8k\"\n",
+ "TRAIN_DATA_SPLIT = \"train\"\n",
+ "TEST_DATA_SPLIT = \"test\"\n",
+ "HF_DATA_DIR = \"main\"\n",
+ "TRAIN_DATA_COLUMNS = [\"question\", \"answer\"]\n",
+ "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/math_qa.json\"\n",
+ "if not os.path.exists(CHAT_TEMPLATE_PATH):\n",
+ " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n",
+ "NUM_TEST_SAMPLES = 20 # Total number of samples to test\n",
+ "BATCH_SIZE = 1 # Number of test samples to process in a batch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yeAHmxSYZUQM"
+ },
+ "source": [
+ "## MaxText Configurations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "In-jdp1AAwrL"
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "config = pyconfig.initialize(\n",
+ " [\n",
+ " \"\",\n",
+ " f\"{MAXTEXT_PKG_DIR}/configs/sft.yml\",\n",
+ " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n",
+ " f\"model_name={MODEL_NAME}\",\n",
+ " f\"hf_access_token={HF_TOKEN}\",\n",
+ " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n",
+ " f\"run_name={RUN_NAME}\",\n",
+ " f\"tokenizer_path={TOKENIZER_PATH}\",\n",
+ " f\"hf_path={DATASET_NAME}\",\n",
+ " f\"train_split={TRAIN_DATA_SPLIT}\",\n",
+ " f\"hf_data_dir={HF_DATA_DIR}\",\n",
+ " f\"train_data_columns={TRAIN_DATA_COLUMNS}\",\n",
+ " \"steps=500\",\n",
+ " \"per_device_batch_size=1\",\n",
+ " \"max_target_length=1024\",\n",
+ " \"learning_rate=3e-6\",\n",
+ " \"weight_dtype=bfloat16\",\n",
+ " \"dtype=bfloat16\",\n",
+ " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "O9b0GWo-ZUQM"
+ },
+ "source": [
+ "## Initial Setup & Data Preparation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TDqFmvUCZUQM"
+ },
+ "source": [
+ "### Create Test Dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "wscWYxrtZUQM"
+ },
+ "outputs": [],
+ "source": [
+ "test_dataset = get_test_dataset(config, tokenizer)\n",
+ "test_dataset = test_dataset[:NUM_TEST_SAMPLES]\n",
+ "test_dataset = test_dataset.to_iter_dataset().batch(BATCH_SIZE, drop_remainder=True)\n",
+ "TOTAL_BATCHES = NUM_TEST_SAMPLES // BATCH_SIZE\n",
+ "print(\n",
+ " f\"Processing {NUM_TEST_SAMPLES} examples with a batch size of {BATCH_SIZE}. This will result in {TOTAL_BATCHES} total batches for the test run.\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bLSvOOEUZUQM"
+ },
+ "source": [
+ "### Create SFT Trainer State"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2IHsC0m6ZUQM"
+ },
+ "outputs": [],
+ "source": [
+ "trainer, mesh = train_sft.setup_trainer_state(config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PpKtEqzFZUQM"
+ },
+ "source": [
+ "### Create vLLM Rollout"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "3-pf_rbqZUQM"
+ },
+ "outputs": [],
+ "source": [
+ "tunix_model = TunixMaxTextAdapter(trainer.model)\n",
+ "vllm_rollout = VllmRollout(\n",
+ " model=tunix_model,\n",
+ " tokenizer=tokenizer,\n",
+ " cache_config_or_size=1280,\n",
+ " mesh=mesh,\n",
+ " rollout_config=base_rollout.RolloutConfig(\n",
+ " rollout_vllm_model_version=TOKENIZER_PATH,\n",
+ " rollout_vllm_hbm_utilization=0.8,\n",
+ " rollout_vllm_init_with_random_weights=True,\n",
+ " rollout_vllm_tpu_backend_type=\"jax\",\n",
+ " ),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "567gTxsEZUQM"
+ },
+ "source": [
+ "## Evaluation before SFT Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "OnACa3zCZUQM"
+ },
+ "outputs": [],
+ "source": [
+ "print(\"Running Pre-SFT Evaluation...\")\n",
+ "score = evaluate_model(test_dataset, vllm_rollout, debug=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "u5-M4iYkZUQN"
+ },
+ "outputs": [],
+ "source": [
+ "print(\"========================= Score for PRE-SFT Evaluation =========================\")\n",
+ "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n",
+ "print(\n",
+ " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n",
+ ")\n",
+ "print(\n",
+ " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "EJE1ookSAzz-"
+ },
+ "source": [
+ "## SFT Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "editable": true,
+ "id": "mgwpNgQYCJEd",
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "print(\"Starting SFT Training...\")\n",
+ "trainer = train_sft.train_model(config, trainer, mesh)\n",
+ "print(\"SFT Training Complete!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WEdNYRhwZUQN"
+ },
+ "source": [
+ "## Evaluation after SFT Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "XcsZacZdZUQN"
+ },
+ "outputs": [],
+ "source": [
+ "print(\"Running Post-SFT Evaluation...\")\n",
+ "model = TunixMaxTextAdapter(trainer.model)\n",
+ "state = nnx.state(model)\n",
+ "vllm_rollout.update_params(state)\n",
+ "score = evaluate_model(test_dataset, vllm_rollout, debug=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "editable": true,
+ "id": "-JtYTPvJZUQN",
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "print(\"========================= Score for POST-SFT Evaluation =========================\")\n",
+ "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n",
+ "print(\n",
+ " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n",
+ ")\n",
+ "print(\n",
+ " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n",
+ ")"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "TPU",
+ "colab": {
+ "gpuType": "V5E1",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/src/MaxText/examples/sft_train_and_evaluate.py b/src/maxtext/examples/sft_train_and_evaluate.py
similarity index 98%
rename from src/MaxText/examples/sft_train_and_evaluate.py
rename to src/maxtext/examples/sft_train_and_evaluate.py
index b38985fa38..7263169362 100644
--- a/src/MaxText/examples/sft_train_and_evaluate.py
+++ b/src/maxtext/examples/sft_train_and_evaluate.py
@@ -35,7 +35,7 @@
export MODEL_CHECKPOINT_PATH=
export HF_ACCESS_TOKEN=
-python3 -m MaxText.examples.sft_train_and_evaluate MaxText/configs/sft.yml \
+python3 -m maxtext.examples.sft_train_and_evaluate MaxText/configs/sft.yml \
run_name=$RUN_NAME base_output_directory=$OUTPUT_PATH \
model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH \
hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH
@@ -67,7 +67,7 @@
--workload=sft-${RUN_NAME} \
--tpu-type ${TPU_TYPE} --num-slices=1 --zone=${ZONE} \
--project=${PROJECT} \
---command "HF_TOKEN=$HF_ACCESS_TOKEN python3 -m MaxText.examples.sft_train_and_evaluate MaxText/configs/sft.yml \
+--command "HF_TOKEN=$HF_ACCESS_TOKEN python3 -m maxtext.examples.sft_train_and_evaluate MaxText/configs/sft.yml \
run_name=$RUN_NAME base_output_directory=$OUTPUT_PATH \
model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH \
hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH"
@@ -125,7 +125,7 @@
)
# Regex to extract the final numerical answer
MATCH_ANSWER = re.compile(rf"{ANSWER_START}.*?([\d\.\,\$]{{1,}})", flags=re.MULTILINE | re.DOTALL)
-CHAT_TEMPLATE_PATH = os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "examples", "chat_templates", "math_qa.json")
+CHAT_TEMPLATE_PATH = os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "examples", "chat_templates", "math_qa.json")
def get_test_dataset(config, tokenizer):
diff --git a/pedagogical_examples/shardings.py b/src/maxtext/examples/shardings.py
similarity index 100%
rename from pedagogical_examples/shardings.py
rename to src/maxtext/examples/shardings.py
diff --git a/pedagogical_examples/shmap_collective_matmul.py b/src/maxtext/examples/shmap_collective_matmul.py
similarity index 100%
rename from pedagogical_examples/shmap_collective_matmul.py
rename to src/maxtext/examples/shmap_collective_matmul.py
diff --git a/tests/integration/shmap_collective_matmul_test.py b/tests/integration/shmap_collective_matmul_test.py
index e89edfa5a7..8c966398dc 100644
--- a/tests/integration/shmap_collective_matmul_test.py
+++ b/tests/integration/shmap_collective_matmul_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Integration test for pedagogical_examples/shmap_collective_matmul.py"""
+"""Integration test for maxtext/examples/shmap_collective_matmul.py"""
import os.path
import sys
@@ -24,7 +24,7 @@
sys.path.append(os.path.join(MAXTEXT_REPO_ROOT, "pedagogical_examples"))
# Uncomment the import when b/415022795 is fixed
-# from pedagogical_examples.shmap_collective_matmul import main
+# from maxtext.examples.shmap_collective_matmul import main
@pytest.mark.skip(reason="Enable when b/415022795 is fixed")