[feat] JoyAI-JoyImage-Edit support#13444
Conversation
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks for the PR! I left some initial feedbacks
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from einops import rearrange |
There was a problem hiding this comment.
can we refactor our einops stuff? it is not a diffusers dependency
There was a problem hiding this comment.
ok, I will refactor and remove einops
| return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) | ||
|
|
||
|
|
||
| class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel): |
There was a problem hiding this comment.
ohh what's going on here? is this some legancy code? can we remove?
There was a problem hiding this comment.
We first developed JoyImage, and then trained JoyImage-Edit based on it. This Transformer 3D model belongs to JoyImage, and JoyImage-Edit is inherited from JoyImage. We will also open-source JoyImage in the future.
They essentially share similar Transformer 3D models. I understand that each pipeline requires a specific Transformer model, which is why we implemented inheritance in this way.
| img_qkv = self.img_attn_qkv(img_modulated) | ||
| img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | ||
| img_q = self.img_attn_q_norm(img_q).to(img_v) | ||
| img_k = self.img_attn_k_norm(img_k).to(img_v) | ||
| if vis_freqs_cis is not None: | ||
| img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) | ||
|
|
||
| txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) | ||
| txt_qkv = self.txt_attn_qkv(txt_modulated) | ||
| txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | ||
| txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) | ||
| txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) | ||
| if txt_freqs_cis is not None: | ||
| txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False) | ||
|
|
||
| q = torch.cat((img_q, txt_q), dim=1) | ||
| k = torch.cat((img_k, txt_k), dim=1) | ||
| v = torch.cat((img_v, txt_v), dim=1) | ||
|
|
||
| attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3) | ||
| img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] |
There was a problem hiding this comment.
| img_qkv = self.img_attn_qkv(img_modulated) | |
| img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | |
| img_q = self.img_attn_q_norm(img_q).to(img_v) | |
| img_k = self.img_attn_k_norm(img_k).to(img_v) | |
| if vis_freqs_cis is not None: | |
| img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) | |
| txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) | |
| txt_qkv = self.txt_attn_qkv(txt_modulated) | |
| txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | |
| txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) | |
| txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) | |
| if txt_freqs_cis is not None: | |
| txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False) | |
| q = torch.cat((img_q, txt_q), dim=1) | |
| k = torch.cat((img_k, txt_k), dim=1) | |
| v = torch.cat((img_v, txt_v), dim=1) | |
| attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3) | |
| img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] | |
| attn_output, text_attn_output = self.attn(...) |
can we refactor the attention implementation to follow diffusers style?
basically you need to move all the layers used in attention calculation here into a JoyImageAttention (similar to FluxAttention https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L275)
also create a JoyImageAttnProcessor (see FluxAttnProcessor as example, I think it is same) https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L75 )
There was a problem hiding this comment.
Thanks for the reminder. I'll clean up this messy code.
| def load_modulation(modulate_type: str, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | ||
| factory_kwargs = {"dtype": dtype, "device": device} | ||
| if modulate_type == "wanx": | ||
| return ModulateWan(hidden_size, factor, **factory_kwargs) | ||
| if modulate_type == "adaLN": | ||
| return ModulateDiT(hidden_size, factor, act_layer, **factory_kwargs) | ||
| if modulate_type == "jdx": | ||
| return ModulateX(hidden_size, factor, **factory_kwargs) | ||
| raise ValueError(f"Unknown modulation type: {modulate_type}.") |
There was a problem hiding this comment.
| def load_modulation(modulate_type: str, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | |
| factory_kwargs = {"dtype": dtype, "device": device} | |
| if modulate_type == "wanx": | |
| return ModulateWan(hidden_size, factor, **factory_kwargs) | |
| if modulate_type == "adaLN": | |
| return ModulateDiT(hidden_size, factor, act_layer, **factory_kwargs) | |
| if modulate_type == "jdx": | |
| return ModulateX(hidden_size, factor, **factory_kwargs) | |
| raise ValueError(f"Unknown modulation type: {modulate_type}.") |
| class ModulateX(nn.Module): | ||
| def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): | ||
| super().__init__() | ||
| self.factor = factor | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| if len(x.shape) != 3: | ||
| x = x.unsqueeze(1) | ||
| return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)] |
There was a problem hiding this comment.
| class ModulateX(nn.Module): | |
| def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): | |
| super().__init__() | |
| self.factor = factor | |
| def forward(self, x: torch.Tensor): | |
| if len(x.shape) != 3: | |
| x = x.unsqueeze(1) | |
| return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)] |
| class ModulateDiT(nn.Module): | ||
| def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | ||
| factory_kwargs = {"dtype": dtype, "device": device} | ||
| super().__init__() | ||
| self.factor = factor | ||
| self.act = act_layer() | ||
| self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) | ||
| nn.init.zeros_(self.linear.weight) | ||
| nn.init.zeros_(self.linear.bias) | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| return self.linear(self.act(x)).chunk(self.factor, dim=-1) |
There was a problem hiding this comment.
| class ModulateDiT(nn.Module): | |
| def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | |
| factory_kwargs = {"dtype": dtype, "device": device} | |
| super().__init__() | |
| self.factor = factor | |
| self.act = act_layer() | |
| self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) | |
| nn.init.zeros_(self.linear.weight) | |
| nn.init.zeros_(self.linear.bias) | |
| def forward(self, x: torch.Tensor): | |
| return self.linear(self.act(x)).chunk(self.factor, dim=-1) |
is ModulateWan is one used in the model? if so let's remove the ModulateDit and ModulateX
| head_dim = hidden_size // heads_num | ||
| mlp_hidden_dim = int(hidden_size * mlp_width_ratio) | ||
|
|
||
| self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) |
There was a problem hiding this comment.
| self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) | |
| self.img_mod = JoyImageModulate(...) |
let's remove the load_modulation function and use the layer directly, better to rename to JoyImageModulate too
There was a problem hiding this comment.
Ok, I will refactor modulation and use ModulateWan
New `model="joyai-edit"` on /v1/image-edit and /v2/image-edit, routed to a separate FastAPI sidecar on 127.0.0.1:8092 that runs JoyImageEditPipeline from the Moran232/diffusers fork + transformers 4.57.1. Process isolation needed because the fork's diffusers core registry patches cannot be vendored (PR huggingface/diffusers#13444 pending) and transformers 4.57.x is incompatible with our 5.3.0 stack. Phase 0 VRAM measurement: 50.3 GB resident, 65.5 GB peak reserved at 1024² / 30 steps (well under the 80 GB gate). Passed. - `joyai_client.py` (NEW, 167 lines): thin httpx wrapper with per-call short-lived AsyncClient, split timeouts (180s edit / 60s mgmt), HTTPStatus→JoyAIError mapping. Singleton `joyai` exported. - `config.py`: `JOYAI_SIDECAR_URL` (default http://127.0.0.1:8092) and `LOAD_JOYAI` env flag. Off by default. - `server.py`: three-tenant swap protocol replaces the two-tenant v1.1.4 helpers. New `_last_gpu_tenant` tracker + `_evict_other_tenants(new)` helper. All three `_ensure_*_ready()` helpers are now `async def` — 13 call sites updated across _dispatch_job and v1 sync handlers. IMAGE_EDIT dispatch arm routes `model=="joyai-edit"` to joyai_client; validates len(image_paths)==1 (422 otherwise). Lifespan health-probes the sidecar when LOAD_JOYAI=1 (non-blocking — joyai-edit returns 503 if unreachable). - `flux_manager.py`: pre-existing bug fix — _edit() hardcoded ensure_model("flux2-klein"), silently ignoring the dispatcher's `model` kwarg. Now accepts and respects `model`. Guidance_scale is now conditional on model != "flux2-klein" (Klein strips CFG, Dev uses it). - `tests/test_joyai_client.py` (NEW, 7 tests) + `tests/test_validation.py` (+3 tests): 89 tests passing (was 79). - Docs: API.md, QUICKSTART.md, README.md, CLAUDE.md, AGENTS.md all updated with joyai-edit model entry, three-tenant swap diagram, latency table, sidecar location/port, LOAD_JOYAI env var, v1.1.8 changelog entry. Out-of-tree (not committed here, installed separately): /mnt/nvme-1/servers/joyai-sidecar/ (sidecar venv + sidecar.py + run.sh) ~/.config/systemd/user/joyai-sidecar.service Smoke-tested end-to-end: upload → /v2/image-edit joyai-edit → SSE stream (phase denoising → encoding → None) → fetch WEBP result (352 KB, 91 s wall clock for 20 steps at 1024²). Three-tenant swap evicted LTX and reloaded it cleanly via _evict_other_tenants. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Description
We are the JoyAI Team, and this is the Diffusers implementation for the JoyAI-Image-Edit model.
GitHub Repository: [https://github.com/jd-opensource/JoyAI-Image]
Hugging Face Model: [https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers]
Original opensource weights [https://huggingface.co/jdopensource/JoyAI-Image-Edit]
Fixes #13430
Model Overview
JoyAI-Image is a unified multimodal foundation model for image understanding, text-to-image generation, and instruction-guided image editing. It combines an 8B Multimodal Large Language Model (MLLM) with a 16B Multimodal Diffusion Transformer (MMDiT).
Kye Features
Image edit examples