Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 90 additions & 10 deletions src/maxtext/trainers/post_train/distillation/distillation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,28 @@ class DistillationForwardOutput:
"""Dataclass to carry MaxText-specific output fields."""

#: logits
logits: jax.Array = None
logits: jax.Array
#: out_projection_activations
out_projection_activations: jax.Array = None
out_projection_activations: jax.Array | None = None


@flax.struct.dataclass(frozen=True)
class MaxTextTrainingInput(peft_trainer.TrainingInput):
"""Extended TrainingInput dataclass to carry MaxText-specific fields."""

#: Position indices for the tokens (for RoPE).
positions: jax.Array = None
positions: jax.Array | None = None
#: Segment IDs for packed sequences (0=padding, 1+=examples).
decoder_segment_ids: jax.Array = None
decoder_segment_ids: jax.Array | None = None
#: Ground truth target tokens (used for loss calculation and logging).
targets: jax.Array = None
targets: jax.Array | None = None
#: Position indices for the target tokens.
targets_position: jax.Array = None
targets_position: jax.Array | None = None
#: Segment IDs for packed target tokens.
targets_segmentation: jax.Array = None
targets_segmentation: jax.Array | None = None
#: Top-K logits from the teacher model.
top_k_logits: jax.Array = None
top_k_indices: jax.Array = None
top_k_logits: jax.Array | None = None
top_k_indices: jax.Array | None = None


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -275,7 +275,7 @@ def compute_loss(
# 3. Combine losses
base_logit_loss = (self.alpha * soft_loss) + ((1.0 - self.alpha) * hard_loss)

feature_loss = 0.0
feature_loss = jnp.array(0.0)
if self.beta_feature > 0.0:

if self.layer_indices is not None:
Expand Down Expand Up @@ -420,6 +420,86 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F
force=force,
)

def maybe_restore(
self,
model: Any,
optimizer: Any = None,
restore_only_lora_params: bool = False,
) -> tuple[int, dict[str, Any]]:
"""Restores model and optimizer state if a checkpoint exists, using correct sharding specs.

This method checks for the latest available checkpoint. If found, it restores the
model parameters and optionally the optimizer state in-place. It automatically
maps the parameter's `sharding` attributes to Orbax restore arguments to ensure
the tensors are placed on the correct device meshes.

Args:
model: The model to restore. If a `ModelBundle` is provided, it automatically
extracts and restores only the `student_model`.
optimizer: The optimizer state to restore. If None, optimizer restoration is skipped.
restore_only_lora_params: If True, restricts restoration to parameters marked
as `nnx.LoRAParam`.

Returns:
A tuple containing the restored step number (0 if no checkpoint was found)
and a dictionary of custom metadata.
"""
if self._checkpoint_manager is None:
return 0, {}

step = self._checkpoint_manager.latest_step()
if step is None:
return 0, {}

max_logging.log(f"Restoring from checkpoint step {step}...")

# Extract student model safely
target_model = getattr(model, "student_model", model)

if restore_only_lora_params:
params = nnx.state(target_model, nnx.LoRAParam)
else:
params = nnx.state(target_model)

def map_to_pspec(data):
if hasattr(data, "sharding"):
return checkpoint.type_handlers.ArrayRestoreArgs(sharding=data.sharding)
return None

restore_args = jax.tree.map(map_to_pspec, params)

cp_restore_args = {
"model_params": checkpoint.args.PyTreeRestore(
item=params,
restore_args=restore_args,
)
}

if optimizer is not None:
optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState)
opt_restore_args = jax.tree.map(map_to_pspec, optimizer_state)
cp_restore_args["optimizer_state"] = checkpoint.args.PyTreeRestore(
item=optimizer_state,
restore_args=opt_restore_args,
)

restored = self._checkpoint_manager.restore(
step,
args=checkpoint.args.Composite(**cp_restore_args),
)

nnx.update(target_model, restored.model_params)
if optimizer is not None:
nnx.update(optimizer, restored.optimizer_state)

metadata = self._checkpoint_manager.metadata(step)
if metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None:
custom_metadata = metadata.custom_metadata
else:
Comment on lines +495 to +498

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 If `metadata` returned by `self._checkpoint_manager.metadata(step)` is a dictionary (which depends on the Orbax version), `hasattr(metadata, "custom_metadata")` will evaluate to `False`, and `custom_metadata` will silently default to an empty dictionary, losing the metadata.

Consider handling the case where metadata is a dictionary directly:

Suggested change
metadata = self._checkpoint_manager.metadata(step)
if metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None:
custom_metadata = metadata.custom_metadata
else:
if isinstance(metadata, dict):
custom_metadata = metadata
elif metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None:
custom_metadata = metadata.custom_metadata
else:
custom_metadata = {}

custom_metadata = {}

return step, dict(custom_metadata)

def restore_iterator(self):
"""Restores the iterator using MaxText's logic."""
if self._checkpoint_manager is None or self._iterator is None:
Expand Down
Loading
Loading