Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion modelopt/torch/export/plugins/mcore_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,12 @@ def save_safetensors_by_layer_index(
meta_filename = filename + ".json"
ckpt_filename = filename + ".safetensors"

# Write safetensors first, then build the per-layer meta JSON from the same dict.
# Order matters: any late mutations to layer_state_dict (e.g. MTP tensors added after
Copy link
Copy Markdown
Contributor

@jenchen13 jenchen13 May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this AI comment be more concise?

# the dict was first constructed) must be captured by both files. Writing safetensors
# first ensures the JSON is always consistent with what is physically on disk.
save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"})

weight_map = {}
layer_total_size = 0
for key, val in layer_state_dict.items():
Expand All @@ -318,7 +324,6 @@ def save_safetensors_by_layer_index(
f,
indent=4,
)
save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"})

# [TODO]: this global barrier needs to be replaced with something safer
torch.distributed.barrier()
Expand Down