From 287d7cc46fa57dbb1de75cc0003962a0c0a5a0b0 Mon Sep 17 00:00:00 2001 From: Mindev27 Date: Sat, 14 Mar 2026 11:06:33 +0900 Subject: [PATCH] fix: prevent int16 overflow in NEON non-dotprod fallback path --- src/ggml-bitnet-mad.cpp | 156 +++++++++++++++++++++------------------- 1 file changed, 84 insertions(+), 72 deletions(-) diff --git a/src/ggml-bitnet-mad.cpp b/src/ggml-bitnet-mad.cpp index 4ba9d6509..6c0999010 100644 --- a/src/ggml-bitnet-mad.cpp +++ b/src/ggml-bitnet-mad.cpp @@ -317,7 +317,7 @@ void ggml_vec_dot_i2_i8_s_1x1(int n, float * s, size_t bs, const void * vx, size #if defined(__ARM_FEATURE_DOTPROD) #else - int16x8_t accu32 = vdupq_n_s16(0); + int32x4_t accu32 = vdupq_n_s32(0); #endif for (int j=0; j < 32; j++) { uint8x16_t xq8_3 = vld1q_u8(x_row + i * 32 * 16 + j * 16); @@ -341,22 +341,24 @@ void ggml_vec_dot_i2_i8_s_1x1(int n, float * s, size_t bs, const void * vx, size accu = vdotq_s32(accu, q8_2, yq8_2); accu = vdotq_s32(accu, q8_3, yq8_3); #else - accu32 = vmlal_s8(accu32, vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accu32 = vmlal_s8(accu32, vget_high_s8(q8_0), vget_high_s8(yq8_0)); - accu32 = vmlal_s8(accu32, vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accu32 = vmlal_s8(accu32, vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accu32 = vmlal_s8(accu32, vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accu32 = vmlal_s8(accu32, vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accu32 = vmlal_s8(accu32, vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accu32 = vmlal_s8(accu32, vget_high_s8(q8_3), vget_high_s8(yq8_3)); + int16x8_t tmp = vdupq_n_s16(0); + tmp = vmlal_s8(tmp, vget_low_s8(q8_0), vget_low_s8(yq8_0)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_0), vget_high_s8(yq8_0)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_1), vget_low_s8(yq8_1)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_1), vget_high_s8(yq8_1)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_2), vget_low_s8(yq8_2)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_2), vget_high_s8(yq8_2)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_3), vget_low_s8(yq8_3)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_3), vget_high_s8(yq8_3)); + accu32 = vaddq_s32(accu32, vmovl_s16(vget_low_s16(tmp))); + accu32 = vaddq_s32(accu32, vmovl_high_s16(tmp)); #endif } #if defined(__ARM_FEATURE_DOTPROD) #else - accu = vaddq_s32(accu, vmovl_s16(vget_low_s16(accu32))); - accu = vaddq_s32(accu, vmovl_high_s16(accu32)); + accu = vaddq_s32(accu, accu32); #endif } @@ -364,7 +366,7 @@ void ggml_vec_dot_i2_i8_s_1x1(int n, float * s, size_t bs, const void * vx, size #if defined(__ARM_FEATURE_DOTPROD) #else - int16x8_t accula = vdupq_n_s16(0); + int32x4_t accula = vdupq_n_s32(0); #endif for (int j = 0; j < la_num; j++) { uint8x16_t xq8_3 = vld1q_u8(x_row + group32_num * 32 * 16 + j * 16); @@ -388,21 +390,23 @@ void ggml_vec_dot_i2_i8_s_1x1(int n, float * s, size_t bs, const void * vx, size accu = vdotq_s32(accu, q8_2, yq8_2); accu = vdotq_s32(accu, q8_3, yq8_3); #else - accula = vmlal_s8(accula, vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accula = vmlal_s8(accula, vget_high_s8(q8_0), vget_high_s8(yq8_0)); - accula = vmlal_s8(accula, vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accula = vmlal_s8(accula, vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accula = vmlal_s8(accula, vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accula = vmlal_s8(accula, vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accula = vmlal_s8(accula, vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accula = vmlal_s8(accula, vget_high_s8(q8_3), vget_high_s8(yq8_3)); + int16x8_t tmp = vdupq_n_s16(0); + tmp = vmlal_s8(tmp, vget_low_s8(q8_0), vget_low_s8(yq8_0)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_0), vget_high_s8(yq8_0)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_1), vget_low_s8(yq8_1)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_1), vget_high_s8(yq8_1)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_2), vget_low_s8(yq8_2)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_2), vget_high_s8(yq8_2)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_3), vget_low_s8(yq8_3)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_3), vget_high_s8(yq8_3)); + accula = vaddq_s32(accula, vmovl_s16(vget_low_s16(tmp))); + accula = vaddq_s32(accula, vmovl_high_s16(tmp)); #endif } #if defined(__ARM_FEATURE_DOTPROD) #else - accu = vaddq_s32(accu, vmovl_s16(vget_low_s16(accula))); - accu = vaddq_s32(accu, vmovl_high_s16(accula)); + accu = vaddq_s32(accu, accula); #endif } int sumi = vaddlvq_s32(accu); @@ -657,9 +661,9 @@ void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size #if defined(__ARM_FEATURE_DOTPROD) #else - int16x8_t accu32[PARALLEL_SIZE]; + int32x4_t accu32[PARALLEL_SIZE]; for (int rb = 0; rb < PARALLEL_SIZE; rb++) { - accu32[rb] = vdupq_n_s16(0); + accu32[rb] = vdupq_n_s32(0); } #endif const uint8_t * px[PARALLEL_SIZE]; @@ -692,15 +696,17 @@ void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size accu[rb] = vdotq_s32(accu[rb], q8_2, yq8_2); accu[rb] = vdotq_s32(accu[rb], q8_3, yq8_3); #else - accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_3), vget_high_s8(yq8_3)); - accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_0), vget_high_s8(yq8_0)); - + int16x8_t tmp = vdupq_n_s16(0); + tmp = vmlal_s8(tmp, vget_low_s8(q8_3), vget_low_s8(yq8_3)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_3), vget_high_s8(yq8_3)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_2), vget_low_s8(yq8_2)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_2), vget_high_s8(yq8_2)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_1), vget_low_s8(yq8_1)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_1), vget_high_s8(yq8_1)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_0), vget_low_s8(yq8_0)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_0), vget_high_s8(yq8_0)); + accu32[rb] = vaddq_s32(accu32[rb], vmovl_s16(vget_low_s16(tmp))); + accu32[rb] = vaddq_s32(accu32[rb], vmovl_high_s16(tmp)); #endif px[rb] += 16; } @@ -710,8 +716,7 @@ void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size #else for (int rb = 0; rb < PARALLEL_SIZE; rb++) { - accu[rb] = vaddq_s32(accu[rb], vmovl_s16(vget_low_s16(accu32[rb]))); - accu[rb] = vaddq_s32(accu[rb], vmovl_high_s16(accu32[rb])); + accu[rb] = vaddq_s32(accu[rb], accu32[rb]); } #endif } @@ -720,9 +725,9 @@ void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size #if defined(__ARM_FEATURE_DOTPROD) #else - int16x8_t accula[PARALLEL_SIZE]; + int32x4_t accula[PARALLEL_SIZE]; for (int rb = 0; rb < PARALLEL_SIZE; rb++) { - accula[rb] = vdupq_n_s16(0); + accula[rb] = vdupq_n_s32(0); } #endif const uint8_t * px[PARALLEL_SIZE]; @@ -748,22 +753,24 @@ void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); - + #if defined(__ARM_FEATURE_DOTPROD) accu[rb] = vdotq_s32(accu[rb], q8_0, yq8_0); accu[rb] = vdotq_s32(accu[rb], q8_1, yq8_1); accu[rb] = vdotq_s32(accu[rb], q8_2, yq8_2); accu[rb] = vdotq_s32(accu[rb], q8_3, yq8_3); #else - accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_3), vget_high_s8(yq8_3)); - accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_0), vget_high_s8(yq8_0)); - + int16x8_t tmp = vdupq_n_s16(0); + tmp = vmlal_s8(tmp, vget_low_s8(q8_3), vget_low_s8(yq8_3)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_3), vget_high_s8(yq8_3)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_2), vget_low_s8(yq8_2)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_2), vget_high_s8(yq8_2)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_1), vget_low_s8(yq8_1)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_1), vget_high_s8(yq8_1)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_0), vget_low_s8(yq8_0)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_0), vget_high_s8(yq8_0)); + accula[rb] = vaddq_s32(accula[rb], vmovl_s16(vget_low_s16(tmp))); + accula[rb] = vaddq_s32(accula[rb], vmovl_high_s16(tmp)); #endif px[rb] += 16; } @@ -773,8 +780,7 @@ void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size #else for (int rb = 0; rb < PARALLEL_SIZE; rb++) { - accu[rb] = vaddq_s32(accu[rb], vmovl_s16(vget_low_s16(accula[rb]))); - accu[rb] = vaddq_s32(accu[rb], vmovl_high_s16(accula[rb])); + accu[rb] = vaddq_s32(accu[rb], accula[rb]); } #endif } @@ -912,10 +918,10 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size #if defined(__ARM_FEATURE_DOTPROD) #else - int16x8_t accu32[PARALLEL_SIZE]; + int32x4_t accu32[PARALLEL_SIZE]; for (int iy = 0; iy < PARALLEL_SIZE; iy++) { - accu32[iy] = vdupq_n_s16(0); + accu32[iy] = vdupq_n_s32(0); } #endif for (int j = 0; j < 32; j++) { @@ -943,14 +949,17 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size accu[iy] = vdotq_s32(accu[iy], q8_2, yq8_2); accu[iy] = vdotq_s32(accu[iy], q8_3, yq8_3); #else - accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_0), vget_high_s8(yq8_0)); - accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_3), vget_high_s8(yq8_3)); + int16x8_t tmp = vdupq_n_s16(0); + tmp = vmlal_s8(tmp, vget_low_s8(q8_0), vget_low_s8(yq8_0)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_0), vget_high_s8(yq8_0)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_1), vget_low_s8(yq8_1)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_1), vget_high_s8(yq8_1)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_2), vget_low_s8(yq8_2)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_2), vget_high_s8(yq8_2)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_3), vget_low_s8(yq8_3)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_3), vget_high_s8(yq8_3)); + accu32[iy] = vaddq_s32(accu32[iy], vmovl_s16(vget_low_s16(tmp))); + accu32[iy] = vaddq_s32(accu32[iy], vmovl_high_s16(tmp)); #endif } @@ -962,7 +971,7 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size #else for (int iy = 0; iy < PARALLEL_SIZE; iy++) { - accu[iy] = vaddq_s32(accu[iy], vaddq_s32(vmovl_high_s16(accu32[iy]), vmovl_s16(vget_low_s16(accu32[iy])))); + accu[iy] = vaddq_s32(accu[iy], accu32[iy]); } #endif } @@ -974,13 +983,13 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size #if defined(__ARM_FEATURE_DOTPROD) #else - int16x8_t accula[PARALLEL_SIZE]; + int32x4_t accula[PARALLEL_SIZE]; for (int iy = 0; iy < PARALLEL_SIZE; iy++) { - accula[iy] = vdupq_n_s16(0); + accula[iy] = vdupq_n_s32(0); } #endif - + for (int j = 0; j < la_num; j++) { // 加载并解包 x 数据(对所有列共享) uint8x16_t xq8_3 = vld1q_u8(px + 0); @@ -1006,14 +1015,17 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size accu[iy] = vdotq_s32(accu[iy], q8_2, yq8_2); accu[iy] = vdotq_s32(accu[iy], q8_3, yq8_3); #else - accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_0), vget_high_s8(yq8_0)); - accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_3), vget_high_s8(yq8_3)); + int16x8_t tmp = vdupq_n_s16(0); + tmp = vmlal_s8(tmp, vget_low_s8(q8_0), vget_low_s8(yq8_0)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_0), vget_high_s8(yq8_0)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_1), vget_low_s8(yq8_1)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_1), vget_high_s8(yq8_1)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_2), vget_low_s8(yq8_2)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_2), vget_high_s8(yq8_2)); + tmp = vmlal_s8(tmp, vget_low_s8(q8_3), vget_low_s8(yq8_3)); + tmp = vmlal_s8(tmp, vget_high_s8(q8_3), vget_high_s8(yq8_3)); + accula[iy] = vaddq_s32(accula[iy], vmovl_s16(vget_low_s16(tmp))); + accula[iy] = vaddq_s32(accula[iy], vmovl_high_s16(tmp)); #endif } @@ -1025,7 +1037,7 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size #else for (int iy = 0; iy < PARALLEL_SIZE; iy++) { - accu[iy] = vaddq_s32(accu[iy], vaddq_s32(vmovl_high_s16(accula[iy]), vmovl_s16(vget_low_s16(accula[iy])))); + accu[iy] = vaddq_s32(accu[iy], accula[iy]); } #endif }