From 83703c6297d9de54867bbac307410f983fda67d7 Mon Sep 17 00:00:00 2001 From: Asher Feldman Date: Sat, 23 May 2026 15:26:37 -0700 Subject: [PATCH] Add KQuant quantization mode with 10 codec encode/decode kernels --- benchmarks/python/gather_qmm_bench.py | 33 +- mlx/backend/cpu/quantized.cpp | 596 ++ mlx/backend/cuda/quantized/qqmm.cpp | 4 + mlx/backend/cuda/quantized/quantized.cpp | 12 + mlx/backend/metal/CMakeLists.txt | 6 + mlx/backend/metal/jit/includes.h | 3 + mlx/backend/metal/jit_kernels.cpp | 31 +- mlx/backend/metal/kernels.h | 4 +- mlx/backend/metal/kernels/CMakeLists.txt | 7 + mlx/backend/metal/kernels/fp_quantized.h | 20 +- mlx/backend/metal/kernels/kq_quantized.h | 5224 +++++++++++++++++ mlx/backend/metal/kernels/kq_quantized.metal | 323 + .../metal/kernels/kq_quantized_encode.h | 1239 ++++ .../metal/kernels/kq_quantized_encode.metal | 86 + .../metal/kernels/kq_quantized_legacy.h | 1627 +++++ mlx/backend/metal/kernels/kq_quantized_nax.h | 2560 ++++++++ .../metal/kernels/kq_quantized_nax.metal | 106 + mlx/backend/metal/kernels/quantized_nax.h | 1 - mlx/backend/metal/kernels/quantized_utils.h | 23 + mlx/backend/metal/nojit_kernels.cpp | 4 +- mlx/backend/metal/quantized.cpp | 591 +- mlx/fast.cpp | 30 +- mlx/fast_primitives.h | 10 +- mlx/io/CMakeLists.txt | 2 +- mlx/io/gguf.cpp | 25 +- mlx/io/gguf.h | 11 +- mlx/io/gguf_quants.cpp | 204 +- mlx/ops.cpp | 312 +- mlx/ops.h | 5 + mlx/primitives.cpp | 63 +- mlx/primitives.h | 46 +- python/mlx/nn/layers/distributed.py | 14 +- python/mlx/nn/layers/embedding.py | 5 +- python/mlx/nn/layers/linear.py | 7 +- python/mlx/nn/layers/quantized.py | 126 +- python/src/ops.cpp | 46 +- python/tests/test_gguf_kquant.py | 155 + python/tests/test_kquant.py | 3041 ++++++++++ python/tests/test_quantized.py | 37 + 39 files changed, 16301 insertions(+), 338 deletions(-) create mode 100644 mlx/backend/metal/kernels/kq_quantized.h create mode 100644 mlx/backend/metal/kernels/kq_quantized.metal create mode 100644 mlx/backend/metal/kernels/kq_quantized_encode.h create mode 100644 mlx/backend/metal/kernels/kq_quantized_encode.metal create mode 100644 mlx/backend/metal/kernels/kq_quantized_legacy.h create mode 100644 mlx/backend/metal/kernels/kq_quantized_nax.h create mode 100644 mlx/backend/metal/kernels/kq_quantized_nax.metal create mode 100644 python/tests/test_gguf_kquant.py create mode 100644 python/tests/test_kquant.py diff --git a/benchmarks/python/gather_qmm_bench.py b/benchmarks/python/gather_qmm_bench.py index 17c06d57d2..e5183b1c4b 100644 --- a/benchmarks/python/gather_qmm_bench.py +++ b/benchmarks/python/gather_qmm_bench.py @@ -40,12 +40,21 @@ def gather_mm_simulate(x, w, indices): return x -def time_gather_qmm(): +def time_gather_qmm(mode="affine", kquant_type=""): + label = kquant_type if mode == "kquant" else mode + print(f"\n--- gather_qmm ({label}) ---") + + quantize_kwargs = {"mode": mode} + qmm_kwargs = {"mode": mode} + if mode == "kquant": + quantize_kwargs["kquant_type"] = kquant_type + qmm_kwargs["kquant_type"] = kquant_type + x = mx.random.normal((N, 1, 1, D)) / 1024**0.5 w1 = mx.random.normal((E, M, D)) / 1024**0.5 w2 = mx.random.normal((E, D, M)) / 1024**0.5 - w1 = mx.quantize(w1) - w2 = mx.quantize(w2) + w1 = mx.quantize(w1, **quantize_kwargs) + w2 = mx.quantize(w2, **quantize_kwargs) indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32) sorted_indices = mx.sort(indices.flatten()).reshape(N, I) mx.eval(x, w1, w2, indices, sorted_indices) @@ -55,8 +64,12 @@ def gather_mm(x, w1, w2, indices, sort): inv_order = None if sort: x, idx, inv_order = gather_sort(x, indices) - x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort) - x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort) + x = mx.gather_qmm( + x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort, **qmm_kwargs + ) + x = mx.gather_qmm( + x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort, **qmm_kwargs + ) if sort: x = scatter_unsort(x, inv_order, indices.shape) return x @@ -68,13 +81,13 @@ def gather_mm(x, w1, w2, indices, sort): x = mx.random.normal((N * I, D)) / 1024**0.5 w1 = mx.random.normal((M, D)) / 1024**0.5 w2 = mx.random.normal((D, M)) / 1024**0.5 - w1 = mx.quantize(w1) - w2 = mx.quantize(w2) + w1 = mx.quantize(w1, **quantize_kwargs) + w2 = mx.quantize(w2, **quantize_kwargs) mx.eval(x, w1, w2) def equivalent_matmul(x, w1, w2): - x = mx.quantized_matmul(x, *w1, transpose=True) - x = mx.quantized_matmul(x, *w2, transpose=True) + x = mx.quantized_matmul(x, *w1, transpose=True, **qmm_kwargs) + x = mx.quantized_matmul(x, *w2, transpose=True, **qmm_kwargs) return x time_fn(equivalent_matmul, x, w1, w2) @@ -82,3 +95,5 @@ def equivalent_matmul(x, w1, w2): if __name__ == "__main__": time_gather_qmm() + for codec in ("q8_0", "q4_k", "q6_k"): + time_gather_qmm(mode="kquant", kquant_type=codec) diff --git a/mlx/backend/cpu/quantized.cpp b/mlx/backend/cpu/quantized.cpp index c0f1a3c315..380abbf8e9 100644 --- a/mlx/backend/cpu/quantized.cpp +++ b/mlx/backend/cpu/quantized.cpp @@ -921,6 +921,28 @@ void fp_bs_qmm_dispatch( } // namespace +namespace { + +template +void kquant_dequantize_dispatch( + const uint8_t* w, + T* out, + size_t num_weights, + const std::string& kquant_type); + +template +void kquant_qmm_cpu( + T* result, + const T* x, + const uint8_t* w, + int M, + int N, + int K, + bool transpose_w, + const std::string& kquant_type); + +} // namespace + void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { auto& x_pre = inputs[0]; auto& w_pre = inputs[1]; @@ -950,6 +972,68 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { transpose_ = transpose_]() mutable { _qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_); }); + } else if (mode_ == QuantizationMode::KQuant) { + encoder.dispatch([out = array::unsafe_weak_copy(out), + x = array::unsafe_weak_copy(x), + w = array::unsafe_weak_copy(w), + transpose_ = transpose_, + kquant_type = kquant_type_]() mutable { + int K = x.shape(-1); + int M = x.ndim() > 1 ? x.shape(-2) : 1; + int N = out.shape(-1); + int batch_size = x.size() / (K * M); + size_t w_batch_els = w.ndim() > 2 ? w.shape(-1) * w.shape(-2) : 0; + switch (x.dtype()) { + case float32: + for (int i = 0; i < batch_size; i++) { + kquant_qmm_cpu( + out.data() + i * M * N, + x.data() + + elem_to_loc(i * M * K, x.shape(), x.strides()), + w.data() + + elem_to_loc(i * w_batch_els, w.shape(), w.strides()), + M, + N, + K, + transpose_, + kquant_type); + } + break; + case float16: + for (int i = 0; i < batch_size; i++) { + kquant_qmm_cpu( + out.data() + i * M * N, + x.data() + + elem_to_loc(i * M * K, x.shape(), x.strides()), + w.data() + + elem_to_loc(i * w_batch_els, w.shape(), w.strides()), + M, + N, + K, + transpose_, + kquant_type); + } + break; + case bfloat16: + for (int i = 0; i < batch_size; i++) { + kquant_qmm_cpu( + out.data() + i * M * N, + x.data() + + elem_to_loc(i * M * K, x.shape(), x.strides()), + w.data() + + elem_to_loc(i * w_batch_els, w.shape(), w.strides()), + M, + N, + K, + transpose_, + kquant_type); + } + break; + default: + throw std::invalid_argument( + "[quantized_matmul] only floating types are supported"); + } + }); } else { encoder.dispatch([out = array::unsafe_weak_copy(out), x = array::unsafe_weak_copy(x), @@ -964,6 +1048,69 @@ void QuantizedMatmul::eval_cpu(const std::vector& inputs, array& out) { } void GatherQMM::eval_cpu(const std::vector& inputs, array& out) { + if (mode_ == QuantizationMode::KQuant) { + auto& encoder = cpu::get_command_encoder(stream()); + auto x = ensure_row_contiguous(inputs[0], encoder, stream()); + auto w = ensure_row_contiguous(inputs[1], encoder, stream()); + auto& lhs_indices = inputs[inputs.size() - 2]; + auto& rhs_indices = inputs[inputs.size() - 1]; + + out.set_data(allocator::malloc(out.nbytes())); + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(lhs_indices); + encoder.set_input_array(rhs_indices); + encoder.set_output_array(out); + encoder.dispatch([out = array::unsafe_weak_copy(out), + x = array::unsafe_weak_copy(x), + w = array::unsafe_weak_copy(w), + lhs_indices = array::unsafe_weak_copy(lhs_indices), + rhs_indices = array::unsafe_weak_copy(rhs_indices), + transpose_ = transpose_, + kquant_type = kquant_type_]() mutable { + int K = x.shape(-1); + int M = x.shape(-2); + int N = out.shape(-1); + int w_els = w.shape(-1) * w.shape(-2); + auto lhs_ptr = lhs_indices.data(); + auto rhs_ptr = rhs_indices.data(); + + auto gather_loop = [&](auto* tag) { + using T = std::remove_pointer_t; + for (int i = 0; i < lhs_indices.size(); i++) { + int x_idx = lhs_ptr[elem_to_loc( + i, lhs_indices.shape(), lhs_indices.strides())]; + int w_idx = rhs_ptr[elem_to_loc( + i, rhs_indices.shape(), rhs_indices.strides())]; + kquant_qmm_cpu( + out.data() + i * M * N, + x.data() + elem_to_loc(x_idx * M * K, x.shape(), x.strides()), + w.data() + + elem_to_loc(w_idx * w_els, w.shape(), w.strides()), + M, + N, + K, + transpose_, + kquant_type); + } + }; + switch (x.dtype()) { + case float32: + gather_loop(static_cast(nullptr)); + break; + case float16: + gather_loop(static_cast(nullptr)); + break; + case bfloat16: + gather_loop(static_cast(nullptr)); + break; + default: + throw std::invalid_argument( + "[quantized_matmul] only floating types are supported"); + } + }); + return; + } auto& x_pre = inputs[0]; auto& w_pre = inputs[1]; auto& scales_pre = inputs[2]; @@ -1226,9 +1373,458 @@ void dispatch_quantize( w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size()); } +namespace { + +inline float read_f16(const uint8_t* ptr) { + _Float16 tmp; + std::memcpy(&tmp, ptr, sizeof(_Float16)); + return static_cast(tmp); +} + +template +void kquant_dequantize_q8_0(const uint8_t* w, T* out, size_t num_weights) { + constexpr int block_weights = 32; + constexpr int block_bytes = 34; + size_t num_blocks = num_weights / block_weights; + for (size_t b = 0; b < num_blocks; b++) { + const uint8_t* block = w + b * block_bytes; + float d = read_f16(block); + const int8_t* qs = reinterpret_cast(block + 2); + T* dst = out + b * block_weights; + for (int i = 0; i < block_weights; i++) { + dst[i] = static_cast(d * static_cast(qs[i])); + } + } +} + +template +void kquant_dequantize_q4_0(const uint8_t* w, T* out, size_t num_weights) { + constexpr int block_weights = 32; + constexpr int block_bytes = 18; + size_t num_blocks = num_weights / block_weights; + for (size_t b = 0; b < num_blocks; b++) { + const uint8_t* block = w + b * block_bytes; + float d = read_f16(block); + const uint8_t* qs = block + 2; + T* dst = out + b * block_weights; + for (int j = 0; j < 16; j++) { + int x0 = (qs[j] & 0x0F) - 8; + int x1 = (qs[j] >> 4) - 8; + dst[j] = static_cast(d * static_cast(x0)); + dst[j + 16] = static_cast(d * static_cast(x1)); + } + } +} + +template +void kquant_dequantize_q4_1(const uint8_t* w, T* out, size_t num_weights) { + constexpr int block_weights = 32; + constexpr int block_bytes = 20; + size_t num_blocks = num_weights / block_weights; + for (size_t b = 0; b < num_blocks; b++) { + const uint8_t* block = w + b * block_bytes; + float d = read_f16(block); + float m = read_f16(block + 2); + const uint8_t* qs = block + 4; + T* dst = out + b * block_weights; + for (int j = 0; j < 16; j++) { + int x0 = qs[j] & 0x0F; + int x1 = qs[j] >> 4; + dst[j] = static_cast(d * static_cast(x0) + m); + dst[j + 16] = static_cast(d * static_cast(x1) + m); + } + } +} + +template +void kquant_dequantize_q5_0(const uint8_t* w, T* out, size_t num_weights) { + constexpr int block_weights = 32; + constexpr int block_bytes = 22; + size_t num_blocks = num_weights / block_weights; + for (size_t b = 0; b < num_blocks; b++) { + const uint8_t* block = w + b * block_bytes; + float d = read_f16(block); + const uint8_t* qh_bytes = block + 2; + uint32_t qh = static_cast(qh_bytes[0]) | + (static_cast(qh_bytes[1]) << 8) | + (static_cast(qh_bytes[2]) << 16) | + (static_cast(qh_bytes[3]) << 24); + const uint8_t* qs = block + 6; + T* dst = out + b * block_weights; + for (int j = 0; j < 16; j++) { + int xh_0 = ((qh >> j) << 4) & 0x10; + int xh_1 = (qh >> (j + 12)) & 0x10; + int x0 = (qs[j] & 0x0F) | xh_0; + int x1 = (qs[j] >> 4) | xh_1; + dst[j] = static_cast(d * static_cast(x0 - 16)); + dst[j + 16] = static_cast(d * static_cast(x1 - 16)); + } + } +} + +template +void kquant_dequantize_q5_1(const uint8_t* w, T* out, size_t num_weights) { + constexpr int block_weights = 32; + constexpr int block_bytes = 24; + size_t num_blocks = num_weights / block_weights; + for (size_t b = 0; b < num_blocks; b++) { + const uint8_t* block = w + b * block_bytes; + float d = read_f16(block); + float m = read_f16(block + 2); + const uint8_t* qh_bytes = block + 4; + const uint8_t* qs = block + 8; + uint32_t qh; + std::memcpy(&qh, qh_bytes, 4); + T* dst = out + b * block_weights; + for (int j = 0; j < 16; j++) { + uint8_t xh_0 = ((qh >> j) << 4) & 0x10; + uint8_t xh_1 = ((qh >> (j + 12))) & 0x10; + uint8_t x0 = (qs[j] & 0x0F) | xh_0; + uint8_t x1 = (qs[j] >> 4) | xh_1; + dst[j] = static_cast(d * static_cast(x0) + m); + dst[j + 16] = static_cast(d * static_cast(x1) + m); + } + } +} + +inline void kquant_unpack_q4k_scales( + const uint8_t* scales_packed, + float* sc, + float* mn, + float d, + float dmin) { + for (int i = 0; i < 8; i++) { + uint8_t raw_sc, raw_m; + if (i < 4) { + raw_sc = scales_packed[i] & 0x3F; + raw_m = scales_packed[i + 4] & 0x3F; + } else { + raw_sc = + (scales_packed[i + 4] & 0x0F) | ((scales_packed[i - 4] >> 6) << 4); + raw_m = (scales_packed[i + 4] >> 4) | ((scales_packed[i] >> 6) << 4); + } + sc[i] = d * static_cast(raw_sc); + mn[i] = dmin * static_cast(raw_m); + } +} + +template +void kquant_dequantize_q4_k(const uint8_t* w, T* out, size_t num_weights) { + constexpr int block_weights = 256; + constexpr int block_bytes = 144; + size_t num_blocks = num_weights / block_weights; + for (size_t b = 0; b < num_blocks; b++) { + const uint8_t* block = w + b * block_bytes; + float d = read_f16(block); + float dmin = read_f16(block + 2); + const uint8_t* scales_packed = block + 4; + const uint8_t* qs = block + 16; + + float sc[8], mn[8]; + kquant_unpack_q4k_scales(scales_packed, sc, mn, d, dmin); + + T* dst = out + b * block_weights; + for (int g = 0; g < 4; g++) { + for (int i = 0; i < 32; i++) { + dst[(2 * g) * 32 + i] = static_cast( + sc[2 * g] * static_cast(qs[g * 32 + i] & 0x0F) - mn[2 * g]); + dst[(2 * g + 1) * 32 + i] = static_cast( + sc[2 * g + 1] * static_cast(qs[g * 32 + i] >> 4) - + mn[2 * g + 1]); + } + } + } +} + +template +void kquant_dequantize_q5_k(const uint8_t* w, T* out, size_t num_weights) { + constexpr int block_weights = 256; + constexpr int block_bytes = 176; + size_t num_blocks = num_weights / block_weights; + for (size_t b = 0; b < num_blocks; b++) { + const uint8_t* block = w + b * block_bytes; + float d = read_f16(block); + float dmin = read_f16(block + 2); + const uint8_t* scales_packed = block + 4; + const uint8_t* qh = block + 16; + const uint8_t* qs = block + 48; + + float sc[8], mn[8]; + kquant_unpack_q4k_scales(scales_packed, sc, mn, d, dmin); + + T* dst = out + b * block_weights; + for (int g = 0; g < 4; g++) { + for (int i = 0; i < 32; i++) { + uint8_t lo0 = qs[g * 32 + i] & 0x0F; + uint8_t lo1 = qs[g * 32 + i] >> 4; + uint8_t hi0 = (qh[i] >> (2 * g)) & 1; + uint8_t hi1 = (qh[i] >> (2 * g + 1)) & 1; + dst[(2 * g) * 32 + i] = static_cast( + sc[2 * g] * static_cast(lo0 | (hi0 << 4)) - mn[2 * g]); + dst[(2 * g + 1) * 32 + i] = static_cast( + sc[2 * g + 1] * static_cast(lo1 | (hi1 << 4)) - + mn[2 * g + 1]); + } + } + } +} + +template +void kquant_dequantize_q6_k(const uint8_t* w, T* out, size_t num_weights) { + constexpr int block_weights = 256; + constexpr int block_bytes = 210; + size_t num_blocks = num_weights / block_weights; + for (size_t b = 0; b < num_blocks; b++) { + const uint8_t* block = w + b * block_bytes; + const uint8_t* ql_base = block; + const uint8_t* qh_base = block + 128; + const int8_t* scales = reinterpret_cast(block + 192); + float d = read_f16(block + 208); + + T* dst = out + b * block_weights; + for (int half = 0; half < 2; half++) { + const uint8_t* ql = ql_base + half * 64; + const uint8_t* qh = qh_base + half * 32; + const int8_t* sc = scales + half * 8; + T* out_half = dst + half * 128; + + for (int l = 0; l < 32; l++) { + int is0 = l / 16; + int8_t q1 = + static_cast((ql[l] & 0x0F) | (((qh[l] >> 0) & 3) << 4)) - + 32; + int8_t q2 = static_cast( + (ql[l + 32] & 0x0F) | (((qh[l] >> 2) & 3) << 4)) - + 32; + int8_t q3 = + static_cast((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + int8_t q4 = + static_cast((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - + 32; + out_half[l] = static_cast( + d * static_cast(sc[is0]) * static_cast(q1)); + out_half[l + 32] = static_cast( + d * static_cast(sc[is0 + 2]) * static_cast(q2)); + out_half[l + 64] = static_cast( + d * static_cast(sc[is0 + 4]) * static_cast(q3)); + out_half[l + 96] = static_cast( + d * static_cast(sc[is0 + 6]) * static_cast(q4)); + } + } + } +} + +inline void kquant_unpack_q3k_scales(const uint8_t* s, int32_t* sc) { + for (int k = 0; k < 4; k++) { + sc[k] = static_cast(s[k] & 0x0F) | + (static_cast((s[8 + k]) & 0x03) << 4); + sc[k + 4] = static_cast(s[k + 4] & 0x0F) | + (static_cast((s[8 + k] >> 2) & 0x03) << 4); + sc[k + 8] = static_cast((s[k] >> 4) & 0x0F) | + (static_cast((s[8 + k] >> 4) & 0x03) << 4); + sc[k + 12] = static_cast((s[k + 4] >> 4) & 0x0F) | + (static_cast((s[8 + k] >> 6) & 0x03) << 4); + } + for (int i = 0; i < 16; i++) { + sc[i] -= 32; + } +} + +template +void kquant_dequantize_q3_k(const uint8_t* w, T* out, size_t num_weights) { + constexpr int block_weights = 256; + constexpr int block_bytes = 110; + size_t num_blocks = num_weights / block_weights; + for (size_t b = 0; b < num_blocks; b++) { + const uint8_t* block = w + b * block_bytes; + const uint8_t* hmask = block; + const uint8_t* qs_full = block + 32; + const uint8_t* scales_packed = block + 96; + float d = read_f16(block + 108); + + int32_t sc[16]; + kquant_unpack_q3k_scales(scales_packed, sc); + + T* dst = out + b * block_weights; + int out_idx = 0; + for (int outer_half = 0; outer_half < 2; outer_half++) { + const uint8_t* qs_chunk = qs_full + outer_half * 32; + for (int shift_idx = 0; shift_idx < 4; shift_idx++) { + int shift = shift_idx * 2; + uint8_t m = 1 << (outer_half * 4 + shift_idx); + int is_left = outer_half * 8 + shift_idx * 2; + float dl_left = d * static_cast(sc[is_left]); + for (int l = 0; l < 16; l++) { + int q2 = (qs_chunk[l] >> shift) & 3; + int h = (hmask[l] & m) ? 0 : 4; + dst[out_idx++] = static_cast(dl_left * static_cast(q2 - h)); + } + float dl_right = d * static_cast(sc[is_left + 1]); + for (int l = 0; l < 16; l++) { + int q2 = (qs_chunk[l + 16] >> shift) & 3; + int h = (hmask[l + 16] & m) ? 0 : 4; + dst[out_idx++] = + static_cast(dl_right * static_cast(q2 - h)); + } + } + } + } +} + +template +void kquant_dequantize_q2_k(const uint8_t* w, T* out, size_t num_weights) { + constexpr int block_weights = 256; + constexpr int block_bytes = 84; + size_t num_blocks = num_weights / block_weights; + for (size_t b = 0; b < num_blocks; b++) { + const uint8_t* block = w + b * block_bytes; + const uint8_t* scales_raw = block; + const uint8_t* qs_full = block + 16; + float d = read_f16(block + 80); + float dmin = read_f16(block + 82); + + T* dst = out + b * block_weights; + int out_idx = 0; + int is_idx = 0; + for (int outer_half = 0; outer_half < 2; outer_half++) { + const uint8_t* qs_chunk = qs_full + outer_half * 32; + for (int shift_idx = 0; shift_idx < 4; shift_idx++) { + int shift = shift_idx * 2; + uint8_t sc_byte_left = scales_raw[is_idx++]; + float dl_left = d * static_cast(sc_byte_left & 0x0F); + float ml_left = dmin * static_cast(sc_byte_left >> 4); + for (int l = 0; l < 16; l++) { + int q2 = (qs_chunk[l] >> shift) & 3; + dst[out_idx++] = + static_cast(dl_left * static_cast(q2) - ml_left); + } + uint8_t sc_byte_right = scales_raw[is_idx++]; + float dl_right = d * static_cast(sc_byte_right & 0x0F); + float ml_right = dmin * static_cast(sc_byte_right >> 4); + for (int l = 0; l < 16; l++) { + int q2 = (qs_chunk[l + 16] >> shift) & 3; + dst[out_idx++] = + static_cast(dl_right * static_cast(q2) - ml_right); + } + } + } + } +} + +template +void kquant_dequantize_dispatch( + const uint8_t* w, + T* out, + size_t num_weights, + const std::string& kquant_type) { + if (kquant_type == "q8_0") { + kquant_dequantize_q8_0(w, out, num_weights); + } else if (kquant_type == "q4_0") { + kquant_dequantize_q4_0(w, out, num_weights); + } else if (kquant_type == "q4_1") { + kquant_dequantize_q4_1(w, out, num_weights); + } else if (kquant_type == "q5_0") { + kquant_dequantize_q5_0(w, out, num_weights); + } else if (kquant_type == "q5_1") { + kquant_dequantize_q5_1(w, out, num_weights); + } else if (kquant_type == "q4_k") { + kquant_dequantize_q4_k(w, out, num_weights); + } else if (kquant_type == "q5_k") { + kquant_dequantize_q5_k(w, out, num_weights); + } else if (kquant_type == "q6_k") { + kquant_dequantize_q6_k(w, out, num_weights); + } else if (kquant_type == "q3_k") { + kquant_dequantize_q3_k(w, out, num_weights); + } else if (kquant_type == "q2_k") { + kquant_dequantize_q2_k(w, out, num_weights); + } else { + throw std::runtime_error( + "[kquant_dequantize] Unsupported codec: " + kquant_type); + } +} + +template +void kquant_qmm_cpu( + T* result, + const T* x, + const uint8_t* w, + int M, + int N, + int K, + bool transpose_w, + const std::string& kquant_type) { + const auto* codec = kquant_codec_by_name(kquant_type); + int w_rows = transpose_w ? N : K; + int w_cols = transpose_w ? K : N; + size_t weights_per_row = static_cast(w_cols); + size_t row_bytes = + (weights_per_row / codec->weights_per_block) * codec->bytes_per_block; + + std::vector w_dec(static_cast(w_rows) * w_cols); + for (int r = 0; r < w_rows; r++) { + kquant_dequantize_dispatch( + w + r * row_bytes, w_dec.data() + r * w_cols, w_cols, kquant_type); + } + + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + float acc = 0.0f; + if (transpose_w) { + for (int k = 0; k < K; k++) { + acc += static_cast(x[m * K + k]) * w_dec[n * K + k]; + } + } else { + for (int k = 0; k < K; k++) { + acc += static_cast(x[m * K + k]) * w_dec[k * N + n]; + } + } + result[m * N + n] = static_cast(acc); + } + } +} + +} // namespace + void fast::Quantize::eval_cpu( const std::vector& inputs, std::vector& outputs) { + if (mode_ == QuantizationMode::KQuant) { + if (!dequantize_) { + throw std::runtime_error( + "[fast::Quantize::eval_cpu] KQuant encode is GPU-only."); + } + auto& encoder = cpu::get_command_encoder(stream()); + auto w = ensure_row_contiguous(inputs[0], encoder, stream()); + auto& out = outputs[0]; + out.set_data(allocator::malloc(out.nbytes())); + encoder.set_input_array(w); + encoder.set_output_array(out); + size_t num_weights = out.size(); + encoder.dispatch([w = array::unsafe_weak_copy(w), + out = array::unsafe_weak_copy(out), + num_weights, + kquant_type = kquant_type_]() mutable { + auto w_ptr = w.data(); + switch (out.dtype()) { + case float32: + kquant_dequantize_dispatch( + w_ptr, out.data(), num_weights, kquant_type); + break; + case float16: + kquant_dequantize_dispatch( + w_ptr, out.data(), num_weights, kquant_type); + break; + case bfloat16: + kquant_dequantize_dispatch( + w_ptr, out.data(), num_weights, kquant_type); + break; + default: + throw std::runtime_error( + "[fast::Quantize::eval_cpu] KQuant dequantize only supports float types."); + } + }); + return; + } auto& encoder = cpu::get_command_encoder(stream()); auto w = ensure_row_contiguous(inputs[0], encoder, stream()); auto& out = outputs[0]; diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index a4e019d662..3d34ad7b87 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -71,6 +71,10 @@ GemmScalars create_nvfp4_scalars( void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("QQMatmul::eval_gpu"); + if (mode_ == QuantizationMode::KQuant) { + throw std::runtime_error( + "[QQMatmul::eval_gpu] KQuant CUDA not implemented."); + } auto& s = stream(); auto& encoder = cu::get_command_encoder(s); diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 645b24cca6..b4e50fb465 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -14,6 +14,10 @@ namespace mlx::core { void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("QuantizedMatmul::eval_gpu"); + if (mode_ == QuantizationMode::KQuant) { + throw std::runtime_error( + "[QuantizedMatmul::eval_gpu] KQuant CUDA not implemented."); + } auto& s = stream(); auto& encoder = cu::get_command_encoder(s); @@ -143,6 +147,10 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("GatherQMM::eval_gpu"); + if (mode_ == QuantizationMode::KQuant) { + throw std::runtime_error( + "[GatherQMM::eval_gpu] KQuant CUDA not implemented."); + } auto& s = stream(); auto& encoder = cu::get_command_encoder(s); @@ -270,6 +278,10 @@ void fast::Quantize::eval_gpu( const std::vector& inputs, std::vector& outputs) { nvtx3::scoped_range r("Quantize::eval_gpu"); + if (mode_ == QuantizationMode::KQuant) { + throw std::runtime_error( + "[fast::Quantize::eval_gpu] KQuant encode/decode CUDA is NYI."); + } auto& s = stream(); auto& enc = cu::get_command_encoder(s); if (dequantize_) { diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index bbb08137f6..4a74f9d0b9 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -79,6 +79,10 @@ if(MLX_METAL_JIT) make_jit_source(quantized kernels/quantized_utils.h) make_jit_source(fp_quantized kernels/quantized_utils.h kernels/fp8.h kernels/fp4.h) + make_jit_source(kq_quantized kernels/quantized_utils.h + kernels/kq_quantized_legacy.h) + make_jit_source(kq_quantized_encode kernels/quantized_utils.h + kernels/kq_quantized.h kernels/kq_quantized_legacy.h) make_jit_source(gemv_masked) make_jit_source(steel/attn/kernels/steel_attention) @@ -94,6 +98,8 @@ if(MLX_METAL_JIT) make_jit_source(quantized_nax kernels/quantized_utils.h) make_jit_source(fp_quantized_nax kernels/quantized_utils.h kernels/fp8.h kernels/fp4.h) + make_jit_source(kq_quantized_nax kernels/quantized_utils.h + kernels/kq_quantized.h kernels/kq_quantized_legacy.h) make_jit_source(steel/attn/kernels/steel_attention_nax) diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index e22efa96d0..6e6b41b93c 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -26,6 +26,8 @@ const char* logsumexp(); const char* quantized_utils(); const char* quantized(); const char* fp_quantized(); +const char* kq_quantized(); +const char* kq_quantized_encode(); const char* ternary(); const char* scan(); const char* scatter_axis(); @@ -54,6 +56,7 @@ const char* steel_gemm_segmented_nax(); const char* quantized_nax(); const char* fp_quantized_nax(); +const char* kq_quantized_nax(); const char* steel_attention_nax(); diff --git a/mlx/backend/metal/jit_kernels.cpp b/mlx/backend/metal/jit_kernels.cpp index 9c47b53b40..76f2e021ca 100644 --- a/mlx/backend/metal/jit_kernels.cpp +++ b/mlx/backend/metal/jit_kernels.cpp @@ -849,16 +849,26 @@ MTL::ComputePipelineState* get_quantized_kernel( metal::Device& d, const std::string& kernel_name, const std::string& template_def, - const std::string& mode) { + const std::string& mode, + bool is_encode) { const auto& lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { std::string kernel_source; + const char* family_source; + if (mode == "affine") { + family_source = metal::quantized(); + } else if (mode == "kquant") { + family_source = + is_encode ? metal::kq_quantized_encode() : metal::kq_quantized(); + } else { + family_source = metal::fp_quantized(); + } concatenate( kernel_source, metal::utils(), metal::gemm(), metal::quantized_utils(), - (mode == "affine") ? metal::quantized() : metal::fp_quantized(), + family_source, template_def); return kernel_source; }); @@ -885,6 +895,11 @@ MTL::ComputePipelineState* get_gather_qmm_kernel( std::string kernel_source; concatenate( kernel_source, metal::utils(), metal::quantized_utils(), metal::gemm()); + if (mode == "kquant") { + throw std::runtime_error( + "[GatherQMM] KQuant gather uses the NAX path; " + "non-NAX JIT should be unreachable."); + } bool is_affine = mode == "affine"; concatenate( kernel_source, @@ -1059,7 +1074,9 @@ MTL::ComputePipelineState* get_qmm_nax_kernel( metal::utils(), metal::gemm_nax(), metal::quantized_utils(), - (mode == "affine") ? metal::quantized_nax() : metal::fp_quantized_nax(), + (mode == "affine") ? metal::quantized_nax() + : (mode == "kquant") ? metal::kq_quantized_nax() + : metal::fp_quantized_nax(), template_def); return kernel_source; }); @@ -1075,6 +1092,7 @@ MTL::ComputePipelineState* get_gather_qmm_nax_kernel( int group_size, int bits, const std::string& mode, + const std::string& func_name, int bm, int bn, int bk, @@ -1089,13 +1107,14 @@ MTL::ComputePipelineState* get_gather_qmm_nax_kernel( metal::utils(), metal::gemm_nax(), metal::quantized_utils()); - bool is_affine = mode == "affine"; concatenate( kernel_source, - is_affine ? metal::quantized_nax() : metal::fp_quantized_nax(), + (mode == "affine") ? metal::quantized_nax() + : (mode == "kquant") ? metal::kq_quantized_nax() + : metal::fp_quantized_nax(), get_template_definition( lib_name, - (is_affine ? "affine" : "fp") + std::string("_gather_qmm_rhs_nax"), + func_name, get_type_string(x.dtype()), group_size, bits, diff --git a/mlx/backend/metal/kernels.h b/mlx/backend/metal/kernels.h index dc0dab970d..846c589ee6 100644 --- a/mlx/backend/metal/kernels.h +++ b/mlx/backend/metal/kernels.h @@ -250,7 +250,8 @@ MTL::ComputePipelineState* get_quantized_kernel( metal::Device& d, const std::string& kernel_name, const std::string& template_def, - const std::string& mode); + const std::string& mode, + bool is_encode = false); MTL::ComputePipelineState* get_gather_qmm_kernel( metal::Device& d, @@ -340,6 +341,7 @@ MTL::ComputePipelineState* get_gather_qmm_nax_kernel( int group_size, int bits, const std::string& mode, + const std::string& func_name, int bm, int bn, int bk, diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 4b13e3ec57..0892ec7253 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -101,6 +101,7 @@ set(STEEL_ATTN_HEADERS set(STEEL_NAX_HEADERS steel/defines.h steel/utils.h + steel/gemm/loader.h steel/gemm/params.h steel/gemm/transforms.h steel/gemm/nax.h @@ -138,6 +139,10 @@ if(NOT MLX_METAL_JIT) build_kernel(quantized quantized.h quantized_utils.h ${STEEL_HEADERS}) build_kernel(fp_quantized fp4.h fp8.h fp_quantized.h quantized_utils.h ${STEEL_HEADERS}) + build_kernel(kq_quantized kq_quantized.h kq_quantized_legacy.h + quantized_utils.h ${STEEL_HEADERS}) + build_kernel(kq_quantized_encode kq_quantized.h kq_quantized_legacy.h + kq_quantized_encode.h quantized_utils.h ${STEEL_HEADERS}) build_kernel(scan scan.h) build_kernel(softmax softmax.h) build_kernel(logsumexp logsumexp.h) @@ -167,6 +172,8 @@ if(NOT MLX_METAL_JIT) build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS}) build_kernel(fp_quantized_nax fp4.h fp8.h fp_quantized_nax.h ${STEEL_NAX_HEADERS}) + build_kernel(kq_quantized_nax kq_quantized.h kq_quantized_legacy.h + kq_quantized_nax.h quantized_utils.h ${STEEL_NAX_HEADERS}) build_kernel(steel/attn/kernels/steel_attention_nax ${STEEL_NAX_ATTN_HEADERS}) diff --git a/mlx/backend/metal/kernels/fp_quantized.h b/mlx/backend/metal/kernels/fp_quantized.h index f4bf438df2..23a0340cb2 100644 --- a/mlx/backend/metal/kernels/fp_quantized.h +++ b/mlx/backend/metal/kernels/fp_quantized.h @@ -59,24 +59,8 @@ struct Dequantize { } }; -template -inline void load_vector(const device T* x, thread U* x_thread) { -#pragma unroll - for (int i = 0; i < values_per_thread; i++) { - x_thread[i] = x[i]; - } -} - -template -inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { - for (int i = 0; i < N; i++) { - x_thread[i] = x[i]; - } - - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; - } -} +// load_vector / load_vector_safe live in quantized_utils.h (shared with +// kq_quantized.h). template inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) { diff --git a/mlx/backend/metal/kernels/kq_quantized.h b/mlx/backend/metal/kernels/kq_quantized.h new file mode 100644 index 0000000000..54c7d6a92b --- /dev/null +++ b/mlx/backend/metal/kernels/kq_quantized.h @@ -0,0 +1,5224 @@ +// Copyright © 2026 Apple Inc. + +#include +#include + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; + +struct kq_empty {}; + +template < + typename T, + typename LoaderW, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void kq_qmm_t_impl( + const device uint8_t* w, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + const int K_eff, + uint3 tid, + uint lid, + uint simd_gid, + uint simd_lid) { + static_assert(BK >= SIMD_SIZE, "BK should be >= SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be a multiple of SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + using mma_t = mlx::steel::BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + /*transpose_a=*/false, + /*transpose_b=*/true, + BK_padded, + BK_padded>; + using loader_x_t = + mlx::steel::BlockLoader; + + const int K_w = (K / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = w; + + x += y_row * static_cast(K); + wl += static_cast(y_col) * K_w; + y += y_row * static_cast(N) + y_col; + + const short num_els = min(BM, M - y_row); + const short num_outs = min(BN, N - y_col); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + LoaderW loader_w(wl, K, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K_eff; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K_eff; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K_eff; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K_eff; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM || num_outs < BN) { + mma_op.store_result_safe(y, N, short2(num_outs, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template < + typename T, + typename LoaderW, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void kq_qmm_n_impl( + const device uint8_t* w, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid, + uint lid, + uint simd_gid, + uint simd_lid) { + static_assert(BK >= SIMD_SIZE, "BK should be >= SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be a multiple of SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using mma_t = mlx::steel::BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + /*transpose_a=*/false, + /*transpose_b=*/false, + BK_padded, + BN_padded>; + using loader_x_t = mlx::steel:: + BlockLoader; + + auto wl = w; + + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * static_cast(K); + wl += (y_col / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + y += y_row * static_cast(N) + y_col; + + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + LoaderW loader_w( + wl, N, Ws, simd_gid, simd_lid, y_col % LoaderW::weights_per_block); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, num_els)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, BM)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template +METAL_FUNC void kq_adjust_matrix_offsets( + const device T*& x, + const device uint8_t*& w, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + } else { + w += elem_to_loc(w_idx, w_shape, w_strides, w_batch_ndims); + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void kq_adjust_matrix_offsets( + const device T*& x, + const device uint8_t*& w, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + } else { + w += elem_to_loc(w_idx, w_shape, w_strides, w_batch_ndims); + } + y += tid.z * output_stride; +} + +// Q8_0: 34 bytes/32 weights. [fp16 d][int8 q[32]]. w[i] = d * q[i]. + +MLX_MTL_CONST int KQ_Q8_0_GROUP = 32; +MLX_MTL_CONST int KQ_Q8_0_BLOCK_BYTES = 34; +MLX_MTL_CONST int KQ_Q8_0_D_OFFSET = 0; +MLX_MTL_CONST int KQ_Q8_0_Q_OFFSET = 2; + +inline float kq_q8_0_d(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q8_0_D_OFFSET)); +} + +inline const device int8_t* kq_q8_0_q_ptr(const device uint8_t* block_addr) { + return (const device int8_t*)(block_addr + KQ_Q8_0_Q_OFFSET); +} + +template +METAL_FUNC void kq_q8_0_dequantize_impl( + const device uint8_t* w, + device T* out, + const constant uint& num_weights, + uint gid) { + if (gid >= num_weights) { + return; + } + const int block_id = gid / KQ_Q8_0_GROUP; + const int within = gid % KQ_Q8_0_GROUP; + const device uint8_t* block_addr = w + block_id * KQ_Q8_0_BLOCK_BYTES; + const float d = kq_q8_0_d(block_addr); + const int8_t q = kq_q8_0_q_ptr(block_addr)[within]; + out[gid] = T(d * float(q)); +} + +template +METAL_FUNC void kq_q8_0_qmv_fast_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q8_0_GROUP, "Q8_0 kernel requires group_size=32"); + static_assert(bits == 8, "Q8_0 kernel requires bits=8"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int values_per_thread = 8; + constexpr int block_size = values_per_thread * SIMD_SIZE; + + typedef float U; + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + const int row_bytes = in_vec_size * KQ_Q8_0_BLOCK_BYTES / KQ_Q8_0_GROUP; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + const int lane_k_offset = simd_lid * values_per_thread; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int k = 0; k < in_vec_size; k += block_size) { + load_vector(x + k + lane_k_offset, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + const int row_idx = out_row + row; + const device uint8_t* row_base = w + row_idx * row_bytes; + + const int k_global = k + lane_k_offset; + const int block_id = k_global / KQ_Q8_0_GROUP; + const int within = k_global - block_id * KQ_Q8_0_GROUP; + const device uint8_t* block_addr = + row_base + block_id * KQ_Q8_0_BLOCK_BYTES; + + const U d = U(kq_q8_0_d(block_addr)); + const device int8_t* q_ptr = kq_q8_0_q_ptr(block_addr) + within; + + U partial = 0; +#pragma unroll + for (int i = 0; i < values_per_thread; i++) { + partial += x_thread[i] * U(q_ptr[i]); + } + result[row] += d * partial; + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void kq_q8_0_qmv_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q8_0_GROUP, "Q8_0 kernel requires group_size=32"); + static_assert(bits == 8, "Q8_0 kernel requires bits=8"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int values_per_thread = 8; + constexpr int block_size = values_per_thread * SIMD_SIZE; + + typedef float U; + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + const int row_bytes = in_vec_size * KQ_Q8_0_BLOCK_BYTES / KQ_Q8_0_GROUP; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + if (out_row >= out_vec_size) { + return; + } + const int max_row = min(out_vec_size, out_row + results_per_simdgroup); + const int active_rows = max_row - out_row; + + const int lane_k_offset = simd_lid * values_per_thread; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int k = 0; k < in_vec_size; k += block_size) { + const int k_remaining = in_vec_size - k - lane_k_offset; + if (k_remaining >= values_per_thread) { + load_vector(x + k + lane_k_offset, x_thread); + } else if (k_remaining > 0) { + load_vector_safe( + x + k + lane_k_offset, x_thread, k_remaining); + } else { +#pragma unroll + for (int i = 0; i < values_per_thread; i++) { + x_thread[i] = 0; + } + } + + const int n_inner = k_remaining >= values_per_thread + ? values_per_thread + : (k_remaining > 0 ? k_remaining : 0); + + if (n_inner == 0) { + continue; + } + + const int k_global = k + lane_k_offset; + const int block_id = k_global / KQ_Q8_0_GROUP; + const int within = k_global - block_id * KQ_Q8_0_GROUP; + + for (int row = 0; row < active_rows; row++) { + const int row_idx = out_row + row; + const device uint8_t* row_base = w + row_idx * row_bytes; + const device uint8_t* block_addr = + row_base + block_id * KQ_Q8_0_BLOCK_BYTES; + + const U d = U(kq_q8_0_d(block_addr)); + const device int8_t* q_ptr = kq_q8_0_q_ptr(block_addr) + within; + + U partial = 0; +#pragma unroll + for (int i = 0; i < values_per_thread; i++) { + if (i < n_inner) { + partial += x_thread[i] * U(q_ptr[i]); + } + } + result[row] += d * partial; + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0 && row < active_rows) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqQ8_0BlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q8_0_GROUP; + MLX_MTL_CONST int bytes_per_block = KQ_Q8_0_BLOCK_BYTES; + + static_assert( + BCOLS == weights_per_block, + "Q8_0 loader requires BCOLS == 32 (one block per K-tile)."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + + const int src_ld; + const int row_bytes; + const int tile_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + + KqQ8_0BlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int /* col_in_block */ = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? bytes_per_block + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj((thread_idx % TCOLS) * n_reads), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)) {} + + void load_unsafe() const { + const float d = float(*(const device half*)src); + const device int8_t* q = + (const device int8_t*)(src + KQ_Q8_0_Q_OFFSET + bj); +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(d * float(q[i])); + } + } + + void load_safe(short2 src_tile_dim) const { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(0); + } + return; + } + const float d = float(*(const device half*)src); + const device int8_t* q = + (const device int8_t*)(src + KQ_Q8_0_Q_OFFSET + bj); +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(d * float(q[i])); + } + } + + void next() { + src += tile_stride; + } +}; + +template +[[kernel]] void kq_q8_0_qmm_t( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q8_0_GROUP, "Q8_0 kernel requires group_size=32"); + static_assert(bits == 8, "Q8_0 kernel requires bits=8"); + constexpr int BM = 64, BK = 32, BN = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ8_0BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_t_impl( + w, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q8_0_qmm_t_splitk( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& k_partition_size, + const constant int& split_k_partition_stride, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q8_0_GROUP, "Q8_0 kernel requires group_size=32"); + static_assert(bits == 8, "Q8_0 kernel requires bits=8"); + constexpr int BM = 32, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ8_0BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + + const int k_start = tid.z * k_partition_size; + x += k_start; + auto wl = w; + wl += (k_start / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + y += tid.z * static_cast(split_k_partition_stride); + + kq_qmm_t_impl( + wl, + x, + y, + Xs, + Ws, + K, + N, + M, + k_partition_size, + tid, + lid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void kq_q8_0_qmm_n( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q8_0_GROUP, "Q8_0 kernel requires group_size=32"); + static_assert(bits == 8, "Q8_0 kernel requires bits=8"); + constexpr int BM = 64, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + using LoaderW = KqQ8_0BlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_n_impl( + w, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q8_0_qmv_fast( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q8_0_qmv_fast_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q8_0_qmv( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q8_0_qmv_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q8_0_dequantize( + const device uint8_t* w, + const device uint8_t* /* scales */, + device T* out, + const constant uint& num_weights, + uint gid [[thread_position_in_grid]]) { + static_assert( + group_size == KQ_Q8_0_GROUP, "Q8_0 kernel requires group_size=32"); + static_assert(bits == 8, "Q8_0 kernel requires bits=8"); + kq_q8_0_dequantize_impl(w, out, num_weights, gid); +} + +#include "mlx/backend/metal/kernels/kq_quantized_legacy.h" + +// Q5_1: 24 bytes/32 weights. [fp16 d][fp16 m][uint32 qh][uint8 qs[16]]. +// q5 = (low4 | high_bit<<4); w[i] = d * q5[i] + m. + +MLX_MTL_CONST int KQ_Q5_1_GROUP = 32; +MLX_MTL_CONST int KQ_Q5_1_BLOCK_BYTES = 24; +MLX_MTL_CONST int KQ_Q5_1_D_OFFSET = 0; +MLX_MTL_CONST int KQ_Q5_1_M_OFFSET = 2; +MLX_MTL_CONST int KQ_Q5_1_QH_OFFSET = 4; +MLX_MTL_CONST int KQ_Q5_1_QS_OFFSET = 8; + +inline float kq_q5_1_d(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q5_1_D_OFFSET)); +} +inline float kq_q5_1_m(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q5_1_M_OFFSET)); +} +inline uint32_t kq_q5_1_qh(const device uint8_t* block_addr) { + return *(const device uint32_t*)(block_addr + KQ_Q5_1_QH_OFFSET); +} +inline const device uint8_t* kq_q5_1_qs_ptr(const device uint8_t* block_addr) { + return block_addr + KQ_Q5_1_QS_OFFSET; +} + +template +METAL_FUNC void kq_q5_1_dequantize_impl( + const device uint8_t* w, + device T* out, + const constant uint& num_weights, + uint gid) { + if (gid >= num_weights) { + return; + } + const int block_id = gid / KQ_Q5_1_GROUP; + const int within = gid % KQ_Q5_1_GROUP; + const device uint8_t* block_addr = w + block_id * KQ_Q5_1_BLOCK_BYTES; + const float d = kq_q5_1_d(block_addr); + const float m = kq_q5_1_m(block_addr); + const uint32_t qh = kq_q5_1_qh(block_addr); + const device uint8_t* qs = kq_q5_1_qs_ptr(block_addr); + const uint32_t hi = ((qh >> within) << 4) & 0x10u; + const uint8_t lo = + (within < 16) ? (qs[within] & 0x0Fu) : (qs[within - 16] >> 4); + const float q5 = float(uint32_t(lo) | hi); + out[gid] = T(d * q5 + m); +} + +template +METAL_FUNC void kq_q5_1_qmv_fast_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q5_1_GROUP, "Q5_1 kernel requires group_size=32"); + static_assert(bits == 5, "Q5_1 kernel requires bits=5"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int block_stride = 16; + + typedef float U; + thread U yl[16]; + thread U result[results_per_simdgroup] = {0}; + + const int ix = simd_lid / 2; + const int il = (simd_lid % 2) * 8; + + const int row_bytes = in_vec_size * KQ_Q5_1_BLOCK_BYTES / KQ_Q5_1_GROUP; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + const int nb = in_vec_size / KQ_Q5_1_GROUP; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += block_stride) { + const int x_base = ib * KQ_Q5_1_GROUP + il; + U sumy = U(0); +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const U a0 = U(x[x_base + i + 0]); + const U a1 = U(x[x_base + i + 1]); + const U b0 = U(x[x_base + i + 16]); + const U b1 = U(x[x_base + i + 17]); + sumy += a0 + a1 + b0 + b1; + yl[i + 0] = a0; + yl[i + 1] = a1 * (U(1) / U(256)); + yl[i + 8] = b0 * (U(1) / U(16)); + yl[i + 9] = b1 * (U(1) / U(4096)); + } + + for (int row = 0; row < results_per_simdgroup; row++) { + const int row_idx = out_row + row; + const device uint8_t* block_addr = + w + row_idx * row_bytes + ib * KQ_Q5_1_BLOCK_BYTES; + const U d = U(kq_q5_1_d(block_addr)); + const U m = U(kq_q5_1_m(block_addr)); + const uint32_t qh = kq_q5_1_qh(block_addr); + const device uint16_t* qs = + reinterpret_cast(kq_q5_1_qs_ptr(block_addr)) + + il / 2; + + U acc[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const uint16_t qi = qs[i / 2]; + acc[0] += yl[i + 0] * + U((qi & 0x000F) | (((qh >> (i + 0 + il)) << 4) & 0x00010)); + acc[1] += yl[i + 1] * + U((qi & 0x0F00) | (((qh >> (i + 1 + il)) << 12) & 0x01000)); + acc[2] += yl[i + 8] * + U((qi & 0x00F0) | (((qh >> (i + 0 + il + 16)) << 8) & 0x00100)); + acc[3] += yl[i + 9] * + U((qi & 0xF000) | (((qh >> (i + 1 + il + 16)) << 16) & 0x10000)); + } + result[row] += d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void kq_q5_1_qmv_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q5_1_GROUP, "Q5_1 kernel requires group_size=32"); + static_assert(bits == 5, "Q5_1 kernel requires bits=5"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int block_stride = 16; + + typedef float U; + thread U yl[16]; + thread U result[results_per_simdgroup] = {0}; + + const int row_bytes = in_vec_size * KQ_Q5_1_BLOCK_BYTES / KQ_Q5_1_GROUP; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + if (out_row >= out_vec_size) { + return; + } + const int max_row = min(out_vec_size, out_row + results_per_simdgroup); + const int active_rows = max_row - out_row; + + const int ix = simd_lid / 2; + const int il = (simd_lid % 2) * 8; + + const int nb = in_vec_size / KQ_Q5_1_GROUP; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += block_stride) { + const int x_base = ib * KQ_Q5_1_GROUP + il; + U sumy = U(0); +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const U a0 = U(x[x_base + i + 0]); + const U a1 = U(x[x_base + i + 1]); + const U b0 = U(x[x_base + i + 16]); + const U b1 = U(x[x_base + i + 17]); + sumy += a0 + a1 + b0 + b1; + yl[i + 0] = a0; + yl[i + 1] = a1 * (U(1) / U(256)); + yl[i + 8] = b0 * (U(1) / U(16)); + yl[i + 9] = b1 * (U(1) / U(4096)); + } + + for (int row = 0; row < active_rows; row++) { + const int row_idx = out_row + row; + const device uint8_t* block_addr = + w + row_idx * row_bytes + ib * KQ_Q5_1_BLOCK_BYTES; + const U d = U(kq_q5_1_d(block_addr)); + const U m = U(kq_q5_1_m(block_addr)); + const uint32_t qh = kq_q5_1_qh(block_addr); + const device uint16_t* qs = + reinterpret_cast(kq_q5_1_qs_ptr(block_addr)) + + il / 2; + + U acc[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const uint16_t qi = qs[i / 2]; + acc[0] += yl[i + 0] * + U((qi & 0x000F) | (((qh >> (i + 0 + il)) << 4) & 0x00010)); + acc[1] += yl[i + 1] * + U((qi & 0x0F00) | (((qh >> (i + 1 + il)) << 12) & 0x01000)); + acc[2] += yl[i + 8] * + U((qi & 0x00F0) | (((qh >> (i + 0 + il + 16)) << 8) & 0x00100)); + acc[3] += yl[i + 9] * + U((qi & 0xF000) | (((qh >> (i + 1 + il + 16)) << 16) & 0x10000)); + } + result[row] += d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0 && row < active_rows) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqQ5_1BlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q5_1_GROUP; + MLX_MTL_CONST int bytes_per_block = KQ_Q5_1_BLOCK_BYTES; + + static_assert( + BCOLS == weights_per_block, + "Q5_1 loader requires BCOLS == 32 (one block per K-tile)."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + MLX_MTL_CONST short bytes_per_thread = n_reads / 2; + MLX_MTL_CONST short half_block = weights_per_block / 2; + static_assert(n_reads >= 2 && n_reads % 2 == 0, "Q5_1 needs even n_reads."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + + const short thread_idx; + const short bi; + const short bj_byte; + + threadgroup T* dst; + const device uint8_t* src; + + KqQ5_1BlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int /* col_in_block */ = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? bytes_per_block + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj_byte((thread_idx % TCOLS) * bytes_per_thread), + dst(dst_ + bi * dst_ld + bj_byte), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)) {} + + void load_unsafe() const { + const float d = float(*(const device half*)(src + KQ_Q5_1_D_OFFSET)); + const float m = float(*(const device half*)(src + KQ_Q5_1_M_OFFSET)); + const uint32_t qh = *(const device uint32_t*)(src + KQ_Q5_1_QH_OFFSET); + const device uint8_t* qs = src + KQ_Q5_1_QS_OFFSET + bj_byte; + static_assert( + bytes_per_thread == 4 || bytes_per_thread == 8, + "Q5_1 ALU vector load supports bytes_per_thread=4 or 8 (uint)."); + uint8_t qs_b[bytes_per_thread]; +#pragma unroll + for (short v = 0; v < bytes_per_thread / 4; v++) { + const uint qs_v = *reinterpret_cast(qs + v * 4); + *reinterpret_cast(&qs_b[v * 4]) = qs_v; + } +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + const uint8_t b = qs_b[i]; + const int j_lo = bj_byte + i; + const int j_hi = bj_byte + half_block + i; + const uint32_t hi_lo = ((qh >> j_lo) << 4) & 0x10u; + const uint32_t hi_hi = ((qh >> j_hi) << 4) & 0x10u; + const float q5_lo = float(uint32_t(b & 0x0F) | hi_lo); + const float q5_hi = float(uint32_t(b >> 4) | hi_hi); + dst[i] = T(d * q5_lo + m); + dst[half_block + i] = T(d * q5_hi + m); + } + } + + void load_safe(short2 src_tile_dim) const { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + dst[i] = T(0); + dst[half_block + i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + src += tile_stride; + } +}; + +template +[[kernel]] void kq_q5_1_qmm_t( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q5_1_GROUP, "Q5_1 kernel requires group_size=32"); + static_assert(bits == 5, "Q5_1 kernel requires bits=5"); + constexpr int BM = 64, BK = 32, BN = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ5_1BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_t_impl( + w, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_1_qmm_t_splitk( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& k_partition_size, + const constant int& split_k_partition_stride, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q5_1_GROUP, "Q5_1 kernel requires group_size=32"); + static_assert(bits == 5, "Q5_1 kernel requires bits=5"); + constexpr int BM = 32, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ5_1BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + + const int k_start = tid.z * k_partition_size; + x += k_start; + auto wl = w; + wl += (k_start / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + y += tid.z * static_cast(split_k_partition_stride); + + kq_qmm_t_impl( + wl, + x, + y, + Xs, + Ws, + K, + N, + M, + k_partition_size, + tid, + lid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void kq_q5_1_qmm_n( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q5_1_GROUP, "Q5_1 kernel requires group_size=32"); + static_assert(bits == 5, "Q5_1 kernel requires bits=5"); + constexpr int BM = 64, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + using LoaderW = KqQ5_1BlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_n_impl( + w, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_1_qmv_fast( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q5_1_qmv_fast_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_1_qmv( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q5_1_qmv_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_1_dequantize( + const device uint8_t* w, + const device uint8_t* /* scales */, + device T* out, + const constant uint& num_weights, + uint gid [[thread_position_in_grid]]) { + static_assert( + group_size == KQ_Q5_1_GROUP, "Q5_1 kernel requires group_size=32"); + static_assert(bits == 5, "Q5_1 kernel requires bits=5"); + kq_q5_1_dequantize_impl(w, out, num_weights, gid); +} + +inline void kq_get_scale_min_k4( + int j, + const device uint8_t* q, + thread uint8_t& d_out, + thread uint8_t& m_out) { + const int j_lo = j & 3; + const bool j_high = (j & 4) != 0; + const uint8_t a = q[j_lo]; + const uint8_t b = q[j_lo + 4]; + const uint8_t c = q[j_lo + 8]; + const uint8_t d_low = a & 0x3F; + const uint8_t m_low = b & 0x3F; + const uint8_t d_high = (c & 0x0F) | ((a >> 6) << 4); + const uint8_t m_high = (c >> 4) | ((b >> 6) << 4); + d_out = j_high ? d_high : d_low; + m_out = j_high ? m_high : m_low; +} + +// Q4_K: 144 bytes/256 weights. [fp16 d][fp16 dmin][scales[12]][qs[128]]. +// w[i] = d * sub_scale * q4 - dmin * sub_min. Nibble-packed, low=even sb. + +MLX_MTL_CONST int KQ_Q4_K_SUPERBLOCK = 256; +MLX_MTL_CONST int KQ_Q4_K_BLOCK_BYTES = 144; +MLX_MTL_CONST int KQ_Q4_K_D_OFFSET = 0; +MLX_MTL_CONST int KQ_Q4_K_DMIN_OFFSET = 2; +MLX_MTL_CONST int KQ_Q4_K_SCALES_OFFSET = 4; +MLX_MTL_CONST int KQ_Q4_K_QS_OFFSET = 16; + +inline float kq_q4_k_d(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q4_K_D_OFFSET)); +} +inline float kq_q4_k_dmin(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q4_K_DMIN_OFFSET)); +} +inline const device uint8_t* kq_q4_k_scales12_ptr( + const device uint8_t* block_addr) { + return block_addr + KQ_Q4_K_SCALES_OFFSET; +} +inline const device uint8_t* kq_q4_k_qs_ptr(const device uint8_t* block_addr) { + return block_addr + KQ_Q4_K_QS_OFFSET; +} + +template +METAL_FUNC void kq_q4_k_dequantize_impl( + const device uint8_t* w, + device T* out, + const constant uint& num_weights, + uint gid) { + if (gid >= num_weights) { + return; + } + const int sb_id = gid / KQ_Q4_K_SUPERBLOCK; + const int within_sb_total = gid - sb_id * KQ_Q4_K_SUPERBLOCK; + const int sub_block = within_sb_total / 32; + const int within_sb = within_sb_total - sub_block * 32; + const int pair = sub_block / 2; + const bool is_high = (sub_block & 1) != 0; + const int qs_byte_idx = pair * 32 + within_sb; + + const device uint8_t* sb_addr = w + sb_id * KQ_Q4_K_BLOCK_BYTES; + const float d = kq_q4_k_d(sb_addr); + const float dmin = kq_q4_k_dmin(sb_addr); + uint8_t sc6, mn6; + kq_get_scale_min_k4(sub_block, kq_q4_k_scales12_ptr(sb_addr), sc6, mn6); + + const uint8_t byte = kq_q4_k_qs_ptr(sb_addr)[qs_byte_idx]; + const uint8_t q4 = is_high ? (byte >> 4) : (byte & 0x0F); + out[gid] = T(d * float(sc6) * float(q4) - dmin * float(mn6)); +} + +template +METAL_FUNC void kq_q4_k_qmv_fast_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q4_K_SUPERBLOCK, "Q4_K kernel requires group_size=256"); + static_assert(bits == 4, "Q4_K kernel requires bits=4"); + + constexpr int num_simdgroups = 2; + // 2 (vs 4 for flat codecs), super-block scale unpacking needs more registers. + constexpr int results_per_simdgroup = 2; + constexpr int sb_stride = 4; + constexpr uint16_t kmask1 = 0x3f3f; + constexpr uint16_t kmask2 = 0x0f0f; + constexpr uint16_t kmask3 = 0xc0c0; + + typedef float U; + thread U yl[16]; + thread U yh[16]; + thread U result[results_per_simdgroup] = {0}; + + const int ix = simd_lid / 8; + const int it = simd_lid % 8; + const int iq = it / 4; + const int ir = it % 4; + + const int row_bytes = in_vec_size * KQ_Q4_K_BLOCK_BYTES / KQ_Q4_K_SUPERBLOCK; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + const int nb = in_vec_size / KQ_Q4_K_SUPERBLOCK; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += sb_stride) { + const int x_base = ib * KQ_Q4_K_SUPERBLOCK + 64 * iq + 8 * ir; + U sumy[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i++) { + yl[i + 0] = U(x[x_base + i + 0]); + sumy[0] += yl[i + 0]; + yl[i + 8] = U(x[x_base + i + 32]); + sumy[1] += yl[i + 8]; + yh[i + 0] = U(x[x_base + i + 128]); + sumy[2] += yh[i + 0]; + yh[i + 8] = U(x[x_base + i + 160]); + sumy[3] += yh[i + 8]; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + const int row_idx = out_row + row; + const device uint8_t* sb_addr = + w + row_idx * row_bytes + ib * KQ_Q4_K_BLOCK_BYTES; + + const device uint16_t* sc16_src = + reinterpret_cast( + kq_q4_k_scales12_ptr(sb_addr)) + + iq; + uint16_t sc16[4]; + sc16[0] = sc16_src[0] & kmask1; + sc16[1] = sc16_src[2] & kmask1; + sc16[2] = ((sc16_src[4] >> 0) & kmask2) | ((sc16_src[0] & kmask3) >> 2); + sc16[3] = ((sc16_src[4] >> 4) & kmask2) | ((sc16_src[2] & kmask3) >> 2); + thread const uint8_t* sc8 = reinterpret_cast(sc16); + + const device uint16_t* q1 = + reinterpret_cast(kq_q4_k_qs_ptr(sb_addr)) + + 16 * iq + 4 * ir; + const device uint16_t* q2 = q1 + 32; + + U acc1[4] = {U(0), U(0), U(0), U(0)}; + U acc2[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 4; i++) { + const uint16_t q1i = q1[i]; + const uint16_t q2i = q2[i]; + acc1[0] += yl[2 * i + 0] * U(q1i & 0x000F); + acc1[1] += yl[2 * i + 1] * U(q1i & 0x0F00); + acc1[2] += yl[2 * i + 8] * U(q1i & 0x00F0); + acc1[3] += yl[2 * i + 9] * U(q1i & 0xF000); + acc2[0] += yh[2 * i + 0] * U(q2i & 0x000F); + acc2[1] += yh[2 * i + 1] * U(q2i & 0x0F00); + acc2[2] += yh[2 * i + 8] * U(q2i & 0x00F0); + acc2[3] += yh[2 * i + 9] * U(q2i & 0xF000); + } + + const U d = U(kq_q4_k_d(sb_addr)); + const U dmin = U(kq_q4_k_dmin(sb_addr)); + result[row] += d * + ((acc1[0] + acc1[1] * (U(1) / U(256))) * U(sc8[0]) + + (acc1[2] + acc1[3] * (U(1) / U(256))) * U(sc8[1]) * + (U(1) / U(16)) + + (acc2[0] + acc2[1] * (U(1) / U(256))) * U(sc8[4]) + + (acc2[2] + acc2[3] * (U(1) / U(256))) * U(sc8[5]) * + (U(1) / U(16))) - + dmin * + (sumy[0] * U(sc8[2]) + sumy[1] * U(sc8[3]) + sumy[2] * U(sc8[6]) + + sumy[3] * U(sc8[7])); + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void kq_q4_k_qmv_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q4_K_SUPERBLOCK, "Q4_K kernel requires group_size=256"); + static_assert(bits == 4, "Q4_K kernel requires bits=4"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 2; + constexpr int values_per_thread = 8; + constexpr int block_size = values_per_thread * SIMD_SIZE; + + typedef float U; + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + const int row_bytes = in_vec_size * KQ_Q4_K_BLOCK_BYTES / KQ_Q4_K_SUPERBLOCK; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + if (out_row >= out_vec_size) { + return; + } + const int max_row = min(out_vec_size, out_row + results_per_simdgroup); + const int active_rows = max_row - out_row; + + const int lane_k_offset = simd_lid * values_per_thread; + const int sub_block = simd_lid / 4; + const int within_sb = (simd_lid % 4) * values_per_thread; + const int pair = sub_block / 2; + const bool is_high = (sub_block & 1) != 0; + const int qs_byte_idx = pair * 32 + within_sb; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int k = 0; k < in_vec_size; k += block_size) { + load_vector(x + k + lane_k_offset, x_thread); + + U partial_x = 0; +#pragma unroll + for (int i = 0; i < values_per_thread; i++) { + partial_x += x_thread[i]; + } + + const int sb_id = k / KQ_Q4_K_SUPERBLOCK; + for (int row = 0; row < active_rows; row++) { + const int row_idx = out_row + row; + const device uint8_t* sb_addr = + w + row_idx * row_bytes + sb_id * KQ_Q4_K_BLOCK_BYTES; + + const U d = U(kq_q4_k_d(sb_addr)); + const U dmin = U(kq_q4_k_dmin(sb_addr)); + uint8_t sc6, mn6; + kq_get_scale_min_k4(sub_block, kq_q4_k_scales12_ptr(sb_addr), sc6, mn6); + const U eff_scale = d * U(sc6); + const U eff_min = dmin * U(mn6); + + const device uint8_t* qs = kq_q4_k_qs_ptr(sb_addr) + qs_byte_idx; + U partial_q = 0; +#pragma unroll + for (int i = 0; i < values_per_thread; i++) { + const uint8_t q4 = is_high ? (qs[i] >> 4) : (qs[i] & 0x0F); + partial_q += x_thread[i] * U(q4); + } + result[row] += eff_scale * partial_q - eff_min * partial_x; + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0 && row < active_rows) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqQ4_KBlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q4_K_SUPERBLOCK; + MLX_MTL_CONST int bytes_per_block = KQ_Q4_K_BLOCK_BYTES; + MLX_MTL_CONST int sub_block_size = 32; + MLX_MTL_CONST int sub_blocks_per_block = weights_per_block / sub_block_size; + + static_assert( + BCOLS == sub_block_size, + "Q4_K loader requires BCOLS == 32 (one sub-block per K-tile)."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + + const int src_ld; + const int row_bytes; + const int tile_stride; + const short fixed_sub_block_idx; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + short sub_block_idx; + struct Cache { + T vals[n_reads]; + }; + metal::conditional_t cached; + + KqQ4_KBlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int col_in_block = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? 0 + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + fixed_sub_block_idx( + reduction_dim == 0 ? (col_in_block / sub_block_size) : 0), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj((thread_idx % TCOLS) * n_reads), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)), + sub_block_idx(0) {} + + void load_unsafe() { + if constexpr (reduction_dim == 1) { + if (sub_block_idx & 1) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = cached.vals[i]; + } + return; + } + } + + const short sb = (reduction_dim == 0) ? fixed_sub_block_idx : sub_block_idx; + + const float d = float(*(const device half*)(src + KQ_Q4_K_D_OFFSET)); + const float dmin = float(*(const device half*)(src + KQ_Q4_K_DMIN_OFFSET)); + const device uint8_t* scales12 = src + KQ_Q4_K_SCALES_OFFSET; + + uint8_t sc6, mn6; + kq_get_scale_min_k4(sb, scales12, sc6, mn6); + const float eff_scale = d * float(sc6); + const float eff_min = dmin * float(mn6); + + const short pair = sb / 2; + const device uint8_t* qs = src + KQ_Q4_K_QS_OFFSET + pair * 32 + bj; + + static_assert( + n_reads == 8 || n_reads == 16, + "Q4_K ALU vector load supports n_reads=8 (uint2) or 16 (uint4)."); + uint8_t qs_b[n_reads]; + if constexpr (n_reads == 8) { + const uint2 qs_v = *reinterpret_cast(qs); + *reinterpret_cast(&qs_b[0]) = qs_v; + } else { + const uint4 qs_v = *reinterpret_cast(qs); + *reinterpret_cast(&qs_b[0]) = qs_v; + } + + if constexpr (reduction_dim == 1) { + uint8_t sc6_hi, mn6_hi; + kq_get_scale_min_k4(sb + 1, scales12, sc6_hi, mn6_hi); + const float eff_scale_hi = d * float(sc6_hi); + const float eff_min_hi = dmin * float(mn6_hi); + +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const uint8_t b = qs_b[i]; + const uint8_t q4_lo = b & 0x0F; + const uint8_t q4_hi = b >> 4; + dst[i] = T(eff_scale * float(q4_lo) - eff_min); + cached.vals[i] = T(eff_scale_hi * float(q4_hi) - eff_min_hi); + } + } else { + const bool is_high = (sb & 1) != 0; +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const uint8_t q4 = is_high ? (qs_b[i] >> 4) : (qs_b[i] & 0x0F); + dst[i] = T(eff_scale * float(q4) - eff_min); + } + } + } + + void load_safe(short2 src_tile_dim) { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + if (reduction_dim == 1) { + sub_block_idx++; + if (sub_block_idx == sub_blocks_per_block) { + sub_block_idx = 0; + src += bytes_per_block; + } + } else { + src += tile_stride; + } + } +}; + +template +[[kernel]] void kq_q4_k_qmm_t( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q4_K_SUPERBLOCK, "Q4_K kernel requires group_size=256"); + static_assert(bits == 4, "Q4_K kernel requires bits=4"); + constexpr int BM = 64, BK = 32, BN = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ4_KBlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_t_impl( + w, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q4_k_qmm_t_splitk( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& k_partition_size, + const constant int& split_k_partition_stride, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q4_K_SUPERBLOCK, "Q4_K kernel requires group_size=256"); + static_assert(bits == 4, "Q4_K kernel requires bits=4"); + constexpr int BM = 32, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ4_KBlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + + const int k_start = tid.z * k_partition_size; + x += k_start; + auto wl = w; + wl += (k_start / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + y += tid.z * static_cast(split_k_partition_stride); + + kq_qmm_t_impl( + wl, + x, + y, + Xs, + Ws, + K, + N, + M, + k_partition_size, + tid, + lid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void kq_q4_k_qmm_n( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q4_K_SUPERBLOCK, "Q4_K kernel requires group_size=256"); + static_assert(bits == 4, "Q4_K kernel requires bits=4"); + constexpr int BM = 64, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + using LoaderW = KqQ4_KBlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_n_impl( + w, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q4_k_qmv_fast( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q4_k_qmv_fast_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q4_k_qmv( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q4_k_qmv_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q4_k_dequantize( + const device uint8_t* w, + const device uint8_t* /* scales */, + device T* out, + const constant uint& num_weights, + uint gid [[thread_position_in_grid]]) { + static_assert( + group_size == KQ_Q4_K_SUPERBLOCK, "Q4_K kernel requires group_size=256"); + static_assert(bits == 4, "Q4_K kernel requires bits=4"); + kq_q4_k_dequantize_impl(w, out, num_weights, gid); +} + +// Q5_K: 176 bytes/256 weights. [fp16 d][fp16 +// dmin][scales[12]][qh[32]][qs[128]]. q5 = q4 | (high_bit<<4); w[i] = d * +// sub_scale * q5 - dmin * sub_min. + +MLX_MTL_CONST int KQ_Q5_K_SUPERBLOCK = 256; +MLX_MTL_CONST int KQ_Q5_K_BLOCK_BYTES = 176; +MLX_MTL_CONST int KQ_Q5_K_D_OFFSET = 0; +MLX_MTL_CONST int KQ_Q5_K_DMIN_OFFSET = 2; +MLX_MTL_CONST int KQ_Q5_K_SCALES_OFFSET = 4; +MLX_MTL_CONST int KQ_Q5_K_QH_OFFSET = 16; +MLX_MTL_CONST int KQ_Q5_K_QS_OFFSET = 48; + +inline float kq_q5_k_d(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q5_K_D_OFFSET)); +} +inline float kq_q5_k_dmin(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q5_K_DMIN_OFFSET)); +} +inline const device uint8_t* kq_q5_k_scales12_ptr( + const device uint8_t* block_addr) { + return block_addr + KQ_Q5_K_SCALES_OFFSET; +} +inline const device uint8_t* kq_q5_k_qh_ptr(const device uint8_t* block_addr) { + return block_addr + KQ_Q5_K_QH_OFFSET; +} +inline const device uint8_t* kq_q5_k_qs_ptr(const device uint8_t* block_addr) { + return block_addr + KQ_Q5_K_QS_OFFSET; +} + +template +METAL_FUNC void kq_q5_k_dequantize_impl( + const device uint8_t* w, + device T* out, + const constant uint& num_weights, + uint gid) { + if (gid >= num_weights) { + return; + } + const int sb_id = gid / KQ_Q5_K_SUPERBLOCK; + const int within_sb_total = gid - sb_id * KQ_Q5_K_SUPERBLOCK; + const int sub_block = within_sb_total / 32; + const int within_sb = within_sb_total - sub_block * 32; + const int pair = sub_block / 2; + const bool is_high = (sub_block & 1) != 0; + const int qs_byte_idx = pair * 32 + within_sb; + + const device uint8_t* sb_addr = w + sb_id * KQ_Q5_K_BLOCK_BYTES; + const float d = kq_q5_k_d(sb_addr); + const float dmin = kq_q5_k_dmin(sb_addr); + uint8_t sc6, mn6; + kq_get_scale_min_k4(sub_block, kq_q5_k_scales12_ptr(sb_addr), sc6, mn6); + + const uint8_t qs_byte = kq_q5_k_qs_ptr(sb_addr)[qs_byte_idx]; + const uint8_t q4 = is_high ? (qs_byte >> 4) : (qs_byte & 0x0F); + const uint8_t qh_byte = kq_q5_k_qh_ptr(sb_addr)[within_sb]; + const uint8_t high_bit = (qh_byte >> sub_block) & 1u; + const uint8_t q5 = q4 | (high_bit << 4); + out[gid] = T(d * float(sc6) * float(q5) - dmin * float(mn6)); +} + +template +METAL_FUNC void kq_q5_k_qmv_fast_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q5_K_SUPERBLOCK, "Q5_K kernel requires group_size=256"); + static_assert(bits == 5, "Q5_K kernel requires bits=5"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 2; + constexpr int sb_stride = 4; + constexpr uint16_t kmask1 = 0x3f3f; + constexpr uint16_t kmask2 = 0x0f0f; + constexpr uint16_t kmask3 = 0xc0c0; + + typedef float U; + thread U yl[16]; + thread U yh[16]; + thread U result[results_per_simdgroup] = {0}; + + const int tid_lane = simd_lid / 4; + const int ix = simd_lid % 4; + const int iq = tid_lane / 4; + const int ir = tid_lane % 4; + + const uint8_t hm1 = uint8_t(1u << (2 * iq)); + const uint8_t hm2 = uint8_t(hm1 << 1); + const uint8_t hm3 = uint8_t(hm1 << 4); + const uint8_t hm4 = uint8_t(hm2 << 4); + + const int row_bytes = in_vec_size * KQ_Q5_K_BLOCK_BYTES / KQ_Q5_K_SUPERBLOCK; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + const int nb = in_vec_size / KQ_Q5_K_SUPERBLOCK; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += sb_stride) { + const int x_base = ib * KQ_Q5_K_SUPERBLOCK + 64 * iq + 8 * ir; + U sumy[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i++) { + yl[i + 0] = U(x[x_base + i + 0]); + sumy[0] += yl[i + 0]; + yl[i + 8] = U(x[x_base + i + 32]); + sumy[1] += yl[i + 8]; + yh[i + 0] = U(x[x_base + i + 128]); + sumy[2] += yh[i + 0]; + yh[i + 8] = U(x[x_base + i + 160]); + sumy[3] += yh[i + 8]; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + const int row_idx = out_row + row; + const device uint8_t* sb_addr = + w + row_idx * row_bytes + ib * KQ_Q5_K_BLOCK_BYTES; + + const device uint16_t* sc16_src = + reinterpret_cast( + kq_q5_k_scales12_ptr(sb_addr)) + + iq; + uint16_t sc16[4]; + sc16[0] = sc16_src[0] & kmask1; + sc16[1] = sc16_src[2] & kmask1; + sc16[2] = ((sc16_src[4] >> 0) & kmask2) | ((sc16_src[0] & kmask3) >> 2); + sc16[3] = ((sc16_src[4] >> 4) & kmask2) | ((sc16_src[2] & kmask3) >> 2); + thread const uint8_t* sc8 = reinterpret_cast(sc16); + + const device uint16_t* q1 = + reinterpret_cast(kq_q5_k_qs_ptr(sb_addr)) + + 16 * iq + 4 * ir; + const device uint16_t* q2 = q1 + 32; + const device uint8_t* qh = kq_q5_k_qh_ptr(sb_addr) + 8 * ir; + + U acc1[4] = {U(0), U(0), U(0), U(0)}; + U acc2[4] = {U(0), U(0), U(0), U(0)}; + U accH[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 4; i++) { + const uint16_t q1i = q1[i]; + const uint16_t q2i = q2[i]; + const uint8_t h0 = qh[2 * i + 0]; + const uint8_t h1 = qh[2 * i + 1]; + acc1[0] += yl[2 * i + 0] * U(q1i & 0x000F); + acc1[1] += yl[2 * i + 1] * U(q1i & 0x0F00); + acc1[2] += yl[2 * i + 8] * U(q1i & 0x00F0); + acc1[3] += yl[2 * i + 9] * U(q1i & 0xF000); + acc2[0] += yh[2 * i + 0] * U(q2i & 0x000F); + acc2[1] += yh[2 * i + 1] * U(q2i & 0x0F00); + acc2[2] += yh[2 * i + 8] * U(q2i & 0x00F0); + acc2[3] += yh[2 * i + 9] * U(q2i & 0xF000); + accH[0] += ((h0 & hm1) ? yl[2 * i + 0] : U(0)) + + ((h1 & hm1) ? yl[2 * i + 1] : U(0)); + accH[1] += ((h0 & hm2) ? yl[2 * i + 8] : U(0)) + + ((h1 & hm2) ? yl[2 * i + 9] : U(0)); + accH[2] += ((h0 & hm3) ? yh[2 * i + 0] : U(0)) + + ((h1 & hm3) ? yh[2 * i + 1] : U(0)); + accH[3] += ((h0 & hm4) ? yh[2 * i + 8] : U(0)) + + ((h1 & hm4) ? yh[2 * i + 9] : U(0)); + } + + const U d = U(kq_q5_k_d(sb_addr)); + const U dmin = U(kq_q5_k_dmin(sb_addr)); + result[row] += d * + (U(sc8[0]) * + ((acc1[0] + acc1[1] * (U(1) / U(256))) + U(16) * accH[0]) + + U(sc8[1]) * + ((acc1[2] + acc1[3] * (U(1) / U(256))) * (U(1) / U(16)) + + U(16) * accH[1]) + + U(sc8[4]) * + ((acc2[0] + acc2[1] * (U(1) / U(256))) + U(16) * accH[2]) + + U(sc8[5]) * + ((acc2[2] + acc2[3] * (U(1) / U(256))) * (U(1) / U(16)) + + U(16) * accH[3])) - + dmin * + (sumy[0] * U(sc8[2]) + sumy[1] * U(sc8[3]) + sumy[2] * U(sc8[6]) + + sumy[3] * U(sc8[7])); + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void kq_q5_k_qmv_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q5_K_SUPERBLOCK, "Q5_K kernel requires group_size=256"); + static_assert(bits == 5, "Q5_K kernel requires bits=5"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 2; + constexpr int sb_stride = 4; + constexpr uint16_t kmask1 = 0x3f3f; + constexpr uint16_t kmask2 = 0x0f0f; + constexpr uint16_t kmask3 = 0xc0c0; + + typedef float U; + thread U yl[16]; + thread U yh[16]; + thread U result[results_per_simdgroup] = {0}; + + const int row_bytes = in_vec_size * KQ_Q5_K_BLOCK_BYTES / KQ_Q5_K_SUPERBLOCK; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + if (out_row >= out_vec_size) { + return; + } + const int max_row = min(out_vec_size, out_row + results_per_simdgroup); + const int active_rows = max_row - out_row; + + const int tid_lane = simd_lid / 4; + const int ix = simd_lid % 4; + const int iq = tid_lane / 4; + const int ir = tid_lane % 4; + + const uint8_t hm1 = uint8_t(1u << (2 * iq)); + const uint8_t hm2 = uint8_t(hm1 << 1); + const uint8_t hm3 = uint8_t(hm1 << 4); + const uint8_t hm4 = uint8_t(hm2 << 4); + + const int nb = in_vec_size / KQ_Q5_K_SUPERBLOCK; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += sb_stride) { + const int x_base = ib * KQ_Q5_K_SUPERBLOCK + 64 * iq + 8 * ir; + U sumy[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i++) { + yl[i + 0] = U(x[x_base + i + 0]); + sumy[0] += yl[i + 0]; + yl[i + 8] = U(x[x_base + i + 32]); + sumy[1] += yl[i + 8]; + yh[i + 0] = U(x[x_base + i + 128]); + sumy[2] += yh[i + 0]; + yh[i + 8] = U(x[x_base + i + 160]); + sumy[3] += yh[i + 8]; + } + + for (int row = 0; row < active_rows; row++) { + const int row_idx = out_row + row; + const device uint8_t* sb_addr = + w + row_idx * row_bytes + ib * KQ_Q5_K_BLOCK_BYTES; + + const device uint16_t* sc16_src = + reinterpret_cast( + kq_q5_k_scales12_ptr(sb_addr)) + + iq; + uint16_t sc16[4]; + sc16[0] = sc16_src[0] & kmask1; + sc16[1] = sc16_src[2] & kmask1; + sc16[2] = ((sc16_src[4] >> 0) & kmask2) | ((sc16_src[0] & kmask3) >> 2); + sc16[3] = ((sc16_src[4] >> 4) & kmask2) | ((sc16_src[2] & kmask3) >> 2); + thread const uint8_t* sc8 = reinterpret_cast(sc16); + + const device uint16_t* q1 = + reinterpret_cast(kq_q5_k_qs_ptr(sb_addr)) + + 16 * iq + 4 * ir; + const device uint16_t* q2 = q1 + 32; + const device uint8_t* qh = kq_q5_k_qh_ptr(sb_addr) + 8 * ir; + + U acc1[4] = {U(0), U(0), U(0), U(0)}; + U acc2[4] = {U(0), U(0), U(0), U(0)}; + U accH[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 4; i++) { + const uint16_t q1i = q1[i]; + const uint16_t q2i = q2[i]; + const uint8_t h0 = qh[2 * i + 0]; + const uint8_t h1 = qh[2 * i + 1]; + acc1[0] += yl[2 * i + 0] * U(q1i & 0x000F); + acc1[1] += yl[2 * i + 1] * U(q1i & 0x0F00); + acc1[2] += yl[2 * i + 8] * U(q1i & 0x00F0); + acc1[3] += yl[2 * i + 9] * U(q1i & 0xF000); + acc2[0] += yh[2 * i + 0] * U(q2i & 0x000F); + acc2[1] += yh[2 * i + 1] * U(q2i & 0x0F00); + acc2[2] += yh[2 * i + 8] * U(q2i & 0x00F0); + acc2[3] += yh[2 * i + 9] * U(q2i & 0xF000); + accH[0] += ((h0 & hm1) ? yl[2 * i + 0] : U(0)) + + ((h1 & hm1) ? yl[2 * i + 1] : U(0)); + accH[1] += ((h0 & hm2) ? yl[2 * i + 8] : U(0)) + + ((h1 & hm2) ? yl[2 * i + 9] : U(0)); + accH[2] += ((h0 & hm3) ? yh[2 * i + 0] : U(0)) + + ((h1 & hm3) ? yh[2 * i + 1] : U(0)); + accH[3] += ((h0 & hm4) ? yh[2 * i + 8] : U(0)) + + ((h1 & hm4) ? yh[2 * i + 9] : U(0)); + } + + const U d = U(kq_q5_k_d(sb_addr)); + const U dmin = U(kq_q5_k_dmin(sb_addr)); + result[row] += d * + (U(sc8[0]) * + ((acc1[0] + acc1[1] * (U(1) / U(256))) + U(16) * accH[0]) + + U(sc8[1]) * + ((acc1[2] + acc1[3] * (U(1) / U(256))) * (U(1) / U(16)) + + U(16) * accH[1]) + + U(sc8[4]) * + ((acc2[0] + acc2[1] * (U(1) / U(256))) + U(16) * accH[2]) + + U(sc8[5]) * + ((acc2[2] + acc2[3] * (U(1) / U(256))) * (U(1) / U(16)) + + U(16) * accH[3])) - + dmin * + (sumy[0] * U(sc8[2]) + sumy[1] * U(sc8[3]) + sumy[2] * U(sc8[6]) + + sumy[3] * U(sc8[7])); + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0 && row < active_rows) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqQ5_KBlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q5_K_SUPERBLOCK; + MLX_MTL_CONST int bytes_per_block = KQ_Q5_K_BLOCK_BYTES; + MLX_MTL_CONST int sub_block_size = 32; + MLX_MTL_CONST int sub_blocks_per_block = weights_per_block / sub_block_size; + + static_assert( + BCOLS == sub_block_size, + "Q5_K loader requires BCOLS == 32 (one sub-block per K-tile)."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + + const int src_ld; + const int row_bytes; + const int tile_stride; + const short fixed_sub_block_idx; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + short sub_block_idx; + struct Cache { + T vals[n_reads]; + uint8_t qh[n_reads]; + }; + metal::conditional_t cached; + + KqQ5_KBlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int col_in_block = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? 0 + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + fixed_sub_block_idx( + reduction_dim == 0 ? (col_in_block / sub_block_size) : 0), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj((thread_idx % TCOLS) * n_reads), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)), + sub_block_idx(0) {} + + void load_unsafe() { + if constexpr (reduction_dim == 1) { + if (sub_block_idx & 1) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = cached.vals[i]; + } + return; + } + } + + const short sb = (reduction_dim == 0) ? fixed_sub_block_idx : sub_block_idx; + + const float d = float(*(const device half*)(src + KQ_Q5_K_D_OFFSET)); + const float dmin = float(*(const device half*)(src + KQ_Q5_K_DMIN_OFFSET)); + const device uint8_t* scales12 = src + KQ_Q5_K_SCALES_OFFSET; + + uint8_t sc6, mn6; + kq_get_scale_min_k4(sb, scales12, sc6, mn6); + const float eff_scale = d * float(sc6); + const float eff_min = dmin * float(mn6); + + const short pair = sb / 2; + const device uint8_t* qs = src + KQ_Q5_K_QS_OFFSET + pair * 32 + bj; + const device uint8_t* qh = src + KQ_Q5_K_QH_OFFSET + bj; + + static_assert( + n_reads == 8 || n_reads == 16, + "Q5_K ALU vector load supports n_reads=8 (uint2) or 16 (uint4)."); + uint8_t qs_b[n_reads]; + if constexpr (n_reads == 8) { + const uint2 qs_v = *reinterpret_cast(qs); + *reinterpret_cast(&qs_b[0]) = qs_v; + } else { + const uint4 qs_v = *reinterpret_cast(qs); + *reinterpret_cast(&qs_b[0]) = qs_v; + } + + if constexpr (reduction_dim == 1) { + uint8_t sc6_hi, mn6_hi; + kq_get_scale_min_k4(sb + 1, scales12, sc6_hi, mn6_hi); + const float eff_scale_hi = d * float(sc6_hi); + const float eff_min_hi = dmin * float(mn6_hi); + + if (sub_block_idx == 0) { + if constexpr (n_reads == 8) { + const uint2 qh_v = *reinterpret_cast(qh); + *reinterpret_cast(&cached.qh[0]) = qh_v; + } else { + const uint4 qh_v = *reinterpret_cast(qh); + *reinterpret_cast(&cached.qh[0]) = qh_v; + } + } + +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const uint8_t b = qs_b[i]; + const uint8_t h = cached.qh[i]; + const uint8_t q4_lo = b & 0x0F; + const uint8_t q4_hi = b >> 4; + const uint8_t hi_lo = (h >> sb) & 1u; + const uint8_t hi_hi = (h >> (sb + 1)) & 1u; + const uint8_t q5_lo = q4_lo | (hi_lo << 4); + const uint8_t q5_hi = q4_hi | (hi_hi << 4); + dst[i] = T(eff_scale * float(q5_lo) - eff_min); + cached.vals[i] = T(eff_scale_hi * float(q5_hi) - eff_min_hi); + } + } else { + uint8_t qh_b[n_reads]; + if constexpr (n_reads == 8) { + const uint2 qh_v = *reinterpret_cast(qh); + *reinterpret_cast(&qh_b[0]) = qh_v; + } else { + const uint4 qh_v = *reinterpret_cast(qh); + *reinterpret_cast(&qh_b[0]) = qh_v; + } + const bool is_high = (sb & 1) != 0; +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const uint8_t q4 = is_high ? (qs_b[i] >> 4) : (qs_b[i] & 0x0F); + const uint8_t hi = (qh_b[i] >> sb) & 1u; + const uint8_t q5 = q4 | (hi << 4); + dst[i] = T(eff_scale * float(q5) - eff_min); + } + } + } + + void load_safe(short2 src_tile_dim) { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + if (reduction_dim == 1) { + sub_block_idx++; + if (sub_block_idx == sub_blocks_per_block) { + sub_block_idx = 0; + src += bytes_per_block; + } + } else { + src += tile_stride; + } + } +}; + +template +[[kernel]] void kq_q5_k_qmm_t( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q5_K_SUPERBLOCK, "Q5_K kernel requires group_size=256"); + static_assert(bits == 5, "Q5_K kernel requires bits=5"); + constexpr int BM = 64, BK = 32, BN = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ5_KBlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_t_impl( + w, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_k_qmm_t_splitk( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& k_partition_size, + const constant int& split_k_partition_stride, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q5_K_SUPERBLOCK, "Q5_K kernel requires group_size=256"); + static_assert(bits == 5, "Q5_K kernel requires bits=5"); + constexpr int BM = 32, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ5_KBlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + + const int k_start = tid.z * k_partition_size; + x += k_start; + auto wl = w; + wl += (k_start / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + y += tid.z * static_cast(split_k_partition_stride); + + kq_qmm_t_impl( + wl, + x, + y, + Xs, + Ws, + K, + N, + M, + k_partition_size, + tid, + lid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void kq_q5_k_qmm_n( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q5_K_SUPERBLOCK, "Q5_K kernel requires group_size=256"); + static_assert(bits == 5, "Q5_K kernel requires bits=5"); + constexpr int BM = 64, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + using LoaderW = KqQ5_KBlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_n_impl( + w, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_k_qmv_fast( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q5_k_qmv_fast_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_k_qmv( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q5_k_qmv_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_k_dequantize( + const device uint8_t* w, + const device uint8_t* /* scales */, + device T* out, + const constant uint& num_weights, + uint gid [[thread_position_in_grid]]) { + static_assert( + group_size == KQ_Q5_K_SUPERBLOCK, "Q5_K kernel requires group_size=256"); + static_assert(bits == 5, "Q5_K kernel requires bits=5"); + kq_q5_k_dequantize_impl(w, out, num_weights, gid); +} + +// Q6_K: 210 bytes/256 weights. REVERSED field order: [ql[128]][qh[64]][int8 +// scales[16]][fp16 d]. q6 = (low4 | (high2<<4)) - 32; w[i] = d * sc * q6. + +MLX_MTL_CONST int KQ_Q6_K_SUPERBLOCK = 256; +MLX_MTL_CONST int KQ_Q6_K_BLOCK_BYTES = 210; +MLX_MTL_CONST int KQ_Q6_K_QL_OFFSET = 0; +MLX_MTL_CONST int KQ_Q6_K_QH_OFFSET = 128; +MLX_MTL_CONST int KQ_Q6_K_SCALES_OFFSET = 192; +MLX_MTL_CONST int KQ_Q6_K_D_OFFSET = 208; + +inline float kq_q6_k_d(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q6_K_D_OFFSET)); +} +inline const device uint8_t* kq_q6_k_ql_ptr(const device uint8_t* block_addr) { + return block_addr + KQ_Q6_K_QL_OFFSET; +} +inline const device uint8_t* kq_q6_k_qh_ptr(const device uint8_t* block_addr) { + return block_addr + KQ_Q6_K_QH_OFFSET; +} +inline const device int8_t* kq_q6_k_scales_ptr( + const device uint8_t* block_addr) { + return (const device int8_t*)(block_addr + KQ_Q6_K_SCALES_OFFSET); +} + +template +METAL_FUNC void kq_q6_k_dequantize_impl( + const device uint8_t* w, + device T* out, + const constant uint& num_weights, + uint gid) { + if (gid >= num_weights) { + return; + } + const int sb_id = gid / KQ_Q6_K_SUPERBLOCK; + const int within_sb = gid - sb_id * KQ_Q6_K_SUPERBLOCK; + + const int half_idx = within_sb / 128; + const int within_half = within_sb - half_idx * 128; + const int quadrant = within_half / 32; + const int l = within_half - quadrant * 32; + + const device uint8_t* sb_addr = w + sb_id * KQ_Q6_K_BLOCK_BYTES; + const float d = kq_q6_k_d(sb_addr); + const device uint8_t* ql = kq_q6_k_ql_ptr(sb_addr) + half_idx * 64; + const device uint8_t* qh = kq_q6_k_qh_ptr(sb_addr) + half_idx * 32; + const device int8_t* sc = kq_q6_k_scales_ptr(sb_addr) + half_idx * 8; + + const int ql_idx = (quadrant & 1) * 32 + l; + const bool is_high_nibble = (quadrant >= 2); + const uint8_t low4 = is_high_nibble ? (uint8_t)(ql[ql_idx] >> 4) + : (uint8_t)(ql[ql_idx] & 0x0F); + const uint8_t high2 = (uint8_t)((qh[l] >> (quadrant * 2)) & 0x03); + const int8_t q6 = (int8_t)(low4 | (high2 << 4)) - (int8_t)32; + const int is_off = l / 16; + const int8_t scale_i8 = sc[is_off + 2 * quadrant]; + out[gid] = T(d * float(scale_i8) * float(q6)); +} + +template +METAL_FUNC void kq_q6_k_qmv_fast_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q6_K_SUPERBLOCK, "Q6_K kernel requires group_size=256"); + static_assert(bits == 6, "Q6_K kernel requires bits=6"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int sb_stride = 2; + + typedef float U; + thread U yl[16]; + thread U result[results_per_simdgroup] = {0}; + + const int tid_lane = simd_lid / 2; + const int ix = simd_lid % 2; + const int ip = tid_lane / 8; + const int il = tid_lane % 8; + const int l0 = 4 * il; + const int is = 8 * ip + l0 / 16; + + const int row_bytes = in_vec_size * KQ_Q6_K_BLOCK_BYTES / KQ_Q6_K_SUPERBLOCK; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + const int nb = in_vec_size / KQ_Q6_K_SUPERBLOCK; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += sb_stride) { + const int x_base = ib * KQ_Q6_K_SUPERBLOCK + 128 * ip + l0; +#pragma unroll + for (int l = 0; l < 4; l++) { + yl[4 * l + 0] = U(x[x_base + l + 0]); + yl[4 * l + 1] = U(x[x_base + l + 32]); + yl[4 * l + 2] = U(x[x_base + l + 64]); + yl[4 * l + 3] = U(x[x_base + l + 96]); + } + + for (int row = 0; row < results_per_simdgroup; row++) { + const int row_idx = out_row + row; + const device uint8_t* sb_addr = + w + row_idx * row_bytes + ib * KQ_Q6_K_BLOCK_BYTES; + + const device uint8_t* q1 = kq_q6_k_ql_ptr(sb_addr) + 64 * ip + l0; + const device uint8_t* q2 = q1 + 32; + const device uint8_t* qh = kq_q6_k_qh_ptr(sb_addr) + 32 * ip + l0; + const device int8_t* sc = kq_q6_k_scales_ptr(sb_addr) + is; + + U sums[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int l = 0; l < 4; l++) { + const uint8_t q1l = q1[l]; + const uint8_t q2l = q2[l]; + const uint8_t qhl = qh[l]; + const int8_t v0 = + int8_t((q1l & 0x0F) | ((qhl & 0x03) << 4)) - int8_t(32); + const int8_t v1 = + int8_t((q2l & 0x0F) | ((qhl & 0x0C) << 2)) - int8_t(32); + const int8_t v2 = int8_t((q1l >> 4) | ((qhl & 0x30) << 0)) - int8_t(32); + const int8_t v3 = int8_t((q2l >> 4) | ((qhl & 0xC0) >> 2)) - int8_t(32); + sums[0] += yl[4 * l + 0] * U(v0); + sums[1] += yl[4 * l + 1] * U(v1); + sums[2] += yl[4 * l + 2] * U(v2); + sums[3] += yl[4 * l + 3] * U(v3); + } + + const U d = U(kq_q6_k_d(sb_addr)); + result[row] += d * + (sums[0] * U(sc[0]) + sums[1] * U(sc[2]) + sums[2] * U(sc[4]) + + sums[3] * U(sc[6])); + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void kq_q6_k_qmv_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q6_K_SUPERBLOCK, "Q6_K kernel requires group_size=256"); + static_assert(bits == 6, "Q6_K kernel requires bits=6"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int sb_stride = 2; + + typedef float U; + thread U yl[16]; + thread U result[results_per_simdgroup] = {0}; + + const int row_bytes = in_vec_size * KQ_Q6_K_BLOCK_BYTES / KQ_Q6_K_SUPERBLOCK; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + if (out_row >= out_vec_size) { + return; + } + const int max_row = min(out_vec_size, out_row + results_per_simdgroup); + const int active_rows = max_row - out_row; + + const int tid_lane = simd_lid / 2; + const int ix = simd_lid % 2; + const int ip = tid_lane / 8; + const int il = tid_lane % 8; + const int l0 = 4 * il; + const int is = 8 * ip + l0 / 16; + + const int nb = in_vec_size / KQ_Q6_K_SUPERBLOCK; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += sb_stride) { + const int x_base = ib * KQ_Q6_K_SUPERBLOCK + 128 * ip + l0; +#pragma unroll + for (int l = 0; l < 4; l++) { + yl[4 * l + 0] = U(x[x_base + l + 0]); + yl[4 * l + 1] = U(x[x_base + l + 32]); + yl[4 * l + 2] = U(x[x_base + l + 64]); + yl[4 * l + 3] = U(x[x_base + l + 96]); + } + + for (int row = 0; row < active_rows; row++) { + const int row_idx = out_row + row; + const device uint8_t* sb_addr = + w + row_idx * row_bytes + ib * KQ_Q6_K_BLOCK_BYTES; + + const device uint8_t* q1 = kq_q6_k_ql_ptr(sb_addr) + 64 * ip + l0; + const device uint8_t* q2 = q1 + 32; + const device uint8_t* qh = kq_q6_k_qh_ptr(sb_addr) + 32 * ip + l0; + const device int8_t* sc = kq_q6_k_scales_ptr(sb_addr) + is; + + U sums[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int l = 0; l < 4; l++) { + const uint8_t q1l = q1[l]; + const uint8_t q2l = q2[l]; + const uint8_t qhl = qh[l]; + const int8_t v0 = + int8_t((q1l & 0x0F) | ((qhl & 0x03) << 4)) - int8_t(32); + const int8_t v1 = + int8_t((q2l & 0x0F) | ((qhl & 0x0C) << 2)) - int8_t(32); + const int8_t v2 = int8_t((q1l >> 4) | ((qhl & 0x30) << 0)) - int8_t(32); + const int8_t v3 = int8_t((q2l >> 4) | ((qhl & 0xC0) >> 2)) - int8_t(32); + sums[0] += yl[4 * l + 0] * U(v0); + sums[1] += yl[4 * l + 1] * U(v1); + sums[2] += yl[4 * l + 2] * U(v2); + sums[3] += yl[4 * l + 3] * U(v3); + } + + const U d = U(kq_q6_k_d(sb_addr)); + result[row] += d * + (sums[0] * U(sc[0]) + sums[1] * U(sc[2]) + sums[2] * U(sc[4]) + + sums[3] * U(sc[6])); + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0 && row < active_rows) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqQ6_KBlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q6_K_SUPERBLOCK; + MLX_MTL_CONST int bytes_per_block = KQ_Q6_K_BLOCK_BYTES; + MLX_MTL_CONST int k_tile_size = 32; + MLX_MTL_CONST int k_tiles_per_block = weights_per_block / k_tile_size; + + static_assert( + BCOLS == k_tile_size, + "Q6_K loader requires BCOLS == 32 (one K-tile per iteration)."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + + const int src_ld; + const int row_bytes; + const int tile_stride; + const short fixed_kt; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + short kt; + struct Caches { + T q1[n_reads]; + T q2[n_reads]; + T q3[n_reads]; + }; + metal::conditional_t cached; + + KqQ6_KBlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int col_in_block = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? 0 + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + fixed_kt(reduction_dim == 0 ? (col_in_block / k_tile_size) : 0), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj((thread_idx % TCOLS) * n_reads), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)), + kt(0) {} + + void load_unsafe() { + if constexpr (reduction_dim == 1) { + const short q = kt & 3; + if (q == 1) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.q1[i]; + return; + } + if (q == 2) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.q2[i]; + return; + } + if (q == 3) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.q3[i]; + return; + } + const short half_idx = kt / 4; + const short scale_off = (bj >= 16) ? 1 : 0; + const float d = float(*(const device half*)(src + KQ_Q6_K_D_OFFSET)); + const device int8_t* scales = + (const device int8_t*)(src + KQ_Q6_K_SCALES_OFFSET); + const float es0 = d * float(scales[(kt + 0) * 2 + scale_off]); + const float es1 = d * float(scales[(kt + 1) * 2 + scale_off]); + const float es2 = d * float(scales[(kt + 2) * 2 + scale_off]); + const float es3 = d * float(scales[(kt + 3) * 2 + scale_off]); + + const device uint8_t* ql_a = + src + KQ_Q6_K_QL_OFFSET + half_idx * 64 + bj; // q=0 lo, q=2 hi + const device uint8_t* ql_b = + src + KQ_Q6_K_QL_OFFSET + half_idx * 64 + 32 + bj; // q=1 lo, q=3 hi + const device uint8_t* qh = src + KQ_Q6_K_QH_OFFSET + half_idx * 32 + bj; + +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const uint8_t a = ql_a[i]; + const uint8_t b = ql_b[i]; + const uint8_t h = qh[i]; + const int8_t q6_0 = + (int8_t)((a & 0x0F) | ((h & 0x03) << 4)) - (int8_t)32; + const int8_t q6_1 = + (int8_t)((b & 0x0F) | (((h >> 2) & 0x03) << 4)) - (int8_t)32; + const int8_t q6_2 = + (int8_t)((a >> 4) | (((h >> 4) & 0x03) << 4)) - (int8_t)32; + const int8_t q6_3 = + (int8_t)((b >> 4) | (((h >> 6) & 0x03) << 4)) - (int8_t)32; + dst[i] = T(es0 * float(q6_0)); + cached.q1[i] = T(es1 * float(q6_1)); + cached.q2[i] = T(es2 * float(q6_2)); + cached.q3[i] = T(es3 * float(q6_3)); + } + return; + } + + const short kt_use = fixed_kt; + const short half_idx = kt_use / 4; + const short quadrant = kt_use - half_idx * 4; + const bool is_high_nibble = (quadrant >= 2); + const short qh_shift = quadrant * 2; + const short scale_idx = kt_use * 2 + (bj >= 16 ? 1 : 0); + + const float d = float(*(const device half*)(src + KQ_Q6_K_D_OFFSET)); + const int8_t scale_i8 = + ((const device int8_t*)(src + KQ_Q6_K_SCALES_OFFSET))[scale_idx]; + const float eff_scale = d * float(scale_i8); + + const device uint8_t* ql_base = + src + KQ_Q6_K_QL_OFFSET + half_idx * 64 + (quadrant & 1) * 32 + bj; + const device uint8_t* qh_base = + src + KQ_Q6_K_QH_OFFSET + half_idx * 32 + bj; + +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const uint8_t low4 = + is_high_nibble ? (ql_base[i] >> 4) : (ql_base[i] & 0x0F); + const uint8_t high2 = (uint8_t)((qh_base[i] >> qh_shift) & 0x03); + const int8_t q6 = (int8_t)(low4 | (high2 << 4)) - (int8_t)32; + dst[i] = T(eff_scale * float(q6)); + } + } + + void load_safe(short2 src_tile_dim) { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + if (reduction_dim == 1) { + kt++; + if (kt == k_tiles_per_block) { + kt = 0; + src += bytes_per_block; + } + } else { + src += tile_stride; + } + } +}; + +template +[[kernel]] void kq_q6_k_qmm_t( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q6_K_SUPERBLOCK, "Q6_K kernel requires group_size=256"); + static_assert(bits == 6, "Q6_K kernel requires bits=6"); + constexpr int BM = 64, BK = 32, BN = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ6_KBlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_t_impl( + w, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q6_k_qmm_t_splitk( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& k_partition_size, + const constant int& split_k_partition_stride, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q6_K_SUPERBLOCK, "Q6_K kernel requires group_size=256"); + static_assert(bits == 6, "Q6_K kernel requires bits=6"); + constexpr int BM = 32, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ6_KBlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + + const int k_start = tid.z * k_partition_size; + x += k_start; + auto wl = w; + wl += (k_start / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + y += tid.z * static_cast(split_k_partition_stride); + + kq_qmm_t_impl( + wl, + x, + y, + Xs, + Ws, + K, + N, + M, + k_partition_size, + tid, + lid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void kq_q6_k_qmm_n( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q6_K_SUPERBLOCK, "Q6_K kernel requires group_size=256"); + static_assert(bits == 6, "Q6_K kernel requires bits=6"); + constexpr int BM = 64, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + using LoaderW = KqQ6_KBlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_n_impl( + w, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q6_k_qmv_fast( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q6_k_qmv_fast_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q6_k_qmv( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q6_k_qmv_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q6_k_dequantize( + const device uint8_t* w, + const device uint8_t* /* scales */, + device T* out, + const constant uint& num_weights, + uint gid [[thread_position_in_grid]]) { + static_assert( + group_size == KQ_Q6_K_SUPERBLOCK, "Q6_K kernel requires group_size=256"); + static_assert(bits == 6, "Q6_K kernel requires bits=6"); + kq_q6_k_dequantize_impl(w, out, num_weights, gid); +} + +// Q3_K: 110 bytes/256 weights. [hmask[32]][qs[64]][scales[12]][fp16 d]. +// q3 = q2 - h; hmask SET means h=0, CLEAR means h=4. +// w[i] = d * (scale - 32) * q3. Symmetric. + +MLX_MTL_CONST int KQ_Q3_K_SUPERBLOCK = 256; +MLX_MTL_CONST int KQ_Q3_K_BLOCK_BYTES = 110; +MLX_MTL_CONST int KQ_Q3_K_HMASK_OFFSET = 0; +MLX_MTL_CONST int KQ_Q3_K_QS_OFFSET = 32; +MLX_MTL_CONST int KQ_Q3_K_SCALES_OFFSET = 96; +MLX_MTL_CONST int KQ_Q3_K_D_OFFSET = 108; + +inline float kq_q3_k_d(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q3_K_D_OFFSET)); +} +inline const device uint8_t* kq_q3_k_hmask_ptr( + const device uint8_t* block_addr) { + return block_addr + KQ_Q3_K_HMASK_OFFSET; +} +inline const device uint8_t* kq_q3_k_qs_ptr(const device uint8_t* block_addr) { + return block_addr + KQ_Q3_K_QS_OFFSET; +} +inline const device uint8_t* kq_q3_k_scales12_ptr( + const device uint8_t* block_addr) { + return block_addr + KQ_Q3_K_SCALES_OFFSET; +} + +inline uint8_t kq_q3_k_unpack_scale(int j, const device uint8_t* q12) { + const int quad = j / 4; + const int byte = j & 3; + const uint8_t low4 = (q12[(quad & 1) * 4 + byte] >> ((quad >> 1) * 4)) & 0x0F; + const uint8_t high2 = (q12[8 + byte] >> (quad * 2)) & 0x03; + return (uint8_t)(low4 | (high2 << 4)); +} + +template +METAL_FUNC void kq_q3_k_dequantize_impl( + const device uint8_t* w, + device T* out, + const constant uint& num_weights, + uint gid) { + if (gid >= num_weights) { + return; + } + const int sb_id = gid / KQ_Q3_K_SUPERBLOCK; + const int within_sb = gid - sb_id * KQ_Q3_K_SUPERBLOCK; + + const int outer_half = within_sb / 128; + const int within_outer = within_sb - outer_half * 128; + const int shift_idx = within_outer / 32; + const int within_shift = within_outer - shift_idx * 32; + + const device uint8_t* sb_addr = w + sb_id * KQ_Q3_K_BLOCK_BYTES; + const float d = kq_q3_k_d(sb_addr); + const int scale_idx = within_sb / 16; + const uint8_t sc_unsigned = + kq_q3_k_unpack_scale(scale_idx, kq_q3_k_scales12_ptr(sb_addr)); + const float eff_scale = d * float((int)sc_unsigned - 32); + + const uint8_t qs_byte = + kq_q3_k_qs_ptr(sb_addr)[outer_half * 32 + within_shift]; + const uint8_t q2 = (qs_byte >> (shift_idx * 2)) & 0x03; + const int hmask_bit = outer_half * 4 + shift_idx; + const uint8_t hmask_byte = kq_q3_k_hmask_ptr(sb_addr)[within_shift]; + const bool hbit_set = ((hmask_byte >> hmask_bit) & 1) != 0; + const int q3 = (int)q2 - (hbit_set ? 0 : 4); + out[gid] = T(eff_scale * float(q3)); +} + +template +METAL_FUNC void kq_q3_k_qmv_fast_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q3_K_SUPERBLOCK, "Q3_K kernel requires group_size=256"); + static_assert(bits == 3, "Q3_K kernel requires bits=3"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 2; + constexpr int sb_stride = 4; + + const ushort4 mm_table[4] = { + {0x0001, 0x0100, 0x0002, 0x0200}, + {0x0004, 0x0400, 0x0008, 0x0800}, + {0x0010, 0x1000, 0x0020, 0x2000}, + {0x0040, 0x4000, 0x0080, 0x8000}, + }; + const ushort4 qm_table[2] = { + {0x0003, 0x0300, 0x000c, 0x0c00}, + {0x0030, 0x3000, 0x00c0, 0xc000}, + }; + + typedef float U; + thread U yl[32]; + thread U sumf1[results_per_simdgroup] = {0}; + thread U sumf2[results_per_simdgroup] = {0}; + + const int tid_lane = simd_lid / 4; + const int ix = simd_lid % 4; + const int ip = tid_lane / 4; + const int il = 2 * ((tid_lane % 4) / 2); + const int ir = tid_lane % 2; + const int l0 = 8 * ir; + const int tid_group = 2 * ip + il / 2; + + const ushort4 hm = mm_table[tid_group]; + const ushort4 qm = qm_table[il / 2]; + const int shift = 2 * il; + const U v1 = (il == 0) ? U(4) : U(64); + const U v2 = U(4) * v1; + const uint16_t s_shift1 = uint16_t(4 * ip); + const uint16_t s_shift2 = uint16_t(s_shift1 + il); + + const int row_bytes = in_vec_size * KQ_Q3_K_BLOCK_BYTES / KQ_Q3_K_SUPERBLOCK; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + const int q_offset_bytes = 32 * ip + l0; + const int y_offset = 128 * ip + 32 * il + l0; + const int nb = in_vec_size / KQ_Q3_K_SUPERBLOCK; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += sb_stride) { + const int x_base = ib * KQ_Q3_K_SUPERBLOCK + y_offset; +#pragma unroll + for (int l = 0; l < 8; l++) { + yl[l + 0] = U(x[x_base + l + 0]); + yl[l + 8] = U(x[x_base + l + 16]); + yl[l + 16] = U(x[x_base + l + 32]); + yl[l + 24] = U(x[x_base + l + 48]); + } + + for (int row = 0; row < results_per_simdgroup; row++) { + const int row_idx = out_row + row; + const device uint8_t* sb_addr = + w + row_idx * row_bytes + ib * KQ_Q3_K_BLOCK_BYTES; + + const device uint16_t* q = reinterpret_cast( + kq_q3_k_qs_ptr(sb_addr) + q_offset_bytes); + const device uint16_t* h = reinterpret_cast( + kq_q3_k_hmask_ptr(sb_addr) + l0); + const device uint16_t* a = reinterpret_cast( + kq_q3_k_scales12_ptr(sb_addr)); + + uint32_t scales32, aux32; + thread uint16_t* scales16 = reinterpret_cast(&scales32); + thread const int8_t* scales = + reinterpret_cast(&scales32); + + scales16[0] = a[4]; + scales16[1] = a[5]; + aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030u; + scales16[0] = a[il + 0]; + scales16[1] = a[il + 1]; + scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0fu) | aux32; + + const U d_all = U(kq_q3_k_d(sb_addr)); + + U s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; +#pragma unroll + for (int l = 0; l < 8; l += 2) { + const uint16_t qs = q[l / 2]; + s1 += yl[l + 0] * U(qs & qm[0]); + s2 += yl[l + 1] * U(qs & qm[1]); + s3 += ((h[l / 2] & hm[0]) ? U(0) : yl[l + 0]) + + ((h[l / 2] & hm[1]) ? U(0) : yl[l + 1]); + s4 += yl[l + 16] * U(qs & qm[2]); + s5 += yl[l + 17] * U(qs & qm[3]); + s6 += ((h[l / 2] & hm[2]) ? U(0) : yl[l + 16]) + + ((h[l / 2] & hm[3]) ? U(0) : yl[l + 17]); + } + U d1 = d_all * (s1 + s2 * (U(1) / U(256)) - s3 * v1); + U d2 = d_all * (s4 + s5 * (U(1) / U(256)) - s6 * v2); + sumf1[row] += d1 * U(int(scales[0]) - 32); + sumf2[row] += d2 * U(int(scales[2]) - 32); + + s1 = s2 = s3 = s4 = s5 = s6 = U(0); +#pragma unroll + for (int l = 0; l < 8; l += 2) { + const uint16_t qs = q[l / 2 + 8]; + s1 += yl[l + 8] * U(qs & qm[0]); + s2 += yl[l + 9] * U(qs & qm[1]); + s3 += ((h[l / 2 + 8] & hm[0]) ? U(0) : yl[l + 8]) + + ((h[l / 2 + 8] & hm[1]) ? U(0) : yl[l + 9]); + s4 += yl[l + 24] * U(qs & qm[2]); + s5 += yl[l + 25] * U(qs & qm[3]); + s6 += ((h[l / 2 + 8] & hm[2]) ? U(0) : yl[l + 24]) + + ((h[l / 2 + 8] & hm[3]) ? U(0) : yl[l + 25]); + } + d1 = d_all * (s1 + s2 * (U(1) / U(256)) - s3 * v1); + d2 = d_all * (s4 + s5 * (U(1) / U(256)) - s6 * v2); + sumf1[row] += d1 * U(int(scales[1]) - 32); + sumf2[row] += d2 * U(int(scales[3]) - 32); + } + } + + const U shift_div = U(1) / U(1u << shift); + for (int row = 0; row < results_per_simdgroup; row++) { + const U combined = (sumf1[row] + U(0.25) * sumf2[row]) * shift_div; + const U reduced = simd_sum(combined); + if (simd_lid == 0) { + y[out_row + row] = static_cast(reduced); + } + } +} + +template +METAL_FUNC void kq_q3_k_qmv_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q3_K_SUPERBLOCK, "Q3_K kernel requires group_size=256"); + static_assert(bits == 3, "Q3_K kernel requires bits=3"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 2; + constexpr int sb_stride = 4; + + const ushort4 mm_table[4] = { + {0x0001, 0x0100, 0x0002, 0x0200}, + {0x0004, 0x0400, 0x0008, 0x0800}, + {0x0010, 0x1000, 0x0020, 0x2000}, + {0x0040, 0x4000, 0x0080, 0x8000}, + }; + const ushort4 qm_table[2] = { + {0x0003, 0x0300, 0x000c, 0x0c00}, + {0x0030, 0x3000, 0x00c0, 0xc000}, + }; + + typedef float U; + thread U yl[32]; + thread U sumf1[results_per_simdgroup] = {0}; + thread U sumf2[results_per_simdgroup] = {0}; + + const int row_bytes = in_vec_size * KQ_Q3_K_BLOCK_BYTES / KQ_Q3_K_SUPERBLOCK; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + if (out_row >= out_vec_size) { + return; + } + const int max_row = min(out_vec_size, out_row + results_per_simdgroup); + const int active_rows = max_row - out_row; + + const int tid_lane = simd_lid / 4; + const int ix = simd_lid % 4; + const int ip = tid_lane / 4; + const int il = 2 * ((tid_lane % 4) / 2); + const int ir = tid_lane % 2; + const int l0 = 8 * ir; + const int tid_group = 2 * ip + il / 2; + + const ushort4 hm = mm_table[tid_group]; + const ushort4 qm = qm_table[il / 2]; + const int shift = 2 * il; + const U v1 = (il == 0) ? U(4) : U(64); + const U v2 = U(4) * v1; + const uint16_t s_shift1 = uint16_t(4 * ip); + const uint16_t s_shift2 = uint16_t(s_shift1 + il); + + const int q_offset_bytes = 32 * ip + l0; + const int y_offset = 128 * ip + 32 * il + l0; + const int nb = in_vec_size / KQ_Q3_K_SUPERBLOCK; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += sb_stride) { + const int x_base = ib * KQ_Q3_K_SUPERBLOCK + y_offset; +#pragma unroll + for (int l = 0; l < 8; l++) { + yl[l + 0] = U(x[x_base + l + 0]); + yl[l + 8] = U(x[x_base + l + 16]); + yl[l + 16] = U(x[x_base + l + 32]); + yl[l + 24] = U(x[x_base + l + 48]); + } + + for (int row = 0; row < active_rows; row++) { + const int row_idx = out_row + row; + const device uint8_t* sb_addr = + w + row_idx * row_bytes + ib * KQ_Q3_K_BLOCK_BYTES; + + const device uint16_t* q = reinterpret_cast( + kq_q3_k_qs_ptr(sb_addr) + q_offset_bytes); + const device uint16_t* h = reinterpret_cast( + kq_q3_k_hmask_ptr(sb_addr) + l0); + const device uint16_t* a = reinterpret_cast( + kq_q3_k_scales12_ptr(sb_addr)); + + uint32_t scales32, aux32; + thread uint16_t* scales16 = reinterpret_cast(&scales32); + thread const int8_t* scales = + reinterpret_cast(&scales32); + + scales16[0] = a[4]; + scales16[1] = a[5]; + aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030u; + scales16[0] = a[il + 0]; + scales16[1] = a[il + 1]; + scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0fu) | aux32; + + const U d_all = U(kq_q3_k_d(sb_addr)); + + U s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; +#pragma unroll + for (int l = 0; l < 8; l += 2) { + const uint16_t qs = q[l / 2]; + s1 += yl[l + 0] * U(qs & qm[0]); + s2 += yl[l + 1] * U(qs & qm[1]); + s3 += ((h[l / 2] & hm[0]) ? U(0) : yl[l + 0]) + + ((h[l / 2] & hm[1]) ? U(0) : yl[l + 1]); + s4 += yl[l + 16] * U(qs & qm[2]); + s5 += yl[l + 17] * U(qs & qm[3]); + s6 += ((h[l / 2] & hm[2]) ? U(0) : yl[l + 16]) + + ((h[l / 2] & hm[3]) ? U(0) : yl[l + 17]); + } + U d1 = d_all * (s1 + s2 * (U(1) / U(256)) - s3 * v1); + U d2 = d_all * (s4 + s5 * (U(1) / U(256)) - s6 * v2); + sumf1[row] += d1 * U(int(scales[0]) - 32); + sumf2[row] += d2 * U(int(scales[2]) - 32); + + s1 = s2 = s3 = s4 = s5 = s6 = U(0); +#pragma unroll + for (int l = 0; l < 8; l += 2) { + const uint16_t qs = q[l / 2 + 8]; + s1 += yl[l + 8] * U(qs & qm[0]); + s2 += yl[l + 9] * U(qs & qm[1]); + s3 += ((h[l / 2 + 8] & hm[0]) ? U(0) : yl[l + 8]) + + ((h[l / 2 + 8] & hm[1]) ? U(0) : yl[l + 9]); + s4 += yl[l + 24] * U(qs & qm[2]); + s5 += yl[l + 25] * U(qs & qm[3]); + s6 += ((h[l / 2 + 8] & hm[2]) ? U(0) : yl[l + 24]) + + ((h[l / 2 + 8] & hm[3]) ? U(0) : yl[l + 25]); + } + d1 = d_all * (s1 + s2 * (U(1) / U(256)) - s3 * v1); + d2 = d_all * (s4 + s5 * (U(1) / U(256)) - s6 * v2); + sumf1[row] += d1 * U(int(scales[1]) - 32); + sumf2[row] += d2 * U(int(scales[3]) - 32); + } + } + + const U shift_div = U(1) / U(1u << shift); + for (int row = 0; row < results_per_simdgroup; row++) { + const U combined = (sumf1[row] + U(0.25) * sumf2[row]) * shift_div; + const U reduced = simd_sum(combined); + if (simd_lid == 0 && row < active_rows) { + y[out_row + row] = static_cast(reduced); + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqQ3_KBlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q3_K_SUPERBLOCK; + MLX_MTL_CONST int bytes_per_block = KQ_Q3_K_BLOCK_BYTES; + MLX_MTL_CONST int k_tile_size = 32; + MLX_MTL_CONST int k_tiles_per_block = weights_per_block / k_tile_size; + + static_assert( + BCOLS == k_tile_size, + "Q3_K loader requires BCOLS == 32 (one K-tile per iteration)."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + + const int src_ld; + const int row_bytes; + const int tile_stride; + const short fixed_kt; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + short kt; + struct Caches { + T c1[n_reads]; + T c2[n_reads]; + T c3[n_reads]; + T c4[n_reads]; + T c5[n_reads]; + T c6[n_reads]; + T c7[n_reads]; + }; + metal::conditional_t cached; + + KqQ3_KBlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int col_in_block = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? 0 + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + fixed_kt(reduction_dim == 0 ? (col_in_block / k_tile_size) : 0), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj((thread_idx % TCOLS) * n_reads), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)), + kt(0) {} + + void load_unsafe() { + if constexpr (reduction_dim == 1) { + if (kt != 0) { + if (kt == 1) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c1[i]; + } else if (kt == 2) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c2[i]; + } else if (kt == 3) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c3[i]; + } else if (kt == 4) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c4[i]; + } else if (kt == 5) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c5[i]; + } else if (kt == 6) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c6[i]; + } else { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c7[i]; + } + return; + } + + const float d = float(*(const device half*)(src + KQ_Q3_K_D_OFFSET)); + const short scale_off = (bj >= 16) ? 1 : 0; + float es[8]; +#pragma unroll + for (short k = 0; k < 8; k++) { + const uint8_t sc = kq_q3_k_unpack_scale( + k * 2 + scale_off, src + KQ_Q3_K_SCALES_OFFSET); + es[k] = d * float((int)sc - 32); + } + + const device uint8_t* qs_a = src + KQ_Q3_K_QS_OFFSET + bj; + const device uint8_t* qs_b = src + KQ_Q3_K_QS_OFFSET + 32 + bj; + const device uint8_t* hm = src + KQ_Q3_K_HMASK_OFFSET + bj; + +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const uint8_t qa = qs_a[i]; + const uint8_t qb = qs_b[i]; + const uint8_t h = hm[i]; + const uint8_t q2_0 = qa & 0x03; + const uint8_t q2_1 = (qa >> 2) & 0x03; + const uint8_t q2_2 = (qa >> 4) & 0x03; + const uint8_t q2_3 = (qa >> 6) & 0x03; + const uint8_t q2_4 = qb & 0x03; + const uint8_t q2_5 = (qb >> 2) & 0x03; + const uint8_t q2_6 = (qb >> 4) & 0x03; + const uint8_t q2_7 = (qb >> 6) & 0x03; + const int q3_0 = (int)q2_0 - (((h >> 0) & 1) ? 0 : 4); + const int q3_1 = (int)q2_1 - (((h >> 1) & 1) ? 0 : 4); + const int q3_2 = (int)q2_2 - (((h >> 2) & 1) ? 0 : 4); + const int q3_3 = (int)q2_3 - (((h >> 3) & 1) ? 0 : 4); + const int q3_4 = (int)q2_4 - (((h >> 4) & 1) ? 0 : 4); + const int q3_5 = (int)q2_5 - (((h >> 5) & 1) ? 0 : 4); + const int q3_6 = (int)q2_6 - (((h >> 6) & 1) ? 0 : 4); + const int q3_7 = (int)q2_7 - (((h >> 7) & 1) ? 0 : 4); + dst[i] = T(es[0] * float(q3_0)); + cached.c1[i] = T(es[1] * float(q3_1)); + cached.c2[i] = T(es[2] * float(q3_2)); + cached.c3[i] = T(es[3] * float(q3_3)); + cached.c4[i] = T(es[4] * float(q3_4)); + cached.c5[i] = T(es[5] * float(q3_5)); + cached.c6[i] = T(es[6] * float(q3_6)); + cached.c7[i] = T(es[7] * float(q3_7)); + } + return; + } + + const short kt_use = fixed_kt; + const short outer_half = kt_use / 4; + const short qs_shift = (kt_use & 3) * 2; + const short hmask_bit = kt_use; + const short scale_idx = kt_use * 2 + (bj >= 16 ? 1 : 0); + + const float d = float(*(const device half*)(src + KQ_Q3_K_D_OFFSET)); + const uint8_t sc_unsigned = + kq_q3_k_unpack_scale(scale_idx, src + KQ_Q3_K_SCALES_OFFSET); + const float eff_scale = d * float((int)sc_unsigned - 32); + + const device uint8_t* qs = src + KQ_Q3_K_QS_OFFSET + outer_half * 32 + bj; + const device uint8_t* hm = src + KQ_Q3_K_HMASK_OFFSET + bj; + +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const uint8_t q2 = (qs[i] >> qs_shift) & 0x03; + const bool hbit_set = ((hm[i] >> hmask_bit) & 1) != 0; + const int q3 = (int)q2 - (hbit_set ? 0 : 4); + dst[i] = T(eff_scale * float(q3)); + } + } + + void load_safe(short2 src_tile_dim) { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + if (reduction_dim == 1) { + kt++; + if (kt == k_tiles_per_block) { + kt = 0; + src += bytes_per_block; + } + } else { + src += tile_stride; + } + } +}; + +template +[[kernel]] void kq_q3_k_qmm_t( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q3_K_SUPERBLOCK, "Q3_K kernel requires group_size=256"); + static_assert(bits == 3, "Q3_K kernel requires bits=3"); + constexpr int BM = 64, BK = 32, BN = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ3_KBlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_t_impl( + w, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q3_k_qmm_t_splitk( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& k_partition_size, + const constant int& split_k_partition_stride, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q3_K_SUPERBLOCK, "Q3_K kernel requires group_size=256"); + static_assert(bits == 3, "Q3_K kernel requires bits=3"); + constexpr int BM = 32, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ3_KBlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + + const int k_start = tid.z * k_partition_size; + x += k_start; + auto wl = w; + wl += (k_start / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + y += tid.z * static_cast(split_k_partition_stride); + + kq_qmm_t_impl( + wl, + x, + y, + Xs, + Ws, + K, + N, + M, + k_partition_size, + tid, + lid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void kq_q3_k_qmm_n( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q3_K_SUPERBLOCK, "Q3_K kernel requires group_size=256"); + static_assert(bits == 3, "Q3_K kernel requires bits=3"); + constexpr int BM = 64, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + using LoaderW = KqQ3_KBlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_n_impl( + w, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q3_k_qmv_fast( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q3_k_qmv_fast_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q3_k_qmv( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q3_k_qmv_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q3_k_dequantize( + const device uint8_t* w, + const device uint8_t* /* scales */, + device T* out, + const constant uint& num_weights, + uint gid [[thread_position_in_grid]]) { + static_assert( + group_size == KQ_Q3_K_SUPERBLOCK, "Q3_K kernel requires group_size=256"); + static_assert(bits == 3, "Q3_K kernel requires bits=3"); + kq_q3_k_dequantize_impl(w, out, num_weights, gid); +} + +// Q2_K: 84 bytes/256 weights. [scales[16]][qs[64]][fp16 d][fp16 dmin]. +// w[i] = d * (sc & 0xF) * q2 - dmin * (sc >> 4). Asymmetric. + +MLX_MTL_CONST int KQ_Q2_K_SUPERBLOCK = 256; +MLX_MTL_CONST int KQ_Q2_K_BLOCK_BYTES = 84; +MLX_MTL_CONST int KQ_Q2_K_SCALES_OFFSET = 0; +MLX_MTL_CONST int KQ_Q2_K_QS_OFFSET = 16; +MLX_MTL_CONST int KQ_Q2_K_D_OFFSET = 80; +MLX_MTL_CONST int KQ_Q2_K_DMIN_OFFSET = 82; + +inline float kq_q2_k_d(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q2_K_D_OFFSET)); +} +inline float kq_q2_k_dmin(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q2_K_DMIN_OFFSET)); +} +inline const device uint8_t* kq_q2_k_scales_ptr( + const device uint8_t* block_addr) { + return block_addr + KQ_Q2_K_SCALES_OFFSET; +} +inline const device uint8_t* kq_q2_k_qs_ptr(const device uint8_t* block_addr) { + return block_addr + KQ_Q2_K_QS_OFFSET; +} + +template +METAL_FUNC void kq_q2_k_dequantize_impl( + const device uint8_t* w, + device T* out, + const constant uint& num_weights, + uint gid) { + if (gid >= num_weights) { + return; + } + const int sb_id = gid / KQ_Q2_K_SUPERBLOCK; + const int within_sb = gid - sb_id * KQ_Q2_K_SUPERBLOCK; + + const int outer_half = within_sb / 128; + const int within_outer = within_sb - outer_half * 128; + const int shift_idx = within_outer / 32; + const int within_shift = within_outer - shift_idx * 32; + + const device uint8_t* sb_addr = w + sb_id * KQ_Q2_K_BLOCK_BYTES; + const float d = kq_q2_k_d(sb_addr); + const float dmin = kq_q2_k_dmin(sb_addr); + const int scale_idx = within_sb / 16; + const uint8_t sc_byte = kq_q2_k_scales_ptr(sb_addr)[scale_idx]; + const float eff_scale = d * float(sc_byte & 0x0F); + const float eff_min = dmin * float(sc_byte >> 4); + + const uint8_t qs_byte = + kq_q2_k_qs_ptr(sb_addr)[outer_half * 32 + within_shift]; + const uint8_t q2 = (qs_byte >> (shift_idx * 2)) & 0x03; + out[gid] = T(eff_scale * float(q2) - eff_min); +} + +template +METAL_FUNC void kq_q2_k_qmv_fast_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q2_K_SUPERBLOCK, "Q2_K kernel requires group_size=256"); + static_assert(bits == 2, "Q2_K kernel requires bits=2"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 2; + constexpr int sb_stride = 4; + + typedef float U; + thread U yl[32]; + thread U result[results_per_simdgroup] = {0}; + + const int ix = simd_lid / 8; + const int it = simd_lid % 8; + const int iq = it / 4; + const int ir = it % 4; + const int is = (8 * ir) / 16; + + const int row_bytes = in_vec_size * KQ_Q2_K_BLOCK_BYTES / KQ_Q2_K_SUPERBLOCK; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + const int nb = in_vec_size / KQ_Q2_K_SUPERBLOCK; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += sb_stride) { + const int x_base = ib * KQ_Q2_K_SUPERBLOCK + 128 * iq + 8 * ir; + U sumy[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i++) { + yl[i + 0] = U(x[x_base + i + 0]); + sumy[0] += yl[i + 0]; + yl[i + 8] = U(x[x_base + i + 32]); + sumy[1] += yl[i + 8]; + yl[i + 16] = U(x[x_base + i + 64]); + sumy[2] += yl[i + 16]; + yl[i + 24] = U(x[x_base + i + 96]); + sumy[3] += yl[i + 24]; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + const int row_idx = out_row + row; + const device uint8_t* sb_addr = + w + row_idx * row_bytes + ib * KQ_Q2_K_BLOCK_BYTES; + + const device uint8_t* sc = kq_q2_k_scales_ptr(sb_addr) + 8 * iq + is; + const device uint16_t* qs = + reinterpret_cast(kq_q2_k_qs_ptr(sb_addr)) + + 16 * iq + 4 * ir; + + U acc1[4] = {U(0), U(0), U(0), U(0)}; + U acc2[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const uint16_t qs_i = qs[i / 2]; + acc1[0] += yl[i + 0] * U(qs_i & 0x0003); + acc2[0] += yl[i + 1] * U(qs_i & 0x0300); + acc1[1] += yl[i + 8] * U(qs_i & 0x000c); + acc2[1] += yl[i + 9] * U(qs_i & 0x0c00); + acc1[2] += yl[i + 16] * U(qs_i & 0x0030); + acc2[2] += yl[i + 17] * U(qs_i & 0x3000); + acc1[3] += yl[i + 24] * U(qs_i & 0x00c0); + acc2[3] += yl[i + 25] * U(qs_i & 0xc000); + } + + const U d = U(kq_q2_k_d(sb_addr)); + const U dmin = U(kq_q2_k_dmin(sb_addr)); + result[row] += d * + ((acc1[0] + acc2[0] * (U(1) / U(256))) * U(sc[0] & 0x0F) + + (acc1[1] + acc2[1] * (U(1) / U(256))) * U(sc[2] & 0x0F) * + (U(1) / U(4)) + + (acc1[2] + acc2[2] * (U(1) / U(256))) * U(sc[4] & 0x0F) * + (U(1) / U(16)) + + (acc1[3] + acc2[3] * (U(1) / U(256))) * U(sc[6] & 0x0F) * + (U(1) / U(64))) - + dmin * (U(1) / U(16)) * + (sumy[0] * U(sc[0] & 0xF0) + sumy[1] * U(sc[2] & 0xF0) + + sumy[2] * U(sc[4] & 0xF0) + sumy[3] * U(sc[6] & 0xF0)); + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void kq_q2_k_qmv_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q2_K_SUPERBLOCK, "Q2_K kernel requires group_size=256"); + static_assert(bits == 2, "Q2_K kernel requires bits=2"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 2; + constexpr int sb_stride = 4; + + typedef float U; + thread U yl[32]; + thread U result[results_per_simdgroup] = {0}; + + const int row_bytes = in_vec_size * KQ_Q2_K_BLOCK_BYTES / KQ_Q2_K_SUPERBLOCK; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + if (out_row >= out_vec_size) { + return; + } + const int max_row = min(out_vec_size, out_row + results_per_simdgroup); + const int active_rows = max_row - out_row; + + const int ix = simd_lid / 8; + const int it = simd_lid % 8; + const int iq = it / 4; + const int ir = it % 4; + const int is = (8 * ir) / 16; + + const int nb = in_vec_size / KQ_Q2_K_SUPERBLOCK; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += sb_stride) { + const int x_base = ib * KQ_Q2_K_SUPERBLOCK + 128 * iq + 8 * ir; + U sumy[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i++) { + yl[i + 0] = U(x[x_base + i + 0]); + sumy[0] += yl[i + 0]; + yl[i + 8] = U(x[x_base + i + 32]); + sumy[1] += yl[i + 8]; + yl[i + 16] = U(x[x_base + i + 64]); + sumy[2] += yl[i + 16]; + yl[i + 24] = U(x[x_base + i + 96]); + sumy[3] += yl[i + 24]; + } + + for (int row = 0; row < active_rows; row++) { + const int row_idx = out_row + row; + const device uint8_t* sb_addr = + w + row_idx * row_bytes + ib * KQ_Q2_K_BLOCK_BYTES; + + const device uint8_t* sc = kq_q2_k_scales_ptr(sb_addr) + 8 * iq + is; + const device uint16_t* qs = + reinterpret_cast(kq_q2_k_qs_ptr(sb_addr)) + + 16 * iq + 4 * ir; + + U acc1[4] = {U(0), U(0), U(0), U(0)}; + U acc2[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const uint16_t qs_i = qs[i / 2]; + acc1[0] += yl[i + 0] * U(qs_i & 0x0003); + acc2[0] += yl[i + 1] * U(qs_i & 0x0300); + acc1[1] += yl[i + 8] * U(qs_i & 0x000c); + acc2[1] += yl[i + 9] * U(qs_i & 0x0c00); + acc1[2] += yl[i + 16] * U(qs_i & 0x0030); + acc2[2] += yl[i + 17] * U(qs_i & 0x3000); + acc1[3] += yl[i + 24] * U(qs_i & 0x00c0); + acc2[3] += yl[i + 25] * U(qs_i & 0xc000); + } + + const U d = U(kq_q2_k_d(sb_addr)); + const U dmin = U(kq_q2_k_dmin(sb_addr)); + result[row] += d * + ((acc1[0] + acc2[0] * (U(1) / U(256))) * U(sc[0] & 0x0F) + + (acc1[1] + acc2[1] * (U(1) / U(256))) * U(sc[2] & 0x0F) * + (U(1) / U(4)) + + (acc1[2] + acc2[2] * (U(1) / U(256))) * U(sc[4] & 0x0F) * + (U(1) / U(16)) + + (acc1[3] + acc2[3] * (U(1) / U(256))) * U(sc[6] & 0x0F) * + (U(1) / U(64))) - + dmin * (U(1) / U(16)) * + (sumy[0] * U(sc[0] & 0xF0) + sumy[1] * U(sc[2] & 0xF0) + + sumy[2] * U(sc[4] & 0xF0) + sumy[3] * U(sc[6] & 0xF0)); + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0 && row < active_rows) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqQ2_KBlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q2_K_SUPERBLOCK; + MLX_MTL_CONST int bytes_per_block = KQ_Q2_K_BLOCK_BYTES; + MLX_MTL_CONST int k_tile_size = 32; + MLX_MTL_CONST int k_tiles_per_block = weights_per_block / k_tile_size; + + static_assert( + BCOLS == k_tile_size, + "Q2_K loader requires BCOLS == 32 (one K-tile per iteration)."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + + const int src_ld; + const int row_bytes; + const int tile_stride; + const short fixed_kt; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + short kt; + struct Caches { + T c1[n_reads]; + T c2[n_reads]; + T c3[n_reads]; + T c4[n_reads]; + T c5[n_reads]; + T c6[n_reads]; + T c7[n_reads]; + }; + metal::conditional_t cached; + + KqQ2_KBlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int col_in_block = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? 0 + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + fixed_kt(reduction_dim == 0 ? (col_in_block / k_tile_size) : 0), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj((thread_idx % TCOLS) * n_reads), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)), + kt(0) {} + + void load_unsafe() { + if constexpr (reduction_dim == 1) { + if (kt != 0) { + if (kt == 1) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c1[i]; + } else if (kt == 2) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c2[i]; + } else if (kt == 3) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c3[i]; + } else if (kt == 4) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c4[i]; + } else if (kt == 5) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c5[i]; + } else if (kt == 6) { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c6[i]; + } else { +#pragma unroll + for (short i = 0; i < n_reads; i++) + dst[i] = cached.c7[i]; + } + return; + } + + const float d = float(*(const device half*)(src + KQ_Q2_K_D_OFFSET)); + const float dmin = + float(*(const device half*)(src + KQ_Q2_K_DMIN_OFFSET)); + const short scale_off = (bj >= 16) ? 1 : 0; + float es[8]; + float em[8]; +#pragma unroll + for (short k = 0; k < 8; k++) { + const uint8_t sc_byte = src[KQ_Q2_K_SCALES_OFFSET + k * 2 + scale_off]; + es[k] = d * float(sc_byte & 0x0F); + em[k] = dmin * float(sc_byte >> 4); + } + + static_assert( + n_reads == 8 || n_reads == 16, + "Q2_K ALU vector load supports n_reads=8 or 16 (uint)."); + const device uint8_t* qs_a = src + KQ_Q2_K_QS_OFFSET + bj; + const device uint8_t* qs_b = src + KQ_Q2_K_QS_OFFSET + 32 + bj; + uint8_t qa_b[n_reads]; + uint8_t qb_b[n_reads]; +#pragma unroll + for (short v = 0; v < n_reads / 4; v++) { + const uint qs_a_v = *reinterpret_cast(qs_a + v * 4); + const uint qs_b_v = *reinterpret_cast(qs_b + v * 4); + *reinterpret_cast(&qa_b[v * 4]) = qs_a_v; + *reinterpret_cast(&qb_b[v * 4]) = qs_b_v; + } + +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const uint8_t qa = qa_b[i]; + const uint8_t qb = qb_b[i]; + const uint8_t q2_0 = qa & 0x03; + const uint8_t q2_1 = (qa >> 2) & 0x03; + const uint8_t q2_2 = (qa >> 4) & 0x03; + const uint8_t q2_3 = (qa >> 6) & 0x03; + const uint8_t q2_4 = qb & 0x03; + const uint8_t q2_5 = (qb >> 2) & 0x03; + const uint8_t q2_6 = (qb >> 4) & 0x03; + const uint8_t q2_7 = (qb >> 6) & 0x03; + dst[i] = T(es[0] * float(q2_0) - em[0]); + cached.c1[i] = T(es[1] * float(q2_1) - em[1]); + cached.c2[i] = T(es[2] * float(q2_2) - em[2]); + cached.c3[i] = T(es[3] * float(q2_3) - em[3]); + cached.c4[i] = T(es[4] * float(q2_4) - em[4]); + cached.c5[i] = T(es[5] * float(q2_5) - em[5]); + cached.c6[i] = T(es[6] * float(q2_6) - em[6]); + cached.c7[i] = T(es[7] * float(q2_7) - em[7]); + } + return; + } + + const short kt_use = fixed_kt; + const short outer_half = kt_use / 4; + const short qs_shift = (kt_use & 3) * 2; + const short scale_idx = kt_use * 2 + (bj >= 16 ? 1 : 0); + + const float d = float(*(const device half*)(src + KQ_Q2_K_D_OFFSET)); + const float dmin = float(*(const device half*)(src + KQ_Q2_K_DMIN_OFFSET)); + const uint8_t sc_byte = src[KQ_Q2_K_SCALES_OFFSET + scale_idx]; + const float eff_scale = d * float(sc_byte & 0x0F); + const float eff_min = dmin * float(sc_byte >> 4); + + const device uint8_t* qs = src + KQ_Q2_K_QS_OFFSET + outer_half * 32 + bj; + +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const uint8_t q2 = (qs[i] >> qs_shift) & 0x03; + dst[i] = T(eff_scale * float(q2) - eff_min); + } + } + + void load_safe(short2 src_tile_dim) { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + if (reduction_dim == 1) { + kt++; + if (kt == k_tiles_per_block) { + kt = 0; + src += bytes_per_block; + } + } else { + src += tile_stride; + } + } +}; + +template +[[kernel]] void kq_q2_k_qmm_t( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q2_K_SUPERBLOCK, "Q2_K kernel requires group_size=256"); + static_assert(bits == 2, "Q2_K kernel requires bits=2"); + constexpr int BM = 64, BK = 32, BN = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ2_KBlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_t_impl( + w, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q2_k_qmm_t_splitk( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& k_partition_size, + const constant int& split_k_partition_stride, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q2_K_SUPERBLOCK, "Q2_K kernel requires group_size=256"); + static_assert(bits == 2, "Q2_K kernel requires bits=2"); + constexpr int BM = 32, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ2_KBlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + + const int k_start = tid.z * k_partition_size; + x += k_start; + auto wl = w; + wl += (k_start / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + y += tid.z * static_cast(split_k_partition_stride); + + kq_qmm_t_impl( + wl, + x, + y, + Xs, + Ws, + K, + N, + M, + k_partition_size, + tid, + lid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void kq_q2_k_qmm_n( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q2_K_SUPERBLOCK, "Q2_K kernel requires group_size=256"); + static_assert(bits == 2, "Q2_K kernel requires bits=2"); + constexpr int BM = 64, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + using LoaderW = KqQ2_KBlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_n_impl( + w, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q2_k_qmv_fast( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q2_k_qmv_fast_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q2_k_qmv( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q2_k_qmv_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q2_k_dequantize( + const device uint8_t* w, + const device uint8_t* /* scales */, + device T* out, + const constant uint& num_weights, + uint gid [[thread_position_in_grid]]) { + static_assert( + group_size == KQ_Q2_K_SUPERBLOCK, "Q2_K kernel requires group_size=256"); + static_assert(bits == 2, "Q2_K kernel requires bits=2"); + kq_q2_k_dequantize_impl(w, out, num_weights, gid); +} + +#define KQUANT_DEFINE_GATHER_KERNELS(CODEC, LOADER) \ + template \ + [[kernel]] void kq_##CODEC##_gather_qmv_fast( \ + const device uint8_t* w, \ + const device uint8_t* /* scales */, \ + const device T* x, \ + const device uint32_t* lhs_indices, \ + const device uint32_t* rhs_indices, \ + device T* y, \ + const constant int& in_vec_size, \ + const constant int& out_vec_size, \ + const constant int& x_batch_ndims, \ + const constant int* x_shape, \ + const constant int64_t* x_strides, \ + const constant int& w_batch_ndims, \ + const constant int* w_shape, \ + const constant int64_t* w_strides, \ + const constant int64_t* /* s_strides */, \ + const constant int& batch_ndims, \ + const constant int* batch_shape, \ + const constant int64_t* lhs_strides, \ + const constant int64_t* rhs_strides, \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]) { \ + int M = x_shape[x_batch_ndims]; \ + kq_adjust_matrix_offsets( \ + x, \ + w, \ + lhs_indices, \ + rhs_indices, \ + y, \ + out_vec_size * M, \ + batch_ndims, \ + batch_shape, \ + lhs_strides, \ + rhs_strides, \ + x_batch_ndims, \ + x_shape, \ + x_strides, \ + w_batch_ndims, \ + w_shape, \ + w_strides, \ + tid); \ + kq_##CODEC##_qmv_fast_impl( \ + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); \ + } \ + \ + template \ + [[kernel]] void kq_##CODEC##_gather_qmv( \ + const device uint8_t* w, \ + const device uint8_t* /* scales */, \ + const device T* x, \ + const device uint32_t* lhs_indices, \ + const device uint32_t* rhs_indices, \ + device T* y, \ + const constant int& in_vec_size, \ + const constant int& out_vec_size, \ + const constant int& x_batch_ndims, \ + const constant int* x_shape, \ + const constant int64_t* x_strides, \ + const constant int& w_batch_ndims, \ + const constant int* w_shape, \ + const constant int64_t* w_strides, \ + const constant int64_t* /* s_strides */, \ + const constant int& batch_ndims, \ + const constant int* batch_shape, \ + const constant int64_t* lhs_strides, \ + const constant int64_t* rhs_strides, \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]) { \ + int M = x_shape[x_batch_ndims]; \ + kq_adjust_matrix_offsets( \ + x, \ + w, \ + lhs_indices, \ + rhs_indices, \ + y, \ + out_vec_size * M, \ + batch_ndims, \ + batch_shape, \ + lhs_strides, \ + rhs_strides, \ + x_batch_ndims, \ + x_shape, \ + x_strides, \ + w_batch_ndims, \ + w_shape, \ + w_strides, \ + tid); \ + kq_##CODEC##_qmv_impl( \ + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); \ + } \ + \ + template \ + [[kernel]] void kq_##CODEC##_gather_qmm_t( \ + const device uint8_t* w, \ + const device uint8_t* /* scales */, \ + const device T* x, \ + const device uint32_t* lhs_indices, \ + const device uint32_t* rhs_indices, \ + device T* y, \ + const constant int& K, \ + const constant int& N, \ + const constant int& M, \ + const constant int& x_batch_ndims, \ + const constant int* x_shape, \ + const constant int64_t* x_strides, \ + const constant int& w_batch_ndims, \ + const constant int* w_shape, \ + const constant int64_t* w_strides, \ + const constant int64_t* /* s_strides */, \ + const constant int& batch_ndims, \ + const constant int* batch_shape, \ + const constant int64_t* lhs_strides, \ + const constant int64_t* rhs_strides, \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]) { \ + kq_adjust_matrix_offsets( \ + x, \ + w, \ + lhs_indices, \ + rhs_indices, \ + y, \ + M * N, \ + batch_ndims, \ + batch_shape, \ + lhs_strides, \ + rhs_strides, \ + x_batch_ndims, \ + x_shape, \ + x_strides, \ + w_batch_ndims, \ + w_shape, \ + w_strides, \ + tid); \ + constexpr int BM = 32, BK = 32, BN = 32; \ + constexpr int BK_padded = (BK + 16 / sizeof(T)); \ + threadgroup T Xs[BM * BK_padded]; \ + threadgroup T Ws[BN * BK_padded]; \ + using LoaderW = LOADER< \ + T, \ + BN, \ + BK, \ + BK_padded, \ + /*reduction_dim=*/1, \ + /*tgp_size=*/2 * 2 * SIMD_SIZE>; \ + kq_qmm_t_impl( \ + w, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid); \ + } \ + \ + template \ + [[kernel]] void kq_##CODEC##_gather_qmm_n( \ + const device uint8_t* w, \ + const device uint8_t* /* scales */, \ + const device T* x, \ + const device uint32_t* lhs_indices, \ + const device uint32_t* rhs_indices, \ + device T* y, \ + const constant int& K, \ + const constant int& N, \ + const constant int& M, \ + const constant int& x_batch_ndims, \ + const constant int* x_shape, \ + const constant int64_t* x_strides, \ + const constant int& w_batch_ndims, \ + const constant int* w_shape, \ + const constant int64_t* w_strides, \ + const constant int64_t* /* s_strides */, \ + const constant int& batch_ndims, \ + const constant int* batch_shape, \ + const constant int64_t* lhs_strides, \ + const constant int64_t* rhs_strides, \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]) { \ + kq_adjust_matrix_offsets( \ + x, \ + w, \ + lhs_indices, \ + rhs_indices, \ + y, \ + M * N, \ + batch_ndims, \ + batch_shape, \ + lhs_strides, \ + rhs_strides, \ + x_batch_ndims, \ + x_shape, \ + x_strides, \ + w_batch_ndims, \ + w_shape, \ + w_strides, \ + tid); \ + constexpr int BM = 32, BK = 32, BN = 32; \ + constexpr int BK_padded = (BK + 16 / sizeof(T)); \ + constexpr int BN_padded = (BN + 16 / sizeof(T)); \ + threadgroup T Xs[BM * BK_padded]; \ + threadgroup T Ws[BK * BN_padded]; \ + using LoaderW = LOADER< \ + T, \ + BK, \ + BN, \ + BN_padded, \ + /*reduction_dim=*/0, \ + /*tgp_size=*/2 * 2 * SIMD_SIZE>; \ + kq_qmm_n_impl( \ + w, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); \ + } + +KQUANT_DEFINE_GATHER_KERNELS(q8_0, KqQ8_0BlockLoader) +KQUANT_DEFINE_GATHER_KERNELS(q4_0, KqQ4_0BlockLoader) +KQUANT_DEFINE_GATHER_KERNELS(q4_1, KqQ4_1BlockLoader) +KQUANT_DEFINE_GATHER_KERNELS(q5_0, KqQ5_0BlockLoader) +KQUANT_DEFINE_GATHER_KERNELS(q5_1, KqQ5_1BlockLoader) +KQUANT_DEFINE_GATHER_KERNELS(q4_k, KqQ4_KBlockLoader) +KQUANT_DEFINE_GATHER_KERNELS(q5_k, KqQ5_KBlockLoader) +KQUANT_DEFINE_GATHER_KERNELS(q6_k, KqQ6_KBlockLoader) +KQUANT_DEFINE_GATHER_KERNELS(q3_k, KqQ3_KBlockLoader) +KQUANT_DEFINE_GATHER_KERNELS(q2_k, KqQ2_KBlockLoader) + +#undef KQUANT_DEFINE_GATHER_KERNELS diff --git a/mlx/backend/metal/kernels/kq_quantized.metal b/mlx/backend/metal/kernels/kq_quantized.metal new file mode 100644 index 0000000000..362afb2057 --- /dev/null +++ b/mlx/backend/metal/kernels/kq_quantized.metal @@ -0,0 +1,323 @@ +// Copyright © 2026 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/quantized_utils.h" +#include "mlx/backend/metal/kernels/kq_quantized.h" + +#define instantiate_kquant_batched(func, type, gs, bits, batched, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_" #func "_" #type "_gs_" #gs "_b_" #bits \ + "_batch_" #batched, \ + kq_ ## codec ## _ ## func, \ + type, \ + gs, \ + bits, \ + batched) + +#define instantiate_kquant_qmm_t(type, gs, bits, aligned_N, batched, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_qmm_t_" #type "_gs_" #gs "_b_" #bits \ + "_alN_" #aligned_N "_batch_" #batched, \ + kq_ ## codec ## _qmm_t, \ + type, \ + gs, \ + bits, \ + aligned_N, \ + batched) + +#define instantiate_kquant_qmm_t_splitk(type, gs, bits, aligned_N, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_qmm_t_splitk_" #type "_gs_" #gs "_b_" #bits \ + "_alN_" #aligned_N, \ + kq_ ## codec ## _qmm_t_splitk, \ + type, \ + gs, \ + bits, \ + aligned_N) + +#define instantiate_kquant_qmm_n(type, gs, bits, batched, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_qmm_n_" #type "_gs_" #gs "_b_" #bits \ + "_batch_" #batched, \ + kq_ ## codec ## _qmm_n, \ + type, \ + gs, \ + bits, \ + batched) + +#define instantiate_kquant_gather_qmv(func, type, gs, bits, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_" #func "_" #type "_gs_" #gs "_b_" #bits, \ + kq_ ## codec ## _ ## func, \ + type, \ + gs, \ + bits) + +#define instantiate_kquant_gather_qmm_t(type, gs, bits, aligned_N, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_gather_qmm_t_" #type "_gs_" #gs "_b_" #bits \ + "_alN_" #aligned_N, \ + kq_ ## codec ## _gather_qmm_t, \ + type, \ + gs, \ + bits, \ + aligned_N) + +#define instantiate_kquant_gather_qmm_n(type, gs, bits, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_gather_qmm_n_" #type "_gs_" #gs "_b_" #bits, \ + kq_ ## codec ## _gather_qmm_n, \ + type, \ + gs, \ + bits) + +#define instantiate_kquant_dequantize(type, gs, bits, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_dequantize_" #type "_gs_" #gs "_b_" #bits, \ + kq_ ## codec ## _dequantize, \ + type, \ + gs, \ + bits) + +#define instantiate_kquant_q8_0_for_type(type) \ + instantiate_kquant_batched(qmv_fast, type, 32, 8, 0, q8_0) \ + instantiate_kquant_batched(qmv_fast, type, 32, 8, 1, q8_0) \ + instantiate_kquant_batched(qmv, type, 32, 8, 0, q8_0) \ + instantiate_kquant_batched(qmv, type, 32, 8, 1, q8_0) \ + instantiate_kquant_qmm_t(type, 32, 8, true, 0, q8_0) \ + instantiate_kquant_qmm_t(type, 32, 8, true, 1, q8_0) \ + instantiate_kquant_qmm_t(type, 32, 8, false, 0, q8_0) \ + instantiate_kquant_qmm_t(type, 32, 8, false, 1, q8_0) \ + instantiate_kquant_qmm_t_splitk(type, 32, 8, true, q8_0) \ + instantiate_kquant_qmm_t_splitk(type, 32, 8, false, q8_0) \ + instantiate_kquant_qmm_n(type, 32, 8, 0, q8_0) \ + instantiate_kquant_qmm_n(type, 32, 8, 1, q8_0) \ + instantiate_kquant_gather_qmv(gather_qmv_fast, type, 32, 8, q8_0) \ + instantiate_kquant_gather_qmv(gather_qmv, type, 32, 8, q8_0) \ + instantiate_kquant_gather_qmm_t(type, 32, 8, true, q8_0) \ + instantiate_kquant_gather_qmm_t(type, 32, 8, false, q8_0) \ + instantiate_kquant_gather_qmm_n(type, 32, 8, q8_0) \ + instantiate_kquant_dequantize(type, 32, 8, q8_0) + +instantiate_kquant_q8_0_for_type(float) +instantiate_kquant_q8_0_for_type(bfloat16_t) +instantiate_kquant_q8_0_for_type(float16_t) + +#define instantiate_kquant_q5_1_for_type(type) \ + instantiate_kquant_batched(qmv_fast, type, 32, 5, 0, q5_1) \ + instantiate_kquant_batched(qmv_fast, type, 32, 5, 1, q5_1) \ + instantiate_kquant_batched(qmv, type, 32, 5, 0, q5_1) \ + instantiate_kquant_batched(qmv, type, 32, 5, 1, q5_1) \ + instantiate_kquant_qmm_t(type, 32, 5, true, 0, q5_1) \ + instantiate_kquant_qmm_t(type, 32, 5, true, 1, q5_1) \ + instantiate_kquant_qmm_t(type, 32, 5, false, 0, q5_1) \ + instantiate_kquant_qmm_t(type, 32, 5, false, 1, q5_1) \ + instantiate_kquant_qmm_t_splitk(type, 32, 5, true, q5_1) \ + instantiate_kquant_qmm_t_splitk(type, 32, 5, false, q5_1) \ + instantiate_kquant_qmm_n(type, 32, 5, 0, q5_1) \ + instantiate_kquant_qmm_n(type, 32, 5, 1, q5_1) \ + instantiate_kquant_gather_qmv(gather_qmv_fast, type, 32, 5, q5_1) \ + instantiate_kquant_gather_qmv(gather_qmv, type, 32, 5, q5_1) \ + instantiate_kquant_gather_qmm_t(type, 32, 5, true, q5_1) \ + instantiate_kquant_gather_qmm_t(type, 32, 5, false, q5_1) \ + instantiate_kquant_gather_qmm_n(type, 32, 5, q5_1) \ + instantiate_kquant_dequantize(type, 32, 5, q5_1) + +instantiate_kquant_q5_1_for_type(float) +instantiate_kquant_q5_1_for_type(bfloat16_t) +instantiate_kquant_q5_1_for_type(float16_t) + +#define instantiate_kquant_q4_0_for_type(type) \ + instantiate_kquant_batched(qmv_fast, type, 32, 4, 0, q4_0) \ + instantiate_kquant_batched(qmv_fast, type, 32, 4, 1, q4_0) \ + instantiate_kquant_batched(qmv, type, 32, 4, 0, q4_0) \ + instantiate_kquant_batched(qmv, type, 32, 4, 1, q4_0) \ + instantiate_kquant_qmm_t(type, 32, 4, true, 0, q4_0) \ + instantiate_kquant_qmm_t(type, 32, 4, true, 1, q4_0) \ + instantiate_kquant_qmm_t(type, 32, 4, false, 0, q4_0) \ + instantiate_kquant_qmm_t(type, 32, 4, false, 1, q4_0) \ + instantiate_kquant_qmm_t_splitk(type, 32, 4, true, q4_0) \ + instantiate_kquant_qmm_t_splitk(type, 32, 4, false, q4_0) \ + instantiate_kquant_qmm_n(type, 32, 4, 0, q4_0) \ + instantiate_kquant_qmm_n(type, 32, 4, 1, q4_0) \ + instantiate_kquant_gather_qmv(gather_qmv_fast, type, 32, 4, q4_0) \ + instantiate_kquant_gather_qmv(gather_qmv, type, 32, 4, q4_0) \ + instantiate_kquant_gather_qmm_t(type, 32, 4, true, q4_0) \ + instantiate_kquant_gather_qmm_t(type, 32, 4, false, q4_0) \ + instantiate_kquant_gather_qmm_n(type, 32, 4, q4_0) \ + instantiate_kquant_dequantize(type, 32, 4, q4_0) + +instantiate_kquant_q4_0_for_type(float) +instantiate_kquant_q4_0_for_type(bfloat16_t) +instantiate_kquant_q4_0_for_type(float16_t) + +#define instantiate_kquant_q4_1_for_type(type) \ + instantiate_kquant_batched(qmv_fast, type, 32, 4, 0, q4_1) \ + instantiate_kquant_batched(qmv_fast, type, 32, 4, 1, q4_1) \ + instantiate_kquant_batched(qmv, type, 32, 4, 0, q4_1) \ + instantiate_kquant_batched(qmv, type, 32, 4, 1, q4_1) \ + instantiate_kquant_qmm_t(type, 32, 4, true, 0, q4_1) \ + instantiate_kquant_qmm_t(type, 32, 4, true, 1, q4_1) \ + instantiate_kquant_qmm_t(type, 32, 4, false, 0, q4_1) \ + instantiate_kquant_qmm_t(type, 32, 4, false, 1, q4_1) \ + instantiate_kquant_qmm_t_splitk(type, 32, 4, true, q4_1) \ + instantiate_kquant_qmm_t_splitk(type, 32, 4, false, q4_1) \ + instantiate_kquant_qmm_n(type, 32, 4, 0, q4_1) \ + instantiate_kquant_qmm_n(type, 32, 4, 1, q4_1) \ + instantiate_kquant_gather_qmv(gather_qmv_fast, type, 32, 4, q4_1) \ + instantiate_kquant_gather_qmv(gather_qmv, type, 32, 4, q4_1) \ + instantiate_kquant_gather_qmm_t(type, 32, 4, true, q4_1) \ + instantiate_kquant_gather_qmm_t(type, 32, 4, false, q4_1) \ + instantiate_kquant_gather_qmm_n(type, 32, 4, q4_1) \ + instantiate_kquant_dequantize(type, 32, 4, q4_1) + +instantiate_kquant_q4_1_for_type(float) +instantiate_kquant_q4_1_for_type(bfloat16_t) +instantiate_kquant_q4_1_for_type(float16_t) + +#define instantiate_kquant_q5_0_for_type(type) \ + instantiate_kquant_batched(qmv_fast, type, 32, 5, 0, q5_0) \ + instantiate_kquant_batched(qmv_fast, type, 32, 5, 1, q5_0) \ + instantiate_kquant_batched(qmv, type, 32, 5, 0, q5_0) \ + instantiate_kquant_batched(qmv, type, 32, 5, 1, q5_0) \ + instantiate_kquant_qmm_t(type, 32, 5, true, 0, q5_0) \ + instantiate_kquant_qmm_t(type, 32, 5, true, 1, q5_0) \ + instantiate_kquant_qmm_t(type, 32, 5, false, 0, q5_0) \ + instantiate_kquant_qmm_t(type, 32, 5, false, 1, q5_0) \ + instantiate_kquant_qmm_t_splitk(type, 32, 5, true, q5_0) \ + instantiate_kquant_qmm_t_splitk(type, 32, 5, false, q5_0) \ + instantiate_kquant_qmm_n(type, 32, 5, 0, q5_0) \ + instantiate_kquant_qmm_n(type, 32, 5, 1, q5_0) \ + instantiate_kquant_gather_qmv(gather_qmv_fast, type, 32, 5, q5_0) \ + instantiate_kquant_gather_qmv(gather_qmv, type, 32, 5, q5_0) \ + instantiate_kquant_gather_qmm_t(type, 32, 5, true, q5_0) \ + instantiate_kquant_gather_qmm_t(type, 32, 5, false, q5_0) \ + instantiate_kquant_gather_qmm_n(type, 32, 5, q5_0) \ + instantiate_kquant_dequantize(type, 32, 5, q5_0) + +instantiate_kquant_q5_0_for_type(float) +instantiate_kquant_q5_0_for_type(bfloat16_t) +instantiate_kquant_q5_0_for_type(float16_t) + +#define instantiate_kquant_q4_k_for_type(type) \ + instantiate_kquant_batched(qmv_fast, type, 256, 4, 0, q4_k) \ + instantiate_kquant_batched(qmv_fast, type, 256, 4, 1, q4_k) \ + instantiate_kquant_batched(qmv, type, 256, 4, 0, q4_k) \ + instantiate_kquant_batched(qmv, type, 256, 4, 1, q4_k) \ + instantiate_kquant_qmm_t(type, 256, 4, true, 0, q4_k) \ + instantiate_kquant_qmm_t(type, 256, 4, true, 1, q4_k) \ + instantiate_kquant_qmm_t(type, 256, 4, false, 0, q4_k) \ + instantiate_kquant_qmm_t(type, 256, 4, false, 1, q4_k) \ + instantiate_kquant_qmm_t_splitk(type, 256, 4, true, q4_k) \ + instantiate_kquant_qmm_t_splitk(type, 256, 4, false, q4_k) \ + instantiate_kquant_qmm_n(type, 256, 4, 0, q4_k) \ + instantiate_kquant_qmm_n(type, 256, 4, 1, q4_k) \ + instantiate_kquant_gather_qmv(gather_qmv_fast, type, 256, 4, q4_k) \ + instantiate_kquant_gather_qmv(gather_qmv, type, 256, 4, q4_k) \ + instantiate_kquant_gather_qmm_t(type, 256, 4, true, q4_k) \ + instantiate_kquant_gather_qmm_t(type, 256, 4, false, q4_k) \ + instantiate_kquant_gather_qmm_n(type, 256, 4, q4_k) \ + instantiate_kquant_dequantize(type, 256, 4, q4_k) + +instantiate_kquant_q4_k_for_type(float) +instantiate_kquant_q4_k_for_type(bfloat16_t) +instantiate_kquant_q4_k_for_type(float16_t) + +#define instantiate_kquant_q5_k_for_type(type) \ + instantiate_kquant_batched(qmv_fast, type, 256, 5, 0, q5_k) \ + instantiate_kquant_batched(qmv_fast, type, 256, 5, 1, q5_k) \ + instantiate_kquant_batched(qmv, type, 256, 5, 0, q5_k) \ + instantiate_kquant_batched(qmv, type, 256, 5, 1, q5_k) \ + instantiate_kquant_qmm_t(type, 256, 5, true, 0, q5_k) \ + instantiate_kquant_qmm_t(type, 256, 5, true, 1, q5_k) \ + instantiate_kquant_qmm_t(type, 256, 5, false, 0, q5_k) \ + instantiate_kquant_qmm_t(type, 256, 5, false, 1, q5_k) \ + instantiate_kquant_qmm_t_splitk(type, 256, 5, true, q5_k) \ + instantiate_kquant_qmm_t_splitk(type, 256, 5, false, q5_k) \ + instantiate_kquant_qmm_n(type, 256, 5, 0, q5_k) \ + instantiate_kquant_qmm_n(type, 256, 5, 1, q5_k) \ + instantiate_kquant_gather_qmv(gather_qmv_fast, type, 256, 5, q5_k) \ + instantiate_kquant_gather_qmv(gather_qmv, type, 256, 5, q5_k) \ + instantiate_kquant_gather_qmm_t(type, 256, 5, true, q5_k) \ + instantiate_kquant_gather_qmm_t(type, 256, 5, false, q5_k) \ + instantiate_kquant_gather_qmm_n(type, 256, 5, q5_k) \ + instantiate_kquant_dequantize(type, 256, 5, q5_k) + +instantiate_kquant_q5_k_for_type(float) +instantiate_kquant_q5_k_for_type(bfloat16_t) +instantiate_kquant_q5_k_for_type(float16_t) + +#define instantiate_kquant_q6_k_for_type(type) \ + instantiate_kquant_batched(qmv_fast, type, 256, 6, 0, q6_k) \ + instantiate_kquant_batched(qmv_fast, type, 256, 6, 1, q6_k) \ + instantiate_kquant_batched(qmv, type, 256, 6, 0, q6_k) \ + instantiate_kquant_batched(qmv, type, 256, 6, 1, q6_k) \ + instantiate_kquant_qmm_t(type, 256, 6, true, 0, q6_k) \ + instantiate_kquant_qmm_t(type, 256, 6, true, 1, q6_k) \ + instantiate_kquant_qmm_t(type, 256, 6, false, 0, q6_k) \ + instantiate_kquant_qmm_t(type, 256, 6, false, 1, q6_k) \ + instantiate_kquant_qmm_t_splitk(type, 256, 6, true, q6_k) \ + instantiate_kquant_qmm_t_splitk(type, 256, 6, false, q6_k) \ + instantiate_kquant_qmm_n(type, 256, 6, 0, q6_k) \ + instantiate_kquant_qmm_n(type, 256, 6, 1, q6_k) \ + instantiate_kquant_gather_qmv(gather_qmv_fast, type, 256, 6, q6_k) \ + instantiate_kquant_gather_qmv(gather_qmv, type, 256, 6, q6_k) \ + instantiate_kquant_gather_qmm_t(type, 256, 6, true, q6_k) \ + instantiate_kquant_gather_qmm_t(type, 256, 6, false, q6_k) \ + instantiate_kquant_gather_qmm_n(type, 256, 6, q6_k) \ + instantiate_kquant_dequantize(type, 256, 6, q6_k) + +instantiate_kquant_q6_k_for_type(float) +instantiate_kquant_q6_k_for_type(bfloat16_t) +instantiate_kquant_q6_k_for_type(float16_t) + +#define instantiate_kquant_q3_k_for_type(type) \ + instantiate_kquant_batched(qmv_fast, type, 256, 3, 0, q3_k) \ + instantiate_kquant_batched(qmv_fast, type, 256, 3, 1, q3_k) \ + instantiate_kquant_batched(qmv, type, 256, 3, 0, q3_k) \ + instantiate_kquant_batched(qmv, type, 256, 3, 1, q3_k) \ + instantiate_kquant_qmm_t(type, 256, 3, true, 0, q3_k) \ + instantiate_kquant_qmm_t(type, 256, 3, true, 1, q3_k) \ + instantiate_kquant_qmm_t(type, 256, 3, false, 0, q3_k) \ + instantiate_kquant_qmm_t(type, 256, 3, false, 1, q3_k) \ + instantiate_kquant_qmm_t_splitk(type, 256, 3, true, q3_k) \ + instantiate_kquant_qmm_t_splitk(type, 256, 3, false, q3_k) \ + instantiate_kquant_qmm_n(type, 256, 3, 0, q3_k) \ + instantiate_kquant_qmm_n(type, 256, 3, 1, q3_k) \ + instantiate_kquant_gather_qmv(gather_qmv_fast, type, 256, 3, q3_k) \ + instantiate_kquant_gather_qmv(gather_qmv, type, 256, 3, q3_k) \ + instantiate_kquant_gather_qmm_t(type, 256, 3, true, q3_k) \ + instantiate_kquant_gather_qmm_t(type, 256, 3, false, q3_k) \ + instantiate_kquant_gather_qmm_n(type, 256, 3, q3_k) \ + instantiate_kquant_dequantize(type, 256, 3, q3_k) + +instantiate_kquant_q3_k_for_type(float) +instantiate_kquant_q3_k_for_type(bfloat16_t) +instantiate_kquant_q3_k_for_type(float16_t) + +#define instantiate_kquant_q2_k_for_type(type) \ + instantiate_kquant_batched(qmv_fast, type, 256, 2, 0, q2_k) \ + instantiate_kquant_batched(qmv_fast, type, 256, 2, 1, q2_k) \ + instantiate_kquant_batched(qmv, type, 256, 2, 0, q2_k) \ + instantiate_kquant_batched(qmv, type, 256, 2, 1, q2_k) \ + instantiate_kquant_qmm_t(type, 256, 2, true, 0, q2_k) \ + instantiate_kquant_qmm_t(type, 256, 2, true, 1, q2_k) \ + instantiate_kquant_qmm_t(type, 256, 2, false, 0, q2_k) \ + instantiate_kquant_qmm_t(type, 256, 2, false, 1, q2_k) \ + instantiate_kquant_qmm_t_splitk(type, 256, 2, true, q2_k) \ + instantiate_kquant_qmm_t_splitk(type, 256, 2, false, q2_k) \ + instantiate_kquant_qmm_n(type, 256, 2, 0, q2_k) \ + instantiate_kquant_qmm_n(type, 256, 2, 1, q2_k) \ + instantiate_kquant_gather_qmv(gather_qmv_fast, type, 256, 2, q2_k) \ + instantiate_kquant_gather_qmv(gather_qmv, type, 256, 2, q2_k) \ + instantiate_kquant_gather_qmm_t(type, 256, 2, true, q2_k) \ + instantiate_kquant_gather_qmm_t(type, 256, 2, false, q2_k) \ + instantiate_kquant_gather_qmm_n(type, 256, 2, q2_k) \ + instantiate_kquant_dequantize(type, 256, 2, q2_k) + +instantiate_kquant_q2_k_for_type(float) +instantiate_kquant_q2_k_for_type(bfloat16_t) +instantiate_kquant_q2_k_for_type(float16_t) + // clang-format on diff --git a/mlx/backend/metal/kernels/kq_quantized_encode.h b/mlx/backend/metal/kernels/kq_quantized_encode.h new file mode 100644 index 0000000000..624a60fa7c --- /dev/null +++ b/mlx/backend/metal/kernels/kq_quantized_encode.h @@ -0,0 +1,1239 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/metal/kernels/kq_quantized.h" + +template +METAL_FUNC void kq_q8_0_quantize_impl( + const device T* w, + device uint8_t* out, + const constant uint& num_blocks, + uint gid) { + if (gid >= num_blocks) { + return; + } + const device T* x = w + gid * KQ_Q8_0_GROUP; + device uint8_t* block_addr = out + gid * KQ_Q8_0_BLOCK_BYTES; + + float amax = 0.0f; + for (int j = 0; j < KQ_Q8_0_GROUP; j++) { + amax = max(amax, fabs(float(x[j]))); + } + + const float d = amax / 127.0f; + const float id = (d != 0.0f) ? (1.0f / d) : 0.0f; + + *(device half*)(block_addr + KQ_Q8_0_D_OFFSET) = half(d); + + device int8_t* qs = (device int8_t*)(block_addr + KQ_Q8_0_Q_OFFSET); + for (int j = 0; j < KQ_Q8_0_GROUP; j++) { + float v = float(x[j]) * id; + qs[j] = int8_t(clamp(round(v), -127.0f, 127.0f)); + } +} + +template +[[kernel]] void kq_q8_0_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + const constant uint& num_blocks [[buffer(2)]], + const device float* imatrix [[buffer(3)]], + const constant uint& has_imatrix [[buffer(4)]], + uint gid [[thread_position_in_grid]]) { + static_assert(group_size == KQ_Q8_0_GROUP, "Q8_0 requires group_size=32"); + static_assert(bits == 8, "Q8_0 requires bits=8"); + (void)imatrix; + (void)has_imatrix; + kq_q8_0_quantize_impl(w, out, num_blocks, gid); +} + +inline int kq_nearest_int(float v) { + return int(rint(v)); +} + +template +inline float kq_make_qp_quants( + const threadgroup float* x, + const threadgroup float* qw, + int nmax, + thread uint8_t* L_out) { + float max_v = 0.0f; + for (int i = 0; i < N; i++) { + if (x[i] > max_v) + max_v = x[i]; + } + if (max_v < 1e-30f) { // GROUP_MAX_EPS + for (int i = 0; i < N; i++) + L_out[i] = 0; + return 0.0f; + } + float iscale = float(nmax) / max_v; + uint8_t L[N]; + for (int i = 0; i < N; i++) { + int l = kq_nearest_int(iscale * x[i]); + L[i] = uint8_t(max(0, min(nmax, l))); + } + float scale = 1.0f / iscale; + float best_mse = 0.0f; + for (int i = 0; i < N; i++) { + float diff = x[i] - scale * float(L[i]); + best_mse += qw[i] * diff * diff; + } + for (int is = -4; is <= 4; is++) { + if (is == 0) + continue; + float iscale_is = (0.1f * float(is) + float(nmax)) / max_v; + float scale_is = 1.0f / iscale_is; + float mse = 0.0f; + for (int i = 0; i < N; i++) { + int l = kq_nearest_int(iscale_is * x[i]); + l = min(nmax, l); + float diff = x[i] - scale_is * float(l); + mse += qw[i] * diff * diff; + } + if (mse < best_mse) { + best_mse = mse; + iscale = iscale_is; + } + } + float sumlx = 0.0f, suml2 = 0.0f; + for (int i = 0; i < N; i++) { + int l = kq_nearest_int(iscale * x[i]); + l = min(nmax, l); + L[i] = uint8_t(l); + sumlx += qw[i] * x[i] * float(l); + suml2 += qw[i] * float(l) * float(l); + } + for (int itry = 0; itry < 5; itry++) { + int n_changed = 0; + for (int i = 0; i < N; i++) { + float w = qw[i]; + float slx = sumlx - w * x[i] * float(L[i]); + float sl2 = suml2 - w * float(L[i]) * float(L[i]); + if (slx > 0.0f && sl2 > 0.0f) { + int new_l = kq_nearest_int(x[i] * sl2 / slx); + new_l = min(nmax, new_l); + if (new_l != int(L[i])) { + slx += w * x[i] * float(new_l); + sl2 += w * float(new_l) * float(new_l); + if (slx * slx * suml2 > sumlx * sumlx * sl2) { + L[i] = uint8_t(new_l); + sumlx = slx; + suml2 = sl2; + n_changed++; + } + } + } + } + if (n_changed == 0) + break; + } + for (int i = 0; i < N; i++) + L_out[i] = L[i]; + return (suml2 > 0.0f) ? (sumlx / suml2) : 0.0f; +} + +template +inline float kq_make_qkx3_quants( + const threadgroup float* x, + const threadgroup float* qw, + int nmax, + thread float& the_min) { + const float rmin = -0.9f; + const float rdelta = 0.05f; + const int nstep = 36; + + float min_v = x[0]; + float max_v = x[0]; + float sum_w = qw[0]; + float sum_x = sum_w * x[0]; + for (int i = 1; i < N; i++) { + if (x[i] < min_v) + min_v = x[i]; + if (x[i] > max_v) + max_v = x[i]; + float w = qw[i]; + sum_w += w; + sum_x += w * x[i]; + } + if (min_v > 0.0f) + min_v = 0.0f; + if (max_v <= min_v) { + the_min = -min_v; + return 0.0f; + } + + float iscale = float(nmax) / (max_v - min_v); + float scale = 1.0f / iscale; + uint8_t L[N]; + float best_mad = 0.0f; + for (int i = 0; i < N; i++) { + int l = kq_nearest_int(iscale * (x[i] - min_v)); + L[i] = uint8_t(max(0, min(nmax, l))); + float diff = scale * float(L[i]) + min_v - x[i]; + best_mad += qw[i] * diff * diff; + } + float best_min = min_v; + + uint8_t Laux[N]; + for (int is = 0; is <= nstep; is++) { + float iscale_is = + (rmin + rdelta * float(is) + float(nmax)) / (max_v - min_v); + float sum_l = 0.0f, sum_l2 = 0.0f, sum_xl = 0.0f; + for (int i = 0; i < N; i++) { + int l = kq_nearest_int(iscale_is * (x[i] - min_v)); + l = max(0, min(nmax, l)); + Laux[i] = uint8_t(l); + float w = qw[i]; + sum_l += w * float(l); + sum_l2 += w * float(l) * float(l); + sum_xl += w * float(l) * x[i]; + } + float D = sum_w * sum_l2 - sum_l * sum_l; + if (D > 0.0f) { + float this_scale = (sum_w * sum_xl - sum_x * sum_l) / D; + float this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D; + if (this_min > 0.0f) { + this_min = 0.0f; + this_scale = sum_xl / sum_l2; + } + float mad = 0.0f; + for (int i = 0; i < N; i++) { + float diff = this_scale * float(Laux[i]) + this_min - x[i]; + mad += qw[i] * diff * diff; + } + if (mad < best_mad) { + for (int i = 0; i < N; i++) + L[i] = Laux[i]; + best_mad = mad; + scale = this_scale; + best_min = this_min; + } + } + } + the_min = -best_min; + return scale; +} + +inline void kq_compute_sigma2_av_x( + threadgroup const float* Xs, + threadgroup float* scratch, + uint lid, + uint simd_id, + uint lane_id, + float factor = 2.0f) { + float my_x = Xs[lid]; + float simd_x2 = simd_sum(my_x * my_x); + if (lane_id == 0) + scratch[simd_id] = simd_x2; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_id == 0) { + float v = (lane_id < 8) ? scratch[lane_id] : 0.0f; + float total = simd_sum(v); + if (lane_id == 0) { + float sigma2 = factor * total / 256.0f; + scratch[8] = sigma2; + scratch[9] = sqrt(sigma2); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +inline void kq_pack_scales12( + device uint8_t* scales12, + const thread uint8_t* Ls, + const thread uint8_t* Lm) { + for (int i = 0; i < 12; i++) + scales12[i] = 0; + for (int j = 0; j < 8; j++) { + uint8_t ls = Ls[j]; + uint8_t lm = Lm[j]; + if (j < 4) { + scales12[j] = ls; + scales12[j + 4] = lm; + } else { + scales12[j + 4] = (ls & 0x0F) | ((lm & 0x0F) << 4); + scales12[j - 4] |= ((ls >> 4) << 6); + scales12[j] |= ((lm >> 4) << 6); + } + } +} + +inline void kq_unpack_scale_min_k4( + int j, + const device uint8_t* scales12, + thread uint8_t& sc_out, + thread uint8_t& mn_out) { + if (j < 4) { + sc_out = scales12[j] & 0x3F; + mn_out = scales12[j + 4] & 0x3F; + } else { + sc_out = (scales12[j + 4] & 0x0F) | ((scales12[j - 4] >> 6) << 4); + mn_out = (scales12[j + 4] >> 4) | ((scales12[j] >> 6) << 4); + } +} + +template +METAL_FUNC void kq_q45_k_quantize_impl( + const device T* w, + device uint8_t* out, + const constant uint& num_blocks, + const device float* imatrix, + bool has_imatrix, + uint K, + uint tg_id, + uint lid, + uint simd_id, + uint lane_id, + threadgroup float* Xs, + threadgroup float* QWs, + threadgroup uint8_t* L_tgm, + threadgroup float* scales_sb, + threadgroup float* mins_sb, + threadgroup float* sw_sb, + threadgroup uint8_t* Ls, + threadgroup uint8_t* Lm, + threadgroup float* scratch) { + static_assert(bits == 4 || bits == 5, "shared Q4_K/Q5_K only"); + constexpr int nmax = (bits == 4) ? 15 : 31; + constexpr int block_bytes = + (bits == 4) ? KQ_Q4_K_BLOCK_BYTES : KQ_Q5_K_BLOCK_BYTES; + constexpr int d_off = (bits == 4) ? KQ_Q4_K_D_OFFSET : KQ_Q5_K_D_OFFSET; + constexpr int dmin_off = + (bits == 4) ? KQ_Q4_K_DMIN_OFFSET : KQ_Q5_K_DMIN_OFFSET; + constexpr int sc_off = + (bits == 4) ? KQ_Q4_K_SCALES_OFFSET : KQ_Q5_K_SCALES_OFFSET; + constexpr int qs_off = (bits == 4) ? KQ_Q4_K_QS_OFFSET : KQ_Q5_K_QS_OFFSET; + constexpr int superblock = 256; + + if (tg_id >= num_blocks) + return; + + device uint8_t* block_addr = out + tg_id * block_bytes; + const device T* x_global = w + tg_id * superblock; + + // -- Phase 1: Load Xs[256] -- + Xs[lid] = float(x_global[lid]); + threadgroup_barrier(mem_flags::mem_threadgroup); + kq_compute_sigma2_av_x(Xs, scratch, lid, simd_id, lane_id); + float sigma2 = scratch[8]; + float av_x = scratch[9]; + + // -- Phase 2: weights -- + if (has_imatrix) { + uint k_off = (tg_id * superblock) % K; + QWs[lid] = imatrix[k_off + lid] * sqrt(sigma2 + Xs[lid] * Xs[lid]); + } else { + QWs[lid] = av_x + abs(Xs[lid]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 3: per-sub-block fit -- + if (lane_id == 0) { + int sb_off = simd_id * 32; + float sumw = 0.0f; + for (int l = 0; l < 32; l++) + sumw += QWs[sb_off + l]; + sw_sb[simd_id] = sumw; + float the_min; + float scale = + kq_make_qkx3_quants<32>(&Xs[sb_off], &QWs[sb_off], nmax, the_min); + scales_sb[simd_id] = scale; + mins_sb[simd_id] = the_min; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 4: super-scale fit + pack scales[12] -- + if (simd_id == 0 && lane_id == 0) { + uint8_t Ls_local[8]; + uint8_t Lm_local[8]; + float d_block = + kq_make_qp_quants<8>(&scales_sb[0], &sw_sb[0], 63, Ls_local); + float m_block = kq_make_qp_quants<8>(&mins_sb[0], &sw_sb[0], 63, Lm_local); + for (int i = 0; i < 8; i++) { + Ls[i] = Ls_local[i]; + Lm[i] = Lm_local[i]; + } + *(device half*)(block_addr + d_off) = half(d_block); + *(device half*)(block_addr + dmin_off) = half(m_block); + kq_pack_scales12(block_addr + sc_off, Ls_local, Lm_local); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 5: re-quantize -- + float d_wire = float(*(device half*)(block_addr + d_off)); + float dmin_wire = float(*(device half*)(block_addr + dmin_off)); + int my_sb = int(lid) / 32; + uint8_t sc, mn; + kq_unpack_scale_min_k4(my_sb, block_addr + sc_off, sc, mn); + float d_final = d_wire * float(sc); + float dm_final = dmin_wire * float(mn); + uint8_t my_L; + if (d_final == 0.0f) { + my_L = 0; + } else { + int l = kq_nearest_int((Xs[lid] + dm_final) / d_final); + l = max(0, min(nmax, l)); + my_L = uint8_t(l); + } + L_tgm[lid] = my_L; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 6: pack qs + qh -- + if (lid < 128) { + int stride = int(lid) / 32; + int l = int(lid) % 32; + uint8_t lo = L_tgm[64 * stride + l] & 0x0F; + uint8_t hi = L_tgm[64 * stride + l + 32] & 0x0F; + device uint8_t* qs = block_addr + qs_off; + qs[32 * stride + l] = lo | (hi << 4); + } + if (bits == 5 && lid >= 128 && lid < 160) { + int j = int(lid) - 128; + uint8_t b = 0; + for (int block_idx = 0; block_idx < 8; block_idx++) { + if (L_tgm[block_idx * 32 + j] > 15) { + b |= uint8_t(1 << block_idx); + } + } + device uint8_t* qh = block_addr + KQ_Q5_K_QH_OFFSET; + qh[j] = b; + } +} + +template +[[kernel]] void kq_q4_k_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + const constant uint& num_blocks [[buffer(2)]], + const device float* imatrix [[buffer(3)]], + const constant uint& has_imatrix [[buffer(4)]], + const constant uint& K [[buffer(5)]], + uint tg_id [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]], + uint lane_id [[thread_index_in_simdgroup]]) { + static_assert(group_size == 256, "Q4_K requires group_size=256"); + static_assert(bits == 4, "Q4_K requires bits=4"); + threadgroup float Xs[256]; + threadgroup float QWs[256]; + threadgroup uint8_t L_tgm[256]; + threadgroup float scales_sb[8]; + threadgroup float mins_sb[8]; + threadgroup float sw_sb[8]; + threadgroup uint8_t Ls[8]; + threadgroup uint8_t Lm[8]; + threadgroup float scratch[16]; + kq_q45_k_quantize_impl( + w, + out, + num_blocks, + imatrix, + has_imatrix != 0, + K, + tg_id, + lid, + simd_id, + lane_id, + Xs, + QWs, + L_tgm, + scales_sb, + mins_sb, + sw_sb, + Ls, + Lm, + scratch); +} + +inline float kq_make_qx_quants_16( + const threadgroup float* x, + const threadgroup float* qw, + int nmax, + thread uint8_t* L_out) { + const int n = 16; + float max_v = 0.0f; + float amax = 0.0f; + for (int i = 0; i < n; i++) { + float ax = abs(x[i]); + if (ax > amax) { + amax = ax; + max_v = x[i]; + } + } + if (amax < 1e-30f) { + for (int i = 0; i < n; i++) + L_out[i] = uint8_t(nmax); + return 0.0f; + } + float iscale = -float(nmax) / max_v; + float sumlx = 0.0f, suml2 = 0.0f; + for (int i = 0; i < n; i++) { + int l = kq_nearest_int(iscale * x[i]); + l = max(-nmax, min(nmax - 1, l)); + L_out[i] = uint8_t(l + nmax); + float w = (qw != nullptr) ? qw[i] : x[i] * x[i]; + sumlx += w * x[i] * float(l); + suml2 += w * float(l) * float(l); + } + float scale = (suml2 > 0.0f) ? (sumlx / suml2) : 0.0f; + float best = scale * sumlx; + for (int is = -9; is <= 9; is++) { + if (is == 0) + continue; + float iscale_is = -(float(nmax) + 0.1f * float(is)) / max_v; + float slx = 0.0f, sl2 = 0.0f; + for (int i = 0; i < n; i++) { + int l = kq_nearest_int(iscale_is * x[i]); + l = max(-nmax, min(nmax - 1, l)); + float w = (qw != nullptr) ? qw[i] : x[i] * x[i]; + slx += w * x[i] * float(l); + sl2 += w * float(l) * float(l); + } + if (sl2 > 0.0f && slx * slx > best * sl2) { + for (int i = 0; i < n; i++) { + int l = kq_nearest_int(iscale_is * x[i]); + l = max(-nmax, min(nmax - 1, l)); + L_out[i] = uint8_t(l + nmax); + } + scale = slx / sl2; + best = scale * slx; + } + } + return scale; +} + +template +METAL_FUNC void kq_q6_k_quantize_impl( + const device T* w, + device uint8_t* out, + const constant uint& num_blocks, + const device float* imatrix, + bool has_imatrix, + uint K, + uint tg_id, + uint lid, + uint simd_id, + uint lane_id, + threadgroup float* Xs, + threadgroup float* QWs, + threadgroup uint8_t* L_tgm, + threadgroup float* scales_sb, + threadgroup float* scratch) { + if (tg_id >= num_blocks) { + return; + } + + device uint8_t* block_addr = out + tg_id * KQ_Q6_K_BLOCK_BYTES; + const device T* x_global = w + tg_id * KQ_Q6_K_SUPERBLOCK; + + // -- Phase 1: Load Xs[256] -- + Xs[lid] = float(x_global[lid]); + if (has_imatrix) { + uint k_off = (tg_id * KQ_Q6_K_SUPERBLOCK) % K; + QWs[lid] = imatrix[k_off + lid]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 2: per-sub-block fit -- + if ((lid % 16) == 0) { + int sb = int(lid) / 16; + int sb_off = sb * 16; + uint8_t L_local[16]; + const threadgroup float* qw_ptr = + has_imatrix ? &QWs[sb_off] : (const threadgroup float*)nullptr; + float scale = kq_make_qx_quants_16(&Xs[sb_off], qw_ptr, 32, L_local); + scales_sb[sb] = scale; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 3: super-scale + write d, scales[16] -- + if (simd_id == 0 && lane_id == 0) { + float max_scale = 0.0f; + float max_abs_scale = 0.0f; + for (int sb = 0; sb < 16; sb++) { + float s = scales_sb[sb]; + float a = abs(s); + if (a > max_abs_scale) { + max_abs_scale = a; + max_scale = s; + } + } + if (max_abs_scale < 1e-30f) { + scratch[0] = -1.0f; + } else { + float iscale = -128.0f / max_scale; + *(device half*)(block_addr + KQ_Q6_K_D_OFFSET) = half(1.0f / iscale); + device int8_t* scales_out = + (device int8_t*)(block_addr + KQ_Q6_K_SCALES_OFFSET); + for (int sb = 0; sb < 16; sb++) { + int s = kq_nearest_int(iscale * scales_sb[sb]); + s = min(127, s); + scales_out[sb] = int8_t(s); + } + scratch[0] = 1.0f; + scratch[1] = float(*(device half*)(block_addr + KQ_Q6_K_D_OFFSET)); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + bool all_zero = scratch[0] < 0.0f; + if (all_zero) { + if (lid < uint(KQ_Q6_K_BLOCK_BYTES)) { + block_addr[lid] = 0; + } + return; + } + + float d_wire = scratch[1]; + + // -- Phase 4: re-quantize -- + int my_sb = int(lid) / 16; + const device int8_t* scales_out = + (const device int8_t*)(block_addr + KQ_Q6_K_SCALES_OFFSET); + int8_t s_int8 = scales_out[my_sb]; + float d_eff = d_wire * float(s_int8); + + uint8_t my_L; + if (d_eff == 0.0f) { + my_L = 32; + } else { + float xv = Xs[lid]; + int l = kq_nearest_int(xv / d_eff); + l = max(-32, min(31, l)); + my_L = uint8_t(l + 32); + } + L_tgm[lid] = my_L; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 5: pack ql[128] + qh[64] -- + if (lid < 64) { + int stride = int(lid) / 32; + int l = int(lid) % 32; + int base = stride * 128; + uint8_t La = L_tgm[base + l]; + uint8_t Lb = L_tgm[base + l + 32]; + uint8_t Lc = L_tgm[base + l + 64]; + uint8_t Ld = L_tgm[base + l + 96]; + device uint8_t* ql_out = block_addr + KQ_Q6_K_QL_OFFSET + stride * 64; + device uint8_t* qh_out = block_addr + KQ_Q6_K_QH_OFFSET + stride * 32; + ql_out[l] = (La & 0x0F) | ((Lc & 0x0F) << 4); + ql_out[l + 32] = (Lb & 0x0F) | ((Ld & 0x0F) << 4); + qh_out[l] = + (La >> 4) | ((Lb >> 4) << 2) | ((Lc >> 4) << 4) | ((Ld >> 4) << 6); + } +} + +template +[[kernel]] void kq_q6_k_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + const constant uint& num_blocks [[buffer(2)]], + const device float* imatrix [[buffer(3)]], + const constant uint& has_imatrix [[buffer(4)]], + const constant uint& K [[buffer(5)]], + uint tg_id [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]], + uint lane_id [[thread_index_in_simdgroup]]) { + static_assert(group_size == 256, "Q6_K requires group_size=256"); + static_assert(bits == 6, "Q6_K requires bits=6"); + threadgroup float Xs[256]; + threadgroup float QWs[256]; + threadgroup uint8_t L_tgm[256]; + threadgroup float scales_sb[16]; + threadgroup float scratch[16]; + kq_q6_k_quantize_impl( + w, + out, + num_blocks, + imatrix, + has_imatrix != 0, + K, + tg_id, + lid, + simd_id, + lane_id, + Xs, + QWs, + L_tgm, + scales_sb, + scratch); +} + +template +[[kernel]] void kq_q5_k_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + const constant uint& num_blocks [[buffer(2)]], + const device float* imatrix [[buffer(3)]], + const constant uint& has_imatrix [[buffer(4)]], + const constant uint& K [[buffer(5)]], + uint tg_id [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]], + uint lane_id [[thread_index_in_simdgroup]]) { + static_assert(group_size == 256, "Q5_K requires group_size=256"); + static_assert(bits == 5, "Q5_K requires bits=5"); + threadgroup float Xs[256]; + threadgroup float QWs[256]; + threadgroup uint8_t L_tgm[256]; + threadgroup float scales_sb[8]; + threadgroup float mins_sb[8]; + threadgroup float sw_sb[8]; + threadgroup uint8_t Ls[8]; + threadgroup uint8_t Lm[8]; + threadgroup float scratch[16]; + kq_q45_k_quantize_impl( + w, + out, + num_blocks, + imatrix, + has_imatrix != 0, + K, + tg_id, + lid, + simd_id, + lane_id, + Xs, + QWs, + L_tgm, + scales_sb, + mins_sb, + sw_sb, + Ls, + Lm, + scratch); +} + +template +METAL_FUNC void kq_q3_k_quantize_impl( + const device T* w, + device uint8_t* out, + const constant uint& num_blocks, + const device float* imatrix, + bool has_imatrix, + uint K, + uint tg_id, + uint lid, + uint simd_id, + uint lane_id, + threadgroup float* Xs, + threadgroup float* QWs, + threadgroup uint8_t* L_tgm, + threadgroup float* scales_sb, + threadgroup float* scratch) { + if (tg_id >= num_blocks) { + return; + } + + device uint8_t* block_addr = out + tg_id * KQ_Q3_K_BLOCK_BYTES; + const device T* x_global = w + tg_id * KQ_Q3_K_SUPERBLOCK; + + // -- Phase 1: Load Xs[256] -- + Xs[lid] = float(x_global[lid]); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 2: imatrix weights -- + if (has_imatrix) { + kq_compute_sigma2_av_x(Xs, scratch, lid, simd_id, lane_id); + float sigma2 = scratch[8]; + uint k_off = (tg_id * KQ_Q3_K_SUPERBLOCK) % K; + QWs[lid] = imatrix[k_off + lid] * sqrt(sigma2 + Xs[lid] * Xs[lid]); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // -- Phase 3: per-sub-block fit -- + if ((lid % 16) == 0) { + int sb = int(lid) / 16; + int sb_off = sb * 16; + uint8_t L_local[16]; + const threadgroup float* qw_ptr = + has_imatrix ? &QWs[sb_off] : (const threadgroup float*)nullptr; + float scale = kq_make_qx_quants_16(&Xs[sb_off], qw_ptr, 4, L_local); + scales_sb[sb] = scale; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 4: super-scale + pack scales[12] -- + if (simd_id == 0 && lane_id == 0) { + float amax_sc = 0.0f; + float max_sc = 0.0f; + for (int sb = 0; sb < 16; sb++) { + float v = scales_sb[sb]; + float av = abs(v); + if (av > amax_sc) { + amax_sc = av; + max_sc = v; + } + } + device uint8_t* scales12 = block_addr + KQ_Q3_K_SCALES_OFFSET; + for (int i = 0; i < 12; i++) + scales12[i] = 0; + + float d_block = 0.0f; + if (max_sc != 0.0f) { + float iscale = -32.0f / max_sc; + for (int sb = 0; sb < 16; sb++) { + int l = kq_nearest_int(iscale * scales_sb[sb]); + l = max(-32, min(31, l)) + 32; // biased [0, 63] + if (sb < 8) { + scales12[sb] = uint8_t(l & 0x0F); + } else { + scales12[sb - 8] |= uint8_t((l & 0x0F) << 4); + } + uint8_t lh = uint8_t((l >> 4) & 0x03); + scales12[8 + (sb % 4)] |= uint8_t(lh << (2 * (sb / 4))); + } + d_block = 1.0f / iscale; + } + *(device half*)(block_addr + KQ_Q3_K_D_OFFSET) = half(d_block); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 5: re-quantize -- + int my_sb = int(lid) / 16; + uint8_t sc_unsigned = + kq_q3_k_unpack_scale(my_sb, block_addr + KQ_Q3_K_SCALES_OFFSET); + int sc_signed = int(sc_unsigned) - 32; + float d_fp16 = float(*(device half*)(block_addr + KQ_Q3_K_D_OFFSET)); + float d_eff = d_fp16 * float(sc_signed); + + uint8_t my_u3; + if (d_eff == 0.0f) { + my_u3 = 4; + } else { + int l = kq_nearest_int(Xs[lid] / d_eff); + l = max(-4, min(3, l)); + my_u3 = uint8_t(l + 4); + } + L_tgm[lid] = my_u3; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 6: pack qs[64] + hmask[32] -- + if (lid < 64) { + int outer_half = int(lid) / 32; + int within_shift = int(lid) % 32; + uint8_t byte = 0; + for (int shift_idx = 0; shift_idx < 4; shift_idx++) { + int w_idx = outer_half * 128 + shift_idx * 32 + within_shift; + uint8_t q2 = L_tgm[w_idx] & 0x03; + byte |= uint8_t(q2 << (shift_idx * 2)); + } + device uint8_t* qs = block_addr + KQ_Q3_K_QS_OFFSET; + qs[lid] = byte; + } else if (lid < 96) { + int within_shift = int(lid) - 64; + uint8_t byte = 0; + for (int b = 0; b < 8; b++) { + int outer_half = b / 4; + int shift_idx = b % 4; + int w_idx = outer_half * 128 + shift_idx * 32 + within_shift; + uint8_t hbit = (L_tgm[w_idx] >> 2) & 0x01; + byte |= uint8_t(hbit << b); + } + device uint8_t* hmask = block_addr + KQ_Q3_K_HMASK_OFFSET; + hmask[within_shift] = byte; + } +} + +template +[[kernel]] void kq_q3_k_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + const constant uint& num_blocks [[buffer(2)]], + const device float* imatrix [[buffer(3)]], + const constant uint& has_imatrix [[buffer(4)]], + const constant uint& K [[buffer(5)]], + uint tg_id [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]], + uint lane_id [[thread_index_in_simdgroup]]) { + static_assert(group_size == 256, "Q3_K requires group_size=256"); + static_assert(bits == 3, "Q3_K requires bits=3"); + threadgroup float Xs[256]; + threadgroup float QWs[256]; + threadgroup uint8_t L_tgm[256]; + threadgroup float scales_sb[16]; + threadgroup float scratch[16]; + kq_q3_k_quantize_impl( + w, + out, + num_blocks, + imatrix, + has_imatrix != 0, + K, + tg_id, + lid, + simd_id, + lane_id, + Xs, + QWs, + L_tgm, + scales_sb, + scratch); +} + +template +METAL_FUNC void kq_q2_k_quantize_impl( + const device T* w, + device uint8_t* out, + const constant uint& num_blocks, + const device float* imatrix, + bool has_imatrix, + uint K, + uint tg_id, + uint lid, + uint simd_id, + uint lane_id, + threadgroup float* Xs, + threadgroup float* QWs, + threadgroup uint8_t* L_tgm, + threadgroup float* scales_sb, + threadgroup float* mins_sb, + threadgroup float* sw_sb, + threadgroup uint8_t* Ls, + threadgroup uint8_t* Lm, + threadgroup float* scratch) { + if (tg_id >= num_blocks) { + return; + } + + device uint8_t* block_addr = out + tg_id * KQ_Q2_K_BLOCK_BYTES; + const device T* x_global = w + tg_id * KQ_Q2_K_SUPERBLOCK; + + // -- Phase 1: Load Xs[256] -- + Xs[lid] = float(x_global[lid]); + threadgroup_barrier(mem_flags::mem_threadgroup); + kq_compute_sigma2_av_x(Xs, scratch, lid, simd_id, lane_id, 1.0f); + float sigma2 = scratch[8]; + float av_x = scratch[9]; + + // -- Phase 2: weights -- + if (has_imatrix) { + uint k_off = (tg_id * KQ_Q2_K_SUPERBLOCK) % K; + QWs[lid] = imatrix[k_off + lid] * sqrt(sigma2 + Xs[lid] * Xs[lid]); + } else { + QWs[lid] = av_x + abs(Xs[lid]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 3: per-sub-block fit -- + if ((lid % 16) == 0) { + int sb = int(lid) / 16; + int sb_off = sb * 16; + float sumw = 0.0f; + for (int l = 0; l < 16; l++) + sumw += QWs[sb_off + l]; + sw_sb[sb] = sumw; + float the_min; + float scale = + kq_make_qkx3_quants<16>(&Xs[sb_off], &QWs[sb_off], 3, the_min); + scales_sb[sb] = scale; + mins_sb[sb] = the_min; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 4: super-scale + pack scales[16] -- + if (simd_id == 0 && lane_id == 0) { + uint8_t Ls_local[16]; + uint8_t Lm_local[16]; + float d_block = + kq_make_qp_quants<16>(&scales_sb[0], &sw_sb[0], 15, Ls_local); + float m_block = kq_make_qp_quants<16>(&mins_sb[0], &sw_sb[0], 15, Lm_local); + for (int i = 0; i < 16; i++) { + Ls[i] = Ls_local[i]; + Lm[i] = Lm_local[i]; + } + *(device half*)(block_addr + KQ_Q2_K_D_OFFSET) = half(d_block); + *(device half*)(block_addr + KQ_Q2_K_DMIN_OFFSET) = half(m_block); + + device uint8_t* scales16 = block_addr + KQ_Q2_K_SCALES_OFFSET; + for (int j = 0; j < 16; j++) { + scales16[j] = uint8_t((Ls_local[j] & 0x0F) | ((Lm_local[j] & 0x0F) << 4)); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 5: re-quantize -- + float d_wire = float(*(device half*)(block_addr + KQ_Q2_K_D_OFFSET)); + float dmin_wire = float(*(device half*)(block_addr + KQ_Q2_K_DMIN_OFFSET)); + device uint8_t* scales16 = block_addr + KQ_Q2_K_SCALES_OFFSET; + int my_sb = int(lid) / 16; + uint8_t sc_byte = scales16[my_sb]; + uint8_t sc = sc_byte & 0x0F; + uint8_t mn = sc_byte >> 4; + float d_eff = d_wire * float(sc); + float m_eff = dmin_wire * float(mn); + + uint8_t my_L; + if (d_eff == 0.0f) { + my_L = 0; + } else { + int l = kq_nearest_int((Xs[lid] + m_eff) / d_eff); + l = max(0, min(3, l)); + my_L = uint8_t(l); + } + L_tgm[lid] = my_L; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // -- Phase 6: pack qs[64] -- + if (lid < 64) { + int outer_half = int(lid) / 32; + int within_shift = int(lid) % 32; + uint8_t byte = 0; + for (int shift_idx = 0; shift_idx < 4; shift_idx++) { + int w_idx = outer_half * 128 + shift_idx * 32 + within_shift; + uint8_t q2 = L_tgm[w_idx] & 0x03; + byte |= uint8_t(q2 << (shift_idx * 2)); + } + device uint8_t* qs = block_addr + KQ_Q2_K_QS_OFFSET; + qs[lid] = byte; + } +} + +template +[[kernel]] void kq_q2_k_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + const constant uint& num_blocks [[buffer(2)]], + const device float* imatrix [[buffer(3)]], + const constant uint& has_imatrix [[buffer(4)]], + const constant uint& K [[buffer(5)]], + uint tg_id [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_id [[simdgroup_index_in_threadgroup]], + uint lane_id [[thread_index_in_simdgroup]]) { + static_assert(group_size == 256, "Q2_K requires group_size=256"); + static_assert(bits == 2, "Q2_K requires bits=2"); + threadgroup float Xs[256]; + threadgroup float QWs[256]; + threadgroup uint8_t L_tgm[256]; + threadgroup float scales_sb[16]; + threadgroup float mins_sb[16]; + threadgroup float sw_sb[16]; + threadgroup uint8_t Ls[16]; + threadgroup uint8_t Lm[16]; + threadgroup float scratch[16]; + kq_q2_k_quantize_impl( + w, + out, + num_blocks, + imatrix, + has_imatrix != 0, + K, + tg_id, + lid, + simd_id, + lane_id, + Xs, + QWs, + L_tgm, + scales_sb, + mins_sb, + sw_sb, + Ls, + Lm, + scratch); +} + +template +METAL_FUNC void kq_q4_0_quantize_impl( + const device T* w, + device uint8_t* out, + const constant uint& num_blocks, + uint gid) { + if (gid >= num_blocks) + return; + const device T* x = w + gid * KQ_Q4_0_GROUP; + device uint8_t* block_addr = out + gid * KQ_Q4_0_BLOCK_BYTES; + + float amax = 0.0f; + for (int j = 0; j < KQ_Q4_0_GROUP; j++) { + amax = max(amax, fabs(float(x[j]))); + } + const float d = amax / 7.0f; + const float id = (d != 0.0f) ? (1.0f / d) : 0.0f; + + *(device half*)(block_addr + KQ_Q4_0_D_OFFSET) = half(d); + device uint8_t* qs = block_addr + KQ_Q4_0_QS_OFFSET; + for (int j = 0; j < 16; j++) { + float x0 = float(x[j]) * id; + float x1 = float(x[j + 16]) * id; + uint8_t q0 = uint8_t(clamp(round(x0) + 8.0f, 0.0f, 15.0f)); + uint8_t q1 = uint8_t(clamp(round(x1) + 8.0f, 0.0f, 15.0f)); + qs[j] = q0 | (q1 << 4); + } +} + +template +[[kernel]] void kq_q4_0_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + const constant uint& num_blocks [[buffer(2)]], + const device float* imatrix [[buffer(3)]], + const constant uint& has_imatrix [[buffer(4)]], + uint gid [[thread_position_in_grid]]) { + static_assert(group_size == KQ_Q4_0_GROUP, "Q4_0 requires group_size=32"); + static_assert(bits == 4, "Q4_0 requires bits=4"); + (void)imatrix; + (void)has_imatrix; + kq_q4_0_quantize_impl(w, out, num_blocks, gid); +} + +template +METAL_FUNC void kq_q4_1_quantize_impl( + const device T* w, + device uint8_t* out, + const constant uint& num_blocks, + uint gid) { + if (gid >= num_blocks) + return; + const device T* x = w + gid * KQ_Q4_1_GROUP; + device uint8_t* block_addr = out + gid * KQ_Q4_1_BLOCK_BYTES; + + float vmin = float(x[0]); + float vmax = float(x[0]); + for (int j = 1; j < KQ_Q4_1_GROUP; j++) { + float v = float(x[j]); + vmin = min(vmin, v); + vmax = max(vmax, v); + } + const float d = (vmax - vmin) / 15.0f; + const float id = (d != 0.0f) ? (1.0f / d) : 0.0f; + + *(device half*)(block_addr + KQ_Q4_1_D_OFFSET) = half(d); + *(device half*)(block_addr + KQ_Q4_1_M_OFFSET) = half(vmin); + device uint8_t* qs = block_addr + KQ_Q4_1_QS_OFFSET; + for (int j = 0; j < 16; j++) { + float x0 = (float(x[j]) - vmin) * id; + float x1 = (float(x[j + 16]) - vmin) * id; + uint8_t q0 = uint8_t(clamp(round(x0), 0.0f, 15.0f)); + uint8_t q1 = uint8_t(clamp(round(x1), 0.0f, 15.0f)); + qs[j] = q0 | (q1 << 4); + } +} + +template +[[kernel]] void kq_q4_1_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + const constant uint& num_blocks [[buffer(2)]], + const device float* imatrix [[buffer(3)]], + const constant uint& has_imatrix [[buffer(4)]], + uint gid [[thread_position_in_grid]]) { + static_assert(group_size == KQ_Q4_1_GROUP, "Q4_1 requires group_size=32"); + static_assert(bits == 4, "Q4_1 requires bits=4"); + (void)imatrix; + (void)has_imatrix; + kq_q4_1_quantize_impl(w, out, num_blocks, gid); +} + +template +METAL_FUNC void kq_q5_0_quantize_impl( + const device T* w, + device uint8_t* out, + const constant uint& num_blocks, + uint gid) { + if (gid >= num_blocks) + return; + const device T* x = w + gid * KQ_Q5_0_GROUP; + device uint8_t* block_addr = out + gid * KQ_Q5_0_BLOCK_BYTES; + + float amax = 0.0f; + for (int j = 0; j < KQ_Q5_0_GROUP; j++) { + amax = max(amax, fabs(float(x[j]))); + } + const float d = amax / 15.0f; + const float id = (d != 0.0f) ? (1.0f / d) : 0.0f; + + *(device half*)(block_addr + KQ_Q5_0_D_OFFSET) = half(d); + device uint8_t* qh_p = block_addr + KQ_Q5_0_QH_OFFSET; + device uint8_t* qs = block_addr + KQ_Q5_0_QS_OFFSET; + + uint32_t qh = 0; + for (int j = 0; j < 16; j++) { + float v0 = float(x[j]) * id; + float v1 = float(x[j + 16]) * id; + uint8_t q0 = uint8_t(clamp(round(v0) + 16.0f, 0.0f, 31.0f)); + uint8_t q1 = uint8_t(clamp(round(v1) + 16.0f, 0.0f, 31.0f)); + qs[j] = (q0 & 0x0Fu) | ((q1 & 0x0Fu) << 4); + qh |= (uint32_t(q0 >> 4) << j); + qh |= (uint32_t(q1 >> 4) << (j + 16)); + } + qh_p[0] = uint8_t(qh & 0xFF); + qh_p[1] = uint8_t((qh >> 8) & 0xFF); + qh_p[2] = uint8_t((qh >> 16) & 0xFF); + qh_p[3] = uint8_t((qh >> 24) & 0xFF); +} + +template +[[kernel]] void kq_q5_0_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + const constant uint& num_blocks [[buffer(2)]], + const device float* imatrix [[buffer(3)]], + const constant uint& has_imatrix [[buffer(4)]], + uint gid [[thread_position_in_grid]]) { + static_assert(group_size == KQ_Q5_0_GROUP, "Q5_0 requires group_size=32"); + static_assert(bits == 5, "Q5_0 requires bits=5"); + (void)imatrix; + (void)has_imatrix; + kq_q5_0_quantize_impl(w, out, num_blocks, gid); +} + +template +METAL_FUNC void kq_q5_1_quantize_impl( + const device T* w, + device uint8_t* out, + const constant uint& num_blocks, + uint gid) { + if (gid >= num_blocks) + return; + const device T* x = w + gid * KQ_Q5_1_GROUP; + device uint8_t* block_addr = out + gid * KQ_Q5_1_BLOCK_BYTES; + + float vmin = float(x[0]); + float vmax = float(x[0]); + for (int j = 1; j < KQ_Q5_1_GROUP; j++) { + float v = float(x[j]); + vmin = min(vmin, v); + vmax = max(vmax, v); + } + const float d = (vmax - vmin) / 31.0f; + const float id = (d != 0.0f) ? (1.0f / d) : 0.0f; + + *(device half*)(block_addr + KQ_Q5_1_D_OFFSET) = half(d); + *(device half*)(block_addr + KQ_Q5_1_M_OFFSET) = half(vmin); + device uint8_t* qh_p = block_addr + KQ_Q5_1_QH_OFFSET; + device uint8_t* qs = block_addr + KQ_Q5_1_QS_OFFSET; + + uint32_t qh = 0; + for (int j = 0; j < 16; j++) { + float v0 = (float(x[j]) - vmin) * id; + float v1 = (float(x[j + 16]) - vmin) * id; + uint8_t q0 = uint8_t(clamp(round(v0), 0.0f, 31.0f)); + uint8_t q1 = uint8_t(clamp(round(v1), 0.0f, 31.0f)); + qs[j] = (q0 & 0x0Fu) | ((q1 & 0x0Fu) << 4); + qh |= (uint32_t(q0 >> 4) << j); + qh |= (uint32_t(q1 >> 4) << (j + 16)); + } + qh_p[0] = uint8_t(qh & 0xFF); + qh_p[1] = uint8_t((qh >> 8) & 0xFF); + qh_p[2] = uint8_t((qh >> 16) & 0xFF); + qh_p[3] = uint8_t((qh >> 24) & 0xFF); +} + +template +[[kernel]] void kq_q5_1_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + const constant uint& num_blocks [[buffer(2)]], + const device float* imatrix [[buffer(3)]], + const constant uint& has_imatrix [[buffer(4)]], + uint gid [[thread_position_in_grid]]) { + static_assert(group_size == KQ_Q5_1_GROUP, "Q5_1 requires group_size=32"); + static_assert(bits == 5, "Q5_1 requires bits=5"); + (void)imatrix; + (void)has_imatrix; + kq_q5_1_quantize_impl(w, out, num_blocks, gid); +} diff --git a/mlx/backend/metal/kernels/kq_quantized_encode.metal b/mlx/backend/metal/kernels/kq_quantized_encode.metal new file mode 100644 index 0000000000..a42adaf92f --- /dev/null +++ b/mlx/backend/metal/kernels/kq_quantized_encode.metal @@ -0,0 +1,86 @@ +// Copyright © 2026 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/quantized_utils.h" +#include "mlx/backend/metal/kernels/kq_quantized_encode.h" + +#define instantiate_kquant_quantize(type, gs, bits, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_quantize_" #type "_gs_" #gs "_b_" #bits, \ + kq_ ## codec ## _quantize, \ + type, \ + gs, \ + bits) + +#define instantiate_kquant_q8_0_quantize_for_type(type) \ + instantiate_kquant_quantize(type, 32, 8, q8_0) + +instantiate_kquant_q8_0_quantize_for_type(float) +instantiate_kquant_q8_0_quantize_for_type(float16_t) +instantiate_kquant_q8_0_quantize_for_type(bfloat16_t) + +#define instantiate_kquant_q4_k_quantize_for_type(type) \ + instantiate_kquant_quantize(type, 256, 4, q4_k) + +instantiate_kquant_q4_k_quantize_for_type(float) +instantiate_kquant_q4_k_quantize_for_type(float16_t) +instantiate_kquant_q4_k_quantize_for_type(bfloat16_t) + +#define instantiate_kquant_q6_k_quantize_for_type(type) \ + instantiate_kquant_quantize(type, 256, 6, q6_k) + +instantiate_kquant_q6_k_quantize_for_type(float) +instantiate_kquant_q6_k_quantize_for_type(float16_t) +instantiate_kquant_q6_k_quantize_for_type(bfloat16_t) + +#define instantiate_kquant_q5_k_quantize_for_type(type) \ + instantiate_kquant_quantize(type, 256, 5, q5_k) + +instantiate_kquant_q5_k_quantize_for_type(float) +instantiate_kquant_q5_k_quantize_for_type(float16_t) +instantiate_kquant_q5_k_quantize_for_type(bfloat16_t) + +#define instantiate_kquant_q3_k_quantize_for_type(type) \ + instantiate_kquant_quantize(type, 256, 3, q3_k) + +instantiate_kquant_q3_k_quantize_for_type(float) +instantiate_kquant_q3_k_quantize_for_type(float16_t) +instantiate_kquant_q3_k_quantize_for_type(bfloat16_t) + +#define instantiate_kquant_q2_k_quantize_for_type(type) \ + instantiate_kquant_quantize(type, 256, 2, q2_k) + +instantiate_kquant_q2_k_quantize_for_type(float) +instantiate_kquant_q2_k_quantize_for_type(float16_t) +instantiate_kquant_q2_k_quantize_for_type(bfloat16_t) + +#define instantiate_kquant_q4_0_quantize_for_type(type) \ + instantiate_kquant_quantize(type, 32, 4, q4_0) + +instantiate_kquant_q4_0_quantize_for_type(float) +instantiate_kquant_q4_0_quantize_for_type(float16_t) +instantiate_kquant_q4_0_quantize_for_type(bfloat16_t) + +#define instantiate_kquant_q4_1_quantize_for_type(type) \ + instantiate_kquant_quantize(type, 32, 4, q4_1) + +instantiate_kquant_q4_1_quantize_for_type(float) +instantiate_kquant_q4_1_quantize_for_type(float16_t) +instantiate_kquant_q4_1_quantize_for_type(bfloat16_t) + +#define instantiate_kquant_q5_0_quantize_for_type(type) \ + instantiate_kquant_quantize(type, 32, 5, q5_0) + +instantiate_kquant_q5_0_quantize_for_type(float) +instantiate_kquant_q5_0_quantize_for_type(float16_t) +instantiate_kquant_q5_0_quantize_for_type(bfloat16_t) + +#define instantiate_kquant_q5_1_quantize_for_type(type) \ + instantiate_kquant_quantize(type, 32, 5, q5_1) + +instantiate_kquant_q5_1_quantize_for_type(float) +instantiate_kquant_q5_1_quantize_for_type(float16_t) +instantiate_kquant_q5_1_quantize_for_type(bfloat16_t) + // clang-format on diff --git a/mlx/backend/metal/kernels/kq_quantized_legacy.h b/mlx/backend/metal/kernels/kq_quantized_legacy.h new file mode 100644 index 0000000000..e835a92e0b --- /dev/null +++ b/mlx/backend/metal/kernels/kq_quantized_legacy.h @@ -0,0 +1,1627 @@ +// Copyright © 2026 Apple Inc. + +// Q4_0: 18 bytes/32 weights. [fp16 d][uint8 qs[16]]. w[i] = d * (q4 - 8). + +MLX_MTL_CONST int KQ_Q4_0_GROUP = 32; +MLX_MTL_CONST int KQ_Q4_0_BLOCK_BYTES = 18; +MLX_MTL_CONST int KQ_Q4_0_D_OFFSET = 0; +MLX_MTL_CONST int KQ_Q4_0_QS_OFFSET = 2; + +inline float kq_q4_0_d(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q4_0_D_OFFSET)); +} +inline const device uint8_t* kq_q4_0_qs_ptr(const device uint8_t* block_addr) { + return block_addr + KQ_Q4_0_QS_OFFSET; +} + +template +METAL_FUNC void kq_q4_0_dequantize_impl( + const device uint8_t* w, + device T* out, + const constant uint& num_weights, + uint gid) { + if (gid >= num_weights) { + return; + } + const int block_id = gid / KQ_Q4_0_GROUP; + const int within = gid % KQ_Q4_0_GROUP; + const device uint8_t* block_addr = w + block_id * KQ_Q4_0_BLOCK_BYTES; + const float d = kq_q4_0_d(block_addr); + const device uint8_t* qs = kq_q4_0_qs_ptr(block_addr); + const int q4 = + (within < 16) ? (int(qs[within]) & 0x0F) : (int(qs[within - 16]) >> 4); + out[gid] = T(d * float(q4 - 8)); +} + +template +METAL_FUNC void kq_q4_0_qmv_fast_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q4_0_GROUP, "Q4_0 kernel requires group_size=32"); + static_assert(bits == 4, "Q4_0 kernel requires bits=4"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int block_stride = 16; + + typedef float U; + thread U yl[16]; + thread U result[results_per_simdgroup] = {0}; + + const int ix = simd_lid / 2; + const int il = (simd_lid % 2) * 8; + + const int row_bytes = in_vec_size * KQ_Q4_0_BLOCK_BYTES / KQ_Q4_0_GROUP; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + const int nb = in_vec_size / KQ_Q4_0_GROUP; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += block_stride) { + const int x_base = ib * KQ_Q4_0_GROUP + il; + U sumy = U(0); +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const U a0 = U(x[x_base + i + 0]); + const U a1 = U(x[x_base + i + 1]); + const U b0 = U(x[x_base + i + 16]); + const U b1 = U(x[x_base + i + 17]); + sumy += a0 + a1 + b0 + b1; + yl[i + 0] = a0; + yl[i + 1] = a1 * (U(1) / U(256)); + yl[i + 8] = b0 * (U(1) / U(16)); + yl[i + 9] = b1 * (U(1) / U(4096)); + } + + for (int row = 0; row < results_per_simdgroup; row++) { + const int row_idx = out_row + row; + const device uint8_t* block_addr = + w + row_idx * row_bytes + ib * KQ_Q4_0_BLOCK_BYTES; + const U d = U(kq_q4_0_d(block_addr)); + const device uint16_t* qs = + reinterpret_cast(kq_q4_0_qs_ptr(block_addr)) + + il / 2; + + U acc[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const uint16_t qi = qs[i / 2]; + acc[0] += yl[i + 0] * U(qi & 0x000F); + acc[1] += yl[i + 1] * U(qi & 0x0F00); + acc[2] += yl[i + 8] * U(qi & 0x00F0); + acc[3] += yl[i + 9] * U(qi & 0xF000); + } + result[row] += d * (acc[0] + acc[1] + acc[2] + acc[3] + sumy * U(-8)); + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void kq_q4_0_qmv_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q4_0_GROUP, "Q4_0 kernel requires group_size=32"); + static_assert(bits == 4, "Q4_0 kernel requires bits=4"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int block_stride = 16; + + typedef float U; + thread U yl[16]; + thread U result[results_per_simdgroup] = {0}; + + const int row_bytes = in_vec_size * KQ_Q4_0_BLOCK_BYTES / KQ_Q4_0_GROUP; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + if (out_row >= out_vec_size) { + return; + } + const int max_row = min(out_vec_size, out_row + results_per_simdgroup); + const int active_rows = max_row - out_row; + + const int ix = simd_lid / 2; + const int il = (simd_lid % 2) * 8; + + const int nb = in_vec_size / KQ_Q4_0_GROUP; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += block_stride) { + const int x_base = ib * KQ_Q4_0_GROUP + il; + U sumy = U(0); +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const U a0 = U(x[x_base + i + 0]); + const U a1 = U(x[x_base + i + 1]); + const U b0 = U(x[x_base + i + 16]); + const U b1 = U(x[x_base + i + 17]); + sumy += a0 + a1 + b0 + b1; + yl[i + 0] = a0; + yl[i + 1] = a1 * (U(1) / U(256)); + yl[i + 8] = b0 * (U(1) / U(16)); + yl[i + 9] = b1 * (U(1) / U(4096)); + } + + for (int row = 0; row < active_rows; row++) { + const int row_idx = out_row + row; + const device uint8_t* block_addr = + w + row_idx * row_bytes + ib * KQ_Q4_0_BLOCK_BYTES; + const U d = U(kq_q4_0_d(block_addr)); + const device uint16_t* qs = + reinterpret_cast(kq_q4_0_qs_ptr(block_addr)) + + il / 2; + + U acc[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const uint16_t qi = qs[i / 2]; + acc[0] += yl[i + 0] * U(qi & 0x000F); + acc[1] += yl[i + 1] * U(qi & 0x0F00); + acc[2] += yl[i + 8] * U(qi & 0x00F0); + acc[3] += yl[i + 9] * U(qi & 0xF000); + } + result[row] += d * (acc[0] + acc[1] + acc[2] + acc[3] + sumy * U(-8)); + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0 && row < active_rows) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqQ4_0BlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q4_0_GROUP; + MLX_MTL_CONST int bytes_per_block = KQ_Q4_0_BLOCK_BYTES; + + static_assert( + BCOLS == weights_per_block, + "Q4_0 loader requires BCOLS == 32 (one block per K-tile)."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + MLX_MTL_CONST short bytes_per_thread = n_reads / 2; + MLX_MTL_CONST short half_block = weights_per_block / 2; + static_assert(n_reads >= 2 && n_reads % 2 == 0, "Q4_0 needs even n_reads."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + + const short thread_idx; + const short bi; + const short bj_byte; + + threadgroup T* dst; + const device uint8_t* src; + + KqQ4_0BlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int /* col_in_block */ = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? bytes_per_block + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj_byte((thread_idx % TCOLS) * bytes_per_thread), + dst(dst_ + bi * dst_ld + bj_byte), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)) {} + + void load_unsafe() const { + const float d = float(*(const device half*)(src + KQ_Q4_0_D_OFFSET)); + const device uint8_t* qs = src + KQ_Q4_0_QS_OFFSET + bj_byte; +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + const uint8_t b = qs[i]; + const int q4_lo = int(b & 0x0F); + const int q4_hi = int(b >> 4); + dst[i] = T(d * float(q4_lo - 8)); + dst[half_block + i] = T(d * float(q4_hi - 8)); + } + } + + void load_safe(short2 src_tile_dim) const { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + dst[i] = T(0); + dst[half_block + i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + src += tile_stride; + } +}; + +template +[[kernel]] void kq_q4_0_qmm_t( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q4_0_GROUP, "Q4_0 kernel requires group_size=32"); + static_assert(bits == 4, "Q4_0 kernel requires bits=4"); + constexpr int BM = 64, BK = 32, BN = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ4_0BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_t_impl( + w, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q4_0_qmm_t_splitk( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& k_partition_size, + const constant int& split_k_partition_stride, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q4_0_GROUP, "Q4_0 kernel requires group_size=32"); + static_assert(bits == 4, "Q4_0 kernel requires bits=4"); + constexpr int BM = 32, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ4_0BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + + const int k_start = tid.z * k_partition_size; + x += k_start; + auto wl = w; + wl += (k_start / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + y += tid.z * static_cast(split_k_partition_stride); + + kq_qmm_t_impl( + wl, + x, + y, + Xs, + Ws, + K, + N, + M, + k_partition_size, + tid, + lid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void kq_q4_0_qmm_n( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q4_0_GROUP, "Q4_0 kernel requires group_size=32"); + static_assert(bits == 4, "Q4_0 kernel requires bits=4"); + constexpr int BM = 64, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + using LoaderW = KqQ4_0BlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_n_impl( + w, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q4_0_qmv_fast( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q4_0_qmv_fast_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q4_0_qmv( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q4_0_qmv_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q4_0_dequantize( + const device uint8_t* w, + const device uint8_t* /* scales */, + device T* out, + const constant uint& num_weights, + uint gid [[thread_position_in_grid]]) { + static_assert( + group_size == KQ_Q4_0_GROUP, "Q4_0 kernel requires group_size=32"); + static_assert(bits == 4, "Q4_0 kernel requires bits=4"); + kq_q4_0_dequantize_impl(w, out, num_weights, gid); +} + +// Q4_1: 20 bytes/32 weights. [fp16 d][fp16 m][uint8 qs[16]]. w[i] = d * q4 + m. + +MLX_MTL_CONST int KQ_Q4_1_GROUP = 32; +MLX_MTL_CONST int KQ_Q4_1_BLOCK_BYTES = 20; +MLX_MTL_CONST int KQ_Q4_1_D_OFFSET = 0; +MLX_MTL_CONST int KQ_Q4_1_M_OFFSET = 2; +MLX_MTL_CONST int KQ_Q4_1_QS_OFFSET = 4; + +inline float kq_q4_1_d(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q4_1_D_OFFSET)); +} +inline float kq_q4_1_m(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q4_1_M_OFFSET)); +} +inline const device uint8_t* kq_q4_1_qs_ptr(const device uint8_t* block_addr) { + return block_addr + KQ_Q4_1_QS_OFFSET; +} + +template +METAL_FUNC void kq_q4_1_dequantize_impl( + const device uint8_t* w, + device T* out, + const constant uint& num_weights, + uint gid) { + if (gid >= num_weights) { + return; + } + const int block_id = gid / KQ_Q4_1_GROUP; + const int within = gid % KQ_Q4_1_GROUP; + const device uint8_t* block_addr = w + block_id * KQ_Q4_1_BLOCK_BYTES; + const float d = kq_q4_1_d(block_addr); + const float m = kq_q4_1_m(block_addr); + const device uint8_t* qs = kq_q4_1_qs_ptr(block_addr); + const int q4 = + (within < 16) ? (int(qs[within]) & 0x0F) : (int(qs[within - 16]) >> 4); + out[gid] = T(d * float(q4) + m); +} + +template +METAL_FUNC void kq_q4_1_qmv_fast_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q4_1_GROUP, "Q4_1 kernel requires group_size=32"); + static_assert(bits == 4, "Q4_1 kernel requires bits=4"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int block_stride = 16; + + typedef float U; + thread U yl[16]; + thread U result[results_per_simdgroup] = {0}; + + const int ix = simd_lid / 2; + const int il = (simd_lid % 2) * 8; + + const int row_bytes = in_vec_size * KQ_Q4_1_BLOCK_BYTES / KQ_Q4_1_GROUP; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + const int nb = in_vec_size / KQ_Q4_1_GROUP; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += block_stride) { + const int x_base = ib * KQ_Q4_1_GROUP + il; + U sumy = U(0); +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const U a0 = U(x[x_base + i + 0]); + const U a1 = U(x[x_base + i + 1]); + const U b0 = U(x[x_base + i + 16]); + const U b1 = U(x[x_base + i + 17]); + sumy += a0 + a1 + b0 + b1; + yl[i + 0] = a0; + yl[i + 1] = a1 * (U(1) / U(256)); + yl[i + 8] = b0 * (U(1) / U(16)); + yl[i + 9] = b1 * (U(1) / U(4096)); + } + + for (int row = 0; row < results_per_simdgroup; row++) { + const int row_idx = out_row + row; + const device uint8_t* block_addr = + w + row_idx * row_bytes + ib * KQ_Q4_1_BLOCK_BYTES; + const U d = U(kq_q4_1_d(block_addr)); + const U m = U(kq_q4_1_m(block_addr)); + const device uint16_t* qs = + reinterpret_cast(kq_q4_1_qs_ptr(block_addr)) + + il / 2; + + U acc[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const uint16_t qi = qs[i / 2]; + acc[0] += yl[i + 0] * U(qi & 0x000F); + acc[1] += yl[i + 1] * U(qi & 0x0F00); + acc[2] += yl[i + 8] * U(qi & 0x00F0); + acc[3] += yl[i + 9] * U(qi & 0xF000); + } + result[row] += d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void kq_q4_1_qmv_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q4_1_GROUP, "Q4_1 kernel requires group_size=32"); + static_assert(bits == 4, "Q4_1 kernel requires bits=4"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int block_stride = 16; + + typedef float U; + thread U yl[16]; + thread U result[results_per_simdgroup] = {0}; + + const int row_bytes = in_vec_size * KQ_Q4_1_BLOCK_BYTES / KQ_Q4_1_GROUP; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + if (out_row >= out_vec_size) { + return; + } + const int max_row = min(out_vec_size, out_row + results_per_simdgroup); + const int active_rows = max_row - out_row; + + const int ix = simd_lid / 2; + const int il = (simd_lid % 2) * 8; + + const int nb = in_vec_size / KQ_Q4_1_GROUP; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += block_stride) { + const int x_base = ib * KQ_Q4_1_GROUP + il; + U sumy = U(0); +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const U a0 = U(x[x_base + i + 0]); + const U a1 = U(x[x_base + i + 1]); + const U b0 = U(x[x_base + i + 16]); + const U b1 = U(x[x_base + i + 17]); + sumy += a0 + a1 + b0 + b1; + yl[i + 0] = a0; + yl[i + 1] = a1 * (U(1) / U(256)); + yl[i + 8] = b0 * (U(1) / U(16)); + yl[i + 9] = b1 * (U(1) / U(4096)); + } + + for (int row = 0; row < active_rows; row++) { + const int row_idx = out_row + row; + const device uint8_t* block_addr = + w + row_idx * row_bytes + ib * KQ_Q4_1_BLOCK_BYTES; + const U d = U(kq_q4_1_d(block_addr)); + const U m = U(kq_q4_1_m(block_addr)); + const device uint16_t* qs = + reinterpret_cast(kq_q4_1_qs_ptr(block_addr)) + + il / 2; + + U acc[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const uint16_t qi = qs[i / 2]; + acc[0] += yl[i + 0] * U(qi & 0x000F); + acc[1] += yl[i + 1] * U(qi & 0x0F00); + acc[2] += yl[i + 8] * U(qi & 0x00F0); + acc[3] += yl[i + 9] * U(qi & 0xF000); + } + result[row] += d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0 && row < active_rows) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqQ4_1BlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q4_1_GROUP; + MLX_MTL_CONST int bytes_per_block = KQ_Q4_1_BLOCK_BYTES; + + static_assert( + BCOLS == weights_per_block, + "Q4_1 loader requires BCOLS == 32 (one block per K-tile)."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + MLX_MTL_CONST short bytes_per_thread = n_reads / 2; + MLX_MTL_CONST short half_block = weights_per_block / 2; + static_assert(n_reads >= 2 && n_reads % 2 == 0, "Q4_1 needs even n_reads."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + + const short thread_idx; + const short bi; + const short bj_byte; + + threadgroup T* dst; + const device uint8_t* src; + + KqQ4_1BlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int /* col_in_block */ = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? bytes_per_block + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj_byte((thread_idx % TCOLS) * bytes_per_thread), + dst(dst_ + bi * dst_ld + bj_byte), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)) {} + + void load_unsafe() const { + const float d = float(*(const device half*)(src + KQ_Q4_1_D_OFFSET)); + const float m = float(*(const device half*)(src + KQ_Q4_1_M_OFFSET)); + const device uint8_t* qs = src + KQ_Q4_1_QS_OFFSET + bj_byte; + static_assert( + bytes_per_thread == 4 || bytes_per_thread == 8, + "Q4_1 ALU vector load supports bytes_per_thread=4 or 8 (uint)."); + uint8_t qs_b[bytes_per_thread]; +#pragma unroll + for (short v = 0; v < bytes_per_thread / 4; v++) { + const uint qs_v = *reinterpret_cast(qs + v * 4); + *reinterpret_cast(&qs_b[v * 4]) = qs_v; + } +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + const uint8_t b = qs_b[i]; + const int q4_lo = int(b & 0x0F); + const int q4_hi = int(b >> 4); + dst[i] = T(d * float(q4_lo) + m); + dst[half_block + i] = T(d * float(q4_hi) + m); + } + } + + void load_safe(short2 src_tile_dim) const { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + dst[i] = T(0); + dst[half_block + i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + src += tile_stride; + } +}; + +template +[[kernel]] void kq_q4_1_qmm_t( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q4_1_GROUP, "Q4_1 kernel requires group_size=32"); + static_assert(bits == 4, "Q4_1 kernel requires bits=4"); + constexpr int BM = 64, BK = 32, BN = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ4_1BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_t_impl( + w, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q4_1_qmm_t_splitk( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& k_partition_size, + const constant int& split_k_partition_stride, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q4_1_GROUP, "Q4_1 kernel requires group_size=32"); + static_assert(bits == 4, "Q4_1 kernel requires bits=4"); + constexpr int BM = 32, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ4_1BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + + const int k_start = tid.z * k_partition_size; + x += k_start; + auto wl = w; + wl += (k_start / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + y += tid.z * static_cast(split_k_partition_stride); + + kq_qmm_t_impl( + wl, + x, + y, + Xs, + Ws, + K, + N, + M, + k_partition_size, + tid, + lid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void kq_q4_1_qmm_n( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q4_1_GROUP, "Q4_1 kernel requires group_size=32"); + static_assert(bits == 4, "Q4_1 kernel requires bits=4"); + constexpr int BM = 64, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + using LoaderW = KqQ4_1BlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_n_impl( + w, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q4_1_qmv_fast( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q4_1_qmv_fast_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q4_1_qmv( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q4_1_qmv_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q4_1_dequantize( + const device uint8_t* w, + const device uint8_t* /* scales */, + device T* out, + const constant uint& num_weights, + uint gid [[thread_position_in_grid]]) { + static_assert( + group_size == KQ_Q4_1_GROUP, "Q4_1 kernel requires group_size=32"); + static_assert(bits == 4, "Q4_1 kernel requires bits=4"); + kq_q4_1_dequantize_impl(w, out, num_weights, gid); +} + +// Q5_0: 22 bytes/32 weights. [fp16 d][uint8 qh[4]][uint8 qs[16]]. w[i] = d * +// (q5 - 16). + +MLX_MTL_CONST int KQ_Q5_0_GROUP = 32; +MLX_MTL_CONST int KQ_Q5_0_BLOCK_BYTES = 22; +MLX_MTL_CONST int KQ_Q5_0_D_OFFSET = 0; +MLX_MTL_CONST int KQ_Q5_0_QH_OFFSET = 2; +MLX_MTL_CONST int KQ_Q5_0_QS_OFFSET = 6; + +inline float kq_q5_0_d(const device uint8_t* block_addr) { + return float(*(const device half*)(block_addr + KQ_Q5_0_D_OFFSET)); +} +// qh at 22N+2 is not uint32-aligned; assemble from byte loads. +inline uint32_t kq_q5_0_qh(const device uint8_t* block_addr) { + const device uint8_t* p = block_addr + KQ_Q5_0_QH_OFFSET; + return uint32_t(p[0]) | (uint32_t(p[1]) << 8) | (uint32_t(p[2]) << 16) | + (uint32_t(p[3]) << 24); +} +inline const device uint8_t* kq_q5_0_qs_ptr(const device uint8_t* block_addr) { + return block_addr + KQ_Q5_0_QS_OFFSET; +} + +template +METAL_FUNC void kq_q5_0_dequantize_impl( + const device uint8_t* w, + device T* out, + const constant uint& num_weights, + uint gid) { + if (gid >= num_weights) { + return; + } + const int block_id = gid / KQ_Q5_0_GROUP; + const int within = gid % KQ_Q5_0_GROUP; + const device uint8_t* block_addr = w + block_id * KQ_Q5_0_BLOCK_BYTES; + const float d = kq_q5_0_d(block_addr); + const uint32_t qh = kq_q5_0_qh(block_addr); + const device uint8_t* qs = kq_q5_0_qs_ptr(block_addr); + const uint32_t hi = ((qh >> within) << 4) & 0x10u; + const uint8_t lo = + (within < 16) ? (qs[within] & 0x0Fu) : (qs[within - 16] >> 4); + const int q5 = int(uint32_t(lo) | hi); + out[gid] = T(d * float(q5 - 16)); +} + +template +METAL_FUNC void kq_q5_0_qmv_fast_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q5_0_GROUP, "Q5_0 kernel requires group_size=32"); + static_assert(bits == 5, "Q5_0 kernel requires bits=5"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int block_stride = 16; + + typedef float U; + thread U yl[16]; + thread U result[results_per_simdgroup] = {0}; + + const int ix = simd_lid / 2; + const int il = (simd_lid % 2) * 8; + + const int row_bytes = in_vec_size * KQ_Q5_0_BLOCK_BYTES / KQ_Q5_0_GROUP; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + const int nb = in_vec_size / KQ_Q5_0_GROUP; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += block_stride) { + const int x_base = ib * KQ_Q5_0_GROUP + il; + U sumy = U(0); +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const U a0 = U(x[x_base + i + 0]); + const U a1 = U(x[x_base + i + 1]); + const U b0 = U(x[x_base + i + 16]); + const U b1 = U(x[x_base + i + 17]); + sumy += a0 + a1 + b0 + b1; + yl[i + 0] = a0; + yl[i + 1] = a1 * (U(1) / U(256)); + yl[i + 8] = b0 * (U(1) / U(16)); + yl[i + 9] = b1 * (U(1) / U(4096)); + } + + for (int row = 0; row < results_per_simdgroup; row++) { + const int row_idx = out_row + row; + const device uint8_t* block_addr = + w + row_idx * row_bytes + ib * KQ_Q5_0_BLOCK_BYTES; + const U d = U(kq_q5_0_d(block_addr)); + const uint32_t qh = kq_q5_0_qh(block_addr); + const device uint16_t* qs = + reinterpret_cast(kq_q5_0_qs_ptr(block_addr)) + + il / 2; + + U acc[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const uint16_t qi = qs[i / 2]; + acc[0] += yl[i + 0] * + U((qi & 0x000F) | (((qh >> (i + 0 + il)) << 4) & 0x00010)); + acc[1] += yl[i + 1] * + U((qi & 0x0F00) | (((qh >> (i + 1 + il)) << 12) & 0x01000)); + acc[2] += yl[i + 8] * + U((qi & 0x00F0) | (((qh >> (i + 0 + il + 16)) << 8) & 0x00100)); + acc[3] += yl[i + 9] * + U((qi & 0xF000) | (((qh >> (i + 1 + il + 16)) << 16) & 0x10000)); + } + result[row] += d * (acc[0] + acc[1] + acc[2] + acc[3] + sumy * U(-16)); + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void kq_q5_0_qmv_impl( + const device uint8_t* w, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid, + uint simd_gid, + uint simd_lid) { + static_assert( + group_size == KQ_Q5_0_GROUP, "Q5_0 kernel requires group_size=32"); + static_assert(bits == 5, "Q5_0 kernel requires bits=5"); + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int block_stride = 16; + + typedef float U; + thread U yl[16]; + thread U result[results_per_simdgroup] = {0}; + + const int row_bytes = in_vec_size * KQ_Q5_0_BLOCK_BYTES / KQ_Q5_0_GROUP; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + if (out_row >= out_vec_size) { + return; + } + const int max_row = min(out_vec_size, out_row + results_per_simdgroup); + const int active_rows = max_row - out_row; + + const int ix = simd_lid / 2; + const int il = (simd_lid % 2) * 8; + + const int nb = in_vec_size / KQ_Q5_0_GROUP; + + x += tid.x * in_vec_size; + y += tid.x * out_vec_size; + + for (int ib = ix; ib < nb; ib += block_stride) { + const int x_base = ib * KQ_Q5_0_GROUP + il; + U sumy = U(0); +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const U a0 = U(x[x_base + i + 0]); + const U a1 = U(x[x_base + i + 1]); + const U b0 = U(x[x_base + i + 16]); + const U b1 = U(x[x_base + i + 17]); + sumy += a0 + a1 + b0 + b1; + yl[i + 0] = a0; + yl[i + 1] = a1 * (U(1) / U(256)); + yl[i + 8] = b0 * (U(1) / U(16)); + yl[i + 9] = b1 * (U(1) / U(4096)); + } + + for (int row = 0; row < active_rows; row++) { + const int row_idx = out_row + row; + const device uint8_t* block_addr = + w + row_idx * row_bytes + ib * KQ_Q5_0_BLOCK_BYTES; + const U d = U(kq_q5_0_d(block_addr)); + const uint32_t qh = kq_q5_0_qh(block_addr); + const device uint16_t* qs = + reinterpret_cast(kq_q5_0_qs_ptr(block_addr)) + + il / 2; + + U acc[4] = {U(0), U(0), U(0), U(0)}; +#pragma unroll + for (int i = 0; i < 8; i += 2) { + const uint16_t qi = qs[i / 2]; + acc[0] += yl[i + 0] * + U((qi & 0x000F) | (((qh >> (i + 0 + il)) << 4) & 0x00010)); + acc[1] += yl[i + 1] * + U((qi & 0x0F00) | (((qh >> (i + 1 + il)) << 12) & 0x01000)); + acc[2] += yl[i + 8] * + U((qi & 0x00F0) | (((qh >> (i + 0 + il + 16)) << 8) & 0x00100)); + acc[3] += yl[i + 9] * + U((qi & 0xF000) | (((qh >> (i + 1 + il + 16)) << 16) & 0x10000)); + } + result[row] += d * (acc[0] + acc[1] + acc[2] + acc[3] + sumy * U(-16)); + } + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0 && row < active_rows) { + y[out_row + row] = static_cast(result[row]); + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqQ5_0BlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q5_0_GROUP; + MLX_MTL_CONST int bytes_per_block = KQ_Q5_0_BLOCK_BYTES; + + static_assert( + BCOLS == weights_per_block, + "Q5_0 loader requires BCOLS == 32 (one block per K-tile)."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + MLX_MTL_CONST short bytes_per_thread = n_reads / 2; + MLX_MTL_CONST short half_block = weights_per_block / 2; + static_assert(n_reads >= 2 && n_reads % 2 == 0, "Q5_0 needs even n_reads."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + + const short thread_idx; + const short bi; + const short bj_byte; + + threadgroup T* dst; + const device uint8_t* src; + + KqQ5_0BlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int /* col_in_block */ = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? bytes_per_block + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj_byte((thread_idx % TCOLS) * bytes_per_thread), + dst(dst_ + bi * dst_ld + bj_byte), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)) {} + + void load_unsafe() const { + const float d = float(*(const device half*)(src + KQ_Q5_0_D_OFFSET)); + const device uint8_t* qh_p = src + KQ_Q5_0_QH_OFFSET; + const uint32_t qh = uint32_t(qh_p[0]) | (uint32_t(qh_p[1]) << 8) | + (uint32_t(qh_p[2]) << 16) | (uint32_t(qh_p[3]) << 24); + const device uint8_t* qs = src + KQ_Q5_0_QS_OFFSET + bj_byte; +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + const uint8_t b = qs[i]; + const int j_lo = bj_byte + i; + const int j_hi = bj_byte + half_block + i; + const uint32_t hi_lo = ((qh >> j_lo) << 4) & 0x10u; + const uint32_t hi_hi = ((qh >> j_hi) << 4) & 0x10u; + const int q5_lo = int(uint32_t(b & 0x0F) | hi_lo); + const int q5_hi = int(uint32_t(b >> 4) | hi_hi); + dst[i] = T(d * float(q5_lo - 16)); + dst[half_block + i] = T(d * float(q5_hi - 16)); + } + } + + void load_safe(short2 src_tile_dim) const { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + dst[i] = T(0); + dst[half_block + i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + src += tile_stride; + } +}; + +template +[[kernel]] void kq_q5_0_qmm_t( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q5_0_GROUP, "Q5_0 kernel requires group_size=32"); + static_assert(bits == 5, "Q5_0 kernel requires bits=5"); + constexpr int BM = 64, BK = 32, BN = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ5_0BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_t_impl( + w, x, y, Xs, Ws, K, N, M, K, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_0_qmm_t_splitk( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& k_partition_size, + const constant int& split_k_partition_stride, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q5_0_GROUP, "Q5_0 kernel requires group_size=32"); + static_assert(bits == 5, "Q5_0 kernel requires bits=5"); + constexpr int BM = 32, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + using LoaderW = KqQ5_0BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + + const int k_start = tid.z * k_partition_size; + x += k_start; + auto wl = w; + wl += (k_start / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + y += tid.z * static_cast(split_k_partition_stride); + + kq_qmm_t_impl( + wl, + x, + y, + Xs, + Ws, + K, + N, + M, + k_partition_size, + tid, + lid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void kq_q5_0_qmm_n( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + static_assert( + group_size == KQ_Q5_0_GROUP, "Q5_0 kernel requires group_size=32"); + static_assert(bits == 5, "Q5_0 kernel requires bits=5"); + constexpr int BM = 64, BK = 32, BN = 32; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + using LoaderW = KqQ5_0BlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/2 * 2 * SIMD_SIZE>; + kq_qmm_n_impl( + w, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_0_qmv_fast( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q5_0_qmv_fast_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_0_qmv( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if constexpr (batched) { + int batch_M = x_shape[x_batch_ndims]; + kq_adjust_matrix_offsets( + x, + w, + y, + out_vec_size * batch_M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + kq_q5_0_qmv_impl( + w, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_0_dequantize( + const device uint8_t* w, + const device uint8_t* /* scales */, + device T* out, + const constant uint& num_weights, + uint gid [[thread_position_in_grid]]) { + static_assert( + group_size == KQ_Q5_0_GROUP, "Q5_0 kernel requires group_size=32"); + static_assert(bits == 5, "Q5_0 kernel requires bits=5"); + kq_q5_0_dequantize_impl(w, out, num_weights, gid); +} diff --git a/mlx/backend/metal/kernels/kq_quantized_nax.h b/mlx/backend/metal/kernels/kq_quantized_nax.h new file mode 100644 index 0000000000..c8bb740baa --- /dev/null +++ b/mlx/backend/metal/kernels/kq_quantized_nax.h @@ -0,0 +1,2560 @@ +// Copyright © 2026 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/kernels/kq_quantized.h" + +using namespace metal; +using namespace mlx::steel; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +template < + typename T, + typename LoaderW, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +METAL_FUNC void kq_qmm_t_nax_tgp_impl( + const device uint8_t* w, + const device T* x, + device T* y, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + const int K_w = (K / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = w; + + x += y_row * static_cast(K); + wl += static_cast(y_col) * K_w; + y += y_row * static_cast(N) + y_col; + + LoaderW loader_w(wl, K, Ws, simd_gid, simd_lid); + + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; + constexpr short TK = SK / 16; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + + const short sgp_sm = min(SM, short(M - (y_row + tm))); + const bool is_unaligned_sm = (sgp_sm != SM); + + const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn))); + + const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col))); + const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN); + + using AccumType = float; + + NAXTile Dtile; + Dtile.clear(); + + x += tm * K; + + dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe(short2(BK, tgp_bn)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(x + kk1, K); + } else { + Atile.load_safe(x + kk1, K, short2(SK, sgp_sm)); + } + + Btile.template load(Ws + tn * BK_padded + kk1); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if constexpr (kAlignedM.value && kAlignedN.value) { + Dtile.store(y + tm * N + tn, N); + } else if (kAlignedM.value && sgp_sn == SN) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm)); + } + }); + }); +} + +template < + typename T, + typename LoaderW, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +METAL_FUNC void kq_qmm_n_nax_tgp_impl( + const device uint8_t* w, + const device T* x, + device T* y, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + (void)M; + + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + auto wl = w; + + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * static_cast(K); + wl += (y_col / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + y += y_row * static_cast(N) + y_col; + + LoaderW loader_w( + wl, N, Ws, simd_gid, simd_lid, y_col % LoaderW::weights_per_block); + + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; + constexpr short TK = SK / 16; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + const short ldb_tgp = BN_padded; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = false; + + using AccumType = float; + + NAXTile Dtile; + Dtile.clear(); + + x += tm * K; + + // Dispatch gates NAX entry on K%BK==0. + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + Atile.load(x + kk1, K); + Btile.template load(Ws + tn + kk1 * ldb_tgp); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + Dtile.store(y + tm * N + tn, N); +} + +// Q8_0: 34 bytes per 32-weight block. [fp16 d][int8 q[32]] + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqNaxQ8_0BlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q8_0_GROUP; // 32 + MLX_MTL_CONST int bytes_per_block = KQ_Q8_0_BLOCK_BYTES; // 34 + + static_assert( + BCOLS % weights_per_block == 0, + "Q8_0 NAX loader requires BCOLS to be a multiple of 32."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + static_assert( + n_reads <= weights_per_block, + "Q8_0 NAX loader: n_reads must not exceed 32 (one block per thread)."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + + KqNaxQ8_0BlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int /* col_in_block */ = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? (BCOLS / weights_per_block) * bytes_per_block + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj((thread_idx % TCOLS) * n_reads), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)) {} + + void load_unsafe() const { + const short block_idx = bj / weights_per_block; + const short within = bj % weights_per_block; + const device uint8_t* block_addr = src + block_idx * bytes_per_block; + const float d = float(*(const device half*)(block_addr + KQ_Q8_0_D_OFFSET)); + const device int8_t* q = + (const device int8_t*)(block_addr + KQ_Q8_0_Q_OFFSET + within); +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(d * float(q[i])); + } + } + + void load_safe(short2 src_tile_dim) const { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + src += tile_stride; + } +}; + +template < + typename T, + int group_size, + int bits, + bool aligned_N, + bool batched, + int BM, + int BN, + int WM, + int WN> +[[kernel]] void kq_q8_0_qmm_t_nax( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q8_0_GROUP, "Q8_0 NAX kernel requires group_size=32"); + static_assert(bits == 8, "Q8_0 NAX kernel requires bits=8"); + + constexpr int BK = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Ws[BN * BK_padded]; + + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + + using LoaderW = KqNaxQ8_0BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/WM * WN * SIMD_SIZE>; + kq_qmm_t_nax_tgp_impl( + w, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q8_0_qmm_n_nax( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q8_0_GROUP, "Q8_0 NAX kernel requires group_size=32"); + static_assert(bits == 8, "Q8_0 NAX kernel requires bits=8"); + + constexpr int BM = 64, BK = 64, BN = 64, WM = 2, WN = 2; + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Ws[BK * BN_padded]; + + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + + using LoaderW = KqNaxQ8_0BlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/WM * WN * SIMD_SIZE>; + kq_qmm_n_nax_tgp_impl( + w, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + int group_size, + int bits, + bool aligned_N, + int BM, + int BN, + int WM, + int WN> +[[kernel]] void kq_q8_0_gather_qmm_t_nax( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q8_0_GROUP, "Q8_0 NAX kernel requires group_size=32"); + static_assert(bits == 8, "Q8_0 NAX kernel requires bits=8"); + + constexpr int BK = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Ws[BN * BK_padded]; + + kq_adjust_matrix_offsets( + x, + w, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + + using LoaderW = KqNaxQ8_0BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/WM * WN * SIMD_SIZE>; + kq_qmm_t_nax_tgp_impl( + w, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q8_0_gather_qmm_n_nax( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q8_0_GROUP, "Q8_0 NAX kernel requires group_size=32"); + static_assert(bits == 8, "Q8_0 NAX kernel requires bits=8"); + + constexpr int BM = 64, BK = 64, BN = 64, WM = 2, WN = 2; + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Ws[BK * BN_padded]; + + kq_adjust_matrix_offsets( + x, + w, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + + using LoaderW = KqNaxQ8_0BlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/WM * WN * SIMD_SIZE>; + kq_qmm_n_nax_tgp_impl( + w, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +// Q5_1: 24 bytes per 32-weight block. [fp16 d][fp16 m][qh[4]][qs[16]] + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqNaxQ5_1BlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q5_1_GROUP; // 32 + MLX_MTL_CONST int bytes_per_block = KQ_Q5_1_BLOCK_BYTES; // 24 + + static_assert( + BCOLS % weights_per_block == 0, + "Q5_1 NAX loader requires BCOLS to be a multiple of 32."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + static_assert( + n_reads <= weights_per_block, + "Q5_1 NAX loader: n_reads must not exceed 32 (one block per thread)."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + + KqNaxQ5_1BlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int /* col_in_block */ = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? (BCOLS / weights_per_block) * bytes_per_block + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj((thread_idx % TCOLS) * n_reads), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)) {} + + void load_unsafe() const { + const short block_idx = bj / weights_per_block; + const short within = bj % weights_per_block; + const device uint8_t* block_addr = src + block_idx * bytes_per_block; + const float d = float(*(const device half*)(block_addr + KQ_Q5_1_D_OFFSET)); + const float m = float(*(const device half*)(block_addr + KQ_Q5_1_M_OFFSET)); + const uint32_t qh = + *(const device uint32_t*)(block_addr + KQ_Q5_1_QH_OFFSET); + const device uint8_t* qs = block_addr + KQ_Q5_1_QS_OFFSET; +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const int j = within + i; + const uint32_t hi = ((qh >> j) << 4) & 0x10u; + const uint8_t lo = (j < 16) ? (qs[j] & 0x0Fu) : (qs[j - 16] >> 4); + const float q5 = float(uint32_t(lo) | hi); + dst[i] = T(d * q5 + m); + } + } + + void load_safe(short2 src_tile_dim) const { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + src += tile_stride; + } +}; + +template < + typename T, + int group_size, + int bits, + bool aligned_N, + bool batched, + int BM, + int BN, + int WM, + int WN> +[[kernel]] void kq_q5_1_qmm_t_nax( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q5_1_GROUP, "Q5_1 NAX kernel requires group_size=32"); + static_assert(bits == 5, "Q5_1 NAX kernel requires bits=5"); + + constexpr int BK = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Ws[BN * BK_padded]; + + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + + using LoaderW = KqNaxQ5_1BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/WM * WN * SIMD_SIZE>; + kq_qmm_t_nax_tgp_impl( + w, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_1_qmm_n_nax( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q5_1_GROUP, "Q5_1 NAX kernel requires group_size=32"); + static_assert(bits == 5, "Q5_1 NAX kernel requires bits=5"); + + constexpr int BM = 64, BK = 64, BN = 64, WM = 2, WN = 2; + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Ws[BK * BN_padded]; + + if constexpr (batched) { + kq_adjust_matrix_offsets( + x, + w, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + } + + using LoaderW = KqNaxQ5_1BlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/WM * WN * SIMD_SIZE>; + kq_qmm_n_nax_tgp_impl( + w, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + int group_size, + int bits, + bool aligned_N, + int BM, + int BN, + int WM, + int WN> +[[kernel]] void kq_q5_1_gather_qmm_t_nax( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q5_1_GROUP, "Q5_1 NAX kernel requires group_size=32"); + static_assert(bits == 5, "Q5_1 NAX kernel requires bits=5"); + + constexpr int BK = 64; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Ws[BN * BK_padded]; + + kq_adjust_matrix_offsets( + x, + w, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + + using LoaderW = KqNaxQ5_1BlockLoader< + T, + BN, + BK, + BK_padded, + /*reduction_dim=*/1, + /*tgp_size=*/WM * WN * SIMD_SIZE>; + kq_qmm_t_nax_tgp_impl( + w, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void kq_q5_1_gather_qmm_n_nax( + const device uint8_t* w, + const device uint8_t* /* scales */, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* /* s_strides */, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert( + group_size == KQ_Q5_1_GROUP, "Q5_1 NAX kernel requires group_size=32"); + static_assert(bits == 5, "Q5_1 NAX kernel requires bits=5"); + + constexpr int BM = 64, BK = 64, BN = 64, WM = 2, WN = 2; + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Ws[BK * BN_padded]; + + kq_adjust_matrix_offsets( + x, + w, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + tid); + + using LoaderW = KqNaxQ5_1BlockLoader< + T, + BK, + BN, + BN_padded, + /*reduction_dim=*/0, + /*tgp_size=*/WM * WN * SIMD_SIZE>; + kq_qmm_n_nax_tgp_impl( + w, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqNaxQ4_0BlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q4_0_GROUP; + MLX_MTL_CONST int bytes_per_block = KQ_Q4_0_BLOCK_BYTES; + + static_assert( + BCOLS % weights_per_block == 0, + "Q4_0 NAX loader requires BCOLS to be a multiple of 32."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + static_assert( + n_reads <= weights_per_block, + "Q4_0 NAX loader: n_reads must not exceed 32 (one block per thread)."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + + KqNaxQ4_0BlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int /* col_in_block */ = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? (BCOLS / weights_per_block) * bytes_per_block + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj((thread_idx % TCOLS) * n_reads), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)) {} + + void load_unsafe() const { + const short block_idx = bj / weights_per_block; + const short within = bj % weights_per_block; + const device uint8_t* block_addr = src + block_idx * bytes_per_block; + const float d = float(*(const device half*)(block_addr + KQ_Q4_0_D_OFFSET)); + const device uint8_t* qs = block_addr + KQ_Q4_0_QS_OFFSET; +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const int j = within + i; + const int x = (j < 16) ? (qs[j] & 0x0Fu) : (qs[j - 16] >> 4); + dst[i] = T(d * float(int(x) - 8)); + } + } + + void load_safe(short2 src_tile_dim) const { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + src += tile_stride; + } +}; + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqNaxQ4_1BlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q4_1_GROUP; + MLX_MTL_CONST int bytes_per_block = KQ_Q4_1_BLOCK_BYTES; + + static_assert( + BCOLS % weights_per_block == 0, + "Q4_1 NAX loader requires BCOLS to be a multiple of 32."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + static_assert( + n_reads <= weights_per_block, + "Q4_1 NAX loader: n_reads must not exceed 32 (one block per thread)."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + + KqNaxQ4_1BlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int /* col_in_block */ = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? (BCOLS / weights_per_block) * bytes_per_block + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj((thread_idx % TCOLS) * n_reads), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)) {} + + void load_unsafe() const { + const short block_idx = bj / weights_per_block; + const short within = bj % weights_per_block; + const device uint8_t* block_addr = src + block_idx * bytes_per_block; + const float d = float(*(const device half*)(block_addr + KQ_Q4_1_D_OFFSET)); + const float m = float(*(const device half*)(block_addr + KQ_Q4_1_M_OFFSET)); + const device uint8_t* qs = block_addr + KQ_Q4_1_QS_OFFSET; +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const int j = within + i; + const int x = (j < 16) ? (qs[j] & 0x0Fu) : (qs[j - 16] >> 4); + dst[i] = T(d * float(x) + m); + } + } + + void load_safe(short2 src_tile_dim) const { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + src += tile_stride; + } +}; + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqNaxQ5_0BlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q5_0_GROUP; + MLX_MTL_CONST int bytes_per_block = KQ_Q5_0_BLOCK_BYTES; + + static_assert( + BCOLS % weights_per_block == 0, + "Q5_0 NAX loader requires BCOLS to be a multiple of 32."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + static_assert( + n_reads <= weights_per_block, + "Q5_0 NAX loader: n_reads must not exceed 32 (one block per thread)."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + + KqNaxQ5_0BlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int /* col_in_block */ = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? (BCOLS / weights_per_block) * bytes_per_block + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj((thread_idx % TCOLS) * n_reads), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)) {} + + void load_unsafe() const { + const short block_idx = bj / weights_per_block; + const short within = bj % weights_per_block; + const device uint8_t* block_addr = src + block_idx * bytes_per_block; + const float d = float(*(const device half*)(block_addr + KQ_Q5_0_D_OFFSET)); + const device uint8_t* qh_p = block_addr + KQ_Q5_0_QH_OFFSET; + const uint32_t qh = uint32_t(qh_p[0]) | (uint32_t(qh_p[1]) << 8) | + (uint32_t(qh_p[2]) << 16) | (uint32_t(qh_p[3]) << 24); + const device uint8_t* qs = block_addr + KQ_Q5_0_QS_OFFSET; +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const int j = within + i; + const uint32_t hi = ((qh >> j) << 4) & 0x10u; + const uint8_t lo = (j < 16) ? (qs[j] & 0x0Fu) : (qs[j - 16] >> 4); + const float q5 = float(uint32_t(lo) | hi); + dst[i] = T(d * (q5 - 16.0f)); + } + } + + void load_safe(short2 src_tile_dim) const { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + src += tile_stride; + } +}; + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqNaxQ4_KBlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q4_K_SUPERBLOCK; + MLX_MTL_CONST int bytes_per_block = KQ_Q4_K_BLOCK_BYTES; + MLX_MTL_CONST int sub_block_size = 32; + MLX_MTL_CONST int sub_blocks_per_block = weights_per_block / sub_block_size; + + static_assert( + BCOLS == 64, + "Q4_K NAX loader requires BCOLS == 64 (two sub-blocks per K-tile)."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + MLX_MTL_CONST short bytes_per_thread = n_reads / 2; + static_assert( + n_reads == sub_block_size, + "Q4_K NAX loader expects n_reads == 32 (half pair per thread)."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + const short fixed_sb_base; + + const short thread_idx; + const short bi; + const short bj_byte; + + threadgroup T* dst; + const device uint8_t* src; + short sb_base; + + KqNaxQ4_KBlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int col_in_block = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? 0 + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + fixed_sb_base(reduction_dim == 0 ? (col_in_block / sub_block_size) : 0), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj_byte((thread_idx % TCOLS) * bytes_per_thread), + dst(dst_ + bi * dst_ld + bj_byte), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)), + sb_base(0) {} + + void load_unsafe() const { + const short pair_base = (reduction_dim == 0) ? fixed_sb_base : sb_base; + + const float d = float(*(const device half*)(src + KQ_Q4_K_D_OFFSET)); + const float dmin = float(*(const device half*)(src + KQ_Q4_K_DMIN_OFFSET)); + const device uint8_t* scales12 = src + KQ_Q4_K_SCALES_OFFSET; + + uint8_t sc6_lo, mn6_lo, sc6_hi, mn6_hi; + kq_get_scale_min_k4(pair_base + 0, scales12, sc6_lo, mn6_lo); + kq_get_scale_min_k4(pair_base + 1, scales12, sc6_hi, mn6_hi); + const float eff_scale_lo = d * float(sc6_lo); + const float eff_min_lo = dmin * float(mn6_lo); + const float eff_scale_hi = d * float(sc6_hi); + const float eff_min_hi = dmin * float(mn6_hi); + + const short pair = pair_base / 2; + const device uint8_t* qs = src + KQ_Q4_K_QS_OFFSET + pair * 32 + bj_byte; + + static_assert( + bytes_per_thread == 16, + "Q4_K NAX vector load assumes bytes_per_thread == 16 (uint4)."); + const uint4 qs_v = *reinterpret_cast(qs); + const thread uint8_t* qs_b = reinterpret_cast(&qs_v); + +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + const uint8_t b = qs_b[i]; + dst[i] = T(eff_scale_lo * float(b & 0x0F) - eff_min_lo); + dst[sub_block_size + i] = T(eff_scale_hi * float(b >> 4) - eff_min_hi); + } + } + + void load_safe(short2 src_tile_dim) const { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + dst[i] = T(0); + dst[sub_block_size + i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + if (reduction_dim == 1) { + sb_base += 2; + if (sb_base == sub_blocks_per_block) { + sb_base = 0; + src += bytes_per_block; + } + } else { + src += tile_stride; + } + } +}; + +// Q5_K: 176 bytes per 256-weight super-block. Q4_K layout + qh[32] high bits. + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqNaxQ5_KBlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q5_K_SUPERBLOCK; + MLX_MTL_CONST int bytes_per_block = KQ_Q5_K_BLOCK_BYTES; + MLX_MTL_CONST int sub_block_size = 32; + MLX_MTL_CONST int sub_blocks_per_block = weights_per_block / sub_block_size; + + static_assert(BCOLS == 64, "Q5_K NAX loader requires BCOLS == 64."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + MLX_MTL_CONST short bytes_per_thread = n_reads / 2; + static_assert(n_reads == sub_block_size, "Q5_K NAX expects n_reads == 32."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + const short fixed_sb_base; + + const short thread_idx; + const short bi; + const short bj_byte; + + threadgroup T* dst; + const device uint8_t* src; + short sb_base; + // qh cached on sb_base==0; reduction_dim==0 reads per-call instead. + struct Caches { + uint8_t qh_cache[bytes_per_thread]; + }; + metal::conditional_t cached; + + KqNaxQ5_KBlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int col_in_block = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? 0 + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + fixed_sb_base(reduction_dim == 0 ? (col_in_block / sub_block_size) : 0), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj_byte((thread_idx % TCOLS) * bytes_per_thread), + dst(dst_ + bi * dst_ld + bj_byte), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)), + sb_base(0) {} + + void load_unsafe() { + static_assert( + bytes_per_thread == 16, + "Q5_K NAX vector load assumes bytes_per_thread == 16 (uint4)."); + + if constexpr (reduction_dim == 1) { + const short pair_base = sb_base; + + const float d = float(*(const device half*)(src + KQ_Q5_K_D_OFFSET)); + const float dmin = + float(*(const device half*)(src + KQ_Q5_K_DMIN_OFFSET)); + const device uint8_t* scales12 = src + KQ_Q5_K_SCALES_OFFSET; + + uint8_t sc6_lo, mn6_lo, sc6_hi, mn6_hi; + kq_get_scale_min_k4(pair_base + 0, scales12, sc6_lo, mn6_lo); + kq_get_scale_min_k4(pair_base + 1, scales12, sc6_hi, mn6_hi); + const float eff_scale_lo = d * float(sc6_lo); + const float eff_min_lo = dmin * float(mn6_lo); + const float eff_scale_hi = d * float(sc6_hi); + const float eff_min_hi = dmin * float(mn6_hi); + + const short pair = pair_base / 2; + const device uint8_t* qs = src + KQ_Q5_K_QS_OFFSET + pair * 32 + bj_byte; + const device uint8_t* qh = src + KQ_Q5_K_QH_OFFSET + bj_byte; + + const uint4 qs_v = *reinterpret_cast(qs); + const thread uint8_t* qs_b = + reinterpret_cast(&qs_v); + + if (sb_base == 0) { + const uint4 qh_v = *reinterpret_cast(qh); + const thread uint8_t* qh_b = + reinterpret_cast(&qh_v); +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + cached.qh_cache[i] = qh_b[i]; + } + } + +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + const uint8_t b = qs_b[i]; + const uint8_t h = cached.qh_cache[i]; + const uint8_t q4_lo = b & 0x0F; + const uint8_t q4_hi = b >> 4; + const uint8_t hi_lo = (h >> pair_base) & 1u; + const uint8_t hi_hi = (h >> (pair_base + 1)) & 1u; + const uint8_t q5_lo = q4_lo | (hi_lo << 4); + const uint8_t q5_hi = q4_hi | (hi_hi << 4); + dst[i] = T(eff_scale_lo * float(q5_lo) - eff_min_lo); + dst[sub_block_size + i] = T(eff_scale_hi * float(q5_hi) - eff_min_hi); + } + return; + } + + const short pair_base = fixed_sb_base; + + const float d = float(*(const device half*)(src + KQ_Q5_K_D_OFFSET)); + const float dmin = float(*(const device half*)(src + KQ_Q5_K_DMIN_OFFSET)); + const device uint8_t* scales12 = src + KQ_Q5_K_SCALES_OFFSET; + + uint8_t sc6_lo, mn6_lo, sc6_hi, mn6_hi; + kq_get_scale_min_k4(pair_base + 0, scales12, sc6_lo, mn6_lo); + kq_get_scale_min_k4(pair_base + 1, scales12, sc6_hi, mn6_hi); + const float eff_scale_lo = d * float(sc6_lo); + const float eff_min_lo = dmin * float(mn6_lo); + const float eff_scale_hi = d * float(sc6_hi); + const float eff_min_hi = dmin * float(mn6_hi); + + const short pair = pair_base / 2; + const device uint8_t* qs = src + KQ_Q5_K_QS_OFFSET + pair * 32 + bj_byte; + const device uint8_t* qh = src + KQ_Q5_K_QH_OFFSET + bj_byte; + + const uint4 qs_v = *reinterpret_cast(qs); + const thread uint8_t* qs_b = reinterpret_cast(&qs_v); + + const uint4 qh_v = *reinterpret_cast(qh); + const thread uint8_t* qh_b = reinterpret_cast(&qh_v); + +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + const uint8_t b = qs_b[i]; + const uint8_t h = qh_b[i]; + const uint8_t q4_lo = b & 0x0F; + const uint8_t q4_hi = b >> 4; + const uint8_t hi_lo = (h >> pair_base) & 1u; + const uint8_t hi_hi = (h >> (pair_base + 1)) & 1u; + const uint8_t q5_lo = q4_lo | (hi_lo << 4); + const uint8_t q5_hi = q4_hi | (hi_hi << 4); + dst[i] = T(eff_scale_lo * float(q5_lo) - eff_min_lo); + dst[sub_block_size + i] = T(eff_scale_hi * float(q5_hi) - eff_min_hi); + } + } + + void load_safe(short2 src_tile_dim) { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + dst[i] = T(0); + dst[sub_block_size + i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + if (reduction_dim == 1) { + sb_base += 2; + if (sb_base == sub_blocks_per_block) { + sb_base = 0; + src += bytes_per_block; + } + } else { + src += tile_stride; + } + } +}; + +// Q6_K: 210 bytes per 256-weight super-block. Reversed field order: +// [ql[128]][qh[64]][int8 scales[16]][fp16 d] + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqNaxQ6_KBlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q6_K_SUPERBLOCK; + MLX_MTL_CONST int bytes_per_block = KQ_Q6_K_BLOCK_BYTES; + MLX_MTL_CONST int k_tile_size = 32; + MLX_MTL_CONST int k_tiles_per_block = weights_per_block / k_tile_size; + + static_assert(BCOLS == 64, "Q6_K NAX loader requires BCOLS == 64."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + static_assert(n_reads == k_tile_size, "Q6_K NAX expects n_reads == 32."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + const short fixed_kt_base; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + short kt_base; + // Pair-cache: kt_base & 2 == 0 computes both pairs; reduction_dim==0 has no + // cache. + struct Caches { + T cached[n_reads]; + }; + metal::conditional_t cached; + + KqNaxQ6_KBlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int col_in_block = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? 0 + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + fixed_kt_base(reduction_dim == 0 ? (col_in_block / k_tile_size) : 0), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj((thread_idx % TCOLS) * n_reads), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)), + kt_base(0) {} + + void load_unsafe() { + if constexpr (reduction_dim == 1) { + if (kt_base & 2) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = cached.cached[i]; + } + return; + } + + const short base = kt_base; + const short kt = base + (bj / k_tile_size); + const short half_idx = kt / 4; + const short quadrant = kt - half_idx * 4; + + const float d = float(*(const device half*)(src + KQ_Q6_K_D_OFFSET)); + const device int8_t* scales = + (const device int8_t*)(src + KQ_Q6_K_SCALES_OFFSET); + + const device uint8_t* ql_base = + src + KQ_Q6_K_QL_OFFSET + half_idx * 64 + (quadrant & 1) * 32; + const device uint8_t* qh_base = src + KQ_Q6_K_QH_OFFSET + half_idx * 32; + + const float es_lo_a = d * float(scales[kt * 2 + 0]); + const float es_hi_a = d * float(scales[kt * 2 + 1]); + const float es_lo_b = d * float(scales[(kt + 2) * 2 + 0]); + const float es_hi_b = d * float(scales[(kt + 2) * 2 + 1]); + const short qh_shift_a = quadrant * 2; + const short qh_shift_b = qh_shift_a + 4; + +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const uint8_t ql_byte = ql_base[i]; + const uint8_t h = qh_base[i]; + const float es_a = (i >= 16) ? es_hi_a : es_lo_a; + const float es_b = (i >= 16) ? es_hi_b : es_lo_b; + const uint8_t low4_a = ql_byte & 0x0F; + const uint8_t low4_b = ql_byte >> 4; + const uint8_t high2_a = (uint8_t)((h >> qh_shift_a) & 0x03); + const uint8_t high2_b = (uint8_t)((h >> qh_shift_b) & 0x03); + const int8_t q6_a = (int8_t)(low4_a | (high2_a << 4)) - (int8_t)32; + const int8_t q6_b = (int8_t)(low4_b | (high2_b << 4)) - (int8_t)32; + dst[i] = T(es_a * float(q6_a)); + cached.cached[i] = T(es_b * float(q6_b)); + } + return; + } + + const short base = fixed_kt_base; + const short kt = base + (bj / k_tile_size); + const short half_idx = kt / 4; + const short quadrant = kt - half_idx * 4; + + const float d = float(*(const device half*)(src + KQ_Q6_K_D_OFFSET)); + const device int8_t* scales = + (const device int8_t*)(src + KQ_Q6_K_SCALES_OFFSET); + + const device uint8_t* ql_base = + src + KQ_Q6_K_QL_OFFSET + half_idx * 64 + (quadrant & 1) * 32; + const device uint8_t* qh_base = src + KQ_Q6_K_QH_OFFSET + half_idx * 32; + + const bool is_high_nibble = (quadrant >= 2); + const short qh_shift = quadrant * 2; + const float eff_scale_lo = d * float(scales[kt * 2 + 0]); + const float eff_scale_hi = d * float(scales[kt * 2 + 1]); + +#pragma unroll + for (short i = 0; i < n_reads; i++) { + const float eff_scale = (i >= 16) ? eff_scale_hi : eff_scale_lo; + const uint8_t low4 = + is_high_nibble ? (ql_base[i] >> 4) : (ql_base[i] & 0x0F); + const uint8_t high2 = (uint8_t)((qh_base[i] >> qh_shift) & 0x03); + const int8_t q6 = (int8_t)(low4 | (high2 << 4)) - (int8_t)32; + dst[i] = T(eff_scale * float(q6)); + } + } + + void load_safe(short2 src_tile_dim) { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < n_reads; i++) { + dst[i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + if (reduction_dim == 1) { + kt_base += 2; + if (kt_base == k_tiles_per_block) { + kt_base = 0; + src += bytes_per_block; + } + } else { + src += tile_stride; + } + } +}; + +// Q3_K: 110 bytes per 256-weight super-block. +// [hmask[32]][qs[64]][scales[12]][fp16 d] + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqNaxQ3_KBlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q3_K_SUPERBLOCK; + MLX_MTL_CONST int bytes_per_block = KQ_Q3_K_BLOCK_BYTES; + MLX_MTL_CONST int k_tile_size = 32; + MLX_MTL_CONST int k_tiles_per_block = weights_per_block / k_tile_size; + + static_assert(BCOLS == 64, "Q3_K NAX loader requires BCOLS == 64."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + MLX_MTL_CONST short bytes_per_thread = n_reads / 2; + static_assert(n_reads == k_tile_size, "Q3_K NAX expects n_reads == 32."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + const short fixed_kt_base; + + const short thread_idx; + const short bi; + const short bj_byte; + + threadgroup T* dst; + const device uint8_t* src; + short kt_base; + // Pair-cache + hmask cache; reduction_dim==0 has no register storage. + struct Caches { + T cached[n_reads]; + uint8_t hmask_cache[bytes_per_thread]; + }; + metal::conditional_t cached; + + KqNaxQ3_KBlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int col_in_block = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? 0 + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + fixed_kt_base(reduction_dim == 0 ? (col_in_block / k_tile_size) : 0), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj_byte((thread_idx % TCOLS) * bytes_per_thread), + dst(dst_ + bi * dst_ld + bj_byte), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)), + kt_base(0) {} + + void load_unsafe() { + if constexpr (reduction_dim == 1) { + if (kt_base & 2) { +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + dst[i] = cached.cached[i]; + dst[k_tile_size + i] = cached.cached[bytes_per_thread + i]; + } + return; + } + + const short base = kt_base; + const short outer_half = base / 4; + const short scale_off = (bj_byte >= 16) ? 1 : 0; + + const float d = float(*(const device half*)(src + KQ_Q3_K_D_OFFSET)); + const device uint8_t* qs = + src + KQ_Q3_K_QS_OFFSET + outer_half * 32 + bj_byte; + const device uint8_t* hm = src + KQ_Q3_K_HMASK_OFFSET + bj_byte; + + const float es_a = d * + float((int)kq_q3_k_unpack_scale( + base * 2 + scale_off, src + KQ_Q3_K_SCALES_OFFSET) - + 32); + const float es_b = d * + float((int)kq_q3_k_unpack_scale( + (base + 1) * 2 + scale_off, src + KQ_Q3_K_SCALES_OFFSET) - + 32); + const float es_c = d * + float((int)kq_q3_k_unpack_scale( + (base + 2) * 2 + scale_off, src + KQ_Q3_K_SCALES_OFFSET) - + 32); + const float es_d = d * + float((int)kq_q3_k_unpack_scale( + (base + 3) * 2 + scale_off, src + KQ_Q3_K_SCALES_OFFSET) - + 32); + + const short shift_a = (base & 3) * 2; + const short shift_b = ((base + 1) & 3) * 2; + const short shift_c = ((base + 2) & 3) * 2; + const short shift_d = ((base + 3) & 3) * 2; + + if (kt_base == 0) { +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + cached.hmask_cache[i] = hm[i]; + } + } + +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + const uint8_t q = qs[i]; + const uint8_t h = cached.hmask_cache[i]; + const uint8_t q2_a = (q >> shift_a) & 0x03; + const uint8_t q2_b = (q >> shift_b) & 0x03; + const uint8_t q2_c = (q >> shift_c) & 0x03; + const uint8_t q2_d = (q >> shift_d) & 0x03; + const int q3_a = (int)q2_a - (((h >> base) & 1) ? 0 : 4); + const int q3_b = (int)q2_b - (((h >> (base + 1)) & 1) ? 0 : 4); + const int q3_c = (int)q2_c - (((h >> (base + 2)) & 1) ? 0 : 4); + const int q3_d = (int)q2_d - (((h >> (base + 3)) & 1) ? 0 : 4); + dst[i] = T(es_a * float(q3_a)); + dst[k_tile_size + i] = T(es_b * float(q3_b)); + cached.cached[i] = T(es_c * float(q3_c)); + cached.cached[bytes_per_thread + i] = T(es_d * float(q3_d)); + } + return; + } + + const short base = fixed_kt_base; + const short outer_half = base / 4; + const short scale_off = (bj_byte >= 16) ? 1 : 0; + + const float d = float(*(const device half*)(src + KQ_Q3_K_D_OFFSET)); + const device uint8_t* qs = + src + KQ_Q3_K_QS_OFFSET + outer_half * 32 + bj_byte; + const device uint8_t* hm = src + KQ_Q3_K_HMASK_OFFSET + bj_byte; + + const short kt = base; + const float es = d * + float((int)kq_q3_k_unpack_scale( + kt * 2 + scale_off, src + KQ_Q3_K_SCALES_OFFSET) - + 32); + const short shift = (kt & 3) * 2; + const short hbit = kt; + + const float es_b = d * + float((int)kq_q3_k_unpack_scale( + (kt + 1) * 2 + scale_off, src + KQ_Q3_K_SCALES_OFFSET) - + 32); + const short shift_b = ((kt + 1) & 3) * 2; + const short hbit_b = kt + 1; + +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + const uint8_t q = qs[i]; + const uint8_t h = hm[i]; + const uint8_t q2_a = (q >> shift) & 0x03; + const uint8_t q2_b = (q >> shift_b) & 0x03; + const int q3_a = (int)q2_a - (((h >> hbit) & 1) ? 0 : 4); + const int q3_b = (int)q2_b - (((h >> hbit_b) & 1) ? 0 : 4); + dst[i] = T(es * float(q3_a)); + dst[k_tile_size + i] = T(es_b * float(q3_b)); + } + } + + void load_safe(short2 src_tile_dim) { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + dst[i] = T(0); + dst[k_tile_size + i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + if (reduction_dim == 1) { + kt_base += 2; + if (kt_base == k_tiles_per_block) { + kt_base = 0; + src += bytes_per_block; + } + } else { + src += tile_stride; + } + } +}; + +// Q2_K: 84 bytes per 256-weight super-block. [scales[16]][qs[64]][fp16 d][fp16 +// dmin] + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size> +struct KqNaxQ2_KBlockLoader { + MLX_MTL_CONST int weights_per_block = KQ_Q2_K_SUPERBLOCK; + MLX_MTL_CONST int bytes_per_block = KQ_Q2_K_BLOCK_BYTES; + MLX_MTL_CONST int k_tile_size = 32; + MLX_MTL_CONST int k_tiles_per_block = weights_per_block / k_tile_size; + + static_assert(BCOLS == 64, "Q2_K NAX loader requires BCOLS == 64."); + static_assert( + (BCOLS * BROWS) % tgp_size == 0, + "tgp_size must evenly divide BCOLS * BROWS."); + + MLX_MTL_CONST short n_reads = (BCOLS * BROWS) / tgp_size; + MLX_MTL_CONST short TCOLS = BCOLS / n_reads; + MLX_MTL_CONST short bytes_per_thread = n_reads / 2; + static_assert(n_reads == k_tile_size, "Q2_K NAX expects n_reads == 32."); + + const int src_ld; + const int row_bytes; + const int tile_stride; + const short fixed_kt_base; + + const short thread_idx; + const short bi; + const short bj_byte; + + threadgroup T* dst; + const device uint8_t* src; + short kt_base; + struct Caches { + T cached[n_reads]; + }; + metal::conditional_t cached; + + KqNaxQ2_KBlockLoader( + const device uint8_t* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]], + int col_in_block = 0) + : src_ld(src_ld_), + row_bytes(src_ld_ * bytes_per_block / weights_per_block), + tile_stride( + reduction_dim + ? 0 + : BROWS * (src_ld_ * bytes_per_block / weights_per_block)), + fixed_kt_base(reduction_dim == 0 ? (col_in_block / k_tile_size) : 0), + thread_idx(simd_group_id * SIMD_SIZE + simd_lane_id), + bi(thread_idx / TCOLS), + bj_byte((thread_idx % TCOLS) * bytes_per_thread), + dst(dst_ + bi * dst_ld + bj_byte), + src(src_ + bi * (src_ld_ * bytes_per_block / weights_per_block)), + kt_base(0) {} + + void load_unsafe() { + if constexpr (reduction_dim == 1) { + if (kt_base & 2) { +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + dst[i] = cached.cached[i]; + dst[k_tile_size + i] = cached.cached[bytes_per_thread + i]; + } + return; + } + + const short base = kt_base; + const short outer_half = base / 4; + const short scale_off = (bj_byte >= 16) ? 1 : 0; + + const float d = float(*(const device half*)(src + KQ_Q2_K_D_OFFSET)); + const float dmin = + float(*(const device half*)(src + KQ_Q2_K_DMIN_OFFSET)); + const device uint8_t* qs = + src + KQ_Q2_K_QS_OFFSET + outer_half * 32 + bj_byte; + + static_assert( + bytes_per_thread == 16, + "Q2_K NAX vector load assumes bytes_per_thread == 16."); + uint8_t qs_b[bytes_per_thread]; +#pragma unroll + for (short v = 0; v < bytes_per_thread / 4; v++) { + const uint qs_v = *reinterpret_cast(qs + v * 4); + *reinterpret_cast(&qs_b[v * 4]) = qs_v; + } + + const uint8_t sc_a = src[KQ_Q2_K_SCALES_OFFSET + base * 2 + scale_off]; + const uint8_t sc_b = + src[KQ_Q2_K_SCALES_OFFSET + (base + 1) * 2 + scale_off]; + const uint8_t sc_c = + src[KQ_Q2_K_SCALES_OFFSET + (base + 2) * 2 + scale_off]; + const uint8_t sc_d = + src[KQ_Q2_K_SCALES_OFFSET + (base + 3) * 2 + scale_off]; + const float es_a = d * float(sc_a & 0x0F); + const float em_a = dmin * float(sc_a >> 4); + const float es_b = d * float(sc_b & 0x0F); + const float em_b = dmin * float(sc_b >> 4); + const float es_c = d * float(sc_c & 0x0F); + const float em_c = dmin * float(sc_c >> 4); + const float es_d = d * float(sc_d & 0x0F); + const float em_d = dmin * float(sc_d >> 4); + + const short shift_a = (base & 3) * 2; + const short shift_b = ((base + 1) & 3) * 2; + const short shift_c = ((base + 2) & 3) * 2; + const short shift_d = ((base + 3) & 3) * 2; + +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + const uint8_t q = qs_b[i]; + const uint8_t q2_a = (q >> shift_a) & 0x03; + const uint8_t q2_b = (q >> shift_b) & 0x03; + const uint8_t q2_c = (q >> shift_c) & 0x03; + const uint8_t q2_d = (q >> shift_d) & 0x03; + dst[i] = T(es_a * float(q2_a) - em_a); + dst[k_tile_size + i] = T(es_b * float(q2_b) - em_b); + cached.cached[i] = T(es_c * float(q2_c) - em_c); + cached.cached[bytes_per_thread + i] = T(es_d * float(q2_d) - em_d); + } + return; + } + + const short base = fixed_kt_base; + const short outer_half = base / 4; + const short scale_off = (bj_byte >= 16) ? 1 : 0; + + const float d = float(*(const device half*)(src + KQ_Q2_K_D_OFFSET)); + const float dmin = float(*(const device half*)(src + KQ_Q2_K_DMIN_OFFSET)); + const device uint8_t* qs = + src + KQ_Q2_K_QS_OFFSET + outer_half * 32 + bj_byte; + + static_assert( + bytes_per_thread == 16, + "Q2_K NAX vector load assumes bytes_per_thread == 16."); + uint8_t qs_b[bytes_per_thread]; +#pragma unroll + for (short v = 0; v < bytes_per_thread / 4; v++) { + const uint qs_v = *reinterpret_cast(qs + v * 4); + *reinterpret_cast(&qs_b[v * 4]) = qs_v; + } + + const short kt = base; + const uint8_t sc_a = src[KQ_Q2_K_SCALES_OFFSET + kt * 2 + scale_off]; + const uint8_t sc_b = src[KQ_Q2_K_SCALES_OFFSET + (kt + 1) * 2 + scale_off]; + const float es_a = d * float(sc_a & 0x0F); + const float em_a = dmin * float(sc_a >> 4); + const float es_b = d * float(sc_b & 0x0F); + const float em_b = dmin * float(sc_b >> 4); + const short shift_a = (kt & 3) * 2; + const short shift_b = ((kt + 1) & 3) * 2; + +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + const uint8_t q = qs_b[i]; + const uint8_t q2_a = (q >> shift_a) & 0x03; + const uint8_t q2_b = (q >> shift_b) & 0x03; + dst[i] = T(es_a * float(q2_a) - em_a); + dst[k_tile_size + i] = T(es_b * float(q2_b) - em_b); + } + } + + void load_safe(short2 src_tile_dim) { + if (bi >= src_tile_dim.y) { +#pragma unroll + for (short i = 0; i < bytes_per_thread; i++) { + dst[i] = T(0); + dst[k_tile_size + i] = T(0); + } + return; + } + load_unsafe(); + } + + void next() { + if (reduction_dim == 1) { + kt_base += 2; + if (kt_base == k_tiles_per_block) { + kt_base = 0; + src += bytes_per_block; + } + } else { + src += tile_stride; + } + } +}; + +#define KQ_NAX_DEFINE_KERNELS(codec, GROUP_CONST, bits_val, LOADER) \ + template < \ + typename T, \ + int group_size, \ + int bits, \ + bool aligned_N, \ + bool batched, \ + int BM, \ + int BN, \ + int WM, \ + int WN> \ + [[kernel]] void kq_##codec##_qmm_t_nax( \ + const device uint8_t* w, \ + const device uint8_t* /* scales */, \ + const device T* x, \ + device T* y, \ + const constant int& K, \ + const constant int& N, \ + const constant int& M, \ + const constant int& x_batch_ndims, \ + const constant int* x_shape, \ + const constant int64_t* x_strides, \ + const constant int& w_batch_ndims, \ + const constant int* w_shape, \ + const constant int64_t* w_strides, \ + const constant int64_t* /* s_strides */, \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]) { \ + static_assert( \ + group_size == GROUP_CONST, \ + #codec " NAX kernel requires group_size=" #GROUP_CONST); \ + static_assert( \ + bits == bits_val, #codec " NAX kernel requires bits=" #bits_val); \ + constexpr int BK = 64; \ + constexpr int BK_padded = (BK + 16 / sizeof(T)); \ + threadgroup T Ws[BN * BK_padded]; \ + if constexpr (batched) { \ + kq_adjust_matrix_offsets( \ + x, \ + w, \ + y, \ + M * N, \ + x_batch_ndims, \ + x_shape, \ + x_strides, \ + w_batch_ndims, \ + w_shape, \ + w_strides, \ + tid); \ + } \ + using LoaderW = LOADER< \ + T, \ + BN, \ + BK, \ + BK_padded, \ + /*reduction_dim=*/1, \ + /*tgp_size=*/WM * WN * SIMD_SIZE>; \ + kq_qmm_t_nax_tgp_impl( \ + w, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); \ + } \ + \ + template \ + [[kernel]] void kq_##codec##_qmm_n_nax( \ + const device uint8_t* w, \ + const device uint8_t* /* scales */, \ + const device T* x, \ + device T* y, \ + const constant int& K, \ + const constant int& N, \ + const constant int& M, \ + const constant int& x_batch_ndims, \ + const constant int* x_shape, \ + const constant int64_t* x_strides, \ + const constant int& w_batch_ndims, \ + const constant int* w_shape, \ + const constant int64_t* w_strides, \ + const constant int64_t* /* s_strides */, \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]) { \ + static_assert( \ + group_size == GROUP_CONST, \ + #codec " NAX kernel requires group_size=" #GROUP_CONST); \ + static_assert( \ + bits == bits_val, #codec " NAX kernel requires bits=" #bits_val); \ + constexpr int BM = 64, BK = 64, BN = 64, WM = 2, WN = 2; \ + constexpr int BN_padded = (BN + 16 / sizeof(T)); \ + threadgroup T Ws[BK * BN_padded]; \ + if constexpr (batched) { \ + kq_adjust_matrix_offsets( \ + x, \ + w, \ + y, \ + M * N, \ + x_batch_ndims, \ + x_shape, \ + x_strides, \ + w_batch_ndims, \ + w_shape, \ + w_strides, \ + tid); \ + } \ + using LoaderW = LOADER< \ + T, \ + BK, \ + BN, \ + BN_padded, \ + /*reduction_dim=*/0, \ + /*tgp_size=*/WM * WN * SIMD_SIZE>; \ + kq_qmm_n_nax_tgp_impl( \ + w, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); \ + } \ + \ + template < \ + typename T, \ + int group_size, \ + int bits, \ + bool aligned_N, \ + int BM, \ + int BN, \ + int WM, \ + int WN> \ + [[kernel]] void kq_##codec##_gather_qmm_t_nax( \ + const device uint8_t* w, \ + const device uint8_t* /* scales */, \ + const device T* x, \ + const device uint32_t* lhs_indices, \ + const device uint32_t* rhs_indices, \ + device T* y, \ + const constant int& K, \ + const constant int& N, \ + const constant int& M, \ + const constant int& x_batch_ndims, \ + const constant int* x_shape, \ + const constant int64_t* x_strides, \ + const constant int& w_batch_ndims, \ + const constant int* w_shape, \ + const constant int64_t* w_strides, \ + const constant int64_t* /* s_strides */, \ + const constant int& batch_ndims, \ + const constant int* batch_shape, \ + const constant int64_t* lhs_strides, \ + const constant int64_t* rhs_strides, \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]) { \ + static_assert( \ + group_size == GROUP_CONST, \ + #codec " NAX kernel requires group_size=" #GROUP_CONST); \ + static_assert( \ + bits == bits_val, #codec " NAX kernel requires bits=" #bits_val); \ + constexpr int BK = 64; \ + constexpr int BK_padded = (BK + 16 / sizeof(T)); \ + threadgroup T Ws[BN * BK_padded]; \ + kq_adjust_matrix_offsets( \ + x, \ + w, \ + lhs_indices, \ + rhs_indices, \ + y, \ + M * N, \ + batch_ndims, \ + batch_shape, \ + lhs_strides, \ + rhs_strides, \ + x_batch_ndims, \ + x_shape, \ + x_strides, \ + w_batch_ndims, \ + w_shape, \ + w_strides, \ + tid); \ + using LoaderW = LOADER< \ + T, \ + BN, \ + BK, \ + BK_padded, \ + /*reduction_dim=*/1, \ + /*tgp_size=*/WM * WN * SIMD_SIZE>; \ + kq_qmm_t_nax_tgp_impl( \ + w, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); \ + } \ + \ + template \ + [[kernel]] void kq_##codec##_gather_qmm_n_nax( \ + const device uint8_t* w, \ + const device uint8_t* /* scales */, \ + const device T* x, \ + const device uint32_t* lhs_indices, \ + const device uint32_t* rhs_indices, \ + device T* y, \ + const constant int& K, \ + const constant int& N, \ + const constant int& M, \ + const constant int& x_batch_ndims, \ + const constant int* x_shape, \ + const constant int64_t* x_strides, \ + const constant int& w_batch_ndims, \ + const constant int* w_shape, \ + const constant int64_t* w_strides, \ + const constant int64_t* /* s_strides */, \ + const constant int& batch_ndims, \ + const constant int* batch_shape, \ + const constant int64_t* lhs_strides, \ + const constant int64_t* rhs_strides, \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint lid [[thread_index_in_threadgroup]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]) { \ + static_assert( \ + group_size == GROUP_CONST, \ + #codec " NAX kernel requires group_size=" #GROUP_CONST); \ + static_assert( \ + bits == bits_val, #codec " NAX kernel requires bits=" #bits_val); \ + constexpr int BM = 64, BK = 64, BN = 64, WM = 2, WN = 2; \ + constexpr int BN_padded = (BN + 16 / sizeof(T)); \ + threadgroup T Ws[BK * BN_padded]; \ + kq_adjust_matrix_offsets( \ + x, \ + w, \ + lhs_indices, \ + rhs_indices, \ + y, \ + M * N, \ + batch_ndims, \ + batch_shape, \ + lhs_strides, \ + rhs_strides, \ + x_batch_ndims, \ + x_shape, \ + x_strides, \ + w_batch_ndims, \ + w_shape, \ + w_strides, \ + tid); \ + using LoaderW = LOADER< \ + T, \ + BK, \ + BN, \ + BN_padded, \ + /*reduction_dim=*/0, \ + /*tgp_size=*/WM * WN * SIMD_SIZE>; \ + kq_qmm_n_nax_tgp_impl( \ + w, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); \ + } + +KQ_NAX_DEFINE_KERNELS(q4_0, 32, 4, KqNaxQ4_0BlockLoader) +KQ_NAX_DEFINE_KERNELS(q4_1, 32, 4, KqNaxQ4_1BlockLoader) +KQ_NAX_DEFINE_KERNELS(q5_0, 32, 5, KqNaxQ5_0BlockLoader) +KQ_NAX_DEFINE_KERNELS(q4_k, 256, 4, KqNaxQ4_KBlockLoader) +KQ_NAX_DEFINE_KERNELS(q5_k, 256, 5, KqNaxQ5_KBlockLoader) +KQ_NAX_DEFINE_KERNELS(q6_k, 256, 6, KqNaxQ6_KBlockLoader) +KQ_NAX_DEFINE_KERNELS(q3_k, 256, 3, KqNaxQ3_KBlockLoader) +KQ_NAX_DEFINE_KERNELS(q2_k, 256, 2, KqNaxQ2_KBlockLoader) + +template < + typename T, + typename LoaderW, + bool transpose, + int BM = 64, + int BN = 64, + int BK = 64, + int WM = 2, + int WN = 2> +METAL_FUNC void kq_gather_qmm_rhs_nax_tgp_impl( + const device T* x, + const device uint8_t* w, + const device uint32_t* indices, + device T* y, + const constant int& M, + const constant int& N, + const constant int& K, + threadgroup T* Ws, + uint3 tid, + uint simd_group_id, + uint simd_lane_id) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + const int K_w = (K / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + const int N_w = (N / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + const int K_it = K / BK; + const size_t stride_w = transpose ? size_t(N) * K_w : size_t(K) * N_w; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + const int k_remain = K - K_it * BK; + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + auto wl = w; + x += y_row_long * static_cast(K); + y += y_row_long * static_cast(N) + y_col_long; + if (transpose) { + wl += y_col_long * K_w; + } else { + wl += (y_col_long / LoaderW::weights_per_block) * LoaderW::bytes_per_block; + } + + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / 16; + constexpr short TN = SN / 16; + constexpr short TK = SK / 16; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = + align_M ? SM : min(SM, short(max(0, (M - (y_row + tm))))); + const short sgp_sn = + align_N ? SN : min(SN, short(max(0, (N - (y_col + tn))))); + + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN); + + constexpr short BR = transpose ? TN : TK; + constexpr short BC = transpose ? TK : TN; + + using AccumType = float; + + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + NAXTile Dtile; + Dtile.clear(); + + const device T* xn = x + tm * K; + + thread LoaderW loader_w( + wl + index * stride_w, + transpose ? K : N, + Ws, + simd_group_id, + simd_lane_id, + transpose ? 0 : int(y_col_long % LoaderW::weights_per_block)); + + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K_it; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe( + transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + // Prevents the Metal compiler from reordering loads across + // iterations. + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(xn + kk1, K); + } else { + Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm)); + } + + if constexpr (transpose) { + Btile.template load(Ws + tn * BK_padded + kk1); + } else { + Btile.template load(Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + xn += BK; + loader_w.next(); + } + + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_safe(tile_w); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + // Prevents the Metal compiler from reordering loads across + // iterations. + volatile int compiler_barrier; + + const short psk = min(int(SK), max(0, (BK - kk1))); + Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm)); + + if constexpr (transpose) { + Btile.template load(Ws + tn * BK_padded + kk1); + } else { + Btile.template load(Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm)); + const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm)); + + if constexpr (kAlignedN.value) { + if (m_lo_lim == 0 && m_hi_lim == SM) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_slice( + y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim)); + } + } else { + Dtile.store_slice( + y + tm * N + tn, + N, + short2(0, m_lo_lim), + short2(sgp_sn, m_hi_lim)); + } + }); + }); + } +} + +#define KQ_NAX_DEFINE_GATHER_RHS(codec, GROUP_CONST, bits_val, LOADER) \ + template < \ + typename T, \ + int group_size, \ + int bits, \ + int BM, \ + int BN, \ + int BK, \ + int WM, \ + int WN, \ + bool transpose> \ + [[kernel]] void kq_##codec##_gather_qmm_rhs_nax( \ + const device T* x [[buffer(0)]], \ + const device uint8_t* w [[buffer(1)]], \ + const device uint8_t* scales [[buffer(2)]], \ + const device uint32_t* indices [[buffer(3)]], \ + device T* y [[buffer(4)]], \ + const constant int& M [[buffer(5)]], \ + const constant int& N [[buffer(6)]], \ + const constant int& K [[buffer(7)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]]) { \ + static_assert( \ + group_size == GROUP_CONST, \ + #codec " NAX kernel requires group_size=" #GROUP_CONST); \ + static_assert( \ + bits == bits_val, #codec " NAX kernel requires bits=" #bits_val); \ + constexpr int BK_padded = (BK + 16 / sizeof(T)); \ + constexpr int BN_padded = (BN + 16 / sizeof(T)); \ + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; \ + using LoaderW = LOADER< \ + T, \ + transpose ? BN : BK, \ + transpose ? BK : BN, \ + transpose ? BK_padded : BN_padded, \ + /*reduction_dim=*/(transpose ? 1 : 0), \ + /*tgp_size=*/WM * WN * SIMD_SIZE>; \ + kq_gather_qmm_rhs_nax_tgp_impl( \ + x, w, indices, y, M, N, K, Ws, tid, simd_group_id, simd_lane_id); \ + } + +KQ_NAX_DEFINE_GATHER_RHS(q8_0, 32, 8, KqNaxQ8_0BlockLoader) +KQ_NAX_DEFINE_GATHER_RHS(q5_1, 32, 5, KqNaxQ5_1BlockLoader) +KQ_NAX_DEFINE_GATHER_RHS(q4_0, 32, 4, KqNaxQ4_0BlockLoader) +KQ_NAX_DEFINE_GATHER_RHS(q4_1, 32, 4, KqNaxQ4_1BlockLoader) +KQ_NAX_DEFINE_GATHER_RHS(q5_0, 32, 5, KqNaxQ5_0BlockLoader) +KQ_NAX_DEFINE_GATHER_RHS(q4_k, 256, 4, KqNaxQ4_KBlockLoader) +KQ_NAX_DEFINE_GATHER_RHS(q5_k, 256, 5, KqNaxQ5_KBlockLoader) +KQ_NAX_DEFINE_GATHER_RHS(q6_k, 256, 6, KqNaxQ6_KBlockLoader) +KQ_NAX_DEFINE_GATHER_RHS(q3_k, 256, 3, KqNaxQ3_KBlockLoader) +KQ_NAX_DEFINE_GATHER_RHS(q2_k, 256, 2, KqNaxQ2_KBlockLoader) + +#undef KQ_NAX_DEFINE_KERNELS +#undef KQ_NAX_DEFINE_GATHER_RHS diff --git a/mlx/backend/metal/kernels/kq_quantized_nax.metal b/mlx/backend/metal/kernels/kq_quantized_nax.metal new file mode 100644 index 0000000000..70cab500d9 --- /dev/null +++ b/mlx/backend/metal/kernels/kq_quantized_nax.metal @@ -0,0 +1,106 @@ +// Copyright © 2026 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/quantized_utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/loader.h" +#include "mlx/backend/metal/kernels/kq_quantized_nax.h" + +#define instantiate_kquant_nax_qmm_t( \ + type, gs, bits, aligned_N, batched, bm, bn, wm, wn, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_qmm_t_nax_" #type "_gs_" #gs "_b_" #bits \ + "_bm" #bm "_bn" #bn "_bk64_wm" #wm "_wn" #wn \ + "_alN_" #aligned_N "_batch_" #batched, \ + kq_ ## codec ## _qmm_t_nax, \ + type, \ + gs, \ + bits, \ + aligned_N, \ + batched, \ + bm, \ + bn, \ + wm, \ + wn) + +#define instantiate_kquant_nax_qmm_n(type, gs, bits, batched, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_qmm_n_nax_" #type "_gs_" #gs "_b_" #bits \ + "_bm64_bn64_bk64_wm2_wn2_batch_" #batched, \ + kq_ ## codec ## _qmm_n_nax, \ + type, \ + gs, \ + bits, \ + batched) + +#define instantiate_kquant_nax_gather_qmm_t( \ + type, gs, bits, aligned_N, bm, bn, wm, wn, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_gather_qmm_t_nax_" #type "_gs_" #gs "_b_" #bits \ + "_bm" #bm "_bn" #bn "_bk64_wm" #wm "_wn" #wn "_alN_" #aligned_N, \ + kq_ ## codec ## _gather_qmm_t_nax, \ + type, \ + gs, \ + bits, \ + aligned_N, \ + bm, \ + bn, \ + wm, \ + wn) + +#define instantiate_kquant_nax_gather_qmm_n(type, gs, bits, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_gather_qmm_n_nax_" #type "_gs_" #gs "_b_" #bits \ + "_bm64_bn64_bk64_wm2_wn2", \ + kq_ ## codec ## _gather_qmm_n_nax, \ + type, \ + gs, \ + bits) + +#define instantiate_kquant_nax_gather_qmm_rhs( \ + type, gs, bits, transpose, suffix, bm, bn, wm, wn, codec) \ + instantiate_kernel( \ + "kquant_" #codec "_gather_qmm_rhs_nax_" #suffix "_" #type \ + "_gs_" #gs "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_64_wm_" #wm "_wn_" #wn, \ + kq_ ## codec ## _gather_qmm_rhs_nax, \ + type, \ + gs, \ + bits, \ + bm, \ + bn, \ + 64, \ + wm, \ + wn, \ + transpose) + +#define instantiate_kquant_nax_codec_for_type(codec, type, gs, bits) \ + instantiate_kquant_nax_qmm_t(type, gs, bits, true, 1, 64, 64, 2, 2, codec) \ + instantiate_kquant_nax_qmm_t(type, gs, bits, true, 0, 64, 64, 2, 2, codec) \ + instantiate_kquant_nax_qmm_t(type, gs, bits, false, 1, 64, 64, 2, 2, codec) \ + instantiate_kquant_nax_qmm_t(type, gs, bits, false, 0, 64, 64, 2, 2, codec) \ + instantiate_kquant_nax_qmm_n(type, gs, bits, 1, codec) \ + instantiate_kquant_nax_qmm_n(type, gs, bits, 0, codec) \ + instantiate_kquant_nax_gather_qmm_t(type, gs, bits, true, 64, 64, 2, 2, codec) \ + instantiate_kquant_nax_gather_qmm_t(type, gs, bits, false, 64, 64, 2, 2, codec) \ + instantiate_kquant_nax_gather_qmm_n(type, gs, bits, codec) \ + instantiate_kquant_nax_gather_qmm_rhs(type, gs, bits, true, nt, 64, 64, 2, 2, codec) \ + instantiate_kquant_nax_gather_qmm_rhs(type, gs, bits, false, nn, 64, 64, 2, 2, codec) + +#define instantiate_kquant_nax_codec(codec, gs, bits) \ + instantiate_kquant_nax_codec_for_type(codec, float, gs, bits) \ + instantiate_kquant_nax_codec_for_type(codec, float16_t, gs, bits) \ + instantiate_kquant_nax_codec_for_type(codec, bfloat16_t, gs, bits) + +instantiate_kquant_nax_codec(q8_0, 32, 8) +instantiate_kquant_nax_codec(q5_1, 32, 5) +instantiate_kquant_nax_codec(q4_0, 32, 4) +instantiate_kquant_nax_codec(q4_1, 32, 4) +instantiate_kquant_nax_codec(q5_0, 32, 5) +instantiate_kquant_nax_codec(q4_k, 256, 4) +instantiate_kquant_nax_codec(q5_k, 256, 5) +instantiate_kquant_nax_codec(q6_k, 256, 6) +instantiate_kquant_nax_codec(q3_k, 256, 3) +instantiate_kquant_nax_codec(q2_k, 256, 2) + // clang-format on diff --git a/mlx/backend/metal/kernels/quantized_nax.h b/mlx/backend/metal/kernels/quantized_nax.h index 8814fafa33..8e6cc9e7d6 100644 --- a/mlx/backend/metal/kernels/quantized_nax.h +++ b/mlx/backend/metal/kernels/quantized_nax.h @@ -1056,7 +1056,6 @@ METAL_FUNC void qmm_t_nax_tgp_impl( loader_w.next(); } - // Store results to device memory threadgroup_barrier(mem_flags::mem_threadgroup); if constexpr (kAlignedM.value && kAlignedN.value) { diff --git a/mlx/backend/metal/kernels/quantized_utils.h b/mlx/backend/metal/kernels/quantized_utils.h index 38253f8fe9..542d0055d7 100644 --- a/mlx/backend/metal/kernels/quantized_utils.h +++ b/mlx/backend/metal/kernels/quantized_utils.h @@ -3,6 +3,29 @@ #include #include +// Load `values_per_thread` consecutive T-values from device memory into a +// thread-local U-buffer (e.g., T=half, U=float for fp16 input + fp32 +// accumulation). Hot path -- caller guarantees in-bounds. +template +inline void load_vector(const device T* x, thread U* x_thread) { +#pragma unroll + for (int i = 0; i < values_per_thread; i++) { + x_thread[i] = x[i]; + } +} + +// Boundary-aware variant: load N values then zero-fill the remaining +// values_per_thread - N slots. +template +inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { + for (int i = 0; i < N; i++) { + x_thread[i] = x[i]; + } + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } +} + template METAL_FUNC void gemm_loop_aligned( threadgroup T* As, diff --git a/mlx/backend/metal/nojit_kernels.cpp b/mlx/backend/metal/nojit_kernels.cpp index 2ed74f470a..c966884a74 100644 --- a/mlx/backend/metal/nojit_kernels.cpp +++ b/mlx/backend/metal/nojit_kernels.cpp @@ -297,7 +297,8 @@ MTL::ComputePipelineState* get_quantized_kernel( metal::Device& d, const std::string& kernel_name, const std::string&, - const std::string&) { + const std::string&, + bool) { return d.get_kernel(kernel_name); } @@ -401,6 +402,7 @@ MTL::ComputePipelineState* get_gather_qmm_nax_kernel( int, int, const std::string&, + const std::string&, int, int, int, diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index c8d5a31cb4..9778d69a01 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -16,21 +16,70 @@ namespace mlx::core { namespace { +inline std::string quantized_kname_prefix( + const std::string& mode, + const std::string& kquant_type) { + if (mode == "kquant") { + if (kquant_type.empty()) { + throw std::runtime_error( + "[QuantizedMatmul] kquant mode requires kquant_type " + "(e.g. \"q4_k\"); (group_size, bits) does not uniquely " + "identify a codec."); + } + if (kquant_codec_by_name(kquant_type) == nullptr) { + throw std::runtime_error( + "[QuantizedMatmul] Unknown kquant_type \"" + kquant_type + "\"."); + } + return "kquant_" + kquant_type + "_"; + } + return mode + "_"; +} + +inline std::string quantized_source_prefix( + const std::string& mode, + const std::string& kquant_type) { + if (mode == "kquant") { + return "kq_" + kquant_type + "_"; + } + return mode + "_"; +} + +inline bool kquant_codec_has_matmul(const std::string& kquant_type) { + const auto* codec = kquant_codec_by_name(kquant_type); + return codec != nullptr && codec->has_matmul_kernel; +} + +inline int kquant_qmv_bn(const std::string& kquant_type) { + if (kquant_type == "q2_k" || kquant_type == "q3_k" || kquant_type == "q4_k" || + kquant_type == "q5_k") { + return 4; + } + return 8; +} + +// Minimum K alignment for the qmv_fast kernel path. +// KQuant block sizes are 32 or 256 -- 256 covers both. +// Affine/FP families require 512 (16 groups of 32). +inline int qmv_fast_k_align(const std::string& mode) { + return (mode == "kquant") ? 256 : 512; +} + template auto get_quantized_kernel_wrapped( metal::Device& d, const std::string& name, const std::string& func, const std::string& mode, + const std::string& kquant_type, const std::string& type, int group_size, int bits, Args... args) { - std::string template_def; - std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func; - template_def = get_template_definition( + std::string fname = quantized_source_prefix(mode, kquant_type) + func; + std::string template_def = get_template_definition( name, fname, type, group_size, bits, std::forward(args)...); - return get_quantized_kernel(d, name, template_def, mode); + bool is_encode = (func == "quantize"); + return get_quantized_kernel(d, name, template_def, mode, is_encode); } template @@ -39,13 +88,13 @@ auto get_qmm_nax_kernel_wrapped( const std::string& name, const std::string& func, const std::string& mode, + const std::string& kquant_type, const std::string& type, int group_size, int bits, Args... args) { - std::string template_def; - std::string fname = ((mode == "affine") ? "affine_" : "fp_") + func; - template_def = get_template_definition( + std::string fname = quantized_source_prefix(mode, kquant_type) + func; + std::string template_def = get_template_definition( name, fname, type, group_size, bits, std::forward(args)...); return get_qmm_nax_kernel(d, name, template_def, mode); } @@ -187,7 +236,8 @@ void qmv_quad( int K, metal::Device& d, const Stream& s, - const std::string& mode) { + const std::string& mode, + const std::string& kquant_type) { int B = out.size() / M / N; constexpr int quads_per_simd = 8; @@ -203,7 +253,7 @@ void qmv_quad( concatenate( kname, - mode + "_qmv_quad_", + quantized_kname_prefix(mode, kquant_type) + "qmv_quad_", type_string, "_gs_", group_size, @@ -213,7 +263,16 @@ void qmv_quad( K, B > 1 ? "_batch_1" : "_batch_0"); auto kernel = get_quantized_kernel_wrapped( - d, kname, "qmv_quad", mode, type_string, group_size, bits, K, B > 1); + d, + kname, + "qmv_quad", + mode, + kquant_type, + type_string, + group_size, + bits, + K, + B > 1); auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); @@ -245,10 +304,11 @@ void qmv( int K, metal::Device& d, const Stream& s, - const std::string& mode) { + const std::string& mode, + const std::string& kquant_type) { int B = out.size() / M / N; - int bn = 8; + int bn = (mode == "kquant") ? kquant_qmv_bn(kquant_type) : 8; int bk = 32; MTL::Size group_dims(bk, 2, 1); MTL::Size grid_dims(M, (N + bn - 1) / bn, B); @@ -256,11 +316,12 @@ void qmv( std::string kname; kname.reserve(64); std::string type_string = get_type_string(x.dtype()); - bool fast = N % bn == 0 && K % 512 == 0; + int k_align = qmv_fast_k_align(mode); + bool fast = N % bn == 0 && K % k_align == 0; concatenate( kname, - mode + (fast ? "_qmv_fast_" : "_qmv_"), + quantized_kname_prefix(mode, kquant_type) + (fast ? "qmv_fast_" : "qmv_"), type_string, "_gs_", group_size, @@ -272,6 +333,7 @@ void qmv( kname, (fast ? "qmv_fast" : "qmv"), mode, + kquant_type, type_string, group_size, bits, @@ -308,7 +370,8 @@ void qvm_split_k( int K, metal::Device& d, const Stream& s, - const std::string& mode) { + const std::string& mode, + const std::string& kquant_type) { auto& compute_encoder = metal::get_command_encoder(s); int split_k = K > 8192 ? 32 : 8; @@ -356,7 +419,7 @@ void qvm_split_k( temp_shape.insert(temp_shape.begin(), 1); } temp_shape.insert(temp_shape.end() - 2, split_k); - array intermediate(temp_shape, x.dtype(), nullptr, {}); + array intermediate(temp_shape, out.dtype(), nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); compute_encoder.add_temporary(intermediate); @@ -365,7 +428,7 @@ void qvm_split_k( kname.reserve(64); concatenate( kname, - mode + "_qvm_split_k_", + quantized_kname_prefix(mode, kquant_type) + "qvm_split_k_", type_string, "_gs_", group_size, @@ -376,7 +439,15 @@ void qvm_split_k( // Encode and dispatch kernel auto kernel = get_quantized_kernel_wrapped( - d, kname, "qvm_split_k", mode, type_string, group_size, bits, split_k); + d, + kname, + "qvm_split_k", + mode, + kquant_type, + type_string, + group_size, + bits, + split_k); compute_encoder.set_compute_pipeline_state(kernel); @@ -429,7 +500,8 @@ void qvm( int K, metal::Device& d, const Stream& s, - const std::string& mode) { + const std::string& mode, + const std::string& kquant_type) { int B = out.size() / M / N; constexpr int num_simdgroups = 2; @@ -443,7 +515,7 @@ void qvm( std::string type_string = get_type_string(x.dtype()); concatenate( kname, - mode + "_qvm_", + quantized_kname_prefix(mode, kquant_type) + "qvm_", type_string, "_gs_", group_size, @@ -451,7 +523,7 @@ void qvm( bits, B > 1 ? "_batch_1" : "_batch_0"); auto kernel = get_quantized_kernel_wrapped( - d, kname, "qvm", mode, type_string, group_size, bits, B > 1); + d, kname, "qvm", mode, kquant_type, type_string, group_size, bits, B > 1); auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); @@ -484,7 +556,8 @@ void qmm_nax( int K, metal::Device& d, const Stream& s, - const std::string& mode) { + const std::string& mode, + const std::string& kquant_type) { int B = out.size() / M / N; int wm = 2; @@ -497,12 +570,13 @@ void qmm_nax( std::string kname; kname.reserve(64); - bool aligned = N % 64 == 0; + bool aligned = N % bn == 0; bool batched = B > 1; std::string type_string = get_type_string(x.dtype()); concatenate( kname, - mode + (transpose ? "_qmm_t_nax_" : "_qmm_n_nax_"), + quantized_kname_prefix(mode, kquant_type) + + (transpose ? "qmm_t_nax_" : "qmm_n_nax_"), type_string, "_gs_", group_size, @@ -528,6 +602,7 @@ void qmm_nax( kname, "qmm_t_nax", mode, + kquant_type, type_string, group_size, bits, @@ -544,6 +619,7 @@ void qmm_nax( kname, "qmm_n_nax", mode, + kquant_type, type_string, group_size, bits, @@ -589,24 +665,26 @@ void gather_qmm_nax( int K, metal::Device& d, const Stream& s, - const std::string& mode) { + const std::string& mode, + const std::string& kquant_type) { int B = out.size() / M / N; int wm = 2; int wn = 2; int bm = 64; int bn = 64; - int bk = 32; + int bk = 64; MTL::Size group_dims(32, wn, wm); MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); std::string kname; kname.reserve(64); - bool aligned = N % 64 == 0; + bool aligned = N % bn == 0; std::string type_string = get_type_string(x.dtype()); concatenate( kname, - mode + (transpose ? "_gather_qmm_t_nax_" : "_gather_qmm_n_nax_"), + quantized_kname_prefix(mode, kquant_type) + + (transpose ? "gather_qmm_t_nax_" : "gather_qmm_n_nax_"), type_string, "_gs_", group_size, @@ -628,8 +706,9 @@ void gather_qmm_nax( kernel = get_qmm_nax_kernel_wrapped( d, kname, - "gather_qmm_t_nax_", + "gather_qmm_t_nax", mode, + kquant_type, type_string, group_size, bits, @@ -643,8 +722,9 @@ void gather_qmm_nax( kernel = get_qmm_nax_kernel_wrapped( d, kname, - "gather_qmm_n_nax_", + "gather_qmm_n_nax", mode, + kquant_type, type_string, group_size, bits, @@ -691,9 +771,11 @@ void qmm( int K, metal::Device& d, const Stream& s, - const std::string& mode) { + const std::string& mode, + const std::string& kquant_type) { if (metal::is_nax_available() && transpose && (K % 64 == 0) && - (env::enable_tf32() || x.dtype() != float32)) { + (env::enable_tf32() || x.dtype() != float32) && + (mode != "kquant" || kquant_codec_has_matmul(kquant_type))) { return qmm_nax( /* const array& x = */ x, /* const array& w = */ w, @@ -708,26 +790,29 @@ void qmm( /* int K = */ K, /* metal::Device& d = */ d, /* const Stream& s = */ s, - /* const std::string& mode = */ mode); + /* const std::string& mode = */ mode, + /* const std::string& kquant_type = */ kquant_type); } int B = out.size() / M / N; + bool kquant = mode == "kquant"; int wm = 2; int wn = 2; - int bm = 32; - int bn = 32; + int bm = kquant ? 64 : 32; + int bn = (kquant && transpose) ? 64 : 32; MTL::Size group_dims(32, wn, wm); MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); std::string kname; kname.reserve(64); - bool aligned = N % 32 == 0; + bool aligned = N % bn == 0; bool batched = B > 1; std::string type_string = get_type_string(x.dtype()); concatenate( kname, - mode + (transpose ? "_qmm_t_" : "_qmm_n_"), + quantized_kname_prefix(mode, kquant_type) + + (transpose ? "qmm_t_" : "qmm_n_"), type_string, "_gs_", group_size, @@ -743,6 +828,7 @@ void qmm( kname, "qmm_t", mode, + kquant_type, type_string, group_size, bits, @@ -750,7 +836,15 @@ void qmm( batched); } else { kernel = get_quantized_kernel_wrapped( - d, kname, "qmm_n", mode, type_string, group_size, bits, batched); + d, + kname, + "qmm_n", + mode, + kquant_type, + type_string, + group_size, + bits, + batched); } auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); @@ -784,7 +878,8 @@ void qmm_splitk( int K, metal::Device& d, const Stream& s, - const std::string& mode) { + const std::string& mode, + const std::string& kquant_type) { // Choose split_k to target ~512 threadgroups int bm = 32, bn = 32; int n_tiles = (N + bn - 1) / bn; @@ -801,7 +896,21 @@ void qmm_splitk( } if (split_k <= 1) { return qmm( - x, w, scales, biases, out, true, group_size, bits, M, N, K, d, s, mode); + x, + w, + scales, + biases, + out, + true, + group_size, + bits, + M, + N, + K, + d, + s, + mode, + kquant_type); } int k_partition_size = K / split_k; @@ -815,7 +924,7 @@ void qmm_splitk( temp_shape.insert(temp_shape.begin(), 1); } temp_shape.insert(temp_shape.begin(), split_k); - array intermediate(temp_shape, x.dtype(), nullptr, {}); + array intermediate(temp_shape, out.dtype(), nullptr, {}); intermediate.set_data(allocator::malloc(intermediate.nbytes())); compute_encoder.add_temporary(intermediate); @@ -829,7 +938,7 @@ void qmm_splitk( kname.reserve(64); concatenate( kname, - mode + "_qmm_t_splitk_", + quantized_kname_prefix(mode, kquant_type) + "qmm_t_splitk_", type_string, "_gs_", group_size, @@ -837,7 +946,15 @@ void qmm_splitk( bits, aligned ? "_alN_true" : "_alN_false"); auto kernel = get_quantized_kernel_wrapped( - d, kname, "qmm_t_splitk", mode, type_string, group_size, bits, aligned); + d, + kname, + "qmm_t_splitk", + mode, + kquant_type, + type_string, + group_size, + bits, + aligned); compute_encoder.set_compute_pipeline_state(kernel); @@ -882,9 +999,11 @@ void gather_qmm( int K, metal::Device& d, const Stream& s, - const std::string& mode) { + const std::string& mode, + const std::string& kquant_type) { if (metal::is_nax_available() && transpose && (K % 64 == 0) && - (env::enable_tf32() || x.dtype() != float32)) { + (env::enable_tf32() || x.dtype() != float32) && + (mode != "kquant" || kquant_codec_has_matmul(kquant_type))) { return gather_qmm_nax( /* const array& x = */ x, /* const array& w = */ w, @@ -901,7 +1020,8 @@ void gather_qmm( /* int K = */ K, /* metal::Device& d = */ d, /* const Stream& s = */ s, - /* const std::string& mode = */ mode); + /* const std::string& mode = */ mode, + /* const std::string& kquant_type = */ kquant_type); } int B = out.size() / M / N; @@ -919,7 +1039,8 @@ void gather_qmm( std::string type_string = get_type_string(x.dtype()); concatenate( kname, - mode + (transpose ? "_gather_qmm_t_" : "_gather_qmm_n_"), + quantized_kname_prefix(mode, kquant_type) + + (transpose ? "gather_qmm_t_" : "gather_qmm_n_"), type_string, "_gs_", group_size, @@ -929,10 +1050,25 @@ void gather_qmm( MTL::ComputePipelineState* kernel; if (transpose) { kernel = get_quantized_kernel_wrapped( - d, kname, "gather_qmm_t", mode, type_string, group_size, bits, aligned); + d, + kname, + "gather_qmm_t", + mode, + kquant_type, + type_string, + group_size, + bits, + aligned); } else { kernel = get_quantized_kernel_wrapped( - d, kname, "gather_qmm_n", mode, type_string, group_size, bits); + d, + kname, + "gather_qmm_n", + mode, + kquant_type, + type_string, + group_size, + bits); } auto& compute_encoder = metal::get_command_encoder(s); @@ -972,10 +1108,11 @@ void gather_qmv( int K, metal::Device& d, const Stream& s, - const std::string& mode) { + const std::string& mode, + const std::string& kquant_type) { int B = out.size() / M / N; - int bn = 8; + int bn = (mode == "kquant") ? kquant_qmv_bn(kquant_type) : 8; int bk = 32; MTL::Size group_dims(bk, 2, 1); MTL::Size grid_dims(M, (N + bn - 1) / bn, B); @@ -983,10 +1120,12 @@ void gather_qmv( std::string kname; kname.reserve(64); std::string type_string = get_type_string(x.dtype()); - bool fast = N % bn == 0 && K % 512 == 0; + int k_align = qmv_fast_k_align(mode); + bool fast = N % bn == 0 && K % k_align == 0; concatenate( kname, - mode + (fast ? "_gather_qmv_fast_" : "_gather_qmv_"), + quantized_kname_prefix(mode, kquant_type) + + (fast ? "gather_qmv_fast_" : "gather_qmv_"), type_string, "_gs_", group_size, @@ -998,6 +1137,7 @@ void gather_qmv( kname, (fast ? "gather_qmv_fast" : "gather_qmv"), mode, + kquant_type, type_string, group_size, bits); @@ -1038,7 +1178,8 @@ void gather_qvm( int K, metal::Device& d, const Stream& s, - const std::string& mode) { + const std::string& mode, + const std::string& kquant_type) { int B = out.size() / M / N; constexpr int num_simdgroups = 2; @@ -1052,14 +1193,14 @@ void gather_qvm( std::string type_string = get_type_string(x.dtype()); concatenate( kname, - mode + "_gather_qvm_", + quantized_kname_prefix(mode, kquant_type) + "gather_qvm_", type_string, "_gs_", group_size, "_b_", bits); auto kernel = get_quantized_kernel_wrapped( - d, kname, "gather_qvm", mode, type_string, group_size, bits); + d, kname, "gather_qvm", mode, kquant_type, type_string, group_size, bits); auto& compute_encoder = metal::get_command_encoder(s); compute_encoder.set_compute_pipeline_state(kernel); @@ -1096,7 +1237,8 @@ void gather_qmm_rhs_nax( int K, metal::Device& d, const Stream& s, - const std::string mode) { + const std::string& mode, + const std::string& kquant_type) { // Start by normalizing the indices array indices = ensure_row_contiguous(indices_, d, s); @@ -1122,9 +1264,11 @@ void gather_qmm_rhs_nax( array w = ensure_row_contiguous(w_, d, s); array scales = ensure_row_contiguous(scales_, d, s); - // TODO: Tune the block sizes - int bm = 64, bn = 64, bk = 64; - int wm = 2, wn = 2; + int bm = 64; + int bn = 64; + int bk = 64; + int wm = 2; + int wn = 2; const bool align_M = (M % bm) == 0; const bool align_N = (N % bn) == 0; @@ -1136,8 +1280,8 @@ void gather_qmm_rhs_nax( std::string type_string = get_type_string(x.dtype()); concatenate( kname, - mode + - (transpose ? "_gather_qmm_rhs_nax_nt_" : "_gather_qmm_rhs_nax_nn_"), + quantized_kname_prefix(mode, kquant_type) + + (transpose ? "gather_qmm_rhs_nax_nt_" : "gather_qmm_rhs_nax_nn_"), type_string, "_gs_", group_size, @@ -1175,6 +1319,8 @@ void gather_qmm_rhs_nax( // Get and set the kernel auto& compute_encoder = metal::get_command_encoder(s); + std::string func_name = + quantized_source_prefix(mode, kquant_type) + "gather_qmm_rhs_nax"; auto kernel = get_gather_qmm_nax_kernel( d, kname, @@ -1184,6 +1330,7 @@ void gather_qmm_rhs_nax( group_size, bits, mode, + func_name, bm, bn, bk, @@ -1227,9 +1374,11 @@ void gather_qmm_rhs( int K, metal::Device& d, const Stream& s, - const std::string mode) { + const std::string& mode, + const std::string& kquant_type) { if (metal::is_nax_available() && transpose && - (env::enable_tf32() || x_.dtype() != float32)) { + (env::enable_tf32() || x_.dtype() != float32) && + (mode != "kquant" || kquant_codec_has_matmul(kquant_type))) { return gather_qmm_rhs_nax( /* const array& x_ = */ x_, /* const array& w_ = */ w_, @@ -1245,7 +1394,8 @@ void gather_qmm_rhs( /* int K = */ K, /* metal::Device& d = */ d, /* const Stream& s = */ s, - /* const std::string mode = */ mode); + /* const std::string mode = */ mode, + /* const std::string& kquant_type = */ kquant_type); } // Start by normalizing the indices @@ -1287,7 +1437,8 @@ void gather_qmm_rhs( std::string type_string = get_type_string(x.dtype()); concatenate( kname, - mode + (transpose ? "_gather_qmm_rhs_nt_" : "_gather_qmm_rhs_nn_"), + quantized_kname_prefix(mode, kquant_type) + + (transpose ? "gather_qmm_rhs_nt_" : "gather_qmm_rhs_nn_"), type_string, "_gs_", group_size, @@ -1375,13 +1526,41 @@ void dispatch_qmv( int K, metal::Device& d, const Stream& s, - const std::string& mode) { + const std::string& mode, + const std::string& kquant_type) { // It is a qmv with a small inner dimension so route to qmv_quad kernel if ((K == 128 || K == 64) && is_power_of_2(bits)) { - qmv_quad(x, w, scales, biases, out, group_size, bits, M, N, K, d, s, mode); + qmv_quad( + x, + w, + scales, + biases, + out, + group_size, + bits, + M, + N, + K, + d, + s, + mode, + kquant_type); return; } - qmv(x, w, scales, biases, out, group_size, bits, M, N, K, d, s, mode); + qmv(x, + w, + scales, + biases, + out, + group_size, + bits, + M, + N, + K, + d, + s, + mode, + kquant_type); } void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { @@ -1408,13 +1587,53 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4; auto mode = quantization_mode_to_string(mode_); + + if (mode_ == QuantizationMode::KQuant) { + if (!transpose_ && M < vector_limit) { + qmm(x, + w, + scales, + biases, + out, + transpose_, + group_size_, + bits_, + M, + N, + K, + d, + s, + mode, + kquant_type_); + return; + } + if ((K == 64 || K == 128) && M < vector_limit && transpose_) { + throw std::runtime_error( + "[QuantizedMatmul::eval_gpu] KQuant qmv_quad is not implemented " + "for K=64 or K=128."); + } + } + // It is a matrix matrix product. if (M >= vector_limit) { // Use split-K qmm for small M with transposed weights (non-batched only) int B = out.size() / M / N; if (transpose_ && B == 1) { qmm_splitk( - x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode); + x, + w, + scales, + biases, + out, + group_size_, + bits_, + M, + N, + K, + d, + s, + mode, + kquant_type_); return; } qmm(x, @@ -1430,26 +1649,66 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { K, d, s, - mode); + mode, + kquant_type_); return; } // Run of the mill qmv if (transpose_) { dispatch_qmv( - x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode); + x, + w, + scales, + biases, + out, + group_size_, + bits_, + M, + N, + K, + d, + s, + mode, + kquant_type_); return; } // Run of the mill qvm if (K < 1024) { - qvm(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode); + qvm(x, + w, + scales, + biases, + out, + group_size_, + bits_, + M, + N, + K, + d, + s, + mode, + kquant_type_); return; } // Qvm with large dimension so route to a split K kernel for more parallelism qvm_split_k( - x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s, mode); + x, + w, + scales, + biases, + out, + group_size_, + bits_, + M, + N, + K, + d, + s, + mode, + kquant_type_); return; } @@ -1481,7 +1740,13 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { // matmuls and reuse reading x and w. // // TODO: Tune 16 and 4 here a bit better. - if (M == 1 && B >= 16 && right_sorted_ == true && B / E >= 4) { + // KQuant rhs-only path requires NAX; mirrors the gate in gather_qmm_rhs(). + bool kquant_rhs_ok = mode_ != QuantizationMode::KQuant || + (metal::is_nax_available() && transpose_ && (K % 64 == 0) && + (env::enable_tf32() || x.dtype() != float32) && + kquant_codec_has_matmul(kquant_type_)); + if (M == 1 && B >= 16 && right_sorted_ == true && B / E >= 4 && + kquant_rhs_ok) { gather_qmm_rhs( x, w, @@ -1497,7 +1762,8 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { K, d, s, - mode); + mode, + kquant_type_); return; } @@ -1519,7 +1785,8 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { K, d, s, - mode); + mode, + kquant_type_); return; } @@ -1539,7 +1806,31 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { K, d, s, - mode); + mode, + kquant_type_); + return; + } + + // KQuant has no dedicated gather_qvm kernel; route through gather_qmm_n. + if (mode_ == QuantizationMode::KQuant) { + gather_qmm( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + transpose_, + group_size_, + bits_, + M, + N, + K, + d, + s, + mode, + kquant_type_); return; } @@ -1558,7 +1849,8 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { K, d, s, - mode); + mode, + kquant_type_); } void quantize_dequantize( @@ -1568,7 +1860,8 @@ void quantize_dequantize( int group_size, int bits, metal::Device& d, - const Stream& s) { + const Stream& s, + const std::string& kquant_type) { auto& compute_encoder = metal::get_command_encoder(s); auto w = ensure_row_contiguous(in, d, s); @@ -1578,14 +1871,21 @@ void quantize_dequantize( std::string kname; concatenate( kname, - mode + "_quantize_dequantize_", + quantized_kname_prefix(mode, kquant_type) + "quantize_dequantize_", type_string, "_gs_", group_size, "_b_", bits); auto kernel = get_quantized_kernel_wrapped( - d, kname, "quantize_dequantize", mode, type_string, group_size, bits); + d, + kname, + "quantize_dequantize", + mode, + kquant_type, + type_string, + group_size, + bits); compute_encoder.set_compute_pipeline_state(kernel); @@ -1609,6 +1909,9 @@ void quantize_dequantize( } void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { + if (mode_ == QuantizationMode::KQuant) { + throw std::runtime_error("[QQMatmul::eval_gpu] KQuant not implemented."); + } auto& s = stream(); auto& d = metal::device(s.device); @@ -1624,7 +1927,7 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { auto xhat = donate_x ? x : array(allocator::malloc(x.nbytes()), x.shape(), x.dtype()); - quantize_dequantize(x, xhat, mode, group_size_, bits_, d, s); + quantize_dequantize(x, xhat, mode, group_size_, bits_, d, s, kquant_type_); // Make sure the last two dims of w and s are contiguous array w = ensure_row_contiguous_matrix(inputs[1], d, s); @@ -1647,9 +1950,12 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { K, d, s, - mode); + mode, + kquant_type_); return; } else { + // General QQMatmul (both operands quantized, arbitrary layout) would + // require a dual-dequant kernel. No known inference workload needs this. throw std::runtime_error("[QQMatmul] NYI for the general case"); } } @@ -1665,6 +1971,118 @@ void fast::Quantize::eval_gpu( auto& d = metal::device(s.device); auto& compute_encoder = metal::get_command_encoder(s); + if (dequantize_ && mode_ == QuantizationMode::KQuant) { + auto w = ensure_row_contiguous(w_pre, d, s); + auto scales = ensure_row_contiguous(inputs[1], d, s); + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_output_array(out, 2); + uint32_t num_weights = static_cast(out.size()); + compute_encoder.set_bytes(num_weights, 3); + + auto type_string = get_type_string(out.dtype()); + auto mode = quantization_mode_to_string(mode_); + std::string kname; + concatenate( + kname, + quantized_kname_prefix(mode, kquant_type_), + "dequantize_", + type_string, + "_gs_", + group_size_, + "_b_", + bits_); + auto kernel = get_quantized_kernel_wrapped( + d, + kname, + "dequantize", + mode, + kquant_type_, + type_string, + group_size_, + bits_); + compute_encoder.set_compute_pipeline_state(kernel); + + NS::UInteger tg = kernel->maxTotalThreadsPerThreadgroup(); + if (tg > static_cast(num_weights)) { + tg = num_weights; + } + auto group_dims = MTL::Size(tg, 1, 1); + auto grid_dims = MTL::Size(num_weights, 1, 1); + compute_encoder.dispatch_threads(grid_dims, group_dims); + return; + } + + if (!dequantize_ && mode_ == QuantizationMode::KQuant) { + auto w_contig = ensure_row_contiguous(w_pre, d, s); + auto& scales_placeholder = outputs[1]; + scales_placeholder.set_data(allocator::malloc(scales_placeholder.nbytes())); + + const KQuantCodec* codec = kquant_codec_by_name(kquant_type_); + if (codec == nullptr) { + throw std::runtime_error( + "[fast::Quantize::eval_gpu] Unknown kquant_type: '" + kquant_type_ + + "'."); + } + uint32_t num_blocks = + static_cast(out.size() / codec->bytes_per_block); + + compute_encoder.set_input_array(w_contig, 0); + compute_encoder.set_output_array(out, 1); + compute_encoder.set_bytes(num_blocks, 2); + + uint32_t has_imatrix = 0; + uint32_t K = static_cast(w_pre.shape(-1)); + if (inputs.size() >= 2) { + auto imatrix_contig = ensure_row_contiguous(inputs[1], d, s); + compute_encoder.set_input_array(imatrix_contig, 3); + has_imatrix = 1; + } else { + compute_encoder.set_input_array(w_contig, 3); + } + compute_encoder.set_bytes(has_imatrix, 4); + compute_encoder.set_bytes(K, 5); + + auto type_string = get_type_string(w_pre.dtype()); + auto mode = quantization_mode_to_string(mode_); + std::string kname; + concatenate( + kname, + quantized_kname_prefix(mode, kquant_type_), + "quantize_", + type_string, + "_gs_", + group_size_, + "_b_", + bits_); + auto kernel = get_quantized_kernel_wrapped( + d, + kname, + "quantize", + mode, + kquant_type_, + type_string, + group_size_, + bits_); + compute_encoder.set_compute_pipeline_state(kernel); + + // K-codecs: 256 threads per super-block; flat codecs: 1 thread per block. + if (group_size_ >= 256) { + auto group_dims = MTL::Size(256, 1, 1); + auto grid_dims = MTL::Size(num_blocks, 1, 1); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + } else { + NS::UInteger tg = kernel->maxTotalThreadsPerThreadgroup(); + if (tg > static_cast(num_blocks)) { + tg = num_blocks; + } + auto group_dims = MTL::Size(tg, 1, 1); + auto grid_dims = MTL::Size(num_blocks, 1, 1); + compute_encoder.dispatch_threads(grid_dims, group_dims); + } + return; + } + auto w = ensure_row_contiguous(w_pre, d, s); if (dequantize_) { auto scales = ensure_row_contiguous(inputs[1], d, s); @@ -1691,10 +2109,12 @@ void fast::Quantize::eval_gpu( auto type_string = dequantize_ ? get_type_string(out.dtype()) : get_type_string(w_pre.dtype()); auto mode = quantization_mode_to_string(mode_); + const std::string kquant_type; std::string kname; concatenate( kname, - mode + (dequantize_ ? "_dequantize" : "_quantize"), + quantized_kname_prefix(mode, kquant_type) + + (dequantize_ ? "dequantize" : "quantize"), "_", type_string, "_gs_", @@ -1706,6 +2126,7 @@ void fast::Quantize::eval_gpu( kname, dequantize_ ? "dequantize" : "quantize", mode, + kquant_type, type_string, group_size_, bits_); diff --git a/mlx/fast.cpp b/mlx/fast.cpp index a668fe9abd..c13d82b3e2 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -926,17 +926,43 @@ bool Quantize::is_equivalent(const Primitive& other) const { const Quantize& p_other = static_cast(other); return ( p_other.group_size_ == group_size_ && p_other.bits_ == bits_ && - p_other.mode_ == mode_ && p_other.dequantize_ == dequantize_); + p_other.mode_ == mode_ && p_other.dequantize_ == dequantize_ && + p_other.kquant_type_ == kquant_type_); } std::vector Quantize::output_shapes(const std::vector& inputs) { auto& w = inputs[0]; if (dequantize_) { - auto out_size = w.shape(-1) * 32 / bits_; + size_t out_size = 0; + if (mode_ == QuantizationMode::KQuant) { + const auto* codec = kquant_codec_by_name(kquant_type_); + if (codec == nullptr) { + throw std::invalid_argument( + "[Quantize::output_shapes] Unknown kquant_type: '" + kquant_type_ + + "'."); + } + out_size = + (w.shape(-1) / codec->bytes_per_block) * codec->weights_per_block; + } else { + out_size = w.shape(-1) * 32 / bits_; + } auto out_shape = w.shape(); out_shape.back() = out_size; return {std::move(out_shape)}; } else { + if (mode_ == QuantizationMode::KQuant) { + const auto* codec = kquant_codec_by_name(kquant_type_); + if (codec == nullptr) { + throw std::invalid_argument( + "[Quantize::output_shapes] Unknown kquant_type: '" + kquant_type_ + + "'."); + } + auto wq_shape = w.shape(); + wq_shape.back() = + (w.shape(-1) / codec->weights_per_block) * codec->bytes_per_block; + Shape s_shape = {1}; + return {std::move(wq_shape), std::move(s_shape)}; + } auto wq_shape = w.shape(); wq_shape.back() = w.shape(-1) * bits_ / 32; auto sshape = w.shape(); diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..a4b2643f4a 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -332,12 +332,14 @@ class Quantize : public Custom { int group_size, int bits, QuantizationMode mode, - bool dequantize) + bool dequantize, + std::string kquant_type = "") : Custom(stream, std::move(fallback)), group_size_(group_size), bits_(bits), mode_(mode), - dequantize_(dequantize) {} + dequantize_(dequantize), + kquant_type_(std::move(kquant_type)) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; @@ -350,7 +352,8 @@ class Quantize : public Custom { bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { - return std::make_tuple(nullptr, group_size_, bits_, mode_, dequantize_); + return std::make_tuple( + nullptr, group_size_, bits_, mode_, dequantize_, kquant_type_); } private: @@ -358,6 +361,7 @@ class Quantize : public Custom { int bits_; QuantizationMode mode_; bool dequantize_; + std::string kquant_type_; }; using ScalarArg = std::variant; diff --git a/mlx/io/CMakeLists.txt b/mlx/io/CMakeLists.txt index 1fedcd46b4..2d0e13f77c 100644 --- a/mlx/io/CMakeLists.txt +++ b/mlx/io/CMakeLists.txt @@ -11,7 +11,7 @@ if(MLX_BUILD_GGUF) FetchContent_Declare( gguflib GIT_REPOSITORY https://github.com/antirez/gguf-tools/ - GIT_TAG 8fa6eb65236618e28fd7710a0fba565f7faa1848) + GIT_TAG fdfafbed766db0a1e9019b07994cd88f133d1aab) FetchContent_MakeAvailable(gguflib) target_include_directories(mlx PRIVATE $) diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index 206f6fb31f..062aa3077d 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -211,7 +211,10 @@ std::unordered_map load_metadata(gguf_ctx* ctx) { return metadata; } -std::unordered_map load_arrays(gguf_ctx* ctx) { +// May insert kKQuantTypesKey into metadata for kquant tensors. +std::unordered_map load_arrays( + gguf_ctx* ctx, + std::unordered_map& metadata) { std::unordered_map array_map; gguf_tensor tensor; @@ -219,15 +222,16 @@ std::unordered_map load_arrays(gguf_ctx* ctx) { if (!inserted.second) { std::ostringstream msg; msg << "[load_gguf] Duplicate parameter name " << inserted.first->second - << " this can happend when loading quantized tensors."; + << " this can happen when loading quantized tensors."; throw std::runtime_error(msg.str()); } }; + std::vector kquant_entries; + while (gguf_get_tensor(ctx, &tensor)) { - if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 || - tensor.type == GGUF_TYPE_Q8_0) { - gguf_load_quantized(array_map, tensor); + if (const auto* kqc = gguf_type_to_kquant_codec(tensor.type)) { + gguf_load_kquant(array_map, tensor, *kqc, kquant_entries); } else { std::string name(tensor.name, tensor.namelen); const auto& [data, dtype] = extract_tensor_data(&tensor); @@ -235,6 +239,10 @@ std::unordered_map load_arrays(gguf_ctx* ctx) { check_insert(array_map.insert({name, loaded_array})); } } + + if (!kquant_entries.empty()) { + metadata[kKQuantTypesKey] = std::move(kquant_entries); + } return array_map; } @@ -254,7 +262,7 @@ GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) { throw std::runtime_error("[load_gguf] gguf_init failed"); } auto metadata = load_metadata(ctx.get()); - auto arrays = load_arrays(ctx.get()); + auto arrays = load_arrays(ctx.get(), metadata); return {arrays, metadata}; } @@ -312,8 +320,11 @@ void save_gguf( memcpy(val->string, src.c_str(), src.length()); }; - // Save any meta data + // Save any meta data (skip synthetic keys injected by the loader) for (auto& [key, value] : metadata) { + if (key == kKQuantTypesKey) { + continue; + } if (auto pv = std::get_if(&value); pv) { const std::string& str = *pv; size_t size = sizeof(gguf_string) + str.length(); diff --git a/mlx/io/gguf.h b/mlx/io/gguf.h index fa5bc458de..91495d39b4 100644 --- a/mlx/io/gguf.h +++ b/mlx/io/gguf.h @@ -12,9 +12,16 @@ extern "C" { namespace mlx::core { +constexpr const char* kKQuantTypesKey = "__kquant_types__"; + Shape get_shape(const gguf_tensor& tensor); -void gguf_load_quantized( + +const KQuantCodec* gguf_type_to_kquant_codec(uint32_t gguf_type); + +void gguf_load_kquant( std::unordered_map& a, - const gguf_tensor& tensor); + const gguf_tensor& tensor, + const KQuantCodec& codec, + std::vector& kquant_entries); } // namespace mlx::core diff --git a/mlx/io/gguf_quants.cpp b/mlx/io/gguf_quants.cpp index 148ed6c479..0b3b06b642 100644 --- a/mlx/io/gguf_quants.cpp +++ b/mlx/io/gguf_quants.cpp @@ -8,157 +8,105 @@ namespace mlx::core { -void unpack_32_4(uint8_t* data, int8_t* dst) { - std::fill_n(dst, 16, 0); - for (int j = 0; j < 16; ++j) { - uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes. - if (j % 2 != 0) { - x <<= 4; - } - dst[j / 2] += x; - } - // Last 16 weights are in the higher bits - for (int j = 0; j < 16; ++j) { - uint8_t x = (data[j + 2] >> 4); - if (j % 2 != 0) { - x <<= 4; - } - dst[8 + j / 2] += x; - } -} - -// Extracts (weight, scales, biases) from Q4_0 tensors. -// Data layout is: |16 bit scale|32 x 4bit weights|. -void extract_q4_0_data( - const gguf_tensor& tensor, - array& weights_arr, - array& scales_arr, - array& biases_arr) { - const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights - auto data = static_cast(tensor.weights_data); - auto weights = weights_arr.data(); - auto scales = scales_arr.data(); - auto biases = biases_arr.data(); - for (int64_t i = 0; i < scales_arr.size(); i++) { - scales[i] = *((float16_t*)data); - biases[i] = -8 * scales[i]; - unpack_32_4(data, weights); - weights += 16; - data += bytes_per_block; +const KQuantCodec* gguf_type_to_kquant_codec(uint32_t gguf_type) { + switch (gguf_type) { + case GGUF_TYPE_Q4_0: + return kquant_codec_by_name("q4_0"); + case GGUF_TYPE_Q4_1: + return kquant_codec_by_name("q4_1"); + case GGUF_TYPE_Q5_0: + return kquant_codec_by_name("q5_0"); + case GGUF_TYPE_Q8_0: + return kquant_codec_by_name("q8_0"); + case GGUF_TYPE_Q5_1: + return kquant_codec_by_name("q5_1"); + case GGUF_TYPE_Q2_K: + return kquant_codec_by_name("q2_k"); + case GGUF_TYPE_Q3_K: + return kquant_codec_by_name("q3_k"); + case GGUF_TYPE_Q4_K: + return kquant_codec_by_name("q4_k"); + case GGUF_TYPE_Q5_K: + return kquant_codec_by_name("q5_k"); + case GGUF_TYPE_Q6_K: + return kquant_codec_by_name("q6_k"); + default: + return nullptr; } } -// Extracts (weight, scales, biases) from Q4_1 tensors. -// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|. -void extract_q4_1_data( +void gguf_load_kquant( + std::unordered_map& a, const gguf_tensor& tensor, - array& weights_arr, - array& scales_arr, - array& biases_arr) { - const uint64_t bytes_per_block = - 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights - auto data = static_cast(tensor.weights_data); - auto weights = weights_arr.data(); - auto scales = scales_arr.data(); - auto biases = biases_arr.data(); - for (int64_t i = 0; i < scales_arr.size(); i++) { - scales[i] = *((float16_t*)data); - biases[i] = *((float16_t*)(data) + 1); - unpack_32_4(data, weights); - weights += 16; - data += bytes_per_block; - } -} + const KQuantCodec& codec, + std::vector& kquant_entries) { + std::string name(tensor.name, tensor.namelen); -// Extracts (weight, scales, biases) from Q8_0 tensors. -// Data layout is: |16 bit scale|32 x 8bit weights|. -void extract_q8_0_data( - const gguf_tensor& tensor, - array& weights_arr, - array& scales_arr, - array& biases_arr) { - const uint64_t weights_per_block = 32; - const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights - auto data = static_cast(tensor.weights_data); - auto weights = weights_arr.data(); - auto scales = scales_arr.data(); - auto biases = biases_arr.data(); - for (int64_t i = 0; i < scales_arr.size(); i++) { - uint8_t* block_data = data + i * bytes_per_block; - scales[i] = *((float16_t*)block_data); - biases[i] = -128 * scales[i]; - for (int64_t j = 0; j < weights_per_block; ++j) { - uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. - // Original data is in int8_t, so we add a bias of -128 and invert the - // first bit. - x ^= 1 << 7; - weights[i * weights_per_block + j] = x; - } + auto logical_shape = get_shape(tensor); + if (logical_shape.empty()) { + std::ostringstream msg; + msg << "[load_gguf] kquant tensor " << name << " has no dimensions"; + throw std::runtime_error(msg.str()); } -} - -void gguf_load_quantized( - std::unordered_map& a, - const gguf_tensor& tensor) { - uint64_t weights_per_byte; - if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1) { - weights_per_byte = 2; - } else { // tensor.type == GGUF_TYPE_Q8_0 - weights_per_byte = 1; + auto last_dim = logical_shape.back(); + if (last_dim % codec.weights_per_block != 0) { + std::ostringstream msg; + msg << "[load_gguf] kquant tensor " << name << " last dim " << last_dim + << " is not divisible by weights_per_block " << codec.weights_per_block + << " for codec " << codec.name; + throw std::runtime_error(msg.str()); } - - std::string name(tensor.name, tensor.namelen); - - auto shape = get_shape(tensor); - const uint64_t weights_per_block = 32; - if (shape[shape.size() - 1] % weights_per_block != 0) { + auto bytes_per_row = + (last_dim / codec.weights_per_block) * codec.bytes_per_block; + auto packed_shape = logical_shape; + packed_shape.back() = bytes_per_row; + + size_t total_bytes = std::accumulate( + packed_shape.begin(), + packed_shape.end(), + static_cast(1), + std::multiplies()); + if (total_bytes != tensor.bsize) { std::ostringstream msg; - msg << "[load_gguf] tensor " << name - << "has incompatible last dim shape: " << shape[shape.size() - 1]; + msg << "[load_gguf] kquant tensor " << name << " (" << codec.name + << ") computed byte size " << total_bytes + << " does not match tensor.bsize " << tensor.bsize; throw std::runtime_error(msg.str()); } - auto weights_shape = shape; - weights_shape.back() /= (weights_per_byte * 4); - auto w_nbytes = uint32.size() * - std::accumulate(weights_shape.begin(), - weights_shape.end(), - 1, - std::multiplies()); - - array weights(allocator::malloc(w_nbytes), std::move(weights_shape), uint32); - - // For scales and bias - shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block; - auto sb_nbytes = float16.size() * - std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + auto buf = allocator::malloc(total_bytes); + std::memcpy(buf.raw_ptr(), tensor.weights_data, total_bytes); + array weight(buf, std::move(packed_shape), uint8); - array scales(allocator::malloc(sb_nbytes), shape, float16); - array biases(allocator::malloc(sb_nbytes), std::move(shape), float16); - if (tensor.type == GGUF_TYPE_Q4_0) { - extract_q4_0_data(tensor, weights, scales, biases); - } else if (tensor.type == GGUF_TYPE_Q4_1) { - extract_q4_1_data(tensor, weights, scales, biases); - } else if (tensor.type == GGUF_TYPE_Q8_0) { - extract_q8_0_data(tensor, weights, scales, biases); + constexpr std::string_view weight_suffix = ".weight"; + std::string name_prefix; + if (name.size() > weight_suffix.size() && + name.compare( + name.size() - weight_suffix.size(), + weight_suffix.size(), + weight_suffix) == 0) { + name_prefix = name.substr(0, name.size() - weight_suffix.size()); + } else { + name_prefix = name; } - a.emplace(name, std::move(weights)); + auto sb = allocator::malloc(uint8.size()); + *static_cast(sb.raw_ptr()) = 0; + array scales_ph(sb, Shape{1}, uint8); auto check_insert = [](const auto& inserted) { if (!inserted.second) { std::ostringstream msg; msg << "[load_gguf] Duplicate parameter name " << inserted.first->second - << " this can happend when loading quantized tensors."; + << " this can happen when loading quantized tensors."; throw std::runtime_error(msg.str()); } }; - constexpr std::string_view weight_suffix = ".weight"; - const std::string name_prefix = - name.substr(0, name.length() - weight_suffix.length()); - check_insert(a.emplace(name_prefix + ".scales", std::move(scales))); - check_insert(a.emplace(name_prefix + ".biases", std::move(biases))); + check_insert(a.emplace(name, std::move(weight))); + check_insert(a.emplace(name_prefix + ".scales", std::move(scales_ph))); + + kquant_entries.push_back(name + ":" + std::string(codec.name)); } } // namespace mlx::core diff --git a/mlx/ops.cpp b/mlx/ops.cpp index defcc2f6e0..3d96e8ac0d 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -122,7 +122,65 @@ std::pair extract_quantized_matmul_dims( const std::optional& biases, bool transpose, int group_size, - int bits) { + int bits, + QuantizationMode mode = QuantizationMode::Affine, + const std::string& kquant_type = "") { + if (mode == QuantizationMode::KQuant) { + if (w.dtype() != uint8) { + std::ostringstream msg; + msg << "[" << tag << "] KQuant weight tensor must be uint8 " + << "but received " << w.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + if (biases) { + std::ostringstream msg; + msg << "[" << tag << "] Biases must be null for kquant mode."; + throw std::invalid_argument(msg.str()); + } + if (kquant_type.empty()) { + std::ostringstream msg; + msg << "[" << tag + << "] kquant mode requires kquant_type (e.g. \"q4_k\")."; + throw std::invalid_argument(msg.str()); + } + const KQuantCodec* codec = kquant_codec_by_name(kquant_type); + if (codec == nullptr) { + std::ostringstream msg; + msg << "[" << tag << "] Unknown kquant_type \"" << kquant_type << "\"."; + throw std::invalid_argument(msg.str()); + } + int w_bytes_per_row = w.shape(-1); + // Each row of `w` must be a whole number of wire-format blocks. A + // weaker check like `(bytes * wpb) % bpb == 0` is wrong: e.g. for + // Q8_0 (wpb=32, bpb=34), 17 bytes pass it but represent half a block. + if (w_bytes_per_row % codec->bytes_per_block != 0) { + std::ostringstream msg; + msg << "[" << tag << "] KQuant weight last dim " << w_bytes_per_row + << " bytes is not a whole number of " << codec->bytes_per_block + << "-byte " << codec->name << " blocks " + << "(group_size=" << group_size << ", bits=" << bits << ")."; + throw std::invalid_argument(msg.str()); + } + int n_blocks_per_row = w_bytes_per_row / codec->bytes_per_block; + int weights_per_row = n_blocks_per_row * codec->weights_per_block; + int w_inner_dims = transpose ? weights_per_row : w.shape(-2); + int w_outer_dims = transpose ? w.shape(-2) : weights_per_row; + int x_inner_dims = x.shape(-1); + if (w_inner_dims != x_inner_dims) { + std::ostringstream msg; + msg << "[" << tag << "] Last dimension of first input with " + << "shape (..., " << x_inner_dims << ") does not match " + << "the expanded quantized matrix (" << w_inner_dims << ", " + << w_outer_dims << ") computed from shape " << w.shape() + << " with kquant codec " << codec->name << " (gs=" << group_size + << ", bits=" << bits << ", " << codec->bytes_per_block << " bytes/" + << codec->weights_per_block + << " weights) and transpose=" << std::boolalpha << transpose; + throw std::invalid_argument(msg.str()); + } + return {w_inner_dims, w_outer_dims}; + } + validate_quantized_input(tag, w, scales, group_size, bits, biases); int x_inner_dims = x.shape(-1); @@ -4368,26 +4426,42 @@ array conv_general( std::pair quantization_params_from_mode( QuantizationMode mode, std::optional group_size_, - std::optional bits_) { + std::optional bits_, + const std::string& kquant_type = "") { int default_group_size; int default_bits; - switch (mode) { - case QuantizationMode::Affine: - default_group_size = 64; - default_bits = 4; - break; - case QuantizationMode::Nvfp4: - default_group_size = 16; - default_bits = 4; - break; - case QuantizationMode::Mxfp4: - default_group_size = 32; - default_bits = 4; - break; - case QuantizationMode::Mxfp8: + if (mode == QuantizationMode::KQuant && !kquant_type.empty()) { + const auto* codec = kquant_codec_by_name(kquant_type); + if (codec) { + default_group_size = codec->weights_per_block; + default_bits = codec->bits; + } else { default_group_size = 32; default_bits = 8; - break; + } + } else { + switch (mode) { + case QuantizationMode::Affine: + default_group_size = 64; + default_bits = 4; + break; + case QuantizationMode::Nvfp4: + default_group_size = 16; + default_bits = 4; + break; + case QuantizationMode::Mxfp4: + default_group_size = 32; + default_bits = 4; + break; + case QuantizationMode::Mxfp8: + default_group_size = 32; + default_bits = 8; + break; + case QuantizationMode::KQuant: + default_group_size = 32; + default_bits = 8; + break; + } } return { group_size_.has_value() ? *group_size_ : default_group_size, @@ -4427,6 +4501,18 @@ std::pair validate_mode_with_type( } else { return {dtype, qmode}; } + } else if (qmode == QuantizationMode::KQuant) { + if (biases) { + std::ostringstream msg; + msg << "[" << tag << "] Biases must be null for quantization mode '" + << mode << "'."; + throw std::invalid_argument(msg.str()); + } + if (out_type.has_value()) { + return {*out_type, qmode}; + } else { + return {float16, qmode}; + } } else if (scales.dtype() != uint8) { std::ostringstream msg; msg << "[" << tag << "] Scale type must be uint8 but received type " @@ -4483,18 +4569,30 @@ array quantized_matmul( std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, + const std::string& kquant_type /* = "" */, StreamOrDevice s /* = {} */) { auto [dtype, qmode] = validate_mode_with_type( "quantized_matmul", scales, biases, std::nullopt, mode); auto [group_size, bits] = - quantization_params_from_mode(qmode, group_size_, bits_); + quantization_params_from_mode(qmode, group_size_, bits_, kquant_type); // Check and extract the quantized matrix shape against x auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( - "quantized_matmul", x, w, scales, biases, transpose, group_size, bits); + "quantized_matmul", + x, + w, + scales, + biases, + transpose, + group_size, + bits, + qmode, + kquant_type); if (qmode == QuantizationMode::Affine) { dtype = promote_types(x.dtype(), dtype); + } else if (qmode == QuantizationMode::KQuant) { + dtype = x.dtype() == float32 ? bfloat16 : x.dtype(); } else { dtype = x.dtype(); } @@ -4510,7 +4608,7 @@ array quantized_matmul( inputs = { astype(x, dtype), w, astype(scales, dtype), astype(*biases, dtype)}; } else { - inputs = {x, w, scales}; + inputs = {astype(x, dtype, s), w, scales}; } if (x.ndim() > 2 && w.ndim() > 2) { @@ -4522,7 +4620,7 @@ array quantized_matmul( std::move(out_shape), dtype, std::make_shared( - to_stream(s), group_size, bits, qmode, transpose), + to_stream(s), group_size, bits, qmode, transpose, kquant_type), std::move(inputs)); } @@ -4905,16 +5003,87 @@ std::vector fp_quantize( return fallback(inputs); } +std::vector kquant_quantize( + const array& w, + int group_size, + int bits, + const std::string& kquant_type, + const std::optional& imatrix, + Stream s) { + if (kquant_type.empty()) { + throw std::invalid_argument( + "[kquant_quantize] kquant_type is required for mode='kquant'."); + } + const auto* codec = kquant_codec_by_name(kquant_type); + if (codec == nullptr) { + throw std::invalid_argument( + "[kquant_quantize] Unknown kquant_type: '" + kquant_type + "'."); + } + if (!codec->has_encode) { + throw std::invalid_argument( + "[kquant_quantize] Quantize (encode) is not supported for codec '" + + kquant_type + "'."); + } + if (w.shape(-1) % codec->weights_per_block != 0) { + std::ostringstream msg; + msg << "[kquant_quantize] Last dim (" << w.shape(-1) + << ") must be a multiple of weights_per_block (" + << codec->weights_per_block << ") for codec '" << kquant_type << "'."; + throw std::invalid_argument(msg.str()); + } + if (imatrix.has_value()) { + const auto& im = *imatrix; + if (im.dtype() != float32) { + throw std::invalid_argument("[kquant_quantize] imatrix must be float32."); + } + if (im.ndim() != 1 || im.shape(-1) != w.shape(-1)) { + std::ostringstream msg; + msg << "[kquant_quantize] imatrix shape must be [K]=(" << w.shape(-1) + << ",) but got " << im.shape() << "."; + throw std::invalid_argument(msg.str()); + } + } + + auto wq_shape = w.shape(); + wq_shape.back() = + (w.shape(-1) / codec->weights_per_block) * codec->bytes_per_block; + Shape s_shape = {1}; + + auto fallback = [](const std::vector&) -> std::vector { + throw std::runtime_error("[kquant_quantize] Gradients are not supported."); + }; + + std::vector inputs = {w}; + if (imatrix.has_value()) { + inputs.push_back(*imatrix); + } + + return array::make_arrays( + {std::move(wq_shape), std::move(s_shape)}, + {uint8, uint8}, + std::make_shared( + s, + fallback, + group_size, + bits, + QuantizationMode::KQuant, + /* dequantize= */ false, + kquant_type), + inputs); +} + std::vector quantize( const array& w, std::optional group_size_ /* = std::nullopt */, std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, const std::optional& global_scale /* = std::nullopt */, + const std::string& kquant_type /* = "" */, + const std::optional& imatrix /* = std::nullopt */, StreamOrDevice s /* = {} */) { auto qmode = string_to_quantization_mode(mode, "quantize"); auto [group_size, bits] = - quantization_params_from_mode(qmode, group_size_, bits_); + quantization_params_from_mode(qmode, group_size_, bits_, kquant_type); if (!issubdtype(w.dtype(), floating)) { std::ostringstream msg; msg << "[quantize] Only real floating types can be quantized " @@ -4946,6 +5115,9 @@ std::vector quantize( validate_global_scale("quantize", qmode, global_scale); if (qmode == QuantizationMode::Affine) { return affine_quantize(w, group_size, bits, s); + } else if (qmode == QuantizationMode::KQuant) { + return kquant_quantize( + w, group_size, bits, kquant_type, imatrix, to_stream(s)); } else { return fp_quantize(w, group_size, bits, qmode, global_scale, to_stream(s)); } @@ -5168,6 +5340,58 @@ array fp_dequantize( return fallback(inputs)[0]; } +array kquant_dequantize( + const array& w, + const array& scales, + int group_size, + int bits, + Dtype out_type, + const std::string& kquant_type, + Stream s) { + if (w.dtype() != uint8) { + throw std::invalid_argument( + "[kquant_dequantize] KQuant weights must be uint8."); + } + if (w.ndim() < 1) { + throw std::invalid_argument( + "[kquant_dequantize] w must have at least 1 dimension."); + } + const auto* codec = kquant_codec_by_name(kquant_type); + if (codec == nullptr) { + throw std::invalid_argument( + "[kquant_dequantize] Unknown kquant_type: '" + kquant_type + "'."); + } + if (w.shape(-1) % codec->bytes_per_block != 0) { + std::ostringstream msg; + msg << "[kquant_dequantize] Last dim (" << w.shape(-1) + << ") must be a multiple of bytes_per_block (" << codec->bytes_per_block + << ") for codec '" << kquant_type << "'."; + throw std::invalid_argument(msg.str()); + } + + auto out_shape = w.shape(); + out_shape.back() = + (w.shape(-1) / codec->bytes_per_block) * codec->weights_per_block; + + auto fallback = [](const std::vector&) -> std::vector { + throw std::runtime_error( + "[kquant_dequantize] Gradients are not supported."); + }; + + return array( + std::move(out_shape), + out_type, + std::make_shared( + s, + fallback, + group_size, + bits, + QuantizationMode::KQuant, + /* dequantize= */ true, + kquant_type), + {w, scales}); +} + array dequantize( const array& w, const array& scales, @@ -5177,11 +5401,16 @@ array dequantize( const std::string& mode /* = "affine" */, const std::optional& global_scale /* = std::nullopt */, std::optional dtype /* = std::nullopt */, + const std::string& kquant_type /* = "" */, StreamOrDevice s /* = {} */) { auto [out_type, qmode] = validate_mode_with_type("dequantize", scales, biases, dtype, mode); auto [group_size, bits] = - quantization_params_from_mode(qmode, group_size_, bits_); + quantization_params_from_mode(qmode, group_size_, bits_, kquant_type); + if (qmode == QuantizationMode::KQuant) { + return kquant_dequantize( + w, scales, group_size, bits, out_type, kquant_type, to_stream(s)); + } if (bits <= 0) { std::ostringstream msg; msg << "[dequantize] Invalid value for bits: " << bits; @@ -5275,20 +5504,41 @@ array gather_qmm( std::optional bits_ /* = std::nullopt */, const std::string& mode /* = "affine" */, bool sorted_indices /* = false */, + const std::string& kquant_type /* = "" */, StreamOrDevice s /* = {} */) { if (!lhs_indices_ && !rhs_indices_) { return quantized_matmul( - x, w, scales, biases, transpose, group_size_, bits_, mode, s); + x, + w, + scales, + biases, + transpose, + group_size_, + bits_, + mode, + kquant_type, + s); } auto [out_type, qmode] = validate_mode_with_type("gather_qmm", scales, biases, std::nullopt, mode); auto [group_size, bits] = - quantization_params_from_mode(qmode, group_size_, bits_); + quantization_params_from_mode(qmode, group_size_, bits_, kquant_type); auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( - "gather_qmm", x, w, scales, biases, transpose, group_size, bits); + "gather_qmm", + x, + w, + scales, + biases, + transpose, + group_size, + bits, + qmode, + kquant_type); if (qmode == QuantizationMode::Affine) { out_type = promote_types(x.dtype(), out_type); + } else if (qmode == QuantizationMode::KQuant) { + out_type = x.dtype() == float32 ? bfloat16 : x.dtype(); } else { out_type = x.dtype(); } @@ -5338,6 +5588,13 @@ array gather_qmm( astype(*biases, out_type, s), std::move(lhs_indices), std::move(rhs_indices)}; + } else if (qmode == QuantizationMode::KQuant) { + inputs = { + astype(x, out_type, s), + std::move(w), + std::move(scales), + std::move(lhs_indices), + std::move(rhs_indices)}; } else { inputs = { astype(x, out_type, s), @@ -5356,7 +5613,8 @@ array gather_qmm( qmode, transpose, sorted_indices && !rhs_indices_, - sorted_indices && !lhs_indices_), + sorted_indices && !lhs_indices_, + kquant_type), std::move(inputs)); } diff --git a/mlx/ops.h b/mlx/ops.h index 208964d1aa..cc5abbae6f 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1479,6 +1479,7 @@ MLX_API array quantized_matmul( std::optional group_size = std::nullopt, std::optional bits = std::nullopt, const std::string& mode = "affine", + const std::string& kquant_type = "", StreamOrDevice s = {}); /** Quantize a matrix along its last axis */ @@ -1488,6 +1489,8 @@ MLX_API std::vector quantize( std::optional bits = std::nullopt, const std::string& mode = "affine", const std::optional& global_scale = std::nullopt, + const std::string& kquant_type = "", + const std::optional& imatrix = std::nullopt, StreamOrDevice s = {}); /** Dequantize a matrix produced by quantize() */ @@ -1500,6 +1503,7 @@ MLX_API array dequantize( const std::string& mode = "affine", const std::optional& global_scale = std::nullopt, std::optional dtype = std::nullopt, + const std::string& kquant_type = "", StreamOrDevice s = {}); MLX_API array qqmm( @@ -1533,6 +1537,7 @@ MLX_API array gather_qmm( std::optional bits = std::nullopt, const std::string& mode = "affine", bool sorted_indices = false, + const std::string& kquant_type = "", StreamOrDevice s = {}); /** Returns a contraction of a and b over multiple dimensions. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index f3acec574b..6e2333f1d7 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "mlx/backend/common/utils.h" #include "mlx/fft.h" @@ -3424,6 +3425,8 @@ std::string quantization_mode_to_string(QuantizationMode mode) { return "mxfp4"; case QuantizationMode::Mxfp8: return "mxfp8"; + case QuantizationMode::KQuant: + return "kquant"; case QuantizationMode::Nvfp4: default: return "nvfp4"; @@ -3441,6 +3444,8 @@ QuantizationMode string_to_quantization_mode( return QuantizationMode::Mxfp8; } else if (mode == "nvfp4") { return QuantizationMode::Nvfp4; + } else if (mode == "kquant") { + return QuantizationMode::KQuant; } std::string msg; if (!tag.empty()) { @@ -3450,6 +3455,23 @@ QuantizationMode string_to_quantization_mode( throw std::invalid_argument(msg); } +const KQuantCodec* kquant_codec_by_name(const std::string& name) { + static const std::unordered_map codecs = { + {"q4_0", {"q4_0", 32, 18, 4, true, true}}, + {"q4_1", {"q4_1", 32, 20, 4, true, true}}, + {"q5_0", {"q5_0", 32, 22, 5, true, true}}, + {"q5_1", {"q5_1", 32, 24, 5, true, true}}, + {"q8_0", {"q8_0", 32, 34, 8, true, true}}, + {"q2_k", {"q2_k", 256, 84, 2, true, true}}, + {"q3_k", {"q3_k", 256, 110, 3, true, true}}, + {"q4_k", {"q4_k", 256, 144, 4, true, true}}, + {"q5_k", {"q5_k", 256, 176, 5, true, true}}, + {"q6_k", {"q6_k", 256, 210, 6, true, true}}, + }; + auto it = codecs.find(name); + return it != codecs.end() ? &it->second : nullptr; +} + std::pair, std::vector> QuantizedMatmul::vmap( const std::vector& inputs, const std::vector& axes) { @@ -3478,6 +3500,7 @@ std::vector QuantizedMatmul::vjp( group_size_, bits_, quantization_mode_to_string(mode_), + kquant_type_, stream())); } @@ -3513,8 +3536,9 @@ std::vector QuantizedMatmul::vjp( group_size_, bits_, quantization_mode_to_string(mode_), - {}, // placeholder for amax std::nullopt, + std::nullopt, + kquant_type_, stream()); wq = unflatten(wq, -1, {-1, group_size_}, stream()); vjps.push_back(sum(multiply(*dsb, wq, stream()), -1, false, stream())); @@ -3542,19 +3566,37 @@ std::vector QuantizedMatmul::jvp( group_size_, bits_, quantization_mode_to_string(mode_), + kquant_type_, stream())}; } bool QuantizedMatmul::is_equivalent(const Primitive& other) const { const QuantizedMatmul& qm_other = static_cast(other); return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && - mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_; + mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_ && + kquant_type_ == qm_other.kquant_type_; } std::vector QuantizedMatmul::output_shapes( const std::vector& inputs) { auto& w = inputs[1]; - int w_outer_dims = (transpose_) ? w.shape(-2) : w.shape(-1) * 32 / bits_; + int w_outer_dims; + if (mode_ == QuantizationMode::KQuant) { + // For kquant the weight tensor is uint8 packed wire-format bytes; the + // expansion factor is weights_per_block / bytes_per_block, not the + // bits-per-uint32 formula used by affine/fp. + const KQuantCodec* codec = kquant_codec_by_name(kquant_type_); + if (codec == nullptr) { + throw std::invalid_argument( + "[QuantizedMatmul::output_shapes] Unknown kquant_type \"" + + kquant_type_ + "\"."); + } + int weights_per_row = + w.shape(-1) / codec->bytes_per_block * codec->weights_per_block; + w_outer_dims = transpose_ ? w.shape(-2) : weights_per_row; + } else { + w_outer_dims = transpose_ ? w.shape(-2) : w.shape(-1) * 32 / bits_; + } auto out_shape = inputs[0].shape(); out_shape.back() = w_outer_dims; return {std::move(out_shape)}; @@ -3563,7 +3605,7 @@ std::vector QuantizedMatmul::output_shapes( bool QQMatmul::is_equivalent(const Primitive& other) const { const QQMatmul& qm_other = static_cast(other); return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && - mode_ == qm_other.mode_; + mode_ == qm_other.mode_ && kquant_type_ == qm_other.kquant_type_; } std::vector QQMatmul::output_shapes(const std::vector& inputs) { @@ -3682,6 +3724,7 @@ std::vector GatherQMM::vjp( bits_, quantization_mode_to_string(mode_), sorted, + kquant_type_, stream()); if (sorted && no_broadcast) { vjps.push_back(g); @@ -3699,7 +3742,9 @@ std::vector GatherQMM::vjp( } // gradient wrt to the indices is undefined - else if (arg > 3) { + // Affine inputs: [x, w, scales, biases, lhs_idx, rhs_idx] + // KQuant/FP inputs: [x, w, scales, lhs_idx, rhs_idx] + else if (arg >= (mode_ == QuantizationMode::Affine ? 4 : 3)) { throw std::runtime_error( "[GatherQMM::vjp] cannot compute the gradient wrt the indices."); } @@ -3748,7 +3793,8 @@ std::vector GatherQMM::vjp( bits_, quantization_mode_to_string(mode_), std::nullopt, - std::nullopt, // amax placeholder + std::nullopt, + kquant_type_, stream()), -1, {-1, group_size_}, @@ -3773,7 +3819,10 @@ std::vector GatherQMM::jvp( bool GatherQMM::is_equivalent(const Primitive& other) const { const GatherQMM& qm_other = static_cast(other); return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ && - mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_; + mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_ && + left_sorted_ == qm_other.left_sorted_ && + right_sorted_ == qm_other.right_sorted_ && + kquant_type_ == qm_other.kquant_type_; } std::pair, std::vector> RandomBits::vmap( diff --git a/mlx/primitives.h b/mlx/primitives.h index 75fb978dce..795749b91e 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -152,13 +152,24 @@ class MLX_API UnaryPrimitive : public Primitive { UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete; }; -enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 }; +enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4, KQuant }; std::string quantization_mode_to_string(QuantizationMode mode); QuantizationMode string_to_quantization_mode( const std::string& mode, std::string_view error_tag = ""); +struct KQuantCodec { + std::string_view name; + int weights_per_block; + int bytes_per_block; + int bits; + bool has_matmul_kernel; + bool has_encode; +}; + +const KQuantCodec* kquant_codec_by_name(const std::string& name); + class Abs : public UnaryPrimitive { public: explicit Abs(Stream stream) : UnaryPrimitive(stream) {} @@ -1619,12 +1630,14 @@ class QuantizedMatmul : public UnaryPrimitive { int group_size, int bits, QuantizationMode mode, - bool transpose) + bool transpose, + const std::string& kquant_type = "") : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), mode_(mode), - transpose_(transpose) {} + transpose_(transpose), + kquant_type_(kquant_type) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1635,7 +1648,7 @@ class QuantizedMatmul : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { - return std::make_tuple(group_size_, bits_, mode_, transpose_); + return std::make_tuple(group_size_, bits_, mode_, transpose_, kquant_type_); } private: @@ -1643,6 +1656,7 @@ class QuantizedMatmul : public UnaryPrimitive { int bits_; QuantizationMode mode_; bool transpose_; + std::string kquant_type_; }; class QQMatmul : public UnaryPrimitive { @@ -1651,11 +1665,13 @@ class QQMatmul : public UnaryPrimitive { Stream stream, int group_size, int bits, - QuantizationMode mode) + QuantizationMode mode, + const std::string& kquant_type = "") : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), - mode_(mode) {} + mode_(mode), + kquant_type_(kquant_type) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1666,13 +1682,14 @@ class QQMatmul : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; std::vector output_shapes(const std::vector& inputs) override; auto state() const { - return std::make_tuple(group_size_, bits_, mode_); + return std::make_tuple(group_size_, bits_, mode_, kquant_type_); } private: int group_size_; int bits_; QuantizationMode mode_; + std::string kquant_type_; }; class GatherQMM : public UnaryPrimitive { @@ -1684,14 +1701,16 @@ class GatherQMM : public UnaryPrimitive { QuantizationMode mode, bool transpose, bool left_sorted = false, - bool right_sorted = false) + bool right_sorted = false, + const std::string& kquant_type = "") : UnaryPrimitive(stream), group_size_(group_size), bits_(bits), mode_(mode), transpose_(transpose), left_sorted_(left_sorted), - right_sorted_(right_sorted) {} + right_sorted_(right_sorted), + kquant_type_(kquant_type) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -1702,7 +1721,13 @@ class GatherQMM : public UnaryPrimitive { bool is_equivalent(const Primitive& other) const override; auto state() const { return std::make_tuple( - group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_); + group_size_, + bits_, + mode_, + transpose_, + left_sorted_, + right_sorted_, + kquant_type_); } private: @@ -1712,6 +1737,7 @@ class GatherQMM : public UnaryPrimitive { bool transpose_; bool left_sorted_; bool right_sorted_; + std::string kquant_type_; }; class RandomBits : public UnaryPrimitive { diff --git a/python/mlx/nn/layers/distributed.py b/python/mlx/nn/layers/distributed.py index 8604047954..3b3da1b87c 100644 --- a/python/mlx/nn/layers/distributed.py +++ b/python/mlx/nn/layers/distributed.py @@ -463,6 +463,11 @@ def from_quantized_linear( segments: Union[int, list] = 1, group: Optional[mx.distributed.Group] = None, ): + mode = getattr(quantized_linear_layer, "mode", "affine") + if mode == "kquant": + raise NotImplementedError( + "Distributed quantized layers do not support mode='kquant'." + ) group = group or mx.distributed.init() output_dims, input_dims = quantized_linear_layer.weight.shape input_dims = (input_dims * 32) // quantized_linear_layer.bits @@ -473,7 +478,7 @@ def from_quantized_linear( hasattr(quantized_linear_layer, "bias"), group_size=quantized_linear_layer.group_size, bits=quantized_linear_layer.bits, - mode=getattr(quantized_linear_layer, "mode", "affine"), + mode=mode, group=group, ) sl.update( @@ -595,6 +600,11 @@ def from_quantized_linear( segments: Union[int, list] = 1, group: Optional[mx.distributed.Group] = None, ): + mode = getattr(quantized_linear_layer, "mode", "affine") + if mode == "kquant": + raise NotImplementedError( + "Distributed quantized layers do not support mode='kquant'." + ) group = group or mx.distributed.init() output_dims, input_dims = quantized_linear_layer.weight.shape input_dims = (input_dims * 32) // quantized_linear_layer.bits @@ -605,7 +615,7 @@ def from_quantized_linear( hasattr(quantized_linear_layer, "bias"), group_size=quantized_linear_layer.group_size, bits=quantized_linear_layer.bits, - mode=getattr(quantized_linear_layer, "mode", "affine"), + mode=mode, group=group, ) sl.update( diff --git a/python/mlx/nn/layers/embedding.py b/python/mlx/nn/layers/embedding.py index e05cfb5f01..3fe2574122 100644 --- a/python/mlx/nn/layers/embedding.py +++ b/python/mlx/nn/layers/embedding.py @@ -46,8 +46,11 @@ def to_quantized( bits: Optional[int] = None, mode: str = "affine", quantize_input: bool = False, + kquant_type: str = "", ): """Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer.""" if quantize_input: raise ValueError("Quantized input is not supported.") - return QuantizedEmbedding.from_embedding(self, group_size, bits, mode) + return QuantizedEmbedding.from_embedding( + self, group_size, bits, mode, kquant_type=kquant_type + ) diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 7b868b93ee..ad7204bb73 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -76,6 +76,7 @@ def to_quantized( bits: Optional[int] = None, mode: str = "affine", quantize_input: bool = False, + kquant_type: str = "", ): """Return a quantized approximation of this layer. @@ -91,6 +92,8 @@ def to_quantized( mode (str): The quantization method to use (see :func:`mlx.core.quantize`). Default: ``"affine"``. quantize_input (bool): Whether to quantize input. Default: ``False``. + kquant_type (str): For ``mode="kquant"``, selects the codec + (e.g. ``"q4_k"``). Default: ``""``. Returns: QuantizedLinear or QQLinear: A quantized version of this layer. @@ -105,7 +108,9 @@ def to_quantized( f"Quantized activations are only supported for 'nvfp4' and 'mxfp8' modes, got {mode}." ) return QQLinear.from_linear(self, group_size, bits, mode) - return QuantizedLinear.from_linear(self, group_size, bits, mode) + return QuantizedLinear.from_linear( + self, group_size, bits, mode, kquant_type=kquant_type + ) class Bilinear(Module): diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index 57e7c88898..106ba96006 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -7,24 +7,59 @@ from mlx.nn.layers.base import Module from mlx.utils import tree_map_with_path - -def _defaults_for_mode(mode, group_size, bits): - mode_defaults = { - "affine": (64, 4), - "mxfp4": (32, 4), - "nvfp4": (16, 4), - "mxfp8": (32, 8), - } - default_group_size, default_bits = mode_defaults[mode] +_KQUANT_CODEC_PARAMS = { + "q4_0": (32, 4), + "q4_1": (32, 4), + "q5_0": (32, 5), + "q5_1": (32, 5), + "q8_0": (32, 8), + "q2_k": (256, 2), + "q3_k": (256, 3), + "q4_k": (256, 4), + "q5_k": (256, 5), + "q6_k": (256, 6), +} + + +def _defaults_for_mode(mode, group_size, bits, kquant_type=""): + if mode == "kquant" and kquant_type in _KQUANT_CODEC_PARAMS: + default_group_size, default_bits = _KQUANT_CODEC_PARAMS[kquant_type] + else: + mode_defaults = { + "affine": (64, 4), + "mxfp4": (32, 4), + "nvfp4": (16, 4), + "mxfp8": (32, 8), + "kquant": (32, 8), + } + default_group_size, default_bits = mode_defaults[mode] return group_size or default_group_size, bits or default_bits +# (weights_per_block, bytes_per_block) per codec. +# Must match the C++ kquant_codec_by_name registry in primitives.cpp. +# Verified by test_kquant.py::TestKQuantRegistryConsistency. +_KQUANT_CODEC_GEOMETRY = { + "q4_0": (32, 18), + "q4_1": (32, 20), + "q5_0": (32, 22), + "q5_1": (32, 24), + "q8_0": (32, 34), + "q2_k": (256, 84), + "q3_k": (256, 110), + "q4_k": (256, 144), + "q5_k": (256, 176), + "q6_k": (256, 210), +} + + def quantize( model: Module, group_size: int = None, bits: int = None, *, mode: str = "affine", + kquant_type: str = "", quantize_input: bool = False, class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None, ): @@ -46,6 +81,8 @@ def quantize( :func:`mlx.core.quantize`). Default: ``None``. mode (str): The quantization method to use (see :func:`mlx.core.quantize`). Default: ``"affine"``. + kquant_type (str): For ``mode="kquant"``, selects the codec + (e.g. ``"q4_k"``, ``"q8_0"``). Default: ``""``. quantize_input (bool): Whether to quantize activations. Default: ``False``. class_predicate (Optional[Callable]): A callable which receives the :obj:`Module` path and :obj:`Module` itself and returns ``True`` or a @@ -70,7 +107,12 @@ def _maybe_quantize(path, m): if bool_or_params := class_predicate(path, m): if hasattr(m, "to_quantized"): if isinstance(bool_or_params, bool): - kwargs = {"group_size": group_size, "bits": bits, "mode": mode} + kwargs = { + "group_size": group_size, + "bits": bits, + "mode": mode, + "kquant_type": kquant_type, + } if quantize_input: kwargs["quantize_input"] = quantize_input return m.to_quantized(**kwargs) @@ -112,6 +154,8 @@ class QuantizedEmbedding(Module): See :func:`~mlx.core.quantize`. Default: ``None``. mode (str): The quantization method to use (see :func:`mlx.core.quantize`). Default: ``"affine"``. + kquant_type (str): For ``mode="kquant"``, selects the codec + (e.g. ``"q4_k"``, ``"q8_0"``). Default: ``""``. """ def __init__( @@ -121,18 +165,27 @@ def __init__( group_size: int = None, bits: int = None, mode: str = "affine", + kquant_type: str = "", ): super().__init__() # Quantization config - self.group_size, self.bits = _defaults_for_mode(mode, group_size, bits) + if mode == "kquant" and not kquant_type: + raise ValueError( + "kquant_type is required when mode='kquant'. " + "Valid codecs: " + ", ".join(sorted(_KQUANT_CODEC_GEOMETRY)) + ) + self.group_size, self.bits = _defaults_for_mode( + mode, group_size, bits, kquant_type + ) self.mode = mode + self.kquant_type = kquant_type # Initialize the quantized weight scale = math.sqrt(1 / dims) weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) self.weight, self.scales, *biases = mx.quantize( - weight, group_size, bits, mode=mode + weight, group_size, bits, mode=mode, kquant_type=kquant_type ) self.biases = biases[0] if biases else None self.num_embeddings = num_embeddings @@ -150,6 +203,7 @@ def __call__(self, x): group_size=self.group_size, bits=self.bits, mode=self.mode, + kquant_type=self.kquant_type, ) def as_linear(self, x): @@ -168,12 +222,14 @@ def as_linear(self, x): group_size=self.group_size, bits=self.bits, mode=self.mode, + kquant_type=self.kquant_type, ) def _extra_repr(self): + kq = f", kquant_type={self.kquant_type}" if self.kquant_type else "" return ( f"{self.num_embeddings}, {self.dims}, " - f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}" + f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}{kq}" ) @classmethod @@ -183,15 +239,19 @@ def from_embedding( group_size: int = None, bits: int = None, mode: str = "affine", + kquant_type: str = "", ): """Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer.""" embedding_dims, dims = embedding_layer.weight.shape - ql = cls(embedding_dims, dims, group_size, bits, mode=mode) + ql = cls( + embedding_dims, dims, group_size, bits, mode=mode, kquant_type=kquant_type + ) ql.weight, ql.scales, *biases = mx.quantize( embedding_layer.weight, group_size, bits, mode=mode, + kquant_type=kquant_type, ) ql.biases = biases[0] if biases else None return ql @@ -218,6 +278,8 @@ class QuantizedLinear(Module): See :func:`~mlx.core.quantize`. Default: ``None``. mode (str): The quantization method to use (see :func:`mlx.core.quantize`). Default: ``"affine"``. + kquant_type (str): For ``mode="kquant"``, selects the codec + (e.g. ``"q4_k"``, ``"q8_0"``). Default: ``""``. """ def __init__( @@ -228,12 +290,21 @@ def __init__( group_size: int = None, bits: int = None, mode: str = "affine", + kquant_type: str = "", ): super().__init__() # Quantization config - self.group_size, self.bits = _defaults_for_mode(mode, group_size, bits) + if mode == "kquant" and not kquant_type: + raise ValueError( + "kquant_type is required when mode='kquant'. " + "Valid codecs: " + ", ".join(sorted(_KQUANT_CODEC_GEOMETRY)) + ) + self.group_size, self.bits = _defaults_for_mode( + mode, group_size, bits, kquant_type + ) self.mode = mode + self.kquant_type = kquant_type # Initialize the quantized weight scale = math.sqrt(1 / input_dims) @@ -243,7 +314,7 @@ def __init__( shape=(output_dims, input_dims), ) self.weight, self.scales, *biases = mx.quantize( - weight, group_size, bits, mode=mode + weight, group_size, bits, mode=mode, kquant_type=kquant_type ) self.biases = biases[0] if biases else None @@ -256,10 +327,16 @@ def __init__( def _extra_repr(self): out_dims, in_dims = self.weight.shape - in_dims = (in_dims * 32) // self.bits + if self.mode == "kquant" and self.kquant_type in _KQUANT_CODEC_GEOMETRY: + # uint8 packing: in_dims is bytes-per-row, expand by codec ratio. + wpb, bpb = _KQUANT_CODEC_GEOMETRY[self.kquant_type] + in_dims = (in_dims // bpb) * wpb + else: + in_dims = (in_dims * 32) // self.bits + kq = f", kquant_type={self.kquant_type}" if self.kquant_type else "" return ( f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " - f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}" + f"group_size={self.group_size}, bits={self.bits}, mode={self.mode}{kq}" ) def __call__(self, x): @@ -272,6 +349,7 @@ def __call__(self, x): group_size=self.group_size, bits=self.bits, mode=self.mode, + kquant_type=self.kquant_type, ) if "bias" in self: x = x + self["bias"] @@ -284,15 +362,25 @@ def from_linear( group_size: int = None, bits: int = None, mode: str = "affine", + kquant_type: str = "", ): """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" output_dims, input_dims = linear_layer.weight.shape - ql = cls(input_dims, output_dims, False, group_size, bits, mode=mode) + ql = cls( + input_dims, + output_dims, + False, + group_size, + bits, + mode=mode, + kquant_type=kquant_type, + ) ql.weight, ql.scales, *biases = mx.quantize( linear_layer.weight, group_size, bits, mode=mode, + kquant_type=kquant_type, ) ql.biases = biases[0] if biases else None diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 9a48b37afe..3629a5c930 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4316,9 +4316,10 @@ void init_ops(nb::module_& m) { "bits"_a = nb::none(), "mode"_a = "affine", nb::kw_only(), + "kquant_type"_a = "", "stream"_a = nb::none(), nb::sig( - "def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, stream: Union[None, Stream, Device] = None) -> array"), + "def quantized_matmul(x: array, w: array, /, scales: array, biases: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, kquant_type: str = '', stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform the matrix multiplication with the quantized matrix ``w``. The quantization uses one floating point scale and bias per ``group_size`` of @@ -4341,9 +4342,14 @@ void init_ops(nb::module_& m) { ``w`` in the quantized array. See supported values and defaults in the :ref:`table of quantization modes `. Default: ``None``. mode (str, optional): The quantization mode. Default: ``"affine"``. + kquant_type (str, optional): For ``mode="kquant"``, selects the codec + (e.g. ``"q4_0"``, ``"q4_k"``). Required for kquant mode; ignored for + other modes. Default: ``""``. Returns: array: The result of the multiplication of ``x`` with ``w``. + For ``mode="kquant"``, the output dtype is ``bfloat16`` when the + input is ``float32``; otherwise the output matches the input dtype. )pbdoc"); m.def( "quantize", @@ -4354,9 +4360,11 @@ void init_ops(nb::module_& m) { "mode"_a = "affine", "global_scale"_a = nb::none(), nb::kw_only(), + "kquant_type"_a = "", + "imatrix"_a = nb::none(), "stream"_a = nb::none(), nb::sig( - "def quantize(w: array, /, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, global_scale: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"), + "def quantize(w: array, /, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', global_scale: Optional[array] = None, *, kquant_type: str = '', imatrix: Optional[array] = None, stream: Union[None, Stream, Device] = None) -> tuple[array, ...]"), R"pbdoc( Quantize the array ``w``. @@ -4370,7 +4378,8 @@ void init_ops(nb::module_& m) { the last dimension divisible by ``group_size`` The supported quantization modes are ``"affine"``, ``"mxfp4"``, - ``"mxfp8"``, and ``"nvfp4"``. They are described in more detail below. + ``"mxfp8"``, ``"nvfp4"``, and ``"kquant"``. They are described in + more detail below. Args: w (array): Array to be quantized @@ -4383,6 +4392,12 @@ void init_ops(nb::module_& m) { mode (str, optional): The quantization mode. Default: ``"affine"``. global_scale (array, optional): The per-input float32 scale used for ``"nvfp4"`` quantization if provided. Default: ``None``. + kquant_type (str, optional): For ``mode="kquant"``, selects the codec + (e.g. ``"q4_k"``, ``"q8_0"``). Required when mode is ``"kquant"``. + Default: ``""``. + imatrix (array, optional): For ``mode="kquant"``, a float32 importance + matrix of shape ``[K]`` for importance-weighted quantization. + Default: ``None``. Returns: tuple: A tuple with either two or three elements containing: @@ -4403,6 +4418,7 @@ void init_ops(nb::module_& m) { mxfp4 32\ :sup:`*` 4\ :sup:`*` e8m0 no mxfp8 32\ :sup:`*` 8\ :sup:`*` e8m0 no nvfp4 16\ :sup:`*` 4\ :sup:`*` e4m3 no + kquant 32\ :sup:`*` 8\ :sup:`*` per-codec no ====== ====================== ========================== ============= ===== :sup:`*` indicates the default value when unspecified. @@ -4442,6 +4458,16 @@ void init_ops(nb::module_& m) { More details on the ``"mx"`` formats can be found in the `specification `_. + + The ``"kquant"`` mode uses per-codec block quantization formats + with hierarchical scales. Each codec has a fixed block geometry + (e.g. 256 weights per super-block for ``"q4_k"``). The ``kquant_type`` + argument selects the codec. Available codecs: ``"q2_k"``, + ``"q3_k"``, ``"q4_k"``, ``"q5_k"``, ``"q6_k"``, ``"q4_0"``, + ``"q4_1"``, ``"q5_0"``, ``"q5_1"``, ``"q8_0"``. Weights are + stored as packed ``uint8`` bytes in the codec's wire format. + Unlike ``affine`` quantization, ``kquant`` does not return a + separate bias array. )pbdoc"); m.def( "dequantize", @@ -4455,9 +4481,10 @@ void init_ops(nb::module_& m) { "global_scale"_a = nb::none(), "dtype"_a = nb::none(), nb::kw_only(), + "kquant_type"_a = "", "stream"_a = nb::none(), nb::sig( - "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', global_scale: Optional[array] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + "def dequantize(w: array, /, scales: array, biases: Optional[array] = None, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', global_scale: Optional[array] = None, dtype: Optional[Dtype] = None, *, kquant_type: str = '', stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Dequantize the matrix ``w`` using quantization parameters. @@ -4472,20 +4499,22 @@ void init_ops(nb::module_& m) { bits (int, optional): The number of bits occupied by each element of ``w`` in the quantized array. See supported values and defaults in the :ref:`table of quantization modes `. Default: ``None``. + mode (str, optional): The quantization mode. Default: ``"affine"``. global_scale (array, optional): The per-input float32 scale used for ``"nvfp4"`` quantization if provided. Default: ``None``. dtype (Dtype, optional): The data type of the dequantized output. If ``None`` the return type is inferred from the scales and biases when possible and otherwise defaults to ``bfloat16``. Default: ``None``. - mode (str, optional): The quantization mode. Default: ``"affine"``. + kquant_type (str, optional): For ``mode="kquant"``, selects the codec + (e.g. ``"q4_k"``, ``"q8_0"``). Default: ``""``. Returns: array: The dequantized version of ``w`` Notes: The currently supported quantization modes are ``"affine"``, - ``"mxfp4``, ``"mxfp8"``, and ``"nvfp4"``. + ``"mxfp4"``, ``"mxfp8"``, ``"nvfp4"``, and ``"kquant"``. For ``affine`` quantization, given the notation in :func:`quantize`, we compute :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` @@ -4510,9 +4539,10 @@ void init_ops(nb::module_& m) { "mode"_a = "affine", nb::kw_only(), "sorted_indices"_a = false, + "kquant_type"_a = "", "stream"_a = nb::none(), nb::sig( - "def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, sorted_indices: bool = False, stream: Union[None, Stream, Device] = None) -> array"), + "def gather_qmm(x: array, w: array, /, scales: array, biases: Optional[array] = None, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: Optional[int] = None, bits: Optional[int] = None, mode: str = 'affine', *, sorted_indices: bool = False, kquant_type: str = '', stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Perform quantized matrix multiplication with matrix-level gather. @@ -4544,6 +4574,8 @@ void init_ops(nb::module_& m) { mode (str, optional): The quantization mode. Default: ``"affine"``. sorted_indices (bool, optional): May allow a faster implementation if the passed indices are sorted. Default: ``False``. + kquant_type (str, optional): For ``mode="kquant"``, selects the codec + (e.g. ``"q4_k"``). Default: ``""``. Returns: array: The result of the multiplication of ``x`` with ``w`` diff --git a/python/tests/test_gguf_kquant.py b/python/tests/test_gguf_kquant.py new file mode 100644 index 0000000000..32b9b179dd --- /dev/null +++ b/python/tests/test_gguf_kquant.py @@ -0,0 +1,155 @@ +# Copyright © 2026 Apple Inc. + +"""Tests for mx.load() raw K-quant tensor loading from GGUF files. + +All 10 K-quant codecs route through the kquant loader, which loads raw wire +bytes as a uint8 array with codec metadata for dispatch. +""" + +import os +import tempfile +import unittest + +import mlx.core as mx +import mlx_tests +import numpy as np + +try: + from gguf import GGMLQuantizationType, GGUFWriter + + HAS_GGUF = True +except ImportError: + HAS_GGUF = False + + +# Codec geometry. Source of truth: mlx/primitives.cpp kquant_codec_by_name. +KQUANT_CODECS = { + "q4_0": (lambda: GGMLQuantizationType.Q4_0, 32, 18), + "q4_1": (lambda: GGMLQuantizationType.Q4_1, 32, 20), + "q5_0": (lambda: GGMLQuantizationType.Q5_0, 32, 22), + "q5_1": (lambda: GGMLQuantizationType.Q5_1, 32, 24), + "q8_0": (lambda: GGMLQuantizationType.Q8_0, 32, 34), + "q2_k": (lambda: GGMLQuantizationType.Q2_K, 256, 84), + "q3_k": (lambda: GGMLQuantizationType.Q3_K, 256, 110), + "q4_k": (lambda: GGMLQuantizationType.Q4_K, 256, 144), + "q5_k": (lambda: GGMLQuantizationType.Q5_K, 256, 176), + "q6_k": (lambda: GGMLQuantizationType.Q6_K, 256, 210), +} + +# All K-quant codecs route through the raw kquant loader. +KQUANT_RAW_TYPES = ( + "q4_0", + "q4_1", + "q5_0", + "q5_1", + "q8_0", + "q2_k", + "q3_k", + "q4_k", + "q5_k", + "q6_k", +) + + +@unittest.skipUnless(HAS_GGUF, "gguf package not installed") +class TestGGUFKQuantLoad(mlx_tests.MLXTestCase): + def _write_gguf(self, path, codec_name, N, K, tensor_name="blk.0.attn_q.weight"): + type_factory, wpb, bpb = KQUANT_CODECS[codec_name] + gguf_type = type_factory() + assert K % wpb == 0 + bytes_per_row = (K // wpb) * bpb + total_bytes = N * bytes_per_row + rng = np.random.default_rng(42) + raw = rng.integers(0, 256, size=total_bytes, dtype=np.uint8) + + writer = GGUFWriter(path, arch="test") + # raw_shape uses numpy convention: last dim is innermost (in bytes when + # raw_dtype is a quantized type). Library reverses to GGML order on disk. + writer.add_tensor( + tensor_name, + raw, + raw_shape=(N, bytes_per_row), + raw_dtype=gguf_type, + ) + writer.write_header_to_file() + writer.write_kv_data_to_file() + writer.write_tensors_to_file() + writer.close() + return raw + + def test_kquant_raw_shape_dtype_bytes(self): + for codec in KQUANT_RAW_TYPES: + with self.subTest(codec=codec): + _, wpb, bpb = KQUANT_CODECS[codec] + N, K = 4, wpb * 2 # 2 blocks per row + with tempfile.NamedTemporaryFile(suffix=".gguf", delete=False) as f: + path = f.name + try: + raw = self._write_gguf(path, codec, N, K) + arrays, _ = mx.load(path, return_metadata=True) + w = arrays["blk.0.attn_q.weight"] + self.assertEqual(w.dtype, mx.uint8) + expected_shape = [N, (K // wpb) * bpb] + self.assertEqual(list(w.shape), expected_shape) + np.testing.assert_array_equal(np.array(w).flatten(), raw) + finally: + os.unlink(path) + + def test_kquant_metadata(self): + codec = "q4_k" + _, wpb, _ = KQUANT_CODECS[codec] + N, K = 4, wpb * 2 + with tempfile.NamedTemporaryFile(suffix=".gguf", delete=False) as f: + path = f.name + try: + self._write_gguf(path, codec, N, K) + _, metadata = mx.load(path, return_metadata=True) + self.assertIn("__kquant_types__", metadata) + kq_types = metadata["__kquant_types__"] + self.assertIsInstance(kq_types, list) + self.assertEqual(len(kq_types), 1) + self.assertEqual(kq_types[0], "blk.0.attn_q.weight:q4_k") + finally: + os.unlink(path) + + def test_placeholder_scales(self): + codec = "q6_k" + _, wpb, _ = KQUANT_CODECS[codec] + N, K = 4, wpb * 2 + with tempfile.NamedTemporaryFile(suffix=".gguf", delete=False) as f: + path = f.name + try: + self._write_gguf(path, codec, N, K) + arrays, _ = mx.load(path, return_metadata=True) + self.assertIn("blk.0.attn_q.scales", arrays) + s = arrays["blk.0.attn_q.scales"] + self.assertEqual(s.dtype, mx.uint8) + self.assertEqual(list(s.shape), [1]) + self.assertEqual(s.size, 1) + finally: + os.unlink(path) + + def test_q4_0_loads_as_kquant(self): + """Q4_0 routes through the kquant raw loader (not legacy affine).""" + codec = "q4_0" + _, wpb, bpb = KQUANT_CODECS[codec] + N, K = 4, 32 + with tempfile.NamedTemporaryFile(suffix=".gguf", delete=False) as f: + path = f.name + try: + raw = self._write_gguf(path, codec, N, K, "blk.0.ffn.weight") + arrays, metadata = mx.load(path, return_metadata=True) + w = arrays["blk.0.ffn.weight"] + self.assertEqual(w.dtype, mx.uint8) + expected_shape = [N, (K // wpb) * bpb] + self.assertEqual(list(w.shape), expected_shape) + np.testing.assert_array_equal(np.array(w).flatten(), raw) + kq = metadata.get("__kquant_types__", []) + names = [e.split(":")[0] for e in kq] if isinstance(kq, list) else [] + self.assertIn("blk.0.ffn.weight", names) + finally: + os.unlink(path) + + +if __name__ == "__main__": + mlx_tests.MLXTestRunner() diff --git a/python/tests/test_kquant.py b/python/tests/test_kquant.py new file mode 100644 index 0000000000..a09134ca94 --- /dev/null +++ b/python/tests/test_kquant.py @@ -0,0 +1,3041 @@ +# Copyright © 2026 Apple Inc. + +import os +import tempfile + +import mlx.core as mx +import mlx.nn as nn +import mlx_tests +import numpy as np + +Q8_0_GROUP = 32 +Q8_0_BLOCK_BYTES = 34 # fp16 d (2) + int8 q[32] (32) +Q8_0_D_OFFSET = 0 +Q8_0_Q_OFFSET = 2 + + +def _quantize_q8_0_row(row: np.ndarray) -> np.ndarray: + """Quantize a 1D fp32 array (length must be a multiple of 32) to Q8_0 + packed wire bytes. Returns uint8 array of length len(row) * 34/32.""" + assert row.ndim == 1 + assert row.size % Q8_0_GROUP == 0, "Q8_0 requires K % 32 == 0" + n_blocks = row.size // Q8_0_GROUP + out = np.zeros(n_blocks * Q8_0_BLOCK_BYTES, dtype=np.uint8) + blocks = row.reshape(n_blocks, Q8_0_GROUP).astype(np.float32) + for b in range(n_blocks): + block = blocks[b] + amax = float(np.max(np.abs(block))) + if amax == 0.0: + d = np.float32(0.0) + q = np.zeros(Q8_0_GROUP, dtype=np.int8) + else: + d = np.float32(amax / 127.0) + q = np.clip(np.round(block / d), -127, 127).astype(np.int8) + d_fp16 = np.float16(d) + base = b * Q8_0_BLOCK_BYTES + out[base + Q8_0_D_OFFSET : base + Q8_0_D_OFFSET + 2] = np.frombuffer( + d_fp16.tobytes(), dtype=np.uint8 + ) + out[base + Q8_0_Q_OFFSET : base + Q8_0_Q_OFFSET + Q8_0_GROUP] = np.frombuffer( + q.tobytes(), dtype=np.uint8 + ) + return out + + +def _quantize_q8_0_matrix(W: np.ndarray) -> np.ndarray: + assert W.ndim == 2 + assert W.shape[1] % Q8_0_GROUP == 0 + out_dim, in_dim = W.shape + bytes_per_row = in_dim * Q8_0_BLOCK_BYTES // Q8_0_GROUP + out = np.zeros((out_dim, bytes_per_row), dtype=np.uint8) + for i in range(out_dim): + out[i] = _quantize_q8_0_row(W[i]) + return out + + +def _dequantize_q8_0_matrix(W_q: np.ndarray, in_dim: int) -> np.ndarray: + assert W_q.dtype == np.uint8 + out_dim = W_q.shape[0] + bytes_per_row = W_q.shape[1] + assert bytes_per_row == in_dim * Q8_0_BLOCK_BYTES // Q8_0_GROUP + n_blocks_per_row = in_dim // Q8_0_GROUP + out = np.zeros((out_dim, in_dim), dtype=np.float32) + for i in range(out_dim): + row = W_q[i] + for b in range(n_blocks_per_row): + base = b * Q8_0_BLOCK_BYTES + d_fp16 = np.frombuffer( + row[base + Q8_0_D_OFFSET : base + Q8_0_D_OFFSET + 2].tobytes(), + dtype=np.float16, + )[0] + q_int8 = np.frombuffer( + row[base + Q8_0_Q_OFFSET : base + Q8_0_Q_OFFSET + Q8_0_GROUP].tobytes(), + dtype=np.int8, + ) + out[i, b * Q8_0_GROUP : (b + 1) * Q8_0_GROUP] = float( + d_fp16 + ) * q_int8.astype(np.float32) + return out + + +def _scales_placeholder() -> mx.array: + """KQuant accepts a placeholder scales tensor; pass a 1-byte dummy.""" + return mx.zeros((1,), dtype=mx.uint8) + + +def _kquant_matmul(x: mx.array, w_packed_np: np.ndarray) -> mx.array: + """Wrap mx.quantized_matmul with mode='kquant' and Q8_0 (gs=32, bits=8).""" + return mx.quantized_matmul( + x, + mx.array(w_packed_np), + scales=_scales_placeholder(), + biases=None, + transpose=True, + group_size=32, + bits=8, + mode="kquant", + kquant_type="q8_0", + ) + + +Q4_0_GROUP = 32 +Q4_0_BLOCK_BYTES = 18 # fp16 d (2) + uint8 qs[16] +Q4_0_D_OFFSET = 0 +Q4_0_QS_OFFSET = 2 + + +def _quantize_q4_0_row(row: np.ndarray) -> np.ndarray: + """Quantize a 1D fp32 array (length must be a multiple of 32) to Q4_0 + packed wire bytes. Returns uint8 array of length len(row) * 18/32. + + Symmetric quantization: scale = amax / -8.0; + q = round(weight / scale) + 8 clipped to [0, 15]; pack split-half. + """ + assert row.ndim == 1 + assert row.size % Q4_0_GROUP == 0, "Q4_0 requires K % 32 == 0" + n_blocks = row.size // Q4_0_GROUP + out = np.zeros(n_blocks * Q4_0_BLOCK_BYTES, dtype=np.uint8) + blocks = row.reshape(n_blocks, Q4_0_GROUP).astype(np.float32) + for b in range(n_blocks): + block = blocks[b] + # Use the value with largest magnitude (signed) -- pick the entry + # with greatest |x|, keep its sign for the scale. + idx = int(np.argmax(np.abs(block))) + amax_signed = float(block[idx]) + if amax_signed == 0.0: + d = np.float32(0.0) + q = np.full(Q4_0_GROUP, 8, dtype=np.uint8) + else: + d = np.float32(amax_signed / -8.0) + inv_d = 1.0 / float(d) + q = np.clip(np.round(block * inv_d).astype(np.int32) + 8, 0, 15).astype( + np.uint8 + ) + # Split-half pack: qs[j] = q[j] | (q[j+16] << 4) for j in [0, 16). + qs = np.zeros(16, dtype=np.uint8) + for j in range(16): + qs[j] = np.uint8((int(q[j]) & 0x0F) | ((int(q[j + 16]) & 0x0F) << 4)) + d_fp16 = np.float16(d) + base = b * Q4_0_BLOCK_BYTES + out[base + Q4_0_D_OFFSET : base + Q4_0_D_OFFSET + 2] = np.frombuffer( + d_fp16.tobytes(), dtype=np.uint8 + ) + out[base + Q4_0_QS_OFFSET : base + Q4_0_QS_OFFSET + 16] = qs + return out + + +def _quantize_q4_0_matrix(W: np.ndarray) -> np.ndarray: + assert W.ndim == 2 + assert W.shape[1] % Q4_0_GROUP == 0 + out_dim, in_dim = W.shape + bytes_per_row = in_dim * Q4_0_BLOCK_BYTES // Q4_0_GROUP + out = np.zeros((out_dim, bytes_per_row), dtype=np.uint8) + for i in range(out_dim): + out[i] = _quantize_q4_0_row(W[i]) + return out + + +def _dequantize_q4_0_matrix(W_q: np.ndarray, in_dim: int) -> np.ndarray: + """Reference Q4_0 dequantization.""" + assert W_q.dtype == np.uint8 + out_dim = W_q.shape[0] + bytes_per_row = W_q.shape[1] + assert bytes_per_row == in_dim * Q4_0_BLOCK_BYTES // Q4_0_GROUP + n_blocks_per_row = in_dim // Q4_0_GROUP + out = np.zeros((out_dim, in_dim), dtype=np.float32) + for i in range(out_dim): + row = W_q[i] + for b in range(n_blocks_per_row): + base = b * Q4_0_BLOCK_BYTES + d = float( + np.frombuffer( + row[base + Q4_0_D_OFFSET : base + Q4_0_D_OFFSET + 2].tobytes(), + dtype=np.float16, + )[0] + ) + qs = np.frombuffer( + row[base + Q4_0_QS_OFFSET : base + Q4_0_QS_OFFSET + 16].tobytes(), + dtype=np.uint8, + ) + for j in range(16): + x0 = (int(qs[j]) & 0x0F) - 8 + x1 = (int(qs[j]) >> 4) - 8 + out[i, b * Q4_0_GROUP + j] = d * x0 + out[i, b * Q4_0_GROUP + j + 16] = d * x1 + return out + + +def _kquant_matmul_q4_0(x: mx.array, w_packed_np: np.ndarray) -> mx.array: + """Wrap mx.quantized_matmul with mode='kquant' and Q4_0 (gs=32, bits=4).""" + return mx.quantized_matmul( + x, + mx.array(w_packed_np), + scales=_scales_placeholder(), + biases=None, + transpose=True, + group_size=32, + bits=4, + mode="kquant", + kquant_type="q4_0", + ) + + +Q4_1_GROUP = 32 +Q4_1_BLOCK_BYTES = 20 # fp16 d (2) + fp16 m (2) + uint8 qs[16] +Q4_1_D_OFFSET = 0 +Q4_1_M_OFFSET = 2 +Q4_1_QS_OFFSET = 4 + + +def _quantize_q4_1_row(row: np.ndarray) -> np.ndarray: + """Quantize a 1D fp32 array to Q4_1 packed wire bytes (asymmetric, 4-bit).""" + assert row.ndim == 1 + assert row.size % Q4_1_GROUP == 0, "Q4_1 requires K % 32 == 0" + n_blocks = row.size // Q4_1_GROUP + out = np.zeros(n_blocks * Q4_1_BLOCK_BYTES, dtype=np.uint8) + blocks = row.reshape(n_blocks, Q4_1_GROUP).astype(np.float32) + for b in range(n_blocks): + block = blocks[b] + mn = float(block.min()) + mx_ = float(block.max()) + if mx_ == mn: + d = np.float32(0.0) + m = np.float32(mn) + q = np.zeros(Q4_1_GROUP, dtype=np.uint8) + else: + d = np.float32((mx_ - mn) / 15.0) + m = np.float32(mn) + q = np.clip(np.round((block - mn) / d), 0, 15).astype(np.uint8) + qs = np.zeros(16, dtype=np.uint8) + for j in range(16): + qs[j] = np.uint8((int(q[j]) & 0x0F) | ((int(q[j + 16]) & 0x0F) << 4)) + d_fp16 = np.float16(d) + m_fp16 = np.float16(m) + base = b * Q4_1_BLOCK_BYTES + out[base + Q4_1_D_OFFSET : base + Q4_1_D_OFFSET + 2] = np.frombuffer( + d_fp16.tobytes(), dtype=np.uint8 + ) + out[base + Q4_1_M_OFFSET : base + Q4_1_M_OFFSET + 2] = np.frombuffer( + m_fp16.tobytes(), dtype=np.uint8 + ) + out[base + Q4_1_QS_OFFSET : base + Q4_1_QS_OFFSET + 16] = qs + return out + + +def _quantize_q4_1_matrix(W: np.ndarray) -> np.ndarray: + assert W.ndim == 2 + assert W.shape[1] % Q4_1_GROUP == 0 + out_dim, in_dim = W.shape + bytes_per_row = in_dim * Q4_1_BLOCK_BYTES // Q4_1_GROUP + out = np.zeros((out_dim, bytes_per_row), dtype=np.uint8) + for i in range(out_dim): + out[i] = _quantize_q4_1_row(W[i]) + return out + + +def _dequantize_q4_1_matrix(W_q: np.ndarray, in_dim: int) -> np.ndarray: + """Reference Q4_1 dequantization.""" + assert W_q.dtype == np.uint8 + out_dim = W_q.shape[0] + bytes_per_row = W_q.shape[1] + assert bytes_per_row == in_dim * Q4_1_BLOCK_BYTES // Q4_1_GROUP + n_blocks_per_row = in_dim // Q4_1_GROUP + out = np.zeros((out_dim, in_dim), dtype=np.float32) + for i in range(out_dim): + row = W_q[i] + for b in range(n_blocks_per_row): + base = b * Q4_1_BLOCK_BYTES + d = float( + np.frombuffer( + row[base + Q4_1_D_OFFSET : base + Q4_1_D_OFFSET + 2].tobytes(), + dtype=np.float16, + )[0] + ) + m = float( + np.frombuffer( + row[base + Q4_1_M_OFFSET : base + Q4_1_M_OFFSET + 2].tobytes(), + dtype=np.float16, + )[0] + ) + qs = np.frombuffer( + row[base + Q4_1_QS_OFFSET : base + Q4_1_QS_OFFSET + 16].tobytes(), + dtype=np.uint8, + ) + for j in range(16): + x0 = int(qs[j]) & 0x0F + x1 = int(qs[j]) >> 4 + out[i, b * Q4_1_GROUP + j] = d * x0 + m + out[i, b * Q4_1_GROUP + j + 16] = d * x1 + m + return out + + +def _kquant_matmul_q4_1(x: mx.array, w_packed_np: np.ndarray) -> mx.array: + return mx.quantized_matmul( + x, + mx.array(w_packed_np), + scales=_scales_placeholder(), + biases=None, + transpose=True, + group_size=32, + bits=4, + mode="kquant", + kquant_type="q4_1", + ) + + +Q5_0_GROUP = 32 +Q5_0_BLOCK_BYTES = 22 # fp16 d (2) + uint8 qh[4] + uint8 qs[16] +Q5_0_D_OFFSET = 0 +Q5_0_QH_OFFSET = 2 +Q5_0_QS_OFFSET = 6 + + +def _quantize_q5_0_row(row: np.ndarray) -> np.ndarray: + """Quantize a 1D fp32 array to Q5_0 packed wire bytes (symmetric, 5-bit).""" + assert row.ndim == 1 + assert row.size % Q5_0_GROUP == 0, "Q5_0 requires K % 32 == 0" + n_blocks = row.size // Q5_0_GROUP + out = np.zeros(n_blocks * Q5_0_BLOCK_BYTES, dtype=np.uint8) + blocks = row.reshape(n_blocks, Q5_0_GROUP).astype(np.float32) + for b in range(n_blocks): + block = blocks[b] + # Scale from the entry with greatest |x|; range is [-16, 15]. + idx = int(np.argmax(np.abs(block))) + amax_signed = float(block[idx]) + if amax_signed == 0.0: + d = np.float32(0.0) + q = np.full(Q5_0_GROUP, 16, dtype=np.uint8) + else: + d = np.float32(amax_signed / -16.0) + inv_d = 1.0 / float(d) + q = np.clip(np.round(block * inv_d).astype(np.int32) + 16, 0, 31).astype( + np.uint8 + ) + # qh: bit j (j<32) = (q[j] >> 4) & 1, but packed into the 4-byte + # layout shared with Q5_1: low 16 bits hold high-bits of weights + # 0..15 (bit j); high 16 bits hold high-bits of weights 16..31 + # (bit (j-16)+16 = j matches the dequant ((qh >> j) << 4) & 0x10 + # for the low half and ((qh >> (j+12))) & 0x10 for the high half). + qh = np.uint32(0) + for j in range(16): + qh |= np.uint32((int(q[j]) >> 4) & 1) << j + for j in range(16): + qh |= np.uint32((int(q[j + 16]) >> 4) & 1) << (j + 16) + # qs: low 4 bits of q[j] in low nibble of qs[j]; q[j+16] in high nibble. + qs = np.zeros(16, dtype=np.uint8) + for j in range(16): + qs[j] = np.uint8((int(q[j]) & 0x0F) | ((int(q[j + 16]) & 0x0F) << 4)) + d_fp16 = np.float16(d) + base = b * Q5_0_BLOCK_BYTES + out[base + Q5_0_D_OFFSET : base + Q5_0_D_OFFSET + 2] = np.frombuffer( + d_fp16.tobytes(), dtype=np.uint8 + ) + out[base + Q5_0_QH_OFFSET : base + Q5_0_QH_OFFSET + 4] = np.frombuffer( + qh.tobytes(), dtype=np.uint8 + ) + out[base + Q5_0_QS_OFFSET : base + Q5_0_QS_OFFSET + 16] = qs + return out + + +def _quantize_q5_0_matrix(W: np.ndarray) -> np.ndarray: + assert W.ndim == 2 + assert W.shape[1] % Q5_0_GROUP == 0 + out_dim, in_dim = W.shape + bytes_per_row = in_dim * Q5_0_BLOCK_BYTES // Q5_0_GROUP + out = np.zeros((out_dim, bytes_per_row), dtype=np.uint8) + for i in range(out_dim): + out[i] = _quantize_q5_0_row(W[i]) + return out + + +def _dequantize_q5_0_matrix(W_q: np.ndarray, in_dim: int) -> np.ndarray: + """Reference Q5_0 dequantization.""" + assert W_q.dtype == np.uint8 + out_dim = W_q.shape[0] + bytes_per_row = W_q.shape[1] + assert bytes_per_row == in_dim * Q5_0_BLOCK_BYTES // Q5_0_GROUP + n_blocks_per_row = in_dim // Q5_0_GROUP + out = np.zeros((out_dim, in_dim), dtype=np.float32) + for i in range(out_dim): + row = W_q[i] + for b in range(n_blocks_per_row): + base = b * Q5_0_BLOCK_BYTES + d = float( + np.frombuffer( + row[base + Q5_0_D_OFFSET : base + Q5_0_D_OFFSET + 2].tobytes(), + dtype=np.float16, + )[0] + ) + qh_bytes = row[base + Q5_0_QH_OFFSET : base + Q5_0_QH_OFFSET + 4] + qh = ( + int(qh_bytes[0]) + | (int(qh_bytes[1]) << 8) + | (int(qh_bytes[2]) << 16) + | (int(qh_bytes[3]) << 24) + ) + qs = np.frombuffer( + row[base + Q5_0_QS_OFFSET : base + Q5_0_QS_OFFSET + 16].tobytes(), + dtype=np.uint8, + ) + for j in range(16): + xh_0 = ((qh >> j) << 4) & 0x10 + xh_1 = ((qh >> (j + 12))) & 0x10 + x0 = (int(qs[j]) & 0x0F) | xh_0 + x1 = (int(qs[j]) >> 4) | xh_1 + out[i, b * Q5_0_GROUP + j] = d * (x0 - 16) + out[i, b * Q5_0_GROUP + j + 16] = d * (x1 - 16) + return out + + +def _kquant_matmul_q5_0(x: mx.array, w_packed_np: np.ndarray) -> mx.array: + return mx.quantized_matmul( + x, + mx.array(w_packed_np), + scales=_scales_placeholder(), + biases=None, + transpose=True, + group_size=32, + bits=5, + mode="kquant", + kquant_type="q5_0", + ) + + +Q5_1_GROUP = 32 +Q5_1_BLOCK_BYTES = 24 # fp16 d (2) + fp16 m (2) + uint32 qh (4) + uint8 qs[16] +Q5_1_D_OFFSET = 0 +Q5_1_M_OFFSET = 2 +Q5_1_QH_OFFSET = 4 +Q5_1_QS_OFFSET = 8 + + +def _quantize_q5_1_row(row: np.ndarray) -> np.ndarray: + """Quantize a 1D fp32 array (length must be a multiple of 32) to Q5_1 + packed wire bytes. Returns uint8 array of length len(row) * 24/32.""" + assert row.ndim == 1 + assert row.size % Q5_1_GROUP == 0, "Q5_1 requires K % 32 == 0" + n_blocks = row.size // Q5_1_GROUP + out = np.zeros(n_blocks * Q5_1_BLOCK_BYTES, dtype=np.uint8) + blocks = row.reshape(n_blocks, Q5_1_GROUP).astype(np.float32) + for b in range(n_blocks): + block = blocks[b] + mn = float(block.min()) + mx_ = float(block.max()) + if mx_ == mn: + d = np.float32(0.0) + m = np.float32(mn) + q = np.zeros(Q5_1_GROUP, dtype=np.uint8) + else: + d = np.float32((mx_ - mn) / 31.0) + m = np.float32(mn) + q = np.clip(np.round((block - mn) / d), 0, 31).astype(np.uint8) + # Pack 5th bit of each weight into qh (uint32, bit j = (q[j] >> 4) & 1). + qh = np.uint32(0) + for j in range(Q5_1_GROUP): + qh |= np.uint32((int(q[j]) >> 4) & 1) << j + # Pack low 4 bits: qs[j] = (q[j] & 0xF) | ((q[j+16] & 0xF) << 4). + qs = np.zeros(16, dtype=np.uint8) + for j in range(16): + qs[j] = np.uint8((int(q[j]) & 0x0F) | ((int(q[j + 16]) & 0x0F) << 4)) + d_fp16 = np.float16(d) + m_fp16 = np.float16(m) + base = b * Q5_1_BLOCK_BYTES + out[base + Q5_1_D_OFFSET : base + Q5_1_D_OFFSET + 2] = np.frombuffer( + d_fp16.tobytes(), dtype=np.uint8 + ) + out[base + Q5_1_M_OFFSET : base + Q5_1_M_OFFSET + 2] = np.frombuffer( + m_fp16.tobytes(), dtype=np.uint8 + ) + out[base + Q5_1_QH_OFFSET : base + Q5_1_QH_OFFSET + 4] = np.frombuffer( + qh.tobytes(), dtype=np.uint8 + ) + out[base + Q5_1_QS_OFFSET : base + Q5_1_QS_OFFSET + 16] = qs + return out + + +def _quantize_q5_1_matrix(W: np.ndarray) -> np.ndarray: + assert W.ndim == 2 + assert W.shape[1] % Q5_1_GROUP == 0 + out_dim, in_dim = W.shape + bytes_per_row = in_dim * Q5_1_BLOCK_BYTES // Q5_1_GROUP + out = np.zeros((out_dim, bytes_per_row), dtype=np.uint8) + for i in range(out_dim): + out[i] = _quantize_q5_1_row(W[i]) + return out + + +def _dequantize_q5_1_matrix(W_q: np.ndarray, in_dim: int) -> np.ndarray: + """Reference Q5_1 dequantization.""" + assert W_q.dtype == np.uint8 + out_dim = W_q.shape[0] + bytes_per_row = W_q.shape[1] + assert bytes_per_row == in_dim * Q5_1_BLOCK_BYTES // Q5_1_GROUP + n_blocks_per_row = in_dim // Q5_1_GROUP + out = np.zeros((out_dim, in_dim), dtype=np.float32) + for i in range(out_dim): + row = W_q[i] + for b in range(n_blocks_per_row): + base = b * Q5_1_BLOCK_BYTES + d = float( + np.frombuffer( + row[base + Q5_1_D_OFFSET : base + Q5_1_D_OFFSET + 2].tobytes(), + dtype=np.float16, + )[0] + ) + m = float( + np.frombuffer( + row[base + Q5_1_M_OFFSET : base + Q5_1_M_OFFSET + 2].tobytes(), + dtype=np.float16, + )[0] + ) + qh = int( + np.frombuffer( + row[base + Q5_1_QH_OFFSET : base + Q5_1_QH_OFFSET + 4].tobytes(), + dtype=np.uint32, + )[0] + ) + qs = np.frombuffer( + row[base + Q5_1_QS_OFFSET : base + Q5_1_QS_OFFSET + 16].tobytes(), + dtype=np.uint8, + ) + for j in range(16): + xh_0 = ((qh >> j) << 4) & 0x10 + xh_1 = ((qh >> (j + 12))) & 0x10 + x0 = (int(qs[j]) & 0x0F) | xh_0 + x1 = (int(qs[j]) >> 4) | xh_1 + out[i, b * Q5_1_GROUP + j] = d * x0 + m + out[i, b * Q5_1_GROUP + j + 16] = d * x1 + m + return out + + +def _kquant_matmul_q5_1(x: mx.array, w_packed_np: np.ndarray) -> mx.array: + """Wrap mx.quantized_matmul with mode='kquant' and Q5_1 (gs=32, bits=5).""" + return mx.quantized_matmul( + x, + mx.array(w_packed_np), + scales=_scales_placeholder(), + biases=None, + transpose=True, + group_size=32, + bits=5, + mode="kquant", + kquant_type="q5_1", + ) + + +# Reasonable per-dtype tolerances for Q8_0 matmul vs fp32 reference. +# Q8_0 itself is <=1/127 relative quantization error per block; fp16 +# accumulation in the kernel adds another ~1e-4 relative. +_MATMUL_REL_TOL = { + mx.float32: 1e-4, + mx.float16: 5e-3, + mx.bfloat16: 5e-2, # bf16 has ~7-bit mantissa -> looser tolerance +} + + +class _KQuantCodecTestMixin: + """Shared test helpers and standard test methods for K-quant codecs. + + Concrete test classes inherit from this mixin and set the following + class attributes to wire up codec-specific quantize/dequantize/matmul: + + quantize_matrix -- staticmethod: (W: np.ndarray) -> np.ndarray + dequantize_matrix -- staticmethod: (W_q: np.ndarray, in_dim: int) -> np.ndarray + matmul_fn -- staticmethod: (x: mx.array, W_q: np.ndarray) -> mx.array + group_size -- int (32 or 256) + bits -- int + block_bytes -- int (wire-format bytes per block) + kquant_type -- str (e.g. "q8_0", "q4_k") + """ + + # Default dimensions -- overridden by gs=256 codecs. + general_out_dim = 64 + general_in_dim = 768 + qmm_t_shapes = ((4, 64, 1024), (17, 48, 1024), (64, 64, 2048)) + qmm_n_shapes = ((8, 64, 1024), (17, 64, 1024)) + + def _check_dequant_via_one_hot(self, out_dim, in_dim, dtype): + rng = np.random.default_rng(42) + W = rng.standard_normal((out_dim, in_dim)).astype(np.float32) * 0.5 + W_q = self.quantize_matrix(W) + W_ref = self.dequantize_matrix(W_q, in_dim) + cols = [0, 1, in_dim // 2, in_dim - 1] + for col in cols: + x_np = np.zeros(in_dim, dtype=np.float32) + x_np[col] = 1.0 + x = mx.array(x_np) + if dtype != mx.float32: + x = x.astype(dtype) + y = self.matmul_fn(x, W_q) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + ref = W_ref[:, col] + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + self.assertLess( + rel, + 1e-2, + msg=f"dtype={dtype} col={col}: rel={rel:.3e}", + ) + + def _check_random_matmul(self, out_dim, in_dim, dtype): + rng = np.random.default_rng(7) + W = rng.standard_normal((out_dim, in_dim)).astype(np.float32) * 0.3 + W_q = self.quantize_matrix(W) + W_ref = self.dequantize_matrix(W_q, in_dim) + x_np = rng.standard_normal((in_dim,)).astype(np.float32) + ref = W_ref @ x_np + + x = mx.array(x_np) + if dtype != mx.float32: + x = x.astype(dtype) + y = self.matmul_fn(x, W_q) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + out_dtype = mx.bfloat16 if dtype == mx.float32 else dtype + tol = _MATMUL_REL_TOL[out_dtype] + self.assertLess( + rel, + tol, + msg=f"out={out_dim} in={in_dim} dtype={dtype}: rel={rel:.3e} tol={tol:.0e}", + ) + + def _check_qmm_t(self, M, N, K, dtype): + rng = np.random.default_rng(7) + W = rng.standard_normal((N, K)).astype(np.float32) * 0.3 + W_q = self.quantize_matrix(W) + W_ref = self.dequantize_matrix(W_q, K) + X_np = rng.standard_normal((M, K)).astype(np.float32) + ref = X_np @ W_ref.T + x = mx.array(X_np) + if dtype != mx.float32: + x = x.astype(dtype) + y = self.matmul_fn(x, W_q) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + out_dtype = mx.bfloat16 if dtype == mx.float32 else dtype + tol = _MATMUL_REL_TOL[out_dtype] + self.assertLess( + rel, + tol, + msg=f"M={M} N={N} K={K} dtype={dtype}: rel={rel:.3e} tol={tol:.0e}", + ) + + def _check_qmm_n(self, M, N, K, dtype): + rng = np.random.default_rng(7) + W = rng.standard_normal((K, N)).astype(np.float32) * 0.3 + W_q = self.quantize_matrix(W) + W_ref = self.dequantize_matrix(W_q, N) + X_np = rng.standard_normal((M, K)).astype(np.float32) + ref = X_np @ W_ref + x = mx.array(X_np) + if dtype != mx.float32: + x = x.astype(dtype) + y = mx.quantized_matmul( + x, + mx.array(W_q), + scales=_scales_placeholder(), + biases=None, + transpose=False, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + ) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + out_dtype = mx.bfloat16 if dtype == mx.float32 else dtype + tol = _MATMUL_REL_TOL[out_dtype] + self.assertLess( + rel, + tol, + msg=f"M={M} N={N} K={K} dtype={dtype}: rel={rel:.3e} tol={tol:.0e}", + ) + + def test_dequantize_via_one_hot(self): + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(dtype=dtype): + self._check_dequant_via_one_hot(8, 1024, dtype) + + def test_dequantize(self): + """mx.dequantize(mode='kquant') reproduces the codec reference.""" + rng = np.random.default_rng(42) + N, K = 64, 1024 + W = rng.standard_normal((N, K)).astype(np.float32) * 0.3 + W_q = self.quantize_matrix(W) + W_ref = self.dequantize_matrix(W_q, K) + # Per-element cast precision: fp32 is exact, fp16 keeps ~10 bits, + # bf16 keeps ~7 bits of mantissa. + tol_by_dtype = { + mx.float32: 1e-4, + mx.float16: 1e-3, + mx.bfloat16: 5e-3, + } + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(dtype=dtype): + y = mx.dequantize( + mx.array(W_q), + scales=_scales_placeholder(), + biases=None, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + dtype=dtype, + ) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(W_ref)))) + rel = float(np.max(np.abs(y_np - W_ref))) / denom + tol = tol_by_dtype[dtype] + self.assertLess( + rel, + tol, + msg=f"dtype={dtype}: rel={rel:.3e} tol={tol:.0e}", + ) + + def test_random_matmul_fast_path(self): + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(dtype=dtype): + self._check_random_matmul(64, 1024, dtype) + + def test_random_matmul_general_path(self): + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(dtype=dtype): + self._check_random_matmul( + self.general_out_dim, self.general_in_dim, dtype + ) + + def test_random_matmul_small_n(self): + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(dtype=dtype): + self._check_random_matmul(8, 1024, dtype) + + def test_qmm_t(self): + for M, N, K in self.qmm_t_shapes: + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(M=M, N=N, K=K, dtype=dtype): + self._check_qmm_t(M, N, K, dtype) + + def test_qmm_n(self): + for M, N, K in self.qmm_n_shapes: + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(M=M, N=N, K=K, dtype=dtype): + self._check_qmm_n(M, N, K, dtype) + + def test_qmm_n_large_k(self): + """transpose=False with K>=2048 exercises splitk/large-reduction paths.""" + K = 2048 + N = max(256, self.group_size) + for M in (1, 8): + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(M=M, N=N, K=K, dtype=dtype): + self._check_qmm_n(M, N, K, dtype) + + def test_non_multiple_of_8_output_dim(self): + """Non-aligned N (not divisible by 8) exercises edge-tile handling. + + Only tests transpose=True (qmv and qmm_t) where N is the output dim. + For transpose=False, N is the quantized axis and must be block-aligned. + """ + K = 1024 + for N in (13, 33): + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(N=N, dtype=dtype): + self._check_qmm_t(1, N, K, dtype) + self._check_qmm_t(10, N, K, dtype) + + def test_gather_qmm(self): + """Indexed transpose=True matmul (MoE): y[b] = x[lhs[b]] @ W[rhs[b]].T. + + Exercises both gather_qmv (M=1) and gather_qmm_t (M=32) paths. + """ + rng = np.random.default_rng(42) + E = 4 # number of experts + B = 8 # batch positions + K = 1024 # multiple of both 32 and 256 + N = 64 + Ws_q = [] + Ws_ref = [] + for _ in range(E): + W = rng.standard_normal((N, K)).astype(np.float32) * 0.3 + W_q = self.quantize_matrix(W) + Ws_q.append(W_q) + Ws_ref.append(self.dequantize_matrix(W_q, K)) + w_stacked = np.stack(Ws_q, axis=0) # (E, N, bytes_per_row) + rhs_idx = rng.integers(0, E, size=B).astype(np.uint32) + lhs_idx = np.arange(B, dtype=np.uint32) + for M in (1, 32): + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(M=M, dtype=dtype): + X = rng.standard_normal((B, M, K)).astype(np.float32) + ref = np.zeros((B, M, N), dtype=np.float32) + for b in range(B): + ref[b] = X[b] @ Ws_ref[rhs_idx[b]].T + x = mx.array(X) + if dtype != mx.float32: + x = x.astype(dtype) + y = mx.gather_qmm( + x, + mx.array(w_stacked), + scales=mx.zeros((E, 1, 1), dtype=mx.uint8), + biases=None, + lhs_indices=mx.array(lhs_idx), + rhs_indices=mx.array(rhs_idx), + transpose=True, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + ) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + out_dtype = mx.bfloat16 if dtype == mx.float32 else dtype + tol = _MATMUL_REL_TOL[out_dtype] + self.assertLess( + rel, + tol, + msg=f"M={M} dtype={dtype}: rel={rel:.3e} tol={tol:.0e}", + ) + + def test_gather_qvm(self): + """Indexed transpose=False matmul (MoE): y[b] = x[lhs[b]] @ W[rhs[b]]. + + Exercises the gather_qmm_n path for KQuant (transpose=False, small M). + """ + rng = np.random.default_rng(43) + E = 4 + B = 8 + K = 256 + N = 512 + Ws_q = [] + Ws_ref = [] + for _ in range(E): + W = rng.standard_normal((K, N)).astype(np.float32) * 0.3 + W_q = self.quantize_matrix(W) + Ws_q.append(W_q) + Ws_ref.append(self.dequantize_matrix(W_q, N)) + w_stacked = np.stack(Ws_q, axis=0) + rhs_idx = rng.integers(0, E, size=B).astype(np.uint32) + lhs_idx = np.arange(B, dtype=np.uint32) + for M in (1, 2): + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(M=M, dtype=dtype): + X = rng.standard_normal((B, M, K)).astype(np.float32) + ref = np.zeros((B, M, N), dtype=np.float32) + for b in range(B): + ref[b] = X[b] @ Ws_ref[rhs_idx[b]] + x = mx.array(X) + if dtype != mx.float32: + x = x.astype(dtype) + y = mx.gather_qmm( + x, + mx.array(w_stacked), + scales=mx.zeros((E, 1, 1), dtype=mx.uint8), + biases=None, + lhs_indices=mx.array(lhs_idx), + rhs_indices=mx.array(rhs_idx), + transpose=False, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + ) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + out_dtype = mx.bfloat16 if dtype == mx.float32 else dtype + tol = _MATMUL_REL_TOL[out_dtype] + self.assertLess( + rel, + tol, + msg=f"M={M} dtype={dtype}: rel={rel:.3e} tol={tol:.0e}", + ) + + def test_batched_qmm(self): + """Batched (B>1) matmul with 3D weight: out[b] = x[b] @ W[b].T (or @ W[b]). + + Exercises the batched dispatch path (non_batched=false, B>1) across + the qmv decode (M=1), qmm_t prefill (M=64), and qmm_n (transpose=False) + kernels. K=1024 is divisible by every codec's group size. + """ + rng = np.random.default_rng(13) + B = 3 + K = 1024 + # For transpose=False the inner dim of the call is the second axis of + # the (out, in) quantized weight; that must be a multiple of group_size. + N_t = 128 if self.group_size <= 32 else self.group_size + N_n = max(128, self.group_size) + + cases = [ + (1, True, N_t), # qmv decode path + (64, True, N_t), # qmm_t prefill path (M >= vector_limit) + (4, False, N_n), # qmm_n path (transpose=False, small M) + ] + + for M, transpose, N in cases: + for dtype in (mx.float32, mx.bfloat16): + with self.subTest(M=M, transpose=transpose, N=N, dtype=dtype): + Ws_q = [] + Ws_ref = [] + for _ in range(B): + if transpose: + W = rng.standard_normal((N, K)).astype(np.float32) * 0.3 + inner = K + else: + W = rng.standard_normal((K, N)).astype(np.float32) * 0.3 + inner = N + W_q = self.quantize_matrix(W) + Ws_q.append(W_q) + Ws_ref.append(self.dequantize_matrix(W_q, inner)) + w_stacked = np.stack(Ws_q, axis=0) + + X = rng.standard_normal((B, M, K)).astype(np.float32) + ref = np.zeros((B, M, N), dtype=np.float32) + for b in range(B): + if transpose: + ref[b] = X[b] @ Ws_ref[b].T + else: + ref[b] = X[b] @ Ws_ref[b] + + x = mx.array(X) + if dtype != mx.float32: + x = x.astype(dtype) + y = mx.quantized_matmul( + x, + mx.array(w_stacked), + scales=mx.zeros((B, 1, 1), dtype=mx.uint8), + biases=None, + transpose=transpose, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + ) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + out_dtype = mx.bfloat16 if dtype == mx.float32 else dtype + tol = _MATMUL_REL_TOL[out_dtype] + self.assertLess( + rel, + tol, + msg=( + f"M={M} transpose={transpose} N={N} dtype={dtype}: " + f"rel={rel:.3e} tol={tol:.0e}" + ), + ) + + def test_partial_block_weight_rejected(self): + bad_w = mx.zeros((4, self.block_bytes // 2), dtype=mx.uint8) + x = mx.zeros((1, self.group_size // 2), dtype=mx.float16) + scales = _scales_placeholder() + with self.assertRaisesRegex((RuntimeError, ValueError), "whole number of"): + mx.quantized_matmul( + x, + bad_w, + scales=scales, + biases=None, + transpose=True, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + ) + + def test_wrong_weight_dtype_rejected(self): + w = mx.zeros((4, self.block_bytes), dtype=mx.uint32) + x = mx.zeros((1, self.group_size), dtype=mx.float16) + scales = _scales_placeholder() + with self.assertRaisesRegex((RuntimeError, ValueError), "must be uint8"): + mx.quantized_matmul( + x, + w, + scales=scales, + biases=None, + transpose=True, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + ) + + # Tolerance for round-trip relative error per codec. Codecs absent + # from this dict have no MLX encoder yet and skip the round-trip test. + # K-codec entries are added as each encoder lands. Tolerances reflect + # the codec's inherent quantization error on N(0, 0.3^2) input. + _ENCODE_ROUND_TRIP_TOL = { + "q8_0": 5e-3, + "q4_0": 8e-2, + "q4_1": 8e-2, + "q5_0": 4e-2, + "q5_1": 4e-2, + "q6_k": 2e-2, + "q4_k": 8e-2, + "q5_k": 4e-2, + "q3_k": 1.3e-1, + "q2_k": 2.6e-1, + } + + def test_quantize_round_trip(self): + """mx.quantize(mode='kquant') -> mx.dequantize round-trip within tol.""" + if self.kquant_type not in self._ENCODE_ROUND_TRIP_TOL: + self.skipTest(f"MLX encoder for {self.kquant_type} not implemented") + rng = np.random.default_rng(42) + N, K = 8, 1024 + W_np = rng.standard_normal((N, K)).astype(np.float32) * 0.3 + tol = self._ENCODE_ROUND_TRIP_TOL[self.kquant_type] + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(dtype=dtype): + w = mx.array(W_np).astype(dtype) + wq, scales = mx.quantize( + w, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + ) + mx.eval(wq, scales) + self.assertEqual(wq.dtype, mx.uint8) + bytes_per_row = K * self.block_bytes // self.group_size + self.assertEqual(wq.shape, (N, bytes_per_row)) + w_dec = mx.dequantize( + wq, + scales=scales, + biases=None, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + dtype=mx.float32, + ) + mx.eval(w_dec) + y_np = np.asarray(w_dec).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(W_np)))) + rel = float(np.max(np.abs(y_np - W_np))) / denom + self.assertLess( + rel, + tol, + msg=f"dtype={dtype}: rel={rel:.3e} tol={tol:.0e}", + ) + + def test_quantize_missing_type_rejected(self): + """mx.quantize(mode='kquant') without kquant_type must raise.""" + w = mx.zeros((4, 64), dtype=mx.float32) + with self.assertRaisesRegex((RuntimeError, ValueError), "kquant_type"): + mx.quantize(w, group_size=32, bits=8, mode="kquant") + + # Codecs that consume an imatrix during encoding. Q8_0 is symmetric + # amax-only and ignores imatrix; that's tested by the identity test + # below (with-imatrix output equals without-imatrix output). + _IMATRIX_AWARE_CODECS = {"q2_k", "q3_k", "q4_k", "q5_k", "q6_k"} + + def test_quantize_imatrix_none_matches_no_arg(self): + """imatrix=None must produce byte-identical output to omitting the arg. + Both code paths must converge.""" + if self.kquant_type not in self._ENCODE_ROUND_TRIP_TOL: + self.skipTest(f"MLX encoder for {self.kquant_type} not implemented") + rng = np.random.default_rng(11) + K = max(self.group_size, 256) + W = rng.standard_normal((4, K)).astype(np.float32) * 0.3 + wq_a, _ = mx.quantize( + mx.array(W), + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + ) + wq_b, _ = mx.quantize( + mx.array(W), + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + imatrix=None, + ) + mx.eval(wq_a, wq_b) + np.testing.assert_array_equal( + np.asarray(wq_a), + np.asarray(wq_b), + err_msg=f"{self.kquant_type}: imatrix=None must equal no-imatrix path", + ) + + def test_quantize_imatrix_improves_high_importance(self): + """A spiky imatrix should lower error on the high-importance range. + + The spike spans a super-block boundary (cols 200..400 crosses the + 256 boundary), so the imatrix path is exercised across multiple + super-blocks per row rather than only sub-blocks within one. + """ + if self.kquant_type not in self._IMATRIX_AWARE_CODECS: + self.skipTest(f"{self.kquant_type} does not consume imatrix during encode") + rng = np.random.default_rng(7) + K = 1024 + W = rng.standard_normal((8, K)).astype(np.float32) * 0.3 + imat_np = np.ones(K, dtype=np.float32) + imat_np[200:400] = 10.0 # crosses super-block boundary at 256 + imat = mx.array(imat_np) + wq_no, _ = mx.quantize( + mx.array(W), + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + ) + wq_im, _ = mx.quantize( + mx.array(W), + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + imatrix=imat, + ) + mx.eval(wq_no, wq_im) + w_no = mx.dequantize( + wq_no, + scales=mx.zeros((1,), dtype=mx.uint8), + biases=None, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + dtype=mx.float32, + ) + w_im = mx.dequantize( + wq_im, + scales=mx.zeros((1,), dtype=mx.uint8), + biases=None, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + dtype=mx.float32, + ) + mx.eval(w_no, w_im) + w_no_np = np.asarray(w_no) + w_im_np = np.asarray(w_im) + err_no_hi = float(np.mean(np.abs(W[:, 200:400] - w_no_np[:, 200:400]))) + err_im_hi = float(np.mean(np.abs(W[:, 200:400] - w_im_np[:, 200:400]))) + self.assertLess( + err_im_hi, + err_no_hi, + msg=( + f"{self.kquant_type}: imatrix should improve high-importance " + f"err: no-imatrix={err_no_hi:.4e}, with-imatrix={err_im_hi:.4e}" + ), + ) + + def test_quantize_imatrix_wrong_shape_rejected(self): + if self.kquant_type not in self._ENCODE_ROUND_TRIP_TOL: + self.skipTest(f"MLX encoder for {self.kquant_type} not implemented") + K = max(self.group_size, 256) + w = mx.zeros((4, K), dtype=mx.float32) + bad_im = mx.zeros((K - 1,), dtype=mx.float32) + with self.assertRaisesRegex((RuntimeError, ValueError), "imatrix shape"): + wq, _ = mx.quantize( + w, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + imatrix=bad_im, + ) + mx.eval(wq) + + def test_quantize_imatrix_wrong_dtype_rejected(self): + if self.kquant_type not in self._ENCODE_ROUND_TRIP_TOL: + self.skipTest(f"MLX encoder for {self.kquant_type} not implemented") + K = max(self.group_size, 256) + w = mx.zeros((4, K), dtype=mx.float32) + bad_im = mx.zeros((K,), dtype=mx.float16) + with self.assertRaisesRegex( + (RuntimeError, ValueError), "imatrix must be float32" + ): + wq, _ = mx.quantize( + w, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + imatrix=bad_im, + ) + mx.eval(wq) + + +class TestKQuant(_KQuantCodecTestMixin, mlx_tests.MLXTestCase): + """End-to-end tests for K-quant Q8_0 (mode='kquant', gs=32, bits=8). + + Validates the full dispatch chain Python -> ops.cpp -> Metal kernel + (kq_q8_0_qmv / kq_q8_0_qmv_fast). + """ + + quantize_matrix = staticmethod(_quantize_q8_0_matrix) + dequantize_matrix = staticmethod(_dequantize_q8_0_matrix) + matmul_fn = staticmethod(_kquant_matmul) + group_size = 32 + bits = 8 + block_bytes = Q8_0_BLOCK_BYTES + kquant_type = "q8_0" + + def test_random_matmul_large(self): + # 4096 x 2048 fp16 -- closer to a real generation-time linear layer. + self._check_random_matmul(4096, 2048, mx.float16) + + def test_quantize_bit_exact(self): + """Q8_0 has no importance weighting (_ref == _impl), so the Metal + encoder must produce byte-identical output to the Python reference + encoder for fp32 inputs.""" + rng = np.random.default_rng(42) + N, K = 8, 1024 + W = rng.standard_normal((N, K)).astype(np.float32) * 0.3 + wq_ref = _quantize_q8_0_matrix(W) + wq, _ = mx.quantize( + mx.array(W), + group_size=32, + bits=8, + mode="kquant", + kquant_type="q8_0", + ) + mx.eval(wq) + wq_metal = np.asarray(wq) + np.testing.assert_array_equal( + wq_metal, + wq_ref, + err_msg="Q8_0 Metal encoder must match Python reference bit-exactly", + ) + + def test_quantize_all_zero_block(self): + """A block of all zeros must produce d=0 and qs=0 (no NaN).""" + W = np.zeros((1, 64), dtype=np.float32) + wq, _ = mx.quantize( + mx.array(W), + group_size=32, + bits=8, + mode="kquant", + kquant_type="q8_0", + ) + mx.eval(wq) + wq_bytes = np.asarray(wq).flatten() + # 2 blocks x 34 bytes = 68 bytes; all should be zero. + self.assertEqual(wq_bytes.shape, (68,)) + self.assertTrue((wq_bytes == 0).all()) + + def test_qvm_routes_to_qmm_n(self): + """transpose=False with small M previously threw NYI; should now route + through qmm_n and produce correct results matching the qmm_n path. + """ + rng = np.random.default_rng(7) + K, N = 1024, 64 + for M in (1, 2, 3): + W = rng.standard_normal((K, N)).astype(np.float32) * 0.3 + W_q = self.quantize_matrix(W) + W_ref = self.dequantize_matrix(W_q, N) + X_np = rng.standard_normal((M, K)).astype(np.float32) + ref = X_np @ W_ref + x = mx.array(X_np).astype(mx.float16) + y = mx.quantized_matmul( + x, + mx.array(W_q), + scales=_scales_placeholder(), + biases=None, + transpose=False, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + ) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + self.assertLess( + rel, _MATMUL_REL_TOL[mx.float16], msg=f"M={M}: rel={rel:.3e}" + ) + + def test_unknown_codec_rejected(self): + # An unknown kquant_type string must be rejected -- this is the + # post-refactor analog of the old (gs, bits) lookup miss. + w = mx.zeros((4, 34), dtype=mx.uint8) + x = mx.zeros((1, 32), dtype=mx.float16) + scales = _scales_placeholder() + with self.assertRaisesRegex((RuntimeError, ValueError), "Unknown kquant_type"): + mx.quantized_matmul( + x, + w, + scales=scales, + biases=None, + transpose=True, + group_size=32, + bits=8, + mode="kquant", + kquant_type="q9_z", # not a real codec + ) + + def test_missing_kquant_type_rejected(self): + # mode='kquant' without a kquant_type string must error -- there is + # no fallback because (gs, bits) does not uniquely identify a codec. + w = mx.zeros((4, 34), dtype=mx.uint8) + x = mx.zeros((1, 32), dtype=mx.float16) + scales = _scales_placeholder() + with self.assertRaisesRegex( + (RuntimeError, ValueError), "kquant mode requires kquant_type" + ): + mx.quantized_matmul( + x, + w, + scales=scales, + biases=None, + transpose=True, + group_size=32, + bits=8, + mode="kquant", + ) + + def test_nax_k_not_aligned_falls_back_to_alu(self): + """NAX requires K % 64 == 0. With K=96 (multiple of 32 but not 64) the + kquant dispatch must take the ALU path and still produce correct + results. Catches the case where the NAX gate is missing the K%64 + guard or where the ALU fallback regresses.""" + rng = np.random.default_rng(0) + K = 96 + M, N = 4, 32 + W = rng.standard_normal((N, K)).astype(np.float32) * 0.2 + W_q = self.quantize_matrix(W) + W_ref = self.dequantize_matrix(W_q, K) + x_np = rng.standard_normal((M, K)).astype(np.float32) + ref = x_np @ W_ref.T + x = mx.array(x_np).astype(mx.float16) + y = mx.quantized_matmul( + x, + mx.array(W_q), + scales=_scales_placeholder(), + biases=None, + transpose=True, + group_size=self.group_size, + bits=self.bits, + mode="kquant", + kquant_type=self.kquant_type, + ) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + self.assertLess(rel, _MATMUL_REL_TOL[mx.float16]) + + +class TestKQuantQ4_0(_KQuantCodecTestMixin, mlx_tests.MLXTestCase): + """End-to-end tests for K-quant Q4_0 (mode='kquant', gs=32, bits=4). + + Q4_0 is the simplest 4-bit codec: per-block fp16 scale, symmetric + nibble pack with -8 centering, no min. Same threading geometry as + Q8_0 -- split-half nibble layout means each thread's 8-weight slice + falls entirely in either the low or high nibble half. + """ + + quantize_matrix = staticmethod(_quantize_q4_0_matrix) + dequantize_matrix = staticmethod(_dequantize_q4_0_matrix) + matmul_fn = staticmethod(_kquant_matmul_q4_0) + group_size = 32 + bits = 4 + block_bytes = Q4_0_BLOCK_BYTES + kquant_type = "q4_0" + + +class TestKQuantQ4_1(_KQuantCodecTestMixin, mlx_tests.MLXTestCase): + """End-to-end tests for K-quant Q4_1 (mode='kquant', gs=32, bits=4). + + Q4_1 is asymmetric: per-block fp16 (d, m) envelope and a 4-bit nibble + pack with split-half layout. Shares (group_size=32, bits=4) with Q4_0; + the codec-name dispatch is the only thing distinguishing them. + """ + + quantize_matrix = staticmethod(_quantize_q4_1_matrix) + dequantize_matrix = staticmethod(_dequantize_q4_1_matrix) + matmul_fn = staticmethod(_kquant_matmul_q4_1) + group_size = 32 + bits = 4 + block_bytes = Q4_1_BLOCK_BYTES + kquant_type = "q4_1" + + +class TestKQuantQ5_0(_KQuantCodecTestMixin, mlx_tests.MLXTestCase): + """End-to-end tests for K-quant Q5_0 (mode='kquant', gs=32, bits=5). + + Q5_0 is symmetric: per-block fp16 d, no min, 5-bit weights centered + at 16. qh extraction is bit-identical to Q5_1's; the (gs, bits) tuple + collides with Q5_1 -- only the codec name distinguishes them. + """ + + quantize_matrix = staticmethod(_quantize_q5_0_matrix) + dequantize_matrix = staticmethod(_dequantize_q5_0_matrix) + matmul_fn = staticmethod(_kquant_matmul_q5_0) + group_size = 32 + bits = 5 + block_bytes = Q5_0_BLOCK_BYTES + kquant_type = "q5_0" + + +class TestKQuantQ5_1(_KQuantCodecTestMixin, mlx_tests.MLXTestCase): + """End-to-end tests for K-quant Q5_1 (mode='kquant', gs=32, bits=5). + + Validates the full dispatch chain Python -> ops.cpp -> Metal kernel + (kq_q5_1_qmv / kq_q5_1_qmv_fast). Q5_1 is asymmetric (d, m) with + 5-bit unsigned weights; the kernel uses factored accumulation + (d * sum(x*q) + m * sum(x)) so the m-term is row-independent. + """ + + quantize_matrix = staticmethod(_quantize_q5_1_matrix) + dequantize_matrix = staticmethod(_dequantize_q5_1_matrix) + matmul_fn = staticmethod(_kquant_matmul_q5_1) + group_size = 32 + bits = 5 + block_bytes = Q5_1_BLOCK_BYTES + kquant_type = "q5_1" + + def test_random_matmul_large(self): + # 4096 x 2048 fp16 -- closer to a real generation-time linear layer. + self._check_random_matmul(4096, 2048, mx.float16) + + +QK_K = 256 +K_SCALE_SIZE = 12 +Q4_K_BLOCK_BYTES = 144 # fp16 d (2) + fp16 dmin (2) + scales[12] + qs[128] +Q4_K_D_OFFSET = 0 +Q4_K_DMIN_OFFSET = 2 +Q4_K_SCALES_OFFSET = 4 +Q4_K_QS_OFFSET = 16 + + +def _unpack_scale_min_q4k(scales12: np.ndarray): + """Unpack 6-bit sub-scale/min pairs from the 12-byte field. Returns (sc8, mn8), + each shape (n_blocks, 8), uint8 in [0, 63].""" + s = scales12.astype(np.uint8) + n = s.shape[0] + sc = np.empty((n, 8), dtype=np.uint8) + mn = np.empty((n, 8), dtype=np.uint8) + sc[:, 0:4] = s[:, 0:4] & 0x3F + mn[:, 0:4] = s[:, 4:8] & 0x3F + sc[:, 4:8] = (s[:, 8:12] & 0x0F) | ((s[:, 0:4] >> 6) << 4) + mn[:, 4:8] = (s[:, 8:12] >> 4) | ((s[:, 4:8] >> 6) << 4) + return sc, mn + + +def _pack_scale_min_q4k(sc6: np.ndarray, mn6: np.ndarray) -> np.ndarray: + """Inverse of _unpack_scale_min_q4k for a single super-block. sc6/mn6 are + length-8 arrays in [0, 63]; returns 12-byte uint8 array.""" + out = np.zeros(12, dtype=np.uint8) + for j in range(4): + out[j] = (sc6[j] & 0x3F) | (((sc6[j + 4] >> 4) & 0x3) << 6) + out[j + 4] = (mn6[j] & 0x3F) | (((mn6[j + 4] >> 4) & 0x3) << 6) + out[j + 8] = (sc6[j + 4] & 0x0F) | ((mn6[j + 4] & 0x0F) << 4) + return out + + +def _quantize_q4_k_row(row: np.ndarray) -> np.ndarray: + """Quantize a 1D fp32 array (length must be a multiple of 256) to Q4_K + packed wire bytes. Simple non-optimal scheme: per sub-block, sc = (max-min)/15 + and the_min = -min (>= 0 since q is non-negative). Quantize sc and the_min + to 6 bits via a per-super-block envelope (d, dmin). Sufficient for tests + because the kernel is compared against a dequant of the same wire bytes.""" + assert row.ndim == 1 + assert row.size % QK_K == 0, "Q4_K requires K % 256 == 0" + n_super = row.size // QK_K + out = np.zeros(n_super * Q4_K_BLOCK_BYTES, dtype=np.uint8) + for sb_idx in range(n_super): + super_block = ( + row[sb_idx * QK_K : (sb_idx + 1) * QK_K].astype(np.float32).reshape(8, 32) + ) + sc_local = np.zeros(8, dtype=np.float32) + mn_local = np.zeros(8, dtype=np.float32) + q4 = np.zeros((8, 32), dtype=np.uint8) + for j in range(8): + block = super_block[j] + mn_j = float(block.min()) + mx_j = float(block.max()) + if mx_j == mn_j: + sc_local[j] = 0.0 + mn_local[j] = -mn_j # the_min = -block.min() >= 0 if min <= 0 + q4[j, :] = 0 + else: + sc_local[j] = (mx_j - mn_j) / 15.0 + mn_local[j] = -mn_j + q4[j, :] = np.clip( + np.round((block - mn_j) / sc_local[j]), 0, 15 + ).astype(np.uint8) + # Quantize sc and mn to 6-bit. mn_local can be negative if block.min > 0; + # in that case the encoding cannot represent it (dmin_mn6 >= 0), so + # dequant will reconstruct y ~= d*sc*q + 0 which still spans [min..max] + # because q in [0, 15] and sc covers the full range. Edge cases (all + # negative blocks with positive mn_local clamped) are handled below. + max_sc = float(sc_local.max()) + # Clamp mn_local to >= 0 (Q4_K's dmin*mn6 term is non-negative); blocks + # whose values are entirely > 0 get mn_local < 0, which we round to 0. + mn_clamped = np.maximum(mn_local, 0.0) + max_mn = float(mn_clamped.max()) + if max_sc <= 0.0: + d = np.float32(0.0) + sc6 = np.zeros(8, dtype=np.uint8) + else: + d = np.float32(max_sc / 63.0) + sc6 = np.clip(np.round(sc_local / d), 0, 63).astype(np.uint8) + if max_mn <= 0.0: + dmin = np.float32(0.0) + mn6 = np.zeros(8, dtype=np.uint8) + else: + dmin = np.float32(max_mn / 63.0) + mn6 = np.clip(np.round(mn_clamped / dmin), 0, 63).astype(np.uint8) + + scales12 = _pack_scale_min_q4k(sc6, mn6) + + # Pack qs[128]: qs[p*32 + l] = q4[2p, l] | (q4[2p+1, l] << 4) + qs = np.zeros(128, dtype=np.uint8) + for p in range(4): + for l in range(32): + qs[p * 32 + l] = (q4[2 * p, l] & 0x0F) | ( + (q4[2 * p + 1, l] & 0x0F) << 4 + ) + + d_fp16 = np.float16(d) + dmin_fp16 = np.float16(dmin) + base = sb_idx * Q4_K_BLOCK_BYTES + out[base + Q4_K_D_OFFSET : base + Q4_K_D_OFFSET + 2] = np.frombuffer( + d_fp16.tobytes(), dtype=np.uint8 + ) + out[base + Q4_K_DMIN_OFFSET : base + Q4_K_DMIN_OFFSET + 2] = np.frombuffer( + dmin_fp16.tobytes(), dtype=np.uint8 + ) + out[base + Q4_K_SCALES_OFFSET : base + Q4_K_SCALES_OFFSET + K_SCALE_SIZE] = ( + scales12 + ) + out[base + Q4_K_QS_OFFSET : base + Q4_K_QS_OFFSET + 128] = qs + return out + + +def _quantize_q4_k_matrix(W: np.ndarray) -> np.ndarray: + assert W.ndim == 2 + assert W.shape[1] % QK_K == 0 + out_dim, in_dim = W.shape + bytes_per_row = in_dim * Q4_K_BLOCK_BYTES // QK_K + out = np.zeros((out_dim, bytes_per_row), dtype=np.uint8) + for i in range(out_dim): + out[i] = _quantize_q4_k_row(W[i]) + return out + + +def _dequantize_q4_k_matrix(W_q: np.ndarray, in_dim: int) -> np.ndarray: + """Reference Q4_K dequantization.""" + assert W_q.dtype == np.uint8 + out_dim = W_q.shape[0] + bytes_per_row = W_q.shape[1] + assert bytes_per_row == in_dim * Q4_K_BLOCK_BYTES // QK_K + n_blocks_per_row = in_dim // QK_K + blocks = W_q.reshape(out_dim * n_blocks_per_row, Q4_K_BLOCK_BYTES) + n_blocks = blocks.shape[0] + d = blocks[:, 0:2].copy().view(np.float16).astype(np.float32).reshape(n_blocks) + dmin = blocks[:, 2:4].copy().view(np.float16).astype(np.float32).reshape(n_blocks) + sc8, mn8 = _unpack_scale_min_q4k(blocks[:, 4:16]) + qs = blocks[:, 16 : 16 + 128] + + sub_scale = d[:, None] * sc8.astype(np.float32) + sub_min = dmin[:, None] * mn8.astype(np.float32) + + qs_g = qs.reshape(n_blocks, 4, 32) + low_nib = (qs_g & 0x0F).astype(np.float32) + high_nib = (qs_g >> 4).astype(np.float32) + sub_q = np.stack([low_nib, high_nib], axis=2).reshape(n_blocks, 8, 32) + + out_flat = ( + (sub_scale[:, :, None] * sub_q - sub_min[:, :, None]) + .reshape(n_blocks * QK_K) + .astype(np.float32) + ) + return out_flat.reshape(out_dim, in_dim) + + +def _kquant_matmul_q4_k(x: mx.array, w_packed_np: np.ndarray) -> mx.array: + """Wrap mx.quantized_matmul with mode='kquant' and Q4_K (gs=256, bits=4).""" + return mx.quantized_matmul( + x, + mx.array(w_packed_np), + scales=_scales_placeholder(), + biases=None, + transpose=True, + group_size=256, + bits=4, + mode="kquant", + kquant_type="q4_k", + ) + + +Q5_K_BLOCK_BYTES = 176 # Q4_K + qh[32] +Q5_K_D_OFFSET = 0 +Q5_K_DMIN_OFFSET = 2 +Q5_K_SCALES_OFFSET = 4 +Q5_K_QH_OFFSET = 16 +Q5_K_QS_OFFSET = 48 + + +def _quantize_q5_k_row(row: np.ndarray) -> np.ndarray: + """Quantize fp32 row to Q5_K wire bytes. Same scheme as Q4_K but with q + in [0, 31] instead of [0, 15]; the high bit goes into qh[32] under the + transposed scheme: bit `sb` of qh[l] = high bit of weight `sb*32 + l`.""" + assert row.ndim == 1 + assert row.size % QK_K == 0, "Q5_K requires K % 256 == 0" + n_super = row.size // QK_K + out = np.zeros(n_super * Q5_K_BLOCK_BYTES, dtype=np.uint8) + for sb_idx in range(n_super): + super_block = ( + row[sb_idx * QK_K : (sb_idx + 1) * QK_K].astype(np.float32).reshape(8, 32) + ) + sc_local = np.zeros(8, dtype=np.float32) + mn_local = np.zeros(8, dtype=np.float32) + q5 = np.zeros((8, 32), dtype=np.uint8) + for j in range(8): + block = super_block[j] + mn_j = float(block.min()) + mx_j = float(block.max()) + if mx_j == mn_j: + sc_local[j] = 0.0 + mn_local[j] = -mn_j + q5[j, :] = 0 + else: + sc_local[j] = (mx_j - mn_j) / 31.0 + mn_local[j] = -mn_j + q5[j, :] = np.clip( + np.round((block - mn_j) / sc_local[j]), 0, 31 + ).astype(np.uint8) + + max_sc = float(sc_local.max()) + mn_clamped = np.maximum(mn_local, 0.0) + max_mn = float(mn_clamped.max()) + if max_sc <= 0.0: + d = np.float32(0.0) + sc6 = np.zeros(8, dtype=np.uint8) + else: + d = np.float32(max_sc / 63.0) + sc6 = np.clip(np.round(sc_local / d), 0, 63).astype(np.uint8) + if max_mn <= 0.0: + dmin = np.float32(0.0) + mn6 = np.zeros(8, dtype=np.uint8) + else: + dmin = np.float32(max_mn / 63.0) + mn6 = np.clip(np.round(mn_clamped / dmin), 0, 63).astype(np.uint8) + + scales12 = _pack_scale_min_q4k(sc6, mn6) + + # Low 4 bits -> qs[128] nibble pairs (same packing as Q4_K). + # High bit -> qh[32] under the transposed scheme: for weight at + # position [sb, l] (sub-block sb, within-sb l), bit sb of qh[l] is set. + qs = np.zeros(128, dtype=np.uint8) + qh = np.zeros(32, dtype=np.uint8) + for p in range(4): + for l in range(32): + low_even = q5[2 * p, l] & 0x0F + low_odd = q5[2 * p + 1, l] & 0x0F + qs[p * 32 + l] = low_even | (low_odd << 4) + for sb in range(8): + for l in range(32): + bit = (int(q5[sb, l]) >> 4) & 1 + qh[l] |= bit << sb + + d_fp16 = np.float16(d) + dmin_fp16 = np.float16(dmin) + base = sb_idx * Q5_K_BLOCK_BYTES + out[base + Q5_K_D_OFFSET : base + Q5_K_D_OFFSET + 2] = np.frombuffer( + d_fp16.tobytes(), dtype=np.uint8 + ) + out[base + Q5_K_DMIN_OFFSET : base + Q5_K_DMIN_OFFSET + 2] = np.frombuffer( + dmin_fp16.tobytes(), dtype=np.uint8 + ) + out[base + Q5_K_SCALES_OFFSET : base + Q5_K_SCALES_OFFSET + K_SCALE_SIZE] = ( + scales12 + ) + out[base + Q5_K_QH_OFFSET : base + Q5_K_QH_OFFSET + 32] = qh + out[base + Q5_K_QS_OFFSET : base + Q5_K_QS_OFFSET + 128] = qs + return out + + +def _quantize_q5_k_matrix(W: np.ndarray) -> np.ndarray: + assert W.ndim == 2 + assert W.shape[1] % QK_K == 0 + out_dim, in_dim = W.shape + bytes_per_row = in_dim * Q5_K_BLOCK_BYTES // QK_K + out = np.zeros((out_dim, bytes_per_row), dtype=np.uint8) + for i in range(out_dim): + out[i] = _quantize_q5_k_row(W[i]) + return out + + +def _dequantize_q5_k_matrix(W_q: np.ndarray, in_dim: int) -> np.ndarray: + """Reference Q5_K dequantization.""" + assert W_q.dtype == np.uint8 + out_dim = W_q.shape[0] + bytes_per_row = W_q.shape[1] + assert bytes_per_row == in_dim * Q5_K_BLOCK_BYTES // QK_K + n_blocks_per_row = in_dim // QK_K + blocks = W_q.reshape(out_dim * n_blocks_per_row, Q5_K_BLOCK_BYTES) + n_blocks = blocks.shape[0] + d = blocks[:, 0:2].copy().view(np.float16).astype(np.float32).reshape(n_blocks) + dmin = blocks[:, 2:4].copy().view(np.float16).astype(np.float32).reshape(n_blocks) + sc8, mn8 = _unpack_scale_min_q4k(blocks[:, 4:16]) + qh = blocks[:, 16:48] + qs = blocks[:, 48 : 48 + 128] + + sub_scale = d[:, None] * sc8.astype(np.float32) + sub_min = dmin[:, None] * mn8.astype(np.float32) + + qs_g = qs.reshape(n_blocks, 4, 32) + low_nib = (qs_g & 0x0F).astype(np.uint8) + high_nib = (qs_g >> 4).astype(np.uint8) + low_bits = np.stack([low_nib, high_nib], axis=2).reshape(n_blocks, 8, 32) + + bit_sel = np.arange(8, dtype=np.uint8).reshape(1, 8, 1) + high_bit = (qh[:, None, :] >> bit_sel) & 0x01 + + q5 = (low_bits | (high_bit << 4)).astype(np.float32) + out_flat = ( + (sub_scale[:, :, None] * q5 - sub_min[:, :, None]) + .reshape(n_blocks * QK_K) + .astype(np.float32) + ) + return out_flat.reshape(out_dim, in_dim) + + +def _kquant_matmul_q5_k(x: mx.array, w_packed_np: np.ndarray) -> mx.array: + """Wrap mx.quantized_matmul with mode='kquant' and Q5_K (gs=256, bits=5).""" + return mx.quantized_matmul( + x, + mx.array(w_packed_np), + scales=_scales_placeholder(), + biases=None, + transpose=True, + group_size=256, + bits=5, + mode="kquant", + kquant_type="q5_k", + ) + + +# REVERSED field order vs Q4_K/Q5_K: payload first, envelope last. +Q6_K_BLOCK_BYTES = 210 +Q6_K_QL_OFFSET = 0 +Q6_K_QH_OFFSET = 128 +Q6_K_SCALES_OFFSET = 192 +Q6_K_D_OFFSET = 208 + + +def _quantize_q6_k_row(row: np.ndarray) -> np.ndarray: + """Quantize fp32 row to Q6_K wire bytes. Symmetric codec: 16 sub-blocks + of 16 weights, signed int8 sub-block scales, no dmin. Simplified + quantizer (positive sub-scales only) -- produces valid q6_K bytes that + round-trip through the reference dequantizer.""" + assert row.ndim == 1 + assert row.size % QK_K == 0, "Q6_K requires K % 256 == 0" + n_super = row.size // QK_K + out = np.zeros(n_super * Q6_K_BLOCK_BYTES, dtype=np.uint8) + for sb_idx in range(n_super): + super_block = ( + row[sb_idx * QK_K : (sb_idx + 1) * QK_K] + .astype(np.float32) + .reshape(16, 16) # 16 sub-blocks of 16 weights each + ) + scale_local = np.zeros(16, dtype=np.float32) + L_signed = np.zeros((16, 16), dtype=np.int32) # in [-32, 31] + for j in range(16): + block = super_block[j] + amax = float(np.max(np.abs(block))) + if amax == 0.0: + scale_local[j] = 0.0 + L_signed[j, :] = 0 + else: + scale_local[j] = amax / 31.0 + L_signed[j, :] = np.clip( + np.round(block / scale_local[j]), -32, 31 + ).astype(np.int32) + + max_sc = float(np.max(np.abs(scale_local))) + if max_sc == 0.0: + d = np.float32(0.0) + scales_i8 = np.zeros(16, dtype=np.int8) + else: + d = np.float32(max_sc / 127.0) + scales_i8 = np.clip(np.round(scale_local / d), -127, 127).astype(np.int8) + + # Encode L as unsigned 6-bit: q6_unsigned = L_signed + 32 in [0, 63]. + # Layout: super_block[j, i] = row[j*16 + i], so weight at global + # within-superblock index g = j*16 + i uses scales_i8[j] = scales_i8[g/16] + # -- matches the dequant scale assignment. + L_flat = (L_signed + 32).astype(np.uint8).reshape(QK_K) + + ql_out = np.zeros(128, dtype=np.uint8) + qh_out = np.zeros(64, dtype=np.uint8) + # Pack ql/qh: + # 2 halves x 32 columns; quadrants 0/2 share ql[0..31] (low/high + # nibble), quadrants 1/3 share ql[32..63]; all 4 share qh[0..31] + # at shifts {0, 2, 4, 6}. + for half in range(2): + h = half * 128 + for l in range(32): + ql_out[half * 64 + l + 0] = (L_flat[h + l + 0] & 0xF) | ( + (L_flat[h + l + 64] & 0xF) << 4 + ) + ql_out[half * 64 + l + 32] = (L_flat[h + l + 32] & 0xF) | ( + (L_flat[h + l + 96] & 0xF) << 4 + ) + qh_out[half * 32 + l] = ( + (L_flat[h + l + 0] >> 4) + | ((L_flat[h + l + 32] >> 4) << 2) + | ((L_flat[h + l + 64] >> 4) << 4) + | ((L_flat[h + l + 96] >> 4) << 6) + ) + + d_fp16 = np.float16(d) + base = sb_idx * Q6_K_BLOCK_BYTES + out[base + Q6_K_QL_OFFSET : base + Q6_K_QL_OFFSET + 128] = ql_out + out[base + Q6_K_QH_OFFSET : base + Q6_K_QH_OFFSET + 64] = qh_out + out[base + Q6_K_SCALES_OFFSET : base + Q6_K_SCALES_OFFSET + 16] = np.frombuffer( + scales_i8.tobytes(), dtype=np.uint8 + ) + out[base + Q6_K_D_OFFSET : base + Q6_K_D_OFFSET + 2] = np.frombuffer( + d_fp16.tobytes(), dtype=np.uint8 + ) + return out + + +def _quantize_q6_k_matrix(W: np.ndarray) -> np.ndarray: + assert W.ndim == 2 + assert W.shape[1] % QK_K == 0 + out_dim, in_dim = W.shape + bytes_per_row = in_dim * Q6_K_BLOCK_BYTES // QK_K + out = np.zeros((out_dim, bytes_per_row), dtype=np.uint8) + for i in range(out_dim): + out[i] = _quantize_q6_k_row(W[i]) + return out + + +def _dequantize_q6_k_matrix(W_q: np.ndarray, in_dim: int) -> np.ndarray: + """Reference Q6_K dequantization.""" + assert W_q.dtype == np.uint8 + out_dim = W_q.shape[0] + bytes_per_row = W_q.shape[1] + assert bytes_per_row == in_dim * Q6_K_BLOCK_BYTES // QK_K + n_blocks_per_row = in_dim // QK_K + blocks = W_q.reshape(out_dim * n_blocks_per_row, Q6_K_BLOCK_BYTES) + n_blocks = blocks.shape[0] + + ql = blocks[:, 0:128] + qh = blocks[:, 128:192] + scales = blocks[:, 192:208] + d = blocks[:, 208:210].copy().view(np.float16).astype(np.float32).reshape(n_blocks) + sc16 = scales.view(np.int8).astype(np.float32) + + ql_h = ql.reshape(n_blocks, 2, 64) + qh_h = qh.reshape(n_blocks, 2, 32) + sc_h = sc16.reshape(n_blocks, 2, 8) + + out = np.empty((n_blocks, 2, 128), dtype=np.float32) + is_idx = np.arange(32) // 16 + for half_idx in range(2): + ql_half = ql_h[:, half_idx, :] + qh_half = qh_h[:, half_idx, :] + sc_half = sc_h[:, half_idx, :] + ql_lo = ql_half[:, 0:32] + ql_lo32 = ql_half[:, 32:64] + q1 = ((ql_lo & 0x0F) | (((qh_half >> 0) & 0x03) << 4)).astype( + np.int8 + ) - np.int8(32) + q2 = ((ql_lo32 & 0x0F) | (((qh_half >> 2) & 0x03) << 4)).astype( + np.int8 + ) - np.int8(32) + q3 = ((ql_lo >> 4) | (((qh_half >> 4) & 0x03) << 4)).astype(np.int8) - np.int8( + 32 + ) + q4 = ((ql_lo32 >> 4) | (((qh_half >> 6) & 0x03) << 4)).astype( + np.int8 + ) - np.int8(32) + for is_off, qq, out_slice in ( + (0, q1, slice(0, 32)), + (2, q2, slice(32, 64)), + (4, q3, slice(64, 96)), + (6, q4, slice(96, 128)), + ): + scl = sc_half[:, is_off + is_idx] + d_eff = d[:, None] * scl + out[:, half_idx, out_slice] = d_eff * qq.astype(np.float32) + + return out.reshape(n_blocks * QK_K).astype(np.float32).reshape(out_dim, in_dim) + + +def _kquant_matmul_q6_k(x: mx.array, w_packed_np: np.ndarray) -> mx.array: + """Wrap mx.quantized_matmul with mode='kquant' and Q6_K (gs=256, bits=6).""" + return mx.quantized_matmul( + x, + mx.array(w_packed_np), + scales=_scales_placeholder(), + biases=None, + transpose=True, + group_size=256, + bits=6, + mode="kquant", + kquant_type="q6_k", + ) + + +Q3_K_BLOCK_BYTES = 110 # hmask[32] + qs[64] + scales[12] + fp16 d (2) +Q3_K_HMASK_OFFSET = 0 +Q3_K_QS_OFFSET = 32 +Q3_K_SCALES_OFFSET = 96 +Q3_K_D_OFFSET = 108 + + +def _unpack_scale_q3k(scales12: np.ndarray) -> np.ndarray: + """Unpack 4-word bit-shuffled Q3_K scales. Given the 12-byte `scales` field, + returns a length-16 uint8 array of the per-sub-block scales in [0, 63] + (subtract 32 for the signed [-32, 31] effective scale).""" + s = scales12.astype(np.uint8) + assert s.shape == (12,) + out = np.zeros(16, dtype=np.uint8) + for k in range(4): + # quad 0: scales[0..3] -> low 4 from s[0..3], high 2 from s[8..11] bits 0-1 + out[k] = (s[k] & 0x0F) | ((s[8 + k] & 0x03) << 4) + # quad 1: scales[4..7] -> low 4 from s[4..7], high 2 from s[8..11] bits 2-3 + out[k + 4] = (s[k + 4] & 0x0F) | (((s[8 + k] >> 2) & 0x03) << 4) + # quad 2: scales[8..11] -> low 4 from s[0..3] high nibble, high 2 from s[8..11] bits 4-5 + out[k + 8] = ((s[k] >> 4) & 0x0F) | (((s[8 + k] >> 4) & 0x03) << 4) + # quad 3: scales[12..15] -> low 4 from s[4..7] high nibble, high 2 from s[8..11] bits 6-7 + out[k + 12] = ((s[k + 4] >> 4) & 0x0F) | (((s[8 + k] >> 6) & 0x03) << 4) + return out + + +def _pack_scale_q3k(sc6: np.ndarray) -> np.ndarray: + """Inverse of _unpack_scale_q3k. sc6 is a length-16 uint8 array in [0, 63]; + returns a 12-byte uint8 array of packed scales. Bijective: pack(unpack(x))==x.""" + sc = sc6.astype(np.uint8) + assert sc.shape == (16,) + out = np.zeros(12, dtype=np.uint8) + for k in range(4): + out[k] = (sc[k] & 0x0F) | ((sc[k + 8] & 0x0F) << 4) + out[k + 4] = (sc[k + 4] & 0x0F) | ((sc[k + 12] & 0x0F) << 4) + out[k + 8] = ( + ((sc[k] >> 4) & 0x03) + | (((sc[k + 4] >> 4) & 0x03) << 2) + | (((sc[k + 8] >> 4) & 0x03) << 4) + | (((sc[k + 12] >> 4) & 0x03) << 6) + ) + return out + + +def _quantize_q3_k_row(row: np.ndarray) -> np.ndarray: + """Quantize fp32 row to Q3_K wire bytes. Symmetric codec: 16 sub-blocks of + 16 weights, signed 6-bit-biased per-sub-block scales (encoded via the + bit-shuffled 12-byte `scales` field), single fp16 super-block envelope `d`. + q3 in [-4, 3] split as 2-bit qs payload + 1-bit hmask (hmask SET when q3 >= 0). + Simplified non-optimal quantizer (positive sub-scales only) -- produces valid + Q3_K bytes that round-trip through the reference dequantizer.""" + assert row.ndim == 1 + assert row.size % QK_K == 0, "Q3_K requires K % 256 == 0" + n_super = row.size // QK_K + out = np.zeros(n_super * Q3_K_BLOCK_BYTES, dtype=np.uint8) + for sb_idx in range(n_super): + super_block = ( + row[sb_idx * QK_K : (sb_idx + 1) * QK_K] + .astype(np.float32) + .reshape(16, 16) # 16 sub-blocks of 16 weights each + ) + scale_local = np.zeros(16, dtype=np.float32) + L_signed = np.zeros((16, 16), dtype=np.int32) # in [-4, 3] + for j in range(16): + block = super_block[j] + amax = float(np.max(np.abs(block))) + if amax == 0.0: + scale_local[j] = 0.0 + L_signed[j, :] = 0 + else: + scale_local[j] = amax / 4.0 # q in [-4, 3] + L_signed[j, :] = np.clip( + np.round(block / scale_local[j]), -4, 3 + ).astype(np.int32) + + max_sc = float(scale_local.max()) + if max_sc <= 0.0: + d = np.float32(0.0) + sc6 = np.full(16, 32, dtype=np.uint8) # encodes 0 + else: + d = np.float32(max_sc / 31.0) + sc_unsigned = np.clip(np.round(scale_local / d), 0, 31).astype(np.int32) + sc6 = (sc_unsigned + 32).astype(np.uint8) # bias by +32 + + scales12 = _pack_scale_q3k(sc6) + + # Pack qs[64] (2-bit payload) and hmask[32] (1-bit high mask). + # Linear weight w_idx = j*16 + i in [0, 256): + # qs_byte_idx = (w_idx // 128) * 32 + (w_idx & 31) + # qs_shift = ((w_idx // 32) & 3) * 2 + # hmask_byte = w_idx & 31 + # hmask_bit = (w_idx >> 5) & 7 + # h_bit = 1 iff q3_signed >= 0; q2 = q3_signed & 3 (covers both ranges). + L_flat = L_signed.reshape(QK_K) + q2_arr = (L_flat & 3).astype(np.uint8) + h_bit_arr = (L_flat >= 0).astype(np.uint8) + + qs = np.zeros(64, dtype=np.uint8) + hmask = np.zeros(32, dtype=np.uint8) + for w_idx in range(QK_K): + outer_half = w_idx // 128 + shift_idx = (w_idx // 32) & 3 + qs_byte_idx = outer_half * 32 + (w_idx & 31) + qs[qs_byte_idx] |= int(q2_arr[w_idx]) << (shift_idx * 2) + hmask_byte = w_idx & 31 + hmask_bit = (w_idx >> 5) & 7 + hmask[hmask_byte] |= int(h_bit_arr[w_idx]) << hmask_bit + + d_fp16 = np.float16(d) + base = sb_idx * Q3_K_BLOCK_BYTES + out[base + Q3_K_HMASK_OFFSET : base + Q3_K_HMASK_OFFSET + 32] = hmask + out[base + Q3_K_QS_OFFSET : base + Q3_K_QS_OFFSET + 64] = qs + out[base + Q3_K_SCALES_OFFSET : base + Q3_K_SCALES_OFFSET + 12] = scales12 + out[base + Q3_K_D_OFFSET : base + Q3_K_D_OFFSET + 2] = np.frombuffer( + d_fp16.tobytes(), dtype=np.uint8 + ) + return out + + +def _quantize_q3_k_matrix(W: np.ndarray) -> np.ndarray: + assert W.ndim == 2 + assert W.shape[1] % QK_K == 0 + out_dim, in_dim = W.shape + bytes_per_row = in_dim * Q3_K_BLOCK_BYTES // QK_K + out = np.zeros((out_dim, bytes_per_row), dtype=np.uint8) + for i in range(out_dim): + out[i] = _quantize_q3_k_row(W[i]) + return out + + +def _dequantize_q3_k_matrix(W_q: np.ndarray, in_dim: int) -> np.ndarray: + """Reference Q3_K dequantization. w[i] = d * (scales[is] - 32) * (q2 - h), + where q2 in [0, 3] from qs and h in {0, 4} (h=0 if hmask bit SET, else 4).""" + assert W_q.dtype == np.uint8 + out_dim = W_q.shape[0] + bytes_per_row = W_q.shape[1] + assert bytes_per_row == in_dim * Q3_K_BLOCK_BYTES // QK_K + n_blocks_per_row = in_dim // QK_K + blocks = W_q.reshape(out_dim * n_blocks_per_row, Q3_K_BLOCK_BYTES) + n_blocks = blocks.shape[0] + + out = np.zeros((n_blocks, QK_K), dtype=np.float32) + for b in range(n_blocks): + hmask = blocks[b, Q3_K_HMASK_OFFSET : Q3_K_HMASK_OFFSET + 32] + qs_full = blocks[b, Q3_K_QS_OFFSET : Q3_K_QS_OFFSET + 64] + scales12 = blocks[b, Q3_K_SCALES_OFFSET : Q3_K_SCALES_OFFSET + 12] + d = float( + np.frombuffer( + blocks[b, Q3_K_D_OFFSET : Q3_K_D_OFFSET + 2].tobytes(), + dtype=np.float16, + )[0] + ) + sc16 = _unpack_scale_q3k(scales12).astype(np.int32) - 32 # signed [-32, 31] + + out_idx = 0 + for outer_half in range(2): + qs_chunk = qs_full[outer_half * 32 : (outer_half + 1) * 32] + for shift_idx in range(4): + shift = shift_idx * 2 + # m = 1 << (outer_half * 4 + shift_idx) -- same as 1 << hmask_bit + m = 1 << (outer_half * 4 + shift_idx) + is_left = outer_half * 8 + shift_idx * 2 + dl_left = d * float(sc16[is_left]) + for l in range(16): + q2 = (int(qs_chunk[l]) >> shift) & 3 + h = 0 if (int(hmask[l]) & m) else 4 + out[b, out_idx] = dl_left * (q2 - h) + out_idx += 1 + is_right = is_left + 1 + dl_right = d * float(sc16[is_right]) + for l in range(16): + q2 = (int(qs_chunk[l + 16]) >> shift) & 3 + h = 0 if (int(hmask[l + 16]) & m) else 4 + out[b, out_idx] = dl_right * (q2 - h) + out_idx += 1 + + return out.reshape(out_dim, in_dim) + + +def _kquant_matmul_q3_k(x: mx.array, w_packed_np: np.ndarray) -> mx.array: + """Wrap mx.quantized_matmul with mode='kquant' and Q3_K (gs=256, bits=3).""" + return mx.quantized_matmul( + x, + mx.array(w_packed_np), + scales=_scales_placeholder(), + biases=None, + transpose=True, + group_size=256, + bits=3, + mode="kquant", + kquant_type="q3_k", + ) + + +Q2_K_BLOCK_BYTES = 84 # scales[16] + qs[64] + fp16 d + fp16 dmin +Q2_K_SCALES_OFFSET = 0 +Q2_K_QS_OFFSET = 16 +Q2_K_D_OFFSET = 80 +Q2_K_DMIN_OFFSET = 82 + + +def _quantize_q2_k_row(row: np.ndarray) -> np.ndarray: + """Quantize fp32 row to Q2_K wire bytes. Asymmetric codec: 16 sub-blocks + of 16 weights, 4-bit (scale, min) nibble pair per sub-block in scales[16], + fp16 (d, dmin) super-block envelope. q2 in [0, 3] in qs (2 bits/weight, + 4 weights/byte). Simplified non-optimal quantizer -- produces valid Q2_K + bytes that round-trip through the reference dequantizer.""" + assert row.ndim == 1 + assert row.size % QK_K == 0, "Q2_K requires K % 256 == 0" + n_super = row.size // QK_K + out = np.zeros(n_super * Q2_K_BLOCK_BYTES, dtype=np.uint8) + for sb_idx in range(n_super): + super_block = ( + row[sb_idx * QK_K : (sb_idx + 1) * QK_K] + .astype(np.float32) + .reshape(16, 16) # 16 sub-blocks of 16 weights each + ) + sc_local = np.zeros(16, dtype=np.float32) + mn_local = np.zeros(16, dtype=np.float32) + q2_arr = np.zeros((16, 16), dtype=np.uint8) + for j in range(16): + block = super_block[j] + mn_j = float(block.min()) + mx_j = float(block.max()) + if mx_j == mn_j: + sc_local[j] = 0.0 + mn_local[j] = -mn_j + q2_arr[j, :] = 0 + else: + sc_local[j] = (mx_j - mn_j) / 3.0 # q in [0, 3] + mn_local[j] = -mn_j + q2_arr[j, :] = np.clip( + np.round((block - mn_j) / sc_local[j]), 0, 3 + ).astype(np.uint8) + + # Quantize sub-scales and sub-mins to 4-bit unsigned. mn_local can be + # negative if all values in block are positive; the encoded min is + # non-negative so we clamp >= 0 (dequant reconstructs values in + # [0, d*sc*3] = [0, max_j], slight rounding loss for entirely positive + # blocks but acceptable for the test reference). + max_sc = float(sc_local.max()) + mn_clamped = np.maximum(mn_local, 0.0) + max_mn = float(mn_clamped.max()) + if max_sc <= 0.0: + d = np.float32(0.0) + sc4 = np.zeros(16, dtype=np.uint8) + else: + d = np.float32(max_sc / 15.0) + sc4 = np.clip(np.round(sc_local / d), 0, 15).astype(np.uint8) + if max_mn <= 0.0: + dmin = np.float32(0.0) + mn4 = np.zeros(16, dtype=np.uint8) + else: + dmin = np.float32(max_mn / 15.0) + mn4 = np.clip(np.round(mn_clamped / dmin), 0, 15).astype(np.uint8) + + scales = (sc4 | (mn4 << 4)).astype(np.uint8) # length 16 + + # Pack qs[64]: same access pattern as Q3_K (no hmask). + L_flat = q2_arr.reshape(QK_K) + qs = np.zeros(64, dtype=np.uint8) + for w_idx in range(QK_K): + outer_half = w_idx // 128 + shift_idx = (w_idx // 32) & 3 + qs_byte_idx = outer_half * 32 + (w_idx & 31) + qs[qs_byte_idx] |= int(L_flat[w_idx]) << (shift_idx * 2) + + d_fp16 = np.float16(d) + dmin_fp16 = np.float16(dmin) + base = sb_idx * Q2_K_BLOCK_BYTES + out[base + Q2_K_SCALES_OFFSET : base + Q2_K_SCALES_OFFSET + 16] = scales + out[base + Q2_K_QS_OFFSET : base + Q2_K_QS_OFFSET + 64] = qs + out[base + Q2_K_D_OFFSET : base + Q2_K_D_OFFSET + 2] = np.frombuffer( + d_fp16.tobytes(), dtype=np.uint8 + ) + out[base + Q2_K_DMIN_OFFSET : base + Q2_K_DMIN_OFFSET + 2] = np.frombuffer( + dmin_fp16.tobytes(), dtype=np.uint8 + ) + return out + + +def _quantize_q2_k_matrix(W: np.ndarray) -> np.ndarray: + assert W.ndim == 2 + assert W.shape[1] % QK_K == 0 + out_dim, in_dim = W.shape + bytes_per_row = in_dim * Q2_K_BLOCK_BYTES // QK_K + out = np.zeros((out_dim, bytes_per_row), dtype=np.uint8) + for i in range(out_dim): + out[i] = _quantize_q2_k_row(W[i]) + return out + + +def _dequantize_q2_k_matrix(W_q: np.ndarray, in_dim: int) -> np.ndarray: + """Reference Q2_K dequantization. + w[i] = d * (sc & 0xF) * q2 - dmin * (sc >> 4).""" + assert W_q.dtype == np.uint8 + out_dim = W_q.shape[0] + bytes_per_row = W_q.shape[1] + assert bytes_per_row == in_dim * Q2_K_BLOCK_BYTES // QK_K + n_blocks_per_row = in_dim // QK_K + blocks = W_q.reshape(out_dim * n_blocks_per_row, Q2_K_BLOCK_BYTES) + n_blocks = blocks.shape[0] + + out = np.zeros((n_blocks, QK_K), dtype=np.float32) + for b in range(n_blocks): + scales = blocks[b, Q2_K_SCALES_OFFSET : Q2_K_SCALES_OFFSET + 16] + qs_full = blocks[b, Q2_K_QS_OFFSET : Q2_K_QS_OFFSET + 64] + d = float( + np.frombuffer( + blocks[b, Q2_K_D_OFFSET : Q2_K_D_OFFSET + 2].tobytes(), + dtype=np.float16, + )[0] + ) + dmin = float( + np.frombuffer( + blocks[b, Q2_K_DMIN_OFFSET : Q2_K_DMIN_OFFSET + 2].tobytes(), + dtype=np.float16, + )[0] + ) + + out_idx = 0 + is_idx = 0 + for outer_half in range(2): + qs_chunk = qs_full[outer_half * 32 : (outer_half + 1) * 32] + for shift_idx in range(4): + shift = shift_idx * 2 + sc_byte_left = int(scales[is_idx]) + is_idx += 1 + dl_left = d * float(sc_byte_left & 0x0F) + ml_left = dmin * float(sc_byte_left >> 4) + for l in range(16): + q2 = (int(qs_chunk[l]) >> shift) & 3 + out[b, out_idx] = dl_left * q2 - ml_left + out_idx += 1 + sc_byte_right = int(scales[is_idx]) + is_idx += 1 + dl_right = d * float(sc_byte_right & 0x0F) + ml_right = dmin * float(sc_byte_right >> 4) + for l in range(16): + q2 = (int(qs_chunk[l + 16]) >> shift) & 3 + out[b, out_idx] = dl_right * q2 - ml_right + out_idx += 1 + + return out.reshape(out_dim, in_dim) + + +def _kquant_matmul_q2_k(x: mx.array, w_packed_np: np.ndarray) -> mx.array: + """Wrap mx.quantized_matmul with mode='kquant' and Q2_K (gs=256, bits=2).""" + return mx.quantized_matmul( + x, + mx.array(w_packed_np), + scales=_scales_placeholder(), + biases=None, + transpose=True, + group_size=256, + bits=2, + mode="kquant", + kquant_type="q2_k", + ) + + +class TestKQuantQ4_K(_KQuantCodecTestMixin, mlx_tests.MLXTestCase): + """End-to-end tests for K-quant Q4_K (mode='kquant', gs=256, bits=4). + + Q4_K has hierarchical scales: a per-super-block (d, dmin) envelope and + 8 sub-blocks of 32 weights each, each with a 6-bit (sub-scale, sub-min) + pair packed into the 12-byte `scales` field. Validates the dispatch + chain and the kq_q4_k_qmv / kq_q4_k_qmv_fast Metal kernels. + """ + + quantize_matrix = staticmethod(_quantize_q4_k_matrix) + dequantize_matrix = staticmethod(_dequantize_q4_k_matrix) + matmul_fn = staticmethod(_kquant_matmul_q4_k) + group_size = 256 + bits = 4 + block_bytes = Q4_K_BLOCK_BYTES + kquant_type = "q4_k" + general_out_dim = 13 + general_in_dim = 256 + qmm_n_shapes = ((8, 256, 1024), (17, 256, 1024)) + + def test_random_matmul_large(self): + # 4096 x 2048 fp16 -- closer to a real generation-time linear layer. + self._check_random_matmul(4096, 2048, mx.float16) + + +class TestKQuantQ5_K(_KQuantCodecTestMixin, mlx_tests.MLXTestCase): + """End-to-end tests for K-quant Q5_K (mode='kquant', gs=256, bits=5). + + Q5_K is Q4_K plus a high-bit array `qh[32]` between the scales and the + qs payload. Each weight is 5 bits (low 4 from qs nibble, high 1 from qh). + The qh layout is transposed: bit `sb` of qh[l] is the high bit of weight + `sb*32 + l`. Validates the dispatch chain and the kq_q5_k_qmv / + kq_q5_k_qmv_fast Metal kernels. + """ + + quantize_matrix = staticmethod(_quantize_q5_k_matrix) + dequantize_matrix = staticmethod(_dequantize_q5_k_matrix) + matmul_fn = staticmethod(_kquant_matmul_q5_k) + group_size = 256 + bits = 5 + block_bytes = Q5_K_BLOCK_BYTES + kquant_type = "q5_k" + general_out_dim = 13 + general_in_dim = 256 + qmm_n_shapes = ((8, 256, 1024), (17, 256, 1024)) + + def test_random_matmul_large(self): + self._check_random_matmul(4096, 2048, mx.float16) + + +class TestKQuantQ6_K(_KQuantCodecTestMixin, mlx_tests.MLXTestCase): + """End-to-end tests for K-quant Q6_K (mode='kquant', gs=256, bits=6). + + Q6_K has 16 sub-blocks of 16 weights each, signed int8 per-sub-block + scales, and a single fp16 super-block envelope `d`. Dequant is symmetric + (no dmin): `w = d * scale[j] * (q6 - 32)`. Wire format reverses Q4_K/Q5_K + field order: payload (`ql`/`qh`) comes first, envelope (`d`) is at the + end of the 210-byte block. Validates the dispatch chain and the + kq_q6_k_qmv / kq_q6_k_qmv_fast Metal kernels. + """ + + quantize_matrix = staticmethod(_quantize_q6_k_matrix) + dequantize_matrix = staticmethod(_dequantize_q6_k_matrix) + matmul_fn = staticmethod(_kquant_matmul_q6_k) + group_size = 256 + bits = 6 + block_bytes = Q6_K_BLOCK_BYTES + kquant_type = "q6_k" + general_out_dim = 13 + general_in_dim = 256 + qmm_n_shapes = ((8, 256, 1024), (17, 256, 1024)) + + def test_random_matmul_large(self): + self._check_random_matmul(4096, 2048, mx.float16) + + +class TestKQuantQ3_K(_KQuantCodecTestMixin, mlx_tests.MLXTestCase): + """End-to-end tests for K-quant Q3_K (mode='kquant', gs=256, bits=3). + + Q3_K is symmetric (no dmin) with 16 sub-blocks of 16 weights, signed + 6-bit-biased per-sub-block scales packed via a 4-word bit-shuffle into + a 12-byte field, a 2-bit qs payload (64 bytes), and a 1-bit hmask + (32 bytes) selecting between q3 ranges [0, 3] (hmask SET) and [-4, -1] + (hmask CLEAR). Validates the dispatch chain and the kq_q3_k_qmv / + kq_q3_k_qmv_fast Metal kernels. + """ + + quantize_matrix = staticmethod(_quantize_q3_k_matrix) + dequantize_matrix = staticmethod(_dequantize_q3_k_matrix) + matmul_fn = staticmethod(_kquant_matmul_q3_k) + group_size = 256 + bits = 3 + block_bytes = Q3_K_BLOCK_BYTES + kquant_type = "q3_k" + general_out_dim = 13 + general_in_dim = 256 + qmm_t_shapes = ((4, 64, 1024), (17, 48, 1024), (64, 64, 2048), (8, 256, 1024)) + qmm_n_shapes = ((8, 256, 1024), (17, 256, 1024)) + + def test_random_matmul_large(self): + self._check_random_matmul(4096, 2048, mx.float16) + + def test_q3_k_scale_unpack_fixture(self): + """Bit-exact validation of the 4-word scale unpack vs. a hand-computed + expected output. De-risks the trickiest piece in this codec.""" + scales12 = np.array( + [0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0, 0x55, 0xAA, 0x33, 0xCC], + dtype=np.uint8, + ) + expected = np.array( + [ + 0x12, + 0x24, + 0x36, + 0x08, + 0x1A, + 0x2C, + 0x0E, + 0x30, + 0x11, + 0x23, + 0x35, + 0x07, + 0x19, + 0x2B, + 0x0D, + 0x3F, + ], + dtype=np.uint8, + ) + actual = _unpack_scale_q3k(scales12) + np.testing.assert_array_equal(actual, expected) + # Bijective round-trip: pack(unpack(x)) == x for any 12-byte input. + np.testing.assert_array_equal(_pack_scale_q3k(actual), scales12) + + def test_q3_k_hmask_inversion(self): + """Hand-crafted block: hmask all-set vs all-clear with a fixed qs/scale + pattern. Catches Metal kernel `(hm & m) ? 0 : 4` inversion bugs.""" + # d=1.0 fp16, all 16 scales = 33 (encodes effective scale = 33-32 = 1), + # qs = 0x55 (each byte holds 4 weights with q2=1 at every shift), + # hmask = 0xFF (h=0, q3=q2-0=1) vs 0x00 (h=4, q3=q2-4=-3). + sc6 = np.full(16, 33, dtype=np.uint8) + scales12 = _pack_scale_q3k(sc6) + for hmask_val, expected_w in ((0xFF, 1.0), (0x00, -3.0)): + block = np.zeros(Q3_K_BLOCK_BYTES, dtype=np.uint8) + block[Q3_K_HMASK_OFFSET : Q3_K_HMASK_OFFSET + 32] = hmask_val + block[Q3_K_QS_OFFSET : Q3_K_QS_OFFSET + 64] = 0x55 + block[Q3_K_SCALES_OFFSET : Q3_K_SCALES_OFFSET + 12] = scales12 + d_fp16 = np.float16(1.0) + block[Q3_K_D_OFFSET : Q3_K_D_OFFSET + 2] = np.frombuffer( + d_fp16.tobytes(), dtype=np.uint8 + ) + W_q = block.reshape(1, Q3_K_BLOCK_BYTES) + # Reference round-trip. + W_ref = _dequantize_q3_k_matrix(W_q, 256) + np.testing.assert_allclose(W_ref, expected_w, atol=1e-4) + # Metal kernel via qmv with x = ones (sum should be 256 * expected). + x_np = np.ones(256, dtype=np.float32) + y = _kquant_matmul_q3_k(mx.array(x_np), W_q) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + np.testing.assert_allclose( + y_np, + np.array([256.0 * expected_w]), + rtol=1e-3, + err_msg=f"hmask=0x{hmask_val:02X}", + ) + + +class TestKQuantQ2_K(_KQuantCodecTestMixin, mlx_tests.MLXTestCase): + """End-to-end tests for K-quant Q2_K (mode='kquant', gs=256, bits=2). + + Q2_K is asymmetric with 16 sub-blocks of 16 weights, 4-bit (scale, min) + nibble pairs per sub-block, fp16 (d, dmin) super-block envelope, and a + 2-bit qs payload (64 bytes). Trivial nibble-split scales -- no + bit-shuffle. Dequant: w = d * (sc & 0xF) * q2 - dmin * (sc >> 4). + Validates the dispatch chain and the kq_q2_k_qmv / kq_q2_k_qmv_fast + Metal kernels. + """ + + quantize_matrix = staticmethod(_quantize_q2_k_matrix) + dequantize_matrix = staticmethod(_dequantize_q2_k_matrix) + matmul_fn = staticmethod(_kquant_matmul_q2_k) + group_size = 256 + bits = 2 + block_bytes = Q2_K_BLOCK_BYTES + kquant_type = "q2_k" + general_out_dim = 13 + general_in_dim = 256 + qmm_t_shapes = ((4, 64, 1024), (17, 48, 1024), (64, 64, 2048), (8, 256, 1024)) + qmm_n_shapes = ((8, 256, 1024), (17, 256, 1024)) + + def test_random_matmul_large(self): + self._check_random_matmul(4096, 2048, mx.float16) + + def test_q2_k_zero_min(self): + """Hand-crafted block: dmin=0 collapses asymmetric -> symmetric. Catches + Metal kernel mishandling of the (sc >> 4) min term separate from sc.""" + # d=1, dmin=0, scales[i] = 0x01 (sc=1, mn=0), qs all 0x55 -> q2=1. + # Expected per-weight: 1 * 1 * 1 - 0 * 0 = 1.0. + block = np.zeros(Q2_K_BLOCK_BYTES, dtype=np.uint8) + block[Q2_K_SCALES_OFFSET : Q2_K_SCALES_OFFSET + 16] = 0x01 + block[Q2_K_QS_OFFSET : Q2_K_QS_OFFSET + 64] = 0x55 + block[Q2_K_D_OFFSET : Q2_K_D_OFFSET + 2] = np.frombuffer( + np.float16(1.0).tobytes(), dtype=np.uint8 + ) + block[Q2_K_DMIN_OFFSET : Q2_K_DMIN_OFFSET + 2] = np.frombuffer( + np.float16(0.0).tobytes(), dtype=np.uint8 + ) + W_q = block.reshape(1, Q2_K_BLOCK_BYTES) + W_ref = _dequantize_q2_k_matrix(W_q, 256) + np.testing.assert_allclose(W_ref, 1.0, atol=1e-4) + x_np = np.ones(256, dtype=np.float32) + y = _kquant_matmul_q2_k(mx.array(x_np), W_q) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + np.testing.assert_allclose(y_np, np.array([256.0]), rtol=1e-3) + + +class TestKQuantNax(mlx_tests.MLXTestCase): + """End-to-end correctness checks for the kquant NAX path. + + Each codec's existing per-class tests already exercise a mix of NAX- + eligible and ALU-eligible shapes (large K -> NAX, K=96 -> ALU); these + tests target NAX-specific concerns: cross-codec correctness and the + `gather_qmm_rhs_nax` route (rhs-only gather). + """ + + # (codec, group_size, bits, quantize_matrix, dequantize_matrix) + NAX_CODECS = [ + ("q8_0", 32, 8, _quantize_q8_0_matrix, _dequantize_q8_0_matrix), + ("q5_1", 32, 5, _quantize_q5_1_matrix, _dequantize_q5_1_matrix), + ("q4_k", 256, 4, _quantize_q4_k_matrix, _dequantize_q4_k_matrix), + ("q5_k", 256, 5, _quantize_q5_k_matrix, _dequantize_q5_k_matrix), + ("q6_k", 256, 6, _quantize_q6_k_matrix, _dequantize_q6_k_matrix), + ("q3_k", 256, 3, _quantize_q3_k_matrix, _dequantize_q3_k_matrix), + ("q2_k", 256, 2, _quantize_q2_k_matrix, _dequantize_q2_k_matrix), + ] + + def test_nax_matmul_correctness(self): + """For each NAX-supported codec run a NAX-eligible quantized_matmul + (K % 64 == 0, transpose=true, fp16) and compare against a numpy + reference computed from the dequantized weights. Tolerances follow + _MATMUL_REL_TOL[fp16] (5e-3) -- NAX MMA rounds slightly differently + from ALU simdgroup MMA but stays well within bound. + """ + rng = np.random.default_rng(11) + # K=512 is divisible by both 64 (NAX) and 256 (super-block codecs). + K, N, M = 512, 64, 64 + for codec, gs, bits, qfn, dqfn in self.NAX_CODECS: + W = rng.standard_normal((N, K)).astype(np.float32) * 0.25 + W_q = qfn(W) + W_ref = dqfn(W_q, K) + x_np = rng.standard_normal((M, K)).astype(np.float32) + ref = x_np @ W_ref.T + for dtype in (mx.float16, mx.bfloat16): + with self.subTest(codec=codec, dtype=dtype): + x = mx.array(x_np).astype(dtype) + y = mx.quantized_matmul( + x, + mx.array(W_q), + scales=_scales_placeholder(), + biases=None, + transpose=True, + group_size=gs, + bits=bits, + mode="kquant", + kquant_type=codec, + ) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + self.assertLess( + rel, + _MATMUL_REL_TOL[dtype], + msg=f"{codec} {dtype}: rel={rel:.3e}", + ) + + def test_nax_gather_qmm(self): + """Gather variants. `gather_qmm` with both lhs+rhs indices exercises + `gather_qmm_nax` (T variant). With only rhs_indices via mx.gather_qmm + + sorted_indices=True + M=1 the GatherQMM dispatch routes through + `gather_qmm_rhs_nax`.""" + rng = np.random.default_rng(13) + # E experts, B routed positions, K NAX-aligned, N tile-aligned. + E, B, K, N = 4, 32, 256, 64 + codec, gs, bits, qfn, dqfn = ( + "q4_k", + 256, + 4, + _quantize_q4_k_matrix, + _dequantize_q4_k_matrix, + ) + Ws_q, Ws_ref = [], [] + for _ in range(E): + W = rng.standard_normal((N, K)).astype(np.float32) * 0.3 + wq = qfn(W) + Ws_q.append(wq) + Ws_ref.append(dqfn(wq, K)) + w_stacked = mx.array(np.stack(Ws_q, axis=0)) + + # -- A. gather_qmm with both indices (M=32 -> gather_qmm_nax) -- + rhs_idx = rng.integers(0, E, size=B).astype(np.uint32) + lhs_idx = np.arange(B, dtype=np.uint32) + M = 32 + X = rng.standard_normal((B, M, K)).astype(np.float32) + ref = np.stack([X[b] @ Ws_ref[rhs_idx[b]].T for b in range(B)]) + x = mx.array(X).astype(mx.float16) + y = mx.gather_qmm( + x, + w_stacked, + scales=mx.zeros((E, 1, 1), dtype=mx.uint8), + biases=None, + lhs_indices=mx.array(lhs_idx), + rhs_indices=mx.array(rhs_idx), + transpose=True, + group_size=gs, + bits=bits, + mode="kquant", + kquant_type=codec, + ) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + self.assertLess( + rel, _MATMUL_REL_TOL[mx.float16] * 5, msg=f"gather_qmm rel={rel:.3e}" + ) + + # -- B. gather_qmm with only rhs (M=1, sorted) -> gather_qmm_rhs_nax -- + # The eval_gpu fast path requires B>=16, B/E>=4, right_sorted=True. + rhs_sorted = np.sort(rng.integers(0, E, size=B).astype(np.uint32)) + M = 1 + X = rng.standard_normal((B, M, K)).astype(np.float32) + ref_rhs = np.stack([X[b] @ Ws_ref[rhs_sorted[b]].T for b in range(B)]) + x = mx.array(X).astype(mx.float16) + y = mx.gather_qmm( + x, + w_stacked, + scales=mx.zeros((E, 1, 1), dtype=mx.uint8), + biases=None, + lhs_indices=None, + rhs_indices=mx.array(rhs_sorted), + transpose=True, + group_size=gs, + bits=bits, + mode="kquant", + kquant_type=codec, + sorted_indices=True, + ) + mx.eval(y) + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(ref_rhs)))) + rel = float(np.max(np.abs(y_np - ref_rhs))) / denom + self.assertLess( + rel, _MATMUL_REL_TOL[mx.float16] * 5, msg=f"gather_qmm_rhs rel={rel:.3e}" + ) + + +def _make_kquant_linear( + K, N, codec, group_size, bits, *, weight_packed_np, bias_np=None +): + """Construct an nn.QuantizedLinear with pre-built kquant wire-format weights. + + Constructs in affine mode and rewrites the layer state to kquant, since + __init__ generates random weights but tests need specific byte patterns. + """ + ql = nn.QuantizedLinear(K, N, bias=(bias_np is not None), mode="affine") + ql.mode = "kquant" + ql.kquant_type = codec + ql.group_size = group_size + ql.bits = bits + ql.weight = mx.array(weight_packed_np) + ql.scales = _scales_placeholder() + ql.biases = None + if bias_np is not None: + ql.bias = mx.array(bias_np) + return ql + + +class TestKQuantQuantizedLinear(mlx_tests.MLXTestCase): + """End-to-end tests for nn.QuantizedLinear under mode="kquant". + + The mx.* op level is already covered by per-codec test classes above; + these tests guard the layer wrapper itself so a future change to + QuantizedLinear's storage or dispatch can't silently break kquant mode. + """ + + _CODEC_MATRIX = ( + # (codec, group_size, bits, quantize_fn, dequantize_fn) + ("q8_0", 32, 8, _quantize_q8_0_matrix, _dequantize_q8_0_matrix), + ("q4_k", 256, 4, _quantize_q4_k_matrix, _dequantize_q4_k_matrix), + ("q6_k", 256, 6, _quantize_q6_k_matrix, _dequantize_q6_k_matrix), + ) + + def test_quantized_linear_forward(self): + K, N, M = 1024, 64, 4 + rng = np.random.default_rng(0) + W = rng.standard_normal((N, K)).astype(np.float32) * 0.3 + x_np = rng.standard_normal((M, K)).astype(np.float32) + + for codec, gs, bits, quantize_fn, dequantize_fn in self._CODEC_MATRIX: + W_q = quantize_fn(W) + W_ref = dequantize_fn(W_q, K) + ref = x_np @ W_ref.T # (M, N) + + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(codec=codec, dtype=dtype): + ql = _make_kquant_linear( + K, N, codec, gs, bits, weight_packed_np=W_q + ) + + x = mx.array(x_np) + if dtype != mx.float32: + x = x.astype(dtype) + y = ql(x) + mx.eval(y) + + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + out_dtype = mx.bfloat16 if dtype == mx.float32 else dtype + tol = _MATMUL_REL_TOL[out_dtype] + self.assertLess( + rel, + tol, + msg=f"codec={codec} dtype={dtype}: " + f"rel={rel:.3e} tol={tol:.0e}", + ) + + def test_quantized_linear_with_bias(self): + K, N, M = 1024, 64, 4 + rng = np.random.default_rng(1) + W = rng.standard_normal((N, K)).astype(np.float32) * 0.3 + x_np = rng.standard_normal((M, K)).astype(np.float32) + bias_np = rng.standard_normal((N,)).astype(np.float32) * 0.5 + + W_q = _quantize_q8_0_matrix(W) + W_ref = _dequantize_q8_0_matrix(W_q, K) + ref = x_np @ W_ref.T + bias_np + + for dtype in (mx.float32, mx.float16, mx.bfloat16): + with self.subTest(dtype=dtype): + ql = _make_kquant_linear( + K, N, "q8_0", 32, 8, weight_packed_np=W_q, bias_np=bias_np + ) + self.assertIn("bias", ql) + + x = mx.array(x_np) + if dtype != mx.float32: + x = x.astype(dtype) + y = ql(x) + mx.eval(y) + + y_np = np.asarray(y.astype(mx.float32)).astype(np.float32) + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + out_dtype = mx.bfloat16 if dtype == mx.float32 else dtype + tol = _MATMUL_REL_TOL[out_dtype] + self.assertLess( + rel, tol, msg=f"dtype={dtype} bias: rel={rel:.3e} tol={tol:.0e}" + ) + + def test_quantized_linear_save_load_roundtrip(self): + K, N, M = 1024, 64, 4 + rng = np.random.default_rng(2) + W = rng.standard_normal((N, K)).astype(np.float32) * 0.3 + x_np = rng.standard_normal((M, K)).astype(np.float32) + W_q = _quantize_q8_0_matrix(W) + + ql = _make_kquant_linear(K, N, "q8_0", 32, 8, weight_packed_np=W_q) + x = mx.array(x_np).astype(mx.float16) + y_orig = ql(x) + mx.eval(y_orig) + + tdir = tempfile.TemporaryDirectory() + try: + path = os.path.join(tdir.name, "ql.safetensors") + ql.save_weights(path) + + ql2 = _make_kquant_linear( + K, N, "q8_0", 32, 8, weight_packed_np=np.zeros_like(W_q) + ) + ql2.load_weights(path) + + y_reload = ql2(x) + mx.eval(y_reload) + self.assertTrue(mx.array_equal(y_orig, y_reload).item()) + finally: + tdir.cleanup() + + def test_quantized_linear_repr(self): + K, N = 1024, 64 + gs, bits = 256, 4 + # Q4_K: 144 bytes per 256-weight super-block. + bytes_per_row = K * 144 // 256 + ql = _make_kquant_linear( + K, + N, + "q4_k", + gs, + bits, + weight_packed_np=np.zeros((N, bytes_per_row), dtype=np.uint8), + ) + + r = repr(ql) + for needle in ( + "mode=kquant", + "kquant_type=q4_k", + f"input_dims={K}", + f"output_dims={N}", + f"group_size={gs}", + f"bits={bits}", + ): + self.assertIn(needle, r, msg=f"missing {needle!r} in repr: {r}") + + +class TestKQuantVJP(mlx_tests.MLXTestCase): + """Gradient tests for kquant quantized_matmul and gather_qmm.""" + + def test_qmm_vjp_x(self): + """VJP wrt x: d/dx (x @ dequant(w).T) cotangent should match.""" + rng = np.random.default_rng(42) + M, K, N = 4, 256, 32 + W_fp = rng.standard_normal((N, K)).astype(np.float32) * 0.1 + W_q_np = _quantize_q8_0_matrix(W_fp) + w_q = mx.array(W_q_np) + x = mx.random.normal((M, K)) + c = mx.ones((M, N)) + + def fn(x_): + return mx.quantized_matmul( + x_, + w_q, + scales=_scales_placeholder(), + transpose=True, + group_size=32, + bits=8, + mode="kquant", + kquant_type="q8_0", + ) + + _, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,)) + + expected = mx.quantized_matmul( + c, + w_q, + scales=_scales_placeholder(), + transpose=False, + group_size=32, + bits=8, + mode="kquant", + kquant_type="q8_0", + ) + self.assertTrue(mx.allclose(vjp_out[0], expected, atol=1e-4)) + + def test_qmm_jvp_x(self): + """JVP wrt x: tangent through quantized_matmul for both transposes.""" + rng = np.random.default_rng(42) + M, K, N = 4, 256, 32 + W_fp = rng.standard_normal((N, K)).astype(np.float32) * 0.1 + W_q_np = _quantize_q8_0_matrix(W_fp) + w_q = mx.array(W_q_np) + x = mx.random.normal((M, K)) + x_tan = mx.ones((M, K)) + + for transpose in (True, False): + with self.subTest(transpose=transpose): + W_shape = (N, K) if transpose else (K, N) + W_local = rng.standard_normal(W_shape).astype(np.float32) * 0.1 + W_q_local = _quantize_q8_0_matrix(W_local) + w_q_local = mx.array(W_q_local) + + def fn(x_): + return mx.quantized_matmul( + x_, + w_q_local, + scales=_scales_placeholder(), + transpose=transpose, + group_size=32, + bits=8, + mode="kquant", + kquant_type="q8_0", + ) + + _, jvp_out = mx.jvp(fn, primals=(x,), tangents=(x_tan,)) + expected = mx.quantized_matmul( + x_tan, + w_q_local, + scales=_scales_placeholder(), + transpose=transpose, + group_size=32, + bits=8, + mode="kquant", + kquant_type="q8_0", + ) + self.assertTrue(mx.allclose(jvp_out[0], expected, atol=1e-4)) + + def test_gather_qmm_vjp_x(self): + """VJP wrt x through gather_qmm.""" + rng = np.random.default_rng(7) + M, K, N = 4, 256, 32 + B = 2 + W_fp = rng.standard_normal((N, K)).astype(np.float32) * 0.1 + W_q_np = _quantize_q8_0_matrix(W_fp) + w_q = mx.array(W_q_np) + w_q = mx.broadcast_to(w_q[None], (B, N, W_q_np.shape[1])) + x = mx.random.normal((B, M, K)) + rhs_indices = mx.array([[0], [1]]) + + def fn(x_): + return mx.gather_qmm( + x_, + w_q, + scales=_scales_placeholder(), + rhs_indices=rhs_indices, + transpose=True, + group_size=32, + bits=8, + mode="kquant", + kquant_type="q8_0", + ) + + out = fn(x) + c = mx.ones_like(out) + _, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,)) + mx.eval(vjp_out[0]) + self.assertEqual(vjp_out[0].shape, x.shape) + + +class TestKQuantQuantizedEmbedding(mlx_tests.MLXTestCase): + """Tests for nn.QuantizedEmbedding under mode='kquant'.""" + + def test_embedding_forward(self): + """Construct a kquant QuantizedEmbedding and verify lookup.""" + num_embeddings, dims = 16, 256 + rng = np.random.default_rng(99) + W = rng.standard_normal((num_embeddings, dims)).astype(np.float32) * 0.3 + W_q = _quantize_q8_0_matrix(W) + W_ref = _dequantize_q8_0_matrix(W_q, dims) + + qe = nn.QuantizedEmbedding(num_embeddings, dims, mode="affine") + qe.mode = "kquant" + qe.kquant_type = "q8_0" + qe.group_size = 32 + qe.bits = 8 + qe.weight = mx.array(W_q) + qe.scales = _scales_placeholder() + qe.biases = None + + indices = mx.array([0, 3, 7, 15]) + y = qe(indices) + mx.eval(y) + + y_np = np.asarray(y.astype(mx.float32)) + ref = W_ref[np.array([0, 3, 7, 15])] + denom = max(1e-8, float(np.max(np.abs(ref)))) + rel = float(np.max(np.abs(y_np - ref))) / denom + self.assertLess( + rel, _MATMUL_REL_TOL[mx.bfloat16], msg=f"embedding rel={rel:.3e}" + ) + + +class TestKQuantEdgeCases(mlx_tests.MLXTestCase): + """Edge case tests for kquant dispatch paths.""" + + def test_codec_geometry_consistency(self): + """Python _KQUANT_CODEC_GEOMETRY must match C++ kquant_codec_by_name.""" + from mlx.nn.layers.quantized import _KQUANT_CODEC_GEOMETRY + + for codec, (wpb, bpb) in _KQUANT_CODEC_GEOMETRY.items(): + with self.subTest(codec=codec): + K = wpb * 4 + packed = mx.zeros((1, (K // wpb) * bpb), dtype=mx.uint8) + out = mx.dequantize( + packed, + mx.zeros((1,), dtype=mx.uint8), + None, + mode="kquant", + kquant_type=codec, + ) + self.assertEqual(out.shape[-1], K) + + def test_qmv_quad_nyi_raises(self): + """K=64 or K=128 with M