Skip to content

Commit 50b4a8d

Browse files
committed
Reapply " Fix QwenImage txt_seq_lens handling (huggingface#12702)"
This reverts commit 6b77b72.
1 parent 6b77b72 commit 50b4a8d

17 files changed

Lines changed: 513 additions & 172 deletions

docs/source/en/api/pipelines/qwenimage.md

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,46 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
108108
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
109109
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
110110
image = pipe(
111-
image=[image_1, image_2],
112-
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
111+
image=[image_1, image_2],
112+
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
113113
num_inference_steps=50
114114
).images[0]
115115
```
116116

117+
## Performance
118+
119+
### torch.compile
120+
121+
Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s):
122+
123+
```python
124+
import torch
125+
from diffusers import QwenImagePipeline
126+
127+
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
128+
pipe.transformer = torch.compile(pipe.transformer)
129+
130+
# First call triggers compilation (~7s overhead)
131+
# Subsequent calls run at ~2.4x faster
132+
image = pipe("a cat", num_inference_steps=50).images[0]
133+
```
134+
135+
### Batched Inference with Variable-Length Prompts
136+
137+
When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output.
138+
139+
```python
140+
# CFG with different prompt lengths works correctly
141+
image = pipe(
142+
prompt="A cat",
143+
negative_prompt="blurry, low quality, distorted",
144+
true_cfg_scale=3.5,
145+
num_inference_steps=50,
146+
).images[0]
147+
```
148+
149+
For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f).
150+
117151
## QwenImagePipeline
118152

119153
[[autodoc]] QwenImagePipeline

examples/dreambooth/train_dreambooth_lora_qwen_image.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,14 +1513,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15131513
height=model_input.shape[3],
15141514
width=model_input.shape[4],
15151515
)
1516-
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
15171516
model_pred = transformer(
15181517
hidden_states=packed_noisy_model_input,
15191518
encoder_hidden_states=prompt_embeds,
15201519
encoder_hidden_states_mask=prompt_embeds_mask,
15211520
timestep=timesteps / 1000,
15221521
img_shapes=img_shapes,
1523-
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
15241522
return_dict=False,
15251523
)[0]
15261524
model_pred = QwenImagePipeline._unpack_latents(

src/diffusers/models/attention_dispatch.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,6 +2128,43 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
21282128
return out
21292129

21302130

2131+
def _prepare_additive_attn_mask(
2132+
attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True
2133+
) -> torch.Tensor:
2134+
"""
2135+
Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA.
2136+
2137+
This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks.
2138+
2139+
Args:
2140+
attn_mask: 2D tensor [batch_size, seq_len_k]
2141+
- Boolean: True means attend, False means mask out
2142+
- Additive: 0.0 means attend, -inf means mask out
2143+
target_dtype: The dtype to convert the mask to (usually query.dtype)
2144+
reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting
2145+
2146+
Returns:
2147+
Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if
2148+
reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True.
2149+
"""
2150+
# Check if the mask is boolean or already additive
2151+
if attn_mask.dtype == torch.bool:
2152+
# Convert boolean to additive: True -> 0.0, False -> -inf
2153+
attn_mask = torch.where(attn_mask, 0.0, float("-inf"))
2154+
# Convert to target dtype
2155+
attn_mask = attn_mask.to(dtype=target_dtype)
2156+
else:
2157+
# Already additive mask - just ensure correct dtype
2158+
attn_mask = attn_mask.to(dtype=target_dtype)
2159+
2160+
# Optionally reshape to 4D for broadcasting in attention mechanisms
2161+
if reshape_4d:
2162+
batch_size, seq_len_k = attn_mask.shape
2163+
attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k)
2164+
2165+
return attn_mask
2166+
2167+
21312168
@_AttentionBackendRegistry.register(
21322169
AttentionBackendName.NATIVE,
21332170
constraints=[_check_device, _check_shape],
@@ -2147,6 +2184,19 @@ def _native_attention(
21472184
) -> torch.Tensor:
21482185
if return_lse:
21492186
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
2187+
2188+
# Reshape 2D mask to 4D for SDPA
2189+
# SDPA accepts both boolean masks (torch.bool) and additive masks (float)
2190+
if (
2191+
attn_mask is not None
2192+
and attn_mask.ndim == 2
2193+
and attn_mask.shape[0] == query.shape[0]
2194+
and attn_mask.shape[1] == key.shape[1]
2195+
):
2196+
# Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k]
2197+
# SDPA handles both boolean and additive masks correctly
2198+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
2199+
21502200
if _parallel_config is None:
21512201
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
21522202
out = torch.nn.functional.scaled_dot_product_attention(
@@ -2713,10 +2763,34 @@ def _xformers_attention(
27132763
attn_mask = xops.LowerTriangularMask()
27142764
elif attn_mask is not None:
27152765
if attn_mask.ndim == 2:
2716-
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
2766+
# Convert 2D mask to 4D for xformers
2767+
# Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask)
2768+
# xformers requires 4D additive masks [batch, heads, seq_q, seq_k]
2769+
# Need memory alignment - create larger tensor and slice for alignment
2770+
original_seq_len = attn_mask.size(1)
2771+
aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8
2772+
2773+
# Create aligned 4D tensor and slice to ensure proper memory layout
2774+
aligned_mask = torch.zeros(
2775+
(batch_size, num_heads_q, seq_len_q, aligned_seq_len),
2776+
dtype=query.dtype,
2777+
device=query.device,
2778+
)
2779+
# Convert to 4D additive mask (handles both boolean and additive inputs)
2780+
mask_additive = _prepare_additive_attn_mask(
2781+
attn_mask, target_dtype=query.dtype
2782+
) # [batch, 1, 1, seq_len_k]
2783+
# Broadcast to [batch, heads, seq_q, seq_len_k]
2784+
aligned_mask[:, :, :, :original_seq_len] = mask_additive
2785+
# Mask out the padding (already -inf from zeros -> where with default)
2786+
aligned_mask[:, :, :, original_seq_len:] = float("-inf")
2787+
2788+
# Slice to actual size with proper alignment
2789+
attn_mask = aligned_mask[:, :, :, :seq_len_kv]
27172790
elif attn_mask.ndim != 4:
27182791
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
2719-
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
2792+
elif attn_mask.ndim == 4:
2793+
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
27202794

27212795
if enable_gqa:
27222796
if num_heads_q % num_heads_kv != 0:

src/diffusers/models/controlnets/controlnet_qwenimage.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23-
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
23+
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
2424
from ..attention import AttentionMixin
2525
from ..cache_utils import CacheMixin
2626
from ..controlnets.controlnet import zero_module
@@ -31,6 +31,7 @@
3131
QwenImageTransformerBlock,
3232
QwenTimestepProjEmbeddings,
3333
RMSNorm,
34+
compute_text_seq_len_from_mask,
3435
)
3536

3637

@@ -136,7 +137,7 @@ def forward(
136137
return_dict: bool = True,
137138
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
138139
"""
139-
The [`FluxTransformer2DModel`] forward method.
140+
The [`QwenImageControlNetModel`] forward method.
140141
141142
Args:
142143
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
@@ -147,24 +148,39 @@ def forward(
147148
The scale factor for ControlNet outputs.
148149
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
149150
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
150-
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
151-
from the embeddings of input conditions.
151+
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
152+
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
153+
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
154+
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
152155
timestep ( `torch.LongTensor`):
153156
Used to indicate denoising step.
154-
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
155-
A list of tensors that if specified are added to the residuals of transformer blocks.
157+
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
158+
Image shapes for RoPE computation.
159+
txt_seq_lens (`List[int]`, *optional*):
160+
**Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence
161+
length.
156162
joint_attention_kwargs (`dict`, *optional*):
157163
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
158164
`self.processor` in
159165
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
160166
return_dict (`bool`, *optional*, defaults to `True`):
161-
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
162-
tuple.
167+
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
163168
164169
Returns:
165-
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
166-
`tuple` where the first element is the sample tensor.
170+
If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where
171+
the first element is the controlnet block samples.
167172
"""
173+
# Handle deprecated txt_seq_lens parameter
174+
if txt_seq_lens is not None:
175+
deprecate(
176+
"txt_seq_lens",
177+
"0.39.0",
178+
"Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in "
179+
"version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` "
180+
"and `encoder_hidden_states_mask`.",
181+
standard_warn=False,
182+
)
183+
168184
if joint_attention_kwargs is not None:
169185
joint_attention_kwargs = joint_attention_kwargs.copy()
170186
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
@@ -186,32 +202,47 @@ def forward(
186202

187203
temb = self.time_text_embed(timestep, hidden_states)
188204

189-
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
205+
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
206+
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
207+
encoder_hidden_states, encoder_hidden_states_mask
208+
)
209+
210+
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
190211

191212
timestep = timestep.to(hidden_states.dtype)
192213
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
193214
encoder_hidden_states = self.txt_in(encoder_hidden_states)
194215

216+
# Construct joint attention mask once to avoid reconstructing in every block
217+
block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {}
218+
if encoder_hidden_states_mask is not None:
219+
# Build joint mask: [text_mask, all_ones_for_image]
220+
batch_size, image_seq_len = hidden_states.shape[:2]
221+
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
222+
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
223+
block_attention_kwargs["attention_mask"] = joint_attention_mask
224+
195225
block_samples = ()
196-
for index_block, block in enumerate(self.transformer_blocks):
226+
for block in self.transformer_blocks:
197227
if torch.is_grad_enabled() and self.gradient_checkpointing:
198228
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
199229
block,
200230
hidden_states,
201231
encoder_hidden_states,
202-
encoder_hidden_states_mask,
232+
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
203233
temb,
204234
image_rotary_emb,
235+
block_attention_kwargs,
205236
)
206237

207238
else:
208239
encoder_hidden_states, hidden_states = block(
209240
hidden_states=hidden_states,
210241
encoder_hidden_states=encoder_hidden_states,
211-
encoder_hidden_states_mask=encoder_hidden_states_mask,
242+
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
212243
temb=temb,
213244
image_rotary_emb=image_rotary_emb,
214-
joint_attention_kwargs=joint_attention_kwargs,
245+
joint_attention_kwargs=block_attention_kwargs,
215246
)
216247
block_samples = block_samples + (hidden_states,)
217248

@@ -267,6 +298,15 @@ def forward(
267298
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
268299
return_dict: bool = True,
269300
) -> Union[QwenImageControlNetOutput, Tuple]:
301+
if txt_seq_lens is not None:
302+
deprecate(
303+
"txt_seq_lens",
304+
"0.39.0",
305+
"Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be "
306+
"removed in version 0.39.0. The text sequence length is now automatically inferred from "
307+
"`encoder_hidden_states` and `encoder_hidden_states_mask`.",
308+
standard_warn=False,
309+
)
270310
# ControlNet-Union with multiple conditions
271311
# only load one ControlNet for saving memories
272312
if len(self.nets) == 1:
@@ -281,7 +321,6 @@ def forward(
281321
encoder_hidden_states_mask=encoder_hidden_states_mask,
282322
timestep=timestep,
283323
img_shapes=img_shapes,
284-
txt_seq_lens=txt_seq_lens,
285324
joint_attention_kwargs=joint_attention_kwargs,
286325
return_dict=return_dict,
287326
)

0 commit comments

Comments
 (0)