Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This Pull Request introduces support for dumping Jaxpr (JAX intermediate representation) during the training and compilation processes. This functionality allows developers to inspect and verify the JAX traces of the training step, which is useful for debugging and ensuring consistency between different execution modes.
Key Changes
1. Configuration Options
New configuration flags have been added to
src/MaxText/configs/base.ymland validated insrc/MaxText/configs/types.pyto control the Jaxpr dumping behavior:dump_jaxpr: Enables or disables Jaxpr dumping.dump_jaxpr_local_dir: Specifies the local directory where the.jaxprfiles are initially saved.dump_jaxpr_gcs_dir: An optional GCS directory for uploading the dumps.dump_jaxpr_delete_local_after: Determines whether to delete local copies after a successful GCS upload.2. Implementation of Jaxpr Dumping
A new utility function
maybe_dump_jaxprhas been implemented insrc/MaxText/maxtext_utils.py. This function:jax.make_jaxprto trace the jitted training step.train_step.jaxpr.gcs_utils.upload_dump.3. Integration into Training and Compilation
maybe_dump_jaxpris called within the main training loop insrc/MaxText/train.py.src/MaxText/train_compile.pyto capture jaxpr during the ahead-of-time compilation process.4. Test Refactoring and New Jaxpr Verification
The existing HLO identity tests in
tests/integration/aot_hlo_identical_test.pyhave been replaced by a more comprehensive test suite intests/integration/aot_identical_test.py. This new file includes:AotHloIdenticalTest: Updated verification for HLO graphs.AotJaxprIdenticalTest: A new test class that ensures the jaxpr generated during AOT compilation matches the jaxpr from a real training run.Tests
https://paste.googleplex.com/5644648066449408
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.