Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions src/maxtext/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import dataclasses
import math
from typing import Any

import jax
from jax import lax
Expand Down Expand Up @@ -1800,3 +1801,115 @@ def qwen3_omni_mrope_embedding_as_linen(
metadata_fn=variable_to_logically_partitioned,
name=name,
)


class DeepSeekV4RotaryEmbedding(nnx.Module):
"""DeepSeek-V4 partial rotary embedding with interleaved frequencies.

DeepSeek-V4 uses an interleaved positional encoding where consecutive channels
are paired together. Unlike standard rotary models that split dimensions globally
into first and second halves, this implementation pairs each even channel 2i
with the corresponding odd channel 2i + 1.

This results in two specific mathematical properties:
1. Inverse frequencies are computed for (dim // 2) unique theta angles.
2. Sinusoidal components are expanded consecutively (e.g., [f0, f0, f1, f1])
prior to application.
"""

def __init__(
self,
head_dim: int,
partial_rotary_factor: float = 64.0 / 512.0,
rope_theta: float = 10000.0,
dtype: Any = jnp.float32,
):
self.head_dim = head_dim
self.partial_rotary_factor = partial_rotary_factor
self.rope_theta = rope_theta
self.dtype = dtype

# Compute the partial rotary dimension (rope_head_dim)
# e.g., 512 * (64 / 512) = 64 channels
self.dim = int(head_dim * partial_rotary_factor)

# Compute base inverse frequencies for half of self.dim (dim // 2 unique theta angles).
# Adjacent channels share the same base frequency, matching the reference sequence.
half_dim = self.dim // 2
fraction = 2 * jnp.arange(0, half_dim, dtype=jnp.float32) / self.dim
self.inv_freq = 1.0 / (self.rope_theta**fraction)

