Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def moe_align1_kernel(
TOKEN_BLOCK_SIZE: tl.constexpr,
NUM_STAGE: tl.constexpr,
):

expert_id = tl.program_id(axis=0)

off_n = tl.arange(0, TOKEN_BLOCK_SIZE)
Expand Down Expand Up @@ -308,7 +307,6 @@ def moe_align2_kernel(
BLOCK_M: tl.constexpr,
BLOCK_EXPERT: tl.constexpr,
):

expert_id = tl.program_id(axis=0)
off_expert = tl.arange(0, BLOCK_EXPERT)
expert_to_token_num = tl.load(experts_token_num_ptr + off_expert, mask=off_expert < expert_num, other=0)
Expand Down Expand Up @@ -417,6 +415,7 @@ def grouped_matmul_kernel(
n_block_num, # int
compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
WEIGHT_SCALE_PER_TENSOR: tl.constexpr,
block_size_n: tl.constexpr,
block_size_k: tl.constexpr,
# tile sizes
Expand Down Expand Up @@ -504,10 +503,13 @@ def grouped_matmul_kernel(
a_scale_ptrs = token_scale_ptr + (a_m_index // topk_num)[:, None]

a_scale = tl.load(a_scale_ptrs, mask=token_mask[:, None], other=0.0, eviction_policy="evict_last")
b_scale = tl.load(
weight_scale_ptr + expert_id * weight_scale_stride0 + offs_bn[None, :] * weight_scale_stride1,
eviction_policy="evict_last",
)
if WEIGHT_SCALE_PER_TENSOR:
b_scale = tl.load(weight_scale_ptr + expert_id * weight_scale_stride0, eviction_policy="evict_last")
else:
b_scale = tl.load(
weight_scale_ptr + expert_id * weight_scale_stride0 + offs_bn[None, :] * weight_scale_stride1,
eviction_policy="evict_last",
)
ab_scale = a_scale * b_scale

if NEED_TRANS:
Expand Down Expand Up @@ -702,7 +704,7 @@ def grouped_matmul(
expert_to_token_num is tensor shape [expert_num],
expert_to_token_index is tensor shape [expert_num, token_num * topk_num],
expert_weights is tensor shape [expert_num, out_dim, hidden_dim]
expert_to_weights_scale is tensor shape [expert_num] or
expert_to_weights_scale is tensor shape [expert_num], [expert_num, 1], [expert_num, out_dim] or
[expert_num, out_dim // block_size_, hidden_dim // block_size_k],
when use_fp8_w8a8 is False, it must be None
out is tensor shape [token_num * topk_num, out_dim]
Expand All @@ -723,6 +725,9 @@ def grouped_matmul(
if expert_to_weights_scale.ndim == 3:
block_size_n = expert_weights.shape[1] // expert_to_weights_scale.shape[1]
block_size_k = expert_weights.shape[2] // expert_to_weights_scale.shape[2]
weight_scale_per_tensor = (
use_fp8_w8a8 and expert_to_weights_scale is not None and expert_to_weights_scale.numel() == expert_num
)

if run_config is None:
if token_inputs.shape[0] <= expert_num:
Expand Down Expand Up @@ -872,6 +877,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
n_block_num=triton.cdiv(n, BLOCK_SIZE_N),
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
WEIGHT_SCALE_PER_TENSOR=weight_scale_per_tensor,
block_size_n=block_size_n,
block_size_k=block_size_k,
BLOCK_SIZE_M=BLOCK_SIZE_M,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

@triton.jit
def grouped_launch(pid, m_block_num, n_block_num, group_m: tl.constexpr):

num_pid_in_group = group_m * n_block_num
group_id = pid // num_pid_in_group
first_pid_m = group_id * group_m
Expand Down Expand Up @@ -45,6 +44,7 @@ def _scaled_mm_per_token(
stride_cn,
USE_TMA: tl.constexpr,
B_IS_TRANS: tl.constexpr,
B_SCALE_IS_TENSOR: tl.constexpr,
NEED_N_MASK: tl.constexpr,
NEED_K_MASK: tl.constexpr,
BLOCK_M: tl.constexpr,
Expand Down Expand Up @@ -77,9 +77,11 @@ def _scaled_mm_per_token(
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

Ascale_ptrs = Ascale + offs_am
Bscale_ptrs = Bscale + offs_bn
a_s = tl.load(Ascale_ptrs)
b_s = tl.load(Bscale_ptrs)
if B_SCALE_IS_TENSOR:
b_s = tl.load(Bscale)
else:
b_s = tl.load(Bscale + offs_bn)

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_DTYPE)

Expand Down Expand Up @@ -143,12 +145,13 @@ def get_test_configs():
return fp8_gemm_configs


def _get_static_key(A, B, out_dtype):
def _get_static_key(A, B, Bscale, out_dtype):
M, K = A.shape
_, N = B.shape
return {
"N": N,
"K": K,
"b_scale_kind": "tensor" if Bscale.numel() == 1 else "channel",
"out_dtype": str(out_dtype),
}

Expand All @@ -175,7 +178,8 @@ def scaled_mm_per_token(
A: Matrix A with shape of [M, K].
B: Matrix B with shape of [K, N].
Ascale: per-token Quantization scale for A: [M] or [M, 1].
Bscale: per-channel Quantization scale for B: [N] or [1, N].
Bscale: per-channel Quantization scale for B: [N] or [1, N],
or per-tensor scale [1].
out_dtype: The data type of out.
out: The output matrix with the shape of [M, N].
Returns:
Expand Down Expand Up @@ -239,6 +243,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
out_desc = None

ACC_DTYPE = tl.int32 if A.dtype == torch.int8 else tl.float32
B_SCALE_IS_TENSOR = Bscale.numel() == 1

_scaled_mm_per_token[grid](
A=A,
Expand All @@ -260,6 +265,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
stride_cn=out.stride(1),
USE_TMA=support_tma,
B_IS_TRANS=B_is_trans,
B_SCALE_IS_TENSOR=B_SCALE_IS_TENSOR,
NEED_N_MASK=NEED_N_MASK,
NEED_K_MASK=NEED_K_MASK,
ACC_DTYPE=ACC_DTYPE,
Expand Down
174 changes: 172 additions & 2 deletions lightllm/common/quantization/w8a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,41 @@
"1",
]

FP8_E4M3_MAX = 448.0


def _fp8_per_tensor_quant(weight: torch.Tensor, device_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
weight = weight.float().cuda(device_id)
if weight.ndim == 3:
scale = weight.abs().amax(dim=(-1, -2)) / FP8_E4M3_MAX
else:
scale = weight.abs().max() / FP8_E4M3_MAX
scale = torch.clamp(scale, min=torch.finfo(torch.float32).tiny)
scale_view = scale.reshape(-1, 1, 1) if weight.ndim == 3 else scale
qweight = _fp8_quant_with_scale(weight, scale_view)
return qweight, scale.reshape(-1)


def _fp8_quant_with_scale(weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return (weight / scale).clamp(min=-FP8_E4M3_MAX, max=FP8_E4M3_MAX).to(dtype=torch.float8_e4m3fn)


def _copy_scale_with_broadcast(dst: torch.Tensor, src: torch.Tensor) -> None:
if dst.numel() == src.numel():
dst.copy_(src.reshape_as(dst))
elif src.numel() == 1:
if dst.dim() == 0:
dst.copy_(src.reshape(()))
else:
dst.copy_(src.reshape(1).expand_as(dst))
else:
raise ValueError(f"can not copy scale with shape {tuple(src.shape)} to {tuple(dst.shape)}")


class BaseQuantizationMethod(QuantizationMethod):
def __init__(self):
super().__init__()
assert HAS_VLLM, "vllm are not installed, you can't use quant api of them."
# assert HAS_VLLM, "vllm are not installed, you can't use quant api of them."
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager

self.cache_manager = g_cache_manager
Expand Down Expand Up @@ -124,7 +154,6 @@ def __init__(self):
self.has_weight_zero_point = False

def quantize(self, weight: torch.Tensor, output: WeightPack) -> None:

qweight, weight_scale = scaled_fp8_quant(
weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True
)
Expand Down Expand Up @@ -181,6 +210,147 @@ def _create_weight(
return mm_param, mm_param_list


@QUANTMETHODS.register(
["triton-fp8w8a8-pertensor", "fp8w8a8-pertensor", "triton-fp8w8a8-pt", "fp8w8a8-pt"],
platform="cuda",
)
class TritonFP8w8a8PerTensorQuantizationMethod(BaseQuantizationMethod):
def __init__(self):
super().__init__()
self.has_weight_scale = True
self.has_weight_zero_point = False

def quantize(self, weight: torch.Tensor, output: WeightPack) -> None:
if weight.ndim == 3 and output.weight_scale is not None and output.weight_scale.numel() == weight.shape[0]:
for expert_idx in range(weight.shape[0]):
qweight, weight_scale = _fp8_per_tensor_quant(weight[expert_idx], self.device_id_)
output.weight[expert_idx].copy_(qweight)
output.weight_scale[expert_idx].copy_(weight_scale.reshape(()))
return

qweight, weight_scale = _fp8_per_tensor_quant(weight, self.device_id_)
output.weight.copy_(qweight)
_copy_scale_with_broadcast(output.weight_scale, weight_scale)
return

def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack) -> None:
parent_pack = getattr(weight_pack, "_fp8_pt_parent_pack", None)
if parent_pack is None:
super().load_weight(weight, weight_pack)
return

staged_weight = weight_pack._fp8_pt_staged_weight
staged_weight.copy_(weight.to(device=staged_weight.device, dtype=staged_weight.dtype, non_blocking=True))
loaded_index = weight_pack._fp8_pt_child_index
if hasattr(weight_pack, "_fp8_pt_expert_index"):
loaded_index = (weight_pack._fp8_pt_expert_index, loaded_index)
parent_pack._fp8_pt_staged_loaded[loaded_index] = True
self._try_finalize_deferred_weight(parent_pack)
return

def _try_finalize_deferred_weight(self, parent_pack: WeightPack) -> bool:
if getattr(parent_pack, "_fp8_pt_finalized", False):
return True
staged_loaded = parent_pack._fp8_pt_staged_loaded
if isinstance(staged_loaded, torch.Tensor):
all_loaded = bool(staged_loaded.all().item())
else:
all_loaded = all(staged_loaded)
if not all_loaded:
return False

self.quantize(parent_pack._fp8_pt_staged_weight, parent_pack)
parent_pack.load_ok = [True, True, True]
parent_pack._fp8_pt_finalized = True
parent_pack._fp8_pt_staged_weight = None
for child_pack in parent_pack._fp8_pt_child_packs:
child_pack.load_ok = [True, True, True]
child_pack._fp8_pt_staged_weight = None
for expert_child_pack in getattr(child_pack, "_fp8_pt_expert_child_packs", []):
expert_child_pack.load_ok = [True, True, True]
expert_child_pack._fp8_pt_staged_weight = None
return True

def apply(
self,
input_tensor: torch.Tensor,
weight_pack: WeightPack,
out: Optional[torch.Tensor] = None,
workspace: Optional[torch.Tensor] = None,
use_custom_tensor_mananger: bool = True,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = weight_pack.weight.t()
weight_scale = weight_pack.weight_scale
x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
m = input_tensor.shape[0]
n = qweight.shape[1]
if out is None:
if use_custom_tensor_mananger:
out = self.cache_manager.alloc_tensor((m, n), input_tensor.dtype, device=input_tensor.device)
else:
out = torch.empty((m, n), dtype=input_tensor.dtype, device=input_tensor.device)
assert bias is None, "Bias addition is not supported in triton-fp8w8a8-pertensor for now"
return fp8_scaled_mm_per_token(
x_q,
qweight,
x_scale,
weight_scale,
input_tensor.dtype,
out,
)

@property
def method_name(self):
return "triton-fp8w8a8-pertensor"

def _create_weight(
self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1
) -> Tuple[WeightPack, List[WeightPack]]:
if isinstance(out_dims, int):
out_dims = [out_dims]
out_dim = sum(out_dims)
expert_prefix = (num_experts,) if num_experts > 1 else ()
weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id)

weight_scale = torch.empty(expert_prefix or (1,), dtype=torch.float32, device=f"cuda:{device_id}")
mm_param = WeightPack(weight=weight, weight_scale=weight_scale)
weight_splits = torch.split(weight, out_dims, dim=-2)
mm_param_list = [WeightPack(weight=weight, weight_scale=weight_scale) for weight in weight_splits]

if len(out_dims) > 1:
staged_weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=dtype, device="cpu")
staged_splits = torch.split(staged_weight, out_dims, dim=-2)
mm_param._fp8_pt_staged_weight = staged_weight
if num_experts > 1:
mm_param._fp8_pt_staged_loaded = torch.zeros(
(num_experts, len(mm_param_list)), dtype=torch.bool, device="cpu"
)
else:
mm_param._fp8_pt_staged_loaded = [False] * len(mm_param_list)
mm_param._fp8_pt_child_packs = mm_param_list
mm_param._fp8_pt_finalized = False
for idx, (child_pack, staged_split) in enumerate(zip(mm_param_list, staged_splits)):
child_pack._fp8_pt_parent_pack = mm_param
child_pack._fp8_pt_child_index = idx
child_pack._fp8_pt_staged_weight = staged_split
if num_experts > 1:
child_pack._fp8_pt_expert_child_packs = []
child_pack._fp8_pt_get_expert = child_pack.get_expert

def _get_deferred_expert(expert_idx, _child_pack=child_pack):
expert_child_pack = _child_pack._fp8_pt_get_expert(expert_idx)
expert_child_pack._fp8_pt_parent_pack = _child_pack._fp8_pt_parent_pack
expert_child_pack._fp8_pt_child_index = _child_pack._fp8_pt_child_index
expert_child_pack._fp8_pt_expert_index = expert_idx
expert_child_pack._fp8_pt_staged_weight = _child_pack._fp8_pt_staged_weight[expert_idx]
_child_pack._fp8_pt_expert_child_packs.append(expert_child_pack)
return expert_child_pack

child_pack.get_expert = _get_deferred_expert
return mm_param, mm_param_list


@QUANTMETHODS.register(["vllm-fp8w8a8-b128", "fp8w8a8-b128"], platform="cuda")
class FP8w8a8B128QuantizationMethod(BaseQuantizationMethod):
def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
"--quant_type",
type=str,
default="none",
help="""Quantization method: vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128
help="""Quantization method: vllm-w8a8 | vllm-fp8w8a8 | triton-fp8w8a8-pertensor | vllm-fp8w8a8-b128
| deepgemm-fp8w8a8-b128 | triton-fp8w8a8-block128 | awq | awq_marlin |
| triton-fp8w8a8g128 (weight perchannel quant and act per group quant) |
triton-fp8w8a8g64 (weight perchannel quantization with group size 64)""",
Expand Down