Skip to content

Commit 21cec5f

Browse files
committed
isolate HashMapping
1 parent 69ac35d commit 21cec5f

2 files changed

Lines changed: 99 additions & 101 deletions

File tree

src/MaxText/layers/engram.py

Lines changed: 43 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
# limitations under the License.
1414

1515

16-
from typing import List, Optional
16+
from typing import List, Optional, Callable
1717
from dataclasses import dataclass, field
1818
import math
19-
from typing import List, Callable
20-
2119
import numpy as np
2220
from sympy import isprime
21+
2322
import torch
2423
import torch.nn as nn
2524
from transformers import AutoTokenizer
@@ -40,7 +39,7 @@
4039
DeepSeek-AI, `Conditional Memory via Scalable Lookup: A New Axis of Sparsity for Large Language Models
4140
<https://arxiv.org/pdf/2601.07372>`_, 2026
4241
43-
Implementation: https://github.com/deepseek-ai/Engram/blob/main/engram_demo_v1.py
42+
Reference implementation: https://github.com/deepseek-ai/Engram/blob/main/engram_demo_v1.py
4443
"""
4544

4645

@@ -165,7 +164,7 @@ def __init__(
165164
n_embed_per_ngram,
166165
n_head_per_ngram,
167166
layer_ids,
168-
tokenizer, # pass the global tokenizer to avoid re-loading
167+
tokenizer,
169168
pad_id,
170169
seed,
171170
):
@@ -430,6 +429,38 @@ def apply_norms(norms, x):
430429
# (B, L, G * C) -> (B, L, G, C)
431430
return y.reshape(B, L, G, C)
432431

432+
# -----------------------------------------------------------------------
433+
# Vocabulary Size Calculation (Global Prime Sequence)
434+
# -----------------------------------------------------------------------
435+
# Engram uses unique prime numbers for the vocabulary size of every head
436+
# in every layer to maximize hash collision independence.
437+
438+
# We instantiate NgramHashMapping here to replicate this deterministic
439+
# sequence. Ideally, this mapping should be created once globally and
440+
# the resulting vocab_sizes passed into this layer to improve startup time.
441+
442+
# # --- Hash Mapping ---
443+
# self.hash_mapping = NgramHashMapping(
444+
# engram_vocab_size=engram_vocab_size,
445+
# max_ngram_size=engram_max_ngram_size,
446+
# n_embed_per_ngram=engram_embed_dim_per_ngram,
447+
# n_head_per_ngram=engram_heads_per_ngram,
448+
# # IMPORTANT: We must pass the FULL list of layer_ids, not just self.layer_id.
449+
# # The mapping finds primes sequentially across all layers; passing a partial
450+
# # list would reset the prime search and break alignment with the reference model.
451+
# layer_ids=layer_ids,
452+
# # Inject the pre-loaded tokenizer to avoid redundant disk I/O per layer.
453+
# tokenizer=tokenizer,
454+
# pad_id=pad_id,
455+
# seed=seed,
456+
# )
457+
458+
# # Extract the specific list of primes [M_{n,k}] for THIS layer only.
459+
# # The structure is [[ngram2_head1, ngram2_head2...], [ngram3_head1...]]
460+
# # We flatten it into a single list of ints: [N1, N2, N3, ...]
461+
# vocab_sizes = self.hash_mapping.get_vocab_sizes(self.layer_id)
462+
# print(f"DEBUG JAX Layer, Vocab Sizes: {vocab_sizes}")
463+
433464

434465
class Engram(nnx.Module):
435466
"""
@@ -440,27 +471,24 @@ class Engram(nnx.Module):
440471
1. Retrieve: Fetch embeddings for current n-gram contexts using Multi-Head Hashing.
441472
2. Gate: Decide how much of this retrieved memory to merge based on the current state.
442473
3. Mix: Apply local temporal smoothing via ShortConv.
474+
475+
Note: vocab_sizes = hash_mapping.get_vocab_sizes(layer_id)
443476
"""
444477

445478
def __init__(
446479
self,
447-
layer_id: int,
448480
rngs: nnx.Rngs,
449481
config,
450482
mesh,
451-
tokenizer, # pass the tokenizer to avoid re-loading
452483
quant: Optional[Quant] = None,
453484
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"),
454485
*,
486+
vocab_sizes,
455487
hc_mult,
456488
engram_heads_per_ngram,
457489
engram_embed_dim_per_ngram,
458490
engram_max_ngram_size,
459491
engram_kernel_size,
460-
engram_vocab_size,
461-
layer_ids,
462-
pad_id,
463-
seed,
464492
):
465493
self.config = config
466494
self.mesh = mesh
@@ -469,11 +497,6 @@ def __init__(
469497
self.kernel_init = kernel_init
470498
self.quant = quant
471499
self.rngs = rngs
472-
473-
self.layer_id = layer_id
474-
self.layer_ids = layer_ids
475-
self.pad_id = pad_id
476-
self.seed = seed
477500
self.hc_mult = hc_mult
478501

479502
# --- Dimensions ---
@@ -484,7 +507,7 @@ def __init__(
484507
# Hierarchy: Engram -> n-gram order (n) -> k-th head (k)
485508
# Raw Inputs
486509
self.max_ngram_size = engram_max_ngram_size # e.g., 4 (tracks 2,3,4-grams)
487-
self.vocab_size = engram_vocab_size
510+
# self.vocab_size = engram_vocab_size
488511
self.conv_kernel_size = engram_kernel_size
489512
# The Hierarchy (Paper Notation)
490513
# K: Number of heads per n-gram order
@@ -500,38 +523,6 @@ def __init__(
500523
# Final concatenated size: (Num Orders) * (Dim per Order)
501524
self.engram_dim = self.num_orders * self.dim_per_ngram
502525

503-
# -----------------------------------------------------------------------
504-
# Vocabulary Size Calculation (Global Prime Sequence)
505-
# -----------------------------------------------------------------------
506-
# Engram uses unique prime numbers for the vocabulary size of every head
507-
# in every layer to maximize hash collision independence.
508-
509-
# We instantiate NgramHashMapping here to replicate this deterministic
510-
# sequence. Ideally, this mapping should be created once globally and
511-
# the resulting vocab_sizes passed into this layer to improve startup time.
512-
513-
# --- Hash Mapping ---
514-
self.hash_mapping = NgramHashMapping(
515-
engram_vocab_size=engram_vocab_size,
516-
max_ngram_size=engram_max_ngram_size,
517-
n_embed_per_ngram=engram_embed_dim_per_ngram,
518-
n_head_per_ngram=engram_heads_per_ngram,
519-
# IMPORTANT: We must pass the FULL list of layer_ids, not just self.layer_id.
520-
# The mapping finds primes sequentially across all layers; passing a partial
521-
# list would reset the prime search and break alignment with the reference model.
522-
layer_ids=layer_ids,
523-
# Inject the pre-loaded tokenizer to avoid redundant disk I/O per layer.
524-
tokenizer=tokenizer,
525-
pad_id=pad_id,
526-
seed=seed,
527-
)
528-
529-
# Extract the specific list of primes [M_{n,k}] for THIS layer only.
530-
# The structure is [[ngram2_head1, ngram2_head2...], [ngram3_head1...]]
531-
# We flatten it into a single list of ints: [N1, N2, N3, ...]
532-
vocab_sizes = self.hash_mapping.get_vocab_sizes(self.layer_id)
533-
print(f"DEBUG JAX Layer, Vocab Sizes: {vocab_sizes}")
534-
535526
# --- 1. Multi-Head Embedding ---
536527
# Stores the learnable vectors E_{n,k} for all n-gram heads in one flattened table.
537528
self.mhe = MultiHeadEmbedding(list_of_N=vocab_sizes, D=self.dim_per_head, config=config, mesh=mesh, rngs=rngs)
@@ -610,7 +601,7 @@ def create_norms(r):
610601
# Note: Creates separate parameters for Q norms
611602
self.q_norms = create_norms(rngs)
612603

613-
def __call__(self, hidden_states: jax.Array, input_ids: jax.Array) -> jax.Array:
604+
def __call__(self, hidden_states: jax.Array, hash_input_ids: jax.Array) -> jax.Array:
614605
"""
615606
Args:
616607
hidden_states: Current transformer state (Query). Shape: (B, L, G, C)
@@ -619,14 +610,15 @@ def __call__(self, hidden_states: jax.Array, input_ids: jax.Array) -> jax.Array:
619610
output: Engram-augmented residuals. Shape: (B, L, G, C)
620611
621612
Note: G = hc_mult
613+
Note: hash_input_ids = hash_mapping.hash(input_ids)[layer_id]
622614
"""
623615
B, L, G, C = hidden_states.shape
624616

625617
# 1. Retrieve Memory
626618
# 1. Generate Hash Indices
627619
# Map raw text -> n-gram contexts -> hash indices z_{t,n,k}
628620
# (B, L) -> (B, L, H_en), where H_en is the total count of heads across all n-gram orders.
629-
hash_input_ids = jnp.array(self.hash_mapping.hash(input_ids)[self.layer_id])
621+
# hash_input_ids = jnp.array(self.hash_mapping.hash(input_ids)[self.layer_id])
630622

631623
# 2. Retrieve Memory
632624
# Fetch e_{t,n,k} from the embedding table.

tests/unit/engram_vs_reference_test.py

Lines changed: 56 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
"""
1717
Tests for Engram: MultiHeadEmbedding, ShortConv, Engram
1818
19-
reference: https://github.com/deepseek-ai/Engram/blob/fb7f84a21f91223715394a33a1dc24bbfb7f788e/engram_demo_v1.py
19+
Reference implementation: https://github.com/deepseek-ai/Engram/blob/fb7f84a21f91223715394a33a1dc24bbfb7f788e/engram_demo_v1.py
2020
s
21+
2122
To run the test
2223
pip install torch numpy transformers sympy
2324
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
@@ -54,6 +55,7 @@
5455
from MaxText.layers.engram import Engram as EngramJAX
5556
from MaxText.layers.engram import ShortConv as ShortConvJAX
5657
from MaxText.layers.engram import MultiHeadEmbedding as MultiHeadEmbeddingJAX
58+
from MaxText.layers.engram import NgramHashMapping as NgramHashMappingJAX
5759

5860

5961
# -----------------------------------------------------------------------------
@@ -605,15 +607,6 @@ def to_jax_norm(pt_norm):
605607
"""Extracts scale parameter from a norm layer."""
606608
return {"scale": to_jax(pt_norm.weight)}
607609

608-
609-
def to_jax_linear(pt_linear):
610-
"""(Out, In) -> {'kernel': (In, Out), 'bias': (Out)}"""
611-
out = {"kernel": to_jax(pt_linear.weight.T)}
612-
if pt_linear.bias is not None:
613-
out["bias"] = to_jax(pt_linear.bias)
614-
return out
615-
616-
617610
def to_jax_vmap(pt_module_list, transform_fn):
618611
"""
619612
Applies transform_fn to a list of modules and stacks the
@@ -628,12 +621,10 @@ def to_jax_shortconv(pt_layer):
628621
"""
629622
Converts a ShortConv layer containing a Conv and a ModuleList of Norms.
630623
"""
631-
# 1. Conv Weights. PyTorch: (Out, In//Groups, Kernel) -> JAX: (Kernel, In//Groups, Out)
632-
conv_kernel = pt_layer.conv.weight.permute(2, 1, 0)
633-
634624
return {
635-
"conv": {"kernel": to_jax(conv_kernel)},
636-
# 2. Weights for the Norms: List[Norm] -> {'scale': (Groups, Channels)}
625+
# (Out, In//Groups, Kernel) -> (Kernel, In//Groups, Out)
626+
"conv": {"kernel": to_jax(pt_layer.conv.weight.permute(2, 1, 0))},
627+
# List[Norm] -> Stacked norm (Groups, Channels)
637628
"norms": to_jax_vmap(pt_layer.norms, to_jax_norm),
638629
}
639630

@@ -644,6 +635,7 @@ def setUp(self):
644635
super().setUp()
645636
torch.manual_seed(42)
646637
np.random.seed(42)
638+
self.nnx_rngs = nnx.Rngs(params=0)
647639

648640
@parameterized.named_parameters(
649641
# {"testcase_name": "base", "hidden_size": 32, "hc_mult": 4, "kernel_size": 4, "dilation": 1},
@@ -666,10 +658,11 @@ def test_shortconv_match(self, hidden_size, hc_mult, kernel_size, dilation):
666658
pt_model.eval()
667659

668660
# 2. Init JAX
669-
rngs = nnx.Rngs(params=0)
670661
config = Config()
671662
cfg, mesh = get_cfg_and_mesh(config)
672-
jax_model = ShortConvJAX(cfg, hidden_size, kernel_size, dilation, hc_mult=hc_mult, activation=activation, rngs=rngs)
663+
jax_model = ShortConvJAX(
664+
cfg, hidden_size, kernel_size, dilation, hc_mult=hc_mult, activation=activation, rngs=self.nnx_rngs
665+
)
673666
print(jax_model)
674667

675668
# 3. Transfer Weights
@@ -705,6 +698,14 @@ def test_shortconv_match(self, hidden_size, hc_mult, kernel_size, dilation):
705698
# -----------------------------------------------------------------------------
706699

707700

701+
def to_jax_linear(pt_linear):
702+
"""(Out, In) -> {'kernel': (In, Out), 'bias': (Out)}"""
703+
out = {"kernel": to_jax(pt_linear.weight.T)}
704+
if pt_linear.bias is not None:
705+
out["bias"] = to_jax(pt_linear.bias)
706+
return out
707+
708+
708709
def to_jax_engram(pt_engram) -> dict:
709710
return {
710711
"mhe": to_jax_mhe(pt_engram.multi_head_embedding),
@@ -728,6 +729,7 @@ def setUp(self):
728729
torch.set_default_dtype(torch.float32)
729730
torch.manual_seed(42)
730731
np.random.seed(42)
732+
self.nnx_rng = nnx.Rngs(params=0)
731733

732734
self.batch_size = 2
733735
self.seq_len = 8
@@ -737,68 +739,72 @@ def setUp(self):
737739
self.engram_cfg = EngramConfig(self.config)
738740
self.backbone_config = BackBoneConfig(self.config)
739741

740-
self.nnx_rng = nnx.Rngs(params=0)
741-
742742
@parameterized.named_parameters(
743743
{"testcase_name": "standard_run", "batch_size": 2, "seq_len": 16},
744744
)
745745
def test_engram_match(self, batch_size, seq_len):
746-
# 1. Setup PyTorch Reference
747-
746+
# 1. torch
748747
EngramPT = Engram
749748
pt_layer = EngramPT(layer_id=self.layer_id, backbone_config=self.backbone_config, engram_cfg=self.engram_cfg)
750749
init_torch_weights(pt_layer)
751750
pt_layer.eval()
752751

752+
# Prepare Inputs
753+
# Create random input_ids and hidden_states
754+
input_ids_np = np.random.randint(0, 1000, (batch_size, seq_len))
755+
pt_input_ids = torch.from_numpy(input_ids_np)
756+
# (B, L, G, D)
757+
pt_hidden_states = torch.randn(
758+
batch_size, seq_len, self.backbone_config.hc_mult, self.backbone_config.hidden_size, dtype=torch.float32
759+
)
760+
761+
# Run Inference
762+
with torch.no_grad():
763+
pt_out = pt_layer(pt_hidden_states, pt_input_ids, self.backbone_config)
764+
765+
# 2 Jax
753766
# "deepseek-ai/DeepSeek-V3"
754767
tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_path, trust_remote_code=True)
755768

756-
# 2. Setup JAX NNX Implementation
757-
config = Config()
758-
cfg, mesh = get_cfg_and_mesh(config)
769+
jax_hash_mapping = NgramHashMappingJAX(
770+
engram_vocab_size=self.config.engram_vocab_size,
771+
max_ngram_size=self.config.engram_max_ngram_size,
772+
n_embed_per_ngram=self.config.engram_embed_dim_per_ngram,
773+
n_head_per_ngram=self.config.engram_heads_per_ngram,
774+
# IMPORTANT: We must pass the FULL list of layer_ids
775+
# The mapping finds primes sequentially across all layers
776+
layer_ids=self.config.engram_layer_ids,
777+
tokenizer=tokenizer,
778+
pad_id=self.config.engram_pad_id,
779+
seed=self.config.engram_seed,
780+
)
781+
782+
vocab_sizes = jax_hash_mapping.get_vocab_sizes(self.layer_id)
783+
784+
# Setup model
785+
cfg, mesh = get_cfg_and_mesh(self.config)
759786
jax_layer = EngramJAX(
760-
layer_id=self.layer_id,
761787
rngs=self.nnx_rng,
762788
config=cfg,
763789
mesh=mesh,
764-
tokenizer=tokenizer,
790+
vocab_sizes=vocab_sizes,
765791
hc_mult=self.config.hc_mult,
766792
engram_heads_per_ngram=self.config.engram_heads_per_ngram,
767793
engram_embed_dim_per_ngram=self.config.engram_embed_dim_per_ngram,
768794
engram_max_ngram_size=self.config.engram_max_ngram_size,
769795
engram_kernel_size=self.config.engram_kernel_size,
770-
engram_vocab_size=self.config.engram_vocab_size,
771-
layer_ids=self.config.engram_layer_ids,
772-
pad_id=self.config.engram_pad_id,
773-
seed=self.config.engram_seed,
774796
)
775797

776-
print("torch_layer", pt_layer.state_dict())
777-
print("jax_layer", jax_layer)
778-
779-
# 3. Synchronize Weights
798+
# Synchronize Weights
780799
jax_weights = to_jax_engram(pt_layer)
781800
nnx.update(jax_layer, jax_weights)
782801

783-
# 4. Prepare Inputs
784-
# Create random input_ids and hidden_states
785-
input_ids_np = np.random.randint(0, 1000, (batch_size, seq_len))
786-
787-
pt_input_ids = torch.from_numpy(input_ids_np)
788-
789-
# (B, L, G, D)
790-
pt_hidden_states = torch.randn(
791-
batch_size, seq_len, self.backbone_config.hc_mult, self.backbone_config.hidden_size, dtype=torch.float32
792-
)
802+
jax_hash_input_ids = jax_hash_mapping.hash(input_ids_np)[self.layer_id]
793803
jax_hidden_states = to_jax(pt_hidden_states)
794804

795-
# 5. Run Inference
796-
with torch.no_grad():
797-
pt_out = pt_layer(pt_hidden_states, pt_input_ids, self.backbone_config)
798-
799-
jax_out = jax_layer(jax_hidden_states, to_jax(pt_input_ids))
805+
jax_out = jax_layer(jax_hidden_states, jax_hash_input_ids)
800806

801-
# 6. Numerical Comparison
807+
# 3 Compare
802808
print(f"\nPT Output Mean: {pt_out.mean().item():.6f}")
803809
print(f"JAX Output Mean: {jax_out.mean():.6f}")
804810

0 commit comments

Comments
 (0)