Skip to content

Commit e897fa0

Browse files
lsteinclaude
andcommitted
feat: complete Qwen Image Edit pipeline with LoRA, GGUF, quantization, and UI support
Major additions: - LoRA support: loader invocation, config detection, conversion utils, prefix constants, and LayerPatcher integration in denoise with sidecar patching for GGUF models - Lightning LoRA: starter models (4-step and 8-step bf16), shift override parameter for the distilled sigma schedule - GGUF fixes: correct base class (ModelLoader), zero_cond_t=True, correct in_channels (no /4 division) - Denoise: use FlowMatchEulerDiscreteScheduler directly, proper CFG gating (skip negative when cfg<=1), reference latent pixel-space resize - I2L: resize reference image to generation dimensions before VAE encoding - Graph builder: wire LoRAs via collection loader, VAE-encode reference image as latents for spatial conditioning, pass shift/quantization params - Frontend: shift override (checkbox+slider), LoRA graph wiring, scheduler hidden for Qwen Image Edit, model switching cleanup - Starter model bundle for Qwen Image Edit - LoRA config registered in discriminated union (factory.py) - Downgrade transformers requirement back to >=4.56.0 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 94958c9 commit e897fa0

19 files changed

Lines changed: 803 additions & 68 deletions

File tree

invokeai/app/invocations/qwen_image_edit_denoise.py

Lines changed: 122 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Callable, Optional
1+
from contextlib import ExitStack
2+
from typing import Callable, Iterator, Optional, Tuple
23

34
import torch
45
import torchvision.transforms as tv_transforms
@@ -22,7 +23,12 @@
2223
from invokeai.app.invocations.primitives import LatentsOutput
2324
from invokeai.app.services.shared.invocation_context import InvocationContext
2425
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
25-
from invokeai.backend.model_manager.taxonomy import BaseModelType
26+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
27+
from invokeai.backend.patches.layer_patcher import LayerPatcher
28+
from invokeai.backend.patches.lora_conversions.qwen_image_edit_lora_constants import (
29+
QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX,
30+
)
31+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
2632
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
2733
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
2834
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import QwenImageEditConditioningInfo
@@ -70,6 +76,12 @@ class QwenImageEditDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
7076
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
7177
steps: int = InputField(default=40, gt=0, description=FieldDescriptions.steps)
7278
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
79+
shift: Optional[float] = InputField(
80+
default=None,
81+
description="Override the sigma schedule shift. "
82+
"When set, uses a fixed shift (e.g. 3.0 for Lightning LoRAs) instead of the default dynamic shifting. "
83+
"Leave unset for the base model's default schedule.",
84+
)
7385

7486
@torch.no_grad()
7587
def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -143,39 +155,47 @@ def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
143155
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
144156
return cfg_scale
145157

146-
def _compute_sigmas(self, image_seq_len: int, num_steps: int) -> list[float]:
158+
def _compute_sigmas(self, image_seq_len: int, num_steps: int, shift_override: float | None = None) -> list[float]:
147159
"""Compute sigmas matching the diffusers FlowMatchEulerDiscreteScheduler.
148160
149-
Reproduces the full pipeline: linspace → exponential time_shift → stretch_shift_to_terminal → append 0.
161+
When shift_override is None, reproduces the full base-model pipeline:
162+
linspace → dynamic exponential time_shift → stretch_shift_to_terminal → append 0.
163+
164+
When shift_override is set (e.g. 3.0 for Lightning LoRAs), uses a fixed mu = log(shift)
165+
with no shift_terminal stretching.
150166
"""
151167
import math
152168

153169
import numpy as np
154170

155-
# Scheduler config values (from scheduler_config.json)
156-
base_shift = 0.5
157-
max_shift = 0.9
158-
base_image_seq_len = 256
159-
max_image_seq_len = 8192
160-
shift_terminal = 0.02
161-
162171
# 1. Initial sigmas: N values from 1.0 to 1/N (same as diffusers pipeline)
163172
sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps).astype(np.float64)
164173

