diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 670d155974..5612845d39 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -680,8 +680,12 @@ global_rampup_samples: 500 # direct preference optimization (DPO) use_dpo: False -dpo_label_smoothing: 0.0 -dpo_beta: 0.1 +dpo: + algo: 'dpo' + orpo_lambda: 0.1 + dpo_label_smoothing: 0.0 + dpo_beta: 0.1 + max_prompt_length: null # Supervised Fine-Tuning (SFT) use_sft: False diff --git a/src/maxtext/configs/post_train/dpo.yml b/src/maxtext/configs/post_train/dpo.yml index 867d9ed228..5307b845ce 100644 --- a/src/maxtext/configs/post_train/dpo.yml +++ b/src/maxtext/configs/post_train/dpo.yml @@ -1,6 +1,12 @@ base_config: "base.yml" use_dpo: true +dpo: + algo: 'dpo' + orpo_lambda: 0.1 + dpo_label_smoothing: 0.0 + dpo_beta: 0.1 + max_prompt_length: null packing: false train_data_columns: ['chosen', 'rejected'] eval_data_columns: ['chosen', 'rejected'] @@ -24,8 +30,6 @@ hf_eval_split: 'test' gradient_clipping_threshold: 10.0 learning_rate: 5.0e-7 -dpo_label_smoothing: 0.0 -dpo_beta: 0.1 enable_goodput_recording: false monitor_goodput: false diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index bb18a81a5f..fd0449d128 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1204,12 +1204,24 @@ class OlmoGrainDataset(BaseModel): olmo_apply_ngram_filter: bool = Field(True, description="Mask repetitive instances per OLMo-core's repetition filter.") +class DPO(BaseModel): + """Configuration for DPO and ORPO preference optimization algorithms.""" + + algo: Literal["dpo", "orpo"] = Field("dpo", description="Alignment algorithm to use.") + dpo_beta: float = Field(0.1, description="Beta parameter for DPO.") + orpo_lambda: float = Field(0.1, description="Weight for preference loss in ORPO.") + dpo_label_smoothing: float = Field(0.0, ge=0.0, le=1.0, description="Label smoothing for DPO.") + max_prompt_length: int | None = Field( + None, + gt=0, + description="Maximum length for prompt. If None, defaults to half of max_target_length.", + ) + + class FineTuning(BaseModel): """Configuration for fine-tuning methods like DPO, SFT, and GRPO.""" use_dpo: bool = Field(False, description="If True, enables Direct Preference Optimization training.") - dpo_label_smoothing: float = Field(0.0, ge=0.0, le=1.0, description="Label smoothing for DPO.") - dpo_beta: float = Field(0.1, description="Beta parameter for DPO.") use_sft: bool = Field(False, description="If True, enables Supervised Fine-Tuning.") sft_train_on_completion_only: bool = Field( False, description="If True, trains only on the completion part of the text." @@ -2301,6 +2313,10 @@ class MaxTextConfig( """ debug: Debug = Field(default_factory=Debug, description="Configuration for debugging options.") + dpo: DPO = Field( + default_factory=DPO, + description="Configuration for DPO and ORPO alignment algorithms.", + ) rl: RL = Field( default_factory=RL, description="Configuration for RL algorithms like Group Relative Policy Optimization (GRPO).", @@ -2887,6 +2903,16 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de raise ValueError("For multimodal SFT, `sft_train_on_completion_only` must be True.") if self.packing: raise ValueError("For multimodal SFT, `packing` is not yet supported.") + if self.use_dpo: + if self.packing: + raise ValueError("For DPO/ORPO, `packing` is not supported.") + if self.dpo.max_prompt_length is not None and self.dpo.max_prompt_length > self.max_target_length: + raise ValueError( + f"`max_prompt_length` ({self.dpo.max_prompt_length}) cannot be " + f"greater than `max_target_length` ({self.max_target_length})." + ) + if self.use_sft and self.use_dpo: + raise ValueError("Only one of `use_sft` or `use_dpo` can be True.") if self.shard_mode == ShardMode.EXPLICIT: supported_decoders = {"simple", "simple_mlp", "llama2", "deepseek"} if self.decoder_block.value not in supported_decoders: diff --git a/src/maxtext/input_pipeline/dpo_utils.py b/src/maxtext/input_pipeline/dpo_utils.py new file mode 100644 index 0000000000..3dbf63a616 --- /dev/null +++ b/src/maxtext/input_pipeline/dpo_utils.py @@ -0,0 +1,108 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DPO specific input pipeline utilities.""" + +import dataclasses +import grain.python as grain +import numpy as np + + +@dataclasses.dataclass +class DPOTunixPrep(grain.MapTransform): + """Prepares DPO data for Tunix. + Renames input columns, extracts common prefix if needed, generates masks, and performs + DPO-aware padding (left-padded prompts, right-padded responses). + """ + + pad_id: int + max_target_length: int + data_column_names: tuple[str, ...] + max_prompt_length: int | None = None + + def map(self, element): + "Apply the dataset transformations for Tunix-based DPO." + # 1. Reformat/Extract Columns + try: + if len(self.data_column_names) == 3: + input_ids = element[self.data_column_names[0]] + chosen_ids = element[self.data_column_names[1]] + rejected_ids = element[self.data_column_names[2]] + elif len(self.data_column_names) == 2: + # Support for datasets like Anthropic/hh-rlhf where prompt is a common prefix + full_chosen = element[self.data_column_names[0]] + full_rejected = element[self.data_column_names[1]] + + # Find common prefix length + prefix_len = 0 + for c, r in zip(full_chosen, full_rejected): + if c != r: + break + prefix_len += 1 + input_ids = full_chosen[:prefix_len] + chosen_ids = full_chosen[prefix_len:] + rejected_ids = full_rejected[prefix_len:] + else: + raise ValueError(f"DPOTunixPrep expects 2 or 3 columns, got {len(self.data_column_names)}") + except KeyError as e: + raise KeyError( + f"Column '{e.args[0]}' not found in the dataset. " + f"Expected columns: {self.data_column_names}. " + f"Available columns: {list(element.keys())}. " + "Please verify that 'train_data_columns' and 'eval_data_columns' match your dataset." + ) from e + + # 2. Padding and Masking + max_prompt_length = self.max_prompt_length or (self.max_target_length // 2) + max_response_length = self.max_target_length - max_prompt_length + + assert max_prompt_length > 0, ( + "max_prompt_length must be positive. " "Check the configs for 'max_prompt_length' and 'max_target_length'." + ) + assert max_response_length > 0, ( + "max_response_length must be positive. " "Check the configs for 'max_prompt_length' and 'max_target_length'." + ) + + prompt_ids = self._pad(input_ids, max_prompt_length, left=True) + chosen_ids = self._pad(chosen_ids, max_response_length, left=False) + rejected_ids = self._pad(rejected_ids, max_response_length, left=False) + + # Remove old columns if they exist + for key in self.data_column_names: + if key in element: + del element[key] + + element["prompt_ids"] = prompt_ids + element["chosen_ids"] = chosen_ids + element["rejected_ids"] = rejected_ids + element["prompt_mask"] = (prompt_ids != self.pad_id).astype(np.int32) + element["chosen_mask"] = (chosen_ids != self.pad_id).astype(np.int32) + element["rejected_mask"] = (rejected_ids != self.pad_id).astype(np.int32) + return element + + def _pad(self, x, length, left=False): + """Pads or trims an array to a specific length. + + When left=True (for prompts), trims from the left to keep the suffix (closest context). + When left=False (for responses), trims from the right to keep the prefix. + """ + x = np.asarray(x) + pad_amount = max(length - x.shape[0], 0) + if left: + pad_width = ((pad_amount, 0),) + x_trimmed = x[-length:] + else: + pad_width = ((0, pad_amount),) + x_trimmed = x[:length] + return np.pad(x_trimmed, pad_width, constant_values=self.pad_id).astype(np.int32) diff --git a/src/maxtext/input_pipeline/hf_data_processing.py b/src/maxtext/input_pipeline/hf_data_processing.py index 489e55c77f..057bc2c59a 100644 --- a/src/maxtext/input_pipeline/hf_data_processing.py +++ b/src/maxtext/input_pipeline/hf_data_processing.py @@ -1,4 +1,4 @@ -# Copyright 2023–2025 Google LLC +# Copyright 2023–2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,9 +24,8 @@ import grain.python as grain -import numpy as np - from maxtext.input_pipeline import data_processing_utils +from maxtext.input_pipeline import dpo_utils from maxtext.input_pipeline import input_pipeline_utils from maxtext.input_pipeline import instruction_data_processing from maxtext.input_pipeline import multihost_dataloading @@ -214,7 +213,7 @@ def preprocessing_pipeline( num_threads=1, drop_remainder=True, generate_padding_batch=False, - use_dpo=None, + use_dpo=False, use_sft=None, use_tunix_gradient_accumulation=False, num_microbatches=1, @@ -330,19 +329,12 @@ def preprocessing_pipeline( ) ) data_column_names = ("inputs", "targets") - elif use_dpo: - - def lists2array(x): - """Convert lists/tuples to array""" - return jax.tree.map(np.asarray, x, is_leaf=lambda y: isinstance(y, (list, tuple))) - - operations.append(grain.MapOperation(lists2array)) - else: + elif not use_dpo: assert len(data_column_names) == 1 operations.append(input_pipeline_utils.HFNormalizeFeatures(data_column_names[0])) data_column_names = ("inputs", "targets") - if packing and not use_dpo: + if packing: length_struct = {col: max_target_length for col in data_column_names} max_segments = max_segments_per_seq if max_segments is not None and max_segments <= 0: @@ -356,7 +348,16 @@ def lists2array(x): ) operations.append(input_pipeline_utils.ReformatPacking(data_column_names)) else: - operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) + if use_dpo: + # Renames arbitrary DPO columns and performs DPO-aware padding. + max_prompt_length = ( + config.dpo.get("max_prompt_length") + if isinstance(config.dpo, dict) + else getattr(config.dpo, "max_prompt_length", None) + ) + operations.append(dpo_utils.DPOTunixPrep(pad_id, max_target_length, data_column_names, max_prompt_length)) + else: + operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) operations.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder)) if shift and not use_dpo: diff --git a/tests/post_training/unit/dpo_data_processing_test.py b/tests/post_training/unit/dpo_data_processing_test.py new file mode 100644 index 0000000000..d7836b538d --- /dev/null +++ b/tests/post_training/unit/dpo_data_processing_test.py @@ -0,0 +1,393 @@ +# Copyright 2025–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for DPO data preparation.""" +import os +import unittest +from datasets import Dataset +import jax +from jax.experimental import mesh_utils +from jax.sharding import Mesh +import numpy as np +import pytest +import transformers + +from maxtext.configs import pyconfig +from maxtext.input_pipeline import dpo_utils +from maxtext.input_pipeline import hf_data_processing +from maxtext.input_pipeline import input_pipeline_interface +from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR, MAXTEXT_PKG_DIR + +pytestmark = [pytest.mark.post_training, pytest.mark.cpu_only] + + +class TestDPOTunixPrep(unittest.TestCase): + """Tests for DPOTunixPrep transform.""" + + def setUp(self): + self.pad_id = 0 + + def test_column_remapping(self): + """Verify that columns are renamed to match Tunix expectations.""" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=21, + data_column_names=("input", "chosen", "rejected"), + ) + sample = { + "input": np.array([1, 2, 3]), + "chosen": np.array([4, 5]), + "rejected": np.array([6, 7, 8]), + } + output = prep.map(sample) + + # Check that old keys are removed + self.assertNotIn("input", output) + self.assertNotIn("chosen", output) + self.assertNotIn("rejected", output) + + # Check that new keys exist + self.assertIn("prompt_ids", output) + self.assertIn("chosen_ids", output) + self.assertIn("rejected_ids", output) + self.assertIn("prompt_mask", output) + self.assertIn("chosen_mask", output) + self.assertIn("rejected_mask", output) + + def test_two_column_prefix_extraction(self): + """Verify common prefix extraction for 2-column datasets.""" + # The column names will be remappend into "chosen" and "rejected" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=21, + data_column_names=("liked", "disliked"), + ) + sample = { + "liked": np.array([1, 2, 3, 10, 11]), + "disliked": np.array([1, 2, 3, 20, 21, 22]), + } + output = prep.map(sample) + + # Prefix is [1, 2, 3], left-padded. + self.assertEqual(output["prompt_ids"].shape[0], 10) + np.testing.assert_array_equal(output["prompt_ids"], [self.pad_id] * 7 + [1, 2, 3]) + # Prompt mask for [0, 0, 0, 0, 0, 1, 2, 3] should be [0, 0, 0, 0, 0, 1, 1, 1] + np.testing.assert_array_equal(output["prompt_mask"], [0] * 7 + [1, 1, 1]) + + # chosen_ids (len 11) should be right-padded + self.assertEqual(output["chosen_ids"].shape[0], 11) + np.testing.assert_array_equal(output["chosen_ids"], [10, 11] + [self.pad_id] * 9) + # chosen_mask for [10, 11, 0, 0, ...] should be [1, 1, 0, 0, ...] + np.testing.assert_array_equal(output["chosen_mask"], [1, 1] + [0] * 9) + + # rejected_ids (len 11) should be right-padded + self.assertEqual(output["rejected_ids"].shape[0], 11) + np.testing.assert_array_equal(output["rejected_ids"], [20, 21, 22] + [self.pad_id] * 8) + # rejected_mask for [20, 21, 22, 0, 0, ...] should be [1, 1, 1, 0, 0, ...] + np.testing.assert_array_equal(output["rejected_mask"], [1, 1, 1] + [0] * 8) + + def test_three_column_remapping(self): + """Verify standard 3-column remapping.""" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=20, + data_column_names=("input", "chosen", "rejected"), + ) + sample = { + "input": np.array([1, 2]), + "chosen": np.array([3, 4]), + "rejected": np.array([5, 6, 7]), + } + output = prep.map(sample) + + self.assertNotIn("input", output) + # Prompt should be left-padded + np.testing.assert_array_equal(output["prompt_ids"], [self.pad_id] * 8 + [1, 2]) + np.testing.assert_array_equal(output["prompt_mask"], [0] * 8 + [1, 1]) + + # Chosen and rejected are right-padded + np.testing.assert_array_equal(output["chosen_ids"], [3, 4] + [self.pad_id] * 8) + np.testing.assert_array_equal(output["chosen_mask"], [1, 1] + [0] * 8) + np.testing.assert_array_equal(output["rejected_ids"], [5, 6, 7] + [self.pad_id] * 7) + np.testing.assert_array_equal(output["rejected_mask"], [1, 1, 1] + [0] * 7) + + def test_three_column_truncation(self): + """Verify that prompts are suffix-truncated and responses are prefix-truncated.""" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=9, + data_column_names=("input", "chosen", "rejected"), + ) + sample = { + "input": np.arange(1, 10), + "chosen": np.arange(10, 20), + "rejected": np.arange(20, 30), + } + output = prep.map(sample) + + # Prompt is suffix-truncated to 4 chars (keeps the end). + np.testing.assert_array_equal(output["prompt_ids"], np.arange(6, 10)) + np.testing.assert_array_equal(output["prompt_mask"], [1] * 4) + + # Chosen and rejected are prefix-truncated to 5 chars (keeps the start). + np.testing.assert_array_equal(output["chosen_ids"], np.arange(10, 15)) + np.testing.assert_array_equal(output["chosen_mask"], [1] * 5) + np.testing.assert_array_equal(output["rejected_ids"], np.arange(20, 25)) + np.testing.assert_array_equal(output["rejected_mask"], [1] * 5) + + def test_two_column_prefix_edge_cases(self): + """Verify prefix extraction robustness with identical strings or prefix strings.""" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=20, + data_column_names=("chosen", "rejected"), + ) + + # Case 1: Identical strings + identical_sample = { + "chosen": np.array([1, 2, 3]), + "rejected": np.array([1, 2, 3]), + } + out_identical = prep.map(identical_sample) + # Entire string is prefix; suffixes are empty (padded with pad_id) + np.testing.assert_array_equal(out_identical["prompt_ids"][-3:], [1, 2, 3]) + np.testing.assert_array_equal(out_identical["chosen_ids"], [self.pad_id] * 10) + np.testing.assert_array_equal(out_identical["rejected_ids"], [self.pad_id] * 10) + + # Case 2: One is a prefix of another + prefix_sample = { + "chosen": np.array([1, 2, 3]), + "rejected": np.array([1, 2, 3, 4, 5]), + } + out_prefix = prep.map(prefix_sample) + np.testing.assert_array_equal(out_prefix["prompt_ids"][-3:], [1, 2, 3]) + np.testing.assert_array_equal(out_prefix["chosen_ids"], [self.pad_id] * 10) + np.testing.assert_array_equal(out_prefix["rejected_ids"][:2], [4, 5]) + + def test_max_prompt_length_override(self): + """Verify that max_prompt_length can be customized to adjust prompt/response split ratio.""" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=20, + data_column_names=("input", "chosen", "rejected"), + max_prompt_length=15, + ) + sample = { + "input": np.array([1, 2]), + "chosen": np.array([3, 4]), + "rejected": np.array([5, 6]), + } + output = prep.map(sample) + + # Prompt length should be 15, response length should be 5 + self.assertEqual(output["prompt_ids"].shape[0], 15) + self.assertEqual(output["chosen_ids"].shape[0], 5) + self.assertEqual(output["rejected_ids"].shape[0], 5) + + def test_missing_column_error(self): + """Verify that a helpful error is raised when a column is missing.""" + prep = dpo_utils.DPOTunixPrep( + pad_id=self.pad_id, + max_target_length=20, + data_column_names=("input", "chosen", "rejected"), + ) + # 'rejected' column is missing + sample = {"input": np.array([1, 2]), "chosen": np.array([3, 4])} + with self.assertRaisesRegex(KeyError, "Column 'rejected' not found in the dataset"): + prep.map(sample) + + +@pytest.mark.external_training +class TestDPOPipelineIntegration(unittest.TestCase): + """End-to-end DPO pipeline integration tests.""" + + def setUp(self): + super().setUp() + self.config = pyconfig.initialize( + [ + os.path.join(MAXTEXT_PKG_DIR, "dpo_trainer"), + os.path.join(MAXTEXT_CONFIGS_DIR, "post_train", "dpo.yml"), + ], + per_device_batch_size=2, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + tokenizer_path="Qwen/Qwen3-4B", + train_split="train", + enable_checkpointing=False, + use_dpo=True, + enable_data_shuffling=False, + max_target_length=64, + ) + self.mesh_shape_1d = (len(jax.devices()),) + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.process_indices = input_pipeline_interface.get_process_loading_real_data( + self.config.data_sharding, + self.config.global_batch_size_to_load, + self.config.global_batch_size_to_train_on, + self.config.max_target_length, + self.mesh, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + self.config.tokenizer_path, + add_bos_token=False, + add_eos_token=False, + legacy=False, + ) + self.pad_id = hf_data_processing._get_pad_id(self.tokenizer) # pylint: disable=protected-access + + def get_data_iterator(self, dataset, data_columns): + """Helper to initialize the preprocessing pipeline.""" + return hf_data_processing.preprocessing_pipeline( + dataloading_host_index=self.process_indices.index(jax.process_index()), + dataloading_host_count=len(self.process_indices), + global_mesh=self.mesh, + dataset=dataset, + config=self.config, + data_column_names=data_columns, + tokenize=self.config.tokenize_train_data, + tokenizer_path=self.config.tokenizer_path, + hf_access_token=self.config.hf_access_token, + global_batch_size=self.config.global_batch_size_to_load, + max_target_length=self.config.max_target_length, + shuffle=self.config.enable_data_shuffling, + data_shuffle_seed=self.config.data_shuffle_seed, + add_bos=self.config.add_bos, + add_eos=self.config.add_eos, + packing=self.config.packing, + generate_padding_batch=False, + use_dpo=self.config.use_dpo, + use_sft=self.config.use_sft, + sft_train_on_completion_only=self.config.sft_train_on_completion_only, + grain_worker_count=0, + ) + + def test_dpo_format_3_columns(self): + """Verify that the 3-column explicit DPO dataset is processed correctly.""" + prompt_str = "Question: What is 2+2?" + chosen_str = "Answer: 4" + rejected_str = "Answer: 5" + + dataset = Dataset.from_dict( + { + "input": [prompt_str] * 10, + "chosen": [chosen_str] * 10, + "rejected": [rejected_str] * 10, + } + ) + data_iter = self.get_data_iterator(dataset, ["input", "chosen", "rejected"]) + batch = next(data_iter) + + # Verify expected keys + for key in ( + "prompt_ids", + "chosen_ids", + "rejected_ids", + "prompt_mask", + "chosen_mask", + "rejected_mask", + ): + self.assertIn(key, batch) + + # Verify batch dimensions match global batch size and split max_target_length + max_prompt_len = self.config.max_target_length // 2 + max_response_len = self.config.max_target_length - max_prompt_len + self.assertEqual( + batch["prompt_ids"].shape, + (self.config.global_batch_size_to_load, max_prompt_len), + ) + self.assertEqual( + batch["chosen_ids"].shape, + (self.config.global_batch_size_to_load, max_response_len), + ) + self.assertEqual( + batch["rejected_ids"].shape, + (self.config.global_batch_size_to_load, max_response_len), + ) + + # Verify decoded content directly + decoded_prompt = self.tokenizer.decode(batch["prompt_ids"][0], skip_special_tokens=True) + decoded_chosen = self.tokenizer.decode(batch["chosen_ids"][0], skip_special_tokens=True) + decoded_rejected = self.tokenizer.decode(batch["rejected_ids"][0], skip_special_tokens=True) + + self.assertEqual(decoded_prompt, prompt_str) + self.assertEqual(decoded_chosen, chosen_str) + self.assertEqual(decoded_rejected, rejected_str) + + # Verify mask structure (left padding for prompt -> 1s at the end; right padding for responses -> 1s at start) + self.assertEqual(batch["prompt_mask"][0][-1], 1) + self.assertEqual(batch["chosen_mask"][0][0], 1) + self.assertEqual(batch["rejected_mask"][0][0], 1) + + def test_dpo_format_2_columns(self): + """Verify that 2-column DPO datasets correctly extract common prefixes.""" + # We use a clear common prefix and different suffixes + prefix = "Common prompt context for DPO:" + chosen_suffix = " the chosen completion" + rejected_suffix = " the rejected completion" + + dataset = Dataset.from_dict( + { + "chosen": [prefix + chosen_suffix] * 10, + "rejected": [prefix + rejected_suffix] * 10, + } + ) + data_iter = self.get_data_iterator(dataset, ["chosen", "rejected"]) + batch = next(data_iter) + + # Verify decoded extracted prefix and completions robustly against BPE token boundary quirks + decoded_prompt = self.tokenizer.decode(batch["prompt_ids"][0], skip_special_tokens=True) + decoded_chosen = self.tokenizer.decode(batch["chosen_ids"][0], skip_special_tokens=True) + decoded_rejected = self.tokenizer.decode(batch["rejected_ids"][0], skip_special_tokens=True) + + self.assertIn("Common prompt context", decoded_prompt) + self.assertIn("chosen", decoded_chosen) + self.assertIn("rejected", decoded_rejected) + + def test_dpo_invalid_column_count(self): + """Verify that passing an unsupported number of columns raises an error.""" + dataset = Dataset.from_dict({"col1": ["a"] * 10}) + with self.assertRaises((ValueError, KeyError)): + # DPOTunixPrep expects 2 or 3 columns + data_iter = self.get_data_iterator(dataset, ["col1"]) + next(data_iter) + + def test_dpo_non_positive_max_prompt_length(self): + """Verify that max_prompt_length <= 0 raises a validation error.""" + with self.assertRaises(ValueError): + pyconfig.initialize( + [ + os.path.join(MAXTEXT_PKG_DIR, "dpo_trainer"), + os.path.join(MAXTEXT_CONFIGS_DIR, "post_train", "dpo.yml"), + ], + per_device_batch_size=2, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory="gs://max-experiments/", + tokenizer_path="Qwen/Qwen3-4B", + train_split="train", + enable_checkpointing=False, + use_dpo=True, + enable_data_shuffling=False, + max_target_length=64, + dpo={"algo": "dpo", "max_prompt_length": 0}, + ) + + +if __name__ == "__main__": + unittest.main()