diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ca80cb450..bd6df260c 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -28,7 +28,7 @@ import torch import torch.nn as nn -from safetensors.torch import save_file +from safetensors.torch import load_file, safe_open, save_file try: import diffusers @@ -111,21 +111,128 @@ def _is_enabled_quantizer(quantizer): return False +def _merge_diffusion_transformer_with_non_transformer_components( + diffusion_transformer_state_dict: dict[str, torch.Tensor], + merged_base_safetensor_path: str, +) -> tuple[dict[str, torch.Tensor], dict[str, str]]: + """Merge diffusion transformer weights with non-transformer components from a safetensors file. + + Non-transformer components (VAE, vocoder, text encoders) and embeddings connectors are + taken from the base checkpoint. Transformer keys are prefixed with 'model.diffusion_model.' + for ComfyUI compatibility. + + Args: + diffusion_transformer_state_dict: The diffusion transformer state dict (already on CPU). + merged_base_safetensor_path: Path to the full base model safetensors file containing + all components (transformer, VAE, vocoder, etc.). + + Returns: + Tuple of (merged_state_dict, base_metadata) where base_metadata is the original + safetensors metadata from the base checkpoint. + """ + base_state = load_file(merged_base_safetensor_path) + + non_transformer_prefixes = [ + "vae.", + "audio_vae.", + "vocoder.", + "text_embedding_projection.", + "text_encoders.", + "first_stage_model.", + "cond_stage_model.", + "conditioner.", + ] + correct_prefix = "model.diffusion_model." + strip_prefixes = ["diffusion_model.", "transformer.", "_orig_mod.", "model.", "velocity_model."] + + base_non_transformer = { + k: v + for k, v in base_state.items() + if any(k.startswith(p) for p in non_transformer_prefixes) + } + base_connectors = { + k: v + for k, v in base_state.items() + if "embeddings_connector" in k and k.startswith(correct_prefix) + } + + prefixed = {} + for k, v in diffusion_transformer_state_dict.items(): + clean_k = k + for prefix in strip_prefixes: + if clean_k.startswith(prefix): + clean_k = clean_k[len(prefix) :] + break + prefixed[f"{correct_prefix}{clean_k}"] = v + + merged = dict(base_non_transformer) + merged.update(base_connectors) + merged.update(prefixed) + with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f: + base_metadata = f.metadata() or {} + + del base_state + return merged, base_metadata + + def _save_component_state_dict_safetensors( - component: nn.Module, component_export_dir: Path + component: nn.Module, + component_export_dir: Path, + merged_base_safetensor_path: str | None = None, + hf_quant_config: dict | None = None, ) -> None: + """Save component state dict as safetensors with optional base checkpoint merge. + + Args: + component: The nn.Module to save. + component_export_dir: Directory to save model.safetensors and config.json. + merged_base_safetensor_path: If provided, merge with non-transformer components + from this base safetensors file. + hf_quant_config: If provided, embed quantization config in safetensors metadata + and per-layer _quantization_metadata for ComfyUI. + """ cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()} - save_file(cpu_state_dict, str(component_export_dir / "model.safetensors")) - with open(component_export_dir / "config.json", "w") as f: - json.dump( + metadata: dict[str, str] = {} + metadata_full: dict[str, str] = {} + if merged_base_safetensor_path is not None: + cpu_state_dict, metadata_full = ( + _merge_diffusion_transformer_with_non_transformer_components( + cpu_state_dict, merged_base_safetensor_path + ) + ) + metadata["_export_format"] = "safetensors_state_dict" + metadata["_class_name"] = type(component).__name__ + + if hf_quant_config is not None: + metadata_full["quantization_config"] = json.dumps(hf_quant_config) + + # Build per-layer _quantization_metadata for ComfyUI + quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() + layer_metadata = {} + for k in cpu_state_dict: + if k.endswith((".weight_scale", ".weight_scale_2")): + layer_name = k.rsplit(".", 1)[0] + if layer_name.endswith(".weight"): + layer_name = layer_name.rsplit(".", 1)[0] + if layer_name not in layer_metadata: + layer_metadata[layer_name] = {"format": quant_algo} + metadata_full["_quantization_metadata"] = json.dumps( { - "_class_name": type(component).__name__, - "_export_format": "safetensors_state_dict", - }, - f, - indent=4, + "format_version": "1.0", + "layers": layer_metadata, + } ) + metadata_full.update(metadata) + save_file( + cpu_state_dict, + str(component_export_dir / "model.safetensors"), + metadata=metadata_full if merged_base_safetensor_path is not None else None, + ) + + with open(component_export_dir / "config.json", "w") as f: + json.dump(metadata, f, indent=4) + def _collect_shared_input_modules( model: nn.Module, @@ -807,6 +914,7 @@ def _export_diffusers_checkpoint( dtype: torch.dtype | None, export_dir: Path, components: list[str] | None, + merged_base_safetensor_path: str | None = None, max_shard_size: int | str = "10GB", ) -> None: """Internal: Export diffusion(-like) model/pipeline checkpoint. @@ -821,6 +929,8 @@ def _export_diffusers_checkpoint( export_dir: The directory to save the exported checkpoint. components: Optional list of component names to export. Only used for pipelines. If None, all components are exported. + merged_base_safetensor_path: If provided, merge the exported transformer with + non-transformer components from this base safetensors file. max_shard_size: Maximum size of each shard file. If the model exceeds this size, it will be sharded into multiple files and a .safetensors.index.json will be created. Use smaller values like "5GB" or "2GB" to force sharding. @@ -879,6 +989,7 @@ def _export_diffusers_checkpoint( # Step 5: Build quantization config quant_config = get_quant_config(component, is_modelopt_qlora=False) + hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None # Step 6: Save the component # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter @@ -888,12 +999,14 @@ def _export_diffusers_checkpoint( component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) else: with hide_quantizers_from_state_dict(component): - _save_component_state_dict_safetensors(component, component_export_dir) - + _save_component_state_dict_safetensors( + component, + component_export_dir, + merged_base_safetensor_path, + hf_quant_config, + ) # Step 7: Update config.json with quantization info - if quant_config is not None: - hf_quant_config = convert_hf_quant_config_format(quant_config) - + if hf_quant_config is not None: config_path = component_export_dir / "config.json" if config_path.exists(): with open(config_path) as file: @@ -905,7 +1018,9 @@ def _export_diffusers_checkpoint( elif hasattr(component, "save_pretrained"): component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) else: - _save_component_state_dict_safetensors(component, component_export_dir) + _save_component_state_dict_safetensors( + component, component_export_dir, merged_base_safetensor_path + ) print(f" Saved to: {component_export_dir}") @@ -985,6 +1100,7 @@ def export_hf_checkpoint( save_modelopt_state: bool = False, components: list[str] | None = None, extra_state_dict: dict[str, torch.Tensor] | None = None, + merged_base_safetensor_path: str | None = None, ): """Export quantized HuggingFace model checkpoint (transformers or diffusers). @@ -1002,6 +1118,9 @@ def export_hf_checkpoint( components: Only used for diffusers pipelines. Optional list of component names to export. If None, all quantized components are exported. extra_state_dict: Extra state dictionary to add to the exported model. + merged_base_safetensor_path: If provided, merge the exported diffusion transformer + with non-transformer components (VAE, vocoder, etc.) from this base safetensors + file. Only used for diffusion model exports (e.g., LTX-2). """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) @@ -1010,7 +1129,9 @@ def export_hf_checkpoint( if HAS_DIFFUSERS: is_diffusers_obj = is_diffusers_object(model) if is_diffusers_obj: - _export_diffusers_checkpoint(model, dtype, export_dir, components) + _export_diffusers_checkpoint( + model, dtype, export_dir, components, merged_base_safetensor_path + ) return # Transformers model export