diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py index 3c1749d46ec..58b045bd21c 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py @@ -169,6 +169,19 @@ def uses_autocast() -> bool: """ return True + @staticmethod + def pruning_mixins() -> Dict[str, Any]: + """Return available pruning mixins for bypass distillation. + + Override in subclasses to provide model-specific pruning mixins, e.g. + ``{"kv_heads": KVHeadsPruningMixIn(...), "experts_removal": ExpertRemovalPruningMixIn(...)}``. + + Returns an empty dict by default so that descriptors that do not need + model-specific weight-slicing (e.g. Llama with standard FFN truncation) + can rely on the generic ``create_child_state_dict`` fallback path. + """ + return {} + @staticmethod def get_language_model_config(config): """Get the language model config from a PretrainedConfig. diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py index c8fd86b4bb6..1abecdec0c2 100644 --- a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py @@ -28,6 +28,7 @@ ExpertRemovalLayerDescriptor, ExpertRemovalPruningMixIn, ) +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn # Expert removal is supported for unquantized models (test models). # Production models use MXFP4 quantized MoE with combined tensors @@ -37,7 +38,11 @@ from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size -__all__ = ["GptOssModelDescriptor", "GptOssExpertRemovalLayerDescriptor"] +__all__ = [ + "GptOssExpertRemovalLayerDescriptor", + "GptOssKVHeadsLayerDescriptor", + "GptOssModelDescriptor", +] @ModelDescriptorFactory.register_decorator("gpt_oss") @@ -173,7 +178,29 @@ def pruning_mixins() -> Dict[str, PruningMixIn]: Note: Expert removal works for unquantized models (test models). Production models use MXFP4 quantization which is not yet supported. """ - return {"expert_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())} + # Single instance shared between the canonical key and the legacy alias + # so resolve_pruning_mixin returns the same object regardless of which + # name a caller uses. + expert_mixin = ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor()) + return { + "experts_removal": expert_mixin, + # Backward-compat alias: this key was "expert_removal" before the + # bypass branch standardised on "experts_removal" (matching the + # NemotronH descriptor). Kept so external scripts that still call + # `resolve_pruning_mixin("expert_removal", GptOssModelDescriptor)` + # continue to work. Remove after a deprecation cycle. + "expert_removal": expert_mixin, + "kv_heads": KVHeadsPruningMixIn(GptOssKVHeadsLayerDescriptor()), + } + + +@dataclass +class GptOssKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) @dataclass diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py index 1c5706d1944..b3f33887367 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py @@ -29,11 +29,16 @@ ExpertRemovalLayerDescriptor, ExpertRemovalPruningMixIn, ) +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same -__all__ = ["NemotronHExpertRemovalLayerDescriptor", "NemotronHModelDescriptor"] +__all__ = [ + "NemotronHExpertRemovalLayerDescriptor", + "NemotronHKVHeadsLayerDescriptor", + "NemotronHModelDescriptor", +] def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: @@ -51,6 +56,15 @@ def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: return matches +@dataclass +class NemotronHKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "mixer.o_proj" + attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) + + @dataclass class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): target_name: str = "mixer.gate" @@ -251,4 +265,5 @@ def build_attention_predicates() -> Dict[str, re.Pattern]: def pruning_mixins() -> Dict[str, PruningMixIn]: return { "experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()), + "kv_heads": KVHeadsPruningMixIn(NemotronHKVHeadsLayerDescriptor()), } diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py index a1e326f2357..0c677f67542 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py @@ -29,11 +29,16 @@ FFNIntermediateLayerDescriptor, FFNIntermediatePruningMixIn, ) +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same -__all__ = ["NemotronHV2FFNIntermediateLayerDescriptor", "NemotronHV2ModelDescriptor"] +__all__ = [ + "NemotronHV2FFNIntermediateLayerDescriptor", + "NemotronHV2KVHeadsLayerDescriptor", + "NemotronHV2ModelDescriptor", +] def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: @@ -69,6 +74,15 @@ class NemotronHV2FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): linear_weight_names: List[str] = field(default_factory=lambda: ["down_proj", "up_proj"]) +@dataclass +class NemotronHV2KVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "mixer.o_proj" + attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) + + @ModelDescriptorFactory.register_decorator("nemotron_h_v2") class NemotronHV2ModelDescriptor(ModelDescriptor): _DECODER_LAYER_CLS: Type[nn.Module] = None @@ -251,5 +265,6 @@ def pruning_mixins() -> Dict[str, PruningMixIn]: "ffn_intermediate": FFNIntermediatePruningMixIn( NemotronHV2FFNIntermediateLayerDescriptor() ), + "kv_heads": KVHeadsPruningMixIn(NemotronHV2KVHeadsLayerDescriptor()), # TODO: Add expert removal support when ExpertRemovalPruningMixIn is migrated } diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py index aeedd419923..a0f9c95c6ce 100644 --- a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py @@ -26,9 +26,13 @@ ) from ....block_config import BlockConfig -from ....pruning.expert_removal_pruning_mixin import ExpertRemovalLayerDescriptor +from ....pruning.expert_removal_pruning_mixin import ( + ExpertRemovalLayerDescriptor, + ExpertRemovalPruningMixIn, +) from ....pruning.ffn_intermediate_pruning_mixin import FFNIntermediateLayerDescriptor -from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn +from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size @@ -56,6 +60,13 @@ def get_language_model_config(config): """Qwen3-VL has nested text_config for language model parameters.""" return config.text_config if hasattr(config, "text_config") else config + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + return { + "experts_removal": ExpertRemovalPruningMixIn(Qwen3VLExpertRemovalLayerDescriptor()), + "kv_heads": KVHeadsPruningMixIn(Qwen3VLKVHeadsLayerDescriptor()), + } + @staticmethod def decoder_layer_cls(): return Qwen3VLMoeTextDecoderLayer diff --git a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py index 740d1fada3c..e46f615f6f6 100644 --- a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py +++ b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py @@ -24,7 +24,12 @@ ) from .pruning_mixin import LayerDescriptor, PruningMixIn -from .pruning_utils import GQAInitMode, _init_attention_biases, _init_attention_weights +from .pruning_utils import ( + GQAInitMode, + _init_attention_biases, + _init_attention_weights, + _lm_head_dim, +) __all__ = [ "KVHeadsLayerDescriptor", @@ -74,7 +79,7 @@ def prune_single_layer( f"{attn_prefix}.{proj_name}" for proj_name in self.layer_descriptor.qkvo_weight_names ] - head_size = new_config.head_dim + head_size = _lm_head_dim(new_config) for part in ["weight", "bias"]: attn_keys = [f"{name}.{part}" for name in [q_name, k_name, v_name, o_name]] q_key, k_key, v_key, o_key = attn_keys diff --git a/modelopt/torch/puzzletron/pruning/pruning_utils.py b/modelopt/torch/puzzletron/pruning/pruning_utils.py index c600e119cfa..3b8e94347cb 100644 --- a/modelopt/torch/puzzletron/pruning/pruning_utils.py +++ b/modelopt/torch/puzzletron/pruning/pruning_utils.py @@ -52,6 +52,7 @@ class MlpInitMode(Enum): PruneByActivationsLog = "PruneByActivationsLog" ExpertRemoval = "ExpertRemoval" ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN" + MoEChannelPruning = "MoEChannelPruning" class LinearInitMode(Enum): @@ -66,6 +67,30 @@ class HiddenSizeInitMode(Enum): CopyAsIs = "CopyAsIs" +def _lm_attrs(config): + """Return the language-model sub-config for VL configs, else the config itself. + + VL configs nest language-model fields like ``num_attention_heads``, ``head_dim``, + and ``hidden_size`` under a sub-config. The attribute name varies by family — + ``text_config`` (Qwen3-VL, Llava, Idefics), ``language_config`` (Llama-4 and a + handful of others), and ``llm_config`` (InternVL and friends) are all common. + Probe each before falling back to the raw config. + """ + for attr in ("text_config", "language_config", "llm_config"): + sub = getattr(config, attr, None) + if sub is not None: + return sub + return config + + +def _lm_head_dim(config) -> int: + lm_config = _lm_attrs(config) + head_dim = getattr(lm_config, "head_dim", None) + if head_dim is not None: + return head_dim + return lm_config.hidden_size // lm_config.num_attention_heads + + def resolve_pruning_mixin( pruning_mixin, descriptor: Type[ModelDescriptor] ) -> PruningMixIn | List[PruningMixIn]: @@ -224,10 +249,13 @@ def _init_attention_weights( head_size, mlp_init_config, ): - assert new_config.num_attention_heads == original_config.num_attention_heads, ( - f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + new_lm = _lm_attrs(new_config) + orig_lm = _lm_attrs(original_config) + assert new_lm.num_attention_heads == orig_lm.num_attention_heads, ( + f"({new_lm.num_attention_heads=}) != ({orig_lm.num_attention_heads=})" ) - num_q_heads = new_config.num_attention_heads + num_q_heads = new_lm.num_attention_heads + # block_configs lives on the outer puzzletron-converted config, not on text_config. num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads @@ -372,17 +400,29 @@ def _init_attention_biases( head_size, mlp_init_config, ): - assert new_config.num_attention_heads == original_config.num_attention_heads, ( - f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + new_lm = _lm_attrs(new_config) + orig_lm = _lm_attrs(original_config) + assert new_lm.num_attention_heads == orig_lm.num_attention_heads, ( + f"({new_lm.num_attention_heads=}) != ({orig_lm.num_attention_heads=})" ) - num_q_heads = new_config.num_attention_heads + num_q_heads = new_lm.num_attention_heads + # block_configs lives on the outer puzzletron-converted config, not on text_config. num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads n_heads_in_group = num_q_heads // num_kv_heads orig_n_heads_in_group = num_q_heads // orig_num_kv_heads - o_proj_bias = new_config.o_proj_bias - attention_bias = new_config.attention_bias + # Some HF native configs (e.g. GptOssConfig) don't expose o_proj_bias / attention_bias as + # top-level attributes the way puzzletron's DeciLM-style configs do. Fall back to probing + # the new state dict for the actual bias keys when the attribute is missing. + # KVHeadsPruningMixIn only calls this helper after filtering to keys present in + # new_state_dict, so the probe mirrors the caller's already-selected bias tensors. + o_proj_bias = getattr(new_config, "o_proj_bias", None) + if o_proj_bias is None: + o_proj_bias = o_key in new_state_dict + attention_bias = getattr(new_config, "attention_bias", None) + if attention_bias is None: + attention_bias = q_key in new_state_dict # If no biases if not (o_proj_bias or attention_bias): @@ -438,8 +478,8 @@ def _init_attention_biases( assert not is_original_mha, ( "Degrouping can only be done on original models that are GQA themselves." ) - n_groups = new_config.num_attention_heads // n_heads_in_group - orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group + n_groups = new_lm.num_attention_heads // n_heads_in_group + orig_n_groups = orig_lm.num_attention_heads // orig_n_heads_in_group assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}" n_repeats = n_groups // orig_n_groups if n_repeats > 1: diff --git a/modelopt/torch/puzzletron/sewing_kit/passage.py b/modelopt/torch/puzzletron/sewing_kit/passage.py index d8fa1f51cf9..c77b9dd41cd 100644 --- a/modelopt/torch/puzzletron/sewing_kit/passage.py +++ b/modelopt/torch/puzzletron/sewing_kit/passage.py @@ -45,6 +45,7 @@ "PassageOutput", "Predicate", "always_false_predicate", + "always_true_predicate", "Passage", "patch_module", ] diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py index 3db63f60013..106b0b3e4c3 100644 --- a/modelopt/torch/puzzletron/sewing_kit/utils.py +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -16,6 +16,7 @@ from __future__ import annotations import inspect +import operator from contextlib import contextmanager from typing import ( TYPE_CHECKING, @@ -451,3 +452,95 @@ def _get_group_kwarg_if_necessary() -> dict: torch.distributed.distributed_c10d._object_to_tensor ).parameters.keys() return dict(group=None) if "group" in arg_names else dict() + + +# ────────────────────────────────────────────────────────────────────────────── +# Loss functions for bypass distillation (blockwise local knowledge distillation) +# ────────────────────────────────────────────────────────────────────────────── + +# `normalized_mse_loss` already lives in tools.kd_model — re-export it here so +# bypass-distillation imports stay co-located with the per-vector / per-batch +# variants below, without duplicating the implementation. The `as +# normalized_mse_loss` form is PEP 484's explicit re-export (mypy treats +# `from X import Y` as a private import otherwise). +from modelopt.torch.puzzletron.tools.kd_model import ( # noqa: E402 + normalized_mse_loss as normalized_mse_loss, +) + + +def vectorwise_normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + epsilon: float = 1e-6, +) -> torch.Tensor: + """Like normalized_mse_loss, but normalization is done per-vector (last dim), then averaged.""" + return batched_normalized_mse_loss(input, target, epsilon, batch_dims=range(input.ndim - 1)) + + +def batched_normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + epsilon: float = 1e-6, + batch_dims: Sequence[int] = (0,), +) -> torch.Tensor: + """Per-batch-element relative-L2 loss. + + For each batch element, computes ``||input - target||^2 / (||target||^2 + eps)`` + over the non-batch dims, then averages across batch elements. The additive + ``epsilon`` in the denominator handles all-zero target slices without a hard + clamp and makes the loss scale-invariant when ``||target||^2 >> eps``. + """ + input_shape = tuple(input.shape) + target_shape = tuple(target.shape) + + if epsilon <= 0: + raise ValueError(f"epsilon must be strictly positive, got {epsilon!r}") + + try: + raw_batch_dims = tuple(operator.index(dim) for dim in batch_dims) + except TypeError as exc: + raise ValueError( + f"batch_dims must be an iterable of integer dimensions; got {batch_dims!r} " + f"for input shape {input_shape} and target shape {target_shape}" + ) from exc + + resolved_batch_dims = [] + for dim in raw_batch_dims: + if dim < -input.ndim or dim >= input.ndim: + raise ValueError( + f"batch_dims contains invalid dimension {dim} for input.ndim={input.ndim}; " + f"input shape={input_shape}, target shape={target_shape}, " + f"batch_dims={raw_batch_dims}, norm_dims=None" + ) + resolved_batch_dims.append(dim % input.ndim) + + if len(set(resolved_batch_dims)) != len(resolved_batch_dims): + raise ValueError( + f"batch_dims contains duplicate dimensions after normalization; " + f"input shape={input_shape}, target shape={target_shape}, " + f"batch_dims={tuple(resolved_batch_dims)}, norm_dims=None" + ) + + norm_dims = tuple(d for d in range(input.ndim) if d not in set(resolved_batch_dims)) + + if input.ndim != target.ndim: + raise ValueError( + f"input and target must have the same number of dimensions; " + f"input shape={input_shape}, target shape={target_shape}, " + f"batch_dims={tuple(resolved_batch_dims)}, norm_dims={norm_dims}" + ) + if input_shape != target_shape: + mismatched_dims = tuple( + dim + for dim, (input_size, target_size) in enumerate(zip(input_shape, target_shape)) + if input_size != target_size + ) + raise ValueError( + f"input and target shapes must match exactly; mismatched_dims={mismatched_dims}, " + f"input shape={input_shape}, target shape={target_shape}, " + f"batch_dims={tuple(resolved_batch_dims)}, norm_dims={norm_dims}" + ) + + num = ((input - target) ** 2).sum(dim=norm_dims) + den = (target**2).sum(dim=norm_dims) + epsilon + return (num / den).mean() diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py index b242c7d48ac..e041d884b0d 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py @@ -22,6 +22,8 @@ import os import re import time +from collections import ChainMap +from collections.abc import Iterator, MutableMapping from copy import deepcopy from functools import partial from pathlib import Path @@ -52,6 +54,49 @@ default_ignore_fn: IgnoreFn = lambda _: False +class _PerLayerKeysView(MutableMapping[str, str]): + def __init__(self, base: dict[str, str]) -> None: + self._base = base + self._overrides: dict[str, str] = {} + self._removed: dict[str, str] = {} + + def __getitem__(self, key: str) -> str: + if key in self._removed: + raise KeyError(key) + if key in self._overrides: + return self._overrides[key] + return self._base[key] + + def __setitem__(self, key: str, value: str) -> None: + self._removed.pop(key, None) + self._overrides[key] = value + + def __delitem__(self, key: str) -> None: + if key in self._removed: + raise KeyError(key) + if key in self._overrides: + self._removed[key] = self._overrides.pop(key) + elif key in self._base: + self._removed[key] = self._base[key] + else: + raise KeyError(key) + + def __iter__(self) -> Iterator[str]: + yield from self._overrides.keys() + for key in self._base: + if key not in self._overrides and key not in self._removed: + yield key + + def __len__(self) -> int: + return sum(1 for _ in self) + + def __contains__(self, key: object) -> bool: + return key not in self._removed and (key in self._overrides or key in self._base) + + def removed_items(self) -> dict[str, str]: + return dict(self._removed) + + class Printer: @staticmethod def print(s: str) -> None: @@ -83,27 +128,43 @@ def _process_single_layer( keys_to_remove = {} layer_out_state_dict = {} - # Delegate to pruning_mixin if available + # Delegate to pruning_mixin if available (supports a single mixin or a list of mixins). + # Mixins run sequentially. Each mixin sees the state dict produced by earlier mixins, + # which lets independent pruning methods compose on the same tensor (for example one + # pruning FFN channels and another pruning hidden-size dimensions). if pruning_mixin is not None: - _layer_out = pruning_mixin.prune_single_layer( - layer_idx=layer_idx, - parent_state_dict=parent_state_dict, - new_state_dict=new_state_dict, - original_config=original_config, - new_config=new_config, - gqa_init_mode=gqa_init_mode, - mlp_init_mode=mlp_init_mode, - mlp_init_config=mlp_init_config, - linear_init_mode=linear_init_mode, - ignored_keys=ignored_keys, - keys=keys, - is_original_mha=is_original_mha, - head_size=head_size, - hidden_size=hidden_size, - keys_to_remove=keys_to_remove, - ) - layer_out_state_dict.update(_layer_out) - return layer_out_state_dict, keys_to_remove + _mixins = pruning_mixin if isinstance(pruning_mixin, list) else [pruning_mixin] + merged_keys_to_remove = {} + parent_layer_updates = {} + new_layer_updates = {} + current_parent_state_dict = ChainMap(parent_layer_updates, parent_state_dict) + current_new_state_dict = ChainMap(new_layer_updates, new_state_dict) + current_keys = _PerLayerKeysView(keys) + for _mixin in _mixins: + mixin_keys_to_remove = {} + _layer_out = _mixin.prune_single_layer( + layer_idx=layer_idx, + parent_state_dict=current_parent_state_dict, + new_state_dict=current_new_state_dict, + original_config=original_config, + new_config=new_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + ignored_keys=ignored_keys, + keys=current_keys, + is_original_mha=is_original_mha, + head_size=head_size, + hidden_size=hidden_size, + keys_to_remove=mixin_keys_to_remove, + ) + layer_out_state_dict.update(_layer_out) + parent_layer_updates.update(_layer_out) + new_layer_updates.update(_layer_out) + merged_keys_to_remove.update(current_keys.removed_items()) + merged_keys_to_remove.update(mixin_keys_to_remove) + return layer_out_state_dict, merged_keys_to_remove # Legacy inline processing (fallback when no pruning_mixin) @@ -791,7 +852,10 @@ def update_model_config( def override(item, item_overrides): if item_overrides is None: - return item_overrides + # Hydra/OmegaConf ``null`` means "leave this field unchanged" in + # model_config_overrides. This lets compact overrides update only one + # sibling field without clearing the rest of the dataclass. + return item if dataclasses.is_dataclass(item): assert isinstance(item_overrides, dict) return dataclass_override(item, item_overrides) diff --git a/modelopt/torch/puzzletron/tools/hydra_utils.py b/modelopt/torch/puzzletron/tools/hydra_utils.py index c30be4efde8..c3e282d5e2b 100644 --- a/modelopt/torch/puzzletron/tools/hydra_utils.py +++ b/modelopt/torch/puzzletron/tools/hydra_utils.py @@ -32,16 +32,57 @@ ] -def warmup_steps(tokens: int, block: int, mbs: int, pct: float = 0.05) -> int: +def warmup_steps(tokens: int, block: int, mbs: int, grad_accum: int, pct: float) -> int: """ - Calculate warmup steps based on total tokens, block size, micro batch size, and warmup percentage. - Used as a resolver in hydra configs. + Calculate warmup steps in optimizer-step units. + + total_iters = tokens / (block * mbs) gives micro-batches; one optimizer step + consumes ``grad_accum`` micro-batches, so total optimizer steps = total_iters + / grad_accum. The LR scheduler in ``_get_lr`` is indexed by ``step_num`` + (optimizer steps), so warmup must be in the same units. """ - steps = (int(tokens) // int(block)) // int(mbs) + try: + tokens = int(tokens) + block = int(block) + mbs = int(mbs) + grad_accum = int(grad_accum) + except (TypeError, ValueError) as exc: + raise ValueError( + "tokens, block, mbs, and grad_accum must be integers or castable to int; " + f"got tokens={tokens!r}, block={block!r}, mbs={mbs!r}, grad_accum={grad_accum!r}" + ) from exc + + try: + pct = float(pct) + except (TypeError, ValueError) as exc: + raise ValueError(f"pct must be a float or castable to float, got {pct!r}") from exc + + if tokens < 0: + raise ValueError(f"tokens must be >= 0, got {tokens!r}") + if block <= 0: + raise ValueError(f"block must be > 0, got {block!r}") + if mbs <= 0: + raise ValueError(f"mbs must be > 0, got {mbs!r}") + if grad_accum < 1: + raise ValueError(f"grad_accum must be >= 1, got {grad_accum!r}") + if not 0.0 <= pct <= 1.0: + raise ValueError(f"pct must be between 0.0 and 1.0 inclusive, got {pct!r}") + + iters = (tokens // block) // mbs + steps = max(1, iters // grad_accum) w = pct * steps return max(1, round(w)) +def _warmup_steps_resolver(*args): + if len(args) != 5: + raise ValueError( + "warmup_steps resolver expects exactly 5 arguments: " + "(tokens, block, micro_batch_size, grad_accumulation_steps, warmup_ratio)" + ) + return warmup_steps(*args) + + def register_hydra_resolvers(): OmegaConf.register_new_resolver("to_path", lambda x: Path(x)) OmegaConf.register_new_resolver( @@ -50,7 +91,7 @@ def register_hydra_resolvers(): OmegaConf.register_new_resolver( "timedelta_minutes", lambda x: datetime.timedelta(minutes=x) if x is not None else None ) - OmegaConf.register_new_resolver("warmup_steps", lambda t, b, m, p: warmup_steps(t, b, m, p)) + OmegaConf.register_new_resolver("warmup_steps", _warmup_steps_resolver) OmegaConf.register_new_resolver("get_object", lambda x: get_object(x)) diff --git a/modelopt/torch/puzzletron/utils/data/dataloaders.py b/modelopt/torch/puzzletron/utils/data/dataloaders.py index f4046531491..3d8b94c82cc 100644 --- a/modelopt/torch/puzzletron/utils/data/dataloaders.py +++ b/modelopt/torch/puzzletron/utils/data/dataloaders.py @@ -31,7 +31,7 @@ from ...tools.logger import mprint from .dataset import ConstantLengthDataset -__all__ = ["create_validation_dataloader", "create_padded_tensor"] +__all__ = ["create_train_dataloader", "create_validation_dataloader", "create_padded_tensor"] def collate_none_fn( @@ -73,6 +73,74 @@ def load_streaming_fn( return dataset +def create_train_dataloader( + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset_path: str | Mapping[str, Dataset], + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + dataset_name: str = "train", + keep_in_memory: bool = False, + shuffle_seed: int | None = None, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + num_workers: int = 0, +) -> DataLoader: + """Create an infinite training DataLoader over ConstantLengthDataset.""" + # ConstantLengthDataset.__iter__ does not consult torch.utils.data.get_worker_info() + # to shard work across DataLoader workers, so num_workers > 0 would have every + # worker iterate the full dataset and emit duplicate samples. Reject explicitly + # until ConstantLengthDataset gains worker-aware iteration; the guard can then + # be removed. + if num_workers > 0: + raise ValueError( + f"create_train_dataloader: num_workers={num_workers} is not supported " + f"because ConstantLengthDataset.__iter__ does not shard via " + f"torch.utils.data.get_worker_info(). Use num_workers=0 (the default) " + f"or add worker-aware sharding to ConstantLengthDataset.__iter__." + ) + + if isinstance(dataset_path, str): + dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory) + else: + dataset = dataset_path + + train_data = dataset[dataset_name] + if shuffle_seed is not None: + # `keep_in_memory` is only valid on map-style HF Datasets; streaming + # `IterableDataset.shuffle()` only accepts `seed` (and an optional + # `buffer_size`). Branch on the dataset type so streaming users + # (`load_from_disk: false`) don't crash on this call. + if isinstance(train_data, datasets.IterableDataset): + train_data = train_data.shuffle(seed=shuffle_seed) + else: + train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=keep_in_memory) + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + infinite=True, + seq_length=block_size, + content_field=content_field, + fim_rate=fim_rate, + fim_spm_rate=fim_spm_rate, + seed=seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + ) + + return DataLoader( + train_dataset, + batch_size=micro_batch_size, + pin_memory=True, + num_workers=num_workers, + ) + + def create_validation_dataloader( accelerator: Accelerator | None, seed: int, diff --git a/modelopt/torch/puzzletron/utils/data/dataset.py b/modelopt/torch/puzzletron/utils/data/dataset.py index f88e44a234b..01422e5a4b7 100644 --- a/modelopt/torch/puzzletron/utils/data/dataset.py +++ b/modelopt/torch/puzzletron/utils/data/dataset.py @@ -35,6 +35,14 @@ CODEGEN_FIM_TOKENS = ["", "<|endoftext|>", ""] +def _message_content_to_text(content) -> str: + if isinstance(content, str): + return content + if isinstance(content, dict) and "text" in content: + return str(content["text"]) + return str(content) + + class ConstantLengthDataset(IterableDataset): """Iterable dataset that returns constant length chunks of tokens from stream of text files. @@ -128,9 +136,18 @@ def __iter__(self) -> dict[str, torch.Tensor]: and {"content", "role"}.issubset(sample[0]) ): if len(sample) > 1: - sample = self.tokenizer.apply_chat_template(sample, tokenize=False) + if getattr(self.tokenizer, "chat_template", None) is not None: + sample = self.tokenizer.apply_chat_template( + sample, tokenize=False + ) + else: + # Base models have no chat template — concatenate message + # contents separated by newlines as plain text. + sample = "\n".join( + _message_content_to_text(m["content"]) for m in sample + ) else: - sample = sample[0]["content"] + sample = _message_content_to_text(sample[0]["content"]) else: sample = sample[self.tokens_field] sample = sample[: self.max_sample_length] diff --git a/modelopt/torch/puzzletron/utils/parsing.py b/modelopt/torch/puzzletron/utils/parsing.py index 149563b4321..69e21e0599b 100644 --- a/modelopt/torch/puzzletron/utils/parsing.py +++ b/modelopt/torch/puzzletron/utils/parsing.py @@ -24,6 +24,7 @@ # mypy: ignore-errors import json +import math from pathlib import Path from typing import Any @@ -116,7 +117,7 @@ def format_block_configs(config) -> str: ╭─────────────────────── Model Architecture ────────────────────────╮ │ Layer 1 │ Attention: no_op │ FFN: mult = 4.95 │ │ Layer 2 │ Attention: 4 heads in group │ FFN: mult = 4.95 │ - │ Layer 3 │ Attention: 4 heads in group │ FFN: no_op │ + │ Layer 3 │ Attention: no_op │ FFN: no_op │ ╰────────────────────────────────────────────────────────────────────╯ """ if not hasattr(config, "block_configs") or not config.block_configs: @@ -158,7 +159,7 @@ def _format_attention_config(attention_config) -> str: num_kv_heads = attention_config.num_key_value_heads if num_kv_heads is not None: - return f"{num_kv_heads} kv heads" + return f"🐙 {num_kv_heads} kv heads" if attention_config.replace_with_linear: return "linear replacement" @@ -192,12 +193,12 @@ def _format_ffn_config(ffn_config) -> str: ffn_intermediate = ffn_config.intermediate_size if ffn_intermediate is not None: - return f"ffn_intermediate = {ffn_intermediate}" + return f"🧱 ffn_dim = {ffn_intermediate}" # Check for MoE configuration moe_config = ffn_config.moe if moe_config: - return "MoE" + return "🔀 MoE" if ffn_config.sparsify: return "sparse" @@ -287,7 +288,7 @@ def _add_config_section(cfg: DictConfig, section_name: str = "", indent: int = 0 # Regular key-value pair indent_str = " " * (indent + 1) value_str = _format_value(value).replace(" " * 0, "").strip() - line = f"│ {indent_str} {key}: {value_str}" + line = f"│ {indent_str} • {key}: {value_str}" # Pad to box width if len(line) >= box_width - 1: # Truncate long lines @@ -310,6 +311,8 @@ def format_stitched_losses( losses_dict: dict[str, float], best_steps_dict: dict[str, int] | None = None, best_values_dict: dict[str, float] | None = None, + initial_values_dict: dict[str, float] | None = None, + not_trainable_names: set[str] | None = None, step_number: int | None = None, title: str = "Stitched Module Losses", ) -> str: @@ -320,6 +323,9 @@ def format_stitched_losses( losses_dict: Dictionary with block names as keys and current loss values as floats best_steps_dict: Optional dictionary with block names as keys and best step numbers as values best_values_dict: Optional dictionary with block names as keys and best loss values as floats + initial_values_dict: Optional dictionary with block names as keys and initial loss values + (from the first log chunk) as floats. Used to render the "Δ from initial" column as + a per-block training-progress signal. step_number: Optional current step number to include in summary title: Title to display at the top of the formatted output @@ -328,23 +334,39 @@ def format_stitched_losses( Example output: ╭─────────────────── Stitched Module Losses ──────────────────╮ - │ Block │ Loss Value │ Best Step │ Best Value │ Change from avg │ - │───────┼────────────┼───────────┼────────────┼──────────────────│ - │ 00 │ 6.21e-03 │ Step 5 │ 5.95e-03 │ ↑ +2.6e-04 │ - │ 01 │ 5.14e-04 │ Step 12 │ 5.14e-04 │ ↓ -1.2e-04 │ - │ 02 │ 9.84e-05 │ Step 15 │ 9.84e-05 │ ↓ -3.1e-04 │ + │ Block │ Loss Value │ Δ from initial │ Best Value │ Best Step │ + │───────┼────────────┼──────────────────┼────────────┼───────────│ + │ 00 │ 6.21e-03 │ ↓ -3.2e-04 (-5%) │ 5.95e-03 │ Step 5 │ + │ 01 │ 5.14e-04 │ ↓ -1.8e-03 (-78%)│ 5.14e-04 │ Step 12 │ + │ 02 │ 9.84e-05 │ ↓ -4.1e-04 (-81%)│ 9.84e-05 │ Step 15 │ ╰──────────────────────────────────────────────────────────────╯ """ if not losses_dict: + if not_trainable_names: + return ( + "No trainable losses found; " + f"skipped {len(not_trainable_names)} non-trainable blocks" + ) return "❌ No losses found" + if best_steps_dict: + best_steps_dict = {k: v for k, v in best_steps_dict.items() if k in losses_dict} + if best_values_dict: + best_values_dict = {k: v for k, v in best_values_dict.items() if k in losses_dict} + if initial_values_dict: + initial_values_dict = {k: v for k, v in initial_values_dict.items() if k in losses_dict} + lines = [] # Calculate statistics loss_values = list(losses_dict.values()) - max_loss = max(loss_values) - min_loss = min(loss_values) - avg_loss = sum(loss_values) / len(loss_values) + finite_loss_values = [value for value in loss_values if math.isfinite(value)] + if finite_loss_values: + max_loss = max(finite_loss_values) + min_loss = min(finite_loss_values) + avg_loss = sum(finite_loss_values) / len(finite_loss_values) + else: + max_loss = min_loss = avg_loss = float("nan") # Calculate box width for new layout (removed Bar column) box_width = 74 @@ -356,10 +378,10 @@ def format_stitched_losses( f"│{' ' * title_padding}{title}{' ' * (box_width - 2 - title_padding - len(title))}│" ) separator = ( - f"│ {'Block':<5} │ {'Loss Value':<12} │ {'Best Step':<10} │ " - f"{'Best Value':<12} │ {'Change from avg':<18} │" + f"│ {'Block':<5} │ {'Loss Value':<12} │ {'Δ from initial':<18} │ " + f"{'Best Value':<12} │ {'Best Step':<10} │" ) - divider = f"│{'─' * 7}┼{'─' * 14}┼{'─' * 12}┼{'─' * 14}┼{'─' * 20}│" + divider = f"│{'─' * 7}┼{'─' * 14}┼{'─' * 20}┼{'─' * 14}┼{'─' * 12}│" lines.extend([header, title_line, separator, divider]) @@ -382,26 +404,36 @@ def format_stitched_losses( best_value = loss_value # Assume current is best if no history best_value_str = f"{best_value:.2e}" - # Calculate change from average - change_from_avg = loss_value - avg_loss - if abs(change_from_avg) > 1e-8: # Only show if meaningful - change_str = f"{abs(change_from_avg):.1e}" - if change_from_avg > 0: - # Current is above average (worse for loss) - change_display = f"↑ +{change_str}" + # Calculate change from initial: current loss minus the block's loss in the + # first log chunk we saw. Per-block training-progress signal — answers "is + # bypass distillation actually reducing this block's loss?" and stays + # apples-to-apples even when blocks have very different intrinsic loss scales. + if initial_values_dict and block_name in initial_values_dict: + initial_value = initial_values_dict[block_name] + if not math.isfinite(loss_value) or not math.isfinite(initial_value): + change_display = "non-finite" else: - # Current is below average (better for loss) - change_display = f"↓ -{change_str}" + delta = loss_value - initial_value + if math.isfinite(loss_value) and math.isfinite(initial_value) and abs(delta) > 1e-8: + pct = (delta / initial_value * 100.0) if initial_value != 0.0 else 0.0 + # Clamp percentage display to keep the cell within the 18-char column + # even on pathological divergence (e.g. a block whose loss 10x'd). + pct_clamped = max(-999.0, min(999.0, pct)) + arrow = "↓" if delta < 0 else "↑" + sign = "-" if delta < 0 else "+" + change_display = f"{arrow} {sign}{abs(delta):.1e} ({pct_clamped:+.0f}%)" + elif math.isfinite(loss_value) and math.isfinite(initial_value): + change_display = "↔ 0.0e+00" else: - # At average value - change_display = "↔ 0.0e+00" + # No baseline supplied (callers may omit initial_values_dict). + change_display = " --" # Format the line block_display = block_name.replace("block_", "").zfill(2) line = ( - f"│ {block_display:<5} │ {loss_str:<12} │ {best_step_str:<10} │ " - f"{best_value_str:<12} │ {change_display:<18} │" + f"│ {block_display:<5} │ {loss_str:<12} │ {change_display:<18} │ " + f"{best_value_str:<12} │ {best_step_str:<10} │" ) lines.append(line) @@ -413,6 +445,8 @@ def format_stitched_losses( if step_number is not None: summary_parts.append(f"Step {step_number}") summary_parts.extend([f"Avg={avg_loss:.2e}", f"Max={max_loss:.2e}", f"Min={min_loss:.2e}"]) + if not_trainable_names: + summary_parts.append(f"Skipped={len(not_trainable_names)}") summary_text = ", ".join(summary_parts) summary = f"│ Summary: {summary_text}" @@ -436,7 +470,9 @@ def format_stitched_losses( best_step_values = [] for block_name, best_step in best_steps_dict.items(): if best_step == modal_best_step and block_name in best_values_dict: - best_step_values.append(best_values_dict[block_name]) + best_value = best_values_dict[block_name] + if math.isfinite(best_value): + best_step_values.append(best_value) if best_step_values: best_step_avg = sum(best_step_values) / len(best_step_values) diff --git a/tests/unit/torch/puzzletron/test_bypass_dataloaders.py b/tests/unit/torch/puzzletron/test_bypass_dataloaders.py new file mode 100644 index 00000000000..1bcea14633e --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_dataloaders.py @@ -0,0 +1,328 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for bypass-distillation dataloader utilities. + +Covers the pure-Python branches of ``utils/data/dataloaders.py`` that don't +need a real tokenizer / GPU / distributed init: the validation-split +auto-detect rules, the ``num_workers`` guard rail, the dataset-loader +delegators, the ``Printer`` fake accelerator, and the small numeric helpers +(``create_padded_tensor``, ``realize_dataset_in_memory``, ``collate_none_fn``). +""" + +import datasets +import pytest +import torch +from datasets import Dataset, DatasetDict + +import modelopt.torch.puzzletron.utils.data.dataloaders as dl +from modelopt.torch.puzzletron.utils.data.dataloaders import ( + Printer, + collate_fn_with_none_support, + collate_none_fn, + create_padded_tensor, + create_train_dataloader, + create_validation_dataloader, + load_from_disk_fn, + load_streaming_fn, + realize_dataset_in_memory, +) +from modelopt.torch.puzzletron.utils.data.dataset import ConstantLengthDataset + +# --------------------------------------------------------------------------- +# realize_dataset_in_memory: pure list materialisation with optional cap +# --------------------------------------------------------------------------- + + +def test_realize_dataset_in_memory_full(): + items = [{"a": 1}, {"a": 2}, {"a": 3}] + out = realize_dataset_in_memory(iter(items), eval_samples=None) + assert out == items + + +def test_realize_dataset_in_memory_capped(): + items = [{"a": 1}, {"a": 2}, {"a": 3}] + out = realize_dataset_in_memory(iter(items), eval_samples=2) + assert out == [{"a": 1}, {"a": 2}] + + +# --------------------------------------------------------------------------- +# create_padded_tensor: identity, 1D pad, 2D pad with non-zero pad value +# --------------------------------------------------------------------------- + + +def test_create_padded_tensor_identity(): + t = torch.arange(6, dtype=torch.float32).reshape(2, 3) + out = create_padded_tensor(t, desired_shape=(2, 3)) + assert out is t # short-circuit, no copy + + +def test_create_padded_tensor_pads_1d_with_default_zero(): + t = torch.tensor([1, 2, 3], dtype=torch.int32) + out = create_padded_tensor(t, desired_shape=(5,)) + assert out.tolist() == [1, 2, 3, 0, 0] + assert out.dtype == torch.int32 + + +def test_create_padded_tensor_pads_2d_with_custom_value(): + t = torch.tensor([[1.0, 2.0]]) + out = create_padded_tensor(t, desired_shape=(2, 3), padding_value=-100.0) + assert out.tolist() == [[1.0, 2.0, -100.0], [-100.0, -100.0, -100.0]] + + +# --------------------------------------------------------------------------- +# Collate helpers: None-aware default collator +# --------------------------------------------------------------------------- + + +def test_collate_none_fn_returns_none(): + assert collate_none_fn([None, None]) is None + assert collate_none_fn([1, 2, 3]) is None # unconditional + + +def test_collate_fn_with_none_support_passes_none_through(): + """A label tensor of None should not be coerced to ``[None, None]`` — the + bypass val loop expects a single ``None`` so it can short-circuit loss + computation. This pins the ``type(None) -> collate_none_fn`` registration.""" + batch = [{"x": torch.tensor([1.0]), "y": None}, {"x": torch.tensor([2.0]), "y": None}] + out = collate_fn_with_none_support(batch) + assert out["y"] is None + assert torch.equal(out["x"], torch.tensor([[1.0], [2.0]])) + + +# --------------------------------------------------------------------------- +# Printer: degenerate "main process" stand-in for Accelerator +# --------------------------------------------------------------------------- + + +def test_printer_attributes_match_main_process_contract(): + assert Printer.is_main_process is True + assert Printer.process_index is None + Printer.print("hello world") # must not raise + + +# --------------------------------------------------------------------------- +# load_from_disk_fn / load_streaming_fn: thin wrappers around datasets.* +# --------------------------------------------------------------------------- + + +def test_load_from_disk_fn_delegates_to_datasets(monkeypatch): + captured = {} + + def fake_load_from_disk(path, keep_in_memory=False): + captured["path"] = path + captured["keep_in_memory"] = keep_in_memory + return "sentinel" + + monkeypatch.setattr(datasets, "load_from_disk", fake_load_from_disk) + out = load_from_disk_fn("/some/path", content_field="conversation", keep_in_memory=True) + assert out == "sentinel" + assert captured == {"path": "/some/path", "keep_in_memory": True} + + +def test_load_streaming_fn_uses_streaming_with_features(monkeypatch): + """``load_streaming_fn`` must request streaming and pin the content field's + feature schema — without ``features=`` HuggingFace would auto-infer types + per-shard, which has caused bypass jobs to crash on schema drift in the past. + """ + captured = {} + + def fake_load_dataset(path, streaming, features, keep_in_memory): + captured["path"] = path + captured["streaming"] = streaming + captured["features"] = features + captured["keep_in_memory"] = keep_in_memory + return "stream-sentinel" + + monkeypatch.setattr(datasets, "load_dataset", fake_load_dataset) + out = load_streaming_fn("hf-org/dataset", content_field="text", keep_in_memory=False) + assert out == "stream-sentinel" + assert captured["path"] == "hf-org/dataset" + assert captured["streaming"] is True + assert captured["keep_in_memory"] is False + # features must be a Features object keyed by the requested content_field + # with a string Value — schema-drift protection is the whole point of this fn. + assert isinstance(captured["features"], datasets.Features) + assert "text" in captured["features"] + assert captured["features"]["text"].dtype == "string" + + +# --------------------------------------------------------------------------- +# create_train_dataloader: ``num_workers > 0`` is a configuration error +# --------------------------------------------------------------------------- + + +def test_create_train_dataloader_rejects_num_workers_gt_zero(): + """ConstantLengthDataset doesn't shard work via ``get_worker_info`` — every + worker would emit the same samples. The guard fires before tokenizer or + dataset are touched, so bare-bones args are enough.""" + with pytest.raises(ValueError, match="num_workers"): + create_train_dataloader( + seed=0, + tokenizer=None, + block_size=8, + dataset_path={"train": []}, + content_field="text", + fim_rate=0.0, + fim_spm_rate=0.0, + micro_batch_size=1, + num_workers=2, + ) + + +class _NoChatTemplateTokenizer: + eos_token_id = 1 + bos_token_id = None + + def __init__(self): + self.seen_texts = None + self.vocab = {} + + def __call__(self, texts, truncation=False): + self.seen_texts = texts + return {"input_ids": [[0] for _ in texts]} + + +class _ConversationDataset: + column_names = ("text",) + + def __iter__(self): + yield { + "text": [ + {"role": "user", "content": {"text": "hello"}}, + {"role": "assistant", "content": {"value": 3}}, + ] + } + + +def test_constant_length_dataset_no_chat_template_normalizes_message_content(): + tokenizer = _NoChatTemplateTokenizer() + dataset = ConstantLengthDataset( + tokenizer, + _ConversationDataset(), + infinite=False, + seq_length=2, + num_of_sequences=1, + chars_per_token=100, + content_field="text", + fim_rate=0.0, + fim_spm_rate=0.0, + label_shift=False, + ) + + realized = list(dataset) + + assert tokenizer.seen_texts == ["hello\n{'value': 3}"] + assert len(realized) == 1 + assert torch.equal(realized[0]["input_ids"], torch.tensor([0, 1])) + assert torch.equal(realized[0]["targets"], torch.tensor([0, 1])) + + +# --------------------------------------------------------------------------- +# create_validation_dataloader: split auto-detect + explicit override +# --------------------------------------------------------------------------- + + +class _FakeConstantLengthDataset: + """Stub for ``ConstantLengthDataset`` that records its ``dataset`` arg. + + Yields one trivial item so ``realize_dataset_in_memory`` can iterate over + it without touching a tokenizer. + """ + + last_dataset = None # class-level capture so tests can read after construction + + def __init__(self, tokenizer, dataset, **kwargs): + type(self).last_dataset = dataset + self._dataset = dataset + + def __iter__(self): + yield {"input_ids": torch.tensor([0])} + + +@pytest.fixture +def patched_dataloader(monkeypatch): + """Replace the heavy bits inside ``create_validation_dataloader`` so the + function exercises only its pure split-selection logic + DataLoader build.""" + monkeypatch.setattr(dl, "ConstantLengthDataset", _FakeConstantLengthDataset) + # Force a tiny in-memory list so we don't drain a real iterable. + monkeypatch.setattr( + dl, + "realize_dataset_in_memory", + lambda dataset, eval_samples: [{"input_ids": torch.tensor([0])}], + ) + _FakeConstantLengthDataset.last_dataset = None + return _FakeConstantLengthDataset + + +def _make_dict_dataset(splits: dict[str, list]) -> DatasetDict: + return DatasetDict({k: Dataset.from_list(v) for k, v in splits.items()}) + + +def _kwargs(): + return { + "accelerator": None, # → Printer (single-process path) + "seed": 0, + "tokenizer": None, + "block_size": 4, + "content_field": "text", + "fim_rate": 0.0, + "fim_spm_rate": 0.0, + "micro_batch_size": 1, + } + + +def test_validation_split_auto_picks_validation_when_present(patched_dataloader): + dd = _make_dict_dataset({"train": [{"text": "t"}], "validation": [{"text": "v"}]}) + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + # The "validation" split must have been the one passed to ConstantLengthDataset. + assert patched_dataloader.last_dataset is dd["validation"] + + +def test_validation_split_auto_falls_back_to_test_when_no_val(patched_dataloader): + dd = _make_dict_dataset({"train": [{"text": "t"}], "test": [{"text": "te"}]}) + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + assert patched_dataloader.last_dataset is dd["test"] + + +def test_validation_split_auto_prefers_val_over_test(patched_dataloader): + """If both ``validation`` and ``test`` exist, the val* prefix must win — + bypass relies on this to score against held-out data, not test data.""" + dd = _make_dict_dataset( + {"train": [{"text": "t"}], "validation": [{"text": "v"}], "test": [{"text": "te"}]} + ) + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + assert patched_dataloader.last_dataset is dd["validation"] + + +def test_validation_split_auto_assertion_on_multiple_val_options(patched_dataloader): + """Ambiguity must fail loudly — silently picking one would be a footgun.""" + dd = _make_dict_dataset({"validation": [{"text": "a"}], "valtest": [{"text": "b"}]}) + with pytest.raises(AssertionError, match="exactly one validation split"): + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + + +def test_validation_split_auto_assertion_on_no_val_or_test(patched_dataloader): + dd = _make_dict_dataset({"train": [{"text": "t"}], "extra": [{"text": "e"}]}) + with pytest.raises(AssertionError, match="exactly one validation split"): + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + + +def test_validation_split_explicit_override_bypasses_auto(patched_dataloader): + """Explicit ``dataset_name`` must skip the auto-detect, even when the + chosen name doesn't match val* / test* prefixes.""" + dd = _make_dict_dataset({"my_eval": [{"text": "x"}]}) + create_validation_dataloader(dataset=dd, dataset_name="my_eval", **_kwargs()) + assert patched_dataloader.last_dataset is dd["my_eval"] diff --git a/tests/unit/torch/puzzletron/test_bypass_losses.py b/tests/unit/torch/puzzletron/test_bypass_losses.py new file mode 100644 index 00000000000..2d59b25716a --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_losses.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for normalized MSE loss functions in sewing_kit/utils.py.""" + +import pytest +import torch + +from modelopt.torch.puzzletron.sewing_kit.utils import ( + batched_normalized_mse_loss, + normalized_mse_loss, + vectorwise_normalized_mse_loss, +) +from modelopt.torch.puzzletron.utils.parsing import format_stitched_losses + +# --------------------------------------------------------------------------- +# normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_normalized_mse_loss_identical_tensors(): + """Identical input and target should produce a loss of approximately 0.""" + torch.manual_seed(42) + x = torch.randn(4, 8) + loss = normalized_mse_loss(x, x) + assert torch.allclose(loss, torch.zeros_like(loss), atol=1e-6) + + +def test_normalized_mse_loss_basic(): + """Loss should be positive and finite for random, non-identical tensors.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target) + assert loss.item() > 0.0 + assert torch.isfinite(loss) + + +def test_normalized_mse_loss_reduction_none(): + """With reduction='none' the output shape should match the input shape.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target, reduction="none") + assert loss.shape == input_.shape + + +def test_normalized_mse_loss_reduction_sum(): + """With reduction='sum' the output should be a scalar tensor.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target, reduction="sum") + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + + +# --------------------------------------------------------------------------- +# vectorwise_normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_vectorwise_normalized_mse_loss_shape(): + """vectorwise_normalized_mse_loss should return a scalar for any 2-D input.""" + torch.manual_seed(42) + input_ = torch.randn(4, 16) + target = torch.randn(4, 16) + loss = vectorwise_normalized_mse_loss(input_, target) + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + + +def test_vectorwise_normalized_mse_loss_identical(): + """Identical input and target should give a loss of approximately 0.""" + torch.manual_seed(42) + x = torch.randn(4, 16) + loss = vectorwise_normalized_mse_loss(x, x) + assert torch.allclose(loss, torch.zeros_like(loss), atol=1e-6) + + +# --------------------------------------------------------------------------- +# batched_normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_batched_normalized_mse_loss_basic(): + """Should return a scalar with a positive, finite value for random tensors.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = batched_normalized_mse_loss(input_, target) + assert loss.ndim == 0 # scalar + assert loss.item() > 0.0 + assert torch.isfinite(loss) + + +def test_batched_normalized_mse_loss_custom_dims(): + """Custom batch_dims=(0, 1) on a 3-D tensor should still return a scalar.""" + torch.manual_seed(42) + input_ = torch.randn(2, 3, 8) + target = torch.randn(2, 3, 8) + loss = batched_normalized_mse_loss(input_, target, batch_dims=(0, 1)) + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + assert loss.item() > 0.0 + + +def test_batched_normalized_mse_loss_zero_target_is_finite(): + """All-zero target slice must not produce NaN/Inf. + + With the relative-L2 formula ``sum((x-t)^2) / (sum(t^2) + eps)``, an all-zero + target reduces the denominator to exactly ``eps`` — finite, no division by + zero — so the loss equals ``||input||^2 / eps``. The numeric value is large + by construction (that's what zero-magnitude targets mean), but the test + pins the property we actually care about: finiteness, not magnitude. + """ + input_ = torch.full((1, 8), 1.0) + target = torch.zeros(1, 8) + loss = batched_normalized_mse_loss(input_, target) + assert torch.isfinite(loss) + assert not torch.isnan(loss) + + +def test_batched_normalized_mse_loss_zero_input_and_target(): + """Both zero should give exactly 0.0 — numerator is zero, denominator is eps.""" + input_ = torch.zeros(2, 4) + target = torch.zeros(2, 4) + loss = batched_normalized_mse_loss(input_, target) + assert loss.item() == 0.0 + + +def test_batched_normalized_mse_loss_scale_invariance(): + """Scaling both input and target by the same constant must leave the loss + unchanged for non-tiny targets — the defining property of relative-L2.""" + torch.manual_seed(0) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + baseline = batched_normalized_mse_loss(input_, target) + scaled = batched_normalized_mse_loss(10.0 * input_, 10.0 * target) + assert torch.allclose(baseline, scaled, rtol=1e-4, atol=1e-6) + + +def test_batched_normalized_mse_loss_rejects_shape_mismatch(): + input_ = torch.randn(2, 3) + target = torch.randn(2, 1) + + with pytest.raises(ValueError, match="input and target shapes must match exactly"): + batched_normalized_mse_loss(input_, target) + + +def test_batched_normalized_mse_loss_rejects_invalid_batch_dim(): + input_ = torch.randn(2, 3) + target = torch.randn(2, 3) + + with pytest.raises(ValueError, match="batch_dims contains invalid dimension"): + batched_normalized_mse_loss(input_, target, batch_dims=(2,)) + + +def test_format_stitched_losses_keeps_trainable_nan_visible(): + out = format_stitched_losses( + {"block_0": float("nan"), "block_1": 1.0}, + initial_values_dict={"block_0": 0.5, "block_1": 2.0}, + not_trainable_names={"block_2"}, + step_number=3, + ) + + assert "nan" in out + assert "non-finite" in out + assert "Skipped=1" in out + assert "No trainable blocks found" not in out diff --git a/tests/unit/torch/puzzletron/test_child_init_mixins.py b/tests/unit/torch/puzzletron/test_child_init_mixins.py new file mode 100644 index 00000000000..b68313245e4 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_child_init_mixins.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +import torch + +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import _process_single_layer + + +class _AddOneMixin: + def prune_single_layer(self, parent_state_dict, keys_to_remove, **kwargs): + keys_to_remove["w"] = "w" + return {"w": parent_state_dict["w"] + 1} + + +class _TimesTwoMixin: + def prune_single_layer(self, parent_state_dict, keys_to_remove, **kwargs): + keys_to_remove["w"] = "w" + return {"w": parent_state_dict["w"] * 2} + + +class _PopKeyMixin: + def prune_single_layer(self, parent_state_dict, keys, **kwargs): + keys.pop("w") + return {"w": parent_state_dict["w"]} + + +def _process_with_mixins(mixins, keys): + return _process_single_layer( + layer_idx=0, + pruning_mixin=mixins, + descriptor=None, + parent_state_dict={"w": torch.tensor([1.0])}, + new_state_dict={"w": torch.tensor([0.0])}, + original_config=SimpleNamespace(), + new_config=SimpleNamespace(), + gqa_init_mode=None, + mlp_init_mode=None, + mlp_init_config=None, + linear_init_mode=None, + ignored_keys=set(), + keys=keys, + is_original_mha=False, + head_size=1, + hidden_size=1, + ) + + +def test_pruning_mixins_compose_overlapping_outputs_sequentially(): + layer_state_dict, keys_to_remove = _process_with_mixins( + [_AddOneMixin(), _TimesTwoMixin()], {"w": "w"} + ) + + assert torch.equal(layer_state_dict["w"], torch.tensor([4.0])) + assert keys_to_remove == {"w": "w"} + + +def test_pruning_mixin_key_mutation_is_tracked_without_mutating_shared_keys(): + shared_keys = {"w": "w"} + + _, keys_to_remove = _process_with_mixins([_PopKeyMixin()], shared_keys) + + assert keys_to_remove == {"w": "w"} + assert shared_keys == {"w": "w"} diff --git a/tests/unit/torch/puzzletron/test_hydra_utils.py b/tests/unit/torch/puzzletron/test_hydra_utils.py new file mode 100644 index 00000000000..4b84dc08812 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_hydra_utils.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from modelopt.torch.puzzletron.tools.hydra_utils import warmup_steps + + +def test_warmup_steps_casts_inputs_before_computing(): + assert warmup_steps("100", "10", "2", "5", "0.5") == 1 + + +@pytest.mark.parametrize( + ("kwargs", "message"), + [ + ({"tokens": -1, "block": 1, "mbs": 1, "grad_accum": 1, "pct": 0.1}, "tokens"), + ({"tokens": 1, "block": 0, "mbs": 1, "grad_accum": 1, "pct": 0.1}, "block"), + ({"tokens": 1, "block": 1, "mbs": 0, "grad_accum": 1, "pct": 0.1}, "mbs"), + ({"tokens": 1, "block": 1, "mbs": 1, "grad_accum": 0, "pct": 0.1}, "grad_accum"), + ({"tokens": 1, "block": 1, "mbs": 1, "grad_accum": 1, "pct": 1.1}, "pct"), + ], +) +def test_warmup_steps_rejects_invalid_inputs(kwargs, message): + with pytest.raises(ValueError, match=message): + warmup_steps(**kwargs) diff --git a/tests/unit/torch/puzzletron/test_kv_heads_pruning_utils.py b/tests/unit/torch/puzzletron/test_kv_heads_pruning_utils.py new file mode 100644 index 00000000000..421ec4304bb --- /dev/null +++ b/tests/unit/torch/puzzletron/test_kv_heads_pruning_utils.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +from modelopt.torch.puzzletron.pruning.pruning_utils import _lm_head_dim + + +def test_lm_head_dim_uses_explicit_nested_head_dim(): + cfg = SimpleNamespace( + text_config=SimpleNamespace(head_dim=96, hidden_size=3072, num_attention_heads=32) + ) + assert _lm_head_dim(cfg) == 96 + + +def test_lm_head_dim_falls_back_to_hidden_size_over_heads(): + cfg = SimpleNamespace(text_config=SimpleNamespace(hidden_size=3072, num_attention_heads=32)) + assert _lm_head_dim(cfg) == 96 diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py b/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py new file mode 100644 index 00000000000..58df5ffe327 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py @@ -0,0 +1,171 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``sewing_kit.utils.ActivityContext``. + +``ActivityContext`` is the stack the ``Passage`` machinery uses to track which +passages are currently active inside a ``StitchedModule.forward`` call. A bug +in push/pop ordering or in the exception-safe cleanup would leak state across +forward passes — every subsequent block would see a stale "active passage" +and route inputs/outputs to the wrong module. +""" + +import pytest + +from modelopt.torch.puzzletron.sewing_kit.utils import ( + ActivityContext, + ActivityContextDuplicateException, + ActivityContextMaxDepthException, + is_submodule_of, + is_submodule_or_same, +) + +# --------------------------------------------------------------------------- +# Basic push/pop semantics via the ``with ctx(value):`` form +# --------------------------------------------------------------------------- + + +def test_starts_empty_and_inactive(): + ctx: ActivityContext[str] = ActivityContext() + assert len(ctx) == 0 + assert not ctx.is_active() + assert ctx.get_active() is None + + +def test_with_block_pushes_and_pops_value(): + ctx: ActivityContext[str] = ActivityContext() + with ctx("a"): + assert ctx.is_active() + assert ctx.get_active() == "a" + assert "a" in ctx + assert len(ctx) == 1 + # After the block: stack must be back to empty. + assert len(ctx) == 0 + assert ctx.get_active() is None + + +def test_nested_pushes_track_lifo_order(): + """``get_active`` returns the *most recent* push (LIFO) — Passage relies on + this to find the innermost active passage during forward.""" + ctx: ActivityContext[str] = ActivityContext() + with ctx("outer"): + assert ctx.get_active() == "outer" + with ctx("inner"): + assert ctx.get_active() == "inner" + assert ctx[0] == "outer" + assert ctx[1] == "inner" + # Inner pop returns to outer. + assert ctx.get_active() == "outer" + + +# --------------------------------------------------------------------------- +# max_depth: limits stack height +# --------------------------------------------------------------------------- + + +def test_max_depth_one_allows_single_push(): + ctx: ActivityContext[str] = ActivityContext(max_depth=1) + with ctx("a"): + assert ctx.get_active() == "a" + + +def test_max_depth_one_rejects_second_push(): + ctx: ActivityContext[str] = ActivityContext(max_depth=1) + with ctx("a"), pytest.raises(ActivityContextMaxDepthException), ctx("b"): + pass + # Stack must have unwound to empty even after the exception. + assert len(ctx) == 0 + + +# --------------------------------------------------------------------------- +# no_duplicates: same value can't appear twice +# --------------------------------------------------------------------------- + + +def test_no_duplicates_rejects_repeat_value(): + ctx: ActivityContext[str] = ActivityContext(no_duplicates=True) + with ctx("x"), pytest.raises(ActivityContextDuplicateException), ctx("x"): + pass + # Stack unwound; the still-active "x" was preserved through the failed push. + assert len(ctx) == 0 + + +def test_no_duplicates_allows_distinct_values(): + ctx: ActivityContext[str] = ActivityContext(no_duplicates=True) + with ctx("x"), ctx("y"): + assert "x" in ctx and "y" in ctx + + +# --------------------------------------------------------------------------- +# reversed=True: insert at front, pop from front +# --------------------------------------------------------------------------- + + +def test_reversed_pushes_to_front_and_pops_from_front(): + """``Passage.active_passages_context`` uses ``reversed=True`` so the + *first* active passage in iteration order is the innermost. Pin both + insert position and pop position.""" + ctx: ActivityContext[str] = ActivityContext(reversed=True) + with ctx("a"): + with ctx("b"): + # b inserted at front of stack. + assert ctx[0] == "b" + assert ctx[1] == "a" + # Pop from front: only "a" left — runs between the inner and outer + # exits, which is why these withs can't be combined. + assert list(ctx[:]) == ["a"] + + +# --------------------------------------------------------------------------- +# Exception safety: stack unwinds even if the caller's body raises +# --------------------------------------------------------------------------- + + +def test_stack_unwinds_when_body_raises(): + """A bug here would leak stack frames — the next forward pass would see + a stale active passage. This is the silent-failure scenario.""" + ctx: ActivityContext[str] = ActivityContext() + with pytest.raises(ValueError, match="boom"), ctx("a"): + assert ctx.get_active() == "a" + raise ValueError("boom") + assert len(ctx) == 0 + + +# --------------------------------------------------------------------------- +# is_submodule_of / is_submodule_or_same — string predicates used by passage.py +# --------------------------------------------------------------------------- + + +def test_is_submodule_of_proper_descendant(): + assert is_submodule_of("model.layers.0.self_attn", "model.layers.0") + assert is_submodule_of("model.layers.0", "model") + # Empty string parent matches any non-empty name (root-of-everything case). + assert is_submodule_of("model", "") + + +def test_is_submodule_of_rejects_self_and_unrelated(): + assert not is_submodule_of("model.layers.0", "model.layers.0") + assert not is_submodule_of("model.layers.0", "model.layers.1") + # Empty == empty is not a submodule relationship. + assert not is_submodule_of("", "") + # Prefix collision: "model.layers" is NOT a submodule of "model.lay" — the + # predicate requires a literal "." separator after the parent. + assert not is_submodule_of("model.layers", "model.lay") + + +def test_is_submodule_or_same_includes_equality(): + assert is_submodule_or_same("model.layers.0", "model.layers.0") + assert is_submodule_or_same("model.layers.0.attn", "model.layers.0") + assert not is_submodule_or_same("model.layers.0", "model.layers.1") diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_function_target_kwargs.py b/tests/unit/torch/puzzletron/test_sewing_kit_function_target_kwargs.py new file mode 100644 index 00000000000..1e412605435 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_sewing_kit_function_target_kwargs.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Regression test for ``FunctionTarget`` kwargs dispatch. + +The bypass-distillation factory stitches teacher and student block outputs into +a per-block loss function using ``InputArgs(target=...)`` and ``InputArgs(input=...)`` +adapters (see ``stitched_model_factory.py:~545``). The loss function is then +invoked by ``StitchedModule.forward`` at ``core.py:600`` as +``node.target.function(*input_args.args, **input_args.kwargs)`` — i.e. with +**named kwargs**. + +If sewing_kit ever switched to positional dispatch in stitch-declaration order, +asymmetric losses (KL divergence, relative-L2, anything where ``f(a, b) != f(b, a)``) +would silently swap their arguments. MSE-shaped losses would hide the regression +because they're symmetric. This test pins the contract. +""" + +import torch + +from modelopt.torch.puzzletron.sewing_kit.core import ExternalTarget, FunctionTarget, Needle +from modelopt.torch.puzzletron.sewing_kit.passage import InputArgs + + +def test_function_target_invoked_with_kwargs_not_positional(): + """The function callable must receive only kwargs (no positional args).""" + received: dict[str, object] = {} + + def record_call(*args, **kwargs): + received["args"] = args + received["kwargs"] = dict(kwargs) + # The output stitch needs *something* to carry — return a sentinel scalar. + return torch.tensor(0.0) + + loss_target = FunctionTarget("loss_fn", record_call) + teacher_value = torch.full((2, 3), 7.0) + student_value = torch.full((2, 3), 11.0) + + # Stitch order is intentionally reversed from the real factory: declare + # student-first, teacher-second. If dispatch were positional-in-declaration- + # order, ``input`` would receive the teacher value and ``target`` the student + # value — which the assertions below would catch. + stitched = ( + Needle() + .stitch( + ExternalTarget().output( + name="student_act", + adapter=lambda v: InputArgs(input=v), + ), + loss_target.input(), + ) + .stitch( + ExternalTarget().output( + name="teacher_act", + adapter=lambda v: InputArgs(target=v), + ), + loss_target.input(), + ) + .stitch( + loss_target.output(), + ExternalTarget().output(name="loss"), + ) + .knot() + ) + + stitched( + {}, + {"student_act": student_value, "teacher_act": teacher_value}, + ) + + assert received["args"] == (), ( + f"FunctionTarget called with positional args {received['args']!r}. " + f"Sewing-kit must dispatch with kwargs only; positional dispatch would " + f"silently swap input/target for asymmetric losses." + ) + assert set(received["kwargs"].keys()) == {"input", "target"} + assert torch.equal(received["kwargs"]["input"], student_value) + assert torch.equal(received["kwargs"]["target"], teacher_value) + + +def test_function_target_kwargs_independent_of_stitch_order(): + """Same as the test above, but with the *real factory's* stitch order + (teacher first, student second). Both orders must produce identical kwargs + — the InputArgs.__add__ kwargs merge is order-independent for distinct + keys.""" + received: dict[str, object] = {} + + def record_call(*args, **kwargs): + received["args"] = args + received["kwargs"] = dict(kwargs) + return torch.tensor(0.0) + + loss_target = FunctionTarget("loss_fn", record_call) + teacher_value = torch.full((2, 3), 13.0) + student_value = torch.full((2, 3), 17.0) + + stitched = ( + Needle() + .stitch( + ExternalTarget().output( + name="teacher_act", + adapter=lambda v: InputArgs(target=v), + ), + loss_target.input(), + ) + .stitch( + ExternalTarget().output( + name="student_act", + adapter=lambda v: InputArgs(input=v), + ), + loss_target.input(), + ) + .stitch( + loss_target.output(), + ExternalTarget().output(name="loss"), + ) + .knot() + ) + + stitched( + {}, + {"teacher_act": teacher_value, "student_act": student_value}, + ) + + assert received["args"] == () + assert set(received["kwargs"].keys()) == {"input", "target"} + assert torch.equal(received["kwargs"]["input"], student_value) + assert torch.equal(received["kwargs"]["target"], teacher_value) diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py b/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py new file mode 100644 index 00000000000..a568fadc07b --- /dev/null +++ b/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py @@ -0,0 +1,164 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``sewing_kit.passage.InputArgs``. + +``InputArgs`` is the workhorse args/kwargs container the bypass distillation +factory uses inside its stitching reducers — see ``bypass_factory_fn`` calls +like ``lambda acc, override, orig, *args: override + orig.drop_args(0)``. +A regression in ``__add__`` or ``drop_args`` would silently corrupt the +inputs passed into per-block forward passes, producing wrong loss values +without any loud failure. +""" + +import pytest + +from modelopt.torch.puzzletron.sewing_kit.passage import InputArgs + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +def test_init_accepts_positional_and_keyword_args(): + ia = InputArgs(1, 2, foo="bar") + assert ia.args == [1, 2] + assert ia.kwargs == {"foo": "bar"} + + +def test_init_with_no_args_is_empty(): + ia = InputArgs() + assert ia.args == [] + assert ia.kwargs == {} + + +# --------------------------------------------------------------------------- +# __add__: concatenates args, merges kwargs (right wins on collision) +# --------------------------------------------------------------------------- + + +def test_add_concatenates_positional_args_in_order(): + a = InputArgs(1, 2) + b = InputArgs(3, 4) + result = a + b + assert result.args == [1, 2, 3, 4] + assert result.kwargs == {} + + +def test_add_merges_kwargs_with_right_winning(): + """Bypass reducers chain ``override + orig.drop_args(0)`` — when both sides + happen to set the same kwarg, the right-side value (the original input) + must win, otherwise the override silently displaces the original kwarg.""" + a = InputArgs(foo="from_a", bar="only_a") + b = InputArgs(foo="from_b", baz="only_b") + result = a + b + assert result.kwargs == {"foo": "from_b", "bar": "only_a", "baz": "only_b"} + + +def test_add_does_not_mutate_operands(): + a = InputArgs(1, 2, x="a") + b = InputArgs(3, y="b") + _ = a + b + assert a.args == [1, 2] and a.kwargs == {"x": "a"} + assert b.args == [3] and b.kwargs == {"y": "b"} + + +def test_add_rejects_non_input_args(): + # ``__add__`` enforces InputArgs+InputArgs only via an internal assert. + # ruff's RUF005 auto-fix to ``[*InputArgs(1), 2]`` would silently replace + # the operator call we're testing — keep the explicit ``+`` form. + with pytest.raises(AssertionError): + InputArgs(1) + [2] # type: ignore[operator] # noqa: RUF005 + + +# --------------------------------------------------------------------------- +# drop_args: clears all positional args (default) or one by index/slice +# --------------------------------------------------------------------------- + + +def test_drop_args_default_clears_all_positional(): + """The ``drop_args(0)`` and ``drop_args()`` forms are both used by bypass + stitches — the default-no-arg form must wipe the entire positional tuple + (kwargs untouched).""" + ia = InputArgs(1, 2, 3, foo="bar") + out = ia.drop_args() + assert out.args == [] + assert out.kwargs == {"foo": "bar"} + # And the original is unmodified. + assert ia.args == [1, 2, 3] + + +def test_drop_args_with_index_drops_one(): + ia = InputArgs(10, 20, 30) + out = ia.drop_args(0) + assert out.args == [20, 30] + # Source preserved. + assert ia.args == [10, 20, 30] + + +def test_drop_args_with_slice_drops_range(): + ia = InputArgs(10, 20, 30, 40) + out = ia.drop_args(slice(1, 3)) + assert out.args == [10, 40] + + +# --------------------------------------------------------------------------- +# drop_kwargs: clears all kwargs (default) or specific keys +# --------------------------------------------------------------------------- + + +def test_drop_kwargs_default_clears_all(): + ia = InputArgs(1, foo="bar", baz="qux") + out = ia.drop_kwargs() + assert out.args == [1] + assert out.kwargs == {} + + +def test_drop_kwargs_with_keys_drops_only_those(): + ia = InputArgs(1, foo="bar", baz="qux", keep="this") + out = ia.drop_kwargs(["foo", "baz"]) + assert out.kwargs == {"keep": "this"} + + +def test_drop_kwargs_silently_ignores_missing_keys(): + """A key listed in ``drop_kwargs`` that isn't present must not raise — + bypass calls this against args from arbitrary upstream stitches and may + pass keys that only some sources produce.""" + ia = InputArgs(foo="bar") + out = ia.drop_kwargs(["nonexistent"]) # must not KeyError + assert out.kwargs == {"foo": "bar"} + + +# --------------------------------------------------------------------------- +# from_value: lifts assorted values into InputArgs +# --------------------------------------------------------------------------- + + +def test_from_value_passes_through_existing_input_args(): + ia = InputArgs(1, foo="bar") + out = InputArgs.from_value(ia) + assert out is ia + + +def test_from_value_lifts_sequence_to_positional_args(): + out = InputArgs.from_value([1, 2, 3]) + assert out.args == [1, 2, 3] + assert out.kwargs == {} + + +def test_from_value_lifts_scalar_to_single_positional(): + out = InputArgs.from_value(42) + assert out.args == [42] + assert out.kwargs == {} diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_needle.py b/tests/unit/torch/puzzletron/test_sewing_kit_needle.py new file mode 100644 index 00000000000..a3db5ef30b8 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_sewing_kit_needle.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``sewing_kit.core.Needle`` graph construction and validation. + +The bypass factory builds three ``Needle``\\s per rank (teacher train, teacher +val, student val) and calls ``Needle.knot()`` on each. ``knot()`` runs +``_validate_nodes`` first; a regression in that validation would either crash +with an opaque NoneType error during forward, or — worse — silently allow a +malformed graph that produces incorrect activations. + +We test the validation contract on CPU without instantiating ``StitchedModule`` +itself (which requires Module patching). ``_validate_nodes`` is a private +method but it's the unit of behavior worth pinning; ``knot()`` is essentially +``_validate_nodes() + StitchedModule(...)``. +""" + +import pytest +import torch.nn as nn + +from modelopt.torch.puzzletron.sewing_kit.core import ( + ExternalTarget, + InputsLoopFoundException, + ModuleTarget, + Needle, + Node, + OnlyInternalNodesException, + StitchDescriptor, +) + +# --------------------------------------------------------------------------- +# get_node_for_target: lazy creation, cached lookup +# --------------------------------------------------------------------------- + + +def test_get_node_for_target_creates_node_on_first_call(): + needle = Needle() + target = ModuleTarget("a", nn.Linear(2, 2)) + node = needle.get_node_for_target(target) + assert isinstance(node, Node) + assert node.target is target + assert needle.nodes[target] is node + + +def test_get_node_for_target_returns_same_node_on_repeat_call(): + """Re-getting the same target must NOT create a duplicate node — every + stitch involving that target must funnel into a single Node, otherwise + the validation/forward graph fragments.""" + needle = Needle() + target = ModuleTarget("a", nn.Linear(2, 2)) + node1 = needle.get_node_for_target(target) + node2 = needle.get_node_for_target(target) + assert node1 is node2 + assert len(needle.nodes) == 1 + + +# --------------------------------------------------------------------------- +# stitch: adds StitchDescriptor to source.stitches_from and dest.stitches_to +# --------------------------------------------------------------------------- + + +def test_stitch_records_descriptor_on_both_endpoints(): + needle = Needle() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + + needle.stitch(target_a.output("x"), target_b.input("y")) + + node_a = needle.nodes[target_a] + node_b = needle.nodes[target_b] + # Source endpoint: A has one outgoing stitch; B has one incoming stitch. + assert len(node_a.stitches_from) == 1 + assert len(node_a.stitches_to) == 0 + assert len(node_b.stitches_from) == 0 + assert len(node_b.stitches_to) == 1 + # Same StitchDescriptor object on both lists. + assert node_a.stitches_from[0] is node_b.stitches_to[0] + assert isinstance(node_a.stitches_from[0], StitchDescriptor) + + +def test_stitch_returns_self_for_chaining(): + """Bypass factory chains ``.stitch(...).stitch(...)`` — the return type + must be the Needle itself so the second call sees the same graph.""" + needle = Needle() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + + out = needle.stitch(target_a.output("x"), target_b.input("y")) + assert out is needle + + +# --------------------------------------------------------------------------- +# _validate_nodes: contract checks before knot() builds the StitchedModule +# --------------------------------------------------------------------------- + + +def test_validate_raises_when_only_internal_nodes_present(): + """A graph with no External and no Remote target has nothing for the + runtime to feed inputs through — must raise loudly rather than build a + dead StitchedModule.""" + needle = Needle() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + needle.stitch(target_a.output("x"), target_b.input("y")) + + with pytest.raises(OnlyInternalNodesException): + needle._validate_nodes() + + +def test_validate_passes_with_external_plus_dag(): + """Happy path: ExternalTarget + a small linear DAG. Must not raise.""" + needle = Needle() + ext = ExternalTarget() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + + needle.stitch(ext.output("init"), target_a.input("entry")) + needle.stitch(target_a.output("x"), target_b.input("y")) + needle.stitch(target_b.output("z"), ext.input("final")) + + # No raise. + needle._validate_nodes() + + +def test_validate_raises_on_input_cycle_among_internal_nodes(): + """Detect a 2-node cycle A→B→A among internal nodes. + + The validation uses ``_search_loops`` walking ``stitches_to`` (incoming + edges); ExternalTarget short-circuits the recursion, so we add an + external feed to A so ``_validate_nodes`` doesn't bail out early on the + 'no external' check. + """ + needle = Needle() + ext = ExternalTarget() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + + # Anchor an external feed so we get past the OnlyInternalNodes check. + needle.stitch(ext.output("init"), target_a.input("entry")) + # Cycle: A -> B -> A. + needle.stitch(target_a.output("x"), target_b.input("y")) + needle.stitch(target_b.output("p"), target_a.input("q")) + + with pytest.raises(InputsLoopFoundException): + needle._validate_nodes() + + +def test_validate_passes_when_external_node_has_self_referential_loop_via_external(): + """``_search_loops`` short-circuits at ExternalTarget. So a 'loop' that + only goes through external (e.g. external→A and A→external) is fine — + and indeed required for normal stitching, where external is both the + input and output endpoint. + """ + needle = Needle() + ext = ExternalTarget() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + + needle.stitch(ext.output("in"), target_a.input("entry")) + needle.stitch(target_a.output("x"), ext.input("out")) + + # Despite the external→A→external pattern, this is the canonical bypass + # shape and must validate clean. + needle._validate_nodes() + + +# --------------------------------------------------------------------------- +# Sanity: ExternalTarget.input()/output() builds correctly typed descriptors +# --------------------------------------------------------------------------- + + +def test_module_target_descriptors_carry_target_and_name(): + """The ``.input("foo")`` and ``.output("bar")`` builders are what the + bypass factory uses to construct stitches. They must propagate the + target reference and the name into the resulting descriptor so the + runtime can route values correctly.""" + target = ModuleTarget("a", nn.Linear(2, 2)) + in_desc = target.input("foo") + out_desc = target.output("bar") + assert in_desc.target is target + assert in_desc.input_name == "foo" + assert out_desc.target is target + assert out_desc.output_name == "bar"