Conversation
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
This Pull Request successfully introduces fixes for resuming distillation training from checkpoints by managing the _train_steps and _restored_custom_metadata appropriately using maybe_restore. It also reorganizes initialization steps for modularity and handles optimizer_state explicitly in the save call.
🔍 General Feedback
- The refactoring into
build_training_componentsand the usage ofnn_partitioning.axis_ruleswrapper around initialization cleans up the setup significantly. - Solid set of unit tests were added to ensure end-to-end functionality including checking whether optimizer state was loaded correctly.
- Added an inline suggestion for a type handler during array restoration to fix potential issues with primitive types.
- Pointed out an opportunity to conditionally compute
grad_normto enhance readability and ensure unnecessary XLA graphs are not added when unneeded.
src/maxtext/trainers/post_train/distillation/distillation_utils.py
Outdated
Show resolved
Hide resolved
6bbb69b to
ed51fc7
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
ed51fc7 to
fb47422
Compare
There was a problem hiding this comment.
This PR successfully fixes the checkpoint resumption and optimizer state persistence bugs during distillation training. The implementation cleanly addresses the previously missing norm axis issue by appropriately wrapping the loading logic within the axis_rules context manager, and it effectively extracts standard configurations into build_training_components for enhanced readability.
🔍 General Feedback
- Checkpoint Resumption Logic: The new
maybe_restorelogic handles model bundles properly viagetattr(model, "student_model", model). - Modularity Enhancements: Extracting the
build_training_componentslogically separates standard Tunix setups from MaxText-specific hardware contexts. - Defensive Configurations: Added a comment inline regarding
getattroverget_with_defaultfor instances managed externally (e.g., Tunix dataclasses) to prevent potential attribute errors on property access.
| self.optimizer, | ||
| restore_only_lora_params=getattr(self, "_lora_enabled", False), | ||
| ) | ||
| grad_accum_steps = self.config.get_with_default("gradient_accumulation_steps", 1) |
There was a problem hiding this comment.
| grad_accum_steps = self.config.get_with_default("gradient_accumulation_steps", 1) | |
| grad_accum_steps = getattr(self.config, "gradient_accumulation_steps", 1) |
4892155 to
7638c8f
Compare
…support state restoration via maybe_restore.
7638c8f to
359ac86
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR integrates and refactors the distillation components with Tunix and adds checkpoint management capabilities for MaxTextDistillationTrainer. The implementation handles complex orchestration between teacher and student models appropriately and isolates offline vs online distillation setups.
🔍 General Feedback
- Good use of
maybe_restoreand_shard_optimizerto handle sharding subtleties specific to JAX when restoring models and optimizer states from checkpoints. - Passing
trainer.modeltotrainer.checkpoint_manager.save()correctly mirrors the new ModelBundle abstraction. - Be extremely cautious with source-code introspection via
inspect.getsource; while functional as a fallback, it adds latent fragility that is best replaced by explicit version checks in future iterations.
| metadata = self._checkpoint_manager.metadata(step) | ||
| if metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None: | ||
| custom_metadata = metadata.custom_metadata | ||
| else: |
There was a problem hiding this comment.
Consider handling the case where metadata is a dictionary directly:
| metadata = self._checkpoint_manager.metadata(step) | |
| if metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None: | |
| custom_metadata = metadata.custom_metadata | |
| else: | |
| if isinstance(metadata, dict): | |
| custom_metadata = metadata | |
| elif metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None: | |
| custom_metadata = metadata.custom_metadata | |
| else: | |
| custom_metadata = {} |
| self._tunix_expects_grad_norm = False | ||
| try: | ||
| source = inspect.getsource(peft_trainer.PeftTrainer._train_step) | ||
| self._tunix_expects_grad_norm = "grad_norm" in source |
There was a problem hiding this comment.
Consider relying on tunix.__version__ to check for this API change. If a version check isn't feasible, consider a safer inspection technique like parsing the AST or inspecting the method signature.
Description
This PR fixes a bugs that prevented resuming distillation training from a trained student checkpoint.
Key Changes:
ValueErrorrelated to the missing'norm'axis during checkpoint restoration by expanding thenn_partitioning.axis_rulescontext manager to wrap both model initialization and checkpoint loading.optimizer=trainer.optimizerduring the final conditional save block so the optimizer state is correctly written to disk, preventing aKeyErroron resume.build_training_componentshelper for better modularity and readability.Tests
test_train_save_and_resumeto verify end-to-end checkpoint saving and successful loop resumption.test_checkpointing_and_resumeto ensure the optimizer state and model weights are mathematically restored correctly upon initialization.Checklist
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.