@@ -654,8 +654,44 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
654654 shard_mode = self .config .shard_mode ,
655655 rngs = self .rngs ,
656656 )
657+ elif self .config .fused_mla_lora_proj :
658+ # Fused Q+KV LoRA up-projection: single matmul (emb -> q_lora_rank + kv_lora_rank + rope_head_dim).
659+ self .wq_kv_a = DenseGeneral (
660+ in_features_shape = self .config .emb_dim ,
661+ out_features_shape = self .q_lora_rank + self .kv_lora_rank + self .qk_rope_head_dim ,
662+ axis = - 1 ,
663+ kernel_init = self .kernel_init ,
664+ kernel_axes = ("embed" , "q_kv_lora_up_proj" ),
665+ dtype = self .dtype ,
666+ weight_dtype = self .weight_dtype ,
667+ quant = self .quant ,
668+ matmul_precision = self .config .matmul_precision ,
669+ shard_mode = self .config .shard_mode ,
670+ rngs = self .rngs ,
671+ )
672+ self .q_norm = RMSNorm (
673+ num_features = self .q_lora_rank ,
674+ dtype = self .config .dtype ,
675+ weight_dtype = self .config .weight_dtype ,
676+ epsilon = self .config .normalization_layer_epsilon ,
677+ kernel_axes = ("norm" ,),
678+ rngs = self .rngs ,
679+ )
680+ self .wq_b = DenseGeneral (
681+ in_features_shape = self .q_lora_rank ,
682+ out_features_shape = (self .num_query_heads , self .qk_head_dim ),
683+ axis = - 1 ,
684+ kernel_init = self .kernel_init ,
685+ kernel_axes = ("q_lora" , "q_heads" , "kv" ),
686+ dtype = self .dtype ,
687+ weight_dtype = self .weight_dtype ,
688+ quant = self .quant ,
689+ matmul_precision = self .config .matmul_precision ,
690+ shard_mode = self .config .shard_mode ,
691+ rngs = self .rngs ,
692+ )
657693 else :
658- # LoRA path for Q .
694+ # Separate Q LoRA up-projection .
659695 self .wq_a = DenseGeneral (
660696 in_features_shape = self .config .emb_dim ,
661697 out_features_shape = self .q_lora_rank ,
@@ -691,20 +727,21 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
691727 rngs = self .rngs ,
692728 )
693729
694- # KV LoRA path.
695- self .wkv_a = DenseGeneral (
696- in_features_shape = self .config .emb_dim ,
697- out_features_shape = self .kv_lora_rank + self .qk_rope_head_dim ,
698- axis = - 1 ,
699- kernel_init = self .kernel_init ,
700- kernel_axes = ("embed" , "kv_lora_up_proj" ),
701- dtype = self .dtype ,
702- weight_dtype = self .weight_dtype ,
703- quant = self .quant ,
704- matmul_precision = self .config .matmul_precision ,
705- shard_mode = self .config .shard_mode ,
706- rngs = self .rngs ,
707- )
730+ if not self .config .fused_mla_lora_proj :
731+ # KV LoRA up-projection. When fused, wq_kv_a handles both Q and KV.
732+ self .wkv_a = DenseGeneral (
733+ in_features_shape = self .config .emb_dim ,
734+ out_features_shape = self .kv_lora_rank + self .qk_rope_head_dim ,
735+ axis = - 1 ,
736+ kernel_init = self .kernel_init ,
737+ kernel_axes = ("embed" , "kv_lora_up_proj" ),
738+ dtype = self .dtype ,
739+ weight_dtype = self .weight_dtype ,
740+ quant = self .quant ,
741+ matmul_precision = self .config .matmul_precision ,
742+ shard_mode = self .config .shard_mode ,
743+ rngs = self .rngs ,
744+ )
708745 self .kv_norm = RMSNorm (
709746 num_features = self .kv_lora_rank ,
710747 dtype = self .config .dtype ,
@@ -792,8 +829,11 @@ def mla_query_projection(
792829 if self .q_lora_rank == 0 :
793830 q = self .query (inputs_q , out_sharding = query_sharding )
794831 else :
795- # LoRA path
796- low_rank_q = self .wq_a (inputs_q , out_sharding = wqa_out_sharding ) # [B, L, q_lora_rank]
832+ # LoRA path: inputs_q is either raw embeddings (unfused) or the pre-split Q slice (fused).
833+ if not self .config .fused_mla_lora_proj :
834+ low_rank_q = self .wq_a (inputs_q , out_sharding = wqa_out_sharding ) # [B, L, q_lora_rank]
835+ else :
836+ low_rank_q = inputs_q # already the q_lora_rank slice from wq_kv_a split in __call__
797837 low_rank_q = self .q_norm (low_rank_q ) # RMSNorm on low rank
798838 low_rank_q = checkpoint_name (low_rank_q , "mla_q" )
799839 q = self .wq_b (low_rank_q , out_sharding = query_sharding ) # [B, L, n_heads, qk_head_dim]
@@ -932,7 +972,10 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
932972 else :
933973 wka_logical_name = (KV_BATCH , LENGTH_NO_EXP , KV_LORA_UP_PROJ )
934974 wkva_out_sharding = create_sharding (self .mesh , wka_logical_name )
935- low_rank = self .wkv_a (inputs , out_sharding = wkva_out_sharding )
975+ if self .config .fused_mla_lora_proj :
976+ low_rank = inputs # already the kv_lora_rank+rope_head_dim slice from wq_kv_a split in __call__
977+ else :
978+ low_rank = self .wkv_a (inputs , out_sharding = wkva_out_sharding )
936979 low_rank_main , low_rank_rope = jnp .split (low_rank , [self .kv_lora_rank ], axis = - 1 )
937980 low_rank_main = self .kv_norm (low_rank_main )
938981 low_rank_main = checkpoint_name (low_rank_main , "mla_kv" )
@@ -1068,12 +1111,23 @@ def __call__(
10681111 inputs_kv = self ._maybe_shard_with_logical (inputs_kv , self .input_axis_names )
10691112 out_logical_name = (BATCH , LENGTH_NO_EXP , HEAD , D_KV )
10701113
1071- query , low_rank_q = self .mla_query_projection (inputs_q , inputs_positions , model_mode )
1072- if self .config .force_q_layout :
1073- query = layout .with_layout_constraint (query , DLL (major_to_minor = (0 , 2 , 3 , 1 )))
1074- key , value , cached_values = self .mla_kv_projection (
1075- inputs_kv , inputs_positions , decoder_segment_ids , model_mode , previous_chunk
1076- )
1114+ if self .config .fused_mla_lora_proj :
1115+ # Single matmul for both Q and KV LoRA up-projections, then split.
1116+ fused_lora = self .wq_kv_a (inputs_q )
1117+ lora_q , lora_kv = jnp .split (fused_lora , [self .q_lora_rank ], axis = - 1 )
1118+ query , low_rank_q = self .mla_query_projection (lora_q , inputs_positions , model_mode )
1119+ if self .config .force_q_layout :
1120+ query = layout .with_layout_constraint (query , DLL (major_to_minor = (0 , 2 , 3 , 1 )))
1121+ key , value , cached_values = self .mla_kv_projection (
1122+ lora_kv , inputs_positions , decoder_segment_ids , model_mode , previous_chunk
1123+ )
1124+ else :
1125+ query , low_rank_q = self .mla_query_projection (inputs_q , inputs_positions , model_mode )
1126+ if self .config .force_q_layout :
1127+ query = layout .with_layout_constraint (query , DLL (major_to_minor = (0 , 2 , 3 , 1 )))
1128+ key , value , cached_values = self .mla_kv_projection (
1129+ inputs_kv , inputs_positions , decoder_segment_ids , model_mode , previous_chunk
1130+ )
10771131 query = checkpoint_name (query , "query_proj" )
10781132 key = checkpoint_name (key , "key_proj" )
10791133 value = checkpoint_name (value , "value_proj" )
0 commit comments