Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def _wrapped_flash_attn_3(
) -> tuple[torch.Tensor, torch.Tensor]:
# Hardcoded for now because pytorch does not support tuple/int type hints
window_size = (-1, -1)
out, lse, *_ = flash_attn_3_func(
result = flash_attn_3_func(
q=q,
k=k,
v=v,
Expand All @@ -693,8 +693,15 @@ def _wrapped_flash_attn_3(
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=True,
)
lse = lse.permute(0, 2, 1)
# Handle both old FA3 (always returns tuple) and new FA3 (returns tuple only with return_attn_probs=True)
if isinstance(result, tuple):
out, lse, *_ = result
lse = lse.permute(0, 2, 1)
else:
out = result
lse = torch.empty(q.shape[0], q.shape[2], q.shape[1], device=q.device, dtype=torch.float32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we need this guard. In both cases (old vs new FA3) we're always returning a tuple since return_attn_probs=True? Why not just leave as out, lse, *_ = flash_attn_3_func

return out, lse


Expand Down Expand Up @@ -2623,7 +2630,7 @@ def _flash_varlen_attention_3(
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)

out, lse, *_ = flash_attn_3_varlen_func(
result = flash_attn_3_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
Expand All @@ -2633,7 +2640,13 @@ def _flash_varlen_attention_3(
max_seqlen_k=max_seqlen_k,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if isinstance(result, tuple):
out, lse, *_ = result
else:
out = result
lse = None
out = out.unflatten(0, (batch_size, -1))

return (out, lse) if return_lse else out
Expand Down