Skip to content

Commit 9614e74

Browse files
committed
[SIMD nf4] update w/ mod_mul_safe
1 parent a2e5cbd commit 9614e74

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

src/simd_nf4.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ inline __uint128_t mul(__uint128_t a, __uint128_t b)
198198
HalfVecType res;
199199
VecType vec_a = load_to_reg(a);
200200
VecType vec_b = load_to_reg(b);
201-
store_low_half_to_mem(&res, mod_mul_safe(vec_a, vec_b, F4));
201+
store_low_half_to_mem(&res, mod_mul_safe<uint32_t>(vec_a, vec_b));
202202
return reinterpret_cast<__uint128_t>(res);
203203
}
204204

@@ -230,7 +230,7 @@ inline void hadamard_mul_rem(unsigned n, __uint128_t* x, __uint128_t* y)
230230
VecType _x_p = load_to_reg(_x[i]);
231231
VecType _y_p = load_to_reg(_y[i]);
232232

233-
store_low_half_to_mem(_x + i, mod_mul_safe(_x_p, _y_p, F4));
233+
store_low_half_to_mem(_x + i, mod_mul_safe<uint32_t>(_x_p, _y_p));
234234
}
235235
}
236236

@@ -248,8 +248,9 @@ inline void hadamard_mul_doubled_rem(
248248
VecType _x_next_p = load_to_reg(_x_half[i]);
249249
VecType _y_p = load_to_reg(_y[i]);
250250

251-
store_low_half_to_mem(_x + i, mod_mul_safe(_x_p, _y_p, F4));
252-
store_low_half_to_mem(_x_half + i, mod_mul_safe(_x_next_p, _y_p, F4));
251+
store_low_half_to_mem(_x + i, mod_mul_safe<uint32_t>(_x_p, _y_p));
252+
store_low_half_to_mem(
253+
_x_half + i, mod_mul_safe<uint32_t>(_x_next_p, _y_p));
253254
}
254255
}
255256

@@ -284,7 +285,7 @@ inline __uint128_t mul(__uint128_t a, __uint128_t b)
284285
VecType res;
285286
VecType vec_a = load_to_reg(a);
286287
VecType vec_b = load_to_reg(b);
287-
store_to_mem(&res, mod_mul_safe(vec_a, vec_b, F4));
288+
store_to_mem(&res, mod_mul_safe<uint32_t>(vec_a, vec_b));
288289
return reinterpret_cast<__uint128_t>(res);
289290
}
290291

@@ -354,7 +355,7 @@ inline void hadamard_mul(unsigned n, __uint128_t* _x, __uint128_t* _y)
354355

355356
// multiply y to the first half of `x`
356357
for (i = 0; i < vec_len; ++i) {
357-
x[i] = mod_mul_safe(x[i], y[i], F4);
358+
x[i] = mod_mul_safe<uint32_t>(x[i], y[i]);
358359
}
359360

360361
if (rem_len > 0) {

0 commit comments

Comments
 (0)