Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,9 @@ class AttentionType(enum.Enum):
class ShardMode(enum.Enum):
AUTO = "auto" # default
EXPLICIT = "explicit"


class HyperConnectionType(enum.Enum):
ATTENTION = "attention"
MLP_MOE = "mlp_moe"
MLP_DENSE = "mlp_dense"
6 changes: 6 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1065,3 +1065,9 @@ vllm_hf_config_path: ""
vllm_additional_config: {}
# When use_jax_splash=True, force the layout of the query tensor to be [..., NUM_HEADS, HEAD_DIM, SEQ_LENGTH]
force_q_layout: false

################################## DeepSeek Manifold-Constrained Hyper Connections (mHC) ##################################
# The number of parallel streams in Hyper Connection.
mhc_expansion_rate: 0
# The number of iterations for the Sinkhorn-Knopp algorithm.
sinkhorn_iterations: 20
61 changes: 61 additions & 0 deletions src/MaxText/configs/models/deepseek-custom.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Small model config for testing (derived from DeepSeek V3.2 - 671B)

base_emb_dim: 1024 # Reduced from 7168
base_num_query_heads: 16 # Reduced from 128
base_num_kv_heads: 16 # Reduced from 128
base_mlp_dim: 2048 # Reduced from 18432
base_moe_mlp_dim: 512 # Reduced from 2048
base_num_decoder_layers: 6 # Reduced from 61
first_num_dense_layers: 1 # Reduced from 3
mlp_activations: ["silu","linear"]
vocab_size: 129280
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-6
num_experts: 16 # Reduced from 256
num_experts_per_tok: 2 # Reduced from 8
shared_experts: 1
routed_scaling_factor: 2.5
routed_score_func: "sigmoid"
routed_bias: True
decoder_block: "deepseek"
# MLA
attention_type: "mla"
q_lora_rank: 384 # Reduced from 1536
kv_lora_rank: 128 # Reduced from 512
qk_nope_head_dim: 32 # Reduced from 128
qk_rope_head_dim: 16 # Reduced from 64
v_head_dim: 128
# RoPE
mscale: 1.0
rope_type: "yarn"
rope_max_timescale: 10_000
max_position_embeddings: 4096 # Reduced for local testing
original_max_position_embeddings: 4096
rope_factor: 1
beta_fast: 32
rope_interleave: True
rope_truncate: True
rope_attention_scaling: False
# Indexer for DeepSeek Sparse Attention
use_sparse_indexer: True
index_n_heads: 16 # Reduced from 64
index_head_dim: 64 # Reduced from 128
index_topk: 256 # Reduced from 2048
# Hyper-connections: mHC enabled
mhc_expansion_rate: 4
sinkhorn_iterations: 20
9 changes: 9 additions & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class ProfilerType(str, Enum):
"deepseek3-test",
"deepseek3-tiny",
"deepseek3.2-671b",
"deepseek-custom",
"kimi-k2-1t",
"gemma-7b",
"gemma-2b",
Expand Down Expand Up @@ -1057,6 +1058,13 @@ class TrainingLoop(BaseModel):
init_weights_seed: int = Field(0, description="Seed for model weight initialization.")


class ManifoldConstrainedHyperConnections(BaseModel):
"""Configuration for DeepSeek Manifold-Constrained Hyper Connections (mHC)."""

mhc_expansion_rate: int = Field(0, description="The number of parallel streams in Hyper Connection.")
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")


class Optimizer(BaseModel):
"""Configuration for the optimizer and learning rate schedule."""

Expand Down Expand Up @@ -1743,6 +1751,7 @@ class MaxTextConfig(
# Training, Optimization, and Fine-Tuning
RematAndOffload,
TrainingLoop,
ManifoldConstrainedHyperConnections,
Optimizer,
AdamW,
Muon,
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/layers/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)

default_bias_init = jax.nn.initializers.constant(0.0)
default_scalar_init = jax.nn.initializers.constant(0.01)


def nd_dense_init(scale, mode, distribution):
Expand Down
235 changes: 235 additions & 0 deletions src/MaxText/layers/mhc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer."""

import jax
from jax.sharding import Mesh

import jax.numpy as jnp
from flax import nnx
from typing import Callable
from MaxText.common_types import Config, Array
from MaxText.layers.normalizations import RMSNorm
from MaxText.layers.initializers import nd_dense_init, default_bias_init, default_scalar_init
from MaxText.common_types import HyperConnectionType


def get_functions(expansion_rate: int):
"""
Creates functions to broadcast a single feature stream into multiple
parallel paths (expand) and aggregate them back (reduce).
"""

def expand(x: Array):
# (batch, length, dim) -> (batch, length, streams, dim)
return jnp.repeat(jnp.expand_dims(x, axis=2), expansion_rate, axis=2)

def reduce(x: Array):
# (batch, length, streams, dim) -> (batch, length, dim)
return jnp.sum(x, axis=2)

return expand, reduce


def sinkhorn(t, iters=20):
"""
Computes the Sinkhorn normalization of a matrix (rows and columns sum to 1).
"""
# Use float32 precision for numerical stability during normalization
initial_dtype = t.dtype
t = t.astype(jnp.float32)

# Column-wise normalization (axis=-2) - positive and sum up to 1 across columns
# Equivalent to t = exp(t) / jnp.sum(jnp.exp(t), axis=-2)
t = jax.nn.softmax(t, axis=-2)

def body_fun(i, val):
# L1 Normalization: val / sum(val) with clipping of denominator
# Normalize rows (axis -1)
val = val / jnp.clip(jnp.sum(val, axis=-1, keepdims=True), min=1e-12)
# Normalize columns (axis -2)
val = val / jnp.clip(jnp.sum(val, axis=-2, keepdims=True), min=1e-12)
return val

# Use lax.fori_loop for an efficient, JIT-friendly loop
t = jax.lax.fori_loop(0, iters, body_fun, t)
return t.astype(initial_dtype)


class ManifoldConstrainedHyperConnections(nnx.Module):
"""Implements Manifold-Constrained Hyper-Connections (mHC).

