diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index c668d9d8c5..ca2fd358ef 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -32,10 +32,12 @@ using namespace metal; instantiate_sdpa_vector(type, 64, 64) \ instantiate_sdpa_vector(type, 96, 96) \ instantiate_sdpa_vector(type, 128, 128) \ + instantiate_sdpa_vector(type, 192, 192) \ instantiate_sdpa_vector(type, 256, 256) \ instantiate_sdpa_vector_aggregation(type, 64) \ instantiate_sdpa_vector_aggregation(type, 96) \ instantiate_sdpa_vector_aggregation(type, 128) \ + instantiate_sdpa_vector_aggregation(type, 192) \ instantiate_sdpa_vector_aggregation(type, 256) instantiate_sdpa_vector_heads(float) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index 0ff9d91b00..4a67826951 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -12,6 +12,7 @@ attention, dtype, bq, bk, bd, wm, wn, mtype, float) #define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ + instantiate_attn(iname, itype, 32, 16, 256, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 37e554f183..4e6ea61880 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -618,9 +618,17 @@ bool ScaledDotProductAttention::use_fallback( const bool sdpa_vector_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || - query_head_dim == 256); + query_head_dim == 192 || query_head_dim == 256); + // For head_dim >= 192, the fused full-attention kernel is slower than + // unfused for short sequences. Only route to fused when kL is large enough + // that the unfused path would exceed Metal buffer limits (the fused kernel + // tiles K/V so it scales to arbitrary sequence lengths). + const bool sdpa_full_large_hd_ok = + (query_head_dim == 192 || query_head_dim == 256) && + key_sequence_length > 16384; const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); + (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 || + sdpa_full_large_hd_ok); const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal);