Skip to content

[cuda] initial Q1_0 backend#18

Draft
khosravipasha wants to merge 1 commit intomasterfrom
q1-cuda
Draft

[cuda] initial Q1_0 backend#18
khosravipasha wants to merge 1 commit intomasterfrom
q1-cuda

Conversation

@khosravipasha
Copy link
Copy Markdown
Collaborator

DRAFT for running CIs and etc, main PR will go main llama.cpp.

@khosravipasha
Copy link
Copy Markdown
Collaborator Author

  Benchmark Results — RTX 5090, Q1_0 CUDA, Flash Attention

  Speed (llama-bench -fa 1)

  ┌─────────────┬──────────┬─────────────┬─────────────┐
  │    Model    │   Size   │ pp512 (t/s) │ tg128 (t/s) │
  ├─────────────┼──────────┼─────────────┼─────────────┤
  │ Bonsai-1.7B │ 231 MiB  │ 29,815      │ 627         │
  ├─────────────┼──────────┼─────────────┼─────────────┤
  │ Bonsai-4B   │ 540 MiB  │ 18,585      │ 486         │
  ├─────────────┼──────────┼─────────────┼─────────────┤
  │ Bonsai-8B   │ 1.07 GiB │ 12,238      │ 373         │
  └─────────────┴──────────┴─────────────┴─────────────┘

  KL Validation (Q1_0 vs dequantized F16, WikiText-2, 20 chunks)

  ┌─────────────┬──────────┬────────────┬────────┬────────┐
  │    Model    │ Mean KLD │ Same Top p │ RMS Δp │ Status │
  ├─────────────┼──────────┼────────────┼────────┼────────┤
  │ Bonsai-1.7B │ 0.000419 │ 98.94%     │ 0.555% │ Pass   │
  ├─────────────┼──────────┼────────────┼────────┼────────┤
  │ Bonsai-4B   │ 0.000429 │ 98.51%     │ 0.593% │ Pass   │
  ├─────────────┼──────────┼────────────┼────────┼────────┤
  │ Bonsai-8B   │ 0.000514 │ 98.71%     │ 0.635% │ Pass   │
  └─────────────┴──────────┴────────────┴────────┴────────┘

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds initial CUDA backend support for the GGML_TYPE_Q1_0 quantization format, wiring it into CUDA dequantization/conversion paths and the quantized matmul kernels (MMVQ/MMQ).

Changes:

  • Implemented Q1_0×Q8_1 dot-product support and hooked it into mul_mat_vec_q (MMVQ).
  • Added MMQ (mul_mat_q) support for Q1_0, including MMA-tile loading and template instantiation generation.
  • Enabled Q1_0 for CUDA get-rows and type conversion/dequantization helpers.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
ggml/src/ggml-cuda/vecdotq.cuh Adds Q1_0 dot-product helpers (Q1_0 × Q8_1) and VDR constants.
ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu Adds generated MMQ instantiation unit for Q1_0.
ggml/src/ggml-cuda/template-instances/generate_cu_files.py Ensures MMQ instantiation files are generated for Q1_0.
ggml/src/ggml-cuda/mmvq.cu Routes Q1_0 through MMVQ vector-dot and VDR selection, and adds type dispatch.
ggml/src/ggml-cuda/mmq.cuh Adds Q1_0 MMQ traits + tile loader and ds-layout selection.
ggml/src/ggml-cuda/mmq.cu Adds Q1_0 MMQ type dispatch and selection logic gating.
ggml/src/ggml-cuda/ggml-cuda.cu Advertises Q1_0 support in CUDA backend op capability checks.
ggml/src/ggml-cuda/getrows.cu Enables get_rows for Q1_0 via dequantize kernel.
ggml/src/ggml-cuda/dequantize.cuh Adds device-side dequantization routine for Q1_0.
ggml/src/ggml-cuda/convert.cu Enables Q1_0 conversions to FP16/FP32 (contiguous and non-contiguous).
ggml/src/ggml-cuda/common.cuh Adds CUDA type traits for Q1_0 (qk/qr/qi).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +308 to 313
// Q1_0 requires MMA (Turing+) — no DP4A fallback path
if (type == GGML_TYPE_Q1_0 && !turing_mma_available(cc)) {
return false;
}

