-
Notifications
You must be signed in to change notification settings - Fork 1.8k
feat(fast): add sliding-window SDPA kernel path #3552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rabbitson87
wants to merge
2
commits into
ml-explore:main
Choose a base branch
from
rabbitson87:sliding-window-attention-kernel
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the
+ 1here needed ?For given (adjusted) query index q, the window should attend to keys
[q - window_size, q]inclusive - meaning the first k would be justq_min - params->window_sizeto match the logicThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the + 1 is intentional with the current definition of window_size.
I interpreted window_size = W as the number of visible positions in the left window, including the current query position. So the valid key range is:
[q - W + 1, q]
For example, with window_size = 2, query q should attend to {q - 1, q}, not {q - 2, q}. Without the + 1, the window would include W + 1 positions.
That said, if we want window_size to mean “number of previous keys in addition to the current key”, then your suggested formula would be correct. I’ll make sure the docs and implementation use the same convention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, there seems to be a gap between between the lower level implementation providers (flash attention and Jax sdpa) vs the upper level implementation use (hugging face)
See below: https://github.com/huggingface/transformers/blob/52b82b299171721fbe7b04fe056187f7aed2e2cc/src/transformers/modeling_flash_attention_utils.py#L627
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This decision comes down to what convention we wish to follow, give me a bit to get back to you on that front
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, thanks for checking.
Just to make the current PR intent explicit: the implementation is currently following the HF/model-level convention where
window_size=Wmeans a total causal window of W visible positions including the current token, so the valid range is[q - W + 1, q].I agree the lower-level provider convention is different: FlashAttention/JAX-style APIs expose left/right window sizes with inclusive bounds, where a left window of W plus right 0 would cover W previous keys plus the current key.
I’ll hold off on further changes until MLX decides which convention the public API should follow. If MLX wants provider-style semantics, I can update the kernel, fallback mask, docs, and tests to use
[q - W, q]; if MLX wants HF/model-style semantics, the current+ 1should stay and I’ll make that convention explicit in the docs/tests.