1616"""
1717Tests for Engram: MultiHeadEmbedding, ShortConv, Engram
1818
19- reference : https://github.com/deepseek-ai/Engram/blob/fb7f84a21f91223715394a33a1dc24bbfb7f788e/engram_demo_v1.py
19+ Reference implementation : https://github.com/deepseek-ai/Engram/blob/fb7f84a21f91223715394a33a1dc24bbfb7f788e/engram_demo_v1.py
2020s
21+
2122To run the test
2223 pip install torch numpy transformers sympy
2324 python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
5455from MaxText .layers .engram import Engram as EngramJAX
5556from MaxText .layers .engram import ShortConv as ShortConvJAX
5657from MaxText .layers .engram import MultiHeadEmbedding as MultiHeadEmbeddingJAX
58+ from MaxText .layers .engram import NgramHashMapping as NgramHashMappingJAX
5759
5860
5961# -----------------------------------------------------------------------------
@@ -605,15 +607,6 @@ def to_jax_norm(pt_norm):
605607 """Extracts scale parameter from a norm layer."""
606608 return {"scale" : to_jax (pt_norm .weight )}
607609
608-
609- def to_jax_linear (pt_linear ):
610- """(Out, In) -> {'kernel': (In, Out), 'bias': (Out)}"""
611- out = {"kernel" : to_jax (pt_linear .weight .T )}
612- if pt_linear .bias is not None :
613- out ["bias" ] = to_jax (pt_linear .bias )
614- return out
615-
616-
617610def to_jax_vmap (pt_module_list , transform_fn ):
618611 """
619612 Applies transform_fn to a list of modules and stacks the
@@ -628,12 +621,10 @@ def to_jax_shortconv(pt_layer):
628621 """
629622 Converts a ShortConv layer containing a Conv and a ModuleList of Norms.
630623 """
631- # 1. Conv Weights. PyTorch: (Out, In//Groups, Kernel) -> JAX: (Kernel, In//Groups, Out)
632- conv_kernel = pt_layer .conv .weight .permute (2 , 1 , 0 )
633-
634624 return {
635- "conv" : {"kernel" : to_jax (conv_kernel )},
636- # 2. Weights for the Norms: List[Norm] -> {'scale': (Groups, Channels)}
625+ # (Out, In//Groups, Kernel) -> (Kernel, In//Groups, Out)
626+ "conv" : {"kernel" : to_jax (pt_layer .conv .weight .permute (2 , 1 , 0 ))},
627+ # List[Norm] -> Stacked norm (Groups, Channels)
637628 "norms" : to_jax_vmap (pt_layer .norms , to_jax_norm ),
638629 }
639630
@@ -644,6 +635,7 @@ def setUp(self):
644635 super ().setUp ()
645636 torch .manual_seed (42 )
646637 np .random .seed (42 )
638+ self .nnx_rngs = nnx .Rngs (params = 0 )
647639
648640 @parameterized .named_parameters (
649641 # {"testcase_name": "base", "hidden_size": 32, "hc_mult": 4, "kernel_size": 4, "dilation": 1},
@@ -666,10 +658,11 @@ def test_shortconv_match(self, hidden_size, hc_mult, kernel_size, dilation):
666658 pt_model .eval ()
667659
668660 # 2. Init JAX
669- rngs = nnx .Rngs (params = 0 )
670661 config = Config ()
671662 cfg , mesh = get_cfg_and_mesh (config )
672- jax_model = ShortConvJAX (cfg , hidden_size , kernel_size , dilation , hc_mult = hc_mult , activation = activation , rngs = rngs )
663+ jax_model = ShortConvJAX (
664+ cfg , hidden_size , kernel_size , dilation , hc_mult = hc_mult , activation = activation , rngs = self .nnx_rngs
665+ )
673666 print (jax_model )
674667
675668 # 3. Transfer Weights
@@ -705,6 +698,14 @@ def test_shortconv_match(self, hidden_size, hc_mult, kernel_size, dilation):
705698# -----------------------------------------------------------------------------
706699
707700
701+ def to_jax_linear (pt_linear ):
702+ """(Out, In) -> {'kernel': (In, Out), 'bias': (Out)}"""
703+ out = {"kernel" : to_jax (pt_linear .weight .T )}
704+ if pt_linear .bias is not None :
705+ out ["bias" ] = to_jax (pt_linear .bias )
706+ return out
707+
708+
708709def to_jax_engram (pt_engram ) -> dict :
709710 return {
710711 "mhe" : to_jax_mhe (pt_engram .multi_head_embedding ),
@@ -728,6 +729,7 @@ def setUp(self):
728729 torch .set_default_dtype (torch .float32 )
729730 torch .manual_seed (42 )
730731 np .random .seed (42 )
732+ self .nnx_rng = nnx .Rngs (params = 0 )
731733
732734 self .batch_size = 2
733735 self .seq_len = 8
@@ -737,68 +739,72 @@ def setUp(self):
737739 self .engram_cfg = EngramConfig (self .config )
738740 self .backbone_config = BackBoneConfig (self .config )
739741
740- self .nnx_rng = nnx .Rngs (params = 0 )
741-
742742 @parameterized .named_parameters (
743743 {"testcase_name" : "standard_run" , "batch_size" : 2 , "seq_len" : 16 },
744744 )
745745 def test_engram_match (self , batch_size , seq_len ):
746- # 1. Setup PyTorch Reference
747-
746+ # 1. torch
748747 EngramPT = Engram
749748 pt_layer = EngramPT (layer_id = self .layer_id , backbone_config = self .backbone_config , engram_cfg = self .engram_cfg )
750749 init_torch_weights (pt_layer )
751750 pt_layer .eval ()
752751
752+ # Prepare Inputs
753+ # Create random input_ids and hidden_states
754+ input_ids_np = np .random .randint (0 , 1000 , (batch_size , seq_len ))
755+ pt_input_ids = torch .from_numpy (input_ids_np )
756+ # (B, L, G, D)
757+ pt_hidden_states = torch .randn (
758+ batch_size , seq_len , self .backbone_config .hc_mult , self .backbone_config .hidden_size , dtype = torch .float32
759+ )
760+
761+ # Run Inference
762+ with torch .no_grad ():
763+ pt_out = pt_layer (pt_hidden_states , pt_input_ids , self .backbone_config )
764+
765+ # 2 Jax
753766 # "deepseek-ai/DeepSeek-V3"
754767 tokenizer = AutoTokenizer .from_pretrained (self .config .tokenizer_path , trust_remote_code = True )
755768
756- # 2. Setup JAX NNX Implementation
757- config = Config ()
758- cfg , mesh = get_cfg_and_mesh (config )
769+ jax_hash_mapping = NgramHashMappingJAX (
770+ engram_vocab_size = self .config .engram_vocab_size ,
771+ max_ngram_size = self .config .engram_max_ngram_size ,
772+ n_embed_per_ngram = self .config .engram_embed_dim_per_ngram ,
773+ n_head_per_ngram = self .config .engram_heads_per_ngram ,
774+ # IMPORTANT: We must pass the FULL list of layer_ids
775+ # The mapping finds primes sequentially across all layers
776+ layer_ids = self .config .engram_layer_ids ,
777+ tokenizer = tokenizer ,
778+ pad_id = self .config .engram_pad_id ,
779+ seed = self .config .engram_seed ,
780+ )
781+
782+ vocab_sizes = jax_hash_mapping .get_vocab_sizes (self .layer_id )
783+
784+ # Setup model
785+ cfg , mesh = get_cfg_and_mesh (self .config )
759786 jax_layer = EngramJAX (
760- layer_id = self .layer_id ,
761787 rngs = self .nnx_rng ,
762788 config = cfg ,
763789 mesh = mesh ,
764- tokenizer = tokenizer ,
790+ vocab_sizes = vocab_sizes ,
765791 hc_mult = self .config .hc_mult ,
766792 engram_heads_per_ngram = self .config .engram_heads_per_ngram ,
767793 engram_embed_dim_per_ngram = self .config .engram_embed_dim_per_ngram ,
768794 engram_max_ngram_size = self .config .engram_max_ngram_size ,
769795 engram_kernel_size = self .config .engram_kernel_size ,
770- engram_vocab_size = self .config .engram_vocab_size ,
771- layer_ids = self .config .engram_layer_ids ,
772- pad_id = self .config .engram_pad_id ,
773- seed = self .config .engram_seed ,
774796 )
775797
776- print ("torch_layer" , pt_layer .state_dict ())
777- print ("jax_layer" , jax_layer )
778-
779- # 3. Synchronize Weights
798+ # Synchronize Weights
780799 jax_weights = to_jax_engram (pt_layer )
781800 nnx .update (jax_layer , jax_weights )
782801
783- # 4. Prepare Inputs
784- # Create random input_ids and hidden_states
785- input_ids_np = np .random .randint (0 , 1000 , (batch_size , seq_len ))
786-
787- pt_input_ids = torch .from_numpy (input_ids_np )
788-
789- # (B, L, G, D)
790- pt_hidden_states = torch .randn (
791- batch_size , seq_len , self .backbone_config .hc_mult , self .backbone_config .hidden_size , dtype = torch .float32
792- )
802+ jax_hash_input_ids = jax_hash_mapping .hash (input_ids_np )[self .layer_id ]
793803 jax_hidden_states = to_jax (pt_hidden_states )
794804
795- # 5. Run Inference
796- with torch .no_grad ():
797- pt_out = pt_layer (pt_hidden_states , pt_input_ids , self .backbone_config )
798-
799- jax_out = jax_layer (jax_hidden_states , to_jax (pt_input_ids ))
805+ jax_out = jax_layer (jax_hidden_states , jax_hash_input_ids )
800806
801- # 6. Numerical Comparison
807+ # 3 Compare
802808 print (f"\n PT Output Mean: { pt_out .mean ().item ():.6f} " )
803809 print (f"JAX Output Mean: { jax_out .mean ():.6f} " )
804810
0 commit comments