From b3a8bb400e7e18c7b75fb1fbbd478878be31cf7c Mon Sep 17 00:00:00 2001 From: dandragona Date: Tue, 12 May 2026 21:31:45 +0000 Subject: [PATCH] Optimize mHC for expansion rate 4 using convex combination of permutations and add enable_mhc_k4_shortcut feature gate --- src/maxtext/configs/base.yml | 2 + src/maxtext/configs/types.py | 68 +++++++++---- src/maxtext/layers/mhc.py | 60 +++++++++--- tests/unit/grain_data_processing_test.py | 1 - tests/unit/mhc_test.py | 120 +++++++++++++++++++---- 5 files changed, 204 insertions(+), 47 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 6e19ccc445..e78eb5207a 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1217,6 +1217,8 @@ force_q_layout: false mhc_expansion_rate: 1 # The number of iterations for the Sinkhorn-Knopp algorithm. sinkhorn_iterations: 20 +# Whether to enable the permutation-based convex combination shortcut when mhc_expansion_rate is 4. +enable_mhc_k4_shortcut: True ################################## DeepSeek Engram ################################## # Indices of transformer layers where Engram are integrated; leave empty [] to disable. diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 20594bccc3..c47a8a7159 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -506,7 +506,10 @@ class ModelArchitecture(BaseModel): True, description="Whether to apply scale on query and key normalizations (default True).", ) - v_norm_with_scale: bool = Field(True, description="Whether to apply scale on value normalization (default True).") + v_norm_with_scale: bool = Field( + True, + description="Whether to apply scale on value normalization (default True).", + ) class MTP(BaseModel): @@ -685,14 +688,18 @@ class MoEGeneral(BaseModel): num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.") num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.") capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.") - ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.") + ragged_buffer_factor: float = Field( + -1.0, + description="Ragged buffer factor. If < 0, ragged buffer is worst case size.", + ) moe_expert_input_dim: int = Field( -1, description="Dimension of tokens entering the MoE layer. If < 0, defaults to emb_dim.", ) base_moe_mlp_dim: int = Field(-1, description="Intermediate dimension at MoE layer.") padded_base_moe_mlp_dim: Optional[int] = Field( - None, description="Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution." + None, + description="Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution.", ) load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.") use_custom_sort_vjp: bool = Field( @@ -873,7 +880,8 @@ class HardwareAndMesh(BaseModel): ) custom_mesh: str = Field("", description="Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']") custom_mesh_and_rule: CustomRule = Field( - CustomRule.DEFAULT, description="Customized mesh and logical rules for granularity." + CustomRule.DEFAULT, + description="Customized mesh and logical rules for granularity.", ) allow_split_physical_axes: bool = Field(False, description="Allow splitting physical axes for device mesh creation.") enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.") @@ -882,7 +890,8 @@ class HardwareAndMesh(BaseModel): pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.") pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.") remove_size_one_mesh_axis_from_type: bool = Field( - True, description="Whether to remove size one mesh axis from type through jax.config." + True, + description="Whether to remove size one mesh axis from type through jax.config.", ) @@ -903,7 +912,10 @@ class LayoutAndSharding(BaseModel): description="Allowed percentage of non-sharded parameters.", ) shard_optimizer_over_data: bool = Field(False, description="Enable ZeRO-1 optimizer sharding over the data axis.") - internal_compile: bool = Field(False, description="Use internal_compile to bypass open-source topology mappings.") + internal_compile: bool = Field( + False, + description="Use internal_compile to bypass open-source topology mappings.", + ) internal_compile_num_devices: int = Field(-1, description="Number of devices when using internal_compile.") compile_xla_flags: str = Field("", description="Compiler options for compilation only.") @@ -950,7 +962,8 @@ class PipelineParallelism(BaseModel): """Configuration for pipeline parallelism.""" pipeline_fsdp_ag_per_repeat: bool = Field( - False, description="Enable weight prefetching for circular pipeline parallelism." + False, + description="Enable weight prefetching for circular pipeline parallelism.", ) num_layers_per_pipeline_stage: int = Field(1, description="Number of layers to place on each pipeline stage.") num_pipeline_repeats: int = Field( @@ -1194,7 +1207,10 @@ class OlmoGrainDataset(BaseModel): ``data_shuffle_seed``); only OLMo-specific fields are listed here. """ - olmo_index_path: PathStr = Field("", description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.") + olmo_index_path: PathStr = Field( + "", + description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.", + ) olmo_path_remap_from: PathStr = Field( "", description="If set, rewrite index file paths starting with this prefix to olmo_path_remap_to.", @@ -1279,19 +1295,24 @@ class Distillation(BaseModel): distill_layer_indices: None | list = Field(None, description="Feature indices for feature loss.") distill_alpha_end: Optional[float] = Field(None, description="Target alpha at end of training. None keeps alpha fixed.") distill_alpha_schedule: Literal["constant", "linear", "cosine"] = Field( - "constant", description="Schedule type for alpha annealing ('constant', 'linear', or 'cosine')." + "constant", + description="Schedule type for alpha annealing ('constant', 'linear', or 'cosine').", ) distill_temperature_end: Optional[float] = Field( - None, description="Target temperature at end of training. None keeps temperature fixed." + None, + description="Target temperature at end of training. None keeps temperature fixed.", ) distill_temperature_schedule: Literal["constant", "linear", "cosine"] = Field( - "constant", description="Schedule type for temperature annealing ('constant', 'linear', or 'cosine')." + "constant", + description="Schedule type for temperature annealing ('constant', 'linear', or 'cosine').", ) distill_beta_end: Optional[float] = Field( - None, description="Target beta_feature at end of training. None keeps beta fixed." + None, + description="Target beta_feature at end of training. None keeps beta fixed.", ) distill_beta_schedule: Literal["constant", "linear", "cosine"] = Field( - "constant", description="Schedule type for beta annealing ('constant', 'linear', or 'cosine')." + "constant", + description="Schedule type for beta annealing ('constant', 'linear', or 'cosine').", ) # --- Learn to init related parameters -- @@ -1314,11 +1335,13 @@ class Distillation(BaseModel): ) attn_module_name: Optional[str] = Field( - None, description="Attention nnx module attribute name to augment with LTI logic" + None, + description="Attention nnx module attribute name to augment with LTI logic", ) lti_layer_indices: Optional[list[int]] = Field( - None, description="List of layer indices to apply LTI modifications. If None, applied to all layers." + None, + description="List of layer indices to apply LTI modifications. If None, applied to all layers.", ) # --------------------------------------- @@ -1365,6 +1388,10 @@ class ManifoldConstrainedHyperConnections(BaseModel): mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.") sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.") + enable_mhc_k4_shortcut: bool = Field( + True, + description="Whether to enable the permutation-based convex combination shortcut when mhc_expansion_rate is 4.", + ) class DilocoParams(BaseModel): @@ -1655,7 +1682,8 @@ class Profiling(BaseModel): tpu_num_chips_to_profile_per_task: int = Field(1, description="Specifies the number of TPU chips to profile per task.") tpu_num_sparse_cores_to_trace: int = Field(2, description="Specifies the number of TPU chips to profile per task.") tpu_num_sparse_core_tiles_to_trace: int = Field( - 1, description="Specifies the number of tiles within each sparse core to trace on the TPU." + 1, + description="Specifies the number of tiles within each sparse core to trace on the TPU.", ) xprof_tpu_power_trace_level: XProfTPUPowerTraceMode = Field( XProfTPUPowerTraceMode.POWER_TRACE_NONE, @@ -2491,7 +2519,11 @@ def validate_and_set_hlo_dump_defaults(): ) for param_name, schedule, end_value in [ ("distill_alpha", self.distill_alpha_schedule, self.distill_alpha_end), - ("distill_temperature", self.distill_temperature_schedule, self.distill_temperature_end), + ( + "distill_temperature", + self.distill_temperature_schedule, + self.distill_temperature_end, + ), ("distill_beta", self.distill_beta_schedule, self.distill_beta_end), ]: if schedule != "constant" and end_value is None: @@ -3004,7 +3036,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.use_grpo = False if self.use_batch_split_schedule: - if self.quantization and not self.quantization == "fp8_full": + if self.quantization and self.quantization != "fp8_full": raise ValueError("Batch split quantization only supports `quantization=fp8_full`") if self.opt_type == "muon" and self.decoder_block not in [ diff --git a/src/maxtext/layers/mhc.py b/src/maxtext/layers/mhc.py index ce700aafcd..e0fb0b971b 100644 --- a/src/maxtext/layers/mhc.py +++ b/src/maxtext/layers/mhc.py @@ -14,17 +14,29 @@ """DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer.""" +import itertools from typing import Callable + from flax import nnx import jax import jax.numpy as jnp from jax.sharding import Mesh from maxtext.common.common_types import Array, Config from maxtext.common.common_types import HyperConnectionType -from maxtext.layers.initializers import default_bias_init, default_scalar_init, nd_dense_init +from maxtext.layers.initializers import ( + default_bias_init, + default_scalar_init, + nd_dense_init, +) from maxtext.layers.normalizations import RMSNorm +def get_4x4_permutation_matrices(): + perms = list(itertools.permutations(range(4))) + perms_array = jnp.array(perms) + return jnp.eye(4)[perms_array] + + def get_functions(expansion_rate: int): """Creates functions to broadcast a single feature stream into multiple @@ -118,6 +130,15 @@ def __init__( out_sharding=(None,), ) + if self.k == 4 and self.config.enable_mhc_k4_shortcut: + res_out_dim = 24 + res_beta_shape = (24,) + res_beta_sharding = (None,) + else: + res_out_dim = self.k * self.k + res_beta_shape = (self.k, self.k) + res_beta_sharding = (None, None) + # Weight matrices scale_init = nd_dense_init(1.0, "fan_in", "normal") in_axis = 0 @@ -126,7 +147,7 @@ def __init__( self.res_alpha = nnx.Param( scale_init( self.rngs.params(), - (self.k * self.dim, self.k * self.k), + (self.k * self.dim, res_out_dim), self.weight_dtype, in_axis=in_axis, out_axis=out_axis, @@ -156,8 +177,8 @@ def __init__( # Biases self.res_beta = nnx.Param( - default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype), - out_sharding=(None, None), + default_bias_init(self.rngs.params(), res_beta_shape, self.weight_dtype), + out_sharding=res_beta_sharding, ) self.pre_beta = nnx.Param( default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), @@ -174,13 +195,30 @@ def res_mapping(self, x: Array): res_alpha = jnp.asarray(self.res_alpha[...], self.dtype) res_beta = jnp.asarray(self.res_beta[...], self.dtype) res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype) - # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k) - h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision) - b, s, _ = h_res.shape - h_res = jnp.reshape(h_res, (b, s, self.k, self.k)) - intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :] - output = sinkhorn(intermediate, self.sinkhorn_iterations) - return output + + if self.k == 4 and self.config.enable_mhc_k4_shortcut: + # Apply projection: (b, s, k*d) @ (k*d, 24) -> (b, s, 24) + h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision) + intermediate = res_alpha_scale * h_res + res_beta[None, None, :] + # Use float32 for numerical stability during softmax + weights = jax.nn.softmax(intermediate.astype(jnp.float32), axis=-1).astype(self.dtype) + # Sum the 24 permutation matrices with the weights + permutation_matrices_4x4 = get_4x4_permutation_matrices().astype(self.dtype) + output = jnp.einsum( + "bsn,nkm -> bskm", + weights, + permutation_matrices_4x4, + precision=self.matmul_precision, + ) + return output + else: + # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k) + h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision) + b, s, _ = h_res.shape + h_res = jnp.reshape(h_res, (b, s, self.k, self.k)) + intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :] + output = sinkhorn(intermediate, self.sinkhorn_iterations) + return output def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int): """Helper function for both pre and post mappings.""" diff --git a/tests/unit/grain_data_processing_test.py b/tests/unit/grain_data_processing_test.py index 079dcc2a00..804cf23feb 100644 --- a/tests/unit/grain_data_processing_test.py +++ b/tests/unit/grain_data_processing_test.py @@ -464,7 +464,6 @@ def setUp(self): tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), enable_checkpointing=False, ) - self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) @pytest.mark.external_training diff --git a/tests/unit/mhc_test.py b/tests/unit/mhc_test.py index f076a132fd..be0a27a284 100644 --- a/tests/unit/mhc_test.py +++ b/tests/unit/mhc_test.py @@ -14,15 +14,15 @@ """Test for DeepSeek Manifold-Constrained Hyper Connections (mHC).""" -import unittest -import pytest - +from absl.testing import absltest +from absl.testing import parameterized from flax import nnx from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp from jax.sharding import Mesh import numpy as np +import pytest from maxtext.configs import pyconfig from maxtext.common.common_types import HyperConnectionType @@ -30,10 +30,13 @@ from maxtext.layers.initializers import nd_dense_init from maxtext.layers.normalizations import RMSNorm from maxtext.utils import maxtext_utils -from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides +from tests.utils.test_helpers import ( + get_decoupled_parallelism_overrides, + get_test_config_path, +) -class TestExpandReduce(unittest.TestCase): +class TestExpandReduce(absltest.TestCase): """Unit tests for MHC dimension expansion and reduction operations.""" def setUp(self): @@ -65,7 +68,7 @@ def test_value_identity(self): np.testing.assert_allclose(out, expected, rtol=1e-5) -class TestSinkhorn(unittest.TestCase): +class TestSinkhorn(absltest.TestCase): """Unit tests for MHC Sinkhorn Algorithm.""" def setUp(self): @@ -86,19 +89,21 @@ def test_doubly_stochastic_property(self): np.testing.assert_allclose(col_sums, jnp.ones_like(col_sums), atol=1e-3) -class TestMHC(unittest.TestCase): +class TestMHC(parameterized.TestCase): """Test for MHC module""" - def setUp(self): + def _setup_mhc(self, rate): + """Sets up the common configurations and modules for MHC testing.""" self.dim = 16 extra_args = get_decoupled_parallelism_overrides() self.config = pyconfig.initialize( [None, get_test_config_path()], **extra_args, - run_name="test_mhc", + skip_jax_distributed_system=True, + run_name=f"test_mhc_k{rate}", enable_checkpointing=False, model_name="deepseek-custom", - per_device_batch_size=4, + per_device_batch_size=max(4, jax.device_count()), max_target_length=7, max_prefill_predict_length=7, attention="dot_product", @@ -107,7 +112,7 @@ def setUp(self): # override override_model_config=True, base_emb_dim=self.dim, - mhc_expansion_rate=3, + mhc_expansion_rate=rate, num_experts=4, num_experts_per_tok=2, engram_layers=[], @@ -137,7 +142,14 @@ def setUp(self): # Skip GPU due to NotImplementedError: dynamic grid bounds not supported in the Triton backend @pytest.mark.tpu_only - def test_moe_layer_output_shape(self): + @parameterized.named_parameters(("Rate3", 3), ("Rate4", 4)) + def test_moe_layer_output_shape(self, rate): + # Skip test if TPU hardware is not available + has_tpu = any(d.platform == "tpu" for d in jax.devices()) + if not has_tpu: + self.skipTest("test_moe_layer_output_shape requires TPU hardware.") + + self._setup_mhc(rate) with nn_partitioning.axis_rules(self.config.logical_axis_rules): module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs) layer = moe.RoutedMoE( @@ -156,12 +168,14 @@ def test_moe_layer_output_shape(self): b, s, k, d = self.x.shape output, metadata = module(self.pre_norm, layer, x=self.x, mhc_type=HyperConnectionType.MLP_MOE) # metadata includes load_balance_loss & moe_bias_updates - self.assertEqual(len(metadata), 2) + self.assertLen(metadata, 2) for key, value in metadata.items(): self.assertIsNotNone(value, f"Key '{key}' has a value of None") self.assertEqual(output.shape, (b, s, k, d)) - def test_dense_layer_output_shape(self): + @parameterized.named_parameters(("Rate3", 3), ("Rate4", 4)) + def test_dense_layer_output_shape(self, rate): + self._setup_mhc(rate) with nn_partitioning.axis_rules(self.config.logical_axis_rules): module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs) layer = linears.MlpBlock( @@ -182,8 +196,14 @@ def test_dense_layer_output_shape(self): self.assertDictEqual(metadata, {}) self.assertEqual(output.shape, (b, s, k, d)) - def test_attention_layer_output_shape(self): - inputs_shape = (self.config.per_device_batch_size, self.config.max_target_length, self.config.emb_dim) + @parameterized.named_parameters(("Rate3", 3), ("Rate4", 4)) + def test_attention_layer_output_shape(self, rate): + self._setup_mhc(rate) + inputs_shape = ( + self.config.per_device_batch_size, + self.config.max_target_length, + self.config.emb_dim, + ) with nn_partitioning.axis_rules(self.config.logical_axis_rules): module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs) layer = attention_mla.MLA( @@ -221,6 +241,72 @@ def test_attention_layer_output_shape(self): self.assertDictEqual(metadata, {}) self.assertEqual(output.shape, (b, s, k, d)) + def test_compare_k4_and_sinkhorn_via_log(self): + """Verify that Sinkhorn can produce the same result as k=4 branch shortcut.""" + self._setup_mhc(4) + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs) + + b, s, k, d = self.x.shape + + # Generate random input X + random_x = jax.random.normal(jax.random.PRNGKey(42), (b, s, k * d)) + norm_x = module.mhc_norm(random_x) + + # Output from k=4 branch + res_mapping_out = module.res_mapping(norm_x) + + # To compare with Sinkhorn, we pass log(res_mapping_out). Since Sinkhorn + # applies softmax, log will undo that. Thus, we should immediately exit + # Sinkhorn (since we started with a doubly-stochastic matrix). + epsilon = 1e-12 + log_input = jnp.log(res_mapping_out + epsilon) + # Use 0 iterations to prove we are already converged. + sinkhorn_out = mhc.sinkhorn(log_input, 0) + + # They should be close + np.testing.assert_allclose(res_mapping_out, sinkhorn_out, atol=1e-2, rtol=1e-2) + + def test_feature_flag_gates_shortcut(self): + """Verify that setting enable_mhc_k4_shortcut=False falls back to Sinkhorn for k=4.""" + self.dim = 16 + extra_args = get_decoupled_parallelism_overrides() + self.config = pyconfig.initialize( + [None, get_test_config_path()], + **extra_args, + skip_jax_distributed_system=True, + run_name="test_mhc_k4_gated", + enable_checkpointing=False, + model_name="deepseek-custom", + per_device_batch_size=4, + max_target_length=7, + max_prefill_predict_length=7, + attention="dot_product", + routed_bias_update_rate=0.01, + load_balance_loss_weight=0.02, + # override + override_model_config=True, + base_emb_dim=self.dim, + mhc_expansion_rate=4, + enable_mhc_k4_shortcut=False, + num_experts=4, + num_experts_per_tok=2, + engram_layers=[], + ) + devices_array = maxtext_utils.create_device_mesh(self.config) + self.mesh = Mesh(devices_array, self.config.mesh_axes) + self.rngs = nnx.Rngs(params=jax.random.key(0), dropout=jax.random.key(42)) + + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs) + + # Shape of res_alpha should be (4*16, 4*4) = (64, 16) instead of (64, 24) + self.assertEqual(module.res_alpha.shape, (64, 16)) + # Shape of res_beta should be (4, 4) instead of (24,) + self.assertEqual(module.res_beta.shape, (4, 4)) + # Permutation matrices shouldn't be defined + self.assertFalse(hasattr(module, "permutation_matrices_4x4")) + if __name__ == "__main__": - unittest.main() + absltest.main()