diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index cc2f674fd4..fb95d3ef64 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -396,6 +396,10 @@ qk_nope_head_dim: 128 qk_rope_head_dim: 64 v_head_dim: 128 +# Constant-std init for MLA proj; output proj scaled by 1/sqrt(2*num_decoder_layers). +# 0 keeps fan_in scaling. +mla_init_std: 0.0 + # QK-Clip (Muon Clip) Configuration use_qk_clip: False # Enable QK-Clip (supported in MLA with DotProduct or Tokamax Splash) qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper) @@ -847,6 +851,11 @@ diloco_outer_momentum: 0.9 # You may disable clipping by setting gradient_clipping_threshold to zero. gradient_clipping_threshold: 1.0 +# Per-token gradient mask at the decoder-layer boundary (DeepSeek-V3 only). +# Tokens whose feature-axis RMS exceeds threshold are zeroed in backward; +# healthy tokens pass through unchanged. 0 disables. +grad_mask_threshold: 0.0 + # Instead of updating the weights every step, you may effectively use a larger # batch by accumulating the gradient over a set of steps. gradient_accumulation_steps: 1 diff --git a/src/maxtext/configs/models/deepseek3-671b.yml b/src/maxtext/configs/models/deepseek3-671b.yml index 18e566cf57..5fb698ce8b 100644 --- a/src/maxtext/configs/models/deepseek3-671b.yml +++ b/src/maxtext/configs/models/deepseek3-671b.yml @@ -44,6 +44,12 @@ qk_nope_head_dim: 128 qk_rope_head_dim: 64 v_head_dim: 128 mscale: 1.0 +# Initialize MLA projections with N(0, std); output proj further scaled +# by 1/sqrt(2*num_decoder_layers). No effect when loading a checkpoint. +mla_init_std: 0.001 +# Mask tokens whose backward-gradient RMS exceeds threshold at each +# decoder-layer boundary. Defensive against bf16 overflow; rarely fires. +grad_mask_threshold: 100.0 # RoPE rope_type: "yarn" rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index c35274cd24..d99758dbb4 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -589,6 +589,13 @@ class MlaAttention(BaseModel): qk_nope_head_dim: NonNegativeInt = Field(128, description="Dimension for non-RoPE part of QK heads in MLA.") qk_rope_head_dim: NonNegativeInt = Field(64, description="Dimension for RoPE part of QK heads in MLA.") v_head_dim: NonNegativeInt = Field(128, description="Dimension of V heads in MLA.") + mla_init_std: NonNegativeFloat = Field( + 0.0, + description=( + "Constant-std init for MLA projections; output proj scaled by " + "1/sqrt(2*num_decoder_layers). 0 keeps fan_in scaling." + ), + ) class AttentionIndexer(BaseModel): @@ -1347,6 +1354,14 @@ class Optimizer(BaseModel): gradient_clipping_threshold: NonNegativeFloat = Field( 1.0, description="The threshold for gradient clipping. 0 disables clipping." ) + grad_mask_threshold: NonNegativeFloat = Field( + 0.0, + description=( + "Per-token gradient mask at the decoder-layer boundary " + "(DeepSeek-V3 only). Forward identity; backward zeros tokens " + "whose feature-axis RMS exceeds threshold. 0 disables." + ), + ) learning_rate: NonNegativeFloat = Field(3.0e-5, description="The peak learning rate.") lr_schedule_type: LearningRateScheduleType = Field( LearningRateScheduleType.COSINE, diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index 133273a36d..d54accd1fc 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -65,7 +65,7 @@ from maxtext.layers import nnx_wrappers from maxtext.layers.attentions import Attention -from maxtext.layers.initializers import nd_dense_init, NdInitializer, variable_to_logically_partitioned +from maxtext.layers.initializers import nd_dense_init, nd_normal_const_std, NdInitializer, variable_to_logically_partitioned from maxtext.layers.linears import DenseGeneral from maxtext.layers.normalizations import RMSNorm from maxtext.layers.quantizations import AqtQuantization as Quant @@ -726,6 +726,9 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No assert self.num_query_heads == self.num_kv_heads, "MLA requires equal number of query and kv heads" assert not self.config.fused_qkv, "Fused QKV is not supported for MLA" + # Constant-std init for MLA projections; output proj rescaled below. + if self.config.mla_init_std > 0.0: + self.kernel_init = nd_normal_const_std(self.config.mla_init_std) if self.q_lora_rank == 0: # Standard Q projection (without LoRA). self.query = DenseGeneral( @@ -823,6 +826,12 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No mscale = 0.1 * self.mscale * math.log(self.rope_factor) + 1.0 self.softmax_scale = self.softmax_scale * mscale * mscale + # Output-proj residual scaling: std / sqrt(2 * num_decoder_layers). + if self.config.mla_init_std > 0.0: + self.kernel_init = nd_normal_const_std( + self.config.mla_init_std / math.sqrt(2.0 * max(1, self.config.num_decoder_layers)) + ) + self.out = self.init_out_w(output_dim=inputs_q_shape[-1]) # Setup paged attention op diff --git a/src/maxtext/layers/initializers.py b/src/maxtext/layers/initializers.py index bbc6605057..063143b5b1 100644 --- a/src/maxtext/layers/initializers.py +++ b/src/maxtext/layers/initializers.py @@ -34,6 +34,21 @@ default_scalar_init = jax.nn.initializers.constant(0.01) +def nd_normal_const_std(std: float): + """Creates a constant-std normal initializer with the NdInitializer signature. + + Returns an initializer that produces N(0, std) regardless of fan_in/fan_out; + useful when a layer needs a fixed-stddev init independent of input shape + (e.g. scaled init for residual output projections). + """ + + def init_fn(key, shape, dtype, in_axis, out_axis): + del in_axis, out_axis + return jax.random.normal(key, shape, dtype=dtype) * std + + return init_fn + + def nd_dense_init(scale, mode, distribution): """Creates a variance-scaling initializer with dynamic in/out axes. diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 0980b78599..0bff87683d 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -40,6 +40,7 @@ from maxtext.layers.normalizations import RMSNorm from maxtext.models import deepseek_batchsplit from maxtext.models import deepseek_batchsplit_fp8 +from maxtext.utils import grad_mask_utils from maxtext.utils import max_utils from maxtext.utils.sharding import create_sharding from maxtext.utils.sharding import maybe_shard_with_logical @@ -260,6 +261,8 @@ def post_process(self, layer_output, load_balance_loss, moe_bias_updates, kv_cac jnp.sum(layer_output == 0) / jnp.size(layer_output), ) + layer_output = grad_mask_utils.maybe_grad_mask(layer_output, self.config) + if self.config.scan_layers: return layer_output, None return layer_output, kv_cache diff --git a/src/maxtext/utils/grad_mask_utils.py b/src/maxtext/utils/grad_mask_utils.py new file mode 100644 index 0000000000..f7961de65b --- /dev/null +++ b/src/maxtext/utils/grad_mask_utils.py @@ -0,0 +1,48 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Per-token gradient mask applied at a layer boundary. + +Forward is identity; backward zeros tokens whose feature-axis RMS exceeds +the configured threshold. Healthy tokens pass through unchanged. Used at +decoder-layer boundaries to bound per-layer cotangent magnitudes +(see deepseek model usage).""" + +import jax +import jax.numpy as jnp + + +@jax.custom_vjp +def _grad_mask(x: jax.Array, threshold: jax.Array) -> jax.Array: + return x + + +def _grad_mask_fwd(x: jax.Array, threshold: jax.Array): + return x, threshold + + +def _grad_mask_bwd(threshold: jax.Array, g: jax.Array): + rms = jnp.sqrt(jnp.mean(jnp.square(g.astype(jnp.float32)), axis=-1, keepdims=True)) + mask = rms <= threshold + return (jnp.where(mask, g, jnp.zeros_like(g)), jnp.zeros_like(threshold)) + + +_grad_mask.defvjp(_grad_mask_fwd, _grad_mask_bwd) + + +def maybe_grad_mask(x: jax.Array, cfg) -> jax.Array: + """Per-token gradient mask if cfg.grad_mask_threshold > 0; else identity.""" + if cfg.grad_mask_threshold > 0.0: + return _grad_mask(x, jnp.asarray(cfg.grad_mask_threshold, jnp.float32)) + return x diff --git a/tests/unit/grad_mask_utils_test.py b/tests/unit/grad_mask_utils_test.py new file mode 100644 index 0000000000..bdf4f94c82 --- /dev/null +++ b/tests/unit/grad_mask_utils_test.py @@ -0,0 +1,124 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for per-layer per-token gradient mask (grad_mask_utils).""" + +import unittest +from collections import namedtuple + +import jax +import jax.numpy as jnp +import numpy as np + +from maxtext.utils.grad_mask_utils import _grad_mask, maybe_grad_mask + + +class GradMaskTest(unittest.TestCase): + + def setUp(self): + self.rng = jax.random.PRNGKey(0) + + def test_forward_is_identity(self): + """Forward pass must return input unchanged regardless of threshold.""" + x = jax.random.normal(self.rng, (2, 8, 16)) + for thr in [0.5, 1.0, 100.0]: + y = _grad_mask(x, jnp.float32(thr)) + np.testing.assert_array_equal(np.asarray(y), np.asarray(x)) + + def test_backward_below_threshold_passthrough(self): + """When per-token RMS <= threshold, backward must return g unchanged.""" + x = jnp.ones((2, 4, 8), dtype=jnp.bfloat16) + threshold = jnp.float32(1e6) # huge → never clip + + def loss_fn(x): + return jnp.sum(_grad_mask(x, threshold) * 0.5) + + g = jax.grad(loss_fn)(x) + expected = jnp.full_like(x, 0.5) + np.testing.assert_allclose(np.asarray(g), np.asarray(expected), atol=1e-3) + + def test_backward_outlier_tokens_are_masked(self): + """Tokens with RMS > threshold get zero gradient; healthy tokens unchanged.""" + x = jnp.zeros((2, 3, 8), dtype=jnp.float32) + threshold = jnp.float32(1.0) + upstream = jnp.ones_like(x) + # Make token (0, 0) an outlier (RMS = 10) and token (1, 2) an outlier (RMS = 100). + upstream = upstream.at[0, 0].set(10.0) + upstream = upstream.at[1, 2].set(100.0) + + def fn(x): + return _grad_mask(x, threshold) + + _, vjp = jax.vjp(fn, x) + (g_masked,) = vjp(upstream) + g_masked = np.asarray(g_masked) + # Outlier tokens zeroed. + np.testing.assert_array_equal(g_masked[0, 0], np.zeros(8, dtype=np.float32)) + np.testing.assert_array_equal(g_masked[1, 2], np.zeros(8, dtype=np.float32)) + # Healthy tokens (RMS = 1.0 == threshold, passes through). + np.testing.assert_array_equal(g_masked[0, 1], np.ones(8, dtype=np.float32)) + np.testing.assert_array_equal(g_masked[1, 0], np.ones(8, dtype=np.float32)) + + def test_backward_threshold_grad_is_zero(self): + """Threshold arg must receive a zero gradient (it's not differentiable).""" + x = jnp.ones((2, 4, 8), dtype=jnp.float32) + + def fn(x, threshold): + return _grad_mask(x, threshold) + + threshold = jnp.float32(1.0) + _, vjp = jax.vjp(fn, x, threshold) + upstream = jnp.ones_like(x) + _, g_threshold = vjp(upstream) + self.assertEqual(float(g_threshold), 0.0) + + def test_maybe_grad_mask_threshold_zero_is_noop(self): + """maybe_grad_mask with threshold=0 returns input unchanged and inserts no boundary.""" + Cfg = namedtuple("Cfg", ["grad_mask_threshold"]) + cfg = Cfg(grad_mask_threshold=0.0) + x = jax.random.normal(self.rng, (2, 4, 8)) + y = maybe_grad_mask(x, cfg) + self.assertIs(y, x) # exact identity, no jnp.array wrapping + + def test_maybe_grad_mask_threshold_positive_applies_mask(self): + """maybe_grad_mask with threshold > 0 zeros tokens whose RMS exceeds threshold.""" + Cfg = namedtuple("Cfg", ["grad_mask_threshold"]) + cfg = Cfg(grad_mask_threshold=0.5) + x = jnp.zeros((2, 4, 8), dtype=jnp.float32) + + def fn(x): + return maybe_grad_mask(x, cfg) + + # All tokens have RMS = 10.0 (every element = 10.0); threshold = 0.5 → all masked. + upstream = jnp.full_like(x, 10.0) + _, vjp = jax.vjp(fn, x) + (g,) = vjp(upstream) + np.testing.assert_array_equal(np.asarray(g), np.zeros_like(np.asarray(x))) + + def test_dtype_preserved_in_backward(self): + """Backward must preserve the gradient's dtype (bf16 in, bf16 out).""" + x = jnp.zeros((2, 4, 8), dtype=jnp.bfloat16) + threshold = jnp.float32(0.1) + + def fn(x): + return _grad_mask(x, threshold) + + upstream = (jax.random.normal(self.rng, x.shape, dtype=jnp.float32) * 10.0).astype(jnp.bfloat16) + _, vjp = jax.vjp(fn, x) + (g,) = vjp(upstream) + self.assertEqual(g.dtype, jnp.bfloat16) + + +if __name__ == "__main__": + unittest.main()