if (turing_mma_available(cc)) {
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In ggml_cuda_should_use_mmq(), the Q1_0 early-return only checks turing_mma_available(cc). That disables the Q1_0 MMQ path on AMD even when MFMA/WMMA is available (and load_tiles_q1_0 is guarded for AMD_MFMA_AVAILABLE). Consider allowing Q1_0 when any MMA backend is available (e.g., turing_mma_available || amd_mfma_available || amd_wmma_available), and keep the ‘no DP4A fallback’ constraint as the absence of all MMA paths.

Suggested change
// Q1_0 requires MMA (Turing+) — no DP4A fallback path
if (type == GGML_TYPE_Q1_0 && !turing_mma_available(cc)) {
return false;
}
if (turing_mma_available(cc)) {
const bool has_mma = turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc);
// Q1_0 requires MMA — no DP4A fallback path
if (type == GGML_TYPE_Q1_0 && !has_mma) {
return false;
}
if (has_mma) {

Copilot uses AI. Check for mistakes.

#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
#define MMQ_ITER_K 256
#define MMQ_ITER_K_Q1_0 128 // For Q1_0: QK1_0=128, QI1_0=4, so threads_per_row = 128/(4*4) = 8
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MMQ_ITER_K_Q1_0 is defined but not referenced anywhere. Either wire it into the Q1_0 tile loader / iter-k selection logic or remove it to avoid stale configuration knobs.

Suggested change
#define MMQ_ITER_K_Q1_0 128 // For Q1_0: QK1_0=128, QI1_0=4, so threads_per_row = 128/(4*4) = 8

Copilot uses AI. Check for mistakes.
Comment on lines +109 to +149
#define VDR_Q1_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism
#define VDR_Q1_0_Q8_1_MMQ 4 // Q1_0 has 128 bits (4 ints) per block

template <int vdr> static __device__ __forceinline__ float vec_dot_q1_0_q8_1_impl(
const int * v, const int * u, const float & d1, const half2 & ds8) {

int sumi = 0;

#pragma unroll
for (int i = 0; i < vdr; ++i) {
const int vi = v[i];

// Unpack 32 bits into 32 signed values (-1 or +1)
// Each bit: 0 -> -1, 1 -> +1
int vi_bytes[8];

#pragma unroll
for (int j = 0; j < 8; ++j) {
const int shift = j * 4;
const int bits4 = (vi >> shift) & 0x0F;

const int b0 = (bits4 & 0x01) ? 1 : -1;
const int b1 = (bits4 & 0x02) ? 1 : -1;
const int b2 = (bits4 & 0x04) ? 1 : -1;
const int b3 = (bits4 & 0x08) ? 1 : -1;

vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
}

#pragma unroll
for (int j = 0; j < 8; ++j) {
sumi = ggml_cuda_dp4a(vi_bytes[j], u[8*i + j], sumi);
}
}

const float2 ds8f = __half22float2(ds8);

// Q1_0 is symmetric (no offset), so we just multiply by scales
return d1 * ds8f.x * sumi;
}

Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vec_dot_q1_0_q8_1_impl() is currently not referenced/instantiated anywhere in the CUDA codebase (only the non-template vec_dot_q1_0_q8_1 wrapper is used). If this is not intended for near-term use, removing it will reduce header complexity/compile cost; otherwise add a call site consistent with the other q_q8_1_impl patterns.

Suggested change
#define VDR_Q1_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism
#define VDR_Q1_0_Q8_1_MMQ 4 // Q1_0 has 128 bits (4 ints) per block
template <int vdr> static __device__ __forceinline__ float vec_dot_q1_0_q8_1_impl(
const int * v, const int * u, const float & d1, const half2 & ds8) {
int sumi = 0;
#pragma unroll
for (int i = 0; i < vdr; ++i) {
const int vi = v[i];
// Unpack 32 bits into 32 signed values (-1 or +1)
// Each bit: 0 -> -1, 1 -> +1
int vi_bytes[8];
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int shift = j * 4;
const int bits4 = (vi >> shift) & 0x0F;
const int b0 = (bits4 & 0x01) ? 1 : -1;
const int b1 = (bits4 & 0x02) ? 1 : -1;
const int b2 = (bits4 & 0x04) ? 1 : -1;
const int b3 = (bits4 & 0x08) ? 1 : -1;
vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
}
#pragma unroll
for (int j = 0; j < 8; ++j) {
sumi = ggml_cuda_dp4a(vi_bytes[j], u[8*i + j], sumi);
}
}
const float2 ds8f = __half22float2(ds8);
// Q1_0 is symmetric (no offset), so we just multiply by scales
return d1 * ds8f.x * sumi;
}

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants