Skip to content

vulkan: add Q1_0_g128 (1-bit ternary) shader support#9

Open
claudlos wants to merge 1 commit intoPrismML-Eng:masterfrom
claudlos:prism
Open

vulkan: add Q1_0_g128 (1-bit ternary) shader support#9
claudlos wants to merge 1 commit intoPrismML-Eng:masterfrom
claudlos:prism

Conversation

@claudlos
Copy link
Copy Markdown

@claudlos claudlos commented Apr 3, 2026

Add Vulkan shader support for Q1_0_g128 quantization

Summary

This PR adds Vulkan shader support for the GGML_TYPE_Q1_0_g128 (1-bit sign / binary quantization, group size 128) format. The primary validated paths today are dequantization, get_rows, and fused mul_mat_vec. Without these shaders, Q1_0_g128 models fall back to CPU dequantization on Vulkan devices, resulting in ~291 graph splits and extremely poor performance. With this PR, the tested inference path runs almost entirely on GPU with only 2 graph splits.

Performance Results

Metric Before (CPU fallback) After (Vulkan shaders) Speedup
Eval (token gen) 0.28 t/s 23.4 t/s 84x
Prompt eval (135 tokens) 0.31 t/s 38.3 t/s 124x
Graph splits 291 2 -99.3%

Comparison: Qwen2.5-3B Q4_K_M on the same hardware achieves 27.8 t/s — our Bonsai 8B Q1_0_g128 reaches 84% of that speed despite being 2.7x larger, validating the efficiency of 1-bit inference.

Files Changed

New Files

  • ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0_g128.comp — Standalone dequantization compute shader. Each thread processes one 128-element block (16 bytes of packed sign bits + fp16 scale), used for get_rows and general dequantization pipelines.

  • ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q1_0_g128.comp — Custom fused matrix-vector multiply shader. Uses 4 threads per 128-element block (32 elements/thread = one uint32 of packed sign bits). Maps each bit to ±1.0, which compiles to efficient v_cndmask-style code on RDNA GPUs. Accumulates via 8 dot(vec4, vec4) operations per thread with fma for the final scale multiply.

Modified Files

  • ggml/src/ggml-vulkan/vulkan-shaders/types.glsl — Added block_q1_0_g128 struct definition (fp16 scale d + 16-byte qs array) and DATA_A_Q1_0_G128 preprocessor configuration with QUANT_K=128, QUANT_R=1, QUANT_AUXF=1.

  • ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl — Added dequantize() and dequantize4() functions for Q1_0_g128 (bit extraction → sign mapping). Added Q1_0_g128 to the single-scale get_dm() path (returns vec2(d, 0)).

  • ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl — Added Q1_0_g128 matrix-matrix multiply load path in load_a_to_shmem. Uses branchless FMA dequantization: d*(2*bit - 1) = fma(2d, bit_float, -d) for efficient SIMD utilization.

  • ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp — Registered q1_0_g128 in the type list, marked as legacy quant with LOAD_VEC_A=4. Excluded from coopmat2 flash attention paths (no dequantFunc defined), excluded from integer dot product q8_1 MMQ paths (no Q8_1 quantization mapping exists for 1-bit types).

  • ggml/src/ggml-vulkan/ggml-vulkan.cpp — Registered Q1_0_g128 pipelines across all shader pipeline arrays: dequant, mul_mat_vec (f32 and f16 B-type variants), mul_mat_vec_id, get_rows, get_rows_f32. Added Q1_0_g128 to supports_op switch cases for GGML_OP_MUL_MAT and GGML_OP_GET_ROWS. Uses rm_iq row multiplier and subgroup16 configuration (matching IQ-type pipeline parameters).

Technical Design Decisions

  1. Custom fused mul_mat_vec shader rather than generic path: The 1-bit format has a unique structure (128 elements packed as 16 bytes of sign bits + single fp16 scale) that doesn't fit the standard mul_mat_vec.comp template. The fused shader avoids intermediate dequantization and directly computes dot products from packed bits.

  2. 4 threads per block: Each thread handles 32 elements (one uint32 worth of bits), loading 4 bytes and expanding to 8 vec4 dot products. This maps well to GPU wavefronts.

  3. Excluded from coopmat2/flash attention: Q1_0_g128 is a weights-only quantization (KV cache uses f16). There's no dequantFunc symbol required for the cooperative matrix path, and flash attention operates on KV cache types, not weight types.

  4. Excluded from integer dot MMQ: The Q8_1 quantized matmul path requires a compatible requantization scheme that doesn't exist for 1-bit types.

  5. Branchless FMA in mul_mm_funcs: The matrix-matrix path uses fma(2d, bit_float, -d) instead of conditional selection, which is more efficient for the wider SIMD paths used in batched matmul.

