Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
e21a9ba
feat(opt): make load_config return validated schema instances
shengliangxu May 14, 2026
198a305
feat(quant): make QuantizerCfgEntry a ModeloptBaseConfig pydantic type
shengliangxu May 14, 2026
058234a
refactor(quant): tighten type hints on QuantizeConfig field validators
shengliangxu May 14, 2026
02513b6
refactor(quant): name the legacy flat-dict quant_cfg input shape
shengliangxu May 14, 2026
c855d2a
refactor(quant): widen quant_cfg input types to Mapping/Sequence
shengliangxu May 14, 2026
0b6b2f0
need to have model_dump for explicitly set k/v
shengliangxu May 14, 2026
c7ab593
refactor(recipe): make RecipeMetadataConfig a ModeloptBaseConfig
shengliangxu May 14, 2026
df4ffb4
fix(recipe): require metadata and quantize sections
shengliangxu May 15, 2026
0d087e0
feat(opt): make ModeloptBaseConfig a real MutableMapping
shengliangxu May 15, 2026
77c8e67
fix(quant): normalize empty cfg to None when disabling a quantizer
shengliangxu May 15, 2026
c43c0f0
fix(quant): keep shared cfg snippets as dicts in public constants
shengliangxu May 15, 2026
a63d420
fix(opt): __setitem__ raises KeyError for unknown keys
shengliangxu May 15, 2026
d7b6e0a
test(quant): tighten cfg-shape rejection assertions
shengliangxu May 15, 2026
251ea06
refactor(quant): schematize QuantizerCfgEntry.cfg as QuantizerAttribu…
shengliangxu May 15, 2026
21d2b35
Merge branch 'main' into shengliangx/schematize-cfg
shengliangxu May 15, 2026
0b788e4
Merge branch 'main' into shengliangx/schematize-cfg
shengliangxu May 15, 2026
11ebae0
Merge branch 'main' into shengliangx/schematize-cfg
shengliangxu May 15, 2026
23c9060
Merge branch 'main' into shengliangx/schematize-cfg
shengliangxu May 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions modelopt/onnx/llm_export_utils/quantization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def get_quant_config(precision, lm_head_precision="fp16"):
else:
raise ValueError(f"Unsupported precision: {precision}")

quant_cfg_list: list = [
e for e in quant_cfg["quant_cfg"] if isinstance(e, dict) and "quantizer_name" in e
]
quant_cfg_list: list = [e for e in quant_cfg["quant_cfg"] if "quantizer_name" in e]

if lm_head_precision == "fp8":
quant_cfg_list.append(
Expand Down
55 changes: 24 additions & 31 deletions modelopt/recipe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
import warnings
from enum import Enum

from pydantic import field_validator, model_validator
from typing_extensions import NotRequired, TypedDict
from pydantic import Field, model_validator

from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
from modelopt.torch.quantization.config import QuantizeConfig
from modelopt.torch.quantization.config import QuantizeConfig # noqa: TC001
from modelopt.torch.speculative.config import DFlashConfig, EagleConfig, MedusaConfig
from modelopt.torch.speculative.plugins.hf_training_args import DataArguments as SpecDataArgs
from modelopt.torch.speculative.plugins.hf_training_args import ModelArguments as SpecModelArgs
Expand All @@ -43,14 +42,21 @@ class RecipeType(str, Enum):
# QAT = "qat" # Not implemented yet, will be added in the future.


class RecipeMetadataConfig(TypedDict):
"""YAML shape of the recipe metadata section."""
_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe."

recipe_type: RecipeType
description: NotRequired[str]

class RecipeMetadataConfig(ModeloptBaseConfig):
"""YAML shape of the recipe metadata section."""

_DEFAULT_RECIPE_DESCRIPTION = "Model optimization recipe."
recipe_type: RecipeType = Field(
title="Recipe type",
description="The type of the recipe (e.g. PTQ).",
)
description: str = ModeloptField(
default=_DEFAULT_RECIPE_DESCRIPTION,
title="Description",
description="Human-readable description of the recipe.",
)


def _metadata_field(recipe_type: RecipeType):
Expand All @@ -69,45 +75,32 @@ class ModelOptRecipeBase(ModeloptBaseConfig):
If a layer name matches ``"*output_layer*"``, the attributes will be replaced with ``{"enable": False}``.
"""

metadata: RecipeMetadataConfig = ModeloptField(
default={"recipe_type": RecipeType.PTQ, "description": _DEFAULT_RECIPE_DESCRIPTION},
metadata: RecipeMetadataConfig = Field(
title="Metadata",
description="Recipe metadata containing the recipe type and description.",
validate_default=True,
description="Recipe metadata containing the recipe type and description. "
"Required: a recipe without a ``metadata`` section is rejected so that a "
"missing section can't silently fall back to a default recipe type.",
)

@field_validator("metadata")
@classmethod
def validate_metadata(cls, metadata: RecipeMetadataConfig) -> RecipeMetadataConfig:
"""Validate recipe metadata and fill defaults for optional fields."""
if metadata["recipe_type"] not in RecipeType:
raise ValueError(
f"Unsupported recipe type: {metadata['recipe_type']}. "
f"Only {list(RecipeType)} are currently supported."
)
return {"description": _DEFAULT_RECIPE_DESCRIPTION, **metadata}

@property
def recipe_type(self) -> RecipeType:
"""Return the recipe type from metadata."""
return self.metadata["recipe_type"]
return self.metadata.recipe_type

@property
def description(self) -> str:
"""Return the recipe description from metadata."""
return self.metadata.get("description", _DEFAULT_RECIPE_DESCRIPTION)
return self.metadata.description


class ModelOptPTQRecipe(ModelOptRecipeBase):
"""Our config class for PTQ recipes."""

metadata: RecipeMetadataConfig = _metadata_field(RecipeType.PTQ)

quantize: QuantizeConfig = ModeloptField(
default=QuantizeConfig(),
quantize: QuantizeConfig = Field(
title="PTQ config",
description="PTQ config containing quant_cfg and algorithm.",
validate_default=True,
description="PTQ config containing quant_cfg and algorithm. Required: a PTQ "
"recipe without a ``quantize`` section is rejected so that a missing section "
"can't silently fall back to the default INT8 config.",
)


Expand Down
125 changes: 51 additions & 74 deletions modelopt/recipe/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@

from .config import (
RECIPE_TYPE_TO_CLASS,
ModelOptDFlashRecipe,
ModelOptEagleRecipe,
ModelOptMedusaRecipe,
ModelOptPTQRecipe,
ModelOptRecipeBase,
RecipeMetadataConfig,
Expand All @@ -40,6 +37,16 @@

__all__ = ["load_config", "load_recipe"]

# Each recipe type's mandatory top-level body section. Checked at the loader level (on the
# raw YAML, before pydantic fills in defaults) so the user sees a clear "PTQ recipe file X
# must contain 'quantize'" instead of pydantic's generic missing-field error.
_REQUIRED_SECTION_PER_RECIPE_TYPE: dict[RecipeType, str] = {
RecipeType.PTQ: "quantize",
RecipeType.SPECULATIVE_EAGLE: "eagle",
RecipeType.SPECULATIVE_DFLASH: "dflash",
RecipeType.SPECULATIVE_MEDUSA: "medusa",
}


def _resolve_recipe_path(recipe_path: str | Path | Traversable) -> Path | Traversable:
"""Resolve a recipe path, checking the built-in library first then the filesystem.
Expand Down Expand Up @@ -148,63 +155,48 @@ def _load_recipe_from_file(
plus the algorithm-specific section (``quantize`` / ``eagle`` / ``dflash`` / ``medusa``).
"""
rtype = _peek_recipe_type(recipe_file)
schema_type = RECIPE_TYPE_TO_CLASS.get(rtype) if rtype is not None else None
data = load_config(recipe_file, schema_type=schema_type)
if not isinstance(data, dict):
raise ValueError(
f"Recipe file {recipe_file} must be a YAML mapping, got {type(data).__name__}."
)
if rtype is None:
raise ValueError(f"Recipe file {recipe_file} must contain a 'metadata.recipe_type' field.")
schema_class = RECIPE_TYPE_TO_CLASS.get(rtype)
if schema_class is None:
raise ValueError(f"Unsupported recipe type: {rtype!r}")

# Pre-flight check on the *raw* YAML so the user sees a clear loader-level error
# rather than a generic pydantic missing-field error. Speculative recipes' body
# sections have field-level defaults, so this check is what keeps their loader
# semantics consistent with PTQ.
required_section = _REQUIRED_SECTION_PER_RECIPE_TYPE.get(rtype)
if required_section is not None:
import yaml

raw = yaml.safe_load(recipe_file.read_text()) or {}
if not isinstance(raw, dict) or required_section not in raw:
kind = (
rtype.value.split("_", 1)[-1].upper() if "_" in rtype.value else rtype.value.upper()
)
raise ValueError(f"{kind} recipe file {recipe_file} must contain {required_section!r}.")

# Passing ``schema_type=schema_class`` to ``load_config`` enables typed-list
# ``$import`` resolution (e.g. ``$import: disable_all`` spliced into
# ``quantize.quant_cfg`` needs to know the list's element schema is
# :class:`QuantizerCfgEntry`). The return value is already a validated schema
# instance.
if overrides:
# Overrides have to be applied before pydantic validation. Round-trip through
# ``model_dump()`` so $imports are resolved and the dict has the resolved shape;
# then splice the dotlist values and re-validate.
recipe = load_config(recipe_file, schema_type=schema_class)
data = recipe.model_dump()
data = _apply_dotlist(data, overrides)
return schema_class.model_validate(data)

metadata = data.get("metadata", {})
if not isinstance(metadata, dict):
recipe = load_config(recipe_file, schema_type=schema_class)
if not isinstance(recipe, schema_class):
raise ValueError(
f"Recipe file {recipe_file} field 'metadata' must be a mapping, "
f"got {type(metadata).__name__}."
)
recipe_type = metadata.get("recipe_type")
if recipe_type is None:
raise ValueError(f"Recipe file {recipe_file} must contain a 'metadata.recipe_type' field.")

if recipe_type == RecipeType.PTQ:
if "quantize" not in data:
raise ValueError(f"PTQ recipe file {recipe_file} must contain 'quantize'.")
return ModelOptPTQRecipe(
metadata=metadata,
quantize=data["quantize"],
f"Recipe file {recipe_file} must produce a {schema_class.__name__}, "
f"got {type(recipe).__name__}."
)
if recipe_type == RecipeType.SPECULATIVE_EAGLE:
if "eagle" not in data:
raise ValueError(f"EAGLE recipe file {recipe_file} must contain 'eagle'.")
return ModelOptEagleRecipe(
metadata=metadata,
model=data.get("model") or {},
data=data.get("data") or {},
training=data.get("training") or {},
eagle=data["eagle"],
)
if recipe_type == RecipeType.SPECULATIVE_DFLASH:
if "dflash" not in data:
raise ValueError(f"DFlash recipe file {recipe_file} must contain 'dflash'.")
return ModelOptDFlashRecipe(
metadata=metadata,
model=data.get("model") or {},
data=data.get("data") or {},
training=data.get("training") or {},
dflash=data["dflash"],
)
if recipe_type == RecipeType.SPECULATIVE_MEDUSA:
if "medusa" not in data:
raise ValueError(f"Medusa recipe file {recipe_file} must contain 'medusa'.")
return ModelOptMedusaRecipe(
metadata=metadata,
model=data.get("model") or {},
data=data.get("data") or {},
training=data.get("training") or {},
medusa=data["medusa"],
)
raise ValueError(f"Unsupported recipe type: {recipe_type!r}")
return recipe


def _find_recipe_section_file(
Expand All @@ -229,25 +221,10 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase:
quantize.
"""
metadata_file = _find_recipe_section_file(recipe_dir, "metadata")

metadata = load_config(metadata_file, schema_type=RecipeMetadataConfig)
if not isinstance(metadata, dict):
raise ValueError(
f"Metadata file {metadata_file} must be a YAML mapping, got {type(metadata).__name__}."
)
recipe_type = metadata.get("recipe_type")
if recipe_type is None:
raise ValueError(f"Metadata file {metadata_file} must contain a 'recipe_type' field.")

if recipe_type == RecipeType.PTQ:
if metadata.recipe_type == RecipeType.PTQ:
quantize_file = _find_recipe_section_file(recipe_dir, "quantize")
quantize_data = load_config(quantize_file, schema_type=QuantizeConfig)
if not isinstance(quantize_data, dict):
raise ValueError(
f"{quantize_file} must be a YAML mapping, got {type(quantize_data).__name__}."
)
return ModelOptPTQRecipe(
metadata=metadata,
quantize=quantize_data,
)
raise ValueError(f"Unsupported recipe type: {recipe_type!r}")
quantize_cfg = load_config(quantize_file, schema_type=QuantizeConfig)
return ModelOptPTQRecipe(metadata=metadata, quantize=quantize_cfg)
raise ValueError(f"Unsupported recipe type: {metadata.recipe_type!r}")
52 changes: 45 additions & 7 deletions modelopt/torch/opt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import fnmatch
import json
from collections.abc import Callable, ItemsView, Iterator, KeysView, ValuesView
from collections.abc import Callable, ItemsView, Iterator, KeysView, MutableMapping, ValuesView
from typing import Any, TypeAlias

import torch
Expand Down Expand Up @@ -57,11 +57,18 @@ def ModeloptField(default: Any = PydanticUndefined, **kwargs): # noqa: N802
# TODO: expand config classes to searcher


class ModeloptBaseConfig(BaseModel):
class ModeloptBaseConfig(BaseModel, MutableMapping):
"""Our config base class for mode configuration.

The base class extends the capabilities of pydantic's BaseModel to provide additional methods
and properties for easier access and manipulation of the configuration.

Inherits from :class:`collections.abc.MutableMapping` so instances satisfy
``isinstance(cfg, Mapping)`` / ``isinstance(cfg, MutableMapping)`` checks and pick up the
mixin methods (``pop``, ``popitem``, ``setdefault``, ``clear``). Schema fields are fixed,
so ``__delitem__`` raises :class:`TypeError`; the inherited ``pop`` / ``clear`` /
``popitem`` therefore also raise on any existing key, while ``pop(key, default)`` for a
missing key still returns the default normally.
"""

model_config = PyDanticConfigDict(extra="forbid", validate_assignment=True)
Expand Down Expand Up @@ -110,18 +117,49 @@ def __contains__(self, key: str) -> bool:
return False

def __getitem__(self, key: str) -> Any:
"""Get the value for the given key (can be name or alias of field)."""
return getattr(self, self.get_field_name_from_key(key))
"""Get the value for the given key (can be name or alias of field).

Raises :class:`KeyError` for missing keys so the class behaves like a regular
:class:`Mapping` — required for the inherited ``MutableMapping`` mixin methods
(``pop``, ``setdefault``, ...) to dispatch correctly.
"""
try:
return getattr(self, self.get_field_name_from_key(key))
except AttributeError:
raise KeyError(key) from None

def __setitem__(self, key: str, value: Any) -> None:
"""Set the value for the given key (can be name or alias of field)."""
setattr(self, self.get_field_name_from_key(key), value)
"""Set the value for the given key (can be name or alias of field).

Raises :class:`KeyError` (not :class:`AttributeError`) for unknown keys so the
class matches the :class:`MutableMapping` protocol — both for direct
``cfg["unknown"] = value`` writes and for inherited mixin helpers like
``setdefault`` that write through ``__setitem__``.
"""
try:
setattr(self, self.get_field_name_from_key(key), value)
except AttributeError:
raise KeyError(key) from None

def __delitem__(self, key: str) -> None:
"""Reject key deletion.

``ModeloptBaseConfig`` exposes a fixed pydantic schema, so removing a key is
ill-defined: schema fields can't disappear, and silently resetting them to their
defaults would surprise callers. Raise ``TypeError`` instead. Defined so the
class fully satisfies the ``MutableMapping`` protocol (``__delitem__`` is
required), without committing to actual deletion semantics.
"""
raise TypeError(
f"{type(self).__name__} does not support key deletion; schema fields are "
f"fixed (attempted to delete {key!r})."
)

def get(self, key: str, default: Any = None) -> Any:
"""Get the value for the given key (can be name or alias) or default if not found."""
Comment thread
shengliangxu marked this conversation as resolved.
try:
return self[key]
except AttributeError:
except KeyError:
return default

def __len__(self) -> int:
Expand Down
Loading
Loading