Skip to content

Added SFT Pre-Processing for Grain Input Pipeline#3437

Open
ajkv-google wants to merge 4 commits intomainfrom
ajkv/sft-grain-implementation
Open

Added SFT Pre-Processing for Grain Input Pipeline#3437
ajkv-google wants to merge 4 commits intomainfrom
ajkv/sft-grain-implementation

Conversation

@ajkv-google
Copy link
Collaborator

@ajkv-google ajkv-google commented Mar 18, 2026

Description

This PR introduces SFT support to the Grain input pipeline by adding a separate sft_preprocessing_pipeline function. Rather than cluttering the existing pretrain code, it uses simple conditionals inside the train and eval iterators to route to this new SFT logic. I followed the existing Hugging Face SFT implementation and adapted its logic to be compatible with Grain's element-wise datasets.

Tests

I added a unit test to verify end-to-end functionality to make sure the Grain SFT pipeline formats the data and outputs correctly. Ran this command to execute the unit test:

  • pytest tests/unit/grain_data_processing_test.py::GrainSFTParquetProcessingTest -v

This is the output of the test: Test Passed Output

Also, ran the training pipeline in Maxtext with sft enabled using a grain dataset with this command:

  • python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=test_grain_sft dataset_type=grain grain_file_type=parquet grain_train_files=gs://maxtext-dataset/hf/ultrachat_200k/train_sft-*.parquet steps=10 tokenizer_type=huggingface tokenizer_path=HuggingFaceH4/zephyr-7b-beta

Verified that the sft processing changes worked and trained successfully: Logs

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Mar 18, 2026

Codecov Report

❌ Patch coverage is 40.57971% with 41 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...rc/maxtext/input_pipeline/grain_data_processing.py 40.57% 33 Missing and 8 partials ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator

@vlad-karp vlad-karp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be great to test not only with maxtext general sft but with distillation sft pipeline as well

grain_per_worker_buffer_size,
):
"""Use grain pipeline to pre-process the dataset and return iterators for sft fine-tuning"""
if config.grain_file_type == "arrayrecord":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is exactly this block in the pretrain_preprocessing_pipeline(). Almost identical code there suggests it would be great to reuse it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there are many common operations between pretrain and sft. I think it's a good idea to extract the common pattern into util functions

else:
dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns))

tokenizer_model = tokenizer.build_tokenizer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for the tokenizer

if getattr(config, "chat_template", None) and hasattr(tokenizer_model, "chat_template"):
tokenizer_model.chat_template = config.chat_template

supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is again repeating logic from the hf pipeline, one can extract utility routines to avoid code repetition

messages = [{"role": "user", "content": element["prompt"]}, {"role": "assistant", "content": element["completion"]}]
elif set(data_columns) == {"question", "answer"}:
messages = [{"role": "user", "content": element["question"]}, {"role": "assistant", "content": element["answer"]}]
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HF pipeline asserts sft is running on a conversational format

)
)

dataset = dataset.map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any pros/cons of doing it via Grain operations like in the hf pipeline?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset.map is the newer and recommended way of using Grain

# global_batch_size_to_load has been expanded in pyconfig.py when expansion_factor_real_data > 1.
batch_size = int(batch_size // config.expansion_factor_real_data)

if config.packing:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this block is again identical to one in pretrain_preprocessing_pipeline().
Also, would be great to check if need some sft related modifications

dataset = dataset.batch(batch_size, batch_fn=batch_fn)

# Shift inputs for teacher-forced training
dataset = dataset.map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it alway be executed in a generic sft_preprocessing_pipeline() ?

)
)

dataset = dataset.map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset.map is the newer and recommended way of using Grain

grain_per_worker_buffer_size,
):
"""Use grain pipeline to pre-process the dataset and return iterators for sft fine-tuning"""
if config.grain_file_type == "arrayrecord":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there are many common operations between pretrain and sft. I think it's a good idea to extract the common pattern into util functions

), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_columns}"

dataset = dataset.map(
functools.partial(_format_chat_template_grain, data_columns=data_columns, tokenizer_model=tokenizer_model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hf pipeline calls instruction_data_processing.convert_to_conversational_format, do we support the same conversion here? https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/input_pipeline/hf_data_processing.py#L261

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants