Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 84 additions & 72 deletions src/ggml-bitnet-mad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ void ggml_vec_dot_i2_i8_s_1x1(int n, float * s, size_t bs, const void * vx, size
#if defined(__ARM_FEATURE_DOTPROD)

#else
int16x8_t accu32 = vdupq_n_s16(0);
int32x4_t accu32 = vdupq_n_s32(0);
#endif
for (int j=0; j < 32; j++) {
uint8x16_t xq8_3 = vld1q_u8(x_row + i * 32 * 16 + j * 16);
Expand All @@ -341,30 +341,32 @@ void ggml_vec_dot_i2_i8_s_1x1(int n, float * s, size_t bs, const void * vx, size
accu = vdotq_s32(accu, q8_2, yq8_2);
accu = vdotq_s32(accu, q8_3, yq8_3);
#else
accu32 = vmlal_s8(accu32, vget_low_s8(q8_0), vget_low_s8(yq8_0));
accu32 = vmlal_s8(accu32, vget_high_s8(q8_0), vget_high_s8(yq8_0));
accu32 = vmlal_s8(accu32, vget_low_s8(q8_1), vget_low_s8(yq8_1));
accu32 = vmlal_s8(accu32, vget_high_s8(q8_1), vget_high_s8(yq8_1));
accu32 = vmlal_s8(accu32, vget_low_s8(q8_2), vget_low_s8(yq8_2));
accu32 = vmlal_s8(accu32, vget_high_s8(q8_2), vget_high_s8(yq8_2));
accu32 = vmlal_s8(accu32, vget_low_s8(q8_3), vget_low_s8(yq8_3));
accu32 = vmlal_s8(accu32, vget_high_s8(q8_3), vget_high_s8(yq8_3));
int16x8_t tmp = vdupq_n_s16(0);
tmp = vmlal_s8(tmp, vget_low_s8(q8_0), vget_low_s8(yq8_0));
tmp = vmlal_s8(tmp, vget_high_s8(q8_0), vget_high_s8(yq8_0));
tmp = vmlal_s8(tmp, vget_low_s8(q8_1), vget_low_s8(yq8_1));
tmp = vmlal_s8(tmp, vget_high_s8(q8_1), vget_high_s8(yq8_1));
tmp = vmlal_s8(tmp, vget_low_s8(q8_2), vget_low_s8(yq8_2));
tmp = vmlal_s8(tmp, vget_high_s8(q8_2), vget_high_s8(yq8_2));
tmp = vmlal_s8(tmp, vget_low_s8(q8_3), vget_low_s8(yq8_3));
tmp = vmlal_s8(tmp, vget_high_s8(q8_3), vget_high_s8(yq8_3));
accu32 = vaddq_s32(accu32, vmovl_s16(vget_low_s16(tmp)));
accu32 = vaddq_s32(accu32, vmovl_high_s16(tmp));
#endif
}

#if defined(__ARM_FEATURE_DOTPROD)

#else
accu = vaddq_s32(accu, vmovl_s16(vget_low_s16(accu32)));
accu = vaddq_s32(accu, vmovl_high_s16(accu32));
accu = vaddq_s32(accu, accu32);
#endif
}

for (int i = 0; i < groupla_num; i++){
#if defined(__ARM_FEATURE_DOTPROD)

#else
int16x8_t accula = vdupq_n_s16(0);
int32x4_t accula = vdupq_n_s32(0);
#endif
for (int j = 0; j < la_num; j++) {
uint8x16_t xq8_3 = vld1q_u8(x_row + group32_num * 32 * 16 + j * 16);
Expand All @@ -388,21 +390,23 @@ void ggml_vec_dot_i2_i8_s_1x1(int n, float * s, size_t bs, const void * vx, size
accu = vdotq_s32(accu, q8_2, yq8_2);
accu = vdotq_s32(accu, q8_3, yq8_3);
#else
accula = vmlal_s8(accula, vget_low_s8(q8_0), vget_low_s8(yq8_0));
accula = vmlal_s8(accula, vget_high_s8(q8_0), vget_high_s8(yq8_0));
accula = vmlal_s8(accula, vget_low_s8(q8_1), vget_low_s8(yq8_1));
accula = vmlal_s8(accula, vget_high_s8(q8_1), vget_high_s8(yq8_1));
accula = vmlal_s8(accula, vget_low_s8(q8_2), vget_low_s8(yq8_2));
accula = vmlal_s8(accula, vget_high_s8(q8_2), vget_high_s8(yq8_2));
accula = vmlal_s8(accula, vget_low_s8(q8_3), vget_low_s8(yq8_3));
accula = vmlal_s8(accula, vget_high_s8(q8_3), vget_high_s8(yq8_3));
int16x8_t tmp = vdupq_n_s16(0);
tmp = vmlal_s8(tmp, vget_low_s8(q8_0), vget_low_s8(yq8_0));
tmp = vmlal_s8(tmp, vget_high_s8(q8_0), vget_high_s8(yq8_0));
tmp = vmlal_s8(tmp, vget_low_s8(q8_1), vget_low_s8(yq8_1));
tmp = vmlal_s8(tmp, vget_high_s8(q8_1), vget_high_s8(yq8_1));
tmp = vmlal_s8(tmp, vget_low_s8(q8_2), vget_low_s8(yq8_2));
tmp = vmlal_s8(tmp, vget_high_s8(q8_2), vget_high_s8(yq8_2));
tmp = vmlal_s8(tmp, vget_low_s8(q8_3), vget_low_s8(yq8_3));
tmp = vmlal_s8(tmp, vget_high_s8(q8_3), vget_high_s8(yq8_3));
accula = vaddq_s32(accula, vmovl_s16(vget_low_s16(tmp)));
accula = vaddq_s32(accula, vmovl_high_s16(tmp));
#endif
}
#if defined(__ARM_FEATURE_DOTPROD)

#else
accu = vaddq_s32(accu, vmovl_s16(vget_low_s16(accula)));
accu = vaddq_s32(accu, vmovl_high_s16(accula));
accu = vaddq_s32(accu, accula);
#endif
}
int sumi = vaddlvq_s32(accu);
Expand Down Expand Up @@ -657,9 +661,9 @@ void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size
#if defined(__ARM_FEATURE_DOTPROD)

#else
int16x8_t accu32[PARALLEL_SIZE];
int32x4_t accu32[PARALLEL_SIZE];
for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
accu32[rb] = vdupq_n_s16(0);
accu32[rb] = vdupq_n_s32(0);
}
#endif
const uint8_t * px[PARALLEL_SIZE];
Expand Down Expand Up @@ -692,15 +696,17 @@ void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size
accu[rb] = vdotq_s32(accu[rb], q8_2, yq8_2);
accu[rb] = vdotq_s32(accu[rb], q8_3, yq8_3);
#else
accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_3), vget_low_s8(yq8_3));
accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_3), vget_high_s8(yq8_3));
accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_2), vget_low_s8(yq8_2));
accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_2), vget_high_s8(yq8_2));
accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_1), vget_low_s8(yq8_1));
accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_1), vget_high_s8(yq8_1));
accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_0), vget_low_s8(yq8_0));
accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_0), vget_high_s8(yq8_0));

