-
Notifications
You must be signed in to change notification settings - Fork 517
[DeepSeek-V4] Implement core primitives (RMSNorm, RoPE, GroupedLinear) #3865
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |||||||
|
|
||||||||
| import dataclasses | ||||||||
| import math | ||||||||
| from typing import Any | ||||||||
|
|
||||||||
| import jax | ||||||||
| from jax import lax | ||||||||
|
|
@@ -1800,3 +1801,115 @@ def qwen3_omni_mrope_embedding_as_linen( | |||||||
| metadata_fn=variable_to_logically_partitioned, | ||||||||
| name=name, | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
||||||||
| class DeepSeekV4RotaryEmbedding(nnx.Module): | ||||||||
| """DeepSeek-V4 partial rotary embedding with interleaved frequencies. | ||||||||
|
|
||||||||
| DeepSeek-V4 uses an interleaved positional encoding where consecutive channels | ||||||||
| are paired together. Unlike standard rotary models that split dimensions globally | ||||||||
| into first and second halves, this implementation pairs each even channel 2i | ||||||||
| with the corresponding odd channel 2i + 1. | ||||||||
|
|
||||||||
| This results in two specific mathematical properties: | ||||||||
| 1. Inverse frequencies are computed for (dim // 2) unique theta angles. | ||||||||
| 2. Sinusoidal components are expanded consecutively (e.g., [f0, f0, f1, f1]) | ||||||||
| prior to application. | ||||||||
| """ | ||||||||
|
|
||||||||
| def __init__( | ||||||||
| self, | ||||||||
| head_dim: int, | ||||||||
| partial_rotary_factor: float = 64.0 / 512.0, | ||||||||
| rope_theta: float = 10000.0, | ||||||||
| dtype: Any = jnp.float32, | ||||||||
| ): | ||||||||
| self.head_dim = head_dim | ||||||||
| self.partial_rotary_factor = partial_rotary_factor | ||||||||
| self.rope_theta = rope_theta | ||||||||
| self.dtype = dtype | ||||||||
|
|
||||||||
| # Compute the partial rotary dimension (rope_head_dim) | ||||||||
| # e.g., 512 * (64 / 512) = 64 channels | ||||||||
| self.dim = int(head_dim * partial_rotary_factor) | ||||||||
|
|
||||||||
| # Compute base inverse frequencies for half of self.dim (dim // 2 unique theta angles). | ||||||||
| # Adjacent channels share the same base frequency, matching the reference sequence. | ||||||||
| half_dim = self.dim // 2 | ||||||||
| fraction = 2 * jnp.arange(0, half_dim, dtype=jnp.float32) / self.dim | ||||||||
| self.inv_freq = 1.0 / (self.rope_theta**fraction) | ||||||||
|
|
||||||||
| def __call__(self, x: jnp.ndarray, position_ids: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: | ||||||||
| # position_ids: [B, S] | ||||||||
| # Expand inverse frequencies for broadcasting: [1, 1, dim/2] | ||||||||
| inv_freq_expanded = self.inv_freq[jnp.newaxis, jnp.newaxis, :] | ||||||||
|
|
||||||||
| # Expand position IDs: [B, S, 1] | ||||||||
| position_ids_expanded = position_ids[:, :, jnp.newaxis].astype(jnp.float32) | ||||||||
|
|
||||||||
| # Compute outer product of positions and frequencies: [B, S, dim/2] | ||||||||
| freqs = position_ids_expanded * inv_freq_expanded | ||||||||
|
|
||||||||
|
Comment on lines
+1851
to
+1852
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟡 In `DeepSeekV4RotaryEmbedding.__call__`, the cosine and sine values are cast to `x.dtype` (which could be `bfloat16`) immediately after computation. Since the subsequent `apply_rotary_pos_emb` explicitly performs rotation in `float32` for numerical stability, it's better to keep the sinusoids in `float32` (or the precision specified in `self.dtype`) until the final rotation is computed. This prevents premature precision loss.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||||||||
| cos = jnp.cos(freqs).astype(x.dtype) # [B, S, dim/2] | ||||||||
| sin = jnp.sin(freqs).astype(x.dtype) # [B, S, dim/2] | ||||||||
|
|
||||||||
| return cos, sin | ||||||||
|
|
||||||||
|
|
||||||||
| def _rotate_half(x: jax.Array) -> jax.Array: | ||||||||
| """Performs consecutive half-rotation to match DeepSeek-V4 interleaved layout. | ||||||||
|
|
||||||||
| Pairs adjacent elements: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2]. | ||||||||
|
|
||||||||
| Operations: | ||||||||
| 1. Slice even indices: x1 = x[..., 0::2] | ||||||||
| 2. Slice odd indices: x2 = x[..., 1::2] | ||||||||
| 3. Stack (-x2, x1) along a new trailing dimension: [..., D/2, 2] | ||||||||
| 4. Reshape back to the original dimension: [..., D] | ||||||||
| """ | ||||||||
| x1 = x[..., 0::2] # [B, S, H, D_rope/2] | ||||||||
| x2 = x[..., 1::2] # [B, S, H, D_rope/2] | ||||||||
|
|
||||||||
| # Interleave consecutive components: [-x2_0, x1_0, -x2_1, x1_1, ...] | ||||||||
| stacked = jnp.stack((-x2, x1), axis=-1) # [B, S, H, D_rope/2, 2] | ||||||||
| return stacked.reshape(x.shape) # [B, S, H, D_rope] | ||||||||
|
|
||||||||
|
|
||||||||
| def apply_rotary_pos_emb( | ||||||||
| x: jax.Array, | ||||||||
| cos: jax.Array, | ||||||||
| sin: jax.Array, | ||||||||
| unsqueeze_dim: int = 2, | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes your input x will always strictly be 4D: [Batch, Sequence, Head, Dimension], I know this is very sensible assumption, but If you ever pass a single sequence without a batch dimension [S, H, D], or a 5D tensor, axis=2 will unsqueeze the wrong dimension and crash the broadcast. Please see if we can do something like this : Instead of relying on a hardcoded unsqueeze_dim |
||||||||
| ) -> jax.Array: | ||||||||
| """Applies DeepSeek-V4 interleaved RoPE to the trailing rotary slice of x. | ||||||||
|
|
||||||||
| 1. Duplicates inverse frequencies consecutively using jnp.repeat along the | ||||||||
| last dimension to match the full rotary dimension size. | ||||||||
| 2. Extracts the trailing 'rope_dim' channels of x to apply rotation, leaving | ||||||||
| the leading 'nope' channels unmodified. | ||||||||
| 3. Computes the rotation using float32 precision for numerical stability, | ||||||||
| casting the final rotated tensor back to the input data type. | ||||||||
| """ | ||||||||
| # cos/sin shape: [B, S, D_rope/2] | ||||||||
| # Duplicate frequencies consecutively to build full D_rope dimension | ||||||||
| cos = jnp.repeat(cos, 2, axis=-1) # [B, S, D_rope] | ||||||||
| sin = jnp.repeat(sin, 2, axis=-1) # [B, S, D_rope] | ||||||||
|
|
||||||||
| # Expand dimensions for head broadcasting: [B, S, 1, D_rope] | ||||||||
| cos = jnp.expand_dims(cos, axis=unsqueeze_dim) | ||||||||
| sin = jnp.expand_dims(sin, axis=unsqueeze_dim) | ||||||||
|
|
||||||||
| rope_dim = cos.shape[-1] | ||||||||
|
|
||||||||
| # Separate features into unrotated (nope) and rotated (rope) slices | ||||||||
| # x: [B, S, H, D] where D is the head dimension | ||||||||
| nope = x[..., :-rope_dim] # [B, S, H, D - D_rope] | ||||||||
| rope = x[..., -rope_dim:] # [B, S, H, D_rope] | ||||||||
|
|
||||||||
| # Cast to float32, compute rotation, and cast back to original data type | ||||||||
| rope_f32 = rope.astype(jnp.float32) | ||||||||
| rotated = (rope_f32 * cos) + (_rotate_half(rope_f32) * sin) | ||||||||
| rotated = rotated.astype(x.dtype) | ||||||||
|
|
||||||||
| # Concatenate unrotated and rotated channels | ||||||||
| return jnp.concatenate([nope, rotated], axis=-1) # [B, S, H, D] | ||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -567,3 +567,66 @@ def mlp_block( | |||||||||||||||||||||||||||||||||||||||||
| abstract_init=False, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| return module | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| class DeepSeekGroupedLinear(nnx.Module): | ||||||||||||||||||||||||||||||||||||||||||
| """Block-diagonal grouped linear projection layer. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| This layer segments the trailing dimension of the input tensor into a specified | ||||||||||||||||||||||||||||||||||||||||||
| number of groups, and projects each group independently using a distinct weight | ||||||||||||||||||||||||||||||||||||||||||
| matrix block. It minimizes parameter counts and compute overhead in the | ||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟢 For consistency with other DeepSeek-V4 layers like `DeepSeekV4RMSNorm` and `DeepSeekV4RotaryEmbedding`, consider renaming this class to `DeepSeekV4GroupedLinear`.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
| attention output projection. | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||
| in_features_per_group: int, | ||||||||||||||||||||||||||||||||||||||||||
| out_features: int, | ||||||||||||||||||||||||||||||||||||||||||
| n_groups: int, | ||||||||||||||||||||||||||||||||||||||||||
| weight_dtype: DType = jnp.float32, | ||||||||||||||||||||||||||||||||||||||||||
| dtype: DType = jnp.float32, | ||||||||||||||||||||||||||||||||||||||||||
| kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"), | ||||||||||||||||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||||||||||||||||
| rngs: nnx.Rngs, | ||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||
| self.in_features_per_group = in_features_per_group | ||||||||||||||||||||||||||||||||||||||||||
| self.out_features = out_features | ||||||||||||||||||||||||||||||||||||||||||
| self.n_groups = n_groups | ||||||||||||||||||||||||||||||||||||||||||
| self.weight_dtype = weight_dtype | ||||||||||||||||||||||||||||||||||||||||||
| self.dtype = dtype | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Validate divisibility of target output features by group count | ||||||||||||||||||||||||||||||||||||||||||
| if out_features % n_groups != 0: | ||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Output features ({out_features}) must be divisible by n_groups ({n_groups}).") | ||||||||||||||||||||||||||||||||||||||||||
| self.out_features_per_group = out_features // n_groups | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Grouped block-diagonal projection kernel parameters | ||||||||||||||||||||||||||||||||||||||||||
| # Kernels are stored as a 3D tensor: [n_groups, in_features_per_group, out_features_per_group] | ||||||||||||||||||||||||||||||||||||||||||
| kernel_shape = (n_groups, in_features_per_group, self.out_features_per_group) | ||||||||||||||||||||||||||||||||||||||||||
| self.weight = nnx.Param( | ||||||||||||||||||||||||||||||||||||||||||
| kernel_init( | ||||||||||||||||||||||||||||||||||||||||||
| rngs.params(), | ||||||||||||||||||||||||||||||||||||||||||
| kernel_shape, | ||||||||||||||||||||||||||||||||||||||||||
| self.weight_dtype, | ||||||||||||||||||||||||||||||||||||||||||
| in_axis=1, | ||||||||||||||||||||||||||||||||||||||||||
| out_axis=2, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+605
to
+614
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟠 This layer is missing `sharding` support in its `nnx.Param` initialization. In MaxText, it's critical to provide explicit sharding constraints for parameters to ensure they are correctly partitioned across the device mesh. Consider adding a `kernel_axes` parameter to `__init__` and passing it to the `nnx.Param` call.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: | ||||||||||||||||||||||||||||||||||||||||||
| """Projects segmented groups from the input tensor using block weight matrices. | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||
| x: Input tensor of shape [..., n_groups, in_features_per_group] | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||
| Projected tensor of shape [..., n_groups, out_features_per_group] | ||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||
| x = jnp.asarray(x, self.dtype) | ||||||||||||||||||||||||||||||||||||||||||
| weight = jnp.asarray(self.weight[...], self.dtype) | ||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change this to weight = jnp.asarray(self.weight.value, self.dtype)? for consistency with normalizations.py? |
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Execute parallel group projection via optimized einsum broadcasting. | ||||||||||||||||||||||||||||||||||||||||||
| # x: [..., g, i] | ||||||||||||||||||||||||||||||||||||||||||
| # weight: [g, i, o] | ||||||||||||||||||||||||||||||||||||||||||
| # output: [..., g, o] | ||||||||||||||||||||||||||||||||||||||||||
| return jnp.einsum("...gi,gio->...go", x, weight) | ||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -240,3 +240,62 @@ def l2norm(x: Array, dim: int = -1, eps: float = 1e-6) -> Array: | |||||||||||
| scale_init=linen_initializers.zeros, | ||||||||||||
| scale_offset=1.0, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class DeepSeekV4RMSNorm(nnx.Module): | ||||||||||||
| """RMS normalization for DeepSeek-V4 (equivalent to T5LayerNorm).""" | ||||||||||||
|
|
||||||||||||
| def __init__( | ||||||||||||
| self, | ||||||||||||
| hidden_size: int, | ||||||||||||
| eps: float = 1e-6, | ||||||||||||
| dtype: Any = jnp.float32, | ||||||||||||
| weight_dtype: Any = jnp.float32, | ||||||||||||
| ): | ||||||||||||
| self.hidden_size = hidden_size | ||||||||||||
| self.eps = eps | ||||||||||||
| self.dtype = dtype | ||||||||||||
| self.weight_dtype = weight_dtype | ||||||||||||
|
|
||||||||||||
| # Initialize learnable scale weight to ones matching T5LayerNorm behavior | ||||||||||||
| self.weight = nnx.Param(jnp.ones((hidden_size,), dtype=weight_dtype)) | ||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
🟠 Like the grouped linear layer, this normalization layer is missing `sharding` support for its weight parameter. To maintain consistency with the base `RMSNorm` and ensure proper sharding in large-scale training, consider adding `kernel_axes` to the constructor and passing it to `nnx.Param`.
Suggested change
|
||||||||||||
|
|
||||||||||||
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: | ||||||||||||
| # [B, S, D] where D = hidden_size | ||||||||||||
| # Convert inputs to float32 for numerical stability during variance pooling | ||||||||||||
| x_f32 = jnp.asarray(x, jnp.float32) # [B, S, D] in float32 | ||||||||||||
|
|
||||||||||||
| # Calculate variance across features axis | ||||||||||||
| variance = jnp.mean(lax.square(x_f32), axis=-1, keepdims=True) # [B, S, 1] | ||||||||||||
|
|
||||||||||||
| # Apply reciprocal square root with epsilon offset | ||||||||||||
| normalized = x_f32 * lax.rsqrt(variance + self.eps) # [B, S, D] | ||||||||||||
|
|
||||||||||||
| # Cast back to active precision and apply scaling weight | ||||||||||||
| y = jnp.asarray(normalized, self.dtype) # [B, S, D] | ||||||||||||
| weight = jnp.asarray(self.weight.get_value(), self.dtype) # [D] | ||||||||||||
| return y * weight # [B, S, D] | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class DeepSeekV4UnweightedRMSNorm(nnx.Module): | ||||||||||||
| """Unweighted RMS normalization for DeepSeek-V4.""" | ||||||||||||
|
|
||||||||||||
| def __init__( | ||||||||||||
| self, | ||||||||||||
| eps: float = 1e-6, | ||||||||||||
| dtype: Any = jnp.float32, | ||||||||||||
| ): | ||||||||||||
| self.eps = eps | ||||||||||||
| self.dtype = dtype | ||||||||||||
|
|
||||||||||||
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: | ||||||||||||
| # [..., D] where D is feature dimension | ||||||||||||
| # Convert inputs to float32 for numerical stability during variance pooling | ||||||||||||
| x_f32 = jnp.asarray(x, jnp.float32) # [..., D] in float32 | ||||||||||||
|
|
||||||||||||
| # Calculate variance across features axis | ||||||||||||
| variance = jnp.mean(lax.square(x_f32), axis=-1, keepdims=True) # [..., 1] | ||||||||||||
|
|
||||||||||||
| # Apply reciprocal square root and cast back to active precision | ||||||||||||
| normalized = x_f32 * lax.rsqrt(variance + self.eps) # [..., D] | ||||||||||||
| return jnp.asarray(normalized, self.dtype) # [..., D] | ||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to replace these three lines with?
freqs = jnp.einsum('bs,d->bsd', position_ids.astype(jnp.float32), self.inv_freq)