Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
04f918e
Type YAML config loading with Pydantic schemas
shengliangxu May 6, 2026
fdad74c
simplify quantize config loading
shengliangxu May 6, 2026
b38703a
Schematize quantizer config entries
shengliangxu May 7, 2026
b7c9359
use ModeloptField
shengliangxu May 7, 2026
b7df9d2
Return typed normalized quantizer config entries
shengliangxu May 7, 2026
53fcd04
Use typed mapping quant configs
shengliangxu May 7, 2026
b23e3a9
Add mixed quant config normalization test
shengliangxu May 7, 2026
5969cb3
Address quant config review feedback
shengliangxu May 8, 2026
0d31d46
Update recipe loader schema expectations
shengliangxu May 8, 2026
d33ee36
Tighten ModeloptBaseConfig mapping semantics
shengliangxu May 8, 2026
0917ab8
fix test errors
shengliangxu May 8, 2026
b5b45b9
Fix diffusers quant config explicit key handling
shengliangxu May 8, 2026
3f533fa
Merge remote-tracking branch 'origin/main' into shengliangx/schematiz…
shengliangxu May 8, 2026
164cccd
Merge remote-tracking branch 'origin/main' into shengliangx/schematiz…
shengliangxu May 9, 2026
d9eccf3
fix review comments
shengliangxu May 9, 2026
c69f86b
yaml for all hard coded PTQ configs
shengliangxu May 6, 2026
a2e033a
numerics yaml
shengliangxu May 7, 2026
8481658
Remove quantize config loader wrapper
shengliangxu May 8, 2026
b65f89d
Add KV quantization config units
shengliangxu May 8, 2026
abec47d
Remove stale FP8 config comments
shengliangxu May 9, 2026
e90653b
update int4 int8
shengliangxu May 9, 2026
df8c002
update descriptions
shengliangxu May 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions docs/source/guides/_quant_cfg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ Each entry in the list is a dictionary with the following fields:
(e.g. ``"nn.Linear"``). If omitted, all modules are targeted regardless of class.
* - ``cfg``
- No
- A dict of quantizer attributes as defined by :class:`QuantizerAttributeConfig
<modelopt.torch.quantization.config.QuantizerAttributeConfig>`, or a list of such dicts
for sequential quantization (see :ref:`sequential-quantizers`).
- A :class:`QuantizerAttributeConfig
<modelopt.torch.quantization.config.QuantizerAttributeConfig>`, or a list of
``QuantizerAttributeConfig`` objects for sequential quantization (see
:ref:`sequential-quantizers`). Equivalent Python dicts, YAML mappings, and lists of
dicts are still accepted for backward compatibility, but those weakly schematized forms
are deprecated.
* - ``enable``
- No
- ``True`` or ``False``. Toggles matched quantizers on or off, independently of ``cfg``.
Expand All @@ -74,6 +77,11 @@ Each entry in the list is a dictionary with the following fields:
a bare ``{"quantizer_name": "*"}`` would silently behave as ``enable=True`` for all
quantizers.

Schema-backed YAML loading parses ``cfg`` mappings into
:class:`QuantizerAttributeConfig <modelopt.torch.quantization.config.QuantizerAttributeConfig>`
values. Plain Python dicts and lists of dicts are accepted only as a backward-compatible,
weakly schematized input format.

----------

Default Quantizer Configuration
Expand Down Expand Up @@ -278,7 +286,9 @@ For entirely custom recipes, compose the list directly:
Sequential Quantization
=======================

When ``cfg`` is a **list** of attribute dicts, the matched
When ``cfg`` is a **list** of
:class:`QuantizerAttributeConfig <modelopt.torch.quantization.config.QuantizerAttributeConfig>`
values, the matched
:class:`TensorQuantizer <modelopt.torch.quantization.nn.modules.tensor_quantizer.TensorQuantizer>`
is replaced with a
:class:`SequentialQuantizer <modelopt.torch.quantization.nn.modules.tensor_quantizer.SequentialQuantizer>`
Expand All @@ -295,6 +305,11 @@ are quantized first in INT4 and then in FP8:
],
}

The list-of-dict spelling shown above remains accepted for existing Python configs and is the
natural YAML spelling, but it is a deprecated weakly schematized input form. After schema-backed
loading or :class:`QuantizeConfig <modelopt.torch.quantization.config.QuantizeConfig>` parsing,
each element is a ``QuantizerAttributeConfig``.

