From 04f918e865209488f43fed7a8eebc26c23b3c26d Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Wed, 6 May 2026 15:37:46 -0700 Subject: [PATCH 01/14] Type YAML config loading with Pydantic schemas Have load_config return Pydantic-normalized values when schema_type or modelopt-schema is present, including typed recipe metadata and quantization config entries. Update recipe loading, docs, and unit tests for typed config objects and normalized quant_cfg handling. Signed-off-by: Shengliang Xu --- docs/source/guides/_quant_cfg.rst | 23 ++- modelopt/recipe/config.py | 36 ++-- modelopt/recipe/loader.py | 29 +--- modelopt/torch/opt/config_loader.py | 84 +++++++--- modelopt/torch/quantization/config.py | 157 +++++++++++------- modelopt/torch/quantization/conversion.py | 37 +++-- modelopt/torch/quantization/model_quant.py | 8 +- tests/unit/recipe/test_loader.py | 99 ++++++++--- .../quantization/test_config_validation.py | 42 +++++ 9 files changed, 345 insertions(+), 170 deletions(-) diff --git a/docs/source/guides/_quant_cfg.rst b/docs/source/guides/_quant_cfg.rst index 6a027b74b00..4223f29e610 100644 --- a/docs/source/guides/_quant_cfg.rst +++ b/docs/source/guides/_quant_cfg.rst @@ -55,9 +55,12 @@ Each entry in the list is a dictionary with the following fields: (e.g. ``"nn.Linear"``). If omitted, all modules are targeted regardless of class. * - ``cfg`` - No - - A dict of quantizer attributes as defined by :class:`QuantizerAttributeConfig - `, or a list of such dicts - for sequential quantization (see :ref:`sequential-quantizers`). + - A :class:`QuantizerAttributeConfig + `, or a list of + ``QuantizerAttributeConfig`` objects for sequential quantization (see + :ref:`sequential-quantizers`). Equivalent Python dicts, YAML mappings, and lists of + dicts are still accepted for backward compatibility, but those weakly schematized forms + are deprecated. * - ``enable`` - No - ``True`` or ``False``. Toggles matched quantizers on or off, independently of ``cfg``. @@ -74,6 +77,11 @@ Each entry in the list is a dictionary with the following fields: a bare ``{"quantizer_name": "*"}`` would silently behave as ``enable=True`` for all quantizers. + Schema-backed YAML loading parses ``cfg`` mappings into + :class:`QuantizerAttributeConfig ` + values. Plain Python dicts and lists of dicts are accepted only as a backward-compatible, + weakly schematized input format. + ---------- Default Quantizer Configuration @@ -278,7 +286,9 @@ For entirely custom recipes, compose the list directly: Sequential Quantization ======================= -When ``cfg`` is a **list** of attribute dicts, the matched +When ``cfg`` is a **list** of +:class:`QuantizerAttributeConfig ` +values, the matched :class:`TensorQuantizer ` is replaced with a :class:`SequentialQuantizer ` @@ -295,6 +305,11 @@ are quantized first in INT4 and then in FP8: ], } +The list-of-dict spelling shown above remains accepted for existing Python configs and is the +natural YAML spelling, but it is a deprecated weakly schematized input form. After schema-backed +loading or :class:`QuantizeConfig ` parsing, +each element is a ``QuantizerAttributeConfig``. + ---------- .. _migrating-from-dict-format: diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index 96f33012afd..59a0c69c6de 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -19,9 +19,6 @@ from enum import Enum -from pydantic import field_validator -from typing_extensions import NotRequired, TypedDict - from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.quantization.config import QuantizeConfig @@ -33,14 +30,18 @@ class RecipeType(str, Enum): # QAT = "qat" # Not implemented yet, will be added in the future. -class RecipeMetadataConfig(TypedDict): - """YAML shape of the recipe metadata section.""" +_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe." - recipe_type: RecipeType - description: NotRequired[str] +class RecipeMetadataConfig(ModeloptBaseConfig): + """YAML shape of the recipe metadata section.""" -_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe." + recipe_type: RecipeType + description: str = ModeloptField( + default=_DEFAULT_RECIPE_DESCRIPTION, + title="Description", + description="Human-readable recipe description.", + ) class ModelOptRecipeBase(ModeloptBaseConfig): @@ -50,32 +51,23 @@ class ModelOptRecipeBase(ModeloptBaseConfig): """ metadata: RecipeMetadataConfig = ModeloptField( - default={"recipe_type": RecipeType.PTQ, "description": _DEFAULT_RECIPE_DESCRIPTION}, + default=RecipeMetadataConfig( + recipe_type=RecipeType.PTQ, description=_DEFAULT_RECIPE_DESCRIPTION + ), title="Metadata", description="Recipe metadata containing the recipe type and description.", validate_default=True, ) - @field_validator("metadata") - @classmethod - def validate_metadata(cls, metadata: RecipeMetadataConfig) -> RecipeMetadataConfig: - """Validate recipe metadata and fill defaults for optional fields.""" - if metadata["recipe_type"] not in RecipeType: - raise ValueError( - f"Unsupported recipe type: {metadata['recipe_type']}. " - f"Only {list(RecipeType)} are currently supported." - ) - return {"description": _DEFAULT_RECIPE_DESCRIPTION, **metadata} - @property def recipe_type(self) -> RecipeType: """Return the recipe type from metadata.""" - return self.metadata["recipe_type"] + return self.metadata.recipe_type @property def description(self) -> str: """Return the recipe description from metadata.""" - return self.metadata.get("description", _DEFAULT_RECIPE_DESCRIPTION) + return self.metadata.description class ModelOptPTQRecipe(ModelOptRecipeBase): diff --git a/modelopt/recipe/loader.py b/modelopt/recipe/loader.py index 9c3c40856d2..e77419e4a24 100644 --- a/modelopt/recipe/loader.py +++ b/modelopt/recipe/loader.py @@ -22,7 +22,7 @@ from pathlib import Path from modelopt.torch.opt.config_loader import BUILTIN_CONFIG_ROOT as BUILTIN_RECIPES_LIB -from modelopt.torch.opt.config_loader import load_config +from modelopt.torch.opt.config_loader import _load_raw_config_with_schema, load_config from modelopt.torch.quantization.config import QuantizeConfig from .config import ModelOptPTQRecipe, ModelOptRecipeBase, RecipeMetadataConfig, RecipeType @@ -89,13 +89,13 @@ def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBas The file must contain a ``metadata`` section with at least ``recipe_type``, plus a ``quant_cfg`` mapping and an optional ``algorithm`` for PTQ recipes. """ - data = load_config(recipe_file, schema_type=ModelOptPTQRecipe) - if not isinstance(data, dict): + raw_data = _load_raw_config_with_schema(recipe_file).data + if not isinstance(raw_data, dict): raise ValueError( - f"Recipe file {recipe_file} must be a YAML mapping, got {type(data).__name__}." + f"Recipe file {recipe_file} must be a YAML mapping, got {type(raw_data).__name__}." ) - metadata = data.get("metadata", {}) + metadata = raw_data.get("metadata", {}) if not isinstance(metadata, dict): raise ValueError( f"Recipe file {recipe_file} field 'metadata' must be a mapping, " @@ -106,12 +106,9 @@ def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBas raise ValueError(f"Recipe file {recipe_file} must contain a 'metadata.recipe_type' field.") if recipe_type == RecipeType.PTQ: - if "quantize" not in data: + if "quantize" not in raw_data: raise ValueError(f"PTQ recipe file {recipe_file} must contain 'quantize'.") - return ModelOptPTQRecipe( - metadata=metadata, - quantize=data["quantize"], - ) + return load_config(recipe_file, schema_type=ModelOptPTQRecipe) raise ValueError(f"Unsupported recipe type: {recipe_type!r}") @@ -139,21 +136,11 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase: metadata_file = _find_recipe_section_file(recipe_dir, "metadata") metadata = load_config(metadata_file, schema_type=RecipeMetadataConfig) - if not isinstance(metadata, dict): - raise ValueError( - f"Metadata file {metadata_file} must be a YAML mapping, got {type(metadata).__name__}." - ) - recipe_type = metadata.get("recipe_type") - if recipe_type is None: - raise ValueError(f"Metadata file {metadata_file} must contain a 'recipe_type' field.") + recipe_type = metadata.recipe_type if recipe_type == RecipeType.PTQ: quantize_file = _find_recipe_section_file(recipe_dir, "quantize") quantize_data = load_config(quantize_file, schema_type=QuantizeConfig) - if not isinstance(quantize_data, dict): - raise ValueError( - f"{quantize_file} must be a YAML mapping, got {type(quantize_data).__name__}." - ) return ModelOptPTQRecipe( metadata=metadata, quantize=quantize_data, diff --git a/modelopt/torch/opt/config_loader.py b/modelopt/torch/opt/config_loader.py index 43231c90995..f0c24c7c510 100644 --- a/modelopt/torch/opt/config_loader.py +++ b/modelopt/torch/opt/config_loader.py @@ -33,12 +33,26 @@ import re import sys from pathlib import Path -from typing import Any, Union, get_args, get_origin, get_type_hints +from typing import ( + TYPE_CHECKING, + Any, + TypeVar, + Union, + get_args, + get_origin, + get_type_hints, + overload, +) import yaml from pydantic import TypeAdapter from typing_extensions import NotRequired, Required, is_typeddict +if TYPE_CHECKING: + from modelopt.torch.opt.config import ModeloptBaseConfig + +_ModeloptConfigT = TypeVar("_ModeloptConfigT", bound="ModeloptBaseConfig") + @dataclass class _ListSnippet: @@ -47,7 +61,7 @@ class _ListSnippet: YAML requires one root node per document, so a file that is "a list with an ``imports`` section" has to use two documents separated by ``---``. This wrapper is the internal transport carrying both pieces from - :func:`_load_raw_config` to :func:`_resolve_imports` without smuggling them + :func:`_load_raw_config_with_schema` to :func:`_resolve_imports` without smuggling them through a sentinel dict key (which would collide if a user happened to choose the same key name). """ @@ -113,7 +127,7 @@ def _resolve_config_path(config_file: str | Path | Traversable) -> Path | Traver built-in package resources return a ``Traversable``. Raises ``ValueError`` if no candidate exists. - Factored out of :func:`_load_raw_config` so :func:`_resolve_imports` can + Factored out of :func:`_load_raw_config_with_schema` so :func:`_resolve_imports` can compute a canonical cycle-detection key without reading the file twice. """ # Probe order: filesystem first, then built-in library. @@ -243,13 +257,6 @@ def _load_raw_config_with_schema(config_file: str | Path | Traversable) -> _RawC ) -def _load_raw_config( - config_file: str | Path | Traversable, -) -> dict[str, Any] | list[Any] | _ListSnippet: - """Load a config YAML without resolving ``$import`` references.""" - return _load_raw_config_with_schema(config_file).data - - _IMPORT_KEY = "$import" @@ -373,10 +380,10 @@ def _validate_modelopt_schema( data: Any, config_path: Any, schema_type: Any | None = None, -) -> None: - """Validate resolved config content against the requested schema without mutating it.""" +) -> Any: + """Validate resolved config content and return the Pydantic-normalized value.""" if schema_type is None and not schema_path: - return + return data if schema_type is None: assert schema_path is not None schema_type = _schema_type(schema_path) @@ -384,7 +391,7 @@ def _validate_modelopt_schema( # TypeAdapter validates the schema types we allow here: BaseModel classes # plus regular typing constructs such as TypedDict, list[TypedDict], unions, # and aliases. Schema comments are not treated as arbitrary validators. - TypeAdapter(schema_type).validate_python(data) + return TypeAdapter(schema_type).validate_python(data) except Exception as exc: raise ValueError( f"Config file {config_path} does not match modelopt-schema " @@ -592,22 +599,52 @@ def _find_import_marker(obj: Any, context: str = "root") -> tuple[Any, str] | No return None +# Concrete ModeloptBaseConfig subclasses are returned as that exact parsed type. +@overload def load_config( config_path: str | Path | Traversable, *, - schema_type: Any | None = None, -) -> dict[str, Any] | list[Any]: + schema_type: type[_ModeloptConfigT], +) -> _ModeloptConfigT: ... + + +# Typed list schemas, such as list[SomeModeloptConfig], validate each element +# and return a list of parsed config objects. +@overload +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: type[list[_ModeloptConfigT]], +) -> list[_ModeloptConfigT]: ... + + +# Without an explicit schema_type, untyped files return raw dict/list payloads; +# files with modelopt-schema comments still return the validated schema value. +@overload +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: None = None, +) -> "ModeloptBaseConfig | dict[str, Any] | list[Any]": ... + + +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: object | None = None, +) -> "ModeloptBaseConfig | dict[str, Any] | list[Any]": """Load a YAML config and resolve all ``$import`` references. This is the primary config loading entry point. It loads the YAML file, resolves any ``imports`` / ``$import`` directives, and returns the final - config dict or list. + config. ``schema_type`` supplies a typing context for import resolution when the - file itself has no ``modelopt-schema`` comment. It is intentionally not a - request to validate the top-level file. Top-level files are validated only - when they declare ``modelopt-schema``; imported snippets are stricter and - must always declare ``modelopt-schema``. + file itself has no ``modelopt-schema`` comment. If either ``schema_type`` + or a ``modelopt-schema`` comment is present, the resolved top-level payload + is returned as the Pydantic-normalized value for that schema. Untyped files + return the resolved dict or list. Imported snippets are stricter and must + always declare ``modelopt-schema``. """ raw = _load_raw_config_with_schema(config_path) data = raw.data @@ -616,5 +653,8 @@ def load_config( if isinstance(data, (_ListSnippet, dict)): data = _resolve_imports(data, schema_type=resolver_schema_type) - _validate_modelopt_schema(raw.schema, data, raw.path, schema_type=declared_schema_type) + if declared_schema_type is not None or schema_type is not None: + return _validate_modelopt_schema( + raw.schema, data, raw.path, schema_type=declared_schema_type or schema_type + ) return data diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 3adb70cf6b7..13df5b19eee 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -71,8 +71,10 @@ - ``parent_class`` *(optional)*: restricts matching to quantizers whose immediate parent module is of this PyTorch class (e.g. ``"nn.Linear"``). If omitted, all matching quantizers are targeted regardless of their parent class. -- ``cfg`` *(optional)*: a dict of quantizer attributes as defined by - :class:`QuantizerAttributeConfig`, or a list of such dicts. When a list is given, the matched +- ``cfg`` *(optional)*: a :class:`QuantizerAttributeConfig`, or a list of + :class:`QuantizerAttributeConfig` objects. Equivalent dict and list-of-dict inputs are accepted + for backward compatibility, but are deprecated weakly schematized forms. When a list is given, + the matched :class:`TensorQuantizer ` is replaced with a :class:`SequentialQuantizer ` @@ -152,6 +154,7 @@ import copy import warnings +from collections.abc import Mapping, Sequence from typing import Any, Literal, cast from pydantic import ValidationInfo, field_validator, model_validator @@ -161,49 +164,6 @@ from modelopt.torch.opt.config_loader import load_config from modelopt.torch.utils.network import ConstructorLike - -class QuantizerCfgEntry(TypedDict, total=False): - """A single entry in a ``quant_cfg`` list.""" - - quantizer_name: Required[str] # matched against quantizer module names - parent_class: str | None # optional; filters by pytorch module class name (e.g. "nn.Linear") - cfg: dict[str, Any] | list[dict[str, Any]] | None # quantizer attribute config(s) - enable: bool | None # toggles matched quantizers on/off; independent of cfg - - -def find_quant_cfg_entry_by_path( - quant_cfg_list: list[QuantizerCfgEntry], quantizer_name: str -) -> QuantizerCfgEntry: - """Find the last entry in a ``quant_cfg`` list whose ``quantizer_name`` key equals the query. - - This performs an **exact string comparison** against the ``quantizer_name`` field of each - entry — it does *not* apply ``fnmatch`` pattern matching. For example, passing - ``"*input_quantizer"`` will only match entries whose ``quantizer_name`` is literally - ``"*input_quantizer"``, not entries with a different wildcard that would match the same - module names at apply time. - - Returns the *last* match because entries are applied in list order and later entries - override earlier ones, so the last match represents the effective configuration. - - Args: - quant_cfg_list: A list of :class:`QuantizerCfgEntry` dicts. - quantizer_name: The exact ``quantizer_name`` string to search for. - - Returns: - The last entry whose ``quantizer_name`` equals *quantizer_name*. - - Raises: - KeyError: If no entry with the given ``quantizer_name`` is found. - """ - result = None - for entry in quant_cfg_list: - if isinstance(entry, dict) and entry.get("quantizer_name") == quantizer_name: - result = entry - if result is None: - raise KeyError(f"No quant_cfg entry with quantizer_name={quantizer_name!r}") - return result - - BiasType = Literal["static", "dynamic"] BiasMethod = Literal["mean", "max_min"] @@ -563,6 +523,48 @@ def validate_calibrator(cls, v, info: ValidationInfo): ) +class QuantizerCfgEntry(TypedDict, total=False): + """A single entry in a ``quant_cfg`` list.""" + + quantizer_name: Required[str] # matched against quantizer module names + parent_class: str | None # optional; filters by pytorch module class name (e.g. "nn.Linear") + cfg: QuantizerAttributeConfig | list[QuantizerAttributeConfig] | None + enable: bool | None # toggles matched quantizers on/off; independent of cfg + + +def find_quant_cfg_entry_by_path( + quant_cfg_list: Sequence[Any], quantizer_name: str +) -> dict[str, Any]: + """Find the last entry in a ``quant_cfg`` list whose ``quantizer_name`` key equals the query. + + This performs an **exact string comparison** against the ``quantizer_name`` field of each + entry — it does *not* apply ``fnmatch`` pattern matching. For example, passing + ``"*input_quantizer"`` will only match entries whose ``quantizer_name`` is literally + ``"*input_quantizer"``, not entries with a different wildcard that would match the same + module names at apply time. + + Returns the *last* match because entries are applied in list order and later entries + override earlier ones, so the last match represents the effective configuration. + + Args: + quant_cfg_list: A list of :class:`QuantizerCfgEntry` dicts. + quantizer_name: The exact ``quantizer_name`` string to search for. + + Returns: + The last entry whose ``quantizer_name`` equals *quantizer_name*. + + Raises: + KeyError: If no entry with the given ``quantizer_name`` is found. + """ + result: dict[str, Any] | None = None + for entry in quant_cfg_list: + if isinstance(entry, dict) and entry.get("quantizer_name") == quantizer_name: + result = entry + if result is None: + raise KeyError(f"No quant_cfg entry with quantizer_name={quantizer_name!r}") + return result + + class QuantizeAlgorithmConfig(ModeloptBaseConfig): """Calibration algorithm config base.""" @@ -928,6 +930,7 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): QuantizeQuantCfgType = list[QuantizerCfgEntry] QuantizerCfgListConfig = QuantizeQuantCfgType +QuantizeQuantCfgInputType = Sequence[Mapping[str, Any]] _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None @@ -1008,9 +1011,13 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: # Support multi-key nn.*-scoped dicts by emitting one entry per sub-key. entries: list[QuantizerCfgEntry] = [] for q_path, sub_cfg in value.items(): - sub_cfg = dict(sub_cfg) - enable = sub_cfg.pop("enable", None) - cfg = sub_cfg or None + if isinstance(sub_cfg, QuantizerAttributeConfig): + enable = None + cfg = sub_cfg + else: + sub_cfg = dict(sub_cfg) + enable = sub_cfg.pop("enable", None) + cfg = sub_cfg or None entry: QuantizerCfgEntry = { "parent_class": key, "quantizer_name": q_path, @@ -1069,20 +1076,26 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: cfg = entry.get("cfg") enable = entry.get("enable", True) if enable and cfg is not None: - if isinstance(cfg, dict): + if isinstance(cfg, QuantizerAttributeConfig): + is_invalid = False + elif isinstance(cfg, dict): is_invalid = len(cfg) == 0 elif isinstance(cfg, list): is_invalid = len(cfg) == 0 or any( - not isinstance(item, dict) or len(item) == 0 for item in cfg + not isinstance(item, (dict, QuantizerAttributeConfig)) + or (isinstance(item, dict) and len(item) == 0) + for item in cfg ) else: is_invalid = True if is_invalid: raise ValueError( - f"Invalid quant_cfg entry: {raw!r} — 'cfg' must be a non-empty dict " - f"or a non-empty list of non-empty dicts when enabling a quantizer " - f"(got {type(cfg).__name__}: {cfg!r}). Either provide quantizer " - "attributes in 'cfg' or remove 'cfg' and set 'enable' explicitly." + f"Invalid quant_cfg entry: {raw!r} — 'cfg' must be a " + "QuantizerAttributeConfig, a non-empty dict, or a non-empty list of " + "QuantizerAttributeConfig/non-empty dict entries when enabling a " + f"quantizer (got {type(cfg).__name__}: {cfg!r}). Either provide " + "quantizer attributes in 'cfg' or remove 'cfg' and set 'enable' " + "explicitly." ) # Normalize: make enable and cfg always explicit. @@ -1156,6 +1169,13 @@ class _QuantizeExportConfig(ModeloptBaseConfig): """An empty config.""" +def _load_quantize_config_dict(config_path: str) -> dict[str, Any]: + """Load a schema-backed QuantizeConfig YAML while preserving public dict constants.""" + config = load_config(config_path) + assert isinstance(config, QuantizeConfig), f"{config_path} must declare QuantizeConfig schema." + return config.model_dump(exclude_unset=True) + + _base_disable_all: list[QuantizerCfgEntry] = [ cast("QuantizerCfgEntry", load_config("configs/ptq/units/base_disable_all")) ] @@ -1211,7 +1231,7 @@ class _QuantizeExportConfig(ModeloptBaseConfig): "algorithm": "max", } -FP8_DEFAULT_CFG: dict[str, Any] = load_config("configs/ptq/presets/model/fp8") +FP8_DEFAULT_CFG: dict[str, Any] = _load_quantize_config_dict("configs/ptq/presets/model/fp8") MAMBA_MOE_FP8_AGGRESSIVE_CFG = { "quant_cfg": [ @@ -1456,7 +1476,7 @@ class _QuantizeExportConfig(ModeloptBaseConfig): # KV-cache configs are designed to be merged with a primary quantization config (e.g. # FP8_DEFAULT_CFG) that already contains _base_disable_all. They intentionally omit both # _base_disable_all and "algorithm" because these are provided by the primary config. -FP8_KV_CFG: dict[str, Any] = load_config("configs/ptq/presets/kv/fp8") +FP8_KV_CFG: dict[str, Any] = _load_quantize_config_dict("configs/ptq/presets/kv/fp8") FP8_AFFINE_KV_CFG = { "quant_cfg": [ @@ -1494,11 +1514,20 @@ def _nvfp4_selective_quant_cfg( for pattern in layer_patterns: # Deep-copy the quantizer dict so each config constant gets its own instance. quant_cfg.append( - {"quantizer_name": f"{pattern}weight_quantizer", "cfg": copy.deepcopy(quantizer)} + cast( + "QuantizerCfgEntry", + {"quantizer_name": f"{pattern}weight_quantizer", "cfg": copy.deepcopy(quantizer)}, + ) ) if not weight_only: quant_cfg.append( - {"quantizer_name": f"{pattern}input_quantizer", "cfg": copy.deepcopy(quantizer)} + cast( + "QuantizerCfgEntry", + { + "quantizer_name": f"{pattern}input_quantizer", + "cfg": copy.deepcopy(quantizer), + }, + ) ) quant_cfg.extend(_default_disabled_quantizer_cfg) return {"quant_cfg": quant_cfg, "algorithm": algorithm} @@ -1764,6 +1793,11 @@ def need_calibration(config): def _not_dynamic(cfg): return cfg.get("enable", True) and cfg.get("type", "") != "dynamic" + def _cfg_to_dict(cfg): + if isinstance(cfg, QuantizerAttributeConfig): + return cfg.model_dump(exclude_unset=True) + return dict(cfg or {}) + quant_cfg: list = config.get("quant_cfg") or [] quant_cfg = normalize_quant_cfg_list(quant_cfg) for entry in quant_cfg: @@ -1775,10 +1809,13 @@ def _not_dynamic(cfg): # Sequential quantizers (e.g. W4A8) have a list of cfg dicts if isinstance(raw_cfg, list): for _config in raw_cfg: - if _not_dynamic(_config): + cfg = _cfg_to_dict(_config) + if "enable" in entry: + cfg["enable"] = entry["enable"] + if _not_dynamic(cfg): return True continue - cfg = dict(raw_cfg or {}) + cfg = _cfg_to_dict(raw_cfg) if "enable" in entry: cfg["enable"] = entry["enable"] if _not_dynamic(cfg): diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 3f97f8380be..535e6d6f508 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -31,7 +31,7 @@ from .config import ( QuantizeConfig, - QuantizeQuantCfgType, + QuantizeQuantCfgInputType, QuantizerAttributeConfig, _QuantizeExportConfig, normalize_quant_cfg_list, @@ -215,7 +215,7 @@ def _replace_quant_module(model: nn.Module, version=None, registry=QuantModuleRe _replace_quant_module(getattr(model, name), version=version, registry=registry) -def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): +def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgInputType): """Apply a quantization config list to the quantizers in ``quant_model``. ``quant_cfg`` is an **ordered list** of :class:`QuantizerCfgEntry <.config.QuantizerCfgEntry>` @@ -223,8 +223,9 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType - ``quantizer_name`` *(required)*: wildcard matched against quantizer module names via :func:`fnmatch`. - - ``cfg`` *(optional)*: a dict of :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>` - fields, or a list of such dicts for sequential quantization. + - ``cfg`` *(optional)*: a :class:`QuantizerAttributeConfig <.config.QuantizerAttributeConfig>` + or a list of them for sequential quantization. Equivalent dict and list-of-dict inputs are + accepted for backward compatibility. - ``enable`` *(optional)*: ``True`` or ``False`` to toggle matched quantizers on or off. When omitted but ``cfg`` is present, defaults to ``True``. Every entry must specify at least one of ``cfg`` or ``enable`` — an entry with only ``quantizer_name`` is invalid. @@ -248,7 +249,7 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType See :ref:`quant-cfg` for the full format reference and common patterns. """ - quant_cfg = normalize_quant_cfg_list(quant_cfg) + quant_cfg = normalize_quant_cfg_list(list(quant_cfg)) for entry in quant_cfg: quantizer_name: str = entry["quantizer_name"] @@ -277,13 +278,23 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType attributes = cfg.model_copy(update={"enable": enable}) elif isinstance(cfg, dict): attributes = QuantizerAttributeConfig(**cfg, enable=enable) + elif isinstance(cfg, list): + attributes = [] + for c in cfg: + if isinstance(c, QuantizerAttributeConfig): + attributes.append(c.model_copy(update={"enable": enable})) + elif isinstance(c, dict): + attributes.append(QuantizerAttributeConfig(**c, enable=enable)) + else: + raise ValueError( + f"Invalid cfg element for quantizer {quantizer_name!r}: expected " + "QuantizerAttributeConfig or dict." + ) else: - attributes = [ - c.model_copy(update={"enable": enable}) - if isinstance(c, QuantizerAttributeConfig) - else QuantizerAttributeConfig(**c, enable=enable) - for c in cfg - ] + raise ValueError( + f"Invalid cfg for quantizer {quantizer_name!r}: expected " + "QuantizerAttributeConfig, dict, or list." + ) set_quantizer_attributes_full(quant_model, quantizer_name, attributes, parent_class) @@ -477,7 +488,7 @@ def set_quantizer_attributes_partial( @contextmanager -def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): +def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgInputType): """Context manager that temporarily applies a quantization config and restores the original state on exit. Calls :func:`set_quantizer_by_cfg` on entry and reverts every @@ -497,7 +508,7 @@ def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuan Yields: None — the context body runs with the new quantizer attributes active. """ - quant_cfg = normalize_quant_cfg_list(quant_cfg) + quant_cfg = normalize_quant_cfg_list(list(quant_cfg)) for entry in quant_cfg: if isinstance(entry.get("cfg"), list): diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 5e65f9cc1d4..589ec477786 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -167,9 +167,11 @@ def quantize( :meth:`calibrate `. Each entry in the ``"quant_cfg"`` list has a ``"quantizer_name"`` wildcard matched - against quantizer module names, an optional ``"cfg"`` dict of quantizer attributes, - and an optional ``"enable"`` toggle. Entries are applied in list order; later entries - override earlier ones. The quantizer modules have names ending with + against quantizer module names, an optional ``"cfg"`` + :class:`QuantizerAttributeConfig ` + (or equivalent backward-compatible dict input), and an optional ``"enable"`` toggle. + Entries are applied in list order; later entries override earlier ones. The quantizer + modules have names ending with ``weight_quantizer`` and ``input_quantizer`` and they perform weight quantization and input quantization (or activation quantization) respectively. The quantizer modules are instances of diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index a5c8ccaf479..19463493b8f 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -19,8 +19,9 @@ import pytest -from modelopt.recipe.config import ModelOptPTQRecipe, RecipeType +from modelopt.recipe.config import ModelOptPTQRecipe, RecipeMetadataConfig, RecipeType from modelopt.recipe.loader import load_config, load_recipe +from modelopt.torch.quantization.config import QuantizeConfig, QuantizerAttributeConfig # --------------------------------------------------------------------------- # Static YAML fixtures @@ -75,6 +76,14 @@ def _write_quantizer_cfg_list(path, body: str): path.write_text(QUANTIZER_CFG_LIST_SCHEMA + body) +def _cfg_to_dict(cfg): + if isinstance(cfg, QuantizerAttributeConfig): + return cfg.model_dump(exclude_unset=True) + if isinstance(cfg, list): + return [_cfg_to_dict(item) for item in cfg] + return cfg + + # --------------------------------------------------------------------------- # Directory-format YAML fixtures # --------------------------------------------------------------------------- @@ -96,6 +105,32 @@ def test_load_config_suffix_probe(tmp_path): assert load_config(str(tmp_path / "mycfg")) == {"key": "val"} +def test_load_config_schema_type_returns_validated_type(tmp_path): + """schema_type validates and returns the parsed schema value.""" + cfg_file = tmp_path / "quantize.yml" + cfg_file.write_text( + "algorithm: max\n" + "quant_cfg:\n" + " - quantizer_name: '*weight_quantizer'\n" + " cfg:\n" + " num_bits: 8\n" + " axis: 0\n" + ) + data = load_config(cfg_file, schema_type=QuantizeConfig) + assert isinstance(data, QuantizeConfig) + assert _cfg_to_dict(data.quant_cfg[0]["cfg"]) == {"num_bits": 8, "axis": 0} + + +def test_load_config_recipe_metadata_returns_validated_type(tmp_path): + """Recipe metadata schema validates and returns the parsed metadata model.""" + cfg_file = tmp_path / "metadata.yml" + cfg_file.write_text("recipe_type: ptq\n") + data = load_config(cfg_file, schema_type=RecipeMetadataConfig) + assert isinstance(data, RecipeMetadataConfig) + assert data.recipe_type == RecipeType.PTQ + assert data.description == "Model optimization recipe." + + def test_load_config_missing_file_raises(tmp_path): """load_config raises ValueError for a path that does not exist.""" with pytest.raises(ValueError, match="Cannot find config file"): @@ -248,6 +283,7 @@ def _normalize_fpx(val): YAML always uses the string form. Both are converted to ``[E, M]`` so the comparison is representation-agnostic. """ + val = _cfg_to_dict(val) if isinstance(val, str): m = re.fullmatch(r"e(\d+)m(\d+)", val) if m: @@ -256,6 +292,8 @@ def _normalize_fpx(val): return list(val) if isinstance(val, dict): return {str(k): _normalize_fpx(v) for k, v in val.items()} + if isinstance(val, list): + return [_normalize_fpx(v) for v in val] return val def _normalize_entries(raw_entries): @@ -302,7 +340,7 @@ def test_import_resolves_cfg_reference(tmp_path): ) recipe = load_recipe(recipe_file) entry = recipe.quantize["quant_cfg"][0] - assert entry["cfg"] == {"num_bits": (4, 3), "axis": None} + assert _cfg_to_dict(entry["cfg"]) == {"num_bits": (4, 3), "axis": None} def test_import_same_name_used_twice(tmp_path): @@ -325,7 +363,9 @@ def test_import_same_name_used_twice(tmp_path): f" $import: fp8\n" ) recipe = load_recipe(recipe_file) - assert recipe.quantize["quant_cfg"][0]["cfg"] == recipe.quantize["quant_cfg"][1]["cfg"] + assert _cfg_to_dict(recipe.quantize["quant_cfg"][0]["cfg"]) == _cfg_to_dict( + recipe.quantize["quant_cfg"][1]["cfg"] + ) def test_import_multiple_snippets(tmp_path): @@ -375,7 +415,7 @@ def test_import_inline_cfg_not_affected(tmp_path): f" axis: 0\n" ) recipe = load_recipe(recipe_file) - assert recipe.quantize["quant_cfg"][1]["cfg"] == {"num_bits": 8, "axis": 0} + assert _cfg_to_dict(recipe.quantize["quant_cfg"][1]["cfg"]) == {"num_bits": 8, "axis": 0} def test_import_unknown_reference_raises(tmp_path): @@ -596,7 +636,7 @@ def test_import_cfg_extend(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "axis": 0} + assert _cfg_to_dict(cfg) == {"num_bits": (4, 3), "axis": 0} def test_import_cfg_inline_overrides_import(tmp_path): @@ -659,7 +699,7 @@ def test_import_in_multiple_dict_values(tmp_path): ) data = load_config(config_file) entry = data["quant_cfg"][0] - assert entry["cfg"] == {"num_bits": (4, 3)} + assert _cfg_to_dict(entry["cfg"]) == {"num_bits": (4, 3)} assert entry["my_field"] == {"fake_quant": False} @@ -683,7 +723,7 @@ def test_import_cfg_multi_import(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "axis": 0} + assert _cfg_to_dict(cfg) == {"num_bits": (4, 3), "axis": 0} def test_import_cfg_multi_import_later_overrides_earlier(tmp_path): @@ -732,7 +772,7 @@ def test_import_cfg_multi_import_with_extend(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "fake_quant": False, "axis": 0} + assert _cfg_to_dict(cfg) == {"num_bits": (4, 3), "fake_quant": False, "axis": 0} def test_import_dir_format(tmp_path): @@ -749,7 +789,10 @@ def test_import_dir_format(tmp_path): " $import: fp8\n" ) recipe = load_recipe(tmp_path) - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3), "axis": None} + assert _cfg_to_dict(recipe.quantize["quant_cfg"][0]["cfg"]) == { + "num_bits": (4, 3), + "axis": None, + } def test_import_dir_format_metadata_imports_do_not_apply_to_quantize(tmp_path): @@ -803,7 +846,7 @@ def test_import_multi_document_list_snippet(tmp_path): recipe = load_recipe(recipe_file) assert len(recipe.quantize["quant_cfg"]) == 1 assert recipe.quantize["quant_cfg"][0]["quantizer_name"] == "*[kv]_bmm_quantizer" - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + assert _cfg_to_dict(recipe.quantize["quant_cfg"][0]["cfg"]) == {"num_bits": (4, 3)} def test_import_builtin_kv_fp8_snippet(): @@ -914,9 +957,9 @@ def test_import_mixed_tree(tmp_path): ) data = load_config(config_file) # Dict import inside list entry - assert data["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + assert _cfg_to_dict(data["quant_cfg"][0]["cfg"]) == {"num_bits": (4, 3)} # List splice - assert data["quant_cfg"][1] == {"quantizer_name": "*lm_head*", "enable": False} + assert data["quant_cfg"][1] == {"quantizer_name": "*lm_head*", "enable": False, "cfg": None} # --------------------------------------------------------------------------- @@ -955,7 +998,7 @@ def test_import_recursive(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3)} + assert _cfg_to_dict(cfg) == {"num_bits": (4, 3)} def test_import_circular_raises(tmp_path): @@ -1055,9 +1098,12 @@ def test_import_cross_file_same_name_no_conflict(tmp_path): ) recipe = load_recipe(recipe_file) # Parent's "fmt" resolves to fp8 (e4m3), not child's nvfp4. - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + assert _cfg_to_dict(recipe.quantize["quant_cfg"][0]["cfg"]) == {"num_bits": (4, 3)} # Child's "fmt" resolves to nvfp4 (e2m1), not parent's fp8. - assert recipe.quantize["quant_cfg"][1]["cfg"] == {"num_bits": (2, 1), "axis": 0} + assert _cfg_to_dict(recipe.quantize["quant_cfg"][1]["cfg"]) == { + "num_bits": (2, 1), + "axis": 0, + } # --------------------------------------------------------------------------- @@ -1088,8 +1134,8 @@ def test_builtin_config_snippets_with_modelopt_schema(config_path): assert data -def test_modelopt_schema_comment_validates_without_changing_payload(tmp_path): - """modelopt-schema validates the resolved payload but load_config still returns a plain dict.""" +def test_modelopt_schema_comment_returns_validated_type(tmp_path): + """modelopt-schema validates and returns the parsed schema value.""" config_file = tmp_path / "fp8.yaml" config_file.write_text( "# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig\n" @@ -1097,7 +1143,8 @@ def test_modelopt_schema_comment_validates_without_changing_payload(tmp_path): "axis:\n" ) data = load_config(config_file) - assert data == {"num_bits": (4, 3), "axis": None} + assert isinstance(data, QuantizerAttributeConfig) + assert data.model_dump(exclude_unset=True) == {"num_bits": (4, 3), "axis": None} def test_modelopt_schema_comment_validation_error(tmp_path): @@ -1144,11 +1191,12 @@ def test_modelopt_schema_comment_validates_after_import_resolution(tmp_path): f" $import: fp8\n" ) data = load_config(config_file) - assert data == [{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": (4, 3)}}] + assert data[0]["quantizer_name"] == "*weight_quantizer" + assert _cfg_to_dict(data[0]["cfg"]) == {"num_bits": (4, 3)} # --------------------------------------------------------------------------- -# Coverage: _load_raw_config edge cases +# Coverage: _load_raw_config_with_schema edge cases # --------------------------------------------------------------------------- @@ -1188,9 +1236,9 @@ def test_load_config_multi_doc_dict_dict(tmp_path): """Multi-document YAML with two dicts merges them.""" cfg_file = tmp_path / "multi.yaml" cfg_file.write_text("imports:\n fp8: some/path\n---\nalgorithm: max\n") - from modelopt.torch.opt.config_loader import _load_raw_config + from modelopt.torch.opt.config_loader import _load_raw_config_with_schema - data = _load_raw_config(cfg_file) + data = _load_raw_config_with_schema(cfg_file).data assert data["imports"] == {"fp8": "some/path"} assert data["algorithm"] == "max" @@ -1199,9 +1247,9 @@ def test_load_config_multi_doc_null_content(tmp_path): """Multi-document YAML where second doc is null treats content as empty dict.""" cfg_file = tmp_path / "multi_null.yaml" cfg_file.write_text("key: value\n---\n") - from modelopt.torch.opt.config_loader import _load_raw_config + from modelopt.torch.opt.config_loader import _load_raw_config_with_schema - data = _load_raw_config(cfg_file) + data = _load_raw_config_with_schema(cfg_file).data assert data == {"key": "value"} @@ -1249,7 +1297,8 @@ def test_load_config_list_valued_yaml(tmp_path): data = load_config(cfg_file) assert isinstance(data, list) assert len(data) == 2 - assert data[0] == {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8}} + assert data[0]["quantizer_name"] == "*weight_quantizer" + assert _cfg_to_dict(data[0]["cfg"]) == {"num_bits": 8} # --------------------------------------------------------------------------- diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index f5b1e576f5e..1d7b8607ddd 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -26,6 +26,7 @@ NVFP4_DEFAULT_CFG, W4A8_AWQ_BETA_CFG, QuantizeConfig, + QuantizerAttributeConfig, find_quant_cfg_entry_by_path, need_calibration, normalize_quant_cfg_list, @@ -525,3 +526,44 @@ def test_validate_quant_cfg_entries_accepts_valid_cfg(self): algorithm="max", ) assert len(cfg.quant_cfg) == 2 + + def test_quant_cfg_parses_dict_cfg_to_pydantic_type(self): + """Python dict cfg input is accepted and parsed to QuantizerAttributeConfig.""" + cfg = QuantizeConfig( + quant_cfg=[ + {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + ], + algorithm="max", + ) + attr_cfg = cfg.quant_cfg[0]["cfg"] + assert isinstance(attr_cfg, QuantizerAttributeConfig) + assert attr_cfg.model_dump(exclude_unset=True) == {"num_bits": 8, "axis": 0} + + def test_quant_cfg_parses_list_of_dict_cfg_to_pydantic_type(self): + """Python list-of-dict cfg input is accepted and parsed to QuantizerAttributeConfig.""" + cfg = QuantizeConfig( + quant_cfg=[ + { + "quantizer_name": "*weight_quantizer", + "cfg": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": 8, "axis": 0}, + ], + }, + ], + algorithm="max", + ) + attr_cfgs = cfg.quant_cfg[0]["cfg"] + assert isinstance(attr_cfgs, list) + assert all(isinstance(attr_cfg, QuantizerAttributeConfig) for attr_cfg in attr_cfgs) + assert attr_cfgs[0].num_bits == 4 + assert attr_cfgs[1].axis == 0 + + def test_quant_cfg_accepts_pydantic_cfg_instances(self): + """Already-parsed QuantizerAttributeConfig input remains valid.""" + attr_cfg = QuantizerAttributeConfig(num_bits=8, axis=0) + cfg = QuantizeConfig( + quant_cfg=[{"quantizer_name": "*weight_quantizer", "cfg": attr_cfg}], + algorithm="max", + ) + assert cfg.quant_cfg[0]["cfg"] == attr_cfg From fdad74cdc8d87b9607218c857096c55ab28730a5 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Wed, 6 May 2026 16:06:47 -0700 Subject: [PATCH 02/14] simplify quantize config loading Signed-off-by: Shengliang Xu --- modelopt/torch/quantization/config.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 13df5b19eee..7031da5917b 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1171,9 +1171,7 @@ class _QuantizeExportConfig(ModeloptBaseConfig): def _load_quantize_config_dict(config_path: str) -> dict[str, Any]: """Load a schema-backed QuantizeConfig YAML while preserving public dict constants.""" - config = load_config(config_path) - assert isinstance(config, QuantizeConfig), f"{config_path} must declare QuantizeConfig schema." - return config.model_dump(exclude_unset=True) + return load_config(config_path, schema_type=QuantizeConfig).model_dump(exclude_unset=True) _base_disable_all: list[QuantizerCfgEntry] = [ From b38703a9653f65febaba1f7fff0068781c3760e9 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Wed, 6 May 2026 17:34:29 -0700 Subject: [PATCH 03/14] Schematize quantizer config entries Convert QuantizerCfgEntry into a ModeloptBaseConfig-backed Pydantic model with validation while preserving dict-style access for callers. Normalize schema-loaded quant_cfg snippets through model_dump, simplify quantizer cfg handling, and cover both dict and QuantizeConfig need_calibration inputs. Signed-off-by: Shengliang Xu --- modelopt/torch/quantization/algorithms.py | 4 +- modelopt/torch/quantization/config.py | 220 ++++++++++++------ modelopt/torch/quantization/conversion.py | 18 +- tests/unit/recipe/test_loader.py | 30 ++- .../quantization/test_config_validation.py | 83 ++++++- 5 files changed, 251 insertions(+), 104 deletions(-) diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index f1db2df9e84..d9f7cd1b8fe 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -129,7 +129,9 @@ def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | No # Disable KV Cache quantization # Currently KV Cache quantization is enabled for some quantization formats and disabled for others # This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy - self.config.quant_cfg.append({"quantizer_name": "*output_quantizer", "enable": False}) + self.config.quant_cfg.append( + mtq_config.QuantizerCfgEntry(quantizer_name="*output_quantizer", enable=False) + ) self.compression = estimate_quant_compression(self.config) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 7031da5917b..f5f17582e0b 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -155,10 +155,9 @@ import copy import warnings from collections.abc import Mapping, Sequence -from typing import Any, Literal, cast +from typing import Any, Literal from pydantic import ValidationInfo, field_validator, model_validator -from typing_extensions import Required, TypedDict from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.opt.config_loader import load_config @@ -523,22 +522,80 @@ def validate_calibrator(cls, v, info: ValidationInfo): ) -class QuantizerCfgEntry(TypedDict, total=False): +class QuantizerCfgEntry(ModeloptBaseConfig): """A single entry in a ``quant_cfg`` list.""" - quantizer_name: Required[str] # matched against quantizer module names - parent_class: str | None # optional; filters by pytorch module class name (e.g. "nn.Linear") - cfg: QuantizerAttributeConfig | list[QuantizerAttributeConfig] | None - enable: bool | None # toggles matched quantizers on/off; independent of cfg + quantizer_name: str # matched against quantizer module names + parent_class: str | None = None # filters by PyTorch module class name, e.g. "nn.Linear" + cfg: QuantizerAttributeConfig | list[QuantizerAttributeConfig] | None = None + enable: bool | None = None # toggles matched quantizers on/off; independent of cfg + + @model_validator(mode="before") + @classmethod + def validate_quantizer_cfg_entry(cls, values): + """Validate raw quant_cfg entry semantics before cfg is parsed.""" + if not isinstance(values, Mapping): + return values + + if "cfg" not in values and "enable" not in values: + raise ValueError( + "Each quant_cfg entry must specify 'cfg', 'enable', or both. " + "An entry with only 'quantizer_name' has no effect." + ) + + cfg = values.get("cfg") + enable = values.get("enable", True) + if enable is False and cfg in ({}, []): + values = dict(values) + values["cfg"] = None + return values + + if enable and cfg is not None: + cls._validate_enabled_cfg(cfg) + return values + + @field_validator("quantizer_name") + @classmethod + def validate_quantizer_name(cls, v): + """Validate quantizer_name is non-empty.""" + if not v: + raise ValueError("quantizer_name must be a non-empty string.") + return v + + @staticmethod + def _validate_enabled_cfg(cfg): + """Validate cfg has real quantizer attributes when enabling a quantizer.""" + if isinstance(cfg, QuantizerAttributeConfig): + return + if isinstance(cfg, Mapping): + if len(cfg) == 0: + raise ValueError("cfg must be a non-empty dict when enabling a quantizer.") + return + if isinstance(cfg, list): + if len(cfg) == 0: + raise ValueError("cfg must be a non-empty list when enabling a quantizer.") + for item in cfg: + if isinstance(item, QuantizerAttributeConfig): + continue + if not isinstance(item, Mapping) or len(item) == 0: + raise ValueError( + "cfg list entries must be QuantizerAttributeConfig or non-empty dicts " + "when enabling a quantizer." + ) + return + raise ValueError( + "cfg must be QuantizerAttributeConfig, a non-empty dict, or a non-empty list " + "when enabling a quantizer." + ) def find_quant_cfg_entry_by_path( quant_cfg_list: Sequence[Any], quantizer_name: str -) -> dict[str, Any]: +) -> QuantizerCfgEntry | Mapping[str, Any]: """Find the last entry in a ``quant_cfg`` list whose ``quantizer_name`` key equals the query. This performs an **exact string comparison** against the ``quantizer_name`` field of each - entry — it does *not* apply ``fnmatch`` pattern matching. For example, passing + entry - it does *not* apply ``fnmatch`` pattern matching. For example, passing ``"*input_quantizer"`` will only match entries whose ``quantizer_name`` is literally ``"*input_quantizer"``, not entries with a different wildcard that would match the same module names at apply time. @@ -547,7 +604,7 @@ def find_quant_cfg_entry_by_path( override earlier ones, so the last match represents the effective configuration. Args: - quant_cfg_list: A list of :class:`QuantizerCfgEntry` dicts. + quant_cfg_list: A list of :class:`QuantizerCfgEntry` objects or legacy dicts. quantizer_name: The exact ``quantizer_name`` string to search for. Returns: @@ -556,9 +613,12 @@ def find_quant_cfg_entry_by_path( Raises: KeyError: If no entry with the given ``quantizer_name`` is found. """ - result: dict[str, Any] | None = None + result: QuantizerCfgEntry | Mapping[str, Any] | None = None for entry in quant_cfg_list: - if isinstance(entry, dict) and entry.get("quantizer_name") == quantizer_name: + if isinstance(entry, QuantizerCfgEntry): + if entry.get("quantizer_name") == quantizer_name: + result = entry + elif isinstance(entry, Mapping) and entry.get("quantizer_name") == quantizer_name: result = entry if result is None: raise KeyError(f"No quant_cfg entry with quantizer_name={quantizer_name!r}") @@ -930,53 +990,54 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): QuantizeQuantCfgType = list[QuantizerCfgEntry] QuantizerCfgListConfig = QuantizeQuantCfgType -QuantizeQuantCfgInputType = Sequence[Mapping[str, Any]] +QuantizeQuantCfgInputType = Sequence[QuantizerCfgEntry | Mapping[str, Any]] _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None -def normalize_quant_cfg_list(v: dict | list) -> list[QuantizerCfgEntry]: - """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` dicts. +def normalize_quant_cfg_list(v: Mapping[str, Any] | list) -> list[QuantizerCfgEntry]: + """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` objects. Supports the following input forms: - A ``list`` of entries in any of the per-entry forms below. - - A legacy flat ``dict`` (``{"*": ..., "*weight_quantizer": ...}``) — each key/value pair is + - A legacy flat ``dict`` (``{"*": ..., "*weight_quantizer": ...}``) - each key/value pair is converted to a single-key dict entry and then normalized. Per-entry forms (when input is a list): - - New format: ``{"quantizer_name": ..., "enable": ..., "cfg": ...}`` — passed through. - - Legacy single-key format: ``{"": }`` — converted to new format. - - Legacy ``nn.*``-scoped format: ``{"nn.": {"": }}`` — converted + - New format: ``{"quantizer_name": ..., "enable": ..., "cfg": ...}`` - passed through. + - Legacy single-key format: ``{"": }`` - converted to new format. + - Legacy ``nn.*``-scoped format: ``{"nn.": {"": }}`` - converted to a new-format entry with ``parent_class`` set. - **Validation** — an entry is rejected if it carries no instruction, i.e. it specifies neither + **Validation** - an entry is rejected if it carries no instruction, i.e. it specifies neither ``cfg`` nor ``enable``. Concretely, the following are invalid: - An empty entry ``{}``. - - An entry with only ``quantizer_name`` and no other keys — the only effect would be an + - An entry with only ``quantizer_name`` and no other keys - the only effect would be an implicit ``enable=True``, which must be stated explicitly. - An entry with ``enable=True`` (explicit or implicit) whose ``cfg`` is not a non-empty - ``dict`` or ``list`` — e.g. ``{"quantizer_name": "*", "cfg": {}}`` or + ``dict`` or ``list`` - e.g. ``{"quantizer_name": "*", "cfg": {}}`` or ``{"quantizer_name": "*", "cfg": 42}``. An enabled quantizer must have a valid configuration. - **Normalization** — after conversion and validation every entry is put into canonical form: + **Normalization** - after conversion and validation every entry is put into canonical form: - ``enable`` is set to ``True`` if not explicitly specified. - ``cfg`` is set to ``None`` if not present in the entry. - Every returned entry is therefore guaranteed to have the keys ``quantizer_name``, ``enable``, - and ``cfg`` (plus optionally ``parent_class``). + Every returned entry is therefore guaranteed to have ``quantizer_name``, ``enable``, and + ``cfg`` set (plus optionally ``parent_class``). The entries remain dict-like for backward + compatibility while also being Pydantic models. Args: v: A list of raw quant_cfg entries in any supported format, or a legacy flat dict. Returns: - A list of :class:`QuantizerCfgEntry` dicts in canonical normalized form. + A list of :class:`QuantizerCfgEntry` objects in canonical normalized form. Raises: ValueError: If any entry has only ``quantizer_name`` with neither ``cfg`` nor ``enable``, @@ -994,12 +1055,12 @@ def _warn_legacy(): stacklevel=4, ) - # Legacy flat-dict format: {"*": {...}, "*weight_quantizer": {...}} → list of single-key dicts. - if isinstance(v, dict): + # Legacy flat-dict format: {"*": {...}, "*weight_quantizer": {...}} -> list of single-key dicts. + if isinstance(v, Mapping): _warn_legacy() v = [{k: val} for k, val in v.items()] - def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: + def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: """Convert a single legacy key-value pair to one or more QuantizerCfgEntry dicts.""" # Legacy "default" key was a catch-all applied as "*" in the old conversion code. if key == "default": @@ -1009,7 +1070,7 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: if not isinstance(value, dict): raise ValueError(f"For 'nn.*' scoped format, value must be a dict, got {value!r}") # Support multi-key nn.*-scoped dicts by emitting one entry per sub-key. - entries: list[QuantizerCfgEntry] = [] + entries: list[dict[str, Any]] = [] for q_path, sub_cfg in value.items(): if isinstance(sub_cfg, QuantizerAttributeConfig): enable = None @@ -1018,7 +1079,7 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: sub_cfg = dict(sub_cfg) enable = sub_cfg.pop("enable", None) cfg = sub_cfg or None - entry: QuantizerCfgEntry = { + entry: dict[str, Any] = { "parent_class": key, "quantizer_name": q_path, "cfg": cfg, @@ -1042,16 +1103,22 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: result: list[QuantizerCfgEntry] = [] _warned_legacy = False for raw in v: - if isinstance(raw, dict) and "quantizer_name" in raw: + if isinstance(raw, QuantizerCfgEntry): + entries = [raw.model_dump(exclude_unset=True)] + elif isinstance(raw, Mapping) and "quantizer_name" in raw: entries = [dict(raw)] # copy to avoid mutating caller's data - elif isinstance(raw, dict) and len(raw) == 1: + elif isinstance(raw, Mapping) and len(raw) == 1: key, val = next(iter(raw.items())) entries = [dict(e) for e in _dict_to_entry(key, val)] if not _warned_legacy: _warn_legacy() _warned_legacy = True - elif isinstance(raw, dict) and len(raw) > 1 and any(k.startswith("nn.") for k in raw): - # Legacy flat dict with nn.*-scoped keys mixed with other keys — expand all pairs. + elif ( + isinstance(raw, Mapping) + and len(raw) > 1 + and any(isinstance(k, str) and k.startswith("nn.") for k in raw) + ): + # Legacy flat dict with nn.*-scoped keys mixed with other keys - expand all pairs. entries = [] for k, val in raw.items(): entries.extend(dict(e) for e in _dict_to_entry(k, val)) @@ -1065,7 +1132,7 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: # Validate: must carry at least one instruction beyond the path selector. if "cfg" not in entry and "enable" not in entry: raise ValueError( - f"Invalid quant_cfg entry: {raw!r} — each entry must specify 'cfg', 'enable', " + f"Invalid quant_cfg entry: {raw!r} - each entry must specify 'cfg', 'enable', " "or both. An entry with only 'quantizer_name' has no effect (implicit " "enable=True is not allowed; set it explicitly)." ) @@ -1090,7 +1157,7 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: is_invalid = True if is_invalid: raise ValueError( - f"Invalid quant_cfg entry: {raw!r} — 'cfg' must be a " + f"Invalid quant_cfg entry: {raw!r} - 'cfg' must be a " "QuantizerAttributeConfig, a non-empty dict, or a non-empty list of " "QuantizerAttributeConfig/non-empty dict entries when enabling a " f"quantizer (got {type(cfg).__name__}: {cfg!r}). Either provide " @@ -1102,7 +1169,7 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: entry.setdefault("enable", True) entry.setdefault("cfg", None) - result.append(cast("QuantizerCfgEntry", entry)) + result.append(QuantizerCfgEntry.model_validate(entry)) return result @@ -1125,26 +1192,11 @@ class QuantizeConfig(ModeloptBaseConfig): @field_validator("quant_cfg", mode="before") @classmethod def normalize_quant_cfg(cls, v): - """Normalize quant_cfg entries: convert dict and tuple forms to QuantizerCfgEntry dicts.""" - if not isinstance(v, (list, dict)): + """Normalize quant_cfg entries into QuantizerCfgEntry objects.""" + if not isinstance(v, (list, Mapping)): return v return normalize_quant_cfg_list(v) - @field_validator("quant_cfg", mode="after") - @classmethod - def validate_quant_cfg_entries(cls, v): - """Validate quantizer attribute configs to surface errors (e.g. invalid axis/block_sizes).""" - qac_fields = set(QuantizerAttributeConfig.model_fields.keys()) - for entry in v: - cfg = entry.get("cfg") - if cfg is None: - continue - cfgs = cfg if isinstance(cfg, list) else [cfg] - for c in cfgs: - if isinstance(c, dict) and qac_fields & set(c.keys()): - QuantizerAttributeConfig.model_validate(c) - return v - class CompressConfig(ModeloptBaseConfig): """Default configuration for ``compress`` mode.""" @@ -1174,15 +1226,43 @@ def _load_quantize_config_dict(config_path: str) -> dict[str, Any]: return load_config(config_path, schema_type=QuantizeConfig).model_dump(exclude_unset=True) -_base_disable_all: list[QuantizerCfgEntry] = [ - cast("QuantizerCfgEntry", load_config("configs/ptq/units/base_disable_all")) -] +def _quantizer_cfg_entry_to_dict(entry: QuantizerCfgEntry | Mapping[str, Any]) -> dict[str, Any]: + """Dump a typed quant_cfg entry back to the public legacy dict shape.""" + if isinstance(entry, QuantizerCfgEntry): + return entry.model_dump(exclude_unset=True) + if isinstance(entry, Mapping): + return dict(entry) + raise TypeError(f"Expected QuantizerCfgEntry or mapping, got {type(entry).__name__}.") + + +def _load_quantizer_cfg_dict_list(config_path: str) -> list[dict[str, Any]]: + """Load a QuantizerCfgEntry or QuantizerCfgListConfig snippet as public dict entries.""" + config = load_config(config_path) + if isinstance(config, QuantizerCfgEntry): + return [_quantizer_cfg_entry_to_dict(config)] + if isinstance(config, list): + entries = [] + for entry in config: + if not isinstance(entry, (QuantizerCfgEntry, Mapping)): + raise TypeError( + f"Expected QuantizerCfgEntry or mapping, got {type(entry).__name__}." + ) + entries.append(_quantizer_cfg_entry_to_dict(entry)) + return entries + if isinstance(config, Mapping): + return [_quantizer_cfg_entry_to_dict(config)] + raise TypeError(f"{config_path} must declare QuantizerCfgEntry or QuantizerCfgListConfig.") + + +_base_disable_all: list[dict[str, Any]] = _load_quantizer_cfg_dict_list( + "configs/ptq/units/base_disable_all" +) -_default_disabled_quantizer_cfg: list[QuantizerCfgEntry] = load_config( +_default_disabled_quantizer_cfg: list[dict[str, Any]] = _load_quantizer_cfg_dict_list( "configs/ptq/units/default_disabled_quantizers" ) -_mamba_moe_disabled_quantizer_cfg: list[QuantizerCfgEntry] = [ +_mamba_moe_disabled_quantizer_cfg: list[dict[str, Any]] = [ {"quantizer_name": "*fc1_latent_proj*", "enable": False}, # Skip Latent MOE {"quantizer_name": "*fc2_latent_proj*", "enable": False}, # Skip Latent MOE {"quantizer_name": "*q_proj*", "enable": False}, # Skip QKV Linear (HF naming) @@ -1507,25 +1587,19 @@ def _nvfp4_selective_quant_cfg( algorithm: str | dict = "max", ) -> dict: """Build an NVFP4 config that quantizes only the specified layer patterns.""" - quant_cfg: list[QuantizerCfgEntry] = [] + quant_cfg: list[dict[str, Any]] = [] quant_cfg.extend(_base_disable_all) for pattern in layer_patterns: # Deep-copy the quantizer dict so each config constant gets its own instance. quant_cfg.append( - cast( - "QuantizerCfgEntry", - {"quantizer_name": f"{pattern}weight_quantizer", "cfg": copy.deepcopy(quantizer)}, - ) + {"quantizer_name": f"{pattern}weight_quantizer", "cfg": copy.deepcopy(quantizer)} ) if not weight_only: quant_cfg.append( - cast( - "QuantizerCfgEntry", - { - "quantizer_name": f"{pattern}input_quantizer", - "cfg": copy.deepcopy(quantizer), - }, - ) + { + "quantizer_name": f"{pattern}input_quantizer", + "cfg": copy.deepcopy(quantizer), + } ) quant_cfg.extend(_default_disabled_quantizer_cfg) return {"quant_cfg": quant_cfg, "algorithm": algorithm} @@ -1783,7 +1857,7 @@ def _nvfp4_selective_quant_cfg( } -def need_calibration(config): +def need_calibration(config: QuantizeConfig | Mapping[str, Any]) -> bool: """Check if calibration is needed for the given config.""" if config["algorithm"] is not None and config["algorithm"] != "max": return True diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 535e6d6f508..808688796d4 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -253,7 +253,7 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgInpu for entry in quant_cfg: quantizer_name: str = entry["quantizer_name"] - cfg = entry["cfg"] # None, dict, or list — always explicit after normalization + cfg = entry["cfg"] # None, QuantizerAttributeConfig, or list after normalization enable: bool = entry["enable"] # always explicit after normalization parent_class_name = entry.get("parent_class") if parent_class_name: @@ -276,24 +276,12 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgInpu # Has cfg: apply full replacement with the explicit enable value. if isinstance(cfg, QuantizerAttributeConfig): attributes = cfg.model_copy(update={"enable": enable}) - elif isinstance(cfg, dict): - attributes = QuantizerAttributeConfig(**cfg, enable=enable) elif isinstance(cfg, list): - attributes = [] - for c in cfg: - if isinstance(c, QuantizerAttributeConfig): - attributes.append(c.model_copy(update={"enable": enable})) - elif isinstance(c, dict): - attributes.append(QuantizerAttributeConfig(**c, enable=enable)) - else: - raise ValueError( - f"Invalid cfg element for quantizer {quantizer_name!r}: expected " - "QuantizerAttributeConfig or dict." - ) + attributes = [c.model_copy(update={"enable": enable}) for c in cfg] else: raise ValueError( f"Invalid cfg for quantizer {quantizer_name!r}: expected " - "QuantizerAttributeConfig, dict, or list." + "QuantizerAttributeConfig or list." ) set_quantizer_attributes_full(quant_model, quantizer_name, attributes, parent_class) diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index 19463493b8f..fa8e11aea50 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -21,7 +21,11 @@ from modelopt.recipe.config import ModelOptPTQRecipe, RecipeMetadataConfig, RecipeType from modelopt.recipe.loader import load_config, load_recipe -from modelopt.torch.quantization.config import QuantizeConfig, QuantizerAttributeConfig +from modelopt.torch.quantization.config import ( + QuantizeConfig, + QuantizerAttributeConfig, + QuantizerCfgEntry, +) # --------------------------------------------------------------------------- # Static YAML fixtures @@ -84,6 +88,12 @@ def _cfg_to_dict(cfg): return cfg +def _entry_to_dict(entry): + if isinstance(entry, QuantizerCfgEntry): + return entry.model_dump(exclude_unset=True) + return dict(entry) + + # --------------------------------------------------------------------------- # Directory-format YAML fixtures # --------------------------------------------------------------------------- @@ -551,7 +561,11 @@ def test_import_entry_element_schema_appends(tmp_path): f" - $import: disable_all\n" ) recipe = load_recipe(recipe_file) - assert recipe.quantize["quant_cfg"] == [{"quantizer_name": "*", "cfg": None, "enable": False}] + assert _entry_to_dict(recipe.quantize["quant_cfg"][0]) == { + "quantizer_name": "*", + "cfg": None, + "enable": False, + } def test_import_entry_wrong_schema_raises(tmp_path): @@ -895,7 +909,8 @@ def test_import_list_splice_outside_typed_list_raises(tmp_path): """A bare $import in an untyped list is rejected.""" _write_quantizer_cfg_list( tmp_path / "extra_tasks.yml", - "- quantizer_name: '*weight_quantizer'\n- quantizer_name: '*input_quantizer'\n", + "- quantizer_name: '*weight_quantizer'\n enable: false\n" + "- quantizer_name: '*input_quantizer'\n enable: false\n", ) config_file = tmp_path / "config.yml" config_file.write_text( @@ -959,7 +974,11 @@ def test_import_mixed_tree(tmp_path): # Dict import inside list entry assert _cfg_to_dict(data["quant_cfg"][0]["cfg"]) == {"num_bits": (4, 3)} # List splice - assert data["quant_cfg"][1] == {"quantizer_name": "*lm_head*", "enable": False, "cfg": None} + assert _entry_to_dict(data["quant_cfg"][1]) == { + "quantizer_name": "*lm_head*", + "enable": False, + "cfg": None, + } # --------------------------------------------------------------------------- @@ -1310,7 +1329,8 @@ def test_import_dict_value_resolves_to_list_raises(tmp_path): """$import in dict value position raises when snippet is a list.""" _write_quantizer_cfg_list( tmp_path / "entries.yml", - "- quantizer_name: '*weight_quantizer'\n- quantizer_name: '*input_quantizer'\n", + "- quantizer_name: '*weight_quantizer'\n enable: false\n" + "- quantizer_name: '*input_quantizer'\n enable: false\n", ) config_file = tmp_path / "config.yml" config_file.write_text( diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 1d7b8607ddd..ba49d863a2c 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -18,6 +18,7 @@ import pytest from pydantic import ValidationError +from modelopt.torch.opt.config import ModeloptBaseConfig from modelopt.torch.quantization.config import ( FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, FP8_DEFAULT_CFG, @@ -27,12 +28,21 @@ W4A8_AWQ_BETA_CFG, QuantizeConfig, QuantizerAttributeConfig, + QuantizerCfgEntry, find_quant_cfg_entry_by_path, need_calibration, normalize_quant_cfg_list, ) +def _cfg_to_dict(cfg): + if isinstance(cfg, QuantizerAttributeConfig): + return cfg.model_dump(exclude_unset=True) + if isinstance(cfg, list): + return [_cfg_to_dict(item) for item in cfg] + return cfg + + def test_need_calibration(): assert need_calibration(FP8_DEFAULT_CFG) assert not need_calibration(FP8_PER_CHANNEL_PER_TOKEN_CFG) @@ -42,6 +52,12 @@ def test_need_calibration(): assert need_calibration(NVFP4_DEFAULT_CFG) +def test_need_calibration_with_quantize_config_type(): + """need_calibration accepts schema-backed QuantizeConfig objects.""" + assert need_calibration(QuantizeConfig.model_validate(FP8_DEFAULT_CFG)) + assert not need_calibration(QuantizeConfig.model_validate(FP8_PER_CHANNEL_PER_TOKEN_CFG)) + + def test_need_calibration_with_list_cfg(): """need_calibration must handle sequential (list) cfg entries without crashing.""" # Static list-cfg on a non-weight quantizer → needs calibration @@ -74,6 +90,50 @@ def test_need_calibration_with_list_cfg(): assert not need_calibration(cfg_dynamic) +def test_quantizer_cfg_entry_is_pydantic_and_dict_like(): + """QuantizerCfgEntry is typed but keeps the dict-style access used by callers.""" + entry = QuantizerCfgEntry(quantizer_name="*", enable=False) + assert isinstance(entry, ModeloptBaseConfig) + assert entry["quantizer_name"] == "*" + assert entry.get("cfg") is None + assert entry.model_dump(exclude_unset=True) == {"quantizer_name": "*", "enable": False} + + cfg_entry = QuantizerCfgEntry(quantizer_name="*weight_quantizer", cfg={"num_bits": 8}) + assert isinstance(cfg_entry["cfg"], QuantizerAttributeConfig) + assert _cfg_to_dict(cfg_entry["cfg"]) == {"num_bits": 8} + + +def test_public_preset_quant_cfg_entries_remain_dicts(): + """Public preset constants keep legacy dict entries for downstream compatibility.""" + assert all(isinstance(entry, dict) for entry in FP8_DEFAULT_CFG["quant_cfg"]) + assert all(isinstance(entry, dict) for entry in NVFP4_DEFAULT_CFG["quant_cfg"]) + + +def test_quantizer_cfg_entry_rejects_no_effect_entry(): + """Direct QuantizerCfgEntry construction rejects entries with no cfg or enable.""" + with pytest.raises(ValidationError, match="must specify 'cfg', 'enable'"): + QuantizerCfgEntry(quantizer_name="*") + + +def test_quantizer_cfg_entry_rejects_empty_name(): + """Direct QuantizerCfgEntry construction rejects empty quantizer names.""" + with pytest.raises(ValidationError, match="non-empty string"): + QuantizerCfgEntry(quantizer_name="", enable=False) + + +def test_quantizer_cfg_entry_rejects_empty_cfg_when_enabled(): + """Direct QuantizerCfgEntry construction rejects empty enabled cfg values.""" + with pytest.raises(ValidationError, match="non-empty dict"): + QuantizerCfgEntry(quantizer_name="*weight_quantizer", cfg={}) + + +def test_quantizer_cfg_entry_treats_empty_disabled_cfg_as_disable_only(): + """Empty cfg with enable=False remains a disable-only entry.""" + entry = QuantizerCfgEntry(quantizer_name="*input_quantizer", cfg={}, enable=False) + assert entry["cfg"] is None + assert entry["enable"] is False + + class TestNormalizeQuantCfgList: def test_new_format_passthrough(self): """New-format entries are returned unchanged (only canonical defaults added).""" @@ -81,7 +141,8 @@ def test_new_format_passthrough(self): result = normalize_quant_cfg_list(raw) assert len(result) == 1 assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} + assert isinstance(result[0]["cfg"], QuantizerAttributeConfig) + assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 8, "axis": 0} assert result[0]["enable"] is True # defaulted def test_new_format_enable_false(self): @@ -103,7 +164,7 @@ def test_legacy_single_key_dict(self): raw = [{"*weight_quantizer": {"num_bits": 8, "axis": 0}}] result = normalize_quant_cfg_list(raw) assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} + assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 8, "axis": 0} assert result[0]["enable"] is True # defaulted def test_legacy_single_key_dict_with_enable(self): @@ -231,7 +292,9 @@ def test_new_format_with_list_cfg(self): ] result = normalize_quant_cfg_list(raw) assert len(result) == 1 - assert result[0]["cfg"] == raw[0]["cfg"] + assert isinstance(result[0]["cfg"], list) + assert all(isinstance(cfg, QuantizerAttributeConfig) for cfg in result[0]["cfg"]) + assert _cfg_to_dict(result[0]["cfg"]) == raw[0]["cfg"] assert result[0]["enable"] is True def test_legacy_flat_dict_conversion(self): @@ -243,7 +306,7 @@ def test_legacy_flat_dict_conversion(self): assert result[0]["enable"] is False assert result[0]["cfg"] is None assert result[1]["quantizer_name"] == "*weight_quantizer" - assert result[1]["cfg"] == {"num_bits": 8, "axis": 0} + assert _cfg_to_dict(result[1]["cfg"]) == {"num_bits": 8, "axis": 0} assert result[1]["enable"] is True def test_legacy_enable_only_produces_cfg_none(self): @@ -274,7 +337,7 @@ def test_legacy_default_key_with_cfg(self): raw = [{"default": {"num_bits": 8, "axis": None}}] result = normalize_quant_cfg_list(raw) assert result[0]["quantizer_name"] == "*" - assert result[0]["cfg"] == {"num_bits": 8, "axis": None} + assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 8, "axis": None} assert result[0]["enable"] is True def test_legacy_flat_dict_with_default_key(self): @@ -309,7 +372,7 @@ def test_legacy_nn_class_with_cfg(self): assert len(result) == 1 assert result[0]["parent_class"] == "nn.Linear" assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 4, "axis": 0} + assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 4, "axis": 0} assert result[0]["enable"] is True def test_legacy_list_valued_cfg(self): @@ -343,7 +406,7 @@ def test_finds_last_match(self): ] ) result = find_quant_cfg_entry_by_path(entries, "*weight_quantizer") - assert result["cfg"] == {"num_bits": 4} + assert _cfg_to_dict(result["cfg"]) == {"num_bits": 4} def test_exact_match_only(self): """Does not do fnmatch — only exact string equality on quantizer_name.""" @@ -400,7 +463,7 @@ def test_wildcard_matches_bare_name(self): [{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8}}] ) matched, enable = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 8} + assert _cfg_to_dict(matched) == {"num_bits": 8} assert enable is True def test_star_matches_any_bare_name(self): @@ -420,7 +483,7 @@ def test_path_scoped_pattern_matches_matching_suffix(self): [{"quantizer_name": "*mlp*weight_quantizer", "cfg": {"num_bits": 4}}] ) matched, enable = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 4} + assert _cfg_to_dict(matched) == {"num_bits": 4} def test_path_scoped_pattern_does_not_match_different_suffix(self): """'*mlp*weight_quantizer' does NOT match bare 'input_quantizer'.""" @@ -444,7 +507,7 @@ def test_last_match_wins(self): ] ) matched, _ = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 4} + assert _cfg_to_dict(matched) == {"num_bits": 4} def test_no_match_returns_none(self): """No matching entry returns (None, None).""" From b7c9359fadd0e3f5caff1646bc36b2bddb3e26a7 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Wed, 6 May 2026 17:46:09 -0700 Subject: [PATCH 04/14] use ModeloptField Signed-off-by: Shengliang Xu --- modelopt/torch/quantization/config.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index f5f17582e0b..481870a0496 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -525,10 +525,29 @@ def validate_calibrator(cls, v, info: ValidationInfo): class QuantizerCfgEntry(ModeloptBaseConfig): """A single entry in a ``quant_cfg`` list.""" - quantizer_name: str # matched against quantizer module names - parent_class: str | None = None # filters by PyTorch module class name, e.g. "nn.Linear" - cfg: QuantizerAttributeConfig | list[QuantizerAttributeConfig] | None = None - enable: bool | None = None # toggles matched quantizers on/off; independent of cfg + quantizer_name: str = ModeloptField( + default="", + title="Quantizer name pattern.", + description="Wildcard pattern matched against quantizer module names.", + validate_default=True, + ) + parent_class: str | None = ModeloptField( + default=None, + title="Parent module class filter.", + description='Optional PyTorch module class name filter, e.g. "nn.Linear".', + ) + cfg: QuantizerAttributeConfig | list[QuantizerAttributeConfig] | None = ModeloptField( + default=None, + title="Quantizer attributes.", + description=( + "Attributes to apply to matched quantizers. A list configures a sequential quantizer." + ), + ) + enable: bool | None = ModeloptField( + default=None, + title="Quantizer enable flag.", + description="Optional on/off toggle for matched quantizers, independent of cfg.", + ) @model_validator(mode="before") @classmethod From b7df9d2846cf8ff4dc630f16f7aba3bdc9ea1003 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 7 May 2026 01:05:15 -0700 Subject: [PATCH 05/14] Return typed normalized quantizer config entries Update normalize_quant_cfg_list to accept dict entries, typed entries, and legacy dict formats while returning QuantizerCfgEntry objects. Preserve already parsed entries, handle implicit enable values in consumers, and cover mixed typed/dict inputs in tests. Signed-off-by: Shengliang Xu --- modelopt/torch/quantization/config.py | 32 ++++++++++++------- modelopt/torch/quantization/conversion.py | 2 +- .../quantization/test_config_validation.py | 28 +++++++++++++++- 3 files changed, 48 insertions(+), 14 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 481870a0496..27edc7d9c92 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -155,7 +155,7 @@ import copy import warnings from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Any, Literal, cast from pydantic import ValidationInfo, field_validator, model_validator @@ -1016,12 +1016,15 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None -def normalize_quant_cfg_list(v: Mapping[str, Any] | list) -> list[QuantizerCfgEntry]: +def normalize_quant_cfg_list( + v: Mapping[str, Any] | list[QuantizerCfgEntry | Mapping[str, Any]], +) -> list[QuantizerCfgEntry]: """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` objects. Supports the following input forms: - A ``list`` of entries in any of the per-entry forms below. + - A ``list`` containing :class:`QuantizerCfgEntry` objects, which are preserved as-is. - A legacy flat ``dict`` (``{"*": ..., "*weight_quantizer": ...}``) - each key/value pair is converted to a single-key dict entry and then normalized. @@ -1048,21 +1051,24 @@ def normalize_quant_cfg_list(v: Mapping[str, Any] | list) -> list[QuantizerCfgEn - ``enable`` is set to ``True`` if not explicitly specified. - ``cfg`` is set to ``None`` if not present in the entry. - Every returned entry is therefore guaranteed to have ``quantizer_name``, ``enable``, and - ``cfg`` set (plus optionally ``parent_class``). The entries remain dict-like for backward - compatibility while also being Pydantic models. + For dict and legacy inputs, every returned entry is guaranteed to have + ``quantizer_name``, ``enable``, and ``cfg`` set (plus optionally ``parent_class``). Typed + :class:`QuantizerCfgEntry` inputs are assumed to be already parsed and are preserved. Args: v: A list of raw quant_cfg entries in any supported format, or a legacy flat dict. Returns: - A list of :class:`QuantizerCfgEntry` objects in canonical normalized form. + A list of :class:`QuantizerCfgEntry` objects in canonical normalized form. Existing + typed entries are preserved. Raises: ValueError: If any entry has only ``quantizer_name`` with neither ``cfg`` nor ``enable``, if ``enable=True`` with an empty or non-dict/list ``cfg``, or if the entry format is not recognized. """ + if isinstance(v, list) and all(isinstance(raw, QuantizerCfgEntry) for raw in v): + return cast("list[QuantizerCfgEntry]", v) def _warn_legacy(): warnings.warn( @@ -1123,8 +1129,9 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: _warned_legacy = False for raw in v: if isinstance(raw, QuantizerCfgEntry): - entries = [raw.model_dump(exclude_unset=True)] - elif isinstance(raw, Mapping) and "quantizer_name" in raw: + result.append(raw) + continue + if isinstance(raw, Mapping) and "quantizer_name" in raw: entries = [dict(raw)] # copy to avoid mutating caller's data elif isinstance(raw, Mapping) and len(raw) == 1: key, val = next(iter(raw.items())) @@ -1894,6 +1901,7 @@ def _cfg_to_dict(cfg): for entry in quant_cfg: name = entry["quantizer_name"] raw_cfg = entry.get("cfg") + enable = entry.get("enable") if "weight_quantizer" in name: # We don't calibrate weight quantizer continue @@ -1901,14 +1909,14 @@ def _cfg_to_dict(cfg): if isinstance(raw_cfg, list): for _config in raw_cfg: cfg = _cfg_to_dict(_config) - if "enable" in entry: - cfg["enable"] = entry["enable"] + if enable is not None: + cfg["enable"] = enable if _not_dynamic(cfg): return True continue cfg = _cfg_to_dict(raw_cfg) - if "enable" in entry: - cfg["enable"] = entry["enable"] + if enable is not None: + cfg["enable"] = enable if _not_dynamic(cfg): return True diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 808688796d4..55d48691de0 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -254,7 +254,7 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgInpu for entry in quant_cfg: quantizer_name: str = entry["quantizer_name"] cfg = entry["cfg"] # None, QuantizerAttributeConfig, or list after normalization - enable: bool = entry["enable"] # always explicit after normalization + enable = entry["enable"] if entry["enable"] is not None else True parent_class_name = entry.get("parent_class") if parent_class_name: try: diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index ba49d863a2c..3bcb687c1af 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -136,15 +136,40 @@ def test_quantizer_cfg_entry_treats_empty_disabled_cfg_as_disable_only(): class TestNormalizeQuantCfgList: def test_new_format_passthrough(self): - """New-format entries are returned unchanged (only canonical defaults added).""" + """New-format dict entries are normalized into QuantizerCfgEntry objects.""" raw = [{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}] result = normalize_quant_cfg_list(raw) assert len(result) == 1 + assert isinstance(result[0], QuantizerCfgEntry) assert result[0]["quantizer_name"] == "*weight_quantizer" assert isinstance(result[0]["cfg"], QuantizerAttributeConfig) assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 8, "axis": 0} assert result[0]["enable"] is True # defaulted + def test_typed_entry_list_passthrough(self): + """Already-parsed QuantizerCfgEntry lists are returned unchanged.""" + raw = [ + QuantizerCfgEntry( + quantizer_name="*weight_quantizer", + cfg=QuantizerAttributeConfig(num_bits=8, axis=0), + enable=True, + ) + ] + result = normalize_quant_cfg_list(raw) + assert result is raw + assert result[0] is raw[0] + + def test_mixed_typed_and_dict_entries_normalize_to_typed_entries(self): + """Mixed QuantizerCfgEntry/dict input lists normalize dicts and preserve typed entries.""" + typed_entry = QuantizerCfgEntry(quantizer_name="*", enable=False) + result = normalize_quant_cfg_list( + [typed_entry, {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8}}] + ) + assert result[0] is typed_entry + assert isinstance(result[1], QuantizerCfgEntry) + assert _cfg_to_dict(result[1]["cfg"]) == {"num_bits": 8} + assert result[1]["enable"] is True + def test_new_format_enable_false(self): """Explicit enable=False is preserved.""" raw = [{"quantizer_name": "*", "enable": False}] @@ -292,6 +317,7 @@ def test_new_format_with_list_cfg(self): ] result = normalize_quant_cfg_list(raw) assert len(result) == 1 + assert isinstance(result[0], QuantizerCfgEntry) assert isinstance(result[0]["cfg"], list) assert all(isinstance(cfg, QuantizerAttributeConfig) for cfg in result[0]["cfg"]) assert _cfg_to_dict(result[0]["cfg"]) == raw[0]["cfg"] From 53fcd04ab2d1295930935733ac7ab0e109cdf206 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 7 May 2026 13:19:30 -0700 Subject: [PATCH 06/14] Use typed mapping quant configs Make ModeloptBaseConfig a MutableMapping and use Mapping/MutableMapping protocol checks for typed quantizer config entries and attributes. Convert predefined quantization recipes to QuantizeConfig objects while preserving dict-style callers and compatibility paths. Signed-off-by: Shengliang Xu --- examples/diffusers/quantization/config.py | 10 +- examples/diffusers/quantization/quantize.py | 9 +- examples/llm_autodeploy/run_auto_quantize.py | 6 +- examples/llm_ptq/cast_mxfp4_to_nvfp4.py | 3 +- examples/llm_ptq/example_utils.py | 41 +- examples/llm_ptq/hf_ptq.py | 16 +- examples/llm_ptq/multinode_ptq.py | 5 +- examples/vllm_serve/vllm_ptq_utils.py | 6 +- .../llm_export_utils/quantization_utils.py | 4 +- modelopt/torch/opt/config.py | 20 +- modelopt/torch/quantization/algorithms.py | 47 +- .../backends/fp8_per_tensor_gemm.py | 4 +- .../torch/quantization/backends/nvfp4_gemm.py | 4 +- modelopt/torch/quantization/config.py | 896 +++++++++--------- modelopt/torch/quantization/conversion.py | 12 +- modelopt/torch/quantization/mode.py | 4 +- modelopt/torch/quantization/model_quant.py | 18 +- .../nn/modules/tensor_quantizer.py | 12 +- .../torch/quantization/utils/core_utils.py | 28 +- .../quantization/test_config_validation.py | 34 +- 20 files changed, 654 insertions(+), 525 deletions(-) diff --git a/examples/diffusers/quantization/config.py b/examples/diffusers/quantization/config.py index e15b8c7ba3c..967802ffa97 100644 --- a/examples/diffusers/quantization/config.py +++ b/examples/diffusers/quantization/config.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Mapping, MutableMapping + import torch.nn as nn from calib.plugin_calib import PercentileCalibrator @@ -104,8 +106,12 @@ def set_quant_config_attr(quant_config, trt_high_precision_dtype, quant_algo, ** quant_config["algorithm"] = algo_cfg for entry in quant_config["quant_cfg"]: - p = entry.get("cfg", {}) - if isinstance(p, dict) and "num_bits" in p and "trt_high_precision_dtype" not in p: + p = entry.get("cfg", {}) if isinstance(entry, Mapping) else {} + if ( + isinstance(p, MutableMapping) + and "num_bits" in p + and "trt_high_precision_dtype" not in p + ): p["trt_high_precision_dtype"] = trt_high_precision_dtype diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 2a3c947a2d6..a087d6f1d69 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -17,6 +17,7 @@ import logging import sys import time as time +from collections.abc import Mapping from pathlib import Path from typing import Any @@ -142,11 +143,13 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: if self.config.format == QuantFormat.FP4: for i, entry in enumerate(quant_cfg_list): - if isinstance(entry, dict) and "block_sizes" in entry.get("cfg", {}): - new_block_sizes = {**entry["cfg"]["block_sizes"], -1: self.config.block_size} + cfg = entry.get("cfg", {}) if isinstance(entry, Mapping) else {} + block_sizes = cfg.get("block_sizes") if isinstance(cfg, Mapping) else None + if isinstance(block_sizes, Mapping): + new_block_sizes = {**block_sizes, -1: self.config.block_size} quant_cfg_list[i] = { **entry, - "cfg": {**entry["cfg"], "block_sizes": new_block_sizes}, + "cfg": {**cfg, "block_sizes": new_block_sizes}, } if self.config.quantize_mha: diff --git a/examples/llm_autodeploy/run_auto_quantize.py b/examples/llm_autodeploy/run_auto_quantize.py index ebd7c1090bb..e20bea5a1f4 100644 --- a/examples/llm_autodeploy/run_auto_quantize.py +++ b/examples/llm_autodeploy/run_auto_quantize.py @@ -24,9 +24,9 @@ from modelopt.torch.utils import create_forward_loop from modelopt.torch.utils.dataset_utils import get_dataset_dataloader -SUPPORT_QUANT_FORMAT = { - "fp8": mtq.FP8_DEFAULT_CFG, - "nvfp4": mtq.NVFP4_DEFAULT_CFG, +SUPPORT_QUANT_FORMAT: dict[str, str] = { + "fp8": "FP8_DEFAULT_CFG", + "nvfp4": "NVFP4_DEFAULT_CFG", } diff --git a/examples/llm_ptq/cast_mxfp4_to_nvfp4.py b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py index 26f3c9f8258..2223f3cf055 100644 --- a/examples/llm_ptq/cast_mxfp4_to_nvfp4.py +++ b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py @@ -34,6 +34,7 @@ """ import json +from collections.abc import Mapping from contextlib import ExitStack, contextmanager from pathlib import Path @@ -304,7 +305,7 @@ def force_weight_quantizers_static(quant_cfg: list) -> None: qname = entry.get("quantizer_name", "") cfg = entry.get("cfg") or {} bs = cfg.get("block_sizes") - if "weight_quantizer" in qname and isinstance(bs, dict): + if "weight_quantizer" in qname and isinstance(bs, Mapping): quant_cfg[i] = {**entry, "cfg": {**cfg, "block_sizes": {**bs, "type": "static"}}} diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 9455157645c..ba3a878bfc9 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -23,6 +23,7 @@ import shutil import sys import warnings +from collections.abc import Mapping, MutableMapping from pathlib import Path from typing import Any @@ -41,6 +42,8 @@ ProcessorMixin, ) +from modelopt.torch.quantization.config import QuantizeConfig + try: from huggingface_hub import snapshot_download except ImportError: @@ -203,17 +206,17 @@ def calibrate_loop(_model): def build_quant_cfg( qformat, - quant_cfg, + quant_cfg: QuantizeConfig | Mapping[str, Any], awq_block_size, model_type, moe_calib_experts_ratio: float | None = None, -) -> dict[str, Any]: - quant_cfg = copy.deepcopy(quant_cfg) - if "awq" in str(quant_cfg.get("algorithm")): +) -> QuantizeConfig: + quant_cfg_obj: QuantizeConfig = QuantizeConfig.model_validate(copy.deepcopy(quant_cfg)) + if "awq" in str(quant_cfg_obj.get("algorithm")): from modelopt.torch.quantization.config import find_quant_cfg_entry_by_path weight_quantizer_entry = find_quant_cfg_entry_by_path( - quant_cfg["quant_cfg"], "*weight_quantizer" + quant_cfg_obj["quant_cfg"], "*weight_quantizer" ) weight_quantizer = weight_quantizer_entry.get("cfg") or {} if isinstance(weight_quantizer, list): @@ -224,34 +227,34 @@ def build_quant_cfg( # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]: - quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} + quant_cfg_obj["algorithm"] = {"method": "awq_lite", "alpha_step": 1} if moe_calib_experts_ratio: assert 0 < moe_calib_experts_ratio <= 1, "moe_calib_experts_ratio must be between 0 and 1" - if isinstance(quant_cfg["algorithm"], str): - quant_cfg["algorithm"] = { - "method": quant_cfg["algorithm"], + if isinstance(quant_cfg_obj["algorithm"], str): + quant_cfg_obj["algorithm"] = { + "method": quant_cfg_obj["algorithm"], "moe_calib_experts_ratio": moe_calib_experts_ratio, } - elif isinstance(quant_cfg["algorithm"], dict): - quant_cfg["algorithm"]["moe_calib_experts_ratio"] = moe_calib_experts_ratio + elif isinstance(quant_cfg_obj["algorithm"], MutableMapping): + quant_cfg_obj["algorithm"]["moe_calib_experts_ratio"] = moe_calib_experts_ratio else: warnings.warn( - f"Quantization algorithm: {quant_cfg['algorithm']} does not support setting moe_calib_experts_ratio" + f"Quantization algorithm: {quant_cfg_obj['algorithm']} does not support setting moe_calib_experts_ratio" ) # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. if model_type == "gemma" and "int8_sq" in qformat: - quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} + quant_cfg_obj["algorithm"] = {"method": "smoothquant", "alpha": 0.5} if model_type == "phi4mm": # Only quantize the language model - quant_cfg["quant_cfg"].append({"quantizer_name": "*speech*", "enable": False}) - quant_cfg["quant_cfg"].append({"quantizer_name": "*audio*", "enable": False}) - quant_cfg["quant_cfg"].append({"quantizer_name": "*image*", "enable": False}) - quant_cfg["quant_cfg"].append({"quantizer_name": "*vision*", "enable": False}) + quant_cfg_obj["quant_cfg"].append({"quantizer_name": "*speech*", "enable": False}) + quant_cfg_obj["quant_cfg"].append({"quantizer_name": "*audio*", "enable": False}) + quant_cfg_obj["quant_cfg"].append({"quantizer_name": "*image*", "enable": False}) + quant_cfg_obj["quant_cfg"].append({"quantizer_name": "*vision*", "enable": False}) - return quant_cfg + return quant_cfg_obj def is_speculative(hf_config): @@ -842,7 +845,7 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod def needs_checkpoint_path_update(quant_cfg: dict) -> bool: """Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath.""" algorithm = quant_cfg.get("algorithm") - if not isinstance(algorithm, dict): + if not isinstance(algorithm, Mapping): return False return algorithm.get("layerwise_checkpoint_dir") is not None diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 758ed75aeed..ca75974d6d3 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -66,7 +66,11 @@ save_expert_token_count_table, ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model -from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration +from modelopt.torch.quantization.config import ( + QuantizeConfig, + _default_disabled_quantizer_cfg, + need_calibration, +) from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized from modelopt.torch.speculative.eagle.utils import ( @@ -89,18 +93,18 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: """Set use_constant_amax on KV cache quantizers. - Creates a new dict for the KV bmm quantizer config to avoid mutating shared references. + Updates the matched KV bmm quantizer entry in place. """ - for i, entry in enumerate(quant_cfg): + for entry in quant_cfg: if entry.get("quantizer_name") != "*[kv]_bmm_quantizer": continue cfg = entry.get("cfg") or {} - assert isinstance(cfg, dict) - quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}} + cfg["use_constant_amax"] = True + entry["cfg"] = cfg break -QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { +QUANT_CFG_CHOICES: dict[str, QuantizeConfig] = { "int8": mtq.INT8_DEFAULT_CFG, "int8_sq": mtq.INT8_SMOOTHQUANT_CFG, "int8_wo": mtq.INT8_WEIGHT_ONLY_CFG, diff --git a/examples/llm_ptq/multinode_ptq.py b/examples/llm_ptq/multinode_ptq.py index 93ef21ea4d4..66c38268d2f 100644 --- a/examples/llm_ptq/multinode_ptq.py +++ b/examples/llm_ptq/multinode_ptq.py @@ -22,7 +22,6 @@ import time import warnings from pathlib import Path -from typing import Any import numpy as np import torch @@ -37,14 +36,14 @@ from modelopt.torch.export import get_model_type from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint -from modelopt.torch.quantization.config import need_calibration +from modelopt.torch.quantization.config import QuantizeConfig, need_calibration from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets # Constants RAND_SEED = 1234 -QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { +QUANT_CFG_CHOICES: dict[str, QuantizeConfig] = { "int8": mtq.INT8_DEFAULT_CFG, "int4_awq": mtq.INT4_AWQ_CFG, "fp8": mtq.FP8_DEFAULT_CFG, diff --git a/examples/vllm_serve/vllm_ptq_utils.py b/examples/vllm_serve/vllm_ptq_utils.py index 88b31d54a70..1c7c6c734ef 100644 --- a/examples/vllm_serve/vllm_ptq_utils.py +++ b/examples/vllm_serve/vllm_ptq_utils.py @@ -119,11 +119,7 @@ def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: list) -> list: return kv_quant_cfg kv_entry = next( - ( - e - for e in kv_quant_cfg - if isinstance(e, dict) and e.get("quantizer_name") == "*[kv]_bmm_quantizer" - ), + (e for e in kv_quant_cfg if e.get("quantizer_name") == "*[kv]_bmm_quantizer"), None, ) if kv_entry is not None: diff --git a/modelopt/onnx/llm_export_utils/quantization_utils.py b/modelopt/onnx/llm_export_utils/quantization_utils.py index 54ca93d5388..3f5aa1d910f 100644 --- a/modelopt/onnx/llm_export_utils/quantization_utils.py +++ b/modelopt/onnx/llm_export_utils/quantization_utils.py @@ -69,9 +69,7 @@ def get_quant_config(precision, lm_head_precision="fp16"): else: raise ValueError(f"Unsupported precision: {precision}") - quant_cfg_list: list = [ - e for e in quant_cfg["quant_cfg"] if isinstance(e, dict) and "quantizer_name" in e - ] + quant_cfg_list: list = list(quant_cfg["quant_cfg"]) if lm_head_precision == "fp8": quant_cfg_list.append( diff --git a/modelopt/torch/opt/config.py b/modelopt/torch/opt/config.py index 62f7b7e16a2..874fe145a57 100644 --- a/modelopt/torch/opt/config.py +++ b/modelopt/torch/opt/config.py @@ -17,7 +17,7 @@ import fnmatch import json -from collections.abc import Callable, ItemsView, Iterator, KeysView, ValuesView +from collections.abc import Callable, ItemsView, Iterator, KeysView, MutableMapping, ValuesView from typing import Any, TypeAlias import torch @@ -57,7 +57,7 @@ def ModeloptField(default: Any = PydanticUndefined, **kwargs): # noqa: N802 # TODO: expand config classes to searcher -class ModeloptBaseConfig(BaseModel): +class ModeloptBaseConfig(BaseModel, MutableMapping[str, Any]): """Our config base class for mode configuration. The base class extends the capabilities of pydantic's BaseModel to provide additional methods @@ -117,6 +117,22 @@ def __setitem__(self, key: str, value: Any) -> None: """Set the value for the given key (can be name or alias of field).""" setattr(self, self.get_field_name_from_key(key), value) + def __delitem__(self, key: str) -> None: + """Unset the given key so exclude_unset dumps omit it.""" + field_name = self.get_field_name_from_key(key) + if field_name in self._iterable_model_extra: + assert self.model_extra is not None + del self.model_extra[field_name] + self.model_fields_set.discard(field_name) + return + + field_info = type(self).model_fields[field_name] + default = field_info.get_default(call_default_factory=True) + if default is PydanticUndefined: + raise AttributeError(f"Key {key} cannot be unset because it has no default.") + setattr(self, field_name, default) + self.model_fields_set.discard(field_name) + def get(self, key: str, default: Any = None) -> Any: """Get the value for the given key (can be name or alias) or default if not found.""" try: diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index d9f7cd1b8fe..c67bd14a218 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -21,7 +21,7 @@ import warnings from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from contextlib import nullcontext from typing import Any @@ -65,7 +65,7 @@ def estimate_quant_compression_for_quantizer(quantizer_attr_cfg): if not quantizer_attr_cfg: return 1.0 return min(estimate_quant_compression_for_quantizer(q) for q in quantizer_attr_cfg) - if isinstance(quantizer_attr_cfg, dict): + if isinstance(quantizer_attr_cfg, Mapping): # Handle raw quantizer cfg dicts (e.g. {"num_bits": (4, 3), "axis": None}) if not quantizer_attr_cfg.get("enable", True): return 1.0 @@ -103,28 +103,38 @@ def estimate_quant_compression_for_quantizer(quantizer_attr_cfg): return estimate_quant_compression_for_quantizer(cfgs) if cfgs else 1.0 +QuantRecipeConfig = str | Mapping[str, Any] | QuantizeConfig | None + + class QuantRecipe(CustomHPType): """A subclass of QuantizeConfig enabling auto_quantize specific configurations. Args: - quant_cfg: str or dict or None. dict is used for custom quantization formats. + quant_cfg: str, QuantizeConfig, mapping, or None. A mapping is used for custom quantization formats. name: name for custom quantization formats. Only used if quantization format is a custom format not available in :mod:`modelopt.torch.quantization.config`. """ - def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | None = None): + def __init__(self, quant_cfg: QuantRecipeConfig = None, name: str | None = None): """Initialize the QuantRecipe with the quantization configuration.""" name = self.get_auto_name_for_config(quant_cfg) or name if quant_cfg is None: - quant_cfg = {"quant_cfg": [{"quantizer_name": "*", "enable": False}]} - elif isinstance(quant_cfg, str): - assert hasattr(mtq_config, quant_cfg), f"Unknown quantization format {quant_cfg}" - quant_cfg = getattr(mtq_config, quant_cfg) + self.config = mtq_config.QuantizeConfig( + quant_cfg=[mtq_config.QuantizerCfgEntry(quantizer_name="*", enable=False)] + ) else: - assert name is not None, "name must be provided for custom quantization formats" - - self.config = mtq_config.QuantizeConfig(**quant_cfg) # type: ignore [arg-type] + if isinstance(quant_cfg, str): + assert hasattr(mtq_config, quant_cfg), f"Unknown quantization format {quant_cfg}" + quant_cfg = getattr(mtq_config, quant_cfg) + elif not isinstance(quant_cfg, QuantizeConfig): + assert name is not None, "name must be provided for custom quantization formats" + + self.config = ( + quant_cfg.model_copy(deep=True) + if isinstance(quant_cfg, QuantizeConfig) + else mtq_config.QuantizeConfig.model_validate(quant_cfg) + ) # Disable KV Cache quantization # Currently KV Cache quantization is enabled for some quantization formats and disabled for others @@ -138,14 +148,25 @@ def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | No self._str_repr: str = f"{name}(effective-bits: {self.compression * 16})" @staticmethod - def get_auto_name_for_config(quant_cfg: str | dict[str, Any] | None) -> str | None: + def get_auto_name_for_config(quant_cfg: QuantRecipeConfig) -> str | None: """Get a name for the quantization configuration.""" if quant_cfg is None: return "NONE" if isinstance(quant_cfg, str): return quant_cfg + + candidate = ( + quant_cfg + if isinstance(quant_cfg, QuantizeConfig) + else mtq_config.QuantizeConfig.model_validate(quant_cfg) + ) for quant_cfg_name in mtq_config.choices: - if quant_cfg == getattr(mtq_config, quant_cfg_name): + choice = getattr(mtq_config, quant_cfg_name) + try: + choice = mtq_config.QuantizeConfig.model_validate(choice) + except Exception: + continue + if candidate == choice: return quant_cfg_name return None diff --git a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py index d89ed35c6ca..c45da9d4ff7 100644 --- a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py +++ b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py @@ -15,6 +15,8 @@ """This module provides a GEMM function for fp8 per tensor quantization.""" +from collections.abc import Mapping + import torch from torch.autograd import Function @@ -105,7 +107,7 @@ def _fp8_availability_check(module, input, args, kwargs): input_cfg = input_cfg[0] if isinstance(weight_cfg, list): weight_cfg = weight_cfg[0] - if not isinstance(input_cfg, dict) or not isinstance(weight_cfg, dict): + if not isinstance(input_cfg, Mapping) or not isinstance(weight_cfg, Mapping): return False # Check hardware support diff --git a/modelopt/torch/quantization/backends/nvfp4_gemm.py b/modelopt/torch/quantization/backends/nvfp4_gemm.py index fdf6babb695..299f5851938 100644 --- a/modelopt/torch/quantization/backends/nvfp4_gemm.py +++ b/modelopt/torch/quantization/backends/nvfp4_gemm.py @@ -15,6 +15,8 @@ """This module provides a GEMM function for nvfp4 quantization.""" +from collections.abc import Mapping + import torch from torch.autograd import Function @@ -224,7 +226,7 @@ def _nvfp4_availability_check(module, input, args, kwargs): input_cfg = input_cfg[0] if isinstance(weight_cfg, list): weight_cfg = weight_cfg[0] - if not isinstance(input_cfg, dict) or not isinstance(weight_cfg, dict): + if not isinstance(input_cfg, Mapping) or not isinstance(weight_cfg, Mapping): return False # Check input quantizer config diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 27edc7d9c92..c15a38dbe88 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -224,7 +224,7 @@ def _validate_recursive(value): if isinstance(value, list): for item in value: _validate_recursive(item) - elif isinstance(value, dict): + elif isinstance(value, Mapping): if len(value) == 1 and "enable" in value and value["enable"] is True: raise ValueError( "Invalid quantizer config: Cannot specify only {'enable': True}. " @@ -1007,8 +1007,7 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): ) -QuantizeQuantCfgType = list[QuantizerCfgEntry] -QuantizerCfgListConfig = QuantizeQuantCfgType +QuantizerCfgListConfig = list[QuantizerCfgEntry] QuantizeQuantCfgInputType = Sequence[QuantizerCfgEntry | Mapping[str, Any]] _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None @@ -1092,7 +1091,7 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: key = "*" if isinstance(key, str) and key.startswith("nn."): - if not isinstance(value, dict): + if not isinstance(value, Mapping): raise ValueError(f"For 'nn.*' scoped format, value must be a dict, got {value!r}") # Support multi-key nn.*-scoped dicts by emitting one entry per sub-key. entries: list[dict[str, Any]] = [] @@ -1114,7 +1113,7 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: entries.append(entry) return entries else: - if isinstance(value, dict): + if isinstance(value, Mapping): cfg = {k: val for k, val in value.items() if k != "enable"} or None enable = value.get("enable") else: @@ -1171,12 +1170,12 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: if enable and cfg is not None: if isinstance(cfg, QuantizerAttributeConfig): is_invalid = False - elif isinstance(cfg, dict): + elif isinstance(cfg, Mapping): is_invalid = len(cfg) == 0 elif isinstance(cfg, list): is_invalid = len(cfg) == 0 or any( - not isinstance(item, (dict, QuantizerAttributeConfig)) - or (isinstance(item, dict) and len(item) == 0) + not isinstance(item, (Mapping, QuantizerAttributeConfig)) + or (isinstance(item, Mapping) and len(item) == 0) for item in cfg ) else: @@ -1202,7 +1201,7 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: class QuantizeConfig(ModeloptBaseConfig): """Default configuration for ``quantize`` mode.""" - quant_cfg: QuantizeQuantCfgType = ModeloptField( + quant_cfg: QuantizerCfgListConfig = ModeloptField( default=[{"quantizer_name": "*", "cfg": {"num_bits": 8, "axis": None}}], title="Quantization configuration", validate_default=True, @@ -1247,11 +1246,6 @@ class _QuantizeExportConfig(ModeloptBaseConfig): """An empty config.""" -def _load_quantize_config_dict(config_path: str) -> dict[str, Any]: - """Load a schema-backed QuantizeConfig YAML while preserving public dict constants.""" - return load_config(config_path, schema_type=QuantizeConfig).model_dump(exclude_unset=True) - - def _quantizer_cfg_entry_to_dict(entry: QuantizerCfgEntry | Mapping[str, Any]) -> dict[str, Any]: """Dump a typed quant_cfg entry back to the public legacy dict shape.""" if isinstance(entry, QuantizerCfgEntry): @@ -1305,294 +1299,328 @@ def _load_quantizer_cfg_dict_list(config_path: str) -> list[dict[str, Any]]: }, # Skip QKV Output Projection (Mcore naming) ] -INT8_DEFAULT_CFG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, - {"quantizer_name": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "max", -} +INT8_DEFAULT_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_name": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + *_default_disabled_quantizer_cfg, + ], + "algorithm": "max", + } +) -INT8_SMOOTHQUANT_CFG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, - {"quantizer_name": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "smoothquant", -} +INT8_SMOOTHQUANT_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_name": "*input_quantizer", "cfg": {"num_bits": 8, "axis": None}}, + *_default_disabled_quantizer_cfg, + ], + "algorithm": "smoothquant", + } +) -INT8_WEIGHT_ONLY_CFG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, - {"quantizer_name": "*input_quantizer", "enable": False}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "max", -} +INT8_WEIGHT_ONLY_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8, "axis": 0}}, + {"quantizer_name": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], + "algorithm": "max", + } +) -FP8_DEFAULT_CFG: dict[str, Any] = _load_quantize_config_dict("configs/ptq/presets/model/fp8") +FP8_DEFAULT_CFG: QuantizeConfig = load_config( + "configs/ptq/presets/model/fp8", schema_type=QuantizeConfig +) -MAMBA_MOE_FP8_AGGRESSIVE_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": {"num_bits": (4, 3), "axis": None}, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": {"num_bits": (4, 3), "axis": None}, - }, - *_default_disabled_quantizer_cfg, - *_mamba_moe_disabled_quantizer_cfg, - ], - "algorithm": "max", -} +MAMBA_MOE_FP8_AGGRESSIVE_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_name": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + *_default_disabled_quantizer_cfg, + *_mamba_moe_disabled_quantizer_cfg, + ], + "algorithm": "max", + } +) -MAMBA_MOE_FP8_CONSERVATIVE_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": {"num_bits": (4, 3), "axis": None}, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": {"num_bits": (4, 3), "axis": None}, - }, - *_default_disabled_quantizer_cfg, - *_mamba_moe_disabled_quantizer_cfg, - {"quantizer_name": "*mixer.in_proj*", "enable": False}, # Skip mamba linear - {"quantizer_name": "*mixer.out_proj*", "enable": False}, # Skip mamba linear - ], - "algorithm": "max", -} +MAMBA_MOE_FP8_CONSERVATIVE_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + { + "quantizer_name": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + *_default_disabled_quantizer_cfg, + *_mamba_moe_disabled_quantizer_cfg, + {"quantizer_name": "*mixer.in_proj*", "enable": False}, # Skip mamba linear + {"quantizer_name": "*mixer.out_proj*", "enable": False}, # Skip mamba linear + ], + "algorithm": "max", + } +) -FP8_PER_CHANNEL_PER_TOKEN_CFG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": 0}}, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": (4, 3), - "type": "dynamic", - "block_sizes": {-1: None}, +FP8_PER_CHANNEL_PER_TOKEN_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": (4, 3), "axis": 0}}, + { + "quantizer_name": "*input_quantizer", + "cfg": { + "num_bits": (4, 3), + "type": "dynamic", + "block_sizes": {-1: None}, + }, }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "max", -} + *_default_disabled_quantizer_cfg, + ], + "algorithm": "max", + } +) # FP8 2D blockwise fake quantization config for deepseek models -FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (4, 3), - "block_sizes": {-1: 128, -2: 128}, +FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": (4, 3), + "block_sizes": {-1: 128, -2: 128}, + }, }, - }, - {"quantizer_name": "*input_quantizer", "enable": False}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "max", -} + {"quantizer_name": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], + "algorithm": "max", + } +) -INT4_BLOCKWISE_WEIGHT_ONLY_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": 4, - "block_sizes": {-1: 128}, +INT4_BLOCKWISE_WEIGHT_ONLY_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": 4, + "block_sizes": {-1: 128}, + }, }, - }, - {"quantizer_name": "*input_quantizer", "enable": False}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "max", -} + {"quantizer_name": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], + "algorithm": "max", + } +) -INT4_AWQ_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": 4, - "block_sizes": {-1: 128, "type": "static"}, +INT4_AWQ_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": 4, + "block_sizes": {-1: 128, "type": "static"}, + }, }, - }, - {"quantizer_name": "*input_quantizer", "enable": False}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": {"method": "awq_lite", "alpha_step": 0.1}, - # "algorithm": {"method": "awq_full", "alpha_step": 0.1, "max_co_batch_size": 1024}, - # "algorithm": {"method": "awq_clip", "max_co_batch_size": 2048}, -} + {"quantizer_name": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], + "algorithm": {"method": "awq_lite", "alpha_step": 0.1}, + # "algorithm": {"method": "awq_full", "alpha_step": 0.1, "max_co_batch_size": 1024}, + # "algorithm": {"method": "awq_clip", "max_co_batch_size": 2048}, + } +) # W4A8 currently uses INT4 blockwise quantization (block size = 128) followed by FP8 quantization # for weights. This could change in the future -W4A8_AWQ_BETA_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": [ - { - "num_bits": 4, - "block_sizes": {-1: 128, "type": "static"}, - }, - { +W4A8_AWQ_BETA_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": [ + { + "num_bits": 4, + "block_sizes": {-1: 128, "type": "static"}, + }, + { + "num_bits": (4, 3), + }, + ], + }, + { + "quantizer_name": "*input_quantizer", + "cfg": { "num_bits": (4, 3), }, - ], - }, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": (4, 3), }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "awq_lite", -} + *_default_disabled_quantizer_cfg, + ], + "algorithm": "awq_lite", + } +) -MXFP8_DEFAULT_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (4, 3), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, +MXFP8_DEFAULT_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": (4, 3), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": (4, 3), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + { + "quantizer_name": "*input_quantizer", + "cfg": { + "num_bits": (4, 3), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": None, -} + *_default_disabled_quantizer_cfg, + ], + "algorithm": None, + } +) -MXFP6_DEFAULT_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (3, 2), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, +MXFP6_DEFAULT_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": (3, 2), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": (3, 2), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + { + "quantizer_name": "*input_quantizer", + "cfg": { + "num_bits": (3, 2), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": None, -} + *_default_disabled_quantizer_cfg, + ], + "algorithm": None, + } +) -MXFP4_DEFAULT_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, +MXFP4_DEFAULT_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + { + "quantizer_name": "*input_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": None, -} + *_default_disabled_quantizer_cfg, + ], + "algorithm": None, + } +) -W4A8_MXFP4_FP8_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, +W4A8_MXFP4_FP8_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": {"num_bits": (4, 3), "axis": None}, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": None, -} + { + "quantizer_name": "*input_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + *_default_disabled_quantizer_cfg, + ], + "algorithm": None, + } +) -MXINT8_DEFAULT_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": 8, - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, +MXINT8_DEFAULT_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": 8, + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": 8, - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + { + "quantizer_name": "*input_quantizer", + "cfg": { + "num_bits": 8, + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": None, -} + *_default_disabled_quantizer_cfg, + ], + "algorithm": None, + } +) # KV-cache configs are designed to be merged with a primary quantization config (e.g. # FP8_DEFAULT_CFG) that already contains _base_disable_all. They intentionally omit both # _base_disable_all and "algorithm" because these are provided by the primary config. -FP8_KV_CFG: dict[str, Any] = _load_quantize_config_dict("configs/ptq/presets/kv/fp8") - -FP8_AFFINE_KV_CFG = { - "quant_cfg": [ - { - "quantizer_name": "*[kv]_bmm_quantizer", - "cfg": { - "num_bits": (4, 3), - "bias": {-2: None, -4: None, "type": "static"}, +FP8_KV_CFG: QuantizeConfig = load_config("configs/ptq/presets/kv/fp8", schema_type=QuantizeConfig) + +FP8_AFFINE_KV_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + { + "quantizer_name": "*[kv]_bmm_quantizer", + "cfg": { + "num_bits": (4, 3), + "bias": {-2: None, -4: None, "type": "static"}, + }, }, - }, - ] -} + ] + } +) _nvfp4_cfg = { "num_bits": (2, 1), @@ -1611,7 +1639,7 @@ def _nvfp4_selective_quant_cfg( quantizer: dict = _nvfp4_cfg, weight_only: bool = False, algorithm: str | dict = "max", -) -> dict: +) -> QuantizeConfig: """Build an NVFP4 config that quantizes only the specified layer patterns.""" quant_cfg: list[dict[str, Any]] = [] quant_cfg.extend(_base_disable_all) @@ -1628,71 +1656,79 @@ def _nvfp4_selective_quant_cfg( } ) quant_cfg.extend(_default_disabled_quantizer_cfg) - return {"quant_cfg": quant_cfg, "algorithm": algorithm} + return QuantizeConfig.model_validate({"quant_cfg": quant_cfg, "algorithm": algorithm}) NVFP4_DEFAULT_CFG = _nvfp4_selective_quant_cfg(["*"]) -NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, +NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + }, }, + {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, + *_default_disabled_quantizer_cfg, + ], + "algorithm": { + "method": "mse", + "fp8_scale_sweep": True, }, - {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": { - "method": "mse", - "fp8_scale_sweep": True, - }, -} + } +) -NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, +NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "static", "scale_bits": (4, 3)}, + }, }, + {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, + *_default_disabled_quantizer_cfg, + ], + "algorithm": { + "method": "local_hessian", + "fp8_scale_sweep": True, }, - {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, - *_default_disabled_quantizer_cfg, - ], - "algorithm": { - "method": "local_hessian", - "fp8_scale_sweep": True, - }, -} + } +) -MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": _nvfp4_cfg}, - {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, - *_default_disabled_quantizer_cfg, - *_mamba_moe_disabled_quantizer_cfg, - ], - "algorithm": "max", -} -MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": _nvfp4_cfg}, - {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, - *_default_disabled_quantizer_cfg, - *_mamba_moe_disabled_quantizer_cfg, - {"quantizer_name": "*mixer.in_proj*", "enable": False}, # Skip mamba linear - {"quantizer_name": "*mixer.out_proj*", "enable": False}, # Skip mamba linear - ], - "algorithm": "max", -} +MAMBA_MOE_NVFP4_AGGRESSIVE_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + {"quantizer_name": "*weight_quantizer", "cfg": _nvfp4_cfg}, + {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, + *_default_disabled_quantizer_cfg, + *_mamba_moe_disabled_quantizer_cfg, + ], + "algorithm": "max", + } +) +MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + {"quantizer_name": "*weight_quantizer", "cfg": _nvfp4_cfg}, + {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, + *_default_disabled_quantizer_cfg, + *_mamba_moe_disabled_quantizer_cfg, + {"quantizer_name": "*mixer.in_proj*", "enable": False}, # Skip mamba linear + {"quantizer_name": "*mixer.out_proj*", "enable": False}, # Skip mamba linear + ], + "algorithm": "max", + } +) NVFP4_AWQ_LITE_CFG = _nvfp4_selective_quant_cfg(["*"], algorithm="awq_lite") @@ -1703,135 +1739,147 @@ def _nvfp4_selective_quant_cfg( ) # See comment above FP8_KV_CFG — KV-cache configs omit _base_disable_all and "algorithm". -NVFP4_AFFINE_KV_CFG = { - "quant_cfg": [ - { - "quantizer_name": "*[kv]_bmm_quantizer", - "cfg": { - **_nvfp4_cfg, - "bias": {-2: None, -4: None, "type": "static"}, +NVFP4_AFFINE_KV_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + { + "quantizer_name": "*[kv]_bmm_quantizer", + "cfg": { + **_nvfp4_cfg, + "bias": {-2: None, -4: None, "type": "static"}, + }, }, - }, - ] -} + ] + } +) -NVFP4_KV_CFG = { - "quant_cfg": [ - {"quantizer_name": "*[kv]_bmm_quantizer", "cfg": _nvfp4_cfg}, - ] -} +NVFP4_KV_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + {"quantizer_name": "*[kv]_bmm_quantizer", "cfg": _nvfp4_cfg}, + ] + } +) # Moved from examples/diffusers/quantization/config.py to here -NVFP4_FP8_MHA_CONFIG = { - "quant_cfg": [ - *_base_disable_all, - {"quantizer_name": "*weight_quantizer", "cfg": _nvfp4_cfg}, - {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, - {"quantizer_name": "*output_quantizer", "enable": False}, - { - "quantizer_name": "*q_bmm_quantizer", - "cfg": { - "num_bits": (4, 3), +NVFP4_FP8_MHA_CONFIG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + {"quantizer_name": "*weight_quantizer", "cfg": _nvfp4_cfg}, + {"quantizer_name": "*input_quantizer", "cfg": _nvfp4_cfg}, + {"quantizer_name": "*output_quantizer", "enable": False}, + { + "quantizer_name": "*q_bmm_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - }, - { - "quantizer_name": "*k_bmm_quantizer", - "cfg": { - "num_bits": (4, 3), + { + "quantizer_name": "*k_bmm_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - }, - { - "quantizer_name": "*v_bmm_quantizer", - "cfg": { - "num_bits": (4, 3), + { + "quantizer_name": "*v_bmm_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - }, - { - "quantizer_name": "*softmax_quantizer", - "cfg": { - "num_bits": (4, 3), + { + "quantizer_name": "*softmax_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - }, - { - "quantizer_name": "transformer_blocks*bmm2_output_quantizer", - "cfg": { - "num_bits": (4, 3), + { + "quantizer_name": "transformer_blocks*bmm2_output_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - }, - ], - "algorithm": "max", -} + ], + "algorithm": "max", + } +) # See comment above FP8_KV_CFG — KV-cache configs omit _base_disable_all and "algorithm". -NVFP4_KV_ROTATE_CFG = { - "quant_cfg": [ - { - # q_bmm is disabled but pre-configured with rotate=True so that downstream - # code can inspect the rotate flag even while the quantizer is off. - "quantizer_name": "*q_bmm_quantizer", - "cfg": { - "rotate": True, +NVFP4_KV_ROTATE_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + { + # q_bmm is disabled but pre-configured with rotate=True so that downstream + # code can inspect the rotate flag even while the quantizer is off. + "quantizer_name": "*q_bmm_quantizer", + "cfg": { + "rotate": True, + }, + "enable": False, }, - "enable": False, - }, - { - "quantizer_name": "*k_bmm_quantizer", - "cfg": { - **_nvfp4_cfg, - "rotate": True, + { + "quantizer_name": "*k_bmm_quantizer", + "cfg": { + **_nvfp4_cfg, + "rotate": True, + }, }, - }, - {"quantizer_name": "*v_bmm_quantizer", "cfg": _nvfp4_cfg}, - ], - "algorithm": "max", -} + {"quantizer_name": "*v_bmm_quantizer", "cfg": _nvfp4_cfg}, + ], + "algorithm": "max", + } +) NVFP4_SVDQUANT_DEFAULT_CFG = _nvfp4_selective_quant_cfg( ["*"], algorithm={"method": "svdquant", "lowrank": 32} ) -W4A8_NVFP4_FP8_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (4, 3)}, +W4A8_NVFP4_FP8_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (4, 3)}, + }, }, - }, - { - "quantizer_name": "*input_quantizer", - "cfg": { - "num_bits": (4, 3), + { + "quantizer_name": "*input_quantizer", + "cfg": { + "num_bits": (4, 3), + }, }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": "max", -} + *_default_disabled_quantizer_cfg, + ], + "algorithm": "max", + } +) -MXFP4_MLP_WEIGHT_ONLY_CFG = { - "quant_cfg": [ - *_base_disable_all, - { - "quantizer_name": "*mlp*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, +MXFP4_MLP_WEIGHT_ONLY_CFG = QuantizeConfig.model_validate( + { + "quant_cfg": [ + *_base_disable_all, + { + "quantizer_name": "*mlp*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - }, - { - "quantizer_name": "*block_sparse_moe*weight_quantizer", - "cfg": { - "num_bits": (2, 1), - "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + { + "quantizer_name": "*block_sparse_moe*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, + }, }, - }, - *_default_disabled_quantizer_cfg, - ], - "algorithm": None, -} + *_default_disabled_quantizer_cfg, + ], + "algorithm": None, + } +) NVFP4_MLP_WEIGHT_ONLY_CFG = _nvfp4_selective_quant_cfg( ["*mlp*", "*block_sparse_moe*"], quantizer=_nvfp4_cfg_bs32, weight_only=True diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 55d48691de0..ffade36c411 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -18,7 +18,7 @@ import fnmatch import re import warnings -from collections.abc import Callable +from collections.abc import Callable, Mapping from contextlib import contextmanager from typing import Any, cast @@ -415,7 +415,7 @@ def set_quantizer_attributes_full( def set_quantizer_attributes_partial( quant_model: nn.Module, wildcard_or_filter_func: str | Callable, - partial_attributes: dict[str, Any] | list[dict[str, Any]], + partial_attributes: Mapping[str, Any] | list[Mapping[str, Any]], parent_class: type[nn.Module] | None = None, ): """Update a subset of quantizer attributes by wildcard or filter function, merging with existing attributes. @@ -450,14 +450,14 @@ def set_quantizer_attributes_partial( an instance of this class. If ``None``, all quantizers matching ``wildcard_or_filter_func`` are adjusted. """ - if not isinstance(partial_attributes, (dict, list)): + if not isinstance(partial_attributes, (Mapping, list)): raise ValueError( - f"Invalid type for attributes: {type(partial_attributes)}, expected dictionary or list of dict." + f"Invalid type for attributes: {type(partial_attributes)}, expected mapping or list of mappings." ) if isinstance(partial_attributes, list) and not all( - isinstance(attr, dict) for attr in partial_attributes + isinstance(attr, Mapping) for attr in partial_attributes ): - raise ValueError("All elements in attributes list must be of type dict.") + raise ValueError("All elements in attributes list must be mappings.") for name, module in quant_model.named_modules(): if _match_quantizer(wildcard_or_filter_func, name, module, parent_class, quant_model): diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 713cdd7373c..9c41664ba5d 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -16,7 +16,7 @@ """This module contains the mode descriptor for the quantization mode.""" from abc import abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Mapping from modelopt.torch.opt.config import ModeloptBaseConfig from modelopt.torch.opt.conversion import ModelLikeModule @@ -374,7 +374,7 @@ def get_modelike_from_algo_cfg(algo_cfg: QuantizeAlgoCfgType) -> ModeConfigList: return [get_modelike_from_algo_cfg(c)[0] for c in algo_cfg] if algo_cfg is None or isinstance(algo_cfg, str): algo_name, algo_cfg = algo_cfg, {} - elif isinstance(algo_cfg, dict): + elif isinstance(algo_cfg, Mapping): algo_name = algo_cfg["method"] else: raise ValueError(f"Invalid config type: {type(algo_cfg)}") diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 589ec477786..0e8f34f101d 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -19,7 +19,7 @@ import inspect import os import warnings -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Mapping from typing import Any import torch @@ -143,7 +143,7 @@ def postprocess_amax(model: nn.Module, key: str, post_process_fn) -> nn.Module: def quantize( model: nn.Module, - config: dict[str, Any | QuantizeConfig], + config: QuantizeConfig | Mapping[str, Any], forward_loop: ForwardLoop | None = None, ) -> nn.Module: """Quantizes and calibrates the model in-place. @@ -240,13 +240,15 @@ def forward_loop(model) -> None: Returns: A pytorch model which has been quantized and calibrated. """ + validated_config: QuantizeConfig = QuantizeConfig.model_validate(config) if not is_quantized(model): - model = apply_mode(model, mode=[("quantize", dict(config))], registry=QuantizeModeRegistry) + model = apply_mode( + model, mode=[("quantize", dict(validated_config))], registry=QuantizeModeRegistry + ) else: # Already quantized, so lets apply the quant_cfg from the config - quant_cfg = QuantizeConfig(**dict(config)).quant_cfg - set_quantizer_by_cfg(model, quant_cfg) - return calibrate(model, config.get("algorithm"), forward_loop=forward_loop) + set_quantizer_by_cfg(model, validated_config.quant_cfg) + return calibrate(model, validated_config.algorithm, forward_loop=forward_loop) # TODO: create a config interface for auto_quantize and expose setting @@ -271,7 +273,7 @@ def forward_loop(model) -> None: def auto_quantize( model: nn.Module, constraints: dict[str, float | str] = {"effective_bits": 4.8}, - quantization_formats: list[dict[str, Any] | str] = [ + quantization_formats: list[QuantizeConfig | Mapping[str, Any] | str | None] = [ mtq.NVFP4_AWQ_LITE_CFG, mtq.FP8_DEFAULT_CFG, ], @@ -500,7 +502,7 @@ def forward_backward_step(model, batch) -> None: for quant_cfg, name in processed_quantization_formats: algo = QuantRecipe(quant_cfg, name=name).config.algorithm - algo_method = algo["method"] if isinstance(algo, dict) else algo + algo_method = algo["method"] if isinstance(algo, Mapping) else algo if algo_method not in _AUTO_QUANTIZE_SUPPORTED_ALGORITHMS: raise ValueError( f"Algorithm '{algo_method}' in '{name}' is not supported by auto_quantize yet. " diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 3ff7401ec3e..f88ebb09594 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -18,7 +18,7 @@ import contextlib import math import warnings -from collections.abc import Callable +from collections.abc import Callable, Mapping from typing import Any, Protocol import torch @@ -203,7 +203,9 @@ def __init__( # Optional quantizer cache for caching quantizer related encoding or tensors. self._quantizer_cache = None - def set_from_attribute_config(self, attribute_cfg: QuantizerAttributeConfig | dict[str, Any]): + def set_from_attribute_config( + self, attribute_cfg: QuantizerAttributeConfig | Mapping[str, Any] + ): """Set quantizer attributes from attribute_cfg. The attributes are defined in @@ -1423,13 +1425,11 @@ def get_modelopt_state(self) -> dict[str, Any]: return {"num_quantizers": len(self), "is_sequential_quantizer": True} def set_from_attribute_config( - self, attributes: list[QuantizerAttributeConfig] | list[dict[str, Any]] + self, attributes: list[QuantizerAttributeConfig] | list[Mapping[str, Any]] ): """Set the attributes of contained quantizers from a list of attribute_dicts.""" if not isinstance(attributes, (list, tuple)): - assert isinstance(attributes, (dict, QuantizerAttributeConfig)), ( - "attributes must be a list or a dict." - ) + assert isinstance(attributes, Mapping), "attributes must be a list or a mapping." attributes = [attributes] * len(self) for attribute, quantizer in zip(attributes, self): diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 1a177e04dc8..22525c414c3 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -17,6 +17,7 @@ import copy from collections import namedtuple +from collections.abc import Mapping, Sequence from contextlib import ExitStack, contextmanager, nullcontext from typing import TYPE_CHECKING, Any @@ -27,7 +28,7 @@ from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam from torch.distributed.tensor import Replicate -from modelopt.torch.quantization.config import QuantizerCfgEntry +from modelopt.torch.quantization.config import QuantizeConfig, QuantizerCfgEntry from modelopt.torch.utils import get_unwrapped_name, print_rank_0 if TYPE_CHECKING: @@ -915,12 +916,13 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True): def update_quant_cfg_with_kv_cache_quant( - quant_cfg: dict[str, Any], kv_cache_quant_cfg: list[QuantizerCfgEntry] -) -> dict[str, Any]: + quant_cfg: QuantizeConfig | Mapping[str, Any], + kv_cache_quant_cfg: Sequence[QuantizerCfgEntry | Mapping[str, Any]], +) -> QuantizeConfig: """Update the quant_cfg with the kv cache quant_cfg. Args: - quant_cfg: The outer quantization config dict (with ``"quant_cfg"`` and ``"algorithm"`` keys). + quant_cfg: The outer quantization config (with ``"quant_cfg"`` and ``"algorithm"`` keys). kv_cache_quant_cfg: A list of :class:`QuantizerCfgEntry ` dicts for KV cache quantization, typically ``some_kv_cfg["quant_cfg"]``. @@ -929,17 +931,17 @@ def update_quant_cfg_with_kv_cache_quant( A deep copy of ``quant_cfg`` with the KV cache entries appended to ``quant_cfg["quant_cfg"]``. """ # If quant_cfg["quant_cfg"] is None, it corresponds to only kv cache quantization case - quant_cfg = copy.deepcopy(quant_cfg) - inner: list[QuantizerCfgEntry] = quant_cfg.get("quant_cfg") or [ - {"quantizer_name": "*", "enable": False} - ] - quant_cfg["quant_cfg"] = inner + list(kv_cache_quant_cfg) + updated_quant_cfg: QuantizeConfig = QuantizeConfig.model_validate(copy.deepcopy(quant_cfg)) + inner = list( + updated_quant_cfg.get("quant_cfg") or [QuantizerCfgEntry(quantizer_name="*", enable=False)] + ) + updated_quant_cfg["quant_cfg"] = inner + list(kv_cache_quant_cfg) # Set default algorithm for kv cache quantization if not provided. - if not quant_cfg.get("algorithm"): - quant_cfg["algorithm"] = "max" - print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}") - return quant_cfg + if not updated_quant_cfg.get("algorithm"): + updated_quant_cfg["algorithm"] = "max" + print_rank_0(f"Updated quant_cfg with KV cache quantization: {updated_quant_cfg}") + return updated_quant_cfg def promote_nvfp4_static_quantizers(model: nn.Module) -> int: diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 3bcb687c1af..015c8b12845 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -15,6 +15,8 @@ """Test of quantization config validations.""" +from collections.abc import MutableMapping + import pytest from pydantic import ValidationError @@ -103,10 +105,34 @@ def test_quantizer_cfg_entry_is_pydantic_and_dict_like(): assert _cfg_to_dict(cfg_entry["cfg"]) == {"num_bits": 8} -def test_public_preset_quant_cfg_entries_remain_dicts(): - """Public preset constants keep legacy dict entries for downstream compatibility.""" - assert all(isinstance(entry, dict) for entry in FP8_DEFAULT_CFG["quant_cfg"]) - assert all(isinstance(entry, dict) for entry in NVFP4_DEFAULT_CFG["quant_cfg"]) +def test_quantizer_cfg_entry_mutable_mapping_delitem_unsets_field(): + """Deleting a config key resets it to unset for exclude_unset dumps.""" + entry = QuantizerCfgEntry(quantizer_name="*weight_quantizer", cfg={"num_bits": 8}, enable=True) + assert isinstance(entry, MutableMapping) + assert entry.model_dump(exclude_unset=True) == { + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": 8}, + "enable": True, + } + + del entry["cfg"] + + assert entry["cfg"] is None + assert entry.model_dump(exclude_unset=True) == { + "quantizer_name": "*weight_quantizer", + "enable": True, + } + + +def test_public_preset_quant_cfg_entries_are_typed_and_dict_like(): + """Public preset constants are typed but keep dict-style entry access.""" + assert isinstance(FP8_DEFAULT_CFG, QuantizeConfig) + assert isinstance(NVFP4_DEFAULT_CFG, QuantizeConfig) + for preset in (FP8_DEFAULT_CFG, NVFP4_DEFAULT_CFG): + assert all(isinstance(entry, QuantizerCfgEntry) for entry in preset["quant_cfg"]) + for entry in preset["quant_cfg"]: + assert entry["quantizer_name"] == entry.quantizer_name + assert dict(entry.items())["quantizer_name"] == entry.quantizer_name def test_quantizer_cfg_entry_rejects_no_effect_entry(): From b23e3a9d58cd68d7205a8c17813220d7af91e940 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 7 May 2026 15:13:42 -0700 Subject: [PATCH 07/14] Add mixed quant config normalization test Cover normalization after mutating raw dict quantizer entries and schema-backed ModeloptBaseConfig entries. Signed-off-by: Shengliang Xu --- .../quantization/test_config_validation.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 015c8b12845..0db4bf241e7 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -15,6 +15,7 @@ """Test of quantization config validations.""" +import copy from collections.abc import MutableMapping import pytest @@ -31,6 +32,8 @@ QuantizeConfig, QuantizerAttributeConfig, QuantizerCfgEntry, + _base_disable_all, + _default_disabled_quantizer_cfg, find_quant_cfg_entry_by_path, need_calibration, normalize_quant_cfg_list, @@ -135,6 +138,45 @@ def test_public_preset_quant_cfg_entries_are_typed_and_dict_like(): assert dict(entry.items())["quantizer_name"] == entry.quantizer_name +def test_mixed_raw_dict_and_modelopt_config_entries_normalize_after_mutation(): + """Mixed raw dict and ModeloptBaseConfig entries normalize after mutation.""" + config = { + "quant_cfg": [ + *copy.deepcopy(_base_disable_all), + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "backend": "custom_backend", + "num_bits": "e5m2", + "pass_through_bwd": True, + "backend_extra_args": { + "format": "scalar", + "block_sizes": 16, + }, + }, + }, + {"quantizer_name": "*input_quantizer", "enable": False}, + *copy.deepcopy(_default_disabled_quantizer_cfg), + ], + "algorithm": "max", + } + + normalized = normalize_quant_cfg_list(config["quant_cfg"]) + weight_entry = find_quant_cfg_entry_by_path(normalized, "*weight_quantizer") + assert weight_entry["cfg"]["num_bits"] == "e5m2" + + raw_weight_entry = find_quant_cfg_entry_by_path(config["quant_cfg"], "*weight_quantizer") + raw_weight_entry["cfg"]["num_bits"] = "e2m1" + normalized = normalize_quant_cfg_list(config["quant_cfg"]) + weight_entry = find_quant_cfg_entry_by_path(normalized, "*weight_quantizer") + assert weight_entry["cfg"]["num_bits"] == "e2m1" + + weight_entry["cfg"]["num_bits"] = "e4m3" + renormalized = normalize_quant_cfg_list(normalized) + weight_entry = find_quant_cfg_entry_by_path(renormalized, "*weight_quantizer") + assert weight_entry["cfg"]["num_bits"] == "e4m3" + + def test_quantizer_cfg_entry_rejects_no_effect_entry(): """Direct QuantizerCfgEntry construction rejects entries with no cfg or enable.""" with pytest.raises(ValidationError, match="must specify 'cfg', 'enable'"): From 5969cb3799032b6af09fd1325f717747309756e8 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 7 May 2026 18:02:20 -0700 Subject: [PATCH 08/14] Address quant config review feedback Signed-off-by: Shengliang Xu --- examples/llm_ptq/example_utils.py | 14 ++++-- examples/vllm_serve/vllm_ptq_utils.py | 8 +++- modelopt/torch/opt/config.py | 9 ++-- modelopt/torch/quantization/config.py | 46 ++++++++++++++++--- modelopt/torch/quantization/conversion.py | 4 +- .../nn/modules/tensor_quantizer.py | 2 + .../torch/quantization/utils/core_utils.py | 2 +- .../quantization/test_config_validation.py | 39 ++++++++++++++++ .../torch/quantization/test_quantize_cpu.py | 6 +++ 9 files changed, 110 insertions(+), 20 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index ba3a878bfc9..99bdee253ef 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -42,7 +42,7 @@ ProcessorMixin, ) -from modelopt.torch.quantization.config import QuantizeConfig +from modelopt.torch.quantization.config import QuantizeConfig, QuantizerCfgEntry try: from huggingface_hub import snapshot_download @@ -249,10 +249,14 @@ def build_quant_cfg( if model_type == "phi4mm": # Only quantize the language model - quant_cfg_obj["quant_cfg"].append({"quantizer_name": "*speech*", "enable": False}) - quant_cfg_obj["quant_cfg"].append({"quantizer_name": "*audio*", "enable": False}) - quant_cfg_obj["quant_cfg"].append({"quantizer_name": "*image*", "enable": False}) - quant_cfg_obj["quant_cfg"].append({"quantizer_name": "*vision*", "enable": False}) + quant_cfg_obj["quant_cfg"].extend( + [ + QuantizerCfgEntry(quantizer_name="*speech*", enable=False), + QuantizerCfgEntry(quantizer_name="*audio*", enable=False), + QuantizerCfgEntry(quantizer_name="*image*", enable=False), + QuantizerCfgEntry(quantizer_name="*vision*", enable=False), + ] + ) return quant_cfg_obj diff --git a/examples/vllm_serve/vllm_ptq_utils.py b/examples/vllm_serve/vllm_ptq_utils.py index 1c7c6c734ef..048b1e8722e 100644 --- a/examples/vllm_serve/vllm_ptq_utils.py +++ b/examples/vllm_serve/vllm_ptq_utils.py @@ -14,7 +14,7 @@ # limitations under the License. import dataclasses -from collections.abc import Callable +from collections.abc import Callable, Mapping from typing import Any import torch @@ -119,7 +119,11 @@ def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: list) -> list: return kv_quant_cfg kv_entry = next( - (e for e in kv_quant_cfg if e.get("quantizer_name") == "*[kv]_bmm_quantizer"), + ( + e + for e in kv_quant_cfg + if isinstance(e, Mapping) and e.get("quantizer_name") == "*[kv]_bmm_quantizer" + ), None, ) if kv_entry is not None: diff --git a/modelopt/torch/opt/config.py b/modelopt/torch/opt/config.py index 874fe145a57..0ef1f70a838 100644 --- a/modelopt/torch/opt/config.py +++ b/modelopt/torch/opt/config.py @@ -119,7 +119,10 @@ def __setitem__(self, key: str, value: Any) -> None: def __delitem__(self, key: str) -> None: """Unset the given key so exclude_unset dumps omit it.""" - field_name = self.get_field_name_from_key(key) + try: + field_name = self.get_field_name_from_key(key) + except AttributeError as e: + raise KeyError(key) from e if field_name in self._iterable_model_extra: assert self.model_extra is not None del self.model_extra[field_name] @@ -129,8 +132,8 @@ def __delitem__(self, key: str) -> None: field_info = type(self).model_fields[field_name] default = field_info.get_default(call_default_factory=True) if default is PydanticUndefined: - raise AttributeError(f"Key {key} cannot be unset because it has no default.") - setattr(self, field_name, default) + raise KeyError(f"Key {key} cannot be unset because it has no default.") + self.__dict__[field_name] = default self.model_fields_set.discard(field_name) def get(self, key: str, default: Any = None) -> Any: diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index c15a38dbe88..87f3a1f8e85 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -561,6 +561,10 @@ def validate_quantizer_cfg_entry(cls, values): "Each quant_cfg entry must specify 'cfg', 'enable', or both. " "An entry with only 'quantizer_name' has no effect." ) + if "cfg" in values and values["cfg"] is None: + raise ValueError("cfg must be omitted or a valid mapping/list, not null.") + if "enable" in values and values["enable"] is None: + raise ValueError("enable must be a boolean when provided, not null.") cfg = values.get("cfg") enable = values.get("enable", True) @@ -1008,7 +1012,7 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): QuantizerCfgListConfig = list[QuantizerCfgEntry] -QuantizeQuantCfgInputType = Sequence[QuantizerCfgEntry | Mapping[str, Any]] +QuantizeQuantCfgInputType = Mapping[str, Any] | Sequence[QuantizerCfgEntry | Mapping[str, Any]] _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None @@ -1016,7 +1020,7 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): def normalize_quant_cfg_list( - v: Mapping[str, Any] | list[QuantizerCfgEntry | Mapping[str, Any]], + v: Mapping[str, Any] | Sequence[QuantizerCfgEntry | Mapping[str, Any]], ) -> list[QuantizerCfgEntry]: """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` objects. @@ -1099,15 +1103,19 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: if isinstance(sub_cfg, QuantizerAttributeConfig): enable = None cfg = sub_cfg - else: + elif isinstance(sub_cfg, Mapping): sub_cfg = dict(sub_cfg) enable = sub_cfg.pop("enable", None) cfg = sub_cfg or None + else: + enable = None + cfg = sub_cfg entry: dict[str, Any] = { "parent_class": key, "quantizer_name": q_path, - "cfg": cfg, } + if cfg is not None: + entry["cfg"] = cfg if enable is not None: entry["enable"] = enable entries.append(entry) @@ -1119,7 +1127,9 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: else: cfg = value enable = None - entry = {"quantizer_name": key, "cfg": cfg} + entry = {"quantizer_name": key} + if cfg is not None: + entry["cfg"] = cfg if enable is not None: entry["enable"] = enable return [entry] @@ -1165,6 +1175,17 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: # Validate: when cfg is present and enable=True, cfg must be a non-empty # dict or list. An empty cfg would attempt to create a # QuantizerAttributeConfig with no actual configuration. + if "cfg" in entry and entry["cfg"] is None: + raise ValueError( + f"Invalid quant_cfg entry: {raw!r} - 'cfg' must be omitted or a " + "valid mapping/list, not null." + ) + if "enable" in entry and entry["enable"] is None: + raise ValueError( + f"Invalid quant_cfg entry: {raw!r} - 'enable' must be a boolean " + "when provided, not null." + ) + cfg = entry.get("cfg") enable = entry.get("enable", True) if enable and cfg is not None: @@ -1190,9 +1211,8 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: "explicitly." ) - # Normalize: make enable and cfg always explicit. + # Normalize: make enable explicit. cfg remains omitted when it is intentionally unset. entry.setdefault("enable", True) - entry.setdefault("cfg", None) result.append(QuantizerCfgEntry.model_validate(entry)) return result @@ -1201,6 +1221,18 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: class QuantizeConfig(ModeloptBaseConfig): """Default configuration for ``quantize`` mode.""" + def model_dump(self, **kwargs): + """Dump quant_cfg entries without unset optional fields.""" + data = super().model_dump(**kwargs) + if "quant_cfg" in data: + data["quant_cfg"] = [ + entry.model_dump(exclude_unset=True) + if isinstance(entry, QuantizerCfgEntry) + else {k: v for k, v in entry.items() if v is not None} + for entry in self.quant_cfg + ] + return data + quant_cfg: QuantizerCfgListConfig = ModeloptField( default=[{"quantizer_name": "*", "cfg": {"num_bits": 8, "axis": None}}], title="Quantization configuration", diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index ffade36c411..1177a342c6e 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -249,7 +249,7 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgInpu See :ref:`quant-cfg` for the full format reference and common patterns. """ - quant_cfg = normalize_quant_cfg_list(list(quant_cfg)) + quant_cfg = normalize_quant_cfg_list(quant_cfg) for entry in quant_cfg: quantizer_name: str = entry["quantizer_name"] @@ -496,7 +496,7 @@ def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuan Yields: None — the context body runs with the new quantizer attributes active. """ - quant_cfg = normalize_quant_cfg_list(list(quant_cfg)) + quant_cfg = normalize_quant_cfg_list(quant_cfg) for entry in quant_cfg: if isinstance(entry.get("cfg"), list): diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index f88ebb09594..f0871f3d291 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1431,6 +1431,8 @@ def set_from_attribute_config( if not isinstance(attributes, (list, tuple)): assert isinstance(attributes, Mapping), "attributes must be a list or a mapping." attributes = [attributes] * len(self) + elif len(attributes) != len(self): + raise ValueError(f"Expected {len(self)} attribute configs, but got {len(attributes)}.") for attribute, quantizer in zip(attributes, self): quantizer.set_from_attribute_config(attribute) diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 22525c414c3..ab4c11b388f 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -935,7 +935,7 @@ def update_quant_cfg_with_kv_cache_quant( inner = list( updated_quant_cfg.get("quant_cfg") or [QuantizerCfgEntry(quantizer_name="*", enable=False)] ) - updated_quant_cfg["quant_cfg"] = inner + list(kv_cache_quant_cfg) + updated_quant_cfg["quant_cfg"] = inner + copy.deepcopy(list(kv_cache_quant_cfg)) # Set default algorithm for kv cache quantization if not provided. if not updated_quant_cfg.get("algorithm"): diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 0db4bf241e7..1690801dc1d 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -126,6 +126,9 @@ def test_quantizer_cfg_entry_mutable_mapping_delitem_unsets_field(): "enable": True, } + with pytest.raises(KeyError): + del entry["missing"] + def test_public_preset_quant_cfg_entries_are_typed_and_dict_like(): """Public preset constants are typed but keep dict-style entry access.""" @@ -177,6 +180,22 @@ def test_mixed_raw_dict_and_modelopt_config_entries_normalize_after_mutation(): assert weight_entry["cfg"]["num_bits"] == "e4m3" +@pytest.mark.parametrize( + ("raw", "match"), + [ + ({"quantizer_name": "*", "cfg": None}, "'?cfg'? must be omitted"), + ({"quantizer_name": "*", "enable": None}, "'?enable'? must be a boolean"), + ], +) +def test_quantizer_cfg_entry_rejects_explicit_null_values(raw, match): + """Explicit null cfg/enable values are rejected instead of treated as omitted.""" + with pytest.raises(ValidationError, match=match): + QuantizerCfgEntry.model_validate(raw) + + with pytest.raises(ValueError, match=match): + normalize_quant_cfg_list([raw]) + + def test_quantizer_cfg_entry_rejects_no_effect_entry(): """Direct QuantizerCfgEntry construction rejects entries with no cfg or enable.""" with pytest.raises(ValidationError, match="must specify 'cfg', 'enable'"): @@ -469,6 +488,26 @@ def test_legacy_nn_class_with_cfg(self): assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 4, "axis": 0} assert result[0]["enable"] is True + def test_legacy_nn_class_with_list_valued_cfg(self): + """Legacy nn.* scoped format preserves list-valued SequentialQuantizer cfg.""" + raw = [ + { + "nn.Linear": { + "*weight_quantizer": [ + {"num_bits": 4, "block_sizes": {-1: 128, "type": "static"}}, + {"num_bits": 8, "axis": 0}, + ] + } + } + ] + result = normalize_quant_cfg_list(raw) + assert len(result) == 1 + assert result[0]["parent_class"] == "nn.Linear" + assert result[0]["quantizer_name"] == "*weight_quantizer" + assert isinstance(result[0]["cfg"], list) + assert _cfg_to_dict(result[0]["cfg"]) == raw[0]["nn.Linear"]["*weight_quantizer"] + assert result[0]["enable"] is True + def test_legacy_list_valued_cfg(self): """Legacy dict format with list-valued cfg (SequentialQuantizer) normalizes correctly.""" raw = [ diff --git a/tests/unit/torch/quantization/test_quantize_cpu.py b/tests/unit/torch/quantization/test_quantize_cpu.py index 301f4cdab1e..efb8627d448 100644 --- a/tests/unit/torch/quantization/test_quantize_cpu.py +++ b/tests/unit/torch/quantization/test_quantize_cpu.py @@ -401,6 +401,12 @@ def test_list_attributes_creates_sequential_quantizer(self): assert isinstance(module, SequentialQuantizer) assert len(module) == 2 + def test_sequential_quantizer_rejects_mismatched_attribute_list_length(self): + """SequentialQuantizer rejects partial list configs instead of silently zipping.""" + quantizer = SequentialQuantizer(TensorQuantizer(), TensorQuantizer()) + with pytest.raises(ValueError, match="Expected 2 attribute configs, but got 1"): + quantizer.set_from_attribute_config([QuantizerAttributeConfig(num_bits=8)]) + def test_ordering_later_entry_overrides_earlier(): """Later entries in quant_cfg override earlier ones for the same quantizer.""" From 0d31d46f8ca38c0f571762c3b28ce8cc4d935bcc Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 7 May 2026 18:41:49 -0700 Subject: [PATCH 09/14] Update recipe loader schema expectations Signed-off-by: Shengliang Xu --- tests/unit/recipe/test_loader.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index fa8e11aea50..91496a2ecd7 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -563,7 +563,6 @@ def test_import_entry_element_schema_appends(tmp_path): recipe = load_recipe(recipe_file) assert _entry_to_dict(recipe.quantize["quant_cfg"][0]) == { "quantizer_name": "*", - "cfg": None, "enable": False, } @@ -977,7 +976,6 @@ def test_import_mixed_tree(tmp_path): assert _entry_to_dict(data["quant_cfg"][1]) == { "quantizer_name": "*lm_head*", "enable": False, - "cfg": None, } From d33ee36945ae8d739bfde5c5a4d42b90bafec4f8 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Fri, 8 May 2026 11:38:48 -0700 Subject: [PATCH 10/14] Tighten ModeloptBaseConfig mapping semantics Signed-off-by: Shengliang Xu --- modelopt/torch/opt/config.py | 74 +++++++++++++------ modelopt/torch/quantization/config.py | 53 ++++--------- modelopt/torch/quantization/conversion.py | 2 +- tests/unit/torch/opt/test_config.py | 2 +- .../quantization/test_config_validation.py | 61 ++++++++++----- 5 files changed, 112 insertions(+), 80 deletions(-) diff --git a/modelopt/torch/opt/config.py b/modelopt/torch/opt/config.py index 0ef1f70a838..00279dd10d2 100644 --- a/modelopt/torch/opt/config.py +++ b/modelopt/torch/opt/config.py @@ -60,8 +60,15 @@ def ModeloptField(default: Any = PydanticUndefined, **kwargs): # noqa: N802 class ModeloptBaseConfig(BaseModel, MutableMapping[str, Any]): """Our config base class for mode configuration. - The base class extends the capabilities of pydantic's BaseModel to provide additional methods - and properties for easier access and manipulation of the configuration. + The base class extends pydantic's BaseModel with a mapping interface so schema-backed + config objects can keep the dict-style access patterns used by older ModelOpt code. + + This is intentionally a fixed-key mutable mapping instead of a general dict. The mapping + keys are the model fields exposed through their aliases when present, and lookups accept + either a field name or its alias. Values are read and written through the pydantic model, so + assignment validation still applies. New keys cannot be inserted, and existing keys cannot + be deleted because the schema defines the complete key set; callers that need omission + semantics should use model_dump(exclude_unset=True) or the explicit_* helpers. """ model_config = PyDanticConfigDict(extra="forbid", validate_assignment=True) @@ -111,44 +118,42 @@ def __contains__(self, key: str) -> bool: def __getitem__(self, key: str) -> Any: """Get the value for the given key (can be name or alias of field).""" - return getattr(self, self.get_field_name_from_key(key)) + try: + return getattr(self, self.get_field_name_from_key(key)) + except AttributeError as e: + raise KeyError(key) from e def __setitem__(self, key: str, value: Any) -> None: - """Set the value for the given key (can be name or alias of field).""" - setattr(self, self.get_field_name_from_key(key), value) + """Set an existing field by name or alias, preserving pydantic assignment validation.""" + try: + field_name = self.get_field_name_from_key(key) + except AttributeError as e: + raise KeyError(key) from e + if field_name not in type(self).model_fields: + raise KeyError(key) + setattr(self, field_name, value) def __delitem__(self, key: str) -> None: - """Unset the given key so exclude_unset dumps omit it.""" + """Reject deletion because ModeloptBaseConfig exposes a fixed schema key set.""" try: - field_name = self.get_field_name_from_key(key) + self.get_field_name_from_key(key) except AttributeError as e: raise KeyError(key) from e - if field_name in self._iterable_model_extra: - assert self.model_extra is not None - del self.model_extra[field_name] - self.model_fields_set.discard(field_name) - return - - field_info = type(self).model_fields[field_name] - default = field_info.get_default(call_default_factory=True) - if default is PydanticUndefined: - raise KeyError(f"Key {key} cannot be unset because it has no default.") - self.__dict__[field_name] = default - self.model_fields_set.discard(field_name) + raise TypeError("Config mapping keys are fixed and cannot be deleted.") def get(self, key: str, default: Any = None) -> Any: """Get the value for the given key (can be name or alias) or default if not found.""" try: return self[key] - except AttributeError: + except KeyError: return default def __len__(self) -> int: - """Return the length of the config.""" - return len(self.model_fields) + len(self._iterable_model_extra) + """Return the number of schema and extra keys exposed by the mapping.""" + return len(type(self).model_fields) + len(self._iterable_model_extra) def __iter__(self) -> Iterator[str]: - """Iterate over aliases (or name if alias is not defined) of fields.""" + """Iterate over schema keys, preferring aliases over field names.""" for field_name, field_info in type(self).model_fields.items(): yield field_info.alias or field_name yield from self._iterable_model_extra @@ -157,6 +162,29 @@ def _get_kv_dict(self) -> dict[str, Any]: """Return a dictionary with keys as aliases if possible.""" return {k: self[k] for k in self} + def iter_explicit_keys(self) -> Iterator[str]: + """Iterate over explicitly set schema keys, preferring aliases over field names.""" + for field_name, field_info in type(self).model_fields.items(): + if field_name in self.model_fields_set: + yield field_info.alias or field_name + yield from self._iterable_model_extra + + def _get_explicit_kv_dict(self) -> dict[str, Any]: + """Return explicitly set key-value pairs with keys as aliases if possible.""" + return {k: self[k] for k in self.iter_explicit_keys()} + + def explicit_keys(self) -> KeysView[str]: + """Return the explicitly set keys of the config.""" + return self._get_explicit_kv_dict().keys() + + def explicit_values(self) -> ValuesView[Any]: + """Return the explicitly set values of the config.""" + return self._get_explicit_kv_dict().values() + + def explicit_items(self) -> ItemsView[str, Any]: + """Return the explicitly set items of the config with keys as aliases if possible.""" + return self._get_explicit_kv_dict().items() + def keys(self) -> KeysView[str]: """Return the keys (aliases prioritized over names) of the config.""" return self._get_kv_dict().keys() diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 87f3a1f8e85..aa31adfd3ca 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -543,8 +543,8 @@ class QuantizerCfgEntry(ModeloptBaseConfig): "Attributes to apply to matched quantizers. A list configures a sequential quantizer." ), ) - enable: bool | None = ModeloptField( - default=None, + enable: bool = ModeloptField( + default=True, title="Quantizer enable flag.", description="Optional on/off toggle for matched quantizers, independent of cfg.", ) @@ -556,11 +556,6 @@ def validate_quantizer_cfg_entry(cls, values): if not isinstance(values, Mapping): return values - if "cfg" not in values and "enable" not in values: - raise ValueError( - "Each quant_cfg entry must specify 'cfg', 'enable', or both. " - "An entry with only 'quantizer_name' has no effect." - ) if "cfg" in values and values["cfg"] is None: raise ValueError("cfg must be omitted or a valid mapping/list, not null.") if "enable" in values and values["enable"] is None: @@ -1038,25 +1033,21 @@ def normalize_quant_cfg_list( - Legacy ``nn.*``-scoped format: ``{"nn.": {"": }}`` - converted to a new-format entry with ``parent_class`` set. - **Validation** - an entry is rejected if it carries no instruction, i.e. it specifies neither - ``cfg`` nor ``enable``. Concretely, the following are invalid: + **Validation** - an entry is rejected if its shape is invalid. Concretely, the following + are invalid: - An empty entry ``{}``. - - An entry with only ``quantizer_name`` and no other keys - the only effect would be an - implicit ``enable=True``, which must be stated explicitly. - An entry with ``enable=True`` (explicit or implicit) whose ``cfg`` is not a non-empty ``dict`` or ``list`` - e.g. ``{"quantizer_name": "*", "cfg": {}}`` or ``{"quantizer_name": "*", "cfg": 42}``. An enabled quantizer must have a valid configuration. - **Normalization** - after conversion and validation every entry is put into canonical form: - - - ``enable`` is set to ``True`` if not explicitly specified. - - ``cfg`` is set to ``None`` if not present in the entry. - - For dict and legacy inputs, every returned entry is guaranteed to have - ``quantizer_name``, ``enable``, and ``cfg`` set (plus optionally ``parent_class``). Typed - :class:`QuantizerCfgEntry` inputs are assumed to be already parsed and are preserved. + **Normalization** - after conversion and validation every entry is parsed as a + :class:`QuantizerCfgEntry`. Schema defaults are available through mapping access, so ``enable`` + defaults to ``True`` and ``cfg`` defaults to ``None`` when omitted. Omitted defaults are not + marked as explicitly set, so ``model_dump(exclude_unset=True)`` preserves the user's sparse + input shape. Typed :class:`QuantizerCfgEntry` inputs are assumed to be already parsed and are + preserved. Args: v: A list of raw quant_cfg entries in any supported format, or a legacy flat dict. @@ -1066,9 +1057,8 @@ def normalize_quant_cfg_list( typed entries are preserved. Raises: - ValueError: If any entry has only ``quantizer_name`` with neither ``cfg`` nor ``enable``, - if ``enable=True`` with an empty or non-dict/list ``cfg``, or if the entry format - is not recognized. + ValueError: If ``enable=True`` with an empty or non-dict/list ``cfg``, or if the entry + format is not recognized. """ if isinstance(v, list) and all(isinstance(raw, QuantizerCfgEntry) for raw in v): return cast("list[QuantizerCfgEntry]", v) @@ -1164,14 +1154,6 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: raise ValueError(f"Invalid quant_cfg entry: {raw!r}.") for entry in entries: - # Validate: must carry at least one instruction beyond the path selector. - if "cfg" not in entry and "enable" not in entry: - raise ValueError( - f"Invalid quant_cfg entry: {raw!r} - each entry must specify 'cfg', 'enable', " - "or both. An entry with only 'quantizer_name' has no effect (implicit " - "enable=True is not allowed; set it explicitly)." - ) - # Validate: when cfg is present and enable=True, cfg must be a non-empty # dict or list. An empty cfg would attempt to create a # QuantizerAttributeConfig with no actual configuration. @@ -1211,9 +1193,6 @@ def _dict_to_entry(key: str, value: Any) -> list[dict[str, Any]]: "explicitly." ) - # Normalize: make enable explicit. cfg remains omitted when it is intentionally unset. - entry.setdefault("enable", True) - result.append(QuantizerCfgEntry.model_validate(entry)) return result @@ -1981,7 +1960,7 @@ def _cfg_to_dict(cfg): for entry in quant_cfg: name = entry["quantizer_name"] raw_cfg = entry.get("cfg") - enable = entry.get("enable") + enable = entry["enable"] if "weight_quantizer" in name: # We don't calibrate weight quantizer continue @@ -1989,14 +1968,12 @@ def _cfg_to_dict(cfg): if isinstance(raw_cfg, list): for _config in raw_cfg: cfg = _cfg_to_dict(_config) - if enable is not None: - cfg["enable"] = enable + cfg["enable"] = enable if _not_dynamic(cfg): return True continue cfg = _cfg_to_dict(raw_cfg) - if enable is not None: - cfg["enable"] = enable + cfg["enable"] = enable if _not_dynamic(cfg): return True diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 1177a342c6e..17bc3f7ddc7 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -254,7 +254,7 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgInpu for entry in quant_cfg: quantizer_name: str = entry["quantizer_name"] cfg = entry["cfg"] # None, QuantizerAttributeConfig, or list after normalization - enable = entry["enable"] if entry["enable"] is not None else True + enable = entry["enable"] parent_class_name = entry.get("parent_class") if parent_class_name: try: diff --git a/tests/unit/torch/opt/test_config.py b/tests/unit/torch/opt/test_config.py index b2ffadb1a78..e0c5993a51a 100644 --- a/tests/unit/torch/opt/test_config.py +++ b/tests/unit/torch/opt/test_config.py @@ -72,7 +72,7 @@ def _run_test(is_new_registered): assert config[lin_name] == lin_expected_value assert config[lin_alias] == lin_expected_value assert getattr(config, lin_name) == lin_expected_value - with nullcontext() if is_new_registered else pytest.raises(AttributeError): + with nullcontext() if is_new_registered else pytest.raises(KeyError): config[new_name] # get diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 1690801dc1d..a86b046455e 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -59,6 +59,7 @@ def test_need_calibration(): def test_need_calibration_with_quantize_config_type(): """need_calibration accepts schema-backed QuantizeConfig objects.""" + assert need_calibration(QuantizeConfig()) assert need_calibration(QuantizeConfig.model_validate(FP8_DEFAULT_CFG)) assert not need_calibration(QuantizeConfig.model_validate(FP8_PER_CHANNEL_PER_TOKEN_CFG)) @@ -100,7 +101,18 @@ def test_quantizer_cfg_entry_is_pydantic_and_dict_like(): entry = QuantizerCfgEntry(quantizer_name="*", enable=False) assert isinstance(entry, ModeloptBaseConfig) assert entry["quantizer_name"] == "*" - assert entry.get("cfg") is None + assert entry["cfg"] is None + assert "cfg" in entry + assert list(entry) == ["quantizer_name", "parent_class", "cfg", "enable"] + assert dict(entry.items()) == { + "quantizer_name": "*", + "parent_class": None, + "cfg": None, + "enable": False, + } + assert dict(entry.explicit_items()) == {"quantizer_name": "*", "enable": False} + with pytest.raises(KeyError): + entry["unknown"] = 1 assert entry.model_dump(exclude_unset=True) == {"quantizer_name": "*", "enable": False} cfg_entry = QuantizerCfgEntry(quantizer_name="*weight_quantizer", cfg={"num_bits": 8}) @@ -108,8 +120,8 @@ def test_quantizer_cfg_entry_is_pydantic_and_dict_like(): assert _cfg_to_dict(cfg_entry["cfg"]) == {"num_bits": 8} -def test_quantizer_cfg_entry_mutable_mapping_delitem_unsets_field(): - """Deleting a config key resets it to unset for exclude_unset dumps.""" +def test_quantizer_cfg_entry_mutable_mapping_rejects_key_deletion(): + """ModeloptBaseConfig mappings have a fixed key set and reject deletion.""" entry = QuantizerCfgEntry(quantizer_name="*weight_quantizer", cfg={"num_bits": 8}, enable=True) assert isinstance(entry, MutableMapping) assert entry.model_dump(exclude_unset=True) == { @@ -118,11 +130,14 @@ def test_quantizer_cfg_entry_mutable_mapping_delitem_unsets_field(): "enable": True, } - del entry["cfg"] + with pytest.raises(TypeError): + del entry["cfg"] - assert entry["cfg"] is None + assert "cfg" in entry + assert entry["cfg"] is not None assert entry.model_dump(exclude_unset=True) == { "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": 8}, "enable": True, } @@ -196,10 +211,13 @@ def test_quantizer_cfg_entry_rejects_explicit_null_values(raw, match): normalize_quant_cfg_list([raw]) -def test_quantizer_cfg_entry_rejects_no_effect_entry(): - """Direct QuantizerCfgEntry construction rejects entries with no cfg or enable.""" - with pytest.raises(ValidationError, match="must specify 'cfg', 'enable'"): - QuantizerCfgEntry(quantizer_name="*") +def test_quantizer_cfg_entry_defaults_enable_true(): + """Direct QuantizerCfgEntry construction uses enable=True when omitted.""" + entry = QuantizerCfgEntry(quantizer_name="*") + assert entry["enable"] is True + assert entry["cfg"] is None + assert dict(entry.explicit_items()) == {"quantizer_name": "*"} + assert entry.model_dump(exclude_unset=True) == {"quantizer_name": "*"} def test_quantizer_cfg_entry_rejects_empty_name(): @@ -231,7 +249,8 @@ def test_new_format_passthrough(self): assert result[0]["quantizer_name"] == "*weight_quantizer" assert isinstance(result[0]["cfg"], QuantizerAttributeConfig) assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 8, "axis": 0} - assert result[0]["enable"] is True # defaulted + assert result[0]["enable"] is True # schema default + assert "enable" not in dict(result[0].explicit_items()) def test_typed_entry_list_passthrough(self): """Already-parsed QuantizerCfgEntry lists are returned unchanged.""" @@ -256,6 +275,7 @@ def test_mixed_typed_and_dict_entries_normalize_to_typed_entries(self): assert isinstance(result[1], QuantizerCfgEntry) assert _cfg_to_dict(result[1]["cfg"]) == {"num_bits": 8} assert result[1]["enable"] is True + assert "enable" not in dict(result[1].explicit_items()) def test_new_format_enable_false(self): """Explicit enable=False is preserved.""" @@ -277,7 +297,8 @@ def test_legacy_single_key_dict(self): result = normalize_quant_cfg_list(raw) assert result[0]["quantizer_name"] == "*weight_quantizer" assert _cfg_to_dict(result[0]["cfg"]) == {"num_bits": 8, "axis": 0} - assert result[0]["enable"] is True # defaulted + assert result[0]["enable"] is True # schema default + assert "enable" not in dict(result[0].explicit_items()) def test_legacy_single_key_dict_with_enable(self): """Legacy {'*path': {'enable': False}} splits enable out from cfg.""" @@ -296,17 +317,19 @@ def test_legacy_nn_class_scoped(self): assert result[0]["enable"] is False def test_normalization_cfg_defaults_to_none(self): - """Entries without cfg get cfg=None after normalization.""" + """Entries without cfg expose the default mapping key but keep it unset.""" raw = [{"quantizer_name": "*lm_head*", "enable": False}] result = normalize_quant_cfg_list(raw) assert "cfg" in result[0] assert result[0]["cfg"] is None + assert "cfg" not in dict(result[0].explicit_items()) def test_normalization_enable_defaults_to_true(self): - """Entries with cfg but no enable get enable=True after normalization.""" + """Entries with cfg but no enable read as enable=True without marking it explicit.""" raw = [{"quantizer_name": "*", "cfg": {"num_bits": 4}}] result = normalize_quant_cfg_list(raw) assert result[0]["enable"] is True + assert "enable" not in dict(result[0].explicit_items()) def test_empty_list(self): """Empty list is returned unchanged.""" @@ -322,10 +345,13 @@ def test_multiple_entries_order_preserved(self): assert result[0]["quantizer_name"] == "*" assert result[1]["quantizer_name"] == "*weight_quantizer" - def test_error_on_quantizer_name_only(self): - """Entry with only quantizer_name and no cfg or enable is rejected.""" - with pytest.raises(ValueError, match="must specify 'cfg', 'enable'"): - normalize_quant_cfg_list([{"quantizer_name": "*"}]) + def test_quantizer_name_only_defaults_enable_true(self): + """Entry with only quantizer_name uses enable=True from the schema default.""" + result = normalize_quant_cfg_list([{"quantizer_name": "*"}]) + assert result[0]["enable"] is True + assert result[0]["cfg"] is None + assert dict(result[0].explicit_items()) == {"quantizer_name": "*"} + assert result[0].model_dump(exclude_unset=True) == {"quantizer_name": "*"} def test_error_on_empty_dict(self): """An empty dict entry is rejected.""" @@ -421,6 +447,7 @@ def test_legacy_flat_dict_conversion(self): assert result[1]["quantizer_name"] == "*weight_quantizer" assert _cfg_to_dict(result[1]["cfg"]) == {"num_bits": 8, "axis": 0} assert result[1]["enable"] is True + assert "enable" not in dict(result[1].explicit_items()) def test_legacy_enable_only_produces_cfg_none(self): """Legacy {'*': {'enable': False}} should produce cfg=None, not cfg={}.""" From 0917ab83dfb8aead3638b79b0fef9c30bf7e32e5 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Fri, 8 May 2026 13:20:48 -0700 Subject: [PATCH 11/14] fix test errors Signed-off-by: Shengliang Xu --- .../backends/fp8_per_tensor_gemm.py | 21 +++++++++++++++++-- .../torch/quantization/backends/nvfp4_gemm.py | 18 ++++++++++++++-- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py index c45da9d4ff7..b2bdb3323b5 100644 --- a/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py +++ b/modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py @@ -20,6 +20,7 @@ import torch from torch.autograd import Function +from modelopt.torch.opt.config import ModeloptBaseConfig from modelopt.torch.quantization.backends.gemm_registry import gemm_registry from modelopt.torch.quantization.config import FP8_DEFAULT_CFG, find_quant_cfg_entry_by_path from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear @@ -121,9 +122,18 @@ def _fp8_availability_check(module, input, args, kwargs): # Check quantizer presence and configuration if not hasattr(module, "input_quantizer") or not hasattr(module, "weight_quantizer"): return False + if not module.input_quantizer.is_enabled or not module.weight_quantizer.is_enabled: + return False # Check input quantizer config - for key, value in input_cfg.items(): + # TODO: Move this compatibility check inside the quantizer; matching config items here + # is fragile and easy to break as config semantics evolve. + input_items = input_cfg.items + if isinstance(input_cfg, ModeloptBaseConfig): + input_items = input_cfg.explicit_items + for key, value in input_items(): + if key == "enable": + continue if ( not hasattr(module.input_quantizer, key) or getattr(module.input_quantizer, key) != value @@ -131,7 +141,14 @@ def _fp8_availability_check(module, input, args, kwargs): return False # Check weight quantizer config - for key, value in weight_cfg.items(): + # TODO: Move this compatibility check inside the quantizer; matching config items here + # is fragile and easy to break as config semantics evolve. + weight_items = weight_cfg.items + if isinstance(weight_cfg, ModeloptBaseConfig): + weight_items = weight_cfg.explicit_items + for key, value in weight_items(): + if key == "enable": + continue if ( not hasattr(module.weight_quantizer, key) or getattr(module.weight_quantizer, key) != value diff --git a/modelopt/torch/quantization/backends/nvfp4_gemm.py b/modelopt/torch/quantization/backends/nvfp4_gemm.py index 299f5851938..f431ad63c1e 100644 --- a/modelopt/torch/quantization/backends/nvfp4_gemm.py +++ b/modelopt/torch/quantization/backends/nvfp4_gemm.py @@ -21,6 +21,7 @@ from torch.autograd import Function import modelopt.torch.quantization as mtq +from modelopt.torch.opt.config import ModeloptBaseConfig from modelopt.torch.quantization.backends.gemm_registry import gemm_registry from modelopt.torch.quantization.backends.utils import fp4_compatible from modelopt.torch.quantization.nn.modules.quant_linear import RealQuantLinear @@ -230,7 +231,15 @@ def _nvfp4_availability_check(module, input, args, kwargs): return False # Check input quantizer config - for key, value in input_cfg.items(): + # TODO: Move this compatibility check inside the quantizer; matching config items here + # is fragile and easy to break as config semantics evolve. + if not module.input_quantizer.is_enabled or not module.weight_quantizer.is_enabled: + return False + + input_items = input_cfg.items + if isinstance(input_cfg, ModeloptBaseConfig): + input_items = input_cfg.explicit_items + for key, value in input_items(): if key == "enable": continue if ( @@ -240,7 +249,12 @@ def _nvfp4_availability_check(module, input, args, kwargs): return False # Check weight quantizer config - for key, value in weight_cfg.items(): + # TODO: Move this compatibility check inside the quantizer; matching config items here + # is fragile and easy to break as config semantics evolve. + weight_items = weight_cfg.items + if isinstance(weight_cfg, ModeloptBaseConfig): + weight_items = weight_cfg.explicit_items + for key, value in weight_items(): if key == "enable": continue if ( From b5b45b9640d583a5c7af951087dbc971a204ec2a Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Fri, 8 May 2026 15:43:14 -0700 Subject: [PATCH 12/14] Fix diffusers quant config explicit key handling Signed-off-by: Shengliang Xu --- examples/diffusers/quantization/config.py | 13 ++++--- examples/diffusers/quantization/quantize.py | 26 +++++++++---- tests/examples/diffusers/test_diffusers.py | 41 +++++++++++++++++++++ 3 files changed, 67 insertions(+), 13 deletions(-) diff --git a/examples/diffusers/quantization/config.py b/examples/diffusers/quantization/config.py index 967802ffa97..fada8bc8cd9 100644 --- a/examples/diffusers/quantization/config.py +++ b/examples/diffusers/quantization/config.py @@ -18,6 +18,8 @@ import torch.nn as nn from calib.plugin_calib import PercentileCalibrator +from modelopt.torch.quantization.config import QuantizerAttributeConfig + FP8_DEFAULT_CONFIG = { "quant_cfg": [ {"quantizer_name": "*", "enable": False}, @@ -107,11 +109,12 @@ def set_quant_config_attr(quant_config, trt_high_precision_dtype, quant_algo, ** for entry in quant_config["quant_cfg"]: p = entry.get("cfg", {}) if isinstance(entry, Mapping) else {} - if ( - isinstance(p, MutableMapping) - and "num_bits" in p - and "trt_high_precision_dtype" not in p - ): + if not isinstance(p, MutableMapping): + continue + keys = p.explicit_keys() if isinstance(p, QuantizerAttributeConfig) else p.keys() + # TODO: Replace this membership-based config patching with a better config API; + # ``in``/``not in`` checks are fragile with schema-backed defaults. + if "num_bits" in keys and "trt_high_precision_dtype" not in keys: p["trt_high_precision_dtype"] = trt_high_precision_dtype diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index a087d6f1d69..07178f22f50 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +import copy import logging import sys import time as time @@ -50,6 +51,7 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.export import export_hf_checkpoint +from modelopt.torch.opt.config import ModeloptBaseConfig def setup_logging(verbose: bool = False) -> logging.Logger: @@ -120,14 +122,6 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: base_cfg = mtq.INT8_SMOOTHQUANT_CFG else: base_cfg = INT8_DEFAULT_CONFIG - if self.config.collect_method != CollectMethod.DEFAULT: - reset_set_int8_config( - base_cfg, - self.config.percentile, - n_steps, - collect_method=self.config.collect_method.value, - backbone=backbone, - ) elif self.config.format == QuantFormat.FP8: base_cfg = FP8_DEFAULT_CONFIG elif self.config.format == QuantFormat.FP4: @@ -139,6 +133,22 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any: raise NotImplementedError(f"Unknown format {self.config.format}") # Build a fresh config dict so we never mutate the global constants. + if isinstance(base_cfg, ModeloptBaseConfig): + base_cfg = base_cfg.model_dump(exclude_unset=True) + base_cfg = copy.deepcopy(base_cfg) + + if ( + self.config.format == QuantFormat.INT8 + and self.config.collect_method != CollectMethod.DEFAULT + ): + reset_set_int8_config( + base_cfg, + self.config.percentile, + n_steps, + collect_method=self.config.collect_method.value, + backbone=backbone, + ) + quant_cfg_list = list(base_cfg["quant_cfg"]) if self.config.format == QuantFormat.FP4: diff --git a/tests/examples/diffusers/test_diffusers.py b/tests/examples/diffusers/test_diffusers.py index 5b117b41b3f..637214f736f 100644 --- a/tests/examples/diffusers/test_diffusers.py +++ b/tests/examples/diffusers/test_diffusers.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import importlib.util +import sys from pathlib import Path from typing import NamedTuple @@ -22,6 +25,44 @@ from _test_utils.torch.misc import minimum_sm +def _load_diffusers_quantization_config_module(): + quantization_dir = ( + Path(__file__).resolve().parents[3] / "examples" / "diffusers" / "quantization" + ) + sys.path.insert(0, str(quantization_dir)) + try: + spec = importlib.util.spec_from_file_location( + "diffusers_quantization_config", quantization_dir / "config.py" + ) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + finally: + sys.path.remove(str(quantization_dir)) + return module + + +def test_diffusers_quant_config_attr_uses_explicit_schema_keys() -> None: + import modelopt.torch.quantization as mtq + + config_module = _load_diffusers_quantization_config_module() + quant_config = copy.deepcopy(mtq.INT8_SMOOTHQUANT_CFG) + input_cfg = next( + entry["cfg"] + for entry in quant_config["quant_cfg"] + if entry["quantizer_name"] == "*input_quantizer" + ) + + assert "trt_high_precision_dtype" in input_cfg + assert "trt_high_precision_dtype" not in input_cfg.explicit_keys() + + config_module.set_quant_config_attr(quant_config, "Half", "smoothquant", alpha=0.8) + + assert input_cfg["trt_high_precision_dtype"] == "Half" + assert quant_config["algorithm"] == {"method": "smoothquant", "alpha": 0.8} + + class DiffuserModel(NamedTuple): dtype: str name: str From d9eccf311b950b7a4a7c8d588eff34956632dc41 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Fri, 8 May 2026 18:00:53 -0700 Subject: [PATCH 13/14] fix review comments Signed-off-by: Shengliang Xu --- examples/llm_ptq/hf_ptq.py | 4 +++- modelopt/torch/quantization/algorithms.py | 6 ++++-- .../quantization/nn/modules/tensor_quantizer.py | 14 ++++++++++---- tests/unit/torch/quantization/test_autoquant.py | 16 ++++++++++++++++ .../unit/torch/quantization/test_quantize_cpu.py | 6 ++++++ 5 files changed, 39 insertions(+), 7 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index ca75974d6d3..76c8521e093 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -98,7 +98,9 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: for entry in quant_cfg: if entry.get("quantizer_name") != "*[kv]_bmm_quantizer": continue - cfg = entry.get("cfg") or {} + cfg = entry.get("cfg") + if cfg is None: + cfg = {} cfg["use_constant_amax"] = True entry["cfg"] = cfg break diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index c67bd14a218..b9b1c9f8726 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -127,14 +127,16 @@ def __init__(self, quant_cfg: QuantRecipeConfig = None, name: str | None = None) if isinstance(quant_cfg, str): assert hasattr(mtq_config, quant_cfg), f"Unknown quantization format {quant_cfg}" quant_cfg = getattr(mtq_config, quant_cfg) - elif not isinstance(quant_cfg, QuantizeConfig): - assert name is not None, "name must be provided for custom quantization formats" + elif not isinstance(quant_cfg, QuantizeConfig) and name is None: + raise ValueError("name must be provided for custom quantization formats") self.config = ( quant_cfg.model_copy(deep=True) if isinstance(quant_cfg, QuantizeConfig) else mtq_config.QuantizeConfig.model_validate(quant_cfg) ) + if name is None: + raise ValueError("name must be provided for custom quantization formats") # Disable KV Cache quantization # Currently KV Cache quantization is enabled for some quantization formats and disabled for others diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index f0871f3d291..dd7b66dc3cd 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -18,7 +18,7 @@ import contextlib import math import warnings -from collections.abc import Callable, Mapping +from collections.abc import Callable, Mapping, Sequence from typing import Any, Protocol import torch @@ -1425,11 +1425,17 @@ def get_modelopt_state(self) -> dict[str, Any]: return {"num_quantizers": len(self), "is_sequential_quantizer": True} def set_from_attribute_config( - self, attributes: list[QuantizerAttributeConfig] | list[Mapping[str, Any]] + self, + attributes: ( + QuantizerAttributeConfig + | Mapping[str, Any] + | Sequence[QuantizerAttributeConfig | Mapping[str, Any]] + ), ): - """Set the attributes of contained quantizers from a list of attribute_dicts.""" + """Set the attributes of contained quantizers from attribute configs.""" if not isinstance(attributes, (list, tuple)): - assert isinstance(attributes, Mapping), "attributes must be a list or a mapping." + if not isinstance(attributes, Mapping): + raise TypeError("attributes must be a list/tuple or a mapping.") attributes = [attributes] * len(self) elif len(attributes) != len(self): raise ValueError(f"Expected {len(self)} attribute configs, but got {len(attributes)}.") diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index 87ec73291e7..2c8ae9250e7 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -79,6 +79,22 @@ def test_quant_recipe(quant_cfg, other_quant_cfg, is_less_than): assert qr_this_duplicate in {qr_this} +def test_quant_recipe_custom_quantize_config_requires_name(): + custom_cfg = mtq.QuantizeConfig( + quant_cfg=[ + mtq.QuantizerCfgEntry( + quantizer_name="*weight_quantizer", + cfg=mtq.QuantizerAttributeConfig(num_bits=8, axis=None), + ) + ] + ) + + with pytest.raises(ValueError, match="name must be provided"): + QuantRecipe(custom_cfg) + + assert str(QuantRecipe(custom_cfg, name="custom_cfg")).startswith("custom_cfg(") + + def test_quant_recipe_hparam(): model_test = torch.nn.Linear(4, 16) model_ref = torch.nn.Linear(4, 16) diff --git a/tests/unit/torch/quantization/test_quantize_cpu.py b/tests/unit/torch/quantization/test_quantize_cpu.py index efb8627d448..0fcac237ccb 100644 --- a/tests/unit/torch/quantization/test_quantize_cpu.py +++ b/tests/unit/torch/quantization/test_quantize_cpu.py @@ -407,6 +407,12 @@ def test_sequential_quantizer_rejects_mismatched_attribute_list_length(self): with pytest.raises(ValueError, match="Expected 2 attribute configs, but got 1"): quantizer.set_from_attribute_config([QuantizerAttributeConfig(num_bits=8)]) + def test_sequential_quantizer_rejects_non_mapping_attribute_config(self): + """SequentialQuantizer rejects invalid scalar attribute configs at runtime.""" + quantizer = SequentialQuantizer(TensorQuantizer(), TensorQuantizer()) + with pytest.raises(TypeError, match="attributes must be a list/tuple or a mapping"): + quantizer.set_from_attribute_config(object()) # type: ignore[arg-type] + def test_ordering_later_entry_overrides_earlier(): """Later entries in quant_cfg override earlier ones for the same quantizer.""" From a5e5062e9fd34323fcfb6423b3136e0227a71653 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Tue, 12 May 2026 01:59:03 -0700 Subject: [PATCH 14/14] docs: add modelopt config system guide Signed-off-by: Shengliang Xu --- docs/source/guides/11_config_system.rst | 499 ++++++++++++++++++++++++ 1 file changed, 499 insertions(+) create mode 100644 docs/source/guides/11_config_system.rst diff --git a/docs/source/guides/11_config_system.rst b/docs/source/guides/11_config_system.rst new file mode 100644 index 00000000000..54816c1c828 --- /dev/null +++ b/docs/source/guides/11_config_system.rst @@ -0,0 +1,499 @@ +.. _modelopt-config-system: + +ModelOpt Config System +###################### + +ModelOpt configs use Python types as the contract and YAML as the portable data +representation. A YAML file is loaded into ordinary Python ``dict``/``list`` +data, optional YAML composition is resolved, and the result is validated by the +owning Pydantic-compatible schema. + +The config system is intentionally general. Quantization configs, reusable YAML +snippets, and recipes are all consumers of the same lower-level semantics. +Recipes are one of the main applications; for the recipe-specific authoring +workflow, see :ref:`recipes`. + +.. contents:: On this page + :local: + :depth: 2 + + +Requirements +============ + +The core configuration system has four required properties and one optional +authoring feature: + +* **Typed / schematized**: each config surface has an explicit Python type + contract. Concrete model configs inherit from + :class:`~modelopt.torch.opt.config.ModeloptBaseConfig`; reusable container + shapes can use Pydantic-compatible type aliases such as + ``list[QuantizerCfgEntry]``. +* **Validated**: invalid values fail at load or schema-construction time. + Type errors, range violations, and unknown fields surface as Pydantic + validation errors instead of being silently ignored. +* **Persistent**: a resolved config can be serialized as plain YAML/JSON data, + and the same plain data can be embedded in a PyTorch checkpoint and restored + against the schema. +* **Backward compatible**: schemas evolve over time. Loading older persisted + configs against newer schemas must remain deliberate and testable. ModelOpt + does not yet have a formal compatibility window, but config authors should + treat compatibility as a schema-design requirement. +* **Composable YAML**: shared fragments such as numeric formats and list units + can be defined once and referenced from multiple YAML files. This is optional + authoring convenience, not a correctness requirement. + +These requirements split the system into three layers: + +* Python/Pydantic-compatible schemas define what is valid. +* YAML stores the user-facing config data. +* The loader resolves YAML conveniences, returns plain data, and invokes schema + validation where the file itself declares a schema. + + +Schema layer +============ + +``ModeloptBaseConfig`` is the common base class for structured ModelOpt config +objects: + +.. code-block:: python + + class ModeloptBaseConfig(BaseModel): + model_config = PyDanticConfigDict(extra="forbid", validate_assignment=True) + +The base class adds ModelOpt-specific behavior on top of Pydantic: + +* ``extra="forbid"`` rejects unknown keys by default. +* ``validate_assignment=True`` revalidates field updates after construction. +* ``ModeloptField(...)`` requires every field to declare a default value. +* ``model_dump()`` and ``model_dump_json()`` default to aliases and suppress + Pydantic serialization warnings. +* Mapping-style access, such as ``cfg["field"]``, ``cfg.get("field")``, + ``cfg.items()``, and ``cfg.update({...})``, keeps config objects compatible + with existing dict-oriented code. +* ``__init_subclass__`` registers each config subclass with PyTorch safe + globals so config objects can be deserialized by ``torch.load`` with + ``weights_only=True``. + +A typical config schema is a regular Pydantic model with field validators: + +.. code-block:: python + + class QuantizeConfig(ModeloptBaseConfig): + quant_cfg: QuantizeQuantCfgType = ModeloptField( + default=[{"quantizer_name": "*", "cfg": {"num_bits": 8, "axis": None}}], + title="Quantization configuration", + validate_default=True, + ) + algorithm: QuantizeAlgoCfgType = ModeloptField( + default="max", + title="Calibration algorithm", + validate_default=True, + ) + + @field_validator("quant_cfg", mode="before") + @classmethod + def normalize_quant_cfg(cls, v): + return normalize_quant_cfg_list(v) if isinstance(v, (list, dict)) else v + +Not every reusable config shape needs its own top-level config class. Some +YAML fragments are validated by narrower schema contracts: + +* Pydantic model classes work for object snippets such as one quantizer rule. +* ``list[T]`` aliases work for list snippets such as a group of quantizer rules. +* unions and other Pydantic ``TypeAdapter``-compatible annotations can be used + when the reusable data shape is a typed container rather than a standalone + model class. + +The important invariant is that the schema lives in Python, while YAML remains +data. + + +Validation model +================ + +Validation happens at two boundaries. + +Imported snippets +----------------- + +Every file referenced by a YAML ``imports`` block is a reusable snippet. It must +include a ``# modelopt-schema: ...`` comment in the initial comment preamble: + +.. code-block:: yaml + + # modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig + num_bits: e4m3 + axis: + +The loader resolves the schema path, validates the resolved snippet payload with +Pydantic ``TypeAdapter``, and only then exposes that snippet to the importing +file. This makes snippets independently reviewable and prevents a malformed +shared fragment from being copied into many configs silently. + +Schema paths are intentionally restricted: + +* they must resolve under the ``modelopt.`` package; +* they must point at a Pydantic-compatible type; +* they are validation contracts, not arbitrary Python execution hooks. + +Top-level configs +----------------- + +Top-level user configs do not always need a ``modelopt-schema`` comment. The +owning API often supplies schema context directly: + +.. code-block:: python + + from modelopt.recipe import load_config + from modelopt.torch.quantization.config import QuantizeConfig + + data = load_config("configs/ptq/presets/model/fp8", schema_type=QuantizeConfig) + cfg = QuantizeConfig.model_validate(data) + +``schema_type`` has one narrow loader responsibility: it provides typing context +for import resolution, especially for deciding whether a list import should +append one element or splice several elements. It is not a blanket request to +validate a top-level file. Top-level validation is performed by the owning +config object, or by ``load_config()`` when the top-level YAML file itself +contains ``# modelopt-schema: ...``. + + +YAML loading +============ + +The general loader lives in ``modelopt.torch.opt.config_loader`` and is exported +through ``modelopt.recipe.load_config``. It is intentionally below the recipe +layer so quantization and other core config modules can use it without depending +on recipes. + +``load_config(path, schema_type=...)`` performs this flow: + +1. Locate the YAML file. Filesystem paths are checked first; if the path is + relative and not found locally, the built-in ``modelopt_recipes`` package is + checked. ``.yml`` and ``.yaml`` suffixes may be omitted. +2. Read the optional ``# modelopt-schema: ...`` comment preamble. +3. Parse one YAML document, or two documents when a list-valued snippet also + needs an ``imports`` declaration. +4. Convert ``eXmY`` strings in ``num_bits`` and ``scale_bits`` fields into + ``(X, Y)`` tuples. +5. Resolve a file-local ``imports`` mapping. +6. Recursively resolve nested imports, detect circular imports, and validate + imported snippets against their declared schemas. +7. Walk the YAML tree and replace ``$import`` references. +8. Validate the top-level payload if the file declares ``modelopt-schema``. +9. Return resolved plain Python ``dict`` or ``list`` data. + +The loader is not a general templating engine. It only understands YAML data, +``imports``, ``$import``, schema comments, and the ``eXmY`` numeric shorthand. +Application-specific CLI or environment overrides should be applied by the +caller before final schema validation. + + +Self-contained YAML +=================== + +The simplest YAML config is self-contained and has no cross-file composition: + +.. code-block:: yaml + + algorithm: max + quant_cfg: + - quantizer_name: '*' + enable: false + - quantizer_name: '*weight_quantizer' + cfg: + num_bits: e2m1 + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + +This is the baseline format. YAML stores values; Python schemas define and +validate the allowed shape. + +Self-contained YAML is the right choice when a config is small, used once, or +clearer without indirection. Composable YAML is for repeated fragments and large +families of related configs. + + +YAML persistence +================ + +A loaded config should round-trip through plain data. After loading and +validation, serialize the resolved config rather than the authoring-time YAML: + +.. code-block:: python + + import yaml + + from modelopt.recipe import load_config + from modelopt.torch.quantization.config import QuantizeConfig + + data = load_config("configs/ptq/presets/model/fp8", schema_type=QuantizeConfig) + cfg = QuantizeConfig.model_validate(data) + + with open("resolved_quantize.yaml", "w", encoding="utf-8") as f: + yaml.safe_dump(cfg.model_dump(), f) + +The output is fully materialized plain data. YAML comments, ``imports`` blocks, +``$import`` markers, and schema comments are authoring metadata; they do not +survive in the resolved dump. This is intentional. Resolved dumps are suitable +for bug reports, reproducibility artifacts, and diffs across runs. + +Reloading a resolved dump is the same operation as any other load: parse plain +YAML data and validate it against the schema. + + +Checkpoint persistence +====================== + +Configs embedded in checkpoints should use the same plain-data contract. Store +``cfg.model_dump()`` in the checkpoint and restore it with the owning schema: + +.. code-block:: python + + import torch + + state = { + "model": model.state_dict(), + "modelopt_state": { + "quantize_config": cfg.model_dump(), + }, + } + torch.save(state, "checkpoint.pt") + + loaded = torch.load("checkpoint.pt", weights_only=True) + restored_cfg = QuantizeConfig.model_validate( + loaded["modelopt_state"]["quantize_config"] + ) + +Persisting plain data keeps checkpoints independent of the original YAML files +and of the authoring-time import graph. Future readers need the schema, not the +source snippets. + +``ModeloptBaseConfig`` also registers subclasses as PyTorch safe globals, which +allows config objects to participate in safe deserialization. Plain-data +persistence remains the most portable form because it is easy to inspect, diff, +and migrate. + + +Schema evolution +================ + +Backward compatibility is a schema concern. When a persisted config outlives the +code version that produced it, a newer schema must either accept it or reject it +with a clear migration path. + +Use these rules when evolving config schemas: + +* Prefer additive fields with defaults over required fields with no default. +* Keep validators tolerant of older spellings when a rename is in flight. +* Normalize legacy forms in ``mode="before"`` validators, then store the + canonical form in ``model_dump()`` output. +* Avoid changing the meaning of an existing key. Add a new key when semantics + change materially. +* Add tests that load representative old plain-data configs against the new + schema. + +ModelOpt does not yet define a formal compatibility window for every config +surface, so schema authors should document compatibility-sensitive changes in +the owning feature area. + + +Composable YAML +=============== + +Python already has composition through variables, functions, imports, and +mutation. YAML does not. ModelOpt's YAML composition layer exists so repeated +YAML fragments can be shared without moving the canonical config into Python. + +Typical repeated fragments include: + +* one numeric format used by several quantizer entries; +* one complete quantizer-entry snippet reused in many configs; +* a list of quantizer entries reused as a unit; +* a snippet that depends on another snippet; +* related variants such as dynamic and static numeric formats. + +The chosen design is a small YAML-native DSL: a file-local ``imports`` mapping +binds names to YAML files, and inline ``$import`` references insert those +resolved snippets into the data tree. Python remains responsible for schema +validation; YAML remains data. + + +Alternatives considered +----------------------- + +The main alternative is to move more composition knowledge into Python, either +through hard-coded fragment registries, Python-owned name-to-file mappings, or +factory-style configs. Those approaches are useful for object construction, but +they make ordinary YAML reuse depend on Python edits or make Python callables +part of the canonical config representation. + +ModelOpt uses a small YAML DSL instead: each file declares its own imports, +references them with ``$import``, and resolves to plain data before validation. +This keeps the import graph self-describing, lets config authors add reusable +fragments as YAML, and still validates every resolved value against Python +schemas. + + +Import declarations +------------------- + +Imports are declared once per YAML file: + +.. code-block:: yaml + + imports: + nvfp4: configs/numerics/nvfp4 + kv_fp8: configs/ptq/units/kv_fp8 + +The names are scoped to that file. An imported snippet may declare its own +``imports`` block, and those names are scoped to the snippet file. Recursive +imports are resolved depth-first. Circular imports are detected using canonical +resolved paths and fail with ``ValueError``. + +A file that declares no ``imports`` may not contain ``$import`` markers. This +keeps authoring mistakes explicit: an unknown reference fails instead of being +left as literal data. + + +Dict imports +------------ + +When ``$import`` appears inside a mapping, the imported mapping is copied into +the current mapping. Inline keys override imported keys at that same mapping +level: + +.. code-block:: yaml + + cfg: + $import: nvfp4 + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + +Multiple imports are applied in order, then inline keys are applied last: + +.. code-block:: yaml + + cfg: + $import: [base_format, override_format] + axis: 0 + +The merge is shallow at the mapping where ``$import`` appears. If one nested +leaf changes, provide the complete nested value inline or define a named snippet +for that variant. This avoids hidden deep-merge rules that are hard to review. + + +List imports +------------ + +List imports are type-directed. For a containing list with schema ``list[T]``: + +* importing a snippet with schema ``list[T]`` splices all imported entries into + the containing list; +* importing a snippet with schema ``T`` appends the imported object as a single + list element; +* importing any other schema raises an error; +* importing into an untyped list raises an error. + +Example: + +.. code-block:: yaml + + quant_cfg: + - $import: base_disable_all # QuantizerCfgEntry, appended + - quantizer_name: '*weight_quantizer' + cfg: + $import: nvfp4 # QuantizerAttributeConfig, dict import + - $import: kv_fp8 # QuantizerCfgListConfig, spliced + +A list-entry import must be a mapping whose only key is ``$import``. If an entry +needs local changes, either write that entry inline or create a snippet for the +variant. + + +Multi-document list snippets +---------------------------- + +A YAML file has one root node per document. A list-valued snippet that also +needs an ``imports`` block therefore uses two YAML documents: the first document +holds import declarations, and the second document holds the list payload. + +.. code-block:: yaml + + # modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig + imports: + fp8: configs/numerics/fp8 + --- + - quantizer_name: '*[kv]_bmm_quantizer' + cfg: + $import: fp8 + +Only ``imports`` from the first document is meaningful for a list snippet. The +loader resolves imports in the second document and returns the resolved list. + + +Composition error model +----------------------- + +The loader raises ``ValueError`` for invalid composition, including: + +* ``imports`` is not a mapping; +* an import path is empty or cannot be resolved; +* a ``$import`` reference is not listed in the file-local ``imports`` mapping; +* an imported snippet does not declare ``modelopt-schema``; +* a schema path does not resolve under ``modelopt.``; +* an imported snippet does not validate against its declared schema; +* a list import has no typed containing list; +* a list import schema is neither the containing list schema nor its element + schema; +* a circular import is detected. + +These failures are load-time errors by design. A composed config should either +resolve to valid plain data or fail before the owning optimization pass starts. + + +Consumers of the config system +============================== + +The config system is shared infrastructure. Current consumers include: + +* lower-level optimization configs such as PTQ ``QuantizeConfig``; +* built-in YAML config snippets under ``modelopt_recipes/configs``; +* higher-level recipes, which package metadata together with one or more + type-specific config sections. + +Recipes do not define separate config semantics. ``load_recipe()`` is a +consumer-specific wrapper: it uses ``load_config()`` to resolve YAML, supplies +schema context for each recipe section, and then constructs a typed recipe +object. The general contract remains the same: YAML authoring data resolves to +plain Python data, and Python schemas validate the result. + + +Authoring guidelines +==================== + +When adding config schemas or YAML files: + +* Put the canonical schema in Python, not in YAML comments or loader logic. +* Use ``ModeloptBaseConfig`` for structured config objects that need methods, + defaults, and validators. +* Use ``ModeloptBaseConfig`` subclasses or typed aliases for reusable snippets. +* Prefer self-contained YAML unless a fragment is reused or factoring materially + improves reviewability. +* Add ``# modelopt-schema: ...`` to every file that can be referenced from an + ``imports`` block. +* Keep top-level user config files free of schema comments unless they are also + intended to be imported as snippets. +* Use a concrete typed list schema for list snippets so append-vs-splice + behavior is unambiguous. +* Serialize resolved configs with ``model_dump()`` for long-term artifacts. +* Store plain config data, not authoring-time YAML paths, in checkpoints. +* Do not parse ModelOpt config YAML with raw YAML APIs in application code. Use + ``load_config()`` or a higher-level API built on it so imports, schema checks, + and ``eXmY`` conversion are applied consistently.