Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ topk_routing_group: -1 # number of top groups to route inputs. For EP,
# all-to-all communication with compute. Currently only implemented with DeepSeek sparse layers.
use_batch_split_schedule: False # a flag if splitting batch into micro-batches to hide communications that yields performance benefits.
batch_split_factor: 1 # the factor by which to split the batch. Only used if use_batch_split_schedule is True.
num_hash_layers: 3 # Number of initial MoE layers to apply static Hash Routing.

# For complex architectures like llama4 there are repeated sets of
# inhomogeneous layers. E.g. maverick uses [dense+rope, moe+rope, dense+rope, moe+nope]
Expand Down Expand Up @@ -1227,6 +1228,7 @@ force_q_layout: false
mhc_expansion_rate: 1
# The number of iterations for the Sinkhorn-Knopp algorithm.
sinkhorn_iterations: 20
hc_eps: 1.0e-6

################################## DeepSeek Engram ##################################
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.
Expand Down
72 changes: 72 additions & 0 deletions src/maxtext/configs/models/deepseek_v4-flash.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Default model configs for DeepSeek-V4-Flash (43 Layers)

base_config: base.yml
model_name: deepseek_v4-flash

base_emb_dim: 4096
base_num_query_heads: 64
base_num_kv_heads: 1
head_dim: 512
base_mlp_dim: 2048
base_moe_mlp_dim: 2048
base_num_decoder_layers: 43
first_num_dense_layers: 0
mlp_activations: ["silu"]
vocab_size: 129280
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-6
num_experts: 256
num_experts_per_tok: 6
shared_experts: 1
routed_scaling_factor: 1.5
routed_score_func: "sqrtsoftplus"
routed_bias: True
norm_topk_prob: True
decoder_block: "deepseek_v4"
pure_nnx_decoder: True
enable_nnx: True

# Manifold-Constrained Hyper-Connection configurations
mhc_expansion_rate: 4
sinkhorn_iterations: 20
compress_rope_theta: 160000.0
index_head_dim: 128
index_n_heads: 64
index_topk: 512
o_groups: 8
o_lora_rank: 1024
sliding_window: 128
num_hash_layers: 3
mlp_activations_limit: 10.0
compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4]

# Compressed Sparse Attention
q_lora_rank: 1024
kv_lora_rank: 512
qk_nope_head_dim: 128
qk_rope_head_dim: 64
v_head_dim: 128
mscale: 1.0

# RoPE
rope_type: "default"
rope_max_timescale: 10_000
max_position_embeddings: 1048576
original_max_position_embeddings: 65536
rope_factor: 16
beta_fast: 32
72 changes: 72 additions & 0 deletions src/maxtext/configs/models/deepseek_v4-tiny.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tiny version of DeepSeek-V4 (4 Layers) for local sharding and compilation checks.

base_config: base.yml
model_name: deepseek_v4-tiny

base_emb_dim: 128
base_num_query_heads: 16
base_num_kv_heads: 1
head_dim: 32
base_mlp_dim: 128
base_moe_mlp_dim: 128
base_num_decoder_layers: 6
first_num_dense_layers: 0
mlp_activations: ["silu"]
vocab_size: 129280
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1.0e-6
num_experts: 8
num_experts_per_tok: 4
shared_experts: 1
routed_scaling_factor: 1.5
routed_score_func: "sqrtsoftplus"
routed_bias: True
norm_topk_prob: True
decoder_block: "deepseek_v4"
pure_nnx_decoder: True
enable_nnx: True

# Manifold-Constrained Hyper-Connection configurations
mhc_expansion_rate: 4
sinkhorn_iterations: 20
compress_rope_theta: 160000.0
index_head_dim: 32
index_n_heads: 16
index_topk: 64
o_groups: 2
o_lora_rank: 64
sliding_window: 32
num_hash_layers: 3
mlp_activations_limit: 10.0
compress_ratios: [0, 4, 128, 4, 128, 0]

# Compressed Attention
q_lora_rank: 64
kv_lora_rank: 32
qk_nope_head_dim: 32
qk_rope_head_dim: 16
v_head_dim: 128
mscale: 1.0

# RoPE
rope_type: "default"
rope_max_timescale: 10_000
max_position_embeddings: 163840
original_max_position_embeddings: 4096
rope_factor: 40
beta_fast: 32
7 changes: 7 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ class ProfilerType(str, Enum):
"deepseek3-test",
"deepseek3-tiny",
"deepseek3.2-671b",
"deepseek_v4-tiny",
"deepseek_v4-flash",
"deepseek-custom",
"kimi-k2-1t",
"gemma-7b",
Expand Down Expand Up @@ -831,6 +833,10 @@ class DeepSeekMoE(BaseModel):
1,
description="Factor by which to split the batch into micro-batches. Only used if use_batch_split_schedule is True.",
)
num_hash_layers: int = Field(
3,
description="Number of initial MoE layers to apply static Hash Routing.",
)


class Qwen3Next(BaseModel):
Expand Down Expand Up @@ -1381,6 +1387,7 @@ class ManifoldConstrainedHyperConnections(BaseModel):

mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.")
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
hc_eps: float = Field(1e-6, description="The epsilon fallback value for numerical stability in mHC.")


class DilocoParams(BaseModel):
Expand Down
51 changes: 38 additions & 13 deletions src/maxtext/layers/attention_compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
# Interleaved rotary embeddings applied to the trailing slice
self.rotary_emb = DeepSeekV4RotaryEmbedding(
head_dim=head_dim,
partial_rotary_factor=64.0 / 512.0,
partial_rotary_factor=config.qk_rope_head_dim / config.head_dim,
rope_theta=rope_theta,
)

