diff --git a/examples/distillation/training.py b/examples/distillation/training.py index 34626e7..a42225e 100644 --- a/examples/distillation/training.py +++ b/examples/distillation/training.py @@ -95,7 +95,7 @@ def convert_checkpoints_to_hf(model_training_config, output_path, best_model_pat def train(): pl.seed_everything(42) parser = HfArgumentParser((TrainingArgs, DataLoadingConfig, DistillTrainingConfig)) - (model_training_config, data_config, distill_config, _) = parser.parse_args_into_dataclasses( + model_training_config, data_config, distill_config, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) if ( diff --git a/examples/pruning/main_pruning.py b/examples/pruning/main_pruning.py index 36647fb..84b05a1 100644 --- a/examples/pruning/main_pruning.py +++ b/examples/pruning/main_pruning.py @@ -11,7 +11,7 @@ if __name__ == "__main__": parser = HfArgumentParser((PruningConfig, CalibrationDataConfig)) - (pruning_config, data_config) = parser.parse_args_into_dataclasses() + pruning_config, data_config = parser.parse_args_into_dataclasses() logger.info(f"pruning_config = {pruning_config}") logger.info(f"data_config = {data_config}") diff --git a/examples/quantization/main_quantization.py b/examples/quantization/main_quantization.py index 4ce8ecd..44a130a 100644 --- a/examples/quantization/main_quantization.py +++ b/examples/quantization/main_quantization.py @@ -11,7 +11,7 @@ if __name__ == "__main__": parser = HfArgumentParser((QuantizationConfig, CalibrationDataConfig)) - (quantization_config, data_config) = parser.parse_args_into_dataclasses() + quantization_config, data_config = parser.parse_args_into_dataclasses() logger.info(f"quantization_config = {quantization_config}") logger.info(f"data_config = {data_config}") diff --git a/examples/structured_pruning/main_structured_pruning.py b/examples/structured_pruning/main_structured_pruning.py index e3bfa3c..c6673f4 100644 --- a/examples/structured_pruning/main_structured_pruning.py +++ b/examples/structured_pruning/main_structured_pruning.py @@ -11,7 +11,7 @@ if __name__ == "__main__": parser = HfArgumentParser((StructuredPruningConfig, CalibrationDataConfig)) - (pruning_config, data_config) = parser.parse_args_into_dataclasses() + pruning_config, data_config = parser.parse_args_into_dataclasses() logger.info(f"pruning_config = {pruning_config}") logger.info(f"data_config = {data_config}") diff --git a/src/fmchisel/distillation/config.py b/src/fmchisel/distillation/config.py index dc2bfa3..3aa89a9 100644 --- a/src/fmchisel/distillation/config.py +++ b/src/fmchisel/distillation/config.py @@ -73,13 +73,11 @@ class DistillTrainingConfig: sample_method: Literal["supervised", "on-policy", "sequence-level"] = field(default="supervised") sample_fraction: float = field( default=1.0, - metadata={ - "help": "Fraction of batches whose responses are sampled from student (on-policy) distribution \ + metadata={"help": "Fraction of batches whose responses are sampled from student (on-policy) distribution \ or teacher (sequence-evel) distribution rather than using the original responses, \ same as the huggingface GKD trainer (parameter self.lmbda). https://huggingface.co/docs/trl/gkd_trainer#trl.GKDConfig \ e.g., 0.4 means 40% of batches are using the responses sampled from student/teacher model, with 60% using original data \ - Ignored when using supervised methods (ground-truth tokens)." - }, + Ignored when using supervised methods (ground-truth tokens)."}, ) max_new_tokens: int = field( default=100, @@ -89,10 +87,8 @@ class DistillTrainingConfig: ) sample_temperature: float = field( default=0.8, - metadata={ - "help": "Sample temperature used for on-policy or sequence-level response token generation. \ - The higher the temperature, the more random the completions." - }, + metadata={"help": "Sample temperature used for on-policy or sequence-level response token generation. \ + The higher the temperature, the more random the completions."}, ) # [end] sampling and generation configs include_prompt_loss: bool = field(