From e21a9baab1fb6e36c0e4c680d0c8dcbff83030f8 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Wed, 13 May 2026 18:00:49 -0700 Subject: [PATCH 01/14] feat(opt): make load_config return validated schema instances load_config now returns a validated instance of the effective schema when one is known. The schema_type argument takes priority; otherwise the file's `# modelopt-schema:` comment is used. Without either, the raw resolved dict/list is returned unchanged. Imported snippets are still strictly required to declare modelopt-schema. Simplify modelopt.recipe.loader to consume the returned ModelOptPTQRecipe / QuantizeConfig instance directly, and update the recipe-loader tests for the new behavior (schema-comment returns an instance; missing `quantize` uses the schema default; unknown recipe_type surfaces via Pydantic's validation message; quant_cfg entries are normalized at load time). Signed-off-by: Shengliang Xu --- modelopt/recipe/loader.py | 35 ++++++++--------------- modelopt/torch/opt/config_loader.py | 44 ++++++++++++++++++++--------- tests/unit/recipe/test_loader.py | 36 +++++++++++++++-------- 3 files changed, 66 insertions(+), 49 deletions(-) diff --git a/modelopt/recipe/loader.py b/modelopt/recipe/loader.py index 9c3c40856d2..8608e3fcbb8 100644 --- a/modelopt/recipe/loader.py +++ b/modelopt/recipe/loader.py @@ -89,29 +89,15 @@ def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBas The file must contain a ``metadata`` section with at least ``recipe_type``, plus a ``quant_cfg`` mapping and an optional ``algorithm`` for PTQ recipes. """ - data = load_config(recipe_file, schema_type=ModelOptPTQRecipe) - if not isinstance(data, dict): + recipe = load_config(recipe_file, schema_type=ModelOptPTQRecipe) + if not isinstance(recipe, ModelOptPTQRecipe): raise ValueError( - f"Recipe file {recipe_file} must be a YAML mapping, got {type(data).__name__}." + f"Recipe file {recipe_file} must produce a {ModelOptPTQRecipe.__name__}, " + f"got {type(recipe).__name__}." ) - - metadata = data.get("metadata", {}) - if not isinstance(metadata, dict): - raise ValueError( - f"Recipe file {recipe_file} field 'metadata' must be a mapping, " - f"got {type(metadata).__name__}." - ) - recipe_type = metadata.get("recipe_type") - if recipe_type is None: - raise ValueError(f"Recipe file {recipe_file} must contain a 'metadata.recipe_type' field.") - + recipe_type = recipe.recipe_type if recipe_type == RecipeType.PTQ: - if "quantize" not in data: - raise ValueError(f"PTQ recipe file {recipe_file} must contain 'quantize'.") - return ModelOptPTQRecipe( - metadata=metadata, - quantize=data["quantize"], - ) + return recipe raise ValueError(f"Unsupported recipe type: {recipe_type!r}") @@ -149,13 +135,14 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase: if recipe_type == RecipeType.PTQ: quantize_file = _find_recipe_section_file(recipe_dir, "quantize") - quantize_data = load_config(quantize_file, schema_type=QuantizeConfig) - if not isinstance(quantize_data, dict): + quantize_cfg = load_config(quantize_file, schema_type=QuantizeConfig) + if not isinstance(quantize_cfg, QuantizeConfig): raise ValueError( - f"{quantize_file} must be a YAML mapping, got {type(quantize_data).__name__}." + f"{quantize_file} must produce a {QuantizeConfig.__name__}, " + f"got {type(quantize_cfg).__name__}." ) return ModelOptPTQRecipe( metadata=metadata, - quantize=quantize_data, + quantize=quantize_cfg, ) raise ValueError(f"Unsupported recipe type: {recipe_type!r}") diff --git a/modelopt/torch/opt/config_loader.py b/modelopt/torch/opt/config_loader.py index 43231c90995..5dbf0ad5bf1 100644 --- a/modelopt/torch/opt/config_loader.py +++ b/modelopt/torch/opt/config_loader.py @@ -596,25 +596,43 @@ def load_config( config_path: str | Path | Traversable, *, schema_type: Any | None = None, -) -> dict[str, Any] | list[Any]: +) -> Any: """Load a YAML config and resolve all ``$import`` references. This is the primary config loading entry point. It loads the YAML file, - resolves any ``imports`` / ``$import`` directives, and returns the final - config dict or list. - - ``schema_type`` supplies a typing context for import resolution when the - file itself has no ``modelopt-schema`` comment. It is intentionally not a - request to validate the top-level file. Top-level files are validated only - when they declare ``modelopt-schema``; imported snippets are stricter and - must always declare ``modelopt-schema``. + resolves any ``imports`` / ``$import`` directives, and returns either a + validated instance of the schema (when one is known) or the raw resolved + payload. + + The effective schema is selected as follows: + + 1. If ``schema_type`` is provided, it is used. + 2. Otherwise, the schema declared by the file's ``# modelopt-schema:`` + comment (if any) is used. + + When an effective schema is selected, the resolved payload is validated + and returned as an instance of that schema — e.g., a Pydantic model + instance for ``BaseModel`` schemas, or a validated dict / list for + ``TypedDict`` / ``list[TypedDict]`` schemas. If neither source supplies a + schema, the raw resolved dict or list is returned unchanged. + + Imported snippets are stricter and must always declare ``modelopt-schema``; + they are validated during import resolution regardless of the top-level + selection above. """ raw = _load_raw_config_with_schema(config_path) data = raw.data declared_schema_type = _schema_type(raw.schema) if raw.schema else None - resolver_schema_type = declared_schema_type or schema_type + effective_schema_type = schema_type if schema_type is not None else declared_schema_type if isinstance(data, (_ListSnippet, dict)): - data = _resolve_imports(data, schema_type=resolver_schema_type) - _validate_modelopt_schema(raw.schema, data, raw.path, schema_type=declared_schema_type) - return data + data = _resolve_imports(data, schema_type=effective_schema_type) + if effective_schema_type is None: + return data + try: + return TypeAdapter(effective_schema_type).validate_python(data) + except Exception as exc: + raise ValueError( + f"Config file {raw.path} does not match modelopt-schema " + f"{_schema_label(effective_schema_type, raw.schema)!r}: {exc}" + ) from exc diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index 738dfc268ca..bfbce21ed55 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -170,19 +170,22 @@ def test_load_recipe_missing_recipe_type_raises(tmp_path): load_recipe(bad) -def test_load_recipe_missing_quantize_raises(tmp_path): - """load_recipe raises ValueError when quantize is absent for a PTQ recipe.""" - bad = tmp_path / "bad.yml" - bad.write_text(CFG_RECIPE_MISSING_quantize) - with pytest.raises(ValueError, match="quantize"): - load_recipe(bad) +def test_load_recipe_missing_quantize_uses_default(tmp_path): + """``quantize`` is optional in a PTQ recipe; absence yields an empty default config.""" + from modelopt.torch.quantization.config import QuantizeConfig + + good = tmp_path / "good.yml" + good.write_text(CFG_RECIPE_MISSING_quantize) + recipe = load_recipe(good) + assert isinstance(recipe.quantize, QuantizeConfig) def test_load_recipe_unsupported_type_raises(tmp_path): """load_recipe raises ValueError for an unknown recipe_type.""" bad = tmp_path / "bad.yml" bad.write_text(CFG_RECIPE_UNSUPPORTED_TYPE) - with pytest.raises(ValueError, match="Unsupported recipe type"): + # Schema-driven validation reports the failure via the TypedDict's enum check. + with pytest.raises(ValueError, match="recipe_type"): load_recipe(bad) @@ -916,8 +919,13 @@ def test_import_mixed_tree(tmp_path): data = load_config(config_file) # Dict import inside list entry assert data["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} - # List splice - assert data["quant_cfg"][1] == {"quantizer_name": "*lm_head*", "enable": False} + # List splice — entries are normalized by QuantizeConfig.quant_cfg's validator, + # which fills in defaults for missing ``enable`` / ``cfg`` keys. + assert data["quant_cfg"][1] == { + "quantizer_name": "*lm_head*", + "enable": False, + "cfg": None, + } # --------------------------------------------------------------------------- @@ -1089,8 +1097,10 @@ def test_builtin_config_snippets_with_modelopt_schema(config_path): assert data -def test_modelopt_schema_comment_validates_without_changing_payload(tmp_path): - """modelopt-schema validates the resolved payload but load_config still returns a plain dict.""" +def test_modelopt_schema_comment_returns_instance(tmp_path): + """A ``modelopt-schema`` comment makes load_config return an instance of that schema.""" + from modelopt.torch.quantization.config import QuantizerAttributeConfig + config_file = tmp_path / "fp8.yaml" config_file.write_text( "# modelopt-schema: modelopt.torch.quantization.config.QuantizerAttributeConfig\n" @@ -1098,7 +1108,9 @@ def test_modelopt_schema_comment_validates_without_changing_payload(tmp_path): "axis:\n" ) data = load_config(config_file) - assert data == {"num_bits": (4, 3), "axis": None} + assert isinstance(data, QuantizerAttributeConfig) + assert data.num_bits == (4, 3) + assert data.axis is None def test_modelopt_schema_comment_validation_error(tmp_path): From 198a305b45dabc7d8202e7817004f8d77caade36 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 14 May 2026 10:46:03 -0700 Subject: [PATCH 02/14] feat(quant): make QuantizerCfgEntry a ModeloptBaseConfig pydantic type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Convert QuantizerCfgEntry from a TypedDict to a ModeloptBaseConfig subclass so entries are validated at construction (cfg/enable shape rules now run via a model_validator on the schema itself) and load_config returns proper pydantic instances for snippets and lists schematized against QuantizerCfgEntry / QuantizerCfgListConfig. normalize_quant_cfg_list now passes already-validated QuantizerCfgEntry instances through unchanged, so the constants loaded from YAML (e.g. _base_disable_all, _default_disabled_quantizer_cfg) can be spread into preset configs and re-validated by QuantizeConfig without round-tripping through dicts. Consumers needed no functional changes because ModeloptBaseConfig already implements __getitem__, get, __contains__, items, keys, and update — covering every dict-shaped access site in modelopt/, examples/, and tests/. Type plumbing: - Introduce RawQuantizeQuantCfgType (a covariant Sequence) for input positions that get normalized — keeps set_quantizer_by_cfg and set_quantizer_by_cfg_context callable with both raw dict literals and pre-validated entries without invariance errors. - algorithms.AutoQuantizeSearcher now constructs a QuantizerCfgEntry instance instead of appending a bare dict to QuantizeConfig.quant_cfg. Tests: - test_config_validation: relaxed two match=non-empty dict assertions where pydantic's field-type check now fires before the model validator. - test_loader: dict-equality assertions on schema-loaded entries now use model_dump(); two YAML fixtures with bare quantizer_name entries (now rejected by the model validator at load time) had enable: false added. Signed-off-by: Shengliang Xu --- .../llm_export_utils/quantization_utils.py | 4 +- modelopt/torch/quantization/algorithms.py | 6 +- modelopt/torch/quantization/config.py | 169 ++++++++++-------- modelopt/torch/quantization/conversion.py | 6 +- tests/unit/recipe/test_loader.py | 39 +++- .../quantization/test_config_validation.py | 18 +- 6 files changed, 145 insertions(+), 97 deletions(-) diff --git a/modelopt/onnx/llm_export_utils/quantization_utils.py b/modelopt/onnx/llm_export_utils/quantization_utils.py index 54ca93d5388..d6c8c4c1e9a 100644 --- a/modelopt/onnx/llm_export_utils/quantization_utils.py +++ b/modelopt/onnx/llm_export_utils/quantization_utils.py @@ -69,9 +69,7 @@ def get_quant_config(precision, lm_head_precision="fp16"): else: raise ValueError(f"Unsupported precision: {precision}") - quant_cfg_list: list = [ - e for e in quant_cfg["quant_cfg"] if isinstance(e, dict) and "quantizer_name" in e - ] + quant_cfg_list: list = [e for e in quant_cfg["quant_cfg"] if "quantizer_name" in e] if lm_head_precision == "fp8": quant_cfg_list.append( diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 992717983db..e4e633e36ae 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -40,7 +40,7 @@ from . import config as mtq_config from . import model_calib -from .config import QuantizeConfig, QuantizerAttributeConfig +from .config import QuantizeConfig, QuantizerAttributeConfig, QuantizerCfgEntry from .conversion import set_quantizer_by_cfg from .nn import QuantLinearConvBase, QuantModule, SequentialQuantizer, TensorQuantizer from .utils import is_quantized_linear @@ -129,7 +129,9 @@ def __init__(self, quant_cfg: str | dict[str, Any] | None = None, name: str | No # Disable KV Cache quantization # Currently KV Cache quantization is enabled for some quantization formats and disabled for others # This breaks the monotonicity of the quantization formats in terms of weight compression Vs accuracy - self.config.quant_cfg.append({"quantizer_name": "*output_quantizer", "enable": False}) + self.config.quant_cfg.append( + QuantizerCfgEntry(quantizer_name="*output_quantizer", enable=False) + ) self.compression = estimate_quant_compression(self.config) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index dfed54cc991..6e0a54dd4b4 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -152,23 +152,70 @@ import copy import warnings -from typing import Any, Literal, cast +from collections.abc import Sequence +from typing import Any, Literal from pydantic import AliasChoices, ValidationInfo, field_validator, model_validator -from typing_extensions import Required, TypedDict from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.opt.config_loader import load_config from modelopt.torch.utils.network import ConstructorLike -class QuantizerCfgEntry(TypedDict, total=False): +class QuantizerCfgEntry(ModeloptBaseConfig): """A single entry in a ``quant_cfg`` list.""" - quantizer_name: Required[str] # matched against quantizer module names - parent_class: str | None # optional; filters by pytorch module class name (e.g. "nn.Linear") - cfg: dict[str, Any] | list[dict[str, Any]] | None # quantizer attribute config(s) - enable: bool | None # toggles matched quantizers on/off; independent of cfg + quantizer_name: str = ModeloptField( + default=..., + title="Quantizer name pattern.", + description="Glob pattern matched against quantizer module names.", + ) + parent_class: str | None = ModeloptField( + default=None, + title="Optional parent-class filter.", + description="If provided, only quantizers whose parent module matches this PyTorch class " + "name (e.g. ``'nn.Linear'``) are affected.", + ) + cfg: dict[str, Any] | list[dict[str, Any]] | None = ModeloptField( + default=None, + title="Quantizer attribute config.", + description="A ``QuantizerAttributeConfig``-shaped dict, or a list of such dicts for " + "sequential quantizers. ``None`` leaves the existing attribute config untouched.", + ) + enable: bool = ModeloptField( + default=True, + title="Enable the quantizer.", + description="Toggle matched quantizers on/off; independent of ``cfg``.", + ) + + @model_validator(mode="after") + def _validate_instruction(self): + """Reject entries that carry no instruction beyond the path selector.""" + fields_set = self.model_fields_set + if "cfg" not in fields_set and "enable" not in fields_set: + raise ValueError( + f"QuantizerCfgEntry must specify 'cfg', 'enable', or both. An entry with only " + f"'quantizer_name'={self.quantizer_name!r} has no effect (implicit enable=True " + "is not allowed; set it explicitly)." + ) + + if self.enable and self.cfg is not None: + if isinstance(self.cfg, dict): + is_invalid = len(self.cfg) == 0 + elif isinstance(self.cfg, list): + is_invalid = len(self.cfg) == 0 or any( + not isinstance(item, dict) or len(item) == 0 for item in self.cfg + ) + else: + is_invalid = True + if is_invalid: + raise ValueError( + f"QuantizerCfgEntry 'cfg' must be a non-empty dict or a non-empty list of " + f"non-empty dicts when enabling quantizer {self.quantizer_name!r}, got " + f"{type(self.cfg).__name__}: {self.cfg!r}. Either provide quantizer " + "attributes in 'cfg' or remove 'cfg' and set 'enable' explicitly." + ) + return self def find_quant_cfg_entry_by_path( @@ -197,7 +244,7 @@ def find_quant_cfg_entry_by_path( """ result = None for entry in quant_cfg_list: - if isinstance(entry, dict) and entry.get("quantizer_name") == quantizer_name: + if entry.get("quantizer_name") == quantizer_name: result = entry if result is None: raise KeyError(f"No quant_cfg entry with quantizer_name={quantizer_name!r}") @@ -930,13 +977,23 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): QuantizeQuantCfgType = list[QuantizerCfgEntry] QuantizerCfgListConfig = QuantizeQuantCfgType +# Pre-normalization input shape: a sequence whose entries can be raw dicts (any of the +# legacy / new dict forms) or already-validated QuantizerCfgEntry instances. +# ``Sequence`` (rather than ``list``) keeps the alias covariant so callers can pass +# ``list[QuantizerCfgEntry]`` without an invariance error. +# ``normalize_quant_cfg_list`` additionally accepts a single legacy flat ``dict`` for the +# whole list, but that path is deprecated and not surfaced in this alias. +RawQuantizeQuantCfgType = Sequence[QuantizerCfgEntry | dict[str, Any]] + _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None -def normalize_quant_cfg_list(v: dict | list) -> list[QuantizerCfgEntry]: - """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` dicts. +def normalize_quant_cfg_list( + v: RawQuantizeQuantCfgType | dict[str, Any], +) -> list[QuantizerCfgEntry]: + """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` instances. Supports the following input forms: @@ -951,35 +1008,19 @@ def normalize_quant_cfg_list(v: dict | list) -> list[QuantizerCfgEntry]: - Legacy ``nn.*``-scoped format: ``{"nn.": {"": }}`` — converted to a new-format entry with ``parent_class`` set. - **Validation** — an entry is rejected if it carries no instruction, i.e. it specifies neither - ``cfg`` nor ``enable``. Concretely, the following are invalid: - - - An empty entry ``{}``. - - An entry with only ``quantizer_name`` and no other keys — the only effect would be an - implicit ``enable=True``, which must be stated explicitly. - - An entry with ``enable=True`` (explicit or implicit) whose ``cfg`` is not a non-empty - ``dict`` or ``list`` — e.g. ``{"quantizer_name": "*", "cfg": {}}`` or - ``{"quantizer_name": "*", "cfg": 42}``. An enabled quantizer must have a valid - configuration. - - **Normalization** — after conversion and validation every entry is put into canonical form: - - - ``enable`` is set to ``True`` if not explicitly specified. - - ``cfg`` is set to ``None`` if not present in the entry. - - Every returned entry is therefore guaranteed to have the keys ``quantizer_name``, ``enable``, - and ``cfg`` (plus optionally ``parent_class``). + Each normalized dict is then constructed into a :class:`QuantizerCfgEntry`, whose own + validator enforces that every entry specifies ``cfg``, ``enable``, or both, and that any + ``cfg`` for an enabled quantizer is a non-empty dict or non-empty list of non-empty dicts. Args: v: A list of raw quant_cfg entries in any supported format, or a legacy flat dict. Returns: - A list of :class:`QuantizerCfgEntry` dicts in canonical normalized form. + A list of validated :class:`QuantizerCfgEntry` instances. Raises: - ValueError: If any entry has only ``quantizer_name`` with neither ``cfg`` nor ``enable``, - if ``enable=True`` with an empty or non-dict/list ``cfg``, or if the entry format - is not recognized. + ValueError: If any entry's shape is not recognized, or if it fails + :class:`QuantizerCfgEntry` validation (missing instruction or invalid ``cfg``). """ def _warn_legacy(): @@ -997,8 +1038,8 @@ def _warn_legacy(): _warn_legacy() v = [{k: val} for k, val in v.items()] - def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: - """Convert a single legacy key-value pair to one or more QuantizerCfgEntry dicts.""" + def _dict_to_entry(key: str, value) -> list[dict[str, Any]]: + """Convert a single legacy key-value pair to one or more entry dicts.""" # Legacy "default" key was a catch-all applied as "*" in the old conversion code. if key == "default": key = "*" @@ -1007,12 +1048,12 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: if not isinstance(value, dict): raise ValueError(f"For 'nn.*' scoped format, value must be a dict, got {value!r}") # Support multi-key nn.*-scoped dicts by emitting one entry per sub-key. - entries: list[QuantizerCfgEntry] = [] + entries: list[dict[str, Any]] = [] for q_path, sub_cfg in value.items(): sub_cfg = dict(sub_cfg) enable = sub_cfg.pop("enable", None) cfg = sub_cfg or None - entry: QuantizerCfgEntry = { + entry: dict[str, Any] = { "parent_class": key, "quantizer_name": q_path, "cfg": cfg, @@ -1036,8 +1077,14 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: result: list[QuantizerCfgEntry] = [] _warned_legacy = False for raw in v: + # Already-validated QuantizerCfgEntry instances (e.g. produced by load_config on a + # snippet schematized with `# modelopt-schema: QuantizerCfgEntry`, then spread into + # a quant_cfg list) are passed through unchanged. + if isinstance(raw, QuantizerCfgEntry): + result.append(raw) + continue if isinstance(raw, dict) and "quantizer_name" in raw: - entries = [dict(raw)] # copy to avoid mutating caller's data + entries: list[dict[str, Any]] = [dict(raw)] # copy to avoid mutating caller's data elif isinstance(raw, dict) and len(raw) == 1: key, val = next(iter(raw.items())) entries = [dict(e) for e in _dict_to_entry(key, val)] @@ -1055,42 +1102,10 @@ def _dict_to_entry(key: str, value) -> list[QuantizerCfgEntry]: else: raise ValueError(f"Invalid quant_cfg entry: {raw!r}.") - for entry in entries: - # Validate: must carry at least one instruction beyond the path selector. - if "cfg" not in entry and "enable" not in entry: - raise ValueError( - f"Invalid quant_cfg entry: {raw!r} — each entry must specify 'cfg', 'enable', " - "or both. An entry with only 'quantizer_name' has no effect (implicit " - "enable=True is not allowed; set it explicitly)." - ) - - # Validate: when cfg is present and enable=True, cfg must be a non-empty - # dict or list. An empty cfg would attempt to create a - # QuantizerAttributeConfig with no actual configuration. - cfg = entry.get("cfg") - enable = entry.get("enable", True) - if enable and cfg is not None: - if isinstance(cfg, dict): - is_invalid = len(cfg) == 0 - elif isinstance(cfg, list): - is_invalid = len(cfg) == 0 or any( - not isinstance(item, dict) or len(item) == 0 for item in cfg - ) - else: - is_invalid = True - if is_invalid: - raise ValueError( - f"Invalid quant_cfg entry: {raw!r} — 'cfg' must be a non-empty dict " - f"or a non-empty list of non-empty dicts when enabling a quantizer " - f"(got {type(cfg).__name__}: {cfg!r}). Either provide quantizer " - "attributes in 'cfg' or remove 'cfg' and set 'enable' explicitly." - ) - - # Normalize: make enable and cfg always explicit. - entry.setdefault("enable", True) - entry.setdefault("cfg", None) - - result.append(cast("QuantizerCfgEntry", entry)) + # Constructing each QuantizerCfgEntry runs its model_validator, which enforces the + # at-least-one-of('cfg', 'enable') and cfg-shape constraints. Defaults for absent + # 'cfg' / 'enable' are filled by the pydantic field defaults. + result.extend(QuantizerCfgEntry(**entry) for entry in entries) return result @@ -1157,15 +1172,13 @@ class _QuantizeExportConfig(ModeloptBaseConfig): """An empty config.""" -_base_disable_all: list[QuantizerCfgEntry] = [ - cast("QuantizerCfgEntry", load_config("configs/ptq/units/base_disable_all")) -] +_base_disable_all: list[QuantizerCfgEntry] = [load_config("configs/ptq/units/base_disable_all")] _default_disabled_quantizer_cfg: list[QuantizerCfgEntry] = load_config( "configs/ptq/units/default_disabled_quantizers" ) -_mamba_moe_disabled_quantizer_cfg: list[QuantizerCfgEntry] = [ +_mamba_moe_disabled_quantizer_cfg: list[QuantizerCfgEntry | dict[str, Any]] = [ {"quantizer_name": "*fc1_latent_proj*", "enable": False}, # Skip Latent MOE {"quantizer_name": "*fc2_latent_proj*", "enable": False}, # Skip Latent MOE {"quantizer_name": "*q_proj*", "enable": False}, # Skip QKV Linear (HF naming) @@ -1490,7 +1503,7 @@ def _nvfp4_selective_quant_cfg( algorithm: str | dict = "max", ) -> dict: """Build an NVFP4 config that quantizes only the specified layer patterns.""" - quant_cfg: list[QuantizerCfgEntry] = [] + quant_cfg: list[QuantizerCfgEntry | dict[str, Any]] = [] quant_cfg.extend(_base_disable_all) for pattern in layer_patterns: # Deep-copy the quantizer dict so each config constant gets its own instance. diff --git a/modelopt/torch/quantization/conversion.py b/modelopt/torch/quantization/conversion.py index 3f97f8380be..40c2b8dbc7e 100644 --- a/modelopt/torch/quantization/conversion.py +++ b/modelopt/torch/quantization/conversion.py @@ -31,8 +31,8 @@ from .config import ( QuantizeConfig, - QuantizeQuantCfgType, QuantizerAttributeConfig, + RawQuantizeQuantCfgType, _QuantizeExportConfig, normalize_quant_cfg_list, ) @@ -215,7 +215,7 @@ def _replace_quant_module(model: nn.Module, version=None, registry=QuantModuleRe _replace_quant_module(getattr(model, name), version=version, registry=registry) -def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): +def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: RawQuantizeQuantCfgType): """Apply a quantization config list to the quantizers in ``quant_model``. ``quant_cfg`` is an **ordered list** of :class:`QuantizerCfgEntry <.config.QuantizerCfgEntry>` @@ -477,7 +477,7 @@ def set_quantizer_attributes_partial( @contextmanager -def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType): +def set_quantizer_by_cfg_context(quant_model: nn.Module, quant_cfg: RawQuantizeQuantCfgType): """Context manager that temporarily applies a quantization config and restores the original state on exit. Calls :func:`set_quantizer_by_cfg` on entry and reverts every diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index bfbce21ed55..271fefd8965 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -515,7 +515,15 @@ def test_import_entry_element_schema_appends(tmp_path): f" - $import: disable_all\n" ) recipe = load_recipe(recipe_file) - assert recipe.quantize["quant_cfg"] == [{"quantizer_name": "*", "cfg": None, "enable": False}] + # Entry was loaded against the QuantizerCfgEntry pydantic schema, so it is now a + # model instance — compare via model_dump for the dict-shape check. + assert len(recipe.quantize["quant_cfg"]) == 1 + assert recipe.quantize["quant_cfg"][0].model_dump() == { + "quantizer_name": "*", + "parent_class": None, + "cfg": None, + "enable": False, + } def test_import_entry_wrong_schema_raises(tmp_path): @@ -856,7 +864,8 @@ def test_import_list_splice_outside_typed_list_raises(tmp_path): """A bare $import in an untyped list is rejected.""" _write_quantizer_cfg_list( tmp_path / "extra_tasks.yml", - "- quantizer_name: '*weight_quantizer'\n- quantizer_name: '*input_quantizer'\n", + "- quantizer_name: '*weight_quantizer'\n enable: false\n" + "- quantizer_name: '*input_quantizer'\n enable: false\n", ) config_file = tmp_path / "config.yml" config_file.write_text( @@ -920,9 +929,11 @@ def test_import_mixed_tree(tmp_path): # Dict import inside list entry assert data["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} # List splice — entries are normalized by QuantizeConfig.quant_cfg's validator, - # which fills in defaults for missing ``enable`` / ``cfg`` keys. - assert data["quant_cfg"][1] == { + # which fills in defaults for missing ``enable`` / ``cfg`` keys. Entries are now + # QuantizerCfgEntry pydantic instances, so compare via model_dump. + assert data["quant_cfg"][1].model_dump() == { "quantizer_name": "*lm_head*", + "parent_class": None, "enable": False, "cfg": None, } @@ -1157,7 +1168,14 @@ def test_modelopt_schema_comment_validates_after_import_resolution(tmp_path): f" $import: fp8\n" ) data = load_config(config_file) - assert data == [{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": (4, 3)}}] + # data is a list of QuantizerCfgEntry pydantic instances, not raw dicts. + assert len(data) == 1 + assert data[0].model_dump() == { + "quantizer_name": "*weight_quantizer", + "parent_class": None, + "cfg": {"num_bits": (4, 3)}, + "enable": True, + } # --------------------------------------------------------------------------- @@ -1262,7 +1280,13 @@ def test_load_config_list_valued_yaml(tmp_path): data = load_config(cfg_file) assert isinstance(data, list) assert len(data) == 2 - assert data[0] == {"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8}} + # Entries are QuantizerCfgEntry pydantic instances after schema validation. + assert data[0].model_dump() == { + "quantizer_name": "*weight_quantizer", + "parent_class": None, + "cfg": {"num_bits": 8}, + "enable": True, + } # --------------------------------------------------------------------------- @@ -1274,7 +1298,8 @@ def test_import_dict_value_resolves_to_list_raises(tmp_path): """$import in dict value position raises when snippet is a list.""" _write_quantizer_cfg_list( tmp_path / "entries.yml", - "- quantizer_name: '*weight_quantizer'\n- quantizer_name: '*input_quantizer'\n", + "- quantizer_name: '*weight_quantizer'\n enable: false\n" + "- quantizer_name: '*input_quantizer'\n enable: false\n", ) config_file = tmp_path / "config.yml" config_file.write_text( diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 84306dc5116..88ef7faa375 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -184,8 +184,13 @@ def test_error_on_empty_cfg_list_enable_true(self): ) def test_error_on_non_dict_non_list_cfg_enable_true(self): - """Entry with cfg of invalid type (e.g. int) and enable=True is rejected.""" - with pytest.raises(ValueError, match="non-empty dict"): + """Entry with cfg of invalid type (e.g. int) and enable=True is rejected. + + Pydantic's field-type check fires before the QuantizerCfgEntry model validator, + so this surfaces as a type error rather than the 'non-empty dict' message — + either is acceptable here as long as the entry is rejected. + """ + with pytest.raises(ValueError): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": 42, "enable": True}] ) @@ -198,8 +203,13 @@ def test_error_on_cfg_list_with_empty_dict_enable_true(self): ) def test_error_on_cfg_list_with_non_dict_element_enable_true(self): - """Entry with cfg=[42] and enable=True is rejected (non-dict element).""" - with pytest.raises(ValueError, match="non-empty dict"): + """Entry with cfg=[42] and enable=True is rejected. + + Pydantic's field-type check fires before the QuantizerCfgEntry model validator, + so the message may report a type error instead of 'non-empty dict' — either is + acceptable, as long as the entry is rejected. + """ + with pytest.raises(ValueError): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": [42], "enable": True}] ) From 058234a001fabc0adffccb7849906e365650376a Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 14 May 2026 10:54:30 -0700 Subject: [PATCH 03/14] refactor(quant): tighten type hints on QuantizeConfig field validators - normalize_quant_cfg: annotate as ``-> QuantizeQuantCfgType`` and always delegate to normalize_quant_cfg_list; remove the silent passthrough for non-list/non-dict input. - normalize_quant_cfg_list: explicitly reject non-list/non-dict input with a clear ValueError so the field-validator's contract is honored (no more TypeError trickling out of the for-loop). - validate_quant_cfg_entries: annotate as ``(QuantizeQuantCfgType) -> QuantizeQuantCfgType``; switch ``entry.get("cfg")`` to ``entry.cfg`` since by mode="after" each element is guaranteed to be a QuantizerCfgEntry instance. - Refresh stale docstrings that still referred to "QuantizerCfgEntry dicts" from the pre-pydantic TypedDict era. Signed-off-by: Shengliang Xu --- modelopt/torch/quantization/config.py | 31 ++++++++++++++++++++------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 6e0a54dd4b4..918dabe3d41 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1037,6 +1037,11 @@ def _warn_legacy(): if isinstance(v, dict): _warn_legacy() v = [{k: val} for k, val in v.items()] + elif not isinstance(v, list): + raise ValueError( + f"quant_cfg must be a list of entries (or a legacy flat dict), got " + f"{type(v).__name__}: {v!r}." + ) def _dict_to_entry(key: str, value) -> list[dict[str, Any]]: """Convert a single legacy key-value pair to one or more entry dicts.""" @@ -1127,22 +1132,32 @@ class QuantizeConfig(ModeloptBaseConfig): @field_validator("quant_cfg", mode="before") @classmethod - def normalize_quant_cfg(cls, v): - """Normalize quant_cfg entries: convert dict and tuple forms to QuantizerCfgEntry dicts.""" - if not isinstance(v, (list, dict)): - return v + def normalize_quant_cfg(cls, v: Any) -> QuantizeQuantCfgType: + """Normalize raw quant_cfg input into a ``list[QuantizerCfgEntry]``. + + Delegates to :func:`normalize_quant_cfg_list`, which accepts every supported input + shape (new-format list, legacy single-key-dict list, legacy flat dict, and lists + containing already-validated ``QuantizerCfgEntry`` instances) and rejects anything + else with a clear ``ValueError`` before pydantic's field-type check would see it. + """ return normalize_quant_cfg_list(v) @field_validator("quant_cfg", mode="after") @classmethod - def validate_quant_cfg_entries(cls, v): - """Validate quantizer attribute configs to surface errors (e.g. invalid axis/block_sizes).""" + def validate_quant_cfg_entries(cls, v: QuantizeQuantCfgType) -> QuantizeQuantCfgType: + """Validate each entry's ``cfg`` against :class:`QuantizerAttributeConfig`. + + Runs after the ``mode="before"`` normalizer and pydantic's field-type check, so + every element here is already a :class:`QuantizerCfgEntry`. This second pass + surfaces attribute-level errors (e.g. invalid ``axis`` / ``block_sizes``) that the + per-entry ``QuantizerCfgEntry`` validator doesn't inspect. + """ qac_fields = set(QuantizerAttributeConfig.model_fields.keys()) for entry in v: - cfg = entry.get("cfg") + cfg = entry.cfg if cfg is None: continue - cfgs = cfg if isinstance(cfg, list) else [cfg] + cfgs: list[dict[str, Any]] = cfg if isinstance(cfg, list) else [cfg] for c in cfgs: if isinstance(c, dict) and qac_fields & set(c.keys()): QuantizerAttributeConfig.model_validate(c) From 02513b6760a985a69e8081224491717788d96851 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 14 May 2026 11:09:32 -0700 Subject: [PATCH 04/14] refactor(quant): name the legacy flat-dict quant_cfg input shape Introduce ``DeprecatedQuantCfgType = dict[str, Any]`` and use it in ``normalize_quant_cfg_list``'s signature so the legacy flat-dict input form is explicitly labeled deprecated at the type level. ``RawQuantizeQuantCfgType`` continues to describe only the supported list-shaped input. Signed-off-by: Shengliang Xu --- modelopt/torch/quantization/config.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 918dabe3d41..9b2dad1b3cf 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -981,17 +981,20 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): # legacy / new dict forms) or already-validated QuantizerCfgEntry instances. # ``Sequence`` (rather than ``list``) keeps the alias covariant so callers can pass # ``list[QuantizerCfgEntry]`` without an invariance error. -# ``normalize_quant_cfg_list`` additionally accepts a single legacy flat ``dict`` for the -# whole list, but that path is deprecated and not surfaced in this alias. RawQuantizeQuantCfgType = Sequence[QuantizerCfgEntry | dict[str, Any]] +# Legacy flat-dict input shape (``{"*": ..., "*weight_quantizer": ...}``). Accepted by +# ``normalize_quant_cfg_list`` for backward compatibility but emits a DeprecationWarning; +# new code should use a list of :class:`QuantizerCfgEntry`-shaped entries instead. +DeprecatedQuantCfgType = dict[str, Any] + _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None QuantizeAlgoCfgType = _QuantizeAlgoCfgType | list[_QuantizeAlgoCfgType] | None def normalize_quant_cfg_list( - v: RawQuantizeQuantCfgType | dict[str, Any], + v: RawQuantizeQuantCfgType | DeprecatedQuantCfgType, ) -> list[QuantizerCfgEntry]: """Normalize a raw quant_cfg into a list of :class:`QuantizerCfgEntry` instances. From c855d2a4ce87d5c3b3a2fb32adf55a3508f9ca5f Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 14 May 2026 13:22:02 -0700 Subject: [PATCH 05/14] refactor(quant): widen quant_cfg input types to Mapping/Sequence Relax the pre-normalization type aliases so callers aren't forced to pass concrete ``dict``/``list`` types: - ``RawQuantizeQuantCfgType`` becomes ``Sequence[QuantizerCfgEntry] | Sequence[Mapping[str, Any]]`` (two covariant arms instead of one with a union element). - ``DeprecatedQuantCfgType`` becomes ``Mapping[str, Any]``. - ``normalize_quant_cfg_list`` and its inner isinstance dispatch use ``Mapping``/``Sequence`` throughout (excluding ``str``/``bytes`` from the Sequence arm). Also tighten internal call-site annotations: the ``QuantizeConfig`` mode="before" field validator and ``need_calibration`` now declare the real accepted input union rather than ``Any``/bare ``list``. Signed-off-by: Shengliang Xu --- modelopt/torch/quantization/config.py | 46 +++++++++++++++------------ 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 9b2dad1b3cf..ddd1ca37657 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -152,7 +152,7 @@ import copy import warnings -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Any, Literal from pydantic import AliasChoices, ValidationInfo, field_validator, model_validator @@ -977,16 +977,18 @@ class GPTQCalibConfig(QuantizeAlgorithmConfig): QuantizeQuantCfgType = list[QuantizerCfgEntry] QuantizerCfgListConfig = QuantizeQuantCfgType -# Pre-normalization input shape: a sequence whose entries can be raw dicts (any of the -# legacy / new dict forms) or already-validated QuantizerCfgEntry instances. -# ``Sequence`` (rather than ``list``) keeps the alias covariant so callers can pass -# ``list[QuantizerCfgEntry]`` without an invariance error. -RawQuantizeQuantCfgType = Sequence[QuantizerCfgEntry | dict[str, Any]] +# Pre-normalization input shape: either a sequence of already-validated +# :class:`QuantizerCfgEntry` instances, or a sequence of raw mappings (any of the legacy / +# new dict forms). Splitting the union into two ``Sequence[...]`` arms — rather than +# ``Sequence[QuantizerCfgEntry | Mapping[str, Any]]`` — keeps each arm covariant in its +# element type, so callers can pass ``list[QuantizerCfgEntry]`` or ``list[dict]`` without +# tripping invariance. +RawQuantizeQuantCfgType = Sequence[QuantizerCfgEntry] | Sequence[Mapping[str, Any]] # Legacy flat-dict input shape (``{"*": ..., "*weight_quantizer": ...}``). Accepted by # ``normalize_quant_cfg_list`` for backward compatibility but emits a DeprecationWarning; # new code should use a list of :class:`QuantizerCfgEntry`-shaped entries instead. -DeprecatedQuantCfgType = dict[str, Any] +DeprecatedQuantCfgType = Mapping[str, Any] _QuantizeAlgoCfgType = str | dict | QuantizeAlgorithmConfig | None @@ -1037,12 +1039,12 @@ def _warn_legacy(): ) # Legacy flat-dict format: {"*": {...}, "*weight_quantizer": {...}} → list of single-key dicts. - if isinstance(v, dict): + if isinstance(v, Mapping): _warn_legacy() v = [{k: val} for k, val in v.items()] - elif not isinstance(v, list): + elif not isinstance(v, Sequence) or isinstance(v, (str, bytes)): raise ValueError( - f"quant_cfg must be a list of entries (or a legacy flat dict), got " + f"quant_cfg must be a sequence of entries (or a legacy flat mapping), got " f"{type(v).__name__}: {v!r}." ) @@ -1053,8 +1055,10 @@ def _dict_to_entry(key: str, value) -> list[dict[str, Any]]: key = "*" if isinstance(key, str) and key.startswith("nn."): - if not isinstance(value, dict): - raise ValueError(f"For 'nn.*' scoped format, value must be a dict, got {value!r}") + if not isinstance(value, Mapping): + raise ValueError( + f"For 'nn.*' scoped format, value must be a mapping, got {value!r}" + ) # Support multi-key nn.*-scoped dicts by emitting one entry per sub-key. entries: list[dict[str, Any]] = [] for q_path, sub_cfg in value.items(): @@ -1071,7 +1075,7 @@ def _dict_to_entry(key: str, value) -> list[dict[str, Any]]: entries.append(entry) return entries else: - if isinstance(value, dict): + if isinstance(value, Mapping): cfg = {k: val for k, val in value.items() if k != "enable"} or None enable = value.get("enable") else: @@ -1091,15 +1095,15 @@ def _dict_to_entry(key: str, value) -> list[dict[str, Any]]: if isinstance(raw, QuantizerCfgEntry): result.append(raw) continue - if isinstance(raw, dict) and "quantizer_name" in raw: + if isinstance(raw, Mapping) and "quantizer_name" in raw: entries: list[dict[str, Any]] = [dict(raw)] # copy to avoid mutating caller's data - elif isinstance(raw, dict) and len(raw) == 1: + elif isinstance(raw, Mapping) and len(raw) == 1: key, val = next(iter(raw.items())) entries = [dict(e) for e in _dict_to_entry(key, val)] if not _warned_legacy: _warn_legacy() _warned_legacy = True - elif isinstance(raw, dict) and len(raw) > 1 and any(k.startswith("nn.") for k in raw): + elif isinstance(raw, Mapping) and len(raw) > 1 and any(k.startswith("nn.") for k in raw): # Legacy flat dict with nn.*-scoped keys mixed with other keys — expand all pairs. entries = [] for k, val in raw.items(): @@ -1135,7 +1139,9 @@ class QuantizeConfig(ModeloptBaseConfig): @field_validator("quant_cfg", mode="before") @classmethod - def normalize_quant_cfg(cls, v: Any) -> QuantizeQuantCfgType: + def normalize_quant_cfg( + cls, v: RawQuantizeQuantCfgType | DeprecatedQuantCfgType + ) -> QuantizeQuantCfgType: """Normalize raw quant_cfg input into a ``list[QuantizerCfgEntry]``. Delegates to :func:`normalize_quant_cfg_list`, which accepts every supported input @@ -1788,7 +1794,7 @@ def _nvfp4_selective_quant_cfg( } -def need_calibration(config): +def need_calibration(config: QuantizeConfig | Mapping[str, Any]) -> bool: """Check if calibration is needed for the given config.""" if config["algorithm"] is not None and config["algorithm"] != "max": return True @@ -1796,8 +1802,8 @@ def need_calibration(config): def _not_dynamic(cfg): return cfg.get("enable", True) and cfg.get("type", "") != "dynamic" - quant_cfg: list = config.get("quant_cfg") or [] - quant_cfg = normalize_quant_cfg_list(quant_cfg) + raw_quant_cfg: RawQuantizeQuantCfgType | DeprecatedQuantCfgType = config.get("quant_cfg") or [] + quant_cfg: list[QuantizerCfgEntry] = normalize_quant_cfg_list(raw_quant_cfg) for entry in quant_cfg: name = entry["quantizer_name"] raw_cfg = entry.get("cfg") From 0b6b2f0c37d367e6201b99752bc59a1bd0bcc2ac Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 14 May 2026 14:15:20 -0700 Subject: [PATCH 06/14] need to have model_dump for explicitly set k/v Signed-off-by: Shengliang Xu --- modelopt/torch/quantization/config.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index ddd1ca37657..9c3595526fc 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1249,7 +1249,9 @@ class _QuantizeExportConfig(ModeloptBaseConfig): "algorithm": "max", } -FP8_DEFAULT_CFG: dict[str, Any] = load_config("configs/ptq/presets/model/fp8") +FP8_DEFAULT_CFG: dict[str, Any] = load_config("configs/ptq/presets/model/fp8").model_dump( + exclude_unset=True +) MAMBA_MOE_FP8_AGGRESSIVE_CFG = { "quant_cfg": [ @@ -1494,7 +1496,9 @@ class _QuantizeExportConfig(ModeloptBaseConfig): # KV-cache configs are designed to be merged with a primary quantization config (e.g. # FP8_DEFAULT_CFG) that already contains _base_disable_all. They intentionally omit both # _base_disable_all and "algorithm" because these are provided by the primary config. -FP8_KV_CFG: dict[str, Any] = load_config("configs/ptq/presets/kv/fp8") +FP8_KV_CFG: dict[str, Any] = load_config("configs/ptq/presets/kv/fp8").model_dump( + exclude_unset=True +) FP8_AFFINE_KV_CFG = { "quant_cfg": [ From c7ab5931ec70c0ee2c3ed052f85113471bb3cc7f Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 14 May 2026 14:59:24 -0700 Subject: [PATCH 07/14] refactor(recipe): make RecipeMetadataConfig a ModeloptBaseConfig MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Convert ``RecipeMetadataConfig`` from a ``TypedDict`` into a ``ModeloptBaseConfig`` so every schema accepted by ``load_config`` is a pydantic model. ``recipe_type`` is required (plain pydantic ``Field``), ``description`` keeps its default via ``ModeloptField``, and the now- redundant ``validate_metadata`` field validator is dropped — pydantic's native enum + required-field checks cover the same ground. ``ModelOptRecipeBase`` switches to ``default_factory`` for the ``metadata`` field (``ModeloptField`` only supports literal defaults). Convenience properties move from ``metadata["..."]`` to attribute access. ``_load_recipe_from_dir`` loses three defensive isinstance/None checks that are now unreachable: pydantic validation in ``load_config`` rejects malformed inputs upstream. With every schema now a ``ModeloptBaseConfig`` subclass, tighten ``load_config``'s overloads: - ``type[_SchemaT]`` → ``_SchemaT`` - ``type[list[_SchemaT]]`` → ``list[_SchemaT]`` - ``None`` → ``Any`` ``_SchemaT`` is bound to ``ModeloptBaseConfig``, so mypy now enforces the invariant that ``schema_type`` is a ``ModeloptBaseConfig`` subclass (or ``list`` of one) at every call site. The previous ``schema_type: Any`` catch-all is dropped. Signed-off-by: Shengliang Xu --- modelopt/recipe/config.py | 40 ++++++++++++----------------- modelopt/recipe/loader.py | 22 +++------------- modelopt/torch/opt/config_loader.py | 31 +++++++++++++++++++++- tests/unit/recipe/test_loader.py | 2 +- 4 files changed, 51 insertions(+), 44 deletions(-) diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index 96f33012afd..58ac425e0b5 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -19,8 +19,7 @@ from enum import Enum -from pydantic import field_validator -from typing_extensions import NotRequired, TypedDict +from pydantic import Field from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.quantization.config import QuantizeConfig @@ -33,14 +32,21 @@ class RecipeType(str, Enum): # QAT = "qat" # Not implemented yet, will be added in the future. -class RecipeMetadataConfig(TypedDict): - """YAML shape of the recipe metadata section.""" +_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe." - recipe_type: RecipeType - description: NotRequired[str] +class RecipeMetadataConfig(ModeloptBaseConfig): + """YAML shape of the recipe metadata section.""" -_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe." + recipe_type: RecipeType = Field( + title="Recipe type", + description="The type of the recipe (e.g. PTQ).", + ) + description: str = ModeloptField( + default=_DEFAULT_RECIPE_DESCRIPTION, + title="Description", + description="Human-readable description of the recipe.", + ) class ModelOptRecipeBase(ModeloptBaseConfig): @@ -49,33 +55,21 @@ class ModelOptRecipeBase(ModeloptBaseConfig): If a layer name matches ``"*output_layer*"``, the attributes will be replaced with ``{"enable": False}``. """ - metadata: RecipeMetadataConfig = ModeloptField( - default={"recipe_type": RecipeType.PTQ, "description": _DEFAULT_RECIPE_DESCRIPTION}, + metadata: RecipeMetadataConfig = Field( + default_factory=lambda: RecipeMetadataConfig(recipe_type=RecipeType.PTQ), title="Metadata", description="Recipe metadata containing the recipe type and description.", - validate_default=True, ) - @field_validator("metadata") - @classmethod - def validate_metadata(cls, metadata: RecipeMetadataConfig) -> RecipeMetadataConfig: - """Validate recipe metadata and fill defaults for optional fields.""" - if metadata["recipe_type"] not in RecipeType: - raise ValueError( - f"Unsupported recipe type: {metadata['recipe_type']}. " - f"Only {list(RecipeType)} are currently supported." - ) - return {"description": _DEFAULT_RECIPE_DESCRIPTION, **metadata} - @property def recipe_type(self) -> RecipeType: """Return the recipe type from metadata.""" - return self.metadata["recipe_type"] + return self.metadata.recipe_type @property def description(self) -> str: """Return the recipe description from metadata.""" - return self.metadata.get("description", _DEFAULT_RECIPE_DESCRIPTION) + return self.metadata.description class ModelOptPTQRecipe(ModelOptRecipeBase): diff --git a/modelopt/recipe/loader.py b/modelopt/recipe/loader.py index 8608e3fcbb8..919c4e0379f 100644 --- a/modelopt/recipe/loader.py +++ b/modelopt/recipe/loader.py @@ -123,26 +123,10 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase: quantize. """ metadata_file = _find_recipe_section_file(recipe_dir, "metadata") - metadata = load_config(metadata_file, schema_type=RecipeMetadataConfig) - if not isinstance(metadata, dict): - raise ValueError( - f"Metadata file {metadata_file} must be a YAML mapping, got {type(metadata).__name__}." - ) - recipe_type = metadata.get("recipe_type") - if recipe_type is None: - raise ValueError(f"Metadata file {metadata_file} must contain a 'recipe_type' field.") - if recipe_type == RecipeType.PTQ: + if metadata.recipe_type == RecipeType.PTQ: quantize_file = _find_recipe_section_file(recipe_dir, "quantize") quantize_cfg = load_config(quantize_file, schema_type=QuantizeConfig) - if not isinstance(quantize_cfg, QuantizeConfig): - raise ValueError( - f"{quantize_file} must produce a {QuantizeConfig.__name__}, " - f"got {type(quantize_cfg).__name__}." - ) - return ModelOptPTQRecipe( - metadata=metadata, - quantize=quantize_cfg, - ) - raise ValueError(f"Unsupported recipe type: {recipe_type!r}") + return ModelOptPTQRecipe(metadata=metadata, quantize=quantize_cfg) + raise ValueError(f"Unsupported recipe type: {metadata.recipe_type!r}") diff --git a/modelopt/torch/opt/config_loader.py b/modelopt/torch/opt/config_loader.py index 5dbf0ad5bf1..76ed2bb6503 100644 --- a/modelopt/torch/opt/config_loader.py +++ b/modelopt/torch/opt/config_loader.py @@ -33,12 +33,14 @@ import re import sys from pathlib import Path -from typing import Any, Union, get_args, get_origin, get_type_hints +from typing import Any, TypeVar, Union, get_args, get_origin, get_type_hints, overload import yaml from pydantic import TypeAdapter from typing_extensions import NotRequired, Required, is_typeddict +from modelopt.torch.opt.config import ModeloptBaseConfig + @dataclass class _ListSnippet: @@ -592,6 +594,33 @@ def _find_import_marker(obj: Any, context: str = "root") -> tuple[Any, str] | No return None +_SchemaT = TypeVar("_SchemaT", bound=ModeloptBaseConfig) + + +@overload +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: type[_SchemaT], +) -> _SchemaT: ... + + +@overload +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: type[list[_SchemaT]], +) -> list[_SchemaT]: ... + + +@overload +def load_config( + config_path: str | Path | Traversable, + *, + schema_type: None = None, +) -> Any: ... + + def load_config( config_path: str | Path | Traversable, *, diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index 271fefd8965..eeb039f97e2 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -184,7 +184,7 @@ def test_load_recipe_unsupported_type_raises(tmp_path): """load_recipe raises ValueError for an unknown recipe_type.""" bad = tmp_path / "bad.yml" bad.write_text(CFG_RECIPE_UNSUPPORTED_TYPE) - # Schema-driven validation reports the failure via the TypedDict's enum check. + # Schema-driven validation reports the failure via the metadata schema's enum check. with pytest.raises(ValueError, match="recipe_type"): load_recipe(bad) From df4ffb4b0a530c25355345f9c0838f407670410c Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 14 May 2026 17:20:37 -0700 Subject: [PATCH 08/14] fix(recipe): require metadata and quantize sections MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Both fields were silently defaulted by the schema, so a malformed recipe file missing either section would still load successfully — a PTQ recipe without quantize would quietly fall back to QuantizeConfig() (the default INT8 config), masking the user's mistake. Drop the defaults so pydantic rejects recipes with missing sections at validation time. description stays optional via its own field default. Signed-off-by: Shengliang Xu --- modelopt/recipe/config.py | 15 ++++++++------- tests/unit/recipe/test_loader.py | 25 ++++++++++++++++++------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index 58ac425e0b5..8a2007f55b0 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -22,7 +22,7 @@ from pydantic import Field from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField -from modelopt.torch.quantization.config import QuantizeConfig +from modelopt.torch.quantization.config import QuantizeConfig # noqa: TC001 class RecipeType(str, Enum): @@ -56,9 +56,10 @@ class ModelOptRecipeBase(ModeloptBaseConfig): """ metadata: RecipeMetadataConfig = Field( - default_factory=lambda: RecipeMetadataConfig(recipe_type=RecipeType.PTQ), title="Metadata", - description="Recipe metadata containing the recipe type and description.", + description="Recipe metadata containing the recipe type and description. " + "Required: a recipe without a ``metadata`` section is rejected so that a " + "missing section can't silently fall back to a default recipe type.", ) @property @@ -75,9 +76,9 @@ def description(self) -> str: class ModelOptPTQRecipe(ModelOptRecipeBase): """Our config class for PTQ recipes.""" - quantize: QuantizeConfig = ModeloptField( - default=QuantizeConfig(), + quantize: QuantizeConfig = Field( title="PTQ config", - description="PTQ config containing quant_cfg and algorithm.", - validate_default=True, + description="PTQ config containing quant_cfg and algorithm. Required: a PTQ " + "recipe without a ``quantize`` section is rejected so that a missing section " + "can't silently fall back to the default INT8 config.", ) diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index eeb039f97e2..759b629f2db 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -41,6 +41,10 @@ quantize: {} """ +CFG_RECIPE_MISSING_METADATA = """\ +quantize: {} +""" + CFG_RECIPE_MISSING_quantize = """\ metadata: recipe_type: ptq @@ -49,6 +53,7 @@ CFG_RECIPE_UNSUPPORTED_TYPE = """\ metadata: recipe_type: unknown_type +quantize: {} """ QUANTIZER_ATTRIBUTE_SCHEMA = ( @@ -170,14 +175,20 @@ def test_load_recipe_missing_recipe_type_raises(tmp_path): load_recipe(bad) -def test_load_recipe_missing_quantize_uses_default(tmp_path): - """``quantize`` is optional in a PTQ recipe; absence yields an empty default config.""" - from modelopt.torch.quantization.config import QuantizeConfig +def test_load_recipe_missing_quantize_raises(tmp_path): + """A PTQ recipe missing the ``quantize`` section is rejected (no silent default).""" + bad = tmp_path / "bad.yml" + bad.write_text(CFG_RECIPE_MISSING_quantize) + with pytest.raises(ValueError, match="quantize"): + load_recipe(bad) + - good = tmp_path / "good.yml" - good.write_text(CFG_RECIPE_MISSING_quantize) - recipe = load_recipe(good) - assert isinstance(recipe.quantize, QuantizeConfig) +def test_load_recipe_missing_metadata_raises(tmp_path): + """A recipe missing the ``metadata`` section is rejected (no silent default).""" + bad = tmp_path / "bad.yml" + bad.write_text(CFG_RECIPE_MISSING_METADATA) + with pytest.raises(ValueError, match="metadata"): + load_recipe(bad) def test_load_recipe_unsupported_type_raises(tmp_path): From 0d087e0508783a7dc695ae5e44b05051e31b038b Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 14 May 2026 17:20:51 -0700 Subject: [PATCH 09/14] feat(opt): make ModeloptBaseConfig a real MutableMapping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Until now ModeloptBaseConfig had the mapping-shaped methods (__getitem__/__setitem__/__iter__/__len__/get/keys/values/items) but did not subclass or register collections.abc.Mapping, so isinstance(cfg, Mapping) returned False. Callers migrating off isinstance(cfg, dict) hit this surprise. Inherit from MutableMapping so the isinstance check works and the ABC mixin methods (pop, popitem, setdefault, clear) come along for free. The schema is fixed, so __delitem__ raises TypeError; pop/popitem/clear inherit that failure on existing keys, while pop(key, default) for a missing key still returns the default. The ABC mixins require __getitem__ to raise KeyError on missing keys, not AttributeError — translate at the __getitem__ boundary and update get() to catch KeyError. One test that previously asserted AttributeError on cfg["missing"] is updated to expect KeyError. Signed-off-by: Shengliang Xu --- modelopt/torch/opt/config.py | 39 +++++++++++++++++++++++++---- tests/unit/torch/opt/test_config.py | 2 +- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/opt/config.py b/modelopt/torch/opt/config.py index 62f7b7e16a2..f033596b242 100644 --- a/modelopt/torch/opt/config.py +++ b/modelopt/torch/opt/config.py @@ -17,7 +17,7 @@ import fnmatch import json -from collections.abc import Callable, ItemsView, Iterator, KeysView, ValuesView +from collections.abc import Callable, ItemsView, Iterator, KeysView, MutableMapping, ValuesView from typing import Any, TypeAlias import torch @@ -57,11 +57,18 @@ def ModeloptField(default: Any = PydanticUndefined, **kwargs): # noqa: N802 # TODO: expand config classes to searcher -class ModeloptBaseConfig(BaseModel): +class ModeloptBaseConfig(BaseModel, MutableMapping): """Our config base class for mode configuration. The base class extends the capabilities of pydantic's BaseModel to provide additional methods and properties for easier access and manipulation of the configuration. + + Inherits from :class:`collections.abc.MutableMapping` so instances satisfy + ``isinstance(cfg, Mapping)`` / ``isinstance(cfg, MutableMapping)`` checks and pick up the + mixin methods (``pop``, ``popitem``, ``setdefault``, ``clear``). Schema fields are fixed, + so ``__delitem__`` raises :class:`TypeError`; the inherited ``pop`` / ``clear`` / + ``popitem`` therefore also raise on any existing key, while ``pop(key, default)`` for a + missing key still returns the default normally. """ model_config = PyDanticConfigDict(extra="forbid", validate_assignment=True) @@ -110,18 +117,40 @@ def __contains__(self, key: str) -> bool: return False def __getitem__(self, key: str) -> Any: - """Get the value for the given key (can be name or alias of field).""" - return getattr(self, self.get_field_name_from_key(key)) + """Get the value for the given key (can be name or alias of field). + + Raises :class:`KeyError` for missing keys so the class behaves like a regular + :class:`Mapping` — required for the inherited ``MutableMapping`` mixin methods + (``pop``, ``setdefault``, ...) to dispatch correctly. + """ + try: + return getattr(self, self.get_field_name_from_key(key)) + except AttributeError: + raise KeyError(key) from None def __setitem__(self, key: str, value: Any) -> None: """Set the value for the given key (can be name or alias of field).""" setattr(self, self.get_field_name_from_key(key), value) + def __delitem__(self, key: str) -> None: + """Reject key deletion. + + ``ModeloptBaseConfig`` exposes a fixed pydantic schema, so removing a key is + ill-defined: schema fields can't disappear, and silently resetting them to their + defaults would surprise callers. Raise ``TypeError`` instead. Defined so the + class fully satisfies the ``MutableMapping`` protocol (``__delitem__`` is + required), without committing to actual deletion semantics. + """ + raise TypeError( + f"{type(self).__name__} does not support key deletion; schema fields are " + f"fixed (attempted to delete {key!r})." + ) + def get(self, key: str, default: Any = None) -> Any: """Get the value for the given key (can be name or alias) or default if not found.""" try: return self[key] - except AttributeError: + except KeyError: return default def __len__(self) -> int: diff --git a/tests/unit/torch/opt/test_config.py b/tests/unit/torch/opt/test_config.py index b2ffadb1a78..e0c5993a51a 100644 --- a/tests/unit/torch/opt/test_config.py +++ b/tests/unit/torch/opt/test_config.py @@ -72,7 +72,7 @@ def _run_test(is_new_registered): assert config[lin_name] == lin_expected_value assert config[lin_alias] == lin_expected_value assert getattr(config, lin_name) == lin_expected_value - with nullcontext() if is_new_registered else pytest.raises(AttributeError): + with nullcontext() if is_new_registered else pytest.raises(KeyError): config[new_name] # get From 77c8e675c7ccada27a77438592ac8a97670e135c Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 14 May 2026 17:21:03 -0700 Subject: [PATCH 10/14] fix(quant): normalize empty cfg to None when disabling a quantizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit QuantizerCfgEntry accepted {cfg: {}, enable: False} and kept cfg as the empty dict. Downstream, any non-None cfg is applied as a full quantizer-attribute replacement, so an empty cfg on a disable entry silently resets the quantizer's attributes back to schema defaults — and if a later rule re-enables the quantizer, it comes back with defaults instead of the config it originally carried. Add a model_validator(mode="before") that rewrites cfg to None when enable=False and cfg is empty (empty dict, empty list, or list of empty dicts), so disable-only entries actually behave like disable-only. A non-empty cfg with enable=False is preserved (deliberate disable+replace). Signed-off-by: Shengliang Xu --- modelopt/torch/quantization/config.py | 27 +++++++++++++ .../quantization/test_config_validation.py | 38 +++++++++++++++++-- 2 files changed, 61 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 9c3595526fc..beaf3cf8648 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -188,6 +188,33 @@ class QuantizerCfgEntry(ModeloptBaseConfig): description="Toggle matched quantizers on/off; independent of ``cfg``.", ) + @model_validator(mode="before") + @classmethod + def _drop_empty_cfg_when_disabled(cls, values): + """Treat ``enable=False`` with an empty ``cfg`` as a pure disable. + + Downstream, any non-``None`` ``cfg`` is applied as a full quantizer-attribute + replacement. An entry like ``{cfg: {}, enable: False}`` would therefore reset + the quantizer's attributes back to schema defaults — and if a later rule + re-enables the quantizer, it would come back with defaults rather than the + config it originally carried. Normalise an empty ``cfg`` (empty dict, empty + list, or a list of empty dicts) to ``None`` so a disable-only entry behaves + like one. + """ + if not isinstance(values, dict): + return values + if values.get("enable") is False: + cfg = values.get("cfg") + cfg_is_empty = (isinstance(cfg, dict) and len(cfg) == 0) or ( + isinstance(cfg, list) + and ( + len(cfg) == 0 or all(isinstance(item, dict) and len(item) == 0 for item in cfg) + ) + ) + if cfg_is_empty: + values = {**values, "cfg": None} + return values + @model_validator(mode="after") def _validate_instruction(self): """Reject entries that carry no instruction beyond the path selector.""" diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 88ef7faa375..6670ae983fb 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -214,19 +214,49 @@ def test_error_on_cfg_list_with_non_dict_element_enable_true(self): [{"quantizer_name": "*weight_quantizer", "cfg": [42], "enable": True}] ) - def test_empty_cfg_dict_enable_false_accepted(self): - """Entry with cfg={} and enable=False is allowed (disable-only entry).""" + def test_empty_cfg_dict_enable_false_normalized_to_none(self): + """Entry with cfg={} and enable=False is normalised to cfg=None (disable-only). + + A non-``None`` cfg is applied as a full quantizer-attribute replacement, so an + empty cfg paired with enable=False would silently reset the quantizer's + attributes. Normalisation to ``None`` makes the entry behave like a pure + disable, preserving the existing attribute config. + """ result = normalize_quant_cfg_list( [{"quantizer_name": "*input_quantizer", "cfg": {}, "enable": False}] ) assert result[0]["enable"] is False + assert result[0]["cfg"] is None - def test_empty_cfg_list_enable_false_accepted(self): - """Entry with cfg=[] and enable=False is allowed (disable-only entry).""" + def test_empty_cfg_list_enable_false_normalized_to_none(self): + """Entry with cfg=[] and enable=False is normalised to cfg=None.""" result = normalize_quant_cfg_list( [{"quantizer_name": "*input_quantizer", "cfg": [], "enable": False}] ) assert result[0]["enable"] is False + assert result[0]["cfg"] is None + + def test_cfg_list_of_empty_dicts_enable_false_normalized_to_none(self): + """Entry with cfg=[{}] and enable=False is normalised to cfg=None.""" + result = normalize_quant_cfg_list( + [{"quantizer_name": "*input_quantizer", "cfg": [{}], "enable": False}] + ) + assert result[0]["enable"] is False + assert result[0]["cfg"] is None + + def test_nonempty_cfg_enable_false_preserved(self): + """Entry with a non-empty cfg and enable=False keeps the cfg (disable+replace).""" + result = normalize_quant_cfg_list( + [ + { + "quantizer_name": "*input_quantizer", + "cfg": {"num_bits": 4}, + "enable": False, + } + ] + ) + assert result[0]["enable"] is False + assert result[0]["cfg"] == {"num_bits": 4} def test_new_format_with_list_cfg(self): """cfg can be a list of dicts for SequentialQuantizer.""" From c43c0f0f6dc57a660ab57ed334053754524115bd Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 14 May 2026 17:55:26 -0700 Subject: [PATCH 11/14] fix(quant): keep shared cfg snippets as dicts in public constants After load_config() started returning schema instances, _base_disable_all and _default_disabled_quantizer_cfg held QuantizerCfgEntry objects, and splatting them into the public dict configs (INT4_AWQ_CFG, NVFP4_DEFAULT_CFG, INT8_DEFAULT_CFG, etc.) leaked schema instances into trees that have always been raw dict/list. Callers serialising those constants or doing isinstance(entry, dict) saw the difference. Dump each entry back to a plain dict with exclude_unset=True (matching the existing treatment of FP8_DEFAULT_CFG / FP8_KV_CFG) so the public constants stay raw and the dumped shape is byte-identical to the YAML snippet. Signed-off-by: Shengliang Xu --- modelopt/torch/quantization/config.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index beaf3cf8648..bad3a055308 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1223,13 +1223,24 @@ class _QuantizeExportConfig(ModeloptBaseConfig): """An empty config.""" -_base_disable_all: list[QuantizerCfgEntry] = [load_config("configs/ptq/units/base_disable_all")] +# Shared snippet constants are dumped back to plain dicts before being spliced into +# the public quant config constants below. ``load_config`` returns validated +# ``QuantizerCfgEntry`` instances for schema-tagged files, but the public constants +# (``INT4_AWQ_CFG``, ``NVFP4_DEFAULT_CFG``, etc.) have always been raw dict/list trees; +# splatting schema instances into them would surprise callers that serialise the +# constants or do ``isinstance(entry, dict)`` checks. ``exclude_unset=True`` keeps the +# sparse YAML shape (only the explicitly set fields) so the dumped dicts are +# byte-identical to what authors wrote in the YAML snippets. +_base_disable_all: list[dict[str, Any]] = [ + load_config("configs/ptq/units/base_disable_all").model_dump(exclude_unset=True) +] -_default_disabled_quantizer_cfg: list[QuantizerCfgEntry] = load_config( - "configs/ptq/units/default_disabled_quantizers" -) +_default_disabled_quantizer_cfg: list[dict[str, Any]] = [ + entry.model_dump(exclude_unset=True) + for entry in load_config("configs/ptq/units/default_disabled_quantizers") +] -_mamba_moe_disabled_quantizer_cfg: list[QuantizerCfgEntry | dict[str, Any]] = [ +_mamba_moe_disabled_quantizer_cfg: list[dict[str, Any]] = [ {"quantizer_name": "*fc1_latent_proj*", "enable": False}, # Skip Latent MOE {"quantizer_name": "*fc2_latent_proj*", "enable": False}, # Skip Latent MOE {"quantizer_name": "*q_proj*", "enable": False}, # Skip QKV Linear (HF naming) @@ -1558,7 +1569,7 @@ def _nvfp4_selective_quant_cfg( algorithm: str | dict = "max", ) -> dict: """Build an NVFP4 config that quantizes only the specified layer patterns.""" - quant_cfg: list[QuantizerCfgEntry | dict[str, Any]] = [] + quant_cfg: list[dict[str, Any]] = [] quant_cfg.extend(_base_disable_all) for pattern in layer_patterns: # Deep-copy the quantizer dict so each config constant gets its own instance. From a63d42016a6c7ce879c0a5f5a27836337cc3e725 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Thu, 14 May 2026 17:55:31 -0700 Subject: [PATCH 12/14] fix(opt): __setitem__ raises KeyError for unknown keys When ModeloptBaseConfig started inheriting from MutableMapping, __getitem__ was updated to raise KeyError for missing keys, but __setitem__ still propagated AttributeError from get_field_name_from_key. Direct writes like cfg["unknown"] = value, and inherited mixin helpers like setdefault that route through __setitem__, leaked AttributeError instead of mapping-style KeyError. Translate at the boundary so both read and write halves of the protocol agree. Signed-off-by: Shengliang Xu --- modelopt/torch/opt/config.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/opt/config.py b/modelopt/torch/opt/config.py index f033596b242..fce2eb36f6b 100644 --- a/modelopt/torch/opt/config.py +++ b/modelopt/torch/opt/config.py @@ -129,8 +129,17 @@ def __getitem__(self, key: str) -> Any: raise KeyError(key) from None def __setitem__(self, key: str, value: Any) -> None: - """Set the value for the given key (can be name or alias of field).""" - setattr(self, self.get_field_name_from_key(key), value) + """Set the value for the given key (can be name or alias of field). + + Raises :class:`KeyError` (not :class:`AttributeError`) for unknown keys so the + class matches the :class:`MutableMapping` protocol — both for direct + ``cfg["unknown"] = value`` writes and for inherited mixin helpers like + ``setdefault`` that write through ``__setitem__``. + """ + try: + setattr(self, self.get_field_name_from_key(key), value) + except AttributeError: + raise KeyError(key) from None def __delitem__(self, key: str) -> None: """Reject key deletion. From d7b6e0a217661a6410285fa4d38035a86a84c90f Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Fri, 15 May 2026 08:04:04 -0700 Subject: [PATCH 13/14] test(quant): tighten cfg-shape rejection assertions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two tests previously asserted only that ValueError was raised when cfg had an invalid python type (cfg=42, cfg=[42]) — pydantic's field-type check fires before QuantizerCfgEntry's model validator, so the older match="non-empty dict" pattern stopped matching and was dropped. A bare pytest.raises(ValueError) accepts any ValueError anywhere in the call path, which is weaker than the intent. Restore a regex that requires the error message to implicate the cfg field and identify a type/shape problem. Accepts either path: - pydantic: "cfg.dict[...]\n Input should be a valid dictionary" - validator: "'cfg' must be a non-empty dict ..." Signed-off-by: Shengliang Xu --- .../quantization/test_config_validation.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 6670ae983fb..00f7cb3605e 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -186,11 +186,14 @@ def test_error_on_empty_cfg_list_enable_true(self): def test_error_on_non_dict_non_list_cfg_enable_true(self): """Entry with cfg of invalid type (e.g. int) and enable=True is rejected. - Pydantic's field-type check fires before the QuantizerCfgEntry model validator, - so this surfaces as a type error rather than the 'non-empty dict' message — - either is acceptable here as long as the entry is rejected. + Two error paths are acceptable here, and the assertion accepts either: + pydantic's field-type check (``cfg`` must be a dict or list) fires first when + ``cfg`` is the wrong python type, while ``QuantizerCfgEntry``'s model validator + emits the "non-empty dict" message when ``cfg`` is the right type but empty. + Either way the message must implicate the ``cfg`` field, not just any + ``ValueError``. """ - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"(?s)cfg.*(non-empty|valid dictionary|valid list)"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": 42, "enable": True}] ) @@ -205,11 +208,12 @@ def test_error_on_cfg_list_with_empty_dict_enable_true(self): def test_error_on_cfg_list_with_non_dict_element_enable_true(self): """Entry with cfg=[42] and enable=True is rejected. - Pydantic's field-type check fires before the QuantizerCfgEntry model validator, - so the message may report a type error instead of 'non-empty dict' — either is - acceptable, as long as the entry is rejected. + Same dual-path acceptance as :meth:`test_error_on_non_dict_non_list_cfg_enable_true`: + pydantic may report a list-element type error, or the model validator may report + "non-empty dict"; the assertion accepts either as long as the message names the + ``cfg`` field. """ - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"(?s)cfg.*(non-empty|valid dictionary|valid list)"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": [42], "enable": True}] ) From 251ea06728312f0cfe9246e87328ce889fd332cc Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Fri, 15 May 2026 09:58:21 -0700 Subject: [PATCH 14/14] refactor(quant): schematize QuantizerCfgEntry.cfg as QuantizerAttributeConfig MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit QuantizerCfgEntry.cfg was typed as dict[str, Any] | list[dict[str, Any]] | None even though every value is supposed to be QuantizerAttributeConfig- shaped. Schematize the field so the schema layer matches the intent: cfg: QuantizerAttributeConfig | list[QuantizerAttributeConfig] | None Pydantic auto-coerces dicts, so YAML loading, public dict-form CFG constants, and dict-input callers keep working untouched. Knock-on cleanups in the same file: - _drop_empty_cfg_when_disabled is folded into a single mode="before" validator (_normalize_cfg_shape) that handles both empty-cfg rules: disabled+empty → cfg=None (normalize), enabled+empty → ValueError with a clearer message. Both rules must run on the raw input before pydantic coerces {} into a default-filled QAC, otherwise "user gave no attributes (typo)" and "user wants schema defaults" become indistinguishable. - The duplicate "non-empty dict" check in _validate_instruction is removed; the responsibility lives entirely in _normalize_cfg_shape now. - QuantizeConfig.validate_quant_cfg_entries is deleted. It used to do a second-pass QuantizerAttributeConfig.model_validate(c); now that cfg *is* QuantizerAttributeConfig, pydantic's field-type validation catches the same attribute-level errors on the first pass. Tests: - Direct dict comparisons of entry["cfg"] become entry["cfg"].model_dump(exclude_unset=True) == {...} (cfg is a model now, not a raw dict). Nested cases like result[0]["cfg"][0]["num_bits"] keep working via the MutableMapping interface inherited from ModeloptBaseConfig. - Error-message regex updated from "non-empty dict" to "at least one quantizer attribute" to match the rewritten message. - One test using load_config without a schema (raw dict tree) is left comparing against a plain dict, with a comment explaining why. Signed-off-by: Shengliang Xu --- modelopt/torch/quantization/config.py | 96 +++++++------------ tests/unit/recipe/test_loader.py | 55 +++++++---- .../quantization/test_config_validation.py | 30 +++--- 3 files changed, 87 insertions(+), 94 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index bad3a055308..acee22fa993 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -176,11 +176,12 @@ class QuantizerCfgEntry(ModeloptBaseConfig): description="If provided, only quantizers whose parent module matches this PyTorch class " "name (e.g. ``'nn.Linear'``) are affected.", ) - cfg: dict[str, Any] | list[dict[str, Any]] | None = ModeloptField( + cfg: "QuantizerAttributeConfig | list[QuantizerAttributeConfig] | None" = ModeloptField( default=None, title="Quantizer attribute config.", - description="A ``QuantizerAttributeConfig``-shaped dict, or a list of such dicts for " - "sequential quantizers. ``None`` leaves the existing attribute config untouched.", + description="A :class:`QuantizerAttributeConfig` (or a mapping that validates as one), " + "or a list of such for sequential quantizers. ``None`` leaves the existing attribute " + "config untouched.", ) enable: bool = ModeloptField( default=True, @@ -190,29 +191,42 @@ class QuantizerCfgEntry(ModeloptBaseConfig): @model_validator(mode="before") @classmethod - def _drop_empty_cfg_when_disabled(cls, values): - """Treat ``enable=False`` with an empty ``cfg`` as a pure disable. - - Downstream, any non-``None`` ``cfg`` is applied as a full quantizer-attribute - replacement. An entry like ``{cfg: {}, enable: False}`` would therefore reset - the quantizer's attributes back to schema defaults — and if a later rule - re-enables the quantizer, it would come back with defaults rather than the - config it originally carried. Normalise an empty ``cfg`` (empty dict, empty - list, or a list of empty dicts) to ``None`` so a disable-only entry behaves - like one. + def _normalize_cfg_shape(cls, values): + """Pre-validation shape rules for ``cfg``. + + Runs against the raw input mapping, before pydantic coerces ``cfg`` into a + :class:`QuantizerAttributeConfig` (which would fill in schema defaults and erase the + distinction between "user typed nothing" and "user typed `{}`"). Two rules: + + 1. ``enable=False`` with an empty ``cfg`` — empty dict, empty list, or list of empty + dicts — is normalized to ``cfg=None``. Downstream applies any non-``None`` ``cfg`` + as a full quantizer-attribute replacement, so without this an entry like + ``{cfg: {}, enable: False}`` would reset attributes to schema defaults and a later + re-enable would bring the quantizer back with defaults instead of its original config. + + 2. ``enable=True`` (explicit or implicit) with an empty ``cfg`` — same shapes — is + rejected. Pydantic would otherwise coerce ``{}`` into ``QuantizerAttributeConfig()`` + with all defaults, silently turning a likely typo (``cfg: {}``) into "quantize with + schema defaults." Callers who really want defaults should drop ``cfg`` entirely and + rely on ``enable=True``; an empty ``cfg`` always indicates missing input. """ if not isinstance(values, dict): return values - if values.get("enable") is False: - cfg = values.get("cfg") - cfg_is_empty = (isinstance(cfg, dict) and len(cfg) == 0) or ( - isinstance(cfg, list) - and ( - len(cfg) == 0 or all(isinstance(item, dict) and len(item) == 0 for item in cfg) - ) - ) - if cfg_is_empty: + cfg = values.get("cfg") + cfg_is_empty = (isinstance(cfg, dict) and len(cfg) == 0) or ( + isinstance(cfg, list) + and (len(cfg) == 0 or all(isinstance(item, dict) and len(item) == 0 for item in cfg)) + ) + if cfg_is_empty: + if values.get("enable") is False: values = {**values, "cfg": None} + else: + raise ValueError( + f"QuantizerCfgEntry 'cfg' must specify at least one quantizer attribute; " + f"got an empty mapping/list for quantizer " + f"{values.get('quantizer_name')!r}. To keep existing attributes, drop " + f"'cfg' and rely on 'enable=True'; to disable, set 'enable=False'." + ) return values @model_validator(mode="after") @@ -225,23 +239,6 @@ def _validate_instruction(self): f"'quantizer_name'={self.quantizer_name!r} has no effect (implicit enable=True " "is not allowed; set it explicitly)." ) - - if self.enable and self.cfg is not None: - if isinstance(self.cfg, dict): - is_invalid = len(self.cfg) == 0 - elif isinstance(self.cfg, list): - is_invalid = len(self.cfg) == 0 or any( - not isinstance(item, dict) or len(item) == 0 for item in self.cfg - ) - else: - is_invalid = True - if is_invalid: - raise ValueError( - f"QuantizerCfgEntry 'cfg' must be a non-empty dict or a non-empty list of " - f"non-empty dicts when enabling quantizer {self.quantizer_name!r}, got " - f"{type(self.cfg).__name__}: {self.cfg!r}. Either provide quantizer " - "attributes in 'cfg' or remove 'cfg' and set 'enable' explicitly." - ) return self @@ -1178,27 +1175,6 @@ def normalize_quant_cfg( """ return normalize_quant_cfg_list(v) - @field_validator("quant_cfg", mode="after") - @classmethod - def validate_quant_cfg_entries(cls, v: QuantizeQuantCfgType) -> QuantizeQuantCfgType: - """Validate each entry's ``cfg`` against :class:`QuantizerAttributeConfig`. - - Runs after the ``mode="before"`` normalizer and pydantic's field-type check, so - every element here is already a :class:`QuantizerCfgEntry`. This second pass - surfaces attribute-level errors (e.g. invalid ``axis`` / ``block_sizes``) that the - per-entry ``QuantizerCfgEntry`` validator doesn't inspect. - """ - qac_fields = set(QuantizerAttributeConfig.model_fields.keys()) - for entry in v: - cfg = entry.cfg - if cfg is None: - continue - cfgs: list[dict[str, Any]] = cfg if isinstance(cfg, list) else [cfg] - for c in cfgs: - if isinstance(c, dict) and qac_fields & set(c.keys()): - QuantizerAttributeConfig.model_validate(c) - return v - class CompressConfig(ModeloptBaseConfig): """Default configuration for ``compress`` mode.""" diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index 759b629f2db..e8d0d33b6c7 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -317,7 +317,7 @@ def test_import_resolves_cfg_reference(tmp_path): ) recipe = load_recipe(recipe_file) entry = recipe.quantize["quant_cfg"][0] - assert entry["cfg"] == {"num_bits": (4, 3), "axis": None} + assert entry["cfg"].model_dump(exclude_unset=True) == {"num_bits": (4, 3), "axis": None} def test_import_same_name_used_twice(tmp_path): @@ -390,7 +390,10 @@ def test_import_inline_cfg_not_affected(tmp_path): f" axis: 0\n" ) recipe = load_recipe(recipe_file) - assert recipe.quantize["quant_cfg"][1]["cfg"] == {"num_bits": 8, "axis": 0} + assert recipe.quantize["quant_cfg"][1]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": 8, + "axis": 0, + } def test_import_unknown_reference_raises(tmp_path): @@ -619,7 +622,7 @@ def test_import_cfg_extend(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "axis": 0} + assert cfg.model_dump(exclude_unset=True) == {"num_bits": (4, 3), "axis": 0} def test_import_cfg_inline_overrides_import(tmp_path): @@ -682,6 +685,7 @@ def test_import_in_multiple_dict_values(tmp_path): ) data = load_config(config_file) entry = data["quant_cfg"][0] + # load_config has no schema here — data is a raw dict tree, so entry["cfg"] is a dict. assert entry["cfg"] == {"num_bits": (4, 3)} assert entry["my_field"] == {"fake_quant": False} @@ -706,7 +710,7 @@ def test_import_cfg_multi_import(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "axis": 0} + assert cfg.model_dump(exclude_unset=True) == {"num_bits": (4, 3), "axis": 0} def test_import_cfg_multi_import_later_overrides_earlier(tmp_path): @@ -755,7 +759,11 @@ def test_import_cfg_multi_import_with_extend(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3), "fake_quant": False, "axis": 0} + assert cfg.model_dump(exclude_unset=True) == { + "num_bits": (4, 3), + "fake_quant": False, + "axis": 0, + } def test_import_dir_format(tmp_path): @@ -772,7 +780,10 @@ def test_import_dir_format(tmp_path): " $import: fp8\n" ) recipe = load_recipe(tmp_path) - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3), "axis": None} + assert recipe.quantize["quant_cfg"][0]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": (4, 3), + "axis": None, + } def test_import_dir_format_metadata_imports_do_not_apply_to_quantize(tmp_path): @@ -826,7 +837,9 @@ def test_import_multi_document_list_snippet(tmp_path): recipe = load_recipe(recipe_file) assert len(recipe.quantize["quant_cfg"]) == 1 assert recipe.quantize["quant_cfg"][0]["quantizer_name"] == "*[kv]_bmm_quantizer" - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + assert recipe.quantize["quant_cfg"][0]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": (4, 3) + } def test_import_builtin_kv_fp8_snippet(): @@ -938,7 +951,7 @@ def test_import_mixed_tree(tmp_path): ) data = load_config(config_file) # Dict import inside list entry - assert data["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + assert data["quant_cfg"][0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": (4, 3)} # List splice — entries are normalized by QuantizeConfig.quant_cfg's validator, # which fills in defaults for missing ``enable`` / ``cfg`` keys. Entries are now # QuantizerCfgEntry pydantic instances, so compare via model_dump. @@ -986,7 +999,7 @@ def test_import_recursive(tmp_path): ) recipe = load_recipe(recipe_file) cfg = recipe.quantize["quant_cfg"][0]["cfg"] - assert cfg == {"num_bits": (4, 3)} + assert cfg.model_dump(exclude_unset=True) == {"num_bits": (4, 3)} def test_import_circular_raises(tmp_path): @@ -1086,9 +1099,14 @@ def test_import_cross_file_same_name_no_conflict(tmp_path): ) recipe = load_recipe(recipe_file) # Parent's "fmt" resolves to fp8 (e4m3), not child's nvfp4. - assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)} + assert recipe.quantize["quant_cfg"][0]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": (4, 3) + } # Child's "fmt" resolves to nvfp4 (e2m1), not parent's fp8. - assert recipe.quantize["quant_cfg"][1]["cfg"] == {"num_bits": (2, 1), "axis": 0} + assert recipe.quantize["quant_cfg"][1]["cfg"].model_dump(exclude_unset=True) == { + "num_bits": (2, 1), + "axis": 0, + } # --------------------------------------------------------------------------- @@ -1179,13 +1197,12 @@ def test_modelopt_schema_comment_validates_after_import_resolution(tmp_path): f" $import: fp8\n" ) data = load_config(config_file) - # data is a list of QuantizerCfgEntry pydantic instances, not raw dicts. + # data is a list of QuantizerCfgEntry pydantic instances, not raw dicts. Dump with + # exclude_unset=True so the inner QuantizerAttributeConfig stays sparse (cascades). assert len(data) == 1 - assert data[0].model_dump() == { + assert data[0].model_dump(exclude_unset=True) == { "quantizer_name": "*weight_quantizer", - "parent_class": None, "cfg": {"num_bits": (4, 3)}, - "enable": True, } @@ -1291,12 +1308,12 @@ def test_load_config_list_valued_yaml(tmp_path): data = load_config(cfg_file) assert isinstance(data, list) assert len(data) == 2 - # Entries are QuantizerCfgEntry pydantic instances after schema validation. - assert data[0].model_dump() == { + # Entries are QuantizerCfgEntry pydantic instances after schema validation; dump + # with exclude_unset=True so the inner QuantizerAttributeConfig stays in sparse + # form (pydantic cascades exclude_unset to nested models). + assert data[0].model_dump(exclude_unset=True) == { "quantizer_name": "*weight_quantizer", - "parent_class": None, "cfg": {"num_bits": 8}, - "enable": True, } diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index 00f7cb3605e..ce98f989f51 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -81,7 +81,7 @@ def test_new_format_passthrough(self): result = normalize_quant_cfg_list(raw) assert len(result) == 1 assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 8, "axis": 0} assert result[0]["enable"] is True # defaulted def test_new_format_enable_false(self): @@ -103,7 +103,7 @@ def test_legacy_single_key_dict(self): raw = [{"*weight_quantizer": {"num_bits": 8, "axis": 0}}] result = normalize_quant_cfg_list(raw) assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 8, "axis": 0} assert result[0]["enable"] is True # defaulted def test_legacy_single_key_dict_with_enable(self): @@ -166,19 +166,19 @@ def test_error_on_multi_key_legacy_dict(self): def test_error_on_empty_cfg_dict_implicit_enable(self): """Entry with cfg={} and implicit enable=True is rejected.""" - with pytest.raises(ValueError, match="non-empty dict"): + with pytest.raises(ValueError, match=r"at least one quantizer attribute"): normalize_quant_cfg_list([{"quantizer_name": "*weight_quantizer", "cfg": {}}]) def test_error_on_empty_cfg_dict_explicit_enable_true(self): """Entry with cfg={} and explicit enable=True is rejected.""" - with pytest.raises(ValueError, match="non-empty dict"): + with pytest.raises(ValueError, match=r"at least one quantizer attribute"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": {}, "enable": True}] ) def test_error_on_empty_cfg_list_enable_true(self): """Entry with cfg=[] and enable=True is rejected.""" - with pytest.raises(ValueError, match="non-empty dict"): + with pytest.raises(ValueError, match=r"at least one quantizer attribute"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": [], "enable": True}] ) @@ -200,7 +200,7 @@ def test_error_on_non_dict_non_list_cfg_enable_true(self): def test_error_on_cfg_list_with_empty_dict_enable_true(self): """Entry with cfg=[{}] and enable=True is rejected (empty dict element).""" - with pytest.raises(ValueError, match="non-empty dict"): + with pytest.raises(ValueError, match=r"at least one quantizer attribute"): normalize_quant_cfg_list( [{"quantizer_name": "*weight_quantizer", "cfg": [{}], "enable": True}] ) @@ -260,7 +260,7 @@ def test_nonempty_cfg_enable_false_preserved(self): ] ) assert result[0]["enable"] is False - assert result[0]["cfg"] == {"num_bits": 4} + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 4} def test_new_format_with_list_cfg(self): """cfg can be a list of dicts for SequentialQuantizer.""" @@ -275,7 +275,7 @@ def test_new_format_with_list_cfg(self): ] result = normalize_quant_cfg_list(raw) assert len(result) == 1 - assert result[0]["cfg"] == raw[0]["cfg"] + assert [c.model_dump(exclude_unset=True) for c in result[0]["cfg"]] == raw[0]["cfg"] assert result[0]["enable"] is True def test_legacy_flat_dict_conversion(self): @@ -287,7 +287,7 @@ def test_legacy_flat_dict_conversion(self): assert result[0]["enable"] is False assert result[0]["cfg"] is None assert result[1]["quantizer_name"] == "*weight_quantizer" - assert result[1]["cfg"] == {"num_bits": 8, "axis": 0} + assert result[1]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 8, "axis": 0} assert result[1]["enable"] is True def test_legacy_enable_only_produces_cfg_none(self): @@ -318,7 +318,7 @@ def test_legacy_default_key_with_cfg(self): raw = [{"default": {"num_bits": 8, "axis": None}}] result = normalize_quant_cfg_list(raw) assert result[0]["quantizer_name"] == "*" - assert result[0]["cfg"] == {"num_bits": 8, "axis": None} + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 8, "axis": None} assert result[0]["enable"] is True def test_legacy_flat_dict_with_default_key(self): @@ -353,7 +353,7 @@ def test_legacy_nn_class_with_cfg(self): assert len(result) == 1 assert result[0]["parent_class"] == "nn.Linear" assert result[0]["quantizer_name"] == "*weight_quantizer" - assert result[0]["cfg"] == {"num_bits": 4, "axis": 0} + assert result[0]["cfg"].model_dump(exclude_unset=True) == {"num_bits": 4, "axis": 0} assert result[0]["enable"] is True def test_legacy_list_valued_cfg(self): @@ -387,7 +387,7 @@ def test_finds_last_match(self): ] ) result = find_quant_cfg_entry_by_path(entries, "*weight_quantizer") - assert result["cfg"] == {"num_bits": 4} + assert result["cfg"].model_dump(exclude_unset=True) == {"num_bits": 4} def test_exact_match_only(self): """Does not do fnmatch — only exact string equality on quantizer_name.""" @@ -444,7 +444,7 @@ def test_wildcard_matches_bare_name(self): [{"quantizer_name": "*weight_quantizer", "cfg": {"num_bits": 8}}] ) matched, enable = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 8} + assert matched.model_dump(exclude_unset=True) == {"num_bits": 8} assert enable is True def test_star_matches_any_bare_name(self): @@ -464,7 +464,7 @@ def test_path_scoped_pattern_matches_matching_suffix(self): [{"quantizer_name": "*mlp*weight_quantizer", "cfg": {"num_bits": 4}}] ) matched, enable = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 4} + assert matched.model_dump(exclude_unset=True) == {"num_bits": 4} def test_path_scoped_pattern_does_not_match_different_suffix(self): """'*mlp*weight_quantizer' does NOT match bare 'input_quantizer'.""" @@ -488,7 +488,7 @@ def test_last_match_wins(self): ] ) matched, _ = _match_quantizer_cfg(quant_cfg, "weight_quantizer") - assert matched == {"num_bits": 4} + assert matched.model_dump(exclude_unset=True) == {"num_bits": 4} def test_no_match_returns_none(self): """No matching entry returns (None, None)."""