-
Notifications
You must be signed in to change notification settings - Fork 462
DeepSeek MHC feature #3065
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
DeepSeek MHC feature #3065
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # Copyright 2023–2026 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # Small model config for testing (derived from DeepSeek V3.2 - 671B) | ||
|
|
||
| base_emb_dim: 1024 # Reduced from 7168 | ||
| base_num_query_heads: 16 # Reduced from 128 | ||
| base_num_kv_heads: 16 # Reduced from 128 | ||
| base_mlp_dim: 2048 # Reduced from 18432 | ||
| base_moe_mlp_dim: 512 # Reduced from 2048 | ||
| base_num_decoder_layers: 6 # Reduced from 61 | ||
| first_num_dense_layers: 1 # Reduced from 3 | ||
| mlp_activations: ["silu","linear"] | ||
| vocab_size: 129280 | ||
| enable_dropout: False | ||
| logits_via_embedding: False | ||
| normalization_layer_epsilon: 1.0e-6 | ||
| num_experts: 16 # Reduced from 256 | ||
| num_experts_per_tok: 2 # Reduced from 8 | ||
| shared_experts: 1 | ||
| routed_scaling_factor: 2.5 | ||
| routed_score_func: "sigmoid" | ||
| routed_bias: True | ||
| decoder_block: "deepseek" | ||
| # MLA | ||
| attention_type: "mla" | ||
| q_lora_rank: 384 # Reduced from 1536 | ||
| kv_lora_rank: 128 # Reduced from 512 | ||
| qk_nope_head_dim: 32 # Reduced from 128 | ||
| qk_rope_head_dim: 16 # Reduced from 64 | ||
| v_head_dim: 128 | ||
| # RoPE | ||
| mscale: 1.0 | ||
| rope_type: "yarn" | ||
| rope_max_timescale: 10_000 | ||
| max_position_embeddings: 4096 # Reduced for local testing | ||
| original_max_position_embeddings: 4096 | ||
| rope_factor: 1 | ||
| beta_fast: 32 | ||
| rope_interleave: True | ||
| rope_truncate: True | ||
| rope_attention_scaling: False | ||
| # Indexer for DeepSeek Sparse Attention | ||
| use_sparse_indexer: True | ||
| index_n_heads: 16 # Reduced from 64 | ||
| index_head_dim: 64 # Reduced from 128 | ||
| index_topk: 256 # Reduced from 2048 | ||
| # Hyper-connections: mHC enabled | ||
| mhc_expansion_rate: 4 | ||
| sinkhorn_iterations: 20 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,235 @@ | ||
| # Copyright 2023–2026 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer.""" | ||
|
|
||
| import jax | ||
| from jax.sharding import Mesh | ||
|
|
||
| import jax.numpy as jnp | ||
| from flax import nnx | ||
| from typing import Callable | ||
| from MaxText.common_types import Config, Array | ||
| from MaxText.layers.normalizations import RMSNorm | ||
| from MaxText.layers.initializers import nd_dense_init, default_bias_init, default_scalar_init | ||
| from MaxText.common_types import HyperConnectionType | ||
|
|
||
|
|
||
| def get_functions(expansion_rate: int): | ||
| """ | ||
| Creates functions to broadcast a single feature stream into multiple | ||
| parallel paths (expand) and aggregate them back (reduce). | ||
| """ | ||
|
|
||
| def expand(x: Array): | ||
| # (batch, length, dim) -> (batch, length, streams, dim) | ||
| return jnp.repeat(jnp.expand_dims(x, axis=2), expansion_rate, axis=2) | ||
|
|
||
| def reduce(x: Array): | ||
| # (batch, length, streams, dim) -> (batch, length, dim) | ||
| return jnp.sum(x, axis=2) | ||
|
|
||
| return expand, reduce | ||
|
|
||
|
|
||
| def sinkhorn(t, iters=20): | ||
| """ | ||
| Computes the Sinkhorn normalization of a matrix (rows and columns sum to 1). | ||
| """ | ||
| # Use float32 precision for numerical stability during normalization | ||
| initial_dtype = t.dtype | ||
| t = t.astype(jnp.float32) | ||
|
|
||
| # Column-wise normalization (axis=-2) - positive and sum up to 1 across columns | ||
| # Equivalent to t = exp(t) / jnp.sum(jnp.exp(t), axis=-2) | ||
| t = jax.nn.softmax(t, axis=-2) | ||
|
|
||
| def body_fun(i, val): | ||
| # L1 Normalization: val / sum(val) with clipping of denominator | ||
| # Normalize rows (axis -1) | ||
| val = val / jnp.clip(jnp.sum(val, axis=-1, keepdims=True), min=1e-12) | ||
| # Normalize columns (axis -2) | ||
| val = val / jnp.clip(jnp.sum(val, axis=-2, keepdims=True), min=1e-12) | ||
| return val | ||
|
|
||
| # Use lax.fori_loop for an efficient, JIT-friendly loop | ||
| t = jax.lax.fori_loop(0, iters, body_fun, t) | ||
| return t.astype(initial_dtype) | ||
|
|
||
|
|
||
| class ManifoldConstrainedHyperConnections(nnx.Module): | ||
| """Implements Manifold-Constrained Hyper-Connections (mHC). | ||
|
|
||
| Reference: https://arxiv.org/pdf/2512.24880 | ||
|
|
||
| Args: | ||
| config: Configuration object containing hyperparameters. | ||
| dim: The feature dimensionality. | ||
| mesh: The hardware mesh for sharding. | ||
| rngs: Random number generation in NNX. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: Config, | ||
| dim: int, | ||
| mesh: Mesh, | ||
| rngs: nnx.Rngs, | ||
| ): | ||
| self.config = config | ||
| self.sinkhorn_iterations = config.sinkhorn_iterations | ||
| self.k = config.mhc_expansion_rate | ||
| self.dim = dim | ||
| self.rngs = rngs | ||
| self.mesh = mesh | ||
| self.weight_dtype = self.config.weight_dtype | ||
|
|
||
| # Norm layer | ||
| self.mhc_norm = RMSNorm( | ||
| num_features=self.k * self.dim, | ||
| dtype=self.config.dtype, | ||
| weight_dtype=self.weight_dtype, | ||
RissyRan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| kernel_axes=("norm",), | ||
| epsilon=self.config.normalization_layer_epsilon, | ||
| rngs=self.rngs, | ||
| ) | ||
|
|
||
| # Scalars | ||
| self.res_alpha_scale = nnx.Param( | ||
| default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), | ||
| sharding=(None,), | ||
| ) | ||
| self.pre_alpha_scale = nnx.Param( | ||
| default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), | ||
| sharding=(None,), | ||
| ) | ||
| self.post_alpha_scale = nnx.Param( | ||
| default_scalar_init(self.rngs.params(), (1,), self.weight_dtype), | ||
| sharding=(None,), | ||
| ) | ||
|
|
||
| # Weight matrices | ||
| scale_init = nd_dense_init(1.0, "fan_in", "normal") | ||
| in_axis = 0 | ||
| out_axis = 1 | ||
| weight_sharding_axis_name = ("activation_embed", None) | ||
| self.res_alpha = nnx.Param( | ||
RissyRan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| scale_init( | ||
| self.rngs.params(), | ||
| (self.k * self.dim, self.k * self.k), | ||
| self.weight_dtype, | ||
| in_axis=in_axis, | ||
| out_axis=out_axis, | ||
| ), | ||
| sharding=weight_sharding_axis_name, | ||
| ) | ||
| self.pre_alpha = nnx.Param( | ||
| scale_init( | ||
| self.rngs.params(), | ||
| (self.k * self.dim, self.k), | ||
| self.weight_dtype, | ||
| in_axis=in_axis, | ||
| out_axis=out_axis, | ||
| ), | ||
| sharding=weight_sharding_axis_name, | ||
| ) | ||
| self.post_alpha = nnx.Param( | ||
| scale_init( | ||
| self.rngs.params(), | ||
| (self.k * self.dim, self.k), | ||
| self.weight_dtype, | ||
| in_axis=in_axis, | ||
| out_axis=out_axis, | ||
| ), | ||
| sharding=weight_sharding_axis_name, | ||
| ) | ||
|
|
||
| # Biases | ||
| self.res_beta = nnx.Param( | ||
| default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype), | ||
| sharding=(None, None), | ||
| ) | ||
| self.pre_beta = nnx.Param( | ||
| default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), | ||
| sharding=(None, None), | ||
| ) | ||
| self.post_beta = nnx.Param( | ||
| default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype), | ||
| sharding=(None, None), | ||
| ) | ||
|
|
||
| def res_mapping(self, x: Array): | ||
| """Helper function for residual mapping.""" | ||
| # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k) | ||
| h_res = jnp.einsum("bsm,mn -> bsn", x, self.res_alpha[...], precision=self.config.matmul_precision) | ||
| b, s, _ = h_res.shape | ||
| h_res = jnp.reshape(h_res, (b, s, self.k, self.k)) | ||
| intermediate = self.res_alpha_scale * h_res + self.res_beta[...][None, None, :, :] | ||
| output = sinkhorn(intermediate, self.sinkhorn_iterations) | ||
| return output | ||
|
|
||
| def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int): | ||
| """Helper function for both pre and post mappings.""" | ||
| # Apply projection: (b, s, k*d) @ (k*d, k) -> (b, s, k) | ||
| h = jnp.einsum("bsm,mk -> bsk", x, alpha, precision=self.config.matmul_precision) | ||
| intermediate = alpha_scale * h + beta[None, None, :] | ||
| output = scale * jax.nn.sigmoid(intermediate) | ||
| return output | ||
|
|
||
| def __call__( | ||
| self, | ||
| branch_fn: Callable, | ||
| x: Array, | ||
| mhc_type: HyperConnectionType, | ||
| **kwargs, | ||
| ) -> Array: | ||
| """Applying manifold-constrained hyper connection based on callable function. | ||
|
|
||
| Args: | ||
| branch_fn: The function to be wrapped by the hyper-connection. | ||
| x: Input tensor of shape `(batch..., dim)`. | ||
| mhc_type: The variant of the connection to apply. | ||
| **kwargs: Additional context passed to the branch function. | ||
|
|
||
| Returns: | ||
| The processed tensor, maintaining the shape of `x`. | ||
| """ | ||
| # x shape: [batch, seq, expansion_rate, emb] | ||
| b, s, k, d = x.shape | ||
|
|
||
| # 1. Flatten the tensor, and RMS normalization | ||
| norm_x = self.mhc_norm(jnp.reshape(x, (b, s, k * d))) | ||
RissyRan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # 2. Pre mapping | ||
| pre_mapping = self.mapping(norm_x, self.pre_alpha_scale, self.pre_alpha[...], self.pre_beta[...], 1.0) | ||
| layer_input = jnp.einsum("bskd,bsk -> bsd", x, pre_mapping, precision=self.config.matmul_precision) | ||
|
|
||
| # 3. Attention or MLP | ||
| if mhc_type == HyperConnectionType.ATTENTION: | ||
| layer_out, _ = branch_fn(inputs_q=layer_input, inputs_kv=layer_input, **kwargs) | ||
| elif mhc_type == HyperConnectionType.MLP_DENSE: | ||
| layer_out = branch_fn(inputs=layer_input, **kwargs) | ||
| elif mhc_type == HyperConnectionType.MLP_MOE: | ||
| layer_out, _, _ = branch_fn(inputs=layer_input, **kwargs) | ||
| else: | ||
| raise ValueError(f"Unsupported type: {mhc_type}") | ||
|
|
||
| # 4. Post mapping | ||
| post_mapping = self.mapping(norm_x, self.post_alpha_scale, self.post_alpha[...], self.post_beta[...], 2.0) | ||
| post_out = jnp.einsum("bsd,bsk -> bskd", layer_out, post_mapping, precision=self.config.matmul_precision) | ||
|
|
||
| # 5. Residual mapping, res_out shape as [batch, seq, expansion_rate, emb] | ||
| res_mapping = self.res_mapping(norm_x) | ||
| res_out = jnp.einsum("bskd,bskm -> bsmd", x, res_mapping, precision=self.config.matmul_precision) | ||
| return res_out + post_out | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.