Skip to content

Fix distillation checkpoint on resumption#3464

Open
gagika wants to merge 1 commit intomainfrom
agagik-distill-checkpoint
Open

Fix distillation checkpoint on resumption#3464
gagika wants to merge 1 commit intomainfrom
agagik-distill-checkpoint

Conversation

@gagika
Copy link
Collaborator

@gagika gagika commented Mar 20, 2026

Description

This PR fixes a bugs that prevented resuming distillation training from a trained student checkpoint.

Key Changes:

  • Fixes Restoring from a Trained Student Checkpoint: Resolves a ValueError related to the missing 'norm' axis during checkpoint restoration by expanding the nn_partitioning.axis_rules context manager to wrap both model initialization and checkpoint loading.
  • Fixes Optimizer State Saving: Explicitly passes optimizer=trainer.optimizer during the final conditional save block so the optimizer state is correctly written to disk, preventing a KeyError on resume.
  • Code Refactoring: Extracts non-hardware configuration logic (Tokenizer, Strategy, TrainingConfig) into a clean build_training_components helper for better modularity and readability.

Tests

  • Added test_train_save_and_resume to verify end-to-end checkpoint saving and successful loop resumption.
  • Added test_checkpointing_and_resume to ensure the optimizer state and model weights are mathematically restored correctly upon initialization.
  • Tested end to end with real run by killing distillation and resuming it.

Checklist

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@gagika gagika changed the title Refactored the distillation input pipeline and checkpoint manager to … Fix distillation checkpoint resumption: Apply axis rules during restore and include optimizer state Mar 20, 2026
@gagika gagika changed the title Fix distillation checkpoint resumption: Apply axis rules during restore and include optimizer state Fix distillation checkpoint on resumption Mar 20, 2026
@github-actions
Copy link

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@codecov
Copy link

codecov bot commented Mar 20, 2026

Codecov Report

❌ Patch coverage is 80.14706% with 27 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
.../trainers/post_train/distillation/train_distill.py 82.29% 11 Missing and 6 partials ⚠️
...ners/post_train/distillation/distillation_utils.py 75.00% 4 Missing and 6 partials ⚠️

📢 Thoughts on this report? Let us know!

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request successfully introduces fixes for resuming distillation training from checkpoints by managing the _train_steps and _restored_custom_metadata appropriately using maybe_restore. It also reorganizes initialization steps for modularity and handles optimizer_state explicitly in the save call.

🔍 General Feedback

  • The refactoring into build_training_components and the usage of nn_partitioning.axis_rules wrapper around initialization cleans up the setup significantly.
  • Solid set of unit tests were added to ensure end-to-end functionality including checking whether optimizer state was loaded correctly.
  • Added an inline suggestion for a type handler during array restoration to fix potential issues with primitive types.
  • Pointed out an opportunity to conditionally compute grad_norm to enhance readability and ensure unnecessary XLA graphs are not added when unneeded.

@gagika gagika force-pushed the agagik-distill-checkpoint branch 2 times, most recently from 6bbb69b to ed51fc7 Compare March 20, 2026 04:06
@github-actions
Copy link

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@gagika gagika force-pushed the agagik-distill-checkpoint branch from ed51fc7 to fb47422 Compare March 20, 2026 04:13
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR successfully fixes the checkpoint resumption and optimizer state persistence bugs during distillation training. The implementation cleanly addresses the previously missing norm axis issue by appropriately wrapping the loading logic within the axis_rules context manager, and it effectively extracts standard configurations into build_training_components for enhanced readability.

🔍 General Feedback

  • Checkpoint Resumption Logic: The new maybe_restore logic handles model bundles properly via getattr(model, "student_model", model).
  • Modularity Enhancements: Extracting the build_training_components logically separates standard Tunix setups from MaxText-specific hardware contexts.
  • Defensive Configurations: Added a comment inline regarding getattr over get_with_default for instances managed externally (e.g., Tunix dataclasses) to prevent potential attribute errors on property access.

self.optimizer,
restore_only_lora_params=getattr(self, "_lora_enabled", False),
)
grad_accum_steps = self.config.get_with_default("gradient_accumulation_steps", 1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Standard dataclasses or configurations from external libraries (like Tunix's `TrainingConfig`) typically do not implement a `get_with_default` method. This might lead to an `AttributeError` at runtime. Using `getattr` is a safer, more pythonic approach.
Suggested change
grad_accum_steps = self.config.get_with_default("gradient_accumulation_steps", 1)
grad_accum_steps = getattr(self.config, "gradient_accumulation_steps", 1)

…support state restoration via maybe_restore.
@gagika gagika force-pushed the agagik-distill-checkpoint branch from 7638c8f to 359ac86 Compare March 21, 2026 00:04
@github-actions
Copy link

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR integrates and refactors the distillation components with Tunix and adds checkpoint management capabilities for MaxTextDistillationTrainer. The implementation handles complex orchestration between teacher and student models appropriately and isolates offline vs online distillation setups.

🔍 General Feedback

  • Good use of maybe_restore and _shard_optimizer to handle sharding subtleties specific to JAX when restoring models and optimizer states from checkpoints.
  • Passing trainer.model to trainer.checkpoint_manager.save() correctly mirrors the new ModelBundle abstraction.
  • Be extremely cautious with source-code introspection via inspect.getsource; while functional as a fallback, it adds latent fragility that is best replaced by explicit version checks in future iterations.

Comment on lines +495 to +498
metadata = self._checkpoint_manager.metadata(step)
if metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None:
custom_metadata = metadata.custom_metadata
else:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 If `metadata` returned by `self._checkpoint_manager.metadata(step)` is a dictionary (which depends on the Orbax version), `hasattr(metadata, "custom_metadata")` will evaluate to `False`, and `custom_metadata` will silently default to an empty dictionary, losing the metadata.

Consider handling the case where metadata is a dictionary directly:

Suggested change
metadata = self._checkpoint_manager.metadata(step)
if metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None:
custom_metadata = metadata.custom_metadata
else:
if isinstance(metadata, dict):
custom_metadata = metadata
elif metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None:
custom_metadata = metadata.custom_metadata
else:
custom_metadata = {}

self._tunix_expects_grad_norm = False
try:
source = inspect.getsource(peft_trainer.PeftTrainer._train_step)
self._tunix_expects_grad_norm = "grad_norm" in source

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Using string matching on the source code to determine an API signature is highly brittle. This could easily cause a runtime crash (`ValueError` during unpacking) if Tunix updates its code to include the word "grad_norm" in a comment, docstring, or a different context within `_train_step` without actually changing the return signature.

Consider relying on tunix.__version__ to check for this API change. If a version check isn't feasible, consider a safer inspection technique like parsing the AST or inspecting the method signature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants