From d7e2c0736bdbe634bfb48dbcd880a832d44d91ad Mon Sep 17 00:00:00 2001 From: Varun Chawla Date: Sun, 22 Feb 2026 20:19:05 -0800 Subject: [PATCH 1/2] Fix Flash Attention 3 interface compatibility for new FA3 versions Newer versions of flash-attn (after Dao-AILab/flash-attention@ed20940) no longer return lse by default from flash_attn_3_func. The function now returns just the output tensor unless return_attn_probs=True is passed. Updated _wrapped_flash_attn_3 and _flash_varlen_attention_3 to pass return_attn_probs and handle both old (always tuple) and new (tensor or tuple) return formats gracefully. Fixes #12022 --- src/diffusers/models/attention_dispatch.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 90ffcac80dc5..5ebf5d5591a6 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -676,7 +676,7 @@ def _wrapped_flash_attn_3( ) -> tuple[torch.Tensor, torch.Tensor]: # Hardcoded for now because pytorch does not support tuple/int type hints window_size = (-1, -1) - out, lse, *_ = flash_attn_3_func( + result = flash_attn_3_func( q=q, k=k, v=v, @@ -693,8 +693,15 @@ def _wrapped_flash_attn_3( pack_gqa=pack_gqa, deterministic=deterministic, sm_margin=sm_margin, + return_attn_probs=True, ) - lse = lse.permute(0, 2, 1) + # Handle both old FA3 (always returns tuple) and new FA3 (returns tuple only with return_attn_probs=True) + if isinstance(result, tuple): + out, lse, *_ = result + lse = lse.permute(0, 2, 1) + else: + out = result + lse = torch.empty(q.shape[0], q.shape[2], q.shape[1], device=q.device, dtype=torch.float32) return out, lse @@ -2623,7 +2630,7 @@ def _flash_varlen_attention_3( key_packed = torch.cat(key_valid, dim=0) value_packed = torch.cat(value_valid, dim=0) - out, lse, *_ = flash_attn_3_varlen_func( + result = flash_attn_3_varlen_func( q=query_packed, k=key_packed, v=value_packed, @@ -2633,7 +2640,13 @@ def _flash_varlen_attention_3( max_seqlen_k=max_seqlen_k, softmax_scale=scale, causal=is_causal, + return_attn_probs=return_lse, ) + if isinstance(result, tuple): + out, lse, *_ = result + else: + out = result + lse = None out = out.unflatten(0, (batch_size, -1)) return (out, lse) if return_lse else out From c0961d56ea0e8990643fc79cb97f6c37bbe935cb Mon Sep 17 00:00:00 2001 From: Varun Chawla Date: Mon, 23 Feb 2026 23:27:52 -0800 Subject: [PATCH 2/2] Simplify _wrapped_flash_attn_3 return unpacking Since return_attn_probs=True is always passed, the result is guaranteed to be a tuple. Remove the unnecessary isinstance guard. --- src/diffusers/models/attention_dispatch.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 5ebf5d5591a6..d6a4fd019b5b 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -695,13 +695,8 @@ def _wrapped_flash_attn_3( sm_margin=sm_margin, return_attn_probs=True, ) - # Handle both old FA3 (always returns tuple) and new FA3 (returns tuple only with return_attn_probs=True) - if isinstance(result, tuple): - out, lse, *_ = result - lse = lse.permute(0, 2, 1) - else: - out = result - lse = torch.empty(q.shape[0], q.shape[2], q.shape[1], device=q.device, dtype=torch.float32) + out, lse, *_ = result + lse = lse.permute(0, 2, 1) return out, lse