Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 95 additions & 13 deletions src/maxtext/trainers/post_train/distillation/distillation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tensorflow as tf
from array_record.python import array_record_module

import abc
from typing import Any, Iterator, Optional, List, Callable

import flax
Expand Down Expand Up @@ -182,21 +183,87 @@ def __next__(self) -> MaxTextTrainingInput:
# -----------------------------------------------------------------------------
# Distillation Strategy
# -----------------------------------------------------------------------------
class CombinedDistillationStrategy:


class DistillationStrategy(abc.ABC):
"""Abstract base class for MaxText Distillation Strategies."""

def __init__(
self, student_forward_fn: Callable, teacher_forward_fn: Callable, vocab_size: int, pad_id: int = 0, **kwargs
):
"""Initializes the generic distillation strategy.

Args:
student_forward_fn: Function to compute student model outputs.
teacher_forward_fn: Function to compute teacher model outputs.
vocab_size: The size of the model's vocabulary.
pad_id: The ID used for padding tokens.
"""
self.student_forward_fn = student_forward_fn
self.teacher_forward_fn = teacher_forward_fn
self.vocab_size = vocab_size
self.pad_id = pad_id

@abc.abstractmethod
def compute_loss(
self,
student_output: "DistillationForwardOutput",
teacher_output: "DistillationForwardOutput",
labels: jax.Array,
) -> tuple[jax.Array, dict[str, jax.Array]]:
"""Computes the distillation loss.

Args:
student_output: The forward pass output of the student model.
teacher_output: The forward pass output of the frozen teacher model.
labels: The masked one-hot encoded ground truth labels.

Returns:
A tuple containing the scalar loss and a dictionary of auxiliary metrics
(e.g., {"distill/soft_loss": ..., "distill/total_loss": ...})
"""
raise NotImplementedError

@abc.abstractmethod
def compute_eval_loss(
self,
student_output: "DistillationForwardOutput",
labels: jax.Array,
) -> tuple[jax.Array, dict[str, jax.Array]]:
"""Computes the evaluation loss (typically just the task loss).

Args:
student_output: The forward pass output of the student model.
labels: The masked one-hot encoded ground truth labels.

Returns:
A tuple containing the scalar loss and an empty (or auxiliary) dict.
"""
raise NotImplementedError

@abc.abstractmethod
def create_labels(self, targets: jax.Array, targets_segmentation: Optional[jax.Array] = None, **kwargs) -> jax.Array:
"""
Creates labels tensor to compute the loss
"""
raise NotImplementedError


class CombinedDistillationStrategy(DistillationStrategy):
"""Strategy that returns detailed metrics for TensorBoard."""

def __init__(
self,
student_forward_fn: Callable[..., DistillationForwardOutput],
teacher_forward_fn: Callable[..., DistillationForwardOutput],
labels_fn: Callable[..., jax.Array],
pad_id: int = 0,
temperature: float = 2.0,
alpha: float = 0.5,
beta_feature: float = 0.0,
layer_indices: Optional[List[int]] = None,
feature_loss_fn: Callable[[jax.Array, jax.Array], jax.Array] | None = None,
cosine_distance_axis: int | tuple[int, ...] = -1,
sft_mode: bool = False,
vocab_size: int = 0,
):
"""Initializes the Combined strategy using tunix logit.LogitStrategy.

Expand All @@ -213,9 +280,14 @@ def __init__(
cosine_distance_axis: The axis to use for cosine distance computation if
feature_loss_fn is not provided. Defaults to -1.
"""
self.student_forward_fn = student_forward_fn
self.teacher_forward_fn = teacher_forward_fn
self.labels_fn = labels_fn

super().__init__(
student_forward_fn=student_forward_fn,
teacher_forward_fn=teacher_forward_fn,
vocab_size=vocab_size,
pad_id=pad_id,
)

self.temperature = temperature
self.alpha = alpha
self.beta_feature = beta_feature
Expand All @@ -226,7 +298,6 @@ def __init__(
self.feature_loss_fn = lambda student_features, teacher_features: jnp.mean(
optax.cosine_distance(student_features, teacher_features, axis=cosine_distance_axis)
)
self.sft_mode = sft_mode

def compute_loss(
self,
Expand All @@ -253,10 +324,9 @@ def compute_loss(

log_student_probs_temp = jax.nn.log_softmax(s_logits / self.temperature, axis=-1)
teacher_probs_temp = jax.nn.softmax(t_logits / self.temperature, axis=-1)

# labels are supposed to have all sft masks applied by this moment
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True) if self.sft_mode else None
mean_mask = jnp.squeeze(labels_mask, axis=-1) if labels_mask is not None else None
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True)
mean_mask = jnp.squeeze(labels_mask, axis=-1)

# KL(Teacher || Student)
kl_div = optax.kl_divergence(log_student_probs_temp, teacher_probs_temp, where=labels_mask)
Expand Down Expand Up @@ -297,7 +367,7 @@ def compute_loss(
metrics = {
"distill/soft_loss": soft_loss,
"distill/hard_loss": hard_loss,
"distill/kl_div": jnp.mean(kl_div),
"distill/kl_div": jnp.mean(kl_div, where=mean_mask),
"distill/teacher_loss": teacher_hard_loss,
"distill/out_proj_feature_loss": feature_loss,
"distill/total_loss": total_loss,
Expand All @@ -316,12 +386,24 @@ def compute_eval_loss(
# Parent logic for task loss
# We re-implement simple CE here to ensure float32 casting
s_logits = student_output.logits.astype(jnp.float32)
ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels)
task_loss = jnp.mean(ce_loss)

labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True)
mean_mask = jnp.squeeze(labels_mask, axis=-1)
ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels, where=labels_mask)
task_loss = jnp.mean(ce_loss, where=mean_mask)

# Must return a tuple because _has_aux=True expects it
return task_loss, {}

def create_labels(self, targets, targets_segmentation=None, **kwargs):
"""Converts integer targets to masked one-hot vectors for hard label loss."""
del kwargs # Unused
one_hot = jax.nn.one_hot(targets, self.vocab_size)
mask = jnp.not_equal(targets, self.pad_id).astype(one_hot.dtype)[..., None]
if targets_segmentation is not None:
mask = mask * (targets_segmentation != 0)[..., None]
return one_hot * mask


# -----------------------------------------------------------------------------
# Checkpoint Manager
Expand Down
17 changes: 4 additions & 13 deletions src/maxtext/trainers/post_train/distillation/train_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class MaxTextDistillationTrainer(peft_trainer.PeftTrainer):
(positions, segment_ids) are passed to the model.
"""

def __init__(self, model, strategy, optimizer, training_config, **kwargs):
def __init__(self, model, strategy: distillation_utils.DistillationStrategy, optimizer, training_config, **kwargs):
# We pass a dummy optimizer to the base PeftTrainer temporarily to prevent PeftTrainer from eagerly
# allocating massive optimizer states for the entire ModelBundle (including the frozen teacher) before
# redefining the trainer optimizer here.
Expand Down Expand Up @@ -245,7 +245,7 @@ def loss_wrapper(student, teacher, batch):
cache=None,
)
# we should apply a mask for labels to disable segment-separator tokens
labels = self.strategy.labels_fn(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
return self.strategy.compute_loss(student_output, teacher_output, labels)

# Because student is the 0th argument, argnums=0 guarantees
Expand Down Expand Up @@ -274,7 +274,7 @@ def _eval_step(self, model, inputs):
decoder_segment_ids=inputs.get("decoder_segment_ids"),
cache=None,
)
labels = self.strategy.labels_fn(inputs["targets"])
labels = self.strategy.create_labels(inputs["targets"], targets_segmentation=inputs.get("targets_segmentation", None))
return self.strategy.compute_eval_loss(student_output, labels)

def _prepare_inputs(
Expand Down Expand Up @@ -454,14 +454,6 @@ def train_distill(
teacher_model.eval()

# 3. Define Distillation Strategy
def labels_fn(targets, targets_segmentation=None, **kwargs):
"""Converts integer targets to masked one-hot vectors for hard label loss."""
del kwargs # Unused
one_hot = jax.nn.one_hot(targets, student_config.vocab_size)
mask = jnp.not_equal(targets, pad_id).astype(one_hot.dtype)[..., None]
if targets_segmentation is not None:
mask = mask * (targets_segmentation != 0)[..., None]
return one_hot * mask

# Both Student and Teacher use the same forward logic via the adapter
student_forward_fn = create_forward_fn(student_config)
Expand All @@ -471,12 +463,11 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
strategy = distillation_utils.CombinedDistillationStrategy(
student_forward_fn=student_forward_fn,
teacher_forward_fn=teacher_forward_fn,
labels_fn=labels_fn,
temperature=student_config.distill_temperature,
alpha=student_config.distill_alpha,
beta_feature=student_config.distill_beta,
layer_indices=student_config.distill_layer_indices,
sft_mode=student_config.use_sft,
vocab_size=student_config.vocab_size,
)

# 4. Optimizer & Config
Expand Down
70 changes: 55 additions & 15 deletions tests/post_training/unit/train_distill_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
)

# Verify loss computation and optimizer update
trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"], targets_segmentation=None)
trainer.strategy.create_labels.assert_called_once_with(mock_batch["targets"], targets_segmentation=None)
trainer.strategy.compute_loss.assert_called_once()
optimizer.update.assert_called_once_with(student_model, mock_grads)

Expand Down Expand Up @@ -291,7 +291,7 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_
loss_wrapper(student_model, teacher_model, mock_batch)

# 6. Assertions
trainer.strategy.labels_fn.assert_called_once_with(
trainer.strategy.create_labels.assert_called_once_with(
mock_batch["targets"], targets_segmentation=mock_targets_segmentation
)
trainer.strategy.student_forward_fn.assert_called_once_with(
Expand Down Expand Up @@ -362,12 +362,11 @@ def _test_monitored_strategy(self, sft_mode: bool):
strategy = distillation_utils.CombinedDistillationStrategy(
student_forward_fn=lambda m, **k: None,
teacher_forward_fn=lambda m, **k: None,
labels_fn=lambda t: t,
vocab_size=4,
temperature=1.0,
alpha=0.5,
beta_feature=1.0,
layer_indices=None,
sft_mode=sft_mode,
)

# Dummy inputs (batch=1, seq=2, vocab=4)
Expand Down Expand Up @@ -410,18 +409,15 @@ def _test_monitored_strategy(self, sft_mode: bool):
self.assertLess(metrics["distill/kl_div"], 1e-5)
self.assertLess(metrics["distill/out_proj_feature_loss"], 1e-5)

def test_strategy_compute_eval_loss(self):
self._verify_strategy_compute_eval_loss(sft_mode=False)

def _verify_strategy_compute_eval_loss(self, sft_mode):
def verify_strategy_compute_eval_loss(self):
"""Covers MonitoredLogitStrategy.compute_eval_loss."""
strategy = distillation_utils.CombinedDistillationStrategy(
student_forward_fn=mock.Mock(),
teacher_forward_fn=mock.Mock(),
labels_fn=mock.Mock(),
vocab_size=4,
# student_config=mock_config,
temperature=1.0,
alpha=0.5,
sft_mode=sft_mode,
)
# Case where feature loss is enabled
logits = distillation_utils.DistillationForwardOutput(
Expand All @@ -443,8 +439,51 @@ def _verify_strategy_compute_eval_loss(self, sft_mode):
self.assertTrue(isinstance(loss, jax.Array))
self.assertEqual(aux, {})

def test_strategy_compute_eval_loss_sft(self):
self._verify_strategy_compute_eval_loss(sft_mode=True)
def test_strategy_ignores_segmentation_zero_tokens(self):
"""Verifies that 0 tokens in targets_segmentation are ignored in loss computation."""
strategy = distillation_utils.CombinedDistillationStrategy(
student_forward_fn=mock.Mock(),
teacher_forward_fn=mock.Mock(),
vocab_size=4,
temperature=1.0,
alpha=0.5,
pad_id=0,
)

# 1. Leverage the targets_segmentation tensor and put a 0 token in between.
# Token 1 is a delimiter (targets_segmentation = 0).
targets = jnp.array([[2, 1, 3]])
targets_segmentation = jnp.array([[1, 0, 1]])

# 2. Create labels with the zeroed out segment delimiter mask.
labels = strategy.create_labels(targets, targets_segmentation=targets_segmentation)

# Student has all predictions incorrect
s_logits = jnp.array(
[
[
[10.0, -10.0, -10.0, -10.0],
[-10.0, 10.0, -10.0, -10.0],
[-10.0, 10.0, -10.0, -10.0],
]
] # correct
)
student_output = distillation_utils.DistillationForwardOutput(logits=s_logits, out_projection_activations=None)

# Teacher perfectly predicts the target for Token 0 and Token 2, and class 1 for Token 1
t_logits = jnp.array([[[-10.0, -10.0, 10.0, -10.0], [10.0, -10.0, -10.0, -10.0], [-10.0, -10.0, -10.0, 10.0]]])
teacher_output = distillation_utils.DistillationForwardOutput(logits=t_logits, out_projection_activations=None)

# 3. Call compute_loss()
_, metrics = strategy.compute_loss(student_output, teacher_output, labels)

# all tokens are predicted incorrect so the loss should be 10*2 since
# token at position 1 should be excluded from the loss
# mean kl_div should also be equal to 20
self.assertTrue(19.0 < metrics["distill/hard_loss"] < 21.0)
self.assertTrue(19.0 < metrics["distill/soft_loss"] < 21.0)
self.assertTrue(19.0 < metrics["distill/kl_div"] < 21.0)
self.assertTrue(metrics["distill/teacher_loss"] == 0.0)

def test_setup_pipeline_grain_enabled(self):
"""Covers _setup_and_restore_input_pipeline when Grain IS detected."""
Expand Down Expand Up @@ -515,6 +554,7 @@ def test_eval_step_calls_student_forward(self):
"attention_mask": mock.Mock(),
"decoder_segment_ids": mock.Mock(),
"targets": mock.Mock(),
"targets_segmentation": None,
}
trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch)

Expand All @@ -528,7 +568,7 @@ def test_eval_step_calls_student_forward(self):
trainer.strategy.student_forward_fn.return_value = mock_student_output

mock_labels = mock.Mock()
trainer.strategy.labels_fn.return_value = mock_labels
trainer.strategy.create_labels.return_value = mock_labels

mock_loss = mock.Mock()
trainer.strategy.compute_eval_loss.return_value = mock_loss
Expand Down Expand Up @@ -557,7 +597,7 @@ def test_eval_step_calls_student_forward(self):
trainer.strategy.teacher_forward_fn.assert_not_called()

# Verify loss computation pipeline
trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"])
trainer.strategy.create_labels.assert_called_once_with(mock_batch["targets"], targets_segmentation=None)
trainer.strategy.compute_eval_loss.assert_called_once_with(mock_student_output, mock_labels)

# Verify it returns the correct loss
Expand Down Expand Up @@ -643,7 +683,7 @@ def __call__(self, x):
"teacher_output": jnp.array([1.0, 1.0]),
}
trainer.gen_model_input_fn = mock.Mock(return_value=dummy_batch)
trainer.strategy.labels_fn.return_value = None
trainer.strategy.create_labels.return_value = None

# 4. Mock the forward pass to COUNT how many times it executes
# We wrap the actual dummy model execution in a mock to track it.
Expand Down
Loading