Skip to content

fix: Q1_0_g128 CPU dot product int truncation#4

Closed
Marxist-Leninist wants to merge 5 commits intoPrismML-Eng:prismfrom
Marxist-Leninist:fix/q1_0_g128-cpu-kernel-int-truncation
Closed

fix: Q1_0_g128 CPU dot product int truncation#4
Marxist-Leninist wants to merge 5 commits intoPrismML-Eng:prismfrom
Marxist-Leninist:fix/q1_0_g128-cpu-kernel-int-truncation

Conversation

@Marxist-Leninist
Copy link
Copy Markdown

Summary

  • The accumulator sumi in ggml_vec_dot_q1_0_g128_q8_0 was declared as int but accumulates float d1 * int sumi_block, causing the float result to be truncated to integer on each loop iteration
  • This produced garbage output for Q1_0_g128 quantized models when running on CPU (x86 and generic/portable paths)
  • Fix: change int sumi = 0 to float sumi = 0 in both kernels

Files changed

  • ggml/src/ggml-cpu/arch/x86/quants.c — x86 kernel
  • ggml/src/ggml-cpu/quants.c — generic/portable kernel

Test

Verified locally on Windows (MinGW, CPU-only) with Bonsai-8B (Q1_0_g128):

  • Before fix: garbled/random token output
  • After fix: coherent, correct responses

The accumulator `sumi` in ggml_vec_dot_q1_0_g128_q8_0 was declared
as `int` but accumulates `float d1 * int sumi_block`, causing the
float result to be truncated to integer on each iteration. This
produced garbage output for Q1_0_g128 models on CPU.

Fix: change `int sumi = 0` to `float sumi = 0` in both the x86
and generic (portable) kernels.
@github-actions github-actions bot added the ggml label Apr 2, 2026
noreply added 4 commits April 2, 2026 16:30
The existing scalar fallback runs at ~0.2 t/s on CPUs without AVX-512
(Ryzen, Intel 12th+ gen consumer). This adds an AVX2 path that:
- Sign-extends int8->int16 in two 16-element passes per Q8_0 block
- Expands 1-bit weights to 16-bit masks via broadcast+AND+cmpeq
- Uses blendv to negate activations where weight bit=0
- Accumulates via madd_epi16 -> cvtepi32_ps -> fmadd_ps

AVX2 is supported on virtually all x86-64 CPUs from 2013+.
…sing

Replace two-pass int16 blendv approach with:
- Single-pass byte-level bit expansion (shuffle+AND+cmpeq)
- XOR+SUB negate trick (replaces slow blendv, 2-3 cyc -> 1 cyc each)
- maddubs+madd accumulation (stays in int8 longer)
- Fully unrolled k-loop (eliminates loop overhead + branch)

Benchmark on i7-10510U (AVX2+FMA, 4T):
  Scalar:  0.2 t/s prompt, 0.2 t/s gen
  AVX2 v1: 2.4 t/s prompt, 2.1 t/s gen (two-pass blendv)
  AVX2 v3: 4.7 t/s prompt, 3.1 t/s gen (this commit)

~15x faster than scalar, ~50% faster than v1.
Apply cache-blocking and prefetch optimizations from the COM6 matrix
multiplication library (github.com/Marxist-Leninist/COM6):

- Increase weight row block size from 16 to 64 for Q1_0_g128
  (1-bit rows are ~576 bytes at K=4096, 64 rows = 36KB fits in L1d)
- Add software prefetch of weight rows 4 iterations ahead,
  mirroring COM6 distributed prefetch strategy
- Enlarge tmp accumulator buffer to match larger block size

Benchmark on i7-10510U (4T, Bonsai-8B Q1_0_g128):
  Before: 3.14 t/s generation
  After:  3.43 t/s generation (+9%)
@khosravipasha
Copy link
Copy Markdown
Collaborator

This look great thanks, there was a few CPU kernel fixes and did not see them until I pushed my changes. For now removed the buggy x86, will merge one of the correct AVX ones.

Could you run the KL divergence tests described here: #8

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes incorrect CPU dot-product accumulation for Q1_0_g128 quantized models by preventing float-to-int truncation during accumulation, improving correctness of inference outputs on CPU backends.

Changes:

  • Change sumi accumulator type from int to float in the generic ggml_vec_dot_q1_0_g128_q8_0 kernel.
  • Change sumi accumulator type from int to float in the x86 ggml_vec_dot_q1_0_g128_q8_0 scalar fallback.
  • Additional unrelated changes: mul_mat block-tiling/prefetch tweaks and a new .gitattributes entry for .gguf via LFS.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.

File Description
ggml/src/ggml-cpu/quants.c Fixes Q1_0_g128 x Q8_0 generic dot-product accumulation truncation by using float accumulator.
ggml/src/ggml-cpu/arch/x86/quants.c Updates Q1_0_g128 x Q8_0 accumulation type; also introduces AVX2-related changes that currently appear to break compilation.
ggml/src/ggml-cpu/ggml-cpu.c Changes tiling block size for Q1_0_g128, increases tmp buffer, and adds prefetching in matmul inner loop.
.gitattributes Adds Git LFS handling for *.gguf files.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +68 to +84
#if defined(__AVX2__)
// AVX2: single-pass byte-level processing, fully unrolled k-loop.
// Pipeline: broadcast+shuffle -> AND+cmpeq -> XOR+SUB -> maddubs+madd -> cvt+fma
const __m256i ones_8 = _mm256_set1_epi8(1);
const __m256i ones_16 = _mm256_set1_epi16(1);
const __m256i byte_shuf = _mm256_setr_epi8(
0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1,
2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3);
const __m256i bit_masks = _mm256_setr_epi8(
1,2,4,8,16,32,64,-128, 1,2,4,8,16,32,64,-128,
1,2,4,8,16,32,64,-128, 1,2,4,8,16,32,64,-128);
const __m256i zero = _mm256_setzero_si256();
__m256 acc = _mm256_setzero_ps();

// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
{
const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
const __m256i lowMask = _mm256_set1_epi8( 0xF );
return _mm256_and_si256(lowMask, bytes);
}
for (int ib = 0; ib < nb; ++ib) {
const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
const uint32_t * qs32 = (const uint32_t *)x[ib].qs;
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The #if defined(__AVX2__) block here contains executable statements that reference nb, x, y_ptr, and s, but this code is located at file scope (inside the helper section) rather than inside a function. This will not compile and also appears to have overwritten the intended mul_sum_us8_pairs_float(...) implementation. Please move this logic into the appropriate vec_dot function (or restore the helper to only operate on its parameters and return a value).

Copilot uses AI. Check for mistakes.
Comment on lines +1188 to +1197
// COM6-inspired block-tiling: larger blocks for Q1_0_g128 (1-bit weights are tiny,
// so we can fit more rows in L1). Prefetch next weight block while processing current.
const int64_t blck_0 = (type == GGML_TYPE_Q1_0_g128) ? 64 : 16;
const int64_t blck_1 = 16;

const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;

// attempt to reduce false-sharing (does not seem to make a difference)
// 16 * 2, accounting for mmla kernels
float tmp[32];
// Size: blck_0 * 2 (accounting for mmla kernels that compute 2 rows at once)
float tmp[128];
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR description says the fix is limited to changing int sumi -> float sumi in two dot-product kernels, but this file now includes a new block-tiling size change, a larger tmp buffer, and a new prefetch loop. Please update the PR description/title to reflect these additional behavioral/performance changes or split them into a separate PR to keep the correctness fix isolated.

Copilot uses AI. Check for mistakes.
Comment on lines +1232 to 1236
for (int64_t ir0 = iir0; ir0 < ir0_max; ir0 += num_rows_per_vec_dot) {
if (ir0 + 4 * num_rows_per_vec_dot < ir0_max) {
__builtin_prefetch(src0_row + (ir0 + 4 * num_rows_per_vec_dot) * nb01, 0, 1);
}
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__builtin_prefetch is a GCC/Clang builtin and is not available on all supported toolchains (notably MSVC). Consider guarding this with compiler checks (e.g., __GNUC__/__clang__) and/or using an existing cross-platform prefetch abstraction (or _mm_prefetch on x86) so CPU builds remain portable.

Copilot uses AI. Check for mistakes.
Comment on lines +1196 to +1197
// Size: blck_0 * 2 (accounting for mmla kernels that compute 2 rows at once)
float tmp[128];
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tmp is documented as blck_0 * 2, but it is hard-coded to float tmp[128] while blck_0 is now dynamic. This makes the comment inaccurate and risks future buffer overruns if blck_0 is tuned again. Prefer sizing tmp from blck_0/num_rows_per_vec_dot (e.g., VLA or max-constant with a static assert) and updating the memcpy offsets accordingly.

Suggested change
// Size: blck_0 * 2 (accounting for mmla kernels that compute 2 rows at once)
float tmp[128];
// Size: blck_0 * num_rows_per_vec_dot
const int64_t tmp_size = blck_0 * num_rows_per_vec_dot;
float tmp[tmp_size];

Copilot uses AI. Check for mistakes.
Comment on lines 668 to +672

float sumf = 0;

// Each Q1_0_g128 block has 128 elements
// Each Q8_0 block has 32 elements
// So we need 4 Q8_0 blocks per Q1_0_g128 block
#if defined(__AVX2__)
// AVX2: process 32 Q8_0 values per sub-block in two 16-element passes.
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sumf is declared before the #if defined(__AVX2__) block, but in AVX2 builds the function assigns *s directly and never reads sumf, which will trigger an unused-variable warning. Consider moving the sumf declaration into the scalar #else path or using GGML_UNUSED(sumf) when AVX2 is enabled.

Copilot uses AI. Check for mistakes.
@khosravipasha
Copy link
Copy Markdown
Collaborator

Good new our first CPU PR just got merged int llama.cpp master branch now, if you are still working on this please rebase with PrismML's master (just pulled the main llama.cpp)

Changes: Q1_0_g128 naming is gone now, the original Q1_0 with group size 32 was deleted and Q1_0_g128 was renamed to Q1_0 now by default has group size 128.

https://github.com/PrismML-Eng/llama.cpp/tree/master

This one only has generic cpu (slow), and ARM NEON path, planning to gather the best x86 kernels from here and to send a PR there (and tag all the contributers).

@khosravipasha
Copy link
Copy Markdown
Collaborator

There is a lot of CPU PRs, planning to gether all in one and then send to the main llama.cpp
Going to close this and mention people that helped in a thread there, if you think your solution is better please comment there:
#10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants