Skip to content

Commit 14328b4

Browse files
committed
Add fused_mla_lora_proj config flag for MLA LoRA up-projections
Fuses the Q and KV LoRA up-projections (wq_a + wkv_a) into a single matmul (wq_kv_a: emb → q_lora_rank + kv_lora_rank + rope_head_dim), halving the number of kernel launches for the LoRA up-projection step. Enabled via fused_mla_lora_proj: True (requires q_lora_rank > 0 and attention_type=mla). Modeled after the existing fused_qkv flag. Includes a unit test verifying that fused and unfused paths produce numerically identical outputs given equivalent weights.
1 parent 37ded59 commit 14328b4

4 files changed

Lines changed: 2061 additions & 1651 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper)
381381

382382
# Combine matmuls for QKV and MLP
383383
fused_qkv: False
384+
fused_mla_lora_proj: False # Fuse MLA Q+KV LoRA up-projections (wq_a+wkv_a) into a single matmul. Requires q_lora_rank > 0.
384385
fused_mlp: False
385386

386387
record_internal_nn_metrics: 0

src/maxtext/configs/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,10 @@ class ModelArchitecture(BaseModel):
435435
)
436436
normalization_layer_epsilon: float = Field(1.0e-05, description="Epsilon value for normalization layers.")
437437
fused_qkv: bool = Field(False, description="If supported, fuse the Q, K, and V projections.")
438+
fused_mla_lora_proj: bool = Field(
439+
False,
440+
description="Fuse MLA Q and KV LoRA up-projections (wq_a + wkv_a) into a single matmul. Requires q_lora_rank > 0.",
441+
)
438442
attention_bias: bool = Field(
439443
False,
440444
description="If True, adds a learnable bias to the query, key, and value projections.",
@@ -2558,6 +2562,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25582562
if self.share_kv_projections and self.attention_type == "mla":
25592563
raise ValueError("`share_kv_projections` is not compatible with `attention_type='mla'`.")
25602564

2565+
if self.fused_mla_lora_proj and self.q_lora_rank == 0:
2566+
raise ValueError("`fused_mla_lora_proj` requires `q_lora_rank > 0`.")
2567+
if self.fused_mla_lora_proj and self.attention_type != "mla":
2568+
raise ValueError("`fused_mla_lora_proj` is only valid with `attention_type='mla'`.")
2569+
25612570
# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
25622571
ici_map = {
25632572
"diloco": self.ici_diloco_parallelism,

src/maxtext/layers/attention_mla.py

Lines changed: 78 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -654,8 +654,44 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
654654
shard_mode=self.config.shard_mode,
655655
rngs=self.rngs,
656656
)
657+
elif self.config.fused_mla_lora_proj:
658+
# Fused Q+KV LoRA up-projection: single matmul (emb -> q_lora_rank + kv_lora_rank + rope_head_dim).
659+
self.wq_kv_a = DenseGeneral(
660+
in_features_shape=self.config.emb_dim,
661+
out_features_shape=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
662+
axis=-1,
663+
kernel_init=self.kernel_init,
664+
kernel_axes=("embed", "q_kv_lora_up_proj"),
665+
dtype=self.dtype,
666+
weight_dtype=self.weight_dtype,
667+
quant=self.quant,
668+
matmul_precision=self.config.matmul_precision,
669+
shard_mode=self.config.shard_mode,
670+
rngs=self.rngs,
671+
)
672+
self.q_norm = RMSNorm(
673+
num_features=self.q_lora_rank,
674+
dtype=self.config.dtype,
675+
weight_dtype=self.config.weight_dtype,
676+
epsilon=self.config.normalization_layer_epsilon,
677+
kernel_axes=("norm",),
678+
rngs=self.rngs,
679+
)
680+
self.wq_b = DenseGeneral(
681+
in_features_shape=self.q_lora_rank,
682+
out_features_shape=(self.num_query_heads, self.qk_head_dim),
683+
axis=-1,
684+
kernel_init=self.kernel_init,
685+
kernel_axes=("q_lora", "q_heads", "kv"),
686+
dtype=self.dtype,
687+
weight_dtype=self.weight_dtype,
688+
quant=self.quant,
689+
matmul_precision=self.config.matmul_precision,
690+
shard_mode=self.config.shard_mode,
691+
rngs=self.rngs,
692+
)
657693
else:
658-
# LoRA path for Q.
694+
# Separate Q LoRA up-projection.
659695
self.wq_a = DenseGeneral(
660696
in_features_shape=self.config.emb_dim,
661697
out_features_shape=self.q_lora_rank,
@@ -691,20 +727,21 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
691727
rngs=self.rngs,
692728
)
693729

694-
# KV LoRA path.
695-
self.wkv_a = DenseGeneral(
696-
in_features_shape=self.config.emb_dim,
697-
out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim,
698-
axis=-1,
699-
kernel_init=self.kernel_init,
700-
kernel_axes=("embed", "kv_lora_up_proj"),
701-
dtype=self.dtype,
702-
weight_dtype=self.weight_dtype,
703-
quant=self.quant,
704-
matmul_precision=self.config.matmul_precision,
705-
shard_mode=self.config.shard_mode,
706-
rngs=self.rngs,
707-
)
730+
if not self.config.fused_mla_lora_proj:
731+
# KV LoRA up-projection. When fused, wq_kv_a handles both Q and KV.
732+
self.wkv_a = DenseGeneral(
733+
in_features_shape=self.config.emb_dim,
734+
out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim,
735+
axis=-1,
736+
kernel_init=self.kernel_init,
737+
kernel_axes=("embed", "kv_lora_up_proj"),
738+
dtype=self.dtype,
739+
weight_dtype=self.weight_dtype,
740+
quant=self.quant,
741+
matmul_precision=self.config.matmul_precision,
742+
shard_mode=self.config.shard_mode,
743+
rngs=self.rngs,
744+
)
708745
self.kv_norm = RMSNorm(
709746
num_features=self.kv_lora_rank,
710747
dtype=self.config.dtype,
@@ -792,8 +829,11 @@ def mla_query_projection(
792829
if self.q_lora_rank == 0:
793830
q = self.query(inputs_q, out_sharding=query_sharding)
794831
else:
795-
# LoRA path
796-
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
832+
# LoRA path: inputs_q is either raw embeddings (unfused) or the pre-split Q slice (fused).
833+
if not self.config.fused_mla_lora_proj:
834+
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
835+
else:
836+
low_rank_q = inputs_q # already the q_lora_rank slice from wq_kv_a split in __call__
797837
low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank
798838
low_rank_q = checkpoint_name(low_rank_q, "mla_q")
799839
q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads, qk_head_dim]
@@ -932,7 +972,10 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
932972
else:
933973
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
934974
wkva_out_sharding = create_sharding(self.mesh, wka_logical_name)
935-
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
975+
if self.config.fused_mla_lora_proj:
976+
low_rank = inputs # already the kv_lora_rank+rope_head_dim slice from wq_kv_a split in __call__
977+
else:
978+
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
936979
low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1)
937980
low_rank_main = self.kv_norm(low_rank_main)
938981
low_rank_main = checkpoint_name(low_rank_main, "mla_kv")
@@ -1068,12 +1111,23 @@ def __call__(
10681111
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)
10691112
out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV)
10701113

1071-
query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
1072-
if self.config.force_q_layout:
1073-
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
1074-
key, value, cached_values = self.mla_kv_projection(
1075-
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
1076-
)
1114+
if self.config.fused_mla_lora_proj:
1115+
# Single matmul for both Q and KV LoRA up-projections, then split.
1116+
fused_lora = self.wq_kv_a(inputs_q)
1117+
lora_q, lora_kv = jnp.split(fused_lora, [self.q_lora_rank], axis=-1)
1118+
query, low_rank_q = self.mla_query_projection(lora_q, inputs_positions, model_mode)
1119+
if self.config.force_q_layout:
1120+
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
1121+
key, value, cached_values = self.mla_kv_projection(
1122+
lora_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
1123+
)
1124+
else:
1125+
query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
1126+
if self.config.force_q_layout:
1127+
query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1)))
1128+
key, value, cached_values = self.mla_kv_projection(
1129+
inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk
1130+
)
10771131
query = checkpoint_name(query, "query_proj")
10781132
key = checkpoint_name(key, "key_proj")
10791133
value = checkpoint_name(value, "value_proj")

0 commit comments

Comments
 (0)