Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper)

# Combine matmuls for QKV and MLP
fused_qkv: False
fused_mla_lora_proj: False # Fuse MLA Q+KV LoRA up-projections (wq_a+wkv_a) into a single matmul. Requires q_lora_rank > 0.
fused_mlp: False

record_internal_nn_metrics: 0
Expand Down
9 changes: 9 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,10 @@ class ModelArchitecture(BaseModel):
)
normalization_layer_epsilon: float = Field(1.0e-05, description="Epsilon value for normalization layers.")
fused_qkv: bool = Field(False, description="If supported, fuse the Q, K, and V projections.")
fused_mla_lora_proj: bool = Field(
False,
description="Fuse MLA Q and KV LoRA up-projections (wq_a + wkv_a) into a single matmul. Requires q_lora_rank > 0.",
)
attention_bias: bool = Field(
False,
description="If True, adds a learnable bias to the query, key, and value projections.",
Expand Down Expand Up @@ -2558,6 +2562,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
if self.share_kv_projections and self.attention_type == "mla":
raise ValueError("`share_kv_projections` is not compatible with `attention_type='mla'`.")

if self.fused_mla_lora_proj and self.q_lora_rank == 0:
raise ValueError("`fused_mla_lora_proj` requires `q_lora_rank > 0`.")
if self.fused_mla_lora_proj and self.attention_type != "mla":
raise ValueError("`fused_mla_lora_proj` is only valid with `attention_type='mla'`.")

# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
ici_map = {
"diloco": self.ici_diloco_parallelism,
Expand Down
102 changes: 78 additions & 24 deletions src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,44 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
elif self.config.fused_mla_lora_proj:
# Fused Q+KV LoRA up-projection: single matmul (emb -> q_lora_rank + kv_lora_rank + rope_head_dim).
self.wq_kv_a = DenseGeneral(
in_features_shape=self.config.emb_dim,
out_features_shape=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("embed", "q_kv_lora_up_proj"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
self.q_norm = RMSNorm(
num_features=self.q_lora_rank,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
epsilon=self.config.normalization_layer_epsilon,
kernel_axes=("norm",),
rngs=self.rngs,
)
self.wq_b = DenseGeneral(
in_features_shape=self.q_lora_rank,
out_features_shape=(self.num_query_heads, self.qk_head_dim),
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("q_lora", "q_heads", "kv"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
else:
# LoRA path for Q.
# Separate Q LoRA up-projection.
self.wq_a = DenseGeneral(
in_features_shape=self.config.emb_dim,
out_features_shape=self.q_lora_rank,
Expand Down Expand Up @@ -691,20 +727,21 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
rngs=self.rngs,
)

# KV LoRA path.
self.wkv_a = DenseGeneral(
in_features_shape=self.config.emb_dim,
out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim,
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("embed", "kv_lora_up_proj"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
if not self.config.fused_mla_lora_proj:
# KV LoRA up-projection. When fused, wq_kv_a handles both Q and KV.
self.wkv_a = DenseGeneral(
in_features_shape=self.config.emb_dim,
out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim,
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("embed", "kv_lora_up_proj"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
self.kv_norm = RMSNorm(
num_features=self.kv_lora_rank,
dtype=self.config.dtype,
Expand Down Expand Up @@ -792,8 +829,11 @@ def mla_query_projection(
if self.q_lora_rank == 0:
q = self.query(inputs_q, out_sharding=query_sharding)
else:
# LoRA path
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
# LoRA path: inputs_q is either raw embeddings (unfused) or the pre-split Q slice (fused).
if not self.config.fused_mla_lora_proj:
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
else:
low_rank_q = inputs_q # already the q_lora_rank slice from wq_kv_a split in __call__
low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank
low_rank_q = checkpoint_name(low_rank_q, "mla_q")
q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads, qk_head_dim]
Expand Down Expand Up @@ -932,7 +972,10 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
else:
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
wkva_out_sharding = create_sharding(self.mesh, wka_logical_name)
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
if self.config.fused_mla_lora_proj:
low_rank = inputs # already the kv_lora_rank+rope_head_dim slice from wq_kv_a split in __call__
else:
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1)
low_rank_main = self.kv_norm(low_rank_main)
low_rank_main = checkpoint_name(low_rank_main, "mla_kv")
Expand Down Expand Up @@ -1068,12 +1111,23 @@ def __call__(
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)

query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
if self.config.force_q_layout:
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
key, value, cached_values = self.mla_kv_projection(
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
)
if self.config.fused_mla_lora_proj:
# Single matmul for both Q and KV LoRA up-projections, then split.
fused_lora = self.wq_kv_a(inputs_q)
lora_q, lora_kv = jnp.split(fused_lora, [self.q_lora_rank], axis=-1)
query, low_rank_q = self.mla_query_projection(lora_q, inputs_positions, model_mode)
if self.config.force_q_layout:
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
key, value, cached_values = self.mla_kv_projection(
lora_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
)
else:
query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
if self.config.force_q_layout:
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
key, value, cached_values = self.mla_kv_projection(
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
)
query = checkpoint_name(query, "query_proj")
key = checkpoint_name(key, "key_proj")
value = checkpoint_name(value, "value_proj")
Expand Down
Loading