diff --git a/modelopt/onnx/llm_export_utils/quantization_utils.py b/modelopt/onnx/llm_export_utils/quantization_utils.py index 54ca93d5388..d6c8c4c1e9a 100644 --- a/modelopt/onnx/llm_export_utils/quantization_utils.py +++ b/modelopt/onnx/llm_export_utils/quantization_utils.py @@ -69,9 +69,7 @@ def get_quant_config(precision, lm_head_precision="fp16"): else: raise ValueError(f"Unsupported precision: {precision}") - quant_cfg_list: list = [ - e for e in quant_cfg["quant_cfg"] if isinstance(e, dict) and "quantizer_name" in e - ] + quant_cfg_list: list = [e for e in quant_cfg["quant_cfg"] if "quantizer_name" in e] if lm_head_precision == "fp8": quant_cfg_list.append( diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index a3af43a914b..749d80a933d 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -20,11 +20,10 @@ import warnings from enum import Enum -from pydantic import field_validator, model_validator -from typing_extensions import NotRequired, TypedDict +from pydantic import Field, model_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField -from modelopt.torch.quantization.config import QuantizeConfig +from modelopt.torch.quantization.config import QuantizeConfig # noqa: TC001 from modelopt.torch.speculative.config import DFlashConfig, EagleConfig, MedusaConfig from modelopt.torch.speculative.plugins.hf_training_args import DataArguments as SpecDataArgs from modelopt.torch.speculative.plugins.hf_training_args import ModelArguments as SpecModelArgs @@ -43,14 +42,21 @@ class RecipeType(str, Enum): # QAT = "qat" # Not implemented yet, will be added in the future. -class RecipeMetadataConfig(TypedDict): - """YAML shape of the recipe metadata section.""" +_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe." - recipe_type: RecipeType - description: NotRequired[str] +class RecipeMetadataConfig(ModeloptBaseConfig): + """YAML shape of the recipe metadata section.""" -_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe." + recipe_type: RecipeType = Field( + title="Recipe type", + description="The type of the recipe (e.g. PTQ).", + ) + description: str = ModeloptField( + default=_DEFAULT_RECIPE_DESCRIPTION, + title="Description", + description="Human-readable description of the recipe.", + ) def _metadata_field(recipe_type: RecipeType): @@ -69,45 +75,32 @@ class ModelOptRecipeBase(ModeloptBaseConfig): If a layer name matches ``"*output_layer*"``, the attributes will be replaced with ``{"enable": False}``. """ - metadata: RecipeMetadataConfig = ModeloptField( - default={"recipe_type": RecipeType.PTQ, "description": _DEFAULT_RECIPE_DESCRIPTION}, + metadata: RecipeMetadataConfig = Field( title="Metadata", - description="Recipe metadata containing the recipe type and description.", - validate_default=True, + description="Recipe metadata containing the recipe type and description. " + "Required: a recipe without a ``metadata`` section is rejected so that a " + "missing section can't silently fall back to a default recipe type.", ) - @field_validator("metadata") - @classmethod - def validate_metadata(cls, metadata: RecipeMetadataConfig) -> RecipeMetadataConfig: - """Validate recipe metadata and fill defaults for optional fields.""" - if metadata["recipe_type"] not in RecipeType: - raise ValueError( - f"Unsupported recipe type: {metadata['recipe_type']}. " - f"Only {list(RecipeType)} are currently supported." - ) - return {"description": _DEFAULT_RECIPE_DESCRIPTION, **metadata} - @property def recipe_type(self) -> RecipeType: """Return the recipe type from metadata.""" - return self.metadata["recipe_type"] + return self.metadata.recipe_type @property def description(self) -> str: """Return the recipe description from metadata.""" - return self.metadata.get("description", _DEFAULT_RECIPE_DESCRIPTION) + return self.metadata.description class ModelOptPTQRecipe(ModelOptRecipeBase): """Our config class for PTQ recipes.""" - metadata: RecipeMetadataConfig = _metadata_field(RecipeType.PTQ) - - quantize: QuantizeConfig = ModeloptField( - default=QuantizeConfig(), + quantize: QuantizeConfig = Field( title="PTQ config", - description="PTQ config containing quant_cfg and algorithm.", - validate_default=True, + description="PTQ config containing quant_cfg and algorithm. Required: a PTQ " + "recipe without a ``quantize`` section is rejected so that a missing section " + "can't silently fall back to the default INT8 config.", ) diff --git a/modelopt/recipe/loader.py b/modelopt/recipe/loader.py index 7be2abac19b..0a9218ff7d0 100644 --- a/modelopt/recipe/loader.py +++ b/modelopt/recipe/loader.py @@ -29,9 +29,6 @@ from .config import ( RECIPE_TYPE_TO_CLASS, - ModelOptDFlashRecipe, - ModelOptEagleRecipe, - ModelOptMedusaRecipe, ModelOptPTQRecipe, ModelOptRecipeBase, RecipeMetadataConfig, @@ -40,6 +37,16 @@ __all__ = ["load_config", "load_recipe"] +# Each recipe type's mandatory top-level body section. Checked at the loader level (on the +# raw YAML, before pydantic fills in defaults) so the user sees a clear "PTQ recipe file X +# must contain 'quantize'" instead of pydantic's generic missing-field error. +_REQUIRED_SECTION_PER_RECIPE_TYPE: dict[RecipeType, str] = { + RecipeType.PTQ: "quantize", + RecipeType.SPECULATIVE_EAGLE: "eagle", + RecipeType.SPECULATIVE_DFLASH: "dflash", + RecipeType.SPECULATIVE_MEDUSA: "medusa", +} + def _resolve_recipe_path(recipe_path: str | Path | Traversable) -> Path | Traversable: """Resolve a recipe path, checking the built-in library first then the filesystem. @@ -148,63 +155,48 @@ def _load_recipe_from_file( plus the algorithm-specific section (``quantize`` / ``eagle`` / ``dflash`` / ``medusa``). """ rtype = _peek_recipe_type(recipe_file) - schema_type = RECIPE_TYPE_TO_CLASS.get(rtype) if rtype is not None else None - data = load_config(recipe_file, schema_type=schema_type) - if not isinstance(data, dict): - raise ValueError( - f"Recipe file {recipe_file} must be a YAML mapping, got {type(data).__name__}." - ) + if rtype is None: + raise ValueError(f"Recipe file {recipe_file} must contain a 'metadata.recipe_type' field.") + schema_class = RECIPE_TYPE_TO_CLASS.get(rtype) + if schema_class is None: + raise ValueError(f"Unsupported recipe type: {rtype!r}") + + # Pre-flight check on the *raw* YAML so the user sees a clear loader-level error + # rather than a generic pydantic missing-field error. Speculative recipes' body + # sections have field-level defaults, so this check is what keeps their loader + # semantics consistent with PTQ. + required_section = _REQUIRED_SECTION_PER_RECIPE_TYPE.get(rtype) + if required_section is not None: + import yaml + + raw = yaml.safe_load(recipe_file.read_text()) or {} + if not isinstance(raw, dict) or required_section not in raw: + kind = ( + rtype.value.split("_", 1)[-1].upper() if "_" in rtype.value else rtype.value.upper() + ) + raise ValueError(f"{kind} recipe file {recipe_file} must contain {required_section!r}.") + + # Passing ``schema_type=schema_class`` to ``load_config`` enables typed-list + # ``$import`` resolution (e.g. ``$import: disable_all`` spliced into + # ``quantize.quant_cfg`` needs to know the list's element schema is + # :class:`QuantizerCfgEntry`). The return value is already a validated schema + # instance. if overrides: + # Overrides have to be applied before pydantic validation. Round-trip through + # ``model_dump()`` so $imports are resolved and the dict has the resolved shape; + # then splice the dotlist values and re-validate. + recipe = load_config(recipe_file, schema_type=schema_class) + data = recipe.model_dump() data = _apply_dotlist(data, overrides) + return schema_class.model_validate(data) - metadata = data.get("metadata", {}) - if not isinstance(metadata, dict): + recipe = load_config(recipe_file, schema_type=schema_class) + if not isinstance(recipe, schema_class): raise ValueError( - f"Recipe file {recipe_file} field 'metadata' must be a mapping, " - f"got {type(metadata).__name__}." - ) - recipe_type = metadata.get("recipe_type") - if recipe_type is None: - raise ValueError(f"Recipe file {recipe_file} must contain a 'metadata.recipe_type' field.") - - if recipe_type == RecipeType.PTQ: - if "quantize" not in data: - raise ValueError(f"PTQ recipe file {recipe_file} must contain 'quantize'.") - return ModelOptPTQRecipe( - metadata=metadata, - quantize=data["quantize"], + f"Recipe file {recipe_file} must produce a {schema_class.__name__}, " + f"got {type(recipe).__name__}." ) - if recipe_type == RecipeType.SPECULATIVE_EAGLE: - if "eagle" not in data: - raise ValueError(f"EAGLE recipe file {recipe_file} must contain 'eagle'.") - return ModelOptEagleRecipe( - metadata=metadata, - model=data.get("model") or {}, - data=data.get("data") or {}, - training=data.get("training") or {}, - eagle=data["eagle"], - ) - if recipe_type == RecipeType.SPECULATIVE_DFLASH: - if "dflash" not in data: - raise ValueError(f"DFlash recipe file {recipe_file} must contain 'dflash'.") - return ModelOptDFlashRecipe( - metadata=metadata, - model=data.get("model") or {}, - data=data.get("data") or {}, - training=data.get("training") or {}, - dflash=data["dflash"], - ) - if recipe_type == RecipeType.SPECULATIVE_MEDUSA: - if "medusa" not in data: - raise ValueError(f"Medusa recipe file {recipe_file} must contain 'medusa'.") - return ModelOptMedusaRecipe( - metadata=metadata, - model=data.get("model") or {}, - data=data.get("data") or {}, - training=data.get("training") or {}, - medusa=data["medusa"], - ) - raise ValueError(f"Unsupported recipe type: {recipe_type!r}") + return recipe def _find_recipe_section_file( @@ -229,25 +221,10 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase: quantize. """ metadata_file = _find_recipe_section_file(recipe_dir, "metadata") - metadata = load_config(metadata_file, schema_type=RecipeMetadataConfig) - if not isinstance(metadata, dict): - raise ValueError( - f"Metadata file {metadata_file} must be a YAML mapping, got {type(metadata).__name__}." - ) - recipe_type = metadata.get("recipe_type") - if recipe_type is None: - raise ValueError(f"Metadata file {metadata_file} must contain a 'recipe_type' field.") - if recipe_type == RecipeType.PTQ: + if metadata.recipe_type == RecipeType.PTQ: quantize_file = _find_recipe_section_file(recipe_dir, "quantize") - quantize_data = load_config(quantize_file, schema_type=QuantizeConfig) - if not isinstance(quantize_data, dict): - raise ValueError( - f"{quantize_file} must be a YAML mapping, got {type(quantize_data).__name__}." - ) - return ModelOptPTQRecipe( - metadata=metadata, - quantize=quantize_data, - ) - raise ValueError(f"Unsupported recipe type: {recipe_type!r}") + quantize_cfg = load_config(quantize_file, schema_type=QuantizeConfig) + return ModelOptPTQRecipe(metadata=metadata, quantize=quantize_cfg) + raise ValueError(f"Unsupported recipe type: {metadata.recipe_type!r}") diff --git a/modelopt/torch/opt/config.py b/modelopt/torch/opt/config.py index 62f7b7e16a2..fce2eb36f6b 100644 --- a/modelopt/torch/opt/config.py +++ b/modelopt/torch/opt/config.py @@ -17,7 +17,7 @@ import fnmatch import json -from collections.abc import Callable, ItemsView, Iterator, KeysView, ValuesView +from collections.abc import Callable, ItemsView, Iterator, KeysView, MutableMapping, ValuesView from typing import Any, TypeAlias import torch @@ -57,11 +57,18 @@ def ModeloptField(default: Any = PydanticUndefined, **kwargs): # noqa: N802 # TODO: expand config classes to searcher -class ModeloptBaseConfig(BaseModel): +class ModeloptBaseConfig(BaseModel, MutableMapping): """Our config base class for mode configuration. The base class extends the capabilities of pydantic's BaseModel to provide additional methods and properties for easier access and manipulation of the configuration. + + Inherits from :class:`collections.abc.MutableMapping` so instances satisfy + ``isinstance(cfg, Mapping)`` / ``isinstance(cfg, MutableMapping)`` checks and pick up the + mixin methods (``pop``, ``popitem``, ``setdefault``, ``clear``). Schema fields are fixed, + so ``__delitem__`` raises :class:`TypeError`; the inherited ``pop`` / ``clear`` / + ``popitem`` therefore also raise on any existing key, while ``pop(key, default)`` for a + missing key still returns the default normally. """ model_config = PyDanticConfigDict(extra="forbid", validate_assignment=True) @@ -110,18 +117,49 @@ def __contains__(self, key: str) -> bool: return False def __getitem__(self, key: str) -> Any: - """Get the value for the given key (can be name or alias of field).""" - return getattr(self, self.get_field_name_from_key(key)) + """Get the value for the given key (can be name or alias of field). + + Raises :class:`KeyError` for missing keys so the class behaves like a regular + :class:`Mapping` — required for the inherited ``MutableMapping`` mixin methods + (``pop``, ``setdefault``, ...) to dispatch correctly. + """ + try: + return getattr(self, self.get_field_name_from_key(key)) + except AttributeError: + raise KeyError(key) from None def __setitem__(self, key: str, value: Any) -> None: - """Set the value for the given key (can be name or alias of field).""" - setattr(self, self.get_field_name_from_key(key), value) + """Set the value for the given key (can be name or alias of field). + + Raises :class:`KeyError` (not :class:`AttributeError`) for unknown keys so the + class matches the :class:`MutableMapping` protocol — both for direct + ``cfg["unknown"] = value`` writes and for inherited mixin helpers like + ``setdefault`` that write through ``__setitem__``. + """ + try: + setattr(self, self.get_field_name_from_key(key), value) + except AttributeError: + raise KeyError(key) from None + + def __delitem__(self, key: str) -> None: + """Reject key deletion. + + ``ModeloptBaseConfig`` exposes a fixed pydantic schema, so removing a key is + ill-defined: schema fields can't disappear, and silently resetting them to their + defaults would surprise callers. Raise ``TypeError`` instead. Defined so the + class fully satisfies the ``MutableMapping`` protocol (``__delitem__`` is + required), without committing to actual deletion semantics. + """ + raise TypeError( + f"{type(self).__name__} does not support key deletion; schema fields are " + f"fixed (attempted to delete {key!r})." + ) def get(self, key: str, default: Any = None) -> Any: """Get the value for the given key (can be name or alias) or default if not found.""" try: return self[key] - except AttributeError: + except KeyError: return default def __len__(self) -> int: diff --git a/modelopt/torch/opt/config_loader.py b/modelopt/torch/opt/config_loader.py index 43231c90995..76ed2bb6503 100644 --- a/modelopt/torch/opt/config_loader.py +++ b/modelopt/torch/opt/config_loader.py @@ -33,12 +33,14 @@ import re import sys from pathlib import Path -from typing import Any, Union, get_args, get_origin, get_type_hints +from typing import Any, TypeVar, Union, get_args, get_origin, get_type_hints, overload import yaml from pydantic import TypeAdapter from typing_extensions import NotRequired, Required, is_typeddict +from modelopt.torch.opt.config import ModeloptBaseConfig + @dataclass class _ListSnippet: @@ -592,29 +594,74 @@ def _find_import_marker(obj: Any, context: str = "root") -> tuple[Any, str] | No return None +_SchemaT = TypeVar("_SchemaT", bound=ModeloptBaseConfig) + + +@overload +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: type[_SchemaT], +) -> _SchemaT: ... + + +@overload +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: type[list[_SchemaT]], +) -> list[_SchemaT]: ... + + +@overload +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: None = None, +) -> Any: ... + + def load_config( config_path: str | Path | Traversable, *, schema_type: Any | None = None, -) -> dict[str, Any] | list[Any]: +) -> Any: """Load a YAML config and resolve all ``$import`` references. This is the primary config loading entry point. It loads the YAML file, - resolves any ``imports`` / ``$import`` directives, and returns the final - config dict or list. - - ``schema_type`` supplies a typing context for import resolution when the - file itself has no ``modelopt-schema`` comment. It is intentionally not a - request to validate the top-level file. Top-level files are validated only - when they declare ``modelopt-schema``; imported snippets are stricter and - must always declare ``modelopt-schema``. + resolves any ``imports`` / ``$import`` directives, and returns either a + validated instance of the schema (when one is known) or the raw resolved + payload. + + The effective schema is selected as follows: + + 1. If ``schema_type`` is provided, it is used. + 2. Otherwise, the schema declared by the file's ``# modelopt-schema:`` + comment (if any) is used. + + When an effective schema is selected, the resolved payload is validated + and returned as an instance of that schema — e.g., a Pydantic model + instance for ``BaseModel`` schemas, or a validated dict / list for + ``TypedDict`` / ``list[TypedDict]`` schemas. If neither source supplies a + schema, the raw resolved dict or list is returned unchanged. + + Imported snippets are stricter and must always declare ``modelopt-schema``; + they are validated during import resolution regardless of the top-level + selection above. """ raw = _load_raw_config_with_schema(config_path) data = raw.data declared_schema_type = _schema_type(raw.schema) if raw.schema else None - resolver_schema_type = declared_schema_type or schema_type + effective_schema_type = schema_type if schema_type is not None else declared_schema_type if isinstance(data, (_ListSnippet, dict)): - data = _resolve_imports(data, schema_type=resolver_schema_type) - _validate_modelopt_schema(raw.schema, data, raw.path, schema_type=declared_schema_type) - return data + data = _resolve_imports(data, schema_type=effective_schema_type) + if effective_schema_type is None: + return data + try: + return TypeAdapter(effective_schema_type).validate_python(data) + except Exception as exc: + raise ValueError( + f"Config file {raw.path} does not match modelopt-schema " + f"{_schema_label(effective_schema_type, raw.schema)!r}: {exc}" + ) from exc diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 992717983db..e4e633e36ae 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -40,7 +40,7 @@ from . import config as mtq_config from . import model_calib -from .config import QuantizeConfig, QuantizerAttributeConfig +from .config import QuantizeConfig, QuantizerAttributeConfig, QuantizerCfgEntry from .conversion import set_quantizer_by_cfg from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import is_quantized_linear @@ -129,7 +129,9 @@ def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | No # Disable KV Cache quantization # Currently KV Cache quantization is enabled for some quantization formats and disabled for others # This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy - self.config.quant_cfg.append({"quantizer_name": "*output_quantizer", "enable": False}) + self.config.quant_cfg.append( + QuantizerCfgEntry(quantizer_name="*output_quantizer", enable=False) + ) self.compression = estimate_quant_compression(self.config) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index b450eb5fa0d..b0c3fb859b2 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -152,23 +152,94 @@ import copy import warnings -from typing import Any, Literal, cast +from collections.abc import Mapping, Sequence +from typing import Any, Literal from pydantic import AliasChoices, ValidationInfo, field_validator, model_validator -from typing_extensions import Required, TypedDict from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.opt.config_loader import load_config from modelopt.torch.utils.network import ConstructorLike -class QuantizerCfgEntry(TypedDict, total=False): +class QuantizerCfgEntry(ModeloptBaseConfig): """A single entry in a ``quant_cfg`` list.""" - quantizer_name: Required[str] # matched against quantizer module names - parent_class: str | None # optional; filters by pytorch module class name (e.g. "nn.Linear") - cfg: dict[str, Any] | list[dict[str, Any]] | None # quantizer attribute config(s) - enable: bool | None # toggles matched quantizers on/off; independent of cfg + quantizer_name: str = ModeloptField( + default=..., + title="Quantizer name pattern.", + description="Glob pattern matched against quantizer module names.", + ) + parent_class: str | None = ModeloptField( + default=None, + title="Optional parent-class filter.", + description="If provided, only quantizers whose parent module matches this PyTorch class " + "name (e.g. ``'nn.Linear'``) are affected.", + ) + cfg: "QuantizerAttributeConfig | list[QuantizerAttributeConfig] | None" = ModeloptField( + default=None, + title="Quantizer attribute config.", + description="A :class:`QuantizerAttributeConfig` (or a mapping that validates as one), " + "or a list of such for sequential quantizers. ``None`` leaves the existing attribute " + "config untouched.", + ) + enable: bool = ModeloptField( + default=True, + title="Enable the quantizer.", + description="Toggle matched quantizers on/off; independent of ``cfg``.", + ) + + @model_validator(mode="before") + @classmethod + def _normalize_cfg_shape(cls, values): + """Pre-validation shape rules for ``cfg``. + + Runs against the raw input mapping, before pydantic coerces ``cfg`` into a + :class:`QuantizerAttributeConfig` (which would fill in schema defaults and erase the + distinction between "user typed nothing" and "user typed `{}`"). Two rules: + + 1. ``enable=False`` with an empty ``cfg`` — empty dict, empty list, or list of empty + dicts — is normalized to ``cfg=None``. Downstream applies any non-``None`` ``cfg`` + as a full quantizer-attribute replacement, so without this an entry like + ``{cfg: {}, enable: False}`` would reset attributes to schema defaults and a later + re-enable would bring the quantizer back with defaults instead of its original config. + + 2. ``enable=True`` (explicit or implicit) with an empty ``cfg`` — same shapes — is + rejected. Pydantic would otherwise coerce ``{}`` into ``QuantizerAttributeConfig()`` + with all defaults, silently turning a likely typo (``cfg: {}``) into "quantize with + schema defaults." Callers who really want defaults should drop ``cfg`` entirely and + rely on ``enable=True``; an empty ``cfg`` always indicates missing input. + """ + if not isinstance(values, dict): + return values + cfg = values.get("cfg") + cfg_is_empty = (isinstance(cfg, dict) and len(cfg) == 0) or ( + isinstance(cfg, list) + and (len(cfg) == 0 or all(isinstance(item, dict) and len(item) == 0 for item in cfg)) + ) + if cfg_is_empty: + if values.get("enable") is False: + values = {**values, "cfg": None} + else: + raise ValueError( + f"QuantizerCfgEntry 'cfg' must specify at least one quantizer attribute; " + f"got an empty mapping/list for quantizer " + f"{values.get('quantizer_name')!r}. To keep existing attributes, drop " + f"'cfg' and rely on 'enable=True'; to disable, set 'enable=False'." + ) + return values + + @model_validator(mode="after") + def _validate_instruction(self): + """Reject entries that carry no instruction beyond the path selector.""" + fields_set = self.model_fields_set + if "cfg" not in fields_set and "enable" not in fields_set: + raise ValueError( + f"QuantizerCfgEntry must specify 'cfg', 'enable', or both. An entry with only " + f"'quantizer_name'={self.quantizer_name!r} has no effect (implicit enable=True " + "is not allowed; set it explicitly)." + ) + return self def find_quant_cfg_entry_by_path( @@ -197,7 +268,7 @@ def find_quant_cfg_entry_by_path( """ result = None for entry in quant_cfg_list: - if isinstance(entry, dict) and entry.get("quantizer_name") == quantizer_name: + if entry.get("quantizer_name") == quantizer_name: result = entry if result is None: raise KeyError(f"No quant_cfg entry with quantizer_name={quantizer_name!r}") @@ -930,13 +1001,28 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): QuantizeQuantCfgType = list[QuantizerCfgEntry] QuantizerCfgListConfig = QuantizeQuantCfgType +# Pre-normalization input shape: either a sequence of already-validated +# :class:`QuantizerCfgEntry` instances, or a sequence of raw mappings (any of the legacy / +# new dict forms). Splitting the union into two ``Sequence[...]`` arms — rather than +# ``Sequence[QuantizerCfgEntry | Mapping[str, Any]]`` — keeps each arm covariant in its +# element type, so callers can pass ``list[QuantizerCfgEntry]`` or ``list[dict]`` without +# tripping invariance. +RawQuantizeQuantCfgType = Sequence[QuantizerCfgEntry] | Sequence[Mapping[str, Any]] + +# Legacy flat-dict input shape (``{"*": ..., "*weight_quantizer": ...}``). Accepted by +# ``normalize_quant_cfg_list`` for backward compatibility but emits a DeprecationWarning; +# new code should use a list of :class:`QuantizerCfgEntry`-shaped entries instead. +DeprecatedQuantCfgType = Mapping[str, Any] + _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None -def normalize_quant_cfg_list(v: dict | list) -> list[QuantizerCfgEntry]: - """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` dicts. +def normalize_quant_cfg_list( + v: RawQuantizeQuantCfgType | DeprecatedQuantCfgType, +) -> list[QuantizerCfgEntry]: + """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` instances. Supports the following input forms: @@ -951,35 +1037,19 @@ def normalize_quant_cfg_list(v: dict | list) -> list[QuantizerCfgEntry]: - Legacy ``nn.*``-scoped format: ``{"nn.": {"": }}`` — converted to a new-format entry with ``parent_class`` set. - **Validation** — an entry is rejected if it carries no instruction, i.e. it specifies neither - ``cfg`` nor ``enable``. Concretely, the following are invalid: - - - An empty entry ``{}``. - - An entry with only ``quantizer_name`` and no other keys — the only effect would be an - implicit ``enable=True``, which must be stated explicitly. - - An entry with ``enable=True`` (explicit or implicit) whose ``cfg`` is not a non-empty - ``dict`` or ``list`` — e.g. ``{"quantizer_name": "*", "cfg": {}}`` or - ``{"quantizer_name": "*", "cfg": 42}``. An enabled quantizer must have a valid - configuration. - - **Normalization** — after conversion and validation every entry is put into canonical form: - - - ``enable`` is set to ``True`` if not explicitly specified. - - ``cfg`` is set to ``None`` if not present in the entry. - - Every returned entry is therefore guaranteed to have the keys ``quantizer_name``, ``enable``, - and ``cfg`` (plus optionally ``parent_class``). + Each normalized dict is then constructed into a :class:`QuantizerCfgEntry`, whose own + validator enforces that every entry specifies ``cfg``, ``enable``, or both, and that any + ``cfg`` for an enabled quantizer is a non-empty dict or non-empty list of non-empty dicts. Args: v: A list of raw quant_cfg entries in any supported format, or a legacy flat dict. Returns: - A list of :class:`QuantizerCfgEntry` dicts in canonical normalized form. + A list of validated :class:`QuantizerCfgEntry` instances. Raises: - ValueError: If any entry has only ``quantizer_name`` with neither ``cfg`` nor ``enable``, - if ``enable=True`` with an empty or non-dict/list ``cfg``, or if the entry format - is not recognized. + ValueError: If any entry's shape is not recognized, or if it fails + :class:`QuantizerCfgEntry` validation (missing instruction or invalid ``cfg``). """ def _warn_legacy(): @@ -993,26 +1063,33 @@ def _warn_legacy(): ) # Legacy flat-dict format: {"*": {...}, "*weight_quantizer": {...}} → list of single-key dicts. - if isinstance(v, dict): + if isinstance(v, Mapping): _warn_legacy() v = [{k: val} for k, val in v.items()] + elif not isinstance(v, Sequence) or isinstance(v, (str, bytes)): + raise ValueError( + f"quant_cfg must be a sequence of entries (or a legacy flat mapping), got " + f"{type(v).__name__}: {v!r}." + ) - def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: - """Convert a single legacy key-value pair to one or more QuantizerCfgEntry dicts.""" + def _dict_to_entry(key: str, value) -> list[dict[str, Any]]: + """Convert a single legacy key-value pair to one or more entry dicts.""" # Legacy "default" key was a catch-all applied as "*" in the old conversion code. if key == "default": key = "*" if isinstance(key, str) and key.startswith("nn."): - if not isinstance(value, dict): - raise ValueError(f"For 'nn.*' scoped format, value must be a dict, got {value!r}") + if not isinstance(value, Mapping): + raise ValueError( + f"For 'nn.*' scoped format, value must be a mapping, got {value!r}" + ) # Support multi-key nn.*-scoped dicts by emitting one entry per sub-key. - entries: list[QuantizerCfgEntry] = [] + entries: list[dict[str, Any]] = [] for q_path, sub_cfg in value.items(): sub_cfg = dict(sub_cfg) enable = sub_cfg.pop("enable", None) cfg = sub_cfg or None - entry: QuantizerCfgEntry = { + entry: dict[str, Any] = { "parent_class": key, "quantizer_name": q_path, "cfg": cfg, @@ -1022,7 +1099,7 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: entries.append(entry) return entries else: - if isinstance(value, dict): + if isinstance(value, Mapping): cfg = {k: val for k, val in value.items() if k != "enable"} or None enable = value.get("enable") else: @@ -1036,15 +1113,21 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: result: list[QuantizerCfgEntry] = [] _warned_legacy = False for raw in v: - if isinstance(raw, dict) and "quantizer_name" in raw: - entries = [dict(raw)] # copy to avoid mutating caller's data - elif isinstance(raw, dict) and len(raw) == 1: + # Already-validated QuantizerCfgEntry instances (e.g. produced by load_config on a + # snippet schematized with `# modelopt-schema: QuantizerCfgEntry`, then spread into + # a quant_cfg list) are passed through unchanged. + if isinstance(raw, QuantizerCfgEntry): + result.append(raw) + continue + if isinstance(raw, Mapping) and "quantizer_name" in raw: + entries: list[dict[str, Any]] = [dict(raw)] # copy to avoid mutating caller's data + elif isinstance(raw, Mapping) and len(raw) == 1: key, val = next(iter(raw.items())) entries = [dict(e) for e in _dict_to_entry(key, val)] if not _warned_legacy: _warn_legacy() _warned_legacy = True - elif isinstance(raw, dict) and len(raw) > 1 and any(k.startswith("nn.") for k in raw): + elif isinstance(raw, Mapping) and len(raw) > 1 and any(k.startswith("nn.") for k in raw): # Legacy flat dict with nn.*-scoped keys mixed with other keys — expand all pairs. entries = [] for k, val in raw.items(): @@ -1055,42 +1138,10 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: else: raise ValueError(f"Invalid quant_cfg entry: {raw!r}.") - for entry in entries: - # Validate: must carry at least one instruction beyond the path selector. - if "cfg" not in entry and "enable" not in entry: - raise ValueError( - f"Invalid quant_cfg entry: {raw!r} — each entry must specify 'cfg', 'enable', " - "or both. An entry with only 'quantizer_name' has no effect (implicit " - "enable=True is not allowed; set it explicitly)." - ) - - # Validate: when cfg is present and enable=True, cfg must be a non-empty - # dict or list. An empty cfg would attempt to create a - # QuantizerAttributeConfig with no actual configuration. - cfg = entry.get("cfg") - enable = entry.get("enable", True) - if enable and cfg is not None: - if isinstance(cfg, dict): - is_invalid = len(cfg) == 0 - elif isinstance(cfg, list): - is_invalid = len(cfg) == 0 or any( - not isinstance(item, dict) or len(item) == 0 for item in cfg - ) - else: - is_invalid = True - if is_invalid: - raise ValueError( - f"Invalid quant_cfg entry: {raw!r} — 'cfg' must be a non-empty dict " - f"or a non-empty list of non-empty dicts when enabling a quantizer " - f"(got {type(cfg).__name__}: {cfg!r}). Either provide quantizer " - "attributes in 'cfg' or remove 'cfg' and set 'enable' explicitly." - ) - - # Normalize: make enable and cfg always explicit. - entry.setdefault("enable", True) - entry.setdefault("cfg", None) - - result.append(cast("QuantizerCfgEntry", entry)) + # Constructing each QuantizerCfgEntry runs its model_validator, which enforces the + # at-least-one-of('cfg', 'enable') and cfg-shape constraints. Defaults for absent + # 'cfg' / 'enable' are filled by the pydantic field defaults. + result.extend(QuantizerCfgEntry(**entry) for entry in entries) return result @@ -1112,27 +1163,18 @@ class QuantizeConfig(ModeloptBaseConfig): @field_validator("quant_cfg", mode="before") @classmethod - def normalize_quant_cfg(cls, v): - """Normalize quant_cfg entries: convert dict and tuple forms to QuantizerCfgEntry dicts.""" - if not isinstance(v, (list, dict)): - return v + def normalize_quant_cfg( + cls, v: RawQuantizeQuantCfgType | DeprecatedQuantCfgType + ) -> QuantizeQuantCfgType: + """Normalize raw quant_cfg input into a ``list[QuantizerCfgEntry]``. + + Delegates to :func:`normalize_quant_cfg_list`, which accepts every supported input + shape (new-format list, legacy single-key-dict list, legacy flat dict, and lists + containing already-validated ``QuantizerCfgEntry`` instances) and rejects anything + else with a clear ``ValueError`` before pydantic's field-type check would see it. + """ return normalize_quant_cfg_list(v) - @field_validator("quant_cfg", mode="after") - @classmethod - def validate_quant_cfg_entries(cls, v): - """Validate quantizer attribute configs to surface errors (e.g. invalid axis/block_sizes).""" - qac_fields = set(QuantizerAttributeConfig.model_fields.keys()) - for entry in v: - cfg = entry.get("cfg") - if cfg is None: - continue - cfgs = cfg if isinstance(cfg, list) else [cfg] - for c in cfgs: - if isinstance(c, dict) and qac_fields & set(c.keys()): - QuantizerAttributeConfig.model_validate(c) - return v - class CompressConfig(ModeloptBaseConfig): """Default configuration for ``compress`` mode.""" @@ -1157,15 +1199,24 @@ class _QuantizeExportConfig(ModeloptBaseConfig): """An empty config.""" -_base_disable_all: list[QuantizerCfgEntry] = [ - cast("QuantizerCfgEntry", load_config("configs/ptq/units/base_disable_all")) +# Shared snippet constants are dumped back to plain dicts before being spliced into +# the public quant config constants below. ``load_config`` returns validated +# ``QuantizerCfgEntry`` instances for schema-tagged files, but the public constants +# (``INT4_AWQ_CFG``, ``NVFP4_DEFAULT_CFG``, etc.) have always been raw dict/list trees; +# splatting schema instances into them would surprise callers that serialise the +# constants or do ``isinstance(entry, dict)`` checks. ``exclude_unset=True`` keeps the +# sparse YAML shape (only the explicitly set fields) so the dumped dicts are +# byte-identical to what authors wrote in the YAML snippets. +_base_disable_all: list[dict[str, Any]] = [ + load_config("configs/ptq/units/base_disable_all").model_dump(exclude_unset=True) ] -_default_disabled_quantizer_cfg: list[QuantizerCfgEntry] = load_config( - "configs/ptq/units/default_disabled_quantizers" -) +_default_disabled_quantizer_cfg: list[dict[str, Any]] = [ + entry.model_dump(exclude_unset=True) + for entry in load_config("configs/ptq/units/default_disabled_quantizers") +] -_mamba_moe_disabled_quantizer_cfg: list[QuantizerCfgEntry] = [ +_mamba_moe_disabled_quantizer_cfg: list[dict[str, Any]] = [ {"quantizer_name": "*fc1_latent_proj*", "enable": False}, # Skip Latent MOE {"quantizer_name": "*fc2_latent_proj*", "enable": False}, # Skip Latent MOE {"quantizer_name": "*q_proj*", "enable": False}, # Skip QKV Linear (HF naming) @@ -1212,7 +1263,9 @@ class _QuantizeExportConfig(ModeloptBaseConfig): "algorithm": "max", } -FP8_DEFAULT_CFG: dict[str, Any] = load_config("configs/ptq/presets/model/fp8") +FP8_DEFAULT_CFG: dict[str, Any] = load_config("configs/ptq/presets/model/fp8").model_dump( + exclude_unset=True +) MAMBA_MOE_FP8_AGGRESSIVE_CFG = { "quant_cfg": [ @@ -1457,7 +1510,9 @@ class _QuantizeExportConfig(ModeloptBaseConfig): # KV-cache configs are designed to be merged with a primary quantization config (e.g. # FP8_DEFAULT_CFG) that already contains _base_disable_all. They intentionally omit both # _base_disable_all and "algorithm" because these are provided by the primary config. -FP8_KV_CFG: dict[str, Any] = load_config("configs/ptq/presets/kv/fp8") +FP8_KV_CFG: dict[str, Any] = load_config("configs/ptq/presets/kv/fp8").model_dump( + exclude_unset=True +) FP8_AFFINE_KV_CFG = { "quant_cfg": [ @@ -1490,7 +1545,7 @@ def _nvfp4_selective_quant_cfg( algorithm: str | dict = "max", ) -> dict: """Build an NVFP4 config that quantizes only the specified layer patterns.""" - quant_cfg: list[QuantizerCfgEntry] = [] + quant_cfg: list[dict[str, Any]] = [] quant_cfg.extend(_base_disable_all) for pattern in layer_patterns: # Deep-copy the quantizer dict so each config constant gets its own instance. @@ -1759,7 +1814,7 @@ def _nvfp4_selective_quant_cfg( } -def need_calibration(config): +def need_calibration(config: QuantizeConfig | Mapping[str, Any]) -> bool: """Check if calibration is needed for the given config.""" if config["algorithm"] is not None and config["algorithm"] != "max": return True @@ -1767,8 +1822,8 @@ def need_calibration(config): def _not_dynamic(cfg): return cfg.get("enable", True) and cfg.get("type", "") != "dynamic" - quant_cfg: list = config.get("quant_cfg") or [] - quant_cfg = normalize_quant_cfg_list(quant_cfg) + raw_quant_cfg: RawQuantizeQuantCfgType | DeprecatedQuantCfgType = config.get("quant_cfg") or [] + quant_cfg: list[QuantizerCfgEntry] = normalize_quant_cfg_list(raw_quant_cfg) for entry in quant_cfg: name = entry["quantizer_name"] raw_cfg = entry.get("cfg") diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 3f97f8380be..40c2b8dbc7e 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -31,8 +31,8 @@ from .config import ( QuantizeConfig, - QuantizeQuantCfgType, QuantizerAttributeConfig, + RawQuantizeQuantCfgType, _QuantizeExportConfig, normalize_quant_cfg_list, ) @@ -215,7 +215,7 @@ def _replace_quant_module(model: nn.Module, version=None, registry=QuantModuleRe _replace_quant_module(getattr(model, name), version=version, registry=registry) -def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): +def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: RawQuantizeQuantCfgType): """Apply a quantization config list to the quantizers in ``quant_model``. ``quant_cfg`` is an **ordered list** of :class:`QuantizerCfgEntry <.config.QuantizerCfgEntry>` @@ -477,7 +477,7 @@ def set_quantizer_attributes_partial( @contextmanager -def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): +def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: RawQuantizeQuantCfgType): """Context manager that temporarily applies a quantization config and restores the original state on exit. Calls :func:`set_quantizer_by_cfg` on entry and reverts every diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index b6c872a7164..ce241150a3b 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -46,6 +46,10 @@ quantize: {} """ +CFG_RECIPE_MISSING_METADATA = """\ +quantize: {} +""" + CFG_RECIPE_MISSING_quantize = """\ metadata: recipe_type: ptq @@ -54,6 +58,7 @@ CFG_RECIPE_UNSUPPORTED_TYPE = """\ metadata: recipe_type: unknown_type +quantize: {} """ QUANTIZER_ATTRIBUTE_SCHEMA = ( @@ -176,18 +181,27 @@ def test_load_recipe_missing_recipe_type_raises(tmp_path): def test_load_recipe_missing_quantize_raises(tmp_path): - """load_recipe raises ValueError when quantize is absent for a PTQ recipe.""" + """A PTQ recipe missing the ``quantize`` section is rejected (no silent default).""" bad = tmp_path / "bad.yml" bad.write_text(CFG_RECIPE_MISSING_quantize) with pytest.raises(ValueError, match="quantize"): load_recipe(bad) +def test_load_recipe_missing_metadata_raises(tmp_path): + """A recipe missing the ``metadata`` section is rejected (no silent default).""" + bad = tmp_path / "bad.yml" + bad.write_text(CFG_RECIPE_MISSING_METADATA) + with pytest.raises(ValueError, match="metadata"): + load_recipe(bad) + + def test_load_recipe_unsupported_type_raises(tmp_path): """load_recipe raises ValueError for an unknown recipe_type.""" bad = tmp_path / "bad.yml" bad.write_text(CFG_RECIPE_UNSUPPORTED_TYPE) - with pytest.raises(ValueError, match="Unsupported recipe type"): + # Schema-driven validation reports the failure via the metadata schema's enum check. + with pytest.raises(ValueError, match="recipe_type"): load_recipe(bad) @@ -525,7 +539,7 @@ def test_import_resolves_cfg_reference(tmp_path): ) recipe = load_recipe(recipe_file) entry = recipe.quantize["quant_cfg"][0] - assert entry["cfg"] == {"num_bits": (4, 3), "axis": None} + assert entry["cfg"].model_dump(exclude_unset=True) == {"num_bits": (4, 3), "axis": None} def test_import_same_name_used_twice(tmp_path): @@ -598,7 +612,10 @@ def test_import_inline_cfg_not_affected(tmp_path): f" axis: 0\n" ) recipe = load_recipe(recipe_file) - assert recipe.quantize["quant_cfg"][1]["cfg"] == {"num_bits": 8, "axis": 0} + assert recipe.quantize["quant_cfg"][1]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": 8, + "axis": 0, + } def test_import_unknown_reference_raises(tmp_path): @@ -734,7 +751,15 @@ def test_import_entry_element_schema_appends(tmp_path): f" - $import: disable_all\n" ) recipe = load_recipe(recipe_file) - assert recipe.quantize["quant_cfg"] == [{"quantizer_name": "*", "cfg": None, "enable": False}] + # Entry was loaded against the QuantizerCfgEntry pydantic schema, so it is now a + # model instance — compare via model_dump for the dict-shape check. + assert len(recipe.quantize["quant_cfg"]) == 1 + assert recipe.quantize["quant_cfg"][0].model_dump() == { + "quantizer_name": "*", + "parent_class": None, + "cfg": None, + "enable": False, + } def test_import_entry_wrong_schema_raises(tmp_path): @@ -819,7 +844,7 @@ def test_import_cfg_extend(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "axis": 0} + assert cfg.model_dump(exclude_unset=True) == {"num_bits": (4, 3), "axis": 0} def test_import_cfg_inline_overrides_import(tmp_path): @@ -882,6 +907,7 @@ def test_import_in_multiple_dict_values(tmp_path): ) data = load_config(config_file) entry = data["quant_cfg"][0] + # load_config has no schema here — data is a raw dict tree, so entry["cfg"] is a dict. assert entry["cfg"] == {"num_bits": (4, 3)} assert entry["my_field"] == {"fake_quant": False} @@ -906,7 +932,7 @@ def test_import_cfg_multi_import(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "axis": 0} + assert cfg.model_dump(exclude_unset=True) == {"num_bits": (4, 3), "axis": 0} def test_import_cfg_multi_import_later_overrides_earlier(tmp_path): @@ -955,7 +981,11 @@ def test_import_cfg_multi_import_with_extend(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "fake_quant": False, "axis": 0} + assert cfg.model_dump(exclude_unset=True) == { + "num_bits": (4, 3), + "fake_quant": False, + "axis": 0, + } def test_import_dir_format(tmp_path): @@ -972,7 +1002,10 @@ def test_import_dir_format(tmp_path): " $import: fp8\n" ) recipe = load_recipe(tmp_path) - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3), "axis": None} + assert recipe.quantize["quant_cfg"][0]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": (4, 3), + "axis": None, + } def test_import_dir_format_metadata_imports_do_not_apply_to_quantize(tmp_path): @@ -1026,7 +1059,9 @@ def test_import_multi_document_list_snippet(tmp_path): recipe = load_recipe(recipe_file) assert len(recipe.quantize["quant_cfg"]) == 1 assert recipe.quantize["quant_cfg"][0]["quantizer_name"] == "*[kv]_bmm_quantizer" - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + assert recipe.quantize["quant_cfg"][0]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": (4, 3) + } def test_import_builtin_kv_fp8_snippet(): @@ -1075,7 +1110,8 @@ def test_import_list_splice_outside_typed_list_raises(tmp_path): """A bare $import in an untyped list is rejected.""" _write_quantizer_cfg_list( tmp_path / "extra_tasks.yml", - "- quantizer_name: '*weight_quantizer'\n- quantizer_name: '*input_quantizer'\n", + "- quantizer_name: '*weight_quantizer'\n enable: false\n" + "- quantizer_name: '*input_quantizer'\n enable: false\n", ) config_file = tmp_path / "config.yml" config_file.write_text( @@ -1137,9 +1173,16 @@ def test_import_mixed_tree(tmp_path): ) data = load_config(config_file) # Dict import inside list entry - assert data["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} - # List splice - assert data["quant_cfg"][1] == {"quantizer_name": "*lm_head*", "enable": False} + assert data["quant_cfg"][0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": (4, 3)} + # List splice — entries are normalized by QuantizeConfig.quant_cfg's validator, + # which fills in defaults for missing ``enable`` / ``cfg`` keys. Entries are now + # QuantizerCfgEntry pydantic instances, so compare via model_dump. + assert data["quant_cfg"][1].model_dump() == { + "quantizer_name": "*lm_head*", + "parent_class": None, + "enable": False, + "cfg": None, + } # --------------------------------------------------------------------------- @@ -1178,7 +1221,7 @@ def test_import_recursive(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3)} + assert cfg.model_dump(exclude_unset=True) == {"num_bits": (4, 3)} def test_import_circular_raises(tmp_path): @@ -1278,9 +1321,14 @@ def test_import_cross_file_same_name_no_conflict(tmp_path): ) recipe = load_recipe(recipe_file) # Parent's "fmt" resolves to fp8 (e4m3), not child's nvfp4. - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + assert recipe.quantize["quant_cfg"][0]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": (4, 3) + } # Child's "fmt" resolves to nvfp4 (e2m1), not parent's fp8. - assert recipe.quantize["quant_cfg"][1]["cfg"] == {"num_bits": (2, 1), "axis": 0} + assert recipe.quantize["quant_cfg"][1]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": (2, 1), + "axis": 0, + } # --------------------------------------------------------------------------- @@ -1311,8 +1359,10 @@ def test_builtin_config_snippets_with_modelopt_schema(config_path): assert data -def test_modelopt_schema_comment_validates_without_changing_payload(tmp_path): - """modelopt-schema validates the resolved payload but load_config still returns a plain dict.""" +def test_modelopt_schema_comment_returns_instance(tmp_path): + """A ``modelopt-schema`` comment makes load_config return an instance of that schema.""" + from modelopt.torch.quantization.config import QuantizerAttributeConfig + config_file = tmp_path / "fp8.yaml" config_file.write_text( "# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig\n" @@ -1320,7 +1370,9 @@ def test_modelopt_schema_comment_validates_without_changing_payload(tmp_path): "axis:\n" ) data = load_config(config_file) - assert data == {"num_bits": (4, 3), "axis": None} + assert isinstance(data, QuantizerAttributeConfig) + assert data.num_bits == (4, 3) + assert data.axis is None def test_modelopt_schema_comment_validation_error(tmp_path): @@ -1367,7 +1419,13 @@ def test_modelopt_schema_comment_validates_after_import_resolution(tmp_path): f" $import: fp8\n" ) data = load_config(config_file) - assert data == [{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": (4, 3)}}] + # data is a list of QuantizerCfgEntry pydantic instances, not raw dicts. Dump with + # exclude_unset=True so the inner QuantizerAttributeConfig stays sparse (cascades). + assert len(data) == 1 + assert data[0].model_dump(exclude_unset=True) == { + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": (4, 3)}, + } # --------------------------------------------------------------------------- @@ -1472,7 +1530,13 @@ def test_load_config_list_valued_yaml(tmp_path): data = load_config(cfg_file) assert isinstance(data, list) assert len(data) == 2 - assert data[0] == {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8}} + # Entries are QuantizerCfgEntry pydantic instances after schema validation; dump + # with exclude_unset=True so the inner QuantizerAttributeConfig stays in sparse + # form (pydantic cascades exclude_unset to nested models). + assert data[0].model_dump(exclude_unset=True) == { + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": 8}, + } # --------------------------------------------------------------------------- @@ -1484,7 +1548,8 @@ def test_import_dict_value_resolves_to_list_raises(tmp_path): """$import in dict value position raises when snippet is a list.""" _write_quantizer_cfg_list( tmp_path / "entries.yml", - "- quantizer_name: '*weight_quantizer'\n- quantizer_name: '*input_quantizer'\n", + "- quantizer_name: '*weight_quantizer'\n enable: false\n" + "- quantizer_name: '*input_quantizer'\n enable: false\n", ) config_file = tmp_path / "config.yml" config_file.write_text( diff --git a/tests/unit/torch/opt/test_config.py b/tests/unit/torch/opt/test_config.py index b2ffadb1a78..e0c5993a51a 100644 --- a/tests/unit/torch/opt/test_config.py +++ b/tests/unit/torch/opt/test_config.py @@ -72,7 +72,7 @@ def _run_test(is_new_registered): assert config[lin_name] == lin_expected_value assert config[lin_alias] == lin_expected_value assert getattr(config, lin_name) == lin_expected_value - with nullcontext() if is_new_registered else pytest.raises(AttributeError): + with nullcontext() if is_new_registered else pytest.raises(KeyError): config[new_name] # get diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 84306dc5116..ce98f989f51 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -81,7 +81,7 @@ def test_new_format_passthrough(self): result = normalize_quant_cfg_list(raw) assert len(result) == 1 assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 8, "axis": 0} assert result[0]["enable"] is True # defaulted def test_new_format_enable_false(self): @@ -103,7 +103,7 @@ def test_legacy_single_key_dict(self): raw = [{"*weight_quantizer": {"num_bits": 8, "axis": 0}}] result = normalize_quant_cfg_list(raw) assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 8, "axis": 0} assert result[0]["enable"] is True # defaulted def test_legacy_single_key_dict_with_enable(self): @@ -166,57 +166,101 @@ def test_error_on_multi_key_legacy_dict(self): def test_error_on_empty_cfg_dict_implicit_enable(self): """Entry with cfg={} and implicit enable=True is rejected.""" - with pytest.raises(ValueError, match="non-empty dict"): + with pytest.raises(ValueError, match=r"at least one quantizer attribute"): normalize_quant_cfg_list([{"quantizer_name": "*weight_quantizer", "cfg": {}}]) def test_error_on_empty_cfg_dict_explicit_enable_true(self): """Entry with cfg={} and explicit enable=True is rejected.""" - with pytest.raises(ValueError, match="non-empty dict"): + with pytest.raises(ValueError, match=r"at least one quantizer attribute"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": {}, "enable": True}] ) def test_error_on_empty_cfg_list_enable_true(self): """Entry with cfg=[] and enable=True is rejected.""" - with pytest.raises(ValueError, match="non-empty dict"): + with pytest.raises(ValueError, match=r"at least one quantizer attribute"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": [], "enable": True}] ) def test_error_on_non_dict_non_list_cfg_enable_true(self): - """Entry with cfg of invalid type (e.g. int) and enable=True is rejected.""" - with pytest.raises(ValueError, match="non-empty dict"): + """Entry with cfg of invalid type (e.g. int) and enable=True is rejected. + + Two error paths are acceptable here, and the assertion accepts either: + pydantic's field-type check (``cfg`` must be a dict or list) fires first when + ``cfg`` is the wrong python type, while ``QuantizerCfgEntry``'s model validator + emits the "non-empty dict" message when ``cfg`` is the right type but empty. + Either way the message must implicate the ``cfg`` field, not just any + ``ValueError``. + """ + with pytest.raises(ValueError, match=r"(?s)cfg.*(non-empty|valid dictionary|valid list)"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": 42, "enable": True}] ) def test_error_on_cfg_list_with_empty_dict_enable_true(self): """Entry with cfg=[{}] and enable=True is rejected (empty dict element).""" - with pytest.raises(ValueError, match="non-empty dict"): + with pytest.raises(ValueError, match=r"at least one quantizer attribute"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": [{}], "enable": True}] ) def test_error_on_cfg_list_with_non_dict_element_enable_true(self): - """Entry with cfg=[42] and enable=True is rejected (non-dict element).""" - with pytest.raises(ValueError, match="non-empty dict"): + """Entry with cfg=[42] and enable=True is rejected. + + Same dual-path acceptance as :meth:`test_error_on_non_dict_non_list_cfg_enable_true`: + pydantic may report a list-element type error, or the model validator may report + "non-empty dict"; the assertion accepts either as long as the message names the + ``cfg`` field. + """ + with pytest.raises(ValueError, match=r"(?s)cfg.*(non-empty|valid dictionary|valid list)"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": [42], "enable": True}] ) - def test_empty_cfg_dict_enable_false_accepted(self): - """Entry with cfg={} and enable=False is allowed (disable-only entry).""" + def test_empty_cfg_dict_enable_false_normalized_to_none(self): + """Entry with cfg={} and enable=False is normalised to cfg=None (disable-only). + + A non-``None`` cfg is applied as a full quantizer-attribute replacement, so an + empty cfg paired with enable=False would silently reset the quantizer's + attributes. Normalisation to ``None`` makes the entry behave like a pure + disable, preserving the existing attribute config. + """ result = normalize_quant_cfg_list( [{"quantizer_name": "*input_quantizer", "cfg": {}, "enable": False}] ) assert result[0]["enable"] is False + assert result[0]["cfg"] is None - def test_empty_cfg_list_enable_false_accepted(self): - """Entry with cfg=[] and enable=False is allowed (disable-only entry).""" + def test_empty_cfg_list_enable_false_normalized_to_none(self): + """Entry with cfg=[] and enable=False is normalised to cfg=None.""" result = normalize_quant_cfg_list( [{"quantizer_name": "*input_quantizer", "cfg": [], "enable": False}] ) assert result[0]["enable"] is False + assert result[0]["cfg"] is None + + def test_cfg_list_of_empty_dicts_enable_false_normalized_to_none(self): + """Entry with cfg=[{}] and enable=False is normalised to cfg=None.""" + result = normalize_quant_cfg_list( + [{"quantizer_name": "*input_quantizer", "cfg": [{}], "enable": False}] + ) + assert result[0]["enable"] is False + assert result[0]["cfg"] is None + + def test_nonempty_cfg_enable_false_preserved(self): + """Entry with a non-empty cfg and enable=False keeps the cfg (disable+replace).""" + result = normalize_quant_cfg_list( + [ + { + "quantizer_name": "*input_quantizer", + "cfg": {"num_bits": 4}, + "enable": False, + } + ] + ) + assert result[0]["enable"] is False + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 4} def test_new_format_with_list_cfg(self): """cfg can be a list of dicts for SequentialQuantizer.""" @@ -231,7 +275,7 @@ def test_new_format_with_list_cfg(self): ] result = normalize_quant_cfg_list(raw) assert len(result) == 1 - assert result[0]["cfg"] == raw[0]["cfg"] + assert [c.model_dump(exclude_unset=True) for c in result[0]["cfg"]] == raw[0]["cfg"] assert result[0]["enable"] is True def test_legacy_flat_dict_conversion(self): @@ -243,7 +287,7 @@ def test_legacy_flat_dict_conversion(self): assert result[0]["enable"] is False assert result[0]["cfg"] is None assert result[1]["quantizer_name"] == "*weight_quantizer" - assert result[1]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[1]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 8, "axis": 0} assert result[1]["enable"] is True def test_legacy_enable_only_produces_cfg_none(self): @@ -274,7 +318,7 @@ def test_legacy_default_key_with_cfg(self): raw = [{"default": {"num_bits": 8, "axis": None}}] result = normalize_quant_cfg_list(raw) assert result[0]["quantizer_name"] == "*" - assert result[0]["cfg"] == {"num_bits": 8, "axis": None} + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 8, "axis": None} assert result[0]["enable"] is True def test_legacy_flat_dict_with_default_key(self): @@ -309,7 +353,7 @@ def test_legacy_nn_class_with_cfg(self): assert len(result) == 1 assert result[0]["parent_class"] == "nn.Linear" assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 4, "axis": 0} + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 4, "axis": 0} assert result[0]["enable"] is True def test_legacy_list_valued_cfg(self): @@ -343,7 +387,7 @@ def test_finds_last_match(self): ] ) result = find_quant_cfg_entry_by_path(entries, "*weight_quantizer") - assert result["cfg"] == {"num_bits": 4} + assert result["cfg"].model_dump(exclude_unset=True) == {"num_bits": 4} def test_exact_match_only(self): """Does not do fnmatch — only exact string equality on quantizer_name.""" @@ -400,7 +444,7 @@ def test_wildcard_matches_bare_name(self): [{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8}}] ) matched, enable = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 8} + assert matched.model_dump(exclude_unset=True) == {"num_bits": 8} assert enable is True def test_star_matches_any_bare_name(self): @@ -420,7 +464,7 @@ def test_path_scoped_pattern_matches_matching_suffix(self): [{"quantizer_name": "*mlp*weight_quantizer", "cfg": {"num_bits": 4}}] ) matched, enable = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 4} + assert matched.model_dump(exclude_unset=True) == {"num_bits": 4} def test_path_scoped_pattern_does_not_match_different_suffix(self): """'*mlp*weight_quantizer' does NOT match bare 'input_quantizer'.""" @@ -444,7 +488,7 @@ def test_last_match_wins(self): ] ) matched, _ = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 4} + assert matched.model_dump(exclude_unset=True) == {"num_bits": 4} def test_no_match_returns_none(self): """No matching entry returns (None, None)."""