Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions modelopt/torch/puzzletron/anymodel/model_descriptor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,19 @@ def uses_autocast() -> bool:
"""
return True

@staticmethod
def pruning_mixins() -> Dict[str, Any]:
"""Return available pruning mixins for bypass distillation.

Override in subclasses to provide model-specific pruning mixins, e.g.
``{"kv_heads": KVHeadsPruningMixIn(...), "experts_removal": ExpertRemovalPruningMixIn(...)}``.

Returns an empty dict by default so that descriptors that do not need
model-specific weight-slicing (e.g. Llama with standard FFN truncation)
can rely on the generic ``create_child_state_dict`` fallback path.
"""
return {}

@staticmethod
def get_language_model_config(config):
"""Get the language model config from a PretrainedConfig.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ExpertRemovalLayerDescriptor,
ExpertRemovalPruningMixIn,
)
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn

# Expert removal is supported for unquantized models (test models).
# Production models use MXFP4 quantized MoE with combined tensors
Expand All @@ -37,7 +38,11 @@
from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory
from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size

__all__ = ["GptOssModelDescriptor", "GptOssExpertRemovalLayerDescriptor"]
__all__ = [
"GptOssExpertRemovalLayerDescriptor",
"GptOssKVHeadsLayerDescriptor",
"GptOssModelDescriptor",
]


@ModelDescriptorFactory.register_decorator("gpt_oss")
Expand Down Expand Up @@ -173,7 +178,29 @@ def pruning_mixins() -> Dict[str, PruningMixIn]:
Note: Expert removal works for unquantized models (test models).
Production models use MXFP4 quantization which is not yet supported.
"""
return {"expert_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())}
# Single instance shared between the canonical key and the legacy alias
# so resolve_pruning_mixin returns the same object regardless of which
# name a caller uses.
expert_mixin = ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())
return {
"experts_removal": expert_mixin,
# Backward-compat alias: this key was "expert_removal" before the
# bypass branch standardised on "experts_removal" (matching the
# NemotronH descriptor). Kept so external scripts that still call
# `resolve_pruning_mixin("expert_removal", GptOssModelDescriptor)`
# continue to work. Remove after a deprecation cycle.
"expert_removal": expert_mixin,
"kv_heads": KVHeadsPruningMixIn(GptOssKVHeadsLayerDescriptor()),
}


@dataclass
class GptOssKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
o_proj_name: str = "self_attn.o_proj"
attn_prefix_name: str = "model.layers.{layer_idx}.self_attn"
qkvo_weight_names: List[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,16 @@
ExpertRemovalLayerDescriptor,
ExpertRemovalPruningMixIn,
)
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn
from ....pruning.pruning_mixin import PruningMixIn
from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory
from ...puzzformer.no_op import MatchingZeros, Same

__all__ = ["NemotronHExpertRemovalLayerDescriptor", "NemotronHModelDescriptor"]
__all__ = [
"NemotronHExpertRemovalLayerDescriptor",
"NemotronHKVHeadsLayerDescriptor",
"NemotronHModelDescriptor",
]


def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]:
Expand All @@ -51,6 +56,15 @@ def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]:
return matches


@dataclass
class NemotronHKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
o_proj_name: str = "mixer.o_proj"
attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer"
qkvo_weight_names: List[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)


@dataclass
class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor):
target_name: str = "mixer.gate"
Expand Down Expand Up @@ -251,4 +265,5 @@ def build_attention_predicates() -> Dict[str, re.Pattern]:
def pruning_mixins() -> Dict[str, PruningMixIn]:
return {
"experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()),
"kv_heads": KVHeadsPruningMixIn(NemotronHKVHeadsLayerDescriptor()),
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,16 @@
FFNIntermediateLayerDescriptor,
FFNIntermediatePruningMixIn,
)
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn
from ....pruning.pruning_mixin import PruningMixIn
from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory
from ...puzzformer.no_op import MatchingZeros, Same

__all__ = ["NemotronHV2FFNIntermediateLayerDescriptor", "NemotronHV2ModelDescriptor"]
__all__ = [
"NemotronHV2FFNIntermediateLayerDescriptor",
"NemotronHV2KVHeadsLayerDescriptor",
"NemotronHV2ModelDescriptor",
]


def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]:
Expand Down Expand Up @@ -69,6 +74,15 @@ class NemotronHV2FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor):
linear_weight_names: List[str] = field(default_factory=lambda: ["down_proj", "up_proj"])


@dataclass
class NemotronHV2KVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
o_proj_name: str = "mixer.o_proj"
attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer"
qkvo_weight_names: List[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)


@ModelDescriptorFactory.register_decorator("nemotron_h_v2")
class NemotronHV2ModelDescriptor(ModelDescriptor):
_DECODER_LAYER_CLS: Type[nn.Module] = None
Expand Down Expand Up @@ -251,5 +265,6 @@ def pruning_mixins() -> Dict[str, PruningMixIn]:
"ffn_intermediate": FFNIntermediatePruningMixIn(
NemotronHV2FFNIntermediateLayerDescriptor()
),
"kv_heads": KVHeadsPruningMixIn(NemotronHV2KVHeadsLayerDescriptor()),
# TODO: Add expert removal support when ExpertRemovalPruningMixIn is migrated
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@
)

from ....block_config import BlockConfig
from ....pruning.expert_removal_pruning_mixin import ExpertRemovalLayerDescriptor
from ....pruning.expert_removal_pruning_mixin import (
ExpertRemovalLayerDescriptor,
ExpertRemovalPruningMixIn,
)
from ....pruning.ffn_intermediate_pruning_mixin import FFNIntermediateLayerDescriptor
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor
from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn
from ....pruning.pruning_mixin import PruningMixIn
from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory
from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size

Expand Down Expand Up @@ -56,6 +60,13 @@ def get_language_model_config(config):
"""Qwen3-VL has nested text_config for language model parameters."""
return config.text_config if hasattr(config, "text_config") else config

@staticmethod
def pruning_mixins() -> Dict[str, PruningMixIn]:
return {
"experts_removal": ExpertRemovalPruningMixIn(Qwen3VLExpertRemovalLayerDescriptor()),
"kv_heads": KVHeadsPruningMixIn(Qwen3VLKVHeadsLayerDescriptor()),
}

@staticmethod
def decoder_layer_cls():
return Qwen3VLMoeTextDecoderLayer
Expand Down
9 changes: 7 additions & 2 deletions modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
)

from .pruning_mixin import LayerDescriptor, PruningMixIn
from .pruning_utils import GQAInitMode, _init_attention_biases, _init_attention_weights
from .pruning_utils import (
GQAInitMode,
_init_attention_biases,
_init_attention_weights,
_lm_head_dim,
)

__all__ = [
"KVHeadsLayerDescriptor",
Expand Down Expand Up @@ -74,7 +79,7 @@ def prune_single_layer(
f"{attn_prefix}.{proj_name}" for proj_name in self.layer_descriptor.qkvo_weight_names
]

head_size = new_config.head_dim
head_size = _lm_head_dim(new_config)
for part in ["weight", "bias"]:
attn_keys = [f"{name}.{part}" for name in [q_name, k_name, v_name, o_name]]
q_key, k_key, v_key, o_key = attn_keys
Expand Down
60 changes: 50 additions & 10 deletions modelopt/torch/puzzletron/pruning/pruning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class MlpInitMode(Enum):
PruneByActivationsLog = "PruneByActivationsLog"
ExpertRemoval = "ExpertRemoval"
ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN"
MoEChannelPruning = "MoEChannelPruning"


class LinearInitMode(Enum):
Expand All @@ -66,6 +67,30 @@ class HiddenSizeInitMode(Enum):
CopyAsIs = "CopyAsIs"


def _lm_attrs(config):
"""Return the language-model sub-config for VL configs, else the config itself.

VL configs nest language-model fields like ``num_attention_heads``, ``head_dim``,
and ``hidden_size`` under a sub-config. The attribute name varies by family —
``text_config`` (Qwen3-VL, Llava, Idefics), ``language_config`` (Llama-4 and a
handful of others), and ``llm_config`` (InternVL and friends) are all common.
Probe each before falling back to the raw config.
"""
for attr in ("text_config", "language_config", "llm_config"):
sub = getattr(config, attr, None)
if sub is not None:
return sub
return config


def _lm_head_dim(config) -> int:
lm_config = _lm_attrs(config)
head_dim = getattr(lm_config, "head_dim", None)
if head_dim is not None:
return head_dim
return lm_config.hidden_size // lm_config.num_attention_heads


def resolve_pruning_mixin(
pruning_mixin, descriptor: Type[ModelDescriptor]
) -> PruningMixIn | List[PruningMixIn]:
Expand Down Expand Up @@ -224,10 +249,13 @@ def _init_attention_weights(
head_size,
mlp_init_config,
):
assert new_config.num_attention_heads == original_config.num_attention_heads, (
f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})"
new_lm = _lm_attrs(new_config)
orig_lm = _lm_attrs(original_config)
assert new_lm.num_attention_heads == orig_lm.num_attention_heads, (
f"({new_lm.num_attention_heads=}) != ({orig_lm.num_attention_heads=})"
)
num_q_heads = new_config.num_attention_heads
num_q_heads = new_lm.num_attention_heads
# block_configs lives on the outer puzzletron-converted config, not on text_config.
num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads
orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads

Expand Down Expand Up @@ -372,17 +400,29 @@ def _init_attention_biases(
head_size,
mlp_init_config,
):
assert new_config.num_attention_heads == original_config.num_attention_heads, (
f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})"
new_lm = _lm_attrs(new_config)
orig_lm = _lm_attrs(original_config)
assert new_lm.num_attention_heads == orig_lm.num_attention_heads, (
f"({new_lm.num_attention_heads=}) != ({orig_lm.num_attention_heads=})"
)
num_q_heads = new_config.num_attention_heads
num_q_heads = new_lm.num_attention_heads
# block_configs lives on the outer puzzletron-converted config, not on text_config.
num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads
orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads
n_heads_in_group = num_q_heads // num_kv_heads
orig_n_heads_in_group = num_q_heads // orig_num_kv_heads

o_proj_bias = new_config.o_proj_bias
attention_bias = new_config.attention_bias
# Some HF native configs (e.g. GptOssConfig) don't expose o_proj_bias / attention_bias as
# top-level attributes the way puzzletron's DeciLM-style configs do. Fall back to probing
# the new state dict for the actual bias keys when the attribute is missing.
# KVHeadsPruningMixIn only calls this helper after filtering to keys present in
# new_state_dict, so the probe mirrors the caller's already-selected bias tensors.
o_proj_bias = getattr(new_config, "o_proj_bias", None)
if o_proj_bias is None:
o_proj_bias = o_key in new_state_dict
attention_bias = getattr(new_config, "attention_bias", None)
if attention_bias is None:
attention_bias = q_key in new_state_dict
Comment thread
Separius marked this conversation as resolved.

# If no biases
if not (o_proj_bias or attention_bias):
Expand Down Expand Up @@ -438,8 +478,8 @@ def _init_attention_biases(
assert not is_original_mha, (
"Degrouping can only be done on original models that are GQA themselves."
)
n_groups = new_config.num_attention_heads // n_heads_in_group
orig_n_groups = original_config.num_attention_heads // orig_n_heads_in_group
n_groups = new_lm.num_attention_heads // n_heads_in_group
orig_n_groups = orig_lm.num_attention_heads // orig_n_heads_in_group
assert n_groups % orig_n_groups == 0, f"{n_groups=} must be a divisor of {orig_n_groups=}"
n_repeats = n_groups // orig_n_groups
if n_repeats > 1:
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/puzzletron/sewing_kit/passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"PassageOutput",
"Predicate",
"always_false_predicate",
"always_true_predicate",
"Passage",
"patch_module",
]
Expand Down
Loading
Loading