Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,12 @@ global_rampup_samples: 500

# direct preference optimization (DPO)
use_dpo: False
dpo_label_smoothing: 0.0
dpo_beta: 0.1
dpo:
algo: 'dpo'
orpo_lambda: 0.1
dpo_label_smoothing: 0.0
dpo_beta: 0.1
max_prompt_length: null

# Supervised Fine-Tuning (SFT)
use_sft: False
Expand Down
8 changes: 6 additions & 2 deletions src/maxtext/configs/post_train/dpo.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
base_config: "base.yml"

use_dpo: true
dpo:
algo: 'dpo'
orpo_lambda: 0.1
dpo_label_smoothing: 0.0
dpo_beta: 0.1
max_prompt_length: null
packing: false
train_data_columns: ['chosen', 'rejected']
eval_data_columns: ['chosen', 'rejected']
Expand All @@ -24,8 +30,6 @@ hf_eval_split: 'test'

gradient_clipping_threshold: 10.0
learning_rate: 5.0e-7
dpo_label_smoothing: 0.0
dpo_beta: 0.1

enable_goodput_recording: false
monitor_goodput: false
Expand Down
30 changes: 28 additions & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,12 +1201,24 @@ class OlmoGrainDataset(BaseModel):
olmo_apply_ngram_filter: bool = Field(True, description="Mask repetitive instances per OLMo-core's repetition filter.")


class DPO(BaseModel):
"""Configuration for DPO and ORPO preference optimization algorithms."""

algo: Literal["dpo", "orpo"] = Field("dpo", description="Alignment algorithm to use.")
dpo_beta: float = Field(0.1, description="Beta parameter for DPO.")
orpo_lambda: float = Field(0.1, description="Weight for preference loss in ORPO.")
dpo_label_smoothing: float = Field(0.0, ge=0.0, le=1.0, description="Label smoothing for DPO.")
max_prompt_length: int | None = Field(
None,
Comment thread
igorts-git marked this conversation as resolved.
gt=0,
description="Maximum length for prompt. If None, defaults to half of max_target_length.",
)


class FineTuning(BaseModel):
"""Configuration for fine-tuning methods like DPO, SFT, and GRPO."""

use_dpo: bool = Field(False, description="If True, enables Direct Preference Optimization training.")
dpo_label_smoothing: float = Field(0.0, ge=0.0, le=1.0, description="Label smoothing for DPO.")
dpo_beta: float = Field(0.1, description="Beta parameter for DPO.")
use_sft: bool = Field(False, description="If True, enables Supervised Fine-Tuning.")
sft_train_on_completion_only: bool = Field(
False, description="If True, trains only on the completion part of the text."
Expand Down Expand Up @@ -2298,6 +2310,10 @@ class MaxTextConfig(
"""

debug: Debug = Field(default_factory=Debug, description="Configuration for debugging options.")
dpo: DPO = Field(
default_factory=DPO,
description="Configuration for DPO and ORPO alignment algorithms.",
)
rl: RL = Field(
default_factory=RL,
description="Configuration for RL algorithms like Group Relative Policy Optimization (GRPO).",
Expand Down Expand Up @@ -2883,6 +2899,16 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
raise ValueError("For multimodal SFT, `sft_train_on_completion_only` must be True.")
if self.packing:
raise ValueError("For multimodal SFT, `packing` is not yet supported.")
if self.use_dpo:
if self.packing:
raise ValueError("For DPO/ORPO, `packing` is not supported.")
if self.dpo.max_prompt_length is not None and self.dpo.max_prompt_length > self.max_target_length:
raise ValueError(
f"`max_prompt_length` ({self.dpo.max_prompt_length}) cannot be "
f"greater than `max_target_length` ({self.max_target_length})."
)
if self.use_sft and self.use_dpo:
raise ValueError("Only one of `use_sft` or `use_dpo` can be True.")
if self.shard_mode == ShardMode.EXPLICIT:
supported_decoders = {"simple", "simple_mlp", "llama2", "deepseek"}
if self.decoder_block.value not in supported_decoders:
Expand Down
108 changes: 108 additions & 0 deletions src/maxtext/input_pipeline/dpo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""DPO specific input pipeline utilities."""

import dataclasses
import grain.python as grain
import numpy as np


@dataclasses.dataclass
class DPOTunixPrep(grain.MapTransform):
"""Prepares DPO data for Tunix.
Renames input columns, extracts common prefix if needed, generates masks, and performs
DPO-aware padding (left-padded prompts, right-padded responses).
"""

pad_id: int
max_target_length: int
data_column_names: tuple[str, ...]
max_prompt_length: int | None = None

def map(self, element):
"Apply the dataset transformations for Tunix-based DPO."
# 1. Reformat/Extract Columns
try:
if len(self.data_column_names) == 3:
input_ids = element[self.data_column_names[0]]
chosen_ids = element[self.data_column_names[1]]
rejected_ids = element[self.data_column_names[2]]
elif len(self.data_column_names) == 2:
# Support for datasets like Anthropic/hh-rlhf where prompt is a common prefix
full_chosen = element[self.data_column_names[0]]
full_rejected = element[self.data_column_names[1]]

# Find common prefix length
prefix_len = 0
for c, r in zip(full_chosen, full_rejected):
if c != r:
break
prefix_len += 1
input_ids = full_chosen[:prefix_len]
chosen_ids = full_chosen[prefix_len:]
rejected_ids = full_rejected[prefix_len:]
else:
raise ValueError(f"DPOTunixPrep expects 2 or 3 columns, got {len(self.data_column_names)}")
except KeyError as e:
raise KeyError(
f"Column '{e.args[0]}' not found in the dataset. "
f"Expected columns: {self.data_column_names}. "
f"Available columns: {list(element.keys())}. "
"Please verify that 'train_data_columns' and 'eval_data_columns' match your dataset."
) from e

# 2. Padding and Masking
max_prompt_length = self.max_prompt_length or (self.max_target_length // 2)
max_response_length = self.max_target_length - max_prompt_length
Comment thread
igorts-git marked this conversation as resolved.

assert max_prompt_length > 0, (
"max_prompt_length must be positive. " "Check the configs for 'max_prompt_length' and 'max_target_length'."
)
assert max_response_length > 0, (
"max_response_length must be positive. " "Check the configs for 'max_prompt_length' and 'max_target_length'."
)

prompt_ids = self._pad(input_ids, max_prompt_length, left=True)
chosen_ids = self._pad(chosen_ids, max_response_length, left=False)
rejected_ids = self._pad(rejected_ids, max_response_length, left=False)

# Remove old columns if they exist
for key in self.data_column_names:
if key in element:
del element[key]

element["prompt_ids"] = prompt_ids
element["chosen_ids"] = chosen_ids
element["rejected_ids"] = rejected_ids
element["prompt_mask"] = (prompt_ids != self.pad_id).astype(np.int32)
element["chosen_mask"] = (chosen_ids != self.pad_id).astype(np.int32)
element["rejected_mask"] = (rejected_ids != self.pad_id).astype(np.int32)
return element

def _pad(self, x, length, left=False):
"""Pads or trims an array to a specific length.

When left=True (for prompts), trims from the left to keep the suffix (closest context).
When left=False (for responses), trims from the right to keep the prefix.
"""
x = np.asarray(x)
pad_amount = max(length - x.shape[0], 0)
if left:
pad_width = ((pad_amount, 0),)
x_trimmed = x[-length:]
Comment thread
igorts-git marked this conversation as resolved.
else:
pad_width = ((0, pad_amount),)
x_trimmed = x[:length]
return np.pad(x_trimmed, pad_width, constant_values=self.pad_id).astype(np.int32)
29 changes: 15 additions & 14 deletions src/maxtext/input_pipeline/hf_data_processing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023–2025 Google LLC
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,9 +24,8 @@

import grain.python as grain

import numpy as np

from maxtext.input_pipeline import data_processing_utils
from maxtext.input_pipeline import dpo_utils
from maxtext.input_pipeline import input_pipeline_utils
from maxtext.input_pipeline import instruction_data_processing
from maxtext.input_pipeline import multihost_dataloading
Expand Down Expand Up @@ -214,7 +213,7 @@ def preprocessing_pipeline(
num_threads=1,
drop_remainder=True,
generate_padding_batch=False,
use_dpo=None,
use_dpo=False,
use_sft=None,
use_tunix_gradient_accumulation=False,
num_microbatches=1,
Expand Down Expand Up @@ -330,19 +329,12 @@ def preprocessing_pipeline(
)
)
data_column_names = ("inputs", "targets")
elif use_dpo:

def lists2array(x):
"""Convert lists/tuples to array"""
return jax.tree.map(np.asarray, x, is_leaf=lambda y: isinstance(y, (list, tuple)))

operations.append(grain.MapOperation(lists2array))
else:
elif not use_dpo:
assert len(data_column_names) == 1
operations.append(input_pipeline_utils.HFNormalizeFeatures(data_column_names[0]))
data_column_names = ("inputs", "targets")

if packing and not use_dpo:
if packing:
length_struct = {col: max_target_length for col in data_column_names}
max_segments = max_segments_per_seq
if max_segments is not None and max_segments <= 0:
Expand All @@ -356,7 +348,16 @@ def lists2array(x):
)
operations.append(input_pipeline_utils.ReformatPacking(data_column_names))
else:
operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id))
if use_dpo:
# Renames arbitrary DPO columns and performs DPO-aware padding.
max_prompt_length = (
config.dpo.get("max_prompt_length")
if isinstance(config.dpo, dict)
else getattr(config.dpo, "max_prompt_length", None)
)
operations.append(dpo_utils.DPOTunixPrep(pad_id, max_target_length, data_column_names, max_prompt_length))
else:
operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id))
operations.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder))

if shift and not use_dpo:
Expand Down
Loading
Loading