@@ -3286,130 +3286,223 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor
32863286}
32873287
32883288/* *
3289- * @brief Performs expert-specific matrix multiplication (MoE) with
3290- * quantized precision using the CANN backend.
3291- *
3292- * This function executes a matrix multiplication operation tailored for
3293- * Mixture of Experts (MoE) models, where the input tensor is multiplied
3294- * with expert-specific quantized weight matrices. It leverages the CANN
3295- * backend to perform efficient low-precision computations and stores the
3296- * quantized result in the destination tensor `dst`.
3297- *
3298- * Quantization techniques reduce memory footprint and improve performance
3299- * by using lower-bit representations (e.g., int8) instead of floating-point.
3300- * This function is designed to work with such formats and may incorporate
3301- * optimizations like identity-based fast paths or routing masks for sparse
3302- * expert selection.
3303- *
3304- * @param ctx The context for executing CANN backend operations.
3305- * @param dst The destination tensor where the quantized MoE multiplication result
3306- * will be stored.
3307- *
3308- * @note This function assumes quantized data types and is designed for
3309- * MoE architectures with potential sparse expert routing.
3289+ * @brief Performs quantized matrix multiplication for Mixture of Experts (MoE)
3290+ * models using the CANN backend.
3291+ *
3292+ * This function implements MUL_MAT_ID operation for quantized weight matrices
3293+ * (Q4_0 and Q8_0 formats). It selects expert-specific weight matrices based on
3294+ * the provided expert indices, and computes matrix multiplication using CANN's
3295+ * WeightQuantBatchMatmulV2 operator.
3296+ *
3297+ * The function performs the following steps:
3298+ * 1. Converts input/output tensors to F16 format if necessary
3299+ * 2. Uses IndexSelect to extract expert-specific weights and scales based on indices
3300+ * 3. Performs quantized matrix multiplication for each expert using WeightQuantBatchMatmulV2
3301+ * 4. Converts output back to the target type if needed
3302+ *
3303+ * Tensor shapes:
3304+ * - dst: [M, K, N, 1] - output tensor
3305+ * - src0: [D, M, A, 1] - quantized weight matrices (Q4_0 or Q8_0)
3306+ * - src1: [D, B, N, 1] - input activations (B = K for per-expert input, or B = 1 for broadcast)
3307+ * - ids: [K, N] - expert indices for routing
3308+ *
3309+ * @param ctx The CANN backend context for operation execution.
3310+ * @param dst The destination tensor where the multiplication result will be stored.
3311+ *
3312+ * @note Only Q4_0 and Q8_0 quantization formats are supported.
3313+ * @note The function handles automatic type conversion to/from F16 as needed by the hardware.
33103314 */
33113315static void ggml_cann_mul_mat_id_quant (ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3312- // TODO: Use aclnnGroupedMatMul
3313- // dst [M, K, N, 1]
3314- ggml_tensor * src0 = dst->src [0 ]; // src0 [D, M, A, 1]
3315- ggml_tensor * src1 = dst->src [1 ]; // src1 [D, B, N, 1], B = K or B = 1
3316- ggml_tensor * ids = dst->src [2 ]; // ids [K, N]
3316+ // dst: [M, K, N, 1]
3317+ // src0: [D, M, A, 1] - quantized weights
3318+ // src1: [D, B, N, 1] - input activations, B = K or B = 1
3319+ // ids: [K, N] - expert indices
3320+ ggml_tensor * src0 = dst->src [0 ];
3321+ ggml_tensor * src1 = dst->src [1 ];
3322+ ggml_tensor * ids = dst->src [2 ];
33173323
3318- GGML_TENSOR_BINARY_OP_LOCALS
3324+ GGML_ASSERT (src0->ne [3 ] == 1 );
3325+ GGML_ASSERT (src1->ne [3 ] == 1 );
3326+ GGML_ASSERT (dst->ne [3 ] == 1 );
3327+ GGML_ASSERT (src1->ne [2 ] == ids->ne [1 ]);
3328+
3329+ const int64_t n_batches = ids->ne [1 ];
3330+ const int64_t n_select_experts = ids->ne [0 ];
3331+ const enum ggml_type type = src0->type ;
3332+
3333+ const int32_t group_size = QK8_0; // Both Q4_0 and Q8_0 use group size of 32
3334+ GGML_ASSERT (group_size == QK4_0);
3335+
3336+ // Calculate element size for quantized weights
3337+ const float weight_elem_size =
3338+ (type == GGML_TYPE_Q4_0) ? 0 .5f :
3339+ (type == GGML_TYPE_Q8_0) ? 1 .0f :
3340+ (GGML_ABORT (" MUL_MAT_ID only supports Q4_0 and Q8_0" ), 0 .0f );
3341+
3342+ // Calculate scale offset in memory
3343+ const size_t weight_size = src0->ne [0 ] * src0->ne [1 ] * src0->ne [2 ] * weight_elem_size;
3344+ const size_t scale_elem_size = sizeof (uint16_t );
3345+ char * scale_data = (char *) src0->data + weight_size;
3346+
3347+ // Allocate buffers for selected expert weights and scales
3348+ const size_t selected_weight_size = src0->ne [0 ] * src0->ne [1 ] * n_select_experts * weight_elem_size;
3349+ ggml_cann_pool_alloc selected_weight_alloc (ctx.pool (), selected_weight_size);
3350+ void * selected_weight_buffer = selected_weight_alloc.get ();
3351+
3352+ const size_t selected_scale_size = (src0->ne [0 ] / group_size) * src0->ne [1 ] * n_select_experts * scale_elem_size;
3353+ ggml_cann_pool_alloc selected_scale_alloc (ctx.pool (), selected_scale_size);
3354+ void * selected_scale_buffer = selected_scale_alloc.get ();
3355+
3356+ // Helper lambda to allocate and cast tensor to F16 if needed
3357+ constexpr size_t f16_elem_size = sizeof (uint16_t );
3358+ auto prepare_f16_buffer = [&](ggml_tensor * tensor, ggml_cann_pool_alloc & allocator,
3359+ bool need_cast = false ) -> void * {
3360+ if (tensor->type == GGML_TYPE_F16) {
3361+ return tensor->data ;
3362+ }
33193363
3320- // copy index from npu to cpu
3321- int64_t n_as = ne02; // A
3322- int64_t n_ids = ids->ne [0 ]; // K
3364+ size_t total_size = f16_elem_size;
3365+ for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
3366+ total_size *= tensor->ne [i];
3367+ }
3368+ void * buffer = allocator.alloc (total_size);
33233369
3324- std::vector<char > ids_host (ggml_nbytes (ids));
3325- ACL_CHECK (aclrtMemcpyAsync (ids_host.data (), ggml_nbytes (ids), ids->data , ggml_nbytes (ids),
3326- ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream ()));
3327- ACL_CHECK (aclrtSynchronizeStream (ctx.stream ()));
3370+ if (need_cast == false ) {
3371+ return buffer;
3372+ }
33283373
3329- char * src0_original = (char *) src0->data ;
3330- char * src1_original = (char *) src1->data ;
3331- char * dst_original = (char *) dst->data ;
3374+ int64_t ne[GGML_MAX_DIMS];
3375+ size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
3376+ for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
3377+ ne[i] = tensor->ne [i];
3378+ if (i > 0 ) {
3379+ nb[i] = nb[i - 1 ] * ne[i - 1 ];
3380+ }
3381+ }
33323382
3333- ggml_tensor src0_row = *src0 ;
3334- ggml_tensor src1_row = *src1 ;
3335- ggml_tensor dst_row = *dst ;
3383+ acl_tensor_ptr src_tensor = ggml_cann_create_tensor (tensor) ;
3384+ acl_tensor_ptr f16_tensor = ggml_cann_create_tensor (buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS) ;
3385+ aclnn_cast (ctx, src_tensor. get (), f16_tensor. get (), ACL_FLOAT16) ;
33363386
3337- const enum ggml_type type = dst->src [0 ]->type ;
3338- float weight_elem_size;
3339- if (type == GGML_TYPE_Q4_0) {
3340- weight_elem_size = float (sizeof (uint8_t )) / 2 ;
3341- } else if (type == GGML_TYPE_Q8_0) {
3342- weight_elem_size = float (sizeof (uint8_t ));
3343- } else {
3344- GGML_ABORT (" MUL_MAT_ID only support quant type Q4_0 and Q8_0 " );
3345- }
3387+ return buffer;
3388+ };
33463389
3347- // src0_row [D, M, 1, 1] weight without permute
3348- src0_row.ne [2 ] = 1 ;
3349- src0_row.ne [3 ] = 1 ;
3350- src0_row.nb [0 ] = weight_elem_size;
3351- src0_row.nb [1 ] = weight_elem_size * ne00;
3352- src0_row.nb [2 ] = weight_elem_size * ne00;
3353- src0_row.nb [3 ] = weight_elem_size * ne00;
3354- size_t weight_stride = ne00 * ne01 * weight_elem_size;
3355- size_t weight_size = weight_stride * ne02 * ne03;
3390+ // Prepare input and output buffers
3391+ ggml_cann_pool_alloc input_alloc (ctx.pool ());
3392+ void * input_buffer = prepare_f16_buffer (src1, input_alloc, true );
33563393
3357- // scale [D, M, 1, 1] -> scale && permute
3358- size_t scale_elem_size = sizeof (uint16_t );
3359- size_t scale_stride = src0->ne [1 ] * src0->ne [0 ] / QK8_0 * scale_elem_size;
3394+ ggml_cann_pool_alloc output_alloc (ctx.pool ());
3395+ void * output_buffer = prepare_f16_buffer (dst, output_alloc, false );
3396+
3397+ // Process each batch
3398+ for (int64_t batch_idx = 0 ; batch_idx < n_batches; batch_idx++) {
3399+ // Create index tensor for current batch
3400+ const size_t index_offset = batch_idx * ids->nb [1 ];
3401+ acl_tensor_ptr batch_indices = ggml_cann_create_tensor (ids, ids->ne , ids->nb , 1 , ACL_FORMAT_ND, index_offset);
3402+
3403+ // Select quantized weights using expert indices
3404+ // Q4_0 stores 2 values per byte, Q8_0 stores 1 value per byte
3405+ const int64_t weight_d = (type == GGML_TYPE_Q4_0) ? src0->ne [0 ] / 2 : src0->ne [0 ];
3406+ const int64_t weight_m = src0->ne [1 ];
3407+ const int64_t weight_n_experts = src0->ne [2 ];
3408+
3409+ int64_t weight_ne[3 ] = { weight_d, weight_m, weight_n_experts };
3410+ size_t weight_nb[3 ] = { sizeof (int8_t ), weight_d * sizeof (int8_t ), weight_d * weight_m * sizeof (int8_t ) };
3411+
3412+ acl_tensor_ptr all_weights =
3413+ ggml_cann_create_tensor (src0->data , ACL_INT8, sizeof (int8_t ), weight_ne, weight_nb, 3 );
3414+
3415+ int64_t selected_weight_ne[3 ] = { weight_d, weight_m, n_select_experts };
3416+ size_t selected_weight_nb[3 ] = { sizeof (int8_t ), weight_d * sizeof (int8_t ),
3417+ weight_d * weight_m * sizeof (int8_t ) };
3418+
3419+ acl_tensor_ptr selected_weights = ggml_cann_create_tensor (selected_weight_buffer, ACL_INT8, sizeof (int8_t ),
3420+ selected_weight_ne, selected_weight_nb, 3 );
3421+
3422+ GGML_CANN_CALL_ACLNN_OP (ctx, IndexSelect, all_weights.get (), 0 , batch_indices.get (), selected_weights.get ());
3423+
3424+ // Select scales using the same expert indices
3425+ const int64_t scale_d = src0->ne [0 ] / group_size;
3426+ int64_t scale_ne[3 ] = { scale_d, weight_m, weight_n_experts };
3427+ size_t scale_nb[3 ] = { scale_elem_size, scale_d * scale_elem_size, scale_d * weight_m * scale_elem_size };
3428+
3429+ acl_tensor_ptr all_scales =
3430+ ggml_cann_create_tensor (scale_data, ACL_FLOAT16, scale_elem_size, scale_ne, scale_nb, 3 );
3431+
3432+ int64_t selected_scale_ne[3 ] = { scale_d, weight_m, n_select_experts };
3433+ size_t selected_scale_nb[3 ] = { scale_elem_size, scale_d * scale_elem_size,
3434+ scale_d * weight_m * scale_elem_size };
3435+
3436+ acl_tensor_ptr selected_scales = ggml_cann_create_tensor (selected_scale_buffer, ACL_FLOAT16, scale_elem_size,
3437+ selected_scale_ne, selected_scale_nb, 3 );
3438+
3439+ GGML_CANN_CALL_ACLNN_OP (ctx, IndexSelect, all_scales.get (), 0 , batch_indices.get (), selected_scales.get ());
3440+
3441+ // Process each expert for current batch
3442+ // IndexSelect output layout: [D, M, K] in contiguous format
3443+ // WeightQuantBatchMatmulV2 expects: [M, D] with row-major stride
3444+ for (int64_t expert_idx = 0 ; expert_idx < n_select_experts; expert_idx++) {
3445+ // Determine input offset: broadcast if src1->ne[1]==1, otherwise use per-expert input
3446+ const size_t input_offset =
3447+ (batch_idx * src1->ne [1 ] + (src1->ne [1 ] == 1 ? 0 : expert_idx)) * src1->ne [0 ] * f16_elem_size;
3448+ const size_t output_offset = (batch_idx * dst->ne [1 ] + expert_idx) * dst->ne [0 ] * f16_elem_size;
3449+
3450+ // Create weight view for current expert: [D, M, K] -> [M, D]
3451+ int64_t weight_view_ne[2 ] = { weight_m, src0->ne [0 ] };
3452+ float weight_view_nb[2 ] = { src0->ne [0 ] * weight_elem_size, weight_elem_size };
3453+ const size_t weight_view_offset = expert_idx * selected_weight_nb[2 ];
33603454
3361- // src1_row [D, 1, 1, 1] -> input
3362- src1_row.ne [1 ] = 1 ;
3363- src1_row.ne [2 ] = 1 ;
3364- src1_row.ne [3 ] = 1 ;
3365- src1_row.nb [2 ] = nb11;
3366- src1_row.nb [3 ] = nb11;
3367-
3368- // dst_row [M, 1, 1, 1] -> out
3369- dst_row.ne [1 ] = 1 ;
3370- dst_row.ne [2 ] = 1 ;
3371- dst_row.ne [3 ] = 1 ;
3372- dst_row.nb [2 ] = nb1;
3373- dst_row.nb [3 ] = nb1;
3374-
3375- // create weight for one row
3376- ggml_cann_pool_alloc weight_allocator (ctx.pool ());
3377- void * weight_buffer = weight_allocator.alloc (nb02);
3378- for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
3379- for (int64_t id = 0 ; id < n_ids; id++) {
3380- // expert index
3381- int32_t i02 = *(int32_t *) (ids_host.data () + iid1 * ids->nb [1 ] + id * ids->nb [0 ]);
3382- GGML_ASSERT (i02 >= 0 && i02 < n_as);
3383-
3384- // If B = 1 (broadcast), always use 0; otherwise, use id.
3385- int64_t i11 = (ne11 == 1 ? 0 : id);
3386- int64_t i12 = iid1;
3387-
3388- int64_t i1 = id;
3389- int64_t i2 = i12;
3390-
3391- void * src0_tmp_ptr = src0_original + i02 * weight_stride;
3392- void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride;
3393- void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12;
3394- void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2;
3395-
3396- // mem cpy
3397- ACL_CHECK (aclrtMemcpyAsync (weight_buffer, weight_stride, src0_tmp_ptr, weight_stride,
3398- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream ()));
3399- void * scale_buffer = (char *) weight_buffer + weight_stride;
3400- ACL_CHECK (aclrtMemcpyAsync (scale_buffer, scale_stride, scale_tmp_ptr, scale_stride,
3401- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream ()));
3402-
3403- src0_row.data = weight_buffer;
3404- src1_row.data = src1_tmp_ptr;
3405- dst_row.data = dst_tmp_ptr;
3406- dst_row.src [0 ] = &src0_row;
3407- dst_row.src [1 ] = &src1_row;
3408-
3409- ggml_cann_mul_mat (ctx, &dst_row);
3455+ acl_tensor_ptr weight_view =
3456+ ggml_cann_create_tensor (selected_weight_buffer, ggml_cann_type_mapping (type), weight_elem_size,
3457+ weight_view_ne, weight_view_nb, 2 , ACL_FORMAT_ND, weight_view_offset);
3458+
3459+ // Create scale view for current expert: [D, M, K] -> [M, D]
3460+ int64_t scale_view_ne[2 ] = { weight_m, scale_d };
3461+ size_t scale_view_nb[2 ] = { selected_scale_nb[1 ], selected_scale_nb[0 ] };
3462+ const size_t scale_view_offset = expert_idx * selected_scale_nb[2 ];
3463+
3464+ acl_tensor_ptr scale_view =
3465+ ggml_cann_create_tensor (selected_scale_buffer, ACL_FLOAT16, scale_elem_size, scale_view_ne,
3466+ scale_view_nb, 2 , ACL_FORMAT_ND, scale_view_offset);
3467+
3468+ // Create input activation tensor [D, 1]
3469+ int64_t input_ne[2 ] = { src1->ne [0 ], 1 };
3470+ size_t input_nb[2 ] = { f16_elem_size, src1->ne [0 ] * f16_elem_size };
3471+
3472+ acl_tensor_ptr input_tensor = ggml_cann_create_tensor (input_buffer, ACL_FLOAT16, f16_elem_size, input_ne,
3473+ input_nb, 2 , ACL_FORMAT_ND, input_offset);
3474+
3475+ // Create output tensor [M, 1]
3476+ int64_t output_ne[2 ] = { dst->ne [0 ], 1 };
3477+ size_t output_nb[2 ] = { f16_elem_size, dst->ne [0 ] * f16_elem_size };
3478+
3479+ acl_tensor_ptr output_tensor = ggml_cann_create_tensor (output_buffer, ACL_FLOAT16, f16_elem_size, output_ne,
3480+ output_nb, 2 , ACL_FORMAT_ND, output_offset);
3481+
3482+ // Perform quantized matrix multiplication
3483+ GGML_CANN_CALL_ACLNN_OP (ctx, WeightQuantBatchMatmulV2, input_tensor.get (), weight_view.get (),
3484+ scale_view.get (), nullptr , nullptr , nullptr , nullptr , group_size,
3485+ output_tensor.get ());
34103486 }
34113487 }
3412- return ;
3488+
3489+ // Cast output back to original type if we used a temporary F16 buffer
3490+ if (dst->type != GGML_TYPE_F16) {
3491+ int64_t ne[GGML_MAX_DIMS];
3492+ size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
3493+ for (int i = 0 ; i < GGML_MAX_DIMS; i++) {
3494+ ne[i] = dst->ne [i];
3495+ if (i > 0 ) {
3496+ nb[i] = nb[i - 1 ] * ne[i - 1 ];
3497+ }
3498+ }
3499+
3500+ acl_tensor_ptr f16_output =
3501+ ggml_cann_create_tensor (output_buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
3502+ acl_tensor_ptr dst_tensor = ggml_cann_create_tensor (dst);
3503+
3504+ aclnn_cast (ctx, f16_output.get (), dst_tensor.get (), ggml_cann_type_mapping (dst->type ));
3505+ }
34133506}
34143507
34153508void ggml_cann_mul_mat_id (ggml_backend_cann_context & ctx, ggml_tensor * dst) {
0 commit comments