Skip to content

[WIP] [fusilli] Emit flex_attention for SDPA#413

Draft
keshavvinayak01 wants to merge 6 commits into
mainfrom
issue-275-generate-stats
Draft

[WIP] [fusilli] Emit flex_attention for SDPA#413
keshavvinayak01 wants to merge 6 commits into
mainfrom
issue-275-generate-stats

Conversation

@keshavvinayak01
Copy link
Copy Markdown
Contributor

@keshavvinayak01 keshavvinayak01 commented May 12, 2026

Emit torch.hop_flex_attention for Fusilli SDPA cases without an explicit tensor mask or dropout. Causal SDPA attaches a generated mask_mod_fn, GQA sets enable_gqa, and explicit masks/dropout continue to use torch.aten.scaled_dot_product_attention until explicit mask arguments are available.

Adds generate_stats support by returning the flex attention logsumexp as the STATS output, with sample, lit, and unit coverage for the new path.

@keshavvinayak01 keshavvinayak01 changed the title [fusilli] Emit flex_attention for SDPA [WIP] [fusilli] Emit flex_attention for SDPA May 12, 2026
@keshavvinayak01 keshavvinayak01 force-pushed the issue-275-generate-stats branch 7 times, most recently from 36e528a to 55346cb Compare May 12, 2026 19:16
Emit torch.hop_flex_attention for SDPA cases without explicit masks or dropout, including causal mask_mod and enable_gqa handling.

Add generate_stats support by returning logsumexp stats and update samples/tests for the new SDPA path.

Keep Graph::sdpa as a single-output API and leave explicit tensor masks/dropout on the legacy SDPA path.

Co-authored-by: Codex <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01 keshavvinayak01 force-pushed the issue-275-generate-stats branch from 55346cb to 4066b13 Compare May 12, 2026 19:27
keshavvinayak01 and others added 2 commits May 13, 2026 11:22
@keshavvinayak01 keshavvinayak01 force-pushed the issue-275-generate-stats branch from 43366e5 to 64ddd2c Compare May 13, 2026 08:27
keshavvinayak01 and others added 3 commits May 13, 2026 17:19
Allow causal non-GQA SDPA to use the flex_attention path now that the IREE mask_mod lowering handles the generated causal callback.

Update the causal emitter lit checks to expect the mask_mod callback and adjusted compile statistics.

Co-authored-by: GPT-5 Codex <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Emit torch.hop_flex_attention from the SDPA custom-op sample when a custom scale is present and the case does not require explicit mask, dropout, or causal handling.

Remove the stale xfail for the custom-op custom-scale sample while keeping the regular SDPA custom-scale xfail on the legacy path.

Co-authored-by: GPT-5 Codex <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Remove the regular SDPA scale fallback now that torch.hop_flex_attention supports an explicit scale operand.

Update the custom-scale emitter test to expect flex_attention and remove the stale custom-scale xfail.

Co-authored-by: GPT-5 Codex <noreply@openai.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@keshavvinayak01
Copy link
Copy Markdown
Contributor Author

keshavvinayak01 commented May 13, 2026

@sjain-stanford could you review this? The failing CI tests would work once iree-org/iree#24426 is in.

@keshavvinayak01
Copy link
Copy Markdown
Contributor Author

Also about the mask support that we discussed on discord, I could add that in a follow up PR too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant