This repo explores how modern hardware could change the landscape of Quantize-aware training (QAT).
For years, QAT primarily to help mapping model inference and deployment on high-throughput Int8 hardware that are widely available on most hardware including CPU, edge devices. It is effective: by simulating quantization effects on full-precision accelerators (i.e., fake-quantization), models could be fine-tuned to recover the accuracy lost caused by rounding/truncation effect of lower precision.
That approach made sense back when accelerators topped out at FP16/BF16. But deep learning hardware has moved on. Starting with Nvidia Hopper and Habana Gaudi2, native FP8 training became viable. Following HW precision innovation and OCP standardization of Microscaling format, NVIDIA's B200/B300 and AMD's MI355X in 2025 ship with the granular MXFP8/6/4 support.
Naturally, this brings a fresh look at QAT. IMHO, QAT will become less relevant going forward, here are my 5 cents:
-
The overhead of QAT. Without optimized implementation such as fused quantizer kernels, quantizers adds extra ops to the training graph, especially for QAT variants that also learn quantization parameters, these cost more memory use, and training time.
-
Today's models are simply too large for older accelerators with limited memory. Modern hardware is required anyway. Why fakequantize and run in 16/32-bit GEMM when we can just execute FP8/4 GEMM which in theory 2-4X faster by design? Sure, there will always be niches, small models fine-tuned on A100s, or compact discriminative models for on-device deployment. But for large LMs, the momentum is clearly toward modern GPUs.
-
New models are "born" low-precision. e.g. DeepSeek v3 is trained in FP8. With today's hardware, sooner new foundation models are inherently low-precision from day one. Post-hoc quantization is only relevant when lower-than-trained precision is required. Also, we wouldn't want to upcast FP8-trained to BF16/FP16 and just to simulate FP4 quantized effect, would we? 🤣
To substantiate, we can perform a quick study.
- Baseline: nvfp4-qat example of TensorRT-Model-Optimizer. It is a Qwen3-8B finetuning (SFT) on Nvidia/OpenScience dataset using QAT flow that aims to deploy in nvfp4 inference, BF16 backward
- SFT in native BF16 to observe overhead of QAT.
- SFT in per-tensor F8.
- SFT in MXFP8. 3&4 naturally use corresponding precision for inference.
- Finally, try a hybrid nvfp4 forward + mxfp8 backward training with our PoC Transformer Engine (TE) recipe.
All experiments run on identical setup, on 8x Blackwell B200 with 512 batch size, identical learning rate and schedule. 2-5 are carried with our custom fine-tuning script developed in NeMo that integrates our PoC TE build.
| # | Backend | Recipe | Inference | Train Step Time (s) | Val Loss |
|---|---|---|---|---|---|
| 1 | TensorRT Model Optimizer | nvfp4-qat | nvfp4 | 28.28 | 0.60351 |
| 2 | Transformer Engine | bf16 | bf16 | 14.00 | 0.57617* |
| 3 | Transformer Engine | fp8 | fp8 | 10.97* | 0.58386 |
| 4 | Transformer Engine | mxfp8 | mxfp8 | 11.93 | 0.57907 |
| 5 | Transformer Engine (ours) | nv4f-mx8b | nvfp4 | 13.18 | 0.60783 |
See run logs at wandb.
On speed (train step time):
- 1 vs 2: QAT incurs roughly 2× overhead compared to native BF16.
- 3-5: Per-tensor FP8 is the fastest (expected); more granular MXFP8 quantization is costlier.
- Note that ours (nv4f-mx8b) in theory should be faster than pure mxfp8 since 1 out of 3 GEMMs in linear layers are run in nvfp4 instead of mxfp8. The reason is that each location carries both an nvfp4 and an mxfp8 quantizer, causing two quantizer kernel launches instead of one. This can be optimized. Still, it is slightly faster than BF16 and ~2× faster than QAT, which is the key takeaway.
On quality (val loss):
- 2: As expected, native BF16 gives best quality. Though deployment in BF16 is not our target.
- 4 vs 3: MXFP8 quantization is per group of 32 elements, less prone to drastic dynamic range, hence better quality than per-tensor FP8. Good to know but not our aligned with our target nvfp4 inference.
- 5 vs 1: Our hybrid NVFP4-forward + MXFP8-backward approach is slightly worse than QAT but within ~1% val-loss, while being around 2× faster.
In short, on modern hardware with native low-precision support, QAT becomes less attractive. QAT due to its overhead with marginal better quality which can be closed by better algorithms.
-
Baseline nvfp4-qat with TensorRT-Model-Optimizer
On Host
export HOSTWD=~/work/dev/ mkdir ${HOSTWD}/trtmo/ && cd $_ git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git && git checkout 615f3c01 git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a docker run -d --gpus all -it --rm --network=host --ipc=host \ --shm-size=16g --ulimit memlock=-1 --ulimit stack=67108864 \ -v ${HOSTWD}/trtmo/NeMo:/opt/NeMo \ -v ${HOSTWD}/trtmo/TensorRT-Model-Optimizer/modelopt:/usr/local/lib/python3.12/dist-packages/modelopt \ -v ${HOSTWD}/trtmo/TensorRT-Model-Optimizer:/workspace/TensorRT-Model-Optimizer \ -v /root/work:/root/work \ nvcr.io/nvidia/nemo:25.07
Inside Container
# need to login to download models from HF huggingface-cli login cd /workspace/TensorRT-Model-Optimizer/examples/nemo_run/ python qat/nemo_qat_flow.py --log-dir /root/work/run --experiment trt_qat
-
For experiment 2-5
On Host (Dockerfile)
docker run -d --gpus all -it --rm --network=host --ipc=host \ --shm-size=16g --ulimit memlock=-1 --ulimit stack=67108864 \ -v /root/work:/root/work \ vuiseng9/faster-qat:nv4f_mx8b
Inside Container
source /opt/venv/bin/activate huggingface-cli login wandb login export HF_TOKEN=your_hf_token # for first time download export NEMO_HOME=/root/work/nemo # where nemo models stored export NEMORUN_HOME=/root/work/run/faster-qat # run outputs export openscience_ds=/root/work/run/trt_qat_experiment/openscience_proc #preprocessed openscience dataset, piggyback trtmo qat runs export gpu_type=b200 # we assume 8x export wandb_project=faster-qat # bf16 SFT (qwen3-8b) python -m scripts.nv4f_mx8b.finetune_qwen3_8b \ -g $gpu_type -gn 8 -ng 8 \ -f sft \ -wdp $wandb_project \ -dp $openscience_ds \ -mb 1 -gb 512 -tp 1 -pp 1 -cp 1 -vp 1 -ep 1 -et 1 \ --compute_dtype bf16 # fp8 current scaling SFT (qwen3-8b) --compute_dtype fp8 --fp8_recipe ds # mxfp8 SFT (qwen3-8b) --compute_dtype fp8 --fp8_recipe mxfp8 # nvfp4 forward, mxfp8 backward SFT (qwen3-8b) --compute_dtype fp8 --fp8_recipe nv4f_mx8b