diff --git a/modelopt/torch/puzzletron/__init__.py b/modelopt/torch/puzzletron/__init__.py index 15389dedfa2..0af53b5cef3 100644 --- a/modelopt/torch/puzzletron/__init__.py +++ b/modelopt/torch/puzzletron/__init__.py @@ -19,6 +19,7 @@ anymodel, block_config, build_library_and_stats, + bypass_distillation, dataset, entrypoint, mip, diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py index 3c1749d46ec..58b045bd21c 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py @@ -169,6 +169,19 @@ def uses_autocast() -> bool: """ return True + @staticmethod + def pruning_mixins() -> Dict[str, Any]: + """Return available pruning mixins for bypass distillation. + + Override in subclasses to provide model-specific pruning mixins, e.g. + ``{"kv_heads": KVHeadsPruningMixIn(...), "experts_removal": ExpertRemovalPruningMixIn(...)}``. + + Returns an empty dict by default so that descriptors that do not need + model-specific weight-slicing (e.g. Llama with standard FFN truncation) + can rely on the generic ``create_child_state_dict`` fallback path. + """ + return {} + @staticmethod def get_language_model_config(config): """Get the language model config from a PretrainedConfig. diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py index c8fd86b4bb6..1abecdec0c2 100644 --- a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py @@ -28,6 +28,7 @@ ExpertRemovalLayerDescriptor, ExpertRemovalPruningMixIn, ) +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn # Expert removal is supported for unquantized models (test models). # Production models use MXFP4 quantized MoE with combined tensors @@ -37,7 +38,11 @@ from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size -__all__ = ["GptOssModelDescriptor", "GptOssExpertRemovalLayerDescriptor"] +__all__ = [ + "GptOssExpertRemovalLayerDescriptor", + "GptOssKVHeadsLayerDescriptor", + "GptOssModelDescriptor", +] @ModelDescriptorFactory.register_decorator("gpt_oss") @@ -173,7 +178,29 @@ def pruning_mixins() -> Dict[str, PruningMixIn]: Note: Expert removal works for unquantized models (test models). Production models use MXFP4 quantization which is not yet supported. """ - return {"expert_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())} + # Single instance shared between the canonical key and the legacy alias + # so resolve_pruning_mixin returns the same object regardless of which + # name a caller uses. + expert_mixin = ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor()) + return { + "experts_removal": expert_mixin, + # Backward-compat alias: this key was "expert_removal" before the + # bypass branch standardised on "experts_removal" (matching the + # NemotronH descriptor). Kept so external scripts that still call + # `resolve_pruning_mixin("expert_removal", GptOssModelDescriptor)` + # continue to work. Remove after a deprecation cycle. + "expert_removal": expert_mixin, + "kv_heads": KVHeadsPruningMixIn(GptOssKVHeadsLayerDescriptor()), + } + + +@dataclass +class GptOssKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) @dataclass diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py index 1c5706d1944..b3f33887367 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py @@ -29,11 +29,16 @@ ExpertRemovalLayerDescriptor, ExpertRemovalPruningMixIn, ) +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same -__all__ = ["NemotronHExpertRemovalLayerDescriptor", "NemotronHModelDescriptor"] +__all__ = [ + "NemotronHExpertRemovalLayerDescriptor", + "NemotronHKVHeadsLayerDescriptor", + "NemotronHModelDescriptor", +] def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: @@ -51,6 +56,15 @@ def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: return matches +@dataclass +class NemotronHKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "mixer.o_proj" + attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) + + @dataclass class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): target_name: str = "mixer.gate" @@ -251,4 +265,5 @@ def build_attention_predicates() -> Dict[str, re.Pattern]: def pruning_mixins() -> Dict[str, PruningMixIn]: return { "experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()), + "kv_heads": KVHeadsPruningMixIn(NemotronHKVHeadsLayerDescriptor()), } diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py index a1e326f2357..0c677f67542 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py @@ -29,11 +29,16 @@ FFNIntermediateLayerDescriptor, FFNIntermediatePruningMixIn, ) +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same -__all__ = ["NemotronHV2FFNIntermediateLayerDescriptor", "NemotronHV2ModelDescriptor"] +__all__ = [ + "NemotronHV2FFNIntermediateLayerDescriptor", + "NemotronHV2KVHeadsLayerDescriptor", + "NemotronHV2ModelDescriptor", +] def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: @@ -69,6 +74,15 @@ class NemotronHV2FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): linear_weight_names: List[str] = field(default_factory=lambda: ["down_proj", "up_proj"]) +@dataclass +class NemotronHV2KVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "mixer.o_proj" + attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) + + @ModelDescriptorFactory.register_decorator("nemotron_h_v2") class NemotronHV2ModelDescriptor(ModelDescriptor): _DECODER_LAYER_CLS: Type[nn.Module] = None @@ -251,5 +265,6 @@ def pruning_mixins() -> Dict[str, PruningMixIn]: "ffn_intermediate": FFNIntermediatePruningMixIn( NemotronHV2FFNIntermediateLayerDescriptor() ), + "kv_heads": KVHeadsPruningMixIn(NemotronHV2KVHeadsLayerDescriptor()), # TODO: Add expert removal support when ExpertRemovalPruningMixIn is migrated } diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py index aeedd419923..a0f9c95c6ce 100644 --- a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py @@ -26,9 +26,13 @@ ) from ....block_config import BlockConfig -from ....pruning.expert_removal_pruning_mixin import ExpertRemovalLayerDescriptor +from ....pruning.expert_removal_pruning_mixin import ( + ExpertRemovalLayerDescriptor, + ExpertRemovalPruningMixIn, +) from ....pruning.ffn_intermediate_pruning_mixin import FFNIntermediateLayerDescriptor -from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn +from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size @@ -56,6 +60,13 @@ def get_language_model_config(config): """Qwen3-VL has nested text_config for language model parameters.""" return config.text_config if hasattr(config, "text_config") else config + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + return { + "experts_removal": ExpertRemovalPruningMixIn(Qwen3VLExpertRemovalLayerDescriptor()), + "kv_heads": KVHeadsPruningMixIn(Qwen3VLKVHeadsLayerDescriptor()), + } + @staticmethod def decoder_layer_cls(): return Qwen3VLMoeTextDecoderLayer diff --git a/modelopt/torch/puzzletron/bypass_distillation/__init__.py b/modelopt/torch/puzzletron/bypass_distillation/__init__.py new file mode 100644 index 00000000000..119cbd5cdaf --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/__init__.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bypass distillation (blockwise local distillation) for the PUZZLE framework. + +This module implements Stage 1 of the PUZZLE pipeline: training alternative transformer +block configurations using per-block knowledge distillation from a teacher model. +""" + +from .training_loop import launch_bypass_distillation + +__all__ = ["launch_bypass_distillation"] diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py new file mode 100644 index 00000000000..c1677900fe7 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py @@ -0,0 +1,263 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint utilities for bypass distillation.""" + +import os +import re +from collections import OrderedDict +from pathlib import Path +from typing import Optional, Union + +import torch +from omegaconf import DictConfig +from tqdm import tqdm + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_checkpoint_from_shards +from modelopt.torch.puzzletron.tools.logger import aprint, mprint +from modelopt.torch.utils.robust_json import json_dump + +from .bypass_utils import load_bypass_state, update_bypass_checkpoint_state +from .stitched_model_factory import StitchedModuleDescriptor + + +def find_latest_run_dir(run_parent_dir: Union[str, Path]) -> str | None: + """Find the latest plain-step checkpoint directory within a run parent directory. + + Resume prefers the manifest's final checkpoint, then the latest plain step + checkpoint. It must not pick ``best-step-*`` because validation-best snapshots + can be stale relative to the latest optimizer state, nor ``start-step-*``. + """ + run_parent_dir = Path(run_parent_dir) + + state = load_bypass_state(run_parent_dir) + if state is not None: + checkpoints = state.get("checkpoints", {}) + for role in ("final", "resume"): + candidate = checkpoints.get(role) + if candidate and (Path(candidate) / "saving_completed").exists(): + return str(candidate) + + # Check for the "latest" symlink. Current checkpoints only update it for + # plain periodic resume checkpoints, but older runs may have pointed it at a + # best/start/final checkpoint. Validate the target name before accepting it. + latest_dir = run_parent_dir / "latest" + if ( + latest_dir.exists() + and re.match(r"^step-\d+-ckpt$", latest_dir.resolve().name) + and (latest_dir / "saving_completed").exists() + ): + return str(latest_dir) + + # Fallback: scan plain ``step-NNNNNN-ckpt`` directories only. + # Treat a missing parent dir as "no previous runs" rather than fatal — this + # handles two cases cleanly: a freshly-wiped bypass dir, and the race where + # non-master ranks reach this function before master finishes the + # ``set_experiment_dir`` mkdir on a shared filesystem. + if not run_parent_dir.exists(): + return None + step_re = re.compile(r"^step-(\d+)-ckpt$") + candidate_dirs: list[tuple[int, Path]] = [] + for d in run_parent_dir.iterdir(): + if not d.is_dir(): + continue + match = step_re.match(d.name) + if match: + candidate_dirs.append((int(match.group(1)), d)) + + if not candidate_dirs: + return None + + candidate_dirs.sort(key=lambda x: x[0], reverse=True) + for _, ckpt_dir in candidate_dirs: + if (ckpt_dir / "saving_completed").exists(): + return str(ckpt_dir) + return None + + +def load_local_state( + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_path: str | Path, +) -> None: + """Load optimizer and grad-scaler state for each stitched module. + + Weights are NOT loaded here — they live in the HF checkpoint at + ``checkpoint_path`` and must be loaded into the student model via + ``load_and_shard_model`` before this function runs (typically by setting + ``init_checkpoint_path`` to the resume directory). This avoids + persisting the same parameters twice (once in ``stitched/*.pth`` and + once in the HF state dict). + + Modifies ``stitched_module_descriptors`` in place. + """ + device = torch.device(f"cuda:{dist.local_rank()}") + load_dir = Path(checkpoint_path) + + if not load_dir.exists(): + raise RuntimeError(f'Can\'t load local state. "{load_dir}" does not exist.') + + for stitched_module_name, stitched_module_descriptor in stitched_module_descriptors.items(): + optimizer = stitched_module_descriptor.optimizer + grad_scaler = stitched_module_descriptor.grad_scaler + + if optimizer is not None: + optimizer_state_path = ( + load_dir / "stitched" / f"{stitched_module_name}.optimizer_state.pth" + ) + mprint( + f"Loading optimizer state for module {stitched_module_name} from {optimizer_state_path}" + ) + loaded_optimizer_state = torch.load( + optimizer_state_path, map_location=device, weights_only=True + ) + optimizer.load_state_dict(loaded_optimizer_state) + del loaded_optimizer_state + + # Restore GradScaler state (only relevant when use_grad_scaling=True; for the + # default bf16 / use_grad_scaling=False path the scaler is disabled and its + # state is a no-op, but we still load it if present for forward-compatibility). + # Older checkpoints predating this save path won't have the file — skip silently. + if grad_scaler is not None: + grad_scaler_state_path = ( + load_dir / "stitched" / f"{stitched_module_name}.grad_scaler.pth" + ) + if grad_scaler_state_path.exists(): + mprint( + f"Loading grad_scaler state for module {stitched_module_name} " + f"from {grad_scaler_state_path}" + ) + loaded_scaler_state = torch.load( + grad_scaler_state_path, map_location=device, weights_only=True + ) + grad_scaler.load_state_dict(loaded_scaler_state) + del loaded_scaler_state + + +def _save_local_file(obj, save_path: Path | str, overwrite=True): + save_path = Path(save_path) + if save_path.exists(): + if not overwrite: + mprint(f'WARNING: Local save path "{save_path}" already exists. Skipping') + return + torch.save(obj, save_path) + + +def _save_local_state( + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_dir: Path | str, + overwrite=True, +) -> None: + """Persist optimizer and grad-scaler state for each stitched module. + + Weights are intentionally NOT saved here. The same trainable parameters + would otherwise land on disk twice — once as ``stitched/{block}.state_dict.pth`` + and once as part of the HF checkpoint that ``save_bypass_checkpoint`` + writes at the top level via ``save_checkpoint(model, ...)``. The HF + checkpoint is the single source of truth for weights; this directory + only carries the optimizer/scaler state that the HF format doesn't + cover. + """ + save_dir = Path(checkpoint_dir) / "stitched" + + if dist.is_master(): + save_dir.mkdir(parents=True, exist_ok=True) + + # Main process creates the directory, so we must wait for it to finish + dist.barrier() + + for stitched_module_name, stitched_module_descriptor in tqdm( + stitched_module_descriptors.items() + ): + optimizer = stitched_module_descriptor.optimizer + grad_scaler = stitched_module_descriptor.grad_scaler + + if optimizer is not None: + optimizer_state_path = save_dir / f"{stitched_module_name}.optimizer_state.pth" + aprint( + f"Saving optimizer state for module {stitched_module_name} to {optimizer_state_path}" + ) + _save_local_file(optimizer.state_dict(), optimizer_state_path, overwrite=overwrite) + + # Persist GradScaler state. Required for correct resume when + # use_grad_scaling=True (state dict carries running scale + growth tracker). + # For the default bf16 / use_grad_scaling=False path the state dict is trivial + # but cheap, so save unconditionally whenever a scaler exists — keeps the + # save/load paths symmetric with the optimizer. + if grad_scaler is not None: + grad_scaler_state_path = save_dir / f"{stitched_module_name}.grad_scaler.pth" + mprint( + f"Saving grad_scaler state for module {stitched_module_name} " + f"to {grad_scaler_state_path}" + ) + _save_local_file(grad_scaler.state_dict(), grad_scaler_state_path, overwrite=overwrite) + + dist.barrier() + + +def save_bypass_checkpoint( + cfg: DictConfig, + descriptor: ModelDescriptor, + model: torch.nn.Module, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_dir: Path | str, + reference_checkpoint_dir: Optional[Path] = None, + checkpoint_role: str = "resume", +) -> None: + """Save a bypass distillation checkpoint.""" + checkpoint_dir = Path(checkpoint_dir) + mprint("Starting checkpoint save") + mprint(f"Saving checkpoint to {checkpoint_dir}") + + # Save stitched module states + _save_local_state( + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=checkpoint_dir, + overwrite=cfg.bypass.model.model_overrides.delete_old_checkpoints, + ) + # Save as HF checkpoint. Must use the gather-aware variant: bypass training is + # pipeline-parallel so each rank's `model.state_dict()` only carries its own + # owned blocks. The unsharded `save_checkpoint` would have every rank write a + # partial `model.safetensors.index.json` to the same path (last writer wins), + # producing an index that omits most ranks' weights — resume then leaves params + # on the meta device. + save_checkpoint_from_shards(model=model, checkpoint_dir=checkpoint_dir, descriptor=descriptor) + + if dist.is_master(): + if checkpoint_role == "resume": + # Create 'latest' symlink via tmp-symlink + atomic rename so concurrent + # readers on a shared filesystem never observe a missing `latest`. The + # plain unlink + symlink_to pair leaves a brief window where the link + # doesn't exist; Path.replace (== os.replace) is atomic on POSIX. + latest_symlink = Path(cfg.bypass.experiment_dir) / "latest" + tmp_symlink = latest_symlink.with_name(f".latest_tmp_{os.getpid()}") + tmp_symlink.unlink(missing_ok=True) + tmp_symlink.symlink_to(checkpoint_dir.name) + tmp_symlink.replace(latest_symlink) + # Save config args json + json_dump(cfg.bypass, checkpoint_dir / "args.json") + model_factory_cfg = cfg.bypass.get("model_factory", {}) + json_dump( + {"keys_to_learn": model_factory_cfg.get("keys_to_learn", "entire_block")}, + checkpoint_dir / "bypass_config.json", + ) + # Save completed file + completed_file = checkpoint_dir / "saving_completed" + completed_file.touch() + update_bypass_checkpoint_state(cfg, checkpoint_dir, checkpoint_role) + + dist.barrier() + mprint("Checkpoint save done") diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py new file mode 100644 index 00000000000..1a5c7feb21f --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py @@ -0,0 +1,402 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for bypass distillation.""" + +import hashlib +import json +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig, ListConfig, OmegaConf + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.utils.robust_json import json_dump, json_load + +BYPASS_STATE_FILENAME = "bypass_state.json" +BYPASS_SUBBLOCK_KEYS_TO_LEARN = frozenset( + {"subblock_ffn", "subblock_attention", "subblock_mamba", "entire_block"} +) + + +def _to_plain_container(value: Any) -> Any: + if isinstance(value, (DictConfig, ListConfig)): + return OmegaConf.to_container(value, resolve=True) + return value + + +def normalize_keys_to_learn(keys_to_learn: Any) -> dict[str, Any]: + """Normalize bypass ``keys_to_learn`` into v1 subblock semantics.""" + keys_to_learn = _to_plain_container(keys_to_learn) + if isinstance(keys_to_learn, str): + if keys_to_learn in BYPASS_SUBBLOCK_KEYS_TO_LEARN: + return {"mode": "subblocks", "subblocks": (keys_to_learn,)} + raise ValueError( + "keys_to_learn must be one of " + f"{sorted(BYPASS_SUBBLOCK_KEYS_TO_LEARN)}, got {keys_to_learn!r}" + ) + + if isinstance(keys_to_learn, Sequence): + values = tuple(keys_to_learn) + if not all(isinstance(value, str) for value in values): + raise TypeError(f"keys_to_learn entries must be strings, got {keys_to_learn!r}") + if not values: + raise ValueError("keys_to_learn cannot be empty") + invalid = [value for value in values if value not in BYPASS_SUBBLOCK_KEYS_TO_LEARN] + if invalid: + raise ValueError( + "keys_to_learn supports only subblock keys in v1; " + f"invalid entries: {invalid!r}" + ) + if "entire_block" in values and len(set(values)) > 1: + raise ValueError("keys_to_learn cannot mix 'entire_block' with other subblock keys") + return {"mode": "subblocks", "subblocks": tuple(dict.fromkeys(values))} + + raise TypeError(f"Unsupported keys_to_learn={keys_to_learn!r}") + + +def learned_subblocks_from_keys_to_learn(keys_to_learn: Any) -> list[str]: + """Return replacement-library subblocks represented by ``keys_to_learn``.""" + normalized = normalize_keys_to_learn(keys_to_learn) + subblocks = set(normalized["subblocks"]) + if subblocks == {"entire_block"}: + return ["block"] + + out: list[str] = [] + if "subblock_attention" in subblocks or "subblock_mamba" in subblocks: + out.append("attention") + if "subblock_ffn" in subblocks: + out.append("ffn") + return out + + +def _slug(value: Any) -> str: + text = str(value).strip().lower().replace("subblock_", "") + keep = [ch if ch.isalnum() else "_" for ch in text] + slug = "".join(keep).strip("_") + while "__" in slug: + slug = slug.replace("__", "_") + return slug or "custom" + + +def get_bypass_run_identity(cfg: DictConfig) -> dict[str, Any]: + """Return the config subset that defines a bypass output. + + The full Hydra config carries mutable runtime counters, checkpoint paths and + logging fields. Those should not decide whether a completed bypass run can + be reused. This identity intentionally keeps architecture, training budget, + data shape and learning-target fields, because changing any of them changes + the produced checkpoint. + """ + bypass = _to_plain_container(cfg.bypass) + training = bypass.get("training", {}) + data = bypass.get("data", {}) + model = bypass.get("model", {}) + model_factory = bypass.get("model_factory", {}) + return { + "model": { + "student_weights_dtype": model.get("student_weights_dtype"), + "model_config_overrides": model.get("model_config_overrides"), + }, + "model_factory": { + "factory": model_factory.get("factory"), + "block_loss_func": model_factory.get("block_loss_func"), + "gqa_init_mode": model_factory.get("gqa_init_mode"), + "mlp_init_mode": model_factory.get("mlp_init_mode"), + "mlp_init_config": model_factory.get("mlp_init_config"), + "linear_init_mode": model_factory.get("linear_init_mode"), + "submodule_for_loss_calculation": model_factory.get("submodule_for_loss_calculation"), + "keys_to_learn": model_factory.get("keys_to_learn"), + }, + "training": { + "learning_rate": training.get("learning_rate"), + "training_tokens": training.get("training_tokens"), + "micro_batch_size": training.get("micro_batch_size"), + "grad_accumulation_steps": training.get("grad_accumulation_steps"), + "weight_decay": training.get("weight_decay"), + "decay_lr": training.get("decay_lr"), + "beta1": training.get("beta1"), + "beta2": training.get("beta2"), + "grad_clip": training.get("grad_clip"), + "grad_clip_type": training.get("grad_clip_type"), + "warmup_ratio": training.get("warmup_ratio"), + "min_lr_factor": training.get("min_lr_factor"), + }, + "data": { + "dataset_path": cfg.get("dataset_path", None), + "block_size": data.get("block_size"), + "data_column": data.get("data_column"), + "fim_rate": data.get("fim_rate"), + "fim_spm_rate": data.get("fim_spm_rate"), + "bos_rate": data.get("bos_rate"), + "source_datasets_to_discard": data.get("source_datasets_to_discard"), + "load_from_disk": data.get("load_from_disk"), + "keep_in_memory": data.get("keep_in_memory"), + "shuffle_train_data_seed": data.get("shuffle_train_data_seed"), + "val_dataset_name": data.get("val_dataset_name"), + "max_eval_samples": data.get("max_eval_samples"), + "eval_samples_per_process": data.get("eval_samples_per_process"), + }, + "validation": { + "disable_validation": bypass.get("disable_validation"), + "save_best_ckpt": bypass.get("save_best_ckpt"), + "realize_best_or_latest": bypass.get("realize_best_or_latest"), + "eval_interval": training.get("eval_interval"), + "val_micro_batch_size": training.get("val_micro_batch_size"), + }, + "seed": bypass.get("seed"), + "dtype": bypass.get("dtype"), + } + + +def get_bypass_config_fingerprint(cfg: DictConfig) -> str: + identity = get_bypass_run_identity(cfg) + payload = json.dumps(identity, sort_keys=True, default=str, separators=(",", ":")) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + +def get_bypass_experiment_fingerprint(cfg: DictConfig) -> str: + """Return a stable ID fingerprint for the architecture and learning target. + + Training budget and data settings are deliberately excluded so a longer + rerun can resume the same architecture from its previous final checkpoint. + The full config fingerprint is still recorded in bypass_state.json and used + for skip-if-complete decisions. + """ + identity = get_bypass_run_identity(cfg) + experiment_identity = { + "model": identity["model"], + "model_factory": { + "factory": identity["model_factory"]["factory"], + "block_loss_func": identity["model_factory"]["block_loss_func"], + "keys_to_learn": identity["model_factory"]["keys_to_learn"], + "gqa_init_mode": identity["model_factory"]["gqa_init_mode"], + "mlp_init_mode": identity["model_factory"]["mlp_init_mode"], + "mlp_init_config": identity["model_factory"]["mlp_init_config"], + "linear_init_mode": identity["model_factory"]["linear_init_mode"], + "submodule_for_loss_calculation": identity["model_factory"][ + "submodule_for_loss_calculation" + ], + }, + } + payload = json.dumps(experiment_identity, sort_keys=True, default=str, separators=(",", ":")) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + +def set_experiment_id(cfg: DictConfig) -> None: + """Set the experiment ID based on the model config overrides. + + The ID encodes every override that affects the produced student so that + sweeps over (FFN size × KV heads) or (num_experts × KV heads) get distinct + directories instead of clobbering each other. + """ + if cfg.bypass.experiment_id is not None: + return + + overrides = cfg.bypass.model.model_config_overrides + parts: list[str] = [] + + if "ffn" in overrides: + ffn_override = overrides.ffn[0] + if "intermediate_size" in ffn_override and ffn_override["intermediate_size"] is not None: + parts.append(f"ffn_{ffn_override['intermediate_size']}") + elif "moe" in ffn_override and ffn_override["moe"] is not None: + parts.append(f"experts_{ffn_override['moe']['num_local_experts']}") + + if "attention" in overrides: + attn_override = overrides.attention[0] + if ( + "num_key_value_heads" in attn_override + and attn_override["num_key_value_heads"] is not None + ): + parts.append(f"heads_{attn_override['num_key_value_heads']}") + + keys_to_learn = cfg.bypass.model_factory.get("keys_to_learn", None) + if keys_to_learn not in (None, "entire_block"): + parts.append(_slug(keys_to_learn)) + + if not parts: + parts.append("custom") + + # Keep the readable architecture prefix, but suffix it with the config + # fingerprint so two runs with the same architecture but different learning + # target or training budget cannot collide in the same experiment_dir. + cfg.bypass.experiment_id = "bypass_" + "_".join(parts) + cfg.bypass.experiment_id += f"_{get_bypass_experiment_fingerprint(cfg)[:8]}" + + +def set_experiment_dir(cfg: DictConfig) -> None: + """Set the experiment directory for the bypass run. + + Stores the path as a string in the OmegaConf node (OmegaConf only supports + primitive types natively). Use sites should reconstruct ``Path(...)`` as needed. + """ + experiment_dir = Path(cfg.puzzle_dir) / "bypass" / "bypass_runs" / cfg.bypass.experiment_id + cfg.bypass.experiment_dir = str(experiment_dir) + if dist.is_master(): + experiment_dir.mkdir(parents=True, exist_ok=True) + + +def get_bypass_state_path(experiment_dir: str | Path) -> Path: + return Path(experiment_dir) / BYPASS_STATE_FILENAME + + +def load_bypass_state(experiment_dir: str | Path) -> dict[str, Any] | None: + state_path = get_bypass_state_path(experiment_dir) + if not state_path.exists(): + return None + return json_load(state_path) + + +def write_bypass_state(cfg: DictConfig, state: dict[str, Any]) -> None: + if not dist.is_master(): + return + json_dump(state, get_bypass_state_path(cfg.bypass.experiment_dir)) + + +def _base_bypass_state(cfg: DictConfig) -> dict[str, Any]: + return { + "version": 1, + "experiment_id": cfg.bypass.get("experiment_id", None), + "config_fingerprint": get_bypass_config_fingerprint(cfg), + "identity": get_bypass_run_identity(cfg), + "status": "running", + "checkpoints": {}, + "realized_checkpoint": None, + "ckpts_symlink": None, + } + + +def update_bypass_checkpoint_state( + cfg: DictConfig, checkpoint_dir: str | Path, checkpoint_role: str +) -> None: + if not dist.is_master(): + return + state = load_bypass_state(cfg.bypass.experiment_dir) or _base_bypass_state(cfg) + state["status"] = "running" + state["config_fingerprint"] = get_bypass_config_fingerprint(cfg) + state["identity"] = get_bypass_run_identity(cfg) + state.setdefault("checkpoints", {})[checkpoint_role] = str(Path(checkpoint_dir)) + write_bypass_state(cfg, state) + + +def mark_bypass_run_completed( + cfg: DictConfig, realized_checkpoint: str | Path, ckpts_symlink: str | Path +) -> None: + state = load_bypass_state(cfg.bypass.experiment_dir) or _base_bypass_state(cfg) + state["status"] = "completed" + state["config_fingerprint"] = get_bypass_config_fingerprint(cfg) + state["identity"] = get_bypass_run_identity(cfg) + state["realized_checkpoint"] = str(realized_checkpoint) + state["ckpts_symlink"] = str(ckpts_symlink) + write_bypass_state(cfg, state) + if dist.is_master(): + (Path(cfg.bypass.experiment_dir) / "_DONE").touch() + + +def bypass_run_is_complete(cfg: DictConfig) -> bool: + state = load_bypass_state(cfg.bypass.experiment_dir) + if state is None: + return False + if state.get("status") != "completed": + return False + if state.get("config_fingerprint") != get_bypass_config_fingerprint(cfg): + return False + realized = state.get("realized_checkpoint") + symlink = state.get("ckpts_symlink") + if not realized or not Path(realized).exists(): + return False + if not symlink or not Path(symlink).exists(): + return False + return True + + +def expected_bypass_runs(cfg: DictConfig) -> list[dict[str, Any]]: + """Return expected run metadata for the current bypass config or sweep.""" + runs: list[dict[str, Any]] = [] + configs_list = cfg.bypass.get("configs", None) + overrides = configs_list if configs_list else [None] + + for override in overrides: + run_cfg = OmegaConf.create( + { + "puzzle_dir": cfg.puzzle_dir, + "dataset_path": cfg.get("dataset_path", None), + "descriptor": cfg.get("descriptor", None), + "bypass": OmegaConf.to_container(cfg.bypass, resolve=True), + } + ) + OmegaConf.set_struct(run_cfg, False) + if override: + run_cfg.bypass.experiment_id = None + if "model_config_overrides" in override: + run_cfg.bypass.model.model_config_overrides = override.model_config_overrides + if "keys_to_learn" in override: + run_cfg.bypass.model_factory.keys_to_learn = override.keys_to_learn + set_experiment_id(run_cfg) + experiment_dir = ( + Path(run_cfg.puzzle_dir) / "bypass" / "bypass_runs" / run_cfg.bypass.experiment_id + ) + runs.append( + { + "experiment_id": run_cfg.bypass.experiment_id, + "experiment_dir": str(experiment_dir), + "config_fingerprint": get_bypass_config_fingerprint(run_cfg), + } + ) + return runs + + +def get_distributed_modules_ownership(module_count: int, world_size: int) -> list[int]: + """Map module (block) indices to GPU ranks for pipeline-parallel distribution.""" + modules_process_ownership: list[int] = [] + + for i in range(world_size): + num_modules_for_process = module_count // world_size + if i < module_count % world_size: + num_modules_for_process += 1 + + modules_process_ownership.extend([i] * num_modules_for_process) + + return modules_process_ownership + + +def get_pipeline_ownership_context( + module_ownership: Sequence[int], rank: int | None = None +) -> dict[str, Any]: + """Return local module indices and neighboring pipeline ranks for ``rank``.""" + if rank is None: + rank = dist.rank() + owned_indices = [i for i, owner in enumerate(module_ownership) if owner == rank] + if not owned_indices: + raise RuntimeError( + f"rank {rank} owns no modules in pipeline ownership map {list(module_ownership)}" + ) + + min_owned_index = min(owned_indices) + max_owned_index = max(owned_indices) + prev_rank = None if min_owned_index == 0 else module_ownership[min_owned_index - 1] + next_rank = ( + None + if max_owned_index + 1 >= len(module_ownership) + else module_ownership[max_owned_index + 1] + ) + return { + "owned_indices": owned_indices, + "owned_index_set": set(owned_indices), + "prev_rank": prev_rank, + "next_rank": next_rank, + } diff --git a/modelopt/torch/puzzletron/bypass_distillation/data_classes.py b/modelopt/torch/puzzletron/bypass_distillation/data_classes.py new file mode 100644 index 00000000000..a6b37099ceb --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/data_classes.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data classes for bypass distillation training.""" + +import dataclasses +from typing import TypeAlias + +IterNum: TypeAlias = int +GlobalRank: TypeAlias = int + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class IterStatistics: + step_num: int + token_count: int + iter_duration: float + lr: float + clipping_count: int + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class LocalTrainingStats: + iter_num: int + stitched_module_losses: dict[str, float] + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class TimeToSaveSignal: + step_num: int diff --git a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py new file mode 100644 index 00000000000..41857721a9a --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py @@ -0,0 +1,635 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Factory for creating stitched teacher/student models for bypass distillation.""" + +import copy +import dataclasses +import re +from argparse import Namespace +from collections import OrderedDict +from contextlib import nullcontext +from pathlib import Path +from typing import Any, Callable, Mapping, Optional, Sequence + +import torch +from omegaconf import DictConfig, OmegaConf +from torch.amp.grad_scaler import GradScaler +from torch.optim import AdamW, Optimizer +from transformers import PretrainedConfig, PreTrainedModel + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.pruning.pruning_utils import GQAInitMode, LinearInitMode, MlpInitMode +from modelopt.torch.puzzletron.sewing_kit import ( + ExternalTarget, + FunctionTarget, + InputArgs, + ModuleTarget, + Needle, + RemoteTarget, + StitchedModule, + always_true_predicate, +) +from modelopt.torch.puzzletron.sewing_kit.core import InputReducer +from modelopt.torch.puzzletron.sewing_kit.utils import ( + batched_normalized_mse_loss, + normalized_mse_loss, + vectorwise_normalized_mse_loss, +) +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( + create_child_state_dict, + update_model_config, +) +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import create_sharded_model +from modelopt.torch.puzzletron.utils.parsing import format_block_configs, parse_dtype + +from .bypass_utils import get_pipeline_ownership_context, normalize_keys_to_learn + +StitchedModulesProcessOwnership = list[int] +SyncDistributedModelWeightsFn = Callable[[], None] +Config = Mapping[str, Any] +Args = Namespace + + +@dataclasses.dataclass +class StitchedModuleDescriptor: + stitched_module: StitchedModule + owned_parameters: dict[str, torch.nn.Parameter] + owned_buffers: dict[str, torch.Tensor] + optimizer: Optional[Optimizer] = None + grad_scaler: Optional[GradScaler] = None + + +def _autocast_context(descriptor: ModelDescriptor): + return ( + torch.autocast(device_type="cuda", dtype=torch.bfloat16) + if descriptor.uses_autocast() + else nullcontext() + ) + + +def _param_names_for_subblock_key( + model: PreTrainedModel, + descriptor: ModelDescriptor, + subblock_key: str, +) -> set[str]: + lm_config = descriptor.get_language_model_config(model.config) + weight_groups = descriptor.get_weight_groups( + model.state_dict().keys(), lm_config.num_hidden_layers + ) + + attn_group_names = [ + group_name for group_name in weight_groups.keys() if group_name.endswith("_attention") + ] + ffn_group_names = [ + group_name for group_name in weight_groups.keys() if group_name.endswith("_ffn") + ] + if subblock_key == "subblock_attention": + group_names = attn_group_names + elif subblock_key == "subblock_ffn": + group_names = ffn_group_names + elif subblock_key == "subblock_mamba": + group_names = attn_group_names # Mamba params live in _attention groups + elif subblock_key == "entire_block": + group_names = attn_group_names + ffn_group_names + else: + raise ValueError(f"Unsupported subblock key: {subblock_key!r}") + + # block_configs lives on the outer puzzletron-converted config for nested + # HF configs (for example Qwen3-VL), not necessarily on the language sub-config. + block_configs = getattr(model.config, "block_configs", None) or getattr( + lm_config, "block_configs", None + ) + + collected: list[str] = [] + for group_name in group_names: + if block_configs is not None: + m = re.match(r"block_(\d+)_attention", group_name) + if m: + block_idx = int(m.group(1)) + if block_idx < len(block_configs): + attention_cfg = getattr(block_configs[block_idx], "attention", None) + is_mamba = getattr(attention_cfg, "mamba", None) is not None + if subblock_key == "subblock_attention" and is_mamba: + continue + if subblock_key == "subblock_mamba" and not is_mamba: + continue + collected.extend(weight_groups[group_name]) + return set(collected) + + +def _set_keys_to_learn( + model: PreTrainedModel, + descriptor: ModelDescriptor, + keys_to_learn: str | Sequence[str], +) -> None: + """Set ``requires_grad=True`` on parameters selected by ``keys_to_learn``. + + Bypass v1 supports only descriptor-backed subblock keys. This keeps training + selection aligned with replacement-library extraction. + """ + normalized = normalize_keys_to_learn(keys_to_learn) + param_names = set() + for subblock_key in normalized["subblocks"]: + param_names.update(_param_names_for_subblock_key(model, descriptor, subblock_key)) + # In pipeline-parallel training a rank may own only blocks that don't match + # keys_to_learn (e.g. a rank with only Mamba blocks during subblock_attention + # bypass has no GQA params after the _mamba rename). That is a valid state: + # those blocks are tracked as non-trainable and omitted from numeric loss stats. + if not param_names: + return + + # Set requires_grad to True for the selected parameters. + for param_name, param in model.named_parameters(): + if param_name in param_names and torch.is_floating_point(param): + param.requires_grad_(True) + + +def _get_all_non_persistent_buffers_set(module: torch.nn.Module) -> set[str]: + all_non_persistent = set() + for module_name, submodule in module.named_modules(): + for buffer_name in submodule._non_persistent_buffers_set: + full_name = f"{module_name}.{buffer_name}" if module_name else buffer_name + all_non_persistent.add(full_name) + return all_non_persistent + + +def bypass_factory_fn( + teacher_model: PreTrainedModel, + descriptor: ModelDescriptor, + cfg: DictConfig, + model_blocks_process_ownership: Sequence[int], + student_model: Optional[PreTrainedModel] = None, +) -> tuple[ + PreTrainedModel, + StitchedModule, + StitchedModule, + StitchedModule, + OrderedDict[str, StitchedModuleDescriptor], + PretrainedConfig, +]: + """Unified factory function for bypass (blockwise local) distillation. + + Handles all layer types — FFN, attention (GQA/MHA), MoE experts, Mamba, and whole blocks — + through a single pipeline. Behavior is driven entirely by ``model_factory`` config fields: + + - ``mlp_init_mode``: how student FFN / MoE weights are initialised + - ``"ExpertRemoval"``: select top-N experts from teacher (MoE models) + - ``"Truncate"`` / ``"PruneByActivationsLog"``: prune FFN channels (dense models) + - ``"CopyAsIs"``: copy weights unchanged (attention-only or Mamba-only runs) + - ``gqa_init_mode``: how attention KV heads are initialised (optional, default ``AverageKV``). + Irrelevant when the student has the same number of KV heads as the teacher. + - ``keys_to_learn``: which subblock parameters to train. + Accepts ``"subblock_ffn"``, ``"subblock_attention"``, ``"subblock_mamba"``, + ``"entire_block"``, or a list of those keys. + + The stitching logic (pipeline-parallel per-block KD) is architecture-agnostic and unchanged + regardless of which layer type is being distilled. + + Args: + teacher_model: The teacher model to use for stitching. + descriptor: Model descriptor for layer naming and pruning mixin lookup. + cfg: The bypass config section. + model_blocks_process_ownership: Ownership mapping of model blocks to process ranks. + student_model: Optionally provided pre-built student model (skips initialisation). + + Returns: + Tuple of (student_model, teacher_stitched, teacher_val_stitched, + student_val_stitched, stitched_module_descriptors, student_config) + """ + device = torch.device(f"cuda:{dist.local_rank()}") + model_config_overrides = cfg.model.model_config_overrides + + _block_loss_funcs: dict[str, Callable[..., Any]] = { + "normalized_mse_loss": normalized_mse_loss, + "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss, + "batched_normalized_mse_loss": batched_normalized_mse_loss, + } + block_loss_func = _block_loss_funcs[cfg.model_factory.block_loss_func] + mprint(f"{block_loss_func.__name__=}") + + owned_block_indexes = set( + block_index + for block_index, owner_rank in enumerate(model_blocks_process_ownership) + if owner_rank == dist.rank() + ) + + # Initialize student_model + if student_model is None: + mprint("Creating student model from teacher model") + + with _autocast_context(descriptor): + if isinstance(model_config_overrides, DictConfig): + config_to_override = OmegaConf.to_container(model_config_overrides, resolve=True) + else: + config_to_override = model_config_overrides + mprint(f"{config_to_override=}") + student_model_config = update_model_config( + model_config=teacher_model.config, + model_config_overrides=config_to_override, + ) + student_model_config.use_cache = False + + mprint(f"Student model config:\n {format_block_configs(student_model_config)}") + + from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + + runtime = Namespace( + device=device, + dtype=torch.bfloat16, + global_rank=dist.rank(), + world_size=dist.size(), + is_main_process=dist.is_master(), + is_last_process=dist.is_last_process(), + ) + + with deci_x_patcher( + model_descriptor=descriptor, + block_configs=getattr(student_model_config, "block_configs", None), + ): + student_model = create_sharded_model( + runtime=runtime, + descriptor=descriptor, + model_config=student_model_config, + owned_block_indexes=owned_block_indexes, + device=device, + ) + # `_init_weights` is HF's per-module initializer; apply it across the + # whole model rather than passing the model itself as a single module. + student_model.apply(student_model._init_weights) + + student_weights_dtype = parse_dtype(cfg.model.student_weights_dtype) + descriptor.init_rotary_embedding(student_model, runtime) + student_model.type(student_weights_dtype) + + mlp_init_mode = MlpInitMode(cfg.model_factory.mlp_init_mode or MlpInitMode.CopyAsIs) + + # For expert removal, use the model-specific pruning mixin so that model-specific + # key paths (e.g. backbone.layers.{i}.mixer for Nemotron-H vs model.layers.{i}.mlp + # for GPT-OSS) are handled correctly. For all other init modes the legacy inline + # key logic in create_child_state_dict is sufficient. + _mixins = [] + if mlp_init_mode == MlpInitMode.ExpertRemoval: + _expert_mixin = descriptor.pruning_mixins().get("experts_removal") + if _expert_mixin is not None: + _mixins.append(_expert_mixin) + + # If any attention layer has fewer KV heads in the student than the teacher, use the + # model-specific KV heads mixin so that k_proj/v_proj weights are correctly sliced + # rather than copied verbatim from the (larger) teacher state dict. + _kv_mixin = descriptor.pruning_mixins().get("kv_heads") + if _kv_mixin is not None: + _student_kv = [ + b.attention.num_key_value_heads + for b in student_model_config.block_configs + if b.attention is not None and b.attention.num_key_value_heads is not None + ] + _teacher_kv = [ + b.attention.num_key_value_heads + for b in teacher_model.config.block_configs + if b.attention is not None and b.attention.num_key_value_heads is not None + ] + assert len(_student_kv) == len(_teacher_kv), ( + f"KV-head block-config length mismatch: student={len(_student_kv)} " + f"teacher={len(_teacher_kv)} — check model_config_overrides" + ) + if _student_kv != _teacher_kv: + _mixins.append(_kv_mixin) + + # If any FFN layer has a smaller intermediate_size in the student than the teacher, + # use the model-specific FFN-intermediate mixin. The generic create_child_state_dict + # path is hardcoded to `model.layers.{i}.mlp.*` (Llama-style), so for families that + # place FFN under a different prefix (e.g. `backbone.layers.{i}.mixer.*` for + # Nemotron-H/H_v2) the mixin is required to slice up_proj/down_proj correctly. + # Filter out no_op FFN blocks (their intermediate_size is None) — relevant for + # hybrid families where each layer is exactly one of {attention, ffn, mamba}. + _ffn_mixin = descriptor.pruning_mixins().get("ffn_intermediate") + if _ffn_mixin is not None and mlp_init_mode in ( + MlpInitMode.Truncate, + MlpInitMode.PruneByActivationsLog, + ): + _student_ffn = [ + b.ffn.intermediate_size + for b in student_model_config.block_configs + if b.ffn is not None and b.ffn.intermediate_size is not None + ] + _teacher_ffn = [ + b.ffn.intermediate_size + for b in teacher_model.config.block_configs + if b.ffn is not None and b.ffn.intermediate_size is not None + ] + assert len(_student_ffn) == len(_teacher_ffn), ( + f"FFN-intermediate block-config length mismatch: student={len(_student_ffn)} " + f"teacher={len(_teacher_ffn)} — check model_config_overrides" + ) + if _student_ffn != _teacher_ffn: + _mixins.append(_ffn_mixin) + + if len(_mixins) == 0: + pruning_mixin = None + elif len(_mixins) == 1: + pruning_mixin = _mixins[0] + else: + pruning_mixin = _mixins + + # GQA init mode is optional: only relevant when the student has fewer KV heads than + # the teacher. Defaults to AverageKV and is a no-op when head counts are equal. + gqa_init_mode = GQAInitMode(cfg.model_factory.get("gqa_init_mode", GQAInitMode.AverageKV)) + + student_state_dict = create_child_state_dict( + pruning_mixin=pruning_mixin, + descriptor=descriptor, + original_state_dict=teacher_model.state_dict(), + new_state_dict=student_model.state_dict(), + original_config=teacher_model.config, + new_config=student_model_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=cfg.model_factory.mlp_init_config, + owned_block_indexes=owned_block_indexes, + linear_init_mode=LinearInitMode( + cfg.model_factory.linear_init_mode or LinearInitMode.Random + ), + ) + + # Load student state dict + missing_keys, unexpected_keys = student_model.load_state_dict( + student_state_dict, strict=False + ) + assert len(unexpected_keys) == 0, f"{unexpected_keys=}" + # GQA models have learnable logit parameters not present in the teacher state dict; + # allow those to be absent and assert nothing else is missing. + non_gqa_missing = [k for k in missing_keys if not re.search(r"gqa_\w+_logits", k)] + assert len(non_gqa_missing) == 0, f"Unexpected missing keys: {non_gqa_missing}" + + else: + mprint("Student model provided explicitly, not using teacher model to instantiate") + student_model_config = student_model.config + + # Set up training parameters + lm_config = descriptor.get_language_model_config(student_model_config) + all_block_indices = list(range(lm_config.num_hidden_layers)) + + student_model.requires_grad_(False) + keys_to_learn = cfg.model_factory.keys_to_learn + mprint(f"Keys to learn: {keys_to_learn}") + + _set_keys_to_learn(model=student_model, descriptor=descriptor, keys_to_learn=keys_to_learn) + + dist.barrier() + mprint(f"Global rank: {dist.rank()}, {owned_block_indexes=}") + dist.barrier() + + torch.cuda.synchronize() + torch.cuda.empty_cache() + dist.barrier() + + # Every rank derives ownership from the same `model_blocks_process_ownership` + # list, so this guard fires identically on every rank when world_size exceeds + # num_hidden_layers — no NCCL hang from a single rank diverging. + ranks_with_blocks = set(model_blocks_process_ownership) + empty_ranks = [r for r in range(dist.size()) if r not in ranks_with_blocks] + if empty_ranks: + raise RuntimeError( + f"world_size ({dist.size()}) exceeds num_hidden_layers " + f"({len(all_block_indices)}); ranks {empty_ranks} would own 0 blocks. " + f"Pipeline-parallel bypass distillation does not support idle ranks — " + f"reduce nproc_per_node to at most num_hidden_layers." + ) + + ownership_context = get_pipeline_ownership_context(model_blocks_process_ownership) + prev_rank: Optional[int] = ownership_context["prev_rank"] + next_rank: Optional[int] = ownership_context["next_rank"] + + teacher_parameters = set(teacher_model.parameters()) + teacher_buffers = set(teacher_model.buffers()) + + # Setup the student model's submodules for knowledge distillation training + with _autocast_context(descriptor), torch.device(device): + stitched_module_descriptors = OrderedDict[str, StitchedModuleDescriptor]() + submodule_for_loss_calculation = cfg.model_factory.submodule_for_loss_calculation + + teacher_target = ModuleTarget("teacher", teacher_model) + teacher_stitcher = Needle() + teacher_val_stitcher = Needle() + + student_target = ModuleTarget("student", student_model) + student_val_stitcher = Needle() + + for local_block_index, global_block_index in enumerate(sorted(owned_block_indexes)): + module_name = descriptor.layer_block_name(global_block_index) + module = student_model.get_submodule(module_name) + + submodule_name = "" + submodule_input_descriptor = submodule_name + submodule_output_descriptor = submodule_name + + if submodule_for_loss_calculation is not None: + assert hasattr(module, submodule_for_loss_calculation) + submodule_output_descriptor = submodule_for_loss_calculation + + input_descriptor = f"{module_name}.{submodule_input_descriptor}".rstrip(".") + output_descriptor = f"{module_name}.{submodule_output_descriptor}".rstrip(".") + + # Receive activations from previous rank + if global_block_index > 0 and local_block_index == 0 and prev_rank is not None: + teacher_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="teacher_activations", adapter=lambda x: InputArgs(x) + ), + teacher_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + teacher_val_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="teacher_activations", adapter=lambda x: InputArgs(x) + ), + teacher_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + student_val_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="student_activations", adapter=lambda x: InputArgs(x) + ), + student_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + + # Send activations to next rank or register model output + if local_block_index + 1 == len(owned_block_indexes): + if next_rank is None: + student_val_stitcher.stitch( + student_target.output(name=""), + ExternalTarget().output("model_output"), + ) + teacher_val_stitcher.stitch( + teacher_target.output(name=""), + ExternalTarget().output("model_output"), + ) + else: + teacher_stitcher.stitch( + teacher_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="teacher_activations"), + ) + teacher_val_stitcher.stitch( + teacher_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="teacher_activations"), + ) + student_val_stitcher.stitch( + student_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="student_activations"), + ) + + # Bypass training stitches + teacher_stitcher.stitch( + teacher_target.input(name=input_descriptor), + ExternalTarget().input(name=input_descriptor), + ).stitch( + teacher_target.output(name=output_descriptor), + ExternalTarget().output(name=output_descriptor), + ) + + # Create the student block stitched module + student_stitched_module_loss_target = FunctionTarget( + "module_loss_func", block_loss_func + ) + student_stitched_module_name = f"block_{global_block_index}" + student_submodule_target = ModuleTarget("student_submodule", module) + # When a block returns a tuple, ``v[0]`` is the hidden state by + # HF convention — every HF transformer block (Llama, Qwen, GPT-OSS, + # NemotronH, …) returns ``(hidden_states, *aux)``, with ``aux`` + # varying (attention weights, KV cache, router logits, …) but + # element 0 always being the hidden state. Puzzletron is HF-format- + # only, so this assumption holds across every supported family. + student_stitched_module = ( + Needle() + .stitch( + ExternalTarget().input(name=input_descriptor), + student_submodule_target.input(name=submodule_input_descriptor), + ) + .stitch( + ExternalTarget().output( + name=output_descriptor, + adapter=lambda v: InputArgs(target=v) + if not isinstance(v, tuple) + else InputArgs(target=v[0]), + ), + student_stitched_module_loss_target.input(), + ) + .stitch( + student_submodule_target.output( + name=submodule_output_descriptor, + adapter=lambda v: InputArgs(input=v) + if not isinstance(v, tuple) + else InputArgs(input=v[0]), + ), + student_stitched_module_loss_target.input(), + ) + .stitch( + student_stitched_module_loss_target.output(), + ExternalTarget().output(name="loss"), + ) + .knot( + ignore_extra_overrides=True, + capture_cache_outputs_predicate=always_true_predicate, + ) + ) + + assert "learning_rate" in cfg.training + # Do NOT enable dummy params: blocks with no real trainable parameters + # (e.g. Mamba blocks during an attention-only bypass run) should produce + # NaN loss so they are excluded from statistics — identical to the + # optimizer=None path in the training loop. + + student_module_parameters = { + p_name: p + for p_name, p in student_stitched_module.named_parameters() + if p not in teacher_parameters and "dummy_param" not in p_name + } + student_module_buffers = { + p_name: p + for p_name, p in student_stitched_module.named_buffers() + if p not in teacher_buffers + and p_name not in _get_all_non_persistent_buffers_set(student_stitched_module) + } + + trainable_params = { + p_name: p for p_name, p in student_module_parameters.items() if p.requires_grad + } + + optimizer = ( + AdamW( + list(trainable_params.values()), + lr=cfg.training.learning_rate, + weight_decay=cfg.training.weight_decay, + betas=(cfg.training.beta1, cfg.training.beta2), + fused=True, + ) + if len(trainable_params) > 0 + else None + ) + + grad_scaler = ( + None + if optimizer is None + else GradScaler(device=device.type, enabled=cfg.training.use_grad_scaling) + ) + + stitched_module_descriptors[student_stitched_module_name] = StitchedModuleDescriptor( + stitched_module=student_stitched_module, + owned_parameters=student_module_parameters, + owned_buffers=student_module_buffers, + optimizer=optimizer, + grad_scaler=grad_scaler, + ) + + teacher_stitched_module = teacher_stitcher.knot(ignore_extra_overrides=True) + teacher_val_stitched_module = teacher_val_stitcher.knot(ignore_extra_overrides=True) + student_val_stitched_module = student_val_stitcher.knot(ignore_extra_overrides=True) + + local_trainable_param_count = sum( + p.numel() + for descriptor_ in stitched_module_descriptors.values() + for p in descriptor_.owned_parameters.values() + if p.requires_grad + ) + global_trainable_param_count = dist.allreduce(local_trainable_param_count, reduction="sum") + if global_trainable_param_count == 0: + raise ValueError( + f"keys_to_learn={keys_to_learn!r} did not match any trainable student parameters" + ) + + return ( + student_model, + teacher_stitched_module, + teacher_val_stitched_module, + student_val_stitched_module, + stitched_module_descriptors, + student_model_config, + ) diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py new file mode 100644 index 00000000000..885bf24d149 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -0,0 +1,1232 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bypass distillation training loop for per-block knowledge distillation. + +This module implements the blockwise local distillation (BLD) stage of the PUZZLE framework. +It trains alternative transformer block configurations using per-block knowledge distillation +from a teacher model, producing a library of "puzzle pieces" with different efficiency/performance +trade-offs. +""" + +import logging +import math +import os +import shutil +import sys +import time +import traceback +from collections import OrderedDict, defaultdict +from contextlib import nullcontext +from pathlib import Path +from statistics import mean +from typing import Optional + +import datasets +import torch +import torch.distributed +import transformers +from omegaconf import DictConfig, OmegaConf +from torch.utils.data.dataloader import DataLoader +from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase + +import modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory as stitched_model_factory_module +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.sewing_kit import InputArgs, StitchedModule +from modelopt.torch.puzzletron.sewing_kit.utils import fake_tensor +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config +from modelopt.torch.puzzletron.tools.logger import aprint, mprint +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model +from modelopt.torch.puzzletron.utils.parsing import format_global_config, format_stitched_losses +from modelopt.torch.utils.robust_json import json_load + +from .bypass_checkpoint_utils import find_latest_run_dir, load_local_state, save_bypass_checkpoint +from .bypass_utils import ( + bypass_run_is_complete, + get_distributed_modules_ownership, + get_pipeline_ownership_context, + load_bypass_state, + mark_bypass_run_completed, + set_experiment_dir, + set_experiment_id, +) +from .data_classes import GlobalRank, IterNum, IterStatistics, TimeToSaveSignal +from .stitched_model_factory import StitchedModuleDescriptor, StitchedModulesProcessOwnership + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def _autocast_context(descriptor: ModelDescriptor): + return ( + torch.autocast(device_type="cuda", dtype=torch.bfloat16) + if descriptor.uses_autocast() + else nullcontext() + ) + + +def launch_bypass_distillation(hydra_cfg: DictConfig) -> None: + """Top-level entry point for bypass distillation stage. + + Runs sewing-kit pipeline-parallel per-block knowledge distillation. + + Supports multiple bypass configurations via ``bypass.configs`` list. + Each entry overrides ``bypass.model.model_config_overrides`` and optionally + ``bypass.model_factory.keys_to_learn``, then runs a full bypass training. + + If ``bypass.configs`` is absent or empty, runs a single bypass training + with the settings already in ``bypass``. + + Args: + hydra_cfg: The full Hydra configuration with a 'bypass' section. + """ + configs_list = hydra_cfg.bypass.get("configs", None) + + if not configs_list: + # Single config mode — run once with whatever is in bypass already + set_experiment_id(hydra_cfg) + set_experiment_dir(hydra_cfg) + dist.barrier() + bypass_complete = bypass_run_is_complete(hydra_cfg) if dist.is_master() else None + bypass_complete = dist.broadcast(bypass_complete, src=0) + if bypass_complete: + mprint( + f"Bypass distillation already completed for {hydra_cfg.bypass.experiment_id}, skipping" + ) + return + mprint("Starting bypass distillation (single config)") + run_bypassed_training(hydra_cfg) + mprint("Bypass distillation completed") + return + + base_model_config_overrides = OmegaConf.to_container( + hydra_cfg.bypass.model.model_config_overrides, resolve=True + ) + base_keys_to_learn = hydra_cfg.bypass.model_factory.keys_to_learn + + mprint(f"Starting bypass distillation sweep ({len(configs_list)} configs)") + for i, override in enumerate(configs_list): + mprint(f"Bypass config {i + 1}/{len(configs_list)}: {override}") + + hydra_cfg.bypass.model.model_config_overrides = OmegaConf.create( + base_model_config_overrides + ) + hydra_cfg.bypass.model_factory.keys_to_learn = base_keys_to_learn + + # Apply overrides for this run + if "model_config_overrides" in override: + hydra_cfg.bypass.model.model_config_overrides = override.model_config_overrides + if "keys_to_learn" in override: + hydra_cfg.bypass.model_factory.keys_to_learn = override.keys_to_learn + + # Reset per-run state so each config starts fresh + hydra_cfg.bypass.experiment_id = None + hydra_cfg.bypass.iter_num = 1 + hydra_cfg.bypass.step_num = 1 + hydra_cfg.bypass.token_count = 0 + hydra_cfg.bypass.best_val_loss = 1e9 + hydra_cfg.bypass.training.clipping_count = 0 + # Per-block bookkeeping for the Stitched-Module-Losses table. Mirrored + # into cfg.bypass on every log chunk so save_bypass_checkpoint's + # args.json snapshot carries them, and resume can restore the columns + # instead of trivially re-anchoring to the first post-resume chunk. + hydra_cfg.bypass.best_losses_by_name = {} + hydra_cfg.bypass.best_steps_by_name = {} + hydra_cfg.bypass.initial_losses_by_name = {} + + set_experiment_id(hydra_cfg) + set_experiment_dir(hydra_cfg) + dist.barrier() + bypass_complete = bypass_run_is_complete(hydra_cfg) if dist.is_master() else None + bypass_complete = dist.broadcast(bypass_complete, src=0) + if bypass_complete: + mprint( + f"Bypass config {i + 1}/{len(configs_list)} " + f"({hydra_cfg.bypass.experiment_id}) already completed, skipping" + ) + else: + run_bypassed_training(hydra_cfg) + mprint(f"Bypass config {i + 1}/{len(configs_list)} completed") + + mprint("Bypass distillation sweep completed") + + +def _flush_loss_buffer( + local_buffer: dict[int, dict[str, float]], + stitched_losses_history: Optional[dict[int, dict[str, float]]], +) -> None: + """All-gather buffered per-iter losses and merge into master's history. + + Pickle-based ``all_gather_object`` was previously called on every micro-batch; + batching to log-chunk boundaries reduces that cost ~``iters_per_log_chunk``×. + All ranks must call this so the collective doesn't deadlock; only master + actually accumulates into ``stitched_losses_history``. + """ + if not local_buffer: + return + gathered: list[Optional[dict[int, dict[str, float]]]] = [None] * dist.size() + torch.distributed.all_gather_object(gathered, local_buffer) + if dist.is_master(): + assert stitched_losses_history is not None + for rank_buf in gathered: + if rank_buf is None: + continue + for it, losses in rank_buf.items(): + stitched_losses_history.setdefault(it, {}).update(losses) + + +def _delete_old_checkpoints( + experiment_dir: Path, + glob_pattern: str, + keep_name: str, +) -> None: + if not dist.is_master(): + return + for old_ckpt_path in experiment_dir.glob(glob_pattern): + if old_ckpt_path.name != keep_name: + shutil.rmtree(str(old_ckpt_path)) + + +def _save_training_checkpoint( + *, + cfg: DictConfig, + descriptor: ModelDescriptor, + model: torch.nn.Module, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + subdir_name: str, + checkpoint_role: str, + cleanup_glob: str | None = None, +) -> None: + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=Path(cfg.bypass.experiment_dir) / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + checkpoint_role=checkpoint_role, + ) + if cleanup_glob and cfg.bypass.model.model_overrides.delete_old_checkpoints: + _delete_old_checkpoints(Path(cfg.bypass.experiment_dir), cleanup_glob, subdir_name) + + +def train( + cfg: DictConfig, + descriptor: ModelDescriptor, + student_model: torch.nn.Module, + student_stitched_model: StitchedModule, + teacher_stitched_model: StitchedModule, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + stitched_modules_process_ownership: StitchedModulesProcessOwnership, + train_dataloader: Optional[DataLoader], + val_dataloader: Optional[DataLoader], + student_model_config: PretrainedConfig, + skip_first_batches: int = 0, + tokenizer: Optional[PreTrainedTokenizerBase] = None, +) -> None: + """Inner training loop for bypass distillation.""" + device = torch.device(f"cuda:{dist.local_rank()}") + + dist.barrier() + + # Anchor the time-based save interval at training start, not module import. + # Earlier this was a module-level `time_start = time.time()`, which made + # the first time-based save fire immediately if the module was imported + # well before train() actually ran (e.g. via test collection or Hydra config + # resolution). + time_last_save = time.time() + iter_t0 = time.time() + + resumed_iter_num = cfg.bypass.iter_num + mprint(f"resumed_iter_num: {resumed_iter_num}") + + # Number of total stitched modules + global_stitched_modules_count = len(stitched_modules_process_ownership) + # Number of stitched modules per process + num_stitched_modules_per_process = [ + sum(1 for x in stitched_modules_process_ownership if x == owner_rank) + for owner_rank in range(dist.size()) + ] + ownership_context = get_pipeline_ownership_context(stitched_modules_process_ownership) + owned_stitched_module_indices = ownership_context["owned_indices"] + mprint(f"{global_stitched_modules_count=}") + mprint(f"{num_stitched_modules_per_process=}") + dist.barrier() + + if dist.is_master(): + # {iter_num: {stitched_module_name: loss}} + stitched_losses_history = dict[IterNum, dict[str, float]]() + else: + stitched_losses_history = None + + # Save checkpoint before training starts + if cfg.bypass.save_checkpoint_before_training and not cfg.bypass.disable_checkpoint_save: + subdir_name = f"start-step-{cfg.bypass.step_num:06d}-ckpt" + _save_training_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + subdir_name=subdir_name, + checkpoint_role="start", + ) + + # Track statistics for each iteration + iter_stats_history: dict[IterNum, IterStatistics] = {} + + # Create fake input ids for the teacher model + fake_input_ids = fake_tensor( + torch.ones( + size=(cfg.bypass.training.micro_batch_size, cfg.bypass.data.block_size), + dtype=torch.long, + device=device, + ) + ) + + prev_rank: Optional[int] = ownership_context["prev_rank"] + next_rank: Optional[int] = ownership_context["next_rank"] + + torch.cuda.synchronize() + + mprint( + f"Grad scaling status: {'enabled' if cfg.bypass.training.use_grad_scaling else 'disabled'}" + ) + + # Only master consumes the dataloader — `next(train_iterator)` is gated by + # `if dist.is_master()` further down. Building the iterator (or running + # skip_first_batches against it) on non-master ranks wastes startup time + # and memory proportional to the dataset, since each tokenizes the full + # corpus only to throw it away. + train_iterator = None + if dist.is_master(): + assert train_dataloader is not None + train_iterator = iter(train_dataloader) + + # Advance past the first `skip_first_batches` batches before the training loop + # starts. Used either to skip a known-bad batch range during debugging, or to + # roll the data iterator forward when resuming a run (model + optimizer state + # are restored from the checkpoint, but the dataloader itself starts fresh). + if dist.is_master() and skip_first_batches > 0: + assert train_iterator is not None + mprint(f"Skipping first {skip_first_batches} batches before training") + for _ in range(skip_first_batches): + next(train_iterator) + + mprint("Waiting for everyone before training starts") + dist.barrier() + + step_to_save = None + # Track best loss value for each block. Seeded from cfg.bypass so resume + # picks up where the previous run left off (run_bypassed_training restores + # these from args.json before train_pipeline_parallel runs). + best_losses_by_name: dict[str, float] = dict(cfg.bypass.get("best_losses_by_name", {})) + best_steps_by_name: dict[str, int] = dict(cfg.bypass.get("best_steps_by_name", {})) + # Anchor for the "Δ from initial" column: per-block loss from the first log chunk. + initial_losses_by_name: dict[str, float] = dict(cfg.bypass.get("initial_losses_by_name", {})) + non_trainable_stitched_module_names = { + name + for name, descriptor in stitched_module_descriptors.items() + if descriptor.optimizer is None + } + + # log_interval is in optimizer-step units; multiply by grad_accum to land in + # micro-batch units, which is what the per-iter loss collection counts. + iters_per_log_chunk = ( + cfg.bypass.training.log_interval * cfg.bypass.training.grad_accumulation_steps + ) + # Per-rank local buffer of {iter_num: {block_name: loss}}. We accumulate + # losses locally on every rank and only collide them via all_gather_object + # at log-chunk boundaries — the object collective is pickle-based and + # was previously the per-iter sync cost. See `_flush_loss_buffer` below. + local_losses_buffer: dict[int, dict[str, float]] = {} + # Buffer variables. Initialise on the active device so non-master ranks + # never hand a CPU tensor to a downstream GPU op if the master-only-fetch + # invariant is ever relaxed (today only master replaces this in the loop). + input_ids = torch.zeros(1, 1, dtype=torch.int64, device=device) + + aprint( + f"previous rank: {str(prev_rank):<5} next rank: {str(next_rank):<5} {owned_stitched_module_indices=}" + ) + + # Train loop start + while True: + time_now = time.time() + # Check if we've reached the maximum number of steps. `step_num` is 1-based + # and incremented at the END of each iteration, so we must use `>` (not `>=`) + # to ensure step `max_steps` itself runs before exiting. + if cfg.bypass.step_num > cfg.bypass.training.max_steps: + # Drain any residual buffered losses (< log-chunk boundary) so the + # final partial chunk's stats reach master and can be logged before + # the function returns. Must run on every rank — collective op. + _flush_loss_buffer(local_losses_buffer, stitched_losses_history) + local_losses_buffer.clear() + if ( + cfg.bypass.model.model_overrides.save_checkpoint_when_done + and not cfg.bypass.disable_checkpoint_save + ): + mprint("Saving final checkpoint before training completion") + subdir_name = f"final-step-{cfg.bypass.step_num:06d}-ckpt" + _save_training_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_role="final", + subdir_name=subdir_name, + cleanup_glob="step-*", + ) + break + + is_accumulating = cfg.bypass.iter_num % cfg.bypass.training.grad_accumulation_steps != 0 + # Determine and set the learning rate for this iteration + lr = ( + _get_lr(cfg, cfg.bypass.step_num) + if cfg.bypass.training.decay_lr + else cfg.bypass.training.learning_rate + ) + for stitched_module_descriptor in stitched_module_descriptors.values(): + optimizer = stitched_module_descriptor.optimizer + if optimizer is not None: + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + if dist.is_master(): + assert train_iterator is not None + train_data = next(train_iterator) + input_ids = train_data["input_ids"] + input_ids = input_ids.to(device) + + with _autocast_context(descriptor), torch.no_grad(): + teacher_input_ids = input_ids if prev_rank is None else fake_input_ids + teacher_output = teacher_stitched_model({}, {}, teacher_input_ids) + + input_overrides = teacher_output.captured_inputs + output_overrides = teacher_output.captured_outputs + + del teacher_output + + input_overrides["teacher_inputs"] = InputArgs(fake_input_ids) + + # Collect per-block loss tensors and batch the GPU→CPU copy to a + # single sync point at the end of the per-block loop. Doing + # ``.to("cpu").item()`` per block forced one CUDA synchronization per + # block per iter, serialising the GPU pipeline across N blocks. + iter_loss_tensors: dict[str, torch.Tensor] = {} + + for local_stitched_module_index, ( + stitched_module_name, + stitched_module_descriptor, + ) in enumerate(stitched_module_descriptors.items()): + stitched_module = stitched_module_descriptor.stitched_module + optimizer = stitched_module_descriptor.optimizer + grad_scaler = stitched_module_descriptor.grad_scaler + + if optimizer is not None: + assert grad_scaler is not None + + with _autocast_context(descriptor): + stitched_module_output = stitched_module( + input_overrides=input_overrides, + output_overrides=output_overrides, + ) + stitched_module_loss = stitched_module_output.captured_outputs["loss"] + del stitched_module_output + scaled_stitched_module_loss = ( + stitched_module_loss / cfg.bypass.training.grad_accumulation_steps + ) + grad_scaler.scale(scaled_stitched_module_loss).backward() + iter_loss_tensors[stitched_module_name] = stitched_module_loss.detach() + del scaled_stitched_module_loss + else: + # No real trainable parameters on this rank/block. Keep this out + # of the numeric loss stream so genuine non-finite losses from + # trainable blocks remain visible instead of being conflated with + # an intentional "not trainable" sentinel. + stitched_module_loss = None + + del stitched_module_loss + + if not is_accumulating: + if optimizer is not None: + grad_clip = cfg.bypass.training.grad_clip + if grad_clip is not None: + if cfg.bypass.training.grad_clip_type == "norm": + grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=stitched_module.parameters(), + max_norm=grad_clip, + ) + if grad_norm > grad_clip: + cfg.bypass.training.clipping_count += 1 + elif cfg.bypass.training.grad_clip_type == "value": + # Stack per-param maxes into a single GPU tensor and + # reduce before `.item()` so we sync once per block + # instead of once per parameter (see per-block batching + # rationale at lines 301-304). + grad_maxes = [ + p.grad.abs().max() + for p in stitched_module.parameters() + if p.grad is not None + ] + if grad_maxes: + max_abs_grad = torch.stack(grad_maxes).max().item() + else: + max_abs_grad = 0.0 + if max_abs_grad > grad_clip: + cfg.bypass.training.clipping_count += 1 + torch.nn.utils.clip_grad_value_( + parameters=stitched_module.parameters(), + clip_value=grad_clip, + ) + else: + raise RuntimeError(f"Invalid {cfg.bypass.training.grad_clip_type}") + + assert grad_scaler is not None + grad_scaler.step(optimizer) + grad_scaler.update() + optimizer.zero_grad(set_to_none=True) + + # Single GPU→CPU sync for all per-block losses collected above. Stacking + # into a 1-D tensor lets us issue exactly one ``.to("cpu")`` instead of + # one per block. + if iter_loss_tensors: + loss_stack = torch.stack([t.flatten()[0] for t in iter_loss_tensors.values()]) + iter_stitched_module_losses: dict[str, float] = dict( + zip(iter_loss_tensors.keys(), loss_stack.to("cpu").tolist()) + ) + else: + iter_stitched_module_losses = {} + + if dist.is_master() and cfg.bypass.iter_num == resumed_iter_num: + mprint(f"Starting from iter {cfg.bypass.iter_num}") + + # Buffer this rank's per-block losses locally. The collide-across-ranks + # gather happens only at log-chunk boundaries (`_flush_loss_buffer`), + # which cuts the per-iter pickle-based all_gather_object cost down to + # one gather per `iters_per_log_chunk` micro-batches. + local_losses_buffer[cfg.bypass.iter_num] = iter_stitched_module_losses + if len(local_losses_buffer) >= iters_per_log_chunk: + _flush_loss_buffer(local_losses_buffer, stitched_losses_history) + local_losses_buffer.clear() + + cfg.bypass.token_count += cfg.bypass.training.tokens_per_iter + iter_t1 = time.time() + iter_duration = iter_t1 - iter_t0 + iter_stats_history[cfg.bypass.iter_num] = IterStatistics( + token_count=cfg.bypass.token_count, + iter_duration=iter_duration, + step_num=cfg.bypass.step_num, + lr=lr, + clipping_count=cfg.bypass.training.clipping_count, + ) + iter_t0 = iter_t1 + + # Time-based save signal (broadcast from master) + save_signal = [step_to_save] + if dist.is_master(): + if cfg.bypass.model.model_overrides.save_interval_seconds is not None: + time_now = time.time() + if ( + time_now - time_last_save + >= cfg.bypass.model.model_overrides.save_interval_seconds + ): + mprint( + f"Time to save! {cfg.bypass.model.model_overrides.save_interval_seconds=}, " + f"{time_last_save=}, {time_now=}" + ) + step_to_save = cfg.bypass.step_num + 5 + save_signal = [step_to_save] + time_last_save = time_now + + torch.distributed.broadcast_object_list(save_signal, src=0) + step_to_save = save_signal[0] + + # Logging + if dist.is_master(): + assert stitched_losses_history is not None + # `iters_per_log_chunk` is computed once before the loop (in + # micro-batch units = log_interval × grad_accum) and reused for + # both the gather-batching threshold and this log drain. + while len(stitched_losses_history) >= iters_per_log_chunk: + lowest_iter = next(iter(stitched_losses_history.keys())) + + log_chunk = { + it: losses + for it, losses in stitched_losses_history.items() + if it - lowest_iter < iters_per_log_chunk + } + if len(log_chunk) < iters_per_log_chunk: + break + + highest_iter = list(log_chunk.keys())[-1] + highest_iter_stats = iter_stats_history[highest_iter] + + losses_by_name = defaultdict[str, list[float]](list) + for losses in log_chunk.values(): + for name, loss in losses.items(): + losses_by_name[name].append(loss) + + losses_by_name_avg = {name: mean(losses) for name, losses in losses_by_name.items()} + non_finite_losses_by_name = { + name: loss + for name, loss in losses_by_name_avg.items() + if not math.isfinite(loss) + } + if non_finite_losses_by_name: + cfg.bypass.non_finite_losses_by_name = dict(non_finite_losses_by_name) + mprint(f"Non-finite stitched losses detected: {non_finite_losses_by_name}") + + # Anchor "Δ from initial" at the very first iter's per-block losses + # (lowest_iter — typically iter 1 on a fresh run, the resumed iter + # otherwise). Using the first chunk's *average* would tautologically + # make Δ == 0 on the first row, since "Loss Value" is that same average. + if not initial_losses_by_name: + initial_losses_by_name.update(stitched_losses_history[lowest_iter]) + + # Update best losses tracking. Record the optimizer-step number + # so the "Best Step" column matches the header's "step N/max" units. + for name, current_loss in losses_by_name_avg.items(): + if not math.isfinite(current_loss): + continue + if name not in best_losses_by_name or current_loss < best_losses_by_name[name]: + best_losses_by_name[name] = current_loss + best_steps_by_name[name] = highest_iter_stats.step_num + + # Mirror to cfg.bypass so save_bypass_checkpoint's args.json snapshot + # carries these forward across resumes. + cfg.bypass.best_losses_by_name = dict(best_losses_by_name) + cfg.bypass.best_steps_by_name = dict(best_steps_by_name) + cfg.bypass.initial_losses_by_name = dict(initial_losses_by_name) + + chunk_iter_durations = [ + iter_stats_history[it].iter_duration for it in log_chunk.keys() + ] + avg_chunk_iter_duration = mean(chunk_iter_durations) + # Report time in step units (= grad_accum × per-iter), since one + # step is one optimizer update — what the user actually thinks of + # as "a training step." Tokens/sec is invariant to that framing. + avg_step_time = ( + avg_chunk_iter_duration * cfg.bypass.training.grad_accumulation_steps + ) + avg_token_speed = cfg.bypass.training.tokens_per_iter / avg_chunk_iter_duration + mprint( + f"step {highest_iter_stats.step_num}/{cfg.bypass.training.max_steps:,}:" + f" avg_step_time={avg_step_time * 1000:.2f}ms" + f" avg_token_speed={avg_token_speed:,.0f}[tok/s]" + ) + mprint( + format_stitched_losses( + losses_dict=losses_by_name_avg, + best_steps_dict=best_steps_by_name, + best_values_dict=best_losses_by_name, + initial_values_dict=initial_losses_by_name, + not_trainable_names=non_trainable_stitched_module_names, + step_number=highest_iter_stats.step_num, + title="Stitched Module Losses", + ) + ) + + if cfg.bypass.wandb_log: + try: + import wandb + + wandb.log( + { + "step": highest_iter_stats.step_num, + "token_count": highest_iter_stats.token_count, + "token_speed": avg_token_speed, + "lr": highest_iter_stats.lr, + "grad_clipping": highest_iter_stats.clipping_count, + }, + step=highest_iter_stats.step_num, + ) + except ImportError: + pass + + for it in log_chunk.keys(): + del iter_stats_history[it] + del stitched_losses_history[it] + + # Validation + if ( + not is_accumulating + and (cfg.bypass.step_num % cfg.bypass.training.eval_interval) == 0 + and val_dataloader is not None + ): + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + losses, _ = calculate_losses_pipeline( + stitched_model=student_stitched_model, + dataloader=val_dataloader, + descriptor=descriptor, + ) + + val_loss = float("inf") + if losses is not None and "lm_loss" in losses: + val_loss = losses["lm_loss"]["avg"] + mprint(f"Validation loss at iter {cfg.bypass.iter_num}: {val_loss:.4f}") + + # Broadcast val_loss so all ranks agree on checkpoint decisions + val_loss_tensor = torch.tensor([val_loss], device=device) + torch.distributed.broadcast(val_loss_tensor, src=dist.size() - 1) + val_loss = val_loss_tensor.item() + + if val_loss < cfg.bypass.best_val_loss: + cfg.bypass.best_val_loss = val_loss + if not cfg.bypass.disable_checkpoint_save and cfg.bypass.save_best_ckpt: + subdir_name = f"best-step-{cfg.bypass.step_num:06d}-ckpt" + _save_training_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_role="best", + subdir_name=subdir_name, + cleanup_glob="best-step-*", + ) + if cfg.bypass.kill_after_first_save: + raise RuntimeError("Done saving checkpoint, kill_after_first_save=True") + + # Checkpoint saving (step-based or time-based) + if not is_accumulating and ( + (cfg.bypass.step_num % cfg.bypass.model.model_overrides.save_interval) == 0 + or step_to_save == cfg.bypass.step_num + ): + if not cfg.bypass.disable_checkpoint_save: + if (cfg.bypass.step_num % cfg.bypass.model.model_overrides.save_interval) == 0: + mprint("Saving step-interval checkpoint") + elif step_to_save == cfg.bypass.step_num: + mprint("Saving time-based checkpoint") + + subdir_name = f"step-{cfg.bypass.step_num:06d}-ckpt" + _save_training_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_role="resume", + subdir_name=subdir_name, + cleanup_glob="step-*", + ) + + if cfg.bypass.kill_after_first_save: + dist.barrier() + raise RuntimeError("Done saving checkpoint, kill_after_first_save=True") + + cfg.bypass.iter_num += 1 + if not is_accumulating: + cfg.bypass.step_num += 1 + + mprint("Finished successfully!") + + +# Learning rate decay scheduler (cosine with warmup) +def _get_lr(cfg: DictConfig, step: int) -> float: + warmup_steps = cfg.bypass.training.warmup_steps + lr_decay_steps = cfg.bypass.training.lr_decay_steps + # Degenerate budget (e.g. tiny `training_tokens` in tests): no room for cosine decay. + # Skip warmup/decay entirely and return base LR — avoids ZeroDivisionError on + # `lr_decay_steps - warmup_steps` and `step / warmup_steps`. + if lr_decay_steps <= warmup_steps: + return cfg.bypass.training.learning_rate + + # 1) linear warmup for warmup_steps steps + if step <= warmup_steps: + if warmup_steps == 0: + # Defensive: training loop's step starts at 1 so this branch is + # unreachable today, but a future caller passing step=0 would hit + # a ZeroDivisionError on `step / warmup_steps` below. + return cfg.bypass.training.learning_rate + lr = cfg.bypass.training.learning_rate * step / warmup_steps + # 2) if step > lr_decay_steps, return min learning rate + elif step > lr_decay_steps: + lr = cfg.bypass.training.min_lr + # 3) in between, use cosine decay down to min learning rate + else: + decay_ratio = (step - warmup_steps) / (lr_decay_steps - warmup_steps) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + lr = cfg.bypass.training.min_lr + coeff * ( + cfg.bypass.training.learning_rate - cfg.bypass.training.min_lr + ) + + return lr + + +def run_bypassed_training(cfg: DictConfig): + """Setup and orchestrate bypass distillation training.""" + logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.WARN + ) + + # Suppress debug messages from HuggingFace libraries + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + device = torch.device(f"cuda:{dist.local_rank()}") + + set_experiment_id(cfg) + set_experiment_dir(cfg) + if bypass_run_is_complete(cfg): + mprint(f"Bypass run {cfg.bypass.experiment_id} is already complete, skipping") + return + + descriptor = ModelDescriptorFactory.get(cfg.descriptor) + trust_remote_code = descriptor.requires_trust_remote_code() + teacher_model_config = load_model_config(cfg.teacher_dir, trust_remote_code=trust_remote_code) + + try: + mprint("Waiting for distributed setup...") + dist.barrier() + + if cfg.bypass.disable_initial_validate: + cfg.bypass.validate_teacher_model = False + cfg.bypass.validate_student_model = False + + if cfg.bypass.teacher_model_load_on_cpu: + assert not cfg.bypass.validate_teacher_model, ( + "Teacher model validation is too slow on CPU" + ) + + num_hidden_layers = descriptor.get_language_model_config( + teacher_model_config + ).num_hidden_layers + + model_blocks_process_ownership = get_distributed_modules_ownership( + module_count=num_hidden_layers, + world_size=dist.size(), + ) + + owned_block_indexes = set( + block_index + for block_index, owner_rank in enumerate(model_blocks_process_ownership) + if owner_rank == dist.rank() + ) + + cfg.teacher_dir = str(Path(cfg.teacher_dir).expanduser()) + teacher_model_config = load_model_config( + cfg.teacher_dir, + trust_remote_code=trust_remote_code, + ) + # Disable KV cache during bypass forward passes. Set the attribute directly rather + # than passing it as an AutoConfig override — some custom configs (GptOss, Qwen3-VL, etc.) + # don't accept it as a known kwarg and would raise via the strict unused-kwargs check. + if hasattr(teacher_model_config, "use_cache"): + teacher_model_config.use_cache = False + if hasattr(teacher_model_config, "text_config") and hasattr( + teacher_model_config.text_config, "use_cache" + ): + teacher_model_config.text_config.use_cache = False + + # Resume detection has to run BEFORE the weight-loading branch below + # so a resume can route through ``load_and_shard_model`` (the HF + # checkpoint at ``resume_checkpoint_path`` is now the single source + # of truth for weights — see _save_local_state docstring). + # set_experiment_id / set_experiment_dir are idempotent and only + # depend on cfg.bypass.model.model_config_overrides + cfg.puzzle_dir, + # so it's safe to call them this early. + resume_checkpoint_path: Optional[str] = None + resume_cfg: Optional[DictConfig] = None + resume_skip_first_batches = cfg.bypass.training.skip_first_batches + if cfg.bypass.resume_checkpoint_path is not None: + resume_checkpoint_path = cfg.bypass.resume_checkpoint_path + elif cfg.bypass.find_last_ckpt_for_resume: + _ckpt_dir = find_latest_run_dir(run_parent_dir=cfg.bypass.experiment_dir) + if _ckpt_dir is None: + mprint("Couldn't find any run dir for resume, assuming this is the first job") + else: + mprint( + f"`cfg.bypass.find_last_ckpt_for_resume` is True. " + f"Auto-found a checkpoint to resume: `{_ckpt_dir}`" + ) + resume_checkpoint_path = _ckpt_dir + + if resume_checkpoint_path: + resume_cfg = DictConfig(json_load(Path(resume_checkpoint_path) / "args.json")) + saved_skip = resume_cfg.training.get( + "skip_first_batches", cfg.bypass.training.skip_first_batches + ) + resume_skip_first_batches = saved_skip + resume_cfg.iter_num + if "data" in resume_cfg and "shuffle_train_data_seed" in resume_cfg.data: + cfg.bypass.data.shuffle_train_data_seed = resume_cfg.data.shuffle_train_data_seed + if "seed" in resume_cfg: + cfg.bypass.seed = resume_cfg.seed + + # Both ``init_checkpoint_path`` and ``resume_checkpoint_path`` point at + # an HF-format directory; share the same loader. ``init_checkpoint_path`` + # wins if both are set (explicit user override beats auto-detect). + weight_load_path = cfg.bypass.init_checkpoint_path or resume_checkpoint_path + student_model = None + if weight_load_path is not None: + mprint(f"Loading student model from {weight_load_path}") + student_model = load_and_shard_model( + descriptor=descriptor, + checkpoint_path=weight_load_path, + owned_block_indexes=owned_block_indexes, + ) + + cfg.bypass.training.min_lr = ( + cfg.bypass.training.learning_rate * cfg.bypass.training.min_lr_factor + ) + cfg.bypass.training.batch_size_per_iter = cfg.bypass.training.micro_batch_size + cfg.bypass.training.tokens_per_iter = ( + cfg.bypass.data.block_size * cfg.bypass.training.batch_size_per_iter + ) + requested_iters = math.ceil( + cfg.bypass.training.training_tokens / cfg.bypass.training.tokens_per_iter + ) + # The loop steps optimizers only after a full grad-accum window, so round + # the requested token budget up to complete optimizer-step units and report + # that actual budget back to the user. + cfg.bypass.training.max_steps = math.ceil( + requested_iters / cfg.bypass.training.grad_accumulation_steps + ) + cfg.bypass.training.max_iters = ( + cfg.bypass.training.max_steps * cfg.bypass.training.grad_accumulation_steps + ) + cfg.bypass.training.max_token_count = ( + cfg.bypass.training.max_iters * cfg.bypass.training.tokens_per_iter + ) + cfg.bypass.training.lr_decay_steps = cfg.bypass.training.max_steps + + if cfg.bypass.training.val_micro_batch_size is None: + cfg.bypass.training.val_micro_batch_size = cfg.bypass.training.micro_batch_size + + if cfg.bypass.training.warmup_steps is None: + cfg.bypass.training.warmup_steps = 0 + + mprint(f"\n{format_global_config(cfg.bypass, 'Bypass Configurations')}") + mprint(f"Max token count: {cfg.bypass.training.max_token_count:,}") + + seed = cfg.bypass.seed + torch.manual_seed(seed) + + tokenizer = AutoTokenizer.from_pretrained( + cfg.teacher_dir, + trust_remote_code=trust_remote_code, + token=True, + ) + + assert teacher_model_config is not None + + mprint(f"Load and shard model with: {owned_block_indexes=}, {cfg.teacher_dir=}") + teacher_model = load_and_shard_model( + descriptor=descriptor, + checkpoint_path=cfg.teacher_dir, + owned_block_indexes=owned_block_indexes, + model_config=teacher_model_config, + ) + + teacher_model.requires_grad_(False) + + # Create dataloaders + from modelopt.torch.puzzletron.utils.data.dataloaders import ( + create_train_dataloader, + create_validation_dataloader, + load_from_disk_fn, + load_streaming_fn, + ) + + if cfg.bypass.data.eval_samples_per_process is not None: + max_eval_samples = cfg.bypass.data.eval_samples_per_process * dist.size() + else: + max_eval_samples = cfg.bypass.data.max_eval_samples + + load_dataset_fn = ( + load_streaming_fn if not cfg.bypass.data.load_from_disk else load_from_disk_fn + ) + + # Only master ever fetches from the train dataloader (training_loop.train + # gates `next(train_iterator)` on `dist.is_master()`), so skip the + # potentially-large HF dataset load + tokenisation on non-master ranks. + if dist.is_master(): + train_dataloader = create_train_dataloader( + seed=seed, + tokenizer=tokenizer, + block_size=cfg.bypass.data.block_size, + dataset_path=cfg.dataset_path, + content_field=cfg.bypass.data.data_column, + fim_rate=cfg.bypass.data.fim_rate, + fim_spm_rate=cfg.bypass.data.fim_spm_rate, + micro_batch_size=cfg.bypass.training.micro_batch_size, + load_dataset_fn=load_dataset_fn, + keep_in_memory=cfg.bypass.data.keep_in_memory, + source_datasets_to_discard=cfg.bypass.data.get( + "source_datasets_to_discard", tuple() + ), + bos_rate=cfg.bypass.data.bos_rate, + shuffle_seed=cfg.bypass.data.shuffle_train_data_seed, + ) + else: + train_dataloader = None + + val_dataloader = None + # Note: val_dataloader is kept constructed on every rank even though only + # master reads from it inside calculate_losses_pipeline. The validation + # block uses `val_dataloader is not None` as a "validation enabled" gate + # that must agree across ranks — and calculate_losses_pipeline itself is + # pipeline-parallel and requires every rank to enter it. Skipping + # construction on non-master ranks would break those invariants. + if not cfg.bypass.disable_validation: + val_dataloader = create_validation_dataloader( + accelerator=None, + seed=seed, + tokenizer=tokenizer, + block_size=cfg.bypass.data.block_size, + dataset=cfg.dataset_path, + content_field=cfg.bypass.data.data_column, + fim_rate=cfg.bypass.data.fim_rate, + fim_spm_rate=cfg.bypass.data.fim_spm_rate, + micro_batch_size=cfg.bypass.training.val_micro_batch_size, + eval_samples=max_eval_samples, + load_dataset_fn=load_dataset_fn, + dataset_name=cfg.bypass.data.val_dataset_name, + keep_in_memory=cfg.bypass.data.keep_in_memory, + source_datasets_to_discard=cfg.bypass.data.get( + "source_datasets_to_discard", tuple() + ), + bos_rate=cfg.bypass.data.bos_rate, + ) + + # set_experiment_id / set_experiment_dir already ran above (before + # weight loading) so the resume detection could use experiment_dir. + + dist.barrier() + + with torch.device(device): + stitched_model_factory_fn = getattr( + stitched_model_factory_module, cfg.bypass.model_factory.factory + ) + ( + student_model, + teacher_stitched_model, + teacher_val_stitched_module, + student_val_stitched_model, + stitched_module_descriptors, + student_model_config, + ) = stitched_model_factory_fn( + teacher_model=teacher_model, + descriptor=descriptor, + cfg=cfg.bypass, + model_blocks_process_ownership=model_blocks_process_ownership, + student_model=student_model, + ) + + # ``resume_checkpoint_path`` was determined earlier (before weight + # loading); the student weights are already in place via + # ``load_and_shard_model``. Only the optimizer/scaler state needs to + # be restored from the per-block ``stitched/`` files. + if resume_checkpoint_path: + load_local_state( + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_path=resume_checkpoint_path, + ) + + assert resume_cfg is not None + + # Periodic checkpoints are saved before the loop increments counters, + # so their args.json is inclusive and needs a +1 bump. Final + # checkpoints are saved after the loop already advanced beyond the + # last completed step, so their counters are already the next values. + resume_from_final = Path(resume_checkpoint_path).name.startswith("final-step-") + counter_bump = 0 if resume_from_final else 1 + cfg.bypass.iter_num = resume_cfg.iter_num + counter_bump + cfg.bypass.token_count = resume_cfg.token_count + cfg.bypass.step_num = resume_cfg.step_num + counter_bump + cfg.bypass.best_val_loss = resume_cfg.best_val_loss + cfg.bypass.training.clipping_count = resume_cfg.training.clipping_count + # Per-block bookkeeping. .get() defaults handle resume from older ckpts + # that predate these fields. + cfg.bypass.best_losses_by_name = resume_cfg.get("best_losses_by_name", {}) + cfg.bypass.best_steps_by_name = resume_cfg.get("best_steps_by_name", {}) + cfg.bypass.initial_losses_by_name = resume_cfg.get("initial_losses_by_name", {}) + mprint(f"Resume from iter_num: {cfg.bypass.iter_num}") + + # Only copy wandb.run_id if it exists in resume config + if hasattr(resume_cfg, "wandb") and hasattr(resume_cfg.wandb, "run_id"): + cfg.bypass.wandb.run_id = resume_cfg.wandb.run_id + + cfg.bypass.save_checkpoint_before_training = False + cfg.bypass.validate_teacher_model = False + cfg.bypass.validate_student_model = False + + cfg.bypass.resume_checkpoint_path = resume_checkpoint_path + + # Initialize Weights and Biases + if cfg.bypass.wandb_log: + try: + import wandb + + wandb.init( + project=cfg.bypass.wandb.project, + entity=cfg.bypass.wandb.entity, + config=dict(cfg.bypass), + ) + except ImportError: + mprint("wandb not installed, disabling wandb logging") + cfg.bypass.wandb_log = False + else: + mprint("Weights & Biases logging disabled (wandb_log=False)") + + if cfg.bypass.validate_teacher_model and val_dataloader is not None: + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + mprint("Evaluating teacher model:") + losses, _ = calculate_losses_pipeline( + stitched_model=teacher_val_stitched_module, + dataloader=val_dataloader, + descriptor=descriptor, + ) + if losses is not None: + mprint(f"Teacher validation losses: {losses}") + mprint("Evaluated teacher model") + + torch.cuda.empty_cache() + dist.barrier() + + parameter_count = sum(p.numel() for p in student_model.parameters()) + aprint(f"Model parameter count: {parameter_count:,}") + cfg.bypass.parameter_count = parameter_count + + dist.barrier() + mprint("Performing dummy runs on stitched modules:") + torch.cuda.synchronize() + with ( + torch.no_grad(), + _autocast_context(descriptor), + torch.device(device), + ): + input_ids = torch.ones( + (cfg.bypass.training.micro_batch_size, cfg.bypass.data.block_size), + dtype=torch.long, + ) + dummy_fake_input_ids = fake_tensor(input_ids) + mprint(f"Dummy runs on stitched modules with shape: {dummy_fake_input_ids.shape=}") + teacher_output = teacher_stitched_model({}, {}, input_ids) + for stitched_module_descriptor in stitched_module_descriptors.values(): + stitched_module = stitched_module_descriptor.stitched_module + stitched_module( + input_overrides={ + **teacher_output.captured_inputs, + "teacher_inputs": InputArgs(dummy_fake_input_ids), + }, + output_overrides=teacher_output.captured_outputs, + ) + for name, param in stitched_module.named_parameters(recurse=True): + if "iter_num" in name: + param.data = torch.zeros_like(param.data) + del name, param + del input_ids, dummy_fake_input_ids, teacher_output + torch.cuda.synchronize() + dist.barrier() + + del teacher_model + + if cfg.bypass.validate_student_model and val_dataloader is not None: + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + mprint("Validating model before training:") + losses, _ = calculate_losses_pipeline( + stitched_model=student_val_stitched_model, + dataloader=val_dataloader, + descriptor=descriptor, + ) + if losses is not None: + mprint(f"Student validation losses: {losses}") + + dist.barrier() + torch.cuda.empty_cache() + dist.barrier() + + train( + cfg=cfg, + descriptor=descriptor, + student_model=student_model, + student_stitched_model=student_val_stitched_model, + teacher_stitched_model=teacher_stitched_model, + stitched_module_descriptors=stitched_module_descriptors, + stitched_modules_process_ownership=model_blocks_process_ownership, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + student_model_config=student_model_config, + skip_first_batches=resume_skip_first_batches, + tokenizer=tokenizer, + ) + + aprint("Finished training successfully!") + dist.barrier() + + except Exception: + # Print the traceback explicitly so distributed runs surface it on every + # rank's stderr (workers under torchrun otherwise lose ordering), then + # re-raise so test frameworks see the real exception instead of a + # generic SystemExit(1). + print(traceback.format_exc(), file=sys.stderr) + raise + + dist.barrier() + if dist.is_master(): + mprint("Realizing bypass checkpoints") + realized_checkpoint, ckpts_symlink = realize_bypass_checkpoints(cfg) + mark_bypass_run_completed(cfg, realized_checkpoint, ckpts_symlink) + dist.barrier() + + +def realize_bypass_checkpoints(cfg: DictConfig) -> tuple[Path, Path]: + """Create symlinks from bypass checkpoint directories to the ckpts directory.""" + state = load_bypass_state(cfg.bypass.experiment_dir) or {} + checkpoints = state.get("checkpoints", {}) + realize_mode = cfg.bypass.get("realize_best_or_latest", "latest") + if realize_mode == "best": + role_preference = ("best", "final", "resume") + elif realize_mode == "latest": + role_preference = ("final", "resume", "best") + else: + raise ValueError(f"Invalid bypass.realize_best_or_latest={realize_mode!r}") + + checkpoint_dir = None + for role in role_preference: + candidate = checkpoints.get(role) + if candidate and Path(candidate).exists(): + checkpoint_dir = Path(candidate) + break + + if checkpoint_dir is None: + fallback = Path(cfg.bypass.experiment_dir) / "latest" + if fallback.exists(): + checkpoint_dir = fallback + else: + raise FileNotFoundError( + f"Could not find a bypass checkpoint to realize in {cfg.bypass.experiment_dir}" + ) + + ckpts_dir = Path(cfg.puzzle_dir) / "ckpts" + ckpts_dir.mkdir(parents=True, exist_ok=True) + + symlink_name = ckpts_dir / cfg.bypass.experiment_id + if symlink_name.exists() or symlink_name.is_symlink(): + symlink_name.unlink() + + symlink_name.symlink_to(checkpoint_dir, target_is_directory=True) + mprint(f"Created symlink: {symlink_name} -> {checkpoint_dir}") + return checkpoint_dir, symlink_name diff --git a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py index 740d1fada3c..e46f615f6f6 100644 --- a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py +++ b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py @@ -24,7 +24,12 @@ ) from .pruning_mixin import LayerDescriptor, PruningMixIn -from .pruning_utils import GQAInitMode, _init_attention_biases, _init_attention_weights +from .pruning_utils import ( + GQAInitMode, + _init_attention_biases, + _init_attention_weights, + _lm_head_dim, +) __all__ = [ "KVHeadsLayerDescriptor", @@ -74,7 +79,7 @@ def prune_single_layer( f"{attn_prefix}.{proj_name}" for proj_name in self.layer_descriptor.qkvo_weight_names ] - head_size = new_config.head_dim + head_size = _lm_head_dim(new_config) for part in ["weight", "bias"]: attn_keys = [f"{name}.{part}" for name in [q_name, k_name, v_name, o_name]] q_key, k_key, v_key, o_key = attn_keys diff --git a/modelopt/torch/puzzletron/pruning/pruning_utils.py b/modelopt/torch/puzzletron/pruning/pruning_utils.py index c600e119cfa..3b8e94347cb 100644 --- a/modelopt/torch/puzzletron/pruning/pruning_utils.py +++ b/modelopt/torch/puzzletron/pruning/pruning_utils.py @@ -52,6 +52,7 @@ class MlpInitMode(Enum): PruneByActivationsLog = "PruneByActivationsLog" ExpertRemoval = "ExpertRemoval" ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN" + MoEChannelPruning = "MoEChannelPruning" class LinearInitMode(Enum): @@ -66,6 +67,30 @@ class HiddenSizeInitMode(Enum): CopyAsIs = "CopyAsIs" +def _lm_attrs(config): + """Return the language-model sub-config for VL configs, else the config itself. + + VL configs nest language-model fields like ``num_attention_heads``, ``head_dim``, + and ``hidden_size`` under a sub-config. The attribute name varies by family — + ``text_config`` (Qwen3-VL, Llava, Idefics), ``language_config`` (Llama-4 and a + handful of others), and ``llm_config`` (InternVL and friends) are all common. + Probe each before falling back to the raw config. + """ + for attr in ("text_config", "language_config", "llm_config"): + sub = getattr(config, attr, None) + if sub is not None: + return sub + return config + + +def _lm_head_dim(config) -> int: + lm_config = _lm_attrs(config) + head_dim = getattr(lm_config, "head_dim", None) + if head_dim is not None: + return head_dim + return lm_config.hidden_size // lm_config.num_attention_heads + + def resolve_pruning_mixin( pruning_mixin, descriptor: Type[ModelDescriptor] ) -> PruningMixIn | List[PruningMixIn]: @@ -224,10 +249,13 @@ def _init_attention_weights( head_size, mlp_init_config, ): - assert new_config.num_attention_heads == original_config.num_attention_heads, ( - f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + new_lm = _lm_attrs(new_config) + orig_lm = _lm_attrs(original_config) + assert new_lm.num_attention_heads == orig_lm.num_attention_heads, ( + f"({new_lm.num_attention_heads=}) != ({orig_lm.num_attention_heads=})" ) - num_q_heads = new_config.num_attention_heads + num_q_heads = new_lm.num_attention_heads + # block_configs lives on the outer puzzletron-converted config, not on text_config. num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads @@ -372,17 +400,29 @@ def _init_attention_biases( head_size, mlp_init_config, ): - assert new_config.num_attention_heads == original_config.num_attention_heads, ( - f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + new_lm = _lm_attrs(new_config) + orig_lm = _lm_attrs(original_config) + assert new_lm.num_attention_heads == orig_lm.num_attention_heads, ( + f"({new_lm.num_attention_heads=}) != ({orig_lm.num_attention_heads=})" ) - num_q_heads = new_config.num_attention_heads + num_q_heads = new_lm.num_attention_heads + # block_configs lives on the outer puzzletron-converted config, not on text_config. num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads n_heads_in_group = num_q_heads // num_kv_heads orig_n_heads_in_group = num_q_heads // orig_num_kv_heads - o_proj_bias = new_config.o_proj_bias - attention_bias = new_config.attention_bias + # Some HF native configs (e.g. GptOssConfig) don't expose o_proj_bias / attention_bias as + # top-level attributes the way puzzletron's DeciLM-style configs do. Fall back to probing + # the new state dict for the actual bias keys when the attribute is missing. + # KVHeadsPruningMixIn only calls this helper after filtering to keys present in + # new_state_dict, so the probe mirrors the caller's already-selected bias tensors. + o_proj_bias = getattr(new_config, "o_proj_bias", None) + if o_proj_bias is None: + o_proj_bias = o_key in new_state_dict + attention_bias = getattr(new_config, "attention_bias", None) + if attention_bias is None: + attention_bias = q_key in new_state_dict # If no biases if not (o_proj_bias or attention_bias): @@ -438,8 +478,8 @@ def _init_attention_biases( assert not is_original_mha, ( "Degrouping can only be done on original models that are GQA themselves." ) - n_groups = new_config.num_attention_heads // n_heads_in_group - orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group + n_groups = new_lm.num_attention_heads // n_heads_in_group + orig_n_groups = orig_lm.num_attention_heads // orig_n_heads_in_group assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" n_repeats = n_groups // orig_n_groups if n_repeats > 1: diff --git a/modelopt/torch/puzzletron/sewing_kit/passage.py b/modelopt/torch/puzzletron/sewing_kit/passage.py index d8fa1f51cf9..c77b9dd41cd 100644 --- a/modelopt/torch/puzzletron/sewing_kit/passage.py +++ b/modelopt/torch/puzzletron/sewing_kit/passage.py @@ -45,6 +45,7 @@ "PassageOutput", "Predicate", "always_false_predicate", + "always_true_predicate", "Passage", "patch_module", ] diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py index 3db63f60013..106b0b3e4c3 100644 --- a/modelopt/torch/puzzletron/sewing_kit/utils.py +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -16,6 +16,7 @@ from __future__ import annotations import inspect +import operator from contextlib import contextmanager from typing import ( TYPE_CHECKING, @@ -451,3 +452,95 @@ def _get_group_kwarg_if_necessary() -> dict: torch.distributed.distributed_c10d._object_to_tensor ).parameters.keys() return dict(group=None) if "group" in arg_names else dict() + + +# ────────────────────────────────────────────────────────────────────────────── +# Loss functions for bypass distillation (blockwise local knowledge distillation) +# ────────────────────────────────────────────────────────────────────────────── + +# `normalized_mse_loss` already lives in tools.kd_model — re-export it here so +# bypass-distillation imports stay co-located with the per-vector / per-batch +# variants below, without duplicating the implementation. The `as +# normalized_mse_loss` form is PEP 484's explicit re-export (mypy treats +# `from X import Y` as a private import otherwise). +from modelopt.torch.puzzletron.tools.kd_model import ( # noqa: E402 + normalized_mse_loss as normalized_mse_loss, +) + + +def vectorwise_normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + epsilon: float = 1e-6, +) -> torch.Tensor: + """Like normalized_mse_loss, but normalization is done per-vector (last dim), then averaged.""" + return batched_normalized_mse_loss(input, target, epsilon, batch_dims=range(input.ndim - 1)) + + +def batched_normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + epsilon: float = 1e-6, + batch_dims: Sequence[int] = (0,), +) -> torch.Tensor: + """Per-batch-element relative-L2 loss. + + For each batch element, computes ``||input - target||^2 / (||target||^2 + eps)`` + over the non-batch dims, then averages across batch elements. The additive + ``epsilon`` in the denominator handles all-zero target slices without a hard + clamp and makes the loss scale-invariant when ``||target||^2 >> eps``. + """ + input_shape = tuple(input.shape) + target_shape = tuple(target.shape) + + if epsilon <= 0: + raise ValueError(f"epsilon must be strictly positive, got {epsilon!r}") + + try: + raw_batch_dims = tuple(operator.index(dim) for dim in batch_dims) + except TypeError as exc: + raise ValueError( + f"batch_dims must be an iterable of integer dimensions; got {batch_dims!r} " + f"for input shape {input_shape} and target shape {target_shape}" + ) from exc + + resolved_batch_dims = [] + for dim in raw_batch_dims: + if dim < -input.ndim or dim >= input.ndim: + raise ValueError( + f"batch_dims contains invalid dimension {dim} for input.ndim={input.ndim}; " + f"input shape={input_shape}, target shape={target_shape}, " + f"batch_dims={raw_batch_dims}, norm_dims=None" + ) + resolved_batch_dims.append(dim % input.ndim) + + if len(set(resolved_batch_dims)) != len(resolved_batch_dims): + raise ValueError( + f"batch_dims contains duplicate dimensions after normalization; " + f"input shape={input_shape}, target shape={target_shape}, " + f"batch_dims={tuple(resolved_batch_dims)}, norm_dims=None" + ) + + norm_dims = tuple(d for d in range(input.ndim) if d not in set(resolved_batch_dims)) + + if input.ndim != target.ndim: + raise ValueError( + f"input and target must have the same number of dimensions; " + f"input shape={input_shape}, target shape={target_shape}, " + f"batch_dims={tuple(resolved_batch_dims)}, norm_dims={norm_dims}" + ) + if input_shape != target_shape: + mismatched_dims = tuple( + dim + for dim, (input_size, target_size) in enumerate(zip(input_shape, target_shape)) + if input_size != target_size + ) + raise ValueError( + f"input and target shapes must match exactly; mismatched_dims={mismatched_dims}, " + f"input shape={input_shape}, target shape={target_shape}, " + f"batch_dims={tuple(resolved_batch_dims)}, norm_dims={norm_dims}" + ) + + num = ((input - target) ** 2).sum(dim=norm_dims) + den = (target**2).sum(dim=norm_dims) + epsilon + return (num / den).mean() diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py index b242c7d48ac..e041d884b0d 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py @@ -22,6 +22,8 @@ import os import re import time +from collections import ChainMap +from collections.abc import Iterator, MutableMapping from copy import deepcopy from functools import partial from pathlib import Path @@ -52,6 +54,49 @@ default_ignore_fn: IgnoreFn = lambda _: False +class _PerLayerKeysView(MutableMapping[str, str]): + def __init__(self, base: dict[str, str]) -> None: + self._base = base + self._overrides: dict[str, str] = {} + self._removed: dict[str, str] = {} + + def __getitem__(self, key: str) -> str: + if key in self._removed: + raise KeyError(key) + if key in self._overrides: + return self._overrides[key] + return self._base[key] + + def __setitem__(self, key: str, value: str) -> None: + self._removed.pop(key, None) + self._overrides[key] = value + + def __delitem__(self, key: str) -> None: + if key in self._removed: + raise KeyError(key) + if key in self._overrides: + self._removed[key] = self._overrides.pop(key) + elif key in self._base: + self._removed[key] = self._base[key] + else: + raise KeyError(key) + + def __iter__(self) -> Iterator[str]: + yield from self._overrides.keys() + for key in self._base: + if key not in self._overrides and key not in self._removed: + yield key + + def __len__(self) -> int: + return sum(1 for _ in self) + + def __contains__(self, key: object) -> bool: + return key not in self._removed and (key in self._overrides or key in self._base) + + def removed_items(self) -> dict[str, str]: + return dict(self._removed) + + class Printer: @staticmethod def print(s: str) -> None: @@ -83,27 +128,43 @@ def _process_single_layer( keys_to_remove = {} layer_out_state_dict = {} - # Delegate to pruning_mixin if available + # Delegate to pruning_mixin if available (supports a single mixin or a list of mixins). + # Mixins run sequentially. Each mixin sees the state dict produced by earlier mixins, + # which lets independent pruning methods compose on the same tensor (for example one + # pruning FFN channels and another pruning hidden-size dimensions). if pruning_mixin is not None: - _layer_out = pruning_mixin.prune_single_layer( - layer_idx=layer_idx, - parent_state_dict=parent_state_dict, - new_state_dict=new_state_dict, - original_config=original_config, - new_config=new_config, - gqa_init_mode=gqa_init_mode, - mlp_init_mode=mlp_init_mode, - mlp_init_config=mlp_init_config, - linear_init_mode=linear_init_mode, - ignored_keys=ignored_keys, - keys=keys, - is_original_mha=is_original_mha, - head_size=head_size, - hidden_size=hidden_size, - keys_to_remove=keys_to_remove, - ) - layer_out_state_dict.update(_layer_out) - return layer_out_state_dict, keys_to_remove + _mixins = pruning_mixin if isinstance(pruning_mixin, list) else [pruning_mixin] + merged_keys_to_remove = {} + parent_layer_updates = {} + new_layer_updates = {} + current_parent_state_dict = ChainMap(parent_layer_updates, parent_state_dict) + current_new_state_dict = ChainMap(new_layer_updates, new_state_dict) + current_keys = _PerLayerKeysView(keys) + for _mixin in _mixins: + mixin_keys_to_remove = {} + _layer_out = _mixin.prune_single_layer( + layer_idx=layer_idx, + parent_state_dict=current_parent_state_dict, + new_state_dict=current_new_state_dict, + original_config=original_config, + new_config=new_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + ignored_keys=ignored_keys, + keys=current_keys, + is_original_mha=is_original_mha, + head_size=head_size, + hidden_size=hidden_size, + keys_to_remove=mixin_keys_to_remove, + ) + layer_out_state_dict.update(_layer_out) + parent_layer_updates.update(_layer_out) + new_layer_updates.update(_layer_out) + merged_keys_to_remove.update(current_keys.removed_items()) + merged_keys_to_remove.update(mixin_keys_to_remove) + return layer_out_state_dict, merged_keys_to_remove # Legacy inline processing (fallback when no pruning_mixin) @@ -791,7 +852,10 @@ def update_model_config( def override(item, item_overrides): if item_overrides is None: - return item_overrides + # Hydra/OmegaConf ``null`` means "leave this field unchanged" in + # model_config_overrides. This lets compact overrides update only one + # sibling field without clearing the rest of the dataclass. + return item if dataclasses.is_dataclass(item): assert isinstance(item_overrides, dict) return dataclass_override(item, item_overrides) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 1240d1c9b65..8d2a8c48710 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -21,7 +21,10 @@ import concurrent.futures import dataclasses import fcntl +import inspect import os +import re +import shutil import time from collections import defaultdict from collections.abc import Callable, Mapping @@ -88,10 +91,20 @@ def force_cache_dynamic_modules( and "AutoConfig" in config.auto_map.keys() ) if has_remote_code and trust_remote_code: - for class_reference in config.auto_map.values(): + for class_reference in _iter_auto_map_class_refs(config.auto_map): _ = get_class_from_dynamic_module(class_reference, checkpoint_dir) +def _iter_auto_map_class_refs(auto_map: Mapping[str, Any]): + for value in auto_map.values(): + if isinstance(value, str): + yield value + elif isinstance(value, (list, tuple)): + for item in value: + if isinstance(item, str): + yield item + + def load_model_config( checkpoint_dir: Path | str, model_config_overrides: Mapping | None = None, @@ -135,16 +148,23 @@ def load_model_config( return config +_FALLBACK_WARNED_CLASSES: set[str] = set() + + def _get_model_class_from_config(config: PretrainedConfig) -> type: """Resolve HuggingFace model class from ``config.architectures`` (see puzzletron checkpoint_utils_hf).""" if hasattr(config, "architectures") and config.architectures: model_class_name = config.architectures[0] if hasattr(transformers, model_class_name): return getattr(transformers, model_class_name) - mprint( - f"Warning: {model_class_name} not found in transformers, " - "falling back to AutoModelForCausalLM" - ) + # Warn at most once per missing class per process — the fallback path + # may be hit thousands of times during scoring/realize loops. + if model_class_name not in _FALLBACK_WARNED_CLASSES: + _FALLBACK_WARNED_CLASSES.add(model_class_name) + mprint( + f"Warning: {model_class_name} not found in transformers, " + "falling back to AutoModelForCausalLM" + ) return AutoModelForCausalLM @@ -209,10 +229,9 @@ def save_checkpoint_from_shards( """ Save a checkpoint when the model's weights are sharded across distributed ranks. - Gathers each rank's partial state dictionary onto rank 0 and writes a complete checkpoint - (including the safetensors index and subblocks) from the merged weights. On a single-process - run, saves directly from the local state dict. Only rank 0 performs the filesystem write; - non-master ranks only participate in the gather. + On distributed runs, rank 0 gathers only tensor-name metadata up front and then gathers + tensors one safetensors file at a time. This avoids materializing the full model from all + ranks on rank 0 while still producing a single HF-compatible checkpoint/index. Parameters: model (PreTrainedModel): The model instance whose local state_dict contains this rank's @@ -222,31 +241,112 @@ def save_checkpoint_from_shards( the safetensors index. """ - local_sd = {k: v.cpu() for k, v in model.state_dict().items()} + local_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()} if dist_utils.size() > 1: - save_err: str | None = None + _save_checkpoint_from_distributed_shards(model.config, local_sd, checkpoint_dir, descriptor) + dist_utils.barrier() + else: + _save_checkpoint(model.config, local_sd, checkpoint_dir, descriptor) + + +def _save_checkpoint_from_distributed_shards( + model_config: PretrainedConfig, + local_state_dict: dict[str, torch.Tensor], + checkpoint_dir: Path | str, + descriptor: "ModelDescriptor", +) -> None: + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + local_keys = list(local_state_dict.keys()) + gathered_keys: list[list[str] | None] | None = ( + [None] * dist_utils.size() if dist_utils.is_master() else None + ) + tdist.gather_object(local_keys, gathered_keys, dst=0) + + owner_by_key = None + weight_map = None + setup_err = None + if dist_utils.is_master(): + try: + assert gathered_keys is not None + checkpoint_dir.mkdir(parents=True, exist_ok=True) + save_model_config(model_config, checkpoint_dir) + + # Match the old full_sd.update(rank_order) behavior for duplicate tied + # weights by letting the highest rank that owns a key supply it. + owner_by_key = { + key: rank + for rank, keys in enumerate(gathered_keys) + if keys is not None + for key in keys + } + full_keys = list(owner_by_key) + + output_emb_weight_name = f"{descriptor.output_embedding_name()}.weight" + if getattr(model_config, "tie_word_embeddings", False): + owner_by_key.pop(output_emb_weight_name, None) + full_keys = [key for key in full_keys if key != output_emb_weight_name] + + lm_config = descriptor.get_language_model_config(model_config) + subblock_keys = descriptor.get_weight_groups( + layer_names=full_keys, + num_hidden_layers=lm_config.num_hidden_layers, + ) + weight_map = { + key: f"subblocks_safetensors/{subblock}.safetensors" + for subblock, layer_keys in subblock_keys.items() + for key in layer_keys + } + + index = {"metadata": {"format": "pt"}, "weight_map": weight_map} + _write_file_process_safe(json_dumps(index), checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME) + (checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).mkdir(parents=True, exist_ok=True) + except Exception as e: + setup_err = repr(e) + owner_by_key = {} + weight_map = {} + + payload = [setup_err, owner_by_key, weight_map] + tdist.broadcast_object_list(payload, src=0) + setup_err, owner_by_key, weight_map = payload + if setup_err is not None: + raise RuntimeError(f"Checkpoint setup failed on rank 0: {setup_err}") + assert owner_by_key is not None + assert weight_map is not None + + for relative_filename in sorted(set(weight_map.values())): + local_file_tensors = { + key: tensor.contiguous() + for key, tensor in local_state_dict.items() + if owner_by_key.get(key) == dist_utils.rank() and weight_map.get(key) == relative_filename + } + gathered_tensors: list[dict[str, torch.Tensor] | None] | None = ( + [None] * dist_utils.size() if dist_utils.is_master() else None + ) + tdist.gather_object(local_file_tensors, gathered_tensors, dst=0) if dist_utils.is_master(): - gathered: list[dict] = [None] * dist_utils.size() - tdist.gather_object(local_sd, gathered, dst=0) - full_sd: dict[str, torch.Tensor] = {} - for shard_sd in gathered: - if shard_sd is None: - continue - full_sd.update(shard_sd) - try: - _save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor) - except Exception as e: - save_err = repr(e) + assert gathered_tensors is not None + file_state_dict: dict[str, torch.Tensor] = {} + for shard_tensors in gathered_tensors: + if shard_tensors: + file_state_dict.update(shard_tensors) + file_err = None + if file_state_dict: + try: + safe_save_file( + tensors=file_state_dict, + filename=checkpoint_dir / relative_filename, + metadata={"format": "pt"}, + ) + except Exception as e: + file_err = repr(e) else: - tdist.gather_object(local_sd, dst=0) - err_box = [save_err] + file_err = None + err_box = [file_err] tdist.broadcast_object_list(err_box, src=0) - # Barrier ensures all ranks wait until file I/O completes before continuing - dist_utils.barrier() if err_box[0] is not None: - raise RuntimeError(f"Checkpoint save failed on rank 0: {err_box[0]}") - else: - _save_checkpoint(model.config, local_sd, checkpoint_dir, descriptor) + raise RuntimeError(f"Checkpoint save failed for {relative_filename}: {err_box[0]}") def _save_checkpoint( @@ -265,9 +365,10 @@ def _save_checkpoint( save_model_config(model_config, checkpoint_dir) # Phase 2: Build weight map using descriptor and write index + lm_config = descriptor.get_language_model_config(model_config) subblock_keys = descriptor.get_weight_groups( layer_names=state_dict.keys(), - num_hidden_layers=model_config.num_hidden_layers, + num_hidden_layers=lm_config.num_hidden_layers, ) weight_map = {} @@ -490,6 +591,46 @@ def _build_safetensors_weight_map( return weight_map +def _copy_auto_map_code_files(model_config: PretrainedConfig, checkpoint_dir: Path) -> None: + """Copy custom modeling Python files referenced in ``auto_map`` to the checkpoint dir. + + ``PretrainedConfig.save_pretrained()`` only copies the config class's own source file + (e.g. ``configuration_nemotron_h.py``). Trust-remote-code models also need ``modeling_*.py`` + (and any other auto_map-referenced ``.py``) present alongside ``config.json``, otherwise + later ``AutoConfig.from_pretrained(..., trust_remote_code=True)`` calls fail with + "does not appear to have a file named modeling_*.py". + + We discover the source directory from the config class itself (via ``inspect.getfile``) + and copy every distinct ``.py`` referenced by the auto_map values. + """ + if not hasattr(model_config, "auto_map") or not isinstance(model_config.auto_map, dict): + return + + try: + source_dir = Path(inspect.getfile(type(model_config))).parent + except (TypeError, OSError): + # Built-in / non-file-backed config class — nothing to copy. + return + + # Module names must look like Python identifiers — refuse anything with separators + # or relative-path components so a malformed/hostile auto_map can't drive shutil.copy + # outside source_dir / checkpoint_dir. + _module_name_re = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + module_names = { + class_ref.split(".")[0] for class_ref in _iter_auto_map_class_refs(model_config.auto_map) + } + + for module_name in module_names: + if not _module_name_re.match(module_name): + mprint(f"Warning: skipping non-identifier auto_map module name: {module_name!r}") + continue + filename = f"{module_name}.py" + src = source_dir / filename + dst = Path(checkpoint_dir) / filename + if src.exists() and not dst.exists(): + shutil.copy(src, dst) + + def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str) -> None: if hasattr(model_config, "block_configs"): model_config.block_configs = [ @@ -497,3 +638,4 @@ def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str for conf in model_config.block_configs ] model_config.save_pretrained(checkpoint_dir) + _copy_auto_map_code_files(model_config, Path(checkpoint_dir)) diff --git a/modelopt/torch/puzzletron/tools/hydra_utils.py b/modelopt/torch/puzzletron/tools/hydra_utils.py index c30be4efde8..c3e282d5e2b 100644 --- a/modelopt/torch/puzzletron/tools/hydra_utils.py +++ b/modelopt/torch/puzzletron/tools/hydra_utils.py @@ -32,16 +32,57 @@ ] -def warmup_steps(tokens: int, block: int, mbs: int, pct: float = 0.05) -> int: +def warmup_steps(tokens: int, block: int, mbs: int, grad_accum: int, pct: float) -> int: """ - Calculate warmup steps based on total tokens, block size, micro batch size, and warmup percentage. - Used as a resolver in hydra configs. + Calculate warmup steps in optimizer-step units. + + total_iters = tokens / (block * mbs) gives micro-batches; one optimizer step + consumes ``grad_accum`` micro-batches, so total optimizer steps = total_iters + / grad_accum. The LR scheduler in ``_get_lr`` is indexed by ``step_num`` + (optimizer steps), so warmup must be in the same units. """ - steps = (int(tokens) // int(block)) // int(mbs) + try: + tokens = int(tokens) + block = int(block) + mbs = int(mbs) + grad_accum = int(grad_accum) + except (TypeError, ValueError) as exc: + raise ValueError( + "tokens, block, mbs, and grad_accum must be integers or castable to int; " + f"got tokens={tokens!r}, block={block!r}, mbs={mbs!r}, grad_accum={grad_accum!r}" + ) from exc + + try: + pct = float(pct) + except (TypeError, ValueError) as exc: + raise ValueError(f"pct must be a float or castable to float, got {pct!r}") from exc + + if tokens < 0: + raise ValueError(f"tokens must be >= 0, got {tokens!r}") + if block <= 0: + raise ValueError(f"block must be > 0, got {block!r}") + if mbs <= 0: + raise ValueError(f"mbs must be > 0, got {mbs!r}") + if grad_accum < 1: + raise ValueError(f"grad_accum must be >= 1, got {grad_accum!r}") + if not 0.0 <= pct <= 1.0: + raise ValueError(f"pct must be between 0.0 and 1.0 inclusive, got {pct!r}") + + iters = (tokens // block) // mbs + steps = max(1, iters // grad_accum) w = pct * steps return max(1, round(w)) +def _warmup_steps_resolver(*args): + if len(args) != 5: + raise ValueError( + "warmup_steps resolver expects exactly 5 arguments: " + "(tokens, block, micro_batch_size, grad_accumulation_steps, warmup_ratio)" + ) + return warmup_steps(*args) + + def register_hydra_resolvers(): OmegaConf.register_new_resolver("to_path", lambda x: Path(x)) OmegaConf.register_new_resolver( @@ -50,7 +91,7 @@ def register_hydra_resolvers(): OmegaConf.register_new_resolver( "timedelta_minutes", lambda x: datetime.timedelta(minutes=x) if x is not None else None ) - OmegaConf.register_new_resolver("warmup_steps", lambda t, b, m, p: warmup_steps(t, b, m, p)) + OmegaConf.register_new_resolver("warmup_steps", _warmup_steps_resolver) OmegaConf.register_new_resolver("get_object", lambda x: get_object(x)) diff --git a/modelopt/torch/puzzletron/utils/data/dataloaders.py b/modelopt/torch/puzzletron/utils/data/dataloaders.py index f4046531491..3d8b94c82cc 100644 --- a/modelopt/torch/puzzletron/utils/data/dataloaders.py +++ b/modelopt/torch/puzzletron/utils/data/dataloaders.py @@ -31,7 +31,7 @@ from ...tools.logger import mprint from .dataset import ConstantLengthDataset -__all__ = ["create_validation_dataloader", "create_padded_tensor"] +__all__ = ["create_train_dataloader", "create_validation_dataloader", "create_padded_tensor"] def collate_none_fn( @@ -73,6 +73,74 @@ def load_streaming_fn( return dataset +def create_train_dataloader( + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset_path: str | Mapping[str, Dataset], + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + dataset_name: str = "train", + keep_in_memory: bool = False, + shuffle_seed: int | None = None, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + num_workers: int = 0, +) -> DataLoader: + """Create an infinite training DataLoader over ConstantLengthDataset.""" + # ConstantLengthDataset.__iter__ does not consult torch.utils.data.get_worker_info() + # to shard work across DataLoader workers, so num_workers > 0 would have every + # worker iterate the full dataset and emit duplicate samples. Reject explicitly + # until ConstantLengthDataset gains worker-aware iteration; the guard can then + # be removed. + if num_workers > 0: + raise ValueError( + f"create_train_dataloader: num_workers={num_workers} is not supported " + f"because ConstantLengthDataset.__iter__ does not shard via " + f"torch.utils.data.get_worker_info(). Use num_workers=0 (the default) " + f"or add worker-aware sharding to ConstantLengthDataset.__iter__." + ) + + if isinstance(dataset_path, str): + dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory) + else: + dataset = dataset_path + + train_data = dataset[dataset_name] + if shuffle_seed is not None: + # `keep_in_memory` is only valid on map-style HF Datasets; streaming + # `IterableDataset.shuffle()` only accepts `seed` (and an optional + # `buffer_size`). Branch on the dataset type so streaming users + # (`load_from_disk: false`) don't crash on this call. + if isinstance(train_data, datasets.IterableDataset): + train_data = train_data.shuffle(seed=shuffle_seed) + else: + train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=keep_in_memory) + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + infinite=True, + seq_length=block_size, + content_field=content_field, + fim_rate=fim_rate, + fim_spm_rate=fim_spm_rate, + seed=seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + ) + + return DataLoader( + train_dataset, + batch_size=micro_batch_size, + pin_memory=True, + num_workers=num_workers, + ) + + def create_validation_dataloader( accelerator: Accelerator | None, seed: int, diff --git a/modelopt/torch/puzzletron/utils/data/dataset.py b/modelopt/torch/puzzletron/utils/data/dataset.py index f88e44a234b..01422e5a4b7 100644 --- a/modelopt/torch/puzzletron/utils/data/dataset.py +++ b/modelopt/torch/puzzletron/utils/data/dataset.py @@ -35,6 +35,14 @@ CODEGEN_FIM_TOKENS = ["", "<|endoftext|>", ""] +def _message_content_to_text(content) -> str: + if isinstance(content, str): + return content + if isinstance(content, dict) and "text" in content: + return str(content["text"]) + return str(content) + + class ConstantLengthDataset(IterableDataset): """Iterable dataset that returns constant length chunks of tokens from stream of text files. @@ -128,9 +136,18 @@ def __iter__(self) -> dict[str, torch.Tensor]: and {"content", "role"}.issubset(sample[0]) ): if len(sample) > 1: - sample = self.tokenizer.apply_chat_template(sample, tokenize=False) + if getattr(self.tokenizer, "chat_template", None) is not None: + sample = self.tokenizer.apply_chat_template( + sample, tokenize=False + ) + else: + # Base models have no chat template — concatenate message + # contents separated by newlines as plain text. + sample = "\n".join( + _message_content_to_text(m["content"]) for m in sample + ) else: - sample = sample[0]["content"] + sample = _message_content_to_text(sample[0]["content"]) else: sample = sample[self.tokens_field] sample = sample[: self.max_sample_length] diff --git a/modelopt/torch/puzzletron/utils/parsing.py b/modelopt/torch/puzzletron/utils/parsing.py index 149563b4321..69e21e0599b 100644 --- a/modelopt/torch/puzzletron/utils/parsing.py +++ b/modelopt/torch/puzzletron/utils/parsing.py @@ -24,6 +24,7 @@ # mypy: ignore-errors import json +import math from pathlib import Path from typing import Any @@ -116,7 +117,7 @@ def format_block_configs(config) -> str: ╭─────────────────────── Model Architecture ────────────────────────╮ │ Layer 1 │ Attention: no_op │ FFN: mult = 4.95 │ │ Layer 2 │ Attention: 4 heads in group │ FFN: mult = 4.95 │ - │ Layer 3 │ Attention: 4 heads in group │ FFN: no_op │ + │ Layer 3 │ Attention: no_op │ FFN: no_op │ ╰────────────────────────────────────────────────────────────────────╯ """ if not hasattr(config, "block_configs") or not config.block_configs: @@ -158,7 +159,7 @@ def _format_attention_config(attention_config) -> str: num_kv_heads = attention_config.num_key_value_heads if num_kv_heads is not None: - return f"{num_kv_heads} kv heads" + return f"🐙 {num_kv_heads} kv heads" if attention_config.replace_with_linear: return "linear replacement" @@ -192,12 +193,12 @@ def _format_ffn_config(ffn_config) -> str: ffn_intermediate = ffn_config.intermediate_size if ffn_intermediate is not None: - return f"ffn_intermediate = {ffn_intermediate}" + return f"🧱 ffn_dim = {ffn_intermediate}" # Check for MoE configuration moe_config = ffn_config.moe if moe_config: - return "MoE" + return "🔀 MoE" if ffn_config.sparsify: return "sparse" @@ -287,7 +288,7 @@ def _add_config_section(cfg: DictConfig, section_name: str = "", indent: int = 0 # Regular key-value pair indent_str = " " * (indent + 1) value_str = _format_value(value).replace(" " * 0, "").strip() - line = f"│ {indent_str} {key}: {value_str}" + line = f"│ {indent_str} • {key}: {value_str}" # Pad to box width if len(line) >= box_width - 1: # Truncate long lines @@ -310,6 +311,8 @@ def format_stitched_losses( losses_dict: dict[str, float], best_steps_dict: dict[str, int] | None = None, best_values_dict: dict[str, float] | None = None, + initial_values_dict: dict[str, float] | None = None, + not_trainable_names: set[str] | None = None, step_number: int | None = None, title: str = "Stitched Module Losses", ) -> str: @@ -320,6 +323,9 @@ def format_stitched_losses( losses_dict: Dictionary with block names as keys and current loss values as floats best_steps_dict: Optional dictionary with block names as keys and best step numbers as values best_values_dict: Optional dictionary with block names as keys and best loss values as floats + initial_values_dict: Optional dictionary with block names as keys and initial loss values + (from the first log chunk) as floats. Used to render the "Δ from initial" column as + a per-block training-progress signal. step_number: Optional current step number to include in summary title: Title to display at the top of the formatted output @@ -328,23 +334,39 @@ def format_stitched_losses( Example output: ╭─────────────────── Stitched Module Losses ──────────────────╮ - │ Block │ Loss Value │ Best Step │ Best Value │ Change from avg │ - │───────┼────────────┼───────────┼────────────┼──────────────────│ - │ 00 │ 6.21e-03 │ Step 5 │ 5.95e-03 │ ↑ +2.6e-04 │ - │ 01 │ 5.14e-04 │ Step 12 │ 5.14e-04 │ ↓ -1.2e-04 │ - │ 02 │ 9.84e-05 │ Step 15 │ 9.84e-05 │ ↓ -3.1e-04 │ + │ Block │ Loss Value │ Δ from initial │ Best Value │ Best Step │ + │───────┼────────────┼──────────────────┼────────────┼───────────│ + │ 00 │ 6.21e-03 │ ↓ -3.2e-04 (-5%) │ 5.95e-03 │ Step 5 │ + │ 01 │ 5.14e-04 │ ↓ -1.8e-03 (-78%)│ 5.14e-04 │ Step 12 │ + │ 02 │ 9.84e-05 │ ↓ -4.1e-04 (-81%)│ 9.84e-05 │ Step 15 │ ╰──────────────────────────────────────────────────────────────╯ """ if not losses_dict: + if not_trainable_names: + return ( + "No trainable losses found; " + f"skipped {len(not_trainable_names)} non-trainable blocks" + ) return "❌ No losses found" + if best_steps_dict: + best_steps_dict = {k: v for k, v in best_steps_dict.items() if k in losses_dict} + if best_values_dict: + best_values_dict = {k: v for k, v in best_values_dict.items() if k in losses_dict} + if initial_values_dict: + initial_values_dict = {k: v for k, v in initial_values_dict.items() if k in losses_dict} + lines = [] # Calculate statistics loss_values = list(losses_dict.values()) - max_loss = max(loss_values) - min_loss = min(loss_values) - avg_loss = sum(loss_values) / len(loss_values) + finite_loss_values = [value for value in loss_values if math.isfinite(value)] + if finite_loss_values: + max_loss = max(finite_loss_values) + min_loss = min(finite_loss_values) + avg_loss = sum(finite_loss_values) / len(finite_loss_values) + else: + max_loss = min_loss = avg_loss = float("nan") # Calculate box width for new layout (removed Bar column) box_width = 74 @@ -356,10 +378,10 @@ def format_stitched_losses( f"│{' ' * title_padding}{title}{' ' * (box_width - 2 - title_padding - len(title))}│" ) separator = ( - f"│ {'Block':<5} │ {'Loss Value':<12} │ {'Best Step':<10} │ " - f"{'Best Value':<12} │ {'Change from avg':<18} │" + f"│ {'Block':<5} │ {'Loss Value':<12} │ {'Δ from initial':<18} │ " + f"{'Best Value':<12} │ {'Best Step':<10} │" ) - divider = f"│{'─' * 7}┼{'─' * 14}┼{'─' * 12}┼{'─' * 14}┼{'─' * 20}│" + divider = f"│{'─' * 7}┼{'─' * 14}┼{'─' * 20}┼{'─' * 14}┼{'─' * 12}│" lines.extend([header, title_line, separator, divider]) @@ -382,26 +404,36 @@ def format_stitched_losses( best_value = loss_value # Assume current is best if no history best_value_str = f"{best_value:.2e}" - # Calculate change from average - change_from_avg = loss_value - avg_loss - if abs(change_from_avg) > 1e-8: # Only show if meaningful - change_str = f"{abs(change_from_avg):.1e}" - if change_from_avg > 0: - # Current is above average (worse for loss) - change_display = f"↑ +{change_str}" + # Calculate change from initial: current loss minus the block's loss in the + # first log chunk we saw. Per-block training-progress signal — answers "is + # bypass distillation actually reducing this block's loss?" and stays + # apples-to-apples even when blocks have very different intrinsic loss scales. + if initial_values_dict and block_name in initial_values_dict: + initial_value = initial_values_dict[block_name] + if not math.isfinite(loss_value) or not math.isfinite(initial_value): + change_display = "non-finite" else: - # Current is below average (better for loss) - change_display = f"↓ -{change_str}" + delta = loss_value - initial_value + if math.isfinite(loss_value) and math.isfinite(initial_value) and abs(delta) > 1e-8: + pct = (delta / initial_value * 100.0) if initial_value != 0.0 else 0.0 + # Clamp percentage display to keep the cell within the 18-char column + # even on pathological divergence (e.g. a block whose loss 10x'd). + pct_clamped = max(-999.0, min(999.0, pct)) + arrow = "↓" if delta < 0 else "↑" + sign = "-" if delta < 0 else "+" + change_display = f"{arrow} {sign}{abs(delta):.1e} ({pct_clamped:+.0f}%)" + elif math.isfinite(loss_value) and math.isfinite(initial_value): + change_display = "↔ 0.0e+00" else: - # At average value - change_display = "↔ 0.0e+00" + # No baseline supplied (callers may omit initial_values_dict). + change_display = " --" # Format the line block_display = block_name.replace("block_", "").zfill(2) line = ( - f"│ {block_display:<5} │ {loss_str:<12} │ {best_step_str:<10} │ " - f"{best_value_str:<12} │ {change_display:<18} │" + f"│ {block_display:<5} │ {loss_str:<12} │ {change_display:<18} │ " + f"{best_value_str:<12} │ {best_step_str:<10} │" ) lines.append(line) @@ -413,6 +445,8 @@ def format_stitched_losses( if step_number is not None: summary_parts.append(f"Step {step_number}") summary_parts.extend([f"Avg={avg_loss:.2e}", f"Max={max_loss:.2e}", f"Min={min_loss:.2e}"]) + if not_trainable_names: + summary_parts.append(f"Skipped={len(not_trainable_names)}") summary_text = ", ".join(summary_parts) summary = f"│ Summary: {summary_text}" @@ -436,7 +470,9 @@ def format_stitched_losses( best_step_values = [] for block_name, best_step in best_steps_dict.items(): if best_step == modal_best_step and block_name in best_values_dict: - best_step_values.append(best_values_dict[block_name]) + best_value = best_values_dict[block_name] + if math.isfinite(best_value): + best_step_values.append(best_value) if best_step_values: best_step_avg = sum(best_step_values) / len(best_step_values) diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py index ea0a6fd2193..82a66c1bb00 100644 --- a/tests/_test_utils/torch/puzzletron/utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -16,6 +16,7 @@ import os from pathlib import Path +import pytest import torch from _test_utils.torch.transformers_models import get_tiny_tokenizer from datasets import Dataset, DatasetDict @@ -25,6 +26,40 @@ import modelopt.torch.utils.distributed as dist from modelopt.torch.export import copy_hf_ckpt_remote_code +# Shared parametrize tuple for puzzletron GPU integration tests. +# Fields: (hf_model_name, converter, hybrid_override_pattern, has_moe_layers). +# To add a new model family, append a single pytest.param row here — every test +# that imports PUZZLETRON_FAMILIES picks it up automatically. +PUZZLETRON_FAMILIES = [ + pytest.param("meta-llama/Llama-3.1-8B-Instruct", "llama", None, False, id="llama-3.1-8B"), + pytest.param("meta-llama/Llama-3.2-3B-Instruct", "llama", None, False, id="llama-3.2-3B"), + pytest.param( + "mistralai/Mistral-Small-24B-Instruct-2501", + "mistral_small", + None, + False, + id="mistral-small-24B", + ), + pytest.param( + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16", + "nemotron_h", + "*E", + True, + id="nemotron-3-30B-A3B", + ), + pytest.param( + "nvidia/NVIDIA-Nemotron-Nano-12B-v2", + "nemotron_h_v2", + "*-", + False, + id="nemotron-nano-12B-v2", + ), + pytest.param("openai/gpt-oss-20b", "gpt_oss", None, True, id="gpt-oss-20b"), + pytest.param("Qwen/Qwen2.5-7B-Instruct", "qwen2", None, False, id="qwen2.5-7B"), + pytest.param("Qwen/Qwen3-8B", "qwen3", None, False, id="qwen3-8B"), + pytest.param("Qwen/Qwen3-VL-30B-A3B-Instruct", "qwen3_vl", None, True, id="qwen3-VL-30B-A3B"), +] + def setup_test_model_and_data( tmp_path: Path, rank: int, hf_model_name: str, hybrid_override_pattern: str | None = None diff --git a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py new file mode 100644 index 00000000000..931b0bfcb17 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -0,0 +1,380 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for ``bypass_checkpoint_utils``. + +The save/resume contract here is the most important regression surface in the +bypass feature: a wrong checkpoint pick or a missing ``saving_completed`` +marker silently restarts training from the wrong iteration. + +What's covered here (CPU-only, codecov-visible): + * ``find_latest_run_dir`` — every branch of the regex/scan/symlink logic. + * ``_save_local_file`` — overwrite/skip semantics. + * ``_save_local_state`` — same three save-path assertions as the GPU file + (state_dict / optimizer / grad_scaler), but on CPU so codecov picks them + up. The GPU file's ``test_load_local_state_*`` cases stay there because + ``load_local_state`` constructs ``torch.device(f"cuda:{rank}")`` directly. + * ``save_bypass_checkpoint`` — orchestration: ``latest`` symlink update, + ``args.json`` dump, ``saving_completed`` marker, master-only gating. +""" + +import os +from collections import OrderedDict +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from omegaconf import OmegaConf +from torch.amp.grad_scaler import GradScaler + +from modelopt.torch.puzzletron.bypass_distillation import bypass_checkpoint_utils as bcu +from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import ( + StitchedModuleDescriptor, +) + +# --------------------------------------------------------------------------- +# Shared fixture: silence the dist helpers so these run single-process / CPU. +# Mirrors tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py:56-62. +# --------------------------------------------------------------------------- + + +@pytest.fixture +def bcu_no_dist(monkeypatch): + monkeypatch.setattr(bcu.dist, "local_rank", lambda: 0) + monkeypatch.setattr(bcu.dist, "is_master", lambda: True) + monkeypatch.setattr(bcu.dist, "barrier", lambda: None) + return bcu + + +def _make_descriptor(*, with_optimizer: bool = True, with_scaler: bool = True): + """Build a CPU-only StitchedModuleDescriptor — the GPU file's helper minus + the configurable init_scale (we don't round-trip the scaler here).""" + module = nn.Linear(4, 4, bias=False) + owned_parameters = dict(module.named_parameters()) + optimizer = torch.optim.AdamW(list(module.parameters()), lr=1e-3) if with_optimizer else None + scaler = GradScaler(device="cpu", enabled=True, init_scale=2.0**16) if with_scaler else None + return StitchedModuleDescriptor( + stitched_module=module, + owned_parameters=owned_parameters, + owned_buffers={}, + optimizer=optimizer, + grad_scaler=scaler, + ) + + +# --------------------------------------------------------------------------- +# find_latest_run_dir +# --------------------------------------------------------------------------- + + +def test_find_latest_run_dir_returns_none_for_empty_dir(tmp_path: Path): + assert bcu.find_latest_run_dir(tmp_path) is None + + +def test_find_latest_run_dir_picks_only_step_with_marker(tmp_path: Path): + step_dir = tmp_path / "step-000010-ckpt" + step_dir.mkdir() + (step_dir / "saving_completed").touch() + assert bcu.find_latest_run_dir(tmp_path) == str(step_dir) + + +def test_find_latest_run_dir_picks_highest_step_number(tmp_path: Path): + """When several plain step checkpoints have completed markers, the highest + integer wins — not lexicographic order, not insertion order.""" + for i in (5, 10, 20): + d = tmp_path / f"step-{i:06d}-ckpt" + d.mkdir() + (d / "saving_completed").touch() + assert bcu.find_latest_run_dir(tmp_path) == str(tmp_path / "step-000020-ckpt") + + +def test_find_latest_run_dir_skips_step_without_marker(tmp_path: Path): + """A partially-written checkpoint (no ``saving_completed``) must be skipped + even when it has a higher step number — otherwise resume would crash on a + truncated state dict.""" + high = tmp_path / "step-000099-ckpt" + high.mkdir() + # No saving_completed → must be ignored. + low = tmp_path / "step-000050-ckpt" + low.mkdir() + (low / "saving_completed").touch() + assert bcu.find_latest_run_dir(tmp_path) == str(low) + + +def test_find_latest_run_dir_returns_none_when_no_step_has_marker(tmp_path: Path): + (tmp_path / "step-000010-ckpt").mkdir() + (tmp_path / "step-000020-ckpt").mkdir() + # No saving_completed anywhere. + assert bcu.find_latest_run_dir(tmp_path) is None + + +def test_find_latest_run_dir_excludes_non_plain_step_names(tmp_path: Path): + """``best-step-*`` / ``start-step-*`` / ``final-step-*`` aren't valid resume + targets — pinned by the docstring on lines 39-42.""" + for name in ("best-step-000099-ckpt", "start-step-000001-ckpt", "final-step-000050-ckpt"): + d = tmp_path / name + d.mkdir() + (d / "saving_completed").touch() + # No plain step-*-ckpt at all. + assert bcu.find_latest_run_dir(tmp_path) is None + + +def test_find_latest_run_dir_uses_latest_symlink_fast_path(tmp_path: Path): + """The ``latest`` symlink, when present and complete, short-circuits the + scan — even when a numerically higher step dir also has a marker. This + matters because the scan branch can be slow on filesystems with many + step dirs (NFS, lustre).""" + target = tmp_path / "step-000010-ckpt" + target.mkdir() + (target / "saving_completed").touch() + (tmp_path / "latest").symlink_to(target.name) + + higher = tmp_path / "step-000020-ckpt" + higher.mkdir() + (higher / "saving_completed").touch() + + # Symlink wins despite higher step existing. + assert bcu.find_latest_run_dir(tmp_path) == str(tmp_path / "latest") + + +def test_find_latest_run_dir_falls_through_when_latest_lacks_marker(tmp_path: Path): + """A ``latest`` symlink whose target lacks ``saving_completed`` (interrupted + save) must be ignored, falling through to the highest completed step.""" + incomplete = tmp_path / "step-000020-ckpt" + incomplete.mkdir() + # No saving_completed. + (tmp_path / "latest").symlink_to(incomplete.name) + + completed = tmp_path / "step-000010-ckpt" + completed.mkdir() + (completed / "saving_completed").touch() + + assert bcu.find_latest_run_dir(tmp_path) == str(completed) + + +def test_find_latest_run_dir_ignores_latest_to_best_checkpoint(tmp_path: Path): + """`latest` is a resume pointer, so old symlinks to best checkpoints are ignored.""" + best = tmp_path / "best-step-000020-ckpt" + best.mkdir() + (best / "saving_completed").touch() + (tmp_path / "latest").symlink_to(best.name) + + completed = tmp_path / "step-000010-ckpt" + completed.mkdir() + (completed / "saving_completed").touch() + + assert bcu.find_latest_run_dir(tmp_path) == str(completed) + + +# --------------------------------------------------------------------------- +# _save_local_file +# --------------------------------------------------------------------------- + + +def test_save_local_file_writes_object_to_disk(tmp_path: Path): + target = tmp_path / "blob.pth" + bcu._save_local_file({"a": torch.tensor([1, 2, 3])}, target) + assert target.exists() + loaded = torch.load(target, weights_only=True) + assert torch.equal(loaded["a"], torch.tensor([1, 2, 3])) + + +def test_save_local_file_overwrite_true_replaces_contents(tmp_path: Path): + target = tmp_path / "blob.pth" + bcu._save_local_file({"v": torch.tensor([1])}, target) + bcu._save_local_file({"v": torch.tensor([99])}, target, overwrite=True) + loaded = torch.load(target, weights_only=True) + assert torch.equal(loaded["v"], torch.tensor([99])) + + +def test_save_local_file_overwrite_false_skips_existing(tmp_path: Path): + target = tmp_path / "blob.pth" + bcu._save_local_file({"v": torch.tensor([1])}, target) + # Second save should be a no-op. + bcu._save_local_file({"v": torch.tensor([99])}, target, overwrite=False) + loaded = torch.load(target, weights_only=True) + assert torch.equal(loaded["v"], torch.tensor([1])) + + +# --------------------------------------------------------------------------- +# _save_local_state: optimizer + grad_scaler only. +# Weights deliberately do NOT land here — the HF checkpoint at the same +# directory carries the full student state dict via ``save_checkpoint``. +# Saving the per-block weights again would just double the disk footprint. +# --------------------------------------------------------------------------- + + +def test_save_local_state_writes_optimizer_and_grad_scaler(tmp_path: Path, bcu_no_dist): + descriptors = OrderedDict([("block_0", _make_descriptor())]) + bcu_no_dist._save_local_state(descriptors, tmp_path) + stitched = tmp_path / "stitched" + assert (stitched / "block_0.optimizer_state.pth").exists() + assert (stitched / "block_0.grad_scaler.pth").exists() + + +def test_save_local_state_does_not_write_weights_state_dict(tmp_path: Path, bcu_no_dist): + """Pin the de-duplication: weights live in the HF checkpoint, not here.""" + descriptors = OrderedDict([("block_0", _make_descriptor())]) + bcu_no_dist._save_local_state(descriptors, tmp_path) + assert not (tmp_path / "stitched" / "block_0.state_dict.pth").exists() + + +def test_save_local_state_skips_grad_scaler_when_descriptor_has_none(tmp_path: Path, bcu_no_dist): + descriptors = OrderedDict([("block_0", _make_descriptor(with_scaler=False))]) + bcu_no_dist._save_local_state(descriptors, tmp_path) + stitched = tmp_path / "stitched" + assert (stitched / "block_0.optimizer_state.pth").exists() + assert not (stitched / "block_0.grad_scaler.pth").exists() + + +def test_save_local_state_skips_optimizer_when_descriptor_has_none(tmp_path: Path, bcu_no_dist): + descriptors = OrderedDict( + [("block_0", _make_descriptor(with_optimizer=False, with_scaler=False))] + ) + bcu_no_dist._save_local_state(descriptors, tmp_path) + stitched = tmp_path / "stitched" + assert not (stitched / "block_0.optimizer_state.pth").exists() + assert not (stitched / "block_0.grad_scaler.pth").exists() + + +# --------------------------------------------------------------------------- +# save_bypass_checkpoint — orchestration: symlink, args.json, marker +# --------------------------------------------------------------------------- + + +def _make_save_cfg(experiment_dir: Path, *, delete_old: bool = True): + """Minimal cfg shape used by ``save_bypass_checkpoint``. + + ``cfg.bypass`` is the object that gets dumped to ``args.json``, so it must + be JSON-serialisable (or DictConfig-with-primitives, which json_dump handles). + """ + return OmegaConf.create( + { + "bypass": { + "experiment_dir": str(experiment_dir), + "model": {"model_overrides": {"delete_old_checkpoints": delete_old}}, + "iter_num": 7, + } + } + ) + + +@pytest.fixture +def patched_save(monkeypatch, bcu_no_dist): + """Stub out the heavy callees so the test only exercises the orchestration + logic in ``save_bypass_checkpoint``.""" + monkeypatch.setattr(bcu_no_dist, "_save_local_state", lambda **kwargs: None) + monkeypatch.setattr(bcu_no_dist, "save_checkpoint_from_shards", lambda **kwargs: None) + return bcu_no_dist + + +def test_save_bypass_checkpoint_creates_latest_symlink_and_marker(tmp_path: Path, patched_save): + experiment_dir = tmp_path / "exp" + experiment_dir.mkdir() + checkpoint_dir = experiment_dir / "step-000007-ckpt" + checkpoint_dir.mkdir() + + cfg = _make_save_cfg(experiment_dir) + patched_save.save_bypass_checkpoint( + cfg=cfg, + descriptor=None, + model=None, + stitched_module_descriptors=OrderedDict(), + checkpoint_dir=checkpoint_dir, + ) + + latest = experiment_dir / "latest" + assert latest.is_symlink() + # Symlink target is relative — just the dir name, so it resolves under experiment_dir. + assert os.readlink(latest) == "step-000007-ckpt" + assert latest.resolve() == checkpoint_dir.resolve() + assert (checkpoint_dir / "args.json").exists() + assert (checkpoint_dir / "saving_completed").exists() + + +def test_save_bypass_checkpoint_replaces_existing_latest_symlink(tmp_path: Path, patched_save): + """A stale ``latest`` from a prior save must be replaced, not appended to. + Without ``unlink(missing_ok=True)`` the symlink_to() call would raise + FileExistsError mid-save and leave the run unable to checkpoint.""" + experiment_dir = tmp_path / "exp" + experiment_dir.mkdir() + old_target = experiment_dir / "step-000003-ckpt" + old_target.mkdir() + new_target = experiment_dir / "step-000007-ckpt" + new_target.mkdir() + (experiment_dir / "latest").symlink_to(old_target.name) + + cfg = _make_save_cfg(experiment_dir) + patched_save.save_bypass_checkpoint( + cfg=cfg, + descriptor=None, + model=None, + stitched_module_descriptors=OrderedDict(), + checkpoint_dir=new_target, + ) + + assert os.readlink(experiment_dir / "latest") == "step-000007-ckpt" + + +def test_save_bypass_checkpoint_best_does_not_replace_latest(tmp_path: Path, patched_save): + experiment_dir = tmp_path / "exp" + experiment_dir.mkdir() + resume_target = experiment_dir / "step-000003-ckpt" + resume_target.mkdir() + best_target = experiment_dir / "best-step-000007-ckpt" + best_target.mkdir() + (experiment_dir / "latest").symlink_to(resume_target.name) + + cfg = _make_save_cfg(experiment_dir) + patched_save.save_bypass_checkpoint( + cfg=cfg, + descriptor=None, + model=None, + stitched_module_descriptors=OrderedDict(), + checkpoint_dir=best_target, + checkpoint_role="best", + ) + + assert os.readlink(experiment_dir / "latest") == "step-000003-ckpt" + assert (best_target / "saving_completed").exists() + assert (best_target / "bypass_config.json").exists() + + +def test_save_bypass_checkpoint_master_only_skips_symlink_on_non_master( + tmp_path: Path, monkeypatch, patched_save +): + """Non-master ranks must not write the symlink, args.json, or marker — + only rank 0 owns those files. The other ranks still call _save_local_state + (their owned blocks) but stop short of the per-experiment metadata.""" + monkeypatch.setattr(patched_save.dist, "is_master", lambda: False) + + experiment_dir = tmp_path / "exp" + experiment_dir.mkdir() + checkpoint_dir = experiment_dir / "step-000007-ckpt" + checkpoint_dir.mkdir() + + cfg = _make_save_cfg(experiment_dir) + patched_save.save_bypass_checkpoint( + cfg=cfg, + descriptor=None, + model=None, + stitched_module_descriptors=OrderedDict(), + checkpoint_dir=checkpoint_dir, + ) + + assert not (experiment_dir / "latest").exists() + assert not (checkpoint_dir / "args.json").exists() + assert not (checkpoint_dir / "saving_completed").exists() diff --git a/tests/unit/torch/puzzletron/test_bypass_dataloaders.py b/tests/unit/torch/puzzletron/test_bypass_dataloaders.py new file mode 100644 index 00000000000..1bcea14633e --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_dataloaders.py @@ -0,0 +1,328 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for bypass-distillation dataloader utilities. + +Covers the pure-Python branches of ``utils/data/dataloaders.py`` that don't +need a real tokenizer / GPU / distributed init: the validation-split +auto-detect rules, the ``num_workers`` guard rail, the dataset-loader +delegators, the ``Printer`` fake accelerator, and the small numeric helpers +(``create_padded_tensor``, ``realize_dataset_in_memory``, ``collate_none_fn``). +""" + +import datasets +import pytest +import torch +from datasets import Dataset, DatasetDict + +import modelopt.torch.puzzletron.utils.data.dataloaders as dl +from modelopt.torch.puzzletron.utils.data.dataloaders import ( + Printer, + collate_fn_with_none_support, + collate_none_fn, + create_padded_tensor, + create_train_dataloader, + create_validation_dataloader, + load_from_disk_fn, + load_streaming_fn, + realize_dataset_in_memory, +) +from modelopt.torch.puzzletron.utils.data.dataset import ConstantLengthDataset + +# --------------------------------------------------------------------------- +# realize_dataset_in_memory: pure list materialisation with optional cap +# --------------------------------------------------------------------------- + + +def test_realize_dataset_in_memory_full(): + items = [{"a": 1}, {"a": 2}, {"a": 3}] + out = realize_dataset_in_memory(iter(items), eval_samples=None) + assert out == items + + +def test_realize_dataset_in_memory_capped(): + items = [{"a": 1}, {"a": 2}, {"a": 3}] + out = realize_dataset_in_memory(iter(items), eval_samples=2) + assert out == [{"a": 1}, {"a": 2}] + + +# --------------------------------------------------------------------------- +# create_padded_tensor: identity, 1D pad, 2D pad with non-zero pad value +# --------------------------------------------------------------------------- + + +def test_create_padded_tensor_identity(): + t = torch.arange(6, dtype=torch.float32).reshape(2, 3) + out = create_padded_tensor(t, desired_shape=(2, 3)) + assert out is t # short-circuit, no copy + + +def test_create_padded_tensor_pads_1d_with_default_zero(): + t = torch.tensor([1, 2, 3], dtype=torch.int32) + out = create_padded_tensor(t, desired_shape=(5,)) + assert out.tolist() == [1, 2, 3, 0, 0] + assert out.dtype == torch.int32 + + +def test_create_padded_tensor_pads_2d_with_custom_value(): + t = torch.tensor([[1.0, 2.0]]) + out = create_padded_tensor(t, desired_shape=(2, 3), padding_value=-100.0) + assert out.tolist() == [[1.0, 2.0, -100.0], [-100.0, -100.0, -100.0]] + + +# --------------------------------------------------------------------------- +# Collate helpers: None-aware default collator +# --------------------------------------------------------------------------- + + +def test_collate_none_fn_returns_none(): + assert collate_none_fn([None, None]) is None + assert collate_none_fn([1, 2, 3]) is None # unconditional + + +def test_collate_fn_with_none_support_passes_none_through(): + """A label tensor of None should not be coerced to ``[None, None]`` — the + bypass val loop expects a single ``None`` so it can short-circuit loss + computation. This pins the ``type(None) -> collate_none_fn`` registration.""" + batch = [{"x": torch.tensor([1.0]), "y": None}, {"x": torch.tensor([2.0]), "y": None}] + out = collate_fn_with_none_support(batch) + assert out["y"] is None + assert torch.equal(out["x"], torch.tensor([[1.0], [2.0]])) + + +# --------------------------------------------------------------------------- +# Printer: degenerate "main process" stand-in for Accelerator +# --------------------------------------------------------------------------- + + +def test_printer_attributes_match_main_process_contract(): + assert Printer.is_main_process is True + assert Printer.process_index is None + Printer.print("hello world") # must not raise + + +# --------------------------------------------------------------------------- +# load_from_disk_fn / load_streaming_fn: thin wrappers around datasets.* +# --------------------------------------------------------------------------- + + +def test_load_from_disk_fn_delegates_to_datasets(monkeypatch): + captured = {} + + def fake_load_from_disk(path, keep_in_memory=False): + captured["path"] = path + captured["keep_in_memory"] = keep_in_memory + return "sentinel" + + monkeypatch.setattr(datasets, "load_from_disk", fake_load_from_disk) + out = load_from_disk_fn("/some/path", content_field="conversation", keep_in_memory=True) + assert out == "sentinel" + assert captured == {"path": "/some/path", "keep_in_memory": True} + + +def test_load_streaming_fn_uses_streaming_with_features(monkeypatch): + """``load_streaming_fn`` must request streaming and pin the content field's + feature schema — without ``features=`` HuggingFace would auto-infer types + per-shard, which has caused bypass jobs to crash on schema drift in the past. + """ + captured = {} + + def fake_load_dataset(path, streaming, features, keep_in_memory): + captured["path"] = path + captured["streaming"] = streaming + captured["features"] = features + captured["keep_in_memory"] = keep_in_memory + return "stream-sentinel" + + monkeypatch.setattr(datasets, "load_dataset", fake_load_dataset) + out = load_streaming_fn("hf-org/dataset", content_field="text", keep_in_memory=False) + assert out == "stream-sentinel" + assert captured["path"] == "hf-org/dataset" + assert captured["streaming"] is True + assert captured["keep_in_memory"] is False + # features must be a Features object keyed by the requested content_field + # with a string Value — schema-drift protection is the whole point of this fn. + assert isinstance(captured["features"], datasets.Features) + assert "text" in captured["features"] + assert captured["features"]["text"].dtype == "string" + + +# --------------------------------------------------------------------------- +# create_train_dataloader: ``num_workers > 0`` is a configuration error +# --------------------------------------------------------------------------- + + +def test_create_train_dataloader_rejects_num_workers_gt_zero(): + """ConstantLengthDataset doesn't shard work via ``get_worker_info`` — every + worker would emit the same samples. The guard fires before tokenizer or + dataset are touched, so bare-bones args are enough.""" + with pytest.raises(ValueError, match="num_workers"): + create_train_dataloader( + seed=0, + tokenizer=None, + block_size=8, + dataset_path={"train": []}, + content_field="text", + fim_rate=0.0, + fim_spm_rate=0.0, + micro_batch_size=1, + num_workers=2, + ) + + +class _NoChatTemplateTokenizer: + eos_token_id = 1 + bos_token_id = None + + def __init__(self): + self.seen_texts = None + self.vocab = {} + + def __call__(self, texts, truncation=False): + self.seen_texts = texts + return {"input_ids": [[0] for _ in texts]} + + +class _ConversationDataset: + column_names = ("text",) + + def __iter__(self): + yield { + "text": [ + {"role": "user", "content": {"text": "hello"}}, + {"role": "assistant", "content": {"value": 3}}, + ] + } + + +def test_constant_length_dataset_no_chat_template_normalizes_message_content(): + tokenizer = _NoChatTemplateTokenizer() + dataset = ConstantLengthDataset( + tokenizer, + _ConversationDataset(), + infinite=False, + seq_length=2, + num_of_sequences=1, + chars_per_token=100, + content_field="text", + fim_rate=0.0, + fim_spm_rate=0.0, + label_shift=False, + ) + + realized = list(dataset) + + assert tokenizer.seen_texts == ["hello\n{'value': 3}"] + assert len(realized) == 1 + assert torch.equal(realized[0]["input_ids"], torch.tensor([0, 1])) + assert torch.equal(realized[0]["targets"], torch.tensor([0, 1])) + + +# --------------------------------------------------------------------------- +# create_validation_dataloader: split auto-detect + explicit override +# --------------------------------------------------------------------------- + + +class _FakeConstantLengthDataset: + """Stub for ``ConstantLengthDataset`` that records its ``dataset`` arg. + + Yields one trivial item so ``realize_dataset_in_memory`` can iterate over + it without touching a tokenizer. + """ + + last_dataset = None # class-level capture so tests can read after construction + + def __init__(self, tokenizer, dataset, **kwargs): + type(self).last_dataset = dataset + self._dataset = dataset + + def __iter__(self): + yield {"input_ids": torch.tensor([0])} + + +@pytest.fixture +def patched_dataloader(monkeypatch): + """Replace the heavy bits inside ``create_validation_dataloader`` so the + function exercises only its pure split-selection logic + DataLoader build.""" + monkeypatch.setattr(dl, "ConstantLengthDataset", _FakeConstantLengthDataset) + # Force a tiny in-memory list so we don't drain a real iterable. + monkeypatch.setattr( + dl, + "realize_dataset_in_memory", + lambda dataset, eval_samples: [{"input_ids": torch.tensor([0])}], + ) + _FakeConstantLengthDataset.last_dataset = None + return _FakeConstantLengthDataset + + +def _make_dict_dataset(splits: dict[str, list]) -> DatasetDict: + return DatasetDict({k: Dataset.from_list(v) for k, v in splits.items()}) + + +def _kwargs(): + return { + "accelerator": None, # → Printer (single-process path) + "seed": 0, + "tokenizer": None, + "block_size": 4, + "content_field": "text", + "fim_rate": 0.0, + "fim_spm_rate": 0.0, + "micro_batch_size": 1, + } + + +def test_validation_split_auto_picks_validation_when_present(patched_dataloader): + dd = _make_dict_dataset({"train": [{"text": "t"}], "validation": [{"text": "v"}]}) + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + # The "validation" split must have been the one passed to ConstantLengthDataset. + assert patched_dataloader.last_dataset is dd["validation"] + + +def test_validation_split_auto_falls_back_to_test_when_no_val(patched_dataloader): + dd = _make_dict_dataset({"train": [{"text": "t"}], "test": [{"text": "te"}]}) + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + assert patched_dataloader.last_dataset is dd["test"] + + +def test_validation_split_auto_prefers_val_over_test(patched_dataloader): + """If both ``validation`` and ``test`` exist, the val* prefix must win — + bypass relies on this to score against held-out data, not test data.""" + dd = _make_dict_dataset( + {"train": [{"text": "t"}], "validation": [{"text": "v"}], "test": [{"text": "te"}]} + ) + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + assert patched_dataloader.last_dataset is dd["validation"] + + +def test_validation_split_auto_assertion_on_multiple_val_options(patched_dataloader): + """Ambiguity must fail loudly — silently picking one would be a footgun.""" + dd = _make_dict_dataset({"validation": [{"text": "a"}], "valtest": [{"text": "b"}]}) + with pytest.raises(AssertionError, match="exactly one validation split"): + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + + +def test_validation_split_auto_assertion_on_no_val_or_test(patched_dataloader): + dd = _make_dict_dataset({"train": [{"text": "t"}], "extra": [{"text": "e"}]}) + with pytest.raises(AssertionError, match="exactly one validation split"): + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + + +def test_validation_split_explicit_override_bypasses_auto(patched_dataloader): + """Explicit ``dataset_name`` must skip the auto-detect, even when the + chosen name doesn't match val* / test* prefixes.""" + dd = _make_dict_dataset({"my_eval": [{"text": "x"}]}) + create_validation_dataloader(dataset=dd, dataset_name="my_eval", **_kwargs()) + assert patched_dataloader.last_dataset is dd["my_eval"] diff --git a/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py new file mode 100644 index 00000000000..4d4031ba197 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py @@ -0,0 +1,256 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``_set_keys_to_learn`` in stitched_model_factory.py. + +This function is the single source of truth for which subblock parameters get +trained during a bypass run. Its branches (subblock_ffn / subblock_attention / +subblock_mamba / entire_block / list) and its hybrid-model ``block_configs`` +filter are all silent on misuse — a regression here would freeze the wrong +layers and produce a worse-than-teacher checkpoint with no loud failure. +""" + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn + +from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import _set_keys_to_learn + +# --------------------------------------------------------------------------- +# Fixtures: a minimal Llama-shaped model and a Llama-shaped descriptor stub +# --------------------------------------------------------------------------- + + +def _make_dense_model(num_layers: int = 2) -> nn.Module: + """Build a tiny model whose named_parameters mimic Llama's naming. + + Parameters live under ``model.layers.{i}.self_attn.{q,k,v,o}_proj.weight`` + and ``model.layers.{i}.mlp.{up,down}_proj.weight``. The function never reads + parameter shapes, so size doesn't matter — what matters is that the names + match what `_set_keys_to_learn` expects to see in `named_parameters()` and + `state_dict().keys()`. + """ + model = nn.Module() + model_inner = nn.Module() + layers = nn.ModuleList() + for _ in range(num_layers): + layer = nn.Module() + # attention + layer.self_attn = nn.Module() + for proj in ("q_proj", "k_proj", "v_proj", "o_proj"): + setattr(layer.self_attn, proj, nn.Linear(4, 4, bias=False)) + # feed-forward + layer.mlp = nn.Module() + for proj in ("up_proj", "down_proj"): + setattr(layer.mlp, proj, nn.Linear(4, 4, bias=False)) + layers.append(layer) + model_inner.layers = layers + model.model = model_inner + # `_set_keys_to_learn` reads `model.config` only to pass through to + # `descriptor.get_language_model_config` — a SimpleNamespace is enough. + model.config = SimpleNamespace() + # Start with everything frozen so any True flag is something the function set. + for p in model.parameters(): + p.requires_grad_(False) + return model + + +def _make_descriptor(num_layers: int, *, block_configs=None): + """Build a descriptor stub exposing only what ``_set_keys_to_learn`` calls. + + - ``get_language_model_config(config)`` returns an object with + ``num_hidden_layers`` and (optionally) ``block_configs``. + - ``get_weight_groups(state_dict_keys, num_hidden_layers)`` returns + ``{"block_{i}_attention": [...], "block_{i}_ffn": [...]}``. + """ + + def get_language_model_config(_config): + ns = SimpleNamespace(num_hidden_layers=num_layers) + if block_configs is not None: + ns.block_configs = block_configs + return ns + + def get_weight_groups(state_dict_keys, n): + groups: dict[str, list[str]] = {} + for i in range(n): + attn_prefix = f"model.layers.{i}.self_attn." + ffn_prefix = f"model.layers.{i}.mlp." + groups[f"block_{i}_attention"] = [ + k for k in state_dict_keys if k.startswith(attn_prefix) + ] + groups[f"block_{i}_ffn"] = [k for k in state_dict_keys if k.startswith(ffn_prefix)] + return groups + + return SimpleNamespace( + get_language_model_config=get_language_model_config, + get_weight_groups=get_weight_groups, + ) + + +def _trainable_names(model: nn.Module) -> set[str]: + return {n for n, p in model.named_parameters() if p.requires_grad} + + +# --------------------------------------------------------------------------- +# Single-string subblock keys (dense model) +# --------------------------------------------------------------------------- + + +def test_subblock_ffn_trains_only_mlp(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + _set_keys_to_learn(model, descriptor, "subblock_ffn") + trainable = _trainable_names(model) + assert all(".mlp." in n for n in trainable), trainable + assert not any(".self_attn." in n for n in trainable), trainable + # Both layers' mlp params must be trainable, not just one. + assert any("model.layers.0.mlp." in n for n in trainable) + assert any("model.layers.1.mlp." in n for n in trainable) + + +def test_subblock_attention_trains_only_self_attn(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + _set_keys_to_learn(model, descriptor, "subblock_attention") + trainable = _trainable_names(model) + assert all(".self_attn." in n for n in trainable), trainable + assert not any(".mlp." in n for n in trainable), trainable + + +def test_entire_block_trains_attention_and_mlp(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + _set_keys_to_learn(model, descriptor, "entire_block") + trainable = _trainable_names(model) + # Both groups present. + assert any(".self_attn." in n for n in trainable), trainable + assert any(".mlp." in n for n in trainable), trainable + # Equal to the union of every model parameter. + assert trainable == {n for n, _ in model.named_parameters()} + + +def test_subblock_key_list_trains_union_of_subblocks(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + _set_keys_to_learn(model, descriptor, ["subblock_attention", "subblock_ffn"]) + trainable = _trainable_names(model) + assert any(".self_attn." in n for n in trainable), trainable + assert any(".mlp." in n for n in trainable), trainable + assert trainable == {n for n, _ in model.named_parameters()} + + +def test_mixed_subblock_and_exact_name_list_raises(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + with pytest.raises(ValueError, match="mix subblock keys"): + _set_keys_to_learn( + model, + descriptor, + ["subblock_attention", "model.layers.0.self_attn.q_proj.weight"], + ) + + +# --------------------------------------------------------------------------- +# Hybrid model: subblock_mamba vs subblock_attention should partition by +# block_configs[i].attention.mamba — this is the path most likely to +# silently misroute training under future descriptor changes. +# --------------------------------------------------------------------------- + + +def _hybrid_block_configs(): + """Block 0: Mamba. Block 1: GQA. Detected via ``attention.mamba is not None``.""" + return [ + SimpleNamespace(attention=SimpleNamespace(mamba=SimpleNamespace())), # Mamba + SimpleNamespace(attention=SimpleNamespace(mamba=None)), # GQA + ] + + +def test_subblock_mamba_on_hybrid_trains_only_mamba_block(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2, block_configs=_hybrid_block_configs()) + _set_keys_to_learn(model, descriptor, "subblock_mamba") + trainable = _trainable_names(model) + # Block 0 (Mamba) attention-group params should be trainable; block 1 (GQA) must not. + assert any("model.layers.0.self_attn." in n for n in trainable), trainable + assert not any("model.layers.1.self_attn." in n for n in trainable), trainable + # FFN params are never trainable under subblock_mamba. + assert not any(".mlp." in n for n in trainable), trainable + + +def test_subblock_attention_on_hybrid_trains_only_gqa_block(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2, block_configs=_hybrid_block_configs()) + _set_keys_to_learn(model, descriptor, "subblock_attention") + trainable = _trainable_names(model) + # Block 1 (GQA) attention-group params are trainable; block 0 (Mamba) must not. + assert any("model.layers.1.self_attn." in n for n in trainable), trainable + assert not any("model.layers.0.self_attn." in n for n in trainable), trainable + assert not any(".mlp." in n for n in trainable), trainable + + +# --------------------------------------------------------------------------- +# Unsupported free-form key forms +# --------------------------------------------------------------------------- + + +def test_explicit_param_name_list_is_rejected(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + target = "model.layers.0.self_attn.q_proj.weight" + with pytest.raises(ValueError, match="subblock keys"): + _set_keys_to_learn(model, descriptor, [target]) + + +def test_regex_string_is_rejected(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + with pytest.raises(ValueError, match="keys_to_learn must be one of"): + _set_keys_to_learn(model, descriptor, r"q_proj") + + +def test_empty_key_list_is_rejected(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + with pytest.raises(ValueError, match="cannot be empty"): + _set_keys_to_learn(model, descriptor, []) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "keys_to_learn", + ["subblock_ffn", "subblock_attention", "entire_block"], +) +def test_subblock_keys_skip_non_floating_point_params(keys_to_learn): + """Integer / non-floating buffers exposed as parameters must stay frozen. + + The function explicitly guards on ``torch.is_floating_point(param)``; this + test pins that guard so a future refactor doesn't accidentally try to + enable grad on int tensors (which would raise at runtime). + """ + model = _make_dense_model(num_layers=2) + # Inject an int "param" alongside a real one. + int_param = nn.Parameter(torch.zeros(2, dtype=torch.long), requires_grad=False) + model.model.layers[0].self_attn.register_parameter("int_counter", int_param) + descriptor = _make_descriptor(num_layers=2) + # Should not raise even though the int param's name matches the attention group. + _set_keys_to_learn(model, descriptor, keys_to_learn) + # The int counter must remain frozen regardless. + assert not model.model.layers[0].self_attn.int_counter.requires_grad diff --git a/tests/unit/torch/puzzletron/test_bypass_losses.py b/tests/unit/torch/puzzletron/test_bypass_losses.py new file mode 100644 index 00000000000..2d59b25716a --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_losses.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for normalized MSE loss functions in sewing_kit/utils.py.""" + +import pytest +import torch + +from modelopt.torch.puzzletron.sewing_kit.utils import ( + batched_normalized_mse_loss, + normalized_mse_loss, + vectorwise_normalized_mse_loss, +) +from modelopt.torch.puzzletron.utils.parsing import format_stitched_losses + +# --------------------------------------------------------------------------- +# normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_normalized_mse_loss_identical_tensors(): + """Identical input and target should produce a loss of approximately 0.""" + torch.manual_seed(42) + x = torch.randn(4, 8) + loss = normalized_mse_loss(x, x) + assert torch.allclose(loss, torch.zeros_like(loss), atol=1e-6) + + +def test_normalized_mse_loss_basic(): + """Loss should be positive and finite for random, non-identical tensors.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target) + assert loss.item() > 0.0 + assert torch.isfinite(loss) + + +def test_normalized_mse_loss_reduction_none(): + """With reduction='none' the output shape should match the input shape.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target, reduction="none") + assert loss.shape == input_.shape + + +def test_normalized_mse_loss_reduction_sum(): + """With reduction='sum' the output should be a scalar tensor.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target, reduction="sum") + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + + +# --------------------------------------------------------------------------- +# vectorwise_normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_vectorwise_normalized_mse_loss_shape(): + """vectorwise_normalized_mse_loss should return a scalar for any 2-D input.""" + torch.manual_seed(42) + input_ = torch.randn(4, 16) + target = torch.randn(4, 16) + loss = vectorwise_normalized_mse_loss(input_, target) + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + + +def test_vectorwise_normalized_mse_loss_identical(): + """Identical input and target should give a loss of approximately 0.""" + torch.manual_seed(42) + x = torch.randn(4, 16) + loss = vectorwise_normalized_mse_loss(x, x) + assert torch.allclose(loss, torch.zeros_like(loss), atol=1e-6) + + +# --------------------------------------------------------------------------- +# batched_normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_batched_normalized_mse_loss_basic(): + """Should return a scalar with a positive, finite value for random tensors.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = batched_normalized_mse_loss(input_, target) + assert loss.ndim == 0 # scalar + assert loss.item() > 0.0 + assert torch.isfinite(loss) + + +def test_batched_normalized_mse_loss_custom_dims(): + """Custom batch_dims=(0, 1) on a 3-D tensor should still return a scalar.""" + torch.manual_seed(42) + input_ = torch.randn(2, 3, 8) + target = torch.randn(2, 3, 8) + loss = batched_normalized_mse_loss(input_, target, batch_dims=(0, 1)) + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + assert loss.item() > 0.0 + + +def test_batched_normalized_mse_loss_zero_target_is_finite(): + """All-zero target slice must not produce NaN/Inf. + + With the relative-L2 formula ``sum((x-t)^2) / (sum(t^2) + eps)``, an all-zero + target reduces the denominator to exactly ``eps`` — finite, no division by + zero — so the loss equals ``||input||^2 / eps``. The numeric value is large + by construction (that's what zero-magnitude targets mean), but the test + pins the property we actually care about: finiteness, not magnitude. + """ + input_ = torch.full((1, 8), 1.0) + target = torch.zeros(1, 8) + loss = batched_normalized_mse_loss(input_, target) + assert torch.isfinite(loss) + assert not torch.isnan(loss) + + +def test_batched_normalized_mse_loss_zero_input_and_target(): + """Both zero should give exactly 0.0 — numerator is zero, denominator is eps.""" + input_ = torch.zeros(2, 4) + target = torch.zeros(2, 4) + loss = batched_normalized_mse_loss(input_, target) + assert loss.item() == 0.0 + + +def test_batched_normalized_mse_loss_scale_invariance(): + """Scaling both input and target by the same constant must leave the loss + unchanged for non-tiny targets — the defining property of relative-L2.""" + torch.manual_seed(0) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + baseline = batched_normalized_mse_loss(input_, target) + scaled = batched_normalized_mse_loss(10.0 * input_, 10.0 * target) + assert torch.allclose(baseline, scaled, rtol=1e-4, atol=1e-6) + + +def test_batched_normalized_mse_loss_rejects_shape_mismatch(): + input_ = torch.randn(2, 3) + target = torch.randn(2, 1) + + with pytest.raises(ValueError, match="input and target shapes must match exactly"): + batched_normalized_mse_loss(input_, target) + + +def test_batched_normalized_mse_loss_rejects_invalid_batch_dim(): + input_ = torch.randn(2, 3) + target = torch.randn(2, 3) + + with pytest.raises(ValueError, match="batch_dims contains invalid dimension"): + batched_normalized_mse_loss(input_, target, batch_dims=(2,)) + + +def test_format_stitched_losses_keeps_trainable_nan_visible(): + out = format_stitched_losses( + {"block_0": float("nan"), "block_1": 1.0}, + initial_values_dict={"block_0": 0.5, "block_1": 2.0}, + not_trainable_names={"block_2"}, + step_number=3, + ) + + assert "nan" in out + assert "non-finite" in out + assert "Skipped=1" in out + assert "No trainable blocks found" not in out diff --git a/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py b/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py new file mode 100644 index 00000000000..38701ba8be3 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the cosine-with-warmup LR scheduler used by bypass distillation. + +``_get_lr`` is the scheduler invoked every step inside ``train``. An off-by-one +in the cosine ramp would silently degrade convergence — bypass jobs run for +hours and produce subtly worse student weights. The degenerate-budget guard +matters for tests and short sweeps where ``training_tokens`` is small. + +Schedule shape (warmup_steps=W, lr_decay_steps=D): + + step ∈ [0, W]: linear ramp 0 → base_lr (warmup branch) + step ∈ (W, D]: cosine decay base_lr → min_lr (cosine branch) + step > D: clamped to min_lr (post-decay branch) + +The cosine uses ``decay_ratio = (step - W) / (D - W)`` so the boundary cases +align: at step=W+1 the cosine has just started (decay_ratio = 1/(D-W)) and at +step=D it reaches min_lr exactly (decay_ratio=1, coeff=0). +""" + +import math + +import pytest +from omegaconf import OmegaConf + +from modelopt.torch.puzzletron.bypass_distillation.training_loop import _get_lr + + +def _make_cfg( + *, + warmup_steps: int, + lr_decay_steps: int, + learning_rate: float = 1.0, + min_lr: float = 0.1, +): + return OmegaConf.create( + { + "bypass": { + "training": { + "warmup_steps": warmup_steps, + "lr_decay_steps": lr_decay_steps, + "learning_rate": learning_rate, + "min_lr": min_lr, + } + } + } + ) + + +def test_degenerate_budget_returns_base_lr(): + """When ``lr_decay_steps <= warmup_steps`` (tiny test budgets), the scheduler + must short-circuit to ``learning_rate`` rather than divide by zero.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=10, learning_rate=0.5) + assert _get_lr(cfg, step=0) == 0.5 + assert _get_lr(cfg, step=1) == 0.5 + assert _get_lr(cfg, step=99) == 0.5 + + +def test_degenerate_budget_warmup_greater_than_decay(): + """``lr_decay_steps < warmup_steps`` is also caught by the same guard.""" + cfg = _make_cfg(warmup_steps=20, lr_decay_steps=10, learning_rate=0.7) + assert _get_lr(cfg, step=5) == 0.7 + + +def test_warmup_linear_ramp(): + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=100, learning_rate=1.0) + assert _get_lr(cfg, step=0) == pytest.approx(0.0) + assert _get_lr(cfg, step=5) == pytest.approx(0.5) + assert _get_lr(cfg, step=10) == pytest.approx(1.0) + + +def test_cosine_starts_decaying_immediately_after_warmup(): + """At ``step == warmup_steps + 1`` the cosine branch is entered with + ``decay_ratio = 1/(D-W)`` — already a small step below base LR, not a + duplicate plateau at base LR. This is the boundary the previous formula + got wrong (it used ``step - W - 1`` and gave ``decay_ratio == 0`` here).""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.0) + # decay_ratio = (11 - 10) / 10 = 0.1 + expected = 0.5 * (1.0 + math.cos(math.pi * 0.1)) + assert _get_lr(cfg, step=11) == pytest.approx(expected) + # Strictly below base LR — the cosine has begun. + assert _get_lr(cfg, step=11) < 1.0 + + +def test_cosine_endpoint_returns_min_lr(): + """At ``step == lr_decay_steps`` the cosine branch reaches its endpoint: + ``decay_ratio == 1`` → ``coeff == 0`` → returns ``min_lr`` exactly. The + post-decay clamp at ``step == lr_decay_steps + 1`` is then a no-op + continuation, not a correction for an off-by-one.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.1) + assert _get_lr(cfg, step=20) == pytest.approx(0.1) + + +def test_cosine_midpoint_is_halfway(): + """At the cosine midpoint, ``coeff == 0.5`` → returns ``(lr + min_lr) / 2``.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.0) + # Midpoint of the post-warmup window: step such that decay_ratio == 0.5. + # decay_ratio = (step - 10) / (20 - 10) → step = 15 gives ratio 0.5. + expected_coeff = 0.5 * (1.0 + math.cos(math.pi * 0.5)) + assert _get_lr(cfg, step=15) == pytest.approx(expected_coeff) + + +def test_post_decay_clamps_to_min_lr(): + """``step > lr_decay_steps`` always returns ``min_lr`` exactly.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.1) + assert _get_lr(cfg, step=21) == 0.1 + assert _get_lr(cfg, step=1000) == 0.1 + + +def test_min_lr_zero_decays_to_zero(): + """Common config: ``min_lr=0`` → cosine endpoint is exactly 0.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=30, learning_rate=2.0, min_lr=0.0) + assert _get_lr(cfg, step=30) == pytest.approx(0.0) + assert _get_lr(cfg, step=31) == 0.0 diff --git a/tests/unit/torch/puzzletron/test_bypass_utils.py b/tests/unit/torch/puzzletron/test_bypass_utils.py new file mode 100644 index 00000000000..0b43a97c01c --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_utils.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for get_distributed_modules_ownership in bypass_utils.py.""" + +import pytest +from omegaconf import OmegaConf + +from modelopt.torch.puzzletron.bypass_distillation.bypass_utils import ( + get_bypass_config_fingerprint, + get_distributed_modules_ownership, + get_pipeline_ownership_context, + set_experiment_id, +) + + +def test_single_gpu_all_to_rank_0(): + """With world_size=1, all 4 modules should be assigned to rank 0.""" + ownership = get_distributed_modules_ownership(module_count=4, world_size=1) + assert ownership == [0, 0, 0, 0] + + +def test_even_distribution(): + """With world_size=2 and 4 modules, each rank should own exactly 2 modules.""" + ownership = get_distributed_modules_ownership(module_count=4, world_size=2) + assert ownership.count(0) == 2 + assert ownership.count(1) == 2 + assert len(ownership) == 4 + + +def test_uneven_distribution(): + """With world_size=2 and 3 modules, rank 0 should own 2 and rank 1 should own 1.""" + ownership = get_distributed_modules_ownership(module_count=3, world_size=2) + assert ownership.count(0) == 2 + assert ownership.count(1) == 1 + assert len(ownership) == 3 + + +@pytest.mark.parametrize( + ("module_count", "world_size"), + [ + (1, 1), + (4, 1), + (4, 2), + (4, 4), + (7, 3), + (10, 4), + (1, 2), + ], +) +def test_total_equals_module_count(module_count, world_size): + """The length of the ownership list must always equal module_count.""" + ownership = get_distributed_modules_ownership(module_count=module_count, world_size=world_size) + assert len(ownership) == module_count + + +def test_consecutive_ownership(): + """Each rank should own a contiguous block of indices (no interleaving).""" + ownership = get_distributed_modules_ownership(module_count=7, world_size=3) + # Verify that once we see a new rank, we never see the previous rank again. + seen_ranks = set() + prev_rank = ownership[0] + seen_ranks.add(prev_rank) + for rank in ownership[1:]: + if rank != prev_rank: + assert rank not in seen_ranks, ( + f"Rank {rank} appears non-consecutively in ownership list: {ownership}" + ) + seen_ranks.add(rank) + prev_rank = rank + + +def test_single_module(): + """With world_size=2 and only 1 module, rank 0 should be the sole owner.""" + ownership = get_distributed_modules_ownership(module_count=1, world_size=2) + assert ownership == [0] + assert len(ownership) == 1 + + +def test_pipeline_ownership_context_returns_neighbors(): + ownership = [0, 0, 1, 1, 2] + + assert get_pipeline_ownership_context(ownership, rank=0) == { + "owned_indices": [0, 1], + "owned_index_set": {0, 1}, + "prev_rank": None, + "next_rank": 1, + } + assert get_pipeline_ownership_context(ownership, rank=1) == { + "owned_indices": [2, 3], + "owned_index_set": {2, 3}, + "prev_rank": 0, + "next_rank": 2, + } + assert get_pipeline_ownership_context(ownership, rank=2) == { + "owned_indices": [4], + "owned_index_set": {4}, + "prev_rank": 1, + "next_rank": None, + } + + +def test_pipeline_ownership_context_rejects_idle_rank(): + with pytest.raises(RuntimeError, match="owns no modules"): + get_pipeline_ownership_context([0, 0, 1], rank=2) + + +def _experiment_cfg(keys_to_learn: str): + return OmegaConf.create( + { + "descriptor": "test_descriptor", + "dataset_path": "/tmp/dataset_a", + "bypass": { + "experiment_id": None, + "dtype": "bf16", + "seed": 42, + "data": { + "block_size": 64, + "data_column": "text", + "fim_rate": 0, + "fim_spm_rate": 0, + "bos_rate": 1.0, + "source_datasets_to_discard": [], + "load_from_disk": True, + "keep_in_memory": False, + "shuffle_train_data_seed": 123, + "val_dataset_name": "valid", + "max_eval_samples": 4, + "eval_samples_per_process": None, + }, + "training": { + "learning_rate": 1e-4, + "training_tokens": 1024, + "micro_batch_size": 1, + "grad_accumulation_steps": 1, + "weight_decay": 0.1, + "decay_lr": True, + "beta1": 0.9, + "beta2": 0.95, + "grad_clip": 1.0, + "grad_clip_type": "norm", + "warmup_ratio": 0.05, + "min_lr_factor": 1e-5, + }, + "model": { + "student_weights_dtype": "bf16", + "model_config_overrides": { + "attention": [{"num_key_value_heads": 1, "no_op": None}] + }, + }, + "model_factory": { + "factory": "bypass_factory_fn", + "block_loss_func": "normalized_mse_loss", + "gqa_init_mode": "AverageKV", + "mlp_init_mode": "Truncate", + "mlp_init_config": {"activations_log_dir": None}, + "linear_init_mode": "FromTeacher", + "submodule_for_loss_calculation": None, + "keys_to_learn": keys_to_learn, + }, + "disable_validation": False, + "save_best_ckpt": True, + "realize_best_or_latest": "best", + }, + } + ) + + +def test_experiment_id_includes_learning_target_and_fingerprint(): + attention_cfg = _experiment_cfg("subblock_attention") + ffn_cfg = _experiment_cfg("subblock_ffn") + + set_experiment_id(attention_cfg) + set_experiment_id(ffn_cfg) + + assert attention_cfg.bypass.experiment_id.startswith("bypass_heads_1_attention_") + assert ffn_cfg.bypass.experiment_id.startswith("bypass_heads_1_ffn_") + assert attention_cfg.bypass.experiment_id != ffn_cfg.bypass.experiment_id + + +def test_experiment_id_falls_back_when_no_architecture_parts_exist(): + cfg = _experiment_cfg("entire_block") + cfg.bypass.model.model_config_overrides = {} + + set_experiment_id(cfg) + + assert cfg.bypass.experiment_id.startswith("bypass_custom_") + assert cfg.bypass.experiment_id != "bypass_None" + + +def test_config_fingerprint_changes_with_dataset_path(): + cfg = _experiment_cfg("subblock_attention") + original = get_bypass_config_fingerprint(cfg) + cfg.dataset_path = "/tmp/dataset_b" + assert get_bypass_config_fingerprint(cfg) != original + + +def test_config_fingerprint_changes_with_shuffle_seed(): + cfg = _experiment_cfg("subblock_attention") + original = get_bypass_config_fingerprint(cfg) + cfg.bypass.data.shuffle_train_data_seed = 456 + assert get_bypass_config_fingerprint(cfg) != original + + +def test_experiment_id_does_not_change_with_dataset_path(): + cfg_a = _experiment_cfg("subblock_attention") + cfg_b = _experiment_cfg("subblock_attention") + cfg_b.dataset_path = "/tmp/dataset_b" + set_experiment_id(cfg_a) + set_experiment_id(cfg_b) + assert cfg_a.bypass.experiment_id == cfg_b.bypass.experiment_id diff --git a/tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py b/tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py new file mode 100644 index 00000000000..c486808e387 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +import torch + +from modelopt.torch.puzzletron.tools import checkpoint_utils_hf as cuhf + + +def test_save_checkpoint_uses_descriptor_language_model_config(tmp_path, monkeypatch): + calls = {} + + class Descriptor: + @staticmethod + def get_language_model_config(config): + return config.text_config + + @staticmethod + def get_weight_groups(layer_names, num_hidden_layers): + calls["num_hidden_layers"] = num_hidden_layers + return {"weights": list(layer_names)} + + @staticmethod + def output_embedding_name(): + return "lm_head" + + monkeypatch.setattr(cuhf, "save_model_config", lambda *args, **kwargs: None) + monkeypatch.setattr(cuhf, "save_subblocks", lambda *args, **kwargs: None) + + cfg = SimpleNamespace( + text_config=SimpleNamespace(num_hidden_layers=7), + tie_word_embeddings=False, + ) + cuhf._save_checkpoint(cfg, {"some.weight": torch.zeros(1)}, tmp_path, Descriptor) + + assert calls["num_hidden_layers"] == 7 + + +def test_copy_auto_map_code_files_ignores_non_string_entries(tmp_path, monkeypatch): + source_dir = tmp_path / "source" + checkpoint_dir = tmp_path / "checkpoint" + source_dir.mkdir() + checkpoint_dir.mkdir() + (source_dir / "modeling_custom.py").write_text("# modeling\n") + (source_dir / "tokenization_custom.py").write_text("# tokenizer\n") + + monkeypatch.setattr(cuhf.inspect, "getfile", lambda _cls: source_dir / "configuration.py") + + cfg = SimpleNamespace( + auto_map={ + "AutoConfig": "configuration_custom.CustomConfig", + "AutoModelForCausalLM": "modeling_custom.CustomModel", + "AutoTokenizer": [None, "tokenization_custom.CustomTokenizer"], + } + ) + + cuhf._copy_auto_map_code_files(cfg, checkpoint_dir) + + assert (checkpoint_dir / "modeling_custom.py").exists() + assert (checkpoint_dir / "tokenization_custom.py").exists() diff --git a/tests/unit/torch/puzzletron/test_child_init_mixins.py b/tests/unit/torch/puzzletron/test_child_init_mixins.py new file mode 100644 index 00000000000..b68313245e4 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_child_init_mixins.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +import torch + +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import _process_single_layer + + +class _AddOneMixin: + def prune_single_layer(self, parent_state_dict, keys_to_remove, **kwargs): + keys_to_remove["w"] = "w" + return {"w": parent_state_dict["w"] + 1} + + +class _TimesTwoMixin: + def prune_single_layer(self, parent_state_dict, keys_to_remove, **kwargs): + keys_to_remove["w"] = "w" + return {"w": parent_state_dict["w"] * 2} + + +class _PopKeyMixin: + def prune_single_layer(self, parent_state_dict, keys, **kwargs): + keys.pop("w") + return {"w": parent_state_dict["w"]} + + +def _process_with_mixins(mixins, keys): + return _process_single_layer( + layer_idx=0, + pruning_mixin=mixins, + descriptor=None, + parent_state_dict={"w": torch.tensor([1.0])}, + new_state_dict={"w": torch.tensor([0.0])}, + original_config=SimpleNamespace(), + new_config=SimpleNamespace(), + gqa_init_mode=None, + mlp_init_mode=None, + mlp_init_config=None, + linear_init_mode=None, + ignored_keys=set(), + keys=keys, + is_original_mha=False, + head_size=1, + hidden_size=1, + ) + + +def test_pruning_mixins_compose_overlapping_outputs_sequentially(): + layer_state_dict, keys_to_remove = _process_with_mixins( + [_AddOneMixin(), _TimesTwoMixin()], {"w": "w"} + ) + + assert torch.equal(layer_state_dict["w"], torch.tensor([4.0])) + assert keys_to_remove == {"w": "w"} + + +def test_pruning_mixin_key_mutation_is_tracked_without_mutating_shared_keys(): + shared_keys = {"w": "w"} + + _, keys_to_remove = _process_with_mixins([_PopKeyMixin()], shared_keys) + + assert keys_to_remove == {"w": "w"} + assert shared_keys == {"w": "w"} diff --git a/tests/unit/torch/puzzletron/test_hydra_utils.py b/tests/unit/torch/puzzletron/test_hydra_utils.py new file mode 100644 index 00000000000..4b84dc08812 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_hydra_utils.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from modelopt.torch.puzzletron.tools.hydra_utils import warmup_steps + + +def test_warmup_steps_casts_inputs_before_computing(): + assert warmup_steps("100", "10", "2", "5", "0.5") == 1 + + +@pytest.mark.parametrize( + ("kwargs", "message"), + [ + ({"tokens": -1, "block": 1, "mbs": 1, "grad_accum": 1, "pct": 0.1}, "tokens"), + ({"tokens": 1, "block": 0, "mbs": 1, "grad_accum": 1, "pct": 0.1}, "block"), + ({"tokens": 1, "block": 1, "mbs": 0, "grad_accum": 1, "pct": 0.1}, "mbs"), + ({"tokens": 1, "block": 1, "mbs": 1, "grad_accum": 0, "pct": 0.1}, "grad_accum"), + ({"tokens": 1, "block": 1, "mbs": 1, "grad_accum": 1, "pct": 1.1}, "pct"), + ], +) +def test_warmup_steps_rejects_invalid_inputs(kwargs, message): + with pytest.raises(ValueError, match=message): + warmup_steps(**kwargs) diff --git a/tests/unit/torch/puzzletron/test_kv_heads_pruning_utils.py b/tests/unit/torch/puzzletron/test_kv_heads_pruning_utils.py new file mode 100644 index 00000000000..421ec4304bb --- /dev/null +++ b/tests/unit/torch/puzzletron/test_kv_heads_pruning_utils.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +from modelopt.torch.puzzletron.pruning.pruning_utils import _lm_head_dim + + +def test_lm_head_dim_uses_explicit_nested_head_dim(): + cfg = SimpleNamespace( + text_config=SimpleNamespace(head_dim=96, hidden_size=3072, num_attention_heads=32) + ) + assert _lm_head_dim(cfg) == 96 + + +def test_lm_head_dim_falls_back_to_hidden_size_over_heads(): + cfg = SimpleNamespace(text_config=SimpleNamespace(hidden_size=3072, num_attention_heads=32)) + assert _lm_head_dim(cfg) == 96 diff --git a/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py new file mode 100644 index 00000000000..f591cc68916 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py @@ -0,0 +1,156 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``launch_bypass_distillation`` (sweep dispatcher). + +The dispatcher's job is to iterate over ``bypass.configs``, apply each override +to the live ``hydra_cfg``, reset the per-run state machine, and invoke +``run_bypassed_training``. Reordering or dropping a reset would silently make +the second sweep entry resume from the first entry's iter counter — a bug +that would only surface as wasted compute and confused checkpoint dirs. + +We patch ``run_bypassed_training`` to a recorder so this stays a pure-Python +test (no GPU, no real training). +""" + +from omegaconf import OmegaConf + +import modelopt.torch.puzzletron.bypass_distillation.training_loop as tl + + +def _base_cfg(tmp_path, configs=None): + """Build a minimal cfg shape that ``launch_bypass_distillation`` reads. + + Includes only the keys touched by the dispatcher itself; ``run_bypassed_training`` + is mocked so its richer requirements are irrelevant here. + """ + cfg = { + "puzzle_dir": str(tmp_path / "puzzletron_bypass_unit"), + "descriptor": "test_descriptor", + "bypass": { + "model": {"model_config_overrides": {"intermediate_size": 1024}}, + "model_factory": {"keys_to_learn": "subblock_ffn"}, + "experiment_id": "stale-id", + "iter_num": 999, + "step_num": 999, + "token_count": 999_999, + "best_val_loss": 0.0, + "training": {"clipping_count": 42}, + }, + } + if configs is not None: + cfg["bypass"]["configs"] = configs + return OmegaConf.create(cfg) + + +def _record_calls(monkeypatch): + """Patch ``run_bypassed_training`` to capture deep-copied cfg snapshots.""" + snapshots = [] + + def _recorder(cfg): + # Deep-copy via container conversion; the live cfg is mutated between calls. + snapshots.append(OmegaConf.to_container(cfg, resolve=True)) + + monkeypatch.setattr(tl, "run_bypassed_training", _recorder) + return snapshots + + +def test_no_configs_key_runs_once(monkeypatch, tmp_path): + """Absent ``bypass.configs`` is the single-config path — one call, no resets.""" + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg(tmp_path, configs=None) + tl.launch_bypass_distillation(cfg) + assert len(snapshots) == 1 + # Single-config path doesn't touch the state machine — values remain as supplied. + assert snapshots[0]["bypass"]["iter_num"] == 999 + assert snapshots[0]["bypass"]["training"]["clipping_count"] == 42 + + +def test_empty_configs_list_runs_once(monkeypatch, tmp_path): + """``configs: []`` must hit the same branch as missing — the truthiness + check on line 85 of training_loop.py treats both as 'no sweep'.""" + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg(tmp_path, configs=[]) + tl.launch_bypass_distillation(cfg) + assert len(snapshots) == 1 + + +def test_two_configs_run_twice_with_distinct_overrides(monkeypatch, tmp_path): + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg( + tmp_path, + configs=[ + {"model_config_overrides": {"intermediate_size": 256}}, + {"model_config_overrides": {"intermediate_size": 128}}, + ], + ) + tl.launch_bypass_distillation(cfg) + assert len(snapshots) == 2 + assert snapshots[0]["bypass"]["model"]["model_config_overrides"] == {"intermediate_size": 256} + assert snapshots[1]["bypass"]["model"]["model_config_overrides"] == {"intermediate_size": 128} + + +def test_keys_to_learn_override_applied(monkeypatch, tmp_path): + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg(tmp_path, configs=[{"keys_to_learn": "subblock_attention"}]) + tl.launch_bypass_distillation(cfg) + assert snapshots[0]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_attention" + + +def test_per_run_state_reset_before_each_call(monkeypatch, tmp_path): + """Every sweep entry must see iter_num=1, step_num=1, token_count=0, + best_val_loss=1e9, clipping_count=0, and a fresh experiment_id even when the + previous entry left the cfg in some other state.""" + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg( + tmp_path, + configs=[ + {"model_config_overrides": {"intermediate_size": 256}}, + {"model_config_overrides": {"intermediate_size": 128}}, + ], + ) + tl.launch_bypass_distillation(cfg) + for snap in snapshots: + assert snap["bypass"]["experiment_id"].startswith("bypass_ffn_") + assert snap["bypass"]["iter_num"] == 1 + assert snap["bypass"]["step_num"] == 1 + assert snap["bypass"]["token_count"] == 0 + assert snap["bypass"]["best_val_loss"] == 1e9 + assert snap["bypass"]["training"]["clipping_count"] == 0 + + +def test_override_without_keys_to_learn_leaves_cfg_value_untouched(monkeypatch, tmp_path): + """A sweep entry that only sets ``model_config_overrides`` must not clobber + the inherited ``keys_to_learn`` (the dispatcher's `if "keys_to_learn" in override` + guard, line 99).""" + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg(tmp_path, configs=[{"model_config_overrides": {"intermediate_size": 256}}]) + tl.launch_bypass_distillation(cfg) + # keys_to_learn was set to "subblock_ffn" in _base_cfg — must survive. + assert snapshots[0]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_ffn" + + +def test_sweep_entry_without_keys_to_learn_uses_base_not_previous_override(monkeypatch, tmp_path): + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg( + tmp_path, + configs=[ + {"keys_to_learn": "subblock_attention"}, + {"model_config_overrides": {"intermediate_size": 256}}, + ], + ) + tl.launch_bypass_distillation(cfg) + assert snapshots[0]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_attention" + assert snapshots[1]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_ffn" diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py b/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py new file mode 100644 index 00000000000..58df5ffe327 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py @@ -0,0 +1,171 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``sewing_kit.utils.ActivityContext``. + +``ActivityContext`` is the stack the ``Passage`` machinery uses to track which +passages are currently active inside a ``StitchedModule.forward`` call. A bug +in push/pop ordering or in the exception-safe cleanup would leak state across +forward passes — every subsequent block would see a stale "active passage" +and route inputs/outputs to the wrong module. +""" + +import pytest + +from modelopt.torch.puzzletron.sewing_kit.utils import ( + ActivityContext, + ActivityContextDuplicateException, + ActivityContextMaxDepthException, + is_submodule_of, + is_submodule_or_same, +) + +# --------------------------------------------------------------------------- +# Basic push/pop semantics via the ``with ctx(value):`` form +# --------------------------------------------------------------------------- + + +def test_starts_empty_and_inactive(): + ctx: ActivityContext[str] = ActivityContext() + assert len(ctx) == 0 + assert not ctx.is_active() + assert ctx.get_active() is None + + +def test_with_block_pushes_and_pops_value(): + ctx: ActivityContext[str] = ActivityContext() + with ctx("a"): + assert ctx.is_active() + assert ctx.get_active() == "a" + assert "a" in ctx + assert len(ctx) == 1 + # After the block: stack must be back to empty. + assert len(ctx) == 0 + assert ctx.get_active() is None + + +def test_nested_pushes_track_lifo_order(): + """``get_active`` returns the *most recent* push (LIFO) — Passage relies on + this to find the innermost active passage during forward.""" + ctx: ActivityContext[str] = ActivityContext() + with ctx("outer"): + assert ctx.get_active() == "outer" + with ctx("inner"): + assert ctx.get_active() == "inner" + assert ctx[0] == "outer" + assert ctx[1] == "inner" + # Inner pop returns to outer. + assert ctx.get_active() == "outer" + + +# --------------------------------------------------------------------------- +# max_depth: limits stack height +# --------------------------------------------------------------------------- + + +def test_max_depth_one_allows_single_push(): + ctx: ActivityContext[str] = ActivityContext(max_depth=1) + with ctx("a"): + assert ctx.get_active() == "a" + + +def test_max_depth_one_rejects_second_push(): + ctx: ActivityContext[str] = ActivityContext(max_depth=1) + with ctx("a"), pytest.raises(ActivityContextMaxDepthException), ctx("b"): + pass + # Stack must have unwound to empty even after the exception. + assert len(ctx) == 0 + + +# --------------------------------------------------------------------------- +# no_duplicates: same value can't appear twice +# --------------------------------------------------------------------------- + + +def test_no_duplicates_rejects_repeat_value(): + ctx: ActivityContext[str] = ActivityContext(no_duplicates=True) + with ctx("x"), pytest.raises(ActivityContextDuplicateException), ctx("x"): + pass + # Stack unwound; the still-active "x" was preserved through the failed push. + assert len(ctx) == 0 + + +def test_no_duplicates_allows_distinct_values(): + ctx: ActivityContext[str] = ActivityContext(no_duplicates=True) + with ctx("x"), ctx("y"): + assert "x" in ctx and "y" in ctx + + +# --------------------------------------------------------------------------- +# reversed=True: insert at front, pop from front +# --------------------------------------------------------------------------- + + +def test_reversed_pushes_to_front_and_pops_from_front(): + """``Passage.active_passages_context`` uses ``reversed=True`` so the + *first* active passage in iteration order is the innermost. Pin both + insert position and pop position.""" + ctx: ActivityContext[str] = ActivityContext(reversed=True) + with ctx("a"): + with ctx("b"): + # b inserted at front of stack. + assert ctx[0] == "b" + assert ctx[1] == "a" + # Pop from front: only "a" left — runs between the inner and outer + # exits, which is why these withs can't be combined. + assert list(ctx[:]) == ["a"] + + +# --------------------------------------------------------------------------- +# Exception safety: stack unwinds even if the caller's body raises +# --------------------------------------------------------------------------- + + +def test_stack_unwinds_when_body_raises(): + """A bug here would leak stack frames — the next forward pass would see + a stale active passage. This is the silent-failure scenario.""" + ctx: ActivityContext[str] = ActivityContext() + with pytest.raises(ValueError, match="boom"), ctx("a"): + assert ctx.get_active() == "a" + raise ValueError("boom") + assert len(ctx) == 0 + + +# --------------------------------------------------------------------------- +# is_submodule_of / is_submodule_or_same — string predicates used by passage.py +# --------------------------------------------------------------------------- + + +def test_is_submodule_of_proper_descendant(): + assert is_submodule_of("model.layers.0.self_attn", "model.layers.0") + assert is_submodule_of("model.layers.0", "model") + # Empty string parent matches any non-empty name (root-of-everything case). + assert is_submodule_of("model", "") + + +def test_is_submodule_of_rejects_self_and_unrelated(): + assert not is_submodule_of("model.layers.0", "model.layers.0") + assert not is_submodule_of("model.layers.0", "model.layers.1") + # Empty == empty is not a submodule relationship. + assert not is_submodule_of("", "") + # Prefix collision: "model.layers" is NOT a submodule of "model.lay" — the + # predicate requires a literal "." separator after the parent. + assert not is_submodule_of("model.layers", "model.lay") + + +def test_is_submodule_or_same_includes_equality(): + assert is_submodule_or_same("model.layers.0", "model.layers.0") + assert is_submodule_or_same("model.layers.0.attn", "model.layers.0") + assert not is_submodule_or_same("model.layers.0", "model.layers.1") diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_function_target_kwargs.py b/tests/unit/torch/puzzletron/test_sewing_kit_function_target_kwargs.py new file mode 100644 index 00000000000..1e412605435 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_sewing_kit_function_target_kwargs.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Regression test for ``FunctionTarget`` kwargs dispatch. + +The bypass-distillation factory stitches teacher and student block outputs into +a per-block loss function using ``InputArgs(target=...)`` and ``InputArgs(input=...)`` +adapters (see ``stitched_model_factory.py:~545``). The loss function is then +invoked by ``StitchedModule.forward`` at ``core.py:600`` as +``node.target.function(*input_args.args, **input_args.kwargs)`` — i.e. with +**named kwargs**. + +If sewing_kit ever switched to positional dispatch in stitch-declaration order, +asymmetric losses (KL divergence, relative-L2, anything where ``f(a, b) != f(b, a)``) +would silently swap their arguments. MSE-shaped losses would hide the regression +because they're symmetric. This test pins the contract. +""" + +import torch + +from modelopt.torch.puzzletron.sewing_kit.core import ExternalTarget, FunctionTarget, Needle +from modelopt.torch.puzzletron.sewing_kit.passage import InputArgs + + +def test_function_target_invoked_with_kwargs_not_positional(): + """The function callable must receive only kwargs (no positional args).""" + received: dict[str, object] = {} + + def record_call(*args, **kwargs): + received["args"] = args + received["kwargs"] = dict(kwargs) + # The output stitch needs *something* to carry — return a sentinel scalar. + return torch.tensor(0.0) + + loss_target = FunctionTarget("loss_fn", record_call) + teacher_value = torch.full((2, 3), 7.0) + student_value = torch.full((2, 3), 11.0) + + # Stitch order is intentionally reversed from the real factory: declare + # student-first, teacher-second. If dispatch were positional-in-declaration- + # order, ``input`` would receive the teacher value and ``target`` the student + # value — which the assertions below would catch. + stitched = ( + Needle() + .stitch( + ExternalTarget().output( + name="student_act", + adapter=lambda v: InputArgs(input=v), + ), + loss_target.input(), + ) + .stitch( + ExternalTarget().output( + name="teacher_act", + adapter=lambda v: InputArgs(target=v), + ), + loss_target.input(), + ) + .stitch( + loss_target.output(), + ExternalTarget().output(name="loss"), + ) + .knot() + ) + + stitched( + {}, + {"student_act": student_value, "teacher_act": teacher_value}, + ) + + assert received["args"] == (), ( + f"FunctionTarget called with positional args {received['args']!r}. " + f"Sewing-kit must dispatch with kwargs only; positional dispatch would " + f"silently swap input/target for asymmetric losses." + ) + assert set(received["kwargs"].keys()) == {"input", "target"} + assert torch.equal(received["kwargs"]["input"], student_value) + assert torch.equal(received["kwargs"]["target"], teacher_value) + + +def test_function_target_kwargs_independent_of_stitch_order(): + """Same as the test above, but with the *real factory's* stitch order + (teacher first, student second). Both orders must produce identical kwargs + — the InputArgs.__add__ kwargs merge is order-independent for distinct + keys.""" + received: dict[str, object] = {} + + def record_call(*args, **kwargs): + received["args"] = args + received["kwargs"] = dict(kwargs) + return torch.tensor(0.0) + + loss_target = FunctionTarget("loss_fn", record_call) + teacher_value = torch.full((2, 3), 13.0) + student_value = torch.full((2, 3), 17.0) + + stitched = ( + Needle() + .stitch( + ExternalTarget().output( + name="teacher_act", + adapter=lambda v: InputArgs(target=v), + ), + loss_target.input(), + ) + .stitch( + ExternalTarget().output( + name="student_act", + adapter=lambda v: InputArgs(input=v), + ), + loss_target.input(), + ) + .stitch( + loss_target.output(), + ExternalTarget().output(name="loss"), + ) + .knot() + ) + + stitched( + {}, + {"teacher_act": teacher_value, "student_act": student_value}, + ) + + assert received["args"] == () + assert set(received["kwargs"].keys()) == {"input", "target"} + assert torch.equal(received["kwargs"]["input"], student_value) + assert torch.equal(received["kwargs"]["target"], teacher_value) diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py b/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py new file mode 100644 index 00000000000..a568fadc07b --- /dev/null +++ b/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``sewing_kit.passage.InputArgs``. + +``InputArgs`` is the workhorse args/kwargs container the bypass distillation +factory uses inside its stitching reducers — see ``bypass_factory_fn`` calls +like ``lambda acc, override, orig, *args: override + orig.drop_args(0)``. +A regression in ``__add__`` or ``drop_args`` would silently corrupt the +inputs passed into per-block forward passes, producing wrong loss values +without any loud failure. +""" + +import pytest + +from modelopt.torch.puzzletron.sewing_kit.passage import InputArgs + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +def test_init_accepts_positional_and_keyword_args(): + ia = InputArgs(1, 2, foo="bar") + assert ia.args == [1, 2] + assert ia.kwargs == {"foo": "bar"} + + +def test_init_with_no_args_is_empty(): + ia = InputArgs() + assert ia.args == [] + assert ia.kwargs == {} + + +# --------------------------------------------------------------------------- +# __add__: concatenates args, merges kwargs (right wins on collision) +# --------------------------------------------------------------------------- + + +def test_add_concatenates_positional_args_in_order(): + a = InputArgs(1, 2) + b = InputArgs(3, 4) + result = a + b + assert result.args == [1, 2, 3, 4] + assert result.kwargs == {} + + +def test_add_merges_kwargs_with_right_winning(): + """Bypass reducers chain ``override + orig.drop_args(0)`` — when both sides + happen to set the same kwarg, the right-side value (the original input) + must win, otherwise the override silently displaces the original kwarg.""" + a = InputArgs(foo="from_a", bar="only_a") + b = InputArgs(foo="from_b", baz="only_b") + result = a + b + assert result.kwargs == {"foo": "from_b", "bar": "only_a", "baz": "only_b"} + + +def test_add_does_not_mutate_operands(): + a = InputArgs(1, 2, x="a") + b = InputArgs(3, y="b") + _ = a + b + assert a.args == [1, 2] and a.kwargs == {"x": "a"} + assert b.args == [3] and b.kwargs == {"y": "b"} + + +def test_add_rejects_non_input_args(): + # ``__add__`` enforces InputArgs+InputArgs only via an internal assert. + # ruff's RUF005 auto-fix to ``[*InputArgs(1), 2]`` would silently replace + # the operator call we're testing — keep the explicit ``+`` form. + with pytest.raises(AssertionError): + InputArgs(1) + [2] # type: ignore[operator] # noqa: RUF005 + + +# --------------------------------------------------------------------------- +# drop_args: clears all positional args (default) or one by index/slice +# --------------------------------------------------------------------------- + + +def test_drop_args_default_clears_all_positional(): + """The ``drop_args(0)`` and ``drop_args()`` forms are both used by bypass + stitches — the default-no-arg form must wipe the entire positional tuple + (kwargs untouched).""" + ia = InputArgs(1, 2, 3, foo="bar") + out = ia.drop_args() + assert out.args == [] + assert out.kwargs == {"foo": "bar"} + # And the original is unmodified. + assert ia.args == [1, 2, 3] + + +def test_drop_args_with_index_drops_one(): + ia = InputArgs(10, 20, 30) + out = ia.drop_args(0) + assert out.args == [20, 30] + # Source preserved. + assert ia.args == [10, 20, 30] + + +def test_drop_args_with_slice_drops_range(): + ia = InputArgs(10, 20, 30, 40) + out = ia.drop_args(slice(1, 3)) + assert out.args == [10, 40] + + +# --------------------------------------------------------------------------- +# drop_kwargs: clears all kwargs (default) or specific keys +# --------------------------------------------------------------------------- + + +def test_drop_kwargs_default_clears_all(): + ia = InputArgs(1, foo="bar", baz="qux") + out = ia.drop_kwargs() + assert out.args == [1] + assert out.kwargs == {} + + +def test_drop_kwargs_with_keys_drops_only_those(): + ia = InputArgs(1, foo="bar", baz="qux", keep="this") + out = ia.drop_kwargs(["foo", "baz"]) + assert out.kwargs == {"keep": "this"} + + +def test_drop_kwargs_silently_ignores_missing_keys(): + """A key listed in ``drop_kwargs`` that isn't present must not raise — + bypass calls this against args from arbitrary upstream stitches and may + pass keys that only some sources produce.""" + ia = InputArgs(foo="bar") + out = ia.drop_kwargs(["nonexistent"]) # must not KeyError + assert out.kwargs == {"foo": "bar"} + + +# --------------------------------------------------------------------------- +# from_value: lifts assorted values into InputArgs +# --------------------------------------------------------------------------- + + +def test_from_value_passes_through_existing_input_args(): + ia = InputArgs(1, foo="bar") + out = InputArgs.from_value(ia) + assert out is ia + + +def test_from_value_lifts_sequence_to_positional_args(): + out = InputArgs.from_value([1, 2, 3]) + assert out.args == [1, 2, 3] + assert out.kwargs == {} + + +def test_from_value_lifts_scalar_to_single_positional(): + out = InputArgs.from_value(42) + assert out.args == [42] + assert out.kwargs == {} diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_needle.py b/tests/unit/torch/puzzletron/test_sewing_kit_needle.py new file mode 100644 index 00000000000..a3db5ef30b8 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_sewing_kit_needle.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``sewing_kit.core.Needle`` graph construction and validation. + +The bypass factory builds three ``Needle``\\s per rank (teacher train, teacher +val, student val) and calls ``Needle.knot()`` on each. ``knot()`` runs +``_validate_nodes`` first; a regression in that validation would either crash +with an opaque NoneType error during forward, or — worse — silently allow a +malformed graph that produces incorrect activations. + +We test the validation contract on CPU without instantiating ``StitchedModule`` +itself (which requires Module patching). ``_validate_nodes`` is a private +method but it's the unit of behavior worth pinning; ``knot()`` is essentially +``_validate_nodes() + StitchedModule(...)``. +""" + +import pytest +import torch.nn as nn + +from modelopt.torch.puzzletron.sewing_kit.core import ( + ExternalTarget, + InputsLoopFoundException, + ModuleTarget, + Needle, + Node, + OnlyInternalNodesException, + StitchDescriptor, +) + +# --------------------------------------------------------------------------- +# get_node_for_target: lazy creation, cached lookup +# --------------------------------------------------------------------------- + + +def test_get_node_for_target_creates_node_on_first_call(): + needle = Needle() + target = ModuleTarget("a", nn.Linear(2, 2)) + node = needle.get_node_for_target(target) + assert isinstance(node, Node) + assert node.target is target + assert needle.nodes[target] is node + + +def test_get_node_for_target_returns_same_node_on_repeat_call(): + """Re-getting the same target must NOT create a duplicate node — every + stitch involving that target must funnel into a single Node, otherwise + the validation/forward graph fragments.""" + needle = Needle() + target = ModuleTarget("a", nn.Linear(2, 2)) + node1 = needle.get_node_for_target(target) + node2 = needle.get_node_for_target(target) + assert node1 is node2 + assert len(needle.nodes) == 1 + + +# --------------------------------------------------------------------------- +# stitch: adds StitchDescriptor to source.stitches_from and dest.stitches_to +# --------------------------------------------------------------------------- + + +def test_stitch_records_descriptor_on_both_endpoints(): + needle = Needle() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + + needle.stitch(target_a.output("x"), target_b.input("y")) + + node_a = needle.nodes[target_a] + node_b = needle.nodes[target_b] + # Source endpoint: A has one outgoing stitch; B has one incoming stitch. + assert len(node_a.stitches_from) == 1 + assert len(node_a.stitches_to) == 0 + assert len(node_b.stitches_from) == 0 + assert len(node_b.stitches_to) == 1 + # Same StitchDescriptor object on both lists. + assert node_a.stitches_from[0] is node_b.stitches_to[0] + assert isinstance(node_a.stitches_from[0], StitchDescriptor) + + +def test_stitch_returns_self_for_chaining(): + """Bypass factory chains ``.stitch(...).stitch(...)`` — the return type + must be the Needle itself so the second call sees the same graph.""" + needle = Needle() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + + out = needle.stitch(target_a.output("x"), target_b.input("y")) + assert out is needle + + +# --------------------------------------------------------------------------- +# _validate_nodes: contract checks before knot() builds the StitchedModule +# --------------------------------------------------------------------------- + + +def test_validate_raises_when_only_internal_nodes_present(): + """A graph with no External and no Remote target has nothing for the + runtime to feed inputs through — must raise loudly rather than build a + dead StitchedModule.""" + needle = Needle() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + needle.stitch(target_a.output("x"), target_b.input("y")) + + with pytest.raises(OnlyInternalNodesException): + needle._validate_nodes() + + +def test_validate_passes_with_external_plus_dag(): + """Happy path: ExternalTarget + a small linear DAG. Must not raise.""" + needle = Needle() + ext = ExternalTarget() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + + needle.stitch(ext.output("init"), target_a.input("entry")) + needle.stitch(target_a.output("x"), target_b.input("y")) + needle.stitch(target_b.output("z"), ext.input("final")) + + # No raise. + needle._validate_nodes() + + +def test_validate_raises_on_input_cycle_among_internal_nodes(): + """Detect a 2-node cycle A→B→A among internal nodes. + + The validation uses ``_search_loops`` walking ``stitches_to`` (incoming + edges); ExternalTarget short-circuits the recursion, so we add an + external feed to A so ``_validate_nodes`` doesn't bail out early on the + 'no external' check. + """ + needle = Needle() + ext = ExternalTarget() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + + # Anchor an external feed so we get past the OnlyInternalNodes check. + needle.stitch(ext.output("init"), target_a.input("entry")) + # Cycle: A -> B -> A. + needle.stitch(target_a.output("x"), target_b.input("y")) + needle.stitch(target_b.output("p"), target_a.input("q")) + + with pytest.raises(InputsLoopFoundException): + needle._validate_nodes() + + +def test_validate_passes_when_external_node_has_self_referential_loop_via_external(): + """``_search_loops`` short-circuits at ExternalTarget. So a 'loop' that + only goes through external (e.g. external→A and A→external) is fine — + and indeed required for normal stitching, where external is both the + input and output endpoint. + """ + needle = Needle() + ext = ExternalTarget() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + + needle.stitch(ext.output("in"), target_a.input("entry")) + needle.stitch(target_a.output("x"), ext.input("out")) + + # Despite the external→A→external pattern, this is the canonical bypass + # shape and must validate clean. + needle._validate_nodes() + + +# --------------------------------------------------------------------------- +# Sanity: ExternalTarget.input()/output() builds correctly typed descriptors +# --------------------------------------------------------------------------- + + +def test_module_target_descriptors_carry_target_and_name(): + """The ``.input("foo")`` and ``.output("bar")`` builders are what the + bypass factory uses to construct stitches. They must propagate the + target reference and the name into the resulting descriptor so the + runtime can route values correctly.""" + target = ModuleTarget("a", nn.Linear(2, 2)) + in_desc = target.input("foo") + out_desc = target.output("bar") + assert in_desc.target is target + assert in_desc.input_name == "foo" + assert out_desc.target is target + assert out_desc.output_name == "bar" diff --git a/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py b/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py new file mode 100644 index 00000000000..5fab764b565 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``_get_all_non_persistent_buffers_set``. + +This helper is what ``bypass_factory_fn`` uses to decide which buffers belong +to ``owned_buffers`` (and therefore get checkpointed) versus which are +recomputed on every forward (RoPE caches, attention masks, etc.). A regression +that drops the module-name prefix would cause the post-resume model to silently +load buffers under wrong names. +""" + +import torch +import torch.nn as nn + +from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import ( + _get_all_non_persistent_buffers_set, +) + + +def test_module_with_no_buffers_returns_empty_set(): + assert _get_all_non_persistent_buffers_set(nn.Module()) == set() + + +def test_persistent_buffer_excluded_non_persistent_included(): + m = nn.Module() + m.register_buffer("p", torch.zeros(1), persistent=True) + m.register_buffer("np", torch.zeros(1), persistent=False) + out = _get_all_non_persistent_buffers_set(m) + assert out == {"np"} + + +def test_nested_submodule_paths_are_fully_qualified(): + """Sub-module non-persistent buffers must surface as ``submodule_name.buffer_name`` + so the matching key in ``state_dict()`` and the bypass save/restore code agree.""" + outer = nn.Module() + inner = nn.Module() + inner.register_buffer("nb", torch.zeros(1), persistent=False) + outer.add_module("inner", inner) + out = _get_all_non_persistent_buffers_set(outer) + assert out == {"inner.nb"} + + +def test_top_level_buffer_has_no_leading_dot(): + """Module name is "" at the root — fully-qualified name must not start + with a dot, otherwise it won't match any state_dict key.""" + m = nn.Module() + m.register_buffer("x", torch.zeros(1), persistent=False) + out = _get_all_non_persistent_buffers_set(m) + assert out == {"x"} + assert not any(name.startswith(".") for name in out) + + +def test_mix_of_persistent_and_non_persistent_in_nested_module(): + """The full discrimination: only the nested non-persistent buffer should + appear, with its fully-qualified path.""" + outer = nn.Module() + inner = nn.Module() + inner.register_buffer("keep", torch.zeros(1), persistent=True) # persistent → excluded + inner.register_buffer("rope_cache", torch.zeros(1), persistent=False) + outer.add_module("attn", inner) + outer.register_buffer("global_keep", torch.zeros(1), persistent=True) # → excluded + out = _get_all_non_persistent_buffers_set(outer) + assert out == {"attn.rope_cache"}