165-
# 2. Calculate mu (linear interpolation, matching diffusers calculate_shift)
166-
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
167-
b = base_shift - m * base_image_seq_len
168-
mu = image_seq_len * m + b
174+
if shift_override is not None:
175+
# Fixed shift (e.g. Lightning LoRA): mu = log(shift), no terminal stretching
176+
mu = math.log(shift_override)
177+
else:
178+
# Dynamic shift from scheduler config
179+
base_shift = 0.5
180+
max_shift = 0.9
181+
base_image_seq_len = 256
182+
max_image_seq_len = 8192
183+
184+
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
185+
b = base_shift - m * base_image_seq_len
186+
mu = image_seq_len * m + b
169187

170-
# 3. Exponential time shift
188+
# 2. Exponential time shift
171189
sigmas = np.array([math.exp(mu) / (math.exp(mu) + (1.0 / s - 1.0)) for s in sigmas])
172190

173-
# 4. Stretch shift to terminal
174-
one_minus = 1.0 - sigmas
175-
scale_factor = one_minus[-1] / (1.0 - shift_terminal)
176-
sigmas = 1.0 - (one_minus / scale_factor)
191+
# 3. Stretch shift to terminal (only for base model schedule)
192+
if shift_override is None:
193+
shift_terminal = 0.02
194+
one_minus = 1.0 - sigmas
195+
scale_factor = one_minus[-1] / (1.0 - shift_terminal)
196+
sigmas = 1.0 - (one_minus / scale_factor)
177197

178-
# 5. Append terminal 0
198+
# 4. Append terminal 0
179199
sigmas = np.append(sigmas, 0.0)
180200

181201
return sigmas.tolist()
@@ -219,7 +239,10 @@ def _run_diffusion(self, context: InvocationContext):
219239

220240
neg_prompt_embeds = None
221241
neg_prompt_mask = None
222-
do_classifier_free_guidance = self.negative_conditioning is not None
242+
# Match the diffusers pipeline: only enable CFG when cfg_scale > 1 AND negative conditioning is provided.
243+
# With cfg_scale <= 1, the negative prediction is unused, so skip it entirely.
244+
cfg_scale_value = self.cfg_scale if isinstance(self.cfg_scale, float) else self.cfg_scale[0]
245+
do_classifier_free_guidance = self.negative_conditioning is not None and cfg_scale_value > 1.0
223246
if do_classifier_free_guidance:
224247
neg_prompt_embeds, neg_prompt_mask = self._load_text_conditioning(
225248
context=context,
@@ -238,11 +261,40 @@ def _run_diffusion(self, context: InvocationContext):
238261
latent_height = self.height // LATENT_SCALE_FACTOR
239262
latent_width = self.width // LATENT_SCALE_FACTOR
240263
image_seq_len = (latent_height * latent_width) // (patch_size**2)
241-
# Compute the shifted sigma schedule (N+1 values, last is 0.0).
242-
# The sigmas serve both as the Euler step sizes AND the timestep conditioning to the model.
243-
sigmas = self._compute_sigmas(image_seq_len, self.steps)
244-
sigmas = clip_timestep_schedule_fractional(sigmas, self.denoising_start, self.denoising_end)
245-
total_steps = len(sigmas) - 1
264+
265+
# Use the actual FlowMatchEulerDiscreteScheduler to compute sigmas/timesteps,
266+
# exactly matching the diffusers pipeline.
267+
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
268+
269+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
270+
str(context.models.get_absolute_path(context.models.get_config(self.transformer.transformer)) / "scheduler"),
271+
local_files_only=True,
272+
)
273+
274+
import math
275+
import numpy as np
276+
277+
if self.shift is not None:
278+
# Lightning LoRA: fixed shift
279+
mu = math.log(self.shift)
280+
else:
281+
# Default dynamic shifting from scheduler config
282+
from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit import calculate_shift
283+
284+
mu = calculate_shift(
285+
image_seq_len,
286+
scheduler.config.get("base_image_seq_len", 256),
287+
scheduler.config.get("max_image_seq_len", 4096),
288+
scheduler.config.get("base_shift", 0.5),
289+
scheduler.config.get("max_shift", 1.15),
290+
)
291+
292+
init_sigmas = np.linspace(1.0, 1.0 / self.steps, self.steps).tolist()
293+
scheduler.set_timesteps(sigmas=init_sigmas, mu=mu, device=device)
294+
295+
timesteps_sched = scheduler.timesteps
296+
sigmas_sched = scheduler.sigmas
297+
total_steps = len(timesteps_sched)
246298

247299
cfg_scale = self._prepare_cfg_scale(total_steps)
248300

@@ -276,14 +328,14 @@ def _run_diffusion(self, context: InvocationContext):
276328

277329
# Prepare input latent image
278330
if init_latents is not None:
279-
s_0 = sigmas[0]
331+
s_0 = sigmas_sched[0].item()
280332
latents = s_0 * noise + (1.0 - s_0) * init_latents
281333
else:
282334
if self.denoising_start > 1e-5:
283335
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
284336
latents = noise
285337

286-
if len(sigmas) <= 1:
338+
if total_steps <= 0:
287339
return latents
288340

289341
# Pack latents into 2x2 patches: (B, C, H, W) -> (B, H/2*W/2, C*4)
@@ -292,7 +344,7 @@ def _run_diffusion(self, context: InvocationContext):
292344
# Pack reference image latents and concatenate along the sequence dimension.
293345
# The edit transformer always expects [noisy_patches ; ref_patches] in its sequence.
294346
if ref_latents is not None:
295-
_, _, rh, rw = ref_latents.shape
347+
_, ref_ch, rh, rw = ref_latents.shape
296348
if rh != latent_height or rw != latent_width:
297349
ref_latents = torch.nn.functional.interpolate(
298350
ref_latents, size=(latent_height, latent_width), mode="bilinear"
@@ -329,23 +381,38 @@ def _run_diffusion(self, context: InvocationContext):
329381
step=0,
330382
order=1,
331383
total_steps=total_steps,
332-
timestep=int(sigmas[0] * 1000),
384+
timestep=int(timesteps_sched[0].item()) if len(timesteps_sched) > 0 else 0,
333385
latents=self._unpack_latents(latents, latent_height, latent_width),
334386
),
335387
)
336388

337389
noisy_seq_len = latents.shape[1]
338390

339-
with transformer_info.model_on_device() as (_, transformer):
391+
# Determine if the model is quantized — GGUF models need sidecar patching for LoRAs
392+
transformer_config = context.models.get_config(self.transformer.transformer)
393+
model_is_quantized = transformer_config.format in (ModelFormat.GGUFQuantized,)
394+
395+
with ExitStack() as exit_stack:
396+
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
340397
assert isinstance(transformer, QwenImageTransformer2DModel)
341398

342-
for step_idx in tqdm(range(total_steps)):
343-
sigma_curr = sigmas[step_idx]
344-
sigma_next = sigmas[step_idx + 1]
399+
# Apply LoRA patches to the transformer
400+
exit_stack.enter_context(
401+
LayerPatcher.apply_smart_model_patches(
402+
model=transformer,
403+
patches=self._lora_iterator(context),
404+
prefix=QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX,
405+
dtype=inference_dtype,
406+
cached_weights=cached_weights,
407+
force_sidecar_patching=model_is_quantized,
408+
)
409+
)
410+
411+
scheduler.set_begin_index(0)
345412

346-
# The model receives the shifted sigma as its time conditioning.
347-
# Diffusers stores timesteps = sigma * 1000 and passes timestep / 1000.
348-
timestep = torch.tensor([sigma_curr], device=device).expand(1).to(inference_dtype)
413+
for step_idx, t in enumerate(tqdm(timesteps_sched)):
414+
# The pipeline passes timestep / 1000 to the transformer
415+
timestep = t.expand(latents.shape[0]).to(inference_dtype)
349416

350417
# Concatenate noisy and reference patches along the sequence dim
351418
model_input = torch.cat([latents, ref_latents_packed], dim=1)
@@ -354,7 +421,7 @@ def _run_diffusion(self, context: InvocationContext):
354421
hidden_states=model_input,
355422
encoder_hidden_states=pos_prompt_embeds,
356423
encoder_hidden_states_mask=pos_prompt_mask,
357-
timestep=timestep,
424+
timestep=timestep / 1000,
358425
img_shapes=img_shapes,
359426
return_dict=False,
360427
)[0]
@@ -366,7 +433,7 @@ def _run_diffusion(self, context: InvocationContext):
366433
hidden_states=model_input,
367434
encoder_hidden_states=neg_prompt_embeds,
368435
encoder_hidden_states_mask=neg_prompt_mask,
369-
timestep=timestep,
436+
timestep=timestep / 1000,
370437
img_shapes=img_shapes,
371438
return_dict=False,
372439
)[0]
@@ -376,14 +443,11 @@ def _run_diffusion(self, context: InvocationContext):
376443
else:
377444
noise_pred = noise_pred_cond
378445

379-
latents_dtype = latents.dtype
380-
latents = latents.to(dtype=torch.float32)
381-
dt = sigma_next - sigma_curr
382-
latents = latents + dt * noise_pred
383-
latents = latents.to(dtype=latents_dtype)
446+
# Use the scheduler's step method — exactly matching the pipeline
447+
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
384448

385449
if inpaint_extension is not None:
386-
# Unpack to 4D for inpaint merging, then repack
450+
sigma_next = sigmas_sched[step_idx + 1].item()
387451
latents_4d = self._unpack_latents(latents, latent_height, latent_width)
388452
latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(latents_4d, sigma_next)
389453
latents = self._pack_latents(latents_4d, 1, out_channels, latent_height, latent_width)
@@ -393,7 +457,7 @@ def _run_diffusion(self, context: InvocationContext):
393457
step=step_idx + 1,
394458
order=1,
395459
total_steps=total_steps,
396-
timestep=int(sigma_curr * 1000),
460+
timestep=int(t.item()),
397461
latents=self._unpack_latents(latents, latent_height, latent_width),
398462
),
399463
)
@@ -408,3 +472,14 @@ def step_callback(state: PipelineIntermediateState) -> None:
408472
context.util.sd_step_callback(state, BaseModelType.QwenImageEdit)
409473

410474
return step_callback
475+
476+
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
477+
"""Iterate over LoRA models to apply to the transformer."""
478+
for lora in self.transformer.loras:
479+
lora_info = context.models.load(lora.lora)
480+
if not isinstance(lora_info.model, ModelPatchRaw):
481+
raise TypeError(
482+
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}."
483+
)
484+
yield (lora_info.model, lora.weight)
485+
del lora_info

invokeai/app/invocations/qwen_image_edit_image_to_latents.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import einops
22
import torch
33
from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage
4+
from PIL import Image as PILImage
45

56
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
67
from invokeai.app.invocations.fields import (
@@ -32,6 +33,14 @@ class QwenImageEditImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBo
3233

3334
image: ImageField = InputField(description="The image to encode.")
3435
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
36+
width: int | None = InputField(
37+
default=None,
38+
description="Resize the image to this width before encoding. If not set, encodes at the image's original size.",
39+
)
40+
height: int | None = InputField(
41+
default=None,
42+
description="Resize the image to this height before encoding. If not set, encodes at the image's original size.",
43+
)
3544

3645
@staticmethod
3746
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
@@ -69,6 +78,11 @@ def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tenso
6978
def invoke(self, context: InvocationContext) -> LatentsOutput:
7079
image = context.images.get_pil(self.image.image_name)
7180

81+
# If target dimensions are specified, resize the image BEFORE encoding
82+
# (matching the diffusers pipeline which resizes in pixel space, not latent space).
83+
if self.width is not None and self.height is not None:
84+
image = image.convert("RGB").resize((self.width, self.height), resample=PILImage.LANCZOS)
85+
7286
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
7387
if image_tensor.dim() == 3:
7488
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")

0 commit comments

Comments
 (0)