2626
2727import argparse
2828import logging
29+ import os
2930
3031import torch
32+ from safetensors .torch import save_file as safetensors_save_file
3133
3234from src .utils .load_checkpoint import load_checkpoint
3335
3436
3537logger = logging .getLogger (__name__ )
3638
39+ ALLOWED_HYPERPARAMETER_KEYS = (
40+ "vocab_size" ,
41+ "hidden_size" ,
42+ "num_hidden_layers" ,
43+ "num_attention_heads" ,
44+ "intermediate_size" ,
45+ "hidden_act" ,
46+ "hidden_dropout_prob" ,
47+ "attention_probs_dropout_prob" ,
48+ "initializer_range" ,
49+ "layer_norm_eps" ,
50+ "pad_token_id" ,
51+ "position_embedding_type" ,
52+ "classifier_dropout" ,
53+ "rotary_theta" ,
54+ "ignore_index" ,
55+ "loss_type" ,
56+ "lora" ,
57+ "lora_alpha" ,
58+ "lora_r" ,
59+ "lora_dropout" ,
60+ )
61+
3762# PYTorch -> TE keymap
3863PYTORCH_TO_TE_KEYMAP = {
3964 "model.layers.*.pre_attn_layer_norm.weight" : "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight" ,
@@ -300,6 +325,11 @@ def convert_state_dict(src: dict, keymap: dict):
300325 return dst_state_dict
301326
302327
328+ def filter_hyper_parameters (hyper_parameters : dict ) -> dict :
329+ """Keep only conversion-compatible hyperparameter keys."""
330+ return {key : value for key , value in hyper_parameters .items () if key in ALLOWED_HYPERPARAMETER_KEYS }
331+
332+
303333def main ():
304334 """Main function."""
305335 logging .basicConfig (level = logging .INFO )
@@ -325,6 +355,7 @@ def main():
325355 # Load source checkpoint (automatically detects format)
326356 logger .info (f"Loading checkpoint from { args .src } " )
327357 src_checkpoint = load_checkpoint (args .src , map_location = "cpu" )
358+ src_checkpoint ["hyper_parameters" ] = filter_hyper_parameters (src_checkpoint ["hyper_parameters" ])
328359
329360 # Perform conversion based on direction
330361 if args .direction == "pytorch2te" :
@@ -341,11 +372,19 @@ def main():
341372 dst_state_dict = split_qkv (converted_state_dict , src_checkpoint ["hyper_parameters" ])
342373
343374 # Prepare final checkpoint
344- dst_checkpoint = {"state_dict" : dst_state_dict , "hyper_parameters" : src_checkpoint ["hyper_parameters" ]}
375+ dst_checkpoint = {
376+ "state_dict" : dst_state_dict ,
377+ "hyper_parameters" : src_checkpoint ["hyper_parameters" ],
378+ }
345379
346380 # Save the converted checkpoint in pickled format
347381 torch .save (dst_checkpoint , args .dst )
348- logger .info (f"Successfully converted checkpoint from { args .src } to { args .dst } " )
382+ logger .info (f"Successfully converted checkpoint saved to { args .dst } " )
383+
384+ # Save the state_dict in safetensors format alongside the .ckpt file
385+ safetensors_path = os .path .splitext (args .dst )[0 ] + ".safetensors"
386+ safetensors_save_file (dst_state_dict , safetensors_path )
387+ logger .info (f"Successfully saved safetensors checkpoint to { safetensors_path } " )
349388
350389
351390if __name__ == "__main__" :
0 commit comments