int16x8_t tmp = vdupq_n_s16(0);
tmp = vmlal_s8(tmp, vget_low_s8(q8_3), vget_low_s8(yq8_3));
tmp = vmlal_s8(tmp, vget_high_s8(q8_3), vget_high_s8(yq8_3));
tmp = vmlal_s8(tmp, vget_low_s8(q8_2), vget_low_s8(yq8_2));
tmp = vmlal_s8(tmp, vget_high_s8(q8_2), vget_high_s8(yq8_2));
tmp = vmlal_s8(tmp, vget_low_s8(q8_1), vget_low_s8(yq8_1));
tmp = vmlal_s8(tmp, vget_high_s8(q8_1), vget_high_s8(yq8_1));
tmp = vmlal_s8(tmp, vget_low_s8(q8_0), vget_low_s8(yq8_0));
tmp = vmlal_s8(tmp, vget_high_s8(q8_0), vget_high_s8(yq8_0));
accu32[rb] = vaddq_s32(accu32[rb], vmovl_s16(vget_low_s16(tmp)));
accu32[rb] = vaddq_s32(accu32[rb], vmovl_high_s16(tmp));
#endif
px[rb] += 16;
}
Expand All @@ -710,8 +716,7 @@ void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size

#else
for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
accu[rb] = vaddq_s32(accu[rb], vmovl_s16(vget_low_s16(accu32[rb])));
accu[rb] = vaddq_s32(accu[rb], vmovl_high_s16(accu32[rb]));
accu[rb] = vaddq_s32(accu[rb], accu32[rb]);
}
#endif
}
Expand All @@ -720,9 +725,9 @@ void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size
#if defined(__ARM_FEATURE_DOTPROD)

