Commit 37ded59
PR #2831: Migrate Decoder to NNX
Imported from GitHub PR #2831
# Description
Migrate the Transformer decoder layer into NNX.
Note: The following models are currently not supported:
- DeepSeek
- Gemma3
- Llama4
Support for these models will be added in a follow-up PR.
Strategy:
A `pure_nnx_decoder` flag is added to control whether NNX or Linen decoder shall be used.
Initial migration doesn't include the pipeline NNX support.
# Tests
Conducted these tests. Details in the [GDoc file](https://docs.google.com/document/d/1NbUP3g5glgbC6bMyt44pwM_vQA1NR7U2rBUzfbTDwSs/edit?pli=1&resourcekey=0-9EUahtzL-hCycdu7l0grhQ&tab=t.htq5367h8au0)
1. Test with different model and compare with Linen training
2. Golden logits comparison
3. Inference
4. Checkpoint comparison (Including TreeStructure Comparison)
5. Sharding comparison
TODOs:
- NNX version of unit tests (future PRs)
# Checklist
Before submitting this PR, please make sure (put X in square brackets):
- [x] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label.
- [x] I have necessary comments in my code, particularly in hard-to-understand areas.
- [x] I have run end-to-end tests tests and provided workload links above if applicable.
- [x] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in [our documentation](https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files).
Copybara import of the project:
--
073e916 by hsuan-lun-chiang <hsuan-lun.chiang@cienet.com>:
Migrate Decoder to NNX
Adding nnx_decoders.py in parallel with decoders.py
1. Dup and modifiy decoders.py on new file nnx_decoders.py
2. add new config pure_nnx_decoder to control if model will use NNXDecoder, default false for now
3. modify relative code to accomodate the change
4. add/modify unit test
Merging this change closes #2831
COPYBARA_INTEGRATE_REVIEW=#2831 from CIeNET-International:feat/Migrate-Decoder-to-NNX 073e916
PiperOrigin-RevId: 8841709821 parent ca7e2df commit 37ded59
7 files changed
Lines changed: 1746 additions & 53 deletions
File tree
- src/maxtext
- configs
- layers
- models
- tests/unit
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1087 | 1087 | | |
1088 | 1088 | | |
1089 | 1089 | | |
| 1090 | + | |
1090 | 1091 | | |
1091 | 1092 | | |
1092 | 1093 | | |
| |||
1152 | 1153 | | |
1153 | 1154 | | |
1154 | 1155 | | |
1155 | | - | |
| 1156 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
783 | 783 | | |
784 | 784 | | |
785 | 785 | | |
| 786 | + | |
786 | 787 | | |
787 | 788 | | |
788 | 789 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
25 | 26 | | |
26 | | - | |
27 | 27 | | |
28 | 28 | | |
29 | 29 | | |
| |||
70 | 70 | | |
71 | 71 | | |
72 | 72 | | |
73 | | - | |
| 73 | + | |
74 | 74 | | |
75 | 75 | | |
76 | 76 | | |
| |||
108 | 108 | | |
109 | 109 | | |
110 | 110 | | |
111 | | - | |
| 111 | + | |
112 | 112 | | |
113 | 113 | | |
114 | 114 | | |
115 | 115 | | |
116 | | - | |
117 | | - | |
118 | | - | |
119 | | - | |
120 | | - | |
121 | | - | |
122 | | - | |
123 | | - | |
124 | | - | |
125 | | - | |
126 | | - | |
| 116 | + | |
127 | 117 | | |
128 | 118 | | |
129 | 119 | | |
| |||
212 | 202 | | |
213 | 203 | | |
214 | 204 | | |
215 | | - | |
| 205 | + | |
216 | 206 | | |
217 | 207 | | |
218 | 208 | | |
| |||
0 commit comments