diff --git a/src/maxtext/input_pipeline/grain_data_processing.py b/src/maxtext/input_pipeline/grain_data_processing.py index 154dac457a..3e835e3932 100644 --- a/src/maxtext/input_pipeline/grain_data_processing.py +++ b/src/maxtext/input_pipeline/grain_data_processing.py @@ -352,6 +352,166 @@ def dpo_preprocessing_pipeline( return dataset +def _format_chat_template_grain(element, data_columns, tokenizer_model): + """Grain-compatible mapping function to format raw columns into conversational messages.""" + # Convert raw columns to conversational messages + if "messages" in data_columns: + messages = element["messages"] + elif set(data_columns) == {"prompt", "completion"}: + messages = [{"role": "user", "content": element["prompt"]}, {"role": "assistant", "content": element["completion"]}] + elif set(data_columns) == {"question", "answer"}: + messages = [{"role": "user", "content": element["question"]}, {"role": "assistant", "content": element["answer"]}] + else: + # Fallback if it's already a single string + messages = element[data_columns[0]] + + # Assign the standardized messages back to the primary column + element[data_columns[0]] = messages + + return input_pipeline_utils.apply_chat_template( + element, tokenizer_model=tokenizer_model, data_column_name=data_columns[0] + ) + + +def _tokenize_sft_chunks(element, text_column_name, tokenizer_model): + """Tokenize each chunk individually without truncating.""" + text_chunks = element[text_column_name] + element[text_column_name] = [tokenizer_model.encode(chunk) for chunk in text_chunks] + return element + + +def sft_preprocessing_pipeline( + dataset, + config, + data_columns, + tokenize, + grain_worker_count, + grain_per_worker_buffer_size, +): + """Use grain pipeline to pre-process the dataset and return iterators for sft fine-tuning""" + if config.grain_file_type == "arrayrecord": + dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize)) + dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) + else: + dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns)) + + tokenizer_model = tokenizer.build_tokenizer( + config.tokenizer_path, + config.tokenizer_type, + config.add_bos, + config.add_eos, + config.hf_access_token, + ) + if tokenizer_model.pad_id is not None: + pad_id = tokenizer_model.pad_id + elif tokenizer_model.unk_id is not None: + pad_id = tokenizer_model.unk_id + else: + pad_id = -1 + + tokenizer_model = tokenizer_model.tokenizer + + if getattr(config, "chat_template", None) and hasattr(tokenizer_model, "chat_template"): + tokenizer_model.chat_template = config.chat_template + + supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]] + assert any( + set(data_columns) == set(supported) for supported in supported_columns + ), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_columns}" + + dataset = dataset.map( + functools.partial(_format_chat_template_grain, data_columns=data_columns, tokenizer_model=tokenizer_model) + ) + + if tokenize: + dataset = dataset.map( + functools.partial( + _tokenize_sft_chunks, + text_column_name=data_columns[0], + tokenizer_model=tokenizer_model, + ) + ) + + dataset = dataset.map( + input_pipeline_utils.SFTPromptMasking( + text_column_name=data_columns[0], + completion_only=config.sft_train_on_completion_only, + max_target_length=config.max_target_length, + unk_id=pad_id, + ) + ) + data_columns = ("inputs", "targets") + + # Pack and batch examples + batch_size = config.global_batch_size_to_load // jax.process_count() + if config.expansion_factor_real_data > 1: + # global_batch_size_to_load has been expanded in pyconfig.py when expansion_factor_real_data > 1. + batch_size = int(batch_size // config.expansion_factor_real_data) + + if config.packing: + length_struct = {col: config.max_target_length for col in data_columns} + max_segments = config.max_segments_per_seq + if max_segments is not None and max_segments <= 0: + max_segments = None + if config.grain_packing_type == "first_fit": + dataset = grain.experimental.FirstFitPackIterDataset( + dataset, + length_struct=length_struct, + num_packing_bins=batch_size, + max_sequences_per_bin=max_segments, + ) + elif config.grain_packing_type == "best_fit": + dataset = BestFitPackIterDataset(dataset, length_struct=length_struct, num_packing_bins=batch_size) + elif config.grain_packing_type == "concat_then_split": + if config.add_bos and hasattr(tokenizer_model, "bos_id"): + dataset = grain.experimental.ConcatThenSplitIterDataset( + dataset, + length_struct=length_struct, + bos_handling=grain.experimental.BOSHandling.REPLACE_FIRST_TOKEN_WITH_BOS, + bos_token_id=tokenizer_model.bos_id, + ) + else: + dataset = grain.experimental.ConcatThenSplitIterDataset(dataset, length_struct=length_struct) + else: + raise ValueError(f"Unknown packing type: {config.packing}") + + rekey_dict = { + "targets_segmentation": "targets_segment_ids", + "inputs_segmentation": "inputs_segment_ids", + "targets_position": "targets_positions", + "inputs_position": "inputs_positions", + } + dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict)) + else: + dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id)) + + batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id) + dataset = dataset.batch(batch_size, batch_fn=batch_fn) + + # Shift inputs for teacher-forced training + dataset = dataset.map( + input_pipeline_utils.ShiftData( + ignored_ids=[pad_id], + axis=1, + ) + ) + multiprocessing_options = ( + pick_performance_config( + ds=dataset, + ram_budget_mb=config.grain_ram_budget_mb, + max_workers=None, + max_buffer_size=None, + ).multiprocessing_options + if grain_worker_count == -1 + else grain.MultiprocessingOptions( + num_workers=grain_worker_count, + per_worker_buffer_size=grain_per_worker_buffer_size, + ) + ) + dataset = dataset.mp_prefetch(multiprocessing_options) + return dataset + + def make_grain_train_iterator( config: ml_collections.ConfigDict, global_mesh, @@ -385,6 +545,15 @@ def make_grain_train_iterator( grain_worker_count=config.grain_worker_count, grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, ) + elif config.use_sft: + train_dataloader = sft_preprocessing_pipeline( + train_ds, + config, + data_columns=config.train_data_columns, + tokenize=config.tokenize_train_data, + grain_worker_count=config.grain_worker_count, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, + ) else: train_dataloader = pretrain_preprocessing_pipeline( train_ds, @@ -415,7 +584,16 @@ def make_grain_train_iterator( ) if config.use_dpo: preprocessing_fn = functools.partial( - pretrain_preprocessing_pipeline, + dpo_preprocessing_pipeline, + config=config, + data_columns=config.train_data_columns, + tokenize=config.tokenize_train_data, + grain_worker_count=config.grain_worker_count, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, + ) + elif config.use_sft: + preprocessing_fn = functools.partial( + sft_preprocessing_pipeline, config=config, data_columns=config.train_data_columns, tokenize=config.tokenize_train_data, @@ -482,6 +660,15 @@ def make_grain_eval_iterator( grain_worker_count=config.grain_worker_count_eval, grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, ) + elif config.use_sft: + eval_dataloader = sft_preprocessing_pipeline( + eval_ds, + config, + data_columns=config.eval_data_columns, + tokenize=config.tokenize_eval_data, + grain_worker_count=config.grain_worker_count_eval, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, + ) else: eval_dataloader = pretrain_preprocessing_pipeline( eval_ds, @@ -516,6 +703,15 @@ def make_grain_eval_iterator( grain_worker_count=config.grain_worker_count_eval, grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, ) + elif config.use_sft: + preprocessing_fn = functools.partial( + sft_preprocessing_pipeline, + config=config, + data_columns=config.eval_data_columns, + tokenize=config.tokenize_eval_data, + grain_worker_count=config.grain_worker_count_eval, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, + ) else: preprocessing_fn = functools.partial( pretrain_preprocessing_pipeline, diff --git a/tests/unit/grain_data_processing_test.py b/tests/unit/grain_data_processing_test.py index 90700b3790..35b1b3b5b5 100644 --- a/tests/unit/grain_data_processing_test.py +++ b/tests/unit/grain_data_processing_test.py @@ -21,6 +21,7 @@ import tempfile import unittest import json +import numpy as np import jax import pytest @@ -525,6 +526,71 @@ def mount_gcsfuse(): raise ValueError(f"Running setup_gcsfuse.sh failed with exit code: {exit_code}") +class GrainSFTParquetProcessingTest(unittest.TestCase): + """Tests the SFT pipeline end-to-end using the real ultrachat_200k parquet dataset.""" + + def setUp(self): + super().setUp() + + grain_train_file = "gs://maxtext-dataset/hf/ultrachat_200k/train_sft-*.parquet" + base_output_directory = "gs://max-experiments/" + config_file = get_test_config_path() + + self.config = pyconfig.initialize( + [sys.argv[0], config_file], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory=base_output_directory, + dataset_type="grain", + grain_file_type="parquet", + grain_train_files=grain_train_file, + use_sft=True, # Triggers your new SFT pipeline + sft_train_on_completion_only=True, + train_data_columns=["messages"], + tokenizer_type="huggingface", + tokenizer_path="HuggingFaceH4/zephyr-7b-beta", # The ungated tokenizer + max_target_length=128, + packing=True, + grain_worker_count=1, + grain_per_worker_buffer_size=1, + enable_checkpointing=False, + ) + self.mesh_shape_1d = (len(jax.devices()),) + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.process_indices = input_pipeline_interface.get_process_loading_real_data( + self.config.data_sharding, + self.config.global_batch_size_to_load, + self.config.global_batch_size_to_train_on, + self.config.max_target_length, + self.mesh, + ) + self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + + def test_train_ds(self): + expected_shape = [jax.device_count(), self.config.max_target_length] + batch = next(self.train_iter) + + # Assert all the required packing and target tensors were generated + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) + + # check to see that if prompts are masked, targets will differ from inputs + has_masked_tokens = np.any(batch["inputs"] != batch["targets"]) + self.assertTrue(bool(has_masked_tokens), "Targets array should differ from inputs array due to prompt masking.") + + if __name__ == "__main__": mount_gcsfuse() unittest.main()