Reference: https://arxiv.org/pdf/2512.24880

Args:
config: Configuration object containing hyperparameters.
dim: The feature dimensionality.
mesh: The hardware mesh for sharding.
rngs: Random number generation in NNX.
"""

def __init__(
self,
config: Config,
dim: int,
mesh: Mesh,
rngs: nnx.Rngs,
):
self.config = config
self.sinkhorn_iterations = config.sinkhorn_iterations
self.k = config.mhc_expansion_rate
self.dim = dim
self.rngs = rngs
self.mesh = mesh
self.weight_dtype = self.config.weight_dtype

# Norm layer
self.mhc_norm = RMSNorm(
num_features=self.k * self.dim,
dtype=self.config.dtype,
weight_dtype=self.weight_dtype,
kernel_axes=("norm",),
epsilon=self.config.normalization_layer_epsilon,
rngs=self.rngs,
)

# Scalars
self.res_alpha_scale = nnx.Param(
default_scalar_init(self.rngs.params(), (1,), self.weight_dtype),
sharding=(None,),
)
self.pre_alpha_scale = nnx.Param(
default_scalar_init(self.rngs.params(), (1,), self.weight_dtype),
sharding=(None,),
)
self.post_alpha_scale = nnx.Param(
default_scalar_init(self.rngs.params(), (1,), self.weight_dtype),
sharding=(None,),
)

# Weight matrices
scale_init = nd_dense_init(1.0, "fan_in", "normal")
in_axis = 0
out_axis = 1
weight_sharding_axis_name = ("activation_embed", None)
self.res_alpha = nnx.Param(
scale_init(
self.rngs.params(),
(self.k * self.dim, self.k * self.k),
self.weight_dtype,
in_axis=in_axis,
out_axis=out_axis,
),
sharding=weight_sharding_axis_name,
)
self.pre_alpha = nnx.Param(
scale_init(
self.rngs.params(),
(self.k * self.dim, self.k),
self.weight_dtype,
in_axis=in_axis,
out_axis=out_axis,
),
sharding=weight_sharding_axis_name,
)
self.post_alpha = nnx.Param(
scale_init(
self.rngs.params(),
(self.k * self.dim, self.k),
self.weight_dtype,
in_axis=in_axis,
out_axis=out_axis,
),
sharding=weight_sharding_axis_name,
)

# Biases
self.res_beta = nnx.Param(
default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype),
sharding=(None, None),
)
self.pre_beta = nnx.Param(
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
sharding=(None, None),
)
self.post_beta = nnx.Param(
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
sharding=(None, None),
)

def res_mapping(self, x: Array):
"""Helper function for residual mapping."""
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
h_res = jnp.einsum("bsm,mn -> bsn", x, self.res_alpha[...], precision=self.config.matmul_precision)
b, s, _ = h_res.shape
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
intermediate = self.res_alpha_scale * h_res + self.res_beta[...][None, None, :, :]
output = sinkhorn(intermediate, self.sinkhorn_iterations)
return output

def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int):
"""Helper function for both pre and post mappings."""
# Apply projection: (b, s, k*d) @ (k*d, k) -> (b, s, k)
h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.config.matmul_precision)
intermediate = alpha_scale * h + beta[None, None, :]
output = scale * jax.nn.sigmoid(intermediate)
return output

def __call__(
self,
branch_fn: Callable,
x: Array,
mhc_type: HyperConnectionType,
**kwargs,
) -> Array:
"""Applying manifold-constrained hyper connection based on callable function.

Args:
branch_fn: The function to be wrapped by the hyper-connection.
x: Input tensor of shape `(batch..., dim)`.
mhc_type: The variant of the connection to apply.
**kwargs: Additional context passed to the branch function.

Returns:
The processed tensor, maintaining the shape of `x`.
"""
# x shape: [batch, seq, expansion_rate, emb]
b, s, k, d = x.shape

# 1. Flatten the tensor, and RMS normalization
norm_x = self.mhc_norm(jnp.reshape(x, (b, s, k * d)))

# 2. Pre mapping
pre_mapping = self.mapping(norm_x, self.pre_alpha_scale, self.pre_alpha[...], self.pre_beta[...], 1.0)
layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.config.matmul_precision)

# 3. Attention or MLP
if mhc_type == HyperConnectionType.ATTENTION:
layer_out, _ = branch_fn(inputs_q=layer_input, inputs_kv=layer_input, **kwargs)
elif mhc_type == HyperConnectionType.MLP_DENSE:
layer_out = branch_fn(inputs=layer_input, **kwargs)
elif mhc_type == HyperConnectionType.MLP_MOE:
layer_out, _, _ = branch_fn(inputs=layer_input, **kwargs)
else:
raise ValueError(f"Unsupported type: {mhc_type}")

# 4. Post mapping
post_mapping = self.mapping(norm_x, self.post_alpha_scale, self.post_alpha[...], self.post_beta[...], 2.0)
post_out = jnp.einsum("bsd,bsk -> bskd", layer_out, post_mapping, precision=self.config.matmul_precision)

# 5. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb]
res_mapping = self.res_mapping(norm_x)
res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.config.matmul_precision)
return res_out + post_out
Loading
Loading