Skip to content
204 changes: 204 additions & 0 deletions auto_round_extension/ark/auto_round_kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,210 @@ def moe_gemm(
)
return outputs

def moe_gemm_decode(
self,
activations: torch.Tensor,
weights: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
*,
scales: Optional[torch.Tensor] = None,
zeros: Optional[torch.Tensor] = None,
weight_bits: int = 4,
group_size: int = 128,
asym: bool = False,
) -> torch.Tensor:
"""MoE GEMV optimized for the decode phase.

Each expert typically processes only 1-2 tokens (top-k routing with
small batch). Activations must already be gathered/sorted by expert
(same convention as ``moe_gemm``).

Args:
activations: ``[total_tokens, K]`` in fp16 or bf16.
weights: 3-D tensor ``[E, N, K_packed]``. The accepted layouts are:

* Unquantized (``weight_bits=16``): ``torch.float16`` / ``torch.bfloat16``
matching the activations dtype, ``K_packed == K``.
* Int8 (``weight_bits=8``): ``torch.uint8``, ``K_packed == K``.
Sym (``asym=False``) reinterprets each byte as signed int8;
asym (``asym=True``) treats each byte as ``uint8`` with a
per-group zero-point.
* Int4 (``weight_bits=4``): ``torch.uint8`` packed,
``K_packed == K // 2`` (two 4-bit values per byte; low nibble
at the lower K index).
* Int2 (``weight_bits=2``): ``torch.uint8`` packed,
``K_packed == K // 4`` (four 2-bit values per byte; field j at
K index ``4*i + j`` occupies bits 2j and 2j+1 of byte i).
* FP8 (``torch.float8_e4m3fn`` / ``torch.float8_e5m2``):
``K_packed == K``. ``weight_bits`` is ignored; ``asym`` must
be ``False`` (no zero-points for FP8).
num_tokens_per_expert: ``[E]`` int32. Sum must equal
``activations.shape[0]``.
scales: ``[E, N, K // group_size]`` in activations dtype. Required
for all quantized paths (int8/int4/int2/fp8); must be ``None``
for unquantized weights.
zeros: ``[E, N, K // group_size]`` in activations dtype. Required
when ``asym=True`` (int8/int4/int2 only); otherwise ``None``.
weight_bits: 2, 4, 8, or 16. Ignored when ``weights`` is an FP8
tensor (the FP8 sub-format is taken from ``weights.dtype``).
group_size: group along K for quantized weights (default 128).
asym: if ``True``, weights use unsigned encoding and ``zeros`` must
be provided. Not supported for FP8.

Returns:
outputs: ``[total_tokens, N]`` in the same dtype as activations.
"""
if activations.device.type != "xpu":
raise NotImplementedError("moe_gemm_decode is only supported on XPU")

if activations.dtype not in (torch.float16, torch.bfloat16):
raise ValueError(f"activations must be fp16/bf16, got {activations.dtype}")

if activations.ndim != 2:
raise ValueError("activations must be 2D [total_tokens, K]")
if weights.ndim != 3:
raise ValueError("weights must be 3D [E, N, K_packed]")

if not activations.is_contiguous():
activations = activations.contiguous()
Comment on lines +851 to +865
if not weights.is_contiguous():
weights = weights.contiguous()

if num_tokens_per_expert.dtype != torch.int32:
num_tokens_per_expert = num_tokens_per_expert.to(torch.int32)
if not num_tokens_per_expert.is_contiguous():
num_tokens_per_expert = num_tokens_per_expert.contiguous()

total_tokens, K = activations.shape
num_experts = weights.shape[0]
N = weights.shape[1]

if num_tokens_per_expert.shape[0] != num_experts:
raise ValueError(
f"num_tokens_per_expert length {num_tokens_per_expert.shape[0]} != num_experts {num_experts}"
)

# Detect FP8 weight dtype first (overrides weight_bits).
is_fp8 = weights.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)

