1- from typing import Callable , Optional
1+ from contextlib import ExitStack
2+ from typing import Callable , Iterator , Optional , Tuple
23
34import torch
45import torchvision .transforms as tv_transforms
2223from invokeai .app .invocations .primitives import LatentsOutput
2324from invokeai .app .services .shared .invocation_context import InvocationContext
2425from invokeai .backend .flux .sampling_utils import clip_timestep_schedule_fractional
25- from invokeai .backend .model_manager .taxonomy import BaseModelType
26+ from invokeai .backend .model_manager .taxonomy import BaseModelType , ModelFormat
27+ from invokeai .backend .patches .layer_patcher import LayerPatcher
28+ from invokeai .backend .patches .lora_conversions .qwen_image_edit_lora_constants import (
29+ QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX ,
30+ )
31+ from invokeai .backend .patches .model_patch_raw import ModelPatchRaw
2632from invokeai .backend .rectified_flow .rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
2733from invokeai .backend .stable_diffusion .diffusers_pipeline import PipelineIntermediateState
2834from invokeai .backend .stable_diffusion .diffusion .conditioning_data import QwenImageEditConditioningInfo
@@ -70,6 +76,12 @@ class QwenImageEditDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
7076 height : int = InputField (default = 1024 , multiple_of = 16 , description = "Height of the generated image." )
7177 steps : int = InputField (default = 40 , gt = 0 , description = FieldDescriptions .steps )
7278 seed : int = InputField (default = 0 , description = "Randomness seed for reproducibility." )
79+ shift : Optional [float ] = InputField (
80+ default = None ,
81+ description = "Override the sigma schedule shift. "
82+ "When set, uses a fixed shift (e.g. 3.0 for Lightning LoRAs) instead of the default dynamic shifting. "
83+ "Leave unset for the base model's default schedule." ,
84+ )
7385
7486 @torch .no_grad ()
7587 def invoke (self , context : InvocationContext ) -> LatentsOutput :
@@ -143,39 +155,47 @@ def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
143155 raise ValueError (f"Invalid CFG scale type: { type (self .cfg_scale )} " )
144156 return cfg_scale
145157
146- def _compute_sigmas (self , image_seq_len : int , num_steps : int ) -> list [float ]:
158+ def _compute_sigmas (self , image_seq_len : int , num_steps : int , shift_override : float | None = None ) -> list [float ]:
147159 """Compute sigmas matching the diffusers FlowMatchEulerDiscreteScheduler.
148160
149- Reproduces the full pipeline: linspace → exponential time_shift → stretch_shift_to_terminal → append 0.
161+ When shift_override is None, reproduces the full base-model pipeline:
162+ linspace → dynamic exponential time_shift → stretch_shift_to_terminal → append 0.
163+
164+ When shift_override is set (e.g. 3.0 for Lightning LoRAs), uses a fixed mu = log(shift)
165+ with no shift_terminal stretching.
150166 """
151167 import math
152168
153169 import numpy as np
154170
155- # Scheduler config values (from scheduler_config.json)
156- base_shift = 0.5
157- max_shift = 0.9
158- base_image_seq_len = 256
159- max_image_seq_len = 8192
160- shift_terminal = 0.02
161-
162171 # 1. Initial sigmas: N values from 1.0 to 1/N (same as diffusers pipeline)
163172 sigmas = np .linspace (1.0 , 1.0 / num_steps , num_steps ).astype (np .float64 )
164173
165- # 2. Calculate mu (linear interpolation, matching diffusers calculate_shift)
166- m = (max_shift - base_shift ) / (max_image_seq_len - base_image_seq_len )
167- b = base_shift - m * base_image_seq_len
168- mu = image_seq_len * m + b
174+ if shift_override is not None :
175+ # Fixed shift (e.g. Lightning LoRA): mu = log(shift), no terminal stretching
176+ mu = math .log (shift_override )
177+ else :
178+ # Dynamic shift from scheduler config
179+ base_shift = 0.5
180+ max_shift = 0.9
181+ base_image_seq_len = 256
182+ max_image_seq_len = 8192
183+
184+ m = (max_shift - base_shift ) / (max_image_seq_len - base_image_seq_len )
185+ b = base_shift - m * base_image_seq_len
186+ mu = image_seq_len * m + b
169187
170- # 3 . Exponential time shift
188+ # 2 . Exponential time shift
171189 sigmas = np .array ([math .exp (mu ) / (math .exp (mu ) + (1.0 / s - 1.0 )) for s in sigmas ])
172190
173- # 4. Stretch shift to terminal
174- one_minus = 1.0 - sigmas
175- scale_factor = one_minus [- 1 ] / (1.0 - shift_terminal )
176- sigmas = 1.0 - (one_minus / scale_factor )
191+ # 3. Stretch shift to terminal (only for base model schedule)
192+ if shift_override is None :
193+ shift_terminal = 0.02
194+ one_minus = 1.0 - sigmas
195+ scale_factor = one_minus [- 1 ] / (1.0 - shift_terminal )
196+ sigmas = 1.0 - (one_minus / scale_factor )
177197
178- # 5 . Append terminal 0
198+ # 4 . Append terminal 0
179199 sigmas = np .append (sigmas , 0.0 )
180200
181201 return sigmas .tolist ()
@@ -219,7 +239,10 @@ def _run_diffusion(self, context: InvocationContext):
219239
220240 neg_prompt_embeds = None
221241 neg_prompt_mask = None
222- do_classifier_free_guidance = self .negative_conditioning is not None
242+ # Match the diffusers pipeline: only enable CFG when cfg_scale > 1 AND negative conditioning is provided.
243+ # With cfg_scale <= 1, the negative prediction is unused, so skip it entirely.
244+ cfg_scale_value = self .cfg_scale if isinstance (self .cfg_scale , float ) else self .cfg_scale [0 ]
245+ do_classifier_free_guidance = self .negative_conditioning is not None and cfg_scale_value > 1.0
223246 if do_classifier_free_guidance :
224247 neg_prompt_embeds , neg_prompt_mask = self ._load_text_conditioning (
225248 context = context ,
@@ -238,11 +261,40 @@ def _run_diffusion(self, context: InvocationContext):
238261 latent_height = self .height // LATENT_SCALE_FACTOR
239262 latent_width = self .width // LATENT_SCALE_FACTOR
240263 image_seq_len = (latent_height * latent_width ) // (patch_size ** 2 )
241- # Compute the shifted sigma schedule (N+1 values, last is 0.0).
242- # The sigmas serve both as the Euler step sizes AND the timestep conditioning to the model.
243- sigmas = self ._compute_sigmas (image_seq_len , self .steps )
244- sigmas = clip_timestep_schedule_fractional (sigmas , self .denoising_start , self .denoising_end )
245- total_steps = len (sigmas ) - 1
264+
265+ # Use the actual FlowMatchEulerDiscreteScheduler to compute sigmas/timesteps,
266+ # exactly matching the diffusers pipeline.
267+ from diffusers .schedulers .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
268+
269+ scheduler = FlowMatchEulerDiscreteScheduler .from_pretrained (
270+ str (context .models .get_absolute_path (context .models .get_config (self .transformer .transformer )) / "scheduler" ),
271+ local_files_only = True ,
272+ )
273+
274+ import math
275+ import numpy as np
276+
277+ if self .shift is not None :
278+ # Lightning LoRA: fixed shift
279+ mu = math .log (self .shift )
280+ else :
281+ # Default dynamic shifting from scheduler config
282+ from diffusers .pipelines .qwenimage .pipeline_qwenimage_edit import calculate_shift
283+
284+ mu = calculate_shift (
285+ image_seq_len ,
286+ scheduler .config .get ("base_image_seq_len" , 256 ),
287+ scheduler .config .get ("max_image_seq_len" , 4096 ),
288+ scheduler .config .get ("base_shift" , 0.5 ),
289+ scheduler .config .get ("max_shift" , 1.15 ),
290+ )
291+
292+ init_sigmas = np .linspace (1.0 , 1.0 / self .steps , self .steps ).tolist ()
293+ scheduler .set_timesteps (sigmas = init_sigmas , mu = mu , device = device )
294+
295+ timesteps_sched = scheduler .timesteps
296+ sigmas_sched = scheduler .sigmas
297+ total_steps = len (timesteps_sched )
246298
247299 cfg_scale = self ._prepare_cfg_scale (total_steps )
248300
@@ -276,14 +328,14 @@ def _run_diffusion(self, context: InvocationContext):
276328
277329 # Prepare input latent image
278330 if init_latents is not None :
279- s_0 = sigmas [0 ]
331+ s_0 = sigmas_sched [0 ]. item ()
280332 latents = s_0 * noise + (1.0 - s_0 ) * init_latents
281333 else :
282334 if self .denoising_start > 1e-5 :
283335 raise ValueError ("denoising_start should be 0 when initial latents are not provided." )
284336 latents = noise
285337
286- if len ( sigmas ) <= 1 :
338+ if total_steps <= 0 :
287339 return latents
288340
289341 # Pack latents into 2x2 patches: (B, C, H, W) -> (B, H/2*W/2, C*4)
@@ -292,7 +344,7 @@ def _run_diffusion(self, context: InvocationContext):
292344 # Pack reference image latents and concatenate along the sequence dimension.
293345 # The edit transformer always expects [noisy_patches ; ref_patches] in its sequence.
294346 if ref_latents is not None :
295- _ , _ , rh , rw = ref_latents .shape
347+ _ , ref_ch , rh , rw = ref_latents .shape
296348 if rh != latent_height or rw != latent_width :
297349 ref_latents = torch .nn .functional .interpolate (
298350 ref_latents , size = (latent_height , latent_width ), mode = "bilinear"
@@ -329,23 +381,38 @@ def _run_diffusion(self, context: InvocationContext):
329381 step = 0 ,
330382 order = 1 ,
331383 total_steps = total_steps ,
332- timestep = int (sigmas [0 ] * 1000 ) ,
384+ timestep = int (timesteps_sched [0 ]. item ()) if len ( timesteps_sched ) > 0 else 0 ,
333385 latents = self ._unpack_latents (latents , latent_height , latent_width ),
334386 ),
335387 )
336388
337389 noisy_seq_len = latents .shape [1 ]
338390
339- with transformer_info .model_on_device () as (_ , transformer ):
391+ # Determine if the model is quantized — GGUF models need sidecar patching for LoRAs
392+ transformer_config = context .models .get_config (self .transformer .transformer )
393+ model_is_quantized = transformer_config .format in (ModelFormat .GGUFQuantized ,)
394+
395+ with ExitStack () as exit_stack :
396+ (cached_weights , transformer ) = exit_stack .enter_context (transformer_info .model_on_device ())
340397 assert isinstance (transformer , QwenImageTransformer2DModel )
341398
342- for step_idx in tqdm (range (total_steps )):
343- sigma_curr = sigmas [step_idx ]
344- sigma_next = sigmas [step_idx + 1 ]
399+ # Apply LoRA patches to the transformer
400+ exit_stack .enter_context (
401+ LayerPatcher .apply_smart_model_patches (
402+ model = transformer ,
403+ patches = self ._lora_iterator (context ),
404+ prefix = QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX ,
405+ dtype = inference_dtype ,
406+ cached_weights = cached_weights ,
407+ force_sidecar_patching = model_is_quantized ,
408+ )
409+ )
410+
411+ scheduler .set_begin_index (0 )
345412
346- # The model receives the shifted sigma as its time conditioning.
347- # Diffusers stores timesteps = sigma * 1000 and passes timestep / 1000.
348- timestep = torch . tensor ([ sigma_curr ], device = device ). expand ( 1 ).to (inference_dtype )
413+ for step_idx , t in enumerate ( tqdm ( timesteps_sched )):
414+ # The pipeline passes timestep / 1000 to the transformer
415+ timestep = t . expand ( latents . shape [ 0 ] ).to (inference_dtype )
349416
350417 # Concatenate noisy and reference patches along the sequence dim
351418 model_input = torch .cat ([latents , ref_latents_packed ], dim = 1 )
@@ -354,7 +421,7 @@ def _run_diffusion(self, context: InvocationContext):
354421 hidden_states = model_input ,
355422 encoder_hidden_states = pos_prompt_embeds ,
356423 encoder_hidden_states_mask = pos_prompt_mask ,
357- timestep = timestep ,
424+ timestep = timestep / 1000 ,
358425 img_shapes = img_shapes ,
359426 return_dict = False ,
360427 )[0 ]
@@ -366,7 +433,7 @@ def _run_diffusion(self, context: InvocationContext):
366433 hidden_states = model_input ,
367434 encoder_hidden_states = neg_prompt_embeds ,
368435 encoder_hidden_states_mask = neg_prompt_mask ,
369- timestep = timestep ,
436+ timestep = timestep / 1000 ,
370437 img_shapes = img_shapes ,
371438 return_dict = False ,
372439 )[0 ]
@@ -376,14 +443,11 @@ def _run_diffusion(self, context: InvocationContext):
376443 else :
377444 noise_pred = noise_pred_cond
378445
379- latents_dtype = latents .dtype
380- latents = latents .to (dtype = torch .float32 )
381- dt = sigma_next - sigma_curr
382- latents = latents + dt * noise_pred
383- latents = latents .to (dtype = latents_dtype )
446+ # Use the scheduler's step method — exactly matching the pipeline
447+ latents = scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
384448
385449 if inpaint_extension is not None :
386- # Unpack to 4D for inpaint merging, then repack
450+ sigma_next = sigmas_sched [ step_idx + 1 ]. item ()
387451 latents_4d = self ._unpack_latents (latents , latent_height , latent_width )
388452 latents_4d = inpaint_extension .merge_intermediate_latents_with_init_latents (latents_4d , sigma_next )
389453 latents = self ._pack_latents (latents_4d , 1 , out_channels , latent_height , latent_width )
@@ -393,7 +457,7 @@ def _run_diffusion(self, context: InvocationContext):
393457 step = step_idx + 1 ,
394458 order = 1 ,
395459 total_steps = total_steps ,
396- timestep = int (sigma_curr * 1000 ),
460+ timestep = int (t . item () ),
397461 latents = self ._unpack_latents (latents , latent_height , latent_width ),
398462 ),
399463 )
@@ -408,3 +472,14 @@ def step_callback(state: PipelineIntermediateState) -> None:
408472 context .util .sd_step_callback (state , BaseModelType .QwenImageEdit )
409473
410474 return step_callback
475+
476+ def _lora_iterator (self , context : InvocationContext ) -> Iterator [Tuple [ModelPatchRaw , float ]]:
477+ """Iterate over LoRA models to apply to the transformer."""
478+ for lora in self .transformer .loras :
479+ lora_info = context .models .load (lora .lora )
480+ if not isinstance (lora_info .model , ModelPatchRaw ):
481+ raise TypeError (
482+ f"Expected ModelPatchRaw for LoRA '{ lora .lora .key } ', got { type (lora_info .model ).__name__ } ."
483+ )
484+ yield (lora_info .model , lora .weight )
485+ del lora_info
0 commit comments