Testing

  • Model: Bonsai-8B Q1_0_g128 GGUF (1-bit ternary quantized 8B parameter model)
  • Hardware: AMD Radeon 680M (RDNA2 integrated GPU, 12 Compute Units, 18GB shared VRAM)
  • Driver: AMD Adrenalin on Windows 11
  • Test methodology: Interactive text generation via llama-cli with Vulkan backend (-ngl 99)
  • Correctness validation: Output text is coherent and contextually appropriate
  • 12 shader variants were tested during optimization to arrive at the final design
  • Prompt evaluation: 135-token prompt, measured throughput
  • Token generation: Measured sustained generation rate

Known Limitations

  1. No cooperative matrix (coopmat2) support: Q1_0_g128 does not participate in cooperative matrix matmul or flash attention paths. This is by design — these paths require a dequantFunc symbol and Q1_0_g128 is weights-only.

  2. No integer dot product (MMQ) support: The q8_1 integer dot product optimization path is excluded for Q1_0_g128 since no compatible requantization scheme exists.

  3. No F16 B-type mul_mat support: Q1_0_g128 with F16 input tensors is explicitly blocked in supports_op. Only F32 B-type is supported. This avoids the complexity of F16 pipeline wiring for a weights-only quantization format (KV cache uses f16, but input activations are f32).

  4. Tested on RDNA2 only: While the shaders use standard Vulkan compute (no vendor-specific extensions), they have only been validated on AMD RDNA2. Testing on NVIDIA and Intel GPUs is recommended.

  5. Flash Attention: N/A for weight quantization types. Q1_0_g128 models use f16 KV cache, which already has full Vulkan support.

Hardware Tested

  • GPU: AMD Radeon 680M (RDNA2 iGPU)
  • Compute Units: 12
  • VRAM: 18GB shared (system memory)
  • OS: Windows 11
  • Vulkan API: 1.3

@soyelmismo
Copy link
Copy Markdown

soyelmismo commented Apr 3, 2026

fully working in rx470, thanks so much!
imagen

@khosravipasha
Copy link
Copy Markdown
Collaborator

khosravipasha commented Apr 3, 2026

Nice, thanks this is very good, we had a Vulkan and opencl backends too but did not have time to test and benchmark them properly, so did not release it, I will try to put them in a brach for people to try. This looks great too.

Curios which phone was this?

Also the x86 cpu giving wrong output shoud be fixed now (stilll not optimized to be fast but its correct now), please check this PR #8

@claudlos
Copy link
Copy Markdown
Author

claudlos commented Apr 4, 2026

Nice, thanks this is very good, we had a Vulkan and opencl backends too but did not have time to test and benchmark them properly, so did not release it, I will try to put them in a brach for people to try. This looks great too.

Curios which phone was this?

Also the x86 cpu giving wrong output shoud be fixed now (stilll not optimized to be fast but its correct now), please check this PR #8

Thanks Khosravipasha I'm happy to help

This was done with a mini computer not a phone. R9 6900HX with AMD Radeon 680M (RDNA2 iGPU)

I edited the PR and removed the notes about x86 cpu dequant.

@harish2704
Copy link
Copy Markdown

It is working on My Vega56 GPU. Thank you very much to @claudlos

@khosravipasha
Copy link
Copy Markdown
Collaborator

@claudlos Awesome thanks, I had a older implementation and rebased it with latest branch and dumped it here (it used to be much more stable I think now works but did not have time to test it throughly), it has both OpenCL and Vulkan,
Yours might be better than the one I have :)

https://github.com/PrismML-Eng/llama.cpp/tree/prism-android

@AlexRednic
Copy link
Copy Markdown

Works well on my 6800HS / 680m. Good job!