# Validate weight layout / dtype combination.
if is_fp8:
if asym:
raise ValueError("FP8 weights do not support asym=True")
if weights.shape[2] != K:
raise ValueError(f"FP8 weights K dim {weights.shape[2]} != activations K {K}")
if scales is None:
raise ValueError("scales is required for FP8 weights")
if scales.dtype != activations.dtype:
raise ValueError("scales dtype must match activations dtype")
if K % group_size != 0:
raise ValueError("K must be a multiple of group_size")
expected_scale_shape = (num_experts, N, K // group_size)
if tuple(scales.shape) != expected_scale_shape:
raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}")
if zeros is not None:
raise ValueError("zeros must be None for FP8 weights")
weight_dtype = (
ARK_DT.float8_e4m3 if weights.dtype == torch.float8_e4m3fn else ARK_DT.float8_e5m2
)
if not scales.is_contiguous():
scales = scales.contiguous()
elif weight_bits == 16:
if weights.dtype != activations.dtype:
raise ValueError("Unquantized weights must match activations dtype")
if weights.shape[2] != K:
raise ValueError(f"Unquantized weights K dim {weights.shape[2]} != activations K {K}")
weight_dtype = cvt_dtype(activations.dtype)
if scales is not None or zeros is not None:
raise ValueError("scales/zeros must be None when weight_bits=16")
elif weight_bits in (8, 4, 2):
if weights.dtype != torch.uint8:
raise ValueError(f"Int{weight_bits} packed weights must be torch.uint8")
if weight_bits == 8:
k_packed_expected = K
k_div = 1
elif weight_bits == 4:
k_packed_expected = K // 2
k_div = 2
else: # weight_bits == 2
k_packed_expected = K // 4
k_div = 4
if K % k_div != 0:
raise ValueError(f"K must be a multiple of {k_div} for weight_bits={weight_bits}")
if weights.shape[2] != k_packed_expected:
raise ValueError(
f"Int{weight_bits} packed weights last dim {weights.shape[2]} must equal K/{k_div} "
f"({k_packed_expected})"
)
if scales is None:
raise ValueError(f"scales is required for int{weight_bits} weights")
if scales.dtype != activations.dtype:
raise ValueError("scales dtype must match activations dtype")
if K % group_size != 0:
raise ValueError("K must be a multiple of group_size")
# Group_size constraints per dtype.
if weight_bits == 4 and (group_size & 1) != 0:
raise ValueError("group_size must be even for int4 weights")
if weight_bits == 2 and (group_size & 3) != 0:
raise ValueError("group_size must be a multiple of 4 for int2 weights")
expected_scale_shape = (num_experts, N, K // group_size)
if tuple(scales.shape) != expected_scale_shape:
raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}")
if asym:
if zeros is None:
raise ValueError("zeros is required when asym=True")
if zeros.dtype != activations.dtype:
raise ValueError("zeros dtype must match activations dtype")
if tuple(zeros.shape) != expected_scale_shape:
raise ValueError(f"zeros shape {tuple(zeros.shape)} != expected {expected_scale_shape}")
else:
if zeros is not None:
raise ValueError("zeros must be None when asym=False")
weight_dtype = {8: ARK_DT.int8, 4: ARK_DT.int4, 2: ARK_DT.int2}[weight_bits]
if not scales.is_contiguous():
scales = scales.contiguous()
if asym and not zeros.is_contiguous():
zeros = zeros.contiguous()
else:
raise ValueError(f"Unsupported weight_bits={weight_bits} (supported: 2, 4, 8, 16)")

if N % 16 != 0:
raise ValueError(f"N must be a multiple of 16 (got {N})")

expected_total = int(num_tokens_per_expert.sum().item())
if expected_total != total_tokens:
raise ValueError(f"Sum of num_tokens_per_expert ({expected_total}) != total_tokens ({total_tokens})")

lib = self.get_lib(activations)
stream = get_stream(activations)
outputs = torch.empty((total_tokens, N), device=activations.device, dtype=activations.dtype)
# Scratch buffer mapping each token to its expert id; filled on-device
# inside the kernel wrapper so we avoid host-device sync.
expert_id_per_token = torch.empty((total_tokens,), device=activations.device, dtype=torch.int32)

scales_ptr = scales.data_ptr() if scales is not None else 0
zeros_ptr = zeros.data_ptr() if zeros is not None else 0

lib.moe_gemm_decode(
stream,
activations.data_ptr(),
weights.data_ptr(),
scales_ptr,
zeros_ptr,
outputs.data_ptr(),
expert_id_per_token.data_ptr(),
cvt_dtype(activations.dtype),
weight_dtype,
N,
K,
group_size,
num_tokens_per_expert.data_ptr(),
num_experts,
total_tokens,
bool(asym),
)
return outputs


if __name__ == "__main__":
ark = ARK()
Expand Down
12 changes: 12 additions & 0 deletions auto_round_extension/ark/auto_round_kernel/ark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ typedef uintptr_t torch_ptr;
// Only include declarations, implementations are in separate .cpp files
#include "sycl_tla_common.hpp"
#include "sycl_tla_moe.hpp"
#include "sycl_tla_moe_decode.hpp"
#include "sycl_tla_sdpa.hpp"
#endif
#else
Expand Down Expand Up @@ -153,6 +154,16 @@ static void moe_gemm_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr
(void*)outputs, (BTLA_DTYPE)(dtype), N, K, (int*)num_tokens_per_expert, num_experts);
}

static void moe_gemm_decode_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr weights, torch_ptr scales,
torch_ptr zeros, torch_ptr outputs, torch_ptr expert_id_per_token_buf,
int act_dtype, int weight_dtype, int N, int K, int group_size,
torch_ptr num_tokens_per_expert, int num_experts, int total_tokens, bool asym) {
ark::moe_gemm_decode((sycl::queue*)stream, (void*)activations, (void*)weights, scales ? (void*)scales : nullptr,
zeros ? (void*)zeros : nullptr, (void*)outputs, (int*)expert_id_per_token_buf,
(BTLA_DTYPE)(act_dtype), (BTLA_DTYPE)(weight_dtype), N, K, group_size,
(int*)num_tokens_per_expert, num_experts, total_tokens, asym);
}

static void sage_dynamic_quant(torch_ptr stream, torch_ptr input, torch_ptr bias, torch_ptr output, torch_ptr scale_out,
int num_rows, int head_dim, int block_size) {
auto* q = (sycl::queue*)stream;
Expand Down Expand Up @@ -292,5 +303,6 @@ PYBIND11_MODULE(PY_NAME, m) {
m.def("sage", &ark::sage);
m.def("sage_dynamic_quant", &ark::sage_dynamic_quant);
m.def("moe_gemm", &ark::moe_gemm_wrapper);
m.def("moe_gemm_decode", &ark::moe_gemm_decode_wrapper);
#endif
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,40 @@ namespace ark {
void moe_gemm(sycl::queue* q, void* activations, void* weights, void* scales, void* outputs, BTLA_DTYPE dtype, int N,
int K, int* num_tokens_per_expert, int num_experts);

/**
* @brief MoE GEMV optimized for the decode phase (M per expert is typically
* 1-2 tokens). Supports unquantized FP16/BF16 weights and int4 (S4_CLIP)
* weights with group-wise scales and optional zero-points.
*
* Implementation is header-only in `sycl_tla_moe_decode.hpp`.
*
* @param q SYCL queue
* @param activations [total_tokens, K] in `act_dtype`
* @param weights Unquantized: [num_experts, N, K] in act_dtype
* Int4: packed [num_experts, N, K/2] uint8
* @param scales [num_experts, N, K/group_size] (act_dtype),
* ignored when weight_dtype is FP16/BF16
* @param zeros [num_experts, N, K/group_size] (act_dtype) or
* nullptr; required when asym==true
* @param outputs [total_tokens, N] in act_dtype
* @param expert_id_per_token_buf [total_tokens] int32 scratch buffer (device)
* @param act_dtype BTLA_DTYPE::F16 or BTLA_DTYPE::BF16
* @param weight_dtype BTLA_DTYPE::F16/BF16/S4_CLIP
* @param N Output feature dim (must be multiple of 16)
* @param K Input feature dim
* @param group_size Quantization group along K (int4 only); must
* divide K and be even. Default 128.
* @param num_tokens_per_expert [num_experts] int32
* @param num_experts Number of experts
* @param total_tokens Sum of num_tokens_per_expert (== rows of
* activations / outputs)
* @param asym Whether int4 weights are asymmetric
* (zeros required when true).
Comment on lines +40 to +66
*/
void moe_gemm_decode(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, void* outputs,
int* expert_id_per_token_buf, BTLA_DTYPE act_dtype, BTLA_DTYPE weight_dtype, int N, int K,
int group_size, int* num_tokens_per_expert, int num_experts, int total_tokens, bool asym);

// ========================================================================
// Public API
// ========================================================================
Expand Down
Loading
Loading