-
Notifications
You must be signed in to change notification settings - Fork 491
Added SFT Pre-Processing for Grain Input Pipeline #3437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
30cfc2c
153953a
a5949cf
15fcafa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -352,6 +352,166 @@ def dpo_preprocessing_pipeline( | |
| return dataset | ||
|
|
||
|
|
||
| def _format_chat_template_grain(element, data_columns, tokenizer_model): | ||
| """Grain-compatible mapping function to format raw columns into conversational messages.""" | ||
| # Convert raw columns to conversational messages | ||
| if "messages" in data_columns: | ||
| messages = element["messages"] | ||
| elif set(data_columns) == {"prompt", "completion"}: | ||
| messages = [{"role": "user", "content": element["prompt"]}, {"role": "assistant", "content": element["completion"]}] | ||
| elif set(data_columns) == {"question", "answer"}: | ||
| messages = [{"role": "user", "content": element["question"]}, {"role": "assistant", "content": element["answer"]}] | ||
| else: | ||
| # Fallback if it's already a single string | ||
| messages = element[data_columns[0]] | ||
|
|
||
| # Assign the standardized messages back to the primary column | ||
| element[data_columns[0]] = messages | ||
|
|
||
| return input_pipeline_utils.apply_chat_template( | ||
| element, tokenizer_model=tokenizer_model, data_column_name=data_columns[0] | ||
| ) | ||
|
|
||
|
|
||
| def _tokenize_sft_chunks(element, text_column_name, tokenizer_model): | ||
| """Tokenize each chunk individually without truncating.""" | ||
| text_chunks = element[text_column_name] | ||
| element[text_column_name] = [tokenizer_model.encode(chunk) for chunk in text_chunks] | ||
| return element | ||
|
|
||
|
|
||
| def sft_preprocessing_pipeline( | ||
| dataset, | ||
| config, | ||
| data_columns, | ||
| tokenize, | ||
| grain_worker_count, | ||
| grain_per_worker_buffer_size, | ||
| ): | ||
| """Use grain pipeline to pre-process the dataset and return iterators for sft fine-tuning""" | ||
| if config.grain_file_type == "arrayrecord": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is exactly this block in the pretrain_preprocessing_pipeline(). Almost identical code there suggests it would be great to reuse it
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, there are many common operations between pretrain and sft. I think it's a good idea to extract the common pattern into util functions |
||
| dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize)) | ||
| dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) | ||
| else: | ||
| dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns)) | ||
|
|
||
| tokenizer_model = tokenizer.build_tokenizer( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same for the tokenizer |
||
| config.tokenizer_path, | ||
| config.tokenizer_type, | ||
| config.add_bos, | ||
| config.add_eos, | ||
| config.hf_access_token, | ||
| ) | ||
| if tokenizer_model.pad_id is not None: | ||
| pad_id = tokenizer_model.pad_id | ||
| elif tokenizer_model.unk_id is not None: | ||
| pad_id = tokenizer_model.unk_id | ||
| else: | ||
| pad_id = -1 | ||
|
|
||
| tokenizer_model = tokenizer_model.tokenizer | ||
|
|
||
| if getattr(config, "chat_template", None) and hasattr(tokenizer_model, "chat_template"): | ||
| tokenizer_model.chat_template = config.chat_template | ||
|
|
||
| supported_columns = [["prompt", "completion"], ["messages"], ["question", "answer"]] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is again repeating logic from the hf pipeline, one can extract utility routines to avoid code repetition |
||
| assert any( | ||
| set(data_columns) == set(supported) for supported in supported_columns | ||
| ), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_columns}" | ||
|
|
||
| dataset = dataset.map( | ||
| functools.partial(_format_chat_template_grain, data_columns=data_columns, tokenizer_model=tokenizer_model) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The hf pipeline calls instruction_data_processing.convert_to_conversational_format, do we support the same conversion here? https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/input_pipeline/hf_data_processing.py#L261 |
||
| ) | ||
|
|
||
| if tokenize: | ||
| dataset = dataset.map( | ||
| functools.partial( | ||
| _tokenize_sft_chunks, | ||
| text_column_name=data_columns[0], | ||
| tokenizer_model=tokenizer_model, | ||
| ) | ||
| ) | ||
|
|
||
| dataset = dataset.map( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any pros/cons of doing it via Grain operations like in the hf pipeline?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dataset.map is the newer and recommended way of using Grain |
||
| input_pipeline_utils.SFTPromptMasking( | ||
| text_column_name=data_columns[0], | ||
| completion_only=config.sft_train_on_completion_only, | ||
| max_target_length=config.max_target_length, | ||
| unk_id=pad_id, | ||
| ) | ||
| ) | ||
| data_columns = ("inputs", "targets") | ||
|
|
||
| # Pack and batch examples | ||
| batch_size = config.global_batch_size_to_load // jax.process_count() | ||
| if config.expansion_factor_real_data > 1: | ||
| # global_batch_size_to_load has been expanded in pyconfig.py when expansion_factor_real_data > 1. | ||
| batch_size = int(batch_size // config.expansion_factor_real_data) | ||
|
|
||
| if config.packing: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this block is again identical to one in pretrain_preprocessing_pipeline(). |
||
| length_struct = {col: config.max_target_length for col in data_columns} | ||
| max_segments = config.max_segments_per_seq | ||
| if max_segments is not None and max_segments <= 0: | ||
| max_segments = None | ||
| if config.grain_packing_type == "first_fit": | ||
| dataset = grain.experimental.FirstFitPackIterDataset( | ||
| dataset, | ||
| length_struct=length_struct, | ||
| num_packing_bins=batch_size, | ||
| max_sequences_per_bin=max_segments, | ||
| ) | ||
| elif config.grain_packing_type == "best_fit": | ||
| dataset = BestFitPackIterDataset(dataset, length_struct=length_struct, num_packing_bins=batch_size) | ||
| elif config.grain_packing_type == "concat_then_split": | ||
| if config.add_bos and hasattr(tokenizer_model, "bos_id"): | ||
| dataset = grain.experimental.ConcatThenSplitIterDataset( | ||
| dataset, | ||
| length_struct=length_struct, | ||
| bos_handling=grain.experimental.BOSHandling.REPLACE_FIRST_TOKEN_WITH_BOS, | ||
| bos_token_id=tokenizer_model.bos_id, | ||
| ) | ||
| else: | ||
| dataset = grain.experimental.ConcatThenSplitIterDataset(dataset, length_struct=length_struct) | ||
| else: | ||
| raise ValueError(f"Unknown packing type: {config.packing}") | ||
|
|
||
| rekey_dict = { | ||
| "targets_segmentation": "targets_segment_ids", | ||
| "inputs_segmentation": "inputs_segment_ids", | ||
| "targets_position": "targets_positions", | ||
| "inputs_position": "inputs_positions", | ||
| } | ||
| dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict)) | ||
| else: | ||
| dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id)) | ||
|
|
||
| batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id) | ||
| dataset = dataset.batch(batch_size, batch_fn=batch_fn) | ||
|
|
||
| # Shift inputs for teacher-forced training | ||
| dataset = dataset.map( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should it alway be executed in a generic sft_preprocessing_pipeline() ? |
||
| input_pipeline_utils.ShiftData( | ||
| ignored_ids=[pad_id], | ||
| axis=1, | ||
| ) | ||
| ) | ||
| multiprocessing_options = ( | ||
| pick_performance_config( | ||
| ds=dataset, | ||
| ram_budget_mb=config.grain_ram_budget_mb, | ||
| max_workers=None, | ||
| max_buffer_size=None, | ||
| ).multiprocessing_options | ||
| if grain_worker_count == -1 | ||
| else grain.MultiprocessingOptions( | ||
| num_workers=grain_worker_count, | ||
| per_worker_buffer_size=grain_per_worker_buffer_size, | ||
| ) | ||
| ) | ||
| dataset = dataset.mp_prefetch(multiprocessing_options) | ||
| return dataset | ||
|
|
||
|
|
||
| def make_grain_train_iterator( | ||
| config: ml_collections.ConfigDict, | ||
| global_mesh, | ||
|
|
@@ -385,6 +545,15 @@ def make_grain_train_iterator( | |
| grain_worker_count=config.grain_worker_count, | ||
| grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, | ||
| ) | ||
| elif config.use_sft: | ||
| train_dataloader = sft_preprocessing_pipeline( | ||
| train_ds, | ||
| config, | ||
| data_columns=config.train_data_columns, | ||
| tokenize=config.tokenize_train_data, | ||
| grain_worker_count=config.grain_worker_count, | ||
| grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, | ||
| ) | ||
| else: | ||
| train_dataloader = pretrain_preprocessing_pipeline( | ||
| train_ds, | ||
|
|
@@ -415,7 +584,16 @@ def make_grain_train_iterator( | |
| ) | ||
| if config.use_dpo: | ||
| preprocessing_fn = functools.partial( | ||
| pretrain_preprocessing_pipeline, | ||
| dpo_preprocessing_pipeline, | ||
| config=config, | ||
| data_columns=config.train_data_columns, | ||
| tokenize=config.tokenize_train_data, | ||
| grain_worker_count=config.grain_worker_count, | ||
| grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, | ||
| ) | ||
| elif config.use_sft: | ||
| preprocessing_fn = functools.partial( | ||
| sft_preprocessing_pipeline, | ||
| config=config, | ||
| data_columns=config.train_data_columns, | ||
| tokenize=config.tokenize_train_data, | ||
|
|
@@ -482,6 +660,15 @@ def make_grain_eval_iterator( | |
| grain_worker_count=config.grain_worker_count_eval, | ||
| grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, | ||
| ) | ||
| elif config.use_sft: | ||
| eval_dataloader = sft_preprocessing_pipeline( | ||
| eval_ds, | ||
| config, | ||
| data_columns=config.eval_data_columns, | ||
| tokenize=config.tokenize_eval_data, | ||
| grain_worker_count=config.grain_worker_count_eval, | ||
| grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, | ||
| ) | ||
| else: | ||
| eval_dataloader = pretrain_preprocessing_pipeline( | ||
| eval_ds, | ||
|
|
@@ -516,6 +703,15 @@ def make_grain_eval_iterator( | |
| grain_worker_count=config.grain_worker_count_eval, | ||
| grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, | ||
| ) | ||
| elif config.use_sft: | ||
| preprocessing_fn = functools.partial( | ||
| sft_preprocessing_pipeline, | ||
| config=config, | ||
| data_columns=config.eval_data_columns, | ||
| tokenize=config.tokenize_eval_data, | ||
| grain_worker_count=config.grain_worker_count_eval, | ||
| grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, | ||
| ) | ||
| else: | ||
| preprocessing_fn = functools.partial( | ||
| pretrain_preprocessing_pipeline, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HF pipeline asserts sft is running on a conversational format