Skip to content

[DeepSeek-V4] Implement core primitives (RMSNorm, RoPE, GroupedLinear)#3865

Open
parambole wants to merge 1 commit into
mainfrom
deepseek_v4_core_primitives
Open

[DeepSeek-V4] Implement core primitives (RMSNorm, RoPE, GroupedLinear)#3865
parambole wants to merge 1 commit into
mainfrom
deepseek_v4_core_primitives

Conversation

@parambole
Copy link
Copy Markdown
Collaborator

@parambole parambole commented May 11, 2026

Description

Implement foundational architectural core primitives required for DeepSeek-V4 integration into MaxText:

  • DeepSeekV4RMSNorm & DeepSeekV4UnweightedRMSNorm: RMS normalization layers utilizing float32 variance pooling. Includes unweighted scale-free variants that avoid allocating or synchronizing trainable weight parameters.
  • DeepSeekGroupedLinear: Block-diagonal grouped linear projection layer supporting parallel group projection via einsum broadcasting ([B, S, hc_mult, D] -> [B, S, D]).
  • DeepSeekV4RotaryEmbedding: Interleaved partial rotary positional embedding pairing consecutive even/odd channels.
  • Unit test suite (tests/unit/deepseek_v4_vs_reference_test.py) validating numerical parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5.

Tests

Tested on CPU

bash

pytest  tests/unit/deepseek_v4_vs_reference_test.py

Results:

======================== 4 passed, 4 warnings in 2.58s =========================

tests/unit/deepseek_v4_vs_reference_test.py .... [100%]

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@parambole parambole changed the title feat: DeepSeek-V4 Core Primitives DeepSeek-V4 Core Primitives May 11, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented May 11, 2026

Codecov Report

❌ Patch coverage is 20.27027% with 59 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/embeddings.py 17.14% 29 Missing ⚠️
src/maxtext/layers/normalizations.py 26.08% 17 Missing ⚠️
src/maxtext/layers/linears.py 18.75% 13 Missing ⚠️

📢 Thoughts on this report? Let us know!

@parambole parambole force-pushed the deepseek_v4_core_primitives branch 3 times, most recently from 68c44a6 to 72a92a7 Compare May 14, 2026 17:36
…near)

Implement architectural core primitives required for DeepSeek-V4 integration into MaxText:

- DeepSeekV4RMSNorm & DeepSeekV4UnweightedRMSNorm: RMS normalization layers utilizing float32 variance pooling. Includes unweighted scale-free variants that avoid allocating or synchronizing trainable weight parameters.
- DeepSeekGroupedLinear: Block-diagonal grouped linear projection layer supporting parallel group projection via einsum broadcasting ([B, S, hc_mult, D] -> [B, S, D]).
- DeepSeekV4RotaryEmbedding: Interleaved partial rotary positional embedding pairing consecutive even/odd channels.
- Unit test suite (deepseek_v4_vs_reference_test.py) validating numerical parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5.
@parambole parambole force-pushed the deepseek_v4_core_primitives branch from 72a92a7 to e81f52d Compare May 14, 2026 17:45
@parambole parambole changed the title DeepSeek-V4 Core Primitives Implement DeepSeek-V4 core primitives (RMSNorm, RoPE, GroupedLinear) May 14, 2026
@parambole parambole changed the title Implement DeepSeek-V4 core primitives (RMSNorm, RoPE, GroupedLinear) [DeepSeek-V4] Implement core primitives (RMSNorm, RoPE, GroupedLinear) May 14, 2026
@github-actions
Copy link
Copy Markdown

🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request implements foundational architectural primitives for DeepSeek-V4, specifically focused on JAX/NNX implementations of interleaved Rotary Embeddings, Grouped Linear layers, and RMS Normalization. The core logic for these primitives appears correct and aligns well with the DeepSeek-V4 specification, and the inclusion of cross-framework parity tests is excellent.

🔍 General Feedback

  • Sharding Support: The most critical feedback is the current lack of explicit sharding support in the nnx.Param initializations for the new DeepSeekV4RMSNorm and DeepSeekGroupedLinear layers. In the MaxText ecosystem, explicitly defining parameter sharding is essential for performance and reliability during large-scale distributed training.
  • Precision Management: The implementation of DeepSeekV4RotaryEmbedding should ensure that sinusoids are maintained in float32 precision until the final rotation application to avoid premature truncation, especially when working with bfloat16 inputs.
  • Test Quality: The use of a PyTorch reference implementation within the unit tests to verify numerical parity is a great practice and provides high confidence in the correctness of the JAX primitives.

Comment on lines +1851 to +1852
freqs = position_ids_expanded * inv_freq_expanded

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 In `DeepSeekV4RotaryEmbedding.__call__`, the cosine and sine values are cast to `x.dtype` (which could be `bfloat16`) immediately after computation. Since the subsequent `apply_rotary_pos_emb` explicitly performs rotation in `float32` for numerical stability, it's better to keep the sinusoids in `float32` (or the precision specified in `self.dtype`) until the final rotation is computed. This prevents premature precision loss.
Suggested change
freqs = position_ids_expanded * inv_freq_expanded
cos = jnp.cos(freqs).astype(self.dtype) # [B, S, dim/2]
sin = jnp.sin(freqs).astype(self.dtype) # [B, S, dim/2]


This layer segments the trailing dimension of the input tensor into a specified
number of groups, and projects each group independently using a distinct weight
matrix block. It minimizes parameter counts and compute overhead in the
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 For consistency with other DeepSeek-V4 layers like `DeepSeekV4RMSNorm` and `DeepSeekV4RotaryEmbedding`, consider renaming this class to `DeepSeekV4GroupedLinear`.
Suggested change
matrix block. It minimizes parameter counts and compute overhead in the
class DeepSeekV4GroupedLinear(nnx.Module):

Comment on lines +605 to +614
kernel_shape = (n_groups, in_features_per_group, self.out_features_per_group)
self.weight = nnx.Param(
kernel_init(
rngs.params(),
kernel_shape,
self.weight_dtype,
in_axis=1,
out_axis=2,
)
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 This layer is missing `sharding` support in its `nnx.Param` initialization. In MaxText, it's critical to provide explicit sharding constraints for parameters to ensure they are correctly partitioned across the device mesh. Consider adding a `kernel_axes` parameter to `__init__` and passing it to the `nnx.Param` call.
Suggested change
kernel_shape = (n_groups, in_features_per_group, self.out_features_per_group)
self.weight = nnx.Param(
kernel_init(
rngs.params(),
kernel_shape,
self.weight_dtype,
in_axis=1,
out_axis=2,
)
)
self.weight = nnx.Param(
kernel_init(
rngs.params(),
kernel_shape,
self.weight_dtype,
in_axis=1,
out_axis=2,
),
sharding=kernel_axes,
)

self.weight_dtype = weight_dtype

# Initialize learnable scale weight to ones matching T5LayerNorm behavior
self.weight = nnx.Param(jnp.ones((hidden_size,), dtype=weight_dtype))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Like the grouped linear layer, this normalization layer is missing `sharding` support for its weight parameter. To maintain consistency with the base `RMSNorm` and ensure proper sharding in large-scale training, consider adding `kernel_axes` to the constructor and passing it to `nnx.Param`.
Suggested change
self.weight = nnx.Param(jnp.ones((hidden_size,), dtype=weight_dtype))
self.weight = nnx.Param(
jnp.ones((hidden_size,), dtype=weight_dtype),
sharding=kernel_axes,
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants