diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index 204ff012c71..fd06e9f384c 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -305,6 +305,12 @@ def save_safetensors_by_layer_index( meta_filename = filename + ".json" ckpt_filename = filename + ".safetensors" + # Write safetensors first, then build the per-layer meta JSON from the same dict. + # Order matters: any late mutations to layer_state_dict (e.g. MTP tensors added after + # the dict was first constructed) must be captured by both files. Writing safetensors + # first ensures the JSON is always consistent with what is physically on disk. + save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}) + weight_map = {} layer_total_size = 0 for key, val in layer_state_dict.items(): @@ -318,7 +324,6 @@ def save_safetensors_by_layer_index( f, indent=4, ) - save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}) # [TODO]: this global barrier needs to be replaced with something safer torch.distributed.barrier()