Added SFT Pre-Processing for Grain Input Pipeline#3437
Added SFT Pre-Processing for Grain Input Pipeline#3437ajkv-google wants to merge 4 commits intomainfrom
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
vlad-karp
left a comment
There was a problem hiding this comment.
It would also be great to test not only with maxtext general sft but with distillation sft pipeline as well
| grain_per_worker_buffer_size, | ||
| ): | ||
| """Use grain pipeline to pre-process the dataset and return iterators for sft fine-tuning""" | ||
| if config.grain_file_type == "arrayrecord": |
There was a problem hiding this comment.
there is exactly this block in the pretrain_preprocessing_pipeline(). Almost identical code there suggests it would be great to reuse it
There was a problem hiding this comment.
Yes, there are many common operations between pretrain and sft. I think it's a good idea to extract the common pattern into util functions
| else: | ||
| dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns)) | ||
|
|
||
| tokenizer_model = tokenizer.build_tokenizer( |
| if getattr(config, "chat_template", None) and hasattr(tokenizer_model, "chat_template"): | ||
| tokenizer_model.chat_template = config.chat_template | ||
|
|
||
| supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]] |
There was a problem hiding this comment.
This is again repeating logic from the hf pipeline, one can extract utility routines to avoid code repetition
| messages = [{"role": "user", "content": element["prompt"]}, {"role": "assistant", "content": element["completion"]}] | ||
| elif set(data_columns) == {"question", "answer"}: | ||
| messages = [{"role": "user", "content": element["question"]}, {"role": "assistant", "content": element["answer"]}] | ||
| else: |
There was a problem hiding this comment.
HF pipeline asserts sft is running on a conversational format
| ) | ||
| ) | ||
|
|
||
| dataset = dataset.map( |
There was a problem hiding this comment.
any pros/cons of doing it via Grain operations like in the hf pipeline?
There was a problem hiding this comment.
dataset.map is the newer and recommended way of using Grain
| # global_batch_size_to_load has been expanded in pyconfig.py when expansion_factor_real_data > 1. | ||
| batch_size = int(batch_size // config.expansion_factor_real_data) | ||
|
|
||
| if config.packing: |
There was a problem hiding this comment.
this block is again identical to one in pretrain_preprocessing_pipeline().
Also, would be great to check if need some sft related modifications
| dataset = dataset.batch(batch_size, batch_fn=batch_fn) | ||
|
|
||
| # Shift inputs for teacher-forced training | ||
| dataset = dataset.map( |
There was a problem hiding this comment.
should it alway be executed in a generic sft_preprocessing_pipeline() ?
| ) | ||
| ) | ||
|
|
||
| dataset = dataset.map( |
There was a problem hiding this comment.
dataset.map is the newer and recommended way of using Grain
| grain_per_worker_buffer_size, | ||
| ): | ||
| """Use grain pipeline to pre-process the dataset and return iterators for sft fine-tuning""" | ||
| if config.grain_file_type == "arrayrecord": |
There was a problem hiding this comment.
Yes, there are many common operations between pretrain and sft. I think it's a good idea to extract the common pattern into util functions
| ), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_columns}" | ||
|
|
||
| dataset = dataset.map( | ||
| functools.partial(_format_chat_template_grain, data_columns=data_columns, tokenizer_model=tokenizer_model) |
There was a problem hiding this comment.
The hf pipeline calls instruction_data_processing.convert_to_conversational_format, do we support the same conversion here? https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/input_pipeline/hf_data_processing.py#L261
Description
This PR introduces SFT support to the Grain input pipeline by adding a separate
sft_preprocessing_pipelinefunction. Rather than cluttering the existing pretrain code, it uses simple conditionals inside the train and eval iterators to route to this new SFT logic. I followed the existing Hugging Face SFT implementation and adapted its logic to be compatible with Grain's element-wise datasets.Tests
I added a unit test to verify end-to-end functionality to make sure the Grain SFT pipeline formats the data and outputs correctly. Ran this command to execute the unit test:
pytest tests/unit/grain_data_processing_test.py::GrainSFTParquetProcessingTest -vThis is the output of the test: Test Passed Output
Also, ran the training pipeline in Maxtext with sft enabled using a grain dataset with this command:
python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=test_grain_sft dataset_type=grain grain_file_type=parquet grain_train_files=gs://maxtext-dataset/hf/ultrachat_200k/train_sft-*.parquet steps=10 tokenizer_type=huggingface tokenizer_path=HuggingFaceH4/zephyr-7b-betaVerified that the sft processing changes worked and trained successfully: Logs
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.