From 5840964b40c901ade0d92d2b68ab10fec5260c44 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Thu, 22 Jan 2026 22:39:54 +0000 Subject: [PATCH] move examples --- .github/workflows/run_jupyter_notebooks.yml | 3 +- README.md | 2 +- codecov.yml | 2 +- docs/guides/run_python_notebook.md | 67 +- docs/tutorials/first_run.md | 2 +- docs/tutorials/posttraining/multimodal.md | 2 +- pedagogical_examples/__init__.py | 13 - src/MaxText/configs/rl.yml | 2 +- src/MaxText/examples/rl_llama3_demo.ipynb | 410 ------------ src/MaxText/examples/sft_llama3_demo.ipynb | 407 ----------- src/MaxText/examples/sft_qwen3_demo.ipynb | 631 ------------------ .../examples/chat_templates/gsm8k_rl.json | 0 .../examples/chat_templates/math_qa.json | 0 .../examples/demo_decoding.ipynb | 4 +- .../examples/multimodal_gemma3_demo.ipynb | 2 +- .../maxtext/examples}/non_spmd.py | 0 src/maxtext/examples/rl_llama3_demo.ipynb | 372 +++++++++++ src/maxtext/examples/sft_llama3_demo.ipynb | 367 ++++++++++ src/maxtext/examples/sft_qwen3_demo.ipynb | 624 +++++++++++++++++ .../examples/sft_train_and_evaluate.py | 6 +- .../maxtext/examples}/shardings.py | 0 .../examples}/shmap_collective_matmul.py | 0 .../shmap_collective_matmul_test.py | 4 +- 23 files changed, 1418 insertions(+), 1502 deletions(-) delete mode 100644 pedagogical_examples/__init__.py delete mode 100644 src/MaxText/examples/rl_llama3_demo.ipynb delete mode 100644 src/MaxText/examples/sft_llama3_demo.ipynb delete mode 100644 src/MaxText/examples/sft_qwen3_demo.ipynb rename src/{MaxText => maxtext}/examples/chat_templates/gsm8k_rl.json (100%) rename src/{MaxText => maxtext}/examples/chat_templates/math_qa.json (100%) rename src/{MaxText => maxtext}/examples/demo_decoding.ipynb (99%) rename src/{MaxText => maxtext}/examples/multimodal_gemma3_demo.ipynb (99%) rename {pedagogical_examples => src/maxtext/examples}/non_spmd.py (100%) create mode 100644 src/maxtext/examples/rl_llama3_demo.ipynb create mode 100644 src/maxtext/examples/sft_llama3_demo.ipynb create mode 100644 src/maxtext/examples/sft_qwen3_demo.ipynb rename src/{MaxText => maxtext}/examples/sft_train_and_evaluate.py (98%) rename {pedagogical_examples => src/maxtext/examples}/shardings.py (100%) rename {pedagogical_examples => src/maxtext/examples}/shmap_collective_matmul.py (100%) diff --git a/.github/workflows/run_jupyter_notebooks.yml b/.github/workflows/run_jupyter_notebooks.yml index 528ff15e5f..78a1ec4375 100644 --- a/.github/workflows/run_jupyter_notebooks.yml +++ b/.github/workflows/run_jupyter_notebooks.yml @@ -87,10 +87,11 @@ jobs: - name: Run Post-Training Notebooks shell: bash env: + PYTHONPATH: "${{ github.workspace }}/src" HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | MAXTEXT_REPO_ROOT=$(pwd) - MAXTEXT_NOTEBOOKS_ROOT="$MAXTEXT_REPO_ROOT/src/MaxText/examples" + MAXTEXT_NOTEBOOKS_ROOT="$MAXTEXT_REPO_ROOT/src/maxtext/examples" for notebook in "$MAXTEXT_NOTEBOOKS_ROOT"/{sft,rl}*.ipynb; do filename=$(basename "$notebook") diff --git a/README.md b/README.md index 5373a11360..aae9502b0b 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies * \[December 22, 2025\] [Muon optimizer](https://kellerjordan.github.io/posts/muon) is now supported. * \[December 10, 2025\] DeepSeek V3.1 is now supported. Use existing configs for [DeepSeek V3 671B](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/models/deepseek3-671b.yml) and load in V3.1 checkpoint to use model. -* \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/examples) are available. +* \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/examples) are available. * \[December 4, 2025\] The [ReadTheDocs documentation site](https://maxtext.readthedocs.io/en/latest/index.html) has been reorganized. * \[December 3, 2025\] Multi-host support for GSPO and GRPO is now available via [new RL tutorials](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/rl_on_multi_host.html). * \[November 20, 2025\] A new guide, [What is Post Training in MaxText?](https://maxtext.readthedocs.io/en/latest/tutorials/post_training_index.html), is now available. diff --git a/codecov.yml b/codecov.yml index 9765fe5edd..0b616cf714 100644 --- a/codecov.yml +++ b/codecov.yml @@ -34,7 +34,7 @@ fixes: ignore: - "src/maxtext/assets" - "src/MaxText/configs" - - "src/MaxText/examples" + - "src/maxtext/examples" - "src/MaxText/experimental" - "src/MaxText/inference" - "src/MaxText/inference_mlperf" diff --git a/docs/guides/run_python_notebook.md b/docs/guides/run_python_notebook.md index 7c5dc14ac6..7d24c271d6 100644 --- a/docs/guides/run_python_notebook.md +++ b/docs/guides/run_python_notebook.md @@ -19,6 +19,7 @@ Before starting, make sure you have: - ✅ Basic familiarity with Jupyter, Python, and Git **For Method 2 (Visual Studio Code) and Method 3 (Local Jupyter Lab) only:** + - ✅ A Google Cloud Platform (GCP) account with billing enabled - ✅ TPU quota available in your region (check under IAM & Admin → Quotas) - ✅ `tpu.nodes.create` permission to create a TPU VM @@ -36,16 +37,18 @@ Currently, this method only supports the **`sft_qwen3_demo.ipynb`** notebook, wh Before proceeding, please verify that the specific notebook you are running works reliably on the free-tier TPU resources. If you encounter frequent disconnections or resource limitations, you may need to: -* Upgrade to a Colab Pro or Pro+ subscription for more stable and powerful TPU access. +- Upgrade to a Colab Pro or Pro+ subscription for more stable and powerful TPU access. -* Move to local Jupyter Lab setup method with access to a powerful TPU machine. +- Move to local Jupyter Lab setup method with access to a powerful TPU machine. ### Step 1: Choose an Example -1.a. Visit the [MaxText examples directory](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/examples) on Github. + +1.a. Visit the [MaxText examples directory](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/examples) on Github. 1.b. Find the notebook you want to run (e.g., `sft_qwen3_demo.ipynb`) and copy its URL. ### Step 2: Import into Colab + 2.a. Go to [Google Colab](https://colab.research.google.com/) and sign in. 2.b. Select **File** -> **Open Notebook**. @@ -63,9 +66,11 @@ Before proceeding, please verify that the specific notebook you are running work 3.c. Click **Save** ### Step 4: Run the Notebook + Follow the instructions within the notebook cells to install dependencies and run the training/inference. ## Method 2: Visual Studio Code with TPU (Recommended) + Running Jupyter notebooks in Visual Studio Code (VS Code) provides a powerful, interactive environment that combines the flexibility of notebooks with the robust features of a code editor. Follow these steps to get your environment up and running. ### Step 1: Set Up TPU VM @@ -75,9 +80,10 @@ In Google Cloud Console, create a standalone TPU VM: 1.a. **Compute Engine** → **TPUs** → **Create TPU** 1.b. Example config: - - **Name:** `maxtext-tpu-node` - - **TPU type:** Choose your desired TPU type - - **Runtime Version:** `tpu-ubuntu2204-base` (or other compatible runtime) + +- **Name:** `maxtext-tpu-node` +- **TPU type:** Choose your desired TPU type +- **Runtime Version:** `tpu-ubuntu2204-base` (or other compatible runtime) ### Step 2: SSH to TPU-VM via VS Code @@ -86,11 +92,12 @@ In Google Cloud Console, create a standalone TPU VM: 2.b. Follow [Connect to a remote host](https://code.visualstudio.com/docs/remote/ssh#_connect-to-a-remote-host) guide to connect to your TPU-VM via VS Code. ### Step 3. Install Necessary Extensions on VS Code + To enable notebook support, you must install two official extensions from the VS Code Marketplace: -* Python Extension: Provides support for the Python language. +- Python Extension: Provides support for the Python language. -* Jupyter Extension: Enables you to create, edit, and run `.ipynb` files directly inside VS Code. +- Jupyter Extension: Enables you to create, edit, and run `.ipynb` files directly inside VS Code. To install, click the `Extensions` icon on the left sidebar (or press `Ctrl+Shift+X` or `Cmd+Shift+X`), search for `Jupyter` and `Python`, and click `Install`. @@ -99,6 +106,7 @@ To install, click the `Extensions` icon on the left sidebar (or press `Ctrl+Shif To execute post-training notebooks on your TPU-VM, follow the official [MaxText installation guides](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/rl.html#create-virtual-environment-and-install-maxtext-dependencies) to install MaxText and its dependencies inside a dedicated virtual environment. ### Step 5: Install the necessary library for Jupyter + Jupyter requires a kernel to execute code. This kernel is tied to a specific Python environment. Open your terminal inside VS Code and run: ```bash @@ -110,9 +118,9 @@ uv pip install ipykernel Before you can run the notebook, you must tell VS Code which Python environment to use. 1. Look at the top-right corner of the notebook editor. -2. Click `Select Kernel`. -3. Choose Python Environments and select the virtual environment you created in Step 4. -4. Open [available post-training notebooks in MaxText](#available-examples) inside VS Code and run the jupyter notebook cells. +1. Click `Select Kernel`. +1. Choose Python Environments and select the virtual environment you created in Step 4. +1. Open [available post-training notebooks in MaxText](#available-examples) inside VS Code and run the jupyter notebook cells. ## Method 3: Local Jupyter Lab with TPU (Recommended) @@ -125,12 +133,15 @@ In Google Cloud Console, create a standalone TPU VM: 1.a. **Compute Engine** → **TPUs** → **Create TPU** 1.b. Example config: - - **Name:** `maxtext-tpu-node` - - **TPU type:** Choose your desired TPU type - - **Runtime Version:** `tpu-ubuntu2204-base` (or other compatible runtime) + +- **Name:** `maxtext-tpu-node` +- **TPU type:** Choose your desired TPU type +- **Runtime Version:** `tpu-ubuntu2204-base` (or other compatible runtime) ### Step 2: Connect with Port Forwarding + Run the following command on your local machine: + > **Note**: The `--` separator before the `-L` flag is required. This tunnels the remote port 8888 to your local machine securely. ```bash @@ -170,13 +181,15 @@ jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root ``` ### Step 7: Access the Notebook + 7.a. Look at the terminal output for a URL that looks like: `http://127.0.0.1:8888/lab?token=...`. 7.b. Copy that URL. 7.c. Paste it into your **local computer's browser**. - * **Important:** If you changed the port in Step 2 (e.g., to `9999`), you must manually replace `8888` in the URL with `9999`. - * *Example:* `http://127.0.0.1:9999/lab?token=...` + +- **Important:** If you changed the port in Step 2 (e.g., to `9999`), you must manually replace `8888` in the URL with `9999`. +- *Example:* `http://127.0.0.1:9999/lab?token=...` 7.d. Once the interface opens in your browser, Click on the current kernel name (e.g., `Python 3 (ipykernel)`). @@ -197,13 +210,13 @@ jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root ## Common Pitfalls & Debugging -| Issue | Solution | -|-------|----------| -| ❌ TPU runtime mismatch | Check TPU runtime version matches VM image | -| ❌ Colab disconnects | Save checkpoints to GCS or Drive regularly | -| ❌ "RESOURCE_EXHAUSTED" errors | Use smaller batch size or v5e-8 instead of v5e-1 | -| ❌ Firewall blocked | Ensure port 8888 open, or always use SSH tunneling | -| ❌ Path confusion | In Colab use `/content/maxtext`; in TPU VM use `~/maxtext` | +| Issue | Solution | +| ------------------------------ | ---------------------------------------------------------- | +| ❌ TPU runtime mismatch | Check TPU runtime version matches VM image | +| ❌ Colab disconnects | Save checkpoints to GCS or Drive regularly | +| ❌ "RESOURCE_EXHAUSTED" errors | Use smaller batch size or v5e-8 instead of v5e-1 | +| ❌ Firewall blocked | Ensure port 8888 open, or always use SSH tunneling | +| ❌ Path confusion | In Colab use `/content/maxtext`; in TPU VM use `~/maxtext` | ## Support and Resources @@ -217,9 +230,9 @@ jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root If you encounter issues or have improvements for this guide, please: 1. Open an issue on the MaxText repository -2. Submit a pull request with your improvements -3. Share your experience in the discussions +1. Submit a pull request with your improvements +1. Share your experience in the discussions ---- +______________________________________________________________________ -**Happy Training! 🚀** \ No newline at end of file +**Happy Training! 🚀** diff --git a/docs/tutorials/first_run.md b/docs/tutorials/first_run.md index 7fefa80344..960ecfbb98 100644 --- a/docs/tutorials/first_run.md +++ b/docs/tutorials/first_run.md @@ -75,7 +75,7 @@ In the same TPU VM where you just installed all the dependencies of MaxText, You #### Decoding in MaxText via notebook -You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint. +You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/demo_decoding.ipynb) to try out decoding on MaxText's `Llama3.1-8b` model implementation. In this notebook, we give `"I love to"` as the prompt, and the greedily sampled first output token is `" cook"`. Please remember to provide the path to your `Llama3.1-8b` checkpoint for the `load_parameters_path` argument in the config inside the notebook. You can use [to_maxtext.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/to_maxtext.py) to create a MaxText/Orbax checkpoint from a Huggingface checkpoint. ### Run MaxText on NVIDIA GPUs diff --git a/docs/tutorials/posttraining/multimodal.md b/docs/tutorials/posttraining/multimodal.md index 980c2d8aca..11c6982c66 100644 --- a/docs/tutorials/posttraining/multimodal.md +++ b/docs/tutorials/posttraining/multimodal.md @@ -6,7 +6,7 @@ This document provides a guide to use the multimodal functionalities in MaxText - **Multimodal Decode**: Inference with text+images as input. - **Supervised Fine-Tuning (SFT)**: Apply SFT to the model using a visual-question-answering dataset. -We also provide a [colab](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/multimodal_gemma3_demo.ipynb) for multimodal features demonstration. The following table provides a list of models and modalities we currently support: +We also provide a [colab](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/multimodal_gemma3_demo.ipynb) for multimodal features demonstration. The following table provides a list of models and modalities we currently support: | Models | Input Modalities | Output Modalities | | :--------------------------------------------- | :--------------- | :---------------- | diff --git a/pedagogical_examples/__init__.py b/pedagogical_examples/__init__.py deleted file mode 100644 index 2237c9162e..0000000000 --- a/pedagogical_examples/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2023–2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/MaxText/configs/rl.yml b/src/MaxText/configs/rl.yml index 283d388d92..5a8f57f664 100644 --- a/src/MaxText/configs/rl.yml +++ b/src/MaxText/configs/rl.yml @@ -171,7 +171,7 @@ reasoning_start_token: '' reasoning_end_token: '' solution_start_token: '' solution_end_token: '' -chat_template_path: 'src/MaxText/examples/chat_templates/gsm8k_rl.json' +chat_template_path: 'src/maxtext/examples/chat_templates/gsm8k_rl.json' skip_jax_distributed_system: True # # TODO(@mazumdera): fix this diff --git a/src/MaxText/examples/rl_llama3_demo.ipynb b/src/MaxText/examples/rl_llama3_demo.ipynb deleted file mode 100644 index a3d7cd7506..0000000000 --- a/src/MaxText/examples/rl_llama3_demo.ipynb +++ /dev/null @@ -1,410 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "aPIBMUYfh6NF" - }, - "source": [ - "# Llama3.1-8B-Instruct Reinforcement Learning Demo\n", - "\n", - "This notebook demonstrates training on Llama3.1-8B-Instruct model with either GRPO (Group Relative Policy Optimization) or GSPO (Group Sequence Policy Optimization).\n", - "\n", - "This notebook can run on **TPU v5e-8** or **v5p-8**.\n", - "\n", - "## What is GRPO/GSPO?\n", - "\n", - "GRPO/GSPO is an RL algorithm that enhances reasoning abilities of LLMs by:\n", - "1. Generating multiple responses for each prompt\n", - "2. Evaluating responses using reward models \n", - "3. Calculating relative advantages to update the policy\n", - "\n", - "The difference is in the loss function - either it's optimizing each token (GRPO) or the whole sequence(GSPO)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3hIEGGBbh6NF" - }, - "source": [ - "## Prerequisites\n", - "\n", - "### Change Runtime Type (only if running on Google Colab)\n", - "\n", - "**Instructions:**\n", - "1. Navigate to the menu at the top of the screen.\n", - "2. Click on **Runtime**.\n", - "3. Select **Change runtime type** from the dropdown menu.\n", - "4. Select **v5e-8** or **v5p-8 TPU** as the **Hardware accelerator**.\n", - "5. Click on **Save**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "t9J8jIJuh6NF" - }, - "source": [ - "### Get Your Hugging Face Token\n", - "\n", - "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", - "\n", - "**Follow these steps to get your token:**\n", - "\n", - "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", - " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", - "\n", - "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", - "\n", - "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", - "\n", - "4. **Copy the generated token**. You will need this in the later steps.\n", - "\n", - "**Follow these steps to store your token (only if running on Google Colab):**\n", - "\n", - "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", - "\n", - "2. Click **\"+ Add new secret\"**.\n", - "\n", - "3. Set the Name as **HF_TOKEN**.\n", - "\n", - "4. Paste your token into the Value field.\n", - "\n", - "5. Ensure the Notebook access toggle is turned On." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "NasqFEooh6NF" - }, - "outputs": [], - "source": [ - "try:\n", - " from google.colab import userdata\n", - " print(\"Running the notebook on Google Colab\")\n", - " IN_COLAB = True\n", - "except ImportError:\n", - " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", - " IN_COLAB = False" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_5dzCPYeh6NF" - }, - "source": [ - "## Installation: MaxText and Dependencies\n", - "\n", - "**⚠️ Note:** The installation process in following cell may take a few minutes to complete. Please be patient." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2G9dGRaoh6NF" - }, - "outputs": [], - "source": [ - "if IN_COLAB:\n", - " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n", - " %cd /content/maxtext\n", - "\n", - " # Install uv, a fast Python package installer\n", - " !pip install uv\n", - "\n", - " # Install MaxText and its dependencies\n", - " !uv pip install -e .[tpu] --resolution=lowest\n", - " !python3 -m MaxText.install_maxtext_extra_deps\n", - "\n", - " # Install vLLM for Jax and TPUs\n", - " !uv pip install vllm-tpu\n", - " !uv pip install --no-deps qwix==0.1.4" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nelhTNVGh6NF" - }, - "source": [ - "### Restart Session (only if running on Google Colab)\n", - "To apply certain changes, you need to restart the session.\n", - "\n", - "**Instructions:**\n", - "1. Navigate to the menu at the top of the screen.\n", - "2. Click on **Runtime**.\n", - "3. Select **Restart session** from the dropdown menu.\n", - "\n", - "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7Q4NS47Mh6NF" - }, - "source": [ - "## Environment Setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Ml0orOyEh6NF" - }, - "outputs": [], - "source": [ - "import datetime\n", - "import os\n", - "import sys\n", - "from pathlib import Path\n", - "import MaxText\n", - "from huggingface_hub import login\n", - "import jax\n", - "\n", - "from MaxText.rl.train_rl import rl_train, setup_configs_and_devices\n", - "\n", - "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n", - "os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n", - "\n", - "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n", - "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BHOtzHWBh6NF" - }, - "outputs": [], - "source": [ - "if not jax.distributed.is_initialized():\n", - " jax.distributed.initialize()\n", - "print(f\"JAX version: {jax.__version__}\")\n", - "print(f\"JAX devices: {jax.devices()}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jqewgodrh6NF" - }, - "outputs": [], - "source": [ - "if IN_COLAB:\n", - " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", - "else:\n", - " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", - "\n", - "# If not found in the environment, prompt the user for input securely\n", - "# getpass function ensures the token is hidden while you type\n", - "if not HF_TOKEN:\n", - " from getpass import getpass\n", - " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", - "\n", - "if HF_TOKEN:\n", - " os.environ[\"HF_TOKEN\"] = HF_TOKEN\n", - " login(token=HF_TOKEN)\n", - " print(\"Authenticated with Hugging Face successfully!\")\n", - "else:\n", - " print(\"Authentication failed: Hugging Face token is not set.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_uxwEwuah6NF" - }, - "source": [ - "## Model Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "dSQzsP3th6NF" - }, - "outputs": [], - "source": [ - "MODEL_NAME = \"llama3.1-8b\"\n", - "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n", - "RUN_NAME = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n", - "LOSS_ALGO=\"grpo\" # or \"gspo-token\" if you want to use GSPO\n", - "\n", - "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json\"\n", - "if not os.path.exists(CHAT_TEMPLATE_PATH):\n", - " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", - "\n", - "# set the path to the model checkpoint or leave empty to download from HuggingFace\n", - "MODEL_CHECKPOINT_PATH = \"\"\n", - "if not MODEL_CHECKPOINT_PATH:\n", - " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_REPO_ROOT}/llama_checkpoint\"\n", - " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n", - " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n", - " \n", - "OUTPUT_DIRECTORY = f\"{MAXTEXT_REPO_ROOT}/rl_llama3_output\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jHbSg0vih6NF" - }, - "source": [ - "## Download Llama3.1-8B Model Checkpoint from Hugging Face" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "I3O4Pr-1h6NF" - }, - "outputs": [], - "source": [ - "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", - " # install torch for the conversion script\n", - " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n", - "\n", - " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_REPO_ROOT} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n", - " {MAXTEXT_REPO_ROOT}/configs/base.yml \\\n", - " model_name={MODEL_NAME} \\\n", - " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n", - " hf_access_token={HF_TOKEN} \\\n", - " use_multimodal=false \\\n", - " scan_layers=true \\\n", - " skip_jax_distributed_system=True\n", - "\n", - "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", - " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HHuww-GWh6NF" - }, - "source": [ - "## MaxText Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "UNzZPHYsh6NF" - }, - "outputs": [], - "source": [ - "# Load configuration for RL training\n", - "config_argv = [\n", - " \"\",\n", - " f\"{MAXTEXT_REPO_ROOT}/configs/rl.yml\",\n", - " f\"model_name={MODEL_NAME}\",\n", - " f\"tokenizer_path={TOKENIZER_PATH}\",\n", - " f\"run_name={RUN_NAME}\",\n", - " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", - " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n", - " f\"base_output_directory={OUTPUT_DIRECTORY}\",\n", - " f\"hf_access_token={HF_TOKEN}\",\n", - " \"debug.rl=False\",\n", - " f\"rl.loss_algo={LOSS_ALGO}\",\n", - " \"use_pathways=False\"\n", - "]\n", - "\n", - "trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(config_argv)\n", - "\n", - "rl_train_steps = int(\n", - " trainer_config.num_batches\n", - " * trainer_config.rl.num_iterations\n", - " * trainer_config.train_fraction\n", - " * trainer_config.num_epoch\n", - ")\n", - "\n", - "print(\"✓ Configuration initialized successfully\")\n", - "print(f\"📁 Output directory: {trainer_config.base_output_directory}\")\n", - "print(f\"🤖 Model: {trainer_config.model_name}\")\n", - "print(f\"📊 RL Train Steps: {rl_train_steps}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uKDvQ-x4h6NG" - }, - "source": [ - "## RL Training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zGm_j3Vrh6NG" - }, - "outputs": [], - "source": [ - "print(\"\\n\" + \"=\" * 80)\n", - "print(f\"🚀 Starting {LOSS_ALGO} Training...\")\n", - "print(\"=\" * 80)\n", - "try:\n", - " rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"✅ Training Completed Successfully!\")\n", - " print(f\"✍️ Note the improved evaluation accuracy metrics with just {rl_train_steps} RL training steps!\")\n", - " print(\"=\" * 80)\n", - " print(f\"📁 Checkpoints saved to: {trainer_config.checkpoint_dir}\")\n", - " print(f\"📊 TensorBoard logs: {trainer_config.tensorboard_dir}\")\n", - " print(f\"🎯 Model ready for inference!\")\n", - "except Exception as e:\n", - " print(\"\\n\" + \"=\" * 80)\n", - " print(\"❌Training Failed!\")\n", - " print(\"=\" * 80)\n", - " print(f\"Error: {str(e)}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TdlhG-wRh6NG" - }, - "source": [ - "## 📚 Learn More\n", - "\n", - "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/rl.html\n", - "- **Configuration**: See `src/MaxText/configs/rl.yml` for all available options\n", - "- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.11" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/src/MaxText/examples/sft_llama3_demo.ipynb b/src/MaxText/examples/sft_llama3_demo.ipynb deleted file mode 100644 index 03e7b078fc..0000000000 --- a/src/MaxText/examples/sft_llama3_demo.ipynb +++ /dev/null @@ -1,407 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "hcsWywuVjQq8" - }, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/sft_llama3_demo.ipynb)\n", - "\n", - "# Llama3.1-8B-Instruct Supervised Fine-Tuning (SFT) Demo\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3faYjYXojQq8" - }, - "source": [ - "## Overview\n", - "\n", - "This notebook demonstrates how to perform Supervised Fine-Tuning (SFT) on Llama3.1-8B-Instruct using the Hugging Face ultrachat_200k dataset with MaxText and Tunix integration for efficient training.\n", - "\n", - "This notebook can run on **TPU v5e-8** or **v5p-8**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6QotQjsMjQq8" - }, - "source": [ - "## Prerequisites\n", - "\n", - "### Change Runtime Type (only if running on Google Colab)\n", - "\n", - "**Instructions:**\n", - "1. Navigate to the menu at the top of the screen.\n", - "2. Click on **Runtime**.\n", - "3. Select **Change runtime type** from the dropdown menu.\n", - "4. Select **v5e-8** or **v5p-8 TPU** as the **Hardware accelerator**.\n", - "5. Click on **Save**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8XMWNuEwjQq8" - }, - "source": [ - "### Get Your Hugging Face Token\n", - "\n", - "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", - "\n", - "**Follow these steps to get your token:**\n", - "\n", - "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", - " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", - "\n", - "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", - "\n", - "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", - "\n", - "4. **Copy the generated token**. You will need this in the later steps.\n", - "\n", - "**Follow these steps to store your token (only if running on Google Colab):**\n", - "\n", - "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", - "\n", - "2. Click **\"+ Add new secret\"**.\n", - "\n", - "3. Set the Name as **HF_TOKEN**.\n", - "\n", - "4. Paste your token into the Value field.\n", - "\n", - "5. Ensure the Notebook access toggle is turned On." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Y53cfWyujQq9" - }, - "outputs": [], - "source": [ - "try:\n", - " from google.colab import userdata\n", - " print(\"Running the notebook on Google Colab\")\n", - " IN_COLAB = True\n", - "except ImportError:\n", - " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", - " IN_COLAB = False" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GeTXb9DnjQq9" - }, - "source": [ - "### Installation: MaxText & Other Dependencies" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KTRkrhp1jQq9" - }, - "source": [ - "**⚠️ Note:** The installation process in following cell may take a few minutes to complete. Please be patient." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "aUoJbrTDjQq9" - }, - "outputs": [], - "source": [ - "if IN_COLAB:\n", - " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n", - " %cd /content/maxtext\n", - "\n", - " # Install uv, a fast Python package installer\n", - " !pip install uv\n", - "\n", - " # Install MaxText and its dependencies\n", - " !uv pip install -e .[tpu] --resolution=lowest\n", - " !python3 -m MaxText.install_maxtext_extra_deps" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8YkY4rSqjQq9" - }, - "source": [ - "### Restart Session (only if running on Google Colab)\n", - "To apply certain changes, you need to restart the session.\n", - "\n", - "**Instructions:**\n", - "1. Navigate to the menu at the top of the screen.\n", - "2. Click on **Runtime**.\n", - "3. Select **Restart session** from the dropdown menu.\n", - "\n", - "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6mrIF5acjQq9" - }, - "source": [ - "## Environment Setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mdeFSFymjQq9" - }, - "outputs": [], - "source": [ - "import datetime\n", - "import os\n", - "import subprocess\n", - "import sys\n", - "import MaxText\n", - "from MaxText import pyconfig\n", - "from maxtext.trainers.post_train.sft import train_sft\n", - "import jax\n", - "from huggingface_hub import login\n", - "\n", - "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n", - "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4RBhx-PBjQq9" - }, - "outputs": [], - "source": [ - "if not jax.distributed.is_initialized():\n", - " jax.distributed.initialize()\n", - "print(f\"JAX version: {jax.__version__}\")\n", - "print(f\"JAX devices: {jax.devices()}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yaw60bgPjQq9" - }, - "outputs": [], - "source": [ - "if IN_COLAB:\n", - " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", - "else:\n", - " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", - "\n", - "# If not found in the environment, prompt the user for input securely\n", - "# getpass function ensures the token is hidden while you type\n", - "if not HF_TOKEN:\n", - " from getpass import getpass\n", - " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", - "\n", - "if HF_TOKEN:\n", - " login(token=HF_TOKEN)\n", - " print(\"Authenticated with Hugging Face successfully!\")\n", - "else:\n", - " print(\"Authentication failed: Hugging Face token is not set.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bxNSmUV8jQq9" - }, - "source": [ - "## Model Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kWjjVLgUjQq9" - }, - "outputs": [], - "source": [ - "MODEL_NAME = \"llama3.1-8b\"\n", - "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n", - "\n", - "# set the path to the model checkpoint or leave empty to download from HuggingFace\n", - "MODEL_CHECKPOINT_PATH = \"\"\n", - "if not MODEL_CHECKPOINT_PATH:\n", - " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_REPO_ROOT}/llama_checkpoint\"\n", - " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n", - " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n", - "\n", - "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_REPO_ROOT}/sft_llama3_output\"\n", - "RUN_NAME = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yCO9GaQMjQq9" - }, - "source": [ - "## Download Llama3.1-8B Model Checkpoint from Hugging Face" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bwzsJpHqjQq9" - }, - "outputs": [], - "source": [ - "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", - " # install torch for the conversion script\n", - " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n", - "\n", - " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_REPO_ROOT} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n", - " {MAXTEXT_REPO_ROOT}/configs/base.yml \\\n", - " model_name={MODEL_NAME} \\\n", - " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n", - " hf_access_token={HF_TOKEN} \\\n", - " use_multimodal=false \\\n", - " scan_layers=true \\\n", - " skip_jax_distributed_system=True\n", - "\n", - "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", - " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "imqH5KL4jQq9" - }, - "source": [ - "## MaxText Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "In-jdp1AAwrL" - }, - "outputs": [], - "source": [ - "# Load configuration for SFT training\n", - "config_argv = [\n", - " \"\",\n", - " f\"{MAXTEXT_REPO_ROOT}/configs/sft.yml\",\n", - " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n", - " f\"model_name={MODEL_NAME}\",\n", - " \"steps=100\",\n", - " \"per_device_batch_size=1\",\n", - " \"max_target_length=1024\",\n", - " \"learning_rate=2.0e-5\",\n", - " \"weight_dtype=bfloat16\",\n", - " \"dtype=bfloat16\",\n", - " \"hf_path=HuggingFaceH4/ultrachat_200k\",\n", - " f\"hf_access_token={HF_TOKEN}\",\n", - " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", - " f\"run_name={RUN_NAME}\",\n", - " f\"tokenizer_path={TOKENIZER_PATH}\",\n", - " \"profiler=xplane\",\n", - "]\n", - "\n", - "config = pyconfig.initialize(config_argv)\n", - "\n", - "print(\"✓ SFT configuration loaded:\")\n", - "print(f\" Model: {config.model_name}\")\n", - "print(f\" Training Steps: {config.steps}\")\n", - "print(f\" Output Directory: {config.base_output_directory}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_fdeK75ajQq9" - }, - "source": [ - "## SFT Training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mgwpNgQYCJEd" - }, - "outputs": [], - "source": [ - "print(\"=\" * 60)\n", - "print(\"🚀 Starting SFT Training...\")\n", - "print(\"=\" * 60)\n", - "\n", - "try:\n", - " trainer, mesh = train_sft.train(config)\n", - "\n", - " print(\"\\n\" + \"=\" * 60)\n", - " print(\"✅ Training Completed Successfully!\")\n", - " print(\"=\" * 60)\n", - " print(f\"📁 Checkpoints saved to: {config.checkpoint_dir}\")\n", - "except Exception as e:\n", - " print(\"\\n\" + \"=\" * 60)\n", - " print(\"❌Training Failed!\")\n", - " print(\"=\" * 60)\n", - " print(f\"Error: {str(e)}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "85nvNASTjQq9" - }, - "source": [ - "## 📚 Learn More\n", - "\n", - "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html\n", - "- **Configuration**: See `src/MaxText/configs/sft.yml` for all available options\n", - "- **Documentation**: Check `src/MaxText/sft/sft_trainer.py` for the `sft_train` function implementation" - ] - } - ], - "metadata": { - "accelerator": "TPU", - "colab": { - "gpuType": "V5E1", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.2" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/src/MaxText/examples/sft_qwen3_demo.ipynb b/src/MaxText/examples/sft_qwen3_demo.ipynb deleted file mode 100644 index a56cb2465f..0000000000 --- a/src/MaxText/examples/sft_qwen3_demo.ipynb +++ /dev/null @@ -1,631 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "1nb_Ppf2ZUQL" - }, - "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/sft_qwen3_demo.ipynb)\n", - "\n", - "# Qwen3-0.6B Supervised Fine-Tuning (SFT) Demo\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FGbe4_YQZUQL" - }, - "source": [ - "## Overview\n", - "\n", - "This notebook performs SFT training and evaluation workflow on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k).\n", - "The primary goal is to demonstrate the end-to-end process of:\n", - "1. Pre-SFT Evaluation: Calcuating baseline accuracy for the model before training.\n", - "2. SFT Training: Fine-tune the model using MaxText & Tunix SFT trainer.\n", - "3. Post-SFT Evaluation: Re-running the evaluation loop after training to measure the performance gain achieved by SFT.\n", - "\n", - "This notebook can run on the **public TPU v5e-1**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zolxPWhQZUQL" - }, - "source": [ - "## Prerequisites\n", - "\n", - "### Change Runtime Type (only if running on Google Colab)\n", - "\n", - "**Instructions:**\n", - "1. Navigate to the menu at the top of the screen.\n", - "2. Click on **Runtime**.\n", - "3. Select **Change runtime type** from the dropdown menu.\n", - "4. Select **v5e-1 TPU** as the **Hardware accelerator**.\n", - "5. Click on **Save**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Rk_QpVVuZUQL" - }, - "source": [ - "### Get Your Hugging Face Token\n", - "\n", - "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", - "\n", - "**Follow these steps to get your token:**\n", - "\n", - "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", - " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", - "\n", - "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", - "\n", - "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", - "\n", - "4. **Copy the generated token**. You will need this in the later steps.\n", - "\n", - "**Follow these steps to store your token (only if running on Google Colab):**\n", - "\n", - "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", - "\n", - "2. Click **\"+ Add new secret\"**.\n", - "\n", - "3. Set the Name as **HF_TOKEN**.\n", - "\n", - "4. Paste your token into the Value field.\n", - "\n", - "5. Ensure the Notebook access toggle is turned On." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VQaxueyfjLwX" - }, - "outputs": [], - "source": [ - "try:\n", - " from google.colab import userdata\n", - " print(\"Running the notebook on Google Colab\")\n", - " IN_COLAB = True\n", - "except ImportError:\n", - " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", - " IN_COLAB = False" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "D9ms-jTSZUQL" - }, - "source": [ - "## Installation: MaxText & Other Dependencies" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wrSXHHrGjLwX" - }, - "source": [ - "**⚠️ Note:** The installation process in following cell may take a few minutes to complete. Please be patient." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jYAhyzQJjLwX" - }, - "outputs": [], - "source": [ - "try:\n", - " import google.colab\n", - " print(\"Running the notebook in Google Colab\")\n", - " IN_COLAB = True\n", - "except ImportError:\n", - " print(\"Running the notebook on JupyterLab\")\n", - " IN_COLAB = False" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "OSPRVbi7n6tB" - }, - "outputs": [], - "source": [ - "if IN_COLAB:\n", - " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n", - " %cd /content/maxtext\n", - "\n", - " # Install uv, a fast Python package installer\n", - " !pip install uv\n", - "\n", - " # Install MaxText and its dependencies\n", - " !uv pip install -e .[tpu] --resolution=lowest\n", - " !python3 -m MaxText.install_maxtext_extra_deps\n", - "\n", - " # Install Tunix for post-training notebooks\n", - " !uv pip install git+https://github.com/google/tunix\n", - " \n", - " # Install vllm for post-training notebooks\n", - " !git clone https://github.com/vllm-project/vllm.git\n", - " !VLLM_TARGET_DEVICE=\"tpu\" uv pip install ./vllm\n", - "\n", - " # Install tpu-inference for post-training notebooks\n", - " !git clone https://github.com/vllm-project/tpu-inference.git\n", - " !uv pip install ./tpu-inference\n", - " !uv pip install --no-deps qwix==0.1.4" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ywtealAxZUQM" - }, - "source": [ - "### Restart Session (only if running on Google Colab)\n", - "To apply certain changes, you need to restart the session.\n", - "\n", - "**Instructions:**\n", - "1. Navigate to the menu at the top of the screen.\n", - "2. Click on **Runtime**.\n", - "3. Select **Restart session** from the dropdown menu.\n", - "\n", - "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Clexf-j7ZUQM" - }, - "source": [ - "## Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "PkBI9A3JZUQM" - }, - "outputs": [], - "source": [ - "import jax\n", - "import os\n", - "import sys\n", - "import transformers\n", - "\n", - "import MaxText\n", - "from MaxText import pyconfig\n", - "from MaxText.examples.sft_train_and_evaluate import evaluate_model, get_test_dataset\n", - "from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter\n", - "from maxtext.trainers.post_train.sft import train_sft\n", - "\n", - "# Suppress vLLM logging with a severity level below ERROR\n", - "os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n", - "from tunix.rl.rollout import base_rollout\n", - "from tunix.rl.rollout.vllm_rollout import VllmRollout\n", - "\n", - "from datetime import datetime\n", - "from flax import nnx\n", - "from huggingface_hub import login\n", - "\n", - "MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)\n", - "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0nXfUgcWjLwX" - }, - "outputs": [], - "source": [ - "if not jax.distributed.is_initialized():\n", - " jax.distributed.initialize()\n", - "print(f\"JAX version: {jax.__version__}\")\n", - "print(f\"JAX devices: {jax.devices()}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JBbPN-uVZUQM" - }, - "outputs": [], - "source": [ - "try:\n", - " from google.colab import userdata\n", - " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", - "except ImportError:\n", - " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", - "\n", - "# If not found in the environment, prompt the user for input securely\n", - "# getpass function ensures the token is hidden while you type\n", - "if not HF_TOKEN:\n", - " from getpass import getpass\n", - " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", - "\n", - "if HF_TOKEN:\n", - " login(token=HF_TOKEN)\n", - " print(\"Authenticated with Hugging Face successfully!\")\n", - "else:\n", - " print(\"Authentication failed: Hugging Face token is not set.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aENuzm9iZUQM" - }, - "source": [ - "## Model Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "RjPYYl3zZUQM" - }, - "outputs": [], - "source": [ - "MODEL_NAME = \"qwen3-0.6b\"\n", - "TOKENIZER_PATH = \"Qwen/Qwen3-0.6B\"\n", - "tokenizer = transformers.AutoTokenizer.from_pretrained(\n", - " TOKENIZER_PATH,\n", - " token=HF_TOKEN,\n", - ")\n", - "\n", - "# set the path to the model checkpoint or leave empty to download from HuggingFace\n", - "MODEL_CHECKPOINT_PATH = \"\"\n", - "if not MODEL_CHECKPOINT_PATH:\n", - " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_REPO_ROOT}/qwen_checkpoint\"\n", - " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n", - " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n", - "\n", - "\n", - "RUN_NAME = datetime.now().strftime(\"%Y-%m-%d-%H-%m-%S\")\n", - "\n", - "# This is the directory where the fine-tuned model checkpoint will be saved\n", - "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_REPO_ROOT}/maxtext_qwen06_output\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4L37Ij4NZUQM" - }, - "source": [ - "## Download Qwen3-0.6B Model Checkpoint from Hugging Face" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kJanDAc0ZUQM" - }, - "outputs": [], - "source": [ - "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", - " # install torch for the conversion script\n", - " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n", - "\n", - " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_REPO_ROOT} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n", - " {MAXTEXT_REPO_ROOT}/configs/base.yml \\\n", - " model_name={MODEL_NAME} \\\n", - " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n", - " hf_access_token={HF_TOKEN} \\\n", - " use_multimodal=false \\\n", - " scan_layers=true \\\n", - " skip_jax_distributed_system=True\n", - "\n", - "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", - " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PC-hILG0ZUQM" - }, - "source": [ - "## Dataset Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "O3MLdr9kZUQM" - }, - "outputs": [], - "source": [ - "DATASET_NAME = \"openai/gsm8k\"\n", - "TRAIN_DATA_SPLIT = \"train\"\n", - "TEST_DATA_SPLIT = \"test\"\n", - "HF_DATA_DIR = \"main\"\n", - "TRAIN_DATA_COLUMNS = [\"question\", \"answer\"]\n", - "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/math_qa.json\"\n", - "if not os.path.exists(CHAT_TEMPLATE_PATH):\n", - " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", - "NUM_TEST_SAMPLES = 20 # Total number of samples to test\n", - "BATCH_SIZE = 1 # Number of test samples to process in a batch" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yeAHmxSYZUQM" - }, - "source": [ - "## MaxText Configurations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "In-jdp1AAwrL" - }, - "outputs": [], - "source": [ - "%%capture\n", - "config = pyconfig.initialize(\n", - " [\n", - " \"\",\n", - " f\"{MAXTEXT_REPO_ROOT}/configs/sft.yml\",\n", - " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n", - " f\"model_name={MODEL_NAME}\",\n", - " f\"hf_access_token={HF_TOKEN}\",\n", - " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", - " f\"run_name={RUN_NAME}\",\n", - " f\"tokenizer_path={TOKENIZER_PATH}\",\n", - " f\"hf_path={DATASET_NAME}\",\n", - " f\"train_split={TRAIN_DATA_SPLIT}\",\n", - " f\"hf_data_dir={HF_DATA_DIR}\",\n", - " f\"train_data_columns={TRAIN_DATA_COLUMNS}\",\n", - " \"steps=500\",\n", - " \"per_device_batch_size=1\",\n", - " \"max_target_length=1024\",\n", - " \"learning_rate=3e-6\",\n", - " \"weight_dtype=bfloat16\",\n", - " \"dtype=bfloat16\",\n", - " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", - " ]\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O9b0GWo-ZUQM" - }, - "source": [ - "## Initial Setup & Data Preparation" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TDqFmvUCZUQM" - }, - "source": [ - "### Create Test Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wscWYxrtZUQM" - }, - "outputs": [], - "source": [ - "test_dataset = get_test_dataset(config, tokenizer)\n", - "test_dataset = test_dataset[:NUM_TEST_SAMPLES]\n", - "test_dataset = test_dataset.to_iter_dataset().batch(BATCH_SIZE, drop_remainder=True)\n", - "TOTAL_BATCHES = NUM_TEST_SAMPLES // BATCH_SIZE\n", - "print(\n", - " f\"Processing {NUM_TEST_SAMPLES} examples with a batch size of {BATCH_SIZE}. This will result in {TOTAL_BATCHES} total batches for the test run.\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bLSvOOEUZUQM" - }, - "source": [ - "### Create SFT Trainer State" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2IHsC0m6ZUQM" - }, - "outputs": [], - "source": [ - "trainer, mesh = train_sft.setup_trainer_state(config)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PpKtEqzFZUQM" - }, - "source": [ - "### Create vLLM Rollout" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3-pf_rbqZUQM" - }, - "outputs": [], - "source": [ - "tunix_model = TunixMaxTextAdapter(trainer.model)\n", - "vllm_rollout = VllmRollout(\n", - " model=tunix_model,\n", - " tokenizer=tokenizer,\n", - " cache_config_or_size=1280,\n", - " mesh=mesh,\n", - " rollout_config=base_rollout.RolloutConfig(\n", - " rollout_vllm_model_version=TOKENIZER_PATH,\n", - " rollout_vllm_hbm_utilization=0.8,\n", - " rollout_vllm_init_with_random_weights=True,\n", - " rollout_vllm_tpu_backend_type=\"jax\",\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "567gTxsEZUQM" - }, - "source": [ - "## Evaluation before SFT Training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "OnACa3zCZUQM" - }, - "outputs": [], - "source": [ - "print(\"Running Pre-SFT Evaluation...\")\n", - "score = evaluate_model(test_dataset, vllm_rollout, debug=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "u5-M4iYkZUQN" - }, - "outputs": [], - "source": [ - "print(\"========================= Score for PRE-SFT Evaluation =========================\")\n", - "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n", - "print(\n", - " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n", - ")\n", - "print(\n", - " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EJE1ookSAzz-" - }, - "source": [ - "## SFT Training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "mgwpNgQYCJEd", - "tags": [] - }, - "outputs": [], - "source": [ - "print(\"Starting SFT Training...\")\n", - "trainer = train_sft.train_model(config, trainer, mesh)\n", - "print(\"SFT Training Complete!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WEdNYRhwZUQN" - }, - "source": [ - "## Evaluation after SFT Training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XcsZacZdZUQN" - }, - "outputs": [], - "source": [ - "print(\"Running Post-SFT Evaluation...\")\n", - "model = TunixMaxTextAdapter(trainer.model)\n", - "state = nnx.state(model)\n", - "vllm_rollout.update_params(state)\n", - "score = evaluate_model(test_dataset, vllm_rollout, debug=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "-JtYTPvJZUQN", - "tags": [] - }, - "outputs": [], - "source": [ - "print(\"========================= Score for POST-SFT Evaluation =========================\")\n", - "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n", - "print(\n", - " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n", - ")\n", - "print(\n", - " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n", - ")" - ] - } - ], - "metadata": { - "accelerator": "TPU", - "colab": { - "gpuType": "V5E1", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.11" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/src/MaxText/examples/chat_templates/gsm8k_rl.json b/src/maxtext/examples/chat_templates/gsm8k_rl.json similarity index 100% rename from src/MaxText/examples/chat_templates/gsm8k_rl.json rename to src/maxtext/examples/chat_templates/gsm8k_rl.json diff --git a/src/MaxText/examples/chat_templates/math_qa.json b/src/maxtext/examples/chat_templates/math_qa.json similarity index 100% rename from src/MaxText/examples/chat_templates/math_qa.json rename to src/maxtext/examples/chat_templates/math_qa.json diff --git a/src/MaxText/examples/demo_decoding.ipynb b/src/maxtext/examples/demo_decoding.ipynb similarity index 99% rename from src/MaxText/examples/demo_decoding.ipynb rename to src/maxtext/examples/demo_decoding.ipynb index 99757c569c..4698a260a0 100644 --- a/src/MaxText/examples/demo_decoding.ipynb +++ b/src/maxtext/examples/demo_decoding.ipynb @@ -5,7 +5,7 @@ "id": "e017d77b", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/demo_decoding.ipynb)\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/demo_decoding.ipynb)\n", " \n", "# Qwen3-0.6B Decoding Demo" ] @@ -437,4 +437,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/src/MaxText/examples/multimodal_gemma3_demo.ipynb b/src/maxtext/examples/multimodal_gemma3_demo.ipynb similarity index 99% rename from src/MaxText/examples/multimodal_gemma3_demo.ipynb rename to src/maxtext/examples/multimodal_gemma3_demo.ipynb index 8410cd6424..4df0157314 100644 --- a/src/MaxText/examples/multimodal_gemma3_demo.ipynb +++ b/src/maxtext/examples/multimodal_gemma3_demo.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/multimodal_gemma3_demo.ipynb)\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/multimodal_gemma3_demo.ipynb)\n", "\n", "# Gemma3 Multimodal Inference/Training Demo" ] diff --git a/pedagogical_examples/non_spmd.py b/src/maxtext/examples/non_spmd.py similarity index 100% rename from pedagogical_examples/non_spmd.py rename to src/maxtext/examples/non_spmd.py diff --git a/src/maxtext/examples/rl_llama3_demo.ipynb b/src/maxtext/examples/rl_llama3_demo.ipynb new file mode 100644 index 0000000000..4eacf0b34a --- /dev/null +++ b/src/maxtext/examples/rl_llama3_demo.ipynb @@ -0,0 +1,372 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Llama3.1-8B-Instruct Reinforcement Learning Demo\n", + "\n", + "This notebook demonstrates training on Llama3.1-8B-Instruct model with either GRPO (Group Relative Policy Optimization) or GSPO (Group Sequence Policy Optimization).\n", + "\n", + "This notebook can run on **TPU v5e-8** or **v5p-8**.\n", + "\n", + "## What is GRPO/GSPO?\n", + "\n", + "GRPO/GSPO is an RL algorithm that enhances reasoning abilities of LLMs by:\n", + "1. Generating multiple responses for each prompt\n", + "2. Evaluating responses using reward models \n", + "3. Calculating relative advantages to update the policy\n", + "\n", + "The difference is in the loss function - either it's optimizing each token (GRPO) or the whole sequence(GSPO)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "### Change Runtime Type (only if running on Google Colab)\n", + "\n", + "**Instructions:**\n", + "1. Navigate to the menu at the top of the screen.\n", + "2. Click on **Runtime**.\n", + "3. Select **Change runtime type** from the dropdown menu.\n", + "4. Select **v5e-8** or **v5p-8 TPU** as the **Hardware accelerator**.\n", + "5. Click on **Save**." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get Your Hugging Face Token\n", + "\n", + "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", + "\n", + "**Follow these steps to get your token:**\n", + "\n", + "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", + " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", + "\n", + "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", + "\n", + "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", + "\n", + "4. **Copy the generated token**. You will need this in the later steps.\n", + "\n", + "**Follow these steps to store your token (only if running on Google Colab):**\n", + "\n", + "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", + "\n", + "2. Click **\"+ Add new secret\"**.\n", + "\n", + "3. Set the Name as **HF_TOKEN**.\n", + "\n", + "4. Paste your token into the Value field.\n", + "\n", + "5. Ensure the Notebook access toggle is turned On." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from google.colab import userdata\n", + " print(\"Running the notebook on Google Colab\")\n", + " IN_COLAB = True\n", + "except ImportError:\n", + " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", + " IN_COLAB = False" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation: MaxText and Dependencies\n", + "\n", + "**⚠️ Note:** The installation process in following cell may take a few minutes to complete. Please be patient." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "if IN_COLAB:\n", + " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n", + " %cd /content/maxtext\n", + "\n", + " # Install uv, a fast Python package installer\n", + " !pip install uv\n", + "\n", + " # Install MaxText and its dependencies\n", + " !uv pip install -e .[tpu] --resolution=lowest\n", + " !python3 -m MaxText.install_maxtext_extra_deps\n", + "\n", + " # Install vLLM for Jax and TPUs\n", + " !uv pip install vllm-tpu\n", + " !uv pip install --no-deps qwix==0.1.4" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Restart Session (only if running on Google Colab)\n", + "To apply certain changes, you need to restart the session.\n", + "\n", + "**Instructions:**\n", + "1. Navigate to the menu at the top of the screen.\n", + "2. Click on **Runtime**.\n", + "3. Select **Restart session** from the dropdown menu.\n", + "\n", + "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Environment Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import datetime\n", + "import os\n", + "import sys\n", + "from pathlib import Path\n", + "import MaxText\n", + "from huggingface_hub import login\n", + "import jax\n", + "\n", + "from MaxText import max_utils\n", + "from MaxText.rl.train_rl import rl_train, setup_configs_and_devices\n", + "\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n", + "os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n", + "\n", + "MAXTEXT_PKG_DIR = os.path.dirname(MaxText.__file__)\n", + "MAXTEXT_REPO_ROOT = os.sep.join([\"maxtext\" if p == \"MaxText\" else p for p in MAXTEXT_PKG_DIR.split(os.sep)])\n", + "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not jax.distributed.is_initialized():\n", + " jax.distributed.initialize()\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX devices: {jax.devices()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if IN_COLAB:\n", + " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", + "else:\n", + " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", + "\n", + "# If not found in the environment, prompt the user for input securely\n", + "# getpass function ensures the token is hidden while you type\n", + "if not HF_TOKEN:\n", + " from getpass import getpass\n", + " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", + "\n", + "if HF_TOKEN:\n", + " os.environ[\"HF_TOKEN\"] = HF_TOKEN\n", + " login(token=HF_TOKEN)\n", + " print(\"Authenticated with Hugging Face successfully!\")\n", + "else:\n", + " print(\"Authentication failed: Hugging Face token is not set.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_NAME = \"llama3.1-8b\"\n", + "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n", + "RUN_NAME = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n", + "LOSS_ALGO=\"grpo\" # or \"gspo-token\" if you want to use GSPO\n", + "\n", + "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/gsm8k_rl.json\"\n", + "if not os.path.exists(CHAT_TEMPLATE_PATH):\n", + " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", + "\n", + "# set the path to the model checkpoint or leave empty to download from HuggingFace\n", + "MODEL_CHECKPOINT_PATH = \"\"\n", + "if not MODEL_CHECKPOINT_PATH:\n", + " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_PKG_DIR}/llama_checkpoint\"\n", + " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n", + " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n", + " \n", + "OUTPUT_DIRECTORY = f\"{MAXTEXT_PKG_DIR}/rl_llama3_output\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download Llama3.1-8B Model Checkpoint from Hugging Face" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", + " # install torch for the conversion script\n", + " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n", + "\n", + " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_PKG_DIR} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n", + " {MAXTEXT_PKG_DIR}/configs/base.yml \\\n", + " model_name={MODEL_NAME} \\\n", + " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n", + " hf_access_token={HF_TOKEN} \\\n", + " use_multimodal=false \\\n", + " scan_layers=true \\\n", + " skip_jax_distributed_system=True\n", + "\n", + "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", + " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MaxText Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load configuration for RL training\n", + "config_argv = [\n", + " \"\",\n", + " f\"{MAXTEXT_PKG_DIR}/configs/rl.yml\",\n", + " f\"model_name={MODEL_NAME}\",\n", + " f\"tokenizer_path={TOKENIZER_PATH}\",\n", + " f\"run_name={RUN_NAME}\",\n", + " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", + " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n", + " f\"base_output_directory={OUTPUT_DIRECTORY}\",\n", + " f\"hf_access_token={HF_TOKEN}\",\n", + " \"debug.rl=False\",\n", + " f\"rl.loss_algo={LOSS_ALGO}\",\n", + " \"use_pathways=False\"\n", + "]\n", + "\n", + "trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(config_argv)\n", + "\n", + "rl_train_steps = int(\n", + " trainer_config.num_batches\n", + " * trainer_config.rl.num_iterations\n", + " * trainer_config.train_fraction\n", + " * trainer_config.num_epoch\n", + ")\n", + "\n", + "print(\"✓ Configuration initialized successfully\")\n", + "print(f\"📁 Output directory: {trainer_config.base_output_directory}\")\n", + "print(f\"🤖 Model: {trainer_config.model_name}\")\n", + "print(f\"📊 RL Train Steps: {rl_train_steps}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## RL Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n\" + \"=\" * 80)\n", + "print(f\"🚀 Starting {LOSS_ALGO} Training...\")\n", + "print(\"=\" * 80)\n", + "try:\n", + " rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"✅ Training Completed Successfully!\")\n", + " print(f\"✍️ Note the improved evaluation accuracy metrics with just {rl_train_steps} RL training steps!\")\n", + " print(\"=\" * 80)\n", + " print(f\"📁 Checkpoints saved to: {trainer_config.checkpoint_dir}\")\n", + " print(f\"📊 TensorBoard logs: {trainer_config.tensorboard_dir}\")\n", + " print(f\"🎯 Model ready for inference!\")\n", + "except Exception as e:\n", + " print(\"\\n\" + \"=\" * 80)\n", + " print(\"❌Training Failed!\")\n", + " print(\"=\" * 80)\n", + " print(f\"Error: {str(e)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 📚 Learn More\n", + "\n", + "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/rl.html\n", + "- **Configuration**: See `src/MaxText/configs/rl.yml` for all available options\n", + "- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/maxtext/examples/sft_llama3_demo.ipynb b/src/maxtext/examples/sft_llama3_demo.ipynb new file mode 100644 index 0000000000..0b7dd227ce --- /dev/null +++ b/src/maxtext/examples/sft_llama3_demo.ipynb @@ -0,0 +1,367 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo.ipynb)\n", + "\n", + "# Llama3.1-8B-Instruct Supervised Fine-Tuning (SFT) Demo\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "\n", + "This notebook demonstrates how to perform Supervised Fine-Tuning (SFT) on Llama3.1-8B-Instruct using the Hugging Face ultrachat_200k dataset with MaxText and Tunix integration for efficient training.\n", + "\n", + "This notebook can run on **TPU v5e-8** or **v5p-8**." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "\n", + "### Change Runtime Type (only if running on Google Colab)\n", + "\n", + "**Instructions:**\n", + "1. Navigate to the menu at the top of the screen.\n", + "2. Click on **Runtime**.\n", + "3. Select **Change runtime type** from the dropdown menu.\n", + "4. Select **v5e-8** or **v5p-8 TPU** as the **Hardware accelerator**.\n", + "5. Click on **Save**." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get Your Hugging Face Token\n", + "\n", + "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", + "\n", + "**Follow these steps to get your token:**\n", + "\n", + "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", + " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", + "\n", + "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", + "\n", + "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", + "\n", + "4. **Copy the generated token**. You will need this in the later steps.\n", + "\n", + "**Follow these steps to store your token (only if running on Google Colab):**\n", + "\n", + "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", + "\n", + "2. Click **\"+ Add new secret\"**.\n", + "\n", + "3. Set the Name as **HF_TOKEN**.\n", + "\n", + "4. Paste your token into the Value field.\n", + "\n", + "5. Ensure the Notebook access toggle is turned On." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from google.colab import userdata\n", + " print(\"Running the notebook on Google Colab\")\n", + " IN_COLAB = True\n", + "except ImportError:\n", + " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", + " IN_COLAB = False" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation: MaxText & Other Dependencies" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**⚠️ Note:** The installation process in following cell may take a few minutes to complete. Please be patient." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if IN_COLAB:\n", + " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n", + " %cd /content/maxtext\n", + "\n", + " # Install uv, a fast Python package installer\n", + " !pip install uv\n", + "\n", + " # Install MaxText and its dependencies\n", + " !uv pip install -e .[tpu] --resolution=lowest\n", + " !python3 -m MaxText.install_maxtext_extra_deps" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Restart Session (only if running on Google Colab)\n", + "To apply certain changes, you need to restart the session.\n", + "\n", + "**Instructions:**\n", + "1. Navigate to the menu at the top of the screen.\n", + "2. Click on **Runtime**.\n", + "3. Select **Restart session** from the dropdown menu.\n", + "\n", + "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Environment Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import datetime\n", + "import os\n", + "import subprocess\n", + "import sys\n", + "import MaxText\n", + "from MaxText import pyconfig\n", + "from maxtext.trainers.post_train.sft import train_sft\n", + "import jax\n", + "from huggingface_hub import login\n", + "\n", + "MAXTEXT_PKG_DIR = os.path.dirname(MaxText.__file__)\n", + "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not jax.distributed.is_initialized():\n", + " jax.distributed.initialize()\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX devices: {jax.devices()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if IN_COLAB:\n", + " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", + "else:\n", + " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", + "\n", + "# If not found in the environment, prompt the user for input securely\n", + "# getpass function ensures the token is hidden while you type\n", + "if not HF_TOKEN:\n", + " from getpass import getpass\n", + " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", + "\n", + "if HF_TOKEN:\n", + " login(token=HF_TOKEN)\n", + " print(\"Authenticated with Hugging Face successfully!\")\n", + "else:\n", + " print(\"Authentication failed: Hugging Face token is not set.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_NAME = \"llama3.1-8b\"\n", + "TOKENIZER_PATH = \"meta-llama/Llama-3.1-8B-Instruct\"\n", + "\n", + "# set the path to the model checkpoint or leave empty to download from HuggingFace\n", + "MODEL_CHECKPOINT_PATH = \"\"\n", + "if not MODEL_CHECKPOINT_PATH:\n", + " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_PKG_DIR}/llama_checkpoint\"\n", + " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n", + " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n", + "\n", + "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_PKG_DIR}/sft_llama3_output\"\n", + "RUN_NAME = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download Llama3.1-8B Model Checkpoint from Hugging Face" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", + " # install torch for the conversion script\n", + " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n", + "\n", + " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_PKG_DIR} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n", + " {MAXTEXT_PKG_DIR}/configs/base.yml \\\n", + " model_name={MODEL_NAME} \\\n", + " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n", + " hf_access_token={HF_TOKEN} \\\n", + " use_multimodal=false \\\n", + " scan_layers=true \\\n", + " skip_jax_distributed_system=True\n", + "\n", + "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", + " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MaxText Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "In-jdp1AAwrL" + }, + "outputs": [], + "source": [ + "# Load configuration for SFT training\n", + "config_argv = [\n", + " \"\",\n", + " f\"{MAXTEXT_PKG_DIR}/configs/sft.yml\",\n", + " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n", + " f\"model_name={MODEL_NAME}\",\n", + " \"steps=100\",\n", + " \"per_device_batch_size=1\",\n", + " \"max_target_length=1024\",\n", + " \"learning_rate=2.0e-5\",\n", + " \"weight_dtype=bfloat16\",\n", + " \"dtype=bfloat16\",\n", + " \"hf_path=HuggingFaceH4/ultrachat_200k\",\n", + " f\"hf_access_token={HF_TOKEN}\",\n", + " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", + " f\"run_name={RUN_NAME}\",\n", + " f\"tokenizer_path={TOKENIZER_PATH}\",\n", + " \"profiler=xplane\",\n", + "]\n", + "\n", + "config = pyconfig.initialize(config_argv)\n", + "\n", + "print(\"✓ SFT configuration loaded:\")\n", + "print(f\" Model: {config.model_name}\")\n", + "print(f\" Training Steps: {config.steps}\")\n", + "print(f\" Output Directory: {config.base_output_directory}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SFT Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mgwpNgQYCJEd" + }, + "outputs": [], + "source": [ + "print(\"=\" * 60)\n", + "print(\"🚀 Starting SFT Training...\")\n", + "print(\"=\" * 60)\n", + "\n", + "try:\n", + " trainer, mesh = train_sft.train(config)\n", + "\n", + " print(\"\\n\" + \"=\" * 60)\n", + " print(\"✅ Training Completed Successfully!\")\n", + " print(\"=\" * 60)\n", + " print(f\"📁 Checkpoints saved to: {config.checkpoint_dir}\")\n", + "except Exception as e:\n", + " print(\"\\n\" + \"=\" * 60)\n", + " print(\"❌Training Failed!\")\n", + " print(\"=\" * 60)\n", + " print(f\"Error: {str(e)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 📚 Learn More\n", + "\n", + "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html\n", + "- **Configuration**: See `src/MaxText/configs/sft.yml` for all available options\n", + "- **Documentation**: Check `src/MaxText/sft/sft_trainer.py` for the `sft_train` function implementation" + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "gpuType": "V5E1", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/src/maxtext/examples/sft_qwen3_demo.ipynb b/src/maxtext/examples/sft_qwen3_demo.ipynb new file mode 100644 index 0000000000..bd44c741c0 --- /dev/null +++ b/src/maxtext/examples/sft_qwen3_demo.ipynb @@ -0,0 +1,624 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "1nb_Ppf2ZUQL" + }, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/sft_qwen3_demo.ipynb)\n", + "\n", + "# Qwen3-0.6B Supervised Fine-Tuning (SFT) Demo\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FGbe4_YQZUQL" + }, + "source": [ + "## Overview\n", + "\n", + "This notebook performs SFT training and evaluation workflow on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k).\n", + "The primary goal is to demonstrate the end-to-end process of:\n", + "1. Pre-SFT Evaluation: Calcuating baseline accuracy for the model before training.\n", + "2. SFT Training: Fine-tune the model using MaxText & Tunix SFT trainer.\n", + "3. Post-SFT Evaluation: Re-running the evaluation loop after training to measure the performance gain achieved by SFT.\n", + "\n", + "This notebook can run on the **public TPU v5e-1**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zolxPWhQZUQL" + }, + "source": [ + "## Prerequisites\n", + "\n", + "### Change Runtime Type (only if running on Google Colab)\n", + "\n", + "**Instructions:**\n", + "1. Navigate to the menu at the top of the screen.\n", + "2. Click on **Runtime**.\n", + "3. Select **Change runtime type** from the dropdown menu.\n", + "4. Select **v5e-1 TPU** as the **Hardware accelerator**.\n", + "5. Click on **Save**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Rk_QpVVuZUQL" + }, + "source": [ + "### Get Your Hugging Face Token\n", + "\n", + "To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.\n", + "\n", + "**Follow these steps to get your token:**\n", + "\n", + "1. **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:\n", + " * [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", + "\n", + "2. **Create a new token** by clicking the **\"+ Create new token\"** button.\n", + "\n", + "3. **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.\n", + "\n", + "4. **Copy the generated token**. You will need this in the later steps.\n", + "\n", + "**Follow these steps to store your token (only if running on Google Colab):**\n", + "\n", + "1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).\n", + "\n", + "2. Click **\"+ Add new secret\"**.\n", + "\n", + "3. Set the Name as **HF_TOKEN**.\n", + "\n", + "4. Paste your token into the Value field.\n", + "\n", + "5. Ensure the Notebook access toggle is turned On." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " from google.colab import userdata\n", + " print(\"Running the notebook on Google Colab\")\n", + " IN_COLAB = True\n", + "except ImportError:\n", + " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", + " IN_COLAB = False" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D9ms-jTSZUQL" + }, + "source": [ + "## Installation: MaxText & Other Dependencies" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**⚠️ Note:** The installation process in following cell may take a few minutes to complete. Please be patient." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " import google.colab\n", + " print(\"Running the notebook in Google Colab\")\n", + " IN_COLAB = True\n", + "except ImportError:\n", + " print(\"Running the notebook on JupyterLab\")\n", + " IN_COLAB = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OSPRVbi7n6tB" + }, + "outputs": [], + "source": [ + "if IN_COLAB:\n", + " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n", + " %cd /content/maxtext\n", + "\n", + " # Install uv, a fast Python package installer\n", + " !pip install uv\n", + "\n", + " # Install MaxText and its dependencies\n", + " !uv pip install -e .[tpu] --resolution=lowest\n", + " !python3 -m MaxText.install_maxtext_extra_deps\n", + "\n", + " # Install Tunix for post-training notebooks\n", + " !uv pip install git+https://github.com/google/tunix\n", + " \n", + " # Install vllm for post-training notebooks\n", + " !git clone https://github.com/vllm-project/vllm.git\n", + " !VLLM_TARGET_DEVICE=\"tpu\" uv pip install ./vllm\n", + "\n", + " # Install tpu-inference for post-training notebooks\n", + " !git clone https://github.com/vllm-project/tpu-inference.git\n", + " !uv pip install ./tpu-inference\n", + " !uv pip install --no-deps qwix==0.1.4" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ywtealAxZUQM" + }, + "source": [ + "### Restart Session (only if running on Google Colab)\n", + "To apply certain changes, you need to restart the session.\n", + "\n", + "**Instructions:**\n", + "1. Navigate to the menu at the top of the screen.\n", + "2. Click on **Runtime**.\n", + "3. Select **Restart session** from the dropdown menu.\n", + "\n", + "You will be asked to confirm the action in a pop-up dialog. Click on **Yes**." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Clexf-j7ZUQM" + }, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PkBI9A3JZUQM" + }, + "outputs": [], + "source": [ + "import jax\n", + "import os\n", + "import sys\n", + "import transformers\n", + "\n", + "import MaxText\n", + "from MaxText import pyconfig\n", + "from maxtext.examples.sft_train_and_evaluate import evaluate_model, get_test_dataset\n", + "from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter\n", + "from maxtext.trainers.post_train.sft import train_sft\n", + "\n", + "# Suppress vLLM logging with a severity level below ERROR\n", + "os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n", + "from tunix.rl.rollout import base_rollout\n", + "from tunix.rl.rollout.vllm_rollout import VllmRollout\n", + "\n", + "from datetime import datetime\n", + "from flax import nnx\n", + "from huggingface_hub import login\n", + "\n", + "MAXTEXT_PKG_DIR = os.path.dirname(MaxText.__file__)\n", + "MAXTEXT_REPO_ROOT = os.sep.join([\"maxtext\" if p == \"MaxText\" else p for p in MAXTEXT_PKG_DIR.split(os.sep)])\n", + "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not jax.distributed.is_initialized():\n", + " jax.distributed.initialize()\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX devices: {jax.devices()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JBbPN-uVZUQM" + }, + "outputs": [], + "source": [ + "try:\n", + " from google.colab import userdata\n", + " HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", + "except ImportError:\n", + " HF_TOKEN = os.environ.get(\"HF_TOKEN\", \"\")\n", + "\n", + "# If not found in the environment, prompt the user for input securely\n", + "# getpass function ensures the token is hidden while you type\n", + "if not HF_TOKEN:\n", + " from getpass import getpass\n", + " HF_TOKEN = getpass(\"Hugging Face token not found in environment. Please enter it here: \")\n", + "\n", + "if HF_TOKEN:\n", + " login(token=HF_TOKEN)\n", + " print(\"Authenticated with Hugging Face successfully!\")\n", + "else:\n", + " print(\"Authentication failed: Hugging Face token is not set.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aENuzm9iZUQM" + }, + "source": [ + "## Model Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RjPYYl3zZUQM" + }, + "outputs": [], + "source": [ + "MODEL_NAME = \"qwen3-0.6b\"\n", + "TOKENIZER_PATH = \"Qwen/Qwen3-0.6B\"\n", + "tokenizer = transformers.AutoTokenizer.from_pretrained(\n", + " TOKENIZER_PATH,\n", + " token=HF_TOKEN,\n", + ")\n", + "\n", + "# set the path to the model checkpoint or leave empty to download from HuggingFace\n", + "MODEL_CHECKPOINT_PATH = \"\"\n", + "if not MODEL_CHECKPOINT_PATH:\n", + " MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_PKG_DIR}/qwen_checkpoint\"\n", + " print(\"Model checkpoint will be downloaded from HuggingFace at: \", MODEL_CHECKPOINT_PATH)\n", + " print(\"Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.\")\n", + "\n", + "\n", + "RUN_NAME = datetime.now().strftime(\"%Y-%m-%d-%H-%m-%S\")\n", + "\n", + "# This is the directory where the fine-tuned model checkpoint will be saved\n", + "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_PKG_DIR}/maxtext_qwen06_output\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4L37Ij4NZUQM" + }, + "source": [ + "## Download Qwen3-0.6B Model Checkpoint from Hugging Face" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kJanDAc0ZUQM" + }, + "outputs": [], + "source": [ + "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", + " # install torch for the conversion script\n", + " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n", + "\n", + " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_PKG_DIR} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n", + " {MAXTEXT_PKG_DIR}/configs/base.yml \\\n", + " model_name={MODEL_NAME} \\\n", + " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n", + " hf_access_token={HF_TOKEN} \\\n", + " use_multimodal=false \\\n", + " scan_layers=true \\\n", + " skip_jax_distributed_system=True\n", + "\n", + "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", + " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PC-hILG0ZUQM" + }, + "source": [ + "## Dataset Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "O3MLdr9kZUQM" + }, + "outputs": [], + "source": [ + "DATASET_NAME = \"openai/gsm8k\"\n", + "TRAIN_DATA_SPLIT = \"train\"\n", + "TEST_DATA_SPLIT = \"test\"\n", + "HF_DATA_DIR = \"main\"\n", + "TRAIN_DATA_COLUMNS = [\"question\", \"answer\"]\n", + "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/math_qa.json\"\n", + "if not os.path.exists(CHAT_TEMPLATE_PATH):\n", + " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", + "NUM_TEST_SAMPLES = 20 # Total number of samples to test\n", + "BATCH_SIZE = 1 # Number of test samples to process in a batch" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yeAHmxSYZUQM" + }, + "source": [ + "## MaxText Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "In-jdp1AAwrL" + }, + "outputs": [], + "source": [ + "%%capture\n", + "config = pyconfig.initialize(\n", + " [\n", + " \"\",\n", + " f\"{MAXTEXT_PKG_DIR}/configs/sft.yml\",\n", + " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n", + " f\"model_name={MODEL_NAME}\",\n", + " f\"hf_access_token={HF_TOKEN}\",\n", + " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", + " f\"run_name={RUN_NAME}\",\n", + " f\"tokenizer_path={TOKENIZER_PATH}\",\n", + " f\"hf_path={DATASET_NAME}\",\n", + " f\"train_split={TRAIN_DATA_SPLIT}\",\n", + " f\"hf_data_dir={HF_DATA_DIR}\",\n", + " f\"train_data_columns={TRAIN_DATA_COLUMNS}\",\n", + " \"steps=500\",\n", + " \"per_device_batch_size=1\",\n", + " \"max_target_length=1024\",\n", + " \"learning_rate=3e-6\",\n", + " \"weight_dtype=bfloat16\",\n", + " \"dtype=bfloat16\",\n", + " f\"chat_template_path={CHAT_TEMPLATE_PATH}\",\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O9b0GWo-ZUQM" + }, + "source": [ + "## Initial Setup & Data Preparation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TDqFmvUCZUQM" + }, + "source": [ + "### Create Test Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wscWYxrtZUQM" + }, + "outputs": [], + "source": [ + "test_dataset = get_test_dataset(config, tokenizer)\n", + "test_dataset = test_dataset[:NUM_TEST_SAMPLES]\n", + "test_dataset = test_dataset.to_iter_dataset().batch(BATCH_SIZE, drop_remainder=True)\n", + "TOTAL_BATCHES = NUM_TEST_SAMPLES // BATCH_SIZE\n", + "print(\n", + " f\"Processing {NUM_TEST_SAMPLES} examples with a batch size of {BATCH_SIZE}. This will result in {TOTAL_BATCHES} total batches for the test run.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bLSvOOEUZUQM" + }, + "source": [ + "### Create SFT Trainer State" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2IHsC0m6ZUQM" + }, + "outputs": [], + "source": [ + "trainer, mesh = train_sft.setup_trainer_state(config)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PpKtEqzFZUQM" + }, + "source": [ + "### Create vLLM Rollout" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3-pf_rbqZUQM" + }, + "outputs": [], + "source": [ + "tunix_model = TunixMaxTextAdapter(trainer.model)\n", + "vllm_rollout = VllmRollout(\n", + " model=tunix_model,\n", + " tokenizer=tokenizer,\n", + " cache_config_or_size=1280,\n", + " mesh=mesh,\n", + " rollout_config=base_rollout.RolloutConfig(\n", + " rollout_vllm_model_version=TOKENIZER_PATH,\n", + " rollout_vllm_hbm_utilization=0.8,\n", + " rollout_vllm_init_with_random_weights=True,\n", + " rollout_vllm_tpu_backend_type=\"jax\",\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "567gTxsEZUQM" + }, + "source": [ + "## Evaluation before SFT Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OnACa3zCZUQM" + }, + "outputs": [], + "source": [ + "print(\"Running Pre-SFT Evaluation...\")\n", + "score = evaluate_model(test_dataset, vllm_rollout, debug=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "u5-M4iYkZUQN" + }, + "outputs": [], + "source": [ + "print(\"========================= Score for PRE-SFT Evaluation =========================\")\n", + "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n", + "print(\n", + " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n", + ")\n", + "print(\n", + " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EJE1ookSAzz-" + }, + "source": [ + "## SFT Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "id": "mgwpNgQYCJEd", + "tags": [] + }, + "outputs": [], + "source": [ + "print(\"Starting SFT Training...\")\n", + "trainer = train_sft.train_model(config, trainer, mesh)\n", + "print(\"SFT Training Complete!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WEdNYRhwZUQN" + }, + "source": [ + "## Evaluation after SFT Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XcsZacZdZUQN" + }, + "outputs": [], + "source": [ + "print(\"Running Post-SFT Evaluation...\")\n", + "model = TunixMaxTextAdapter(trainer.model)\n", + "state = nnx.state(model)\n", + "vllm_rollout.update_params(state)\n", + "score = evaluate_model(test_dataset, vllm_rollout, debug=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "id": "-JtYTPvJZUQN", + "tags": [] + }, + "outputs": [], + "source": [ + "print(\"========================= Score for POST-SFT Evaluation =========================\")\n", + "print(f\"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%\")\n", + "print(\n", + " f\"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%\"\n", + ")\n", + "print(\n", + " f\"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%\"\n", + ")" + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "gpuType": "V5E1", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/src/MaxText/examples/sft_train_and_evaluate.py b/src/maxtext/examples/sft_train_and_evaluate.py similarity index 98% rename from src/MaxText/examples/sft_train_and_evaluate.py rename to src/maxtext/examples/sft_train_and_evaluate.py index b38985fa38..7263169362 100644 --- a/src/MaxText/examples/sft_train_and_evaluate.py +++ b/src/maxtext/examples/sft_train_and_evaluate.py @@ -35,7 +35,7 @@ export MODEL_CHECKPOINT_PATH= export HF_ACCESS_TOKEN= -python3 -m MaxText.examples.sft_train_and_evaluate MaxText/configs/sft.yml \ +python3 -m maxtext.examples.sft_train_and_evaluate MaxText/configs/sft.yml \ run_name=$RUN_NAME base_output_directory=$OUTPUT_PATH \ model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH \ hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH @@ -67,7 +67,7 @@ --workload=sft-${RUN_NAME} \ --tpu-type ${TPU_TYPE} --num-slices=1 --zone=${ZONE} \ --project=${PROJECT} \ ---command "HF_TOKEN=$HF_ACCESS_TOKEN python3 -m MaxText.examples.sft_train_and_evaluate MaxText/configs/sft.yml \ +--command "HF_TOKEN=$HF_ACCESS_TOKEN python3 -m maxtext.examples.sft_train_and_evaluate MaxText/configs/sft.yml \ run_name=$RUN_NAME base_output_directory=$OUTPUT_PATH \ model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH \ hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH" @@ -125,7 +125,7 @@ ) # Regex to extract the final numerical answer MATCH_ANSWER = re.compile(rf"{ANSWER_START}.*?([\d\.\,\$]{{1,}})", flags=re.MULTILINE | re.DOTALL) -CHAT_TEMPLATE_PATH = os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "examples", "chat_templates", "math_qa.json") +CHAT_TEMPLATE_PATH = os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "examples", "chat_templates", "math_qa.json") def get_test_dataset(config, tokenizer): diff --git a/pedagogical_examples/shardings.py b/src/maxtext/examples/shardings.py similarity index 100% rename from pedagogical_examples/shardings.py rename to src/maxtext/examples/shardings.py diff --git a/pedagogical_examples/shmap_collective_matmul.py b/src/maxtext/examples/shmap_collective_matmul.py similarity index 100% rename from pedagogical_examples/shmap_collective_matmul.py rename to src/maxtext/examples/shmap_collective_matmul.py diff --git a/tests/integration/shmap_collective_matmul_test.py b/tests/integration/shmap_collective_matmul_test.py index e89edfa5a7..8c966398dc 100644 --- a/tests/integration/shmap_collective_matmul_test.py +++ b/tests/integration/shmap_collective_matmul_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Integration test for pedagogical_examples/shmap_collective_matmul.py""" +"""Integration test for maxtext/examples/shmap_collective_matmul.py""" import os.path import sys @@ -24,7 +24,7 @@ sys.path.append(os.path.join(MAXTEXT_REPO_ROOT, "pedagogical_examples")) # Uncomment the import when b/415022795 is fixed -# from pedagogical_examples.shmap_collective_matmul import main +# from maxtext.examples.shmap_collective_matmul import main @pytest.mark.skip(reason="Enable when b/415022795 is fixed")