Skip to content

Commit 5624c99

Browse files
committed
CANN: implement quantized MUL_MAT_ID for MoE models
Implement ggml_cann_mul_mat_id_quant function to support quantized matrix multiplication for Mixture of Experts (MoE) architectures on CANN backend. Key features: - Support Q4_0 and Q8_0 quantized weight formats - Use IndexSelect to dynamically route expert-specific weights based on indices - Leverage WeightQuantBatchMatmulV2 for efficient quantized computation - Handle automatic F16 type conversion for hardware compatibility - Support both per-expert and broadcast input modes Implementation details: - Extract expert weights and scales using CANN IndexSelect operation - Process each batch and expert combination independently - Create proper tensor views with correct stride for matmul operations - Automatic input/output type casting to/from F16 as needed Testing: All test cases passed for supported types (F32, F16, Q4_0, Q8_0).
1 parent 25f40ca commit 5624c99

1 file changed

Lines changed: 204 additions & 111 deletions

File tree

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 204 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -3286,130 +3286,223 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor
32863286
}
32873287

32883288
/**
3289-
* @brief Performs expert-specific matrix multiplication (MoE) with
3290-
* quantized precision using the CANN backend.
3291-
*
3292-
* This function executes a matrix multiplication operation tailored for
3293-
* Mixture of Experts (MoE) models, where the input tensor is multiplied
3294-
* with expert-specific quantized weight matrices. It leverages the CANN
3295-
* backend to perform efficient low-precision computations and stores the
3296-
* quantized result in the destination tensor `dst`.
3297-
*
3298-
* Quantization techniques reduce memory footprint and improve performance
3299-
* by using lower-bit representations (e.g., int8) instead of floating-point.
3300-
* This function is designed to work with such formats and may incorporate
3301-
* optimizations like identity-based fast paths or routing masks for sparse
3302-
* expert selection.
3303-
*
3304-
* @param ctx The context for executing CANN backend operations.
3305-
* @param dst The destination tensor where the quantized MoE multiplication result
3306-
* will be stored.
3307-
*
3308-
* @note This function assumes quantized data types and is designed for
3309-
* MoE architectures with potential sparse expert routing.
3289+
* @brief Performs quantized matrix multiplication for Mixture of Experts (MoE)
3290+
* models using the CANN backend.
3291+
*
3292+
* This function implements MUL_MAT_ID operation for quantized weight matrices
3293+
* (Q4_0 and Q8_0 formats). It selects expert-specific weight matrices based on
3294+
* the provided expert indices, and computes matrix multiplication using CANN's
3295+
* WeightQuantBatchMatmulV2 operator.
3296+
*
3297+
* The function performs the following steps:
3298+
* 1. Converts input/output tensors to F16 format if necessary
3299+
* 2. Uses IndexSelect to extract expert-specific weights and scales based on indices
3300+
* 3. Performs quantized matrix multiplication for each expert using WeightQuantBatchMatmulV2
3301+
* 4. Converts output back to the target type if needed
3302+
*
3303+
* Tensor shapes:
3304+
* - dst: [M, K, N, 1] - output tensor
3305+
* - src0: [D, M, A, 1] - quantized weight matrices (Q4_0 or Q8_0)
3306+
* - src1: [D, B, N, 1] - input activations (B = K for per-expert input, or B = 1 for broadcast)
3307+
* - ids: [K, N] - expert indices for routing
3308+
*
3309+
* @param ctx The CANN backend context for operation execution.
3310+
* @param dst The destination tensor where the multiplication result will be stored.
3311+
*
3312+
* @note Only Q4_0 and Q8_0 quantization formats are supported.
3313+
* @note The function handles automatic type conversion to/from F16 as needed by the hardware.
33103314
*/
33113315
static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3312-
// TODO: Use aclnnGroupedMatMul
3313-
//dst [M, K, N, 1]
3314-
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
3315-
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
3316-
ggml_tensor * ids = dst->src[2]; //ids [K, N]
3316+
// dst: [M, K, N, 1]
3317+
// src0: [D, M, A, 1] - quantized weights
3318+
// src1: [D, B, N, 1] - input activations, B = K or B = 1
3319+
// ids: [K, N] - expert indices
3320+
ggml_tensor * src0 = dst->src[0];
3321+
ggml_tensor * src1 = dst->src[1];
3322+
ggml_tensor * ids = dst->src[2];
33173323

