diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index ad200093904..05be4dbf01e 100755
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -28,6 +28,7 @@ Changelog
**New Features**
+- Add NVFP4 W4A16 weight-only quantization (``w4a16_nvfp4``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.W4A16_NVFP4_CFG`` or ``--qformat w4a16_nvfp4`` in ``hf_ptq.py``. vLLM deployment support is in progress.
- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
- Add end-to-end tutorial for Minitron pruning + distillation + quantization + evaluation + vLLM deployment for Nemotron-Nano-9B-v2 → Pruned 7B along with data blend preparation steps (and ablation study). See `examples/pruning/minitron/README.md `_ for details.
- Add Puzzletron - a new algorithm for heterogeneous pruning of LLM and VLM models. See `examples/puzzletron/README.md `_ for more details.
diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py
index 758ed75aeed..44eb05c9358 100755
--- a/examples/llm_ptq/hf_ptq.py
+++ b/examples/llm_ptq/hf_ptq.py
@@ -113,6 +113,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
"fp8_pc_pt": mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,
"w4a8_nvfp4_fp8": mtq.W4A8_NVFP4_FP8_CFG,
+ "w4a16_nvfp4": mtq.W4A16_NVFP4_CFG,
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_experts_only": mtq.NVFP4_EXPERTS_ONLY_CFG,
@@ -785,6 +786,12 @@ def export_quantized(
extra_state_dict=mtp_state_dict,
)
+ if args.qformat == "w4a16_nvfp4":
+ warnings.warn(
+ "TensorRT-LLM and SGLang do not support this format. "
+ "vLLM deployment support is in progress."
+ )
+
# Restore default padding and export the tokenizer as well.
if tokenizer is not None:
tokenizer.padding_side = default_padding_side
@@ -1128,7 +1135,7 @@ def quantize_main(
quant_cfg = copy.deepcopy(quant_cfg)
force_weight_quantizers_static(quant_cfg["quant_cfg"])
- if args.qformat in QUANT_CFG_CHOICES:
+ if args.recipe is not None or args.qformat in QUANT_CFG_CHOICES:
mono_quantize(
args,
quant_cfg,
diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh
index 6ca99c7f963..f0730f66434 100755
--- a/examples/llm_ptq/scripts/huggingface_example.sh
+++ b/examples/llm_ptq/scripts/huggingface_example.sh
@@ -53,15 +53,20 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
- fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian) ;;
+ fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian | w4a16_nvfp4) ;;
*)
- echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian]" >&2
+ echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian, w4a16_nvfp4]" >&2
exit 1
;;
esac
done
IFS=" "
+if [ -n "$RECIPE" ] && [ -n "$QFORMAT" ]; then
+ echo "Error: --recipe and --quant are mutually exclusive." >&2
+ exit 1
+fi
+
script_dir="$(dirname "$(readlink -f "$0")")"
pushd $script_dir/..
@@ -72,7 +77,12 @@ fi
QFORMAT_MODIFIED="${QFORMAT//,/_}"
-MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}}
+if [ -n "$RECIPE" ]; then
+ RECIPE_LABEL=$(basename "$RECIPE" .yaml | sed 's/[^0-9a-zA-Z\-]/_/g')
+ MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${RECIPE_LABEL}
+else
+ MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT_MODIFIED}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}}
+fi
SAVE_PATH=${ROOT_SAVE_PATH}/saved_models_${MODEL_NAME}
@@ -127,6 +137,10 @@ if $TRUST_REMOTE_CODE; then
PTQ_ARGS+=" --trust_remote_code "
fi
+if [ -n "${EXCLUDE_MODULES:-}" ]; then
+ PTQ_ARGS+=" --exclude_modules ${EXCLUDE_MODULES} "
+fi
+
if $USE_SEQ_DEVICE_MAP; then
PTQ_ARGS+=" --use_seq_device_map "
fi
@@ -177,11 +191,16 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
if [[ "$MODEL_CONFIG_EXIST" == false ]]; then
echo "Quantizing original model..."
+ if [ -n "$RECIPE" ]; then
+ QUANT_ARG="--recipe=$RECIPE"
+ else
+ QUANT_ARG="--qformat=${QFORMAT// /,}"
+ fi
python hf_ptq.py \
--pyt_ckpt_path=$MODEL_PATH \
--export_path=$SAVE_PATH \
--sparsity_fmt=$SPARSITY_FMT \
- --qformat="${QFORMAT// /,}" \
+ $QUANT_ARG \
--calib_size=$CALIB_SIZE \
--batch_size=$CALIB_BATCH_SIZE \
--inference_tensor_parallel=$TP \
@@ -203,6 +222,12 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
exit 0
fi
+ if [ "$QFORMAT" = "w4a16_nvfp4" ]; then
+ echo "w4a16_nvfp4 checkpoint exported to $SAVE_PATH"
+ echo "To serve on vLLM, convert to compressed-tensors"
+ exit 0
+ fi
+
if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]]; then
cuda_major=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader -i 0 | cut -d. -f1)
diff --git a/examples/llm_ptq/scripts/parser.sh b/examples/llm_ptq/scripts/parser.sh
index 3817c1dee7c..09896bef7fe 100644
--- a/examples/llm_ptq/scripts/parser.sh
+++ b/examples/llm_ptq/scripts/parser.sh
@@ -99,8 +99,8 @@ parse_options() {
fi
# Verify required options are provided
- if [ -z "$MODEL_PATH" ] || [ -z "$QFORMAT" ] || [ -z "$TASKS" ]; then
- echo "Usage: $0 --model= --quant= --tasks="
+ if [ -z "$MODEL_PATH" ] || [ -z "$TASKS" ] || { [ -z "$QFORMAT" ] && [ -z "$RECIPE" ]; }; then
+ echo "Usage: $0 --model= (--quant= | RECIPE=) --tasks="
echo "Optional args: --sparsity= --awq_block_size= --calib="
exit 1
fi
diff --git a/modelopt/torch/export/convert_hf_config.py b/modelopt/torch/export/convert_hf_config.py
index 5f8c3f3b55c..06e5923a30f 100644
--- a/modelopt/torch/export/convert_hf_config.py
+++ b/modelopt/torch/export/convert_hf_config.py
@@ -57,6 +57,11 @@ def _quant_algo_to_group_config(quant_algo: str, group_size: int | None = None)
return {
"weights": {"dynamic": False, "num_bits": 4, "type": "int", "group_size": gs},
}
+ elif quant_algo == "W4A16_NVFP4":
+ gs = group_size or 16
+ return {
+ "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": gs},
+ }
elif quant_algo in ("NVFP4_AWQ", "W4A8_AWQ"):
gs = group_size or 128
return {
@@ -183,6 +188,14 @@ def convert_hf_quant_config_format(input_config: dict[str, Any]) -> dict[str, An
"targets": ["Linear"],
}
new_config["config_groups"] = {"group_0": config_group_details}
+ elif quant_algo_value == "W4A16_NVFP4":
+ # Weight-only FP4
+ group_size = original_quantization_details.get("group_size", 16)
+ config_group_details = {
+ "weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": group_size},
+ "targets": ["Linear"],
+ }
+ new_config["config_groups"] = {"group_0": config_group_details}
elif quant_algo_value == "MIXED_PRECISION":
quantized_layers = original_quantization_details.get("quantized_layers", {})
diff --git a/modelopt/torch/export/model_config.py b/modelopt/torch/export/model_config.py
index dce39767c76..308a18daee0 100755
--- a/modelopt/torch/export/model_config.py
+++ b/modelopt/torch/export/model_config.py
@@ -38,6 +38,7 @@
QUANTIZATION_MXFP4 = "mxfp4"
QUANTIZATION_MXFP8 = "mxfp8"
QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8"
+QUANTIZATION_W4A16_NVFP4 = "w4a16_nvfp4" # weight-only FP4
QUANTIZATION_NVFP4_AWQ = "nvfp4_awq"
QUANTIZATION_FP8_PB_REAL = "fp8_pb_real"
QUANTIZATION_FP8_PB_WO = "fp8_pb_wo"
diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py
index 76f304a478a..95d9e288a27 100755
--- a/modelopt/torch/export/quant_utils.py
+++ b/modelopt/torch/export/quant_utils.py
@@ -69,6 +69,7 @@
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_MXFP4_FP8,
QUANTIZATION_W4A8_NVFP4_FP8,
+ QUANTIZATION_W4A16_NVFP4,
)
logger = logging.getLogger(__name__)
@@ -359,6 +360,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
+ QUANTIZATION_W4A16_NVFP4,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
# Calibrate weight quantizer if amax is not set
@@ -403,6 +405,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
+ QUANTIZATION_W4A16_NVFP4,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
# Calibrate weight quantizer if amax is not set
@@ -641,6 +644,10 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
return QUANTIZATION_NVFP4_AWQ
if getattr(layer, "fused_with_prequant", False):
return QUANTIZATION_NVFP4_AWQ
+ # W4A16 weight-only: input_quantizer absent or disabled
+ if input_quantizer is None or not input_quantizer.is_enabled:
+ if scale_bits == (4, 3):
+ return QUANTIZATION_W4A16_NVFP4
assert input_quantizer is not None, (
f"input_quantizer is None for {quantizer_attr_names}"
)
@@ -808,6 +815,11 @@ def process_layer_quant_config(layer_config_dict):
"quant_algo": "NVFP4",
"group_size": block_size_value,
}
+ elif v == "w4a16_nvfp4":
+ layer_config = {
+ "quant_algo": "W4A16_NVFP4",
+ "group_size": block_size_value,
+ }
elif v == "nvfp4_awq":
layer_config = {
"quant_algo": "NVFP4_AWQ",
@@ -985,6 +997,7 @@ def to_quantized_weight(
if quantization in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
+ QUANTIZATION_W4A16_NVFP4,
QUANTIZATION_W4A8_NVFP4_FP8,
QUANTIZATION_NVFP4_SVDQUANT,
]:
diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py
index 7585d4b108a..2068c24438d 100644
--- a/modelopt/torch/export/unified_export_hf.py
+++ b/modelopt/torch/export/unified_export_hf.py
@@ -82,6 +82,7 @@
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
+ QUANTIZATION_W4A16_NVFP4,
)
from .model_utils import get_language_model_from_vl, is_multimodal_model
from .moe_utils import _export_fused_experts
@@ -517,6 +518,7 @@ def _export_quantized_weight(
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_NVFP4,
+ QUANTIZATION_W4A16_NVFP4,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
@@ -546,6 +548,7 @@ def _export_quantized_weight(
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
+ QUANTIZATION_W4A16_NVFP4,
]:
# Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
# for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization
diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py
index 3adb70cf6b7..62c7cda50b3 100644
--- a/modelopt/torch/quantization/config.py
+++ b/modelopt/torch/quantization/config.py
@@ -1683,6 +1683,7 @@ def _nvfp4_selective_quant_cfg(
],
"algorithm": "max",
}
+W4A16_NVFP4_CFG = _nvfp4_selective_quant_cfg(["*"], weight_only=True)
MXFP4_MLP_WEIGHT_ONLY_CFG = {
"quant_cfg": [
@@ -1739,6 +1740,7 @@ def _nvfp4_selective_quant_cfg(
"NVFP4_FP8_MHA_CONFIG",
"NVFP4_KV_CFG",
"NVFP4_KV_ROTATE_CFG",
+ "W4A16_NVFP4_CFG",
"W4A8_NVFP4_FP8_CFG",
"NVFP4_SVDQUANT_DEFAULT_CFG",
"W4A8_AWQ_BETA_CFG",
diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py
index 77f26b20602..9e85fc2d6e4 100644
--- a/modelopt/torch/quantization/plugins/huggingface.py
+++ b/modelopt/torch/quantization/plugins/huggingface.py
@@ -900,6 +900,26 @@ def forward(self, *args, **kwargs):
self._down_proj_linear = False
return super().forward(*args, **kwargs)
+ def iter_weights_for_calibration(self):
+ """Yield ``(weight_slice, quantizer)`` pairs for each expert and weight type.
+
+ The base implementation resolves singular ``*_weight_quantizer`` names via
+ ``quantizer_attr_names``, but fused experts store per-expert quantizers as
+ ``nn.ModuleList`` attributes (``gate_up_proj_weight_quantizers``,
+ ``down_proj_weight_quantizers``). Override to yield the per-expert slice
+ and its corresponding quantizer directly.
+ """
+ for weight_name, quantizers_name in (
+ ("gate_up_proj", "gate_up_proj_weight_quantizers"),
+ ("down_proj", "down_proj_weight_quantizers"),
+ ):
+ weight = getattr(self, weight_name, None)
+ quantizers = getattr(self, quantizers_name, None)
+ if weight is None or quantizers is None:
+ continue
+ for idx, q in enumerate(quantizers):
+ yield weight[idx], q
+
def fold_weight(self, keep_attrs: bool = False):
"""Fold per-expert weight quantizers into the fused 3-D weights.
diff --git a/modelopt_recipes/configs/ptq/units/w4a16_nvfp4.yaml b/modelopt_recipes/configs/ptq/units/w4a16_nvfp4.yaml
new file mode 100644
index 00000000000..b4676dbff34
--- /dev/null
+++ b/modelopt_recipes/configs/ptq/units/w4a16_nvfp4.yaml
@@ -0,0 +1,24 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# W4A16 NVFP4: NVFP4 E2M1 dynamic weight quantizer only; activations remain in BF16.
+
+# modelopt-schema: modelopt.torch.quantization.config.QuantizerCfgListConfig
+imports:
+ nvfp4: configs/numerics/nvfp4
+---
+ - quantizer_name: '*weight_quantizer'
+ cfg:
+ $import: nvfp4
diff --git a/modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp16.yaml b/modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp16.yaml
new file mode 100644
index 00000000000..03ee1b2236e
--- /dev/null
+++ b/modelopt_recipes/general/ptq/nvfp4_weight_only-kv_fp16.yaml
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+imports:
+ base_disable_all: configs/ptq/units/base_disable_all
+ default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers
+ w4a16_nvfp4: configs/ptq/units/w4a16_nvfp4
+
+metadata:
+ recipe_type: ptq
+ description: NVFP4 W4A16 weight-only, BF16 activations, max calibration. No calibration forward pass required.
+quantize:
+ algorithm: max
+ quant_cfg:
+ - $import: base_disable_all
+ - $import: w4a16_nvfp4
+ - $import: default_disabled_quantizers
diff --git a/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py b/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py
index 8bdf3f5e659..6e0c56bfd1d 100644
--- a/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py
+++ b/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py
@@ -47,6 +47,7 @@
("w4a8_awq", "tiny_llama-w4a8-awq", True, False, True, True, False),
("int8_wo", "tiny_llama-int8-wo", False, False, False, False, False),
("nvfp4_svdquant", "tiny_llama-nvfp4-svdquant", True, False, True, True, True),
+ ("w4a16_nvfp4", "tiny_llama-w4a16-nvfp4", False, False, False, False, False),
# MoE models (fused experts: Qwen3 MoE, GPT-OSS)
("nvfp4", "tiny_qwen3_moe-nvfp4", True, False, True, True, False),
("fp8", "tiny_gpt_oss-fp8", True, False, True, True, False),
diff --git a/tools/launcher/core.py b/tools/launcher/core.py
index 8fd4e25ee79..bcade6e7508 100644
--- a/tools/launcher/core.py
+++ b/tools/launcher/core.py
@@ -50,6 +50,9 @@ def get_default_env(experiment_title=None):
"HF_HOME": f"/{title}/hf-cache",
"HF_TOKEN": os.getenv("HF_TOKEN", ""),
"MLM_SKIP_INSTALL": "1",
+ # DockerExecutor runs as the host UID, which may not be in the container's
+ # /etc/passwd; setting USER prevents getpass.getuser() from calling pwd.getpwuid().
+ "USER": os.getenv("USER", "docker"),
}
return slurm_env, local_env