==================================
Testing GPU layers
===================================
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Ryzen 7 6800HS Creator Edition (RADV REMBRANDT) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 65536 | int dot: 1 | matrix cores: none
| model                          |       size |     params | backend    | ngl | n_batch | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | -: | --------------: | -------------------: |
| qwen3 8B Q1_0                  |   1.07 GiB |     8.19 B | Vulkan     |   0 |    1024 |  1 |          pp1024 |        120.02 ± 5.21 |
| qwen3 8B Q1_0                  |   1.07 GiB |     8.19 B | Vulkan     |   0 |    1024 |  1 |           tg256 |         14.05 ± 0.18 |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | -: | --------------: | -------------------: |
| qwen3 8B Q1_0                  |   1.07 GiB |     8.19 B | Vulkan     |   9 |    1024 |  1 |          pp1024 |        125.24 ± 0.26 |
| qwen3 8B Q1_0                  |   1.07 GiB |     8.19 B | Vulkan     |   9 |    1024 |  1 |           tg256 |         14.81 ± 0.57 |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | -: | --------------: | -------------------: |
| qwen3 8B Q1_0                  |   1.07 GiB |     8.19 B | Vulkan     |  18 |    1024 |  1 |          pp1024 |        110.44 ± 0.28 |
| qwen3 8B Q1_0                  |   1.07 GiB |     8.19 B | Vulkan     |  18 |    1024 |  1 |           tg256 |         19.44 ± 0.91 |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | -: | --------------: | -------------------: |
| qwen3 8B Q1_0                  |   1.07 GiB |     8.19 B | Vulkan     |  27 |    1024 |  1 |          pp1024 |        123.67 ± 0.76 |
| qwen3 8B Q1_0                  |   1.07 GiB |     8.19 B | Vulkan     |  27 |    1024 |  1 |           tg256 |         26.99 ± 0.21 |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | -: | --------------: | -------------------: |
| qwen3 8B Q1_0                  |   1.07 GiB |     8.19 B | Vulkan     |  99 |    1024 |  1 |          pp1024 |        111.68 ± 0.14 |
| qwen3 8B Q1_0                  |   1.07 GiB |     8.19 B | Vulkan     |  99 |    1024 |  1 |           tg256 |         37.39 ± 0.04 |

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Vulkan backend support for the GGML_TYPE_Q1_0_g128 (1-bit, group size 128) quantization format to reduce CPU fallback and graph splits by introducing dedicated shaders and wiring them into the Vulkan pipeline selection logic.

Changes:

  • Introduces new Vulkan compute shaders for Q1_0_g128 dequantization and fused mul_mat_vec.
  • Extends shared GLSL/Vulkan shader infrastructure (type definitions, dequant helpers, matmul loaders, shader generation) to recognize and compile Q1_0_g128 variants.
  • Registers the new pipelines in the Vulkan backend and adds backend-op tests covering get_rows, mul_mat, and mul_mat_id for Q1_0_g128.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/test-backend-ops.cpp Adds new test cases exercising Q1_0_g128 get_rows, mul_mat, and mul_mat_id.
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp Adds q1_0_g128 to shader generation/type lists and excludes unsupported paths (e.g., integer-dot variants).
ggml/src/ggml-vulkan/vulkan-shaders/types.glsl Defines block_q1_0_g128 + preprocessor config and adds an 8-bit storage extension requirement.
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl Adds Q1_0_g128 load/dequant path for matrix-matrix shaders (load_a_to_shmem).
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q1_0_g128.comp New fused Q1_0_g128 matrix-vector shader operating directly on packed sign bits.
ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0_g128.comp New standalone dequant shader for Q1_0_g128 blocks.
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl Adds Q1_0_g128 dequant helpers and routes it through the single-scale get_dm() path.
ggml/src/ggml-vulkan/ggml-vulkan.cpp Wires Q1_0_g128 pipelines into Vulkan backend and extends supports_op to cover relevant ops.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@Benjamin-Wegener
Copy link
Copy Markdown

Would TurboQuant also work with the Bonsai-8B 1-bit GGUF model weights on Vulkan? I know that TurboQuant does not change the 1-bit model weights — it only changes the KV cache format at runtime, which reduces VRAM usage. I was wondering whether the TurboQuant-enabled version of llama.cpp would also allow for efficient inference with the Bonsai-8B GGUF model. https://huggingface.co/apothic/bonsai-8B-1bit-turboquant