3318-
GGML_TENSOR_BINARY_OP_LOCALS
3324+
GGML_ASSERT(src0->ne[3] == 1);
3325+
GGML_ASSERT(src1->ne[3] == 1);
3326+
GGML_ASSERT(dst->ne[3] == 1);
3327+
GGML_ASSERT(src1->ne[2] == ids->ne[1]);
3328+
3329+
const int64_t n_batches = ids->ne[1];
3330+
const int64_t n_select_experts = ids->ne[0];
3331+
const enum ggml_type type = src0->type;
3332+
3333+
const int32_t group_size = QK8_0; // Both Q4_0 and Q8_0 use group size of 32
3334+
GGML_ASSERT(group_size == QK4_0);
3335+
3336+
// Calculate element size for quantized weights
3337+
const float weight_elem_size =
3338+
(type == GGML_TYPE_Q4_0) ? 0.5f :
3339+
(type == GGML_TYPE_Q8_0) ? 1.0f :
3340+
(GGML_ABORT("MUL_MAT_ID only supports Q4_0 and Q8_0"), 0.0f);
3341+
3342+
// Calculate scale offset in memory
3343+
const size_t weight_size = src0->ne[0] * src0->ne[1] * src0->ne[2] * weight_elem_size;
3344+
const size_t scale_elem_size = sizeof(uint16_t);
3345+
char * scale_data = (char *) src0->data + weight_size;
3346+
3347+
// Allocate buffers for selected expert weights and scales
3348+
const size_t selected_weight_size = src0->ne[0] * src0->ne[1] * n_select_experts * weight_elem_size;
3349+
ggml_cann_pool_alloc selected_weight_alloc(ctx.pool(), selected_weight_size);
3350+
void * selected_weight_buffer = selected_weight_alloc.get();
3351+
3352+
const size_t selected_scale_size = (src0->ne[0] / group_size) * src0->ne[1] * n_select_experts * scale_elem_size;
3353+
ggml_cann_pool_alloc selected_scale_alloc(ctx.pool(), selected_scale_size);
3354+
void * selected_scale_buffer = selected_scale_alloc.get();
3355+
3356+
// Helper lambda to allocate and cast tensor to F16 if needed
3357+
constexpr size_t f16_elem_size = sizeof(uint16_t);
3358+
auto prepare_f16_buffer = [&](ggml_tensor * tensor, ggml_cann_pool_alloc & allocator,
3359+
bool need_cast = false) -> void * {
3360+
if (tensor->type == GGML_TYPE_F16) {
3361+
return tensor->data;
3362+
}
33193363

3320-
// copy index from npu to cpu
3321-
int64_t n_as = ne02; // A
3322-
int64_t n_ids = ids->ne[0]; // K
3364+
size_t total_size = f16_elem_size;
3365+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
3366+
total_size *= tensor->ne[i];
3367+
}
3368+
void * buffer = allocator.alloc(total_size);
33233369

3324-
std::vector<char> ids_host(ggml_nbytes(ids));
3325-
ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids->data, ggml_nbytes(ids),
3326-
ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream()));
3327-
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
3370+
if (need_cast == false) {
3371+
return buffer;
3372+
}
33283373

3329-
char * src0_original = (char *) src0->data;
3330-
char * src1_original = (char *) src1->data;
3331-
char * dst_original = (char *) dst->data;
3374+
int64_t ne[GGML_MAX_DIMS];
3375+
size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
3376+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
3377+
ne[i] = tensor->ne[i];
3378+
if (i > 0) {
3379+
nb[i] = nb[i - 1] * ne[i - 1];
3380+
}
3381+
}
33323382

3333-
ggml_tensor src0_row = *src0;
3334-
ggml_tensor src1_row = *src1;
3335-
ggml_tensor dst_row = *dst;
3383+
acl_tensor_ptr src_tensor = ggml_cann_create_tensor(tensor);
3384+
acl_tensor_ptr f16_tensor = ggml_cann_create_tensor(buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
3385+
aclnn_cast(ctx, src_tensor.get(), f16_tensor.get(), ACL_FLOAT16);
33363386

3337-
const enum ggml_type type = dst->src[0]->type;
3338-
float weight_elem_size;
3339-
if (type == GGML_TYPE_Q4_0) {
3340-
weight_elem_size = float(sizeof(uint8_t)) / 2;
3341-
} else if (type == GGML_TYPE_Q8_0) {
3342-
weight_elem_size = float(sizeof(uint8_t));
3343-
} else {
3344-
GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
3345-
}
3387+
return buffer;
3388+
};
33463389