Expand Down Expand Up @@ -328,7 +328,7 @@ def __init__(
# Interleaved rotary embedding aligning query/key pos representations
self.rotary_emb = DeepSeekV4RotaryEmbedding(
head_dim=self.head_dim,
partial_rotary_factor=(config.head_dim * (64.0 / 512.0)) / self.head_dim,
partial_rotary_factor=config.qk_rope_head_dim / self.head_dim,
rope_theta=rope_theta,
)

Expand Down Expand Up @@ -582,7 +582,7 @@ def __init__(
# Interleaved rotary embeddings for compressed sequences
self.rotary_emb = DeepSeekV4RotaryEmbedding(
head_dim=head_dim,
partial_rotary_factor=64.0 / 512.0,
partial_rotary_factor=config.qk_rope_head_dim / config.head_dim,
rope_theta=rope_theta,
)

Expand Down Expand Up @@ -764,6 +764,7 @@ def __init__(
eps: float = 1e-6,
weight_dtype: Any = jnp.float32,
dtype: Any = jnp.float32,
attention_type: str = "compressed_sparse_attention",
*,
rngs: nnx.Rngs,
):
Expand All @@ -779,12 +780,13 @@ def __init__(
eps: Tiny additive variance limit for RMS normalization stability.
weight_dtype: The parameter weights numerical data type.
dtype: The mathematical execution numerical data type.
attention_type: The type of compressed attention being instantiated.
rngs: The Flax NNX random number generator collection.
"""
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.layer_type = config.layer_types[layer_idx]
self.attention_type = attention_type
self.num_heads = num_heads
self.head_dim = head_dim
self.sliding_window = config.sliding_window
Expand Down Expand Up @@ -858,7 +860,7 @@ def __init__(
self.sinks = nnx.Param(jax.nn.initializers.zeros(rngs.params(), (num_heads,), weight_dtype))

# Layer specific compressor allocation
if self.layer_type == "heavily_compressed_attention":
if self.attention_type == "heavily_compressed_attention":
self.compressor = HCACompressor(
hidden_size=hidden_size,
head_dim=head_dim,
Expand All @@ -869,7 +871,7 @@ def __init__(
dtype=dtype,
rngs=rngs,
)
elif self.layer_type == "compressed_sparse_attention":
elif self.attention_type == "compressed_sparse_attention":
self.compressor = CSACompressor(
hidden_size=hidden_size,
q_lora_rank=q_lora_rank,
Expand All @@ -884,14 +886,33 @@ def __init__(
else:
self.compressor = None

# Compute partial rotary factor dynamically from config to prevent dimension mismatches.
# DeepSeek-V4 pairs consecutive channels to apply partial RoPE on qk_rope_head_dim channels,
# requiring dynamic scaling: partial_rotary_factor = qk_rope_head_dim / head_dim.
self.partial_rotary_factor = self.config.qk_rope_head_dim / self.config.head_dim

self.rope_theta = (
self.config.rope_max_timescale if self.attention_type == "sliding_attention" else self.config.compress_rope_theta
)

# Local rotary embedding block matching standard MaxText (Gemma/Llama2) paradigms.
self.rotary_embedding = DeepSeekV4RotaryEmbedding(
head_dim=self.head_dim,
partial_rotary_factor=self.partial_rotary_factor,
rope_theta=self.rope_theta,
)

def __call__(
self,
hidden_states: jnp.ndarray,
cos: jnp.ndarray,
sin: jnp.ndarray,
position_ids: jnp.ndarray,
hidden_states: jnp.ndarray | None = None,
position_ids: jnp.ndarray | None = None,
attention_mask: jnp.ndarray | None = None,
inputs_q: jnp.ndarray | None = None,
inputs_kv: jnp.ndarray | None = None,
**kwargs,
) -> tuple[jnp.ndarray, jnp.ndarray]:
if hidden_states is None:
hidden_states = inputs_q
"""Executes DeepSeek-V4 compressed multi-head attention.

This method projects input states to query representations, applies low-rank
Expand All @@ -903,8 +924,6 @@ def __call__(

Args:
hidden_states: The input hidden representation sequence of shape [B, S, D_model].
cos: Positional RoPE cosine frequencies array of shape [B, S, D_rope].
sin: Positional RoPE sine frequencies array of shape [B, S, D_rope].
position_ids: Absolute sequence position identifiers of shape [B, S].
attention_mask: Optional attention mask of shape [B, 1, S, S_kv].

Expand All @@ -914,10 +933,16 @@ def __call__(
- The final multi-head attention weights of shape [B, H, S, S_kv].
"""
# hidden_states shape: [B, S, D_model]
# cos, sin shape: [B, S, D_rope]
# position_ids shape: [B, S]
# attention_mask shape: [B, 1, S_q, S_kv]
batch, seq_len, _ = hidden_states.shape
# Unconditionally compute RoPE positional frequency embeddings locally from position IDs.
if position_ids is None:
# [B, S] position sequence index grid broadcast
position_ids = jnp.broadcast_to(jnp.arange(seq_len, dtype=jnp.int32)[None], (batch, seq_len))
# cos/sin shape: [B, S, qk_rope_head_dim / 2]
cos, sin = self.rotary_embedding(hidden_states, position_ids)

h_shape = (batch, seq_len, self.num_heads, self.head_dim)

# Project inputs to query representations
Expand Down
Loading
Loading