Skip to content

add causal_upper_left mask option to scaled_dot_product_attention#3254

Open
mm65x wants to merge 2 commits intoml-explore:mainfrom
mm65x:sdpa-causal-upper-left
Open

add causal_upper_left mask option to scaled_dot_product_attention#3254
mm65x wants to merge 2 commits intoml-explore:mainfrom
mm65x:sdpa-causal-upper-left

Conversation

@mm65x
Copy link
Copy Markdown
Contributor

@mm65x mm65x commented Mar 14, 2026

Proposed changes

#2835

adds "causal_upper_left" and "causal_lower_right" as explicit mask options
to mx.fast.scaled_dot_product_attention. "causal" stays as an alias for
"causal_lower_right", so nothing breaks.

when S_Q != S_KV, lower-right aligns the last query with the last key,
while upper-left aligns query i with keys 0..i (matching PyTorch's
is_causal=True). when S_Q == S_KV they're identical.

on Metal, the full-attention kernels already parameterize the causal diagonal
via qL_off in AttnParams, so the change is just passing 0 instead of
kL - qL for upper-left. the vector kernels previously hardcoded the offset
inline, so a causal_offset buffer argument was added instead. on CUDA, the
cuDNN path uses set_causal_mask vs set_causal_mask_bottom_right, and the
vector kernels use a causal_offset field in AttnParams (same approach as
Metal).

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@mm65x mm65x force-pushed the sdpa-causal-upper-left branch from 7693beb to 6b2d587 Compare March 14, 2026 23:07
@mm65x mm65x force-pushed the sdpa-causal-upper-left branch from 670373d to a435cd1 Compare March 20, 2026 10:33
@mm65x mm65x marked this pull request as ready for review March 20, 2026 10:33
@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented Mar 25, 2026

Can you fix the lint error?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants