From 1d026885ebfac5d2611f351c1b11bcad1f8a3831 Mon Sep 17 00:00:00 2001 From: Rissy Ran Date: Fri, 30 Jan 2026 04:53:20 +0000 Subject: [PATCH] Onboard DeepSeek MHC feature --- src/MaxText/common_types.py | 6 + src/MaxText/configs/base.yml | 6 + .../configs/models/deepseek-custom.yml | 61 +++++ src/MaxText/configs/types.py | 9 + src/MaxText/layers/initializers.py | 1 + src/MaxText/layers/mhc.py | 235 ++++++++++++++++++ tests/unit/mhc_test.py | 201 +++++++++++++++ 7 files changed, 519 insertions(+) create mode 100644 src/MaxText/configs/models/deepseek-custom.yml create mode 100644 src/MaxText/layers/mhc.py create mode 100644 tests/unit/mhc_test.py diff --git a/src/MaxText/common_types.py b/src/MaxText/common_types.py index d71bf400e9..f36b991cef 100644 --- a/src/MaxText/common_types.py +++ b/src/MaxText/common_types.py @@ -114,3 +114,9 @@ class AttentionType(enum.Enum): class ShardMode(enum.Enum): AUTO = "auto" # default EXPLICIT = "explicit" + + +class HyperConnectionType(enum.Enum): + ATTENTION = "attention" + MLP_MOE = "mlp_moe" + MLP_DENSE = "mlp_dense" diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 35ce9a1301..a9bf2cbea2 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -1065,3 +1065,9 @@ vllm_hf_config_path: "" vllm_additional_config: {} # When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH] force_q_layout: false + +################################## DeepSeek Manifold-Constrained Hyper Connections (mHC) ################################## +# The number of parallel streams in Hyper Connection. +mhc_expansion_rate: 0 +# The number of iterations for the Sinkhorn-Knopp algorithm. +sinkhorn_iterations: 20 diff --git a/src/MaxText/configs/models/deepseek-custom.yml b/src/MaxText/configs/models/deepseek-custom.yml new file mode 100644 index 0000000000..46bd43e49b --- /dev/null +++ b/src/MaxText/configs/models/deepseek-custom.yml @@ -0,0 +1,61 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Small model config for testing (derived from DeepSeek V3.2 - 671B) + +base_emb_dim: 1024 # Reduced from 7168 +base_num_query_heads: 16 # Reduced from 128 +base_num_kv_heads: 16 # Reduced from 128 +base_mlp_dim: 2048 # Reduced from 18432 +base_moe_mlp_dim: 512 # Reduced from 2048 +base_num_decoder_layers: 6 # Reduced from 61 +first_num_dense_layers: 1 # Reduced from 3 +mlp_activations: ["silu","linear"] +vocab_size: 129280 +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-6 +num_experts: 16 # Reduced from 256 +num_experts_per_tok: 2 # Reduced from 8 +shared_experts: 1 +routed_scaling_factor: 2.5 +routed_score_func: "sigmoid" +routed_bias: True +decoder_block: "deepseek" +# MLA +attention_type: "mla" +q_lora_rank: 384 # Reduced from 1536 +kv_lora_rank: 128 # Reduced from 512 +qk_nope_head_dim: 32 # Reduced from 128 +qk_rope_head_dim: 16 # Reduced from 64 +v_head_dim: 128 +# RoPE +mscale: 1.0 +rope_type: "yarn" +rope_max_timescale: 10_000 +max_position_embeddings: 4096 # Reduced for local testing +original_max_position_embeddings: 4096 +rope_factor: 1 +beta_fast: 32 +rope_interleave: True +rope_truncate: True +rope_attention_scaling: False +# Indexer for DeepSeek Sparse Attention +use_sparse_indexer: True +index_n_heads: 16 # Reduced from 64 +index_head_dim: 64 # Reduced from 128 +index_topk: 256 # Reduced from 2048 +# Hyper-connections: mHC enabled +mhc_expansion_rate: 4 +sinkhorn_iterations: 20 diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index e1d5d811f4..8e902543b9 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -210,6 +210,7 @@ class ProfilerType(str, Enum): "deepseek3-test", "deepseek3-tiny", "deepseek3.2-671b", + "deepseek-custom", "kimi-k2-1t", "gemma-7b", "gemma-2b", @@ -1057,6 +1058,13 @@ class TrainingLoop(BaseModel): init_weights_seed: int = Field(0, description="Seed for model weight initialization.") +class ManifoldConstrainedHyperConnections(BaseModel): + """Configuration for DeepSeek Manifold-Constrained Hyper Connections (mHC).""" + + mhc_expansion_rate: int = Field(0, description="The number of parallel streams in Hyper Connection.") + sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.") + + class Optimizer(BaseModel): """Configuration for the optimizer and learning rate schedule.""" @@ -1743,6 +1751,7 @@ class MaxTextConfig( # Training, Optimization, and Fine-Tuning RematAndOffload, TrainingLoop, + ManifoldConstrainedHyperConnections, Optimizer, AdamW, Muon, diff --git a/src/MaxText/layers/initializers.py b/src/MaxText/layers/initializers.py index 9dfac8759c..955d4c3d05 100644 --- a/src/MaxText/layers/initializers.py +++ b/src/MaxText/layers/initializers.py @@ -31,6 +31,7 @@ default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0) default_bias_init = jax.nn.initializers.constant(0.0) +default_scalar_init = jax.nn.initializers.constant(0.01) def nd_dense_init(scale, mode, distribution): diff --git a/src/MaxText/layers/mhc.py b/src/MaxText/layers/mhc.py new file mode 100644 index 0000000000..f1a2da1c8c --- /dev/null +++ b/src/MaxText/layers/mhc.py @@ -0,0 +1,235 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer.""" + +import jax +from jax.sharding import Mesh + +import jax.numpy as jnp +from flax import nnx +from typing import Callable +from MaxText.common_types import Config, Array +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.initializers import nd_dense_init, default_bias_init, default_scalar_init +from MaxText.common_types import HyperConnectionType + + +def get_functions(expansion_rate: int): + """ + Creates functions to broadcast a single feature stream into multiple + parallel paths (expand) and aggregate them back (reduce). + """ + + def expand(x: Array): + # (batch, length, dim) -> (batch, length, streams, dim) + return jnp.repeat(jnp.expand_dims(x, axis=2), expansion_rate, axis=2) + + def reduce(x: Array): + # (batch, length, streams, dim) -> (batch, length, dim) + return jnp.sum(x, axis=2) + + return expand, reduce + + +def sinkhorn(t, iters=20): + """ + Computes the Sinkhorn normalization of a matrix (rows and columns sum to 1). + """ + # Use float32 precision for numerical stability during normalization + initial_dtype = t.dtype + t = t.astype(jnp.float32) + + # Column-wise normalization (axis=-2) - positive and sum up to 1 across columns + # Equivalent to t = exp(t) / jnp.sum(jnp.exp(t), axis=-2) + t = jax.nn.softmax(t, axis=-2) + + def body_fun(i, val): + # L1 Normalization: val / sum(val) with clipping of denominator + # Normalize rows (axis -1) + val = val / jnp.clip(jnp.sum(val, axis=-1, keepdims=True), min=1e-12) + # Normalize columns (axis -2) + val = val / jnp.clip(jnp.sum(val, axis=-2, keepdims=True), min=1e-12) + return val + + # Use lax.fori_loop for an efficient, JIT-friendly loop + t = jax.lax.fori_loop(0, iters, body_fun, t) + return t.astype(initial_dtype) + + +class ManifoldConstrainedHyperConnections(nnx.Module): + """Implements Manifold-Constrained Hyper-Connections (mHC). + + Reference: https://arxiv.org/pdf/2512.24880 + + Args: + config: Configuration object containing hyperparameters. + dim: The feature dimensionality. + mesh: The hardware mesh for sharding. + rngs: Random number generation in NNX. + """ + + def __init__( + self, + config: Config, + dim: int, + mesh: Mesh, + rngs: nnx.Rngs, + ): + self.config = config + self.sinkhorn_iterations = config.sinkhorn_iterations + self.k = config.mhc_expansion_rate + self.dim = dim + self.rngs = rngs + self.mesh = mesh + self.weight_dtype = self.config.weight_dtype + + # Norm layer + self.mhc_norm = RMSNorm( + num_features=self.k * self.dim, + dtype=self.config.dtype, + weight_dtype=self.weight_dtype, + kernel_axes=("norm",), + epsilon=self.config.normalization_layer_epsilon, + rngs=self.rngs, + ) + + # Scalars + self.res_alpha_scale = nnx.Param( + default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), + sharding=(None,), + ) + self.pre_alpha_scale = nnx.Param( + default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), + sharding=(None,), + ) + self.post_alpha_scale = nnx.Param( + default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), + sharding=(None,), + ) + + # Weight matrices + scale_init = nd_dense_init(1.0, "fan_in", "normal") + in_axis = 0 + out_axis = 1 + weight_sharding_axis_name = ("activation_embed", None) + self.res_alpha = nnx.Param( + scale_init( + self.rngs.params(), + (self.k * self.dim, self.k * self.k), + self.weight_dtype, + in_axis=in_axis, + out_axis=out_axis, + ), + sharding=weight_sharding_axis_name, + ) + self.pre_alpha = nnx.Param( + scale_init( + self.rngs.params(), + (self.k * self.dim, self.k), + self.weight_dtype, + in_axis=in_axis, + out_axis=out_axis, + ), + sharding=weight_sharding_axis_name, + ) + self.post_alpha = nnx.Param( + scale_init( + self.rngs.params(), + (self.k * self.dim, self.k), + self.weight_dtype, + in_axis=in_axis, + out_axis=out_axis, + ), + sharding=weight_sharding_axis_name, + ) + + # Biases + self.res_beta = nnx.Param( + default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype), + sharding=(None, None), + ) + self.pre_beta = nnx.Param( + default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), + sharding=(None, None), + ) + self.post_beta = nnx.Param( + default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), + sharding=(None, None), + ) + + def res_mapping(self, x: Array): + """Helper function for residual mapping.""" + # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k) + h_res = jnp.einsum("bsm,mn -> bsn", x, self.res_alpha[...], precision=self.config.matmul_precision) + b, s, _ = h_res.shape + h_res = jnp.reshape(h_res, (b, s, self.k, self.k)) + intermediate = self.res_alpha_scale * h_res + self.res_beta[...][None, None, :, :] + output = sinkhorn(intermediate, self.sinkhorn_iterations) + return output + + def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int): + """Helper function for both pre and post mappings.""" + # Apply projection: (b, s, k*d) @ (k*d, k) -> (b, s, k) + h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.config.matmul_precision) + intermediate = alpha_scale * h + beta[None, None, :] + output = scale * jax.nn.sigmoid(intermediate) + return output + + def __call__( + self, + branch_fn: Callable, + x: Array, + mhc_type: HyperConnectionType, + **kwargs, + ) -> Array: + """Applying manifold-constrained hyper connection based on callable function. + + Args: + branch_fn: The function to be wrapped by the hyper-connection. + x: Input tensor of shape `(batch..., dim)`. + mhc_type: The variant of the connection to apply. + **kwargs: Additional context passed to the branch function. + + Returns: + The processed tensor, maintaining the shape of `x`. + """ + # x shape: [batch, seq, expansion_rate, emb] + b, s, k, d = x.shape + + # 1. Flatten the tensor, and RMS normalization + norm_x = self.mhc_norm(jnp.reshape(x, (b, s, k * d))) + + # 2. Pre mapping + pre_mapping = self.mapping(norm_x, self.pre_alpha_scale, self.pre_alpha[...], self.pre_beta[...], 1.0) + layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.config.matmul_precision) + + # 3. Attention or MLP + if mhc_type == HyperConnectionType.ATTENTION: + layer_out, _ = branch_fn(inputs_q=layer_input, inputs_kv=layer_input, **kwargs) + elif mhc_type == HyperConnectionType.MLP_DENSE: + layer_out = branch_fn(inputs=layer_input, **kwargs) + elif mhc_type == HyperConnectionType.MLP_MOE: + layer_out, _, _ = branch_fn(inputs=layer_input, **kwargs) + else: + raise ValueError(f"Unsupported type: {mhc_type}") + + # 4. Post mapping + post_mapping = self.mapping(norm_x, self.post_alpha_scale, self.post_alpha[...], self.post_beta[...], 2.0) + post_out = jnp.einsum("bsd,bsk -> bskd", layer_out, post_mapping, precision=self.config.matmul_precision) + + # 5. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb] + res_mapping = self.res_mapping(norm_x) + res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.config.matmul_precision) + return res_out + post_out diff --git a/tests/unit/mhc_test.py b/tests/unit/mhc_test.py new file mode 100644 index 0000000000..e02f920a61 --- /dev/null +++ b/tests/unit/mhc_test.py @@ -0,0 +1,201 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for DeepSeek Manifold-Constrained Hyper Connections (mHC).""" + +import os.path +import unittest + +from flax import nnx +from flax.linen import partitioning as nn_partitioning +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +import numpy as np + +from MaxText import pyconfig +from MaxText.common_types import HyperConnectionType +from MaxText.globals import MAXTEXT_PKG_DIR +from MaxText.layers import attention_mla, linears, mhc, moe +from MaxText.layers.initializers import nd_dense_init +from maxtext.utils import maxtext_utils + + +class TestExpandReduce(unittest.TestCase): + """Unit tests for MHC dimension expansion and reduction operations.""" + + def setUp(self): + self.rate = 4 + self.batch, self.seq_len, self.dim = 2, 8, 12 + self.shape = (self.batch, self.seq_len, self.dim) + self.expand, self.reduce = mhc.get_functions(self.rate) + + # Consistent random data for testing + self.key = jax.random.PRNGKey(0) + self.x = jax.random.normal(self.key, self.shape) + + def test_expand_shape(self): + """Verifies (B, S, D) -> (B, S, K, D)""" + out = self.expand(self.x) + expected_shape = (self.batch, self.seq_len, self.rate, self.dim) + self.assertEqual(out.shape, expected_shape) + + def test_reduce_shape(self): + """Verifies (B, S, K, D) -> (B, S, D)""" + dummy_expanded = jnp.ones((self.batch, self.seq_len, self.rate, self.dim)) + out = self.reduce(dummy_expanded) + self.assertEqual(out.shape, self.shape) + + def test_value_identity(self): + """Mathematically, reduce(expand(x)) should equal expansion_rate * x.""" + out = self.reduce(self.expand(self.x)) + expected = self.x * self.rate + np.testing.assert_allclose(out, expected, rtol=1e-5) + + +class TestSinkhorn(unittest.TestCase): + """Unit tests for MHC Sinkhorn Algorithm.""" + + def setUp(self): + self.key = jax.random.PRNGKey(42) + self.matrix_shape = (8, 8) + self.t = jax.random.normal(self.key, self.matrix_shape) + + def test_doubly_stochastic_property(self): + """After many iterations, rows and columns should sum to approximately 1.""" + # Use more iterations to ensure convergence + out = mhc.sinkhorn(self.t, iters=20) + + row_sums = jnp.sum(out, axis=-1) + col_sums = jnp.sum(out, axis=-2) + + # Check if sums are close to 1.0 + np.testing.assert_allclose(row_sums, jnp.ones_like(row_sums), atol=1e-3) + np.testing.assert_allclose(col_sums, jnp.ones_like(col_sums), atol=1e-3) + + +class TestMHC(unittest.TestCase): + """Test for MHC module""" + + def setUp(self): + self.dim = 16 + self.config = pyconfig.initialize( + [None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + run_name="test_mhc", + enable_checkpointing=False, + model_name="deepseek-custom", + per_device_batch_size=4, + max_target_length=7, + max_prefill_predict_length=7, + base_emb_dim=self.dim, + mhc_expansion_rate=3, + num_experts=4, + num_experts_per_tok=2, + attention="dot_product", + ) + devices_array = maxtext_utils.create_device_mesh(self.config) + self.mesh = Mesh(devices_array, self.config.mesh_axes) + + self.rngs = nnx.Rngs(params=jax.random.key(0), dropout=jax.random.key(42)) + self.x = jax.random.normal( + jax.random.PRNGKey(0), + ( + self.config.per_device_batch_size, + self.config.max_target_length, + self.config.mhc_expansion_rate, + self.config.emb_dim, + ), + ) + + def test_moe_layer_output_shape(self): + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs) + layer = moe.RoutedMoE( + config=self.config, + num_experts=self.config.num_experts, + num_experts_per_tok=self.config.num_experts_per_tok, + mesh=self.mesh, + kernel_init=nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", "mlp"), + intermediate_dim=self.config.base_mlp_dim, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + rngs=self.rngs, + ) + + b, s, k, d = self.x.shape + output = module(layer, x=self.x, mhc_type=HyperConnectionType.MLP_MOE) + self.assertEqual(output.shape, (b, s, k, d)) + + def test_dense_layer_output_shape(self): + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs) + layer = linears.MlpBlock( + config=self.config, + mesh=self.mesh, + in_features=self.config.emb_dim, + intermediate_dim=self.config.moe_mlp_dim, + activations=self.config.mlp_activations, + intermediate_dropout_rate=self.config.dropout_rate, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + model_mode=self.config.model_call_mode, + rngs=self.rngs, + ) + + b, s, k, d = self.x.shape + output = module(layer, x=self.x, mhc_type=HyperConnectionType.MLP_DENSE) + self.assertEqual(output.shape, (b, s, k, d)) + + def test_attention_layer_output_shape(self): + inputs_shape = (self.config.per_device_batch_size, self.config.max_target_length, self.config.emb_dim) + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + module = mhc.ManifoldConstrainedHyperConnections(self.config, self.dim, self.mesh, self.rngs) + layer = attention_mla.MLA( + config=self.config, + num_query_heads=self.config.num_query_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + attention_type=self.config.attention_type, + inputs_q_shape=inputs_shape, + inputs_kv_shape=inputs_shape, + mesh=self.mesh, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, + name="self_attention", + q_lora_rank=self.config.q_lora_rank, + kv_lora_rank=self.config.kv_lora_rank, + qk_nope_head_dim=self.config.qk_nope_head_dim, + qk_rope_head_dim=self.config.qk_rope_head_dim, + v_head_dim=self.config.v_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + original_max_position_embeddings=self.config.original_max_position_embeddings, + mscale=self.config.mscale, + rope_factor=self.config.rope_factor, + model_mode="train", + rngs=self.rngs, + attn_logits_soft_cap=self.config.attn_logits_soft_cap, + ) + + b, s, k, d = self.x.shape + output = module(layer, x=self.x, mhc_type=HyperConnectionType.ATTENTION) + self.assertEqual(output.shape, (b, s, k, d)) + + +if __name__ == "__main__": + unittest.main()