def __call__(self, x: jnp.ndarray, position_ids: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
# position_ids: [B, S]
# Expand inverse frequencies for broadcasting: [1, 1, dim/2]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to replace these three lines with?

freqs = jnp.einsum('bs,d->bsd', position_ids.astype(jnp.float32), self.inv_freq)

inv_freq_expanded = self.inv_freq[jnp.newaxis, jnp.newaxis, :]

# Expand position IDs: [B, S, 1]
position_ids_expanded = position_ids[:, :, jnp.newaxis].astype(jnp.float32)

# Compute outer product of positions and frequencies: [B, S, dim/2]
freqs = position_ids_expanded * inv_freq_expanded

Comment on lines +1851 to +1852
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 In `DeepSeekV4RotaryEmbedding.__call__`, the cosine and sine values are cast to `x.dtype` (which could be `bfloat16`) immediately after computation. Since the subsequent `apply_rotary_pos_emb` explicitly performs rotation in `float32` for numerical stability, it's better to keep the sinusoids in `float32` (or the precision specified in `self.dtype`) until the final rotation is computed. This prevents premature precision loss.
Suggested change
freqs = position_ids_expanded * inv_freq_expanded
cos = jnp.cos(freqs).astype(self.dtype) # [B, S, dim/2]
sin = jnp.sin(freqs).astype(self.dtype) # [B, S, dim/2]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

cos = jnp.cos(freqs).astype(x.dtype) # [B, S, dim/2]
sin = jnp.sin(freqs).astype(x.dtype) # [B, S, dim/2]

return cos, sin


def _rotate_half(x: jax.Array) -> jax.Array:
"""Performs consecutive half-rotation to match DeepSeek-V4 interleaved layout.

Pairs adjacent elements: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2].

Operations:
1. Slice even indices: x1 = x[..., 0::2]
2. Slice odd indices: x2 = x[..., 1::2]
3. Stack (-x2, x1) along a new trailing dimension: [..., D/2, 2]
4. Reshape back to the original dimension: [..., D]
"""
x1 = x[..., 0::2] # [B, S, H, D_rope/2]
x2 = x[..., 1::2] # [B, S, H, D_rope/2]

# Interleave consecutive components: [-x2_0, x1_0, -x2_1, x1_1, ...]
stacked = jnp.stack((-x2, x1), axis=-1) # [B, S, H, D_rope/2, 2]
return stacked.reshape(x.shape) # [B, S, H, D_rope]


def apply_rotary_pos_emb(
x: jax.Array,
cos: jax.Array,
sin: jax.Array,
unsqueeze_dim: int = 2,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes your input x will always strictly be 4D: [Batch, Sequence, Head, Dimension], I know this is very sensible assumption, but If you ever pass a single sequence without a batch dimension [S, H, D], or a 5D tensor, axis=2 will unsqueeze the wrong dimension and crash the broadcast.

Please see if we can do something like this :

Instead of relying on a hardcoded unsqueeze_dim

head_dim_axis = x.ndim - 2
cos = jnp.expand_dims(cos, axis=head_dim_axis)
sin = jnp.expand_dims(sin, axis=head_dim_axis)

) -> jax.Array:
"""Applies DeepSeek-V4 interleaved RoPE to the trailing rotary slice of x.

1. Duplicates inverse frequencies consecutively using jnp.repeat along the
last dimension to match the full rotary dimension size.
2. Extracts the trailing 'rope_dim' channels of x to apply rotation, leaving
the leading 'nope' channels unmodified.
3. Computes the rotation using float32 precision for numerical stability,
casting the final rotated tensor back to the input data type.
"""
# cos/sin shape: [B, S, D_rope/2]
# Duplicate frequencies consecutively to build full D_rope dimension
cos = jnp.repeat(cos, 2, axis=-1) # [B, S, D_rope]
sin = jnp.repeat(sin, 2, axis=-1) # [B, S, D_rope]

# Expand dimensions for head broadcasting: [B, S, 1, D_rope]
cos = jnp.expand_dims(cos, axis=unsqueeze_dim)
sin = jnp.expand_dims(sin, axis=unsqueeze_dim)

rope_dim = cos.shape[-1]

# Separate features into unrotated (nope) and rotated (rope) slices
# x: [B, S, H, D] where D is the head dimension
nope = x[..., :-rope_dim] # [B, S, H, D - D_rope]
rope = x[..., -rope_dim:] # [B, S, H, D_rope]

# Cast to float32, compute rotation, and cast back to original data type
rope_f32 = rope.astype(jnp.float32)
rotated = (rope_f32 * cos) + (_rotate_half(rope_f32) * sin)
rotated = rotated.astype(x.dtype)

# Concatenate unrotated and rotated channels
return jnp.concatenate([nope, rotated], axis=-1) # [B, S, H, D]
63 changes: 63 additions & 0 deletions src/maxtext/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,3 +567,66 @@ def mlp_block(
abstract_init=False,
)
return module


class DeepSeekGroupedLinear(nnx.Module):
"""Block-diagonal grouped linear projection layer.

This layer segments the trailing dimension of the input tensor into a specified
number of groups, and projects each group independently using a distinct weight
matrix block. It minimizes parameter counts and compute overhead in the
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 For consistency with other DeepSeek-V4 layers like `DeepSeekV4RMSNorm` and `DeepSeekV4RotaryEmbedding`, consider renaming this class to `DeepSeekV4GroupedLinear`.
Suggested change
matrix block. It minimizes parameter counts and compute overhead in the
class DeepSeekV4GroupedLinear(nnx.Module):

attention output projection.
"""

def __init__(
self,
in_features_per_group: int,
out_features: int,
n_groups: int,
weight_dtype: DType = jnp.float32,
dtype: DType = jnp.float32,
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
*,
rngs: nnx.Rngs,
):
self.in_features_per_group = in_features_per_group
self.out_features = out_features
self.n_groups = n_groups
self.weight_dtype = weight_dtype
self.dtype = dtype

# Validate divisibility of target output features by group count
if out_features % n_groups != 0:
raise ValueError(f"Output features ({out_features}) must be divisible by n_groups ({n_groups}).")
self.out_features_per_group = out_features // n_groups

# Grouped block-diagonal projection kernel parameters
# Kernels are stored as a 3D tensor: [n_groups, in_features_per_group, out_features_per_group]
kernel_shape = (n_groups, in_features_per_group, self.out_features_per_group)
self.weight = nnx.Param(
kernel_init(
rngs.params(),
kernel_shape,
self.weight_dtype,
in_axis=1,
out_axis=2,
)
)
Comment on lines +605 to +614
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 This layer is missing `sharding` support in its `nnx.Param` initialization. In MaxText, it's critical to provide explicit sharding constraints for parameters to ensure they are correctly partitioned across the device mesh. Consider adding a `kernel_axes` parameter to `__init__` and passing it to the `nnx.Param` call.
Suggested change
kernel_shape = (n_groups, in_features_per_group, self.out_features_per_group)
self.weight = nnx.Param(
kernel_init(
rngs.params(),
kernel_shape,
self.weight_dtype,
in_axis=1,
out_axis=2,
)
)
self.weight = nnx.Param(
kernel_init(
rngs.params(),
kernel_shape,
self.weight_dtype,
in_axis=1,
out_axis=2,
),
sharding=kernel_axes,
)


def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Projects segmented groups from the input tensor using block weight matrices.

Args:
x: Input tensor of shape [..., n_groups, in_features_per_group]

Returns:
Projected tensor of shape [..., n_groups, out_features_per_group]
"""
x = jnp.asarray(x, self.dtype)
weight = jnp.asarray(self.weight[...], self.dtype)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change this to weight = jnp.asarray(self.weight.value, self.dtype)? for consistency with normalizations.py?


# Execute parallel group projection via optimized einsum broadcasting.
# x: [..., g, i]
# weight: [g, i, o]
# output: [..., g, o]
return jnp.einsum("...gi,gio->...go", x, weight)
59 changes: 59 additions & 0 deletions src/maxtext/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,62 @@ def l2norm(x: Array, dim: int = -1, eps: float = 1e-6) -> Array:
scale_init=linen_initializers.zeros,
scale_offset=1.0,
)


class DeepSeekV4RMSNorm(nnx.Module):
"""RMS normalization for DeepSeek-V4 (equivalent to T5LayerNorm)."""

def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
dtype: Any = jnp.float32,
weight_dtype: Any = jnp.float32,
):
self.hidden_size = hidden_size
self.eps = eps
self.dtype = dtype
self.weight_dtype = weight_dtype

# Initialize learnable scale weight to ones matching T5LayerNorm behavior
self.weight = nnx.Param(jnp.ones((hidden_size,), dtype=weight_dtype))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Like the grouped linear layer, this normalization layer is missing `sharding` support for its weight parameter. To maintain consistency with the base `RMSNorm` and ensure proper sharding in large-scale training, consider adding `kernel_axes` to the constructor and passing it to `nnx.Param`.
Suggested change
self.weight = nnx.Param(jnp.ones((hidden_size,), dtype=weight_dtype))
self.weight = nnx.Param(
jnp.ones((hidden_size,), dtype=weight_dtype),
sharding=kernel_axes,
)


def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
# [B, S, D] where D = hidden_size
# Convert inputs to float32 for numerical stability during variance pooling
x_f32 = jnp.asarray(x, jnp.float32) # [B, S, D] in float32

# Calculate variance across features axis
variance = jnp.mean(lax.square(x_f32), axis=-1, keepdims=True) # [B, S, 1]

# Apply reciprocal square root with epsilon offset
normalized = x_f32 * lax.rsqrt(variance + self.eps) # [B, S, D]

# Cast back to active precision and apply scaling weight
y = jnp.asarray(normalized, self.dtype) # [B, S, D]
weight = jnp.asarray(self.weight.get_value(), self.dtype) # [D]
return y * weight # [B, S, D]


class DeepSeekV4UnweightedRMSNorm(nnx.Module):
"""Unweighted RMS normalization for DeepSeek-V4."""

def __init__(
self,
eps: float = 1e-6,
dtype: Any = jnp.float32,
):
self.eps = eps
self.dtype = dtype

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
# [..., D] where D is feature dimension
# Convert inputs to float32 for numerical stability during variance pooling
x_f32 = jnp.asarray(x, jnp.float32) # [..., D] in float32

# Calculate variance across features axis
variance = jnp.mean(lax.square(x_f32), axis=-1, keepdims=True) # [..., 1]

# Apply reciprocal square root and cast back to active precision
normalized = x_f32 * lax.rsqrt(variance + self.eps) # [..., D]
return jnp.asarray(normalized, self.dtype) # [..., D]
Loading
Loading