feat: migrate pipeline to nnx#2885
feat: migrate pipeline to nnx#2885mesakhcienet wants to merge 5 commits intoAI-Hypercomputer:mainfrom
Conversation
6875da8 to
f34b1a3
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
12a3907 to
2c16599
Compare
64dc147 to
9e4518e
Compare
631a73e to
ac97a1d
Compare
1849f0b to
669dc01
Compare
2d742f9 to
fc3fe0b
Compare
2e46721 to
b732cb3
Compare
618de58 to
e7656b2
Compare
bvandermoon
left a comment
There was a problem hiding this comment.
@gobbleturk what testing do you recommend for migrating pipeline parallelism to NNX? I'll send over an internal doc @hsuan-lun-chiang, @mesakhcienet, and others put together that shows the tests they have already run
@NuojCheng any thoughts here? |
NuojCheng
left a comment
There was a problem hiding this comment.
Some additional train compile test for pipeline NNX migration:
- Train compile test 1: https://paste.googleplex.com/5960957017849856
- Train compile test 2: https://paste.googleplex.com/5749974483730432
- Train compile test 3: https://paste.googleplex.com/5201745681711104
If the train compile tests above can pass without getting OOM + current tests in pipeline_parallelism_test.py can all pass, then I think it is good to go! Please ping me if the PR is ready for review.
|
There are also some linen usage in
I don't see them get updated in this PR but I think they probably should be updated? Another thing is the usage of function in maxtext/src/maxtext/utils/pipeline_utils.py Lines 151 to 162 in 77f5334 |
Description
implement nnx-based pipeline.
This PR extends PR#2831
Main changes:
NNXPipeline, which is a nnx-based pipeline class.Tests
we run the pipeline process with command below:
MODEL_NAME=llama2-7b python -m MaxText.train src/maxtext/configs/base.yml \ run_name=pipeline_test_${MODEL_NAME}_nnx \ base_output_directory=/dev/shm/pipeline_test_nnx \ model_name=${MODEL_NAME}\ dataset_type=synthetic \ steps=15 \ debug_sharding=true \ per_device_batch_size=2 \ max_target_length=32 \ ici_pipeline_parallelism=2 \ num_pipeline_microbatches=4 \ num_layers_per_pipeline_stage=2 \ enable_checkpointing=false \ enable_nnx=true \ pure_nnx_decoder=true \ scan_layers_per_stage=false \ async_checkpointing=false > nnx-porting-log/pipeline/custom_${MODEL_NAME}.log 2>&1Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.