diff --git a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py index 007cc0fddd2..a86170e0727 100644 --- a/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py +++ b/fastdeploy/model_executor/layers/quantization/block_wise_fp8.py @@ -67,7 +67,7 @@ def __init__(self, weight_block_size: list = [-1, -1], is_checkpoint_bf16: bool self.quant_round_type = 1 self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM) self.is_checkpoint_bf16 = is_checkpoint_bf16 - self.deepgemm_scale_ue8m0 = True if get_sm_version() == 100 else False + self.deepgemm_scale_ue8m0 = True if get_sm_version() >= 100 else False def name(self) -> str: return "block_wise_fp8" @@ -125,7 +125,8 @@ def deep_gemm_fp8_gemm_nt( layer_output_size: int, bias: paddle.Tensor = None, ): - if get_sm_version() == 100 and current_platform.is_cuda(): + sm_version = get_sm_version() + if sm_version >= 100 and current_platform.is_cuda(): # disable_ue8m0_cast is default False for SM100 fp8_gemm_nt( (x, x_scale_tensor), diff --git a/fastdeploy/model_executor/layers/quantization/fp8_utils.py b/fastdeploy/model_executor/layers/quantization/fp8_utils.py index 65d30d4004d..a5cd230f601 100644 --- a/fastdeploy/model_executor/layers/quantization/fp8_utils.py +++ b/fastdeploy/model_executor/layers/quantization/fp8_utils.py @@ -65,7 +65,7 @@ def load_deep_gemm(): """ if current_platform.is_cuda(): - if get_sm_version() == 100: + if get_sm_version() >= 100: # SM100 should use PFCC DeepGemm paddle.compat.enable_torch_proxy(scope={"deep_gemm"}) try: @@ -245,7 +245,7 @@ def fused_stack_transpose_quant(expert_weight_list, use_ue8m0=False): # Blackwell (SM100) GPUs require pow2_scale quantization. # Guard with is_cuda() so non-CUDA environments do not call into # paddle.device.cuda.* and cause a crash. - use_pow2_scale = current_platform.is_cuda() and get_sm_version() == 100 + use_pow2_scale = current_platform.is_cuda() and get_sm_version() >= 100 w, scale = paddlefleet_ops.fuse_stack_transpose_fp8_quant( expert_weight_list,