3347-
// src0_row [D, M, 1, 1] weight without permute
3348-
src0_row.ne[2] = 1;
3349-
src0_row.ne[3] = 1;
3350-
src0_row.nb[0] = weight_elem_size;
3351-
src0_row.nb[1] = weight_elem_size * ne00;
3352-
src0_row.nb[2] = weight_elem_size * ne00;
3353-
src0_row.nb[3] = weight_elem_size * ne00;
3354-
size_t weight_stride = ne00 * ne01 * weight_elem_size;
3355-
size_t weight_size = weight_stride * ne02 * ne03;
3390+
// Prepare input and output buffers
3391+
ggml_cann_pool_alloc input_alloc(ctx.pool());
3392+
void * input_buffer = prepare_f16_buffer(src1, input_alloc, true);
33563393

3357-
// scale [D, M, 1, 1] -> scale && permute
3358-
size_t scale_elem_size = sizeof(uint16_t);
3359-
size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
3394+
ggml_cann_pool_alloc output_alloc(ctx.pool());
3395+
void * output_buffer = prepare_f16_buffer(dst, output_alloc, false);
3396+
3397+
// Process each batch
3398+
for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) {
3399+
// Create index tensor for current batch
3400+
const size_t index_offset = batch_idx * ids->nb[1];
3401+
acl_tensor_ptr batch_indices = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, index_offset);
3402+
3403+
// Select quantized weights using expert indices
3404+
// Q4_0 stores 2 values per byte, Q8_0 stores 1 value per byte
3405+
const int64_t weight_d = (type == GGML_TYPE_Q4_0) ? src0->ne[0] / 2 : src0->ne[0];
3406+
const int64_t weight_m = src0->ne[1];
3407+
const int64_t weight_n_experts = src0->ne[2];
3408+
3409+
int64_t weight_ne[3] = { weight_d, weight_m, weight_n_experts };
3410+
size_t weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), weight_d * weight_m * sizeof(int8_t) };
3411+
3412+
acl_tensor_ptr all_weights =
3413+
ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, 3);
3414+
3415+
int64_t selected_weight_ne[3] = { weight_d, weight_m, n_select_experts };
3416+
size_t selected_weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t),
3417+
weight_d * weight_m * sizeof(int8_t) };
3418+
3419+
acl_tensor_ptr selected_weights = ggml_cann_create_tensor(selected_weight_buffer, ACL_INT8, sizeof(int8_t),
3420+
selected_weight_ne, selected_weight_nb, 3);
3421+
3422+
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_weights.get(), 0, batch_indices.get(), selected_weights.get());
3423+
3424+
// Select scales using the same expert indices
3425+
const int64_t scale_d = src0->ne[0] / group_size;
3426+
int64_t scale_ne[3] = { scale_d, weight_m, weight_n_experts };
3427+
size_t scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, scale_d * weight_m * scale_elem_size };
3428+
3429+
acl_tensor_ptr all_scales =
3430+
ggml_cann_create_tensor(scale_data, ACL_FLOAT16, scale_elem_size, scale_ne, scale_nb, 3);
3431+
3432+
int64_t selected_scale_ne[3] = { scale_d, weight_m, n_select_experts };
3433+
size_t selected_scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size,
3434+
scale_d * weight_m * scale_elem_size };
3435+
3436+
acl_tensor_ptr selected_scales = ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size,
3437+
selected_scale_ne, selected_scale_nb, 3);
3438+
3439+
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_scales.get(), 0, batch_indices.get(), selected_scales.get());
3440+
3441+
// Process each expert for current batch
3442+
// IndexSelect output layout: [D, M, K] in contiguous format
3443+
// WeightQuantBatchMatmulV2 expects: [M, D] with row-major stride
3444+
for (int64_t expert_idx = 0; expert_idx < n_select_experts; expert_idx++) {
3445+
// Determine input offset: broadcast if src1->ne[1]==1, otherwise use per-expert input
3446+
const size_t input_offset =
3447+
(batch_idx * src1->ne[1] + (src1->ne[1] == 1 ? 0 : expert_idx)) * src1->ne[0] * f16_elem_size;
3448+
const size_t output_offset = (batch_idx * dst->ne[1] + expert_idx) * dst->ne[0] * f16_elem_size;
3449+
3450+
// Create weight view for current expert: [D, M, K] -> [M, D]
3451+
int64_t weight_view_ne[2] = { weight_m, src0->ne[0] };
3452+
float weight_view_nb[2] = { src0->ne[0] * weight_elem_size, weight_elem_size };
3453+
const size_t weight_view_offset = expert_idx * selected_weight_nb[2];
33603454

3361-
// src1_row [D, 1, 1, 1] -> input
3362-
src1_row.ne[1] = 1;
3363-
src1_row.ne[2] = 1;
3364-
src1_row.ne[3] = 1;
3365-
src1_row.nb[2] = nb11;
3366-
src1_row.nb[3] = nb11;
3367-
3368-
// dst_row [M, 1, 1, 1] -> out
3369-
dst_row.ne[1] = 1;
3370-
dst_row.ne[2] = 1;
3371-
dst_row.ne[3] = 1;
3372-
dst_row.nb[2] = nb1;
3373-
dst_row.nb[3] = nb1;
3374-
3375-
//create weight for one row
3376-
ggml_cann_pool_alloc weight_allocator(ctx.pool());
3377-
void * weight_buffer = weight_allocator.alloc(nb02);
3378-
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
3379-
for (int64_t id = 0; id < n_ids; id++) {
3380-
// expert index
3381-
int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]);
3382-
GGML_ASSERT(i02 >= 0 && i02 < n_as);
3383-
3384-
// If B = 1 (broadcast), always use 0; otherwise, use id.
3385-
int64_t i11 = (ne11 == 1 ? 0 : id);
3386-
int64_t i12 = iid1;
3387-
3388-
int64_t i1 = id;
3389-
int64_t i2 = i12;
3390-
3391-
void * src0_tmp_ptr = src0_original + i02 * weight_stride;
3392-
void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride;
3393-
void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12;
3394-
void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2;
3395-
3396-
// mem cpy
3397-
ACL_CHECK(aclrtMemcpyAsync(weight_buffer, weight_stride, src0_tmp_ptr, weight_stride,
3398-
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
3399-
void * scale_buffer = (char *) weight_buffer + weight_stride;
3400-
ACL_CHECK(aclrtMemcpyAsync(scale_buffer, scale_stride, scale_tmp_ptr, scale_stride,
3401-
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
3402-
3403-
src0_row.data = weight_buffer;
3404-
src1_row.data = src1_tmp_ptr;
3405-
dst_row.data = dst_tmp_ptr;
3406-
dst_row.src[0] = &src0_row;
3407-
dst_row.src[1] = &src1_row;
3408-
3409-
ggml_cann_mul_mat(ctx, &dst_row);
3455+
acl_tensor_ptr weight_view =
3456+
ggml_cann_create_tensor(selected_weight_buffer, ggml_cann_type_mapping(type), weight_elem_size,
3457+
weight_view_ne, weight_view_nb, 2, ACL_FORMAT_ND, weight_view_offset);
3458+
3459+
// Create scale view for current expert: [D, M, K] -> [M, D]
3460+
int64_t scale_view_ne[2] = { weight_m, scale_d };
3461+
size_t scale_view_nb[2] = { selected_scale_nb[1], selected_scale_nb[0] };
3462+
const size_t scale_view_offset = expert_idx * selected_scale_nb[2];
3463+
3464+
acl_tensor_ptr scale_view =
3465+
ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, scale_view_ne,
3466+
scale_view_nb, 2, ACL_FORMAT_ND, scale_view_offset);
3467+
3468+
// Create input activation tensor [D, 1]
3469+
int64_t input_ne[2] = { src1->ne[0], 1 };
3470+
size_t input_nb[2] = { f16_elem_size, src1->ne[0] * f16_elem_size };
3471+
3472+
acl_tensor_ptr input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, f16_elem_size, input_ne,
3473+
input_nb, 2, ACL_FORMAT_ND, input_offset);
3474+
3475+
// Create output tensor [M, 1]
3476+
int64_t output_ne[2] = { dst->ne[0], 1 };
3477+
size_t output_nb[2] = { f16_elem_size, dst->ne[0] * f16_elem_size };
3478+
3479+
acl_tensor_ptr output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, output_ne,
3480+
output_nb, 2, ACL_FORMAT_ND, output_offset);
3481+
3482+
// Perform quantized matrix multiplication
3483+
GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, input_tensor.get(), weight_view.get(),
3484+
scale_view.get(), nullptr, nullptr, nullptr, nullptr, group_size,
3485+
output_tensor.get());
34103486
}
34113487
}
3412-
return;
3488+
3489+
// Cast output back to original type if we used a temporary F16 buffer
3490+
if (dst->type != GGML_TYPE_F16) {
3491+
int64_t ne[GGML_MAX_DIMS];
3492+
size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
3493+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
3494+
ne[i] = dst->ne[i];
3495+
if (i > 0) {
3496+
nb[i] = nb[i - 1] * ne[i - 1];
3497+
}
3498+
}
3499+
3500+
acl_tensor_ptr f16_output =
3501+
ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
3502+
acl_tensor_ptr dst_tensor = ggml_cann_create_tensor(dst);
3503+
3504+
aclnn_cast(ctx, f16_output.get(), dst_tensor.get(), ggml_cann_type_mapping(dst->type));
3505+
}
34133506
}
34143507

34153508
void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) {

0 commit comments

Comments
 (0)