From 3b46b3c0d78a4f5e2d2d3df7ea3092471ee62e47 Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Thu, 19 Mar 2026 12:03:40 -0700 Subject: [PATCH 1/2] Add fused_mla_lora_proj config flag for MLA LoRA up-projections MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fuses the Q and KV LoRA up-projections (wq_a + wkv_a) into a single matmul (wq_kv_a: emb → q_lora_rank + kv_lora_rank + rope_head_dim), halving the number of kernel launches for the LoRA up-projection step. Enabled via fused_mla_lora_proj: True (requires q_lora_rank > 0 and attention_type=mla). Modeled after the existing fused_qkv flag. Includes a unit test verifying that fused and unfused paths produce numerically identical outputs given equivalent weights. --- src/maxtext/configs/base.yml | 1 + src/maxtext/configs/types.py | 9 + src/maxtext/layers/attention_mla.py | 102 +- tests/unit/attention_test.py | 3600 +++++++++++++++------------ 4 files changed, 2061 insertions(+), 1651 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 398df849fe..661914c83c 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -381,6 +381,7 @@ qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper) # Combine matmuls for QKV and MLP fused_qkv: False +fused_mla_lora_proj: False # Fuse MLA Q+KV LoRA up-projections (wq_a+wkv_a) into a single matmul. Requires q_lora_rank > 0. fused_mlp: False record_internal_nn_metrics: 0 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 952a1e9f8e..e84cbb9f76 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -435,6 +435,10 @@ class ModelArchitecture(BaseModel): ) normalization_layer_epsilon: float = Field(1.0e-05, description="Epsilon value for normalization layers.") fused_qkv: bool = Field(False, description="If supported, fuse the Q, K, and V projections.") + fused_mla_lora_proj: bool = Field( + False, + description="Fuse MLA Q and KV LoRA up-projections (wq_a + wkv_a) into a single matmul. Requires q_lora_rank > 0.", + ) attention_bias: bool = Field( False, description="If True, adds a learnable bias to the query, key, and value projections.", @@ -2558,6 +2562,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de if self.share_kv_projections and self.attention_type == "mla": raise ValueError("`share_kv_projections` is not compatible with `attention_type='mla'`.") + if self.fused_mla_lora_proj and self.q_lora_rank == 0: + raise ValueError("`fused_mla_lora_proj` requires `q_lora_rank > 0`.") + if self.fused_mla_lora_proj and self.attention_type != "mla": + raise ValueError("`fused_mla_lora_proj` is only valid with `attention_type='mla'`.") + # I. FINAL TYPE CONVERSIONS AND DERIVED LISTS ici_map = { "diloco": self.ici_diloco_parallelism, diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index e0d6e4e9f1..f5605e1145 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -654,8 +654,44 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No shard_mode=self.config.shard_mode, rngs=self.rngs, ) + elif self.config.fused_mla_lora_proj: + # Fused Q+KV LoRA up-projection: single matmul (emb -> q_lora_rank + kv_lora_rank + rope_head_dim). + self.wq_kv_a = DenseGeneral( + in_features_shape=self.config.emb_dim, + out_features_shape=self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim, + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "q_kv_lora_up_proj"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + matmul_precision=self.config.matmul_precision, + shard_mode=self.config.shard_mode, + rngs=self.rngs, + ) + self.q_norm = RMSNorm( + num_features=self.q_lora_rank, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + epsilon=self.config.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=self.rngs, + ) + self.wq_b = DenseGeneral( + in_features_shape=self.q_lora_rank, + out_features_shape=(self.num_query_heads, self.qk_head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("q_lora", "q_heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + matmul_precision=self.config.matmul_precision, + shard_mode=self.config.shard_mode, + rngs=self.rngs, + ) else: - # LoRA path for Q. + # Separate Q LoRA up-projection. self.wq_a = DenseGeneral( in_features_shape=self.config.emb_dim, out_features_shape=self.q_lora_rank, @@ -691,20 +727,21 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No rngs=self.rngs, ) - # KV LoRA path. - self.wkv_a = DenseGeneral( - in_features_shape=self.config.emb_dim, - out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim, - axis=-1, - kernel_init=self.kernel_init, - kernel_axes=("embed", "kv_lora_up_proj"), - dtype=self.dtype, - weight_dtype=self.weight_dtype, - quant=self.quant, - matmul_precision=self.config.matmul_precision, - shard_mode=self.config.shard_mode, - rngs=self.rngs, - ) + if not self.config.fused_mla_lora_proj: + # KV LoRA up-projection. When fused, wq_kv_a handles both Q and KV. + self.wkv_a = DenseGeneral( + in_features_shape=self.config.emb_dim, + out_features_shape=self.kv_lora_rank + self.qk_rope_head_dim, + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "kv_lora_up_proj"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + matmul_precision=self.config.matmul_precision, + shard_mode=self.config.shard_mode, + rngs=self.rngs, + ) self.kv_norm = RMSNorm( num_features=self.kv_lora_rank, dtype=self.config.dtype, @@ -792,8 +829,11 @@ def mla_query_projection( if self.q_lora_rank == 0: q = self.query(inputs_q, out_sharding=query_sharding) else: - # LoRA path - low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank] + # LoRA path: inputs_q is either raw embeddings (unfused) or the pre-split Q slice (fused). + if not self.config.fused_mla_lora_proj: + low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank] + else: + low_rank_q = inputs_q # already the q_lora_rank slice from wq_kv_a split in __call__ low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank low_rank_q = checkpoint_name(low_rank_q, "mla_q") q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads, qk_head_dim] @@ -932,7 +972,10 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm else: wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ) wkva_out_sharding = create_sharding(self.mesh, wka_logical_name) - low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding) + if self.config.fused_mla_lora_proj: + low_rank = inputs # already the kv_lora_rank+rope_head_dim slice from wq_kv_a split in __call__ + else: + low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding) low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1) low_rank_main = self.kv_norm(low_rank_main) low_rank_main = checkpoint_name(low_rank_main, "mla_kv") @@ -1068,12 +1111,23 @@ def __call__( inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names) out_logical_name = (BATCH, LENGTH_NO_EXP, HEAD, D_KV) - query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode) - if self.config.force_q_layout: - query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1))) - key, value, cached_values = self.mla_kv_projection( - inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk - ) + if self.config.fused_mla_lora_proj: + # Single matmul for both Q and KV LoRA up-projections, then split. + fused_lora = self.wq_kv_a(inputs_q) + lora_q, lora_kv = jnp.split(fused_lora, [self.q_lora_rank], axis=-1) + query, low_rank_q = self.mla_query_projection(lora_q, inputs_positions, model_mode) + if self.config.force_q_layout: + query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1))) + key, value, cached_values = self.mla_kv_projection( + lora_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk + ) + else: + query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode) + if self.config.force_q_layout: + query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1))) + key, value, cached_values = self.mla_kv_projection( + inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk + ) query = checkpoint_name(query, "query_proj") key = checkpoint_name(key, "key_proj") value = checkpoint_name(value, "value_proj") diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index fc2c3c2d24..501abe27c7 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -49,1736 +49,2082 @@ class BidirectionalBlockMaskTest(unittest.TestCase): - """Test for make_bidirectional_block_mask.""" - - def test_one_block_mask(self): - bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0]]) - # pylint: disable=protected-access - block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.asarray( - [ + """Test for make_bidirectional_block_mask.""" + + def test_one_block_mask(self): + bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0]]) + # pylint: disable=protected-access + block_mask = _make_bidirectional_block_mask(bidirectional_mask) + expected_mask = np.asarray( [ - [False, False, False, False, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, False, False, False, False, False], - [False, False, False, False, False, False], + [ + [False, False, False, False, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ] ] - ] - ) - np.testing.assert_array_equal(block_mask, expected_mask) - - def test_two_blocks_mask(self): - bidirectional_mask = np.asarray([[0, 1, 1, 0, 1, 1]]) - # pylint: disable=protected-access - block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.asarray( - [ + ) + np.testing.assert_array_equal(block_mask, expected_mask) + + def test_two_blocks_mask(self): + bidirectional_mask = np.asarray([[0, 1, 1, 0, 1, 1]]) + # pylint: disable=protected-access + block_mask = _make_bidirectional_block_mask(bidirectional_mask) + expected_mask = np.asarray( [ - [False, False, False, False, False, False], - [False, True, True, False, False, False], - [False, True, True, False, False, False], - [False, False, False, False, False, False], - [False, False, False, False, True, True], - [False, False, False, False, True, True], + [ + [False, False, False, False, False, False], + [False, True, True, False, False, False], + [False, True, True, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, True, True], + [False, False, False, False, True, True], + ] ] - ] - ) - np.testing.assert_array_equal(block_mask, expected_mask) - - def test_batch_block_masks(self): - bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1]]) - # pylint: disable=protected-access - block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.asarray( - [ - [ - [False, False, False, False, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, False, False, False, False, False], - [False, False, False, False, False, False], - ], + ) + np.testing.assert_array_equal(block_mask, expected_mask) + + def test_batch_block_masks(self): + bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1]]) + # pylint: disable=protected-access + block_mask = _make_bidirectional_block_mask(bidirectional_mask) + expected_mask = np.asarray( [ - [False, False, False, False, False, False], - [False, True, True, False, False, False], - [False, True, True, False, False, False], - [False, False, False, False, False, False], - [False, False, False, False, True, True], - [False, False, False, False, True, True], - ], - ] - ) - np.testing.assert_array_equal(block_mask, expected_mask) - - def test_empty_block_mask(self): - bidirectional_mask = np.asarray([[0, 0, 0, 0, 0, 0]]) - # pylint: disable=protected-access - block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.zeros( - (bidirectional_mask.shape[0], bidirectional_mask.shape[1], bidirectional_mask.shape[1]), dtype=bool - ) - np.testing.assert_array_equal(block_mask, expected_mask) - - def test_full_block_mask(self): - bidirectional_mask = np.asarray([[1, 1, 1, 1, 1, 1]]) - # pylint: disable=protected-access - block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.ones( - (bidirectional_mask.shape[0], bidirectional_mask.shape[1], bidirectional_mask.shape[1]), dtype=bool - ) - np.testing.assert_array_equal(block_mask, expected_mask) - - def test_combine_with_causal_mask(self): - seq_len = 6 - row_ids = np.arange(seq_len, dtype=np.int32)[:, None] - col_ids = np.arange(seq_len, dtype=np.int32)[None, :] - causal_mask = (col_ids <= row_ids)[None, None, None, :, :] - bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1]]) - # pylint: disable=protected-access - image_mask = _make_bidirectional_block_mask(bidirectional_mask) - combined_mask = causal_mask | image_mask[:, None, None, ...] - expected_mask = np.asarray( - [ + [ + [False, False, False, False, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ], + [ + [False, False, False, False, False, False], + [False, True, True, False, False, False], + [False, True, True, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, True, True], + [False, False, False, False, True, True], + ], + ] + ) + np.testing.assert_array_equal(block_mask, expected_mask) + + def test_empty_block_mask(self): + bidirectional_mask = np.asarray([[0, 0, 0, 0, 0, 0]]) + # pylint: disable=protected-access + block_mask = _make_bidirectional_block_mask(bidirectional_mask) + expected_mask = np.zeros( + ( + bidirectional_mask.shape[0], + bidirectional_mask.shape[1], + bidirectional_mask.shape[1], + ), + dtype=bool, + ) + np.testing.assert_array_equal(block_mask, expected_mask) + + def test_full_block_mask(self): + bidirectional_mask = np.asarray([[1, 1, 1, 1, 1, 1]]) + # pylint: disable=protected-access + block_mask = _make_bidirectional_block_mask(bidirectional_mask) + expected_mask = np.ones( + ( + bidirectional_mask.shape[0], + bidirectional_mask.shape[1], + bidirectional_mask.shape[1], + ), + dtype=bool, + ) + np.testing.assert_array_equal(block_mask, expected_mask) + + def test_combine_with_causal_mask(self): + seq_len = 6 + row_ids = np.arange(seq_len, dtype=np.int32)[:, None] + col_ids = np.arange(seq_len, dtype=np.int32)[None, :] + causal_mask = (col_ids <= row_ids)[None, None, None, :, :] + bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1]]) + # pylint: disable=protected-access + image_mask = _make_bidirectional_block_mask(bidirectional_mask) + combined_mask = causal_mask | image_mask[:, None, None, ...] + expected_mask = np.asarray( [ [ [ - [True, False, False, False, False, False], - [True, True, True, True, False, False], - [True, True, True, True, False, False], - [True, True, True, True, False, False], - [True, True, True, True, True, False], - [True, True, True, True, True, True], + [ + [True, False, False, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, True, False], + [True, True, True, True, True, True], + ] ] - ] - ], - [ + ], [ [ - [True, False, False, False, False, False], - [True, True, True, False, False, False], - [True, True, True, False, False, False], - [True, True, True, True, False, False], - [True, True, True, True, True, True], - [True, True, True, True, True, True], + [ + [True, False, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + ] ] - ] - ], - ] - ) - np.testing.assert_array_equal(combined_mask, expected_mask) + ], + ] + ) + np.testing.assert_array_equal(combined_mask, expected_mask) class ChunkedCausalMaskTest(unittest.TestCase): - """Test for the ChunkedCausalMask.""" - - def test_basic_chunking(self): - """Tests the mask with a simple chunk size.""" - seq_len = 8 - chunk_size = 4 - mask = ChunkedCausalMask(shape=(seq_len, seq_len), chunk_size=chunk_size) - - # Manually compute the expected mask - # Causal within chunks (0-3, 4-7) - expected_mask = np.zeros((seq_len, seq_len), dtype=np.bool_) - for r in range(seq_len): - for c in range(seq_len): - q_chunk = r // chunk_size - kv_chunk = c // chunk_size - if q_chunk == kv_chunk and r >= c: - expected_mask[r, c] = True - - # Get the actual mask by slicing - actual_mask = mask[:, :] - - np.testing.assert_array_equal(actual_mask, expected_mask) - # Make sure _generate_chunk_attention_mask also produces the same mask - # pylint: disable=protected-access - actual_mask = _generate_chunk_attention_mask(mask_shape=mask.shape, chunk_size=chunk_size) - np.testing.assert_array_equal(actual_mask, expected_mask) - - def test_full_length_chunk(self): - """Tests when chunk size equals sequence length (should be causal).""" - seq_len = 6 - chunk_size = 6 # Same as seq_len - mask = ChunkedCausalMask(shape=(seq_len, seq_len), chunk_size=chunk_size) - - # Expected mask is a standard lower triangular causal mask - expected_mask = np.tril(np.ones((seq_len, seq_len), dtype=np.bool_)) - - actual_mask = mask[:, :] - np.testing.assert_array_equal(actual_mask, expected_mask) - # Make sure _generate_chunk_attention_mask also produces the same mask - # pylint: disable=protected-access - actual_mask = _generate_chunk_attention_mask(mask_shape=mask.shape, chunk_size=chunk_size) - np.testing.assert_array_equal(actual_mask, expected_mask) - - def test_single_token_chunk(self): - """Tests when chunk size is 1 (only attend to self).""" - seq_len = 5 - chunk_size = 1 - mask = ChunkedCausalMask(shape=(seq_len, seq_len), chunk_size=chunk_size) - - # Expected mask is just the identity matrix - expected_mask = np.eye(seq_len, dtype=np.bool_) - - actual_mask = mask[:, :] - np.testing.assert_array_equal(actual_mask, expected_mask) - # Make sure _generate_chunk_attention_mask also produces the same mask - # pylint: disable=protected-access - actual_mask = _generate_chunk_attention_mask(mask_shape=mask.shape, chunk_size=chunk_size) - np.testing.assert_array_equal(actual_mask, expected_mask) - - def test_non_square_shape(self): - """Tests with different query and key sequence lengths.""" - q_len = 6 - kv_len = 8 - chunk_size = 3 - mask = ChunkedCausalMask(shape=(q_len, kv_len), chunk_size=chunk_size) - - # Manually compute expected mask - expected_mask = np.zeros((q_len, kv_len), dtype=np.bool_) - for r in range(q_len): - for c in range(kv_len): - q_chunk = r // chunk_size - kv_chunk = c // chunk_size - if q_chunk == kv_chunk and r >= c: - expected_mask[r, c] = True - - actual_mask = mask[:, :] - np.testing.assert_array_equal(actual_mask, expected_mask) - # Make sure _generate_chunk_attention_mask also produces the same mask - # pylint: disable=protected-access - actual_mask = _generate_chunk_attention_mask(mask_shape=mask.shape, chunk_size=chunk_size) - np.testing.assert_array_equal(actual_mask, expected_mask) - - def test_value_error_on_zero_chunk_size(self): - """Tests that a ValueError is raised for chunk_size <= 0.""" - with self.assertRaises(ValueError): - ChunkedCausalMask(shape=(4, 4), chunk_size=0) - with self.assertRaises(ValueError): - ChunkedCausalMask(shape=(4, 4), chunk_size=-2) - with self.assertRaises(ValueError): - # pylint: disable=protected-access - _generate_chunk_attention_mask(mask_shape=(4, 4), chunk_size=0) + """Test for the ChunkedCausalMask.""" + + def test_basic_chunking(self): + """Tests the mask with a simple chunk size.""" + seq_len = 8 + chunk_size = 4 + mask = ChunkedCausalMask(shape=(seq_len, seq_len), chunk_size=chunk_size) + + # Manually compute the expected mask + # Causal within chunks (0-3, 4-7) + expected_mask = np.zeros((seq_len, seq_len), dtype=np.bool_) + for r in range(seq_len): + for c in range(seq_len): + q_chunk = r // chunk_size + kv_chunk = c // chunk_size + if q_chunk == kv_chunk and r >= c: + expected_mask[r, c] = True + + # Get the actual mask by slicing + actual_mask = mask[:, :] + + np.testing.assert_array_equal(actual_mask, expected_mask) + # Make sure _generate_chunk_attention_mask also produces the same mask + # pylint: disable=protected-access + actual_mask = _generate_chunk_attention_mask( + mask_shape=mask.shape, chunk_size=chunk_size + ) + np.testing.assert_array_equal(actual_mask, expected_mask) + + def test_full_length_chunk(self): + """Tests when chunk size equals sequence length (should be causal).""" + seq_len = 6 + chunk_size = 6 # Same as seq_len + mask = ChunkedCausalMask(shape=(seq_len, seq_len), chunk_size=chunk_size) + + # Expected mask is a standard lower triangular causal mask + expected_mask = np.tril(np.ones((seq_len, seq_len), dtype=np.bool_)) + + actual_mask = mask[:, :] + np.testing.assert_array_equal(actual_mask, expected_mask) + # Make sure _generate_chunk_attention_mask also produces the same mask + # pylint: disable=protected-access + actual_mask = _generate_chunk_attention_mask( + mask_shape=mask.shape, chunk_size=chunk_size + ) + np.testing.assert_array_equal(actual_mask, expected_mask) + + def test_single_token_chunk(self): + """Tests when chunk size is 1 (only attend to self).""" + seq_len = 5 + chunk_size = 1 + mask = ChunkedCausalMask(shape=(seq_len, seq_len), chunk_size=chunk_size) + + # Expected mask is just the identity matrix + expected_mask = np.eye(seq_len, dtype=np.bool_) + + actual_mask = mask[:, :] + np.testing.assert_array_equal(actual_mask, expected_mask) + # Make sure _generate_chunk_attention_mask also produces the same mask + # pylint: disable=protected-access + actual_mask = _generate_chunk_attention_mask( + mask_shape=mask.shape, chunk_size=chunk_size + ) + np.testing.assert_array_equal(actual_mask, expected_mask) + + def test_non_square_shape(self): + """Tests with different query and key sequence lengths.""" + q_len = 6 + kv_len = 8 + chunk_size = 3 + mask = ChunkedCausalMask(shape=(q_len, kv_len), chunk_size=chunk_size) + + # Manually compute expected mask + expected_mask = np.zeros((q_len, kv_len), dtype=np.bool_) + for r in range(q_len): + for c in range(kv_len): + q_chunk = r // chunk_size + kv_chunk = c // chunk_size + if q_chunk == kv_chunk and r >= c: + expected_mask[r, c] = True + + actual_mask = mask[:, :] + np.testing.assert_array_equal(actual_mask, expected_mask) + # Make sure _generate_chunk_attention_mask also produces the same mask + # pylint: disable=protected-access + actual_mask = _generate_chunk_attention_mask( + mask_shape=mask.shape, chunk_size=chunk_size + ) + np.testing.assert_array_equal(actual_mask, expected_mask) + + def test_value_error_on_zero_chunk_size(self): + """Tests that a ValueError is raised for chunk_size <= 0.""" + with self.assertRaises(ValueError): + ChunkedCausalMask(shape=(4, 4), chunk_size=0) + with self.assertRaises(ValueError): + ChunkedCausalMask(shape=(4, 4), chunk_size=-2) + with self.assertRaises(ValueError): + # pylint: disable=protected-access + _generate_chunk_attention_mask(mask_shape=(4, 4), chunk_size=0) class AttentionTest(parameterized.TestCase): - """Test for the Attention""" - - # Note: if you are changing these configs, please make sure to change the configs in - # context_parallelism.py as well, since we are using the same configs for both - # tests to get the same mesh and other config - config_arguments = { - "per_device_batch_size": 1.0, - "run_name": "test", - "enable_checkpointing": False, - "max_prefill_predict_length": 16, - "max_target_length": 512, - "sa_block_q": 128, - "sa_block_kv": 128, - "sa_block_kv_compute": 128, - "sa_block_q_dkv": 128, - "sa_block_kv_dkv": 128, - "sa_block_kv_dkv_compute": 128, - "sa_block_q_dq": 128, - "sa_block_kv_dq": 128, - } - - def setUp(self): - """Initializes the configuration for each test""" - super().setUp() - # Conditionally set ici_fsdp_parallelism to match device count in decoupled mode - extra_args = get_decoupled_parallelism_overrides() - if not is_decoupled(): - jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) - config = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **self.config_arguments, - **extra_args, - ) - self.cfg = config - - self.rng = jax.random.PRNGKey(0) - self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) - - devices_array = maxtext_utils.create_device_mesh(self.cfg) - self.mesh = Mesh(devices_array, self.cfg.mesh_axes) - self.global_batch_size = self.cfg.global_batch_size_to_train_on - self.num_kv_heads = self.cfg.num_kv_heads - self.num_query_heads = self.cfg.num_query_heads - self.max_target_length = self.cfg.max_target_length - self.max_prefill_predict_length = self.cfg.max_prefill_predict_length - self.head_dim = self.cfg.head_dim - self.embed_dim = self.cfg.base_emb_dim - self.dtype = self.cfg.dtype - self.attention_type = self.cfg.attention_type - - dummy_inputs_q = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) - dummy_inputs_kv = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) - self._attention_as_mha_generic = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - attention_type=self.attention_type, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) + """Test for the Attention""" - def get_data(self, dtype): - """get data""" - lnx = jax.random.normal( - self.rng, - shape=(self.global_batch_size, self.max_target_length, self.embed_dim), - dtype=dtype, - ) + # Note: if you are changing these configs, please make sure to change the configs in + # context_parallelism.py as well, since we are using the same configs for both + # tests to get the same mesh and other config + config_arguments = { + "per_device_batch_size": 1.0, + "run_name": "test", + "enable_checkpointing": False, + "max_prefill_predict_length": 16, + "max_target_length": 512, + "sa_block_q": 128, + "sa_block_kv": 128, + "sa_block_kv_compute": 128, + "sa_block_q_dkv": 128, + "sa_block_kv_dkv": 128, + "sa_block_kv_dkv_compute": 128, + "sa_block_q_dq": 128, + "sa_block_kv_dq": 128, + } - decoder_segment_ids = jax.random.randint(self.rng, (self.global_batch_size, self.max_target_length), 0, 4) - decoder_positions = jax.random.randint( - self.rng, (self.global_batch_size, self.max_target_length), 0, self.max_target_length - ) + def setUp(self): + """Initializes the configuration for each test""" + super().setUp() + # Conditionally set ici_fsdp_parallelism to match device count in decoupled mode + extra_args = get_decoupled_parallelism_overrides() + if not is_decoupled(): + jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) + config = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **self.config_arguments, + **extra_args, + ) + self.cfg = config + + self.rng = jax.random.PRNGKey(0) + self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) + + devices_array = maxtext_utils.create_device_mesh(self.cfg) + self.mesh = Mesh(devices_array, self.cfg.mesh_axes) + self.global_batch_size = self.cfg.global_batch_size_to_train_on + self.num_kv_heads = self.cfg.num_kv_heads + self.num_query_heads = self.cfg.num_query_heads + self.max_target_length = self.cfg.max_target_length + self.max_prefill_predict_length = self.cfg.max_prefill_predict_length + self.head_dim = self.cfg.head_dim + self.embed_dim = self.cfg.base_emb_dim + self.dtype = self.cfg.dtype + self.attention_type = self.cfg.attention_type + + dummy_inputs_q = jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ) + dummy_inputs_kv = jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ) + self._attention_as_mha_generic = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + attention_type=self.attention_type, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) - return lnx, decoder_segment_ids, decoder_positions + def get_data(self, dtype): + """get data""" + lnx = jax.random.normal( + self.rng, + shape=(self.global_batch_size, self.max_target_length, self.embed_dim), + dtype=dtype, + ) - def get_structured_data(self, dtype): - """get structured data""" - lnx = jax.random.normal( - self.rng, - shape=(self.global_batch_size, self.max_target_length, self.embed_dim), - dtype=dtype, - ) + decoder_segment_ids = jax.random.randint( + self.rng, (self.global_batch_size, self.max_target_length), 0, 4 + ) + decoder_positions = jax.random.randint( + self.rng, + (self.global_batch_size, self.max_target_length), + 0, + self.max_target_length, + ) - decoder_positions = jnp.stack( - [jnp.arange(self.max_target_length, dtype=jnp.int32) for _ in range(self.global_batch_size)] - ) + return lnx, decoder_segment_ids, decoder_positions - decoder_segment_ids = ( - jax.numpy.zeros((self.global_batch_size, self.max_target_length)) + DECODING_ACTIVE_SEQUENCE_INDICATOR - ) + def get_structured_data(self, dtype): + """get structured data""" + lnx = jax.random.normal( + self.rng, + shape=(self.global_batch_size, self.max_target_length, self.embed_dim), + dtype=dtype, + ) - return lnx, decoder_segment_ids, decoder_positions - - @pytest.mark.tpu_only - def test_autoregression(self): - prefill_length = self.cfg.max_prefill_predict_length - decode_total_length = self.cfg.max_target_length - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) - - mha_full, _ = self._attention_as_mha_generic( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + decoder_positions = jnp.stack( + [ + jnp.arange(self.max_target_length, dtype=jnp.int32) + for _ in range(self.global_batch_size) + ] + ) - lnx_prefill = lnx[:, 0:prefill_length, :] - decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] - decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - - mha_prefill, _ = self._attention_as_mha_generic( - lnx_prefill, - lnx_prefill, - decoder_segment_ids=decoder_segment_ids_prefill, - inputs_positions=decoder_positions_prefill, - deterministic=True, - model_mode=MODEL_MODE_PREFILL, - ) + decoder_segment_ids = ( + jax.numpy.zeros((self.global_batch_size, self.max_target_length)) + + DECODING_ACTIVE_SEQUENCE_INDICATOR + ) - self.assertTrue( - jax.numpy.allclose(mha_prefill, mha_full[:, :prefill_length, :], rtol=1e-02, atol=1e-02, equal_nan=False) - ) + return lnx, decoder_segment_ids, decoder_positions - for idx in range(prefill_length, decode_total_length): - lnx_idx = lnx[:, idx : idx + 1, :] - decoder_positions_idx = decoder_positions[:, idx : idx + 1] - mha_idx, _ = self._attention_as_mha_generic( - lnx_idx, - lnx_idx, - inputs_positions=decoder_positions_idx, - deterministic=True, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - ) - - mha_full_this_idx = mha_full[:, idx : idx + 1, :] - self.assertTrue(mha_full_this_idx.shape == mha_idx.shape) - self.assertTrue(jax.numpy.allclose(mha_full_this_idx, mha_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) - - @pytest.mark.tpu_only - def test_model_mode_prefill_dtype_float32(self): - self._test_model_mode_prefill_dtype(jnp.float32) - - @pytest.mark.tpu_only - def test_model_mode_prefill_dtype_bfloat16(self): - """test model mode prefill for dtype bfloat16""" - self._test_model_mode_prefill_dtype(jnp.bfloat16) - - def _test_model_mode_prefill_dtype(self, dtype): - """test model mode prefill for specified dtype""" - lnx, decoder_segment_ids, decoder_positions = self.get_data(dtype) - prefill_length = self.cfg.max_prefill_predict_length - lnx_prefill = lnx[:, 0:prefill_length, :] - decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] - decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - - dummy_inputs_q = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) - dummy_inputs_kv = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) - attention_as_mha_generic = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.cfg.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=dtype, - dropout_rate=self.cfg.dropout_rate, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) + @pytest.mark.tpu_only + def test_autoregression(self): + prefill_length = self.cfg.max_prefill_predict_length + decode_total_length = self.cfg.max_target_length + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( + self.dtype + ) - mha_prefill, _ = attention_as_mha_generic( - lnx_prefill, - lnx_prefill, - decoder_segment_ids=decoder_segment_ids_prefill, - inputs_positions=decoder_positions_prefill, - deterministic=True, - model_mode=MODEL_MODE_PREFILL, - ) + mha_full, _ = self._attention_as_mha_generic( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - self.assertEqual(dtype, mha_prefill.dtype) - - @pytest.mark.tpu_only - def test_tpu_kernel_attention_mha(self): - self.tpu_kernel_attention_helper(self.num_kv_heads) - - @pytest.mark.tpu_only - def test_tpu_kernel_attention_gqa(self): - self.tpu_kernel_attention_helper(self.num_kv_heads // 2) - - @pytest.mark.tpu_only - def test_tpu_kernel_attention_mqa(self): - self.tpu_kernel_attention_helper(1) - - @pytest.mark.tpu_only - def test_tpu_kernel_attention_mha_share_kv(self): - self.tpu_kernel_attention_helper(self.num_kv_heads, share_kv_projections=True) - - @pytest.mark.tpu_only - def test_tpu_kernel_attention_gqa_share_kv(self): - self.tpu_kernel_attention_helper(self.num_kv_heads // 2, share_kv_projections=True) - - def tpu_kernel_attention_helper(self, num_kv_heads, share_kv_projections=False): - """Test equivalence between dot_product and TPU accelerated""" - - lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) - - dummy_inputs_q = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) - dummy_inputs_kv = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) - attention_as_mha_generic = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.cfg.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - share_kv_projections=share_kv_projections, - rngs=self.nnx_rng, - ) + lnx_prefill = lnx[:, 0:prefill_length, :] + decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] + decoder_positions_prefill = decoder_positions[:, 0:prefill_length] + + mha_prefill, _ = self._attention_as_mha_generic( + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) - generic_state = nnx.state(attention_as_mha_generic) + self.assertTrue( + jax.numpy.allclose( + mha_prefill, + mha_full[:, :prefill_length, :], + rtol=1e-02, + atol=1e-02, + equal_nan=False, + ) + ) - mha_generic_output, _ = attention_as_mha_generic( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + for idx in range(prefill_length, decode_total_length): + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] + mha_idx, _ = self._attention_as_mha_generic( + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + ) + + mha_full_this_idx = mha_full[:, idx : idx + 1, :] + self.assertTrue(mha_full_this_idx.shape == mha_idx.shape) + self.assertTrue( + jax.numpy.allclose( + mha_full_this_idx, mha_idx, rtol=1e-02, atol=1e-02, equal_nan=False + ) + ) + + @pytest.mark.tpu_only + def test_model_mode_prefill_dtype_float32(self): + self._test_model_mode_prefill_dtype(jnp.float32) + + @pytest.mark.tpu_only + def test_model_mode_prefill_dtype_bfloat16(self): + """test model mode prefill for dtype bfloat16""" + self._test_model_mode_prefill_dtype(jnp.bfloat16) + + def _test_model_mode_prefill_dtype(self, dtype): + """test model mode prefill for specified dtype""" + lnx, decoder_segment_ids, decoder_positions = self.get_data(dtype) + prefill_length = self.cfg.max_prefill_predict_length + lnx_prefill = lnx[:, 0:prefill_length, :] + decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] + decoder_positions_prefill = decoder_positions[:, 0:prefill_length] + + dummy_inputs_q = jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ) + dummy_inputs_kv = jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ) + attention_as_mha_generic = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.cfg.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=dtype, + dropout_rate=self.cfg.dropout_rate, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) - dummy_inputs_q = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) - dummy_inputs_kv = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) - attention_as_mha_flash = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.cfg.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="flash", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - share_kv_projections=share_kv_projections, - rngs=self.nnx_rng, - ) - nnx.update(attention_as_mha_flash, generic_state) - - mha_generic_flash_output, _ = attention_as_mha_flash( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + mha_prefill, _ = attention_as_mha_generic( + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) - self.assertTrue( - jax.numpy.allclose(mha_generic_output, mha_generic_flash_output, rtol=1e-01, atol=1e-01, equal_nan=False) - ) + self.assertEqual(dtype, mha_prefill.dtype) - def test_share_kv_projections(self): - """Test that kv projections are shared.""" - dummy_inputs_q = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) - dummy_inputs_kv = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) - attention_share_kv = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.cfg.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - share_kv_projections=True, - rngs=self.nnx_rng, - ) + @pytest.mark.tpu_only + def test_tpu_kernel_attention_mha(self): + self.tpu_kernel_attention_helper(self.num_kv_heads) - self.assertFalse(hasattr(attention_share_kv, "value")) - self.assertTrue(hasattr(attention_share_kv, "key")) + @pytest.mark.tpu_only + def test_tpu_kernel_attention_gqa(self): + self.tpu_kernel_attention_helper(self.num_kv_heads // 2) - # 1. Check NNX state - state_shared = nnx.state(attention_share_kv) - self.assertNotIn("value", state_shared) - self.assertIn("key", state_shared) + @pytest.mark.tpu_only + def test_tpu_kernel_attention_mqa(self): + self.tpu_kernel_attention_helper(1) - # 2. Forward Pass Verification - lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) + @pytest.mark.tpu_only + def test_tpu_kernel_attention_mha_share_kv(self): + self.tpu_kernel_attention_helper(self.num_kv_heads, share_kv_projections=True) - output_shared, _ = attention_share_kv( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + @pytest.mark.tpu_only + def test_tpu_kernel_attention_gqa_share_kv(self): + self.tpu_kernel_attention_helper( + self.num_kv_heads // 2, share_kv_projections=True + ) - self.assertEqual(output_shared.shape, (self.global_batch_size, self.max_target_length, self.embed_dim)) - - # 3. Equivalence Check with standard unshared Attention - attention_no_share = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.cfg.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - share_kv_projections=False, - rngs=self.nnx_rng, - ) + def tpu_kernel_attention_helper(self, num_kv_heads, share_kv_projections=False): + """Test equivalence between dot_product and TPU accelerated""" - # Force unshared layer to copy weights from shared layer, mapping 'key' to 'value' - attention_no_share.query.kernel.value = attention_share_kv.query.kernel.value - attention_no_share.key.kernel.value = attention_share_kv.key.kernel.value - attention_no_share.value.kernel.value = attention_share_kv.key.kernel.value - attention_no_share.out.kernel.value = attention_share_kv.out.kernel.value - - output_no_share, _ = attention_no_share( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) - self.assertTrue(jax.numpy.allclose(output_shared, output_no_share, rtol=1e-04, atol=1e-04, equal_nan=False)) - - @parameterized.named_parameters( - { - "testcase_name": "cp_no_load_balance", - "ici_context_parallelism": 4, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_with_load_balance", - "ici_context_parallelism": 4, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_ep_no_load_balance", - "ici_context_parallelism": 2, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_ep_with_load_balance", - "ici_context_parallelism": 2, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "ep_no_load_balance", - "ici_context_parallelism": 1, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "ep_with_load_balance", - "ici_context_parallelism": 1, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_no_load_balance_explicit", - "ici_context_parallelism": 4, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "explicit", - }, - { - "testcase_name": "cp_with_load_balance_explicit", - "ici_context_parallelism": 4, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "explicit", - }, - { - "testcase_name": "cp_ep_no_load_balance_explicit", - "ici_context_parallelism": 2, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "cp_ep_with_load_balance_explicit", - "ici_context_parallelism": 2, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "ep_no_load_balance_explicit", - "ici_context_parallelism": 1, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "ep_with_load_balance_explicit", - "ici_context_parallelism": 1, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - ) - # TODO (b/454764135.) : This tests fails with new tokamax kernel - @pytest.mark.tpu_only - def test_tpu_flash_attention_context_parallel( - self, - ici_context_parallelism, - context_parallel_load_balance, - ici_expert_parallelism, - expert_shard_attention_option, - shard_mode, - ): - """Test equivalence between dot_product and flash attention + context/expert parallelism""" - num_kv_heads = self.num_kv_heads - lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) - # Dot product - mha_generic_output, _ = self._attention_as_mha_generic( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) - generic_state = nnx.state(self._attention_as_mha_generic) - - # Test with Context Parallelism - cfg_cp = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **self.config_arguments, - ici_context_parallelism=ici_context_parallelism, - context_parallel_load_balance=context_parallel_load_balance, - ici_expert_parallelism=ici_expert_parallelism, - expert_shard_attention_option=expert_shard_attention_option, - shard_mode=shard_mode, - ) - devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp) - axis_type = AxisType.Explicit if shard_mode == "explicit" else AxisType.Auto - axis_names = [axis_type for _ in cfg_cp.mesh_axes] - mesh_cp = Mesh(devices_array_cp, cfg_cp.mesh_axes, axis_types=tuple(axis_names)) - attention_as_mha_flash_cp = Attention( - config=cfg_cp, - num_query_heads=cfg_cp.num_query_heads, - num_kv_heads=num_kv_heads, - head_dim=cfg_cp.head_dim, - max_target_length=cfg_cp.max_target_length, - max_prefill_predict_length=cfg_cp.max_prefill_predict_length, - inputs_q_shape=lnx.shape, - inputs_kv_shape=lnx.shape, - mesh=mesh_cp, - attention_kernel="flash", - dtype=self.dtype, - dropout_rate=cfg_cp.dropout_rate, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) - nnx.update(attention_as_mha_flash_cp, generic_state) - - mha_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism( - cfg_cp, - mesh_cp, - attention_as_mha_flash_cp, - lnx, - decoder_segment_ids, - decoder_positions, - ) + dummy_inputs_q = jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ) + dummy_inputs_kv = jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ) + attention_as_mha_generic = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.cfg.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + share_kv_projections=share_kv_projections, + rngs=self.nnx_rng, + ) - # This removes all sharding information and makes them standard NumPy arrays. - mha_generic_output = jax.device_get(mha_generic_output) - mha_generic_flash_cp_output = jax.device_get(mha_generic_flash_cp_output) + generic_state = nnx.state(attention_as_mha_generic) - self.assertTrue( - jax.numpy.allclose(mha_generic_output, mha_generic_flash_cp_output, rtol=1e-01, atol=1e-01, equal_nan=False), - msg="Logits from generic dot product and flash attention + context/expert parallelism are not close.\n" - f"ici_context_parallelism={ici_context_parallelism}, context_parallel_load_balance={context_parallel_load_balance}," - f" ici_expert_parallelism={ici_expert_parallelism}, expert_shard_attention_option={expert_shard_attention_option}.", - ) + mha_generic_output, _ = attention_as_mha_generic( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - @pytest.mark.tpu_only - def test_dot_product_cache_axis_order(self): - all_axis_orders = tuple(itertools.permutations(range(4))) - for axis_order in random.choices(all_axis_orders, k=4): - self.dot_product_attention_helper(prefill_cache_axis_order=axis_order, ar_cache_axis_order=axis_order) - print(f"passed test for {axis_order=}") - - def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order): - for compute_axis_order in [(0, 1, 2, 3), (0, 2, 1, 3)]: - self._dot_product_attention( - prefill_cache_axis_order, - ar_cache_axis_order, - compute_axis_order=compute_axis_order, - ) - print(f"passed subtest for {compute_axis_order=}") - - def _dot_product_attention( - self, - prefill_cache_axis_order, - ar_cache_axis_order, - compute_axis_order, - ): - """Test equalvant between different layout control in dot_product""" - - rtol, atol = 1e-02, 1e-02 - - config = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - per_device_batch_size=1.0, - run_name="test", - enable_checkpointing=False, - max_target_length=128, - max_prefill_predict_length=16, - attention="dot_product", - ) + dummy_inputs_q = jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ) + dummy_inputs_kv = jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ) + attention_as_mha_flash = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.cfg.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="flash", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + share_kv_projections=share_kv_projections, + rngs=self.nnx_rng, + ) + nnx.update(attention_as_mha_flash, generic_state) + + mha_generic_flash_output, _ = attention_as_mha_flash( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - prefill_length = config.max_prefill_predict_length - decode_total_length = config.max_target_length - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(config.dtype) - lnx_prefill = lnx[:, 0:prefill_length, :] - decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] - decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - - dummy_inputs_q = jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)) - dummy_inputs_kv = jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)) - attention_w_layout = Attention( - mesh=self.mesh, - config=config, - num_query_heads=config.num_query_heads, - num_kv_heads=config.num_kv_heads, - head_dim=config.head_dim, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - max_target_length=config.max_target_length, - max_prefill_predict_length=config.max_prefill_predict_length, - attention_kernel=config.attention, - dtype=config.dtype, - prefill_cache_axis_order=prefill_cache_axis_order, - ar_cache_axis_order=ar_cache_axis_order, - compute_axis_order=compute_axis_order, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) - attention_w_layout_full, _ = attention_w_layout( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + self.assertTrue( + jax.numpy.allclose( + mha_generic_output, + mha_generic_flash_output, + rtol=1e-01, + atol=1e-01, + equal_nan=False, + ) + ) - attention_w_layout_prefill, _ = attention_w_layout( - lnx_prefill, - lnx_prefill, - decoder_segment_ids=decoder_segment_ids_prefill, - inputs_positions=decoder_positions_prefill, - deterministic=True, - model_mode=MODEL_MODE_PREFILL, - ) - self.assertTrue( - jax.numpy.allclose(attention_w_layout_full[:, :prefill_length, :], attention_w_layout_prefill, equal_nan=False) - ) + def test_share_kv_projections(self): + """Test that kv projections are shared.""" + dummy_inputs_q = jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ) + dummy_inputs_kv = jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ) + attention_share_kv = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.cfg.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + share_kv_projections=True, + rngs=self.nnx_rng, + ) - for idx in range(prefill_length, decode_total_length): - lnx_idx = lnx[:, idx : idx + 1, :] - decoder_positions_idx = decoder_positions[:, idx : idx + 1] - - attention_w_layout_idx, _ = attention_w_layout( - lnx_idx, - lnx_idx, - inputs_positions=decoder_positions_idx, - deterministic=True, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - ) - - attention_w_layout_full_this_idx = attention_w_layout_full[:, idx : idx + 1, :] - self.assertTrue(attention_w_layout_full_this_idx.shape == attention_w_layout_idx.shape) - self.assertTrue( - jax.numpy.allclose( - attention_w_layout_full_this_idx, attention_w_layout_idx, rtol=rtol, atol=atol, equal_nan=False - ) - ) - - @pytest.mark.tpu_only - def test_dot_product_reshape_q(self): - for compute_axis_order in [(0, 1, 2, 3), (0, 2, 1, 3)]: - self._dot_product_attention_reshape_q( - compute_axis_order=compute_axis_order, - ) - print(f"test passed for compute_axis_order: {compute_axis_order}") - - def _dot_product_attention_reshape_q(self, compute_axis_order): - """Test equalvant between q and reshape q in dot_product""" - - rtol, atol = 1e-02, 1e-02 - - config = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - per_device_batch_size=1.0, - run_name="test", - enable_checkpointing=False, - max_target_length=128, - max_prefill_predict_length=16, - attention="dot_product", - ) + self.assertFalse(hasattr(attention_share_kv, "value")) + self.assertTrue(hasattr(attention_share_kv, "key")) - prefill_length = config.max_prefill_predict_length - decode_total_length = config.max_target_length - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(config.dtype) - - lnx_prefill = lnx[:, 0:prefill_length, :] - decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] - decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - - dummy_inputs_q = jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)) - dummy_inputs_kv = jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)) - - attention_wo_reshape_q = Attention( - mesh=self.mesh, - config=config, - num_query_heads=config.num_query_heads, - num_kv_heads=config.num_kv_heads, - head_dim=config.head_dim, - max_target_length=config.max_target_length, - max_prefill_predict_length=config.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - attention_kernel=config.attention, - dtype=config.dtype, - compute_axis_order=compute_axis_order, - reshape_q=False, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) + # 1. Check NNX state + state_shared = nnx.state(attention_share_kv) + self.assertNotIn("value", state_shared) + self.assertIn("key", state_shared) - attention_w_reshape_q = Attention( - mesh=self.mesh, - config=config, - num_query_heads=config.num_query_heads, - num_kv_heads=config.num_kv_heads, - head_dim=config.head_dim, - max_target_length=config.max_target_length, - max_prefill_predict_length=config.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - attention_kernel=config.attention, - dtype=config.dtype, - compute_axis_order=compute_axis_order, - reshape_q=True, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) + # 2. Forward Pass Verification + lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) - attention_wo_reshape_q_state = nnx.state(attention_wo_reshape_q) - nnx.update(attention_w_reshape_q, attention_wo_reshape_q_state) + output_shared, _ = attention_share_kv( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - attention_wo_reshape_q_full, _ = attention_wo_reshape_q( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + self.assertEqual( + output_shared.shape, + (self.global_batch_size, self.max_target_length, self.embed_dim), + ) - attention_w_reshape_q_full, _ = attention_w_reshape_q( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + # 3. Equivalence Check with standard unshared Attention + attention_no_share = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.cfg.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + share_kv_projections=False, + rngs=self.nnx_rng, + ) - attention_wo_reshape_q_prefill, _ = attention_wo_reshape_q( - lnx_prefill, - lnx_prefill, - decoder_segment_ids=decoder_segment_ids_prefill, - inputs_positions=decoder_positions_prefill, - deterministic=True, - model_mode=MODEL_MODE_PREFILL, - ) - self.assertTrue( - jax.numpy.allclose( - attention_wo_reshape_q_full[:, :prefill_length, :], attention_wo_reshape_q_prefill, equal_nan=False + # Force unshared layer to copy weights from shared layer, mapping 'key' to 'value' + attention_no_share.query.kernel.value = attention_share_kv.query.kernel.value + attention_no_share.key.kernel.value = attention_share_kv.key.kernel.value + attention_no_share.value.kernel.value = attention_share_kv.key.kernel.value + attention_no_share.out.kernel.value = attention_share_kv.out.kernel.value + + output_no_share, _ = attention_no_share( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, ) - ) - attention_w_reshape_q_prefill, _ = attention_w_reshape_q( - lnx_prefill, - lnx_prefill, - decoder_segment_ids=decoder_segment_ids_prefill, - inputs_positions=decoder_positions_prefill, - deterministic=True, - model_mode=MODEL_MODE_PREFILL, - ) - self.assertTrue( - jax.numpy.allclose( - attention_w_reshape_q_full[:, :prefill_length, :], attention_w_reshape_q_prefill, equal_nan=False + self.assertTrue( + jax.numpy.allclose( + output_shared, output_no_share, rtol=1e-04, atol=1e-04, equal_nan=False + ) ) - ) - self.assertTrue(jax.numpy.allclose(attention_wo_reshape_q_prefill, attention_w_reshape_q_prefill, equal_nan=False)) - self.assertTrue( - jax.numpy.allclose( - attention_wo_reshape_q_full[:, :prefill_length, :], - attention_w_reshape_q_full[:, :prefill_length, :], - equal_nan=False, + @parameterized.named_parameters( + { + "testcase_name": "cp_no_load_balance", + "ici_context_parallelism": 4, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_with_load_balance", + "ici_context_parallelism": 4, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_ep_no_load_balance", + "ici_context_parallelism": 2, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_ep_with_load_balance", + "ici_context_parallelism": 2, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "ep_no_load_balance", + "ici_context_parallelism": 1, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "ep_with_load_balance", + "ici_context_parallelism": 1, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_no_load_balance_explicit", + "ici_context_parallelism": 4, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_with_load_balance_explicit", + "ici_context_parallelism": 4, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_ep_no_load_balance_explicit", + "ici_context_parallelism": 2, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_ep_with_load_balance_explicit", + "ici_context_parallelism": 2, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "ep_no_load_balance_explicit", + "ici_context_parallelism": 1, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "ep_with_load_balance_explicit", + "ici_context_parallelism": 1, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + ) + # TODO (b/454764135.) : This tests fails with new tokamax kernel + @pytest.mark.tpu_only + def test_tpu_flash_attention_context_parallel( + self, + ici_context_parallelism, + context_parallel_load_balance, + ici_expert_parallelism, + expert_shard_attention_option, + shard_mode, + ): + """Test equivalence between dot_product and flash attention + context/expert parallelism""" + num_kv_heads = self.num_kv_heads + lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) + # Dot product + mha_generic_output, _ = self._attention_as_mha_generic( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + generic_state = nnx.state(self._attention_as_mha_generic) + + # Test with Context Parallelism + cfg_cp = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **self.config_arguments, + ici_context_parallelism=ici_context_parallelism, + context_parallel_load_balance=context_parallel_load_balance, + ici_expert_parallelism=ici_expert_parallelism, + expert_shard_attention_option=expert_shard_attention_option, + shard_mode=shard_mode, + ) + devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp) + axis_type = AxisType.Explicit if shard_mode == "explicit" else AxisType.Auto + axis_names = [axis_type for _ in cfg_cp.mesh_axes] + mesh_cp = Mesh(devices_array_cp, cfg_cp.mesh_axes, axis_types=tuple(axis_names)) + attention_as_mha_flash_cp = Attention( + config=cfg_cp, + num_query_heads=cfg_cp.num_query_heads, + num_kv_heads=num_kv_heads, + head_dim=cfg_cp.head_dim, + max_target_length=cfg_cp.max_target_length, + max_prefill_predict_length=cfg_cp.max_prefill_predict_length, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, + mesh=mesh_cp, + attention_kernel="flash", + dtype=self.dtype, + dropout_rate=cfg_cp.dropout_rate, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) + nnx.update(attention_as_mha_flash_cp, generic_state) + + mha_generic_flash_cp_output = ( + attention_test_util.forward_with_context_expert_parallelism( + cfg_cp, + mesh_cp, + attention_as_mha_flash_cp, + lnx, + decoder_segment_ids, + decoder_positions, + ) ) - ) - for idx in range(prefill_length, decode_total_length): - lnx_idx = lnx[:, idx : idx + 1, :] - decoder_positions_idx = decoder_positions[:, idx : idx + 1] - - attention_wo_reshape_q_idx, _ = attention_wo_reshape_q( - lnx_idx, - lnx_idx, - inputs_positions=decoder_positions_idx, - deterministic=True, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - ) - - attention_wo_reshape_q_full_this_idx = attention_wo_reshape_q_full[:, idx : idx + 1, :] - self.assertTrue(attention_wo_reshape_q_full_this_idx.shape == attention_wo_reshape_q_idx.shape) - self.assertTrue( - jax.numpy.allclose( - attention_wo_reshape_q_full_this_idx, attention_wo_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False - ) - ) - - attention_w_reshape_q_idx, _ = attention_w_reshape_q( - lnx_idx, - lnx_idx, - inputs_positions=decoder_positions_idx, - deterministic=True, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - ) - - attention_w_reshape_q_full_this_idx = attention_w_reshape_q_full[:, idx : idx + 1, :] - self.assertTrue(attention_w_reshape_q_full_this_idx.shape == attention_w_reshape_q_idx.shape) - self.assertTrue( - jax.numpy.allclose( - attention_w_reshape_q_full_this_idx, attention_w_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False - ) - ) - - self.assertTrue( - jax.numpy.allclose(attention_w_reshape_q_idx, attention_wo_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False) - ) - - def test_sliding_window_attention(self): - """Test sliding window attention""" - - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) - - dummy_inputs_q = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) - dummy_inputs_kv = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) - - # Global Attention - global_attn = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - mesh=self.mesh, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - attention_type=AttentionType.GLOBAL, - model_mode=MODEL_MODE_TRAIN, - rngs=self.nnx_rng, - ) + # This removes all sharding information and makes them standard NumPy arrays. + mha_generic_output = jax.device_get(mha_generic_output) + mha_generic_flash_cp_output = jax.device_get(mha_generic_flash_cp_output) + + self.assertTrue( + jax.numpy.allclose( + mha_generic_output, + mha_generic_flash_cp_output, + rtol=1e-01, + atol=1e-01, + equal_nan=False, + ), + msg="Logits from generic dot product and flash attention + context/expert parallelism are not close.\n" + f"ici_context_parallelism={ici_context_parallelism}, context_parallel_load_balance={context_parallel_load_balance}," + f" ici_expert_parallelism={ici_expert_parallelism}, expert_shard_attention_option={expert_shard_attention_option}.", + ) - # Attention with sliding window of size 8 - sliding_attn = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - mesh=self.mesh, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - attention_type=AttentionType.LOCAL_SLIDING, - sliding_window_size=8, - model_mode=MODEL_MODE_TRAIN, - rngs=self.nnx_rng, - ) + @pytest.mark.tpu_only + def test_dot_product_cache_axis_order(self): + all_axis_orders = tuple(itertools.permutations(range(4))) + for axis_order in random.choices(all_axis_orders, k=4): + self.dot_product_attention_helper( + prefill_cache_axis_order=axis_order, ar_cache_axis_order=axis_order + ) + print(f"passed test for {axis_order=}") + + def dot_product_attention_helper( + self, prefill_cache_axis_order, ar_cache_axis_order + ): + for compute_axis_order in [(0, 1, 2, 3), (0, 2, 1, 3)]: + self._dot_product_attention( + prefill_cache_axis_order, + ar_cache_axis_order, + compute_axis_order=compute_axis_order, + ) + print(f"passed subtest for {compute_axis_order=}") + + def _dot_product_attention( + self, + prefill_cache_axis_order, + ar_cache_axis_order, + compute_axis_order, + ): + """Test equalvant between different layout control in dot_product""" + + rtol, atol = 1e-02, 1e-02 + + config = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + max_target_length=128, + max_prefill_predict_length=16, + attention="dot_product", + ) - # To share parameters, we copy the state from sliding_attn to global_attn. - sliding_attn_state = nnx.state(sliding_attn) - nnx.update(global_attn, sliding_attn_state) - - global_attn_output, _ = global_attn( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + prefill_length = config.max_prefill_predict_length + decode_total_length = config.max_target_length + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( + config.dtype + ) + lnx_prefill = lnx[:, 0:prefill_length, :] + decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] + decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - sliding_window_output, _ = sliding_attn( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + dummy_inputs_q = jnp.ones( + (self.global_batch_size, config.max_target_length, config.base_emb_dim) + ) + dummy_inputs_kv = jnp.ones( + (self.global_batch_size, config.max_target_length, config.base_emb_dim) + ) + attention_w_layout = Attention( + mesh=self.mesh, + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + dtype=config.dtype, + prefill_cache_axis_order=prefill_cache_axis_order, + ar_cache_axis_order=ar_cache_axis_order, + compute_axis_order=compute_axis_order, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) + attention_w_layout_full, _ = attention_w_layout( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - # Test if sliding window attention is different from global attention - self.assertFalse( - jax.numpy.allclose( - sliding_window_output.astype(jnp.bfloat16), global_attn_output.astype(jnp.bfloat16), rtol=1e-04, atol=1e-04 + attention_w_layout_prefill, _ = attention_w_layout( + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) + self.assertTrue( + jax.numpy.allclose( + attention_w_layout_full[:, :prefill_length, :], + attention_w_layout_prefill, + equal_nan=False, + ) ) - ) - # Attention with sliding window of size max_target_length - # This should be equivalent to global attention. - sliding_attn_full_window = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - mesh=self.mesh, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - attention_type=AttentionType.LOCAL_SLIDING, - sliding_window_size=self.max_target_length, - model_mode=MODEL_MODE_TRAIN, - rngs=self.nnx_rng, - ) + for idx in range(prefill_length, decode_total_length): + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] - nnx.update(sliding_attn_full_window, sliding_attn_state) + attention_w_layout_idx, _ = attention_w_layout( + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + ) - sliding_window_output_full, _ = sliding_attn_full_window( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + attention_w_layout_full_this_idx = attention_w_layout_full[ + :, idx : idx + 1, : + ] + self.assertTrue( + attention_w_layout_full_this_idx.shape == attention_w_layout_idx.shape + ) + self.assertTrue( + jax.numpy.allclose( + attention_w_layout_full_this_idx, + attention_w_layout_idx, + rtol=rtol, + atol=atol, + equal_nan=False, + ) + ) + + @pytest.mark.tpu_only + def test_dot_product_reshape_q(self): + for compute_axis_order in [(0, 1, 2, 3), (0, 2, 1, 3)]: + self._dot_product_attention_reshape_q( + compute_axis_order=compute_axis_order, + ) + print(f"test passed for compute_axis_order: {compute_axis_order}") + + def _dot_product_attention_reshape_q(self, compute_axis_order): + """Test equalvant between q and reshape q in dot_product""" + + rtol, atol = 1e-02, 1e-02 + + config = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + max_target_length=128, + max_prefill_predict_length=16, + attention="dot_product", + ) - print(f"{sliding_window_output_full.astype(jnp.bfloat16)=}") - print(f"{global_attn_output.astype(jnp.bfloat16)=}") + prefill_length = config.max_prefill_predict_length + decode_total_length = config.max_target_length + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( + config.dtype + ) + + lnx_prefill = lnx[:, 0:prefill_length, :] + decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] + decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - # Test if sliding window attention with max_target_length size is the same as global attention - self.assertTrue( - jax.numpy.allclose( - sliding_window_output_full.astype(jnp.bfloat16), - global_attn_output.astype(jnp.bfloat16), - rtol=1e-04, - atol=1e-04, + dummy_inputs_q = jnp.ones( + (self.global_batch_size, config.max_target_length, config.base_emb_dim) + ) + dummy_inputs_kv = jnp.ones( + (self.global_batch_size, config.max_target_length, config.base_emb_dim) ) - ) - @pytest.mark.skip(reason="Requires `vllm-tpu` package which is not yet a MaxText dependency.") - @pytest.mark.tpu_only - @mock.patch("tpu_inference.layers.jax.attention_interface.sharded_ragged_paged_attention", create=True) - def test_forward_serve_vllm(self, mock_sharded_ragged_paged_attention): - """Tests the forward_serve_vllm method with mocked RPA attention.""" - # Setup config for vLLM RPA - vllm_config_arguments = self.config_arguments.copy() - vllm_config_arguments["attention"] = "vllm_rpa" - vllm_config_arguments["chunk_attn_window_size"] = 128 - config = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **vllm_config_arguments, - ) + attention_wo_reshape_q = Attention( + mesh=self.mesh, + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + attention_kernel=config.attention, + dtype=config.dtype, + compute_axis_order=compute_axis_order, + reshape_q=False, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) - seq_len = self.max_target_length - - # Create Attention instance - dummy_inputs_q = jnp.ones((self.global_batch_size, seq_len, self.embed_dim)) - dummy_inputs_kv = jnp.ones((self.global_batch_size, seq_len, self.embed_dim)) - attention_vllm = Attention( - config=config, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=self.dtype, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - rngs=self.nnx_rng, - ) + attention_w_reshape_q = Attention( + mesh=self.mesh, + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + attention_kernel=config.attention, + dtype=config.dtype, + compute_axis_order=compute_axis_order, + reshape_q=True, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) - # Prepare inputs - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) - mock_kv_cache = [jnp.ones((1,))] - - mock_attention_metadata = mock.Mock() - mock_attention_metadata.seq_lens = jnp.array([1] * self.global_batch_size) - mock_attention_metadata.block_tables = jnp.array([[0]] * self.global_batch_size) - mock_attention_metadata.query_start_loc = jnp.array(list(range(self.global_batch_size))) - mock_attention_metadata.request_distribution = jnp.array([self.global_batch_size]) - - # Mock the return value of sharded_ragged_paged_attention - total_tokens = self.global_batch_size * seq_len - mock_output_shape = (total_tokens, self.num_query_heads, self.head_dim) - mock_output = jnp.ones(mock_output_shape, dtype=self.dtype) - mock_updated_kv_cache = [jnp.zeros((1,))] - - mock_callable = mock.Mock(return_value=(mock_output, mock_updated_kv_cache)) - mock_sharded_ragged_paged_attention.return_value = mock_callable - - # Call the attention layer - output, updated_kv_cache = attention_vllm( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - kv_cache=mock_kv_cache, - attention_metadata=mock_attention_metadata, - ) + attention_wo_reshape_q_state = nnx.state(attention_wo_reshape_q) + nnx.update(attention_w_reshape_q, attention_wo_reshape_q_state) + + attention_wo_reshape_q_full, _ = attention_wo_reshape_q( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + + attention_w_reshape_q_full, _ = attention_w_reshape_q( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + + attention_wo_reshape_q_prefill, _ = attention_wo_reshape_q( + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) + self.assertTrue( + jax.numpy.allclose( + attention_wo_reshape_q_full[:, :prefill_length, :], + attention_wo_reshape_q_prefill, + equal_nan=False, + ) + ) + + attention_w_reshape_q_prefill, _ = attention_w_reshape_q( + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) + self.assertTrue( + jax.numpy.allclose( + attention_w_reshape_q_full[:, :prefill_length, :], + attention_w_reshape_q_prefill, + equal_nan=False, + ) + ) + + self.assertTrue( + jax.numpy.allclose( + attention_wo_reshape_q_prefill, + attention_w_reshape_q_prefill, + equal_nan=False, + ) + ) + self.assertTrue( + jax.numpy.allclose( + attention_wo_reshape_q_full[:, :prefill_length, :], + attention_w_reshape_q_full[:, :prefill_length, :], + equal_nan=False, + ) + ) + + for idx in range(prefill_length, decode_total_length): + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] + + attention_wo_reshape_q_idx, _ = attention_wo_reshape_q( + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + ) + + attention_wo_reshape_q_full_this_idx = attention_wo_reshape_q_full[ + :, idx : idx + 1, : + ] + self.assertTrue( + attention_wo_reshape_q_full_this_idx.shape + == attention_wo_reshape_q_idx.shape + ) + self.assertTrue( + jax.numpy.allclose( + attention_wo_reshape_q_full_this_idx, + attention_wo_reshape_q_idx, + rtol=rtol, + atol=atol, + equal_nan=False, + ) + ) + + attention_w_reshape_q_idx, _ = attention_w_reshape_q( + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + ) + + attention_w_reshape_q_full_this_idx = attention_w_reshape_q_full[ + :, idx : idx + 1, : + ] + self.assertTrue( + attention_w_reshape_q_full_this_idx.shape + == attention_w_reshape_q_idx.shape + ) + self.assertTrue( + jax.numpy.allclose( + attention_w_reshape_q_full_this_idx, + attention_w_reshape_q_idx, + rtol=rtol, + atol=atol, + equal_nan=False, + ) + ) + + self.assertTrue( + jax.numpy.allclose( + attention_w_reshape_q_idx, + attention_wo_reshape_q_idx, + rtol=rtol, + atol=atol, + equal_nan=False, + ) + ) + + def test_sliding_window_attention(self): + """Test sliding window attention""" + + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( + self.dtype + ) + + dummy_inputs_q = jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ) + dummy_inputs_kv = jnp.ones( + (self.global_batch_size, self.max_target_length, self.embed_dim) + ) + + # Global Attention + global_attn = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + mesh=self.mesh, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + attention_type=AttentionType.GLOBAL, + model_mode=MODEL_MODE_TRAIN, + rngs=self.nnx_rng, + ) + + # Attention with sliding window of size 8 + sliding_attn = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + mesh=self.mesh, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + attention_type=AttentionType.LOCAL_SLIDING, + sliding_window_size=8, + model_mode=MODEL_MODE_TRAIN, + rngs=self.nnx_rng, + ) + + # To share parameters, we copy the state from sliding_attn to global_attn. + sliding_attn_state = nnx.state(sliding_attn) + nnx.update(global_attn, sliding_attn_state) + + global_attn_output, _ = global_attn( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + + sliding_window_output, _ = sliding_attn( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + + # Test if sliding window attention is different from global attention + self.assertFalse( + jax.numpy.allclose( + sliding_window_output.astype(jnp.bfloat16), + global_attn_output.astype(jnp.bfloat16), + rtol=1e-04, + atol=1e-04, + ) + ) - # Assertions - mock_sharded_ragged_paged_attention.assert_called_once() - mock_callable.assert_called_once() - self.assertEqual(updated_kv_cache, mock_updated_kv_cache) + # Attention with sliding window of size max_target_length + # This should be equivalent to global attention. + sliding_attn_full_window = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + mesh=self.mesh, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + attention_type=AttentionType.LOCAL_SLIDING, + sliding_window_size=self.max_target_length, + model_mode=MODEL_MODE_TRAIN, + rngs=self.nnx_rng, + ) + + nnx.update(sliding_attn_full_window, sliding_attn_state) + + sliding_window_output_full, _ = sliding_attn_full_window( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + + print(f"{sliding_window_output_full.astype(jnp.bfloat16)=}") + print(f"{global_attn_output.astype(jnp.bfloat16)=}") + + # Test if sliding window attention with max_target_length size is the same as global attention + self.assertTrue( + jax.numpy.allclose( + sliding_window_output_full.astype(jnp.bfloat16), + global_attn_output.astype(jnp.bfloat16), + rtol=1e-04, + atol=1e-04, + ) + ) - # The output of forward_serve_vllm is reshaped back to (batch, seq, ...) - reshaped_mock_output = mock_output.reshape(self.global_batch_size, seq_len, self.num_query_heads, self.head_dim) - expected_output = attention_vllm.out_projection(reshaped_mock_output) - self.assertTrue(jnp.allclose(output, expected_output)) - self.assertEqual(output.shape, (self.global_batch_size, seq_len, self.embed_dim)) + @pytest.mark.skip( + reason="Requires `vllm-tpu` package which is not yet a MaxText dependency." + ) + @pytest.mark.tpu_only + @mock.patch( + "tpu_inference.layers.jax.attention_interface.sharded_ragged_paged_attention", + create=True, + ) + def test_forward_serve_vllm(self, mock_sharded_ragged_paged_attention): + """Tests the forward_serve_vllm method with mocked RPA attention.""" + # Setup config for vLLM RPA + vllm_config_arguments = self.config_arguments.copy() + vllm_config_arguments["attention"] = "vllm_rpa" + vllm_config_arguments["chunk_attn_window_size"] = 128 + config = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **vllm_config_arguments, + ) + + seq_len = self.max_target_length + + # Create Attention instance + dummy_inputs_q = jnp.ones((self.global_batch_size, seq_len, self.embed_dim)) + dummy_inputs_kv = jnp.ones((self.global_batch_size, seq_len, self.embed_dim)) + attention_vllm = Attention( + config=config, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + rngs=self.nnx_rng, + ) + + # Prepare inputs + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( + self.dtype + ) + mock_kv_cache = [jnp.ones((1,))] + + mock_attention_metadata = mock.Mock() + mock_attention_metadata.seq_lens = jnp.array([1] * self.global_batch_size) + mock_attention_metadata.block_tables = jnp.array([[0]] * self.global_batch_size) + mock_attention_metadata.query_start_loc = jnp.array( + list(range(self.global_batch_size)) + ) + mock_attention_metadata.request_distribution = jnp.array( + [self.global_batch_size] + ) + + # Mock the return value of sharded_ragged_paged_attention + total_tokens = self.global_batch_size * seq_len + mock_output_shape = (total_tokens, self.num_query_heads, self.head_dim) + mock_output = jnp.ones(mock_output_shape, dtype=self.dtype) + mock_updated_kv_cache = [jnp.zeros((1,))] + + mock_callable = mock.Mock(return_value=(mock_output, mock_updated_kv_cache)) + mock_sharded_ragged_paged_attention.return_value = mock_callable + + # Call the attention layer + output, updated_kv_cache = attention_vllm( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + kv_cache=mock_kv_cache, + attention_metadata=mock_attention_metadata, + ) + + # Assertions + mock_sharded_ragged_paged_attention.assert_called_once() + mock_callable.assert_called_once() + self.assertEqual(updated_kv_cache, mock_updated_kv_cache) + + # The output of forward_serve_vllm is reshaped back to (batch, seq, ...) + reshaped_mock_output = mock_output.reshape( + self.global_batch_size, seq_len, self.num_query_heads, self.head_dim + ) + expected_output = attention_vllm.out_projection(reshaped_mock_output) + self.assertTrue(jnp.allclose(output, expected_output)) + self.assertEqual( + output.shape, (self.global_batch_size, seq_len, self.embed_dim) + ) class MLATest(attention_test_util.MLATestBase): - """Test for the Multi-Headed Latent Attention""" - - @parameterized.named_parameters( - {"testcase_name": "RoPE_Yarn_Autoregression", "rope_type": "yarn"}, - {"testcase_name": "Default_Autoregression", "rope_type": "default"}, - ) - @pytest.mark.tpu_only - def test_autoregression(self, rope_type): - cfg, mla = self.init_mla(self.config_arguments, rope_type) - prefill_length = cfg.max_prefill_predict_length - decode_total_length = cfg.max_target_length - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(cfg, cfg.dtype) - - mla_full, _ = mla( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + """Test for the Multi-Headed Latent Attention""" + + @parameterized.named_parameters( + {"testcase_name": "RoPE_Yarn_Autoregression", "rope_type": "yarn"}, + {"testcase_name": "Default_Autoregression", "rope_type": "default"}, + ) + @pytest.mark.tpu_only + def test_autoregression(self, rope_type): + cfg, mla = self.init_mla(self.config_arguments, rope_type) + prefill_length = cfg.max_prefill_predict_length + decode_total_length = cfg.max_target_length + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( + cfg, cfg.dtype + ) - lnx_prefill = lnx[:, 0:prefill_length, :] - decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] - decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - - mla_prefill, _ = mla( - lnx_prefill, - lnx_prefill, - decoder_segment_ids=decoder_segment_ids_prefill, - inputs_positions=decoder_positions_prefill, - deterministic=True, - model_mode=MODEL_MODE_PREFILL, - ) + mla_full, _ = mla( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - self.assertTrue( - jax.numpy.allclose(mla_prefill, mla_full[:, :prefill_length, :], rtol=1e-02, atol=1e-02, equal_nan=False) - ) + lnx_prefill = lnx[:, 0:prefill_length, :] + decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] + decoder_positions_prefill = decoder_positions[:, 0:prefill_length] + + mla_prefill, _ = mla( + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) - for idx in range(prefill_length, decode_total_length): - lnx_idx = lnx[:, idx : idx + 1, :] - decoder_positions_idx = decoder_positions[:, idx : idx + 1] - mla_idx, _ = mla( - lnx_idx, - lnx_idx, - inputs_positions=decoder_positions_idx, - deterministic=True, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - ) - - mla_full_this_idx = mla_full[:, idx : idx + 1, :] - self.assertEqual(mla_full_this_idx.shape, mla_idx.shape) - # TODO (b/394626702) uncomment last check when decode and kv_cache are implemented for MLA - # self.assertTrue(jax.numpy.allclose(mla_full_this_idx, mla_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) - - def test_projection_initialization(self): - """Tests that MLA and Attention layers initialize the correct projection weights.""" - # 1. Initialize a standard Attention layer for comparison - # Create a copy of the arguments and override the attention_type for the base model - attention_config_args = self.config_arguments.copy() - attention_config_args["attention_type"] = AttentionType.GLOBAL.value - extra_args = get_decoupled_parallelism_overrides() - attention_cfg = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **attention_config_args, - **extra_args, - ) - dummy_inputs_q = jnp.ones( - (attention_cfg.global_batch_size_to_train_on, attention_cfg.max_target_length, attention_cfg.base_emb_dim) - ) - dummy_inputs_kv = jnp.ones( - (attention_cfg.global_batch_size_to_train_on, attention_cfg.max_target_length, attention_cfg.base_emb_dim) - ) + self.assertTrue( + jax.numpy.allclose( + mla_prefill, + mla_full[:, :prefill_length, :], + rtol=1e-02, + atol=1e-02, + equal_nan=False, + ) + ) - base_attention = Attention( - config=attention_cfg, - num_query_heads=attention_cfg.num_query_heads, - num_kv_heads=attention_cfg.num_kv_heads, - head_dim=attention_cfg.head_dim, - max_target_length=attention_cfg.max_target_length, - max_prefill_predict_length=attention_cfg.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=attention_cfg.dtype, - rngs=self.nnx_rng, - ) + for idx in range(prefill_length, decode_total_length): + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] + mla_idx, _ = mla( + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + ) + + mla_full_this_idx = mla_full[:, idx : idx + 1, :] + self.assertEqual(mla_full_this_idx.shape, mla_idx.shape) + # TODO (b/394626702) uncomment last check when decode and kv_cache are implemented for MLA + # self.assertTrue(jax.numpy.allclose(mla_full_this_idx, mla_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) + + def test_projection_initialization(self): + """Tests that MLA and Attention layers initialize the correct projection weights.""" + # 1. Initialize a standard Attention layer for comparison + # Create a copy of the arguments and override the attention_type for the base model + attention_config_args = self.config_arguments.copy() + attention_config_args["attention_type"] = AttentionType.GLOBAL.value + extra_args = get_decoupled_parallelism_overrides() + attention_cfg = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **attention_config_args, + **extra_args, + ) + dummy_inputs_q = jnp.ones( + ( + attention_cfg.global_batch_size_to_train_on, + attention_cfg.max_target_length, + attention_cfg.base_emb_dim, + ) + ) + dummy_inputs_kv = jnp.ones( + ( + attention_cfg.global_batch_size_to_train_on, + attention_cfg.max_target_length, + attention_cfg.base_emb_dim, + ) + ) - # 2. Assert that the base Attention layer HAS all its standard projections - self.assertTrue(hasattr(base_attention, "query"), "Base Attention should have 'query' projection.") - self.assertTrue(hasattr(base_attention, "key"), "Base Attention should have 'key' projection.") - self.assertTrue(hasattr(base_attention, "value"), "Base Attention should have 'value' projection.") - self.assertTrue(hasattr(base_attention, "out"), "Base Attention should have 'out' projection.") - - # 3. Initialize the MLA layer - mla_config_args = self.config_arguments.copy() - mla_extra_args = get_decoupled_parallelism_overrides() - mla_config_args.update(mla_extra_args) - _, mla_layer = self.init_mla(mla_config_args, rope_type="default") - - # 4. Assert that the MLA layer DOES NOT HAVE the base projections - self.assertFalse(hasattr(mla_layer, "query"), "MLA should not have 'query' projection.") - self.assertFalse(hasattr(mla_layer, "key"), "MLA should not have 'key' projection.") - self.assertFalse(hasattr(mla_layer, "value"), "MLA should not have 'value' projection.") - - # 5. Assert that the MLA layer HAS all of its own specific projections AND the common 'out' projection - self.assertTrue(hasattr(mla_layer, "wq_a"), "MLA should have 'wq_a' projection.") - self.assertTrue(hasattr(mla_layer, "wq_b"), "MLA should have 'wq_b' projection.") - self.assertTrue(hasattr(mla_layer, "wkv_a"), "MLA should have 'wkv_a' projection.") - self.assertTrue(hasattr(mla_layer, "wkv_b"), "MLA should have 'wkv_b' projection.") - self.assertTrue(hasattr(mla_layer, "q_norm"), "MLA should have 'q_norm' projection.") - self.assertTrue(hasattr(mla_layer, "kv_norm"), "MLA should have 'kv_norm' projection.") - self.assertTrue(hasattr(mla_layer, "out"), "MLA should have 'out' projection.") - - @parameterized.named_parameters( - { - "testcase_name": "cp_no_load_balance", - "ici_context_parallelism": 4, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_with_load_balance", - "ici_context_parallelism": 4, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_ep_no_load_balance", - "ici_context_parallelism": 2, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_ep_with_load_balance", - "ici_context_parallelism": 2, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "ep_no_load_balance", - "ici_context_parallelism": 1, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "ep_with_load_balance", - "ici_context_parallelism": 1, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_no_load_balance_explicit", - "ici_context_parallelism": 4, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "explicit", - }, - { - "testcase_name": "cp_with_load_balance_explicit", - "ici_context_parallelism": 4, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "explicit", - }, - { - "testcase_name": "cp_ep_no_load_balance_explicit", - "ici_context_parallelism": 2, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "cp_ep_with_load_balance_explicit", - "ici_context_parallelism": 2, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "ep_no_load_balance_explicit", - "ici_context_parallelism": 1, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "ep_with_load_balance_explicit", - "ici_context_parallelism": 1, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - ) - # TODO (b/454764135.) : This tests fails with new tokamax kernel - @pytest.mark.tpu_only - def test_tpu_flash_attention_context_parallel( - self, - ici_context_parallelism, - context_parallel_load_balance, - ici_expert_parallelism, - expert_shard_attention_option, - shard_mode, - ): - """Test equivalence between dot_product and flash attention + context/expert parallelism""" + base_attention = Attention( + config=attention_cfg, + num_query_heads=attention_cfg.num_query_heads, + num_kv_heads=attention_cfg.num_kv_heads, + head_dim=attention_cfg.head_dim, + max_target_length=attention_cfg.max_target_length, + max_prefill_predict_length=attention_cfg.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=attention_cfg.dtype, + rngs=self.nnx_rng, + ) - config_arguments = { - "per_device_batch_size": 1.0, - "run_name": "test", - "enable_checkpointing": False, - "max_target_length": 512, - "sa_block_q": 128, - "sa_block_kv": 128, - "sa_block_kv_compute": 128, - "sa_block_q_dkv": 128, - "sa_block_kv_dkv": 128, - "sa_block_kv_dkv_compute": 128, - "sa_block_q_dq": 128, - "sa_block_kv_dq": 128, - "attention_type": AttentionType.MLA.value, - "q_lora_rank": 4, - "kv_lora_rank": 4, - "qk_nope_head_dim": 128, - "qk_rope_head_dim": 64, - "v_head_dim": 128, - "shard_mode": shard_mode, - } + # 2. Assert that the base Attention layer HAS all its standard projections + self.assertTrue( + hasattr(base_attention, "query"), + "Base Attention should have 'query' projection.", + ) + self.assertTrue( + hasattr(base_attention, "key"), + "Base Attention should have 'key' projection.", + ) + self.assertTrue( + hasattr(base_attention, "value"), + "Base Attention should have 'value' projection.", + ) + self.assertTrue( + hasattr(base_attention, "out"), + "Base Attention should have 'out' projection.", + ) - cfg, mla = self.init_mla(config_arguments, rope_type="default") - lnx, decoder_segment_ids, decoder_positions = self.get_data(cfg, cfg.dtype) - # Dot product - mla_generic_output, _ = mla( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) - generic_state = nnx.state(mla) - - # Test with Context Parallelism - cfg_cp = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **config_arguments, - rope_type=cfg.rope_type, - ici_context_parallelism=ici_context_parallelism, - context_parallel_load_balance=context_parallel_load_balance, - ici_expert_parallelism=ici_expert_parallelism, - expert_shard_attention_option=expert_shard_attention_option, - ) - devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp) - axis_type = AxisType.Explicit if shard_mode == "explicit" else AxisType.Auto - axis_names = [axis_type for _ in cfg_cp.mesh_axes] - mesh_cp = Mesh(devices_array_cp, cfg_cp.mesh_axes, axis_types=tuple(axis_names)) - attention_as_mla_flash_cp = MLA( - config=cfg_cp, - num_query_heads=cfg_cp.num_query_heads, - num_kv_heads=cfg_cp.num_kv_heads, - head_dim=cfg_cp.head_dim, - inputs_q_shape=lnx.shape, - inputs_kv_shape=lnx.shape, - max_target_length=cfg_cp.max_target_length, - max_prefill_predict_length=cfg_cp.max_prefill_predict_length, - mesh=mesh_cp, - attention_kernel="flash", - dtype=cfg_cp.dtype, - dropout_rate=cfg_cp.dropout_rate, - attention_type=cfg_cp.attention_type, - q_lora_rank=cfg_cp.q_lora_rank, - kv_lora_rank=cfg_cp.kv_lora_rank, - qk_nope_head_dim=cfg_cp.qk_nope_head_dim, - qk_rope_head_dim=cfg_cp.qk_rope_head_dim, - v_head_dim=cfg_cp.v_head_dim, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) - nnx.update(attention_as_mla_flash_cp, generic_state) - mla_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism( - cfg_cp, - mesh_cp, - attention_as_mla_flash_cp, - lnx, - decoder_segment_ids, - decoder_positions, - ) + # 3. Initialize the MLA layer + mla_config_args = self.config_arguments.copy() + mla_extra_args = get_decoupled_parallelism_overrides() + mla_config_args.update(mla_extra_args) + _, mla_layer = self.init_mla(mla_config_args, rope_type="default") - # This removes all sharding information and makes them standard NumPy arrays. - mla_generic_output = jax.device_get(mla_generic_output) - mla_generic_flash_cp_output = jax.device_get(mla_generic_flash_cp_output) + # 4. Assert that the MLA layer DOES NOT HAVE the base projections + self.assertFalse( + hasattr(mla_layer, "query"), "MLA should not have 'query' projection." + ) + self.assertFalse( + hasattr(mla_layer, "key"), "MLA should not have 'key' projection." + ) + self.assertFalse( + hasattr(mla_layer, "value"), "MLA should not have 'value' projection." + ) - self.assertTrue( - jax.numpy.allclose(mla_generic_output, mla_generic_flash_cp_output, rtol=1e-01, atol=1e-01, equal_nan=False), - msg="MLA Logits from generic dot product and flash attention + context/expert parallelism are not close.\n" - f"ici_context_parallelism={ici_context_parallelism}, context_parallel_load_balance={context_parallel_load_balance}," - f" ici_expert_parallelism={ici_expert_parallelism}, expert_shard_attention_option={expert_shard_attention_option}.", - ) + # 5. Assert that the MLA layer HAS all of its own specific projections AND the common 'out' projection + self.assertTrue( + hasattr(mla_layer, "wq_a"), "MLA should have 'wq_a' projection." + ) + self.assertTrue( + hasattr(mla_layer, "wq_b"), "MLA should have 'wq_b' projection." + ) + self.assertTrue( + hasattr(mla_layer, "wkv_a"), "MLA should have 'wkv_a' projection." + ) + self.assertTrue( + hasattr(mla_layer, "wkv_b"), "MLA should have 'wkv_b' projection." + ) + self.assertTrue( + hasattr(mla_layer, "q_norm"), "MLA should have 'q_norm' projection." + ) + self.assertTrue( + hasattr(mla_layer, "kv_norm"), "MLA should have 'kv_norm' projection." + ) + self.assertTrue(hasattr(mla_layer, "out"), "MLA should have 'out' projection.") + + def test_fused_mla_lora_proj_output_equivalence(self): + """Tests that fused_mla_lora_proj=True produces identical outputs to fused_mla_lora_proj=False.""" + extra_args = get_decoupled_parallelism_overrides() + + # Initialize the unfused model. + unfused_args = { + **self.config_arguments, + "fused_mla_lora_proj": False, + **extra_args, + } + cfg_unfused = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], **unfused_args + ) + devices_array = maxtext_utils.create_device_mesh(cfg_unfused) + mesh = Mesh(devices_array, cfg_unfused.mesh_axes) + dummy_q = jnp.ones( + ( + cfg_unfused.global_batch_size_to_train_on, + cfg_unfused.max_target_length, + cfg_unfused.base_emb_dim, + ) + ) + mla_unfused = MLA( + config=cfg_unfused, + num_query_heads=cfg_unfused.num_query_heads, + num_kv_heads=cfg_unfused.num_kv_heads, + head_dim=cfg_unfused.head_dim, + inputs_q_shape=dummy_q.shape, + inputs_kv_shape=dummy_q.shape, + max_target_length=cfg_unfused.max_target_length, + max_prefill_predict_length=cfg_unfused.max_prefill_predict_length, + mesh=mesh, + attention_kernel="dot_product", + dtype=cfg_unfused.dtype, + dropout_rate=cfg_unfused.dropout_rate, + attention_type=cfg_unfused.attention_type, + q_lora_rank=cfg_unfused.q_lora_rank, + kv_lora_rank=cfg_unfused.kv_lora_rank, + qk_nope_head_dim=cfg_unfused.qk_nope_head_dim, + qk_rope_head_dim=cfg_unfused.qk_rope_head_dim, + v_head_dim=cfg_unfused.v_head_dim, + model_mode=MODEL_MODE_TRAIN, + rngs=nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)), + ) - def get_indexer_test_data(self, batch_size, q_len, kv_len, num_heads, head_dim): - """Helper to generate random data for indexer tests.""" - key_q, key_k, key_is = jax.random.split(self.rng, 3) - query = jax.random.normal(key_q, (batch_size, q_len, num_heads, head_dim)) - key = jax.random.normal(key_k, (batch_size, kv_len, num_heads, head_dim)) - indexer_score = jax.random.normal(key_is, (batch_size, q_len, kv_len)) - return query, key, indexer_score - - def get_causal_mask_for_indexer(self, batch_size, q_len, kv_len): - """Helper to generate a causal mask with DEFAULT_MASK_VALUE.""" - row_ids = jnp.arange(q_len)[:, None] - col_ids = jnp.arange(kv_len)[None, :] - attention_mask = jnp.where(col_ids <= row_ids, 0.0, DEFAULT_MASK_VALUE) - attention_mask = jnp.broadcast_to(attention_mask, (batch_size, q_len, kv_len)) - return attention_mask - - def test_indexer_loss(self): - """Test indexer loss computation.""" - mla_config_args = self.config_arguments.copy() - mla_config_args["use_sparse_indexer"] = True - mla_config_args["attention"] = "dot_product" - _, mla = self.init_mla(mla_config_args, rope_type="default") - - batch_size = 2 - q_len = 3 - kv_len = 4 - num_heads = 5 - head_dim = 6 - scaling_factor = 0.5 - - query, key, indexer_score = self.get_indexer_test_data(batch_size, q_len, kv_len, num_heads, head_dim) - - # Causal mask - attention_mask = self.get_causal_mask_for_indexer(batch_size, q_len, kv_len) - indexer_score += attention_mask - - topk_indices = jnp.array([[[0, 1], [0, 1], [0, 1]], [[0, 1], [0, 1], [0, 1]]]) - indexer_mask = mla.indexer.generate_mask(topk_indices, kv_len) + attention_mask - - loss_dense = mla.calculate_indexer_loss( - indexer_score=indexer_score, - query=query, - key=key, - attention_mask=attention_mask, - indexer_mask=indexer_mask, - sparse_loss=False, - scaling_factor=scaling_factor, - ) + # Initialize the fused model. + fused_args = { + **self.config_arguments, + "fused_mla_lora_proj": True, + **extra_args, + } + cfg_fused = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], **fused_args + ) + mla_fused = MLA( + config=cfg_fused, + num_query_heads=cfg_fused.num_query_heads, + num_kv_heads=cfg_fused.num_kv_heads, + head_dim=cfg_fused.head_dim, + inputs_q_shape=dummy_q.shape, + inputs_kv_shape=dummy_q.shape, + max_target_length=cfg_fused.max_target_length, + max_prefill_predict_length=cfg_fused.max_prefill_predict_length, + mesh=mesh, + attention_kernel="dot_product", + dtype=cfg_fused.dtype, + dropout_rate=cfg_fused.dropout_rate, + attention_type=cfg_fused.attention_type, + q_lora_rank=cfg_fused.q_lora_rank, + kv_lora_rank=cfg_fused.kv_lora_rank, + qk_nope_head_dim=cfg_fused.qk_nope_head_dim, + qk_rope_head_dim=cfg_fused.qk_rope_head_dim, + v_head_dim=cfg_fused.v_head_dim, + model_mode=MODEL_MODE_TRAIN, + rngs=nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)), + ) - loss_sparse = mla.calculate_indexer_loss( - indexer_score=indexer_score, - query=query, - key=key, - attention_mask=attention_mask, - indexer_mask=indexer_mask, - sparse_loss=True, - scaling_factor=scaling_factor, - ) + # Make both models mathematically equivalent: + # fused wq_kv_a = concat(unfused wq_a, unfused wkv_a) along the output axis. + mla_fused.wq_kv_a.kernel.value = jnp.concatenate( + [mla_unfused.wq_a.kernel.value, mla_unfused.wkv_a.kernel.value], axis=-1 + ) + mla_fused.wq_b.kernel.value = mla_unfused.wq_b.kernel.value + mla_fused.q_norm.scale.value = mla_unfused.q_norm.scale.value + mla_fused.wkv_b.kernel.value = mla_unfused.wkv_b.kernel.value + mla_fused.kv_norm.scale.value = mla_unfused.kv_norm.scale.value + mla_fused.out.kernel.value = mla_unfused.out.kernel.value + + # Run both models on the same inputs and verify outputs are identical. + lnx, decoder_segment_ids, decoder_positions = self.get_data( + cfg_unfused, cfg_unfused.dtype + ) + common_kwargs = { + "decoder_segment_ids": decoder_segment_ids, + "inputs_positions": decoder_positions, + "deterministic": True, + "model_mode": MODEL_MODE_TRAIN, + } + output_unfused, _ = mla_unfused(lnx, lnx, **common_kwargs) + output_fused, _ = mla_fused(lnx, lnx, **common_kwargs) + + self.assertTrue( + jax.numpy.allclose( + output_unfused, output_fused, rtol=1e-05, atol=1e-05, equal_nan=False + ), + "fused_mla_lora_proj=True and fused_mla_lora_proj=False produced different outputs.", + ) - np.testing.assert_array_less(0.0, loss_dense) - np.testing.assert_array_less(0.0, loss_sparse) - - def test_indexer_loss_kl_divergence_zero(self): - """Test that KL divergence is 0 when target and pred distributions match exactly.""" - mla_config_args = self.config_arguments.copy() - mla_config_args["use_sparse_indexer"] = True - mla_config_args["attention"] = "dot_product" - _, mla = self.init_mla(mla_config_args, rope_type="default") - - batch_size = 2 - q_len = 3 - kv_len = 4 - num_heads = 5 - head_dim = 6 - - # Setup perfectly matching distributions - # Make query and key such that einsum yields zeros (so softmax gives uniform distribution over unmasked) - query = jnp.zeros((batch_size, q_len, num_heads, head_dim)) - key = jnp.zeros((batch_size, kv_len, num_heads, head_dim)) - - # Causal mask - attention_mask = self.get_causal_mask_for_indexer(batch_size, q_len, kv_len) - - # Indexer score matches the shape and is uniform - indexer_score = jnp.zeros((batch_size, q_len, kv_len)) + attention_mask - - topk_indices = jnp.array([[[0, 1], [0, 1], [0, 1]], [[0, 1], [0, 1], [0, 1]]]) - indexer_mask = mla.indexer.generate_mask(topk_indices, kv_len) + attention_mask - - loss = mla.calculate_indexer_loss( - indexer_score=indexer_score, - query=query, - key=key, - attention_mask=attention_mask, - indexer_mask=indexer_mask, - sparse_loss=False, - scaling_factor=1.0, - ) + @parameterized.named_parameters( + { + "testcase_name": "cp_no_load_balance", + "ici_context_parallelism": 4, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_with_load_balance", + "ici_context_parallelism": 4, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_ep_no_load_balance", + "ici_context_parallelism": 2, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_ep_with_load_balance", + "ici_context_parallelism": 2, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "ep_no_load_balance", + "ici_context_parallelism": 1, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "ep_with_load_balance", + "ici_context_parallelism": 1, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_no_load_balance_explicit", + "ici_context_parallelism": 4, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_with_load_balance_explicit", + "ici_context_parallelism": 4, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_ep_no_load_balance_explicit", + "ici_context_parallelism": 2, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_ep_with_load_balance_explicit", + "ici_context_parallelism": 2, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "ep_no_load_balance_explicit", + "ici_context_parallelism": 1, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "ep_with_load_balance_explicit", + "ici_context_parallelism": 1, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + ) + # TODO (b/454764135.) : This tests fails with new tokamax kernel + @pytest.mark.tpu_only + def test_tpu_flash_attention_context_parallel( + self, + ici_context_parallelism, + context_parallel_load_balance, + ici_expert_parallelism, + expert_shard_attention_option, + shard_mode, + ): + """Test equivalence between dot_product and flash attention + context/expert parallelism""" + + config_arguments = { + "per_device_batch_size": 1.0, + "run_name": "test", + "enable_checkpointing": False, + "max_target_length": 512, + "sa_block_q": 128, + "sa_block_kv": 128, + "sa_block_kv_compute": 128, + "sa_block_q_dkv": 128, + "sa_block_kv_dkv": 128, + "sa_block_kv_dkv_compute": 128, + "sa_block_q_dq": 128, + "sa_block_kv_dq": 128, + "attention_type": AttentionType.MLA.value, + "q_lora_rank": 4, + "kv_lora_rank": 4, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "shard_mode": shard_mode, + } + + cfg, mla = self.init_mla(config_arguments, rope_type="default") + lnx, decoder_segment_ids, decoder_positions = self.get_data(cfg, cfg.dtype) + # Dot product + mla_generic_output, _ = mla( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + generic_state = nnx.state(mla) + + # Test with Context Parallelism + cfg_cp = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **config_arguments, + rope_type=cfg.rope_type, + ici_context_parallelism=ici_context_parallelism, + context_parallel_load_balance=context_parallel_load_balance, + ici_expert_parallelism=ici_expert_parallelism, + expert_shard_attention_option=expert_shard_attention_option, + ) + devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp) + axis_type = AxisType.Explicit if shard_mode == "explicit" else AxisType.Auto + axis_names = [axis_type for _ in cfg_cp.mesh_axes] + mesh_cp = Mesh(devices_array_cp, cfg_cp.mesh_axes, axis_types=tuple(axis_names)) + attention_as_mla_flash_cp = MLA( + config=cfg_cp, + num_query_heads=cfg_cp.num_query_heads, + num_kv_heads=cfg_cp.num_kv_heads, + head_dim=cfg_cp.head_dim, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, + max_target_length=cfg_cp.max_target_length, + max_prefill_predict_length=cfg_cp.max_prefill_predict_length, + mesh=mesh_cp, + attention_kernel="flash", + dtype=cfg_cp.dtype, + dropout_rate=cfg_cp.dropout_rate, + attention_type=cfg_cp.attention_type, + q_lora_rank=cfg_cp.q_lora_rank, + kv_lora_rank=cfg_cp.kv_lora_rank, + qk_nope_head_dim=cfg_cp.qk_nope_head_dim, + qk_rope_head_dim=cfg_cp.qk_rope_head_dim, + v_head_dim=cfg_cp.v_head_dim, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) + nnx.update(attention_as_mla_flash_cp, generic_state) + mla_generic_flash_cp_output = ( + attention_test_util.forward_with_context_expert_parallelism( + cfg_cp, + mesh_cp, + attention_as_mla_flash_cp, + lnx, + decoder_segment_ids, + decoder_positions, + ) + ) - np.testing.assert_allclose(loss, 0.0, atol=1e-5) + # This removes all sharding information and makes them standard NumPy arrays. + mla_generic_output = jax.device_get(mla_generic_output) + mla_generic_flash_cp_output = jax.device_get(mla_generic_flash_cp_output) + + self.assertTrue( + jax.numpy.allclose( + mla_generic_output, + mla_generic_flash_cp_output, + rtol=1e-01, + atol=1e-01, + equal_nan=False, + ), + msg="MLA Logits from generic dot product and flash attention + context/expert parallelism are not close.\n" + f"ici_context_parallelism={ici_context_parallelism}, context_parallel_load_balance={context_parallel_load_balance}," + f" ici_expert_parallelism={ici_expert_parallelism}, expert_shard_attention_option={expert_shard_attention_option}.", + ) + def get_indexer_test_data(self, batch_size, q_len, kv_len, num_heads, head_dim): + """Helper to generate random data for indexer tests.""" + key_q, key_k, key_is = jax.random.split(self.rng, 3) + query = jax.random.normal(key_q, (batch_size, q_len, num_heads, head_dim)) + key = jax.random.normal(key_k, (batch_size, kv_len, num_heads, head_dim)) + indexer_score = jax.random.normal(key_is, (batch_size, q_len, kv_len)) + return query, key, indexer_score + + def get_causal_mask_for_indexer(self, batch_size, q_len, kv_len): + """Helper to generate a causal mask with DEFAULT_MASK_VALUE.""" + row_ids = jnp.arange(q_len)[:, None] + col_ids = jnp.arange(kv_len)[None, :] + attention_mask = jnp.where(col_ids <= row_ids, 0.0, DEFAULT_MASK_VALUE) + attention_mask = jnp.broadcast_to(attention_mask, (batch_size, q_len, kv_len)) + return attention_mask + + def test_indexer_loss(self): + """Test indexer loss computation.""" + mla_config_args = self.config_arguments.copy() + mla_config_args["use_sparse_indexer"] = True + mla_config_args["attention"] = "dot_product" + _, mla = self.init_mla(mla_config_args, rope_type="default") + + batch_size = 2 + q_len = 3 + kv_len = 4 + num_heads = 5 + head_dim = 6 + scaling_factor = 0.5 + + query, key, indexer_score = self.get_indexer_test_data( + batch_size, q_len, kv_len, num_heads, head_dim + ) -class Qwen3NextGatedDeltaNetTest(unittest.TestCase): - """Test for the Gated Delta Net in Qwen3-Next""" + # Causal mask + attention_mask = self.get_causal_mask_for_indexer(batch_size, q_len, kv_len) + indexer_score += attention_mask + + topk_indices = jnp.array([[[0, 1], [0, 1], [0, 1]], [[0, 1], [0, 1], [0, 1]]]) + indexer_mask = mla.indexer.generate_mask(topk_indices, kv_len) + attention_mask + + loss_dense = mla.calculate_indexer_loss( + indexer_score=indexer_score, + query=query, + key=key, + attention_mask=attention_mask, + indexer_mask=indexer_mask, + sparse_loss=False, + scaling_factor=scaling_factor, + ) - def setUp(self): - super().setUp() - self.config_arguments = { - "per_device_batch_size": 1.0, - "run_name": "test", - "enable_checkpointing": False, - "max_prefill_predict_length": 16, - "max_target_length": 32, - "base_emb_dim": 128, # changed to base_emb_dim so it properly overrides the default 2048 - "gdn_num_value_heads": 4, - "gdn_num_key_heads": 4, - "gdn_key_head_dim": 32, - "gdn_value_head_dim": 32, - "gdn_conv_kernel_dim": 4, - "gdn_chunk_size": 16, - "dtype": "bfloat16", - } - self.cfg = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **self.config_arguments, - ) - self.rng = jax.random.PRNGKey(0) - self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) - - def get_structured_data(self, dtype): - """get structured data for GDN (only requires hidden states)""" - lnx = jax.random.normal( - self.rng, - shape=(self.cfg.global_batch_size_to_train_on, self.cfg.max_target_length, self.cfg.emb_dim), - dtype=dtype, - ) - return lnx - - @pytest.mark.tpu_only - def test_autoregression(self): - cfg = self.cfg - prefill_length = cfg.max_prefill_predict_length - decode_total_length = cfg.max_target_length - - # 1. Init Data - lnx = self.get_structured_data(cfg.dtype) - - # 2. Init GDN Layer - gdn = Qwen3NextGatedDeltaNet( - config=cfg, - dtype=cfg.dtype, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) + loss_sparse = mla.calculate_indexer_loss( + indexer_score=indexer_score, + query=query, + key=key, + attention_mask=attention_mask, + indexer_mask=indexer_mask, + sparse_loss=True, + scaling_factor=scaling_factor, + ) - # 3. Full / Train mode - gdn_full = gdn( - lnx, - model_mode=MODEL_MODE_TRAIN, - ) + np.testing.assert_array_less(0.0, loss_dense) + np.testing.assert_array_less(0.0, loss_sparse) + + def test_indexer_loss_kl_divergence_zero(self): + """Test that KL divergence is 0 when target and pred distributions match exactly.""" + mla_config_args = self.config_arguments.copy() + mla_config_args["use_sparse_indexer"] = True + mla_config_args["attention"] = "dot_product" + _, mla = self.init_mla(mla_config_args, rope_type="default") + + batch_size = 2 + q_len = 3 + kv_len = 4 + num_heads = 5 + head_dim = 6 + + # Setup perfectly matching distributions + # Make query and key such that einsum yields zeros (so softmax gives uniform distribution over unmasked) + query = jnp.zeros((batch_size, q_len, num_heads, head_dim)) + key = jnp.zeros((batch_size, kv_len, num_heads, head_dim)) + + # Causal mask + attention_mask = self.get_causal_mask_for_indexer(batch_size, q_len, kv_len) + + # Indexer score matches the shape and is uniform + indexer_score = jnp.zeros((batch_size, q_len, kv_len)) + attention_mask + + topk_indices = jnp.array([[[0, 1], [0, 1], [0, 1]], [[0, 1], [0, 1], [0, 1]]]) + indexer_mask = mla.indexer.generate_mask(topk_indices, kv_len) + attention_mask + + loss = mla.calculate_indexer_loss( + indexer_score=indexer_score, + query=query, + key=key, + attention_mask=attention_mask, + indexer_mask=indexer_mask, + sparse_loss=False, + scaling_factor=1.0, + ) - # 4. Prefill mode - lnx_prefill = lnx[:, 0:prefill_length, :] + np.testing.assert_allclose(loss, 0.0, atol=1e-5) - gdn_prefill = gdn( - lnx_prefill, - model_mode=MODEL_MODE_PREFILL, - ) - self.assertTrue( - jax.numpy.allclose(gdn_prefill, gdn_full[:, :prefill_length, :], rtol=1e-02, atol=1e-02, equal_nan=False) - ) +class Qwen3NextGatedDeltaNetTest(unittest.TestCase): + """Test for the Gated Delta Net in Qwen3-Next""" + + def setUp(self): + super().setUp() + self.config_arguments = { + "per_device_batch_size": 1.0, + "run_name": "test", + "enable_checkpointing": False, + "max_prefill_predict_length": 16, + "max_target_length": 32, + "base_emb_dim": 128, # changed to base_emb_dim so it properly overrides the default 2048 + "gdn_num_value_heads": 4, + "gdn_num_key_heads": 4, + "gdn_key_head_dim": 32, + "gdn_value_head_dim": 32, + "gdn_conv_kernel_dim": 4, + "gdn_chunk_size": 16, + "dtype": "bfloat16", + } + self.cfg = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **self.config_arguments, + ) + self.rng = jax.random.PRNGKey(0) + self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) + + def get_structured_data(self, dtype): + """get structured data for GDN (only requires hidden states)""" + lnx = jax.random.normal( + self.rng, + shape=( + self.cfg.global_batch_size_to_train_on, + self.cfg.max_target_length, + self.cfg.emb_dim, + ), + dtype=dtype, + ) + return lnx + + @pytest.mark.tpu_only + def test_autoregression(self): + cfg = self.cfg + prefill_length = cfg.max_prefill_predict_length + decode_total_length = cfg.max_target_length + + # 1. Init Data + lnx = self.get_structured_data(cfg.dtype) + + # 2. Init GDN Layer + gdn = Qwen3NextGatedDeltaNet( + config=cfg, + dtype=cfg.dtype, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) + + # 3. Full / Train mode + gdn_full = gdn( + lnx, + model_mode=MODEL_MODE_TRAIN, + ) + + # 4. Prefill mode + lnx_prefill = lnx[:, 0:prefill_length, :] + + gdn_prefill = gdn( + lnx_prefill, + model_mode=MODEL_MODE_PREFILL, + ) + + self.assertTrue( + jax.numpy.allclose( + gdn_prefill, + gdn_full[:, :prefill_length, :], + rtol=1e-02, + atol=1e-02, + equal_nan=False, + ) + ) - # 5. Autoregressive mode - for idx in range(prefill_length, decode_total_length): - lnx_idx = lnx[:, idx : idx + 1, :] + # 5. Autoregressive mode + for idx in range(prefill_length, decode_total_length): + lnx_idx = lnx[:, idx : idx + 1, :] - gdn_idx = gdn( - lnx_idx, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - ) + gdn_idx = gdn( + lnx_idx, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + ) - gdn_full_this_idx = gdn_full[:, idx : idx + 1, :] - self.assertEqual(gdn_full_this_idx.shape, gdn_idx.shape) + gdn_full_this_idx = gdn_full[:, idx : idx + 1, :] + self.assertEqual(gdn_full_this_idx.shape, gdn_idx.shape) - self.assertTrue(jax.numpy.allclose(gdn_full_this_idx, gdn_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) + self.assertTrue( + jax.numpy.allclose( + gdn_full_this_idx, gdn_idx, rtol=1e-02, atol=1e-02, equal_nan=False + ) + ) if __name__ == "__main__": - unittest.main() + unittest.main() From d995983f71213389f46c03906e3f54dd000cf6c9 Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Tue, 24 Mar 2026 11:51:28 -0700 Subject: [PATCH 2/2] Fix pyink formatting in attention_test.py --- tests/unit/attention_test.py | 3794 ++++++++++++++++------------------ 1 file changed, 1818 insertions(+), 1976 deletions(-) diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 501abe27c7..49a263a411 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -49,2082 +49,1924 @@ class BidirectionalBlockMaskTest(unittest.TestCase): - """Test for make_bidirectional_block_mask.""" - - def test_one_block_mask(self): - bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0]]) - # pylint: disable=protected-access - block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.asarray( - [ - [ - [False, False, False, False, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, False, False, False, False, False], - [False, False, False, False, False, False], - ] - ] - ) - np.testing.assert_array_equal(block_mask, expected_mask) - - def test_two_blocks_mask(self): - bidirectional_mask = np.asarray([[0, 1, 1, 0, 1, 1]]) - # pylint: disable=protected-access - block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.asarray( - [ - [ - [False, False, False, False, False, False], - [False, True, True, False, False, False], - [False, True, True, False, False, False], - [False, False, False, False, False, False], - [False, False, False, False, True, True], - [False, False, False, False, True, True], - ] - ] - ) - np.testing.assert_array_equal(block_mask, expected_mask) - - def test_batch_block_masks(self): - bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1]]) - # pylint: disable=protected-access - block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.asarray( - [ - [ - [False, False, False, False, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, False, False, False, False, False], - [False, False, False, False, False, False], - ], - [ - [False, False, False, False, False, False], - [False, True, True, False, False, False], - [False, True, True, False, False, False], - [False, False, False, False, False, False], - [False, False, False, False, True, True], - [False, False, False, False, True, True], - ], - ] - ) - np.testing.assert_array_equal(block_mask, expected_mask) - - def test_empty_block_mask(self): - bidirectional_mask = np.asarray([[0, 0, 0, 0, 0, 0]]) - # pylint: disable=protected-access - block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.zeros( - ( - bidirectional_mask.shape[0], - bidirectional_mask.shape[1], - bidirectional_mask.shape[1], - ), - dtype=bool, - ) - np.testing.assert_array_equal(block_mask, expected_mask) - - def test_full_block_mask(self): - bidirectional_mask = np.asarray([[1, 1, 1, 1, 1, 1]]) - # pylint: disable=protected-access - block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.ones( - ( - bidirectional_mask.shape[0], - bidirectional_mask.shape[1], - bidirectional_mask.shape[1], - ), - dtype=bool, - ) - np.testing.assert_array_equal(block_mask, expected_mask) - - def test_combine_with_causal_mask(self): - seq_len = 6 - row_ids = np.arange(seq_len, dtype=np.int32)[:, None] - col_ids = np.arange(seq_len, dtype=np.int32)[None, :] - causal_mask = (col_ids <= row_ids)[None, None, None, :, :] - bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1]]) - # pylint: disable=protected-access - image_mask = _make_bidirectional_block_mask(bidirectional_mask) - combined_mask = causal_mask | image_mask[:, None, None, ...] - expected_mask = np.asarray( - [ - [ - [ - [ - [True, False, False, False, False, False], - [True, True, True, True, False, False], - [True, True, True, True, False, False], - [True, True, True, True, False, False], - [True, True, True, True, True, False], - [True, True, True, True, True, True], - ] - ] - ], - [ - [ - [ - [True, False, False, False, False, False], - [True, True, True, False, False, False], - [True, True, True, False, False, False], - [True, True, True, True, False, False], - [True, True, True, True, True, True], - [True, True, True, True, True, True], - ] - ] - ], - ] - ) - np.testing.assert_array_equal(combined_mask, expected_mask) + """Test for make_bidirectional_block_mask.""" + + def test_one_block_mask(self): + bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0]]) + # pylint: disable=protected-access + block_mask = _make_bidirectional_block_mask(bidirectional_mask) + expected_mask = np.asarray([[ + [False, False, False, False, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ]]) + np.testing.assert_array_equal(block_mask, expected_mask) + + def test_two_blocks_mask(self): + bidirectional_mask = np.asarray([[0, 1, 1, 0, 1, 1]]) + # pylint: disable=protected-access + block_mask = _make_bidirectional_block_mask(bidirectional_mask) + expected_mask = np.asarray([[ + [False, False, False, False, False, False], + [False, True, True, False, False, False], + [False, True, True, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, True, True], + [False, False, False, False, True, True], + ]]) + np.testing.assert_array_equal(block_mask, expected_mask) + + def test_batch_block_masks(self): + bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1]]) + # pylint: disable=protected-access + block_mask = _make_bidirectional_block_mask(bidirectional_mask) + expected_mask = np.asarray([ + [ + [False, False, False, False, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ], + [ + [False, False, False, False, False, False], + [False, True, True, False, False, False], + [False, True, True, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, True, True], + [False, False, False, False, True, True], + ], + ]) + np.testing.assert_array_equal(block_mask, expected_mask) + + def test_empty_block_mask(self): + bidirectional_mask = np.asarray([[0, 0, 0, 0, 0, 0]]) + # pylint: disable=protected-access + block_mask = _make_bidirectional_block_mask(bidirectional_mask) + expected_mask = np.zeros( + ( + bidirectional_mask.shape[0], + bidirectional_mask.shape[1], + bidirectional_mask.shape[1], + ), + dtype=bool, + ) + np.testing.assert_array_equal(block_mask, expected_mask) + + def test_full_block_mask(self): + bidirectional_mask = np.asarray([[1, 1, 1, 1, 1, 1]]) + # pylint: disable=protected-access + block_mask = _make_bidirectional_block_mask(bidirectional_mask) + expected_mask = np.ones( + ( + bidirectional_mask.shape[0], + bidirectional_mask.shape[1], + bidirectional_mask.shape[1], + ), + dtype=bool, + ) + np.testing.assert_array_equal(block_mask, expected_mask) + + def test_combine_with_causal_mask(self): + seq_len = 6 + row_ids = np.arange(seq_len, dtype=np.int32)[:, None] + col_ids = np.arange(seq_len, dtype=np.int32)[None, :] + causal_mask = (col_ids <= row_ids)[None, None, None, :, :] + bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1]]) + # pylint: disable=protected-access + image_mask = _make_bidirectional_block_mask(bidirectional_mask) + combined_mask = causal_mask | image_mask[:, None, None, ...] + expected_mask = np.asarray([ + [[[ + [True, False, False, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, True, False], + [True, True, True, True, True, True], + ]]], + [[[ + [True, False, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + ]]], + ]) + np.testing.assert_array_equal(combined_mask, expected_mask) class ChunkedCausalMaskTest(unittest.TestCase): - """Test for the ChunkedCausalMask.""" - - def test_basic_chunking(self): - """Tests the mask with a simple chunk size.""" - seq_len = 8 - chunk_size = 4 - mask = ChunkedCausalMask(shape=(seq_len, seq_len), chunk_size=chunk_size) - - # Manually compute the expected mask - # Causal within chunks (0-3, 4-7) - expected_mask = np.zeros((seq_len, seq_len), dtype=np.bool_) - for r in range(seq_len): - for c in range(seq_len): - q_chunk = r // chunk_size - kv_chunk = c // chunk_size - if q_chunk == kv_chunk and r >= c: - expected_mask[r, c] = True - - # Get the actual mask by slicing - actual_mask = mask[:, :] - - np.testing.assert_array_equal(actual_mask, expected_mask) - # Make sure _generate_chunk_attention_mask also produces the same mask - # pylint: disable=protected-access - actual_mask = _generate_chunk_attention_mask( - mask_shape=mask.shape, chunk_size=chunk_size - ) - np.testing.assert_array_equal(actual_mask, expected_mask) - - def test_full_length_chunk(self): - """Tests when chunk size equals sequence length (should be causal).""" - seq_len = 6 - chunk_size = 6 # Same as seq_len - mask = ChunkedCausalMask(shape=(seq_len, seq_len), chunk_size=chunk_size) - - # Expected mask is a standard lower triangular causal mask - expected_mask = np.tril(np.ones((seq_len, seq_len), dtype=np.bool_)) - - actual_mask = mask[:, :] - np.testing.assert_array_equal(actual_mask, expected_mask) - # Make sure _generate_chunk_attention_mask also produces the same mask - # pylint: disable=protected-access - actual_mask = _generate_chunk_attention_mask( - mask_shape=mask.shape, chunk_size=chunk_size - ) - np.testing.assert_array_equal(actual_mask, expected_mask) - - def test_single_token_chunk(self): - """Tests when chunk size is 1 (only attend to self).""" - seq_len = 5 - chunk_size = 1 - mask = ChunkedCausalMask(shape=(seq_len, seq_len), chunk_size=chunk_size) - - # Expected mask is just the identity matrix - expected_mask = np.eye(seq_len, dtype=np.bool_) - - actual_mask = mask[:, :] - np.testing.assert_array_equal(actual_mask, expected_mask) - # Make sure _generate_chunk_attention_mask also produces the same mask - # pylint: disable=protected-access - actual_mask = _generate_chunk_attention_mask( - mask_shape=mask.shape, chunk_size=chunk_size - ) - np.testing.assert_array_equal(actual_mask, expected_mask) - - def test_non_square_shape(self): - """Tests with different query and key sequence lengths.""" - q_len = 6 - kv_len = 8 - chunk_size = 3 - mask = ChunkedCausalMask(shape=(q_len, kv_len), chunk_size=chunk_size) - - # Manually compute expected mask - expected_mask = np.zeros((q_len, kv_len), dtype=np.bool_) - for r in range(q_len): - for c in range(kv_len): - q_chunk = r // chunk_size - kv_chunk = c // chunk_size - if q_chunk == kv_chunk and r >= c: - expected_mask[r, c] = True - - actual_mask = mask[:, :] - np.testing.assert_array_equal(actual_mask, expected_mask) - # Make sure _generate_chunk_attention_mask also produces the same mask - # pylint: disable=protected-access - actual_mask = _generate_chunk_attention_mask( - mask_shape=mask.shape, chunk_size=chunk_size - ) - np.testing.assert_array_equal(actual_mask, expected_mask) - - def test_value_error_on_zero_chunk_size(self): - """Tests that a ValueError is raised for chunk_size <= 0.""" - with self.assertRaises(ValueError): - ChunkedCausalMask(shape=(4, 4), chunk_size=0) - with self.assertRaises(ValueError): - ChunkedCausalMask(shape=(4, 4), chunk_size=-2) - with self.assertRaises(ValueError): - # pylint: disable=protected-access - _generate_chunk_attention_mask(mask_shape=(4, 4), chunk_size=0) + """Test for the ChunkedCausalMask.""" + + def test_basic_chunking(self): + """Tests the mask with a simple chunk size.""" + seq_len = 8 + chunk_size = 4 + mask = ChunkedCausalMask(shape=(seq_len, seq_len), chunk_size=chunk_size) + + # Manually compute the expected mask + # Causal within chunks (0-3, 4-7) + expected_mask = np.zeros((seq_len, seq_len), dtype=np.bool_) + for r in range(seq_len): + for c in range(seq_len): + q_chunk = r // chunk_size + kv_chunk = c // chunk_size + if q_chunk == kv_chunk and r >= c: + expected_mask[r, c] = True + + # Get the actual mask by slicing + actual_mask = mask[:, :] + + np.testing.assert_array_equal(actual_mask, expected_mask) + # Make sure _generate_chunk_attention_mask also produces the same mask + # pylint: disable=protected-access + actual_mask = _generate_chunk_attention_mask(mask_shape=mask.shape, chunk_size=chunk_size) + np.testing.assert_array_equal(actual_mask, expected_mask) + + def test_full_length_chunk(self): + """Tests when chunk size equals sequence length (should be causal).""" + seq_len = 6 + chunk_size = 6 # Same as seq_len + mask = ChunkedCausalMask(shape=(seq_len, seq_len), chunk_size=chunk_size) + + # Expected mask is a standard lower triangular causal mask + expected_mask = np.tril(np.ones((seq_len, seq_len), dtype=np.bool_)) + + actual_mask = mask[:, :] + np.testing.assert_array_equal(actual_mask, expected_mask) + # Make sure _generate_chunk_attention_mask also produces the same mask + # pylint: disable=protected-access + actual_mask = _generate_chunk_attention_mask(mask_shape=mask.shape, chunk_size=chunk_size) + np.testing.assert_array_equal(actual_mask, expected_mask) + + def test_single_token_chunk(self): + """Tests when chunk size is 1 (only attend to self).""" + seq_len = 5 + chunk_size = 1 + mask = ChunkedCausalMask(shape=(seq_len, seq_len), chunk_size=chunk_size) + + # Expected mask is just the identity matrix + expected_mask = np.eye(seq_len, dtype=np.bool_) + + actual_mask = mask[:, :] + np.testing.assert_array_equal(actual_mask, expected_mask) + # Make sure _generate_chunk_attention_mask also produces the same mask + # pylint: disable=protected-access + actual_mask = _generate_chunk_attention_mask(mask_shape=mask.shape, chunk_size=chunk_size) + np.testing.assert_array_equal(actual_mask, expected_mask) + + def test_non_square_shape(self): + """Tests with different query and key sequence lengths.""" + q_len = 6 + kv_len = 8 + chunk_size = 3 + mask = ChunkedCausalMask(shape=(q_len, kv_len), chunk_size=chunk_size) + + # Manually compute expected mask + expected_mask = np.zeros((q_len, kv_len), dtype=np.bool_) + for r in range(q_len): + for c in range(kv_len): + q_chunk = r // chunk_size + kv_chunk = c // chunk_size + if q_chunk == kv_chunk and r >= c: + expected_mask[r, c] = True + + actual_mask = mask[:, :] + np.testing.assert_array_equal(actual_mask, expected_mask) + # Make sure _generate_chunk_attention_mask also produces the same mask + # pylint: disable=protected-access + actual_mask = _generate_chunk_attention_mask(mask_shape=mask.shape, chunk_size=chunk_size) + np.testing.assert_array_equal(actual_mask, expected_mask) + + def test_value_error_on_zero_chunk_size(self): + """Tests that a ValueError is raised for chunk_size <= 0.""" + with self.assertRaises(ValueError): + ChunkedCausalMask(shape=(4, 4), chunk_size=0) + with self.assertRaises(ValueError): + ChunkedCausalMask(shape=(4, 4), chunk_size=-2) + with self.assertRaises(ValueError): + # pylint: disable=protected-access + _generate_chunk_attention_mask(mask_shape=(4, 4), chunk_size=0) class AttentionTest(parameterized.TestCase): - """Test for the Attention""" - - # Note: if you are changing these configs, please make sure to change the configs in - # context_parallelism.py as well, since we are using the same configs for both - # tests to get the same mesh and other config - config_arguments = { - "per_device_batch_size": 1.0, - "run_name": "test", - "enable_checkpointing": False, - "max_prefill_predict_length": 16, - "max_target_length": 512, - "sa_block_q": 128, - "sa_block_kv": 128, - "sa_block_kv_compute": 128, - "sa_block_q_dkv": 128, - "sa_block_kv_dkv": 128, - "sa_block_kv_dkv_compute": 128, - "sa_block_q_dq": 128, - "sa_block_kv_dq": 128, - } - - def setUp(self): - """Initializes the configuration for each test""" - super().setUp() - # Conditionally set ici_fsdp_parallelism to match device count in decoupled mode - extra_args = get_decoupled_parallelism_overrides() - if not is_decoupled(): - jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) - config = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **self.config_arguments, - **extra_args, - ) - self.cfg = config - - self.rng = jax.random.PRNGKey(0) - self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) - - devices_array = maxtext_utils.create_device_mesh(self.cfg) - self.mesh = Mesh(devices_array, self.cfg.mesh_axes) - self.global_batch_size = self.cfg.global_batch_size_to_train_on - self.num_kv_heads = self.cfg.num_kv_heads - self.num_query_heads = self.cfg.num_query_heads - self.max_target_length = self.cfg.max_target_length - self.max_prefill_predict_length = self.cfg.max_prefill_predict_length - self.head_dim = self.cfg.head_dim - self.embed_dim = self.cfg.base_emb_dim - self.dtype = self.cfg.dtype - self.attention_type = self.cfg.attention_type - - dummy_inputs_q = jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim) - ) - dummy_inputs_kv = jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim) - ) - self._attention_as_mha_generic = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - attention_type=self.attention_type, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) - - def get_data(self, dtype): - """get data""" - lnx = jax.random.normal( - self.rng, - shape=(self.global_batch_size, self.max_target_length, self.embed_dim), - dtype=dtype, - ) - - decoder_segment_ids = jax.random.randint( - self.rng, (self.global_batch_size, self.max_target_length), 0, 4 - ) - decoder_positions = jax.random.randint( - self.rng, - (self.global_batch_size, self.max_target_length), - 0, - self.max_target_length, - ) - - return lnx, decoder_segment_ids, decoder_positions - - def get_structured_data(self, dtype): - """get structured data""" - lnx = jax.random.normal( - self.rng, - shape=(self.global_batch_size, self.max_target_length, self.embed_dim), - dtype=dtype, - ) - - decoder_positions = jnp.stack( - [ - jnp.arange(self.max_target_length, dtype=jnp.int32) - for _ in range(self.global_batch_size) - ] - ) - - decoder_segment_ids = ( - jax.numpy.zeros((self.global_batch_size, self.max_target_length)) - + DECODING_ACTIVE_SEQUENCE_INDICATOR - ) - - return lnx, decoder_segment_ids, decoder_positions - - @pytest.mark.tpu_only - def test_autoregression(self): - prefill_length = self.cfg.max_prefill_predict_length - decode_total_length = self.cfg.max_target_length - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( - self.dtype - ) - - mha_full, _ = self._attention_as_mha_generic( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) - - lnx_prefill = lnx[:, 0:prefill_length, :] - decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] - decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - - mha_prefill, _ = self._attention_as_mha_generic( - lnx_prefill, - lnx_prefill, - decoder_segment_ids=decoder_segment_ids_prefill, - inputs_positions=decoder_positions_prefill, - deterministic=True, - model_mode=MODEL_MODE_PREFILL, - ) - - self.assertTrue( - jax.numpy.allclose( - mha_prefill, - mha_full[:, :prefill_length, :], - rtol=1e-02, - atol=1e-02, - equal_nan=False, - ) - ) - - for idx in range(prefill_length, decode_total_length): - lnx_idx = lnx[:, idx : idx + 1, :] - decoder_positions_idx = decoder_positions[:, idx : idx + 1] - mha_idx, _ = self._attention_as_mha_generic( - lnx_idx, - lnx_idx, - inputs_positions=decoder_positions_idx, - deterministic=True, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - ) - - mha_full_this_idx = mha_full[:, idx : idx + 1, :] - self.assertTrue(mha_full_this_idx.shape == mha_idx.shape) - self.assertTrue( - jax.numpy.allclose( - mha_full_this_idx, mha_idx, rtol=1e-02, atol=1e-02, equal_nan=False - ) - ) - - @pytest.mark.tpu_only - def test_model_mode_prefill_dtype_float32(self): - self._test_model_mode_prefill_dtype(jnp.float32) - - @pytest.mark.tpu_only - def test_model_mode_prefill_dtype_bfloat16(self): - """test model mode prefill for dtype bfloat16""" - self._test_model_mode_prefill_dtype(jnp.bfloat16) - - def _test_model_mode_prefill_dtype(self, dtype): - """test model mode prefill for specified dtype""" - lnx, decoder_segment_ids, decoder_positions = self.get_data(dtype) - prefill_length = self.cfg.max_prefill_predict_length - lnx_prefill = lnx[:, 0:prefill_length, :] - decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] - decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - - dummy_inputs_q = jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim) - ) - dummy_inputs_kv = jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim) - ) - attention_as_mha_generic = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.cfg.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=dtype, - dropout_rate=self.cfg.dropout_rate, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) - - mha_prefill, _ = attention_as_mha_generic( - lnx_prefill, - lnx_prefill, - decoder_segment_ids=decoder_segment_ids_prefill, - inputs_positions=decoder_positions_prefill, - deterministic=True, - model_mode=MODEL_MODE_PREFILL, - ) - - self.assertEqual(dtype, mha_prefill.dtype) - - @pytest.mark.tpu_only - def test_tpu_kernel_attention_mha(self): - self.tpu_kernel_attention_helper(self.num_kv_heads) + """Test for the Attention""" + + # Note: if you are changing these configs, please make sure to change the configs in + # context_parallelism.py as well, since we are using the same configs for both + # tests to get the same mesh and other config + config_arguments = { + "per_device_batch_size": 1.0, + "run_name": "test", + "enable_checkpointing": False, + "max_prefill_predict_length": 16, + "max_target_length": 512, + "sa_block_q": 128, + "sa_block_kv": 128, + "sa_block_kv_compute": 128, + "sa_block_q_dkv": 128, + "sa_block_kv_dkv": 128, + "sa_block_kv_dkv_compute": 128, + "sa_block_q_dq": 128, + "sa_block_kv_dq": 128, + } + + def setUp(self): + """Initializes the configuration for each test""" + super().setUp() + # Conditionally set ici_fsdp_parallelism to match device count in decoupled mode + extra_args = get_decoupled_parallelism_overrides() + if not is_decoupled(): + jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) + config = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **self.config_arguments, + **extra_args, + ) + self.cfg = config + + self.rng = jax.random.PRNGKey(0) + self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) + + devices_array = maxtext_utils.create_device_mesh(self.cfg) + self.mesh = Mesh(devices_array, self.cfg.mesh_axes) + self.global_batch_size = self.cfg.global_batch_size_to_train_on + self.num_kv_heads = self.cfg.num_kv_heads + self.num_query_heads = self.cfg.num_query_heads + self.max_target_length = self.cfg.max_target_length + self.max_prefill_predict_length = self.cfg.max_prefill_predict_length + self.head_dim = self.cfg.head_dim + self.embed_dim = self.cfg.base_emb_dim + self.dtype = self.cfg.dtype + self.attention_type = self.cfg.attention_type + + dummy_inputs_q = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) + dummy_inputs_kv = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) + self._attention_as_mha_generic = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + attention_type=self.attention_type, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) - @pytest.mark.tpu_only - def test_tpu_kernel_attention_gqa(self): - self.tpu_kernel_attention_helper(self.num_kv_heads // 2) + def get_data(self, dtype): + """get data""" + lnx = jax.random.normal( + self.rng, + shape=(self.global_batch_size, self.max_target_length, self.embed_dim), + dtype=dtype, + ) - @pytest.mark.tpu_only - def test_tpu_kernel_attention_mqa(self): - self.tpu_kernel_attention_helper(1) + decoder_segment_ids = jax.random.randint(self.rng, (self.global_batch_size, self.max_target_length), 0, 4) + decoder_positions = jax.random.randint( + self.rng, + (self.global_batch_size, self.max_target_length), + 0, + self.max_target_length, + ) - @pytest.mark.tpu_only - def test_tpu_kernel_attention_mha_share_kv(self): - self.tpu_kernel_attention_helper(self.num_kv_heads, share_kv_projections=True) + return lnx, decoder_segment_ids, decoder_positions - @pytest.mark.tpu_only - def test_tpu_kernel_attention_gqa_share_kv(self): - self.tpu_kernel_attention_helper( - self.num_kv_heads // 2, share_kv_projections=True - ) - - def tpu_kernel_attention_helper(self, num_kv_heads, share_kv_projections=False): - """Test equivalence between dot_product and TPU accelerated""" + def get_structured_data(self, dtype): + """get structured data""" + lnx = jax.random.normal( + self.rng, + shape=(self.global_batch_size, self.max_target_length, self.embed_dim), + dtype=dtype, + ) - lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) + decoder_positions = jnp.stack( + [jnp.arange(self.max_target_length, dtype=jnp.int32) for _ in range(self.global_batch_size)] + ) - dummy_inputs_q = jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim) - ) - dummy_inputs_kv = jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim) - ) - attention_as_mha_generic = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.cfg.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - share_kv_projections=share_kv_projections, - rngs=self.nnx_rng, - ) + decoder_segment_ids = ( + jax.numpy.zeros((self.global_batch_size, self.max_target_length)) + DECODING_ACTIVE_SEQUENCE_INDICATOR + ) - generic_state = nnx.state(attention_as_mha_generic) + return lnx, decoder_segment_ids, decoder_positions + + @pytest.mark.tpu_only + def test_autoregression(self): + prefill_length = self.cfg.max_prefill_predict_length + decode_total_length = self.cfg.max_target_length + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) + + mha_full, _ = self._attention_as_mha_generic( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - mha_generic_output, _ = attention_as_mha_generic( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + lnx_prefill = lnx[:, 0:prefill_length, :] + decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] + decoder_positions_prefill = decoder_positions[:, 0:prefill_length] + + mha_prefill, _ = self._attention_as_mha_generic( + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) - dummy_inputs_q = jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim) - ) - dummy_inputs_kv = jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim) - ) - attention_as_mha_flash = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.cfg.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="flash", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - share_kv_projections=share_kv_projections, - rngs=self.nnx_rng, - ) - nnx.update(attention_as_mha_flash, generic_state) - - mha_generic_flash_output, _ = attention_as_mha_flash( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, + self.assertTrue( + jax.numpy.allclose( + mha_prefill, + mha_full[:, :prefill_length, :], + rtol=1e-02, + atol=1e-02, + equal_nan=False, ) + ) - self.assertTrue( - jax.numpy.allclose( - mha_generic_output, - mha_generic_flash_output, - rtol=1e-01, - atol=1e-01, - equal_nan=False, - ) - ) + for idx in range(prefill_length, decode_total_length): + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] + mha_idx, _ = self._attention_as_mha_generic( + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + ) + + mha_full_this_idx = mha_full[:, idx : idx + 1, :] + self.assertTrue(mha_full_this_idx.shape == mha_idx.shape) + self.assertTrue(jax.numpy.allclose(mha_full_this_idx, mha_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) + + @pytest.mark.tpu_only + def test_model_mode_prefill_dtype_float32(self): + self._test_model_mode_prefill_dtype(jnp.float32) + + @pytest.mark.tpu_only + def test_model_mode_prefill_dtype_bfloat16(self): + """test model mode prefill for dtype bfloat16""" + self._test_model_mode_prefill_dtype(jnp.bfloat16) + + def _test_model_mode_prefill_dtype(self, dtype): + """test model mode prefill for specified dtype""" + lnx, decoder_segment_ids, decoder_positions = self.get_data(dtype) + prefill_length = self.cfg.max_prefill_predict_length + lnx_prefill = lnx[:, 0:prefill_length, :] + decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] + decoder_positions_prefill = decoder_positions[:, 0:prefill_length] + + dummy_inputs_q = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) + dummy_inputs_kv = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) + attention_as_mha_generic = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.cfg.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=dtype, + dropout_rate=self.cfg.dropout_rate, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) - def test_share_kv_projections(self): - """Test that kv projections are shared.""" - dummy_inputs_q = jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim) - ) - dummy_inputs_kv = jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim) - ) - attention_share_kv = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.cfg.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - share_kv_projections=True, - rngs=self.nnx_rng, - ) + mha_prefill, _ = attention_as_mha_generic( + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) - self.assertFalse(hasattr(attention_share_kv, "value")) - self.assertTrue(hasattr(attention_share_kv, "key")) + self.assertEqual(dtype, mha_prefill.dtype) + + @pytest.mark.tpu_only + def test_tpu_kernel_attention_mha(self): + self.tpu_kernel_attention_helper(self.num_kv_heads) + + @pytest.mark.tpu_only + def test_tpu_kernel_attention_gqa(self): + self.tpu_kernel_attention_helper(self.num_kv_heads // 2) + + @pytest.mark.tpu_only + def test_tpu_kernel_attention_mqa(self): + self.tpu_kernel_attention_helper(1) + + @pytest.mark.tpu_only + def test_tpu_kernel_attention_mha_share_kv(self): + self.tpu_kernel_attention_helper(self.num_kv_heads, share_kv_projections=True) + + @pytest.mark.tpu_only + def test_tpu_kernel_attention_gqa_share_kv(self): + self.tpu_kernel_attention_helper(self.num_kv_heads // 2, share_kv_projections=True) + + def tpu_kernel_attention_helper(self, num_kv_heads, share_kv_projections=False): + """Test equivalence between dot_product and TPU accelerated""" + + lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) + + dummy_inputs_q = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) + dummy_inputs_kv = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) + attention_as_mha_generic = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.cfg.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + share_kv_projections=share_kv_projections, + rngs=self.nnx_rng, + ) - # 1. Check NNX state - state_shared = nnx.state(attention_share_kv) - self.assertNotIn("value", state_shared) - self.assertIn("key", state_shared) + generic_state = nnx.state(attention_as_mha_generic) - # 2. Forward Pass Verification - lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) + mha_generic_output, _ = attention_as_mha_generic( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - output_shared, _ = attention_share_kv( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + dummy_inputs_q = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) + dummy_inputs_kv = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) + attention_as_mha_flash = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.cfg.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="flash", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + share_kv_projections=share_kv_projections, + rngs=self.nnx_rng, + ) + nnx.update(attention_as_mha_flash, generic_state) + + mha_generic_flash_output, _ = attention_as_mha_flash( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - self.assertEqual( - output_shared.shape, - (self.global_batch_size, self.max_target_length, self.embed_dim), + self.assertTrue( + jax.numpy.allclose( + mha_generic_output, + mha_generic_flash_output, + rtol=1e-01, + atol=1e-01, + equal_nan=False, ) + ) - # 3. Equivalence Check with standard unshared Attention - attention_no_share = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.cfg.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - share_kv_projections=False, - rngs=self.nnx_rng, - ) + def test_share_kv_projections(self): + """Test that kv projections are shared.""" + dummy_inputs_q = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) + dummy_inputs_kv = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) + attention_share_kv = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.cfg.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + share_kv_projections=True, + rngs=self.nnx_rng, + ) - # Force unshared layer to copy weights from shared layer, mapping 'key' to 'value' - attention_no_share.query.kernel.value = attention_share_kv.query.kernel.value - attention_no_share.key.kernel.value = attention_share_kv.key.kernel.value - attention_no_share.value.kernel.value = attention_share_kv.key.kernel.value - attention_no_share.out.kernel.value = attention_share_kv.out.kernel.value - - output_no_share, _ = attention_no_share( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + self.assertFalse(hasattr(attention_share_kv, "value")) + self.assertTrue(hasattr(attention_share_kv, "key")) - self.assertTrue( - jax.numpy.allclose( - output_shared, output_no_share, rtol=1e-04, atol=1e-04, equal_nan=False - ) - ) + # 1. Check NNX state + state_shared = nnx.state(attention_share_kv) + self.assertNotIn("value", state_shared) + self.assertIn("key", state_shared) - @parameterized.named_parameters( - { - "testcase_name": "cp_no_load_balance", - "ici_context_parallelism": 4, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_with_load_balance", - "ici_context_parallelism": 4, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_ep_no_load_balance", - "ici_context_parallelism": 2, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_ep_with_load_balance", - "ici_context_parallelism": 2, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "ep_no_load_balance", - "ici_context_parallelism": 1, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "ep_with_load_balance", - "ici_context_parallelism": 1, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_no_load_balance_explicit", - "ici_context_parallelism": 4, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "explicit", - }, - { - "testcase_name": "cp_with_load_balance_explicit", - "ici_context_parallelism": 4, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "explicit", - }, - { - "testcase_name": "cp_ep_no_load_balance_explicit", - "ici_context_parallelism": 2, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "cp_ep_with_load_balance_explicit", - "ici_context_parallelism": 2, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "ep_no_load_balance_explicit", - "ici_context_parallelism": 1, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "ep_with_load_balance_explicit", - "ici_context_parallelism": 1, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - ) - # TODO (b/454764135.) : This tests fails with new tokamax kernel - @pytest.mark.tpu_only - def test_tpu_flash_attention_context_parallel( - self, - ici_context_parallelism, - context_parallel_load_balance, - ici_expert_parallelism, - expert_shard_attention_option, - shard_mode, - ): - """Test equivalence between dot_product and flash attention + context/expert parallelism""" - num_kv_heads = self.num_kv_heads - lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) - # Dot product - mha_generic_output, _ = self._attention_as_mha_generic( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) - generic_state = nnx.state(self._attention_as_mha_generic) - - # Test with Context Parallelism - cfg_cp = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **self.config_arguments, - ici_context_parallelism=ici_context_parallelism, - context_parallel_load_balance=context_parallel_load_balance, - ici_expert_parallelism=ici_expert_parallelism, - expert_shard_attention_option=expert_shard_attention_option, - shard_mode=shard_mode, - ) - devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp) - axis_type = AxisType.Explicit if shard_mode == "explicit" else AxisType.Auto - axis_names = [axis_type for _ in cfg_cp.mesh_axes] - mesh_cp = Mesh(devices_array_cp, cfg_cp.mesh_axes, axis_types=tuple(axis_names)) - attention_as_mha_flash_cp = Attention( - config=cfg_cp, - num_query_heads=cfg_cp.num_query_heads, - num_kv_heads=num_kv_heads, - head_dim=cfg_cp.head_dim, - max_target_length=cfg_cp.max_target_length, - max_prefill_predict_length=cfg_cp.max_prefill_predict_length, - inputs_q_shape=lnx.shape, - inputs_kv_shape=lnx.shape, - mesh=mesh_cp, - attention_kernel="flash", - dtype=self.dtype, - dropout_rate=cfg_cp.dropout_rate, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) - nnx.update(attention_as_mha_flash_cp, generic_state) - - mha_generic_flash_cp_output = ( - attention_test_util.forward_with_context_expert_parallelism( - cfg_cp, - mesh_cp, - attention_as_mha_flash_cp, - lnx, - decoder_segment_ids, - decoder_positions, - ) - ) + # 2. Forward Pass Verification + lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) - # This removes all sharding information and makes them standard NumPy arrays. - mha_generic_output = jax.device_get(mha_generic_output) - mha_generic_flash_cp_output = jax.device_get(mha_generic_flash_cp_output) - - self.assertTrue( - jax.numpy.allclose( - mha_generic_output, - mha_generic_flash_cp_output, - rtol=1e-01, - atol=1e-01, - equal_nan=False, - ), - msg="Logits from generic dot product and flash attention + context/expert parallelism are not close.\n" - f"ici_context_parallelism={ici_context_parallelism}, context_parallel_load_balance={context_parallel_load_balance}," - f" ici_expert_parallelism={ici_expert_parallelism}, expert_shard_attention_option={expert_shard_attention_option}.", - ) + output_shared, _ = attention_share_kv( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - @pytest.mark.tpu_only - def test_dot_product_cache_axis_order(self): - all_axis_orders = tuple(itertools.permutations(range(4))) - for axis_order in random.choices(all_axis_orders, k=4): - self.dot_product_attention_helper( - prefill_cache_axis_order=axis_order, ar_cache_axis_order=axis_order - ) - print(f"passed test for {axis_order=}") - - def dot_product_attention_helper( - self, prefill_cache_axis_order, ar_cache_axis_order - ): - for compute_axis_order in [(0, 1, 2, 3), (0, 2, 1, 3)]: - self._dot_product_attention( - prefill_cache_axis_order, - ar_cache_axis_order, - compute_axis_order=compute_axis_order, - ) - print(f"passed subtest for {compute_axis_order=}") - - def _dot_product_attention( - self, - prefill_cache_axis_order, - ar_cache_axis_order, - compute_axis_order, - ): - """Test equalvant between different layout control in dot_product""" - - rtol, atol = 1e-02, 1e-02 - - config = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - per_device_batch_size=1.0, - run_name="test", - enable_checkpointing=False, - max_target_length=128, - max_prefill_predict_length=16, - attention="dot_product", - ) + self.assertEqual( + output_shared.shape, + (self.global_batch_size, self.max_target_length, self.embed_dim), + ) - prefill_length = config.max_prefill_predict_length - decode_total_length = config.max_target_length - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( - config.dtype - ) - lnx_prefill = lnx[:, 0:prefill_length, :] - decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] - decoder_positions_prefill = decoder_positions[:, 0:prefill_length] + # 3. Equivalence Check with standard unshared Attention + attention_no_share = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.cfg.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + share_kv_projections=False, + rngs=self.nnx_rng, + ) - dummy_inputs_q = jnp.ones( - (self.global_batch_size, config.max_target_length, config.base_emb_dim) - ) - dummy_inputs_kv = jnp.ones( - (self.global_batch_size, config.max_target_length, config.base_emb_dim) - ) - attention_w_layout = Attention( - mesh=self.mesh, - config=config, - num_query_heads=config.num_query_heads, - num_kv_heads=config.num_kv_heads, - head_dim=config.head_dim, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - max_target_length=config.max_target_length, - max_prefill_predict_length=config.max_prefill_predict_length, - attention_kernel=config.attention, - dtype=config.dtype, - prefill_cache_axis_order=prefill_cache_axis_order, - ar_cache_axis_order=ar_cache_axis_order, - compute_axis_order=compute_axis_order, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) - attention_w_layout_full, _ = attention_w_layout( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + # Force unshared layer to copy weights from shared layer, mapping 'key' to 'value' + attention_no_share.query.kernel.value = attention_share_kv.query.kernel.value + attention_no_share.key.kernel.value = attention_share_kv.key.kernel.value + attention_no_share.value.kernel.value = attention_share_kv.key.kernel.value + attention_no_share.out.kernel.value = attention_share_kv.out.kernel.value + + output_no_share, _ = attention_no_share( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - attention_w_layout_prefill, _ = attention_w_layout( - lnx_prefill, - lnx_prefill, - decoder_segment_ids=decoder_segment_ids_prefill, - inputs_positions=decoder_positions_prefill, - deterministic=True, - model_mode=MODEL_MODE_PREFILL, - ) - self.assertTrue( - jax.numpy.allclose( - attention_w_layout_full[:, :prefill_length, :], - attention_w_layout_prefill, - equal_nan=False, - ) - ) + self.assertTrue(jax.numpy.allclose(output_shared, output_no_share, rtol=1e-04, atol=1e-04, equal_nan=False)) + + @parameterized.named_parameters( + { + "testcase_name": "cp_no_load_balance", + "ici_context_parallelism": 4, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_with_load_balance", + "ici_context_parallelism": 4, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_ep_no_load_balance", + "ici_context_parallelism": 2, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_ep_with_load_balance", + "ici_context_parallelism": 2, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "ep_no_load_balance", + "ici_context_parallelism": 1, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "ep_with_load_balance", + "ici_context_parallelism": 1, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_no_load_balance_explicit", + "ici_context_parallelism": 4, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_with_load_balance_explicit", + "ici_context_parallelism": 4, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_ep_no_load_balance_explicit", + "ici_context_parallelism": 2, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_ep_with_load_balance_explicit", + "ici_context_parallelism": 2, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "ep_no_load_balance_explicit", + "ici_context_parallelism": 1, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "ep_with_load_balance_explicit", + "ici_context_parallelism": 1, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + ) + # TODO (b/454764135.) : This tests fails with new tokamax kernel + @pytest.mark.tpu_only + def test_tpu_flash_attention_context_parallel( + self, + ici_context_parallelism, + context_parallel_load_balance, + ici_expert_parallelism, + expert_shard_attention_option, + shard_mode, + ): + """Test equivalence between dot_product and flash attention + context/expert parallelism""" + num_kv_heads = self.num_kv_heads + lnx, decoder_segment_ids, decoder_positions = self.get_data(self.dtype) + # Dot product + mha_generic_output, _ = self._attention_as_mha_generic( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + generic_state = nnx.state(self._attention_as_mha_generic) + + # Test with Context Parallelism + cfg_cp = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **self.config_arguments, + ici_context_parallelism=ici_context_parallelism, + context_parallel_load_balance=context_parallel_load_balance, + ici_expert_parallelism=ici_expert_parallelism, + expert_shard_attention_option=expert_shard_attention_option, + shard_mode=shard_mode, + ) + devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp) + axis_type = AxisType.Explicit if shard_mode == "explicit" else AxisType.Auto + axis_names = [axis_type for _ in cfg_cp.mesh_axes] + mesh_cp = Mesh(devices_array_cp, cfg_cp.mesh_axes, axis_types=tuple(axis_names)) + attention_as_mha_flash_cp = Attention( + config=cfg_cp, + num_query_heads=cfg_cp.num_query_heads, + num_kv_heads=num_kv_heads, + head_dim=cfg_cp.head_dim, + max_target_length=cfg_cp.max_target_length, + max_prefill_predict_length=cfg_cp.max_prefill_predict_length, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, + mesh=mesh_cp, + attention_kernel="flash", + dtype=self.dtype, + dropout_rate=cfg_cp.dropout_rate, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) + nnx.update(attention_as_mha_flash_cp, generic_state) + + mha_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism( + cfg_cp, + mesh_cp, + attention_as_mha_flash_cp, + lnx, + decoder_segment_ids, + decoder_positions, + ) - for idx in range(prefill_length, decode_total_length): - lnx_idx = lnx[:, idx : idx + 1, :] - decoder_positions_idx = decoder_positions[:, idx : idx + 1] - - attention_w_layout_idx, _ = attention_w_layout( - lnx_idx, - lnx_idx, - inputs_positions=decoder_positions_idx, - deterministic=True, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - ) - - attention_w_layout_full_this_idx = attention_w_layout_full[ - :, idx : idx + 1, : - ] - self.assertTrue( - attention_w_layout_full_this_idx.shape == attention_w_layout_idx.shape - ) - self.assertTrue( - jax.numpy.allclose( - attention_w_layout_full_this_idx, - attention_w_layout_idx, - rtol=rtol, - atol=atol, - equal_nan=False, - ) - ) - - @pytest.mark.tpu_only - def test_dot_product_reshape_q(self): - for compute_axis_order in [(0, 1, 2, 3), (0, 2, 1, 3)]: - self._dot_product_attention_reshape_q( - compute_axis_order=compute_axis_order, - ) - print(f"test passed for compute_axis_order: {compute_axis_order}") - - def _dot_product_attention_reshape_q(self, compute_axis_order): - """Test equalvant between q and reshape q in dot_product""" - - rtol, atol = 1e-02, 1e-02 - - config = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - per_device_batch_size=1.0, - run_name="test", - enable_checkpointing=False, - max_target_length=128, - max_prefill_predict_length=16, - attention="dot_product", - ) + # This removes all sharding information and makes them standard NumPy arrays. + mha_generic_output = jax.device_get(mha_generic_output) + mha_generic_flash_cp_output = jax.device_get(mha_generic_flash_cp_output) + + self.assertTrue( + jax.numpy.allclose( + mha_generic_output, + mha_generic_flash_cp_output, + rtol=1e-01, + atol=1e-01, + equal_nan=False, + ), + msg="Logits from generic dot product and flash attention + context/expert parallelism are not close.\n" + f"ici_context_parallelism={ici_context_parallelism}, context_parallel_load_balance={context_parallel_load_balance}," + f" ici_expert_parallelism={ici_expert_parallelism}, expert_shard_attention_option={expert_shard_attention_option}.", + ) - prefill_length = config.max_prefill_predict_length - decode_total_length = config.max_target_length - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( - config.dtype - ) + @pytest.mark.tpu_only + def test_dot_product_cache_axis_order(self): + all_axis_orders = tuple(itertools.permutations(range(4))) + for axis_order in random.choices(all_axis_orders, k=4): + self.dot_product_attention_helper(prefill_cache_axis_order=axis_order, ar_cache_axis_order=axis_order) + print(f"passed test for {axis_order=}") + + def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order): + for compute_axis_order in [(0, 1, 2, 3), (0, 2, 1, 3)]: + self._dot_product_attention( + prefill_cache_axis_order, + ar_cache_axis_order, + compute_axis_order=compute_axis_order, + ) + print(f"passed subtest for {compute_axis_order=}") + + def _dot_product_attention( + self, + prefill_cache_axis_order, + ar_cache_axis_order, + compute_axis_order, + ): + """Test equalvant between different layout control in dot_product""" + + rtol, atol = 1e-02, 1e-02 + + config = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + max_target_length=128, + max_prefill_predict_length=16, + attention="dot_product", + ) - lnx_prefill = lnx[:, 0:prefill_length, :] - decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] - decoder_positions_prefill = decoder_positions[:, 0:prefill_length] + prefill_length = config.max_prefill_predict_length + decode_total_length = config.max_target_length + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(config.dtype) + lnx_prefill = lnx[:, 0:prefill_length, :] + decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] + decoder_positions_prefill = decoder_positions[:, 0:prefill_length] + + dummy_inputs_q = jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)) + dummy_inputs_kv = jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)) + attention_w_layout = Attention( + mesh=self.mesh, + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + dtype=config.dtype, + prefill_cache_axis_order=prefill_cache_axis_order, + ar_cache_axis_order=ar_cache_axis_order, + compute_axis_order=compute_axis_order, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) + attention_w_layout_full, _ = attention_w_layout( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - dummy_inputs_q = jnp.ones( - (self.global_batch_size, config.max_target_length, config.base_emb_dim) - ) - dummy_inputs_kv = jnp.ones( - (self.global_batch_size, config.max_target_length, config.base_emb_dim) + attention_w_layout_prefill, _ = attention_w_layout( + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) + self.assertTrue( + jax.numpy.allclose( + attention_w_layout_full[:, :prefill_length, :], + attention_w_layout_prefill, + equal_nan=False, ) + ) - attention_wo_reshape_q = Attention( - mesh=self.mesh, - config=config, - num_query_heads=config.num_query_heads, - num_kv_heads=config.num_kv_heads, - head_dim=config.head_dim, - max_target_length=config.max_target_length, - max_prefill_predict_length=config.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - attention_kernel=config.attention, - dtype=config.dtype, - compute_axis_order=compute_axis_order, - reshape_q=False, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) + for idx in range(prefill_length, decode_total_length): + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] + + attention_w_layout_idx, _ = attention_w_layout( + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + ) + + attention_w_layout_full_this_idx = attention_w_layout_full[:, idx : idx + 1, :] + self.assertTrue(attention_w_layout_full_this_idx.shape == attention_w_layout_idx.shape) + self.assertTrue( + jax.numpy.allclose( + attention_w_layout_full_this_idx, + attention_w_layout_idx, + rtol=rtol, + atol=atol, + equal_nan=False, + ) + ) + + @pytest.mark.tpu_only + def test_dot_product_reshape_q(self): + for compute_axis_order in [(0, 1, 2, 3), (0, 2, 1, 3)]: + self._dot_product_attention_reshape_q( + compute_axis_order=compute_axis_order, + ) + print(f"test passed for compute_axis_order: {compute_axis_order}") + + def _dot_product_attention_reshape_q(self, compute_axis_order): + """Test equalvant between q and reshape q in dot_product""" + + rtol, atol = 1e-02, 1e-02 + + config = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + max_target_length=128, + max_prefill_predict_length=16, + attention="dot_product", + ) - attention_w_reshape_q = Attention( - mesh=self.mesh, - config=config, - num_query_heads=config.num_query_heads, - num_kv_heads=config.num_kv_heads, - head_dim=config.head_dim, - max_target_length=config.max_target_length, - max_prefill_predict_length=config.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - attention_kernel=config.attention, - dtype=config.dtype, - compute_axis_order=compute_axis_order, - reshape_q=True, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) + prefill_length = config.max_prefill_predict_length + decode_total_length = config.max_target_length + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(config.dtype) + + lnx_prefill = lnx[:, 0:prefill_length, :] + decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] + decoder_positions_prefill = decoder_positions[:, 0:prefill_length] + + dummy_inputs_q = jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)) + dummy_inputs_kv = jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)) + + attention_wo_reshape_q = Attention( + mesh=self.mesh, + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + attention_kernel=config.attention, + dtype=config.dtype, + compute_axis_order=compute_axis_order, + reshape_q=False, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) - attention_wo_reshape_q_state = nnx.state(attention_wo_reshape_q) - nnx.update(attention_w_reshape_q, attention_wo_reshape_q_state) + attention_w_reshape_q = Attention( + mesh=self.mesh, + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + attention_kernel=config.attention, + dtype=config.dtype, + compute_axis_order=compute_axis_order, + reshape_q=True, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) - attention_wo_reshape_q_full, _ = attention_wo_reshape_q( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + attention_wo_reshape_q_state = nnx.state(attention_wo_reshape_q) + nnx.update(attention_w_reshape_q, attention_wo_reshape_q_state) - attention_w_reshape_q_full, _ = attention_w_reshape_q( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) - - attention_wo_reshape_q_prefill, _ = attention_wo_reshape_q( - lnx_prefill, - lnx_prefill, - decoder_segment_ids=decoder_segment_ids_prefill, - inputs_positions=decoder_positions_prefill, - deterministic=True, - model_mode=MODEL_MODE_PREFILL, - ) - self.assertTrue( - jax.numpy.allclose( - attention_wo_reshape_q_full[:, :prefill_length, :], - attention_wo_reshape_q_prefill, - equal_nan=False, - ) - ) + attention_wo_reshape_q_full, _ = attention_wo_reshape_q( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - attention_w_reshape_q_prefill, _ = attention_w_reshape_q( - lnx_prefill, - lnx_prefill, - decoder_segment_ids=decoder_segment_ids_prefill, - inputs_positions=decoder_positions_prefill, - deterministic=True, - model_mode=MODEL_MODE_PREFILL, - ) - self.assertTrue( - jax.numpy.allclose( - attention_w_reshape_q_full[:, :prefill_length, :], - attention_w_reshape_q_prefill, - equal_nan=False, - ) - ) + attention_w_reshape_q_full, _ = attention_w_reshape_q( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - self.assertTrue( - jax.numpy.allclose( - attention_wo_reshape_q_prefill, - attention_w_reshape_q_prefill, - equal_nan=False, - ) - ) - self.assertTrue( - jax.numpy.allclose( - attention_wo_reshape_q_full[:, :prefill_length, :], - attention_w_reshape_q_full[:, :prefill_length, :], - equal_nan=False, - ) + attention_wo_reshape_q_prefill, _ = attention_wo_reshape_q( + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) + self.assertTrue( + jax.numpy.allclose( + attention_wo_reshape_q_full[:, :prefill_length, :], + attention_wo_reshape_q_prefill, + equal_nan=False, ) + ) - for idx in range(prefill_length, decode_total_length): - lnx_idx = lnx[:, idx : idx + 1, :] - decoder_positions_idx = decoder_positions[:, idx : idx + 1] - - attention_wo_reshape_q_idx, _ = attention_wo_reshape_q( - lnx_idx, - lnx_idx, - inputs_positions=decoder_positions_idx, - deterministic=True, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - ) - - attention_wo_reshape_q_full_this_idx = attention_wo_reshape_q_full[ - :, idx : idx + 1, : - ] - self.assertTrue( - attention_wo_reshape_q_full_this_idx.shape - == attention_wo_reshape_q_idx.shape - ) - self.assertTrue( - jax.numpy.allclose( - attention_wo_reshape_q_full_this_idx, - attention_wo_reshape_q_idx, - rtol=rtol, - atol=atol, - equal_nan=False, - ) - ) - - attention_w_reshape_q_idx, _ = attention_w_reshape_q( - lnx_idx, - lnx_idx, - inputs_positions=decoder_positions_idx, - deterministic=True, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - ) - - attention_w_reshape_q_full_this_idx = attention_w_reshape_q_full[ - :, idx : idx + 1, : - ] - self.assertTrue( - attention_w_reshape_q_full_this_idx.shape - == attention_w_reshape_q_idx.shape - ) - self.assertTrue( - jax.numpy.allclose( - attention_w_reshape_q_full_this_idx, - attention_w_reshape_q_idx, - rtol=rtol, - atol=atol, - equal_nan=False, - ) - ) - - self.assertTrue( - jax.numpy.allclose( - attention_w_reshape_q_idx, - attention_wo_reshape_q_idx, - rtol=rtol, - atol=atol, - equal_nan=False, - ) - ) - - def test_sliding_window_attention(self): - """Test sliding window attention""" - - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( - self.dtype + attention_w_reshape_q_prefill, _ = attention_w_reshape_q( + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) + self.assertTrue( + jax.numpy.allclose( + attention_w_reshape_q_full[:, :prefill_length, :], + attention_w_reshape_q_prefill, + equal_nan=False, ) + ) - dummy_inputs_q = jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim) - ) - dummy_inputs_kv = jnp.ones( - (self.global_batch_size, self.max_target_length, self.embed_dim) + self.assertTrue( + jax.numpy.allclose( + attention_wo_reshape_q_prefill, + attention_w_reshape_q_prefill, + equal_nan=False, ) - - # Global Attention - global_attn = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - mesh=self.mesh, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - attention_type=AttentionType.GLOBAL, - model_mode=MODEL_MODE_TRAIN, - rngs=self.nnx_rng, + ) + self.assertTrue( + jax.numpy.allclose( + attention_wo_reshape_q_full[:, :prefill_length, :], + attention_w_reshape_q_full[:, :prefill_length, :], + equal_nan=False, ) + ) - # Attention with sliding window of size 8 - sliding_attn = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - mesh=self.mesh, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - attention_type=AttentionType.LOCAL_SLIDING, - sliding_window_size=8, - model_mode=MODEL_MODE_TRAIN, - rngs=self.nnx_rng, - ) + for idx in range(prefill_length, decode_total_length): + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] + + attention_wo_reshape_q_idx, _ = attention_wo_reshape_q( + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + ) + + attention_wo_reshape_q_full_this_idx = attention_wo_reshape_q_full[:, idx : idx + 1, :] + self.assertTrue(attention_wo_reshape_q_full_this_idx.shape == attention_wo_reshape_q_idx.shape) + self.assertTrue( + jax.numpy.allclose( + attention_wo_reshape_q_full_this_idx, + attention_wo_reshape_q_idx, + rtol=rtol, + atol=atol, + equal_nan=False, + ) + ) + + attention_w_reshape_q_idx, _ = attention_w_reshape_q( + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + ) + + attention_w_reshape_q_full_this_idx = attention_w_reshape_q_full[:, idx : idx + 1, :] + self.assertTrue(attention_w_reshape_q_full_this_idx.shape == attention_w_reshape_q_idx.shape) + self.assertTrue( + jax.numpy.allclose( + attention_w_reshape_q_full_this_idx, + attention_w_reshape_q_idx, + rtol=rtol, + atol=atol, + equal_nan=False, + ) + ) + + self.assertTrue( + jax.numpy.allclose( + attention_w_reshape_q_idx, + attention_wo_reshape_q_idx, + rtol=rtol, + atol=atol, + equal_nan=False, + ) + ) + + def test_sliding_window_attention(self): + """Test sliding window attention""" + + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) + + dummy_inputs_q = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) + dummy_inputs_kv = jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)) + + # Global Attention + global_attn = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + mesh=self.mesh, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + attention_type=AttentionType.GLOBAL, + model_mode=MODEL_MODE_TRAIN, + rngs=self.nnx_rng, + ) - # To share parameters, we copy the state from sliding_attn to global_attn. - sliding_attn_state = nnx.state(sliding_attn) - nnx.update(global_attn, sliding_attn_state) - - global_attn_output, _ = global_attn( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + # Attention with sliding window of size 8 + sliding_attn = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + mesh=self.mesh, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + attention_type=AttentionType.LOCAL_SLIDING, + sliding_window_size=8, + model_mode=MODEL_MODE_TRAIN, + rngs=self.nnx_rng, + ) - sliding_window_output, _ = sliding_attn( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + # To share parameters, we copy the state from sliding_attn to global_attn. + sliding_attn_state = nnx.state(sliding_attn) + nnx.update(global_attn, sliding_attn_state) + + global_attn_output, _ = global_attn( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - # Test if sliding window attention is different from global attention - self.assertFalse( - jax.numpy.allclose( - sliding_window_output.astype(jnp.bfloat16), - global_attn_output.astype(jnp.bfloat16), - rtol=1e-04, - atol=1e-04, - ) - ) + sliding_window_output, _ = sliding_attn( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - # Attention with sliding window of size max_target_length - # This should be equivalent to global attention. - sliding_attn_full_window = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - mesh=self.mesh, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - attention_kernel="dot_product", - dtype=self.dtype, - dropout_rate=self.cfg.dropout_rate, - attention_type=AttentionType.LOCAL_SLIDING, - sliding_window_size=self.max_target_length, - model_mode=MODEL_MODE_TRAIN, - rngs=self.nnx_rng, + # Test if sliding window attention is different from global attention + self.assertFalse( + jax.numpy.allclose( + sliding_window_output.astype(jnp.bfloat16), + global_attn_output.astype(jnp.bfloat16), + rtol=1e-04, + atol=1e-04, ) + ) - nnx.update(sliding_attn_full_window, sliding_attn_state) + # Attention with sliding window of size max_target_length + # This should be equivalent to global attention. + sliding_attn_full_window = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + mesh=self.mesh, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + attention_kernel="dot_product", + dtype=self.dtype, + dropout_rate=self.cfg.dropout_rate, + attention_type=AttentionType.LOCAL_SLIDING, + sliding_window_size=self.max_target_length, + model_mode=MODEL_MODE_TRAIN, + rngs=self.nnx_rng, + ) - sliding_window_output_full, _ = sliding_attn_full_window( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) + nnx.update(sliding_attn_full_window, sliding_attn_state) - print(f"{sliding_window_output_full.astype(jnp.bfloat16)=}") - print(f"{global_attn_output.astype(jnp.bfloat16)=}") - - # Test if sliding window attention with max_target_length size is the same as global attention - self.assertTrue( - jax.numpy.allclose( - sliding_window_output_full.astype(jnp.bfloat16), - global_attn_output.astype(jnp.bfloat16), - rtol=1e-04, - atol=1e-04, - ) - ) + sliding_window_output_full, _ = sliding_attn_full_window( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - @pytest.mark.skip( - reason="Requires `vllm-tpu` package which is not yet a MaxText dependency." - ) - @pytest.mark.tpu_only - @mock.patch( - "tpu_inference.layers.jax.attention_interface.sharded_ragged_paged_attention", - create=True, - ) - def test_forward_serve_vllm(self, mock_sharded_ragged_paged_attention): - """Tests the forward_serve_vllm method with mocked RPA attention.""" - # Setup config for vLLM RPA - vllm_config_arguments = self.config_arguments.copy() - vllm_config_arguments["attention"] = "vllm_rpa" - vllm_config_arguments["chunk_attn_window_size"] = 128 - config = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **vllm_config_arguments, - ) + print(f"{sliding_window_output_full.astype(jnp.bfloat16)=}") + print(f"{global_attn_output.astype(jnp.bfloat16)=}") - seq_len = self.max_target_length - - # Create Attention instance - dummy_inputs_q = jnp.ones((self.global_batch_size, seq_len, self.embed_dim)) - dummy_inputs_kv = jnp.ones((self.global_batch_size, seq_len, self.embed_dim)) - attention_vllm = Attention( - config=config, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=self.dtype, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - rngs=self.nnx_rng, + # Test if sliding window attention with max_target_length size is the same as global attention + self.assertTrue( + jax.numpy.allclose( + sliding_window_output_full.astype(jnp.bfloat16), + global_attn_output.astype(jnp.bfloat16), + rtol=1e-04, + atol=1e-04, ) + ) - # Prepare inputs - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( - self.dtype - ) - mock_kv_cache = [jnp.ones((1,))] + @pytest.mark.skip(reason="Requires `vllm-tpu` package which is not yet a MaxText dependency.") + @pytest.mark.tpu_only + @mock.patch( + "tpu_inference.layers.jax.attention_interface.sharded_ragged_paged_attention", + create=True, + ) + def test_forward_serve_vllm(self, mock_sharded_ragged_paged_attention): + """Tests the forward_serve_vllm method with mocked RPA attention.""" + # Setup config for vLLM RPA + vllm_config_arguments = self.config_arguments.copy() + vllm_config_arguments["attention"] = "vllm_rpa" + vllm_config_arguments["chunk_attn_window_size"] = 128 + config = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **vllm_config_arguments, + ) - mock_attention_metadata = mock.Mock() - mock_attention_metadata.seq_lens = jnp.array([1] * self.global_batch_size) - mock_attention_metadata.block_tables = jnp.array([[0]] * self.global_batch_size) - mock_attention_metadata.query_start_loc = jnp.array( - list(range(self.global_batch_size)) - ) - mock_attention_metadata.request_distribution = jnp.array( - [self.global_batch_size] - ) + seq_len = self.max_target_length + + # Create Attention instance + dummy_inputs_q = jnp.ones((self.global_batch_size, seq_len, self.embed_dim)) + dummy_inputs_kv = jnp.ones((self.global_batch_size, seq_len, self.embed_dim)) + attention_vllm = Attention( + config=config, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + rngs=self.nnx_rng, + ) - # Mock the return value of sharded_ragged_paged_attention - total_tokens = self.global_batch_size * seq_len - mock_output_shape = (total_tokens, self.num_query_heads, self.head_dim) - mock_output = jnp.ones(mock_output_shape, dtype=self.dtype) - mock_updated_kv_cache = [jnp.zeros((1,))] - - mock_callable = mock.Mock(return_value=(mock_output, mock_updated_kv_cache)) - mock_sharded_ragged_paged_attention.return_value = mock_callable - - # Call the attention layer - output, updated_kv_cache = attention_vllm( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - kv_cache=mock_kv_cache, - attention_metadata=mock_attention_metadata, - ) + # Prepare inputs + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) + mock_kv_cache = [jnp.ones((1,))] + + mock_attention_metadata = mock.Mock() + mock_attention_metadata.seq_lens = jnp.array([1] * self.global_batch_size) + mock_attention_metadata.block_tables = jnp.array([[0]] * self.global_batch_size) + mock_attention_metadata.query_start_loc = jnp.array(list(range(self.global_batch_size))) + mock_attention_metadata.request_distribution = jnp.array([self.global_batch_size]) + + # Mock the return value of sharded_ragged_paged_attention + total_tokens = self.global_batch_size * seq_len + mock_output_shape = (total_tokens, self.num_query_heads, self.head_dim) + mock_output = jnp.ones(mock_output_shape, dtype=self.dtype) + mock_updated_kv_cache = [jnp.zeros((1,))] + + mock_callable = mock.Mock(return_value=(mock_output, mock_updated_kv_cache)) + mock_sharded_ragged_paged_attention.return_value = mock_callable + + # Call the attention layer + output, updated_kv_cache = attention_vllm( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + kv_cache=mock_kv_cache, + attention_metadata=mock_attention_metadata, + ) - # Assertions - mock_sharded_ragged_paged_attention.assert_called_once() - mock_callable.assert_called_once() - self.assertEqual(updated_kv_cache, mock_updated_kv_cache) + # Assertions + mock_sharded_ragged_paged_attention.assert_called_once() + mock_callable.assert_called_once() + self.assertEqual(updated_kv_cache, mock_updated_kv_cache) - # The output of forward_serve_vllm is reshaped back to (batch, seq, ...) - reshaped_mock_output = mock_output.reshape( - self.global_batch_size, seq_len, self.num_query_heads, self.head_dim - ) - expected_output = attention_vllm.out_projection(reshaped_mock_output) - self.assertTrue(jnp.allclose(output, expected_output)) - self.assertEqual( - output.shape, (self.global_batch_size, seq_len, self.embed_dim) - ) + # The output of forward_serve_vllm is reshaped back to (batch, seq, ...) + reshaped_mock_output = mock_output.reshape(self.global_batch_size, seq_len, self.num_query_heads, self.head_dim) + expected_output = attention_vllm.out_projection(reshaped_mock_output) + self.assertTrue(jnp.allclose(output, expected_output)) + self.assertEqual(output.shape, (self.global_batch_size, seq_len, self.embed_dim)) class MLATest(attention_test_util.MLATestBase): - """Test for the Multi-Headed Latent Attention""" - - @parameterized.named_parameters( - {"testcase_name": "RoPE_Yarn_Autoregression", "rope_type": "yarn"}, - {"testcase_name": "Default_Autoregression", "rope_type": "default"}, - ) - @pytest.mark.tpu_only - def test_autoregression(self, rope_type): - cfg, mla = self.init_mla(self.config_arguments, rope_type) - prefill_length = cfg.max_prefill_predict_length - decode_total_length = cfg.max_target_length - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data( - cfg, cfg.dtype - ) - - mla_full, _ = mla( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) - - lnx_prefill = lnx[:, 0:prefill_length, :] - decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] - decoder_positions_prefill = decoder_positions[:, 0:prefill_length] - - mla_prefill, _ = mla( - lnx_prefill, - lnx_prefill, - decoder_segment_ids=decoder_segment_ids_prefill, - inputs_positions=decoder_positions_prefill, - deterministic=True, - model_mode=MODEL_MODE_PREFILL, - ) + """Test for the Multi-Headed Latent Attention""" + + @parameterized.named_parameters( + {"testcase_name": "RoPE_Yarn_Autoregression", "rope_type": "yarn"}, + {"testcase_name": "Default_Autoregression", "rope_type": "default"}, + ) + @pytest.mark.tpu_only + def test_autoregression(self, rope_type): + cfg, mla = self.init_mla(self.config_arguments, rope_type) + prefill_length = cfg.max_prefill_predict_length + decode_total_length = cfg.max_target_length + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(cfg, cfg.dtype) + + mla_full, _ = mla( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) - self.assertTrue( - jax.numpy.allclose( - mla_prefill, - mla_full[:, :prefill_length, :], - rtol=1e-02, - atol=1e-02, - equal_nan=False, - ) - ) + lnx_prefill = lnx[:, 0:prefill_length, :] + decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] + decoder_positions_prefill = decoder_positions[:, 0:prefill_length] + + mla_prefill, _ = mla( + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) - for idx in range(prefill_length, decode_total_length): - lnx_idx = lnx[:, idx : idx + 1, :] - decoder_positions_idx = decoder_positions[:, idx : idx + 1] - mla_idx, _ = mla( - lnx_idx, - lnx_idx, - inputs_positions=decoder_positions_idx, - deterministic=True, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - ) - - mla_full_this_idx = mla_full[:, idx : idx + 1, :] - self.assertEqual(mla_full_this_idx.shape, mla_idx.shape) - # TODO (b/394626702) uncomment last check when decode and kv_cache are implemented for MLA - # self.assertTrue(jax.numpy.allclose(mla_full_this_idx, mla_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) - - def test_projection_initialization(self): - """Tests that MLA and Attention layers initialize the correct projection weights.""" - # 1. Initialize a standard Attention layer for comparison - # Create a copy of the arguments and override the attention_type for the base model - attention_config_args = self.config_arguments.copy() - attention_config_args["attention_type"] = AttentionType.GLOBAL.value - extra_args = get_decoupled_parallelism_overrides() - attention_cfg = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **attention_config_args, - **extra_args, - ) - dummy_inputs_q = jnp.ones( - ( - attention_cfg.global_batch_size_to_train_on, - attention_cfg.max_target_length, - attention_cfg.base_emb_dim, - ) - ) - dummy_inputs_kv = jnp.ones( - ( - attention_cfg.global_batch_size_to_train_on, - attention_cfg.max_target_length, - attention_cfg.base_emb_dim, - ) + self.assertTrue( + jax.numpy.allclose( + mla_prefill, + mla_full[:, :prefill_length, :], + rtol=1e-02, + atol=1e-02, + equal_nan=False, ) + ) - base_attention = Attention( - config=attention_cfg, - num_query_heads=attention_cfg.num_query_heads, - num_kv_heads=attention_cfg.num_kv_heads, - head_dim=attention_cfg.head_dim, - max_target_length=attention_cfg.max_target_length, - max_prefill_predict_length=attention_cfg.max_prefill_predict_length, - inputs_q_shape=dummy_inputs_q.shape, - inputs_kv_shape=dummy_inputs_kv.shape, - mesh=self.mesh, - attention_kernel="dot_product", - dtype=attention_cfg.dtype, - rngs=self.nnx_rng, - ) + for idx in range(prefill_length, decode_total_length): + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] + mla_idx, _ = mla( + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + ) + + mla_full_this_idx = mla_full[:, idx : idx + 1, :] + self.assertEqual(mla_full_this_idx.shape, mla_idx.shape) + # TODO (b/394626702) uncomment last check when decode and kv_cache are implemented for MLA + # self.assertTrue(jax.numpy.allclose(mla_full_this_idx, mla_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) + + def test_projection_initialization(self): + """Tests that MLA and Attention layers initialize the correct projection weights.""" + # 1. Initialize a standard Attention layer for comparison + # Create a copy of the arguments and override the attention_type for the base model + attention_config_args = self.config_arguments.copy() + attention_config_args["attention_type"] = AttentionType.GLOBAL.value + extra_args = get_decoupled_parallelism_overrides() + attention_cfg = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **attention_config_args, + **extra_args, + ) + dummy_inputs_q = jnp.ones(( + attention_cfg.global_batch_size_to_train_on, + attention_cfg.max_target_length, + attention_cfg.base_emb_dim, + )) + dummy_inputs_kv = jnp.ones(( + attention_cfg.global_batch_size_to_train_on, + attention_cfg.max_target_length, + attention_cfg.base_emb_dim, + )) + + base_attention = Attention( + config=attention_cfg, + num_query_heads=attention_cfg.num_query_heads, + num_kv_heads=attention_cfg.num_kv_heads, + head_dim=attention_cfg.head_dim, + max_target_length=attention_cfg.max_target_length, + max_prefill_predict_length=attention_cfg.max_prefill_predict_length, + inputs_q_shape=dummy_inputs_q.shape, + inputs_kv_shape=dummy_inputs_kv.shape, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=attention_cfg.dtype, + rngs=self.nnx_rng, + ) - # 2. Assert that the base Attention layer HAS all its standard projections - self.assertTrue( - hasattr(base_attention, "query"), - "Base Attention should have 'query' projection.", - ) - self.assertTrue( - hasattr(base_attention, "key"), - "Base Attention should have 'key' projection.", - ) - self.assertTrue( - hasattr(base_attention, "value"), - "Base Attention should have 'value' projection.", - ) - self.assertTrue( - hasattr(base_attention, "out"), - "Base Attention should have 'out' projection.", - ) + # 2. Assert that the base Attention layer HAS all its standard projections + self.assertTrue( + hasattr(base_attention, "query"), + "Base Attention should have 'query' projection.", + ) + self.assertTrue( + hasattr(base_attention, "key"), + "Base Attention should have 'key' projection.", + ) + self.assertTrue( + hasattr(base_attention, "value"), + "Base Attention should have 'value' projection.", + ) + self.assertTrue( + hasattr(base_attention, "out"), + "Base Attention should have 'out' projection.", + ) - # 3. Initialize the MLA layer - mla_config_args = self.config_arguments.copy() - mla_extra_args = get_decoupled_parallelism_overrides() - mla_config_args.update(mla_extra_args) - _, mla_layer = self.init_mla(mla_config_args, rope_type="default") + # 3. Initialize the MLA layer + mla_config_args = self.config_arguments.copy() + mla_extra_args = get_decoupled_parallelism_overrides() + mla_config_args.update(mla_extra_args) + _, mla_layer = self.init_mla(mla_config_args, rope_type="default") + + # 4. Assert that the MLA layer DOES NOT HAVE the base projections + self.assertFalse(hasattr(mla_layer, "query"), "MLA should not have 'query' projection.") + self.assertFalse(hasattr(mla_layer, "key"), "MLA should not have 'key' projection.") + self.assertFalse(hasattr(mla_layer, "value"), "MLA should not have 'value' projection.") + + # 5. Assert that the MLA layer HAS all of its own specific projections AND the common 'out' projection + self.assertTrue(hasattr(mla_layer, "wq_a"), "MLA should have 'wq_a' projection.") + self.assertTrue(hasattr(mla_layer, "wq_b"), "MLA should have 'wq_b' projection.") + self.assertTrue(hasattr(mla_layer, "wkv_a"), "MLA should have 'wkv_a' projection.") + self.assertTrue(hasattr(mla_layer, "wkv_b"), "MLA should have 'wkv_b' projection.") + self.assertTrue(hasattr(mla_layer, "q_norm"), "MLA should have 'q_norm' projection.") + self.assertTrue(hasattr(mla_layer, "kv_norm"), "MLA should have 'kv_norm' projection.") + self.assertTrue(hasattr(mla_layer, "out"), "MLA should have 'out' projection.") + + def test_fused_mla_lora_proj_output_equivalence(self): + """Tests that fused_mla_lora_proj=True produces identical outputs to fused_mla_lora_proj=False.""" + extra_args = get_decoupled_parallelism_overrides() + + # Initialize the unfused model. + unfused_args = { + **self.config_arguments, + "fused_mla_lora_proj": False, + **extra_args, + } + cfg_unfused = pyconfig.initialize([sys.argv[0], get_test_config_path()], **unfused_args) + devices_array = maxtext_utils.create_device_mesh(cfg_unfused) + mesh = Mesh(devices_array, cfg_unfused.mesh_axes) + dummy_q = jnp.ones(( + cfg_unfused.global_batch_size_to_train_on, + cfg_unfused.max_target_length, + cfg_unfused.base_emb_dim, + )) + mla_unfused = MLA( + config=cfg_unfused, + num_query_heads=cfg_unfused.num_query_heads, + num_kv_heads=cfg_unfused.num_kv_heads, + head_dim=cfg_unfused.head_dim, + inputs_q_shape=dummy_q.shape, + inputs_kv_shape=dummy_q.shape, + max_target_length=cfg_unfused.max_target_length, + max_prefill_predict_length=cfg_unfused.max_prefill_predict_length, + mesh=mesh, + attention_kernel="dot_product", + dtype=cfg_unfused.dtype, + dropout_rate=cfg_unfused.dropout_rate, + attention_type=cfg_unfused.attention_type, + q_lora_rank=cfg_unfused.q_lora_rank, + kv_lora_rank=cfg_unfused.kv_lora_rank, + qk_nope_head_dim=cfg_unfused.qk_nope_head_dim, + qk_rope_head_dim=cfg_unfused.qk_rope_head_dim, + v_head_dim=cfg_unfused.v_head_dim, + model_mode=MODEL_MODE_TRAIN, + rngs=nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)), + ) - # 4. Assert that the MLA layer DOES NOT HAVE the base projections - self.assertFalse( - hasattr(mla_layer, "query"), "MLA should not have 'query' projection." - ) - self.assertFalse( - hasattr(mla_layer, "key"), "MLA should not have 'key' projection." - ) - self.assertFalse( - hasattr(mla_layer, "value"), "MLA should not have 'value' projection." - ) + # Initialize the fused model. + fused_args = { + **self.config_arguments, + "fused_mla_lora_proj": True, + **extra_args, + } + cfg_fused = pyconfig.initialize([sys.argv[0], get_test_config_path()], **fused_args) + mla_fused = MLA( + config=cfg_fused, + num_query_heads=cfg_fused.num_query_heads, + num_kv_heads=cfg_fused.num_kv_heads, + head_dim=cfg_fused.head_dim, + inputs_q_shape=dummy_q.shape, + inputs_kv_shape=dummy_q.shape, + max_target_length=cfg_fused.max_target_length, + max_prefill_predict_length=cfg_fused.max_prefill_predict_length, + mesh=mesh, + attention_kernel="dot_product", + dtype=cfg_fused.dtype, + dropout_rate=cfg_fused.dropout_rate, + attention_type=cfg_fused.attention_type, + q_lora_rank=cfg_fused.q_lora_rank, + kv_lora_rank=cfg_fused.kv_lora_rank, + qk_nope_head_dim=cfg_fused.qk_nope_head_dim, + qk_rope_head_dim=cfg_fused.qk_rope_head_dim, + v_head_dim=cfg_fused.v_head_dim, + model_mode=MODEL_MODE_TRAIN, + rngs=nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)), + ) - # 5. Assert that the MLA layer HAS all of its own specific projections AND the common 'out' projection - self.assertTrue( - hasattr(mla_layer, "wq_a"), "MLA should have 'wq_a' projection." - ) - self.assertTrue( - hasattr(mla_layer, "wq_b"), "MLA should have 'wq_b' projection." - ) - self.assertTrue( - hasattr(mla_layer, "wkv_a"), "MLA should have 'wkv_a' projection." - ) - self.assertTrue( - hasattr(mla_layer, "wkv_b"), "MLA should have 'wkv_b' projection." - ) - self.assertTrue( - hasattr(mla_layer, "q_norm"), "MLA should have 'q_norm' projection." - ) - self.assertTrue( - hasattr(mla_layer, "kv_norm"), "MLA should have 'kv_norm' projection." - ) - self.assertTrue(hasattr(mla_layer, "out"), "MLA should have 'out' projection.") - - def test_fused_mla_lora_proj_output_equivalence(self): - """Tests that fused_mla_lora_proj=True produces identical outputs to fused_mla_lora_proj=False.""" - extra_args = get_decoupled_parallelism_overrides() - - # Initialize the unfused model. - unfused_args = { - **self.config_arguments, - "fused_mla_lora_proj": False, - **extra_args, - } - cfg_unfused = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], **unfused_args - ) - devices_array = maxtext_utils.create_device_mesh(cfg_unfused) - mesh = Mesh(devices_array, cfg_unfused.mesh_axes) - dummy_q = jnp.ones( - ( - cfg_unfused.global_batch_size_to_train_on, - cfg_unfused.max_target_length, - cfg_unfused.base_emb_dim, - ) - ) - mla_unfused = MLA( - config=cfg_unfused, - num_query_heads=cfg_unfused.num_query_heads, - num_kv_heads=cfg_unfused.num_kv_heads, - head_dim=cfg_unfused.head_dim, - inputs_q_shape=dummy_q.shape, - inputs_kv_shape=dummy_q.shape, - max_target_length=cfg_unfused.max_target_length, - max_prefill_predict_length=cfg_unfused.max_prefill_predict_length, - mesh=mesh, - attention_kernel="dot_product", - dtype=cfg_unfused.dtype, - dropout_rate=cfg_unfused.dropout_rate, - attention_type=cfg_unfused.attention_type, - q_lora_rank=cfg_unfused.q_lora_rank, - kv_lora_rank=cfg_unfused.kv_lora_rank, - qk_nope_head_dim=cfg_unfused.qk_nope_head_dim, - qk_rope_head_dim=cfg_unfused.qk_rope_head_dim, - v_head_dim=cfg_unfused.v_head_dim, - model_mode=MODEL_MODE_TRAIN, - rngs=nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)), - ) + # Make both models mathematically equivalent: + # fused wq_kv_a = concat(unfused wq_a, unfused wkv_a) along the output axis. + mla_fused.wq_kv_a.kernel.value = jnp.concatenate( + [mla_unfused.wq_a.kernel.value, mla_unfused.wkv_a.kernel.value], axis=-1 + ) + mla_fused.wq_b.kernel.value = mla_unfused.wq_b.kernel.value + mla_fused.q_norm.scale.value = mla_unfused.q_norm.scale.value + mla_fused.wkv_b.kernel.value = mla_unfused.wkv_b.kernel.value + mla_fused.kv_norm.scale.value = mla_unfused.kv_norm.scale.value + mla_fused.out.kernel.value = mla_unfused.out.kernel.value + + # Run both models on the same inputs and verify outputs are identical. + lnx, decoder_segment_ids, decoder_positions = self.get_data(cfg_unfused, cfg_unfused.dtype) + common_kwargs = { + "decoder_segment_ids": decoder_segment_ids, + "inputs_positions": decoder_positions, + "deterministic": True, + "model_mode": MODEL_MODE_TRAIN, + } + output_unfused, _ = mla_unfused(lnx, lnx, **common_kwargs) + output_fused, _ = mla_fused(lnx, lnx, **common_kwargs) - # Initialize the fused model. - fused_args = { - **self.config_arguments, - "fused_mla_lora_proj": True, - **extra_args, - } - cfg_fused = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], **fused_args - ) - mla_fused = MLA( - config=cfg_fused, - num_query_heads=cfg_fused.num_query_heads, - num_kv_heads=cfg_fused.num_kv_heads, - head_dim=cfg_fused.head_dim, - inputs_q_shape=dummy_q.shape, - inputs_kv_shape=dummy_q.shape, - max_target_length=cfg_fused.max_target_length, - max_prefill_predict_length=cfg_fused.max_prefill_predict_length, - mesh=mesh, - attention_kernel="dot_product", - dtype=cfg_fused.dtype, - dropout_rate=cfg_fused.dropout_rate, - attention_type=cfg_fused.attention_type, - q_lora_rank=cfg_fused.q_lora_rank, - kv_lora_rank=cfg_fused.kv_lora_rank, - qk_nope_head_dim=cfg_fused.qk_nope_head_dim, - qk_rope_head_dim=cfg_fused.qk_rope_head_dim, - v_head_dim=cfg_fused.v_head_dim, - model_mode=MODEL_MODE_TRAIN, - rngs=nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)), - ) + self.assertTrue( + jax.numpy.allclose(output_unfused, output_fused, rtol=1e-05, atol=1e-05, equal_nan=False), + "fused_mla_lora_proj=True and fused_mla_lora_proj=False produced different outputs.", + ) - # Make both models mathematically equivalent: - # fused wq_kv_a = concat(unfused wq_a, unfused wkv_a) along the output axis. - mla_fused.wq_kv_a.kernel.value = jnp.concatenate( - [mla_unfused.wq_a.kernel.value, mla_unfused.wkv_a.kernel.value], axis=-1 - ) - mla_fused.wq_b.kernel.value = mla_unfused.wq_b.kernel.value - mla_fused.q_norm.scale.value = mla_unfused.q_norm.scale.value - mla_fused.wkv_b.kernel.value = mla_unfused.wkv_b.kernel.value - mla_fused.kv_norm.scale.value = mla_unfused.kv_norm.scale.value - mla_fused.out.kernel.value = mla_unfused.out.kernel.value - - # Run both models on the same inputs and verify outputs are identical. - lnx, decoder_segment_ids, decoder_positions = self.get_data( - cfg_unfused, cfg_unfused.dtype - ) - common_kwargs = { - "decoder_segment_ids": decoder_segment_ids, - "inputs_positions": decoder_positions, - "deterministic": True, - "model_mode": MODEL_MODE_TRAIN, - } - output_unfused, _ = mla_unfused(lnx, lnx, **common_kwargs) - output_fused, _ = mla_fused(lnx, lnx, **common_kwargs) - - self.assertTrue( - jax.numpy.allclose( - output_unfused, output_fused, rtol=1e-05, atol=1e-05, equal_nan=False - ), - "fused_mla_lora_proj=True and fused_mla_lora_proj=False produced different outputs.", - ) + @parameterized.named_parameters( + { + "testcase_name": "cp_no_load_balance", + "ici_context_parallelism": 4, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_with_load_balance", + "ici_context_parallelism": 4, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_ep_no_load_balance", + "ici_context_parallelism": 2, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_ep_with_load_balance", + "ici_context_parallelism": 2, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "ep_no_load_balance", + "ici_context_parallelism": 1, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "ep_with_load_balance", + "ici_context_parallelism": 1, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "auto", + }, + { + "testcase_name": "cp_no_load_balance_explicit", + "ici_context_parallelism": 4, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_with_load_balance_explicit", + "ici_context_parallelism": 4, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 1, + "expert_shard_attention_option": "fsdp", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_ep_no_load_balance_explicit", + "ici_context_parallelism": 2, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "cp_ep_with_load_balance_explicit", + "ici_context_parallelism": 2, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 2, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "ep_no_load_balance_explicit", + "ici_context_parallelism": 1, + "context_parallel_load_balance": False, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + { + "testcase_name": "ep_with_load_balance_explicit", + "ici_context_parallelism": 1, + "context_parallel_load_balance": True, + "ici_expert_parallelism": 4, + "expert_shard_attention_option": "context", + "shard_mode": "explicit", + }, + ) + # TODO (b/454764135.) : This tests fails with new tokamax kernel + @pytest.mark.tpu_only + def test_tpu_flash_attention_context_parallel( + self, + ici_context_parallelism, + context_parallel_load_balance, + ici_expert_parallelism, + expert_shard_attention_option, + shard_mode, + ): + """Test equivalence between dot_product and flash attention + context/expert parallelism""" - @parameterized.named_parameters( - { - "testcase_name": "cp_no_load_balance", - "ici_context_parallelism": 4, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_with_load_balance", - "ici_context_parallelism": 4, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_ep_no_load_balance", - "ici_context_parallelism": 2, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_ep_with_load_balance", - "ici_context_parallelism": 2, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "ep_no_load_balance", - "ici_context_parallelism": 1, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "ep_with_load_balance", - "ici_context_parallelism": 1, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "auto", - }, - { - "testcase_name": "cp_no_load_balance_explicit", - "ici_context_parallelism": 4, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "explicit", - }, - { - "testcase_name": "cp_with_load_balance_explicit", - "ici_context_parallelism": 4, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 1, - "expert_shard_attention_option": "fsdp", - "shard_mode": "explicit", - }, - { - "testcase_name": "cp_ep_no_load_balance_explicit", - "ici_context_parallelism": 2, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "cp_ep_with_load_balance_explicit", - "ici_context_parallelism": 2, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 2, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "ep_no_load_balance_explicit", - "ici_context_parallelism": 1, - "context_parallel_load_balance": False, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - { - "testcase_name": "ep_with_load_balance_explicit", - "ici_context_parallelism": 1, - "context_parallel_load_balance": True, - "ici_expert_parallelism": 4, - "expert_shard_attention_option": "context", - "shard_mode": "explicit", - }, - ) - # TODO (b/454764135.) : This tests fails with new tokamax kernel - @pytest.mark.tpu_only - def test_tpu_flash_attention_context_parallel( - self, - ici_context_parallelism, - context_parallel_load_balance, - ici_expert_parallelism, - expert_shard_attention_option, - shard_mode, - ): - """Test equivalence between dot_product and flash attention + context/expert parallelism""" - - config_arguments = { - "per_device_batch_size": 1.0, - "run_name": "test", - "enable_checkpointing": False, - "max_target_length": 512, - "sa_block_q": 128, - "sa_block_kv": 128, - "sa_block_kv_compute": 128, - "sa_block_q_dkv": 128, - "sa_block_kv_dkv": 128, - "sa_block_kv_dkv_compute": 128, - "sa_block_q_dq": 128, - "sa_block_kv_dq": 128, - "attention_type": AttentionType.MLA.value, - "q_lora_rank": 4, - "kv_lora_rank": 4, - "qk_nope_head_dim": 128, - "qk_rope_head_dim": 64, - "v_head_dim": 128, - "shard_mode": shard_mode, - } - - cfg, mla = self.init_mla(config_arguments, rope_type="default") - lnx, decoder_segment_ids, decoder_positions = self.get_data(cfg, cfg.dtype) - # Dot product - mla_generic_output, _ = mla( - lnx, - lnx, - decoder_segment_ids=decoder_segment_ids, - inputs_positions=decoder_positions, - deterministic=True, - model_mode=MODEL_MODE_TRAIN, - ) - generic_state = nnx.state(mla) - - # Test with Context Parallelism - cfg_cp = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **config_arguments, - rope_type=cfg.rope_type, - ici_context_parallelism=ici_context_parallelism, - context_parallel_load_balance=context_parallel_load_balance, - ici_expert_parallelism=ici_expert_parallelism, - expert_shard_attention_option=expert_shard_attention_option, - ) - devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp) - axis_type = AxisType.Explicit if shard_mode == "explicit" else AxisType.Auto - axis_names = [axis_type for _ in cfg_cp.mesh_axes] - mesh_cp = Mesh(devices_array_cp, cfg_cp.mesh_axes, axis_types=tuple(axis_names)) - attention_as_mla_flash_cp = MLA( - config=cfg_cp, - num_query_heads=cfg_cp.num_query_heads, - num_kv_heads=cfg_cp.num_kv_heads, - head_dim=cfg_cp.head_dim, - inputs_q_shape=lnx.shape, - inputs_kv_shape=lnx.shape, - max_target_length=cfg_cp.max_target_length, - max_prefill_predict_length=cfg_cp.max_prefill_predict_length, - mesh=mesh_cp, - attention_kernel="flash", - dtype=cfg_cp.dtype, - dropout_rate=cfg_cp.dropout_rate, - attention_type=cfg_cp.attention_type, - q_lora_rank=cfg_cp.q_lora_rank, - kv_lora_rank=cfg_cp.kv_lora_rank, - qk_nope_head_dim=cfg_cp.qk_nope_head_dim, - qk_rope_head_dim=cfg_cp.qk_rope_head_dim, - v_head_dim=cfg_cp.v_head_dim, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) - nnx.update(attention_as_mla_flash_cp, generic_state) - mla_generic_flash_cp_output = ( - attention_test_util.forward_with_context_expert_parallelism( - cfg_cp, - mesh_cp, - attention_as_mla_flash_cp, - lnx, - decoder_segment_ids, - decoder_positions, - ) - ) + config_arguments = { + "per_device_batch_size": 1.0, + "run_name": "test", + "enable_checkpointing": False, + "max_target_length": 512, + "sa_block_q": 128, + "sa_block_kv": 128, + "sa_block_kv_compute": 128, + "sa_block_q_dkv": 128, + "sa_block_kv_dkv": 128, + "sa_block_kv_dkv_compute": 128, + "sa_block_q_dq": 128, + "sa_block_kv_dq": 128, + "attention_type": AttentionType.MLA.value, + "q_lora_rank": 4, + "kv_lora_rank": 4, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "shard_mode": shard_mode, + } - # This removes all sharding information and makes them standard NumPy arrays. - mla_generic_output = jax.device_get(mla_generic_output) - mla_generic_flash_cp_output = jax.device_get(mla_generic_flash_cp_output) - - self.assertTrue( - jax.numpy.allclose( - mla_generic_output, - mla_generic_flash_cp_output, - rtol=1e-01, - atol=1e-01, - equal_nan=False, - ), - msg="MLA Logits from generic dot product and flash attention + context/expert parallelism are not close.\n" - f"ici_context_parallelism={ici_context_parallelism}, context_parallel_load_balance={context_parallel_load_balance}," - f" ici_expert_parallelism={ici_expert_parallelism}, expert_shard_attention_option={expert_shard_attention_option}.", - ) + cfg, mla = self.init_mla(config_arguments, rope_type="default") + lnx, decoder_segment_ids, decoder_positions = self.get_data(cfg, cfg.dtype) + # Dot product + mla_generic_output, _ = mla( + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + generic_state = nnx.state(mla) + + # Test with Context Parallelism + cfg_cp = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **config_arguments, + rope_type=cfg.rope_type, + ici_context_parallelism=ici_context_parallelism, + context_parallel_load_balance=context_parallel_load_balance, + ici_expert_parallelism=ici_expert_parallelism, + expert_shard_attention_option=expert_shard_attention_option, + ) + devices_array_cp = maxtext_utils.create_device_mesh(cfg_cp) + axis_type = AxisType.Explicit if shard_mode == "explicit" else AxisType.Auto + axis_names = [axis_type for _ in cfg_cp.mesh_axes] + mesh_cp = Mesh(devices_array_cp, cfg_cp.mesh_axes, axis_types=tuple(axis_names)) + attention_as_mla_flash_cp = MLA( + config=cfg_cp, + num_query_heads=cfg_cp.num_query_heads, + num_kv_heads=cfg_cp.num_kv_heads, + head_dim=cfg_cp.head_dim, + inputs_q_shape=lnx.shape, + inputs_kv_shape=lnx.shape, + max_target_length=cfg_cp.max_target_length, + max_prefill_predict_length=cfg_cp.max_prefill_predict_length, + mesh=mesh_cp, + attention_kernel="flash", + dtype=cfg_cp.dtype, + dropout_rate=cfg_cp.dropout_rate, + attention_type=cfg_cp.attention_type, + q_lora_rank=cfg_cp.q_lora_rank, + kv_lora_rank=cfg_cp.kv_lora_rank, + qk_nope_head_dim=cfg_cp.qk_nope_head_dim, + qk_rope_head_dim=cfg_cp.qk_rope_head_dim, + v_head_dim=cfg_cp.v_head_dim, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) + nnx.update(attention_as_mla_flash_cp, generic_state) + mla_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism( + cfg_cp, + mesh_cp, + attention_as_mla_flash_cp, + lnx, + decoder_segment_ids, + decoder_positions, + ) - def get_indexer_test_data(self, batch_size, q_len, kv_len, num_heads, head_dim): - """Helper to generate random data for indexer tests.""" - key_q, key_k, key_is = jax.random.split(self.rng, 3) - query = jax.random.normal(key_q, (batch_size, q_len, num_heads, head_dim)) - key = jax.random.normal(key_k, (batch_size, kv_len, num_heads, head_dim)) - indexer_score = jax.random.normal(key_is, (batch_size, q_len, kv_len)) - return query, key, indexer_score - - def get_causal_mask_for_indexer(self, batch_size, q_len, kv_len): - """Helper to generate a causal mask with DEFAULT_MASK_VALUE.""" - row_ids = jnp.arange(q_len)[:, None] - col_ids = jnp.arange(kv_len)[None, :] - attention_mask = jnp.where(col_ids <= row_ids, 0.0, DEFAULT_MASK_VALUE) - attention_mask = jnp.broadcast_to(attention_mask, (batch_size, q_len, kv_len)) - return attention_mask - - def test_indexer_loss(self): - """Test indexer loss computation.""" - mla_config_args = self.config_arguments.copy() - mla_config_args["use_sparse_indexer"] = True - mla_config_args["attention"] = "dot_product" - _, mla = self.init_mla(mla_config_args, rope_type="default") - - batch_size = 2 - q_len = 3 - kv_len = 4 - num_heads = 5 - head_dim = 6 - scaling_factor = 0.5 - - query, key, indexer_score = self.get_indexer_test_data( - batch_size, q_len, kv_len, num_heads, head_dim - ) + # This removes all sharding information and makes them standard NumPy arrays. + mla_generic_output = jax.device_get(mla_generic_output) + mla_generic_flash_cp_output = jax.device_get(mla_generic_flash_cp_output) + + self.assertTrue( + jax.numpy.allclose( + mla_generic_output, + mla_generic_flash_cp_output, + rtol=1e-01, + atol=1e-01, + equal_nan=False, + ), + msg="MLA Logits from generic dot product and flash attention + context/expert parallelism are not close.\n" + f"ici_context_parallelism={ici_context_parallelism}, context_parallel_load_balance={context_parallel_load_balance}," + f" ici_expert_parallelism={ici_expert_parallelism}, expert_shard_attention_option={expert_shard_attention_option}.", + ) - # Causal mask - attention_mask = self.get_causal_mask_for_indexer(batch_size, q_len, kv_len) - indexer_score += attention_mask - - topk_indices = jnp.array([[[0, 1], [0, 1], [0, 1]], [[0, 1], [0, 1], [0, 1]]]) - indexer_mask = mla.indexer.generate_mask(topk_indices, kv_len) + attention_mask - - loss_dense = mla.calculate_indexer_loss( - indexer_score=indexer_score, - query=query, - key=key, - attention_mask=attention_mask, - indexer_mask=indexer_mask, - sparse_loss=False, - scaling_factor=scaling_factor, - ) + def get_indexer_test_data(self, batch_size, q_len, kv_len, num_heads, head_dim): + """Helper to generate random data for indexer tests.""" + key_q, key_k, key_is = jax.random.split(self.rng, 3) + query = jax.random.normal(key_q, (batch_size, q_len, num_heads, head_dim)) + key = jax.random.normal(key_k, (batch_size, kv_len, num_heads, head_dim)) + indexer_score = jax.random.normal(key_is, (batch_size, q_len, kv_len)) + return query, key, indexer_score + + def get_causal_mask_for_indexer(self, batch_size, q_len, kv_len): + """Helper to generate a causal mask with DEFAULT_MASK_VALUE.""" + row_ids = jnp.arange(q_len)[:, None] + col_ids = jnp.arange(kv_len)[None, :] + attention_mask = jnp.where(col_ids <= row_ids, 0.0, DEFAULT_MASK_VALUE) + attention_mask = jnp.broadcast_to(attention_mask, (batch_size, q_len, kv_len)) + return attention_mask + + def test_indexer_loss(self): + """Test indexer loss computation.""" + mla_config_args = self.config_arguments.copy() + mla_config_args["use_sparse_indexer"] = True + mla_config_args["attention"] = "dot_product" + _, mla = self.init_mla(mla_config_args, rope_type="default") + + batch_size = 2 + q_len = 3 + kv_len = 4 + num_heads = 5 + head_dim = 6 + scaling_factor = 0.5 + + query, key, indexer_score = self.get_indexer_test_data(batch_size, q_len, kv_len, num_heads, head_dim) + + # Causal mask + attention_mask = self.get_causal_mask_for_indexer(batch_size, q_len, kv_len) + indexer_score += attention_mask + + topk_indices = jnp.array([[[0, 1], [0, 1], [0, 1]], [[0, 1], [0, 1], [0, 1]]]) + indexer_mask = mla.indexer.generate_mask(topk_indices, kv_len) + attention_mask + + loss_dense = mla.calculate_indexer_loss( + indexer_score=indexer_score, + query=query, + key=key, + attention_mask=attention_mask, + indexer_mask=indexer_mask, + sparse_loss=False, + scaling_factor=scaling_factor, + ) - loss_sparse = mla.calculate_indexer_loss( - indexer_score=indexer_score, - query=query, - key=key, - attention_mask=attention_mask, - indexer_mask=indexer_mask, - sparse_loss=True, - scaling_factor=scaling_factor, - ) + loss_sparse = mla.calculate_indexer_loss( + indexer_score=indexer_score, + query=query, + key=key, + attention_mask=attention_mask, + indexer_mask=indexer_mask, + sparse_loss=True, + scaling_factor=scaling_factor, + ) - np.testing.assert_array_less(0.0, loss_dense) - np.testing.assert_array_less(0.0, loss_sparse) - - def test_indexer_loss_kl_divergence_zero(self): - """Test that KL divergence is 0 when target and pred distributions match exactly.""" - mla_config_args = self.config_arguments.copy() - mla_config_args["use_sparse_indexer"] = True - mla_config_args["attention"] = "dot_product" - _, mla = self.init_mla(mla_config_args, rope_type="default") - - batch_size = 2 - q_len = 3 - kv_len = 4 - num_heads = 5 - head_dim = 6 - - # Setup perfectly matching distributions - # Make query and key such that einsum yields zeros (so softmax gives uniform distribution over unmasked) - query = jnp.zeros((batch_size, q_len, num_heads, head_dim)) - key = jnp.zeros((batch_size, kv_len, num_heads, head_dim)) - - # Causal mask - attention_mask = self.get_causal_mask_for_indexer(batch_size, q_len, kv_len) - - # Indexer score matches the shape and is uniform - indexer_score = jnp.zeros((batch_size, q_len, kv_len)) + attention_mask - - topk_indices = jnp.array([[[0, 1], [0, 1], [0, 1]], [[0, 1], [0, 1], [0, 1]]]) - indexer_mask = mla.indexer.generate_mask(topk_indices, kv_len) + attention_mask - - loss = mla.calculate_indexer_loss( - indexer_score=indexer_score, - query=query, - key=key, - attention_mask=attention_mask, - indexer_mask=indexer_mask, - sparse_loss=False, - scaling_factor=1.0, - ) + np.testing.assert_array_less(0.0, loss_dense) + np.testing.assert_array_less(0.0, loss_sparse) + + def test_indexer_loss_kl_divergence_zero(self): + """Test that KL divergence is 0 when target and pred distributions match exactly.""" + mla_config_args = self.config_arguments.copy() + mla_config_args["use_sparse_indexer"] = True + mla_config_args["attention"] = "dot_product" + _, mla = self.init_mla(mla_config_args, rope_type="default") + + batch_size = 2 + q_len = 3 + kv_len = 4 + num_heads = 5 + head_dim = 6 + + # Setup perfectly matching distributions + # Make query and key such that einsum yields zeros (so softmax gives uniform distribution over unmasked) + query = jnp.zeros((batch_size, q_len, num_heads, head_dim)) + key = jnp.zeros((batch_size, kv_len, num_heads, head_dim)) + + # Causal mask + attention_mask = self.get_causal_mask_for_indexer(batch_size, q_len, kv_len) + + # Indexer score matches the shape and is uniform + indexer_score = jnp.zeros((batch_size, q_len, kv_len)) + attention_mask + + topk_indices = jnp.array([[[0, 1], [0, 1], [0, 1]], [[0, 1], [0, 1], [0, 1]]]) + indexer_mask = mla.indexer.generate_mask(topk_indices, kv_len) + attention_mask + + loss = mla.calculate_indexer_loss( + indexer_score=indexer_score, + query=query, + key=key, + attention_mask=attention_mask, + indexer_mask=indexer_mask, + sparse_loss=False, + scaling_factor=1.0, + ) - np.testing.assert_allclose(loss, 0.0, atol=1e-5) + np.testing.assert_allclose(loss, 0.0, atol=1e-5) class Qwen3NextGatedDeltaNetTest(unittest.TestCase): - """Test for the Gated Delta Net in Qwen3-Next""" - - def setUp(self): - super().setUp() - self.config_arguments = { - "per_device_batch_size": 1.0, - "run_name": "test", - "enable_checkpointing": False, - "max_prefill_predict_length": 16, - "max_target_length": 32, - "base_emb_dim": 128, # changed to base_emb_dim so it properly overrides the default 2048 - "gdn_num_value_heads": 4, - "gdn_num_key_heads": 4, - "gdn_key_head_dim": 32, - "gdn_value_head_dim": 32, - "gdn_conv_kernel_dim": 4, - "gdn_chunk_size": 16, - "dtype": "bfloat16", - } - self.cfg = pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **self.config_arguments, - ) - self.rng = jax.random.PRNGKey(0) - self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) - - def get_structured_data(self, dtype): - """get structured data for GDN (only requires hidden states)""" - lnx = jax.random.normal( - self.rng, - shape=( - self.cfg.global_batch_size_to_train_on, - self.cfg.max_target_length, - self.cfg.emb_dim, - ), - dtype=dtype, - ) - return lnx - - @pytest.mark.tpu_only - def test_autoregression(self): - cfg = self.cfg - prefill_length = cfg.max_prefill_predict_length - decode_total_length = cfg.max_target_length - - # 1. Init Data - lnx = self.get_structured_data(cfg.dtype) - - # 2. Init GDN Layer - gdn = Qwen3NextGatedDeltaNet( - config=cfg, - dtype=cfg.dtype, - model_mode=MODEL_MODE_PREFILL, - rngs=self.nnx_rng, - ) + """Test for the Gated Delta Net in Qwen3-Next""" - # 3. Full / Train mode - gdn_full = gdn( - lnx, - model_mode=MODEL_MODE_TRAIN, - ) + def setUp(self): + super().setUp() + self.config_arguments = { + "per_device_batch_size": 1.0, + "run_name": "test", + "enable_checkpointing": False, + "max_prefill_predict_length": 16, + "max_target_length": 32, + "base_emb_dim": 128, # changed to base_emb_dim so it properly overrides the default 2048 + "gdn_num_value_heads": 4, + "gdn_num_key_heads": 4, + "gdn_key_head_dim": 32, + "gdn_value_head_dim": 32, + "gdn_conv_kernel_dim": 4, + "gdn_chunk_size": 16, + "dtype": "bfloat16", + } + self.cfg = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **self.config_arguments, + ) + self.rng = jax.random.PRNGKey(0) + self.nnx_rng = nnx.Rngs(params=0, dropout=jax.random.PRNGKey(42)) + + def get_structured_data(self, dtype): + """get structured data for GDN (only requires hidden states)""" + lnx = jax.random.normal( + self.rng, + shape=( + self.cfg.global_batch_size_to_train_on, + self.cfg.max_target_length, + self.cfg.emb_dim, + ), + dtype=dtype, + ) + return lnx + + @pytest.mark.tpu_only + def test_autoregression(self): + cfg = self.cfg + prefill_length = cfg.max_prefill_predict_length + decode_total_length = cfg.max_target_length + + # 1. Init Data + lnx = self.get_structured_data(cfg.dtype) + + # 2. Init GDN Layer + gdn = Qwen3NextGatedDeltaNet( + config=cfg, + dtype=cfg.dtype, + model_mode=MODEL_MODE_PREFILL, + rngs=self.nnx_rng, + ) + + # 3. Full / Train mode + gdn_full = gdn( + lnx, + model_mode=MODEL_MODE_TRAIN, + ) - # 4. Prefill mode - lnx_prefill = lnx[:, 0:prefill_length, :] + # 4. Prefill mode + lnx_prefill = lnx[:, 0:prefill_length, :] - gdn_prefill = gdn( - lnx_prefill, - model_mode=MODEL_MODE_PREFILL, - ) + gdn_prefill = gdn( + lnx_prefill, + model_mode=MODEL_MODE_PREFILL, + ) - self.assertTrue( - jax.numpy.allclose( - gdn_prefill, - gdn_full[:, :prefill_length, :], - rtol=1e-02, - atol=1e-02, - equal_nan=False, - ) + self.assertTrue( + jax.numpy.allclose( + gdn_prefill, + gdn_full[:, :prefill_length, :], + rtol=1e-02, + atol=1e-02, + equal_nan=False, ) + ) - # 5. Autoregressive mode - for idx in range(prefill_length, decode_total_length): - lnx_idx = lnx[:, idx : idx + 1, :] + # 5. Autoregressive mode + for idx in range(prefill_length, decode_total_length): + lnx_idx = lnx[:, idx : idx + 1, :] - gdn_idx = gdn( - lnx_idx, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - ) + gdn_idx = gdn( + lnx_idx, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + ) - gdn_full_this_idx = gdn_full[:, idx : idx + 1, :] - self.assertEqual(gdn_full_this_idx.shape, gdn_idx.shape) + gdn_full_this_idx = gdn_full[:, idx : idx + 1, :] + self.assertEqual(gdn_full_this_idx.shape, gdn_idx.shape) - self.assertTrue( - jax.numpy.allclose( - gdn_full_this_idx, gdn_idx, rtol=1e-02, atol=1e-02, equal_nan=False - ) - ) + self.assertTrue(jax.numpy.allclose(gdn_full_this_idx, gdn_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) if __name__ == "__main__": - unittest.main() + unittest.main()