Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
54afa65
[feat] JoyAI-JoyImage-Edit support
Apr 2, 2026
8459759
[fix] remove rearrange
Apr 14, 2026
e6e6df5
[refactor] two pass when do cfg
Apr 14, 2026
f557113
[refactor] remove repa, use wantimetextembeding, refactor modulate code
Apr 14, 2026
d397b68
[refactor] Joyimage Attention refactor
Apr 14, 2026
9d78e4e
remove vae tiling and autocast
Apr 14, 2026
cc9d134
[fix] remove einops from setup.py
Apr 20, 2026
001f7d3
[refactor] Refactor JoyImageEditPipeline to use explicit arguments in…
Apr 20, 2026
657b4b0
[fix] remove deprecated method decode_latents
Apr 20, 2026
19027dd
[refactor] refactor the image pre-processing logic into a separate Va…
Apr 20, 2026
0a06f33
[refactor] add JoyImageAttention to align with Attention + AttnProces…
Apr 20, 2026
02d947c
[refactor] simplify bucket logic in JoyImageEditImageProcessor by rep…
Apr 23, 2026
2b7fb86
[fix] remove leftover training-only parameters
Apr 25, 2026
79c48a8
[fix] add layerwise casting and fp32 module patterns to JoyImageTrans…
Apr 27, 2026
98cee97
[test] add JoyImageEditPipeline fast tests and JoyImageEditTransforme…
Apr 27, 2026
a716050
[fix] fix some pipeline args to support batch inference
Apr 27, 2026
320bde7
[fix] duplicate images to match batch size when fewer images than pro…
Apr 27, 2026
3ed6ca9
[fix] remove no longer used config parameters
Apr 28, 2026
92f4d85
Merge branch 'main' into joyimage_edit
dg845 Apr 28, 2026
261613b
Apply style fixes
github-actions[bot] Apr 28, 2026
f364da3
[fix] remove unused dataclass and rewrite helpers as inline functions
Apr 28, 2026
c7bd284
[fix] make dummy objects for JoyImageEdit
Apr 28, 2026
e45e1ad
[fix] allow test_torch_compile_repeated_blocks to pass
Apr 28, 2026
aeaa334
[fix] add examples on JoyImageEditPipeline
Apr 28, 2026
ce4a3d9
fix code style issues with ruff and black
Apr 28, 2026
82e5cd2
Apply style fixes
github-actions[bot] Apr 29, 2026
844f3f1
[fix] change default num_inference_steps to 40
Apr 29, 2026
fd29a73
[fix] use forward hook to extract pre-norm hidden states for transfor…
Apr 29, 2026
3a6b658
[fix] change the assert to ValueError in pipeline
Apr 29, 2026
76c1647
[fix] rename JoyImageTransformer3DModel to JoyImageEditTransformer3DM…
Apr 29, 2026
deb5d4f
[fix] support gradient checkpointing
Apr 29, 2026
d8c7c0e
[refactor] simplify RoPE utilities, inline helpers, copy WanTimeTextI…
Apr 29, 2026
203494e
[fix] remove _get_text_encoder_ckpt and qwen_processor
Apr 29, 2026
8408f55
[fix] change nn.RMSNorm to FP32LayerNorm
Apr 29, 2026
88579cc
[fix] small fixes for suggestions given by Claude
Apr 29, 2026
ea52935
Merge branch 'main' into joyimage_edit
dg845 Apr 30, 2026
9d9ef52
[refactor] build model using from _pretained instead of config
Apr 30, 2026
87b5383
[refactor] auto-wrap prompt and support text-to-image in JoyImage Edi…
Apr 30, 2026
cf61b0a
make style, make quality and make fix-copies
Apr 30, 2026
dbfbb59
[test] small fix to use vocab_size=1024
Apr 30, 2026
756904e
Merge branch 'main' into joyimage_edit
dg845 Apr 30, 2026
dc5a0a2
Merge branch 'main' into joyimage_edit
yiyixuxu May 1, 2026
63baf43
[refactor] separate encode_prompt_multiple_images from encode_prompt,…
May 2, 2026
e8c4db7
[test] fix CI: use strict=False for xfail and add @require_torch_acce…
May 4, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 306 additions & 0 deletions scripts/convert_joyimage_edit_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
import argparse
import pathlib
from typing import Any, Dict, Tuple
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from safetensors.torch import load_file
from diffusers import (
AutoencoderKLWan,
JoyImageEditTransformer3DModel,
JoyImageEditPipeline,
)
# This code is modified from convert_wan_to_diffusers.py to support input ckpt path
def convert_vae(vae_ckpt_path):
old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
new_state_dict = {}

