Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlx/backend/cuda/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,11 @@ bool ScaledDotProductAttention::use_fallback(
bool do_causal,
bool is_training,
bool output_logsumexp,
int window_size,
Stream s) {
if (window_size > 0) {
return true;
}
if (s.device == Device::cpu) {
return true;
}
Expand Down
43 changes: 39 additions & 4 deletions mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];
constant bool has_window [[function_constant(303)]];

struct MaxOp {
template <typename T>
Expand Down Expand Up @@ -234,9 +235,10 @@ template <
}

int kb_lim = params->NK;
int kb_start = 0;
int kb_min_causal = params->NK;

if (do_causal) {
if (do_causal || has_window) {
int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK;
kb_lim = min(params->NK, kb_lim);
Expand All @@ -246,8 +248,19 @@ template <
kb_min_causal = (q_min / BK);
}

if (has_window) {
int q_min = tid.x * BQ + params->qL_off;
int k_min = q_min - params->window_size + 1;
if (k_min > 0) {
kb_start = k_min / BK;
if (kb_start > kb_lim) {
kb_start = kb_lim;
}
}
}

// Loop over KV seq length
for (int kb = 0; kb < kb_lim; kb++) {
for (int kb = kb_start; kb < kb_lim; kb++) {
// Load K block and apply scale
threadgroup_barrier(mem_flags::mem_threadgroup);
if (!align_K && kb == (params->NK_aligned)) {
Expand Down Expand Up @@ -302,8 +315,8 @@ template <
}
}

// Mask out if causal
if (do_causal && kb >= kb_min_causal) {
// Mask out if causal or past the right edge of the sliding window.
if ((do_causal || has_window) && kb >= kb_min_causal) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = Limits<selem_t>::finite_min;
Expand All @@ -325,6 +338,28 @@ template <
}
}

if (has_window && kb < (kb_start + ((BQ + BK - 1) / BK))) {
using stile_t = decltype(Stile);
using selem_t = typename stile_t::elem_type;
constexpr auto neg_inf = Limits<selem_t>::finite_min;

STEEL_PRAGMA_UNROLL
for (short i = 0; i < stile_t::kTileRows; i++) {
const int row_pos =
tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows);
STEEL_PRAGMA_UNROLL
for (short j = 0; j < stile_t::kTileCols; j++) {
const int col_pos = kb * BK + sn + (j * stile_t::kFragCols);
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) {
if (row_pos - (col_pos + jj) >= params->window_size) {
Stile.frag_at(i, j)[jj] = neg_inf;
}
}
}
}
}

// Other masking as needed
if (has_mask) {
using stile_t = decltype(Stile);
Expand Down
48 changes: 44 additions & 4 deletions mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ constant bool align_K [[function_constant(201)]];
constant bool has_mask [[function_constant(300)]];
constant bool do_causal [[function_constant(301)]];
constant bool has_sinks [[function_constant(302)]];
constant bool has_window [[function_constant(303)]];

template <typename T>
struct TransformScale {
Expand Down Expand Up @@ -174,9 +175,10 @@ template <
}

int kb_lim = params->NK;
int kb_start = 0;
int kb_min_causal = params->NK;

if (do_causal) {
if (do_causal || has_window) {
int q_max = (tid.x + 1) * BQ + params->qL_off;
kb_lim = (q_max + BK - 1) / BK;
kb_lim = min(params->NK, kb_lim);
Expand All @@ -186,6 +188,17 @@ template <
kb_min_causal = (q_min / BK);
}

if (has_window) {
int q_min = tid.x * BQ + params->qL_off;
int k_min = q_min - params->window_size + 1;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the + 1 here needed ?
For given (adjusted) query index q, the window should attend to keys [q - window_size, q] inclusive - meaning the first k would be just q_min - params->window_size to match the logic

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the + 1 is intentional with the current definition of window_size.

I interpreted window_size = W as the number of visible positions in the left window, including the current query position. So the valid key range is:

[q - W + 1, q]
For example, with window_size = 2, query q should attend to {q - 1, q}, not {q - 2, q}. Without the + 1, the window would include W + 1 positions.

That said, if we want window_size to mean “number of previous keys in addition to the current key”, then your suggested formula would be correct. I’ll make sure the docs and implementation use the same convention.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, there seems to be a gap between between the lower level implementation providers (flash attention and Jax sdpa) vs the upper level implementation use (hugging face)
See below: https://github.com/huggingface/transformers/blob/52b82b299171721fbe7b04fe056187f7aed2e2cc/src/transformers/modeling_flash_attention_utils.py#L627

Screenshot 2026-05-20 at 1 46 57 PM

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This decision comes down to what convention we wish to follow, give me a bit to get back to you on that front

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thanks for checking.

Just to make the current PR intent explicit: the implementation is currently following the HF/model-level convention where window_size=W means a total causal window of W visible positions including the current token, so the valid range is [q - W + 1, q].

I agree the lower-level provider convention is different: FlashAttention/JAX-style APIs expose left/right window sizes with inclusive bounds, where a left window of W plus right 0 would cover W previous keys plus the current key.

I’ll hold off on further changes until MLX decides which convention the public API should follow. If MLX wants provider-style semantics, I can update the kernel, fallback mask, docs, and tests to use [q - W, q]; if MLX wants HF/model-style semantics, the current + 1 should stay and I’ll make that convention explicit in the docs/tests.

if (k_min > 0) {
kb_start = k_min / BK;
if (kb_start > kb_lim) {
kb_start = kb_lim;
}
}
}

const bool is_last_bq = int(tid.x) == (params->NQ_aligned);
// const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ);
const bool is_last_q = is_last_bq;
Expand All @@ -194,7 +207,7 @@ template <
const short lim_rows_k = params->kL_rem;

// Loop over KV seq length
for (int kb = 0; kb < kb_lim; kb++) {
for (int kb = kb_start; kb < kb_lim; kb++) {
const int is_last_k = (kb == (params->NK_aligned));

// Do S = Q @ K.T
Expand Down Expand Up @@ -275,8 +288,8 @@ template <
}
}

// Mask out if causal
if (do_causal && kb >= kb_min_causal) {
// Mask out if causal or past the right edge of the sliding window.
if ((do_causal || has_window) && kb >= kb_min_causal) {
constexpr auto neg_inf = Limits<AccumType>::finite_min;

const int base_row = tid.x * BQ + params->qL_off + tm;
Expand All @@ -303,6 +316,33 @@ template <
}
}

if (has_window && kb < (kb_start + ((BQ + BK - 1) / BK))) {
constexpr auto neg_inf = Limits<AccumType>::finite_min;

const int base_row = tid.x * BQ + params->qL_off + tm;
const int base_col = kb * BK;

STEEL_PRAGMA_UNROLL
for (short iq = 0; iq < TQ; iq++) {
STEEL_PRAGMA_UNROLL
for (short ik = 0; ik < TK; ik++) {
thread auto& fg = Stile.frag_at(iq, ik);

STEEL_PRAGMA_UNROLL
for (short ii = 0; ii < stile_t::kFragThrRows; ii++) {
STEEL_PRAGMA_UNROLL
for (short jj = 0; jj < stile_t::kFragThrCols; jj++) {
const auto r =
base_row + iq * kU + ii * stile_t::kFragRowsJump + sm;
const auto c = base_col + ik * kU + jj + sn;
const auto loc = ii * stile_t::kFragThrCols + jj;
fg[loc] = ((r - c) >= params->window_size) ? neg_inf : fg[loc];
}
}
}
}
}

// Other masking as needed
if (has_mask) {
constexpr auto neg_inf = Limits<AccumType>::finite_min;
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/metal/kernels/steel/attn/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ struct AttnParams {
int kL_rem; ///< Remainder in last key/value block
int qL_off; ///< Offset in query sequence start

int window_size; ///< Sliding window size (-1 or 0 = no window)

int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
Expand Down
31 changes: 24 additions & 7 deletions mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ void sdpa_full_self_attention_nax(
const float scale,
array& o,
bool do_causal_,
int window_size_,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
using namespace mlx::steel;
Expand All @@ -48,13 +49,15 @@ void sdpa_full_self_attention_nax(
const bool has_mask = mask.has_value();
const bool do_causal = do_causal_;
const bool has_sinks = sinks.has_value();
const bool has_window = window_size_ > 0;

metal::MTLFCList func_consts = {
{&align_Q, MTL::DataType::DataTypeBool, 200},
{&align_K, MTL::DataType::DataTypeBool, 201},
{&has_mask, MTL::DataType::DataTypeBool, 300},
{&do_causal, MTL::DataType::DataTypeBool, 301},
{&has_sinks, MTL::DataType::DataTypeBool, 302}};
{&has_sinks, MTL::DataType::DataTypeBool, 302},
{&has_window, MTL::DataType::DataTypeBool, 303}};

std::string base_name;
concatenate(
Expand Down Expand Up @@ -87,7 +90,9 @@ void sdpa_full_self_attention_nax(
"_do_causal_",
(do_causal ? 't' : 'n'),
"_has_sinks_",
(has_sinks ? 't' : 'n'));
(has_sinks ? 't' : 'n'),
"_has_window_",
(has_window ? 't' : 'n'));

auto& compute_encoder = metal::get_command_encoder(s);

Expand Down Expand Up @@ -133,6 +138,8 @@ void sdpa_full_self_attention_nax(
/* int kL_rem = */ (kL - NK_aligned * bk),
/* int qL_off = */ (kL - qL),

/* int window_size = */ window_size_,

/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
Expand Down Expand Up @@ -172,6 +179,7 @@ void sdpa_full_self_attention_metal(
const float scale,
array& o,
bool do_causal_,
int window_size_,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
if (metal::is_nax_available() && q.shape(3) != 80 &&
Expand All @@ -185,6 +193,7 @@ void sdpa_full_self_attention_metal(
/* const float scale = */ scale,
/* array& o = */ o,
/* bool do_causal_ = */ do_causal_,
/* int window_size_ = */ window_size_,
/* const std::optional<array>& mask = */ mask,
/* const std::optional<array>& sinks = */ sinks);
}
Expand All @@ -211,13 +220,15 @@ void sdpa_full_self_attention_metal(
const bool has_mask = mask.has_value();
const bool do_causal = do_causal_;
const bool has_sinks = sinks.has_value();
const bool has_window = window_size_ > 0;

metal::MTLFCList func_consts = {
{&align_Q, MTL::DataType::DataTypeBool, 200},
{&align_K, MTL::DataType::DataTypeBool, 201},
{&has_mask, MTL::DataType::DataTypeBool, 300},
{&do_causal, MTL::DataType::DataTypeBool, 301},
{&has_sinks, MTL::DataType::DataTypeBool, 302}};
{&has_sinks, MTL::DataType::DataTypeBool, 302},
{&has_window, MTL::DataType::DataTypeBool, 303}};

std::string base_name;
concatenate(
Expand Down Expand Up @@ -250,7 +261,9 @@ void sdpa_full_self_attention_metal(
"_do_causal_",
(do_causal ? 't' : 'n'),
"_has_sinks_",
(has_sinks ? 't' : 'n'));
(has_sinks ? 't' : 'n'),
"_has_window_",
(has_window ? 't' : 'n'));

auto& compute_encoder = metal::get_command_encoder(s);

Expand Down Expand Up @@ -296,6 +309,8 @@ void sdpa_full_self_attention_metal(
/* int kL_rem = */ (kL - NK_aligned * bk),
/* int qL_off = */ (kL - qL),

/* int window_size = */ window_size_,

/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
Expand Down Expand Up @@ -597,6 +612,7 @@ bool ScaledDotProductAttention::use_fallback(
bool do_causal,
bool is_training,
bool output_logsumexp,
int window_size,
Stream s) {
if (is_training) {
// It's faster for training on Metal to use the unfused SDPA for both
Expand All @@ -623,7 +639,8 @@ bool ScaledDotProductAttention::use_fallback(
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
(query_head_dim == 64 || query_head_dim == 80 ||
query_head_dim == 128 || (query_head_dim == 256 && window_size > 0));

const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
Expand All @@ -633,7 +650,7 @@ bool ScaledDotProductAttention::use_fallback(

const bool supports_sdpa_vector = (query_sequence_length <= 8) &&
(query_sequence_length <= key_sequence_length) &&
sdpa_vector_supported_head_dim &&
sdpa_vector_supported_head_dim && window_size <= 0 &&
(query_sequence_length * gqa_factor) <= 32;

return !(supports_sdpa_full || supports_sdpa_vector);
Expand Down Expand Up @@ -782,7 +799,7 @@ void ScaledDotProductAttention::eval_gpu(
: std::nullopt;

sdpa_full_self_attention_metal(
s, d, q, k, v, scale_, o, do_causal_, mask, sinks);
s, d, q, k, v, scale_, o, do_causal_, window_size_, mask, sinks);
}

metal::get_command_encoder(s).add_temporaries(std::move(copies));
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/no_gpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ bool fast::ScaledDotProductAttention::use_fallback(
bool do_causal,
bool is_training,
bool output_logsumexp,
int window_size,
Stream s) {
return true;
}
Expand Down
Loading