diff --git a/docs/source/guides/_quant_cfg.rst b/docs/source/guides/_quant_cfg.rst index 6a027b74b00..4223f29e610 100644 --- a/docs/source/guides/_quant_cfg.rst +++ b/docs/source/guides/_quant_cfg.rst @@ -55,9 +55,12 @@ Each entry in the list is a dictionary with the following fields: (e.g. ``"nn.Linear"``). If omitted, all modules are targeted regardless of class. * - ``cfg`` - No - - A dict of quantizer attributes as defined by :class:`QuantizerAttributeConfig - `, or a list of such dicts - for sequential quantization (see :ref:`sequential-quantizers`). + - A :class:`QuantizerAttributeConfig + `, or a list of + ``QuantizerAttributeConfig`` objects for sequential quantization (see + :ref:`sequential-quantizers`). Equivalent Python dicts, YAML mappings, and lists of + dicts are still accepted for backward compatibility, but those weakly schematized forms + are deprecated. * - ``enable`` - No - ``True`` or ``False``. Toggles matched quantizers on or off, independently of ``cfg``. @@ -74,6 +77,11 @@ Each entry in the list is a dictionary with the following fields: a bare ``{"quantizer_name": "*"}`` would silently behave as ``enable=True`` for all quantizers. + Schema-backed YAML loading parses ``cfg`` mappings into + :class:`QuantizerAttributeConfig ` + values. Plain Python dicts and lists of dicts are accepted only as a backward-compatible, + weakly schematized input format. + ---------- Default Quantizer Configuration @@ -278,7 +286,9 @@ For entirely custom recipes, compose the list directly: Sequential Quantization ======================= -When ``cfg`` is a **list** of attribute dicts, the matched +When ``cfg`` is a **list** of +:class:`QuantizerAttributeConfig ` +values, the matched :class:`TensorQuantizer ` is replaced with a :class:`SequentialQuantizer ` @@ -295,6 +305,11 @@ are quantized first in INT4 and then in FP8: ], } +The list-of-dict spelling shown above remains accepted for existing Python configs and is the +natural YAML spelling, but it is a deprecated weakly schematized input form. After schema-backed +loading or :class:`QuantizeConfig ` parsing, +each element is a ``QuantizerAttributeConfig``. + ---------- .. _migrating-from-dict-format: diff --git a/examples/diffusers/quantization/config.py b/examples/diffusers/quantization/config.py index e15b8c7ba3c..fada8bc8cd9 100644 --- a/examples/diffusers/quantization/config.py +++ b/examples/diffusers/quantization/config.py @@ -13,9 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Mapping, MutableMapping + import torch.nn as nn from calib.plugin_calib import PercentileCalibrator +from modelopt.torch.quantization.config import QuantizerAttributeConfig + FP8_DEFAULT_CONFIG = { "quant_cfg": [ {"quantizer_name": "*", "enable": False}, @@ -104,8 +108,13 @@ def set_quant_config_attr(quant_config, trt_high_precision_dtype, quant_algo, ** quant_config["algorithm"] = algo_cfg for entry in quant_config["quant_cfg"]: - p = entry.get("cfg", {}) - if isinstance(p, dict) and "num_bits" in p and "trt_high_precision_dtype" not in p: + p = entry.get("cfg", {}) if isinstance(entry, Mapping) else {} + if not isinstance(p, MutableMapping): + continue + keys = p.explicit_keys() if isinstance(p, QuantizerAttributeConfig) else p.keys() + # TODO: Replace this membership-based config patching with a better config API; + # ``in``/``not in`` checks are fragile with schema-backed defaults. + if "num_bits" in keys and "trt_high_precision_dtype" not in keys: p["trt_high_precision_dtype"] = trt_high_precision_dtype diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 2a3c947a2d6..07178f22f50 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -14,9 +14,11 @@ # limitations under the License. import argparse +import copy import logging import sys import time as time +from collections.abc import Mapping from pathlib import Path from typing import Any @@ -49,6 +51,7 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.opt.config import ModeloptBaseConfig def setup_logging(verbose: bool = False) -> logging.Logger: @@ -119,14 +122,6 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: base_cfg = mtq.INT8_SMOOTHQUANT_CFG else: base_cfg = INT8_DEFAULT_CONFIG - if self.config.collect_method != CollectMethod.DEFAULT: - reset_set_int8_config( - base_cfg, - self.config.percentile, - n_steps, - collect_method=self.config.collect_method.value, - backbone=backbone, - ) elif self.config.format == QuantFormat.FP8: base_cfg = FP8_DEFAULT_CONFIG elif self.config.format == QuantFormat.FP4: @@ -138,15 +133,33 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: raise NotImplementedError(f"Unknown format {self.config.format}") # Build a fresh config dict so we never mutate the global constants. + if isinstance(base_cfg, ModeloptBaseConfig): + base_cfg = base_cfg.model_dump(exclude_unset=True) + base_cfg = copy.deepcopy(base_cfg) + + if ( + self.config.format == QuantFormat.INT8 + and self.config.collect_method != CollectMethod.DEFAULT + ): + reset_set_int8_config( + base_cfg, + self.config.percentile, + n_steps, + collect_method=self.config.collect_method.value, + backbone=backbone, + ) + quant_cfg_list = list(base_cfg["quant_cfg"]) if self.config.format == QuantFormat.FP4: for i, entry in enumerate(quant_cfg_list): - if isinstance(entry, dict) and "block_sizes" in entry.get("cfg", {}): - new_block_sizes = {**entry["cfg"]["block_sizes"], -1: self.config.block_size} + cfg = entry.get("cfg", {}) if isinstance(entry, Mapping) else {} + block_sizes = cfg.get("block_sizes") if isinstance(cfg, Mapping) else None + if isinstance(block_sizes, Mapping): + new_block_sizes = {**block_sizes, -1: self.config.block_size} quant_cfg_list[i] = { **entry, - "cfg": {**entry["cfg"], "block_sizes": new_block_sizes}, + "cfg": {**cfg, "block_sizes": new_block_sizes}, } if self.config.quantize_mha: diff --git a/examples/llm_autodeploy/run_auto_quantize.py b/examples/llm_autodeploy/run_auto_quantize.py index ebd7c1090bb..f899300192a 100644 --- a/examples/llm_autodeploy/run_auto_quantize.py +++ b/examples/llm_autodeploy/run_auto_quantize.py @@ -21,10 +21,11 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.config import QuantizeConfig from modelopt.torch.utils import create_forward_loop from modelopt.torch.utils.dataset_utils import get_dataset_dataloader -SUPPORT_QUANT_FORMAT = { +SUPPORT_QUANT_FORMAT: dict[str, QuantizeConfig] = { "fp8": mtq.FP8_DEFAULT_CFG, "nvfp4": mtq.NVFP4_DEFAULT_CFG, } @@ -87,7 +88,7 @@ def loss_func(output, data): data_loader=calib_dataloader, forward_step=lambda model, batch: model(**batch), loss_func=loss_func, - quantization_formats=[SUPPORT_QUANT_FORMAT[format] for format in qformat_list], + quantization_formats=[SUPPORT_QUANT_FORMAT[quant_format] for quant_format in qformat_list], num_calib_steps=len(calib_dataloader), num_score_steps=min( len(calib_dataloader), 128 // batch_size diff --git a/examples/llm_ptq/cast_mxfp4_to_nvfp4.py b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py index 26f3c9f8258..2223f3cf055 100644 --- a/examples/llm_ptq/cast_mxfp4_to_nvfp4.py +++ b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py @@ -34,6 +34,7 @@ """ import json +from collections.abc import Mapping from contextlib import ExitStack, contextmanager from pathlib import Path @@ -304,7 +305,7 @@ def force_weight_quantizers_static(quant_cfg: list) -> None: qname = entry.get("quantizer_name", "") cfg = entry.get("cfg") or {} bs = cfg.get("block_sizes") - if "weight_quantizer" in qname and isinstance(bs, dict): + if "weight_quantizer" in qname and isinstance(bs, Mapping): quant_cfg[i] = {**entry, "cfg": {**cfg, "block_sizes": {**bs, "type": "static"}}} diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 9455157645c..99bdee253ef 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -23,6 +23,7 @@ import shutil import sys import warnings +from collections.abc import Mapping, MutableMapping from pathlib import Path from typing import Any @@ -41,6 +42,8 @@ ProcessorMixin, ) +from modelopt.torch.quantization.config import QuantizeConfig, QuantizerCfgEntry + try: from huggingface_hub import snapshot_download except ImportError: @@ -203,17 +206,17 @@ def calibrate_loop(_model): def build_quant_cfg( qformat, - quant_cfg, + quant_cfg: QuantizeConfig | Mapping[str, Any], awq_block_size, model_type, moe_calib_experts_ratio: float | None = None, -) -> dict[str, Any]: - quant_cfg = copy.deepcopy(quant_cfg) - if "awq" in str(quant_cfg.get("algorithm")): +) -> QuantizeConfig: + quant_cfg_obj: QuantizeConfig = QuantizeConfig.model_validate(copy.deepcopy(quant_cfg)) + if "awq" in str(quant_cfg_obj.get("algorithm")): from modelopt.torch.quantization.config import find_quant_cfg_entry_by_path weight_quantizer_entry = find_quant_cfg_entry_by_path( - quant_cfg["quant_cfg"], "*weight_quantizer" + quant_cfg_obj["quant_cfg"], "*weight_quantizer" ) weight_quantizer = weight_quantizer_entry.get("cfg") or {} if isinstance(weight_quantizer, list): @@ -224,34 +227,38 @@ def build_quant_cfg( # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: - quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} + quant_cfg_obj["algorithm"] = {"method": "awq_lite", "alpha_step": 1} if moe_calib_experts_ratio: assert 0 < moe_calib_experts_ratio <= 1, "moe_calib_experts_ratio must be between 0 and 1" - if isinstance(quant_cfg["algorithm"], str): - quant_cfg["algorithm"] = { - "method": quant_cfg["algorithm"], + if isinstance(quant_cfg_obj["algorithm"], str): + quant_cfg_obj["algorithm"] = { + "method": quant_cfg_obj["algorithm"], "moe_calib_experts_ratio": moe_calib_experts_ratio, } - elif isinstance(quant_cfg["algorithm"], dict): - quant_cfg["algorithm"]["moe_calib_experts_ratio"] = moe_calib_experts_ratio + elif isinstance(quant_cfg_obj["algorithm"], MutableMapping): + quant_cfg_obj["algorithm"]["moe_calib_experts_ratio"] = moe_calib_experts_ratio else: warnings.warn( - f"Quantization algorithm: {quant_cfg['algorithm']} does not support setting moe_calib_experts_ratio" + f"Quantization algorithm: {quant_cfg_obj['algorithm']} does not support setting moe_calib_experts_ratio" ) # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. if model_type == "gemma" and "int8_sq" in qformat: - quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} + quant_cfg_obj["algorithm"] = {"method": "smoothquant", "alpha": 0.5} if model_type == "phi4mm": # Only quantize the language model - quant_cfg["quant_cfg"].append({"quantizer_name": "*speech*", "enable": False}) - quant_cfg["quant_cfg"].append({"quantizer_name": "*audio*", "enable": False}) - quant_cfg["quant_cfg"].append({"quantizer_name": "*image*", "enable": False}) - quant_cfg["quant_cfg"].append({"quantizer_name": "*vision*", "enable": False}) + quant_cfg_obj["quant_cfg"].extend( + [ + QuantizerCfgEntry(quantizer_name="*speech*", enable=False), + QuantizerCfgEntry(quantizer_name="*audio*", enable=False), + QuantizerCfgEntry(quantizer_name="*image*", enable=False), + QuantizerCfgEntry(quantizer_name="*vision*", enable=False), + ] + ) - return quant_cfg + return quant_cfg_obj def is_speculative(hf_config): @@ -842,7 +849,7 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod def needs_checkpoint_path_update(quant_cfg: dict) -> bool: """Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath.""" algorithm = quant_cfg.get("algorithm") - if not isinstance(algorithm, dict): + if not isinstance(algorithm, Mapping): return False return algorithm.get("layerwise_checkpoint_dir") is not None diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 758ed75aeed..76c8521e093 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -66,7 +66,11 @@ save_expert_token_count_table, ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model -from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration +from modelopt.torch.quantization.config import ( + QuantizeConfig, + _default_disabled_quantizer_cfg, + need_calibration, +) from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized from modelopt.torch.speculative.eagle.utils import ( @@ -89,18 +93,20 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: """Set use_constant_amax on KV cache quantizers. - Creates a new dict for the KV bmm quantizer config to avoid mutating shared references. + Updates the matched KV bmm quantizer entry in place. """ - for i, entry in enumerate(quant_cfg): + for entry in quant_cfg: if entry.get("quantizer_name") != "*[kv]_bmm_quantizer": continue - cfg = entry.get("cfg") or {} - assert isinstance(cfg, dict) - quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}} + cfg = entry.get("cfg") + if cfg is None: + cfg = {} + cfg["use_constant_amax"] = True + entry["cfg"] = cfg break -QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { +QUANT_CFG_CHOICES: dict[str, QuantizeConfig] = { "int8": mtq.INT8_DEFAULT_CFG, "int8_sq": mtq.INT8_SMOOTHQUANT_CFG, "int8_wo": mtq.INT8_WEIGHT_ONLY_CFG, diff --git a/examples/llm_ptq/multinode_ptq.py b/examples/llm_ptq/multinode_ptq.py index 93ef21ea4d4..66c38268d2f 100644 --- a/examples/llm_ptq/multinode_ptq.py +++ b/examples/llm_ptq/multinode_ptq.py @@ -22,7 +22,6 @@ import time import warnings from pathlib import Path -from typing import Any import numpy as np import torch @@ -37,14 +36,14 @@ from modelopt.torch.export import get_model_type from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint -from modelopt.torch.quantization.config import need_calibration +from modelopt.torch.quantization.config import QuantizeConfig, need_calibration from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets # Constants RAND_SEED = 1234 -QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { +QUANT_CFG_CHOICES: dict[str, QuantizeConfig] = { "int8": mtq.INT8_DEFAULT_CFG, "int4_awq": mtq.INT4_AWQ_CFG, "fp8": mtq.FP8_DEFAULT_CFG, diff --git a/examples/vllm_serve/vllm_ptq_utils.py b/examples/vllm_serve/vllm_ptq_utils.py index 88b31d54a70..048b1e8722e 100644 --- a/examples/vllm_serve/vllm_ptq_utils.py +++ b/examples/vllm_serve/vllm_ptq_utils.py @@ -14,7 +14,7 @@ # limitations under the License. import dataclasses -from collections.abc import Callable +from collections.abc import Callable, Mapping from typing import Any import torch @@ -122,7 +122,7 @@ def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: list) -> list: ( e for e in kv_quant_cfg - if isinstance(e, dict) and e.get("quantizer_name") == "*[kv]_bmm_quantizer" + if isinstance(e, Mapping) and e.get("quantizer_name") == "*[kv]_bmm_quantizer" ), None, ) diff --git a/modelopt/onnx/llm_export_utils/quantization_utils.py b/modelopt/onnx/llm_export_utils/quantization_utils.py index 54ca93d5388..3f5aa1d910f 100644 --- a/modelopt/onnx/llm_export_utils/quantization_utils.py +++ b/modelopt/onnx/llm_export_utils/quantization_utils.py @@ -69,9 +69,7 @@ def get_quant_config(precision, lm_head_precision="fp16"): else: raise ValueError(f"Unsupported precision: {precision}") - quant_cfg_list: list = [ - e for e in quant_cfg["quant_cfg"] if isinstance(e, dict) and "quantizer_name" in e - ] + quant_cfg_list: list = list(quant_cfg["quant_cfg"]) if lm_head_precision == "fp8": quant_cfg_list.append( diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index 96f33012afd..59a0c69c6de 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -19,9 +19,6 @@ from enum import Enum -from pydantic import field_validator -from typing_extensions import NotRequired, TypedDict - from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.quantization.config import QuantizeConfig @@ -33,14 +30,18 @@ class RecipeType(str, Enum): # QAT = "qat" # Not implemented yet, will be added in the future. -class RecipeMetadataConfig(TypedDict): - """YAML shape of the recipe metadata section.""" +_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe." - recipe_type: RecipeType - description: NotRequired[str] +class RecipeMetadataConfig(ModeloptBaseConfig): + """YAML shape of the recipe metadata section.""" -_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe." + recipe_type: RecipeType + description: str = ModeloptField( + default=_DEFAULT_RECIPE_DESCRIPTION, + title="Description", + description="Human-readable recipe description.", + ) class ModelOptRecipeBase(ModeloptBaseConfig): @@ -50,32 +51,23 @@ class ModelOptRecipeBase(ModeloptBaseConfig): """ metadata: RecipeMetadataConfig = ModeloptField( - default={"recipe_type": RecipeType.PTQ, "description": _DEFAULT_RECIPE_DESCRIPTION}, + default=RecipeMetadataConfig( + recipe_type=RecipeType.PTQ, description=_DEFAULT_RECIPE_DESCRIPTION + ), title="Metadata", description="Recipe metadata containing the recipe type and description.", validate_default=True, ) - @field_validator("metadata") - @classmethod - def validate_metadata(cls, metadata: RecipeMetadataConfig) -> RecipeMetadataConfig: - """Validate recipe metadata and fill defaults for optional fields.""" - if metadata["recipe_type"] not in RecipeType: - raise ValueError( - f"Unsupported recipe type: {metadata['recipe_type']}. " - f"Only {list(RecipeType)} are currently supported." - ) - return {"description": _DEFAULT_RECIPE_DESCRIPTION, **metadata} - @property def recipe_type(self) -> RecipeType: """Return the recipe type from metadata.""" - return self.metadata["recipe_type"] + return self.metadata.recipe_type @property def description(self) -> str: """Return the recipe description from metadata.""" - return self.metadata.get("description", _DEFAULT_RECIPE_DESCRIPTION) + return self.metadata.description class ModelOptPTQRecipe(ModelOptRecipeBase): diff --git a/modelopt/recipe/loader.py b/modelopt/recipe/loader.py index 9c3c40856d2..e77419e4a24 100644 --- a/modelopt/recipe/loader.py +++ b/modelopt/recipe/loader.py @@ -22,7 +22,7 @@ from pathlib import Path from modelopt.torch.opt.config_loader import BUILTIN_CONFIG_ROOT as BUILTIN_RECIPES_LIB -from modelopt.torch.opt.config_loader import load_config +from modelopt.torch.opt.config_loader import _load_raw_config_with_schema, load_config from modelopt.torch.quantization.config import QuantizeConfig from .config import ModelOptPTQRecipe, ModelOptRecipeBase, RecipeMetadataConfig, RecipeType @@ -89,13 +89,13 @@ def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBas The file must contain a ``metadata`` section with at least ``recipe_type``, plus a ``quant_cfg`` mapping and an optional ``algorithm`` for PTQ recipes. """ - data = load_config(recipe_file, schema_type=ModelOptPTQRecipe) - if not isinstance(data, dict): + raw_data = _load_raw_config_with_schema(recipe_file).data + if not isinstance(raw_data, dict): raise ValueError( - f"Recipe file {recipe_file} must be a YAML mapping, got {type(data).__name__}." + f"Recipe file {recipe_file} must be a YAML mapping, got {type(raw_data).__name__}." ) - metadata = data.get("metadata", {}) + metadata = raw_data.get("metadata", {}) if not isinstance(metadata, dict): raise ValueError( f"Recipe file {recipe_file} field 'metadata' must be a mapping, " @@ -106,12 +106,9 @@ def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBas raise ValueError(f"Recipe file {recipe_file} must contain a 'metadata.recipe_type' field.") if recipe_type == RecipeType.PTQ: - if "quantize" not in data: + if "quantize" not in raw_data: raise ValueError(f"PTQ recipe file {recipe_file} must contain 'quantize'.") - return ModelOptPTQRecipe( - metadata=metadata, - quantize=data["quantize"], - ) + return load_config(recipe_file, schema_type=ModelOptPTQRecipe) raise ValueError(f"Unsupported recipe type: {recipe_type!r}") @@ -139,21 +136,11 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase: metadata_file = _find_recipe_section_file(recipe_dir, "metadata") metadata = load_config(metadata_file, schema_type=RecipeMetadataConfig) - if not isinstance(metadata, dict): - raise ValueError( - f"Metadata file {metadata_file} must be a YAML mapping, got {type(metadata).__name__}." - ) - recipe_type = metadata.get("recipe_type") - if recipe_type is None: - raise ValueError(f"Metadata file {metadata_file} must contain a 'recipe_type' field.") + recipe_type = metadata.recipe_type if recipe_type == RecipeType.PTQ: quantize_file = _find_recipe_section_file(recipe_dir, "quantize") quantize_data = load_config(quantize_file, schema_type=QuantizeConfig) - if not isinstance(quantize_data, dict): - raise ValueError( - f"{quantize_file} must be a YAML mapping, got {type(quantize_data).__name__}." - ) return ModelOptPTQRecipe( metadata=metadata, quantize=quantize_data, diff --git a/modelopt/torch/opt/config.py b/modelopt/torch/opt/config.py index 62f7b7e16a2..00279dd10d2 100644 --- a/modelopt/torch/opt/config.py +++ b/modelopt/torch/opt/config.py @@ -17,7 +17,7 @@ import fnmatch import json -from collections.abc import Callable, ItemsView, Iterator, KeysView, ValuesView +from collections.abc import Callable, ItemsView, Iterator, KeysView, MutableMapping, ValuesView from typing import Any, TypeAlias import torch @@ -57,11 +57,18 @@ def ModeloptField(default: Any = PydanticUndefined, **kwargs): # noqa: N802 # TODO: expand config classes to searcher -class ModeloptBaseConfig(BaseModel): +class ModeloptBaseConfig(BaseModel, MutableMapping[str, Any]): """Our config base class for mode configuration. - The base class extends the capabilities of pydantic's BaseModel to provide additional methods - and properties for easier access and manipulation of the configuration. + The base class extends pydantic's BaseModel with a mapping interface so schema-backed + config objects can keep the dict-style access patterns used by older ModelOpt code. + + This is intentionally a fixed-key mutable mapping instead of a general dict. The mapping + keys are the model fields exposed through their aliases when present, and lookups accept + either a field name or its alias. Values are read and written through the pydantic model, so + assignment validation still applies. New keys cannot be inserted, and existing keys cannot + be deleted because the schema defines the complete key set; callers that need omission + semantics should use model_dump(exclude_unset=True) or the explicit_* helpers. """ model_config = PyDanticConfigDict(extra="forbid", validate_assignment=True) @@ -111,25 +118,42 @@ def __contains__(self, key: str) -> bool: def __getitem__(self, key: str) -> Any: """Get the value for the given key (can be name or alias of field).""" - return getattr(self, self.get_field_name_from_key(key)) + try: + return getattr(self, self.get_field_name_from_key(key)) + except AttributeError as e: + raise KeyError(key) from e def __setitem__(self, key: str, value: Any) -> None: - """Set the value for the given key (can be name or alias of field).""" - setattr(self, self.get_field_name_from_key(key), value) + """Set an existing field by name or alias, preserving pydantic assignment validation.""" + try: + field_name = self.get_field_name_from_key(key) + except AttributeError as e: + raise KeyError(key) from e + if field_name not in type(self).model_fields: + raise KeyError(key) + setattr(self, field_name, value) + + def __delitem__(self, key: str) -> None: + """Reject deletion because ModeloptBaseConfig exposes a fixed schema key set.""" + try: + self.get_field_name_from_key(key) + except AttributeError as e: + raise KeyError(key) from e + raise TypeError("Config mapping keys are fixed and cannot be deleted.") def get(self, key: str, default: Any = None) -> Any: """Get the value for the given key (can be name or alias) or default if not found.""" try: return self[key] - except AttributeError: + except KeyError: return default def __len__(self) -> int: - """Return the length of the config.""" - return len(self.model_fields) + len(self._iterable_model_extra) + """Return the number of schema and extra keys exposed by the mapping.""" + return len(type(self).model_fields) + len(self._iterable_model_extra) def __iter__(self) -> Iterator[str]: - """Iterate over aliases (or name if alias is not defined) of fields.""" + """Iterate over schema keys, preferring aliases over field names.""" for field_name, field_info in type(self).model_fields.items(): yield field_info.alias or field_name yield from self._iterable_model_extra @@ -138,6 +162,29 @@ def _get_kv_dict(self) -> dict[str, Any]: """Return a dictionary with keys as aliases if possible.""" return {k: self[k] for k in self} + def iter_explicit_keys(self) -> Iterator[str]: + """Iterate over explicitly set schema keys, preferring aliases over field names.""" + for field_name, field_info in type(self).model_fields.items(): + if field_name in self.model_fields_set: + yield field_info.alias or field_name + yield from self._iterable_model_extra + + def _get_explicit_kv_dict(self) -> dict[str, Any]: + """Return explicitly set key-value pairs with keys as aliases if possible.""" + return {k: self[k] for k in self.iter_explicit_keys()} + + def explicit_keys(self) -> KeysView[str]: + """Return the explicitly set keys of the config.""" + return self._get_explicit_kv_dict().keys() + + def explicit_values(self) -> ValuesView[Any]: + """Return the explicitly set values of the config.""" + return self._get_explicit_kv_dict().values() + + def explicit_items(self) -> ItemsView[str, Any]: + """Return the explicitly set items of the config with keys as aliases if possible.""" + return self._get_explicit_kv_dict().items() + def keys(self) -> KeysView[str]: """Return the keys (aliases prioritized over names) of the config.""" return self._get_kv_dict().keys() diff --git a/modelopt/torch/opt/config_loader.py b/modelopt/torch/opt/config_loader.py index 43231c90995..e6c3389d1b8 100644 --- a/modelopt/torch/opt/config_loader.py +++ b/modelopt/torch/opt/config_loader.py @@ -33,12 +33,26 @@ import re import sys from pathlib import Path -from typing import Any, Union, get_args, get_origin, get_type_hints +from typing import ( + TYPE_CHECKING, + Any, + TypeVar, + Union, + get_args, + get_origin, + get_type_hints, + overload, +) import yaml from pydantic import TypeAdapter from typing_extensions import NotRequired, Required, is_typeddict +if TYPE_CHECKING: + from modelopt.torch.opt.config import ModeloptBaseConfig + +_ModeloptConfigT = TypeVar("_ModeloptConfigT", bound="ModeloptBaseConfig") + @dataclass class _ListSnippet: @@ -47,7 +61,7 @@ class _ListSnippet: YAML requires one root node per document, so a file that is "a list with an ``imports`` section" has to use two documents separated by ``---``. This wrapper is the internal transport carrying both pieces from - :func:`_load_raw_config` to :func:`_resolve_imports` without smuggling them + :func:`_load_raw_config_with_schema` to :func:`_resolve_imports` without smuggling them through a sentinel dict key (which would collide if a user happened to choose the same key name). """ @@ -113,7 +127,7 @@ def _resolve_config_path(config_file: str | Path | Traversable) -> Path | Traver built-in package resources return a ``Traversable``. Raises ``ValueError`` if no candidate exists. - Factored out of :func:`_load_raw_config` so :func:`_resolve_imports` can + Factored out of :func:`_load_raw_config_with_schema` so :func:`_resolve_imports` can compute a canonical cycle-detection key without reading the file twice. """ # Probe order: filesystem first, then built-in library. @@ -243,13 +257,6 @@ def _load_raw_config_with_schema(config_file: str | Path | Traversable) -> _RawC ) -def _load_raw_config( - config_file: str | Path | Traversable, -) -> dict[str, Any] | list[Any] | _ListSnippet: - """Load a config YAML without resolving ``$import`` references.""" - return _load_raw_config_with_schema(config_file).data - - _IMPORT_KEY = "$import" @@ -334,7 +341,19 @@ def _schema_equal(left: Any | None, right: Any | None) -> bool: def _list_element_schema(schema_type: Any | None) -> Any | None: """Return the element schema for a typed ``list[T]`` annotation.""" schema_type = _unwrap_schema_type(schema_type) - if get_origin(schema_type) is not list: + origin = get_origin(schema_type) + if origin in (UnionType, Union): + element_schemas = [] + for arg in get_args(schema_type): + if arg is NoneType: + continue + element_schema = _list_element_schema(arg) + if element_schema is None: + continue + if not any(_schema_equal(element_schema, seen) for seen in element_schemas): + element_schemas.append(element_schema) + return element_schemas[0] if len(element_schemas) == 1 else None + if origin is not list: return None args = get_args(schema_type) if len(args) != 1 or args[0] is Any: @@ -373,10 +392,10 @@ def _validate_modelopt_schema( data: Any, config_path: Any, schema_type: Any | None = None, -) -> None: - """Validate resolved config content against the requested schema without mutating it.""" +) -> Any: + """Validate resolved config content and return the Pydantic-normalized value.""" if schema_type is None and not schema_path: - return + return data if schema_type is None: assert schema_path is not None schema_type = _schema_type(schema_path) @@ -384,7 +403,7 @@ def _validate_modelopt_schema( # TypeAdapter validates the schema types we allow here: BaseModel classes # plus regular typing constructs such as TypedDict, list[TypedDict], unions, # and aliases. Schema comments are not treated as arbitrary validators. - TypeAdapter(schema_type).validate_python(data) + return TypeAdapter(schema_type).validate_python(data) except Exception as exc: raise ValueError( f"Config file {config_path} does not match modelopt-schema " @@ -508,6 +527,12 @@ def _resolve_list_import( if _schema_equal(imported.schema_type, element_schema): return [imported.data] + element_schema_unwrapped = _unwrap_schema_type(element_schema) + if isinstance(imported.data, dict) and ( + element_schema_unwrapped is dict or get_origin(element_schema_unwrapped) is dict + ): + return [imported.data] + raise ValueError( f"$import {ref_name!r} in list at {context} has schema " f"{_schema_label(imported.schema_type, imported.schema)!r}; expected either " @@ -592,22 +617,52 @@ def _find_import_marker(obj: Any, context: str = "root") -> tuple[Any, str] | No return None +# Concrete ModeloptBaseConfig subclasses are returned as that exact parsed type. +@overload def load_config( config_path: str | Path | Traversable, *, - schema_type: Any | None = None, -) -> dict[str, Any] | list[Any]: + schema_type: type[_ModeloptConfigT], +) -> _ModeloptConfigT: ... + + +# Typed list schemas, such as list[SomeModeloptConfig], validate each element +# and return a list of parsed config objects. +@overload +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: type[list[_ModeloptConfigT]], +) -> list[_ModeloptConfigT]: ... + + +# Without an explicit schema_type, untyped files return raw dict/list payloads; +# files with modelopt-schema comments still return the validated schema value. +@overload +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: None = None, +) -> "ModeloptBaseConfig | dict[str, Any] | list[Any]": ... + + +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: object | None = None, +) -> "ModeloptBaseConfig | dict[str, Any] | list[Any]": """Load a YAML config and resolve all ``$import`` references. This is the primary config loading entry point. It loads the YAML file, resolves any ``imports`` / ``$import`` directives, and returns the final - config dict or list. + config. ``schema_type`` supplies a typing context for import resolution when the - file itself has no ``modelopt-schema`` comment. It is intentionally not a - request to validate the top-level file. Top-level files are validated only - when they declare ``modelopt-schema``; imported snippets are stricter and - must always declare ``modelopt-schema``. + file itself has no ``modelopt-schema`` comment. If either ``schema_type`` + or a ``modelopt-schema`` comment is present, the resolved top-level payload + is returned as the Pydantic-normalized value for that schema. Untyped files + return the resolved dict or list. Imported snippets are stricter and must + always declare ``modelopt-schema``. """ raw = _load_raw_config_with_schema(config_path) data = raw.data @@ -616,5 +671,8 @@ def load_config( if isinstance(data, (_ListSnippet, dict)): data = _resolve_imports(data, schema_type=resolver_schema_type) - _validate_modelopt_schema(raw.schema, data, raw.path, schema_type=declared_schema_type) + if declared_schema_type is not None or schema_type is not None: + return _validate_modelopt_schema( + raw.schema, data, raw.path, schema_type=declared_schema_type or schema_type + ) return data diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index f1db2df9e84..b9b1c9f8726 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -21,7 +21,7 @@ import warnings from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from contextlib import nullcontext from typing import Any @@ -65,7 +65,7 @@ def estimate_quant_compression_for_quantizer(quantizer_attr_cfg): if not quantizer_attr_cfg: return 1.0 return min(estimate_quant_compression_for_quantizer(q) for q in quantizer_attr_cfg) - if isinstance(quantizer_attr_cfg, dict): + if isinstance(quantizer_attr_cfg, Mapping): # Handle raw quantizer cfg dicts (e.g. {"num_bits": (4, 3), "axis": None}) if not quantizer_attr_cfg.get("enable", True): return 1.0 @@ -103,47 +103,72 @@ def estimate_quant_compression_for_quantizer(quantizer_attr_cfg): return estimate_quant_compression_for_quantizer(cfgs) if cfgs else 1.0 +QuantRecipeConfig = str | Mapping[str, Any] | QuantizeConfig | None + + class QuantRecipe(CustomHPType): """A subclass of QuantizeConfig enabling auto_quantize specific configurations. Args: - quant_cfg: str or dict or None. dict is used for custom quantization formats. + quant_cfg: str, QuantizeConfig, mapping, or None. A mapping is used for custom quantization formats. name: name for custom quantization formats. Only used if quantization format is a custom format not available in :mod:`modelopt.torch.quantization.config`. """ - def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | None = None): + def __init__(self, quant_cfg: QuantRecipeConfig = None, name: str | None = None): """Initialize the QuantRecipe with the quantization configuration.""" name = self.get_auto_name_for_config(quant_cfg) or name if quant_cfg is None: - quant_cfg = {"quant_cfg": [{"quantizer_name": "*", "enable": False}]} - elif isinstance(quant_cfg, str): - assert hasattr(mtq_config, quant_cfg), f"Unknown quantization format {quant_cfg}" - quant_cfg = getattr(mtq_config, quant_cfg) + self.config = mtq_config.QuantizeConfig( + quant_cfg=[mtq_config.QuantizerCfgEntry(quantizer_name="*", enable=False)] + ) else: - assert name is not None, "name must be provided for custom quantization formats" - - self.config = mtq_config.QuantizeConfig(**quant_cfg) # type: ignore [arg-type] + if isinstance(quant_cfg, str): + assert hasattr(mtq_config, quant_cfg), f"Unknown quantization format {quant_cfg}" + quant_cfg = getattr(mtq_config, quant_cfg) + elif not isinstance(quant_cfg, QuantizeConfig) and name is None: + raise ValueError("name must be provided for custom quantization formats") + + self.config = ( + quant_cfg.model_copy(deep=True) + if isinstance(quant_cfg, QuantizeConfig) + else mtq_config.QuantizeConfig.model_validate(quant_cfg) + ) + if name is None: + raise ValueError("name must be provided for custom quantization formats") # Disable KV Cache quantization # Currently KV Cache quantization is enabled for some quantization formats and disabled for others # This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy - self.config.quant_cfg.append({"quantizer_name": "*output_quantizer", "enable": False}) + self.config.quant_cfg.append( + mtq_config.QuantizerCfgEntry(quantizer_name="*output_quantizer", enable=False) + ) self.compression = estimate_quant_compression(self.config) self._str_repr: str = f"{name}(effective-bits: {self.compression * 16})" @staticmethod - def get_auto_name_for_config(quant_cfg: str | dict[str, Any] | None) -> str | None: + def get_auto_name_for_config(quant_cfg: QuantRecipeConfig) -> str | None: """Get a name for the quantization configuration.""" if quant_cfg is None: return "NONE" if isinstance(quant_cfg, str): return quant_cfg + + candidate = ( + quant_cfg + if isinstance(quant_cfg, QuantizeConfig) + else mtq_config.QuantizeConfig.model_validate(quant_cfg) + ) for quant_cfg_name in mtq_config.choices: - if quant_cfg == getattr(mtq_config, quant_cfg_name): + choice = getattr(mtq_config, quant_cfg_name) + try: + choice = mtq_config.QuantizeConfig.model_validate(choice) + except Exception: + continue + if candidate == choice: return quant_cfg_name return None diff --git a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py index d89ed35c6ca..b2bdb3323b5 100644 --- a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py +++ b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py @@ -15,9 +15,12 @@ """This module provides a GEMM function for fp8 per tensor quantization.""" +from collections.abc import Mapping + import torch from torch.autograd import Function +from modelopt.torch.opt.config import ModeloptBaseConfig from modelopt.torch.quantization.backends.gemm_registry import gemm_registry from modelopt.torch.quantization.config import FP8_DEFAULT_CFG, find_quant_cfg_entry_by_path from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear @@ -105,7 +108,7 @@ def _fp8_availability_check(module, input, args, kwargs): input_cfg = input_cfg[0] if isinstance(weight_cfg, list): weight_cfg = weight_cfg[0] - if not isinstance(input_cfg, dict) or not isinstance(weight_cfg, dict): + if not isinstance(input_cfg, Mapping) or not isinstance(weight_cfg, Mapping): return False # Check hardware support @@ -119,9 +122,18 @@ def _fp8_availability_check(module, input, args, kwargs): # Check quantizer presence and configuration if not hasattr(module, "input_quantizer") or not hasattr(module, "weight_quantizer"): return False + if not module.input_quantizer.is_enabled or not module.weight_quantizer.is_enabled: + return False # Check input quantizer config - for key, value in input_cfg.items(): + # TODO: Move this compatibility check inside the quantizer; matching config items here + # is fragile and easy to break as config semantics evolve. + input_items = input_cfg.items + if isinstance(input_cfg, ModeloptBaseConfig): + input_items = input_cfg.explicit_items + for key, value in input_items(): + if key == "enable": + continue if ( not hasattr(module.input_quantizer, key) or getattr(module.input_quantizer, key) != value @@ -129,7 +141,14 @@ def _fp8_availability_check(module, input, args, kwargs): return False # Check weight quantizer config - for key, value in weight_cfg.items(): + # TODO: Move this compatibility check inside the quantizer; matching config items here + # is fragile and easy to break as config semantics evolve. + weight_items = weight_cfg.items + if isinstance(weight_cfg, ModeloptBaseConfig): + weight_items = weight_cfg.explicit_items + for key, value in weight_items(): + if key == "enable": + continue if ( not hasattr(module.weight_quantizer, key) or getattr(module.weight_quantizer, key) != value diff --git a/modelopt/torch/quantization/backends/nvfp4_gemm.py b/modelopt/torch/quantization/backends/nvfp4_gemm.py index fdf6babb695..f431ad63c1e 100644 --- a/modelopt/torch/quantization/backends/nvfp4_gemm.py +++ b/modelopt/torch/quantization/backends/nvfp4_gemm.py @@ -15,10 +15,13 @@ """This module provides a GEMM function for nvfp4 quantization.""" +from collections.abc import Mapping + import torch from torch.autograd import Function import modelopt.torch.quantization as mtq +from modelopt.torch.opt.config import ModeloptBaseConfig from modelopt.torch.quantization.backends.gemm_registry import gemm_registry from modelopt.torch.quantization.backends.utils import fp4_compatible from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear @@ -224,11 +227,19 @@ def _nvfp4_availability_check(module, input, args, kwargs): input_cfg = input_cfg[0] if isinstance(weight_cfg, list): weight_cfg = weight_cfg[0] - if not isinstance(input_cfg, dict) or not isinstance(weight_cfg, dict): + if not isinstance(input_cfg, Mapping) or not isinstance(weight_cfg, Mapping): return False # Check input quantizer config - for key, value in input_cfg.items(): + # TODO: Move this compatibility check inside the quantizer; matching config items here + # is fragile and easy to break as config semantics evolve. + if not module.input_quantizer.is_enabled or not module.weight_quantizer.is_enabled: + return False + + input_items = input_cfg.items + if isinstance(input_cfg, ModeloptBaseConfig): + input_items = input_cfg.explicit_items + for key, value in input_items(): if key == "enable": continue if ( @@ -238,7 +249,12 @@ def _nvfp4_availability_check(module, input, args, kwargs): return False # Check weight quantizer config - for key, value in weight_cfg.items(): + # TODO: Move this compatibility check inside the quantizer; matching config items here + # is fragile and easy to break as config semantics evolve. + weight_items = weight_cfg.items + if isinstance(weight_cfg, ModeloptBaseConfig): + weight_items = weight_cfg.explicit_items + for key, value in weight_items(): if key == "enable": continue if ( diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 3adb70cf6b7..34cce5c1bed 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -71,8 +71,10 @@ - ``parent_class`` *(optional)*: restricts matching to quantizers whose immediate parent module is of this PyTorch class (e.g. ``"nn.Linear"``). If omitted, all matching quantizers are targeted regardless of their parent class. -- ``cfg`` *(optional)*: a dict of quantizer attributes as defined by - :class:`QuantizerAttributeConfig`, or a list of such dicts. When a list is given, the matched +- ``cfg`` *(optional)*: a :class:`QuantizerAttributeConfig`, or a list of + :class:`QuantizerAttributeConfig` objects. Equivalent dict and list-of-dict inputs are accepted + for backward compatibility, but are deprecated weakly schematized forms. When a list is given, + the matched :class:`TensorQuantizer ` is replaced with a :class:`SequentialQuantizer ` @@ -150,60 +152,16 @@ """ -import copy import warnings +from collections.abc import Mapping, Sequence from typing import Any, Literal, cast from pydantic import ValidationInfo, field_validator, model_validator -from typing_extensions import Required, TypedDict from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.opt.config_loader import load_config from modelopt.torch.utils.network import ConstructorLike - -class QuantizerCfgEntry(TypedDict, total=False): - """A single entry in a ``quant_cfg`` list.""" - - quantizer_name: Required[str] # matched against quantizer module names - parent_class: str | None # optional; filters by pytorch module class name (e.g. "nn.Linear") - cfg: dict[str, Any] | list[dict[str, Any]] | None # quantizer attribute config(s) - enable: bool | None # toggles matched quantizers on/off; independent of cfg - - -def find_quant_cfg_entry_by_path( - quant_cfg_list: list[QuantizerCfgEntry], quantizer_name: str -) -> QuantizerCfgEntry: - """Find the last entry in a ``quant_cfg`` list whose ``quantizer_name`` key equals the query. - - This performs an **exact string comparison** against the ``quantizer_name`` field of each - entry — it does *not* apply ``fnmatch`` pattern matching. For example, passing - ``"*input_quantizer"`` will only match entries whose ``quantizer_name`` is literally - ``"*input_quantizer"``, not entries with a different wildcard that would match the same - module names at apply time. - - Returns the *last* match because entries are applied in list order and later entries - override earlier ones, so the last match represents the effective configuration. - - Args: - quant_cfg_list: A list of :class:`QuantizerCfgEntry` dicts. - quantizer_name: The exact ``quantizer_name`` string to search for. - - Returns: - The last entry whose ``quantizer_name`` equals *quantizer_name*. - - Raises: - KeyError: If no entry with the given ``quantizer_name`` is found. - """ - result = None - for entry in quant_cfg_list: - if isinstance(entry, dict) and entry.get("quantizer_name") == quantizer_name: - result = entry - if result is None: - raise KeyError(f"No quant_cfg entry with quantizer_name={quantizer_name!r}") - return result - - BiasType = Literal["static", "dynamic"] BiasMethod = Literal["mean", "max_min"] @@ -265,7 +223,7 @@ def _validate_recursive(value): if isinstance(value, list): for item in value: _validate_recursive(item) - elif isinstance(value, dict): + elif isinstance(value, Mapping): if len(value) == 1 and "enable" in value and value["enable"] is True: raise ValueError( "Invalid quantizer config: Cannot specify only {'enable': True}. " @@ -563,6 +521,127 @@ def validate_calibrator(cls, v, info: ValidationInfo): ) +class QuantizerCfgEntry(ModeloptBaseConfig): + """A single entry in a ``quant_cfg`` list.""" + + quantizer_name: str = ModeloptField( + default="", + title="Quantizer name pattern.", + description="Wildcard pattern matched against quantizer module names.", + validate_default=True, + ) + parent_class: str | None = ModeloptField( + default=None, + title="Parent module class filter.", + description='Optional PyTorch module class name filter, e.g. "nn.Linear".', + ) + cfg: QuantizerAttributeConfig | list[QuantizerAttributeConfig] | None = ModeloptField( + default=None, + title="Quantizer attributes.", + description=( + "Attributes to apply to matched quantizers. A list configures a sequential quantizer." + ), + ) + enable: bool = ModeloptField( + default=True, + title="Quantizer enable flag.", + description="Optional on/off toggle for matched quantizers, independent of cfg.", + ) + + @model_validator(mode="before") + @classmethod + def validate_quantizer_cfg_entry(cls, values): + """Validate raw quant_cfg entry semantics before cfg is parsed.""" + if not isinstance(values, Mapping): + return values + + if "cfg" in values and values["cfg"] is None: + raise ValueError("cfg must be omitted or a valid mapping/list, not null.") + if "enable" in values and values["enable"] is None: + raise ValueError("enable must be a boolean when provided, not null.") + + cfg = values.get("cfg") + enable = values.get("enable", True) + if enable is False and cfg in ({}, []): + values = dict(values) + values["cfg"] = None + return values + + if enable and cfg is not None: + cls._validate_enabled_cfg(cfg) + return values + + @field_validator("quantizer_name") + @classmethod + def validate_quantizer_name(cls, v): + """Validate quantizer_name is non-empty.""" + if not v: + raise ValueError("quantizer_name must be a non-empty string.") + return v + + @staticmethod + def _validate_enabled_cfg(cfg): + """Validate cfg has real quantizer attributes when enabling a quantizer.""" + if isinstance(cfg, QuantizerAttributeConfig): + return + if isinstance(cfg, Mapping): + if len(cfg) == 0: + raise ValueError("cfg must be a non-empty dict when enabling a quantizer.") + return + if isinstance(cfg, list): + if len(cfg) == 0: + raise ValueError("cfg must be a non-empty list when enabling a quantizer.") + for item in cfg: + if isinstance(item, QuantizerAttributeConfig): + continue + if not isinstance(item, Mapping) or len(item) == 0: + raise ValueError( + "cfg list entries must be QuantizerAttributeConfig or non-empty dicts " + "when enabling a quantizer." + ) + return + raise ValueError( + "cfg must be QuantizerAttributeConfig, a non-empty dict, or a non-empty list " + "when enabling a quantizer." + ) + + +def find_quant_cfg_entry_by_path( + quant_cfg_list: Sequence[Any], quantizer_name: str +) -> QuantizerCfgEntry | Mapping[str, Any]: + """Find the last entry in a ``quant_cfg`` list whose ``quantizer_name`` key equals the query. + + This performs an **exact string comparison** against the ``quantizer_name`` field of each + entry - it does *not* apply ``fnmatch`` pattern matching. For example, passing + ``"*input_quantizer"`` will only match entries whose ``quantizer_name`` is literally + ``"*input_quantizer"``, not entries with a different wildcard that would match the same + module names at apply time. + + Returns the *last* match because entries are applied in list order and later entries + override earlier ones, so the last match represents the effective configuration. + + Args: + quant_cfg_list: A list of :class:`QuantizerCfgEntry` objects or legacy dicts. + quantizer_name: The exact ``quantizer_name`` string to search for. + + Returns: + The last entry whose ``quantizer_name`` equals *quantizer_name*. + + Raises: + KeyError: If no entry with the given ``quantizer_name`` is found. + """ + result: QuantizerCfgEntry | Mapping[str, Any] | None = None + for entry in quant_cfg_list: + if isinstance(entry, QuantizerCfgEntry): + if entry.get("quantizer_name") == quantizer_name: + result = entry + elif isinstance(entry, Mapping) and entry.get("quantizer_name") == quantizer_name: + result = entry + if result is None: + raise KeyError(f"No quant_cfg entry with quantizer_name={quantizer_name!r}") + return result + + class QuantizeAlgorithmConfig(ModeloptBaseConfig): """Calibration algorithm config base.""" @@ -926,60 +1005,62 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): ) -QuantizeQuantCfgType = list[QuantizerCfgEntry] -QuantizerCfgListConfig = QuantizeQuantCfgType +QuantizerCfgListConfig = list[QuantizerCfgEntry] +QuantizeQuantCfgInputType = Mapping[str, Any] | Sequence[QuantizerCfgEntry | Mapping[str, Any]] _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None -def normalize_quant_cfg_list(v: dict | list) -> list[QuantizerCfgEntry]: - """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` dicts. +def normalize_quant_cfg_list( + v: Mapping[str, Any] | Sequence[QuantizerCfgEntry | Mapping[str, Any]], +) -> list[QuantizerCfgEntry]: + """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` objects. Supports the following input forms: - A ``list`` of entries in any of the per-entry forms below. - - A legacy flat ``dict`` (``{"*": ..., "*weight_quantizer": ...}``) — each key/value pair is + - A ``list`` containing :class:`QuantizerCfgEntry` objects, which are preserved as-is. + - A legacy flat ``dict`` (``{"*": ..., "*weight_quantizer": ...}``) - each key/value pair is converted to a single-key dict entry and then normalized. Per-entry forms (when input is a list): - - New format: ``{"quantizer_name": ..., "enable": ..., "cfg": ...}`` — passed through. - - Legacy single-key format: ``{"": }`` — converted to new format. - - Legacy ``nn.*``-scoped format: ``{"nn.": {"": }}`` — converted + - New format: ``{"quantizer_name": ..., "enable": ..., "cfg": ...}`` - passed through. + - Legacy single-key format: ``{"": }`` - converted to new format. + - Legacy ``nn.*``-scoped format: ``{"nn.": {"": }}`` - converted to a new-format entry with ``parent_class`` set. - **Validation** — an entry is rejected if it carries no instruction, i.e. it specifies neither - ``cfg`` nor ``enable``. Concretely, the following are invalid: + **Validation** - an entry is rejected if its shape is invalid. Concretely, the following + are invalid: - An empty entry ``{}``. - - An entry with only ``quantizer_name`` and no other keys — the only effect would be an - implicit ``enable=True``, which must be stated explicitly. - An entry with ``enable=True`` (explicit or implicit) whose ``cfg`` is not a non-empty - ``dict`` or ``list`` — e.g. ``{"quantizer_name": "*", "cfg": {}}`` or + ``dict`` or ``list`` - e.g. ``{"quantizer_name": "*", "cfg": {}}`` or ``{"quantizer_name": "*", "cfg": 42}``. An enabled quantizer must have a valid configuration. - **Normalization** — after conversion and validation every entry is put into canonical form: - - - ``enable`` is set to ``True`` if not explicitly specified. - - ``cfg`` is set to ``None`` if not present in the entry. - - Every returned entry is therefore guaranteed to have the keys ``quantizer_name``, ``enable``, - and ``cfg`` (plus optionally ``parent_class``). + **Normalization** - after conversion and validation every entry is parsed as a + :class:`QuantizerCfgEntry`. Schema defaults are available through mapping access, so ``enable`` + defaults to ``True`` and ``cfg`` defaults to ``None`` when omitted. Omitted defaults are not + marked as explicitly set, so ``model_dump(exclude_unset=True)`` preserves the user's sparse + input shape. Typed :class:`QuantizerCfgEntry` inputs are assumed to be already parsed and are + preserved. Args: v: A list of raw quant_cfg entries in any supported format, or a legacy flat dict. Returns: - A list of :class:`QuantizerCfgEntry` dicts in canonical normalized form. + A list of :class:`QuantizerCfgEntry` objects in canonical normalized form. Existing + typed entries are preserved. Raises: - ValueError: If any entry has only ``quantizer_name`` with neither ``cfg`` nor ``enable``, - if ``enable=True`` with an empty or non-dict/list ``cfg``, or if the entry format - is not recognized. + ValueError: If ``enable=True`` with an empty or non-dict/list ``cfg``, or if the entry + format is not recognized. """ + if isinstance(v, list) and all(isinstance(raw, QuantizerCfgEntry) for raw in v): + return cast("list[QuantizerCfgEntry]", v) def _warn_legacy(): warnings.warn( @@ -991,43 +1072,53 @@ def _warn_legacy(): stacklevel=4, ) - # Legacy flat-dict format: {"*": {...}, "*weight_quantizer": {...}} → list of single-key dicts. - if isinstance(v, dict): + # Legacy flat-dict format: {"*": {...}, "*weight_quantizer": {...}} -> list of single-key dicts. + if isinstance(v, Mapping): _warn_legacy() v = [{k: val} for k, val in v.items()] - def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: + def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: """Convert a single legacy key-value pair to one or more QuantizerCfgEntry dicts.""" # Legacy "default" key was a catch-all applied as "*" in the old conversion code. if key == "default": key = "*" if isinstance(key, str) and key.startswith("nn."): - if not isinstance(value, dict): + if not isinstance(value, Mapping): raise ValueError(f"For 'nn.*' scoped format, value must be a dict, got {value!r}") # Support multi-key nn.*-scoped dicts by emitting one entry per sub-key. - entries: list[QuantizerCfgEntry] = [] + entries: list[dict[str, Any]] = [] for q_path, sub_cfg in value.items(): - sub_cfg = dict(sub_cfg) - enable = sub_cfg.pop("enable", None) - cfg = sub_cfg or None - entry: QuantizerCfgEntry = { + if isinstance(sub_cfg, QuantizerAttributeConfig): + enable = None + cfg = sub_cfg + elif isinstance(sub_cfg, Mapping): + sub_cfg = dict(sub_cfg) + enable = sub_cfg.pop("enable", None) + cfg = sub_cfg or None + else: + enable = None + cfg = sub_cfg + entry: dict[str, Any] = { "parent_class": key, "quantizer_name": q_path, - "cfg": cfg, } + if cfg is not None: + entry["cfg"] = cfg if enable is not None: entry["enable"] = enable entries.append(entry) return entries else: - if isinstance(value, dict): + if isinstance(value, Mapping): cfg = {k: val for k, val in value.items() if k != "enable"} or None enable = value.get("enable") else: cfg = value enable = None - entry = {"quantizer_name": key, "cfg": cfg} + entry = {"quantizer_name": key} + if cfg is not None: + entry["cfg"] = cfg if enable is not None: entry["enable"] = enable return [entry] @@ -1035,16 +1126,23 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: result: list[QuantizerCfgEntry] = [] _warned_legacy = False for raw in v: - if isinstance(raw, dict) and "quantizer_name" in raw: + if isinstance(raw, QuantizerCfgEntry): + result.append(raw) + continue + if isinstance(raw, Mapping) and "quantizer_name" in raw: entries = [dict(raw)] # copy to avoid mutating caller's data - elif isinstance(raw, dict) and len(raw) == 1: + elif isinstance(raw, Mapping) and len(raw) == 1: key, val = next(iter(raw.items())) entries = [dict(e) for e in _dict_to_entry(key, val)] if not _warned_legacy: _warn_legacy() _warned_legacy = True - elif isinstance(raw, dict) and len(raw) > 1 and any(k.startswith("nn.") for k in raw): - # Legacy flat dict with nn.*-scoped keys mixed with other keys — expand all pairs. + elif ( + isinstance(raw, Mapping) + and len(raw) > 1 + and any(isinstance(k, str) and k.startswith("nn.") for k in raw) + ): + # Legacy flat dict with nn.*-scoped keys mixed with other keys - expand all pairs. entries = [] for k, val in raw.items(): entries.extend(dict(e) for e in _dict_to_entry(k, val)) @@ -1055,48 +1153,65 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: raise ValueError(f"Invalid quant_cfg entry: {raw!r}.") for entry in entries: - # Validate: must carry at least one instruction beyond the path selector. - if "cfg" not in entry and "enable" not in entry: - raise ValueError( - f"Invalid quant_cfg entry: {raw!r} — each entry must specify 'cfg', 'enable', " - "or both. An entry with only 'quantizer_name' has no effect (implicit " - "enable=True is not allowed; set it explicitly)." - ) - # Validate: when cfg is present and enable=True, cfg must be a non-empty # dict or list. An empty cfg would attempt to create a # QuantizerAttributeConfig with no actual configuration. + if "cfg" in entry and entry["cfg"] is None: + raise ValueError( + f"Invalid quant_cfg entry: {raw!r} - 'cfg' must be omitted or a " + "valid mapping/list, not null." + ) + if "enable" in entry and entry["enable"] is None: + raise ValueError( + f"Invalid quant_cfg entry: {raw!r} - 'enable' must be a boolean " + "when provided, not null." + ) + cfg = entry.get("cfg") enable = entry.get("enable", True) if enable and cfg is not None: - if isinstance(cfg, dict): + if isinstance(cfg, QuantizerAttributeConfig): + is_invalid = False + elif isinstance(cfg, Mapping): is_invalid = len(cfg) == 0 elif isinstance(cfg, list): is_invalid = len(cfg) == 0 or any( - not isinstance(item, dict) or len(item) == 0 for item in cfg + not isinstance(item, (Mapping, QuantizerAttributeConfig)) + or (isinstance(item, Mapping) and len(item) == 0) + for item in cfg ) else: is_invalid = True if is_invalid: raise ValueError( - f"Invalid quant_cfg entry: {raw!r} — 'cfg' must be a non-empty dict " - f"or a non-empty list of non-empty dicts when enabling a quantizer " - f"(got {type(cfg).__name__}: {cfg!r}). Either provide quantizer " - "attributes in 'cfg' or remove 'cfg' and set 'enable' explicitly." + f"Invalid quant_cfg entry: {raw!r} - 'cfg' must be a " + "QuantizerAttributeConfig, a non-empty dict, or a non-empty list of " + "QuantizerAttributeConfig/non-empty dict entries when enabling a " + f"quantizer (got {type(cfg).__name__}: {cfg!r}). Either provide " + "quantizer attributes in 'cfg' or remove 'cfg' and set 'enable' " + "explicitly." ) - # Normalize: make enable and cfg always explicit. - entry.setdefault("enable", True) - entry.setdefault("cfg", None) - - result.append(cast("QuantizerCfgEntry", entry)) + result.append(QuantizerCfgEntry.model_validate(entry)) return result class QuantizeConfig(ModeloptBaseConfig): """Default configuration for ``quantize`` mode.""" - quant_cfg: QuantizeQuantCfgType = ModeloptField( + def model_dump(self, **kwargs): + """Dump quant_cfg entries without unset optional fields.""" + data = super().model_dump(**kwargs) + if "quant_cfg" in data: + data["quant_cfg"] = [ + entry.model_dump(exclude_unset=True) + if isinstance(entry, QuantizerCfgEntry) + else {k: v for k, v in entry.items() if v is not None} + for entry in self.quant_cfg + ] + return data + + quant_cfg: QuantizerCfgListConfig = ModeloptField( default=[{"quantizer_name": "*", "cfg": {"num_bits": 8, "axis": None}}], title="Quantization configuration", validate_default=True, @@ -1112,26 +1227,11 @@ class QuantizeConfig(ModeloptBaseConfig): @field_validator("quant_cfg", mode="before") @classmethod def normalize_quant_cfg(cls, v): - """Normalize quant_cfg entries: convert dict and tuple forms to QuantizerCfgEntry dicts.""" - if not isinstance(v, (list, dict)): + """Normalize quant_cfg entries into QuantizerCfgEntry objects.""" + if not isinstance(v, (list, Mapping)): return v return normalize_quant_cfg_list(v) - @field_validator("quant_cfg", mode="after") - @classmethod - def validate_quant_cfg_entries(cls, v): - """Validate quantizer attribute configs to surface errors (e.g. invalid axis/block_sizes).""" - qac_fields = set(QuantizerAttributeConfig.model_fields.keys()) - for entry in v: - cfg = entry.get("cfg") - if cfg is None: - continue - cfgs = cfg if isinstance(cfg, list) else [cfg] - for c in cfgs: - if isinstance(c, dict) and qac_fields & set(c.keys()): - QuantizerAttributeConfig.model_validate(c) - return v - class CompressConfig(ModeloptBaseConfig): """Default configuration for ``compress`` mode.""" @@ -1156,564 +1256,213 @@ class _QuantizeExportConfig(ModeloptBaseConfig): """An empty config.""" -_base_disable_all: list[QuantizerCfgEntry] = [ - cast("QuantizerCfgEntry", load_config("configs/ptq/units/base_disable_all")) -] - -_default_disabled_quantizer_cfg: list[QuantizerCfgEntry] = load_config( - "configs/ptq/units/default_disabled_quantizers" -) - -_mamba_moe_disabled_quantizer_cfg: list[QuantizerCfgEntry] = [ - {"quantizer_name": "*fc1_latent_proj*", "enable": False}, # Skip Latent MOE - {"quantizer_name": "*fc2_latent_proj*", "enable": False}, # Skip Latent MOE - {"quantizer_name": "*q_proj*", "enable": False}, # Skip QKV Linear (HF naming) - {"quantizer_name": "*k_proj*", "enable": False}, # Skip QKV Linear (HF naming) - {"quantizer_name": "*v_proj*", "enable": False}, # Skip QKV Linear (HF naming) - {"quantizer_name": "*o_proj*", "enable": False}, # Skip QKV Output Projection (HF naming) - { - "quantizer_name": "*self_attention.linear_qkv*", - "enable": False, - }, # Skip QKV Linear (Mcore naming) - { - "quantizer_name": "*self_attention.linear_proj*", - "enable": False, - }, # Skip QKV Output Projection (Mcore naming) -] - -INT8_DEFAULT_CFG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, - {"quantizer_name": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "max", -} - -INT8_SMOOTHQUANT_CFG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, - {"quantizer_name": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "smoothquant", -} - -INT8_WEIGHT_ONLY_CFG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, - {"quantizer_name": "*input_quantizer", "enable": False}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "max", -} - -FP8_DEFAULT_CFG: dict[str, Any] = load_config("configs/ptq/presets/model/fp8") - -MAMBA_MOE_FP8_AGGRESSIVE_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": {"num_bits": (4, 3), "axis": None}, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": {"num_bits": (4, 3), "axis": None}, - }, - *_default_disabled_quantizer_cfg, - *_mamba_moe_disabled_quantizer_cfg, - ], - "algorithm": "max", -} - -MAMBA_MOE_FP8_CONSERVATIVE_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": {"num_bits": (4, 3), "axis": None}, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": {"num_bits": (4, 3), "axis": None}, - }, - *_default_disabled_quantizer_cfg, - *_mamba_moe_disabled_quantizer_cfg, - {"quantizer_name": "*mixer.in_proj*", "enable": False}, # Skip mamba linear - {"quantizer_name": "*mixer.out_proj*", "enable": False}, # Skip mamba linear - ], - "algorithm": "max", -} - -FP8_PER_CHANNEL_PER_TOKEN_CFG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": 0}}, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": (4, 3), - "type": "dynamic", - "block_sizes": {-1: None}, - }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "max", -} - -# FP8 2D blockwise fake quantization config for deepseek models -FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (4, 3), - "block_sizes": {-1: 128, -2: 128}, - }, - }, - {"quantizer_name": "*input_quantizer", "enable": False}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "max", -} - -INT4_BLOCKWISE_WEIGHT_ONLY_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": 4, - "block_sizes": {-1: 128}, - }, - }, - {"quantizer_name": "*input_quantizer", "enable": False}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "max", -} - +def _load_quantizer_attribute_dict(config_path: str) -> dict[str, Any]: + """Load a schema-backed QuantizerAttributeConfig YAML as a public dict.""" + config = load_config(config_path, schema_type=QuantizerAttributeConfig) + if isinstance(config, QuantizerAttributeConfig): + return config.model_dump(exclude_unset=True) + if isinstance(config, Mapping): + return dict(config) + raise TypeError(f"{config_path} must declare QuantizerAttributeConfig.") + + +def _quantizer_cfg_entry_to_dict(entry: QuantizerCfgEntry | Mapping[str, Any]) -> dict[str, Any]: + """Dump a typed quant_cfg entry back to the public legacy dict shape.""" + if isinstance(entry, QuantizerCfgEntry): + return entry.model_dump(exclude_unset=True) + if isinstance(entry, Mapping): + return dict(entry) + raise TypeError(f"Expected QuantizerCfgEntry or mapping, got {type(entry).__name__}.") + + +def _load_quantizer_cfg_dict_list(config_path: str) -> list[dict[str, Any]]: + """Load a QuantizerCfgEntry or QuantizerCfgListConfig snippet as public dict entries.""" + config = load_config(config_path) + if isinstance(config, QuantizerCfgEntry): + return [_quantizer_cfg_entry_to_dict(config)] + if isinstance(config, list): + entries = [] + for entry in config: + if not isinstance(entry, (QuantizerCfgEntry, Mapping)): + raise TypeError( + f"Expected QuantizerCfgEntry or mapping, got {type(entry).__name__}." + ) + entries.append(_quantizer_cfg_entry_to_dict(entry)) + return entries + if isinstance(config, Mapping): + return [_quantizer_cfg_entry_to_dict(config)] + raise TypeError(f"{config_path} must declare QuantizerCfgEntry or QuantizerCfgListConfig.") -INT4_AWQ_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": 4, - "block_sizes": {-1: 128, "type": "static"}, - }, - }, - {"quantizer_name": "*input_quantizer", "enable": False}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": {"method": "awq_lite", "alpha_step": 0.1}, - # "algorithm": {"method": "awq_full", "alpha_step": 0.1, "max_co_batch_size": 1024}, - # "algorithm": {"method": "awq_clip", "max_co_batch_size": 2048}, -} -# W4A8 currently uses INT4 blockwise quantization (block size = 128) followed by FP8 quantization -# for weights. This could change in the future -W4A8_AWQ_BETA_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": [ - { - "num_bits": 4, - "block_sizes": {-1: 128, "type": "static"}, - }, - { - "num_bits": (4, 3), - }, - ], - }, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": (4, 3), - }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "awq_lite", -} +_base_disable_all: list[dict[str, Any]] = _load_quantizer_cfg_dict_list( + "configs/ptq/units/base_disable_all" +) -MXFP8_DEFAULT_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (4, 3), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - }, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": (4, 3), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": None, -} +_default_disabled_quantizer_cfg: list[dict[str, Any]] = _load_quantizer_cfg_dict_list( + "configs/ptq/units/default_disabled_quantizers" +) -MXFP6_DEFAULT_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (3, 2), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - }, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": (3, 2), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": None, -} +_mamba_moe_disabled_quantizer_cfg: list[dict[str, Any]] = _load_quantizer_cfg_dict_list( + "configs/ptq/units/mamba_moe_disabled_quantizers" +) -MXFP4_DEFAULT_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - }, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": None, -} +_nvfp4_cfg: dict[str, Any] = _load_quantizer_attribute_dict("configs/numerics/nvfp4") -W4A8_MXFP4_FP8_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - }, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": {"num_bits": (4, 3), "axis": None}, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": None, -} +_nvfp4_cfg_bs32: dict[str, Any] = _load_quantizer_attribute_dict("configs/numerics/nvfp4_bs32") -MXINT8_DEFAULT_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": 8, - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - }, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": 8, - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": None, -} +INT8_DEFAULT_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/int8", + schema_type=QuantizeConfig, +) +INT8_SMOOTHQUANT_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/int8_smoothquant", + schema_type=QuantizeConfig, +) +INT8_WEIGHT_ONLY_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/int8_weight_only", + schema_type=QuantizeConfig, +) +FP8_DEFAULT_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/fp8", + schema_type=QuantizeConfig, +) +MAMBA_MOE_FP8_AGGRESSIVE_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/mamba_moe_fp8_aggressive", + schema_type=QuantizeConfig, +) +MAMBA_MOE_FP8_CONSERVATIVE_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/mamba_moe_fp8_conservative", + schema_type=QuantizeConfig, +) +FP8_PER_CHANNEL_PER_TOKEN_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/fp8_per_channel_per_token", + schema_type=QuantizeConfig, +) +FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/fp8_2d_blockwise_weight_only", + schema_type=QuantizeConfig, +) +INT4_BLOCKWISE_WEIGHT_ONLY_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/int4_blockwise_weight_only", + schema_type=QuantizeConfig, +) +INT4_AWQ_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/int4_awq", + schema_type=QuantizeConfig, +) +W4A8_AWQ_BETA_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/w4a8_awq_beta", + schema_type=QuantizeConfig, +) +MXFP8_DEFAULT_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/mxfp8", + schema_type=QuantizeConfig, +) +MXFP6_DEFAULT_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/mxfp6", + schema_type=QuantizeConfig, +) +MXFP4_DEFAULT_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/mxfp4", + schema_type=QuantizeConfig, +) +W4A8_MXFP4_FP8_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/w4a8_mxfp4_fp8", + schema_type=QuantizeConfig, +) +MXINT8_DEFAULT_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/mxint8", + schema_type=QuantizeConfig, +) # KV-cache configs are designed to be merged with a primary quantization config (e.g. # FP8_DEFAULT_CFG) that already contains _base_disable_all. They intentionally omit both # _base_disable_all and "algorithm" because these are provided by the primary config. -FP8_KV_CFG: dict[str, Any] = load_config("configs/ptq/presets/kv/fp8") - -FP8_AFFINE_KV_CFG = { - "quant_cfg": [ - { - "quantizer_name": "*[kv]_bmm_quantizer", - "cfg": { - "num_bits": (4, 3), - "bias": {-2: None, -4: None, "type": "static"}, - }, - }, - ] -} - -_nvfp4_cfg = { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, -} - -_nvfp4_cfg_bs32 = { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (4, 3)}, -} - - -def _nvfp4_selective_quant_cfg( - layer_patterns: list[str], - *, - quantizer: dict = _nvfp4_cfg, - weight_only: bool = False, - algorithm: str | dict = "max", -) -> dict: - """Build an NVFP4 config that quantizes only the specified layer patterns.""" - quant_cfg: list[QuantizerCfgEntry] = [] - quant_cfg.extend(_base_disable_all) - for pattern in layer_patterns: - # Deep-copy the quantizer dict so each config constant gets its own instance. - quant_cfg.append( - {"quantizer_name": f"{pattern}weight_quantizer", "cfg": copy.deepcopy(quantizer)} - ) - if not weight_only: - quant_cfg.append( - {"quantizer_name": f"{pattern}input_quantizer", "cfg": copy.deepcopy(quantizer)} - ) - quant_cfg.extend(_default_disabled_quantizer_cfg) - return {"quant_cfg": quant_cfg, "algorithm": algorithm} - - -NVFP4_DEFAULT_CFG = _nvfp4_selective_quant_cfg(["*"]) - -NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - }, - }, - {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": { - "method": "mse", - "fp8_scale_sweep": True, - }, -} - -NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, - }, - }, - {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": { - "method": "local_hessian", - "fp8_scale_sweep": True, - }, -} - -MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": _nvfp4_cfg}, - {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, - *_default_disabled_quantizer_cfg, - *_mamba_moe_disabled_quantizer_cfg, - ], - "algorithm": "max", -} -MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": _nvfp4_cfg}, - {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, - *_default_disabled_quantizer_cfg, - *_mamba_moe_disabled_quantizer_cfg, - {"quantizer_name": "*mixer.in_proj*", "enable": False}, # Skip mamba linear - {"quantizer_name": "*mixer.out_proj*", "enable": False}, # Skip mamba linear - ], - "algorithm": "max", -} - -NVFP4_AWQ_LITE_CFG = _nvfp4_selective_quant_cfg(["*"], algorithm="awq_lite") - -NVFP4_AWQ_CLIP_CFG = _nvfp4_selective_quant_cfg(["*"], algorithm={"method": "awq_clip"}) - -NVFP4_AWQ_FULL_CFG = _nvfp4_selective_quant_cfg( - ["*"], algorithm={"method": "awq_full", "alpha_step": 0.1} +FP8_KV_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/kv/fp8", + schema_type=QuantizeConfig, ) - -# See comment above FP8_KV_CFG — KV-cache configs omit _base_disable_all and "algorithm". -NVFP4_AFFINE_KV_CFG = { - "quant_cfg": [ - { - "quantizer_name": "*[kv]_bmm_quantizer", - "cfg": { - **_nvfp4_cfg, - "bias": {-2: None, -4: None, "type": "static"}, - }, - }, - ] -} - -NVFP4_KV_CFG = { - "quant_cfg": [ - {"quantizer_name": "*[kv]_bmm_quantizer", "cfg": _nvfp4_cfg}, - ] -} - -# Moved from examples/diffusers/quantization/config.py to here -NVFP4_FP8_MHA_CONFIG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": _nvfp4_cfg}, - {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, - {"quantizer_name": "*output_quantizer", "enable": False}, - { - "quantizer_name": "*q_bmm_quantizer", - "cfg": { - "num_bits": (4, 3), - }, - }, - { - "quantizer_name": "*k_bmm_quantizer", - "cfg": { - "num_bits": (4, 3), - }, - }, - { - "quantizer_name": "*v_bmm_quantizer", - "cfg": { - "num_bits": (4, 3), - }, - }, - { - "quantizer_name": "*softmax_quantizer", - "cfg": { - "num_bits": (4, 3), - }, - }, - { - "quantizer_name": "transformer_blocks*bmm2_output_quantizer", - "cfg": { - "num_bits": (4, 3), - }, - }, - ], - "algorithm": "max", -} - -# See comment above FP8_KV_CFG — KV-cache configs omit _base_disable_all and "algorithm". -NVFP4_KV_ROTATE_CFG = { - "quant_cfg": [ - { - # q_bmm is disabled but pre-configured with rotate=True so that downstream - # code can inspect the rotate flag even while the quantizer is off. - "quantizer_name": "*q_bmm_quantizer", - "cfg": { - "rotate": True, - }, - "enable": False, - }, - { - "quantizer_name": "*k_bmm_quantizer", - "cfg": { - **_nvfp4_cfg, - "rotate": True, - }, - }, - {"quantizer_name": "*v_bmm_quantizer", "cfg": _nvfp4_cfg}, - ], - "algorithm": "max", -} - -NVFP4_SVDQUANT_DEFAULT_CFG = _nvfp4_selective_quant_cfg( - ["*"], algorithm={"method": "svdquant", "lowrank": 32} +FP8_AFFINE_KV_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/kv/fp8_affine", + schema_type=QuantizeConfig, ) -W4A8_NVFP4_FP8_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (4, 3)}, - }, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": (4, 3), - }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "max", -} - -MXFP4_MLP_WEIGHT_ONLY_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*mlp*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - }, - }, - { - "quantizer_name": "*block_sparse_moe*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, - }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": None, -} - -NVFP4_MLP_WEIGHT_ONLY_CFG = _nvfp4_selective_quant_cfg( - ["*mlp*", "*block_sparse_moe*"], quantizer=_nvfp4_cfg_bs32, weight_only=True +NVFP4_DEFAULT_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/nvfp4", + schema_type=QuantizeConfig, +) +NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/nvfp4_w4a4_weight_mse_fp8_sweep", + schema_type=QuantizeConfig, +) +NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/nvfp4_w4a4_weight_local_hessian", + schema_type=QuantizeConfig, +) +MAMBA_MOE_NVFP4_AGGRESSIVE_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/mamba_moe_nvfp4_aggressive", + schema_type=QuantizeConfig, +) +MAMBA_MOE_NVFP4_CONSERVATIVE_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/mamba_moe_nvfp4_conservative", + schema_type=QuantizeConfig, +) +NVFP4_AWQ_LITE_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/nvfp4_awq_lite", + schema_type=QuantizeConfig, +) +NVFP4_AWQ_CLIP_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/nvfp4_awq_clip", + schema_type=QuantizeConfig, ) -NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg( - ["*mlp.experts*", "*block_sparse_moe*", "*.experts.*"] +NVFP4_AWQ_FULL_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/nvfp4_awq_full", + schema_type=QuantizeConfig, +) +NVFP4_AFFINE_KV_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/kv/nvfp4_affine", + schema_type=QuantizeConfig, +) +NVFP4_KV_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/kv/nvfp4", + schema_type=QuantizeConfig, +) +NVFP4_FP8_MHA_CONFIG: QuantizeConfig = load_config( + "configs/ptq/presets/model/nvfp4_fp8_mha", + schema_type=QuantizeConfig, +) +NVFP4_KV_ROTATE_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/kv/nvfp4_rotate", + schema_type=QuantizeConfig, +) +NVFP4_SVDQUANT_DEFAULT_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/nvfp4_svdquant", + schema_type=QuantizeConfig, +) +W4A8_NVFP4_FP8_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/w4a8_nvfp4_fp8", + schema_type=QuantizeConfig, +) +MXFP4_MLP_WEIGHT_ONLY_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/mxfp4_mlp_weight_only", + schema_type=QuantizeConfig, +) +NVFP4_MLP_WEIGHT_ONLY_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/nvfp4_mlp_weight_only", + schema_type=QuantizeConfig, +) +NVFP4_EXPERTS_ONLY_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/nvfp4_experts_only", + schema_type=QuantizeConfig, +) +NVFP4_MLP_ONLY_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/nvfp4_mlp_only", + schema_type=QuantizeConfig, +) +NVFP4_OMLP_ONLY_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/nvfp4_omlp_only", + schema_type=QuantizeConfig, ) -NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*", "*.experts.*"]) -NVFP4_OMLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*o_proj*", "*mlp*", "*block_sparse_moe*"]) # DO NOT ADD NEW CONFIGS HERE. If you want to add a new general recipe, add it to # modelopt_recipes/general/ptq/ as a yaml file @@ -1729,6 +1478,7 @@ def _nvfp4_selective_quant_cfg( "INT8_SMOOTHQUANT_CFG", "INT8_WEIGHT_ONLY_CFG", "MXFP4_DEFAULT_CFG", + "MXFP6_DEFAULT_CFG", "MXFP8_DEFAULT_CFG", "MXINT8_DEFAULT_CFG", "NVFP4_AFFINE_KV_CFG", @@ -1752,11 +1502,12 @@ def _nvfp4_selective_quant_cfg( "MAMBA_MOE_NVFP4_AGGRESSIVE_CFG", "MAMBA_MOE_FP8_CONSERVATIVE_CFG", "MAMBA_MOE_FP8_AGGRESSIVE_CFG", + "NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG", "NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG", } -def need_calibration(config): +def need_calibration(config: QuantizeConfig | Mapping[str, Any]) -> bool: """Check if calibration is needed for the given config.""" if config["algorithm"] is not None and config["algorithm"] != "max": return True @@ -1764,23 +1515,30 @@ def need_calibration(config): def _not_dynamic(cfg): return cfg.get("enable", True) and cfg.get("type", "") != "dynamic" + def _cfg_to_dict(cfg): + if isinstance(cfg, QuantizerAttributeConfig): + return cfg.model_dump(exclude_unset=True) + return dict(cfg or {}) + quant_cfg: list = config.get("quant_cfg") or [] quant_cfg = normalize_quant_cfg_list(quant_cfg) for entry in quant_cfg: name = entry["quantizer_name"] raw_cfg = entry.get("cfg") + enable = entry["enable"] if "weight_quantizer" in name: # We don't calibrate weight quantizer continue # Sequential quantizers (e.g. W4A8) have a list of cfg dicts if isinstance(raw_cfg, list): for _config in raw_cfg: - if _not_dynamic(_config): + cfg = _cfg_to_dict(_config) + cfg["enable"] = enable + if _not_dynamic(cfg): return True continue - cfg = dict(raw_cfg or {}) - if "enable" in entry: - cfg["enable"] = entry["enable"] + cfg = _cfg_to_dict(raw_cfg) + cfg["enable"] = enable if _not_dynamic(cfg): return True diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 3f97f8380be..17bc3f7ddc7 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -18,7 +18,7 @@ import fnmatch import re import warnings -from collections.abc import Callable +from collections.abc import Callable, Mapping from contextlib import contextmanager from typing import Any, cast @@ -31,7 +31,7 @@ from .config import ( QuantizeConfig, - QuantizeQuantCfgType, + QuantizeQuantCfgInputType, QuantizerAttributeConfig, _QuantizeExportConfig, normalize_quant_cfg_list, @@ -215,7 +215,7 @@ def _replace_quant_module(model: nn.Module, version=None, registry=QuantModuleRe _replace_quant_module(getattr(model, name), version=version, registry=registry) -def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): +def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgInputType): """Apply a quantization config list to the quantizers in ``quant_model``. ``quant_cfg`` is an **ordered list** of :class:`QuantizerCfgEntry <.config.QuantizerCfgEntry>` @@ -223,8 +223,9 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType - ``quantizer_name`` *(required)*: wildcard matched against quantizer module names via :func:`fnmatch`. - - ``cfg`` *(optional)*: a dict of :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>` - fields, or a list of such dicts for sequential quantization. + - ``cfg`` *(optional)*: a :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>` + or a list of them for sequential quantization. Equivalent dict and list-of-dict inputs are + accepted for backward compatibility. - ``enable`` *(optional)*: ``True`` or ``False`` to toggle matched quantizers on or off. When omitted but ``cfg`` is present, defaults to ``True``. Every entry must specify at least one of ``cfg`` or ``enable`` — an entry with only ``quantizer_name`` is invalid. @@ -252,8 +253,8 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType for entry in quant_cfg: quantizer_name: str = entry["quantizer_name"] - cfg = entry["cfg"] # None, dict, or list — always explicit after normalization - enable: bool = entry["enable"] # always explicit after normalization + cfg = entry["cfg"] # None, QuantizerAttributeConfig, or list after normalization + enable = entry["enable"] parent_class_name = entry.get("parent_class") if parent_class_name: try: @@ -275,15 +276,13 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType # Has cfg: apply full replacement with the explicit enable value. if isinstance(cfg, QuantizerAttributeConfig): attributes = cfg.model_copy(update={"enable": enable}) - elif isinstance(cfg, dict): - attributes = QuantizerAttributeConfig(**cfg, enable=enable) + elif isinstance(cfg, list): + attributes = [c.model_copy(update={"enable": enable}) for c in cfg] else: - attributes = [ - c.model_copy(update={"enable": enable}) - if isinstance(c, QuantizerAttributeConfig) - else QuantizerAttributeConfig(**c, enable=enable) - for c in cfg - ] + raise ValueError( + f"Invalid cfg for quantizer {quantizer_name!r}: expected " + "QuantizerAttributeConfig or list." + ) set_quantizer_attributes_full(quant_model, quantizer_name, attributes, parent_class) @@ -416,7 +415,7 @@ def set_quantizer_attributes_full( def set_quantizer_attributes_partial( quant_model: nn.Module, wildcard_or_filter_func: str | Callable, - partial_attributes: dict[str, Any] | list[dict[str, Any]], + partial_attributes: Mapping[str, Any] | list[Mapping[str, Any]], parent_class: type[nn.Module] | None = None, ): """Update a subset of quantizer attributes by wildcard or filter function, merging with existing attributes. @@ -451,14 +450,14 @@ def set_quantizer_attributes_partial( an instance of this class. If ``None``, all quantizers matching ``wildcard_or_filter_func`` are adjusted. """ - if not isinstance(partial_attributes, (dict, list)): + if not isinstance(partial_attributes, (Mapping, list)): raise ValueError( - f"Invalid type for attributes: {type(partial_attributes)}, expected dictionary or list of dict." + f"Invalid type for attributes: {type(partial_attributes)}, expected mapping or list of mappings." ) if isinstance(partial_attributes, list) and not all( - isinstance(attr, dict) for attr in partial_attributes + isinstance(attr, Mapping) for attr in partial_attributes ): - raise ValueError("All elements in attributes list must be of type dict.") + raise ValueError("All elements in attributes list must be mappings.") for name, module in quant_model.named_modules(): if _match_quantizer(wildcard_or_filter_func, name, module, parent_class, quant_model): @@ -477,7 +476,7 @@ def set_quantizer_attributes_partial( @contextmanager -def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): +def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgInputType): """Context manager that temporarily applies a quantization config and restores the original state on exit. Calls :func:`set_quantizer_by_cfg` on entry and reverts every diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 713cdd7373c..9c41664ba5d 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -16,7 +16,7 @@ """This module contains the mode descriptor for the quantization mode.""" from abc import abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Mapping from modelopt.torch.opt.config import ModeloptBaseConfig from modelopt.torch.opt.conversion import ModelLikeModule @@ -374,7 +374,7 @@ def get_modelike_from_algo_cfg(algo_cfg: QuantizeAlgoCfgType) -> ModeConfigList: return [get_modelike_from_algo_cfg(c)[0] for c in algo_cfg] if algo_cfg is None or isinstance(algo_cfg, str): algo_name, algo_cfg = algo_cfg, {} - elif isinstance(algo_cfg, dict): + elif isinstance(algo_cfg, Mapping): algo_name = algo_cfg["method"] else: raise ValueError(f"Invalid config type: {type(algo_cfg)}") diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 5e65f9cc1d4..0e8f34f101d 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -19,7 +19,7 @@ import inspect import os import warnings -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Mapping from typing import Any import torch @@ -143,7 +143,7 @@ def postprocess_amax(model: nn.Module, key: str, post_process_fn) -> nn.Module: def quantize( model: nn.Module, - config: dict[str, Any | QuantizeConfig], + config: QuantizeConfig | Mapping[str, Any], forward_loop: ForwardLoop | None = None, ) -> nn.Module: """Quantizes and calibrates the model in-place. @@ -167,9 +167,11 @@ def quantize( :meth:`calibrate `. Each entry in the ``"quant_cfg"`` list has a ``"quantizer_name"`` wildcard matched - against quantizer module names, an optional ``"cfg"`` dict of quantizer attributes, - and an optional ``"enable"`` toggle. Entries are applied in list order; later entries - override earlier ones. The quantizer modules have names ending with + against quantizer module names, an optional ``"cfg"`` + :class:`QuantizerAttributeConfig ` + (or equivalent backward-compatible dict input), and an optional ``"enable"`` toggle. + Entries are applied in list order; later entries override earlier ones. The quantizer + modules have names ending with ``weight_quantizer`` and ``input_quantizer`` and they perform weight quantization and input quantization (or activation quantization) respectively. The quantizer modules are instances of @@ -238,13 +240,15 @@ def forward_loop(model) -> None: Returns: A pytorch model which has been quantized and calibrated. """ + validated_config: QuantizeConfig = QuantizeConfig.model_validate(config) if not is_quantized(model): - model = apply_mode(model, mode=[("quantize", dict(config))], registry=QuantizeModeRegistry) + model = apply_mode( + model, mode=[("quantize", dict(validated_config))], registry=QuantizeModeRegistry + ) else: # Already quantized, so lets apply the quant_cfg from the config - quant_cfg = QuantizeConfig(**dict(config)).quant_cfg - set_quantizer_by_cfg(model, quant_cfg) - return calibrate(model, config.get("algorithm"), forward_loop=forward_loop) + set_quantizer_by_cfg(model, validated_config.quant_cfg) + return calibrate(model, validated_config.algorithm, forward_loop=forward_loop) # TODO: create a config interface for auto_quantize and expose setting @@ -269,7 +273,7 @@ def forward_loop(model) -> None: def auto_quantize( model: nn.Module, constraints: dict[str, float | str] = {"effective_bits": 4.8}, - quantization_formats: list[dict[str, Any] | str] = [ + quantization_formats: list[QuantizeConfig | Mapping[str, Any] | str | None] = [ mtq.NVFP4_AWQ_LITE_CFG, mtq.FP8_DEFAULT_CFG, ], @@ -498,7 +502,7 @@ def forward_backward_step(model, batch) -> None: for quant_cfg, name in processed_quantization_formats: algo = QuantRecipe(quant_cfg, name=name).config.algorithm - algo_method = algo["method"] if isinstance(algo, dict) else algo + algo_method = algo["method"] if isinstance(algo, Mapping) else algo if algo_method not in _AUTO_QUANTIZE_SUPPORTED_ALGORITHMS: raise ValueError( f"Algorithm '{algo_method}' in '{name}' is not supported by auto_quantize yet. " diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 3ff7401ec3e..dd7b66dc3cd 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -18,7 +18,7 @@ import contextlib import math import warnings -from collections.abc import Callable +from collections.abc import Callable, Mapping, Sequence from typing import Any, Protocol import torch @@ -203,7 +203,9 @@ def __init__( # Optional quantizer cache for caching quantizer related encoding or tensors. self._quantizer_cache = None - def set_from_attribute_config(self, attribute_cfg: QuantizerAttributeConfig | dict[str, Any]): + def set_from_attribute_config( + self, attribute_cfg: QuantizerAttributeConfig | Mapping[str, Any] + ): """Set quantizer attributes from attribute_cfg. The attributes are defined in @@ -1423,14 +1425,20 @@ def get_modelopt_state(self) -> dict[str, Any]: return {"num_quantizers": len(self), "is_sequential_quantizer": True} def set_from_attribute_config( - self, attributes: list[QuantizerAttributeConfig] | list[dict[str, Any]] + self, + attributes: ( + QuantizerAttributeConfig + | Mapping[str, Any] + | Sequence[QuantizerAttributeConfig | Mapping[str, Any]] + ), ): - """Set the attributes of contained quantizers from a list of attribute_dicts.""" + """Set the attributes of contained quantizers from attribute configs.""" if not isinstance(attributes, (list, tuple)): - assert isinstance(attributes, (dict, QuantizerAttributeConfig)), ( - "attributes must be a list or a dict." - ) + if not isinstance(attributes, Mapping): + raise TypeError("attributes must be a list/tuple or a mapping.") attributes = [attributes] * len(self) + elif len(attributes) != len(self): + raise ValueError(f"Expected {len(self)} attribute configs, but got {len(attributes)}.") for attribute, quantizer in zip(attributes, self): quantizer.set_from_attribute_config(attribute) diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 1a177e04dc8..ab4c11b388f 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -17,6 +17,7 @@ import copy from collections import namedtuple +from collections.abc import Mapping, Sequence from contextlib import ExitStack, contextmanager, nullcontext from typing import TYPE_CHECKING, Any @@ -27,7 +28,7 @@ from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam from torch.distributed.tensor import Replicate -from modelopt.torch.quantization.config import QuantizerCfgEntry +from modelopt.torch.quantization.config import QuantizeConfig, QuantizerCfgEntry from modelopt.torch.utils import get_unwrapped_name, print_rank_0 if TYPE_CHECKING: @@ -915,12 +916,13 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True): def update_quant_cfg_with_kv_cache_quant( - quant_cfg: dict[str, Any], kv_cache_quant_cfg: list[QuantizerCfgEntry] -) -> dict[str, Any]: + quant_cfg: QuantizeConfig | Mapping[str, Any], + kv_cache_quant_cfg: Sequence[QuantizerCfgEntry | Mapping[str, Any]], +) -> QuantizeConfig: """Update the quant_cfg with the kv cache quant_cfg. Args: - quant_cfg: The outer quantization config dict (with ``"quant_cfg"`` and ``"algorithm"`` keys). + quant_cfg: The outer quantization config (with ``"quant_cfg"`` and ``"algorithm"`` keys). kv_cache_quant_cfg: A list of :class:`QuantizerCfgEntry ` dicts for KV cache quantization, typically ``some_kv_cfg["quant_cfg"]``. @@ -929,17 +931,17 @@ def update_quant_cfg_with_kv_cache_quant( A deep copy of ``quant_cfg`` with the KV cache entries appended to ``quant_cfg["quant_cfg"]``. """ # If quant_cfg["quant_cfg"] is None, it corresponds to only kv cache quantization case - quant_cfg = copy.deepcopy(quant_cfg) - inner: list[QuantizerCfgEntry] = quant_cfg.get("quant_cfg") or [ - {"quantizer_name": "*", "enable": False} - ] - quant_cfg["quant_cfg"] = inner + list(kv_cache_quant_cfg) + updated_quant_cfg: QuantizeConfig = QuantizeConfig.model_validate(copy.deepcopy(quant_cfg)) + inner = list( + updated_quant_cfg.get("quant_cfg") or [QuantizerCfgEntry(quantizer_name="*", enable=False)] + ) + updated_quant_cfg["quant_cfg"] = inner + copy.deepcopy(list(kv_cache_quant_cfg)) # Set default algorithm for kv cache quantization if not provided. - if not quant_cfg.get("algorithm"): - quant_cfg["algorithm"] = "max" - print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}") - return quant_cfg + if not updated_quant_cfg.get("algorithm"): + updated_quant_cfg["algorithm"] = "max" + print_rank_0(f"Updated quant_cfg with KV cache quantization: {updated_quant_cfg}") + return updated_quant_cfg def promote_nvfp4_static_quantizers(model: nn.Module) -> int: diff --git a/modelopt_recipes/configs/numerics/fp8.yaml b/modelopt_recipes/configs/numerics/fp8.yaml index ab1da6fad5f..7761dd106c0 100644 --- a/modelopt_recipes/configs/numerics/fp8.yaml +++ b/modelopt_recipes/configs/numerics/fp8.yaml @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# FP8 E4M3 quantizer attributes (per-tensor; used for weight/activation/KV). -# ``axis: null`` is explicit to match the hardcoded ``FP8_DEFAULT_CFG`` shape — -# downstream code that keys on ``"axis" in cfg`` sees the same dict layout. +# Per-tensor FP8 E4M3 quantizer attributes. # modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig num_bits: e4m3 diff --git a/modelopt_recipes/configs/numerics/int4_per_block.yaml b/modelopt_recipes/configs/numerics/int4_per_block.yaml new file mode 100644 index 00000000000..35d9f53a17a --- /dev/null +++ b/modelopt_recipes/configs/numerics/int4_per_block.yaml @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Static INT4 quantizer attributes with 128-value blocks on the last dimension. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig +num_bits: 4 +block_sizes: + -1: 128 + type: static diff --git a/modelopt_recipes/configs/numerics/int8_per_channel.yaml b/modelopt_recipes/configs/numerics/int8_per_channel.yaml new file mode 100644 index 00000000000..31c10635fc4 --- /dev/null +++ b/modelopt_recipes/configs/numerics/int8_per_channel.yaml @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Per-channel INT8 quantizer attributes with axis 0. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig +num_bits: 8 +axis: 0 diff --git a/modelopt_recipes/configs/numerics/mxfp4.yaml b/modelopt_recipes/configs/numerics/mxfp4.yaml new file mode 100644 index 00000000000..f32fde304f2 --- /dev/null +++ b/modelopt_recipes/configs/numerics/mxfp4.yaml @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Dynamic MXFP4 E2M1 block quantizer attributes with E8M0 scales. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig +num_bits: e2m1 +block_sizes: + -1: 32 + type: dynamic + scale_bits: e8m0 diff --git a/modelopt_recipes/configs/numerics/mxfp6.yaml b/modelopt_recipes/configs/numerics/mxfp6.yaml new file mode 100644 index 00000000000..f8849edd294 --- /dev/null +++ b/modelopt_recipes/configs/numerics/mxfp6.yaml @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Dynamic MXFP6 E3M2 block quantizer attributes with E8M0 scales. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig +num_bits: e3m2 +block_sizes: + -1: 32 + type: dynamic + scale_bits: e8m0 diff --git a/modelopt_recipes/configs/numerics/mxfp8.yaml b/modelopt_recipes/configs/numerics/mxfp8.yaml new file mode 100644 index 00000000000..46cb3d9f7c7 --- /dev/null +++ b/modelopt_recipes/configs/numerics/mxfp8.yaml @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Dynamic MXFP8 E4M3 block quantizer attributes with E8M0 scales. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig +num_bits: e4m3 +block_sizes: + -1: 32 + type: dynamic + scale_bits: e8m0 diff --git a/modelopt_recipes/configs/numerics/mxint8.yaml b/modelopt_recipes/configs/numerics/mxint8.yaml new file mode 100644 index 00000000000..388b251de67 --- /dev/null +++ b/modelopt_recipes/configs/numerics/mxint8.yaml @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Dynamic MXINT8 block quantizer attributes with E8M0 scales. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig +num_bits: 8 +block_sizes: + -1: 32 + type: dynamic + scale_bits: e8m0 diff --git a/modelopt_recipes/configs/numerics/nvfp4.yaml b/modelopt_recipes/configs/numerics/nvfp4.yaml index 68629c009fb..88598e36e85 100644 --- a/modelopt_recipes/configs/numerics/nvfp4.yaml +++ b/modelopt_recipes/configs/numerics/nvfp4.yaml @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# NVFP4 E2M1 blockwise quantizer attributes with FP8 E4M3 scales (dynamic calibration, the default). +# Dynamic NVFP4 E2M1 block quantizer attributes with FP8 E4M3 scales. # modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig num_bits: e2m1 diff --git a/modelopt_recipes/configs/numerics/nvfp4_bs32.yaml b/modelopt_recipes/configs/numerics/nvfp4_bs32.yaml new file mode 100644 index 00000000000..a84b63a91d3 --- /dev/null +++ b/modelopt_recipes/configs/numerics/nvfp4_bs32.yaml @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Dynamic NVFP4 E2M1 block quantizer attributes with FP8 E4M3 scales and block size 32. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig +num_bits: e2m1 +block_sizes: + -1: 32 + type: dynamic + scale_bits: e4m3 diff --git a/modelopt_recipes/configs/numerics/nvfp4_static.yaml b/modelopt_recipes/configs/numerics/nvfp4_static.yaml index 32bd247b79a..9f6ac62e11e 100644 --- a/modelopt_recipes/configs/numerics/nvfp4_static.yaml +++ b/modelopt_recipes/configs/numerics/nvfp4_static.yaml @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# NVFP4 E2M1 blockwise quantizer attributes with FP8 E4M3 scales (used for NVFP4 weights since weight scales can be static). +# Static NVFP4 E2M1 block quantizer attributes with FP8 E4M3 scales. # modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig num_bits: e2m1 diff --git a/modelopt_recipes/configs/ptq/presets/README.md b/modelopt_recipes/configs/ptq/presets/README.md index 3ab307fe453..b07f989ffe0 100644 --- a/modelopt_recipes/configs/ptq/presets/README.md +++ b/modelopt_recipes/configs/ptq/presets/README.md @@ -1,7 +1,7 @@ # PTQ Preset Configs This directory holds preset quantization configurations that serve as the -YAML source of truth for the hardcoded `*_CFG` dicts in +YAML source of truth for the `*_CFG` dicts exposed from `modelopt.torch.quantization.config` (e.g., `FP8_DEFAULT_CFG`, `FP8_KV_CFG`). diff --git a/modelopt_recipes/configs/ptq/presets/kv/fp8.yaml b/modelopt_recipes/configs/ptq/presets/kv/fp8.yaml index 7e97f0bc77b..21894ef9c01 100644 --- a/modelopt_recipes/configs/ptq/presets/kv/fp8.yaml +++ b/modelopt_recipes/configs/ptq/presets/kv/fp8.yaml @@ -13,10 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -# FP8 E4M3 KV cache quantization preset. -# Equivalent to the hardcoded FP8_KV_CFG in config.py. -# This is a partial config (no algorithm, no base_disable_all) — designed -# to be merged with a primary model quantization config. +# Partial QuantizeConfig that enables FP8 E4M3 KV-cache quantizers. +# Merge this fragment with a primary model quantization preset. # modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig imports: diff --git a/modelopt_recipes/configs/ptq/presets/kv/fp8_affine.yaml b/modelopt_recipes/configs/ptq/presets/kv/fp8_affine.yaml new file mode 100644 index 00000000000..4540df34ea9 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/kv/fp8_affine.yaml @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Partial QuantizeConfig that enables affine FP8 E4M3 KV-cache quantizers. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + kv_fp8_affine: configs/ptq/units/kv_fp8_affine + +quant_cfg: + - $import: kv_fp8_affine diff --git a/modelopt_recipes/configs/ptq/presets/kv/nvfp4.yaml b/modelopt_recipes/configs/ptq/presets/kv/nvfp4.yaml new file mode 100644 index 00000000000..6d759e2c115 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/kv/nvfp4.yaml @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Partial QuantizeConfig that enables NVFP4 KV-cache quantizers. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + kv_nvfp4: configs/ptq/units/kv_nvfp4 + +quant_cfg: + - $import: kv_nvfp4 diff --git a/modelopt_recipes/configs/ptq/presets/kv/nvfp4_affine.yaml b/modelopt_recipes/configs/ptq/presets/kv/nvfp4_affine.yaml new file mode 100644 index 00000000000..1f2a871010b --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/kv/nvfp4_affine.yaml @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Partial QuantizeConfig that enables affine NVFP4 KV-cache quantizers. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + kv_nvfp4_affine: configs/ptq/units/kv_nvfp4_affine + +quant_cfg: + - $import: kv_nvfp4_affine diff --git a/modelopt_recipes/configs/ptq/presets/kv/nvfp4_rotate.yaml b/modelopt_recipes/configs/ptq/presets/kv/nvfp4_rotate.yaml new file mode 100644 index 00000000000..2451ee1a359 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/kv/nvfp4_rotate.yaml @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Partial QuantizeConfig that enables rotated NVFP4 KV-cache quantizers. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + kv_nvfp4_rotate: configs/ptq/units/kv_nvfp4_rotate + +algorithm: max +quant_cfg: + - $import: kv_nvfp4_rotate diff --git a/modelopt_recipes/configs/ptq/presets/model/fp8.yaml b/modelopt_recipes/configs/ptq/presets/model/fp8.yaml index af80b57fe48..423904a6e18 100644 --- a/modelopt_recipes/configs/ptq/presets/model/fp8.yaml +++ b/modelopt_recipes/configs/ptq/presets/model/fp8.yaml @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# FP8 per-tensor weight and activation (W8A8), max calibration. -# Equivalent to the hardcoded FP8_DEFAULT_CFG in config.py. +# QuantizeConfig preset for W8A8 FP8 E4M3 with per-tensor weights and inputs. # modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig imports: diff --git a/modelopt_recipes/configs/ptq/presets/model/fp8_2d_blockwise_weight_only.yaml b/modelopt_recipes/configs/ptq/presets/model/fp8_2d_blockwise_weight_only.yaml new file mode 100644 index 00000000000..136f956288d --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/fp8_2d_blockwise_weight_only.yaml @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for FP8 E4M3 2D blockwise weight-only quantization. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + num_bits: e4m3 + block_sizes: + -1: 128 + -2: 128 + - quantizer_name: '*input_quantizer' + enable: false + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/fp8_per_channel_per_token.yaml b/modelopt_recipes/configs/ptq/presets/model/fp8_per_channel_per_token.yaml new file mode 100644 index 00000000000..8c3f1d78ccb --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/fp8_per_channel_per_token.yaml @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for FP8 E4M3 per-channel weights and per-token dynamic inputs. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + num_bits: e4m3 + axis: 0 + - quantizer_name: '*input_quantizer' + cfg: + num_bits: e4m3 + type: dynamic + block_sizes: + -1: + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/int4_awq.yaml b/modelopt_recipes/configs/ptq/presets/model/int4_awq.yaml new file mode 100644 index 00000000000..828aef7d06f --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/int4_awq.yaml @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for AWQ-lite INT4 weight-only quantization. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + int4_per_block: configs/numerics/int4_per_block + +algorithm: + method: awq_lite + alpha_step: 0.1 +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: int4_per_block + - quantizer_name: '*input_quantizer' + enable: false + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/int4_blockwise_weight_only.yaml b/modelopt_recipes/configs/ptq/presets/model/int4_blockwise_weight_only.yaml new file mode 100644 index 00000000000..beb3f20718f --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/int4_blockwise_weight_only.yaml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for INT4 blockwise weight-only quantization. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + num_bits: 4 + block_sizes: + -1: 128 + - quantizer_name: '*input_quantizer' + enable: false + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/int8.yaml b/modelopt_recipes/configs/ptq/presets/model/int8.yaml new file mode 100644 index 00000000000..7610d74a0df --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/int8.yaml @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for INT8 per-channel weights and per-tensor inputs. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + int8_per_channel: configs/numerics/int8_per_channel + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: int8_per_channel + - quantizer_name: '*input_quantizer' + cfg: + num_bits: 8 + axis: + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/int8_smoothquant.yaml b/modelopt_recipes/configs/ptq/presets/model/int8_smoothquant.yaml new file mode 100644 index 00000000000..e560a623914 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/int8_smoothquant.yaml @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for SmoothQuant INT8 per-channel weights and per-tensor inputs. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + int8_per_channel: configs/numerics/int8_per_channel + +algorithm: smoothquant +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: int8_per_channel + - quantizer_name: '*input_quantizer' + cfg: + num_bits: 8 + axis: + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/int8_weight_only.yaml b/modelopt_recipes/configs/ptq/presets/model/int8_weight_only.yaml new file mode 100644 index 00000000000..cc475ab6103 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/int8_weight_only.yaml @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for INT8 per-channel weight-only quantization. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + int8_per_channel: configs/numerics/int8_per_channel + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: int8_per_channel + - quantizer_name: '*input_quantizer' + enable: false + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/mamba_moe_fp8_aggressive.yaml b/modelopt_recipes/configs/ptq/presets/model/mamba_moe_fp8_aggressive.yaml new file mode 100644 index 00000000000..24fb95897ad --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/mamba_moe_fp8_aggressive.yaml @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for FP8 W8A8 Mamba-MoE quantization with shared exclusions. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + fp8: configs/numerics/fp8 + mamba_moe_disabled_quantizers: configs/ptq/units/mamba_moe_disabled_quantizers + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: fp8 + - quantizer_name: '*input_quantizer' + cfg: + $import: fp8 + - $import: default_disabled_quantizers + - $import: mamba_moe_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/mamba_moe_fp8_conservative.yaml b/modelopt_recipes/configs/ptq/presets/model/mamba_moe_fp8_conservative.yaml new file mode 100644 index 00000000000..b943b31dcde --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/mamba_moe_fp8_conservative.yaml @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for FP8 W8A8 Mamba-MoE quantization with mixer projections disabled. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + fp8: configs/numerics/fp8 + mamba_moe_disabled_quantizers: configs/ptq/units/mamba_moe_disabled_quantizers + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: fp8 + - quantizer_name: '*input_quantizer' + cfg: + $import: fp8 + - $import: default_disabled_quantizers + - $import: mamba_moe_disabled_quantizers + - quantizer_name: '*mixer.in_proj*' + enable: false + - quantizer_name: '*mixer.out_proj*' + enable: false diff --git a/modelopt_recipes/configs/ptq/presets/model/mamba_moe_nvfp4_aggressive.yaml b/modelopt_recipes/configs/ptq/presets/model/mamba_moe_nvfp4_aggressive.yaml new file mode 100644 index 00000000000..6346548eb84 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/mamba_moe_nvfp4_aggressive.yaml @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for NVFP4 W4A4 Mamba-MoE quantization with shared exclusions. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + mamba_moe_disabled_quantizers: configs/ptq/units/mamba_moe_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*input_quantizer' + cfg: + $import: nvfp4 + - $import: default_disabled_quantizers + - $import: mamba_moe_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/mamba_moe_nvfp4_conservative.yaml b/modelopt_recipes/configs/ptq/presets/model/mamba_moe_nvfp4_conservative.yaml new file mode 100644 index 00000000000..f94a4b1fc4d --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/mamba_moe_nvfp4_conservative.yaml @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for NVFP4 W4A4 Mamba-MoE quantization with mixer projections disabled. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + mamba_moe_disabled_quantizers: configs/ptq/units/mamba_moe_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*input_quantizer' + cfg: + $import: nvfp4 + - $import: default_disabled_quantizers + - $import: mamba_moe_disabled_quantizers + - quantizer_name: '*mixer.in_proj*' + enable: false + - quantizer_name: '*mixer.out_proj*' + enable: false diff --git a/modelopt_recipes/configs/ptq/presets/model/mxfp4.yaml b/modelopt_recipes/configs/ptq/presets/model/mxfp4.yaml new file mode 100644 index 00000000000..982e22144ec --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/mxfp4.yaml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for dynamic MXFP4 block quantization on weights and inputs. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + mxfp4: configs/numerics/mxfp4 + +algorithm: +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: mxfp4 + - quantizer_name: '*input_quantizer' + cfg: + $import: mxfp4 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/mxfp4_mlp_weight_only.yaml b/modelopt_recipes/configs/ptq/presets/model/mxfp4_mlp_weight_only.yaml new file mode 100644 index 00000000000..8d03600e872 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/mxfp4_mlp_weight_only.yaml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for dynamic MXFP4 block weight-only quantization on MLP/MoE layers. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + mxfp4: configs/numerics/mxfp4 + +algorithm: +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*mlp*weight_quantizer' + cfg: + $import: mxfp4 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + cfg: + $import: mxfp4 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/mxfp6.yaml b/modelopt_recipes/configs/ptq/presets/model/mxfp6.yaml new file mode 100644 index 00000000000..e8d590f3848 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/mxfp6.yaml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for dynamic MXFP6 block quantization on weights and inputs. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + mxfp6: configs/numerics/mxfp6 + +algorithm: +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: mxfp6 + - quantizer_name: '*input_quantizer' + cfg: + $import: mxfp6 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/mxfp8.yaml b/modelopt_recipes/configs/ptq/presets/model/mxfp8.yaml new file mode 100644 index 00000000000..7cf2832311c --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/mxfp8.yaml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for dynamic MXFP8 block quantization on weights and inputs. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + mxfp8: configs/numerics/mxfp8 + +algorithm: +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: mxfp8 + - quantizer_name: '*input_quantizer' + cfg: + $import: mxfp8 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/mxint8.yaml b/modelopt_recipes/configs/ptq/presets/model/mxint8.yaml new file mode 100644 index 00000000000..e6ef1ca3d06 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/mxint8.yaml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for dynamic MXINT8 block quantization on weights and inputs. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + mxint8: configs/numerics/mxint8 + +algorithm: +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: mxint8 + - quantizer_name: '*input_quantizer' + cfg: + $import: mxint8 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/nvfp4.yaml b/modelopt_recipes/configs/ptq/presets/model/nvfp4.yaml new file mode 100644 index 00000000000..f569f501433 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/nvfp4.yaml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for dynamic NVFP4 W4A4 quantization on weights and inputs. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*input_quantizer' + cfg: + $import: nvfp4 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/nvfp4_awq_clip.yaml b/modelopt_recipes/configs/ptq/presets/model/nvfp4_awq_clip.yaml new file mode 100644 index 00000000000..d3cce284196 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/nvfp4_awq_clip.yaml @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for NVFP4 W4A4 quantization with AWQ clip calibration. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + +algorithm: + method: awq_clip +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*input_quantizer' + cfg: + $import: nvfp4 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/nvfp4_awq_full.yaml b/modelopt_recipes/configs/ptq/presets/model/nvfp4_awq_full.yaml new file mode 100644 index 00000000000..38934b9c05f --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/nvfp4_awq_full.yaml @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for NVFP4 W4A4 quantization with full AWQ calibration. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + +algorithm: + method: awq_full + alpha_step: 0.1 +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*input_quantizer' + cfg: + $import: nvfp4 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/nvfp4_awq_lite.yaml b/modelopt_recipes/configs/ptq/presets/model/nvfp4_awq_lite.yaml new file mode 100644 index 00000000000..e69daf39e57 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/nvfp4_awq_lite.yaml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for NVFP4 W4A4 quantization with AWQ-lite calibration. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + +algorithm: awq_lite +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*input_quantizer' + cfg: + $import: nvfp4 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/nvfp4_experts_only.yaml b/modelopt_recipes/configs/ptq/presets/model/nvfp4_experts_only.yaml new file mode 100644 index 00000000000..e2b4a956b6c --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/nvfp4_experts_only.yaml @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for dynamic NVFP4 W4A4 quantization on expert layers only. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*mlp.experts*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*mlp.experts*input_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*input_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*.experts.*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*.experts.*input_quantizer' + cfg: + $import: nvfp4 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/nvfp4_fp8_mha.yaml b/modelopt_recipes/configs/ptq/presets/model/nvfp4_fp8_mha.yaml new file mode 100644 index 00000000000..fadb7b9bbc4 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/nvfp4_fp8_mha.yaml @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for Diffusers NVFP4 with FP8 attention quantizers. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + nvfp4: configs/numerics/nvfp4 + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*input_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*output_quantizer' + enable: false + - quantizer_name: '*q_bmm_quantizer' + cfg: + num_bits: e4m3 + - quantizer_name: '*k_bmm_quantizer' + cfg: + num_bits: e4m3 + - quantizer_name: '*v_bmm_quantizer' + cfg: + num_bits: e4m3 + - quantizer_name: '*softmax_quantizer' + cfg: + num_bits: e4m3 + - quantizer_name: 'transformer_blocks*bmm2_output_quantizer' + cfg: + num_bits: e4m3 diff --git a/modelopt_recipes/configs/ptq/presets/model/nvfp4_mlp_only.yaml b/modelopt_recipes/configs/ptq/presets/model/nvfp4_mlp_only.yaml new file mode 100644 index 00000000000..dbe32c0b3a4 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/nvfp4_mlp_only.yaml @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for dynamic NVFP4 W4A4 quantization on MLP/MoE layers only. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*mlp*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*mlp*input_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*input_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*.experts.*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*.experts.*input_quantizer' + cfg: + $import: nvfp4 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/nvfp4_mlp_weight_only.yaml b/modelopt_recipes/configs/ptq/presets/model/nvfp4_mlp_weight_only.yaml new file mode 100644 index 00000000000..952ea3a90db --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/nvfp4_mlp_weight_only.yaml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for NVFP4 block-size-32 weight-only quantization on MLP/MoE layers. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4_bs32: configs/numerics/nvfp4_bs32 + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*mlp*weight_quantizer' + cfg: + $import: nvfp4_bs32 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + cfg: + $import: nvfp4_bs32 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/nvfp4_omlp_only.yaml b/modelopt_recipes/configs/ptq/presets/model/nvfp4_omlp_only.yaml new file mode 100644 index 00000000000..1b7e1cbd7c7 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/nvfp4_omlp_only.yaml @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for dynamic NVFP4 W4A4 quantization on output projections and MLP/MoE layers. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*o_proj*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*o_proj*input_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*mlp*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*mlp*input_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*block_sparse_moe*input_quantizer' + cfg: + $import: nvfp4 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/nvfp4_svdquant.yaml b/modelopt_recipes/configs/ptq/presets/model/nvfp4_svdquant.yaml new file mode 100644 index 00000000000..8101d666217 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/nvfp4_svdquant.yaml @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for NVFP4 W4A4 quantization with SVDQuant low-rank calibration. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + +algorithm: + method: svdquant + lowrank: 32 +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4 + - quantizer_name: '*input_quantizer' + cfg: + $import: nvfp4 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/nvfp4_w4a4_weight_local_hessian.yaml b/modelopt_recipes/configs/ptq/presets/model/nvfp4_w4a4_weight_local_hessian.yaml new file mode 100644 index 00000000000..ac6a3094b7c --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/nvfp4_w4a4_weight_local_hessian.yaml @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for NVFP4 W4A4 with static weight scales from local-Hessian calibration. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + nvfp4_static: configs/numerics/nvfp4_static + +algorithm: + method: local_hessian + fp8_scale_sweep: true +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4_static + - quantizer_name: '*input_quantizer' + cfg: + $import: nvfp4 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/nvfp4_w4a4_weight_mse_fp8_sweep.yaml b/modelopt_recipes/configs/ptq/presets/model/nvfp4_w4a4_weight_mse_fp8_sweep.yaml new file mode 100644 index 00000000000..3ae22dbc3a6 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/nvfp4_w4a4_weight_mse_fp8_sweep.yaml @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for NVFP4 W4A4 with static weight scales from MSE FP8-scale sweep. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4: configs/numerics/nvfp4 + nvfp4_static: configs/numerics/nvfp4_static + +algorithm: + method: mse + fp8_scale_sweep: true +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4_static + - quantizer_name: '*input_quantizer' + cfg: + $import: nvfp4 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/w4a8_awq_beta.yaml b/modelopt_recipes/configs/ptq/presets/model/w4a8_awq_beta.yaml new file mode 100644 index 00000000000..12073e14601 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/w4a8_awq_beta.yaml @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for W4A8 AWQ-lite with INT4 block weights and FP8 inputs. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + fp8: configs/numerics/fp8 + int4_per_block: configs/numerics/int4_per_block + +algorithm: awq_lite +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + - $import: int4_per_block + - $import: fp8 + - quantizer_name: '*input_quantizer' + cfg: + $import: fp8 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/w4a8_mxfp4_fp8.yaml b/modelopt_recipes/configs/ptq/presets/model/w4a8_mxfp4_fp8.yaml new file mode 100644 index 00000000000..428cb659da5 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/w4a8_mxfp4_fp8.yaml @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for W4A8 with MXFP4 block weights and FP8 inputs. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + fp8: configs/numerics/fp8 + mxfp4: configs/numerics/mxfp4 + +algorithm: +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: mxfp4 + - quantizer_name: '*input_quantizer' + cfg: + $import: fp8 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/presets/model/w4a8_nvfp4_fp8.yaml b/modelopt_recipes/configs/ptq/presets/model/w4a8_nvfp4_fp8.yaml new file mode 100644 index 00000000000..86b335cbc11 --- /dev/null +++ b/modelopt_recipes/configs/ptq/presets/model/w4a8_nvfp4_fp8.yaml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizeConfig preset for W4A8 with NVFP4 block-size-32 weights and FP8 inputs. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig +imports: + base_disable_all: configs/ptq/units/base_disable_all + default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers + nvfp4_bs32: configs/numerics/nvfp4_bs32 + +algorithm: max +quant_cfg: + - $import: base_disable_all + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4_bs32 + - quantizer_name: '*input_quantizer' + cfg: + num_bits: e4m3 + - $import: default_disabled_quantizers diff --git a/modelopt_recipes/configs/ptq/units/README.md b/modelopt_recipes/configs/ptq/units/README.md index b7a7421f9fc..91e3dab973a 100644 --- a/modelopt_recipes/configs/ptq/units/README.md +++ b/modelopt_recipes/configs/ptq/units/README.md @@ -19,7 +19,12 @@ recipes (under `general/` or `models/`) or presets (under `presets/`). | `base_disable_all.yaml` | Deny-all entry: disables all quantizers as the first step | | `default_disabled_quantizers.yaml` | Standard exclusions (LM head, routers, BatchNorm, etc.) | | `kv_fp8.yaml` | FP8 E4M3 KV cache quantizer entry; supported on Hopper+ GPUs | +| `kv_fp8_affine.yaml` | FP8 E4M3 affine KV cache quantizer entries; supported on Hopper+ GPUs | | `kv_fp8_cast.yaml` | FP8 E4M3 KV cache with constant amax (skips KV calibration); supported on Hopper+ GPUs | +| `kv_nvfp4.yaml` | NVFP4 KV cache quantizer entry; supported on Blackwell+ GPUs | +| `kv_nvfp4_affine.yaml` | NVFP4 affine KV cache quantizer entries; supported on Blackwell+ GPUs | | `kv_nvfp4_cast.yaml` | NVFP4 KV cache with constant amax (skips KV calibration); supported on Blackwell+ GPUs | +| `kv_nvfp4_rotate.yaml` | NVFP4 rotated KV cache quantizer entries; supported on Blackwell+ GPUs | +| `mamba_moe_disabled_quantizers.yaml` | Shared Mamba-MoE quantizer exclusions | | `w8a8_fp8_fp8.yaml` | FP8 weight + activation quantizer entries (W8A8); supported on Hopper+ GPUs | | `w4a4_nvfp4_nvfp4.yaml` | NVFP4 weight + activation quantizer entries (W4A4); supported on Blackwell+ GPUs | diff --git a/modelopt_recipes/configs/ptq/units/base_disable_all.yaml b/modelopt_recipes/configs/ptq/units/base_disable_all.yaml index 9a520ee207f..ee96d00411c 100644 --- a/modelopt_recipes/configs/ptq/units/base_disable_all.yaml +++ b/modelopt_recipes/configs/ptq/units/base_disable_all.yaml @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Disable all quantizers by default (deny-all-then-configure pattern). +# QuantizerCfgList snippet that disables every quantizer before selective re-enabling. # modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgEntry quantizer_name: '*' diff --git a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml index 1508f942776..86d5a64c673 100644 --- a/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml +++ b/modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Standard quantizer exclusions: layers that should not be quantized. +# QuantizerCfgList snippet for standard module patterns that should remain unquantized. # modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig - quantizer_name: '*block_sparse_moe.gate*' diff --git a/modelopt_recipes/configs/ptq/units/kv_fp8.yaml b/modelopt_recipes/configs/ptq/units/kv_fp8.yaml index 646be96709f..86156e5e95c 100644 --- a/modelopt_recipes/configs/ptq/units/kv_fp8.yaml +++ b/modelopt_recipes/configs/ptq/units/kv_fp8.yaml @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# FP8 E4M3 KV cache quantization. +# QuantizerCfgList snippet that enables FP8 E4M3 KV-cache quantizers. # # This snippet uses multi-document YAML (separated by ---) because it is a # list-valued snippet that also needs to $import another snippet. YAML only diff --git a/modelopt_recipes/configs/ptq/units/kv_fp8_affine.yaml b/modelopt_recipes/configs/ptq/units/kv_fp8_affine.yaml new file mode 100644 index 00000000000..5458e5511ca --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/kv_fp8_affine.yaml @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizerCfgList snippet that enables affine FP8 E4M3 KV-cache quantizers. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig +imports: + kv_fp8: configs/ptq/units/kv_fp8 +--- + - $import: kv_fp8 + - quantizer_name: '*[kv]_bmm_quantizer' + cfg: + num_bits: e4m3 + axis: + bias: + -2: + -4: + type: static diff --git a/modelopt_recipes/configs/ptq/units/kv_fp8_cast.yaml b/modelopt_recipes/configs/ptq/units/kv_fp8_cast.yaml index 64cfbd47bc7..606c969ab37 100644 --- a/modelopt_recipes/configs/ptq/units/kv_fp8_cast.yaml +++ b/modelopt_recipes/configs/ptq/units/kv_fp8_cast.yaml @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# FP8 E4M3 KV cache quantization with constant amax. +# QuantizerCfgList snippet that enables FP8 E4M3 KV-cache quantizers with constant amax. # modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig imports: diff --git a/modelopt_recipes/configs/ptq/units/kv_nvfp4.yaml b/modelopt_recipes/configs/ptq/units/kv_nvfp4.yaml new file mode 100644 index 00000000000..a95b854a0aa --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/kv_nvfp4.yaml @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizerCfgList snippet that enables NVFP4 KV-cache quantizers. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig +imports: + nvfp4: configs/numerics/nvfp4 +--- + - quantizer_name: '*[kv]_bmm_quantizer' + cfg: + $import: nvfp4 diff --git a/modelopt_recipes/configs/ptq/units/kv_nvfp4_affine.yaml b/modelopt_recipes/configs/ptq/units/kv_nvfp4_affine.yaml new file mode 100644 index 00000000000..2122e8b3431 --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/kv_nvfp4_affine.yaml @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizerCfgList snippet that enables affine NVFP4 KV-cache quantizers. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig +imports: + kv_nvfp4: configs/ptq/units/kv_nvfp4 + nvfp4: configs/numerics/nvfp4 +--- + - $import: kv_nvfp4 + - quantizer_name: '*[kv]_bmm_quantizer' + cfg: + $import: nvfp4 + bias: + -2: + -4: + type: static diff --git a/modelopt_recipes/configs/ptq/units/kv_nvfp4_cast.yaml b/modelopt_recipes/configs/ptq/units/kv_nvfp4_cast.yaml index 3fc5d597aa8..b5658c2ff11 100644 --- a/modelopt_recipes/configs/ptq/units/kv_nvfp4_cast.yaml +++ b/modelopt_recipes/configs/ptq/units/kv_nvfp4_cast.yaml @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# NVFP4 KV cache quantization with constant amax. +# QuantizerCfgList snippet that enables NVFP4 KV-cache quantizers with constant amax. # # The deployment kernel upcasts NVFP4 KV values to FP8 before attention, so the # scale must land in the FP8 range. diff --git a/modelopt_recipes/configs/ptq/units/kv_nvfp4_rotate.yaml b/modelopt_recipes/configs/ptq/units/kv_nvfp4_rotate.yaml new file mode 100644 index 00000000000..b117edbf1be --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/kv_nvfp4_rotate.yaml @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizerCfgList snippet that enables rotated NVFP4 KV-cache quantizers. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig +imports: + nvfp4: configs/numerics/nvfp4 +--- + - quantizer_name: '*q_bmm_quantizer' + cfg: + rotate: true + enable: false + - quantizer_name: '*k_bmm_quantizer' + cfg: + $import: nvfp4 + rotate: true + - quantizer_name: '*v_bmm_quantizer' + cfg: + $import: nvfp4 diff --git a/modelopt_recipes/configs/ptq/units/mamba_moe_disabled_quantizers.yaml b/modelopt_recipes/configs/ptq/units/mamba_moe_disabled_quantizers.yaml new file mode 100644 index 00000000000..c9b87f8d212 --- /dev/null +++ b/modelopt_recipes/configs/ptq/units/mamba_moe_disabled_quantizers.yaml @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# QuantizerCfgList snippet with Mamba/MoE-specific exclusion patterns. + +# modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig + - quantizer_name: '*fc1_latent_proj*' + enable: false + - quantizer_name: '*fc2_latent_proj*' + enable: false + - quantizer_name: '*q_proj*' + enable: false + - quantizer_name: '*k_proj*' + enable: false + - quantizer_name: '*v_proj*' + enable: false + - quantizer_name: '*o_proj*' + enable: false + - quantizer_name: '*self_attention.linear_qkv*' + enable: false + - quantizer_name: '*self_attention.linear_proj*' + enable: false diff --git a/modelopt_recipes/configs/ptq/units/w4a4_nvfp4_nvfp4.yaml b/modelopt_recipes/configs/ptq/units/w4a4_nvfp4_nvfp4.yaml index 033cdf76697..010d81ab621 100644 --- a/modelopt_recipes/configs/ptq/units/w4a4_nvfp4_nvfp4.yaml +++ b/modelopt_recipes/configs/ptq/units/w4a4_nvfp4_nvfp4.yaml @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# W4A4 NVFP4: NVFP4 E2M1 dynamic weight and activation quantizers. +# QuantizerCfgList snippet that enables dynamic NVFP4 on weight and input quantizers. # modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig imports: diff --git a/modelopt_recipes/configs/ptq/units/w8a8_fp8_fp8.yaml b/modelopt_recipes/configs/ptq/units/w8a8_fp8_fp8.yaml index 07db59ff3b0..068f38d1497 100644 --- a/modelopt_recipes/configs/ptq/units/w8a8_fp8_fp8.yaml +++ b/modelopt_recipes/configs/ptq/units/w8a8_fp8_fp8.yaml @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# W8A8 FP8: FP8 E4M3 weight and activation quantizers. +# QuantizerCfgList snippet that enables per-tensor FP8 E4M3 on weight and input quantizers. # modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig imports: diff --git a/modelopt_recipes/general/ptq/fp8_default-kv_fp8.yaml b/modelopt_recipes/general/ptq/fp8_default-kv_fp8.yaml index 4c6ba99e11f..ea2ac567290 100644 --- a/modelopt_recipes/general/ptq/fp8_default-kv_fp8.yaml +++ b/modelopt_recipes/general/ptq/fp8_default-kv_fp8.yaml @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Composed PTQ recipe for W8A8 FP8 E4M3 model quantization with FP8 KV-cache quantization. + imports: base_disable_all: configs/ptq/units/base_disable_all default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers @@ -21,7 +23,8 @@ imports: metadata: recipe_type: ptq - description: FP8 per-tensor weight and activation (W8A8), FP8 KV cache, max calibration. + description: >- + Composes W8A8 FP8 E4M3 model quantization with FP8 KV-cache quantization; uses max calibration. quantize: algorithm: max quant_cfg: diff --git a/modelopt_recipes/general/ptq/fp8_default-kv_fp8_cast.yaml b/modelopt_recipes/general/ptq/fp8_default-kv_fp8_cast.yaml index f99a716ced5..4e24bf53274 100644 --- a/modelopt_recipes/general/ptq/fp8_default-kv_fp8_cast.yaml +++ b/modelopt_recipes/general/ptq/fp8_default-kv_fp8_cast.yaml @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Composed PTQ recipe for W8A8 FP8 E4M3 model quantization with FP8 KV-cache cast mode. + imports: base_disable_all: configs/ptq/units/base_disable_all default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers @@ -22,8 +24,8 @@ imports: metadata: recipe_type: ptq description: >- - FP8 per-tensor weight and activation (W8A8), FP8 KV cache with constant amax - (skips KV calibration; amax hardcoded to FP8 E4M3 max 448.0), max calibration. + Composes W8A8 FP8 E4M3 model quantization with FP8 KV-cache cast mode using constant amax; uses + max calibration. quantize: algorithm: max quant_cfg: diff --git a/modelopt_recipes/general/ptq/nvfp4_default-kv_fp8.yaml b/modelopt_recipes/general/ptq/nvfp4_default-kv_fp8.yaml index 63b6d673b94..6a65efef57a 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-kv_fp8.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-kv_fp8.yaml @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Composed PTQ recipe for dynamic NVFP4 W4A4 model quantization with FP8 KV-cache quantization. + imports: base_disable_all: configs/ptq/units/base_disable_all default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers @@ -21,7 +23,9 @@ imports: metadata: recipe_type: ptq - description: NVFP4 W4A4, FP8 KV cache, max calibration. + description: >- + Composes dynamic NVFP4 W4A4 model quantization with FP8 KV-cache quantization; uses max + calibration. quantize: algorithm: max quant_cfg: diff --git a/modelopt_recipes/general/ptq/nvfp4_default-kv_fp8_cast.yaml b/modelopt_recipes/general/ptq/nvfp4_default-kv_fp8_cast.yaml index 1504f33d3cc..312cdd16c8d 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-kv_fp8_cast.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-kv_fp8_cast.yaml @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Composed PTQ recipe for dynamic NVFP4 W4A4 model quantization with FP8 KV-cache cast mode. + imports: base_disable_all: configs/ptq/units/base_disable_all default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers @@ -22,8 +24,8 @@ imports: metadata: recipe_type: ptq description: >- - NVFP4 W4A4, FP8 KV cache with constant amax (skips KV calibration; amax - hardcoded to FP8 E4M3 max 448.0), max calibration. + Composes dynamic NVFP4 W4A4 model quantization with FP8 KV-cache cast mode using constant amax; + uses max calibration. quantize: algorithm: max quant_cfg: diff --git a/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml b/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml index 6aabb04a150..6dee51857c8 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Composed PTQ recipe for NVFP4 W4A4 model quantization with KV quantizers disabled and GPTQ calibration. + imports: base_disable_all: configs/ptq/units/base_disable_all default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers @@ -21,7 +23,9 @@ imports: metadata: recipe_type: ptq - description: NVFP4 weight and activation (W4A4), gptq layerwise calibration. + description: >- + Applies NVFP4 W4A4 with static weight scales, dynamic inputs, KV quantizers disabled, and GPTQ + layerwise calibration. quantize: algorithm: method: gptq diff --git a/modelopt_recipes/general/ptq/nvfp4_default-kv_nvfp4_cast.yaml b/modelopt_recipes/general/ptq/nvfp4_default-kv_nvfp4_cast.yaml index d9991e0b9c3..0acdf6050db 100644 --- a/modelopt_recipes/general/ptq/nvfp4_default-kv_nvfp4_cast.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_default-kv_nvfp4_cast.yaml @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Composed PTQ recipe for dynamic NVFP4 W4A4 model quantization with NVFP4 KV-cache cast mode. + imports: base_disable_all: configs/ptq/units/base_disable_all default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers @@ -22,10 +24,8 @@ imports: metadata: recipe_type: ptq description: >- - NVFP4 W4A4, NVFP4 KV cache with constant amax (skips KV calibration; amax - hardcoded to FP8 E4M3 max 448.0 — the deployment kernel upcasts NVFP4 KV - values to FP8 before attention, so the scale must land in the FP8 range), - max calibration. + Composes dynamic NVFP4 W4A4 model quantization with NVFP4 KV-cache cast mode using constant + amax; uses max calibration. quantize: algorithm: max quant_cfg: diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml index 6222ab39e3a..08864c8a50d 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Composed PTQ recipe for expert-only dynamic NVFP4 quantization with FP8 KV-cache quantization. + imports: base_disable_all: configs/ptq/units/base_disable_all default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers @@ -21,7 +23,9 @@ imports: metadata: recipe_type: ptq - description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max layerwise calibration. + description: >- + Applies dynamic NVFP4 only to expert-layer weight and input quantizers, plus FP8 KV-cache + quantization; uses max calibration. quantize: algorithm: method: max diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-kv_fp8_cast.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-kv_fp8_cast.yaml index 5db1666402d..5bf9a36dc31 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-kv_fp8_cast.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only_mse-kv_fp8_cast.yaml @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Composed PTQ recipe for expert-only NVFP4 quantization with MSE weight calibration and FP8 KV-cache cast mode. + imports: base_disable_all: configs/ptq/units/base_disable_all default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers @@ -22,7 +24,9 @@ imports: metadata: recipe_type: ptq - description: NVFP4 static weight (MSE FP8-scale sweep) and dynamic activation for expert layers only (W4A4), FP8 KV cache with constant amax. + description: >- + Applies static NVFP4 weight scales from MSE FP8-scale sweep and dynamic NVFP4 inputs to expert + layers only, plus FP8 KV-cache cast mode. quantize: algorithm: method: mse diff --git a/modelopt_recipes/general/ptq/nvfp4_mlp_only-kv_fp8.yaml b/modelopt_recipes/general/ptq/nvfp4_mlp_only-kv_fp8.yaml index 60cba464e0c..a4cf71a1dbd 100644 --- a/modelopt_recipes/general/ptq/nvfp4_mlp_only-kv_fp8.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_mlp_only-kv_fp8.yaml @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Composed PTQ recipe for MLP/MoE-only dynamic NVFP4 quantization with FP8 KV-cache quantization. + imports: base_disable_all: configs/ptq/units/base_disable_all default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers @@ -21,7 +23,9 @@ imports: metadata: recipe_type: ptq - description: NVFP4 static weight and dynamic activation for all linear layers (W4A4), FP8 KV cache, max calibration. + description: >- + Applies dynamic NVFP4 only to MLP/MoE weight and input quantizers, plus FP8 KV-cache + quantization; uses max calibration. quantize: algorithm: max quant_cfg: diff --git a/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-kv_fp8_cast.yaml b/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-kv_fp8_cast.yaml index 875fb47c9b3..2ea2c0ab13e 100644 --- a/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-kv_fp8_cast.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_mlp_only_mse-kv_fp8_cast.yaml @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Composed PTQ recipe for MLP/MoE-only NVFP4 quantization with MSE weight calibration and FP8 KV-cache cast mode. + imports: base_disable_all: configs/ptq/units/base_disable_all default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers @@ -22,7 +24,9 @@ imports: metadata: recipe_type: ptq - description: NVFP4 static weight (MSE FP8-scale sweep) and dynamic activation for MLP/MoE linear layers (W4A4), FP8 KV cache with constant amax. + description: >- + Applies static NVFP4 weight scales from MSE FP8-scale sweep and dynamic NVFP4 inputs to MLP/MoE + layers, plus FP8 KV-cache cast mode. quantize: algorithm: method: mse diff --git a/modelopt_recipes/general/ptq/nvfp4_omlp_only-kv_fp8.yaml b/modelopt_recipes/general/ptq/nvfp4_omlp_only-kv_fp8.yaml index 13c7cac0797..5348e8c7123 100644 --- a/modelopt_recipes/general/ptq/nvfp4_omlp_only-kv_fp8.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_omlp_only-kv_fp8.yaml @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Composed PTQ recipe for output-projection and MLP/MoE dynamic NVFP4 quantization with FP8 KV-cache quantization. + imports: base_disable_all: configs/ptq/units/base_disable_all default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers @@ -21,7 +23,9 @@ imports: metadata: recipe_type: ptq - description: NVFP4 static weight and dynamic activation for all linear layers including output projections, FP8 KV cache, max calibration. + description: >- + Applies dynamic NVFP4 to output-projection and MLP/MoE weight and input quantizers, plus + FP8 KV-cache quantization; uses max calibration. quantize: algorithm: max quant_cfg: diff --git a/modelopt_recipes/general/speculative_decoding/dflash.yaml b/modelopt_recipes/general/speculative_decoding/dflash.yaml index 3d43e0fe1d4..d6458a9b26e 100644 --- a/modelopt_recipes/general/speculative_decoding/dflash.yaml +++ b/modelopt_recipes/general/speculative_decoding/dflash.yaml @@ -1,4 +1,4 @@ -# Base config for DFlash training. Override fields via OmegaConf dotlist on the CLI. +# DFlash speculative-decoding training recipe. Override fields via OmegaConf dotlist on the CLI. # maps to ModelArguments (main.py) model: diff --git a/modelopt_recipes/general/speculative_decoding/eagle3.yaml b/modelopt_recipes/general/speculative_decoding/eagle3.yaml index a1b7ff77708..fb9484a909f 100644 --- a/modelopt_recipes/general/speculative_decoding/eagle3.yaml +++ b/modelopt_recipes/general/speculative_decoding/eagle3.yaml @@ -1,4 +1,4 @@ -# Base config for EAGLE3 training. Override fields via OmegaConf dotlist on the CLI. +# EAGLE3 speculative-decoding training recipe. Override fields via OmegaConf dotlist on the CLI. # maps to ModelArguments (main.py) model: diff --git a/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml b/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml index c00aff7d44f..17eb0d7a716 100644 --- a/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml +++ b/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml @@ -13,9 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Model-specific PTQ recipe for Step3.5-Flash NVFP4 MLP/MoE quantization with FP8 KV cache. + metadata: recipe_type: ptq - description: NVFP4 static weight and dynamic activation for MoE/MLP projections (W4A4), FP8 KV cache, max calibration. + description: >- + Step3.5-Flash PTQ recipe that enables dynamic NVFP4 on MoE/MLP weight and input quantizers, + enables FP8 KV-cache quantizers, and leaves other quantizers disabled. quantize: algorithm: max quant_cfg: diff --git a/tests/examples/diffusers/test_diffusers.py b/tests/examples/diffusers/test_diffusers.py index 5b117b41b3f..637214f736f 100644 --- a/tests/examples/diffusers/test_diffusers.py +++ b/tests/examples/diffusers/test_diffusers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import importlib.util +import sys from pathlib import Path from typing import NamedTuple @@ -22,6 +25,44 @@ from _test_utils.torch.misc import minimum_sm +def _load_diffusers_quantization_config_module(): + quantization_dir = ( + Path(__file__).resolve().parents[3] / "examples" / "diffusers" / "quantization" + ) + sys.path.insert(0, str(quantization_dir)) + try: + spec = importlib.util.spec_from_file_location( + "diffusers_quantization_config", quantization_dir / "config.py" + ) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + finally: + sys.path.remove(str(quantization_dir)) + return module + + +def test_diffusers_quant_config_attr_uses_explicit_schema_keys() -> None: + import modelopt.torch.quantization as mtq + + config_module = _load_diffusers_quantization_config_module() + quant_config = copy.deepcopy(mtq.INT8_SMOOTHQUANT_CFG) + input_cfg = next( + entry["cfg"] + for entry in quant_config["quant_cfg"] + if entry["quantizer_name"] == "*input_quantizer" + ) + + assert "trt_high_precision_dtype" in input_cfg + assert "trt_high_precision_dtype" not in input_cfg.explicit_keys() + + config_module.set_quant_config_attr(quant_config, "Half", "smoothquant", alpha=0.8) + + assert input_cfg["trt_high_precision_dtype"] == "Half" + assert quant_config["algorithm"] == {"method": "smoothquant", "alpha": 0.8} + + class DiffuserModel(NamedTuple): dtype: str name: str diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index a5c8ccaf479..382805ac388 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -16,11 +16,17 @@ """Unit tests for modelopt.recipe.loader and modelopt.recipe.loader.load_config.""" import re +from importlib.resources import files import pytest -from modelopt.recipe.config import ModelOptPTQRecipe, RecipeType +from modelopt.recipe.config import ModelOptPTQRecipe, RecipeMetadataConfig, RecipeType from modelopt.recipe.loader import load_config, load_recipe +from modelopt.torch.quantization.config import ( + QuantizeConfig, + QuantizerAttributeConfig, + QuantizerCfgEntry, +) # --------------------------------------------------------------------------- # Static YAML fixtures @@ -75,6 +81,20 @@ def _write_quantizer_cfg_list(path, body: str): path.write_text(QUANTIZER_CFG_LIST_SCHEMA + body) +def _cfg_to_dict(cfg): + if isinstance(cfg, QuantizerAttributeConfig): + return cfg.model_dump(exclude_unset=True) + if isinstance(cfg, list): + return [_cfg_to_dict(item) for item in cfg] + return cfg + + +def _entry_to_dict(entry): + if isinstance(entry, QuantizerCfgEntry): + return entry.model_dump(exclude_unset=True) + return dict(entry) + + # --------------------------------------------------------------------------- # Directory-format YAML fixtures # --------------------------------------------------------------------------- @@ -96,6 +116,32 @@ def test_load_config_suffix_probe(tmp_path): assert load_config(str(tmp_path / "mycfg")) == {"key": "val"} +def test_load_config_schema_type_returns_validated_type(tmp_path): + """schema_type validates and returns the parsed schema value.""" + cfg_file = tmp_path / "quantize.yml" + cfg_file.write_text( + "algorithm: max\n" + "quant_cfg:\n" + " - quantizer_name: '*weight_quantizer'\n" + " cfg:\n" + " num_bits: 8\n" + " axis: 0\n" + ) + data = load_config(cfg_file, schema_type=QuantizeConfig) + assert isinstance(data, QuantizeConfig) + assert _cfg_to_dict(data.quant_cfg[0]["cfg"]) == {"num_bits": 8, "axis": 0} + + +def test_load_config_recipe_metadata_returns_validated_type(tmp_path): + """Recipe metadata schema validates and returns the parsed metadata model.""" + cfg_file = tmp_path / "metadata.yml" + cfg_file.write_text("recipe_type: ptq\n") + data = load_config(cfg_file, schema_type=RecipeMetadataConfig) + assert isinstance(data, RecipeMetadataConfig) + assert data.recipe_type == RecipeType.PTQ + assert data.description == "Model optimization recipe." + + def test_load_config_missing_file_raises(tmp_path): """load_config raises ValueError for a path that does not exist.""" with pytest.raises(ValueError, match="Cannot find config file"): @@ -248,6 +294,7 @@ def _normalize_fpx(val): YAML always uses the string form. Both are converted to ``[E, M]`` so the comparison is representation-agnostic. """ + val = _cfg_to_dict(val) if isinstance(val, str): m = re.fullmatch(r"e(\d+)m(\d+)", val) if m: @@ -256,6 +303,8 @@ def _normalize_fpx(val): return list(val) if isinstance(val, dict): return {str(k): _normalize_fpx(v) for k, v in val.items()} + if isinstance(val, list): + return [_normalize_fpx(v) for v in val] return val def _normalize_entries(raw_entries): @@ -302,7 +351,7 @@ def test_import_resolves_cfg_reference(tmp_path): ) recipe = load_recipe(recipe_file) entry = recipe.quantize["quant_cfg"][0] - assert entry["cfg"] == {"num_bits": (4, 3), "axis": None} + assert _cfg_to_dict(entry["cfg"]) == {"num_bits": (4, 3), "axis": None} def test_import_same_name_used_twice(tmp_path): @@ -325,7 +374,9 @@ def test_import_same_name_used_twice(tmp_path): f" $import: fp8\n" ) recipe = load_recipe(recipe_file) - assert recipe.quantize["quant_cfg"][0]["cfg"] == recipe.quantize["quant_cfg"][1]["cfg"] + assert _cfg_to_dict(recipe.quantize["quant_cfg"][0]["cfg"]) == _cfg_to_dict( + recipe.quantize["quant_cfg"][1]["cfg"] + ) def test_import_multiple_snippets(tmp_path): @@ -375,7 +426,7 @@ def test_import_inline_cfg_not_affected(tmp_path): f" axis: 0\n" ) recipe = load_recipe(recipe_file) - assert recipe.quantize["quant_cfg"][1]["cfg"] == {"num_bits": 8, "axis": 0} + assert _cfg_to_dict(recipe.quantize["quant_cfg"][1]["cfg"]) == {"num_bits": 8, "axis": 0} def test_import_unknown_reference_raises(tmp_path): @@ -511,7 +562,10 @@ def test_import_entry_element_schema_appends(tmp_path): f" - $import: disable_all\n" ) recipe = load_recipe(recipe_file) - assert recipe.quantize["quant_cfg"] == [{"quantizer_name": "*", "cfg": None, "enable": False}] + assert _entry_to_dict(recipe.quantize["quant_cfg"][0]) == { + "quantizer_name": "*", + "enable": False, + } def test_import_entry_wrong_schema_raises(tmp_path): @@ -596,7 +650,7 @@ def test_import_cfg_extend(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "axis": 0} + assert _cfg_to_dict(cfg) == {"num_bits": (4, 3), "axis": 0} def test_import_cfg_inline_overrides_import(tmp_path): @@ -659,7 +713,7 @@ def test_import_in_multiple_dict_values(tmp_path): ) data = load_config(config_file) entry = data["quant_cfg"][0] - assert entry["cfg"] == {"num_bits": (4, 3)} + assert _cfg_to_dict(entry["cfg"]) == {"num_bits": (4, 3)} assert entry["my_field"] == {"fake_quant": False} @@ -683,7 +737,7 @@ def test_import_cfg_multi_import(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "axis": 0} + assert _cfg_to_dict(cfg) == {"num_bits": (4, 3), "axis": 0} def test_import_cfg_multi_import_later_overrides_earlier(tmp_path): @@ -732,7 +786,7 @@ def test_import_cfg_multi_import_with_extend(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "fake_quant": False, "axis": 0} + assert _cfg_to_dict(cfg) == {"num_bits": (4, 3), "fake_quant": False, "axis": 0} def test_import_dir_format(tmp_path): @@ -749,7 +803,10 @@ def test_import_dir_format(tmp_path): " $import: fp8\n" ) recipe = load_recipe(tmp_path) - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3), "axis": None} + assert _cfg_to_dict(recipe.quantize["quant_cfg"][0]["cfg"]) == { + "num_bits": (4, 3), + "axis": None, + } def test_import_dir_format_metadata_imports_do_not_apply_to_quantize(tmp_path): @@ -803,7 +860,7 @@ def test_import_multi_document_list_snippet(tmp_path): recipe = load_recipe(recipe_file) assert len(recipe.quantize["quant_cfg"]) == 1 assert recipe.quantize["quant_cfg"][0]["quantizer_name"] == "*[kv]_bmm_quantizer" - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + assert _cfg_to_dict(recipe.quantize["quant_cfg"][0]["cfg"]) == {"num_bits": (4, 3)} def test_import_builtin_kv_fp8_snippet(): @@ -852,7 +909,8 @@ def test_import_list_splice_outside_typed_list_raises(tmp_path): """A bare $import in an untyped list is rejected.""" _write_quantizer_cfg_list( tmp_path / "extra_tasks.yml", - "- quantizer_name: '*weight_quantizer'\n- quantizer_name: '*input_quantizer'\n", + "- quantizer_name: '*weight_quantizer'\n enable: false\n" + "- quantizer_name: '*input_quantizer'\n enable: false\n", ) config_file = tmp_path / "config.yml" config_file.write_text( @@ -914,9 +972,12 @@ def test_import_mixed_tree(tmp_path): ) data = load_config(config_file) # Dict import inside list entry - assert data["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + assert _cfg_to_dict(data["quant_cfg"][0]["cfg"]) == {"num_bits": (4, 3)} # List splice - assert data["quant_cfg"][1] == {"quantizer_name": "*lm_head*", "enable": False} + assert _entry_to_dict(data["quant_cfg"][1]) == { + "quantizer_name": "*lm_head*", + "enable": False, + } # --------------------------------------------------------------------------- @@ -955,7 +1016,7 @@ def test_import_recursive(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3)} + assert _cfg_to_dict(cfg) == {"num_bits": (4, 3)} def test_import_circular_raises(tmp_path): @@ -1055,9 +1116,12 @@ def test_import_cross_file_same_name_no_conflict(tmp_path): ) recipe = load_recipe(recipe_file) # Parent's "fmt" resolves to fp8 (e4m3), not child's nvfp4. - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + assert _cfg_to_dict(recipe.quantize["quant_cfg"][0]["cfg"]) == {"num_bits": (4, 3)} # Child's "fmt" resolves to nvfp4 (e2m1), not parent's fp8. - assert recipe.quantize["quant_cfg"][1]["cfg"] == {"num_bits": (2, 1), "axis": 0} + assert _cfg_to_dict(recipe.quantize["quant_cfg"][1]["cfg"]) == { + "num_bits": (2, 1), + "axis": 0, + } # --------------------------------------------------------------------------- @@ -1065,20 +1129,20 @@ def test_import_cross_file_same_name_no_conflict(tmp_path): # --------------------------------------------------------------------------- -_BUILTIN_CONFIG_SNIPPETS = [ - "configs/numerics/fp8", - "configs/numerics/nvfp4", - "configs/numerics/nvfp4_static", - "configs/ptq/units/base_disable_all", - "configs/ptq/units/default_disabled_quantizers", - "configs/ptq/units/kv_fp8", - "configs/ptq/units/kv_fp8_cast", - "configs/ptq/units/kv_nvfp4_cast", - "configs/ptq/units/w4a4_nvfp4_nvfp4", - "configs/ptq/units/w8a8_fp8_fp8", - "configs/ptq/presets/kv/fp8", - "configs/ptq/presets/model/fp8", -] +def _iter_builtin_config_snippets(root): + """Yield built-in config YAML files that declare a modelopt schema.""" + for child in sorted(root.iterdir(), key=lambda path: path.name): + if child.is_dir(): + yield from _iter_builtin_config_snippets(child) + elif child.name.endswith((".yaml", ".yml")) and "modelopt-schema:" in child.read_text( + encoding="utf-8" + ): + yield child + + +_BUILTIN_CONFIG_SNIPPETS = list( + _iter_builtin_config_snippets(files("modelopt_recipes").joinpath("configs")) +) @pytest.mark.parametrize("config_path", _BUILTIN_CONFIG_SNIPPETS) @@ -1088,8 +1152,8 @@ def test_builtin_config_snippets_with_modelopt_schema(config_path): assert data -def test_modelopt_schema_comment_validates_without_changing_payload(tmp_path): - """modelopt-schema validates the resolved payload but load_config still returns a plain dict.""" +def test_modelopt_schema_comment_returns_validated_type(tmp_path): + """modelopt-schema validates and returns the parsed schema value.""" config_file = tmp_path / "fp8.yaml" config_file.write_text( "# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig\n" @@ -1097,7 +1161,8 @@ def test_modelopt_schema_comment_validates_without_changing_payload(tmp_path): "axis:\n" ) data = load_config(config_file) - assert data == {"num_bits": (4, 3), "axis": None} + assert isinstance(data, QuantizerAttributeConfig) + assert data.model_dump(exclude_unset=True) == {"num_bits": (4, 3), "axis": None} def test_modelopt_schema_comment_validation_error(tmp_path): @@ -1144,11 +1209,72 @@ def test_modelopt_schema_comment_validates_after_import_resolution(tmp_path): f" $import: fp8\n" ) data = load_config(config_file) - assert data == [{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": (4, 3)}}] + assert data[0]["quantizer_name"] == "*weight_quantizer" + assert _cfg_to_dict(data[0]["cfg"]) == {"num_bits": (4, 3)} + + +def test_import_dict_snippet_imports_in_union_typed_list_field(tmp_path): + """A bare import can append into QuantizerCfgEntry.cfg's list branch.""" + (tmp_path / "int4.yaml").write_text( + "# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig\n" + "num_bits: 4\n" + "block_sizes:\n" + " -1: 128\n" + " type: static\n" + ) + (tmp_path / "fp8.yaml").write_text( + "# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig\n" + "num_bits: e4m3\n" + ) + config_file = tmp_path / "config.yaml" + config_file.write_text( + f"# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig\n" + f"imports:\n" + f" int4: {tmp_path / 'int4.yaml'}\n" + f" fp8: {tmp_path / 'fp8.yaml'}\n" + f"algorithm: awq_lite\n" + f"quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" - $import: int4\n" + f" - $import: fp8\n" + ) + + data = load_config(config_file) + + assert _cfg_to_dict(data["quant_cfg"][0]["cfg"]) == [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": (4, 3)}, + ] + + +def test_import_dict_snippet_in_union_typed_list_field_with_inline_item(tmp_path): + """A dict snippet can be imported as one item inside QuantizerCfgEntry.cfg list.""" + _write_quantizer_attribute( + tmp_path / "int4.yaml", + "num_bits: 4\nblock_sizes:\n -1: 128\n type: static\n", + ) + config_file = tmp_path / "config.yaml" + config_file.write_text( + f"# modelopt-schema: modelopt.torch.quantization.config.QuantizeConfig\n" + f"imports:\n" + f" int4: {tmp_path / 'int4.yaml'}\n" + f"algorithm: awq_lite\n" + f"quant_cfg:\n" + f" - quantizer_name: '*weight_quantizer'\n" + f" cfg:\n" + f" - $import: int4\n" + f" - num_bits: e4m3\n" + ) + data = load_config(config_file) + assert _cfg_to_dict(data["quant_cfg"][0]["cfg"]) == [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": (4, 3)}, + ] # --------------------------------------------------------------------------- -# Coverage: _load_raw_config edge cases +# Coverage: _load_raw_config_with_schema edge cases # --------------------------------------------------------------------------- @@ -1188,9 +1314,9 @@ def test_load_config_multi_doc_dict_dict(tmp_path): """Multi-document YAML with two dicts merges them.""" cfg_file = tmp_path / "multi.yaml" cfg_file.write_text("imports:\n fp8: some/path\n---\nalgorithm: max\n") - from modelopt.torch.opt.config_loader import _load_raw_config + from modelopt.torch.opt.config_loader import _load_raw_config_with_schema - data = _load_raw_config(cfg_file) + data = _load_raw_config_with_schema(cfg_file).data assert data["imports"] == {"fp8": "some/path"} assert data["algorithm"] == "max" @@ -1199,9 +1325,9 @@ def test_load_config_multi_doc_null_content(tmp_path): """Multi-document YAML where second doc is null treats content as empty dict.""" cfg_file = tmp_path / "multi_null.yaml" cfg_file.write_text("key: value\n---\n") - from modelopt.torch.opt.config_loader import _load_raw_config + from modelopt.torch.opt.config_loader import _load_raw_config_with_schema - data = _load_raw_config(cfg_file) + data = _load_raw_config_with_schema(cfg_file).data assert data == {"key": "value"} @@ -1249,7 +1375,8 @@ def test_load_config_list_valued_yaml(tmp_path): data = load_config(cfg_file) assert isinstance(data, list) assert len(data) == 2 - assert data[0] == {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8}} + assert data[0]["quantizer_name"] == "*weight_quantizer" + assert _cfg_to_dict(data[0]["cfg"]) == {"num_bits": 8} # --------------------------------------------------------------------------- @@ -1261,7 +1388,8 @@ def test_import_dict_value_resolves_to_list_raises(tmp_path): """$import in dict value position raises when snippet is a list.""" _write_quantizer_cfg_list( tmp_path / "entries.yml", - "- quantizer_name: '*weight_quantizer'\n- quantizer_name: '*input_quantizer'\n", + "- quantizer_name: '*weight_quantizer'\n enable: false\n" + "- quantizer_name: '*input_quantizer'\n enable: false\n", ) config_file = tmp_path / "config.yml" config_file.write_text( diff --git a/tests/unit/torch/opt/test_config.py b/tests/unit/torch/opt/test_config.py index b2ffadb1a78..e0c5993a51a 100644 --- a/tests/unit/torch/opt/test_config.py +++ b/tests/unit/torch/opt/test_config.py @@ -72,7 +72,7 @@ def _run_test(is_new_registered): assert config[lin_name] == lin_expected_value assert config[lin_alias] == lin_expected_value assert getattr(config, lin_name) == lin_expected_value - with nullcontext() if is_new_registered else pytest.raises(AttributeError): + with nullcontext() if is_new_registered else pytest.raises(KeyError): config[new_name] # get diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index 87ec73291e7..2c8ae9250e7 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -79,6 +79,22 @@ def test_quant_recipe(quant_cfg, other_quant_cfg, is_less_than): assert qr_this_duplicate in {qr_this} +def test_quant_recipe_custom_quantize_config_requires_name(): + custom_cfg = mtq.QuantizeConfig( + quant_cfg=[ + mtq.QuantizerCfgEntry( + quantizer_name="*weight_quantizer", + cfg=mtq.QuantizerAttributeConfig(num_bits=8, axis=None), + ) + ] + ) + + with pytest.raises(ValueError, match="name must be provided"): + QuantRecipe(custom_cfg) + + assert str(QuantRecipe(custom_cfg, name="custom_cfg")).startswith("custom_cfg(") + + def test_quant_recipe_hparam(): model_test = torch.nn.Linear(4, 16) model_ref = torch.nn.Linear(4, 16) diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index f5b1e576f5e..a86b046455e 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -15,9 +15,13 @@ """Test of quantization config validations.""" +import copy +from collections.abc import MutableMapping + import pytest from pydantic import ValidationError +from modelopt.torch.opt.config import ModeloptBaseConfig from modelopt.torch.quantization.config import ( FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, FP8_DEFAULT_CFG, @@ -26,12 +30,24 @@ NVFP4_DEFAULT_CFG, W4A8_AWQ_BETA_CFG, QuantizeConfig, + QuantizerAttributeConfig, + QuantizerCfgEntry, + _base_disable_all, + _default_disabled_quantizer_cfg, find_quant_cfg_entry_by_path, need_calibration, normalize_quant_cfg_list, ) +def _cfg_to_dict(cfg): + if isinstance(cfg, QuantizerAttributeConfig): + return cfg.model_dump(exclude_unset=True) + if isinstance(cfg, list): + return [_cfg_to_dict(item) for item in cfg] + return cfg + + def test_need_calibration(): assert need_calibration(FP8_DEFAULT_CFG) assert not need_calibration(FP8_PER_CHANNEL_PER_TOKEN_CFG) @@ -41,6 +57,13 @@ def test_need_calibration(): assert need_calibration(NVFP4_DEFAULT_CFG) +def test_need_calibration_with_quantize_config_type(): + """need_calibration accepts schema-backed QuantizeConfig objects.""" + assert need_calibration(QuantizeConfig()) + assert need_calibration(QuantizeConfig.model_validate(FP8_DEFAULT_CFG)) + assert not need_calibration(QuantizeConfig.model_validate(FP8_PER_CHANNEL_PER_TOKEN_CFG)) + + def test_need_calibration_with_list_cfg(): """need_calibration must handle sequential (list) cfg entries without crashing.""" # Static list-cfg on a non-weight quantizer → needs calibration @@ -73,15 +96,186 @@ def test_need_calibration_with_list_cfg(): assert not need_calibration(cfg_dynamic) +def test_quantizer_cfg_entry_is_pydantic_and_dict_like(): + """QuantizerCfgEntry is typed but keeps the dict-style access used by callers.""" + entry = QuantizerCfgEntry(quantizer_name="*", enable=False) + assert isinstance(entry, ModeloptBaseConfig) + assert entry["quantizer_name"] == "*" + assert entry["cfg"] is None + assert "cfg" in entry + assert list(entry) == ["quantizer_name", "parent_class", "cfg", "enable"] + assert dict(entry.items()) == { + "quantizer_name": "*", + "parent_class": None, + "cfg": None, + "enable": False, + } + assert dict(entry.explicit_items()) == {"quantizer_name": "*", "enable": False} + with pytest.raises(KeyError): + entry["unknown"] = 1 + assert entry.model_dump(exclude_unset=True) == {"quantizer_name": "*", "enable": False} + + cfg_entry = QuantizerCfgEntry(quantizer_name="*weight_quantizer", cfg={"num_bits": 8}) + assert isinstance(cfg_entry["cfg"], QuantizerAttributeConfig) + assert _cfg_to_dict(cfg_entry["cfg"]) == {"num_bits": 8} + + +def test_quantizer_cfg_entry_mutable_mapping_rejects_key_deletion(): + """ModeloptBaseConfig mappings have a fixed key set and reject deletion.""" + entry = QuantizerCfgEntry(quantizer_name="*weight_quantizer", cfg={"num_bits": 8}, enable=True) + assert isinstance(entry, MutableMapping) + assert entry.model_dump(exclude_unset=True) == { + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": 8}, + "enable": True, + } + + with pytest.raises(TypeError): + del entry["cfg"] + + assert "cfg" in entry + assert entry["cfg"] is not None + assert entry.model_dump(exclude_unset=True) == { + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": 8}, + "enable": True, + } + + with pytest.raises(KeyError): + del entry["missing"] + + +def test_public_preset_quant_cfg_entries_are_typed_and_dict_like(): + """Public preset constants are typed but keep dict-style entry access.""" + assert isinstance(FP8_DEFAULT_CFG, QuantizeConfig) + assert isinstance(NVFP4_DEFAULT_CFG, QuantizeConfig) + for preset in (FP8_DEFAULT_CFG, NVFP4_DEFAULT_CFG): + assert all(isinstance(entry, QuantizerCfgEntry) for entry in preset["quant_cfg"]) + for entry in preset["quant_cfg"]: + assert entry["quantizer_name"] == entry.quantizer_name + assert dict(entry.items())["quantizer_name"] == entry.quantizer_name + + +def test_mixed_raw_dict_and_modelopt_config_entries_normalize_after_mutation(): + """Mixed raw dict and ModeloptBaseConfig entries normalize after mutation.""" + config = { + "quant_cfg": [ + *copy.deepcopy(_base_disable_all), + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "backend": "custom_backend", + "num_bits": "e5m2", + "pass_through_bwd": True, + "backend_extra_args": { + "format": "scalar", + "block_sizes": 16, + }, + }, + }, + {"quantizer_name": "*input_quantizer", "enable": False}, + *copy.deepcopy(_default_disabled_quantizer_cfg), + ], + "algorithm": "max", + } + + normalized = normalize_quant_cfg_list(config["quant_cfg"]) + weight_entry = find_quant_cfg_entry_by_path(normalized, "*weight_quantizer") + assert weight_entry["cfg"]["num_bits"] == "e5m2" + + raw_weight_entry = find_quant_cfg_entry_by_path(config["quant_cfg"], "*weight_quantizer") + raw_weight_entry["cfg"]["num_bits"] = "e2m1" + normalized = normalize_quant_cfg_list(config["quant_cfg"]) + weight_entry = find_quant_cfg_entry_by_path(normalized, "*weight_quantizer") + assert weight_entry["cfg"]["num_bits"] == "e2m1" + + weight_entry["cfg"]["num_bits"] = "e4m3" + renormalized = normalize_quant_cfg_list(normalized) + weight_entry = find_quant_cfg_entry_by_path(renormalized, "*weight_quantizer") + assert weight_entry["cfg"]["num_bits"] == "e4m3" + + +@pytest.mark.parametrize( + ("raw", "match"), + [ + ({"quantizer_name": "*", "cfg": None}, "'?cfg'? must be omitted"), + ({"quantizer_name": "*", "enable": None}, "'?enable'? must be a boolean"), + ], +) +def test_quantizer_cfg_entry_rejects_explicit_null_values(raw, match): + """Explicit null cfg/enable values are rejected instead of treated as omitted.""" + with pytest.raises(ValidationError, match=match): + QuantizerCfgEntry.model_validate(raw) + + with pytest.raises(ValueError, match=match): + normalize_quant_cfg_list([raw]) + + +def test_quantizer_cfg_entry_defaults_enable_true(): + """Direct QuantizerCfgEntry construction uses enable=True when omitted.""" + entry = QuantizerCfgEntry(quantizer_name="*") + assert entry["enable"] is True + assert entry["cfg"] is None + assert dict(entry.explicit_items()) == {"quantizer_name": "*"} + assert entry.model_dump(exclude_unset=True) == {"quantizer_name": "*"} + + +def test_quantizer_cfg_entry_rejects_empty_name(): + """Direct QuantizerCfgEntry construction rejects empty quantizer names.""" + with pytest.raises(ValidationError, match="non-empty string"): + QuantizerCfgEntry(quantizer_name="", enable=False) + + +def test_quantizer_cfg_entry_rejects_empty_cfg_when_enabled(): + """Direct QuantizerCfgEntry construction rejects empty enabled cfg values.""" + with pytest.raises(ValidationError, match="non-empty dict"): + QuantizerCfgEntry(quantizer_name="*weight_quantizer", cfg={}) + + +def test_quantizer_cfg_entry_treats_empty_disabled_cfg_as_disable_only(): + """Empty cfg with enable=False remains a disable-only entry.""" + entry = QuantizerCfgEntry(quantizer_name="*input_quantizer", cfg={}, enable=False) + assert entry["cfg"] is None + assert entry["enable"] is False + + class TestNormalizeQuantCfgList: def test_new_format_passthrough(self): - """New-format entries are returned unchanged (only canonical defaults added).""" + """New-format dict entries are normalized into QuantizerCfgEntry objects.""" raw = [{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}] result = normalize_quant_cfg_list(raw) assert len(result) == 1 + assert isinstance(result[0], QuantizerCfgEntry) assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} - assert result[0]["enable"] is True # defaulted + assert isinstance(result[0]["cfg"], QuantizerAttributeConfig) + assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 8, "axis": 0} + assert result[0]["enable"] is True # schema default + assert "enable" not in dict(result[0].explicit_items()) + + def test_typed_entry_list_passthrough(self): + """Already-parsed QuantizerCfgEntry lists are returned unchanged.""" + raw = [ + QuantizerCfgEntry( + quantizer_name="*weight_quantizer", + cfg=QuantizerAttributeConfig(num_bits=8, axis=0), + enable=True, + ) + ] + result = normalize_quant_cfg_list(raw) + assert result is raw + assert result[0] is raw[0] + + def test_mixed_typed_and_dict_entries_normalize_to_typed_entries(self): + """Mixed QuantizerCfgEntry/dict input lists normalize dicts and preserve typed entries.""" + typed_entry = QuantizerCfgEntry(quantizer_name="*", enable=False) + result = normalize_quant_cfg_list( + [typed_entry, {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8}}] + ) + assert result[0] is typed_entry + assert isinstance(result[1], QuantizerCfgEntry) + assert _cfg_to_dict(result[1]["cfg"]) == {"num_bits": 8} + assert result[1]["enable"] is True + assert "enable" not in dict(result[1].explicit_items()) def test_new_format_enable_false(self): """Explicit enable=False is preserved.""" @@ -102,8 +296,9 @@ def test_legacy_single_key_dict(self): raw = [{"*weight_quantizer": {"num_bits": 8, "axis": 0}}] result = normalize_quant_cfg_list(raw) assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} - assert result[0]["enable"] is True # defaulted + assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 8, "axis": 0} + assert result[0]["enable"] is True # schema default + assert "enable" not in dict(result[0].explicit_items()) def test_legacy_single_key_dict_with_enable(self): """Legacy {'*path': {'enable': False}} splits enable out from cfg.""" @@ -122,17 +317,19 @@ def test_legacy_nn_class_scoped(self): assert result[0]["enable"] is False def test_normalization_cfg_defaults_to_none(self): - """Entries without cfg get cfg=None after normalization.""" + """Entries without cfg expose the default mapping key but keep it unset.""" raw = [{"quantizer_name": "*lm_head*", "enable": False}] result = normalize_quant_cfg_list(raw) assert "cfg" in result[0] assert result[0]["cfg"] is None + assert "cfg" not in dict(result[0].explicit_items()) def test_normalization_enable_defaults_to_true(self): - """Entries with cfg but no enable get enable=True after normalization.""" + """Entries with cfg but no enable read as enable=True without marking it explicit.""" raw = [{"quantizer_name": "*", "cfg": {"num_bits": 4}}] result = normalize_quant_cfg_list(raw) assert result[0]["enable"] is True + assert "enable" not in dict(result[0].explicit_items()) def test_empty_list(self): """Empty list is returned unchanged.""" @@ -148,10 +345,13 @@ def test_multiple_entries_order_preserved(self): assert result[0]["quantizer_name"] == "*" assert result[1]["quantizer_name"] == "*weight_quantizer" - def test_error_on_quantizer_name_only(self): - """Entry with only quantizer_name and no cfg or enable is rejected.""" - with pytest.raises(ValueError, match="must specify 'cfg', 'enable'"): - normalize_quant_cfg_list([{"quantizer_name": "*"}]) + def test_quantizer_name_only_defaults_enable_true(self): + """Entry with only quantizer_name uses enable=True from the schema default.""" + result = normalize_quant_cfg_list([{"quantizer_name": "*"}]) + assert result[0]["enable"] is True + assert result[0]["cfg"] is None + assert dict(result[0].explicit_items()) == {"quantizer_name": "*"} + assert result[0].model_dump(exclude_unset=True) == {"quantizer_name": "*"} def test_error_on_empty_dict(self): """An empty dict entry is rejected.""" @@ -230,7 +430,10 @@ def test_new_format_with_list_cfg(self): ] result = normalize_quant_cfg_list(raw) assert len(result) == 1 - assert result[0]["cfg"] == raw[0]["cfg"] + assert isinstance(result[0], QuantizerCfgEntry) + assert isinstance(result[0]["cfg"], list) + assert all(isinstance(cfg, QuantizerAttributeConfig) for cfg in result[0]["cfg"]) + assert _cfg_to_dict(result[0]["cfg"]) == raw[0]["cfg"] assert result[0]["enable"] is True def test_legacy_flat_dict_conversion(self): @@ -242,8 +445,9 @@ def test_legacy_flat_dict_conversion(self): assert result[0]["enable"] is False assert result[0]["cfg"] is None assert result[1]["quantizer_name"] == "*weight_quantizer" - assert result[1]["cfg"] == {"num_bits": 8, "axis": 0} + assert _cfg_to_dict(result[1]["cfg"]) == {"num_bits": 8, "axis": 0} assert result[1]["enable"] is True + assert "enable" not in dict(result[1].explicit_items()) def test_legacy_enable_only_produces_cfg_none(self): """Legacy {'*': {'enable': False}} should produce cfg=None, not cfg={}.""" @@ -273,7 +477,7 @@ def test_legacy_default_key_with_cfg(self): raw = [{"default": {"num_bits": 8, "axis": None}}] result = normalize_quant_cfg_list(raw) assert result[0]["quantizer_name"] == "*" - assert result[0]["cfg"] == {"num_bits": 8, "axis": None} + assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 8, "axis": None} assert result[0]["enable"] is True def test_legacy_flat_dict_with_default_key(self): @@ -308,7 +512,27 @@ def test_legacy_nn_class_with_cfg(self): assert len(result) == 1 assert result[0]["parent_class"] == "nn.Linear" assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 4, "axis": 0} + assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 4, "axis": 0} + assert result[0]["enable"] is True + + def test_legacy_nn_class_with_list_valued_cfg(self): + """Legacy nn.* scoped format preserves list-valued SequentialQuantizer cfg.""" + raw = [ + { + "nn.Linear": { + "*weight_quantizer": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": 8, "axis": 0}, + ] + } + } + ] + result = normalize_quant_cfg_list(raw) + assert len(result) == 1 + assert result[0]["parent_class"] == "nn.Linear" + assert result[0]["quantizer_name"] == "*weight_quantizer" + assert isinstance(result[0]["cfg"], list) + assert _cfg_to_dict(result[0]["cfg"]) == raw[0]["nn.Linear"]["*weight_quantizer"] assert result[0]["enable"] is True def test_legacy_list_valued_cfg(self): @@ -342,7 +566,7 @@ def test_finds_last_match(self): ] ) result = find_quant_cfg_entry_by_path(entries, "*weight_quantizer") - assert result["cfg"] == {"num_bits": 4} + assert _cfg_to_dict(result["cfg"]) == {"num_bits": 4} def test_exact_match_only(self): """Does not do fnmatch — only exact string equality on quantizer_name.""" @@ -399,7 +623,7 @@ def test_wildcard_matches_bare_name(self): [{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8}}] ) matched, enable = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 8} + assert _cfg_to_dict(matched) == {"num_bits": 8} assert enable is True def test_star_matches_any_bare_name(self): @@ -419,7 +643,7 @@ def test_path_scoped_pattern_matches_matching_suffix(self): [{"quantizer_name": "*mlp*weight_quantizer", "cfg": {"num_bits": 4}}] ) matched, enable = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 4} + assert _cfg_to_dict(matched) == {"num_bits": 4} def test_path_scoped_pattern_does_not_match_different_suffix(self): """'*mlp*weight_quantizer' does NOT match bare 'input_quantizer'.""" @@ -443,7 +667,7 @@ def test_last_match_wins(self): ] ) matched, _ = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 4} + assert _cfg_to_dict(matched) == {"num_bits": 4} def test_no_match_returns_none(self): """No matching entry returns (None, None).""" @@ -525,3 +749,44 @@ def test_validate_quant_cfg_entries_accepts_valid_cfg(self): algorithm="max", ) assert len(cfg.quant_cfg) == 2 + + def test_quant_cfg_parses_dict_cfg_to_pydantic_type(self): + """Python dict cfg input is accepted and parsed to QuantizerAttributeConfig.""" + cfg = QuantizeConfig( + quant_cfg=[ + {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + ], + algorithm="max", + ) + attr_cfg = cfg.quant_cfg[0]["cfg"] + assert isinstance(attr_cfg, QuantizerAttributeConfig) + assert attr_cfg.model_dump(exclude_unset=True) == {"num_bits": 8, "axis": 0} + + def test_quant_cfg_parses_list_of_dict_cfg_to_pydantic_type(self): + """Python list-of-dict cfg input is accepted and parsed to QuantizerAttributeConfig.""" + cfg = QuantizeConfig( + quant_cfg=[ + { + "quantizer_name": "*weight_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": 8, "axis": 0}, + ], + }, + ], + algorithm="max", + ) + attr_cfgs = cfg.quant_cfg[0]["cfg"] + assert isinstance(attr_cfgs, list) + assert all(isinstance(attr_cfg, QuantizerAttributeConfig) for attr_cfg in attr_cfgs) + assert attr_cfgs[0].num_bits == 4 + assert attr_cfgs[1].axis == 0 + + def test_quant_cfg_accepts_pydantic_cfg_instances(self): + """Already-parsed QuantizerAttributeConfig input remains valid.""" + attr_cfg = QuantizerAttributeConfig(num_bits=8, axis=0) + cfg = QuantizeConfig( + quant_cfg=[{"quantizer_name": "*weight_quantizer", "cfg": attr_cfg}], + algorithm="max", + ) + assert cfg.quant_cfg[0]["cfg"] == attr_cfg diff --git a/tests/unit/torch/quantization/test_quantize_cpu.py b/tests/unit/torch/quantization/test_quantize_cpu.py index 301f4cdab1e..0fcac237ccb 100644 --- a/tests/unit/torch/quantization/test_quantize_cpu.py +++ b/tests/unit/torch/quantization/test_quantize_cpu.py @@ -401,6 +401,18 @@ def test_list_attributes_creates_sequential_quantizer(self): assert isinstance(module, SequentialQuantizer) assert len(module) == 2 + def test_sequential_quantizer_rejects_mismatched_attribute_list_length(self): + """SequentialQuantizer rejects partial list configs instead of silently zipping.""" + quantizer = SequentialQuantizer(TensorQuantizer(), TensorQuantizer()) + with pytest.raises(ValueError, match="Expected 2 attribute configs, but got 1"): + quantizer.set_from_attribute_config([QuantizerAttributeConfig(num_bits=8)]) + + def test_sequential_quantizer_rejects_non_mapping_attribute_config(self): + """SequentialQuantizer rejects invalid scalar attribute configs at runtime.""" + quantizer = SequentialQuantizer(TensorQuantizer(), TensorQuantizer()) + with pytest.raises(TypeError, match="attributes must be a list/tuple or a mapping"): + quantizer.set_from_attribute_config(object()) # type: ignore[arg-type] + def test_ordering_later_entry_overrides_earlier(): """Later entries in quant_cfg override earlier ones for the same quantizer."""