@khosravipasha
Copy link
Copy Markdown
Collaborator

@Benjamin-Wegener Eventually probably yes, but also need to support what TurboQuant does support Vulkan too, that would be more like a question for the main llama.cpp repository.

@khosravipasha
Copy link
Copy Markdown
Collaborator

Good new our first CPU PR just got merged int llama.cpp master branch now, if you are still working on this please rebase with PrismML's master (just pulled the main llama.cpp)

Changes: Q1_0_g128 naming is gone now, the original Q1_0 with group size 32 was deleted and Q1_0_g128 was renamed to Q1_0 now by default has group size 128.

https://github.com/PrismML-Eng/llama.cpp/tree/master

This one only has generic cpu (slow), and ARM NEON path, planning to gather the best x86 kernels from here and to send a PR there (and tag all the contributers).

Adds Vulkan shader support for the GGML_TYPE_Q1_0 (1-bit sign quantization,
group size 128) format. Without these shaders, Q1_0 models fall back to CPU
dequantization on Vulkan devices, resulting in ~291 graph splits and extremely
poor performance.

New files:
- dequant_q1_0.comp: standalone dequantization shader
- mul_mat_vec_q1_0.comp: fused matrix-vector multiply (4 threads/block)

Changes:
- types.glsl: block_q1_0 struct + GL_EXT_shader_8bit_storage
- dequant_funcs.glsl: dequantize/dequantize4/get_dm for Q1_0
- mul_mm_funcs.glsl: load_a_to_shmem with branchless FMA decoding
- vulkan-shaders-gen.cpp: Q1_0 type registration, exclusions for
  coopmat2/flash_attn/integer_dot (no dequantFunc/Q8_1 mapping)
- ggml-vulkan.cpp: pipeline registration, supports_op, f32acc forced
- test-backend-ops.cpp: get_rows/mul_mat/mul_mat_id tests

Does NOT force Q1_0 through dequant fallback in mul_mat_q_f16 -
the fused matmul path is used when available, giving 2x+ prefill
speedup on AMD RDNA2 Vulkan vs the dequant path.

Tested on: AMD Radeon 680M (RDNA2), AMD RX 470, AMD Vega 56
Rebased onto PrismML master (Q1_0_g128 renamed to Q1_0)
@claudlos claudlos changed the base branch from prism to master April 7, 2026 00:50
@claudlos
Copy link
Copy Markdown
Author

claudlos commented Apr 7, 2026

Rebased onto master with the Q1_0 rename (Q1_0_g128 is now just Q1_0).

Changes from original PR:

  • All Q1_0_g128 references renamed to Q1_0
  • Shader files renamed: dequant_q1_0.comp, mul_mat_vec_q1_0.comp
  • Removed force_q1_0_dequant (addresses Copilot review comment [cuda] Fix mmq/mma path #1): prefill went from 63 t/s → 149 t/s on Radeon 680M (2.3x improvement), generation speed unchanged
  • Did NOT address Copilot comment feat: port TQ3_0 KV cache from llama-turboquant #2 (8bit_storage gating) — needs deeper refactor of all 8-bit type declarations in types.glsl, not safe as a one-line fix

Benchmark results (Radeon 680M RDNA2 Vulkan, llama-bench -ngl 99 -p 512 -n 128 -r 5):

Model Before fix (forced dequant) After fix (fused path) Speedup
Bonsai-8B prompt 63 t/s 149 t/s 2.3x
Bonsai-8B gen 23 t/s 23 t/s
Bonsai-4B prompt 309 t/s
Bonsai-4B gen 33 t/s
Bonsai-1.7B prompt 742 t/s
Bonsai-1.7B gen 69 t/s

NVIDIA Vulkan benchmarks in progress — will post results when ready.

Re @Benjamin-Wegener's TurboQuant question: I've been prototyping TurboQuant KV cache quantization on Vulkan in a separate fork. The plumbing works (Vulkan shaders compile, TURBO2/3/4 pipeline runs to completion) but output correctness isn't there yet (WHT rotation bug). It's orthogonal to this PR — TurboQuant only changes KV cache format at runtime, not model weights.

@khosravipasha
Copy link
Copy Markdown
Collaborator

khosravipasha commented Apr 7, 2026

Awesome thanks, not too familiar with Radeon 680M RDNA2's integrated GPU, is the speeds around what you were expecting? Curios how a 4-bit model of 8B performs on that device (maybe try Q4_0), or a 4-bit model than ends up being 1GB. That's how I used to gauge how close to optimal my kernels were for other backends.

Yeah don't have to listen to everything copilot reviwer says.

}
#endif

#if defined(DATA_A_Q1_0)vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint byte_idx = iqs / 8; const uint bit_idx = iqs % 8; const uint bits = uint(data_a[a_offset + ib].qs[byte_idx]); const float sign0 = ((bits >> bit_idx) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx2 = (iqs + 1) / 8; const uint bit_idx2 = (iqs + 1) % 8; const uint bits2 = uint(data_a[a_offset + ib].qs[byte_idx2]); const float sign1 = ((bits2 >> bit_idx2) & 1) == 1 ? 1.0f : -1.0f; return vec2(sign0, sign1);}vec4 dequantize4(uint ib, uint iqs, uint a_offset) { const uint byte_idx0 = iqs / 8; const uint bit_idx0 = iqs % 8; const uint bits0 = uint(data_a[a_offset + ib].qs[byte_idx0]); const float s0 = ((bits0 >> bit_idx0) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx1 = (iqs + 1) / 8; const uint bit_idx1 = (iqs + 1) % 8; const uint bits1 = uint(data_a[a_offset + ib].qs[byte_idx1]); const float s1 = ((bits1 >> bit_idx1) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx2 = (iqs + 2) / 8; const uint bit_idx2 = (iqs + 2) % 8; const uint bits2 = uint(data_a[a_offset + ib].qs[byte_idx2]); const float s2 = ((bits2 >> bit_idx2) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx3 = (iqs + 3) / 8; const uint bit_idx3 = (iqs + 3) % 8; const uint bits3 = uint(data_a[a_offset + ib].qs[byte_idx3]); const float s3 = ((bits3 >> bit_idx3) & 1) == 1 ? 1.0f : -1.0f; return vec4(s0, s1, s2, s3);}#endif
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this meant to be minimized?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, not intentional — that wasn’t meant to be minimized; it should be formatted normally like the surrounding
shader.

#endif

#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_Q1_0)vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint byte_idx = iqs / 8; const uint bit_idx = iqs % 8; const uint bits = uint(data_a[a_offset + ib].qs[byte_idx]); const float sign0 = ((bits >> bit_idx) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx2 = (iqs + 1) / 8; const uint bit_idx2 = (iqs + 1) % 8; const uint bits2 = uint(data_a[a_offset + ib].qs[byte_idx2]); const float sign1 = ((bits2 >> bit_idx2) & 1) == 1 ? 1.0f : -1.0f; return vec2(sign0, sign1);}vec4 dequantize4(uint ib, uint iqs, uint a_offset) { const uint byte_idx0 = iqs / 8; const uint bit_idx0 = iqs % 8; const uint bits0 = uint(data_a[a_offset + ib].qs[byte_idx0]); const float s0 = ((bits0 >> bit_idx0) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx1 = (iqs + 1) / 8; const uint bit_idx1 = (iqs + 1) % 8; const uint bits1 = uint(data_a[a_offset + ib].qs[byte_idx1]); const float s1 = ((bits1 >> bit_idx1) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx2 = (iqs + 2) / 8; const uint bit_idx2 = (iqs + 2) % 8; const uint bits2 = uint(data_a[a_offset + ib].qs[byte_idx2]); const float s2 = ((bits2 >> bit_idx2) & 1) == 1 ? 1.0f : -1.0f; const uint byte_idx3 = (iqs + 3) / 8; const uint bit_idx3 = (iqs + 3) % 8; const uint bits3 = uint(data_a[a_offset + ib].qs[byte_idx3]); const float s3 = ((bits3 >> bit_idx3) & 1) == 1 ? 1.0f : -1.0f; return vec4(s0, s1, s2, s3);}#endif
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar here, was this meant to be minimized?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants