From 2d7b6e05ac76b9b306c641ab54f49527427eef13 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Thu, 14 May 2026 17:01:37 -0700 Subject: [PATCH 1/3] Feature: add DPO and ORPO preference data preprocessing pipeline utilities Includes robust common prefix extraction for 2-column datasets, prompt suffix truncation, customizable max_prompt_length with validation against max_target_length, and complete integration unit test coverage. --- src/maxtext/configs/base.yml | 16 +- src/maxtext/configs/post_train/dpo.yml | 8 +- src/maxtext/configs/types.py | 91 ++++- src/maxtext/input_pipeline/dpo_utils.py | 101 +++++ .../input_pipeline/hf_data_processing.py | 29 +- .../unit/dpo_data_processing_test.py | 370 ++++++++++++++++++ 6 files changed, 574 insertions(+), 41 deletions(-) create mode 100644 src/maxtext/input_pipeline/dpo_utils.py create mode 100644 tests/post_training/unit/dpo_data_processing_test.py diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 5667b6ec00..45588b0c8b 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -86,7 +86,7 @@ checkpoint_conversion_fn: none # optional checkpoint context to use for loading. options: "orbax", "safetensors" source_checkpoint_layout: "orbax" -# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing +# Only applicable to Single Controller/Pathways on Cloud. Experimental feature, under testing colocated_python_checkpointing: False # enables autocheckpoint, which saves a checkpoint at the preemption step. @@ -448,7 +448,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' # internal_compile allows bypassing open-source topology name mappings when using internal topologies directly via get_topology_desc. internal_compile: False internal_compile_num_devices: -1 # You must specify the number of devices when using internal_compile. -compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_sparse_cores_for_gather_offloading=1 --xla_tpu_scoped_vmem_limit_kib=65536" +compile_xla_flags: "" # Compiler options e.g. compile_xla_flags="--xla_tpu_num_sparse_cores_for_gather_offloading=1 --xla_tpu_scoped_vmem_limit_kib=65536" # Parallelism shard_mode: "auto" # can be either auto or explicit @@ -677,8 +677,12 @@ global_rampup_samples: 500 # direct preference optimization (DPO) use_dpo: False -dpo_label_smoothing: 0.0 -dpo_beta: 0.1 +dpo: + algo: 'dpo' + orpo_lambda: 0.1 + dpo_label_smoothing: 0.0 + dpo_beta: 0.1 + max_prompt_length: null # Supervised Fine-Tuning (SFT) use_sft: False @@ -1203,7 +1207,7 @@ use_jax_splash: false # Path to the HuggingFace-style config directory for the adapter (e.g. src/maxtext/integration/vllm/maxtext_vllm_adapter) vllm_hf_config_path: "" # A JSON string of overrides to apply to the HuggingFace-style config for the vLLM adapter. -# This can be used to override specific settings without modifying the original config file. +# This can be used to override specific settings without modifying the original config file. vllm_hf_overrides: {} # JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}') vllm_additional_config: {} @@ -1218,7 +1222,7 @@ sinkhorn_iterations: 20 ################################## DeepSeek Engram ################################## # Indices of transformer layers where Engram are integrated; leave empty [] to disable. -# Example: [1, 4] attaches to the 2nd and 5th layer. +# Example: [1, 4] attaches to the 2nd and 5th layer. engram_layers: [] # The max 'n' in N-gram. Example: n=3 means it covers both 2-grams and 3-grams. engram_max_ngram_size: 3 diff --git a/src/maxtext/configs/post_train/dpo.yml b/src/maxtext/configs/post_train/dpo.yml index 867d9ed228..5307b845ce 100644 --- a/src/maxtext/configs/post_train/dpo.yml +++ b/src/maxtext/configs/post_train/dpo.yml @@ -1,6 +1,12 @@ base_config: "base.yml" use_dpo: true +dpo: + algo: 'dpo' + orpo_lambda: 0.1 + dpo_label_smoothing: 0.0 + dpo_beta: 0.1 + max_prompt_length: null packing: false train_data_columns: ['chosen', 'rejected'] eval_data_columns: ['chosen', 'rejected'] @@ -24,8 +30,6 @@ hf_eval_split: 'test' gradient_clipping_threshold: 10.0 learning_rate: 5.0e-7 -dpo_label_smoothing: 0.0 -dpo_beta: 0.1 enable_goodput_recording: false monitor_goodput: false diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index a0f436dff3..4fe43c7052 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -505,7 +505,10 @@ class ModelArchitecture(BaseModel): True, description="Whether to apply scale on query and key normalizations (default True).", ) - v_norm_with_scale: bool = Field(True, description="Whether to apply scale on value normalization (default True).") + v_norm_with_scale: bool = Field( + True, + description="Whether to apply scale on value normalization (default True).", + ) class MTP(BaseModel): @@ -684,14 +687,18 @@ class MoEGeneral(BaseModel): num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.") num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.") capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.") - ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.") + ragged_buffer_factor: float = Field( + -1.0, + description="Ragged buffer factor. If < 0, ragged buffer is worst case size.", + ) moe_expert_input_dim: int = Field( -1, description="Dimension of tokens entering the MoE layer. If < 0, defaults to emb_dim.", ) base_moe_mlp_dim: int = Field(-1, description="Intermediate dimension at MoE layer.") padded_base_moe_mlp_dim: Optional[int] = Field( - None, description="Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution." + None, + description="Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution.", ) load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.") use_custom_sort_vjp: bool = Field( @@ -868,7 +875,8 @@ class HardwareAndMesh(BaseModel): ) custom_mesh: str = Field("", description="Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']") custom_mesh_and_rule: CustomRule = Field( - CustomRule.DEFAULT, description="Customized mesh and logical rules for granularity." + CustomRule.DEFAULT, + description="Customized mesh and logical rules for granularity.", ) allow_split_physical_axes: bool = Field(False, description="Allow splitting physical axes for device mesh creation.") enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.") @@ -877,7 +885,8 @@ class HardwareAndMesh(BaseModel): pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.") pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.") remove_size_one_mesh_axis_from_type: bool = Field( - True, description="Whether to remove size one mesh axis from type through jax.config." + True, + description="Whether to remove size one mesh axis from type through jax.config.", ) @@ -898,7 +907,10 @@ class LayoutAndSharding(BaseModel): description="Allowed percentage of non-sharded parameters.", ) shard_optimizer_over_data: bool = Field(False, description="Enable ZeRO-1 optimizer sharding over the data axis.") - internal_compile: bool = Field(False, description="Use internal_compile to bypass open-source topology mappings.") + internal_compile: bool = Field( + False, + description="Use internal_compile to bypass open-source topology mappings.", + ) internal_compile_num_devices: int = Field(-1, description="Number of devices when using internal_compile.") compile_xla_flags: str = Field("", description="Compiler options for compilation only.") @@ -945,7 +957,8 @@ class PipelineParallelism(BaseModel): """Configuration for pipeline parallelism.""" pipeline_fsdp_ag_per_repeat: bool = Field( - False, description="Enable weight prefetching for circular pipeline parallelism." + False, + description="Enable weight prefetching for circular pipeline parallelism.", ) num_layers_per_pipeline_stage: int = Field(1, description="Number of layers to place on each pipeline stage.") num_pipeline_repeats: int = Field( @@ -1189,7 +1202,10 @@ class OlmoGrainDataset(BaseModel): ``data_shuffle_seed``); only OLMo-specific fields are listed here. """ - olmo_index_path: PathStr = Field("", description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.") + olmo_index_path: PathStr = Field( + "", + description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.", + ) olmo_path_remap_from: PathStr = Field( "", description="If set, rewrite index file paths starting with this prefix to olmo_path_remap_to.", @@ -1201,12 +1217,23 @@ class OlmoGrainDataset(BaseModel): olmo_apply_ngram_filter: bool = Field(True, description="Mask repetitive instances per OLMo-core's repetition filter.") +class DPO(BaseModel): + """Configuration for DPO and ORPO preference optimization algorithms.""" + + algo: Literal["dpo", "orpo"] = Field("dpo", description="Alignment algorithm to use.") + dpo_beta: float = Field(0.1, description="Beta parameter for DPO.") + orpo_lambda: float = Field(0.1, description="Weight for preference loss in ORPO.") + dpo_label_smoothing: float = Field(0.0, ge=0.0, le=1.0, description="Label smoothing for DPO.") + max_prompt_length: int | None = Field( + None, + description="Maximum length for prompt. If None, defaults to half of max_target_length.", + ) + + class FineTuning(BaseModel): """Configuration for fine-tuning methods like DPO, SFT, and GRPO.""" use_dpo: bool = Field(False, description="If True, enables Direct Preference Optimization training.") - dpo_label_smoothing: float = Field(0.0, ge=0.0, le=1.0, description="Label smoothing for DPO.") - dpo_beta: float = Field(0.1, description="Beta parameter for DPO.") use_sft: bool = Field(False, description="If True, enables Supervised Fine-Tuning.") sft_train_on_completion_only: bool = Field( False, description="If True, trains only on the completion part of the text." @@ -1274,19 +1301,24 @@ class Distillation(BaseModel): distill_layer_indices: None | list = Field(None, description="Feature indices for feature loss.") distill_alpha_end: Optional[float] = Field(None, description="Target alpha at end of training. None keeps alpha fixed.") distill_alpha_schedule: Literal["constant", "linear", "cosine"] = Field( - "constant", description="Schedule type for alpha annealing ('constant', 'linear', or 'cosine')." + "constant", + description="Schedule type for alpha annealing ('constant', 'linear', or 'cosine').", ) distill_temperature_end: Optional[float] = Field( - None, description="Target temperature at end of training. None keeps temperature fixed." + None, + description="Target temperature at end of training. None keeps temperature fixed.", ) distill_temperature_schedule: Literal["constant", "linear", "cosine"] = Field( - "constant", description="Schedule type for temperature annealing ('constant', 'linear', or 'cosine')." + "constant", + description="Schedule type for temperature annealing ('constant', 'linear', or 'cosine').", ) distill_beta_end: Optional[float] = Field( - None, description="Target beta_feature at end of training. None keeps beta fixed." + None, + description="Target beta_feature at end of training. None keeps beta fixed.", ) distill_beta_schedule: Literal["constant", "linear", "cosine"] = Field( - "constant", description="Schedule type for beta annealing ('constant', 'linear', or 'cosine')." + "constant", + description="Schedule type for beta annealing ('constant', 'linear', or 'cosine').", ) # --- Learn to init related parameters -- @@ -1309,11 +1341,13 @@ class Distillation(BaseModel): ) attn_module_name: Optional[str] = Field( - None, description="Attention nnx module attribute name to augment with LTI logic" + None, + description="Attention nnx module attribute name to augment with LTI logic", ) lti_layer_indices: Optional[list[int]] = Field( - None, description="List of layer indices to apply LTI modifications. If None, applied to all layers." + None, + description="List of layer indices to apply LTI modifications. If None, applied to all layers.", ) # --------------------------------------- @@ -1650,7 +1684,8 @@ class Profiling(BaseModel): tpu_num_chips_to_profile_per_task: int = Field(1, description="Specifies the number of TPU chips to profile per task.") tpu_num_sparse_cores_to_trace: int = Field(2, description="Specifies the number of TPU chips to profile per task.") tpu_num_sparse_core_tiles_to_trace: int = Field( - 1, description="Specifies the number of tiles within each sparse core to trace on the TPU." + 1, + description="Specifies the number of tiles within each sparse core to trace on the TPU.", ) xprof_tpu_power_trace_level: XProfTPUPowerTraceMode = Field( XProfTPUPowerTraceMode.POWER_TRACE_NONE, @@ -2298,6 +2333,10 @@ class MaxTextConfig( """ debug: Debug = Field(default_factory=Debug, description="Configuration for debugging options.") + dpo: DPO = Field( + default_factory=DPO, + description="Configuration for DPO and ORPO alignment algorithms.", + ) rl: RL = Field( default_factory=RL, description="Configuration for RL algorithms like Group Relative Policy Optimization (GRPO).", @@ -2486,7 +2525,11 @@ def validate_and_set_hlo_dump_defaults(): ) for param_name, schedule, end_value in [ ("distill_alpha", self.distill_alpha_schedule, self.distill_alpha_end), - ("distill_temperature", self.distill_temperature_schedule, self.distill_temperature_end), + ( + "distill_temperature", + self.distill_temperature_schedule, + self.distill_temperature_end, + ), ("distill_beta", self.distill_beta_schedule, self.distill_beta_end), ]: if schedule != "constant" and end_value is None: @@ -2883,6 +2926,16 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de raise ValueError("For multimodal SFT, `sft_train_on_completion_only` must be True.") if self.packing: raise ValueError("For multimodal SFT, `packing` is not yet supported.") + if self.use_dpo: + if self.packing: + raise ValueError("For DPO/ORPO, `packing` is not supported.") + if self.dpo.max_prompt_length is not None and self.dpo.max_prompt_length > self.max_target_length: + raise ValueError( + f"`max_prompt_length` ({self.dpo.max_prompt_length}) cannot be " + f"greater than `max_target_length` ({self.max_target_length})." + ) + if self.use_sft and self.use_dpo: + raise ValueError("Only one of `use_sft` or `use_dpo` can be True.") if self.shard_mode == ShardMode.EXPLICIT: supported_decoders = {"simple", "simple_mlp", "llama2", "deepseek"} if self.decoder_block.value not in supported_decoders: diff --git a/src/maxtext/input_pipeline/dpo_utils.py b/src/maxtext/input_pipeline/dpo_utils.py new file mode 100644 index 0000000000..3f12be610a --- /dev/null +++ b/src/maxtext/input_pipeline/dpo_utils.py @@ -0,0 +1,101 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DPO specific input pipeline utilities.""" + +import dataclasses +import grain.python as grain +import numpy as np + + +@dataclasses.dataclass +class DPOTunixPrep(grain.MapTransform): + """Prepares DPO data for Tunix. + Renames input columns, extracts common prefix if needed, generates masks, and performs + DPO-aware padding (left-padded prompts, right-padded responses). + """ + + pad_id: int + max_target_length: int + data_column_names: tuple[str, ...] + max_prompt_length: int | None = None + + def map(self, element): + "Apply the dataset transformations for Tunix-based DPO." + # 1. Reformat/Extract Columns + try: + if len(self.data_column_names) == 3: + input_ids = element[self.data_column_names[0]] + chosen_ids = element[self.data_column_names[1]] + rejected_ids = element[self.data_column_names[2]] + elif len(self.data_column_names) == 2: + # Support for datasets like Anthropic/hh-rlhf where prompt is a common prefix + full_chosen = element[self.data_column_names[0]] + full_rejected = element[self.data_column_names[1]] + + # Find common prefix length + prefix_len = 0 + for c, r in zip(full_chosen, full_rejected): + if c != r: + break + prefix_len += 1 + input_ids = full_chosen[:prefix_len] + chosen_ids = full_chosen[prefix_len:] + rejected_ids = full_rejected[prefix_len:] + else: + raise ValueError(f"DPOTunixPrep expects 2 or 3 columns, got {len(self.data_column_names)}") + except KeyError as e: + raise KeyError( + f"Column '{e.args[0]}' not found in the dataset. " + f"Expected columns: {self.data_column_names}. " + f"Available columns: {list(element.keys())}. " + "Please verify that 'train_data_columns' and 'eval_data_columns' match your dataset." + ) from e + + # 2. Padding and Masking + max_prompt_length = self.max_prompt_length or (self.max_target_length // 2) + max_response_length = self.max_target_length - max_prompt_length + + prompt_ids = self._pad(input_ids, max_prompt_length, left=True) + chosen_ids = self._pad(chosen_ids, max_response_length, left=False) + rejected_ids = self._pad(rejected_ids, max_response_length, left=False) + + # Remove old columns if they exist + for key in self.data_column_names: + if key in element: + del element[key] + + element["prompt_ids"] = prompt_ids + element["chosen_ids"] = chosen_ids + element["rejected_ids"] = rejected_ids + element["prompt_mask"] = (prompt_ids != self.pad_id).astype(np.int32) + element["chosen_mask"] = (chosen_ids != self.pad_id).astype(np.int32) + element["rejected_mask"] = (rejected_ids != self.pad_id).astype(np.int32) + return element + + def _pad(self, x, length, left=False): + """Pads or trims an array to a specific length. + + When left=True (for prompts), trims from the left to keep the suffix (closest context). + When left=False (for responses), trims from the right to keep the prefix. + """ + x = np.asarray(x) + pad_amount = max(length - x.shape[0], 0) + if left: + pad_width = ((pad_amount, 0),) + x_trimmed = x[-length:] + else: + pad_width = ((0, pad_amount),) + x_trimmed = x[:length] + return np.pad(x_trimmed, pad_width, constant_values=self.pad_id).astype(np.int32) diff --git a/src/maxtext/input_pipeline/hf_data_processing.py b/src/maxtext/input_pipeline/hf_data_processing.py index 489e55c77f..057bc2c59a 100644 --- a/src/maxtext/input_pipeline/hf_data_processing.py +++ b/src/maxtext/input_pipeline/hf_data_processing.py @@ -1,4 +1,4 @@ -# Copyright 2023–2025 Google LLC +# Copyright 2023–2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,9 +24,8 @@ import grain.python as grain -import numpy as np - from maxtext.input_pipeline import data_processing_utils +from maxtext.input_pipeline import dpo_utils from maxtext.input_pipeline import input_pipeline_utils from maxtext.input_pipeline import instruction_data_processing from maxtext.input_pipeline import multihost_dataloading @@ -214,7 +213,7 @@ def preprocessing_pipeline( num_threads=1, drop_remainder=True, generate_padding_batch=False, - use_dpo=None, + use_dpo=False, use_sft=None, use_tunix_gradient_accumulation=False, num_microbatches=1, @@ -330,19 +329,12 @@ def preprocessing_pipeline( ) ) data_column_names = ("inputs", "targets") - elif use_dpo: - - def lists2array(x): - """Convert lists/tuples to array""" - return jax.tree.map(np.asarray, x, is_leaf=lambda y: isinstance(y, (list, tuple))) - - operations.append(grain.MapOperation(lists2array)) - else: + elif not use_dpo: assert len(data_column_names) == 1 operations.append(input_pipeline_utils.HFNormalizeFeatures(data_column_names[0])) data_column_names = ("inputs", "targets") - if packing and not use_dpo: + if packing: length_struct = {col: max_target_length for col in data_column_names} max_segments = max_segments_per_seq if max_segments is not None and max_segments <= 0: @@ -356,7 +348,16 @@ def lists2array(x): ) operations.append(input_pipeline_utils.ReformatPacking(data_column_names)) else: - operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) + if use_dpo: + # Renames arbitrary DPO columns and performs DPO-aware padding. + max_prompt_length = ( + config.dpo.get("max_prompt_length") + if isinstance(config.dpo, dict) + else getattr(config.dpo, "max_prompt_length", None) + ) + operations.append(dpo_utils.DPOTunixPrep(pad_id, max_target_length, data_column_names, max_prompt_length)) + else: + operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) operations.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder)) if shift and not use_dpo: diff --git a/tests/post_training/unit/dpo_data_processing_test.py b/tests/post_training/unit/dpo_data_processing_test.py new file mode 100644 index 0000000000..3b0a00814a --- /dev/null +++ b/tests/post_training/unit/dpo_data_processing_test.py @@ -0,0 +1,370 @@ +# Copyright 2025–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for DPO data preparation.""" +import os +import unittest +from datasets import Dataset +import jax +from jax.experimental import mesh_utils +from jax.sharding import Mesh +import numpy as np +import pytest +import transformers + +from maxtext.configs import pyconfig +from maxtext.input_pipeline import dpo_utils +from maxtext.input_pipeline import hf_data_processing +from maxtext.input_pipeline import input_pipeline_interface +from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR + +pytestmark = [pytest.mark.post_training, pytest.mark.cpu_only] + + +class TestDPOTunixPrep(unittest.TestCase): + """Tests for DPOTunixPrep transform.""" + + def setUp(self): + self.pad_id = 0 + + def test_column_remapping(self): + """Verify that columns are renamed to match Tunix expectations.""" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=21, + data_column_names=("input", "chosen", "rejected"), + ) + sample = { + "input": np.array([1, 2, 3]), + "chosen": np.array([4, 5]), + "rejected": np.array([6, 7, 8]), + } + output = prep.map(sample) + + # Check that old keys are removed + self.assertNotIn("input", output) + self.assertNotIn("chosen", output) + self.assertNotIn("rejected", output) + + # Check that new keys exist + self.assertIn("prompt_ids", output) + self.assertIn("chosen_ids", output) + self.assertIn("rejected_ids", output) + self.assertIn("prompt_mask", output) + self.assertIn("chosen_mask", output) + self.assertIn("rejected_mask", output) + + def test_two_column_prefix_extraction(self): + """Verify common prefix extraction for 2-column datasets.""" + # The column names will be remappend into "chosen" and "rejected" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=21, + data_column_names=("liked", "disliked"), + ) + sample = { + "liked": np.array([1, 2, 3, 10, 11]), + "disliked": np.array([1, 2, 3, 20, 21, 22]), + } + output = prep.map(sample) + + # Prefix is [1, 2, 3], left-padded. + self.assertEqual(output["prompt_ids"].shape[0], 10) + np.testing.assert_array_equal(output["prompt_ids"], [self.pad_id] * 7 + [1, 2, 3]) + # Prompt mask for [0, 0, 0, 0, 0, 1, 2, 3] should be [0, 0, 0, 0, 0, 1, 1, 1] + np.testing.assert_array_equal(output["prompt_mask"], [0] * 7 + [1, 1, 1]) + + # chosen_ids (len 11) should be right-padded + self.assertEqual(output["chosen_ids"].shape[0], 11) + np.testing.assert_array_equal(output["chosen_ids"], [10, 11] + [self.pad_id] * 9) + # chosen_mask for [10, 11, 0, 0, ...] should be [1, 1, 0, 0, ...] + np.testing.assert_array_equal(output["chosen_mask"], [1, 1] + [0] * 9) + + # rejected_ids (len 11) should be right-padded + self.assertEqual(output["rejected_ids"].shape[0], 11) + np.testing.assert_array_equal(output["rejected_ids"], [20, 21, 22] + [self.pad_id] * 8) + # rejected_mask for [20, 21, 22, 0, 0, ...] should be [1, 1, 1, 0, 0, ...] + np.testing.assert_array_equal(output["rejected_mask"], [1, 1, 1] + [0] * 8) + + def test_three_column_remapping(self): + """Verify standard 3-column remapping.""" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=20, + data_column_names=("input", "chosen", "rejected"), + ) + sample = { + "input": np.array([1, 2]), + "chosen": np.array([3, 4]), + "rejected": np.array([5, 6, 7]), + } + output = prep.map(sample) + + self.assertNotIn("input", output) + # Prompt should be left-padded + np.testing.assert_array_equal(output["prompt_ids"], [self.pad_id] * 8 + [1, 2]) + np.testing.assert_array_equal(output["prompt_mask"], [0] * 8 + [1, 1]) + + # Chosen and rejected are right-padded + np.testing.assert_array_equal(output["chosen_ids"], [3, 4] + [self.pad_id] * 8) + np.testing.assert_array_equal(output["chosen_mask"], [1, 1] + [0] * 8) + np.testing.assert_array_equal(output["rejected_ids"], [5, 6, 7] + [self.pad_id] * 7) + np.testing.assert_array_equal(output["rejected_mask"], [1, 1, 1] + [0] * 7) + + def test_three_column_truncation(self): + """Verify that prompts are suffix-truncated and responses are prefix-truncated.""" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=9, + data_column_names=("input", "chosen", "rejected"), + ) + sample = { + "input": np.arange(1, 10), + "chosen": np.arange(10, 20), + "rejected": np.arange(20, 30), + } + output = prep.map(sample) + + # Prompt is suffix-truncated to 4 chars (keeps the end). + np.testing.assert_array_equal(output["prompt_ids"], np.arange(6, 10)) + np.testing.assert_array_equal(output["prompt_mask"], [1] * 4) + + # Chosen and rejected are prefix-truncated to 5 chars (keeps the start). + np.testing.assert_array_equal(output["chosen_ids"], np.arange(10, 15)) + np.testing.assert_array_equal(output["chosen_mask"], [1] * 5) + np.testing.assert_array_equal(output["rejected_ids"], np.arange(20, 25)) + np.testing.assert_array_equal(output["rejected_mask"], [1] * 5) + + def test_two_column_prefix_edge_cases(self): + """Verify prefix extraction robustness with identical strings or prefix strings.""" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=20, + data_column_names=("chosen", "rejected"), + ) + + # Case 1: Identical strings + identical_sample = { + "chosen": np.array([1, 2, 3]), + "rejected": np.array([1, 2, 3]), + } + out_identical = prep.map(identical_sample) + # Entire string is prefix; suffixes are empty (padded with pad_id) + np.testing.assert_array_equal(out_identical["prompt_ids"][-3:], [1, 2, 3]) + np.testing.assert_array_equal(out_identical["chosen_ids"], [self.pad_id] * 10) + np.testing.assert_array_equal(out_identical["rejected_ids"], [self.pad_id] * 10) + + # Case 2: One is a prefix of another + prefix_sample = { + "chosen": np.array([1, 2, 3]), + "rejected": np.array([1, 2, 3, 4, 5]), + } + out_prefix = prep.map(prefix_sample) + np.testing.assert_array_equal(out_prefix["prompt_ids"][-3:], [1, 2, 3]) + np.testing.assert_array_equal(out_prefix["chosen_ids"], [self.pad_id] * 10) + np.testing.assert_array_equal(out_prefix["rejected_ids"][:2], [4, 5]) + + def test_max_prompt_length_override(self): + """Verify that max_prompt_length can be customized to adjust prompt/response split ratio.""" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=20, + data_column_names=("input", "chosen", "rejected"), + max_prompt_length=15, + ) + sample = { + "input": np.array([1, 2]), + "chosen": np.array([3, 4]), + "rejected": np.array([5, 6]), + } + output = prep.map(sample) + + # Prompt length should be 15, response length should be 5 + self.assertEqual(output["prompt_ids"].shape[0], 15) + self.assertEqual(output["chosen_ids"].shape[0], 5) + self.assertEqual(output["rejected_ids"].shape[0], 5) + + def test_missing_column_error(self): + """Verify that a helpful error is raised when a column is missing.""" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=20, + data_column_names=("input", "chosen", "rejected"), + ) + # 'rejected' column is missing + sample = {"input": np.array([1, 2]), "chosen": np.array([3, 4])} + with self.assertRaisesRegex(KeyError, "Column 'rejected' not found in the dataset"): + prep.map(sample) + + +@pytest.mark.external_training +class TestDPOPipelineIntegration(unittest.TestCase): + """End-to-end DPO pipeline integration tests.""" + + def setUp(self): + super().setUp() + self.config = pyconfig.initialize( + [ + os.path.join(MAXTEXT_PKG_DIR, "dpo_trainer"), + os.path.join(MAXTEXT_CONFIGS_DIR, "post_train", "dpo.yml"), + ], + per_device_batch_size=2, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + tokenizer_path="Qwen/Qwen3-4B", + train_split="train", + enable_checkpointing=False, + use_dpo=True, + enable_data_shuffling=False, + max_target_length=64, + ) + self.mesh_shape_1d = (len(jax.devices()),) + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.process_indices = input_pipeline_interface.get_process_loading_real_data( + self.config.data_sharding, + self.config.global_batch_size_to_load, + self.config.global_batch_size_to_train_on, + self.config.max_target_length, + self.mesh, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + self.config.tokenizer_path, + add_bos_token=False, + add_eos_token=False, + legacy=False, + ) + self.pad_id = hf_data_processing._get_pad_id(self.tokenizer) # pylint: disable=protected-access + + def get_data_iterator(self, dataset, data_columns): + """Helper to initialize the preprocessing pipeline.""" + return hf_data_processing.preprocessing_pipeline( + dataloading_host_index=self.process_indices.index(jax.process_index()), + dataloading_host_count=len(self.process_indices), + global_mesh=self.mesh, + dataset=dataset, + config=self.config, + data_column_names=data_columns, + tokenize=self.config.tokenize_train_data, + tokenizer_path=self.config.tokenizer_path, + hf_access_token=self.config.hf_access_token, + global_batch_size=self.config.global_batch_size_to_load, + max_target_length=self.config.max_target_length, + shuffle=self.config.enable_data_shuffling, + data_shuffle_seed=self.config.data_shuffle_seed, + add_bos=self.config.add_bos, + add_eos=self.config.add_eos, + packing=self.config.packing, + generate_padding_batch=False, + use_dpo=self.config.use_dpo, + use_sft=self.config.use_sft, + sft_train_on_completion_only=self.config.sft_train_on_completion_only, + grain_worker_count=0, + ) + + def test_dpo_format_3_columns(self): + """Verify that the 3-column explicit DPO dataset is processed correctly.""" + prompt_str = "Question: What is 2+2?" + chosen_str = "Answer: 4" + rejected_str = "Answer: 5" + + dataset = Dataset.from_dict( + { + "input": [prompt_str] * 10, + "chosen": [chosen_str] * 10, + "rejected": [rejected_str] * 10, + } + ) + data_iter = self.get_data_iterator(dataset, ["input", "chosen", "rejected"]) + batch = next(data_iter) + + # Verify expected keys + for key in ( + "prompt_ids", + "chosen_ids", + "rejected_ids", + "prompt_mask", + "chosen_mask", + "rejected_mask", + ): + self.assertIn(key, batch) + + # Verify batch dimensions match global batch size and split max_target_length + max_prompt_len = self.config.max_target_length // 2 + max_response_len = self.config.max_target_length - max_prompt_len + self.assertEqual( + batch["prompt_ids"].shape, + (self.config.global_batch_size_to_load, max_prompt_len), + ) + self.assertEqual( + batch["chosen_ids"].shape, + (self.config.global_batch_size_to_load, max_response_len), + ) + self.assertEqual( + batch["rejected_ids"].shape, + (self.config.global_batch_size_to_load, max_response_len), + ) + + # Verify decoded content directly + decoded_prompt = self.tokenizer.decode(batch["prompt_ids"][0], skip_special_tokens=True) + decoded_chosen = self.tokenizer.decode(batch["chosen_ids"][0], skip_special_tokens=True) + decoded_rejected = self.tokenizer.decode(batch["rejected_ids"][0], skip_special_tokens=True) + + self.assertEqual(decoded_prompt, prompt_str) + self.assertEqual(decoded_chosen, chosen_str) + self.assertEqual(decoded_rejected, rejected_str) + + # Verify mask structure (left padding for prompt -> 1s at the end; right padding for responses -> 1s at start) + self.assertEqual(batch["prompt_mask"][0][-1], 1) + self.assertEqual(batch["chosen_mask"][0][0], 1) + self.assertEqual(batch["rejected_mask"][0][0], 1) + + def test_dpo_format_2_columns(self): + """Verify that 2-column DPO datasets correctly extract common prefixes.""" + # We use a clear common prefix and different suffixes + prefix = "Common prompt context for DPO:" + chosen_suffix = " the chosen completion" + rejected_suffix = " the rejected completion" + + dataset = Dataset.from_dict( + { + "chosen": [prefix + chosen_suffix] * 10, + "rejected": [prefix + rejected_suffix] * 10, + } + ) + data_iter = self.get_data_iterator(dataset, ["chosen", "rejected"]) + batch = next(data_iter) + + # Verify decoded extracted prefix and completions robustly against BPE token boundary quirks + decoded_prompt = self.tokenizer.decode(batch["prompt_ids"][0], skip_special_tokens=True) + decoded_chosen = self.tokenizer.decode(batch["chosen_ids"][0], skip_special_tokens=True) + decoded_rejected = self.tokenizer.decode(batch["rejected_ids"][0], skip_special_tokens=True) + + self.assertIn("Common prompt context", decoded_prompt) + self.assertIn("chosen", decoded_chosen) + self.assertIn("rejected", decoded_rejected) + + def test_dpo_invalid_column_count(self): + """Verify that passing an unsupported number of columns raises an error.""" + dataset = Dataset.from_dict({"col1": ["a"] * 10}) + with self.assertRaises((ValueError, KeyError)): + # DPOTunixPrep expects 2 or 3 columns + data_iter = self.get_data_iterator(dataset, ["col1"]) + next(data_iter) + + +if __name__ == "__main__": + unittest.main() From c123a44b185b9fa3604e182d9dca3b3888a72e93 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Thu, 14 May 2026 15:20:41 -0700 Subject: [PATCH 2/3] Feature: integrate post-training Direct Preference Optimization (DPO) trainer and execution utilities --- docs/tutorials/post_training_index.md | 6 +- docs/tutorials/posttraining/dpo.md | 110 ++++++ src/maxtext/common/metric_logger.py | 21 +- src/maxtext/configs/pyconfig.py | 6 +- .../trainers/post_train/dpo/train_dpo.py | 188 +++++++++++ src/maxtext/utils/maxtext_utils.py | 316 ++++-------------- tests/end_to_end/tpu/test_dpo.sh | 30 +- 7 files changed, 378 insertions(+), 299 deletions(-) create mode 100644 docs/tutorials/posttraining/dpo.md create mode 100644 src/maxtext/trainers/post_train/dpo/train_dpo.py diff --git a/docs/tutorials/post_training_index.md b/docs/tutorials/post_training_index.md index d277cfb4be..e637ac0c14 100644 --- a/docs/tutorials/post_training_index.md +++ b/docs/tutorials/post_training_index.md @@ -26,8 +26,8 @@ MaxText was co-designed with key Google led innovations to provide a unified pos - **SFT (Supervised Fine-Tuning)** - [SFT on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html) - [SFT on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft_on_multi_host.html) -- **LoRA (Low-Rank Adaptation)** - - [LoRA on Single-Host TPUs](posttraining/lora.md) +- **DPO (Direct Preference Optimization) and ORPO (Odds-Ratio Policy Optimization)** + - [DPO/ORPO on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/dpo.html) - **Multimodal SFT** - [Multimodal Support](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/multimodal.html) - **Reinforcement Learning (RL)** @@ -67,10 +67,10 @@ maxdepth: 1 --- posttraining/sft.md posttraining/sft_on_multi_host.md +posttraining/dpo.md posttraining/rl.md posttraining/rl_on_multi_host.md posttraining/knowledge_distillation.md -posttraining/lora.md posttraining/multimodal.md posttraining/full_finetuning.md posttraining/gepa_optimization.md diff --git a/docs/tutorials/posttraining/dpo.md b/docs/tutorials/posttraining/dpo.md new file mode 100644 index 0000000000..582438abed --- /dev/null +++ b/docs/tutorials/posttraining/dpo.md @@ -0,0 +1,110 @@ + + +# Preference Optimization (DPO & ORPO) on Single-Host TPUs + +MaxText supports two primary methods for aligning models with human preferences: **Direct Preference Optimization (DPO)** and **Odds Ratio Preference Optimization (ORPO)**. Both methods avoid the complexity of traditional Reinforcement Learning from Human Feedback (RLHF) by optimizing directly on preference data. + +## DPO vs. ORPO + +- **Direct Preference Optimization (DPO):** Optimizes the policy by maximizing the relative log-probability of preferred responses over rejected ones. DPO requires a **reference model** (a frozen copy of the base model) to regularize the training and ensure the policy does not drift too far from the original model's distribution. +- **Odds Ratio Preference Optimization (ORPO):** A newer, reference-free alignment method that integrates the preference loss directly into the supervised fine-tuning objective using an odds ratio. Because it **does not require a reference model**, ORPO is more memory-efficient and faster than DPO. + +## Data Requirements + +Both methods consume preference data in a **triplet format** consisting of a Prompt, a Chosen response, and a Rejected response. MaxText supports two ways to provide this data via the `train_data_columns` configuration: + +1. **Explicit Triplets (3 Columns):** The dataset provides three distinct columns for the prompt, chosen response, and rejected response. +2. **Shared Prefix (2 Columns):** For datasets like `Anthropic/hh-rlhf`, where the prompt is embedded at the beginning of the responses, you can provide just two columns (e.g., `chosen` and `rejected`). MaxText will automatically extract the shared common prefix as the **Prompt** and treat the differing suffixes as the responses. + +During the input pipeline, prompts are left-padded and responses are right-padded to maintain optimal context for the model. + +## Prerequisites + +For instructions on installing MaxText with post-training dependencies on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) and use the `maxtext[tpu-post-train]` installation path. + +## Local run on a single-host TPU VM + +### Setup environment variables + +Login to Hugging Face: + +```bash +hf auth login +``` + +Set up your training environment: + +```bash +# -- Model configuration -- +# The MaxText model name. See `src/maxtext/configs/types.py` for `ModelName` for a +# full list of supported models. +export MODEL= # e.g., "qwen3-0.6b" + +# -- MaxText configuration -- +# Use a GCS bucket you own to store logs and checkpoints. Ideally in the same +# region as your TPUs to minimize latency and costs. +# You can list your buckets and their locations in the +# [Cloud Console](https://console.cloud.google.com/storage/browser). +export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/maxtext-runs + +# An arbitrary string to identify this specific run. +# We recommend to include the model, user, and timestamp. +# Note: Kubernetes requires workload names to be valid DNS labels (lowercase, no underscores or periods). +export RUN_NAME= + +export STEPS= # e.g., 1000 +export PER_DEVICE_BATCH_SIZE= # e.g., 1 + +export ALGORITHM=<"dpo" or "orpo"> # Set to either "orpo" or "dpo" + +# -- Dataset configuration -- +export DATASET_NAME= # e.g., "argilla/distilabel-intel-orca-dpo-pairs" +export TRAIN_SPLIT= # e.g., train + +# Map your dataset columns to [Prompt, Chosen, Rejected] +# For 3-column datasets: +export TRAIN_DATA_COLUMNS="['input', 'chosen', 'rejected']" + +# For 2-column datasets (Prefix Extraction): +# export TRAIN_DATA_COLUMNS="['chosen', 'rejected']" +``` + +## Running DPO Training + +You can run the DPO training using the specialized post-training script: + +```{note} +The script below uses `eval_interval=0` because the default "argilla/distilabel-intel-orca-dpo-pairs" dataset only has a "train" split. +To use the same split for eval you can set a non-zero value and add `hf_eval_split=train`. +``` + +```bash +python3 -m maxtext.trainers.post_train.dpo.train_dpo \ + run_name=${RUN_NAME?} \ + base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + model_name=${MODEL?} \ + dataset_type=hf \ + hf_path=${DATASET_NAME?} \ + train_split=${TRAIN_SPLIT?} \ + train_data_columns="${TRAIN_DATA_COLUMNS?}" \ + steps=${STEPS?} \ + eval_interval=0 \ + per_device_batch_size=1 \ + max_target_length=1024 \ + use_dpo=1 \ + dpo.algo=${ALGORITHM?} +``` diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index 114c0e1519..f1b83f5349 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -180,9 +180,10 @@ def _log_training_metrics(self, metrics, step): f"perplexity: {perplexity:.3f}", ] ) - if self.config.use_dpo: - dpo_loss = scalars.get("learning/dpo_loss", 0.0) - log_parts.append(f"dpo_loss: {dpo_loss:.3f}") + if "learning/dpo_loss" in scalars: + log_parts.append(f"dpo_loss: {scalars['learning/dpo_loss']:.3f}") + if "learning/reward_accuracy" in scalars: + log_parts.append(f"reward_accuracy: {scalars['learning/reward_accuracy']:.3f}") if self.config.num_experts > 1: moe_lb_loss = scalars.get("learning/moe_lb_loss", 0.0) @@ -295,12 +296,12 @@ def write_metrics_to_managed_mldiagnostics(self, metrics, step): def write_setup_info_to_tensorboard(self, params): """Writes setup information like train config params, num model params, and XLA flags to TensorBoard.""" + if not self.config.enable_tensorboard: + return num_model_parameters = max_utils.calculate_num_params_from_pytree(params) self.metadata[MetadataKey.PER_DEVICE_TFLOPS], _, _ = maxtext_utils.calculate_tflops_training_per_device(self.config) self.metadata[MetadataKey.PER_DEVICE_TOKENS] = maxtext_utils.calculate_tokens_training_per_device(self.config) max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion") - if not self.config.enable_tensorboard: - return max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), self.writer) max_utils.add_text_to_summary_writer("libtpu_init_args", os.getenv("LIBTPU_INIT_ARGS", ""), self.writer) maxtext_utils.add_config_to_summary_writer(self.config, self.writer) @@ -374,9 +375,9 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None): metrics["scalar"].get("evaluation/mtp_acceptance_rate_percent", 0.0) ) self.cumulative_eval_metrics["scalar"]["eval/z_loss"] += float(metrics["scalar"].get("evaluation/z_loss", 0.0)) - if self.config.use_dpo: + if "evaluation/dpo_reward_accuracy" in metrics["scalar"]: self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] += float( - metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0) + metrics["scalar"]["evaluation/dpo_reward_accuracy"] ) if eval_step_count: @@ -400,10 +401,8 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None): self.cumulative_eval_metrics["scalar"]["eval/avg_z_loss"] = ( self.cumulative_eval_metrics["scalar"]["eval/z_loss"] / eval_step_count ) - if self.config.use_dpo: - self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] = ( - self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] / eval_step_count - ) + if "eval/dpo_reward_accuracy" in self.cumulative_eval_metrics["scalar"]: + self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] /= eval_step_count self.write_metrics(self.cumulative_eval_metrics, step, is_training=False) diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index c8da22b504..b4541596de 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -52,6 +52,7 @@ "maxtext.trainers.pre_train.train": "base.yml", "maxtext.trainers.pre_train.train_compile": "base.yml", "maxtext.trainers.post_train.distillation.train_distill": "post_train/distillation.yml", + "maxtext.trainers.post_train.dpo.train_dpo": "post_train/dpo.yml", "maxtext.trainers.post_train.rl.train_rl": "post_train/rl.yml", "maxtext.trainers.post_train.sft.train_sft": "post_train/sft.yml", "maxtext.trainers.post_train.sft.train_sft_deprecated": "post_train/sft.yml", @@ -324,11 +325,6 @@ def initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConfi # 2. Get overrides from CLI and kwargs cli_cfg = omegaconf.OmegaConf.from_cli(cli_args) - if "hf_access_token" in cli_cfg: - logger.warning( - "WARNING: Passing 'hf_access_token' via command-line arguments is deprecated and insecure because it makes " - "your token visible in 'ps' and shell history. Please set the 'HF_TOKEN' environment variable instead." - ) kwargs_cfg = omegaconf.OmegaConf.create(kwargs) overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg) diff --git a/src/maxtext/trainers/post_train/dpo/train_dpo.py b/src/maxtext/trainers/post_train/dpo/train_dpo.py new file mode 100644 index 0000000000..adf76c861e --- /dev/null +++ b/src/maxtext/trainers/post_train/dpo/train_dpo.py @@ -0,0 +1,188 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DPO Training script that uses Tunix DPOTrainer on a MaxText model. + +Example command: +Training & Evaluation: + python3 -m maxtext.trainers.post_train.dpo.train_dpo \ + run_name=${WORKLOAD?} base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + tokenizer_path="google/gemma-2-2b-it" tokenizer_type=huggingface \ + dataset_type="hf" hf_path="Anthropic/hh-rlhf" hf_eval_split="test" \ + train_data_columns="['chosen', 'rejected']" eval_data_columns="['chosen', 'rejected']" \ + model_name=${MODEL?} load_parameters_path=${MAXTEXT_CONVERTED_CHECKPOINT?}/0/items \ + hf_access_token=${HF_TOKEN?} per_device_batch_size=1 max_target_length=1024 \ + eval_interval=2 eval_steps=2 steps=10 profiler=xplane weight_dtype=bfloat16 +""" + +from absl import app +import jax +import optax +from orbax import checkpoint as ocp +import pathwaysutils + +from flax import nnx +from flax.linen import partitioning as nn_partitioning + +from tunix.sft import metrics_logger, profiler +from tunix.sft.dpo.dpo_trainer import DPOTrainer, DPOTrainingConfig + +from maxtext.common.goodput import ( + GoodputEvent, + RECORD_JOB_END_TIME, + RECORD_JOB_START_TIME, + create_goodput_recorder, + maybe_monitor_goodput, + maybe_record_goodput, + record_goodput, +) +from maxtext.configs import pyconfig +from maxtext.optimizers import optimizers +from maxtext.trainers.post_train.dpo import hooks +from maxtext.utils import max_logging +from maxtext.utils import max_utils +from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils + + +def get_tunix_config(mt_config: pyconfig.HyperParameters) -> DPOTrainingConfig: + """Gets the Tunix training configurations from the MaxText config. + + Args: + mt_config: MaxText config. + + Returns: + A Tunix `DPOTrainingConfig` object. + """ + # Checkpointing configurations + checkpointing_options = ocp.CheckpointManagerOptions( + save_interval_steps=mt_config.checkpoint_period, + enable_async_checkpointing=mt_config.async_checkpointing, + ) + + # Metrics configurations + metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=mt_config.tensorboard_dir) + + # Profiler configurations + profiler_options = None + if mt_config.profiler: + set_profile_options = True + platform_version = jax.extend.backend.get_backend().platform_version.strip() + if platform_version.startswith("Pathways"): + max_logging.log("Pathways backend detected. Disabling setting profile options.") + set_profile_options = False + profiler_options = profiler.ProfilerOptions( + log_dir=mt_config.tensorboard_dir, + skip_first_n_steps=mt_config.skip_first_n_steps_for_profiler, + profiler_steps=mt_config.profiler_steps, + set_profile_options=set_profile_options, + ) + + max_prompt_length = mt_config.max_target_length // 2 + return DPOTrainingConfig( + eval_every_n_steps=mt_config.eval_interval, + max_steps=mt_config.steps, + gradient_accumulation_steps=mt_config.gradient_accumulation_steps, + checkpoint_root_directory=mt_config.checkpoint_dir, + checkpointing_options=checkpointing_options, + metrics_logging_options=metrics_logging_options, + profiler_options=profiler_options, + algorithm=mt_config.dpo.algo, + lambda_orpo=mt_config.dpo.orpo_lambda, + beta=mt_config.dpo.dpo_beta, + label_smoothing=mt_config.dpo.dpo_label_smoothing, + max_prompt_length=max_prompt_length, + max_response_length=mt_config.max_target_length - max_prompt_length, + ) + + +def setup_trainer_state(mt_config, goodput_recorder=None): + """Set up prerequisites for training loop.""" + tunix_config = get_tunix_config(mt_config) + + with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): + model, mesh = model_creation_utils.from_pretrained(mt_config, wrap_with_tunix_adapter=True) + + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) + # pass in model for muon + optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model) + + if mt_config.gradient_clipping_threshold > 0: + optimizer = optax.chain( + optax.clip_by_global_norm(max_norm=mt_config.gradient_clipping_threshold), + optimizer, + ) + + # ORPO does not require a reference model. + ref_model = nnx.clone(model) if mt_config.dpo.algo == "dpo" else None + + with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION): + training_hooks = hooks.DPOTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder) + data_hooks = hooks.DPODataHooks(mt_config, mesh, goodput_recorder) + + # Provide rules context so logical axes (e.g. 'norm') are translated to mesh axes during maybe_restore + with nn_partitioning.axis_rules(mt_config.logical_axis_rules): + trainer = DPOTrainer( + model=model, ref_model=ref_model, optimizer=optimizer, training_config=tunix_config, tokenizer=None + ) + trainer.with_training_hooks(training_hooks) + trainer.with_data_hooks(data_hooks) + + return trainer, mesh + + +def train_model(mt_config: pyconfig.HyperParameters, trainer, mesh): + """Runs the DPO training loop in Tunix.""" + with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): + trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator) + return trainer + + +def train(mt_config, goodput_recorder=None): + """Main method for DPO training. + + Args: + mt_config: MaxText config. + goodput_recorder: An optional GoodputRecorder to record performance metrics. + """ + trainer, mesh = setup_trainer_state(mt_config, goodput_recorder) + _job_completed_gracefully = False + try: + trainer = train_model(mt_config, trainer, mesh) + _job_completed_gracefully = True + finally: + if _job_completed_gracefully: + record_goodput(goodput_recorder, RECORD_JOB_END_TIME) + return trainer, mesh + + +def main(argv: list[str]) -> None: + """Main function to run DPO training. + + Args: + argv: Command-line arguments. + """ + pathwaysutils.initialize() + + mt_config = pyconfig.initialize(argv) + max_utils.print_system_information() + + goodput_recorder = create_goodput_recorder(mt_config) + record_goodput(goodput_recorder, RECORD_JOB_START_TIME) + with maybe_monitor_goodput(mt_config): + train(mt_config, goodput_recorder) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 0f07f5c14d..7fbf17e8e7 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -16,23 +16,25 @@ """Utils that are only interesting to MaxText.""" import functools +import pickle import os from typing import Sequence -from flax import nnx, linen as nn -from flax.core.spmd import composite_rules, from_sharding_rules, get_logical_axis_rules +from flax import linen as nn from flax.linen import partitioning as nn_partitioning -from flax.training.train_state import TrainState +from flax.training import train_state import numpy as np -import jax -import jax.numpy as jnp -from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec from jax.experimental import mesh_utils from jax.experimental.serialize_executable import deserialize_and_load +from jax.sharding import AxisType, Mesh + +import jax +import jax.numpy as jnp import optax + import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager @@ -53,7 +55,6 @@ from maxtext.utils import max_utils from maxtext.utils import sharding from maxtext.utils import elastic_utils -from maxtext.utils import maxtext_utils_nnx OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" @@ -101,10 +102,7 @@ def get_functional_train_with_signature( """Get the shardings (both state and data) for `train_step`.""" functional_train = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) functional_train.__name__ = "train_step" - if config.pure_nnx: - in_shardings = (state_mesh_shardings, data_sharding) # State, batch - else: - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = (state_mesh_shardings, None) # State, metrics static_argnums = () # We partial out the static argnums of model and config donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. @@ -115,10 +113,7 @@ def get_functional_eval_with_signature(eval_step, data_sharding, state_mesh_shar """Get the shardings (both state and data) for `eval_step`.""" functional_eval = functools.partial(eval_step, model, config) functional_eval.__name__ = "eval_step" - if config.pure_nnx: - in_shardings = (state_mesh_shardings, data_sharding) # State, batch (NNX: no rng) - else: - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = None # metrics static_argnums = () # We partial out the static argnums of model, config donate_argnums = () # state will be kept instead of being donated in eval_step @@ -215,7 +210,8 @@ def load_compiled(config, partial_train, state, execution_devices): # Parker is working on a serializing these def load_serialized_compiled(save_name): with open(save_name, "rb") as f: - return f.read() + serialized_compiled = pickle.load(f) + return serialized_compiled def get_train_input_output_trees(func, input_args, input_kwargs): _, in_tree_recreated = jax.tree_util.tree_flatten((input_args, input_kwargs)) @@ -1069,11 +1065,13 @@ def calculate_tflops_training_per_device(config, log=True): learnable_weight_tflops = learnable_weight_tflops * config.gradient_accumulation_steps attention_tflops = attention_tflops * config.gradient_accumulation_steps - # DPO includes one additional forward pass per gradient accumulation step + # DPO includes one additional forward pass with the reference model per gradient accumulation step + # ORPO does not need the extra forward pass. if config.use_dpo: + # Typically flops = ForwardFlops + BackwardFlops. Backward is ~2x of ForwardFlops. + # The reference model only runs forward. reference_model_tflops = learnable_weight_tflops / 3 # additional forward pass - reference_model_attention_tflops = attention_tflops / 3 - attention_tflops = attention_tflops + reference_model_attention_tflops + attention_tflops += attention_tflops / 3 else: reference_model_tflops = 0 @@ -1240,15 +1238,15 @@ def _apply_update(path, param): return state.replace(params=new_params) -def init_decode_state(apply_fn, params) -> TrainState: +def init_decode_state(apply_fn, params) -> train_state.TrainState: """Init train state with null opt state for decode.""" - state = TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore + state = train_state.TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore return state def init_training_state(apply_fn, params, tx): """Init train state with null opt state for decode.""" - state = TrainState.create(apply_fn=apply_fn, params=params, tx=tx) + state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx) return state @@ -1376,7 +1374,7 @@ def setup_initial_state( is_training: True to initialize training state, False for decode state Returns: - train_state: the initialized train state. For NNX, this is a TrainStateNNX instance + state: the initialized train state state_mesh_annotations: the mesh annotations for the train state """ @@ -1415,48 +1413,29 @@ def setup_initial_state( else: # The update of data_iterator state happens in place, no need to assign explicitly state = restored["items"] - - # For NNX, convert the pure dict to nnx.State using the abstract state as template - if config.pure_nnx: - nnx.replace_by_pure_dict(unboxed_abstract_state, state) - state = unboxed_abstract_state else: init_state_partial = init_state_fn init_state_partial.__name__ = "initialize_state" - if config.pure_nnx: - state = jax.jit( - lambda: nnx.state(init_state_partial()), # Get state only, mapping to out_sharding structure - in_shardings=None, - out_shardings=state_mesh_shardings, - )() - else: - # pylint: disable=not-callable - state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings, - )() - if raw_params: # If we loaded a partial state, we need to merge it. - if config.pure_nnx: - # raw_params should have the same sharding info as in the model - nnx.update(state.model, raw_params) - else: - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - if sparsity_enabled: - # Sparsity-init keeps freshly initialized params for any leaf still - # represented as an abstract ShapeDtypeStruct in raw_params (i.e. not - # actually restored), and uses the restored value otherwise. - def _merge_params(p_raw, p_init): - if isinstance(p_raw, jax.ShapeDtypeStruct): - return p_init - return p_raw - - merged_params = jax.tree_util.tree_map(_merge_params, raw_params, state.params) - state = state.replace(params=merged_params) - else: - state = state.replace(params=raw_params) - if not config.pure_nnx: - state = max_utils.unbox_logicallypartioned(state) + # pylint: disable=not-callable + state = jax.jit( + init_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings, + )() + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + if sparsity_enabled and raw_params: # If we loaded a partial state, we need to merge it. + + def _merge_params(p_raw, p_init): + if isinstance(p_raw, jax.ShapeDtypeStruct): + return p_init + return p_raw + + merged_params = jax.tree_util.tree_map(_merge_params, raw_params, state.params) + state = state.replace(params=merged_params) + elif raw_params: + state = state.replace(params=raw_params) + + state = max_utils.unbox_logicallypartioned(state) return state, state_mesh_annotations, state_mesh_shardings, data_iterator @@ -1471,9 +1450,6 @@ def get_logical_annotations(config, mesh, init_state_fn): def get_abstract_state(config, mesh, init_state_fn, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" - if config.pure_nnx: - return get_abstract_state_nnx(config, mesh, init_state_fn, is_training) - init_state_partial = init_state_fn with nn_partitioning.axis_rules(config.logical_axis_rules): @@ -1517,148 +1493,6 @@ def move(path, x): ) -def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State: - """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis. - - Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also - inserts the partition_name axis at the correct scan_axis position for parameters created by - _create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a - 3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension. - - Args: - abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)). - mesh: JAX physical mesh. - - Returns: - Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding. - """ - - def _make_named_sharding(v): - val = v.get_value() - if not hasattr(val, "shape"): - # Non-tensor value (e.g., optax MaskedNode for non-trainable params). Preserve - # as-is so the treedef matches abs_var_state in the downstream jax.tree.map. - return v - metadata = v.get_metadata() - out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding") - if not out_sharding: - pspec = PartitionSpec() - else: - # Insert the scan axis for parameters created by _create_scanned_layers. - # _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the - # axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these. - if nnx.PARTITION_NAME in metadata: - partition_name = metadata[nnx.PARTITION_NAME] - # Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits - # param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode - # scan_axis=0 for non-Param types. stacked_rest non-Param variables have - # param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct. - scan_axis = metadata.get("param_scan_axis", 0) - out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) - # Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames - # 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted - # the scan axis. Only insert if not already present. - if partition_name not in out_sharding: - out_sharding.insert(scan_axis, partition_name) - out_sharding = tuple(out_sharding) - # Convert logical axis names to physical mesh axes using current context rules. - context_rules = get_logical_axis_rules() - local_rules = metadata.get("sharding_rules", ()) - if context_rules or local_rules: - rules = composite_rules(context_rules, local_rules) - pspec = PartitionSpec(*from_sharding_rules(out_sharding, rules)) - else: - pspec = PartitionSpec(*out_sharding) - return v.replace(NamedSharding(mesh, pspec)) - - return jax.tree.map(_make_named_sharding, abs_var_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) - - -def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=True): - """Calculates the abstract sharded state and memory placement for an NNX TrainState. - - This function performs an abstract trace of the NNX model and optimizer using - `nnx.get_abstract_model`. It resolves logical sharding annotations into physical - JAX shardings and applies memory placement optimizations such as optimizer - sharding and host memory offloading (pinning to CPU RAM). - - Args: - config: Configuration object containing sharding and offloading hyperparameters - (e.g., shard_optimizer_over_data, optimizer_memory_host_offload). - mesh: JAX physical mesh used to resolve logical axis names to physical devices. - nnx_init_trainstate_fn: A zero-argument factory function that produces a - TrainStateNNX instance during the abstract trace. - is_training: Boolean indicating if the state is for training. If True, - optimizer state is processed and memory offloading strategies are applied. - - Returns: - A tuple containing (abstract_sharded_state, None, state_mesh_shardings): - abstract_sharded_state: An nnx.State containing ShapeDtypeStructs with - fully resolved physical sharding and memory_kind metadata. - state_mesh_annotations: An nnx.State tree consisting of the raw PartitionSpec - objects corresponding to each parameter/variable. - state_mesh_shardings: An nnx.State tree consisting of the raw JAX - Sharding objects corresponding to each parameter/variable. - """ - assert nnx_init_trainstate_fn is not None, "get_abstract_state_nnx: init function must be given." - - with nn_partitioning.axis_rules(config.logical_axis_rules): - # Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply - # get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers - # axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally - # which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers, - # causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor. - # Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls - # var.shape for every variable when a global mesh is active, but masked optimizer - # state variables (e.g. from trainable_parameters_mask) have value=MaskedNode() - # which has no .shape and would raise AttributeError. We handle sharding - # ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not - # needed here. - abs_model = nnx.eval_shape(nnx_init_trainstate_fn) - _, abs_var_state = nnx.split(abs_model) - named_sharding_state = get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) - abstract_state = jax.tree.map( - lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), - abs_var_state, - named_sharding_state, - ) - - state_mesh_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) - - if is_training and config.shard_optimizer_over_data: - # Add data to sharding for optimizer state - optimizer_sharding = jax.tree_util.tree_map_with_path( - functools.partial(sharding.add_data_to_sharding, mesh), - abstract_state.optimizer, - state_mesh_shardings.optimizer, - ) - state_mesh_shardings.optimizer = optimizer_sharding - if is_training and config.optimizer_memory_host_offload: - optimizer_sharding = jax.tree_util.tree_map_with_path( - maxtext_utils_nnx.move_memory_to_host, - state_mesh_shardings.optimizer, - is_leaf=lambda x: isinstance(x, NamedSharding), - ) - state_mesh_shardings.optimizer = optimizer_sharding - if is_training and config.parameter_memory_host_offload: - assert config.param_scan_axis == 0, "You must set the scan axis 0 to enable parameter offloading." - _, state_params, _ = nnx.split(state_mesh_shardings, nnx.Param, ...) - state_params = jax.tree_util.tree_map_with_path( - maxtext_utils_nnx.move_memory_to_host, - state_params, - is_leaf=lambda x: isinstance(x, NamedSharding), - ) - nnx.update(state_mesh_shardings, state_params) - - abstract_sharded_state = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, state_mesh_shardings) - state_mesh_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) - return ( - abstract_sharded_state, - state_mesh_annotations, - state_mesh_shardings, - ) - - def get_prefill_kv_cache_annotations(model, config, rng, mesh, page_state: None | PageState = None): """Get a shaped abstraction of the state (including optimizer)""" @@ -1741,13 +1575,7 @@ def save_quantized_checkpoint_if_configured(config, params): def add_config_to_summary_writer(config, summary_writer): """Writes config params to tensorboard""" if jax.process_index() == 0: - if hasattr(config, "get_keys"): - config_dict = config.get_keys() - elif hasattr(config, "model_dump"): - config_dict = config.model_dump() - else: - config_dict = dict(config) - for key, value in config_dict.items(): + for key, value in config.get_keys().items(): max_utils.add_text_to_summary_writer(key, str(value), summary_writer) @@ -1910,41 +1738,26 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No """ Print state shardings comparing Logical Definition vs Physical Result. """ - if not isinstance(params, nnx.State): - if not hasattr(params, "params"): - params = {"params": params} - if not hasattr(params_sharding, "params"): - params_sharding = {"params": params_sharding} - if logical_annotations and not hasattr(logical_annotations, "params"): - logical_annotations = {"params": logical_annotations} + if not hasattr(params, "params"): + params = {"params": params} + if not hasattr(params_sharding, "params"): + params_sharding = {"params": params_sharding} + if logical_annotations and not hasattr(logical_annotations, "params"): + logical_annotations = {"params": logical_annotations} leaves_params, _ = jax.tree_util.tree_flatten_with_path(params) leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding) + leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) - if logical_annotations is not None: - leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) - for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip( - leaves_params, leaves_sharding, leaves_logical - ): - path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) - shape = jax.typeof(leaf_val) - pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) - pspec_str = str(tuple(pspec)) - logical_str = str(leaf_logical_val) - - message = ( - f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" - ) - max_logging.info(message) - else: - for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding): - path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) - shape = jax.typeof(leaf_val) - pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) - pspec_str = str(tuple(pspec)) + for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical): + path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) + shape = jax.typeof(leaf_val) + pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) + pspec_str = str(tuple(pspec)) + logical_str = str(leaf_logical_val) - message = f" {path_str}\n" f" Shape: {shape}\n" f" Physical: {pspec_str}" - max_logging.info(message) + message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" + max_logging.info(message) print(flush=True) @@ -1955,19 +1768,8 @@ def maybe_dump_jaxpr(config, p_train_step, train_step_inputs): return max_logging.log("Tracing train_step to jaxpr...") - # Trace the underlying un-jitted function via __wrapped__ to avoid heavy remote - # compilation/gRPC round-trips to the Pathways controller. - unwrapped_step = getattr(p_train_step, "__wrapped__", p_train_step) - - def to_abstract(x): - if hasattr(x, "shape") and hasattr(x, "dtype"): - return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype) - return x - - # Convert all input arguments recursively to purely local abstract ShapeDtypeStruct objects - # to completely bypass remote Array objects and proxy tracing overhead. - abstract_inputs = jax.tree.map(to_abstract, train_step_inputs) - p_train_jaxpr = jax.make_jaxpr(unwrapped_step)(*abstract_inputs) + # We use the p_train_step (the JIT-decorated function) + p_train_jaxpr = jax.make_jaxpr(p_train_step)(*train_step_inputs) local_filename = "train_step.jaxpr" local_path = os.path.join(config.dump_jaxpr_local_dir, local_filename) diff --git a/tests/end_to_end/tpu/test_dpo.sh b/tests/end_to_end/tpu/test_dpo.sh index afca64124a..2d036ec353 100644 --- a/tests/end_to_end/tpu/test_dpo.sh +++ b/tests/end_to_end/tpu/test_dpo.sh @@ -8,29 +8,13 @@ RUN_NAME=dpo_$(date +%Y-%m-%d-%H-%M-%S) export GEMMA_2B_CKPT_PATH=$(gcloud storage ls gs://maxtext-gemma/gemma2/2b | sort -r | head -1) LOGS="gs://maxtext-external/logs" -# tfds pipeline -python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/dpo.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma \ - run_name="$RUN_NAME-tfds" model_name=gemma2-2b base_output_directory=${LOGS} \ - load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \ - per_device_batch_size=0.5 allow_split_physical_axes=True \ - ici_data_parallelism=2 ici_tensor_parallelism=2 ici_fsdp_parallelism=1 - -# grain pipeline -mkdir -p /tmp/anthropic_rlhf || true -gcloud storage cp -r gs://maxtext-dataset/dpo/anthropic_rlhf/array_record /tmp/anthropic_rlhf -python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/dpo.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma \ - run_name="$RUN_NAME-grain" model_name=gemma2-2b base_output_directory=${LOGS} \ - load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \ - dataset_type=grain grain_worker_count=16 \ - grain_train_files='/tmp/anthropic_rlhf/array_record/anthropic_rlhf_tfds-train.array_record*' \ - grain_eval_files='/tmp/anthropic_rlhf/array_record/anthropic_rlhf_tfds-test.array_record*' \ - per_device_batch_size=0.5 allow_split_physical_axes=True \ - ici_data_parallelism=2 ici_tensor_parallelism=2 ici_fsdp_parallelism=1 - # hf pipeline -python3 -m maxtext.trainers.pre_train.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/dpo.yml tokenizer_path='google/gemma-2-2b-it' \ - run_name="$RUN_NAME-grain" model_name=gemma2-2b base_output_directory=${LOGS} \ +python3 -m maxtext.trainers.post_train.dpo.train_dpo "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/post_train/dpo.yml tokenizer_path='google/gemma-2-2b-it' \ + run_name="$RUN_NAME-hf" model_name=gemma2-2b base_output_directory=${LOGS} \ load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \ dataset_type=hf hf_access_token=$HF_TOKEN hf_path='Anthropic/hh-rlhf' \ - per_device_batch_size=0.5 allow_split_physical_axes=True ici_tensor_parallelism=2 \ - ici_data_parallelism=2 ici_tensor_parallelism=2 ici_fsdp_parallelism=1 + train_data_columns="['chosen', 'rejected']" \ + per_device_batch_size=0.5 allow_split_physical_axes=True eval_interval=0 \ + ici_data_parallelism=1 ici_tensor_parallelism=1 ici_fsdp_parallelism=1 \ + use_grpo=False steps=2 + From 91bb06359200a4f0389d9a9a8e687fefca632931 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Thu, 14 May 2026 15:21:02 -0700 Subject: [PATCH 3/3] Feature: integrate standalone DPO and ORPO demonstration notebook with dynamic algorithm toggle and CI validation --- docs/guides/run_python_notebook.md | 4 + src/maxtext/examples/dpo_qwen3_demo.ipynb | 278 +++++++++++++++++++++ tests/post_training/unit/dpo_hooks_test.py | 2 +- 3 files changed, 283 insertions(+), 1 deletion(-) create mode 100644 src/maxtext/examples/dpo_qwen3_demo.ipynb diff --git a/docs/guides/run_python_notebook.md b/docs/guides/run_python_notebook.md index 6e6cc08091..454dd93600 100644 --- a/docs/guides/run_python_notebook.md +++ b/docs/guides/run_python_notebook.md @@ -186,6 +186,10 @@ jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root - **`sft_qwen3_demo.ipynb`** → Qwen3-0.6B SFT training and evaluation on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k). This notebook is friendly for beginners and runs successfully on Google Colab's free-tier v5e-1 TPU runtime. - **`sft_llama3_demo_tpu.ipynb`** → Llama3.1-8B SFT training on [Hugging Face ultrachat_200k dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k). We recommend running this on a v5p-8 TPU VM using [Method 2](#method-2-visual-studio-code-with-tpu-recommended) or [Method 3](#method-3-local-jupyter-lab-with-tpu-recommended). +### Preference Optimization (DPO & ORPO) Training + +- **`dpo_qwen3_demo.ipynb`** → Direct Preference Optimization (DPO) and Odds Ratio Preference Optimization (ORPO) training on [Hugging Face argilla/distilabel-intel-orca-dpo-pairs dataset](https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs). Friendly for beginners and runs successfully on single-host TPU environments. Includes a dropdown parameter toggle for switching between algorithms. + ### Reinforcement Learning (GRPO/GSPO) Training - **`rl_llama3_demo.ipynb`** → GRPO/GSPO training on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k). We recommend running this on a v5p-8 TPU VM using [Method 2](#method-2-visual-studio-code-with-tpu-recommended) or [Method 3](#method-3-local-jupyter-lab-with-tpu-recommended). diff --git a/src/maxtext/examples/dpo_qwen3_demo.ipynb b/src/maxtext/examples/dpo_qwen3_demo.ipynb new file mode 100644 index 0000000000..d6e8bb4003 --- /dev/null +++ b/src/maxtext/examples/dpo_qwen3_demo.ipynb @@ -0,0 +1,278 @@ +{ + "cells": [ + { + "id": "9daae20d", + "cell_type": "markdown", + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/dpo_qwen3_demo.ipynb)\n", + "\n", + "# Qwen3 Preference Optimization (DPO/ORPO) Demo\n", + "\n", + "This notebook demonstrates how to perform Direct Preference Optimization (DPO) and Odds Ratio Preference Optimization (ORPO) on Qwen3-0.6B using the `argilla/distilabel-intel-orca-dpo-pairs` dataset with MaxText and Tunix integration." + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "04803667", + "cell_type": "markdown", + "source": [ + "## Overview\n", + "\n", + "Direct Preference Optimization (DPO) is a stable and efficient alternative to RLHF (Reinforcement Learning from Human Feedback) that optimizes language models directly on preference pairs without requiring a separate reward model. \n", + "\n", + "In this notebook, we will:\n", + "1. Set up the environment and install dependencies.\n", + "2. Load a pre-trained Qwen3-0.6B model configuration.\n", + "3. Fine-tune it using DPO on a preference dataset.\n", + "4. Save the resulting model checkpoints." + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "119988b1", + "cell_type": "markdown", + "source": [ + "## Prerequisites\n", + "\n", + "### Get Your Hugging Face Token\n", + "\n", + "To access model checkpoints from the Hugging Face Hub, you need to authenticate with a personal access token.\n", + "\n", + "1. **Navigate to the Access Tokens page** in your Hugging Face account settings: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)\n", + "2. **Create a new token** with a `read` role.\n", + "3. **Copy the generated token**." + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "20641891", + "cell_type": "code", + "source": [ + "try:\n", + " from google.colab import userdata\n", + " print(\"Running the notebook on Google Colab\")\n", + " IN_COLAB = True\n", + "except ImportError:\n", + " print(\"Running the notebook on Visual Studio or JupyterLab\")\n", + " IN_COLAB = False" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "de871451", + "cell_type": "markdown", + "source": [ + "## Installation: MaxText and Post-training Dependencies" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "cc6dbe7a", + "cell_type": "code", + "source": [ + "if IN_COLAB:\n", + " # Clone the MaxText repository\n", + " !git clone https://github.com/AI-Hypercomputer/maxtext.git\n", + " %cd maxtext\n", + "\n", + " # Install uv, a fast Python package installer\n", + " !pip install uv\n", + " \n", + " # Install MaxText and post-training dependencies\n", + " !uv pip install -e .[tpu-post-train] --resolution=lowest\n", + " !install_tpu_post_train_extra_deps\n", + " \n", + " print(\"\\nIMPORTANT: Restart your session (Runtime \u003e Restart session) and run the notebook from the 'Environment Setup' section.\")" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "5581f437", + "cell_type": "markdown", + "source": [ + "## Environment Setup" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "e8e6aa2e", + "cell_type": "code", + "source": [ + "import datetime\n", + "import jax\n", + "import os\n", + "from maxtext.configs import pyconfig\n", + "from maxtext.utils.globals import MAXTEXT_PKG_DIR\n", + "from maxtext.trainers.post_train.dpo import train_dpo\n", + "from huggingface_hub import login\n", + "\n", + "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX devices: {jax.devices()}\")" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "f3933aba", + "cell_type": "code", + "source": [ + "if IN_COLAB:\n", + " from huggingface_hub import notebook_login\n", + " notebook_login()\n", + "else:\n", + " from huggingface_hub import login\n", + " login()" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "e585f936", + "cell_type": "markdown", + "source": [ + "## Configuration\n", + "\n", + "We will set up the configuration for Direct Preference Optimization (DPO) training." + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "d838f4e9", + "cell_type": "code", + "source": [ + "# Model and Experiment details\n", + "ALGORITHM = 'dpo' # @param ['dpo', 'orpo']\n", + "MODEL_NAME = \"qwen3-0.6b\" # @param [\"qwen3-0.6b\", \"gemma2-2b\", \"llama3-8b\"]\n", + "TOKENIZER_PATH = \"Qwen/Qwen3-0.6B\" # @param {type:\"string\"}\n", + "STEPS = 20 # @param {type:\"integer\"}\n", + "BASE_OUTPUT_DIRECTORY = f\"{MAXTEXT_PKG_DIR}/align_qwen06_output\"\n", + "RUN_NAME = f\"{ALGORITHM}-demo-{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}\"\n", + "\n", + "# Construct configuration arguments\n", + "config_argv = [\n", + " \"\",\n", + " f\"{MAXTEXT_PKG_DIR}/configs/post_train/dpo.yml\",\n", + " f\"run_name={RUN_NAME}\",\n", + " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", + " f\"model_name={MODEL_NAME}\",\n", + " f\"tokenizer_path={TOKENIZER_PATH}\",\n", + " \"dataset_type=hf\",\n", + " \"hf_path=argilla/distilabel-intel-orca-dpo-pairs\",\n", + " \"hf_eval_split=train\",\n", + " \"train_split=train\",\n", + " f\"steps={STEPS}\",\n", + " \"per_device_batch_size=1\",\n", + " \"max_target_length=1024\",\n", + " \"eval_interval=10\",\n", + " \"enable_checkpointing=False\",\n", + " \"log_config=False\",\n", + " \"use_dpo=True\",\n", + " f\"dpo.algo='{ALGORITHM}'\",\n", + " \"enable_goodput_recording=False\",\n", + " \"monitor_goodput=False\",\n", + " \"profiler=xplane\",\n", + "]\n", + "\n", + "config = pyconfig.initialize(config_argv)\n", + "\n", + "print(f\"\\n✓ {ALGORITHM.upper()} configuration loaded for {MODEL_NAME}\")\n", + "print(f\" Run name: {config.run_name}\")\n", + "print(f\" Output directory: {config.base_output_directory}/{config.run_name}\")" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "4bb3e2fb", + "cell_type": "markdown", + "source": [ + "## Data Preview\n", + "\n", + "It is always a good practice to inspect the dataset before training. DPO requires a dataset with preference pairs (prompt, chosen, rejected)." + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "faa74a17", + "cell_type": "code", + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset_name = \"argilla/distilabel-intel-orca-dpo-pairs\"\n", + "print(f\"Loading a few samples from {dataset_name}...\")\n", + "preview_ds = load_dataset(dataset_name, split=\"train\", streaming=True)\n", + "samples = list(preview_ds.take(2))\n", + "\n", + "for i, sample in enumerate(samples):\n", + " print(f\"\\n--- Sample {i+1} ---\")\n", + " print(f\"PROMPT:\\n{sample['input'][:300]}...\")\n", + " print(f\"\\nCHOSEN:\\n{sample['chosen'][:300]}...\")\n", + " print(f\"\\nREJECTED:\\n{sample['rejected'][:300]}...\")" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "d376f43a", + "cell_type": "markdown", + "source": [ + "## Run Direct Preference Optimization Training" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "a286dc52", + "cell_type": "code", + "source": [ + "print(\"Starting DPO Training...\")\n", + "try:\n", + " trainer, mesh = train_dpo.train(config)\n", + " print(\"DPO Training Complete!\")\n", + " print(f\"Checkpoints saved to: {config.checkpoint_dir}\")\n", + "except Exception as e:\n", + " print(\"Training Failed!\")\n", + " print(f\"Error: {str(e)}\")\n", + " import traceback\n", + " traceback.print_exc()" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "c878a929", + "cell_type": "markdown", + "source": [ + "## 📚 Learn More\n", + "\n", + "- **DPO Paper**: [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290)\n", + "- **MaxText Documentation**: https://maxtext.readthedocs.io/\n", + "- **Tunix Documentation**: https://github.com/google/tunix" + ], + "metadata": {}, + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat_minor": 5, + "nbformat": 4 +} \ No newline at end of file diff --git a/tests/post_training/unit/dpo_hooks_test.py b/tests/post_training/unit/dpo_hooks_test.py index 93c36be81f..89e349bb8a 100644 --- a/tests/post_training/unit/dpo_hooks_test.py +++ b/tests/post_training/unit/dpo_hooks_test.py @@ -1,4 +1,4 @@ -# Copyright 2023–2026 Google LLC +# Copyright 2023–2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.