[DeepSeek-V4] Implement core primitives (RMSNorm, RoPE, GroupedLinear)#3865
[DeepSeek-V4] Implement core primitives (RMSNorm, RoPE, GroupedLinear)#3865parambole wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
68c44a6 to
72a92a7
Compare
…near) Implement architectural core primitives required for DeepSeek-V4 integration into MaxText: - DeepSeekV4RMSNorm & DeepSeekV4UnweightedRMSNorm: RMS normalization layers utilizing float32 variance pooling. Includes unweighted scale-free variants that avoid allocating or synchronizing trainable weight parameters. - DeepSeekGroupedLinear: Block-diagonal grouped linear projection layer supporting parallel group projection via einsum broadcasting ([B, S, hc_mult, D] -> [B, S, D]). - DeepSeekV4RotaryEmbedding: Interleaved partial rotary positional embedding pairing consecutive even/odd channels. - Unit test suite (deepseek_v4_vs_reference_test.py) validating numerical parity against PyTorch reference implementations at atol=1e-5, rtol=1e-5.
72a92a7 to
e81f52d
Compare
|
🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request implements foundational architectural primitives for DeepSeek-V4, specifically focused on JAX/NNX implementations of interleaved Rotary Embeddings, Grouped Linear layers, and RMS Normalization. The core logic for these primitives appears correct and aligns well with the DeepSeek-V4 specification, and the inclusion of cross-framework parity tests is excellent.
🔍 General Feedback
- Sharding Support: The most critical feedback is the current lack of explicit sharding support in the
nnx.Paraminitializations for the newDeepSeekV4RMSNormandDeepSeekGroupedLinearlayers. In the MaxText ecosystem, explicitly defining parameter sharding is essential for performance and reliability during large-scale distributed training. - Precision Management: The implementation of
DeepSeekV4RotaryEmbeddingshould ensure that sinusoids are maintained infloat32precision until the final rotation application to avoid premature truncation, especially when working withbfloat16inputs. - Test Quality: The use of a PyTorch reference implementation within the unit tests to verify numerical parity is a great practice and provides high confidence in the correctness of the JAX primitives.
| freqs = position_ids_expanded * inv_freq_expanded | ||
|
|
There was a problem hiding this comment.
| freqs = position_ids_expanded * inv_freq_expanded | |
| cos = jnp.cos(freqs).astype(self.dtype) # [B, S, dim/2] | |
| sin = jnp.sin(freqs).astype(self.dtype) # [B, S, dim/2] |
|
|
||
| This layer segments the trailing dimension of the input tensor into a specified | ||
| number of groups, and projects each group independently using a distinct weight | ||
| matrix block. It minimizes parameter counts and compute overhead in the |
There was a problem hiding this comment.
| matrix block. It minimizes parameter counts and compute overhead in the | |
| class DeepSeekV4GroupedLinear(nnx.Module): |
| kernel_shape = (n_groups, in_features_per_group, self.out_features_per_group) | ||
| self.weight = nnx.Param( | ||
| kernel_init( | ||
| rngs.params(), | ||
| kernel_shape, | ||
| self.weight_dtype, | ||
| in_axis=1, | ||
| out_axis=2, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
| kernel_shape = (n_groups, in_features_per_group, self.out_features_per_group) | |
| self.weight = nnx.Param( | |
| kernel_init( | |
| rngs.params(), | |
| kernel_shape, | |
| self.weight_dtype, | |
| in_axis=1, | |
| out_axis=2, | |
| ) | |
| ) | |
| self.weight = nnx.Param( | |
| kernel_init( | |
| rngs.params(), | |
| kernel_shape, | |
| self.weight_dtype, | |
| in_axis=1, | |
| out_axis=2, | |
| ), | |
| sharding=kernel_axes, | |
| ) |
| self.weight_dtype = weight_dtype | ||
|
|
||
| # Initialize learnable scale weight to ones matching T5LayerNorm behavior | ||
| self.weight = nnx.Param(jnp.ones((hidden_size,), dtype=weight_dtype)) |
There was a problem hiding this comment.
| self.weight = nnx.Param(jnp.ones((hidden_size,), dtype=weight_dtype)) | |
| self.weight = nnx.Param( | |
| jnp.ones((hidden_size,), dtype=weight_dtype), | |
| sharding=kernel_axes, | |
| ) |
Description
Implement foundational architectural core primitives required for DeepSeek-V4 integration into MaxText:
DeepSeekV4RMSNorm&DeepSeekV4UnweightedRMSNorm: RMS normalization layers utilizing float32 variance pooling. Includes unweighted scale-free variants that avoid allocating or synchronizing trainable weight parameters.DeepSeekGroupedLinear: Block-diagonal grouped linear projection layer supporting parallel group projection via einsum broadcasting ([B, S, hc_mult, D] -> [B, S, D]).DeepSeekV4RotaryEmbedding: Interleaved partial rotary positional embedding pairing consecutive even/odd channels.tests/unit/deepseek_v4_vs_reference_test.py) validating numerical parity against PyTorch reference implementations atatol=1e-5, rtol=1e-5.Tests
Tested on CPU
bash
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.