Skip to content

Add DeepSeek Engram layer#3010

Open
shuningjin wants to merge 1 commit intomainfrom
shuningjin-engram
Open

Add DeepSeek Engram layer#3010
shuningjin wants to merge 1 commit intomainfrom
shuningjin-engram

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Jan 26, 2026

Description

Background

What this PR does

Add Engram layer: engram.py

  • compressed tokenizer (non-parametric)
  • n-gram hash mapping (non-parametric)
  • multi-head embedding
  • short convolution (multi-branch)
  • engram (multi-branch)

Add unit test: tests.unit.engram_vs_reference_test

  • verify each module match output of reference code

Implementation Notes

Relationship of components

  • n-gram hash mapping: encompass compressed tokenizer.
  • Engram: encompass multi-head embedding, short convolution
  • n-gram hash mapping and Engram
    • n-gram hash mapping converts vanilla token-ids to hashed ngram token-ids, which Engram consumes for embedding lookup
    • Future: n-gram hash mapping will need to be inserted in Data Input Pipeline in future integration

Multi-branch

  • engram and shortconv handles multi-branch input and multi-branch output (if hc_mult > 1), optimized with nnx.vmap
  • Future: to be integrated into multi-branch backbone like mHC.

Tests

unit test against reference

python3 -m pytest -v --pyargs tests.unit.engram_vs_reference_test -rP -s

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Jan 26, 2026

Codecov Report

❌ Patch coverage is 0% with 209 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/layers/engram.py 0.00% 209 Missing ⚠️

📢 Thoughts on this report? Let us know!

@shuningjin shuningjin changed the title [DRAFT] do no merge [DRAFT] engram Jan 29, 2026
@shuningjin shuningjin force-pushed the shuningjin-engram branch 2 times, most recently from 93458cf to 21cec5f Compare January 30, 2026 17:52
@shuningjin shuningjin changed the title [DRAFT] engram Add DeepSeek Engram layer Feb 4, 2026
@shuningjin shuningjin marked this pull request as ready for review February 4, 2026 21:48
@shuningjin
Copy link
Collaborator Author

@gemini-cli /review

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gemini-cli /review

@github-actions
Copy link

github-actions bot commented Feb 5, 2026

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This pull request introduces a JAX implementation of the DeepSeek Engram layer, along with comprehensive unit tests that validate its behavior against a PyTorch reference. The code is well-structured and the implementation appears to be correct and thorough. The core logic is sound, and the use of vectorization with nnx.vmap is a good practice for performance.

🔍 General Feedback

  • Good Testing: The inclusion of unit tests comparing the JAX implementation to a PyTorch reference is excellent. This provides high confidence in the correctness of the implementation.
  • Clear Implementation: The code in engram.py is well-commented and organized, making it easy to follow the logic from the original paper.
  • TODOs: I've commented on the TODOs left in the code. Addressing them will improve the clarity and robustness of the implementation.

# Structure: {layer_id: [[2gram_head1, ..., 2gram_headH], ..., [Ngram_head1, ..., Ngram_headH]]}
self.vocab_size_across_layers = self._calculate_vocab_size_across_layers()

def _calculate_multipliers_across_layers(self, seed: int):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Your comment here is valid and raises a good question. Using `tokenizer.pad_id` directly would be more robust and less prone to configuration errors. If there's a specific reason to pass `pad_id` separately and then look it up in the `lookup_table`, it would be beneficial to document that reasoning here. Otherwise, I'd recommend simplifying this to use the tokenizer's padding ID directly.

