Skip to content

Commit ee0f01b

Browse files
committed
Onboard DeepSeek MHC feature
1 parent af14e43 commit ee0f01b

6 files changed

Lines changed: 520 additions & 0 deletions

File tree

src/MaxText/common_types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,9 @@ class AttentionType(enum.Enum):
114114
class ShardMode(enum.Enum):
115115
AUTO = "auto" # default
116116
EXPLICIT = "explicit"
117+
118+
119+
class HyperConnectionType(enum.Enum):
120+
ATTENTION = "attention"
121+
MLP_MOE = "mlp_moe"
122+
MLP_DENSE = "mlp_dense"

src/MaxText/configs/base.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,3 +1051,15 @@ vllm_hf_config_path: ""
10511051
vllm_additional_config: {}
10521052
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]
10531053
force_q_layout: false
1054+
1055+
################################## DeepSeek Manifold-Constrained Hyper Connections (mHC) ##################################
1056+
# The number of parallel streams in Hyper Connection.
1057+
mhc_expansion_rate: 0
1058+
# The scale for the residual mapping.
1059+
mhc_res_alpha_scale: 0.01
1060+
# The scale for the pre mapping.
1061+
mhc_pre_alpha_scale: 0.01
1062+
# The scale for the post mapping.
1063+
mhc_post_alpha_scale: 0.01
1064+
# The number of iterations for the Sinkhorn-Knopp algorithm.
1065+
sinkhorn_iterations: 20
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Small model config for testing (derived from DeepSeek V3.2 - 671B)
16+
17+
base_emb_dim: 1024 # Reduced from 7168
18+
base_num_query_heads: 16 # Reduced from 128
19+
base_num_kv_heads: 16 # Reduced from 128
20+
base_mlp_dim: 2048 # Reduced from 18432
21+
base_moe_mlp_dim: 512 # Reduced from 2048
22+
base_num_decoder_layers: 6 # Reduced from 61
23+
first_num_dense_layers: 1 # Reduced from 3
24+
mlp_activations: ["silu","linear"]
25+
vocab_size: 129280
26+
enable_dropout: False
27+
logits_via_embedding: False
28+
normalization_layer_epsilon: 1.0e-6
29+
num_experts: 16 # Reduced from 256
30+
num_experts_per_tok: 2 # Reduced from 8
31+
shared_experts: 1
32+
routed_scaling_factor: 2.5
33+
routed_score_func: "sigmoid"
34+
routed_bias: True
35+
decoder_block: "deepseek"
36+
# MLA
37+
attention_type: "mla"
38+
q_lora_rank: 384 # Reduced from 1536
39+
kv_lora_rank: 128 # Reduced from 512
40+
qk_nope_head_dim: 32 # Reduced from 128
41+
qk_rope_head_dim: 16 # Reduced from 64
42+
v_head_dim: 128
43+
# RoPE
44+
mscale: 1.0
45+
rope_type: "yarn"
46+
rope_max_timescale: 10_000
47+
max_position_embeddings: 4096 # Reduced for local testing
48+
original_max_position_embeddings: 4096
49+
rope_factor: 1
50+
beta_fast: 32
51+
rope_interleave: True
52+
rope_truncate: True
53+
rope_attention_scaling: False
54+
# Indexer for DeepSeek Sparse Attention
55+
use_sparse_indexer: True
56+
index_n_heads: 16 # Reduced from 64
57+
index_head_dim: 64 # Reduced from 128
58+
index_topk: 256 # Reduced from 2048
59+
# Hyper-connections: mHC enabled
60+
mhc_expansion_rate: 4
61+
mhc_res_alpha_scale: 0.01
62+
mhc_pre_alpha_scale: 0.01
63+
mhc_post_alpha_scale: 0.01
64+
sinkhorn_iterations: 20

src/MaxText/configs/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ class ProfilerType(str, Enum):
209209
"deepseek3-test",
210210
"deepseek3-tiny",
211211
"deepseek3.2-671b",
212+
"deepseek-custom",
212213
"kimi-k2-1t",
213214
"gemma-7b",
214215
"gemma-2b",
@@ -1048,6 +1049,16 @@ class TrainingLoop(BaseModel):
10481049
init_weights_seed: int = Field(0, description="Seed for model weight initialization.")
10491050

10501051

1052+
class ManifoldConstrainedHyperConnections(BaseModel):
1053+
"""Configuration for DeepSeek Manifold-Constrained Hyper Connections (mHC)."""
1054+
1055+
mhc_expansion_rate: int = Field(0, description="The number of parallel streams in Hyper Connection.")
1056+
mhc_res_alpha_scale: float = Field(0.01, description="The scale for the residual mapping.")
1057+
mhc_pre_alpha_scale: float = Field(0.01, description="The scale for the pre mapping.")
1058+
mhc_post_alpha_scale: float = Field(0.01, description="The scale for the post mapping.")
1059+
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
1060+
1061+
10511062
class Optimizer(BaseModel):
10521063
"""Configuration for the optimizer and learning rate schedule."""
10531064

@@ -1727,6 +1738,7 @@ class MaxTextConfig(
17271738
# Training, Optimization, and Fine-Tuning
17281739
RematAndOffload,
17291740
TrainingLoop,
1741+
ManifoldConstrainedHyperConnections,
17301742
Optimizer,
17311743
AdamW,
17321744
Muon,

src/MaxText/layers/mhc.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer."""
16+
17+
import jax
18+
from jax.sharding import Mesh
19+
20+
import jax.numpy as jnp
21+
from flax import nnx
22+
from typing import Callable
23+
from MaxText.common_types import Config
24+
from MaxText.layers.normalizations import RMSNorm
25+
from MaxText.layers.initializers import nd_dense_init, default_bias_init
26+
from MaxText.common_types import HyperConnectionType
27+
28+
29+
def get_functions(expansion_rate: int):
30+
"""
31+
Creates functions to broadcast a single feature stream into multiple
32+
parallel paths (expand) and aggregate them back (reduce).
33+
"""
34+
35+
def expand(x: jnp.ndarray):
36+
# (batch, length, dim) -> (streams, batch, length, dim)
37+
return jnp.repeat(jnp.expand_dims(x, axis=0), expansion_rate, axis=0)
38+
39+
def reduce(x: jnp.ndarray):
40+
# (streams, batch, length, dim) -> (batch, length, dim)
41+
return jnp.sum(x, axis=0)
42+
43+
return expand, reduce
44+
45+
46+
def sinkhorn(t, iters=20):
47+
"""
48+
Computes the Sinkhorn normalization of a matrix (rows and columns sum to 1).
49+
"""
50+
# Use float32 precision for numerical stability during normalization
51+
initial_dtype = t.dtype
52+
t = t.astype(jnp.float32)
53+
54+
# Initial softmax along the rows (dim -2)
55+
# Makes values to be positive and sum up to 1 across columns
56+
t = jax.nn.softmax(t, axis=-2)
57+
58+
def body_fun(i, val):
59+
# L1 Normalization: val / sum(val) with clipping of denominator
60+
# Normalize rows (axis -1)
61+
val = val / jnp.clip(jnp.sum(val, axis=-1, keepdims=True), min=1e-12)
62+
# Normalize columns (axis -2)
63+
val = val / jnp.clip(jnp.sum(val, axis=-2, keepdims=True), min=1e-12)
64+
return val
65+
66+
# Use lax.fori_loop for an efficient, JIT-friendly loop
67+
t = jax.lax.fori_loop(0, iters, body_fun, t)
68+
return t.astype(initial_dtype)
69+
70+
71+
class ManifoldConstrainedHyperConnections(nnx.Module):
72+
"""Implements Manifold-Constrained Hyper-Connections (mHC).
73+
74+
Reference: https://arxiv.org/pdf/2512.24880
75+
76+
Args:
77+
config: Configuration object containing hyperparameters.
78+
model_mode: String indicating the execution context.
79+
dim: The feature dimensionality.
80+
mesh: The hardware mesh for sharding.
81+
rngs: Random number generation in NNX.
82+
"""
83+
84+
def __init__(
85+
self,
86+
config: Config,
87+
model_mode: str,
88+
dim: int,
89+
mesh: Mesh,
90+
rngs: nnx.Rngs,
91+
):
92+
self.config = config
93+
self.sinkhorn_iterations = config.sinkhorn_iterations
94+
self.k = config.mhc_expansion_rate
95+
self.dim = dim
96+
self.rngs = rngs
97+
self.mesh = mesh
98+
self.weight_dtype = self.config.weight_dtype
99+
100+
# Norm layer
101+
self.mhc_norm = RMSNorm(
102+
num_features=self.dim,
103+
dtype=self.config.dtype,
104+
weight_dtype=self.weight_dtype,
105+
kernel_axes=("norm",),
106+
epsilon=self.config.normalization_layer_epsilon,
107+
rngs=self.rngs,
108+
)
109+
110+
# Scalers
111+
self.mhc_res_alpha_scale = self.config.mhc_res_alpha_scale
112+
self.mhc_pre_alpha_scale = self.config.mhc_pre_alpha_scale
113+
self.mhc_post_alpha_scale = self.config.mhc_post_alpha_scale
114+
115+
# Weight matrices
116+
scale_init = nd_dense_init(1.0, "fan_in", "normal")
117+
in_axis = (0, 1)
118+
out_axis = 2
119+
weight_sharding_axis_name = (None, "activation_embed", None)
120+
self.res_alpha = nnx.Param(
121+
scale_init(
122+
self.rngs.params(),
123+
(self.k, self.dim, self.k * self.k),
124+
self.weight_dtype,
125+
in_axis=in_axis,
126+
out_axis=out_axis,
127+
),
128+
sharding=weight_sharding_axis_name,
129+
)
130+
self.pre_alpha = nnx.Param(
131+
scale_init(
132+
self.rngs.params(),
133+
(self.k, self.dim, self.k),
134+
self.weight_dtype,
135+
in_axis=in_axis,
136+
out_axis=out_axis,
137+
),
138+
sharding=weight_sharding_axis_name,
139+
)
140+
self.post_alpha = nnx.Param(
141+
scale_init(
142+
self.rngs.params(),
143+
(self.k, self.dim, self.k),
144+
self.weight_dtype,
145+
in_axis=in_axis,
146+
out_axis=out_axis,
147+
),
148+
sharding=weight_sharding_axis_name,
149+
)
150+
151+
# Biases
152+
self.res_beta = nnx.Param(
153+
default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype),
154+
sharding=(None, None),
155+
)
156+
self.pre_beta = nnx.Param(
157+
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
158+
sharding=(None, None),
159+
)
160+
self.post_beta = nnx.Param(
161+
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
162+
sharding=(None, None),
163+
)
164+
165+
def res_mapping(self, x: jnp.ndarray):
166+
"""Helper function for residule mapping."""
167+
# Apply projection: (k, b, s, d) @ (k, d, k*k) -> (k*k)
168+
h_res = jnp.einsum("kbsd,kdm -> m", x, self.res_alpha[...])
169+
h_res = jnp.reshape(h_res, (self.k, self.k))
170+
intermediate = self.mhc_res_alpha_scale * h_res + self.res_beta[...]
171+
output = sinkhorn(intermediate, self.sinkhorn_iterations)
172+
return output
173+
174+
def mapping(self, x: jnp.ndarray, alpha_scale: jnp.ndarray, alpha: jnp.ndarray, beta: jnp.ndarray, scale: int):
175+
"""Helper function for both pre and post mappings."""
176+
# Apply projection: (k, b, s, d) @ (k, d, k) -> (k)
177+
h = jnp.einsum("kbsd,kdm -> m", x, alpha)
178+
intermediate = alpha_scale * h + beta
179+
output = scale * jax.nn.sigmoid(intermediate)
180+
return output
181+
182+
def __call__(
183+
self,
184+
branch_fn: Callable,
185+
x: jnp.ndarray,
186+
mhc_type: HyperConnectionType,
187+
**kwargs,
188+
) -> jnp.ndarray:
189+
"""Applying manifold-constrained hyper connection based on callable function.
190+
191+
Args:
192+
branch_fn: The function to be wrapped by the hyper-connection.
193+
x: Input tensor of shape `(batch..., dim)`.
194+
mhc_type: The variant of the connection to apply.
195+
**kwargs: Additional context passed to the branch function.
196+
197+
Returns:
198+
The processed tensor, maintaining the shape of `x`.
199+
"""
200+
# x shape: [expansion_rate, batch, seq, emb]
201+
# 1. RMS normalization
202+
x = self.mhc_norm(x)
203+
204+
# 2. Pre mapping
205+
pre_mapping = self.mapping(x, self.mhc_pre_alpha_scale, self.pre_alpha[...], self.pre_beta[...], 1.0)
206+
layer_input = jnp.einsum("kbsd,k -> bsd", x, pre_mapping)
207+
208+
# 3. Attention or MLP
209+
if mhc_type == HyperConnectionType.ATTENTION:
210+
layer_out, _ = branch_fn(inputs_q=layer_input, inputs_kv=layer_input, **kwargs)
211+
elif mhc_type == HyperConnectionType.MLP_DENSE:
212+
layer_out = branch_fn(inputs=layer_input, **kwargs)
213+
elif mhc_type == HyperConnectionType.MLP_MOE:
214+
layer_out, _, _ = branch_fn(inputs=layer_input, **kwargs)
215+
else:
216+
raise ValueError(f"Unsupported type: {mhc_type}")
217+
218+
# 4. Post mapping
219+
post_mapping = self.mapping(x, self.mhc_post_alpha_scale, self.post_alpha[...], self.post_beta[...], 2.0)
220+
post_out = jnp.einsum("bsd,k -> kbsd", layer_out, post_mapping)
221+
222+
# 5. Residual mapping, res_out shape as [expansion_rate, batch, seq, emb]
223+
res_mapping = self.res_mapping(x)
224+
res_out = jnp.einsum("kbsd,km -> mbsd", x, res_mapping)
225+
return res_out + post_out

0 commit comments

Comments
 (0)