From 2c4e71f17efdae3404e5caad8b5c1ead4aa3ef20 Mon Sep 17 00:00:00 2001 From: Shuning Jin Date: Fri, 6 Feb 2026 06:53:01 +0000 Subject: [PATCH] Add DeepSeek Engram layer --- pytest.ini | 1 + src/MaxText/configs/base.yml | 1 + src/MaxText/layers/engram.py | 684 +++++++++++++++++++++ src/MaxText/pyconfig.py | 2 +- tests/unit/engram_vs_reference_test.py | 784 +++++++++++++++++++++++++ 5 files changed, 1471 insertions(+), 1 deletion(-) create mode 100644 src/MaxText/layers/engram.py create mode 100644 tests/unit/engram_vs_reference_test.py diff --git a/pytest.ini b/pytest.ini index 5f220deb7b..a770702369 100644 --- a/pytest.ini +++ b/pytest.ini @@ -21,6 +21,7 @@ addopts = --ignore=tests/unit/qwen3_omni_layers_test.py --ignore=tests/unit/qwen3_next_vs_reference_test.py --ignore=tests/unit/deepseek32_vs_reference_test.py + --ignore=tests/unit/engram_vs_reference_test.py markers = tpu_only: marks tests to be run on TPUs only gpu_only: marks tests to be run on GPUs only diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index becc372955..40c58014f2 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -473,6 +473,7 @@ logical_axis_rules: [ ['paged_kv_head_dim_size', []], ['dense_layers', []], ['moe_layers', []], + ['engram_dim', ['tensor']], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] diff --git a/src/MaxText/layers/engram.py b/src/MaxText/layers/engram.py new file mode 100644 index 0000000000..b106b3e520 --- /dev/null +++ b/src/MaxText/layers/engram.py @@ -0,0 +1,684 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DeepSeek-AI, `Conditional Memory via Scalable Lookup: A New Axis of Sparsity for Large Language Models + `_, 2026 + +Reference implementation: https://github.com/deepseek-ai/Engram/blob/main/engram_demo_v1.py +""" + +from typing import List, Optional +import numpy as np +from sympy import isprime +from tokenizers import normalizers, Regex + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from flax import nnx + +from MaxText.tokenizer import HFTokenizer +from MaxText.common_types import MODEL_MODE_TRAIN, Array, Config +from MaxText.layers.embeddings import Embed +from MaxText.layers.initializers import nd_dense_init, NdInitializer +from MaxText.layers.linears import DenseGeneral +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.quantizations import AqtQuantization as Quant + + +class CompressedTokenizer: + """ + A canonicalizing wrapper that reduces vocabulary sparsity for n-gram lookup. + + This class maps semantically equivalent tokens (e.g., "Apple", " apple", "APPLE") + to a single unified ID. This many-to-one mapping significantly reduces the + combinatorial size of the n-gram space. + + Attributes: + lookup_table: Array mapping `original_id` -> `compressed_id`. + num_new_token: Size of the compressed vocabulary. + """ + + def __init__(self, tokenizer: HFTokenizer): + normalizer = self._build_normalizer() + self.lookup_table, self.num_new_token = self._build_lookup_table(tokenizer, normalizer) + + def __len__(self): + return self.num_new_token + + def _build_normalizer(self): + """ + Builds the normalization pipeline for text processing. + """ + # Private use Unicode character to protect single spaces during stripping + SENTINEL = "\uE000" + + # Normalization pipeline: ensures variations like "Café" and "cafe" map to the same ID + normalizer = normalizers.Sequence( + [ + # Compatibility decomposition (e.g., ½ -> 1/2) + normalizers.NFKC(), + # Canonical decomposition (e.g., é -> e + ´) + normalizers.NFD(), + # Strip diacritics (e.g., e + ´ -> e) + normalizers.StripAccents(), + # Lowercase conversion ("The" -> "the") + normalizers.Lowercase(), + # Collapse all whitespace variations to a single space + normalizers.Replace(Regex(r"[ \t\r\n]+"), " "), + # Protect standalone spaces from subsequent stripping + normalizers.Replace(Regex(r"^ $"), SENTINEL), + # Remove leading/trailing whitespace + normalizers.Strip(), + # Restore protected spaces + normalizers.Replace(SENTINEL, " "), + ] + ) + return normalizer + + def _build_lookup_table(self, tokenizer: HFTokenizer, normalizer: normalizers.Sequence): + """ + Builds the mapping from the original vocabulary to the compressed vocabulary. + """ + vocab_size = len(tokenizer) + # Mapping: original_tid -> compressed_nid (Many-to-One) + old2new = np.empty(vocab_size, dtype=np.int64) + # Mapping: normalized_string -> compressed_nid (One-to-One) + key2new = {} + + for tid in range(vocab_size): + # Decode token to raw text + text = tokenizer.decode([tid], skip_special_tokens=False) + + if "\ufffd" in text: + # Handle invalid UTF-8 (replacement char �). Use raw token instead. + key = tokenizer.convert_ids_to_tokens(tid) + else: + # Normalize text (e.g., " APPLE" -> "apple") + normalized_text = normalizer.normalize_str(text) + # Fallback to raw text if normalization creates an empty string + key = normalized_text if normalized_text else text + + # Assign compressed ID + nid = key2new.get(key) + if nid is None: + nid = len(key2new) + key2new[key] = nid + + old2new[tid] = nid + + return old2new, len(key2new) + + def __call__(self, input_ids): + """ + Maps original token IDs to compressed IDs. + """ + input_ids = np.asarray(input_ids, dtype=np.int64) + + # Identify valid tokens (ignore padding/masks usually marked with negative IDs) + valid_mask = input_ids >= 0 + valid_ids = input_ids[valid_mask] + + # Vectorized lookup: O(1) per token + output_ids = input_ids.copy() + output_ids[valid_mask] = self.lookup_table[valid_ids] + return output_ids + + +class NgramHashMapping: + """ + Maps n-gram sequences to hash-based indices for memory lookup. + + This class implements the Engram hashing mechanism. It converts variable-length + n-grams into fixed integer IDs. To handle the large combinatorial space, it uses: + 1. Unique Prime Vocabularies: Per-head prime moduli to minimize collision overlap. + 2. Sliding Window: Efficient shifting to generate n-gram views. + 3. Lightweight Hashing: A multiplicative-XOR function (Rabin-Karp variant). + """ + + def __init__( + self, + engram_vocab_bases: List[int], + max_ngram_size: int, + engram_num_heads: int, + layer_ids: List[int], + tokenizer: HFTokenizer, + pad_id: int, + seed: int, + ): + """ + Args: + engram_vocab_size: List of minimum head vocab sizes for each n-gram order. + max_ngram_size: Max n-gram size to track (e.g., 3 tracks 2-grams and 3-grams). + engram_num_heads: Number of parallel heads per n-gram order. + layer_ids: List of layer indices using Engram. + tokenizer: Base Hugging Face tokenizer. + pad_id: Padding token ID. + seed: Random seed for hash multiplier generation. + """ + self.min_head_vocab_size_per_ngram = engram_vocab_bases + self.max_ngram_size = max_ngram_size + self.n_head_per_ngram = engram_num_heads + self.layer_ids = layer_ids + + # Initialize compressed tokenizer + self.compressed_tokenizer = CompressedTokenizer(tokenizer) + self.tokenizer_vocab_size = len(self.compressed_tokenizer) + # TODO(shuningjin): why not just use pad_id = tokenizer.pad_id + if pad_id is not None: + self.pad_id = int(self.compressed_tokenizer.lookup_table[pad_id]) + + # Pre-calculate odd multipliers for hashing: {layer_id: multipliers} + self.layer_multipliers = self._calculate_multipliers_across_layers(seed) + + # Pre-calculate unique prime vocab sizes for every head + # Structure: {layer_id: [[2gram_head1, ..., 2gram_headH], ..., [Ngram_head1, ..., Ngram_headH]]} + self.vocab_size_across_layers = self._calculate_vocab_size_across_layers() + + def _calculate_multipliers_across_layers(self, seed: int): + """ + Pre-calculates random odd multipliers for each layer and n-gram position. + + Returns: + A dictionary mapping layer_id to a list of `max_ngram_size` multipliers. + """ + # Pre-calculate bounds for random generation + max_long = np.iinfo(np.int64).max + m_max = int(max_long // self.tokenizer_vocab_size) + half_bound = max(1, m_max // 2) + LAYER_PRIME_OFFSET = 10007 + + layer_multipliers = {} + for layer_id in self.layer_ids: + # Offset seed to decorrelate layers + layer_seed = int(seed + LAYER_PRIME_OFFSET * int(layer_id)) + np_rng = np.random.default_rng(layer_seed) + # Generate random odd integers + random_value = np_rng.integers(low=0, high=half_bound, size=(self.max_ngram_size,), dtype=np.int64) + multipliers = random_value * 2 + 1 + layer_multipliers[layer_id] = multipliers + return layer_multipliers + + def _calculate_vocab_size_across_layers(self): + """ + Calculates unique prime vocabulary sizes for every head in every layer. + Using unique primes minimizes the probability of simultaneous collisions across heads. + """ + + def find_next_unseen_prime(start: int, seen_primes: set): + candidate = start + 1 + while candidate in seen_primes or not isprime(candidate): + candidate += 1 + return candidate + + seen_primes = set() + vocab_size_across_layers = {} + + for layer_id in self.layer_ids: + all_ngram_vocab_sizes = [] + for n in range(2, self.max_ngram_size + 1): + current_ngram_heads_sizes = [] + + # Start search from the configured minimum size + n_gram_index = n - 2 + vocab_size = self.min_head_vocab_size_per_ngram[n_gram_index] + current_prime_search_start = vocab_size - 1 + + # Find unique primes for each head + num_heads = self.n_head_per_ngram + for _ in range(num_heads): + found_prime = find_next_unseen_prime(current_prime_search_start, seen_primes) + seen_primes.add(found_prime) + current_ngram_heads_sizes.append(found_prime) + current_prime_search_start = found_prime + + all_ngram_vocab_sizes.append(current_ngram_heads_sizes) + vocab_size_across_layers[layer_id] = all_ngram_vocab_sizes + + return vocab_size_across_layers + + def get_vocab_sizes(self, layer_id: int): + """ + Returns a flattened list of prime vocabulary sizes for a specific layer. + """ + return [head_size for ngram_size in self.vocab_size_across_layers[layer_id] for head_size in ngram_size] + + def _get_ngram_hashes(self, compressed_ids: np.ndarray, layer_id: int) -> np.ndarray: + """ + Computes hash indices for all n-grams in the input batch. + + Args: + compressed_ids: (B, S) input token IDs. + layer_id: engram layer id. + + Returns: + hash_ids: (B, S, H_total) where H_total = H * num_ngram_orders + """ + x = np.asarray(compressed_ids, dtype=np.int64) + B, T = x.shape + + # 1. Create Sliding Windows via Shifting + base_shifts = [] + for k in range(self.max_ngram_size): + if k == 0: + # e.g., [The, cat, sat] + base_shifts.append(x) + else: + # Pre-allocate full array with PAD_ID + shifted = np.full((B, T), self.pad_id, dtype=np.int64) + # Fast memory copy, slicing and assignment + # e.g., k=1, [PAD, The, cat] + shifted[:, k:] = x[:, :-k] + base_shifts.append(shifted) + + # 2. Retrieve layer-specific hash multipliers + multipliers = self.layer_multipliers[layer_id] + + # 3. Compute Hashes: multiplicative bitwise XOR + # Implements rolling hash: H_n = (Token_0 * m_0) ^ ... ^ (Token_k * m_k) + all_hashes = [] + # Initialize rolling hash with 1-gram + rolling_hash = base_shifts[0] * multipliers[0] + # Pre-fetch vocab sizes for modulo + vocab_sizes = self.vocab_size_across_layers[layer_id] + + for n in range(2, self.max_ngram_size + 1): + # Update rolling hash with next token position + k = n - 1 + rolling_hash = np.bitwise_xor(rolling_hash, base_shifts[k] * multipliers[k]) + + # Retrieve prime vocab sizes for all heads of this n-gram order + n_gram_index = n - 2 + vocab_sizes_for_this_gram = vocab_sizes[n_gram_index] + mods = np.array(vocab_sizes_for_this_gram, dtype=np.int64) + + # Broadcast Modulo: Map hash to valid table indices + # (B, S, 1) % (H,) -> (B, S, H) + head_hashes = rolling_hash[..., None] % mods + all_hashes.append(head_hashes) + + # Concatenate all heads: (B, S, H_total) where H_total = H * num_ngram_orders + return np.concatenate(all_hashes, axis=2) + + def __call__(self, input_ids): + # input_ids from standard tokenizer + compressed_ids = self.compressed_tokenizer(input_ids) + hash_ids_for_all_layers = {} + for layer_id in self.layer_ids: + hash_ids = self._get_ngram_hashes(compressed_ids, layer_id=layer_id) + hash_ids_for_all_layers[layer_id] = hash_ids + return hash_ids_for_all_layers + + +class MultiHeadEmbedding(nnx.Module): + """ + A flattened table representation for multi-head embedding spaces across n-gram orders. + """ + + def __init__( + self, + config: Config, + vocab_sizes: List[int], + head_dim: int, + mesh: Mesh, + rngs: nnx.Rngs, + ): + """ + Args: + vocab_sizes: Flattened list of prime vocabulary sizes for all heads across all n-gram orders. + Example: [Size_2gram_Head1, Size_2gram_Head2, Size_3gram_Head1, ...]. + head_dim: Embedding dimension for a single head. + config: The model configuration. + mesh: Device mesh for partitioning. + rngs: Random number generators for initialization. + """ + self.num_heads = len(vocab_sizes) + + # Compute starting index for each head's segment in the flattened table. + # Offsets serve as the "base address" for each head. + offsets = np.cumsum([0] + vocab_sizes[:-1]) # prefix sum + self.offsets = jnp.array(offsets, dtype=jnp.int32) + + # The total embedding size is the sum of all individual head vocabularies. + self.embedding = Embed(num_embeddings=sum(vocab_sizes), num_features=head_dim, config=config, mesh=mesh, rngs=rngs) + + def __call__(self, input_ids: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: + """ + Retrieves embeddings for multi-head indices. + + Args: + input_ids: Hashed indices. Shape (B, S, H_total), where H_total is the total number of heads. + model_mode: The model's operational mode (e.g., 'train', 'prefill'). + + Returns: + embeddings: Shape (B, S, H_total, D_head). + """ + # Broadcasting Add: (B, S, H) + (H,) -> (B, S, H) + # Shifts local indices (0..Prime-1) to global table positions. + shifted_ids = input_ids + self.offsets + + # Embedding lookup: (B, S, H_total) -> (B, S, H_total, D_head) + return self.embedding(shifted_ids, model_mode=model_mode) + + +class ShortConv(nnx.Module): + """ + Depthwise causal 1D convolution, with multi-branch integration. + + Applies local temporal smoothing + - Independent RMSNorms to each branch + - Shared convolution to mix time steps [t-k, t] + """ + + def __init__( + self, + config: Config, + hidden_size: int, + kernel_size: int = 4, + dilation: int = 1, + hc_mult: int = 4, + rngs: nnx.Rngs = None, + ): + """ + Args: + config: The model configuration. + hidden_size (D): Dimension of a single branch. + kernel_size: Temporal window size. + dilation: Dilation rate for the convolution. + hc_mult (G): Number of branches. + rngs: RNG state for initialization. + """ + self.hc_mult = hc_mult + # Total channels = G * D + total_channels = hidden_size * hc_mult + + # Norm (Vectorized) + # independent weight per branch, branched input + @nnx.split_rngs(splits=hc_mult) + @nnx.vmap(in_axes=0, out_axes=0) + def create_norms(r): + return RMSNorm( + num_features=hidden_size, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + epsilon=config.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=r, + ) + + # Weights: {"scale": (G, D)} + self.norms = create_norms(rngs) + + # Convolution (Shared) + # Depthwise: feature_group_count == in_features ensures no mixing across channels. + # Causal: Ensures output at t only depends on inputs <= t. + # Weights: {"kernel": (kernel_size, in_features//feature_group_count, total_channels)} + self.conv = nnx.Conv( + in_features=total_channels, + out_features=total_channels, + kernel_size=(kernel_size,), + feature_group_count=total_channels, + kernel_dilation=(dilation,), + padding="CAUSAL", + use_bias=False, + dtype=config.dtype, + param_dtype=config.weight_dtype, + precision=config.matmul_precision, + rngs=rngs, + ) + + def __call__(self, x: Array) -> Array: + """ + Compute y^i = SiLU(Conv1D(RMSNorm^i(x^i))) for each branch i. + + Args: + x: Input tensor of shape (B, S, G, D) + Returns: + Output tensor of shape (B, S, G, D) + + Shape annotation: + B: Batch size + S: Sequence length (temporal dimension) + G: Number of branches (hc_mult) + D: Hidden size (base_emb_dim) + """ + B, S, G, D = x.shape + + # Apply Norms (Vectorized over Group dim) + # `in_axes=(0, 2)`: norms is axis 0, x is axis 2 + # `out_axes=2`: put the group dim back at axis 2 + # shape stays (B, S, G, D) + @nnx.vmap(in_axes=(0, 2), out_axes=2) + def apply_norms(norms, x): + return norms(x) + + x = apply_norms(self.norms, x) + + # Flatten branches into channel: (B, S, G, D) -> (B, S, G * D) + x_flat = x.reshape(B, S, G * D) + # Depthwise Convolution to mix temporal dimension S only. Shape stays (B, S, G * D) + y = self.conv(x_flat) + y = jax.nn.silu(y) + # Restore branch: (B, S, G * D) -> (B, S, G, D) + return y.reshape(B, S, G, D) + + +class Engram(nnx.Module): + """ + Engram Memory Layer with n-gram embedding, with multi-branch integration. + + Main components: + - Context-independent Retrieval: Fetch static n-gram embeddings via Multi-Head Hashing. + - Context-aware Gating: Compute similarity between memory (Key) and context (Query) to determine relevance. + - Mix: Apply local temporal smoothing via convolution. + """ + + def __init__( + self, + rngs: nnx.Rngs, + config: Config, + mesh: Mesh, + quant: Optional[Quant] = None, + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"), + *, + hc_mult: int = 4, + vocab_sizes: List[int], + engram_num_heads: int, + engram_head_dim: int, + engram_max_ngram_size: int, + engram_kernel_size: int, + ): + """ + Args: + rngs: RNG state for initialization + config: The model configuration. + mesh: Partitioning mesh. + quant: Quantization config. + kernel_init: Weight initializer. + hc_mult (G): Number of branches. + vocab_sizes: List of prime vocabulary sizes for the embedding table. + engram_num_heads (H): Heads per n-gram order. + engram_head_dim (D_head): Dimension per head. + engram_max_ngram_size: Max n-gram order (e.g., 3 covers 2-grams and 3-grams). + engram_kernel_size: convolution kernel size. + """ + self.config = config + self.mesh = mesh + self.dtype = self.config.dtype + self.weight_dtype = self.config.dtype + self.kernel_init = kernel_init + self.quant = quant + self.rngs = rngs + self.hc_mult = hc_mult + + # Hierarchy: Engram -> n-gram Order -> h-th Head + self.max_ngram_size = engram_max_ngram_size + self.conv_kernel_size = engram_kernel_size + num_ngram_orders = self.max_ngram_size - 1 + # D_en: Final concatenated size of the retrieved memory + self.engram_dim = engram_head_dim * engram_num_heads * num_ngram_orders + + # Embedding: one flattened table to store all n-gram heads across orders + self.multi_head_embedding = MultiHeadEmbedding( + config=config, vocab_sizes=vocab_sizes, head_dim=engram_head_dim, mesh=mesh, rngs=rngs + ) + + # Key Projection (vectorized): retrieved n-gram memory -> Key + # Independent weights per branch, Shared input + @nnx.split_rngs(splits=hc_mult) + @nnx.vmap(in_axes=0, out_axes=0) + def create_key_projs(r): + return DenseGeneral( + in_features_shape=self.engram_dim, + out_features_shape=config.base_emb_dim, + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("engram_dim", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + matmul_precision=self.config.matmul_precision, + shard_mode=self.config.shard_mode, + use_bias=True, + rngs=r, + ) + + self.key_projs = create_key_projs(rngs) + + # Norms (vectorized) + # Independent weights per branch, Branched input + @nnx.split_rngs(splits=hc_mult) + @nnx.vmap(in_axes=0, out_axes=0) + def create_norms(r): + return RMSNorm( + num_features=config.base_emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + epsilon=config.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=r, + ) + + # Key Normalization + self.k_norms = create_norms(rngs) + # Query Normalization + self.q_norms = create_norms(rngs) + + # Value Projection (shared): Retrieved memory -> Value + self.value_proj = DenseGeneral( + in_features_shape=self.engram_dim, + out_features_shape=config.base_emb_dim, + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("engram_dim", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + matmul_precision=self.config.matmul_precision, + shard_mode=self.config.shard_mode, + use_bias=True, + rngs=self.rngs, + ) + + # Short Convolution (vectorized internally) + # Applies depthwise causal convolution to smooth the retrieved memory over time. + self.short_conv = ShortConv( + config=config, + hidden_size=config.base_emb_dim, + kernel_size=self.conv_kernel_size, + dilation=self.max_ngram_size, + hc_mult=hc_mult, + rngs=rngs, + ) + + def __call__(self, hidden_states: Array, hash_input_ids: Array) -> Array: + """ + Computes the Engram output by retrieving, gating, and smoothing n-gram memory. + + Args: + hidden_states: Current transformer state (Query). Shape: (B, S, G, D). + hash_input_ids: Hashed token IDs. Shape: (B, S, H_total). + Produced by `hash_mapping.hash(input_ids)[layer_id]`. + + Returns: + Shape: (B, S, G, D) + + Shape annotation: + B: Batch Size + S: Sequence Length + G: hc_mult, Number of Branches + H_total: Total number of heads across n-grams. num_head * num_ngrams + D: base_emb_dim + D_head: Dimension of a single head embedding + D_en: Dimension of flattened embedding across heads and n-grams + """ + B, S, _, D = hidden_states.shape + + # 1. Retrieve Memory from Embedding + # (B, S, H_total) -> (B, S, H_total, D_head) + embeddings = self.multi_head_embedding(hash_input_ids) + # (B, S, H_total, D_head) -> (B, S, D_en) + embeddings = embeddings.reshape(B, S, -1) + + # 2. Static Memory as Key + # Vectorized broadcast: apply each of the G key_projs to the SAME embeddings. + # in_axes: (0, None) -> 0 splits the Dense layers, None broadcasts embeddings + # out_axes: 2 -> Stack the results at axis 2 to get (B, S, G, D) + @nnx.vmap(in_axes=(0, None), out_axes=2) + def apply_projs(projs, x): + return projs(x) + + # (B, S, D_en) -> (B, S, G, D) + key = apply_projs(self.key_projs, embeddings) + + # 3. Compute Norms + # Vectorized Map: Map over the G dimension (Axis 2) for both weights and input + @nnx.vmap(in_axes=(0, 2), out_axes=2) + def apply_norms(norms, x): + return norms(x) + + # (B, S, G, D) shape stays + key = apply_norms(self.k_norms, key) + + # 4. Dynamic Context as Query + # (B, S, G, D) shape stays + query = apply_norms(self.q_norms, hidden_states) + + # 5. QK product as Gates + # Compute similarity of memory (Key) and current state (Query) + qk_product = jnp.einsum("bsgc,bsgc->bsg", query, key, precision=self.config.matmul_precision) + gate = qk_product / jnp.sqrt(D) + # Range Compression: Apply signed square-root to prevent sigmoid saturation + gate = jnp.sqrt(jnp.maximum(jnp.abs(gate), 1e-6)) * jnp.sign(gate) + # Sigmoid activation to get gating probability [0, 1] + gate = jax.nn.sigmoid(gate) # (B, S, G) + + # 6. Static Memory as Value + # (B, S, D_en) -> (B, S, D) + value = self.value_proj(embeddings) + + # 7. Apply Gates to Value + # (B, S, G, 1) * (B, S, 1, D) -> (B, S, G, D) + gated_value = gate[:, :, :, None] * value[:, :, None, :] + + # 8. ShortConv as Temporal Smoothing + # Shape remains, (B, S, G, D) + # Apply depthwise conv to mix S + conv_output = self.short_conv(gated_value) + # residual for conv component + output = gated_value + conv_output + + # Note: residual connection for hidden_states will be added by the caller + return output diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index 6a3a160a28..d3c838a096 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -33,7 +33,7 @@ from maxtext.utils import max_utils logger = logging.getLogger(__name__) -logger.setLevel(os.environ.get("LOGLEVEL", "INFO")) +# logger.setLevel(os.environ.get("LOGLEVEL", "INFO")) _BASE_CONFIG_ATTR = "base_config" _MAX_PREFIX = "M_" diff --git a/tests/unit/engram_vs_reference_test.py b/tests/unit/engram_vs_reference_test.py new file mode 100644 index 0000000000..c13c4e68d5 --- /dev/null +++ b/tests/unit/engram_vs_reference_test.py @@ -0,0 +1,784 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Tests for Engram: CompressedTokenizer, NgramHashMapping, MultiHeadEmbedding, ShortConv, Engram + +Reference implementation: +https://github.com/deepseek-ai/Engram/blob/fb7f84a21f91223715394a33a1dc24bbfb7f788e/engram_demo_v1.py + +To run the test: + python3 -m pip install torch numpy transformers sympy + python3 -m pytest -v --pyargs tests.unit.engram_vs_reference_test -rP -s +""" + + +from typing import List +from dataclasses import dataclass, field +import math +import os +import unittest +from absl.testing import parameterized + +import numpy as np +from sympy import isprime + +from tokenizers import normalizers, Regex +from transformers import AutoTokenizer +import torch +from torch import nn + +from flax import nnx +import jax +import jax.numpy as jnp +from jax.sharding import Mesh + +from MaxText.globals import MAXTEXT_PKG_DIR +from MaxText import pyconfig +from MaxText import maxtext_utils + +from MaxText.layers.engram import CompressedTokenizer as CompressedTokenizerJAX +from MaxText.layers.engram import NgramHashMapping as NgramHashMappingJAX +from MaxText.layers.engram import MultiHeadEmbedding as MultiHeadEmbeddingJAX +from MaxText.layers.engram import ShortConv as ShortConvJAX +from MaxText.layers.engram import Engram as EngramJAX + + +# ----------------------------------------------------------------------------- +# Config +# ----------------------------------------------------------------------------- + + +@dataclass +class Config: + """MaxText config""" + + base_emb_dim: int = 1024 + tokenizer_path: str = "deepseek-ai/DeepSeek-V3" + # TODO (ranran, shuningjin): add configs to base.yml during engram integration + # mhc + hc_mult: int = 2 # if > 1 use mhc, if 1 not use mhc + # Engram + engram_max_ngram_size: int = 3 # max_ngram_size, use 2...N + # List of minimum head vocab sizes for each n-gram order + engram_vocab_bases: List[int] = field(default_factory=lambda: [129280 * 5, 129280 * 5]) + engram_layer_ids: List[int] = field(default_factory=lambda: [1, 15]) + engram_kernel_size: int = 4 # conv kernel size + engram_head_dim: int = 32 + engram_num_heads: int = 8 # num heads per n-gram + # Hashing + engram_pad_id: int = 2 # TODO(shuningjin): not the same as tokenizer.pad_id? + engram_seed: int = 0 + + +class EngramConfig: + """Torch Engram Config""" + + def __init__(self, config): + self.tokenizer_name_or_path = config.tokenizer_path + self.engram_vocab_size = config.engram_vocab_bases + self.max_ngram_size = config.engram_max_ngram_size + self.n_embed_per_ngram = config.engram_head_dim * config.engram_num_heads + self.n_head_per_ngram = config.engram_num_heads + self.layer_ids = config.engram_layer_ids + self.pad_id = config.engram_pad_id + self.seed = config.engram_seed + self.kernel_size = config.engram_kernel_size + + +class BackBoneConfig: + """Torch Backbone Config""" + + def __init__(self, config): + + self.hidden_size = config.base_emb_dim + self.hc_mult = config.hc_mult + + +# ----------------------------------------------------------------------------- +# Torch Reference Implementation +# ----------------------------------------------------------------------------- + + +class CompressedTokenizer: + + def __init__( + self, + tokenizer_name_or_path, + ): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True) + + SENTINEL = "\uE000" + self.normalizer = normalizers.Sequence( + [ + normalizers.NFKC(), + normalizers.NFD(), + normalizers.StripAccents(), + normalizers.Lowercase(), + normalizers.Replace(Regex(r"[ \t\r\n]+"), " "), + normalizers.Replace(Regex(r"^ $"), SENTINEL), + normalizers.Strip(), + normalizers.Replace(SENTINEL, " "), + ] + ) + + self.lookup_table, self.num_new_token = self._build_lookup_table() + + def __len__(self): + return self.num_new_token + + def _build_lookup_table(self): + old2new = {} + key2new = {} + new_tokens = [] + + vocab_size = len(self.tokenizer) + for tid in range(vocab_size): + text = self.tokenizer.decode([tid], skip_special_tokens=False) + + if "�" in text: + key = self.tokenizer.convert_ids_to_tokens(tid) + else: + norm = self.normalizer.normalize_str(text) + key = norm if norm else text + + nid = key2new.get(key) + if nid is None: + nid = len(new_tokens) + key2new[key] = nid + new_tokens.append(key) + old2new[tid] = nid + + lookup = np.empty(vocab_size, dtype=np.int64) + for tid in range(vocab_size): + lookup[tid] = old2new[tid] + + return lookup, len(new_tokens) + + def _compress(self, input_ids): + arr = np.asarray(input_ids, dtype=np.int64) + pos_mask = arr >= 0 + out = arr.copy() + valid_ids = arr[pos_mask] + out[pos_mask] = self.lookup_table[valid_ids] + return out + + def __call__(self, input_ids): + return self._compress(input_ids) + + +class ShortConv(nn.Module): + + def __init__( + self, + hidden_size: int, + kernel_size: int = 4, + dilation: int = 1, + norm_eps: float = 1e-5, + hc_mult: int = 4, + activation: bool = True, + ): + super().__init__() + self.hc_mult = hc_mult + self.activation = activation + + total_channels = hidden_size * hc_mult + self.conv = nn.Conv1d( + in_channels=total_channels, + out_channels=total_channels, + kernel_size=kernel_size, + groups=total_channels, + bias=False, + padding=(kernel_size - 1) * dilation, + dilation=dilation, + ) + + self.norms = nn.ModuleList([nn.RMSNorm(hidden_size, eps=norm_eps) for _ in range(hc_mult)]) + + if self.activation: + self.act_fn = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Input: (B,L,HC_MULT,D) + Output: (B,L,HC_MULT,D) + """ + B, T, G, C = x.shape + + assert G == self.hc_mult, f"Input groups {G} != hc_mult {self.hc_mult}" + + normed_chunks = [] + for i in range(G): + chunk = x[:, :, i, :] + normed_chunks.append(self.norms[i](chunk)) + + x_norm = torch.cat(normed_chunks, dim=-1) + x_bct = x_norm.transpose(1, 2) + y_bct = self.conv(x_bct) + y_bct = y_bct[..., :T] + + if self.activation: + y_bct = self.act_fn(y_bct) + y = y_bct.transpose(1, 2).view(B, T, G, C).contiguous() + + return y + + +def find_next_prime(start, seen_primes): + candidate = start + 1 + while True: + if isprime(candidate) and candidate not in seen_primes: + return candidate + candidate += 1 + + +class NgramHashMapping: + + def __init__( + self, + engram_vocab_size, + max_ngram_size, + n_embed_per_ngram, + n_head_per_ngram, + layer_ids, + tokenizer_name_or_path, + pad_id, + seed, + ): + self.vocab_size_per_ngram = engram_vocab_size + self.max_ngram_size = max_ngram_size + self.n_embed_per_ngram = n_embed_per_ngram + self.n_head_per_ngram = n_head_per_ngram + self.pad_id = pad_id + self.layer_ids = layer_ids + + self.compressed_tokenizer = CompressedTokenizer(tokenizer_name_or_path=tokenizer_name_or_path) + self.tokenizer_vocab_size = len(self.compressed_tokenizer) + if self.pad_id is not None: + self.pad_id = int(self.compressed_tokenizer.lookup_table[self.pad_id]) + + max_long = np.iinfo(np.int64).max + M_max = int(max_long // self.tokenizer_vocab_size) + half_bound = max(1, M_max // 2) + PRIME_1 = 10007 + + self.layer_multipliers = {} + + for layer_id in self.layer_ids: + base_seed = int(seed + PRIME_1 * int(layer_id)) + g = np.random.default_rng(base_seed) + r = g.integers(low=0, high=half_bound, size=(self.max_ngram_size,), dtype=np.int64) + multipliers = r * 2 + 1 + self.layer_multipliers[layer_id] = multipliers + + self.vocab_size_across_layers = self.calculate_vocab_size_across_layers() + + def calculate_vocab_size_across_layers(self): + seen_primes = set() + vocab_size_across_layers = {} + + for layer_id in self.layer_ids: + all_ngram_vocab_sizes = [] + for ngram in range(2, self.max_ngram_size + 1): + current_ngram_heads_sizes = [] + + vocab_size = self.vocab_size_per_ngram[ngram - 2] + num_head = self.n_head_per_ngram + current_prime_search_start = vocab_size - 1 + + for _ in range(num_head): + found_prime = find_next_prime(current_prime_search_start, seen_primes) + seen_primes.add(found_prime) + current_ngram_heads_sizes.append(found_prime) + current_prime_search_start = found_prime + + all_ngram_vocab_sizes.append(current_ngram_heads_sizes) + vocab_size_across_layers[layer_id] = all_ngram_vocab_sizes + + return vocab_size_across_layers + + def _get_ngram_hashes( + self, + input_ids: np.ndarray, + layer_id: int, + ) -> np.ndarray: + x = np.asarray(input_ids, dtype=np.int64) + B, T = x.shape + + multipliers = self.layer_multipliers[layer_id] + + def shift_k(k: int) -> np.ndarray: + if k == 0: + return x + shifted = np.pad(x, ((0, 0), (k, 0)), mode="constant", constant_values=self.pad_id)[:, :T] + return shifted + + base_shifts = [shift_k(k) for k in range(self.max_ngram_size)] + + all_hashes = [] + + for n in range(2, self.max_ngram_size + 1): + n_gram_index = n - 2 + tokens = base_shifts[:n] + mix = tokens[0] * multipliers[0] + for k in range(1, n): + mix = np.bitwise_xor(mix, tokens[k] * multipliers[k]) + num_heads_for_this_ngram = self.n_head_per_ngram + head_vocab_sizes = self.vocab_size_across_layers[layer_id][n_gram_index] + + for j in range(num_heads_for_this_ngram): + mod = int(head_vocab_sizes[j]) + head_hash = mix % mod + all_hashes.append(head_hash.astype(np.int64, copy=False)) + + return np.stack(all_hashes, axis=2) + + def hash(self, input_ids): + input_ids = self.compressed_tokenizer(input_ids) + hash_ids_for_all_layers = {} + for layer_id in self.layer_ids: + hash_ids_for_all_layers[layer_id] = self._get_ngram_hashes(input_ids, layer_id=layer_id) + return hash_ids_for_all_layers + + +class MultiHeadEmbedding(nn.Module): + + def __init__(self, list_of_N: List[int], D: int): + super().__init__() + self.num_heads = len(list_of_N) + self.embedding_dim = D + + offsets = [0] + for n in list_of_N[:-1]: + offsets.append(offsets[-1] + n) + + self.register_buffer("offsets", torch.tensor(offsets, dtype=torch.long)) + + total_N = sum(list_of_N) + self.embedding = nn.Embedding(num_embeddings=total_N, embedding_dim=D) + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + shifted_input_ids = input_ids + self.offsets + output = self.embedding(shifted_input_ids) + + return output + + +class Engram(nn.Module): + + # added argument: engram_cfg, backbone_config + def __init__(self, layer_id, backbone_config, engram_cfg): + super().__init__() + self.layer_id = layer_id + self.hash_mapping = NgramHashMapping( + engram_vocab_size=engram_cfg.engram_vocab_size, + max_ngram_size=engram_cfg.max_ngram_size, + n_embed_per_ngram=engram_cfg.n_embed_per_ngram, + n_head_per_ngram=engram_cfg.n_head_per_ngram, + layer_ids=engram_cfg.layer_ids, + tokenizer_name_or_path=engram_cfg.tokenizer_name_or_path, + pad_id=engram_cfg.pad_id, + seed=engram_cfg.seed, + ) + self.multi_head_embedding = MultiHeadEmbedding( + list_of_N=[x for y in self.hash_mapping.vocab_size_across_layers[self.layer_id] for x in y], + D=engram_cfg.n_embed_per_ngram // engram_cfg.n_head_per_ngram, + ) + self.short_conv = ShortConv( + hidden_size=backbone_config.hidden_size, + kernel_size=engram_cfg.kernel_size, + dilation=engram_cfg.max_ngram_size, + hc_mult=backbone_config.hc_mult, + ) + engram_hidden_size = (engram_cfg.max_ngram_size - 1) * engram_cfg.n_embed_per_ngram + self.value_proj = nn.Linear(engram_hidden_size, backbone_config.hidden_size) + self.key_projs = nn.ModuleList( + [nn.Linear(engram_hidden_size, backbone_config.hidden_size) for _ in range(backbone_config.hc_mult)] + ) + self.norm1 = nn.ModuleList([nn.RMSNorm(backbone_config.hidden_size) for _ in range(backbone_config.hc_mult)]) + self.norm2 = nn.ModuleList([nn.RMSNorm(backbone_config.hidden_size) for _ in range(backbone_config.hc_mult)]) + + # added argument: backbone_config + def forward(self, hidden_states, input_ids, backbone_config): + """ + hidden_states: [B, L, HC_MULT, D] + input_ids: [B, L] + """ + hash_input_ids = torch.from_numpy(self.hash_mapping.hash(input_ids)[self.layer_id]) + embeddings = self.multi_head_embedding(hash_input_ids).flatten(start_dim=-2) + gates = [] + for hc_idx in range(backbone_config.hc_mult): + key = self.key_projs[hc_idx](embeddings) + normed_key = self.norm1[hc_idx](key) + query = hidden_states[:, :, hc_idx, :] + normed_query = self.norm2[hc_idx](query) + gate = (normed_key * normed_query).sum(dim=-1) / math.sqrt(backbone_config.hidden_size) + gate = gate.abs().clamp_min(1e-6).sqrt() * gate.sign() + gate = gate.sigmoid().unsqueeze(-1) + gates.append(gate) + gates = torch.stack(gates, dim=2) + value = gates * self.value_proj(embeddings).unsqueeze(2) + output = value + self.short_conv(value) + return output + + +# ----------------------------------------------------------------------------- +# Test JAX Module: Helper +# ----------------------------------------------------------------------------- + + +def to_jax(pt_tensor: torch.Tensor) -> jax.Array: + return jnp.asarray(pt_tensor.detach().cpu().numpy()) + + +def init_torch_weights(module, std=1): + """ + Initialize all parameters in the module with N(0,std). + This simple strategy is intended only for unit test. + """ + with torch.no_grad(): + for _, param in module.named_parameters(): + torch.nn.init.normal_(param, mean=0.0, std=std) + + +def get_cfg_and_mesh(config): + """Returns MaxText configuration and mesh.""" + cfg = pyconfig.initialize( + [None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + run_name="", + enable_checkpointing=False, + model_name="default", + dtype="float32", + # high precision + weight_dtype="float32", + matmul_precision="highest", + float32_qk_product=True, + float32_logits=True, + base_emb_dim=config.base_emb_dim, + ) + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + return cfg, mesh + + +# ----------------------------------------------------------------------------- +# Test JAX Module (non-pareamteric): CompressedTokenizer, NgramHashMapping +# ----------------------------------------------------------------------------- + + +class CompressedTokenizerTest(parameterized.TestCase): + + def test_tokenierzer_match(self): + np.random.seed(42) + tokenizer_path = "deepseek-ai/DeepSeek-V3" + hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) + # input + batch_size, seq_len = 2, 32 + input_ids = np.random.randint(0, len(hf_tokenizer), (batch_size, seq_len)) + # 1. PyTorch + pt_tokenizer = CompressedTokenizer(tokenizer_path) + pt_lookup_table = pt_tokenizer.lookup_table + pt_out = pt_tokenizer(input_ids) + # 2. JAX + jax_tokenizer = CompressedTokenizerJAX(hf_tokenizer) + jax_lookup_table = jax_tokenizer.lookup_table + jax_out = jax_tokenizer(input_ids) + # 3. Compare + np.testing.assert_equal(jax_lookup_table, pt_lookup_table) + np.testing.assert_equal(len(pt_tokenizer), len(jax_tokenizer)) + np.testing.assert_array_equal(pt_out, jax_out) + + +class NgramHashMappingTest(parameterized.TestCase): + + def test_hash_match(self): + np.random.seed(42) + self.config = Config() + self.engram_cfg = EngramConfig(self.config) + self.backbone_config = BackBoneConfig(self.config) + tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_path, trust_remote_code=True) + # input + batch_size, seq_len = 2, 32 + input_ids = np.random.randint(0, len(tokenizer), (batch_size, seq_len)) + # 1. PyTorch + pt_hash_mapping = NgramHashMapping( + engram_vocab_size=self.engram_cfg.engram_vocab_size, + max_ngram_size=self.engram_cfg.max_ngram_size, + n_embed_per_ngram=self.engram_cfg.n_embed_per_ngram, + n_head_per_ngram=self.engram_cfg.n_head_per_ngram, + layer_ids=self.engram_cfg.layer_ids, + tokenizer_name_or_path=self.engram_cfg.tokenizer_name_or_path, + pad_id=self.engram_cfg.pad_id, + seed=self.engram_cfg.seed, + ) + pt_out = pt_hash_mapping.hash(input_ids) + # 2. JAX + jax_hash_mapping = NgramHashMappingJAX( + engram_vocab_bases=self.config.engram_vocab_bases, + max_ngram_size=self.config.engram_max_ngram_size, + engram_num_heads=self.config.engram_num_heads, + layer_ids=self.config.engram_layer_ids, + tokenizer=tokenizer, + pad_id=self.config.engram_pad_id, + seed=self.config.engram_seed, + ) + jax_out = jax_hash_mapping(input_ids) + # 3. Compare + # keys are layer_ids + self.assertDictEqual(jax_hash_mapping.vocab_size_across_layers, pt_hash_mapping.vocab_size_across_layers) + np.testing.assert_equal(pt_out, jax_out) + + +# ----------------------------------------------------------------------------- +# Test JAX Module: MultiHeadEmbedding +# ----------------------------------------------------------------------------- + + +def to_jax_mhe(pt_layer): + """ + Extracts weights from PyTorch MultiHeadEmbedding. + """ + return {"embedding": {"embedding": to_jax(pt_layer.embedding.weight)}} + + +class MultiHeadEmbeddingTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + torch.set_default_dtype(torch.float32) + torch.manual_seed(42) + np.random.seed(42) + self.rngs = nnx.Rngs(params=0) + + @parameterized.named_parameters( + {"testcase_name": "multiple_head", "vocab_sizes": [100, 200, 150]}, + {"testcase_name": "single_head", "vocab_sizes": [500]}, + ) + def test_mhe_match(self, vocab_sizes, head_dim=32): + # vocab_sizes: a flattened list of sizes for all heads across all n-grams + # Input + num_total_heads = len(vocab_sizes) + # indices must be within the range of each specific head's vocab. + batch_size, seq_len = 2, 32 + input_np = np.zeros((batch_size, seq_len, num_total_heads), dtype=np.int32) + for i, v_size in enumerate(vocab_sizes): + input_np[:, :, i] = np.random.randint(0, v_size, (batch_size, seq_len)) + x_pt = torch.from_numpy(input_np).long() + x_jax = jnp.array(input_np) + + # 1. PyTorch + pt_model = MultiHeadEmbedding(vocab_sizes, head_dim) + init_torch_weights(pt_model) + pt_model.eval() + with torch.no_grad(): + y_pt = pt_model(x_pt) + + # 2. JAX + config = Config() + cfg, mesh = get_cfg_and_mesh(config) + jax_model = MultiHeadEmbeddingJAX(cfg, vocab_sizes, head_dim, mesh, self.rngs) + # weight transfer + weights = to_jax_mhe(pt_model) + nnx.update(jax_model, weights) + # forward + y_jax = jax_model(x_jax) + + # 3. Compare + # Check offsets + np.testing.assert_array_equal(jax_model.offsets, to_jax(pt_model.offsets)) + # Check outputs + np.testing.assert_allclose(y_jax, to_jax(y_pt), rtol=1e-5, atol=1e-5) + + +# ----------------------------------------------------------------------------- +# Test JAX Module: ShortConv +# ----------------------------------------------------------------------------- + + +def to_jax_norm(pt_norm): + """Extracts scale parameter from a norm layer.""" + return {"scale": to_jax(pt_norm.weight)} + + +def to_jax_vmap(pt_module_list, transform_fn): + """ + Applies transform_fn to a list of modules and stacks the + resulting PyTrees along a new leading axis. + """ + jax_trees = [transform_fn(m) for m in pt_module_list] + # Stacks all keys (kernel, bias, etc.) along a new 0-th dimension + return jax.tree.map(lambda *xs: jnp.stack(xs), *jax_trees) + + +def to_jax_shortconv(pt_layer): + """ + Converts a ShortConv layer containing a Conv and a ModuleList of Norms. + """ + return { + # (Out, In//Groups, Kernel) -> (Kernel, In//Groups, Out) + "conv": {"kernel": to_jax(pt_layer.conv.weight.permute(2, 1, 0))}, + # List[Norm] -> Stacked norm (Groups, Channels) + "norms": to_jax_vmap(pt_layer.norms, to_jax_norm), + } + + +class ShortConvTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + torch.set_default_dtype(torch.float32) + torch.manual_seed(42) + np.random.seed(42) + self.nnx_rngs = nnx.Rngs(params=0) + + @parameterized.named_parameters( + {"testcase_name": "multi_branch", "hc_mult": 4}, + {"testcase_name": "single_branch", "hc_mult": 1}, + ) + def test_shortconv_match(self, hc_mult, dilation=2, hidden_size=16, kernel_size=3): + + # Input Data, Shape: (B, S, G, D) + batch_size, seq_len = 2, 32 + x_pt = torch.randn(batch_size, seq_len, hc_mult, hidden_size) + x_jax = to_jax(x_pt) + + config = Config() + cfg, mesh = get_cfg_and_mesh(config) + + # 1. PyTorch + pt_model = ShortConv(hidden_size, kernel_size, dilation, hc_mult=hc_mult, norm_eps=cfg.normalization_layer_epsilon) + init_torch_weights(pt_model) + pt_model.eval() + with torch.no_grad(): + y_pt = pt_model(x_pt) + + # 2. JAX + jax_model = ShortConvJAX(cfg, hidden_size, kernel_size, dilation, hc_mult=hc_mult, rngs=self.nnx_rngs) + # Transfer Weights + weights = to_jax_shortconv(pt_model) + nnx.update(jax_model, weights) + # Forward Pass + y_jax = jax_model(x_jax) + + # 3. Compare + np.testing.assert_allclose(y_jax, to_jax(y_pt), rtol=1e-3, atol=1e-3) + + +# ----------------------------------------------------------------------------- +# Test JAX Module: Engram +# ----------------------------------------------------------------------------- + + +def to_jax_linear(pt_linear): + """(Out, In) -> {'kernel': (In, Out), 'bias': (Out)}""" + out = {"kernel": to_jax(pt_linear.weight.T)} + if pt_linear.bias is not None: + out["bias"] = to_jax(pt_linear.bias) + return out + + +def to_jax_engram(pt_engram) -> dict: + return { + "multi_head_embedding": to_jax_mhe(pt_engram.multi_head_embedding), + # Standard Single Layer (No Stacking needed) + "value_proj": to_jax_linear(pt_engram.value_proj), + # Vectorized Layers (Stacking needed) + # Result shapes: Kernel (G, In, Out), Bias (G, Out) + "key_projs": to_jax_vmap(pt_engram.key_projs, to_jax_linear), + # Result shapes: Scale (G, D) + "k_norms": to_jax_vmap(pt_engram.norm1, to_jax_norm), + "q_norms": to_jax_vmap(pt_engram.norm2, to_jax_norm), + "short_conv": to_jax_shortconv(pt_engram.short_conv), + } + + +class EngramTest(parameterized.TestCase): + """Verifies JAX Engram implementation matches PyTorch reference.""" + + def setUp(self): + super().setUp() + torch.set_default_dtype(torch.float32) + torch.manual_seed(42) + np.random.seed(42) + self.nnx_rng = nnx.Rngs(params=0) + self.layer_id = 1 # must belong to config.engram_layer_ids + + @parameterized.named_parameters( + {"testcase_name": "multi_branch", "hc_mult": 4}, + {"testcase_name": "single_branch", "hc_mult": 1}, + ) + def test_engram_match(self, hc_mult): + # config + self.config = Config(hc_mult=hc_mult) + self.engram_cfg = EngramConfig(self.config) + self.backbone_config = BackBoneConfig(self.config) + # Prepare Inputs + # random input_ids (B, S) + batch_size, seq_len = 2, 32 + input_ids_np = np.random.randint(0, 1000, (batch_size, seq_len)) + pt_input_ids = torch.from_numpy(input_ids_np) + # hidden_states (B, S, G, D) + pt_hidden_states = torch.randn( + batch_size, seq_len, self.backbone_config.hc_mult, self.backbone_config.hidden_size, dtype=torch.float32 + ) + + # 1. PyTorch + pt_layer = Engram(layer_id=self.layer_id, backbone_config=self.backbone_config, engram_cfg=self.engram_cfg) + init_torch_weights(pt_layer) + pt_layer.eval() + # forward + with torch.no_grad(): + pt_out = pt_layer(pt_hidden_states, pt_input_ids, self.backbone_config) + + # 2. JAX + # Data pipeline + tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_path, trust_remote_code=True) + jax_hash_mapping = NgramHashMappingJAX( + engram_vocab_bases=self.config.engram_vocab_bases, + max_ngram_size=self.config.engram_max_ngram_size, + engram_num_heads=self.config.engram_num_heads, + layer_ids=self.config.engram_layer_ids, + tokenizer=tokenizer, + pad_id=self.config.engram_pad_id, + seed=self.config.engram_seed, + ) + vocab_sizes = jax_hash_mapping.get_vocab_sizes(self.layer_id) # layer specific + jax_hash_input_ids = jax_hash_mapping(input_ids_np)[self.layer_id] # layer specific + + # Setup model + cfg, mesh = get_cfg_and_mesh(self.config) + jax_layer = EngramJAX( + rngs=self.nnx_rng, + config=cfg, + mesh=mesh, + hc_mult=self.config.hc_mult, + vocab_sizes=vocab_sizes, + engram_num_heads=self.config.engram_num_heads, + engram_head_dim=self.config.engram_head_dim, + engram_max_ngram_size=self.config.engram_max_ngram_size, + engram_kernel_size=self.config.engram_kernel_size, + ) + # Synchronize Weights + jax_weights = to_jax_engram(pt_layer) + nnx.update(jax_layer, jax_weights) + # Forward + jax_out = jax_layer(to_jax(pt_hidden_states), jax_hash_input_ids) + + # 3. Compare + np.testing.assert_allclose(to_jax(pt_out), jax_out, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + unittest.main()