From 30cfc2cf4be3d88adf90eeee1b2fbb383b5da89f Mon Sep 17 00:00:00 2001 From: ajkv-google Date: Wed, 18 Mar 2026 00:02:00 +0000 Subject: [PATCH 1/4] Added SFT support for grain input pipeline --- .../input_pipeline/grain_data_processing.py | 191 +++++++++++++++++- tests/unit/grain_data_processing_test.py | 75 +++++++ 2 files changed, 265 insertions(+), 1 deletion(-) diff --git a/src/maxtext/input_pipeline/grain_data_processing.py b/src/maxtext/input_pipeline/grain_data_processing.py index 154dac457a..e3ce398d0a 100644 --- a/src/maxtext/input_pipeline/grain_data_processing.py +++ b/src/maxtext/input_pipeline/grain_data_processing.py @@ -352,6 +352,159 @@ def dpo_preprocessing_pipeline( return dataset +def _format_chat_template_grain(element, data_columns, tokenizer_model): + """Grain-compatible mapping function to format raw columns into conversational messages.""" + # Convert raw columns to conversational messages + if "messages" in data_columns: + messages = element["messages"] + elif set(data_columns) == {"prompt", "completion"}: + messages = [{"role": "user", "content": element["prompt"]}, {"role": "assistant", "content": element["completion"]}] + elif set(data_columns) == {"question", "answer"}: + messages = [{"role": "user", "content": element["question"]}, {"role": "assistant", "content": element["answer"]}] + else: + # Fallback if it's already a single string + messages = element[data_columns[0]] + + # Assign the standardized messages back to the primary column + element[data_columns[0]] = messages + + return input_pipeline_utils.apply_chat_template( + element, tokenizer_model=tokenizer_model, data_column_name=data_columns[0] + ) + + +def sft_preprocessing_pipeline( + dataset, + config, + data_columns, + tokenize, + grain_worker_count, + grain_per_worker_buffer_size, +): + """Use grain pipeline to pre-process the dataset and return iterators for sft fine-tuning""" + if config.grain_file_type == "arrayrecord": + dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize)) + dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) + else: + dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns)) + + tokenizer_model = tokenizer.build_tokenizer( + config.tokenizer_path, + config.tokenizer_type, + config.add_bos, + config.add_eos, + config.hf_access_token, + ) + if tokenizer_model.pad_id is not None: + pad_id = tokenizer_model.pad_id + elif tokenizer_model.unk_id is not None: + pad_id = tokenizer_model.unk_id + else: + pad_id = -1 + + if getattr(config, "chat_template", None) and hasattr(tokenizer_model, "chat_template"): + tokenizer_model.chat_template = config.chat_template + + supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]] + assert any( + set(data_columns) == set(supported) for supported in supported_columns + ), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_columns}" + + dataset = dataset.map( + functools.partial(_format_chat_template_grain, data_columns=data_columns, tokenizer_model=tokenizer_model) + ) + + if tokenize: + text_column_name = data_columns[0] + + def tokenize_sft_chunks(element, col=text_column_name): + # Tokenize each chunk individually without truncating + text_chunks = element[col] + element[col] = [tokenizer_model.encode(chunk) for chunk in text_chunks] + return element + + dataset = dataset.map(tokenize_sft_chunks) + + dataset = dataset.map( + input_pipeline_utils.SFTPromptMasking( + text_column_name=data_columns[0], + completion_only=config.sft_train_on_completion_only, + max_target_length=config.max_target_length, + unk_id=pad_id, + ) + ) + data_columns = ("inputs", "targets") + + # Pack and batch examples + batch_size = config.global_batch_size_to_load // jax.process_count() + if config.expansion_factor_real_data > 1: + # global_batch_size_to_load has been expanded in pyconfig.py when expansion_factor_real_data > 1. + batch_size = int(batch_size // config.expansion_factor_real_data) + + if config.packing: + length_struct = {col: config.max_target_length for col in data_columns} + max_segments = config.max_segments_per_seq + if max_segments is not None and max_segments <= 0: + max_segments = None + if config.grain_packing_type == "first_fit": + dataset = grain.experimental.FirstFitPackIterDataset( + dataset, + length_struct=length_struct, + num_packing_bins=batch_size, + max_sequences_per_bin=max_segments, + ) + elif config.grain_packing_type == "best_fit": + dataset = BestFitPackIterDataset(dataset, length_struct=length_struct, num_packing_bins=batch_size) + elif config.grain_packing_type == "concat_then_split": + if config.add_bos and hasattr(tokenizer_model, "bos_id"): + dataset = grain.experimental.ConcatThenSplitIterDataset( + dataset, + length_struct=length_struct, + bos_handling=grain.experimental.BOSHandling.REPLACE_FIRST_TOKEN_WITH_BOS, + bos_token_id=tokenizer_model.bos_id, + ) + else: + dataset = grain.experimental.ConcatThenSplitIterDataset(dataset, length_struct=length_struct) + else: + raise ValueError(f"Unknown packing type: {config.packing}") + + rekey_dict = { + "targets_segmentation": "targets_segment_ids", + "inputs_segmentation": "inputs_segment_ids", + "targets_position": "targets_positions", + "inputs_position": "inputs_positions", + } + dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict)) + else: + dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id)) + + batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id) + dataset = dataset.batch(batch_size, batch_fn=batch_fn) + + # Shift inputs for teacher-forced training + dataset = dataset.map( + input_pipeline_utils.ShiftData( + ignored_ids=[pad_id], + axis=1, + ) + ) + multiprocessing_options = ( + pick_performance_config( + ds=dataset, + ram_budget_mb=config.grain_ram_budget_mb, + max_workers=None, + max_buffer_size=None, + ).multiprocessing_options + if grain_worker_count == -1 + else grain.MultiprocessingOptions( + num_workers=grain_worker_count, + per_worker_buffer_size=grain_per_worker_buffer_size, + ) + ) + dataset = dataset.mp_prefetch(multiprocessing_options) + return dataset + + def make_grain_train_iterator( config: ml_collections.ConfigDict, global_mesh, @@ -385,6 +538,15 @@ def make_grain_train_iterator( grain_worker_count=config.grain_worker_count, grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, ) + elif config.use_sft: + train_dataloader = sft_preprocessing_pipeline( + train_ds, + config, + data_columns=config.train_data_columns, + tokenize=config.tokenize_train_data, + grain_worker_count=config.grain_worker_count, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, + ) else: train_dataloader = pretrain_preprocessing_pipeline( train_ds, @@ -415,7 +577,16 @@ def make_grain_train_iterator( ) if config.use_dpo: preprocessing_fn = functools.partial( - pretrain_preprocessing_pipeline, + dpo_preprocessing_pipeline, + config=config, + data_columns=config.train_data_columns, + tokenize=config.tokenize_train_data, + grain_worker_count=config.grain_worker_count, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, + ) + elif config.use_sft: + preprocessing_fn = functools.partial( + sft_preprocessing_pipeline, config=config, data_columns=config.train_data_columns, tokenize=config.tokenize_train_data, @@ -482,6 +653,15 @@ def make_grain_eval_iterator( grain_worker_count=config.grain_worker_count_eval, grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, ) + elif config.use_sft: + eval_dataloader = sft_preprocessing_pipeline( + eval_ds, + config, + data_columns=config.eval_data_columns, + tokenize=config.tokenize_eval_data, + grain_worker_count=config.grain_worker_count_eval, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, + ) else: eval_dataloader = pretrain_preprocessing_pipeline( eval_ds, @@ -516,6 +696,15 @@ def make_grain_eval_iterator( grain_worker_count=config.grain_worker_count_eval, grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, ) + elif config.use_sft: + preprocessing_fn = functools.partial( + sft_preprocessing_pipeline, + config=config, + data_columns=config.eval_data_columns, + tokenize=config.tokenize_eval_data, + grain_worker_count=config.grain_worker_count_eval, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, + ) else: preprocessing_fn = functools.partial( pretrain_preprocessing_pipeline, diff --git a/tests/unit/grain_data_processing_test.py b/tests/unit/grain_data_processing_test.py index 90700b3790..b906431862 100644 --- a/tests/unit/grain_data_processing_test.py +++ b/tests/unit/grain_data_processing_test.py @@ -21,11 +21,15 @@ import tempfile import unittest import json +import ml_collections +import numpy as np import jax import pytest from jax.sharding import Mesh from jax.experimental import mesh_utils +from unittest import mock +import grain.python as grain from maxtext.configs import pyconfig from maxtext.input_pipeline import grain_data_processing @@ -524,7 +528,78 @@ def mount_gcsfuse(): if exit_code != os.EX_OK: raise ValueError(f"Running setup_gcsfuse.sh failed with exit code: {exit_code}") +class GrainSFTPipelineTest(unittest.TestCase): + """Tests the full SFT preprocessing pipeline end-to-end using dummy data.""" + def setUp(self): + super().setUp() + # Create a minimal config to satisfy the pipeline's requirements + self.config = ml_collections.ConfigDict({ + "grain_file_type": "in_memory", # Skips arrayrecord parsing + "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), + "tokenizer_type": "sentencepiece", + "add_bos": True, + "add_eos": True, + "hf_access_token": "", + "use_truncation": False, + "max_target_length": 16, + "sft_train_on_completion_only": True, + "packing": False, + "global_batch_size_to_load": 2, # Using 2 examples + "expansion_factor_real_data": 1.0, + "grain_ram_budget_mb": 512, + # A very basic chat template for testing purposes + "chat_template": "{% for message in messages %}{{ message['role'] + ': ' + message['content'] + ' ' }}{% endfor %}", + }) + + @mock.patch('maxtext.input_pipeline.input_pipeline_utils.apply_chat_template') + def test_sft_preprocessing_pipeline(self, mock_apply_chat_template): + # Fake the exact chunked structure and is_prompt array that MaxText expects + def fake_apply_chat(element, tokenizer_model, data_column_name): + # Return hardcoded strings instead of referencing the overwritten dictionary + element[data_column_name] = ["What is 2+2? ", "It is 4."] + element['is_prompt'] = [True, False] + return element + + mock_apply_chat_template.side_effect = fake_apply_chat + + # Create a dummy in-memory dataset + dummy_data = [ + {"prompt": "What is 2+2?", "completion": "It is 4."}, + {"prompt": "Say hello.", "completion": "Hello!"} + ] + dataset = grain.MapDataset.source(dummy_data) + dataset = dataset.to_iter_dataset() + data_columns = ["prompt", "completion"] + + # Run pipeline + pipeline_iterator = grain_data_processing.sft_preprocessing_pipeline( + dataset=dataset, + config=self.config, + data_columns=data_columns, + tokenize=True, + grain_worker_count=0, + grain_per_worker_buffer_size=1, + ) + + # Get the first batch + iterator = iter(pipeline_iterator) + batch = next(iterator) + + # Assert the pipeline output is correct + self.assertIn("inputs", batch) + self.assertIn("targets", batch) + + expected_shape = (2, self.config.max_target_length) + + # Check shapes + self.assertEqual(batch["inputs"].shape, expected_shape) + self.assertEqual(batch["targets"].shape, expected_shape) + + # Check for masked tokens + has_masked_tokens = np.any((batch["targets"] == 0) | (batch["targets"] == -1)) + self.assertTrue(has_masked_tokens, "Targets array should contain masked (ignore) IDs for the prompt sections.") + if __name__ == "__main__": mount_gcsfuse() unittest.main() From 153953afbb6310f0f6cd4eef3d418148d3d87321 Mon Sep 17 00:00:00 2001 From: ajkv-google Date: Wed, 18 Mar 2026 00:28:41 +0000 Subject: [PATCH 2/4] Updated code formatting and spacing --- tests/unit/grain_data_processing_test.py | 65 +++++++++++++----------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/tests/unit/grain_data_processing_test.py b/tests/unit/grain_data_processing_test.py index b906431862..ba1a70ff49 100644 --- a/tests/unit/grain_data_processing_test.py +++ b/tests/unit/grain_data_processing_test.py @@ -528,46 +528,48 @@ def mount_gcsfuse(): if exit_code != os.EX_OK: raise ValueError(f"Running setup_gcsfuse.sh failed with exit code: {exit_code}") + class GrainSFTPipelineTest(unittest.TestCase): """Tests the full SFT preprocessing pipeline end-to-end using dummy data.""" def setUp(self): super().setUp() # Create a minimal config to satisfy the pipeline's requirements - self.config = ml_collections.ConfigDict({ - "grain_file_type": "in_memory", # Skips arrayrecord parsing - "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), - "tokenizer_type": "sentencepiece", - "add_bos": True, - "add_eos": True, - "hf_access_token": "", - "use_truncation": False, - "max_target_length": 16, - "sft_train_on_completion_only": True, - "packing": False, - "global_batch_size_to_load": 2, # Using 2 examples - "expansion_factor_real_data": 1.0, - "grain_ram_budget_mb": 512, - # A very basic chat template for testing purposes - "chat_template": "{% for message in messages %}{{ message['role'] + ': ' + message['content'] + ' ' }}{% endfor %}", - }) - - @mock.patch('maxtext.input_pipeline.input_pipeline_utils.apply_chat_template') + self.config = ml_collections.ConfigDict( + { + "grain_file_type": "in_memory", # Skips arrayrecord parsing + "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), + "tokenizer_type": "sentencepiece", + "add_bos": True, + "add_eos": True, + "hf_access_token": "", + "use_truncation": False, + "max_target_length": 16, + "sft_train_on_completion_only": True, + "packing": False, + "global_batch_size_to_load": 2, # Using 2 examples + "expansion_factor_real_data": 1.0, + "grain_ram_budget_mb": 512, + # A very basic chat template for testing purposes + "chat_template": ( + "{% for message in messages %}" "{{ message['role'] + ': ' + message['content'] + ' ' }}" "{% endfor %}" + ), + } + ) + + @mock.patch("maxtext.input_pipeline.input_pipeline_utils.apply_chat_template") def test_sft_preprocessing_pipeline(self, mock_apply_chat_template): # Fake the exact chunked structure and is_prompt array that MaxText expects def fake_apply_chat(element, tokenizer_model, data_column_name): - # Return hardcoded strings instead of referencing the overwritten dictionary - element[data_column_name] = ["What is 2+2? ", "It is 4."] - element['is_prompt'] = [True, False] - return element - + # Return hardcoded strings instead of referencing the overwritten dictionary + element[data_column_name] = ["What is 2+2? ", "It is 4."] + element["is_prompt"] = [True, False] + return element + mock_apply_chat_template.side_effect = fake_apply_chat # Create a dummy in-memory dataset - dummy_data = [ - {"prompt": "What is 2+2?", "completion": "It is 4."}, - {"prompt": "Say hello.", "completion": "Hello!"} - ] + dummy_data = [{"prompt": "What is 2+2?", "completion": "It is 4."}, {"prompt": "Say hello.", "completion": "Hello!"}] dataset = grain.MapDataset.source(dummy_data) dataset = dataset.to_iter_dataset() data_columns = ["prompt", "completion"] @@ -589,9 +591,9 @@ def fake_apply_chat(element, tokenizer_model, data_column_name): # Assert the pipeline output is correct self.assertIn("inputs", batch) self.assertIn("targets", batch) - + expected_shape = (2, self.config.max_target_length) - + # Check shapes self.assertEqual(batch["inputs"].shape, expected_shape) self.assertEqual(batch["targets"].shape, expected_shape) @@ -599,7 +601,8 @@ def fake_apply_chat(element, tokenizer_model, data_column_name): # Check for masked tokens has_masked_tokens = np.any((batch["targets"] == 0) | (batch["targets"] == -1)) self.assertTrue(has_masked_tokens, "Targets array should contain masked (ignore) IDs for the prompt sections.") - + + if __name__ == "__main__": mount_gcsfuse() unittest.main() From a5949cf46c4d240772ab6379fbedd4506f5bd669 Mon Sep 17 00:00:00 2001 From: ajkv-google Date: Wed, 18 Mar 2026 21:54:09 +0000 Subject: [PATCH 3/4] Updated grain sft unit test and sft implementation to use tokenizer --- .../input_pipeline/grain_data_processing.py | 2 + tests/unit/grain_data_processing_test.py | 120 ++++++++---------- 2 files changed, 56 insertions(+), 66 deletions(-) diff --git a/src/maxtext/input_pipeline/grain_data_processing.py b/src/maxtext/input_pipeline/grain_data_processing.py index e3ce398d0a..9664170aee 100644 --- a/src/maxtext/input_pipeline/grain_data_processing.py +++ b/src/maxtext/input_pipeline/grain_data_processing.py @@ -402,6 +402,8 @@ def sft_preprocessing_pipeline( else: pad_id = -1 + tokenizer_model = tokenizer_model.tokenizer + if getattr(config, "chat_template", None) and hasattr(tokenizer_model, "chat_template"): tokenizer_model.chat_template = config.chat_template diff --git a/tests/unit/grain_data_processing_test.py b/tests/unit/grain_data_processing_test.py index ba1a70ff49..35b1b3b5b5 100644 --- a/tests/unit/grain_data_processing_test.py +++ b/tests/unit/grain_data_processing_test.py @@ -21,15 +21,12 @@ import tempfile import unittest import json -import ml_collections import numpy as np import jax import pytest from jax.sharding import Mesh from jax.experimental import mesh_utils -from unittest import mock -import grain.python as grain from maxtext.configs import pyconfig from maxtext.input_pipeline import grain_data_processing @@ -529,78 +526,69 @@ def mount_gcsfuse(): raise ValueError(f"Running setup_gcsfuse.sh failed with exit code: {exit_code}") -class GrainSFTPipelineTest(unittest.TestCase): - """Tests the full SFT preprocessing pipeline end-to-end using dummy data.""" +class GrainSFTParquetProcessingTest(unittest.TestCase): + """Tests the SFT pipeline end-to-end using the real ultrachat_200k parquet dataset.""" def setUp(self): super().setUp() - # Create a minimal config to satisfy the pipeline's requirements - self.config = ml_collections.ConfigDict( - { - "grain_file_type": "in_memory", # Skips arrayrecord parsing - "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), - "tokenizer_type": "sentencepiece", - "add_bos": True, - "add_eos": True, - "hf_access_token": "", - "use_truncation": False, - "max_target_length": 16, - "sft_train_on_completion_only": True, - "packing": False, - "global_batch_size_to_load": 2, # Using 2 examples - "expansion_factor_real_data": 1.0, - "grain_ram_budget_mb": 512, - # A very basic chat template for testing purposes - "chat_template": ( - "{% for message in messages %}" "{{ message['role'] + ': ' + message['content'] + ' ' }}" "{% endfor %}" - ), - } - ) - @mock.patch("maxtext.input_pipeline.input_pipeline_utils.apply_chat_template") - def test_sft_preprocessing_pipeline(self, mock_apply_chat_template): - # Fake the exact chunked structure and is_prompt array that MaxText expects - def fake_apply_chat(element, tokenizer_model, data_column_name): - # Return hardcoded strings instead of referencing the overwritten dictionary - element[data_column_name] = ["What is 2+2? ", "It is 4."] - element["is_prompt"] = [True, False] - return element - - mock_apply_chat_template.side_effect = fake_apply_chat - - # Create a dummy in-memory dataset - dummy_data = [{"prompt": "What is 2+2?", "completion": "It is 4."}, {"prompt": "Say hello.", "completion": "Hello!"}] - dataset = grain.MapDataset.source(dummy_data) - dataset = dataset.to_iter_dataset() - data_columns = ["prompt", "completion"] - - # Run pipeline - pipeline_iterator = grain_data_processing.sft_preprocessing_pipeline( - dataset=dataset, - config=self.config, - data_columns=data_columns, - tokenize=True, - grain_worker_count=0, + grain_train_file = "gs://maxtext-dataset/hf/ultrachat_200k/train_sft-*.parquet" + base_output_directory = "gs://max-experiments/" + config_file = get_test_config_path() + + self.config = pyconfig.initialize( + [sys.argv[0], config_file], + per_device_batch_size=1, + run_name="test", + mesh_axes=["data"], + logical_axis_rules=[["batch", "data"]], + data_sharding=["data"], + base_output_directory=base_output_directory, + dataset_type="grain", + grain_file_type="parquet", + grain_train_files=grain_train_file, + use_sft=True, # Triggers your new SFT pipeline + sft_train_on_completion_only=True, + train_data_columns=["messages"], + tokenizer_type="huggingface", + tokenizer_path="HuggingFaceH4/zephyr-7b-beta", # The ungated tokenizer + max_target_length=128, + packing=True, + grain_worker_count=1, grain_per_worker_buffer_size=1, + enable_checkpointing=False, ) + self.mesh_shape_1d = (len(jax.devices()),) + self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) + self.process_indices = input_pipeline_interface.get_process_loading_real_data( + self.config.data_sharding, + self.config.global_batch_size_to_load, + self.config.global_batch_size_to_train_on, + self.config.max_target_length, + self.mesh, + ) + self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) - # Get the first batch - iterator = iter(pipeline_iterator) - batch = next(iterator) - - # Assert the pipeline output is correct - self.assertIn("inputs", batch) - self.assertIn("targets", batch) - - expected_shape = (2, self.config.max_target_length) + def test_train_ds(self): + expected_shape = [jax.device_count(), self.config.max_target_length] + batch = next(self.train_iter) - # Check shapes - self.assertEqual(batch["inputs"].shape, expected_shape) - self.assertEqual(batch["targets"].shape, expected_shape) + # Assert all the required packing and target tensors were generated + self.assertEqual( + {k: list(v.shape) for k, v in batch.items()}, + { + "inputs": expected_shape, + "inputs_position": expected_shape, + "inputs_segmentation": expected_shape, + "targets": expected_shape, + "targets_position": expected_shape, + "targets_segmentation": expected_shape, + }, + ) - # Check for masked tokens - has_masked_tokens = np.any((batch["targets"] == 0) | (batch["targets"] == -1)) - self.assertTrue(has_masked_tokens, "Targets array should contain masked (ignore) IDs for the prompt sections.") + # check to see that if prompts are masked, targets will differ from inputs + has_masked_tokens = np.any(batch["inputs"] != batch["targets"]) + self.assertTrue(bool(has_masked_tokens), "Targets array should differ from inputs array due to prompt masking.") if __name__ == "__main__": From 15fcafa75dbb18d4a42e0863540bce0837a29ff9 Mon Sep 17 00:00:00 2001 From: ajkv-google Date: Wed, 18 Mar 2026 23:30:41 +0000 Subject: [PATCH 4/4] Cleaned up code for readability --- .../input_pipeline/grain_data_processing.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/maxtext/input_pipeline/grain_data_processing.py b/src/maxtext/input_pipeline/grain_data_processing.py index 9664170aee..3e835e3932 100644 --- a/src/maxtext/input_pipeline/grain_data_processing.py +++ b/src/maxtext/input_pipeline/grain_data_processing.py @@ -373,6 +373,13 @@ def _format_chat_template_grain(element, data_columns, tokenizer_model): ) +def _tokenize_sft_chunks(element, text_column_name, tokenizer_model): + """Tokenize each chunk individually without truncating.""" + text_chunks = element[text_column_name] + element[text_column_name] = [tokenizer_model.encode(chunk) for chunk in text_chunks] + return element + + def sft_preprocessing_pipeline( dataset, config, @@ -417,15 +424,13 @@ def sft_preprocessing_pipeline( ) if tokenize: - text_column_name = data_columns[0] - - def tokenize_sft_chunks(element, col=text_column_name): - # Tokenize each chunk individually without truncating - text_chunks = element[col] - element[col] = [tokenizer_model.encode(chunk) for chunk in text_chunks] - return element - - dataset = dataset.map(tokenize_sft_chunks) + dataset = dataset.map( + functools.partial( + _tokenize_sft_chunks, + text_column_name=data_columns[0], + tokenizer_model=tokenizer_model, + ) + ) dataset = dataset.map( input_pipeline_utils.SFTPromptMasking(