diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 98f8e39efa..a59b07b11f 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -405,6 +405,16 @@ qk_clip_threshold: 100.0 # Threshold for clipping (tau in the paper) fused_qkv: False fused_mlp: False +# DeepSeek-V4 Compressed Attention parameters +compress_rope_theta: 160000.0 +compress_ratios: [] +index_head_dim: 128 +index_n_heads: 64 +index_topk: 512 +o_groups: 8 +o_lora_rank: 1024 +sliding_window: 128 + record_internal_nn_metrics: 0 # Output directory diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index b3261ab75d..e663bcfc5c 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -618,6 +618,22 @@ class AttentionIndexer(BaseModel): indexer_loss_scaling_factor: float = Field(0.0, description="Multiplier for the indexer KL divergence loss.") +class DeepSeekV4AttentionConfig(BaseModel): + """Configuration specific to DeepSeek-V4 stateless compressed attention layers.""" + + compress_rope_theta: float = Field(160000.0, description="Theta base frequency for long-range compressor layers.") + compress_ratios: list[int] = Field( + default_factory=list, + description="Layer-by-layer compressor rates (0: standard, 4: CSA, 128: HCA).", + ) + index_head_dim: int = Field(128, description="Head dim for indexer query and key.") + index_n_heads: int = Field(64, description="Number of query heads in indexer.") + index_topk: int = Field(512, description="Number of tokens selected by indexer.") + o_groups: int = Field(8, description="Number of group partitions for grouped linear output projection.") + o_lora_rank: int = Field(1024, description="Low-rank output dimension prior to grouped mix projection.") + sliding_window: int = Field(128, description="Sliding window size for attention.") + + class Llama4Attention(BaseModel): """Configuration specific to Llama4-style models.""" @@ -2224,6 +2240,7 @@ class MaxTextConfig( MlaAttention, MoBa, AttentionIndexer, + DeepSeekV4AttentionConfig, Llama4Attention, SplashAttention, PagedAttention, diff --git a/src/maxtext/layers/attention_compressed.py b/src/maxtext/layers/attention_compressed.py new file mode 100644 index 0000000000..eba547d68f --- /dev/null +++ b/src/maxtext/layers/attention_compressed.py @@ -0,0 +1,1000 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compressed Attention layers and long-range compressors.""" + +from typing import Any +import jax +import jax.numpy as jnp +from flax import nnx +from maxtext.layers.embeddings import DeepSeekV4RotaryEmbedding, apply_rotary_pos_emb +from maxtext.layers.normalizations import DeepSeekV4RMSNorm, DeepSeekV4UnweightedRMSNorm +from maxtext.layers.linears import DeepSeekGroupedLinear + + +class HCACompressor(nnx.Module): + """Heavily Compressed Attention (HCA) long-range compressor layer. + + This layer groups sequence features into non-overlapping windows of size 'compress_rate', + applies learnable pooling gates combined with static positional bias, averages the features + inside each window to emit a single compressed representation per window, and rotates the + resulting compressed sequence using interleaved rotary embeddings. + """ + + def __init__( + self, + hidden_size: int, + head_dim: int, + config: Any, + layer_idx: int, + eps: float = 1e-6, + weight_dtype: Any = jnp.float32, + dtype: Any = jnp.float32, + *, + rngs: nnx.Rngs, + ): + """Initializes the Heavily Compressed Attention (HCA) long-range compressor. + + Args: + hidden_size: The model's global hidden dimension size. + head_dim: The projection size of each attention key-value channel. + config: The DeepSeekV4 model configurations metadata. + layer_idx: The sequential layer depth index of this compressor in the decoder stack. + eps: The tiny additive variance limit for RMS normalization stability. + weight_dtype: The parameter weights numerical data type. + dtype: The mathematical execution numerical data type. + rngs: The standard Flax NNX random number generator collection. + """ + super().__init__() + self.compress_rate = config.compress_ratios[layer_idx] + self.head_dim = head_dim + self.hidden_size = hidden_size + self.eps = eps + self.weight_dtype = weight_dtype + self.dtype = dtype + rope_theta = config.compress_rope_theta + + # Linear projection of inputs to key/value representation + self.kv_proj = nnx.Linear( + in_features=hidden_size, + out_features=head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Linear projection of inputs to gate logits + self.gate_proj = nnx.Linear( + in_features=hidden_size, + out_features=head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Positional bias parameter added to gate logits inside each window + self.position_bias = nnx.Param( + jax.nn.initializers.normal(stddev=0.02)( + rngs.params(), + (self.compress_rate, head_dim), + weight_dtype, + ) + ) + + # RMS normalization applied to pooled window features + self.kv_norm = DeepSeekV4RMSNorm( + hidden_size=head_dim, + eps=eps, + dtype=dtype, + weight_dtype=weight_dtype, + ) + + # Interleaved rotary embeddings applied to the trailing slice + self.rotary_emb = DeepSeekV4RotaryEmbedding( + head_dim=head_dim, + partial_rotary_factor=64.0 / 512.0, + rope_theta=rope_theta, + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + q_residual: Any = None, + position_ids: jnp.ndarray = None, + ) -> tuple[jnp.ndarray, jnp.ndarray | None]: + """Applies Heavily Compressed Attention (HCA) compression to sequence keys and values. + + This method splits the sequence into non-overlapping windows of size 'compress_rate', + aggregates feature representation vectors using Softmax-weighted gates, normalizes the + resulting vectors using RMS norm, applies position-aware interleaved rotary embeddings, + and expands the output dimension to match standard multi-head key-value layouts. + + Args: + hidden_states: The input hidden representation sequence of shape [B, S, D_model]. + q_residual: Ignored optional placeholder matching polymorphic calling conventions. + position_ids: Optional position indicators of shape [B, S]. + + Returns: + Compressed, position-encoded representation tensor of shape [B, 1, W, D_head], + where W is the compressed sequence length equal to S // compress_rate. + """ + # hidden_states: [B, S, D_model] + # position_ids: [B, S] + batch, seq_len, _ = hidden_states.shape + + # Project inputs to key/value and gate representations + # kv: [B, S, D_head] + # gate: [B, S, D_head] + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + + # Compute sequence multiple bound corresponding to the window stride rate + # usable: scalar integer + usable = (seq_len // self.compress_rate) * self.compress_rate + n_windows = usable // self.compress_rate + + # Slice sequences to match clean multiple dimensions + # chunk_kv: [B, S_usable, D_head] + # chunk_gate: [B, S_usable, D_head] + chunk_kv = kv[:, :usable, :] + chunk_gate = gate[:, :usable, :] + + # Reshape sliced inputs into non-overlapping windows of size 'compress_rate' + # chunk_kv: [B, W, compress_rate, D_head] + # chunk_gate: [B, W, compress_rate, D_head] + chunk_kv = chunk_kv.reshape(batch, n_windows, self.compress_rate, self.head_dim) + chunk_gate = chunk_gate.reshape(batch, n_windows, self.compress_rate, self.head_dim) + + # Add positional bias parameters to gate logits + # chunk_gate: [B, W, compress_rate, D_head] + position_bias = jnp.asarray(self.position_bias[...], self.dtype) + chunk_gate = chunk_gate + position_bias[jnp.newaxis, jnp.newaxis, :, :] + + # Compute softmax aggregation probabilities in float32 for stability + # gate_softmax: [B, W, compress_rate, D_head] + gate_softmax = jax.nn.softmax(chunk_gate.astype(jnp.float32), axis=2).astype(self.dtype) + + # Aggregate key/value features using computed gate weights + # pooled: [B, W, D_head] + pooled = jnp.sum(chunk_kv * gate_softmax, axis=2) + + # Normalize aggregated window features + # compressed: [B, W, D_head] + compressed = self.kv_norm(pooled) + + # Determine absolute sequence indexes corresponding to each window start + # positions: [B, W] + positions = jnp.arange(n_windows, dtype=jnp.int32) * self.compress_rate + positions = jnp.broadcast_to(positions[jnp.newaxis, :], (batch, n_windows)) + + # Compute interleaved rotary embeddings sine and cosine values + # cos: [B, W, D_rope/2] + # sin: [B, W, D_rope/2] + cos, sin = self.rotary_emb(compressed, positions) + + # Expand dimensions to allow broadcasting over head axis during rotary mapping + # compressed_4d: [B, W, 1, D_head] + compressed_4d = jnp.expand_dims(compressed, axis=2) + + # Apply interleaved RoPE rotation over the trailing slice + # rotated_4d: [B, W, 1, D_head] + rotated_4d = apply_rotary_pos_emb(compressed_4d, cos, sin, unsqueeze_dim=2) + + # Squeeze dummy head dimension to recover standard 3D shape layout + # rotated: [B, W, D_head] + rotated = jnp.squeeze(rotated_4d, axis=2) + + # Expand output format to match standard multi-head key/value dimensions + # compressed_kv: [B, 1, W, D_head] + compressed_kv = jnp.expand_dims(rotated, axis=1) + + # Evaluate caching dimensions boundary checks to prevent empty execution + compressed_len = n_windows + if seq_len == 1 or compressed_len == 0: + return compressed_kv, None + + # Compute causal block bias mask over compressed sequence segments to prevent query leakage. + # A query at sequence position `t` is restricted from attending to any compressed cache block + # index `w` if `t <= w * compress_rate`. This represents future sequence information that is + # mathematically unavailable at position `t`. + # + # entry_indices: [W] representing compressed block window positions + entry_indices = jnp.arange(compressed_len, dtype=jnp.int32) + # causal_threshold: [B, S] representing ready block count boundaries per sequence token + causal_threshold = (position_ids + 1) // self.compress_rate + # Construct sequence-level causal future mask via dimension broadcasting. + # future_mask: [B, 1, S, W] + future_mask = ( + entry_indices[jnp.newaxis, jnp.newaxis, jnp.newaxis, :] >= causal_threshold[:, jnp.newaxis, :, jnp.newaxis] + ) + # Initialize causal block bias containing -inf mask values for invalid future elements. + # block_bias: [B, 1, S, W] + block_bias = jnp.where(future_mask, -jnp.inf, 0.0) + return compressed_kv, block_bias + + +class DeepSeekV4Indexer(nnx.Module): + """Lightning Indexer (paper §2.3.1, eqs. 13–17). + + Used by Compressed Sparse Attention (CSA) to pick the top-k compressed KV + blocks per query. + """ + + def __init__( + self, + hidden_size: int, + q_lora_rank: int, + config: Any, + layer_idx: int, + eps: float = 1e-6, + weight_dtype: Any = jnp.float32, + dtype: Any = jnp.float32, + *, + rngs: nnx.Rngs, + ): + """Initializes the Lightning Indexer. + + Args: + hidden_size: The model's global hidden dimension size. + q_lora_rank: The projection rank dimension of Q LoRA. + config: The DeepSeekV4 model configurations metadata. + layer_idx: The decoder stack layer index containing this indexer. + eps: Tiny additive variance limit for RMS normalization stability. + weight_dtype: The parameter weights numerical data type. + dtype: The mathematical execution numerical data type. + rngs: The Flax NNX random number generator collection. + """ + super().__init__() + self.compress_rate = config.compress_ratios[layer_idx] + self.num_heads = config.index_n_heads + self.head_dim = config.index_head_dim + self.index_topk = config.index_topk + self.softmax_scale = config.index_head_dim**-0.5 + self.weights_scaling = config.index_n_heads**-0.5 + self.dtype = dtype + self.weight_dtype = weight_dtype + rope_theta = config.compress_rope_theta + + # Key projections for indexing-scale compression + self.kv_proj = nnx.Linear( + in_features=hidden_size, + out_features=2 * self.head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Gate projections for indexing-scale compression + self.gate_proj = nnx.Linear( + in_features=hidden_size, + out_features=2 * self.head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Positional bias parameters inside indexing windows + self.position_bias = nnx.Param( + jax.nn.initializers.normal(stddev=0.02)( + rngs.params(), + (self.compress_rate, 2 * self.head_dim), + weight_dtype, + ) + ) + + # RMS normalization for indexer key values + self.kv_norm = DeepSeekV4RMSNorm( + hidden_size=self.head_dim, + eps=eps, + dtype=dtype, + weight_dtype=weight_dtype, + ) + + # Query projection mapping Q LoRA rank to multi-head indexing features + self.q_b_proj = nnx.Linear( + in_features=q_lora_rank, + out_features=self.num_heads * self.head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Dynamic score scaling projection + self.weights_proj = nnx.Linear( + in_features=hidden_size, + out_features=self.num_heads, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Interleaved rotary embedding aligning query/key pos representations + self.rotary_emb = DeepSeekV4RotaryEmbedding( + head_dim=self.head_dim, + partial_rotary_factor=(config.head_dim * (64.0 / 512.0)) / self.head_dim, + rope_theta=rope_theta, + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + q_residual: jnp.ndarray, + position_ids: jnp.ndarray, + ) -> jnp.ndarray: + """Computes top-k relevant compressed block indices per query position. + + This method compresses sequence keys and values into overlapping window + segments, applies position-aware RoPE encoding, projects incoming query residuals + into alignment spaces, computes similarity matrices across query positions and + windows, dynamically scales/weights scores using projected head scaling arrays, + and selects the top-k windows using JAX optimized top_k primitives. + + Args: + hidden_states: The input sequence representations of shape [B, S, D_model]. + q_residual: The Q LoRA low-rank query projections of shape [B, S, D_rank]. + position_ids: The sequence absolute position identifiers of shape [B, S]. + + Returns: + Integer index array of shape [B, S, k] containing the gathered top-k + compressed window indices for each query position, where k = index_topk. + """ + # hidden_states: [B, S, D_model] + # q_residual: [B, S, D_rank] + # position_ids: [B, S] + batch, seq_len, _ = hidden_states.shape + + # Project inputs to index keys and gates + # kv: [B, S, 2 * D_idx] + # gate: [B, S, 2 * D_idx] + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + + # Calculate sequence bounds matching the stride rate + # usable: scalar integer + usable = (seq_len // self.compress_rate) * self.compress_rate + n_windows = usable // self.compress_rate + + # Slice sequences to valid sequence bounds + # chunk_kv: [B, S_usable, 2 * D_idx] + # chunk_gate: [B, S_usable, 2 * D_idx] + chunk_kv = kv[:, :usable, :] + chunk_gate = gate[:, :usable, :] + + # Segment sliced elements into non-overlapping windows + # chunk_kv: [B, W, compress_rate, 2 * D_idx] + # chunk_gate: [B, W, compress_rate, 2 * D_idx] + chunk_kv = chunk_kv.reshape(batch, n_windows, self.compress_rate, 2 * self.head_dim) + chunk_gate = chunk_gate.reshape(batch, n_windows, self.compress_rate, 2 * self.head_dim) + + # Incorporate static positional bias parameters + # chunk_gate: [B, W, compress_rate, 2 * D_idx] + position_bias = jnp.asarray(self.position_bias[...], self.dtype) + chunk_gate = chunk_gate + position_bias[jnp.newaxis, jnp.newaxis, :, :] + + # Overlap slicing setups: segment into Ca / Cb series + # prev_kv: [B, W, compress_rate, D_idx] (Ca) + # curr_kv: [B, W, compress_rate, D_idx] (Cb) + # prev_gate: [B, W, compress_rate, D_idx] (Ca) + # curr_gate: [B, W, compress_rate, D_idx] (Cb) + prev_kv = chunk_kv[..., : self.head_dim] + curr_kv = chunk_kv[..., self.head_dim :] + prev_gate = chunk_gate[..., : self.head_dim] + curr_gate = chunk_gate[..., self.head_dim :] + + # Set up combined padded layouts for boundary window overlap calculations + # new_kv: [B, W, 2 * compress_rate, D_idx] + # new_gate: [B, W, 2 * compress_rate, D_idx] + new_kv = jnp.zeros((batch, n_windows, 2 * self.compress_rate, self.head_dim), dtype=self.dtype) + new_gate = jnp.full((batch, n_windows, 2 * self.compress_rate, self.head_dim), -jnp.inf, dtype=self.dtype) + + # Map Cb representations into second half slots + new_kv = new_kv.at[:, :, self.compress_rate :].set(curr_kv) + new_gate = new_gate.at[:, :, self.compress_rate :].set(curr_gate) + + # Map Ca representations of preceding windows into first half slots + if n_windows > 1: + new_kv = new_kv.at[:, 1:, : self.compress_rate].set(prev_kv[:, :-1, :, :]) + new_gate = new_gate.at[:, 1:, : self.compress_rate].set(prev_gate[:, :-1, :, :]) + + # Aggregate indexing features using gate softmax probabilities computed in float32 + # gate_softmax: [B, W, 2 * compress_rate, D_idx] + gate_softmax = jax.nn.softmax(new_gate.astype(jnp.float32), axis=2).astype(self.dtype) + # pooled: [B, W, D_idx] + pooled = jnp.sum(new_kv * gate_softmax, axis=2) + + # Normalize index keys + # compressed: [B, W, D_idx] + compressed = self.kv_norm(pooled) + + # Extract absolute starting positions of index windows + # positions: [B, W] + positions = jnp.arange(n_windows, dtype=jnp.int32) * self.compress_rate + positions = jnp.broadcast_to(positions[jnp.newaxis, :], (batch, n_windows)) + + # Compute sinusoids and apply interleaved rotary embeddings + # cos: [B, W, D_idx_rope/2] + # sin: [B, W, D_idx_rope/2] + cos, sin = self.rotary_emb(compressed, positions) + # compressed_4d: [B, W, 1, D_idx] + compressed_4d = jnp.expand_dims(compressed, axis=2) + # rotated_4d: [B, W, 1, D_idx] + rotated_4d = apply_rotary_pos_emb(compressed_4d, cos, sin, unsqueeze_dim=2) + # compressed_kv: [B, W, D_idx] + compressed_kv = jnp.squeeze(rotated_4d, axis=2) + + # Project and reshape queries to multiple head alignments + # q: [B, S, H, D_idx] + q = self.q_b_proj(q_residual) + q = q.reshape(batch, seq_len, self.num_heads, self.head_dim) + + # Compute rotary components matching current query positions + # cos_q: [B, S, D_idx_rope/2] + # sin_q: [B, S, D_idx_rope/2] + cos_q, sin_q = self.rotary_emb(hidden_states, position_ids) + # Apply RoPE to query elements + # q: [B, S, H, D_idx] + q = apply_rotary_pos_emb(q, cos_q, sin_q, unsqueeze_dim=2) + + # Calculate attention alignment scores across windows + # swaped_kv: [B, 1, D_idx, W] + swaped_kv = jnp.swapaxes(compressed_kv, -1, -2) + swaped_kv = jnp.expand_dims(swaped_kv, axis=1) + # scores: [B, S, H, W] + scores = jnp.matmul(q, swaped_kv) + scores = jax.nn.relu(scores) * self.softmax_scale + + # Project and scale dynamic aggregation scoring weights + # weights: [B, S, H] + weights = self.weights_proj(hidden_states) * self.weights_scaling + # Aggregate scoring profiles over heads axis + # index_scores: [B, S, W] + index_scores = jnp.sum(scores * jnp.expand_dims(weights, axis=-1), axis=2) + + # Extract top-k scoring compressed blocks per query sequence position + # topk_indices: [B, S, k] + compressed_len = compressed_kv.shape[1] + topk_limit = min(self.index_topk, compressed_len) + + if compressed_len > 0: + # Compute sequence-level causal ready block counts. + # causal_threshold: [B, S] + causal_threshold = (position_ids + 1) // self.compress_rate + # entry_indices: [W] + entry_indices = jnp.arange(compressed_len, dtype=jnp.int32) + # Construct query-specific causal mask along compressed index dimension. + # future_mask: [B, S, W] + future_mask = entry_indices[jnp.newaxis, jnp.newaxis, :] >= causal_threshold[:, :, jnp.newaxis] + # Zero-out future block scores by masking them with -inf prior to top-k calculations. + # index_scores: [B, S, W] + index_scores = jnp.where(future_mask, -jnp.inf, index_scores) + # Select top-k indices per token position based on masked scores. + # topk_indices: [B, S, k] + _, topk_indices = jax.lax.top_k(index_scores, topk_limit) + # Early tokens with too few ready blocks will still have invalid top-k selections pointing + # to future blocks. Detect them and replace with a `-1` sentinel. + # invalid: [B, S, k] + invalid = topk_indices >= causal_threshold[:, :, jnp.newaxis] + topk_indices = jnp.where(invalid, -1, topk_indices) + return topk_indices + + # Fallback stateless default top-k select path + _, topk_indices = jax.lax.top_k(index_scores, topk_limit) + return topk_indices + + +class CSACompressor(nnx.Module): + """Compressed Sparse Attention (CSA) compressor layer. + + This layer aggregates token representations into overlapping Ca/Cb window segments, + normalizes/rotates them, and uses the DeepSeekV4Indexer to gather the top-k + relevant compressed KV blocks per query. + """ + + def __init__( + self, + hidden_size: int, + q_lora_rank: int, + head_dim: int, + config: Any, + layer_idx: int, + eps: float = 1e-6, + weight_dtype: Any = jnp.float32, + dtype: Any = jnp.float32, + *, + rngs: nnx.Rngs, + ): + """Initializes the Compressed Sparse Attention (CSA) compressor. + + Args: + hidden_size: The model's global hidden dimension size. + q_lora_rank: The projection rank dimension of Q LoRA. + head_dim: The projection size of each attention key-value channel. + config: The DeepSeekV4 model configurations metadata. + layer_idx: The decoder stack layer index containing this compressor. + eps: Tiny additive variance limit for RMS normalization stability. + weight_dtype: The parameter weights numerical data type. + dtype: The mathematical execution numerical data type. + rngs: The Flax NNX random number generator collection. + """ + super().__init__() + self.compress_rate = config.compress_ratios[layer_idx] + self.head_dim = head_dim + self.hidden_size = hidden_size + self.eps = eps + self.weight_dtype = weight_dtype + self.dtype = dtype + rope_theta = config.compress_rope_theta + + # Projections for outer compressed key/value formats + self.kv_proj = nnx.Linear( + in_features=hidden_size, + out_features=2 * head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Projections for outer gate logits + self.gate_proj = nnx.Linear( + in_features=hidden_size, + out_features=2 * head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Static positional biases added inside windows + self.position_bias = nnx.Param( + jax.nn.initializers.normal(stddev=0.02)( + rngs.params(), + (config.compress_ratios[layer_idx], 2 * head_dim), + weight_dtype, + ) + ) + + # RMS normalization applied to aggregated representations + self.kv_norm = DeepSeekV4RMSNorm( + hidden_size=head_dim, + eps=eps, + dtype=dtype, + weight_dtype=weight_dtype, + ) + + # Interleaved rotary embeddings for compressed sequences + self.rotary_emb = DeepSeekV4RotaryEmbedding( + head_dim=head_dim, + partial_rotary_factor=64.0 / 512.0, + rope_theta=rope_theta, + ) + + # Lightning Indexer component + self.indexer = DeepSeekV4Indexer( + hidden_size=hidden_size, + q_lora_rank=q_lora_rank, + config=config, + layer_idx=layer_idx, + eps=eps, + weight_dtype=weight_dtype, + dtype=dtype, + rngs=rngs, + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + q_residual: jnp.ndarray, + position_ids: jnp.ndarray, + ) -> tuple[jnp.ndarray, jnp.ndarray]: + """Applies Compressed Sparse Attention (CSA) compression and gathers top-k blocks. + + This method compresses sequence keys and values into overlapping window + segments, applies position-aware RoPE encoding, runs the Lightning Indexer to + extract the top-k scoring window indices for each query position, executes a + high-performance TPU-efficient advanced gather, and shapes the output to match + standard multi-head key-value layouts. + + Args: + hidden_states: The input sequence representations of shape [B, S, D_model]. + q_residual: The Q LoRA low-rank query projections of shape [B, S, D_rank]. + position_ids: The sequence absolute position identifiers of shape [B, S]. + + Returns: + Position-encoded, gathered key-value representation tensor of shape + [B, 1, S * k, D_head], where k = index_topk. + """ + # hidden_states: [B, S, D_model] + # q_residual: [B, S, D_rank] + # position_ids: [B, S] + batch, seq_len, _ = hidden_states.shape + + # Project input features to key/value and gate components + # kv: [B, S, 2 * D] + # gate: [B, S, 2 * D] + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + + # Determine valid sequence bounds + # usable: scalar integer + usable = (seq_len // self.compress_rate) * self.compress_rate + n_windows = usable // self.compress_rate + + # Slice inputs to sequence bounds + # chunk_kv: [B, S_usable, 2 * D] + # chunk_gate: [B, S_usable, 2 * D] + chunk_kv = kv[:, :usable, :] + chunk_gate = gate[:, :usable, :] + + # Segment sliced elements into non-overlapping windows + # chunk_kv: [B, W, compress_rate, 2 * D] + # chunk_gate: [B, W, compress_rate, 2 * D] + chunk_kv = chunk_kv.reshape(batch, n_windows, self.compress_rate, 2 * self.head_dim) + chunk_gate = chunk_gate.reshape(batch, n_windows, self.compress_rate, 2 * self.head_dim) + + # Aggregate window gate logits with static positional biases + # chunk_gate: [B, W, compress_rate, 2 * D] + position_bias = jnp.asarray(self.position_bias[...], self.dtype) + chunk_gate = chunk_gate + position_bias[jnp.newaxis, jnp.newaxis, :, :] + + # Overlap slicing: extract Ca / Cb configurations + # prev_kv: [B, W, compress_rate, D] (Ca) + # curr_kv: [B, W, compress_rate, D] (Cb) + # prev_gate: [B, W, compress_rate, D] (Ca) + # curr_gate: [B, W, compress_rate, D] (Cb) + prev_kv = chunk_kv[..., : self.head_dim] + curr_kv = chunk_kv[..., self.head_dim :] + prev_gate = chunk_gate[..., : self.head_dim] + curr_gate = chunk_gate[..., self.head_dim :] + + # Assemble padded window layouts for overlap combination + # new_kv: [B, W, 2 * compress_rate, D] + # new_gate: [B, W, 2 * compress_rate, D] + new_kv = jnp.zeros((batch, n_windows, 2 * self.compress_rate, self.head_dim), dtype=self.dtype) + new_gate = jnp.full((batch, n_windows, 2 * self.compress_rate, self.head_dim), -jnp.inf, dtype=self.dtype) + + # Set current window representations to second half slots + new_kv = new_kv.at[:, :, self.compress_rate :].set(curr_kv) + new_gate = new_gate.at[:, :, self.compress_rate :].set(curr_gate) + + # Set previous window representations to first half slots + if n_windows > 1: + new_kv = new_kv.at[:, 1:, : self.compress_rate].set(prev_kv[:, :-1, :, :]) + new_gate = new_gate.at[:, 1:, : self.compress_rate].set(prev_gate[:, :-1, :, :]) + + # Aggregate features using window gate softmax probabilities computed in float32 + # gate_softmax: [B, W, 2 * compress_rate, D] + gate_softmax = jax.nn.softmax(new_gate.astype(jnp.float32), axis=2).astype(self.dtype) + # pooled: [B, W, D] + pooled = jnp.sum(new_kv * gate_softmax, axis=2) + + # Normalize window features + # compressed: [B, W, D] + compressed = self.kv_norm(pooled) + + # Obtain starting positions of compressed windows + # positions: [B, W] + positions = jnp.arange(n_windows, dtype=jnp.int32) * self.compress_rate + positions = jnp.broadcast_to(positions[jnp.newaxis, :], (batch, n_windows)) + + # Apply interleaved rotary embeddings over aggregated outputs + # cos: [B, W, D_rope/2] + # sin: [B, W, D_rope/2] + cos, sin = self.rotary_emb(compressed, positions) + # compressed_4d: [B, W, 1, D] + compressed_4d = jnp.expand_dims(compressed, axis=2) + # rotated_4d: [B, W, 1, D] + rotated_4d = apply_rotary_pos_emb(compressed_4d, cos, sin, unsqueeze_dim=2) + # compressed_kv: [B, W, D] + compressed_kv = jnp.squeeze(rotated_4d, axis=2) + + # Execute Lightning Indexer to obtain block indices per query + # topk: [B, S, k] + topk = self.indexer(hidden_states, q_residual, position_ids) + + # Clamp indices safely using jnp.clip to avoid JAX negative/out-of-bounds indexing exceptions + # under indexer -1 sentinel conditions. + # safe_indices: [B, S, k] + safe_indices = jnp.clip(topk, a_min=0) + # batch_idx: [B, 1, 1] + batch_idx = jnp.arange(batch)[:, jnp.newaxis, jnp.newaxis] + # Perform TPU-efficient JAX Advanced Indexing Gather. + # gathered: [B, S, k, D] + gathered = compressed_kv[batch_idx, safe_indices, :] + + # Reshape gathered elements to standardized multi-head formats + # compressed_kv_out: [B, 1, S * k, D] + compressed_kv_out = gathered.reshape(batch, 1, seq_len * topk.shape[-1], self.head_dim) + + # Vectorized block bias mask construction to filter out invalid sparse gathered entries. + # valid: [B, S, k] indicating whether each top-k selection is valid (non-sentinel) + valid = topk >= 0 + # allowed: [B, S, k] containing 0.0 for valid entries and -inf for invalid sentinels + allowed = jnp.where(valid, 0.0, -jnp.inf) + # Construct an equivalence diagonal mask matching query sequence indices. + # eq_mask: [S, S, 1] representing identity query boundaries + eq_mask = jnp.arange(seq_len)[:, jnp.newaxis, jnp.newaxis] == jnp.arange(seq_len)[jnp.newaxis, :, jnp.newaxis] + # allowed_expanded: [B, S, 1, k] + allowed_expanded = allowed[:, :, jnp.newaxis, :] + # Distribute allowed masks diagonally using JAX vectorization to prevent cross-query leakage. + # block_bias_5d: [B, S, S, k] + block_bias_5d = jnp.where(eq_mask[jnp.newaxis, :, :, :], allowed_expanded, -jnp.inf) + # Reshape and format to standard key-value sequence length formats + # block_bias: [B, S, S * k] + block_bias = block_bias_5d.reshape(batch, seq_len, seq_len * topk.shape[-1]) + # block_bias: [B, 1, S, S * k] + block_bias = jnp.expand_dims(block_bias, axis=1) + return compressed_kv_out, block_bias + + +class DeepSeekV4Attention(nnx.Module): + """Main coordination attention block for DeepSeek-V4 compressed layer configurations. + + This module implements multi-head attention augmented with query-compression LoRA + projections, unweighted key/value normalizations, optional heavily or sparsely + compressed long-range context compressor integrations, learnable attention sinks, + and parallelized grouped output mixing projections. + """ + + def __init__( + self, + hidden_size: int, + q_lora_rank: int, + head_dim: int, + num_heads: int, + config: Any, + layer_idx: int, + eps: float = 1e-6, + weight_dtype: Any = jnp.float32, + dtype: Any = jnp.float32, + *, + rngs: nnx.Rngs, + ): + """Initializes the DeepSeekV4 Attention coordinator block. + + Args: + hidden_size: The model's global hidden dimension size. + q_lora_rank: The projection rank dimension of Q LoRA. + head_dim: The projection size of each attention key-value channel. + num_heads: The total number of query attention heads. + config: The DeepSeekV4 model configurations metadata. + layer_idx: The decoder stack layer index containing this attention module. + eps: Tiny additive variance limit for RMS normalization stability. + weight_dtype: The parameter weights numerical data type. + dtype: The mathematical execution numerical data type. + rngs: The Flax NNX random number generator collection. + """ + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.layer_type = config.layer_types[layer_idx] + self.num_heads = num_heads + self.head_dim = head_dim + self.sliding_window = config.sliding_window + self.scaling = head_dim**-0.5 + self.dtype = dtype + self.weight_dtype = weight_dtype + + # Projections for query extraction and low-rank compression + self.q_a_proj = nnx.Linear( + in_features=hidden_size, + out_features=q_lora_rank, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + self.q_a_norm = DeepSeekV4RMSNorm( + hidden_size=q_lora_rank, + eps=eps, + dtype=dtype, + weight_dtype=weight_dtype, + ) + self.q_b_proj = nnx.Linear( + in_features=q_lora_rank, + out_features=num_heads * head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + self.q_b_norm = DeepSeekV4UnweightedRMSNorm( + eps=eps, + dtype=dtype, + ) + + # Unified projected shared MQA key/value block + self.kv_proj = nnx.Linear( + in_features=hidden_size, + out_features=head_dim, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + self.kv_norm = DeepSeekV4RMSNorm( + hidden_size=head_dim, + eps=eps, + dtype=dtype, + weight_dtype=weight_dtype, + ) + + # Block-diagonal grouped linear layer for multi-head features mixing + self.o_a_proj = DeepSeekGroupedLinear( + in_features_per_group=num_heads * head_dim // config.o_groups, + out_features=config.o_groups * config.o_lora_rank, + n_groups=config.o_groups, + weight_dtype=weight_dtype, + dtype=dtype, + rngs=rngs, + ) + self.o_b_proj = nnx.Linear( + in_features=config.o_groups * config.o_lora_rank, + out_features=hidden_size, + use_bias=False, + dtype=dtype, + param_dtype=weight_dtype, + rngs=rngs, + ) + + # Attention Sink Parameter + self.sinks = nnx.Param(jax.nn.initializers.zeros(rngs.params(), (num_heads,), weight_dtype)) + + # Layer specific compressor allocation + if self.layer_type == "heavily_compressed_attention": + self.compressor = HCACompressor( + hidden_size=hidden_size, + head_dim=head_dim, + config=config, + layer_idx=layer_idx, + eps=eps, + weight_dtype=weight_dtype, + dtype=dtype, + rngs=rngs, + ) + elif self.layer_type == "compressed_sparse_attention": + self.compressor = CSACompressor( + hidden_size=hidden_size, + q_lora_rank=q_lora_rank, + head_dim=head_dim, + config=config, + layer_idx=layer_idx, + eps=eps, + weight_dtype=weight_dtype, + dtype=dtype, + rngs=rngs, + ) + else: + self.compressor = None + + def __call__( + self, + hidden_states: jnp.ndarray, + cos: jnp.ndarray, + sin: jnp.ndarray, + position_ids: jnp.ndarray, + attention_mask: jnp.ndarray | None = None, + ) -> tuple[jnp.ndarray, jnp.ndarray]: + """Executes DeepSeek-V4 compressed multi-head attention. + + This method projects input states to query representations, applies low-rank + LoRA, normalizes and applies positional RoPE, creates shared multi-query key-value + configurations, expands local context windows using structural compressors (if present), + calculates multi-head attention logits, applies structural masks and learnable + sinks, computes stable attention probabilities, and mixes final attention outputs + using parallel block-diagonal grouped projections. + + Args: + hidden_states: The input hidden representation sequence of shape [B, S, D_model]. + cos: Positional RoPE cosine frequencies array of shape [B, S, D_rope]. + sin: Positional RoPE sine frequencies array of shape [B, S, D_rope]. + position_ids: Absolute sequence position identifiers of shape [B, S]. + attention_mask: Optional attention mask of shape [B, 1, S, S_kv]. + + Returns: + A tuple containing: + - The projected mixed output sequence tensor of shape [B, S, D_model]. + - The final multi-head attention weights of shape [B, H, S, S_kv]. + """ + # hidden_states shape: [B, S, D_model] + # cos, sin shape: [B, S, D_rope] + # position_ids shape: [B, S] + # attention_mask shape: [B, 1, S_q, S_kv] + batch, seq_len, _ = hidden_states.shape + h_shape = (batch, seq_len, self.num_heads, self.head_dim) + + # Project inputs to query representations + # q_residual: [B, S, D_rank] + q_residual = self.q_a_norm(self.q_a_proj(hidden_states)) + # q: [B, H, S, D_head] + q = self.q_b_proj(q_residual).reshape(h_shape).transpose(0, 2, 1, 3) + q = self.q_b_norm(q) + # q: [B, H, S, D_head] + q = apply_rotary_pos_emb(q, cos, sin, unsqueeze_dim=1) + + # Project inputs to key/value representation + # kv: [B, 1, S, D_head] + kv = self.kv_norm(self.kv_proj(hidden_states)).reshape(batch, seq_len, 1, self.head_dim).transpose(0, 2, 1, 3) + # kv: [B, 1, S, D_head] + kv = apply_rotary_pos_emb(kv, cos, sin, unsqueeze_dim=1) + + block_bias = None + # Apply compressed attention key-value generation path + if self.compressor is not None: + # compressed_kv: [B, 1, W, D_head] or [B, 1, S*k, D_head] + # block_bias: [B, 1, S, W] or [B, 1, S, S*k] + compressed_kv, block_bias = self.compressor(hidden_states, q_residual, position_ids) + # kv: [B, 1, S + W, D_head] or [B, 1, S + S*k, D_head] + kv = jnp.concatenate([kv, compressed_kv], axis=2) + + # Broadcast key/value configurations to all heads + # k: [B, H, S_kv, D_head] + k = jnp.repeat(kv, self.num_heads, axis=1) + # v: [B, H, S_kv, D_head] + v = jnp.repeat(kv, self.num_heads, axis=1) + + # Compute attention logits + # logits: [B, H, S, S_kv] + logits = jnp.einsum("bhsd, bhkd -> bhsk", q, k, precision=self.config.matmul_precision) * self.scaling + + # Apply attention mask addition and block bias concatenation + if attention_mask is not None: + if block_bias is not None: + attention_mask = jnp.concatenate([attention_mask, block_bias.astype(attention_mask.dtype)], axis=-1) + elif kv.shape[2] > attention_mask.shape[-1]: + pad_width = kv.shape[2] - attention_mask.shape[-1] + attention_mask = jnp.pad(attention_mask, ((0, 0), (0, 0), (0, 0), (0, pad_width)), constant_values=0.0) + logits = logits + attention_mask + + # Concatenate learnable attention sinks + # sinks: [1, H, 1, 1] -> [B, H, S, 1] + sinks = self.sinks.value.reshape(1, -1, 1, 1) + sinks = jnp.broadcast_to(sinks, (batch, self.num_heads, seq_len, 1)) + # combined_logits: [B, H, S, S_kv + 1] + combined_logits = jnp.concatenate([logits, sinks], axis=-1) + + # Stable Softmax projection + combined_logits = combined_logits - jnp.max(combined_logits, axis=-1, keepdims=True) + probs = jax.nn.softmax(combined_logits, axis=-1) + # Drop sinks representation column + # attn_weights: [B, H, S, S_kv] + attn_weights = probs[..., :-1] + + # Project attention weights onto values + # attn_output: [B, H, S, D_head] + attn_output = jnp.einsum("bhsk, bhkd -> bhsd", attn_weights, v, precision=self.config.matmul_precision) + + # Apply conjugate RoPE transformation to restore position invariants + attn_output = apply_rotary_pos_emb(attn_output, cos, -sin, unsqueeze_dim=1) + + # Map outputs to grouped linear configuration + # grouped: [B, S, o_groups, (H / o_groups) * D_head] + grouped = attn_output.transpose(0, 2, 1, 3).reshape(batch, seq_len, self.config.o_groups, -1) + # grouped: [B, S, o_groups, o_lora_rank] + grouped = self.o_a_proj(grouped) + # Flatten back to grouped rank representations + # grouped_flat: [B, S, o_groups * o_lora_rank] + grouped_flat = grouped.reshape(batch, seq_len, -1) + + # Final linear projection to hidden dimension space + # output: [B, S, D_model] + output = self.o_b_proj(grouped_flat) + + return output, attn_weights diff --git a/tests/unit/deepseek_v4_vs_reference_test.py b/tests/unit/deepseek_v4_vs_reference_test.py index 3b87f5b3fd..7691848b7a 100644 --- a/tests/unit/deepseek_v4_vs_reference_test.py +++ b/tests/unit/deepseek_v4_vs_reference_test.py @@ -28,6 +28,10 @@ import maxtext.layers.embeddings as jax_emb_module import maxtext.layers.linears as jax_linear_module from maxtext.layers.moe import DeepSeekV4TopKRouter, DeepSeekV4HashRouter +from maxtext.layers import attention_compressed +from maxtext.layers.embeddings import DeepSeekV4RotaryEmbedding, apply_rotary_pos_emb +from maxtext.layers.normalizations import DeepSeekV4RMSNorm, DeepSeekV4UnweightedRMSNorm +from maxtext.layers.linears import DeepSeekGroupedLinear # ============================================================================== @@ -66,6 +70,7 @@ def __init__(self, **kwargs): "compressed_sparse_attention": 4, "heavily_compressed_attention": 128, } + self.compress_ratios = [128] * 43 self.sliding_window = 128 self.o_groups = 8 self.o_lora_rank = 1024 @@ -74,7 +79,8 @@ def __init__(self, **kwargs): self.index_topk = 512 self.rms_norm_eps = 1.0e-6 self.attention_dropout = 0.0 - self.layer_types = ["heavily_compressed_attention"] * 43 + self._attn_implementation = "eager" + self.matmul_precision = "default" # Setup default rope parameters dim = int(self.head_dim * self.partial_rotary_factor) @@ -510,7 +516,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: batch, _, _ = hidden_states.shape cache_layer: DeepseekV4HCACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -535,7 +541,22 @@ def forward( if cache_layer is not None: compressed = cache_layer.update_compressor_states("compressor", compressed) - return compressed.unsqueeze(1) + compressed_kv = compressed.unsqueeze(1) + + compressed_len = compressed_kv.shape[2] + seq_len = position_ids.shape[1] + if seq_len == 1 or compressed_len == 0: + return compressed_kv, None + + # query `t` may only see cache entries at pos `w` t > w * compress_rate (ex: t=7, w=2 t does not attend to it). + entry_indices = torch.arange(compressed_len, device=compressed_kv.device) + causal_threshold = (position_ids + 1) // self.compress_rate # [B, S] + block_bias = compressed_kv.new_zeros((batch, 1, seq_len, compressed_len)) + block_bias = block_bias.masked_fill( + entry_indices.view(1, 1, 1, -1) >= causal_threshold.unsqueeze(1).unsqueeze(-1), + float("-inf"), + ) + return compressed_kv, block_bias class DeepseekV4Indexer_PT(nn.Module): @@ -644,8 +665,25 @@ def forward( scores = F.relu(scores) * self.softmax_scale weights = self.weights_proj(hidden_states).float() * self.weights_scaling # [B, S, H] index_scores = (scores * weights.unsqueeze(-1)).sum(dim=2) # [B, S, T] - topk = min(self.index_topk, compressed_kv.shape[1]) - return index_scores.topk(topk, dim=-1).indices + compressed_len = compressed_kv.shape[1] + top_k = min(self.index_topk, compressed_len) + + # not all queries can attend to the compressed entries. If a query's position + # is small than the relative position of the key (say m=4, query 2 cannot attend + # to compressed key at position 4, because it compressed info for states at position + # 12 to 16. Thus we need to make sure that top_k does not land in that range. + # Picks that still point past `causal_threshold` (early queries with too few ready + # blocks) are replaced with a `-1` sentinel that the compressor treats as invalid. + if compressed_len > 0: + causal_threshold = (position_ids + 1) // self.compress_rate # [B, S] + entry_indices = torch.arange(compressed_len, device=index_scores.device) + future_mask = entry_indices.view(1, 1, -1) >= causal_threshold.unsqueeze(-1) # [B, S, T] + index_scores = index_scores.masked_fill(future_mask, float("-inf")) + top_k_indices = index_scores.topk(top_k, dim=-1).indices # [B, S, k] + invalid = top_k_indices >= causal_threshold.unsqueeze(-1) + return torch.where(invalid, torch.full_like(top_k_indices, -1), top_k_indices) + + return index_scores.topk(top_k, dim=-1).indices class DeepseekV4CSACompressor_PT(nn.Module): @@ -689,7 +727,7 @@ def forward( position_ids: torch.Tensor, past_key_values: Cache | None, layer_idx: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: batch, seq_len, _ = hidden_states.shape cache_layer: DeepseekV4CSACache = past_key_values.layers[layer_idx] if past_key_values is not None else None kv = self.kv_proj(hidden_states) @@ -742,10 +780,31 @@ def forward( compressed_kv = compressed.unsqueeze(1) # Lightning Indexer: gather top-`index_topk` compressed entries per query. - topk = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) # [B, S, k] - expanded = compressed_kv.unsqueeze(2).expand(-1, -1, seq_len, -1, -1) - idx = topk.unsqueeze(1).unsqueeze(-1).expand(-1, 1, -1, -1, self.head_dim) - return torch.gather(expanded, 3, idx).reshape(batch, 1, -1, self.head_dim) + # in some cases, the output index can return top-k positions that should not be attended to. + # Ex: for query at index 5, m=4, and `index_topk=1024`, 1024 index are return but only 2 should be + # attended to. The indexer marks the rest with `-1`; we clamp before the gather and keep the `valid` + # to drop them from the per-query block mask afterwards. + top_k_indices = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) # [B, S, k] + top_k = top_k_indices.shape[-1] + compressed_len = compressed_kv.shape[2] + valid = top_k_indices >= 0 # [B, S, k] + # Flatten (B, T) into one row axis and shift picks by `b * T`, then index_select once. + # Same kernel as an embedding lookup — cheaper than `gather` over an expanded view. + safe_indices = top_k_indices.clamp(min=0) + batch_offsets = (torch.arange(batch, device=compressed_kv.device) * compressed_len).view(batch, 1, 1) + flat_indices = (safe_indices + batch_offsets).view(-1) # [B*S*k] + flat_kv = compressed_kv.reshape(batch * compressed_len, self.head_dim) + gathered = flat_kv.index_select(0, flat_indices).view(batch, 1, -1, self.head_dim) # [B, 1, S*k, D] + + # Per-query block bias: query `t` may only see the cache entries that are <= `seq_len // m` + # and in these, only the ones marked valid by the indexer. Everything else is `-inf`. + # While the above negated the indexer, here we apply the "causal" masking. + block_bias = gathered.new_full((batch, 1, seq_len, seq_len, top_k), float("-inf")) + allowed = torch.where(valid, gathered.new_zeros(()), gathered.new_full((), float("-inf"))) # [B, S, k] + query_indices = torch.arange(seq_len, device=gathered.device) + block_bias[:, 0, query_indices, query_indices, :] = allowed # diagonal: q_idx == block_idx + block_bias = block_bias.view(batch, 1, seq_len, seq_len * top_k) + return gathered, block_bias def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -808,9 +867,9 @@ class DeepseekV4Attention_PT(nn.Module): Positional Embedding"). RoPE is also applied with position `-i` to the attention output's rope slice, so the contribution of each KV entry stays a function of the *relative* distance to the query. - * Per-head learnable attention sink like gpt OSS. - * Grouped low-rank output projection for perfs. - * 3 different cache mechanisms, sliding, sliding+CSA, sliding+HCA. + * Per-head learnable attention sink like gpt OSS. + * Grouped low-rank output projection for perfs. + * 3 different cache mechanisms, sliding, sliding+CSA, sliding+HCA. """ def __init__(self, config: DeepseekV4Config, layer_idx: int): @@ -818,6 +877,9 @@ def __init__(self, config: DeepseekV4Config, layer_idx: int): self.config = config self.layer_idx = layer_idx self.layer_type = config.layer_types[layer_idx] + # Sliding-only layers use the "main" (plain θ=10000) rope; CSA/HCA layers + # share the same yarn-scaled "compress" rope as their compressor. + self.rope_layer_type = "main" if self.layer_type == "sliding_attention" else "compress" self.num_heads = config.num_attention_heads self.num_key_value_groups = config.num_attention_heads # single KV head, broadcast to all self.head_dim = config.head_dim @@ -863,17 +925,21 @@ def forward( if past_key_values is not None: # sliding where K==V kv = past_key_values.update(kv, kv, self.layer_idx)[0] + block_bias = None if self.compressor is not None: # Compressed KV (CSA or HCA) - compressed_kv = self.compressor(hidden_states, q_residual, position_ids, past_key_values, self.layer_idx) + compressed_kv, block_bias = self.compressor( + hidden_states, q_residual, position_ids, past_key_values, self.layer_idx + ) kv = torch.cat([kv, compressed_kv], dim=2) - # The compressor path concatenates extra entries onto the KV axis after the - # standard sliding-window cache update, so a tensor `attention_mask` (built - # for the pre-concat KV length) needs to be right-padded to cover them. - # Flex-attention passes a `BlockMask` whose KV-length axis comes from its - # own `mask_mod`, not from a dense tensor — skip the pad in that case. - if isinstance(attention_mask, torch.Tensor) and kv.shape[2] > attention_mask.shape[-1]: - attention_mask = F.pad(attention_mask, (0, kv.shape[2] - attention_mask.shape[-1]), value=0.0) + # compressor returns a `block_bias` carrying per-query causality + indexer + # selections, which needs to be concatenated to the right of `attention_mask`. + # Eager/flash interfaces consume the combined mask directly. + if isinstance(attention_mask, torch.Tensor): + if block_bias is not None: + attention_mask = torch.cat([attention_mask, block_bias.to(attention_mask.dtype)], dim=-1) + elif kv.shape[2] > attention_mask.shape[-1]: + attention_mask = F.pad(attention_mask, (0, kv.shape[2] - attention_mask.shape[-1]), value=0.0) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward @@ -990,7 +1056,7 @@ def test_unweighted_rms_norm_parity(self): # Execute JAX equivalent target unweighted RMS normalization. # Target module instantiated from top-level imports to optimize namespace lookup. - jax_model = jax_norm_module.DeepSeekV4UnweightedRMSNorm(eps=1e-6) + jax_model = DeepSeekV4UnweightedRMSNorm(eps=1e-6) out_jax = jax_model(x_jax) # Compare outputs within numerical precision tolerance limits. @@ -1014,7 +1080,7 @@ def test_rms_norm_parity(self): # Execute JAX equivalent target RMS normalization. # JAX model state parameters are explicitly updated to match the generated weights. - jax_model = jax_norm_module.DeepSeekV4RMSNorm(hidden_size=512, eps=1e-6) + jax_model = DeepSeekV4RMSNorm(hidden_size=512, eps=1e-6) jax_model.weight.value = jnp.array(weight_np) out_jax = jax_model(x_jax) @@ -1051,11 +1117,11 @@ def test_rotary_embedding_parity(self): # Execute JAX equivalent target rotary embeddings. # The target JAX layer operates natively on [B, S, H, D] layouts, applying # dimensional unsqueezing at axis 2 to broadcast across heads. - jax_emb = jax_emb_module.DeepSeekV4RotaryEmbedding(head_dim=D, partial_rotary_factor=64.0 / 512.0, rope_theta=10000.0) + jax_emb = DeepSeekV4RotaryEmbedding(head_dim=D, partial_rotary_factor=64.0 / 512.0, rope_theta=10000.0) cos_jax, sin_jax = jax_emb(x_jax, position_ids_jax) # Execute JAX target application. - out_jax = jax_emb_module.apply_rotary_pos_emb(x_jax, cos_jax, sin_jax, unsqueeze_dim=2) + out_jax = apply_rotary_pos_emb(x_jax, cos_jax, sin_jax, unsqueeze_dim=2) out_jax_np = np.array(out_jax) # Compare both the intermediate cos/sin sinusoids and the final rotated values. @@ -1091,7 +1157,7 @@ def test_grouped_linear_parity(self): # Execute JAX equivalent target grouped linear block projection. # JAX weights are initialized using the deterministic key context. rngs = nnx.Rngs(42) - jax_model = jax_linear_module.DeepSeekGroupedLinear( + jax_model = DeepSeekGroupedLinear( in_features_per_group=i, out_features=o, n_groups=g, @@ -1248,6 +1314,388 @@ def __init__(self): np.testing.assert_array_equal(jax_indices, py_indices.numpy()) np.testing.assert_allclose(py_weights.detach().numpy(), jax_weights, atol=1e-5, rtol=1e-5) + def test_hca_compressor_parity(self): + # Configure deterministic seeds for parity reproducibility + np.random.seed(42) + B, S, D, D_head, compress_rate = 2, 128, 512, 256, 32 + + # hidden_states: [B, S, D] + x_np = np.random.randn(B, S, D).astype(np.float32) + positions_np = np.broadcast_to(np.arange(S)[np.newaxis, :], (B, S)).astype(np.int32) + + x_torch = torch.tensor(x_np) + positions_torch = torch.tensor(positions_np, dtype=torch.long) + + x_jax = jnp.array(x_np) + positions_jax = jnp.array(positions_np) + + # Initialize PyTorch configurations matching parameter spaces + config = DeepseekV4Config() + config.hidden_size = D + config.head_dim = D_head + config.compress_rates["heavily_compressed_attention"] = compress_rate + config.rms_norm_eps = 1e-6 + + # Initialize PyTorch HCA Compressor model + torch_model = DeepseekV4HCACompressor_PT(config) + torch.nn.init.normal_(torch_model.position_bias, std=0.02) + + # Map JAX layer using matching parameters + jax_config = DeepseekV4Config() + jax_config.compress_ratios = [compress_rate] * 43 + jax_config.compress_rope_theta = 160000.0 + + rngs = nnx.Rngs(42) + jax_model = attention_compressed.HCACompressor( + hidden_size=D, + head_dim=D_head, + config=jax_config, + layer_idx=0, + eps=1e-6, + rngs=rngs, + ) + + # Set JAX parameters identical to PyTorch states to guarantee numerical parity + jax_model.kv_proj.kernel[...] = jnp.array(torch_model.kv_proj.weight.detach().numpy().T) + jax_model.gate_proj.kernel[...] = jnp.array(torch_model.gate_proj.weight.detach().numpy().T) + jax_model.position_bias[...] = jnp.array(torch_model.position_bias.detach().numpy()) + jax_model.kv_norm.weight[...] = jnp.array(torch_model.kv_norm.weight.detach().numpy()) + + # Execute PyTorch stateless compressor path + # Shape out_torch: [B, 1, W, D_head] where W = S // compress_rate = 4 + out_torch, block_bias_torch = torch_model( + hidden_states=x_torch, + q_residual=None, + position_ids=positions_torch, + past_key_values=None, + layer_idx=0, + ) + out_torch = out_torch.detach().numpy() + if block_bias_torch is not None: + block_bias_torch = block_bias_torch.detach().numpy() + + # Execute JAX equivalent stateless compressor path + # Shape out_jax: [B, 1, W, D_head] + out_jax, block_bias_jax = jax_model( + hidden_states=x_jax, + position_ids=positions_jax, + ) + out_jax_np = np.array(out_jax) + if block_bias_jax is not None: + block_bias_jax = np.array(block_bias_jax) + + # Validate bit-accurate state outputs matching numerical tolerance thresholds + np.testing.assert_allclose(out_torch, out_jax_np, atol=1e-5, rtol=1e-5) + if block_bias_torch is not None or block_bias_jax is not None: + np.testing.assert_allclose(block_bias_torch, block_bias_jax, atol=1e-5, rtol=1e-5) + + def test_indexer_parity(self): + np.random.seed(42) + B, S, D, D_rank = 2, 128, 512, 1024 + num_heads, index_head_dim, index_topk, compress_rate = 64, 128, 8, 4 + + # hidden_states: [B, S, D] + x_np = np.random.randn(B, S, D).astype(np.float32) + # q_residual: [B, S, D_rank] + q_res_np = np.random.randn(B, S, D_rank).astype(np.float32) + # position_ids: [B, S] + positions_np = np.broadcast_to(np.arange(S)[np.newaxis, :], (B, S)).astype(np.int32) + + x_torch = torch.tensor(x_np) + q_res_torch = torch.tensor(q_res_np) + positions_torch = torch.tensor(positions_np, dtype=torch.long) + + x_jax = jnp.array(x_np) + q_res_jax = jnp.array(q_res_np) + positions_jax = jnp.array(positions_np) + + # Initialize PyTorch indexer configurations + config = DeepseekV4Config() + config.hidden_size = D + config.q_lora_rank = D_rank + config.index_n_heads = num_heads + config.index_head_dim = index_head_dim + config.index_topk = index_topk + config.compress_rates["compressed_sparse_attention"] = compress_rate + config.rms_norm_eps = 1e-6 + + torch_model = DeepseekV4Indexer_PT(config) + torch.nn.init.normal_(torch_model.position_bias, std=0.02) + + # Map JAX equivalent Indexer module + jax_config = DeepseekV4Config() + jax_config.index_n_heads = num_heads + jax_config.index_head_dim = index_head_dim + jax_config.index_topk = index_topk + jax_config.compress_ratios = [compress_rate] * 43 + jax_config.compress_rope_theta = 160000.0 + + rngs = nnx.Rngs(42) + jax_model = attention_compressed.DeepSeekV4Indexer( + hidden_size=D, + q_lora_rank=D_rank, + config=jax_config, + layer_idx=0, + eps=1e-6, + rngs=rngs, + ) + + # Synchronize parameter values + jax_model.kv_proj.kernel[...] = jnp.array(torch_model.kv_proj.weight.detach().numpy().T) + jax_model.gate_proj.kernel[...] = jnp.array(torch_model.gate_proj.weight.detach().numpy().T) + jax_model.position_bias[...] = jnp.array(torch_model.position_bias.detach().numpy()) + jax_model.kv_norm.weight[...] = jnp.array(torch_model.kv_norm.weight.detach().numpy()) + jax_model.q_b_proj.kernel[...] = jnp.array(torch_model.q_b_proj.weight.detach().numpy().T) + jax_model.weights_proj.kernel[...] = jnp.array(torch_model.weights_proj.weight.detach().numpy().T) + + # Execute models + out_torch = ( + torch_model( + hidden_states=x_torch, + q_residual=q_res_torch, + position_ids=positions_torch, + past_key_values=None, + layer_idx=0, + ) + .detach() + .numpy() + ) + + out_jax = jax_model( + hidden_states=x_jax, + q_residual=q_res_jax, + position_ids=positions_jax, + ) + out_jax_np = np.array(out_jax) + + # Check mathematical equivalence of top-k selection indices + np.testing.assert_allclose(out_torch, out_jax_np, atol=1e-5, rtol=1e-5) + + def test_csa_compressor_parity(self): + np.random.seed(42) + B, S, D, D_rank, D_head = 2, 128, 512, 1024, 256 + num_heads, index_head_dim, index_topk, compress_rate = 64, 128, 8, 4 + + # Inputs + x_np = np.random.randn(B, S, D).astype(np.float32) + q_res_np = np.random.randn(B, S, D_rank).astype(np.float32) + positions_np = np.broadcast_to(np.arange(S)[np.newaxis, :], (B, S)).astype(np.int32) + + x_torch = torch.tensor(x_np) + q_res_torch = torch.tensor(q_res_np) + positions_torch = torch.tensor(positions_np, dtype=torch.long) + + x_jax = jnp.array(x_np) + q_res_jax = jnp.array(q_res_np) + positions_jax = jnp.array(positions_np) + + # Configurations + config = DeepseekV4Config() + config.hidden_size = D + config.q_lora_rank = D_rank + config.head_dim = D_head + config.index_n_heads = num_heads + config.index_head_dim = index_head_dim + config.index_topk = index_topk + config.compress_rates["compressed_sparse_attention"] = compress_rate + config.rms_norm_eps = 1e-6 + + torch_model = DeepseekV4CSACompressor_PT(config) + torch.nn.init.normal_(torch_model.position_bias, std=0.02) + torch.nn.init.normal_(torch_model.indexer.position_bias, std=0.02) + + # JAX CSA Compressor + jax_config = DeepseekV4Config() + jax_config.head_dim = D_head + jax_config.index_n_heads = num_heads + jax_config.index_head_dim = index_head_dim + jax_config.index_topk = index_topk + jax_config.compress_ratios = [compress_rate] * 43 + jax_config.compress_rope_theta = 160000.0 + + rngs = nnx.Rngs(42) + jax_model = attention_compressed.CSACompressor( + hidden_size=D, + q_lora_rank=D_rank, + head_dim=D_head, + config=jax_config, + layer_idx=0, + eps=1e-6, + rngs=rngs, + ) + + # Synchronize outer compressor states + jax_model.kv_proj.kernel[...] = jnp.array(torch_model.kv_proj.weight.detach().numpy().T) + jax_model.gate_proj.kernel[...] = jnp.array(torch_model.gate_proj.weight.detach().numpy().T) + jax_model.position_bias[...] = jnp.array(torch_model.position_bias.detach().numpy()) + jax_model.kv_norm.weight[...] = jnp.array(torch_model.kv_norm.weight.detach().numpy()) + + # Synchronize inner indexer states + jax_model.indexer.kv_proj.kernel[...] = jnp.array(torch_model.indexer.kv_proj.weight.detach().numpy().T) + jax_model.indexer.gate_proj.kernel[...] = jnp.array(torch_model.indexer.gate_proj.weight.detach().numpy().T) + jax_model.indexer.position_bias[...] = jnp.array(torch_model.indexer.position_bias.detach().numpy()) + jax_model.indexer.kv_norm.weight[...] = jnp.array(torch_model.indexer.kv_norm.weight.detach().numpy()) + jax_model.indexer.q_b_proj.kernel[...] = jnp.array(torch_model.indexer.q_b_proj.weight.detach().numpy().T) + jax_model.indexer.weights_proj.kernel[...] = jnp.array(torch_model.indexer.weights_proj.weight.detach().numpy().T) + + # Execute + out_torch, block_bias_torch = torch_model( + hidden_states=x_torch, + q_residual=q_res_torch, + position_ids=positions_torch, + past_key_values=None, + layer_idx=0, + ) + out_torch = out_torch.detach().numpy() + if block_bias_torch is not None: + block_bias_torch = block_bias_torch.detach().numpy() + + out_jax, block_bias_jax = jax_model( + hidden_states=x_jax, + q_residual=q_res_jax, + position_ids=positions_jax, + ) + out_jax_np = np.array(out_jax) + if block_bias_jax is not None: + block_bias_jax = np.array(block_bias_jax) + + # Diagnose indexer parity + topk_torch = ( + torch_model.indexer( + hidden_states=x_torch, + q_residual=q_res_torch, + position_ids=positions_torch, + past_key_values=None, + layer_idx=0, + ) + .detach() + .numpy() + ) + + topk_jax = jax_model.indexer( + hidden_states=x_jax, + q_residual=q_res_jax, + position_ids=positions_jax, + ) + topk_jax_np = np.array(topk_jax) + + print("topk_torch:", topk_torch) + print("topk_jax:", topk_jax_np) + np.testing.assert_allclose(topk_torch, topk_jax_np, atol=1e-5, rtol=1e-5) + + # Check complete parity of gathered/indexed keys + np.testing.assert_allclose(out_torch, out_jax_np, atol=1e-5, rtol=1e-5) + if block_bias_torch is not None or block_bias_jax is not None: + np.testing.assert_allclose(block_bias_torch, block_bias_jax, atol=1e-5, rtol=1e-5) + + def test_attention_layer_parity(self): + np.random.seed(42) + B, S, D, D_rank, D_head, num_heads = 2, 128, 512, 1024, 256, 16 + compress_rate = 32 + + # Inputs + x_np = np.random.randn(B, S, D).astype(np.float32) + position_ids_np = np.broadcast_to(np.arange(S)[np.newaxis, :], (B, S)).astype(np.int32) + + x_torch = torch.tensor(x_np) + position_ids_torch = torch.tensor(position_ids_np, dtype=torch.long) + + x_jax = jnp.array(x_np) + position_ids_jax = jnp.array(position_ids_np) + + # Configurations + config = DeepseekV4Config() + config.hidden_size = D + config.q_lora_rank = D_rank + config.head_dim = D_head + config.num_attention_heads = num_heads + config.num_key_value_heads = 1 + config.compress_rates["heavily_compressed_attention"] = compress_rate + config.rms_norm_eps = 1e-6 + config.layer_types = ["heavily_compressed_attention"] * 10 + + # Generate reference position embeddings (cos, sin) + torch_emb = DeepseekV4RotaryEmbedding_PT(config) + cos_torch, sin_torch = torch_emb(x_torch, position_ids_torch, layer_type="main") + + cos_jax = jnp.array(cos_torch.detach().numpy()) + sin_jax = jnp.array(sin_torch.detach().numpy()) + + # Initialize PyTorch and JAX coordinate attention layers + torch_model = DeepseekV4Attention_PT(config, layer_idx=0) + torch.nn.init.normal_(torch_model.sinks, std=0.02) + if torch_model.compressor is not None: + torch.nn.init.normal_(torch_model.compressor.position_bias, std=0.02) + + jax_config = DeepseekV4Config() + jax_config.hidden_size = D + jax_config.q_lora_rank = D_rank + jax_config.head_dim = D_head + jax_config.num_attention_heads = num_heads + jax_config.compress_ratios = [compress_rate] * 10 + jax_config.compress_rope_theta = 160000.0 + jax_config.layer_types = ["heavily_compressed_attention"] * 10 + jax_config.o_groups = config.o_groups + jax_config.o_lora_rank = config.o_lora_rank + + rngs = nnx.Rngs(42) + jax_model = attention_compressed.DeepSeekV4Attention( + hidden_size=D, + q_lora_rank=D_rank, + head_dim=D_head, + num_heads=num_heads, + config=jax_config, + layer_idx=0, + eps=1e-6, + rngs=rngs, + ) + + # Copy projections and normalize weights from PyTorch to JAX + jax_model.q_a_proj.kernel[...] = jnp.array(torch_model.q_a_proj.weight.detach().numpy().T) + jax_model.q_a_norm.weight[...] = jnp.array(torch_model.q_a_norm.weight.detach().numpy()) + jax_model.q_b_proj.kernel[...] = jnp.array(torch_model.q_b_proj.weight.detach().numpy().T) + + jax_model.kv_proj.kernel[...] = jnp.array(torch_model.kv_proj.weight.detach().numpy().T) + jax_model.kv_norm.weight[...] = jnp.array(torch_model.kv_norm.weight.detach().numpy()) + + # Handle Grouped Output Projection mapping + w_o_a_np = torch_model.o_a_proj.weight.detach().numpy() + in_features_per_group = num_heads * D_head // config.o_groups + w_o_a_np = w_o_a_np.reshape(config.o_groups, -1, in_features_per_group).transpose(0, 2, 1) + jax_model.o_a_proj.weight[...] = jnp.array(w_o_a_np) + + jax_model.o_b_proj.kernel[...] = jnp.array(torch_model.o_b_proj.weight.detach().numpy().T) + jax_model.sinks[...] = jnp.array(torch_model.sinks.detach().numpy()) + + # Copy Compressor weights if present + if torch_model.compressor is not None: + jax_model.compressor.kv_proj.kernel[...] = jnp.array(torch_model.compressor.kv_proj.weight.detach().numpy().T) + jax_model.compressor.gate_proj.kernel[...] = jnp.array(torch_model.compressor.gate_proj.weight.detach().numpy().T) + jax_model.compressor.position_bias[...] = jnp.array(torch_model.compressor.position_bias.detach().numpy()) + jax_model.compressor.kv_norm.weight[...] = jnp.array(torch_model.compressor.kv_norm.weight.detach().numpy()) + + # Execute PyTorch attention layer + out_torch, _ = torch_model( + hidden_states=x_torch, + position_embeddings=(cos_torch, sin_torch), + position_ids=position_ids_torch, + attention_mask=None, + ) + out_torch_np = out_torch.detach().numpy() + + # Execute JAX attention layer + out_jax, _ = jax_model( + hidden_states=x_jax, + cos=cos_jax, + sin=sin_jax, + position_ids=position_ids_jax, + attention_mask=None, + ) + out_jax_np = np.array(out_jax) + + # Check complete numerical parity of coordination attention layers + np.testing.assert_allclose(out_torch_np, out_jax_np, atol=1e-5, rtol=1e-5) + if __name__ == "__main__": unittest.main()