Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Changelog

**New Features**

- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress.
- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
- Add end-to-end tutorial for Minitron pruning + distillation + quantization + evaluation + vLLM deployment for Nemotron-Nano-9B-v2 → Pruned 7B along with data blend preparation steps (and ablation study). See `examples/pruning/minitron/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/pruning/minitron/>`_ for details.
- Add Puzzletron - a new algorithm for heterogeneous pruning of LLM and VLM models. See `examples/puzzletron/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/puzzletron>`_ for more details.
Expand Down
7 changes: 7 additions & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
"fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
"w4a16_nvfp4": mtq.W4A16_NVFP4_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG,
Expand Down Expand Up @@ -785,6 +786,12 @@ def export_quantized(
extra_state_dict=mtp_state_dict,
)

if args.qformat == "w4a16_nvfp4":
warnings.warn(
"TensorRT-LLM and SGLang do not support this format. "
"vLLM deployment support is in progress."
)
Comment thread
hychiang-git marked this conversation as resolved.

# Restore default padding and export the tokenizer as well.
if tokenizer is not None:
tokenizer.padding_side = default_padding_side
Expand Down
14 changes: 12 additions & 2 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian) ;;
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian | w4a16_nvfp4) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian]" >&2
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian, w4a16_nvfp4]" >&2
exit 1
;;
esac
Expand Down Expand Up @@ -127,6 +127,10 @@ if $TRUST_REMOTE_CODE; then
PTQ_ARGS+=" --trust_remote_code "
fi

if [ -n "${EXCLUDE_MODULES:-}" ]; then
PTQ_ARGS+=" --exclude_modules ${EXCLUDE_MODULES} "
fi

if $USE_SEQ_DEVICE_MAP; then
PTQ_ARGS+=" --use_seq_device_map "
fi
Expand Down Expand Up @@ -203,6 +207,12 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
exit 0
fi

if [ "$QFORMAT" = "w4a16_nvfp4" ]; then
echo "w4a16_nvfp4 checkpoint exported to $SAVE_PATH"
echo "To serve on vLLM, convert to compressed-tensors"
exit 0
fi

if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]]; then
cuda_major=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader -i 0 | cut -d. -f1)

Expand Down
13 changes: 13 additions & 0 deletions modelopt/torch/export/convert_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def _quant_algo_to_group_config(quant_algo: str, group_size: int | None = None)
return {
"weights": {"dynamic": False, "num_bits": 4, "type": "int", "group_size": gs},
}
elif quant_algo == "W4A16_NVFP4":
gs = group_size or 16
return {
"weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": gs},
}
elif quant_algo in ("NVFP4_AWQ", "W4A8_AWQ"):
gs = group_size or 128
return {
Expand Down Expand Up @@ -183,6 +188,14 @@ def convert_hf_quant_config_format(input_config: dict[str, Any]) -> dict[str, An
"targets": ["Linear"],
}
new_config["config_groups"] = {"group_0": config_group_details}
elif quant_algo_value == "W4A16_NVFP4":
# Weight-only FP4
group_size = original_quantization_details.get("group_size", 16)
config_group_details = {
"weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": group_size},
"targets": ["Linear"],
}
new_config["config_groups"] = {"group_0": config_group_details}
elif quant_algo_value == "MIXED_PRECISION":
quantized_layers = original_quantization_details.get("quantized_layers", {})

Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/export/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
QUANTIZATION_MXFP4 = "mxfp4"
QUANTIZATION_MXFP8 = "mxfp8"
QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8"
QUANTIZATION_W4A16_NVFP4 = "w4a16_nvfp4" # weight-only FP4
QUANTIZATION_NVFP4_AWQ = "nvfp4_awq"
QUANTIZATION_FP8_PB_REAL = "fp8_pb_real"
QUANTIZATION_FP8_PB_WO = "fp8_pb_wo"
Expand Down
13 changes: 13 additions & 0 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_MXFP4_FP8,
QUANTIZATION_W4A8_NVFP4_FP8,
QUANTIZATION_W4A16_NVFP4,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -359,6 +360,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A16_NVFP4,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
# Calibrate weight quantizer if amax is not set
Expand Down Expand Up @@ -403,6 +405,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A16_NVFP4,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
# Calibrate weight quantizer if amax is not set
Expand Down Expand Up @@ -641,6 +644,10 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
return QUANTIZATION_NVFP4_AWQ
if getattr(layer, "fused_with_prequant", False):
return QUANTIZATION_NVFP4_AWQ
# W4A16 weight-only: input_quantizer absent or disabled
if input_quantizer is None or not input_quantizer.is_enabled:
if scale_bits == (4, 3):
return QUANTIZATION_W4A16_NVFP4
assert input_quantizer is not None, (
f"input_quantizer is None for {quantizer_attr_names}"
)
Expand Down Expand Up @@ -808,6 +815,11 @@ def process_layer_quant_config(layer_config_dict):
"quant_algo": "NVFP4",
"group_size": block_size_value,
}
elif v == "w4a16_nvfp4":
layer_config = {
"quant_algo": "W4A16_NVFP4",
"group_size": block_size_value,
}
elif v == "nvfp4_awq":
layer_config = {
"quant_algo": "NVFP4_AWQ",
Expand Down Expand Up @@ -985,6 +997,7 @@ def to_quantized_weight(
if quantization in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_W4A16_NVFP4,
QUANTIZATION_W4A8_NVFP4_FP8,
QUANTIZATION_NVFP4_SVDQUANT,
]:
Expand Down
3 changes: 3 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
QUANTIZATION_W4A16_NVFP4,
)
from .model_utils import get_language_model_from_vl, is_multimodal_model
from .moe_utils import _export_fused_experts
Expand Down Expand Up @@ -517,6 +518,7 @@ def _export_quantized_weight(
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_NVFP4,
QUANTIZATION_W4A16_NVFP4,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
Expand Down Expand Up @@ -546,6 +548,7 @@ def _export_quantized_weight(
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A16_NVFP4,
]:
# Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
# for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization
Expand Down
2 changes: 2 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,6 +1683,7 @@ def _nvfp4_selective_quant_cfg(
],
"algorithm": "max",
}
W4A16_NVFP4_CFG = _nvfp4_selective_quant_cfg(["*"], weight_only=True)

MXFP4_MLP_WEIGHT_ONLY_CFG = {
"quant_cfg": [
Expand Down Expand Up @@ -1739,6 +1740,7 @@ def _nvfp4_selective_quant_cfg(
"NVFP4_FP8_MHA_CONFIG",
"NVFP4_KV_CFG",
"NVFP4_KV_ROTATE_CFG",
"W4A16_NVFP4_CFG",
"W4A8_NVFP4_FP8_CFG",
"NVFP4_SVDQUANT_DEFAULT_CFG",
"W4A8_AWQ_BETA_CFG",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
("w4a8_awq", "tiny_llama-w4a8-awq", True, False, True, True, False),
("int8_wo", "tiny_llama-int8-wo", False, False, False, False, False),
("nvfp4_svdquant", "tiny_llama-nvfp4-svdquant", True, False, True, True, True),
("w4a16_nvfp4", "tiny_llama-w4a16-nvfp4", False, False, False, False, False),
# MoE models (fused experts: Qwen3 MoE, GPT-OSS)
("nvfp4", "tiny_qwen3_moe-nvfp4", True, False, True, True, False),
("fp8", "tiny_gpt_oss-fp8", True, False, True, True, False),
Expand Down
3 changes: 3 additions & 0 deletions tools/launcher/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def get_default_env(experiment_title=None):
"HF_HOME": f"/{title}/hf-cache",
"HF_TOKEN": os.getenv("HF_TOKEN", ""),
"MLM_SKIP_INSTALL": "1",
# DockerExecutor runs as the host UID, which may not be in the container's
# /etc/passwd; setting USER prevents getpass.getuser() from calling pwd.getpwuid().
"USER": os.getenv("USER", "docker"),
}
return slurm_env, local_env

Expand Down
Loading