# Create mappings for specific components
middle_key_mapping = {
# Encoder middle block
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
# Decoder middle block
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
}

# Create a mapping for attention blocks
attention_mapping = {
# Encoder middle attention
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
# Decoder middle attention
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
}

# Create a mapping for the head components
head_mapping = {
# Encoder head
"encoder.head.0.gamma": "encoder.norm_out.gamma",
"encoder.head.2.bias": "encoder.conv_out.bias",
"encoder.head.2.weight": "encoder.conv_out.weight",
# Decoder head
"decoder.head.0.gamma": "decoder.norm_out.gamma",
"decoder.head.2.bias": "decoder.conv_out.bias",
"decoder.head.2.weight": "decoder.conv_out.weight",
}

# Create a mapping for the quant components
quant_mapping = {
"conv1.weight": "quant_conv.weight",
"conv1.bias": "quant_conv.bias",
"conv2.weight": "post_quant_conv.weight",
"conv2.bias": "post_quant_conv.bias",
}

# Process each key in the state dict
for key, value in old_state_dict.items():
# Handle middle block keys using the mapping
if key in middle_key_mapping:
new_key = middle_key_mapping[key]
new_state_dict[new_key] = value
# Handle attention blocks using the mapping
elif key in attention_mapping:
new_key = attention_mapping[key]
new_state_dict[new_key] = value
# Handle head keys using the mapping
elif key in head_mapping:
new_key = head_mapping[key]
new_state_dict[new_key] = value
# Handle quant keys using the mapping
elif key in quant_mapping:
new_key = quant_mapping[key]
new_state_dict[new_key] = value
# Handle encoder conv1
elif key == "encoder.conv1.weight":
new_state_dict["encoder.conv_in.weight"] = value
elif key == "encoder.conv1.bias":
new_state_dict["encoder.conv_in.bias"] = value
# Handle decoder conv1
elif key == "decoder.conv1.weight":
new_state_dict["decoder.conv_in.weight"] = value
elif key == "decoder.conv1.bias":
new_state_dict["decoder.conv_in.bias"] = value
# Handle encoder downsamples
elif key.startswith("encoder.downsamples."):
# Convert to down_blocks
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")

# Convert residual block naming but keep the original structure
if ".residual.0.gamma" in new_key:
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
elif ".residual.2.bias" in new_key:
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
elif ".residual.2.weight" in new_key:
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
elif ".residual.3.gamma" in new_key:
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
elif ".residual.6.bias" in new_key:
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
elif ".residual.6.weight" in new_key:
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
elif ".shortcut.bias" in new_key:
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
elif ".shortcut.weight" in new_key:
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")

new_state_dict[new_key] = value

# Handle decoder upsamples
elif key.startswith("decoder.upsamples."):
# Convert to up_blocks
parts = key.split(".")
block_idx = int(parts[2])

# Group residual blocks
if "residual" in key:
if block_idx in [0, 1, 2]:
new_block_idx = 0
resnet_idx = block_idx
elif block_idx in [4, 5, 6]:
new_block_idx = 1
resnet_idx = block_idx - 4
elif block_idx in [8, 9, 10]:
new_block_idx = 2
resnet_idx = block_idx - 8
elif block_idx in [12, 13, 14]:
new_block_idx = 3
resnet_idx = block_idx - 12
else:
# Keep as is for other blocks
new_state_dict[key] = value
continue

# Convert residual block naming
if ".residual.0.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
elif ".residual.2.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
elif ".residual.2.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
elif ".residual.3.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
elif ".residual.6.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
elif ".residual.6.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
else:
new_key = key

new_state_dict[new_key] = value

# Handle shortcut connections
elif ".shortcut." in key:
if block_idx == 4:
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")

new_state_dict[new_key] = value

# Handle upsamplers
elif ".resample." in key or ".time_conv." in key:
if block_idx == 3:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
elif block_idx == 7:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
elif block_idx == 11:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")

new_state_dict[new_key] = value
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_state_dict[new_key] = value
else:
# Keep other keys unchanged
new_state_dict[key] = value

with init_empty_weights():
vae = AutoencoderKLWan()
vae.load_state_dict(new_state_dict, strict=True, assign=True)
return vae

def get_transformer_config() -> Tuple[Dict[str, Any], ...]:
config = {
"diffusers_config": {
"hidden_size": 4096,
"in_channels": 16,
"heads_num": 32,
"mm_double_blocks_depth": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"rope_dim_list": [16, 56, 56],
"text_states_dim": 4096,
"rope_type": "rope",
"dit_modulation_type": "wanx",
"unpatchify_new": True,
"rope_theta": 10000,
},
}
return config
def convert_transformer(ckpt_path: str):
checkpoint = torch.load(ckpt_path, weights_only=True)
if "model" in checkpoint:
original_state_dict = checkpoint["model"]
else:
original_state_dict = checkpoint
config = get_transformer_config()
with init_empty_weights():
transformer = JoyImageEditTransformer3DModel(**config['diffusers_config'])
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint")
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer")
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
parser.add_argument("--flow_shift", type=float, default=7.0)
return parser.parse_args()

DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if __name__ == "__main__":
args = get_args()
transformer = None
vae = None
dtype = DTYPE_MAPPING[args.dtype]

if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
assert args.text_encoder_path is not None
# assert args.tokenizer_path is not None
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path)
vae = vae.to(dtype=dtype)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.save_pipeline:
processor = AutoProcessor.from_pretrained(args.text_encoder_path)
text_encoder = Qwen3VLForConditionalGeneration.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path)
flow_shift = 1.5
scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000, shift=flow_shift
)
transformer = transformer.to("cuda")
vae = vae.to("cuda")
pipe = JoyImageEditPipeline(
processor=processor,
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
).to("cuda")
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
processor.save_pretrained(f"{args.output_path}/processor")
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
"accelerate>=0.31.0",
"compel==0.1.8",
"datasets",
"einops",
Comment thread
dg845 marked this conversation as resolved.
Outdated
"filelock",
"flax>=0.4.1",
"ftfy",
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
"HunyuanImageTransformer2DModel",
"JoyImageEditTransformer3DModel",
"HunyuanVideo15Transformer3DModel",
"HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel",
Expand Down Expand Up @@ -596,6 +597,7 @@
"LTXLatentUpsamplePipeline",
"LTXPipeline",
"LucyEditPipeline",
"JoyImageEditPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
"LuminaPipeline",
Expand Down Expand Up @@ -1025,6 +1027,7 @@
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
HunyuanImageTransformer2DModel,
JoyImageEditTransformer3DModel,
HunyuanVideo15Transformer3DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
Expand Down Expand Up @@ -1359,6 +1362,7 @@
LTXLatentUpsamplePipeline,
LTXPipeline,
LucyEditPipeline,
JoyImageEditPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,
LuminaPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
_import_structure["transformers.transformer_joyimage"] = ["JoyImageEditTransformer3DModel", "JoyImageTransformer3DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
Expand Down Expand Up @@ -225,6 +226,7 @@
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,
JoyImageEditTransformer3DModel,
HunyuanVideo15Transformer3DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
from .transformer_joyimage import JoyImageEditTransformer3DModel, JoyImageTransformer3DModel
from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_longcat_image import LongCatImageTransformer2DModel
from .transformer_ltx import LTXVideoTransformer3DModel
Expand Down
Loading
Loading