#else
int16x8_t accula[PARALLEL_SIZE];
int32x4_t accula[PARALLEL_SIZE];
for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
accula[rb] = vdupq_n_s16(0);
accula[rb] = vdupq_n_s32(0);
}
#endif
const uint8_t * px[PARALLEL_SIZE];
Expand All @@ -748,22 +753,24 @@ void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size
int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask));
int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask));
int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask));

#if defined(__ARM_FEATURE_DOTPROD)
accu[rb] = vdotq_s32(accu[rb], q8_0, yq8_0);
accu[rb] = vdotq_s32(accu[rb], q8_1, yq8_1);
accu[rb] = vdotq_s32(accu[rb], q8_2, yq8_2);
accu[rb] = vdotq_s32(accu[rb], q8_3, yq8_3);
#else
accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_3), vget_low_s8(yq8_3));
accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_3), vget_high_s8(yq8_3));
accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_2), vget_low_s8(yq8_2));
accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_2), vget_high_s8(yq8_2));
accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_1), vget_low_s8(yq8_1));
accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_1), vget_high_s8(yq8_1));
accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_0), vget_low_s8(yq8_0));
accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_0), vget_high_s8(yq8_0));

int16x8_t tmp = vdupq_n_s16(0);
tmp = vmlal_s8(tmp, vget_low_s8(q8_3), vget_low_s8(yq8_3));
tmp = vmlal_s8(tmp, vget_high_s8(q8_3), vget_high_s8(yq8_3));
tmp = vmlal_s8(tmp, vget_low_s8(q8_2), vget_low_s8(yq8_2));
tmp = vmlal_s8(tmp, vget_high_s8(q8_2), vget_high_s8(yq8_2));
tmp = vmlal_s8(tmp, vget_low_s8(q8_1), vget_low_s8(yq8_1));
tmp = vmlal_s8(tmp, vget_high_s8(q8_1), vget_high_s8(yq8_1));
tmp = vmlal_s8(tmp, vget_low_s8(q8_0), vget_low_s8(yq8_0));
tmp = vmlal_s8(tmp, vget_high_s8(q8_0), vget_high_s8(yq8_0));
accula[rb] = vaddq_s32(accula[rb], vmovl_s16(vget_low_s16(tmp)));
accula[rb] = vaddq_s32(accula[rb], vmovl_high_s16(tmp));
#endif
px[rb] += 16;
}
Expand All @@ -773,8 +780,7 @@ void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size

#else
for (int rb = 0; rb < PARALLEL_SIZE; rb++) {
accu[rb] = vaddq_s32(accu[rb], vmovl_s16(vget_low_s16(accula[rb])));
accu[rb] = vaddq_s32(accu[rb], vmovl_high_s16(accula[rb]));
accu[rb] = vaddq_s32(accu[rb], accula[rb]);
}
#endif
}
Expand Down Expand Up @@ -912,10 +918,10 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size
#if defined(__ARM_FEATURE_DOTPROD)

#else
int16x8_t accu32[PARALLEL_SIZE];
int32x4_t accu32[PARALLEL_SIZE];

for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
accu32[iy] = vdupq_n_s16(0);
accu32[iy] = vdupq_n_s32(0);
}
#endif
for (int j = 0; j < 32; j++) {
Expand Down Expand Up @@ -943,14 +949,17 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size
accu[iy] = vdotq_s32(accu[iy], q8_2, yq8_2);
accu[iy] = vdotq_s32(accu[iy], q8_3, yq8_3);
#else
accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_0), vget_low_s8(yq8_0));
accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_0), vget_high_s8(yq8_0));
accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_1), vget_low_s8(yq8_1));
accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_1), vget_high_s8(yq8_1));
accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_2), vget_low_s8(yq8_2));
accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_2), vget_high_s8(yq8_2));
accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_3), vget_low_s8(yq8_3));
accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_3), vget_high_s8(yq8_3));
int16x8_t tmp = vdupq_n_s16(0);
tmp = vmlal_s8(tmp, vget_low_s8(q8_0), vget_low_s8(yq8_0));
tmp = vmlal_s8(tmp, vget_high_s8(q8_0), vget_high_s8(yq8_0));
tmp = vmlal_s8(tmp, vget_low_s8(q8_1), vget_low_s8(yq8_1));
tmp = vmlal_s8(tmp, vget_high_s8(q8_1), vget_high_s8(yq8_1));
tmp = vmlal_s8(tmp, vget_low_s8(q8_2), vget_low_s8(yq8_2));
tmp = vmlal_s8(tmp, vget_high_s8(q8_2), vget_high_s8(yq8_2));
tmp = vmlal_s8(tmp, vget_low_s8(q8_3), vget_low_s8(yq8_3));
tmp = vmlal_s8(tmp, vget_high_s8(q8_3), vget_high_s8(yq8_3));
accu32[iy] = vaddq_s32(accu32[iy], vmovl_s16(vget_low_s16(tmp)));
accu32[iy] = vaddq_s32(accu32[iy], vmovl_high_s16(tmp));
#endif
}

Expand All @@ -962,7 +971,7 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size

#else
for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
accu[iy] = vaddq_s32(accu[iy], vaddq_s32(vmovl_high_s16(accu32[iy]), vmovl_s16(vget_low_s16(accu32[iy]))));
accu[iy] = vaddq_s32(accu[iy], accu32[iy]);
}
#endif
}
Expand All @@ -974,13 +983,13 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size
#if defined(__ARM_FEATURE_DOTPROD)

#else
int16x8_t accula[PARALLEL_SIZE];
int32x4_t accula[PARALLEL_SIZE];

for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
accula[iy] = vdupq_n_s16(0);
accula[iy] = vdupq_n_s32(0);
}
#endif

for (int j = 0; j < la_num; j++) {
// 加载并解包 x 数据(对所有列共享)
uint8x16_t xq8_3 = vld1q_u8(px + 0);
Expand All @@ -1006,14 +1015,17 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size
accu[iy] = vdotq_s32(accu[iy], q8_2, yq8_2);
accu[iy] = vdotq_s32(accu[iy], q8_3, yq8_3);
#else
accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_0), vget_low_s8(yq8_0));
accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_0), vget_high_s8(yq8_0));
accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_1), vget_low_s8(yq8_1));
accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_1), vget_high_s8(yq8_1));
accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_2), vget_low_s8(yq8_2));
accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_2), vget_high_s8(yq8_2));
accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_3), vget_low_s8(yq8_3));
accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_3), vget_high_s8(yq8_3));
int16x8_t tmp = vdupq_n_s16(0);
tmp = vmlal_s8(tmp, vget_low_s8(q8_0), vget_low_s8(yq8_0));
tmp = vmlal_s8(tmp, vget_high_s8(q8_0), vget_high_s8(yq8_0));
tmp = vmlal_s8(tmp, vget_low_s8(q8_1), vget_low_s8(yq8_1));
tmp = vmlal_s8(tmp, vget_high_s8(q8_1), vget_high_s8(yq8_1));
tmp = vmlal_s8(tmp, vget_low_s8(q8_2), vget_low_s8(yq8_2));
tmp = vmlal_s8(tmp, vget_high_s8(q8_2), vget_high_s8(yq8_2));
tmp = vmlal_s8(tmp, vget_low_s8(q8_3), vget_low_s8(yq8_3));
tmp = vmlal_s8(tmp, vget_high_s8(q8_3), vget_high_s8(yq8_3));
accula[iy] = vaddq_s32(accula[iy], vmovl_s16(vget_low_s16(tmp)));
accula[iy] = vaddq_s32(accula[iy], vmovl_high_s16(tmp));
#endif
}

Expand All @@ -1025,7 +1037,7 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size

#else
for (int iy = 0; iy < PARALLEL_SIZE; iy++) {
accu[iy] = vaddq_s32(accu[iy], vaddq_s32(vmovl_high_s16(accula[iy]), vmovl_s16(vget_low_s16(accula[iy]))));
accu[iy] = vaddq_s32(accu[iy], accula[iy]);
}
#endif
}
Expand Down