From ce6fce8c4fc661890714362fc419d65966509d54 Mon Sep 17 00:00:00 2001 From: vlad-karp Date: Fri, 20 Mar 2026 23:31:57 +0000 Subject: [PATCH 1/2] fixed eod masking + refactored distill strategy --- .../distillation/distillation_utils.py | 111 ++++++++++++++++-- .../post_train/distillation/train_distill.py | 32 +++-- .../post_training/unit/train_distill_test.py | 74 +++++++++--- 3 files changed, 176 insertions(+), 41 deletions(-) diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index 6768d1e466..6cfc3a5b68 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -18,10 +18,14 @@ model structures with Tunix's training interfaces. """ +<<<<<<< Updated upstream import pickle import tensorflow as tf from array_record.python import array_record_module +======= +import abc +>>>>>>> Stashed changes from typing import Any, Iterator, Optional, List, Callable import flax @@ -182,21 +186,87 @@ def __next__(self) -> MaxTextTrainingInput: # ----------------------------------------------------------------------------- # Distillation Strategy # ----------------------------------------------------------------------------- -class CombinedDistillationStrategy: + + +class DistillationStrategy(abc.ABC): + """Abstract base class for MaxText Distillation Strategies.""" + + def __init__( + self, student_forward_fn: Callable, teacher_forward_fn: Callable, vocab_size: int, pad_id: int = 0, **kwargs + ): + """Initializes the generic distillation strategy. + + Args: + student_forward_fn: Function to compute student model outputs. + teacher_forward_fn: Function to compute teacher model outputs. + vocab_size: The size of the model's vocabulary. + pad_id: The ID used for padding tokens. + """ + self.student_forward_fn = student_forward_fn + self.teacher_forward_fn = teacher_forward_fn + self.vocab_size = vocab_size + self.pad_id = pad_id + + @abc.abstractmethod + def compute_loss( + self, + student_output: "DistillationForwardOutput", + teacher_output: "DistillationForwardOutput", + labels: jax.Array, + ) -> tuple[jax.Array, dict[str, jax.Array]]: + """Computes the distillation loss. + + Args: + student_output: The forward pass output of the student model. + teacher_output: The forward pass output of the frozen teacher model. + labels: The masked one-hot encoded ground truth labels. + + Returns: + A tuple containing the scalar loss and a dictionary of auxiliary metrics + (e.g., {"distill/soft_loss": ..., "distill/total_loss": ...}) + """ + raise NotImplementedError + + @abc.abstractmethod + def compute_eval_loss( + self, + student_output: "DistillationForwardOutput", + labels: jax.Array, + ) -> tuple[jax.Array, dict[str, jax.Array]]: + """Computes the evaluation loss (typically just the task loss). + + Args: + student_output: The forward pass output of the student model. + labels: The masked one-hot encoded ground truth labels. + + Returns: + A tuple containing the scalar loss and an empty (or auxiliary) dict. + """ + raise NotImplementedError + + @abc.abstractmethod + def create_labels(self, targets: jax.Array, targets_segmentation: Optional[jax.Array] = None, **kwargs) -> jax.Array: + """ + Creates labels tensor to compute the loss + """ + raise NotImplementedError + + +class CombinedDistillationStrategy(DistillationStrategy): """Strategy that returns detailed metrics for TensorBoard.""" def __init__( self, student_forward_fn: Callable[..., DistillationForwardOutput], teacher_forward_fn: Callable[..., DistillationForwardOutput], - labels_fn: Callable[..., jax.Array], + pad_id: int = 0, temperature: float = 2.0, alpha: float = 0.5, beta_feature: float = 0.0, layer_indices: Optional[List[int]] = None, feature_loss_fn: Callable[[jax.Array, jax.Array], jax.Array] | None = None, cosine_distance_axis: int | tuple[int, ...] = -1, - sft_mode: bool = False, + vocab_size: int = 0, ): """Initializes the Combined strategy using tunix logit.LogitStrategy. @@ -213,9 +283,14 @@ def __init__( cosine_distance_axis: The axis to use for cosine distance computation if feature_loss_fn is not provided. Defaults to -1. """ - self.student_forward_fn = student_forward_fn - self.teacher_forward_fn = teacher_forward_fn - self.labels_fn = labels_fn + + super().__init__( + student_forward_fn=student_forward_fn, + teacher_forward_fn=teacher_forward_fn, + vocab_size=vocab_size, + pad_id=pad_id, + ) + self.temperature = temperature self.alpha = alpha self.beta_feature = beta_feature @@ -226,7 +301,6 @@ def __init__( self.feature_loss_fn = lambda student_features, teacher_features: jnp.mean( optax.cosine_distance(student_features, teacher_features, axis=cosine_distance_axis) ) - self.sft_mode = sft_mode def compute_loss( self, @@ -253,10 +327,9 @@ def compute_loss( log_student_probs_temp = jax.nn.log_softmax(s_logits / self.temperature, axis=-1) teacher_probs_temp = jax.nn.softmax(t_logits / self.temperature, axis=-1) - # labels are supposed to have all sft masks applied by this moment - labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True) if self.sft_mode else None - mean_mask = jnp.squeeze(labels_mask, axis=-1) if labels_mask is not None else None + labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True) + mean_mask = jnp.squeeze(labels_mask, axis=-1) # KL(Teacher || Student) kl_div = optax.kl_divergence(log_student_probs_temp, teacher_probs_temp, where=labels_mask) @@ -297,7 +370,7 @@ def compute_loss( metrics = { "distill/soft_loss": soft_loss, "distill/hard_loss": hard_loss, - "distill/kl_div": jnp.mean(kl_div), + "distill/kl_div": jnp.mean(kl_div, where=mean_mask), "distill/teacher_loss": teacher_hard_loss, "distill/out_proj_feature_loss": feature_loss, "distill/total_loss": total_loss, @@ -316,12 +389,24 @@ def compute_eval_loss( # Parent logic for task loss # We re-implement simple CE here to ensure float32 casting s_logits = student_output.logits.astype(jnp.float32) - ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels) - task_loss = jnp.mean(ce_loss) + + labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True) + mean_mask = jnp.squeeze(labels_mask, axis=-1) + ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels, where=labels_mask) + task_loss = jnp.mean(ce_loss, where=mean_mask) # Must return a tuple because _has_aux=True expects it return task_loss, {} + def create_labels(self, targets, targets_segmentation=None, **kwargs): + """Converts integer targets to masked one-hot vectors for hard label loss.""" + del kwargs # Unused + one_hot = jax.nn.one_hot(targets, self.vocab_size) + mask = jnp.not_equal(targets, self.pad_id).astype(one_hot.dtype)[..., None] + if targets_segmentation is not None: + mask = mask * (targets_segmentation != 0)[..., None] + return one_hot * mask + # ----------------------------------------------------------------------------- # Checkpoint Manager diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index b8d76bab7c..53f4c1f3a7 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -197,7 +197,7 @@ class MaxTextDistillationTrainer(peft_trainer.PeftTrainer): (positions, segment_ids) are passed to the model. """ - def __init__(self, model, strategy, optimizer, training_config, **kwargs): + def __init__(self, model, strategy: distillation_utils.DistillationStrategy, optimizer, training_config, **kwargs): # We pass a dummy optimizer to the base PeftTrainer temporarily to prevent PeftTrainer from eagerly # allocating massive optimizer states for the entire ModelBundle (including the frozen teacher) before # redefining the trainer optimizer here. @@ -245,7 +245,7 @@ def loss_wrapper(student, teacher, batch): cache=None, ) # we should apply a mask for labels to disable segment-separator tokens - labels = self.strategy.labels_fn(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None)) + labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None)) return self.strategy.compute_loss(student_output, teacher_output, labels) # Because student is the 0th argument, argnums=0 guarantees @@ -274,7 +274,7 @@ def _eval_step(self, model, inputs): decoder_segment_ids=inputs.get("decoder_segment_ids"), cache=None, ) - labels = self.strategy.labels_fn(inputs["targets"]) + labels = self.strategy.create_labels(inputs["targets"], targets_segmentation=inputs.get("targets_segmentation", None)) return self.strategy.compute_eval_loss(student_output, labels) def _prepare_inputs( @@ -454,14 +454,6 @@ def train_distill( teacher_model.eval() # 3. Define Distillation Strategy - def labels_fn(targets, targets_segmentation=None, **kwargs): - """Converts integer targets to masked one-hot vectors for hard label loss.""" - del kwargs # Unused - one_hot = jax.nn.one_hot(targets, student_config.vocab_size) - mask = jnp.not_equal(targets, pad_id).astype(one_hot.dtype)[..., None] - if targets_segmentation is not None: - mask = mask * (targets_segmentation != 0)[..., None] - return one_hot * mask # Both Student and Teacher use the same forward logic via the adapter student_forward_fn = create_forward_fn(student_config) @@ -471,12 +463,11 @@ def labels_fn(targets, targets_segmentation=None, **kwargs): strategy = distillation_utils.CombinedDistillationStrategy( student_forward_fn=student_forward_fn, teacher_forward_fn=teacher_forward_fn, - labels_fn=labels_fn, temperature=student_config.distill_temperature, alpha=student_config.distill_alpha, beta_feature=student_config.distill_beta, layer_indices=student_config.distill_layer_indices, - sft_mode=student_config.use_sft, + vocab_size=student_config.vocab_size, ) # 4. Optimizer & Config @@ -539,6 +530,7 @@ def labels_fn(targets, targets_segmentation=None, **kwargs): raw_train_iter = _setup_and_restore_input_pipeline(trainer, raw_train_iter, student_config, train_config) # 8. Configure Input Mapping +<<<<<<< Updated upstream def custom_gen_model_input_fn(batch): inputs_dict = { "input_tokens": batch.input_tokens, @@ -568,6 +560,20 @@ def custom_gen_model_input_fn(batch): return inputs_dict trainer = trainer.with_gen_model_input_fn(custom_gen_model_input_fn) +======= + trainer = trainer.with_gen_model_input_fn( + lambda batch: { + "input_tokens": batch.input_tokens, + "positions": batch.positions, + "attention_mask": batch.input_mask, + "decoder_segment_ids": batch.decoder_segment_ids, + "targets": batch.targets, # Passed to strategy (create_labels) + "targets_position": batch.targets_position, # Passed to strategy (create_labels) + "targets_segmentation": batch.targets_segmentation, # Passed to strategy (create_labels) + "cache": None, + } + ) +>>>>>>> Stashed changes # 9. Create Iterator Wrappers (Use Utils) train_iter = distillation_utils.MaxTextToTunixIterator(raw_train_iter) diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index 1604b7d86c..d93161a9f0 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -247,7 +247,7 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a ) # Verify loss computation and optimizer update - trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"], targets_segmentation=None) + trainer.strategy.create_labels.assert_called_once_with(mock_batch["targets"], targets_segmentation=None) trainer.strategy.compute_loss.assert_called_once() optimizer.update.assert_called_once_with(student_model, mock_grads) @@ -291,7 +291,7 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_ loss_wrapper(student_model, teacher_model, mock_batch) # 6. Assertions - trainer.strategy.labels_fn.assert_called_once_with( + trainer.strategy.create_labels.assert_called_once_with( mock_batch["targets"], targets_segmentation=mock_targets_segmentation ) trainer.strategy.student_forward_fn.assert_called_once_with( @@ -359,15 +359,16 @@ def test_monitored_strategy_sft(self): def _test_monitored_strategy(self, sft_mode: bool): """Verifies the strategy calculates metrics and returns the correct tuple.""" + mock_config = mock.Mock() + mock_config.vocab_size = 4 strategy = distillation_utils.CombinedDistillationStrategy( student_forward_fn=lambda m, **k: None, teacher_forward_fn=lambda m, **k: None, - labels_fn=lambda t: t, + vocab_size=mock_config.vocab_size, temperature=1.0, alpha=0.5, beta_feature=1.0, layer_indices=None, - sft_mode=sft_mode, ) # Dummy inputs (batch=1, seq=2, vocab=4) @@ -410,18 +411,17 @@ def _test_monitored_strategy(self, sft_mode: bool): self.assertLess(metrics["distill/kl_div"], 1e-5) self.assertLess(metrics["distill/out_proj_feature_loss"], 1e-5) - def test_strategy_compute_eval_loss(self): - self._verify_strategy_compute_eval_loss(sft_mode=False) - - def _verify_strategy_compute_eval_loss(self, sft_mode): + def verify_strategy_compute_eval_loss(self): """Covers MonitoredLogitStrategy.compute_eval_loss.""" + mock_config = mock.Mock() + mock_config.vocab_size = 4 strategy = distillation_utils.CombinedDistillationStrategy( student_forward_fn=mock.Mock(), teacher_forward_fn=mock.Mock(), - labels_fn=mock.Mock(), + vocab_size=mock_config.vocab_size, + # student_config=mock_config, temperature=1.0, alpha=0.5, - sft_mode=sft_mode, ) # Case where feature loss is enabled logits = distillation_utils.DistillationForwardOutput( @@ -443,8 +443,51 @@ def _verify_strategy_compute_eval_loss(self, sft_mode): self.assertTrue(isinstance(loss, jax.Array)) self.assertEqual(aux, {}) - def test_strategy_compute_eval_loss_sft(self): - self._verify_strategy_compute_eval_loss(sft_mode=True) + def test_strategy_ignores_segmentation_zero_tokens(self): + """Verifies that 0 tokens in targets_segmentation are ignored in loss computation.""" + strategy = distillation_utils.CombinedDistillationStrategy( + student_forward_fn=mock.Mock(), + teacher_forward_fn=mock.Mock(), + vocab_size=4, + temperature=1.0, + alpha=0.5, + pad_id=0, + ) + + # 1. Leverage the targets_segmentation tensor and put a 0 token in between. + # Token 1 is a delimiter (targets_segmentation = 0). + targets = jnp.array([[2, 1, 3]]) + targets_segmentation = jnp.array([[1, 0, 1]]) + + # 2. Create labels with the zeroed out segment delimiter mask. + labels = strategy.create_labels(targets, targets_segmentation=targets_segmentation) + + # Student has all predictions incorrect + s_logits = jnp.array( + [ + [ + [10.0, -10.0, -10.0, -10.0], + [-10.0, 10.0, -10.0, -10.0], + [-10.0, 10.0, -10.0, -10.0], + ] + ] # correct + ) + student_output = distillation_utils.DistillationForwardOutput(logits=s_logits, out_projection_activations=None) + + # Teacher perfectly predicts the target for Token 0 and Token 2, and class 1 for Token 1 + t_logits = jnp.array([[[-10.0, -10.0, 10.0, -10.0], [10.0, -10.0, -10.0, -10.0], [-10.0, -10.0, -10.0, 10.0]]]) + teacher_output = distillation_utils.DistillationForwardOutput(logits=t_logits, out_projection_activations=None) + + # 3. Call compute_loss() + _, metrics = strategy.compute_loss(student_output, teacher_output, labels) + + # all tokens are predicted incorrect so the loss should be 10*2 since + # token at position 1 should be excluded from the loss + # mean kl_div should also be equal to 20 + self.assertTrue(19.0 < metrics["distill/hard_loss"] < 21.0) + self.assertTrue(19.0 < metrics["distill/soft_loss"] < 21.0) + self.assertTrue(19.0 < metrics["distill/kl_div"] < 21.0) + self.assertTrue(metrics["distill/teacher_loss"] == 0.0) def test_setup_pipeline_grain_enabled(self): """Covers _setup_and_restore_input_pipeline when Grain IS detected.""" @@ -515,6 +558,7 @@ def test_eval_step_calls_student_forward(self): "attention_mask": mock.Mock(), "decoder_segment_ids": mock.Mock(), "targets": mock.Mock(), + "targets_segmentation": None, } trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch) @@ -528,7 +572,7 @@ def test_eval_step_calls_student_forward(self): trainer.strategy.student_forward_fn.return_value = mock_student_output mock_labels = mock.Mock() - trainer.strategy.labels_fn.return_value = mock_labels + trainer.strategy.create_labels.return_value = mock_labels mock_loss = mock.Mock() trainer.strategy.compute_eval_loss.return_value = mock_loss @@ -557,7 +601,7 @@ def test_eval_step_calls_student_forward(self): trainer.strategy.teacher_forward_fn.assert_not_called() # Verify loss computation pipeline - trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"]) + trainer.strategy.create_labels.assert_called_once_with(mock_batch["targets"], targets_segmentation=None) trainer.strategy.compute_eval_loss.assert_called_once_with(mock_student_output, mock_labels) # Verify it returns the correct loss @@ -643,7 +687,7 @@ def __call__(self, x): "teacher_output": jnp.array([1.0, 1.0]), } trainer.gen_model_input_fn = mock.Mock(return_value=dummy_batch) - trainer.strategy.labels_fn.return_value = None + trainer.strategy.create_labels.return_value = None # 4. Mock the forward pass to COUNT how many times it executes # We wrap the actual dummy model execution in a mock to track it. From 699ca4f25bec8acb99c2f64073d0a3c441aa782d Mon Sep 17 00:00:00 2001 From: vlad-karp Date: Fri, 20 Mar 2026 23:37:57 +0000 Subject: [PATCH 2/2] fix merge --- .../post_train/distillation/distillation_utils.py | 3 --- .../post_train/distillation/train_distill.py | 15 --------------- tests/post_training/unit/train_distill_test.py | 8 ++------ 3 files changed, 2 insertions(+), 24 deletions(-) diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index 6cfc3a5b68..c33eb84bd1 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -18,14 +18,11 @@ model structures with Tunix's training interfaces. """ -<<<<<<< Updated upstream import pickle import tensorflow as tf from array_record.python import array_record_module -======= import abc ->>>>>>> Stashed changes from typing import Any, Iterator, Optional, List, Callable import flax diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 53f4c1f3a7..2869574577 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -530,7 +530,6 @@ def train_distill( raw_train_iter = _setup_and_restore_input_pipeline(trainer, raw_train_iter, student_config, train_config) # 8. Configure Input Mapping -<<<<<<< Updated upstream def custom_gen_model_input_fn(batch): inputs_dict = { "input_tokens": batch.input_tokens, @@ -560,20 +559,6 @@ def custom_gen_model_input_fn(batch): return inputs_dict trainer = trainer.with_gen_model_input_fn(custom_gen_model_input_fn) -======= - trainer = trainer.with_gen_model_input_fn( - lambda batch: { - "input_tokens": batch.input_tokens, - "positions": batch.positions, - "attention_mask": batch.input_mask, - "decoder_segment_ids": batch.decoder_segment_ids, - "targets": batch.targets, # Passed to strategy (create_labels) - "targets_position": batch.targets_position, # Passed to strategy (create_labels) - "targets_segmentation": batch.targets_segmentation, # Passed to strategy (create_labels) - "cache": None, - } - ) ->>>>>>> Stashed changes # 9. Create Iterator Wrappers (Use Utils) train_iter = distillation_utils.MaxTextToTunixIterator(raw_train_iter) diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index d93161a9f0..4565c2a3d6 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -359,12 +359,10 @@ def test_monitored_strategy_sft(self): def _test_monitored_strategy(self, sft_mode: bool): """Verifies the strategy calculates metrics and returns the correct tuple.""" - mock_config = mock.Mock() - mock_config.vocab_size = 4 strategy = distillation_utils.CombinedDistillationStrategy( student_forward_fn=lambda m, **k: None, teacher_forward_fn=lambda m, **k: None, - vocab_size=mock_config.vocab_size, + vocab_size=4, temperature=1.0, alpha=0.5, beta_feature=1.0, @@ -413,12 +411,10 @@ def _test_monitored_strategy(self, sft_mode: bool): def verify_strategy_compute_eval_loss(self): """Covers MonitoredLogitStrategy.compute_eval_loss.""" - mock_config = mock.Mock() - mock_config.vocab_size = 4 strategy = distillation_utils.CombinedDistillationStrategy( student_forward_fn=mock.Mock(), teacher_forward_fn=mock.Mock(), - vocab_size=mock_config.vocab_size, + vocab_size=4, # student_config=mock_config, temperature=1.0, alpha=0.5,