Skip to content

Commit b0ed511

Browse files
fhoushmandGoogle-ML-Automation
authored andcommitted
Force layout on Q for MLA.
This helps in non-pallas splash attention and removes copies when num_heads is 128. major to minor layout original query: 1, 2, 192, 1024, 128 attention expectation: 1, 2, 128, 192 , 1024 PiperOrigin-RevId: 855382451
1 parent 4bcee99 commit b0ed511

4 files changed

Lines changed: 27 additions & 8 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,3 +1016,5 @@ use_jax_splash: false
10161016
vllm_hf_config_path: ""
10171017
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
10181018
vllm_additional_config: {}
1019+
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]
1020+
force_q_layout: false

src/MaxText/configs/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ class Attention(BaseModel):
480480
enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.")
481481
use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.")
482482
use_jax_splash: bool = Field(False, description="Whether to use jax splash attention.")
483+
force_q_layout: bool = Field(False, description="Force the Q layout")
483484

484485

485486
class MoBa(BaseModel):
@@ -2231,6 +2232,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22312232
"Muon dimension numbers haven't been tested for this model. Run this command first: "
22322233
f"`python3 -m MaxText.muon_utils {self.model_name} True`"
22332234
)
2235+
if self.force_q_layout and not self.use_jax_splash:
2236+
raise ValueError(
2237+
"`force_q_layout` can only be true if `use_jax_splash` is also true."
2238+
)
22342239

22352240
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
22362241
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.

src/MaxText/kernels/jax_flash_attention.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,14 @@ def flash_attention_block_masked(
107107
# `l` is initialized to 0 since no blocks have been processed yet and the sum
108108
# is 0.
109109
l = jnp.zeros(
110-
(batch_size, num_kv_heads, q_groups, q_seq_len), dtype=jnp.float32
110+
(batch_size, num_kv_heads, q_groups, q_seq_len), dtype=data_type
111111
)
112112
# `m` is initialized to the mask_value so that the first block's maximum logit
113113
# correctly becomes the running maximum.
114114
m = jnp.full(
115115
(batch_size, num_kv_heads, q_groups, q_seq_len),
116116
mask_value,
117-
dtype=jnp.float32,
117+
dtype=data_type,
118118
)
119119

120120
output = jnp.zeros(
@@ -138,11 +138,12 @@ def outer_loop_body(j, carried):
138138
def inner_loop_body(i, carried_inner):
139139
output, l, m = carried_inner
140140

141+
# let's get the slice of Q in N dimension
142+
q_slice = jax.lax.dynamic_slice_in_dim(q, i * block_q, block_q, axis=-2)
143+
141144
# Calculates the attention computation (Q@K.T)@V with online softmax for
142145
# the current query and key/value blocks.
143146
def compute_attention_block(output, l, m):
144-
# let's get the slice of Q in N dimension
145-
q_slice = jax.lax.dynamic_slice_in_dim(q, i * block_q, block_q, axis=-2)
146147
output_i_slice = jax.lax.dynamic_slice_in_dim(
147148
output, i * block_q, block_q, axis=-2
148149
)
@@ -156,7 +157,7 @@ def compute_attention_block(output, l, m):
156157
"bxhqc,bxkc->bxhqk",
157158
q_slice,
158159
k_j_slice,
159-
preferred_element_type=jnp.float32,
160+
preferred_element_type=data_type,
160161
)
161162
full_mask_i_j_slice = jax.lax.dynamic_slice(
162163
mask_full,
@@ -193,7 +194,7 @@ def compute_attention_block(output, l, m):
193194

194195
output_i_slice_new = numerator / divider
195196
output = jax.lax.dynamic_update_index_in_dim(
196-
output, output_i_slice_new.astype(data_type), i * block_q, axis=-2
197+
output, output_i_slice_new, i * block_q, axis=-2
197198
)
198199
l = jax.lax.dynamic_update_index_in_dim(
199200
l, l_i_new, i * block_q, axis=-1

src/MaxText/layers/attention_mla.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,17 @@
1616

1717
import math
1818
from typing import Any, Optional, Tuple
19-
19+
import jax
2020
from jax.ad_checkpoint import checkpoint_name
21-
from jax.sharding import Mesh, NamedSharding
21+
from jax.experimental import layout
2222
import jax.numpy as jnp
23+
from jax.sharding import Mesh, NamedSharding
24+
25+
Layout = layout.Format
26+
if jax.__version_info__ >= (0, 6, 3):
27+
DLL = layout.Layout
28+
else:
29+
DLL = layout.DeviceLocalLayout # type: ignore
2330

2431
from flax import nnx
2532

@@ -738,6 +745,10 @@ def __call__(
738745
out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)
739746

740747
query = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
748+
if self.config.force_q_layout:
749+
query = layout.with_layout_constraint(
750+
query, DLL(major_to_minor=(0, 2, 3, 1))
751+
)
741752
key, value, cached_values = self.mla_kv_projection(
742753
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
743754
)

0 commit comments

Comments
 (0)