diff --git a/auto_round_extension/ark/auto_round_kernel/__init__.py b/auto_round_extension/ark/auto_round_kernel/__init__.py index 831d3714a..731cdf37d 100644 --- a/auto_round_extension/ark/auto_round_kernel/__init__.py +++ b/auto_round_extension/ark/auto_round_kernel/__init__.py @@ -797,6 +797,210 @@ def moe_gemm( ) return outputs + def moe_gemm_decode( + self, + activations: torch.Tensor, + weights: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + *, + scales: Optional[torch.Tensor] = None, + zeros: Optional[torch.Tensor] = None, + weight_bits: int = 4, + group_size: int = 128, + asym: bool = False, + ) -> torch.Tensor: + """MoE GEMV optimized for the decode phase. + + Each expert typically processes only 1-2 tokens (top-k routing with + small batch). Activations must already be gathered/sorted by expert + (same convention as ``moe_gemm``). + + Args: + activations: ``[total_tokens, K]`` in fp16 or bf16. + weights: 3-D tensor ``[E, N, K_packed]``. The accepted layouts are: + + * Unquantized (``weight_bits=16``): ``torch.float16`` / ``torch.bfloat16`` + matching the activations dtype, ``K_packed == K``. + * Int8 (``weight_bits=8``): ``torch.uint8``, ``K_packed == K``. + Sym (``asym=False``) reinterprets each byte as signed int8; + asym (``asym=True``) treats each byte as ``uint8`` with a + per-group zero-point. + * Int4 (``weight_bits=4``): ``torch.uint8`` packed, + ``K_packed == K // 2`` (two 4-bit values per byte; low nibble + at the lower K index). + * Int2 (``weight_bits=2``): ``torch.uint8`` packed, + ``K_packed == K // 4`` (four 2-bit values per byte; field j at + K index ``4*i + j`` occupies bits 2j and 2j+1 of byte i). + * FP8 (``torch.float8_e4m3fn`` / ``torch.float8_e5m2``): + ``K_packed == K``. ``weight_bits`` is ignored; ``asym`` must + be ``False`` (no zero-points for FP8). + num_tokens_per_expert: ``[E]`` int32. Sum must equal + ``activations.shape[0]``. + scales: ``[E, N, K // group_size]`` in activations dtype. Required + for all quantized paths (int8/int4/int2/fp8); must be ``None`` + for unquantized weights. + zeros: ``[E, N, K // group_size]`` in activations dtype. Required + when ``asym=True`` (int8/int4/int2 only); otherwise ``None``. + weight_bits: 2, 4, 8, or 16. Ignored when ``weights`` is an FP8 + tensor (the FP8 sub-format is taken from ``weights.dtype``). + group_size: group along K for quantized weights (default 128). + asym: if ``True``, weights use unsigned encoding and ``zeros`` must + be provided. Not supported for FP8. + + Returns: + outputs: ``[total_tokens, N]`` in the same dtype as activations. + """ + if activations.device.type != "xpu": + raise NotImplementedError("moe_gemm_decode is only supported on XPU") + + if activations.dtype not in (torch.float16, torch.bfloat16): + raise ValueError(f"activations must be fp16/bf16, got {activations.dtype}") + + if activations.ndim != 2: + raise ValueError("activations must be 2D [total_tokens, K]") + if weights.ndim != 3: + raise ValueError("weights must be 3D [E, N, K_packed]") + + if not activations.is_contiguous(): + activations = activations.contiguous() + if not weights.is_contiguous(): + weights = weights.contiguous() + + if num_tokens_per_expert.dtype != torch.int32: + num_tokens_per_expert = num_tokens_per_expert.to(torch.int32) + if not num_tokens_per_expert.is_contiguous(): + num_tokens_per_expert = num_tokens_per_expert.contiguous() + + total_tokens, K = activations.shape + num_experts = weights.shape[0] + N = weights.shape[1] + + if num_tokens_per_expert.shape[0] != num_experts: + raise ValueError( + f"num_tokens_per_expert length {num_tokens_per_expert.shape[0]} != num_experts {num_experts}" + ) + + # Detect FP8 weight dtype first (overrides weight_bits). + is_fp8 = weights.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + + # Validate weight layout / dtype combination. + if is_fp8: + if asym: + raise ValueError("FP8 weights do not support asym=True") + if weights.shape[2] != K: + raise ValueError(f"FP8 weights K dim {weights.shape[2]} != activations K {K}") + if scales is None: + raise ValueError("scales is required for FP8 weights") + if scales.dtype != activations.dtype: + raise ValueError("scales dtype must match activations dtype") + if K % group_size != 0: + raise ValueError("K must be a multiple of group_size") + expected_scale_shape = (num_experts, N, K // group_size) + if tuple(scales.shape) != expected_scale_shape: + raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}") + if zeros is not None: + raise ValueError("zeros must be None for FP8 weights") + weight_dtype = ( + ARK_DT.float8_e4m3 if weights.dtype == torch.float8_e4m3fn else ARK_DT.float8_e5m2 + ) + if not scales.is_contiguous(): + scales = scales.contiguous() + elif weight_bits == 16: + if weights.dtype != activations.dtype: + raise ValueError("Unquantized weights must match activations dtype") + if weights.shape[2] != K: + raise ValueError(f"Unquantized weights K dim {weights.shape[2]} != activations K {K}") + weight_dtype = cvt_dtype(activations.dtype) + if scales is not None or zeros is not None: + raise ValueError("scales/zeros must be None when weight_bits=16") + elif weight_bits in (8, 4, 2): + if weights.dtype != torch.uint8: + raise ValueError(f"Int{weight_bits} packed weights must be torch.uint8") + if weight_bits == 8: + k_packed_expected = K + k_div = 1 + elif weight_bits == 4: + k_packed_expected = K // 2 + k_div = 2 + else: # weight_bits == 2 + k_packed_expected = K // 4 + k_div = 4 + if K % k_div != 0: + raise ValueError(f"K must be a multiple of {k_div} for weight_bits={weight_bits}") + if weights.shape[2] != k_packed_expected: + raise ValueError( + f"Int{weight_bits} packed weights last dim {weights.shape[2]} must equal K/{k_div} " + f"({k_packed_expected})" + ) + if scales is None: + raise ValueError(f"scales is required for int{weight_bits} weights") + if scales.dtype != activations.dtype: + raise ValueError("scales dtype must match activations dtype") + if K % group_size != 0: + raise ValueError("K must be a multiple of group_size") + # Group_size constraints per dtype. + if weight_bits == 4 and (group_size & 1) != 0: + raise ValueError("group_size must be even for int4 weights") + if weight_bits == 2 and (group_size & 3) != 0: + raise ValueError("group_size must be a multiple of 4 for int2 weights") + expected_scale_shape = (num_experts, N, K // group_size) + if tuple(scales.shape) != expected_scale_shape: + raise ValueError(f"scales shape {tuple(scales.shape)} != expected {expected_scale_shape}") + if asym: + if zeros is None: + raise ValueError("zeros is required when asym=True") + if zeros.dtype != activations.dtype: + raise ValueError("zeros dtype must match activations dtype") + if tuple(zeros.shape) != expected_scale_shape: + raise ValueError(f"zeros shape {tuple(zeros.shape)} != expected {expected_scale_shape}") + else: + if zeros is not None: + raise ValueError("zeros must be None when asym=False") + weight_dtype = {8: ARK_DT.int8, 4: ARK_DT.int4, 2: ARK_DT.int2}[weight_bits] + if not scales.is_contiguous(): + scales = scales.contiguous() + if asym and not zeros.is_contiguous(): + zeros = zeros.contiguous() + else: + raise ValueError(f"Unsupported weight_bits={weight_bits} (supported: 2, 4, 8, 16)") + + if N % 16 != 0: + raise ValueError(f"N must be a multiple of 16 (got {N})") + + expected_total = int(num_tokens_per_expert.sum().item()) + if expected_total != total_tokens: + raise ValueError(f"Sum of num_tokens_per_expert ({expected_total}) != total_tokens ({total_tokens})") + + lib = self.get_lib(activations) + stream = get_stream(activations) + outputs = torch.empty((total_tokens, N), device=activations.device, dtype=activations.dtype) + # Scratch buffer mapping each token to its expert id; filled on-device + # inside the kernel wrapper so we avoid host-device sync. + expert_id_per_token = torch.empty((total_tokens,), device=activations.device, dtype=torch.int32) + + scales_ptr = scales.data_ptr() if scales is not None else 0 + zeros_ptr = zeros.data_ptr() if zeros is not None else 0 + + lib.moe_gemm_decode( + stream, + activations.data_ptr(), + weights.data_ptr(), + scales_ptr, + zeros_ptr, + outputs.data_ptr(), + expert_id_per_token.data_ptr(), + cvt_dtype(activations.dtype), + weight_dtype, + N, + K, + group_size, + num_tokens_per_expert.data_ptr(), + num_experts, + total_tokens, + bool(asym), + ) + return outputs + if __name__ == "__main__": ark = ARK() diff --git a/auto_round_extension/ark/auto_round_kernel/ark.cpp b/auto_round_extension/ark/auto_round_kernel/ark.cpp index 86d983f90..6bb5ad040 100755 --- a/auto_round_extension/ark/auto_round_kernel/ark.cpp +++ b/auto_round_extension/ark/auto_round_kernel/ark.cpp @@ -25,6 +25,7 @@ typedef uintptr_t torch_ptr; // Only include declarations, implementations are in separate .cpp files #include "sycl_tla_common.hpp" #include "sycl_tla_moe.hpp" +#include "sycl_tla_moe_decode.hpp" #include "sycl_tla_sdpa.hpp" #endif #else @@ -153,6 +154,16 @@ static void moe_gemm_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr (void*)outputs, (BTLA_DTYPE)(dtype), N, K, (int*)num_tokens_per_expert, num_experts); } +static void moe_gemm_decode_wrapper(torch_ptr stream, torch_ptr activations, torch_ptr weights, torch_ptr scales, + torch_ptr zeros, torch_ptr outputs, torch_ptr expert_id_per_token_buf, + int act_dtype, int weight_dtype, int N, int K, int group_size, + torch_ptr num_tokens_per_expert, int num_experts, int total_tokens, bool asym) { + ark::moe_gemm_decode((sycl::queue*)stream, (void*)activations, (void*)weights, scales ? (void*)scales : nullptr, + zeros ? (void*)zeros : nullptr, (void*)outputs, (int*)expert_id_per_token_buf, + (BTLA_DTYPE)(act_dtype), (BTLA_DTYPE)(weight_dtype), N, K, group_size, + (int*)num_tokens_per_expert, num_experts, total_tokens, asym); +} + static void sage_dynamic_quant(torch_ptr stream, torch_ptr input, torch_ptr bias, torch_ptr output, torch_ptr scale_out, int num_rows, int head_dim, int block_size) { auto* q = (sycl::queue*)stream; @@ -292,5 +303,6 @@ PYBIND11_MODULE(PY_NAME, m) { m.def("sage", &ark::sage); m.def("sage_dynamic_quant", &ark::sage_dynamic_quant); m.def("moe_gemm", &ark::moe_gemm_wrapper); + m.def("moe_gemm_decode", &ark::moe_gemm_decode_wrapper); #endif } \ No newline at end of file diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp index 37752f79a..31b3994ce 100644 --- a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp @@ -35,6 +35,40 @@ namespace ark { void moe_gemm(sycl::queue* q, void* activations, void* weights, void* scales, void* outputs, BTLA_DTYPE dtype, int N, int K, int* num_tokens_per_expert, int num_experts); +/** + * @brief MoE GEMV optimized for the decode phase (M per expert is typically + * 1-2 tokens). Supports unquantized FP16/BF16 weights and int4 (S4_CLIP) + * weights with group-wise scales and optional zero-points. + * + * Implementation is header-only in `sycl_tla_moe_decode.hpp`. + * + * @param q SYCL queue + * @param activations [total_tokens, K] in `act_dtype` + * @param weights Unquantized: [num_experts, N, K] in act_dtype + * Int4: packed [num_experts, N, K/2] uint8 + * @param scales [num_experts, N, K/group_size] (act_dtype), + * ignored when weight_dtype is FP16/BF16 + * @param zeros [num_experts, N, K/group_size] (act_dtype) or + * nullptr; required when asym==true + * @param outputs [total_tokens, N] in act_dtype + * @param expert_id_per_token_buf [total_tokens] int32 scratch buffer (device) + * @param act_dtype BTLA_DTYPE::F16 or BTLA_DTYPE::BF16 + * @param weight_dtype BTLA_DTYPE::F16/BF16/S4_CLIP + * @param N Output feature dim (must be multiple of 16) + * @param K Input feature dim + * @param group_size Quantization group along K (int4 only); must + * divide K and be even. Default 128. + * @param num_tokens_per_expert [num_experts] int32 + * @param num_experts Number of experts + * @param total_tokens Sum of num_tokens_per_expert (== rows of + * activations / outputs) + * @param asym Whether int4 weights are asymmetric + * (zeros required when true). + */ +void moe_gemm_decode(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, void* outputs, + int* expert_id_per_token_buf, BTLA_DTYPE act_dtype, BTLA_DTYPE weight_dtype, int N, int K, + int group_size, int* num_tokens_per_expert, int num_experts, int total_tokens, bool asym); + // ======================================================================== // Public API // ======================================================================== diff --git a/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp new file mode 100644 index 000000000..343752633 --- /dev/null +++ b/auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp @@ -0,0 +1,701 @@ +// SYCL MoE Decode Kernel +// +// GEMV-style MoE kernel optimized for the decode phase, where each expert +// typically processes only 1-2 tokens (top-k routing with batch size 1). +// +// Layout convention (caller already sorted activations per expert, +// identical to the prefill `moe_gemm` interface): +// - activations: [total_tokens, K] row-major +// - weights (fp/bf16): [num_experts, N, K] row-major +// - weights (int8): [num_experts, N, K] row-major, one +// int8 per byte (sym: signed -128..127; +// asym: unsigned 0..255 with zero-point) +// - weights (int4 packed): [num_experts, N, K/2] row-major, two +// 4-bit values per byte (low nibble at lower K) +// - weights (int2 packed): [num_experts, N, K/4] row-major, four +// 2-bit values per byte (field j at K index +// 4*i+j is bits [2j+1:2j]) +// - weights (fp8): [num_experts, N, K] row-major, one +// FP8 byte per weight (E4M3 / E5M2); scales +// applied per-group, no zero-points +// - scales: [num_experts, N, K/group_size] +// - zeros (asym only): [num_experts, N, K/group_size] +// - num_tokens_per_expert: [num_experts] int32 +// - outputs: [total_tokens, N] +// +// Target: Intel BMG (Xe2), sub_group_size = 16. One sub-group per (token, N-tile) +// with N_TILE == SG_SIZE: each lane independently computes one output element, +// so no cross-lane reduction is needed and activation reads are coalesced across +// the sub-group through the L1 cache. +// +// Copyright (C) 2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include + +#include "bestla/bestla/bestla.h" + +#ifdef ARK_XPU +#include +#endif + +#if defined(ARK_XPU) && defined(ARK_SYCL_TLA) + +namespace ark { +namespace moe_decode_detail { + +constexpr int SG_SIZE = 16; +constexpr int N_TILE = SG_SIZE; // one output element per sub-group lane + +// ---------------------------------------------------------------------------- +// Kernel name tags (one per specialization, required for SYCL kernel naming) +// ---------------------------------------------------------------------------- +template +class MoEDecodeKernelFP; + +template +class MoEDecodeKernelInt4; + +template +class MoEDecodeKernelInt8; + +template +class MoEDecodeKernelInt2; + +template +class MoEDecodeKernelFP8; + +// ---------------------------------------------------------------------------- +// FP8 byte -> float decode (device-side, no LUT / SLM required). +// Matches IEEE-style layout used by torch.float8_e4m3fn / torch.float8_e5m2: +// E4M3 (finite-only): 1 sign, 4 exp (bias 7), 3 mantissa; 0x7F/0xFF = NaN. +// E5M2 (IEEE-like): 1 sign, 5 exp (bias 15), 2 mantissa; exp==31 -> Inf/NaN. +// We keep these inline rather than using a SLM LUT because the decode kernel +// runs only one sub-group per workgroup and per-lane bit ops are cheap relative +// to the global memory loads. +// ---------------------------------------------------------------------------- +inline float decode_fp8_e4m3(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const uint32_t sign = byte >> 7; + float v; + if (mag == 0u) { + v = 0.0f; + } else if (mag == 0x7Fu) { + v = sycl::nan(0u); + } else { + const int exp = static_cast((mag >> 3) & 0xFu); + const int man = static_cast(mag & 0x7u); + if (exp == 0) { + // subnormal: value = man * 2^(1 - bias - mbits) = man / 512 + v = static_cast(man) * (1.0f / 512.0f); + } else { + // normal: (1 + man/8) * 2^(exp - bias), bias = 7 + v = (1.0f + static_cast(man) * 0.125f) * sycl::ldexp(1.0f, exp - 7); + } + } + return sign ? -v : v; +} + +inline float decode_fp8_e5m2(uint8_t byte) { + const uint32_t mag = byte & 0x7Fu; + const uint32_t sign = byte >> 7; + const int exp = static_cast((mag >> 2) & 0x1Fu); + const int man = static_cast(mag & 0x3u); + float v; + if (exp == 0) { + // subnormal (incl. zero): value = man * 2^(1 - bias - mbits) = man / 65536 + v = static_cast(man) * (1.0f / 65536.0f); + } else if (exp == 31) { + v = (man == 0) ? std::numeric_limits::infinity() : sycl::nan(0u); + } else { + // normal: (1 + man/4) * 2^(exp - bias), bias = 15 + v = (1.0f + static_cast(man) * 0.25f) * sycl::ldexp(1.0f, exp - 15); + } + return sign ? -v : v; +} + +// ---------------------------------------------------------------------------- +// Build a [total_tokens] -> expert_id mapping from num_tokens_per_expert. +// Runs on host (num_experts is small, total_tokens is small in decode). +// Caller-managed buffer (USM device allocation) keeps host noise out of the +// hot path; here we just fill it via a tiny SYCL kernel for simplicity. +// ---------------------------------------------------------------------------- +inline void fill_expert_id_per_token(sycl::queue* q, int* expert_id_per_token, + const int* num_tokens_per_expert, int num_experts, + int total_tokens) { + // Sequential prefix-scan on a single thread; cheap because num_experts is + // small (typ. <= 256) and we avoid host-device sync entirely. + q->single_task([=]() { + int offset = 0; + for (int e = 0; e < num_experts; ++e) { + int n = num_tokens_per_expert[e]; + for (int i = 0; i < n; ++i) { + if (offset + i < total_tokens) { + expert_id_per_token[offset + i] = e; + } + } + offset += n; + } + }).wait(); +} + +// ---------------------------------------------------------------------------- +// FP16 / BF16 baseline GEMV (no quantization). +// ---------------------------------------------------------------------------- +template +void launch_fp(sycl::queue* q, const ScalarT* activations, const ScalarT* weights, ScalarT* outputs, + const int* expert_id_per_token, int total_tokens, int N, int K) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode: N must be a multiple of 16"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + const ScalarT* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * K; + + float acc = 0.0f; + // Unroll by 8 to hide latency; arbitrary K (any multiple of 8). + int k = 0; + constexpr int UNROLL = 8; + for (; k + UNROLL <= K; k += UNROLL) { +#pragma unroll + for (int u = 0; u < UNROLL; ++u) { + acc += static_cast(act_row[k + u]) * static_cast(w_row[k + u]); + } + } + for (; k < K; ++k) { + acc += static_cast(act_row[k]) * static_cast(w_row[k]); + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }) + .wait(); +} + +// ---------------------------------------------------------------------------- +// INT4 (S4_CLIP) GEMV with group-wise dequantization. +// +// Asym=false: signed nibble in [-8, 7], dequant = q * scale +// Asym=true : unsigned nibble in [0, 15], dequant = (q - zero) * scale +// +// Packing: two 4-bit values per byte; the value at k = 2*i is the LOW nibble +// of byte i, the value at k = 2*i+1 is the HIGH nibble. This matches the +// existing CPU/XPU `packq` layout for S4_CLIP weights. +// ---------------------------------------------------------------------------- +template +void launch_int4(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, + const ScalarT* zeros, ScalarT* outputs, const int* expert_id_per_token, int total_tokens, int N, + int K, int group_size) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode(int4): N must be a multiple of 16"); + } + if (K % group_size != 0 || (group_size & 1) != 0) { + throw std::invalid_argument("moe_gemm_decode(int4): K must be a multiple of group_size and group_size must be even"); + } + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_decode(int4): zeros pointer required when asym=true"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + const int num_groups_k = K / group_size; + const int k_packed = K / 2; // bytes of packed weight per (expert, n) + + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + + const uint8_t* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * k_packed; + const ScalarT* s_row = + scales + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k; + const ScalarT* z_row = Asym + ? zeros + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k + : nullptr; + + float acc = 0.0f; + for (int g = 0; g < num_groups_k; ++g) { + const float scale = static_cast(s_row[g]); + float zero = 0.0f; + if constexpr (Asym) { + zero = static_cast(z_row[g]); + } + const int k_base = g * group_size; + // Two nibbles per byte; iterate in pairs. + for (int kk = 0; kk < group_size; kk += 2) { + const uint8_t packed = w_row[(k_base + kk) / 2]; + float w0, w1; + if constexpr (Asym) { + const int q0 = static_cast(packed & 0x0F); + const int q1 = static_cast((packed >> 4) & 0x0F); + w0 = (static_cast(q0) - zero) * scale; + w1 = (static_cast(q1) - zero) * scale; + } else { + // Sign-extend each 4-bit signed nibble to 8-bit signed: + // low nibble: shift left by 4 to move bit[3] into bit[7], + // then arithmetic right-shift by 4 replicates the sign bit. + // high nibble: same trick after masking off the low nibble. + const int q0 = static_cast(static_cast(packed << 4) >> 4); + const int q1 = static_cast(static_cast(packed & 0xF0) >> 4); + w0 = static_cast(q0) * scale; + w1 = static_cast(q1) * scale; + } + acc += static_cast(act_row[k_base + kk]) * w0; + acc += static_cast(act_row[k_base + kk + 1]) * w1; + } + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }) + .wait(); +} + +// ---------------------------------------------------------------------------- +// INT8 (S8) GEMV with group-wise dequantization. +// +// Asym=false: signed byte in [-128, 127], dequant = q * scale +// Asym=true : unsigned byte in [0, 255], dequant = (q - zero) * scale +// +// Weights are stored as raw uint8 bytes (1 byte per weight). The same buffer +// type is used for sym and asym; the only difference is the sign interpretation +// performed at decode time. +// ---------------------------------------------------------------------------- +template +void launch_int8(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, + const ScalarT* zeros, ScalarT* outputs, const int* expert_id_per_token, int total_tokens, int N, + int K, int group_size) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode(int8): N must be a multiple of 16"); + } + if (K % group_size != 0) { + throw std::invalid_argument("moe_gemm_decode(int8): K must be a multiple of group_size"); + } + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_decode(int8): zeros pointer required when asym=true"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + const int num_groups_k = K / group_size; + + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + + const uint8_t* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * K; + const ScalarT* s_row = + scales + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k; + const ScalarT* z_row = Asym + ? zeros + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k + : nullptr; + + float acc = 0.0f; + for (int g = 0; g < num_groups_k; ++g) { + const float scale = static_cast(s_row[g]); + float zero = 0.0f; + if constexpr (Asym) { + zero = static_cast(z_row[g]); + } + const int k_base = g * group_size; + for (int kk = 0; kk < group_size; ++kk) { + const uint8_t raw = w_row[k_base + kk]; + float w; + if constexpr (Asym) { + w = (static_cast(raw) - zero) * scale; + } else { + // Reinterpret as signed int8. + w = static_cast(static_cast(raw)) * scale; + } + acc += static_cast(act_row[k_base + kk]) * w; + } + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }) + .wait(); +} + +// ---------------------------------------------------------------------------- +// INT2 (S2_CLIP) GEMV with group-wise dequantization. +// +// Packing: 4 values per byte. The value at K index 4*i + j is stored in +// bits [2j+1 : 2j] of byte i (i.e. byte = q0 | (q1<<2) | (q2<<4) | (q3<<6)). +// +// Asym=false: signed 2-bit value in [-2, 1]; dequant = q * scale +// Asym=true : unsigned 2-bit value in [0, 3]; dequant = (q - zero) * scale +// ---------------------------------------------------------------------------- +template +void launch_int2(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, + const ScalarT* zeros, ScalarT* outputs, const int* expert_id_per_token, int total_tokens, int N, + int K, int group_size) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode(int2): N must be a multiple of 16"); + } + if ((K & 0x3) != 0) { + throw std::invalid_argument("moe_gemm_decode(int2): K must be a multiple of 4"); + } + if (K % group_size != 0 || (group_size & 0x3) != 0) { + throw std::invalid_argument( + "moe_gemm_decode(int2): K must be a multiple of group_size and group_size must be a multiple of 4"); + } + if (Asym && zeros == nullptr) { + throw std::invalid_argument("moe_gemm_decode(int2): zeros pointer required when asym=true"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + const int num_groups_k = K / group_size; + const int k_packed = K / 4; // bytes of packed weight per (expert, n) + + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + + const uint8_t* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * k_packed; + const ScalarT* s_row = + scales + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k; + const ScalarT* z_row = Asym + ? zeros + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k + : nullptr; + + float acc = 0.0f; + for (int g = 0; g < num_groups_k; ++g) { + const float scale = static_cast(s_row[g]); + float zero = 0.0f; + if constexpr (Asym) { + zero = static_cast(z_row[g]); + } + const int k_base = g * group_size; + // 4 values per byte; iterate in quads. + for (int kk = 0; kk < group_size; kk += 4) { + const uint8_t packed = w_row[(k_base + kk) / 4]; + float w[4]; + if constexpr (Asym) { + const int q0 = static_cast(packed & 0x3); + const int q1 = static_cast((packed >> 2) & 0x3); + const int q2 = static_cast((packed >> 4) & 0x3); + const int q3 = static_cast((packed >> 6) & 0x3); + w[0] = (static_cast(q0) - zero) * scale; + w[1] = (static_cast(q1) - zero) * scale; + w[2] = (static_cast(q2) - zero) * scale; + w[3] = (static_cast(q3) - zero) * scale; + } else { + // Sign-extend each 2-bit signed value via shift-left-then-arith-shift-right. + // After placing the 2 bits in the high 2 of an int8, >>6 replicates the sign. + const int q0 = static_cast(static_cast(packed << 6) >> 6); + const int q1 = static_cast(static_cast((packed << 4) & 0xC0) >> 6); + const int q2 = static_cast(static_cast((packed << 2) & 0xC0) >> 6); + const int q3 = static_cast(static_cast(packed & 0xC0) >> 6); + w[0] = static_cast(q0) * scale; + w[1] = static_cast(q1) * scale; + w[2] = static_cast(q2) * scale; + w[3] = static_cast(q3) * scale; + } + acc += static_cast(act_row[k_base + kk + 0]) * w[0]; + acc += static_cast(act_row[k_base + kk + 1]) * w[1]; + acc += static_cast(act_row[k_base + kk + 2]) * w[2]; + acc += static_cast(act_row[k_base + kk + 3]) * w[3]; + } + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }) + .wait(); +} + +// ---------------------------------------------------------------------------- +// FP8 (E4M3 / E5M2) GEMV with group-wise scale (no zero-point). +// +// Weights are 1 FP8 byte per element [E, N, K]. The byte is decoded inline by +// bit manipulation; the LUT in fp8_lut.h would also work but inline decode +// keeps this kernel self-contained and avoids touching SLM. +// ---------------------------------------------------------------------------- +template +void launch_fp8(sycl::queue* q, const ScalarT* activations, const uint8_t* weights, const ScalarT* scales, + ScalarT* outputs, const int* expert_id_per_token, int total_tokens, int N, int K, int group_size) { + if (N % N_TILE != 0) { + throw std::invalid_argument("moe_gemm_decode(fp8): N must be a multiple of 16"); + } + if (K % group_size != 0) { + throw std::invalid_argument("moe_gemm_decode(fp8): K must be a multiple of group_size"); + } + if (total_tokens == 0) return; + + const int n_tiles = N / N_TILE; + const int num_groups_k = K / group_size; + + sycl::range<2> global{static_cast(total_tokens), static_cast(n_tiles * SG_SIZE)}; + sycl::range<2> local{1, static_cast(SG_SIZE)}; + + q->parallel_for>( + sycl::nd_range<2>(global, local), + [=](sycl::nd_item<2> it) [[intel::reqd_sub_group_size(SG_SIZE)]] { + const int token = static_cast(it.get_global_id(0)); + const int n_tile = static_cast(it.get_group(1)); + const int lane = static_cast(it.get_local_id(1)); + const int n_global = n_tile * N_TILE + lane; + + const int expert = expert_id_per_token[token]; + const ScalarT* act_row = activations + static_cast(token) * K; + + const uint8_t* w_row = + weights + (static_cast(expert) * N + static_cast(n_global)) * K; + const ScalarT* s_row = + scales + (static_cast(expert) * N + static_cast(n_global)) * num_groups_k; + + float acc = 0.0f; + for (int g = 0; g < num_groups_k; ++g) { + const float scale = static_cast(s_row[g]); + const int k_base = g * group_size; + for (int kk = 0; kk < group_size; ++kk) { + const uint8_t raw = w_row[k_base + kk]; + float w; + if constexpr (IsE4M3) { + w = decode_fp8_e4m3(raw) * scale; + } else { + w = decode_fp8_e5m2(raw) * scale; + } + acc += static_cast(act_row[k_base + kk]) * w; + } + } + + outputs[static_cast(token) * N + n_global] = static_cast(acc); + }) + .wait(); +} + +} // namespace moe_decode_detail + +// ---------------------------------------------------------------------------- +// Public API +// +// weight_dtype: +// BTLA_DTYPE::F16 / BF16 : weights stored as [E, N, K] in matching +// floating dtype, no scales/zeros needed +// BTLA_DTYPE::S8 : int8 weights [E, N, K] (uint8 buffer, +// interpreted as signed when asym==false, +// unsigned with zero-points when asym==true) +// BTLA_DTYPE::S4_CLIP : packed int4 weights [E, N, K/2] (uint8), +// scales [E, N, K/group_size] in act dtype, +// zeros optional (asym==true requires it) +// BTLA_DTYPE::S2_CLIP : packed int2 weights [E, N, K/4] (uint8), +// 4 values per byte, sym/asym like int4 +// BTLA_DTYPE::F8_E4M3 / F8_E5M2 : FP8 weights [E, N, K] (uint8 buffer), +// group-wise scales, no zero-points +// act_dtype: F16 or BF16 (must match scales/outputs dtype) +// ---------------------------------------------------------------------------- +inline void moe_gemm_decode(sycl::queue* q, void* activations, void* weights, void* scales, void* zeros, + void* outputs, int* expert_id_per_token_buf, BTLA_DTYPE act_dtype, + BTLA_DTYPE weight_dtype, int N, int K, int group_size, int* num_tokens_per_expert, + int num_experts, int total_tokens, bool asym) { + moe_decode_detail::fill_expert_id_per_token(q, expert_id_per_token_buf, num_tokens_per_expert, num_experts, + total_tokens); + + if (weight_dtype == BTLA_DTYPE::F16 || weight_dtype == BTLA_DTYPE::BF16) { + if (weight_dtype != act_dtype) { + throw std::invalid_argument("moe_gemm_decode: unquantized weight_dtype must match act_dtype"); + } + if (act_dtype == BTLA_DTYPE::F16) { + moe_decode_detail::launch_fp(q, static_cast(activations), + static_cast(weights), + static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K); + } else { + moe_decode_detail::launch_fp( + q, static_cast(activations), + static_cast(weights), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K); + } + return; + } + + if (weight_dtype == BTLA_DTYPE::S4_CLIP) { + if (act_dtype == BTLA_DTYPE::F16) { + if (asym) { + moe_decode_detail::launch_int4( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int4( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + if (asym) { + moe_decode_detail::launch_int4( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int4( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else { + throw std::invalid_argument("moe_gemm_decode(int4): act_dtype must be FP16 or BF16"); + } + return; + } + + if (weight_dtype == BTLA_DTYPE::S8) { + if (act_dtype == BTLA_DTYPE::F16) { + if (asym) { + moe_decode_detail::launch_int8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + if (asym) { + moe_decode_detail::launch_int8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else { + throw std::invalid_argument("moe_gemm_decode(int8): act_dtype must be FP16 or BF16"); + } + return; + } + + if (weight_dtype == BTLA_DTYPE::S2_CLIP) { + if (act_dtype == BTLA_DTYPE::F16) { + if (asym) { + moe_decode_detail::launch_int2( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int2( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), + static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + if (asym) { + moe_decode_detail::launch_int2( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_int2( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(zeros), static_cast(outputs), + expert_id_per_token_buf, total_tokens, N, K, group_size); + } + } else { + throw std::invalid_argument("moe_gemm_decode(int2): act_dtype must be FP16 or BF16"); + } + return; + } + + if (weight_dtype == BTLA_DTYPE::F8_E4M3 || weight_dtype == BTLA_DTYPE::F8_E5M2) { + if (asym) { + throw std::invalid_argument("moe_gemm_decode(fp8): asym mode is not supported"); + } + const bool is_e4m3 = (weight_dtype == BTLA_DTYPE::F8_E4M3); + if (act_dtype == BTLA_DTYPE::F16) { + if (is_e4m3) { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K, group_size); + } else { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, + total_tokens, N, K, group_size); + } + } else if (act_dtype == BTLA_DTYPE::BF16) { + using BF = sycl::ext::oneapi::bfloat16; + if (is_e4m3) { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, + group_size); + } else { + moe_decode_detail::launch_fp8( + q, static_cast(activations), static_cast(weights), + static_cast(scales), static_cast(outputs), expert_id_per_token_buf, total_tokens, N, K, + group_size); + } + } else { + throw std::invalid_argument("moe_gemm_decode(fp8): act_dtype must be FP16 or BF16"); + } + return; + } + + throw std::invalid_argument( + "moe_gemm_decode: unsupported weight_dtype (supported: F16, BF16, S8, S4_CLIP, S2_CLIP, F8_E4M3, F8_E5M2)"); +} + +} // namespace ark + +#endif // ARK_XPU && ARK_SYCL_TLA diff --git a/auto_round_extension/ark/test/test_moe.py b/auto_round_extension/ark/test/test_moe.py index 1134f7b4d..bca8d84be 100644 --- a/auto_round_extension/ark/test/test_moe.py +++ b/auto_round_extension/ark/test/test_moe.py @@ -39,6 +39,13 @@ def has_moe_gemm(): return hasattr(ark.xpu_lib, "moe_gemm") +def has_moe_gemm_decode(): + """Check if MoE decode GEMV kernel is available.""" + if ark.xpu_lib is None: + return False + return hasattr(ark.xpu_lib, "moe_gemm_decode") + + @pytest.mark.skipif(not is_xpu_available(), reason="XPU not available") @pytest.mark.skipif(not has_moe_gemm(), reason="MOE GEMM kernel not built (need ARK_SYCL_TLA=ON)") class TestMoEGemm: @@ -166,5 +173,545 @@ def test_moe_gemm_various_sizes(self, N, K): print(f"MOE GEMM test passed for N={N}, K={K}") +# --------------------------------------------------------------------------- +# Decode-path tests (M per expert is typically 1-2, mirrors top-k routing +# after the activations have been gathered/sorted by the upper layer). +# --------------------------------------------------------------------------- + + +def _pack_int4_sym(w_float, scales, group_size): + """Quantize a [E, N, K] fp tensor to symmetric int4 packed [E, N, K/2]. + + scales is filled in-place with [E, N, K/group_size] values. + """ + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + s = (absmax / 7.0).squeeze(-1).to(scales.dtype) + scales.copy_(s) + q = torch.clamp(torch.round(w / (s.to(w.dtype).unsqueeze(-1))), -8, 7).to(torch.int8) + q = q.reshape(E, N, K) + # Pack two nibbles per byte: low nibble at lower K, high nibble at higher K. + q_low = q[..., 0::2] & 0x0F + q_high = q[..., 1::2] & 0x0F + packed = (q_low | (q_high << 4)).to(torch.uint8) + return packed + + +def _pack_int4_asym(w_float, scales, zeros, group_size): + """Quantize to asymmetric int4 (range [0, 15]); returns packed weights.""" + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + wmin = w.amin(dim=-1, keepdim=True) + wmax = w.amax(dim=-1, keepdim=True) + s = ((wmax - wmin) / 15.0).clamp(min=1e-8) + z = torch.round(-wmin / s).clamp(0, 15) + scales.copy_(s.squeeze(-1).to(scales.dtype)) + zeros.copy_(z.squeeze(-1).to(zeros.dtype)) + q = torch.clamp(torch.round(w / s + z), 0, 15).to(torch.int32) + q = q.reshape(E, N, K) + q_low = q[..., 0::2] & 0x0F + q_high = q[..., 1::2] & 0x0F + packed = (q_low | (q_high << 4)).to(torch.uint8) + return packed + + +def _dequant_int4_sym(packed, scales, group_size): + """Inverse of _pack_int4_sym. Returns [E, N, K] in scales.dtype.""" + E, N, K_half = packed.shape + K = K_half * 2 + low = (packed & 0x0F).to(torch.int8) + high = ((packed >> 4) & 0x0F).to(torch.int8) + # Sign extend 4-bit -> 8-bit + low = torch.where(low >= 8, low - 16, low) + high = torch.where(high >= 8, high - 16, high) + q = torch.empty(E, N, K, dtype=torch.int8, device=packed.device) + q[..., 0::2] = low + q[..., 1::2] = high + q = q.reshape(E, N, K // group_size, group_size).to(scales.dtype) + return (q * scales.unsqueeze(-1)).reshape(E, N, K) + + +def _dequant_int4_asym(packed, scales, zeros, group_size): + E, N, K_half = packed.shape + K = K_half * 2 + low = (packed & 0x0F).to(torch.int32) + high = ((packed >> 4) & 0x0F).to(torch.int32) + q = torch.empty(E, N, K, dtype=torch.int32, device=packed.device) + q[..., 0::2] = low + q[..., 1::2] = high + q = q.reshape(E, N, K // group_size, group_size).to(scales.dtype) + deq = (q - zeros.to(scales.dtype).unsqueeze(-1)) * scales.unsqueeze(-1) + return deq.reshape(E, N, K) + + +# --------------------------------------------------------------------------- +# Int8 / Int2 / FP8 helpers (decode-path). +# --------------------------------------------------------------------------- + + +def _pack_int8_sym(w_float, scales, group_size): + """Quantize [E, N, K] fp -> int8 (signed, [-127, 127]); fills scales in place.""" + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + s = (absmax / 127.0).squeeze(-1).to(scales.dtype) + scales.copy_(s) + q = torch.clamp(torch.round(w / s.to(w.dtype).unsqueeze(-1)), -127, 127).to(torch.int8) + # Reinterpret as uint8 with no value change. + return q.reshape(E, N, K).view(torch.uint8).contiguous() + + +def _pack_int8_asym(w_float, scales, zeros, group_size): + """Quantize [E, N, K] fp -> uint8 ([0, 255]); fills scales/zeros in place.""" + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + wmin = w.amin(dim=-1, keepdim=True) + wmax = w.amax(dim=-1, keepdim=True) + s = ((wmax - wmin) / 255.0).clamp(min=1e-8) + z = torch.round(-wmin / s).clamp(0, 255) + scales.copy_(s.squeeze(-1).to(scales.dtype)) + zeros.copy_(z.squeeze(-1).to(zeros.dtype)) + q = torch.clamp(torch.round(w / s + z), 0, 255).to(torch.int32) + return q.reshape(E, N, K).to(torch.uint8).contiguous() + + +def _dequant_int8_sym(packed_u8, scales, group_size): + # Reinterpret uint8 bytes as signed int8. + q = packed_u8.view(torch.int8).to(scales.dtype) + E, N, K = q.shape + q = q.reshape(E, N, K // group_size, group_size) + return (q * scales.unsqueeze(-1)).reshape(E, N, K) + + +def _dequant_int8_asym(packed_u8, scales, zeros, group_size): + q = packed_u8.to(torch.int32).to(scales.dtype) + E, N, K = q.shape + q = q.reshape(E, N, K // group_size, group_size) + deq = (q - zeros.to(scales.dtype).unsqueeze(-1)) * scales.unsqueeze(-1) + return deq.reshape(E, N, K) + + +def _pack_int2_sym(w_float, scales, group_size): + """Quantize [E, N, K] fp -> packed int2 (signed [-2, 1]); shape [E, N, K/4]. + + Packing: byte = q0 | (q1<<2) | (q2<<4) | (q3<<6), where the j-th 2-bit + field corresponds to K index 4*i + j. + """ + E, N, K = w_float.shape + assert K % 4 == 0, "K must be a multiple of 4 for int2 packing" + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + # Symmetric int2 has signed range [-2, 1] (i.e. clip at 2 and -2 but -2 inclusive). + absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + s = (absmax / 2.0).squeeze(-1).to(scales.dtype) + scales.copy_(s) + q = torch.clamp(torch.round(w / s.to(w.dtype).unsqueeze(-1)), -2, 1).to(torch.int32) + q = q.reshape(E, N, K) + # Pack 4 values per byte. + q0 = q[..., 0::4] & 0x3 + q1 = q[..., 1::4] & 0x3 + q2 = q[..., 2::4] & 0x3 + q3 = q[..., 3::4] & 0x3 + packed = (q0 | (q1 << 2) | (q2 << 4) | (q3 << 6)).to(torch.uint8) + return packed + + +def _pack_int2_asym(w_float, scales, zeros, group_size): + """Quantize [E, N, K] fp -> packed int2 (unsigned [0, 3]); shape [E, N, K/4].""" + E, N, K = w_float.shape + assert K % 4 == 0 + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + wmin = w.amin(dim=-1, keepdim=True) + wmax = w.amax(dim=-1, keepdim=True) + s = ((wmax - wmin) / 3.0).clamp(min=1e-8) + z = torch.round(-wmin / s).clamp(0, 3) + scales.copy_(s.squeeze(-1).to(scales.dtype)) + zeros.copy_(z.squeeze(-1).to(zeros.dtype)) + q = torch.clamp(torch.round(w / s + z), 0, 3).to(torch.int32) + q = q.reshape(E, N, K) + q0 = q[..., 0::4] & 0x3 + q1 = q[..., 1::4] & 0x3 + q2 = q[..., 2::4] & 0x3 + q3 = q[..., 3::4] & 0x3 + packed = (q0 | (q1 << 2) | (q2 << 4) | (q3 << 6)).to(torch.uint8) + return packed + + +def _dequant_int2_sym(packed, scales, group_size): + E, N, K_q = packed.shape + K = K_q * 4 + p = packed.to(torch.int32) + fields = torch.empty(E, N, K, dtype=torch.int32, device=packed.device) + fields[..., 0::4] = p & 0x3 + fields[..., 1::4] = (p >> 2) & 0x3 + fields[..., 2::4] = (p >> 4) & 0x3 + fields[..., 3::4] = (p >> 6) & 0x3 + # Sign-extend 2-bit (>=2 means negative). + fields = torch.where(fields >= 2, fields - 4, fields).to(scales.dtype) + fields = fields.reshape(E, N, K // group_size, group_size) + return (fields * scales.unsqueeze(-1)).reshape(E, N, K) + + +def _dequant_int2_asym(packed, scales, zeros, group_size): + E, N, K_q = packed.shape + K = K_q * 4 + p = packed.to(torch.int32) + fields = torch.empty(E, N, K, dtype=torch.int32, device=packed.device) + fields[..., 0::4] = p & 0x3 + fields[..., 1::4] = (p >> 2) & 0x3 + fields[..., 2::4] = (p >> 4) & 0x3 + fields[..., 3::4] = (p >> 6) & 0x3 + fields = fields.to(scales.dtype) + fields = fields.reshape(E, N, K // group_size, group_size) + deq = (fields - zeros.to(scales.dtype).unsqueeze(-1)) * scales.unsqueeze(-1) + return deq.reshape(E, N, K) + + +def _pack_fp8(w_float, scales, group_size, fp8_dtype): + """Quantize [E, N, K] fp -> FP8 (e4m3fn/e5m2) with per-group scale. + + Scales are filled in-place; ``w_float`` is divided by the scale, then cast + to ``fp8_dtype`` (rounding handled by torch). Returns the FP8 tensor. + """ + assert fp8_dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + E, N, K = w_float.shape + G = K // group_size + w = w_float.reshape(E, N, G, group_size) + absmax = w.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + # Pick fp8 max representable magnitude for the chosen format. + if fp8_dtype == torch.float8_e4m3fn: + fp8_max = 448.0 + else: # e5m2 + fp8_max = 57344.0 + s = (absmax / fp8_max).squeeze(-1).to(scales.dtype) + scales.copy_(s) + scaled = (w / s.to(w.dtype).unsqueeze(-1)).reshape(E, N, K) + # Clamp to fp8 representable range before cast to avoid Inf/NaN. + scaled = scaled.clamp(-fp8_max, fp8_max) + return scaled.to(fp8_dtype).contiguous() + + +def _dequant_fp8(packed_fp8, scales, group_size, out_dtype): + """Reference dequant: cast fp8 -> out_dtype and multiply per-group scale.""" + E, N, K = packed_fp8.shape + w = packed_fp8.to(out_dtype).reshape(E, N, K // group_size, group_size) + return (w * scales.unsqueeze(-1)).reshape(E, N, K) + + +def _moe_decode_reference(activations, dequant_weights, num_tokens_per_expert): + """Reference: each token is matmul'd against its routed expert's weights.""" + total_tokens, K = activations.shape + E, N, _ = dequant_weights.shape + out = torch.empty(total_tokens, N, dtype=activations.dtype, device=activations.device) + offset = 0 + for e in range(E): + n_tokens = int(num_tokens_per_expert[e].item()) + if n_tokens == 0: + continue + a = activations[offset : offset + n_tokens] # [n_tokens, K] + w = dequant_weights[e] # [N, K] + out[offset : offset + n_tokens] = a @ w.T + offset += n_tokens + return out + + +@pytest.mark.skipif(not is_xpu_available(), reason="XPU not available") +@pytest.mark.skipif(not has_moe_gemm_decode(), reason="MoE decode GEMV kernel not built (need ARK_SYCL_TLA=ON)") +class TestMoEGemmDecode: + """Unit tests for the MoE decode GEMV kernel. + + The activations layout follows the same convention as ``moe_gemm``: the + upper layer has already gathered/sorted tokens per expert, so the kernel + only needs ``num_tokens_per_expert`` (no top-k indices). + """ + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_fp_basic(self, dtype): + num_experts = 4 + # One token per expert with one zero-token expert -> typical top-k=3 + # decode pattern after gather. + tokens_per_expert = [1, 0, 1, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 128 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + weights = torch.randn(num_experts, N, K, dtype=dtype, device="xpu") * 0.1 + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode(activations, weights, num_tokens_per_expert, weight_bits=16) + + ref = _moe_decode_reference(activations, weights, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + assert out.dtype == dtype + torch.testing.assert_close(out, ref, rtol=2e-2, atol=2e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_int4_sym(self, dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 1, 0, 2] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_sym(w_float, scales, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=4, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_int4_asym(self, dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [0, 1, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=4, + group_size=group_size, + asym=True, + ) + + dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + def test_decode_validation_errors(self): + """Sanity-check that Python-side validation catches misuse.""" + num_experts = 2 + activations = torch.randn(2, 128, dtype=torch.float16, device="xpu") + num_tokens_per_expert = torch.tensor([1, 1], dtype=torch.int32, device="xpu") + + # N must be a multiple of 16 + bad_weights = torch.randn(num_experts, 17, 128, dtype=torch.float16, device="xpu") + with pytest.raises(ValueError): + ark.moe_gemm_decode(activations, bad_weights, num_tokens_per_expert, weight_bits=16) + + # weight_bits=4 requires uint8 packed weights + bad_packed = torch.randn(num_experts, 64, 64, dtype=torch.float16, device="xpu") + scales = torch.empty(num_experts, 64, 1, dtype=torch.float16, device="xpu") + with pytest.raises(ValueError): + ark.moe_gemm_decode( + activations, bad_packed, num_tokens_per_expert, scales=scales, weight_bits=4, group_size=128 + ) + + # asym=True without zeros must error + packed = torch.zeros(num_experts, 64, 64, dtype=torch.uint8, device="xpu") + with pytest.raises(ValueError): + ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=4, + group_size=128, + asym=True, + ) + + # FP8 + asym is rejected + fp8_w = torch.zeros(num_experts, 64, 128, dtype=torch.float8_e4m3fn, device="xpu") + zeros = torch.empty(num_experts, 64, 1, dtype=torch.float16, device="xpu") + with pytest.raises(ValueError): + ark.moe_gemm_decode( + activations, + fp8_w, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + group_size=128, + asym=True, + ) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_int8_sym(self, dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 1, 0, 2] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_sym(w_float, scales, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=8, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_int8_asym(self, dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [0, 1, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=8, + group_size=group_size, + asym=True, + ) + + dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=5e-2, atol=5e-2) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_int2_sym(self, dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 1, 0, 2] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_sym(w_float, scales, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + weight_bits=2, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + # Int2 has much higher quant error; relax tolerance vs int4. + torch.testing.assert_close(out, ref, rtol=1e-1, atol=1e-1) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_decode_int2_asym(self, dtype): + num_experts = 4 + group_size = 128 + tokens_per_expert = [0, 1, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + zeros = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + zeros=zeros, + weight_bits=2, + group_size=group_size, + asym=True, + ) + + dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + torch.testing.assert_close(out, ref, rtol=1e-1, atol=1e-1) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize( + "fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2] + ) + @pytest.mark.parametrize("group_size", [32, 128]) + def test_decode_fp8(self, dtype, fp8_dtype, group_size): + num_experts = 4 + tokens_per_expert = [1, 0, 2, 1] + total_tokens = sum(tokens_per_expert) + N, K = 256, 256 + + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(num_experts, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(num_experts, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + num_tokens_per_expert = torch.tensor(tokens_per_expert, dtype=torch.int32, device="xpu") + + out = ark.moe_gemm_decode( + activations, + packed, + num_tokens_per_expert, + scales=scales, + group_size=group_size, + asym=False, + ) + + dequant = _dequant_fp8(packed, scales, group_size, dtype) + ref = _moe_decode_reference(activations, dequant, num_tokens_per_expert) + assert out.shape == (total_tokens, N) + # E5M2 has only 2 mantissa bits -> coarser; relax tolerance for both. + rtol = 1e-1 if fp8_dtype == torch.float8_e5m2 else 5e-2 + atol = 1e-1 if fp8_dtype == torch.float8_e5m2 else 5e-2 + torch.testing.assert_close(out, ref, rtol=rtol, atol=atol) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/auto_round_extension/ark/test/test_moe_decode_perf.py b/auto_round_extension/ark/test/test_moe_decode_perf.py new file mode 100644 index 000000000..44bbb7fd0 --- /dev/null +++ b/auto_round_extension/ark/test/test_moe_decode_perf.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Performance comparison: ``ark.moe_gemm_decode`` vs default XPU MoE. + +The "default XPU MoE implementation" used as the baseline is the standard +per-expert PyTorch matmul loop (the same approach ``_moe_decode_reference`` +uses in ``test_moe.py``). For quantized formats the weights are dequantized +once up-front (outside the timed region), so the baseline measures only the +matmul cost on XPU. This is what models fall back to when no fused decode +kernel is available. + +How to run:: + + pytest -v -s auto_round_extension/ark/test/test_moe_decode_perf.py + +The ``-s`` flag is required to see the printed timing tables. +""" + +import auto_round_kernel +import pytest +import torch + +# Reuse the existing pack/dequant helpers from the correctness tests so that +# the benchmarked path matches what the unit tests already validate. +from test_moe import ( # noqa: E402 + _dequant_fp8, + _dequant_int2_asym, + _dequant_int2_sym, + _dequant_int4_asym, + _dequant_int4_sym, + _dequant_int8_asym, + _dequant_int8_sym, + _pack_fp8, + _pack_int2_asym, + _pack_int2_sym, + _pack_int4_asym, + _pack_int4_sym, + _pack_int8_asym, + _pack_int8_sym, +) + +ark = auto_round_kernel.ARK() + + +# --------------------------------------------------------------------------- +# Skip reasons. +# +# The original test_moe.py collapses several different failure modes into one +# generic "kernel not built" message which makes it impossible to tell whether +# the build is missing the kernel or whether XPU itself didn't come up. The +# helpers below distinguish those cases so a skipped run is actually +# actionable. +# --------------------------------------------------------------------------- + + +def _xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _xpu_skip_reason() -> str: + if not hasattr(torch, "xpu"): + return "torch has no xpu submodule (need an Intel XPU build of torch)" + if not torch.xpu.is_available(): + return "torch.xpu.is_available() == False (no XPU device or driver visible)" + return "" + + +def _decode_skip_reason() -> str: + """Return non-empty string if the decode kernel can't be exercised.""" + reason = _xpu_skip_reason() + if reason: + return reason + if ark.xpu_lib is None: + return ( + "ark.xpu_lib is None -- the XPU extension module " + "(auto_round_kernel_xpu) failed to import; check that auto_round_kernel " + "was installed for THIS Python env with XPU support enabled" + ) + if not hasattr(ark.xpu_lib, "moe_gemm_decode"): + return ( + "ark.xpu_lib loaded but has no moe_gemm_decode symbol -- " + "rebuild with ARK_SYCL_TLA=ON to compile the MoE decode GEMV kernel" + ) + return "" + + +_DECODE_SKIP = _decode_skip_reason() + +# Surface diagnostics on collection so the user always sees why the suite +# would skip, without having to add extra flags. +print( + "[moe-decode-perf] xpu_available=%s xpu_lib=%s has_moe_gemm_decode=%s" + % ( + _xpu_available(), + "loaded" if ark.xpu_lib is not None else "None", + hasattr(ark.xpu_lib, "moe_gemm_decode") if ark.xpu_lib is not None else False, + ) +) +if _DECODE_SKIP: + print("[moe-decode-perf] suite will SKIP. reason: %s" % _DECODE_SKIP) + + +# --------------------------------------------------------------------------- +# Timing utilities. +# --------------------------------------------------------------------------- + +# Warmup / iteration counts kept modest so the suite is still UT-shaped +# (finishes in seconds) but large enough for stable medians. +WARMUP = 5 +ITERS = 30 + + +def _xpu_time_ms(fn, warmup: int = WARMUP, iters: int = ITERS) -> float: + """Time ``fn`` on XPU using device events; returns median ms per call.""" + for _ in range(warmup): + fn() + torch.xpu.synchronize() + + timings = [] + for _ in range(iters): + start = torch.xpu.Event(enable_timing=True) + end = torch.xpu.Event(enable_timing=True) + start.record() + fn() + end.record() + end.synchronize() + timings.append(start.elapsed_time(end)) + timings.sort() + return timings[len(timings) // 2] + + +def _default_moe_decode(activations, dequant_weights, num_tokens_per_expert): + """Default XPU MoE decode baseline: per-expert torch matmul loop. + + This mirrors the path a model would take when no fused MoE decode kernel + is available: gather/sort tokens by expert (done by the caller), then + iterate over experts and do a plain ``A @ W.T`` on each slice. + """ + total_tokens, _ = activations.shape + E, N, _ = dequant_weights.shape + out = torch.empty(total_tokens, N, dtype=activations.dtype, device=activations.device) + offset = 0 + for e in range(E): + n_tokens = int(num_tokens_per_expert[e].item()) + if n_tokens == 0: + continue + a = activations[offset : offset + n_tokens] + out[offset : offset + n_tokens] = a @ dequant_weights[e].T + offset += n_tokens + return out + + +# --------------------------------------------------------------------------- +# Shape matrix. +# +# Picked to cover small / medium / large MoE expert FFNs (Mixtral-style +# 4096x14336 down-projection is the upper bound; smaller shapes catch +# launch-overhead-dominated cases). ``tokens_per_expert`` follows the +# expected decode-phase pattern (top-k routing with batch=1: each active +# expert sees one token). +# --------------------------------------------------------------------------- + +DECODE_SHAPES = [ + # (label, num_experts, tokens_per_expert, N, K) + ("small E=4 ", 4, [1, 0, 1, 1], 1024, 1024), + ("medium E=8 ", 8, [1, 1, 0, 1, 1, 0, 1, 1], 2048, 2048), + ("large E=8 ", 8, [1, 0, 1, 1, 0, 1, 1, 1], 4096, 4096), + ("ffn-up E=8 ", 8, [1, 1, 0, 1, 1, 1, 0, 1], 14336, 4096), + ("ffn-dn E=8 ", 8, [1, 1, 0, 1, 1, 1, 0, 1], 4096, 14336), +] + + +def _print_header(title: str) -> None: + print() + print("=" * 96) + print(title) + print( + f"{'shape':<14}{'N':>7}{'K':>7}{'tokens':>8}" + f"{'baseline(ms)':>16}{'ark(ms)':>14}{'speedup':>12}" + ) + print("-" * 96) + + +def _print_row(label, N, K, total_tokens, base_ms, ark_ms): + speedup = base_ms / ark_ms if ark_ms > 0 else float("nan") + print( + f"{label:<14}{N:>7}{K:>7}{total_tokens:>8}" + f"{base_ms:>16.4f}{ark_ms:>14.4f}{speedup:>11.2f}x" + ) + + +# --------------------------------------------------------------------------- +# Benchmark cases. +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(bool(_DECODE_SKIP), reason=_DECODE_SKIP or "ok") +class TestMoEGemmDecodePerf: + """Median XPU-event timings of ``moe_gemm_decode`` vs per-expert ``A @ W.T``. + + The baseline uses *already-dequantized* weights, so quantized cases only + pay the matmul cost in the timed region (no per-iteration dequant). This + is the most favorable apples-to-apples comparison for the baseline; the + fused decode kernel must beat that to be worth using. + """ + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_perf_fp(self, dtype): + _print_header(f"FP weights ({str(dtype).split('.')[-1]}) -- ark.moe_gemm_decode vs per-expert A @ W.T") + for label, E, tpe, N, K in DECODE_SHAPES: + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + weights = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, weights, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode(activations, weights, ntpe, weight_bits=16) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int4(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT4 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm_decode vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in DECODE_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int4_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int4_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int4_sym(w_float, scales, group_size) + dequant = _dequant_int4_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, zeros=zeros, + weight_bits=4, group_size=group_size, asym=asym, + ) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int8(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT8 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm_decode vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in DECODE_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int8_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int8_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int8_sym(w_float, scales, group_size) + dequant = _dequant_int8_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, zeros=zeros, + weight_bits=8, group_size=group_size, asym=asym, + ) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("asym", [False, True]) + def test_perf_int2(self, dtype, asym): + group_size = 128 + kind = "asym" if asym else "sym" + _print_header( + f"INT2 {kind} (group_size={group_size}, act={str(dtype).split('.')[-1]}) " + f"-- ark.moe_gemm_decode vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in DECODE_SHAPES: + if K % group_size != 0 or K % 4 != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + if asym: + zeros = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_int2_asym(w_float, scales, zeros, group_size) + dequant = _dequant_int2_asym(packed, scales, zeros, group_size).to(dtype) + else: + zeros = None + packed = _pack_int2_sym(w_float, scales, group_size) + dequant = _dequant_int2_sym(packed, scales, group_size).to(dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, zeros=zeros, + weight_bits=2, group_size=group_size, asym=asym, + ) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) + def test_perf_fp8(self, dtype, fp8_dtype): + group_size = 128 + _print_header( + f"FP8 {str(fp8_dtype).split('.')[-1]} (group_size={group_size}, " + f"act={str(dtype).split('.')[-1]}) -- ark.moe_gemm_decode vs dequant + per-expert A @ W.T" + ) + for label, E, tpe, N, K in DECODE_SHAPES: + if K % group_size != 0: + continue + total_tokens = sum(tpe) + activations = torch.randn(total_tokens, K, dtype=dtype, device="xpu") + w_float = (torch.randn(E, N, K, dtype=torch.float32, device="xpu") * 0.1).to(dtype) + scales = torch.empty(E, N, K // group_size, dtype=dtype, device="xpu") + packed = _pack_fp8(w_float, scales, group_size, fp8_dtype) + dequant = _dequant_fp8(packed, scales, group_size, dtype) + ntpe = torch.tensor(tpe, dtype=torch.int32, device="xpu") + + base_ms = _xpu_time_ms(lambda: _default_moe_decode(activations, dequant, ntpe)) + ark_ms = _xpu_time_ms( + lambda: ark.moe_gemm_decode( + activations, packed, ntpe, + scales=scales, + group_size=group_size, asym=False, + ) + ) + _print_row(label, N, K, total_tokens, base_ms, ark_ms) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])