diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py index d451758c8..bcfd9de40 100644 --- a/examples/deepseek/ptq.py +++ b/examples/deepseek/ptq.py @@ -99,7 +99,7 @@ def linear( weight = weight_quantizer(weight) return F.linear(x, weight, bias) elif gemm_impl == "bf16": - weight = weight_dequant(weight, weight.scale) + weight = weight_dequant(weight, weight.scale, dtype=torch.bfloat16) if act_quantizer is not None: x = act_quantizer(x) if weight_quantizer is not None: diff --git a/examples/deepseek/quantize_to_nvfp4.py b/examples/deepseek/quantize_to_nvfp4.py index a18cbbc16..af387fce5 100644 --- a/examples/deepseek/quantize_to_nvfp4.py +++ b/examples/deepseek/quantize_to_nvfp4.py @@ -44,11 +44,11 @@ from typing import Any import torch -from ds_kernel import weight_dequant from safetensors.torch import load_file, save_file from tqdm import tqdm from modelopt.torch.quantization.qtensor import NVFP4QTensor +from modelopt.torch.quantization.triton import weight_dequant def _remap_key(key_dict: dict[str, Any]):