2020
2121from ...configuration_utils import ConfigMixin , register_to_config
2222from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
23- from ...utils import USE_PEFT_BACKEND , BaseOutput , logging , scale_lora_layers , unscale_lora_layers
23+ from ...utils import USE_PEFT_BACKEND , BaseOutput , deprecate , logging , scale_lora_layers , unscale_lora_layers
2424from ..attention import AttentionMixin
2525from ..cache_utils import CacheMixin
2626from ..controlnets .controlnet import zero_module
3131 QwenImageTransformerBlock ,
3232 QwenTimestepProjEmbeddings ,
3333 RMSNorm ,
34+ compute_text_seq_len_from_mask ,
3435)
3536
3637
@@ -136,7 +137,7 @@ def forward(
136137 return_dict : bool = True ,
137138 ) -> Union [torch .FloatTensor , Transformer2DModelOutput ]:
138139 """
139- The [`FluxTransformer2DModel `] forward method.
140+ The [`QwenImageControlNetModel `] forward method.
140141
141142 Args:
142143 hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
@@ -147,24 +148,39 @@ def forward(
147148 The scale factor for ControlNet outputs.
148149 encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
149150 Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
150- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
151- from the embeddings of input conditions.
151+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
152+ Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
153+ Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
154+ (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
152155 timestep ( `torch.LongTensor`):
153156 Used to indicate denoising step.
154- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
155- A list of tensors that if specified are added to the residuals of transformer blocks.
157+ img_shapes (`List[Tuple[int, int, int]]`, *optional*):
158+ Image shapes for RoPE computation.
159+ txt_seq_lens (`List[int]`, *optional*):
160+ **Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence
161+ length.
156162 joint_attention_kwargs (`dict`, *optional*):
157163 A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
158164 `self.processor` in
159165 [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
160166 return_dict (`bool`, *optional*, defaults to `True`):
161- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
162- tuple.
167+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
163168
164169 Returns:
165- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput `] is returned, otherwise a
166- `tuple` where the first element is the sample tensor .
170+ If `return_dict` is True, a [`~models.controlnet.ControlNetOutput `] is returned, otherwise a `tuple` where
171+ the first element is the controlnet block samples .
167172 """
173+ # Handle deprecated txt_seq_lens parameter
174+ if txt_seq_lens is not None :
175+ deprecate (
176+ "txt_seq_lens" ,
177+ "0.39.0" ,
178+ "Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in "
179+ "version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` "
180+ "and `encoder_hidden_states_mask`." ,
181+ standard_warn = False ,
182+ )
183+
168184 if joint_attention_kwargs is not None :
169185 joint_attention_kwargs = joint_attention_kwargs .copy ()
170186 lora_scale = joint_attention_kwargs .pop ("scale" , 1.0 )
@@ -186,32 +202,47 @@ def forward(
186202
187203 temb = self .time_text_embed (timestep , hidden_states )
188204
189- image_rotary_emb = self .pos_embed (img_shapes , txt_seq_lens , device = hidden_states .device )
205+ # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
206+ text_seq_len , _ , encoder_hidden_states_mask = compute_text_seq_len_from_mask (
207+ encoder_hidden_states , encoder_hidden_states_mask
208+ )
209+
210+ image_rotary_emb = self .pos_embed (img_shapes , max_txt_seq_len = text_seq_len , device = hidden_states .device )
190211
191212 timestep = timestep .to (hidden_states .dtype )
192213 encoder_hidden_states = self .txt_norm (encoder_hidden_states )
193214 encoder_hidden_states = self .txt_in (encoder_hidden_states )
194215
216+ # Construct joint attention mask once to avoid reconstructing in every block
217+ block_attention_kwargs = joint_attention_kwargs .copy () if joint_attention_kwargs is not None else {}
218+ if encoder_hidden_states_mask is not None :
219+ # Build joint mask: [text_mask, all_ones_for_image]
220+ batch_size , image_seq_len = hidden_states .shape [:2 ]
221+ image_mask = torch .ones ((batch_size , image_seq_len ), dtype = torch .bool , device = hidden_states .device )
222+ joint_attention_mask = torch .cat ([encoder_hidden_states_mask , image_mask ], dim = 1 )
223+ block_attention_kwargs ["attention_mask" ] = joint_attention_mask
224+
195225 block_samples = ()
196- for index_block , block in enumerate ( self .transformer_blocks ) :
226+ for block in self .transformer_blocks :
197227 if torch .is_grad_enabled () and self .gradient_checkpointing :
198228 encoder_hidden_states , hidden_states = self ._gradient_checkpointing_func (
199229 block ,
200230 hidden_states ,
201231 encoder_hidden_states ,
202- encoder_hidden_states_mask ,
232+ None , # Don't pass encoder_hidden_states_mask (using attention_mask instead)
203233 temb ,
204234 image_rotary_emb ,
235+ block_attention_kwargs ,
205236 )
206237
207238 else :
208239 encoder_hidden_states , hidden_states = block (
209240 hidden_states = hidden_states ,
210241 encoder_hidden_states = encoder_hidden_states ,
211- encoder_hidden_states_mask = encoder_hidden_states_mask ,
242+ encoder_hidden_states_mask = None , # Don't pass (using attention_mask instead)
212243 temb = temb ,
213244 image_rotary_emb = image_rotary_emb ,
214- joint_attention_kwargs = joint_attention_kwargs ,
245+ joint_attention_kwargs = block_attention_kwargs ,
215246 )
216247 block_samples = block_samples + (hidden_states ,)
217248
@@ -267,6 +298,15 @@ def forward(
267298 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
268299 return_dict : bool = True ,
269300 ) -> Union [QwenImageControlNetOutput , Tuple ]:
301+ if txt_seq_lens is not None :
302+ deprecate (
303+ "txt_seq_lens" ,
304+ "0.39.0" ,
305+ "Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be "
306+ "removed in version 0.39.0. The text sequence length is now automatically inferred from "
307+ "`encoder_hidden_states` and `encoder_hidden_states_mask`." ,
308+ standard_warn = False ,
309+ )
270310 # ControlNet-Union with multiple conditions
271311 # only load one ControlNet for saving memories
272312 if len (self .nets ) == 1 :
@@ -281,7 +321,6 @@ def forward(
281321 encoder_hidden_states_mask = encoder_hidden_states_mask ,
282322 timestep = timestep ,
283323 img_shapes = img_shapes ,
284- txt_seq_lens = txt_seq_lens ,
285324 joint_attention_kwargs = joint_attention_kwargs ,
286325 return_dict = return_dict ,
287326 )
0 commit comments