diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 398df849fe..661914c83c 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -381,6 +381,7 @@ qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper) # Combine matmuls for QKV and MLP fused_qkv: False +fused_mla_lora_proj: False # Fuse MLA Q+KV LoRA up-projections (wq_a+wkv_a) into a single matmul. Requires q_lora_rank > 0. fused_mlp: False record_internal_nn_metrics: 0 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 952a1e9f8e..e84cbb9f76 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -435,6 +435,10 @@ class ModelArchitecture(BaseModel): ) normalization_layer_epsilon: float = Field(1.0e-05, description="Epsilon value for normalization layers.") fused_qkv: bool = Field(False, description="If supported, fuse the Q, K, and V projections.") + fused_mla_lora_proj: bool = Field( + False, + description="Fuse MLA Q and KV LoRA up-projections (wq_a + wkv_a) into a single matmul. Requires q_lora_rank > 0.", + ) attention_bias: bool = Field( False, description="If True, adds a learnable bias to the query, key, and value projections.", @@ -2558,6 +2562,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de if self.share_kv_projections and self.attention_type == "mla": raise ValueError("`share_kv_projections` is not compatible with `attention_type='mla'`.") + if self.fused_mla_lora_proj and self.q_lora_rank == 0: + raise ValueError("`fused_mla_lora_proj` requires `q_lora_rank > 0`.") + if self.fused_mla_lora_proj and self.attention_type != "mla": + raise ValueError("`fused_mla_lora_proj` is only valid with `attention_type='mla'`.") + # I. FINAL TYPE CONVERSIONS AND DERIVED LISTS ici_map = { "diloco": self.ici_diloco_parallelism, diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index e0d6e4e9f1..f5605e1145 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -654,8 +654,44 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No shard_mode=self.config.shard_mode, rngs=self.rngs, ) + elif self.config.fused_mla_lora_proj: + # Fused Q+KV LoRA up-projection: single matmul (emb -> q_lora_rank + kv_lora_rank + rope_head_dim). + self.wq_kv_a = DenseGeneral( + in_features_shape=self.config.emb_dim, + out_features_shape=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim, + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "q_kv_lora_up_proj"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + matmul_precision=self.config.matmul_precision, + shard_mode=self.config.shard_mode, + rngs=self.rngs, + ) + self.q_norm = RMSNorm( + num_features=self.q_lora_rank, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + epsilon=self.config.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=self.rngs, + ) + self.wq_b = DenseGeneral( + in_features_shape=self.q_lora_rank, + out_features_shape=(self.num_query_heads, self.qk_head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("q_lora", "q_heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + matmul_precision=self.config.matmul_precision, + shard_mode=self.config.shard_mode, + rngs=self.rngs, + ) else: - # LoRA path for Q. + # Separate Q LoRA up-projection. self.wq_a = DenseGeneral( in_features_shape=self.config.emb_dim, out_features_shape=self.q_lora_rank, @@ -691,20 +727,21 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No rngs=self.rngs, ) - # KV LoRA path. - self.wkv_a = DenseGeneral( - in_features_shape=self.config.emb_dim, - out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim, - axis=-1, - kernel_init=self.kernel_init, - kernel_axes=("embed", "kv_lora_up_proj"), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - quant=self.quant, - matmul_precision=self.config.matmul_precision, - shard_mode=self.config.shard_mode, - rngs=self.rngs, - ) + if not self.config.fused_mla_lora_proj: + # KV LoRA up-projection. When fused, wq_kv_a handles both Q and KV. + self.wkv_a = DenseGeneral( + in_features_shape=self.config.emb_dim, + out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim, + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "kv_lora_up_proj"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + matmul_precision=self.config.matmul_precision, + shard_mode=self.config.shard_mode, + rngs=self.rngs, + ) self.kv_norm = RMSNorm( num_features=self.kv_lora_rank, dtype=self.config.dtype, @@ -792,8 +829,11 @@ def mla_query_projection( if self.q_lora_rank == 0: q = self.query(inputs_q, out_sharding=query_sharding) else: - # LoRA path - low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank] + # LoRA path: inputs_q is either raw embeddings (unfused) or the pre-split Q slice (fused). + if not self.config.fused_mla_lora_proj: + low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank] + else: + low_rank_q = inputs_q # already the q_lora_rank slice from wq_kv_a split in __call__ low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank low_rank_q = checkpoint_name(low_rank_q, "mla_q") q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads, qk_head_dim] @@ -932,7 +972,10 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm else: wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ) wkva_out_sharding = create_sharding(self.mesh, wka_logical_name) - low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding) + if self.config.fused_mla_lora_proj: + low_rank = inputs # already the kv_lora_rank+rope_head_dim slice from wq_kv_a split in __call__ + else: + low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding) low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1) low_rank_main = self.kv_norm(low_rank_main) low_rank_main = checkpoint_name(low_rank_main, "mla_kv") @@ -1068,12 +1111,23 @@ def __call__( inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names) out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV) - query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode) - if self.config.force_q_layout: - query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1))) - key, value, cached_values = self.mla_kv_projection( - inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk - ) + if self.config.fused_mla_lora_proj: + # Single matmul for both Q and KV LoRA up-projections, then split. + fused_lora = self.wq_kv_a(inputs_q) + lora_q, lora_kv = jnp.split(fused_lora, [self.q_lora_rank], axis=-1) + query, low_rank_q = self.mla_query_projection(lora_q, inputs_positions, model_mode) + if self.config.force_q_layout: + query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1))) + key, value, cached_values = self.mla_kv_projection( + lora_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk + ) + else: + query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode) + if self.config.force_q_layout: + query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1))) + key, value, cached_values = self.mla_kv_projection( + inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk + ) query = checkpoint_name(query, "query_proj") key = checkpoint_name(key, "key_proj") value = checkpoint_name(value, "value_proj") diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index fc2c3c2d24..49a263a411 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -55,62 +55,52 @@ def test_one_block_mask(self): bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0]]) # pylint: disable=protected-access block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.asarray( - [ - [ - [False, False, False, False, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, False, False, False, False, False], - [False, False, False, False, False, False], - ] - ] - ) + expected_mask = np.asarray([[ + [False, False, False, False, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ]]) np.testing.assert_array_equal(block_mask, expected_mask) def test_two_blocks_mask(self): bidirectional_mask = np.asarray([[0, 1, 1, 0, 1, 1]]) # pylint: disable=protected-access block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.asarray( - [ - [ - [False, False, False, False, False, False], - [False, True, True, False, False, False], - [False, True, True, False, False, False], - [False, False, False, False, False, False], - [False, False, False, False, True, True], - [False, False, False, False, True, True], - ] - ] - ) + expected_mask = np.asarray([[ + [False, False, False, False, False, False], + [False, True, True, False, False, False], + [False, True, True, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, True, True], + [False, False, False, False, True, True], + ]]) np.testing.assert_array_equal(block_mask, expected_mask) def test_batch_block_masks(self): bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1]]) # pylint: disable=protected-access block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.asarray( + expected_mask = np.asarray([ [ - [ - [False, False, False, False, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, False, False, False, False, False], - [False, False, False, False, False, False], - ], - [ - [False, False, False, False, False, False], - [False, True, True, False, False, False], - [False, True, True, False, False, False], - [False, False, False, False, False, False], - [False, False, False, False, True, True], - [False, False, False, False, True, True], - ], - ] - ) + [False, False, False, False, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ], + [ + [False, False, False, False, False, False], + [False, True, True, False, False, False], + [False, True, True, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, True, True], + [False, False, False, False, True, True], + ], + ]) np.testing.assert_array_equal(block_mask, expected_mask) def test_empty_block_mask(self): @@ -118,7 +108,12 @@ def test_empty_block_mask(self): # pylint: disable=protected-access block_mask = _make_bidirectional_block_mask(bidirectional_mask) expected_mask = np.zeros( - (bidirectional_mask.shape[0], bidirectional_mask.shape[1], bidirectional_mask.shape[1]), dtype=bool + ( + bidirectional_mask.shape[0], + bidirectional_mask.shape[1], + bidirectional_mask.shape[1], + ), + dtype=bool, ) np.testing.assert_array_equal(block_mask, expected_mask) @@ -127,7 +122,12 @@ def test_full_block_mask(self): # pylint: disable=protected-access block_mask = _make_bidirectional_block_mask(bidirectional_mask) expected_mask = np.ones( - (bidirectional_mask.shape[0], bidirectional_mask.shape[1], bidirectional_mask.shape[1]), dtype=bool + ( + bidirectional_mask.shape[0], + bidirectional_mask.shape[1], + bidirectional_mask.shape[1], + ), + dtype=bool, ) np.testing.assert_array_equal(block_mask, expected_mask) @@ -140,34 +140,24 @@ def test_combine_with_causal_mask(self): # pylint: disable=protected-access image_mask = _make_bidirectional_block_mask(bidirectional_mask) combined_mask = causal_mask | image_mask[:, None, None, ...] - expected_mask = np.asarray( - [ - [ - [ - [ - [True, False, False, False, False, False], - [True, True, True, True, False, False], - [True, True, True, True, False, False], - [True, True, True, True, False, False], - [True, True, True, True, True, False], - [True, True, True, True, True, True], - ] - ] - ], - [ - [ - [ - [True, False, False, False, False, False], - [True, True, True, False, False, False], - [True, True, True, False, False, False], - [True, True, True, True, False, False], - [True, True, True, True, True, True], - [True, True, True, True, True, True], - ] - ] - ], - ] - ) + expected_mask = np.asarray([ + [[[ + [True, False, False, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, True, False], + [True, True, True, True, True, True], + ]]], + [[[ + [True, False, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + ]]], + ]) np.testing.assert_array_equal(combined_mask, expected_mask) @@ -346,7 +336,10 @@ def get_data(self, dtype): decoder_segment_ids = jax.random.randint(self.rng, (self.global_batch_size, self.max_target_length), 0, 4) decoder_positions = jax.random.randint( - self.rng, (self.global_batch_size, self.max_target_length), 0, self.max_target_length + self.rng, + (self.global_batch_size, self.max_target_length), + 0, + self.max_target_length, ) return lnx, decoder_segment_ids, decoder_positions @@ -398,7 +391,13 @@ def test_autoregression(self): ) self.assertTrue( - jax.numpy.allclose(mha_prefill, mha_full[:, :prefill_length, :], rtol=1e-02, atol=1e-02, equal_nan=False) + jax.numpy.allclose( + mha_prefill, + mha_full[:, :prefill_length, :], + rtol=1e-02, + atol=1e-02, + equal_nan=False, + ) ) for idx in range(prefill_length, decode_total_length): @@ -548,7 +547,13 @@ def tpu_kernel_attention_helper(self, num_kv_heads, share_kv_projections=False): ) self.assertTrue( - jax.numpy.allclose(mha_generic_output, mha_generic_flash_output, rtol=1e-01, atol=1e-01, equal_nan=False) + jax.numpy.allclose( + mha_generic_output, + mha_generic_flash_output, + rtol=1e-01, + atol=1e-01, + equal_nan=False, + ) ) def test_share_kv_projections(self): @@ -592,7 +597,10 @@ def test_share_kv_projections(self): model_mode=MODEL_MODE_TRAIN, ) - self.assertEqual(output_shared.shape, (self.global_batch_size, self.max_target_length, self.embed_dim)) + self.assertEqual( + output_shared.shape, + (self.global_batch_size, self.max_target_length, self.embed_dim), + ) # 3. Equivalence Check with standard unshared Attention attention_no_share = Attention( @@ -797,7 +805,13 @@ def test_tpu_flash_attention_context_parallel( mha_generic_flash_cp_output = jax.device_get(mha_generic_flash_cp_output) self.assertTrue( - jax.numpy.allclose(mha_generic_output, mha_generic_flash_cp_output, rtol=1e-01, atol=1e-01, equal_nan=False), + jax.numpy.allclose( + mha_generic_output, + mha_generic_flash_cp_output, + rtol=1e-01, + atol=1e-01, + equal_nan=False, + ), msg="Logits from generic dot product and flash attention + context/expert parallelism are not close.\n" f"ici_context_parallelism={ici_context_parallelism}, context_parallel_load_balance={context_parallel_load_balance}," f" ici_expert_parallelism={ici_expert_parallelism}, expert_shard_attention_option={expert_shard_attention_option}.", @@ -884,7 +898,11 @@ def _dot_product_attention( model_mode=MODEL_MODE_PREFILL, ) self.assertTrue( - jax.numpy.allclose(attention_w_layout_full[:, :prefill_length, :], attention_w_layout_prefill, equal_nan=False) + jax.numpy.allclose( + attention_w_layout_full[:, :prefill_length, :], + attention_w_layout_prefill, + equal_nan=False, + ) ) for idx in range(prefill_length, decode_total_length): @@ -903,7 +921,11 @@ def _dot_product_attention( self.assertTrue(attention_w_layout_full_this_idx.shape == attention_w_layout_idx.shape) self.assertTrue( jax.numpy.allclose( - attention_w_layout_full_this_idx, attention_w_layout_idx, rtol=rtol, atol=atol, equal_nan=False + attention_w_layout_full_this_idx, + attention_w_layout_idx, + rtol=rtol, + atol=atol, + equal_nan=False, ) ) @@ -1008,7 +1030,9 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): ) self.assertTrue( jax.numpy.allclose( - attention_wo_reshape_q_full[:, :prefill_length, :], attention_wo_reshape_q_prefill, equal_nan=False + attention_wo_reshape_q_full[:, :prefill_length, :], + attention_wo_reshape_q_prefill, + equal_nan=False, ) ) @@ -1022,11 +1046,19 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): ) self.assertTrue( jax.numpy.allclose( - attention_w_reshape_q_full[:, :prefill_length, :], attention_w_reshape_q_prefill, equal_nan=False + attention_w_reshape_q_full[:, :prefill_length, :], + attention_w_reshape_q_prefill, + equal_nan=False, ) ) - self.assertTrue(jax.numpy.allclose(attention_wo_reshape_q_prefill, attention_w_reshape_q_prefill, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose( + attention_wo_reshape_q_prefill, + attention_w_reshape_q_prefill, + equal_nan=False, + ) + ) self.assertTrue( jax.numpy.allclose( attention_wo_reshape_q_full[:, :prefill_length, :], @@ -1051,7 +1083,11 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): self.assertTrue(attention_wo_reshape_q_full_this_idx.shape == attention_wo_reshape_q_idx.shape) self.assertTrue( jax.numpy.allclose( - attention_wo_reshape_q_full_this_idx, attention_wo_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False + attention_wo_reshape_q_full_this_idx, + attention_wo_reshape_q_idx, + rtol=rtol, + atol=atol, + equal_nan=False, ) ) @@ -1067,12 +1103,22 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): self.assertTrue(attention_w_reshape_q_full_this_idx.shape == attention_w_reshape_q_idx.shape) self.assertTrue( jax.numpy.allclose( - attention_w_reshape_q_full_this_idx, attention_w_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False + attention_w_reshape_q_full_this_idx, + attention_w_reshape_q_idx, + rtol=rtol, + atol=atol, + equal_nan=False, ) ) self.assertTrue( - jax.numpy.allclose(attention_w_reshape_q_idx, attention_wo_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False) + jax.numpy.allclose( + attention_w_reshape_q_idx, + attention_wo_reshape_q_idx, + rtol=rtol, + atol=atol, + equal_nan=False, + ) ) def test_sliding_window_attention(self): @@ -1147,7 +1193,10 @@ def test_sliding_window_attention(self): # Test if sliding window attention is different from global attention self.assertFalse( jax.numpy.allclose( - sliding_window_output.astype(jnp.bfloat16), global_attn_output.astype(jnp.bfloat16), rtol=1e-04, atol=1e-04 + sliding_window_output.astype(jnp.bfloat16), + global_attn_output.astype(jnp.bfloat16), + rtol=1e-04, + atol=1e-04, ) ) @@ -1198,7 +1247,10 @@ def test_sliding_window_attention(self): @pytest.mark.skip(reason="Requires `vllm-tpu` package which is not yet a MaxText dependency.") @pytest.mark.tpu_only - @mock.patch("tpu_inference.layers.jax.attention_interface.sharded_ragged_paged_attention", create=True) + @mock.patch( + "tpu_inference.layers.jax.attention_interface.sharded_ragged_paged_attention", + create=True, + ) def test_forward_serve_vllm(self, mock_sharded_ragged_paged_attention): """Tests the forward_serve_vllm method with mocked RPA attention.""" # Setup config for vLLM RPA @@ -1311,7 +1363,13 @@ def test_autoregression(self, rope_type): ) self.assertTrue( - jax.numpy.allclose(mla_prefill, mla_full[:, :prefill_length, :], rtol=1e-02, atol=1e-02, equal_nan=False) + jax.numpy.allclose( + mla_prefill, + mla_full[:, :prefill_length, :], + rtol=1e-02, + atol=1e-02, + equal_nan=False, + ) ) for idx in range(prefill_length, decode_total_length): @@ -1342,12 +1400,16 @@ def test_projection_initialization(self): **attention_config_args, **extra_args, ) - dummy_inputs_q = jnp.ones( - (attention_cfg.global_batch_size_to_train_on, attention_cfg.max_target_length, attention_cfg.base_emb_dim) - ) - dummy_inputs_kv = jnp.ones( - (attention_cfg.global_batch_size_to_train_on, attention_cfg.max_target_length, attention_cfg.base_emb_dim) - ) + dummy_inputs_q = jnp.ones(( + attention_cfg.global_batch_size_to_train_on, + attention_cfg.max_target_length, + attention_cfg.base_emb_dim, + )) + dummy_inputs_kv = jnp.ones(( + attention_cfg.global_batch_size_to_train_on, + attention_cfg.max_target_length, + attention_cfg.base_emb_dim, + )) base_attention = Attention( config=attention_cfg, @@ -1365,10 +1427,22 @@ def test_projection_initialization(self): ) # 2. Assert that the base Attention layer HAS all its standard projections - self.assertTrue(hasattr(base_attention, "query"), "Base Attention should have 'query' projection.") - self.assertTrue(hasattr(base_attention, "key"), "Base Attention should have 'key' projection.") - self.assertTrue(hasattr(base_attention, "value"), "Base Attention should have 'value' projection.") - self.assertTrue(hasattr(base_attention, "out"), "Base Attention should have 'out' projection.") + self.assertTrue( + hasattr(base_attention, "query"), + "Base Attention should have 'query' projection.", + ) + self.assertTrue( + hasattr(base_attention, "key"), + "Base Attention should have 'key' projection.", + ) + self.assertTrue( + hasattr(base_attention, "value"), + "Base Attention should have 'value' projection.", + ) + self.assertTrue( + hasattr(base_attention, "out"), + "Base Attention should have 'out' projection.", + ) # 3. Initialize the MLA layer mla_config_args = self.config_arguments.copy() @@ -1390,6 +1464,104 @@ def test_projection_initialization(self): self.assertTrue(hasattr(mla_layer, "kv_norm"), "MLA should have 'kv_norm' projection.") self.assertTrue(hasattr(mla_layer, "out"), "MLA should have 'out' projection.") + def test_fused_mla_lora_proj_output_equivalence(self): + """Tests that fused_mla_lora_proj=True produces identical outputs to fused_mla_lora_proj=False.""" + extra_args = get_decoupled_parallelism_overrides() + + # Initialize the unfused model. + unfused_args = { + **self.config_arguments, + "fused_mla_lora_proj": False, + **extra_args, + } + cfg_unfused = pyconfig.initialize([sys.argv[0], get_test_config_path()], **unfused_args) + devices_array = maxtext_utils.create_device_mesh(cfg_unfused) + mesh = Mesh(devices_array, cfg_unfused.mesh_axes) + dummy_q = jnp.ones(( + cfg_unfused.global_batch_size_to_train_on, + cfg_unfused.max_target_length, + cfg_unfused.base_emb_dim, + )) + mla_unfused = MLA( + config=cfg_unfused, + num_query_heads=cfg_unfused.num_query_heads, + num_kv_heads=cfg_unfused.num_kv_heads, + head_dim=cfg_unfused.head_dim, + inputs_q_shape=dummy_q.shape, + inputs_kv_shape=dummy_q.shape, + max_target_length=cfg_unfused.max_target_length, + max_prefill_predict_length=cfg_unfused.max_prefill_predict_length, + mesh=mesh, + attention_kernel="dot_product", + dtype=cfg_unfused.dtype, + dropout_rate=cfg_unfused.dropout_rate, + attention_type=cfg_unfused.attention_type, + q_lora_rank=cfg_unfused.q_lora_rank, + kv_lora_rank=cfg_unfused.kv_lora_rank, + qk_nope_head_dim=cfg_unfused.qk_nope_head_dim, + qk_rope_head_dim=cfg_unfused.qk_rope_head_dim, + v_head_dim=cfg_unfused.v_head_dim, + model_mode=MODEL_MODE_TRAIN, + rngs=nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)), + ) + + # Initialize the fused model. + fused_args = { + **self.config_arguments, + "fused_mla_lora_proj": True, + **extra_args, + } + cfg_fused = pyconfig.initialize([sys.argv[0], get_test_config_path()], **fused_args) + mla_fused = MLA( + config=cfg_fused, + num_query_heads=cfg_fused.num_query_heads, + num_kv_heads=cfg_fused.num_kv_heads, + head_dim=cfg_fused.head_dim, + inputs_q_shape=dummy_q.shape, + inputs_kv_shape=dummy_q.shape, + max_target_length=cfg_fused.max_target_length, + max_prefill_predict_length=cfg_fused.max_prefill_predict_length, + mesh=mesh, + attention_kernel="dot_product", + dtype=cfg_fused.dtype, + dropout_rate=cfg_fused.dropout_rate, + attention_type=cfg_fused.attention_type, + q_lora_rank=cfg_fused.q_lora_rank, + kv_lora_rank=cfg_fused.kv_lora_rank, + qk_nope_head_dim=cfg_fused.qk_nope_head_dim, + qk_rope_head_dim=cfg_fused.qk_rope_head_dim, + v_head_dim=cfg_fused.v_head_dim, + model_mode=MODEL_MODE_TRAIN, + rngs=nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)), + ) + + # Make both models mathematically equivalent: + # fused wq_kv_a = concat(unfused wq_a, unfused wkv_a) along the output axis. + mla_fused.wq_kv_a.kernel.value = jnp.concatenate( + [mla_unfused.wq_a.kernel.value, mla_unfused.wkv_a.kernel.value], axis=-1 + ) + mla_fused.wq_b.kernel.value = mla_unfused.wq_b.kernel.value + mla_fused.q_norm.scale.value = mla_unfused.q_norm.scale.value + mla_fused.wkv_b.kernel.value = mla_unfused.wkv_b.kernel.value + mla_fused.kv_norm.scale.value = mla_unfused.kv_norm.scale.value + mla_fused.out.kernel.value = mla_unfused.out.kernel.value + + # Run both models on the same inputs and verify outputs are identical. + lnx, decoder_segment_ids, decoder_positions = self.get_data(cfg_unfused, cfg_unfused.dtype) + common_kwargs = { + "decoder_segment_ids": decoder_segment_ids, + "inputs_positions": decoder_positions, + "deterministic": True, + "model_mode": MODEL_MODE_TRAIN, + } + output_unfused, _ = mla_unfused(lnx, lnx, **common_kwargs) + output_fused, _ = mla_fused(lnx, lnx, **common_kwargs) + + self.assertTrue( + jax.numpy.allclose(output_unfused, output_fused, rtol=1e-05, atol=1e-05, equal_nan=False), + "fused_mla_lora_proj=True and fused_mla_lora_proj=False produced different outputs.", + ) + @parameterized.named_parameters( { "testcase_name": "cp_no_load_balance", @@ -1586,7 +1758,13 @@ def test_tpu_flash_attention_context_parallel( mla_generic_flash_cp_output = jax.device_get(mla_generic_flash_cp_output) self.assertTrue( - jax.numpy.allclose(mla_generic_output, mla_generic_flash_cp_output, rtol=1e-01, atol=1e-01, equal_nan=False), + jax.numpy.allclose( + mla_generic_output, + mla_generic_flash_cp_output, + rtol=1e-01, + atol=1e-01, + equal_nan=False, + ), msg="MLA Logits from generic dot product and flash attention + context/expert parallelism are not close.\n" f"ici_context_parallelism={ici_context_parallelism}, context_parallel_load_balance={context_parallel_load_balance}," f" ici_expert_parallelism={ici_expert_parallelism}, expert_shard_attention_option={expert_shard_attention_option}.", @@ -1725,7 +1903,11 @@ def get_structured_data(self, dtype): """get structured data for GDN (only requires hidden states)""" lnx = jax.random.normal( self.rng, - shape=(self.cfg.global_batch_size_to_train_on, self.cfg.max_target_length, self.cfg.emb_dim), + shape=( + self.cfg.global_batch_size_to_train_on, + self.cfg.max_target_length, + self.cfg.emb_dim, + ), dtype=dtype, ) return lnx @@ -1762,7 +1944,13 @@ def test_autoregression(self): ) self.assertTrue( - jax.numpy.allclose(gdn_prefill, gdn_full[:, :prefill_length, :], rtol=1e-02, atol=1e-02, equal_nan=False) + jax.numpy.allclose( + gdn_prefill, + gdn_full[:, :prefill_length, :], + rtol=1e-02, + atol=1e-02, + equal_nan=False, + ) ) # 5. Autoregressive mode