-
Notifications
You must be signed in to change notification settings - Fork 7k
[feat] JoyAI-JoyImage-Edit support #13444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Moran232
wants to merge
44
commits into
huggingface:main
Choose a base branch
from
Moran232:joyimage_edit
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
54afa65
[feat] JoyAI-JoyImage-Edit support
8459759
[fix] remove rearrange
e6e6df5
[refactor] two pass when do cfg
f557113
[refactor] remove repa, use wantimetextembeding, refactor modulate code
d397b68
[refactor] Joyimage Attention refactor
9d78e4e
remove vae tiling and autocast
cc9d134
[fix] remove einops from setup.py
001f7d3
[refactor] Refactor JoyImageEditPipeline to use explicit arguments in…
657b4b0
[fix] remove deprecated method decode_latents
19027dd
[refactor] refactor the image pre-processing logic into a separate Va…
0a06f33
[refactor] add JoyImageAttention to align with Attention + AttnProces…
02d947c
[refactor] simplify bucket logic in JoyImageEditImageProcessor by rep…
2b7fb86
[fix] remove leftover training-only parameters
79c48a8
[fix] add layerwise casting and fp32 module patterns to JoyImageTrans…
98cee97
[test] add JoyImageEditPipeline fast tests and JoyImageEditTransforme…
a716050
[fix] fix some pipeline args to support batch inference
320bde7
[fix] duplicate images to match batch size when fewer images than pro…
3ed6ca9
[fix] remove no longer used config parameters
92f4d85
Merge branch 'main' into joyimage_edit
dg845 261613b
Apply style fixes
github-actions[bot] f364da3
[fix] remove unused dataclass and rewrite helpers as inline functions
c7bd284
[fix] make dummy objects for JoyImageEdit
e45e1ad
[fix] allow test_torch_compile_repeated_blocks to pass
aeaa334
[fix] add examples on JoyImageEditPipeline
ce4a3d9
fix code style issues with ruff and black
82e5cd2
Apply style fixes
github-actions[bot] 844f3f1
[fix] change default num_inference_steps to 40
fd29a73
[fix] use forward hook to extract pre-norm hidden states for transfor…
3a6b658
[fix] change the assert to ValueError in pipeline
76c1647
[fix] rename JoyImageTransformer3DModel to JoyImageEditTransformer3DM…
deb5d4f
[fix] support gradient checkpointing
d8c7c0e
[refactor] simplify RoPE utilities, inline helpers, copy WanTimeTextI…
203494e
[fix] remove _get_text_encoder_ckpt and qwen_processor
8408f55
[fix] change nn.RMSNorm to FP32LayerNorm
88579cc
[fix] small fixes for suggestions given by Claude
ea52935
Merge branch 'main' into joyimage_edit
dg845 9d9ef52
[refactor] build model using from _pretained instead of config
87b5383
[refactor] auto-wrap prompt and support text-to-image in JoyImage Edi…
cf61b0a
make style, make quality and make fix-copies
dbfbb59
[test] small fix to use vocab_size=1024
756904e
Merge branch 'main' into joyimage_edit
dg845 dc5a0a2
Merge branch 'main' into joyimage_edit
yiyixuxu 63baf43
[refactor] separate encode_prompt_multiple_images from encode_prompt,…
e8c4db7
[test] fix CI: use strict=False for xfail and add @require_torch_acce…
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,306 @@ | ||
| import argparse | ||
| import pathlib | ||
| from typing import Any, Dict, Tuple | ||
| import torch | ||
| from accelerate import init_empty_weights | ||
| from huggingface_hub import hf_hub_download, snapshot_download | ||
| from safetensors.torch import load_file | ||
| from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration | ||
| from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler | ||
| from safetensors.torch import load_file | ||
| from diffusers import ( | ||
| AutoencoderKLWan, | ||
| JoyImageEditTransformer3DModel, | ||
| JoyImageEditPipeline, | ||
| ) | ||
| # This code is modified from convert_wan_to_diffusers.py to support input ckpt path | ||
| def convert_vae(vae_ckpt_path): | ||
| old_state_dict = torch.load(vae_ckpt_path, weights_only=True) | ||
| new_state_dict = {} | ||
|
|
||
| # Create mappings for specific components | ||
| middle_key_mapping = { | ||
| # Encoder middle block | ||
| "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", | ||
| "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", | ||
| "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", | ||
| "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", | ||
| "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", | ||
| "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", | ||
| "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", | ||
| "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", | ||
| "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", | ||
| "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", | ||
| "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", | ||
| "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", | ||
| # Decoder middle block | ||
| "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", | ||
| "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", | ||
| "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", | ||
| "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", | ||
| "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", | ||
| "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", | ||
| "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", | ||
| "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", | ||
| "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", | ||
| "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", | ||
| "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", | ||
| "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", | ||
| } | ||
|
|
||
| # Create a mapping for attention blocks | ||
| attention_mapping = { | ||
| # Encoder middle attention | ||
| "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", | ||
| "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", | ||
| "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", | ||
| "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", | ||
| "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", | ||
| # Decoder middle attention | ||
| "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", | ||
| "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", | ||
| "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", | ||
| "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", | ||
| "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", | ||
| } | ||
|
|
||
| # Create a mapping for the head components | ||
| head_mapping = { | ||
| # Encoder head | ||
| "encoder.head.0.gamma": "encoder.norm_out.gamma", | ||
| "encoder.head.2.bias": "encoder.conv_out.bias", | ||
| "encoder.head.2.weight": "encoder.conv_out.weight", | ||
| # Decoder head | ||
| "decoder.head.0.gamma": "decoder.norm_out.gamma", | ||
| "decoder.head.2.bias": "decoder.conv_out.bias", | ||
| "decoder.head.2.weight": "decoder.conv_out.weight", | ||
| } | ||
|
|
||
| # Create a mapping for the quant components | ||
| quant_mapping = { | ||
| "conv1.weight": "quant_conv.weight", | ||
| "conv1.bias": "quant_conv.bias", | ||
| "conv2.weight": "post_quant_conv.weight", | ||
| "conv2.bias": "post_quant_conv.bias", | ||
| } | ||
|
|
||
| # Process each key in the state dict | ||
| for key, value in old_state_dict.items(): | ||
| # Handle middle block keys using the mapping | ||
| if key in middle_key_mapping: | ||
| new_key = middle_key_mapping[key] | ||
| new_state_dict[new_key] = value | ||
| # Handle attention blocks using the mapping | ||
| elif key in attention_mapping: | ||
| new_key = attention_mapping[key] | ||
| new_state_dict[new_key] = value | ||
| # Handle head keys using the mapping | ||
| elif key in head_mapping: | ||
| new_key = head_mapping[key] | ||
| new_state_dict[new_key] = value | ||
| # Handle quant keys using the mapping | ||
| elif key in quant_mapping: | ||
| new_key = quant_mapping[key] | ||
| new_state_dict[new_key] = value | ||
| # Handle encoder conv1 | ||
| elif key == "encoder.conv1.weight": | ||
| new_state_dict["encoder.conv_in.weight"] = value | ||
| elif key == "encoder.conv1.bias": | ||
| new_state_dict["encoder.conv_in.bias"] = value | ||
| # Handle decoder conv1 | ||
| elif key == "decoder.conv1.weight": | ||
| new_state_dict["decoder.conv_in.weight"] = value | ||
| elif key == "decoder.conv1.bias": | ||
| new_state_dict["decoder.conv_in.bias"] = value | ||
| # Handle encoder downsamples | ||
| elif key.startswith("encoder.downsamples."): | ||
| # Convert to down_blocks | ||
| new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") | ||
|
|
||
| # Convert residual block naming but keep the original structure | ||
| if ".residual.0.gamma" in new_key: | ||
| new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") | ||
| elif ".residual.2.bias" in new_key: | ||
| new_key = new_key.replace(".residual.2.bias", ".conv1.bias") | ||
| elif ".residual.2.weight" in new_key: | ||
| new_key = new_key.replace(".residual.2.weight", ".conv1.weight") | ||
| elif ".residual.3.gamma" in new_key: | ||
| new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") | ||
| elif ".residual.6.bias" in new_key: | ||
| new_key = new_key.replace(".residual.6.bias", ".conv2.bias") | ||
| elif ".residual.6.weight" in new_key: | ||
| new_key = new_key.replace(".residual.6.weight", ".conv2.weight") | ||
| elif ".shortcut.bias" in new_key: | ||
| new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") | ||
| elif ".shortcut.weight" in new_key: | ||
| new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") | ||
|
|
||
| new_state_dict[new_key] = value | ||
|
|
||
| # Handle decoder upsamples | ||
| elif key.startswith("decoder.upsamples."): | ||
| # Convert to up_blocks | ||
| parts = key.split(".") | ||
| block_idx = int(parts[2]) | ||
|
|
||
| # Group residual blocks | ||
| if "residual" in key: | ||
| if block_idx in [0, 1, 2]: | ||
| new_block_idx = 0 | ||
| resnet_idx = block_idx | ||
| elif block_idx in [4, 5, 6]: | ||
| new_block_idx = 1 | ||
| resnet_idx = block_idx - 4 | ||
| elif block_idx in [8, 9, 10]: | ||
| new_block_idx = 2 | ||
| resnet_idx = block_idx - 8 | ||
| elif block_idx in [12, 13, 14]: | ||
| new_block_idx = 3 | ||
| resnet_idx = block_idx - 12 | ||
| else: | ||
| # Keep as is for other blocks | ||
| new_state_dict[key] = value | ||
| continue | ||
|
|
||
| # Convert residual block naming | ||
| if ".residual.0.gamma" in key: | ||
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" | ||
| elif ".residual.2.bias" in key: | ||
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" | ||
| elif ".residual.2.weight" in key: | ||
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" | ||
| elif ".residual.3.gamma" in key: | ||
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" | ||
| elif ".residual.6.bias" in key: | ||
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" | ||
| elif ".residual.6.weight" in key: | ||
| new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" | ||
| else: | ||
| new_key = key | ||
|
|
||
| new_state_dict[new_key] = value | ||
|
|
||
| # Handle shortcut connections | ||
| elif ".shortcut." in key: | ||
| if block_idx == 4: | ||
| new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") | ||
| new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") | ||
| else: | ||
| new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") | ||
| new_key = new_key.replace(".shortcut.", ".conv_shortcut.") | ||
|
|
||
| new_state_dict[new_key] = value | ||
|
|
||
| # Handle upsamplers | ||
| elif ".resample." in key or ".time_conv." in key: | ||
| if block_idx == 3: | ||
| new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") | ||
| elif block_idx == 7: | ||
| new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") | ||
| elif block_idx == 11: | ||
| new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") | ||
| else: | ||
| new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") | ||
|
|
||
| new_state_dict[new_key] = value | ||
| else: | ||
| new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") | ||
| new_state_dict[new_key] = value | ||
| else: | ||
| # Keep other keys unchanged | ||
| new_state_dict[key] = value | ||
|
|
||
| with init_empty_weights(): | ||
| vae = AutoencoderKLWan() | ||
| vae.load_state_dict(new_state_dict, strict=True, assign=True) | ||
| return vae | ||
|
|
||
| def get_transformer_config() -> Tuple[Dict[str, Any], ...]: | ||
| config = { | ||
| "diffusers_config": { | ||
| "hidden_size": 4096, | ||
| "in_channels": 16, | ||
| "heads_num": 32, | ||
| "mm_double_blocks_depth": 40, | ||
| "out_channels": 16, | ||
| "patch_size": [1, 2, 2], | ||
| "rope_dim_list": [16, 56, 56], | ||
| "text_states_dim": 4096, | ||
| "rope_type": "rope", | ||
| "dit_modulation_type": "wanx", | ||
| "unpatchify_new": True, | ||
| "rope_theta": 10000, | ||
| }, | ||
| } | ||
| return config | ||
| def convert_transformer(ckpt_path: str): | ||
| checkpoint = torch.load(ckpt_path, weights_only=True) | ||
| if "model" in checkpoint: | ||
| original_state_dict = checkpoint["model"] | ||
| else: | ||
| original_state_dict = checkpoint | ||
| config = get_transformer_config() | ||
| with init_empty_weights(): | ||
| transformer = JoyImageEditTransformer3DModel(**config['diffusers_config']) | ||
| transformer.load_state_dict(original_state_dict, strict=True, assign=True) | ||
| return transformer | ||
|
|
||
| def get_args(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" | ||
| ) | ||
| parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") | ||
| parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") | ||
| parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer") | ||
| parser.add_argument("--save_pipeline", action="store_true") | ||
| parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") | ||
| parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") | ||
| parser.add_argument("--flow_shift", type=float, default=7.0) | ||
| return parser.parse_args() | ||
|
|
||
| DTYPE_MAPPING = { | ||
| "fp32": torch.float32, | ||
| "fp16": torch.float16, | ||
| "bf16": torch.bfloat16, | ||
| } | ||
| if __name__ == "__main__": | ||
| args = get_args() | ||
| transformer = None | ||
| vae = None | ||
| dtype = DTYPE_MAPPING[args.dtype] | ||
|
|
||
| if args.save_pipeline: | ||
| assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None | ||
| assert args.text_encoder_path is not None | ||
| # assert args.tokenizer_path is not None | ||
| if args.transformer_ckpt_path is not None: | ||
| transformer = convert_transformer(args.transformer_ckpt_path) | ||
| transformer = transformer.to(dtype=dtype) | ||
| if not args.save_pipeline: | ||
| transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | ||
| if args.vae_ckpt_path is not None: | ||
| vae = convert_vae(args.vae_ckpt_path) | ||
| vae = vae.to(dtype=dtype) | ||
| if not args.save_pipeline: | ||
| vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | ||
| if args.save_pipeline: | ||
| processor = AutoProcessor.from_pretrained(args.text_encoder_path) | ||
| text_encoder = Qwen3VLForConditionalGeneration.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16).to("cuda") | ||
| tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path) | ||
| flow_shift = 1.5 | ||
| scheduler = FlowMatchEulerDiscreteScheduler( | ||
| num_train_timesteps=1000, shift=flow_shift | ||
| ) | ||
| transformer = transformer.to("cuda") | ||
| vae = vae.to("cuda") | ||
| pipe = JoyImageEditPipeline( | ||
| processor=processor, | ||
| transformer=transformer, | ||
| text_encoder=text_encoder, | ||
| tokenizer=tokenizer, | ||
| vae=vae, | ||
| scheduler=scheduler, | ||
| ).to("cuda") | ||
| pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | ||
| processor.save_pretrained(f"{args.output_path}/processor") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -99,6 +99,7 @@ | |
| "accelerate>=0.31.0", | ||
| "compel==0.1.8", | ||
| "datasets", | ||
| "einops", | ||
| "filelock", | ||
| "flax>=0.4.1", | ||
| "ftfy", | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.