Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
e8ae3c9 to
f095801
Compare
93458cf to
21cec5f
Compare
bb190ed to
2dc37df
Compare
|
@gemini-cli /review |
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This pull request introduces a JAX implementation of the DeepSeek Engram layer, along with comprehensive unit tests that validate its behavior against a PyTorch reference. The code is well-structured and the implementation appears to be correct and thorough. The core logic is sound, and the use of vectorization with nnx.vmap is a good practice for performance.
🔍 General Feedback
- Good Testing: The inclusion of unit tests comparing the JAX implementation to a PyTorch reference is excellent. This provides high confidence in the correctness of the implementation.
- Clear Implementation: The code in
engram.pyis well-commented and organized, making it easy to follow the logic from the original paper. - TODOs: I've commented on the
TODOs left in the code. Addressing them will improve the clarity and robustness of the implementation.
| # Structure: {layer_id: [[2gram_head1, ..., 2gram_headH], ..., [Ngram_head1, ..., Ngram_headH]]} | ||
| self.vocab_size_across_layers = self._calculate_vocab_size_across_layers() | ||
|
|
||
| def _calculate_multipliers_across_layers(self, seed: int): |
There was a problem hiding this comment.
| kernel_axes=("engram_dim", "embed"), | ||
| dtype=self.dtype, | ||
| weight_dtype=self.weight_dtype, | ||
| quant=self.quant, |
There was a problem hiding this comment.
| # Value Projection (shared): Retrieved memory -> Value | ||
| self.value_proj = DenseGeneral( | ||
| in_features_shape=self.engram_dim, | ||
| out_features_shape=config.base_emb_dim, |
There was a problem hiding this comment.
|
🤖 I'm sorry @RissyRan, but I was unable to process your request. Please see the logs for more details. |
2dc37df to
5371cae
Compare
RissyRan
left a comment
There was a problem hiding this comment.
I reviewed the test and CompressedTokenizer. Will continue to review the rest part tomorrow.
|
|
||
| """ | ||
| DeepSeek-AI, `Conditional Memory via Scalable Lookup: A New Axis of Sparsity for Large Language Models | ||
| <https://arxiv.org/pdf/2601.07372>`_, 2026 |
There was a problem hiding this comment.
is this extra "_" on purpose?
| """ | ||
|
|
||
| def __init__(self, tokenizer: HFTokenizer): | ||
| # TODO(shuningjin): maybe don't need to hold tokenizer, if we only use the lookup table as bridge |
There was a problem hiding this comment.
What's the consequence if we remove it here?
| engram_head_dim: int = 32 | ||
| engram_num_heads: int = 8 # num heads per n-gram | ||
| # Hashing | ||
| engram_pad_id: int = 2 # TODO(shuningjin): not the same as tokenizer.pad_id? |
There was a problem hiding this comment.
Does this need to be defined by users?
RissyRan
left a comment
There was a problem hiding this comment.
Thanks for the change! I left some initial comments, and may need to go over multihead embedding and conv parts. It should be quick.
| n-grams into fixed integer IDs. To handle the large combinatorial space, it uses: | ||
| 1. Unique Prime Vocabularies: Per-head prime moduli to minimize collision overlap. | ||
| 2. Sliding Window: Efficient shifting to generate n-gram views. | ||
| 3. Lightweight Hashing: A multiplicative-XOR function (Rabin-Karp variant). |
There was a problem hiding this comment.
Nice! I may miss this part in reference implementation. Did you add this optimization?
| A dictionary mapping layer_id to a list of `max_ngram_size` multipliers. | ||
| """ | ||
| # Pre-calculate bounds for random generation | ||
| max_long = np.iinfo(np.int64).max |
There was a problem hiding this comment.
Could you cross check if we could update all np to jnp?
| LAYER_PRIME_OFFSET = 10007 | ||
|
|
||
| layer_multipliers = {} | ||
| for layer_id in self.layer_ids: |
There was a problem hiding this comment.
Do you think we could update this block using vectorized operation? dim will depends on len(layer_ids). It's fixed at compile time.
| quant: Optional[Quant] = None, | ||
| kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"), | ||
| *, | ||
| hc_mult: int = 4, |
There was a problem hiding this comment.
shall we put params with default value at the very end?
| axis=-1, | ||
| kernel_init=self.kernel_init, | ||
| # TODO(shuningjin): this needs to be actual logical axis? @reviewer | ||
| kernel_axes=("engram_dim", "embed"), |
There was a problem hiding this comment.
You could add the sharding constraint into base.yml.
maxtext/src/MaxText/configs/base.yml
Line 402 in 352dd58
I see it is smaller dim compared to emb, we could shard on tensor as a starting point. I see embed usually sharding on fsdp, sequence, context etc.
| Shape annotation: | ||
| B: Batch Size | ||
| S: Sequence Length | ||
| G: hc_mult, Number of Branches |
There was a problem hiding this comment.
Do you plan to separate this config or treat it same as mhc_expansion_rate?
| # Norms (vectorized) | ||
| # Independent weights per branch, Branched input | ||
| @nnx.split_rngs(splits=hc_mult) | ||
| @nnx.vmap(in_axes=0, out_axes=0) |
There was a problem hiding this comment.
Are you sharding on batch dimension? Why is that? Similar comment for other in_axes=0 vmap op.
| # Vectorized broadcast: apply each of the G key_projs to the SAME embeddings. | ||
| # in_axes: (0, None) -> 0 splits the Dense layers, None broadcasts embeddings | ||
| # out_axes: 2 -> Stack the results at axis 2 to get (B, S, G, D) | ||
| @nnx.vmap(in_axes=(0, None), out_axes=2) |
There was a problem hiding this comment.
Do you know if this in_axes working properly? I see your unit test has b=2 setup. When I integrated flash attn with sparse attn, I have to change the unit test to from b=2 to b=4 when sharding on fsdp, otherwise, it will fail on v5p-8 local machine.
| max_ngram_size: int, | ||
| engram_num_heads: int, | ||
| layer_ids: List[int], | ||
| tokenizer: HFTokenizer, |
There was a problem hiding this comment.
When you saying would like to put the look up table into data pipeline. Is this structure or performance beneficial? When we call the engram from decoder layer, we need to pass this tokenizer. So you are thinking, this engram module will call/depend on data pipeline to look up?
| self.backbone_config = BackBoneConfig(self.config) | ||
| tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_path, trust_remote_code=True) | ||
| # input | ||
| batch, seq_len = 2, 3 |
There was a problem hiding this comment.
Could we set up a longer sequence, like 8, so test overlap of 2/3-grams?
Description
Background
What this PR does
Add Engram layer:
engram.pyAdd unit test:
tests.unit.engram_vs_reference_testImplementation Notes
Relationship of components
n-gram hash mapping: encompasscompressed tokenizer.Engram: encompassmulti-head embedding,short convolutionn-gram hash mappingandEngramn-gram hash mappingconverts vanilla token-ids to hashed ngram token-ids, whichEngramconsumes for embedding lookupn-gram hash mappingwill need to be inserted in Data Input Pipeline in future integrationMulti-branch
engramandshortconvhandles multi-branch input and multi-branch output (ifhc_mult > 1), optimized with nnx.vmapTests
unit test against reference
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.