Skip to content

vuiseng9/faster-qat

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 

Repository files navigation

Faster than FP4-QAT? Use actual NVFP4 forward + MXFP8 backward.

This repo explores how modern hardware could change the landscape of Quantize-aware training (QAT).

For years, QAT primarily to help mapping model inference and deployment on high-throughput Int8 hardware that are widely available on most hardware including CPU, edge devices. It is effective: by simulating quantization effects on full-precision accelerators (i.e., fake-quantization), models could be fine-tuned to recover the accuracy lost caused by rounding/truncation effect of lower precision.

That approach made sense back when accelerators topped out at FP16/BF16. But deep learning hardware has moved on. Starting with Nvidia Hopper and Habana Gaudi2, native FP8 training became viable. Following HW precision innovation and OCP standardization of Microscaling format, NVIDIA's B200/B300 and AMD's MI355X in 2025 ship with the granular MXFP8/6/4 support.

Naturally, this brings a fresh look at QAT. IMHO, QAT will become less relevant going forward, here are my 5 cents:

  1. The overhead of QAT. Without optimized implementation such as fused quantizer kernels, quantizers adds extra ops to the training graph, especially for QAT variants that also learn quantization parameters, these cost more memory use, and training time.

  2. Today's models are simply too large for older accelerators with limited memory. Modern hardware is required anyway. Why fakequantize and run in 16/32-bit GEMM when we can just execute FP8/4 GEMM which in theory 2-4X faster by design? Sure, there will always be niches, small models fine-tuned on A100s, or compact discriminative models for on-device deployment. But for large LMs, the momentum is clearly toward modern GPUs.

  3. New models are "born" low-precision. e.g. DeepSeek v3 is trained in FP8. With today's hardware, sooner new foundation models are inherently low-precision from day one. Post-hoc quantization is only relevant when lower-than-trained precision is required. Also, we wouldn't want to upcast FP8-trained to BF16/FP16 and just to simulate FP4 quantized effect, would we? 🤣

To substantiate, we can perform a quick study.

  1. Baseline: nvfp4-qat example of TensorRT-Model-Optimizer. It is a Qwen3-8B finetuning (SFT) on Nvidia/OpenScience dataset using QAT flow that aims to deploy in nvfp4 inference, BF16 backward
  2. SFT in native BF16 to observe overhead of QAT.
  3. SFT in per-tensor F8.
  4. SFT in MXFP8. 3&4 naturally use corresponding precision for inference.
  5. Finally, try a hybrid nvfp4 forward + mxfp8 backward training with our PoC Transformer Engine (TE) recipe.

All experiments run on identical setup, on 8x Blackwell B200 with 512 batch size, identical learning rate and schedule. 2-5 are carried with our custom fine-tuning script developed in NeMo that integrates our PoC TE build.

# Backend Recipe Inference Train Step Time (s) Val Loss
1 TensorRT Model Optimizer nvfp4-qat nvfp4 28.28 0.60351
2 Transformer Engine bf16 bf16 14.00 0.57617*
3 Transformer Engine fp8 fp8 10.97* 0.58386
4 Transformer Engine mxfp8 mxfp8 11.93 0.57907
5 Transformer Engine (ours) nv4f-mx8b nvfp4 13.18 0.60783

See run logs at wandb.

On speed (train step time):

  • 1 vs 2: QAT incurs roughly 2× overhead compared to native BF16.
  • 3-5: Per-tensor FP8 is the fastest (expected); more granular MXFP8 quantization is costlier.
  • Note that ours (nv4f-mx8b) in theory should be faster than pure mxfp8 since 1 out of 3 GEMMs in linear layers are run in nvfp4 instead of mxfp8. The reason is that each location carries both an nvfp4 and an mxfp8 quantizer, causing two quantizer kernel launches instead of one. This can be optimized. Still, it is slightly faster than BF16 and ~2× faster than QAT, which is the key takeaway.

On quality (val loss):

  • 2: As expected, native BF16 gives best quality. Though deployment in BF16 is not our target.
  • 4 vs 3: MXFP8 quantization is per group of 32 elements, less prone to drastic dynamic range, hence better quality than per-tensor FP8. Good to know but not our aligned with our target nvfp4 inference.
  • 5 vs 1: Our hybrid NVFP4-forward + MXFP8-backward approach is slightly worse than QAT but within ~1% val-loss, while being around 2× faster.

In short, on modern hardware with native low-precision support, QAT becomes less attractive. QAT due to its overhead with marginal better quality which can be closed by better algorithms.

How to Reproduce

  1. Baseline nvfp4-qat with TensorRT-Model-Optimizer

    On Host

    export HOSTWD=~/work/dev/
    mkdir ${HOSTWD}/trtmo/ && cd $_
    git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git && git checkout 615f3c01
    git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a
    
    docker run -d --gpus all -it --rm --network=host --ipc=host \
      --shm-size=16g --ulimit memlock=-1 --ulimit stack=67108864 \
      -v ${HOSTWD}/trtmo/NeMo:/opt/NeMo \
      -v ${HOSTWD}/trtmo/TensorRT-Model-Optimizer/modelopt:/usr/local/lib/python3.12/dist-packages/modelopt \
      -v ${HOSTWD}/trtmo/TensorRT-Model-Optimizer:/workspace/TensorRT-Model-Optimizer \
      -v /root/work:/root/work \
      nvcr.io/nvidia/nemo:25.07

    Inside Container

    # need to login to download models from HF
    huggingface-cli login
    cd /workspace/TensorRT-Model-Optimizer/examples/nemo_run/
    python qat/nemo_qat_flow.py  --log-dir /root/work/run --experiment trt_qat
  2. For experiment 2-5

    On Host (Dockerfile)

    docker run -d --gpus all -it --rm --network=host --ipc=host \
      --shm-size=16g --ulimit memlock=-1 --ulimit stack=67108864 \
      -v /root/work:/root/work \
      vuiseng9/faster-qat:nv4f_mx8b

    Inside Container

    source /opt/venv/bin/activate
    huggingface-cli login
    wandb login
    export HF_TOKEN=your_hf_token # for first time download
    
    export NEMO_HOME=/root/work/nemo # where nemo models stored
    export NEMORUN_HOME=/root/work/run/faster-qat # run outputs
    export openscience_ds=/root/work/run/trt_qat_experiment/openscience_proc #preprocessed openscience dataset, piggyback trtmo qat runs
    export gpu_type=b200 # we assume 8x
    export wandb_project=faster-qat
    
    # bf16 SFT (qwen3-8b)
    python -m scripts.nv4f_mx8b.finetune_qwen3_8b \
            -g $gpu_type  -gn 8 -ng 8 \
            -f sft \
            -wdp $wandb_project \
            -dp $openscience_ds  \
            -mb 1 -gb 512 -tp 1 -pp 1 -cp 1 -vp 1 -ep 1 -et 1 \
            --compute_dtype bf16
    
    # fp8 current scaling SFT (qwen3-8b)
            --compute_dtype fp8 --fp8_recipe ds
    
    # mxfp8 SFT (qwen3-8b)
            --compute_dtype fp8 --fp8_recipe mxfp8
    
    # nvfp4 forward, mxfp8 backward SFT (qwen3-8b)
            --compute_dtype fp8 --fp8_recipe nv4f_mx8b

About

Revisiting QAT: QAT vs. native NVFP4/MXFP8 fine-tuning.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors