[WIP] [fusilli] Emit flex_attention for SDPA#413
Draft
keshavvinayak01 wants to merge 6 commits into
Draft
Conversation
36e528a to
55346cb
Compare
Emit torch.hop_flex_attention for SDPA cases without explicit masks or dropout, including causal mask_mod and enable_gqa handling. Add generate_stats support by returning logsumexp stats and update samples/tests for the new SDPA path. Keep Graph::sdpa as a single-output API and leave explicit tensor masks/dropout on the legacy SDPA path. Co-authored-by: Codex <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
55346cb to
4066b13
Compare
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
43366e5 to
64ddd2c
Compare
Allow causal non-GQA SDPA to use the flex_attention path now that the IREE mask_mod lowering handles the generated causal callback. Update the causal emitter lit checks to expect the mask_mod callback and adjusted compile statistics. Co-authored-by: GPT-5 Codex <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Emit torch.hop_flex_attention from the SDPA custom-op sample when a custom scale is present and the case does not require explicit mask, dropout, or causal handling. Remove the stale xfail for the custom-op custom-scale sample while keeping the regular SDPA custom-scale xfail on the legacy path. Co-authored-by: GPT-5 Codex <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Remove the regular SDPA scale fallback now that torch.hop_flex_attention supports an explicit scale operand. Update the custom-scale emitter test to expect flex_attention and remove the stale custom-scale xfail. Co-authored-by: GPT-5 Codex <noreply@openai.com> Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Contributor
Author
|
@sjain-stanford could you review this? The failing CI tests would work once iree-org/iree#24426 is in. |
Contributor
Author
|
Also about the mask support that we discussed on discord, I could add that in a follow up PR too. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Emit
torch.hop_flex_attentionfor Fusilli SDPA cases without an explicit tensor mask or dropout. Causal SDPA attaches a generatedmask_mod_fn, GQA setsenable_gqa, and explicit masks/dropout continue to usetorch.aten.scaled_dot_product_attentionuntil explicit mask arguments are available.Adds
generate_statssupport by returning the flex attention logsumexp as theSTATSoutput, with sample, lit, and unit coverage for the new path.