kernel_axes=("engram_dim", "embed"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 You've correctly identified a point of potential confusion regarding the logical axes for sharding. For clarity and to ensure correct behavior in a distributed environment, it would be best to confirm the intended logical axis names. This will help maintainers and future contributors understand the sharding strategy.

# Value Projection (shared): Retrieved memory -> Value
self.value_proj = DenseGeneral(
in_features_shape=self.engram_dim,
out_features_shape=config.base_emb_dim,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 This is another good catch regarding the logical axis names. Explicitly defining these based on your sharding plan will improve the code's readability and prevent potential issues with model parallelism. Please verify the correct logical axis names to be used here.

@github-actions
Copy link

github-actions bot commented Feb 5, 2026

🤖 I'm sorry @RissyRan, but I was unable to process your request. Please see the logs for more details.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the test and CompressedTokenizer. Will continue to review the rest part tomorrow.


"""
DeepSeek-AI, `Conditional Memory via Scalable Lookup: A New Axis of Sparsity for Large Language Models
<https://arxiv.org/pdf/2601.07372>`_, 2026
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this extra "_" on purpose?

"""

def __init__(self, tokenizer: HFTokenizer):
# TODO(shuningjin): maybe don't need to hold tokenizer, if we only use the lookup table as bridge
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the consequence if we remove it here?

engram_head_dim: int = 32
engram_num_heads: int = 8 # num heads per n-gram
# Hashing
engram_pad_id: int = 2 # TODO(shuningjin): not the same as tokenizer.pad_id?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be defined by users?

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! I left some initial comments, and may need to go over multihead embedding and conv parts. It should be quick.

n-grams into fixed integer IDs. To handle the large combinatorial space, it uses:
1. Unique Prime Vocabularies: Per-head prime moduli to minimize collision overlap.
2. Sliding Window: Efficient shifting to generate n-gram views.
3. Lightweight Hashing: A multiplicative-XOR function (Rabin-Karp variant).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I may miss this part in reference implementation. Did you add this optimization?

A dictionary mapping layer_id to a list of `max_ngram_size` multipliers.
"""
# Pre-calculate bounds for random generation
max_long = np.iinfo(np.int64).max
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you cross check if we could update all np to jnp?

LAYER_PRIME_OFFSET = 10007

layer_multipliers = {}
for layer_id in self.layer_ids:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could update this block using vectorized operation? dim will depends on len(layer_ids). It's fixed at compile time.

quant: Optional[Quant] = None,
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"),
*,
hc_mult: int = 4,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we put params with default value at the very end?

axis=-1,
kernel_init=self.kernel_init,
# TODO(shuningjin): this needs to be actual logical axis? @reviewer
kernel_axes=("engram_dim", "embed"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could add the sharding constraint into base.yml.

logical_axis_rules: [

I see it is smaller dim compared to emb, we could shard on tensor as a starting point. I see embed usually sharding on fsdp, sequence, context etc.

Shape annotation:
B: Batch Size
S: Sequence Length
G: hc_mult, Number of Branches
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to separate this config or treat it same as mhc_expansion_rate?

# Norms (vectorized)
# Independent weights per branch, Branched input
@nnx.split_rngs(splits=hc_mult)
@nnx.vmap(in_axes=0, out_axes=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sharding on batch dimension? Why is that? Similar comment for other in_axes=0 vmap op.

# Vectorized broadcast: apply each of the G key_projs to the SAME embeddings.
# in_axes: (0, None) -> 0 splits the Dense layers, None broadcasts embeddings
# out_axes: 2 -> Stack the results at axis 2 to get (B, S, G, D)
@nnx.vmap(in_axes=(0, None), out_axes=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if this in_axes working properly? I see your unit test has b=2 setup. When I integrated flash attn with sparse attn, I have to change the unit test to from b=2 to b=4 when sharding on fsdp, otherwise, it will fail on v5p-8 local machine.

max_ngram_size: int,
engram_num_heads: int,
layer_ids: List[int],
tokenizer: HFTokenizer,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you saying would like to put the look up table into data pipeline. Is this structure or performance beneficial? When we call the engram from decoder layer, we need to pass this tokenizer. So you are thinking, this engram module will call/depend on data pipeline to look up?

self.backbone_config = BackBoneConfig(self.config)
tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_path, trust_remote_code=True)
# input
batch, seq_len = 2, 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we set up a longer sequence, like 8, so test overlap of 2/3-grams?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants