From 359ac86c72e0576616fdd96b4febb6bcbc5be998 Mon Sep 17 00:00:00 2001 From: Gagik Amirkhanyan Date: Sat, 21 Mar 2026 00:04:19 +0000 Subject: [PATCH] Refactored the distillation input pipeline and checkpoint manager to support state restoration via maybe_restore. --- .../distillation/distillation_utils.py | 100 ++++- .../post_train/distillation/train_distill.py | 367 +++++++++++------- .../post_training/unit/train_distill_test.py | 301 ++++++++++++-- 3 files changed, 582 insertions(+), 186 deletions(-) diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index 6768d1e466..e2716aba44 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -48,9 +48,9 @@ class DistillationForwardOutput: """Dataclass to carry MaxText-specific output fields.""" #: logits - logits: jax.Array = None + logits: jax.Array #: out_projection_activations - out_projection_activations: jax.Array = None + out_projection_activations: jax.Array | None = None @flax.struct.dataclass(frozen=True) @@ -58,18 +58,18 @@ class MaxTextTrainingInput(peft_trainer.TrainingInput): """Extended TrainingInput dataclass to carry MaxText-specific fields.""" #: Position indices for the tokens (for RoPE). - positions: jax.Array = None + positions: jax.Array | None = None #: Segment IDs for packed sequences (0=padding, 1+=examples). - decoder_segment_ids: jax.Array = None + decoder_segment_ids: jax.Array | None = None #: Ground truth target tokens (used for loss calculation and logging). - targets: jax.Array = None + targets: jax.Array | None = None #: Position indices for the target tokens. - targets_position: jax.Array = None + targets_position: jax.Array | None = None #: Segment IDs for packed target tokens. - targets_segmentation: jax.Array = None + targets_segmentation: jax.Array | None = None #: Top-K logits from the teacher model. - top_k_logits: jax.Array = None - top_k_indices: jax.Array = None + top_k_logits: jax.Array | None = None + top_k_indices: jax.Array | None = None # ----------------------------------------------------------------------------- @@ -275,7 +275,7 @@ def compute_loss( # 3. Combine losses base_logit_loss = (self.alpha * soft_loss) + ((1.0 - self.alpha) * hard_loss) - feature_loss = 0.0 + feature_loss = jnp.array(0.0) if self.beta_feature > 0.0: if self.layer_indices is not None: @@ -420,6 +420,86 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F force=force, ) + def maybe_restore( + self, + model: Any, + optimizer: Any = None, + restore_only_lora_params: bool = False, + ) -> tuple[int, dict[str, Any]]: + """Restores model and optimizer state if a checkpoint exists, using correct sharding specs. + + This method checks for the latest available checkpoint. If found, it restores the + model parameters and optionally the optimizer state in-place. It automatically + maps the parameter's `sharding` attributes to Orbax restore arguments to ensure + the tensors are placed on the correct device meshes. + + Args: + model: The model to restore. If a `ModelBundle` is provided, it automatically + extracts and restores only the `student_model`. + optimizer: The optimizer state to restore. If None, optimizer restoration is skipped. + restore_only_lora_params: If True, restricts restoration to parameters marked + as `nnx.LoRAParam`. + + Returns: + A tuple containing the restored step number (0 if no checkpoint was found) + and a dictionary of custom metadata. + """ + if self._checkpoint_manager is None: + return 0, {} + + step = self._checkpoint_manager.latest_step() + if step is None: + return 0, {} + + max_logging.log(f"Restoring from checkpoint step {step}...") + + # Extract student model safely + target_model = getattr(model, "student_model", model) + + if restore_only_lora_params: + params = nnx.state(target_model, nnx.LoRAParam) + else: + params = nnx.state(target_model) + + def map_to_pspec(data): + if hasattr(data, "sharding"): + return checkpoint.type_handlers.ArrayRestoreArgs(sharding=data.sharding) + return None + + restore_args = jax.tree.map(map_to_pspec, params) + + cp_restore_args = { + "model_params": checkpoint.args.PyTreeRestore( + item=params, + restore_args=restore_args, + ) + } + + if optimizer is not None: + optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState) + opt_restore_args = jax.tree.map(map_to_pspec, optimizer_state) + cp_restore_args["optimizer_state"] = checkpoint.args.PyTreeRestore( + item=optimizer_state, + restore_args=opt_restore_args, + ) + + restored = self._checkpoint_manager.restore( + step, + args=checkpoint.args.Composite(**cp_restore_args), + ) + + nnx.update(target_model, restored.model_params) + if optimizer is not None: + nnx.update(optimizer, restored.optimizer_state) + + metadata = self._checkpoint_manager.metadata(step) + if metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None: + custom_metadata = metadata.custom_metadata + else: + custom_metadata = {} + + return step, dict(custom_metadata) + def restore_iterator(self): """Restores the iterator using MaxText's logic.""" if self._checkpoint_manager is None or self._iterator is None: diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index b8d76bab7c..5f6ccf23b5 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -32,6 +32,8 @@ 3. **Tunix Integration**: We wrap the MaxText models in `TunixMaxTextAdapter` to expose a standard interface (call signature) that the Tunix `DistillationTrainer` expects. """ + +import inspect from typing import Sequence, Callable from absl import app from flax import nnx @@ -212,6 +214,34 @@ def __init__(self, model, strategy, optimizer, training_config, **kwargs): wrt = nnx.LoRAParam if self._lora_enabled else nnx.Param self.optimizer = nnx.Optimizer(model.student_model, optimizer, wrt=wrt) + # Detect if Tunix expects _train_step to return grad_norm by inspecting the source + self._tunix_expects_grad_norm = False + try: + source = inspect.getsource(peft_trainer.PeftTrainer._train_step) + self._tunix_expects_grad_norm = "grad_norm" in source + except (TypeError, OSError): + # Fallback if source code is unavailable + pass + + def _shard_optimizer(self, mesh: jax.sharding.Mesh) -> None: + """Overrides base _shard_optimizer to safely shard restored scalars. + + This is necessary because the optimizer state restored from checkpoints may contain unsharded + scalars (e.g., Adam moments). + """ + if mesh.empty: + return + optimizer_state = nnx.state(self.optimizer, nnx.optimizer.OptState) + optimizer_pspecs = nnx.get_partition_spec(optimizer_state) + + def _safe_shard(x, pspec): + if isinstance(pspec, jax.sharding.PartitionSpec): + return jax.device_put(x, jax.sharding.NamedSharding(mesh, pspec)) + return x + + optimizer_sharded_state = jax.tree.map(_safe_shard, optimizer_state, optimizer_pspecs) + nnx.update(self.optimizer, optimizer_sharded_state) + def _train_step(self, model, optimizer, inputs): """Overrides the main JIT block to natively handle ModelBundle module.""" @@ -258,9 +288,13 @@ def loss_wrapper(student, teacher, batch): out, grads = grad_fn(model.student_model, model.teacher_model, batch) + tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True) + optimizer.update(model.student_model, grads) - return out[0], out[1] # loss, aux + if tunix_expects_grad_norm: + return out[0], out[1], optax.global_norm(grads) + return out[0], out[1] def _eval_step(self, model, inputs): """Evaluation only needs the student.""" @@ -321,62 +355,70 @@ def _post_process_train_step(self, aux: dict[str, jax.Array]) -> None: self._buffered_train_metrics.additional_metrics[name][0].append(value) + def setup_checkpoint_manager_and_restore(self, raw_train_iter, config): + """Configures the trainer's CheckpointManager and restores states. -def _setup_and_restore_input_pipeline(trainer, raw_train_iter, config, train_config): - """Configures the trainer to save/restore Grain iterator state. + This function unconditionally replaces the default CheckpointManager with + MaxTextCheckpointManager. This ensures consistent API availability (like + wait_until_finished) and enables Grain checkpointing if the iterator supports it. - This function unconditionally replaces the default CheckpointManager with - MaxTextCheckpointManager. This ensures consistent API availability (like - wait_until_finished) and enables Grain checkpointing if the iterator supports it. + Args: + raw_train_iter: The input pipeline iterator. + config: The MaxText HyperParameters. - Args: - trainer: The active DistillationTrainer instance. - raw_train_iter: The input pipeline iterator. - config: The MaxText HyperParameters. - train_config: The Tunix TrainingConfig. + Returns: + The iterator to use for training (restored or original). + """ + is_grain_dataset = config.dataset_type == "grain" + has_save_method = hasattr(raw_train_iter, "save") + enable_checkpointing = raw_train_iter is not None and (is_grain_dataset or has_save_method) - Returns: - The iterator to use for training (restored or original). - """ - is_grain_dataset = config.dataset_type == "grain" - has_save_method = hasattr(raw_train_iter, "save") - enable_checkpointing = raw_train_iter is not None and (is_grain_dataset or has_save_method) - - iterator_to_manage = raw_train_iter if enable_checkpointing else None - - if enable_checkpointing: - max_logging.log("Input Pipeline Checkpointing: ENABLED") - max_logging.log(f"Details: dataset_type='{config.dataset_type}', has_save={has_save_method}") - else: - max_logging.log("Input Pipeline Checkpointing: DISABLED") - if raw_train_iter is None: - max_logging.log("Reason: train_iter is None") + iterator_to_manage = raw_train_iter if enable_checkpointing else None + + if enable_checkpointing: + max_logging.log("Input Pipeline Checkpointing: ENABLED") + max_logging.log(f"Details: dataset_type='{config.dataset_type}', has_save={has_save_method}") else: - max_logging.log( - f"Reason: Iterator '{type(raw_train_iter).__name__}' is not recognized as Grain " - f"(dataset_type='{config.dataset_type}', has_save={has_save_method})" - ) + max_logging.log("Input Pipeline Checkpointing: DISABLED") + if raw_train_iter is None: + max_logging.log("Reason: train_iter is None") + else: + max_logging.log( + f"Reason: Iterator '{type(raw_train_iter).__name__}' is not recognized as Grain " + f"(dataset_type='{config.dataset_type}', has_save={has_save_method})" + ) - # 1. Create the specialized manager (always) - maxtext_manager = distillation_utils.MaxTextCheckpointManager( - raw_iterator=iterator_to_manage, - root_directory=train_config.checkpoint_root_directory, - options=train_config.checkpointing_options, - ) + # 1. Ensure clean resource release of the base class's manager + # pylint: disable=access-member-before-definition + if self.checkpoint_manager: + self.checkpoint_manager.close() + # pylint: enable=access-member-before-definition + + # 2. Assign the specialized manager + self.checkpoint_manager = distillation_utils.MaxTextCheckpointManager( + raw_iterator=iterator_to_manage, + root_directory=config.checkpoint_dir, + options=self.config.checkpointing_options, + ) - # 2. Swap managers (ensure clean resource release) - if trainer.checkpoint_manager: - trainer.checkpoint_manager.close() - trainer.checkpoint_manager = maxtext_manager + # 3. Restore Model & Optimizer State correctly via MaxTextCheckpointManager. + # Accessing protected variables of the base class IS allowed inside the subclass! + self._train_steps, self._restored_custom_metadata = self.checkpoint_manager.maybe_restore( + self.model, + self.optimizer, + restore_only_lora_params=getattr(self, "_lora_enabled", False), + ) + grad_accum_steps = self.config.get_with_default("gradient_accumulation_steps", 1) + self._iter_steps = self._train_steps * grad_accum_steps - # 3. Restore input state (if applicable) - if enable_checkpointing: - restored_iter = trainer.checkpoint_manager.restore_iterator() - if restored_iter is not None: - max_logging.log("Restored input pipeline state to match model step.") - return restored_iter + # 4. Restore input state (if applicable) + if enable_checkpointing: + restored_iter = self.checkpoint_manager.restore_iterator() + if restored_iter is not None: + max_logging.log("Restored input pipeline state to match model step.") + return restored_iter - return raw_train_iter + return raw_train_iter # ----------------------------------------------------------------------------- @@ -402,34 +444,22 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh) # ----------------------------------------------------------------------------- -def train_distill( +def build_training_components( student_config: pyconfig.HyperParameters, teacher_config: pyconfig.HyperParameters, is_offline: bool = False, offline_data_dir: str | None = None, -) -> None: - """Main distillation training loop. - - Orchestrates the loading of both student and teacher models, configures the - distillation strategy, and executes the training loop via the Tunix Trainer. +): + """Builds and returns the strategy, optimizer, and training config objects. Args: - student_config: Configuration object for the Student model (learnable). - teacher_config: Configuration object for the Teacher model (frozen). - """ - # Validate vocab size match between Student and Teacher - if student_config.vocab_size != teacher_config.vocab_size: - raise ValueError( - f"Vocab size mismatch! Student: {student_config.vocab_size}, Teacher: {teacher_config.vocab_size}. " - "Distillation requires matching vocabularies." - ) - - # 1. Setup Mesh - devices = jax.devices() - devices_array = maxtext_utils.create_device_mesh(student_config, devices) - mesh = jax.sharding.Mesh(devices_array, student_config.mesh_axes) + student_config: Configuration object for the Student model. + teacher_config: Configuration object for the Teacher model. - # 2. Load Models & Tokenizer Info + Returns: + A tuple of (DistillationStrategy, Optimizer, TrainingConfig). + """ + # 2. Load Tokenizer Info tok = tokenizer.build_tokenizer( tokenizer_path=student_config.tokenizer_path, tokenizer_type=student_config.tokenizer_type, @@ -439,20 +469,6 @@ def train_distill( ) pad_id = tok.pad_id if tok.pad_id is not None else 0 - max_logging.log(f"Loading Student from {student_config.load_parameters_path}...") - _log_config_details(student_config, "Student") - student_model = get_maxtext_model(student_config, mesh) - - # Skip teacher model loading if offline - if is_offline: - max_logging.log("Offline Distillation: Skipping Teacher Model loading.") - teacher_model = None - else: - max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...") - _log_config_details(teacher_config, "Teacher") - teacher_model = get_maxtext_model(teacher_config, mesh) - teacher_model.eval() - # 3. Define Distillation Strategy def labels_fn(targets, targets_segmentation=None, **kwargs): """Converts integer targets to masked one-hot vectors for hard label loss.""" @@ -507,85 +523,136 @@ def labels_fn(targets, targets_segmentation=None, **kwargs): eval_every_n_steps=student_config.eval_interval, metrics_logging_options=metrics_logging_options, profiler_options=profiler_options, - checkpoint_root_directory=student_config.checkpoint_dir, + checkpoint_root_directory=None, # Tunix should NOT checkpoint our ModelBundle. MaxTextCheckpointManager handles this. checkpointing_options=checkpointing_options, gradient_accumulation_steps=student_config.gradient_accumulation_steps, ) - # 5. Data Iterators (Init BEFORE Trainer) - if is_offline: - max_logging.log(f"Loading Offline Dataset from {offline_data_dir}...") - raw_train_iter = distillation_utils.OfflineArrayRecordIterator(offline_data_dir) - raw_eval_iter = None - else: - max_logging.log("Initializing Data Iterators via MaxText pipeline...") - raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh) - - student_model.train() - model_bundle = ModelBundle(teacher_model, student_model) - - # 6. Initialize Trainer - trainer = MaxTextDistillationTrainer( - model=model_bundle, - strategy=strategy, - optimizer=optimizer, - training_config=train_config, - ) - trainer.is_managed_externally = True - trainer._has_aux = True # pylint: disable=protected-access - - # 7. Input Pipeline Checkpointing - # Replace the default CheckpointManager with a Grain-aware one, which enables iterator checkpointing for grain datasets. - raw_train_iter = _setup_and_restore_input_pipeline(trainer, raw_train_iter, student_config, train_config) - - # 8. Configure Input Mapping - def custom_gen_model_input_fn(batch): - inputs_dict = { - "input_tokens": batch.input_tokens, - "positions": batch.positions, - "attention_mask": batch.input_mask, - "decoder_segment_ids": batch.decoder_segment_ids, - "targets": batch.targets, - "targets_position": batch.targets_position, - "targets_segmentation": batch.targets_segmentation, - "cache": None, - } - - # If we are in online mode then we exit - if getattr(batch, "top_k_logits", None) is None: - return inputs_dict + return strategy, optimizer, train_config - # Scatter the offline arrays into a dense tensor of -10000s - dense_shape = batch.input_tokens.shape + (student_config.vocab_size,) - dense_logits = jnp.full(dense_shape, -10000.0, dtype=jnp.float32) - dense_logits = jnp.put_along_axis(dense_logits, batch.top_k_indices, batch.top_k_logits, axis=-1, inplace=False) - # Inject it as teacher_output so the trainer skips the teacher forward pass - inputs_dict["teacher_output"] = distillation_utils.DistillationForwardOutput( - logits=dense_logits, out_projection_activations=None - ) +def train_distill( + student_config: pyconfig.HyperParameters, + teacher_config: pyconfig.HyperParameters, + is_offline: bool = False, + offline_data_dir: str | None = None, +) -> None: + """Main distillation training loop. - return inputs_dict + Orchestrates the loading of both student and teacher models, configures the + distillation strategy, and executes the training loop via the Tunix Trainer. - trainer = trainer.with_gen_model_input_fn(custom_gen_model_input_fn) + Args: + student_config: Configuration object for the Student model (learnable). + teacher_config: Configuration object for the Teacher model (frozen). + """ + # Validate vocab size match between Student and Teacher + if student_config.vocab_size != teacher_config.vocab_size: + raise ValueError( + f"Vocab size mismatch! Student: {student_config.vocab_size}, Teacher: {teacher_config.vocab_size}. " + "Distillation requires matching vocabularies." + ) - # 9. Create Iterator Wrappers (Use Utils) - train_iter = distillation_utils.MaxTextToTunixIterator(raw_train_iter) + # Build Training Components (No hardware context required) + strategy, optimizer, train_config = build_training_components( + student_config, teacher_config, is_offline, offline_data_dir + ) - eval_iter = None - if raw_eval_iter is not None: - max_logging.log("Evaluation iterator successfully initialized.") - eval_iter = distillation_utils.MaxTextToTunixIterator(raw_eval_iter) - elif student_config.eval_interval > 0: - max_logging.log("Warning: eval_interval > 0 but create_data_iterator returned None for eval_iter.") + # 1. Setup Mesh + devices = jax.devices() + devices_array = maxtext_utils.create_device_mesh(student_config, devices) + mesh = jax.sharding.Mesh(devices_array, student_config.mesh_axes) - # 10. Train - max_logging.log("Starting Distillation Training...") + # Hardware Execution (Safe Context) + max_logging.log("Applying logical axis rules for model initialization and training...") with mesh, nn_partitioning.axis_rules(student_config.logical_axis_rules): + + # 2. Load Models + max_logging.log(f"Loading Student from {student_config.load_parameters_path}...") + _log_config_details(student_config, "Student") + student_model = get_maxtext_model(student_config, mesh) + + if is_offline: + max_logging.log("Offline Distillation: Skipping Teacher Model loading.") + teacher_model = None + else: + max_logging.log(f"Loading Teacher from {teacher_config.load_parameters_path}...") + _log_config_details(teacher_config, "Teacher") + teacher_model = get_maxtext_model(teacher_config, mesh) + teacher_model.eval() + + student_model.train() + model_bundle = ModelBundle(teacher_model, student_model) + + # 3. Initialize Trainer + trainer = MaxTextDistillationTrainer( + model=model_bundle, + strategy=strategy, + optimizer=optimizer, + training_config=train_config, + ) + trainer.is_managed_externally = True + trainer._has_aux = True # pylint: disable=protected-access + + # 4. Data Iterators (Init BEFORE Trainer pipeline setup) + # We use MaxText's native create_data_iterator which creates both train and eval iterators + if is_offline: + max_logging.log(f"Loading Offline Dataset from {offline_data_dir}...") + raw_train_iter = distillation_utils.OfflineArrayRecordIterator(offline_data_dir) + raw_eval_iter = None + else: + max_logging.log("Initializing Data Iterators via MaxText pipeline...") + raw_train_iter, raw_eval_iter = input_pipeline_interface.create_data_iterator(student_config, mesh) + + # 5. Input Pipeline Checkpointing & Restoration + # Replace the default CheckpointManager with a Grain-aware one, which enables iterator checkpointing for grain datasets. + raw_train_iter = trainer.setup_checkpoint_manager_and_restore(raw_train_iter, student_config) + + # 6. Configure Input Mapping + def custom_gen_model_input_fn(batch): + inputs_dict = { + "input_tokens": batch.input_tokens, + "positions": batch.positions, + "attention_mask": batch.input_mask, + "decoder_segment_ids": batch.decoder_segment_ids, + "targets": batch.targets, # Passed to strategy (labels_fn) + "targets_position": batch.targets_position, # Passed to strategy (labels_fn) + "targets_segmentation": batch.targets_segmentation, # Passed to strategy (labels_fn) + "cache": None, + } + # If we are in online mode then we exit + if getattr(batch, "top_k_logits", None) is None: + return inputs_dict + + # Scatter the offline arrays into a dense tensor of -10000s + dense_shape = batch.input_tokens.shape + (student_config.vocab_size,) + dense_logits = jnp.full(dense_shape, -10000.0, dtype=jnp.float32) + dense_logits = jnp.put_along_axis(dense_logits, batch.top_k_indices, batch.top_k_logits, axis=-1, inplace=False) + + # Inject it as teacher_output so the trainer skips the teacher forward pass + inputs_dict["teacher_output"] = distillation_utils.DistillationForwardOutput( + logits=dense_logits, out_projection_activations=None + ) + return inputs_dict + + trainer = trainer.with_gen_model_input_fn(custom_gen_model_input_fn) + + # 7. Create Iterator Wrappers (Use Utils) + train_iter = distillation_utils.MaxTextToTunixIterator(raw_train_iter) + + eval_iter = None + if raw_eval_iter is not None: + max_logging.log("Evaluation iterator successfully initialized.") + eval_iter = distillation_utils.MaxTextToTunixIterator(raw_eval_iter) + elif student_config.eval_interval > 0: + max_logging.log("Warning: eval_interval > 0 but create_data_iterator returned None for eval_iter.") + + # 8. Train + max_logging.log("Starting Distillation Training...") # Pass both iterators to the trainer trainer.train(train_iter, eval_iter) - # 11. Final Save (Conditional) + # 9. Final Save (Conditional) if student_config.save_checkpoint_on_completion: should_save = student_config.steps % student_config.checkpoint_period @@ -594,7 +661,8 @@ def custom_gen_model_input_fn(batch): try: saved = trainer.checkpoint_manager.save( trainer.train_steps, - trainer.model.student_model, + trainer.model, + optimizer=trainer.optimizer, save_only_lora_params=getattr(trainer, "_lora_enabled", False), force=True, ) @@ -619,6 +687,9 @@ def main(argv: Sequence[str]) -> None: Parses configuration, isolates Student and Teacher overrides, and triggers the training loop. + + Args: + argv: List of command-line arguments. Expects [script_name, config_file, ...]. """ # 1. Parse Global Config to extract Overrides global_config = pyconfig.initialize(argv) diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index 1604b7d86c..d6adadb89b 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -15,6 +15,7 @@ """Unit tests for the Distillation Trainer.""" +import os import pytest pytest.importorskip("tunix") @@ -36,6 +37,7 @@ from maxtext.trainers.post_train.distillation import train_distill from maxtext.trainers.post_train.distillation import distillation_utils from maxtext.configs import pyconfig +from tests.utils.test_helpers import get_test_config_path # pylint: disable=protected-access @@ -140,9 +142,12 @@ def test_prepare_inputs_logic(self): # 5. Verify trainer.strategy.get_teacher_outputs.assert_not_called() + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_train_step_skips_teacher_forward_when_output_present(self, mock_value_and_grad, mock_tree_map): + def test_train_step_skips_teacher_forward_when_output_present( + self, mock_value_and_grad, mock_tree_map, mock_global_norm + ): """Verifies teacher forward is skipped when model_output is already in the batch.""" # 1. Initialize Trainer # pylint: disable=no-value-for-parameter @@ -166,9 +171,10 @@ def test_train_step_skips_teacher_forward_when_output_present(self, mock_value_a optimizer, inputs = mock.Mock(), mock.Mock() # 4. Configure mocked nnx.value_and_grad - mock_loss, mock_aux, mock_grads = mock.Mock(), mock.Mock(), mock.Mock() + mock_loss, mock_aux, mock_grads = mock.Mock(), {}, mock.Mock() mock_grad_fn = mock.Mock(return_value=((mock_loss, mock_aux), mock_grads)) mock_value_and_grad.return_value = mock_grad_fn + mock_global_norm.return_value = mock.Mock() # 5. Execute outer function & trigger inner loss_wrapper trainer._train_step(model_bundle, optimizer, inputs) @@ -188,9 +194,12 @@ def test_train_step_skips_teacher_forward_when_output_present(self, mock_value_a cache=None, ) + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_and_grad, mock_tree_map): + def test_train_step_calls_teacher_forward_when_output_missing( + self, mock_value_and_grad, mock_tree_map, mock_global_norm + ): """Verifies teacher forward is called when model_output is missing from the batch.""" # 1. Initialize Trainer # pylint: disable=no-value-for-parameter @@ -214,12 +223,14 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a optimizer, inputs = mock.Mock(), mock.Mock() # 4. Configure mocked nnx.value_and_grad - mock_loss, mock_aux, mock_grads = mock.Mock(), mock.Mock(), mock.Mock() + mock_loss, mock_aux, mock_grads = mock.Mock(), {}, mock.Mock() mock_grad_fn = mock.Mock(return_value=((mock_loss, mock_aux), mock_grads)) mock_value_and_grad.return_value = mock_grad_fn + mock_gn = mock.Mock() + mock_global_norm.return_value = mock_gn # 5. Execute outer function & trigger inner loss_wrapper - loss, aux = trainer._train_step(model_bundle, optimizer, inputs) + train_step_out = trainer._train_step(model_bundle, optimizer, inputs) loss_wrapper = mock_value_and_grad.call_args[0][0] loss_wrapper(student_model, teacher_model, mock_batch) @@ -252,12 +263,16 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a optimizer.update.assert_called_once_with(student_model, mock_grads) # Verify the final returns match what grad_fn produced - self.assertEqual(loss, mock_loss) - self.assertEqual(aux, mock_aux) + self.assertEqual(train_step_out[0], mock_loss) + if len(train_step_out) > 2: + self.assertEqual(train_step_out[2], mock_gn) + elif "distill/grad_norm" in train_step_out[1]: + self.assertEqual(train_step_out[1]["distill/grad_norm"], mock_gn) + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_tree_map): + def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_tree_map, mock_global_norm): """Verifies strategy callbacks receive decoder_target_tokens and decoder_target_mask.""" # 1. Initialize Trainer # pylint: disable=no-value-for-parameter @@ -282,8 +297,9 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_ optimizer, inputs = mock.Mock(), mock.Mock() # 4. Configure mocked nnx.value_and_grad - mock_grad_fn = mock.Mock(return_value=((mock.Mock(), mock.Mock()), mock.Mock())) + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), {}), mock.Mock())) mock_value_and_grad.return_value = mock_grad_fn + mock_global_norm.return_value = mock.Mock() # 5. Execute outer function & trigger inner loss_wrapper trainer._train_step(model_bundle, optimizer, inputs) @@ -447,55 +463,73 @@ def test_strategy_compute_eval_loss_sft(self): self._verify_strategy_compute_eval_loss(sft_mode=True) def test_setup_pipeline_grain_enabled(self): - """Covers _setup_and_restore_input_pipeline when Grain IS detected.""" + """Covers setup_checkpoint_manager_and_restore when Grain IS detected.""" mock_trainer = mock.Mock() mock_trainer.checkpoint_manager = mock.Mock() # Mock restore returning None (no checkpoint yet) mock_trainer.checkpoint_manager.restore_iterator.return_value = None + mock_trainer.model.student_model = mock.Mock() + mock_trainer.optimizer = mock.Mock() + mock_trainer._lora_enabled = False + mock_iter = mock.Mock() mock_iter.save = mock.Mock() # Has save method config = mock.Mock() config.dataset_type = "grain" + config.checkpoint_dir = self.test_dir # Use real options to avoid Orbax validation errors caused by Mocks train_config = mock.Mock() - train_config.checkpoint_root_directory = self.test_dir + train_config.get_with_default.return_value = 1 + train_config.checkpoint_root_directory = None train_config.checkpointing_options = ocp.CheckpointManagerOptions(max_to_keep=1, create=True) + mock_trainer.config = train_config - # Run function - result = train_distill._setup_and_restore_input_pipeline(mock_trainer, mock_iter, config, train_config) + result = train_distill.MaxTextDistillationTrainer.setup_checkpoint_manager_and_restore( + mock_trainer, mock_iter, config + ) # Verify manager was swapped self.assertIsInstance(mock_trainer.checkpoint_manager, distillation_utils.MaxTextCheckpointManager) self.assertEqual(result, mock_iter) def test_setup_pipeline_restored(self): - """Covers _setup_and_restore_input_pipeline when restore succeeds.""" + """Verifies that a checkpoint accurately restores the input pipeline iterator.""" mock_trainer = mock.Mock() - - # Mock successful restore - restored_iter = mock.Mock() mock_manager = mock.Mock() + restored_iter = mock.Mock() mock_manager.restore_iterator.return_value = restored_iter + # Use real options + train_config = mock.Mock() + train_config.get_with_default.return_value = 1 + train_config.checkpoint_root_directory = None + train_config.checkpointing_options = ocp.CheckpointManagerOptions(max_to_keep=1, create=True) + mock_trainer.config = train_config # set internally + # We need to mock the constructor of MaxTextCheckpointManager to return our mock with mock.patch( "maxtext.trainers.post_train.distillation.distillation_utils.MaxTextCheckpointManager", return_value=mock_manager ): + mock_trainer.model = mock.Mock() + mock_trainer.optimizer = mock.Mock() + mock_trainer._lora_enabled = False + mock_trainer.config = train_config # Set internal Tunix config + mock_iter = mock.Mock() mock_iter.save = mock.Mock() config = mock.Mock() config.dataset_type = "grain" + config.checkpoint_dir = self.test_dir - # Use real options - train_config = mock.Mock() - train_config.checkpoint_root_directory = self.test_dir - train_config.checkpointing_options = ocp.CheckpointManagerOptions(max_to_keep=1, create=True) + mock_manager.maybe_restore.return_value = (10, {}) - result = train_distill._setup_and_restore_input_pipeline(mock_trainer, mock_iter, config, train_config) + result = train_distill.MaxTextDistillationTrainer.setup_checkpoint_manager_and_restore( + mock_trainer, mock_iter, config + ) # Verify it returned the restored iterator, NOT the raw one self.assertEqual(result, restored_iter) @@ -564,19 +598,33 @@ def test_eval_step_calls_student_forward(self): self.assertEqual(actual_loss, mock_loss) def test_setup_pipeline_disabled(self): - """Covers _setup_and_restore_input_pipeline when checkpoiting is disabled.""" + """Covers setup_checkpoint_manager_and_restore when checkpointing is disabled.""" mock_trainer = mock.Mock() + mock_trainer.model = mock.Mock() + mock_trainer.optimizer = mock.Mock() + mock_trainer._lora_enabled = False + mock_iter = object() # No save method config = mock.Mock() config.dataset_type = "tfds" # Not grain + config.checkpoint_dir = self.test_dir # Use real options train_config = mock.Mock() - train_config.checkpoint_root_directory = self.test_dir + train_config.get_with_default.return_value = 1 + train_config.checkpoint_root_directory = None train_config.checkpointing_options = ocp.CheckpointManagerOptions(max_to_keep=1, create=True) + mock_trainer.config = train_config - result = train_distill._setup_and_restore_input_pipeline(mock_trainer, mock_iter, config, train_config) + with mock.patch( + "maxtext.trainers.post_train.distillation.distillation_utils.MaxTextCheckpointManager.maybe_restore" + ) as mock_restore: + mock_restore.return_value = (10, {}) + + result = train_distill.MaxTextDistillationTrainer.setup_checkpoint_manager_and_restore( + mock_trainer, mock_iter, config + ) # Should still swap manager (to MaxTextCheckpointManager) but with None iterator self.assertIsInstance(mock_trainer.checkpoint_manager, distillation_utils.MaxTextCheckpointManager) @@ -624,7 +672,7 @@ def __call__(self, x): model_bundle = train_distill.ModelBundle(teacher_model=teacher, student_model=student) # Snapshot the initial weights - initial_weights = student.linear.kernel.value.copy() + initial_weights = student.linear.kernel.get_value().copy() # 2. Setup Optimizer with MultiSteps (Accumulate over 2 passes) base_optimizer = optax.sgd(learning_rate=0.1) @@ -661,7 +709,7 @@ def __call__(self, x): # Verify weights are completely UNCHANGED np.testing.assert_allclose( - student.linear.kernel.value, initial_weights, err_msg="Weights should not update on the first pass." + student.linear.kernel.get_value(), initial_weights, err_msg="Weights should not update on the first pass." ) # --- EXECUTE PASS 2 --- @@ -673,7 +721,204 @@ def __call__(self, x): # Verify weights HAVE changed with self.assertRaises(AssertionError, msg="Weights should have updated on the second pass."): - np.testing.assert_allclose(student.linear.kernel.value, initial_weights) + np.testing.assert_allclose(student.linear.kernel.get_value(), initial_weights) + + @mock.patch("clu.metric_writers.create_default_writer") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.tokenizer.build_tokenizer") + def test_train_save_and_resume(self, mock_build_tokenizer, mock_writer): + """Verifies that the trainer can save a checkpoint and resume from it.""" + # Provide a dummy tokenizer + mock_tok = mock.Mock() + mock_tok.pad_id = 0 + mock_build_tokenizer.return_value = mock_tok + + base_config_path = get_test_config_path() + base_args = [ + "", + base_config_path, + f"base_output_directory={self.test_dir}", + "run_name=distill_resume_test", + 'metrics_dir=""', + "dataset_type=synthetic", + "vocab_size=32", + "base_emb_dim=8", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "base_mlp_dim=16", + "base_num_decoder_layers=1", + "head_dim=8", + "per_device_batch_size=1", + "max_target_length=16", + "enable_checkpointing=True", + "async_checkpointing=False", + "checkpoint_period=1", + "save_checkpoint_on_completion=True", + "log_period=1", + "eval_interval=0", + "use_sft=False", + "distill_beta=0.0", + "dataset_path=/", # Not used for synthetic, but required by some checks + "enable_checkpointing=True", + ] + + # Run 1: Train for 1 step + argv_run1 = base_args + ["steps=1"] + global_config_1 = pyconfig.initialize(argv_run1) + student_config_1 = pyconfig.initialize(argv_run1, **global_config_1.student_overrides) + teacher_config_1 = pyconfig.initialize(argv_run1, **global_config_1.teacher_overrides) + + # Execute first run + train_distill.train_distill(student_config_1, teacher_config_1) + + # Run 2: Resume and train up to step 2 + argv_run2 = base_args + ["steps=2"] + global_config_2 = pyconfig.initialize(argv_run2) + student_config_2 = pyconfig.initialize(argv_run2, **global_config_2.student_overrides) + teacher_config_2 = pyconfig.initialize(argv_run2, **global_config_2.teacher_overrides) + + # Wrap the checkpoint manager creation to spy on maybe_restore + original_maybe_restore = distillation_utils.MaxTextCheckpointManager.maybe_restore + with mock.patch.object(distillation_utils.MaxTextCheckpointManager, "maybe_restore", autospec=True) as mock_restore: + # Actually call the original to preserve behavior + def side_effect(self, *args, **kwargs): + return original_maybe_restore(self, *args, **kwargs) + + mock_restore.side_effect = side_effect + + # Execute second run + train_distill.train_distill(student_config_2, teacher_config_2) + + # Verify that restore was called and returned train_steps = 1 + self.assertTrue(mock_restore.called) + # Check the actual return value of the mocked call would be (1, ...) but it's hard to assert directly + # On the spy's return we know it was called. To be safe, we can check the checkpoint directory. + + # Check that step 2 checkpoint was written + # The checkpoints should be stored in {test_dir}/distill_resume_test/checkpoints + checkpoint_dir = os.path.join(self.test_dir, "distill_resume_test", "checkpoints") + self.assertTrue(os.path.exists(checkpoint_dir), f"Checkpoint directory {checkpoint_dir} not found") + + # List contents of checkpoint dir + checkpoints = os.listdir(checkpoint_dir) + # Checkpoints are usually named '0/', '1/', etc. + # With steps=1 and steps=2 and checkpoint_period=1, we should have '1' and '2' (or similar). + self.assertTrue(any(c == "1" or c.endswith("1") for c in checkpoints), f"Checkpoint 1 not found in {checkpoints}") + self.assertTrue(any(c == "2" or c.endswith("2") for c in checkpoints), f"Checkpoint 2 not found in {checkpoints}") + + def test_checkpointing_and_resume(self): + """Trains a few steps, saves a checkpoint, and resumes from it.""" + + # 1. Setup minimal dummy model and models bundle + class DummyModel(nnx.Module): + + def __init__(self): + self.linear = nnx.Linear(in_features=2, out_features=2, rngs=nnx.Rngs(0)) + + def __call__(self, input_tokens, **kwargs): + # We need an output compatible with the dummy strategy + return self.linear(jnp.ones((1, 2))) + + student1 = DummyModel() + teacher1 = DummyModel() + bundle1 = train_distill.ModelBundle(teacher_model=teacher1, student_model=student1) + + # 2. Setup strategy and trainer config + strategy = mock.Mock() + strategy.compute_loss.side_effect = lambda s_out, t_out, labels: (jnp.sum(s_out.logits), {"aux": 1.0}) + strategy.labels_fn.return_value = None + strategy.student_forward_fn = lambda model, **kw: distillation_utils.DistillationForwardOutput( + logits=model(kw["input_tokens"]) + ) + strategy.teacher_forward_fn = lambda model, **kw: distillation_utils.DistillationForwardOutput( + logits=model(kw["input_tokens"]) + ) + + config = mock.Mock() + config.checkpoint_dir = self.test_dir + config.dataset_type = "tfds" + config.lora_enabled = False + + # pylint: disable=import-outside-toplevel + from tunix.sft import peft_trainer + + train_config = peft_trainer.TrainingConfig( + max_steps=2, + eval_every_n_steps=0, + checkpointing_options=ocp.CheckpointManagerOptions(save_interval_steps=1, max_to_keep=2, create=True), + gradient_accumulation_steps=1, + ) + + optimizer1 = optax.sgd(0.1) + + trainer1 = train_distill.MaxTextDistillationTrainer( + model=bundle1, + strategy=strategy, + optimizer=optimizer1, + training_config=train_config, + ) + trainer1._lora_enabled = False + trainer1.is_managed_externally = True + + # Mock input mapping + trainer1 = trainer1.with_gen_model_input_fn( + lambda batch: { + "input_tokens": batch.input_tokens, + "positions": batch.positions, + "attention_mask": batch.input_mask, + "decoder_segment_ids": batch.decoder_segment_ids, + "targets": batch.targets, + "targets_position": batch.targets_position, + "targets_segmentation": batch.targets_segmentation, + "cache": None, + } + ) + + # 3. Restore pipeline (creates the MaxTextCheckpointManager) + # pylint: disable=unexpected-keyword-arg + dummy_input = distillation_utils.MaxTextTrainingInput( + input_tokens=jnp.ones((1, 2)), + input_mask=jnp.ones((1, 2), dtype=bool), + ) + dummy_iter = iter([dummy_input, dummy_input]) + + trainer1.setup_checkpoint_manager_and_restore(dummy_iter, config) + + # Train for 2 steps + trainer1.train(dummy_iter, None) + + trainer1.checkpoint_manager.wait_until_finished() + + # Verify checkpoint exists + self.assertEqual(trainer1.checkpoint_manager.latest_step(), 2) + saved_weights = student1.linear.kernel.get_value().copy() + + # 4. Resume + student2 = DummyModel() + teacher2 = DummyModel() + bundle2 = train_distill.ModelBundle(teacher_model=teacher2, student_model=student2) + optimizer2 = optax.sgd(0.1) + + trainer2 = train_distill.MaxTextDistillationTrainer( + model=bundle2, + strategy=strategy, + optimizer=optimizer2, + training_config=train_config, + ) + trainer2._lora_enabled = False + + # Call setup_checkpoint_manager_and_restore to resume + trainer2.setup_checkpoint_manager_and_restore(iter([]), config) + + # We expect _train_steps to be restored to 2 + self.assertEqual(trainer2._train_steps, 2) + + # Verify weights are identical to the trained ones, rather than the fresh ones + np.testing.assert_allclose(student2.linear.kernel.get_value(), saved_weights) + + if hasattr(trainer1.checkpoint_manager, "wait_until_finished"): + trainer1.checkpoint_manager.wait_until_finished() + if hasattr(trainer2.checkpoint_manager, "wait_until_finished"): + trainer2.checkpoint_manager.wait_until_finished() @mock.patch("maxtext.trainers.post_train.distillation.train_distill.distillation_utils.OfflineArrayRecordIterator") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.MaxTextDistillationTrainer")