Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 197 additions & 1 deletion src/maxtext/input_pipeline/grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,166 @@ def dpo_preprocessing_pipeline(
return dataset


def _format_chat_template_grain(element, data_columns, tokenizer_model):
"""Grain-compatible mapping function to format raw columns into conversational messages."""
# Convert raw columns to conversational messages
if "messages" in data_columns:
messages = element["messages"]
elif set(data_columns) == {"prompt", "completion"}:
messages = [{"role": "user", "content": element["prompt"]}, {"role": "assistant", "content": element["completion"]}]
elif set(data_columns) == {"question", "answer"}:
messages = [{"role": "user", "content": element["question"]}, {"role": "assistant", "content": element["answer"]}]
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HF pipeline asserts sft is running on a conversational format

# Fallback if it's already a single string
messages = element[data_columns[0]]

# Assign the standardized messages back to the primary column
element[data_columns[0]] = messages

return input_pipeline_utils.apply_chat_template(
element, tokenizer_model=tokenizer_model, data_column_name=data_columns[0]
)


def _tokenize_sft_chunks(element, text_column_name, tokenizer_model):
"""Tokenize each chunk individually without truncating."""
text_chunks = element[text_column_name]
element[text_column_name] = [tokenizer_model.encode(chunk) for chunk in text_chunks]
return element


def sft_preprocessing_pipeline(
dataset,
config,
data_columns,
tokenize,
grain_worker_count,
grain_per_worker_buffer_size,
):
"""Use grain pipeline to pre-process the dataset and return iterators for sft fine-tuning"""
if config.grain_file_type == "arrayrecord":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is exactly this block in the pretrain_preprocessing_pipeline(). Almost identical code there suggests it would be great to reuse it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there are many common operations between pretrain and sft. I think it's a good idea to extract the common pattern into util functions

dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize))
dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
else:
dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns))

tokenizer_model = tokenizer.build_tokenizer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for the tokenizer

config.tokenizer_path,
config.tokenizer_type,
config.add_bos,
config.add_eos,
config.hf_access_token,
)
if tokenizer_model.pad_id is not None:
pad_id = tokenizer_model.pad_id
elif tokenizer_model.unk_id is not None:
pad_id = tokenizer_model.unk_id
else:
pad_id = -1

tokenizer_model = tokenizer_model.tokenizer

if getattr(config, "chat_template", None) and hasattr(tokenizer_model, "chat_template"):
tokenizer_model.chat_template = config.chat_template

supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is again repeating logic from the hf pipeline, one can extract utility routines to avoid code repetition

assert any(
set(data_columns) == set(supported) for supported in supported_columns
), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_columns}"

dataset = dataset.map(
functools.partial(_format_chat_template_grain, data_columns=data_columns, tokenizer_model=tokenizer_model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hf pipeline calls instruction_data_processing.convert_to_conversational_format, do we support the same conversion here? https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/input_pipeline/hf_data_processing.py#L261

)

if tokenize:
dataset = dataset.map(
functools.partial(
_tokenize_sft_chunks,
text_column_name=data_columns[0],
tokenizer_model=tokenizer_model,
)
)

dataset = dataset.map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any pros/cons of doing it via Grain operations like in the hf pipeline?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset.map is the newer and recommended way of using Grain

input_pipeline_utils.SFTPromptMasking(
text_column_name=data_columns[0],
completion_only=config.sft_train_on_completion_only,
max_target_length=config.max_target_length,
unk_id=pad_id,
)
)
data_columns = ("inputs", "targets")

# Pack and batch examples
batch_size = config.global_batch_size_to_load // jax.process_count()
if config.expansion_factor_real_data > 1:
# global_batch_size_to_load has been expanded in pyconfig.py when expansion_factor_real_data > 1.
batch_size = int(batch_size // config.expansion_factor_real_data)

if config.packing:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this block is again identical to one in pretrain_preprocessing_pipeline().
Also, would be great to check if need some sft related modifications

length_struct = {col: config.max_target_length for col in data_columns}
max_segments = config.max_segments_per_seq
if max_segments is not None and max_segments <= 0:
max_segments = None
if config.grain_packing_type == "first_fit":
dataset = grain.experimental.FirstFitPackIterDataset(
dataset,
length_struct=length_struct,
num_packing_bins=batch_size,
max_sequences_per_bin=max_segments,
)
elif config.grain_packing_type == "best_fit":
dataset = BestFitPackIterDataset(dataset, length_struct=length_struct, num_packing_bins=batch_size)
elif config.grain_packing_type == "concat_then_split":
if config.add_bos and hasattr(tokenizer_model, "bos_id"):
dataset = grain.experimental.ConcatThenSplitIterDataset(
dataset,
length_struct=length_struct,
bos_handling=grain.experimental.BOSHandling.REPLACE_FIRST_TOKEN_WITH_BOS,
bos_token_id=tokenizer_model.bos_id,
)
else:
dataset = grain.experimental.ConcatThenSplitIterDataset(dataset, length_struct=length_struct)
else:
raise ValueError(f"Unknown packing type: {config.packing}")

rekey_dict = {
"targets_segmentation": "targets_segment_ids",
"inputs_segmentation": "inputs_segment_ids",
"targets_position": "targets_positions",
"inputs_position": "inputs_positions",
}
dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict))
else:
dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))

batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
dataset = dataset.batch(batch_size, batch_fn=batch_fn)