----------

.. _migrating-from-dict-format:
Expand Down
13 changes: 11 additions & 2 deletions examples/diffusers/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Mapping, MutableMapping

import torch.nn as nn
from calib.plugin_calib import PercentileCalibrator

from modelopt.torch.quantization.config import QuantizerAttributeConfig

FP8_DEFAULT_CONFIG = {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
Expand Down Expand Up @@ -104,8 +108,13 @@ def set_quant_config_attr(quant_config, trt_high_precision_dtype, quant_algo, **
quant_config["algorithm"] = algo_cfg

for entry in quant_config["quant_cfg"]:
p = entry.get("cfg", {})
if isinstance(p, dict) and "num_bits" in p and "trt_high_precision_dtype" not in p:
p = entry.get("cfg", {}) if isinstance(entry, Mapping) else {}
if not isinstance(p, MutableMapping):
continue
keys = p.explicit_keys() if isinstance(p, QuantizerAttributeConfig) else p.keys()
# TODO: Replace this membership-based config patching with a better config API;
# ``in``/``not in`` checks are fragile with schema-backed defaults.
if "num_bits" in keys and "trt_high_precision_dtype" not in keys:
p["trt_high_precision_dtype"] = trt_high_precision_dtype


Expand Down
35 changes: 24 additions & 11 deletions examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# limitations under the License.

import argparse
import copy
import logging
import sys
import time as time
from collections.abc import Mapping
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -49,6 +51,7 @@
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.export import export_hf_checkpoint
from modelopt.torch.opt.config import ModeloptBaseConfig


def setup_logging(verbose: bool = False) -> logging.Logger:
Expand Down Expand Up @@ -119,14 +122,6 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any:
base_cfg = mtq.INT8_SMOOTHQUANT_CFG
else:
base_cfg = INT8_DEFAULT_CONFIG
if self.config.collect_method != CollectMethod.DEFAULT:
reset_set_int8_config(
base_cfg,
self.config.percentile,
n_steps,
collect_method=self.config.collect_method.value,
backbone=backbone,
)
elif self.config.format == QuantFormat.FP8:
base_cfg = FP8_DEFAULT_CONFIG
elif self.config.format == QuantFormat.FP4:
Expand All @@ -138,15 +133,33 @@ def get_quant_config(self, n_steps: int, backbone: torch.nn.Module) -> Any:
raise NotImplementedError(f"Unknown format {self.config.format}")

# Build a fresh config dict so we never mutate the global constants.
if isinstance(base_cfg, ModeloptBaseConfig):
base_cfg = base_cfg.model_dump(exclude_unset=True)
base_cfg = copy.deepcopy(base_cfg)

if (
self.config.format == QuantFormat.INT8
and self.config.collect_method != CollectMethod.DEFAULT
):
reset_set_int8_config(
base_cfg,
self.config.percentile,
n_steps,
collect_method=self.config.collect_method.value,
backbone=backbone,
)

quant_cfg_list = list(base_cfg["quant_cfg"])

if self.config.format == QuantFormat.FP4:
for i, entry in enumerate(quant_cfg_list):
if isinstance(entry, dict) and "block_sizes" in entry.get("cfg", {}):
new_block_sizes = {**entry["cfg"]["block_sizes"], -1: self.config.block_size}
cfg = entry.get("cfg", {}) if isinstance(entry, Mapping) else {}
block_sizes = cfg.get("block_sizes") if isinstance(cfg, Mapping) else None
if isinstance(block_sizes, Mapping):
new_block_sizes = {**block_sizes, -1: self.config.block_size}
quant_cfg_list[i] = {
**entry,
"cfg": {**entry["cfg"], "block_sizes": new_block_sizes},
"cfg": {**cfg, "block_sizes": new_block_sizes},
}

if self.config.quantize_mha:
Expand Down
5 changes: 3 additions & 2 deletions examples/llm_autodeploy/run_auto_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@

import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.config import QuantizeConfig
from modelopt.torch.utils import create_forward_loop
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader

SUPPORT_QUANT_FORMAT = {
SUPPORT_QUANT_FORMAT: dict[str, QuantizeConfig] = {
"fp8": mtq.FP8_DEFAULT_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
}
Expand Down Expand Up @@ -87,7 +88,7 @@ def loss_func(output, data):
data_loader=calib_dataloader,
forward_step=lambda model, batch: model(**batch),
loss_func=loss_func,
quantization_formats=[SUPPORT_QUANT_FORMAT[format] for format in qformat_list],
quantization_formats=[SUPPORT_QUANT_FORMAT[quant_format] for quant_format in qformat_list],
num_calib_steps=len(calib_dataloader),
num_score_steps=min(
len(calib_dataloader), 128 // batch_size
Expand Down
3 changes: 2 additions & 1 deletion examples/llm_ptq/cast_mxfp4_to_nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"""

import json
from collections.abc import Mapping
from contextlib import ExitStack, contextmanager
from pathlib import Path

Expand Down Expand Up @@ -304,7 +305,7 @@ def force_weight_quantizers_static(quant_cfg: list) -> None:
qname = entry.get("quantizer_name", "")
cfg = entry.get("cfg") or {}
bs = cfg.get("block_sizes")
if "weight_quantizer" in qname and isinstance(bs, dict):
if "weight_quantizer" in qname and isinstance(bs, Mapping):
quant_cfg[i] = {**entry, "cfg": {**cfg, "block_sizes": {**bs, "type": "static"}}}


Expand Down
45 changes: 26 additions & 19 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import shutil
import sys
import warnings
from collections.abc import Mapping, MutableMapping
from pathlib import Path
from typing import Any

Expand All @@ -41,6 +42,8 @@
ProcessorMixin,
)

from modelopt.torch.quantization.config import QuantizeConfig, QuantizerCfgEntry

try:
from huggingface_hub import snapshot_download
except ImportError:
Expand Down Expand Up @@ -203,17 +206,17 @@ def calibrate_loop(_model):

def build_quant_cfg(
qformat,
quant_cfg,
quant_cfg: QuantizeConfig | Mapping[str, Any],
awq_block_size,
model_type,
moe_calib_experts_ratio: float | None = None,
) -> dict[str, Any]:
quant_cfg = copy.deepcopy(quant_cfg)
if "awq" in str(quant_cfg.get("algorithm")):
) -> QuantizeConfig:
quant_cfg_obj: QuantizeConfig = QuantizeConfig.model_validate(copy.deepcopy(quant_cfg))
if "awq" in str(quant_cfg_obj.get("algorithm")):
from modelopt.torch.quantization.config import find_quant_cfg_entry_by_path

weight_quantizer_entry = find_quant_cfg_entry_by_path(
quant_cfg["quant_cfg"], "*weight_quantizer"
quant_cfg_obj["quant_cfg"], "*weight_quantizer"
)
weight_quantizer = weight_quantizer_entry.get("cfg") or {}
if isinstance(weight_quantizer, list):
Expand All @@ -224,34 +227,38 @@ def build_quant_cfg(

# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}
quant_cfg_obj["algorithm"] = {"method": "awq_lite", "alpha_step": 1}

if moe_calib_experts_ratio:
assert 0 < moe_calib_experts_ratio <= 1, "moe_calib_experts_ratio must be between 0 and 1"
if isinstance(quant_cfg["algorithm"], str):
quant_cfg["algorithm"] = {
"method": quant_cfg["algorithm"],
if isinstance(quant_cfg_obj["algorithm"], str):
quant_cfg_obj["algorithm"] = {
"method": quant_cfg_obj["algorithm"],
"moe_calib_experts_ratio": moe_calib_experts_ratio,
}
elif isinstance(quant_cfg["algorithm"], dict):
quant_cfg["algorithm"]["moe_calib_experts_ratio"] = moe_calib_experts_ratio
elif isinstance(quant_cfg_obj["algorithm"], MutableMapping):
quant_cfg_obj["algorithm"]["moe_calib_experts_ratio"] = moe_calib_experts_ratio
else:
warnings.warn(
f"Quantization algorithm: {quant_cfg['algorithm']} does not support setting moe_calib_experts_ratio"
f"Quantization algorithm: {quant_cfg_obj['algorithm']} does not support setting moe_calib_experts_ratio"
)

# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
if model_type == "gemma" and "int8_sq" in qformat:
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
quant_cfg_obj["algorithm"] = {"method": "smoothquant", "alpha": 0.5}

if model_type == "phi4mm":
# Only quantize the language model
quant_cfg["quant_cfg"].append({"quantizer_name": "*speech*", "enable": False})
quant_cfg["quant_cfg"].append({"quantizer_name": "*audio*", "enable": False})
quant_cfg["quant_cfg"].append({"quantizer_name": "*image*", "enable": False})
quant_cfg["quant_cfg"].append({"quantizer_name": "*vision*", "enable": False})
quant_cfg_obj["quant_cfg"].extend(
[
QuantizerCfgEntry(quantizer_name="*speech*", enable=False),
QuantizerCfgEntry(quantizer_name="*audio*", enable=False),
QuantizerCfgEntry(quantizer_name="*image*", enable=False),
QuantizerCfgEntry(quantizer_name="*vision*", enable=False),
]
)

return quant_cfg
return quant_cfg_obj


def is_speculative(hf_config):
Expand Down Expand Up @@ -842,7 +849,7 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod
def needs_checkpoint_path_update(quant_cfg: dict) -> bool:
"""Check if quant_cfg has a layerwise_checkpoint_dir that should be auto-resolved to a unique subpath."""
algorithm = quant_cfg.get("algorithm")
if not isinstance(algorithm, dict):
if not isinstance(algorithm, Mapping):
return False
return algorithm.get("layerwise_checkpoint_dir") is not None

Expand Down
20 changes: 13 additions & 7 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@
save_expert_token_count_table,
)
from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
from modelopt.torch.quantization.config import (
QuantizeConfig,
_default_disabled_quantizer_cfg,
need_calibration,
)
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
from modelopt.torch.quantization.utils import is_quantized
from modelopt.torch.speculative.eagle.utils import (
Expand All @@ -89,18 +93,20 @@
def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
"""Set use_constant_amax on KV cache quantizers.

Creates a new dict for the KV bmm quantizer config to avoid mutating shared references.
Updates the matched KV bmm quantizer entry in place.
"""
for i, entry in enumerate(quant_cfg):
for entry in quant_cfg:
if entry.get("quantizer_name") != "*[kv]_bmm_quantizer":
continue
cfg = entry.get("cfg") or {}
assert isinstance(cfg, dict)
quant_cfg[i] = {**entry, "cfg": {**cfg, "use_constant_amax": True}}
cfg = entry.get("cfg")
if cfg is None:
cfg = {}
cfg["use_constant_amax"] = True
entry["cfg"] = cfg
break


QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
QUANT_CFG_CHOICES: dict[str, QuantizeConfig] = {
"int8": mtq.INT8_DEFAULT_CFG,
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
"int8_wo": mtq.INT8_WEIGHT_ONLY_CFG,
Expand Down
5 changes: 2 additions & 3 deletions examples/llm_ptq/multinode_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import time
import warnings
from pathlib import Path
from typing import Any

import numpy as np
import torch
Expand All @@ -37,14 +36,14 @@
from modelopt.torch.export import get_model_type
from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format
from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint
from modelopt.torch.quantization.config import need_calibration
from modelopt.torch.quantization.config import QuantizeConfig, need_calibration
from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets

# Constants
RAND_SEED = 1234

QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
QUANT_CFG_CHOICES: dict[str, QuantizeConfig] = {
"int8": mtq.INT8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
Expand Down
4 changes: 2 additions & 2 deletions examples/vllm_serve/vllm_ptq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import dataclasses
from collections.abc import Callable
from collections.abc import Callable, Mapping
from typing import Any

import torch
Expand Down Expand Up @@ -122,7 +122,7 @@ def update_kv_cfg_for_mla(model: torch.nn.Module, kv_quant_cfg: list) -> list:
(
e
for e in kv_quant_cfg
if isinstance(e, dict) and e.get("quantizer_name") == "*[kv]_bmm_quantizer"
if isinstance(e, Mapping) and e.get("quantizer_name") == "*[kv]_bmm_quantizer"
),
None,
)
Expand Down
4 changes: 1 addition & 3 deletions modelopt/onnx/llm_export_utils/quantization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def get_quant_config(precision, lm_head_precision="fp16"):
else:
raise ValueError(f"Unsupported precision: {precision}")

quant_cfg_list: list = [
e for e in quant_cfg["quant_cfg"] if isinstance(e, dict) and "quantizer_name" in e
]
quant_cfg_list: list = list(quant_cfg["quant_cfg"])

if lm_head_precision == "fp8":
quant_cfg_list.append(
Expand Down
Loading
Loading