# Shift inputs for teacher-forced training
dataset = dataset.map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it alway be executed in a generic sft_preprocessing_pipeline() ?

input_pipeline_utils.ShiftData(
ignored_ids=[pad_id],
axis=1,
)
)
multiprocessing_options = (
pick_performance_config(
ds=dataset,
ram_budget_mb=config.grain_ram_budget_mb,
max_workers=None,
max_buffer_size=None,
).multiprocessing_options
if grain_worker_count == -1
else grain.MultiprocessingOptions(
num_workers=grain_worker_count,
per_worker_buffer_size=grain_per_worker_buffer_size,
)
)
dataset = dataset.mp_prefetch(multiprocessing_options)
return dataset


def make_grain_train_iterator(
config: ml_collections.ConfigDict,
global_mesh,
Expand Down Expand Up @@ -385,6 +545,15 @@ def make_grain_train_iterator(
grain_worker_count=config.grain_worker_count,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
)
elif config.use_sft:
train_dataloader = sft_preprocessing_pipeline(
train_ds,
config,
data_columns=config.train_data_columns,
tokenize=config.tokenize_train_data,
grain_worker_count=config.grain_worker_count,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
)
else:
train_dataloader = pretrain_preprocessing_pipeline(
train_ds,
Expand Down Expand Up @@ -415,7 +584,16 @@ def make_grain_train_iterator(
)
if config.use_dpo:
preprocessing_fn = functools.partial(
pretrain_preprocessing_pipeline,
dpo_preprocessing_pipeline,
config=config,
data_columns=config.train_data_columns,
tokenize=config.tokenize_train_data,
grain_worker_count=config.grain_worker_count,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
)
elif config.use_sft:
preprocessing_fn = functools.partial(
sft_preprocessing_pipeline,
config=config,
data_columns=config.train_data_columns,
tokenize=config.tokenize_train_data,
Expand Down Expand Up @@ -482,6 +660,15 @@ def make_grain_eval_iterator(
grain_worker_count=config.grain_worker_count_eval,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
)
elif config.use_sft:
eval_dataloader = sft_preprocessing_pipeline(
eval_ds,
config,
data_columns=config.eval_data_columns,
tokenize=config.tokenize_eval_data,
grain_worker_count=config.grain_worker_count_eval,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
)
else:
eval_dataloader = pretrain_preprocessing_pipeline(
eval_ds,
Expand Down Expand Up @@ -516,6 +703,15 @@ def make_grain_eval_iterator(
grain_worker_count=config.grain_worker_count_eval,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
)
elif config.use_sft:
preprocessing_fn = functools.partial(
sft_preprocessing_pipeline,
config=config,
data_columns=config.eval_data_columns,
tokenize=config.tokenize_eval_data,
grain_worker_count=config.grain_worker_count_eval,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
)
else:
preprocessing_fn = functools.partial(
pretrain_preprocessing_pipeline,
Expand Down
66 changes: 66 additions & 0 deletions tests/unit/grain_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tempfile
import unittest
import json
import numpy as np

import jax
import pytest
Expand Down Expand Up @@ -525,6 +526,71 @@ def mount_gcsfuse():
raise ValueError(f"Running setup_gcsfuse.sh failed with exit code: {exit_code}")


class GrainSFTParquetProcessingTest(unittest.TestCase):
"""Tests the SFT pipeline end-to-end using the real ultrachat_200k parquet dataset."""

def setUp(self):
super().setUp()

grain_train_file = "gs://maxtext-dataset/hf/ultrachat_200k/train_sft-*.parquet"
base_output_directory = "gs://max-experiments/"
config_file = get_test_config_path()

self.config = pyconfig.initialize(
[sys.argv[0], config_file],
per_device_batch_size=1,
run_name="test",
mesh_axes=["data"],
logical_axis_rules=[["batch", "data"]],
data_sharding=["data"],
base_output_directory=base_output_directory,
dataset_type="grain",
grain_file_type="parquet",
grain_train_files=grain_train_file,
use_sft=True, # Triggers your new SFT pipeline
sft_train_on_completion_only=True,
train_data_columns=["messages"],
tokenizer_type="huggingface",
tokenizer_path="HuggingFaceH4/zephyr-7b-beta", # The ungated tokenizer
max_target_length=128,
packing=True,
grain_worker_count=1,
grain_per_worker_buffer_size=1,
enable_checkpointing=False,
)
self.mesh_shape_1d = (len(jax.devices()),)
self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes)
self.process_indices = input_pipeline_interface.get_process_loading_real_data(
self.config.data_sharding,
self.config.global_batch_size_to_load,
self.config.global_batch_size_to_train_on,
self.config.max_target_length,
self.mesh,
)
self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)

def test_train_ds(self):
expected_shape = [jax.device_count(), self.config.max_target_length]
batch = next(self.train_iter)

# Assert all the required packing and target tensors were generated
self.assertEqual(
{k: list(v.shape) for k, v in batch.items()},
{
"inputs": expected_shape,
"inputs_position": expected_shape,
"inputs_segmentation": expected_shape,
"targets": expected_shape,
"targets_position": expected_shape,
"targets_segmentation": expected_shape,
},
)

# check to see that if prompts are masked, targets will differ from inputs
has_masked_tokens = np.any(batch["inputs"] != batch["targets"])
self.assertTrue(bool(has_masked_tokens), "Targets array should differ from inputs array due to prompt masking.")


if __name__ == "__main__":
mount_gcsfuse()
unittest.main()
Loading